Commit 497b120f authored by Matias Guijarro's avatar Matias Guijarro
Browse files

Merge branch 'add-device-meta-to-scan_info' into 'master'

Emit a PrepareScanEvent

See merge request !3655
parents 747d511f c3d1a8f4
Pipeline #46265 failed with stages
in 95 minutes and 26 seconds
......@@ -585,10 +585,9 @@ class DataStreamReader:
self._publish_synchro_event()
self._start_read_task()
def _read_active_streams(self, priority_threshold=None):
def _read_active_streams(self):
"""Get data from the active streams
:param int priority_threshold: read only from this priority or higher
:returns list(2-tuple): list((name, events))
name: name of the stream
events: list((index, raw)))
......@@ -599,14 +598,8 @@ class DataStreamReader:
streams_to_read = sorted(
self._active_streams.items(), key=lambda item: item[1]["priority"]
)
if priority_threshold is None:
streams_to_read = {k: v["first_index"] for k, v in streams_to_read}
else:
streams_to_read = {
k: v["first_index"]
for k, v in streams_to_read
if v["priority"] <= priority_threshold
}
streams_to_read = {k: v["first_index"] for k, v in streams_to_read}
# first_index: yield events with stream ID larger then this
# block=None: yield nothing when no events
# block=0: always yield something (no timeout)
......
......@@ -333,6 +333,11 @@ class StartEvent(TimeEvent):
TYPE = b"START"
class PreparedEvent(TimeEvent):
TYPE = b"PREPARED"
class EndEvent(TimeEvent):
TYPE = b"END"
......
......@@ -8,7 +8,7 @@
from bliss.config import streaming_events
__all__ = ["EndScanEvent"]
__all__ = ["EndScanEvent", "PreparedScanEvent"]
class EndScanEvent(streaming_events.EndEvent):
......@@ -23,3 +23,27 @@ class EndScanEvent(streaming_events.EndEvent):
:returns EndScanEvent:
"""
return cls(raw=events[0][1])
@property
def description(self):
"""Used to generate EventData description"""
return self.exception
class PreparedScanEvent(streaming_events.PreparedEvent):
TYPE = b"PREPARED_SCAN"
@classmethod
def merge(cls, events):
"""Keep only the first event.
:param list((index, raw)) events:
:returns PreparedScanEvent:
"""
return cls(raw=events[0][1])
@property
def description(self):
"""Used to generate EventData description"""
return None
......@@ -6,7 +6,7 @@
# Distributed under the GNU LGPLv3. See LICENSE for more info.
"""
Events returned recieved when walking a `DataNode`,
Events returned received when walking a `DataNode`,
derived from raw Redis stream events.
"""
......@@ -21,6 +21,7 @@ class EventType(Enum):
NEW_NODE = 1
NEW_DATA = 2
END_SCAN = 3
PREPARED_SCAN = 4
class EventData(NamedTuple):
......
......@@ -40,7 +40,8 @@ A ScanNode is represented by 4 Redis keys:
{db_name} -> see DataNodeContainer
{db_name}_info -> see DataNodeContainer
{db_name}_children -> see DataNodeContainer
{db_name}_data -> contains the END event
{db_name}_end -> contains the END event
{db_name}_prepared -> contains the PREPARED event
A ChannelDataNode is represented by 3 Redis keys:
......@@ -681,7 +682,7 @@ class DataNode(metaclass=DataNodeMetaClass):
"""
_TIMETOLIVE = 24 * 3600 # 1 day
VERSION = (1, 0) # change major version for incompatible API changes
VERSION = (1, 1) # change major version for incompatible API changes
@staticmethod
def _principal_db_name(name, parent=None):
......@@ -724,6 +725,9 @@ class DataNode(metaclass=DataNodeMetaClass):
self.__db_name = db_name
self.node_type = node_type
self._priorities = {}
"""Hold priorities per streams."""
# The info dictionary associated to the DataNode
self._info = settings.HashObjSetting(f"{db_name}_info", connection=connection)
info_dict = self._init_info(create=create, **kwargs)
......@@ -740,6 +744,17 @@ class DataNode(metaclass=DataNodeMetaClass):
self._ttl_setter = None
self._struct = self._get_struct(db_name, connection=self.db_connection)
def _register_stream_priority(self, fullname: str, priority: int):
"""
Register the stream priority which will be used on the reader side.
:paran str fullname: Full name of the stream
:param int priority: data from streams with a lower priority is never
yielded as long as higher priority streams have
data. Lower number means higher priority.
"""
self._priorities[fullname] = priority
def add_prefetch(self, async_proxy=None):
"""As long as caching on the proxy level exists in CachingRedisDbProxy,
we need to prefetch settings like this.
......@@ -815,7 +830,8 @@ class DataNode(metaclass=DataNodeMetaClass):
:param `**kw`: see `_create_nonassociated_stream`
:returns DataStream:
"""
return self._create_nonassociated_stream(f"{self.db_name}_{suffix}", **kw)
stream = self._create_nonassociated_stream(f"{self.db_name}_{suffix}", **kw)
return stream
@classmethod
def _streamid_to_idx(cls, streamID):
......@@ -1362,6 +1378,12 @@ class DataNode(metaclass=DataNodeMetaClass):
if not self.db_connection.exists(stream_name):
return
stream = self._create_nonassociated_stream(stream_name)
# Use the priority as it was setup
priority = self._priorities.get(stream.name, 0)
if priority is not None:
kw["priority"] = priority
reader.add_streams(stream, node=self, first_index=first_index, **kw)
def _subscribe_streams(
......@@ -1390,6 +1412,7 @@ class DataNode(metaclass=DataNodeMetaClass):
parent = self.parent
if parent is None:
return None
# Higher priority than PREPARED scan
children_stream = parent._create_stream("children_list")
self_db_name = self.db_name
for index, raw in children_stream.rev_range():
......@@ -1535,31 +1558,32 @@ class DataNodeContainer(DataNode):
search_data_streams = False
# Subscribe to the streams of all children, not only the direct children.
# TODO: this assumes that all streams to subscribe too are called
# "*_children_list" and "*_data". Can be solved with DataNode
# derived class self-registration and each class adding
# stream suffixes and orders.
# TODO: this makes assumptions about the data nodes and their streams.
# Any change in streams (rename stream, add new streams, ...) will
# affect this code.
# Subscribe to streams found by a recursive search
nodes_with_data = list()
nodes_with_data = dict()
search_suffixes = {"data": ["data"], "end": ["prepared", "end"]}
nodes_with_children = list()
excluded_stream_names = set(reader.excluded_stream_names)
if search_data_streams:
# Make sure the NEW_NODE event always arrives before the NEW_DATA event:
# Make sure the NEW_NODE event always arrives before any other node event:
# - assume "...:parent_children_list" is created BEFORE "...parent:child_data"
# - search for *_children_list AFTER searching for *_data
# - subscribe to *_children_list BEFORE subscribing to *_data
node_names = self._search_nodes_with_streams(
"data", excluded_stream_names, include_parent=False
)
nodes_with_data = list(
self.get_filtered_nodes(
*node_names,
include_filter=include_filter,
recursive_exclude=exclude_children,
strict_recursive_exclude=False,
for suffix in search_suffixes:
node_names = self._search_nodes_with_streams(
suffix, excluded_stream_names, include_parent=False
)
nodes_with_data[suffix] = list(
self.get_filtered_nodes(
*node_names,
include_filter=include_filter,
recursive_exclude=exclude_children,
strict_recursive_exclude=False,
)
)
)
if not exclude_my_children:
node_names = self._search_nodes_with_streams(
"children_list", excluded_stream_names, include_parent=False
......@@ -1574,8 +1598,13 @@ class DataNodeContainer(DataNode):
# Subscribe to the streams that were searched
for node in nodes_with_children:
node._subscribe_stream("children_list", reader, first_index=first_index)
for node in nodes_with_data:
node._subscribe_stream("data", reader, first_index=first_index)
for search_suffix, nodes in nodes_with_data.items():
subscribe_suffixes = search_suffixes[search_suffix]
for node in nodes:
for subscribe_suffix in subscribe_suffixes:
node._subscribe_stream(
subscribe_suffix, reader, first_index=first_index
)
# Exclude searched Redis keys from further subscription attempts
reader.excluded_stream_names |= excluded_stream_names
......@@ -1636,6 +1665,7 @@ class DataNodeContainer(DataNode):
last_node = None
active_streams = dict()
excluded_stream_names = set()
# Higher priority than PREPARED scan
children_stream = self._create_stream("children_list")
first_index = children_stream.before_last_index()
if first_index is None:
......
......@@ -21,6 +21,7 @@ class _ChannelDataNodeBase(DataNode):
def __init__(self, name, **kwargs):
super().__init__(self._NODE_TYPE, name, **kwargs)
self._queue = self._create_stream("data", maxlen=CHANNEL_MAX_LEN)
self._register_stream_priority(f"{self.db_name}_data", 2)
self._last_index = self._idx_to_streamid(0)
@classmethod
......
......@@ -9,21 +9,50 @@ from bliss.common.greenlet_utils import AllowKill
from bliss.data.node import DataNodeContainer
from bliss.data.nodes.channel import ChannelDataNode
from bliss.data.nodes.lima import LimaImageChannelDataNode
from bliss.data.events import Event, EventType, EventData, EndScanEvent
from bliss.config.streaming_events import StreamEvent
from bliss.data.events import (
Event,
EventType,
EventData,
EndScanEvent,
PreparedScanEvent,
)
from bliss.config import settings
class ScanNode(DataNodeContainer):
_NODE_TYPE = "scan"
_EVENT_TYPE_MAPPING = {
EndScanEvent.TYPE.decode("ascii"): EventType.END_SCAN,
PreparedScanEvent.TYPE.decode("ascii"): EventType.PREPARED_SCAN,
}
"""Mapping from event name to EventType
"""
def __init__(self, name, **kwargs):
super().__init__(self._NODE_TYPE, name, **kwargs)
self._sync_stream = self._create_stream("data")
self._end_stream = self._create_stream("end")
self._prepared_stream = self._create_stream("prepared")
# Register to priority as the following way: NEW DATA > PREPARED > NEW NODE > END
self._register_stream_priority(self._end_stream.name, 3)
self._register_stream_priority(self._prepared_stream.name, 1)
@property
def dataset(self):
return self.parent
def prepared(self):
"""Publish PREPARED event in Redis
"""
if not self.new_node:
return
# to avoid to have multiple modification events
# TODO: what does the comment above mean?
with settings.pipeline(self._prepared_stream, self._info):
event = PreparedScanEvent()
self._prepared_stream.add_event(event)
def end(self, exception=None):
"""Publish END event in Redis
"""
......@@ -31,7 +60,7 @@ class ScanNode(DataNodeContainer):
return
# to avoid to have multiple modification events
# TODO: what does the comment above mean?
with settings.pipeline(self._sync_stream, self._info):
with settings.pipeline(self._end_stream, self._info):
event = EndScanEvent()
add_info = {
"end_time": event.time,
......@@ -39,7 +68,7 @@ class ScanNode(DataNodeContainer):
"end_timestamp": event.timestamp,
}
self._info.update(add_info)
self._sync_stream.add_event(event)
self._end_stream.add_event(event)
def decode_raw_events(self, events):
"""Decode raw stream data
......@@ -49,31 +78,23 @@ class ScanNode(DataNodeContainer):
"""
if not events:
return None
first_index = self._streamid_to_idx(events[0][0])
ev = EndScanEvent.merge(events)
return EventData(
first_index=first_index, data=ev.TYPE.decode(), description=ev.exception
)
assert len(events) == 1 # Else you are about to lose events
event = events[0]
timestamp, raw_data = event
first_index = self._streamid_to_idx(timestamp)
ev = StreamEvent.factory(raw_data)
data = type(ev).TYPE.decode()
return EventData(first_index=first_index, data=data, description=ev.description)
def get_db_names(self, **kw):
db_names = super().get_db_names(**kw)
db_names.append(self.db_name + "_data")
db_names.append(self._end_stream.name)
db_names.append(self._prepared_stream.name)
return db_names
def get_settings(self):
return super().get_settings() + [self._sync_stream]
def _subscribe_stream(self, stream_suffix, reader, first_index=None, **kw):
"""Subscribe to a particular stream associated with this node.
:param str stream_suffix: stream to add is "{db_name}_{stream_suffix}"
:param DataStreamReader reader:
:param str or int first_index: Redis stream index (None is now)
"""
if stream_suffix == "data":
# Lower priority than all other streams
kw["priority"] = 1
super()._subscribe_stream(stream_suffix, reader, first_index=first_index, **kw)
return super().get_settings() + [self._end_stream, self._prepared_stream]
def _subscribe_streams(self, reader, first_index=None, **kw):
"""Subscribe to all associated streams of this node.
......@@ -82,10 +103,26 @@ class ScanNode(DataNodeContainer):
:param **kw: see DataNodeContainer
"""
super()._subscribe_streams(reader, first_index=first_index, **kw)
suffix = self._end_stream.name.rsplit("_", 1)[-1]
self._subscribe_stream(
suffix, reader, first_index=0, create=True, ignore_excluded=True
)
suffix = self._prepared_stream.name.rsplit("_", 1)[-1]
self._subscribe_stream(
"data", reader, first_index=0, create=True, ignore_excluded=True
suffix, reader, first_index=0, create=True, ignore_excluded=True
)
def get_stream_event_handler(self, stream):
"""
:param DataStream stream:
:returns callable:
"""
if stream.name == self._end_stream.name:
return self._iter_data_stream_events
elif stream.name == self._prepared_stream.name:
return self._iter_data_stream_events
return super(ScanNode, self).get_stream_event_handler(stream)
def _iter_data_stream_events(
self,
reader,
......@@ -104,15 +141,20 @@ class ScanNode(DataNodeContainer):
:param bool yield_events: yield Event or DataNode
:yields Event:
"""
data = self.decode_raw_events(events)
if data is None:
return
if yield_events and self._included(include_filter):
with AllowKill():
yield Event(type=EventType.END_SCAN, node=self, data=data)
# Stop reading events from this node's streams
# and the streams of its children
reader.remove_matching_streams(f"{self.db_name}*")
for event in events:
data = self.decode_raw_events([event])
if data is None:
return
if yield_events and self._included(include_filter):
with AllowKill():
kind = data.data
event_id = self._EVENT_TYPE_MAPPING[kind]
event = Event(type=event_id, node=self, data=data)
yield event
if event_id is EventType.END_SCAN:
# Stop reading events from this node's streams
# and the streams of its children
reader.remove_matching_streams(f"{self.db_name}*")
def get_data_from_nodes(pipeline, *nodes):
......
......@@ -105,9 +105,19 @@ class ScansObserver:
"""
pass
def on_scan_created(self, scan_db_name: str, scan_info: Dict):
"""
Called upon scan created (devices are not yet prepared).
Arguments:
scan_db_name: Identifier of the scan
scan_info: Dictionary containing scan metadata
"""
pass
def on_scan_started(self, scan_db_name: str, scan_info: Dict):
"""
Called upon scan start.
Called upon scan started (the devices was prepared).
Arguments:
scan_db_name: Identifier of the scan
......@@ -336,13 +346,13 @@ class ScansWatcher:
# New scan was created
scan_info = node.info.get_all()
self._running_scans.add(db_name)
observer.on_scan_started(db_name, scan_info)
observer.on_scan_created(db_name, scan_info)
elif node_type == "scan_group":
if self._watch_scan_group:
# New scan was created
scan_info = node.info.get_all()
self._running_scans.add(db_name)
observer.on_scan_started(db_name, scan_info)
observer.on_scan_created(db_name, scan_info)
else:
scan_db_name = self._get_scan_db_name_from_child(db_name)
if scan_db_name is not None:
......@@ -403,6 +413,16 @@ class ScansWatcher:
)
except Exception:
sys.excepthook(*sys.exc_info())
elif event_type == EventType.PREPARED_SCAN:
node_type = node.type
if self._watch_scan_group or node_type == "scan":
db_name = node.db_name
if db_name in self._running_scans:
scan_info = node.info.get_all()
try:
observer.on_scan_started(db_name, scan_info)
except Exception:
sys.excepthook(*sys.exc_info())
elif event_type == EventType.END_SCAN:
node_type = node.type
if self._watch_scan_group or node_type == "scan":
......@@ -470,7 +490,7 @@ class DefaultScansObserver(ScansObserver):
"""
self._current_event = event
def on_scan_started(self, scan_db_name: str, scan_info: Dict):
def on_scan_created(self, scan_db_name: str, scan_info: Dict):
# Pre-compute mapping from each channels to its master
top_master_per_channels = {}
for top_master, meta in scan_info["acquisition_chain"].items():
......
......@@ -216,9 +216,9 @@ class ScanManager(bliss_scan.ScansObserver):
managed)."""
return scan_db_name in self.__cache
def on_scan_started(self, scan_db_name: str, scan_info: Dict):
def on_scan_created(self, scan_db_name: str, scan_info: Dict):
_logger.info("Scan started: %s", scan_info.get("title", scan_db_name))
_logger.debug("on_scan_started %s", scan_db_name)
_logger.debug("on_scan_created %s", scan_db_name)
if scan_db_name in self.__cache:
# We should receive a single new_scan per scan, but let's check anyway
_logger.debug("new_scan from %s ignored", scan_db_name)
......
......@@ -1266,6 +1266,8 @@ class Scan:
self._prepare_devices(devices_tree)
self.writer.prepare(self)
self.node.prepared()
self._axes_in_scan = self._get_data_axes(include_calc_reals=True)
with execute_pre_scan_hooks(self._axes_in_scan):
pass
......
......@@ -803,7 +803,7 @@ class ScanPrinterFromRedis(scan_mdl.ScansObserver):
requested_channels = scan_renderer.displayable_channel_names.copy()
scan_renderer.set_displayed_channels(requested_channels)
def on_scan_started(self, scan_db_name: str, scan_info: typing.Dict):
def on_scan_created(self, scan_db_name: str, scan_info: typing.Dict):
self.scan_renderer = ScanRenderer(scan_info)
# Update the displayed channels before printing the scan header
if self.scan_renderer.scan_type != "ct":
......@@ -866,13 +866,13 @@ class ScanDataListener(scan_mdl.ScansObserver):
return None
return ScanPrinterFromRedis(self.scan_display)
def on_scan_started(self, scan_db_name: str, scan_info: typing.Dict):
def on_scan_created(self, scan_db_name: str, scan_info: typing.Dict):
"""Called from Redis callback on scan started"""
if self._scan_displayer is None:
self._scan_displayer = self._create_scan_displayer(scan_info)
if self._scan_displayer is not None:
self._scan_id = scan_db_name
self._scan_displayer.on_scan_started(scan_db_name, scan_info)
self._scan_displayer.on_scan_created(scan_db_name, scan_info)
else:
self._warning_messages.append(
f"\nWarning: a new scan '{scan_db_name}' has been started while scan '{self._scan_id}' is running.\nNew scan outputs will be ignored."
......
......@@ -345,7 +345,7 @@ class PrintScanProgress(scan_mdl.ScansObserver):
def __init__(self):
self._data = {}
def on_scan_started(self, scan_db_name: str, scan_info: typing.Dict):
def on_scan_created(self, scan_db_name: str, scan_info: typing.Dict):
self._data = {}
def on_ndim_data_received(
......
......@@ -35,8 +35,8 @@ SCAN_INFO_3 = {
class MockedScanManager(scan_manager.ScanManager):
def emit_scan_started(self, scan_info):
self.on_scan_started(scan_info["node_name"], scan_info)
def emit_scan_created(self, scan_info):
self.on_scan_created(scan_info["node_name"], scan_info)
def emit_scan_finished(self, scan_info):
self.on_scan_finished(scan_info["node_name"], scan_info)
......@@ -67,12 +67,12 @@ def test_interleaved_scans():
scans = manager.get_alive_scans()
assert len(scans) == 0
manager.emit_scan_started(scan_info_1)
manager.emit_scan_created(scan_info_1)
scans = manager.get_alive_scans()
assert len(scans) == 1
assert scans[0].scanInfo() == scan_info_1
manager.emit_scan_started(scan_info_2)
manager.emit_scan_created(scan_info_2)
manager.emit_scalar_updated(scan_info_1, "axis:roby", numpy.arange(2))
manager.emit_scalar_updated(scan_info_2, "axis:robz", numpy.arange(3))
manager.wait_data_processed()
......@@ -95,7 +95,7 @@ def test_sequencial_scans():
manager = MockedScanManager(flintModel=None)
manager.emit_scan_started(scan_info_1)
manager.emit_scan_created(scan_info_1)
manager.emit_scalar_updated(scan_info_1, "axis:roby", numpy.arange(2))
manager.wait_data_processed()
scans = manager.get_alive_scans()
......@@ -104,7 +104,7 @@ def test_sequencial_scans():
assert manager.get_alive_scans() == []
assert scans[0].scanInfo() == scan_info_1
manager.emit_scan_started(scan_info_2)
manager.emit_scan_created(scan_info_2)
manager.emit_scalar_updated(scan_info_2, "axis:robz", numpy.arange(3))
manager.wait_data_processed()
scans = manager.get_alive_scans()
......@@ -119,7 +119,7 @@ def test_bad_sequence__end_before_new():
manager = MockedScanManager(flintModel=None)
manager.emit_scan_finished(scan_info_1)
manager.emit_scan_started(scan_info_1)
manager.emit_scan_created(scan_info_1)
# FIXME What to do anyway then? The manager is locked
......@@ -170,7 +170,7 @@ def test_image__default():
manager = MockedScanManager(flintModel=None)
manager.emit_scan_started(scan_info_3)
manager.emit_scan_created(scan_info_3)
scan = manager.get_alive_scans()[0]
image = numpy.arange(1).reshape(1, 1)
......@@ -191,7 +191,7 @@ def test_image__disable_video():
manager = MockedScanManager(flintModel=None)
manager.emit_scan_started(scan_info_3)
manager.emit_scan_created(scan_info_3)
scan = manager.get_alive_scans()[0]