diff --git a/bliss/common/scans/step_by_step.py b/bliss/common/scans/step_by_step.py index af5f4adf23c60f9aec16f753e55b094198ae7bfe..31385149ffd7f18567ac4682f36a0c68c209820e 100644 --- a/bliss/common/scans/step_by_step.py +++ b/bliss/common/scans/step_by_step.py @@ -327,6 +327,7 @@ def lookupscan( return_scan: bool = True, scan_info: Optional[dict] = None, scan_params: Optional[dict] = None, + restore_motor_positions: bool = False, ): """Lookupscan usage: lookupscan([(m0,numpy.arange(0,2,0.5)),(m1,numpy.linspace(1,3,4))],0.1,diode2) @@ -349,6 +350,7 @@ def lookupscan( return_scan: bool = True, scan_info: Optional[dict] = None, scan_params: Optional[dict] = None, + restore_motor_positions: bool = False, """ scan_info = ScanInfo.normalize(scan_info) if scan_params is None: @@ -406,6 +408,9 @@ def lookupscan( data_watch_callback=StepScanDataWatch(), ) + if restore_motor_positions: + scan.restore_motor_positions = True + if run: scan.run() @@ -431,6 +436,7 @@ def anscan( run: bool = True, return_scan: bool = True, scan_info: Optional[dict] = None, + restore_motor_positions: bool = False, ): """ anscan usage: @@ -456,6 +462,7 @@ def anscan( return_scan: bool = True, scan_info: Optional[dict] = None, scan_params: Optional[dict] = None, + restore_motor_positions: bool = False, example: anscan( [(m1, 1, 2), (m2, 3, 7)], 10, 0.1, diode2) @@ -534,6 +541,7 @@ def anscan( return_scan=return_scan, scan_info=scan_info, scan_params=scan_params, + restore_motor_positions=restore_motor_positions, ) @@ -593,18 +601,9 @@ def dnscan( sleep_time=sleep_time, run=False, scan_info=scan_info, + restore_motor_positions=True, ) - def run_with_cleanup(self, __run__=scan.run): - with cleanup( - *[m[0] for m in motor_tuple_list], - restore_list=(cleanup_axis.POS,), - verbose=True, - ): - __run__() - - scan.run = types.MethodType(run_with_cleanup, scan) - if run: scan.run() diff --git a/bliss/scanning/scan.py b/bliss/scanning/scan.py index 03454db92c3defdeabdb57b016ddf8d0306165c8..81a6aa2093568321301b760dad59136c3c881099 100755 --- a/bliss/scanning/scan.py +++ b/bliss/scanning/scan.py @@ -19,13 +19,12 @@ from typing import Callable, Any import typeguard import logging import numpy +import itertools from bliss.common.types import _countable from bliss import current_session, is_bliss_shell -from bliss.common.axis import Axis from bliss.common.motor_group import is_motor_group from bliss.common.hook import group_hooks, execute_pre_scan_hooks -from bliss.common.event import connect, disconnect from bliss.common import event from bliss.common.cleanup import error_cleanup, axis as cleanup_axis, capture_exceptions from bliss.common.greenlet_utils import KillMask @@ -623,6 +622,7 @@ class Scan: self.__nodes = dict() self._devices = [] self._axes_in_scan = [] # for pre_scan, post_scan in axes hooks + self._restore_motor_positions = False self._data_watch_task = None self._data_watch_callback = data_watch_callback @@ -913,6 +913,18 @@ class Scan: else: return "{scan_number}" + @property + def restore_motor_positions(self): + """Weither to restore the initial motor positions at the end of scan run (for dscans). + """ + return self._restore_motor_positions + + @restore_motor_positions.setter + def restore_motor_positions(self, restore): + """Weither to restore the initial motor positions at the end of scan run (for dscans). + """ + self._restore_motor_positions = restore + def get_plot( self, channel_item, plot_type, as_axes=False, wait=False, silent=False ): @@ -1420,6 +1432,12 @@ class Scan: "Scan state is not idle. Scan objects can only be used once." ) + if self.restore_motor_positions: + # store initial positions + motor_positions = [ + (mot, mot._set_position) for mot in self._get_data_axes() + ] + # check if watch callback has to be called in "prepare" and "stop" phases data_watch_call_on_prepare = data_watch_call_on_stop = False if self._data_watch_callback is not None: @@ -1591,6 +1609,18 @@ class Scan: raise ScanAbort from e raise e + # restore motors initial position + if self.restore_motor_positions: + with capture(): + if is_bliss_shell(): + from bliss.shell.standard import umv as move + + event.send(self, "close_progress_bar") + else: + from bliss.common.standard import move + + move(*itertools.chain(*motor_positions)) + # execute post scan hooks hooks = group_hooks(self._axes_in_scan) for hook in reversed(list(hooks)): diff --git a/bliss/shell/data/display.py b/bliss/shell/data/display.py index 10107373632b1b350acbf5df7ee5e130bff1afeb..48ae4784cae53057ace322588c83a3678b9691e4 100644 --- a/bliss/shell/data/display.py +++ b/bliss/shell/data/display.py @@ -16,6 +16,8 @@ import typing import gevent import numbers +from louie import Any + from bliss.data import scan as scan_mdl from bliss.common.utils import nonblocking_print from bliss.common.event import dispatcher @@ -626,10 +628,17 @@ class ScanPrinterWithProgressBar(ScanPrinter): self.progress_bar.set_description(", ".join(self.labels)) self.progress_bar.refresh() + def _close_progress_bar(self): + if self.progress_bar is not None: + self.progress_bar.close() + def on_scan_new(self, scan, scan_info): super().on_scan_new(scan, scan_info) total = scan_info["npoints"] self.progress_bar = tqdm(total=total, leave=False) + dispatcher.connect( + self._close_progress_bar, signal="close_progress_bar", sender=Any + ) def on_scan_data(self, scan_info, data): nb_rows = self.scan_renderer.nb_data_rows diff --git a/tests/scans/test_exception.py b/tests/scans/test_exception.py index 0f3606de027d041f2e7e4ce07ad72260a9b59e7c..ba55146c05b4f8b8ddedc88fa45ce5ad906c8d9c 100644 --- a/tests/scans/test_exception.py +++ b/tests/scans/test_exception.py @@ -1,7 +1,9 @@ import gevent import pytest +from unittest import mock from bliss.common import scans +from bliss.common.standard import move from bliss.common.counter import SamplingCounter from bliss.controllers.counter import SamplingCounterController from bliss.common.soft_axis import SoftAxis @@ -162,6 +164,24 @@ def test_exception_in_preset(default_session): assert s.node.info["state"] == ScanState.KILLED +def test_exception_in_dscan_move_back(default_session): + diode = default_session.config.get("diode") + bad = default_session.config.get("bad") + bad.move(0) + s = scans.dscan(bad, 1, 2, 1, .1, diode, save=False, run=False) + + def patched_move(*args): + bad.controller.bad_start = True + move(*args) + + with mock.patch("bliss.common.standard.move", new=patched_move): + with pytest.raises(RuntimeError): + s.run() + + assert bad.position > 1 + assert bad.state.READY + + def test_sequence_state(default_session): diode = default_session.config.get("diode") roby = default_session.config.get("roby")