utils.py 32.8 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
Benoit Formet's avatar
Benoit Formet committed
5
# Copyright (c) 2015-2020 Beamline Control Unit, ESRF
6
7
# Distributed under the GNU LGPLv3. See LICENSE for more info.

Laurent Claustre's avatar
Laurent Claustre committed
8
9
import os
import sys
10
import builtins
11
import inspect
12
import gevent
13
from gevent import threadpool
14
import types
15
import itertools
16
import functools
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
17
import numpy
18
import collections.abc
19
import importlib.util
20
import distutils.util
21
from collections.abc import MutableMapping, MutableSequence
22
import socket
Perceval Guillou's avatar
Perceval Guillou committed
23
import fnmatch
24
import contextlib
Perceval Guillou's avatar
Perceval Guillou committed
25

26
from itertools import zip_longest
27
from bliss.common.event import saferef
Perceval Guillou's avatar
Perceval Guillou committed
28

29
import typeguard
30

31

32
33
34
class ErrorWithTraceback:
    def __init__(self, error_txt="!ERR"):
        self._ERR = error_txt
35
        self.exc_info = None
36
37
38
39
40

    def __str__(self):
        return self._ERR


41
class WrappedMethod(object):
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
42
43
44
    def __init__(self, control, method_name):
        self.method_name = method_name
        self.control = control
45

Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
46
47
    def __call__(self, this, *args, **kwargs):
        return getattr(self.control, self.method_name)(*args, **kwargs)
48

49

50
def wrap_methods(from_object, target_object):
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
51
52
    for name in dir(from_object):
        if inspect.ismethod(getattr(from_object, name)):
53
54
55
            if hasattr(target_object, name) and inspect.ismethod(
                getattr(target_object, name)
            ):
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
56
                continue
57
58
59
            setattr(
                target_object,
                name,
60
                types.MethodType(WrappedMethod(from_object, name), target_object),
61
            )
62

63

64
def add_property(inst, name, method):
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
65
    cls = type(inst)
66
    module = cls.__module__
67
    if not hasattr(cls, "__perinstance"):
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
68
69
        cls = type(cls.__name__, (cls,), {})
        cls.__perinstance = True
70
        cls.__module__ = module
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
71
72
        inst.__class__ = cls
    setattr(cls, name, property(method))
73
74
75


def grouped(iterable, n):
76
77
78
79
80
81
82
83
84
    """
    Group elements of an iterable n by n.
    Return a zip object.
    s -> (s0,s1,s2,...sn-1), (sn,sn+1,sn+2,...s2n-1), (s2n,s2n+1,s2n+2,...s3n-1), ...
    Excedentary elements are discarded.
    Example:
    DEMO [5]: list(grouped([1,2,3,4,5], 2))
    Out  [5]: [(1, 2), (3, 4)]
    """
Vincent Michel's avatar
Vincent Michel committed
85
    return zip(*[iter(iterable)] * n)
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
86

87

88
def grouped_with_tail(iterable, n):
89
90
    """like grouped(), but do not remove last elements if they not reach the
    given length n"""
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    iterator = iter(iterable)
    while True:
        partial = []
        for _ in range(n):
            try:
                value = next(iterator)
            except StopIteration:
                if len(partial):
                    yield partial
                return
            else:
                partial.append(value)
        yield partial


106
107
108
109
110
111
def chunk_list(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i : i + n]


112
113
114
def flatten_gen(items):
    """Yield items from any nested iterable; see Reference."""
    for x in items:
Wout De Nolf's avatar
Wout De Nolf committed
115
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, (str, bytes)):
116
117
118
119
120
121
122
123
124
            for sub_x in flatten(x):
                yield sub_x
        else:
            yield x


def flatten(items):
    """returns a list"""
    return [i for i in flatten_gen(items)]
125
126


Benoit Formet's avatar
Benoit Formet committed
127
def merge(items):
Valentin Valls's avatar
Valentin Valls committed
128
129
130
131
132
133
134
135
    """Merge a list of list, first level only

    Example:

    .. code-block:

        merge([ [1,2], [3] ]) -> [1,2,3]
        merge([ [1,2], [[3,4]], [5] ]) -> [1,2,[3,4],5]
Benoit Formet's avatar
Benoit Formet committed
136
137
138
139
    """
    return [item for sublist in items for item in sublist]


140
141
142
def all_equal(iterable):
    g = itertools.groupby(iterable)
    return next(g, True) and not next(g, False)
143

Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
144

Perceval Guillou's avatar
Perceval Guillou committed
145
def split_keys_to_tree(dico, separator):
Valentin Valls's avatar
Valentin Valls committed
146
147
148
149
150
151
152
153
    """Takes a dict, iterate over keys and split the key around 'separator'
    into tags.

    Then creates a nested dict (like a tree) where each tag is a node.

    Example:

    .. code-block:: python
Perceval Guillou's avatar
Perceval Guillou committed
154

Valentin Valls's avatar
Valentin Valls committed
155
156
157
158
        dico = {'A_B_C':1, 'A_B_D':2}
        result = split_keys_to_tree(dico, '_')
        assert result == {'A': {'B': {'C':1, 'D':2},},}
    """
Perceval Guillou's avatar
Perceval Guillou committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    tree = {}

    for k, v in dico.items():
        tags = k.split(separator)
        if not tags:
            raise ValueError(f"cannot handle key {k}")

        depth = len(tags) - 1
        tmp = tree
        for i, tag in enumerate(tags):
            if i != depth:
                tmp = tmp.setdefault(tag, {})
            else:
                tmp[tag] = v

    return tree


177
178
179
"""
functions to add custom attributes and commands to an object.
"""
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
180
181


182
183
184
def add_object_method(
    obj, method, pre_call, name=None, args=[], types_info=(None, None)
):
185
186

    if name is None:
Vincent Michel's avatar
Vincent Michel committed
187
        name = method.__func__.__name__
188
189
190
191

    def call(self, *args, **kwargs):
        if callable(pre_call):
            pre_call(self, *args, **kwargs)
Vincent Michel's avatar
Vincent Michel committed
192
        return method.__func__(method.__self__, *args, **kwargs)
193
194

    obj._add_custom_method(
195
196
197
198
        types.MethodType(
            functools.update_wrapper(functools.partial(call, *([obj] + args)), method),
            obj,
        ),
199
200
201
        name,
        types_info,
    )
202
203


204
205
206
def object_method(
    method=None, name=None, args=[], types_info=(None, None), filter=None
):
207
208
209
210
211
    """
    Decorator to add a custom method to an object.

    The same as add_object_method but its purpose is to be used as a
    decorator to the controller method which is to be exported as object method.
212
213
214

    Return a method where _object_method_ attribute is filled with a dict of
    elements to characterize it.
215
    """
216
217
218
219
220
221

    def get_wrapper(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)

222
223
        # We strip the first argument in the signature as it will be the 'self'
        # of the instance to which the method will be attached
224
225
226
227
228
229
        sig = inspect.signature(func)
        sig = sig.replace(parameters=tuple(sig.parameters.values())[1:])
        wrapper.__signature__ = sig

        wrapper._object_method_ = dict(
            name=name, args=args, types_info=types_info, filter=filter
230
        )
231
        return wrapper
232

233
    if method is None:
234
        # Passe here if decorator is called with decorator arguments
235
236
        def object_method_wrap(func):
            return get_wrapper(func)
237

238
239
        return object_method_wrap
    else:
240
        # Passe here if the decorator is called without arguments
241
        return get_wrapper(method)
242
243


244
245
246
247
248
def object_method_type(
    method=None, name=None, args=[], types_info=(None, None), type=None
):
    def f(x):
        return isinstance(x, type)
249

250
251
252
    return object_method(
        method=method, name=name, args=args, types_info=types_info, filter=f
    )
253

254
255
256
257

def add_object_attribute(
    obj, name=None, fget=None, fset=None, args=[], type_info=None, filter=None
):
258
    obj._add_custom_attribute(name, fget, fset, type_info)
259

Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
260

261
262
263
"""
decorators for set/get methods to access to custom attributes
"""
264

Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
265

266
267
268
269
270
271
272
273
274
def object_attribute_type_get(
    get_method=None, name=None, args=[], type_info=None, type=None
):
    def f(x):
        return isinstance(x, type)

    return object_attribute_get(
        get_method=get_method, name=name, args=args, type_info=type_info, filter=f
    )
275

Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
276

277
278
279
def object_attribute_get(
    get_method=None, name=None, args=[], type_info=None, filter=None
):
280
    if get_method is None:
281
282
283
284
285
286
287
        return functools.partial(
            object_attribute_get,
            name=name,
            args=args,
            type_info=type_info,
            filter=filter,
        )
288

289
    if name is None:
Vincent Michel's avatar
Vincent Michel committed
290
        name = get_method.__name__
291
292
    attr_name = name
    if attr_name.startswith("get_"):
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
293
        attr_name = attr_name[4:]  # removes leading "get_"
294

Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
295
    get_method._object_method_ = dict(
296
297
        name=name, args=args, types_info=("None", type_info), filter=filter
    )
298
299
300

    if not hasattr(get_method, "_object_attribute_"):
        get_method._object_attribute_ = dict()
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
301
    get_method._object_attribute_.update(
302
303
        name=attr_name, fget=get_method, args=args, type_info=type_info, filter=filter
    )
304
305
306

    return get_method

307

308
309
310
311
312
def object_attribute_type_set(
    set_method=None, name=None, args=[], type_info=None, type=None
):
    def f(x):
        return isinstance(x, type)
313

314
315
316
    return object_attribute_set(
        set_method=set_method, name=name, args=args, type_info=type_info, filter=f
    )
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
317

318
319
320
321

def object_attribute_set(
    set_method=None, name=None, args=[], type_info=None, filter=None
):
322
    if set_method is None:
323
324
325
326
327
328
329
        return functools.partial(
            object_attribute_set,
            name=name,
            args=args,
            type_info=type_info,
            filter=filter,
        )
330

331
    if name is None:
Vincent Michel's avatar
Vincent Michel committed
332
        name = set_method.__name__
333
334
    attr_name = name
    if attr_name.startswith("set_"):
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
335
        attr_name = attr_name[4:]  # removes leading "set_"
336

Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
337
    set_method._object_method_ = dict(
338
339
        name=name, args=args, types_info=(type_info, "None"), filter=filter
    )
340
341
342

    if not hasattr(set_method, "_object_attribute_"):
        set_method._object_attribute_ = dict()
Sebastien Petitdemange's avatar
pep8    
Sebastien Petitdemange committed
343
    set_method._object_attribute_.update(
344
345
        name=attr_name, fset=set_method, args=args, type_info=type_info, filter=filter
    )
346
347
348
349
350
351
352
353
354

    return set_method


def set_custom_members(src_obj, target_obj, pre_call=None):
    # Creates custom methods and attributes for <target_obj> object
    # using <src_object> object definitions.
    # Populates __custom_methods_list and __custom_attributes_dict
    # for tango device server.
355
356
357
358
359
360
361
362
363
    for name, m in inspect.getmembers(src_obj.__class__, inspect.isfunction):
        # this loop carefully avoids to execute properties,
        # by looking for class members of type 'function' only.
        # Then, we get the supposed method with getattr;
        # if it is not a method we ignore the member
        member = getattr(src_obj, name)
        if not inspect.ismethod(member):
            continue

364
        if hasattr(member, "_object_attribute_"):
365
            attribute_info = dict(member._object_attribute_)
366
367
            filter_ = attribute_info.pop("filter", None)
            if filter_ is None or filter_(target_obj):
368
                add_object_attribute(target_obj, **member._object_attribute_)
369

370
371
372
373
        # For each method of <src_obj>: try to add it as a
        # custom method or as methods to set/get custom
        # attributes.
        try:
374
            method_info = dict(member._object_method_)
375
376
        except AttributeError:
            pass
377
378
379
380
        else:
            filter_ = method_info.pop("filter", None)
            if filter_ is None or filter_(target_obj):
                add_object_method(target_obj, member, pre_call, **method_info)
381

382

383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
def with_custom_members(klass):
    """A class decorator to enable custom attributes and custom methods"""

    def _get_custom_methods(self):
        try:
            return self.__custom_methods_list
        except AttributeError:
            self.__custom_methods_list = []
            return self.__custom_methods_list

    def custom_methods_list(self):
        """ Returns a *copy* of the custom methods """
        return self._get_custom_methods()[:]

    def _add_custom_method(self, method, name, types_info=(None, None)):
        setattr(self, name, method)
        self._get_custom_methods().append((name, types_info))

    def _get_custom_attributes(self):
        try:
            return self.__custom_attributes_dict
        except AttributeError:
            self.__custom_attributes_dict = {}
            return self.__custom_attributes_dict

    def custom_attributes_list(self):
        """
        List of custom attributes defined for this axis.
        Internal usage only
        """
        ad = self._get_custom_attributes()

        # Converts dict into list...
        return [(a_name, ad[a_name][0], ad[a_name][1]) for a_name in ad]

    def _add_custom_attribute(self, name, fget=None, fset=None, type_info=None):
        custom_attrs = self._get_custom_attributes()
        attr_info = custom_attrs.get(name)
        if attr_info:
            orig_type_info, access_mode = attr_info
423
            if fget and "r" not in access_mode:
424
                access_mode = "rw"
425
            if fset and "w" not in access_mode:
426
427
                access_mode = "rw"
            assert type_info == orig_type_info, "%s get/set types mismatch" % name
428
        else:
429
430
            access_mode = "r" if fget else ""
            access_mode += "w" if fset else ""
431
            if fget is None and fset is None:
432
                raise RuntimeError("impossible case: must have fget or fset...")
433
434
435
436
437
438
439
440
441
442
443
444
        custom_attrs[name] = type_info, access_mode

    klass._get_custom_methods = _get_custom_methods
    klass.custom_methods_list = property(custom_methods_list)
    klass._add_custom_method = _add_custom_method
    klass._get_custom_attributes = _get_custom_attributes
    klass.custom_attributes_list = property(custom_attributes_list)
    klass._add_custom_attribute = _add_custom_attribute

    return klass


445
446
447
class Null(object):
    __slots__ = []

448
449
450
    def __call__(self, *args, **kwargs):
        pass

451

452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
class StripIt(object):
    """
    Encapsulate object with a short str/repr/format.
    Useful to have in log messages since it only computes the representation
    if the log message is recorded. Example::

        >>> import logging
        >>> logging.basicConfig(level=logging.DEBUG)

        >>> from bliss.common.utils import StripIt

        >>> msg_from_socket = 'Here it is my testament: ' + 50*'bla '
        >>> logging.debug('Received: %s', StripIt(msg_from_socket))
        DEBUG:root:Received: Here it is my testament: bla bla bla bla bla [...]
    """
467
468

    __slots__ = "obj", "max_len"
469
470
471
472
473
474
475
476

    def __init__(self, obj, max_len=50):
        self.obj = obj
        self.max_len = max_len

    def __strip(self, s):
        max_len = self.max_len
        if len(s) > max_len:
477
478
            suffix = " [...]"
            s = s[: max_len - len(suffix)] + suffix
479
480
481
482
483
484
485
486
487
488
        return s

    def __str__(self):
        return self.__strip(str(self.obj))

    def __repr__(self):
        return self.__strip(repr(self.obj))

    def __format__(self, format_spec):
        return self.__strip(format(self.obj, format_spec))
489

490

491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
class periodic_exec(object):
    def __init__(self, period_in_s, func):
        if not callable(func):
            self.func_ref = None
        else:
            self.func_ref = saferef.safe_ref(func)
        self.period = period_in_s
        self.__task = None

    def __enter__(self):
        if self.period > 0 and self.func_ref:
            self.__task = gevent.spawn(self._timer)

    def __exit__(self, *args):
        if self.__task is not None:
506
            self.__task.kill()
507
508
509
510
511
512
513
514
515
516
517

    def _timer(self):
        while True:
            func = self.func_ref()
            if func is None:
                return
            else:
                func()
                del func
                gevent.sleep(self.period)

518

519
520
def safe_get(obj, member, on_error=None, **kwargs):
    try:
521
        if isinstance(getattr(obj.__class__, member), property):
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
522
523
524
            return getattr(obj, member)
        else:
            return getattr(obj, member)(**kwargs)
525
    except Exception:
526
        if on_error:
527
            if isinstance(on_error, ErrorWithTraceback):
528
                on_error.exc_info = sys.exc_info()
529
530
            return on_error

531

532
533
def common_prefix(paths, sep=os.path.sep):
    def allnamesequal(name):
534
535
        return all(n == name[0] for n in name[1:])

536
    bydirectorylevels = zip(*[p.split(sep) for p in paths])
537
    return sep.join(x[0] for x in itertools.takewhile(allnamesequal, bydirectorylevels))
538

539

540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
class autocomplete_property(property):
    """
    a custom property class that will be added to 
    jedi's ALLOWED_DESCRIPTOR_ACCESS via 
    
    from jedi.evaluate.compiled import access
    access.ALLOWED_DESCRIPTOR_ACCESS += (autocomplete_property,)
    
    in the bliss shell so
    @property  --> not evaluated for autocompletion
    @autocomplete_property  --> evaluated for autocompletion
    
    the @autocomplete_property decorator is especially useful for
    counter namespaces or similar object
    """

    pass
557
558


Linus Pithan's avatar
Linus Pithan committed
559
560
561
562
563
564
565
566
567
568
569
570
# the following code around `UserNamespace` is about a namespace that
# that has autocomplete_property properties itself.  It provides a
# signature completion in the bliss shell also for its members.
# More details in the doc bliss doc.
#
# BLISS [1]: from bliss.common.utils import UserNamespace
# BLISS [2]: def a(self,kwarg=13):
#       ...:     print(a)
# BLISS [4]: c=UserNamespace({"a":a})
# BLISS [5]: c.a(
#               a(self, kwarg=13)   # signature suggestion

571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
# create a copy of module collections to have a copy of namedtuple
__SPEC = importlib.util.find_spec("collections")
mycollections = importlib.util.module_from_spec(__SPEC)
__SPEC.loader.exec_module(mycollections)
sys.modules["mycollections"] = mycollections

from mycollections import namedtuple as UserNamedtuple  # noqa E402

# patch property to trigger jedi signature hint
UserNamedtuple.__globals__["property"] = autocomplete_property


def UserNamespace(env_dict={}):
    klass = UserNamedtuple("namespace", env_dict, module=__name__ + ".namespace")

    def namespace_dir(self):
        __dir__ = super(self.__class__, self).__dir__()
        to_remove = []
        if "count" not in env_dict:
            to_remove.append("count")
        if "index" not in env_dict:
            to_remove.append("index")
        return [i for i in __dir__ if i not in to_remove]

    # patch dir function to hide "count" & "index" built-in tuples functions from jedi completion
    klass.__dir__ = namespace_dir
    ns = klass(**env_dict)
    return ns


Linus Pithan's avatar
Linus Pithan committed
601
602
def deep_update(d, u):
    """Do a deep merge of one dict into another.
603

Linus Pithan's avatar
Linus Pithan committed
604
605
606
    This will update d with values in u, but will not delete keys in d
    not found in u at some arbitrary depth of d. That is, u is deeply
    merged into d.
Linus Pithan's avatar
Linus Pithan committed
607

Linus Pithan's avatar
Linus Pithan committed
608
609
    Args -
      d, u: dicts
Linus Pithan's avatar
Linus Pithan committed
610

Linus Pithan's avatar
Linus Pithan committed
611
    Note: this is destructive to d, but not u.
Linus Pithan's avatar
Linus Pithan committed
612

Linus Pithan's avatar
Linus Pithan committed
613
    Returns: None
Linus Pithan's avatar
Linus Pithan committed
614
    """
Linus Pithan's avatar
Linus Pithan committed
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
    stack = [(d, u)]
    while stack:
        d, u = stack.pop(0)
        for k, v in u.items():
            if not isinstance(v, collections.abc.Mapping):
                # u[k] is not a dict, nothing to merge, so just set it,
                # regardless if d[k] *was* a dict
                d[k] = v
            else:
                # note: u[k] is a dict

                # get d[k], defaulting to a dict, if it doesn't previously
                # exist
                dv = d.setdefault(k, {})

                if not isinstance(dv, collections.abc.Mapping):
                    # d[k] is not a dict, so just set it to u[k],
                    # overriding whatever it was
                    d[k] = v
                else:
                    # both d[k] and u[k] are dicts, push them on the stack
                    # to merge
                    stack.append((dv, v))
638
639


640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
def is_basictype(val):
    return isinstance(val, (int, str, float, type(None)))


def is_complextype(val):
    return isinstance(val, (MutableMapping, MutableSequence))


def is_mutsequence(val):
    return isinstance(val, MutableSequence)


def is_mutmapping(val):
    return isinstance(val, MutableMapping)


def is_sametype(val1, val2):
    if is_basictype(val1) and is_basictype(val2) and (type(val1) == type(val2)):
        return True
    elif is_mutmapping(val1) and is_mutmapping(val2):
        return True
    elif is_mutsequence(val1) and is_mutsequence(val2):
        return True


MISSING = "---missing---"


def prudent_update(d, u):
    """Updates a MutableMapping or MutalbeSequence 'd'
    from another one 'u'.
    The update is done trying to minimize changes: the
    update is done only on leaves of the tree if possible.
    This is to preserve the original object as much as possible.
    """
    if is_basictype(d) and is_basictype(u):
        if d != u:
            if d == MISSING:
                return u
            elif u == MISSING:
                return d
            return u
        else:
            return d  # prefer not to update
    elif is_complextype(d) and is_complextype(u):
        if is_sametype(d, u):
            # same type
            if is_mutmapping(d):
                for k, v in u.items():
                    if k in d:
                        d[k] = prudent_update(d[k], v)
                    else:
                        d[k] = v
            elif is_mutsequence(d):
Matias Guijarro's avatar
Matias Guijarro committed
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
                if len(u) < len(d):
                    # issue 2348: if updated element is smaller than existing one, updated replaces existing
                    d = u
                else:
                    for num, (el1, el2) in enumerate(
                        zip_longest(d, u, fillvalue=MISSING)
                    ):
                        if el2 == MISSING:
                            # Nothing to do
                            pass
                        else:
                            # missing el1 is managed by prudent_update
                            # when el1==MISSING el2!=MISSING -> el2 returned
                            value = prudent_update(el1, el2)
                            try:
                                d[num] = value
                            except IndexError:
                                d.append(value)
712
713
714
715
716
717
718
719
720
721
722
723
724
725
            else:
                raise NotImplementedError
            return d
        else:
            # not same type so the destination will be replaced
            return u
    elif is_basictype(d) and is_complextype(u):
        return u
    elif is_complextype(d) and is_basictype(u):
        return u
    else:
        raise NotImplementedError


Linus Pithan's avatar
Linus Pithan committed
726
727
728
729
730
731
732
733
734
735
736
737
738
def update_node_info(node, d):
    """updates the BaseHashSetting of a DataNode and does a deep update if needed. 
    parameters: node: DataNode or DataNodeContainer; d: dict"""
    assert type(d) == dict
    for key, value in d.items():
        tmp = node.info.get(key)
        if tmp and type(value) == dict and type(tmp) == dict:
            deep_update(tmp, value)
            node.info[key] = tmp
        else:
            node.info[key] = value


739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
def rounder(template_number, number):
    """Round a number according to a template number
    
    assert rounder(0.0001, 16.12345) == "16.1234"
    assert rounder(1, 16.123) == "16"
    assert rounder(0.1, 8.5) == "8.5"
    """
    precision = (
        len(str(template_number).split(".")[-1])
        if not float(template_number).is_integer()
        else 0
    )
    return numpy.format_float_positional(
        number, precision=precision, unique=False, trim="-"
    )
754
755


Linus Pithan's avatar
Linus Pithan committed
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
def round(a, decimals=None, out=None, precision=None):
    """
    like numpy.round just with extened signature that 
    can deal with precision (a template number providing
    the smallest significant increment)

    assert round(16.123,precision=.2) == 16.1
    assert round(16.123,precision=1) == 16 
    assert round(16.123,precision=0.0001) == 16.123     
    """
    if decimals is not None:
        return numpy.round(a, decimals=decimals, out=out)
    elif precision is not None:
        digits = int(numpy.ceil(numpy.log10(1 / precision)))
        return numpy.round(a, digits)
    else:
        return numpy.round(a, decimals=0, out=out)


775
776
777
778
779
class ShellStr(str):
    """Subclasses str to give a nice representation in the Bliss shell"""

    def __info__(self):
        return str(self)
780
781
782
783
784
785
786
787
788
789
790


def get_open_ports(n):
    sockets = [socket.socket() for _ in range(n)]
    try:
        for s in sockets:
            s.bind(("", 0))
        return [s.getsockname()[1] for s in sockets]
    finally:
        for s in sockets:
            s.close()
Laurent Claustre's avatar
Laurent Claustre committed
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843


class ColorTags:
    PURPLE = "\033[95m"
    CYAN = "\033[96m"
    DARKCYAN = "\033[36m"
    BLUE = "\033[94m"
    GREEN = "\033[92m"
    YELLOW = "\033[93m"
    RED = "\033[91m"
    BOLD = "\033[1m"
    UNDERLINE = "\033[4m"
    END = "\033[0m"


def __color_message(tag, msg):
    return "{0}{1}{2}".format(tag, msg, ColorTags.END)


def PURPLE(msg):
    return __color_message(ColorTags.PURPLE, msg)


def CYAN(msg):
    return __color_message(ColorTags.CYAN, msg)


def DARKCYAN(msg):
    return __color_message(ColorTags.DARKCYAN, msg)


def BLUE(msg):
    return __color_message(ColorTags.BLUE, msg)


def GREEN(msg):
    return __color_message(ColorTags.GREEN, msg)


def YELLOW(msg):
    return __color_message(ColorTags.YELLOW, msg)


def RED(msg):
    return __color_message(ColorTags.RED, msg)


def UNDERLINE(msg):
    return __color_message(ColorTags.UNDERLINE, msg)


def BOLD(msg):
    return __color_message(ColorTags.BOLD, msg)
844
845
846
847
848


def shorten_signature(original_function=None, *, annotations=None, hidden_kwargs=None):
    """decorator that can be used to simplyfy the signature displayed in the bliss shell.
       by default it is removing the annotation of each parameter or replacing it with a custum one.
849

850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
       annotations: dict with parameters as key
       hidden_kwargs: list of parameters that should not be displayed but remain usable.
    """

    def _decorate(function):
        @functools.wraps(function)
        def wrapped_function(*args, **kwargs):
            return function(*args, **kwargs)

        sig = inspect.signature(function)
        params = list(sig.parameters.values())
        to_be_removed = list()
        for i, param in enumerate(params):
            if hidden_kwargs and param.name in hidden_kwargs:
                to_be_removed.append(param)
            elif annotations and param.name in annotations.keys():
                params[i] = param.replace(annotation=annotations[param.name])
                # ,default=inspect.Parameter.empty)
            else:
                params[i] = param.replace(annotation=inspect.Parameter.empty)
        for p in to_be_removed:
            params.remove(p)
        sig = sig.replace(parameters=params)
        wrapped_function.__signature__ = sig

        return wrapped_function

    if original_function:
        return _decorate(original_function)

    return _decorate
881
882
883


def custom_error_msg(
884
    exception_type, message, new_exception_type=None, display_original_msg=False
885
):
886
    """decorator to modify exception and/or the corresponding message"""
887
888
889
890
891
892

    def _decorate(function):
        @functools.wraps(function)
        def wrapped_function(*args, **kwargs):
            try:
                return function(*args, **kwargs)
893
            except Exception as e:
894
                if isinstance(e, exception_type):
895
896
897
                    if new_exception_type:
                        new_exception = new_exception_type
                    else:
898
                        new_exception = exception_type
899
                    if display_original_msg:
900
                        raise new_exception(message + " " + str(e)) from e
901
                    else:
902
                        raise new_exception(message) from e
903
                else:
904
                    raise
905
906
907
908
909
910

        return wrapped_function

    return _decorate


911
912
913
914
915
916
917
918
919
920
921
922
class TypeguardTypeError(TypeError):
    """TypeError that is used only in Typeguard module
       should be pushed to typeguard repositoy
    """

    pass


typeguard.TypeError = TypeguardTypeError


def typeguardTypeError_to_hint(function):
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
    """decorator that transforms TypeError into a simpliyed RuntimeError
    Intended use: Modifying the message when using @typeguard.typechecked
    """

    @functools.wraps(function)
    def wrapped_function(*args, **kwargs):
        sig = inspect.signature(function)
        params = list(sig.parameters.values())
        msg = (
            "Intended Usage: "
            + function.__name__
            + "("
            + ", ".join(
                [p.name for p in params if p.default == inspect.Parameter.empty]
            )
            + ")  Hint:"
            + ""
        )
        return custom_error_msg(
942
943
944
945
            TypeguardTypeError,
            msg,
            new_exception_type=RuntimeError,
            display_original_msg=True,
946
947
948
        )(function)(*args, **kwargs)

    return wrapped_function
949
950
951


def typecheck_var_args_pattern(args_pattern, empty_var_pos_args_allowed=False):
Valentin Valls's avatar
Valentin Valls committed
952
953
954
955
956
    """Decorator that can be used for typechecking of `*args` that have to
    follow a certain pattern e.g.

    .. code-block::

957
        @typecheck_var_args_pattern([Scannable,_float])
Valentin Valls's avatar
Valentin Valls committed
958
959
        def umv(*args):
            ...
960
961
962
963
964
965
966
967
968
969
970
    """

    def decorate(function):
        @functools.wraps(function)
        def wrapped_function(*args, **kwargs):
            sig = inspect.signature(function)
            params = list(sig.parameters.values())
            for i, param in enumerate(params):
                if param.kind == inspect.Parameter.VAR_POSITIONAL:
                    var_args = args[i:]
                    if not empty_var_pos_args_allowed and len(var_args) == 0:
971
972
973
                        raise TypeguardTypeError(
                            f"Arguments of type {args_pattern} missing!"
                        )
974
                    if len(var_args) % len(args_pattern) != 0:
975
                        raise TypeguardTypeError(
976
977
978
979
980
981
982
983
984
985
986
987
988
                            f"Wrong number of arguments (not a multiple of {len(args_pattern)} [{args_pattern}])"
                        )
                    for j, a in enumerate(var_args):
                        typeguard.check_type(
                            f"{param.name}[{j}]", a, args_pattern[j % len(args_pattern)]
                        )
            return function(*args, **kwargs)

        return wrapped_function

    return decorate


989
def modify_annotations(annotations):
Valentin Valls's avatar
Valentin Valls committed
990
991
992
993
994
995
996
    """Modify the annotation in an existing signature.

    .. code-block: python

        @modify_annotations({"args": "motor1, rel. pos1, motor2, rel. pos2, ..."})
        def umvr(*args):
            ...
997
998
999
1000
    """

    def decorate(function):
        def wrapped_function(*args, **kwargs):