Commit abd39d24 authored by Valentin Valls's avatar Valentin Valls

Create a gaussian fit item

parent 74d1b345
......@@ -159,6 +159,76 @@ class DerivativeItem(
assert False
class GaussianFitData(NamedTuple):
xx: numpy.ndarray
yy: numpy.ndarray
fit: mathutils.GaussianFitResult
class GaussianFitItem(plot_model.AbstractComputableItem, plot_item_model.CurveMixIn):
"""This item use the scan data to process result before displaying it."""
def __reduce__(self):
return (self.__class__, (), self.__getstate__())
def __getstate__(self):
state = super(GaussianFitItem, self).__getstate__()
assert "y_axis" not in state
state["y_axis"] = self.yAxis()
return state
def __setstate__(self, state):
super(GaussianFitItem, self).__setstate__(state)
self.setYAxis(state.pop("y_axis"))
def isResultValid(self, result):
return result is not None
def compute(self, scan: scan_model.Scan) -> Optional[GaussianFitData]:
sourceItem = self.source()
xx = sourceItem.xArray(scan)
yy = sourceItem.yArray(scan)
if xx is None or yy is None:
return None
try:
fit = mathutils.fit_gaussian(xx, yy)
except Exception as e:
_logger.debug("Error while computing gaussian fit", exc_info=True)
result = GaussianFitData(numpy.array([]), numpy.array([]), None)
raise plot_model.ComputeError(
"Error while creating gaussian fit.\n" + str(e), result=result
)
yy = fit.transform(xx)
return GaussianFitData(xx, yy, fit)
def xData(self, scan: scan_model.Scan) -> Optional[scan_model.Data]:
result = self.reachResult(scan)
if not self.isResultValid(result):
return None
data = result.xx
return scan_model.Data(self, data)
def yData(self, scan: scan_model.Scan) -> Optional[scan_model.Data]:
result = self.reachResult(scan)
if not self.isResultValid(result):
return None
data = result.yy
return scan_model.Data(self, data)
def displayName(self, axisName, scan: scan_model.Scan) -> str:
"""Helper to reach the axis display name"""
sourceItem = self.source()
if axisName == "x":
return sourceItem.displayName("x", scan)
elif axisName == "y":
return "gaussian(%s)" % sourceItem.displayName("y", scan)
else:
assert False
class MaxData(NamedTuple):
max_index: int
max_location_y: float
......
......@@ -72,7 +72,8 @@ def fit_gaussian(xx: numpy.ndarray, yy: numpy.ndarray) -> GaussianFitResult:
ipos = numpy.argmax(yy)
pos = xx[ipos]
# FIXME: It would be good to provide a better guess for sigma
p0 = [pos, 1, height, background]
std = abs(xx[-1] - xx[0]) * 0.5
p0 = [pos, std, height, background]
# Distance to the target function
errfunc = lambda p, x, y: _gaussian(x, p) - y
......
......@@ -250,6 +250,13 @@ class _AddItemAction(qt.QWidgetAction):
action.setIcon(icon)
action.triggered.connect(self.__createDerivative)
menu.addAction(action)
action = qt.QAction(self)
action.setText("Gaussian fit")
icon = icons.getQIcon("flint:icons/item-stats")
action.setIcon(icon)
action.triggered.connect(self.__createGaussianFit)
menu.addAction(action)
else:
action = qt.QAction(self)
action.setText("No available items")
......@@ -292,6 +299,15 @@ class _AddItemAction(qt.QWidgetAction):
with plot.transaction():
plot.addItem(newItem)
def __createGaussianFit(self):
parentItem = self.__getSelectedPlotItem()
if parentItem is not None:
plot = parentItem.plot()
newItem = plot_state_model.GaussianFitItem(plot)
newItem.setSource(parentItem)
with plot.transaction():
plot.addItem(newItem)
class _DataItem(_property_tree_helper.ScanRowItem):
def __init__(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment