Commit e8b82ce9 authored by Matias Guijarro's avatar Matias Guijarro
Browse files

Merge branch '2185-protect-metadata-gathering-from-parallel-scans' into 'master'

Resolve "Protect metadata gathering from parallel scans"

Closes #2185

See merge request !2995
parents 0142f745 0c3dc13b
Pipeline #36672 passed with stages
in 87 minutes and 36 seconds
This diff is collapsed.
......@@ -13,7 +13,7 @@ import typing
from bliss.common.counter import Counter
from bliss.common.axis import Axis
from bliss.data.nodes.scan import get_data_from_nodes
from bliss.data.node import _get_or_create_node
from bliss.data.node import get_or_create_node
from bliss.common.utils import get_matching_names
......@@ -94,7 +94,7 @@ def watch_session_scans(
watch_scan_group: If True the scan groups are also listed like any other
scans
"""
session_node = _get_or_create_node(session_name, node_type="session")
session_node = get_or_create_node(session_name, node_type="session")
if session_node is None:
return
running_scans = dict()
......
......@@ -93,24 +93,37 @@ class Dataset(DataPolicyObject):
super().__init__(node)
self.definitions = Definitions()
def gather_metadata(self):
"""Initialize the dataset node info"""
def gather_metadata(self, on_exists=None):
"""Initialize the dataset node info.
When metadata already exists in Redis:
on_exists="skip": do nothing
on_exists="overwrite": overwrite in Redis
else: raise RuntimeError
"""
if self.is_closed:
raise RuntimeError("The dataset is already closed")
if self._node.info.get("__metadata_gathered__"):
raise RuntimeError("metadata for this dataset has already been collected!")
if self.metadata_gathering_done:
if on_exists == "skip":
return
elif on_exists == "overwrite":
pass
else:
raise RuntimeError("Metadata gathering already done")
# Gather metadata
if current_session.icat_mapping:
metadata = current_session.icat_mapping.get_metadata()
infodict = current_session.icat_mapping.get_metadata()
else:
metadata = dict()
infodict = dict()
metadata["startDate"] = datetime.datetime.now().isoformat()
# Add additional metadata
infodict["startDate"] = datetime.datetime.now().isoformat()
assert isinstance(metadata, dict)
for k, v in metadata.items():
# Check metadata
for k, v in infodict.items():
assert self.validate_fieldname(
k
), f"{k} is not an accepted key in this dataset!"
......@@ -118,10 +131,11 @@ class Dataset(DataPolicyObject):
v, str
), f"{v} is not an accepted value for ICAT (only strings are allowed)!"
self._node.info["__closed__"] = False
self._node.info.update(metadata)
# Add other info keys (not metadata)
infodict["__metadata_gathered__"] = True
self._node.info["__metadata_gathered__"] = True
# Update the node's info
self._node.info.update(infodict)
@property
def metadata_gathering_done(self):
......
from pprint import pprint
from bliss.data.node import get_node
from bliss.data.node import get_session_node
def demo_listener(session_name):
n = get_node(session_name)
n = get_session_node(session_name)
for dataset in n.walk(wait=False, filter="dataset"):
if dataset.is_closed:
print(f"dataset {dataset.db_name} is closed the collected metadata is:")
......
......@@ -19,14 +19,14 @@ from bliss.scanning.scan import Scan
from bliss.data.nodes.scan import ScanNode
from bliss.data.node import get_session_node
from bliss.scanning.scan import ScanState, ScanPreset
from bliss.data.node import _create_node
from bliss.data.node import create_node
from bliss import current_session
from bliss.common.logtools import user_warning
class ScanGroup(Scan):
def _create_data_node(self, node_name):
self._Scan__node = _create_node(
self._Scan__node = create_node(
node_name, "scan_group", parent=self.root_node, info=self._scan_info
)
......
......@@ -33,7 +33,7 @@ from bliss.common.utils import Null, update_node_info, round
from bliss.common.profiling import Statistics, time_profile
from bliss.controllers.motor import Controller
from bliss.config.settings_cache import CacheConnection
from bliss.data.node import _get_or_create_node, _create_node
from bliss.data.node import get_or_create_node, create_node
from bliss.data.scan import get_data
from bliss.scanning.chain import AcquisitionSlave, AcquisitionMaster, StopChain
from bliss.scanning.writer.null import Writer as NullWriter
......@@ -791,9 +791,9 @@ class Scan:
Important: has to be a method, since it can be overwritten in Scan subclasses (like Sequence)
"""
self.__node = _create_node(
self.__node = create_node(
node_name,
"scan",
node_type="scan",
parent=self.root_node,
info=self._scan_info,
connection=self._cache_cnx,
......@@ -1216,10 +1216,10 @@ class Scan:
def _prepare_channels(self, channels, parent_node):
for channel in channels:
chan_name = channel.short_name
channel_node = _get_or_create_node(
channel_node = get_or_create_node(
chan_name,
channel.data_node_type,
parent_node,
node_type=channel.data_node_type,
parent=parent_node,
shape=channel.shape,
dtype=channel.dtype,
unit=channel.unit,
......@@ -1248,7 +1248,7 @@ class Scan:
else:
parent_node = self.nodes[dev_node.bpointer]
if isinstance(dev, (AcquisitionSlave, AcquisitionMaster)):
data_container_node = _create_node(
data_container_node = create_node(
dev.name, parent=parent_node, connection=self._cache_cnx
)
self._cache_cnx.add_prefetch(data_container_node)
......
......@@ -24,7 +24,7 @@ import enum
from bliss import current_session
from bliss.config.settings import ParametersWardrobe
from bliss.config.settings_cache import get_redis_client_cache
from bliss.data.node import _get_node, _get_or_create_node
from bliss.data.node import datanode_factory
from bliss.scanning.writer.null import Writer as NullWriter
from bliss.scanning import writer as writer_module
from bliss.common.proxy import Proxy
......@@ -735,11 +735,13 @@ class BasicScanSaving(EvalParametersWardrobe):
node = None
if create:
for item_name, node_type in db_path_items:
node = _get_or_create_node(item_name, node_type, parent=node)
node = datanode_factory(
item_name, node_type, parent=node, on_not_state="create"
)
self._fill_node_info(node, node_type)
else:
for item_name, node_type in db_path_items:
node = _get_node(item_name, node_type, parent=node)
node = datanode_factory(item_name, parent=node, on_not_state=None)
if node is None:
return None
return node
......@@ -959,6 +961,7 @@ class ESRFScanSaving(BasicScanSaving):
def _db_path_keys(self, eval_dict=None):
session = self.session
base_path = self.get_cached_property("base_path", eval_dict).split(os.sep)
base_path = [p for p in base_path if p]
proposal = self.get_cached_property("proposal_name", eval_dict)
sample = self.get_cached_property("sample_name", eval_dict)
# When dataset="0001" the DataNode.name will be the integer 1
......@@ -1540,5 +1543,4 @@ class ESRFScanSaving(BasicScanSaving):
dataset = self.dataset # Created in Redis when missing
if dataset.is_closed:
raise RuntimeError("Dataset is already closed (choose a different name)")
if not dataset.metadata_gathering_done:
dataset.gather_metadata()
dataset.gather_metadata(on_exists="skip")
......@@ -21,8 +21,7 @@ import logging
import traceback
from gevent.time import time
from contextlib import contextmanager
from bliss.data.node import get_node as _get_node
from bliss.data.node import _get_node_object
from bliss.data.node import datanode_factory
from bliss.config.streaming import DataStreamReaderStopHandler
from ..utils.logging_utils import CustomLogger
from ..io import io_utils
......@@ -51,12 +50,11 @@ class PeriodicTask(object):
def get_node(node_type, db_name):
"""
Get DataNode instance event if the Redis node does not exist yet
Get DataNode instance even if the Redis node does not exist yet
"""
node = _get_node(db_name)
if node is None:
node = _get_node_object(node_type, db_name, None, None)
return node
return datanode_factory(
db_name, node_type=node_type, state="exists", on_not_state="instantiate"
)
class BaseSubscriber(object):
......
......@@ -1149,3 +1149,10 @@ def test_electronic_logbook(session, icat_logbook_subscriber, esrf_data_policy):
category=category,
scan_saving=scan_saving,
)
def test_parallel_scans(session, esrf_data_policy):
glts = [
gevent.spawn(session.scan_saving.clone().on_scan_run, True) for _ in range(100)
]
gevent.joinall(glts, raise_error=True, timeout=10)
......@@ -26,8 +26,8 @@ from bliss.data.node import (
get_node,
DataNode,
DataNodeContainer,
_get_or_create_node,
_get_node_object,
get_or_create_node,
datanode_factory,
sessions_list,
get_last_saved_scan,
)
......@@ -359,7 +359,7 @@ def test_iterator_over_reference_with_lima(redis_data_conn, lima_session, with_r
if with_roi:
lima_sim.roi_counters["myroi"] = [0, 0, 1, 1]
session_node = _get_or_create_node(lima_session.name, node_type="session")
session_node = get_or_create_node(lima_session.name, node_type="session")
with gevent.Timeout(10 + 2 * (npoints + 1) * exp_time):
......@@ -779,7 +779,9 @@ def _count_node_events(
def walk():
"""Stops walking if no event has been received for x seconds
"""
node = _get_node_object(node_type, db_name, None, None)
node = datanode_factory(
db_name, node_type=node_type, on_not_state="instantiate"
)
startlistening_event.set()
if count_nodes:
evgen = node.walk(filter=filter, wait=wait)
......
......@@ -12,7 +12,7 @@ import re
from bliss.scanning.group import Sequence, Group
from bliss.common import scans
from bliss.data.node import get_node, _get_or_create_node
from bliss.data.node import get_node, get_or_create_node
from bliss.data.nodes.node_ref_channel import NodeRefChannel
from bliss.data.nodes.scan import ScanNode
from bliss.scanning.chain import AcquisitionChannel
......@@ -221,7 +221,7 @@ def test_sequence_invalid_group(session):
with pytest.raises(RuntimeError):
g = Group(s1, s2)
n = _get_or_create_node("bla:bla:bla")
n = get_or_create_node("bla:bla:bla")
with pytest.raises(RuntimeError):
g = Group(s1, n)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment