Commit bae9a17f authored by payno's avatar payno
Browse files

NXDetector: fix issue loading data as virtual sources + add unit test for...

NXDetector: fix issue loading data as virtual sources + add unit test for loading and dumping the NXdetector data
parent 67a04be7
Pipeline #77877 passed with stages
in 5 minutes and 31 seconds
......@@ -37,6 +37,7 @@ from h5py import VirtualSource
from typing import Iterable, Optional, Union
import numpy
import h5py
from nxtomomill.utils.frameappender import FrameAppender
from nxtomomill.utils.h5pyutils import from_virtual_source_to_data_url
try:
......@@ -408,7 +409,13 @@ class NXdetector(NXobject):
self.data = urls
elif load_data_as == "as_virtual_source":
if dataset.is_virtual:
self.data = dataset.virtual_sources()
virtual_sources = []
for vs_info in dataset.virtual_sources():
_, vs = FrameAppender._recreate_vs(
vs_info=vs_info, vds_file=file_path
)
virtual_sources.append(vs)
self.data = virtual_sources
else:
raise ValueError(f"{data_dataset_path} is not virtual")
......
......@@ -308,3 +308,46 @@ def test_nx_detector_with_external_urls():
concatenated_nx_detector = NXdetector.concatenate([nx_detector, nx_detector])
assert isinstance(concatenated_nx_detector.data[1], DataUrl)
assert len(concatenated_nx_detector.data) == n_base_raw_dataset * 2
@pytest.mark.parametrize(
"load_data_as, expected_type",
[
("as_numpy_array", numpy.ndarray),
("as_virtual_source", h5py.VirtualSource),
("as_data_url", DataUrl),
],
)
def test_load_detector_data(tmp_path, load_data_as, expected_type):
print("load_data_as is", load_data_as)
print("expected_type is", expected_type)
layout = h5py.VirtualLayout(shape=(4 * 2, 100, 100), dtype="i4")
for n in range(0, 4):
filename = os.path.join(tmp_path, "{n}.h5")
with h5py.File(filename, "w") as f:
f["data"] = numpy.arange(100 * 100 * 2).reshape(2, 100, 100)
vsource = h5py.VirtualSource(filename, "data", shape=(2, 100, 100))
start_n = n * 2
end_n = start_n + 2
layout[start_n:end_n] = vsource
output_file = os.path.join(tmp_path, "VDS.h5")
with h5py.File(output_file, "w") as f:
f.create_virtual_dataset("data", layout, fillvalue=-5)
nx_detector = NXdetector()
nx_detector._load(
file_path=output_file,
data_path="/",
load_data_as=load_data_as,
nexus_version=None,
)
if expected_type is numpy.ndarray:
assert isinstance(nx_detector.data, expected_type)
else:
for elmt in nx_detector.data:
assert isinstance(elmt, expected_type)
nx_detector.save(os.path.join(tmp_path, "output_file.nx"))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment