alignment.py 24.6 KB
Newer Older
Nicola Vigano's avatar
Nicola Vigano committed
1
import numpy as np
2

3
import logging
4
from numpy.polynomial.polynomial import Polynomial
5

6
from nabu.utils import previouspow2
7
from nabu.misc import fourier_filters
8

9
10
try:
    from scipy.ndimage.filters import median_filter
11

12
13
14
    __have_scipy__ = True
except ImportError:
    from silx.math.medianfilter import medfilt2d as median_filter
15

16
    __have_scipy__ = False
17

myron's avatar
myron committed
18

myron's avatar
myron committed
19
20
class AlignmentBase(object):
    @staticmethod
21
22
    def refine_max_position_2d(f_vals: np.ndarray, fy=None, fx=None):
        """Computes the sub-pixel max position of the given function sampling.
23

24
25
26
27
28
29
30
31
        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
        fy: numpy.ndarray, optional
            Vertical coordinates of the sampled points
        fx: numpy.ndarray, optional
            Horizontal coordinates of the sampled points
32

33
34
35
36
37
        Raises
        ------
        ValueError
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        Returns
        -------
        tuple(float, float)
            Estimated (vertical, horizontal) function max, according to the
            coordinates in fy and fx.
        """
        if len(f_vals.shape) > 2:
            raise ValueError(
                "The fitted values should form a 2-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fy is None:
            fy_half_size = (f_vals.shape[0] - 1) / 2
            fy = np.linspace(-fy_half_size, fy_half_size, f_vals.shape[0])
        elif not (len(fy.shape) == 1 and np.all(fy.size == f_vals.shape[0])):
            raise ValueError(
                "Vertical coordinates should have the same length as values matrix. Sizes of fy: %d, f_vals: [%s]"
                % (fy.size, " ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.shape[1] - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.shape[1])
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.shape[1])):
            raise ValueError(
                "Horizontal coordinates should have the same length as values matrix. Sizes of fx: %d, f_vals: [%s]"
                % (fx.size, " ".join(("%d" % s for s in f_vals.shape)))
            )

        fy, fx = np.meshgrid(fy, fx, indexing="ij")
        fy = fy.flatten()
        fx = fx.flatten()
        coords = np.array([np.ones(f_vals.size), fy, fx, fy * fx, fy ** 2, fx ** 2])
71

72
        coeffs = np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0]
73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        # For a 1D parabola `f(x) = ax^2 + bx + c`, the vertex position is:
        # x_v = -b / 2a. For a 2D parabola, the vertex position is:
        # (y, x)_v = - b / A, where:
        A = [[2 * coeffs[4], coeffs[3]], [coeffs[3], 2 * coeffs[5]]]
        b = coeffs[1:3]
        vertex_yx = np.linalg.lstsq(A, -b, rcond=None)[0]

        vertex_min_yx = [np.min(fy), np.min(fx)]
        vertex_max_yx = [np.max(fy), np.max(fx)]
        if np.any(vertex_yx < vertex_min_yx) or np.any(vertex_yx > vertex_max_yx):
            raise ValueError(
                "Fitted (y: {}, x: {}) positions are outide the margins of input: y: [{}, {}], x: [{}, {}]".format(
                    vertex_yx[0], vertex_yx[1], vertex_min_yx[0], vertex_max_yx[0], vertex_min_yx[1], vertex_max_yx[1]
                )
            )
        return vertex_yx

    @staticmethod
92
    def refine_max_position_1d(f_vals, fx=None):
93
94
95
96
97
98
        """Computes the sub-pixel max position of the given function sampling.

        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
99
100
        fx: numpy.ndarray, optional
            Coordinates of the sampled points
101
102
103
104

        Raises
        ------
        ValueError
105
106
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
107
108
109
110
111
112

        Returns
        -------
        float
            Estimated function max, according to the coordinates in fx.
        """
113
        if not len(f_vals.shape) == 1:
114
            raise ValueError(
115
116
117
118
119
120
121
122
123
124
                "The fitted values should form a 1-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.size - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.size)
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.size)):
            raise ValueError(
                "Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d"
                % (fx.size, f_vals.size)
125
126
            )

127
128
129
130
        # using Polynomial.fit, because supposed to be more numerically stable
        # than previous solutions (according to numpy).
        poly = Polynomial.fit(fx, f_vals, deg=2)
        coeffs = poly.convert().coef
131

132
133
134
        # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is:
        # x_v = -b / 2a.
        vertex_x = - coeffs[1] / (2 * coeffs[2])
135

136
137
138
139
140
141
142
143
144
        vertex_min_x = np.min(fx)
        vertex_max_x = np.max(fx)
        if not (vertex_min_x < vertex_x < vertex_max_x):
            raise ValueError(
                "Fitted x: {} position is outide the margins of input: x: [{}, {}]".format(
                    vertex_x, vertex_min_x, vertex_max_x
                )
            )
        return vertex_x
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    @staticmethod
    def extract_peak_region(cc, peak_radius=1, cc_vs=None, cc_hs=None):
        """
        Extracts a region around the maximum value.

        Parameters
        ----------
        cc: numpy.ndarray
            Correlation image.
        peak_radius: int, optional
            The l_inf radius of the area to extract around the peak. The default is 1.
        cc_vs: numpy.ndarray, optional
            The vertical coordinates of `cc`. The default is None.
        cc_hs: numpy.ndarray, optional
            The horizontal coordinates of `cc`. The default is None.

        Returns
        -------
        f_vals: numpy.ndarray
            The extracted function values.
        fv: numpy.ndarray
            The vertical coordinates of the extracted values.
        fh: numpy.ndarray
            The horizontal coordinates of the extracted values.
        """
        img_shape = np.array(cc.shape)
        # get pixel having the maximum value of the correlation array
        pix_max_corr = np.argmax(cc)
        pv, ph = np.unravel_index(pix_max_corr, img_shape)

        # select a n x n neighborhood for the sub-pixel fitting (with wrapping)
        pv = np.arange(pv - peak_radius, pv + peak_radius + 1) % img_shape[-2]
        ph = np.arange(ph - peak_radius, ph + peak_radius + 1) % img_shape[-1]

        # extract the (v, h) pixel coordinates
        fv = None if cc_vs is None else cc_vs[pv]
        fh = None if cc_hs is None else cc_hs[ph]

        # extract the correlation values
        pv, ph = np.meshgrid(pv, ph, indexing="ij")
        f_vals = cc[pv, ph]

        return (f_vals, fv, fh)

190
191
192
193
194
195
196
197
    @staticmethod
    def _determine_roi(img_shape, roi_yxhw, do_truncate_horz_pow2):
        if roi_yxhw is None:
            # vertical window size is reduced to a power of 2 to accelerate fft
            # same thing horizontal window - if requested. Default is not
            roi_yxhw = previouspow2(img_shape)
            if not do_truncate_horz_pow2:
                roi_yxhw[1] = img_shape[1]
198

199
200
201
202
        if len(roi_yxhw) == 2:  # Convert centered 2-element roi into 4-element
            roi_yxhw = np.array(roi_yxhw, dtype=np.int)
            roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw))
        return roi_yxhw
203

204
    @staticmethod
205
    def _prepare_image(
206
        img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, high_pass=None, low_pass=None
207
    ):
208
209
        """
        Prepare and returns  a cropped  and filtered image, or array of filtered images if the input is an  array of images.
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238

        Parameters
        ----------
        img: numpy.ndarray
            image or stack of images
        invalid_val: float
            value to be used in replacement of nan and inf values
        median_filt_shape: int or sequence of int
            the width or the widths of the median window
        high_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter, and the result is subtracted from 1 to obtain the high pass filter
            When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff
            frequency while a smooth  transition to zero is done for smaller frequency
         low_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter.
            When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff
            frequency and then a smooth erfc transition to zero is done

        Returns
        -------
        numpy.array_like
            The computed filter
        """
239
240
        img = np.squeeze(img)  # Removes singleton dimensions, but does a shallow copy
        img = np.ascontiguousarray(img)
241

242
        if roi_yxhw is not None:
243
            img = img[
244
                ..., roi_yxhw[0]: roi_yxhw[0] + roi_yxhw[2], roi_yxhw[1]: roi_yxhw[1] + roi_yxhw[3],
245
246
            ]

myron's avatar
myron committed
247
        img = img.copy()
248
249
250
251

        img[np.isnan(img)] = invalid_val
        img[np.isinf(img)] = invalid_val

252
        if high_pass is not None or low_pass is not None:
253
            img_filter = np.ones(img.shape[-2:], dtype=img.dtype)
254
            if low_pass is not None:
255
                img_filter[:] *= fourier_filters.get_lowpass_filter(img.shape[-2:], low_pass)
256
            if high_pass is not None:
257
258
259
                img_filter[:] *= fourier_filters.get_highpass_filter(img.shape[-2:], high_pass)
            # fft2 and iff2 use axes=(-2, -1) by default
            img = np.fft.ifft2(np.fft.fft2(img) * img_filter).real
260

261
        if median_filt_shape is not None:
262
263
264
265
            img_shape = img.shape
            if __have_scipy__:
                # expanding filter shape with ones, to cover the stack of images
                # but disabling inter-image filtering
266
                median_filt_shape = np.concatenate(
267
                    (np.ones((len(img_shape) - len(median_filt_shape),), dtype=np.int), median_filt_shape,)
268
                )
269
270
271
272
273
274
                img = median_filter(img, size=median_filt_shape)
            else:
                if len(img_shape) == 2:
                    img = median_filter(img, kernel_size=median_filt_shape)
                elif len(img_shape) > 2:
                    # if dealing with a stack of images, we have to do them one by one
275
                    img = np.reshape(img, (-1,) + tuple(img_shape[-2:]))
276
                    for ii in range(img.shape[0]):
277
                        img[ii, ...] = median_filter(img[ii, ...], kernel_size=median_filt_shape)
278
                    img = np.reshape(img, img_shape)
279
280
281

        return img

282
283
    @staticmethod
    def _compute_correlation_fft(img_1, img_2, padding_mode):
284
        do_circular_conv = padding_mode is None or padding_mode == "wrap"
285
        if not do_circular_conv:
Nicola Vigano's avatar
Nicola Vigano committed
286
            padding = np.ceil(np.array(img_2.shape) / 2).astype(np.int)
287
288
            img_1 = np.pad(img_1, ((padding[0],), (padding[1],)), mode=padding_mode)
            img_2 = np.pad(img_2, ((padding[0],), (padding[1],)), mode=padding_mode)
289
290
291
292
293
294
295
296

        # compute fft's of the 2 images
        img_fft_1 = np.fft.fft2(img_1)
        img_fft_2 = np.conjugate(np.fft.fft2(img_2))
        # inverse fft of the product to get cross_correlation of the 2 images
        cc = np.real(np.fft.ifft2(img_fft_1 * img_fft_2))

        if not do_circular_conv:
297
            cc = np.fft.fftshift(cc, axes=(-2, -1))
298
            cc = cc[padding[0]: -padding[0], padding[1]: -padding[1]]
299
            cc = np.fft.ifftshift(cc, axes=(-2, -1))
300
301
302

        return cc

303
304

class CenterOfRotation(AlignmentBase):
305
    def __init__(self, horz_fft_width=False):
306
307
308
309
310
311
312
313
314
315
316
        """
        Center of Rotation (CoR) computation object.
        This class is used on radios.

        Parameters
        ----------
        horz_fft_width: boolean, optional
            If True, restrict the horizontal size to a power of 2:

            >>> new_x_dim = 2 ** math.floor(math.log2(x_dim))
        """
317
        self._init_parameters(horz_fft_width)
318

319
    def _init_parameters(self, horz_fft_width):
320
321
322
323
324
325
326
        self.truncate_horz_pow2 = horz_fft_width

    @staticmethod
    def _check_img_sizes(img_1: np.ndarray, img_2: np.ndarray):
        shape_1 = np.squeeze(img_1).shape
        shape_2 = np.squeeze(img_2).shape
        if not len(shape_1) == 2:
327
            raise ValueError(
328
                "Images need to be 2-dimensional. Shape of image #1: %s" % (" ".join(("%d" % x for x in shape_1)))
329
            )
330
        if not len(shape_2) == 2:
331
            raise ValueError(
332
                "Images need to be 2-dimensional. Shape of image #2: %s" % (" ".join(("%d" % x for x in shape_2)))
333
            )
334
        if not np.all(shape_1 == shape_2):
335
336
            raise ValueError(
                "Images need to be of the same shape. Shape of image #1: %s, image #2: %s"
337
                % (" ".join(("%d" % x for x in shape_1)), " ".join(("%d" % x for x in shape_2)),)
338
            )
339

340
    def find_shift(
341
342
343
344
345
        self,
        img_1: np.ndarray,
        img_2: np.ndarray,
        roi_yxhw=None,
        median_filt_shape=None,
myron's avatar
myron committed
346
        padding_mode=None,
347
        peak_fit_radius=1,
348
349
        high_pass=None,
        low_pass=None
350
    ):
351
        """Find the Center of Rotation (CoR), given to images.
Nicola Vigano's avatar
Nicola Vigano committed
352

Pierre Paleo's avatar
Fix doc    
Pierre Paleo committed
353
354
        This method finds the half-shift between two opposite images, by
        means of correlation computed in Fourier space.
355

356
357
        The output of this function, allows to compute motor movements for
        aligning the sample rotation axis. Given the following values:
358

359
360
361
362
        - L1: distance from source to motor
        - L2: distance from source to detector
        - ps: physical pixel size
        - v: output of this function
363

Nicola Vigano's avatar
Nicola Vigano committed
364
        displacement of motor = (L1 / L2 * ps) * v
365

366
367
368
369
370
371
        Parameters
        ----------
        img_1: numpy.ndarray
            First image
        img_2: numpy.ndarray
            Second image, it needs to have been flipped already (e.g. using numpy.fliplr).
372
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
373
374
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
375
376
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
377
378
379
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
380
381
382
383
384
385
386
387
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
388
        peak_fit_radius: int, optional
389
390
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
391

392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        high_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter, and the result is subtracted from 1 to obtain the high pass filter
            When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff
            frequency while a smooth  transition to zero is done for smaller frequency
        low_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter.
            When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff
            frequency and then a smooth erfc transition to zero is done

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
        Raises
        ------
        ValueError
            In case images are not 2-dimensional or have different sizes.

        Returns
        -------
        float
            Estimated center of rotation position from the center of the RoI in pixels.

        Examples
        --------
        The following code computes the center of rotation position for two
        given images in a tomography scan, where the second image is taken at
        180 degrees from the first.

        >>> radio1 = data[0, :, :]
        ... radio2 = np.fliplr(data[1, :, :])
425
        ... CoR_calc = CenterOfRotation()
426
        ... cor_position = CoR_calc.find_shift(radio1, radio2)
427
428
429

        Or for noisy images:

430
        >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
431
432
433
        """
        self._check_img_sizes(img_1, img_2)

434
        if peak_fit_radius < 1:
435
436
437
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
438
439
            peak_fit_radius = 1

440
        img_shape = img_2.shape
441
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, self.truncate_horz_pow2)
442

443
444
445
446
        img_1 = self._prepare_image(
            img_1,
            roi_yxhw=roi_yxhw,
            median_filt_shape=median_filt_shape,
447
448
            high_pass=high_pass,
            low_pass=low_pass
449
450
451
452
453
        )
        img_2 = self._prepare_image(
            img_2,
            roi_yxhw=roi_yxhw,
            median_filt_shape=median_filt_shape,
454
455
            high_pass=high_pass,
            low_pass=low_pass
456
        )
457

458
        cc = self._compute_correlation_fft(img_1, img_2, padding_mode)
459
460
461
        img_shape = img_2.shape
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
462

463
        (f_vals, fv, fh) = self.extract_peak_region(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
464
        fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh)
465

466
        return fitted_shifts_vh[-1] / 2.0
467

468
    __call__ = find_shift
469
470
471
472
473
474
475
476


class DetectorTranslationAlongBeam(AlignmentBase):
    @staticmethod
    def _check_img_sizes(img_stack: np.ndarray, img_pos: np.ndarray):
        shape_stack = np.squeeze(img_stack).shape
        shape_pos = np.squeeze(img_pos).shape
        if not len(shape_stack) == 3:
477
            raise ValueError(
478
                "A stack of 2-dimensional images is required. Shape of stack: %s" % (" ".join(("%d" % x for x in shape_stack)))
479
            )
480
        if not len(shape_pos) == 1:
481
482
483
484
            raise ValueError(
                "Positions need to be a 1-dimensional array. Shape of the positions variable: %s"
                % (" ".join(("%d" % x for x in shape_pos)))
            )
485
486
        if not shape_stack[0] == shape_pos[0]:
            raise ValueError(
487
488
                "The same number of images and positions is required."
                + " Shape of stack: %s, shape of positions variable: %s"
489
                % (" ".join(("%d" % x for x in shape_stack)), " ".join(("%d" % x for x in shape_pos)),)
490
            )
491
492

    def find_shift(
493
494
495
496
497
498
        self,
        img_stack: np.ndarray,
        img_pos: np.array,
        roi_yxhw=None,
        median_filt_shape=None,
        padding_mode=None,
Nicola Vigano's avatar
Nicola Vigano committed
499
        peak_fit_radius=1,
500
        equispaced_increments=False
501
502
503
504
505
    ):
        """Find the deviation of the translation axis of the area detector
        along the beam propagation direction.

        TODO: Add more information here! Including interpretation of the result
506
507
        This means giving also an example on how to convert the returned values
        into meaningful quantities. See "Returns" for more details.
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530

        Parameters
        ----------
        img_stack: numpy.ndarray
            A stack of images (usually 4) at different distances
        img_pos: numpy.ndarray
            Position of the images along the translation axis
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
531
        peak_fit_radius: int, optional
532
533
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
534
535
536
537
538
539
540
541
542
        equispaced_increments: boolean, optional
            Tells whether the position increments are equispaced or not. If
            equispaced increments are used, we have to compute the correlation
            images with respect to the first image, otherwise we can do it
            against adjacent images.
            The advantage of doing it between adjacent images is that we do not
            build up large displacements in the correlation.
            However, if this is done for equispaced images, the linear fitting
            becomes unstable.
543
544
545
546

        Returns
        -------
        tuple(float, float)
547
548
            Estimated (vertical, horizontal) increment per unit-distance of the
            ratio between pixel-size and detector translation.
549
550
551
552
553
554
555

        Examples
        --------
        TODO: Add examples here!
        """
        self._check_img_sizes(img_stack, img_pos)

556
        if peak_fit_radius < 1:
557
558
559
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
560
561
            peak_fit_radius = 1

562
563
564
565
        num_imgs = img_stack.shape[0]
        img_shape = img_stack.shape[-2:]
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, False)

566
        img_stack = self._prepare_image(img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
567

568
569
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
570

571
        # do correlations
572
        ccs = [
573
574
575
            self._compute_correlation_fft(
                img_stack[0 if equispaced_increments else ii - 1, ...], img_stack[ii, ...], padding_mode
            )
576
577
            for ii in range(1, num_imgs)
        ]
578

579
        shifts_vh = np.empty((num_imgs - 1, 2))
580
        for ii, cc in enumerate(ccs):
581
            (f_vals, fv, fh) = self.extract_peak_region(cc, peak_fit_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
582
            shifts_vh[ii, :] = self.refine_max_position_2d(f_vals, fv, fh)
583

584
585
586
587
588
589
590
591
592
593
594
        if equispaced_increments:
            img_pos_increments = img_pos[1:] - img_pos[0]
        else:
            img_pos_increments = - np.diff(img_pos)

        # Polynomial.fit is supposed to be more numerically stable than polyfit
        # (according to numpy)
        coeffs_v = Polynomial.fit(img_pos_increments, shifts_vh[:, 0], deg=1).convert().coef
        coeffs_h = Polynomial.fit(img_pos_increments, shifts_vh[:, 1], deg=1).convert().coef

        return coeffs_v[1], coeffs_h[1]