Commit a1ae1f9d authored by myron's avatar myron Committed by Pierre Paleo
Browse files

black

parent 747d9e61
......@@ -477,54 +477,50 @@ class AlignmentBase(object):
return cc
class CenterOfRotation(AlignmentBase):
class CenterOfRotation(AlignmentBase):
def global_search_master_(
self,
img_1: np.ndarray,
img_2: np.ndarray,
roi_yxhw=None,
median_filt_shape=None,
padding_mode=None,
peak_fit_radius=1,
high_pass=None,
low_pass=None,
half_tomo_cor_guess=None,
half_tomo_return_stats=False,
limits=None,
):
dim_radio = img_1.shape[1]
def global_search_master_(
self,
img_1: np.ndarray,
img_2: np.ndarray,
roi_yxhw=None,
median_filt_shape=None,
padding_mode=None,
peak_fit_radius=1,
high_pass=None,
low_pass=None,
half_tomo_cor_guess=None,
half_tomo_return_stats = False,
limits = None
):
dim_radio = img_1.shape[1]
if limits is None:
lim1, lim2 = 10.0, dim_radio - 10
else:
lim1, lim2 = limits
found_centers = []
Xcor = lim1
while Xcor < lim2:
cor_position, merit, energy = self.find_shift(
img_1,
img_2,
low_pass=low_pass,
high_pass=high_pass,
half_tomo_cor_guess=(Xcor - (img_1.shape[1] // 2)),
half_tomo_return_stats=True,
)
if not np.isnan(merit):
found_centers.append([merit, cor_position, energy])
if limits is None:
lim1, lim2 = 10.0, dim_radio-10
else:
lim1, lim2 = limits
found_centers = []
Xcor = lim1
while( Xcor < lim2 ):
cor_position, merit , energy = self.find_shift(img_1,
img_2,
low_pass=low_pass,
high_pass=high_pass,
half_tomo_cor_guess= (Xcor - (img_1.shape[1]//2) ) ,
half_tomo_return_stats = True )
if not np.isnan(merit):
found_centers.append( [ merit, cor_position, energy ] )
Xcor = min(
Xcor + Xcor / 6.0 ,
Xcor + ( dim_radio - Xcor ) / 6.0
)
Xcor = min(Xcor + Xcor / 6.0, Xcor + (dim_radio - Xcor) / 6.0)
found_centers.sort()
cor_position = found_centers[0][1]
return cor_position
found_centers.sort()
cor_position = found_centers[0][1]
return cor_position
def find_shift(
self,
img_1: np.ndarray,
......@@ -536,8 +532,8 @@ class CenterOfRotation(AlignmentBase):
high_pass=None,
low_pass=None,
half_tomo_cor_guess=None,
half_tomo_return_stats = False,
global_search = False
half_tomo_return_stats=False,
global_search=False,
):
"""Find the Center of Rotation (CoR), given two images.
......@@ -613,26 +609,16 @@ class CenterOfRotation(AlignmentBase):
>>> cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
"""
if global_search:
if global_search is True:
limits = None
else:
limits = global_search
return self.global_search_master_( img_1,
img_2,
roi_yxhw,
median_filt_shape,
padding_mode,
peak_fit_radius,
high_pass,
low_pass,
limits
)
return self.global_search_master_(
img_1, img_2, roi_yxhw, median_filt_shape, padding_mode, peak_fit_radius, high_pass, low_pass, limits
)
self._check_img_pair_sizes(img_1, img_2)
if peak_fit_radius < 1:
......@@ -649,17 +635,13 @@ class CenterOfRotation(AlignmentBase):
if half_tomo_cor_guess is not None:
cor_in_img = img_1.shape[1] // 2 + half_tomo_cor_guess
tmpsigma = min(
(img_1.shape[1] - cor_in_img) / 4.0,
(cor_in_img) / 4.0,
)
tmpsigma = min((img_1.shape[1] - cor_in_img) / 4.0, (cor_in_img) / 4.0,)
tmpx = (np.arange(img_1.shape[1]) - cor_in_img) / tmpsigma
apodis = np.exp(-tmpx * tmpx / 2.0)
if half_tomo_return_stats:
img_1_orig = np.array(img_1)
img_1_orig = np.array(img_1)
img_1[:] = img_1 * apodis
cc = self._compute_correlation_fft(img_1, img_2, padding_mode, high_pass=high_pass, low_pass=low_pass)
img_shape = img_2.shape
......@@ -670,7 +652,7 @@ class CenterOfRotation(AlignmentBase):
fitted_shifts_vh = self.refine_max_position_2d(f_vals, fv, fh)
return_value = fitted_shifts_vh[-1] / 2.0
if half_tomo_cor_guess is not None:
# fftfreqs and fftshifts are introducing jumps in the results
# we correct for that here
......@@ -687,29 +669,18 @@ class CenterOfRotation(AlignmentBase):
if half_tomo_return_stats:
cor_in_img = img_1.shape[1] // 2 + return_value
tmpsigma = min (
(img_1.shape[1] - cor_in_img) / 4.0 ,
( cor_in_img ) / 4.0 ,
)
tmpsigma = min((img_1.shape[1] - cor_in_img) / 4.0, (cor_in_img) / 4.0,)
M1 = int(round(half_tomo_cor_guess + img_1.shape[1] // 2 )) - int(round(tmpsigma))
M2 = int(round(half_tomo_cor_guess + img_1.shape[1] // 2 )) + int(round(tmpsigma))
M1 = int(round(half_tomo_cor_guess + img_1.shape[1] // 2)) - int(round(tmpsigma))
M2 = int(round(half_tomo_cor_guess + img_1.shape[1] // 2)) + int(round(tmpsigma))
piece1 = np.log( np.maximum(1.0e-3, img_1_orig[ : ,
M1:M2
] ) )
piece2 = np.log( np.maximum(1.0e-3, img_2 [ : ,
img_1.shape[1] - M2 :
img_1.shape[1] - M1 ] )
)
piece1 = np.log(np.maximum(1.0e-3, img_1_orig[:, M1:M2]))
piece2 = np.log(np.maximum(1.0e-3, img_2[:, img_1.shape[1] - M2 : img_1.shape[1] - M1]))
energy = np.array(piece1 * piece1 + piece2 * piece2, "d").sum()
diff_energy = np.array((piece1 - piece2) * (piece1 - piece2), "d").sum()
return return_value, diff_energy / energy, energy
energy = np.array( piece1*piece1 + piece2*piece2 ,"d" ).sum()
diff_energy = np.array( (piece1-piece2) * (piece1-piece2) ,"d" ).sum()
return return_value , diff_energy/energy, energy
return return_value
__call__ = find_shift
......@@ -886,10 +857,10 @@ class DetectorTranslationAlongBeam(AlignmentBase):
print("Fitted pixel shifts per unit-length: vertical = %f, horizontal = %f" % (coeffs_v[1], coeffs_h[1]))
f, axs = plt.subplots(1, 2)
axs[0].scatter(img_pos, shifts_vh[:, 0])
axs[0].plot(img_pos, polyval(img_pos, coeffs_v), '-C1')
axs[0].plot(img_pos, polyval(img_pos, coeffs_v), "-C1")
axs[0].set_title("Vertical shifts")
axs[1].scatter(img_pos, shifts_vh[:, 1])
axs[1].plot(img_pos, polyval(img_pos, coeffs_h), '-C1')
axs[1].plot(img_pos, polyval(img_pos, coeffs_h), "-C1")
axs[1].set_title("Horizontal shifts")
plt.show(block=False)
......@@ -1019,7 +990,7 @@ class CameraTilt(CenterOfRotation):
)
f, ax = plt.subplots(1, 1)
ax.scatter(cc_v_coords, fitted_shifts_h)
ax.plot(cc_v_coords, polyval(cc_v_coords, coeffs_h), '-C1')
ax.plot(cc_v_coords, polyval(cc_v_coords, coeffs_h), "-C1")
ax.set_title("Correlation peaks")
plt.show(block=False)
......@@ -1130,7 +1101,7 @@ class CameraFocus(CenterOfRotation):
print("Fitted focus motor position:", focus_pos, "and corresponding image position:", focus_ind)
f, ax = plt.subplots(1, 1)
ax.stem(img_pos, img_stds)
ax.stem(focus_pos, img_std_max, linefmt='C1-', markerfmt='C1o')
ax.stem(focus_pos, img_std_max, linefmt="C1-", markerfmt="C1o")
ax.set_title("Images std")
plt.show(block=False)
......@@ -1312,9 +1283,7 @@ class CameraFocus(CenterOfRotation):
img_shape = img_stack.shape[-2:]
block_size = np.array(img_shape) / regions_number
block_stack_size = np.array([
num_imgs, regions_number, block_size[-2], regions_number, block_size[-1]
], dtype=np.int)
block_stack_size = np.array([num_imgs, regions_number, block_size[-2], regions_number, block_size[-1]], dtype=np.int)
img_stack = np.reshape(img_stack, block_stack_size)
img_stds = np.std(img_stack, axis=(-3, -1)) / np.mean(img_stack, axis=(-3, -1))
......@@ -1344,13 +1313,13 @@ class CameraFocus(CenterOfRotation):
print("Fitted focus motor position:", focus_pos, "and corresponding image position:", focus_ind)
print("Fitted tilts (to be divided by pixel size, and converted to deg): (v, h) %s" % tilts_vh)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax = fig.add_subplot(111, projection="3d")
ax.plot_wireframe(fx, fy, focus_poss)
regions_half_shape = (regions_number - 1) / 2
base_points = np.linspace(-regions_half_shape, regions_half_shape, regions_number)
ax.plot(np.zeros((regions_number, )), base_points, np.polyval([tg_v, focus_pos], base_points), 'C2')
ax.plot(base_points, np.zeros((regions_number, )), np.polyval([tg_h, focus_pos], base_points), 'C2')
ax.scatter(0, 0, focus_pos, marker='o', c='C1')
ax.plot(np.zeros((regions_number,)), base_points, np.polyval([tg_v, focus_pos], base_points), "C2")
ax.plot(base_points, np.zeros((regions_number,)), np.polyval([tg_h, focus_pos], base_points), "C2")
ax.scatter(0, 0, focus_pos, marker="o", c="C1")
ax.set_title("Images std")
plt.show(block=False)
......
......@@ -92,7 +92,7 @@ def get_cor_data_half_tomo():
im1 = hf["im1"][()]
im2 = hf["im2"][()]
cor = hf["cor"][()]
return im1, im2, cor
......@@ -350,8 +350,6 @@ class TestCor(object):
)
assert np.isclose(cor_pos, cor_position, atol=self.abs_tol), message
def test_half_tomo_cor_exp(self):
""" test the hal_tomo algorithm on experimental data and global search
"""
......@@ -360,30 +358,22 @@ class TestCor(object):
radio1 = radios["radio1"]
radio2 = radios["radio2"]
radio2 = np.fliplr( radio2 )
cor_pos = 983.038
radio2 = np.fliplr(radio2)
cor_pos = 983.038
CoR_calc = alignment.CenterOfRotation()
cor_position = CoR_calc.find_shift(radio1,
radio2,
low_pass=1,
high_pass=20,
global_search = True)
print( cor_position)
cor_position = CoR_calc.find_shift(radio1, radio2, low_pass=1, high_pass=20, global_search=True)
print("Found cor_position", cor_position)
message = (
"Computed CoR %f " % cor_position
+ " and real CoR %f should coincide when using the halftomo algorithm with hald tomo data" % cor_pos
)
assert np.isclose(cor_pos, cor_position, atol=0.1), message
def test_cor_posx_linear(self):
radio1 = self.data[0, :, :]
radio2 = np.fliplr(self.data[1, :, :])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment