Commit ae02d46b authored by payno's avatar payno
Browse files

Merge branch 'add_concatenate_op_to_nexus' into 'master'

nexus: add function to concatenate several NXtomo together

See merge request !109
parents 24c377d8 2b95e4e4
Pipeline #76435 passed with stages
in 7 minutes and 5 seconds
......@@ -3,3 +3,4 @@ from .nxobject import NXobject # noqa F401
from .nxsample import NXsample # noqa F401
from .nxsource import NXsource # noqa F401
from .nxtomo import NXtomo # noqa F401
from .utils import concatenate # noqa F401
......@@ -29,6 +29,8 @@ __license__ = "MIT"
__date__ = "03/02/2022"
from functools import partial
from operator import is_not
from silx.utils.proxy import docstring
from silx.io.url import DataUrl
from h5py import VirtualSource
......@@ -50,7 +52,17 @@ from tomoscan.unitsystem import TimeSystem
from tomoscan.unitsystem import Unit
from tomoscan.io import HDF5File
from h5py import h5s as h5py_h5s
try:
from h5py._hl.vds import VDSmap
except ImportError:
has_VDSmap = False
else:
has_VDSmap = True
import h5py._hl.selections as selection
import logging
_logger = logging.getLogger(__name__)
class NXdetector(NXobject):
......@@ -101,10 +113,11 @@ class NXdetector(NXobject):
)
elif isinstance(data, (tuple, list)):
for elmt in data:
if not isinstance(elmt, (DataUrl, VirtualSource)):
raise TypeError(
f"element of 'data' are expected to be silx DataUrl or h5py virtualSource. Not {type(elmt)}"
)
if has_VDSmap:
if not isinstance(elmt, (DataUrl, VirtualSource, VDSmap)):
raise TypeError(
f"element of 'data' are expected to be silx DataUrl or h5py virtualSource. Not {type(elmt)}"
)
data = tuple(data)
elif data is None:
pass
......@@ -447,6 +460,116 @@ class NXdetector(NXobject):
data_path="/".join([data_path, nexus_detector_paths.IMAGE_KEY_CONTROL]),
)
@staticmethod
def _concatenate_except_data(nx_detector, nx_objects: tuple):
image_key_ctrl = [
nx_obj.image_key_control
for nx_obj in nx_objects
if nx_obj.image_key_control is not None
]
if len(image_key_ctrl) > 0:
nx_detector.image_key_control = numpy.concatenate(image_key_ctrl)
# note: image_key is deduced from image_key_control
nx_detector.x_pixel_size = nx_objects[0].x_pixel_size.value
nx_detector.y_pixel_size = nx_objects[0].y_pixel_size.value
nx_detector.distance = nx_objects[0].distance.value
nx_detector.field_of_view = nx_objects[0].field_of_view
nx_detector.estimated_cor_from_motor = nx_objects[0].estimated_cor_from_motor
for nx_obj in nx_objects[1:]:
if nx_detector.x_pixel_size.value and not numpy.isclose(
nx_detector.x_pixel_size.value, nx_obj.x_pixel_size.value
):
_logger.warning(
f"found different x pixel size value. ({nx_detector.x_pixel_size.value} vs {nx_obj.x_pixel_size.value}). Pick the first one"
)
if nx_detector.y_pixel_size.value and not numpy.isclose(
nx_detector.y_pixel_size.value, nx_obj.y_pixel_size.value
):
_logger.warning(
f"found different y pixel size value. ({nx_detector.y_pixel_size.value} vs {nx_obj.y_pixel_size.value}). Pick the first one"
)
if nx_detector.distance.value and not numpy.isclose(
nx_detector.distance.value, nx_obj.distance.value
):
_logger.warning(
f"found different distance value. ({nx_detector.distance.value} vs {nx_obj.distance.value}). Pick the first one"
)
if (
nx_detector.field_of_view
and nx_detector.field_of_view != nx_obj.field_of_view
):
_logger.warning(
f"found different field_of_view value. ({nx_detector.field_of_view} vs {nx_obj.field_of_view}). Pick the first one"
)
if nx_detector.estimated_cor_from_motor and not numpy.isclose(
nx_detector.estimated_cor_from_motor, nx_obj.estimated_cor_from_motor
):
_logger.warning(
f"found different estimated_cor_from_motor value. ({nx_detector.estimated_cor_from_motor} vs {nx_obj.estimated_cor_from_motor}). Pick the first one"
)
@docstring(NXobject)
def concatenate(nx_objects: tuple, node_name="detector"):
# filter None obj
nx_objects = tuple(filter(partial(is_not, None), nx_objects))
if len(nx_objects) == 0:
return None
# warning: later we make the assumption that nx_objects contains at least one element
for nx_obj in nx_objects:
if not isinstance(nx_obj, NXdetector):
raise TypeError("Cannot concatenate non NXinstrument object")
nx_detector = NXdetector(node_name=node_name)
NXdetector._concatenate_except_data(
nx_objects=nx_objects, nx_detector=nx_detector
)
# now handle data on it's own
detector_data = [
nx_obj.data for nx_obj in nx_objects if nx_obj.data is not None
]
if len(detector_data) > 0:
if isinstance(detector_data[0], numpy.ndarray):
# store_as = "as_numpy_array"
expected = numpy.ndarray
elif isinstance(detector_data[0], Iterable):
if isinstance(detector_data[0][0], h5py.VirtualSource):
# store_as = "as_virtual_source"
expected = h5py.VirtualSource
elif isinstance(detector_data[0][0], DataUrl):
# store_as = "as_data_url"
expected = DataUrl
else:
raise TypeError(
f"detector data is expected to be a numpy array or a h5py.VirtualSource or a numpy array. {type(detector_data[0][0])} is not handled."
)
else:
raise TypeError(
f"detector data is expected to be a numpy array or a h5py.VirtualSource or a numpy array. {type(detector_data[0])} is not handled."
)
for data in detector_data:
if expected in (DataUrl, h5py.VirtualSource):
# for DataUrl and VirtualSource check type of the element
cond = isinstance(data[0], expected)
else:
cond = isinstance(data, expected)
if not cond:
raise TypeError(
f"Incoherent data type cross detector data ({type(data)} when {expected} expected)"
)
if expected in (DataUrl, h5py.VirtualSource):
new_data = []
[new_data.extend(data) for data in detector_data]
else:
new_data = numpy.concatenate(detector_data)
nx_detector.data = new_data
return nx_detector
class NXdetectorWithUnit(NXdetector):
def __init__(
......@@ -513,3 +636,82 @@ class NXdetectorWithUnit(NXdetector):
"linear",
]
return nx_dict
@docstring(NXobject)
def concatenate(
nx_objects: tuple, default_unit, expected_dim, node_name="detector"
):
# filter None obj
nx_objects = tuple(filter(partial(is_not, None), nx_objects))
if len(nx_objects) == 0:
return None
# warning: later we make the assumption that nx_objects contains at least one element
for nx_obj in nx_objects:
if not isinstance(nx_obj, NXdetector):
raise TypeError("Cannot concatenate non NXinstrument object")
nx_detector = NXdetectorWithUnit(
node_name=node_name, default_unit=default_unit, expected_dim=expected_dim
)
NXdetector._concatenate_except_data(
nx_objects=nx_objects, nx_detector=nx_detector
)
# now handle data on it's own
detector_data = [
nx_obj.data.value
for nx_obj in nx_objects
if (nx_obj.data is not None and nx_obj.data.value is not None)
]
detector_units = set(
[
nx_obj.data.unit
for nx_obj in nx_objects
if (nx_obj.data is not None and nx_obj.data.value is not None)
]
)
if len(detector_units) > 1:
# with DataUrl and Virtual Sources we are not able to do conversion
raise ValueError("More than one units found. Unagle to build the detector")
if len(detector_data) > 0:
if isinstance(detector_data[0], numpy.ndarray):
# store_as = "as_numpy_array"
expected = numpy.array
elif isinstance(detector_data[0], Iterable):
if isinstance(detector_data[0][0], h5py.VirtualSource):
# store_as = "as_virtual_source"
expected = h5py.VirtualSource
elif isinstance(detector_data[0][0], DataUrl):
# store_as = "as_data_url"
expected = DataUrl
else:
raise TypeError(
f"detector data is expected to be a numpy array or a h5py.VirtualSource or a numpy array. {type(detector_data[0][0])} is not handled."
)
else:
raise TypeError(
f"detector data is expected to be a numpy array or a h5py.VirtualSource or a numpy array. {type(detector_data[0])} is not handled."
)
for data in detector_data:
if expected in (DataUrl, h5py.VirtualSource):
# for DataUrl and VirtualSource check type of the element
cond = isinstance(data[0], expected)
else:
cond = isinstance(data, expected)
if not cond:
raise TypeError(
f"Incoherent data type cross detector data ({type(data)} when {expected} expected)"
)
if expected in (DataUrl, h5py.VirtualSource):
new_data = []
[new_data.extend(data) for data in detector_data]
else:
new_data = numpy.concatenate(detector_data)
nx_detector.data.value = new_data
nx_detector.data.unit = list(detector_units)[0]
return nx_detector
......@@ -29,6 +29,8 @@ __license__ = "MIT"
__date__ = "10/02/2022"
from functools import partial
from operator import is_not
from silx.utils.proxy import docstring
from typing import Optional
from nxtomomill.nexus.nxdetector import NXdetector, NXdetectorWithUnit
......@@ -39,6 +41,9 @@ from tomoscan.unitsystem.voltagesystem import VoltageSystem
from .utils import get_data
from tomoscan.io import HDF5File
from .nxobject import NXobject
import logging
_logger = logging.getLogger(__name__)
class NXinstrument(NXobject):
......@@ -63,16 +68,6 @@ class NXinstrument(NXobject):
self._name = None
self._set_freeze(True)
@property
def name(self) -> Optional[str]:
return self._name
@name.setter
def name(self, name: Optional[str]):
if not isinstance(name, (str, type(None))):
raise TypeError(f"name is expected to be None or a str. Not {type(name)}")
self._name = name
@property
def detector(self) -> Optional[NXdetector]:
return self._detector
......@@ -205,3 +200,41 @@ class NXinstrument(NXobject):
file_path=file_path,
data_path="/".join([data_path, nexus_instrument_paths.NAME]),
)
@docstring(NXobject)
def concatenate(nx_objects: tuple, node_name="instrument"):
# filter None obj
nx_objects = tuple(filter(partial(is_not, None), nx_objects))
if len(nx_objects) == 0:
return None
# warning: later we make the assumption that nx_objects contains at least one element
for nx_obj in nx_objects:
if not isinstance(nx_obj, NXinstrument):
raise TypeError("Cannot concatenate non NXinstrument object")
nx_instrument = NXinstrument(node_name=node_name)
nx_instrument.name = nx_objects[0].name
_logger.info(f"instrument name {nx_objects[0].name} will be picked")
nx_instrument.source = NXsource.concatenate(
[nx_obj.source for nx_obj in nx_objects],
node_name="source",
)
nx_instrument.source.parent = nx_instrument
nx_instrument.diode = NXdetectorWithUnit.concatenate(
[nx_obj.diode for nx_obj in nx_objects],
node_name="diode",
expected_dim=(1,),
default_unit=VoltageSystem.VOLT,
)
nx_instrument.diode.parent = nx_instrument
nx_instrument.detector = NXdetector.concatenate(
[nx_obj.detector for nx_obj in nx_objects],
node_name="detector",
)
nx_instrument.detector.parent = nx_instrument
return nx_instrument
......@@ -309,3 +309,12 @@ class NXobject:
raise AttributeError("can't set attribute", __name)
else:
super().__setattr__(__name, __value)
@staticmethod
def concatenate(nx_objects: tuple, node_name: str):
"""
concatenate a tuple of NXobject into a single NXobject
:param Iterable Nx-objects: nx object to concatenate
:param str node_name: name of the node to create. Parent must be handled manually for now.
"""
raise NotImplementedError("Base class")
......@@ -29,6 +29,8 @@ __license__ = "MIT"
__date__ = "03/02/2022"
from functools import partial
from operator import is_not
from typing import Iterable, Optional
import numpy
......@@ -37,6 +39,9 @@ from .utils import cast_and_check_array_1D, get_data_and_unit, get_data
from silx.utils.proxy import docstring
from tomoscan.nexus.paths.nxtomo import get_paths as get_nexus_paths
from tomoscan.unitsystem.metricsystem import MetricSystem
import logging
_logger = logging.getLogger(__name__)
class NXsample(NXobject):
......@@ -190,3 +195,61 @@ class NXsample(NXobject):
data_path="/".join([data_path, nexus_sample_paths.Z_TRANSLATION]),
default_unit=MetricSystem.METER,
)
@docstring(NXobject)
def concatenate(nx_objects: tuple, node_name="sample"):
nx_objects = tuple(filter(partial(is_not, None), nx_objects))
# filter None obj
if len(nx_objects) == 0:
return None
# warning: later we make the assumption that nx_objects contains at least one element
for nx_obj in nx_objects:
if not isinstance(nx_obj, NXsample):
raise TypeError("Cannot concatenate non NXsample object")
nx_sample = NXsample(node_name)
_logger.info(f"sample name {nx_objects[0].name} will be picked")
nx_sample.name = nx_objects[0].name
rotation_angles = [
nx_obj.rotation_angle
for nx_obj in nx_objects
if nx_obj.rotation_angle is not None
]
if len(rotation_angles) > 0:
nx_sample.rotation_angle = numpy.concatenate(rotation_angles)
x_translations = [
nx_obj.x_translation.value * nx_obj.x_translation.unit.value
for nx_obj in nx_objects
if nx_obj.x_translation is not None
]
if len(x_translations) > 0:
nx_sample.x_translation = numpy.concatenate(x_translations)
y_translations = [
nx_obj.y_translation.value * nx_obj.y_translation.unit.value
for nx_obj in nx_objects
if nx_obj.y_translation.value is not None
]
if len(y_translations) > 0:
nx_sample.y_translation = numpy.concatenate(y_translations)
z_translations = [
nx_obj.z_translation.value * nx_obj.z_translation.unit.value
for nx_obj in nx_objects
if nx_obj.z_translation.value is not None
]
if len(z_translations) > 0:
nx_sample.z_translation = numpy.concatenate(z_translations)
rocking_list = list(
filter(
partial(is_not, None),
[nx_obj.rocking for nx_obj in nx_objects],
)
)
if len(rocking_list) > 0:
nx_sample.rocking = numpy.concatenate(rocking_list)
return nx_sample
......@@ -28,12 +28,17 @@ __license__ = "MIT"
__date__ = "03/02/2022"
from functools import partial
from operator import is_not
from typing import Optional, Union
from silx.utils.proxy import docstring
from .nxobject import NXobject
from silx.utils.enum import Enum as _Enum
from tomoscan.nexus.paths.nxtomo import get_paths as get_nexus_paths
from .utils import get_data
import logging
_logger = logging.getLogger(__name__)
class SourceType(_Enum):
......@@ -127,6 +132,24 @@ class NXsource(NXobject):
data_path="/".join([data_path, nexus_source_paths.TYPE]),
)
@docstring(NXobject)
def concatenate(nx_objects: tuple, node_name="source"):
# filter None obj
nx_objects = tuple(filter(partial(is_not, None), nx_objects))
if len(nx_objects) == 0:
return None
# warning: later we make the assumption that nx_objects contains at least one element
for nx_obj in nx_objects:
if not isinstance(nx_obj, NXsource):
raise TypeError("Cannot concatenate non NXsource object")
nx_souce = NXsource(node_name=node_name)
nx_souce.name = nx_objects[0].name
_logger.info(f"Take the first source name {nx_objects[0].name}")
nx_souce.type = nx_objects[0].type
_logger.info(f"Take the first source type {nx_objects[0].type}")
return nx_souce
class DefaultESRFSource(NXsource):
def __init__(self, node_name="source", parent=None) -> None:
......
......@@ -28,6 +28,8 @@ __license__ = "MIT"
__date__ = "03/02/2022"
from functools import partial
from operator import is_not
import os
from typing import Optional, Union
from tomoscan.io import HDF5File
......@@ -126,6 +128,7 @@ class NXtomo(NXobject):
raise TypeError(
f"sample is expected ot be an instance of {NXsample} or None. Not {type(sample)}"
)
self._sample = sample
@property
def energy(self) -> Optional[float]:
......@@ -332,3 +335,51 @@ class NXtomo(NXobject):
detector_data_as=detector_data_as,
)
return self
def concatenate(nx_objects: tuple, node_name=""):
"""
concatenate a tuple of NXobject into a single NXobject
:param tuple nx_objects:
:return: NXtomo instance which is the concatenation of the nx_objects
"""
nx_objects = tuple(filter(partial(is_not, None), nx_objects))
# filter None obj
if len(nx_objects) == 0:
return None
# warning: later we make the assumption that nx_objects contains at least one element
for nx_obj in nx_objects:
if not isinstance(nx_obj, NXtomo):
raise TypeError("Cannot concatenate non NXtomo object")
nx_tomo = NXtomo(node_name)
# check object concatenation can be handled
nx_tomo.energy = (
nx_objects[0].energy.value
* nx_objects[0].energy.unit.value
/ nx_tomo.energy.unit.value
)
for nx_obj in nx_objects:
if not numpy.isclose(nx_tomo.energy.value, nx_obj.energy.value):
_logger.warning(f"{nx_obj} and {nx_objects[0]} have different energy")
_logger.info(f"title {nx_objects[0].title} will be picked")
nx_tomo.title = nx_objects[0].title
start_times = [nx_obj.start_time for nx_obj in nx_objects]
end_times = [nx_obj.end_time for nx_obj in nx_objects]
nx_tomo.start_time = min(start_times)
nx_tomo.end_time = max(end_times)
nx_tomo.sample = NXsample.concatenate(
tuple([nx_obj.sample for nx_obj in nx_objects])
)
nx_tomo.sample.parent = nx_tomo
nx_tomo.instrument = NXinstrument.concatenate(
tuple([nx_obj.instrument for nx_obj in nx_objects]),
)
nx_tomo.instrument.parent = nx_tomo
return nx_tomo
......@@ -30,6 +30,7 @@ __date__ = "04/02/2022"
import tempfile
from nxtomomill.nexus.nxdetector import NXdetector, NXdetectorWithUnit, FieldOfView
from tomoscan.unitsystem.voltagesystem import VoltageSystem
from nxtomomill.utils import ImageKey
......@@ -99,6 +100,18 @@ def test_nx_detector():
with pytest.raises(AttributeError):
nx_detector.test = 12
# test nx_detector concatenation
concatenated_nx_detector = NXdetector.concatenate([nx_detector, nx_detector])
numpy.testing.assert_array_equal(
concatenated_nx_detector.image_key_control, [ImageKey.PROJECTION] * 10
)
assert concatenated_nx_detector.x_pixel_size.value == 1e-7
assert concatenated_nx_detector.y_pixel_size.value == 2e-7
assert concatenated_nx_detector.distance.value == 0.02
nx_detector.field_of_view = FieldOfView.HALF
nx_detector.count_time = [0.1] * 10
nx_detector.estimated_cor_from_motor = 0.5
def test_nx_detector_with_unit():
diode = NXdetectorWithUnit(
......@@ -116,6 +129,15 @@ def test_nx_detector_with_unit():
diode.data = numpy.random.random(12)
diode.data = (DataUrl(),)
# test nx_detector concatenation
concatenated_nx_detector = NXdetectorWithUnit.concatenate(
[diode, diode],
expected_dim=(1,),
default_unit=VoltageSystem.VOLT,
)
assert len(concatenated_nx_detector.data.value) == 2
assert isinstance(concatenated_nx_detector.data.value[1], DataUrl)
def test_nx_detector_with_virtual_source():
"""Insure detector data can be write from Virtual sources"""
......@@ -141,11 +163,11 @@ def test_nx_detector_with_virtual_source():
).reshape(base_raw_dataset_shape)
v_sources.append(h5py.VirtualSource(h5f["data"]))
nx_detecteur = NXdetector()
nx_detecteur.data = v_sources
nx_detector = NXdetector()
nx_detector.data = v_sources
detector_file = os.path.join(tmp_folder, "detector_file.hdf5")
nx_detecteur.save(file_path=detector_file, data_path="/")
nx_detector.save(file_path=detector_file, data_path="/")
# check the virtual dataset has been properly createde and linked
with h5py.File(detector_file, mode="r") as h5f_master:
......@@ -165,6 +187,11 @@ def test_nx_detector_with_virtual_source():
assert vs_info.file_name.startswith("./")
assert cwd == os.getcwd()
# check concatenation
concatenated_nx_detector = NXdetector.concatenate([nx_detector, nx_detector])
assert isinstance(concatenated_nx_detector.data[1], h5py.VirtualSource)
assert len(concatenated_nx_detector.data) == len(raw_files) * 2
def test_nx_detector_with_local_urls():
"""Insure detector data can be write from DataUrl linking to local dataset (in the same file)"""
......@@ -192,9 +219,9 @@ def test_nx_detector_with_local_urls():
scheme="silx",
)