Commit 0fbae136 authored by payno's avatar payno

[feature] add guess_shift function to guess shift during acquisition

- add also unit test for it.
- this is quiet a raw version, has to be tested again and wildly optinized
parent 9c9af232
......@@ -61,3 +61,65 @@ def shift_img(data, dx, dy):
res = abs(numpy.fft.ifft2(numpy.fft.fft2(data) * numpy.exp(
1.0j * 2.0 * numpy.pi * (dy * ny / ynum + dx * nx / xnum))))
return res
def guess_shift(data, axis, start=-1.0, stop=1.0, step=0.1):
"""Try to guess the shift from the flatten data"""
def com_1d(ydata):
xdata = numpy.arange(len(ydata))
deno = numpy.sum(ydata).astype(numpy.float32)
if deno == 0.:
return numpy.nan
else:
return numpy.sum(xdata * ydata).astype(numpy.float32) / deno
def com_2d(data, axis):
sum = numpy.sum(data, axis=axis)
deno = numpy.sum(sum)
if deno == 0.0:
return numpy.nan
else:
return sum / deno
def com_var(x_shift, y_shift):
if x_shift is None:
assert y_shift is not None
if y_shift is None:
assert x_shift is not None
coms = []
for iImg, img in enumerate(data):
# if img is too large, reduce size
_img = img
if _img.shape[-1] > 1024:
_img = _img[::2, ::2]
if x_shift is None:
shifted_img = shift_img(_img, dx=0.0, dy=y_shift * iImg)
else:
shifted_img = shift_img(_img, dx=x_shift * iImg, dy=0.0)
if x_shift is None:
com = com_1d(ydata=com_2d(shifted_img, axis=1))
else:
com = com_1d(ydata=com_2d(shifted_img, axis=0))
if img.shape[-1] > 1024:
com = com * 0.5
coms.append(com)
return numpy.var(coms)
assert axis in (0, 1)
assert data.ndim is 3
vars = []
_range = numpy.arange(start=start, stop=stop, step=step)
if axis is 0:
y_shift = None
[vars.append(com_var(x_shift, y_shift)) for x_shift in _range]
elif axis is 1:
x_shift = None
[vars.append(com_var(x_shift, y_shift)) for y_shift in _range]
else:
raise not NotImplementedError('')
return _range[numpy.argmin(vars)]
......@@ -28,15 +28,18 @@ __license__ = "MIT"
__date__ = "20/11/2018"
import os
import unittest
import numpy
from id06workflow.core import image
from id06workflow.core import experiment
class TestShift(unittest.TestCase):
"""
Test that RefCopy process is correct
"""
data_dir = '/users/payno/datasets/id06/strain_scan'
def testShit(self):
"""
......@@ -47,6 +50,41 @@ class TestShift(unittest.TestCase):
shifted_data = image.shift_img(data, dx=1.0, dy=-1.0)
self.assertTrue(numpy.allclose(data[:-1, 1:], shifted_data[1:, :-1]))
@unittest.skipIf(os.path.exists(data_dir) is False, reason='Dataset source folder is not existing')
def testShiftGuess(self):
"""Test the algorithm to propose a shift"""
def createDataset():
data_file_pattern = os.path.join(self.data_dir,
'reduced_strain/strain_0000.edf')
assert os.path.exists(data_file_pattern)
ff_files = []
dir_ff = os.path.join(self.data_dir, "bg_ff_5s_1x1/")
[ff_files.append(os.path.join(dir_ff, _file)) for _file in
os.listdir(dir_ff)]
dim1 = experiment._Dim(kind=experiment.POSITIONER_METADATA,
name='diffry', relative_prev_val=True,
size=31)
dim2 = experiment._Dim(kind=experiment.POSITIONER_METADATA,
name='obpitch')
self._dims = {0: dim1, 1: dim2}
return experiment.Dataset(data_files_pattern=data_file_pattern,
ff_files=ff_files)
dataset = createDataset()
_experiment = experiment.Experiment(dataset=dataset)
# print(image.guess_shift(data=_experiment.data_flatten, axis=0))
res = (image.guess_shift(data=_experiment.data_flatten, axis=1,
start=-0.0005, stop=0.0005, step=0.0001))
self.assertTrue(numpy.isclose(
image.guess_shift(data=_experiment.data_flatten, axis=0), res))
self.assertTrue(numpy.isclose(
image.guess_shift(data=_experiment.data_flatten, axis=1), 0.0))
def suite():
test_suite = unittest.TestSuite()
......
......@@ -32,7 +32,7 @@ from silx.gui import qt
from silx.gui.plot import Plot2D
from id06workflow.core.experiment.operation.shift import Shift as ShiftOperation
from id06workflow.core.experiment.operation.ThreadedOperation import ThreadedOperation
from id06workflow.core.experiment import Experiment
from id06workflow.core.image import guess_shift
from silx.gui.plot.StackView import StackViewMainWindow
from id06workflow.gui.settings import DEFAULT_COLORMAP
from functools import partial
......@@ -78,6 +78,9 @@ class ShiftCorrectionWidget(qt.QWidget):
self.getShift = self._control.getShift
self.isProcessing = self._operation_thread.isRunning
# Signal?slot estimation
self._control._estimate_button.pressed.connect(self._make_estimation)
def setShift(self, dx, dy, dz):
self._control.setShift(dx=dx, dy=dy, dz=dz)
self._updateShift()
......@@ -134,6 +137,13 @@ class ShiftCorrectionWidget(qt.QWidget):
assert type(progress) is int
self.sigProgress.emit(progress)
def _make_estimation(self):
if self._experiment is None:
_logger.warning('No experiment set, unable to estimate fit')
x_shift = guess_shift(data=self._experiment.data_flatten, axis=-1)
y_shift = guess_shift(data=self._experiment.data_flatten, axis=-2)
self._control.setShift(dx=x_shift, dy=y_shift)
class _DxDyDzWidget(qt.QWidget):
"""
......@@ -160,16 +170,8 @@ class _DxDyDzWidget(qt.QWidget):
self.layout().addWidget(self._dzLE, 3, 1)
self._estimate_button = qt.QPushButton('estimate', parent=self)
self._estimate_button.pressed.connect(self._make_estimation)
self.layout().addWidget(self._estimate_button, 0, 0, 1, 2)
def _make_estimation(self):
msg = qt.QMessageBox(self)
msg.setIcon(qt.QMessageBox.Information)
text = "algorithm of the shift estimation has not been implemented yet"
msg.setText(text)
msg.exec_()
def getDx(self):
"""
......@@ -198,13 +200,16 @@ class _DxDyDzWidget(qt.QWidget):
"""
return (self.getDx(), self.getDy(), self.getDz())
def setShift(self, dx, dy, dz):
def setShift(self, dx=None, dy=None, dz=None):
"""
:param float dx: shift translation in x (for now on the stack of image)
:param float dy: shift translation in y (for now on the stack of image)
:param float dz: shift translation in z (for now on the stack of image)
:param None or float dx: shift translation in x (for now on the stack of image)
:param None or float dy: shift translation in y (for now on the stack of image)
:param None or float dz: shift translation in z (for now on the stack of image)
"""
self._dxLE.setText(str(dx))
self._dyLE.setText(str(dy))
self._dzLE.setText(str(dz))
if dx is not None:
self._dxLE.setText(str(dx))
if dy is not None:
self._dyLE.setText(str(dy))
if dz is not None:
self._dzLE.setText(str(dz))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment