XasObjectViewer.py 26.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/
"""Tools to visualize spectra"""


__authors__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "04/07/2019"


payno's avatar
payno committed
33
from est.core.types import XASObject, Spectrum
34
from est.core.utils.symbol import MU_CHAR
35
36
37
38
39
from silx.gui import qt
from silx.gui.plot import Plot1D
from silx.gui.plot.StackView import StackViewMainWindow
from silx.utils.enum import Enum
from silx.gui.widgets.FrameBrowser import HorizontalSliderWithBrowser
payno's avatar
payno committed
40
from est.gui import icons
41
from silx.gui import icons as silx_icons
42
from silx.gui.colors import Colormap
43
from typing import Iterable
44
45
import numpy
import silx
46
47
48
import logging

_logger = logging.getLogger(__name__)
49

50
# median spectrum view
payno's avatar
payno committed
51
silx_version = silx.version.split(".")
52
53
54
55
if not (int(silx_version[0]) == 0 and int(silx_version[1]) <= 11):
    silx_plot_has_baseline_feature = True
else:
    silx_plot_has_baseline_feature = False
payno's avatar
payno committed
56
57
58
    _logger.warning(
        "a more recent of silx is required to display " "mean spectrum (0.12)"
    )
59

60
61

class ViewType(Enum):
payno's avatar
payno committed
62
63
    map = (0,)
    spectrum = (1,)
64
65
66


class _SpectrumViewAction(qt.QAction):
67
    def __init__(self, parent=None, iView=0):
payno's avatar
payno committed
68
        qt.QAction.__init__(self, "spectrum view", parent=parent)
69
        assert iView in (0, 1)  # for now we can only deal with two plot at max
payno's avatar
payno committed
70
        # otherwise no more icon color to display
71
72
73
74
75
        self._iView = iView
        if iView == 0:
            icon = "item-1dim"
        elif iView == 1:
            icon = "item-1dim-black"
76
77
        else:
            # if necessary: add more icons, this is the only limitation
payno's avatar
payno committed
78
            raise NotImplementedError("Only two spectrum views are maanged")
79
        spectrum_icon = icons.getQIcon(icon)
80
81
82
83
84
85
        self.setIcon(spectrum_icon)
        self.setCheckable(True)


class _MapViewAction(qt.QAction):
    def __init__(self, parent=None):
payno's avatar
payno committed
86
        qt.QAction.__init__(self, "map view", parent=parent)
87
88
89
90
91
92
        map_icon = silx_icons.getQIcon("image")
        self.setIcon(map_icon)
        self.setCheckable(True)


class XasObjectViewer(qt.QMainWindow):
93
    """Viewer dedicated to view a XAS object
94

95
96
97
98
99
    :param QObject parent: Qt parent
    :param list mapKeys: list of str keys to propose for the map display
    :param list spectrumsPlots: list of keys if several spectrum plot should be
                                proposed.
    """
payno's avatar
payno committed
100

101
102
103
    viewTypeChanged = qt.Signal()
    """emitted when the view type change"""

104
    def __init__(self, parent=None, mapKeys=None, spectrumPlots=None):
105
106
107
108
109
110
111
112
        qt.QMainWindow.__init__(self, parent)
        self.setWindowFlags(qt.Qt.Widget)

        # main stack widget
        self._mainWidget = qt.QWidget(parent=self)
        self._mainWidget.setLayout(qt.QVBoxLayout())
        self.setCentralWidget(self._mainWidget)
        # map view
113
        self._mapView = MapViewer(parent=self, keys=mapKeys)
114
115
116
        self._mainWidget.layout().addWidget(self._mapView)

        # spectrum view
117
118
119
120
        self._spectrumViews = []
        if spectrumPlots is not None:
            spectrum_views_ = spectrumPlots
        else:
payno's avatar
payno committed
121
            spectrum_views_ = ("",)
122
123
124
125
        for spectrumPlot in range(len(spectrum_views_)):
            spectrumView = SpectrumViewer(parent=self)
            self._mainWidget.layout().addWidget(spectrumView)
            self._spectrumViews.append(spectrumView)
126
        # add toolbar
payno's avatar
payno committed
127
        toolbar = qt.QToolBar("")
128
        toolbar.setIconSize(qt.QSize(32, 32))
129
        self._spectrumViewActions = []
130
        self.view_actions = qt.QActionGroup(self)
131
        for iSpectrumView, tooltip in enumerate(spectrum_views_):
payno's avatar
payno committed
132
            spectrumViewAction = _SpectrumViewAction(parent=None, iView=iSpectrumView)
133
134
135
136
            self.view_actions.addAction(spectrumViewAction)
            self._spectrumViewActions.append(spectrumViewAction)
            spectrumViewAction.setToolTip(tooltip)
            toolbar.addAction(spectrumViewAction)
137
138
        self._mapViewAction = _MapViewAction()
        toolbar.addAction(self._mapViewAction)
139
        self.view_actions.addAction(self._mapViewAction)
140
141
142
143
144
145

        self.addToolBar(qt.Qt.LeftToolBarArea, toolbar)
        toolbar.setMovable(False)

        # connect signal / Slot
        self._mapViewAction.triggered.connect(self._updateView)
146
147
        for spectrumAction in self._spectrumViewActions:
            spectrumAction.triggered.connect(self._updateView)
148
149

        # initialize
150
        self._spectrumViewActions[0].setChecked(True)
151
152
153
        self._updateView()

    def _updateView(self, *arg, **kwargs):
154
155
156
157
        index, view_type = self.getViewType()
        self._mapView.setVisible(view_type is ViewType.map)
        for iView, spectrumView in enumerate(self._spectrumViews):
            spectrumView.setVisible(view_type is ViewType.spectrum and iView == index)
158
        self.viewTypeChanged.emit()
159
160
161

    def getViewType(self):
        if self._mapViewAction.isChecked():
162
163
164
165
166
167
            return None, ViewType.map
        else:
            for spectrumViewAction in self._spectrumViewActions:
                if spectrumViewAction.isChecked():
                    return spectrumViewAction._iView, ViewType.spectrum
        return None, None
168
169
170
171

    def setXASObj(self, xas_obj):
        self._mapView.clear()

172
        self._mapView.setXasObject(xas_obj)
173
174
175
        for spectrumView in self._spectrumViews:
            spectrumView.clear()
            spectrumView.setXasObject(xas_obj)
176
177


178
class MapViewer(qt.QWidget):
179
180
181
182
183
184
185
    """
    Widget to display different map of the spectra
    """

    sigFrameChanged = qt.Signal(int)
    """Signal emitter when the frame number has changed."""

186
187
    def __init__(self, parent=None, keys=None):
        """
payno's avatar
payno committed
188
189

        :param parent:
190
191
192
193
194
195
196
197
198
199
200
201
202
        :param keys: volume keys to display for the xasObject (Mu,
        NormalizedMu...)
        """
        assert keys is not None
        self._xasObj = None
        qt.QWidget.__init__(self, parent=parent)
        self.setLayout(qt.QVBoxLayout())
        self._mainWindow = StackViewMainWindow(parent=parent)
        self.layout().addWidget(self._mainWindow)
        self.layout().setContentsMargins(0, 0, 0, 0)
        self.layout().setSpacing(0.0)

        self._mainWindow.setKeepDataAspectRatio(True)
payno's avatar
payno committed
203
        self._mainWindow.setColormap(Colormap(name="temperature"))
204
205
206
207
208
209
210

        # define the keys combobox
        self._keyWidget = qt.QWidget(parent=self)
        self._keyWidget.setLayout(qt.QHBoxLayout())
        self._keyComboBox = qt.QComboBox(parent=self._keyWidget)
        for key in keys:
            self._keyComboBox.addItem(key)
payno's avatar
payno committed
211
        self._keyWidget.layout().addWidget(qt.QLabel("view: "))
212
213
214
215
216
217
218
219
220
221
        self._keyWidget.layout().addWidget(self._keyComboBox)
        self.keySelectionDocker = qt.QDockWidget(parent=self)
        self.keySelectionDocker.setContentsMargins(0, 0, 0, 0)
        self._keyWidget.layout().setContentsMargins(0, 0, 0, 0)
        self._keyWidget.layout().setSpacing(0.0)
        self.keySelectionDocker.setWidget(self._keyWidget)
        # self._mainWindow.addDockWidget(qt.Qt.TopDockWidgetArea, dockWidget)
        self.keySelectionDocker.setAllowedAreas(qt.Qt.TopDockWidgetArea)
        self.keySelectionDocker.setFeatures(qt.QDockWidget.NoDockWidgetFeatures)

222
223
224
225
        # expose API
        self.getActiveImage = self._mainWindow.getActiveImage
        self.menuBar = self._mainWindow.menuBar

226
227
        # connect signal / slot
        self._keyComboBox.currentTextChanged.connect(self._updateView)
228
        self._mainWindow.sigFrameChanged.connect(self._shareFrameChangedSignal)
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

    def clear(self):
        self._mainWindow.clear()

    def getActiveKey(self):
        return self._keyComboBox.currentText()

    def setXasObject(self, xas_obj):
        self._xasObj = xas_obj
        self._updateView()

    def _updateView(self, *args, **kwargs):
        if self._xasObj is None:
            return
        # set the map view
244
        spectra_volume = self._xasObj.spectra.map_to(
payno's avatar
payno committed
245
246
            key=self.getActiveKey(),
        )
247
248
        self._mainWindow.setStack(spectra_volume)

payno's avatar
payno committed
249
250
    def getPlot(self):
        return self._mainWindow.getPlot()
251

252
253
254
255
256
257
258
259
    def _shareFrameChangedSignal(self, frame):
        self.sigFrameChanged.emit(frame)

    def setPerspectiveVisible(self, b):
        """hide the dimension selection"""
        self._mainWindow.setOptionVisible(b)
        self._mainWindow._browser.setVisible(True)

260
261
262
263

class _ExtendedSliderWithBrowser(HorizontalSliderWithBrowser):
    def __init__(self, parent=None, name=None):
        HorizontalSliderWithBrowser.__init__(self, parent)
payno's avatar
payno committed
264
        self.layout().insertWidget(0, qt.QLabel(str(name + ":")))
265
266


267
class _CurveOperation(object):
payno's avatar
payno committed
268
269
270
271
272
273
274
275
276
277
278
279
280
    def __init__(
        self,
        x,
        y,
        legend,
        yaxis=None,
        linestyle=None,
        symbol=None,
        color=None,
        ylabel=None,
        baseline=None,
        alpha=1.0,
    ):
281
282
283
284
285
286
287
288
        self.x = x
        self.y = y
        self.legend = legend
        self.yaxis = yaxis
        self.linestyle = linestyle
        self.symbol = symbol
        self.color = color
        self.ylabel = ylabel
289
290
        self.baseline = baseline
        self.alpha = alpha
291
292


293
class _XMarkerOperation(object):
payno's avatar
payno committed
294
    def __init__(self, x, legend, color="blue"):
295
296
297
298
299
        self.x = x
        self.legend = legend
        self.color = color


300
301
302
class _RawDataList(qt.QTableWidget):
    def __init__(self, parent):
        qt.QTableWidget.__init__(self, parent)
303
        self.clear()
304
305
306
307
308

    def setData(self, x: Iterable, y: Iterable):
        if len(x) != len(y):
            raise ValueError("x and y should have the same number of element")
        self.setRowCount(len(x))
309
310
311
        self.setHorizontalHeaderLabels(
            ["X (energy converted to eV)", "Y ({})".format(MU_CHAR)]
        )
312
313
314

        for i_row, (x_value, y_value) in enumerate(zip(x, y)):
            x_item = qt.QTableWidgetItem()
payno's avatar
payno committed
315
            x_item.setText(str(x_value))
316
317
318
319
320
321
            self.setItem(i_row, 0, x_item)
            y_item = qt.QTableWidgetItem()
            y_item.setText(str(y_value))
            self.setItem(i_row, 1, y_item)
        self.resizeColumnsToContents()

322
323
324
325
326
327
328
329
330
331
    def clear(self):
        super().clear()
        self.setHorizontalHeaderLabels(
            ["X (energy converted to eV)", "Y ({})".format(MU_CHAR)]
        )
        self.setRowCount(0)
        self.setColumnCount(2)
        self.setSortingEnabled(True)
        self.verticalHeader().hide()

332

333
334
class SpectrumViewer(qt.QMainWindow):
    def __init__(self, parent=None):
335
336
337
        self._curveOperations = []
        """List of callaback to produce plot regarding the XASObject.
        Callback function should return a _curve_operation"""
338
        qt.QMainWindow.__init__(self, parent)
339

340
        self.xas_obj = None
341
342
343
344
345

        self._plotWidget = Plot1D(parent=self)
        self._rawDataWidget = _RawDataList(parent=self)

        self._tabWidget = qt.QTabWidget(self)
346
347
        self._tabWidget.addTab(self._plotWidget, "plot")
        self._tabWidget.addTab(self._rawDataWidget, "data as text")
348
349

        self.setCentralWidget(self._tabWidget)
350

351
        # frame browsers
352
        dockWidget = qt.QDockWidget(self)
353
354
355
356
357

        frameBrowsers = qt.QWidget(parent=self)
        frameBrowsers.setLayout(qt.QVBoxLayout())
        frameBrowsers.layout().setContentsMargins(0, 0, 0, 0)

payno's avatar
payno committed
358
        self._dim1FrameBrowser = _ExtendedSliderWithBrowser(parent=self, name="dim 1")
359
        frameBrowsers.layout().addWidget(self._dim1FrameBrowser)
payno's avatar
payno committed
360
        self._dim2FrameBrowser = _ExtendedSliderWithBrowser(parent=self, name="dim 2")
361
362
363
364
365
366
367
368
369
370
371
        frameBrowsers.layout().addWidget(self._dim2FrameBrowser)
        dockWidget.setWidget(frameBrowsers)

        self.addDockWidget(qt.Qt.BottomDockWidgetArea, dockWidget)
        dockWidget.setAllowedAreas(qt.Qt.BottomDockWidgetArea)
        dockWidget.setFeatures(qt.QDockWidget.NoDockWidgetFeatures)

        # connect signal / slot
        self._dim1FrameBrowser.valueChanged.connect(self._updateSpectrumDisplayed)
        self._dim2FrameBrowser.valueChanged.connect(self._updateSpectrumDisplayed)

372
373
374
375
376
377
378
379
380
381
382
    def addCurveOperation(self, callbacks):
        """register an curve to display from Xasobject keys, and a legend

        :param callbacks: callback to call when displaying a specific curve
        :type: Union[list,tuple,function]
        """
        if isinstance(callbacks, (list, tuple)):
            for callback in callbacks:
                self.addCurveOperation(callback)
        else:
            self._curveOperations.append(callbacks)
383
384
385
386
387

    def clearCurveOperations(self):
        """Remove all defined curve operation"""
        self._curveOperations.clear()

388
389
    def setXasObject(self, xas_obj):
        self.xas_obj = xas_obj
390
        if self.xas_obj is None:
391
            self.clear()
392
        else:
payno's avatar
payno committed
393
394
395
396
            assert self.xas_obj.spectra.shape[0] >= 0
            assert self.xas_obj.spectra.shape[1] >= 0
            self._dim1FrameBrowser.setRange(0, self.xas_obj.spectra.shape[0] - 1)
            self._dim2FrameBrowser.setRange(0, self.xas_obj.spectra.shape[1] - 1)
397
            self._updateSpectrumDisplayed()
398
399
400
401
402
403

    def _updateSpectrumDisplayed(self, *args, **kwargs):
        if self.xas_obj is None:
            return
        dim1_index = self._dim1FrameBrowser.value()
        dim2_index = self._dim2FrameBrowser.value()
404
405
406
407
408
409
        if dim1_index < 0 or dim2_index < 0:
            return

        assert dim1_index >= 0
        assert dim2_index >= 0

410
        spectrum = self.xas_obj.get_spectrum(dim1_index, dim2_index)
411
412
413
        # update raw data tab
        self._rawDataWidget.setData(x=spectrum.energy, y=spectrum.mu)
        # update plot tab
414
        for operation in self._curveOperations:
payno's avatar
payno committed
415
            curves = [operation(spectrum)]
416
417
418
419
420
            if silx_plot_has_baseline_feature is True:
                new_curves_op = operation(self.xas_obj, index=dim1_index)
                if new_curves_op is not None:
                    curves += new_curves_op
            for res in curves:
421
422
423
                # result can be None or nan if the processing fails. So in this
                # case we won't display anything
                if res is not None and res.x is not None and res.x is not numpy.nan:
424
                    if isinstance(res, _CurveOperation):
425
                        kwargs = {
payno's avatar
payno committed
426
427
428
429
430
431
432
433
                            "x": res.x,
                            "y": res.y,
                            "legend": res.legend,
                            "yaxis": res.yaxis,
                            "linestyle": res.linestyle,
                            "symbol": res.symbol,
                            "color": res.color,
                            "ylabel": res.ylabel,
434
435
                        }
                        if silx_plot_has_baseline_feature:
payno's avatar
payno committed
436
                            kwargs["baseline"] = (res.baseline,)
437

438
439
                        curve = self._plotWidget.addCurve(**kwargs)
                        curve = self._plotWidget.getCurve(curve)
440
441
                        curve.setAlpha(res.alpha)
                    elif isinstance(res, _XMarkerOperation):
442
                        self._plotWidget.addXMarker(
payno's avatar
payno committed
443
444
                            x=res.x, color=res.color, legend=res.legend
                        )
445
                    else:
payno's avatar
payno committed
446
447
448
                        raise TypeError(
                            "this type of operation is not " "recognized", type(res)
                        )
449
450

    def clear(self):
451
        self._plotWidget.clear()
452
453
        self._dim1FrameBrowser.setMaximum(-1)
        self._dim2FrameBrowser.setMaximum(-1)
454
        self._rawDataWidget.clear()
455
456


payno's avatar
payno committed
457
458
COLOR_MEAN = "black"
COLOR_STD = "grey"
459
460
461
462


def _plot_norm(obj, **kwargs):
    if isinstance(obj, XASObject):
payno's avatar
payno committed
463
464
        assert "index" in kwargs
        index_dim1 = kwargs["index"]
465
        spectra = obj.spectra.map_to("normalized_mu", relative_to="energy")
466
467
468
469
        spectra = spectra[:, index_dim1, :]
        mean = numpy.mean(spectra, axis=1)
        std = numpy.std(spectra, axis=1)
        return (
payno's avatar
payno committed
470
471
472
473
474
475
476
477
478
479
480
481
            _CurveOperation(
                x=obj.energy, y=mean, color=COLOR_MEAN, legend="mean norm", alpha=0.5
            ),
            _CurveOperation(
                x=obj.energy,
                y=mean + std,
                baseline=mean - std,
                color=COLOR_STD,
                legend="std norm",
                alpha=0.5,
            ),
        )
482
483
    elif isinstance(obj, Spectrum):
        if obj.normalized_mu is None:
payno's avatar
payno committed
484
            _logger.error("norm has not been computed yet")
485
486
            return
        assert len(obj.energy) == len(obj.normalized_mu)
payno's avatar
payno committed
487
488
489
        return _CurveOperation(
            x=obj.energy, y=obj.normalized_mu, legend="norm", color="black"
        )
490
491


492
493
494
def _plot_norm_area(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
495
496
    if not hasattr(obj, "norm_area"):
        _logger.error("norm_area has not been computed yet")
497
        return
498
    assert len(obj.energy) == len(obj.norm_area)
payno's avatar
payno committed
499
500
501
    return _CurveOperation(
        x=obj.energy, y=obj.norm_area, legend="norm_area", color="orange"
    )
502
503


504
505
506
def _plot_mback_mu(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
507
508
    if not hasattr(obj, "mback_mu"):
        _logger.error("mback_mu has not been computed yet")
509
        return
payno's avatar
payno committed
510
511
512
    return _CurveOperation(
        x=obj.energy, y=obj.mback_mu, legend="mback_mu", color="orange"
    )
513
514


515
516
517
518
def _plot_pre_edge(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
    if obj.pre_edge is None:
payno's avatar
payno committed
519
        _logger.error("pre_edge has not been computed yet")
520
        return
521
    assert len(obj.energy) == len(obj.pre_edge)
payno's avatar
payno committed
522
523
524
    return _CurveOperation(
        x=obj.energy, y=obj.pre_edge, legend="pre edge", color="green"
    )
525
526


527
528
529
530
def _plot_post_edge(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
    if obj.post_edge is None:
payno's avatar
payno committed
531
        _logger.error("post_edge has not been computed yet")
532
        return
533
    assert len(obj.energy) == len(obj.post_edge)
payno's avatar
payno committed
534
535
536
    return _CurveOperation(
        x=obj.energy, y=obj.post_edge, legend="post edge", color="black"
    )
537
538


539
540
541
def _plot_edge(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
542
543
    if not hasattr(obj, "e0"):
        _logger.error("e0 has not been computed yet")
544
        return
payno's avatar
payno committed
545
    return _XMarkerOperation(x=obj.e0, legend="edge", color="yellow")
546
547


548
549
550
def _plot_raw(obj, **kwargs):
    if isinstance(obj, Spectrum):
        if obj.mu is None:
payno's avatar
payno committed
551
            _logger.error("mu not existing")
552
            return
payno's avatar
payno committed
553
        return _CurveOperation(x=obj.energy, y=obj.mu, legend="mu", color="red")
554
    elif isinstance(obj, XASObject):
payno's avatar
payno committed
555
556
        assert "index" in kwargs
        index_dim1 = kwargs["index"]
557
        spectra = obj.spectra.map_to("normalized_mu", relative_to="energy")
558
559
560
561
        spectra = spectra[:, index_dim1, :]
        mean = numpy.mean(spectra, axis=1)
        std = numpy.std(spectra, axis=1)
        return (
payno's avatar
payno committed
562
563
564
565
566
567
568
569
570
571
572
573
            _CurveOperation(
                x=obj.energy, y=mean, color=COLOR_MEAN, legend="mean mu", alpha=0.5
            ),
            _CurveOperation(
                x=obj.energy,
                y=mean + std,
                baseline=mean - std,
                color=COLOR_STD,
                legend="mu std",
                alpha=0.5,
            ),
        )
574
    else:
payno's avatar
payno committed
575
        raise ValueError("input type is not manged")
576
577
578
579
580


def _plot_fpp(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
581
582
    if not hasattr(obj, "fpp"):
        _logger.error("fpp has not been computed yet")
583
        return
payno's avatar
payno committed
584
    return _CurveOperation(x=obj.energy, y=obj.fpp, legend="fpp", color="blue")
585
586


587
588
589
def _plot_f2(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
590
591
    if not hasattr(obj, "f2"):
        _logger.error("f2 has not been computed yet")
592
        return
payno's avatar
payno committed
593
    return _CurveOperation(x=obj.energy, y=obj.f2, legend="f2", color="orange")
594
595
596
597
598


def _plot_chir_mag(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
599
600
    if not hasattr(obj, "r"):
        _logger.error("r not computed, unable to display it")
601
        return
payno's avatar
payno committed
602
603
    if not hasattr(obj, "chir_mag"):
        _logger.error("chir_mag not computed, unable to display it")
604
        return
payno's avatar
payno committed
605
    return _CurveOperation(x=obj.r, y=obj.chir_mag, legend="chi k (mag)")
606
607
608
609
610


def _plot_chir_re(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
611
612
    if not hasattr(obj, "r"):
        _logger.error("r not computed, unable to display it")
613
        return
payno's avatar
payno committed
614
615
    if not hasattr(obj, "chir_re"):
        _logger.error("chir_re not computed, unable to display it")
616
        return
payno's avatar
payno committed
617
    return _CurveOperation(x=obj.r, y=obj.chir_re, legend="chi k (real)")
618
619
620
621
622


def _plot_chir_imag(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
623
624
    if not hasattr(obj, "r"):
        _logger.error("r not computed, unable to display it")
625
        return
payno's avatar
payno committed
626
627
    if not hasattr(obj, "chir_im"):
        _logger.error("chir_im not computed, unable to display it")
628
        return
payno's avatar
payno committed
629
    return _CurveOperation(x=obj.r, y=obj.chir_im, legend="chi k (imag)")
630
631
632
633
634


def _plot_spectrum(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
635
636
637
    return _CurveOperation(
        x=obj.energy, y=obj.mu, legend="spectrum", yaxis=None, color="blue"
    )
638
639
640
641
642


def _plot_bkg(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
643
644
    if not hasattr(obj, "bkg"):
        _logger.error("missing bkg parameter, unable to compute bkg plot")
645
        return
payno's avatar
payno committed
646
647
648
    return _CurveOperation(
        x=obj.energy, y=obj.bkg, legend="bkg", yaxis=None, color="red"
    )
649
650
651
652
653


def _plot_knots(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
654
655
    if not hasattr(obj, "autobk_details"):
        _logger.error("missing bkg parameter, unable to compute bkg plot")
656
        return
payno's avatar
payno committed
657
658
659
660
661
662
663
664
665
    return _CurveOperation(
        x=obj.autobk_details.knots_e,
        y=obj.autobk_details.knots_y,
        legend="knots",
        yaxis=None,
        color="green",
        linestyle="",
        symbol="o",
    )
666
667


payno's avatar
payno committed
668
669
def _exafs_signal_plot(obj, **kwargs):
    if not isinstance(obj, Spectrum):
670
        return None
payno's avatar
payno committed
671
    missing_keys = obj.get_missing_keys(("EXAFSKValues", "EXAFSSignal"))
672
    if len(missing_keys) > 0:
payno's avatar
payno committed
673
674
675
        _logger.error(
            "missing keys:", missing_keys, "unable to compute exafs signal plot"
        )
676
        return
payno's avatar
payno committed
677
678
679
680
681
682
    k = obj["EXAFSKValues"]
    if "KMin" not in obj:
        obj["KMin"] = k.min()
    if "KMax" not in obj:
        obj["KMax"] = k.max()

payno's avatar
payno committed
683
    idx = (obj["EXAFSKValues"] >= obj["KMin"]) & (obj["EXAFSKValues"] <= obj["KMax"])
payno's avatar
payno committed
684
685
    x = obj["EXAFSKValues"][idx]
    y = obj["EXAFSSignal"][idx]
686
687
688
    return _CurveOperation(x=x, y=y, legend="EXAFSSignal")


payno's avatar
payno committed
689
690
def _exafs_postedge_plot(obj, **kwargs):
    if not isinstance(obj, Spectrum):
691
        return None
payno's avatar
payno committed
692
    missing_keys = obj.get_missing_keys(("EXAFSKValues", "PostEdgeB"))
693
    if len(missing_keys) > 0:
payno's avatar
payno committed
694
695
696
        _logger.error(
            "missing keys:", missing_keys, "unable to compute exafs postedge plot"
        )
697
        return
payno's avatar
payno committed
698
699
700
701
702
    k = obj["EXAFSKValues"]
    if "KMin" not in obj:
        obj["KMin"] = k.min()
    if "KMax" not in obj:
        obj["KMax"] = k.max()
703

payno's avatar
payno committed
704
    idx = (obj["EXAFSKValues"] >= obj["KMin"]) & (obj["EXAFSKValues"] <= obj["KMax"])
705

payno's avatar
payno committed
706
707
    x = obj["EXAFSKValues"][idx]
    y = obj["PostEdgeB"][idx]
708
709
710
    return _CurveOperation(x=x, y=y, legend="PostEdge")


payno's avatar
payno committed
711
712
def _exafs_knots_plot(obj, **kwargs):
    if not isinstance(obj, Spectrum):
713
        return None
payno's avatar
payno committed
714
    missing_keys = obj.get_missing_keys(("KnotsX", "KnotsY"))
715
    if len(missing_keys) > 0:
payno's avatar
payno committed
716
717
718
        _logger.error(
            "missing keys:", missing_keys, "unable to compute exafs knot plot"
        )
719
        return
payno's avatar
payno committed
720
721
    x = obj["KnotsX"]
    y = obj["KnotsY"]
722
    return _CurveOperation(x=x, y=y, legend="Knots", linestyle="", symbol="o")
payno's avatar
payno committed
723
724
725
726
727
728


def _normalized_exafs(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
    assert isinstance(obj, Spectrum)
payno's avatar
payno committed
729
    missing_keys = obj.get_missing_keys(("EXAFSKValues", "EXAFSNormalized"))
730
    if len(missing_keys) > 0:
payno's avatar
payno committed
731
732
733
        _logger.error(
            "missing keys:", missing_keys, "unable to compute normalized EXAFS"
        )
payno's avatar
payno committed
734
735
736
737
738
739
740
741
742
743
        return None

    if obj["KWeight"]:
        if obj["KWeight"] == 1:
            ylabel = "EXAFS Signal * k"
        else:
            ylabel = "EXAFS Signal * k^%d" % obj["KWeight"]
    else:
        ylabel = "EXAFS Signal"

payno's avatar
payno committed
744
    idx = (obj["EXAFSKValues"] >= obj["KMin"]) & (obj["EXAFSKValues"] <= obj["KMax"])
payno's avatar
payno committed
745

payno's avatar
payno committed
746
747
748
749
750
751
    return _CurveOperation(
        x=obj["EXAFSKValues"][idx],
        y=obj["EXAFSNormalized"][idx],
        legend="Normalized EXAFS",
        ylabel=ylabel,
    )
payno's avatar
payno committed
752
753
754
755
756


def _ft_window_plot(obj, **kwargs):
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
757
    missing_keys = obj.ft.get_missing_keys(("K", "WindowWeight"))
758
    if len(missing_keys) > 0:
payno's avatar
payno committed
759
760
761
        _logger.error(
            "missing keys:", missing_keys, "unable to compute normalized EXAFS"
        )
payno's avatar
payno committed
762
763
        return None

payno's avatar
payno committed
764
765
766
767
768
769
770
    return _CurveOperation(
        x=obj.ft["K"],
        y=obj.ft["WindowWeight"],
        legend="FT Window",
        yaxis="right",
        color="red",
    )
payno's avatar
payno committed
771
772
773


def _ft_intensity_plot(obj, **kwargs):
774
775
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
776
    missing_keys = obj.ft.get_missing_keys(("FTRadius", "FTIntensity"))
777
    if len(missing_keys) > 0:
payno's avatar
payno committed
778
        _logger.error("missing keys:", missing_keys, "unable to compute spectrum plot")
payno's avatar
payno committed
779
        return
payno's avatar
payno committed
780
781
782
    return _CurveOperation(
        x=obj.ft["FTRadius"], y=obj.ft["FTIntensity"], legend="FT Intensity"
    )
payno's avatar
payno committed
783
784
785


def _ft_imaginary_plot(obj, **kwargs):
786
787
    if not isinstance(obj, Spectrum):
        return None
payno's avatar
payno committed
788
    missing_keys = obj.ft.get_missing_keys(("FTRadius", "FTImaginary"))
789
    if len(missing_keys) > 0:
payno's avatar
payno committed
790
        _logger.error("missing keys:", missing_keys, "unable to compute spectrum plot")
payno's avatar
payno committed
791
        return
payno's avatar
payno committed
792
793
794
795
796
797
    return _CurveOperation(
        x=obj.ft["FTRadius"],
        y=obj.ft["FTImaginary"],
        legend="FT Imaginary",
        color="red",
    )