alignment.py 24.1 KB
Newer Older
Nicola Vigano's avatar
Nicola Vigano committed
1
import numpy as np
2

3
import logging
4
from numpy.polynomial.polynomial import Polynomial
5

6
from nabu.utils import previouspow2
7
from nabu.misc import fourier_filters
8

9
10
try:
    from scipy.ndimage.filters import median_filter
11

12
13
14
    __have_scipy__ = True
except ImportError:
    from silx.math.medianfilter import medfilt2d as median_filter
15

16
    __have_scipy__ = False
17

myron's avatar
myron committed
18

myron's avatar
myron committed
19
20
class AlignmentBase(object):
    @staticmethod
21
22
    def refine_max_position_2d(f_vals: np.ndarray, fy=None, fx=None):
        """Computes the sub-pixel max position of the given function sampling.
23

24
25
26
27
28
29
30
31
        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
        fy: numpy.ndarray, optional
            Vertical coordinates of the sampled points
        fx: numpy.ndarray, optional
            Horizontal coordinates of the sampled points
32

33
34
35
36
37
        Raises
        ------
        ValueError
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        Returns
        -------
        tuple(float, float)
            Estimated (vertical, horizontal) function max, according to the
            coordinates in fy and fx.
        """
        if len(f_vals.shape) > 2:
            raise ValueError(
                "The fitted values should form a 2-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fy is None:
            fy_half_size = (f_vals.shape[0] - 1) / 2
            fy = np.linspace(-fy_half_size, fy_half_size, f_vals.shape[0])
        elif not (len(fy.shape) == 1 and np.all(fy.size == f_vals.shape[0])):
            raise ValueError(
                "Vertical coordinates should have the same length as values matrix. Sizes of fy: %d, f_vals: [%s]"
                % (fy.size, " ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.shape[1] - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.shape[1])
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.shape[1])):
            raise ValueError(
                "Horizontal coordinates should have the same length as values matrix. Sizes of fx: %d, f_vals: [%s]"
                % (fx.size, " ".join(("%d" % s for s in f_vals.shape)))
            )

        fy, fx = np.meshgrid(fy, fx, indexing="ij")
        fy = fy.flatten()
        fx = fx.flatten()
        coords = np.array([np.ones(f_vals.size), fy, fx, fy * fx, fy ** 2, fx ** 2])
71

72
        coeffs = np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0]
73

74
75
76
77
78
79
80
81
82
83
84
        # For a 1D parabola `f(x) = ax^2 + bx + c`, the vertex position is:
        # x_v = -b / 2a. For a 2D parabola, the vertex position is:
        # (y, x)_v = - b / A, where:
        A = [[2 * coeffs[4], coeffs[3]], [coeffs[3], 2 * coeffs[5]]]
        b = coeffs[1:3]
        vertex_yx = np.linalg.lstsq(A, -b, rcond=None)[0]

        vertex_min_yx = [np.min(fy), np.min(fx)]
        vertex_max_yx = [np.max(fy), np.max(fx)]
        if np.any(vertex_yx < vertex_min_yx) or np.any(vertex_yx > vertex_max_yx):
            raise ValueError(
Nicola Vigano's avatar
Nicola Vigano committed
85
                "Fitted (y: {}, x: {}) positions are outide the input margins y: [{}, {}], and x: [{}, {}]".format(
86
87
88
89
90
91
                    vertex_yx[0], vertex_yx[1], vertex_min_yx[0], vertex_max_yx[0], vertex_min_yx[1], vertex_max_yx[1]
                )
            )
        return vertex_yx

    @staticmethod
92
    def refine_max_position_1d(f_vals, fx=None):
93
94
95
96
97
98
        """Computes the sub-pixel max position of the given function sampling.

        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
99
100
        fx: numpy.ndarray, optional
            Coordinates of the sampled points
101
102
103
104

        Raises
        ------
        ValueError
105
106
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
107
108
109
110
111
112

        Returns
        -------
        float
            Estimated function max, according to the coordinates in fx.
        """
113
        if not len(f_vals.shape) in (1, 2):
114
            raise ValueError(
115
                "The fitted values should be either one or a collection of 1-dimensional arrays. Array of shape: [%s] was given."
116
117
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
118
119
        num_vals = f_vals.shape[0]

120
        if fx is None:
121
122
123
            fx_half_size = (num_vals - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, num_vals)
        elif not (len(fx.shape) == 1 and np.all(fx.size == num_vals)):
124
125
            raise ValueError(
                "Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d"
126
                % (fx.size, num_vals)
127
128
            )

129
130
131
132
133
134
135
136
        if len(f_vals.shape) == 1:
            # using Polynomial.fit, because supposed to be more numerically
            # stable than previous solutions (according to numpy).
            poly = Polynomial.fit(fx, f_vals, deg=2)
            coeffs = poly.convert().coef
        else:
            coords = np.array([np.ones(num_vals), fx, fx ** 2])
            coeffs = np.linalg.lstsq(coords.T, f_vals, rcond=None)[0]
137

138
139
        # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is:
        # x_v = -b / 2a.
140
        vertex_x = -coeffs[1, :] / (2 * coeffs[2, :])
141

142
143
        vertex_min_x = np.min(fx)
        vertex_max_x = np.max(fx)
144
145
146
147
148
        lower_bound_ok = vertex_min_x < vertex_x
        upper_bound_ok = vertex_x < vertex_max_x
        if not np.all(lower_bound_ok * upper_bound_ok):
            if len(f_vals.shape) == 1:
                message = "Fitted position {} is outide the input margins [{}, {}]".format(
149
150
                    vertex_x, vertex_min_x, vertex_max_x
                )
151
152
153
154
155
            else:
                message = "Fitted positions outide the input margins [{}, {}]: %d below and %d above".format(
                    vertex_min_x, vertex_max_x, np.sum(1 - lower_bound_ok), np.sum(1 - upper_bound_ok)
                )
            raise ValueError(message)
156
        return vertex_x
157

158
    @staticmethod
159
    def extract_peak_region_2d(cc, peak_radius=1, cc_vs=None, cc_hs=None):
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        """
        Extracts a region around the maximum value.

        Parameters
        ----------
        cc: numpy.ndarray
            Correlation image.
        peak_radius: int, optional
            The l_inf radius of the area to extract around the peak. The default is 1.
        cc_vs: numpy.ndarray, optional
            The vertical coordinates of `cc`. The default is None.
        cc_hs: numpy.ndarray, optional
            The horizontal coordinates of `cc`. The default is None.

        Returns
        -------
        f_vals: numpy.ndarray
            The extracted function values.
        fv: numpy.ndarray
            The vertical coordinates of the extracted values.
        fh: numpy.ndarray
            The horizontal coordinates of the extracted values.
        """
        img_shape = np.array(cc.shape)
        # get pixel having the maximum value of the correlation array
        pix_max_corr = np.argmax(cc)
        pv, ph = np.unravel_index(pix_max_corr, img_shape)

        # select a n x n neighborhood for the sub-pixel fitting (with wrapping)
        pv = np.arange(pv - peak_radius, pv + peak_radius + 1) % img_shape[-2]
        ph = np.arange(ph - peak_radius, ph + peak_radius + 1) % img_shape[-1]

        # extract the (v, h) pixel coordinates
        fv = None if cc_vs is None else cc_vs[pv]
        fh = None if cc_hs is None else cc_hs[ph]

        # extract the correlation values
        pv, ph = np.meshgrid(pv, ph, indexing="ij")
        f_vals = cc[pv, ph]

        return (f_vals, fv, fh)

202
203
204
205
206
207
208
209
    @staticmethod
    def _determine_roi(img_shape, roi_yxhw, do_truncate_horz_pow2):
        if roi_yxhw is None:
            # vertical window size is reduced to a power of 2 to accelerate fft
            # same thing horizontal window - if requested. Default is not
            roi_yxhw = previouspow2(img_shape)
            if not do_truncate_horz_pow2:
                roi_yxhw[1] = img_shape[1]
210

211
212
213
214
        if len(roi_yxhw) == 2:  # Convert centered 2-element roi into 4-element
            roi_yxhw = np.array(roi_yxhw, dtype=np.int)
            roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw))
        return roi_yxhw
215

216
    @staticmethod
217
    def _prepare_image(img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, low_pass=None, high_pass=None):
218
219
        """
        Prepare and returns  a cropped  and filtered image, or array of filtered images if the input is an  array of images.
220
221
222
223
224
225
226
227
228

        Parameters
        ----------
        img: numpy.ndarray
            image or stack of images
        invalid_val: float
            value to be used in replacement of nan and inf values
        median_filt_shape: int or sequence of int
            the width or the widths of the median window
229
230
        low_pass: float or sequence of two floats
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`
231
        high_pass: float or sequence of two floats
232
            High-pass filter properties, as described in `nabu.misc.fourier_filters`
233
234
235
236
237
238

        Returns
        -------
        numpy.array_like
            The computed filter
        """
239
240
        img = np.squeeze(img)  # Removes singleton dimensions, but does a shallow copy
        img = np.ascontiguousarray(img)
241

242
        if roi_yxhw is not None:
243
            img = img[
244
                ..., roi_yxhw[0] : roi_yxhw[0] + roi_yxhw[2], roi_yxhw[1] : roi_yxhw[1] + roi_yxhw[3],
245
246
            ]

myron's avatar
myron committed
247
        img = img.copy()
248
249
250
251

        img[np.isnan(img)] = invalid_val
        img[np.isinf(img)] = invalid_val

252
        if high_pass is not None or low_pass is not None:
253
254
255
            img_filter = fourier_filters.get_bandpass_filter(
                img.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass
            )
256
257
            # fft2 and iff2 use axes=(-2, -1) by default
            img = np.fft.ifft2(np.fft.fft2(img) * img_filter).real
258

259
        if median_filt_shape is not None:
260
261
262
263
            img_shape = img.shape
            if __have_scipy__:
                # expanding filter shape with ones, to cover the stack of images
                # but disabling inter-image filtering
264
                median_filt_shape = np.concatenate(
265
                    (np.ones((len(img_shape) - len(median_filt_shape),), dtype=np.int), median_filt_shape,)
266
                )
267
268
269
270
271
272
                img = median_filter(img, size=median_filt_shape)
            else:
                if len(img_shape) == 2:
                    img = median_filter(img, kernel_size=median_filt_shape)
                elif len(img_shape) > 2:
                    # if dealing with a stack of images, we have to do them one by one
273
                    img = np.reshape(img, (-1,) + tuple(img_shape[-2:]))
274
                    for ii in range(img.shape[0]):
275
                        img[ii, ...] = median_filter(img[ii, ...], kernel_size=median_filt_shape)
276
                    img = np.reshape(img, img_shape)
277
278
279

        return img

280
    @staticmethod
281
    def _compute_correlation_fft(img_1, img_2, padding_mode, axes=(-2, -1), low_pass=None, high_pass=None):
282
        do_circular_conv = padding_mode is None or padding_mode == "wrap"
283
        if not do_circular_conv:
284
285
286
287
288
289
290
291
            img_shape = img_2.shape
            pad_size = np.ceil(np.array(img_shape) / 2).astype(np.int)
            pad_array = [(0,)] * len(img_shape)
            for a in axes:
                pad_array[a] = (pad_size[a],)

            img_1 = np.pad(img_1, pad_array, mode=padding_mode)
            img_2 = np.pad(img_2, pad_array, mode=padding_mode)
292
293

        # compute fft's of the 2 images
294
295
        img_fft_1 = np.fft.fftn(img_1, axes=axes)
        img_fft_2 = np.conjugate(np.fft.fftn(img_2, axes=axes))
296
297
298
299
300
301
302

        img_prod = img_fft_1 * img_fft_2

        if low_pass is not None or high_pass is not None:
            filt = fourier_filters.get_bandpass_filter(img_prod.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass)
            img_prod *= filt

303
        # inverse fft of the product to get cross_correlation of the 2 images
304
        cc = np.real(np.fft.ifftn(img_prod, axes=axes))
305
306

        if not do_circular_conv:
307
308
309
310
            cc = np.fft.fftshift(cc, axes=axes)

            slicing = [slice(None)] * len(img_shape)
            for a in axes:
311
                slicing[a] = slice(pad_size[a], cc.shape[a] - pad_size[a])
312
313
314
            cc = cc[tuple(slicing)]

            cc = np.fft.ifftshift(cc, axes=axes)
315
316
317

        return cc

318
319

class CenterOfRotation(AlignmentBase):
320
    def __init__(self, horz_fft_width=False):
321
322
323
324
325
326
327
328
329
330
331
        """
        Center of Rotation (CoR) computation object.
        This class is used on radios.

        Parameters
        ----------
        horz_fft_width: boolean, optional
            If True, restrict the horizontal size to a power of 2:

            >>> new_x_dim = 2 ** math.floor(math.log2(x_dim))
        """
