Commit 43ee1430 authored by Matias Guijarro's avatar Matias Guijarro
Browse files

fix issue #1773: ensure reading task is killed when acq. stop fails, to allow...

fix issue #1773: ensure reading task is killed when acq. stop fails, to allow cleanup of all devices involved in the scan

Before the fix, if a `acq_stop` was failing, `wait_all_devices` was stuck
thus preventing cleaning up of other devices.
parent 009af896
Pipeline #46219 passed with stages
in 91 minutes and 57 seconds
......@@ -327,10 +327,12 @@ class SoftwarePositionTriggerMaster(MotorMaster):
self.started.set()
def stop(self):
self.movable.stop()
event.disconnect(self.movable, "internal_state", self.on_state_change)
if self.task:
self.task.kill()
try:
self.movable.stop()
finally:
event.disconnect(self.movable, "internal_state", self.on_state_change)
if self.task:
self.task.kill()
def trigger(self):
return self._start_move()
......
......@@ -426,7 +426,6 @@ class AcquisitionObject:
# The acquistion object is also considered to be
# ready when the reading task (if any) is not running.
if self.has_reading_task():
# No time profiling of wait_reading here!
tasks.append(gevent.spawn(self.wait_reading))
tasks.append(gevent.spawn(self.wait_ready))
join_tasks(tasks, count=1)
......@@ -834,7 +833,7 @@ class AcquisitionChainIter:
return self._tree.children("root")[0].identifier.acquisition_object
def apply_parameters(self):
for tasks in self._execute("apply_parameters", wait_between_levels=False):
for tasks, _ in self._execute("apply_parameters", wait_between_levels=False):
gevent.joinall(tasks, raise_error=True)
def prepare(self, scan, scan_info):
......@@ -868,7 +867,7 @@ class AcquisitionChainIter:
join_tasks(preset_tasks)
for tasks in self._execute(
for tasks, _ in self._execute(
"acq_prepare", wait_between_levels=not self._parallel_prepare
):
join_tasks(tasks)
......@@ -887,7 +886,7 @@ class AcquisitionChainIter:
join_tasks(preset_tasks)
for tasks in self._execute("acq_start"):
for tasks, _ in self._execute("acq_start"):
join_tasks(tasks)
def _acquisition_object_iterators(self):
......@@ -905,6 +904,7 @@ class AcquisitionChainIter:
def stop(self):
all_tasks = []
all_acq_objs = []
with capture_exceptions(raise_index=0) as capture:
# call before_stop on preset
with capture():
......@@ -916,14 +916,20 @@ class AcquisitionChainIter:
gevent.joinall(preset_tasks) # wait to call all before_stop on preset
gevent.joinall(preset_tasks, raise_error=True)
for tasks in self._execute("acq_stop", master_to_slave=True):
for tasks, acq_objs in self._execute("acq_stop", master_to_slave=True):
with KillMask(masked_kill_nb=1):
gevent.joinall(tasks)
all_tasks.extend(tasks)
with capture():
gevent.joinall(all_tasks, raise_error=True)
all_acq_objs.extend(acq_objs)
for i, task in enumerate(all_tasks):
with capture():
try:
task.get()
except BaseException:
acq_obj = all_acq_objs[i]
if hasattr(acq_obj, "_reading_task"):
acq_obj._reading_task.kill()
raise
with capture():
self.wait_all_devices()
......@@ -944,7 +950,7 @@ class AcquisitionChainIter:
if self.__sequence_index == 0:
self._start_time = time.time()
wait_ready_tasks = self._execute("acq_wait_ready", master_to_slave=True)
for tasks in wait_ready_tasks:
for tasks, _ in wait_ready_tasks:
join_tasks(tasks)
try:
if self.__sequence_index:
......@@ -961,7 +967,8 @@ class AcquisitionChainIter:
return self
def _execute(self, func_name, master_to_slave=False, wait_between_levels=True):
tasks = list()
tasks = []
acq_objs = []
prev_level = None
if master_to_slave:
acq_obj_iters = list(self._tree.expand_tree(mode=Tree.WIDTH))[1:]
......@@ -972,14 +979,16 @@ class AcquisitionChainIter:
node = self._tree.get_node(acq_obj_iter)
level = self._tree.depth(node)
if wait_between_levels and prev_level != level:
yield tasks
tasks = list()
yield tasks, acq_objs
acq_objs = []
tasks = []
prev_level = level
func = getattr(acq_obj_iter, func_name)
t = gevent.spawn(func)
_running_task_on_device[acq_obj_iter.acquisition_object] = t
acq_objs.append(acq_obj_iter.acquisition_object)
tasks.append(t)
yield tasks
yield tasks, acq_objs
def __iter__(self):
return self
......
......@@ -10,6 +10,7 @@ import time
import numpy as np
import pytest
import gevent
from unittest import mock
from bliss.common import event
from bliss.common import scans
......@@ -161,6 +162,40 @@ def test_interrupted_scan(session, diode_acq_device_factory):
assert acquisition_device_2.stop_flag
def test_timeout_error(session, diode_acq_device_factory):
robz = session.config.get("robz")
robz.velocity_limits = 0, 100
robz.velocity = 1
chain = AcquisitionChain()
acquisition_device_1, _ = diode_acq_device_factory.get(count_time=0.1, npoints=5)
acquisition_device_2, _ = diode_acq_device_factory.get(count_time=0.1, npoints=5)
master = SoftwarePositionTriggerMaster(robz, 0, 1, 5)
chain.add(master, acquisition_device_1)
chain.add(master, acquisition_device_2)
s = Scan(chain, save=False)
with mock.patch.object(robz, "stop", side_effect=gevent.Timeout) as stop_method:
with mock.patch.object(
acquisition_device_1, "stop", side_effect=gevent.Timeout
) as acq_dev_stop_method:
# Run scan
scan_task = gevent.spawn(s.run)
with gevent.Timeout(1):
s.wait_state(ScanState.STARTING)
gevent.sleep(0.2)
try:
scan_task.kill(KeyboardInterrupt)
except:
assert scan_task.ready()
assert stop_method.called_once
assert acq_dev_stop_method.called_once
assert acquisition_device_2.stop_flag
def test_scan_too_fast(session, diode_acq_device_factory):
robz = session.config.get("robz")
robz.velocity = 10
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment