Commit dfda1469 authored by Nicola Vigano's avatar Nicola Vigano
Browse files

Alignment: allow specifying Fourier filter type and shape



This allows to automatically create filters that are compatible with rfft
Signed-off-by: Nicola Vigano's avatarNicola VIGANÒ <nicola.vigano@esrf.fr>
parent c8ddff07
......@@ -15,7 +15,7 @@ except ImportError:
__have_scipy__ = False
def get_lowpass_filter(img_shape, cutoff_par=None):
def get_lowpass_filter(img_shape, cutoff_par=None, use_rfft=False, data_type=np.float64):
"""Computes a low pass filter using the erfc function.
Parameters
......@@ -29,6 +29,10 @@ def get_lowpass_filter(img_shape, cutoff_par=None):
parameter.
When a sequence of two numbers is given then the filter is 1 ( no filtering) till the cutoff
frequency while a smooth erfc transition to zero is done
use_rfft: boolean, optional
Creates a filter to be used with the result of a rfft type of Fourier transform. Defaults to False.
data_type: `numpy.dtype`, optional
Specifies the data type of the computed filter. It defaults to `numpy.float64`
Raises
------
......@@ -62,7 +66,7 @@ def get_lowpass_filter(img_shape, cutoff_par=None):
coords = [np.fft.fftfreq(s, 1) for s in img_shape]
coords = np.meshgrid(*coords, indexing="ij")
r = np.sqrt(np.sum(np.array(coords) ** 2, axis=0))
r = np.sqrt(np.sum(np.array(coords, dtype=data_type) ** 2, axis=0))
if cutoff_trans_fact is not None:
k_cut = 0.5 / cutoff_pix
......@@ -75,10 +79,17 @@ def get_lowpass_filter(img_shape, cutoff_par=None):
else:
res = np.exp(-(np.pi ** 2) * (r ** 2) * (cutoff_pix ** 2) * 2)
return res
# Making sure to force result to chosen data type
res = res.astype(data_type)
if use_rfft:
slicelist = tuple(slice(0, (N + 1) // 2) for N in res.shape)
return res[slicelist]
else:
return res
def get_highpass_filter(img_shape, cutoff_par=None):
def get_highpass_filter(img_shape, cutoff_par=None, use_rfft=False, data_type=np.float64):
"""Computes a high pass filter using the erfc function.
Parameters
......@@ -92,6 +103,10 @@ def get_highpass_filter(img_shape, cutoff_par=None):
parameter, and the result is subtracted from 1 to obtain the high pass filter
When a sequence of two numbers is given then the filter is 1 ( no filtering) above the cutoff
frequency and then a smooth transition to zero is done for smaller frequency
use_rfft: boolean, optional
Creates a filter to be used with the result of a rfft type of Fourier transform. Defaults to False.
data_type: `numpy.dtype`, optional
Specifies the data type of the computed filter. It defaults to `numpy.float64`
Raises
------
......@@ -106,10 +121,10 @@ def get_highpass_filter(img_shape, cutoff_par=None):
if cutoff_par is None:
return 1
else:
return 1 - get_lowpass_filter(img_shape, cutoff_par)
return 1 - get_lowpass_filter(img_shape, cutoff_par, use_rfft=use_rfft, data_type=data_type)
def get_bandpass_filter(img_shape, cutoff_lowpass=None, cutoff_highpass=None):
def get_bandpass_filter(img_shape, cutoff_lowpass=None, cutoff_highpass=None, use_rfft=False, data_type=np.float64):
"""Computes a band pass filter using the erfc function.
The cutoff structures should be formed as follows:
......@@ -127,6 +142,10 @@ def get_bandpass_filter(img_shape, cutoff_lowpass=None, cutoff_highpass=None):
Cutoff parameters for the low-pass filter
cutoff_highpass: float or sequence of two floats
Cutoff parameters for the high-pass filter
use_rfft: boolean, optional
Creates a filter to be used with the result of a rfft type of Fourier transform. Defaults to False.
data_type: `numpy.dtype`, optional
Specifies the data type of the computed filter. It defaults to `numpy.float64`
Raises
------
......@@ -138,6 +157,6 @@ def get_bandpass_filter(img_shape, cutoff_lowpass=None, cutoff_highpass=None):
numpy.array_like
The computed filter
"""
return get_lowpass_filter(img_shape, cutoff_par=cutoff_lowpass) * get_highpass_filter(
img_shape, cutoff_par=cutoff_highpass
)
return get_lowpass_filter(
img_shape, cutoff_par=cutoff_lowpass, use_rfft=use_rfft, data_type=data_type
) * get_highpass_filter(img_shape, cutoff_par=cutoff_highpass, use_rfft=use_rfft, data_type=data_type)
import numpy as np
try:
import scipy.fft
my_fftn = scipy.fft.rfftn
my_ifftn = scipy.fft.irfftn
my_fft2 = scipy.fft.rfft2
my_ifft2 = scipy.fft.irfft2
def my_fft_layout_adapt(x):
slicelist = tuple(slice(0, (N + 1) // 2) for N in x.shape)
return x[slicelist]
except ImportError:
my_fftn = np.fft.fftn
my_ifftn = np.fft.ifftn
my_fft2 = np.fft.fft2
my_ifft2 = np.fft.ifft2
def my_fft_layout_adapt(x):
return x
import logging
from numpy.polynomial.polynomial import Polynomial, polyval
......@@ -32,10 +9,18 @@ from nabu.misc import fourier_filters
try:
from scipy.ndimage.filters import median_filter
import scipy.fft
local_fftn = scipy.fft.rfftn
local_ifftn = scipy.fft.irfftn
__have_scipy__ = True
except ImportError:
from silx.math.medianfilter import medfilt2d as median_filter
local_fftn = np.fft.fftn
local_ifftn = np.fft.ifftn
__have_scipy__ = False
try:
......@@ -61,6 +46,8 @@ class AlignmentBase(object):
>>> new_x_dim = 2 ** math.floor(math.log2(x_dim))
verbose: boolean, optional
When True it will produce verbose output, including plots.
data_type: `numpy.float32`
Computation data type.
"""
self._init_parameters(horz_fft_width, verbose, data_type)
......@@ -318,10 +305,7 @@ class AlignmentBase(object):
roi_yxhw = np.concatenate(((img_shape - roi_yxhw) // 2, roi_yxhw))
return roi_yxhw
@staticmethod
def _prepare_image(
img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, low_pass=None, high_pass=None, data_type=None
):
def _prepare_image(self, img, invalid_val=1e-5, roi_yxhw=None, median_filt_shape=None, low_pass=None, high_pass=None):
"""
Prepare and returns a cropped and filtered image, or array of filtered images if the input is an array of images.
......@@ -344,7 +328,7 @@ class AlignmentBase(object):
The computed filter
"""
img = np.squeeze(img) # Removes singleton dimensions, but does a shallow copy
img = np.ascontiguousarray(img, dtype=data_type)
img = np.ascontiguousarray(img, dtype=self.data_type)
if roi_yxhw is not None:
img = img[
......@@ -358,10 +342,14 @@ class AlignmentBase(object):
if high_pass is not None or low_pass is not None:
img_filter = fourier_filters.get_bandpass_filter(
img.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass
img.shape[-2:],
cutoff_lowpass=low_pass,
cutoff_highpass=high_pass,
use_rfft=__have_scipy__,
data_type=self.data_type,
)
# fft2 and iff2 use axes=(-2, -1) by default
img = my_ifft2(my_fft2(img) * my_fft_layout_adapt(img_filter.astype(self.data_type))).real
img = local_ifftn(local_fftn(img, axes=(-2, -1)) * img_filter, axes=(-2, -1)).real
if median_filt_shape is not None:
img_shape = img.shape
......@@ -384,8 +372,7 @@ class AlignmentBase(object):
return img
@staticmethod
def _compute_correlation_fft(img_1, img_2, padding_mode, axes=(-2, -1), low_pass=None, high_pass=None):
def _compute_correlation_fft(self, img_1, img_2, padding_mode, axes=(-2, -1), low_pass=None, high_pass=None):
do_circular_conv = padding_mode is None or padding_mode == "wrap"
if not do_circular_conv:
img_shape = img_2.shape
......@@ -398,17 +385,23 @@ class AlignmentBase(object):
img_2 = np.pad(img_2, pad_array, mode=padding_mode)
# compute fft's of the 2 images
img_fft_1 = my_fftn(img_1, axes=axes)
img_fft_2 = np.conjugate(my_fftn(img_2, axes=axes))
img_fft_1 = local_fftn(img_1, axes=axes)
img_fft_2 = np.conjugate(local_fftn(img_2, axes=axes))
img_prod = img_fft_1 * img_fft_2
if low_pass is not None or high_pass is not None:
filt = fourier_filters.get_bandpass_filter(img_prod.shape[-2:], cutoff_lowpass=low_pass, cutoff_highpass=high_pass)
filt = fourier_filters.get_bandpass_filter(
img_prod.shape[-2:],
cutoff_lowpass=low_pass,
cutoff_highpass=high_pass,
use_rfft=__have_scipy__,
data_type=self.data_type,
)
img_prod *= filt
# inverse fft of the product to get cross_correlation of the 2 images
cc = np.real(my_ifftn(img_prod, axes=axes))
cc = np.real(local_ifftn(img_prod, axes=axes))
if not do_circular_conv:
cc = np.fft.fftshift(cc, axes=axes)
......@@ -534,8 +527,8 @@ class CenterOfRotation(AlignmentBase):
img_shape = img_2.shape
roi_yxhw = self._determine_roi(img_shape, roi_yxhw)
img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, data_type=self.data_type)
img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, data_type=self.data_type)
img_1 = self._prepare_image(img_1, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
img_2 = self._prepare_image(img_2, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
cc = self._compute_correlation_fft(img_1, img_2, padding_mode, high_pass=high_pass, low_pass=low_pass)
......@@ -665,9 +658,6 @@ class DetectorTranslationAlongBeam(AlignmentBase):
... )
... print(shifts_v, shifts_h)
>>> ( -2.47 , -1.236 )
"""
self._check_img_sizes(img_stack, img_pos)
......@@ -681,9 +671,7 @@ class DetectorTranslationAlongBeam(AlignmentBase):
img_shape = img_stack.shape[-2:]
roi_yxhw = self._determine_roi(img_shape, roi_yxhw)
img_stack = self._prepare_image(
img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape, data_type=self.data_type
)
img_stack = self._prepare_image(img_stack, roi_yxhw=roi_yxhw, median_filt_shape=median_filt_shape)
# do correlations
ccs = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment