Commit f61df3bd authored by Valentin Valls's avatar Valentin Valls
Browse files

Refactor the default plot API

- plot_scatter is not using anymore the previous API
- Add extra tests
parent e51ff7cf
......@@ -138,7 +138,6 @@ The return values are shown in the following example:
from typing import List
import numpy
import functools
import time
from bliss import current_session, is_bliss_shell, global_map
......@@ -174,79 +173,211 @@ close_flint = flint_proxy.close_flint
restart_flint = flint_proxy.restart_flint
def _create_plot(
plot_class,
data=None,
name=None,
existing_id=None,
selected=True,
closeable=True,
**kwargs,
def plot_curve(
data=None, x=None, name=None, existing_id=None, selected=True, closeable=True
):
"""
Display `data` as curve
Arguments:
data: A list, a numpy array a dict containing numpy arrays, and
structured numpy arrays
x: If specified, used as x-axis
name: Name of the plot
existing_id: A unique name for the plot
selected: If true the plot will be selected when created
closeable: If true the plot will be closeable
"""
flint = flint_proxy.get_flint()
plot = flint.get_plot(
plot_class,
p = flint.get_plot(
"curve",
name=name,
unique_name=existing_id,
selected=selected,
closeable=closeable,
)
if data is not None:
plot.plot(data=data, **kwargs)
return plot
p.clear_data()
if isinstance(data, list):
p.add_curve(data, x, legend="value")
elif isinstance(data, dict):
if x is None:
x_from_args = False
x = data.get("x")
else:
x_from_args = True
for k, v in data.items():
if not x_from_args and k == "x":
# This key is used as x axis
continue
p.add_curve(x, v, legend=k)
elif isinstance(data, numpy.ndarray):
if data.dtype.fields is not None:
if x is None:
x_from_args = False
if "x" in data.dtype.fields:
x = data["x"]
else:
x_from_args = True
for k in data.dtype.fields.keys():
if not x_from_args and k == "x":
# This key is used as x axis
continue
v = data[k]
p.add_curve(x, v, legend=k)
else:
p.add_curve(x, data, legend="value")
return p
plot_curve = functools.partial(_create_plot, flint_plots.CurvePlot)
plot_scatter = functools.partial(_create_plot, flint_plots.ScatterPlot)
plot_image = functools.partial(_create_plot, flint_plots.ImagePlot)
plot_image_with_histogram = functools.partial(
_create_plot, flint_plots.HistogramImagePlot
)
plot_image_stack = functools.partial(_create_plot, flint_plots.ImageStackPlot)
def plot_scatter(
x, y, value, name=None, existing_id=None, selected=True, closeable=True
):
"""
Display `data` as scatter
def default_plot(data=None, **kwargs):
kwargs["data"] = data
# No data available
if data is None:
return plot_curve(**kwargs)
# Assume a dict of curves
if isinstance(data, dict):
return plot_curve(**kwargs)
data = numpy.array(data)
# Unstructured data
if data.dtype.fields is None:
# Assume a single curve
if data.ndim == 1:
return plot_curve(**kwargs)
# Assume a single image
if data.ndim == 2:
return plot_image(**kwargs)
# Assume a colored image
if data.ndim == 3 and data.shape[2] in (3, 4):
return plot_image(**kwargs)
# Assume an image stack
if data.ndim == 3:
return plot_image_stack(**kwargs)
Arguments:
x: A list or 1D numpy array with the x-axis values
y: A list or 1D numpy array with the y-axis values
value: A list or 1D numpy array with the intensity
name: Name of the plot
existing_id: A unique name for the plot
selected: If true the plot will be selected when created
closeable: If true the plot will be closeable
"""
flint = flint_proxy.get_flint()
p = flint.get_plot(
"scatter",
name=name,
unique_name=existing_id,
selected=selected,
closeable=closeable,
)
p.set_data(x, y, value)
return p
def plot_image(data=None, name=None, existing_id=None, selected=True, closeable=True):
"""
Display `data` as an image
Arguments:
data: A 2D numpy array
name: Name of the plot
existing_id: A unique name for the plot
selected: If true the plot will be selected when created
closeable: If true the plot will be closeable
"""
flint = flint_proxy.get_flint()
p = flint.get_plot(
"image",
name=name,
unique_name=existing_id,
selected=selected,
closeable=closeable,
)
if data is not None:
p.set_data(data)
return p
plot_image_with_histogram = plot_image
"""Compatibility with BLISS <= 1.8"""
def plot_image_stack(
data=None, name=None, existing_id=None, selected=True, closeable=True
):
"""
Display `data` as a stack of images
Arguments:
data: A 3D numpy array
name: Name of the plot
existing_id: A unique name for the plot
selected: If true the plot will be selected when created
closeable: If true the plot will be closeable
"""
flint = flint_proxy.get_flint()
p = flint.get_plot(
"imagestack",
name=name,
unique_name=existing_id,
selected=selected,
closeable=closeable,
)
if data is not None:
p.set_data(data)
return p
def _plot_from_dict(data, **kwargs):
"""Create a plot from a dict.
Assume each key is a 1D array.
If a `x` key is used, it will be used as x-axis
"""
return plot_curve(data, **kwargs)
def _plot_from_structured_array(data, **kwargs):
# Assume a single struct of curves
if data.ndim == 0:
return plot_curve(**kwargs)
return plot_curve(data, **kwargs)
# A list of struct
if data.ndim == 1:
# Assume multiple curves
if all(data[field].ndim == 1 for field in data.dtype.fields):
return plot_curve(**kwargs)
return plot_curve(data, **kwargs)
# Assume multiple plots
return tuple(
default_plot(data=data[field], **kwargs) for field in data.dtype.fields
)
# Not recognized
raise ValueError("Not recognized data")
return tuple(plot(data=data[field], **kwargs) for field in data.dtype.fields)
raise ValueError(
f"No plot representation for this numpy structured array (dim={data.ndim})"
)
# Alias
plot = default_plot
def _plot_from_array(data, **kwargs):
# Assume a single curve
if data.ndim == 1:
return plot_curve(data, **kwargs)
# Assume a single image
if data.ndim == 2:
return plot_image(data, **kwargs)
# Assume a colored image
if data.ndim == 3 and data.shape[2] in (3, 4):
return plot_image(data, **kwargs)
# Assume an image stack
if data.ndim == 3:
return plot_image_stack(data, **kwargs)
raise ValueError(
f"No plot representation for this numpy array data (dim={data.ndim})"
)
def plot(data=None, **kwargs):
# No data available
if data is None:
return plot_curve(data=None, **kwargs)
if isinstance(data, dict):
return _plot_from_dict(data=data, **kwargs)
data = numpy.array(data)
if data.dtype.fields is not None:
return _plot_from_structured_array(data=data, **kwargs)
return _plot_from_array(data=data, **kwargs)
# Alias
default_plot = plot
"""Compatibility with BLISS <= 1.8"""
### plotselect etc.
......
......@@ -270,13 +270,6 @@ class BasePlot(object):
"""Returns the current data range used by this plot"""
return self.__remote.get_data_range()
# Plotting
def plot(self, data, **kwargs):
fields = list(self.add_data(data))
names = fields[: self.DATA_INPUT_NUMBER]
self.select_data(*names, **kwargs)
# Clean up
def is_open(self) -> bool:
......@@ -452,30 +445,6 @@ class Plot1D(BasePlot):
# Data input number for a single representation
DATA_INPUT_NUMBER = 2
# Specialized x data handling
def plot(self, data, **kwargs):
# Add data
data_dict = self.add_data(data)
# Get x field
x = kwargs.pop("x", None)
x_field = x if isinstance(x, str) else "x"
# Get provided x
if x_field in data_dict:
x = data_dict[x_field]
# Get default x
elif x is None:
key = next(iter(data_dict))
length = len(data_dict[key])
x = numpy.arange(length)
# Add x data
if x is not None:
self.add_single_data(x_field, x)
# Plot all curves
for field in data_dict:
if field != x_field:
self.select_data(x_field, field, **kwargs)
def update_axis_marker(
self, unique_name: str, channel_name, position: float, text: str
):
......
......@@ -11,6 +11,15 @@ def test_plot_list(flint_session):
assert vrange[0:2] == [[0, 5], [0, 2]]
def test_scatter_plot(flint_session):
x = numpy.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
y = numpy.array([0, 0, 0, 1, 1, 1, 2, 2, 2])
value = numpy.array([0, 1, 0, 1, 2, 1, 0, 1, 0])
p = plot.plot_scatter(x, y, value)
vrange = p.get_data_range()
assert vrange[0:2] == [[0, 2], [0, 2]]
def test_plot_numpy_1d(flint_session):
data = numpy.array([0, 1, 2, 0, 1, 2])
p = plot.plot(data)
......@@ -18,6 +27,23 @@ def test_plot_numpy_1d(flint_session):
assert vrange[0:2] == [[0, 5], [0, 2]]
def test_plot_structured_numpy_1d(flint_session):
dtype = [("x", float), ("v1", float), ("v2", float)]
data = numpy.array([(0, 5, 6), (1, 6, 5), (2, 5, 6), (3, 6, 5)], dtype=dtype)
p = plot.plot(data)
vrange = p.get_data_range()
assert vrange[0:2] == [[0, 3], [5, 6]]
def test_plot_dict(flint_session):
x = numpy.array([0, 1, 2, 3, 4, 5])
y = numpy.array([5, 6, 5, 6, 5, 6])
data = {"x": x, "y": y}
p = plot.plot(data)
vrange = p.get_data_range()
assert vrange[0:2] == [[0, 5], [5, 6]]
def test_plot_numpy_2d(flint_session):
data = numpy.arange(10 * 10)
data.shape = 10, 10
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment