Commit df67db23 authored by Carsten Richter's avatar Carsten Richter

Merge branch 'fix-hdf-text' into 'master'

Fix hdf text

See merge request !96
parents 36e02056 438aa036
Pipeline #6712 canceled with stages
......@@ -59,7 +59,9 @@ def getNodeClass(nodeType, attrs=None):
for key, value in attrs.items():
info = _registeredAttributes.get(key)
if info:
klass = info.get(value.decode())
if hasattr(value, 'decode'):
value = value.decode()
klass = info.get(value)
if klass:
break
if not klass:
......@@ -291,7 +293,7 @@ class H5DatasetNode(H5Base):
ndims = len(item.shape)
if ndims == 0:
data = item[()]
if isinstance(data, (np.string_, bytes_type)):
if hasattr(data, 'decode'):
text = data.decode().replace('\n', ' ')
else:
text = str(data)
......
......@@ -59,10 +59,10 @@ def projectItemFactory(h5File, h5Path, mode=None):
klass = None
with h5py.File(h5File, mode='r') as h5f:
attrs = h5f[h5Path].attrs
# For some reason attrs.get sometimes fails,
# using "in" seems a bit more robust.
if 'XsocsClass' in attrs:
xsocsClass = attrs['XsocsClass'].decode()
xsocsClass = attrs['XsocsClass']
if hasattr(xsocsClass, 'decode'):
xsocsClass = xsocsClass.decode()
klass = getItemClass(xsocsClass)
del attrs
......
......@@ -34,7 +34,7 @@ import numpy
from .XsocsH5Base import XsocsH5Base
from .QSpaceH5 import QSpaceCoordinates
from ._utils import str_to_h5_utf8
from ._utils import str_to_h5_utf8, find_NX_class
from ..util import text_type
......@@ -53,21 +53,6 @@ class BackgroundTypes(object):
ALLOWED = NONE, CONSTANT, LINEAR, SNIP
def _find_NX_class(group, nx_class):
"""Yield name of items in group of nx_class NX_class
:param h5py.Group group:
:param str nx_class:
:rtype: Iterable[str]
"""
for key, item in group.items():
cls = item.attrs.get('NX_class', '')
if hasattr(cls, 'decode'):
cls = cls.decode()
if cls == nx_class:
yield key
class FitH5(XsocsH5Base):
"""File containing fit results.
......@@ -95,7 +80,7 @@ class FitH5(XsocsH5Base):
:rtype: List[str]
"""
with self._get_file() as h5_file:
return sorted(_find_NX_class(h5_file, 'NXentry'))
return sorted(find_NX_class(h5_file, 'NXentry'))
def processes(self, entry):
"""Return the processes names for the given entry.
......@@ -104,7 +89,7 @@ class FitH5(XsocsH5Base):
:rtype: List[str]
"""
with self._get_file() as h5_file:
return sorted(_find_NX_class(h5_file[entry], 'NXprocess'))
return sorted(find_NX_class(h5_file[entry], 'NXprocess'))
def get_result_names(self, entry, process):
"""Returns the result names for the given process.
......
......@@ -156,10 +156,8 @@ class QSpaceH5(XsocsH5Base):
def selected_entries(self):
"""Returns the input entries used for the conversion."""
path = self._ENTRIES_PATH + '/selected'
entries = self._get_array_data(path)
if entries is not None:
return [entry.decode() for entry in entries]
return []
entries = self._get_array_str(path)
return entries if entries is not None else []
@property
def shifts(self):
......@@ -207,10 +205,8 @@ class QSpaceH5(XsocsH5Base):
Returns the input entries that were not used for the conversion.
"""
path = self._ENTRIES_PATH + '/discarded'
entries = self._get_array_data(path)
if entries is not None:
return [entry.decode() for entry in entries]
return []
entries = self._get_array_str(path)
return entries if entries is not None else []
@property
def image_mask(self):
......@@ -434,11 +430,10 @@ class QSpaceH5Writer(QSpaceH5):
number of selected entries. The shift is in grid coordinates.
"""
path = self._ENTRIES_PATH + '/selected'
selected = _np.array(selected, dtype=_np.string_)
selected = str_to_h5_utf8(selected)
self._set_array_data(path, selected)
path = self._ENTRIES_PATH + '/discarded'
discarded = _np.array((discarded is not None and discarded) or [],
dtype=_np.string_)
discarded = str_to_h5_utf8(discarded if discarded is not None else [])
self._set_array_data(path, discarded)
if sample_shifts is not None:
......
......@@ -37,6 +37,7 @@ import h5py as _h5py
import numpy as _np
from .XsocsH5Base import XsocsH5Base
from ._utils import str_to_h5_utf8, find_NX_class
class InvalidEntryError(Exception):
......@@ -46,9 +47,9 @@ class InvalidEntryError(Exception):
ScanPositions = namedtuple('ScanPositions',
['motor_0', 'pos_0', 'motor_1', 'pos_1', 'shape'])
MOTORCOLS = {"pix":"adcY",
"piy":"adcX",
"piz":"adcZ"}
MOTORCOLS = {"pix": "adcY",
"piy": "adcX",
"piz": "adcZ"}
def _process_entry(method):
......@@ -61,7 +62,6 @@ def _process_entry(method):
return _method
class XsocsH5(XsocsH5Base):
TOP_ENTRY = 'global'
......@@ -89,13 +89,7 @@ class XsocsH5(XsocsH5Base):
def _update_entries(self):
with self._get_file() as h5_file:
# TODO : this isnt pretty but for some reason the attrs.get() fails
# when there is no attribute NX_class (should return the default
# None)
self.__entries = sorted([key for key in h5_file
if ('NX_class' in h5_file[key].attrs and
h5_file[key].attrs[
'NX_class'].decode() == 'NXentry')])
self.__entries = sorted(find_NX_class(h5_file, 'NXentry'))
def entries(self):
if self.__entries is None:
......@@ -152,7 +146,6 @@ class XsocsH5(XsocsH5Base):
"""
return self.__detector_params(entry, 'image_roi_offset')
@_process_entry
def n_images(self, entry):
# TODO : make sure that data.ndims = 3
......@@ -176,8 +169,9 @@ class XsocsH5(XsocsH5Base):
@_process_entry
def image_cumul(self, entry, dtype=None):
"""
Returns the summed intensity for each image.
"""Returns the summed intensity for each image.
:param str entry:
:param dtype: dtype passed to the numpy.sum function.
Default is numpy.double.
:type dtype: numpy.dtype
......@@ -198,16 +192,20 @@ class XsocsH5(XsocsH5Base):
def scan_positions(self, entry):
path = self.measurement_tpl.format(entry)
params = self.scan_params(entry)
m0 = '/{0}'.format(MOTORCOLS[params['motor_0'].decode()])
m1 = '/{0}'.format(MOTORCOLS[params['motor_1'].decode()])
motors = [m.decode() if hasattr(m, 'decode') else m
for m in (params['motor_0'], params['motor_1'])]
m0 = '/{0}'.format(MOTORCOLS[motors[0]])
m1 = '/{0}'.format(MOTORCOLS[motors[1]])
n_0 = params['motor_0_steps']
n_1 = params['motor_1_steps']
x_pos = self._get_array_data(path + m0)
y_pos = self._get_array_data(path + m1)
return ScanPositions(motor_0=params['motor_0'],
return ScanPositions(motor_0=motors[0],
pos_0=x_pos,
motor_1=params['motor_1'],
motor_1=motors[1],
pos_1=y_pos,
shape=(n_0, n_1))
......@@ -226,12 +224,12 @@ class XsocsH5(XsocsH5Base):
@_process_entry
def is_regular_grid(self, entry):
# TODO
"""
For now grids are always regular
:param entry:
:return:
"""For now grids are always regular
:param str entry:
:rtype: bool
"""
# TODO
return True
@_process_entry
......@@ -245,9 +243,13 @@ class XsocsH5(XsocsH5Base):
'delay']
with self._get_file() as h5_file:
path = self.scan_params_tpl.format(entry) + '/{0}'
return OrderedDict([(param, h5_file.get(path.format(param),
_np.array(None))[()])
for param in param_names])
result = OrderedDict()
for param in param_names:
value = h5_file.get(path.format(param), _np.array(None))[()]
if hasattr(value, 'decode'):
value = value.decode()
result[param] = value
return result
@_process_entry
def positioner(self, entry, positioner):
......@@ -331,8 +333,8 @@ class XsocsH5(XsocsH5Base):
class XsocsH5Writer(XsocsH5):
def __init__(self, h5_f, mode='a', **kwargs):
super(XsocsH5Writer, self).__init__(h5_f, mode=mode, **kwargs)
def __init__(self, h5_f, mode='a'):
super(XsocsH5Writer, self).__init__(h5_f, mode=mode)
def __set_detector_params(self, entry, params):
with self._get_file() as h5_file:
......@@ -375,11 +377,11 @@ class XsocsH5Writer(XsocsH5):
delay,
**kwargs):
params = OrderedDict([('motor_0', _np.string_(motor_0)),
params = OrderedDict([('motor_0', str_to_h5_utf8(motor_0)),
('motor_0_start', float(motor_0_start)),
('motor_0_end', float(motor_0_end)),
('motor_0_steps', int(motor_0_steps)),
('motor_1', _np.string_(motor_1)),
('motor_1', str_to_h5_utf8(motor_1)),
('motor_1_start', float(motor_1_start)),
('motor_1_end', float(motor_1_end)),
('motor_1_steps', int(motor_1_steps)),
......@@ -392,29 +394,29 @@ class XsocsH5Writer(XsocsH5):
def create_entry(self, entry):
with self._get_file() as h5_file:
entry_grp = h5_file.require_group(entry)
entry_grp.attrs['NX_class'] = _np.string_('NXentry')
entry_grp.attrs['NX_class'] = str_to_h5_utf8('NXentry')
# creating mandatory groups and setting their Nexus attributes
grp = entry_grp.require_group('measurement/image')
grp.attrs['interpretation'] = _np.string_('image')
grp.attrs['interpretation'] = str_to_h5_utf8('image')
# setting the nexus classes
#entry_grp.attrs['NX_class'] = _np.string_('NXentry')
# entry_grp.attrs['NX_class'] = str_to_h5_utf8('NXentry')
grp = entry_grp.require_group('instrument')
grp.attrs['NX_class'] = _np.string_('NXinstrument')
grp.attrs['NX_class'] = str_to_h5_utf8('NXinstrument')
grp = entry_grp.require_group('instrument/detector')
grp.attrs['NX_class'] = _np.string_('NXdetector')
grp.attrs['NX_class'] = str_to_h5_utf8('NXdetector')
grp = entry_grp.require_group('instrument/positioners')
grp.attrs['NX_class'] = _np.string_('NXcollection')
grp.attrs['NX_class'] = str_to_h5_utf8('NXcollection')
grp = entry_grp.require_group('measurement')
grp.attrs['NX_class'] = _np.string_('NXcollection')
grp.attrs['NX_class'] = str_to_h5_utf8('NXcollection')
grp = entry_grp.require_group('measurement/image')
grp.attrs['NX_class'] = _np.string_('NXcollection')
grp.attrs['NX_class'] = str_to_h5_utf8('NXcollection')
# creating some links
grp = entry_grp.require_group('measurement/image')
......
......@@ -36,6 +36,25 @@ from contextlib import contextmanager
import h5py as _h5py
import numpy as _np
from ..util import text_type
from ._utils import str_to_h5_utf8
# We have to work around a limitation of the h5py.Group.copy method
# that fails when a group already exists in the destination file.
def _copy_obj(name, obj, src_grp=None, dest_grp=None):
if isinstance(obj, _h5py.Group):
dest_grp.require_group(name)
else:
src_grp.copy(name,
dest_grp,
name=name,
shallow=False,
expand_soft=True,
expand_external=True,
expand_refs=True,
without_attrs=False)
class XsocsH5Base(object):
# TODO : mechanism to test file type (isValid whatever)
......@@ -144,7 +163,27 @@ class XsocsH5Base(object):
except KeyError:
return None
def _get_array_str(self, path):
"""Returns the array of string contained in the dataset.
:param str path: The path of the dataset in the HDF5 file
:rtype: List[str]
"""
strings = self._get_array_data(path=path)
if strings is None:
return None
else:
return [s.decode() if hasattr(s, 'decode') else s for s in strings]
def _set_scalar_data(self, path, value):
"""Write a scalar or string value to a given dataset
:param str path: Dataset path
:param Union[float,int,str] value:
"""
if isinstance(value, text_type):
value = str_to_h5_utf8(value)
with self._get_file() as h5_f:
value_np = _np.array(value)
dset = h5_f.require_dataset(path,
......@@ -153,11 +192,10 @@ class XsocsH5Base(object):
dset[()] = value
def _set_array_data(self, path, value):
"""
Sets the given numpy array at the given path in this HDF5 file.
"""Sets the given numpy array at the given path in this HDF5 file.
:param path:
:param value:
:return:
"""
with self._get_file() as h5_f:
dset = h5_f.require_dataset(path,
......@@ -169,6 +207,7 @@ class XsocsH5Base(object):
"""
Creates a dataset as the given path. All extra arguments are passed
to h5py.DataSet.create_dataset.
:param path:
:param args:
:param kwargs:
......@@ -187,8 +226,8 @@ class XsocsH5Base(object):
@contextmanager
def item_context(self, item_path, **kwargs):
"""
Context manager for the image dataset.
"""Context manager for the image dataset.
WARNING: only to be used as a context manager!
WARNING: the data set must exist. see also QSpaceH5Writer.init_cube
"""
......@@ -202,26 +241,11 @@ class XsocsH5Base(object):
del item
def copy_group(self, src_h5f, src_path, dest_path):
"""
Recursively copies an object from one HDF5 file to another.
"""Recursively copies an object from one HDF5 file to another.
Warning : it fails if it finds a conflict with an already existing
dataset.
"""
# We have to work around a limitation of the h5py.Group.copy method
# that fails when a group already exists in the destination file.
def _copy_obj(name, obj, src_grp=None, dest_grp=None):
if isinstance(obj, _h5py.Group):
dest_grp.require_group(name)
else:
src_grp.copy(name,
dest_grp,
name=name,
shallow=False,
expand_soft=True,
expand_external=True,
expand_refs=True,
without_attrs=False)
with _h5py.File(src_h5f, 'r') as src_h5:
with self._get_file() as h5_file:
src_grp = src_h5[src_path]
......
......@@ -37,3 +37,18 @@ def str_to_h5_utf8(text):
:rtype: numpy.ndarray
"""
return numpy.array(text, dtype=h5py.special_dtype(vlen=text_type))
def find_NX_class(group, nx_class):
"""Yield name of items in group of nx_class NX_class
:param h5py.Group group:
:param str nx_class:
:rtype: Iterable[str]
"""
for key, item in group.items():
cls = item.attrs.get('NX_class', '')
if hasattr(cls, 'decode'):
cls = cls.decode()
if cls == nx_class:
yield key
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment