Commit 17481245 authored by Thomas Vincent's avatar Thomas Vincent
Browse files

rework fitH5: clean-up, update doc, updat writer API

parent c9323822
......@@ -22,6 +22,7 @@
# THE SOFTWARE.
#
# ###########################################################################*/
"""Class handle read/write of fit/COM results in HDF5 files"""
from __future__ import absolute_import
......@@ -29,17 +30,12 @@ __authors__ = ["D. Naudet"]
__license__ = "MIT"
__date__ = "15/09/2016"
from collections import namedtuple
import numpy as np
import numpy
from .XsocsH5Base import XsocsH5Base
from ..util import text_type
FitResult = namedtuple('FitResult', ['name', 'qx', 'qy', 'qz'])
class BackgroundTypes(object):
"""Enum of background subtraction types:
......@@ -66,8 +62,8 @@ class FitH5QAxis(object):
class FitH5(XsocsH5Base):
"""
File containing fit results.
"""File containing fit results.
Requirements :
- the number of sample position is defined at entry level : all processes
within the same entry are applied to the same sample points.
......@@ -86,9 +82,9 @@ class FitH5(XsocsH5Base):
_BACKGROUND_MODE_PATH = '{entry}/background_mode'
def entries(self):
"""
Return the entry names.
:return:
"""Return the entry names.
:rtype: List[str]
"""
with self._get_file() as h5_file:
# TODO : this isnt pretty but for some reason the attrs.get() fails
......@@ -100,10 +96,10 @@ class FitH5(XsocsH5Base):
'NX_class'].decode() == 'NXentry')])
def processes(self, entry):
"""
Return the processes names for the given entry.
:param entry:
:return:
"""Return the processes names for the given entry.
:param str entry:
:rtype: List[str]
"""
with self._get_file() as h5_file:
entry_grp = h5_file[entry]
......@@ -114,53 +110,26 @@ class FitH5(XsocsH5Base):
return processes
def get_result_names(self, entry, process):
"""
Returns the result names for the given process. Names are ordered
alphabetically.
:param entry:
:param process:
:return:
"""Returns the result names for the given process.
Names are ordered alphabetically.
:param str entry:
:param str process:
:rtype: List[str]
"""
results_path = self._RESULT_GRP_PATH.format(entry=entry,
process=process)
with self._get_file() as h5_file:
return sorted(h5_file[results_path].keys())
def get_qx_status(self, entry):
"""
Returns the Qx fit status for the given entry/process.
:param entry:
:param process:
:return:
"""
return self.get_status(entry, FitH5QAxis.qx_axis)
def get_qy_status(self, entry):
"""
Returns the Qy fit status for the given entry/process.
:param entry:
:param process:
:return:
"""
return self.get_status(entry, FitH5QAxis.qy_axis)
def get_qz_status(self, entry):
"""
Returns the Qz fit status for the given entry/process.
:param entry:
:param process:
:return:
"""
return self.get_status(entry, FitH5QAxis.qz_axis)
def get_status(self, entry, axis):
"""
Returns the fit status for the given entry/process/axis
:param entry:
:param process:
"""Returns the fit status for the given entry/process/axis
:param str entry:
:param axis: FitH5QAxis.qx_axis, FitH5QAxis.qy_axis
or FitH5QAxis.qz_axis
:return:
:rtype: numpy.ndarray
"""
axis_name = FitH5QAxis.axis_name(axis)
status_path = self._STATUS_PATH.format(entry=entry,
......@@ -183,68 +152,68 @@ class FitH5(XsocsH5Base):
return status
def scan_x(self, entry):
"""
Return the sample points coordinates along x for the given entry.
:param entry:
:return:
"""Return the sample points coordinates along x for the given entry.
:param str entry:
:rtype: numpy.ndarray
"""
dset_path = self._SCAN_X_PATH.format(entry=entry)
return self._get_array_data(dset_path)
def scan_y(self, entry):
"""
Return the sample points coordinates along y for the given entry.
:param entry:
:return:
"""Return the sample points coordinates along y for the given entry.
:param str entry:
:rtype: numpy.ndarray
"""
dset_path = self._SCAN_Y_PATH.format(entry=entry)
return self._get_array_data(dset_path)
def get_qx(self, entry):
"""
Returns the axis values for qx for the given entry.
:param entry:
:return:
"""Returns the axis values for qx for the given entry.
:param str entry:
:rtype: numpy.ndarray
"""
return self.__get_axis_values(entry, FitH5QAxis.qx_axis)
def get_qy(self, entry):
"""
Returns the axis values for qy for the given entry.
:param entry:
:return:
"""Returns the axis values for qy for the given entry.
:param str entry:
:rtype: numpy.ndarray
"""
return self.__get_axis_values(entry, FitH5QAxis.qy_axis)
def get_qz(self, entry):
"""
Returns the axis values for qz for the given entry.
:param entry:
:return:
"""Returns the axis values for qz for the given entry.
:param str entry:
:rtype: numpy.ndarray
"""
return self.__get_axis_values(entry, FitH5QAxis.qz_axis)
def __get_axis_values(self, entry, axis):
"""
Returns the axis values.
:param entry:
"""Returns the axis values.
:param str entry:
:param axis: FitH5QAxis.qx_axis, FitH5QAxis.qy_axis
or FitH5QAxis.qz_axis
:return:
:rtype: numpy.ndarray
"""
axis_name = FitH5QAxis.axis_name(axis)
return self._get_array_data(self._QSPACE_AXIS_PATH.format(
entry=entry, axis=axis_name))
def get_axis_result(self, entry, process, result, axis):
"""
Returns the results for the given entry/process/result name/axis.
:param entry:
:param process:
:param result:
"""Returns the results for the given entry/process/result name/axis.
:param str entry:
:param str process:
:param str result:
:param axis: FitH5QAxis.qx_axis, FitH5QAxis.qy_axis or
FitH5QAxis.qz_axis
:return:
:rtype: numpy.ndarray
"""
assert axis in FitH5QAxis.axis_values
axis_name = FitH5QAxis.axis_name(axis)
......@@ -255,60 +224,34 @@ class FitH5(XsocsH5Base):
return self._get_array_data(result_path)
def get_qx_result(self, entry, process, result):
"""
Returns the results (qx) for the given entry/process/result name.
:param entry:
:param process:
:param result:
:return:
"""Returns the results (qx) for the given entry/process/result name.
:param str entry:
:param str process:
:param str result:
:return: numpy.ndarray
"""
return self.get_axis_result(entry, process, result, FitH5QAxis.qx_axis)
def get_qy_result(self, entry, process, result):
"""
Returns the results (qy) for the given entry/process/result name.
:param entry:
:param process:
:param result:
:return:
"""Returns the results (qy) for the given entry/process/result name.
:param str entry:
:param str process:
:param str result:
:return: numpy.ndarray
"""
return self.get_axis_result(entry, process, result, FitH5QAxis.qy_axis)
def get_qz_result(self, entry, process, result):
"""
Returns the results (qz) for the given entry/process/result name.
:param entry:
:param process:
:param result:
:return:
"""
return self.get_axis_result(entry, process, result, FitH5QAxis.qz_axis)
def get_result(self, entry, process, result):
"""
Returns the results values (qx, qy, qz) for
the given entry/process/result name.
:param entry:
:param process:
:param result:
:return: a FitResult instance.
"""
with self:
results = {}
for axis in FitH5QAxis.axis_values:
results[FitH5QAxis.axis_name(axis)] = \
self.get_axis_result(entry, process, result, axis)
return FitResult(name=result, **results)
"""Returns the results (qz) for the given entry/process/result name.
def get_n_points(self, entry):
:param str entry:
:param str process:
:param str result:
:return: numpy.ndarray
"""
Returns the number of sample positions for this entry.
:param entry:
:return:
"""
dset_path = self._SCAN_X_PATH.format(entry=entry)
shape = self._get_array_data(dset_path, shape=True)
return shape[0]
return self.get_axis_result(entry, process, result, FitH5QAxis.qz_axis)
def get_background_mode(self, entry):
"""Returns the background subtraction mode used
......@@ -321,11 +264,10 @@ class FitH5(XsocsH5Base):
return mode if mode is not None else BackgroundTypes.NONE
def export_csv(self, entry, filename):
"""
Exports an entry results as csv.
:param entry:
:param filename:
:return:
"""Exports an entry results as csv.
:param str entry:
:param str filename:
"""
x, y = self.scan_x(entry), self.scan_y(entry)
......@@ -351,7 +293,7 @@ class FitH5(XsocsH5Base):
header = '; '.join(header_process) + '\n' + '; '.join(header_list)
results = np.zeros((len(x), len(header_list)))
results = numpy.zeros((len(x), len(header_list)))
results[:, 0] = x
results[:, 1] = y
......@@ -371,17 +313,22 @@ class FitH5(XsocsH5Base):
axis)
col_idx += 1
np.savetxt(filename,
results,
fmt='%.10g',
header=header,
comments='',
delimiter='; ')
numpy.savetxt(filename,
results,
fmt='%.10g',
header=header,
comments='',
delimiter='; ')
class FitH5Writer(FitH5):
"""Class to write fit/COM results in a HDF5 file"""
def create_entry(self, entry):
"""Create group to store result for entry
:param str entry:
"""
with self._get_file() as h5_file:
entries = self.entries()
if len(entries) > 0:
......@@ -390,9 +337,14 @@ class FitH5Writer(FitH5):
''.format(self.filename, entries))
# TODO : check if it already exists
entry_grp = h5_file.require_group(entry)
entry_grp.attrs['NX_class'] = np.string_('NXentry')
entry_grp.attrs['NX_class'] = numpy.string_('NXentry')
def create_process(self, entry, process):
"""Create group to store a process in entry
:param str entry:
:param str process:
"""
# TODO : check that there isn't already an existing process
with self._get_file() as h5_file:
......@@ -407,56 +359,63 @@ class FitH5Writer(FitH5):
# TODO : check if it exists
process_grp = entry_grp.require_group(process)
process_grp.attrs['NX_class'] = np.string_('NXprocess')
process_grp.attrs['NX_class'] = numpy.string_('NXprocess')
results_grp = process_grp.require_group('results')
results_grp.attrs['NX_class'] = np.string_('NXcollection')
results_grp.attrs['NX_class'] = numpy.string_('NXcollection')
def set_scan_x(self, entry, x):
dset_path = self._SCAN_X_PATH.format(entry=entry)
return self._set_array_data(dset_path, x)
def set_sample_positions(self, entry, x, y):
"""Write sample positions (x, y) in file
def set_scan_y(self, entry, y):
dset_path = self._SCAN_Y_PATH.format(entry=entry)
return self._set_array_data(dset_path, y)
:param str entry:
:param numpy.ndarray x:
:param numpy.ndarray y:
"""
self._set_array_data(self._SCAN_X_PATH.format(entry=entry), x)
self._set_array_data(self._SCAN_Y_PATH.format(entry=entry), y)
def set_status(self, entry, axis, data):
axis_name = FitH5QAxis.axis_name(axis)
def set_status(self, entry, dimension, data):
"""Write fit/COM status in the file
:param str entry:
:param int dimension:
:param numpy.ndarray data:
"""
# TODO get axis name from index
axis_name = FitH5QAxis.axis_name(dimension)
status_path = self._STATUS_PATH.format(entry=entry,
axis=axis_name)
self._set_array_data(status_path, data)
def __set_axis_result(self, entry, process, name, q_axis, data):
assert q_axis in FitH5QAxis.axis_values
axis_name = FitH5QAxis.axis_name(q_axis)
def set_result(self, entry, process, dimension, name, data):
"""Write a fit/COM result parameter to the HDF5 file
:param str entry:
:param str process:
:param int dimension:
:param str name:
:param numpy.ndarray data:
"""
assert dimension in FitH5QAxis.axis_values
axis_name = FitH5QAxis.axis_name(dimension)
result_path = self._RESULT_PATH.format(entry=entry,
process=process,
result=name,
axis=axis_name)
self._set_array_data(result_path, data)
def set_qx_result(self, entry, process, name, data):
self.__set_axis_result(entry, process, name, FitH5QAxis.qx_axis, data)
def set_qy_result(self, entry, process, name, data):
self.__set_axis_result(entry, process, name, FitH5QAxis.qy_axis, data)
def set_qspace_dimension_values(self, entry, dim0, dim1, dim2):
"""Write qspace axes coordinates for each dimension
def set_qz_result(self, entry, process, name, data):
self.__set_axis_result(entry, process, name, FitH5QAxis.qz_axis, data)
def __set_axis_values(self, entry, axis, values):
axis_name = FitH5QAxis.axis_name(axis)
self._set_array_data(self._QSPACE_AXIS_PATH.format(entry=entry,
axis=axis_name),
values)
def set_qx(self, entry, values):
self.__set_axis_values(entry, FitH5QAxis.qx_axis, values)
def set_qy(self, entry, values):
self.__set_axis_values(entry, FitH5QAxis.qy_axis, values)
def set_qz(self, entry, values):
self.__set_axis_values(entry, FitH5QAxis.qz_axis, values)
:param str entry:
:param numpy.ndarray dim0:
:param numpy.ndarray dim1:
:param numpy.ndarray dim2:
"""
for index, values in enumerate((dim0, dim1, dim2)):
axis_name = FitH5QAxis.axis_name(index)
self._set_array_data(self._QSPACE_AXIS_PATH.format(entry=entry,
axis=axis_name),
values)
def set_background_mode(self, entry, mode):
"""Returns the background subtraction mode used
......
......@@ -157,28 +157,21 @@ class FitResult(object):
with FitH5Writer(fit_h5, mode=mode) as fitH5:
fitH5.create_entry(fit_name)
fitH5.create_process(fit_name, result_name)
fitH5.set_scan_x(fit_name, self.sample_x)
fitH5.set_scan_y(fit_name, self.sample_y)
q_dim0, q_dim1, q_dim2 = self.qspace_dimension_values
fitH5.set_qx(fit_name, q_dim0)
fitH5.set_qy(fit_name, q_dim1)
fitH5.set_qz(fit_name, q_dim2)
fitH5.set_sample_positions(fit_name, self.sample_x, self.sample_y)
fitH5.set_qspace_dimension_values(
fit_name, *self.qspace_dimension_values)
fitH5.set_background_mode(fit_name, self.background_mode)
fitH5.create_process(fit_name, result_name)
for array, func, axis in zip(
self._fit_results,
(fitH5.set_qx_result, fitH5.set_qy_result, fitH5.set_qz_result),
(0, 1, 2)):
for dimension, array in enumerate(self._fit_results):
for name in self.available_result_names:
results = self.get_results(axis, name, copy=False)
results = self.get_results(dimension, name, copy=False)
if name == 'Status':
fitH5.set_status(fit_name, axis, results)
fitH5.set_status(fit_name, dimension, results)
else:
func(fit_name, result_name, name, results)
fitH5.set_result(
fit_name, result_name, dimension, name, results)
class PeakFitter(object):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment