Commit 13c635ff authored by Sebastien Petitdemange's avatar Sebastien Petitdemange Committed by Matias Guijarro
Browse files

service: added a way to serve bliss object.

Can serve counters and temperature objects.
parent a288264f
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
# Copyright (c) 2015-2020 Beamline Control Unit, ESRF
# Distributed under the GNU LGPLv3. See LICENSE for more info.
import os
import time
import argparse
import weakref
import gevent
from bliss.config import settings, static
from bliss.comm import rpc
from bliss.common import protocols
from bliss.common import counter, logtools
from . import plugins
_Port2Object = weakref.WeakValueDictionary()
class ConnectionError(RuntimeError):
pass
def get_object_from_port(port):
return _Port2Object.get(port)
# --- CLIENT ---
class Client:
def __init__(self, name, config_node):
self.__name = name
self._config_node = config_node
self._proxy = None
self._connexion_info = settings.Struct(f"service:{name}")
try:
self.connect()
except:
logtools.log_warning(self, "Service %s not running", name)
def __dir__(self):
attributes = ["close", "connect", "config", "__info__"]
try:
self.connect()
except ConnectionError:
pass
else:
try:
attr_proxy = dir(self._proxy)
except ConnectionRefusedError:
pass
else:
attributes.extend(attr_proxy)
return attributes
def __getattr__(self, name):
try:
self.connect()
except ConnectionError:
raise AttributeError(name)
return getattr(self._proxy, name)
def __setattr__(self, name, value):
if name.startswith("_Client") or name in (
"_proxy",
"_config_node",
"_connexion_info",
):
return super().__setattr__(name, value)
self.connect()
return setattr(self._proxy, name, value)
def __info__(self):
info = self._connexion_info
pid = info.pid
hostname = info.hostname
if not pid:
pid_info = "**Not Started**"
if hostname:
pid_info = pid_info + f" (previously Running on host **{hostname}**)"
else:
pid_info = f"with pid {pid}"
info_str = f"Service {self.__name} on host {hostname}:{info.port} {pid_info}"
try:
self.connect()
except ConnectionError:
pass
if self._proxy is not None:
try:
extra_info = self._proxy.__info__()
except:
pass
else:
info_str += "\n\n" + extra_info
return info_str
@property
def config(self):
return self._config_node
def close(self):
try:
if self._proxy:
self._proxy.close()
finally:
self._proxy = None
def connect(self):
if self._proxy is None:
hostname = self._connexion_info.hostname
port = self._connexion_info.port
if hostname is None or port is None:
raise ConnectionError(
f"Server service **{self.__name}** has never been started"
)
pid = self._connexion_info.pid
if not pid:
raise ConnectionError(
f"Server service **{self.__name}** is Down, "
f"previously started on **{hostname}**"
)
client = rpc.Client(
f"tcp://{hostname}:{port}", disconnect_callback=self.close
)
self._proxy = plugins.get_local_client(client, port, self.config)
@property
def counters(self):
self.connect()
if isinstance(self._proxy, protocols.CounterContainer):
return self._proxy.counters
if isinstance(self._proxy, counter.Counter):
return protocols.counter_namespace([self._proxy])
raise NotImplementedError
# --- SERVER ---
def _set_info(info, port):
if info is not None:
info._proxy.ttl(None)
info.hostname = gevent.socket.gethostname()
info.port = port
info.pid = os.getpid()
info.started = time.time()
def _start_server(obj, name, info, services, server_loop, obj_to_server):
server = rpc.Server(obj)
server.bind("tcp://localhost:0")
port = server._socket.getsockname()[1]
if name is not None:
services[name] = info, server
obj_to_server[obj] = info, server
_set_info(info, port)
_Port2Object[port] = obj
print(f"Staring service {name} for object {obj} at port {port}")
server_loop.append(gevent.spawn(server.run))
return port
def main():
"""
Server service
"""
# Argument parsing
parser = argparse.ArgumentParser()
parser.add_argument("name", nargs="+", help="named objects to export as a service")
args = parser.parse_args()
config = static.get_config()
services = dict()
server_loop = list()
obj_to_server = dict()
def start_sub_server(obj):
_, server = obj_to_server.get(obj, (None, None))
if server is None:
port = _start_server(obj, None, None, services, server_loop, obj_to_server)
else:
port = server._socket.getsockname()[1]
return port
def _start_service(name):
info = settings.Struct(f"service:{name}")
obj = config._get(name, direct_access=True)
if obj in obj_to_server:
prev_info, server = obj_to_server[obj]
if prev_info is None:
_set_info(info, server._socket.getsockname()[1])
obj_to_server[obj] = info, server
service[name] = info, server
return obj
obj = plugins.get_local_server(obj, start_sub_server)
_start_server(obj, name, info, services, server_loop, obj_to_server)
return obj
# Patch config get to start a service on any dependencies.
config.get = _start_service
try:
for name in set(args.name):
# check that is defined as a service
config_node = config.get_config(name)
if config_node is None:
raise ValueError(
f"Can't get object named **{name}**, not in configuration"
)
if not config_node.is_service:
raise ValueError(f"object **{name}** is not defined as a service")
_start_service(name)
try:
gevent.wait()
except KeyboardInterrupt:
gevent.killall(server_loop)
finally:
for info, server in services.values():
try:
server.close()
except:
pass
# Set redis key to a time to live of 2 months
# Just to clean it in case it's no more used.
if info is not None:
info._proxy.ttl(2 * 30 * 24 * 3600)
info.pid = 0
from . import main
main()
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
# Copyright (c) 2015-2020 Beamline Control Unit, ESRF
# Distributed under the GNU LGPLv3. See LICENSE for more info.
import os
import functools
import traceback
import pkgutil
import copyreg
_CLIENT_LOCAL_CALLBACK = dict()
_SERVER_LOCAL_CALLBACK = dict()
_INIT_FLAG = False
def _init_plugins():
for importer, module_name, _ in pkgutil.iter_modules(
[os.path.dirname(__file__)], prefix="bliss.comm.service.plugins."
):
m = __import__(module_name, globals(), locals(), [""], 0)
if not hasattr(m, "init"):
continue
try:
m.init()
except:
traceback.print_exc()
def _check_init(func):
@functools.wraps(func)
def f(*args, **keys):
global _INIT_FLAG
if not _INIT_FLAG:
_init_plugins()
_INIT_FLAG = True
return func(*args, **keys)
return f
def _prio_sort(item):
return item[1][0]
def register_local_client_callback(object_type, cbk, priority=0):
_CLIENT_LOCAL_CALLBACK[object_type] = priority, cbk
def register_local_server_callback(object_type, cbk, priority=0):
_SERVER_LOCAL_CALLBACK[object_type] = priority, cbk
@_check_init
def get_local_client(client, port, config):
for object_type, (priority, cbk) in sorted(
_CLIENT_LOCAL_CALLBACK.items(), key=_prio_sort, reverse=True
):
if isinstance(client, object_type):
return cbk(client, port, config)
return client
@_check_init
def get_local_server(obj, start_sub_server):
for object_type, (priority, cbk) in sorted(
_SERVER_LOCAL_CALLBACK.items(), key=_prio_sort, reverse=True
):
if isinstance(obj, object_type):
return cbk(obj, start_sub_server)
return obj
# Local server object
def _LocalServerObject(port):
from .. import get_object_from_port
return get_object_from_port(port)
def _pickle_LocalServerObject(local_object):
return _LocalServerObject, (local_object._port,)
def add_local_server_object(klass):
"""
If some objects need to be pickle from the client to
the server.
The matching is done by the port number locally.
So on the client side the object need to store the
distant server port for this object as **_port**
"""
copyreg.pickle(klass, _pickle_LocalServerObject)
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
# Copyright (c) 2015-2020 Beamline Control Unit, ESRF
# Distributed under the GNU LGPLv3. See LICENSE for more info.
import copyreg
import gevent
import weakref
from bliss.common import counter, proxy
from bliss.comm import rpc
from . import (
register_local_client_callback,
register_local_server_callback,
add_local_server_object,
)
class _LocalCounterController(proxy.Proxy):
__slots__ = list(proxy.Proxy.__slots__) + ["_counters"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._counters = weakref.WeakValueDictionary()
def create_chain_node(self):
# This has to be Local not remote
klass = getattr(self.__target__, "__class__")
return klass.create_chain_node(self)
def get_acquisition_object(self, acq_params, ctrl_params, parent_acq_params):
# This has to be Local not remote
klass = getattr(self.__target__, "__class__")
return klass.get_acquisition_object(
self, acq_params, ctrl_params, parent_acq_params
)
def __eq__(self, other):
try:
obj = other.__wrapped__
except AttributeError:
return False
else:
return obj.connection_address == self.__target__.connection_address
def __hash__(self):
return id(self.__target__)
class _LocalClientCounter(proxy.Proxy):
__slots__ = list(proxy.Proxy.__slots__) + ["_port"]
def __init__(self, client, port, config):
super().__init__(None)
self._port = port
self.__target__ = client
@property
def _counter_controller(self):
cc_client = self.__target__._counter_controller
controller = _LocalCounterController(None)
controller.__target__ = cc_client
return controller
@property
def shape(self):
return tuple(self.__target__.shape)
add_local_server_object(_LocalClientCounter)
def _server_counter_controller(obj, start_sub_server):
cc = obj._counter_controller
port = start_sub_server(cc)
hostname = gevent.socket.gethostname()
class Cnt(proxy.Proxy):
__target__ = obj
@property
def _counter_controller(self):
return rpc._SubServer(f"tcp://{hostname}:{port}")
return Cnt(None)
def init():
register_local_client_callback(counter.Counter, _LocalClientCounter)
register_local_server_callback(counter.Counter, _server_counter_controller)
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
# Copyright (c) 2015-2020 Beamline Control Unit, ESRF
# Distributed under the GNU LGPLv3. See LICENSE for more info.
import copyreg
import gevent
from bliss.common import protocols
from bliss.common import temperature, proxy
from bliss.common import counter
from bliss.comm import rpc
from bliss import global_map
from . import (
register_local_client_callback,
register_local_server_callback,
add_local_server_object,
)
from .counter import _LocalCounterController
# Input and Output temperature object
# --- Client Side ---
class _LocalSamplingCounter(counter.SamplingCounter):
pass
# Don't need any information in **read** so counter can be None
def _None(*args):
return None
def _pickle_None(*args):
return _None, ()
copyreg.pickle(_LocalSamplingCounter, _pickle_None)
class _LocalTempObject(_LocalCounterController):
__slots__ = list(_LocalCounterController.__slots__) + ["_port"]
def __init__(self, client, port, config):
super().__init__(None)
self._port = port
self.__target__ = client
global_map.register(self, parents_list=["counters"])
@property
def counters(self):
return protocols.counter_namespace(
[_LocalSamplingCounter(self.name, self, unit=self.config.get("unit"))]
)
add_local_server_object(_LocalTempObject)
# --- Server side ---
def _server_temp_controller(obj, start_sub_server):
ctrl = obj.controller
port = start_sub_server(ctrl)
hostname = gevent.socket.gethostname()
class Cnt(proxy.Proxy):
__target__ = obj
@property
def controller(self):
return rpc._SubServer(f"tcp://{hostname}:{port}")
return Cnt(None)
# Loop object
# --- Client side ---
class _LocalLoopObject(_LocalCounterController):
__slots__ = list(proxy.Proxy.__slots__) + ["_port"]
def __init__(self, client, port, config):
super().__init__(None)
self._port = port
self.__target__ = client
global_map.register(self, parents_list=["counters"])
@property
def input(self):
remote_input = self.__target__.input
connection_address = self.__target__.connection_address
port = int(connection_address.split(":")[-1])
return _LocalTempObject(remote_input, port, remote_input.config)
@property
def output(self):
remote_output = self.__target__.output
connection_address = self.__target__.connection_address
port = int(connection_address.split(":")[-1])
return _LocalTempObject(remote_output, port, remote_output.config)
@property
def counters(self):
input_obj = self.input
output_obj = self.output
return protocols.counter_namespace(input_obj.counters + output_obj.counters)
add_local_server_object(_LocalLoopObject)
# --- Server side ---
def _server_loop_temp_controller(obj, start_sub_server):
ctrl = obj.controller
controller_port = start_sub_server(ctrl)
hostname = gevent.socket.gethostname()
input_object = obj.input
input_port = start_sub_server(input_object)
output_object = obj.output
output_port = start_sub_server(output_object)
class Cnt(proxy.Proxy):
__target__ = obj
@property
def controller(self):
return rpc._SubServer(f"tcp://{hostname}:{controller_port}")
@property
def input(self):
return rpc._SubServer(f"tcp://{hostname}:{input_port}")
@property
def output(self):
return rpc._SubServer(f"tcp://{hostname}:{output_port}")
return Cnt(None)
def init():
# Input
register_local_client_callback(temperature.Input, _LocalTempObject)
register_local_server_callback(temperature.Input, _server_temp_controller)
# Output
register_local_client_callback(temperature.Output, _LocalTempObject)
register_local_server_callback(temperature.Output, _server_temp_controller)
# Loop
register_local_client_callback(temperature.Loop, _LocalLoopObject)
register_local_server_callback(temperature.Loop, _server_loop_temp_controller)
......@@ -62,6 +62,7 @@ from bliss.config.conductor import client
from bliss.config import channels
from bliss.common.utils import prudent_update, Singleton
from bliss import global_map
from bliss.comm import service
def get_config(base_path="", timeout=3., raise_yaml_exc=True):
......@@ -249,13 +250,17 @@ class ConfigNode(MutableMapping):
# key which triggers a YAML_ collection to be identified as a bliss named item
NAME_KEY = "name"
USER_TAG_KEY = "user_tag"
RPC_SERVICE_KEY = "service"
indexed_nodes = weakref.WeakValueDictionary()
tagged_nodes = defaultdict(weakref.WeakSet)
services = weakref.WeakSet()
@staticmethod
def reset_cache():
ConfigNode.indexed_nodes = weakref.WeakValueDictionary()
ConfigNode.tagged_nodes = defaultdict(weakref.WeakSet)
ConfigNode.services = weakref.WeakSet()
@staticmethod
def goto_path(d, path_as_list, key_error_exception=True):
......@@ -383,6 +388,8 @@ class ConfigNode(MutableMapping):
user_tags = value if isinstance(value, MutableSequence) else [value]
for tag in user_tags:
ConfigNode.tagged_nodes[tag].add(node)
elif key == ConfigNode.RPC_SERVICE_KEY:
ConfigNode.services.add(self)
self._data[key] = convert_value(value, self)
def setdefault(self, key, value):
......@@ -438,6 +445,14 @@ class ConfigNode(MutableMapping):