Commit 989e4361 authored by Thomas Vincent's avatar Thomas Vincent

add from_fi_h5 methods and equality support to FitResult

parent b04bcd99
...@@ -41,7 +41,7 @@ from silx.math.fit import snip1d ...@@ -41,7 +41,7 @@ from silx.math.fit import snip1d
from ... import config from ... import config
from ...io import QSpaceH5 from ...io import QSpaceH5
from ...io.FitH5 import BackgroundTypes, FitH5Writer from ...io.FitH5 import BackgroundTypes, FitH5, FitH5Writer
from ...util import gaussian, project from ...util import gaussian, project
...@@ -112,7 +112,7 @@ class FitResult(object): ...@@ -112,7 +112,7 @@ class FitResult(object):
"""QSpace axis names (List[str])""" """QSpace axis names (List[str])"""
self.roi_indices = roi_indices self.roi_indices = roi_indices
"""Roi indices (List[List[str]])""" """Roi indices (List[List[int]])"""
self.fit_mode = fit_mode self.fit_mode = fit_mode
"""Fit type (FitTypes)""" """Fit type (FitTypes)"""
...@@ -145,20 +145,22 @@ class FitResult(object): ...@@ -145,20 +145,22 @@ class FitResult(object):
""" """
return numpy.array(self._fit_results[dimension][parameter], copy=copy) return numpy.array(self._fit_results[dimension][parameter], copy=copy)
_FIT_ENTRY_NAMES = {
FitTypes.GAUSSIAN: 'Gaussian',
FitTypes.CENTROID: 'Centroid'}
_FIT_PROCESS_NAMES = {
FitTypes.GAUSSIAN: 'gauss_0',
FitTypes.CENTROID: 'centroid'}
def to_fit_h5(self, fit_h5, mode=None): def to_fit_h5(self, fit_h5, mode=None):
"""Write fit results to an HDF5 file """Write fit results to an HDF5 file
:param str fit_h5: Filename where to save fit results :param str fit_h5: Filename where to save fit results
:param Union[None,str] mode: HDF5 file opening mode :param Union[None,str] mode: HDF5 file opening mode
""" """
if self.fit_mode == FitTypes.GAUSSIAN: fit_name = self._FIT_ENTRY_NAMES[self.fit_mode]
fit_name = 'Gaussian' result_name = self._FIT_PROCESS_NAMES[self.fit_mode]
result_name = 'gauss_0'
elif self.fit_mode == FitTypes.CENTROID:
fit_name = 'Centroid'
result_name = 'centroid'
else:
raise RuntimeError('Unknown Fit Type')
with FitH5Writer(fit_h5, with FitH5Writer(fit_h5,
entry=fit_name, entry=fit_name,
...@@ -179,6 +181,90 @@ class FitResult(object): ...@@ -179,6 +181,90 @@ class FitResult(object):
else: else:
fitH5.set_result(result_name, dimension, name, results) fitH5.set_result(result_name, dimension, name, results)
@classmethod
def from_fit_h5(cls, filename):
"""Create a FitResult from content of a HDF5 fit file
:param str filename: HDF% fit results file name
:rtype: FitResult
"""
with FitH5(filename) as fith5:
# Retrieve entry
entries = fith5.entries()
if len(entries) != 1:
raise RuntimeError("Only one entry in fit result is supported")
entry = fith5.entries()[0]
# Get fit mode corresponding to entry name
for key, value in cls._FIT_ENTRY_NAMES.items():
if value == entry:
fit_mode = key
break
else:
raise RuntimeError("Unsupported fit entry name: %s" % entry)
# Get corresponding NXProcess name
process = cls._FIT_PROCESS_NAMES[fit_mode]
if process not in fith5.processes(entry):
raise RuntimeError("Cannot find relevant NXProcess group in file")
# Retrieve fit results
parameters = list(fith5.get_result_names(entry, process))
parameters.append('Status')
fit_results = []
for axis in range(len(fith5.get_qspace_dimension_names(entry))):
result = [fith5.get_axis_result(entry, process, param, axis)
for param in parameters[:-1]] # All but Status
result.append(fith5.get_status(entry, axis))
fit_results.append(result)
# Get dtype from result arrays
result_dtype = [(name, array.dtype)
for name, array in zip(parameters, fit_results[0])]
# Convert from list by axis of list by param of list by sample point
# To list by sample point of list by axis of list by param
nb_axes = len(fit_results)
nb_params = len(fit_results[0])
nb_points = len(fit_results[0][0])
fit_results = [[[fit_results[axis][param][point]
for param in range(nb_params)]
for axis in range(nb_axes)]
for point in range(nb_points)]
# Convert to record array
fit_results = numpy.array(fit_results, dtype=result_dtype)
# Return FitResult object
sample_x, sample_y = fith5.sample_positions(entry)
return FitResult(
sample_x=sample_x,
sample_y=sample_y,
q_dim_values=fith5.get_qspace_dimension_values(entry),
q_dim_names=fith5.get_qspace_dimension_names(entry),
roi_indices=fith5.get_roi_indices(entry),
fit_mode=fit_mode,
background_mode=fith5.get_background_mode(entry),
fit_results=fit_results)
def __eq__(self, other):
"""Implement equality, useful for tests"""
return (
isinstance(other, FitResult) and
numpy.array_equal(self.sample_x, other.sample_x) and
numpy.array_equal(self.sample_y, other.sample_y) and
numpy.all([numpy.array_equal(a1, a2)
for a1, a2 in zip(self.qspace_dimension_values,
other.qspace_dimension_values)]) and
(tuple(self.qspace_dimension_names) ==
tuple(other.qspace_dimension_names)) and
numpy.array_equal(self.roi_indices, self.roi_indices) and
self.fit_mode == other.fit_mode and
self.background_mode == self.background_mode and
numpy.array_equal(self._fit_results, other._fit_results))
class PeakFitter(object): class PeakFitter(object):
"""Class performing fit/com processing """Class performing fit/com processing
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment