scan_info_helper.py 47.3 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
Benoit Formet's avatar
Benoit Formet committed
5
# Copyright (c) 2015-2020 Beamline Control Unit, ESRF
6
7
# Distributed under the GNU LGPLv3. See LICENSE for more info.
"""
Valentin Valls's avatar
Typo    
Valentin Valls committed
8
Provides helper to read scan_info.
9
10
11
12
"""
from __future__ import annotations
from typing import Any
from typing import Dict
13
from typing import List
14
from typing import Optional
15
from typing import MutableMapping
16
from typing import NamedTuple
17

18
import numpy
19
import weakref
20
21
import logging
from ..model import scan_model
22
23
from ..model import plot_model
from ..model import plot_item_model
24
from . import model_helper
25
from bliss.controllers.lima import roi as lima_roi
26
27


28
29
_logger = logging.getLogger(__name__)

30
31
32

class ChannelInfo(NamedTuple):
    name: str
Valentin Valls's avatar
Valentin Valls committed
33
    info: Dict
34
35
36
    device: str
    master: str

37

Valentin Valls's avatar
Valentin Valls committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
_SCAN_CATEGORY = {
    # A single measurement
    "ct": "point",
    # Many measurements
    "timescan": "nscan",
    "loopscan": "nscan",
    "lookupscan": "nscan",
    "pointscan": "nscan",
    "ascan": "nscan",
    "a2scan": "nscan",
    "a3scan": "nscan",
    "a4scan": "nscan",
    "anscan": "nscan",
    "dscan": "nscan",
    "d2scan": "nscan",
    "d3scan": "nscan",
    "d4scan": "nscan",
    "dnscan": "nscan",
    # Many measurements using 2 correlated axes
    "amesh": "mesh",
    "dmesh": "mesh",
}


def get_scan_category(scan_info: Dict = None, scan_type: str = None) -> Optional[str]:
    """
    Returns a scan category for the given scan_info.

    Returns:
        One of "point", "nscan", "mesh" or None if nothing matches.
    """
    if scan_info is not None:
        scan_type = scan_info.get("type", None)
    return _SCAN_CATEGORY.get(scan_type, None)

73

74
75
76
def _get_channels(
    scan_info: Dict, top_master_name: str = None, dim: int = None, master: bool = None
):
77
    """
78
79
80
81
82
83
84
85
86
    Returns channels from top_master_name and optionally filtered by dim and master.

    Channels from masters are listed first, and the channel order stays the same.

    Arguments:
        scan_info: Scan info dict
        top_master_name: If not None, a specific top master is read
        dim: If not None, only includes the channels with the requested dim
        master: If not None, only includes channels from a master / or not
87
    """
88
89
    names = []

90
    master_count = 0
91
92
93
94
95
96
97
98
99
100
101
102
    for top_master, meta in scan_info["acquisition_chain"].items():
        if top_master_name is not None:
            if top_master != top_master_name:
                # If the filter mismatch
                continue
        devices = meta["devices"]
        for device_name in devices:
            device_info = scan_info["devices"].get(device_name, None)
            if device_info is None:
                continue

            if master is not None:
103
104
105
106
                is_triggering = "triggered_devices" in device_info
                if is_triggering:
                    master_count += 1
                is_master = is_triggering and master_count == 1
107
108
109
110
111
112
113
114
115
116
117
118
                if master ^ is_master:
                    # If the filter mismatch
                    continue

            for c in device_info.get("channels", []):
                if dim is not None:
                    if scan_info["channels"].get(c, {}).get("dim", 0) != dim:
                        # If the filter mismatch
                        continue
                names.append(c)

    return names
119
120


121
def iter_channels(scan_info: Dict[str, Any]):
Valentin Valls's avatar
Valentin Valls committed
122
123
    acquisition_chain_description = scan_info.get("acquisition_chain", {})
    channels_description = scan_info.get("channels", {})
124
125
126
127

    def get_device_from_channel_name(channel_name):
        """Returns the device name from the channel name, else None"""
        if ":" in channel_name:
128
            return channel_name.rsplit(":", 1)[0]
129
130
        return None

131
132
    channels = set([])

Valentin Valls's avatar
Valentin Valls committed
133
134
135
136
    for master_name in acquisition_chain_description.keys():
        master_channels = _get_channels(scan_info, master_name)
        for channel_name in master_channels:
            info = channels_description.get(channel_name, {})
137
            device_name = get_device_from_channel_name(channel_name)
Valentin Valls's avatar
Valentin Valls committed
138
            channel = ChannelInfo(channel_name, info, device_name, master_name)
