Commit a3b83d57 authored by Valentin Valls's avatar Valentin Valls

Refactor statistics into a plot_state_model

parent 94f835ea
......@@ -19,6 +19,7 @@ from bliss.flint.model import scan_model
from bliss.flint.model import flint_model
from bliss.flint.model import plot_model
from bliss.flint.model import plot_item_model
from bliss.flint.model import plot_state_model
class DefaultStyleStrategy(plot_model.StyleStrategy):
......@@ -124,7 +125,7 @@ class DefaultStyleStrategy(plot_model.StyleStrategy):
if isinstance(item, plot_item_model.ScanItem):
continue
if isinstance(item, plot_model.AbstractComputableItem):
if isinstance(item, plot_item_model.CurveStatisticMixIn):
if isinstance(item, plot_state_model.CurveStatisticMixIn):
source = item.source()
baseStyle = self.getStyleFromItem(source, scan)
style = plot_model.Style(
......
......@@ -17,14 +17,11 @@ Here is a list of plot and item inheritance.
"""
from __future__ import annotations
from typing import Optional
from typing import Tuple
import numpy
import collections
from . import scan_model
from . import plot_model
from ..utils import mathutils
class CurvePlot(plot_model.Plot):
......@@ -97,15 +94,6 @@ class CurveMixIn:
return data.array()
class CurveStatisticMixIn:
"""This item use the scan data to process result before displaying it."""
def yAxis(self) -> str:
"""Returns the name of the y-axis in which the statistic have to be displayed"""
source = self.source()
return source.yAxis()
class CurveItem(plot_model.Item, CurveMixIn):
"""Define a curve as part of a plot.
......@@ -203,152 +191,6 @@ class CurveItem(plot_model.Item, CurveMixIn):
return data
class DerivativeItem(plot_model.AbstractComputableItem, CurveMixIn):
"""This item use the scan data to process result before displaying it."""
def __reduce__(self):
return (self.__class__, (), self.__getstate__())
def __getstate__(self):
state = super(DerivativeItem, self).__getstate__()
assert "y_axis" not in state
state["y_axis"] = self.yAxis()
return state
def __setstate__(self, state):
super(DerivativeItem, self).__setstate__(state)
self.setYAxis(state.pop("y_axis"))
def isResultValid(self, result):
return result is not None
def compute(
self, scan: scan_model.Scan
) -> Optional[Tuple[numpy.ndarray, numpy.ndarray]]:
sourceItem = self.source()
x = sourceItem.xData(scan)
y = sourceItem.yData(scan)
if x is None or y is None:
return None
x = x.array()
y = y.array()
if x is None or y is None:
return None
try:
result = mathutils.derivate(x, y)
except Exception as e:
# FIXME: Maybe it is better to return a special type and then return
# Managed outside to store it into the validation cache
scan.setCacheValidation(
self, self.version(), "Error while creating derivative.\n" + str(e)
)
return None
return result
def xData(self, scan: scan_model.Scan) -> Optional[scan_model.Data]:
result = self.reachResult(scan)
if not self.isResultValid(result):
return None
data = result[0]
return scan_model.Data(self, data)
def yData(self, scan: scan_model.Scan) -> Optional[scan_model.Data]:
result = self.reachResult(scan)
if not self.isResultValid(result):
return None
data = result[1]
return scan_model.Data(self, data)
MaxData = collections.namedtuple(
"MaxData",
["max_index", "max_location_y", "max_location_x", "min_y_value", "nb_points"],
)
class MaxCurveItem(plot_model.AbstractIncrementalComputableItem, CurveStatisticMixIn):
"""Implement a statistic which identify the maximum location of a curve."""
def isResultValid(self, result):
return result is not None
def setSource(self, source: plot_model.Item):
previousSource = self.source()
if previousSource is not None:
previousSource.valueChanged.disconnect(self.__sourceChanged)
plot_model.AbstractIncrementalComputableItem.setSource(self, source)
if source is not None:
source.valueChanged.connect(self.__sourceChanged)
self.__sourceChanged(plot_model.ChangeEventType.YAXIS)
def __sourceChanged(self, eventType):
if eventType == plot_model.ChangeEventType.YAXIS:
self.valueChanged.emit(plot_model.ChangeEventType.YAXIS)
def compute(self, scan: scan_model.Scan) -> Optional[MaxData]:
sourceItem = self.source()
xx = sourceItem.xArray(scan)
yy = sourceItem.yArray(scan)
if xx is None or yy is None:
return None
max_index = numpy.argmax(yy)
min_y_value = numpy.min(yy)
max_location_x, max_location_y = xx[max_index], yy[max_index]
result = MaxData(
max_index, max_location_y, max_location_x, min_y_value, len(xx)
)
return result
def incrementalCompute(
self, previousResult: MaxData, scan: scan_model.Scan
) -> MaxData:
sourceItem = self.source()
xx = sourceItem.xArray(scan)
yy = sourceItem.yArray(scan)
if xx is None or yy is None:
raise ValueError("Non empty data is expected")
nb = previousResult.nb_points
if nb == len(xx):
# obviously nothing to compute
return previousResult
xx = xx[nb:]
yy = yy[nb:]
max_index = numpy.argmax(yy)
min_y_value = numpy.min(yy)
max_location_x, max_location_y = xx[max_index], yy[max_index]
max_index = max_index + nb
if previousResult.min_y_value < min_y_value:
min_y_value = previousResult.min_y_value
if previousResult.max_location_y > max_location_y:
# Update and return the previous result
return MaxData(
previousResult.max_index,
previousResult.max_location_y,
previousResult.max_location_x,
min_y_value,
nb + len(xx),
)
# Update and new return the previous result
result = MaxData(
max_index, max_location_y, max_location_x, min_y_value, nb + len(xx)
)
return result
class McaPlot(plot_model.Plot):
"""Define a plot which is specific for MCAs."""
......
# -*- coding: utf-8 -*-
#
# This file is part of the bliss project
#
# Copyright (c) 2015-2019 Beamline Control Unit, ESRF
# Distributed under the GNU LGPLv3. See LICENSE for more info.
"""Contains implementation of concrete objects used to model plots.
It exists 4 kinds of plots: curves, scatter, image, MCAs. Each plot contains
specific items. But it is not a constraint from the architecture.
Here is a list of plot and item inheritance.
.. image:: _static/flint/model/plot_model_item.png
:alt: Scan model
:align: center
"""
from __future__ import annotations
from typing import Optional
from typing import NamedTuple
import numpy
from . import scan_model
from . import plot_model
from . import plot_item_model
from ..utils import mathutils
class CurveStatisticMixIn:
"""This item use the scan data to process result before displaying it."""
def yAxis(self) -> str:
"""Returns the name of the y-axis in which the statistic have to be displayed"""
source = self.source()
return source.yAxis()
class DerivativeItem(plot_model.AbstractComputableItem, plot_item_model.CurveMixIn):
"""This item use the scan data to process result before displaying it."""
def __reduce__(self):
return (self.__class__, (), self.__getstate__())
def __getstate__(self):
state = super(DerivativeItem, self).__getstate__()
assert "y_axis" not in state
state["y_axis"] = self.yAxis()
return state
def __setstate__(self, state):
super(DerivativeItem, self).__setstate__(state)
self.setYAxis(state.pop("y_axis"))
def isResultValid(self, result):
return result is not None
def compute(
self, scan: scan_model.Scan
) -> Optional[Tuple[numpy.ndarray, numpy.ndarray]]:
sourceItem = self.source()
x = sourceItem.xData(scan)
y = sourceItem.yData(scan)
if x is None or y is None:
return None
x = x.array()
y = y.array()
if x is None or y is None:
return None
try:
result = mathutils.derivate(x, y)
except Exception as e:
# FIXME: Maybe it is better to return a special type and then return
# Managed outside to store it into the validation cache
scan.setCacheValidation(
self, self.version(), "Error while creating derivative.\n" + str(e)
)
return None
return result
def xData(self, scan: scan_model.Scan) -> Optional[scan_model.Data]:
result = self.reachResult(scan)
if not self.isResultValid(result):
return None
data = result[0]
return scan_model.Data(self, data)
def yData(self, scan: scan_model.Scan) -> Optional[scan_model.Data]:
result = self.reachResult(scan)
if not self.isResultValid(result):
return None
data = result[1]
return scan_model.Data(self, data)
class MaxData(NamedTuple):
max_index: int
max_location_y: float
max_location_x: float
min_y_value: float
nb_points: int
class MaxCurveItem(plot_model.AbstractIncrementalComputableItem, CurveStatisticMixIn):
"""Implement a statistic which identify the maximum location of a curve."""
def isResultValid(self, result):
return result is not None
def setSource(self, source: plot_model.Item):
previousSource = self.source()
if previousSource is not None:
previousSource.valueChanged.disconnect(self.__sourceChanged)
plot_model.AbstractIncrementalComputableItem.setSource(self, source)
if source is not None:
source.valueChanged.connect(self.__sourceChanged)
self.__sourceChanged(plot_model.ChangeEventType.YAXIS)
def __sourceChanged(self, eventType):
if eventType == plot_model.ChangeEventType.YAXIS:
self.valueChanged.emit(plot_model.ChangeEventType.YAXIS)
def compute(self, scan: scan_model.Scan) -> Optional[MaxData]:
sourceItem = self.source()
xx = sourceItem.xArray(scan)
yy = sourceItem.yArray(scan)
if xx is None or yy is None:
return None
max_index = numpy.argmax(yy)
min_y_value = numpy.min(yy)
max_location_x, max_location_y = xx[max_index], yy[max_index]
result = MaxData(
max_index, max_location_y, max_location_x, min_y_value, len(xx)
)
return result
def incrementalCompute(
self, previousResult: MaxData, scan: scan_model.Scan
) -> MaxData:
sourceItem = self.source()
xx = sourceItem.xArray(scan)
yy = sourceItem.yArray(scan)
if xx is None or yy is None:
raise ValueError("Non empty data is expected")
nb = previousResult.nb_points
if nb == len(xx):
# obviously nothing to compute
return previousResult
xx = xx[nb:]
yy = yy[nb:]
max_index = numpy.argmax(yy)
min_y_value = numpy.min(yy)
max_location_x, max_location_y = xx[max_index], yy[max_index]
max_index = max_index + nb
if previousResult.min_y_value < min_y_value:
min_y_value = previousResult.min_y_value
if previousResult.max_location_y > max_location_y:
# Update and return the previous result
return MaxData(
previousResult.max_index,
previousResult.max_location_y,
previousResult.max_location_x,
min_y_value,
nb + len(xx),
)
# Update and new return the previous result
result = MaxData(
max_index, max_location_y, max_location_x, min_y_value, nb + len(xx)
)
return result
......@@ -16,6 +16,7 @@ from silx.gui import icons
from bliss.flint.model import scan_model
from bliss.flint.model import plot_model
from bliss.flint.model import plot_item_model
from bliss.flint.model import plot_state_model
class StandardRowItem(qt.QStandardItem):
......@@ -88,7 +89,7 @@ class ScanRowItem(StandardRowItem):
icon = icons.getQIcon("flint:icons/channel-curve")
elif isinstance(plotItem, plot_item_model.CurveMixIn):
icon = icons.getQIcon("flint:icons/item-func")
elif isinstance(plotItem, plot_item_model.CurveStatisticMixIn):
elif isinstance(plotItem, plot_state_model.CurveStatisticMixIn):
icon = icons.getQIcon("flint:icons/item-stats")
else:
icon = icons.getQIcon("flint:icons/item-channel")
......
......@@ -24,6 +24,7 @@ from bliss.flint.model import scan_model
from bliss.flint.model import flint_model
from bliss.flint.model import plot_model
from bliss.flint.model import plot_item_model
from bliss.flint.model import plot_state_model
from bliss.flint.widgets.extended_dock_widget import ExtendedDockWidget
from bliss.flint.widgets.plot_helper import FlintPlot
from bliss.flint.helper import scan_info_helper
......@@ -482,8 +483,8 @@ class CurvePlotWidget(ExtendedDockWidget):
plot._add(curveItem)
plotItems.append((legend, "curve"))
elif isinstance(item, plot_item_model.CurveStatisticMixIn):
if isinstance(item, plot_item_model.MaxCurveItem):
elif isinstance(item, plot_state_model.CurveStatisticMixIn):
if isinstance(item, plot_state_model.MaxCurveItem):
legend = str(item) + "/" + str(scan)
result = item.reachResult(scan)
if item.isResultValid(result):
......
......@@ -19,6 +19,7 @@ from silx.gui import icons
from bliss.flint.model import flint_model
from bliss.flint.model import plot_model
from bliss.flint.model import plot_item_model
from bliss.flint.model import plot_state_model
from bliss.flint.model import scan_model
from bliss.flint.helper import model_helper
from . import delegates
......@@ -426,7 +427,7 @@ class _DataItem(_property_tree_helper.ScanRowItem):
# self.__updateXAxisStyle(False, None)
useXAxis = False
self.__updateXAxisStyle(False)
elif isinstance(plotItem, plot_item_model.CurveStatisticMixIn):
elif isinstance(plotItem, plot_state_model.CurveStatisticMixIn):
useXAxis = False
self.__updateXAxisStyle(False)
......
......@@ -7,6 +7,7 @@ import pickle
import pytest
from bliss.flint.model import plot_item_model
from bliss.flint.model import plot_state_model
from bliss.flint.model import plot_model
from bliss.flint.model import style_model
from bliss.flint.helper import style_helper
......@@ -27,9 +28,9 @@ from bliss.flint.helper import style_helper
plot_item_model.ScatterItem,
plot_item_model.McaItem,
plot_item_model.ImageItem,
plot_item_model.DerivativeItem,
plot_item_model.CurveStatisticMixIn,
plot_item_model.MaxCurveItem,
plot_state_model.DerivativeItem,
plot_state_model.CurveStatisticMixIn,
plot_state_model.MaxCurveItem,
plot_item_model.MotorPositionMarker,
style_helper.DefaultStyleStrategy,
style_model.Style,
......
"""Testing plot item model."""
import numpy
from silx.gui import qt
from bliss.flint.model import plot_item_model
from bliss.flint.model import plot_state_model
from bliss.flint.model import plot_model
class CurveMock(qt.QObject, plot_item_model.CurveMixIn):
valueChanged = qt.Signal()
def __init__(self, xx: numpy.ndarray, yy: numpy.ndarray):
super(CurveMock, self).__init__()
self._xx = xx
self._yy = yy
def xArray(self, scan) -> numpy.ndarray:
return self._xx
def yArray(self, scan) -> numpy.ndarray:
return self._yy
def test_max_compute():
scan = None
yy = [0, -10, 2, 5, 9, 500, 100]
xx = numpy.arange(len(yy)) * 10
item = plot_item_model.MaxCurveItem()
curveItem = CurveMock(xx=xx, yy=yy)
item.setSource(curveItem)
result = item.compute(scan)
assert result.nb_points == len(xx)
assert result.max_index == 5
assert result.max_location_x == 50
assert result.max_location_y == 500
assert result.min_y_value == -10
def test_max_incremental_compute_1():
"""The result is part of the increment"""
scan = None
yy = [0, -10, 2, 5, 9, 500, 100]
xx = numpy.arange(len(yy)) * 10
item = plot_item_model.MaxCurveItem()
curveItem = CurveMock(xx=xx[: len(xx) // 2], yy=yy[: len(xx) // 2])
item.setSource(curveItem)
result = item.compute(scan)
curveItem = CurveMock(xx=xx, yy=yy)
item.setSource(curveItem)
result = item.incrementalCompute(result, scan)
assert result.nb_points == len(xx)
assert result.max_index == 5
assert result.max_location_x == 50
assert result.max_location_y == 500
assert result.min_y_value == -10
def test_max_incremental_compute_2():
"""The result is NOT part of the increment"""
scan = None
yy = [0, 10, 500, 5, 9, -10, 100]
xx = numpy.arange(len(yy)) * 10
item = plot_item_model.MaxCurveItem()
curveItem = CurveMock(xx=xx[: len(xx) // 2], yy=yy[: len(xx) // 2])
item.setSource(curveItem)
result = item.compute(scan)
curveItem = CurveMock(xx=xx, yy=yy)
item.setSource(curveItem)
result = item.incrementalCompute(result, scan)
assert result.nb_points == len(xx)
assert result.max_index == 2
assert result.max_location_x == 20
assert result.max_location_y == 500
assert result.min_y_value == -10
def test_picklable():
plot = plot_item_model.CurvePlot()
plot.setScansStored(True)
......@@ -92,12 +14,12 @@ def test_picklable():
item.setYChannel(plot_model.ChannelRef(None, "y"))
plot.addItem(item)
item2 = plot_item_model.DerivativeItem(plot)
item2 = plot_state_model.DerivativeItem(plot)
item2.setYAxis("right")
item2.setSource(item)
plot.addItem(item2)
item3 = plot_item_model.MaxCurveItem(plot)
item3 = plot_state_model.MaxCurveItem(plot)
item3.setSource(item2)
plot.addItem(item3)
import pickle
......
"""Testing plot state model."""
import numpy
from silx.gui import qt
from bliss.flint.model import plot_state_model
from bliss.flint.model import plot_item_model
class CurveMock(qt.QObject, plot_item_model.CurveMixIn):
valueChanged = qt.Signal()
def __init__(self, xx: numpy.ndarray, yy: numpy.ndarray):
super(CurveMock, self).__init__()
self._xx = xx
self._yy = yy
def xArray(self, scan) -> numpy.ndarray:
return self._xx
def yArray(self, scan) -> numpy.ndarray:
return self._yy
def test_max_compute():
scan = None
yy = [0, -10, 2, 5, 9, 500, 100]
xx = numpy.arange(len(yy)) * 10
item = plot_state_model.MaxCurveItem()
curveItem = CurveMock(xx=xx, yy=yy)
item.setSource(curveItem)
result = item.compute(scan)
assert result.nb_points == len(xx)
assert result.max_index == 5
assert result.max_location_x == 50
assert result.max_location_y == 500
assert result.min_y_value == -10
def test_max_incremental_compute_1():
"""The result is part of the increment"""
scan = None
yy = [0, -10, 2, 5, 9, 500, 100]
xx = numpy.arange(len(yy)) * 10
item = plot_state_model.MaxCurveItem()
curveItem = CurveMock(xx=xx[: len(xx) // 2], yy=yy[: len(xx) // 2])
item.setSource(curveItem)
result = item.compute(scan)
curveItem = CurveMock(xx=xx, yy=yy)
item.setSource(curveItem)
result = item.incrementalCompute(result, scan)
assert result.nb_points == len(xx)
assert result.max_index == 5
assert result.max_location_x == 50
assert result.max_location_y == 500
assert result.min_y_value == -10