Commit ca596003 authored by Pierre Paleo's avatar Pierre Paleo
Browse files

Merge branch 'cor_app' into 'master'

CoR application

Closes #114

See merge request !46
parents 0432fa3a b97a297f
Pipeline #27912 passed with stages
in 4 minutes and 48 seconds
from os import path
import numpy as np
from .logger import LoggerOrPrint
from ..resources.logger import LoggerOrPrint
from .utils import use_options, pipeline_step
from ..utils import check_supported
from ..io.reader import ChunkReader
......@@ -476,6 +476,7 @@ class FullFieldPipeline:
self._reset_sub_region(sub_region)
self._reset_memory()
self._init_writer()
self._init_double_flatfield()
self._read_data()
self._process_chunk()
......@@ -25,7 +25,7 @@ class CudaFullFieldPipeline(FullFieldPipeline):
FlatFieldClass = CudaFlatField
DoubleFlatFieldClass = CudaDoubleFlatField
PaganinPhaseRetrievalClass = CudaPaganinPhaseRetrieval
UnsharpMask = CudaUnsharpMask
UnsharpMaskClass = CudaUnsharpMask
VerticalShiftClass = CudaVerticalShift
SinoProcessingClass = CudaSinoProcessing
MLogClass = CudaLog
......
......@@ -108,14 +108,16 @@ def _extract_nabuconfig_keyvals():
def _handle_renamed_key(key, val, section):
if val is not None:
return key, val
if key in renamed_keys:
if key in renamed_keys and renamed_keys[key]["section"] == section:
info = renamed_keys[key]
print(
"Option '%s' has been renamed to '%s' since version %s. It will result in an error in version %s"
% (key, info["new_name"], info["since"], info["end_deprecation"])
)
val = nabu_config[section].get(info["new_name"], None)
return info["new_name"], val
val = nabu_config[section].get(info["new_name"], None)
return info["new_name"], val
else:
return key, None
def validate_nabu_config(config):
......
......@@ -69,7 +69,7 @@ class UnsharpMask(object):
res = convolve1d(res1, self._gaussian_kernel, axis=0, mode=self.mode)
return res
def unsharp(self, image):
def unsharp(self, image, output=None):
"""
Reference unsharp mask implementation.
"""
......@@ -80,5 +80,8 @@ class UnsharpMask(object):
res = (1 + self.coeff) * image - self.coeff * image_b
else: # LoG
res = image + self.coeff * image_b
if output is not None:
output[:] = res[:]
return output
return res
......@@ -10,15 +10,15 @@ class CudaUnsharpMask(UnsharpMask, CudaProcessing):
def __init__(self, shape, sigma, coeff, mode="reflect", method="gaussian",
cuda_options=None):
"""
NB: For now, this class is designed to use the lowest amount of GPU memory
as possible. Therefore, the input and output image/volumes are assumed
to be already on device.
Unsharp Mask, cuda backend.
"""
cuda_options = cuda_options or {}
CudaProcessing.__init__(self, **cuda_options)
UnsharpMask.__init__(self, shape, sigma, coeff, mode=mode, method=method)
self._init_convolution()
self._init_mad_kernel()
self._init_arrays_to_none(["_d_out"])
def _init_convolution(self):
self.convolution = Convolution(
......@@ -40,10 +40,10 @@ class CudaUnsharpMask(UnsharpMask, CudaProcessing):
name="mul_add"
)
def unsharp(self, image, output):
# For now image and output are assumed to be already allocated on device
assert isinstance(image, garray.GPUArray)
assert isinstance(output, garray.GPUArray)
def unsharp(self, image, output=None):
if output is None:
self._allocate_array("_d_out", self.shape, "f")
output = self._d_out
self.convolution(image, output=output)
if self.method == "gaussian":
self.mad_kernel(output, -self.coeff, image, 1. + self.coeff)
......
......@@ -236,6 +236,7 @@ class FlatField(CCDProcessing):
"""
Apply a flat-field correction, with the current parameters, to a stack
of radios.
The processing is done in-place, meaning that the radios content is overwritten.
Parameters
-----------
......
......@@ -7,8 +7,8 @@ from ...resources.processconfig import ProcessConfig
from ...app.fullfield import FullFieldPipeline
from ...cuda.utils import __has_pycuda__
if __has_pycuda__:
from ..app.fullfield_cuda import CudaFullFieldPipeline, CudaFullFieldPipelineLimitedMemory
from ...app.logger import Logger
from ...app.fullfield_cuda import CudaFullFieldPipeline, CudaFullFieldPipelineLimitedMemory
from ..logger import Logger
from ... import version
......@@ -57,6 +57,10 @@ def main():
logger.warning("Using user-provided energy %.2f keV" % args["energy"])
proc.dataset_infos.dataset_scanner._energy = args["energy"]
proc.processing_options["phase"]["energy_kev"] = args["energy"]
if proc.dataset_infos.energy < 1e-3 and proc.nabu_config["phase"]["method"] != None:
msg = "No information on energy. Cannot retrieve phase. Please use the --energy option"
logger.fatal(msg)
raise ValueError(msg)
#
if __has_pycuda__:
......
import numpy as np
from silx.io import get_data
from ..preproc.ccd import FlatField
from ..preproc.alignment import CenterOfRotation
class CORFinder:
"""
An application-type class for finding the Center Of Rotation (COR).
"""
def __init__(self, dataset_info, angles=None, halftomo=False):
"""
Initialize a CORFinder object.
Parameters
----------
dataset_info: `nabu.resources.dataset_analyzer.DatasetAnalyzer`
Dataset information structure
angles: array, optional
Information on rotation angles. If provided, it overwrites
the rotation angles available in `dataset_info`, if any.
halftomo: bool, optional
Whether the scan was performed in "half tomography" acquisition.
"""
self.halftomo = halftomo
self.dataset_info = dataset_info
self.shape = dataset_info._radio_dims_notbinned[::-1]
self._get_angles(angles)
self._init_radios()
self._init_flatfield()
self._apply_flatfield()
self.cor = CenterOfRotation()
def _get_angles(self, angles):
dataset_angles = self.dataset_info.rotation_angles
if dataset_angles is None:
if angles is None: # should not happen with hdf5
print("Warning: no information on angles was found for this dataset. Using default [0, 180[ range.")
angles = np.linspace(0, np.pi, len(self.dataset_info.projections), False)
dataset_angles = angles
self.angles = dataset_angles
def _init_radios(self):
# TODO
if self.halftomo:
raise NotImplementedError("Automatic COR with half tomo is not supported yet")
#
# We take 2 radios. It could be tuned for a 360 degrees scan.
self._n_radios = 2
self._radios_indices = []
radios_indices = sorted(self.dataset_info.projections.keys())
# Take angles 0 and 180 degrees. It might not work of there is an offset
i_0 = np.argmin(np.abs(self.angles))
i_180 = np.argmin(np.abs(self.angles - np.pi))
_min_indices = [i_0, i_180]
self._radios_indices = [
radios_indices[i_0],
radios_indices[i_180]
]
self.radios = np.zeros((self._n_radios, ) + self.shape, "f")
for i in range(self._n_radios):
radio_idx = self._radios_indices[i]
self.radios[i] = get_data(self.dataset_info.projections[radio_idx]).astype("f")
def _init_flatfield(self):
self.flatfield = FlatField(
self.radios.shape,
flats=self.dataset_info.flats,
darks=self.dataset_info.darks,
radios_indices=self._radios_indices,
interpolation="linear",
convert_float=True
)
def _apply_flatfield(self):
self.flatfield.normalize_radios(self.radios)
def find_cor(self, **cor_kwargs):
"""
Find the center of rotation.
Parameters
----------
This function passes the named parameters to nabu.preproc.alignment.CenterOfRotation.find_shift.
Returns
-------
cor: float
The estimated center of rotation for the current dataset.
"""
shift = self.cor.find_shift(
self.radios[0],
np.fliplr(self.radios[1]),
**cor_kwargs
)
# find_shift returned a single scalar in 2020.1
# This should be the default after 2020.2 release
if hasattr(shift, "__iter__"):
shift = shift[0]
#
return self.shape[1]/2 + shift
......@@ -42,6 +42,8 @@ class DatasetAnalyzer(object):
self.radio_dims = (self.dataset_scanner.dim_1, self.dataset_scanner.dim_2)
self._binning = (1, 1)
self.translations = None
self.axis_position = None
self._radio_dims_notbinned = self.radio_dims
@property
def energy(self):
......
......@@ -57,7 +57,9 @@ class NabuValidator(object):
nx, nz = self.dataset_infos.radio_dims
ny = nx
if self.nabu_config["reconstruction"]["enable_halftomo"]:
cor = int(round(self.nabu_config["reconstruction"]["rotation_axis_position"]))
if self.dataset_infos.axis_position is None:
raise ValueError("rotation_axis_position should be either a number or 'auto' for half tomo")
cor = int(round(self.dataset_infos.axis_position))
ny = nx = 2*cor
what = (
("reconstruction", "start_x", nx),
......@@ -104,7 +106,6 @@ class NabuValidator(object):
def _get_rotation_axis(self):
rec_params = self.nabu_config["reconstruction"]
self.dataset_infos.axis_position = rec_params["rotation_axis_position"]
axis_correction_file = rec_params["axis_correction_file"]
axis_correction = None
if axis_correction_file is not None:
......
......@@ -145,8 +145,8 @@ nabu_config = {
},
"rotation_axis_position": {
"default": "",
"help": "Rotation axis position. Default is the middle of the detector width.",
"validator": optional_float_validator,
"help": "Rotation axis position. Default is the middle of the detector width. If set to 'auto', nabu will attempt to determine it automatically.",
"validator": cor_validator,
"type": "required",
},
"axis_correction_file": {
......@@ -345,16 +345,17 @@ nabu_config = {
"default": "0",
"help": "What to do in the case where the output file exists.\nBy default, the output data is never overwritten and the process is interrupted if the file already exists.\nSet this option to 1 if you want to overwrite without asking.",
"validator": boolean_validator,
"type": "optional",
"type": "required",
},
},
}
renamed_keys = {
"marge": {
"section": "phase",
"new_name": "margin",
"since": "2020.2.0",
"end_deprecation": "2020.3.0",
"end_deprecation": "2020.4.0",
}
}
......
......@@ -3,7 +3,7 @@ from ..utils import PlaceHolder, DataPlaceHolder, copy_dict_items
from ..io.config import NabuConfigParser, validate_nabu_config
from .dataset_analyzer import analyze_dataset, EDFDatasetAnalyzer, HDF5DatasetAnalyzer
from .dataset_validator import NabuValidator
from .cor import CORFinder
class ProcessConfig:
......@@ -59,14 +59,26 @@ class ProcessConfig:
assert (isinstance(dataset_infos, EDFDatasetAnalyzer)) or (
isinstance(dataset_infos, HDF5DatasetAnalyzer)
)
self.dataset_infos = dataset_infos
self.nabu_config = validate_nabu_config(conf)
self.checks = checks
self.remove_unused_radios = remove_unused_radios
self._get_cor()
self.validation_stage2()
self.build_processing_steps()
def _get_cor(self):
cor = self.nabu_config["reconstruction"]["rotation_axis_position"]
if cor == "auto":
self.corfinder = CORFinder(
self.dataset_infos,
halftomo=self.nabu_config["reconstruction"]["enable_halftomo"],
)
cor = self.corfinder.find_cor()
self.dataset_infos.axis_position = cor
def validation_stage2(self):
validator = NabuValidator(self.nabu_config, self.dataset_infos)
if self.checks:
......@@ -85,7 +97,6 @@ class ProcessConfig:
tasks = []
options = {}
#
# Dataset / Get data
#
......@@ -184,6 +195,8 @@ class ProcessConfig:
"start_x", "end_x", "start_y", "end_y", "start_z", "end_z"]
)
rec_options = options["reconstruction"]
rec_options["rotation_axis_position"] = dataset_infos.axis_position
options["build_sino"]["rotation_axis_position"] = dataset_infos.axis_position
rec_options["axis_correction"] = dataset_infos.axis_correction
rec_options["angles"] = dataset_infos.reconstruction_angles
rec_options["radio_dims_y_x"] = dataset_infos.radio_dims[::-1]
......
......@@ -213,6 +213,19 @@ def optional_float_validator(val):
val_float = None
return val_float
@validator
def cor_validator(val):
if isinstance(val, float):
return val
elif len(val.strip()) >= 1:
if val.lower() == "auto":
return "auto"
val_float, error = convert_to_float(val)
assert error is None, "Invalid number"
return val_float
else:
return None
@validator
def phase_method_validator(val):
return name_range_checker(
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment