Commit 745bcfbb authored by Sebastien Petitdemange's avatar Sebastien Petitdemange
Browse files

scan: improved watching_session_scan function.

- Grouped data read for zerod channels
- Read event by block
- Follow multiple scan
parent 3b5be780
......@@ -8,9 +8,11 @@
import os
import time
import datetime
import enum
import numpy
import pickle
import gevent
from gevent.threadpool import ThreadPool
from bliss.common.cleanup import excepthook
from bliss.data.node import DataNodeIterator, _get_or_create_node, DataNodeContainer
from bliss.config import settings
......@@ -93,11 +95,19 @@ def get_data(scan):
Return a dictionary of { channel_name: numpy array }
"""
dtype = list()
scan_channel_get_data_func = dict() # { channel_name: function }
max_channel_len = 0
connection = scan.node.db_connection
pipeline = connection.pipeline()
for device, node in scan.nodes.items():
data = dict()
nodes_and_index = [(node, 0) for node in scan.nodes]
for channel_name, channel_data in get_data_from_nodes(pipeline, *nodes_and_index):
data[channel_name] = channel_data
return data
def get_data_from_nodes(pipeline, *nodes_and_start_index):
scan_channel_get_data_func = dict() # { channel_name: function }
for node, start_index in nodes_and_start_index:
if node.type == "channel":
channel_name = node.name
i = 2
......@@ -114,7 +124,7 @@ def get_data(scan):
# as it is in a Redis pipeline, get returns the
# conversion function only - data will be received
# after .execute()
scan_channel_get_data_func[channel_name] = chan.get(0, -1)
scan_channel_get_data_func[channel_name] = chan.get(start_index, -1)
finally:
chan.db_connection = saved_db_connection
......@@ -124,102 +134,162 @@ def get_data(scan):
for i, (channel_name, get_data_func) in enumerate(
scan_channel_get_data_func.items()
):
data[channel_name] = get_data_func(result[i])
yield channel_name, get_data_func(result[i])
return data
_SCAN_EVENT = enum.IntEnum("SCAN_EVENT", "NEW NEW_CHILD NEW_DATA END")
def _watch_data(
scan_node, scan_info, scan_new_child_callback, scan_data_callback, read_pipe
def _watch_data_callback(
event,
events_dict,
scan_new_callback,
scan_new_child_callback,
scan_data_callback,
scan_end_callback,
):
scan_data = dict()
data_indexes = dict()
scan_data_iterator = DataNodeIterator(scan_node, wakeup_fd=read_pipe)
for event_type, data_channel in scan_data_iterator.walk_events():
if event_type == scan_data_iterator.EVENTS.EXTERNAL_EVENT:
break
if event_type == scan_data_iterator.EVENTS.NEW_CHILD:
scan_new_child_callback(scan_info, data_channel)
elif event_type == scan_data_iterator.EVENTS.NEW_DATA_IN_CHANNEL:
data = data_channel.get(
data_indexes.setdefault(data_channel.db_name, 0), -1
)
if not data: # already received
continue
data_indexes[data_channel.db_name] += len(data)
for master, channels in scan_info["acquisition_chain"].items():
master_channels = channels["master"]
scalars = channels.get("scalars", [])
spectra = channels.get("spectra", [])
images = channels.get("images", [])
try:
for channel_name in master_channels["scalars"]:
scan_data.setdefault(channel_name, [])
if data_channel.fullname == channel_name:
scan_data[channel_name] = numpy.concatenate(
(scan_data[channel_name], data)
pool = ThreadPool(1)
running_scans = dict()
while True:
event.wait()
event.clear()
local_events = events_dict.copy()
events_dict.clear()
for event_type, event_data in local_events.items():
if event_type == _SCAN_EVENT.NEW:
for db_name, scan_info in event_data:
scan_new_callback(scan_info)
running_scans.setdefault(db_name, dict())
elif event_type == _SCAN_EVENT.NEW_CHILD:
for (
scan_db_name,
(scan_info, data_channels_event),
) in event_data.items():
scan_dict = running_scans[scan_db_name]
nodes_info = scan_dict.setdefault("nodes_info", dict())
scan_dict.setdefault("nodes_data", dict())
for (
channel_db_name,
channel_data_node,
) in data_channels_event.items():
scan_new_child_callback(scan_info, channel_data_node)
try:
fullname = channel_data_node.fullname
nodes_info.setdefault(
channel_db_name,
(fullname, len(channel_data_node.shape), 0),
)
raise StopIteration
for i, channel_name in enumerate(scalars):
scan_data.setdefault(channel_name, [])
if data_channel.fullname == channel_name:
scan_data[channel_name] = numpy.concatenate(
(scan_data.get(channel_name, []), data)
except AttributeError:
nodes_info.setdefault(
channel_db_name, (channel_data_node.name, -1, 0)
)
with excepthook():
scan_data_callback(
"0d",
master,
{
"master_channels": master_channels["scalars"],
"channel_index": i,
"channel_name": channel_name,
"data": scan_data,
"scan_info": scan_info,
},
)
raise StopIteration
for i, channel_name in enumerate(spectra):
if data_channel.db_name.endswith(channel_name):
with excepthook():
scan_data_callback(
"1d",
elif event_type == _SCAN_EVENT.NEW_DATA:
zerod_nodes = list()
other_nodes = dict()
for (
scan_db_name,
(scan_info, data_channels_event),
) in event_data.items():
scan_dict = running_scans[scan_db_name]
nodes_info = scan_dict["nodes_info"]
nodes_data = scan_dict["nodes_data"]
for (
channel_db_name,
channel_data_node,
) in data_channels_event.items():
fullname, dim, last_index = nodes_info.get(channel_db_name)
if dim == 0:
zerod_nodes.append(
(
fullname,
channel_db_name,
channel_data_node,
last_index,
)
)
else:
other_nodes[fullname] = (
channel_db_name,
dim,
channel_data_node,
)
# fetching all zerod in one go
zerod_nodes_index = [
(channel_node, start_index)
for _, _, channel_node, start_index in zerod_nodes
]
try:
connection = zerod_nodes_index[0][0].db_connection
pipeline = connection.pipeline()
except IndexError:
connection = pipeline = None
new_data_flags = False
for (
(fullname, channel_db_name, _, last_index),
(_, channel_data),
) in zip(
zerod_nodes, get_data_from_nodes(pipeline, *zerod_nodes_index)
):
new_data_flags = (
True if len(channel_data) > 0 else new_data_flags
)
prev_data = nodes_data.get(fullname, [])
nodes_data[fullname] = numpy.concatenate(
(prev_data, channel_data)
)
nodes_info[channel_db_name] = (
fullname,
0,
last_index + len(channel_data),
)
if zerod_nodes and new_data_flags:
event_channels_full_name = set(
(fullname for fullname, _, _, _ in zerod_nodes)
)
for master, channels in scan_info["acquisition_chain"].items():
channels_set = set(
channels["master"]["scalars"]
+ channels.get("scalars", [])
)
if event_channels_full_name.intersection(channels_set):
t = pool.spawn(
scan_data_callback,
"0d",
master,
{
"channel_index": i,
"channel_name": channel_name,
"data": data,
"scan_info": scan_info,
},
{"data": nodes_data, "scan_info": scan_info},
)
raise StopIteration
for i, channel_name in enumerate(images):
if data_channel.db_name.endswith(channel_name):
with excepthook():
t.get()
gevent.idle()
elif zerod_nodes:
gevent.sleep(.1) # relax a little bit
for master, channels in scan_info["acquisition_chain"].items():
other_names = channels.get("spectra", []) + channels.get(
"images", []
)
for i, channel_name in enumerate(other_names):
channel_db_name, dim, channel_data_node = other_nodes.get(
channel_name, (None, -1, None)
)
if channel_db_name:
scan_data_callback(
"2d",
f"{dim}d",
master,
{
"channel_index": i,
"channel_name": channel_name,
"data": data,
"channel_data_node": channel_data_node,
"scan_info": scan_info,
},
)
raise StopIteration
except StopIteration:
break
def safe_watch_data(*args):
with excepthook():
_watch_data(*args)
elif event_type == _SCAN_EVENT.END:
for db_name, scan_info in event_data:
if scan_end_callback:
scan_end_callback(scan_info)
running_scans.pop(db_name, None)
gevent.idle()
def watch_session_scans(
......@@ -237,8 +307,17 @@ def watch_session_scans(
return
data_iterator = DataNodeIterator(session_node, wakeup_fd=exit_read_fd)
watch_data_task = None
rpipe, wpipe = os.pipe()
events_dict = dict()
callback_event = gevent.event.Event()
watch_data_callback = gevent.spawn(
_watch_data_callback,
callback_event,
events_dict,
scan_new_callback,
scan_new_child_callback,
scan_data_callback,
scan_end_callback,
)
try:
pubsub = data_iterator.children_event_register()
......@@ -250,49 +329,69 @@ def watch_session_scans(
]
current_scan_node = None
for event_type, scan_node in data_iterator.wait_for_event(
pubsub, filter="scan"
):
running_scans = dict()
def _get_scan_info(db_name):
for key, scan_dict in running_scans.items():
if db_name.startswith(key):
return scan_dict["info"], key
return None, None
for event_type, node in data_iterator.wait_for_event(pubsub):
if event_type == data_iterator.EVENTS.EXTERNAL_EVENT:
if watch_data_task:
os.write(wpipe, b".")
watch_data_task.join()
break
elif event_type == data_iterator.EVENTS.NEW_CHILD:
if (
current_scan_node is not None
and current_scan_node.db_name == scan_node.db_name
):
continue
current_scan_node = scan_node
scan_info = scan_node.info.get_all()
if watch_data_task:
os.write(wpipe, b".")
watch_data_task.join()
# call user callbacks and start data watch task for this scan
with excepthook():
# call 'scan_new' callback, if an exception happens in user
# code the data watch task is *not* started -- it will be
# retried at next scan
scan_new_callback(scan_info)
# spawn watching task: incoming scan data triggers
# corresponding user callbacks (see code in '_watch_data')
watch_data_task = gevent.spawn(
safe_watch_data,
scan_node,
scan_info,
scan_new_child_callback,
scan_data_callback,
rpipe,
node_type = node.type
db_name = node.db_name
if node_type == "scan":
# New scan was created
scan_dictionnary = running_scans.setdefault(db_name, dict())
if not scan_dictionnary:
scan_info = node.info.get_all()
scan_dictionnary["info"] = scan_info
new_event = events_dict.setdefault(_SCAN_EVENT.NEW, list())
new_event.append((db_name, scan_info))
callback_event.set()
else:
scan_info, scan_db_name = _get_scan_info(db_name)
if scan_info: # scan_found
new_child_event = events_dict.setdefault(
_SCAN_EVENT.NEW_CHILD, dict()
)
_, scan_data_event = new_child_event.setdefault(
scan_db_name, (scan_info, dict())
)
scan_data_event.setdefault(db_name, node)
callback_event.set()
elif event_type == data_iterator.EVENTS.NEW_DATA_IN_CHANNEL:
db_name = node.db_name
scan_info, scan_db_name = _get_scan_info(db_name)
if scan_info:
new_data_event = events_dict.setdefault(
_SCAN_EVENT.NEW_DATA, dict()
)
_, new_event = new_data_event.setdefault(
scan_db_name, (scan_info, dict())
)
new_event.setdefault(db_name, node)
callback_event.set()
elif event_type == data_iterator.EVENTS.END_SCAN:
scan_info = scan_node.info.get_all()
current_scan_node = None
if scan_data_callback is not None:
scan_end_callback(scan_info)
db_name = node.db_name
scan_dict = running_scans.pop(db_name)
if scan_dict:
scan_info = scan_dict["info"]
new_event = events_dict.setdefault(_SCAN_EVENT.END, list())
new_event.append((db_name, scan_info))
callback_event.set()
# check watch_data_callback is still running
try:
watch_data_callback.get(block=False)
except BaseException as e:
if e.__class__.__name__ == "Timeout":
pass
else:
raise
finally:
if watch_data_task:
watch_data_task.kill()
watch_data_callback.kill()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment