file.py 7.17 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
5
# Copyright (c) 2015-2022 Beamline Control Unit, ESRF
6
7
# Distributed under the GNU LGPLv3. See LICENSE for more info.

8
import itertools
9
import os
10
from time import time
11

12
from bliss.scanning.chain import AcquisitionSlave, AcquisitionMaster
13
from bliss.common.event import connect, disconnect
14
from blissdata import settings
15
from bliss import current_session
16

17

18
class _EventReceiver(object):
19
20
    def __init__(self, device, parent_entry, callback):
        self.device = device
21
22
        self.parent_entry = parent_entry
        self.callback = callback
23

24
25
26
    def __call__(self, event_dict=None, signal=None, sender=None):
        if callable(self.callback):
            self.callback(self.parent_entry, event_dict, signal, sender)
27

28
    def connect(self):
29
        for signal in ("start", "end"):
30
31
            connect(self.device, signal, self)
        for channel in self.device.channels:
Vincent Michel's avatar
Vincent Michel committed
32
            connect(channel, "new_data", self)
33
34
35
36

    def disconnect(self):
        if self.device is None:
            return
37
        for signal in ("start", "end"):
38
39
            disconnect(self.device, signal, self)
        for channel in self.device.channels:
40
            disconnect(channel, "new_data", self)
41
42
43
        self.device = None


44
class FileWriter:
45
46
    FILE_EXTENSION = None

47
48
49
50
    def __init__(
        self,
        root_path,
        images_root_path,
51
        data_filename,
52
53
        master_event_callback=None,
        device_event_callback=None,
54
55
        connection=None,
        **_,
56
    ):
Valentin Valls's avatar
Valentin Valls committed
57
        """A default way to organize file structure"""
58
        self._save_images = True
59
        self._root_path_template = root_path
60
        self._data_filename_template = data_filename
61
62
        self._template_dict = {}
        self._images_root_path_template = images_root_path
63
64
        if master_event_callback is None:
            master_event_callback = self._master_event_callback
65
66
        self._master_event_callback = master_event_callback
        self._device_event_callback = device_event_callback
67
68
        self._event_receivers = list()

69
70
71
72
        # Parameters for the default master event callback
        self._master_event_callback_time = time()
        self._master_event_callback_period = 3

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        try:
            name = current_session.name
        except AttributeError:
            name = "default"
        db_name = f"writers:{self.writer_type()}:{name}"
        self.__options = settings.HashObjSetting(db_name, connection=connection)

    @classmethod
    def writer_type(cls):
        return cls.__module__.split(".")[-1]

    @property
    def options(self):
        return self.__options

    def writer_options(self) -> dict:
        return self.options.get_all()

91
    @property
92
93
    def template(self):
        return self._template_dict
94

95
    @property
96
    def root_path(self):
Valentin Valls's avatar
Valentin Valls committed
97
        """File directory"""
98
        return self._root_path_template.format(**self._template_dict)
99

100
101
    @property
    def data_filename(self):
Valentin Valls's avatar
Valentin Valls committed
102
        """File name without extension"""
103
104
105
106
        return self._data_filename_template.format(**self._template_dict)

    @property
    def filename(self):
Valentin Valls's avatar
Valentin Valls committed
107
        """Full file path"""
108
109
110
111
        return os.path.join(
            self.root_path,
            os.path.extsep.join((self.data_filename, self.FILE_EXTENSION)),
        )
112

113
114
115
    def create_path(self, path: str) -> bool:
        os.makedirs(path, exist_ok=True)
        return True
116

117
    def new_file(self, scan_name, scan_info):
118
        """Create a new scan file
119

120
121
        Filename is stored in the class as the 'filename' property
        """
122
123
        raise NotImplementedError

124
    def finalize_scan_entry(self, scan):
Valentin Valls's avatar
Valentin Valls committed
125
        """Called at the end of a scan"""
126
127
        pass

128
129
130
131
132
    def new_scan(self, scan_name, scan_info):
        raise NotImplementedError

    def new_master(self, master, scan_entry):
        return scan_entry
133

134
    def prepare_saving(self, device, images_path):
135
        any_image = any(
136
137
            channel.reference and len(channel.shape) == 2 for channel in device.channels
        )
138
139
140
141
142
143
        if any_image and self._save_images:
            directory = os.path.dirname(images_path)
            prefix = os.path.basename(images_path)
            device.set_image_saving(directory, prefix)
        else:
            device.set_image_saving(None, None, force_no_saving=True)
144

145
    def _prepare_callbacks(self, device, master_entry, callback):
146
147
        ev_receiver = _EventReceiver(device, master_entry, callback)
        ev_receiver.connect()
148
149
        self._event_receivers.append(ev_receiver)

150
151
152
153
154
    def _remove_callbacks(self):
        for ev_receiver in self._event_receivers:
            ev_receiver.disconnect()
        self._event_receivers = []

155
    def prepare(self, scan):
156
        self._master_event_callback_time = time()
157
158
159
        self.create_path(self.root_path)
        self.new_file(scan.node.name, scan.scan_info)
        scan_entry = self.new_scan(scan.node.name, scan.scan_info)
160

161
        self._event_receivers = []
162

163
164
        scan_counter = itertools.count()

Vincent Michel's avatar
Vincent Michel committed
165
        for dev, node in scan.nodes.items():
166
            if isinstance(dev, AcquisitionMaster):
167
168
169
170
171
172
173
174
175
176
                if dev.parent is None:
                    # top-level master
                    scan_index = next(scan_counter)
                    if scan_index > 0:
                        # multiple top-level masters: create a new scan with sub-scan
                        # convention: scan number will get a .1, .2, etc suffix
                        scan_number, scan_name = scan.node.name.split("_", maxsplit=1)
                        subscan_name = f"{scan_number}{'.%d_' % scan_index}{scan_name}"
                        scan_entry = self.new_scan(subscan_name, scan.scan_info)

177
                master_entry = self.new_master(dev, scan_entry)
178
                self._prepare_callbacks(dev, master_entry, self._master_event_callback)
179

180
                images_path = self.images_path(scan, img_acq_device=dev.name)
181
                self.prepare_saving(dev, images_path)
182
183

                for slave in dev.slaves:
184
                    if isinstance(slave, AcquisitionSlave) and callable(
185
186
187
188
189
                        self._device_event_callback
                    ):
                        self._prepare_callbacks(
                            slave, master_entry, self._device_event_callback
                        )
190

191
192
193
194
195
196
197
    def images_path(self, scan, img_acq_device: str = "{img_acq_device}"):
        return self._images_root_path_template.format(
            scan_name=scan.name,
            img_acq_device=img_acq_device,
            scan_number=scan.scan_number,
        )

198
    def close(self):
199
        self._remove_callbacks()
200
201
202
203
204
205

    def get_scan_entries(self):
        """
        Should return all scan entries from this path
        """
        return []
206
207
208
209
210
211
212

    @property
    def last_scan_number(self):
        """Scans start numbering from 1 so 0 indicates
        no scan exists in the file.
        """
        return 0
213
214
215
216
217
218
219
220
221
222
223
224

    @property
    def _master_event_callback_tick(self):
        tm = time()
        tmmax = self._master_event_callback_time + self._master_event_callback_period
        if tm > tmmax:
            self._master_event_callback_time = tm
            return True
        else:
            return False

    def _master_event_callback(self, parent, event_dict, signal, sender):
225
        pass