Commit ba35ef1a authored by Pierre Paleo's avatar Pierre Paleo
Browse files

CTFPhaseRetrieval: add support for R2C FFT

parent 8454cab5
Pipeline #51810 passed with stage
in 7 minutes and 12 seconds
......@@ -85,7 +85,7 @@ class GeoPars:
class CTFPhaseRetrieval:
def __init__(self, geo_pars, delta_beta, lim1=1.0e-5, lim2=0.2, logger=None):
def __init__(self, geo_pars, delta_beta, lim1=1.0e-5, lim2=0.2, use_rfft=False, logger=None):
"""
This class implements the CTF formula (see equation 8 in Boliang Yu et al
Optic Express Vol 26, No 9, 2018) in its regularised form which avoids the zeros
......@@ -102,13 +102,15 @@ class CTFPhaseRetrieval:
the regulariser strenght at low frequencies
lim2: float >0
the regulariser strenght at high frequencies
use_rfft: bool, optional
Whether to use real-to-complex (R2C) FFT instead of usual complex-to-complex (C2C).
logger: optional
a logger object
"""
self.logger = LoggerOrPrint(logger)
self.setup(geo_pars, delta_beta, lim1=lim1, lim2=lim2)
self.setup(geo_pars, delta_beta, lim1=lim1, lim2=lim2, use_rfft=use_rfft)
def setup(self, geo_pars, delta_beta, lim1=None, lim2=None):
def setup(self, geo_pars, delta_beta, lim1=None, lim2=None, use_rfft=False):
"""
Set up the current geometrical parameters.
......@@ -135,6 +137,9 @@ class CTFPhaseRetrieval:
self.geo_pars = geo_pars
self.filter_is_ready = False
self.filter_shape = None
self.use_rfft = use_rfft
self._fft_func = np.fft.rfft2 if use_rfft else np.fft.fft2
self._ifft_func = np.fft.irfft2 if use_rfft else np.fft.ifft2
def require_filter_for_shape(self, padded_img_shape):
......@@ -160,7 +165,10 @@ class CTFPhaseRetrieval:
[self.geo_pars.length_scale / self.geo_pars.pix_size_rec, self.geo_pars.length_scale / self.geo_pars.pix_size_rec]
)
ff_index_vh = list(map(np.fft.fftfreq, padded_img_shape))
if not self.use_rfft:
ff_index_vh = list(map(np.fft.fftfreq, padded_img_shape))
else:
ff_index_vh = [np.fft.fftfreq(padded_img_shape[0]), np.fft.rfftfreq(padded_img_shape[1])]
# if padded_img_shape[1]%2 == 0 : # change to holotomo_slave indexing (by a transparent 2pi shift)
# ff_index_x[ ff_index_x == -0.5 ] = +0.5
......@@ -231,6 +239,7 @@ class CTFPhaseRetrieval:
0.5 / (self.cut_v + 1.0 / padded_img_shape[0]),
0.01 * padded_img_shape[0] / (1 + self.cut_v * padded_img_shape[0]),
),
use_rfft=self.use_rfft
)
self.r /= self.r[0, 0]
......@@ -260,17 +269,17 @@ class CTFPhaseRetrieval:
def _apply_filter(self, img):
self.require_filter_for_shape(img.shape)
img_f = np.fft.fft2(img)
img_f = self._fft_func(img)
firec0 = self.unreg_filter_denom * img_f
unreg_filter_denom_0_mean = self.unreg_filter_denom[0, 0]
nf, mf = firec0.shape
nf, mf = img.shape
# here it is assumed that the average of img is 1 and the DC component is removed
firec0[0, 0] -= nf * mf * unreg_filter_denom_0_mean
## formula 8, with regularisation to stay at a safe distance from the poles
ph = firec0 / (2 * self.unreg_filter_denom * self.unreg_filter_denom + self.lim)
ph = np.fft.ifft2(ph).real
ph = self._ifft_func(ph).real
return ph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment