alignment.py 26.1 KB
Newer Older
Nicola Vigano's avatar
Nicola Vigano committed
1
import numpy as np
2

3
import logging
4
from numpy.polynomial.polynomial import Polynomial
5

6
from nabu.utils import previouspow2
7
from nabu.misc import fourier_filters
8

9
10
try:
    from scipy.ndimage.filters import median_filter
11

12
13
14
    __have_scipy__ = True
except ImportError:
    from silx.math.medianfilter import medfilt2d as median_filter
15

16
    __have_scipy__ = False
17

myron's avatar
myron committed
18

myron's avatar
myron committed
19
20
class AlignmentBase(object):
    @staticmethod
21
22
    def refine_max_position_2d(f_vals: np.ndarray, fy=None, fx=None):
        """Computes the sub-pixel max position of the given function sampling.
23

24
25
26
27
28
29
30
31
        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
        fy: numpy.ndarray, optional
            Vertical coordinates of the sampled points
        fx: numpy.ndarray, optional
            Horizontal coordinates of the sampled points
32

33
34
35
36
37
        Raises
        ------
        ValueError
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        Returns
        -------
        tuple(float, float)
            Estimated (vertical, horizontal) function max, according to the
            coordinates in fy and fx.
        """
        if len(f_vals.shape) > 2:
            raise ValueError(
                "The fitted values should form a 2-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fy is None:
            fy_half_size = (f_vals.shape[0] - 1) / 2
            fy = np.linspace(-fy_half_size, fy_half_size, f_vals.shape[0])
        elif not (len(fy.shape) == 1 and np.all(fy.size == f_vals.shape[0])):
            raise ValueError(
                "Vertical coordinates should have the same length as values matrix. Sizes of fy: %d, f_vals: [%s]"
                % (fy.size, " ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.shape[1] - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.shape[1])
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.shape[1])):
            raise ValueError(
                "Horizontal coordinates should have the same length as values matrix. Sizes of fx: %d, f_vals: [%s]"
                % (fx.size, " ".join(("%d" % s for s in f_vals.shape)))
            )

        fy, fx = np.meshgrid(fy, fx, indexing="ij")
        fy = fy.flatten()
        fx = fx.flatten()
        coords = np.array([np.ones(f_vals.size), fy, fx, fy * fx, fy ** 2, fx ** 2])
71

72
        coeffs = np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0]
73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        # For a 1D parabola `f(x) = ax^2 + bx + c`, the vertex position is:
        # x_v = -b / 2a. For a 2D parabola, the vertex position is:
        # (y, x)_v = - b / A, where:
        A = [[2 * coeffs[4], coeffs[3]], [coeffs[3], 2 * coeffs[5]]]
        b = coeffs[1:3]
        vertex_yx = np.linalg.lstsq(A, -b, rcond=None)[0]

        vertex_min_yx = [np.min(fy), np.min(fx)]
        vertex_max_yx = [np.max(fy), np.max(fx)]
        if np.any(vertex_yx < vertex_min_yx) or np.any(vertex_yx > vertex_max_yx):
            raise ValueError(
                "Fitted (y: {}, x: {}) positions are outide the margins of input: y: [{}, {}], x: [{}, {}]".format(
                    vertex_yx[0], vertex_yx[1], vertex_min_yx[0], vertex_max_yx[0], vertex_min_yx[1], vertex_max_yx[1]
                )
            )
        return vertex_yx

    @staticmethod
92
    def refine_max_position_1d(f_vals, fx=None):
93
94
95
96
97
98
        """Computes the sub-pixel max position of the given function sampling.

        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
99
100
        fx: numpy.ndarray, optional
            Coordinates of the sampled points
101
102
103
104

        Raises
        ------
        ValueError
105
106
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
107
108
109
110
111
112

        Returns
        -------
        float
            Estimated function max, according to the coordinates in fx.
        """
113
        if not len(f_vals.shape) == 1:
114
            raise ValueError(
115
116
117
118
119
120
121
122
123
124
                "The fitted values should form a 1-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.size - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.size)
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.size)):
            raise ValueError(
                "Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d"
                % (fx.size, f_vals.size)
125
126
            )

127
128
129
130
        # using Polynomial.fit, because supposed to be more numerically stable
        # than previous solutions (according to numpy).
        poly = Polynomial.fit(fx, f_vals, deg=2)
        coeffs = poly.convert().coef
131

132
133
        # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is:
        # x_v = -b / 2a.
134
        vertex_x = -coeffs[1] / (2 * coeffs[2])
135

136
137
138
139
140
141
142
143
144
        vertex_min_x = np.min(fx)
        vertex_max_x = np.max(fx)
        if not (vertex_min_x < vertex_x < vertex_max_x):
            raise ValueError(
                "Fitted x: {} position is outide the margins of input: x: [{}, {}]".format(
                    vertex_x, vertex_min_x, vertex_max_x
                )
            )
        return vertex_x
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    @staticmethod
    def extract_peak_region(cc, peak_radius=1, cc_vs=None, cc_hs=None):
        """
        Extracts a region around the maximum value.

        Parameters
        ----------
        cc: numpy.ndarray
            Correlation image.
        peak_radius: int, optional
            The l_inf radius of the area to extract around the peak. The default is 1.
        cc_vs: numpy.ndarray, optional
            The vertical coordinates of `cc`. The default is None.
        cc_hs: numpy.ndarray, optional
            The horizontal coordinates of `cc`. The default is None.

        Returns
        -------
        f_vals: numpy.ndarray
            The extracted function values.
        fv: numpy.ndarray
            The vertical coordinates of the extracted values.
        fh: numpy.ndarray
            The horizontal coordinates of the extracted values.
        """
        img_shape = np.array(cc.shape)
        # get pixel having the maximum value of the correlation array
        pix_max_corr = np.argmax(cc)
        pv, ph = np.unravel_index(pix_max_corr, img_shape)

        # select a n x n neighborhood for the sub-pixel fitting (with wrapping)
        pv = np.arange(pv - peak_radius, pv + peak_radius + 1) % img_shape[-2]
        ph = np.arange(ph - peak_radius, ph + peak_radius + 1) % img_shape[-1]

        # extract the (v, h) pixel coordinates
        fv = None if cc_vs is None else cc_vs[pv]
        fh = None if cc_hs is None else cc_hs[ph]

        # extract the correlation values
        pv, ph = np.meshgrid(pv, ph, indexing="ij")
        f_vals = cc[pv, ph]

        return (f_vals, fv, fh)

190
191
192
193
194
195
196
197
    @staticmethod
    def _determine_roi(img_shape, roi_yxhw, do_truncate_horz_pow2):
        if roi_yxhw is None:
            # vertical window size is reduced to a power of 2 to accelerate fft
            # same thing horizontal window - if requested. Default is not
            roi_yxhw = previouspow2(img_shape)
            if not do_truncate_horz_pow2:
                roi_yxhw[1] = img_shape[1]
198

199
200
201
202
        if len(roi_yxhw) == 2:  # Convert centered 2-element roi into 4-element
            roi_yxhw = np.array(roi_yxhw, dtype=np.int)
            roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw))
        return roi_yxhw
203

204
    @staticmethod
205
    def _prepare_image(img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, high_pass=None, low_pass=None):
206
207
        """
        Prepare and returns  a cropped  and filtered image, or array of filtered images if the input is an  array of images.
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

        Parameters
        ----------
        img: numpy.ndarray
            image or stack of images
        invalid_val: float
            value to be used in replacement of nan and inf values
        median_filt_shape: int or sequence of int
            the width or the widths of the median window
        high_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter, and the result is subtracted from 1 to obtain the high pass filter
            When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff
            frequency while a smooth  transition to zero is done for smaller frequency
         low_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter.
            When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff
            frequency and then a smooth erfc transition to zero is done

        Returns
        -------
        numpy.array_like
            The computed filter
        """
237
238
        img = np.squeeze(img)  # Removes singleton dimensions, but does a shallow copy
        img = np.ascontiguousarray(img)
239

240
        if roi_yxhw is not None:
241
            img = img[
242
                ..., roi_yxhw[0] : roi_yxhw[0] + roi_yxhw[2], roi_yxhw[1] : roi_yxhw[1] + roi_yxhw[3],
243
244
            ]

myron's avatar
myron committed
245
        img = img.copy()
246
247
248
249

        img[np.isnan(img)] = invalid_val
        img[np.isinf(img)] = invalid_val

250
        if high_pass is not None or low_pass is not None:
251
            img_filter = np.ones(img.shape[-2:], dtype=img.dtype)
252
            if low_pass is not None:
253
                img_filter[:] *= fourier_filters.get_lowpass_filter(img.shape[-2:], low_pass)
254
            if high_pass is not None:
255
256
257
                img_filter[:] *= fourier_filters.get_highpass_filter(img.shape[-2:], high_pass)
            # fft2 and iff2 use axes=(-2, -1) by default
            img = np.fft.ifft2(np.fft.fft2(img) * img_filter).real
258

259
        if median_filt_shape is not None:
260
261
262
263
            img_shape = img.shape
            if __have_scipy__:
                # expanding filter shape with ones, to cover the stack of images
                # but disabling inter-image filtering
264
                median_filt_shape = np.concatenate(
265
                    (np.ones((len(img_shape) - len(median_filt_shape),), dtype=np.int), median_filt_shape,)
266
                )
267
268
269
270
271
272
                img = median_filter(img, size=median_filt_shape)
            else:
                if len(img_shape) == 2:
                    img = median_filter(img, kernel_size=median_filt_shape)
                elif len(img_shape) > 2:
                    # if dealing with a stack of images, we have to do them one by one
273
                    img = np.reshape(img, (-1,) + tuple(img_shape[-2:]))
274
                    for ii in range(img.shape[0]):
275
                        img[ii, ...] = median_filter(img[ii, ...], kernel_size=median_filt_shape)
276
                    img = np.reshape(img, img_shape)
277
278
279

        return img

280
281
    @staticmethod
    def _compute_correlation_fft(img_1, img_2, padding_mode):
282
        do_circular_conv = padding_mode is None or padding_mode == "wrap"
283
        if not do_circular_conv:
Nicola Vigano's avatar
Nicola Vigano committed
284
            padding = np.ceil(np.array(img_2.shape) / 2).astype(np.int)
285
286
            img_1 = np.pad(img_1, ((padding[0],), (padding[1],)), mode=padding_mode)
            img_2 = np.pad(img_2, ((padding[0],), (padding[1],)), mode=padding_mode)
287
288
289
290
291
292
293
294

        # compute fft's of the 2 images
        img_fft_1 = np.fft.fft2(img_1)
        img_fft_2 = np.conjugate(np.fft.fft2(img_2))
        # inverse fft of the product to get cross_correlation of the 2 images
        cc = np.real(np.fft.ifft2(img_fft_1 * img_fft_2))

        if not do_circular_conv:
295
            cc = np.fft.fftshift(cc, axes=(-2, -1))
296
            cc = cc[padding[0] : -padding[0], padding[1] : -padding[1]]
297
            cc = np.fft.ifftshift(cc, axes=(-2, -1))
298
299
300

        return cc

301
302

class CenterOfRotation(AlignmentBase):
303
    def __init__(self, horz_fft_width=False):
304
305
306
307
308
309
310
311
312
313
314
        """
        Center of Rotation (CoR) computation object.
        This class is used on radios.

        Parameters
        ----------
        horz_fft_width: boolean, optional
            If True, restrict the horizontal size to a power of 2:

            >>> new_x_dim = 2 ** math.floor(math.log2(x_dim))
        """
315
        self._init_parameters(horz_fft_width)
316

317
    def _init_parameters(self, horz_fft_width):
318
319
320
321
322
323
324
        self.truncate_horz_pow2 = horz_fft_width

    @staticmethod
    def _check_img_sizes(img_1: np.ndarray, img_2: np.ndarray):
        shape_1 = np.squeeze(img_1).shape
        shape_2 = np.squeeze(img_2).shape
        if not len(shape_1) == 2:
325
            raise ValueError(
326
                "Images need to be 2-dimensional. Shape of image #1: %s" % (" ".join(("%d" % x for x in shape_1)))
327
            )
328
        if not len(shape_2) == 2:
329
            raise ValueError(
330
                "Images need to be 2-dimensional. Shape of image #2: %s" % (" ".join(("%d" % x for x in shape_2)))
331
            )
332
        if not np.all(shape_1 == shape_2):
333
334
            raise ValueError(
                "Images need to be of the same shape. Shape of image #1: %s, image #2: %s"
335
                % (" ".join(("%d" % x for x in shape_1)), " ".join(("%d" % x for x in shape_2)),)
336
            )
337

338
    def find_shift(
339
340
341
342
343
        self,
        img_1: np.ndarray,
        img_2: np.ndarray,
        roi_yxhw=None,
        median_filt_shape=None,
myron's avatar
myron committed
344
        padding_mode=None,
345
        peak_fit_radius=1,
346
        high_pass=None,
347
        low_pass=None,
348
    ):
349
        """Find the Center of Rotation (CoR), given to images.
Nicola Vigano's avatar
Nicola Vigano committed
350

Pierre Paleo's avatar
Fix doc    
Pierre Paleo committed
351
352
        This method finds the half-shift between two opposite images, by
        means of correlation computed in Fourier space.
353

354
355
        The output of this function, allows to compute motor movements for
        aligning the sample rotation axis. Given the following values:
356

357
358
359
360
        - L1: distance from source to motor
        - L2: distance from source to detector
        - ps: physical pixel size
        - v: output of this function
361

Nicola Vigano's avatar
Nicola Vigano committed
362
        displacement of motor = (L1 / L2 * ps) * v
363

364
365
366
367
368
369
        Parameters
        ----------
        img_1: numpy.ndarray
            First image
        img_2: numpy.ndarray
            Second image, it needs to have been flipped already (e.g. using numpy.fliplr).
370
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
371
372
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
373
374
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
375
376
377
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
378
379
380
381
382
383
384
385
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
386
        peak_fit_radius: int, optional
387
388
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
389

390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        high_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter, and the result is subtracted from 1 to obtain the high pass filter
            When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff
            frequency while a smooth  transition to zero is done for smaller frequency
        low_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter.
            When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff
            frequency and then a smooth erfc transition to zero is done

405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        Raises
        ------
        ValueError
            In case images are not 2-dimensional or have different sizes.

        Returns
        -------
        float
            Estimated center of rotation position from the center of the RoI in pixels.

        Examples
        --------
        The following code computes the center of rotation position for two
        given images in a tomography scan, where the second image is taken at
        180 degrees from the first.

        >>> radio1 = data[0, :, :]
        ... radio2 = np.fliplr(data[1, :, :])
423
        ... CoR_calc = CenterOfRotation()
424
        ... cor_position = CoR_calc.find_shift(radio1, radio2)
425
426
427

        Or for noisy images:

428
        >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
429
430
431
        """
        self._check_img_sizes(img_1, img_2)

432
        if peak_fit_radius < 1:
433
434
435
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
436
437
            peak_fit_radius = 1

438
        img_shape = img_2.shape
439
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, self.truncate_horz_pow2)
440

441
        img_1 = self._prepare_image(
442
            img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, high_pass=high_pass, low_pass=low_pass
443
444
        )
        img_2 = self._prepare_image(
445
            img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, high_pass=high_pass, low_pass=low_pass
446
        )
447

448
        cc = self._compute_correlation_fft(img_1, img_2, padding_mode)
449
450
451
        img_shape = img_2.shape
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
452

453
        (f_vals, fv, fh) = self.extract_peak_region(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
454
        fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh)
455

456
        return fitted_shifts_vh[-1] / 2.0
457

458
    __call__ = find_shift
459
460
461
462
463
464
465
466


class DetectorTranslationAlongBeam(AlignmentBase):
    @staticmethod
    def _check_img_sizes(img_stack: np.ndarray, img_pos: np.ndarray):
        shape_stack = np.squeeze(img_stack).shape
        shape_pos = np.squeeze(img_pos).shape
        if not len(shape_stack) == 3:
467
            raise ValueError(
468
                "A stack of 2-dimensional images is required. Shape of stack: %s" % (" ".join(("%d" % x for x in shape_stack)))
469
            )
470
        if not len(shape_pos) == 1:
471
472
473
474
            raise ValueError(
                "Positions need to be a 1-dimensional array. Shape of the positions variable: %s"
                % (" ".join(("%d" % x for x in shape_pos)))
            )
475
476
        if not shape_stack[0] == shape_pos[0]:
            raise ValueError(
477
478
                "The same number of images and positions is required."
                + " Shape of stack: %s, shape of positions variable: %s"
479
                % (" ".join(("%d" % x for x in shape_stack)), " ".join(("%d" % x for x in shape_pos)),)
480
            )
481

482
483
484
485
486
487
488
    @staticmethod
    def _check_equispaced(img_pos, supposed_equispaced):
        img_pos_incrs = np.diff(img_pos)
        detected_equispaced = np.all(np.isclose(img_pos_incrs[0], img_pos_incrs[1:]))

        if detected_equispaced and not supposed_equispaced:
            logging.getLogger(__name__).warning(
489
490
491
492
                "The image position increments were supposed to NOT be equispaced, "
                + "but they seem to be equispaced: "
                + " ".join(("%f" % x for x in img_pos))
                + ". Forcing behavior according to the detected condition."
493
494
495
            )
        if not detected_equispaced and supposed_equispaced:
            logging.getLogger(__name__).warning(
496
497
498
499
                "The image position increments were supposed to be equispaced, "
                + "but they seem to NOT be equispaced: "
                + " ".join(("%f" % x for x in img_pos))
                + ". Forcing behavior according to the detected condition."
500
501
502
            )
        return detected_equispaced

503
    def find_shift(
504
505
506
507
508
509
        self,
        img_stack: np.ndarray,
        img_pos: np.array,
        roi_yxhw=None,
        median_filt_shape=None,
        padding_mode=None,
Nicola Vigano's avatar
Nicola Vigano committed
510
        peak_fit_radius=1,
511
        equispaced_increments=True,
512
        return_shifts=False,
513
514
515
516
517
    ):
        """Find the deviation of the translation axis of the area detector
        along the beam propagation direction.

        TODO: Add more information here! Including interpretation of the result
518
519
        This means giving also an example on how to convert the returned values
        into meaningful quantities. See "Returns" for more details.
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542

        Parameters
        ----------
        img_stack: numpy.ndarray
            A stack of images (usually 4) at different distances
        img_pos: numpy.ndarray
            Position of the images along the translation axis
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
543
        peak_fit_radius: int, optional
544
545
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
546
547
548
549
550
551
552
553
554
        equispaced_increments: boolean, optional
            Tells whether the position increments are equispaced or not. If
            equispaced increments are used, we have to compute the correlation
            images with respect to the first image, otherwise we can do it
            against adjacent images.
            The advantage of doing it between adjacent images is that we do not
            build up large displacements in the correlation.
            However, if this is done for equispaced images, the linear fitting
            becomes unstable.
555
556
557
558
        found_shifts_list: list, optional
            if a list is given in input, it will be populated
            with the found shifts. For each given image of the stack a tuple of shifts will be
            appended to the list. This slot is intended for introspection.
559
560
561
562

        Returns
        -------
        tuple(float, float)
563
564
            Estimated (vertical, horizontal) increment per unit-distance of the
            ratio between pixel-size and detector translation.
565
566
567
568
569
570

        Examples
        --------
        TODO: Add examples here!
        """
        self._check_img_sizes(img_stack, img_pos)
571
        equispaced_increments = self._check_equispaced(img_pos, equispaced_increments)
572

573
        if peak_fit_radius < 1:
574
575
576
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
577
578
            peak_fit_radius = 1

579
580
581
582
        num_imgs = img_stack.shape[0]
        img_shape = img_stack.shape[-2:]
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, False)

583
        img_stack = self._prepare_image(img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
584

585
586
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
587

588
        # do correlations
589
        ccs = [
590
591
592
            self._compute_correlation_fft(
                img_stack[0 if equispaced_increments else ii - 1, ...], img_stack[ii, ...], padding_mode
            )
593
594
            for ii in range(1, num_imgs)
        ]
595

596
        shifts_vh = np.empty((num_imgs - 1, 2))
597
        for ii, cc in enumerate(ccs):
Nicola Vigano's avatar
Nicola Vigano committed
598
            (f_vals, fv, fh) = self.extract_peak_region(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
599
            shifts_vh[ii, :] = self.refine_max_position_2d(f_vals, fv, fh)
600

601
602
603
        if equispaced_increments:
            img_pos_increments = img_pos[1:] - img_pos[0]
        else:
Nicola Vigano's avatar
Nicola Vigano committed
604
            img_pos_increments = np.diff(img_pos)
605
606
607
608
609
610

        # Polynomial.fit is supposed to be more numerically stable than polyfit
        # (according to numpy)
        coeffs_v = Polynomial.fit(img_pos_increments, shifts_vh[:, 0], deg=1).convert().coef
        coeffs_h = Polynomial.fit(img_pos_increments, shifts_vh[:, 1], deg=1).convert().coef

611
612
        if return_shifts:
            r_shifts = []
613
            for vh in shifts_vh:
614
615
616
617
                r_shifts.append(vh)
            return coeffs_v[1], coeffs_h[1], r_shifts
        else:
            return coeffs_v[1], coeffs_h[1]