Commit 1853223c authored by Matias Guijarro's avatar Matias Guijarro

Merge branch 'support-computed-items' into 'master'

Support computed items

See merge request !1987
parents e3586ca0 fa7cba25
Pipeline #21268 passed with stages
in 58 minutes and 33 seconds
......@@ -711,6 +711,16 @@ def anscan(
scan_info["start"] = starts_list
scan_info["stop"] = stops_list
requests = {}
for motor, start, stop in motor_tuple_list:
d = mot.position if scan_type == "dscan" else 0
requests[f"axis:{motor.name}"] = {
"start": start + d,
"stop": stop + d,
"points": npoints,
}
scan_info["requests"] = requests
_update_with_scan_display_meta(scan_info)
scan_params = dict()
......
......@@ -13,7 +13,6 @@ from typing import Optional
from typing import List
from typing import Dict
from typing import Tuple
from typing import Union
from silx.gui import colors
......@@ -501,7 +500,7 @@ def updateDisplayedChannelNames(
item, _updated = createCurveItem(plot, channel, yAxis="left")
else:
assert False
assert _updated == False
assert not _updated
item.setVisible(True)
for item in unneeded_items:
plot.removeItem(item)
......@@ -19,6 +19,7 @@ from bliss.flint.model import scan_model
from bliss.flint.model import flint_model
from bliss.flint.model import plot_model
from bliss.flint.model import plot_item_model
from bliss.flint.model import plot_state_model
class DefaultStyleStrategy(plot_model.StyleStrategy):
......@@ -119,23 +120,32 @@ class DefaultStyleStrategy(plot_model.StyleStrategy):
def computeItemStyleFromCurvePlot(self, plot, scans):
i = 0
countBase = 0
if len(scans) == 1:
countBase = 0
for item in plot.items():
if not isinstance(item, plot_model.ComputableMixIn):
countBase += 1
for scan in scans:
for item in plot.items():
if isinstance(item, plot_item_model.ScanItem):
continue
if isinstance(item, plot_model.AbstractComputableItem):
if isinstance(item, plot_item_model.CurveStatisticMixIn):
source = item.source()
baseStyle = self.getStyleFromItem(source, scan)
style = plot_model.Style(
lineStyle=":", lineColor=baseStyle.lineColor
)
if isinstance(item, plot_model.ComputableMixIn):
if countBase == 1:
# Allocate a new color for everything
color = self.pickColor(i)
i += 1
else:
# Reuse the color
source = item.source()
baseStyle = self.getStyleFromItem(source, scan)
style = plot_model.Style(
lineStyle="-.", lineColor=baseStyle.lineColor
)
color = baseStyle.lineColor
if isinstance(item, plot_state_model.CurveStatisticItem):
style = plot_model.Style(lineStyle=":", lineColor=color)
else:
style = plot_model.Style(lineStyle="-.", lineColor=color)
else:
color = self.pickColor(i)
style = plot_model.Style(lineStyle="-", lineColor=color)
......
......@@ -20,7 +20,6 @@ from bliss.config.conductor.client import get_default_connection
from bliss.config.conductor.client import get_redis_connection
from bliss.flint import config
import pickle
import logging
from silx.gui import qt
......@@ -262,7 +261,6 @@ class ManageMainBehaviours(qt.QObject):
updatePlotModel = enforceDisplay or not sameScan
else:
updatePlotModel = True
_logger.error(updatePlotModel)
if len(plots) > 0:
defaultPlot = plots[0]
......
......@@ -142,8 +142,8 @@ class ScanManager:
def exception_orrured(future_exception):
try:
future_exception.get()
except:
_logger.error("Error orrured in watch_session_scans", exc_info=True)
except Exception:
_logger.error("Error occurred in watch_session_scans", exc_info=True)
self._spawn_scans_session_watch(session_name, clean_redis=True)
task.link_exception(exception_orrured)
......
......@@ -17,14 +17,13 @@ Here is a list of plot and item inheritance.
"""
from __future__ import annotations
from typing import Optional
from typing import Tuple
from typing import Dict
from typing import Any
import numpy
import collections
from . import scan_model
from . import plot_model
from ..utils import mathutils
class CurvePlot(plot_model.Plot):
......@@ -71,10 +70,20 @@ class CurveMixIn:
def __init__(self):
self.__yAxis = "left"
def __getstate__(self):
state = {}
state["y_axis"] = self.yAxis()
return state
def __setstate__(self, state):
self.setYAxis(state.pop("y_axis"))
def yAxis(self) -> str:
return self.__yAxis
def setYAxis(self, yAxis: str):
if self.__yAxis == yAxis:
return
self.__yAxis = yAxis
self._emitValueChanged(plot_model.ChangeEventType.YAXIS)
......@@ -96,14 +105,9 @@ class CurveMixIn:
return None
return data.array()
class CurveStatisticMixIn:
"""This item use the scan data to process result before displaying it."""
def yAxis(self) -> str:
"""Returns the name of the y-axis in which the statistic have to be displayed"""
source = self.source()
return source.yAxis()
def displayName(self, axisName, scan: scan_model.Scan) -> str:
"""Helper to reach the axis display name"""
raise NotImplementedError()
class CurveItem(plot_model.Item, CurveMixIn):
......@@ -113,29 +117,26 @@ class CurveItem(plot_model.Item, CurveMixIn):
"""
def __init__(self, parent: plot_model.Plot = None):
super(CurveItem, self).__init__(parent=parent)
plot_model.Item.__init__(self, parent=parent)
CurveMixIn.__init__(self)
self.__x: Optional[plot_model.ChannelRef] = None
self.__y: Optional[plot_model.ChannelRef] = None
self.__yAxis: str = "left"
def __reduce__(self):
return (self.__class__, (), self.__getstate__())
def __getstate__(self):
state = super(CurveItem, self).__getstate__()
state: Dict[str, Any] = {}
state.update(plot_model.Item.__getstate__(self))
state.update(CurveMixIn.__getstate__(self))
assert "x" not in state
assert "y" not in state
assert "y-axis" not in state
state["x"] = self.__x
state["y"] = self.__y
state["y_axis"] = self.__yAxis
return state
def __setstate__(self, state):
super(CurveItem, self).__setstate__(state)
plot_model.Item.__setstate__(self, state)
CurveMixIn.__setstate__(self, state)
self.__x = state.pop("x")
self.__y = state.pop("y")
self.__yAxis = state.pop("y_axis")
def isValid(self):
return self.__x is not None and self.__y is not None
......@@ -175,15 +176,6 @@ class CurveItem(plot_model.Item, CurveMixIn):
self.__y = channel
self._emitValueChanged(plot_model.ChangeEventType.Y_CHANNEL)
def yAxis(self) -> str:
return self.__yAxis
def setYAxis(self, yAxis: str):
if self.__yAxis == yAxis:
return
self.__yAxis = yAxis
self._emitValueChanged(plot_model.ChangeEventType.YAXIS)
def xData(self, scan: scan_model.Scan) -> Optional[scan_model.Data]:
channel = self.xChannel()
if channel is None:
......@@ -202,151 +194,14 @@ class CurveItem(plot_model.Item, CurveMixIn):
return None
return data
class DerivativeItem(plot_model.AbstractComputableItem, CurveMixIn):
"""This item use the scan data to process result before displaying it."""
def __reduce__(self):
return (self.__class__, (), self.__getstate__())
def __getstate__(self):
state = super(DerivativeItem, self).__getstate__()
assert "y_axis" not in state
state["y_axis"] = self.yAxis()
return state
def __setstate__(self, state):
super(DerivativeItem, self).__setstate__(state)
self.setYAxis(state.pop("y_axis"))
def isResultValid(self, result):
return result is not None
def compute(
self, scan: scan_model.Scan
) -> Optional[Tuple[numpy.ndarray, numpy.ndarray]]:
sourceItem = self.source()
x = sourceItem.xData(scan)
y = sourceItem.yData(scan)
if x is None or y is None:
return None
x = x.array()
y = y.array()
if x is None or y is None:
return None
try:
result = mathutils.derivate(x, y)
except Exception as e:
# FIXME: Maybe it is better to return a special type and then return
# Managed outside to store it into the validation cache
scan.setCacheValidation(
self, self.version(), "Error while creating derivative.\n" + str(e)
)
return None
return result
def xData(self, scan: scan_model.Scan) -> Optional[scan_model.Data]:
result = self.reachResult(scan)
if not self.isResultValid(result):
return None
data = result[0]
return scan_model.Data(self, data)
def yData(self, scan: scan_model.Scan) -> Optional[scan_model.Data]:
result = self.reachResult(scan)
if not self.isResultValid(result):
return None
data = result[1]
return scan_model.Data(self, data)
MaxData = collections.namedtuple(
"MaxData",
["max_index", "max_location_y", "max_location_x", "min_y_value", "nb_points"],
)
class MaxCurveItem(plot_model.AbstractIncrementalComputableItem, CurveStatisticMixIn):
"""Implement a statistic which identify the maximum location of a curve."""
def isResultValid(self, result):
return result is not None
def setSource(self, source: plot_model.Item):
previousSource = self.source()
if previousSource is not None:
previousSource.valueChanged.disconnect(self.__sourceChanged)
plot_model.AbstractIncrementalComputableItem.setSource(self, source)
if source is not None:
source.valueChanged.connect(self.__sourceChanged)
self.__sourceChanged(plot_model.ChangeEventType.YAXIS)
def __sourceChanged(self, eventType):
if eventType == plot_model.ChangeEventType.YAXIS:
self.valueChanged.emit(plot_model.ChangeEventType.YAXIS)
def compute(self, scan: scan_model.Scan) -> Optional[MaxData]:
sourceItem = self.source()
xx = sourceItem.xArray(scan)
yy = sourceItem.yArray(scan)
if xx is None or yy is None:
return None
max_index = numpy.argmax(yy)
min_y_value = numpy.min(yy)
max_location_x, max_location_y = xx[max_index], yy[max_index]
result = MaxData(
max_index, max_location_y, max_location_x, min_y_value, len(xx)
)
return result
def incrementalCompute(
self, previousResult: MaxData, scan: scan_model.Scan
) -> MaxData:
sourceItem = self.source()
xx = sourceItem.xArray(scan)
yy = sourceItem.yArray(scan)
if xx is None or yy is None:
raise ValueError("Non empty data is expected")
nb = previousResult.nb_points
if nb == len(xx):
# obviously nothing to compute
return previousResult
xx = xx[nb:]
yy = yy[nb:]
max_index = numpy.argmax(yy)
min_y_value = numpy.min(yy)
max_location_x, max_location_y = xx[max_index], yy[max_index]
max_index = max_index + nb
if previousResult.min_y_value < min_y_value:
min_y_value = previousResult.min_y_value
if previousResult.max_location_y > max_location_y:
# Update and return the previous result
return MaxData(
previousResult.max_index,
previousResult.max_location_y,
previousResult.max_location_x,
min_y_value,
nb + len(xx),
)
# Update and new return the previous result
result = MaxData(
max_index, max_location_y, max_location_x, min_y_value, nb + len(xx)
)
return result
def displayName(self, axisName, scan: scan_model.Scan) -> str:
"""Helper to reach the axis display name"""
if axisName == "x":
return self.xChannel().displayName(scan)
elif axisName == "y":
return self.yChannel().displayName(scan)
else:
assert False
class McaPlot(plot_model.Plot):
......
......@@ -26,13 +26,14 @@ of the implementation.
"""
from __future__ import annotations
from typing import Tuple
from typing import List
from typing import Any
from typing import Dict
from typing import Optional
import numpy
import enum
import logging
import contextlib
from silx.gui import qt
......@@ -41,6 +42,9 @@ from .style_model import Style
from . import style_model
_logger = logging.getLogger(__name__)
class ChangeEventType(enum.Enum):
"""Enumerate the list of attributes which can emit a change event."""
......@@ -92,7 +96,7 @@ class Plot(qt.QObject):
def __init__(self, parent=None):
super(Plot, self).__init__(parent=parent)
self.__items: List[Item] = []
self.__styleStrategy: StyleStrategy = None
self.__styleStrategy: Optional[StyleStrategy] = None
self.__inTransaction: int = 0
def __reduce__(self):
......@@ -131,7 +135,7 @@ class Plot(qt.QObject):
self.__inTransaction += 1
self.transactionStarted.emit()
try:
yield
yield self
finally:
self.__inTransaction -= 1
self.transactionFinished.emit()
......@@ -151,11 +155,12 @@ class Plot(qt.QObject):
def removeItem(self, item: Item):
items = self.__itemTree(item)
for i in items:
item._setPlot(None)
self.__items.remove(i)
for i in items:
self.itemRemoved.emit(i)
with self.transaction():
for i in items:
item._setPlot(None)
self.__items.remove(i)
for i in items:
self.itemRemoved.emit(i)
self.invalidateStructure()
def items(self) -> List[Item]:
......@@ -383,26 +388,32 @@ _NotComputed = object()
"""Allow to flag an attribute as not computed"""
class AbstractComputableItem(Item):
"""This item use the scan data to process result before displaying it."""
class ComputeError(Exception):
"""Raised when the `compute` method of ComputableMixIn or
IncrementalComputableMixIn can't compute any output"""
resultAvailable = qt.Signal(object)
def __init__(self, msg: str, result=None):
super(ComputeError, self).__init__(self, msg)
self.msg = msg
self.result = result
class ChildItem(Item):
"""An item with a source"""
def __init__(self, parent=None):
Item.__init__(self, parent=parent)
super(ChildItem, self).__init__(parent=parent)
self.__source: Optional[Item] = None
def __reduce__(self):
return (self.__class__, (), self.__getstate__())
def __getstate__(self):
state = super(AbstractComputableItem, self).__getstate__()
state: Dict[str, Any] = {}
state.update(Item.__getstate__(self))
assert "source" not in state
state["source"] = self.__source
return state
def __setstate__(self, state):
super(AbstractComputableItem, self).__setstate__(state)
Item.__setstate__(self, state)
self.__source = state.pop("source")
def isChildOf(self, parent: Item) -> bool:
......@@ -420,17 +431,48 @@ class AbstractComputableItem(Item):
def source(self) -> Optional[Item]:
return self.__source
class ComputableMixIn:
"""This item use the scan data to process result before displaying it."""
resultAvailable = qt.Signal(object)
def inputData(self) -> Any:
"""Needed to invalidate the data according to the configuration"""
return None
def isResultComputed(self, scan: scan_model.Scan) -> bool:
return scan.hasCachedResult(self)
def reachResult(self, scan: scan_model.Scan):
# FIXME: implement an asynchronous the cache system
# FIXME: cache system have to be invalidated when self config changes
if scan.hasCachedResult(self):
result = scan.getCachedResult(self)
key = (self, self.inputData())
if scan.hasCachedResult(key):
result = scan.getCachedResult(key)
else:
result = self.compute(scan)
scan.setCachedResult(self, result)
try:
result = self.compute(scan)
except ComputeError as e:
try:
# FIXME: This messages should be stored at the same place
scan.setCacheValidation(self, self.version(), e.msg)
except KeyError:
_logger.error(
"Computation message lost: %s, %s, %s",
self,
self.version(),
e.msg,
)
result = e.result
except Exception as e:
scan.setCacheValidation(
self, self.version(), "Error while computing:" + str(e)
)
result = None
scan.setCachedResult(key, result)
if not self.isResultValid(result):
return None
return result
......@@ -442,7 +484,7 @@ class AbstractComputableItem(Item):
raise NotImplementedError()
class AbstractIncrementalComputableItem(AbstractComputableItem):
class IncrementalComputableMixIn(ComputableMixIn):
def incrementalCompute(self, previousResult: Any, scan: scan_model.Scan) -> Any:
"""Compute a data using the previous value as basis"""
raise NotImplementedError()
......
This diff is collapsed.
<?xml version="1.0" encoding="UTF-8"?>
<!-- Created with Inkscape (http://www.inkscape.org/) -->
<svg id="svg6" version="1.1" viewBox="0 0 32 32" xmlns="http://www.w3.org/2000/svg" xmlns:cc="http://creativecommons.org/ns#" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
<metadata id="metadata2">
<rdf:RDF>
<cc:Work rdf:about="">
<dc:format>image/svg+xml</dc:format>
<dc:type rdf:resource="http://purl.org/dc/dcmitype/StillImage"/>
<dc:title/>
</cc:Work>
</rdf:RDF>
</metadata>
<path id="path4" d="m4.0413 27.281c8.6016-0.008824 3.1105-23.891 11.959-23.917 8.8482-0.026233 3.2362 23.908 11.959 23.917" fill="none" stroke="#00a14b" stroke-width="2.1827"/>
<g id="g843" transform="translate(-34.36 -7.0998)">
<path id="path4693-1" d="m50.401 10.659v26.711" fill="none" stroke="#00a14b" stroke-dasharray="2.37510014, 2.37510014" stroke-width="2.3751"/>
<g id="g899" transform="translate(39.007 -3.5796)" stroke="#000">
<g id="g888" transform="translate(-3.2784 2.6943)" fill="none" stroke="#000" stroke-width=".15">
<g id="g865" transform="translate(.8631 13.263)" fill="#000" stroke="#000" stroke-miterlimit="10" stroke-width=".5" style="font-variant-east_asian:normal">
<rect id="rect861" x="24.483" y="7.225" width="1.239" height="8.379" style="font-variant-east_asian:normal"/>
<rect id="rect863" x="20.913" y="10.796" width="8.38" height="1.237" style="font-variant-east_asian:normal"/>
</g>
</g>
</g>
</g>
</svg>
......@@ -308,7 +308,7 @@ class AcquisitionSimulator(qt.QObject):
return height * stepData
pos = numpy.random.rand() * (nbPoints1 // 2) + nbPoints1 // 4
height = 5 * numpy.random.rand() * 5
height = 5 + numpy.random.rand() * 5
stepData = (
step(pos, nbPoints1, 6, height=height) + numpy.random.random(nbPoints1) * 1
)
......@@ -330,6 +330,92 @@ class AcquisitionSimulator(qt.QObject):
) + 0.3 * numpy.random.random(nbPoints2)
scan.registerData(3, device3_channel1, data)
def __createSlit(self, scan: _VirtualScan, interval, duration, includeMasters=True):
master_time1 = scan_model.Device(scan.scan())
master_time1.setName("timer")
master_time1_index = scan_model.Channel(master_time1)
master_time1_index.setName("timer:elapsed_time")
device1 = scan_model.Device(scan.scan())
device1.setName("dev1")
device1.setMaster(master_time1)
device1_channel1 = scan_model.Channel(device1)
device1_channel1.setName("dev1:sy")
device2 = scan_model.Device(scan.scan())
device2.setName("dev2")
device2.setMaster(master_time1)
device2_channel1 = scan_model.Channel(device2)
device2_channel1.setName("dev2:diode1")
device2_channel2 = scan_model.Channel(device2)
device2_channel2.setName("dev2:diode2")
scan_info = {
"display_names": {},
"master": {
"display_names": {},
"images": [],
"scalars": [],
"scalars_units": {},
"spectra": [],
},
"scalars": [
device2_channel1.name(),
device2_channel2.name(),
master_time1_index.name(),
],
"scalars_units": {},
}
start, stop = -10, 20
if includeMasters:
scan_info["master"]["scalars"].append(device1_channel1.name())
scan_info["master"]["scalars_units"][device1_channel1.name()] = "mm"
else:
scan_info["scalars"].append(device1_channel1.name())
scan_info["scalars_units"][device1_channel1.name()] = "mm"
scan.scan_info["acquisition_chain"][master_time1.name()] = scan_info
requests = {}
requests[device1_channel1.name()] = {"start": start, "stop": stop}
scan.scan_info["requests"] = requests
# Every 2 ticks
nbPoints1 = (duration // interval) // 2
index1 = numpy.linspace(0, duration, nbPoints1)
def step(position, nbPoints, gaussianStd, height=1):
gaussianSize = int(gaussianStd) * 10
gaussianData = scipy.signal.gaussian(gaussianSize, gaussianStd)
stepData = numpy.zeros(len(index1) + gaussianSize)
stepData[int(position) :] = 1
stepData = scipy.signal.convolve(stepData, gaussianData, mode="same")[
0:nbPoints
]
stepData *= 1 / stepData[-1]
return height * stepData
pos = numpy.random.rand() * (nbPoints1 // 2) + nbPoints1 // 4
height = 5 + numpy.random.rand() * 5
stepData = (
step(pos, nbPoints1, 6, height=height) + numpy.random.random(nbPoints1) * 1
)
gaussianData = (
scipy.signal.gaussian(nbPoints1, 6) * height
+ numpy.random.random(nbPoints1) * 1
)
motorData = (
numpy.linspace(start, stop, nbPoints1)
+ numpy.random.random(nbPoints1) * 0.2
)
scan.registerData(2, master_time1_index, index1)
scan.registerData(2, device1_channel1, motorData)
scan.registerData(2, device2_channel1, stepData)
scan.registerData(2, device2_channel2, gaussianData)
def __createMcas(self, scan: _VirtualScan, interval, duration):
master_time1 = scan_model.Device(scan.scan())
master_time1.setName("timer_mca")
......@@ -589,6 +675,8 @@ class AcquisitionSimulator(qt.QObject):
if name is None or name == "counter":
self.__createCounters(scan, interval, duration)
elif name is None or name == "slit":
self.__createSlit(scan, interval, duration)
elif name == "counter-no-master":
self.__createCounters(scan, interval, duration, includeMasters=False)
if name is None or name == "mca":
......
......@@ -36,6 +36,11 @@ class SimulatorWidget(qt.QMainWindow):
button.clicked.connect(lambda: self.__startScan(10, 2000, "counter"))
layout.addWidget(button)
button = qt.QPushButton(self)
button.setText("Slit scan")
button.clicked.connect(lambda: self.__startScan(10, 2000, "slit"))
layout.addWidget(button)
button = qt.QPushButton(self)
button.setText("Counter scan (no masters)")
button.clicked.connect(lambda: self.__startScan(10, 2000, "counter-no-master"))
......