Commit 627385a5 authored by Matias Guijarro's avatar Matias Guijarro
Browse files

Merge branch '2860-ttl-and-key-eviction-of-scan-data' into 'master'

Resolve "TTL and key eviction of scan data"

Closes #2860

See merge request !3873
parents 6e931862 5ded7f31
Pipeline #51635 passed with stages
in 102 minutes and 43 seconds
from bliss.config.conductor import client
def set_expiration_time(keys, seconds):
"""Set the expiration time of all Redis keys
"""
async_proxy = client.get_redis_proxy(db=1).pipeline()
try:
for name in keys:
async_proxy.expire(name, seconds)
finally:
async_proxy.execute()
def remove_expiration_time(keys):
"""Remove the expiration time of all Redis keys
"""
async_proxy = client.get_redis_proxy(db=1).pipeline()
try:
for name in keys:
async_proxy.persist(name)
finally:
async_proxy.execute()
......@@ -101,13 +101,11 @@ Use the following utility functions to instantiate a DataNode:
return None
"""
import time
import inspect
import pkgutil
import os
import weakref
import warnings
from numbers import Number
from bliss.common.event import dispatcher
from bliss.common.utils import grouped
from bliss.common.greenlet_utils import protect_from_kill, AllowKill
......@@ -632,35 +630,6 @@ class DataNodeAsyncHelper:
self._nodes = None
def set_ttl(db_name):
"""Set the time-to-live upon garbage collection of DataNode
which was instantiated with `create==True` (also affects the parents).
"""
if DataNode._TIMETOLIVE is None:
return
# Do not create a Redis connection pool during garbage collection
connection = client.get_existing_redis_proxy(db=1, timeout=10)
if connection is None:
return
# New instance needs to be created because we are in garbage collection
# of the original instance
node = get_node(db_name, state="exists", connection=connection)
if node is not None:
node.set_ttl()
def enable_ttl(ttl: Number = 24 * 3600):
"""Enable `set_ttl`
"""
DataNode._TIMETOLIVE = ttl
def disable_ttl():
"""Disable `set_ttl`
"""
DataNode._TIMETOLIVE = None
class DataNodeMetaClass(type):
def __call__(cls, *args, **kwargs):
"""This wraps the __init__ execution
......@@ -741,7 +710,6 @@ class DataNode(metaclass=DataNodeMetaClass):
self._struct = self._create_struct(db_name, node_type)
else:
self.__new_node = False
self._ttl_setter = None
self._struct = self._get_struct(db_name, connection=self.db_connection)
def _register_stream_priority(self, fullname: str, priority: int):
......@@ -782,8 +750,6 @@ class DataNode(metaclass=DataNodeMetaClass):
self._struct.parent = parent.db_name
if add_to_parent:
parent.add_children(self)
# Set TTL on garbage collection
self._ttl_setter = weakref.finalize(self, set_ttl, self.__db_name)
def get_nodes(self, *db_names, **kw):
"""
......@@ -990,34 +956,6 @@ class DataNode(metaclass=DataNodeMetaClass):
def connect(self, signal, callback):
dispatcher.connect(callback, signal, self)
@protect_from_kill
def set_ttl(self, include_parents=True):
"""Set the time-to-live for all Redis objects associated to this node
"""
if self._TIMETOLIVE is not None:
self.apply_ttl(set(self.get_db_names(include_parents=include_parents)))
self.detach_ttl_setter()
def detach_ttl_setter(self):
"""Make sure ttl is not set upon garbage collection.
"""
if self._ttl_setter is not None:
self._ttl_setter.detach()
def apply_ttl(self, db_names):
"""Set time-to-live for a list of Redis objects
:param list(str) db_names:
"""
if self._TIMETOLIVE is None:
return
p = self.connection.pipeline()
try:
for name in db_names:
p.expire(name, self._TIMETOLIVE)
finally:
p.execute()
def get_db_names(self, include_parents=True):
"""All associated Redis keys, including the associated keys of the parents.
"""
......
......@@ -127,6 +127,7 @@ class Dataset(DataPolicyObject):
self._store_in_icat(icat_client)
self.freeze_inherited_icat_metadata()
self._node.info["__closed__"] = True
self.set_expiration_time(include_parents=False)
self._log_debug("closed dataset")
def _store_in_icat(self, icat_client):
......
......@@ -9,6 +9,9 @@ from bliss.common.logtools import log_debug, log_warning
from bliss.common.utils import autocomplete_property
from bliss.common.namespace_wrapper import NamespaceWrapper
from bliss.icat.definitions import Definitions
from bliss.data.expiration import set_expiration_time
from bliss.data.expiration import remove_expiration_time
from bliss.config.settings import scan as scan_redis
class DataPolicyObject:
......@@ -18,6 +21,7 @@ class DataPolicyObject:
_REQUIRED_INFO = {"__name__", "__path__"}
_NODE_TYPE = NotImplemented
DATA_EXPIRATION_TIME = 600 # seconds
def __init__(self, node):
"""
......@@ -234,3 +238,19 @@ class DataPolicyObject:
@property
def metadata_is_complete(self):
return not self.missing_fields
def set_expiration_time(self, include_parents=True):
"""Includes node and children nodes. Parent nodes are optional.
"""
names = set(
scan_redis(self.node.db_name + "*", connection=self.node.connection)
)
if include_parents:
names |= set(self.node.get_db_names(include_parents=True))
set_expiration_time(names, self.DATA_EXPIRATION_TIME)
def remove_expiration_time(self):
"""Includes node and parents nodes. Child nodes are not included.
"""
names = self.node.get_db_names(include_parents=True)
remove_expiration_time(names)
......@@ -8,7 +8,6 @@
from gevent.queue import Queue
import gevent
from contextlib import contextmanager
import numpy
from bliss.scanning.chain import (
AcquisitionMaster,
AcquisitionSlave,
......@@ -20,6 +19,9 @@ from bliss.data.nodes.scan import ScanNode
from bliss.scanning.scan import ScanState, ScanPreset
from bliss.scanning.scan_info import ScanInfo
from bliss.common.logtools import user_warning
from bliss import current_session
from bliss.config.settings import scan as scan_redis
from bliss.config.conductor import client
class ScanGroup(Scan):
......@@ -250,7 +252,7 @@ class Sequence:
if self._scan is None:
return self._scan_info
else:
self.scan.scan_info
return self._scan.scan_info
@property
def state(self):
......@@ -296,14 +298,8 @@ class Group(Sequence):
class GroupingMaster(AcquisitionMaster):
def __init__(self):
AcquisitionMaster.__init__(
self,
None,
name="GroupingMaster",
npoints=0,
prepare_once=True,
start_once=True,
super().__init__(
None, name="GroupingMaster", npoints=0, prepare_once=True, start_once=True
)
self.scan_queue = Queue()
......@@ -348,12 +344,8 @@ class GroupingMaster(AcquisitionMaster):
self._number_channel.emit(int(scan.info["scan_nb"]))
self._node_channel.emit(scan.db_name)
# Reset the node TTL's
if scan.connection.ttl(scan.db_name) > 0:
scan.set_ttl()
for n in scan.walk(wait=False):
if n.connection.ttl(n.db_name) > 0:
n.set_ttl(include_parents=False)
self._reset_expiration_time(scan)
except BaseException:
self._publish_success &= False
raise
......@@ -362,6 +354,15 @@ class GroupingMaster(AcquisitionMaster):
finally:
self._publish_event.set()
def _reset_expiration_time(self, scan_node):
proxy = client.get_redis_proxy(db=1)
if proxy.ttl(scan_node.db_name) == -1:
return
scan_keys = set(scan_redis(scan_node.db_name + ":*", connection=proxy))
scan_keys |= set(scan_node.get_db_names(include_parents=False))
parent_keys = set(scan_node.parent.get_db_names(include_parents=True))
current_session.scan_saving.set_expiration_time(scan_keys, parent_keys)
def wait_all_published(self, timeout=None):
"""Wait until `_publish_new_subscan` is called for all subscans
that are queued. Publishing is done by iterating over this
......@@ -393,12 +394,12 @@ class GroupingMaster(AcquisitionMaster):
pass
class GroupingSlave(
AcquisitionSlave
): # one instance of this for channels published `on the fly` and one that is called after the scan?
def __init__(self, name, channels):
class GroupingSlave(AcquisitionSlave):
"""For custom sequence channels
"""
AcquisitionSlave.__init__(self, None, name=name)
def __init__(self, name, channels):
super().__init__(None, name=name)
self.start_event = gevent.event.Event()
for channel in channels:
self.channels.append(channel)
......
......@@ -831,10 +831,32 @@ class Scan:
self.node.end(self._scan_info, exception=_exception)
with capture():
self.set_ttl()
self.set_expiration_time()
self._update_scan_info_in_redis()
def set_expiration_time(self):
"""Set the expiration time of all Redis keys associated to this scan
"""
scan_keys = self.get_db_names()
parent_keys = self.get_parent_db_names()
self.__scan_saving.set_expiration_time(scan_keys, parent_keys)
def get_db_names(self):
"""Get all Redis keys associated to this scan
"""
db_names = set()
nodes = list(self.nodes.values())
for node in nodes:
db_names |= set(node.get_db_names(include_parents=False))
db_names |= set(self.node.get_db_names(include_parents=False))
return db_names
def get_parent_db_names(self):
"""Get all Redis keys associated to the parents of this scan
"""
return set(self.node.parent.get_db_names(include_parents=True))
def _init_scan_number(self):
self.writer.template.update(
{
......@@ -1250,18 +1272,6 @@ class Scan:
self._current_pipeline_stream = self.root_connection.pipeline()
return self._stream_pipeline_task
def set_ttl(self):
# node.get_db_names takes the most time
db_names = set()
nodes = list(self.nodes.values())
for node in nodes:
db_names |= set(node.get_db_names(include_parents=False))
db_names |= set(self.node.get_db_names())
self.node.apply_ttl(db_names)
for node in nodes:
node.detach_ttl_setter()
self.node.detach_ttl_setter()
def _device_event(self, event_dict=None, signal=None, sender=None):
if signal == "end":
if self._USE_PIPELINE_MGR:
......
......@@ -37,6 +37,7 @@ from bliss.common.utils import autocomplete_property
from bliss.icat.proposal import Proposal
from bliss.icat.dataset_collection import DatasetCollection
from bliss.icat.dataset import Dataset
from bliss.data.expiration import set_expiration_time
_SCAN_SAVING_CLASS = None
......@@ -331,6 +332,7 @@ class BasicScanSaving(EvalParametersWardrobe):
]
REDIS_SETTING_PREFIX = "scan_saving"
SLOTS = []
DATA_EXPIRATION_TIME = 600 # seconds
def __init__(self, name=None):
"""
......@@ -731,6 +733,13 @@ class BasicScanSaving(EvalParametersWardrobe):
"""
pass
def set_expiration_time(self, scan_keys, parent_keys):
"""Set the expiration time of all Redis keys associated to this scan
"""
set_expiration_time(
itertools.chain(scan_keys, parent_keys), self.DATA_EXPIRATION_TIME
)
class ESRFScanSaving(BasicScanSaving):
"""Parameterized representation of the scan data file path
......@@ -1247,6 +1256,8 @@ class ESRFScanSaving(BasicScanSaving):
self._proposal = name
self._freeze_date()
self._reset_collection()
if name:
self.proposal.remove_expiration_time()
if not isinstance(self.icat_client, IcatTangoProxy):
self.icat_client.start_investigation(
proposal=self.proposal_name, beamline=self.beamline
......@@ -1513,6 +1524,8 @@ class ESRFScanSaving(BasicScanSaving):
def _close_proposal(self):
"""Close the current proposal.
"""
if self._proposal:
self.proposal.set_expiration_time(include_parents=True)
self._proposal_object = None
self._proposal = ""
......@@ -1590,3 +1603,9 @@ class ESRFScanSaving(BasicScanSaving):
if dataset.is_closed:
raise RuntimeError("Dataset is already closed (choose a different name)")
dataset.gather_metadata(on_exists="skip")
def set_expiration_time(self, scan_keys, _):
"""Set the expiration time of the Redis keys of the scan only.
The expiration of the parent keys is handed by `DataPolicyObject`.
"""
set_expiration_time(scan_keys, self.DATA_EXPIRATION_TIME)
......@@ -48,8 +48,6 @@ from bliss.tango.clients.utils import wait_tango_device, wait_tango_db
from bliss.shell.cli.repl import BlissRepl
from bliss import logging_startup
from bliss.scanning import scan_meta
from bliss.data.node import enable_ttl as _enable_ttl
from bliss.data.node import disable_ttl as _disable_ttl
import socket
BLISS = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
......@@ -413,22 +411,7 @@ def wait_ports(ports, timeout=10):
@pytest.fixture
def disable_ttl():
_disable_ttl()
@pytest.fixture
def enable_ttl(disable_ttl):
# We use `disable_ttl` to make sure enable has priority over disable,
# regardless of the fixture order
ttl = 24 * 3600
_enable_ttl(ttl)
yield ttl
_disable_ttl()
@pytest.fixture
def beacon(ports, disable_ttl):
def beacon(ports):
redis_db = redis.Redis(port=ports.redis_port)
redis_db.flushall()
redis_data_db = redis.Redis(port=ports.redis_data_port)
......
from bliss.common.session import set_current_session
from bliss.scanning.scan_saving import ESRFScanSaving, ESRFDataPolicyEvent
def set_esrf_config(scan_saving, base_path):
# Make sure all data saving mount points
# have base_path as root in the session's
# scan saving config (in memory)
assert isinstance(scan_saving, ESRFScanSaving)
scan_saving_config = scan_saving.scan_saving_config
roots = ["inhouse_data_root", "visitor_data_root", "tmp_data_root"]
for root in roots:
for prefix in ["", "icat_"]:
key = prefix + root
mount_points = scan_saving_config.get(key, None)
if mount_points is None:
continue
elif isinstance(mount_points, str):
scan_saving_config[key] = mount_points.replace("/tmp/scans", base_path)
else:
for mp in mount_points:
mount_points[mp] = mount_points[mp].replace("/tmp/scans", base_path)
def set_esrf_data_policy(session):
# SCAN_SAVING uses the `current_session`
set_current_session(session, force=True)
assert session.name == session.scan_saving.session
# TODO: cannot use enable_esrf_data_policy directly because
# we need to modify the in-memory config before setting the proposal.
# If enable_esrf_data_policy changes however, we are in trouble.
tmpdir = session.scan_saving.base_path
session._set_scan_saving(cls=ESRFScanSaving)
set_esrf_config(session.scan_saving, tmpdir)
# session.scan_saving.get_path() set the proposal to the default
# proposal and notify ICAT. When using the `icat_subscriber` fixture,
# this will be the first event.
session._emit_event(
ESRFDataPolicyEvent.Enable, data_path=session.scan_saving.get_path()
)
def set_basic_data_policy(session):
session.disable_esrf_data_policy()
def set_data_policy(session, policy):
if policy == "basic":
set_basic_data_policy(session)
elif policy == "esrf":
set_esrf_data_policy(session)
else:
ValueError(policy, "Unsupported data policy")
......@@ -6,63 +6,17 @@
# Distributed under the GNU LGPLv3. See LICENSE for more info.
import pytest
from bliss.common.session import set_current_session
from bliss.scanning.scan_saving import ESRFScanSaving, ESRFDataPolicyEvent
def modify_esrf_policy_mount_points(scan_saving, base_path):
# Make sure all data saving mount points
# have base_path as root in the session's
# scan saving config (in memory)
assert isinstance(scan_saving, ESRFScanSaving)
scan_saving_config = scan_saving.scan_saving_config
roots = ["inhouse_data_root", "visitor_data_root", "tmp_data_root"]
for root in roots:
for prefix in ["", "icat_"]:
key = prefix + root
mount_points = scan_saving_config.get(key, None)
if mount_points is None:
continue
elif isinstance(mount_points, str):
scan_saving_config[key] = mount_points.replace("/tmp/scans", base_path)
else:
for mp in mount_points:
mount_points[mp] = mount_points[mp].replace("/tmp/scans", base_path)
def _esrf_data_policy(session):
# SCAN_SAVING uses the `current_session`
set_current_session(session, force=True)
assert session.name == session.scan_saving.session
# TODO: cannot use enable_esrf_data_policy directly because
# we need to modify the in-memory config before setting the proposal.
# If enable_esrf_data_policy changes however, we are in trouble.
tmpdir = session.scan_saving.base_path
session._set_scan_saving(cls=ESRFScanSaving)
modify_esrf_policy_mount_points(session.scan_saving, tmpdir)
# session.scan_saving.get_path() set the proposal to the default
# proposal and notify ICAT. When using the `icat_subscriber` fixture,
# this will be the first event.
session._emit_event(
ESRFDataPolicyEvent.Enable, data_path=session.scan_saving.get_path()
)
yield session.scan_saving.scan_saving_config
session.disable_esrf_data_policy()
from ..data_policies import set_data_policy
@pytest.fixture
def esrf_data_policy(session, icat_backend):
yield from _esrf_data_policy(session)
set_data_policy(session, "esrf")
@pytest.fixture
def esrf_data_policy_tango(session, icat_tango_backend):
yield from _esrf_data_policy(session)
set_data_policy(session, "esrf")
@pytest.fixture
......@@ -76,9 +30,9 @@ def session2(beacon, scan_tmpdir):
@pytest.fixture
def esrf_data_policy2(session2, icat_backend):
yield from _esrf_data_policy(session2)
set_data_policy(session2, "esrf")
@pytest.fixture
def esrf_data_policy2_tango(session2, icat_tango_backend):
yield from _esrf_data_policy(session2)
set_data_policy(session2, "esrf")
......@@ -11,6 +11,7 @@ import gevent
import os
import itertools
import numpy
import gevent.queue
from bliss.common.standard import loopscan
from bliss.common.session import set_current_session
......@@ -78,7 +79,7 @@ def test_icat_backends(
def test_inhouse_scan_saving(session, icat_subscriber, esrf_data_policy):
scan_saving = session.scan_saving
scan_saving_config = esrf_data_policy
scan_saving_config = scan_saving.scan_saving_config
icat_test_utils.assert_icat_received_current_proposal(scan_saving, icat_subscriber)
for bset in [False, True]:
......@@ -109,7 +110,7 @@ def test_visitor_scan_saving(session, icat_subscriber, esrf_data_policy):
icat_test_utils.assert_icat_received_current_proposal(scan_saving, icat_subscriber)
scan_saving.mount_point = "fs1"
scan_saving_config = esrf_data_policy
scan_saving_config = scan_saving.scan_saving_config
scan_saving.proposal_name = "mx415"
assert scan_saving.base_path == scan_saving_config["visitor_data_root"]["fs1"]
assert scan_saving.icat_base_path == scan_saving_config["visitor_data_root"]["fs1"]
......@@ -122,7 +123,7 @@ def test_tmp_scan_saving(session, icat_subscriber, esrf_data_policy):
icat_test_utils.assert_icat_received_current_proposal(scan_saving, icat_subscriber)
scan_saving.mount_point = "fs1"
scan_saving_config = esrf_data_policy
scan_saving_config = scan_saving.scan_saving_config
scan_saving.proposal_name = "test123"
expected = scan_saving_config["tmp_data_root"]["fs1"].format(
beamline=scan_saving.beamline
......@@ -913,7 +914,7 @@ def test_session_scan_saving_clone(session, esrf_data_policy):
def test_mount_points(session, esrf_data_policy):
scan_saving = session.scan_saving
scan_saving_config = esrf_data_policy
scan_saving_config = scan_saving.scan_saving_config
# Test setting mount points
assert scan_saving.mount_points == {"", "fs1", "fs2", "fs3"}
......
......@@ -8,15 +8,15 @@
import pytest
import os
import gevent
from gevent import subprocess
from contextlib import contextmanager
from bliss.common import measurementgroup
from bliss.common.tango import DevState, Database
from nexus_writer_service.subscribers.session_writer import all_cli_saveoptions
from bliss.tango.clients.utils import wait_tango_device
from tests.nexus_writer.helpers import nxw_test_config
from tests.nexus_writer.helpers import nxw_test_utils
from .helpers import nxw_test_config
from .helpers import nxw_test_utils
from ..data_policies import set_esrf_config
@pytest.fixture
......@@ -189,22 +189,8 @@ def prepare_scan_saving(session=None, tmpdir=None, policy=True, **kwargs):
tmpdir = str(tmpdir.join(session.name))
session.enable_esrf_data_policy()
scan_saving = session.scan_saving
scan_saving.writer = "nexus"
scan_saving_config = scan_saving.scan_saving_config
roots = ["inhouse_data_root", "visitor_data_root", "tmp_data_root"]
for root in roots:
for prefix in ["", "icat_"]:
key = prefix + root
mount_points = scan_saving_config.get(key, None)
if mount_points is None: