diff --git a/tomoscan/normalization.py b/tomoscan/normalization.py index 768483907ab0ac7653a91f7e2e520b90f5654fad..aa94d0dd34a79ec981d8738c0b8ba334ddffe6ae 100644 --- a/tomoscan/normalization.py +++ b/tomoscan/normalization.py @@ -44,13 +44,10 @@ _logger = logging.getLogger(__name__) class Method(_Enum): NONE = "none" - SCALAR = "scalar" + SUBTRACTION = "subtraction" + DIVISION = "division" CHEBYSHEV = "chebyshev" LSQR_SPLINE = "lsqr spline" - MONITOR = "intensity monitor" - # MANUAL_ROI = "manual ROI" - # AUTO_ROI = "automatic ROI" - # DATASET = "from dataset" class _MethodMode(_Enum): @@ -107,7 +104,7 @@ class _ROIInfo: self.y_max = y_max -class _IntensityNormalization: +class IntensityNormalization: """Information regarding the intensity normalization to be done""" def __init__(self): @@ -137,6 +134,27 @@ class _IntensityNormalization: def get_extra_infos(self) -> typing.Union[dict, _DatasetInfos, _ROIInfo]: return self._extra_info + def to_dict(self) -> dict: + res = { + "method": self.method.value, + } + if self._extra_info not in (None, {}): + res["extra_infos"] = self.get_extra_infos() + return res + + def load_from_dict(self, dict_): + if "method" in dict_: + self.method = dict_["method"] + if "extra_infos" in dict_: + self.set_extra_infos(dict_["extra_infos"]) + return self + + @staticmethod + def from_dict(dict_): + res = IntensityNormalization() + res.load_from_dict(dict_) + return res + def __str__(self): return "method: {}, extra-infos: {}".format(self.method, self.get_extra_infos()) diff --git a/tomoscan/scanbase.py b/tomoscan/scanbase.py index 86b4d8d29017588c094caad437debb3b2905a94b..72c08ec0c6ef4590ed0d9ccbc9bb6c15164e1d5b 100644 --- a/tomoscan/scanbase.py +++ b/tomoscan/scanbase.py @@ -48,7 +48,7 @@ from math import ceil from .progress import Progress from bisect import bisect_left from tomoscan.normalization import ( - _IntensityNormalization, + IntensityNormalization, Method as _IntensityMethod, normalize_chebyshev_2D, normalize_lsqr_spline_2D, @@ -181,7 +181,7 @@ class TomoScanBase(Dataset): """monitor of the intensity during acquisition. Can be a diode for example""" self._source = None - self._intensity_normalization = _IntensityNormalization() + self._intensity_normalization = IntensityNormalization() """Extra information for normalization""" def clear_caches(self): @@ -650,7 +650,9 @@ class TomoScanBase(Dataset): else: return None - def _apply_sino_norm(self, sinogram, norm_method, subsampling=1, **kwargs): + def _apply_sino_norm( + self, sinogram, norm_method: _IntensityMethod, subsampling=1, **kwargs + ): if norm_method is not None: norm_method = _IntensityMethod.from_value(norm_method) if norm_method in (None, _IntensityMethod.NONE): @@ -659,35 +661,44 @@ class TomoScanBase(Dataset): return normalize_chebyshev_2D(sinogram) elif norm_method is _IntensityMethod.LSQR_SPLINE: return normalize_lsqr_spline_2D(sinogram) - elif norm_method is _IntensityMethod.SCALAR: - if "value" not in kwargs: - raise KeyError("'value' should be provided to extra_infos") - elif numpy.isscalar(kwargs["value"]): - return sinogram - kwargs["value"] - else: - for sl, sc in zip(range(len(sinogram)), kwargs["value"][::subsampling]): - sinogram[sl] = sinogram[sl] - sc - return sinogram - elif norm_method is _IntensityMethod.MONITOR: - intensities = self.get_projections_intensity_monitor() - if intensities is None: - raise ValueError("No dataset for intensity monitoring found") - else: - i_values = set(intensities.values()) - if len(i_values) == 1 and list(i_values)[0] is None: - raise ValueError("No dataset for intensity monitoring found") - - intensities = [intensities[key] for key in sorted(intensities.keys())] - intensities = intensities[::subsampling] - for sl, sc in zip(range(len(sinogram)), intensities): - if sc is None: - _logger.warning( - "Intensity not found for line {sl}. won't normalize this line" - ) - sinogram[sl] = sinogram[sl] - else: + elif norm_method in (_IntensityMethod.DIVISION, _IntensityMethod.SUBTRACTION): + if "value" in kwargs: + _logger.info("Apply sinogram normalization from 'value' key") + if numpy.isscalar(kwargs["value"]): + return sinogram - kwargs["value"] + else: + for sl, sc in zip( + range(len(sinogram)), kwargs["value"][::subsampling] + ): sinogram[sl] = sinogram[sl] - sc - return sinogram + return sinogram + elif "dataset_url" in kwargs: + _logger.info("Apply sinogram normalization from 'dataset_url' key") + intensities = DataUrl(path=kwargs["dataset_url"]) + if intensities is None: + raise ValueError("No dataset for intensity monitoring found") + else: + i_values = set(intensities.values()) + if len(i_values) == 1 and list(i_values)[0] is None: + raise ValueError("No dataset for intensity monitoring found") + + intensities = [ + intensities[key] for key in sorted(intensities.keys()) + ] + intensities = intensities[::subsampling] + for sl, sc in zip(range(len(sinogram)), intensities): + if sc is None: + _logger.warning( + "Intensity not found for line {sl}. won't normalize this line" + ) + sinogram[sl] = sinogram[sl] + else: + sinogram[sl] = sinogram[sl] - sc + return sinogram + else: + raise KeyError( + f"{norm_method.value} requires a value or an url to be computed" + ) else: raise ValueError("norm method not handled", norm_method) diff --git a/tomoscan/test/test_normalization.py b/tomoscan/test/test_normalization.py index ff56abc91d6293cff6efe01bd3399c7b77c566ac..29404475645266f6f08389c7a6a2862c294da222 100644 --- a/tomoscan/test/test_normalization.py +++ b/tomoscan/test/test_normalization.py @@ -45,7 +45,8 @@ else: has_scipy = True -def test_normalization_scalar_normalization(): +@pytest.mark.parametrize("method", ["subtraction", "division"]) +def test_normalization_scalar_normalization(method): """test scalar normalization""" with HDF5MockContext( scan_path=os.path.join(tempfile.mkdtemp(), "scan_test"), @@ -53,30 +54,9 @@ def test_normalization_scalar_normalization(): n_ini_proj=10, ) as scan: with pytest.raises(KeyError): - scan.get_sinogram(line=2, norm_method="scalar") + scan.get_sinogram(line=2, norm_method=method) - scan.get_sinogram(line=2, norm_method="scalar", value=12.2) - - -def test_normalization_intensity_monitor_normalization(): - """test scalar normalization""" - scan_path = os.path.join(tempfile.mkdtemp(), "scan_test") - with HDF5MockContext( - scan_path=scan_path, - n_proj=10, - n_ini_proj=10, - ) as scan: - with pytest.raises(ValueError): - scan.get_sinogram(line=2, norm_method="intensity monitor") - - scan_path = os.path.join(tempfile.mkdtemp(), "scan_test") - with HDF5MockContext( - scan_path=scan_path, - n_proj=10, - n_ini_proj=10, - intensity_monitor=True, - ) as scan: - scan.get_sinogram(line=2, norm_method="intensity monitor") + scan.get_sinogram(line=2, norm_method=method, value=12.2) def test_normalize_chebyshev_2D(): @@ -92,7 +72,7 @@ def test_normalize_chebyshev_2D(): assert numpy.array_equal(sinogram, sinogram_2) -@pytest.mark.skip(condition=not has_scipy, reason="scipy missing") +@pytest.mark.skipif(condition=not has_scipy, reason="scipy missing") def test_normalize_lsqr_spline_2D(): """test lsqr_spline_2D normalization""" with HDF5MockContext(