Commit 18e971f0 authored by Pierre Paleo's avatar Pierre Paleo

Merge branch '198-simple-half-acq-cor' into 'master'

Resolve "Add simple half-acquisition CoR finding algorithms"

Closes #198

See merge request !93
parents 2e415f39 cd8d4103
Pipeline #37008 passed with stages
in 4 minutes and 59 seconds
This diff is collapsed.
......@@ -32,6 +32,16 @@ def bootstrap_cor(request):
cls.cor_gl_pix, cls.cor_hl_pix, cls.tilt_deg = calib_data
@pytest.fixture(scope="class")
def bootstrap_cor_win(request):
cls = request.cls
cls.abs_tol = 0.2
cls.data_ha_proj, cls.cor_ha_pr_pix = get_cor_win_proj_data_h5("ha_autocor_radios.npz")
cls.data_ha_sino, cls.cor_ha_sn_pix = get_cor_win_sino_data_h5("halftomo_1_sino.npz")
@pytest.fixture(scope="class")
def bootstrap_dtr(request):
cls = request.cls
......@@ -81,6 +91,58 @@ def get_cor_data_h5(*dataset_path):
return data, (cor_global_pix, cor_highlow_pix, tilt_deg)
def get_cor_win_proj_data_h5(*dataset_path):
"""
Get a dataset file from silx.org/pub/nabu/data
dataset_args is a list describing a nested folder structures, ex.
["path", "to", "my", "dataset.h5"]
"""
dataset_relpath = os.path.join(*dataset_path)
dataset_downloaded_path = utilstest.getfile(dataset_relpath)
data = np.load(dataset_downloaded_path)
radios = np.stack((data["radio1"], data["radio2"]), axis=0)
return radios, data["cor_pos"]
def get_cor_win_sino_data_h5(*dataset_path):
"""
Get a dataset file from silx.org/pub/nabu/data
dataset_args is a list describing a nested folder structures, ex.
["path", "to", "my", "dataset.h5"]
"""
dataset_relpath = os.path.join(*dataset_path)
dataset_downloaded_path = utilstest.getfile(dataset_relpath)
data = np.load(dataset_downloaded_path)
sino_shape = data["sino"].shape
sinos = np.stack((data["sino"][:sino_shape[0]//2], data["sino"][sino_shape[0]//2:]), axis=0)
return sinos, data["cor"] - sino_shape[1] / 2
def get_cor_data_half_tomo():
"""Obtains two weakly overlapping images with features plus spurious noise for a challenging test of cor retrieval for half tomo."""
datasource = ExternalResources(project="nabu", url_base=None)
myfile = os.path.join(datasource.data_home, "data_for_ht_cor.h5")
if not os.path.isfile(myfile):
im1, im2, cor = _get_challenging_ImsCouple_for_halftomo_cor()
f = h5py.File(myfile, "w")
f["im1"] = np.array(im1, "f")
f["im2"] = np.array(im2, "f")
f["cor"] = cor
assert os.path.isfile(myfile)
with h5py.File(myfile, "r") as hf:
im1 = hf["im1"][()]
im2 = hf["im2"][()]
cor = hf["cor"][()]
return im1, im2, cor
def get_alignxc_data(*dataset_path):
"""
Get a dataset file from silx.org/pub/nabu/data
......@@ -349,6 +411,112 @@ class TestCor(object):
assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message
@pytest.mark.usefixtures("bootstrap_cor", "bootstrap_cor_win")
class TestCorWindowSlide(object):
def test_proj_center_axis_lft(self):
radio1 = self.data[0, :, :]
radio2 = np.fliplr(self.data[1, :, :])
CoR_calc = alignment.CenterOfRotationSlidingWindow()
cor_position = CoR_calc.find_shift(radio1, radio2, side="left", window_width=round(radio1.shape[-1] / 4.0 * 3.0))
message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_gl_pix
assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message
def test_proj_center_axis_cen(self):
radio1 = self.data[0, :, :]
radio2 = np.fliplr(self.data[1, :, :])
CoR_calc = alignment.CenterOfRotationSlidingWindow()
cor_position = CoR_calc.find_shift(radio1, radio2, side="center")
message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_gl_pix
assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message
def test_proj_right_axis_rgt(self):
radio1 = self.data_ha_proj[0, :, :]
radio2 = np.fliplr(self.data_ha_proj[1, :, :])
CoR_calc = alignment.CenterOfRotationSlidingWindow()
cor_position = CoR_calc.find_shift(radio1, radio2, side="right")
message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_pr_pix
assert np.isclose(self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message
def test_proj_left_axis_lft(self):
radio1 = np.fliplr(self.data_ha_proj[0, :, :])
radio2 = self.data_ha_proj[1, :, :]
CoR_calc = alignment.CenterOfRotationSlidingWindow()
cor_position = CoR_calc.find_shift(radio1, radio2, side="left")
message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % -self.cor_ha_pr_pix
assert np.isclose(-self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message
def test_sino_right_axis_rgt(self):
sino1 = self.data_ha_sino[0, :, :]
sino2 = np.fliplr(self.data_ha_sino[1, :, :])
CoR_calc = alignment.CenterOfRotationSlidingWindow()
cor_position = CoR_calc.find_shift(sino1, sino2, side="right")
message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_sn_pix
assert np.isclose(self.cor_ha_sn_pix, cor_position, atol=self.abs_tol * 5), message
@pytest.mark.usefixtures("bootstrap_cor", "bootstrap_cor_win")
class TestCorWindowGrow(object):
def test_proj_center_axis_cen(self):
radio1 = self.data[0, :, :]
radio2 = np.fliplr(self.data[1, :, :])
CoR_calc = alignment.CenterOfRotationGrowingWindow()
cor_position = CoR_calc.find_shift(radio1, radio2, side="center")
message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_gl_pix
assert np.isclose(self.cor_gl_pix, cor_position, atol=self.abs_tol), message
def test_proj_right_axis_rgt(self):
radio1 = self.data_ha_proj[0, :, :]
radio2 = np.fliplr(self.data_ha_proj[1, :, :])
CoR_calc = alignment.CenterOfRotationGrowingWindow()
cor_position = CoR_calc.find_shift(radio1, radio2, side="right")
message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_pr_pix
assert np.isclose(self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message
def test_proj_left_axis_lft(self):
radio1 = np.fliplr(self.data_ha_proj[0, :, :])
radio2 = self.data_ha_proj[1, :, :]
CoR_calc = alignment.CenterOfRotationGrowingWindow()
cor_position = CoR_calc.find_shift(radio1, radio2, side="left")
message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % -self.cor_ha_pr_pix
assert np.isclose(-self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message
def test_proj_right_axis_all(self):
radio1 = self.data_ha_proj[0, :, :]
radio2 = np.fliplr(self.data_ha_proj[1, :, :])
CoR_calc = alignment.CenterOfRotationGrowingWindow()
cor_position = CoR_calc.find_shift(radio1, radio2, side="all")
message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_pr_pix
assert np.isclose(self.cor_ha_pr_pix, cor_position, atol=self.abs_tol), message
def test_sino_right_axis_rgt(self):
sino1 = self.data_ha_sino[0, :, :]
sino2 = np.fliplr(self.data_ha_sino[1, :, :])
CoR_calc = alignment.CenterOfRotationGrowingWindow()
cor_position = CoR_calc.find_shift(sino1, sino2, side="right")
message = "Computed CoR %f " % cor_position + " and expected CoR %f do not coincide" % self.cor_ha_sn_pix
assert np.isclose(self.cor_ha_sn_pix, cor_position, atol=self.abs_tol * 4), message
@pytest.mark.usefixtures("bootstrap_dtr")
class TestDetectorTranslation(object):
def test_alignxc(self):
......
import numpy as np
from silx.io import get_data
from ..preproc.ccd import FlatField
from ..preproc.alignment import CenterOfRotation, CenterOfRotationAdaptiveSearch
from ..preproc.alignment import CenterOfRotation, CenterOfRotationAdaptiveSearch, CenterOfRotationSlidingWindow, CenterOfRotationGrowingWindow
from .logger import LoggerOrPrint
from .utils import extract_parameters
from ..utils import check_supported
class CORFinder:
"""
An application-type class for finding the Center Of Rotation (COR).
"""
def __init__(self, dataset_info, angles=None, halftomo=False, do_flatfield=True, logger=None):
search_methods = {
"centered":{
"class": CenterOfRotation,
},
"global": {
"class": CenterOfRotationAdaptiveSearch,
"default_kwargs": {"low_pass": 1, "high_pass":20},
},
"sliding-window": {
"class": CenterOfRotationSlidingWindow,
"default_args": ["center"],
},
"growing-window": {
"class": CenterOfRotationGrowingWindow,
}
}
def __init__(self, dataset_info, angles=None, halftomo=False, do_flatfield=True, cor_options=None, logger=None):
"""
Initialize a CORFinder object.
......@@ -35,6 +54,7 @@ class CORFinder:
self._default_search_method = "centered"
if self.halftomo:
self._default_search_method = "global"
self._get_cor_options(cor_options)
def _get_angles(self, angles):
......@@ -99,14 +119,27 @@ class CORFinder:
self.flatfield.normalize_radios(self.radios)
def find_cor(self, search_method=None, **cor_kwargs):
def _get_cor_options(self, cor_options):
if cor_options is None:
self.cor_options = None
return
try:
cor_options = extract_parameters(cor_options)
except Exception as exc:
msg = "Could not extract parameters from cor_options: %s" % (str(exc))
self.logger.fatal(msg)
raise ValueError(msg)
self.cor_options = cor_options
def find_cor(self, search_method=None):
"""
Find the center of rotation.
Parameters
----------
search_method: str, optional
Which CoR search method to use. Default is "auto" (equivalent to "centered").
Which CoR search method to use. Default "centered".
Returns
-------
......@@ -118,24 +151,36 @@ class CORFinder:
This function passes the named parameters to nabu.preproc.alignment.CenterOfRotation.find_shift.
"""
search_method = search_method or self._default_search_method
if search_method == "global":
self.cor = CenterOfRotationAdaptiveSearch(logger=self.logger)
shift = self.cor.find_shift(
self.radios[0],
np.fliplr(self.radios[1]),
low_pass=1, high_pass=20
)
else:
self.cor = CenterOfRotation(logger=self.logger)
shift = self.cor.find_shift(
self.radios[0],
np.fliplr(self.radios[1]),
**cor_kwargs
)
check_supported(search_method, self.search_methods.keys(), "CoR estimation method")
cor_class = self.search_methods[search_method]["class"]
cor_finder = cor_class(logger=self.logger)
self.logger.info("Estimating center of rotation")
default_params = self.search_methods[search_method].get("default_kwargs", None) or {}
cor_exec_kwargs = default_params.copy()
cor_exec_kwargs.update(self.cor_options or {})
cor_exec_args = self.search_methods[search_method].get("default_args", None) or []
# Specific to CenterOfRotationSlidingWindow
if cor_class == CenterOfRotationSlidingWindow:
side_param = cor_exec_kwargs.pop("side", "center")
cor_exec_args = [side_param]
#
self.logger.debug("%s(%s)" % (get_class_name(cor_class), str(cor_exec_kwargs)))
shift = cor_finder.find_shift(
self.radios[0],
np.fliplr(self.radios[1]),
*cor_exec_args,
**cor_exec_kwargs
)
# find_shift returned a single scalar in 2020.1
# This should be the default after 2020.2 release
if hasattr(shift, "__iter__"):
shift = shift[0]
#
return self.shape[1]/2 + shift
res = self.shape[1]/2 + shift
self.logger.info("Estimated center of rotation: %.2f" % res)
return res
def get_class_name(class_object):
return str(class_object).split(".")[-1].strip(">").strip("'").strip('"')
......@@ -266,6 +266,7 @@ class NabuValidator(object):
self.dataset_infos._projections_subsampled = subsample_dict(self.dataset_infos.projections, subsampling_factor)
self.dataset_infos._projs_indices_subsampled = sorted(self.dataset_infos._projections_subsampled.keys())
self.dataset_infos.reconstruction_angles = self.dataset_infos.reconstruction_angles[::subsampling_factor]
# should be simply len(projections)... ?
self.dataset_infos.n_angles //= subsampling_factor
if self.binning != (1, 1):
bin_x, bin_z = self.binning
......@@ -295,7 +296,6 @@ class NabuValidator(object):
else:
rot_c = (nx - 1)/2.
self.dataset_infos.axis_position = rot_c
self.nabu_config["reconstruction"]["rotation_axis_position"] = rot_c
def check_output_file(self):
......
......@@ -163,10 +163,16 @@ nabu_config = {
},
"rotation_axis_position": {
"default": "",
"help": "Rotation axis position. Default is the middle of the detector width. If set to 'auto', nabu will attempt to determine it automatically.",
"help": "Rotation axis position. Default (empty) is the middle of the detector width.\nAdditionally, the following methods are available to find automaticall the Center of Rotation (CoR):\n - centered : a fast and simple auto-CoR method. It only works when the CoR is not far from the middle of the detector. It does not work for half-tomography.\n - global : a slow but robust auto-CoR.\n - sliding-window : semi-automatically find the CoR with a sliding window. You have to specify on which side the CoR is (left, center, right). Please see the 'cor_options' parameter.\n - growing-window : automatically find the CoR with a sliding-and-growing window. You can tune the option with the parameter 'cor_options'.",
"validator": cor_validator,
"type": "required",
},
"cor_options": {
"default": "",
"help": "Options for methods finding automatically the rotation axis position. The parameters are separated by commas and passed as 'name=value', for example: low_pass=1; high_pass=20. Mind the semicolon separator (;).",
"validator": cor_options_validator,
"type": "advanced",
},
"axis_correction_file": {
"default": "",
"help": "In the case where the axis position is specified for each slice",
......
......@@ -140,11 +140,15 @@ class CorMethods(Enum):
AUTO = "centered"
CENTERED = "centered"
GLOBAL = "global"
SLIDING = "sliding-window"
GROWING = "growing-window"
cor_methods = {
"auto": "centered",
"centered": "centered",
"global": "global",
"sliding-window": "sliding-window",
"sliding window": "sliding-window",
"growing-window": "growing-window",
"growing window": "growing-window",
}
......@@ -122,6 +122,7 @@ class ProcessConfig:
self.dataset_infos,
halftomo=self.nabu_config["reconstruction"]["enable_halftomo"],
do_flatfield=self.nabu_config["preproc"]["flatfield_enabled"],
cor_options=self.nabu_config["reconstruction"]["cor_options"],
logger=self.logger
)
cor = self.corfinder.find_cor(search_method=cor)
......
from ast import literal_eval
import numpy as np
from psutil import virtual_memory, cpu_count
from .params import files_formats, FileFormat
......@@ -82,6 +83,27 @@ def get_threads_per_node(max_threads, is_percentage=True):
return min(max_threads, sys_n_threads)
def extract_parameters(params_str, sep=";"):
"""
Extract the named parameters from a string.
Example
--------
The function can be used as follows:
>>> extract_parameters("window_width=None; median_filt_shape=(3,3); padding_mode='wrap'")
... {'window_width': None, 'median_filt_shape': (3, 3), 'padding_mode': 'wrap'}
"""
params_list = params_str.strip(sep).split(sep)
res = {}
for param_str in params_list:
param_name, param_val_str = param_str.strip().split("=")
param_name = param_name.strip()
param_val_str = param_val_str.strip()
param_val = literal_eval(param_val_str)
res[param_name] = param_val
return res
def is_hdf5_extension(ext):
return FileFormat.from_value(files_formats[ext]) == FileFormat.HDF5
......@@ -230,6 +230,11 @@ def cor_validator(val):
)
return val
@validator
def cor_options_validator(val):
if len(val.strip()) == 0:
return None
return val
@validator
def flatfield_enabled_validator(val):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment