connection.py 33.1 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
Benoit Formet's avatar
Benoit Formet committed
5
# Copyright (c) 2015-2020 Beamline Control Unit, ESRF
6
7
# Distributed under the GNU LGPLv3. See LICENSE for more info.

8
import time
9
import weakref
10
import os, sys
11
import gevent
12
import gevent.lock
13
from gevent import socket, select, event, queue
14
from . import protocol
15
import netifaces
16
from functools import wraps
17
import warnings
18
from collections import namedtuple
19

20
from bliss.common.greenlet_utils import protect_from_kill, AllowKill
21
from bliss.config.conductor import redis_connection
22

23

24
class StolenLockException(RuntimeError):
25
    """This exception is raise in case of a stolen lock"""
26

27
28

def ip4_broadcast_addresses(default_route_only=False):
29
    ip_list = []
30
31
32
    # get default route interface, if any
    gws = netifaces.gateways()
    try:
33
        if default_route_only:
34
            interfaces = [gws["default"][netifaces.AF_INET][1]]
35
36
37
38
39
        else:
            interfaces = netifaces.interfaces()
        for interface in interfaces:
            for link in netifaces.ifaddresses(interface).get(netifaces.AF_INET, []):
                ip_list.append(link.get("broadcast"))
40
41
    except Exception:
        pass
42

Vincent Michel's avatar
Vincent Michel committed
43
    return [_f for _f in ip_list if _f]
44

45

46
def ip4_broadcast_discovery(udp):
47
    for addr in ip4_broadcast_addresses():
48
        udp.sendto(b"Hello", (addr, protocol.DEFAULT_UDP_SERVER_PORT))
49

50

51
52
53
54
55
56
57
58
59
60
61
62
def compare_hosts(host1, host2):
    if host1 == host2:
        return True
    if host1 == "localhost" and host2 == socket.gethostname():
        return True
    if host2 == "localhost" and host1 == socket.gethostname():
        return True
    if socket.gethostbyname(host1) == socket.gethostbyname(host2):
        return True
    return False


63
def check_connect(func):
64
    @wraps(func)
65
    def f(self, *args, **keys):
66
        self.connect()
67
68
        return func(self, *args, **keys)

69
70
    return f

71

72
class ConnectionException(Exception):
73
74
    def __init__(self, *args, **kwargs):
        Exception.__init__(self, *args, **kwargs)
75

76

77
78
RedisPoolId = namedtuple("RedisProxyId", ["db"])
RedisProxyId = namedtuple("RedisProxyId", ["db", "caching"])
79
80


81
82
83
84
85
86
87
88
89
90
91
92
class Connection:
    """A Beacon connection is created and destroyed like this:

        connection = Connection(host=..., port=...)
        connection.connect()  # not required
        connection.close()  # closes all Redis connections as well

    When `host` is not provided, it falls back to environment variable BEACON_HOST.
    When `port` is not provided, it falls back to environment variable BEACON_PORT.
    When either does not have a fallback, use UDP broadcasting to find Beacon.

    The Beacon connection also manages all Redis connections.
93
    Use `get_redis_proxy` to create a connection or use an existing one.
94
95
96
97
98
99
100
101
102
    Use `close_all_redis_connections` to close all Redis connections.

    Beacon locks: the methods `lock`, `unlock` and  `who_locked` provide
    a mechanism to acquire and release locks in the Beacon server.

    Beacon manages configuration files (YAML) and python modules. This class
    allows fetching and manipulating those.
    """

103
104
    CLIENT_NAME = f"{socket.gethostname()}:{os.getpid()}"

105
    class WaitingLock:
106
        def __init__(self, cnt, priority, device_name):
107
            self._cnt = weakref.ref(cnt)
108
109
            raw_names = [name.encode() for name in device_name]
            self._msg = b"%d|%s" % (priority, b"|".join(raw_names))
110
111
            self._queue = queue.Queue()

112
        def msg(self):
113
114
            return self._msg

115
        def get(self):
116
117
118
119
            return self._queue.get()

        def __enter__(self):
            cnt = self._cnt()
120
            pm = cnt._pending_lock.get(self._msg, [])
121
122
123
124
125
126
            if not pm:
                cnt._pending_lock[self._msg] = [self._queue]
            else:
                pm.append(self._queue)
            return self

127
        def __exit__(self, *args):
128
            cnt = self._cnt()
129
            pm = cnt._pending_lock.pop(self._msg, [])
130
131
132
133
134
135
136
137
            if pm:
                try:
                    pm.remove(self._queue)
                except ValueError:
                    pass
                cnt._pending_lock[self._msg] = pm

    class WaitingQueue(object):
138
        def __init__(self, cnt):
139
            self._cnt = weakref.ref(cnt)
140
            self._message_key = str(cnt._message_key).encode()
141
142
143
            cnt._message_key += 1
            self._queue = queue.Queue()

144
        def message_key(self):
145
146
            return self._message_key

147
        def get(self):
148
149
            return self._queue.get()

150
        def queue(self):
151
152
            return self._queue

153
        def __enter__(self):
154
155
156
157
            cnt = self._cnt()
            cnt._message_queue[self._message_key] = self._queue
            return self

158
        def __exit__(self, *args):
159
            cnt = self._cnt()
160
            cnt._message_queue.pop(self._message_key, None)
161

162
    def __init__(self, host=None, port=None):
163
        if host is None:
164
            host = os.environ.get("BEACON_HOST")
165
        if host is not None and ":" in host:
166
            host, port = host.split(":")
167
        if port is None:
168
169
170
171
172
173
174
175
            port = os.environ.get("BEACON_PORT")
        if port is not None:
            try:
                port = int(port)
            except ValueError:
                if not os.access(port, os.R_OK):
                    raise RuntimeError("port can be a tcp port (int) or unix socket")

176
        # Beacon connection
177
178
        self._host = host
        self._port = port
179
180
181
        # self._port_number is here to keep trace of port number
        # as self._port can be replaced by unix socket name.
        self._port_number = port
182
183
184
185
        self._socket = None
        self._connect_lock = gevent.lock.Semaphore()
        self._connected = gevent.event.Event()
        self._send_lock = gevent.lock.Semaphore()
186
187
        self._uds_query_event = event.Event()
        self._redis_query_event = event.Event()
188
189
        self._message_key = 0
        self._message_queue = {}
190
        self._clean_beacon_cache()
191
192
        self._raw_read_task = None

193
194
195
196
197
198
199
200
        # Beacon locks
        self._pending_lock = {}
        # Count how many time an object has been locked in the
        # current process per greenlet:
        self._lock_counters = weakref.WeakKeyDictionary()  # {Greenlet -> {str: int}}

        # Redis connections
        self._get_redis_lock = gevent.lock.Semaphore()
201

202
        # Keep hard references to all shared Redis proxies
203
        # (these proxies don't hold a `redis.Redis.Connection` instance)
204
        self._shared_redis_proxies = {}  # {RedisProxyId: RedisDbProxyBase}
205

206
        # Keep weak references to all shared Redis connection pools:
207
208
        self._redis_connection_pools = (
            weakref.WeakValueDictionary()
209
        )  # {RedisPoolId: RedisDbConnectionPool}
210

211
212
        # Keep weak references to all cached Redis proxies which are not
        # reused (although they could be but their cache with kep growing)
213
        self._non_shared_redis_proxies = weakref.WeakSet()  # {RedisDbProxyBase}
214

215
        # Hard references to the connection pools are held by the
216
217
218
219
220
        # Redis proxies themselves. Connections of RedisDbConnectionPool
        # are closed upon garbage collection of RedisDbConnectionPool. So
        # when the proxies too a pool are the only ones having a hard
        # reference too that pool, the connections are closed when all
        # proxies are garbage collected.
221
222
223
224

    def close(self, timeout=None):
        """Disconnection from Beacon and Redis
        """
225
        if self._raw_read_task is not None:
226
            self._raw_read_task.kill(timeout=timeout)
227
228
            self._raw_read_task = None

229
230
    @property
    def uds(self):
231
232
233
234
235
        """
        False: UDS not supported by this platform
        None: Port not defined
        str: Port number
        """
236
237
238
239
240
241
242
243
244
        if sys.platform in ["win32", "cygwin"]:
            return False
        else:
            try:
                int(self._port)
            except ValueError:
                return self._port
            else:
                return None
245

246
    def connect(self):
247
248
249
        """Find the Beacon server (if not already known) and make the
        TCP or UDS connection.
        """
250
251
252
253
254
255
256
257
258
259
260
261
262
        with self._connect_lock:
            if self._connected.is_set():
                return
            # Address undefined
            if self._port is None or self._host is None:
                self._host, self._port = self._discovery(self._host)

            # UDS connection
            if self.uds:
                self._socket = self._uds_connect(self.uds)
            # TCP connection
            else:
                self._socket = self._tcp_connect(self._host, self._port)
263

264
265
            # Spawn read task
            self._raw_read_task = gevent.spawn(self._raw_read)
266
            self._raw_read_task.name = "BeaconListenTask"
267

268
269
270
            # Run the UDS query
            if self.uds is None:
                self._uds_query()
271

272
            self.on_connected()
273

274
            self._connected.set()
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
275

276
277
278
279
280
    def on_connected(self):
        """Executed whenever a new connection is made
        """
        self._set_get_clientname(name=self.CLIENT_NAME, timeout=3)

281
    def _discovery(self, host, timeout=3.0):
282
283
284
285
        # Manage timeout
        if timeout < 0:
            if host is not None:
                raise RuntimeError(
286
                    f"Conductor server on host `{host}' does not reply (check beacon server)"
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                )
            raise RuntimeError(
                "Could not find any conductor "
                "(check Beacon server and BEACON_HOST environment variable)"
            )
        started = time.time()

        # Create UDP socket
        udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        udp.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
        udp.settimeout(0.2)

        # Send discovery
        address_list = [host] if host is not None else ip4_broadcast_addresses(True)
        for addr in address_list:
            try:
                udp.sendto(b"Hello", (addr, protocol.DEFAULT_UDP_SERVER_PORT))
            except socket.gaierror:
                raise ConnectionException("Host `%s' is not found in DNS" % addr)
306

307
308
309
        # Loop over UDP messages
        try:
            for message in iter(lambda: udp.recv(8192), None):
310

311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
                # Decode message
                raw_host, raw_port = message.split(b"|")
                received_host = raw_host.decode()
                received_port = int(raw_port.decode())

                # Received host doesn't match the host
                if host is not None and not compare_hosts(host, received_host):
                    continue

                # A matching host has been found
                return received_host, received_port

        # Try again
        except socket.timeout:
            timeout -= time.time() - started
            return self._discovery(host, timeout=timeout)

    def _tcp_connect(self, host, port):
329
330
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_IP, socket.IP_TOS, 0x10)
331
        try:
332
            sock.connect((host, port))
333
334
335
336
        except IOError:
            raise RuntimeError(
                "Conductor server on host `{}:{}' does not reply (check beacon server)".format(
                    host, port
337
                )
338
            )
339
        return sock
340
341

    def _uds_connect(self, uds_path):
342
343
344
        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        sock.connect(uds_path)
        return sock
345

346
    def _uds_query(self, timeout=3.0):
347
        self._uds_query_event.clear()
348
        self._sendall(
349
350
351
            protocol.message(protocol.UDS_QUERY, socket.gethostname().encode())
        )
        self._uds_query_event.wait(timeout)
352

353
    @check_connect
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
354
    def lock(self, *devices_name, **params):
355
356
        priority = params.get("priority", 50)
        timeout = params.get("timeout", 10)
357
358
        if not devices_name:
            return
359
360
361
362
        with self.WaitingLock(self, priority, devices_name) as wait_lock:
            with gevent.Timeout(
                timeout, RuntimeError("lock timeout (%s)" % str(devices_name))
            ):
363
364
                while True:
                    self._sendall(protocol.message(protocol.LOCK, wait_lock.msg()))
365
                    status = wait_lock.get()
366
367
                    if status == protocol.LOCK_OK_REPLY:
                        break
368
        self._increment_lock_counters(devices_name)
369
370

    @check_connect
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
371
    def unlock(self, *devices_name, **params):
372
373
        timeout = params.get("timeout", 1)
        priority = params.get("priority", 50)
374
        if not devices_name:
375
            return
376
377
        raw_names = [name.encode() for name in devices_name]
        msg = b"%d|%s" % (priority, b"|".join(raw_names))
378
379
380
        with gevent.Timeout(
            timeout, RuntimeError("unlock timeout (%s)" % str(devices_name))
        ):
381
            self._sendall(protocol.message(protocol.UNLOCK, msg))
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        self._decrement_lock_counters(devices_name)

    def _increment_lock_counters(self, devices_name):
        """Keep track of locking per greenlet
        """
        locked_objects = self._lock_counters.setdefault(gevent.getcurrent(), dict())
        for device in devices_name:
            nb_lock = locked_objects.get(device, 0)
            locked_objects[device] = nb_lock + 1

    def _decrement_lock_counters(self, devices_name):
        """Keep track of locking per greenlet
        """
        locked_objects = self._lock_counters.setdefault(gevent.getcurrent(), dict())
396
        max_lock = 0
397
        for device in devices_name:
398
            nb_lock = locked_objects.get(device, 0)
399
            nb_lock -= 1
400
            if nb_lock > max_lock:
401
402
403
                max_lock = nb_lock
            locked_objects[device] = nb_lock
        if max_lock <= 0:
404
            self._lock_counters.pop(gevent.getcurrent(), None)
405
406

    @check_connect
407
    def get_redis_connection_address(self, timeout=3.0):
408
409
410
        """Get the Redis host and port from Beacon. Cached for the duration
        of the Beacon connection.
        """
411
        if self._redis_host is None:
412
            with gevent.Timeout(
413
                timeout, RuntimeError("Can't get redis connection information")
414
            ):
415
                while self._redis_host is None:
416
                    self._redis_query_event.clear()
417
                    self._sendall(protocol.message(protocol.REDIS_QUERY))
418
                    self._redis_query_event.wait()
419

420
        return self._redis_host, self._redis_port
421

422
    def _get_redis_conn_pool(self, proxyid: RedisProxyId):
423
        """Get a Redis connection pool (create when it does not exist yet)
424
        for the db.
425

426
        :param RedisProxyId proxyid:
427
        :returns RedisDbConnectionPool:
428
        """
429
430
        poolid = RedisPoolId(db=proxyid.db)
        pool = self._redis_connection_pools.get(poolid)
431
        if pool is None:
432
433
            pool = self._create_redis_conn_pool(poolid)
            self._redis_connection_pools[poolid] = pool
434
435
        return pool

436
437
438
439
440
441
442
443
444
445
446
    def _create_redis_conn_pool(self, proxyid: RedisProxyId):
        """
        :param RedisProxyId proxyid:
        :returns RedisDbConnectionPool:
        """
        address = self.get_redis_connection_address()
        if proxyid.db == 1:
            try:
                address = self.get_redis_data_server_connection_address()
            except RuntimeError:  # Service not running on beacon server
                pass
447

448
449
450
451
452
453
        host, port = address
        if host == "localhost":
            redis_url = f"unix://{port}"
        else:
            redis_url = f"redis://{host}:{port}"
        return redis_connection.create_connection_pool(
454
            redis_url, proxyid.db, client_name=self.CLIENT_NAME
455
        )
456

457
    def _get_shared_redis_proxy(self, proxyid: RedisProxyId):
458
        """Get a reusabed proxy and create it when it doesn't exist.
459
        """
460
        with self._get_redis_lock:
461
            proxy = self._shared_redis_proxies.get(proxyid)
462
            if proxy is None:
463
                pool = self._get_redis_conn_pool(proxyid)
464
465
                proxy = pool.create_proxy(caching=proxyid.caching)
                self._shared_redis_proxies[proxyid] = proxy
466
            return proxy
467

468
    def _get_non_shared_redis_proxy(self, proxyid: RedisProxyId):
469
470
471
        """Get a reusabed proxy and create it when it doesn't exist.
        """
        with self._get_redis_lock:
472
473
474
            pool = self._get_redis_conn_pool(proxyid)
            proxy = pool.create_proxy(caching=proxyid.caching)
            self._non_shared_redis_proxies.add(proxy)
475
476
477
478
479
480
            return proxy

    def get_redis_connection(self, **kw):
        warnings.warn("Use 'get_redis_proxy' instead", FutureWarning)
        return self.get_redis_proxy(**kw)

481
482
    def get_redis_proxy(self, db=0, caching=False, shared=True):
        """Get a greenlet-safe proxy to a Redis database.
483

484
485
486
        :param int db: Redis database too which we need a proxy
        :param bool caching: client-side caching
        :param bool shared: use a shared proxy held by the Beacon connection
487
        """
488
489
490
        proxyid = RedisProxyId(db=db, caching=caching)
        if shared:
            return self._get_shared_redis_proxy(proxyid)
491
        else:
492
            return self._get_non_shared_redis_proxy(proxyid)
493

494
    def close_all_redis_connections(self):
495
496
497
498
499
500
501
502
        # To close `redis.connection.Connection` you need to call its
        # `disconnect` method (also called on garbage collection).
        #
        # Connection pools have a `disconnect` method that disconnect
        # all their connections, which means close and destroy their
        # socket instances.
        #
        # Note: closing a proxy will not close any connections
503
504
505
506
        proxies = list(self._non_shared_redis_proxies)
        proxies.extend(self._shared_redis_proxies.values())
        self._shared_redis_proxies = dict()
        self._non_shared_redis_proxies = weakref.WeakSet()
507
        for proxy in proxies:
508
            proxy.close()
509
            proxy.connection_pool.disconnect()
510

511
512
513
514
    def clean_all_redis_connection(self):
        warnings.warn("Use 'close_all_redis_connections' instead", FutureWarning)
        self.close_all_redis_connections()

515
    @check_connect
516
    def get_config_file(self, file_path, timeout=3.0):
517
        with gevent.Timeout(timeout, RuntimeError("Can't get configuration file")):
518
            with self.WaitingQueue(self) as wq:
519
                msg = b"%s|%s" % (wq.message_key(), file_path.encode())
520
521
                self._sendall(protocol.message(protocol.CONFIG_GET_FILE, msg))
                # self._socket.sendall(protocol.message(protocol.CONFIG_GET_FILE, msg))
522
                value = wq.get()
523
                if isinstance(value, RuntimeError):
524
525
                    raise value
                else:
526
                    return value
527

528
    @check_connect
529
    def get_config_db_tree(self, base_path="", timeout=3.0):
530
        with gevent.Timeout(timeout, RuntimeError("Can't get configuration tree")):
531
            with self.WaitingQueue(self) as wq:
532
                msg = b"%s|%s" % (wq.message_key(), base_path.encode())
533
                self._sendall(protocol.message(protocol.CONFIG_GET_DB_TREE, msg))
534
                value = wq.get()
535
                if isinstance(value, RuntimeError):
536
537
538
                    raise value
                else:
                    import json
539

540
541
                    return json.loads(value)

coutinho's avatar
coutinho committed
542
    @check_connect
543
    def remove_config_file(self, file_path, timeout=3.0):
544
        with gevent.Timeout(timeout, RuntimeError("Can't remove configuration file")):
coutinho's avatar
coutinho committed
545
            with self.WaitingQueue(self) as wq:
546
                msg = b"%s|%s" % (wq.message_key(), file_path.encode())
547
                self._sendall(protocol.message(protocol.CONFIG_REMOVE_FILE, msg))
coutinho's avatar
coutinho committed
548
                for rx_msg in wq.queue():
549
                    print(rx_msg)
coutinho's avatar
coutinho committed
550

551
    @check_connect
552
    def move_config_path(self, src_path, dst_path, timeout=3.0):
553
        with gevent.Timeout(timeout, RuntimeError("Can't move configuration file")):
554
            with self.WaitingQueue(self) as wq:
555
556
557
558
559
                msg = b"%s|%s|%s" % (
                    wq.message_key(),
                    src_path.encode(),
                    dst_path.encode(),
                )
560
                self._sendall(protocol.message(protocol.CONFIG_MOVE_PATH, msg))
561
                for rx_msg in wq.queue():
562
                    print(rx_msg)
563

564
    @check_connect
565
    def get_config_db(self, base_path="", timeout=30.0):
566
        return_files = []
567
        with gevent.Timeout(timeout, RuntimeError("Can't get configuration file")):
568
            with self.WaitingQueue(self) as wq:
569
                msg = b"%s|%s" % (wq.message_key(), base_path.encode())
570
                self._sendall(protocol.message(protocol.CONFIG_GET_DB_BASE_PATH, msg))
571
                for rx_msg in wq.queue():
572
                    if isinstance(rx_msg, RuntimeError):
573
                        raise rx_msg
574
                    file_path, file_value = self._get_msg_key(rx_msg)
575
576
                    if file_path is None:
                        continue
577
                    return_files.append((file_path.decode(), file_value.decode()))
578
579
        return return_files

Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
580
    @check_connect
581
    def set_config_db_file(self, file_path, content, timeout=3.0):
582
        with gevent.Timeout(timeout, RuntimeError("Can't set config file")):
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
583
            with self.WaitingQueue(self) as wq:
584
585
586
587
                msg = b"%s|%s|%s" % (
                    wq.message_key(),
                    file_path.encode(),
                    content.encode(),
588
                )
589
                self._sendall(protocol.message(protocol.CONFIG_SET_DB_FILE, msg))
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
590
591
592
                for rx_msg in wq.queue():
                    raise rx_msg

593
    @check_connect
594
    def get_python_modules(self, base_path="", timeout=3.0):
595
        return_module = []
596
        with gevent.Timeout(timeout, RuntimeError("Can't get python modules")):
597
            with self.WaitingQueue(self) as wq:
598
                msg = b"%s|%s" % (wq.message_key(), base_path.encode())
599
                self._sendall(protocol.message(protocol.CONFIG_GET_PYTHON_MODULE, msg))
600
                for rx_msg in wq.queue():
601
                    if isinstance(rx_msg, RuntimeError):
602
                        raise rx_msg
603
                    module_name, full_path = self._get_msg_key(rx_msg)
604
                    return_module.append((module_name.decode(), full_path.decode()))
605
        return return_module
Vincent Michel's avatar
Vincent Michel committed
606

607
    @check_connect
608
    def get_log_server_address(self, timeout=3.0):
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        """Get the log host and port from Beacon. Cached for the duration
        of the Beacon connection.
        """
        if self._log_server_host is None:
            with gevent.Timeout(
                timeout, RuntimeError("Can't retrieve log server port")
            ):
                with self.WaitingQueue(self) as wq:
                    msg = b"%s|" % wq.message_key()
                    self._socket.sendall(
                        protocol.message(protocol.LOG_SERVER_ADDRESS_QUERY, msg)
                    )
                    for rx_msg in wq.queue():
                        if isinstance(rx_msg, RuntimeError):
                            raise rx_msg
                        host, port = self._get_msg_key(rx_msg)
                        self._log_server_host = host.decode()
                        self._log_server_port = port.decode()
                        break
        return self._log_server_host, self._log_server_port
629

630
631
    @check_connect
    def get_redis_data_server_connection_address(self, timeout=3.):
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        """Get the Redis data host and port from Beacon. Cached for the duration
        of the Beacon connection.
        """
        if self._redis_data_host is None:
            with gevent.Timeout(
                timeout, RuntimeError("Can't get redis data server information")
            ):
                with self.WaitingQueue(self) as wq:
                    msg = b"%s|" % wq.message_key()
                    self._socket.sendall(
                        protocol.message(protocol.REDIS_DATA_SERVER_QUERY, msg)
                    )
                    for rx_msg in wq.queue():
                        if isinstance(rx_msg, RuntimeError):
                            raise rx_msg
                        host, port = rx_msg.split(b"|")
                        self._redis_data_host = host.decode()
                        self._redis_data_port = port.decode()
                        break
        return self._redis_data_host, self._redis_data_port
652

653
    @check_connect
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
654
    def set_client_name(self, name, timeout=3.0):
655
        self._set_get_clientname(name=name, timeout=timeout)
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
656

657
    @check_connect
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
658
    def get_client_name(self, timeout=3.0):
659
        return self._set_get_clientname(timeout=timeout)
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
660
661
662
663
664
665
666

    def who_locked(self, *names, timeout=3.0):
        name2client = dict()
        with gevent.Timeout(timeout, RuntimeError("Can't get who lock client name")):
            with self.WaitingQueue(self) as wq:
                raw_names = [b"%s" % wq.message_key()] + [n.encode() for n in names]
                msg = b"|".join(raw_names)
667
                self._sendall(protocol.message(protocol.WHO_LOCKED, msg))
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
668
669
670
671
672
673
674
                for rx_msg in wq.queue():
                    if isinstance(rx_msg, RuntimeError):
                        raise rx_msg
                    name, client_info = rx_msg.split(b"|")
                    name2client[name.decode()] = client_info.decode()
        return name2client

675
676
677
678
679
680
681
682
683
684
685
686
687
    def _set_get_clientname(self, name=None, timeout=3.):
        """Give a name for this client to the Beacon server (optional)
        and return the name under which this client is know by Beacon.
        """
        if name:
            timeout_msg = "Can't set client name"
            msg_type = protocol.CLIENT_SET_NAME
            name = name.encode()
        else:
            timeout_msg = "Can't get client name"
            msg_type = protocol.CLIENT_GET_NAME
            name = b""
        with gevent.Timeout(timeout, RuntimeError(timeout_msg)):
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
688
            with self.WaitingQueue(self) as wq:
689
                msg = b"%s|%s" % (wq.message_key(), name)
690
                self._sendall(protocol.message(msg_type, msg))
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
691
692
693
694
695
                rx_msg = wq.get()
                if isinstance(rx_msg, RuntimeError):
                    raise rx_msg
                return rx_msg.decode()

696
    def _lock_mgt(self, fd, messageType, message):
697
        if messageType == protocol.LOCK_OK_REPLY:
698
            events = self._pending_lock.get(message, [])
699
            if not events:
700
                fd.sendall(protocol.message(protocol.UNLOCK, message))
701
702
703
704
705
            else:
                e = events.pop(0)
                e.put(messageType)
            return True
        elif messageType == protocol.LOCK_RETRY:
Vincent Michel's avatar
Vincent Michel committed
706
            for m, l in self._pending_lock.items():
707
708
                for e in l:
                    e.put(messageType)
709
            return True
710
        elif messageType == protocol.LOCK_STOLEN:
711
            stolen_object_lock = set(message.split(b"|"))
712
            greenlet_to_objects = self._lock_counters.copy()
Vincent Michel's avatar
Vincent Michel committed
713
            for greenlet, locked_objects in greenlet_to_objects.items():
714
                locked_object_name = set(
Vincent Michel's avatar
Vincent Michel committed
715
                    (name for name, nb_lock in locked_objects.items() if nb_lock > 0)
716
                )
717
718
719
720
721
                if locked_object_name.intersection(stolen_object_lock):
                    try:
                        greenlet.kill(exception=StolenLockException)
                    except AttributeError:
                        pass
722
            fd.sendall(protocol.message(protocol.LOCK_STOLEN_OK_REPLY, message))
723
            return True
724
725
        return False

726
    def _get_msg_key(self, message):
727
        pos = message.find(b"|")
728
        if pos < 0:
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
729
            return message, None
730
        return message[:pos], message[pos + 1 :]
731

732
733
734
735
    def _sendall(self, msg):
        with self._send_lock:
            self._socket.sendall(msg)

736
    def _raw_read(self):
737
738
739
740
        self.__raw_read()

    @protect_from_kill
    def __raw_read(self):
741
742
743
        """This listens to Beacon indefinitely (until killed or socket error).
        Closes Beacon and Redis connections when finished.
        """
744
        try:
745
746
            data = b""
            while True:
747
748
                with AllowKill():
                    raw_data = self._socket.recv(16 * 1024)
749
750
                if not raw_data:
                    break
751
                data = b"%s%s" % (data, raw_data)
752
753
                while data:
                    try:
754
                        messageType, message, data = protocol.unpack_message(data)
755
                    except protocol.IncompleteMessage:
756
757
                        break
                    try:
758
                        # print 'rx',messageType
759
                        if self._lock_mgt(self._socket, messageType, message):
760
                            continue
761
762
763
764
765
                        elif messageType in (
                            protocol.CONFIG_GET_FILE_OK,
                            protocol.CONFIG_GET_DB_TREE_OK,
                            protocol.CONFIG_DB_FILE_RX,
                            protocol.CONFIG_GET_PYTHON_MODULE_RX,
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
766
767
                            protocol.CLIENT_NAME_OK,
                            protocol.WHO_LOCKED_RX,
768
                            protocol.LOG_SERVER_ADDRESS_OK,
769
                            protocol.REDIS_DATA_SERVER_OK,
770
771
                        ):
                            message_key, value = self._get_msg_key(message)
772
                            queue = self._message_queue.get(message_key)
773
774
775
776
777
778
779
780
781
782
                            if queue is not None:
                                queue.put(value)
                        elif messageType in (
                            protocol.CONFIG_GET_FILE_FAILED,
                            protocol.CONFIG_DB_FAILED,
                            protocol.CONFIG_SET_DB_FILE_FAILED,
                            protocol.CONFIG_GET_DB_TREE_FAILED,
                            protocol.CONFIG_REMOVE_FILE_FAILED,
                            protocol.CONFIG_MOVE_PATH_FAILED,
                            protocol.CONFIG_GET_PYTHON_MODULE_FAILED,
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
783
                            protocol.WHO_LOCKED_FAILED,
784
                            protocol.LOG_SERVER_ADDRESS_FAIL,
785
                            protocol.REDIS_DATA_SERVER_FAILED,
786
787
                        ):
                            message_key, value = self._get_msg_key(message)
788
                            queue = self._message_queue.get(message_key)
789
                            if queue is not None:
790
                                queue.put(RuntimeError(value.decode()))
791
792
793
794
795
796
                        elif messageType in (
                            protocol.CONFIG_DB_END,
                            protocol.CONFIG_SET_DB_FILE_OK,
                            protocol.CONFIG_REMOVE_FILE_OK,
                            protocol.CONFIG_MOVE_PATH_OK,
                            protocol.CONFIG_GET_PYTHON_MODULE_END,
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
797
                            protocol.WHO_LOCKED_END,
798
799
                        ):
                            message_key, value = self._get_msg_key(message)
800
                            queue = self._message_queue.get(message_key)
801
802
                            if queue is not None:
                                queue.put(StopIteration)
803
                        elif messageType == protocol.REDIS_QUERY_ANSWER:
804
                            host, port = message.split(b":", 1)
805
806
                            self._redis_host = host.decode()
                            self._redis_port = port.decode()
807
                            self._redis_query_event.set()
808
809
                        elif messageType == protocol.UDS_OK:
                            try:
810
                                uds_path = message.decode()
811
                                sock = self._uds_connect(uds_path)
812
                            except socket.error:
813
                                raise
814
                            else:
815
816
                                self._socket.close()
                                self._socket = sock
817
                                self._port = uds_path
818
                            finally:
819
                                self._uds_query_event.set()
820
                        elif messageType == protocol.UDS_FAILED:
821
                            self._uds_query_event.set()
822
                        elif messageType == protocol.UNKNOW_MESSAGE:
823
                            message_key, value = self._get_msg_key(message)
824
                            queue = self._message_queue.get(message_key)
825
826
827
828
829
                            error = RuntimeError(
                                "Beacon server don't know this command (%s)" % value
                            )
                            if queue is not None:
                                queue.put(error)
830
831
832
833
                    except:
                        sys.excepthook(*sys.exc_info())
        except socket.error:
            pass
834
835
        except gevent.GreenletExit:
            pass
836
        except:
Matias Guijarro's avatar
Matias Guijarro committed
837
            sys.excepthook(*sys.exc_info())
838
        finally:
839
            with self._connect_lock:
840
841
842
843
                self._close_beacon_connection()
                self.close_all_redis_connections()

    def _close_beacon_connection(self):
844
845
        """Result of `close` of a socket error (perhaps closed)
        """
846
847
848
849
850
851
852
853
854
        if self._socket:
            self._socket.close()
            self._socket = None
        self._connected.clear()
        self._clean_beacon_cache()

    def _clean_beacon_cache(self):
        """Clean all cached results from Beacon queries
        """
855
856
        self._redis_host = None
        self._redis_port = None
857
858
859
860
        self._redis_data_host = None
        self._redis_data_port = None
        self._log_server_host = None
        self._log_server_port = None
861
862
863

    @check_connect
    def __str__(self):
864
        return "Connection({0}:{1})".format(self._host, self._port)