Commit fc6c70ba authored by Matias Guijarro's avatar Matias Guijarro
Browse files

Merge branch '2798-redis-subscriber-does-not-get-new-scan-events-on-existing-parents' into 'master'

Resolve "Redis subscriber does not get new scan events on existing parents"

Closes #2798

See merge request !3770
parents 01db9c19 7ec1020e
Pipeline #48143 failed with stages
in 101 minutes and 18 seconds
"""Utilities for Redis server-side scripts
"""
_SCRIPTS = dict()
def register_script(redisproxy, script_name: str, script: str) -> None:
"""Local registration. The registration with the Redis server is
done on first usage.
"""
if script_name in _SCRIPTS:
return
scriptobj = redisproxy.register_script(script)
scriptobj.registered_client = None
# scriptobj only contains the script hash and code
_SCRIPTS[script_name] = scriptobj
def evaluate_script(redisproxy, script_name: str, keys=tuple(), args=tuple()):
"""Evaluate a server-side Redis script
"""
if script_name not in _SCRIPTS:
raise RuntimeError(f"Redis script {repr(script_name)} is not registered")
scriptobj = _SCRIPTS[script_name]
return scriptobj(keys=keys, args=args, client=redisproxy)
......@@ -25,6 +25,7 @@ from .conductor import client
from bliss.config.conductor.client import set_config_db_file, remote_open
from bliss.common.utils import Null, auto_coerce
from bliss import current_session
from bliss.config.conductor.redis_scripts import register_script, evaluate_script
logger = logging.getLogger(__name__)
......@@ -566,8 +567,7 @@ class BaseHashSetting(BaseSetting):
return f"<{type(self).__name__} name=%s value=%s>" % (self.name, value)
def __delitem__(self, key):
cnx = self.connection
cnx.hdel(self.name, key)
self.remove(key)
def __len__(self):
cnx = self.connection
......@@ -680,7 +680,7 @@ class BaseHashSetting(BaseSetting):
return False
lua_script = """
orderedhashsetting_helper_script = """
-- Atomic addiction of a key to a hash and to a list
-- to keep track of inserction order
......@@ -730,7 +730,7 @@ else
-- attribute does exist
return redis.call("HSET", hashkey, attribute, value)
end
""".encode()
"""
class OrderedHashSetting(BaseHashSetting):
......@@ -745,39 +745,18 @@ class OrderedHashSetting(BaseHashSetting):
write_type_conversion: conversion of data applied before writing
"""
add_key_script_sha1 = None
def __init__(
self,
name,
connection=None,
read_type_conversion=auto_coerce,
write_type_conversion=str,
):
super().__init__(name, connection, read_type_conversion, write_type_conversion)
if (
self.add_key_script_sha1 is None
): # class attribute to execute only the first time
# calculate sha1 of the script
sha1 = hashlib.sha1(lua_script).hexdigest()
# check if script already exists in Redis and if not loads it
if not self.connection.script_exists(sha1)[0]:
sha1_from_redis = self.connection.script_load(lua_script)
if sha1_from_redis != sha1:
raise RuntimeError("Exception in sending lua script to Redis")
type(self).add_key_script_sha1 = sha1
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
register_script(
self.connection,
"orderedhashsetting_helper_script",
orderedhashsetting_helper_script,
)
@property
def _name_order(self):
return self._name + ":creation_order"
def __delitem__(self, key):
self.remove(key)
def __len__(self):
return super().__len__()
def ttl(self, value=-1):
hash_ttl = super().ttl(value)
set_ttl = ttl_func(self._cnx(), self._name_order, value)
......@@ -826,13 +805,11 @@ class OrderedHashSetting(BaseHashSetting):
cnx.delete(self._name_order)
if mapping is not None:
for k, v in mapping.items():
cnx.evalsha(
self.add_key_script_sha1,
2,
self._name,
self._name + ":creation_order",
k,
v,
evaluate_script(
cnx,
"orderedhashsetting_helper_script",
keys=(self._name, self._name + ":creation_order"),
args=(k, v),
)
cnx.execute()
......@@ -841,13 +818,11 @@ class OrderedHashSetting(BaseHashSetting):
with pipeline(self) as p:
if values:
for k, v in values.items():
p.evalsha(
self.add_key_script_sha1,
2,
self._name,
self._name + ":creation_order",
k,
v,
evaluate_script(
p,
"orderedhashsetting_helper_script",
keys=(self._name, self._name + ":creation_order"),
args=(k, v),
)
def has_key(self, key):
......@@ -883,13 +858,11 @@ class OrderedHashSetting(BaseHashSetting):
if self._write_type_conversion:
value = self._write_type_conversion(value)
cnx = self._cnx()
cnx.evalsha(
self.add_key_script_sha1,
2,
self._name,
self._name + ":creation_order",
key,
value,
evaluate_script(
cnx,
"orderedhashsetting_helper_script",
keys=(self._name, self._name + ":creation_order"),
args=(key, value),
)
def __contains__(self, key):
......@@ -916,7 +889,7 @@ class HashSetting(BaseHashSetting):
connection=None,
read_type_conversion=auto_coerce,
write_type_conversion=str,
default_values={},
default_values=None,
):
super().__init__(
name,
......@@ -924,6 +897,8 @@ class HashSetting(BaseHashSetting):
read_type_conversion=read_type_conversion,
write_type_conversion=write_type_conversion,
)
if default_values is None:
default_values = dict()
self._default_values = default_values
@read_decorator
......@@ -975,10 +950,12 @@ class HashSettingProp(BaseSetting):
connection=None,
read_type_conversion=auto_coerce,
write_type_conversion=str,
default_values={},
default_values=None,
use_object_name=True,
):
super().__init__(name, connection, read_type_conversion, write_type_conversion)
if default_values is None:
default_values = dict()
self._default_values = default_values
self._use_object_name = use_object_name
......
......@@ -5,6 +5,7 @@
# Copyright (c) 2015-2019 Beamline Control Unit, ESRF
# Distributed under the GNU LGPLv3. See LICENSE for more info.
import sys
import gevent
import uuid
import enum
......@@ -14,6 +15,7 @@ import logging
from contextlib import contextmanager
from bliss.config.settings import BaseSetting, pipeline
from bliss.config import streaming_events
from bliss.config.conductor.redis_scripts import register_script, evaluate_script
logger = logging.getLogger(__name__)
......@@ -23,6 +25,25 @@ class CustomLogger(logging.LoggerAdapter):
return "[{}] {}".format(str(self.extra), msg), kwargs
create_stream_script = """
-- Atomic creation of an empty STREAM in Redis
-- KEYS[1]: redis-key of the STREAM
local streamkey = KEYS[1]
if (redis.call("EXISTS", streamkey)==0) then
redis.call("XADD", streamkey, "0-1", "key", "value")
redis.call("XDEL", streamkey, "0-1")
end
"""
def create_data_stream(name, connection):
register_script(connection, "create_stream_script", create_stream_script)
evaluate_script(connection, "create_stream_script", keys=(name,))
class DataStream(BaseSetting):
"""An ordered dictionary of dictionaries in Redis with optionally
a maximal number of items.
......@@ -38,7 +59,9 @@ class DataStream(BaseSetting):
The dictionary values are dictionaries which represent encoded StreamEvent's.
"""
def __init__(self, name, connection=None, maxlen=None, approximate=True):
def __init__(
self, name, connection=None, maxlen=None, approximate=True, create=False
):
"""
:param str name:
:param connection:
......@@ -48,6 +71,8 @@ class DataStream(BaseSetting):
super().__init__(name, connection, None, None)
self._maxlen = maxlen
self._approximate = approximate
if create:
create_data_stream(self.name, self.connection)
def __str__(self):
return f"{self.__class__.__name__}({self.name}, maxlen={self._maxlen})"
......@@ -350,7 +375,7 @@ class DataStreamReader:
self.stop_handler = stop_handler
def __str__(self):
return "{}({} subscribed, {} activate, {} consumer".format(
return "{}({} subscribed, {} active, {} consumer)".format(
self.__class__.__name__,
self.n_subscribed_streams,
self.n_active_streams,
......@@ -433,7 +458,9 @@ class DataStreamReader:
raise TypeError("All streams must have the same redis connection")
# Create the synchronization stream
self.__synchro_stream = DataStream(str(uuid.uuid4()), maxlen=16, connection=cnx)
self.__synchro_stream = DataStream(
str(uuid.uuid4()), maxlen=16, connection=cnx, create=True
)
return self.__synchro_stream
@property
......@@ -448,10 +475,10 @@ class DataStreamReader:
return
with pipeline(synchro_stream):
if end:
self._logger.debug("SYNC_END")
self._logger.debug("PUBLISH SYNC_END")
synchro_stream.add(self.SYNC_END)
else:
self._logger.debug("SYNC_EVENT")
self._logger.debug("PUBLISH SYNC_EVENT")
synchro_stream.add(self.SYNC_EVENT)
synchro_stream.ttl(60)
......@@ -501,7 +528,7 @@ class DataStreamReader:
continue
if not ignore_excluded and stream.name in self.excluded_stream_names:
continue
self._logger.debug(f"ADD STREAM {stream.name}")
self._logger.debug("ADD STREAM %s", stream.name)
self.check_stream_connection(stream)
sinfo = self._compile_stream_info(
stream, first_index=first_index, priority=priority, **info
......@@ -656,6 +683,7 @@ class DataStreamReader:
with self._read_task_context():
keep_reading = True
synchro_name = self._synchro_stream.name
self._logger.debug("READING events starts.")
while keep_reading:
# When not waiting for new events (wait=False)
# will stop reading after reading all current
......@@ -664,10 +692,15 @@ class DataStreamReader:
keep_reading = self._wait
# When wait=True: wait indefinitely when no events
self._logger.debug("READING ...")
self._logger.debug("READING events ...")
lst = self._read_active_streams()
self._logger.debug("RECEIVED events %d streams", len(lst))
read_priority = None
for name, events in lst:
if not events:
# This happens because of empty stream creation
# in create_stream_script.
continue
name = name.decode()
sinfo = self._active_streams[name]
if read_priority is None:
......@@ -676,16 +709,16 @@ class DataStreamReader:
# Lower priority streams are never read until
# while higher priority streams have unread data
keep_reading = True
self._logger.debug("SKIP %s: %d events", name, len(events))
self._logger.debug("SKIP %d events from %s", len(events), name)
break
self._logger.debug("PROCESS %s: %d events", name, len(events))
self._logger.debug("PROCESS %d events from %s", len(events), name)
if name == synchro_name:
self._process_synchro_events(events)
keep_reading = True
else:
self._process_consumer_events(sinfo, events)
gevent.idle()
self._logger.debug("READING DONE.")
self._logger.debug("EVENTS processed.")
# Keep reading when active streams are modified
# by the consumer. This ensures that all streams
......@@ -693,6 +726,7 @@ class DataStreamReader:
self._wait_no_consuming()
if not keep_reading:
keep_reading = self.has_new_synchro_events()
self._logger.debug("READING events finished.")
def _wait_no_consuming(self):
"""Wait until the consumer is not processing an event
......@@ -715,7 +749,7 @@ class DataStreamReader:
for index, raw in events:
if streaming_events.EndEvent.istype(raw):
# stop reader loop (does not stop consumer)
self._logger.debug("STOP reading event")
self._logger.debug("RECEIVED stop event")
raise StopIteration
self._synchro_index = index
self._update_active_streams()
......@@ -723,10 +757,19 @@ class DataStreamReader:
def _log_events(self, task, stream, events):
if self._logger.getEffectiveLevel() > logging.DEBUG:
return
content = "\n ".join(
[f"{raw[b'__EVENT__']}: {raw.get(b'db_name')}" for idx, raw in events]
)
self._logger.debug(f"{task} {stream.name}:\n {content}")
content = "\n ".join(self._log_events_content(events))
self._logger.debug(f"{task} from {stream.name}:\n {content}")
@staticmethod
def _log_events_content(events):
for _, raw in events:
evttype = raw[b"__EVENT__"].decode()
db_name = raw.get(b"db_name", b"").decode()
nbytes = sys.getsizeof(raw)
if db_name:
yield f"{evttype}: {db_name} ({nbytes} bytes)"
else:
yield f"{evttype}: {nbytes} bytes"
def _process_consumer_events(self, sinfo, events):
"""Queue stream events and progress the index
......@@ -735,7 +778,7 @@ class DataStreamReader:
:param dict sinfo: stream info
:param list events: list((index, raw)))
"""
self._log_events("QUEUE", sinfo["stream"], events)
self._log_events("BUFFER events", sinfo["stream"], events)
self._queue.put((sinfo["stream"], events))
sinfo["first_index"] = events[-1][0]
......@@ -769,14 +812,15 @@ class DataStreamReader:
if not self._streams and not self._wait:
self._queue.put(StopIteration)
self._logger.debug("CONSUMING ...")
self._logger.debug("CONSUMING events starts.")
for item in self._queue:
if isinstance(item, Exception):
raise item
self._log_events("QUEUE", item[0], item[1])
self._log_events("CONSUME events", item[0], item[1])
self._consumer_state = self.ConsumerState.PROCESSING
yield item
self._consumer_state = self.ConsumerState.WAITING
finally:
self._consumer_state = self.ConsumerState.FINISHED
self._has_consumer = False
self._logger.debug("CONSUMING events finished.")
......@@ -821,6 +821,7 @@ class DataNode(metaclass=DataNodeMetaClass):
:returns DataStream:
"""
kw.setdefault("connection", self.db_connection)
kw.setdefault("create", self.__new_node)
return streaming.DataStream(name, **kw)
def _create_stream(self, suffix, **kw):
......@@ -1428,8 +1429,7 @@ class DataNodeContainer(DataNode):
def __init__(
self, node_type, name, parent=None, connection=None, create=False, **kwargs
):
DataNode.__init__(
self,
super().__init__(
node_type,
name,
parent=parent,
......
......@@ -23,7 +23,7 @@ def test_data_stream(wait_all_created, beacon):
"""Create stream and publish nevents
"""
nonlocal nstreams, stream_created, start_streaming
stream = streaming.DataStream(f"stream_{nevents}")
stream = streaming.DataStream(f"stream_{nevents}", create=True)
nstreams += 1
stream_created.set()
start_streaming.wait()
......@@ -77,11 +77,13 @@ def test_data_stream(wait_all_created, beacon):
order = []
# block=0: wait indefinitely for new events
for stream_name, events in connection.xread(streams_to_read, block=0):
if not events:
continue
nevents = int(stream_name.split(b"_")[1])
lst = data.setdefault(nevents, [])
for index, value in events:
for _, value in events:
lst.append(int(value[b"data"]))
streams_to_read[stream_name.decode()] = index
streams_to_read[stream_name.decode()] = events[-1][0]
order.append(nevents)
assert order == sorted(order, reverse=True), "read order not preserved"
if len(data) == len(streams_to_read):
......@@ -176,7 +178,7 @@ class DataStreamTestPublishers:
:param str stream_name:
"""
try:
stream = streaming.DataStream(stream_name)
stream = streaming.DataStream(stream_name, create=True)
idata = 0
while True:
gevent.sleep(random.random() / 1000)
......
......@@ -40,6 +40,13 @@ from bliss.scanning.acquisition.timer import SoftwareTimerMaster
from bliss.scanning.channel import AcquisitionChannel
@pytest.fixture
def streaming_debug_logging():
# Show streaming logs when test fails
logger = logging.getLogger("bliss.config.streaming")
logger.setLevel(logging.DEBUG)
@pytest.fixture
def lima_session(beacon, scan_tmpdir, lima_simulator):
session = beacon.get("lima_test_session")
......@@ -206,7 +213,7 @@ def _validate_node_indexing(node, shape, dtype, npoints, expected_data, extract)
extract(node[3:1])
def test_scan_data_0d(session, redis_data_conn):
def test_scan_data_0d(session, redis_data_conn, streaming_debug_logging):
simul_counter = session.env_dict.get("sim_ct_gauss")
npoints = 10
s = scans.timescan(
......@@ -239,7 +246,7 @@ def test_scan_data_0d(session, redis_data_conn):
_validate_node_indexing(node, tuple(), float, npoints, expected_data, extract)
def test_lima_data_channel_node(lima_session, redis_data_conn):
def test_lima_data_channel_node(lima_session, redis_data_conn, streaming_debug_logging):
lima_sim = lima_session.env_dict["lima_simulator"]
npoints = 10
s = scans.timescan(0.1, lima_sim, npoints=npoints)
......@@ -282,7 +289,7 @@ def test_lima_data_channel_node(lima_session, redis_data_conn):
)
def test_data_iterator_event(beacon, redis_data_conn, session):
def test_data_iterator_event(beacon, redis_data_conn, session, streaming_debug_logging):
def iterate_channel_events(scan_db_name, channels):
for e, n, data in get_node(scan_db_name).walk_events():
if e == e.NEW_DATA:
......@@ -327,7 +334,9 @@ def test_data_iterator_event(beacon, redis_data_conn, session):
@pytest.mark.parametrize("with_roi", [False, True], ids=["without ROI", "with ROI"])
def test_reference_with_lima(redis_data_conn, lima_session, with_roi):
def test_reference_with_lima(
redis_data_conn, lima_session, with_roi, streaming_debug_logging
):
lima_sim = getattr(setup_globals, "lima_simulator")
# Roi handling
......@@ -353,7 +362,9 @@ def test_reference_with_lima(redis_data_conn, lima_session, with_roi):
@pytest.mark.parametrize("with_roi", [False, True], ids=["without ROI", "with ROI"])
def test_iterator_over_reference_with_lima(redis_data_conn, lima_session, with_roi):
def test_iterator_over_reference_with_lima(
redis_data_conn, lima_session, with_roi, streaming_debug_logging
):
npoints = 5
exp_time = 1
lima_sim = getattr(setup_globals, "lima_simulator")
......@@ -443,7 +454,7 @@ def test_ttl_setter(session, capsys, enable_ttl):
assert err == ""
def test_walk_after_nodes_disappeared(session):
def test_walk_after_nodes_disappeared(session, streaming_debug_logging):
detector = session.env_dict["diode"]
s = scans.loopscan(1, 0.1, detector)
session_db_name = session.name
......@@ -513,7 +524,7 @@ def test_walk_after_nodes_disappeared(session):
validate_count(nnodes, nevents)
def test_children_timing(beacon, session):
def test_children_timing(beacon, session, streaming_debug_logging):
diode2 = session.env_dict["diode2"]
def walker(db_name):
......@@ -606,7 +617,7 @@ def test_scan_end_timing(
# if this raises "END_SCAN" event was not emitted
def test_data_shape_of_get(default_session):
def test_data_shape_of_get(default_session, streaming_debug_logging):
class myAcqDev(AcquisitionSlave):
def __init__(self):
class dev:
......@@ -658,7 +669,7 @@ def test_data_shape_of_get(default_session):
assert numpy.array(mynode.get_as_array(0, 2)).dtype == numpy.float64
def test_stop_before_any_walk_event(default_session):
def test_stop_before_any_walk_event(default_session, streaming_debug_logging):
session_node = get_session_node(default_session.name)
event = gevent.event.Event()
......@@ -677,7 +688,7 @@ def test_stop_before_any_walk_event(default_session):
task.get()
def test_stop_after_first_walk_event(session):
def test_stop_after_first_walk_event(session, streaming_debug_logging):
session_node = get_session_node(session.name)
event = gevent.event.Event()
......@@ -720,10 +731,6 @@ def _count_node_events(
:param num overhead: per event
:returns dict or list, int: events or nodes, number of detectors in scan
"""
# Show streaming logs when test fails
l = logging.getLogger("bliss.config.streaming")
l.setLevel(logging.DEBUG)
if beforestart:
wait = True
......@@ -871,7 +878,9 @@ _count_parameters = [
@pytest.mark.parametrize("beforestart, wait, include_filter", _count_parameters)
def test_walk_events_on_session_node(beforestart, wait, include_filter, session):
def test_walk_events_on_session_node(
beforestart, wait, include_filter, session, streaming_debug_logging
):
events, nmasters, nchannels = _count_node_events(
beforestart, session, session.name, include_filter=include_filter, wait=wait
)
......@@ -906,7 +915,9 @@ def test_walk_events_on_session_node(beforestart, wait, include_filter, session)
@pytest.mark.parametrize("beforestart, wait, include_filter", _count_parameters)
def test_walk_nodes_on_session_node(beforestart, wait, include_filter, session):
def test_walk_nodes_on_session_node(
beforestart, wait, include_filter, session, streaming_debug_logging
):
nodes, nmasters, nchannels = _count_nodes(
beforestart, session, session.name, include_filter=include_filter, wait=wait
)
......@@ -927,7 +938,9 @@ def test_walk_nodes_on_session_node(beforestart, wait, include_filter, session):
@pytest.mark.parametrize("beforestart", [True, False])
def test_walk_events_on_wrong_session_node(beforestart, session):
def test_walk_events_on_wrong_session_node(
beforestart, session, streaming_debug_logging
):
events, nmasters, nchannels = _count_node_events(
beforestart, session, session.name[:-1]
)
......@@ -935,13 +948,17 @@ def test_walk_events_on_wrong_session_node(beforestart, session):
@pytest.mark.parametrize("beforestart", [True, False])
def test_walk_nodes_on_wrong_session_node(beforestart, session):
def test_walk_nodes_on_wrong_session_node(
beforestart, session, streaming_debug_logging
):
nodes, nmasters, nchannels = _count_nodes(beforestart, session, session.name[:-1])
assert not nodes
@pytest.mark.parametrize("beforestart, wait, include_filter", _count_parameters)
def test_walk_events_on_dataset_node(beforestart, wait, include_filter, session):
def test_walk_events_on_dataset_node(
beforestart, wait, include_filter, session, streaming_debug_logging
):