139
            yield channel
140
141
            channels.add(channel_name)

142
    requests = scan_info.get("channels", {})
143
144
145
146
    if not isinstance(requests, dict):
        _logger.warning("scan_info.requests is not a dict")
        requests = {}

Valentin Valls's avatar
Valentin Valls committed
147
    for channel_name, info in requests.items():
148
149
150
151
        if channel_name in channels:
            continue
        device_name = get_device_from_channel_name(channel_name)
        # FIXME: For now, let say everything is scalar here
Valentin Valls's avatar
Valentin Valls committed
152
        channel = ChannelInfo(channel_name, info, device_name, "custom")
153
        yield channel
154
155


156
157
158
class ScanModelReader:
    """Object reading a scan_info and generating a scan model"""

Valentin Valls's avatar
Valentin Valls committed
159
160
161
162
163
164
    DEVICE_TYPES = {
        None: scan_model.DeviceType.NONE,
        "lima": scan_model.DeviceType.LIMA,
        "mca": scan_model.DeviceType.MCA,
    }

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    def __init__(self, scan_info):
        self._scan_info = scan_info
        self._acquisition_chain_description = scan_info.get("acquisition_chain", {})
        self._device_description = scan_info.get("devices", {})
        self._channel_description = scan_info.get("channels", {})

        scan_info = self._scan_info
        is_group = scan_info.get("is-scan-sequence", False)
        if is_group:
            scan = scan_model.ScanGroup()
        else:
            scan = scan_model.Scan()

        scan.setScanInfo(scan_info)
        self._scan = scan
        self._parsed_devices = set()

    def parse(self):
        """Parse the whole scan info and return scan model"""
        assert self._scan is not None, "The scan was already parsed"
        self._parse_scan()
        self._precache_scatter_constraints()
        scan = self._scan
        self._scan = None
        scan.seal()
        return scan

    def _parse_scan(self):
        """Parse the whole scan structure"""
        for top_master_name, meta in self._acquisition_chain_description.items():
            self._parse_top_device(top_master_name, meta)

    def _parse_top_device(self, name, meta) -> scan_model.Device:
        top_master = scan_model.Device(self._scan)
        top_master.setName(name)

        sub_device_names = meta["devices"]

        for i, sub_device_name in enumerate(sub_device_names):
            if sub_device_name in self._parsed_devices:
                continue
            self._parsed_devices.add(sub_device_name)
            sub_meta = self._device_description.get(sub_device_name, None)
            if sub_meta is None:
                _logger.error(
                    "scan_info mismatch. Device name %s metadata not found",
                    sub_device_name,
                )
                continue
            sub_name = sub_device_name.rsplit(":", 1)[-1]
Valentin Valls's avatar
Valentin Valls committed
215
216
217
218
            if i == 0:
                parser_class = self.TopDeviceParser
            else:
                parser_class = None
219
            self._parse_device(
Valentin Valls's avatar
Valentin Valls committed
220
                sub_name, sub_meta, parent=top_master, parser_class=parser_class
221
222
            )

Valentin Valls's avatar
Valentin Valls committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    class DefaultDeviceParser:
        def __init__(self, reader):
            self.reader = reader

        def parse(self, name, meta, parent):
            device = self.create_device(name, meta, parent)
            self.parse_sub_devices(device, meta)
            self.parse_channels(device, meta)

        def create_device(self, name, meta, parent):
            device = scan_model.Device(self.reader._scan)
            device.setName(name)
            device.setMaster(parent)
            device_type = meta.get("type", None)
            device_type = self.reader.DEVICE_TYPES.get(
                device_type, scan_model.DeviceType.UNKNOWN
            )
            device.setType(device_type)
            metadata = scan_model.DeviceMetadata(info=meta, roi=None)
            device.setMetadata(metadata)
            return device

        def parse_sub_devices(self, device, meta):
            device_ids = meta.get("triggered_devices", [])
            for device_id in device_ids:
                self.reader._parsed_devices.add(device_id)
                sub_meta = self.reader._device_description.get(device_id, None)
250
251
252
                if sub_meta is None:
                    _logger.error(
                        "scan_info mismatch. Device name %s metadata not found",
Valentin Valls's avatar
Valentin Valls committed
253
                        device_id,
254
255
                    )
                    continue
Valentin Valls's avatar
Valentin Valls committed
256
257
258
                sub_name = device_id.rsplit(":", 1)[-1]
                self.reader._parse_device(sub_name, sub_meta, parent=device)

259
        def parse_channels(self, device: scan_model.Device, meta):
Valentin Valls's avatar
Valentin Valls committed
260
261
262
263
            channel_names = meta.get("channels", [])
            for channel_fullname in channel_names:
                channel_meta = self.reader._channel_description.get(
                    channel_fullname, None
264
                )
Valentin Valls's avatar
Valentin Valls committed
265
266
267
268
269
270
271
272
                if channel_meta is None:
                    _logger.error(
                        "scan_info mismatch. Channel name %s metadata not found",
                        channel_fullname,
                    )
                    continue
                self.parse_channel(channel_fullname, channel_meta, parent=device)

273
274
275
276
277
278
279
280
281
282
283
284
285
            xaxis_array = meta.get("xaxis_array", None)
            if xaxis_array is not None:
                # Create a virtual channel already feed with data
                try:
                    xaxis_array = numpy.array(xaxis_array)
                    if len(xaxis_array.shape) != 1:
                        raise RuntimeError("scan_info xaxis_array expect a 1D data")
                except Exception:
                    _logger.warning(
                        "scan_info contains wrong xaxis_array: %s", xaxis_array
                    )
                    xaxis_array = numpy.array([])

286
287
                unit = meta.get("xaxis_array_unit", None)
                label = meta.get("xaxis_array_label", None)
288
289
                channel = scan_model.Channel(device)
                channel.setType(scan_model.ChannelType.SPECTRUM)
290
291
292
293
                if unit is not None:
                    channel.setUnit(unit)
                if label is not None:
                    channel.setDisplayName(label)
294
295
296
297
298
                data = scan_model.Data(array=xaxis_array)
                channel.setData(data)
                fullname = device.name()
                channel.setName(f"{fullname}:#:xaxis_array")

Valentin Valls's avatar
Valentin Valls committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
        def parse_channel(self, channel_fullname: str, meta, parent: scan_model.Device):
            channel = scan_model.Channel(parent)
            channel.setName(channel_fullname)

            # protect mutation of the original object, with the following `pop`
            meta = dict(meta)

            # FIXME: This have to be cleaned up (unit and display name are part of the metadata)
            unit = meta.pop("unit", None)
            if unit is not None:
                channel.setUnit(unit)
            display_name = meta.pop("display_name", None)
            if display_name is not None:
                channel.setDisplayName(display_name)

            metadata = parse_channel_metadata(meta)
            channel.setMetadata(metadata)

    class TopDeviceParser(DefaultDeviceParser):
        def parse_sub_devices(self, device, meta):
            # Ignore sub devices to make it a bit more flat
            pass

    class LimaRoiDeviceParser(DefaultDeviceParser):
323
324
        def parse_channels(self, device: scan_model.Device, meta: Dict):

Valentin Valls's avatar
Valentin Valls committed
325
326
327
            # cache virtual roi devices
            virtual_rois = {}

328
329
330
331
332
333
334
335
336
337
            # FIXME: It would be good to have a real ROI concept in BLISS
            # Here we iterate the set of metadata to try to find something interesting
            for roi_name, roi_dict in meta.items():
                if not isinstance(roi_dict, dict):
                    continue
                if "kind" not in roi_dict:
                    continue
                roi_device = self.create_virtual_roi(roi_name, roi_dict, device)
                virtual_rois[roi_name] = roi_device

Valentin Valls's avatar
Valentin Valls committed
338
            def get_virtual_roi(channel_fullname):
339
340
                """Retrieve roi device from channel name"""
                nonlocal virtual_rois
Valentin Valls's avatar
Valentin Valls committed
341
342
343
344
345
346
347
                short_name = channel_fullname.rsplit(":", 1)[-1]

                if "_" in short_name:
                    roi_name, _ = short_name.rsplit("_", 1)
                else:
                    roi_name = short_name

348
                return virtual_rois.get(roi_name, None)
Valentin Valls's avatar
Valentin Valls committed
349
350
351
352
353
354
355
356
357
358
359
360
361

            channel_names = meta.get("channels", [])
            for channel_fullname in channel_names:
                channel_meta = self.reader._channel_description.get(
                    channel_fullname, None
                )
                if channel_meta is None:
                    _logger.error(
                        "scan_info mismatch. Channel name %s metadata not found",
                        channel_fullname,
                    )
                    continue
                roi_device = get_virtual_roi(channel_fullname)
362
363
364
365
366
367
368
                if roi_device is not None:
                    parent_channel = roi_device
                else:
                    parent_channel = device
                self.parse_channel(
                    channel_fullname, channel_meta, parent=parent_channel
                )
Valentin Valls's avatar
Valentin Valls committed
369

370
        def create_virtual_roi(self, roi_name, roi_dict, parent):
Valentin Valls's avatar
Valentin Valls committed
371
372
373
374
375
376
377
378
379
380
381
            device = scan_model.Device(self.reader._scan)
            device.setName(roi_name)
            device.setMaster(parent)
            device.setType(scan_model.DeviceType.VIRTUAL_ROI)

            # Read metadata
            roi = None
            if roi_dict is not None:
                try:
                    roi = lima_roi.dict_to_roi(roi_dict)
                except Exception:
382
383
384
385
386
387
                    _logger.warning(
                        "Error while reading roi '%s' from '%s'",
                        roi_name,
                        device.fullName(),
                        exc_info=True,
                    )
Valentin Valls's avatar
Valentin Valls committed
388
389
390
391

            metadata = scan_model.DeviceMetadata({}, roi)
            device.setMetadata(metadata)
            return device
392

393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
    class McaDeviceParser(DefaultDeviceParser):
        def parse_channels(self, device, meta):
            # cache virtual roi devices
            virtual_detectors = {}

            def get_virtual_detector(channel_fullname):
                """Some magic to create virtual device for each ROIs"""
                short_name = channel_fullname.rsplit(":", 1)[-1]

                # FIXME: It would be good to have a real detector concept in BLISS
                if "_" in short_name:
                    _, detector_name = short_name.rsplit("_", 1)
                else:
                    detector_name = short_name

                key = f"{device.name()}:{detector_name}"
                if key in virtual_detectors:
                    return virtual_detectors[key]

                detector_device = scan_model.Device(self.reader._scan)
                detector_device.setName(detector_name)
                detector_device.setMaster(device)
                detector_device.setType(scan_model.DeviceType.VIRTUAL_MCA_DETECTOR)
                virtual_detectors[key] = detector_device
                return detector_device

            channel_names = meta.get("channels", [])
            for channel_fullname in channel_names:
                channel_meta = self.reader._channel_description.get(
                    channel_fullname, None
                )
                if channel_meta is None:
                    _logger.error(
                        "scan_info mismatch. Channel name %s metadata not found",
                        channel_fullname,
                    )
                    continue
                roi_device = get_virtual_detector(channel_fullname)
                self.parse_channel(channel_fullname, channel_meta, parent=roi_device)

Valentin Valls's avatar
Valentin Valls committed
433
434
435
436
437
438
439
    def _parse_device(
        self, name: str, meta: Dict, parent: scan_model.Device, parser_class=None
    ):
        if parent.type() == scan_model.DeviceType.LIMA:
            if name == "roi_counters" or name == "roi_profiles":
                parser_class = self.LimaRoiDeviceParser
        if parser_class is None:
440
441
442
443
444
            device_type = meta.get("type")
            if device_type == "mca":
                parser_class = self.McaDeviceParser
            else:
                parser_class = self.DefaultDeviceParser
Valentin Valls's avatar
Valentin Valls committed
445
446
447

        node_parser = parser_class(self)
        node_parser.parse(name, meta, parent=parent)
448
449
450
451
452
453
454
455

    def _precache_scatter_constraints(self):
        """Precache information about group of data and available scatter axis"""
        scan = self._scan
        scatterDataDict: Dict[str, scan_model.ScatterData] = {}
        for device in scan.devices():
            for channel in device.channels():
                metadata = channel.metadata()
456
457
458
459
460
                if metadata.group is not None:
                    scatterData = scatterDataDict.get(metadata.group, None)
                    if scatterData is None:
                        scatterData = scan_model.ScatterData()
                        scatterDataDict[metadata.group] = scatterData
461
                    if metadata.axisKind is not None or metadata.axisId is not None:
462
                        scatterData.addAxisChannel(channel, metadata.axisId)
463
464
                    else:
                        scatterData.addCounterChannel(channel)
465

466
467
        for scatterData in scatterDataDict.values():
            scan.addScatterData(scatterData)
468
469


470
471
472
473
def create_scan_model(scan_info: Dict) -> scan_model.Scan:
    reader = ScanModelReader(scan_info)
    scan = reader.parse()
    return scan
474
475


476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def _pop_and_convert(meta, key, func):
    value = meta.pop(key, None)
    if value is None:
        return None
    try:
        value = func(value)
    except ValueError:
        _logger.warning("%s %s is not a valid value. Field ignored.", key, value)
        value = None
    return value


def parse_channel_metadata(meta: Dict) -> scan_model.ChannelMetadata:
    meta = meta.copy()

Valentin Valls's avatar
Valentin Valls committed
491
492
493
494
495
496
497
498
    # Compatibility Bliss 1.0
    if "axes-points" in meta and "axis-points" not in meta:
        _logger.warning("Metadata axes-points have to be replaced by axis-points.")
        meta["axis-points"] = meta.pop("axes-points")
    if "axes-kind" in meta and "axis-kind" not in meta:
        _logger.warning("Metadata axes-kind have to be replaced by axis-kind.")
        meta["axis-kind"] = meta.pop("axes-kind")

499
500
501
502
503
    start = _pop_and_convert(meta, "start", float)
    stop = _pop_and_convert(meta, "stop", float)
    vmin = _pop_and_convert(meta, "min", float)
    vmax = _pop_and_convert(meta, "max", float)
    points = _pop_and_convert(meta, "points", int)
Valentin Valls's avatar
Valentin Valls committed
504
    axisPoints = _pop_and_convert(meta, "axis-points", int)
Valentin Valls's avatar
Valentin Valls committed
505
    axisPointsHint = _pop_and_convert(meta, "axis-points-hint", int)
Valentin Valls's avatar
Valentin Valls committed
506
    axisKind = _pop_and_convert(meta, "axis-kind", scan_model.AxisKind)
507
    axisId = _pop_and_convert(meta, "axis-id", int)
508
    group = _pop_and_convert(meta, "group", str)
509
    dim = _pop_and_convert(meta, "dim", int)
510

511
512
513
514
515
516
517
518
519
520
521
522
523
524
    # Compatibility code with existing user scripts written for BLISS 1.4
    mapping = {
        scan_model.AxisKind.FAST: (0, scan_model.AxisKind.FORTH),
        scan_model.AxisKind.FAST_BACKNFORTH: (0, scan_model.AxisKind.BACKNFORTH),
        scan_model.AxisKind.SLOW: (1, scan_model.AxisKind.FORTH),
        scan_model.AxisKind.SLOW_BACKNFORTH: (1, scan_model.AxisKind.BACKNFORTH),
    }
    if axisKind in mapping:
        if axisId is not None:
            _logger.warning(
                "Both axis-id and axis-kind with flat/slow is used. axis-id will be ignored"
            )
        axisId, axisKind = mapping[axisKind]

525
    for key in meta.keys():
Valentin Valls's avatar
Valentin Valls committed
526
        _logger.warning("Metadata key %s is unknown. Field ignored.", key)
527
528

    return scan_model.ChannelMetadata(
529
530
531
532
533
534
535
536
537
        start,
        stop,
        vmin,
        vmax,
        points,
        axisId,
        axisPoints,
        axisKind,
        group,
Valentin Valls's avatar
Valentin Valls committed
538
        axisPointsHint,
539
        dim,
540
541
542
    )


543
544
545
546
547
def get_device_from_channel(channel_name) -> str:
    elements = channel_name.split(":")
    return elements[0]


548
549
550
551
552
553
554
def _select_default_counter(scan, plot):
    """Select a default counter if needed."""
    for item in plot.items():
        if isinstance(item, plot_item_model.ScatterItem):
            if item.valueChannel() is None:
                # If there is an axis but no value
                # Pick a value
555
556
557
558
559
560
561
562
563
564
565
566
567
                axisChannelRef = item.xChannel()
                if axisChannelRef is None:
                    axisChannelRef = item.yChannel()
                if axisChannelRef is None:
                    continue
                axisChannel = axisChannelRef.channel(scan)

                scatterData = scan.getScatterDataByChannel(axisChannel)
                names: List[str]
                if scatterData is not None:
                    counters = scatterData.counterChannels()
                    names = [c.name() for c in counters]
                else:
568
                    acquisition_chain = scan.scanInfo().get("acquisition_chain", None)
569
                    names = []
570
                    if acquisition_chain is not None:
571
572
573
574
575
                        for master_name in acquisition_chain.keys():
                            counter_scalars = _get_channels(
                                scan.scanInfo(), master_name, master=False, dim=0
                            )
                            names.extend(counter_scalars)
576
                if len(names) > 0:
577
578
579
580
581
582
                    # Try to use a default counter which is not an elapse time
                    quantityNames = [
                        n for n in names if scan.getChannelByName(n).unit() != "s"
                    ]
                    if len(quantityNames) > 0:
                        names = quantityNames
583
584
                    channelRef = plot_model.ChannelRef(plot, names[0])
                    item.setValueChannel(channelRef)
585
586


587
588
589
590
591
592
class DisplayExtra(NamedTuple):
    displayed_channels: Optional[List[str]]
    plotselect: Optional[List[str]]


def parse_display_extra(scan_info: Dict) -> DisplayExtra:
593
    """Return the list of the displayed channels stored in the scan"""
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609

    def parse_optional_list_of_string(data, name):
        """Sanitize data from scan_info protocol"""
        if data is None:
            return None

        if not isinstance(data, list):
            _logger.warning("%s is not a list: Key ignored", name)
            return None

        if not all([isinstance(i, str) for i in data]):
            _logger.warning("%s must only contains strings: Key ignored", name)
            return None

        return data

610
611
    display_extra = scan_info.get("_display_extra", None)
    if display_extra is not None:
612
613
614
615
616
617
        raw = display_extra.get("displayed_channels", None)
        displayed_channels = parse_optional_list_of_string(
            raw, "_display_extra.displayed_channels"
        )
        raw = display_extra.get("plotselect", None)
        plotselect = parse_optional_list_of_string(raw, "_display_extra.plotselect")
618
    else:
619
620
621
        displayed_channels = None
        plotselect = None
    return DisplayExtra(displayed_channels, plotselect)
622
623


624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
def removed_same_plots(plots, remove_plots) -> List[plot_model.Plot]:
    """Returns plots from an initial list of `plots` in which same plots was
    removed."""
    if remove_plots == []:
        return list(plots)
    result = []
    for p in plots:
        for p2 in remove_plots:
            if p.hasSameTarget(p2):
                break
        else:
            result.append(p)
            continue
    return result


640
641
642
def create_plot_model(
    scan_info: Dict, scan: Optional[scan_model.Scan] = None
) -> List[plot_model.Plot]:
643
644
645
646
647
    """Create plot models from a scan_info.

    Use the `plots` description or infer the plots from the `acquisition_chain`.
    Finally update the selection using `_display_extra`.
    """
648
649
650
    if scan is None:
        scan = create_scan_model(scan_info)

651
    if "plots" in scan_info:
652
653
654
        plots = read_plot_models(scan_info)
        for plot in plots:
            _select_default_counter(scan, plot)
655

656
        def contains_default_plot_kind(plots, plot):
657
658
            """Returns true if the list contain a default plot for this kind."""
            for p in plots:
659
660
                if p.hasSameTarget(plot):
                    return True
661
662
            return False

663
        aq_plots = infer_plot_models(scan)
664
665
666
        for plot in aq_plots:
            if not contains_default_plot_kind(plots, plot):
                plots.append(plot)
667
    else:
668
        plots = infer_plot_models(scan)
669

670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
    def filter_with_scan_content(channel_names, scan):
        if scan is None:
            return channel_names
        if channel_names is None:
            return channel_names
        # Filter selection by available channels
        intersection = set(channel_names) & set(scan.getChannelNames())
        if len(channel_names) != len(intersection):
            # Remove missing without breaking the order
            for name in list(channel_names):
                if name not in intersection:
                    channel_names.remove(name)
                    _logger.warning(
                        "Skip display of channel '%s' from scan_info. Not part of the scan",
                        name,
                    )
            if len(channel_names) == 0:
                channel_names = None
        return channel_names

    display_extra = parse_display_extra(scan_info)
    displayed_channels = filter_with_scan_content(
        display_extra.displayed_channels, scan
    )

    for plot in plots:
        channel_names = None
        if isinstance(plot, plot_item_model.CurvePlot):
            if displayed_channels is None:
                channel_names = filter_with_scan_content(display_extra.plotselect, scan)
            else:
                channel_names = displayed_channels
        elif isinstance(plot, plot_item_model.ScatterPlot):
            if displayed_channels:
                channel_names = displayed_channels
        if channel_names:
            model_helper.updateDisplayedChannelNames(plot, scan, channel_names)
707

708
    return plots
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731


def read_plot_models(scan_info: Dict) -> List[plot_model.Plot]:
    """Read description of plot models from a scan_info"""
    result: List[plot_model.Plot] = []

    plots = scan_info.get("plots", None)
    if not isinstance(plots, list):
        return []

    for plot_description in plots:
        if not isinstance(plot_description, dict):
            _logger.warning("Plot description is not a dict. Skipped.")
            continue

        kind = plot_description.get("kind", None)
        if kind != "scatter-plot":
            _logger.warning("Kind %s unsupported. Skipped.", kind)
            continue

        plot = plot_item_model.ScatterPlot()

        name = plot_description.get("name", None)
Valentin Valls's avatar
Valentin Valls committed
732
        if name is not None:
Valentin Valls's avatar
Valentin Valls committed
733
            plot.setName(name)
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764

        items = plot_description.get("items", None)
        if not isinstance(items, list):
            _logger.warning("'items' not using the right type. List expected. Ignored.")
            items = []

        for item_description in items:
            kind = item_description.get("kind", None)
            if kind == "scatter":
                item = plot_item_model.ScatterItem(plot)

                xname = item_description.get("x", None)
                if xname is not None:
                    x_channel = plot_model.ChannelRef(plot, xname)
                    item.setXChannel(x_channel)
                yname = item_description.get("y", None)
                if yname is not None:
                    y_channel = plot_model.ChannelRef(plot, yname)
                    item.setYChannel(y_channel)
                valuename = item_description.get("value", None)
                if valuename is not None:
                    value_channel = plot_model.ChannelRef(plot, valuename)
                    item.setValueChannel(value_channel)
                plot.addItem(item)
            else:
                _logger.warning("Item 'kind' %s unsupported. Item ignored.", kind)
        result.append(plot)

    return result


765
766
767
768
def _infer_default_curve_plot(
    scan_info: Dict, have_scatter: bool
) -> Optional[plot_model.Plot]:
    """Create a curve plot by inferring the acquisition chain content.
769

770
771
772
    If there is a scatter as main plot, try to use a time counter as axis.
    """
    plot = plot_item_model.CurvePlot()
773

774
775
    def get_unit(channel_name: str) -> Optional[str]:
        return scan_info["channels"][channel_name].get("unit", None)
776

777
    acquisition_chain = scan_info.get("acquisition_chain", None)
778
779
    for master_name in acquisition_chain.keys():
        scalars = _get_channels(scan_info, master_name, dim=0, master=False)
780
        master_channels = _get_channels(scan_info, master_name, dim=0, master=True)
781

782
783
784
        if have_scatter:
            # In case of scatter the curve plot have to plot the time in x
            # Masters in y1 and the first value in y2
785

786
787
788
789
790
791
792
793
794
795
            for timer in scalars:
                if timer in master_channels:
                    # skip the masters
                    continue
                if get_unit(timer) != "s":
                    # skip non time base
                    continue
                break
            else:
                timer = None
796

797
798
799
800
801
802
803
804
805
806
            for scalar in scalars:
                if scalar in master_channels:
                    # skip the masters
                    continue
                if get_unit(scalar) == "s":
                    # skip the time base
                    continue
                break
            else:
                scalar = None
807

808
809
810
811
812
813
814
815
816
            if timer is not None:
                if scalar is not None:
                    item = plot_item_model.CurveItem(plot)
                    x_channel = plot_model.ChannelRef(plot, timer)
                    y_channel = plot_model.ChannelRef(plot, scalar)
                    item.setXChannel(x_channel)
                    item.setYChannel(y_channel)
                    item.setYAxis("left")
                    plot.addItem(item)
817

818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
                for channel_name in master_channels:
                    item = plot_item_model.CurveItem(plot)
                    x_channel = plot_model.ChannelRef(plot, timer)
                    y_channel = plot_model.ChannelRef(plot, channel_name)
                    item.setXChannel(x_channel)
                    item.setYChannel(y_channel)
                    item.setYAxis("right")
                    plot.addItem(item)
            else:
                # The plot will be empty
                pass
        else:
            if len(master_channels) > 0 and master_channels[0].startswith("axis:"):
                master_channel = master_channels[0]
                master_channel_unit = get_unit(master_channel)
                is_motor_scan = master_channel_unit != "s"
            else:
                is_motor_scan = False
836

837
838
839
840
            for channel_name in scalars:
                if is_motor_scan and get_unit(channel_name) == "s":
                    # Do not display base time for motor based scan
                    continue
841

842
843
                item = plot_item_model.CurveItem(plot)
                data_channel = plot_model.ChannelRef(plot, channel_name)
844

845
846
                if len(master_channels) == 0:
                    master_channel = None
847
                else:
848
                    master_channel = plot_model.ChannelRef(plot, master_channels[0])
849

850
851
852
853
854
855
                item.setXChannel(master_channel)
                item.setYChannel(data_channel)
                plot.addItem(item)
                # Only display the first counter
                break
    return plot
856

857

858
859
860
861
def _infer_default_scatter_plot(scan_info: Dict) -> List[plot_model.Plot]:
    """Create a set of scatter plots according to the content of acquisition
    chain"""
    plots: List[plot_model.Plot] = []
862

863
864
    def get_unit(channel_name: str) -> Optional[str]:
        return scan_info["channels"][channel_name].get("unit", None)
865

866
    acquisition_chain = scan_info.get("acquisition_chain", None)
867

868
869
    for master_name in acquisition_chain.keys():
        plot = plot_item_model.ScatterPlot()
870

871
872
        scalars = _get_channels(scan_info, master_name, dim=0, master=False)
        axes_channels = _get_channels(scan_info, master_name, dim=0, master=True)
873

874
875
876
877
878
879
880
881
882
883
884
        # Reach the first scalar which is not a time unit
        for scalar in scalars:
            if scalar in axes_channels:
                # skip the masters
                continue
            if get_unit(scalar) == "s":
                # skip the time base
                continue
            break
        else:
            scalar = None
885

886
887
888
889
        if len(axes_channels) >= 1:
            x_channel = plot_model.ChannelRef(plot, axes_channels[0])
        else:
            x_channel = None
890

891
892
893
894
        if len(axes_channels) >= 2:
            y_channel = plot_model.ChannelRef(plot, axes_channels[1])
        else:
            y_channel = None
Cyril Guilloud's avatar
Cyril Guilloud committed
895

896
897
898
899
        if scalar is not None:
            data_channel = plot_model.ChannelRef(plot, scalar)
        else:
            data_channel = None
Cyril Guilloud's avatar
Cyril Guilloud committed
900

901
902
903
904
905
906
907
908
909
910
        item = plot_item_model.ScatterItem(plot)
        item.setXChannel(x_channel)
        item.setYChannel(y_channel)
        item.setValueChannel(data_channel)
        plot.addItem(item)
        plots.append(plot)

    return plots


911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
def _initialize_image_plot_from_device(device: scan_model.Device) -> plot_model.Plot:
    """Initialize ImagePlot with default information which can be used
    structurally"""
    plot = plot_item_model.ImagePlot()

    # Reach a name which is stable between 2 scans
    # FIXME: This have to be provided by the scan_info
    def get_stable_name(device):
        for channel in device.channels():
            name = channel.name()
            return name.rsplit(":", 1)[0]
        return device.fullName().split(":", 1)[1]

    stable_name = get_stable_name(device)
    plot.setDeviceName(stable_name)

    if device.type() == scan_model.DeviceType.LIMA:
        for sub_device in device.devices():
            if sub_device.name() in ["roi_counters", "roi_profiles"]:
                for roi_device in sub_device.devices():
                    if roi_device.type() != scan_model.DeviceType.VIRTUAL_ROI:
                        continue
                    item = plot_item_model.RoiItem(plot)
                    item.setDeviceName(roi_device.fullName())
                    plot.addItem(item)
    return plot


def infer_plot_models(scan: scan_model.Scan) -> List[plot_model.Plot]:
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
    """Infer description of plot models from a scan_info using
    `acquisition_chain`.

    - Dedicated default plot is created for 0D channels according to the kind
      of scan. It could be:
        - ct plot
        - curve plot
        - scatter plot
    - A dedicated image plot is created per lima detectors
    - A dedicated MCA plot is created per mca detectors
    - Remaining 2D channels are displayed as an image widget
    - Remaining 1D channels are displayed as a 1D plot
    """
    result: List[plot_model.Plot] = []

    default_plot = None
956
    scan_info = scan.scanInfo()
957
958
959
960
961
962
963

    acquisition_chain = scan_info.get("acquisition_chain", None)
    if len(acquisition_chain.keys()) == 1:
        first_key = list(acquisition_chain.keys())[0]
        if first_key == "GroupingMaster":
            # Make sure groups does not generate any plots
            return []
964

965
    # ct / curve / scatter
966

967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
    if scan_info.get("type", None) == "ct":
        plot = plot_item_model.ScalarPlot()
        result.append(plot)
    else:
        have_scalar = False
        have_scatter = False
        for master_name in acquisition_chain.keys():
            scalars = _get_channels(scan_info, master_name, dim=0, master=False)
            if len(scalars) > 0:
                have_scalar = True
            if scan_info.get("data_dim", 1) == 2 or scan_info.get("dim", 1) == 2:
                have_scatter = True

        if have_scalar:
            plot = _infer_default_curve_plot(scan_info, have_scatter)
            if plot is not None:
                result.append(plot)
                if not have_scalar:
                    default_plot = plot
        if have_scatter:
            plots = _infer_default_scatter_plot(scan_info)
            if len(plots) > 0:
                result.extend(plots)
                if default_plot is None:
                    default_plot = plots[0]
992

993
    # MCA devices
994

995
996
997
    for device_id, device_info in scan_info.get("devices", {}).items():
        device_type = device_info.get("type")
        device_name = device_id.rsplit(":", 1)[-1]
998
999
1000

        if device_type != "mca":
            continue
For faster browsing, not all history is shown. View entire blame