alignment.py 28.7 KB
Newer Older
Nicola Vigano's avatar
Nicola Vigano committed
1
import numpy as np
2

myron's avatar
myron committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
try:
    import scipy.fft

    my_fftn = scipy.fft.rfftn
    my_ifftn = scipy.fft.irfftn
    my_fft2 = scipy.fft.rfft2
    my_ifft2 = scipy.fft.irfft2

    def my_fft_layout_adapt(x):
        slicelist = tuple(slice(0, (N + 1) // 2) for N in x.shape)
        return x[slicelist]


except ImportError:
    my_fftn = np.fft.fftn
    my_ifftn = np.fft.ifftn
    my_fft2 = np.fft.fft2
    my_ifft2 = np.fft.ifft2

    def my_fft_layout_adapt(x):
        return x


26
import logging
27
from numpy.polynomial.polynomial import Polynomial, polyval
28

29
from nabu.utils import previouspow2
30
from nabu.misc import fourier_filters
31

32
33
try:
    from scipy.ndimage.filters import median_filter
34

35
36
37
    __have_scipy__ = True
except ImportError:
    from silx.math.medianfilter import medfilt2d as median_filter
38

39
    __have_scipy__ = False
40

41
42
43
44
45
46
47
48
49
try:
    import matplotlib.pyplot as plt

    __have_matplotlib__ = True
except ImportError:
    logging.getLogger(__name__).warning("Matplotlib not available. Plotting disabled")

    __have_matplotlib__ = False

myron's avatar
myron committed
50

myron's avatar
myron committed
51
class AlignmentBase(object):
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    def __init__(self, horz_fft_width=False, verbose=False):
        """
        Alignment basic functions.

        Parameters
        ----------
        horz_fft_width: boolean, optional
            If True, restrict the horizontal size to a power of 2:

            >>> new_x_dim = 2 ** math.floor(math.log2(x_dim))
        verbose: boolean, optional
            When True it will produce verbose output, including plots.
        """
        self._init_parameters(horz_fft_width, verbose)

    def _init_parameters(self, horz_fft_width, verbose):
        self.truncate_horz_pow2 = horz_fft_width

        if verbose and not __have_matplotlib__:
            logging.getLogger(__name__).warning("Matplotlib not available. Plotting disabled, despite being activated by user")
            verbose = False
        self.verbose = verbose

myron's avatar
myron committed
75
    @staticmethod
76
77
    def refine_max_position_2d(f_vals: np.ndarray, fy=None, fx=None):
        """Computes the sub-pixel max position of the given function sampling.
78

79
80
81
82
83
84
85
86
        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
        fy: numpy.ndarray, optional
            Vertical coordinates of the sampled points
        fx: numpy.ndarray, optional
            Horizontal coordinates of the sampled points
87

88
89
90
91
92
        Raises
        ------
        ValueError
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
93

94
95
96
97
98
99
        Returns
        -------
        tuple(float, float)
            Estimated (vertical, horizontal) function max, according to the
            coordinates in fy and fx.
        """
100
        if not (len(f_vals.shape) == 2):
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            raise ValueError(
                "The fitted values should form a 2-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fy is None:
            fy_half_size = (f_vals.shape[0] - 1) / 2
            fy = np.linspace(-fy_half_size, fy_half_size, f_vals.shape[0])
        elif not (len(fy.shape) == 1 and np.all(fy.size == f_vals.shape[0])):
            raise ValueError(
                "Vertical coordinates should have the same length as values matrix. Sizes of fy: %d, f_vals: [%s]"
                % (fy.size, " ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.shape[1] - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.shape[1])
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.shape[1])):
            raise ValueError(
                "Horizontal coordinates should have the same length as values matrix. Sizes of fx: %d, f_vals: [%s]"
                % (fx.size, " ".join(("%d" % s for s in f_vals.shape)))
            )

        fy, fx = np.meshgrid(fy, fx, indexing="ij")
        fy = fy.flatten()
        fx = fx.flatten()
        coords = np.array([np.ones(f_vals.size), fy, fx, fy * fx, fy ** 2, fx ** 2])
126

127
        coeffs = np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0]
128

129
130
131
132
133
134
135
136
137
138
139
        # For a 1D parabola `f(x) = ax^2 + bx + c`, the vertex position is:
        # x_v = -b / 2a. For a 2D parabola, the vertex position is:
        # (y, x)_v = - b / A, where:
        A = [[2 * coeffs[4], coeffs[3]], [coeffs[3], 2 * coeffs[5]]]
        b = coeffs[1:3]
        vertex_yx = np.linalg.lstsq(A, -b, rcond=None)[0]

        vertex_min_yx = [np.min(fy), np.min(fx)]
        vertex_max_yx = [np.max(fy), np.max(fx)]
        if np.any(vertex_yx < vertex_min_yx) or np.any(vertex_yx > vertex_max_yx):
            raise ValueError(
Nicola Vigano's avatar
Nicola Vigano committed
140
                "Fitted (y: {}, x: {}) positions are outide the input margins y: [{}, {}], and x: [{}, {}]".format(
141
                    vertex_yx[0], vertex_yx[1], vertex_min_yx[0], vertex_max_yx[0], vertex_min_yx[1], vertex_max_yx[1],
142
143
144
145
146
                )
            )
        return vertex_yx

    @staticmethod
147
    def refine_max_position_1d(f_vals, fx=None):
148
149
150
151
152
153
        """Computes the sub-pixel max position of the given function sampling.

        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
154
155
        fx: numpy.ndarray, optional
            Coordinates of the sampled points
156
157
158
159

        Raises
        ------
        ValueError
160
161
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
162
163
164
165
166
167

        Returns
        -------
        float
            Estimated function max, according to the coordinates in fx.
        """
168
        if not len(f_vals.shape) in (1, 2):
169
            raise ValueError(
170
                "The fitted values should be either one or a collection of 1-dimensional arrays. Array of shape: [%s] was given."
171
172
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
173
174
        num_vals = f_vals.shape[0]

175
        if fx is None:
176
177
178
            fx_half_size = (num_vals - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, num_vals)
        elif not (len(fx.shape) == 1 and np.all(fx.size == num_vals)):
179
180
            raise ValueError(
                "Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d"
181
                % (fx.size, num_vals)
182
183
            )

184
185
186
187
188
189
190
191
        if len(f_vals.shape) == 1:
            # using Polynomial.fit, because supposed to be more numerically
            # stable than previous solutions (according to numpy).
            poly = Polynomial.fit(fx, f_vals, deg=2)
            coeffs = poly.convert().coef
        else:
            coords = np.array([np.ones(num_vals), fx, fx ** 2])
            coeffs = np.linalg.lstsq(coords.T, f_vals, rcond=None)[0]
192

193
194
        # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is:
        # x_v = -b / 2a.
195
        vertex_x = -coeffs[1, :] / (2 * coeffs[2, :])
196

197
198
        vertex_min_x = np.min(fx)
        vertex_max_x = np.max(fx)
199
200
201
202
203
        lower_bound_ok = vertex_min_x < vertex_x
        upper_bound_ok = vertex_x < vertex_max_x
        if not np.all(lower_bound_ok * upper_bound_ok):
            if len(f_vals.shape) == 1:
                message = "Fitted position {} is outide the input margins [{}, {}]".format(
204
205
                    vertex_x, vertex_min_x, vertex_max_x
                )
206
207
            else:
                message = "Fitted positions outide the input margins [{}, {}]: %d below and %d above".format(
208
                    vertex_min_x, vertex_max_x, np.sum(1 - lower_bound_ok), np.sum(1 - upper_bound_ok),
209
210
                )
            raise ValueError(message)
211
        return vertex_x
212

213
    @staticmethod
214
    def extract_peak_region_2d(cc, peak_radius=1, cc_vs=None, cc_hs=None):
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        """
        Extracts a region around the maximum value.

        Parameters
        ----------
        cc: numpy.ndarray
            Correlation image.
        peak_radius: int, optional
            The l_inf radius of the area to extract around the peak. The default is 1.
        cc_vs: numpy.ndarray, optional
            The vertical coordinates of `cc`. The default is None.
        cc_hs: numpy.ndarray, optional
            The horizontal coordinates of `cc`. The default is None.

        Returns
        -------
        f_vals: numpy.ndarray
            The extracted function values.
        fv: numpy.ndarray
            The vertical coordinates of the extracted values.
        fh: numpy.ndarray
            The horizontal coordinates of the extracted values.
        """
        img_shape = np.array(cc.shape)
        # get pixel having the maximum value of the correlation array
        pix_max_corr = np.argmax(cc)
        pv, ph = np.unravel_index(pix_max_corr, img_shape)

        # select a n x n neighborhood for the sub-pixel fitting (with wrapping)
        pv = np.arange(pv - peak_radius, pv + peak_radius + 1) % img_shape[-2]
        ph = np.arange(ph - peak_radius, ph + peak_radius + 1) % img_shape[-1]

        # extract the (v, h) pixel coordinates
        fv = None if cc_vs is None else cc_vs[pv]
        fh = None if cc_hs is None else cc_hs[ph]

        # extract the correlation values
        pv, ph = np.meshgrid(pv, ph, indexing="ij")
        f_vals = cc[pv, ph]

        return (f_vals, fv, fh)

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    @staticmethod
    def extract_peak_regions_1d(cc, axis=-1, peak_radius=1, cc_coords=None):
        """
        Extracts a region around the maximum value.

        Parameters
        ----------
        cc: numpy.ndarray
            Correlation image.
        axis: int, optional
            Find the max values along the specified direction. The default is -1.
        peak_radius: int, optional
            The l_inf radius of the area to extract around the peak. The default is 1.
        cc_coords: numpy.ndarray, optional
            The coordinates of `cc` along the selected axis. The default is None.

        Returns
        -------
        f_vals: numpy.ndarray
            The extracted function values.
        fc_ax: numpy.ndarray
            The coordinates of the extracted values, along the selected axis.
        """
        img_shape = np.array(cc.shape)
        if not (len(img_shape) == 2):
            raise ValueError(
                "The input image should be a 2-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in cc.shape)))
            )
        other_axis = (axis + 1) % 2
        # get pixel having the maximum value of the correlation array
        pix_max = np.argmax(cc, axis=axis)

        # select a n neighborhood for the many 1D sub-pixel fittings (with wrapping)
291
        p_ax_range = np.arange(-peak_radius, +peak_radius + 1)
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        p_ax = (pix_max[None, :] + p_ax_range[:, None]) % img_shape[axis]

        p_ln = np.tile(np.arange(0, img_shape[axis])[None, :], [2 * peak_radius + 1, 1])

        # extract the pixel coordinates along the axis
        fc_ax = None if cc_coords is None else cc_coords[p_ax.flatten()].reshape(p_ax.shape)

        # extract the correlation values
        if other_axis == 0:
            f_vals = cc[p_ln, p_ax]
        else:
            f_vals = cc[p_ax, p_ln]

        return (f_vals, fc_ax)

307
    def _determine_roi(self, img_shape, roi_yxhw):
308
309
310
311
        if roi_yxhw is None:
            # vertical window size is reduced to a power of 2 to accelerate fft
            # same thing horizontal window - if requested. Default is not
            roi_yxhw = previouspow2(img_shape)
312
            if not self.truncate_horz_pow2:
313
                roi_yxhw[1] = img_shape[1]
314

315
316
317
318
        if len(roi_yxhw) == 2:  # Convert centered 2-element roi into 4-element
            roi_yxhw = np.array(roi_yxhw, dtype=np.int)
            roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw))
        return roi_yxhw
319

320
    @staticmethod
321
322
323
    def _prepare_image(
        img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, low_pass=None, high_pass=None,
    ):
324
325
        """
        Prepare and returns  a cropped  and filtered image, or array of filtered images if the input is an  array of images.
326
327
328
329
330
331
332
333
334

        Parameters
        ----------
        img: numpy.ndarray
            image or stack of images
        invalid_val: float
            value to be used in replacement of nan and inf values
        median_filt_shape: int or sequence of int
            the width or the widths of the median window
335
336
        low_pass: float or sequence of two floats
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`
337
        high_pass: float or sequence of two floats
338
            High-pass filter properties, as described in `nabu.misc.fourier_filters`
339
340
341
342
343
344

        Returns
        -------
        numpy.array_like
            The computed filter
        """
345
346
        img = np.squeeze(img)  # Removes singleton dimensions, but does a shallow copy
        img = np.ascontiguousarray(img)
347

348
        if roi_yxhw is not None:
349
            img = img[
350
                ..., roi_yxhw[0] : roi_yxhw[0] + roi_yxhw[2], roi_yxhw[1] : roi_yxhw[1] + roi_yxhw[3],
351
352
            ]

myron's avatar
myron committed
353
        img = img.copy()
354
355
356
357

        img[np.isnan(img)] = invalid_val
        img[np.isinf(img)] = invalid_val

358
        if high_pass is not None or low_pass is not None:
359
360
361
            img_filter = fourier_filters.get_bandpass_filter(
                img.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass
            )
362
            # fft2 and iff2 use axes=(-2, -1) by default
myron's avatar
myron committed
363
            img = my_ifft2(my_fft2(img) * my_fft_layout_adapt(img_filter)).real
364

365
        if median_filt_shape is not None:
366
367
368
369
            img_shape = img.shape
            if __have_scipy__:
                # expanding filter shape with ones, to cover the stack of images
                # but disabling inter-image filtering
370
                median_filt_shape = np.concatenate(
371
                    (np.ones((len(img_shape) - len(median_filt_shape),), dtype=np.int), median_filt_shape,)
372
                )
373
374
375
376
377
378
                img = median_filter(img, size=median_filt_shape)
            else:
                if len(img_shape) == 2:
                    img = median_filter(img, kernel_size=median_filt_shape)
                elif len(img_shape) > 2:
                    # if dealing with a stack of images, we have to do them one by one
379
                    img = np.reshape(img, (-1,) + tuple(img_shape[-2:]))
380
                    for ii in range(img.shape[0]):
381
                        img[ii, ...] = median_filter(img[ii, ...], kernel_size=median_filt_shape)
382
                    img = np.reshape(img, img_shape)
383
384
385

        return img

386
    @staticmethod
387
    def _compute_correlation_fft(img_1, img_2, padding_mode, axes=(-2, -1), low_pass=None, high_pass=None):
388
        do_circular_conv = padding_mode is None or padding_mode == "wrap"
389
        if not do_circular_conv:
390
391
392
393
394
395
396
397
            img_shape = img_2.shape
            pad_size = np.ceil(np.array(img_shape) / 2).astype(np.int)
            pad_array = [(0,)] * len(img_shape)
            for a in axes:
                pad_array[a] = (pad_size[a],)

            img_1 = np.pad(img_1, pad_array, mode=padding_mode)
            img_2 = np.pad(img_2, pad_array, mode=padding_mode)
398
399

        # compute fft's of the 2 images
myron's avatar
myron committed
400
401
        img_fft_1 = my_fftn(img_1, axes=axes)
        img_fft_2 = np.conjugate(my_fftn(img_2, axes=axes))
402
403
404
405
406
407
408

        img_prod = img_fft_1 * img_fft_2

        if low_pass is not None or high_pass is not None:
            filt = fourier_filters.get_bandpass_filter(img_prod.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass)
            img_prod *= filt

409
        # inverse fft of the product to get cross_correlation of the 2 images
myron's avatar
myron committed
410
        cc = np.real(my_ifftn(img_prod, axes=axes))
411
412

        if not do_circular_conv:
413
414
415
416
            cc = np.fft.fftshift(cc, axes=axes)

            slicing = [slice(None)] * len(img_shape)
            for a in axes:
417
                slicing[a] = slice(pad_size[a], cc.shape[a] - pad_size[a])
418
419
420
            cc = cc[tuple(slicing)]

            cc = np.fft.ifftshift(cc, axes=axes)
421
422
423

        return cc

424
425
426
427
428
429
430

class CenterOfRotation(AlignmentBase):
    @staticmethod
    def _check_img_sizes(img_1: np.ndarray, img_2: np.ndarray):
        shape_1 = np.squeeze(img_1).shape
        shape_2 = np.squeeze(img_2).shape
        if not len(shape_1) == 2:
431
            raise ValueError(
432
                "Images need to be 2-dimensional. Shape of image #1: %s" % (" ".join(("%d" % x for x in shape_1)))
433
            )
434
        if not len(shape_2) == 2:
435
            raise ValueError(
436
                "Images need to be 2-dimensional. Shape of image #2: %s" % (" ".join(("%d" % x for x in shape_2)))
437
            )
438
        if not np.all(shape_1 == shape_2):
439
440
            raise ValueError(
                "Images need to be of the same shape. Shape of image #1: %s, image #2: %s"
441
                % (" ".join(("%d" % x for x in shape_1)), " ".join(("%d" % x for x in shape_2)),)
442
            )
443

444
    def find_shift(
445
446
447
448
449
        self,
        img_1: np.ndarray,
        img_2: np.ndarray,
        roi_yxhw=None,
        median_filt_shape=None,
myron's avatar
myron committed
450
        padding_mode=None,
451
        peak_fit_radius=1,
452
        high_pass=None,
453
        low_pass=None,
454
    ):
455
        """Find the Center of Rotation (CoR), given to images.
Nicola Vigano's avatar
Nicola Vigano committed
456

Pierre Paleo's avatar
Fix doc    
Pierre Paleo committed
457
458
        This method finds the half-shift between two opposite images, by
        means of correlation computed in Fourier space.
459

460
461
        The output of this function, allows to compute motor movements for
        aligning the sample rotation axis. Given the following values:
462

463
464
465
466
        - L1: distance from source to motor
        - L2: distance from source to detector
        - ps: physical pixel size
        - v: output of this function
467

Nicola Vigano's avatar
Nicola Vigano committed
468
        displacement of motor = (L1 / L2 * ps) * v
469

470
471
472
473
474
475
        Parameters
        ----------
        img_1: numpy.ndarray
            First image
        img_2: numpy.ndarray
            Second image, it needs to have been flipped already (e.g. using numpy.fliplr).
476
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
477
478
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
479
480
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
481
482
483
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
484
485
486
487
488
489
490
491
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
492
        peak_fit_radius: int, optional
493
494
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
495
        low_pass: float or sequence of two floats
496
497
498
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`
        high_pass: float or sequence of two floats
            High-pass filter properties, as described in `nabu.misc.fourier_filters`
499

500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
        Raises
        ------
        ValueError
            In case images are not 2-dimensional or have different sizes.

        Returns
        -------
        float
            Estimated center of rotation position from the center of the RoI in pixels.

        Examples
        --------
        The following code computes the center of rotation position for two
        given images in a tomography scan, where the second image is taken at
        180 degrees from the first.

        >>> radio1 = data[0, :, :]
        ... radio2 = np.fliplr(data[1, :, :])
518
        ... CoR_calc = CenterOfRotation()
519
        ... cor_position = CoR_calc.find_shift(radio1, radio2)
520
521
522

        Or for noisy images:

523
        >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
524
525
526
        """
        self._check_img_sizes(img_1, img_2)

527
        if peak_fit_radius < 1:
528
529
530
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
531
532
            peak_fit_radius = 1

533
        img_shape = img_2.shape
534
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw)
535

536
537
        img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
        img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
538

539
        cc = self._compute_correlation_fft(img_1, img_2, padding_mode, high_pass=high_pass, low_pass=low_pass)
540

541
542
543
        img_shape = img_2.shape
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
544

545
        (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
546
        fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh)
547

548
        return fitted_shifts_vh[-1] / 2.0
549

550
    __call__ = find_shift
551
552
553
554
555
556
557
558


class DetectorTranslationAlongBeam(AlignmentBase):
    @staticmethod
    def _check_img_sizes(img_stack: np.ndarray, img_pos: np.ndarray):
        shape_stack = np.squeeze(img_stack).shape
        shape_pos = np.squeeze(img_pos).shape
        if not len(shape_stack) == 3:
559
            raise ValueError(
560
                "A stack of 2-dimensional images is required. Shape of stack: %s" % (" ".join(("%d" % x for x in shape_stack)))
561
            )
562
        if not len(shape_pos) == 1:
563
564
565
566
            raise ValueError(
                "Positions need to be a 1-dimensional array. Shape of the positions variable: %s"
                % (" ".join(("%d" % x for x in shape_pos)))
            )
567
568
        if not shape_stack[0] == shape_pos[0]:
            raise ValueError(
569
570
                "The same number of images and positions is required."
                + " Shape of stack: %s, shape of positions variable: %s"
571
                % (" ".join(("%d" % x for x in shape_stack)), " ".join(("%d" % x for x in shape_pos)),)
572
            )
573
574

    def find_shift(
575
576
577
578
579
580
        self,
        img_stack: np.ndarray,
        img_pos: np.array,
        roi_yxhw=None,
        median_filt_shape=None,
        padding_mode=None,
Nicola Vigano's avatar
Nicola Vigano committed
581
        peak_fit_radius=1,
582
583
        high_pass=None,
        low_pass=None,
584
        equispaced_increments=True,
585
        return_shifts=False,
586
        use_adjacent_imgs=False,
587
    ):
myron's avatar
myron committed
588
        """Find  vertical and  horizontal position increments per a unit-distance detector translation along the
589
590
591
        traslation axis. The units are pixel_unit/input_unit where input_unit are the  unit that the user has  used
        to pass the argument img_pos. The output expresses shifts of the detector so that if the image is moving
        in the positive direction (expressed in pixels coordinates) the output will be negative because it means
myron's avatar
myron committed
592
        that the detector as a whole is shifting in the opposite direction (taking the shaped  beam as a reference)
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615

        Parameters
        ----------
        img_stack: numpy.ndarray
            A stack of images (usually 4) at different distances
        img_pos: numpy.ndarray
            Position of the images along the translation axis
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
616
        peak_fit_radius: int, optional
617
618
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
619
620
621
622
        low_pass: float or sequence of two floats
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`.
        high_pass: float or sequence of two floats
            High-pass filter properties, as described in `nabu.misc.fourier_filters`.
623
624
625
626
        return_shifts: boolean, optional
            if a True a list, containing for each given image of the stack a tuple of shifts, will
            be return together with  the increment per unit-distance.
            This slot is intended for introspection.
627
628
        use_adjacent_imgs: boolean, optional
            Compute correlation between adjacent images. It is better when dealing with large shifts.
629
630
631
632

        Returns
        -------
        tuple(float, float)
633
634
            Estimated (vertical, horizontal) increment per unit-distance of the
            ratio between pixel-size and detector translation.
635
636
            Optionally a tuple containing the shifts per each image are found if optional argument
            return_shifts is true
637
638
639

        Examples
        --------
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669

        >>> import numpy as np
        ... import scipy.ndimage
        ... from nabu.preproc.alignment import  DetectorTranslationAlongBeam
        ...
        ... T_calc = DetectorTranslationAlongBeam()
        ...
        ... stack = np.zeros([4, 512, 512], "d")
        ...
        ... # Add low frequency spurious component
        ... for i in range(4):
        ...     stack[i, 200 - i * 10, 200 - i * 10] = 1
        ... stack = scipy.ndimage.filters.gaussian_filter(stack, [0, 10, 10.0]) * 100
        ...
        ... # Add the feature
        ... x, y = np.meshgrid(np.arange(stack.shape[-1]), np.arange(stack.shape[-2]))
        ... for i in range(4):
        ...     xc = x - (250 + i * 1.234)
        ...     yc = y - (250 + i * 1.234 * 2)
        ...     stack[i] += np.exp(-(xc * xc + yc * yc) * 0.5)
        ...
        ... # Find the shifts from the features
        ... shifts_v, shifts_h, found_shifts_list = T_calc.find_shift(
        ...     stack, np.array([0.0, 1, 2, 3]), high_pass=1.0, return_shifts=True
        ... )
        ... print(shifts_v, shifts_h)
        >>> ( -2.47 , -1.236 )



670
671
672
        """
        self._check_img_sizes(img_stack, img_pos)

673
        if peak_fit_radius < 1:
674
675
676
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
677
678
            peak_fit_radius = 1

679
680
        num_imgs = img_stack.shape[0]
        img_shape = img_stack.shape[-2:]
681
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw)
682

683
        img_stack = self._prepare_image(img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
684
685

        # do correlations
686
        ccs = [
687
            self._compute_correlation_fft(
688
                img_stack[ii - 1 if use_adjacent_imgs else 0, ...],
689
690
691
692
                img_stack[ii, ...],
                padding_mode,
                high_pass=high_pass,
                low_pass=low_pass,
693
            )
694
695
            for ii in range(1, num_imgs)
        ]
696

697
698
699
700
        img_shape = img_stack.shape[-2:]
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])

701
        shifts_vh = np.empty((num_imgs - 1, 2))
702
        for ii, cc in enumerate(ccs):
703
            (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
704
            shifts_vh[ii, :] = self.refine_max_position_2d(f_vals, fv, fh)
705

706
707
        if use_adjacent_imgs:
            shifts_vh = np.cumsum(shifts_vh, axis=0)
708
709

        img_shifts_vh = np.concatenate(([[0, 0]], shifts_vh), axis=0)
710
711
712

        # Polynomial.fit is supposed to be more numerically stable than polyfit
        # (according to numpy)
713
714
        coeffs_v = Polynomial.fit(img_pos, img_shifts_vh[:, 0], deg=1).convert().coef
        coeffs_h = Polynomial.fit(img_pos, img_shifts_vh[:, 1], deg=1).convert().coef
715

716
717
        if self.verbose:
            f, axs = plt.subplots(1, 2)
718
719
            axs[0].scatter(img_pos, img_shifts_vh[:, 0])
            axs[0].plot(img_pos, polyval(img_pos, coeffs_v))
720
            axs[0].set_title("Vertical shifts")
721
722
            axs[1].scatter(img_pos, img_shifts_vh[:, 1])
            axs[1].plot(img_pos, polyval(img_pos, coeffs_h))
723
724
725
            axs[1].set_title("Horizontal shifts")
            plt.show(block=False)

726
        if return_shifts:
727
            return coeffs_v[1], coeffs_h[1], shifts_vh
728
729
        else:
            return coeffs_v[1], coeffs_h[1]