alignment.py 29.8 KB
Newer Older
Nicola Vigano's avatar
Nicola Vigano committed
1
import numpy as np
2

3
import logging
4
from numpy.polynomial.polynomial import Polynomial, polyval
5

6
from nabu.utils import previouspow2
7
from nabu.misc import fourier_filters
8

9
10
try:
    from scipy.ndimage.filters import median_filter
11

12
13
14
15
16
    import scipy.fft

    local_fftn = scipy.fft.rfftn
    local_ifftn = scipy.fft.irfftn

17
18
19
    __have_scipy__ = True
except ImportError:
    from silx.math.medianfilter import medfilt2d as median_filter
20

21
22
23
    local_fftn = np.fft.fftn
    local_ifftn = np.fft.ifftn

24
    __have_scipy__ = False
25

26
27
28
29
30
31
32
33
34
try:
    import matplotlib.pyplot as plt

    __have_matplotlib__ = True
except ImportError:
    logging.getLogger(__name__).warning("Matplotlib not available. Plotting disabled")

    __have_matplotlib__ = False

myron's avatar
myron committed
35

myron's avatar
myron committed
36
class AlignmentBase(object):
myron's avatar
myron committed
37
    def __init__(self, horz_fft_width=False, verbose=False, data_type=np.float32):
38
39
40
41
42
43
44
45
46
47
48
        """
        Alignment basic functions.

        Parameters
        ----------
        horz_fft_width: boolean, optional
            If True, restrict the horizontal size to a power of 2:

            >>> new_x_dim = 2 ** math.floor(math.log2(x_dim))
        verbose: boolean, optional
            When True it will produce verbose output, including plots.
49
50
        data_type: `numpy.float32`
            Computation data type.
51
        """
myron's avatar
myron committed
52
        self._init_parameters(horz_fft_width, verbose, data_type)
53

myron's avatar
myron committed
54
    def _init_parameters(self, horz_fft_width, verbose, data_type):
55
56
57
58
59
60
        self.truncate_horz_pow2 = horz_fft_width

        if verbose and not __have_matplotlib__:
            logging.getLogger(__name__).warning("Matplotlib not available. Plotting disabled, despite being activated by user")
            verbose = False
        self.verbose = verbose
myron's avatar
myron committed
61
        self.data_type = data_type
62

myron's avatar
myron committed
63
    @staticmethod
64
65
    def refine_max_position_2d(f_vals: np.ndarray, fy=None, fx=None):
        """Computes the sub-pixel max position of the given function sampling.
66

67
68
69
70
71
72
73
74
        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
        fy: numpy.ndarray, optional
            Vertical coordinates of the sampled points
        fx: numpy.ndarray, optional
            Horizontal coordinates of the sampled points
75

76
77
78
79
80
        Raises
        ------
        ValueError
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
81

82
83
84
85
86
87
        Returns
        -------
        tuple(float, float)
            Estimated (vertical, horizontal) function max, according to the
            coordinates in fy and fx.
        """
88
        if not (len(f_vals.shape) == 2):
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            raise ValueError(
                "The fitted values should form a 2-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fy is None:
            fy_half_size = (f_vals.shape[0] - 1) / 2
            fy = np.linspace(-fy_half_size, fy_half_size, f_vals.shape[0])
        elif not (len(fy.shape) == 1 and np.all(fy.size == f_vals.shape[0])):
            raise ValueError(
                "Vertical coordinates should have the same length as values matrix. Sizes of fy: %d, f_vals: [%s]"
                % (fy.size, " ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.shape[1] - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.shape[1])
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.shape[1])):
            raise ValueError(
                "Horizontal coordinates should have the same length as values matrix. Sizes of fx: %d, f_vals: [%s]"
                % (fx.size, " ".join(("%d" % s for s in f_vals.shape)))
            )

        fy, fx = np.meshgrid(fy, fx, indexing="ij")
        fy = fy.flatten()
        fx = fx.flatten()
        coords = np.array([np.ones(f_vals.size), fy, fx, fy * fx, fy ** 2, fx ** 2])
114

115
        coeffs = np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0]
116

117
118
119
120
121
122
123
124
125
126
127
        # For a 1D parabola `f(x) = ax^2 + bx + c`, the vertex position is:
        # x_v = -b / 2a. For a 2D parabola, the vertex position is:
        # (y, x)_v = - b / A, where:
        A = [[2 * coeffs[4], coeffs[3]], [coeffs[3], 2 * coeffs[5]]]
        b = coeffs[1:3]
        vertex_yx = np.linalg.lstsq(A, -b, rcond=None)[0]

        vertex_min_yx = [np.min(fy), np.min(fx)]
        vertex_max_yx = [np.max(fy), np.max(fx)]
        if np.any(vertex_yx < vertex_min_yx) or np.any(vertex_yx > vertex_max_yx):
            raise ValueError(
Nicola Vigano's avatar
Nicola Vigano committed
128
                "Fitted (y: {}, x: {}) positions are outide the input margins y: [{}, {}], and x: [{}, {}]".format(
129
                    vertex_yx[0], vertex_yx[1], vertex_min_yx[0], vertex_max_yx[0], vertex_min_yx[1], vertex_max_yx[1],
130
131
132
133
134
                )
            )
        return vertex_yx

    @staticmethod
135
    def refine_max_position_1d(f_vals, fx=None):
136
137
138
139
140
141
        """Computes the sub-pixel max position of the given function sampling.

        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
142
143
        fx: numpy.ndarray, optional
            Coordinates of the sampled points
144
145
146
147

        Raises
        ------
        ValueError
148
149
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
150
151
152
153
154
155

        Returns
        -------
        float
            Estimated function max, according to the coordinates in fx.
        """
156
        if not len(f_vals.shape) in (1, 2):
157
            raise ValueError(
158
                "The fitted values should be either one or a collection of 1-dimensional arrays. Array of shape: [%s] was given."
159
160
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
161
162
        num_vals = f_vals.shape[0]

163
        if fx is None:
164
165
166
            fx_half_size = (num_vals - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, num_vals)
        elif not (len(fx.shape) == 1 and np.all(fx.size == num_vals)):
167
168
            raise ValueError(
                "Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d"
169
                % (fx.size, num_vals)
170
171
            )

172
173
174
175
176
177
178
179
        if len(f_vals.shape) == 1:
            # using Polynomial.fit, because supposed to be more numerically
            # stable than previous solutions (according to numpy).
            poly = Polynomial.fit(fx, f_vals, deg=2)
            coeffs = poly.convert().coef
        else:
            coords = np.array([np.ones(num_vals), fx, fx ** 2])
            coeffs = np.linalg.lstsq(coords.T, f_vals, rcond=None)[0]
180

181
182
        # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is:
        # x_v = -b / 2a.
183
        vertex_x = -coeffs[1, :] / (2 * coeffs[2, :])
184

185
186
        vertex_min_x = np.min(fx)
        vertex_max_x = np.max(fx)
187
188
189
190
191
        lower_bound_ok = vertex_min_x < vertex_x
        upper_bound_ok = vertex_x < vertex_max_x
        if not np.all(lower_bound_ok * upper_bound_ok):
            if len(f_vals.shape) == 1:
                message = "Fitted position {} is outide the input margins [{}, {}]".format(
192
193
                    vertex_x, vertex_min_x, vertex_max_x
                )
194
195
            else:
                message = "Fitted positions outide the input margins [{}, {}]: %d below and %d above".format(
196
                    vertex_min_x, vertex_max_x, np.sum(1 - lower_bound_ok), np.sum(1 - upper_bound_ok),
197
198
                )
            raise ValueError(message)
199
        return vertex_x
200

201
    @staticmethod
202
    def extract_peak_region_2d(cc, peak_radius=1, cc_vs=None, cc_hs=None):
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        """
        Extracts a region around the maximum value.

        Parameters
        ----------
        cc: numpy.ndarray
            Correlation image.
        peak_radius: int, optional
            The l_inf radius of the area to extract around the peak. The default is 1.
        cc_vs: numpy.ndarray, optional
            The vertical coordinates of `cc`. The default is None.
        cc_hs: numpy.ndarray, optional
            The horizontal coordinates of `cc`. The default is None.

        Returns
        -------
        f_vals: numpy.ndarray
            The extracted function values.
        fv: numpy.ndarray
            The vertical coordinates of the extracted values.
        fh: numpy.ndarray
            The horizontal coordinates of the extracted values.
        """
        img_shape = np.array(cc.shape)
        # get pixel having the maximum value of the correlation array
        pix_max_corr = np.argmax(cc)
        pv, ph = np.unravel_index(pix_max_corr, img_shape)

        # select a n x n neighborhood for the sub-pixel fitting (with wrapping)
        pv = np.arange(pv - peak_radius, pv + peak_radius + 1) % img_shape[-2]
        ph = np.arange(ph - peak_radius, ph + peak_radius + 1) % img_shape[-1]

        # extract the (v, h) pixel coordinates
        fv = None if cc_vs is None else cc_vs[pv]
        fh = None if cc_hs is None else cc_hs[ph]

        # extract the correlation values
        pv, ph = np.meshgrid(pv, ph, indexing="ij")
        f_vals = cc[pv, ph]

        return (f_vals, fv, fh)

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    @staticmethod
    def extract_peak_regions_1d(cc, axis=-1, peak_radius=1, cc_coords=None):
        """
        Extracts a region around the maximum value.

        Parameters
        ----------
        cc: numpy.ndarray
            Correlation image.
        axis: int, optional
            Find the max values along the specified direction. The default is -1.
        peak_radius: int, optional
            The l_inf radius of the area to extract around the peak. The default is 1.
        cc_coords: numpy.ndarray, optional
            The coordinates of `cc` along the selected axis. The default is None.

        Returns
        -------
        f_vals: numpy.ndarray
            The extracted function values.
        fc_ax: numpy.ndarray
            The coordinates of the extracted values, along the selected axis.
        """
        img_shape = np.array(cc.shape)
        if not (len(img_shape) == 2):
            raise ValueError(
                "The input image should be a 2-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in cc.shape)))
            )
        other_axis = (axis + 1) % 2
        # get pixel having the maximum value of the correlation array
        pix_max = np.argmax(cc, axis=axis)

        # select a n neighborhood for the many 1D sub-pixel fittings (with wrapping)
279
        p_ax_range = np.arange(-peak_radius, +peak_radius + 1)
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        p_ax = (pix_max[None, :] + p_ax_range[:, None]) % img_shape[axis]

        p_ln = np.tile(np.arange(0, img_shape[axis])[None, :], [2 * peak_radius + 1, 1])

        # extract the pixel coordinates along the axis
        fc_ax = None if cc_coords is None else cc_coords[p_ax.flatten()].reshape(p_ax.shape)

        # extract the correlation values
        if other_axis == 0:
            f_vals = cc[p_ln, p_ax]
        else:
            f_vals = cc[p_ax, p_ln]

        return (f_vals, fc_ax)

295
    def _determine_roi(self, img_shape, roi_yxhw):
296
297
298
299
        if roi_yxhw is None:
            # vertical window size is reduced to a power of 2 to accelerate fft
            # same thing horizontal window - if requested. Default is not
            roi_yxhw = previouspow2(img_shape)
300
            if not self.truncate_horz_pow2:
301
                roi_yxhw[1] = img_shape[1]
302

303
304
305
306
        if len(roi_yxhw) == 2:  # Convert centered 2-element roi into 4-element
            roi_yxhw = np.array(roi_yxhw, dtype=np.int)
            roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw))
        return roi_yxhw
307

308
    def _prepare_image(self, img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, low_pass=None, high_pass=None):
309
310
        """
        Prepare and returns  a cropped  and filtered image, or array of filtered images if the input is an  array of images.
311
312
313
314
315
316
317
318
319

        Parameters
        ----------
        img: numpy.ndarray
            image or stack of images
        invalid_val: float
            value to be used in replacement of nan and inf values
        median_filt_shape: int or sequence of int
            the width or the widths of the median window
320
321
        low_pass: float or sequence of two floats
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`
322
        high_pass: float or sequence of two floats
323
            High-pass filter properties, as described in `nabu.misc.fourier_filters`
324
325
326
327
328
329

        Returns
        -------
        numpy.array_like
            The computed filter
        """
330
        img = np.squeeze(img)  # Removes singleton dimensions, but does a shallow copy
331
        img = np.ascontiguousarray(img, dtype=self.data_type)
332

333
        if roi_yxhw is not None:
334
            img = img[
335
                ..., roi_yxhw[0] : roi_yxhw[0] + roi_yxhw[2], roi_yxhw[1] : roi_yxhw[1] + roi_yxhw[3],
336
337
            ]

myron's avatar
myron committed
338
        img = img.copy()
339
340
341
342

        img[np.isnan(img)] = invalid_val
        img[np.isinf(img)] = invalid_val

343
        if high_pass is not None or low_pass is not None:
344
            img_filter = fourier_filters.get_bandpass_filter(
345
346
347
348
349
                img.shape[-2:],
                cutoff_lowpass=low_pass,
                cutoff_highpass=high_pass,
                use_rfft=__have_scipy__,
                data_type=self.data_type,
350
            )
351
            # fft2 and iff2 use axes=(-2, -1) by default
352
            img = local_ifftn(local_fftn(img, axes=(-2, -1)) * img_filter, axes=(-2, -1)).real
353

354
        if median_filt_shape is not None:
355
356
357
358
            img_shape = img.shape
            if __have_scipy__:
                # expanding filter shape with ones, to cover the stack of images
                # but disabling inter-image filtering
359
                median_filt_shape = np.concatenate(
360
                    (np.ones((len(img_shape) - len(median_filt_shape),), dtype=np.int), median_filt_shape,)
361
                )
362
363
364
365
366
367
                img = median_filter(img, size=median_filt_shape)
            else:
                if len(img_shape) == 2:
                    img = median_filter(img, kernel_size=median_filt_shape)
                elif len(img_shape) > 2:
                    # if dealing with a stack of images, we have to do them one by one
368
                    img = np.reshape(img, (-1,) + tuple(img_shape[-2:]))
369
                    for ii in range(img.shape[0]):
370
                        img[ii, ...] = median_filter(img[ii, ...], kernel_size=median_filt_shape)
371
                    img = np.reshape(img, img_shape)
372
373
374

        return img

375
    def _compute_correlation_fft(self, img_1, img_2, padding_mode, axes=(-2, -1), low_pass=None, high_pass=None):
376
        do_circular_conv = padding_mode is None or padding_mode == "wrap"
377
        img_shape = img_2.shape
378
        if not do_circular_conv:
379
380
381
382
383
384
385
            pad_size = np.ceil(np.array(img_shape) / 2).astype(np.int)
            pad_array = [(0,)] * len(img_shape)
            for a in axes:
                pad_array[a] = (pad_size[a],)

            img_1 = np.pad(img_1, pad_array, mode=padding_mode)
            img_2 = np.pad(img_2, pad_array, mode=padding_mode)
386
387

        # compute fft's of the 2 images
388
389
        img_fft_1 = local_fftn(img_1, axes=axes)
        img_fft_2 = np.conjugate(local_fftn(img_2, axes=axes))
390
391
392
393

        img_prod = img_fft_1 * img_fft_2

        if low_pass is not None or high_pass is not None:
394
            filt = fourier_filters.get_bandpass_filter(
395
                img_shape[-2:],
396
397
398
399
400
                cutoff_lowpass=low_pass,
                cutoff_highpass=high_pass,
                use_rfft=__have_scipy__,
                data_type=self.data_type,
            )
401
402
            img_prod *= filt

403
        # inverse fft of the product to get cross_correlation of the 2 images
404
        cc = np.real(local_ifftn(img_prod, axes=axes))
405
406

        if not do_circular_conv:
407
408
409
410
            cc = np.fft.fftshift(cc, axes=axes)

            slicing = [slice(None)] * len(img_shape)
            for a in axes:
411
                slicing[a] = slice(pad_size[a], cc.shape[a] - pad_size[a])
412
413
414
            cc = cc[tuple(slicing)]

            cc = np.fft.ifftshift(cc, axes=axes)
415
416
417

        return cc

418
419
420
421
422
423
424

class CenterOfRotation(AlignmentBase):
    @staticmethod
    def _check_img_sizes(img_1: np.ndarray, img_2: np.ndarray):
        shape_1 = np.squeeze(img_1).shape
        shape_2 = np.squeeze(img_2).shape
        if not len(shape_1) == 2:
425
            raise ValueError(
426
                "Images need to be 2-dimensional. Shape of image #1: %s" % (" ".join(("%d" % x for x in shape_1)))
427
            )
428
        if not len(shape_2) == 2:
429
            raise ValueError(
430
                "Images need to be 2-dimensional. Shape of image #2: %s" % (" ".join(("%d" % x for x in shape_2)))
431
            )
432
        if not np.all(shape_1 == shape_2):
433
434
            raise ValueError(
                "Images need to be of the same shape. Shape of image #1: %s, image #2: %s"
435
                % (" ".join(("%d" % x for x in shape_1)), " ".join(("%d" % x for x in shape_2)),)
436
            )
437

438
    def find_shift(
439
440
441
442
443
        self,
        img_1: np.ndarray,
        img_2: np.ndarray,
        roi_yxhw=None,
        median_filt_shape=None,
myron's avatar
myron committed
444
        padding_mode=None,
445
        peak_fit_radius=1,
446
        high_pass=None,
447
        low_pass=None,
448
    ):
449
        """Find the Center of Rotation (CoR), given to images.
Nicola Vigano's avatar
Nicola Vigano committed
450

Pierre Paleo's avatar
Fix doc    
Pierre Paleo committed
451
452
        This method finds the half-shift between two opposite images, by
        means of correlation computed in Fourier space.
453

454
455
        The output of this function, allows to compute motor movements for
        aligning the sample rotation axis. Given the following values:
456

457
458
459
460
        - L1: distance from source to motor
        - L2: distance from source to detector
        - ps: physical pixel size
        - v: output of this function
461

Nicola Vigano's avatar
Nicola Vigano committed
462
        displacement of motor = (L1 / L2 * ps) * v
463

464
465
466
467
468
469
        Parameters
        ----------
        img_1: numpy.ndarray
            First image
        img_2: numpy.ndarray
            Second image, it needs to have been flipped already (e.g. using numpy.fliplr).
470
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
471
472
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
473
474
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
475
476
477
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
478
479
480
481
482
483
484
485
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
486
        peak_fit_radius: int, optional
487
488
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
489
        low_pass: float or sequence of two floats
490
491
492
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`
        high_pass: float or sequence of two floats
            High-pass filter properties, as described in `nabu.misc.fourier_filters`
493

494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
        Raises
        ------
        ValueError
            In case images are not 2-dimensional or have different sizes.

        Returns
        -------
        float
            Estimated center of rotation position from the center of the RoI in pixels.

        Examples
        --------
        The following code computes the center of rotation position for two
        given images in a tomography scan, where the second image is taken at
        180 degrees from the first.

        >>> radio1 = data[0, :, :]
        ... radio2 = np.fliplr(data[1, :, :])
512
        ... CoR_calc = CenterOfRotation()
513
        ... cor_position = CoR_calc.find_shift(radio1, radio2)
514
515
516

        Or for noisy images:

517
        >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
518
519
520
        """
        self._check_img_sizes(img_1, img_2)

521
        if peak_fit_radius < 1:
522
523
524
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
525
526
            peak_fit_radius = 1

527
        img_shape = img_2.shape
528
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw)
529

530
531
        img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
        img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
532

533
        cc = self._compute_correlation_fft(img_1, img_2, padding_mode, high_pass=high_pass, low_pass=low_pass)
534

535
536
537
        img_shape = img_2.shape
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
538

539
        (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
540
        fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh)
541

542
        return fitted_shifts_vh[-1] / 2.0
543

544
    __call__ = find_shift
545
546
547
548
549
550
551
552


class DetectorTranslationAlongBeam(AlignmentBase):
    @staticmethod
    def _check_img_sizes(img_stack: np.ndarray, img_pos: np.ndarray):
        shape_stack = np.squeeze(img_stack).shape
        shape_pos = np.squeeze(img_pos).shape
        if not len(shape_stack) == 3:
553
            raise ValueError(
554
                "A stack of 2-dimensional images is required. Shape of stack: %s" % (" ".join(("%d" % x for x in shape_stack)))
555
            )
556
        if not len(shape_pos) == 1:
557
558
559
560
            raise ValueError(
                "Positions need to be a 1-dimensional array. Shape of the positions variable: %s"
                % (" ".join(("%d" % x for x in shape_pos)))
            )
561
562
        if not shape_stack[0] == shape_pos[0]:
            raise ValueError(
563
564
                "The same number of images and positions is required."
                + " Shape of stack: %s, shape of positions variable: %s"
565
                % (" ".join(("%d" % x for x in shape_stack)), " ".join(("%d" % x for x in shape_pos)),)
566
            )
567
568

    def find_shift(
569
570
571
572
573
574
        self,
        img_stack: np.ndarray,
        img_pos: np.array,
        roi_yxhw=None,
        median_filt_shape=None,
        padding_mode=None,
Nicola Vigano's avatar
Nicola Vigano committed
575
        peak_fit_radius=1,
576
577
        high_pass=None,
        low_pass=None,
578
        return_shifts=False,
579
        use_adjacent_imgs=False,
580
    ):
581
582
583
584
585
586
587
588
589
590
591
592
593
        """Find the vertical and horizontal shifts for translations of the
        detector along the beam direction.

        These shifts are in pixels-per-unit-translation, and they are due to
        the misalignment of the translation stage, with respect to the beam
        propagation direction.

        To compute the vertical and horizontal tilt angles from the obtained `shift_pix`:

        >>> tilt_deg = np.rad2deg(np.arctan(shift_pix * pixel_size))

        where `pixel_size` and and the input parameter `img_pos` have to be
        expressed in the same units.
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616

        Parameters
        ----------
        img_stack: numpy.ndarray
            A stack of images (usually 4) at different distances
        img_pos: numpy.ndarray
            Position of the images along the translation axis
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
617
        peak_fit_radius: int, optional
618
619
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
620
621
622
623
        low_pass: float or sequence of two floats
            Low-pass filter properties, as described in `nabu.misc.fourier_filters`.
        high_pass: float or sequence of two floats
            High-pass filter properties, as described in `nabu.misc.fourier_filters`.
624
625
626
627
        return_shifts: boolean, optional
            if a True a list, containing for each given image of the stack a tuple of shifts, will
            be return together with  the increment per unit-distance.
            This slot is intended for introspection.
628
        use_adjacent_imgs: boolean, optional
629
630
631
632
            Compute correlation between adjacent images.
            It can be used when dealing with large shifts, to avoid overflowing the shift.
            This option allows to replicate the behavior of the reference function `alignxc.m`
            However, it is detrimental to shift fitting accuracy. Defaults to False.
633
634
635
636

        Returns
        -------
        tuple(float, float)
637
638
            Estimated (vertical, horizontal) increment per unit-distance of the
            ratio between pixel-size and detector translation.
639
640
            Optionally a tuple containing the shifts per each image are found if optional argument
            return_shifts is true
641
642
643

        Examples
        --------
644
645
        The following example creates a stack of shifted images, and retrieves the computed shift.
        Here we use a high-pass filter, due to the presence of some low-frequency noise component.
646
647

        >>> import numpy as np
648
        ... import scipy as sp
649
650
651
        ... import scipy.ndimage
        ... from nabu.preproc.alignment import  DetectorTranslationAlongBeam
        ...
652
        ... tr_calc = DetectorTranslationAlongBeam()
653
        ...
654
        ... stack = np.zeros([4, 512, 512])
655
656
657
658
        ...
        ... # Add low frequency spurious component
        ... for i in range(4):
        ...     stack[i, 200 - i * 10, 200 - i * 10] = 1
659
        ... stack = sp.ndimage.filters.gaussian_filter(stack, [0, 10, 10.0]) * 100
660
661
662
663
664
665
666
667
        ...
        ... # Add the feature
        ... x, y = np.meshgrid(np.arange(stack.shape[-1]), np.arange(stack.shape[-2]))
        ... for i in range(4):
        ...     xc = x - (250 + i * 1.234)
        ...     yc = y - (250 + i * 1.234 * 2)
        ...     stack[i] += np.exp(-(xc * xc + yc * yc) * 0.5)
        ...
668
669
670
        ... # Image translation along the beam
        ... img_pos = np.arange(4)
        ...
671
        ... # Find the shifts from the features
672
        ... shifts_v, shifts_h = tr_calc.find_shift(stack, img_pos, high_pass=1.0)
673
674
        ... print(shifts_v, shifts_h)
        >>> ( -2.47 , -1.236 )
675
676
677
678
679
680
681
682
683
684
685
686
687
688

        and the following commands convert the shifts in angular tilts:

        >>> tilt_v_deg = np.rad2deg(np.arctan(shifts_v * pixel_size))
        >>> tilt_h_deg = np.rad2deg(np.arctan(shifts_h * pixel_size))

        To enable the legacy behavior of `alignxc.m` (correlation between adjacent images):

        >>> shifts_v, shifts_h = tr_calc.find_shift(stack, img_pos, use_adjacent_imgs=True)

        To plot the correlation shifts and the fitted straight lines for both directions:

        >>> tr_calc = DetectorTranslationAlongBeam(verbose=True)
        ... shifts_v, shifts_h = tr_calc.find_shift(stack, img_pos)
689
690
691
        """
        self._check_img_sizes(img_stack, img_pos)

692
        if peak_fit_radius < 1:
693
694
695
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
696
697
            peak_fit_radius = 1

698
699
        num_imgs = img_stack.shape[0]
        img_shape = img_stack.shape[-2:]
700
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw)
701

702
        img_stack = self._prepare_image(img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
703
704

        # do correlations
705
        ccs = [
706
            self._compute_correlation_fft(
707
                img_stack[ii - 1 if use_adjacent_imgs else 0, ...],
708
709
710
711
                img_stack[ii, ...],
                padding_mode,
                high_pass=high_pass,
                low_pass=low_pass,
712
            )
713
714
            for ii in range(1, num_imgs)
        ]
715

716
717
718
719
        img_shape = img_stack.shape[-2:]
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])

720
        shifts_vh = np.empty((num_imgs, 2))
721
        for ii, cc in enumerate(ccs):
722
            (f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
723
            shifts_vh[ii+1, :] = self.refine_max_position_2d(f_vals, fv, fh)
724

725
726
        if use_adjacent_imgs:
            shifts_vh = np.cumsum(shifts_vh, axis=0)
727

728
729
        # Polynomial.fit is supposed to be more numerically stable than polyfit
        # (according to numpy)
730
731
        coeffs_v = Polynomial.fit(img_pos, shifts_vh[:, 0], deg=1).convert().coef
        coeffs_h = Polynomial.fit(img_pos, shifts_vh[:, 1], deg=1).convert().coef
732

733
734
        if self.verbose:
            f, axs = plt.subplots(1, 2)
735
            axs[0].scatter(img_pos, shifts_vh[:, 0])
736
            axs[0].plot(img_pos, polyval(img_pos, coeffs_v))
737
            axs[0].set_title("Vertical shifts")
738
            axs[1].scatter(img_pos, shifts_vh[:, 1])
739
            axs[1].plot(img_pos, polyval(img_pos, coeffs_h))
740
741
742
            axs[1].set_title("Horizontal shifts")
            plt.show(block=False)

743
        if return_shifts:
744
            return coeffs_v[1], coeffs_h[1], shifts_vh
745
746
        else:
            return coeffs_v[1], coeffs_h[1]