Commit 7eed11f3 authored by payno's avatar payno Committed by Henri Payno
Browse files

normalization: clean up things

* remove calculation area (volume for now)
* remove calculation method (one value per frame...)
parent 672ffab8
...@@ -37,8 +37,6 @@ __date__ = "25/06/2021" ...@@ -37,8 +37,6 @@ __date__ = "25/06/2021"
from tomwer.core.process.task import Task from tomwer.core.process.task import Task
from .params import ( from .params import (
IntensityNormalizationParams, IntensityNormalizationParams,
_CalculationArea,
_ValueCalculationMethod,
_ValueCalculationFct, _ValueCalculationFct,
_ValueSource, _ValueSource,
) )
...@@ -167,20 +165,6 @@ class IntensityNormalizationTask( ...@@ -167,20 +165,6 @@ class IntensityNormalizationTask(
raise ValueError("calc_fct should be provided") raise ValueError("calc_fct should be provided")
else: else:
calc_fct = _ValueCalculationFct.from_value(calc_fct) calc_fct = _ValueCalculationFct.from_value(calc_fct)
calc_area = extra_info.get("calc_area", None)
if calc_area is None:
raise ValueError("calc_area should be provided")
else:
calc_area = _CalculationArea.from_value(calc_area)
calc_method = extra_info.get("calc_method", None)
if calc_method is None:
raise ValueError("calc_method should be provided")
else:
calc_method = _ValueCalculationMethod.from_value(calc_method)
# # For now we removed the _CalculationArea.PROJECTION option
# if calc_method is _ValueCalculationMethod.ONE_PER_FRAME and calc_area is _CalculationArea.PROJECTION:
# raise ValueError("You cannot get one value per frame and "
# "computing ROI on a single projection")
try: try:
value = self._cache_compute_from_manual_roi( value = self._cache_compute_from_manual_roi(
...@@ -189,8 +173,7 @@ class IntensityNormalizationTask( ...@@ -189,8 +173,7 @@ class IntensityNormalizationTask(
start_y=start_y, start_y=start_y,
end_x=end_x, end_x=end_x,
end_y=end_y, end_y=end_y,
calc_area=calc_area, )[calc_fct.value]
)[calc_method.value][calc_fct.value]
except Exception as e: except Exception as e:
_logger.error(e) _logger.error(e)
return None return None
...@@ -207,7 +190,6 @@ class IntensityNormalizationTask( ...@@ -207,7 +190,6 @@ class IntensityNormalizationTask(
end_x, end_x,
start_y, start_y,
end_y, end_y,
calc_area: _CalculationArea,
) -> dict: ) -> dict:
""" """
compute mean and median on a volume or a slice. compute mean and median on a volume or a slice.
...@@ -237,193 +219,133 @@ class IntensityNormalizationTask( ...@@ -237,193 +219,133 @@ class IntensityNormalizationTask(
end_y = max(0, end_y) end_y = max(0, end_y)
scan = dataset_identifier.recreate_dataset() scan = dataset_identifier.recreate_dataset()
projections = scan.projections projections = scan.projections
if calc_area is _CalculationArea.VOLUME:
roi_area = numpy.zeros(
(len(projections), int(end_y - start_y), int(end_x - start_x))
)
# as dataset can be large we try to split the load of
# the data
proj_compacted = tomoscan.esrf.utils.get_compacted_dataslices(
projections,
max_grp_size=20,
)
proj_indexes = sorted(proj_compacted.keys())
# small hack to avoid loading several time the same url
url_treated = set()
def url_has_been_treated(url: DataUrl): roi_area = numpy.zeros(
return ( (len(projections), int(end_y - start_y), int(end_x - start_x))
)
# as dataset can be large we try to split the load of
# the data
proj_compacted = tomoscan.esrf.utils.get_compacted_dataslices(
projections,
max_grp_size=20,
)
proj_indexes = sorted(proj_compacted.keys())
# small hack to avoid loading several time the same url
url_treated = set()
def url_has_been_treated(url: DataUrl):
return (
url.file_path(),
url.data_path(),
url.data_slice().start,
url.data_slice().stop,
url.data_slice().step,
url.scheme(),
) in url_treated
def append_url(url: DataUrl):
url_treated.add(
(
url.file_path(), url.file_path(),
url.data_path(), url.data_path(),
url.data_slice().start, url.data_slice().start,
url.data_slice().stop, url.data_slice().stop,
url.data_slice().step, url.data_slice().step,
url.scheme(), url.scheme(),
) in url_treated
def append_url(url: DataUrl):
url_treated.add(
(
url.file_path(),
url.data_path(),
url.data_slice().start,
url.data_slice().stop,
url.data_slice().step,
url.scheme(),
)
) )
)
current_idx = 0 current_idx = 0
url_idxs = {v.path(): k for k, v in scan.projections.items()} url_idxs = {v.path(): k for k, v in scan.projections.items()}
for proj_index in proj_indexes: for proj_index in proj_indexes:
url = proj_compacted[proj_index] url = proj_compacted[proj_index]
if url_has_been_treated(url): if url_has_been_treated(url):
continue continue
append_url(url) append_url(url)
data = silx.io.get_data(url) data = silx.io.get_data(url)
if data.ndim < 2: if data.ndim < 2:
raise ValueError("data is expected to be at least 2D") raise ValueError("data is expected to be at least 2D")
# clamp ROI with frame size # clamp ROI with frame size
start_x = min(data.shape[-1], start_x) start_x = min(data.shape[-1], start_x)
start_y = min(data.shape[-2], start_y) start_y = min(data.shape[-2], start_y)
end_x = min(data.shape[-1], end_x) end_x = min(data.shape[-1], end_x)
end_y = min(data.shape[-2], end_y) end_y = min(data.shape[-2], end_y)
def retrieve_data_proj_indexes(url_): def retrieve_data_proj_indexes(url_):
urls = [] urls = []
for slice in range( for slice in range(
url_.data_slice().start, url_.data_slice().start,
url_.data_slice().stop, url_.data_slice().stop,
url_.data_slice().step or 1, url_.data_slice().step or 1,
): ):
urls.append( urls.append(
DataUrl( DataUrl(
file_path=url_.file_path(), file_path=url_.file_path(),
data_path=url_.data_path(), data_path=url_.data_path(),
scheme=url_.scheme(), scheme=url_.scheme(),
data_slice=slice, data_slice=slice,
)
) )
)
# try to retrieve the index from the projections else # try to retrieve the index from the projections else
# keep the slice index as the frame index (should be the # keep the slice index as the frame index (should be the
# case in most case # case in most case
res = [] res = []
for my_url in urls: for my_url in urls:
my_url_path = my_url.path() my_url_path = my_url.path()
if my_url_path in url_idxs: if my_url_path in url_idxs:
res.append(url_idxs[my_url_path]) res.append(url_idxs[my_url_path])
else: else:
_logger.warning( _logger.warning(
"unable to retrieve frame index from url {}. " "unable to retrieve frame index from url {}. "
"Take the slice index as frame index".format( "Take the slice index as frame index".format(my_url_path)
my_url_path )
) return res
)
return res
data_indexes = retrieve_data_proj_indexes(url) data_indexes = retrieve_data_proj_indexes(url)
# apply flat field correction # apply flat field correction
if data.ndim == 2: if data.ndim == 2:
projs = (data,) projs = (data,)
else: else:
projs = list(data) projs = list(data)
data = scan.flat_field_correction( data = scan.flat_field_correction(projs=projs, proj_indexes=data_indexes)
projs=projs, proj_indexes=data_indexes if data is None:
continue
data = numpy.asarray(data)
if data.ndim == 2:
roi_area[current_idx] = data[start_y:end_y, start_x:end_x]
current_idx += 1
elif data.ndim == 3:
length = data.shape[0]
roi_area[current_idx : current_idx + length] = data[
:, start_y:end_y, start_x:end_x
]
current_idx += length
else:
raise ValueError(
"Frame where expected and not a " "{}D object".format(data.ndim)
) )
if data is None:
continue
data = numpy.asarray(data)
if data.ndim == 2:
roi_area[current_idx] = data[start_y:end_y, start_x:end_x]
current_idx += 1
elif data.ndim == 3:
length = data.shape[0]
roi_area[current_idx : current_idx + length] = data[
:, start_y:end_y, start_x:end_x
]
current_idx += length
else:
raise ValueError(
"Frame where expected and not a " "{}D object".format(data.ndim)
)
else:
raise ValueError("{} is not handled".format(calc_area))
return IntensityNormalizationTask.compute_stats(roi_area) return IntensityNormalizationTask.compute_stats(roi_area)
@staticmethod @staticmethod
def compute_stats(data): def compute_stats(data):
results = {} results = {}
for calc_method in _ValueCalculationMethod: for calc_fct in _ValueCalculationFct:
results[calc_method.value] = {} if data.ndim == 3:
for calc_fct in _ValueCalculationFct: res = getattr(numpy, calc_fct.value)(data, axis=(-2, -1))
if calc_method is _ValueCalculationMethod.SCALAR_VALUE: elif data.ndim == 2:
res = getattr(numpy, calc_fct.value)(data) res = getattr(numpy, calc_fct.value)(data, axis=(-1))
elif calc_method is _ValueCalculationMethod.ONE_PER_FRAME: elif data.ndim in (0, 1):
if data.ndim == 3: res = data
res = getattr(numpy, calc_fct.value)(data, axis=(-2, -1)) else:
elif data.ndim == 2: raise ValueError("dataset dimension not handled ({})".format(data.ndim))
res = getattr(numpy, calc_fct.value)(data, axis=(-1)) results[calc_fct.value] = res
elif data.ndim in (0, 1):
res = data
else:
raise ValueError(
"dataset dimension not handled ({})".format(data.ndim)
)
else:
raise ValueError("{} is not handled".format(calc_method.value))
results[calc_method.value][calc_fct.value] = res
return results return results
def _compute_from_automatic_roi(self, scan): def _compute_from_automatic_roi(self, scan):
raise NotImplementedError("Not implemented yet") raise NotImplementedError("Not implemented yet")
def _compute_from_dataset(self):
params = IntensityNormalizationParams.from_dict(self.get_configuration())
extra_info = params.extra_infos
calc_fct = extra_info.get("calc_fct", None)
if calc_fct is None:
raise ValueError("calc_fct should be provided")
else:
calc_fct = _ValueCalculationFct.from_value(calc_fct)
calc_area = extra_info.get("calc_area", None)
if calc_area is None:
raise ValueError("calc_area should be provided")
else:
calc_area = _CalculationArea.from_value(calc_area)
calc_method = extra_info.get("calc_method", None)
if calc_method is None:
raise ValueError("calc_method should be provided")
else:
calc_method = _ValueCalculationMethod.from_value(calc_method)
url = extra_info.get("dataset_url", None)
if url is None:
raise ValueError("dataset_url should be provided")
elif isinstance(url, DataUrl):
url = url.path()
try:
return self._cache_compute_from_dataset(calc_area=calc_area, url=url,)[
calc_method.value
][calc_fct.value]
except Exception as e:
_logger.error(e)
return None
@staticmethod
@functools.lru_cache(
maxsize=6
) # maxsize=6 to at most keep info at volume and at frame level for 3 scans
def _cache_compute_from_dataset(calc_area: _CalculationArea, url: str) -> dict:
if calc_area is _CalculationArea.VOLUME:
url = DataUrl(url)
data = get_data(url)
return IntensityNormalizationTask.compute_stats(data=data)
else:
raise ValueError("{} is not managed".format(calc_area))
@staticmethod @staticmethod
def program_name(): def program_name():
"""Name of the program used for this processing""" """Name of the program used for this processing"""
......
...@@ -54,21 +54,11 @@ class _ValueCalculationFct(_Enum): ...@@ -54,21 +54,11 @@ class _ValueCalculationFct(_Enum):
MEDIAN = "median" MEDIAN = "median"
class _ValueCalculationMethod(_Enum):
ONE_PER_FRAME = "one value per fame"
SCALAR_VALUE = "scalar"
class _DatasetScope(_Enum): class _DatasetScope(_Enum):
LOCAL = "local" LOCAL = "local"
GLOBAL = "global" GLOBAL = "global"
class _CalculationArea(_Enum):
VOLUME = "volume"
# PROJECTION = "projection" # for now just do it on the entire volume
class _DatasetInfos: class _DatasetInfos:
def __init__(self): def __init__(self):
self._scope = _DatasetScope.GLOBAL self._scope = _DatasetScope.GLOBAL
......
...@@ -514,17 +514,16 @@ class _NormIntensityOptions(qt.QWidget): ...@@ -514,17 +514,16 @@ class _NormIntensityOptions(qt.QWidget):
_ValueSource.AUTO_ROI, _ValueSource.AUTO_ROI,
_ValueSource.DATASET, _ValueSource.DATASET,
) )
self._intensityCalcOpts._calculationAreaCB.setVisible(
source in (_ValueSource.MANUAL_ROI,)
)
self._intensityCalcOpts._calculationAreaLabel.setVisible(
source in (_ValueSource.MANUAL_ROI,)
)
self._datasetWidget.setVisible(source == _ValueSource.DATASET) self._datasetWidget.setVisible(source == _ValueSource.DATASET)
self.setManualROIVisible(source == _ValueSource.MANUAL_ROI) self.setManualROIVisible(source == _ValueSource.MANUAL_ROI)
self._optsMethod.setVisible( self._optsMethod.setVisible(
method in interactive_methods and source in interactive_sources method in interactive_methods and source in interactive_sources
) )
self._intensityCalcOpts.setCalculationFctVisible(
method in interactive_methods
and source in interactive_sources
and source is not _ValueSource.DATASET
)
self._scalarValueWidget.setVisible(source == _ValueSource.MANUAL_SCALAR) self._scalarValueWidget.setVisible(source == _ValueSource.MANUAL_SCALAR)
self._buttonsGrp.setVisible( self._buttonsGrp.setVisible(
method in interactive_methods and source in interactive_sources method in interactive_methods and source in interactive_sources
...@@ -637,8 +636,6 @@ class _NormIntensityOptions(qt.QWidget): ...@@ -637,8 +636,6 @@ class _NormIntensityOptions(qt.QWidget):
"start_y": roi.getOrigin()[1], "start_y": roi.getOrigin()[1],
"end_y": roi.getOrigin()[1] + roi.getSize()[1], "end_y": roi.getOrigin()[1] + roi.getSize()[1],
"calc_fct": self._intensityCalcOpts.getCalculationFct().value, "calc_fct": self._intensityCalcOpts.getCalculationFct().value,
"calc_area": self._intensityCalcOpts.getCalculationArea().value,
"calc_method": self._intensityCalcOpts.getCalculationMethod().value,
} }
elif source is _ValueSource.DATASET: elif source is _ValueSource.DATASET:
return { return {
...@@ -664,26 +661,15 @@ class _NormIntensityCalcOpts(qt.QWidget): ...@@ -664,26 +661,15 @@ class _NormIntensityCalcOpts(qt.QWidget):
self._calculationModeCB = qt.QComboBox(self) self._calculationModeCB = qt.QComboBox(self)
for fct in _normParams._ValueCalculationFct.values(): for fct in _normParams._ValueCalculationFct.values():
self._calculationModeCB.addItem(fct) self._calculationModeCB.addItem(fct)
self.layout().addRow("calculation fct", self._calculationModeCB) self._calculationModeLabel = qt.QLabel("calculation fct", self)
# calculation depth self.layout().addRow(self._calculationModeLabel, self._calculationModeCB)
self._calculationAreaCB = qt.QComboBox(self)
for area in _normParams._CalculationArea.values():
self._calculationAreaCB.addItem(area)
self._calculationAreaLabel = qt.QLabel("calculation area", self)
self.layout().addRow(self._calculationAreaLabel, self._calculationAreaCB)
# calculation 'axis'
self._calculationMethodCB = qt.QComboBox(self)
for opt in _normParams._ValueCalculationMethod.values():
self._calculationMethodCB.addItem(opt)
self.layout().addRow("calculation method", self._calculationMethodCB)
# connect signal / slot # connect signal / slot
self._calculationAreaCB.currentTextChanged.connect(self._areaChanged)
self._calculationModeCB.currentIndexChanged.connect(self._configurationChanged) self._calculationModeCB.currentIndexChanged.connect(self._configurationChanged)
self._calculationAreaCB.currentIndexChanged.connect(self._configurationChanged)
self._calculationMethodCB.currentIndexChanged.connect( def setCalculationFctVisible(self, visible):
self._configurationChanged self._calculationModeLabel.setVisible(visible)
) self._calculationModeCB.setVisible(visible)
def _configurationChanged(self): def _configurationChanged(self):
self.sigConfigurationChanged.emit() self.sigConfigurationChanged.emit()
...@@ -699,38 +685,6 @@ class _NormIntensityCalcOpts(qt.QWidget): ...@@ -699,38 +685,6 @@ class _NormIntensityCalcOpts(qt.QWidget):
) )
self._calculationModeCB.setCurrentIndex(idx) self._calculationModeCB.setCurrentIndex(idx)
def getCalculationArea(self):
return _normParams._CalculationArea.from_value(
self._calculationAreaCB.currentText()
)
def setCalculationArea(self, area):
idx = self._calculationAreaCB.findText(
_normParams._CalculationArea.from_value(area)
)
self._calculationAreaCB.setCurrentIndex(idx)
def getCalculationMethod(self):
return _normParams._ValueCalculationMethod.from_value(
self._calculationMethodCB.currentText()
)
def setCalculationMethod(self, method):
idx = self._calculationMethodCB.findText(
_normParams._ValueCalculationMethod.from_value(method)
)
self._calculationMethodCB.setCurrentIndex(idx)
def _areaChanged(self):
if self.getCalculationArea() == _normParams._CalculationArea.PROJECTION:
idx = self._calculationMethodCB.findText(
_normParams._ValueCalculationMethod.SCALAR_VALUE.value
)
self._calculationMethodCB.setCurrentIndex(idx)
self._calculationMethodCB.setEnabled(False)
else:
self._calculationMethodCB.setEnabled(True)
class _NormIntensityControl(ControlWidget): class _NormIntensityControl(ControlWidget):
def __init__(self, parent=None): def __init__(self, parent=None):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment