Commit 747d9e61 authored by myron's avatar myron Committed by Pierre Paleo
Browse files

Test with global search for half tomo passes

parent ad8562e5
import numpy as np
import math
import logging
from numpy.polynomial.polynomial import Polynomial, polyval
......@@ -476,8 +477,54 @@ class AlignmentBase(object):
return cc
class CenterOfRotation(AlignmentBase):
def global_search_master_(
self,
img_1: np.ndarray,
img_2: np.ndarray,
roi_yxhw=None,
median_filt_shape=None,
padding_mode=None,
peak_fit_radius=1,
high_pass=None,
low_pass=None,
half_tomo_cor_guess=None,
half_tomo_return_stats = False,
limits = None
):
dim_radio = img_1.shape[1]
if limits is None:
lim1, lim2 = 10.0, dim_radio-10
else:
lim1, lim2 = limits
found_centers = []
Xcor = lim1
while( Xcor < lim2 ):
cor_position, merit , energy = self.find_shift(img_1,
img_2,
low_pass=low_pass,
high_pass=high_pass,
half_tomo_cor_guess= (Xcor - (img_1.shape[1]//2) ) ,
half_tomo_return_stats = True )
if not np.isnan(merit):
found_centers.append( [ merit, cor_position, energy ] )
Xcor = min(
Xcor + Xcor / 6.0 ,
Xcor + ( dim_radio - Xcor ) / 6.0
)
found_centers.sort()
cor_position = found_centers[0][1]
return cor_position
def find_shift(
self,
img_1: np.ndarray,
......@@ -489,6 +536,8 @@ class CenterOfRotation(AlignmentBase):
high_pass=None,
low_pass=None,
half_tomo_cor_guess=None,
half_tomo_return_stats = False,
global_search = False
):
"""Find the Center of Rotation (CoR), given two images.
......@@ -535,7 +584,7 @@ class CenterOfRotation(AlignmentBase):
high_pass: float or sequence of two floats
High-pass filter properties, as described in `nabu.misc.fourier_filters`
half_tomo_cor_guess: float or None
The approximate position of the rotation axis from the iage center. Optional.
The approximate position of the rotation axis from the image center. Optional.
When given a special algorithm is used which can work also in half-tomo conditions
Raises
......@@ -563,6 +612,27 @@ class CenterOfRotation(AlignmentBase):
>>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
"""
if global_search:
if global_search is True:
limits = None
else:
limits = global_search
return self.global_search_master_( img_1,
img_2,
roi_yxhw,
median_filt_shape,
padding_mode,
peak_fit_radius,
high_pass,
low_pass,
limits
)
self._check_img_pair_sizes(img_1, img_2)
if peak_fit_radius < 1:
......@@ -579,10 +649,17 @@ class CenterOfRotation(AlignmentBase):
if half_tomo_cor_guess is not None:
cor_in_img = img_1.shape[1] // 2 + half_tomo_cor_guess
tmpsigma = (img_1.shape[1] - cor_in_img) / 4.0
tmpsigma = min(
(img_1.shape[1] - cor_in_img) / 4.0,
(cor_in_img) / 4.0,
)
tmpx = (np.arange(img_1.shape[1]) - cor_in_img) / tmpsigma
apodis = np.exp(-tmpx * tmpx / 2.0)
if half_tomo_return_stats:
img_1_orig = np.array(img_1)
img_1[:] = img_1 * apodis
cc = self._compute_correlation_fft(img_1, img_2, padding_mode, high_pass=high_pass, low_pass=low_pass)
img_shape = img_2.shape
......@@ -592,6 +669,8 @@ class CenterOfRotation(AlignmentBase):
(f_vals, fv, fh) = self.extract_peak_region_2d(cc, peak_radius=peak_fit_radius, cc_vs=cc_vs, cc_hs=cc_hs)
fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh)
return_value = fitted_shifts_vh[-1] / 2.0
if half_tomo_cor_guess is not None:
# fftfreqs and fftshifts are introducing jumps in the results
# we correct for that here
......@@ -602,11 +681,36 @@ class CenterOfRotation(AlignmentBase):
else:
p2 = -cc.shape[1] + tmp
if abs(half_tomo_cor_guess - p1 / 2) < abs(half_tomo_cor_guess - p2 / 2):
return p1 / 2
return_value = p1 / 2
else:
return p2 / 2
return_value = p2 / 2
if half_tomo_return_stats:
cor_in_img = img_1.shape[1] // 2 + return_value
tmpsigma = min (
(img_1.shape[1] - cor_in_img) / 4.0 ,
( cor_in_img ) / 4.0 ,
)
M1 = int(round(half_tomo_cor_guess + img_1.shape[1] // 2 )) - int(round(tmpsigma))
M2 = int(round(half_tomo_cor_guess + img_1.shape[1] // 2 )) + int(round(tmpsigma))
piece1 = np.log( np.maximum(1.0e-3, img_1_orig[ : ,
M1:M2
] ) )
piece2 = np.log( np.maximum(1.0e-3, img_2 [ : ,
img_1.shape[1] - M2 :
img_1.shape[1] - M1 ] )
)
return fitted_shifts_vh[-1] / 2.0
energy = np.array( piece1*piece1 + piece2*piece2 ,"d" ).sum()
diff_energy = np.array( (piece1-piece2) * (piece1-piece2) ,"d" ).sum()
return return_value , diff_energy/energy, energy
return return_value
__call__ = find_shift
......
......@@ -4,6 +4,7 @@ import os
import h5py
from silx.resources import ExternalResources
from nabu.testutils import get_data as nabu_get_data
try:
import scipy.ndimage
......@@ -91,6 +92,7 @@ def get_cor_data_half_tomo():
im1 = hf["im1"][()]
im2 = hf["im2"][()]
cor = hf["cor"][()]
return im1, im2, cor
......@@ -348,6 +350,40 @@ class TestCor(object):
)
assert np.isclose(cor_pos, cor_position, atol=self.abs_tol), message
def test_half_tomo_cor_exp(self):
""" test the hal_tomo algorithm on experimental data and global search
"""
radios = nabu_get_data("ha_autocor_radios.npz")
radio1 = radios["radio1"]
radio2 = radios["radio2"]
radio2 = np.fliplr( radio2 )
cor_pos = 983.038
CoR_calc = alignment.CenterOfRotation()
cor_position = CoR_calc.find_shift(radio1,
radio2,
low_pass=1,
high_pass=20,
global_search = True)
print( cor_position)
message = (
"Computed CoR %f " % cor_position
+ " and real CoR %f should coincide when using the halftomo algorithm with hald tomo data" % cor_pos
)
assert np.isclose(cor_pos, cor_position, atol=0.1), message
def test_cor_posx_linear(self):
radio1 = self.data[0, :, :]
radio2 = np.fliplr(self.data[1, :, :])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment