Commit 0667b851 authored by Pierre Paleo's avatar Pierre Paleo
Browse files

Merge branch 'fix_ff_nan' into 'master'

Fix NaNs in flat-field

Closes #302

See merge request !176
parents 0ea4e217 164d602d
Pipeline #73812 passed with stages
in 6 minutes and 19 seconds
__version__ = "2022.1.2"
__version__ = "2022.1.3"
__nabu_modules__ = [
"app",
"cuda",
......
......@@ -2,7 +2,7 @@ from ..utils import updiv
import pycuda.gpuarray as garray
from pycuda.compiler import SourceModule
class CudaKernel(object):
class CudaKernel:
"""
Helper class that wraps CUDA kernel through pycuda SourceModule.
......
......@@ -65,5 +65,11 @@ __global__ void flatfield_normalization(
#error "N_DARKS > 1 is not supported yet"
#endif
radios[pos] = (radios[pos] - dark_val) / (flat_val - dark_val);
float val = (radios[pos] - dark_val) / (flat_val - dark_val);
#ifdef NAN_VALUE
if (flat_val == dark_val) val = NAN_VALUE;
#endif
radios[pos] = val;
}
......@@ -18,6 +18,7 @@ class FlatFieldArrays:
radios_indices=None,
interpolation: str = "linear",
distortion_correction=None,
nan_value=1.0,
):
"""
Initialize a flat-field normalization process.
......@@ -40,6 +41,8 @@ class FlatFieldArrays:
Interpolation method for flat-field. See below for more details.
distortion_correction: DistortionCorrection, optional
A DistortionCorrection object. If provided, it is used to correct flat distortions based on each radio.
nan_value: float, optional
Which float value is used to replace nan/inf after flat-field.
Important
......@@ -62,7 +65,7 @@ class FlatFieldArrays:
If interpolation="linear", the normalization is done as a linear
function of the radio index.
"""
self._set_parameters(radios_shape, radios_indices, interpolation)
self._set_parameters(radios_shape, radios_indices, interpolation, nan_value)
self._set_flats_and_darks(flats, darks)
self.distortion_correction = distortion_correction
......@@ -106,7 +109,7 @@ class FlatFieldArrays:
)
def _set_parameters(self, radios_shape, radios_indices, interpolation):
def _set_parameters(self, radios_shape, radios_indices, interpolation, nan_value):
self._set_radios_shape(radios_shape)
if radios_indices is None:
radios_indices = np.arange(0, self.n_radios, dtype=np.int32)
......@@ -122,6 +125,7 @@ class FlatFieldArrays:
check_supported(
interpolation, self._supported_interpolations, "Interpolation mode"
)
self.nan_value = nan_value
@staticmethod
def get_previous_next_indices(arr, idx):
......@@ -186,6 +190,12 @@ class FlatFieldArrays:
self._dark = dark
return self._dark
def remove_invalid_values(self, img):
if self.nan_value is None:
return
invalid_mask = np.logical_not(np.isfinite(img))
img[invalid_mask] = self.nan_value
def normalize_radios(self, radios):
"""
Apply a flat-field normalization, with the current parameters, to a stack
......@@ -207,6 +217,7 @@ class FlatFieldArrays:
if do_flats_distortion_correction:
flat = self.distortion_correction.estimate_and_correct(flat, radio_data)
radios[i] = radio_data / flat
self.remove_invalid_values(radios[i])
return radios
......@@ -221,6 +232,7 @@ class FlatFieldArrays:
if self.distortion_correction is not None:
flat = self.distortion_correction.estimate_and_correct(flat, radio)
radio /= flat
self.remove_invalid_values(radio)
return radio
......@@ -236,6 +248,7 @@ class FlatFieldDataUrls(FlatField):
radios_indices=None,
interpolation: str = "linear",
distortion_correction=None,
nan_value=1.0,
**chunk_reader_kwargs
):
"""
......@@ -259,6 +272,8 @@ class FlatFieldDataUrls(FlatField):
Interpolation method for flat-field. See below for more details.
distortion_correction: DistortionCorrection, optional
A DistortionCorrection object. If provided, it is used to
nan_value: float, optional
Which float value is used to replace nan/inf after flat-field.
Other Parameters
......
......@@ -16,6 +16,7 @@ class CudaFlatFieldArrays(FlatFieldArrays):
radios_indices=None,
interpolation: str = "linear",
distortion_correction=None,
nan_value=1.0,
cuda_options: Union[dict, None] = None,
):
"""
......@@ -35,7 +36,8 @@ class CudaFlatFieldArrays(FlatFieldArrays):
darks,
radios_indices=radios_indices,
interpolation=interpolation,
distortion_correction=distortion_correction
distortion_correction=distortion_correction,
nan_value=nan_value
)
self._set_cuda_options(cuda_options)
self._init_cuda_kernels()
......@@ -54,15 +56,19 @@ class CudaFlatFieldArrays(FlatFieldArrays):
raise ValueError(
"Interpolation other than linar is not yet implemented in the cuda back-end"
)
#
self._cuda_fname = get_cuda_srcfile("flatfield.cu")
options = [
"-DN_FLATS=%d" % self.n_flats,
"-DN_DARKS=%d" % self.n_darks,
]
if self.nan_value is not None:
options.append("-DNAN_VALUE=%f" % self.nan_value)
self.cuda_kernel = CudaKernel(
"flatfield_normalization",
self._cuda_fname,
signature="PPPiiiPP",
options=[
"-DN_FLATS=%d" % self.n_flats,
"-DN_DARKS=%d" % self.n_darks,
]
options=options
)
self._nx = np.int32(self.shape[1])
self._ny = np.int32(self.shape[0])
......
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment