Commit e624fd83 authored by Thomas Vincent's avatar Thomas Vincent

Move find_NX_class to _utils

parent 6d5dc956
......@@ -34,7 +34,7 @@ import numpy
from .XsocsH5Base import XsocsH5Base
from .QSpaceH5 import QSpaceCoordinates
from ._utils import str_to_h5_utf8
from ._utils import str_to_h5_utf8, find_NX_class
from ..util import text_type
......@@ -53,21 +53,6 @@ class BackgroundTypes(object):
ALLOWED = NONE, CONSTANT, LINEAR, SNIP
def _find_NX_class(group, nx_class):
"""Yield name of items in group of nx_class NX_class
:param h5py.Group group:
:param str nx_class:
:rtype: Iterable[str]
"""
for key, item in group.items():
cls = item.attrs.get('NX_class', '')
if hasattr(cls, 'decode'):
cls = cls.decode()
if cls == nx_class:
yield key
class FitH5(XsocsH5Base):
"""File containing fit results.
......@@ -95,7 +80,7 @@ class FitH5(XsocsH5Base):
:rtype: List[str]
"""
with self._get_file() as h5_file:
return sorted(_find_NX_class(h5_file, 'NXentry'))
return sorted(find_NX_class(h5_file, 'NXentry'))
def processes(self, entry):
"""Return the processes names for the given entry.
......@@ -104,7 +89,7 @@ class FitH5(XsocsH5Base):
:rtype: List[str]
"""
with self._get_file() as h5_file:
return sorted(_find_NX_class(h5_file[entry], 'NXprocess'))
return sorted(find_NX_class(h5_file[entry], 'NXprocess'))
def get_result_names(self, entry, process):
"""Returns the result names for the given process.
......
......@@ -37,3 +37,18 @@ def str_to_h5_utf8(text):
:rtype: numpy.ndarray
"""
return numpy.array(text, dtype=h5py.special_dtype(vlen=text_type))
def find_NX_class(group, nx_class):
"""Yield name of items in group of nx_class NX_class
:param h5py.Group group:
:param str nx_class:
:rtype: Iterable[str]
"""
for key, item in group.items():
cls = item.attrs.get('NX_class', '')
if hasattr(cls, 'decode'):
cls = cls.decode()
if cls == nx_class:
yield key
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment