Commit 65fee165 authored by Valentin Valls's avatar Valentin Valls
Browse files

Create a PreparedScanEvent

parent 1337cd66
......@@ -333,6 +333,11 @@ class StartEvent(TimeEvent):
TYPE = b"START"
class PreparedEvent(TimeEvent):
TYPE = b"PREPARED"
class EndEvent(TimeEvent):
TYPE = b"END"
......
......@@ -8,7 +8,7 @@
from bliss.config import streaming_events
__all__ = ["EndScanEvent"]
__all__ = ["EndScanEvent", "PreparedScanEvent"]
class EndScanEvent(streaming_events.EndEvent):
......@@ -28,3 +28,22 @@ class EndScanEvent(streaming_events.EndEvent):
def description(self):
"""Used to generate EventData description"""
return self.exception
class PreparedScanEvent(streaming_events.PreparedEvent):
TYPE = b"PREPARED_SCAN"
@classmethod
def merge(cls, events):
"""Keep only the first event.
:param list((index, raw)) events:
:returns PreparedScanEvent:
"""
return cls(raw=events[0][1])
@property
def description(self):
"""Used to generate EventData description"""
return None
......@@ -21,6 +21,7 @@ class EventType(Enum):
NEW_NODE = 1
NEW_DATA = 2
END_SCAN = 3
PREPARED_SCAN = 4
class EventData(NamedTuple):
......
......@@ -41,6 +41,7 @@ A ScanNode is represented by 4 Redis keys:
{db_name}_info -> see DataNodeContainer
{db_name}_children -> see DataNodeContainer
{db_name}_end -> contains the END event
{db_name}_prepared -> contains the PREPARED event
A ChannelDataNode is represented by 3 Redis keys:
......@@ -1407,6 +1408,7 @@ class DataNode(metaclass=DataNodeMetaClass):
parent = self.parent
if parent is None:
return None
# Higher priority than PREPARED scan
children_stream = parent._create_stream("children_list")
self_db_name = self.db_name
for index, raw in children_stream.rev_range():
......@@ -1653,6 +1655,7 @@ class DataNodeContainer(DataNode):
last_node = None
active_streams = dict()
excluded_stream_names = set()
# Higher priority than PREPARED scan
children_stream = self._create_stream("children_list")
first_index = children_stream.before_last_index()
if first_index is None:
......
......@@ -20,7 +20,7 @@ class _ChannelDataNodeBase(DataNode):
def __init__(self, name, **kwargs):
super().__init__(self._NODE_TYPE, name, **kwargs)
self._queue = self._create_stream("data", maxlen=CHANNEL_MAX_LEN)
self._queue = self._create_stream("data", maxlen=CHANNEL_MAX_LEN, priority=2)
self._last_index = self._idx_to_streamid(0)
@classmethod
......
......@@ -10,26 +10,48 @@ from bliss.data.node import DataNodeContainer
from bliss.data.nodes.channel import ChannelDataNode
from bliss.data.nodes.lima import LimaImageChannelDataNode
from bliss.config.streaming_events import StreamEvent
from bliss.data.events import Event, EventType, EventData, EndScanEvent
from bliss.data.events import (
Event,
EventType,
EventData,
EndScanEvent,
PreparedScanEvent,
)
from bliss.config import settings
class ScanNode(DataNodeContainer):
_NODE_TYPE = "scan"
_EVENT_TYPE_MAPPING = {EndScanEvent.TYPE.decode("ascii"): EventType.END_SCAN}
_EVENT_TYPE_MAPPING = {
EndScanEvent.TYPE.decode("ascii"): EventType.END_SCAN,
PreparedScanEvent.TYPE.decode("ascii"): EventType.PREPARED_SCAN,
}
"""Mapping from event name to EventType
"""
def __init__(self, name, **kwargs):
super().__init__(self._NODE_TYPE, name, **kwargs)
# Lower priority than all other streams
self._end_stream = self._create_stream("end", priority=1)
self._end_stream = self._create_stream("end", priority=3)
# Lower priority than NEW_NODE, higher than NEW_DATA
self._prepared_stream = self._create_stream("prepared", priority=1)
@property
def dataset(self):
return self.parent
def prepared(self):
"""Publish PREPARED event in Redis
"""
if not self.new_node:
return
# to avoid to have multiple modification events
# TODO: what does the comment above mean?
with settings.pipeline(self._prepared_stream, self._info):
event = PreparedScanEvent()
self._prepared_stream.add_event(event)
def end(self, exception=None):
"""Publish END event in Redis
"""
......@@ -47,15 +69,6 @@ class ScanNode(DataNodeContainer):
self._info.update(add_info)
self._end_stream.add_event(event)
def _get_event_class(self, stream_event):
stream_event = stream_event[0]
kind = stream_event[1][b"__EVENT__"]
# FIXME: Use dict instead of iteration
for event_class in self._SUPPORTED_EVENTS:
if event_class.TYPE == kind:
return event_class
raise RuntimeError("Unsupported event kind %s", kind)
def decode_raw_events(self, events):
"""Decode raw stream data
......@@ -76,10 +89,11 @@ class ScanNode(DataNodeContainer):
def get_db_names(self, **kw):
db_names = super().get_db_names(**kw)
db_names.append(self._end_stream.name)
db_names.append(self._prepared_stream.name)
return db_names
def get_settings(self):
return super().get_settings() + [self._end_stream]
return super().get_settings() + [self._end_stream, self._prepared_stream]
def _subscribe_streams(self, reader, first_index=None, **kw):
"""Subscribe to all associated streams of this node.
......@@ -92,6 +106,10 @@ class ScanNode(DataNodeContainer):
self._subscribe_stream(
suffix, reader, first_index=0, create=True, ignore_excluded=True
)
suffix = self._prepared_stream.name.rsplit("_", 1)[-1]
self._subscribe_stream(
suffix, reader, first_index=0, create=True, ignore_excluded=True
)
def get_stream_event_handler(self, stream):
"""
......@@ -100,6 +118,8 @@ class ScanNode(DataNodeContainer):
"""
if stream.name == self._end_stream.name:
return self._iter_data_stream_events
elif stream.name == self._prepared_stream.name:
return self._iter_data_stream_events
return super(ScanNode, self).get_stream_event_handler(stream)
def _iter_data_stream_events(
......@@ -130,9 +150,10 @@ class ScanNode(DataNodeContainer):
event_id = self._EVENT_TYPE_MAPPING[kind]
event = Event(type=event_id, node=self, data=data)
yield event
# Stop reading events from this node's streams
# and the streams of its children
reader.remove_matching_streams(f"{self.db_name}*")
if event_id is EventType.END_SCAN:
# Stop reading events from this node's streams
# and the streams of its children
reader.remove_matching_streams(f"{self.db_name}*")
def get_data_from_nodes(pipeline, *nodes):
......
......@@ -1266,6 +1266,8 @@ class Scan:
self._prepare_devices(devices_tree)
self.writer.prepare(self)
self.node.prepared()
self._axes_in_scan = self._get_data_axes(include_calc_reals=True)
with execute_pre_scan_hooks(self._axes_in_scan):
pass
......
......@@ -490,7 +490,7 @@ def test_walk_after_nodes_disappeared(session):
# Validate counting when all nodes are still present
nroot = len(session.scan_saving._db_path_keys) - 1
nnodes = nroot + 6 # + scan, master, epoch, elapsed, controller, diode
nevents = nnodes + 4 # + 3 x data + 1 x end
nevents = nnodes + 5 # + 3 x data + 1 x end + 1 x prepared
validate_count(nnodes, nevents)
# Scan incomplete
......@@ -502,7 +502,7 @@ def test_walk_after_nodes_disappeared(session):
names = list(s.node.search_redis(s.node.db_name + "*"))
nnodes = nroot + 1
nevents = nnodes + 1
nevents = nnodes + 2
validate_count(nnodes, nevents)
# Scan missing
......@@ -856,9 +856,10 @@ def test_walk_events_on_session_node(beforestart, wait, include_filter, session)
beforestart, session, session.name, include_filter=include_filter, wait=wait
)
if include_filter == "scan":
assert set(events.keys()) == {"NEW_NODE", "END_SCAN"}
assert set(events.keys()) == {"NEW_NODE", "END_SCAN", "PREPARED_SCAN"}
assert len(events["NEW_NODE"]) == 1
assert len(events["END_SCAN"]) == 1
assert len(events["PREPARED_SCAN"]) == 1
elif include_filter == "channel":
# New node events: epoch, elapsed_time, n x detector
assert set(events.keys()) == {"NEW_NODE", "NEW_DATA"}
......@@ -872,12 +873,18 @@ def test_walk_events_on_session_node(beforestart, wait, include_filter, session)
else:
# New node events: root nodes, scan, scan master (timer),
# epoch, elapsed_time, n x (controller, detector)
assert set(events.keys()) == {"NEW_NODE", "NEW_DATA", "END_SCAN"}
assert set(events.keys()) == {
"NEW_NODE",
"NEW_DATA",
"END_SCAN",
"PREPARED_SCAN",
}
# One less because the NEW_NODE event for session.name is
# not emitted on node session.name
nroot = len(session.scan_saving._db_path_keys) - 1
assert len(events["NEW_NODE"]) == nroot + 2 + nmasters + 2 * nchannels
assert len(events["NEW_DATA"]) == nmasters + nchannels
assert len(events["PREPARED_SCAN"]) == 1
assert len(events["END_SCAN"]) == 1
......@@ -929,9 +936,10 @@ def test_walk_events_on_dataset_node(beforestart, wait, include_filter, session)
)
if include_filter == "scan":
# New node events: scan
assert set(events.keys()) == {"NEW_NODE", "END_SCAN"}
assert set(events.keys()) == {"NEW_NODE", "END_SCAN", "PREPARED_SCAN"}
assert len(events["NEW_NODE"]) == 1
assert len(events["END_SCAN"]) == 1
assert len(events["PREPARED_SCAN"]) == 1
elif include_filter == "channel":
# New node events: epoch, elapsed_time, n x detector
assert set(events.keys()) == {"NEW_NODE", "NEW_DATA"}
......@@ -945,10 +953,16 @@ def test_walk_events_on_dataset_node(beforestart, wait, include_filter, session)
else:
# New node events: dataset, scan master (timer), epoch,
# elapsed_time, n x (controller, detector)
assert set(events.keys()) == {"NEW_NODE", "NEW_DATA", "END_SCAN"}
assert set(events.keys()) == {
"NEW_NODE",
"NEW_DATA",
"END_SCAN",
"PREPARED_SCAN",
}
assert len(events["NEW_NODE"]) == 2 + nmasters + 2 * nchannels
assert len(events["NEW_DATA"]) == nmasters + nchannels
assert len(events["END_SCAN"]) == 1
assert len(events["PREPARED_SCAN"]) == 1
@pytest.mark.parametrize("beforestart, wait, include_filter", _count_parameters)
......@@ -989,8 +1003,9 @@ def test_walk_events_on_scan_node(beforestart, wait, include_filter, session):
wait=wait,
)
if include_filter == "scan":
assert set(events.keys()) == {"END_SCAN"}
assert set(events.keys()) == {"END_SCAN", "PREPARED_SCAN"}
assert len(events["END_SCAN"]) == 1
assert len(events["PREPARED_SCAN"]) == 1
elif include_filter == "channel":
# New node events: epoch, elapsed_time, n x detector
assert set(events.keys()) == {"NEW_NODE", "NEW_DATA"}
......@@ -1004,10 +1019,16 @@ def test_walk_events_on_scan_node(beforestart, wait, include_filter, session):
else:
# New node events: scan master (timer), epoch, elapsed_time,
# n x (controller, detector)
assert set(events.keys()) == {"NEW_NODE", "NEW_DATA", "END_SCAN"}
assert set(events.keys()) == {
"NEW_NODE",
"NEW_DATA",
"END_SCAN",
"PREPARED_SCAN",
}
assert len(events["NEW_NODE"]) == 1 + nmasters + 2 * nchannels
assert len(events["NEW_DATA"]) == nmasters + nchannels
assert len(events["END_SCAN"]) == 1
assert len(events["PREPARED_SCAN"]) == 1
@pytest.mark.parametrize("beforestart, wait, include_filter", _count_parameters)
......@@ -1439,7 +1460,8 @@ def test_filter_nodes(session):
nroot = len(session.scan_saving._db_path_keys)
keys_per_channel = 3
keys_per_container = 2
keys_per_scan = 4
streams_per_scan = 2
keys_per_scan = streams_per_scan + 3
containers_per_scan = 2 # master, controller
channels_per_scan = 3 # epoch, elapsed, diode
keys_per_scan += (
......@@ -1624,13 +1646,14 @@ def test_walk_events_filter(session):
nevents = nroot + nscans * nodes_per_scan # NEW_NODE
nevents += nscans * channels_per_scan # NEW_DATA
nevents += nscans # END_SCAN
nevents += nscans # PREPARED_SCAN
db_names = list(_filter_walk_get_nodes(session_node.walk_events, wait=False))
assert len(db_names) == nevents
assert set(scan_db_names).issubset(db_names)
# Walk all scan events:
kw = {"include_filter": "scan"}
nevents = 2 * nscans # NEW_NODE + END_SCAN
nevents = 3 * nscans # NEW_NODE + PREPARED + END_SCAN
for a, b in itertools.product((None, "scan"), (None, "scan")):
kw = {
"include_filter": "scan",
......
......@@ -263,7 +263,7 @@ def test_sequence_events(session):
seq_context.add(s2)
event_dump = list()
nexpectedevents = 70
nexpectedevents = 74
started_event = gevent.event.Event()
finished_event = gevent.event.Event()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment