Commit d92952f8 authored by Thomas Vincent's avatar Thomas Vincent
Browse files

Use coordinates to trigger dataset names + rework reading attributes

parent 1de706bb
......@@ -33,6 +33,8 @@ __date__ = "15/09/2016"
import numpy
from .XsocsH5Base import XsocsH5Base
from .QSpaceH5 import QSpaceCoordinates
from ._utils import str_to_h5_utf8
from ..util import text_type
......@@ -61,6 +63,21 @@ class FitH5QAxis(object): # TODO remove
return FitH5QAxis.axis_names[axis]
def _find_NX_class(group, nx_class):
"""Yield name of items in group of nx_class NX_class
:param h5py.Group group:
:param str nx_class:
:rtype: Iterable[str]
for key, item in group.items():
cls = item.attrs.get('NX_class', '')
if hasattr(cls, 'decode'):
cls = cls.decode()
if cls == nx_class:
yield key
class FitH5(XsocsH5Base):
"""File containing fit results.
......@@ -90,10 +107,7 @@ class FitH5(XsocsH5Base):
# TODO : this isnt pretty but for some reason the attrs.get() fails
# when there is no attribute NX_class (should return the default
# None)
return sorted([key for key in h5_file
if ('NX_class' in h5_file[key].attrs and
'NX_class'].decode() == 'NXentry')])
return sorted(_find_NX_class(h5_file, 'NXentry'))
def processes(self, entry):
"""Return the processes names for the given entry.
......@@ -102,12 +116,7 @@ class FitH5(XsocsH5Base):
:rtype: List[str]
with self._get_file() as h5_file:
entry_grp = h5_file[entry]
processes = sorted([key for key in entry_grp
if ('NX_class' in entry_grp[key].attrs and
'NX_class'].decode() == 'NXprocess')])
return processes
return sorted(_find_NX_class(h5_file[entry], 'NXprocess'))
def get_result_names(self, entry, process):
"""Returns the result names for the given process.
......@@ -331,7 +340,27 @@ class FitH5(XsocsH5Base):
class FitH5Writer(FitH5):
"""Class to write fit/COM results in a HDF5 file"""
"""Class to write fit/COM results in a HDF5 file
:param str h5_f: Filename where to write
:param QSpaceCoordinates coordinates: Kind of QSpace coordinates
:param str mode: File opening mode
def __init__(self, h5_f,
assert coordinates in QSpaceCoordinates.ALLOWED
self.__coordinates = coordinates
super(FitH5Writer, self).__init__(h5_f, mode)
def __axis_name(self, dimension):
"""Returns the names of the QSpace axis
:param int dimension: Dimension index of the axis
:rtype: str
return QSpaceCoordinates.axes_names(self.__coordinates)[dimension]
def create_entry(self, entry):
"""Create group to store result for entry
......@@ -346,7 +375,7 @@ class FitH5Writer(FitH5):
''.format(self.filename, entries))
# TODO : check if it already exists
entry_grp = h5_file.require_group(entry)
entry_grp.attrs['NX_class'] = numpy.string_('NXentry')
entry_grp.attrs['NX_class'] = str_to_h5_utf8('NXentry')
def create_process(self, entry, process):
"""Create group to store a process in entry
......@@ -368,9 +397,9 @@ class FitH5Writer(FitH5):
# TODO : check if it exists
process_grp = entry_grp.require_group(process)
process_grp.attrs['NX_class'] = numpy.string_('NXprocess')
process_grp.attrs['NX_class'] = str_to_h5_utf8('NXprocess')
results_grp = process_grp.require_group('results')
results_grp.attrs['NX_class'] = numpy.string_('NXcollection')
results_grp.attrs['NX_class'] = str_to_h5_utf8('NXcollection')
def set_sample_positions(self, entry, x, y):
"""Write sample positions (x, y) in file
......@@ -389,10 +418,8 @@ class FitH5Writer(FitH5):
:param int dimension:
:param numpy.ndarray data:
# TODO get axis name from index
axis_name = FitH5QAxis.axis_name(dimension)
status_path = self._STATUS_PATH.format(entry=entry,
status_path = self._STATUS_PATH.format(
entry=entry, axis=self.__axis_name(dimension))
self._set_array_data(status_path, data)
def set_result(self, entry, process, dimension, name, data):
......@@ -404,8 +431,7 @@ class FitH5Writer(FitH5):
:param str name:
:param numpy.ndarray data:
assert dimension in FitH5QAxis.axis_values
axis_name = FitH5QAxis.axis_name(dimension)
axis_name = self.__axis_name(dimension)
result_path = self._RESULT_PATH.format(entry=entry,
......@@ -421,10 +447,10 @@ class FitH5Writer(FitH5):
:param numpy.ndarray dim2:
for index, values in enumerate((dim0, dim1, dim2)):
axis_name = FitH5QAxis.axis_name(index)
def set_background_mode(self, entry, mode):
"""Returns the background subtraction mode used
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment