Commit 8f152eb3 authored by Henri Payno's avatar Henri Payno
Browse files

[validator] improve: speed up and handle partially vds - virtual source

- now we work on caompacted urls to avoid several file opening / closing
- we look at the dataset specific virtual source and not all sources.
  This is handled only at level 1. Not at virtual source src_space
parent 4aaebacc
Pipeline #53589 passed with stages
in 6 minutes and 32 seconds
......@@ -38,6 +38,7 @@ import numpy
import logging
import sys
from tomoscan.io import HDF5File
import warnings
_logger = logging.getLogger(__name__)
......@@ -104,7 +105,9 @@ def extract_urls_from_edf(
return res
def get_compacted_dataslices(urls, max_grp_size=None, return_merged_indices=False):
def get_compacted_dataslices(
urls: dict, max_grp_size=None, return_merged_indices=False, return_url_set=False
):
"""
Regroup urls to get the data more efficiently.
Build a structure mapping files indices to information on
......@@ -115,6 +118,10 @@ def get_compacted_dataslices(urls, max_grp_size=None, return_merged_indices=Fals
a silx `DataUrl`.
:param max_grp_size: maximum size of url grps
:type max_grp_size: None or int
:param bool return_merged_indices: if True return the last merged indices.
Deprecated
:param bool return_url_set: return a set with url containing `urls` slices
and data path
:return: Dictionary where the key is a list of indices, and the value is
the corresponding `silx.io.url.DataUrl` with merged data_slice
......@@ -137,6 +144,11 @@ def get_compacted_dataslices(urls, max_grp_size=None, return_merged_indices=Fals
def merge_slices(slice1, slice2):
return slice(slice1.start, slice2.stop)
if return_merged_indices is True:
warnings.warn(
"return_merged_indices is deprecated. It will be removed in version 0.8"
)
if max_grp_size is None:
max_grp_size = sys.maxsize
......@@ -195,6 +207,18 @@ def get_compacted_dataslices(urls, max_grp_size=None, return_merged_indices=Fals
),
)
)
if return_url_set:
url_set = {}
for _, url in res.items():
path = url.file_path(), url.data_path(), str(url.data_slice())
url_set[url.file_path(), url.data_path(), str(url.data_slice())] = url
if return_merged_indices:
return res, merge_slices, url_set
else:
return res, url_set
if return_merged_indices:
return res, merged_slices
else:
......@@ -297,6 +321,15 @@ def get_datasets_linked_to_vds(url: DataUrl):
start_file_path = url.file_path()
start_dataset_path = url.data_path()
start_dataset_slice = url.data_slice()
if isinstance(start_dataset_slice, slice):
start_dataset_slice = tuple(
range(
start_dataset_slice.start,
start_dataset_slice.stop,
start_dataset_slice.step or 1,
)
)
print("inital dataset_slice is", start_dataset_slice[0], start_dataset_slice[-1])
virtual_dataset_to_treat = set()
final_dataset = set()
......@@ -317,7 +350,6 @@ def get_datasets_linked_to_vds(url: DataUrl):
if os.path.exists(file_path):
with HDF5File(file_path, mode="r") as h5f:
dataset = h5f[dataset_path]
if dataset.is_virtual:
for vs_info in dataset.virtual_sources():
min_frame_bound = vs_info.vspace.get_select_bounds()[0][0]
......@@ -329,18 +361,23 @@ def get_datasets_linked_to_vds(url: DataUrl):
<= max_frame_bound
):
continue
elif isinstance(dataset_slice, slice):
elif isinstance(dataset_slice, tuple):
if (
data_slice.end <= min_frame_bound
and data_slice.start >= max_frame_bound
min_frame_bound > dataset_slice[-1]
or max_frame_bound < dataset_slice[0]
):
continue
os.chdir(os.path.dirname(file_path))
# Fixme: For now will look at the entire dataset of the n +1 file.
# if those can also contains virtual dataset and we want to handle
# the case a part of it is broken but not ours this should handle
# hyperslab
virtual_dataset_to_treat.add(
(
os.path.realpath(vs_info.file_name),
vs_info.dset_name,
vs_info.vspace.get_select_bounds(),
None,
)
)
else:
......
......@@ -133,6 +133,9 @@ def test_frame_broken_vds(validator_cls):
scan_path=os.path.join(tempfile.mkdtemp(), "scan_test"),
n_proj=10,
n_ini_proj=10,
create_ini_dark=True,
create_ini_ref=True,
create_final_ref=False,
) as scan:
validator = validator_cls(scan=scan, check_vds=True, check_values=False)
assert (
......@@ -143,14 +146,15 @@ def test_frame_broken_vds(validator_cls):
# modify 'data' dataset to set a virtual dataset with broken link (file does not exists)
with h5py.File(scan.master_file, mode="a") as h5f:
detector_grp = h5f[scan.entry]["instrument/detector"]
shape = detector_grp["data"].shape
del detector_grp["data"]
# create invalid VDS
layout = h5py.VirtualLayout(shape=(1, 100), dtype="i4")
layout = h5py.VirtualLayout(shape=shape, dtype="i4")
filename = "toto.h5"
vsource = h5py.VirtualSource(filename, "data", shape=(100,))
layout[0] = vsource
vsource = h5py.VirtualSource(filename, "data", shape=shape)
layout[0 : shape[0]] = vsource
detector_grp.create_virtual_dataset("data", layout)
......
......@@ -159,10 +159,15 @@ class _VdsAndValuesValidatorMixIn:
def check_vds(self):
return self._check_vds
def check_urls(self, urls):
def check_urls(self, urls: dict):
if urls is None:
return True
_, compacted_urls = get_compacted_dataslices(urls, return_url_set=True)
if self.check_vds:
# compact urls to speed up
for url in urls:
for _, url in compacted_urls.items():
if dataset_has_broken_vds(url=url):
self._vds_ok = False
return False
......@@ -171,7 +176,7 @@ class _VdsAndValuesValidatorMixIn:
if self.check_values:
self._no_nan = True
for url in urls:
for _, url in compacted_urls.items():
data = get_data(url)
self._no_nan = self._no_nan and not numpy.isnan(data).any()
return self._no_nan
......@@ -220,7 +225,7 @@ class DarkDatasetValidator(DarkEntryValidator, _VdsAndValuesValidatorMixIn):
if self._has_data is False:
return False
return _VdsAndValuesValidatorMixIn.check_urls(self, self.scan.darks.values())
return _VdsAndValuesValidatorMixIn.check_urls(self, self.scan.darks)
def info(self, with_scan=True):
return _VdsAndValuesValidatorMixIn.info(self, with_scan)
......@@ -255,7 +260,7 @@ class FlatDatasetValidator(FlatEntryValidator, _VdsAndValuesValidatorMixIn):
if self._has_data is False:
return False
return _VdsAndValuesValidatorMixIn.check_urls(self, self.scan.flats.values())
return _VdsAndValuesValidatorMixIn.check_urls(self, self.scan.flats)
def info(self, with_scan=True):
return _VdsAndValuesValidatorMixIn.info(self, with_scan)
......@@ -292,9 +297,7 @@ class ProjectionDatasetValidator(ProjectionEntryValidator, _VdsAndValuesValidato
if self._has_data is False:
return False
return _VdsAndValuesValidatorMixIn.check_urls(
self, self.scan.projections.values()
)
return _VdsAndValuesValidatorMixIn.check_urls(self, self.scan.projections)
def info(self, with_scan=True):
return _VdsAndValuesValidatorMixIn.info(self, with_scan)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment