fullfield.py 34.5 KB
Newer Older
Pierre Paleo's avatar
Pierre Paleo committed
1
from os import path
Pierre Paleo's avatar
Pierre Paleo committed
2
from time import time
Pierre Paleo's avatar
Pierre Paleo committed
3
import numpy as np
Pierre Paleo's avatar
Pierre Paleo committed
4
from silx.io.url import DataUrl
5
from .utils import use_options, pipeline_step, WriterConfigurator
6
7
from ..utils import ArrayPlaceHolder
from ..resources.logger import LoggerOrPrint
Pierre Paleo's avatar
Pierre Paleo committed
8
from ..io.reader import ChunkReader, HDF5Loader, get_hdf5_dataset_shape
9
from ..preproc.ccd import FlatFieldDataUrls, Log, CCDCorrection
10
from ..preproc.distortion import DistortionCorrection
11
from ..preproc.shift import VerticalShift
Pierre Paleo's avatar
Pierre Paleo committed
12
from ..preproc.double_flatfield import DoubleFlatField
Pierre Paleo's avatar
Pierre Paleo committed
13
from ..preproc.phase import PaganinPhaseRetrieval
14
from ..preproc.sinogram import SinoProcessing, SinoNormalization
15
from ..misc.rotation import Rotation
16
from ..preproc.rings import MunchDeringer
17
from ..misc.unsharp import UnsharpMask
18
from ..misc.histogram import PartialHistogram, hist_as_2Darray
19
from ..resources.utils import is_hdf5_extension, extract_parameters
20
21
22
23
24
# For now we don't have a plain python/numpy backend for reconstruction
try:
    from ..reconstruction.fbp_opencl import Backprojector
except:
    Backprojector = None
25

Pierre Paleo's avatar
Pierre Paleo committed
26

Pierre Paleo's avatar
Pierre Paleo committed
27
28
29
30
31
32
33
class FullFieldPipeline:
    """
    Pipeline for "regular" full-field tomography.
    Data is processed by chunks. A chunk consists in K contiguous lines of all the radios.
    In parallel geometry, a chunk of K radios lines gives K sinograms,
    and equivalently K reconstructed slices.
    """
34

35
    #FlatFieldClass = FlatField
Pierre Paleo's avatar
Pierre Paleo committed
36
    DoubleFlatFieldClass = DoubleFlatField
37
    CCDCorrectionClass = CCDCorrection
38
    PaganinPhaseRetrievalClass = PaganinPhaseRetrieval
Pierre Paleo's avatar
Pierre Paleo committed
39
    UnsharpMaskClass = UnsharpMask
40
    ImageRotationClass = Rotation
41
    VerticalShiftClass = VerticalShift
42
    SinoProcessingClass = SinoProcessing
43
    SinoDeringerClass = MunchDeringer
44
    MLogClass = Log
45
    SinoNormalizationClass = SinoNormalization
46
    FBPClass = Backprojector
47
    HistogramClass = PartialHistogram
48

49
    def __init__(self, process_config, sub_region, logger=None, extra_options=None, phase_margin=None):
Pierre Paleo's avatar
Add doc    
Pierre Paleo committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        """
        Initialize a Full-Field pipeline.

        Parameters
        ----------
        processing_config: `nabu.resources.processcinfig.ProcessConfig`
            Process configuration.
        sub_region: tuple
            Sub-region to process in the volume for this worker, in the format
            `(start_x, end_x, start_z, end_z)`.
        logger: `nabu.app.logger.Logger`, optional
            Logger class
        extra_options: dict, optional
            Advanced extra options.
64
65
66
67
68
69
70
71
72
73
74
75
76
        phase_margin: tuple, optional
            Margin to use when performing phase retrieval, in the form ((up, down), (left, right)).
            See also the documentation of PaganinPhaseRetrieval.
            If not provided, no margin is applied.


        Notes
        ------
        Using a `phase_margin` results in a lesser number of reconstructed slices.
        More specifically, if `phase_margin = (V, H)`, then there will be `delta_z - 2*V`
        reconstructed slices (if the sub-region is in the middle of the volume)
        or `delta_z - V` reconstructed slices (if the sub-region is on top or bottom
        of the volume).
Pierre Paleo's avatar
Add doc    
Pierre Paleo committed
77
        """
Pierre Paleo's avatar
Pierre Paleo committed
78
        self.logger = LoggerOrPrint(logger)
79
        self._set_params(process_config, sub_region, extra_options, phase_margin)
Pierre Paleo's avatar
Pierre Paleo committed
80
        self.set_subregion(sub_region)
81
82
        self._init_pipeline()

Pierre Paleo's avatar
Pierre Paleo committed
83

84
    def _set_params(self, process_config, sub_region, extra_options, phase_margin):
Pierre Paleo's avatar
Pierre Paleo committed
85
        self.process_config = process_config
Pierre Paleo's avatar
Pierre Paleo committed
86
        self.dataset_infos = self.process_config.dataset_infos
87
        self.processing_steps = self.process_config.processing_steps.copy()
Pierre Paleo's avatar
Pierre Paleo committed
88
89
90
        self.processing_options = self.process_config.processing_options
        self.sub_region = self._check_subregion(sub_region)
        self.delta_z = sub_region[-1] - sub_region[-2]
91
        self.chunk_size = self.delta_z
92
        self._set_phase_margin(phase_margin)
93
        self._set_extra_options(extra_options)
94
95
96
        self._callbacks = {}
        self._steps_name2component = {}
        self._steps_component2name = {}
Pierre Paleo's avatar
Pierre Paleo committed
97
        self._data_dump = {}
98
        self._resume_from_step = None
Pierre Paleo's avatar
Pierre Paleo committed
99

Pierre Paleo's avatar
Pierre Paleo committed
100
101
102
103
104
105
106
107
108
109
110

    @staticmethod
    def _check_subregion(sub_region):
        if len(sub_region) < 4:
            assert len(sub_region) == 2
            sub_region = (None, None) + sub_region
        if None in sub_region[-2:]:
            raise ValueError("Cannot set z_min or z_max to None")
        return sub_region


Pierre Paleo's avatar
Pierre Paleo committed
111
112
113
114
115
116
117
118
    def _set_extra_options(self, extra_options):
        if extra_options is None:
            extra_options = {}
        advanced_options = {}
        advanced_options.update(extra_options)
        self.extra_options = advanced_options


119
120
121
122
123
124
125
126
127
    def _set_phase_margin(self, phase_margin):
        if phase_margin is None:
            phase_margin = ((0, 0), (0, 0))
        self._phase_margin_up = phase_margin[0][0]
        self._phase_margin_down = phase_margin[0][1]
        self._phase_margin_left = phase_margin[1][0]
        self._phase_margin_right = phase_margin[1][1]


Pierre Paleo's avatar
Pierre Paleo committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    def set_subregion(self, sub_region):
        """
        Set a sub-region to process.

        Parameters
        ----------
        sub_region: tuple
            Sub-region to process in the volume, in the format
            `(start_x, end_x, start_z, end_z)` or `(start_z, end_z)`.
        """
        sub_region = self._check_subregion(sub_region)
        dz = sub_region[-1] - sub_region[-2]
        if dz != self.delta_z:
            raise ValueError(
                "Class was initialized for delta_z = %d but provided sub_region has delta_z = %d"
143
                % (self.delta_z, dz)
Pierre Paleo's avatar
Pierre Paleo committed
144
            )
Pierre Paleo's avatar
Pierre Paleo committed
145
146
147
        self.sub_region = sub_region
        self.z_min = sub_region[-2]
        self.z_max = sub_region[-1]
Pierre Paleo's avatar
Pierre Paleo committed
148

149

Pierre Paleo's avatar
Pierre Paleo committed
150
    def _compute_phase_kernel_margin(self):
151
152
153
154
155
156
        """
        Get the "margin" to pass to classes like PaganinPhaseRetrieval.
        In order to have a good accuracy for filter-based phase retrieval methods,
        we need to load extra data around the edges of each image. Otherwise,
        a default padding type is applied.
        """
157
        if not(self.use_radio_processing_margin):
Pierre Paleo's avatar
Pierre Paleo committed
158
            self._phase_margin = None
159
            return
160
161
        up_margin = self._phase_margin_up
        down_margin = self._phase_margin_down
162
163
164
        # Horizontal margin is not implemented
        left_margin, right_margin = (0, 0)
        self._phase_margin = ((up_margin, down_margin), (left_margin, right_margin))
Pierre Paleo's avatar
Pierre Paleo committed
165
166


167
168
169
170
171
    @property
    def use_radio_processing_margin(self):
        return ("phase" in self.processing_steps) or ("unsharp_mask" in self.processing_steps)


Pierre Paleo's avatar
Pierre Paleo committed
172
    def _get_phase_margin(self):
173
        if not(self.use_radio_processing_margin):
Pierre Paleo's avatar
Pierre Paleo committed
174
175
176
177
            return ((0, 0), (0, 0))
        return self._phase_margin


Pierre Paleo's avatar
Pierre Paleo committed
178
179
180
181
182
183
184
185
    def _get_cropped_radios(self):
        ((up_margin, down_margin), (left_margin, right_margin)) = self._phase_margin
        zslice = slice(up_margin or None, -down_margin or None)
        xslice = slice(left_margin or None, -right_margin or None)
        self._radios_cropped = self.radios[:, zslice, xslice]
        return self._radios_cropped


Pierre Paleo's avatar
Pierre Paleo committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    @property
    def phase_margin(self):
        """
        Return the margin for phase retrieval in the form ((up, down), (left, right))
        """
        return self._get_phase_margin()


    @property
    def n_recs(self):
        """
        Return the final number of reconstructed slices.
        """
        n_recs = self.delta_z
        n_recs -= sum(self._get_phase_margin()[0])
        return n_recs
202
203


204
205
206
207
208
209
210
    def _get_process_name(self, kind="reconstruction"):
        # In the future, might be something like "reconstruction-<ID>"
        if kind == "reconstruction":
            return "reconstruction"
        elif kind == "histogram":
            return "histogram"
        return kind
211
212


213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    def _configure_dump(self, step_name):
        if step_name not in self.processing_steps:
            if step_name == "sinogram" and self.process_config._dump_sinogram:
                fname_full = self.process_config._dump_sinogram_file
            else:
                return
        else:
            if not self.processing_options[step_name].get("save", False):
                return
            fname_full = self.processing_options[step_name]["save_steps_file"]

        fname, ext = path.splitext(fname_full)
        dirname, file_prefix = path.split(fname)
        output_dir = path.join(dirname, file_prefix)
        file_prefix += str("_%04d" % self._get_image_start_index())

Pierre Paleo's avatar
Pierre Paleo committed
229
        self._data_dump[step_name] = WriterConfigurator(
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
            output_dir, file_prefix,
            file_format="hdf5",
            overwrite=True,
            logger=self.logger,
            nx_info = {
                "process_name": step_name,
                "processing_index": 0, # TODO
                "config": {
                    "processing_options": self.processing_options,
                    "nabu_config": self.process_config.nabu_config
                },
                "entry": getattr(self.dataset_infos.dataset_scanner, "entry", None),
            }
        )


246
    def _configure_data_dumps(self):
247
        for step_name in self.processing_steps:
248
            self._configure_dump(step_name)
249
        # sinogram is a special keyword: not in processing_steps, but guaranteed to be before sinogram generation
250
        if self.process_config._dump_sinogram:
251
            self._configure_dump("sinogram")
252

Pierre Paleo's avatar
Pierre Paleo committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    #
    # Callbacks
    #

    def register_callback(self, step_name, callback):
        """
        Register a callback for a pipeline processing step.

        Parameters
        ----------
        step_name: str
            processing step name
        callback: callable
            A function. It will be executed once the processing step `step_name`
            is finished. The function takes only one argument: the class instance.
        """
        if step_name not in self.processing_steps:
            raise ValueError(
                "'%s' is not in processing steps %s"
                % (step_name, self.processing_steps)
            )
        if step_name in self._callbacks:
            self._callbacks[step_name].append(callback)
        else:
            self._callbacks[step_name] = [callback]

Pierre Paleo's avatar
Pierre Paleo committed
279

280
281
282
283
284
285
    def _reshape_radios_after_phase(self):
        """
        Callback executed after phase retrieval, if margin != (0, 0).
        It modifies self.radios so that further processing will be done
        on the "inner part".
        """
286
287
        if sum(self._get_phase_margin()[0]) <= 0:
            return
288
289
290
        self._orig_radios = self.radios
        self.logger.debug(
            "Reshaping radios from %s to %s"
291
            % (str(self.radios.shape), str(self._radios_cropped.shape))
292
        )
293
        self.radios = self._radios_cropped
294

295
296
297
298
299

    def _get_data_from_resumed_step(self):
        if self._resume_from_step:
            self.radios = self.chunk_reader.data

Pierre Paleo's avatar
Pierre Paleo committed
300
    #
301
    # Overwritten in inheriting classes
Pierre Paleo's avatar
Pierre Paleo committed
302
    #
303
304
305
306
307
308
309
310

    def _get_shape(self, step_name):
        """
        Get the shape to provide to the class corresponding to step_name.
        """
        if step_name == "flatfield":
            shape = self.radios.shape
        elif step_name == "double_flatfield":
Pierre Paleo's avatar
Pierre Paleo committed
311
            shape = self.radios.shape
312
313
        elif step_name == "rotate_projections":
            shape = self.radios.shape[1:]
314
        elif step_name == "phase":
Pierre Paleo's avatar
Pierre Paleo committed
315
            shape = self.radios.shape[1:]
316
        elif step_name == "ccd_correction":
317
            shape = self.radios.shape[1:]
318
        elif step_name == "unsharp_mask":
319
            shape = self.radios.shape[1:]
320
        elif step_name == "take_log":
321
            shape = self._radios_cropped_shape
322
        elif step_name == "radios_movements":
323
            shape = self._radios_cropped_shape
324
325
        elif step_name == "sino_normalization":
            shape = self._radios_cropped_shape
326
        elif step_name == "build_sino":
327
            shape = self._radios_cropped_shape
328
        elif step_name == "sino_rings_correction":
Pierre Paleo's avatar
Pierre Paleo committed
329
            shape = self.sino_builder.output_shape
330
        elif step_name == "reconstruction":
Pierre Paleo's avatar
Pierre Paleo committed
331
            shape = self.sino_builder.output_shape[1:]
332
333
        else:
            raise ValueError("Unknown processing step %s" % step_name)
Pierre Paleo's avatar
Pierre Paleo committed
334
        self.logger.debug("Data shape for %s is %s" % (step_name, str(shape)))
335
336
        return shape

Pierre Paleo's avatar
Pierre Paleo committed
337
    def _get_phase_output_shape(self):
338
        if not(self.use_radio_processing_margin):
339
            self._radios_cropped_shape = self.radios.shape
Pierre Paleo's avatar
Pierre Paleo committed
340
341
            return
        ((up_margin, down_margin), (left_margin, right_margin)) = self._phase_margin
342
        self._radios_cropped_shape = (
Pierre Paleo's avatar
Pierre Paleo committed
343
344
345
346
347
            self.radios.shape[0],
            self.radios.shape[1] - (up_margin + down_margin),
            self.radios.shape[2] - (left_margin + right_margin)
        )

Pierre Paleo's avatar
Pierre Paleo committed
348
    def _allocate_array(self, shape, dtype, name=None):
Pierre Paleo's avatar
Pierre Paleo committed
349
350
        return np.zeros(shape, dtype=dtype)

351
352
353
354
    def _allocate_sinobuilder_output(self):
        return self._allocate_array(self.sino_builder.output_shape, "f", name="sinos")

    def _allocate_recs(self, ny, nx):
Pierre Paleo's avatar
Pierre Paleo committed
355
        self.n_slices = self.radios.shape[1]  # TODO modify with vertical shifts
356
        if self.use_radio_processing_margin:
Pierre Paleo's avatar
Bugfix    
Pierre Paleo committed
357
            self.n_slices -= sum(self.phase_margin[0])
358
        self.recs = self._allocate_array((self.n_slices, ny, nx), "f", name="recs")
Pierre Paleo's avatar
Pierre Paleo committed
359

Pierre Paleo's avatar
Pierre Paleo committed
360
361
362
    def _reset_memory(self):
        pass

363

364
365
366
367
368
    def _get_read_dump_subregion(self):
        read_opts = self.processing_options["read_chunk"]
        if read_opts.get("process_file", None) is None:
            return None
        dump_start_z, dump_end_z = read_opts["dump_start_z"], read_opts["dump_end_z"]
Pierre Paleo's avatar
Pierre Paleo committed
369
370
        relative_start_z = self.z_min - dump_start_z
        relative_end_z = relative_start_z + self.delta_z
371
372
373
374
375
        # (n_angles, n_z, n_x)
        subregion = (None, None, relative_start_z, relative_end_z, None, None)
        return subregion


376
    def _configure_resume_from_step(self):
377
378
379
        if self._resume_from_step is None:
            self.radios = self.chunk_reader.files_data
        else:
380
            read_opts = self.processing_options["read_chunk"]
381
            expected_radios_shape = get_hdf5_dataset_shape(
382
383
384
                read_opts["process_file"],
                read_opts["process_h5_path"],
                sub_region=self._get_read_dump_subregion(),
385
386
            )
            self.radios = ArrayPlaceHolder(expected_radios_shape, "f", name="radios")
387
            self.register_callback("read_chunk", FullFieldPipeline._get_data_from_resumed_step)
388
389
390
391
392
393
394


    def _init_reader_finalize(self):
        """
        Method called after _init_reader
        """
        self._configure_resume_from_step()
Pierre Paleo's avatar
Pierre Paleo committed
395
        self._compute_phase_kernel_margin()
396
397
        self._get_phase_output_shape()

Pierre Paleo's avatar
Pierre Paleo committed
398
399
400
401
402

    def _process_finalize(self):
        """
        Method called once the pipeline has been executed
        """
Pierre Paleo's avatar
Pierre Paleo committed
403
404
        pass

405
406
407
408

    def _get_slice_start_index(self):
        return self.z_min + self._phase_margin_up

409
410
    _get_image_start_index = _get_slice_start_index

411
412
413
414
415
416
417
418
    #
    # Pipeline initialization
    #

    def _init_pipeline(self):
        self._init_reader()
        self._init_flatfield()
        self._init_double_flatfield()
419
        self._init_ccd_corrections()
420
        self._init_radios_rotation()
421
        self._init_phase()
422
        self._init_unsharp()
423
        self._init_radios_movements()
424
        self._init_mlog()
425
        self._init_sino_normalization()
426
        self._init_sino_builder()
427
        self._init_sino_rings_correction()
428
        self._prepare_reconstruction()
429
        self._init_reconstruction()
430
        self._init_histogram()
431
        self._init_writer()
432
        self._configure_data_dumps()
433

434
    @use_options("read_chunk", "chunk_reader")
Pierre Paleo's avatar
Pierre Paleo committed
435
436
437
438
    def _init_reader(self):
        if "read_chunk" not in self.processing_steps:
            raise ValueError("Cannot proceed without reading data")
        options = self.processing_options["read_chunk"]
439
440
441
442
443
444
445
446
447
448
449
450
451
        process_file = options.get("process_file", None)
        if process_file is None:
            # Standard case - start pipeline from raw data
            # ChunkReader always take a non-subsampled dictionary "files"
            self.chunk_reader = ChunkReader(
                options["files"],
                sub_region=self.sub_region,
                convert_float=True,
                binning=options["binning"],
                dataset_subsampling=options["dataset_subsampling"]
            )
        else:
            # Resume pipeline from dumped intermediate step
Pierre Paleo's avatar
Pierre Paleo committed
452
453
454
            self.chunk_reader = HDF5Loader(
                process_file,
                options["process_h5_path"],
455
                sub_region=self._get_read_dump_subregion()
Pierre Paleo's avatar
Pierre Paleo committed
456
            )
457
            self._resume_from_step = options["step_name"]
Pierre Paleo's avatar
Pierre Paleo committed
458
459
460
            self.logger.debug(
                "Load subregion %s from file %s" % (str(self.chunk_reader.sub_region), self.chunk_reader.fname)
            )
Pierre Paleo's avatar
Pierre Paleo committed
461
        self._init_reader_finalize()
Pierre Paleo's avatar
Pierre Paleo committed
462

463
464
    # Will be removed once FlatField == FlatFieldArrays
    def _get_flatfield_cls(self):
465
        return FlatFieldDataUrls
466

467
    @use_options("flatfield", "flatfield")
468
469
    def _init_flatfield(self, shape=None):
        if shape is None:
470
            shape = self._get_shape("flatfield")
Pierre Paleo's avatar
Pierre Paleo committed
471
        options = self.processing_options["flatfield"]
472
473
474
475

        distortion_correction = None
        if options["do_flat_distortion"]:
            self.logger.info("Flats distortion correction will be applied")
476
477
478
            estimation_kwargs = {}
            estimation_kwargs.update(options["flat_distortion_params"])
            estimation_kwargs["logger"] = self.logger
479
480
            distortion_correction = DistortionCorrection(
                estimation_method="fft-correlation",
481
                estimation_kwargs=estimation_kwargs,
482
483
484
                correction_method="interpn"
            )

485
        # FlatField parameter "radios_indices" must account for subsampling
486
487
        flatfield_cls = self._get_flatfield_cls()
        self.flatfield = flatfield_cls(
488
            shape,
Pierre Paleo's avatar
Pierre Paleo committed
489
490
            flats=self.dataset_infos.flats,
            darks=self.dataset_infos.darks,
Pierre Paleo's avatar
Pierre Paleo committed
491
            radios_indices=options["projs_indices"],
Pierre Paleo's avatar
Pierre Paleo committed
492
            interpolation="linear",
493
            distortion_correction=distortion_correction,
Pierre Paleo's avatar
Pierre Paleo committed
494
495
496
497
498
            sub_region=self.sub_region,
            binning=options["binning"],
            convert_float=True
        )

499
    @use_options("double_flatfield", "double_flatfield")
Pierre Paleo's avatar
Pierre Paleo committed
500
501
    def _init_double_flatfield(self):
        options = self.processing_options["double_flatfield"]
502
        avg_is_on_log = (options["sigma"] is not None)
Pierre Paleo's avatar
Pierre Paleo committed
503
504
505
506
        result_url = None
        if options["processes_file"] not in (None, ""):
            result_url = DataUrl(
                file_path=options["processes_file"],
507
                data_path=(self.dataset_infos.hdf5_entry or "entry") + "/double_flatfield/results/data",
Pierre Paleo's avatar
Pierre Paleo committed
508
509
            )
            self.logger.info("Loading double flatfield from %s" % result_url.file_path())
Pierre Paleo's avatar
Pierre Paleo committed
510
        self.double_flatfield = self.DoubleFlatFieldClass(
511
            self._get_shape("double_flatfield"),
Pierre Paleo's avatar
Pierre Paleo committed
512
513
            result_url=result_url,
            sub_region=self.sub_region,
Pierre Paleo's avatar
Pierre Paleo committed
514
515
            input_is_mlog=False,
            output_is_mlog=False,
516
            average_is_on_log=avg_is_on_log,
Pierre Paleo's avatar
Pierre Paleo committed
517
518
519
            sigma_filter=options["sigma"]
        )

520
521
522
    @use_options("ccd_correction", "ccd_correction")
    def _init_ccd_corrections(self):
        options = self.processing_options["ccd_correction"]
523
        self.ccd_correction = self.CCDCorrectionClass(
524
525
526
527
            self._get_shape("ccd_correction"),
            median_clip_thresh=options["median_clip_thresh"]
        )

528
    @use_options("phase", "phase_retrieval")
Pierre Paleo's avatar
Pierre Paleo committed
529
530
    def _init_phase(self):
        options = self.processing_options["phase"]
531
532
533
534
535
536
537
        # If unsharp mask follows phase retrieval, then it should be done
        # before cropping to the "inner part".
        # Otherwise, crop the data just after phase retrieval.
        if "unsharp_mask" in self.processing_steps:
            margin = None
        else:
            margin=self._phase_margin
538
        self.phase_retrieval = self.PaganinPhaseRetrievalClass(
539
            self._get_shape("phase"),
Pierre Paleo's avatar
Pierre Paleo committed
540
            distance=options["distance_m"],
Pierre Paleo's avatar
Pierre Paleo committed
541
542
            energy=options["energy_kev"],
            delta_beta=options["delta_beta"],
Pierre Paleo's avatar
Pierre Paleo committed
543
            pixel_size=options["pixel_size_m"],
Pierre Paleo's avatar
Pierre Paleo committed
544
            padding=options["padding_type"],
545
            margin=margin,
546
            fftw_num_threads=True, # TODO tune in advanced params of nabu config file
Pierre Paleo's avatar
Pierre Paleo committed
547
        )
548
549
550
551
552
        if self.phase_retrieval.use_fftw:
            self.logger.debug(
                "PaganinPhaseRetrieval using FFTW with %d threads"
                % self.phase_retrieval.fftw.num_threads
            )
553
        if "unsharp_mask" not in self.processing_steps:
554
            self.register_callback("phase", FullFieldPipeline._reshape_radios_after_phase)
555

556
    @use_options("unsharp_mask", "unsharp_mask")
Pierre Paleo's avatar
Pierre Paleo committed
557
558
    def _init_unsharp(self):
        options = self.processing_options["unsharp_mask"]
Pierre Paleo's avatar
Pierre Paleo committed
559
        self.unsharp_mask = self.UnsharpMaskClass(
560
            self._get_shape("unsharp_mask"),
Pierre Paleo's avatar
Pierre Paleo committed
561
562
563
            options["unsharp_sigma"], options["unsharp_coeff"],
            mode="reflect", method="gaussian"
        )
564
        self.register_callback("unsharp_mask", FullFieldPipeline._reshape_radios_after_phase)
Pierre Paleo's avatar
Pierre Paleo committed
565

Pierre Paleo's avatar
Pierre Paleo committed
566
    @use_options("take_log", "mlog")
Pierre Paleo's avatar
Pierre Paleo committed
567
    def _init_mlog(self):
Pierre Paleo's avatar
Pierre Paleo committed
568
569
        options = self.processing_options["take_log"]
        self.mlog = self.MLogClass(
570
            self._get_shape("take_log"),
571
572
573
            clip_min=options["log_min_clip"],
            clip_max=options["log_max_clip"]
        )
Pierre Paleo's avatar
Pierre Paleo committed
574

575
576
577
    @use_options("rotate_projections", "projs_rot")
    def _init_radios_rotation(self):
        options = self.processing_options["rotate_projections"]
Pierre Paleo's avatar
Pierre Paleo committed
578
579
580
581
        center = options["center"]
        if center is None:
            nx, ny = self.dataset_infos.radio_dims
            center = (nx/2 - 0.5, ny/2 - 0.5)
582
        center = (center[0], center[1] - self.z_min)
583
584
585
        self.projs_rot = self.ImageRotationClass(
            self._get_shape("rotate_projections"),
            options["angle"],
Pierre Paleo's avatar
Pierre Paleo committed
586
            center=center,
587
588
589
            mode="edge",
            reshape=False
        )
Pierre Paleo's avatar
Pierre Paleo committed
590
591
592
        self._tmp_rotated_radio = self._allocate_array(
            self._get_shape("rotate_projections"), "f", name="tmp_rotated_radio"
        )
593

594
595
596
597
598
    @use_options("radios_movements", "radios_movements")
    def _init_radios_movements(self):
        options = self.processing_options["radios_movements"]
        self._vertical_shifts = options["translation_movements"][:, 1]
        self.radios_movements = self.VerticalShiftClass(
599
            self._get_shape("radios_movements"),
600
601
602
            self._vertical_shifts
        )

603
604
605
606
607
608
609
610
    @use_options("sino_normalization", "sino_normalization")
    def _init_sino_normalization(self):
        options = self.processing_options["sino_normalization"]
        self.sino_normalization = self.SinoNormalizationClass(
            kind=options["method"],
            radios_shape=self._get_shape("sino_normalization"),
        )

611
    @use_options("build_sino", "sino_builder")
Pierre Paleo's avatar
Pierre Paleo committed
612
613
    def _init_sino_builder(self):
        options = self.processing_options["build_sino"]
614
        self.sino_builder = self.SinoProcessingClass(
615
            radios_shape=self._get_shape("build_sino"),
Pierre Paleo's avatar
Pierre Paleo committed
616
617
618
            rot_center=options["rotation_axis_position"],
            halftomo=options["enable_halftomo"],
        )
619
        if not(options["enable_halftomo"]):
Pierre Paleo's avatar
Pierre Paleo committed
620
621
622
623
624
            self._sinobuilder_copy = False
            self._sinobuilder_output = None
            self.sinos = None
        else:
            self._sinobuilder_copy = True
625
            self.sinos = self._allocate_sinobuilder_output()
Pierre Paleo's avatar
Pierre Paleo committed
626
627
            self._sinobuilder_output = self.sinos

628
629
630
631
632
633
634
635
636
637
638
    @use_options("sino_rings_correction", "sino_deringer")
    def _init_sino_rings_correction(self):
        options = self.processing_options["sino_rings_correction"]
        fw_params = extract_parameters(options["user_options"])
        fw_sigma = fw_params.pop("sigma", 1.)
        self.sino_deringer = self.SinoDeringerClass(
            fw_sigma,
            sinos_shape=self._get_shape("sino_rings_correction"),
            **fw_params
        )

639
    # this should be renamed, as it could be confused with _init_reconstruction. What about _get_reconstruction_array ?
640
    @use_options("reconstruction", "reconstruction")
Pierre Paleo's avatar
Pierre Paleo committed
641
642
    def _prepare_reconstruction(self):
        options = self.processing_options["reconstruction"]
Pierre Paleo's avatar
Pierre Paleo committed
643
644
645
646
        x_s, x_e = options["start_x"], options["end_x"]
        y_s, y_e = options["start_y"], options["end_y"]
        self._rec_roi = (x_s, x_e + 1, y_s, y_e + 1)
        self._allocate_recs(y_e - y_s + 1, x_e - x_s + 1)
Pierre Paleo's avatar
Pierre Paleo committed
647

648

649
    @use_options("reconstruction", "reconstruction")
650
    def _init_reconstruction(self):
651
652
653
654
        options = self.processing_options["reconstruction"]
        # TODO account for reconstruction from already formed sinograms
        if self.sino_builder is None:
            raise ValueError("Reconstruction cannot be done without build_sino")
Pierre Paleo's avatar
Pierre Paleo committed
655
656
        if self.FBPClass is None:
            raise ValueError("No usable FBP module was found")
657
658
659
660
661

        if options["enable_halftomo"]:
            rot_center = options["rotation_axis_position_halftomo"]
        else:
            rot_center = options["rotation_axis_position"]
662
663
        if options.get("cor_estimated_auto", False):
            self.logger.info("Estimated center of rotation: %.2f" % rot_center)
664
665
        if self.sino_builder._halftomo_flip:
            rot_center = self.sino_builder.rot_center
666
667

        self.reconstruction = self.FBPClass(
668
            self._get_shape("reconstruction"),
669
670
671
672
673
674
675
676
677
678
            angles=options["angles"],
            rot_center=rot_center,
            filter_name=options["fbp_filter_type"],
            slice_roi=self._rec_roi,
            scale_factor=1./options["pixel_size_cm"],
            extra_options={
                "padding_mode": options["padding_type"],
                "axis_correction": options["axis_correction"],
            }
        )
679
        if options["fbp_filter_type"] is None:
680
681
            self.reconstruction.fbp = self.reconstruction.backproj

682
683
    @use_options("histogram", "histogram")
    def _init_histogram(self):
684
        options = self.processing_options["histogram"]
Pierre Paleo's avatar
Pierre Paleo committed
685
        self.histogram = self.HistogramClass(
686
687
            method="fixed_bins_number", num_bins=options["histogram_bins"]
        )
688

689
    @use_options("save", "writer")
690
691
    def _init_writer(self):
        options = self.processing_options["save"]
692
        file_prefix = options["file_prefix"]
693
        output_dir = path.join(
694
            options["location"],
695
            file_prefix
696
        )
697
        nx_info = None
698
699
        self._hdf5_output = is_hdf5_extension(options["file_format"])
        if self._hdf5_output:
700
            fname_start_index = None
701
            file_prefix += str("_%04d" % self._get_slice_start_index())
702
            entry = getattr(self.dataset_infos.dataset_scanner, "entry", None)
703
            nx_info = {
704
705
                "process_name": self._get_process_name(),
                "processing_index": 0,
706
707
708
709
                "config": {
                    "processing_options": self.processing_options,
                    "nabu_config": self.process_config.nabu_config
                },
710
711
                "entry": entry,
            }
712
            self._histogram_processing_index = nx_info["processing_index"] + 1
713
        else:
714
            fname_start_index = self._get_slice_start_index()
715
            self._histogram_processing_index = 1
716
        self._writer_configurator = WriterConfigurator(
717
            output_dir, file_prefix,
Pierre Paleo's avatar
Pierre Paleo committed
718
719
720
721
722
723
            file_format=options["file_format"],
            overwrite=options["overwrite"],
            start_index=fname_start_index,
            logger=self.logger,
            nx_info=nx_info,
            write_histogram=("histogram" in self.processing_steps),
724
            histogram_entry=getattr(self.dataset_infos.dataset_scanner, "entry", "entry")
725
726
727
728
        )
        self.writer = self._writer_configurator.writer
        self._writer_exec_args = self._writer_configurator._writer_exec_args
        self._writer_exec_kwargs = self._writer_configurator._writer_exec_kwargs
729
730
        self.histogram_writer = self._writer_configurator.get_histogram_writer()

731

Pierre Paleo's avatar
Pierre Paleo committed
732
733
734
735
736
737
    #
    # Pipeline re-initialization
    #

    def _reset_sub_region(self, sub_region):
        self.set_subregion(sub_region)
738
        self._reset_reader_subregion()
739
740
741
        self._reset_flatfield()

    def _reset_flatfield(self):
742
        self._init_flatfield()
Pierre Paleo's avatar
Pierre Paleo committed
743

744
    #
Pierre Paleo's avatar
Pierre Paleo committed
745
    # Pipeline execution
746
    #
Pierre Paleo's avatar
Pierre Paleo committed
747

748
    @pipeline_step("chunk_reader", "Reading data")
Pierre Paleo's avatar
Pierre Paleo committed
749
    def _read_data(self):
Pierre Paleo's avatar
Pierre Paleo committed
750
        self.logger.debug("Region = %s" % str(self.sub_region))
Pierre Paleo's avatar
Pierre Paleo committed
751
        t0 = time()
Pierre Paleo's avatar
Pierre Paleo committed
752
        self.chunk_reader.load_data()
Pierre Paleo's avatar
Pierre Paleo committed
753
754
        el = time() - t0

Pierre Paleo's avatar
Pierre Paleo committed
755
        shp = self.chunk_reader.data.shape
756
        GB = np.prod(shp) * self.chunk_reader.dtype.itemsize / 1e9
Pierre Paleo's avatar
Pierre Paleo committed
757
        self.logger.info(
758
            "Read subvolume %s (%.2f GB) in %.2f s: %.3f GB/s"
Pierre Paleo's avatar
Pierre Paleo committed
759
760
            % (str(shp), GB, el, GB/el)
        )
Pierre Paleo's avatar
Pierre Paleo committed
761

762
763
764
765
766
767
768
    def _reset_reader_subregion(self):
        if self._resume_from_step is None:
            self.chunk_reader._set_subregion(self.sub_region)
            self.chunk_reader._init_reader()
            self.chunk_reader._loaded = False
        else:
            self.chunk_reader._set_subregion(self._get_read_dump_subregion())
769
            self.chunk_reader._loaded = False
770
771


772
    @pipeline_step("flatfield", "Applying flat-field")
Pierre Paleo's avatar
Pierre Paleo committed
773
774
775
    def _flatfield(self):
        self.flatfield.normalize_radios(self.radios)

776
    @pipeline_step("double_flatfield", "Applying double flat-field")
777
778
779
780
    def _double_flatfield(self, radios=None):
        if radios is None:
            radios = self.radios
        self.double_flatfield.apply_double_flatfield(radios)
Pierre Paleo's avatar
Pierre Paleo committed
781

Pierre Paleo's avatar
Pierre Paleo committed
782
    @pipeline_step("ccd_correction", "Applying CCD corrections")
783
784
785
786
787
788
789
790
791
    def _ccd_corrections(self, radios=None):
        if radios is None:
            radios = self.radios
        _tmp_radio = self._allocate_array(radios.shape[1:], "f", name="tmp_ccdcorr_radio")
        for i in range(radios.shape[0]):
            self.ccd_correction.median_clip_correction(
                radios[i], output=_tmp_radio
            )
            radios[i][:] = _tmp_radio[:]
792

Pierre Paleo's avatar
Pierre Paleo committed
793
794
795
796
    @pipeline_step("projs_rot", "Rotating projections")
    def _rotate_projections(self, radios=None):
        if radios is None:
            radios = self.radios
Pierre Paleo's avatar
Pierre Paleo committed
797
        tmp_radio = self._tmp_rotated_radio
Pierre Paleo's avatar
Pierre Paleo committed
798
        for i in range(radios.shape[0]):
Pierre Paleo's avatar
Pierre Paleo committed
799
800
            self.projs_rot.rotate(radios[i], output=tmp_radio)
            radios[i][:] = tmp_radio[:]
Pierre Paleo's avatar
Pierre Paleo committed
801

802
    @pipeline_step("phase_retrieval", "Performing phase retrieval")
Pierre Paleo's avatar
Pierre Paleo committed
803
    def _retrieve_phase(self):
804
805
806
807
808
        if "unsharp_mask" in self.processing_steps:
            output = self.radios
        else:
            self._get_cropped_radios()
            output = self._radios_cropped
Pierre Paleo's avatar
Pierre Paleo committed
809
        for i in range(self.radios.shape[0]):
810
            self.phase_retrieval.apply_filter(
811
                self.radios[i], output=output[i]
812
            )
Pierre Paleo's avatar
Pierre Paleo committed
813

814
    @pipeline_step("unsharp_mask", "Performing unsharp mask")
Pierre Paleo's avatar
Pierre Paleo committed
815
816
817
    def _apply_unsharp(self):
        for i in range(self.radios.shape[0]):
            self.radios[i] = self.unsharp_mask.unsharp(self.radios[i])
818
        self._get_cropped_radios()
Pierre Paleo's avatar
Pierre Paleo committed
819

820
    @pipeline_step("mlog", "Taking logarithm")
Pierre Paleo's avatar
Pierre Paleo committed
821
    def _take_log(self):
Pierre Paleo's avatar
Pierre Paleo committed
822
        self.mlog.take_logarithm(self.radios)
Pierre Paleo's avatar
Pierre Paleo committed
823

824
    @pipeline_step("radios_movements", "Applying radios movements")
825
826
827
    def _radios_movements(self, radios=None):
        if radios is None:
            radios = self.radios
828
        self.radios_movements.apply_vertical_shifts(
829
            radios, list(range(radios.shape[0]))
830
831
        )

832
833
834
835
836
837
838
    @pipeline_step("sino_normalization", "Normalizing sinograms")
    def _normalize_sinos(self, radios=None):
        if radios is None:
            radios = self.radios
        sinos = radios.transpose((1, 0, 2))
        self.sino_normalization.normalize(sinos)

839
840
841
842
843
    def _dump_sinogram(self, radios=None):
        if radios is None:
            radios = self.radios
        self._dump_data_to_file("sinogram", data=radios)

844
    @pipeline_step("sino_builder", "Building sinograms")
845
846
847
    def _build_sino(self, radios=None):
        if radios is None:
            radios = self.radios
Pierre Paleo's avatar
Pierre Paleo committed
848
849
        # Either a new array (previously allocated in "_sinobuilder_output"),
        # or a view of "radios"
850
        self.sinos = self.sino_builder.radios_to_sinos(
851
            radios,
852
853
854
            output=self._sinobuilder_output,
            copy=self._sinobuilder_copy
        )
Pierre Paleo's avatar
Pierre Paleo committed
855

856
    @pipeline_step("sino_deringer", "Removing rings on sinograms")
857
858
859
    def _destripe_sinos(self, sinos=None):
        if sinos is None:
            sinos = self.sinos
860
861
        self.sino_deringer.remove_rings(sinos)

862
    @pipeline_step("reconstruction", "Reconstruction")
863
864
865
866
    def _reconstruct(self, sinos=None):
        if sinos is None:
            sinos = self.sinos
        for i in range(sinos.shape[0]):
867
            self.reconstruction.fbp(
868
                sinos[i], output=self.recs[i]
869
            )
Pierre Paleo's avatar
Pierre Paleo committed
870

871
    @pipeline_step("histogram", "Computing histogram")
872
873
874
875
    def _compute_histogram(self, data=None):
        if data is None:
            data = self.recs
        self.recs_histogram = self.histogram.compute_histogram(data.ravel())
876

877
    @pipeline_step("writer", "Saving data")
878
879
880
881
    def _write_data(self, data=None):
        if data is None:
            data = self.recs
        self.writer.write(data, *self._writer_exec_args, **self._writer_exec_kwargs)
882
        self.logger.info("Wrote %s" % self.writer.get_filename())
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
        self._write_histogram()

    def _write_histogram(self):
        if "histogram" not in self.processing_steps:
            return
        self.logger.info("Saving histogram")
        self.histogram_writer.write(
            hist_as_2Darray(self.recs_histogram),
            self._get_process_name(kind="histogram"),
            processing_index=self._histogram_processing_index,
            config={
                "file": path.basename(self.writer.get_filename()),
                "bins": self.processing_options["histogram"]["histogram_bins"],
            }
        )
898

Pierre Paleo's avatar
Pierre Paleo committed
899
900
901
902
903
904
905
906
907
    def _dump_data_to_file(self, step_name, data=None):
        if step_name not in self._data_dump:
            return
        if data is None:
            data = self.radios
        writer = self._data_dump[step_name]
        self.logger.info("Dumping data to %s" % writer.fname)
        writer.write_data(data)

Pierre Paleo's avatar
Pierre Paleo committed
908

909
    def _process_chunk(self):
Pierre Paleo's avatar
Pierre Paleo committed
910
911
        self._flatfield()
        self._double_flatfield()
912
        self._ccd_corrections()
Pierre Paleo's avatar
Pierre Paleo committed
913
        self._rotate_projections()
Pierre Paleo's avatar
Pierre Paleo committed
914
        self._retrieve_phase()
Pierre Paleo's avatar
Pierre Paleo committed
915
        self._apply_unsharp()
Pierre Paleo's avatar
Pierre Paleo committed
916
        self._take_log()
917
        self._radios_movements()
918
        self._normalize_sinos()
919
        self._dump_sinogram()
Pierre Paleo's avatar
Pierre Paleo committed
920
        self._build_sino()
921
        self._destripe_sinos()
Pierre Paleo's avatar
Pierre Paleo committed
922
        self._reconstruct()
923
        self._compute_histogram()
Pierre Paleo's avatar
Pierre Paleo committed
924
        self._write_data()
Pierre Paleo's avatar
Pierre Paleo committed
925
        self._process_finalize()
926
927
928
929
930
931
932


    def process_chunk(self, sub_region=None):
        if sub_region is not None:
            self._reset_sub_region(sub_region)
            self._reset_memory()
            self._init_writer()
933
            self._init_double_flatfield()
934
            self._configure_data_dumps()
935
936
        self._read_data()
        self._process_chunk()