alignment.py 18.5 KB
Newer Older
Nicola Vigano's avatar
Nicola Vigano committed
1
import numpy as np
myron's avatar
myron committed
2
import math
3
from numpy.polynomial.polynomial import polyroots
4

5
from nabu.utils import previouspow2
6

7
8
try:
    from scipy.ndimage.filters import median_filter
myron's avatar
myron committed
9
    import scipy.special as spspe
10

11
12
13
    __have_scipy__ = True
except ImportError:
    from silx.math.medianfilter import medfilt2d as median_filter
14

15
    __have_scipy__ = False
16

myron's avatar
myron committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30

def _get_lowpass_filter(img_shape, cutoff_pix, cutoff_trans_pix):
    """Computes a low pass filter with the erfc.
    :param img_shape: Shape of the image
    :type img_shape: tuple
    :param cutoff_pix: Position of the cut off in pixels
    :type cutoff_pix: float
    :param cutoff_trans_pix: Size of the cut off transition in pixels
    :type cutoff_trans_pix: float

    :return: The computes filter
    :rtype: `numpy.array_like`
    """
    coords = [np.fft.fftfreq(s, 1) for s in img_shape]
31
    coords = np.meshgrid(*coords, indexing="ij")
myron's avatar
myron committed
32

33
34
    Kcut = 0.5 / cutoff_pix
    Kcut_width = 0.5 / cutoff_trans_pix
myron's avatar
myron committed
35
36
37

    r = np.sqrt(np.sum(np.array(coords) ** 2, axis=0))
    myargs = (r - Kcut) / Kcut_width
38
39
    if __have_scipy__:
        res = spspe.erfc(myargs) / 2
myron's avatar
myron committed
40
    else:
41
        res = np.array(list(map(math.erfc, myargs))) / 2
myron's avatar
myron committed
42

43
    return res
myron's avatar
myron committed
44

45

myron's avatar
myron committed
46
47
48
class AlignmentBase(object):
    @staticmethod
    def _biquadratic_refinement(cc_vals):
49

myron's avatar
myron committed
50
51
52
53
        X, Y = np.meshgrid([-1, 0, 1], [-1, 0, 1])
        X = X.flatten()
        Y = Y.flatten()
        coords = np.array([np.ones(X.size), X, Y, X * Y, X * X, Y * Y])
54

myron's avatar
myron committed
55
56
57
58
59
        a = np.linalg.lstsq(coords.T, cc_vals.flatten(), rcond=None)[0]
        zero_derivative_coords = np.linalg.lstsq(
            [[2 * a[4], a[3]], [a[3], 2 * a[5]]], [-a[1], -a[2]], rcond=None
        )[0]
        zero_derivative_coords = np.clip(zero_derivative_coords, -1, 1)
60

myron's avatar
myron committed
61
        return zero_derivative_coords
62

63
64
65
66
67
68
69
70
    @staticmethod
    def _determine_roi(img_shape, roi_yxhw, do_truncate_horz_pow2):
        if roi_yxhw is None:
            # vertical window size is reduced to a power of 2 to accelerate fft
            # same thing horizontal window - if requested. Default is not
            roi_yxhw = previouspow2(img_shape)
            if not do_truncate_horz_pow2:
                roi_yxhw[1] = img_shape[1]
71

72
73
74
75
        if len(roi_yxhw) == 2:  # Convert centered 2-element roi into 4-element
            roi_yxhw = np.array(roi_yxhw, dtype=np.int)
            roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw))
        return roi_yxhw
76

77
    @staticmethod
78
79
80
81
82
83
84
85
    def _prepare_image(
        img,
        invalid_val=1e-5,
        roi_yxhw=None,
        median_filt_shape=None,
        cutoff_pix=None,
        cutoff_trans_pix=None,
    ):
86
87
        img = np.squeeze(img)  # Removes singleton dimensions, but does a shallow copy
        img = np.ascontiguousarray(img)
88

89
        if roi_yxhw is not None:
90
91
92
93
94
95
            img = img[
                ...,
                roi_yxhw[0] : roi_yxhw[0] + roi_yxhw[2],
                roi_yxhw[1] : roi_yxhw[1] + roi_yxhw[3],
            ]

myron's avatar
myron committed
96
        img = img.copy()
97
98
99
100

        img[np.isnan(img)] = invalid_val
        img[np.isinf(img)] = invalid_val

myron's avatar
myron committed
101
102
103
104
        if cutoff_pix is not None and cutoff_trans_pix is not None:
            myfilter = _get_lowpass_filter(img.shape[-2:], cutoff_pix, cutoff_trans_pix)
            img_shape = img.shape
            if len(img_shape) == 2:
105
                img = np.fft.ifft2(np.fft.fft2(img) * myfilter).real
myron's avatar
myron committed
106
            elif len(img_shape) > 2:
107
108
109
110
111
112
113
114
                # if dealing with a stack of images, we have to do them one by one
                img = np.reshape(img, (-1,) + tuple(img_shape[-2:]))
                for ii in range(img.shape[0]):
                    img[ii, ...] = np.fft.ifft2(
                        np.fft.fft2(img[ii, ...]) * myfilter
                    ).real
                img = np.reshape(img, img_shape)

115
        if median_filt_shape is not None:
116
117
118
119
            img_shape = img.shape
            if __have_scipy__:
                # expanding filter shape with ones, to cover the stack of images
                # but disabling inter-image filtering
120
121
122
123
124
125
126
127
                median_filt_shape = np.concatenate(
                    (
                        np.ones(
                            (len(img_shape) - len(median_filt_shape),), dtype=np.int
                        ),
                        median_filt_shape,
                    )
                )
128
129
130
131
132
133
                img = median_filter(img, size=median_filt_shape)
            else:
                if len(img_shape) == 2:
                    img = median_filter(img, kernel_size=median_filt_shape)
                elif len(img_shape) > 2:
                    # if dealing with a stack of images, we have to do them one by one
134
                    img = np.reshape(img, (-1,) + tuple(img_shape[-2:]))
135
                    for ii in range(img.shape[0]):
136
137
138
                        img[ii, ...] = median_filter(
                            img[ii, ...], kernel_size=median_filt_shape
                        )
139
                    img = np.reshape(img, img_shape)
140
141
142

        return img

143
144
    @staticmethod
    def _compute_correlation_fft(img_1, img_2, padding_mode):
145
        do_circular_conv = padding_mode is None or padding_mode == "wrap"
146
147
        if not do_circular_conv:
            padding = np.ceil(np.array(img_2.img_shape) / 2).astype(np.int)
148
149
            img_1 = np.pad(img_1, ((padding[0],), (padding[1],)), mode=padding_mode)
            img_2 = np.pad(img_2, ((padding[0],), (padding[1],)), mode=padding_mode)
150
151
152
153
154
155
156
157

        # compute fft's of the 2 images
        img_fft_1 = np.fft.fft2(img_1)
        img_fft_2 = np.conjugate(np.fft.fft2(img_2))
        # inverse fft of the product to get cross_correlation of the 2 images
        cc = np.real(np.fft.ifft2(img_fft_1 * img_fft_2))

        if not do_circular_conv:
158
            cc = cc[padding[0] : -padding[0], padding[1] : -padding[1]]
159
160
161

        return cc

162
163
164
165
166
167
168
169
170
171
    def refine_max_position(self, xf, yf):
        """Computes the sub-pixel max position of the given function sampling.

        Parameters
        ----------
        xf: numpy.ndarray
            Coordinates of the sampled points
        yf: numpy.ndarray
            Function values of the sampled points

172
173
174
175
176
        Raises
        ------
        ValueError
            In case position and values do not have the same size.

177
178
179
180
181
        Returns
        -------
        float
            Estimated function max, according to the coordinates in xf.
        """
182
183
        npix = xf.size
        if not npix == yf.size:
184
185
186
187
            raise ValueError(
                "Coordinates and values should have the same length. Sizes of xf: %d, yf: %d"
                % (xf.size, yf.size)
            )
188
189

        matr = xf[:, None] ** (npix - np.arange(1, npix + 1))
190
191
192
193
194

        # kf is the coeffs of the polynom passing trough the npix pixels
        kf = np.linalg.inv(matr).dot(yf[:, None])

        # compute coefficients of the 1st order derivative of the polynom
195
        cf = ((npix - 1) - np.arange(npix - 1)) * kf[:-1, 0]
196
197
198
199
200
201
202
203
204
205
206
207
        # revert for polyroots
        cf = np.flip(cf)

        # maximum is the 0 of the 1st derivative of the polynom
        zeros_pol = polyroots(cf)

        # The pixel within the roots of the derivative of the polynom is
        # the one near the approximate center of rotation (CoR) px
        approx_max = xf[np.argmax(yf)]
        closest_pix_ind = np.argmin((zeros_pol - approx_max) ** 2)
        return np.real(zeros_pol[closest_pix_ind])

208
209

class CenterOfRotation(AlignmentBase):
210
    def __init__(self, poly_deg=2, horz_fft_width=False):
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
        """
        Center of Rotation (CoR) computation object.
        This class is used on radios.

        Parameters
        ----------
        poly_deg: int, optional
            Degree of polynom to interpolate the global value to a
            sub-pixel precision. The number of pixels involved will be

            >>> npix = poly_deg + 1

        horz_fft_width: boolean, optional
            If True, restrict the horizontal size to a power of 2:

            >>> new_x_dim = 2 ** math.floor(math.log2(x_dim))
        """
        self._init_parameters(poly_deg, horz_fft_width)

    def _init_parameters(self, poly_deg, horz_fft_width):
        self.truncate_horz_pow2 = horz_fft_width
        self.npix = poly_deg + 1

    @staticmethod
    def _check_img_sizes(img_1: np.ndarray, img_2: np.ndarray):
        shape_1 = np.squeeze(img_1).shape
        shape_2 = np.squeeze(img_2).shape
        if not len(shape_1) == 2:
239
240
241
242
            raise ValueError(
                "Images need to be 2-dimensional. Shape of image #1: %s"
                % (" ".join(("%d" % x for x in shape_1)))
            )
243
        if not len(shape_2) == 2:
244
245
246
247
            raise ValueError(
                "Images need to be 2-dimensional. Shape of image #2: %s"
                % (" ".join(("%d" % x for x in shape_2)))
            )
248
        if not np.all(shape_1 == shape_2):
249
250
251
252
253
254
255
            raise ValueError(
                "Images need to be of the same shape. Shape of image #1: %s, image #2: %s"
                % (
                    " ".join(("%d" % x for x in shape_1)),
                    " ".join(("%d" % x for x in shape_2)),
                )
            )
256

257
    def find_shift(
258
259
260
261
262
        self,
        img_1: np.ndarray,
        img_2: np.ndarray,
        roi_yxhw=None,
        median_filt_shape=None,
myron's avatar
myron committed
263
        padding_mode=None,
264
265
        cutoff_pix=None,
        cutoff_trans_pix=None,
266
    ):
267
        """Find the Center of Rotation (CoR), given to images.
Nicola Vigano's avatar
Nicola Vigano committed
268

Pierre Paleo's avatar
Fix doc    
Pierre Paleo committed
269
270
        This method finds the half-shift between two opposite images, by
        means of correlation computed in Fourier space.
271

272
273
        The output of this function, allows to compute motor movements for
        aligning the sample rotation axis. Given the following values:
274

275
276
277
278
        - L1: distance from source to motor
        - L2: distance from source to detector
        - ps: physical pixel size
        - v: output of this function
279

280
        displacement of motor = (L1 / L2 * s) * v
281

282
283
284
285
286
287
        Parameters
        ----------
        img_1: numpy.ndarray
            First image
        img_2: numpy.ndarray
            Second image, it needs to have been flipped already (e.g. using numpy.fliplr).
288
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
289
290
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
291
292
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
293
294
295
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
296
297
298
299
300
301
302
303
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

        Raises
        ------
        ValueError
            In case images are not 2-dimensional or have different sizes.

        Returns
        -------
        float
            Estimated center of rotation position from the center of the RoI in pixels.

        Examples
        --------
        The following code computes the center of rotation position for two
        given images in a tomography scan, where the second image is taken at
        180 degrees from the first.

        >>> radio1 = data[0, :, :]
        ... radio2 = np.fliplr(data[1, :, :])
323
        ... CoR_calc = CenterOfRotation(poly_deg=2)
324
        ... cor_position = CoR_calc.find_shift(radio1, radio2)
325
326
327

        Or for noisy images:

328
        >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
329
330
331
        """
        self._check_img_sizes(img_1, img_2)

332
        img_shape = img_2.shape
333
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, self.truncate_horz_pow2)
334

335
336
337
338
339
340
341
342
343
344
345
346
347
348
        img_1 = self._prepare_image(
            img_1,
            roi_yxhw=roi_yxhw,
            median_filt_shape=median_filt_shape,
            cutoff_pix=cutoff_pix,
            cutoff_trans_pix=cutoff_trans_pix,
        )
        img_2 = self._prepare_image(
            img_2,
            roi_yxhw=roi_yxhw,
            median_filt_shape=median_filt_shape,
            cutoff_pix=cutoff_pix,
            cutoff_trans_pix=cutoff_trans_pix,
        )
349

350
        cc = self._compute_correlation_fft(img_1, img_2, padding_mode)
351
352

        img_shape = cc.shape  # Re-define image shape after RoI
353

354
        # get pixel having the maximum value of the correlation array
Nicola Vigano's avatar
Nicola Vigano committed
355
        pix_max_corr = np.argmax(cc)
myron's avatar
myron committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        NY, NX = cc.shape
        py, px = np.unravel_index(pix_max_corr, (NY, NX))

        payload_vals = np.zeros([3, 3], "f")

        ly = max(0, py - 1)
        lx = max(0, px - 1)
        hy = min(NY, py + 1)
        hx = min(NX, px + 1)
        payload_vals[:] = (cc[ly : hy + 1, lx : hx + 1]).min()

        payload_vals[
            (ly - (py - 1)) : (hy + 1 - (py - 1)), (lx - (px - 1)) : (hx + 1 - (px - 1))
        ] = cc[ly : hy + 1, lx : hx + 1]

        Dpy, Dpx = self._biquadratic_refinement(payload_vals)
372

myron's avatar
myron committed
373
374
        ppy = py + Dpy
        ppx = px + Dpx
375
376
377
378
379

        if ppy > (cc.shape[0] - 1) / 2:
            ppy = ppy - cc.shape[0]
        if ppx > (cc.shape[1] - 1) / 2:
            ppx = ppx - cc.shape[1]
380

381
        fitted_shift = ppx
382

383
        return fitted_shift / 2.0
384

385
    __call__ = find_shift
386
387
388
389
390
391
392
393


class DetectorTranslationAlongBeam(AlignmentBase):
    @staticmethod
    def _check_img_sizes(img_stack: np.ndarray, img_pos: np.ndarray):
        shape_stack = np.squeeze(img_stack).shape
        shape_pos = np.squeeze(img_pos).shape
        if not len(shape_stack) == 3:
394
395
396
397
            raise ValueError(
                "A stack of 2-dimensional images is required. Shape of stack: %s"
                % (" ".join(("%d" % x for x in shape_stack)))
            )
398
        if not len(shape_pos) == 1:
399
400
401
402
            raise ValueError(
                "Positions need to be a 1-dimensional array. Shape of the positions variable: %s"
                % (" ".join(("%d" % x for x in shape_pos)))
            )
403
404
        if not shape_stack[0] == shape_pos[0]:
            raise ValueError(
405
406
407
408
409
410
411
                "The same number of images and positions is required."
                + " Shape of stack: %s, shape of positions variable: %s"
                % (
                    " ".join(("%d" % x for x in shape_stack)),
                    " ".join(("%d" % x for x in shape_pos)),
                )
            )
412
413
414
415
416
417
418
419

    def find_shift(
        self,
        img_stack: np.ndarray,
        img_pos: np.array,
        roi_yxhw=None,
        median_filt_shape=None,
        padding_mode=None,
420
        poly_deg=2,
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    ):
        """Find the deviation of the translation axis of the area detector
        along the beam propagation direction.

        TODO: Add more information here! Including interpretation of the result

        Parameters
        ----------
        img_stack: numpy.ndarray
            A stack of images (usually 4) at different distances
        img_pos: numpy.ndarray
            Position of the images along the translation axis
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'

        Returns
        -------
        tuple(float, float)
            Estimated (vertical, horizontal) increment per unit-distance of the shift.

        Examples
        --------
        TODO: Add examples here!
        """
        self._check_img_sizes(img_stack, img_pos)

        num_imgs = img_stack.shape[0]
        img_shape = img_stack.shape[-2:]
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, False)

465
466
467
        img_stack = self._prepare_image(
            img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape
        )
468

469
470
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
471
472
        # do correlations
        ccs = [
473
474
475
476
477
            self._compute_correlation_fft(
                img_stack[0, ...], img_stack[ii, ...], padding_mode
            )
            for ii in range(1, num_imgs)
        ]
478
479
480
481

        npix = poly_deg + 1
        poly_center = npix // 2

482
        shifts_v, shifts_h = np.empty((num_imgs - 1,)), np.empty((num_imgs - 1,))
483
484
485
486
487
        for ii, cc in enumerate(ccs):
            # get pixel having the maximum value of the correlation array
            pix_max_corr = np.argmax(cc)
            pv, ph = np.unravel_index(pix_max_corr, img_shape)

488
489
490
            xh_pos = ((ph - poly_center + np.arange(npix)) % img_shape[1]).astype(
                dtype=np.int
            )
491
492
            shifts_h[ii] = self.refine_max_position(cc_hs[xh_pos], cc[pv, xh_pos])

493
494
495
            xv_pos = ((pv - poly_center + np.arange(npix)) % img_shape[0]).astype(
                dtype=np.int
            )
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            shifts_v[ii] = self.refine_max_position(cc_vs[xv_pos], cc[xv_pos, ph])

        """
        TODO: finish implementing the following logic!

        pos -= repmat(pos(:,1),1,ni);

        ph = polyfit(motpos,pos(1,:),1);
        thz = ph(1)*pixsize/1000*180/pi;
        printf('Correction in horizontal angle: %5.2f urad -> mvr thz %5.4f\n',thz*pi/180*1e6,thz)
        figure(1);
        plot(motpos,pos(1,:))
        title('Horizontal displacement')

        pv = polyfit(motpos,pos(2,:),1);
        thy = pv(1)*pixsize/1000*180/pi;
        printf('Correction in vertical angle: %5.2f urad -> mvr thy %5.4f\n',thy*pi/180*1e6,thy)
        figure(2);
        plot(motpos,pos(2,:))
        title('Vertical displacement')
        """