Commit d5e1cbbc authored by payno's avatar payno
Browse files

[unit test] fix unit test for guessing shift

- now OverwritingOperation can set data_flatten but not data. Might evolve, it depends how we saw things.
parent b27ce07a
Pipeline #6589 passed with stage
in 2 minutes and 20 seconds
...@@ -151,7 +151,7 @@ def main(argv): ...@@ -151,7 +151,7 @@ def main(argv):
test_suite = unittest.TestSuite() test_suite = unittest.TestSuite()
test_suite.addTest(id06workflow.test.suite()) test_suite.addTest(id06workflow.test.suite())
import orangecontrib.id06workflow.test import orangecontrib.id06workflow.test
test_suite.addTest(orangecontrib.id06workflow.test) test_suite.addTest(orangecontrib.id06workflow.test.suite())
result = runner.run(test_suite) result = runner.run(test_suite)
if result.wasSuccessful(): if result.wasSuccessful():
......
...@@ -193,8 +193,8 @@ class Experiment(object): ...@@ -193,8 +193,8 @@ class Experiment(object):
self.__data = self.getRawData() self.__data = self.getRawData()
return self.__data return self.__data
@data.setter @data_flatten.setter
def data(self, data): def data_flatten(self, data):
assert data.ndim > 2 assert data.ndim > 2
self.__data = data.reshape(-1, data.shape[-2], data.shape[-1]) self.__data = data.reshape(-1, data.shape[-2], data.shape[-1])
......
...@@ -63,6 +63,12 @@ class _BaseOperation(qt.QObject): ...@@ -63,6 +63,12 @@ class _BaseOperation(qt.QObject):
data.setflags(write=self._can_overwrite_data) data.setflags(write=self._can_overwrite_data)
return data return data
@property
def data_flatten(self):
data_flatten = self._experiment.data_flatten.view()
data_flatten.setflags(write=self._can_overwrite_data)
return data_flatten
@property @property
def experiment(self): def experiment(self):
return self._experiment return self._experiment
...@@ -97,9 +103,9 @@ class OverwritingOperation(_BaseOperation): ...@@ -97,9 +103,9 @@ class OverwritingOperation(_BaseOperation):
def __init__(self, experiment, name): def __init__(self, experiment, name):
_BaseOperation.__init__(self, experiment, name, can_overwrite_data=True) _BaseOperation.__init__(self, experiment, name, can_overwrite_data=True)
@_BaseOperation.data.setter @_BaseOperation.data_flatten.setter
def data(self, data): def data_flatten(self, data):
self._experiment.data = data self._experiment.data_flatten = data
def dry_run(self, cache_data=None): def dry_run(self, cache_data=None):
""" """
......
...@@ -49,7 +49,7 @@ class RoiOperation(OverwritingOperation): ...@@ -49,7 +49,7 @@ class RoiOperation(OverwritingOperation):
str(self._size))) str(self._size)))
def compute(self): def compute(self):
self.data = self._compute(self.data) self.data_flatten = self._compute(self.data)
self.registerOperation() self.registerOperation()
return self.data return self.data
...@@ -64,7 +64,7 @@ class RoiOperation(OverwritingOperation): ...@@ -64,7 +64,7 @@ class RoiOperation(OverwritingOperation):
def apply(self): def apply(self):
if self._cache_data is None: if self._cache_data is None:
raise ValueError('No data in cache') raise ValueError('No data in cache')
self.data = self._cache_data self.data_flatten = self._cache_data
self.clear_cache() self.clear_cache()
self.registerOperation() self.registerOperation()
return self.data return self.data
......
...@@ -55,7 +55,6 @@ class Shift(OverwritingOperation): ...@@ -55,7 +55,6 @@ class Shift(OverwritingOperation):
OverwritingOperation.__init__(self, experiment, name='shift') OverwritingOperation.__init__(self, experiment, name='shift')
if dz != 0: if dz != 0:
raise NotImplementedError('z shift not taken into account yet') raise NotImplementedError('z shift not taken into account yet')
assert self.data_flatten.ndim is 3
self.dx = dx self.dx = dx
self.dy = dy self.dy = dy
self.dz = dz self.dz = dz
...@@ -110,9 +109,10 @@ class Shift(OverwritingOperation): ...@@ -110,9 +109,10 @@ class Shift(OverwritingOperation):
self._cache_data = None self._cache_data = None
def _compute(self, data): def _compute(self, data):
assert data.ndim is 3
res = [] res = []
nImg = data.shape[0] nImg = data.shape[0]
for iImg, img in enumerate(self.data_flatten[:]): for iImg, img in enumerate(data[:]):
self.updateProgress(int(iImg / nImg * 100.0)) self.updateProgress(int(iImg / nImg * 100.0))
_dx = self.dx * iImg _dx = self.dx * iImg
_dy = self.dy * iImg _dy = self.dy * iImg
......
...@@ -85,7 +85,7 @@ class TestExperiement(unittest.TestCase): ...@@ -85,7 +85,7 @@ class TestExperiement(unittest.TestCase):
back_sub_data = experiment.apply_background_subtraction() back_sub_data = experiment.apply_background_subtraction()
self.assertTrue(numpy.array_equal(back_sub_data[3], numpy.zeros(dims))) self.assertTrue(numpy.array_equal(back_sub_data[3], numpy.zeros(dims)))
@unittest.skipUnless(os.path.exists('/nobackup/linazimov/payno/dev/esrf/ID06/dataset/for_mapping')) @unittest.skipUnless(os.path.exists('/nobackup/linazimov/payno/dev/esrf/ID06/dataset/for_mapping'), reason='Data files not available')
def testMapping(self): def testMapping(self):
# TODO: store datasets on the web and retrieve them using utils.getDataset() # TODO: store datasets on the web and retrieve them using utils.getDataset()
_dir = '/nobackup/linazimov/payno/dev/esrf/ID06/dataset/for_mapping' _dir = '/nobackup/linazimov/payno/dev/esrf/ID06/dataset/for_mapping'
......
...@@ -76,13 +76,12 @@ class TestShift(unittest.TestCase): ...@@ -76,13 +76,12 @@ class TestShift(unittest.TestCase):
dataset = createDataset() dataset = createDataset()
_experiment = experiment.Experiment(dataset=dataset) _experiment = experiment.Experiment(dataset=dataset)
# print(image.guess_shift(data=_experiment.data_flatten, axis=0))
res = (image.guess_shift(data=_experiment.data_flatten, axis=1, res = (image.guess_shift(data=_experiment.data_flatten, axis=1,
start=-0.0005, stop=0.0005, step=0.0001)) start=-0.0005, stop=0.0005, step=0.0001))
self.assertTrue(numpy.isclose( self.assertTrue(numpy.isclose(res, -2.0e-04))
image.guess_shift(data=_experiment.data_flatten, axis=0), res)) res = image.guess_shift(data=_experiment.data_flatten, axis=0,
self.assertTrue(numpy.isclose( start=-1.0, stop=1.0, step=0.5)
image.guess_shift(data=_experiment.data_flatten, axis=1), 0.0)) self.assertTrue(numpy.isclose(res, 0.0))
......
...@@ -91,6 +91,12 @@ class ShiftCorrectionWidget(qt.QWidget): ...@@ -91,6 +91,12 @@ class ShiftCorrectionWidget(qt.QWidget):
# and show some information on advancement # and show some information on advancement
if self._operation_thread.isRunning() is True: if self._operation_thread.isRunning() is True:
_logger.error('Cannot process shift until the current one is finished') _logger.error('Cannot process shift until the current one is finished')
return
if self._experiment is None:
_logger.error(
'Cannot process shift because no experiment has been set')
return
_shift = self.getShift() _shift = self.getShift()
if _shift == self._lastShift: if _shift == self._lastShift:
......
...@@ -134,6 +134,13 @@ class TestFirstSetup(OrangeWorflowTest): ...@@ -134,6 +134,13 @@ class TestFirstSetup(OrangeWorflowTest):
# manage shift correction # manage shift correction
self.assertTrue(self._shiftCorrectionWidget._editedExperiment is not None) self.assertTrue(self._shiftCorrectionWidget._editedExperiment is not None)
self._shiftCorrectionWidget.setShift(dx=1.0, dy=-1.0, dz=0.0) self._shiftCorrectionWidget.setShift(dx=1.0, dy=-1.0, dz=0.0)
timeout = 10*1000 # timeout in millisec
waiting_time = 0
# wait for shift to be processed
while (self._shiftCorrectionWidget._widget._operation_thread.isRunning() and
waiting_time < timeout):
self.qWait(200)
waiting_time += 200
self._shiftCorrectionWidget.validate() self._shiftCorrectionWidget.validate()
self._moveToNextStep() self._moveToNextStep()
...@@ -141,7 +148,6 @@ class TestFirstSetup(OrangeWorflowTest): ...@@ -141,7 +148,6 @@ class TestFirstSetup(OrangeWorflowTest):
self.assertTrue(self._noiseReductionWidget._editedExperiment is not None) self.assertTrue(self._noiseReductionWidget._editedExperiment is not None)
self._noiseReductionWidget.validate() self._noiseReductionWidget.validate()
self._moveToNextStep() self._moveToNextStep()
# TODO: check value of the next input
noise_reducted_data = self._saveWidget._editedExperiment.data noise_reducted_data = self._saveWidget._editedExperiment.data
self.assertTrue(noise_reducted_data.min() >= 0.2) self.assertTrue(noise_reducted_data.min() >= 0.2)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment