diff --git a/doc/cli_tools.md b/doc/cli_tools.md index 7c5de29fb5e553ace985c8f91a3984fa63564c91..859849ede943e6d0a68b8b8432439ba4f1785011 100644 --- a/doc/cli_tools.md +++ b/doc/cli_tools.md @@ -353,7 +353,7 @@ detailed parameters: ## `nabu-stitching`: perform stitching on a set of volume or projections -See [](stitching) +See [](stitching/index) ## `nabu-stitching-config`: create a configuration file for a stitching (to provide to `nabu-stitching`) diff --git a/doc/stitching/data_duplication.md b/doc/stitching/data_duplication.md new file mode 100644 index 0000000000000000000000000000000000000000..1d5a9ae0e5af2b4d33d694f65d3bc710e338b825 --- /dev/null +++ b/doc/stitching/data_duplication.md @@ -0,0 +1,24 @@ +# data duplication + +the parameter stitching/avoid_data_duplication allow you to activate / deactivate data duplication with some limitation / rules (see below) + +```txt +[stitching] +... +avoid_data_duplication = True +``` + +## pre-processing stitching + +If stitching is done over projections. As the output provides flat-field normalized projection it will always duplicate data. At the exception of the case where projections of the ![NXtomo](https://manual.nexusformat.org/classes/applications/NXtomo.html) are already normalized and `avoid_data_duplication` is set to True + +## post-processing stitching + +If stitiching is done over reconstructed volume then in any case the stitching areas between volumes will be created (as this is new). +But remaining area will correspond to raw volumes. And we can avoid copy of this part (for HDF5 input and output volumes only). + +If `avoid_data_duplication` is set to True then part corresponding to raw volume will simply be link to original reconstructed volumes. + +```{warning} +if `avoid_data_duplication` is True this also mean that stiched reconstructed volume will contain relative links to the reconstruced volumes. So if those are moved / removed or if the stitched volume is moved this mean that link will be break. +``` diff --git a/doc/stitching/index.rst b/doc/stitching/index.rst index 33027f807b49b766005ef82c62635d733a62a25a..d638f01aa22b09f66fecf164724304279b7c0270 100644 --- a/doc/stitching/index.rst +++ b/doc/stitching/index.rst @@ -20,3 +20,4 @@ This is a brief overview of stitching methods and how to benefit from it. normalization_by_sample.md distribution_on_slurm.md design.md + data_duplication.md diff --git a/nabu/stitching/config.py b/nabu/stitching/config.py index a0fd7785587d72e1d2ebff578cba2dabc84ca3ee..3f30d7958b3b0418f6e02660ffc14e811ff68ce7 100644 --- a/nabu/stitching/config.py +++ b/nabu/stitching/config.py @@ -134,6 +134,8 @@ ALIGNMENT_AXIS_1_FIELD = "alignment_axis_1" PAD_MODE_FIELD = "pad_mode" +AVOID_DATA_DUPLICATION_FIELD = "avoid_data_duplication" + # SLURM SLURM_SECTION = "slurm" @@ -539,6 +541,10 @@ class StitchingConfiguration: normalization_by_sample: NormalizationBySample = None + duplicate_data: bool = True + """when possible (for HDF5) avoid duplicating data as-much-much-as-possible. Overlaping region between two frames will be duplicated. Remaining will be 'raw_data' for volume. + For projection flat field will be applied""" + @property def stitching_type(self): raise NotImplementedError("Base class") @@ -657,6 +663,12 @@ class StitchingConfiguration: "help": f"pad mode to use for frame alignment. Valid values are 'constant', 'edge', 'linear_ramp', maximum', 'mean', 'median', 'minimum', 'reflect', 'symmetric', 'wrap', and 'empty'. See nupy.pad documentation for details", "type": "advanced", }, + AVOID_DATA_DUPLICATION_FIELD: { + "default": "1", + "help": "When possible (stitching on reconstructed volume and HDF5 volume as input and output) create link to original data instead of duplicating it all. Warning: this will create relative link between the stiched volume and the original reconstructed volume.", + "validator": boolean_validator, + "type": "advanced", + }, }, OUTPUT_SECTION: { OVERWRITE_RESULTS_FIELD: { @@ -771,6 +783,7 @@ class StitchingConfiguration: RESCALE_FRAMES: self.rescale_frames, RESCALE_PARAMS: _dict_to_str(self.rescale_params or {}), STITCHING_KERNELS_EXTRA_PARAMS: _dict_to_str(self.stitching_kernels_extra_params or {}), + AVOID_DATA_DUPLICATION_FIELD: not self.duplicate_data, }, OUTPUT_SECTION: { OVERWRITE_RESULTS_FIELD: int( @@ -1028,6 +1041,7 @@ class PreProcessedZStitchingConfiguration(ZStitchingConfiguration): config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER) ), pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"), + duplicate_data=not config[STITCHING_SECTION].get(AVOID_DATA_DUPLICATION_FIELD, False), normalization_by_sample=NormalizationBySample.from_dict(config.get(NORMALIZATION_BY_SAMPLE_SECTION, {})), ) @@ -1182,6 +1196,7 @@ class PostProcessedZStitchingConfiguration(ZStitchingConfiguration): config[STITCHING_SECTION].get(ALIGNMENT_AXIS_2_FIELD, AlignmentAxis2.CENTER) ), pad_mode=config[STITCHING_SECTION].get(PAD_MODE_FIELD, "constant"), + duplicate_data=not config[STITCHING_SECTION].get(AVOID_DATA_DUPLICATION_FIELD, False), normalization_by_sample=NormalizationBySample.from_dict(config.get(NORMALIZATION_BY_SAMPLE_SECTION, {})), ) diff --git a/nabu/stitching/tests/test_z_stitching.py b/nabu/stitching/tests/test_z_postprocessing_stitching.py similarity index 53% rename from nabu/stitching/tests/test_z_stitching.py rename to nabu/stitching/tests/test_z_postprocessing_stitching.py index e0e48ba9d3990f07005c3bd96d05b1b2c07d9715..d17732d5376d8c5693a5a3e15976f894b3466694 100644 --- a/nabu/stitching/tests/test_z_stitching.py +++ b/nabu/stitching/tests/test_z_postprocessing_stitching.py @@ -1,64 +1,20 @@ -# coding: utf-8 -# /*########################################################################## -# -# Copyright (c) 2016-2017 European Synchrotron Radiation Facility -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -# -# ###########################################################################*/ - -__authors__ = ["H. Payno"] -__license__ = "MIT" -__date__ = "13/05/2022" - - import os -from silx.image.phantomgenerator import PhantomGenerator -from scipy.ndimage import shift as scipy_shift + +import h5py import numpy import pytest -from nabu.stitching.config import ( - PreProcessedZStitchingConfiguration, - PostProcessedZStitchingConfiguration, -) -from nabu.utils import Progress -from nabu.stitching.config import KEY_IMG_REG_METHOD, NormalizationBySample -from nabu.stitching.overlap import ZStichOverlapKernel, OverlapStitchingStrategy -from nabu.stitching.z_stitching import ( - PostProcessZStitcher, - PreProcessZStitcher, - stitch_vertically_raw_frames, - ZStitcher, -) -from nxtomo.nxobject.nxdetector import ImageKey -from nxtomo.utils.transformation import UDDetTransformation, LRDetTransformation -from nxtomo.application.nxtomo import NXtomo -from nabu.stitching.alignment import AlignmentAxis1, AlignmentAxis2 +from silx.image.phantomgenerator import PhantomGenerator +from tomoscan.esrf.volume import EDFVolume, HDF5Volume +from tomoscan.esrf.volume.tiffvolume import TIFFVolume, has_tifffile from tomoscan.factory import Factory as TomoscanFactory from tomoscan.utils.volume import concatenate as concatenate_volumes -from tomoscan.esrf.scan.nxtomoscan import NXtomoScan -from tomoscan.esrf.volume import HDF5Volume, EDFVolume -from tomoscan.esrf.volume.jp2kvolume import JP2KVolume, has_minimal_openjpeg -from tomoscan.esrf.volume.tiffvolume import TIFFVolume, has_tifffile -from nabu.stitching.utils import ShiftAlgorithm -import h5py +from nabu.stitching.alignment import AlignmentAxis1, AlignmentAxis2 +from nabu.stitching.config import NormalizationBySample, PostProcessedZStitchingConfiguration +from nabu.stitching.overlap import OverlapStitchingStrategy +from nabu.stitching.utils import ShiftAlgorithm +from nabu.stitching.z_stitching import PostProcessZStitcher +from nabu.utils import Progress strategies_to_test_weights = ( OverlapStitchingStrategy.CLOSEST, @@ -84,365 +40,6 @@ def build_raw_volume(): return raw_volume -@pytest.mark.parametrize("strategy", strategies_to_test_weights) -def test_overlap_z_stitcher(strategy): - frame_width = 128 - frame_height = frame_width - frame_1 = PhantomGenerator.get2DPhantomSheppLogan(n=frame_width) - stitcher = ZStichOverlapKernel( - stitching_strategy=strategy, - overlap_size=frame_height, - frame_width=128, - ) - stitched_frame = stitcher.stitch(frame_1, frame_1)[0] - assert stitched_frame.shape == (frame_height, frame_width) - # check result is close to the expected one - numpy.testing.assert_allclose(frame_1, stitched_frame, atol=10e-10) - - # check sum of weights ~ 1.0 - numpy.testing.assert_allclose( - stitcher.weights_img_1 + stitcher.weights_img_2, - numpy.ones_like(stitcher.weights_img_1), - ) - - -@pytest.mark.parametrize("dtype", (numpy.float16, numpy.float32)) -def test_z_stitch_raw_frames(dtype): - """ - test z_stitch_raw_frames: insure a stitching with 3 frames and different overlap can be done - """ - ref_frame_width = 256 - frame_ref = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(dtype) - - # split the frame into several part - frame_1 = frame_ref[0:100] - frame_2 = frame_ref[80:164] - frame_3 = frame_ref[154:] - - kernel_1 = ZStichOverlapKernel(frame_width=ref_frame_width, overlap_size=20) - kernel_2 = ZStichOverlapKernel(frame_width=ref_frame_width, overlap_size=10) - - stitched = stitch_vertically_raw_frames( - frames=(frame_1, frame_2, frame_3), - output_dtype=dtype, - overlap_kernels=(kernel_1, kernel_2), - raw_frames_compositions=None, - overlap_frames_compositions=None, - key_lines=( - ( - 90, # frame_1 height - kernel_1 height / 2.0 - 10, # kernel_1 height / 2.0 - ), - ( - 79, # frame_2 height - kernel_2 height / 2.0 ou 102-20 ? - 5, # kernel_2 height / 2.0 - ), - ), - ) - - assert stitched.shape == frame_ref.shape - numpy.testing.assert_array_almost_equal(frame_ref, stitched) - - -def test_z_stitch_raw_frames_2(): - """ - test z_stitch_raw_frames: insure a stitching with 3 frames and different overlap can be done - """ - ref_frame_width = 256 - frame_ref = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(numpy.float32) - - # split the frame into several part - frame_1 = frame_ref.copy() - frame_2 = frame_ref.copy() - frame_3 = frame_ref.copy() - - kernel_1 = ZStichOverlapKernel(frame_width=ref_frame_width, overlap_size=10) - kernel_2 = ZStichOverlapKernel(frame_width=ref_frame_width, overlap_size=10) - - stitched = stitch_vertically_raw_frames( - frames=(frame_1, frame_2, frame_3), - output_dtype=numpy.float32, - overlap_kernels=(kernel_1, kernel_2), - raw_frames_compositions=None, - overlap_frames_compositions=None, - key_lines=((20, 20), (105, 105)), - ) - - assert stitched.shape == frame_ref.shape - numpy.testing.assert_array_almost_equal(frame_ref, stitched) - - -_stitching_configurations = ( - # simple case where shifts are provided - { - "n_proj": 4, - "raw_pos": ((0, 0, 0), (-90, 0, 0), (-180, 0, 0)), # requested shift to - "input_pos": ((0, 0, 0), (-90, 0, 0), (-180, 0, 0)), # requested shift to - "raw_shifts": ((0, 0), (-90, 0), (-180, 0)), - }, - # simple case where shift is found from z position - { - "n_proj": 4, - "raw_pos": ((90, 0, 0), (0, 0, 0), (-90, 0, 0)), - "input_pos": ((90, 0, 0), (0, 0, 0), (-90, 0, 0)), - "check_bb": ((40, 140), (-50, 50), (-140, -40)), - "axis_0_params": { - KEY_IMG_REG_METHOD: ShiftAlgorithm.NONE, - }, - "axis_2_params": { - KEY_IMG_REG_METHOD: ShiftAlgorithm.NONE, - }, - "raw_shifts": ((0, 0), (-90, 0), (-180, 0)), - }, -) - - -@pytest.mark.parametrize("configuration", _stitching_configurations) -@pytest.mark.parametrize("dtype", (numpy.float32, numpy.int16)) -def test_PreProcessZStitcher(tmp_path, dtype, configuration): - """ - test PreProcessZStitcher class and insure a full stitching can be done automatically. - """ - n_proj = configuration["n_proj"] - ref_frame_width = 280 - raw_frame_height = 100 - ref_frame = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(dtype) * 256.0 - - # add some mark for image registration - ref_frame[:, 96] = -3.2 - ref_frame[:, 125] = 9.1 - ref_frame[:, 165] = 4.4 - ref_frame[:, 200] = -2.5 - # create raw data - frame_0_shift, frame_1_shift, frame_2_shift = configuration["raw_shifts"] - frame_0 = scipy_shift(ref_frame, shift=frame_0_shift)[:raw_frame_height] - frame_1 = scipy_shift(ref_frame, shift=frame_1_shift)[:raw_frame_height] - frame_2 = scipy_shift(ref_frame, shift=frame_2_shift)[:raw_frame_height] - - frames = frame_0, frame_1, frame_2 - frame_0_input_pos, frame_1_input_pos, frame_2_input_pos = configuration["input_pos"] - frame_0_raw_pos, frame_1_raw_pos, frame_2_raw_pos = configuration["raw_pos"] - - # create a Nxtomo for each of those raw data - raw_data_dir = tmp_path / "raw_data" - raw_data_dir.mkdir() - output_dir = tmp_path / "output_dir" - output_dir.mkdir() - z_position = ( - frame_0_raw_pos[0], - frame_1_raw_pos[0], - frame_2_raw_pos[0], - ) - scans = [] - for (i_frame, frame), z_pos in zip(enumerate(frames), z_position): - nx_tomo = NXtomo() - nx_tomo.sample.z_translation = [z_pos] * n_proj - nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) - nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj - nx_tomo.instrument.detector.x_pixel_size = 1.0 - nx_tomo.instrument.detector.y_pixel_size = 1.0 - nx_tomo.instrument.detector.distance = 2.3 - nx_tomo.energy = 19.2 - nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj) - - file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx") - entry = f"entry000{i_frame}" - nx_tomo.save(file_path=file_path, data_path=entry) - scans.append(NXtomoScan(scan=file_path, entry=entry)) - - # if requested: check bounding box - check_bb = configuration.get("check_bb", None) - if check_bb is not None: - for scan, expected_bb in zip(scans, check_bb): - assert scan.get_bounding_box(axis="z") == expected_bb - output_file_path = os.path.join(output_dir, "stitched.nx") - output_data_path = "stitched" - z_stich_config = PreProcessedZStitchingConfiguration( - stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, - overwrite_results=True, - axis_0_pos_px=( - frame_0_input_pos[0], - frame_1_input_pos[0], - frame_2_input_pos[0], - ), - axis_1_pos_px=( - frame_0_input_pos[1], - frame_1_input_pos[1], - frame_2_input_pos[1], - ), - axis_2_pos_px=( - frame_0_input_pos[2], - frame_1_input_pos[2], - frame_2_input_pos[2], - ), - axis_0_pos_mm=None, - axis_1_pos_mm=None, - axis_2_pos_mm=None, - input_scans=scans, - output_file_path=output_file_path, - output_data_path=output_data_path, - axis_0_params=configuration.get("axis_0_params", {}), - axis_1_params=configuration.get("axis_1_params", {}), - axis_2_params=configuration.get("axis_2_params", {}), - output_nexus_version=None, - slices=None, - slurm_config=None, - slice_for_cross_correlation="middle", - pixel_size=None, - ) - stitcher = PreProcessZStitcher(z_stich_config) - output_identifier = stitcher.stitch() - assert output_identifier.file_path == output_file_path - assert output_identifier.data_path == output_data_path - - created_nx_tomo = NXtomo().load( - file_path=output_identifier.file_path, - data_path=output_identifier.data_path, - detector_data_as="as_numpy_array", - ) - - assert created_nx_tomo.instrument.detector.data.ndim == 3 - mean_abs_error = configuration.get("mean_abs_error", None) - if mean_abs_error is not None: - assert ( - numpy.mean(numpy.abs(ref_frame - created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :])) - < mean_abs_error - ) - else: - numpy.testing.assert_array_almost_equal( - ref_frame, created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :] - ) - - # check also other metadata are here - assert created_nx_tomo.instrument.detector.distance.value == 2.3 - assert created_nx_tomo.energy.value == 19.2 - numpy.testing.assert_array_equal( - created_nx_tomo.instrument.detector.image_key_control, - numpy.asarray([ImageKey.PROJECTION.PROJECTION] * n_proj), - ) - - # check configuration has been saved - with h5py.File(output_identifier.file_path, mode="r") as h5f: - assert "stitching_configuration" in h5f[output_identifier.data_path] - - -slices_to_test_pre = ( - { - "slices": (None,), - "complete": True, - }, - { - "slices": (("first",), ("middle",), ("last",)), - "complete": False, - }, - { - "slices": ((0, 1, 2), slice(3, -1, 1)), - "complete": True, - }, -) - - -@pytest.mark.parametrize("configuration_dist", slices_to_test_pre) -def test_DistributePreProcessZStitcher(tmp_path, configuration_dist): - slices = configuration_dist["slices"] - complete = configuration_dist["complete"] - - n_projs = 100 - raw_data = numpy.arange(100 * 128 * 128).reshape((100, 128, 128)) - - # create raw data - frame_0 = raw_data[:, 60:] - assert frame_0.ndim == 3 - frame_0_pos = 40 - frame_1 = raw_data[:, 0:80] - assert frame_1.ndim == 3 - frame_1_pos = 94 - frames = (frame_0, frame_1) - z_positions = (frame_0_pos, frame_1_pos) - - # create a Nxtomo for each of those raw data - raw_data_dir = tmp_path / "raw_data" - raw_data_dir.mkdir() - output_dir = tmp_path / "output_dir" - output_dir.mkdir() - - scans = [] - for (i_frame, frame), z_pos in zip(enumerate(frames), z_positions): - nx_tomo = NXtomo() - nx_tomo.sample.z_translation = [z_pos] * n_projs - nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False) - nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_projs - nx_tomo.instrument.detector.x_pixel_size = 1.0 - nx_tomo.instrument.detector.y_pixel_size = 1.0 - nx_tomo.instrument.detector.distance = 2.3 - nx_tomo.energy = 19.2 - nx_tomo.instrument.detector.data = frame - - file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx") - entry = f"entry000{i_frame}" - nx_tomo.save(file_path=file_path, data_path=entry) - scans.append(NXtomoScan(scan=file_path, entry=entry)) - - stitched_nx_tomo = [] - for s in slices: - output_file_path = os.path.join(output_dir, "stitched_section.nx") - output_data_path = f"stitched_{s}" - z_stich_config = PreProcessedZStitchingConfiguration( - axis_0_pos_px=z_positions, - axis_1_pos_px=None, - axis_2_pos_px=(0, 0), - axis_0_pos_mm=None, - axis_1_pos_mm=None, - axis_2_pos_mm=None, - axis_0_params={}, - axis_1_params={}, - axis_2_params={}, - stitching_strategy=OverlapStitchingStrategy.CLOSEST, - overwrite_results=True, - input_scans=scans, - output_file_path=output_file_path, - output_data_path=output_data_path, - output_nexus_version=None, - slices=s, - slurm_config=None, - slice_for_cross_correlation="middle", - pixel_size=None, - ) - stitcher = PreProcessZStitcher(z_stich_config) - output_identifier = stitcher.stitch() - assert output_identifier.file_path == output_file_path - assert output_identifier.data_path == output_data_path - - created_nx_tomo = NXtomo().load( - file_path=output_identifier.file_path, - data_path=output_identifier.data_path, - detector_data_as="as_numpy_array", - ) - stitched_nx_tomo.append(created_nx_tomo) - assert len(stitched_nx_tomo) == len(slices) - final_nx_tomo = NXtomo.concatenate(stitched_nx_tomo) - assert isinstance(final_nx_tomo.instrument.detector.data, numpy.ndarray) - final_nx_tomo.save( - file_path=os.path.join(output_dir, "final_stitched.nx"), - data_path="entry0000", - ) - - if complete: - len(final_nx_tomo.instrument.detector.data) == 128 - # test middle - numpy.testing.assert_array_almost_equal(raw_data[1], final_nx_tomo.instrument.detector.data[1, :, :]) - else: - len(final_nx_tomo.instrument.detector.data) == 3 - # test middle - numpy.testing.assert_array_almost_equal(raw_data[49], final_nx_tomo.instrument.detector.data[1, :, :]) - # in the case of first, middle and last frames - # test first - numpy.testing.assert_array_almost_equal(raw_data[0], final_nx_tomo.instrument.detector.data[0, :, :]) - - # test last - numpy.testing.assert_array_almost_equal(raw_data[-1], final_nx_tomo.instrument.detector.data[-1, :, :]) - - _VOL_CLASSES_TO_TEST_FOR_POSTPROC_STITCHING = [HDF5Volume, EDFVolume] # avoid testing glymur because doesn't handle float # if has_minimal_openjpeg: @@ -706,124 +303,6 @@ def test_DistributePostProcessZStitcher(tmp_path, configuration_dist, flip_ud): ) -def test_get_overlap_areas(): - """test get_overlap_areas function""" - f_upper = numpy.linspace(7, 15, num=9, endpoint=True) - f_lower = numpy.linspace(0, 12, num=13, endpoint=True) - - o_1, o_2 = ZStitcher.get_overlap_areas( - upper_frame=f_upper, - lower_frame=f_lower, - upper_frame_key_line=3, - lower_frame_key_line=10, - overlap_size=4, - stitching_axis=0, - ) - - numpy.testing.assert_array_equal(o_1, o_2) - numpy.testing.assert_array_equal(o_1, numpy.linspace(8, 11, num=4, endpoint=True)) - - -def test_frame_flip(tmp_path): - """check it with some NXtomo fliped""" - pytest.skip(reason="Broken test") - ref_frame_width = 280 - n_proj = 10 - raw_frame_width = 100 - ref_frame = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(numpy.float32) * 256.0 - # create raw data - frame_0_shift = (0, 0) - frame_1_shift = (-90, 0) - frame_2_shift = (-180, 0) - - frame_0 = scipy_shift(ref_frame, shift=frame_0_shift)[:raw_frame_width] - frame_1 = scipy_shift(ref_frame, shift=frame_1_shift)[:raw_frame_width] - frame_2 = scipy_shift(ref_frame, shift=frame_2_shift)[:raw_frame_width] - frames = frame_0, frame_1, frame_2 - - x_flips = [False, True, True] - y_flips = [False, False, True] - - def apply_flip(args): - frame, flip_x, flip_y = args - if flip_x: - frame = numpy.fliplr(frame) - if flip_y: - frame = numpy.flipud(frame) - return frame - - frames = map(apply_flip, zip(frames, x_flips, y_flips)) - - # create a Nxtomo for each of those raw data - raw_data_dir = tmp_path / "raw_data" - raw_data_dir.mkdir() - output_dir = tmp_path / "output_dir" - output_dir.mkdir() - z_position = (90, 0, -90) - - scans = [] - for (i_frame, frame), z_pos, x_flip, y_flip in zip(enumerate(frames), z_position, x_flips, y_flips): - nx_tomo = NXtomo() - nx_tomo.sample.z_translation = [z_pos] * n_proj - nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) - nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj - nx_tomo.instrument.detector.x_pixel_size = 1.0 - nx_tomo.instrument.detector.y_pixel_size = 1.0 - nx_tomo.instrument.detector.distance = 2.3 - if x_flip: - nx_tomo.instrument.detector.transformations.add_transformation(LRDetTransformation()) - if y_flip: - nx_tomo.instrument.detector.transformations.add_transformation(UDDetTransformation()) - nx_tomo.energy = 19.2 - nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj) - - file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx") - entry = f"entry000{i_frame}" - nx_tomo.save(file_path=file_path, data_path=entry) - scans.append(NXtomoScan(scan=file_path, entry=entry)) - - output_file_path = os.path.join(output_dir, "stitched.nx") - output_data_path = "stitched" - assert len(scans) == 3 - z_stich_config = PreProcessedZStitchingConfiguration( - axis_0_pos_px=(0, -90, -180), - axis_1_pos_px=None, - axis_2_pos_px=(0, 0, 0), - axis_0_pos_mm=None, - axis_1_pos_mm=None, - axis_2_pos_mm=None, - axis_0_params={}, - axis_1_params={}, - axis_2_params={}, - stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, - overwrite_results=True, - input_scans=scans, - output_file_path=output_file_path, - output_data_path=output_data_path, - output_nexus_version=None, - slices=None, - slurm_config=None, - slice_for_cross_correlation="middle", - pixel_size=None, - ) - stitcher = PreProcessZStitcher(z_stich_config) - output_identifier = stitcher.stitch() - assert output_identifier.file_path == output_file_path - assert output_identifier.data_path == output_data_path - - created_nx_tomo = NXtomo().load( - file_path=output_identifier.file_path, - data_path=output_identifier.data_path, - detector_data_as="as_numpy_array", - ) - - assert created_nx_tomo.instrument.detector.data.ndim == 3 - # insure flipping has been taking into account - numpy.testing.assert_array_almost_equal(ref_frame, created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :]) - - assert len(created_nx_tomo.instrument.detector.transformations) == 0 - - @pytest.mark.parametrize("alignment_axis_2", ("left", "right", "center")) def test_vol_z_stitching_with_alignment_axis_2(tmp_path, alignment_axis_2): """ @@ -926,14 +405,6 @@ def test_vol_z_stitching_with_alignment_axis_2(tmp_path, alignment_axis_2): assert output_volume.data.shape == (120, 4, 120) - import h5py - - with h5py.File("input.h5", mode="w") as h5f: - h5f["data"] = raw_volume - - with h5py.File("output.h5", mode="w") as h5f: - h5f["data"] = output_volume.data - if alignment_axis_2 == "center": numpy.testing.assert_array_almost_equal(raw_volume[:, :, 10:-10], output_volume.data[:, :, 10:-10]) elif alignment_axis_2 == "left": @@ -1064,8 +535,6 @@ def test_normalization_by_sample(tmp_path): simple test of a volume stitching. Raw volumes have 'extra' values (+2, +5, +9) that must be removed at the end thanks to the normalization """ - from copy import deepcopy - raw_volume = build_raw_volume() # create folder to save data (and debug) raw_data_dir = tmp_path / "raw_data" @@ -1181,3 +650,115 @@ def test_normalization_by_sample(tmp_path): assert "configuration" in metadata assert output_volume.position[0] == -60.0 assert output_volume.pixel_size == (1.0, 1.0, 1.0) + + +@pytest.mark.parametrize("data_duplication", (True, False)) +def test_data_duplication(tmp_path, data_duplication): + raw_volume = build_raw_volume() + + # create folder to save data (and debug) + raw_data_dir = tmp_path / "raw_data" + raw_data_dir.mkdir() + output_dir = tmp_path / "output_dir" + output_dir.mkdir() + + volume_1 = HDF5Volume( + data=raw_volume[0:30], + metadata={ + "processing_options": { + "reconstruction": { + "position": (-15.0, 0.0, 0.0), + "voxel_size_cm": (100.0, 100.0, 100.0), + } + }, + }, + file_path=os.path.join(raw_data_dir, f"raw_volume_1.hdf5"), + data_path="volume", + ) + + volume_2 = HDF5Volume( + data=raw_volume[20:80], + metadata={ + "processing_options": { + "reconstruction": { + "position": (-50.0, 0.0, 0.0), + "voxel_size_cm": (100.0, 100.0, 100.0), + } + }, + }, + file_path=os.path.join(raw_data_dir, f"raw_volume_2.hdf5"), + data_path="volume", + ) + + volume_3 = HDF5Volume( + data=raw_volume[60:], + metadata={ + "processing_options": { + "reconstruction": { + "position": (-90.0, 0.0, 0.0), + "voxel_size_cm": (100.0, 100.0, 100.0), + } + }, + }, + file_path=os.path.join(raw_data_dir, f"raw_volume_3.hdf5"), + data_path="volume", + ) + + for volume in (volume_1, volume_2, volume_3): + volume.save() + volume.clear_cache() + + output_volume = HDF5Volume( + file_path=os.path.join(output_dir, "stitched_volume.hdf5"), + data_path="stitched_volume", + ) + + z_stich_config = PostProcessedZStitchingConfiguration( + stitching_strategy=OverlapStitchingStrategy.CLOSEST, + overwrite_results=True, + input_volumes=(volume_1, volume_2, volume_3), + output_volume=output_volume, + slices=None, + slurm_config=None, + axis_0_pos_px=None, + axis_0_pos_mm=None, + axis_0_params={"img_reg_method": ShiftAlgorithm.NONE}, + axis_1_pos_px=None, + axis_1_pos_mm=None, + axis_1_params={"img_reg_method": ShiftAlgorithm.NONE}, + axis_2_pos_px=None, + axis_2_pos_mm=None, + axis_2_params={"img_reg_method": ShiftAlgorithm.NONE}, + slice_for_cross_correlation="middle", + voxel_size=None, + duplicate_data=data_duplication, + ) + + stitcher = PostProcessZStitcher(z_stich_config, progress=None) + output_identifier = stitcher.stitch() + + import shutil + + shutil.copytree( + output_dir, + "/home/payno/Documents/dev/tomography/nabu/test_stitching", + ) + + assert output_identifier.file_path == output_volume.file_path + assert output_identifier.data_path == output_volume.data_path + + output_volume.data = None + output_volume.metadata = None + output_volume.load_data(store=True) + output_volume.load_metadata(store=True) + + assert raw_volume.shape == output_volume.data.shape + numpy.testing.assert_almost_equal(raw_volume.data, output_volume.data) + + with h5py.File(output_volume.file_path, mode="r") as h5f: + if data_duplication: + assert f"{output_volume.data_path}/stitching_regions" not in h5f + assert not h5f[f"{output_volume.data_path}/results/data"].is_virtual + else: + assert f"{output_volume.data_path}/stitching_regions" in h5f + assert h5f[f"{output_volume.data_path}/results/data"].is_virtual diff --git a/nabu/stitching/tests/test_z_preprocessing_stitching.py b/nabu/stitching/tests/test_z_preprocessing_stitching.py new file mode 100644 index 0000000000000000000000000000000000000000..5186411454ce471d1e9f6f81e9a2abe79f9126b5 --- /dev/null +++ b/nabu/stitching/tests/test_z_preprocessing_stitching.py @@ -0,0 +1,534 @@ +# coding: utf-8 +# /*########################################################################## +# +# Copyright (c) 2016-2017 European Synchrotron Radiation Facility +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# ###########################################################################*/ + +__authors__ = ["H. Payno"] +__license__ = "MIT" +__date__ = "13/05/2022" + + +import os +from silx.image.phantomgenerator import PhantomGenerator +from scipy.ndimage import shift as scipy_shift +import numpy +import pytest +from nabu.stitching.config import PreProcessedZStitchingConfiguration +from nabu.stitching.config import KEY_IMG_REG_METHOD +from nabu.stitching.overlap import ZStichOverlapKernel, OverlapStitchingStrategy +from nabu.stitching.z_stitching import ( + PreProcessZStitcher, + stitch_vertically_raw_frames, + ZStitcher, +) +from nxtomo.nxobject.nxdetector import ImageKey +from nxtomo.utils.transformation import UDDetTransformation, LRDetTransformation +from nxtomo.application.nxtomo import NXtomo +from tomoscan.esrf.scan.nxtomoscan import NXtomoScan +from nabu.stitching.utils import ShiftAlgorithm +import h5py + + +strategies_to_test_weights = ( + OverlapStitchingStrategy.CLOSEST, + OverlapStitchingStrategy.COSINUS_WEIGHTS, + OverlapStitchingStrategy.LINEAR_WEIGHTS, + OverlapStitchingStrategy.MEAN, +) + + +@pytest.mark.parametrize("strategy", strategies_to_test_weights) +def test_overlap_z_stitcher(strategy): + frame_width = 128 + frame_height = frame_width + frame_1 = PhantomGenerator.get2DPhantomSheppLogan(n=frame_width) + stitcher = ZStichOverlapKernel( + stitching_strategy=strategy, + overlap_size=frame_height, + frame_width=128, + ) + stitched_frame = stitcher.stitch(frame_1, frame_1)[0] + assert stitched_frame.shape == (frame_height, frame_width) + # check result is close to the expected one + numpy.testing.assert_allclose(frame_1, stitched_frame, atol=10e-10) + + # check sum of weights ~ 1.0 + numpy.testing.assert_allclose( + stitcher.weights_img_1 + stitcher.weights_img_2, + numpy.ones_like(stitcher.weights_img_1), + ) + + +@pytest.mark.parametrize("dtype", (numpy.float16, numpy.float32)) +def test_z_stitch_raw_frames(dtype): + """ + test z_stitch_raw_frames: insure a stitching with 3 frames and different overlap can be done + """ + ref_frame_width = 256 + frame_ref = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(dtype) + + # split the frame into several part + frame_1 = frame_ref[0:100] + frame_2 = frame_ref[80:164] + frame_3 = frame_ref[154:] + + kernel_1 = ZStichOverlapKernel(frame_width=ref_frame_width, overlap_size=20) + kernel_2 = ZStichOverlapKernel(frame_width=ref_frame_width, overlap_size=10) + + stitched = stitch_vertically_raw_frames( + frames=(frame_1, frame_2, frame_3), + output_dtype=dtype, + overlap_kernels=(kernel_1, kernel_2), + raw_frames_compositions=None, + overlap_frames_compositions=None, + key_lines=( + ( + 90, # frame_1 height - kernel_1 height / 2.0 + 10, # kernel_1 height / 2.0 + ), + ( + 79, # frame_2 height - kernel_2 height / 2.0 ou 102-20 ? + 5, # kernel_2 height / 2.0 + ), + ), + ) + + assert stitched.shape == frame_ref.shape + numpy.testing.assert_array_almost_equal(frame_ref, stitched) + + +def test_z_stitch_raw_frames_2(): + """ + test z_stitch_raw_frames: insure a stitching with 3 frames and different overlap can be done + """ + ref_frame_width = 256 + frame_ref = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(numpy.float32) + + # split the frame into several part + frame_1 = frame_ref.copy() + frame_2 = frame_ref.copy() + frame_3 = frame_ref.copy() + + kernel_1 = ZStichOverlapKernel(frame_width=ref_frame_width, overlap_size=10) + kernel_2 = ZStichOverlapKernel(frame_width=ref_frame_width, overlap_size=10) + + stitched = stitch_vertically_raw_frames( + frames=(frame_1, frame_2, frame_3), + output_dtype=numpy.float32, + overlap_kernels=(kernel_1, kernel_2), + raw_frames_compositions=None, + overlap_frames_compositions=None, + key_lines=((20, 20), (105, 105)), + ) + + assert stitched.shape == frame_ref.shape + numpy.testing.assert_array_almost_equal(frame_ref, stitched) + + +_stitching_configurations = ( + # simple case where shifts are provided + { + "n_proj": 4, + "raw_pos": ((0, 0, 0), (-90, 0, 0), (-180, 0, 0)), # requested shift to + "input_pos": ((0, 0, 0), (-90, 0, 0), (-180, 0, 0)), # requested shift to + "raw_shifts": ((0, 0), (-90, 0), (-180, 0)), + }, + # simple case where shift is found from z position + { + "n_proj": 4, + "raw_pos": ((90, 0, 0), (0, 0, 0), (-90, 0, 0)), + "input_pos": ((90, 0, 0), (0, 0, 0), (-90, 0, 0)), + "check_bb": ((40, 140), (-50, 50), (-140, -40)), + "axis_0_params": { + KEY_IMG_REG_METHOD: ShiftAlgorithm.NONE, + }, + "axis_2_params": { + KEY_IMG_REG_METHOD: ShiftAlgorithm.NONE, + }, + "raw_shifts": ((0, 0), (-90, 0), (-180, 0)), + }, +) + + +@pytest.mark.parametrize("configuration", _stitching_configurations) +@pytest.mark.parametrize("dtype", (numpy.float32, numpy.int16)) +def test_PreProcessZStitcher(tmp_path, dtype, configuration): + """ + test PreProcessZStitcher class and insure a full stitching can be done automatically. + """ + n_proj = configuration["n_proj"] + ref_frame_width = 280 + raw_frame_height = 100 + ref_frame = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(dtype) * 256.0 + + # add some mark for image registration + ref_frame[:, 96] = -3.2 + ref_frame[:, 125] = 9.1 + ref_frame[:, 165] = 4.4 + ref_frame[:, 200] = -2.5 + # create raw data + frame_0_shift, frame_1_shift, frame_2_shift = configuration["raw_shifts"] + frame_0 = scipy_shift(ref_frame, shift=frame_0_shift)[:raw_frame_height] + frame_1 = scipy_shift(ref_frame, shift=frame_1_shift)[:raw_frame_height] + frame_2 = scipy_shift(ref_frame, shift=frame_2_shift)[:raw_frame_height] + + frames = frame_0, frame_1, frame_2 + frame_0_input_pos, frame_1_input_pos, frame_2_input_pos = configuration["input_pos"] + frame_0_raw_pos, frame_1_raw_pos, frame_2_raw_pos = configuration["raw_pos"] + + # create a Nxtomo for each of those raw data + raw_data_dir = tmp_path / "raw_data" + raw_data_dir.mkdir() + output_dir = tmp_path / "output_dir" + output_dir.mkdir() + z_position = ( + frame_0_raw_pos[0], + frame_1_raw_pos[0], + frame_2_raw_pos[0], + ) + scans = [] + for (i_frame, frame), z_pos in zip(enumerate(frames), z_position): + nx_tomo = NXtomo() + nx_tomo.sample.z_translation = [z_pos] * n_proj + nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) + nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj + nx_tomo.instrument.detector.x_pixel_size = 1.0 + nx_tomo.instrument.detector.y_pixel_size = 1.0 + nx_tomo.instrument.detector.distance = 2.3 + nx_tomo.energy = 19.2 + nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj) + + file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx") + entry = f"entry000{i_frame}" + nx_tomo.save(file_path=file_path, data_path=entry) + scans.append(NXtomoScan(scan=file_path, entry=entry)) + + # if requested: check bounding box + check_bb = configuration.get("check_bb", None) + if check_bb is not None: + for scan, expected_bb in zip(scans, check_bb): + assert scan.get_bounding_box(axis="z") == expected_bb + output_file_path = os.path.join(output_dir, "stitched.nx") + output_data_path = "stitched" + z_stich_config = PreProcessedZStitchingConfiguration( + stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, + overwrite_results=True, + axis_0_pos_px=( + frame_0_input_pos[0], + frame_1_input_pos[0], + frame_2_input_pos[0], + ), + axis_1_pos_px=( + frame_0_input_pos[1], + frame_1_input_pos[1], + frame_2_input_pos[1], + ), + axis_2_pos_px=( + frame_0_input_pos[2], + frame_1_input_pos[2], + frame_2_input_pos[2], + ), + axis_0_pos_mm=None, + axis_1_pos_mm=None, + axis_2_pos_mm=None, + input_scans=scans, + output_file_path=output_file_path, + output_data_path=output_data_path, + axis_0_params=configuration.get("axis_0_params", {}), + axis_1_params=configuration.get("axis_1_params", {}), + axis_2_params=configuration.get("axis_2_params", {}), + output_nexus_version=None, + slices=None, + slurm_config=None, + slice_for_cross_correlation="middle", + pixel_size=None, + ) + stitcher = PreProcessZStitcher(z_stich_config) + output_identifier = stitcher.stitch() + assert output_identifier.file_path == output_file_path + assert output_identifier.data_path == output_data_path + + created_nx_tomo = NXtomo().load( + file_path=output_identifier.file_path, + data_path=output_identifier.data_path, + detector_data_as="as_numpy_array", + ) + + assert created_nx_tomo.instrument.detector.data.ndim == 3 + mean_abs_error = configuration.get("mean_abs_error", None) + if mean_abs_error is not None: + assert ( + numpy.mean(numpy.abs(ref_frame - created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :])) + < mean_abs_error + ) + else: + numpy.testing.assert_array_almost_equal( + ref_frame, created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :] + ) + + # check also other metadata are here + assert created_nx_tomo.instrument.detector.distance.value == 2.3 + assert created_nx_tomo.energy.value == 19.2 + numpy.testing.assert_array_equal( + created_nx_tomo.instrument.detector.image_key_control, + numpy.asarray([ImageKey.PROJECTION.PROJECTION] * n_proj), + ) + + # check configuration has been saved + with h5py.File(output_identifier.file_path, mode="r") as h5f: + assert "stitching_configuration" in h5f[output_identifier.data_path] + + +slices_to_test_pre = ( + { + "slices": (None,), + "complete": True, + }, + { + "slices": (("first",), ("middle",), ("last",)), + "complete": False, + }, + { + "slices": ((0, 1, 2), slice(3, -1, 1)), + "complete": True, + }, +) + + +@pytest.mark.parametrize("configuration_dist", slices_to_test_pre) +def test_DistributePreProcessZStitcher(tmp_path, configuration_dist): + slices = configuration_dist["slices"] + complete = configuration_dist["complete"] + + n_projs = 100 + raw_data = numpy.arange(100 * 128 * 128).reshape((100, 128, 128)) + + # create raw data + frame_0 = raw_data[:, 60:] + assert frame_0.ndim == 3 + frame_0_pos = 40 + frame_1 = raw_data[:, 0:80] + assert frame_1.ndim == 3 + frame_1_pos = 94 + frames = (frame_0, frame_1) + z_positions = (frame_0_pos, frame_1_pos) + + # create a Nxtomo for each of those raw data + raw_data_dir = tmp_path / "raw_data" + raw_data_dir.mkdir() + output_dir = tmp_path / "output_dir" + output_dir.mkdir() + + scans = [] + for (i_frame, frame), z_pos in zip(enumerate(frames), z_positions): + nx_tomo = NXtomo() + nx_tomo.sample.z_translation = [z_pos] * n_projs + nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_projs, endpoint=False) + nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_projs + nx_tomo.instrument.detector.x_pixel_size = 1.0 + nx_tomo.instrument.detector.y_pixel_size = 1.0 + nx_tomo.instrument.detector.distance = 2.3 + nx_tomo.energy = 19.2 + nx_tomo.instrument.detector.data = frame + + file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx") + entry = f"entry000{i_frame}" + nx_tomo.save(file_path=file_path, data_path=entry) + scans.append(NXtomoScan(scan=file_path, entry=entry)) + + stitched_nx_tomo = [] + for s in slices: + output_file_path = os.path.join(output_dir, "stitched_section.nx") + output_data_path = f"stitched_{s}" + z_stich_config = PreProcessedZStitchingConfiguration( + axis_0_pos_px=z_positions, + axis_1_pos_px=None, + axis_2_pos_px=(0, 0), + axis_0_pos_mm=None, + axis_1_pos_mm=None, + axis_2_pos_mm=None, + axis_0_params={}, + axis_1_params={}, + axis_2_params={}, + stitching_strategy=OverlapStitchingStrategy.CLOSEST, + overwrite_results=True, + input_scans=scans, + output_file_path=output_file_path, + output_data_path=output_data_path, + output_nexus_version=None, + slices=s, + slurm_config=None, + slice_for_cross_correlation="middle", + pixel_size=None, + ) + stitcher = PreProcessZStitcher(z_stich_config) + output_identifier = stitcher.stitch() + assert output_identifier.file_path == output_file_path + assert output_identifier.data_path == output_data_path + + created_nx_tomo = NXtomo().load( + file_path=output_identifier.file_path, + data_path=output_identifier.data_path, + detector_data_as="as_numpy_array", + ) + stitched_nx_tomo.append(created_nx_tomo) + assert len(stitched_nx_tomo) == len(slices) + final_nx_tomo = NXtomo.concatenate(stitched_nx_tomo) + assert isinstance(final_nx_tomo.instrument.detector.data, numpy.ndarray) + final_nx_tomo.save( + file_path=os.path.join(output_dir, "final_stitched.nx"), + data_path="entry0000", + ) + + if complete: + len(final_nx_tomo.instrument.detector.data) == 128 + # test middle + numpy.testing.assert_array_almost_equal(raw_data[1], final_nx_tomo.instrument.detector.data[1, :, :]) + else: + len(final_nx_tomo.instrument.detector.data) == 3 + # test middle + numpy.testing.assert_array_almost_equal(raw_data[49], final_nx_tomo.instrument.detector.data[1, :, :]) + # in the case of first, middle and last frames + # test first + numpy.testing.assert_array_almost_equal(raw_data[0], final_nx_tomo.instrument.detector.data[0, :, :]) + + # test last + numpy.testing.assert_array_almost_equal(raw_data[-1], final_nx_tomo.instrument.detector.data[-1, :, :]) + + +def test_get_overlap_areas(): + """test get_overlap_areas function""" + f_upper = numpy.linspace(7, 15, num=9, endpoint=True) + f_lower = numpy.linspace(0, 12, num=13, endpoint=True) + + o_1, o_2 = ZStitcher.get_overlap_areas( + upper_frame=f_upper, + lower_frame=f_lower, + upper_frame_key_line=3, + lower_frame_key_line=10, + overlap_size=4, + stitching_axis=0, + ) + + numpy.testing.assert_array_equal(o_1, o_2) + numpy.testing.assert_array_equal(o_1, numpy.linspace(8, 11, num=4, endpoint=True)) + + +def test_frame_flip(tmp_path): + """check it with some NXtomo fliped""" + pytest.skip(reason="Broken test") + ref_frame_width = 280 + n_proj = 10 + raw_frame_width = 100 + ref_frame = PhantomGenerator.get2DPhantomSheppLogan(n=ref_frame_width).astype(numpy.float32) * 256.0 + # create raw data + frame_0_shift = (0, 0) + frame_1_shift = (-90, 0) + frame_2_shift = (-180, 0) + + frame_0 = scipy_shift(ref_frame, shift=frame_0_shift)[:raw_frame_width] + frame_1 = scipy_shift(ref_frame, shift=frame_1_shift)[:raw_frame_width] + frame_2 = scipy_shift(ref_frame, shift=frame_2_shift)[:raw_frame_width] + frames = frame_0, frame_1, frame_2 + + x_flips = [False, True, True] + y_flips = [False, False, True] + + def apply_flip(args): + frame, flip_x, flip_y = args + if flip_x: + frame = numpy.fliplr(frame) + if flip_y: + frame = numpy.flipud(frame) + return frame + + frames = map(apply_flip, zip(frames, x_flips, y_flips)) + + # create a Nxtomo for each of those raw data + raw_data_dir = tmp_path / "raw_data" + raw_data_dir.mkdir() + output_dir = tmp_path / "output_dir" + output_dir.mkdir() + z_position = (90, 0, -90) + + scans = [] + for (i_frame, frame), z_pos, x_flip, y_flip in zip(enumerate(frames), z_position, x_flips, y_flips): + nx_tomo = NXtomo() + nx_tomo.sample.z_translation = [z_pos] * n_proj + nx_tomo.sample.rotation_angle = numpy.linspace(0, 180, num=n_proj, endpoint=False) + nx_tomo.instrument.detector.image_key_control = [ImageKey.PROJECTION] * n_proj + nx_tomo.instrument.detector.x_pixel_size = 1.0 + nx_tomo.instrument.detector.y_pixel_size = 1.0 + nx_tomo.instrument.detector.distance = 2.3 + if x_flip: + nx_tomo.instrument.detector.transformations.add_transformation(LRDetTransformation()) + if y_flip: + nx_tomo.instrument.detector.transformations.add_transformation(UDDetTransformation()) + nx_tomo.energy = 19.2 + nx_tomo.instrument.detector.data = numpy.asarray([frame] * n_proj) + + file_path = os.path.join(raw_data_dir, f"nxtomo_{i_frame}.nx") + entry = f"entry000{i_frame}" + nx_tomo.save(file_path=file_path, data_path=entry) + scans.append(NXtomoScan(scan=file_path, entry=entry)) + + output_file_path = os.path.join(output_dir, "stitched.nx") + output_data_path = "stitched" + assert len(scans) == 3 + z_stich_config = PreProcessedZStitchingConfiguration( + axis_0_pos_px=(0, -90, -180), + axis_1_pos_px=None, + axis_2_pos_px=(0, 0, 0), + axis_0_pos_mm=None, + axis_1_pos_mm=None, + axis_2_pos_mm=None, + axis_0_params={}, + axis_1_params={}, + axis_2_params={}, + stitching_strategy=OverlapStitchingStrategy.LINEAR_WEIGHTS, + overwrite_results=True, + input_scans=scans, + output_file_path=output_file_path, + output_data_path=output_data_path, + output_nexus_version=None, + slices=None, + slurm_config=None, + slice_for_cross_correlation="middle", + pixel_size=None, + ) + stitcher = PreProcessZStitcher(z_stich_config) + output_identifier = stitcher.stitch() + assert output_identifier.file_path == output_file_path + assert output_identifier.data_path == output_data_path + + created_nx_tomo = NXtomo().load( + file_path=output_identifier.file_path, + data_path=output_identifier.data_path, + detector_data_as="as_numpy_array", + ) + + assert created_nx_tomo.instrument.detector.data.ndim == 3 + # insure flipping has been taking into account + numpy.testing.assert_array_almost_equal(ref_frame, created_nx_tomo.instrument.detector.data[0, :ref_frame_width, :]) + + assert len(created_nx_tomo.instrument.detector.transformations) == 0 diff --git a/nabu/stitching/z_stitching.py b/nabu/stitching/z_stitching.py index db8dd2e2c3f0729986e2eb14bb50c5769368489f..e747b57178d0338ea1a2ea33526c04255d80a6f7 100644 --- a/nabu/stitching/z_stitching.py +++ b/nabu/stitching/z_stitching.py @@ -433,6 +433,32 @@ class ZStitcher: ] ) + @staticmethod + def _dump_frame(output_dataset, index, stitched_frame, axis): + if axis == 0: + output_dataset[index] = stitched_frame + elif axis == 1: + output_dataset[:, index, :] = stitched_frame + elif axis == 2: + output_dataset[:, :, index] = stitched_frame + else: + raise ValueError + + @staticmethod + def _dump_stitched_VS( + output_VL: h5py.VirtualLayout, index, stitched_VS: h5py.VirtualSource, axis, region_start, region_end + ): + assert isinstance(output_VL, h5py.VirtualLayout), "'output_VL' should be a 'h5py.VirtualLayout'" + assert isinstance(stitched_VS, h5py.VirtualSource), "'stitched_VS' should be a 'h5py.VirtualSource'" + if axis == 0: + output_VL[index, region_start:region_end] = stitched_VS + elif axis == 1: + output_VL[region_start:region_end, index, :] = stitched_VS + elif axis == 2: + output_VL[region_start:region_end, :, index] = stitched_VS + else: + raise ValueError + @staticmethod def stitch_frames( frames: Union[tuple, numpy.ndarray], @@ -442,7 +468,7 @@ class ZStitcher: stitching_axis: int, overlap_kernels: tuple, output_dataset: Optional[Union[h5py.Dataset, numpy.ndarray]] = None, - dump_frame_fct=None, + dump_frame_axis: Optional[int] = None, check_inputs=True, shift_mode="nearest", i_frame=None, @@ -450,6 +476,8 @@ class ZStitcher: alignment="center", pad_mode="constant", new_width: Optional[int] = None, + stitching_regions_hdf5_dataset: Optional[tuple] = None, + raw_regions_hdf5_dataset: Optional[tuple] = None, ) -> numpy.ndarray: """ shift frames according to provided `shifts` (as y, x tuples) then stitch all the shifted frames together and @@ -525,7 +553,7 @@ class ZStitcher: x_shifted_data.append(shifted_frame) # step 2: create stitched frame - res = stitch_vertically_raw_frames( + stitched_frame, composition_cls = stitch_vertically_raw_frames( frames=x_shifted_data, key_lines=( [ @@ -536,24 +564,97 @@ class ZStitcher: overlap_kernels=overlap_kernels, check_inputs=check_inputs, output_dtype=output_dtype, - return_composition_cls=return_composition_cls, + return_composition_cls=True, alignment=alignment, pad_mode=pad_mode, new_width=new_width, ) - if return_composition_cls: - stitched_frame, _ = res - else: - stitched_frame = res # step 3: dump stitched frame - if output_dataset is not None and i_frame is not None: - dump_frame_fct( + duplicate_data = stitching_regions_hdf5_dataset is None + # 3.1 full frame + if duplicate_data and output_dataset is not None and i_frame is not None: + ZStitcher._dump_frame( output_dataset=output_dataset, index=i_frame, stitched_frame=stitched_frame, + axis=dump_frame_axis, ) - return res + # 3.2 on stitching regions + if not duplicate_data: + assert isinstance( + output_dataset, h5py.VirtualLayout + ), "in the case we want to avoid data duplication 'output_dataset' must be a VirtualLayout" + # save stitched region + stitching_regions = composition_cls["overlap_compositon"] + for (_, _, region_start, region_end), stitching_region_hdf5_dataset in zip( + stitching_regions.browse(), stitching_regions_hdf5_dataset + ): + assert isinstance(output_dataset, h5py.VirtualLayout) + assert isinstance(stitching_region_hdf5_dataset, h5py.Dataset) + stitching_region_array = stitched_frame[region_start:region_end] + ZStitcher._dump_frame( + output_dataset=stitching_region_hdf5_dataset, + index=i_frame, + stitched_frame=stitching_region_array, + axis=1, + ) + vs = ZStitcher.create_subset_selection( + dataset=stitching_region_hdf5_dataset, + slices=( + slice(0, stitching_region_hdf5_dataset.shape[0]), + slice(i_frame, i_frame + 1), + slice(0, stitching_region_hdf5_dataset.shape[2]), + ), + ) + + ZStitcher._dump_stitched_VS( + output_VL=output_dataset, + index=i_frame, + axis=1, + region_start=region_start, + region_end=region_end, + stitched_VS=vs, + ) + + # create virtual source of the raw data + raw_regions = composition_cls["raw_compositon"] + for (frame_start, frame_end, region_start, region_end), raw_region_hdf5_dataset in zip( + raw_regions.browse(), raw_regions_hdf5_dataset + ): + vs = ZStitcher.create_subset_selection( + dataset=raw_region_hdf5_dataset, + slices=( + slice(frame_start, frame_end), + slice(i_frame, i_frame + 1), + slice(0, raw_region_hdf5_dataset.shape[2]), + ), + ) + + ZStitcher._dump_stitched_VS( + output_VL=output_dataset, + index=i_frame, + axis=1, + region_start=region_start, + region_end=region_end, + stitched_VS=vs, + ) + + if return_composition_cls: + return stitched_frame, composition_cls + else: + return stitched_frame + + @staticmethod + def create_subset_selection(dataset: h5py.Dataset, slices: tuple) -> h5py.VirtualSource: + assert isinstance(dataset, h5py.Dataset), f"dataset is expected to be a h5py.Dataset. Get {type(dataset)}" + assert isinstance(slices, tuple), f"slices is expected to be a tuple of slices. Get {type(slices)} instead" + import h5py._hl.selections as selection + + virtual_source = h5py.VirtualSource(dataset) + sel = selection.select(dataset.shape, slices, dataset=dataset) + virtual_source.sel = sel + return virtual_source @staticmethod @cache(maxsize=None) @@ -606,10 +707,6 @@ class PreProcessZStitcher(ZStitcher): if self.configuration.axis_2_params is None: self.configuration.axis_2_params = {} - @staticmethod - def _dump_frame(output_dataset: h5py.Dataset, index: int, stitched_frame: numpy.ndarray): - output_dataset[index] = stitched_frame - @property def reading_orders(self): """ @@ -1210,6 +1307,13 @@ class PreProcessZStitcher(ZStitcher): ), self._stitching_width, ) + if not self.configuration.duplicate_data: + # if we want to avoid data duplication we will need to create new dataset for the stitching area + stitching_sources_arr_shapes = tuple( + [(n_proj, abs(overlap), self._stitching_width) for overlap in self._axis_0_rel_shifts] + ) + else: + stitching_sources_arr_shapes = tuple() # get expected output dataset first (just in case output and input files are the same) first_proj_idx = sorted(self.z_serie[0].projections.keys())[0] @@ -1249,19 +1353,36 @@ class PreProcessZStitcher(ZStitcher): ) def get_output_data_type(): - return numpy.float32 # because we will apply flat field correction on it and they are not raw data - # scan = self.z_serie[0] - # radio_url = tuple(scan.projections.values())[0] - # assert isinstance(radio_url, DataUrl) - # data = get_data(radio_url) - # return data.dtype + """ + output data type for pre-processing is always float32. + Because flat field correction on it and they are not raw data + """ + return numpy.float32 output_dtype = get_output_data_type() + overlap_hdf5_datasets = [] + # create datasets to save overlap if needed + with HDF5File(filename=self.configuration.output_file_path, mode="a") as h5f: + for i_region, overlap_shape in enumerate(stitching_sources_arr_shapes): + overlap_data_path = "/".join( + [ + self.configuration.output_data_path, + "stitching_regions", + f"region_{i_region}", + ] + ) + overlap_hdf5_datasets.append( + h5f.create_dataset( + name=overlap_data_path, + shape=overlap_shape, + dtype=output_dtype, + ) + ) + # append frames ("instrument/detactor/data" dataset) with HDF5File(filename=self.configuration.output_file_path, mode="a") as h5f: # note: nx_tomo.save already handles the possible overwrite conflict by removing # self.configuration.output_file_path or raising an error - stitched_frame_path = "/".join( [ self.configuration.output_data_path, @@ -1311,13 +1432,14 @@ class PreProcessZStitcher(ZStitcher): overlap_kernels=self._overlap_kernels, i_frame=i_proj, output_dtype=output_dtype, - dump_frame_fct=self._dump_frame, + dump_frame_axis=1, return_composition_cls=store_composition if i_proj == 0 else False, stitching_axis=0, pad_mode=self.configuration.pad_mode, alignment=self.configuration.alignment_axis_2, new_width=self._stitching_width, check_inputs=i_proj == 0, # on process check on the first iteration + stitching_regions_hdf5_dataset=overlap_hdf5_datasets, ) if i_proj == 0 and store_composition: _, self._frame_composition = sf @@ -1645,6 +1767,13 @@ class PostProcessZStitcher(ZStitcher): if self.configuration.output_volume is None: raise ValueError("input volume should be provided") + if not self.configuration.duplicate_data and not ( + isinstance(self.configuration.output_volume, HDF5Volume) + and numpy.all([isinstance(volume, HDF5Volume) for volume in self._input_volumes]) + ): + _logger.warning("Unable to avoid data duplication. All volumes must be HDF5") + self.configuration.duplicate_data = True + n_volumes = len(self.z_serie) if n_volumes == 0: raise ValueError("no scan to stich together") @@ -1706,10 +1835,20 @@ class PostProcessZStitcher(ZStitcher): def _create_stitched_volume(self, store_composition: bool): overlap_kernels = self._overlap_kernels self._slices_to_stitch, n_slices = self.configuration.settle_slices() + data_type = self.get_output_data_type() # sync overwrite_results with volume overwrite parameter self.configuration.output_volume.overwrite = self.configuration.overwrite_results + # prepare HDF5 dataset for stitching region if requested + if not self.configuration.duplicate_data: + # if we want to avoid data duplication we will need to create new dataset for the stitching area + stitching_sources_arr_shapes = tuple( + [(abs(overlap), n_slices, self._stitching_width) for overlap in self._axis_0_rel_shifts] + ) + else: + stitching_sources_arr_shapes = tuple() + # init final volume final_volume = self.configuration.output_volume final_volume_shape = ( @@ -1721,8 +1860,6 @@ class PostProcessZStitcher(ZStitcher): self._stitching_width, ) - data_type = self.get_output_data_type() - if self.progress: self.progress.set_max_advancement(final_volume_shape[1]) @@ -1732,8 +1869,12 @@ class PostProcessZStitcher(ZStitcher): else: step = 1 with PostProcessZStitcher._FinalDatasetContext( - volume=final_volume, volume_shape=final_volume_shape, dtype=data_type - ) as output_dataset: + volume=final_volume, + volume_shape=final_volume_shape, + dtype=data_type, + stitching_sources_arr_shapes=stitching_sources_arr_shapes, + ) as output_datasets: + output_dataset, stitching_regions_hdf5_dataset = output_datasets # note: output_dataset is a HDF5 dataset if final volume is an HDF5 volume else is a numpy array with PostProcessZStitcher._RawDatasetsContext( self._input_volumes, @@ -1765,12 +1906,14 @@ class PostProcessZStitcher(ZStitcher): y_relative_shifts=self._axis_0_rel_shifts, overlap_kernels=overlap_kernels, output_dataset=output_dataset, - dump_frame_fct=self._dump_frame, + dump_frame_axis=0, i_frame=y_index, output_dtype=data_type, return_composition_cls=store_composition if y_index == 0 else False, stitching_axis=0, check_inputs=y_index == 0, # on process check on the first iteration + stitching_regions_hdf5_dataset=stitching_regions_hdf5_dataset, + raw_regions_hdf5_dataset=raw_datasets, ) if y_index == 0 and store_composition: _, self._frame_composition = sf @@ -1817,9 +1960,17 @@ class PostProcessZStitcher(ZStitcher): In the case of HDF5 we want to save this directly in the file to avoid keeping the full volume in memory. Insure also contain processing will be common between the different processing + + If stitching_sources_arr_shapes is provided this mean that we want to create stitching region and then create a VDS to avoid data duplication """ - def __init__(self, volume: VolumeBase, volume_shape: tuple, dtype: numpy.dtype) -> None: + def __init__( + self, + volume: VolumeBase, + volume_shape: tuple, + dtype: numpy.dtype, + stitching_sources_arr_shapes: Optional[tuple], + ) -> None: super().__init__() if not isinstance(volume, VolumeBase): raise TypeError( @@ -1830,8 +1981,10 @@ class PostProcessZStitcher(ZStitcher): self._volume_shape = volume_shape self.__file_handler = None self._dtype = dtype + self._stitching_sources_arr_shapes = stitching_sources_arr_shapes + self._duplicate_data = stitching_sources_arr_shapes is None - def __enter__(self): + def _create_stitched_volume_dataset(self): # handle the specific case of HDF5. Goal: avoid getting the full stitched volume in memory if isinstance(self._volume, HDF5Volume): self.__file_handler = HDF5File(filename=self._volume.data_url.file_path(), mode="a") @@ -1843,25 +1996,61 @@ class PostProcessZStitcher(ZStitcher): _logger.error(f"Fail to overwrite data. Reason is {e}") data = None self.__file_handler.close() + self._duplicate_data = True + # avoid creating a dataset for stitched volume as creation of the stitched_volume failed return data # create dataset try: - data = self.__file_handler.create_dataset( - self._volume.data_url.data_path(), - shape=self._volume_shape, - dtype=self._dtype, - ) + if self._duplicate_data: + data = self.__file_handler.create_dataset( + self._volume.data_url.data_path(), + shape=self._volume_shape, + dtype=self._dtype, + ) + else: + data = h5py.VirtualLayout( + shape=self._volume_shape, + dtype=self._dtype, + ) except Exception as e2: _logger.error(f"Fail to create final dataset. Reason is {e2}") data = None self.__file_handler.close() - # for other file format: create the full dataset in memory + self._duplicate_data = True + # avoid creating a dataset for stitched volume as creation of the stitched_volume failed else: + # for other file format: create the full dataset in memory data = numpy.empty(self._volume_shape, dtype=self._dtype) return data + def _create_stitched_sub_region_datasets(self): + # create datasets to store overlaps if needed + if isinstance(self._volume, HDF5Volume) and not self._duplicate_data: + stitching_regions_hdf5_dataset = [] + for i_region, overlap_shape in enumerate(self._stitching_sources_arr_shapes): + data_path = f"{self._volume.data_path}/stitching_regions/region_{i_region}" + if self._volume.overwrite and data_path in self.__file_handler: + del self.__file_handler[data_path] + stitching_regions_hdf5_dataset.append( + self.__file_handler.create_dataset( + name=data_path, + shape=overlap_shape, + dtype=self._dtype, + ) + ) + else: + stitching_regions_hdf5_dataset = None + return stitching_regions_hdf5_dataset + + def __enter__(self): + self._data = self._create_stitched_volume_dataset() + stitching_regions_hdf5_dataset = self._create_stitched_sub_region_datasets() + return self._data, stitching_regions_hdf5_dataset + def __exit__(self, *exc): + if isinstance(self._data, h5py.VirtualLayout): + self.__file_handler.create_virtual_dataset(self._volume.data_url.data_path(), self._data) if self.__file_handler is not None: return self.__file_handler.close() else: