session.py 34.5 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
Benoit Formet's avatar
Benoit Formet committed
5
# Copyright (c) 2015-2020 Beamline Control Unit, ESRF
6
7
8
9
# Distributed under the GNU LGPLv3. See LICENSE for more info.

import os
import sys
Matias Guijarro's avatar
Matias Guijarro committed
10
import typing
11
import warnings
12
import collections
13
import functools
14
import inspect
15
import contextlib
16
import shutil
17
from treelib import Tree
18
from types import ModuleType
19
from weakref import WeakKeyDictionary
20
from tabulate import tabulate
21

22
from bliss import setup_globals, global_map, is_bliss_shell
23
from bliss.config import static
24
from bliss.config.settings import SimpleSetting
25
from bliss.config.channels import EventChannel
26
from bliss.config.conductor.client import get_text_file, get_python_modules, get_file
27
from bliss.common.proxy import Proxy
28
from bliss.common.logtools import log_warning
29
from bliss.common.utils import UserNamespace, chunk_list
30
from bliss.common import constants
31
from bliss.scanning import scan_saving
32
from bliss.scanning import scan_display
33

34

35
_SESSION_IMPORTERS = set()
36
CURRENT_SESSION = None
37

38

39
40
41
42
43
44
45
46
47
48
49
50
def set_current_session(session, force=True):
    if force:
        global CURRENT_SESSION
        CURRENT_SESSION = session
    else:
        raise RuntimeError("It is not allowed to set another current session.")


def get_current_session():
    return CURRENT_SESSION


51
class _StringImporter(object):
52
    BASE_MODULE_NAMESPACE = "bliss.session"
53

54
    def __init__(self, path, session_name, in_load_script=False):
55
        self._modules = dict()
56
        session_module_namespace = "%s.%s" % (self.BASE_MODULE_NAMESPACE, session_name)
57
        for module_name, file_path in get_python_modules(path):
58
            self._modules["%s.%s" % (session_module_namespace, module_name)] = file_path
59
60
            if in_load_script:
                self._modules[module_name] = file_path
61
        if self._modules:
62
            self._modules[self.BASE_MODULE_NAMESPACE] = None
63
            self._modules["%s.%s" % (self.BASE_MODULE_NAMESPACE, session_name)] = None
64
65
66
67
68
69
70

    def find_module(self, fullname, path):
        if fullname in self._modules:
            return self
        return None

    def load_module(self, fullname):
71
        if fullname not in self._modules.keys():
72
73
            raise ImportError(fullname)

74
75
        filename = self._modules.get(fullname)
        if filename:
76
            s_code = get_text_file(filename)
77
        else:
78
79
            filename = "%s (__init__ memory)" % fullname
            s_code = ""  # empty __init__.py
80

81
        new_module = sys.modules.get(fullname, ModuleType(fullname))
82
        new_module.__loader__ = self
83
        module_filename = "beacon://%s" % filename
84
        new_module.__file__ = module_filename
85
        new_module.__name__ = fullname
86
        if filename.find("__init__") > -1:
87
88
89
            new_module.__path__ = []
            new_module.__package__ = fullname
        else:
90
            new_module.__package__ = fullname.rpartition(".")[0]
91
        sys.modules.setdefault(fullname, new_module)
92
        c_code = compile(s_code, module_filename, "exec")
Vincent Michel's avatar
Vincent Michel committed
93
        exec(c_code, new_module.__dict__)
94
95
        return new_module

96
    def get_source(self, fullname):
97
        if fullname not in self._modules.keys():
98
99
            raise ImportError(fullname)

100
        filename = self._modules.get(fullname)
101
        return get_text_file(filename) if filename else ""
102

103

104
105
106
107
class ConfigProxy(Proxy):
    def __init__(self, target, env_dict):
        object.__setattr__(self, "_ConfigProxy__env_dict", env_dict)
        super().__init__(target, init_once=True)
108

109
110
111
112
113
114
    def get(self, name):
        """This is the same as the canonical static config.get,
        except that it adds the object to the corresponding session env dict"""
        obj = self.__wrapped__.get(name)
        self.__env_dict[name] = obj
        return obj
115

116

117
class Session:
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
118
119
120
    """
    Bliss session.

121
    Sessions group objects with a setup.
Vincent Michel's avatar
Vincent Michel committed
122

123
    YAML file example:
Valentin Valls's avatar
Valentin Valls committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

    .. code-block::

         - plugin: session          # could be defined in parents
           class: Session
           name: super_mario        # session name

           # 'config-objects' contains
           # object name you want to export
           # either in yaml compact list
           config-objects: [seby,diode2]
           # or standard yaml list
           config-objects:
           - seby
           - diode2
           # if config-objects key doesn't exist,
           # session will export all objects;
           # 'exclude-objects' can be used to exclude objects
           exclude-objects: [seby]

           # you can also include other session
           # with the 'include-sessions'
           include-sessions: [luigi]

           # finally a setup file can be defined to be
           # executed for the session.
           # All objects or functions defined in the
           # setup file will be exported in the environment.
           # The file path is relative to the session yaml file
           # location if it starts with a './'
           # otherwise it is absolute from the root of the
           # beacon file data base.
           setup-file: ./super_mario.py

           # A svg synoptic (Web shell) can be added:
           synoptic:
             svg-file: super_mario.svg
Sebastien Petitdemange's avatar
Sebastien Petitdemange committed
161
    """
162

163
164
    def __init__(self, name, config_tree):
        self.__name = name
165
        self.__env_dict = {}
166
167
168
169
170
171
172
        self.__scripts_module_path = None
        self.__setup_file = None
        self.__synoptic_file = None
        self.__config_objects_names = []
        self.__exclude_objects_names = []
        self.__children_tree = None
        self.__include_sessions = []
173
174
        self.__map = None
        self.__log = None
175
        self.__scans = collections.deque(maxlen=20)
176
        self.__user_script_homedir = SimpleSetting("%s:user_script_homedir" % self.name)
177
        self._script_source_cache = WeakKeyDictionary()
178
        self.__data_policy_events = EventChannel(f"{self.name}:esrf_data_policy")
179
180
        self.scan_saving = None
        self.scan_display = None
181
        self.is_loading_config = False
182
183
184
185

        self.init(config_tree)

    def init(self, config_tree):
186
        try:
187
188
189
            self.__scripts_module_path = os.path.normpath(
                os.path.join(os.path.dirname(config_tree.filename), "scripts")
            )
190
191
192
        except AttributeError:
            # config_tree has no .filename
            self.__scripts_module_path = None
193

194
        try:
195
196
            setup_file_path = config_tree["setup-file"]
        except KeyError:
197
            self.__setup_file = None
198
199
        else:
            try:
200
201
202
                self.__setup_file = os.path.normpath(
                    os.path.join(os.path.dirname(config_tree.filename), setup_file_path)
                )
203
204
            except TypeError:
                self.__setup_file = None
205
            else:
206
207
208
                self.__scripts_module_path = os.path.join(
                    os.path.dirname(self.__setup_file), "scripts"
                )
209

210
211
212
213
        # convert windows-style path to linux-style
        if self.__scripts_module_path:
            self.__scripts_module_path = self._scripts_module_path.replace("\\", "/")

214
215
216
217
        try:
            self.__synoptic_file = config_tree.get("synoptic").get("svg-file")
        except AttributeError:
            self.__synoptic_file = None
218

219
        self.__config_objects_names = config_tree.get("config-objects")
220
221
        self.__exclude_objects_names = config_tree.get("exclude-objects", list())
        self.__children_tree = None
222
        self.__include_sessions = config_tree.get("include-sessions")
223
        self.__config_aliases = config_tree.get("aliases", [])
224
225
        self.__icat_mapping = None
        self.__icat_mapping_config = config_tree.get("icat-mapping")
226
227
228
        self.__default_user_script_homedir = config_tree.get("default-userscript-dir")
        if self.__default_user_script_homedir and not self._get_user_script_home():
            self._set_user_script_home(self.__default_user_script_homedir)
229
230
231
        self.__scan_saving_config = config_tree.get(
            "scan_saving", self.config.root.get("scan_saving", {})
        )
232

233
234
235
236
    @property
    def name(self):
        return self.__name

237
238
239
240
    @property
    def scans(self):
        return self.__scans

241
242
    @property
    def config(self):
243
        return ConfigProxy(static.get_config, self.env_dict)
244

245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    @property
    @contextlib.contextmanager
    def temporary_config(self):
        """
        Create a context to export temporary some devices.
        """
        # store current config status
        cfg = static.get_config()
        name2instancekey = set(cfg._name2instance.keys())
        name2cache = cfg._name2cache.copy()

        # reload is not permited in temporary config
        previous_reload = cfg.reload

        def reload(*args):
            raise RuntimeError("Not permitted under tempaorary config context")

        cfg.reload = reload

        try:
            yield self.config
        finally:
            # rollback config
            cfg.reload = previous_reload
            diff_keys = set(cfg._name2instance.keys()) - name2instancekey
            for key in diff_keys:
                cfg._name2instance.pop(key)
                self.__env_dict.pop(key, None)
            cfg_name2cache_key = set(cfg._name2cache)
            prev_name2cache_key = set(name2cache)
            added_keys = cfg_name2cache_key - prev_name2cache_key
            removed_key = prev_name2cache_key - cfg_name2cache_key
            # remove added cache
            for key in added_keys:
                cfg._name2cache.pop(key)
            # re-insert removed cache
            for key in removed_key:
                cfg._name2cache[key] = name2cache[key]

284
285
286
287
288
289
290
291
    @property
    def setup_file(self):
        return self.__setup_file

    @property
    def synoptic_file(self):
        return self.__synoptic_file

292
293
294
295
    @property
    def _scripts_module_path(self):
        return self.__scripts_module_path

Linus Pithan's avatar
Linus Pithan committed
296
297
    @property
    def icat_mapping(self):
298
299
300
301
302
        if self.__icat_mapping is not None:
            return self.__icat_mapping
        if self.__icat_mapping_config:
            self.__icat_mapping = self.config.get(self.__icat_mapping_config)
            return self.__icat_mapping
Linus Pithan's avatar
Linus Pithan committed
303

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    def _child_session_iter(self):
        sessions_tree = self.sessions_tree
        for child_session in reversed(
            list(sessions_tree.expand_tree(mode=Tree.WIDTH))[1:]
        ):
            yield child_session

    def _aliases_info(self, cache={"aliases": {}, "config_id": None}):
        aliases = cache["aliases"]
        config_id = id(self.__config_aliases)
        if cache["config_id"] != config_id:
            aliases.clear()
            cache["config_id"] = config_id
        if aliases:
            return aliases

        for child_session in self._child_session_iter():
            aliases.update(child_session._aliases_info())

        for alias_cfg in self.__config_aliases:
324
            cfg = alias_cfg.clone()
325
326
327
328
            aliases[cfg.pop("original_name")] = cfg

        return aliases

329
    @property
330
331
332
333
334
335
336
337
338
339
340
341
342
    def object_names(self, cache={"objects_names": [], "config_id": None}):
        objects_names = cache["objects_names"]
        config_id = id(self.__config_objects_names)
        if cache["config_id"] != config_id:
            objects_names.clear()
            cache["config_id"] = config_id
        if objects_names:
            return objects_names

        names_list = list()
        for child_session in self._child_session_iter():
            names_list.extend(child_session.object_names)

343
344
        session_config = self.config.get_config(self.name)

345
        if self.__config_objects_names is None:
346
            names_list = list()
347
348
349
350
351
352
353
354
355
356
357
358
            for name in self.config.names_list:
                cfg = self.config.get_config(name)
                if cfg.get("class", "").lower() == "session":
                    continue
                if cfg.get_inherited("plugin") == "default":
                    continue
                names_list.append(name)
        else:
            names_list.extend(self.__config_objects_names[:])
            # Check if other session in config-objects
            for name in names_list:
                object_config = self.config.get_config(name)
359

360
361
362
363
364
                if object_config is None:
                    log_warning(
                        self,
                        f"In {session_config.filename} of session '{self.name}':"
                        + f" object '{name}' does not exist. Ignoring it.",
365
366
                    )
                    names_list.remove(name)
367
368
369
370
371
372
373
374
375
                else:
                    class_name = object_config.get("class", "")
                    if class_name.lower() == "session":
                        warnings.warn(
                            f"Session {self.name} 'config-objects' list contains session "
                            + f"{name}, ignoring (hint: add session in 'include-sessions' list)",
                            RuntimeWarning,
                        )
                        names_list.remove(name)
376
377
378
379
380
381
382
383
384
385

        for name in self.__exclude_objects_names:
            try:
                names_list.remove(name)
            except (ValueError, AttributeError):
                pass
        seen = set()
        objects_names.clear()
        objects_names.extend(x for x in names_list if not (x in seen or seen.add(x)))
        return objects_names
386

387
388
389
390
391
392
    @property
    def sessions_tree(self):
        """
        return children session as a tree
        """
        if self.__children_tree is None:
393
            children = {self.name: (1, list())}
394
395
            tree = Tree()
            tree.create_node(tag=self.name, identifier=self)
396
            tree = self._build_children_tree(tree, self, children)
397
            multiple_ref_child = [
398
                (name, parents) for name, (ref, parents) in children.items() if ref > 1
399
            ]
400
401
            if multiple_ref_child:
                msg = "Session %s as cyclic references to sessions:\n" % self.name
402
403
404
405
                msg += "\n".join(
                    "session %s is referenced in %r" % (session_name, parents)
                    for session_name, parents in multiple_ref_child
                )
406
407
408
409
                raise RuntimeError(msg)
            self.__children_tree = tree
        return self.__children_tree

410
    def _build_children_tree(self, tree, parent, children):
411
412
        if self.__include_sessions is not None:
            for session_name in self.__include_sessions:
413
                nb_ref, parents = children.get(session_name, (0, list()))
414
                nb_ref += 1
415
                children[session_name] = (nb_ref, parents)
416
417
418
419
420
                parents.append(self.name)
                if nb_ref > 1:  # avoid cyclic reference
                    continue

                child = self.config.get(session_name)
421
422
423
                child_node = tree.create_node(
                    tag=session_name, identifier=child, parent=parent
                )
424
                child._build_children_tree(tree, child, children)
425
426
        return tree

427
428
    @property
    def env_dict(self):
429
        return self.__env_dict
430

431
432
433
434
435
436
    def _emit_event(self, event, **kwargs):
        if event in scan_saving.ESRFDataPolicyEvent:
            self.__data_policy_events.post(dict(event_type=event, value=kwargs))
        else:
            raise NotImplementedError

437
    def _set_scan_saving(self, cls=None):
438
439
        """Defines the data policy, which includes the electronic logbook
        """
440
        scan_saving.set_scan_saving_class(cls)
441
442
443
444
        self.scan_saving = scan_saving.ScanSaving(self.name)
        if is_bliss_shell():
            self.env_dict["SCAN_SAVING"] = self.scan_saving

445
446
447
448
449
450
451
452
    @property
    def _config_scan_saving_class(self):
        scan_saving_class_name = self.__scan_saving_config.get("class")
        try:
            return getattr(scan_saving, scan_saving_class_name)
        except (AttributeError, TypeError):
            return None

453
454
455
456
457
    def _set_scan_display(self):
        self.scan_display = scan_display.ScanDisplay(self.name)
        if is_bliss_shell():
            self.env_dict["SCAN_DISPLAY"] = self.scan_display

458
    def enable_esrf_data_policy(self):
459
        self._set_scan_saving(cls=scan_saving.ESRFScanSaving)
460
461
462
463
        self._emit_event(
            scan_saving.ESRFDataPolicyEvent.Enable,
            data_path=self.scan_saving.get_path(),
        )
464
465

    def disable_esrf_data_policy(self):
466
        self._set_scan_saving()
467
468
469
470
        self._emit_event(
            scan_saving.ESRFDataPolicyEvent.Disable,
            data_path=self.scan_saving.get_path(),
        )
471

472
473
474
475
476
477
478
    def _cache_script_source(self, obj):
        """ Store source code of obj in cache for prdef """
        try:
            self._script_source_cache[obj] = inspect.getsourcelines(obj)
        except Exception:
            pass

479
480
481
482
483
484
485
486
    def load_script(self, script_module_name, session=None):
        """
        load a script name script_module_name and export all public
        (not starting with _) object and function in env_dict.
        just print exception but not throwing it.

        Args:
            script_module_name the python file you want to load
487
            session (optional) the session from which to load the script
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
        """
        if session is None:
            session = self
        elif isinstance(session, str):
            session = self.config.get(session)

        if session._scripts_module_path:
            importer = _StringImporter(
                session._scripts_module_path, session.name, in_load_script=True
            )
            try:
                sys.meta_path.insert(0, importer)

                module_name = "%s.%s.%s" % (
                    _StringImporter.BASE_MODULE_NAMESPACE,
                    session.name,
                    os.path.splitext(script_module_name)[0],
                )
                filename = importer._modules.get(module_name)
                if not filename:
                    raise RuntimeError("Cannot find module %s" % module_name)

                s_code = get_text_file(filename)
                c_code = compile(s_code, filename, "exec")

                globals_dict = self.env_dict.copy()
514
                globals_dict["__file__"] = filename
515
516
517
518
519
520
521
522
523
                try:
                    exec(c_code, globals_dict)
                except Exception:
                    sys.excepthook(*sys.exc_info())

                for k in globals_dict.keys():
                    if k.startswith("_"):
                        continue
                    self.env_dict[k] = globals_dict[k]
524
                    self._cache_script_source(globals_dict[k])
525
526
            finally:
                sys.meta_path.remove(importer)
527
528
        else:
            raise RuntimeError(f"{session.name} session has no script module path")
529

530
531
532
533
534
535
    def _get_user_script_home(self):
        return self.__user_script_homedir.get()

    def _set_user_script_home(self, dir):
        self.__user_script_homedir.set(dir)

536
537
538
539
540
    def _reset_user_script_home(self):
        if self.__default_user_script_homedir:
            self.__user_script_homedir.set(self.__default_user_script_homedir)
        else:
            self.__user_script_homedir.clear()
541

542
    def user_script_homedir(self, new_dir=None, reset=False):
543
544
545
546
547
548
        """
        Set or get local user script home directory

        Args:
            None -> returns current user script home directory
            new_dir (optional) -> set user script home directory to new_dir
549
            reset (optional) -> reset previously set user script home directory
550
        """
551
552
        if reset:
            self._reset_user_script_home()
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        elif new_dir is not None:
            if not os.path.isabs(new_dir):
                raise RuntimeError(f"Directory path must be absolute [{new_dir}]")
            if not os.path.isdir(new_dir):
                raise RuntimeError(f"Invalid directory [{new_dir}]")
            self._set_user_script_home(new_dir)
        else:
            return self._get_user_script_home()

    def user_script_list(self):
        """List python scripts from user script home directory"""
        rootdir = self._get_user_script_home()
        if not rootdir:
            print(
                "First, you need to set a directory with `user_script_homedir(path_to_dir)`"
            )
            raise RuntimeError("User scripts home directory not configured")
        if not os.path.isdir(rootdir):
            raise RuntimeError(f"Invalid directory [{rootdir}]")

        print(f"List of python scripts in [{rootdir}]:")
        for (dirpath, dirnames, filenames) in os.walk(rootdir):
            dirname = dirpath.replace(rootdir, "")
            dirname = dirname.lstrip(os.path.sep)
            for filename in filenames:
                _, ext = os.path.splitext(filename)
                if ext != ".py":
                    continue
                print(f" - {os.path.join(dirname, filename)}")

583
    def user_script_load(self, scriptname=None, export_global="user"):
584
585
        """
        load a script and export all public (= not starting with _)
586
587
        objects and functions to current environment or to a namespace.
        (exceptions are printed but not thrown, execution is stopped)
588
589

        Args:
590
591
            scriptname: the python file to load (script path can be absolute relative to script_homedir)
        Optional args:
592
593
            export_global="user" (default): export objects to "user" namespace in session env dict (eg. user.myfunc())
            export_global=False: return a namespace
594
            export_global=True: export objects to session env dict
595
596
        """
        return self._user_script_exec(
597
            scriptname, load=True, export_global=export_global
598
599
600
601
602
        )

    def user_script_run(self, scriptname=None):
        """
        Execute a script without exporting objects or functions to current environment.
603
        (exceptions are printed but not thrown, execution is stopped)
604
605

        Args:
606
            scriptname: the python file to run (script path can be absolute or relative to script_homedir)
607
        """
608
        self._user_script_exec(scriptname, load=False)
609

610
    def _user_script_exec(self, scriptname, load=False, export_global=False):
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
        if not scriptname:
            self.user_script_list()
            return

        if os.path.isabs(scriptname):
            filepath = scriptname
        else:
            if not self._get_user_script_home():
                print(
                    "First, you need to set a directory with `user_script_homedir(path_to_dir)`"
                )
                raise RuntimeError("User scripts home directory not configured")

            homedir = os.path.abspath(self._get_user_script_home())
            filepath = os.path.join(homedir, scriptname)

        _, ext = os.path.splitext(scriptname)
        if not ext:
            filepath += ".py"
        if not os.path.isfile(filepath):
            raise RuntimeError(f"Cannot find [{filepath}] !")
        try:
            script = open(filepath).read()
        except Exception:
            raise RuntimeError(f"Failed to read [{filepath}] !")

637
        if load is True:
638
            print(f"Loading [{filepath}]")
639
        else:
640
            print(f"Running [{filepath}]")
641
642

        globals_dict = self.env_dict.copy()
643
        globals_dict["__file__"] = filepath
644

645
        c_code = compile(script, filepath, "exec")
646
        try:
647
            exec(c_code, globals_dict)
648
649
650
        except Exception:
            sys.excepthook(*sys.exc_info())

651
        def safe_save_to_env_dict(env_dict, key, value):
652
            """ Print warning if env_dict[key] already exists """
653
            if key in env_dict and value is not env_dict[key]:
654
                print(f"Replaced [{key}] in session env")
655
656
            env_dict[key] = value

657
        # case #1: run file
658
659
660
        if not load:
            return

661
        # case #2: export to global env dict
662
663
664
665
        if export_global is True:
            for k in globals_dict.keys():
                if k.startswith("_"):
                    continue
666
                safe_save_to_env_dict(self.env_dict, k, globals_dict[k])
667
                self._cache_script_source(globals_dict[k])
668

669
        else:
670
            env_dict = dict()
671
            for k in c_code.co_names:
672
673
                if k.startswith("_"):
                    continue
674
675
                if k not in globals_dict:
                    continue
676
                env_dict[k] = globals_dict[k]
677
            ns = UserNamespace(env_dict)
678

679
680
681
            for obj in env_dict.values():
                self._cache_script_source(obj)

682
            if isinstance(export_global, str):
683
684
685
686
                if (
                    getattr(self.env_dict.get(export_global), "__module__", None)
                    == "bliss.common.utils.namespace"
                ):
687
                    # case #3: export and merge to existing namespace in env dict
688
689
690
                    d = self.env_dict[export_global]._asdict()
                    d.update(env_dict)
                    self.env_dict[export_global] = UserNamespace(d)
691
                    print(f"Merged [{export_global}] namespace in session.")
692
                else:
693
                    # case #4: export to given (non existing) namespace in env dict
694
                    safe_save_to_env_dict(self.env_dict, export_global, ns)
695
696
                    print(f"Exported [{export_global}] namespace in session.")

697
            else:
698
                # case #5: export_global is False, return the namespace
699
                return ns
700

Matias Guijarro's avatar
Matias Guijarro committed
701
    def _do_setup(self, env_dict: typing.Union[dict, None], verbose: bool) -> bool:
702
        """
Matias Guijarro's avatar
Matias Guijarro committed
703
704
705
706
707
708
        Load configuration, and execute the setup script

        env_dict: globals dictionary (or None to use current session env. dict)
        verbose: boolean flag passed to `load_config`

        Return: True if setup went without error, False otherwise
709
        """
Matias Guijarro's avatar
Matias Guijarro committed
710
        ret = True
Matias Guijarro's avatar
Matias Guijarro committed
711
        set_current_session(self, force=True)
712
713

        # Session environment
714
715
716
        if env_dict is None:
            env_dict = get_current_session().env_dict
        self.__env_dict = env_dict
717

718
719
720
721
722
        # Data policy needs to be defined before instantiating the
        # session objects
        self._set_scan_saving(cls=self._config_scan_saving_class)

        # Instantiate the session objects
723
        try:
724
            CURRENT_SESSION.is_loading_config = True
725
726
            self._load_config(verbose)
        except Exception:
Matias Guijarro's avatar
Matias Guijarro committed
727
            ret = False
728
            sys.excepthook(*sys.exc_info())
729
        finally:
730
            CURRENT_SESSION.is_loading_config = False
731
732
733
734
735
736
737
738
739
740
            env_dict["config"] = self.config

        self._register_session_importers(self)

        self._set_scan_display()

        self._additional_env_variables(env_dict)

        for child_session in self._child_session_iter():
            self._register_session_importers(child_session)
Matias Guijarro's avatar
Matias Guijarro committed
741
742
743
744
745
746
747
748
749
750
751
752
753
754
            child_session_ret = child_session._setup(env_dict, nested=True)
            ret = ret and child_session_ret

        setup_ret = self._setup(env_dict)
        ret = ret and setup_ret

        return ret

    def setup(
        self,
        env_dict: typing.Optional[dict] = None,
        verbose: typing.Optional[bool] = False,
    ) -> bool:
        """Call _do_setup, but catch exception to display error message via except hook
755

Matias Guijarro's avatar
Matias Guijarro committed
756
757
758
759
760
761
762
763
764
765
766
767
        In case of SystemExit: the exception is propagated.

        Return: True if setup went without error, False otherwise
        """
        try:
            ret = self._do_setup(env_dict, verbose)
        except SystemExit:
            raise
        except BaseException:
            sys.excepthook(*sys.exc_info())
            return False
        return ret
768

769
770
771
772
773
774
775
776
777
778
779
780
781
782
    @staticmethod
    def _register_session_importers(session):
        """Allows remote scripts to be registered and executed locally
        """
        if session.__scripts_module_path and session.name not in _SESSION_IMPORTERS:
            sys.meta_path.append(
                _StringImporter(session.__scripts_module_path, session.name)
            )
            _SESSION_IMPORTERS.add(session.name)

    def _additional_env_variables(self, env_dict):
        """Add additional variables to the session environment
        """
        from bliss.common.measurementgroup import ACTIVE_MG
783

784
785
        env_dict["ALIASES"] = global_map.aliases
        env_dict["ACTIVE_MG"] = ACTIVE_MG
786
        if "load_script" not in env_dict:
787
            env_dict["load_script"] = self.load_script
788
789
790
791
792
793
794
795
        if "user_script_homedir" not in env_dict:
            env_dict["user_script_homedir"] = self.user_script_homedir
        if "user_script_list" not in env_dict:
            env_dict["user_script_list"] = self.user_script_list
        if "user_script_load" not in env_dict:
            env_dict["user_script_load"] = self.user_script_load
        if "user_script_run" not in env_dict:
            env_dict["user_script_run"] = self.user_script_run
796

797
    def _setup(self, env_dict, nested=False):
798
799
800
801
802
803
        """
        Load an execute setup file.

        Called by _do_setup() which is called by setup().
        Must return True in case of success.
        """
804
        if self.setup_file is None:
805
            return True
806

807
808
        print("%s: Executing setup file..." % self.name)

809
810
811
        with get_file(
            {"setup_file": self.setup_file}, "setup_file", text=True
        ) as setup_file:
812

813
814
815
816
817
818
819
            if nested:
                # in case of nested sessions, execute load_script from the child session
                env_dict["load_script"] = functools.partial(
                    env_dict["load_script"], session=self.name
                )
            else:
                env_dict["load_script"] = self.load_script
820

Matias Guijarro's avatar
Matias Guijarro committed
821
822
823
824
825
826
            try:
                code = compile(setup_file.read(), self.setup_file, "exec")
                exec(code, env_dict)
            except Exception:
                sys.excepthook(*sys.exc_info())
                return False
827

Vincent Michel's avatar
Vincent Michel committed
828
            for obj_name, obj in env_dict.items():
829
                setattr(setup_globals, obj_name, obj)
830

831
            return True
832

833
    def close(self):
834
835
836
837
838
839
840
841
842
843
844
845
846
        setup_globals.__dict__.clear()
        for obj_name, obj in self.env_dict.items():
            if obj is self or obj is self.config:
                continue
            try:
                obj.__close__()
            except Exception:
                pass
        self.env_dict.clear()
        global CURRENT_SESSION
        CURRENT_SESSION = None

    def _load_config(self, verbose=True):
847
848
849
850
851
852
        warning_item_list = list()
        success_item_list = list()
        error_item_list = list()
        error_count = 0
        item_count = 0

853
        for item_name in self.object_names:
854
855
            item_count += 1

856
            # Skip initialization of existing objects.
857
            if hasattr(setup_globals, item_name):
858
                self.env_dict[item_name] = getattr(setup_globals, item_name)
859
860
                continue

861
862
            print(f"Initializing: {item_name}                  ", end="", flush=True)

863
            try:
864
865
                self.config.get(item_name)
            except Exception:
866
                if verbose:
867
868
869
870
871
872
873
874
875
876
877
878
                    print("\r", end="", flush=True)  # return to begining of line.
                    print(" " * 80, flush=True)
                    print(
                        f"Initialization of {item_name} \033[91mFAILED\033[0m ",
                        flush=True,
                    )

                    print(f"[{error_count}] ", end="", flush=True)
                    sys.excepthook(*sys.exc_info())
                    error_count += 1
                    error_item_list.append(item_name)

879
            else:
880
                print("\r", end="", flush=True)  # return to begining of line.
881
882
883
                if verbose:
                    item_node = self.config.get_config(item_name)
                    if item_node.plugin is None:
884
                        warning_item_list.append(item_name)
885
                    else:
886
887
                        success_item_list.append(item_name)

888
        # Clear the line.
889
890
        print(" " * 80, flush=True)

891
        # Maximal length of objects names (min 5).
892
        display_width = shutil.get_terminal_size().columns
893
894
895
896
897
898
899
        if len(self.object_names) == 0:
            max_length = 5
            print("There are no objects declared in the session's config file.")
        else:
            max_length = max([len(x) for x in self.object_names])
        # Number of items displayable on one line.
        item_number = int(display_width / max_length) + 1
900
901
902
903
904
905

        # SUCCESS
        success_count = len(success_item_list)
        if success_count > 0:
            success_item_list.sort(key=str.casefold)
            print(
906
907
                f"OK: {len(success_item_list)}/{item_count}"
                f" object{'s' if success_count > 1 else ''} successfully initialized.",
908
909
910
911
912
913
914
915
916
917
918
919
                flush=True,
            )
            print(
                tabulate(chunk_list(success_item_list, item_number), tablefmt="plain")
            )
            print("")

        # WARNING
        warning_count = len(warning_item_list)
        if warning_count > 0:
            warning_item_list.sort(key=str.casefold)
            print(
920
921
                f"WARNING: {len(warning_item_list)} object{'s' if warning_count > 1 else ''}"
                f" initialized with **default** plugin:"
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
            )
            print(
                tabulate(chunk_list(warning_item_list, item_number), tablefmt="plain")
            )
            print("")

        # ERROR
        if error_count > 0:
            error_item_list.sort(key=str.casefold)
            print(
                f"ERROR: {error_count} object{'s' if error_count > 1 else ''} failed to intialize:"
            )
            print(tabulate(chunk_list(error_item_list, item_number), tablefmt="plain"))
            print("")

            if error_count == 1:
                print("To learn about failure, type: 'last_error'")
            else:
                print(
                    f"To learn about failures, type: 'last_error[X]' for X in [0..{error_count-1}]"
                )
            print("")
944

945
        # Make aliases.
946
947
        for item_name, alias_cfg in self._aliases_info().items():
            alias_name = alias_cfg["alias_name"]
948
949
950
951
            try:
                global_map.aliases.add(alias_name, item_name, verbose=verbose)
            except Exception:
                sys.excepthook(*sys.exc_info())
952
        try:
953
            self.config.get(self.name)
954
        except Exception:
955
            sys.excepthook(*sys.exc_info())
956

957
958
        setup_globals.__dict__.update(self.env_dict)

959
    def resetup(self, verbose=False):
960
        self.close()
961
962
963
964
965

        self.config.reload()

        self.init(self.config.get_config(self.name))

966
        self.setup(self.env_dict, verbose)
967
968
969


class DefaultSession(Session):
970
971
972
    """Session without config, setup scripts and data policy
    """

973
    def __init__(self):
974
        Session.__init__(self, constants.DEFAULT_SESSION_NAME, {"config-objects": []})
975

976
977
    def _set_scan_saving(self, cls=None):
        if cls is not None:
978
            log_warning(self, "No data policy allowed in this session.")
979
980
        super()._set_scan_saving(None)

981
982
983
984
985
986
    def enable_esrf_data_policy(self):
        pass

    def disable_esrf_data_policy(self):
        pass

987
    def _load_config(self, verbose=True):
988
        pass
989
990

    def resetup(self, verbose=False):
991
        pass