alignment.py 28 KB
Newer Older
Nicola Vigano's avatar
Nicola Vigano committed
1
import numpy as np
myron's avatar
myron committed
2
import math
3
import logging
4
from numpy.polynomial.polynomial import Polynomial
5

6
from nabu.utils import previouspow2
7

8
9
try:
    from scipy.ndimage.filters import median_filter
myron's avatar
myron committed
10
    import scipy.special as spspe
11

12
13
14
    __have_scipy__ = True
except ImportError:
    from silx.math.medianfilter import medfilt2d as median_filter
15

16
    __have_scipy__ = False
17

myron's avatar
myron committed
18

19
def _get_lowpass_filter(img_shape, cutoff_par):
myron's avatar
myron committed
20
    """Computes a low pass filter with the erfc.
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    Parameters
    ----------
    img_shape: tuple
        Shape of the image
    cutoff_par: float or sequence of two floats
        Position of the cut off in pixels, if a sequence is given the second float expresses the
        width of the transition region which is given as a fraction of the cutoff frequency.
        When only one float is given for this argument a gaussian is applied whose sigma is the
        parameter.
        When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff
        frequency while a smooth erfc transition to zero is done

    Raises
    ------
    ValueError
            In case cutoff_par is not properly given

    Returns
    -------
    numpy.array_like
        The computed filter
myron's avatar
myron committed
43
    """
44
45
46
47
48
49
50
51
52
53
54
55
56
    if isinstance(cutoff_par, (int, float)):
        cutoff_pix = cutoff_par
        cutoff_trans_fact = None
    else:
        try:
            cutoff_pix, cutoff_trans_fact = cutoff_par
        except ValueError:
            raise ValueError("Argument cutoff_par  (which specifies the pass filter shape) must be either a scalar or a"
                             " sequence of two scalars")
        if (not isinstance(cutoff_pix, (int, float))) or (not isinstance(cutoff_trans_fact, (int, float))):
            raise ValueError("Argument cutoff_par  (which specifies the pass filter shape) must be  one number or a sequence"
                             "of two numbers")

myron's avatar
myron committed
57
    coords = [np.fft.fftfreq(s, 1) for s in img_shape]
58
    coords = np.meshgrid(*coords, indexing="ij")
myron's avatar
myron committed
59
60

    r = np.sqrt(np.sum(np.array(coords) ** 2, axis=0))
61
62
63
64
65
66
67
68
69

    if cutoff_trans_fact is not None:
        k_cut = 0.5 / cutoff_pix
        k_cut_width = k_cut * cutoff_trans_fact
        k_pos_rescaled = (r - k_cut) / k_cut_width
        if __have_scipy__:
            res = spspe.erfc(k_pos_rescaled) / 2
        else:
            res = np.array(list(map(math.erfc, k_pos_rescaled))) / 2
myron's avatar
myron committed
70
    else:
71
        res = np.exp(- np.pi*np.pi*r*r*cutoff_pix*cutoff_pix*2)
myron's avatar
myron committed
72

73
    return res
myron's avatar
myron committed
74

75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def _get_highpass_filter(img_shape, cutoff_pix, cutoff_par):
    """Computes a high pass filter with the erfc.

    Parameters
    ----------
    img_shape: tuple
        Shape of the image
    cutoff_par: float or sequence of two floats
        Position of the cut off in pixels, if a sequence is given the second float expresses the
        width of the transition region which is given as a fraction of the cutoff frequency.
        When only one float is given for this argument a gaussian is applied whose sigma is the
        parameter, and the result is subtracted from 1 to obtain the high pass filter
        When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff
        frequency and then a smooth  transition to zero is done for smaller frequency

    Returns
    -------
    numpy.array_like
        The computed filter
    """
    res = 1 - _get_lowpass_filter(img_shape, cutoff_pix, cutoff_par)
    return res


myron's avatar
myron committed
100
101
class AlignmentBase(object):
    @staticmethod
102
103
    def refine_max_position_2d(f_vals: np.ndarray, fy=None, fx=None):
        """Computes the sub-pixel max position of the given function sampling.
104

105
106
107
108
109
110
111
112
        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
        fy: numpy.ndarray, optional
            Vertical coordinates of the sampled points
        fx: numpy.ndarray, optional
            Horizontal coordinates of the sampled points
113

114
115
116
117
118
        Raises
        ------
        ValueError
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        Returns
        -------
        tuple(float, float)
            Estimated (vertical, horizontal) function max, according to the
            coordinates in fy and fx.
        """
        if len(f_vals.shape) > 2:
            raise ValueError(
                "The fitted values should form a 2-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fy is None:
            fy_half_size = (f_vals.shape[0] - 1) / 2
            fy = np.linspace(-fy_half_size, fy_half_size, f_vals.shape[0])
        elif not (len(fy.shape) == 1 and np.all(fy.size == f_vals.shape[0])):
            raise ValueError(
                "Vertical coordinates should have the same length as values matrix. Sizes of fy: %d, f_vals: [%s]"
                % (fy.size, " ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.shape[1] - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.shape[1])
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.shape[1])):
            raise ValueError(
                "Horizontal coordinates should have the same length as values matrix. Sizes of fx: %d, f_vals: [%s]"
                % (fx.size, " ".join(("%d" % s for s in f_vals.shape)))
            )

        fy, fx = np.meshgrid(fy, fx, indexing="ij")
        fy = fy.flatten()
        fx = fx.flatten()
        coords = np.array([np.ones(f_vals.size), fy, fx, fy * fx, fy ** 2, fx ** 2])
152

153
        coeffs = np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0]
154

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        # For a 1D parabola `f(x) = ax^2 + bx + c`, the vertex position is:
        # x_v = -b / 2a. For a 2D parabola, the vertex position is:
        # (y, x)_v = - b / A, where:
        A = [[2 * coeffs[4], coeffs[3]], [coeffs[3], 2 * coeffs[5]]]
        b = coeffs[1:3]
        vertex_yx = np.linalg.lstsq(A, -b, rcond=None)[0]

        vertex_min_yx = [np.min(fy), np.min(fx)]
        vertex_max_yx = [np.max(fy), np.max(fx)]
        if np.any(vertex_yx < vertex_min_yx) or np.any(vertex_yx > vertex_max_yx):
            raise ValueError(
                "Fitted (y: {}, x: {}) positions are outide the margins of input: y: [{}, {}], x: [{}, {}]".format(
                    vertex_yx[0], vertex_yx[1], vertex_min_yx[0], vertex_max_yx[0], vertex_min_yx[1], vertex_max_yx[1]
                )
            )
        return vertex_yx

    @staticmethod
173
    def refine_max_position_1d(f_vals, fx=None):
174
175
176
177
178
179
        """Computes the sub-pixel max position of the given function sampling.

        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
180
181
        fx: numpy.ndarray, optional
            Coordinates of the sampled points
182
183
184
185

        Raises
        ------
        ValueError
186
187
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
188
189
190
191
192
193

        Returns
        -------
        float
            Estimated function max, according to the coordinates in fx.
        """
194
        if not len(f_vals.shape) == 1:
195
            raise ValueError(
196
197
198
199
200
201
202
203
204
205
                "The fitted values should form a 1-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.size - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.size)
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.size)):
            raise ValueError(
                "Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d"
                % (fx.size, f_vals.size)
206
207
            )

208
209
210
211
        # using Polynomial.fit, because supposed to be more numerically stable
        # than previous solutions (according to numpy).
        poly = Polynomial.fit(fx, f_vals, deg=2)
        coeffs = poly.convert().coef
212

213
214
215
        # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is:
        # x_v = -b / 2a.
        vertex_x = - coeffs[1] / (2 * coeffs[2])
216

217
218
219
220
221
222
223
224
225
        vertex_min_x = np.min(fx)
        vertex_max_x = np.max(fx)
        if not (vertex_min_x < vertex_x < vertex_max_x):
            raise ValueError(
                "Fitted x: {} position is outide the margins of input: x: [{}, {}]".format(
                    vertex_x, vertex_min_x, vertex_max_x
                )
            )
        return vertex_x
226

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    @staticmethod
    def extract_peak_region(cc, peak_radius=1, cc_vs=None, cc_hs=None):
        """
        Extracts a region around the maximum value.

        Parameters
        ----------
        cc: numpy.ndarray
            Correlation image.
        peak_radius: int, optional
            The l_inf radius of the area to extract around the peak. The default is 1.
        cc_vs: numpy.ndarray, optional
            The vertical coordinates of `cc`. The default is None.
        cc_hs: numpy.ndarray, optional
            The horizontal coordinates of `cc`. The default is None.

        Returns
        -------
        f_vals: numpy.ndarray
            The extracted function values.
        fv: numpy.ndarray
            The vertical coordinates of the extracted values.
        fh: numpy.ndarray
            The horizontal coordinates of the extracted values.
        """
        img_shape = np.array(cc.shape)
        # get pixel having the maximum value of the correlation array
        pix_max_corr = np.argmax(cc)
        pv, ph = np.unravel_index(pix_max_corr, img_shape)

        # select a n x n neighborhood for the sub-pixel fitting (with wrapping)
        pv = np.arange(pv - peak_radius, pv + peak_radius + 1) % img_shape[-2]
        ph = np.arange(ph - peak_radius, ph + peak_radius + 1) % img_shape[-1]

        # extract the (v, h) pixel coordinates
        fv = None if cc_vs is None else cc_vs[pv]
        fh = None if cc_hs is None else cc_hs[ph]

        # extract the correlation values
        pv, ph = np.meshgrid(pv, ph, indexing="ij")
        f_vals = cc[pv, ph]

        return (f_vals, fv, fh)

271
272
273
274
275
276
277
278
    @staticmethod
    def _determine_roi(img_shape, roi_yxhw, do_truncate_horz_pow2):
        if roi_yxhw is None:
            # vertical window size is reduced to a power of 2 to accelerate fft
            # same thing horizontal window - if requested. Default is not
            roi_yxhw = previouspow2(img_shape)
            if not do_truncate_horz_pow2:
                roi_yxhw[1] = img_shape[1]
279

280
281
282
283
        if len(roi_yxhw) == 2:  # Convert centered 2-element roi into 4-element
            roi_yxhw = np.array(roi_yxhw, dtype=np.int)
            roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw))
        return roi_yxhw
284

285
    @staticmethod
286
    def _prepare_image(
287
        img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, high_pass=None, low_pass=None
288
    ):
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        """Prepare and returns  a cropped  and filtered image, or array of filtered images if the input is an  array of images.

        Parameters
        ----------
        img: numpy.ndarray
            image or stack of images
        invalid_val: float
            value to be used in replacement of nan and inf values
        median_filt_shape: int or sequence of int
            the width or the widths of the median window
        high_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter, and the result is subtracted from 1 to obtain the high pass filter
            When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff
            frequency while a smooth  transition to zero is done for smaller frequency
         low_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter.
            When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff
            frequency and then a smooth erfc transition to zero is done


        Returns
        -------
        numpy.array_like
            The computed filter
        """

321
322
        img = np.squeeze(img)  # Removes singleton dimensions, but does a shallow copy
        img = np.ascontiguousarray(img)
323

324
        if roi_yxhw is not None:
325
            img = img[
326
                ..., roi_yxhw[0]: roi_yxhw[0] + roi_yxhw[2], roi_yxhw[1]: roi_yxhw[1] + roi_yxhw[3],
327
328
            ]

myron's avatar
myron committed
329
        img = img.copy()
330
331
332
333

        img[np.isnan(img)] = invalid_val
        img[np.isinf(img)] = invalid_val

334
335
336
337
338
339
        if high_pass is not None or low_pass is not None:
            myfilter = np.ones_like(img)
            if low_pass is not None:
                myfilter[:] *= _get_lowpass_filter(img.shape[-2:], low_pass)
            if high_pass is not None:
                myfilter[:] *= _get_highpass_filter(img.shape[-2:], high_pass)
myron's avatar
myron committed
340
341
            img_shape = img.shape
            if len(img_shape) == 2:
342
                img = np.fft.ifft2(np.fft.fft2(img) * myfilter).real
myron's avatar
myron committed
343
            elif len(img_shape) > 2:
344
345
346
                # if dealing with a stack of images, we have to do them one by one
                img = np.reshape(img, (-1,) + tuple(img_shape[-2:]))
                for ii in range(img.shape[0]):
347
                    img[ii, ...] = np.fft.ifft2(np.fft.fft2(img[ii, ...]) * myfilter).real
348
349
                img = np.reshape(img, img_shape)

350
        if median_filt_shape is not None:
351
352
353
354
            img_shape = img.shape
            if __have_scipy__:
                # expanding filter shape with ones, to cover the stack of images
                # but disabling inter-image filtering
355
                median_filt_shape = np.concatenate(
356
                    (np.ones((len(img_shape) - len(median_filt_shape),), dtype=np.int), median_filt_shape,)
357
                )
358
359
360
361
362
363
                img = median_filter(img, size=median_filt_shape)
            else:
                if len(img_shape) == 2:
                    img = median_filter(img, kernel_size=median_filt_shape)
                elif len(img_shape) > 2:
                    # if dealing with a stack of images, we have to do them one by one
364
                    img = np.reshape(img, (-1,) + tuple(img_shape[-2:]))
365
                    for ii in range(img.shape[0]):
366
                        img[ii, ...] = median_filter(img[ii, ...], kernel_size=median_filt_shape)
367
                    img = np.reshape(img, img_shape)
368
369
370

        return img

371
372
    @staticmethod
    def _compute_correlation_fft(img_1, img_2, padding_mode):
373
        do_circular_conv = padding_mode is None or padding_mode == "wrap"
374
        if not do_circular_conv:
Nicola Vigano's avatar
Nicola Vigano committed
375
            padding = np.ceil(np.array(img_2.shape) / 2).astype(np.int)
376
377
            img_1 = np.pad(img_1, ((padding[0],), (padding[1],)), mode=padding_mode)
            img_2 = np.pad(img_2, ((padding[0],), (padding[1],)), mode=padding_mode)
378
379
380
381
382
383
384
385

        # compute fft's of the 2 images
        img_fft_1 = np.fft.fft2(img_1)
        img_fft_2 = np.conjugate(np.fft.fft2(img_2))
        # inverse fft of the product to get cross_correlation of the 2 images
        cc = np.real(np.fft.ifft2(img_fft_1 * img_fft_2))

        if not do_circular_conv:
386
            cc = np.fft.fftshift(cc, axes=(-2, -1))
387
            cc = cc[padding[0]: -padding[0], padding[1]: -padding[1]]
388
            cc = np.fft.ifftshift(cc, axes=(-2, -1))
389
390
391

        return cc

392
393

class CenterOfRotation(AlignmentBase):
394
    def __init__(self, horz_fft_width=False):
395
396
397
398
399
400
401
402
403
404
405
        """
        Center of Rotation (CoR) computation object.
        This class is used on radios.

        Parameters
        ----------
        horz_fft_width: boolean, optional
            If True, restrict the horizontal size to a power of 2:

            >>> new_x_dim = 2 ** math.floor(math.log2(x_dim))
        """
406
        self._init_parameters(horz_fft_width)
407

408
    def _init_parameters(self, horz_fft_width):
409
410
411
412
413
414
415
        self.truncate_horz_pow2 = horz_fft_width

    @staticmethod
    def _check_img_sizes(img_1: np.ndarray, img_2: np.ndarray):
        shape_1 = np.squeeze(img_1).shape
        shape_2 = np.squeeze(img_2).shape
        if not len(shape_1) == 2:
416
            raise ValueError(
417
                "Images need to be 2-dimensional. Shape of image #1: %s" % (" ".join(("%d" % x for x in shape_1)))
418
            )
419
        if not len(shape_2) == 2:
420
            raise ValueError(
421
                "Images need to be 2-dimensional. Shape of image #2: %s" % (" ".join(("%d" % x for x in shape_2)))
422
            )
423
        if not np.all(shape_1 == shape_2):
424
425
            raise ValueError(
                "Images need to be of the same shape. Shape of image #1: %s, image #2: %s"
426
                % (" ".join(("%d" % x for x in shape_1)), " ".join(("%d" % x for x in shape_2)),)
427
            )
428

429
    def find_shift(
430
431
432
433
434
        self,
        img_1: np.ndarray,
        img_2: np.ndarray,
        roi_yxhw=None,
        median_filt_shape=None,
myron's avatar
myron committed
435
        padding_mode=None,
436
        peak_fit_radius=1,
437
438
        high_pass=None,
        low_pass=None
439
    ):
440
        """Find the Center of Rotation (CoR), given to images.
Nicola Vigano's avatar
Nicola Vigano committed
441

Pierre Paleo's avatar
Fix doc    
Pierre Paleo committed
442
443
        This method finds the half-shift between two opposite images, by
        means of correlation computed in Fourier space.
444

445
446
        The output of this function, allows to compute motor movements for
        aligning the sample rotation axis. Given the following values:
447

448
449
450
451
        - L1: distance from source to motor
        - L2: distance from source to detector
        - ps: physical pixel size
        - v: output of this function
452

Nicola Vigano's avatar
Nicola Vigano committed
453
        displacement of motor = (L1 / L2 * ps) * v
454

455
456
457
458
459
460
        Parameters
        ----------
        img_1: numpy.ndarray
            First image
        img_2: numpy.ndarray
            Second image, it needs to have been flipped already (e.g. using numpy.fliplr).
461
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
462
463
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
464
465
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
466
467
468
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
469
470
471
472
473
474
475
476
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
477
        peak_fit_radius: int, optional
478
479
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
480

481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        high_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter, and the result is subtracted from 1 to obtain the high pass filter
            When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff
            frequency while a smooth  transition to zero is done for smaller frequency
        low_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter.
            When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff
            frequency and then a smooth erfc transition to zero is done

496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
        Raises
        ------
        ValueError
            In case images are not 2-dimensional or have different sizes.

        Returns
        -------
        float
            Estimated center of rotation position from the center of the RoI in pixels.

        Examples
        --------
        The following code computes the center of rotation position for two
        given images in a tomography scan, where the second image is taken at
        180 degrees from the first.

        >>> radio1 = data[0, :, :]
        ... radio2 = np.fliplr(data[1, :, :])
514
        ... CoR_calc = CenterOfRotation()
515
        ... cor_position = CoR_calc.find_shift(radio1, radio2)
516
517
518

        Or for noisy images:

519
        >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
520
521
522
        """
        self._check_img_sizes(img_1, img_2)

523
        if peak_fit_radius < 1:
524
525
526
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
527
528
            peak_fit_radius = 1

529
        img_shape = img_2.shape
530
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, self.truncate_horz_pow2)
531

532
533
534
535
        img_1 = self._prepare_image(
            img_1,
            roi_yxhw=roi_yxhw,
            median_filt_shape=median_filt_shape,
536
537
            high_pass=high_pass,
            low_pass=low_pass
538
539
540
541
542
        )
        img_2 = self._prepare_image(
            img_2,
            roi_yxhw=roi_yxhw,
            median_filt_shape=median_filt_shape,
543
544
            high_pass=high_pass,
            low_pass=low_pass
545
        )
546

547
        cc = self._compute_correlation_fft(img_1, img_2, padding_mode)
548
549
550
        img_shape = img_2.shape
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
551

552
        (f_vals, fv, fh) = self.extract_peak_region(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
553
        fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh)
554

555
        return fitted_shifts_vh[-1] / 2.0
556

557
    __call__ = find_shift
558
559
560
561
562
563
564
565


class DetectorTranslationAlongBeam(AlignmentBase):
    @staticmethod
    def _check_img_sizes(img_stack: np.ndarray, img_pos: np.ndarray):
        shape_stack = np.squeeze(img_stack).shape
        shape_pos = np.squeeze(img_pos).shape
        if not len(shape_stack) == 3:
566
            raise ValueError(
567
                "A stack of 2-dimensional images is required. Shape of stack: %s" % (" ".join(("%d" % x for x in shape_stack)))
568
            )
569
        if not len(shape_pos) == 1:
570
571
572
573
            raise ValueError(
                "Positions need to be a 1-dimensional array. Shape of the positions variable: %s"
                % (" ".join(("%d" % x for x in shape_pos)))
            )
574
575
        if not shape_stack[0] == shape_pos[0]:
            raise ValueError(
576
577
                "The same number of images and positions is required."
                + " Shape of stack: %s, shape of positions variable: %s"
578
                % (" ".join(("%d" % x for x in shape_stack)), " ".join(("%d" % x for x in shape_pos)),)
579
            )
580
581

    def find_shift(
582
583
584
585
586
587
        self,
        img_stack: np.ndarray,
        img_pos: np.array,
        roi_yxhw=None,
        median_filt_shape=None,
        padding_mode=None,
Nicola Vigano's avatar
Nicola Vigano committed
588
        peak_fit_radius=1,
589
        equispaced_increments=False
590
591
592
593
594
    ):
        """Find the deviation of the translation axis of the area detector
        along the beam propagation direction.

        TODO: Add more information here! Including interpretation of the result
595
596
        This means giving also an example on how to convert the returned values
        into meaningful quantities. See "Returns" for more details.
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

        Parameters
        ----------
        img_stack: numpy.ndarray
            A stack of images (usually 4) at different distances
        img_pos: numpy.ndarray
            Position of the images along the translation axis
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
620
        peak_fit_radius: int, optional
621
622
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
623
624
625
626
627
628
629
630
631
        equispaced_increments: boolean, optional
            Tells whether the position increments are equispaced or not. If
            equispaced increments are used, we have to compute the correlation
            images with respect to the first image, otherwise we can do it
            against adjacent images.
            The advantage of doing it between adjacent images is that we do not
            build up large displacements in the correlation.
            However, if this is done for equispaced images, the linear fitting
            becomes unstable.
632
633
634
635

        Returns
        -------
        tuple(float, float)
636
637
            Estimated (vertical, horizontal) increment per unit-distance of the
            ratio between pixel-size and detector translation.
638
639
640
641
642
643
644

        Examples
        --------
        TODO: Add examples here!
        """
        self._check_img_sizes(img_stack, img_pos)

645
        if peak_fit_radius < 1:
646
647
648
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
649
650
            peak_fit_radius = 1

651
652
653
654
        num_imgs = img_stack.shape[0]
        img_shape = img_stack.shape[-2:]
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, False)

655
        img_stack = self._prepare_image(img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
656

657
658
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
659

660
        # do correlations
661
        ccs = [
662
663
664
            self._compute_correlation_fft(
                img_stack[0 if equispaced_increments else ii - 1, ...], img_stack[ii, ...], padding_mode
            )
665
666
            for ii in range(1, num_imgs)
        ]
667

668
        shifts_vh = np.empty((num_imgs - 1, 2))
669
        for ii, cc in enumerate(ccs):
670
            (f_vals, fv, fh) = self.extract_peak_region(cc, peak_fit_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
671
            shifts_vh[ii, :] = self.refine_max_position_2d(f_vals, fv, fh)
672

673
674
675
676
677
678
679
680
681
682
683
        if equispaced_increments:
            img_pos_increments = img_pos[1:] - img_pos[0]
        else:
            img_pos_increments = - np.diff(img_pos)

        # Polynomial.fit is supposed to be more numerically stable than polyfit
        # (according to numpy)
        coeffs_v = Polynomial.fit(img_pos_increments, shifts_vh[:, 0], deg=1).convert().coef
        coeffs_h = Polynomial.fit(img_pos_increments, shifts_vh[:, 1], deg=1).convert().coef

        return coeffs_v[1], coeffs_h[1]