scatter_plot.py 36.9 KB
Newer Older
1
2
3
4
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
Benoit Formet's avatar
Benoit Formet committed
5
# Copyright (c) 2015-2020 Beamline Control Unit, ESRF
6
7
8
9
10
11
# Distributed under the GNU LGPLv3. See LICENSE for more info.

from __future__ import annotations
from typing import Tuple
from typing import Dict
from typing import List
12
from typing import Sequence
13
14
from typing import Optional

15
import logging
16
import numpy
17

18
from silx.gui import qt
19
from silx.gui import icons
20
from silx.gui import colors
Valentin Valls's avatar
Valentin Valls committed
21
from silx.gui.plot.actions import histogram
22
from silx.gui.plot.items.shape import BoundingRect
23
from silx.gui.plot.items.shape import Shape
24
from silx.gui.plot.items.scatter import Scatter
25
from silx.gui.plot.items.curve import Curve
26
27
28
29

from bliss.flint.model import scan_model
from bliss.flint.model import flint_model
from bliss.flint.model import plot_model
30
from bliss.flint.model import style_model
31
from bliss.flint.model import plot_item_model
Valentin Valls's avatar
Valentin Valls committed
32
from bliss.flint.helper import scan_info_helper
Valentin Valls's avatar
Valentin Valls committed
33
from bliss.flint.helper import model_helper
Valentin Valls's avatar
Valentin Valls committed
34
from bliss.flint.utils import signalutils
Valentin Valls's avatar
Valentin Valls committed
35
36
37
38
39
40
41
42
from .utils import plot_helper
from .utils import view_helper
from .utils import refresh_helper
from .utils import tooltip_helper
from .utils import marker_action
from .utils import export_action
from .utils import profile_action
from .utils import plot_action
43
from .utils import style_action
44
45


46
47
48
_logger = logging.getLogger(__name__)


49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class _Title:
    def __init__(self, plot):
        self.__plot = plot

        self.__hasPreviousImage: bool = False
        """Remember that there was an image before this scan, to avoid to
        override the title at startup and waiting for the first image"""
        self.__lastSubTitle = None
        """Remembers the last subtitle in case it have to be reuse when
        displaying the data from the previous scan"""

    def itemUpdated(self, scan, item):
        self.__updateAll(scan, item)

    def scanRemoved(self, scan):
        """Removed scan, just before using another scan"""
        if scan is not None:
            self.__updateTitle("From previous scan")
            self.__hasPreviousImage = True
        else:
            self.__hasPreviousImage = False

    def scanStarted(self, scan):
        if not self.__hasPreviousImage:
            self.__updateAll(scan)

    def scanFinished(self, scan):
        title = scan_info_helper.get_full_title(scan)
        if scan.state() == scan_model.ScanState.FINISHED:
            title += " (finished)"
        self.__updateTitle(title)

    def __formatItemTitle(self, scan: scan_model.Scan, item=None):
        if item is None:
            return None

        groups = {}

        groupByChannels = item.groupByChannels()
        if groupByChannels is not None:
            for channel in groupByChannels:
                channel = channel.channel(scan)
                if channel is None:
                    continue
                array = channel.array()
                if array is None or len(array) == 0:
                    continue
                fvalue = array[-1]
                groups[channel.name()] = fvalue

        if len(groups) == 0:
            return None
        titles = [f"{k} = {v}" for k, v in groups.items()]
        title = ", ".join(titles)
        return title

    def __updateTitle(self, title):
        subtitle = None
        if self.__lastSubTitle is not None:
            subtitle = self.__lastSubTitle
        if subtitle is not None:
            title = f"{title}\n{subtitle}"
        self.__plot.setGraphTitle(title)

    def __updateAll(self, scan: scan_model.Scan, item=None):
        title = scan_info_helper.get_full_title(scan)
        subtitle = None
        itemTitle = self.__formatItemTitle(scan, item)
        self.__lastSubTitle = itemTitle
        if itemTitle is not None:
            subtitle = f"{itemTitle}"
        if subtitle is not None:
            title = f"{title}\n{subtitle}"
        self.__plot.setGraphTitle(title)


125
126
127
128
class ScatterNormalization:
    """Transform raw scatter data into displayable normalized scatter"""

    def __init__(self, scan: scan_model.Scan, item: plot_model.Item, scatterSize: int):
129
130
131
132
133
134
135
136
137
138
139
140
141

        # Normalize backnforth into regular image
        channel = item.xChannel().channel(scan)
        scatter = scan.getScatterDataByChannel(channel)
        self.__axisKind: List[scan_model.AxisKind] = []
        self.__indexes = None
        self.__skipImage = False
        if scatter:
            for axisId in range(scatter.maxDim()):
                channel = scatter.channelsAt(axisId)[0]
                kind = channel.metadata().axisKind
                self.__axisKind.append(kind)
            shape = scatter.shape()
142
            self.__nbmin = numpy.prod([(1 if i in [-1, None] else i) for i in shape])
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

            kinds = set(self.__axisKind)
            hasNone = None in kinds
            hasForth = scan_model.AxisKind.FORTH in kinds
            hasBacknforth = scan_model.AxisKind.BACKNFORTH in kinds
            _hasStep = scan_model.AxisKind.STEP in kinds

            isForth = hasForth and not hasNone and not hasBacknforth
            isBacknforth = not isForth and not hasNone

            if isBacknforth:
                if scatterSize % self.__nbmin == 0:
                    size = scatterSize
                else:
                    size = scatterSize + (self.__nbmin - scatterSize)

                indexes = numpy.arange(size, dtype=int)

                try:
                    indexes.shape = shape
                    # Compute the index transformation to revert
                    # backnforth into a regular image
                    # Use numpy order
                    self.__axisKind = list(reversed(self.__axisKind))
                    for i in reversed(range(len(self.__axisKind))):
                        kind = self.__axisKind[i]
                        if kind == scan_model.AxisKind.BACKNFORTH:
                            indexes.shape = (-1,) + shape[i:]
                            indexes[1::2, :] = indexes[1::2, ::-1]
                    self.__indexes = indexes.flatten()
                except Exception:
                    # There could be a lot of inconsistencies with meta info
                    self.__skipImage = True

        # Filter to display the last frame
178
179
180
181
182
183
184
185
186
187
188
189
190
        groupByChannels = item.groupByChannels()
        if groupByChannels is not None:
            mask = numpy.array([True] * scatterSize)
            for channel in groupByChannels:
                channel = channel.channel(scan)
                if channel is None:
                    continue
                array = channel.array()
                if array is None or len(array) == 0:
                    continue
                fvalue = array[-1]
                mask = numpy.logical_and(mask, array == fvalue)
            self.__mask = mask
191

192
            if self.__indexes is not None:
193
194
                self.__skipImage = True
                self.__indexes = None
195
196
197
        else:
            self.__mask = None

Valentin Valls's avatar
Valentin Valls committed
198
199
200
        if self.__indexes is not None:
            self.__max = numpy.nanmax(self.__indexes) + 1

201
    def hasNormalization(self) -> bool:
202
        return self.__indexes is not None or self.__mask is not None
203
204

    def normalize(self, array: numpy.ndarray) -> numpy.ndarray:
205
206
207
208
209
210
211
        if array is None:
            return None
        if len(array) == 0:
            return array

        # Normalize backnforth into regular image
        if self.__indexes is not None:
Valentin Valls's avatar
Valentin Valls committed
212
            extraSize = self.__max - len(array)
Valentin Valls's avatar
Valentin Valls committed
213
214
            if extraSize > 0:
                array = numpy.append(array, [numpy.nan] * extraSize)
Valentin Valls's avatar
Valentin Valls committed
215
            return array[self.__indexes]
216
217

        # Only display last frame
218
219
220
221
        if self.__mask is None:
            return array
        return array[self.__mask]

222
223
224
225
226
227
228
229
230
231
232
    def setupScatterItem(
        self,
        scatter: Scatter,
        xChannel: scan_model.Channel,
        yChannel: scan_model.Channel,
    ):
        """Feed the scatter plot item with metadata from the channels to
        optimize the rendering"""
        xmeta = xChannel.metadata()
        ymeta = yChannel.metadata()

233
234
235
236
237
        if (
            not self.__skipImage
            and ymeta.axisPoints is not None
            and xmeta.axisPoints is not None
        ):
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
            scatter.setVisualizationParameter(
                scatter.VisualizationParameter.GRID_SHAPE,
                (ymeta.axisPoints, xmeta.axisPoints),
            )

        if (
            xmeta.start is not None
            and xmeta.stop is not None
            and ymeta.start is not None
            and ymeta.stop is not None
        ):
            scatter.setVisualizationParameter(
                scatter.VisualizationParameter.GRID_BOUNDS,
                ((xmeta.start, ymeta.start), (xmeta.stop, ymeta.stop)),
            )

254
255
256
257
258
259
260
261
        initialized = False
        hasAxisPoints = (
            xmeta.axisPointsHint is not None and ymeta.axisPointsHint is not None
        )
        hasAxisPointsHint = (
            xmeta.axisPoints is not None and ymeta.axisPoints is not None
        )

262
263
264
265
266
267
268
269
270
        if xmeta.axisKind is not None and ymeta.axisKind is not None:
            if xmeta.axisId < ymeta.axisId:
                order = "row"
            elif xmeta.axisId > ymeta.axisId:
                order = "column"

            scatter.setVisualizationParameter(
                scatter.VisualizationParameter.GRID_MAJOR_ORDER, order
            )
271
            initialized = True
272

273
        if self.__skipImage or hasAxisPointsHint or (hasAxisPoints and not initialized):
274
            width, height = xmeta.axisPointsHint, ymeta.axisPointsHint
275
276
277
278
279
280
281
282
283
            if width is None:
                width = xmeta.axisPoints
            if height is None:
                height = ymeta.axisPoints
            if height is not None and width is not None:
                scatter.setVisualizationParameter(
                    scatter.VisualizationParameter.BINNED_STATISTIC_SHAPE,
                    (height, width),
                )
284
285
286
287
288
289
290
291
            # FIXME: Clean up in few time: part of silx 0.14 and 0.13.bugfix
            if hasattr(scatter.VisualizationParameter, "DATA_BOUNDS_HINT"):
                if (
                    xmeta.start is not None
                    and xmeta.stop is not None
                    and ymeta.start is not None
                    and ymeta.stop is not None
                ):
292
293
294
295
296
297
298
299
                    xrange = min(xmeta.start, xmeta.stop), max(xmeta.start, xmeta.stop)
                    yrange = min(ymeta.start, ymeta.stop), max(ymeta.start, ymeta.stop)
                    if width > 1 and height > 1:
                        x_half_px = abs(xmeta.start - xmeta.stop) / (width - 1) * 0.5
                        y_half_px = abs(ymeta.start - ymeta.stop) / (height - 1) * 0.5
                        xrange = xrange[0] - x_half_px, xrange[1] + x_half_px
                        yrange = yrange[0] - y_half_px, yrange[1] + y_half_px

300
301
                    scatter.setVisualizationParameter(
                        scatter.VisualizationParameter.DATA_BOUNDS_HINT,
302
                        (yrange, xrange),
303
                    )
304
305
306
307

    def isImageRenderingSupported(
        self, xChannel: scan_model.Channel, yChannel: scan_model.Channel
    ):
308
        """True if there is enough metadata to display this 2 axis as an image.
309

310
        The scatter data also have to be structured in order to display it.
311
        """
312
313
314
315
        if self.__skipImage:
            return False
        if self.__indexes is not None:
            return True
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        xmeta = xChannel.metadata()
        ymeta = yChannel.metadata()
        if xmeta.axisKind != scan_model.AxisKind.FORTH:
            return False
        if ymeta.axisKind != scan_model.AxisKind.FORTH:
            return False
        return set([xmeta.axisId, ymeta.axisId]) == set([0, 1])

    def isHistogramingRenderingSupported(
        self, xChannel: scan_model.Channel, yChannel: scan_model.Channel
    ):
        """True if there is enough metadata to display this 2 axis as an
        histogram.
        """
        xmeta = xChannel.metadata()
        ymeta = yChannel.metadata()
332
        if xmeta.axisPoints is None and xmeta.axisPointsHint is None:
333
            return False
334
        if ymeta.axisPoints is None and ymeta.axisPointsHint is None:
335
336
337
            return False
        return True

338

339
class ScatterPlotWidget(plot_helper.PlotWidget):
340
341
342
343
344
345
346
347
348
    def __init__(self, parent=None):
        super(ScatterPlotWidget, self).__init__(parent=parent)
        self.__scan: Optional[scan_model.Scan] = None
        self.__flintModel: Optional[flint_model.FlintState] = None
        self.__plotModel: plot_model.Plot = None

        self.__items: Dict[plot_model.Item, List[Tuple[str, str]]] = {}

        self.__plotWasUpdated: bool = False
349
        self.__plot = plot_helper.FlintPlot(parent=self)
350
        self.__plot.setActiveCurveStyle(linewidth=2)
Valentin Valls's avatar
Valentin Valls committed
351
        self.__plot.setDataMargins(0.05, 0.05, 0.05, 0.05)
Valentin Valls's avatar
Valentin Valls committed
352

353
354
        self.__colormap = colors.Colormap("viridis")

355
356
        self.__title = _Title(self.__plot)

357
358
359
        self.setFocusPolicy(qt.Qt.StrongFocus)
        self.__plot.installEventFilter(self)
        self.__plot.getWidgetHandle().installEventFilter(self)
360
        self.__view = view_helper.ViewManager(self.__plot)
361
362

        self.__aggregator = signalutils.EventAggregator(self)
363
        self.__refreshManager = refresh_helper.RefreshManager(self)
364
365
366
367
        self.__refreshManager.setAggregator(self.__aggregator)

        toolBar = self.__createToolBar()

Valentin Valls's avatar
Valentin Valls committed
368
369
        # Try to improve the look and feel
        # FIXME: THis should be done with stylesheet
370
371
372
373
        line = qt.QFrame(self)
        line.setFrameShape(qt.QFrame.HLine)
        line.setFrameShadow(qt.QFrame.Sunken)

Valentin Valls's avatar
Valentin Valls committed
374
375
        frame = qt.QFrame(self)
        frame.setFrameShape(qt.QFrame.StyledPanel)
376
        frame.setAutoFillBackground(True)
Valentin Valls's avatar
Valentin Valls committed
377
378
        layout = qt.QVBoxLayout(frame)
        layout.setContentsMargins(0, 0, 0, 0)
379
380
381
382
        layout.setSpacing(0)
        layout.addWidget(toolBar)
        layout.addWidget(line)
        layout.addWidget(self.__plot)
Valentin Valls's avatar
Valentin Valls committed
383
384
385
386
387
388
        widget = qt.QFrame(self)
        layout = qt.QVBoxLayout(widget)
        layout.addWidget(frame)
        layout.setContentsMargins(0, 1, 0, 0)
        self.setWidget(widget)

389
        self.__tooltipManager = tooltip_helper.TooltipItemManager(self, self.__plot)
390
        self.__tooltipManager.setFilter(plot_helper.FlintScatter)
391

392
393
        self.__syncAxisTitle = signalutils.InvalidatableSignal(self)
        self.__syncAxisTitle.triggered.connect(self.__updateAxesLabel)
Valentin Valls's avatar
Valentin Valls committed
394
        self.__syncAxis = signalutils.InvalidatableSignal(self)
395
396
397
        self.__syncAxis.triggered.connect(self.__scatterAxesUpdated)

        self.__bounding = BoundingRect()
Valentin Valls's avatar
Valentin Valls committed
398
        self.__bounding.setName("bound")
399
400
401

        self.__lastValue = Scatter()
        self.__lastValue.setSymbol(",")
Valentin Valls's avatar
Valentin Valls committed
402
        self.__lastValue.setName("cursor_last_value")
403
        self.__lastValue.setVisible(False)
Valentin Valls's avatar
Valentin Valls committed
404
        self.__lastValue.setZValue(10)
405
        self.__rect = Shape("rectangle")
Valentin Valls's avatar
Valentin Valls committed
406
        self.__rect.setName("rect")
407
408
409
410
411
        self.__rect.setVisible(False)
        self.__rect.setFill(False)
        self.__rect.setColor("#E0E0E0")
        self.__rect.setZValue(0.1)

412
413
414
415
        self.__plot.addItem(self.__bounding)
        self.__plot.addItem(self.__tooltipManager.marker())
        self.__plot.addItem(self.__lastValue)
        self.__plot.addItem(self.__rect)
416

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        self.widgetActivated.connect(self.__activated)

    def __activated(self):
        self.__initColormapWidget()

    def __initColormapWidget(self):
        live = self.flintModel().liveWindow()
        colormapWidget = live.acquireColormapWidget(self)
        if colormapWidget is not None:
            for item in self.__plot.getItems():
                if isinstance(item, plot_helper.FlintScatter):
                    colormapWidget.setItem(item)
                    break
            else:
                colormapWidget.setColormap(self.__colormap)

    def configuration(self):
        config = super(ScatterPlotWidget, self).configuration()
        try:
            config.colormap = self.__colormap._toDict()
        except Exception:
            # As it relies on private API, make it safe
            _logger.error("Impossible to save colormap preference", exc_info=True)
        return config

    def setConfiguration(self, config):
        try:
            self.__colormap._setFromDict(config.colormap)
        except Exception:
            # As it relies on private API, make it safe
            _logger.error("Impossible to restore colormap preference", exc_info=True)
        super(ScatterPlotWidget, self).setConfiguration(config)

450
451
452
    def defaultColormap(self):
        return self.__colormap

453
454
455
    def getRefreshManager(self) -> plot_helper.RefreshManager:
        return self.__refreshManager

456
457
    def __createToolBar(self):
        toolBar = qt.QToolBar(self)
Valentin Valls's avatar
Valentin Valls committed
458
        toolBar.setMovable(False)
459
460
461

        from silx.gui.plot.actions import mode
        from silx.gui.plot.actions import control
Valentin Valls's avatar
Valentin Valls committed
462
        from silx.gui.widgets.MultiModeAction import MultiModeAction
463

Valentin Valls's avatar
Valentin Valls committed
464
465
466
467
        modeAction = MultiModeAction(self)
        modeAction.addAction(mode.ZoomModeAction(self.__plot, self))
        modeAction.addAction(mode.PanModeAction(self.__plot, self))
        toolBar.addAction(modeAction)
468

469
        resetZoom = self.__view.createResetZoomAction(parent=self)
470
471
472
473
        toolBar.addAction(resetZoom)
        toolBar.addSeparator()

        # Axis
474
        action = self.__refreshManager.createRefreshAction(self)
475
        toolBar.addAction(action)
Valentin Valls's avatar
Valentin Valls committed
476
477
478
        toolBar.addAction(
            plot_action.CustomAxisAction(self.__plot, self, kind="scatter")
        )
479
480
        toolBar.addSeparator()

481
482
483
484
        # Item
        action = style_action.FlintItemStyleAction(self.__plot, self)
        toolBar.addAction(action)
        self.__styleAction = action
485
486
        action = style_action.FlintSharedColormapAction(self.__plot, self)
        action.setInitColormapWidgetCallback(self.__initColormapWidget)
487
488
489
490
        toolBar.addAction(action)
        self.__contrastAction = action
        toolBar.addSeparator()

491
        # Tools
Valentin Valls's avatar
Valentin Valls committed
492
493
494
        action = control.CrosshairAction(self.__plot, parent=self)
        action.setIcon(icons.getQIcon("flint:icons/crosshair"))
        toolBar.addAction(action)
Valentin Valls's avatar
Valentin Valls committed
495
496

        action = histogram.PixelIntensitiesHistoAction(self.__plot, self)
Valentin Valls's avatar
Valentin Valls committed
497
498
499
        icon = icons.getQIcon("flint:icons/histogram")
        action.setIcon(icon)
        toolBar.addAction(action)
500

Valentin Valls's avatar
Valentin Valls committed
501
        toolBar.addAction(profile_action.ProfileAction(self.__plot, self, "scatter"))
Valentin Valls's avatar
Valentin Valls committed
502

Valentin Valls's avatar
Valentin Valls committed
503
        action = marker_action.MarkerAction(
Valentin Valls's avatar
Valentin Valls committed
504
505
506
507
508
            plot=self.__plot, parent=self, kind="scatter"
        )
        self.__markerAction = action
        toolBar.addAction(action)

Valentin Valls's avatar
Valentin Valls committed
509
510
511
512
        action = control.ColorBarAction(self.__plot, self)
        icon = icons.getQIcon("flint:icons/colorbar")
        action.setIcon(icon)
        toolBar.addAction(action)
513
514
515
        toolBar.addSeparator()

        # Export
Valentin Valls's avatar
Valentin Valls committed
516

Valentin Valls's avatar
Valentin Valls committed
517
518
        self.__exportAction = export_action.ExportAction(self.__plot, self)
        toolBar.addAction(self.__exportAction)
519
520
521

        return toolBar

522
523
524
525
    def logbookAction(self):
        """Expose a logbook action if one"""
        return self.__exportAction.logbookAction()

Valentin Valls's avatar
Valentin Valls committed
526
527
528
529
530
531
532
    def _silxPlot(self):
        """Returns the silx plot associated to this view.

        It is provided without any warranty.
        """
        return self.__plot

533
534
535
536
537
538
539
540
541
542
543
544
    def eventFilter(self, widget, event):
        if widget is not self.__plot and widget is not self.__plot.getWidgetHandle():
            return
        if event.type() == qt.QEvent.MouseButtonPress:
            self.widgetActivated.emit(self)
        return widget.eventFilter(widget, event)

    def createPropertyWidget(self, parent: qt.QWidget):
        from . import scatter_plot_property

        propertyWidget = scatter_plot_property.ScatterPlotPropertyWidget(parent)
        propertyWidget.setFlintModel(self.__flintModel)
545
        propertyWidget.setFocusWidget(self)
546
547
        return propertyWidget

548
549
550
    def flintModel(self) -> Optional[flint_model.FlintState]:
        return self.__flintModel

551
552
    def setFlintModel(self, flintModel: Optional[flint_model.FlintState]):
        self.__flintModel = flintModel
Valentin Valls's avatar
Valentin Valls committed
553
        self.__exportAction.setFlintModel(flintModel)
554
555
        self.__styleAction.setFlintModel(flintModel)
        self.__contrastAction.setFlintModel(flintModel)
556
557
558

    def setPlotModel(self, plotModel: plot_model.Plot):
        if self.__plotModel is not None:
559
560
561
562
563
564
565
566
567
            self.__plotModel.structureChanged.disconnect(
                self.__aggregator.callbackTo(self.__structureChanged)
            )
            self.__plotModel.itemValueChanged.disconnect(
                self.__aggregator.callbackTo(self.__itemValueChanged)
            )
            self.__plotModel.transactionFinished.disconnect(
                self.__aggregator.callbackTo(self.__transactionFinished)
            )
568
569
        self.__plotModel = plotModel
        if self.__plotModel is not None:
570
571
572
573
574
575
576
577
578
            self.__plotModel.structureChanged.connect(
                self.__aggregator.callbackTo(self.__structureChanged)
            )
            self.__plotModel.itemValueChanged.connect(
                self.__aggregator.callbackTo(self.__itemValueChanged)
            )
            self.__plotModel.transactionFinished.connect(
                self.__aggregator.callbackTo(self.__transactionFinished)
            )
579
        self.plotModelUpdated.emit(plotModel)
580
        self.__sanitizeItems()
581
        self.__redrawAll()
582
        self.__syncAxisTitle.trigger()
Valentin Valls's avatar
Valentin Valls committed
583
        self.__syncAxis.trigger()
584
585
586
587
588
589

    def plotModel(self) -> plot_model.Plot:
        return self.__plotModel

    def __structureChanged(self):
        self.__redrawAll()
590
        self.__syncAxisTitle.trigger()
Valentin Valls's avatar
Valentin Valls committed
591
        self.__syncAxis.trigger()
592
593
594
595

    def __transactionFinished(self):
        if self.__plotWasUpdated:
            self.__plotWasUpdated = False
596
            self.__view.plotUpdated()
597
        self.__syncAxisTitle.validate()
Valentin Valls's avatar
Valentin Valls committed
598
        self.__syncAxis.validate()
599
600
601
602

    def __itemValueChanged(
        self, item: plot_model.Item, eventType: plot_model.ChangeEventType
    ):
Valentin Valls's avatar
Valentin Valls committed
603
        inTransaction = self.__plotModel.isInTransaction()
604
605
        if eventType == plot_model.ChangeEventType.VISIBILITY:
            self.__updateItem(item)
606
            self.__syncAxisTitle.triggerIf(not inTransaction)
607
608
        elif eventType == plot_model.ChangeEventType.CUSTOM_STYLE:
            self.__updateItem(item)
609
        elif eventType == plot_model.ChangeEventType.X_CHANNEL:
610
            self.__sanitizeItem(item)
611
            self.__updateItem(item)
612
            self.__syncAxisTitle.triggerIf(not inTransaction)
Valentin Valls's avatar
Valentin Valls committed
613
            self.__syncAxis.triggerIf(not inTransaction)
614
        elif eventType == plot_model.ChangeEventType.Y_CHANNEL:
615
            self.__sanitizeItem(item)
616
            self.__updateItem(item)
617
            self.__syncAxisTitle.triggerIf(not inTransaction)
Valentin Valls's avatar
Valentin Valls committed
618
            self.__syncAxis.triggerIf(not inTransaction)
619
        elif eventType == plot_model.ChangeEventType.VALUE_CHANNEL:
620
            self.__sanitizeItem(item)
621
622
            self.__updateItem(item)

623
624
625
626
627
628
629
630
631
    def __scatterAxesUpdated(self):
        scan = self.__scan
        plot = self.__plotModel
        if plot is None:
            bound = None
        else:
            xAxis = set([])
            yAxis = set([])
            for item in plot.items():
Valentin Valls's avatar
Valentin Valls committed
632
633
634
635
636
637
                xChannel = item.xChannel()
                yChannel = item.yChannel()
                if xChannel is not None:
                    xAxis.add(xChannel.channel(scan))
                if yChannel is not None:
                    yAxis.add(yChannel.channel(scan))
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
            xAxis.discard(None)
            yAxis.discard(None)

            def getRange(axis: Sequence[scan_model.Channel]):
                vv = set([])
                for a in axis:
                    metadata = a.metadata()
                    v = set([metadata.start, metadata.stop, metadata.min, metadata.max])
                    vv.update(v)
                vv.discard(None)
                if len(vv) == 0:
                    return None, None
                return min(vv), max(vv)

            xRange = getRange(list(xAxis))
            yRange = getRange(list(yAxis))
            if xRange[0] is None or yRange[0] is None:
                bound = None
            else:
                bound = (xRange[0], xRange[1], yRange[0], yRange[1])

        self.__bounding.setBounds(bound)
Valentin Valls's avatar
Valentin Valls committed
660

661
662
663
664
665
666
        if bound is not None:
            self.__rect.setVisible(True)
            self.__rect.setPoints([(xRange[0], yRange[0]), (xRange[1], yRange[1])])
        else:
            self.__rect.setVisible(False)

Valentin Valls's avatar
Valentin Valls committed
667
    def __updateAxesLabel(self):
668
        scan = self.__scan
Valentin Valls's avatar
Valentin Valls committed
669
670
671
672
673
674
675
676
677
678
        plot = self.__plotModel
        if plot is None:
            xLabel = ""
            yLabel = ""
        else:
            xLabels = []
            yLabels = []
            for item in plot.items():
                if not item.isValid():
                    continue
679
680
                if not item.isVisible():
                    continue
Valentin Valls's avatar
Valentin Valls committed
681
                if isinstance(item, plot_item_model.ScatterItem):
682
683
                    xLabels.append(item.xChannel().displayName(scan))
                    yLabels.append(item.yChannel().displayName(scan))
Valentin Valls's avatar
Valentin Valls committed
684
685
686
687
688
            xLabel = " + ".join(sorted(set(xLabels)))
            yLabel = " + ".join(sorted(set(yLabels)))
        self.__plot.getXAxis().setLabel(xLabel)
        self.__plot.getYAxis().setLabel(yLabel)

689
690
691
    def scan(self) -> Optional[scan_model.Scan]:
        return self.__scan

692
    def setScan(self, scan: scan_model.Scan = None):
693
694
695
        if self.__scan is scan:
            return
        if self.__scan is not None:
696
697
698
699
700
701
702
703
704
            self.__scan.scanDataUpdated[object].disconnect(
                self.__aggregator.callbackTo(self.__scanDataUpdated)
            )
            self.__scan.scanStarted.disconnect(
                self.__aggregator.callbackTo(self.__scanStarted)
            )
            self.__scan.scanFinished.disconnect(
                self.__aggregator.callbackTo(self.__scanFinished)
            )
705
        self.__title.scanRemoved(self.__scan)
706
707
        self.__scan = scan
        if self.__scan is not None:
708
709
710
711
712
713
714
715
716
            self.__scan.scanDataUpdated[object].connect(
                self.__aggregator.callbackTo(self.__scanDataUpdated)
            )
            self.__scan.scanStarted.connect(
                self.__aggregator.callbackTo(self.__scanStarted)
            )
            self.__scan.scanFinished.connect(
                self.__aggregator.callbackTo(self.__scanFinished)
            )
Valentin Valls's avatar
Valentin Valls committed
717
            if self.__scan.state() != scan_model.ScanState.INITIALIZED:
718
719
                self.__title.scanStarted(self.__scan)

720
        self.scanModelUpdated.emit(scan)
721
        self.__sanitizeItems()
722
723
724
        self.__redrawAll()

    def __scanStarted(self):
725
        self.__refreshManager.scanStarted()
Valentin Valls's avatar
typo    
Valentin Valls committed
726
        if self.__flintModel is not None and self.__flintModel.getDate() == "0214":
Valentin Valls's avatar
Valentin Valls committed
727
728
729
            self.__lastValue.setSymbol("\u2665")
        else:
            self.__lastValue.setSymbol(",")
Valentin Valls's avatar
Valentin Valls committed
730
        self.__markerAction.clear()
Valentin Valls's avatar
Valentin Valls committed
731
732
        self.__lastValue.setData(x=[], y=[], value=[])
        self.__lastValue.setVisible(True)
733
        self.__view.scanStarted()
734
        self.__syncAxis.trigger()
735
        self.__title.scanStarted(self.__scan)
736
737

    def __scanFinished(self):
738
        self.__refreshManager.scanFinished()
Valentin Valls's avatar
Valentin Valls committed
739
        self.__lastValue.setVisible(False)
740
        self.__title.scanFinished(self.__scan)
741

742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
    def __scanDataUpdated(self, event: scan_model.ScanDataUpdateEvent):
        plotModel = self.__plotModel
        if plotModel is None:
            return
        for item in plotModel.items():
            if not isinstance(item, plot_item_model.ScatterItem):
                continue
            if not item.isValid():
                continue
            # Create an API to return the involved channel names
            xName = item.xChannel().name()
            yName = item.yChannel().name()
            valueName = item.valueChannel().name()
            if (
                event.isUpdatedChannelName(xName)
                or event.isUpdatedChannelName(yName)
                or event.isUpdatedChannelName(valueName)
            ):
                self.__updateItem(item)
761
762
763
764
765

    def __cleanAll(self):
        for _item, itemKeys in self.__items.items():
            for key in itemKeys:
                self.__plot.remove(*key)
766
767
        self.__rect.setVisible(False)
        self.__lastValue.setVisible(False)
768
        self.__view.plotCleared()
769

770
    def __cleanItem(self, item: plot_model.Item) -> bool:
771
        itemKeys = self.__items.pop(item, [])
772
773
        if len(itemKeys) == 0:
            return False
774
775
        for key in itemKeys:
            self.__plot.remove(*key)
776
        return True
777
778
779
780
781
782
783
784
785
786

    def __redrawAll(self):
        self.__cleanAll()
        plotModel = self.__plotModel
        if plotModel is None:
            return

        for item in plotModel.items():
            self.__updateItem(item)

787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
    def __sanitizeItems(self):
        scan = self.__scan
        if scan is None:
            return
        plot = self.__plotModel
        if plot is None:
            return
        for item in plot.items():
            if isinstance(item, plot_item_model.ScatterItem):
                self.__sanitizeItem(item)

    def __sanitizeItem(self, item: plot_item_model.ScatterItem):
        if not item.isValid():
            return
        if not isinstance(item, plot_item_model.ScatterItem):
            return

        xChannelRef = item.xChannel()
        yChannelRef = item.yChannel()
        if xChannelRef is None or yChannelRef is None:
            return

        scan = self.__scan
        assert scan is not None
        xChannel = xChannelRef.channel(scan)
        yChannel = yChannelRef.channel(scan)
        if xChannel is None or yChannel is None:
            return

        if xChannel.metadata().group != yChannel.metadata().group:
            # FIXME: This should be cached... Try to display data not from the same group
            return

        scatterData = scan.getScatterDataByChannel(xChannel)
        if scatterData is None:
            # FIXME: This should be cached... Try to display data not from the same group
            return

        if scatterData.maxDim() <= 2:
            # Nothing to do
            return

        # Now we have to find groupBy
        xId = scatterData.channelAxis(xChannel)
        yId = scatterData.channelAxis(yChannel)
        if xId == yId:
            # FIXME: This should not be displayed anyway
            _logger.warning("ndim scatter using same axis dim for the 2 axis")
            return

        # Try to find channels to group together other dimensions
        axisIds = list(range(scatterData.maxDim()))
        axisIds.remove(xId)
        axisIds.remove(yId)
        groupBys = [scatterData.findGroupableAt(i) for i in axisIds]
        if None in groupBys:
            # FIXME: Should not be displayed
            _logger.warning("ndim scatter can't be grouped to 2d scatter")
            return

        groupByRefs = [plot_model.ChannelRef(item, c.name()) for c in groupBys]
        item.setGroupByChannels(groupByRefs)

850
851
852
853
854
855
856
857
858
859
    def __updateItem(self, item: plot_model.Item):
        if self.__plotModel is None:
            return
        if self.__scan is None:
            return
        if not item.isValid():
            return
        if not isinstance(item, plot_item_model.ScatterItem):
            return

860
        scan = self.__scan
861
862
863
        plot = self.__plot
        plotItems: List[Tuple[str, str]] = []

864
        updateZoomNow = not self.__plotModel.isInTransaction()
865

866
        wasUpdated = self.__cleanItem(item)
867

868
        if not item.isVisible():
869
870
            if wasUpdated:
                self.__updatePlotZoom(updateZoomNow)
871
872
            return

873
        if not item.isValidInScan(scan):
874
875
            if wasUpdated:
                self.__updatePlotZoom(updateZoomNow)
876
877
            return

878
879
880
881
        valueChannel = item.valueChannel()
        xChannel = item.xChannel()
        yChannel = item.yChannel()
        if valueChannel is None or xChannel is None or yChannel is None:
882
883
            if wasUpdated:
                self.__updatePlotZoom(updateZoomNow)
884
885
            return

886
887
888
889
        # Channels from channel ref
        xChannel = xChannel.channel(scan)
        yChannel = yChannel.channel(scan)

890
        value = valueChannel.array(scan)
891
892
        xx = xChannel.array()
        yy = yChannel.array()
893
        if value is None or xx is None or yy is None:
894
895
            if wasUpdated:
                self.__updatePlotZoom(updateZoomNow)
896
897
            return

898
        # FIXME: This have to be cached and optimized
899
900
901
902
903
904
905
906
907
908
        scatterSize = len(xx)
        normalization = ScatterNormalization(scan, item, scatterSize)
        if normalization.hasNormalization():
            xx = normalization.normalize(xx)
            yy = normalization.normalize(yy)
            value = normalization.normalize(value)
            indexes = numpy.arange(scatterSize)
            indexes = normalization.normalize(indexes)
        else:
            indexes = None
909

910
911
        self.__title.itemUpdated(scan, item)

912
        legend = valueChannel.name()
913
        style = item.getStyle(scan)
914
        colormap = model_helper.getColormapWithItemStyle(item, style, self.__colormap)
915

Valentin Valls's avatar
Valentin Valls committed
916
917
        scatter = None
        curve = None
Valentin Valls's avatar
Valentin Valls committed
918
919
920
921
        pointBased = True
        if style.fillStyle is not style_model.FillStyle.NO_FILL:
            pointBased = False
            fillStyle = style.fillStyle
922
            scatter = plot_helper.FlintScatter()
923
            scatter.setData(x=xx, y=yy, value=value, copy=False)
924
            scatter.setRealIndexes(indexes)
925
926
            scatter.setColormap(colormap)
            scatter.setCustomItem(item)
927
            scatter.setScan(scan)
928
            key = legend + "_solid"
Valentin Valls's avatar
Valentin Valls committed
929
            scatter.setName(key)
930

931
932
            if fillStyle == style_model.FillStyle.SCATTER_INTERPOLATION:
                scatter.setVisualization(scatter.Visualization.SOLID)
933
            elif normalization.isImageRenderingSupported(xChannel, yChannel):
934
935
936
937
                if fillStyle == style_model.FillStyle.SCATTER_REGULAR_GRID:
                    scatter.setVisualization(scatter.Visualization.REGULAR_GRID)
                elif fillStyle == style_model.FillStyle.SCATTER_IRREGULAR_GRID:
                    scatter.setVisualization(scatter.Visualization.IRREGULAR_GRID)
938
            elif normalization.isHistogramingRenderingSupported(xChannel, yChannel):
939
940
                # Fall back with an histogram
                scatter.setVisualization(scatter.Visualization.BINNED_STATISTIC)
941
            else:
942
                pointBased = True
943
944
945

            if not pointBased:
                plot.addItem(scatter)
946
                normalization.setupScatterItem(scatter, xChannel, yChannel)
947
                plotItems.append((key, "scatter"))
948

949
950
951
952
953
954
955
        if not pointBased and len(value) >= 1:
            vmin, vmax = colormap.getColormapRange(value)
            colormap2 = colormap.copy()
            colormap2.setVRange(vmin, vmax)
            self.__lastValue.setData(x=xx[-1:], y=yy[-1:], value=value[-1:])
            self.__lastValue.setColormap(colormap2)

956
957
958
959
960
961
962
        if style.lineStyle == style_model.LineStyle.SCATTER_SEQUENCE:
            key = plot.addCurve(
                x=xx,
                y=yy,
                legend=legend + "_line",
                color=style.lineColor,
                linestyle="-",
963
                resetzoom=False,
964
965
966
            )
            plotItems.append((key, "curve"))

Valentin Valls's avatar
Valentin Valls committed
967
        if pointBased:
968
969
970
            symbolStyle = style_model.symbol_to_silx(style.symbolStyle)
            if symbolStyle == " ":
                symbolStyle = "o"
971
            scatter = plot_helper.FlintScatter()
972
            scatter.setData(x=xx, y=yy, value=value, copy=False)
973
            scatter.setRealIndexes(indexes)
974
975
976
977
            scatter.setColormap(colormap)
            scatter.setSymbol(symbolStyle)
            scatter.setSymbolSize(style.symbolSize)
            scatter.setCustomItem(item)
978
            scatter.setScan(scan)
979
            key = legend + "_point"
Valentin Valls's avatar
Valentin Valls committed
980
981
            scatter.setName(key)
            plot.addItem(scatter)
982
            plotItems.append((key, "scatter"))
983
984
985
986
        elif (
            style.symbolStyle is not style_model.SymbolStyle.NO_SYMBOL
            and style.symbolColor is not None
        ):
Valentin Valls's avatar
Valentin Valls committed
987
            symbolStyle = style_model.symbol_to_silx(style.symbolStyle)
988
            curve = Curve()
989
990
991
992
            curve.setData(x=xx, y=yy, copy=False)
            curve.setColor(style.symbolColor)
            curve.setSymbol(symbolStyle)
            curve.setLineStyle(" ")
Valentin Valls's avatar
Valentin Valls committed
993
            curve.setSymbolSize(style.symbolSize)
994
            key = legend + "_point"
Valentin Valls's avatar
Valentin Valls committed
995
996
            curve.setName(key)
            plot.addItem(curve)
Valentin Valls's avatar
Valentin Valls committed
997
            plotItems.append((key, "curve"))
998

999
1000
1001
1002
1003
1004
        live = self.flintModel().liveWindow()
        if live is not None:
            colormapWidget = live.ownedColormapWidget(self)
        else:
            colormapWidget = None

Valentin Valls's avatar
Valentin Valls committed
1005
        if scatter is not None:
1006
1007
1008
            # Profile is not selectable,
            # so it does not interfere with profile interaction
            scatter._setSelectable(False)
Valentin Valls's avatar
Valentin Valls committed
1009
1010
1011
1012
            self.__plot._setActiveItem("scatter", scatter.getLegend())
        elif curve is not None:
            self.__plot._setActiveItem("curve", curve.getLegend())

1013
1014
1015
1016
1017
1018
        if colormapWidget is not None:
            if scatter is not None:
                colormapWidget.setItem(scatter)
            else:
                colormapWidget.setItem(None)

1019
        self.__items[item] = plotItems
1020
1021
1022
1023
1024
        self.__updatePlotZoom(updateZoomNow)

    def __updatePlotZoom(self, updateZoomNow):
        if updateZoomNow:
            self.__view.plotUpdated()
1025
1026
        else:
            self.__plotWasUpdated = True