Commit 74b1c6c1 authored by Henri Payno's avatar Henri Payno
Browse files

normalization: update to fit nabu approach

# Conflicts:
#	tomwer/core/scan/scanbase.py
parent baf41912
......@@ -116,7 +116,6 @@ class NormIOW(WidgetLongProcessing, SuperviseOW):
description = "Define normalization on intensity to be applied on projections"
icon = "icons/norm_I.svg"
priority = 28
category = "esrfWidgets"
keywords = [
"tomography",
"normalization",
......
......@@ -46,11 +46,17 @@ class IntensityNormalizationThread(qt.QThread):
self._result = None
def run(self) -> None:
process = IntensityNormalizationTask(process_id=None)
process.set_properties(self._configuration)
self._result = process.process(
self.scan
).intensity_normalization.tomwer_processing_res
process = IntensityNormalizationTask(
process_id=None,
inputs={
"data": self.scan,
"configuration": self._configuration,
},
varinfo=None,
)
process.run()
self._result = self.scan.intensity_normalization.tomwer_processing_res
class NormIntensityWindow(_NormIntensityWindow):
......
......@@ -41,6 +41,7 @@ from .params import (
_CalculationArea,
_ValueCalculationMethod,
_ValueCalculationFct,
_ValueSource,
)
from processview.core.superviseprocess import SuperviseProcess
from silx.io.url import DataUrl
......@@ -85,6 +86,8 @@ class IntensityNormalizationTask(
)
SuperviseProcess.__init__(self, process_id=process_id)
self._dry_run = False
if "configuration" in inputs:
self.set_configuration(inputs["configuration"])
def set_properties(self, properties):
if isinstance(properties, IntensityNormalizationParams):
......@@ -102,23 +105,24 @@ class IntensityNormalizationTask(
scan.intensity_normalization.tomwer_processing_res_code = None
params = IntensityNormalizationParams.from_dict(self._settings)
try:
if params.method is Method.MANUAL_ROI:
if params.source is _ValueSource.MANUAL_ROI:
res = self._compute_from_manual_roi(scan)
need_conversion_to_tomoscan = True
elif params.method is Method.AUTO_ROI:
# need_conversion_to_tomoscan = True
elif params.source is _ValueSource.AUTO_ROI:
res = self._compute_from_automatic_roi(scan)
need_conversion_to_tomoscan = True
elif params.method is Method.DATASET:
# need_conversion_to_tomoscan = True
elif params.source is _ValueSource.DATASET:
res = self._compute_from_dataset()
need_conversion_to_tomoscan = True
# need_conversion_to_tomoscan = True
elif params.method.value in TomoScanMethod.values():
need_conversion_to_tomoscan = False
# need_conversion_to_tomoscan = False
res = None
else:
raise ValueError("method {} is not handled".format(params.method))
except Exception as e:
_logger.error(e)
scan.intensity_normalization.tomwer_processing_res_code = False
res = None
else:
scan.intensity_normalization.tomwer_processing_res_code = True
# insure this could be hashable (for caches)
......@@ -126,13 +130,13 @@ class IntensityNormalizationTask(
res = tuple(res)
scan.intensity_normalization.tomwer_processing_res = res
if need_conversion_to_tomoscan:
results_to_tomoscan_norm(
scan=scan,
method=params.method,
results=res,
extra_infos=params.extra_infos,
)
# if need_conversion_to_tomoscan:
# results_to_tomoscan_norm(
# scan=scan,
# method=params.method,
# results=res,
# extra_infos=params.extra_infos,
# )
self.outputs.data = scan
......@@ -420,27 +424,35 @@ class IntensityNormalizationTask(
return "Normalize intensity."
def results_to_tomoscan_norm(
scan,
method: Method,
results: typing.Union[numpy.ndarray, float, None],
extra_infos: dict,
):
"""
Util function to copy results from tomwer normalization to tomoscan
normalization parameters
:param TomwerScanBase scan:
:param Method method:
:param dict results:
:param dict extra_infos:
:return:
"""
if method in TomoScanMethod.values():
scan.intensity_normalization = method
scan.intensity_normalization.set_extra_infos(extra_infos)
else:
scan.intensity_normalization = TomoScanMethod.SCALAR
if results is not None:
extra_infos["value"] = results
scan.intensity_normalization.set_extra_infos(extra_infos)
# def results_to_tomoscan_norm(
# scan,
# method: Method,
# results: typing.Union[numpy.ndarray, float, None],
# extra_infos: dict,
# ):
# """
# Util function to copy results from tomwer normalization to tomoscan
# normalization parameters
# :param TomwerScanBase scan:
# :param Method method:
# :param dict results:
# :param dict extra_infos:
# :return:
# """
# print("results to tomo scan norm will be")
# print(scan)
# print("method is", method)
# print("results is", results)
# print("extra_infos is", extra_infos)
# scan.intensity_normalization = method
# extra_infos["value"] = results
# scan.intensity_normalization.set_extra_infos(extra_infos)
# # if method in TomoScanMethod.values():
# # scan.intensity_normalization = method
# # scan.intensity_normalization.set_extra_infos(extra_infos)
# # else:
# # scan.intensity_normalization = TomoScanMethod.SCALAR
# # if results is not None:
# # extra_infos["value"] = results
# # scan.intensity_normalization.set_extra_infos(extra_infos)
......@@ -35,10 +35,20 @@ __date__ = "25/06/2021"
from silx.utils.enum import Enum as _Enum
from tomwer.core.scan.scanbase import NormMethod as Method
from tomoscan.normalization import Method
import typing
class _ValueSource(_Enum):
MONITOR = "intensity monitor"
MANUAL_ROI = "manual ROI"
AUTO_ROI = "automatic ROI"
DATASET = "from dataset"
MANUAL_SCALAR = "scalar"
NONE = "none"
class _ValueCalculationFct(_Enum):
MEAN = "mean"
MEDIAN = "median"
......@@ -101,7 +111,7 @@ class _ROIInfo:
class IntensityNormalizationParams:
"""Information regarding the intensity normalization to be done"""
def __init__(self, method=Method.NONE, extra_infos=None):
def __init__(self, method=Method.NONE, source=_ValueSource.NONE, extra_infos=None):
if not isinstance(method, (str, Method)):
raise TypeError(
"method is expected to be a str or an instance " "of Method"
......@@ -113,6 +123,7 @@ class IntensityNormalizationParams:
)
self._method = Method.from_value(method)
self._extra_infos = extra_infos if extra_infos is not None else {}
self._source = _ValueSource.from_value(source)
@property
def method(self):
......@@ -124,6 +135,16 @@ class IntensityNormalizationParams:
method = Method.NONE
self._method = Method.from_value(method)
@property
def source(self):
return self._source
@source.setter
def source(self, source):
if source is None:
source = _ValueSource.NONE
self._source = _ValueSource.from_value(source)
@property
def extra_infos(self):
return self._extra_infos
......@@ -139,7 +160,8 @@ class IntensityNormalizationParams:
def to_dict(self):
_dict = self._extra_infos
_dict["method"] = self._method.value
_dict["method"] = self.method.value
_dict["source"] = self.source.value
return _dict
@staticmethod
......@@ -153,4 +175,7 @@ class IntensityNormalizationParams:
if "method" in tmp_dict:
self.method = tmp_dict["method"]
del tmp_dict["method"]
if "source" in tmp_dict:
self.source = tmp_dict["source"]
del tmp_dict["source"]
self.extra_infos = tmp_dict
......@@ -28,7 +28,7 @@ __date__ = "11/07/2021"
from tomwer.core.process.reconstruction.normalization import params, normalization
from tomwer.core.scan.scanbase import NormMethod as Method
from tomoscan.normalization import Method
from tomwer.core.utils.scanutils import MockHDF5
import unittest
import tempfile
......@@ -70,9 +70,9 @@ class TestNormalization(unittest.TestCase):
100 * 100 * 2
).reshape(2, 100, 100)
process = normalization.IntensityNormalizationTask()
process_params = normalization.IntensityNormalizationParams()
process_params.method = Method.MANUAL_ROI
process_params.method = Method.SUBSTRACTION
process_params.source = params._ValueSource.MANUAL_ROI
expected_results = {
"mean": {
"scalar": 5800.5,
......@@ -96,11 +96,14 @@ class TestNormalization(unittest.TestCase):
"calc_method": calc_method,
"calc_area": "volume",
}
process.set_properties(process_params)
res = process.process(
self.scan
).intensity_normalization.tomwer_processing_res
process = normalization.IntensityNormalizationTask(
inputs={
"data": self.scan,
"configuration": process_params,
}
)
process.run()
res = process.results.intensity_normalization.tomwer_processing_res
if isinstance(res, numpy.ndarray):
numpy.testing.assert_array_equal(
res, expected_results[calc_fct][calc_method]
......
......@@ -42,7 +42,6 @@ from tomoscan.io import HDF5File
from processview.core.dataset import Dataset
from typing import Optional
from silx.utils.enum import Enum as _Enum
from tomoscan.normalization import Method as _RawMethod
logger = logging.getLogger(__name__)
......@@ -824,10 +823,6 @@ def _is_reconstructed_slice_file(
class NormMethod(_Enum):
NONE = _RawMethod.NONE.value
MANUAL_SCALAR = _RawMethod.SCALAR.value
CHEBYSHEV = _RawMethod.CHEBYSHEV.value
LSQR_SPLINE = _RawMethod.LSQR_SPLINE.value
MANUAL_ROI = "manual ROI"
AUTO_ROI = "automatic ROI"
DATASET = "from dataset"
......@@ -47,6 +47,7 @@ from tomwer.core.scan.scanbase import TomwerScanBase
from tomwer.gui.visualization.sinogramviewer import SinogramViewer as _SinogramViewer
from tomwer.core.process.reconstruction.normalization.normalization import Method
from tomwer.core.process.reconstruction.normalization import params as _normParams
from tomwer.core.process.reconstruction.normalization.params import _ValueSource
from tomoscan.normalization import Method as TomoScanMethod
from tomwer.gui.utils.buttons import PadlockButton
import weakref
......@@ -86,9 +87,10 @@ class NormIntensityWindow(qt.QMainWindow):
self.addDockWidget(qt.Qt.RightDockWidgetArea, self._dockWidgetCtrl)
# connect signal / slot
self._optsWidget.sigModeChanged.connect(self._modeChanged)
# self._optsWidget.sigModeChanged.connect(self._modeChanged)
self._optsWidget.sigValueUpdated.connect(self.setResult)
self._optsWidget.sigConfigurationChanged.connect(self._configurationChanged)
self._optsWidget.sigSourceChanged.connect(self._sourceChanged)
self._crtWidget.sigValidateRequest.connect(self._validated)
# set up
......@@ -113,17 +115,31 @@ class NormIntensityWindow(qt.QMainWindow):
def getCurrentMethod(self):
return self._optsWidget.getCurrentMethod()
def getCurrentSource(self):
return self._optsWidget.getCurrentSource()
def _modeChanged(self):
self._sourceChanged()
def _sourceChanged(self):
source = self.getCurrentSource()
method = self.getCurrentMethod()
scan = self.getScan()
methods_using_manual_roi = (Method.DIVISION, Method.SUBSTRACTION)
self._centralWidget.setManualROIVisible(
self.getCurrentMethod() == Method.MANUAL_ROI
source is _ValueSource.MANUAL_ROI and method in methods_using_manual_roi
)
scan = self.getScan()
if scan:
# if the normed sinogram can be obtained `directly` from tomoscan
if self.getCurrentMethod().value in TomoScanMethod.values():
scan.intensity_normalization = self.getCurrentMethod().value
scan.intensity_normalization.set_extra_infos(self.getExtraArgs())
scan.intensity_normalization.tomwer_processing_res_code = True
methods_requesting_calculation = (Method.DIVISION, Method.SUBSTRACTION)
if method in methods_requesting_calculation:
# if the normed sinogram can be obtained `directly`
if source in (_ValueSource.MANUAL_SCALAR, _ValueSource.DATASET):
scan.intensity_normalization = self.getCurrentMethod().value
scan.intensity_normalization.set_extra_infos(self.getExtraArgs())
scan.intensity_normalization.tomwer_processing_res_code = True
self._centralWidget._updateSinogramROI()
def getScan(self):
if self._scan is not None:
......@@ -199,10 +215,15 @@ class _Viewer(qt.QTabWidget):
self._sinoView.setScan(scan, update=False)
def _updateSinogramROI(self):
display_sino_roi = self.parent().getCurrentMethod()
if display_sino_roi != Method.MANUAL_ROI:
self._sinoView.setROIVisible(False)
else:
source = self.parent().getCurrentSource()
method = self.parent().getCurrentMethod()
display_sino_roi = source is _ValueSource.MANUAL_ROI and method in (
Method.DIVISION,
Method.SUBSTRACTION,
)
if display_sino_roi:
roi = self._projView.getROI()
sinogram_line = self._sinoView.getLine()
y_min = roi.getOrigin()[1]
......@@ -214,6 +235,8 @@ class _Viewer(qt.QTabWidget):
self._sinoView.setROIVisible(True)
else:
self._sinoView.setROIVisible(False)
else:
self._sinoView.setROIVisible(False)
def setManualROIVisible(self, visible):
self._projView.setManualROIVisible(visible=visible)
......@@ -241,7 +264,7 @@ class _ProjPlotWithROI(DataViewer):
"""signal emit when ROI change"""
def __init__(self, *args, **kwargs):
DataViewer.__init__(self, *args, **kwargs)
DataViewer.__init__(self, *args, **kwargs, show_overview=False)
self._sinogramLine = 0
self._roiVisible = False
self.setScanInfoVisible(False)
......@@ -390,6 +413,9 @@ class _NormIntensityOptions(qt.QWidget):
sigModeChanged = qt.Signal()
"""signal emitted when the mode change"""
sigSourceChanged = qt.Signal()
"""signal emitted when the source change"""
sigValueUpdated = qt.Signal(object)
"""Signal emit when user defines manually the value"""
......@@ -407,13 +433,7 @@ class _NormIntensityOptions(qt.QWidget):
# mode
self._modeCB = qt.QComboBox(self)
for mode in Method:
if mode in (
# Method.LSQR_SPLINE,
# Method.MANUAL_SCALAR,
# Method.DATASET,
# Method.AUTO_ROI,
# Method.MANUAL_ROI,
):
if mode in (Method.LSQR_SPLINE,):
continue
else:
self._modeCB.addItem(mode.value)
......@@ -422,12 +442,23 @@ class _NormIntensityOptions(qt.QWidget):
self._lockButton = PadlockButton(self)
self._lockButton.setFixedWidth(25)
self.layout().addWidget(self._lockButton, 0, 2, 1, 1)
# source
self._sourceCB = qt.QComboBox(self)
for mode in _ValueSource:
if mode == _ValueSource.NONE:
# filter this value because does not have much sense for the GUI
continue
if mode == _ValueSource.AUTO_ROI:
continue
self._sourceCB.addItem(mode.value)
self._sourceLabel = qt.QLabel("source:", self)
self.layout().addWidget(self._sourceLabel, 1, 0, 1, 1)
self.layout().addWidget(self._sourceCB, 1, 1, 1, 1)
# method
self._optsMethod = qt.QGroupBox(self)
self._optsMethod.setTitle("options")
self._optsMethod.setLayout(qt.QVBoxLayout())
self.layout().addWidget(self._optsMethod, 1, 0, 1, 3)
self.layout().addWidget(self._optsMethod, 2, 0, 1, 3)
# intensity calculation options
self._intensityCalcOpts = _NormIntensityCalcOpts(self)
self._optsMethod.layout().addWidget(self._intensityCalcOpts)
......@@ -436,7 +467,7 @@ class _NormIntensityOptions(qt.QWidget):
self._optsMethod.layout().addWidget(self._datasetWidget)
# scalar value
self._scalarValueWidget = _NormIntensityScalarValue(self)
self.layout().addWidget(self._scalarValueWidget, 2, 0, 1, 3)
self.layout().addWidget(self._scalarValueWidget, 3, 0, 1, 3)
# buttons
self._buttonsGrp = qt.QWidget(self)
self._buttonsGrp.setLayout(qt.QGridLayout())
......@@ -446,13 +477,15 @@ class _NormIntensityOptions(qt.QWidget):
spacer = qt.QWidget(self)
spacer.setSizePolicy(qt.QSizePolicy.Expanding, qt.QSizePolicy.Minimum)
self._buttonsGrp.layout().addWidget(spacer)
self.layout().addWidget(self._buttonsGrp, 3, 0, 1, 3)
self.layout().addWidget(self._buttonsGrp, 4, 0, 1, 3)
self._modeChanged()
# connect signal / slot
self._modeCB.currentIndexChanged.connect(self._modeChanged)
self._modeCB.currentIndexChanged.connect(self._configurationChanged)
self._sourceCB.currentIndexChanged.connect(self._sourceChanged)
self._sourceCB.currentIndexChanged.connect(self._configurationChanged)
self._computeButton.released.connect(self._computationRequested)
self._scalarValueWidget.sigValueChanged.connect(self._valueUpdated)
self._scalarValueWidget.sigValueChanged.connect(self._configurationChanged)
......@@ -466,6 +499,33 @@ class _NormIntensityOptions(qt.QWidget):
def _configurationChanged(self):
self.sigConfigurationChanged.emit()
def _sourceChanged(self):
source = self.getCurrentSource()
method = self.getCurrentMethod()
interactive_methods = (Method.DIVISION, Method.SUBSTRACTION)
interactive_sources = (
_ValueSource.MANUAL_ROI,
_ValueSource.AUTO_ROI,
_ValueSource.DATASET,
)
self._intensityCalcOpts._calculationAreaCB.setVisible(
source in (_ValueSource.MANUAL_ROI,)
)
self._intensityCalcOpts._calculationAreaLabel.setVisible(
source in (_ValueSource.MANUAL_ROI,)
)
self._datasetWidget.setVisible(source == _ValueSource.DATASET)
self.setManualROIVisible(source == _ValueSource.MANUAL_ROI)
self._optsMethod.setVisible(
method in interactive_methods and source in interactive_sources
)
self._scalarValueWidget.setVisible(source == _ValueSource.MANUAL_SCALAR)
self._buttonsGrp.setVisible(
method in interactive_methods and source in interactive_sources
)
self.sigSourceChanged.emit()
def _lockChanged(self):
self._scalarValueWidget.setEnabled(not self.isLocked())
self._datasetWidget.setEnabled(not self.isLocked())
......@@ -490,27 +550,23 @@ class _NormIntensityOptions(qt.QWidget):
idx = self._modeCB.findText(method.value)
self._modeCB.setCurrentIndex(idx)
def getCurrentSource(self):
return _ValueSource.from_value(self._sourceCB.currentText())
def setCurrentSource(self, source):
source = _ValueSource.from_value(source)
idx = self._sourceCB.findText(source.value)
self._sourceCB.setCurrentIndex(idx)
def _modeChanged(self, *args, **kwargs):
mode = self.getCurrentMethod()
self._intensityCalcOpts.setVisible(mode in (Method.MANUAL_ROI, Method.DATASET))
self._intensityCalcOpts._calculationAreaCB.setVisible(
mode in (Method.MANUAL_ROI,)
)
self._intensityCalcOpts._calculationAreaLabel.setVisible(
mode in (Method.MANUAL_ROI,)
)
self._datasetWidget.setVisible(mode == Method.DATASET)
self.setManualROIVisible(mode == Method.MANUAL_ROI)
self._optsMethod.setVisible(
mode in (Method.MANUAL_ROI, Method.AUTO_ROI, Method.DATASET)
)
self._scalarValueWidget.setVisible(mode == Method.MANUAL_SCALAR)
self._buttonsGrp.setVisible(
mode in (Method.AUTO_ROI, Method.DATASET, Method.MANUAL_ROI)
)
self.sigValueCanBeLocked.emit(mode == Method.MANUAL_SCALAR)
mode_with_calculations = (Method.DIVISION, Method.SUBSTRACTION)
self._intensityCalcOpts.setVisible(mode in mode_with_calculations)
self._sourceCB.setVisible(mode in mode_with_calculations)
self._sourceLabel.setVisible(mode in mode_with_calculations)
self._sourceChanged()
self.sigModeChanged.emit()
self.sigSourceChanged.emit()
def _valueUpdated(self, *args):
self.sigValueUpdated.emit(args)
......@@ -521,6 +577,7 @@ class _NormIntensityOptions(qt.QWidget):
def getConfiguration(self) -> dict:
return _normParams.IntensityNormalizationParams(
method=self.getCurrentMethod(),
source=self.getCurrentSource(),
extra_infos=self.getExtraInfos(),
).to_dict()
......@@ -546,7 +603,7 @@ class _NormIntensityOptions(qt.QWidget):
self._intensityCalcOpts.setCalculationArea(extra_infos["calc_area"])
if "calc_method" in extra_infos:
self._intensityCalcOpts.setCalculationMethod(extra_infos["calc_method"])
if params.method is Method.MANUAL_SCALAR:
if params.source is _ValueSource.MANUAL_SCALAR:
if "value" in extra_infos:
self._scalarValueWidget.setValue(extra_infos["value"])
......@@ -555,30 +612,34 @@ class _NormIntensityOptions(qt.QWidget):
def getExtraInfos(self):
method = self.getCurrentMethod()
if method is Method.MANUAL_SCALAR:
return {"value": self._scalarValueWidget.getValue()}
elif method is Method.AUTO_ROI:
raise NotImplementedError("auto roi not implemented yet")
elif method is Method.MANUAL_ROI:
roi = self._getROI()
return {
"start_x": roi.getOrigin()[0],
"end_x": roi.getOrigin()[0] + roi.getSize()[0],
"start_y": roi.getOrigin()[1],
"end_y": roi.getOrigin()[1] + roi.getSize()[1],
"calc_fct": self._intensityCalcOpts.getCalculationFct().value,
"calc_area": self._intensityCalcOpts.getCalculationArea().value,
"calc_method": self._intensityCalcOpts.getCalculationMethod().value,
}
elif method is Method.DATASET:
return {
"calc_fct": self._intensityCalcOpts.getCalculationFct().value,
"calc_area": self._intensityCalcOpts.getCalculationArea().value,
"calc_method": self._intensityCalcOpts.getCalculationMethod().value,
"dataset_url": self._datasetWidget.getDatasetUrl().path(),
}
else:
source = self.getCurrentSource()
if method in (Method.CHEBYSHEV, Method.NONE):
return {}
else:
if source is _ValueSource.MANUAL_SCALAR:
return {"value": self._scalarValueWidget.getValue()}
elif source is _ValueSource.AUTO_ROI:
raise NotImplementedError("auto roi not implemented yet")
elif source is _ValueSource.MANUAL_ROI:
roi = self._getROI()
return {
"start_x": roi.getOrigin()[0],