Commit 989e4361 authored by Thomas Vincent's avatar Thomas Vincent

add from_fi_h5 methods and equality support to FitResult

parent b04bcd99
......@@ -41,7 +41,7 @@ from import snip1d
from ... import config
from import QSpaceH5
from import BackgroundTypes, FitH5Writer
from import BackgroundTypes, FitH5, FitH5Writer
from ...util import gaussian, project
......@@ -112,7 +112,7 @@ class FitResult(object):
"""QSpace axis names (List[str])"""
self.roi_indices = roi_indices
"""Roi indices (List[List[str]])"""
"""Roi indices (List[List[int]])"""
self.fit_mode = fit_mode
"""Fit type (FitTypes)"""
......@@ -145,20 +145,22 @@ class FitResult(object):
return numpy.array(self._fit_results[dimension][parameter], copy=copy)
FitTypes.GAUSSIAN: 'Gaussian',
FitTypes.CENTROID: 'Centroid'}
FitTypes.GAUSSIAN: 'gauss_0',
FitTypes.CENTROID: 'centroid'}
def to_fit_h5(self, fit_h5, mode=None):
"""Write fit results to an HDF5 file
:param str fit_h5: Filename where to save fit results
:param Union[None,str] mode: HDF5 file opening mode
if self.fit_mode == FitTypes.GAUSSIAN:
fit_name = 'Gaussian'
result_name = 'gauss_0'
elif self.fit_mode == FitTypes.CENTROID:
fit_name = 'Centroid'
result_name = 'centroid'
raise RuntimeError('Unknown Fit Type')
fit_name = self._FIT_ENTRY_NAMES[self.fit_mode]
result_name = self._FIT_PROCESS_NAMES[self.fit_mode]
with FitH5Writer(fit_h5,
......@@ -179,6 +181,90 @@ class FitResult(object):
fitH5.set_result(result_name, dimension, name, results)
def from_fit_h5(cls, filename):
"""Create a FitResult from content of a HDF5 fit file
:param str filename: HDF% fit results file name
:rtype: FitResult
with FitH5(filename) as fith5:
# Retrieve entry
entries = fith5.entries()
if len(entries) != 1:
raise RuntimeError("Only one entry in fit result is supported")
entry = fith5.entries()[0]
# Get fit mode corresponding to entry name
for key, value in cls._FIT_ENTRY_NAMES.items():
if value == entry:
fit_mode = key
raise RuntimeError("Unsupported fit entry name: %s" % entry)
# Get corresponding NXProcess name
process = cls._FIT_PROCESS_NAMES[fit_mode]
if process not in fith5.processes(entry):
raise RuntimeError("Cannot find relevant NXProcess group in file")
# Retrieve fit results
parameters = list(fith5.get_result_names(entry, process))
fit_results = []
for axis in range(len(fith5.get_qspace_dimension_names(entry))):
result = [fith5.get_axis_result(entry, process, param, axis)
for param in parameters[:-1]] # All but Status
result.append(fith5.get_status(entry, axis))
# Get dtype from result arrays
result_dtype = [(name, array.dtype)
for name, array in zip(parameters, fit_results[0])]
# Convert from list by axis of list by param of list by sample point
# To list by sample point of list by axis of list by param
nb_axes = len(fit_results)
nb_params = len(fit_results[0])
nb_points = len(fit_results[0][0])
fit_results = [[[fit_results[axis][param][point]
for param in range(nb_params)]
for axis in range(nb_axes)]
for point in range(nb_points)]
# Convert to record array
fit_results = numpy.array(fit_results, dtype=result_dtype)
# Return FitResult object
sample_x, sample_y = fith5.sample_positions(entry)
return FitResult(
def __eq__(self, other):
"""Implement equality, useful for tests"""
return (
isinstance(other, FitResult) and
numpy.array_equal(self.sample_x, other.sample_x) and
numpy.array_equal(self.sample_y, other.sample_y) and
numpy.all([numpy.array_equal(a1, a2)
for a1, a2 in zip(self.qspace_dimension_values,
other.qspace_dimension_values)]) and
(tuple(self.qspace_dimension_names) ==
tuple(other.qspace_dimension_names)) and
numpy.array_equal(self.roi_indices, self.roi_indices) and
self.fit_mode == other.fit_mode and
self.background_mode == self.background_mode and
numpy.array_equal(self._fit_results, other._fit_results))
class PeakFitter(object):
"""Class performing fit/com processing
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment