Commit c951ffc0 authored by Wout De Nolf's avatar Wout De Nolf
Browse files

add a different autofilter chain patching approach to solve lima issues

parent cf956a88
Pipeline #51804 passed with stages
in 102 minutes and 40 seconds
This diff is collapsed.
# This file is part of the bliss project
#
# Copyright (c) 2015-2020 Beamline Control Unit, ESRF
# Distributed under the GNU LGPLv3. See LICENSE for more info.
"""
Module to manage scan with automatic filter.
Yaml config may look like this:
- plugin: bliss
class: AutoFilter
name: autof_eh1
package: bliss.common.auto_filter
detector_counter_name: roi1
monitor_counter_name: mon
min_count_rate: 20000
max_count_rate: 50000
energy_axis: $eccmono
filterset: $filtW1
# optionnal parameters
always_back: True
counters:
- counter_name: curratt
tag: fiteridx
- counter_name: transm
tag: transmission
- counter_name: ratio
tag: ratio
suffix_for_corr_counter: "_corr"
counters_for_correction:
- det
- apdcnt
"""
from tabulate import tabulate
from bliss.config.beacon_object import BeaconObject
from bliss.config import static
from bliss.common.event import connect, disconnect
from bliss.common.measurementgroup import _get_counters_from_names
from bliss.common.counter import SamplingCounter
from bliss.common.utils import autocomplete_property
from bliss import global_map
from bliss.common.session import get_current_session
from bliss.common.axis import Axis
from bliss.common.types import _countable
from bliss.common.protocols import counter_namespace
from bliss.common.auto_filter.filterset import FilterSet
from bliss.common.auto_filter.counters import FilterSetCounterController
from bliss.common.auto_filter.counters import AutoFilterCalcCounterController
def _unmarshalling_energy_axis(_, value):
if isinstance(value, str):
config = static.get_config()
return config.get(value)
else:
return value
def _marshalling_energy_axis(_, value):
return value.name
class AutoFilter(BeaconObject):
detector_counter_name = BeaconObject.property_setting(
"detector_counter_name", doc="Detector counter name"
)
monitor_counter_name = BeaconObject.property_setting(
"monitor_counter_name", doc="Monitor counter name"
)
@detector_counter_name.setter
def detector_counter_name(self, counter_name):
assert isinstance(counter_name, str)
return counter_name
@monitor_counter_name.setter
def monitor_counter_name(self, counter_name):
assert isinstance(counter_name, str)
return counter_name
@property
def detector_counter(self):
return self.__counter_getter(self.detector_counter_name)
@detector_counter.setter
def detector_counter(self, counter):
self.detector_counter_name = self.__counter_setter(counter)
@property
def monitor_counter(self):
return self.__counter_getter(self.monitor_counter_name)
@monitor_counter.setter
def monitor_counter(self, counter):
self.monitor_counter_name = self.__parse_counter_setter_value(counter)
def __counter_getter(self, counter_name):
if not counter_name:
raise RuntimeError("Counter missing from configuration")
counters, missing = _get_counters_from_names([counter_name])
if missing:
raise RuntimeError(f"Counter {repr(counter_name)} does not exist")
return counters[0]
def __counter_setter(self, counter):
if isinstance(counter, str):
# check that counter exists ... not sure if the next lines work in all cases
try:
global_map.get_counter_from_fullname(counter)
return counter
except AttributeError:
raise RuntimeError(f"Counter {repr(counter)} does not exist") from None
elif isinstance(counter, _countable):
return counter.fullname
else:
raise RuntimeError(f"Unknown counter {counter}")
min_count_rate = BeaconObject.property_setting(
"min_count_rate",
must_be_in_config=True,
doc="Minimum allowed count rate on monitor",
)
@min_count_rate.setter
def min_count_rate(self, value):
self.__filterset_needs_sync = True
max_count_rate = BeaconObject.property_setting(
"max_count_rate",
must_be_in_config=True,
doc="Maximum allowed count rate on monitor",
)
@max_count_rate.setter
def max_count_rate(self, value):
self.__filterset_needs_sync = True
always_back = BeaconObject.property_setting(
"always_back",
must_be_in_config=False,
default=True,
doc="Always move back the filter to the original position at the end of the scan",
)
corr_suffix = BeaconObject.property_setting(
"corr_suffix",
must_be_in_config=False,
default="_corr",
doc="suffix to be added to the corrected counters",
)
filterset = BeaconObject.config_obj_property_setting(
"filterset", doc="filterset to attached to the autofilter"
)
@filterset.setter
def filterset(self, new_filterset):
assert isinstance(new_filterset, FilterSet)
self.__filterset_needs_sync = True
# as this is a config_obj_property_setting
# the setter has to return the name of the
# corresponding beacon object
return new_filterset
def __init__(self, name, config):
super().__init__(config, share_hardware=False)
global_map.register(self, tag=self.name, parents_list=["counters"])
self.__create_counters(config)
self.__counters_for_corr = set()
self.counters_for_correction = config.get("counters_for_correction", [])
self.__filterset_is_synchronized = False
self.__filterset_needs_sync = True
self._max_nb_iter = None
def _set_energy_changed(self, new_energy):
self.__filterset_needs_sync = True
def __close__(self):
energy_axis = self.energy_axis
if energy_axis is not None:
disconnect(energy_axis, "position", self._set_energy_changed)
def synchronize_filterset(self):
if not self.__filterset_needs_sync:
return
filterset = self.filterset
# Synchronize the filterset with countrate range and energy
# and tell it to store back filter if necessary
energy = self.energy_axis.position
if energy <= 0:
unit = self.energy_axis.unit
raise RuntimeError(f"The current energy is not valid: {energy} {unit}")
# filterset sync. method return the maximum effective number of filters
# which will correspond to the maximum number of filter changes
self._max_nb_iter = filterset.sync(
self.min_count_rate, self.max_count_rate, energy, self.always_back
)
self.__filterset_needs_sync = False
self.__filterset_is_synchronized = True
@property
def max_nb_iter(self):
self.synchronize_filterset()
return self._max_nb_iter
def maximum_number_of_tries(self, scan_npoints):
# only add twice max number of filter iteration to the total nb points
# to be programed to counter devices.
return scan_npoints + (4 * self.max_nb_iter)
energy_axis = BeaconObject.property_setting(
"energy_axis",
must_be_in_config=True,
set_marshalling=_marshalling_energy_axis,
set_unmarshalling=_unmarshalling_energy_axis,
)
@energy_axis.setter
def energy_axis(self, energy_axis):
previous_energy_axis = self.energy_axis
if self._in_initialize_with_setting or energy_axis != previous_energy_axis:
if isinstance(energy_axis, Axis):
if previous_energy_axis is not None:
disconnect(
previous_energy_axis, "position", self._set_energy_changed
)
connect(energy_axis, "position", self._set_energy_changed)
self._set_energy_changed(energy_axis.position)
else:
raise ValueError(f"{energy_axis} is not a Bliss Axis")
@property
def counters_for_correction(self):
"""These counters will have an additional correction counter
"""
return list(self.__counters_for_corr)
@counters_for_correction.setter
def counters_for_correction(self, counters):
if not isinstance(counters, list):
counters = list(counters)
# The monitor counter is the default, remove missing counters.
cnts, missing = _get_counters_from_names(counters)
for cnt in cnts:
self.__counters_for_corr.add(cnt.fullname)
@autocomplete_property
def counters(self):
counters = []
if self.filterset_counter_controller is not None:
counters += list(self.filterset_counter_controller.counters)
if self.calc_counter_controller is not None:
counters += list(self.calc_counter_controller.outputs)
return counter_namespace(counters)
@property
def transmission(self):
self.synchronize_filterset()
return self.filterset.transmission
@property
def filter(self):
return self.filterset.filter
@filter.setter
def filter(self, new_filter):
self.__filterset_needs_sync = True
self.filterset.filter = new_filter
def __info__(self):
table_info = []
for sname in (
"monitor_counter_name",
"detector_counter_name",
"min_count_rate",
"max_count_rate",
"always_back",
):
table_info.append([sname, getattr(self, sname)])
info = str(tabulate(table_info, headers=["Parameter", "Value"]))
info += "\n\n" + f"Active filterset: {self.filterset.name}"
info += (
"\n"
+ f"Energy axis {self.energy_axis.name}: {self.energy_axis.position:.5g} keV"
)
# calling transmission can update the filterset info_table if the energy has changed
transm = self.transmission
info += (
"\n\n"
+ f"Active filter idx {self.filterset.filter}, transmission {transm:g}"
)
info += "\n\n" + "Table of Effective Filters :"
if self.__filterset_is_synchronized:
info += "\n" + self.filterset.info_table()
else:
info += "\n Cannot get effective filters, check your energy, please !!!"
return info
def __create_counters(self, config, export_to_session=True):
cnts_conf = config.get("counters")
if cnts_conf is None:
self.filterset_counter_controller = None
self.calc_counter_controller = None
return
self.filterset_counter_controller = FilterSetCounterController(self)
for counter_name, tag in self.iter_counter_names(config, skip_tags=["ratio"]):
counter = self.filterset_counter_controller.create_counter(
SamplingCounter, counter_name, mode="SINGLE"
)
counter.tag = tag
if export_to_session:
self.__add_to_bliss_session(counter_name, counter)
self.calc_counter_controller = AutoFilterCalcCounterController(self, config)
def iter_counter_names(self, config, skip_tags=tuple(), only_tags=None):
for conf in config.get("counters", list()):
tag = conf["tag"].strip()
if only_tags and tag not in only_tags:
continue
if tag in skip_tags:
continue
counter_name = conf["counter_name"].strip()
yield counter_name, tag
def __add_to_bliss_session(self, name, obj):
current_session = get_current_session()
if current_session is None:
return
if (
name in current_session.config.names_list
or name in current_session.env_dict.keys()
):
raise ValueError(
f"Cannot export object to session with the name '{name}', name is already taken! "
)
current_session.env_dict[name] = obj
def beam_attenuation_correction(self, point_nb, name, data):
"""Calculate signal that would have been measured when the primary
beam was not attenuation.
"""
return data / self.transmission
......@@ -3,16 +3,26 @@
# Copyright (c) 2015-2019 Beamline Control Unit, ESRF
# Distributed under the GNU LGPLv3. See LICENSE for more info.
import weakref
import logging
import gevent
import numpy
from bliss.common import event
from bliss.common.event import dispatcher
from bliss.scanning.chain import AcquisitionMaster
from bliss.scanning.chain import AcquisitionSlave
from bliss.scanning.chain import duplicate_channel
from bliss.scanning.chain import AcquisitionChain
from bliss.scanning.channel import AcquisitionChannel
from bliss.scanning.acquisition.timer import SoftwareTimerMaster
from bliss.scanning.acquisition.motor import (
VariableStepTriggerMaster as _VariableStepTriggerMaster
)
from bliss.scanning import chain
from bliss.scanning.acquisition import lima
from bliss.scanning.channel import AcquisitionChannel
from bliss.common import event
from bliss.scanning.acquisition.lima import LimaAcquisitionMaster
from bliss.common.auto_filter.base_controller import AutoFilter
logger = logging.getLogger(__name__)
class VariableStepTriggerMaster(_VariableStepTriggerMaster):
......@@ -24,13 +34,21 @@ class VariableStepTriggerMaster(_VariableStepTriggerMaster):
def __iter__(self):
position_iter = zip(*self._motor_pos)
positions = next(position_iter)
msg_attempt = "\nAutofilter measure point %d (attempt %d)"
msg_done = "\nAutofilter completed point %d in %d attempts"
attempt_nb = 1
point_nb = 0
while True:
self.next_mv_cmd_arg = []
for axis, position in zip(self._axes, positions):
self.next_mv_cmd_arg += [axis, position]
self.reset_point_valid()
logger.debug(msg_attempt, point_nb, attempt_nb)
yield self
if self.is_point_valid():
logger.debug(msg_done, point_nb, attempt_nb)
point_nb += 1
attempt_nb = 1
try:
positions = next(position_iter)
except StopIteration:
......@@ -44,6 +62,7 @@ class VariableStepTriggerMaster(_VariableStepTriggerMaster):
self.wait_slaves()
def validate_point(self, point_nb, valid):
logger.debug("GET VALID MASTER point_nb %d %s", point_nb, valid)
self._valid_point = valid
self._event.set()
if valid:
......@@ -60,13 +79,13 @@ class VariableStepTriggerMaster(_VariableStepTriggerMaster):
return self._valid_point
class _Base:
class AcquisitionObjectWrapper:
def __init__(self, auto_filter):
self._name_2_channel = weakref.WeakValueDictionary()
self._name_2_corr_chan = weakref.WeakValueDictionary()
# copy all channel from the slave.
for channel in self.device.channels:
new_channel, _, _ = chain.duplicate_channel(channel)
new_channel, _, _ = duplicate_channel(channel)
self._name_2_channel[new_channel.name] = new_channel
event.connect(channel, "new_data", self.new_data_received)
self.channels.append(new_channel)
......@@ -103,7 +122,6 @@ class _Base:
self.__received_event.wait()
except gevent.Timeout:
pass
# print(f"stop {self.device}")
try:
return self.device.stop()
finally:
......@@ -137,14 +155,20 @@ class _Base:
return min_last_point == current_point + 1 and not self.__pending_data
def new_data_received(self, event_dict=None, signal=None, sender=None):
channel_data = event_dict.get("data")
if channel_data is None:
return
channel = sender
channel_name = channel.name
last_point_rx = self.__last_point_rx.setdefault(channel_name, 0)
valid = self.__valid_point.get(last_point_rx)
logger.debug(
"GET VALID %s point_nb %d %s",
channel_name,
self.__last_point_rx[channel_name],
valid,
)
channel_data = event_dict.get("data")
if channel_data is None:
return
# three cases
# valid = False -> not valid
# valid = True -> is valid
......@@ -160,29 +184,38 @@ class _Base:
self.__pending_data[channel_name] = numpy.append(
previous_data, channel_data
)
else: # valid is True or False
if valid:
my_channel = self._name_2_channel[channel_name]
# print(f"emit {channel_name} {self._auto_filter.current_point}")
my_channel.emit(channel_data)
elif valid:
logger.debug(
"PUBLISH(1) %s point_nb %d %s",
channel_name,
self.__last_point_rx[channel_name],
valid,
)
my_channel = self._name_2_channel[channel_name]
my_channel.emit(channel_data)
corr_chan = self._name_2_corr_chan.get(channel_name)
if corr_chan is not None:
corrected_data = self._auto_filter.corr_func(
last_point_rx, channel_name, channel_data
)
corr_chan.emit(corrected_data)
corr_chan = self._name_2_corr_chan.get(channel_name)
if corr_chan is not None:
corrected_data = self._auto_filter.beam_attenuation_correction(
last_point_rx, channel_name, channel_data
)
corr_chan.emit(corrected_data)
if isinstance(channel_data, dict): # Lima
self.__last_point_rx[channel_name] = last_point_rx + 1
else:
self.__last_point_rx[channel_name] = last_point_rx + len(channel_data)
logger.debug(
"INC %s point_nb %d", channel_name, self.__last_point_rx[channel_name]
)
self.__received_event.set()
def validate_point(self, point_nb, valid_flag):
# for now we just do simple thing we remove all the
# pending... to check if it too simple. Doesn't take into
# account data block receiving...
logger.debug("GET VALID point_nb %d %s", point_nb, valid_flag)
self.__valid_point[point_nb] = valid_flag
if not valid_flag:
# clean pending_data
......@@ -192,73 +225,78 @@ class _Base:
self.__pending_data = dict()
for channel_name, data in pending_data.items():
channel = self._name_2_channel[channel_name]
# print(f"emit {channel_name} {self._auto_filter.current_point}")
logger.debug(
"PUBLISH(2) %s point_nb %d %s",
channel_name,
self.__last_point_rx[channel_name],
valid_flag,
)
channel.emit(data)
corr_chan = self._name_2_corr_chan.get(channel_name)
if corr_chan is not None:
corrected_data = self._auto_filter.corr_func(
corrected_data = self._auto_filter.beam_attenuation_correction(
point_nb, channel_name, data
)
corr_chan.emit(corrected_data)
self.__received_event.set()
class _Slave(_Base, chain.AcquisitionSlave):
def __init__(self, auto_filter, slave, npoints=1):
chain.AcquisitionSlave.__init__(
class AcquisitionSlaveWrapper(AcquisitionObjectWrapper, AcquisitionSlave):
def __init__(self, auto_filter, slave):
AcquisitionSlave.__init__(
self,
slave,
name=slave.name,
npoints=npoints,
npoints=slave.npoints,
trigger_type=slave.trigger_type,
prepare_once=slave.prepare_once,
start_once=slave.start_once,
)
_Base.__init__(self, auto_filter)
AcquisitionObjectWrapper.__init__(self, auto_filter)
class _SlaveIter(_Slave):
class AcquisitionSlaveWrapperIter(AcquisitionSlaveWrapper):
def __iter__(self):
for i in self.device:
yield self
class _Master(_Base, chain.AcquisitionMaster):
def __init__(self, auto_filter, master, npoints=1):
chain.AcquisitionMaster.__init__(
class AcquisitionMasterWrapper(AcquisitionObjectWrapper, AcquisitionMaster):
def __init__(self, auto_filter, master):
AcquisitionMaster.__init__(
self,
master,
name=master.name,
npoints=npoints,
npoints=master.npoints,
trigger_type=master.trigger_type,
prepare_once=master.prepare_once,
start_once=master.start_once,
)
_Base.__init__(self, auto_filter)
AcquisitionObjectWrapper.__init__(self, auto_filter)
# hack slaves of master
# replace buy our slaves
master._AcquisitionMaster__slaves = self.slaves
@property
def parent(self):
return chain.AcquisitionMaster.parent.fget(self)
return AcquisitionMaster.parent.fget(self)
@parent.setter
def parent(self, new_parent):
chain.AcquisitionMaster.parent.fset(self, new_parent)
AcquisitionMaster.parent.fset(self, new_parent)
# give to the embeded AcqMaster the same parent
# to avoid to trig on start (i.e:Timer)
self.device.parent = new_parent
class _MasterIter(_Master):
class AcquisitionMasterWrapperIter(AcquisitionMasterWrapper):