332
        self._init_parameters(horz_fft_width)
333

334
    def _init_parameters(self, horz_fft_width):
335
336
337
338
339
340
341
        self.truncate_horz_pow2 = horz_fft_width

    @staticmethod
    def _check_img_sizes(img_1: np.ndarray, img_2: np.ndarray):
        shape_1 = np.squeeze(img_1).shape
        shape_2 = np.squeeze(img_2).shape
        if not len(shape_1) == 2:
342
            raise ValueError(
343
                "Images need to be 2-dimensional. Shape of image #1: %s" % (" ".join(("%d" % x for x in shape_1)))
344
            )
345
        if not len(shape_2) == 2:
346
            raise ValueError(
347
                "Images need to be 2-dimensional. Shape of image #2: %s" % (" ".join(("%d" % x for x in shape_2)))
348
            )
349
        if not np.all(shape_1 == shape_2):
350
351
            raise ValueError(
                "Images need to be of the same shape. Shape of image #1: %s, image #2: %s"
352
                % (" ".join(("%d" % x for x in shape_1)), " ".join(("%d" % x for x in shape_2)),)
353
            )
354

355
    def find_shift(
356
357
358
359
360
        self,
        img_1: np.ndarray,
        img_2: np.ndarray,
        roi_yxhw=None,
        median_filt_shape=None,
myron's avatar
myron committed
361
        padding_mode=None,
362
        peak_fit_radius=1,
363
        high_pass=None,
364
        low_pass=None,
365
    ):
366
        """Find the Center of Rotation (CoR), given to images.
Nicola Vigano's avatar
Nicola Vigano committed
367

Pierre Paleo's avatar
Fix doc    
Pierre Paleo committed
368
369
        This method finds the half-shift between two opposite images, by
        means of correlation computed in Fourier space.
370

371
372
        The output of this function, allows to compute motor movements for
        aligning the sample rotation axis. Given the following values:
373

374
375
376
377
        - L1: distance from source to motor
        - L2: distance from source to detector
        - ps: physical pixel size
        - v: output of this function
378

Nicola Vigano's avatar
Nicola Vigano committed
379
        displacement of motor = (L1 / L2 * ps) * v
380

381
382
383
384
385
386
        Parameters
        ----------
        img_1: numpy.ndarray
            First image
        img_2: numpy.ndarray
            Second image, it needs to have been flipped already (e.g. using numpy.fliplr).
387
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
388
389
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
390
391
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
392
393
394
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
395
396
397
398
399
400
401
402
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
403
        peak_fit_radius: int, optional
404
405
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
406
        low_pass: float or sequence of two floats
407
408
409
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`
        high_pass: float or sequence of two floats
            High-pass filter properties, as described in `nabu.misc.fourier_filters`
410

411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        Raises
        ------
        ValueError
            In case images are not 2-dimensional or have different sizes.

        Returns
        -------
        float
            Estimated center of rotation position from the center of the RoI in pixels.

        Examples
        --------
        The following code computes the center of rotation position for two
        given images in a tomography scan, where the second image is taken at
        180 degrees from the first.

        >>> radio1 = data[0, :, :]
        ... radio2 = np.fliplr(data[1, :, :])
429
        ... CoR_calc = CenterOfRotation()
430
        ... cor_position = CoR_calc.find_shift(radio1, radio2)
431
432
433

        Or for noisy images:

434
        >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
435
436
437
        """
        self._check_img_sizes(img_1, img_2)

438
        if peak_fit_radius < 1:
439
440
441
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
442
443
            peak_fit_radius = 1

444
        img_shape = img_2.shape
445
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, self.truncate_horz_pow2)
446

447
448
        img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
        img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
449

450
        cc = self._compute_correlation_fft(img_1, img_2, padding_mode, high_pass=high_pass, low_pass=low_pass)
451
452
453
        img_shape = img_2.shape
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
454

455
        (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
456
        fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh)
457

458
        return fitted_shifts_vh[-1] / 2.0
459

460
    __call__ = find_shift
461
462
463
464
465
466
467
468


class DetectorTranslationAlongBeam(AlignmentBase):
    @staticmethod
    def _check_img_sizes(img_stack: np.ndarray, img_pos: np.ndarray):
        shape_stack = np.squeeze(img_stack).shape
        shape_pos = np.squeeze(img_pos).shape
        if not len(shape_stack) == 3:
469
            raise ValueError(
470
                "A stack of 2-dimensional images is required. Shape of stack: %s" % (" ".join(("%d" % x for x in shape_stack)))
471
            )
472
        if not len(shape_pos) == 1:
473
474
475
476
            raise ValueError(
                "Positions need to be a 1-dimensional array. Shape of the positions variable: %s"
                % (" ".join(("%d" % x for x in shape_pos)))
            )
477
478
        if not shape_stack[0] == shape_pos[0]:
            raise ValueError(
479
480
                "The same number of images and positions is required."
                + " Shape of stack: %s, shape of positions variable: %s"
481
                % (" ".join(("%d" % x for x in shape_stack)), " ".join(("%d" % x for x in shape_pos)),)
482
            )
483
484

    def find_shift(
485
486
487
488
489
490
        self,
        img_stack: np.ndarray,
        img_pos: np.array,
        roi_yxhw=None,
        median_filt_shape=None,
        padding_mode=None,
Nicola Vigano's avatar
Nicola Vigano committed
491
        peak_fit_radius=1,
492
493
        high_pass=None,
        low_pass=None,
494
        equispaced_increments=True,
495
        return_shifts=False,
496
        use_adjacent_imgs=False,
497
498
499
500
501
    ):
        """Find the deviation of the translation axis of the area detector
        along the beam propagation direction.

        TODO: Add more information here! Including interpretation of the result
502
503
        This means giving also an example on how to convert the returned values
        into meaningful quantities. See "Returns" for more details.
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526

        Parameters
        ----------
        img_stack: numpy.ndarray
            A stack of images (usually 4) at different distances
        img_pos: numpy.ndarray
            Position of the images along the translation axis
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
527
        peak_fit_radius: int, optional
528
529
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
530
531
532
533
        low_pass: float or sequence of two floats
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`.
        high_pass: float or sequence of two floats
            High-pass filter properties, as described in `nabu.misc.fourier_filters`.
534
535
536
537
        found_shifts_list: list, optional
            if a list is given in input, it will be populated
            with the found shifts. For each given image of the stack a tuple of shifts will be
            appended to the list. This slot is intended for introspection.
538
539
        use_adjacent_imgs: boolean, optional
            Compute correlation between adjacent images. It is better when dealing with large shifts.
540
541
542
543

        Returns
        -------
        tuple(float, float)
544
545
            Estimated (vertical, horizontal) increment per unit-distance of the
            ratio between pixel-size and detector translation.
546
547
548
549
550
551
552

        Examples
        --------
        TODO: Add examples here!
        """
        self._check_img_sizes(img_stack, img_pos)

553
        if peak_fit_radius < 1:
554
555
556
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
557
558
            peak_fit_radius = 1

559
560
561
562
        num_imgs = img_stack.shape[0]
        img_shape = img_stack.shape[-2:]
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, False)

563
        img_stack = self._prepare_image(img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
564

565
566
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
567

568
        # do correlations
569
        ccs = [
570
            self._compute_correlation_fft(
571
                img_stack[ii - 1 if use_adjacent_imgs else 0, ...],
572
573
574
575
                img_stack[ii, ...],
                padding_mode,
                high_pass=high_pass,
                low_pass=low_pass,
576
            )
577
578
            for ii in range(1, num_imgs)
        ]
579

580
        shifts_vh = np.empty((num_imgs - 1, 2))
581
        for ii, cc in enumerate(ccs):
582
            (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
583
            shifts_vh[ii, :] = self.refine_max_position_2d(f_vals, fv, fh)
584

585
586
587
        if use_adjacent_imgs:
            shifts_vh = np.cumsum(shifts_vh, axis=0)
        img_pos_increments = img_pos[1:] - img_pos[0]
588
589
590
591
592
593

        # Polynomial.fit is supposed to be more numerically stable than polyfit
        # (according to numpy)
        coeffs_v = Polynomial.fit(img_pos_increments, shifts_vh[:, 0], deg=1).convert().coef
        coeffs_h = Polynomial.fit(img_pos_increments, shifts_vh[:, 1], deg=1).convert().coef

594
        if return_shifts:
595
            return coeffs_v[1], coeffs_h[1], shifts_vh
596
597
        else:
            return coeffs_v[1], coeffs_h[1]