scanbase.py 21.6 KB
Newer Older
payno's avatar
payno committed
1
# coding: utf-8
payno's avatar
payno committed
2
# /*##########################################################################
payno's avatar
payno committed
3
# Copyright (C) 2016- 2020 European Synchrotron Radiation Facility
payno's avatar
payno committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#############################################################################*/
payno's avatar
payno committed
24
"""This modules contains base class for TomoScanBase"""
payno's avatar
payno committed
25
26
27
28
29
30
31

__authors__ = ["H.Payno"]
__license__ = "MIT"
__date__ = "09/10/2019"


import os
payno's avatar
payno committed
32
import typing
payno's avatar
payno committed
33
import logging
payno's avatar
payno committed
34
import numpy
Pierre Paleo's avatar
Pierre Paleo committed
35
from typing import Union, Iterable
36
from collections import OrderedDict
payno's avatar
payno committed
37
from .unitsystem.metricsystem import MetricSystem
payno's avatar
payno committed
38
from silx.utils.enum import Enum as _Enum
payno's avatar
payno committed
39
40
from silx.io.url import DataUrl
from silx.io.utils import get_data
41
42
43
import silx.io.utils
from math import ceil
from .progress import Progress
44
from bisect import bisect_left
payno's avatar
payno committed
45

46
_logger = logging.getLogger(__name__)
payno's avatar
payno committed
47
48


payno's avatar
payno committed
49
50
class _FOV(_Enum):
    """Possible existing field of view"""
payno's avatar
payno committed
51
52
53

    FULL = "Full"
    HALF = "Half"
payno's avatar
payno committed
54
55


payno's avatar
payno committed
56
57
58
59
60
61
class TomoScanBase:
    """
    Base Class representing a scan.
    It is used to obtain core information regarding an aquisition like
    projections, dark and flat field...

payno's avatar
payno committed
62
    :param scan: path to the root folder containing the scan.
payno's avatar
payno committed
63
    :type scan: Union[str,None]
payno's avatar
payno committed
64
65
    """

payno's avatar
payno committed
66
67
68
    DICT_TYPE_KEY = "type"

    DICT_PATH_KEY = "path"
payno's avatar
payno committed
69
70
71
72

    _SCHEME = None
    """scheme to read data url for this type of acquisition"""

73
74
75
76
    def __init__(
        self,
        scan: Union[None, str],
        type_: str,
Pierre Paleo's avatar
Pierre Paleo committed
77
        ignore_projections: Union[None, Iterable] = None,
78
    ):
payno's avatar
payno committed
79
80
        self.path = scan
        self._type = type_
payno's avatar
payno committed
81
82
83
84
85
86
87
88
89
        self._normed_flats = None
        """darks normed. When set a dict is expected with index as the key
           and median or median of darks serie as value"""
        self._normed_darks = None
        """flats normed. When set a dict is expected with index as the key
           and median or median of darks serie as value"""
        self._notify_ffc_rsc_missing = True
        """Should we notify the user if ffc fails because cannot find dark or
        flat. Used to avoid several warnings. Only display one"""
90
91
92
93
94
        self._projections = None
        self._alignment_projections = None
        self._flats_weights = None
        """list flats indexes to use for flat field correction and associate
        weights"""
Pierre Paleo's avatar
Pierre Paleo committed
95
        self.ignore_projections = ignore_projections
payno's avatar
payno committed
96

97
98
99
    def clear_caches(self):
        """clear caches. Might be call if some data changed after
        first read of data or metadata"""
payno's avatar
payno committed
100
        self._notify_ffc_rsc_missing = True
101
        self._alignment_projections = None
102
        self._flats_weights = None
payno's avatar
payno committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116

    @property
    def normed_darks(self):
        return self._normed_darks

    def set_normed_darks(self, darks):
        self._normed_darks = darks

    @property
    def normed_flats(self):
        return self._normed_flats

    def set_normed_flats(self, flats):
        self._normed_flats = flats
117

payno's avatar
payno committed
118
    @property
payno's avatar
payno committed
119
    def path(self) -> Union[None, str]:
payno's avatar
payno committed
120
121
122
123
124
125
126
127
        """

        :return: path of the scan root folder.
        :rtype: Union[str,None]
        """
        return self._path

    @path.setter
payno's avatar
payno committed
128
    def path(self, path: Union[str, None]) -> None:
payno's avatar
payno committed
129
130
131
132
133
134
135
        if path is None:
            self._path = path
        else:
            assert type(path) is str
            self._path = os.path.abspath(path)

    @property
payno's avatar
payno committed
136
    def type(self) -> str:
payno's avatar
payno committed
137
138
139
140
141
142
143
144
        """

        :return: type of the scanBase (can be 'edf' or 'hdf5' for now).
        :rtype: str
        """
        return self._type

    @staticmethod
payno's avatar
payno committed
145
    def is_tomoscan_dir(directory: str, **kwargs) -> bool:
payno's avatar
payno committed
146
147
148
149
150
151
152
153
154
        """
        Check if the given directory is holding an acquisition

        :param str directory:
        :return: does the given directory contains any acquisition
        :rtype: bool
        """
        raise NotImplementedError("Base class")

payno's avatar
payno committed
155
    def is_abort(self, **kwargs) -> bool:
payno's avatar
payno committed
156
157
158
159
160
161
162
163
        """

        :return: True if the acquisition has been abort
        :rtype: bool
        """
        raise NotImplementedError("Base class")

    @property
payno's avatar
payno committed
164
    def flats(self) -> Union[None, dict]:
payno's avatar
payno committed
165
166
167
168
        """list of flats files"""
        return self._flats

    @flats.setter
payno's avatar
payno committed
169
    def flats(self, flats: Union[None, dict]) -> None:
payno's avatar
payno committed
170
171
172
        self._flats = flats

    @property
173
    def darks(self) -> Union[None, dict]:
payno's avatar
payno committed
174
175
176
177
        """list of darks files"""
        return self._darks

    @darks.setter
payno's avatar
payno committed
178
    def darks(self, darks: Union[None, dict]) -> None:
payno's avatar
payno committed
179
180
181
        self._darks = darks

    @property
payno's avatar
payno committed
182
    def projections(self) -> Union[None, dict]:
payno's avatar
payno committed
183
184
        """if found dict of projections urls with index during acquisition as
        key"""
payno's avatar
payno committed
185
186
187
        return self._projections

    @projections.setter
payno's avatar
payno committed
188
    def projections(self, projections: dict) -> None:
payno's avatar
payno committed
189
190
        self._projections = projections

payno's avatar
payno committed
191
192
193
194
195
196
197
198
199
200
201
202
    @property
    def alignment_projections(self) -> Union[None, dict]:
        """
        dict of projections made for alignment with acquisition index as key
        None if not found
        """
        return self._alignment_projections

    @alignment_projections.setter
    def alignment_projections(self, alignment_projs):
        self._alignment_projections = alignment_projs

payno's avatar
payno committed
203
    @property
payno's avatar
payno committed
204
    def dark_n(self) -> Union[None, int]:
payno's avatar
payno committed
205
        raise NotImplementedError("Base class")
payno's avatar
payno committed
206
207

    @property
payno's avatar
payno committed
208
    def tomo_n(self) -> Union[None, int]:
209
        """number of projection WITHOUT the return projections"""
payno's avatar
payno committed
210
        raise NotImplementedError("Base class")
payno's avatar
payno committed
211
212

    @property
payno's avatar
payno committed
213
    def ref_n(self) -> Union[None, int]:
payno's avatar
payno committed
214
        raise NotImplementedError("Base class")
payno's avatar
payno committed
215
216

    @property
payno's avatar
payno committed
217
    def pixel_size(self) -> Union[None, float]:
payno's avatar
payno committed
218
        raise NotImplementedError("Base class")
payno's avatar
payno committed
219

payno's avatar
payno committed
220
    def get_pixel_size(self, unit="m") -> Union[None, float]:
payno's avatar
payno committed
221
        if self.pixel_size:
payno's avatar
payno committed
222
            return self.pixel_size / MetricSystem.from_value(unit).value
payno's avatar
payno committed
223
224
225
        else:
            return None

payno's avatar
payno committed
226
    @property
payno's avatar
payno committed
227
    def dim_1(self) -> Union[None, int]:
payno's avatar
payno committed
228
        raise NotImplementedError("Base class")
payno's avatar
payno committed
229
230

    @property
payno's avatar
payno committed
231
    def dim_2(self) -> Union[None, int]:
payno's avatar
payno committed
232
        raise NotImplementedError("Base class")
payno's avatar
payno committed
233
234

    @property
payno's avatar
payno committed
235
    def ff_interval(self) -> Union[None, int]:
payno's avatar
payno committed
236
        raise NotImplementedError("Base class")
payno's avatar
payno committed
237
238

    @property
payno's avatar
payno committed
239
    def scan_range(self) -> Union[None, int]:
payno's avatar
payno committed
240
        raise NotImplementedError("Base class")
payno's avatar
payno committed
241

payno's avatar
payno committed
242
243
244
245
246
247
    @property
    def energy(self) -> Union[None, float]:
        """

        :return: incident beam energy in keV
        """
payno's avatar
payno committed
248
        raise NotImplementedError("Base class")
payno's avatar
payno committed
249

payno's avatar
payno committed
250
251
252
253
254
255
    @property
    def distance(self) -> Union[None, float]:
        """

        :return: sample / detector distance in meter
        """
payno's avatar
payno committed
256
        raise NotImplementedError("Base class")
payno's avatar
payno committed
257

payno's avatar
payno committed
258
259
260
261
262
263
    @property
    def field_of_view(self):
        """

        :return: field of view of the scan. None if unknow else Full or Half
        """
payno's avatar
payno committed
264
        raise NotImplementedError("Base class")
payno's avatar
payno committed
265

payno's avatar
payno committed
266
267
268
269
270
271
272
273
274
    @property
    def estimated_cor_frm_motor(self):
        """

        :return: Estimated center of rotation estimated from motor position
        :rtype: Union[None, float]. If return value is in [-frame_width, +frame_width]
        """
        raise NotImplementedError("Base class")

275
276
277
278
279
280
281
282
283
284
285
286
    @property
    def x_translation(self) -> typing.Union[None, tuple]:
        raise NotImplementedError("Base class")

    @property
    def y_translation(self) -> typing.Union[None, tuple]:
        raise NotImplementedError("Base class")

    @property
    def z_translation(self) -> typing.Union[None, tuple]:
        raise NotImplementedError("Base class")

payno's avatar
payno committed
287
    def get_distance(self, unit="m") -> Union[None, float]:
payno's avatar
payno committed
288
289
290
291
292
293
294
295
296
297
        """

        :param Union[MetricSystem, str] unit: unit requested for the distance
        :return: sample / detector distance with the requested unit
        """
        if self.distance:
            return self.distance / MetricSystem.from_value(unit).value
        else:
            return None

payno's avatar
payno committed
298
    def update(self) -> None:
payno's avatar
payno committed
299
300
301
        """Parse the root folder and files to update informations"""
        raise NotImplementedError("Base class")

payno's avatar
payno committed
302
    def to_dict(self) -> dict:
payno's avatar
payno committed
303
304
305
306
307
308
309
        """

        :return: convert the TomoScanBase object to a dictionary.
                 Used to serialize the object for example.
        :rtype: dict
        """
        res = dict()
310
311
        res[self.DICT_TYPE_KEY] = self.type
        res[self.DICT_PATH_KEY] = self.path
payno's avatar
payno committed
312
313
        return res

payno's avatar
payno committed
314
    def load_from_dict(self, _dict: dict):
payno's avatar
payno committed
315
316
317
318
        """
        Load properties contained in the dictionnary.

        :param _dict: dictionary to load
payno's avatar
payno committed
319
        :type _dict: dict
payno's avatar
payno committed
320
321
322
323
324
        :return: self
        :raises: ValueError if dict is invalid
        """
        raise NotImplementedError("Base class")

payno's avatar
payno committed
325
    def equal(self, other) -> bool:
payno's avatar
payno committed
326
327
        """

payno's avatar
payno committed
328
        :param :class:`.ScanBase` other: instance to compare with
payno's avatar
payno committed
329
        :return: True if instance are equivalent
payno's avatar
payno committed
330
331
332

        ..note:: we cannot use the __eq__ function because this object need to be
                 pickable
payno's avatar
payno committed
333
334
        """
        return (
payno's avatar
payno committed
335
336
337
338
            isinstance(other, self.__class__)
            or isinstance(self, other.__class__)
            and self.type == other.type
            and self.path == other.path
payno's avatar
payno committed
339
340
        )

payno's avatar
payno committed
341
    def get_proj_angle_url(self) -> dict:
payno's avatar
payno committed
342
        """
343
344
345
346
347
        return a dictionary of all the projection. key is the angle of the
        projection and value is the url.

        Keys are int for 'standard' projections and strings for return
        projections.
payno's avatar
payno committed
348
349
350

        :return dict: angles as keys, radios as value.
        """
payno's avatar
payno committed
351
        raise NotImplementedError("Base class")
payno's avatar
payno committed
352
353

    @staticmethod
payno's avatar
payno committed
354
    def map_urls_on_scan_range(urls, n_projection, scan_range) -> dict:
payno's avatar
payno committed
355
356
357
358
        """
        map given urls to an angle regarding scan_range and number of projection.
        We take the hypothesis that 'extra projection' are taken regarding the
        'id19' policy:
payno's avatar
payno committed
359

payno's avatar
payno committed
360
         * If the acquisition has a scan range of 360 then:
payno's avatar
payno committed
361

payno's avatar
payno committed
362
            * if 4 extra projection, the angles are (270, 180, 90, 0)
payno's avatar
payno committed
363

payno's avatar
payno committed
364
            * if 5 extra projection, the angles are (360, 270, 180, 90, 0)
payno's avatar
payno committed
365

payno's avatar
payno committed
366
         * If the acquisition has a scan range of 180 then:
payno's avatar
payno committed
367

payno's avatar
payno committed
368
            * if 2 extra projections: the angles are (90, 0)
payno's avatar
payno committed
369

payno's avatar
payno committed
370
371
            * if 3 extra projections: the angles are (180, 90, 0)

payno's avatar
payno committed
372
        ..warning:: each url should contain only one radio.
payno's avatar
payno committed
373

374
        :param urls: dict with all the urls. First url should be
payno's avatar
payno committed
375
376
                     the first radio acquire, last url should match the last
                     radio acquire.
payno's avatar
payno committed
377
        :type urls: dict
payno's avatar
payno committed
378
        :param n_projection: number of projection for the sample.
payno's avatar
payno committed
379
        :type n_projection: int
payno's avatar
payno committed
380
        :param scan_range: acquisition range (usually 180 or 360)
payno's avatar
payno committed
381
        :type scan_range: float
payno's avatar
payno committed
382
383
        :return: angle in degree as key and url as value
        :rtype: dict
384
385
386

        :raises: ValueError if the number of extra images found and scan_range
                 are incoherent
payno's avatar
payno committed
387
388
        """
        assert n_projection is not None
389
        ordered_url = OrderedDict(sorted(urls.items(), key=lambda x: x))
payno's avatar
payno committed
390

391
        res = {}
payno's avatar
payno committed
392
393
        # deal with the 'standard' acquisitions
        for proj_i in range(n_projection):
394
            url = list(ordered_url.values())[proj_i]
395
396
397
398
            if n_projection == 1:
                angle = 0.0
            else:
                angle = proj_i * scan_range / (n_projection - 1)
payno's avatar
payno committed
399
            if proj_i < len(urls):
400
                res[angle] = url
payno's avatar
payno committed
401
402
403
404

        if len(urls) > n_projection:
            # deal with extra images (used to check if the sampled as moved for
            # example)
405
            extraImgs = list(ordered_url.keys())[n_projection:]
payno's avatar
payno committed
406
407
            if len(extraImgs) in (4, 5):
                if scan_range < 360:
408
                    _logger.warning(
payno's avatar
payno committed
409
410
411
                        "incoherent data information to retrieve"
                        "scan extra images angle"
                    )
payno's avatar
payno committed
412
                elif len(extraImgs) == 4:
payno's avatar
payno committed
413
414
415
416
                    res["270(1)"] = ordered_url[extraImgs[0]]
                    res["180(1)"] = ordered_url[extraImgs[1]]
                    res["90(1)"] = ordered_url[extraImgs[2]]
                    res["0(1)"] = ordered_url[extraImgs[3]]
payno's avatar
payno committed
417
                else:
payno's avatar
payno committed
418
419
420
421
422
                    res["360(1)"] = ordered_url[extraImgs[0]]
                    res["270(1)"] = ordered_url[extraImgs[1]]
                    res["180(1)"] = ordered_url[extraImgs[2]]
                    res["90(1)"] = ordered_url[extraImgs[3]]
                    res["0(1)"] = ordered_url[extraImgs[4]]
payno's avatar
payno committed
423
424
            elif len(extraImgs) in (2, 3):
                if scan_range > 180:
425
                    _logger.warning(
payno's avatar
payno committed
426
427
428
                        "incoherent data information to retrieve"
                        "scan extra images angle"
                    )
429
                elif len(extraImgs) == 3:
payno's avatar
payno committed
430
431
432
                    res["180(1)"] = ordered_url[extraImgs[0]]
                    res["90(1)"] = ordered_url[extraImgs[1]]
                    res["0(1)"] = ordered_url[extraImgs[2]]
payno's avatar
payno committed
433
                else:
payno's avatar
payno committed
434
435
                    res["90(1)"] = ordered_url[extraImgs[0]]
                    res["0(1)"] = ordered_url[extraImgs[1]]
436
            elif len(extraImgs) == 1:
payno's avatar
payno committed
437
                res["0(1)"] = ordered_url[extraImgs[0]]
payno's avatar
payno committed
438
            else:
payno's avatar
payno committed
439
440
441
                raise ValueError(
                    "incoherent data information to retrieve scan" "extra images angle"
                )
payno's avatar
payno committed
442
        return res
payno's avatar
payno committed
443

payno's avatar
payno committed
444
445
446
447
448
449
450
451
452
453
    def get_sinogram(self, line, subsampling=1):
        """
        extract the sinogram from projections

        :param int line: which sinogram we want
        :param int subsampling: subsampling to apply. Allows to skip some io

        :return: computed sinogram from projections
        :rtype: numpy.array
        """
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
        if (
            self.tomo_n is not None and self.dim_2 is not None and line > self.dim_2
        ) or line < 0:
            raise ValueError("requested line {} is not in the scan".format(line))

        if self.projections is not None:
            dim1, dim2 = self.dim_1, self.dim_2
            y_dim = ceil(self.tomo_n / subsampling)
            sinogram = numpy.empty((y_dim, dim1))
            _logger.info(
                "compute sinogram for line {} of {} (subsampling: {})".format(
                    line, self.path, subsampling
                )
            )
            advancement = Progress(
                name="compute sinogram for {}, line={},"
                "sampling={}".format(os.path.basename(self.path), line, subsampling)
            )
            advancement.setMaxAdvancement(self.tomo_n)
            projections = self.projections
            o_keys = list(projections.keys())
            o_keys.sort()
            for i_proj, proj_key in enumerate(o_keys):
                if i_proj % subsampling == 0:
                    proj_url = projections[proj_key]
                    proj = silx.io.utils.get_data(proj_url)
                    proj = self.flat_field_correction(
                        projs=[proj], proj_indexes=[i_proj]
                    )[0]
                    sinogram[i_proj // subsampling] = proj[line]
                advancement.increaseAdvancement(1)
            return sinogram
        else:
            return None
payno's avatar
payno committed
488

489
    def _frame_flat_field_correction(
payno's avatar
payno committed
490
491
492
493
        self,
        data: typing.Union[numpy.ndarray, DataUrl],
        index_proj: typing.Union[int, None],
        dark,
494
        flat_weights: dict,
payno's avatar
payno committed
495
496
497
498
499
500
501
502
503
504
    ):
        """
        compute flat field correction for a provided data from is index
        one dark and two flats (require also indexes)
        """
        assert isinstance(data, (numpy.ndarray, DataUrl))
        if isinstance(data, DataUrl):
            data = get_data(data)
        can_process = True

505
        if flat_weights in (None, {}):
payno's avatar
payno committed
506
            if self._notify_ffc_rsc_missing:
507
                _logger.error("cannot make flat field correction, flat not found")
payno's avatar
payno committed
508
509
            can_process = False
        else:
510
511
512
513
514
515
516
            for flat_index, _ in flat_weights.items():
                if flat_index not in self.normed_flats:
                    _logger.error(
                        "flat {} has been removed, unable to apply flat field"
                        "".format(flat_index)
                    )
                    can_process = False
payno's avatar
payno committed
517
518
519
520
                elif (
                    self.normed_flats is not None
                    and self.normed_flats[flat_index].ndim != 2
                ):
521
522
523
524
525
                    _logger.error(
                        "cannot make flat field correction, flat should be of "
                        "dimension 2"
                    )
                    can_process = False
payno's avatar
payno committed
526
527
528
529
530

        if can_process is False:
            self._notify_ffc_rsc_missing = False
            return data

531
532
533
534
535
536
537
        if len(flat_weights) == 1:
            flat_value = self.normed_flats[list(flat_weights.keys())[0]]
        elif len(flat_weights) == 2:
            flat_keys = list(flat_weights.keys())
            flat_1 = flat_keys[0]
            flat_2 = flat_keys[1]

payno's avatar
payno committed
538
539
540
541
            flat_value = (
                self.normed_flats[flat_1] * flat_weights[flat_1]
                + self.normed_flats[flat_2] * flat_weights[flat_2]
            )
payno's avatar
payno committed
542
        else:
payno's avatar
payno committed
543
544
545
546
            raise ValueError(
                "no more than two flats are expected and"
                "at least one shuold be provided"
            )
payno's avatar
payno committed
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568

        div = flat_value - dark
        div[div == 0] = 1
        return (data - dark) / div

    def flat_field_correction(
        self, projs: typing.Iterable, proj_indexes: typing.Iterable
    ):
        """Apply flat field correction on the given data

        :param Iterable projs: list of projection (numpy array) to apply correction
                              on
        :param Iterable data proj_indexes: list of indexes of the projection in
                                         the acquisition sequence. Values can
                                         be int or None. If None then the
                                         index take will be the one in the
                                         middle of the flats taken.
        :return: corrected data: list of numpy array
        :rtype: list
        """
        assert isinstance(projs, typing.Iterable)
        assert isinstance(proj_indexes, typing.Iterable)
569
570
571
572
573
574
575
576
577
578
579
580
581

        def has_missing_keys():
            if proj_indexes is None:
                return False
            for proj_index in proj_indexes:
                if proj_index not in self._flats_weights:
                    return False
            return True

        if self._flats_weights in (None, {}) or has_missing_keys():
            self._flats_weights = self._get_flats_weights()

        if self._flats_weights in (None, {}):
payno's avatar
payno committed
582
            _logger.error("Unable to compute flat weights")
583
584

        darks = self._normed_darks
payno's avatar
payno committed
585
586
587
        if darks is not None and len(darks) > 0:
            # take only one dark into account for now
            dark = list(darks.values())[0]
588
589
590
591
592
593
        else:
            dark = None

        if dark is None:
            if self._notify_ffc_rsc_missing:
                _logger.error("cannot make flat field correction, dark not found")
594
595
596
597
                return [
                    get_data(proj) if isinstance(proj, DataUrl) else proj
                    for proj in projs
                ]
598
599
600
601
602

        if dark is not None and dark.ndim != 2:
            _logger.error(
                "cannot make flat field correction, dark should be of " "dimension 2"
            )
603
604
605
            return [
                get_data(proj) if isinstance(proj, DataUrl) else proj for proj in projs
            ]
606

payno's avatar
payno committed
607
        return [
608
            self._frame_flat_field_correction(
payno's avatar
payno committed
609
610
611
                data=frame,
                dark=dark,
                index_proj=proj_i,
612
613
614
                flat_weights=self._flats_weights[proj_i]
                if proj_i in self._flats_weights
                else None,
payno's avatar
payno committed
615
616
617
            )
            for frame, proj_i in zip(projs, proj_indexes)
        ]
618
619
620

    def _get_flats_weights(self):
        """compute flats indexes to use and weights for each projection"""
payno's avatar
payno committed
621
622
        if self.normed_flats is None:
            return None
623
624
625
626
627
628
629
630
631
        flats_indexes = sorted(self.normed_flats.keys())

        def get_weights(proj_index):
            if proj_index in flats_indexes:
                return {proj_index: 1.0}
            pos = bisect_left(flats_indexes, proj_index)
            left_pos = flats_indexes[pos - 1]
            if pos == 0:
                return {flats_indexes[0]: 1.0}
payno's avatar
payno committed
632
            elif pos > len(flats_indexes) - 1:
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
                return {flats_indexes[-1]: 1.0}
            else:
                right_pos = flats_indexes[pos]
                delta = right_pos - left_pos
                return {
                    left_pos: 1 - (proj_index - left_pos) / delta,
                    right_pos: 1 - (right_pos - proj_index) / delta,
                }

        if self.normed_flats is None or len(self.normed_flats) == 0:
            return {}
        else:
            res = {}
            for proj_index in self.projections:
                res[proj_index] = get_weights(proj_index=proj_index)
            return res