Commit e3ad70ca authored by Sebastien Petitdemange's avatar Sebastien Petitdemange Committed by Vincent Michel
Browse files

major re-factoring of data publishing.

parent c3831040
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
# Copyright (c) 2016 Beamline Control Unit, ESRF
# Distributed under the GNU LGPLv3. See LICENSE for more info.
from bliss.config.settings import QueueSetting
from bliss.data.node import DataNode
import numpy
import redis
import functools
import cPickle
def data_to_bytes(data):
if isinstance(data, numpy.ndarray):
return data.dumps()
else:
return data
def data_from_pipeline(data, shape=None, dtype=None):
if len(shape) == 0:
return numpy.array(data, dtype=dtype)
else:
a = numpy.array([numpy.loads(x) for x in data], dtype=dtype)
a.shape = (-1,)+shape
return a
def data_from_bytes(data, shape=None, dtype=None):
if isinstance(data, redis.client.Pipeline):
return functools.partial(data_from_pipeline, shape=shape, dtype=dtype)
try:
return numpy.loads(data)
except cPickle.UnpicklingError:
return float(data)
class ChannelDataNode(DataNode):
def __init__(self, name, **keys):
shape = keys.pop('shape', None)
dtype = keys.pop('dtype', None)
DataNode.__init__(self, 'channel', name, **keys)
if keys.get('create', False):
if shape is not None:
self.info["shape"] = shape
if dtype is not None:
self.info["dtype"] = dtype
cnx = self.db_connection
self._queue = QueueSetting("%s_data" % self.db_name, connection=cnx,
read_type_conversion=functools.partial(data_from_bytes, shape=shape, dtype=dtype),
write_type_conversion=data_to_bytes)
def store(self, signal, event_dict, cnx=None):
if signal == "new_data":
data = event_dict.get("data")
channel = event_dict.get("channel")
if len(channel.shape) == data.ndim:
self._queue.append(data, cnx=cnx)
else:
self._queue.extend(data, cnx=cnx)
def get(self, from_index, to_index=None, cnx=None):
if to_index is None:
return self._queue.get(from_index, from_index, cnx=cnx)
else:
return self._queue.get(from_index, to_index, cnx=cnx)
def __len__(self, cnx=None):
return self._queue.__len__(cnx=cnx)
@property
def shape(self):
return self.info.get("shape")
@property
def dtype(self):
return self.info.get("dtype")
def _get_db_names(self):
db_names = DataNode._get_db_names(self)
db_names.append(self.db_name+"_data")
return db_names
......@@ -4,7 +4,28 @@
#
# Copyright (c) 2016 Beamline Control Unit, ESRF
# Distributed under the GNU LGPLv3. See LICENSE for more info.
"""
Redis structure
--eh3 (DataNodeContainer - inherits from DataNode)
|
--scan1 (Scan - inherits from DataNode)
|
--P201 (DataNodeContainer - inherits from DataNode)
|
--c0 (ChannelDataNode - inherits from DataNode)
DataNode is the base class.
A data node has 3 Redis keys to represent it:
{db_name} -> Struct { name, db_name, node_type, parent=(parent db_name) }
{db_name}_info -> HashObjSetting, free dictionary
{db_name}_children -> QueueSetting, list of db names
The channel data node extends the structure above with:
{db_name}_channel -> QueueSetting, list of channel values
"""
import pkgutil
import inspect
import re
......@@ -16,6 +37,10 @@ from bliss.config.conductor import client
from bliss.config.settings import Struct, QueueSetting, HashObjSetting
def is_zerod(node):
return node.type == 'channel' and len(node.shape) == 0
def to_timestamp(dt, epoch=None):
if epoch is None:
epoch = datetime.datetime(1970, 1, 1)
......@@ -23,23 +48,25 @@ def to_timestamp(dt, epoch=None):
return td.microseconds / float(10**6) + td.seconds + td.days * 86400
# From continuous scan
# make list of available plugins for generating DataNode objects
node_plugins = dict()
for importer, module_name, _ in pkgutil.iter_modules([os.path.join(os.path.dirname(__file__), '..', 'data')]):
node_plugins[module_name] = importer
for importer, module_name, _ in pkgutil.iter_modules([os.path.dirname(__file__)],
prefix="bliss.data."):
node_type = module_name.replace("bliss.data.", "")
node_plugins[node_type] = module_name
def _get_node_object(node_type, name, parent, connection, create=False):
importer = node_plugins.get(node_type)
if importer is None:
return DataNode(node_type, name, parent, connection=connection, create=create)
def _get_node_object(node_type, name, parent, connection, create=False, **keys):
module_name = node_plugins.get(node_type)
if module_name is None:
return DataNodeContainer(node_type, name, parent, connection=connection, create=create, **keys)
else:
m = importer.find_module(node_type).load_module(node_type)
classes = inspect.getmembers(m, lambda x: inspect.isclass(
x) and issubclass(x, DataNode) and x != DataNode)
m = __import__(module_name, globals(), locals(), [''], -1)
classes = inspect.getmembers(m, lambda x: inspect.isclass(x) and issubclass(x, DataNode) and
x not in (DataNode, DataNodeContainer))
# there should be only 1 class inheriting from DataNode in the plugin
klass = classes[0][-1]
return klass(name, parent=parent, connection=connection, create=create)
return klass(name, parent=parent, connection=connection, create=create, **keys)
def get_node(name, node_type=None, parent=None, connection=None):
......@@ -54,20 +81,20 @@ def get_node(name, node_type=None, parent=None, connection=None):
return _get_node_object(node_type, name, parent, connection)
def _create_node(name, node_type=None, parent=None, connection=None):
def _create_node(name, node_type=None, parent=None, connection=None, **keys):
if connection is None:
connection = client.get_cache(db=1)
return _get_node_object(node_type, name, parent, connection, create=True)
return _get_node_object(node_type, name, parent, connection, create=True, **keys)
def _get_or_create_node(name, node_type=None, parent=None, connection=None):
def _get_or_create_node(name, node_type=None, parent=None, connection=None, **keys):
if connection is None:
connection = client.get_cache(db=1)
db_name = DataNode.exists(name, parent, connection)
if db_name:
return get_node(db_name, connection=connection)
else:
return _create_node(name, node_type, parent, connection)
return _create_node(name, node_type, parent, connection, **keys)
class DataNodeIterator(object):
......@@ -145,10 +172,11 @@ class DataNodeIterator(object):
pubsub = redis.pubsub()
pubsub.psubscribe("__keyspace@1__:%s*_children_list" %
self.node.db_name)
pubsub.psubscribe("__keyspace@1__:%s*_channels" % self.node.db_name)
return pubsub
def child_register_new_data(self, child_node, pubsub):
if child_node.type() == 'zerod':
if child_node.type == 'zerod':
for channel_name in child_node.channels_name():
zerod_db_name = child_node.db_name
event_key = "__keyspace@1__:%s_%s" % (
......@@ -248,15 +276,20 @@ class DataNode(object):
db_name = '%s:%s' % (parent.db_name, name) if parent else name
return db_name if connection.exists(db_name) else None
def __init__(self, node_type, name, parent=None, connection=None, create=False):
@staticmethod
def _set_ttl(db_names):
redis_conn = client.get_cache(db=1)
pipeline = redis_conn.pipeline()
for name in db_names:
pipeline.expire(name, DataNode.default_time_to_live)
pipeline.execute()
def __init__(self, node_type, name, parent=None, connection=None, create=False, **keys):
if connection is None:
connection = client.get_cache(db=1)
db_name = '%s:%s' % (parent.db_name, name) if parent else name
self._data = Struct(db_name,
connection=connection)
children_queue_name = '%s_children_list' % db_name
self._children = QueueSetting(children_queue_name,
connection=connection)
info_hash_name = '%s_info' % db_name
self._info = HashObjSetting(info_hash_name,
connection=connection)
......@@ -290,15 +323,6 @@ class DataNode(object):
return DataNodeIterator(self)
@property
def add_children(self, *child):
if len(child) > 1:
children_no = self._children.extend([c.db_name() for c in child])
else:
children_no = self._children.append(child[0].db_name())
def connect(self, signal, callback):
dispatcher.connect(callback, signal, self)
def parent(self):
parent_name = self._data.parent
if parent_name:
......@@ -307,34 +331,12 @@ class DataNode(object):
del self._data.parent
return parent
#@brief iter over children
#@return an iterator
#@param from_id start child index
#@param to_id last child index
def children(self, from_id=0, to_id=-1):
for child_name in self._children.get(from_id, to_id):
new_child = get_node(child_name)
if new_child is not None:
yield new_child
else:
self._children.remove(child_name) # clean
def last_child(self):
return get_node(self._children.get(-1))
def set_info(self, key, values):
self._info[keys] = values
if self._ttl > 0:
self._info.ttl(self._ttl)
def info_iteritems(self):
return self._info.iteritems()
def info_get(self, name):
return self._info.get(name)
@property
def info(self):
return self._info
def data_update(self, keys):
self._data.update(keys)
def connect(self, signal, callback):
dispatcher.connect(callback, signal, self)
def set_ttl(self):
db_names = set(self._get_db_names())
......@@ -342,14 +344,6 @@ class DataNode(object):
if self._ttl_setter is not None:
self._ttl_setter.disable()
@staticmethod
def _set_ttl(db_names):
redis_conn = client.get_cache(db=1)
pipeline = redis_conn.pipeline()
for name in db_names:
pipeline.expire(name, DataNode.default_time_to_live)
pipeline.execute()
def _get_db_names(self):
db_name = self.db_name
children_queue_name = '%s_children_list' % db_name
......@@ -360,5 +354,36 @@ class DataNode(object):
db_names.extend(parent._get_db_names())
return db_names
def store(self, signal, event_dict):
pass
class DataNodeContainer(DataNode):
def __init__(self, node_type, name, parent=None, connection=None, create=False):
DataNode.__init__(self, node_type, name,
parent=parent, connection=connection, create=create)
children_queue_name = '%s_children_list' % self.db_name
self._children = QueueSetting(
children_queue_name, connection=connection)
def add_children(self, *child):
if len(child) > 1:
self._children.extend([c.db_name for c in child])
else:
self._children.append(child[0].db_name)
def children(self, from_id=0, to_id=-1):
"""Iter over children.
@return an iterator
@param from_id start child index
@param to_id last child index
"""
for child_name in self._children.get(from_id, to_id):
new_child = get_node(child_name)
if new_child is not None:
yield new_child
else:
self._children.remove(child_name) # clean
@property
def last_child(self):
return get_node(self._children.get(-1))
......@@ -10,7 +10,7 @@ import datetime
import numpy
import pickle
from bliss.data.node import DataNode
from bliss.data.node import DataNodeContainer, is_zerod
def _transform_dict_obj(dict_object):
......@@ -46,9 +46,9 @@ def pickle_dump(var):
return pickle.dumps(var)
class Scan(DataNode):
class Scan(DataNodeContainer):
def __init__(self, name, create=False, **keys):
DataNode.__init__(self, 'scan', name, create=create, **keys)
DataNodeContainer.__init__(self, 'scan', name, create=create, **keys)
self.__create = create
if create:
start_time_stamp = time.time()
......@@ -82,19 +82,30 @@ def get_data(scan):
connection = scan.node.db_connection
pipeline = connection.pipeline()
for device, node in scan.nodes.iteritems():
if node.type() == 'zerod':
for channel_name in node.channels_name():
chan = node.get_channel(
channel_name, check_exists=False, cnx=pipeline)
chanlist.append(channel_name)
chan.get(0, -1) # all data
dtype.append((channel_name, 'f8'))
if node.type == 'channel':
channel_name = node.name
chan = node
# append channel name and get all data from channel;
# as it is in a Redis pipeline, get returns the
# conversion function only - data will be received
# after .execute()
chanlist.append((channel_name,
chan.get(0, -1, cnx=pipeline)))
result = pipeline.execute()
structured_array_dtype = []
for i, (channel_name, get_data_func) in enumerate(chanlist):
channel_data = get_data_func(result[i])
result[i] = channel_data
structured_array_dtype.append(
(channel_name, channel_data.dtype, channel_data.shape[1:]))
max_channel_len = max((len(values) for values in result))
data = numpy.zeros(max_channel_len, dtype=dtype)
for channel_name, values in zip(chanlist, result):
a = data[channel_name]
nb_data = len(values)
a[0:nb_data] = values[0:nb_data]
data = numpy.zeros(max_channel_len, dtype=structured_array_dtype)
for i, (channel_name, _) in enumerate(chanlist):
data[channel_name] = result[i]
return data
......@@ -33,7 +33,7 @@ class BaseCounterAcquisitionDevice(AcquisitionDevice):
if not isinstance(counter, GroupedReadMixin):
self.channels.append(AcquisitionChannel(
counter.name, numpy.double, (1,)))
counter.name, numpy.double, ()))
@property
def count_time(self):
......@@ -50,7 +50,7 @@ class BaseCounterAcquisitionDevice(AcquisitionDevice):
self.__grouped_read_counters_list.append(counter)
self.channels.append(AcquisitionChannel(
counter.name, numpy.double, (1,)))
counter.name, numpy.double, ()))
def _emit_new_data(self, data):
self.channels.update_from_iterable(data)
......
......@@ -76,7 +76,7 @@ class SoftwarePositionTriggerMaster(MotorMaster):
def __init__(self, axis, start, end, npoints=1, **kwargs):
self._positions = numpy.linspace(start, end, npoints + 1)[:-1]
MotorMaster.__init__(self, axis, start, end, **kwargs)
self.channels.append(AcquisitionChannel(axis.name, numpy.double, (1,)))
self.channels.append(AcquisitionChannel(axis.name, numpy.double, ()))
self.__nb_points = npoints
@property
......@@ -110,8 +110,7 @@ class SoftwarePositionTriggerMaster(MotorMaster):
self.movable.stop(wait=False)
self.exception = sys.exc_info()
else:
self.channels[0].value = position
self.channels.update()
self.channels[0].emit(position)
def move_done(self, done):
if done:
......@@ -221,7 +220,7 @@ class _StepTriggerMaster(AcquisitionMaster):
trigger_type=trigger_type, **keys)
self.channels.extend(
(AcquisitionChannel(axis.name, numpy.double, (1,)) for axis in self._axes))
(AcquisitionChannel(axis.name, numpy.double, ()) for axis in self._axes))
@property
def npoints(self):
......
......@@ -36,7 +36,7 @@ class MusstAcquisitionDevice(AcquisitionDevice):
self.vars = vars if vars is not None else dict()
store_list = store_list if store_list is not None else list()
self.channels.extend(
(AcquisitionChannel(name, numpy.int32, (1,)) for name in store_list))
(AcquisitionChannel(name, numpy.int32, ()) for name in store_list))
self.next_vars = None
self._iter_index = 0
......
......@@ -18,8 +18,7 @@ class SoftwareTimerMaster(AcquisitionMaster):
AcquisitionMaster.__init__(self, None, 'timer', **keys)
self.count_time = count_time
self.sleep_time = sleep_time
self.channels.append(AcquisitionChannel(
'timestamp', numpy.double, (1,)))
self.channels.append(AcquisitionChannel('timestamp', numpy.double, ()))
self._nb_point = 0
......@@ -48,8 +47,7 @@ class SoftwareTimerMaster(AcquisitionMaster):
gevent.sleep(self.sleep_time)
start_trigger = time.time()
self.channels[0].value = start_trigger
self.channels.update()
self.channels[0].emit(start_trigger)
self.trigger_slaves()
elapsed_trigger = time.time() - start_trigger
......
......@@ -8,69 +8,9 @@
from treelib import Tree
import gevent
from bliss.common.event import dispatcher
from .channel import AcquisitionChannelList, AcquisitionChannel
import time
import weakref
import numpy as np
class AcquisitionChannelList(list):
def __init__(self, parent, *args, **kwargs):
list.__init__(self, *args, **kwargs)
self.__parent = parent
def __emit_new_data(self):
dispatcher.send("new_data", self.__parent, {
"channel_data": dict(((c.name, c.value) for c in self))})
def update(self, values_dict=None):
"""Update all channels and emit the new_data event
Input:
values_dict - { channel_name: value, ... }
"""
if values_dict:
for channel in self:
channel.value = values_dict[channel.name]
self.__emit_new_data()
def update_from_iterable(self, iterable):
for i, channel in enumerate(self):
channel.value = iterable[i]
self.__emit_new_data()
def update_from_array(self, array):
for i, channel in enumerate(self):
channel.value = array[:, i]
self.__emit_new_data()
class AcquisitionChannel(object):
def __init__(self, name, dtype, shape):
self.__name = name
self.__value = None
self.dtype = dtype
self.shape = shape
@property
def name(self):
return self.__name
@property
def value(self):
return self.__value
@value.setter
def value(self, value):
value = np.atleast_1d(value).astype(self.dtype, copy=False)
if value.shape != self.shape:
raise ValueError("Channel value shape '%s` does not correspond to new value shape: %s" % (
self.shape, value.shape))
self.__value = value
class DeviceIterator(object):
......@@ -133,7 +73,7 @@ class AcquisitionMaster(object):
self.__parent = None
self.__slaves = list()
self.__triggers = list()
self.__channels = AcquisitionChannelList(self)
self.__channels = AcquisitionChannelList()
self.__npoints = npoints
self.__trigger_type = trigger_type
self.__prepare_once = prepare_once
......@@ -246,7 +186,7 @@ class AcquisitionDevice(object):
self.__parent = None
self.__name = name
self.__trigger_type = trigger_type
self.__channels = AcquisitionChannelList(self)
self.__channels = AcquisitionChannelList()
self.__npoints = npoints
self.__prepare_once = prepare_once
self.__start_once = start_once
......
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
# Copyright (c) 2016 Beamline Control Unit, ESRF
# Distributed under the GNU LGPLv3. See LICENSE for more info.
from bliss.common.event import dispatcher
from bliss.data.node import _get_or_create_node
import numpy
class AcquisitionChannelList(list):
def update(self, values_dict):
"""Update all channels and emit the new_data event
Input:
values_dict - { channel_name: value, ... }
"""
for channel in self:
channel.emit(values_dict[channel.name])
def update_from_iterable(self, iterable):
for i, channel in enumerate(self):
channel.emit(iterable[i])
def update_from_array(self, array):
for i, channel in enumerate(self):
channel.emit(array[:,i])
class AcquisitionChannel(object):
def __init__(self, name, dtype, shape, description=None, reference=False):
self.__name = name
self.__dtype = dtype
self.__shape = shape
self.__reference = reference
self.__description = { 'reference': reference }
if isinstance(description, dict):
self.__description.update(description)
@property
def name(self):