alignment.py 25.9 KB
Newer Older
Nicola Vigano's avatar
Nicola Vigano committed
1
import numpy as np
2

3
import logging
4
from numpy.polynomial.polynomial import Polynomial
5

6
from nabu.utils import previouspow2
7
from nabu.misc import fourier_filters
8

9
10
try:
    from scipy.ndimage.filters import median_filter
11

12
13
14
    __have_scipy__ = True
except ImportError:
    from silx.math.medianfilter import medfilt2d as median_filter
15

16
    __have_scipy__ = False
17

myron's avatar
myron committed
18

myron's avatar
myron committed
19
20
class AlignmentBase(object):
    @staticmethod
21
22
    def refine_max_position_2d(f_vals: np.ndarray, fy=None, fx=None):
        """Computes the sub-pixel max position of the given function sampling.
23

24
25
26
27
28
29
30
31
        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
        fy: numpy.ndarray, optional
            Vertical coordinates of the sampled points
        fx: numpy.ndarray, optional
            Horizontal coordinates of the sampled points
32

33
34
35
36
37
        Raises
        ------
        ValueError
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
38

39
40
41
42
43
44
        Returns
        -------
        tuple(float, float)
            Estimated (vertical, horizontal) function max, according to the
            coordinates in fy and fx.
        """
45
        if not (len(f_vals.shape) == 2):
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            raise ValueError(
                "The fitted values should form a 2-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fy is None:
            fy_half_size = (f_vals.shape[0] - 1) / 2
            fy = np.linspace(-fy_half_size, fy_half_size, f_vals.shape[0])
        elif not (len(fy.shape) == 1 and np.all(fy.size == f_vals.shape[0])):
            raise ValueError(
                "Vertical coordinates should have the same length as values matrix. Sizes of fy: %d, f_vals: [%s]"
                % (fy.size, " ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.shape[1] - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.shape[1])
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.shape[1])):
            raise ValueError(
                "Horizontal coordinates should have the same length as values matrix. Sizes of fx: %d, f_vals: [%s]"
                % (fx.size, " ".join(("%d" % s for s in f_vals.shape)))
            )

        fy, fx = np.meshgrid(fy, fx, indexing="ij")
        fy = fy.flatten()
        fx = fx.flatten()
        coords = np.array([np.ones(f_vals.size), fy, fx, fy * fx, fy ** 2, fx ** 2])
71

72
        coeffs = np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0]
73

74
75
76
77
78
79
80
81
82
83
84
        # For a 1D parabola `f(x) = ax^2 + bx + c`, the vertex position is:
        # x_v = -b / 2a. For a 2D parabola, the vertex position is:
        # (y, x)_v = - b / A, where:
        A = [[2 * coeffs[4], coeffs[3]], [coeffs[3], 2 * coeffs[5]]]
        b = coeffs[1:3]
        vertex_yx = np.linalg.lstsq(A, -b, rcond=None)[0]

        vertex_min_yx = [np.min(fy), np.min(fx)]
        vertex_max_yx = [np.max(fy), np.max(fx)]
        if np.any(vertex_yx < vertex_min_yx) or np.any(vertex_yx > vertex_max_yx):
            raise ValueError(
Nicola Vigano's avatar
Nicola Vigano committed
85
                "Fitted (y: {}, x: {}) positions are outide the input margins y: [{}, {}], and x: [{}, {}]".format(
86
87
88
89
90
91
                    vertex_yx[0], vertex_yx[1], vertex_min_yx[0], vertex_max_yx[0], vertex_min_yx[1], vertex_max_yx[1]
                )
            )
        return vertex_yx

    @staticmethod
92
    def refine_max_position_1d(f_vals, fx=None):
93
94
95
96
97
98
        """Computes the sub-pixel max position of the given function sampling.

        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
99
100
        fx: numpy.ndarray, optional
            Coordinates of the sampled points
101
102
103
104

        Raises
        ------
        ValueError
105
106
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
107
108
109
110
111
112

        Returns
        -------
        float
            Estimated function max, according to the coordinates in fx.
        """
113
        if not len(f_vals.shape) in (1, 2):
114
            raise ValueError(
115
                "The fitted values should be either one or a collection of 1-dimensional arrays. Array of shape: [%s] was given."
116
117
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
118
119
        num_vals = f_vals.shape[0]

120
        if fx is None:
121
122
123
            fx_half_size = (num_vals - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, num_vals)
        elif not (len(fx.shape) == 1 and np.all(fx.size == num_vals)):
124
125
            raise ValueError(
                "Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d"
126
                % (fx.size, num_vals)
127
128
            )

129
130
131
132
133
134
135
136
        if len(f_vals.shape) == 1:
            # using Polynomial.fit, because supposed to be more numerically
            # stable than previous solutions (according to numpy).
            poly = Polynomial.fit(fx, f_vals, deg=2)
            coeffs = poly.convert().coef
        else:
            coords = np.array([np.ones(num_vals), fx, fx ** 2])
            coeffs = np.linalg.lstsq(coords.T, f_vals, rcond=None)[0]
137

138
139
        # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is:
        # x_v = -b / 2a.
140
        vertex_x = -coeffs[1, :] / (2 * coeffs[2, :])
141

142
143
        vertex_min_x = np.min(fx)
        vertex_max_x = np.max(fx)
144
145
146
147
148
        lower_bound_ok = vertex_min_x < vertex_x
        upper_bound_ok = vertex_x < vertex_max_x
        if not np.all(lower_bound_ok * upper_bound_ok):
            if len(f_vals.shape) == 1:
                message = "Fitted position {} is outide the input margins [{}, {}]".format(
149
150
                    vertex_x, vertex_min_x, vertex_max_x
                )
151
152
153
154
155
            else:
                message = "Fitted positions outide the input margins [{}, {}]: %d below and %d above".format(
                    vertex_min_x, vertex_max_x, np.sum(1 - lower_bound_ok), np.sum(1 - upper_bound_ok)
                )
            raise ValueError(message)
156
        return vertex_x
157

158
    @staticmethod
159
    def extract_peak_region_2d(cc, peak_radius=1, cc_vs=None, cc_hs=None):
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        """
        Extracts a region around the maximum value.

        Parameters
        ----------
        cc: numpy.ndarray
            Correlation image.
        peak_radius: int, optional
            The l_inf radius of the area to extract around the peak. The default is 1.
        cc_vs: numpy.ndarray, optional
            The vertical coordinates of `cc`. The default is None.
        cc_hs: numpy.ndarray, optional
            The horizontal coordinates of `cc`. The default is None.

        Returns
        -------
        f_vals: numpy.ndarray
            The extracted function values.
        fv: numpy.ndarray
            The vertical coordinates of the extracted values.
        fh: numpy.ndarray
            The horizontal coordinates of the extracted values.
        """
        img_shape = np.array(cc.shape)
        # get pixel having the maximum value of the correlation array
        pix_max_corr = np.argmax(cc)
        pv, ph = np.unravel_index(pix_max_corr, img_shape)

        # select a n x n neighborhood for the sub-pixel fitting (with wrapping)
        pv = np.arange(pv - peak_radius, pv + peak_radius + 1) % img_shape[-2]
        ph = np.arange(ph - peak_radius, ph + peak_radius + 1) % img_shape[-1]

        # extract the (v, h) pixel coordinates
        fv = None if cc_vs is None else cc_vs[pv]
        fh = None if cc_hs is None else cc_hs[ph]

        # extract the correlation values
        pv, ph = np.meshgrid(pv, ph, indexing="ij")
        f_vals = cc[pv, ph]

        return (f_vals, fv, fh)

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    @staticmethod
    def extract_peak_regions_1d(cc, axis=-1, peak_radius=1, cc_coords=None):
        """
        Extracts a region around the maximum value.

        Parameters
        ----------
        cc: numpy.ndarray
            Correlation image.
        axis: int, optional
            Find the max values along the specified direction. The default is -1.
        peak_radius: int, optional
            The l_inf radius of the area to extract around the peak. The default is 1.
        cc_coords: numpy.ndarray, optional
            The coordinates of `cc` along the selected axis. The default is None.

        Returns
        -------
        f_vals: numpy.ndarray
            The extracted function values.
        fc_ax: numpy.ndarray
            The coordinates of the extracted values, along the selected axis.
        """
        img_shape = np.array(cc.shape)
        if not (len(img_shape) == 2):
            raise ValueError(
                "The input image should be a 2-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in cc.shape)))
            )
        other_axis = (axis + 1) % 2
        # get pixel having the maximum value of the correlation array
        pix_max = np.argmax(cc, axis=axis)

        # select a n neighborhood for the many 1D sub-pixel fittings (with wrapping)
        p_ax_range = np.arange(- peak_radius, + peak_radius + 1)
        p_ax = (pix_max[None, :] + p_ax_range[:, None]) % img_shape[axis]

        p_ln = np.tile(np.arange(0, img_shape[axis])[None, :], [2 * peak_radius + 1, 1])

        # extract the pixel coordinates along the axis
        fc_ax = None if cc_coords is None else cc_coords[p_ax.flatten()].reshape(p_ax.shape)

        # extract the correlation values
        if other_axis == 0:
            f_vals = cc[p_ln, p_ax]
        else:
            f_vals = cc[p_ax, p_ln]

        return (f_vals, fc_ax)

252
253
254
255
256
257
258
259
    @staticmethod
    def _determine_roi(img_shape, roi_yxhw, do_truncate_horz_pow2):
        if roi_yxhw is None:
            # vertical window size is reduced to a power of 2 to accelerate fft
            # same thing horizontal window - if requested. Default is not
            roi_yxhw = previouspow2(img_shape)
            if not do_truncate_horz_pow2:
                roi_yxhw[1] = img_shape[1]
260

261
262
263
264
        if len(roi_yxhw) == 2:  # Convert centered 2-element roi into 4-element
            roi_yxhw = np.array(roi_yxhw, dtype=np.int)
            roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw))
        return roi_yxhw
265

266
    @staticmethod
267
    def _prepare_image(img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, low_pass=None, high_pass=None):
268
269
        """
        Prepare and returns  a cropped  and filtered image, or array of filtered images if the input is an  array of images.
270
271
272
273
274
275
276
277
278

        Parameters
        ----------
        img: numpy.ndarray
            image or stack of images
        invalid_val: float
            value to be used in replacement of nan and inf values
        median_filt_shape: int or sequence of int
            the width or the widths of the median window
279
280
        low_pass: float or sequence of two floats
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`
281
        high_pass: float or sequence of two floats
282
            High-pass filter properties, as described in `nabu.misc.fourier_filters`
283
284
285
286
287
288

        Returns
        -------
        numpy.array_like
            The computed filter
        """
289
290
        img = np.squeeze(img)  # Removes singleton dimensions, but does a shallow copy
        img = np.ascontiguousarray(img)
291

292
        if roi_yxhw is not None:
293
            img = img[
294
                ..., roi_yxhw[0] : roi_yxhw[0] + roi_yxhw[2], roi_yxhw[1] : roi_yxhw[1] + roi_yxhw[3],
295
296
            ]

myron's avatar
myron committed
297
        img = img.copy()
298
299
300
301

        img[np.isnan(img)] = invalid_val
        img[np.isinf(img)] = invalid_val

302
        if high_pass is not None or low_pass is not None:
303
304
305
            img_filter = fourier_filters.get_bandpass_filter(
                img.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass
            )
306
307
            # fft2 and iff2 use axes=(-2, -1) by default
            img = np.fft.ifft2(np.fft.fft2(img) * img_filter).real
308

309
        if median_filt_shape is not None:
310
311
312
313
            img_shape = img.shape
            if __have_scipy__:
                # expanding filter shape with ones, to cover the stack of images
                # but disabling inter-image filtering
314
                median_filt_shape = np.concatenate(
315
                    (np.ones((len(img_shape) - len(median_filt_shape),), dtype=np.int), median_filt_shape,)
316
                )
317
318
319
320
321
322
                img = median_filter(img, size=median_filt_shape)
            else:
                if len(img_shape) == 2:
                    img = median_filter(img, kernel_size=median_filt_shape)
                elif len(img_shape) > 2:
                    # if dealing with a stack of images, we have to do them one by one
323
                    img = np.reshape(img, (-1,) + tuple(img_shape[-2:]))
324
                    for ii in range(img.shape[0]):
325
                        img[ii, ...] = median_filter(img[ii, ...], kernel_size=median_filt_shape)
326
                    img = np.reshape(img, img_shape)
327
328
329

        return img

330
    @staticmethod
331
    def _compute_correlation_fft(img_1, img_2, padding_mode, axes=(-2, -1), low_pass=None, high_pass=None):
332
        do_circular_conv = padding_mode is None or padding_mode == "wrap"
333
        if not do_circular_conv:
334
335
336
337
338
339
340
341
            img_shape = img_2.shape
            pad_size = np.ceil(np.array(img_shape) / 2).astype(np.int)
            pad_array = [(0,)] * len(img_shape)
            for a in axes:
                pad_array[a] = (pad_size[a],)

            img_1 = np.pad(img_1, pad_array, mode=padding_mode)
            img_2 = np.pad(img_2, pad_array, mode=padding_mode)
342
343

        # compute fft's of the 2 images
344
345
        img_fft_1 = np.fft.fftn(img_1, axes=axes)
        img_fft_2 = np.conjugate(np.fft.fftn(img_2, axes=axes))
346
347
348
349
350
351
352

        img_prod = img_fft_1 * img_fft_2

        if low_pass is not None or high_pass is not None:
            filt = fourier_filters.get_bandpass_filter(img_prod.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass)
            img_prod *= filt

353
        # inverse fft of the product to get cross_correlation of the 2 images
354
        cc = np.real(np.fft.ifftn(img_prod, axes=axes))
355
356

        if not do_circular_conv:
357
358
359
360
            cc = np.fft.fftshift(cc, axes=axes)

            slicing = [slice(None)] * len(img_shape)
            for a in axes:
361
                slicing[a] = slice(pad_size[a], cc.shape[a] - pad_size[a])
362
363
364
            cc = cc[tuple(slicing)]

            cc = np.fft.ifftshift(cc, axes=axes)
365
366
367

        return cc

368
369

class CenterOfRotation(AlignmentBase):
370
    def __init__(self, horz_fft_width=False):
371
372
373
374
375
376
377
378
379
380
381
        """
        Center of Rotation (CoR) computation object.
        This class is used on radios.

        Parameters
        ----------
        horz_fft_width: boolean, optional
            If True, restrict the horizontal size to a power of 2:

            >>> new_x_dim = 2 ** math.floor(math.log2(x_dim))
        """
382
        self._init_parameters(horz_fft_width)
383

384
    def _init_parameters(self, horz_fft_width):
385
386
387
388
389
390
391
        self.truncate_horz_pow2 = horz_fft_width

    @staticmethod
    def _check_img_sizes(img_1: np.ndarray, img_2: np.ndarray):
        shape_1 = np.squeeze(img_1).shape
        shape_2 = np.squeeze(img_2).shape
        if not len(shape_1) == 2:
392
            raise ValueError(
393
                "Images need to be 2-dimensional. Shape of image #1: %s" % (" ".join(("%d" % x for x in shape_1)))
394
            )
395
        if not len(shape_2) == 2:
396
            raise ValueError(
397
                "Images need to be 2-dimensional. Shape of image #2: %s" % (" ".join(("%d" % x for x in shape_2)))
398
            )
399
        if not np.all(shape_1 == shape_2):
400
401
            raise ValueError(
                "Images need to be of the same shape. Shape of image #1: %s, image #2: %s"
402
                % (" ".join(("%d" % x for x in shape_1)), " ".join(("%d" % x for x in shape_2)),)
403
            )
404

405
    def find_shift(
406
407
408
409
410
        self,
        img_1: np.ndarray,
        img_2: np.ndarray,
        roi_yxhw=None,
        median_filt_shape=None,
myron's avatar
myron committed
411
        padding_mode=None,
412
        peak_fit_radius=1,
413
        high_pass=None,
414
        low_pass=None,
415
    ):
416
        """Find the Center of Rotation (CoR), given to images.
Nicola Vigano's avatar
Nicola Vigano committed
417

Pierre Paleo's avatar
Fix doc    
Pierre Paleo committed
418
419
        This method finds the half-shift between two opposite images, by
        means of correlation computed in Fourier space.
420

421
422
        The output of this function, allows to compute motor movements for
        aligning the sample rotation axis. Given the following values:
423

424
425
426
427
        - L1: distance from source to motor
        - L2: distance from source to detector
        - ps: physical pixel size
        - v: output of this function
428

Nicola Vigano's avatar
Nicola Vigano committed
429
        displacement of motor = (L1 / L2 * ps) * v
430

431
432
433
434
435
436
        Parameters
        ----------
        img_1: numpy.ndarray
            First image
        img_2: numpy.ndarray
            Second image, it needs to have been flipped already (e.g. using numpy.fliplr).
437
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
438
439
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
440
441
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
442
443
444
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
445
446
447
448
449
450
451
452
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
453
        peak_fit_radius: int, optional
454
455
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
456
        low_pass: float or sequence of two floats
457
458
459
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`
        high_pass: float or sequence of two floats
            High-pass filter properties, as described in `nabu.misc.fourier_filters`
460

461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
        Raises
        ------
        ValueError
            In case images are not 2-dimensional or have different sizes.

        Returns
        -------
        float
            Estimated center of rotation position from the center of the RoI in pixels.

        Examples
        --------
        The following code computes the center of rotation position for two
        given images in a tomography scan, where the second image is taken at
        180 degrees from the first.

        >>> radio1 = data[0, :, :]
        ... radio2 = np.fliplr(data[1, :, :])
479
        ... CoR_calc = CenterOfRotation()
480
        ... cor_position = CoR_calc.find_shift(radio1, radio2)
481
482
483

        Or for noisy images:

484
        >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
485
486
487
        """
        self._check_img_sizes(img_1, img_2)

488
        if peak_fit_radius < 1:
489
490
491
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
492
493
            peak_fit_radius = 1

494
        img_shape = img_2.shape
495
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, self.truncate_horz_pow2)
496

497
498
        img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
        img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
499

500
        cc = self._compute_correlation_fft(img_1, img_2, padding_mode, high_pass=high_pass, low_pass=low_pass)
501
502
503
        img_shape = img_2.shape
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
504

505
        (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
506
        fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh)
507

508
        return fitted_shifts_vh[-1] / 2.0
509

510
    __call__ = find_shift
511
512
513
514
515
516
517
518


class DetectorTranslationAlongBeam(AlignmentBase):
    @staticmethod
    def _check_img_sizes(img_stack: np.ndarray, img_pos: np.ndarray):
        shape_stack = np.squeeze(img_stack).shape
        shape_pos = np.squeeze(img_pos).shape
        if not len(shape_stack) == 3:
519
            raise ValueError(
520
                "A stack of 2-dimensional images is required. Shape of stack: %s" % (" ".join(("%d" % x for x in shape_stack)))
521
            )
522
        if not len(shape_pos) == 1:
523
524
525
526
            raise ValueError(
                "Positions need to be a 1-dimensional array. Shape of the positions variable: %s"
                % (" ".join(("%d" % x for x in shape_pos)))
            )
527
528
        if not shape_stack[0] == shape_pos[0]:
            raise ValueError(
529
530
                "The same number of images and positions is required."
                + " Shape of stack: %s, shape of positions variable: %s"
531
                % (" ".join(("%d" % x for x in shape_stack)), " ".join(("%d" % x for x in shape_pos)),)
532
            )
533
534

    def find_shift(
535
536
537
538
539
540
        self,
        img_stack: np.ndarray,
        img_pos: np.array,
        roi_yxhw=None,
        median_filt_shape=None,
        padding_mode=None,
Nicola Vigano's avatar
Nicola Vigano committed
541
        peak_fit_radius=1,
542
543
        high_pass=None,
        low_pass=None,
544
        equispaced_increments=True,
545
        return_shifts=False,
546
        use_adjacent_imgs=False,
547
548
549
550
551
    ):
        """Find the deviation of the translation axis of the area detector
        along the beam propagation direction.

        TODO: Add more information here! Including interpretation of the result
552
553
        This means giving also an example on how to convert the returned values
        into meaningful quantities. See "Returns" for more details.
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576

        Parameters
        ----------
        img_stack: numpy.ndarray
            A stack of images (usually 4) at different distances
        img_pos: numpy.ndarray
            Position of the images along the translation axis
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
577
        peak_fit_radius: int, optional
578
579
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
580
581
582
583
        low_pass: float or sequence of two floats
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`.
        high_pass: float or sequence of two floats
            High-pass filter properties, as described in `nabu.misc.fourier_filters`.
584
585
586
587
        found_shifts_list: list, optional
            if a list is given in input, it will be populated
            with the found shifts. For each given image of the stack a tuple of shifts will be
            appended to the list. This slot is intended for introspection.
588
589
        use_adjacent_imgs: boolean, optional
            Compute correlation between adjacent images. It is better when dealing with large shifts.
590
591
592
593

        Returns
        -------
        tuple(float, float)
594
595
            Estimated (vertical, horizontal) increment per unit-distance of the
            ratio between pixel-size and detector translation.
596
597
598
599
600
601
602

        Examples
        --------
        TODO: Add examples here!
        """
        self._check_img_sizes(img_stack, img_pos)

603
        if peak_fit_radius < 1:
604
605
606
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
607
608
            peak_fit_radius = 1

609
610
611
612
        num_imgs = img_stack.shape[0]
        img_shape = img_stack.shape[-2:]
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, False)

613
        img_stack = self._prepare_image(img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
614

615
616
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
617

618
        # do correlations
619
        ccs = [
620
            self._compute_correlation_fft(
621
                img_stack[ii - 1 if use_adjacent_imgs else 0, ...],
622
623
624
625
                img_stack[ii, ...],
                padding_mode,
                high_pass=high_pass,
                low_pass=low_pass,
626
            )
627
628
            for ii in range(1, num_imgs)
        ]
629

630
        shifts_vh = np.empty((num_imgs - 1, 2))
631
        for ii, cc in enumerate(ccs):
632
            (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
633
            shifts_vh[ii, :] = self.refine_max_position_2d(f_vals, fv, fh)
634

635
636
637
        if use_adjacent_imgs:
            shifts_vh = np.cumsum(shifts_vh, axis=0)
        img_pos_increments = img_pos[1:] - img_pos[0]
638
639
640
641
642
643

        # Polynomial.fit is supposed to be more numerically stable than polyfit
        # (according to numpy)
        coeffs_v = Polynomial.fit(img_pos_increments, shifts_vh[:, 0], deg=1).convert().coef
        coeffs_h = Polynomial.fit(img_pos_increments, shifts_vh[:, 1], deg=1).convert().coef

644
        if return_shifts:
645
            return coeffs_v[1], coeffs_h[1], shifts_vh
646
647
        else:
            return coeffs_v[1], coeffs_h[1]