Skip to content
alignment.py 24.6 KiB
Newer Older
import numpy as np
import logging
from numpy.polynomial.polynomial import Polynomial
from nabu.utils import previouspow2
from nabu.misc import fourier_filters
try:
    from scipy.ndimage.filters import median_filter
    __have_scipy__ = True
except ImportError:
    from silx.math.medianfilter import medfilt2d as median_filter
    __have_scipy__ = False
myron's avatar
myron committed
class AlignmentBase(object):
    @staticmethod
    def refine_max_position_2d(f_vals: np.ndarray, fy=None, fx=None):
        """Computes the sub-pixel max position of the given function sampling.
        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
        fy: numpy.ndarray, optional
            Vertical coordinates of the sampled points
        fx: numpy.ndarray, optional
            Horizontal coordinates of the sampled points
        Raises
        ------
        ValueError
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.
        Returns
        -------
        tuple(float, float)
            Estimated (vertical, horizontal) function max, according to the
            coordinates in fy and fx.
        """
        if len(f_vals.shape) > 2:
            raise ValueError(
                "The fitted values should form a 2-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fy is None:
            fy_half_size = (f_vals.shape[0] - 1) / 2
            fy = np.linspace(-fy_half_size, fy_half_size, f_vals.shape[0])
        elif not (len(fy.shape) == 1 and np.all(fy.size == f_vals.shape[0])):
            raise ValueError(
                "Vertical coordinates should have the same length as values matrix. Sizes of fy: %d, f_vals: [%s]"
                % (fy.size, " ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.shape[1] - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.shape[1])
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.shape[1])):
            raise ValueError(
                "Horizontal coordinates should have the same length as values matrix. Sizes of fx: %d, f_vals: [%s]"
                % (fx.size, " ".join(("%d" % s for s in f_vals.shape)))
            )

        fy, fx = np.meshgrid(fy, fx, indexing="ij")
        fy = fy.flatten()
        fx = fx.flatten()
        coords = np.array([np.ones(f_vals.size), fy, fx, fy * fx, fy ** 2, fx ** 2])
        coeffs = np.linalg.lstsq(coords.T, f_vals.flatten(), rcond=None)[0]
        # For a 1D parabola `f(x) = ax^2 + bx + c`, the vertex position is:
        # x_v = -b / 2a. For a 2D parabola, the vertex position is:
        # (y, x)_v = - b / A, where:
        A = [[2 * coeffs[4], coeffs[3]], [coeffs[3], 2 * coeffs[5]]]
        b = coeffs[1:3]
        vertex_yx = np.linalg.lstsq(A, -b, rcond=None)[0]

        vertex_min_yx = [np.min(fy), np.min(fx)]
        vertex_max_yx = [np.max(fy), np.max(fx)]
        if np.any(vertex_yx < vertex_min_yx) or np.any(vertex_yx > vertex_max_yx):
            raise ValueError(
                "Fitted (y: {}, x: {}) positions are outide the margins of input: y: [{}, {}], x: [{}, {}]".format(
                    vertex_yx[0], vertex_yx[1], vertex_min_yx[0], vertex_max_yx[0], vertex_min_yx[1], vertex_max_yx[1]
                )
            )
        return vertex_yx

    @staticmethod
    def refine_max_position_1d(f_vals, fx=None):
        """Computes the sub-pixel max position of the given function sampling.

        Parameters
        ----------
        f_vals: numpy.ndarray
            Function values of the sampled points
        fx: numpy.ndarray, optional
            Coordinates of the sampled points
            In case position and values do not have the same size, or in case
            the fitted maximum is outside the fitting region.

        Returns
        -------
        float
            Estimated function max, according to the coordinates in fx.
        """
                "The fitted values should form a 1-dimensional array. Array of shape: [%s] was given."
                % (" ".join(("%d" % s for s in f_vals.shape)))
            )
        if fx is None:
            fx_half_size = (f_vals.size - 1) / 2
            fx = np.linspace(-fx_half_size, fx_half_size, f_vals.size)
        elif not (len(fx.shape) == 1 and np.all(fx.size == f_vals.size)):
            raise ValueError(
                "Base coordinates should have the same length as values array. Sizes of fx: %d, f_vals: %d"
                % (fx.size, f_vals.size)
        # using Polynomial.fit, because supposed to be more numerically stable
        # than previous solutions (according to numpy).
        poly = Polynomial.fit(fx, f_vals, deg=2)
        coeffs = poly.convert().coef
        # For a 1D parabola `f(x) = c + bx + ax^2`, the vertex position is:
        # x_v = -b / 2a.
        vertex_x = - coeffs[1] / (2 * coeffs[2])
        vertex_min_x = np.min(fx)
        vertex_max_x = np.max(fx)
        if not (vertex_min_x < vertex_x < vertex_max_x):
            raise ValueError(
                "Fitted x: {} position is outide the margins of input: x: [{}, {}]".format(
                    vertex_x, vertex_min_x, vertex_max_x
                )
            )
        return vertex_x
    @staticmethod
    def extract_peak_region(cc, peak_radius=1, cc_vs=None, cc_hs=None):
        """
        Extracts a region around the maximum value.

        Parameters
        ----------
        cc: numpy.ndarray
            Correlation image.
        peak_radius: int, optional
            The l_inf radius of the area to extract around the peak. The default is 1.
        cc_vs: numpy.ndarray, optional
            The vertical coordinates of `cc`. The default is None.
        cc_hs: numpy.ndarray, optional
            The horizontal coordinates of `cc`. The default is None.

        Returns
        -------
        f_vals: numpy.ndarray
            The extracted function values.
        fv: numpy.ndarray
            The vertical coordinates of the extracted values.
        fh: numpy.ndarray
            The horizontal coordinates of the extracted values.
        """
        img_shape = np.array(cc.shape)
        # get pixel having the maximum value of the correlation array
        pix_max_corr = np.argmax(cc)
        pv, ph = np.unravel_index(pix_max_corr, img_shape)

        # select a n x n neighborhood for the sub-pixel fitting (with wrapping)
        pv = np.arange(pv - peak_radius, pv + peak_radius + 1) % img_shape[-2]
        ph = np.arange(ph - peak_radius, ph + peak_radius + 1) % img_shape[-1]

        # extract the (v, h) pixel coordinates
        fv = None if cc_vs is None else cc_vs[pv]
        fh = None if cc_hs is None else cc_hs[ph]

        # extract the correlation values
        pv, ph = np.meshgrid(pv, ph, indexing="ij")
        f_vals = cc[pv, ph]

        return (f_vals, fv, fh)

    @staticmethod
    def _determine_roi(img_shape, roi_yxhw, do_truncate_horz_pow2):
        if roi_yxhw is None:
            # vertical window size is reduced to a power of 2 to accelerate fft
            # same thing horizontal window - if requested. Default is not
            roi_yxhw = previouspow2(img_shape)
            if not do_truncate_horz_pow2:
                roi_yxhw[1] = img_shape[1]
        if len(roi_yxhw) == 2:  # Convert centered 2-element roi into 4-element
            roi_yxhw = np.array(roi_yxhw, dtype=np.int)
            roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw))
        return roi_yxhw
    def _prepare_image(
        img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, high_pass=None, low_pass=None
        """Prepare and returns  a cropped  and filtered image, or array of filtered images if the input is an  array of images.

        Parameters
        ----------
        img: numpy.ndarray
            image or stack of images
        invalid_val: float
            value to be used in replacement of nan and inf values
        median_filt_shape: int or sequence of int
            the width or the widths of the median window
        high_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter, and the result is subtracted from 1 to obtain the high pass filter
            When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff
            frequency while a smooth  transition to zero is done for smaller frequency
         low_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter.
            When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff
            frequency and then a smooth erfc transition to zero is done


        Returns
        -------
        numpy.array_like
            The computed filter
        """

        img = np.squeeze(img)  # Removes singleton dimensions, but does a shallow copy
        img = np.ascontiguousarray(img)
        if roi_yxhw is not None:
                ..., roi_yxhw[0]: roi_yxhw[0] + roi_yxhw[2], roi_yxhw[1]: roi_yxhw[1] + roi_yxhw[3],
myron's avatar
myron committed
        img = img.copy()

        img[np.isnan(img)] = invalid_val
        img[np.isinf(img)] = invalid_val

        if high_pass is not None or low_pass is not None:
            img_filter = np.ones(img.shape[-2:], dtype=img.dtype)
                img_filter[:] *= fourier_filters.get_lowpass_filter(img.shape[-2:], low_pass)
                img_filter[:] *= fourier_filters.get_highpass_filter(img.shape[-2:], high_pass)
            # fft2 and iff2 use axes=(-2, -1) by default
            img = np.fft.ifft2(np.fft.fft2(img) * img_filter).real
        if median_filt_shape is not None:
            img_shape = img.shape
            if __have_scipy__:
                # expanding filter shape with ones, to cover the stack of images
                # but disabling inter-image filtering
                median_filt_shape = np.concatenate(
                    (np.ones((len(img_shape) - len(median_filt_shape),), dtype=np.int), median_filt_shape,)
                img = median_filter(img, size=median_filt_shape)
            else:
                if len(img_shape) == 2:
                    img = median_filter(img, kernel_size=median_filt_shape)
                elif len(img_shape) > 2:
                    # if dealing with a stack of images, we have to do them one by one
                    img = np.reshape(img, (-1,) + tuple(img_shape[-2:]))
                    for ii in range(img.shape[0]):
                        img[ii, ...] = median_filter(img[ii, ...], kernel_size=median_filt_shape)
                    img = np.reshape(img, img_shape)
    @staticmethod
    def _compute_correlation_fft(img_1, img_2, padding_mode):
        do_circular_conv = padding_mode is None or padding_mode == "wrap"
Nicola Vigano's avatar
Nicola Vigano committed
            padding = np.ceil(np.array(img_2.shape) / 2).astype(np.int)
            img_1 = np.pad(img_1, ((padding[0],), (padding[1],)), mode=padding_mode)
            img_2 = np.pad(img_2, ((padding[0],), (padding[1],)), mode=padding_mode)

        # compute fft's of the 2 images
        img_fft_1 = np.fft.fft2(img_1)
        img_fft_2 = np.conjugate(np.fft.fft2(img_2))
        # inverse fft of the product to get cross_correlation of the 2 images
        cc = np.real(np.fft.ifft2(img_fft_1 * img_fft_2))

        if not do_circular_conv:
            cc = np.fft.fftshift(cc, axes=(-2, -1))
            cc = cc[padding[0]: -padding[0], padding[1]: -padding[1]]
            cc = np.fft.ifftshift(cc, axes=(-2, -1))

class CenterOfRotation(AlignmentBase):
    def __init__(self, horz_fft_width=False):
        """
        Center of Rotation (CoR) computation object.
        This class is used on radios.

        Parameters
        ----------
        horz_fft_width: boolean, optional
            If True, restrict the horizontal size to a power of 2:

            >>> new_x_dim = 2 ** math.floor(math.log2(x_dim))
        """
        self._init_parameters(horz_fft_width)
    def _init_parameters(self, horz_fft_width):
        self.truncate_horz_pow2 = horz_fft_width

    @staticmethod
    def _check_img_sizes(img_1: np.ndarray, img_2: np.ndarray):
        shape_1 = np.squeeze(img_1).shape
        shape_2 = np.squeeze(img_2).shape
        if not len(shape_1) == 2:
            raise ValueError(
                "Images need to be 2-dimensional. Shape of image #1: %s" % (" ".join(("%d" % x for x in shape_1)))
            raise ValueError(
                "Images need to be 2-dimensional. Shape of image #2: %s" % (" ".join(("%d" % x for x in shape_2)))
        if not np.all(shape_1 == shape_2):
            raise ValueError(
                "Images need to be of the same shape. Shape of image #1: %s, image #2: %s"
                % (" ".join(("%d" % x for x in shape_1)), " ".join(("%d" % x for x in shape_2)),)
    def find_shift(
        self,
        img_1: np.ndarray,
        img_2: np.ndarray,
        roi_yxhw=None,
        median_filt_shape=None,
myron's avatar
myron committed
        padding_mode=None,
        """Find the Center of Rotation (CoR), given to images.
Pierre Paleo's avatar
Pierre Paleo committed
        This method finds the half-shift between two opposite images, by
        means of correlation computed in Fourier space.
        The output of this function, allows to compute motor movements for
        aligning the sample rotation axis. Given the following values:
        - L1: distance from source to motor
        - L2: distance from source to detector
        - ps: physical pixel size
        - v: output of this function
        displacement of motor = (L1 / L2 * ps) * v
        Parameters
        ----------
        img_1: numpy.ndarray
            First image
        img_2: numpy.ndarray
            Second image, it needs to have been flipped already (e.g. using numpy.fliplr).
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
        high_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter, and the result is subtracted from 1 to obtain the high pass filter
            When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff
            frequency while a smooth  transition to zero is done for smaller frequency
        low_pass: float or sequence of two floats
            Position of the cut off in pixels, if a sequence is given the second float expresses the
            width of the transition region which is given as a fraction of the cutoff frequency.
            When only one float is given for this argument a gaussian is applied whose sigma is the
            parameter.
            When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff
            frequency and then a smooth erfc transition to zero is done

        Raises
        ------
        ValueError
            In case images are not 2-dimensional or have different sizes.

        Returns
        -------
        float
            Estimated center of rotation position from the center of the RoI in pixels.

        Examples
        --------
        The following code computes the center of rotation position for two
        given images in a tomography scan, where the second image is taken at
        180 degrees from the first.

        >>> radio1 = data[0, :, :]
        ... radio2 = np.fliplr(data[1, :, :])
        ... CoR_calc = CenterOfRotation()
        ... cor_position = CoR_calc.find_shift(radio1, radio2)
        >>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
        """
        self._check_img_sizes(img_1, img_2)

        if peak_fit_radius < 1:
            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
        img_shape = img_2.shape
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, self.truncate_horz_pow2)
        img_1 = self._prepare_image(
            img_1,
            roi_yxhw=roi_yxhw,
            median_filt_shape=median_filt_shape,
        )
        img_2 = self._prepare_image(
            img_2,
            roi_yxhw=roi_yxhw,
            median_filt_shape=median_filt_shape,
        cc = self._compute_correlation_fft(img_1, img_2, padding_mode)
        img_shape = img_2.shape
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
        (f_vals, fv, fh) = self.extract_peak_region(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
        fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh)
        return fitted_shifts_vh[-1] / 2.0
    __call__ = find_shift


class DetectorTranslationAlongBeam(AlignmentBase):
    @staticmethod
    def _check_img_sizes(img_stack: np.ndarray, img_pos: np.ndarray):
        shape_stack = np.squeeze(img_stack).shape
        shape_pos = np.squeeze(img_pos).shape
        if not len(shape_stack) == 3:
            raise ValueError(
                "A stack of 2-dimensional images is required. Shape of stack: %s" % (" ".join(("%d" % x for x in shape_stack)))
        if not len(shape_pos) == 1:
            raise ValueError(
                "Positions need to be a 1-dimensional array. Shape of the positions variable: %s"
                % (" ".join(("%d" % x for x in shape_pos)))
            )
        if not shape_stack[0] == shape_pos[0]:
            raise ValueError(
                "The same number of images and positions is required."
                + " Shape of stack: %s, shape of positions variable: %s"
                % (" ".join(("%d" % x for x in shape_stack)), " ".join(("%d" % x for x in shape_pos)),)
        self,
        img_stack: np.ndarray,
        img_pos: np.array,
        roi_yxhw=None,
        median_filt_shape=None,
        padding_mode=None,
        peak_fit_radius=1,
    ):
        """Find the deviation of the translation axis of the area detector
        along the beam propagation direction.

        TODO: Add more information here! Including interpretation of the result
        This means giving also an example on how to convert the returned values
        into meaningful quantities. See "Returns" for more details.

        Parameters
        ----------
        img_stack: numpy.ndarray
            A stack of images (usually 4) at different distances
        img_pos: numpy.ndarray
            Position of the images along the translation axis
        roi_yxhw: (2, ) or (4, ) numpy.ndarray, tuple, or array, optional
            4 elements vector containing: vertical and horizontal coordinates
            of first pixel, plus height and width of the Region of Interest (RoI).
            Or a 2 elements vector containing: plus height and width of the
            centered Region of Interest (RoI).
            Default is None -> deactivated.
        median_filt_shape: (2, ) numpy.ndarray, tuple, or array, optional
            Shape of the median filter window. Default is None -> deactivated.
        padding_mode: str in numpy.pad's mode list, optional
            Padding mode, which determines the type of convolution. If None or
            'wrap' are passed, this resorts to the traditional circular convolution.
            If 'edge' or 'constant' are passed, it results in a linear convolution.
            Default is the circular convolution.
            All options are:
                None | 'constant' | 'edge' | 'linear_ramp' | 'maximum' | 'mean'
                | 'median' | 'minimum' | 'reflect' | 'symmetric' |'wrap'
            Radius size around the max correlation pixel, for sub-pixel fitting.
            Minimum and default value is 1.
        equispaced_increments: boolean, optional
            Tells whether the position increments are equispaced or not. If
            equispaced increments are used, we have to compute the correlation
            images with respect to the first image, otherwise we can do it
            against adjacent images.
            The advantage of doing it between adjacent images is that we do not
            build up large displacements in the correlation.
            However, if this is done for equispaced images, the linear fitting
            becomes unstable.

        Returns
        -------
        tuple(float, float)
            Estimated (vertical, horizontal) increment per unit-distance of the
            ratio between pixel-size and detector translation.

        Examples
        --------
        TODO: Add examples here!
        """
        self._check_img_sizes(img_stack, img_pos)

            logging.getLogger(__name__).warning(
                "Parameter peak_fit_radius should be at least 1, given: %d instead." % peak_fit_radius
            )
        num_imgs = img_stack.shape[0]
        img_shape = img_stack.shape[-2:]
        roi_yxhw = self._determine_roi(img_shape, roi_yxhw, False)

        img_stack = self._prepare_image(img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
        cc_vs = np.fft.fftfreq(img_shape[-2], 1 / img_shape[-2])
        cc_hs = np.fft.fftfreq(img_shape[-1], 1 / img_shape[-1])
            self._compute_correlation_fft(
                img_stack[0 if equispaced_increments else ii - 1, ...], img_stack[ii, ...], padding_mode
            )
            for ii in range(1, num_imgs)
        ]
        for ii, cc in enumerate(ccs):
            (f_vals, fv, fh) = self.extract_peak_region(cc, peak_fit_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
            shifts_vh[ii, :] = self.refine_max_position_2d(f_vals, fv, fh)
        if equispaced_increments:
            img_pos_increments = img_pos[1:] - img_pos[0]
        else:
            img_pos_increments = - np.diff(img_pos)

        # Polynomial.fit is supposed to be more numerically stable than polyfit
        # (according to numpy)
        coeffs_v = Polynomial.fit(img_pos_increments, shifts_vh[:, 0], deg=1).convert().coef
        coeffs_h = Polynomial.fit(img_pos_increments, shifts_vh[:, 1], deg=1).convert().coef

        return coeffs_v[1], coeffs_h[1]