Commit b4ceafe3 authored by Pierre Paleo's avatar Pierre Paleo
Browse files

WIP Component: phase

parent 0cf594ff
Pipeline #19772 passed with stage
in 1 minute and 11 seconds
......@@ -64,3 +64,5 @@ class Component:
def execute(self, *args, **kwargs):
raise ValueError("This must be implemented by child class")
__call__ = execute
......@@ -13,17 +13,16 @@ def PhaseRetrieval(Component):
def _get_backend(self):
self.backend = "numpy"
if self.options["use_cuda"]:
self.backend = self.check_requirement(
__has_pycuda__, "pycuda must be installed for use_cuda=1",
"cuda", "numpy"
)
if self.options["use_opencl"]:
elif self.options["use_opencl"]: # mutually exclusive with use_cuda - see #72
# Not implemented yet
self.logger.warning("use_opencl: OpenCL backend is not available yet for phase retrieval")
self.backend = "numpy"
else:
self.backend = "numpy"
def _init_phase_retrieval(self):
......@@ -52,7 +51,30 @@ def PhaseRetrieval(Component):
self.logger.debug("Phase retrieval object created with %s backend" % self.backend)
def execute(self):
def _get_radio_dims(self, radios):
if radios.ndim == 2:
n_z, n_x = radios.shape
n_a = 1
elif radios.ndim == 3:
n_a, n_z, n_x = radios.shape
else:
raise ValueError("radios dims must be 2 or 3")
return (n_a, n_z, n_x)
def execute(self, radios, output=None):
"""
Perform a phase retrieval on a chunk of radios.
"""
# should "execute()" always be applied on a radios chunk ?
# If so, it should be mentioned in the documentation.
# It makes sense to do some checks on the output once, and then launch heavy calculations.
# Done in-place by default ! i.e input data is erased
n_a, n_z, n_x = self._get_radio_dims(radios)
self.logger.info("Phase retrieval on chunk %d: start" % self.current_chunk)
self.phase_retrieval.apply_filter(radio, output=XXX) # depends CPU/GPU
for i in range(n_a):
self.phase_retrieval.apply_filter(
radios[i], output=XXX # TODO
)
self.logger.info("Phase retrieval on chunk %d: end" % self.current_chunk)
......@@ -127,6 +127,7 @@ class CudaPaganinPhaseRetrieval(PaganinPhaseRetrieval, CudaProcessing):
s0, s1 = self.shape_inner
((U, _), (L, _)) = self.margin
if output is None:
# copy D2H
return self.d_radio_padded[U:U + s0, L:L + s1].get()
assert output.shape == self.shape_inner
assert output.dtype == np.float32
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment