Commit e624fd83 authored by Thomas Vincent's avatar Thomas Vincent
Browse files

Move find_NX_class to _utils

parent 6d5dc956
...@@ -34,7 +34,7 @@ import numpy ...@@ -34,7 +34,7 @@ import numpy
from .XsocsH5Base import XsocsH5Base from .XsocsH5Base import XsocsH5Base
from .QSpaceH5 import QSpaceCoordinates from .QSpaceH5 import QSpaceCoordinates
from ._utils import str_to_h5_utf8 from ._utils import str_to_h5_utf8, find_NX_class
from ..util import text_type from ..util import text_type
...@@ -53,21 +53,6 @@ class BackgroundTypes(object): ...@@ -53,21 +53,6 @@ class BackgroundTypes(object):
ALLOWED = NONE, CONSTANT, LINEAR, SNIP ALLOWED = NONE, CONSTANT, LINEAR, SNIP
def _find_NX_class(group, nx_class):
"""Yield name of items in group of nx_class NX_class
:param h5py.Group group:
:param str nx_class:
:rtype: Iterable[str]
"""
for key, item in group.items():
cls = item.attrs.get('NX_class', '')
if hasattr(cls, 'decode'):
cls = cls.decode()
if cls == nx_class:
yield key
class FitH5(XsocsH5Base): class FitH5(XsocsH5Base):
"""File containing fit results. """File containing fit results.
...@@ -95,7 +80,7 @@ class FitH5(XsocsH5Base): ...@@ -95,7 +80,7 @@ class FitH5(XsocsH5Base):
:rtype: List[str] :rtype: List[str]
""" """
with self._get_file() as h5_file: with self._get_file() as h5_file:
return sorted(_find_NX_class(h5_file, 'NXentry')) return sorted(find_NX_class(h5_file, 'NXentry'))
def processes(self, entry): def processes(self, entry):
"""Return the processes names for the given entry. """Return the processes names for the given entry.
...@@ -104,7 +89,7 @@ class FitH5(XsocsH5Base): ...@@ -104,7 +89,7 @@ class FitH5(XsocsH5Base):
:rtype: List[str] :rtype: List[str]
""" """
with self._get_file() as h5_file: with self._get_file() as h5_file:
return sorted(_find_NX_class(h5_file[entry], 'NXprocess')) return sorted(find_NX_class(h5_file[entry], 'NXprocess'))
def get_result_names(self, entry, process): def get_result_names(self, entry, process):
"""Returns the result names for the given process. """Returns the result names for the given process.
......
...@@ -37,3 +37,18 @@ def str_to_h5_utf8(text): ...@@ -37,3 +37,18 @@ def str_to_h5_utf8(text):
:rtype: numpy.ndarray :rtype: numpy.ndarray
""" """
return numpy.array(text, dtype=h5py.special_dtype(vlen=text_type)) return numpy.array(text, dtype=h5py.special_dtype(vlen=text_type))
def find_NX_class(group, nx_class):
"""Yield name of items in group of nx_class NX_class
:param h5py.Group group:
:param str nx_class:
:rtype: Iterable[str]
"""
for key, item in group.items():
cls = item.attrs.get('NX_class', '')
if hasattr(cls, 'decode'):
cls = cls.decode()
if cls == nx_class:
yield key
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment