Commit 03ed965b authored by Julia Garriga Ferrer's avatar Julia Garriga Ferrer
Browse files

Merge branch 'improve_shift' into 'master'

Improve shift

See merge request !127
parents 89dd4664 93ec127f
Pipeline #50059 passed with stage
in 2 minutes and 11 seconds
......@@ -26,7 +26,7 @@
__authors__ = ["J. Garriga"]
__license__ = "MIT"
__date__ = "07/06/2021"
__date__ = "05/07/2021"
import copy
import glob
......@@ -737,7 +737,7 @@ class Dataset():
return Dataset(_dir=roi_dir, data=new_data, dims=self.__dims, transformation=transformation,
in_memory=self._in_memory)
def find_shift(self, dimension=None, h_max=0.5, h_step=0.01, indices=None):
def find_shift(self, dimension=None, steps=50, indices=None):
"""
Find shift of the data or part of it.
......@@ -751,7 +751,7 @@ class Dataset():
:type indices: Union[None, array_like]
:returns: Array with shift per frame.
"""
return shift_detection(self.get_data(indices, dimension), h_max, h_step)
return shift_detection(self.get_data(indices, dimension), steps)
def apply_shift(self, shift, dimension=None, shift_approach="fft", indices=None,
callback=None, _dir=None):
......@@ -797,6 +797,8 @@ class Dataset():
if not self.operations_state[Operation.SHIFT]:
return
filename = _dir + "/data" + str(i).zfill(4) + ".npy"
if shift[:, i].all() > 1:
shift_approach = "linear"
img = apply_shift(data[i], shift[:, i], shift_approach)
numpy.save(filename, img)
urls.append(DataUrl(file_path=filename, scheme='fabio'))
......@@ -828,6 +830,8 @@ class Dataset():
if not self.operations_state[Operation.SHIFT]:
return
filename = _dir + "/data" + str(i).zfill(4) + ".npy"
if shift[:, i].all() > 1:
shift_approach = "linear"
img = apply_shift(data[i], shift[:, i], shift_approach)
numpy.save(filename, img)
urls.append(DataUrl(file_path=filename, scheme='fabio'))
......@@ -846,7 +850,7 @@ class Dataset():
return Dataset(_dir=_dir, data=data, dims=self.__dims, transformation=self.transformation,
in_memory=self._in_memory)
def find_and_apply_shift(self, dimension=None, h_max=0.5, h_step=0.01, shift_approach="fft",
def find_and_apply_shift(self, dimension=None, steps=100, shift_approach="fft",
indices=None, callback=None):
"""
Find the shift of the data or part of it and apply it.
......@@ -863,7 +867,7 @@ class Dataset():
:param Union[function, None] callback: Callback
:returns: Dataset with the new data.
"""
shift = self.find_shift(dimension, h_max, h_step, indices=indices)
shift = self.find_shift(dimension, steps, indices=indices)
return self.apply_shift(shift, dimension, indices=indices)
def _cascade_nmf(self, num_components, iterations, vstep=None, hstep=None, indices=None):
......@@ -949,7 +953,7 @@ class Dataset():
:return: (H, W): The components matrix and the mixing matrix.
"""
bss_dir = self.dir + "/bss/"
bss_dir = self.dir + "/bss"
if not os.path.isdir(bss_dir):
os.mkdir(bss_dir)
if self._in_memory:
......
......@@ -26,7 +26,7 @@
__authors__ = ["J. Garriga"]
__license__ = "MIT"
__date__ = "26/05/2021"
__date__ = "07/07/2021"
import enum
......@@ -169,18 +169,15 @@ def normalize(x):
return x / numpy.linalg.norm(x)
def improve_linear_shift(data, v, h_max, h_step, nimages=None, shift_approach="linear"):
def improve_linear_shift(data, v, h, epsilon, steps, nimages=None, shift_approach="linear"):
"""
Function to find the best shift between the images. It loops ``h_max * h_step`` times,
Function to find the best shift between the images. It loops ``steps`` times,
applying a different shift each, and trying to find the one that has the best result.
:param array_like data: The stack of images.
:param 2-dimensional array_like v: The vector with the direction of the shift.
:param number h_max: The maximum value that h can achieve, being h the shift between
images divided by the vector v (i.e the coordinates of the shift in base v).
:param number h_step: Spacing between the ``h`` tried. For any `` shift = h * v * idx``,
where ``idx`` is the index of the image to apply the shift to, this is the distance
between two adjacent values of h.
:param float epsilon: Maximum value of h
:param int steps: Number of different tries of h.
:param int nimages: The number of images to be used to find the best shift. It has to
be smaller or equal as the length of the data. If it is smaller, the images used
are chosen using `numpy.random.choice`, without replacement.
......@@ -195,23 +192,23 @@ def improve_linear_shift(data, v, h_max, h_step, nimages=None, shift_approach="l
iData = numpy.random.choice(iData, nimages, False)
score = {}
utils.advancement_display(0, h_max, "Finding shift")
for h in numpy.arange(0, h_max, h_step):
utils.advancement_display(0, h + epsilon, "Finding shift")
step = epsilon / steps
for h_ in numpy.arange(0, h + epsilon, step):
result = numpy.zeros(data[0].shape)
for iFrame in iData:
shift = h * v * iFrame
shift = h_ * v * iFrame
result += apply_shift(data[iFrame], shift, shift_approach)
# Compute score using normalized variance
# TODO: add more autofocus options
score[h] = normalized_variance(result)
utils.advancement_display(h + h_step, h_max, "Finding shift")
score[h_] = normalized_variance(result)
utils.advancement_display(h_ + step, h + epsilon, "Finding shift")
optimal_h = max(score.keys(), key=(lambda k: score[k]))
return optimal_h
def shift_detection(data, h_max=0.5, h_step=0.01):
def shift_detection(data, steps, shift_approach="linear"):
"""
Finds the linear shift from a set of images.
......@@ -232,7 +229,10 @@ def shift_detection(data, h_max=0.5, h_step=0.01):
second_sum += data[i]
shift = find_shift(first_sum, second_sum, 1000)
v = normalize(shift)
h = improve_linear_shift(data, v, h_max, h_step)
v_ = 2 * shift / len(data)
h = numpy.sqrt(v_[0]**2 + v_[1]**2)
epsilon = 2 * h
h = improve_linear_shift(data, v, h, epsilon, steps, shift_approach=shift_approach)
return numpy.outer(h * v, numpy.arange(len(data)))
......
......@@ -26,7 +26,7 @@
__authors__ = ["J. Garriga"]
__license__ = "MIT"
__date__ = "23/10/2020"
__date__ = "16/06/2021"
import unittest
......@@ -88,14 +88,14 @@ class TestImageRegistration(unittest.TestCase):
def test_improve_shift(self):
""" Tests the shift improvement"""
h = imageRegistration.improve_linear_shift(self.data, [1, 1], 0.1, 0.1, shift_approach='fft')
numpy.testing.assert_allclose(h, [0, 0])
h = imageRegistration.improve_linear_shift(self.data, [1, 1], 0.1, 0.1, shift_approach='linear')
numpy.testing.assert_allclose([0, 0], [0, 0])
h = imageRegistration.improve_linear_shift(self.data, [1, 1], 0.1, 0.1, 1, shift_approach='fft')
numpy.testing.assert_allclose(h, 0.)
h = imageRegistration.improve_linear_shift(self.data, [1, 1], 0.1, 0.1, 1, shift_approach='linear')
numpy.testing.assert_allclose(h, 0.)
@unittest.skipUnless(scipy, "scipy is missing")
def test_shift_detection10(self):
""" Tests the shift detection with tolerance of 5 decimals"""
""" Tests the shift detection with tolerance of 3 decimals"""
first_frame = numpy.zeros((100, 100))
# Simulating a series of frame with information in the middle.
first_frame[25:75, 25:75] = numpy.random.randint(50, 300, size=(50, 50))
......@@ -104,12 +104,12 @@ class TestImageRegistration(unittest.TestCase):
for i in range(9):
data += [numpy.fft.ifftn(scipy.ndimage.fourier_shift(numpy.fft.fftn(data[-1]), shift)).real]
data = numpy.asanyarray(data, dtype=numpy.int16)
optimal_shift = imageRegistration.shift_detection(data, 2)
optimal_shift = imageRegistration.shift_detection(data, 100, shift_approach="fft")
shift = [[0, -1, -2, -3, -4, -5, -6, -7, -8, -9],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
numpy.testing.assert_allclose(shift, optimal_shift, rtol=1e-05)
numpy.testing.assert_allclose(shift, optimal_shift, rtol=1e-03)
@unittest.skipUnless(scipy, "scipy is missing")
def test_shift_detection01(self):
......@@ -122,12 +122,12 @@ class TestImageRegistration(unittest.TestCase):
for i in range(9):
data += [numpy.fft.ifftn(scipy.ndimage.fourier_shift(numpy.fft.fftn(data[-1]), shift)).real]
data = numpy.asanyarray(data, dtype=numpy.int16)
optimal_shift = imageRegistration.shift_detection(data, 2)
optimal_shift = imageRegistration.shift_detection(data, 100, shift_approach="fft")
shift = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, -1, -2, -3, -4, -5, -6, -7, -8, -9]]
numpy.testing.assert_allclose(shift, optimal_shift, rtol=1e-05)
numpy.testing.assert_allclose(shift, optimal_shift, rtol=1e-03)
@unittest.skipUnless(scipy, "scipy is missing")
def test_shift_detection11(self):
......@@ -141,7 +141,7 @@ class TestImageRegistration(unittest.TestCase):
data += [numpy.fft.ifftn(scipy.ndimage.fourier_shift(numpy.fft.fftn(data[-1]), shift)).real]
data = numpy.asanyarray(data, dtype=numpy.int16)
optimal_shift = imageRegistration.shift_detection(data, 2)
optimal_shift = imageRegistration.shift_detection(data, 100)
shift = [[0, -1, -2, -3, -4, -5, -6, -7, -8, -9],
[0, -1, -2, -3, -4, -5, -6, -7, -8, -9]]
......@@ -159,8 +159,7 @@ class TestImageRegistration(unittest.TestCase):
for i in range(9):
data += [numpy.fft.ifftn(scipy.ndimage.fourier_shift(numpy.fft.fftn(data[-1]), shift)).real]
data = numpy.asanyarray(data, dtype=numpy.int16)
optimal_shift = imageRegistration.shift_detection(data, 1)
optimal_shift = imageRegistration.shift_detection(data, 100)
shift = [[0, -0.5, -1, -1.5, -2, -2.5, -3, -3.5, -4, -4.5],
[0, -0.2, -0.4, -0.6, -0.8, -1, -1.2, -1.4, -1.6, -1.8]]
......@@ -170,7 +169,7 @@ class TestImageRegistration(unittest.TestCase):
""" Tests the shift correction of a [0,0] shift."""
data = imageRegistration.shift_correction(self.data, numpy.outer([0, 0], numpy.arange(3)))
numpy.testing.assert_allclose(data, self.data, rtol=1e-05)
numpy.testing.assert_allclose(data, self.data, rtol=1e-03)
def test_shift_correction01(self):
""" Tests the shift correction of a [0,1] shift."""
......@@ -192,7 +191,7 @@ class TestImageRegistration(unittest.TestCase):
[1, 2, 3, 4, 5]]])
data = imageRegistration.shift_correction(self.data, numpy.outer([1, 0], numpy.arange(3)))
numpy.testing.assert_allclose(data, expected, rtol=1e-05)
numpy.testing.assert_allclose(data, expected, rtol=1e-03)
def test_shift_correction10(self):
""" Tests the shift correction of a [1,0] shift."""
......@@ -313,7 +312,7 @@ class TestReshapedShift(unittest.TestCase):
dataset = self.dataset.reshape_data()
# Detects shift using only images where value 1 of dimension 1 is fixed
optimal_shift = dataset.find_shift(dimension=[1, 0], h_max=1)
optimal_shift = dataset.find_shift(dimension=[1, 0])
shift = [[0, -0.5, -1, -1.5, -2],
[0, -0.2, -0.4, -0.6, -0.8]]
......@@ -334,7 +333,7 @@ class TestReshapedShift(unittest.TestCase):
dataset = self.dataset.reshape_data()
# Detects shift using only images where value 1 of dimension 1 is fixed
optimal_shift = dataset.find_shift(dimension=[0, 0], h_max=3)
optimal_shift = dataset.find_shift(dimension=[0, 0])
shift = [[0, -2.5],
[0, -1]]
......@@ -358,11 +357,12 @@ class TestReshapedShift(unittest.TestCase):
self.dataset.find_dimensions(POSITIONER_METADATA)
dataset = self.dataset.reshape_data()
dataset = dataset.find_and_apply_shift(dimension=[1, 0], h_max=1)
dataset = dataset.find_and_apply_shift(dimension=[1, 0])
for frame in dataset.data.take(0, 0):
print(numpy.max(abs(data[0] - frame)))
# Check if the difference between the shifted frames and the sample frame is small enough
self.assertTrue((abs(data[0] - frame) < 5).all())
self.assertTrue((abs(data[0] - frame) < 6).all())
def tearDown(self):
shutil.rmtree(self._dir)
......@@ -26,7 +26,7 @@
__authors__ = ["J. Garriga"]
__license__ = "MIT"
__date__ = "22/12/2020"
__date__ = "05/07/2021"
# import os
import numpy
......@@ -42,12 +42,46 @@ from .operationThread import OperationThread
from .utils import ChooseDimensionDock
class ShiftCorrectionDialog(qt.QDialog):
"""
Dialog with `ShiftCorrectionWidget` as main window and standard buttons.
"""
okSignal = qt.Signal()
def __init__(self, parent=None):
qt.QDialog.__init__(self, parent)
self.setWindowFlags(qt.Qt.Widget)
types = qt.QDialogButtonBox.Ok
self._buttons = qt.QDialogButtonBox(parent=self)
self._buttons.setStandardButtons(types)
self._buttons.setEnabled(False)
resetB = self._buttons.addButton(self._buttons.Reset)
self.mainWindow = ShiftCorrectionWidget(parent=self)
self.mainWindow.setAttribute(qt.Qt.WA_DeleteOnClose)
self.setLayout(qt.QVBoxLayout())
self.layout().addWidget(self.mainWindow)
self.layout().addWidget(self._buttons)
self._buttons.accepted.connect(self.okSignal.emit)
resetB.clicked.connect(self.mainWindow.resetStack)
self.mainWindow.computingSignal.connect(self._toggleButton)
def setDataset(self, dataset, indices=None, bg_indices=None, bg_dataset=None):
if dataset is not None:
self._buttons.setEnabled(True)
self.mainWindow.setDataset(dataset, indices, bg_indices, bg_dataset)
def _toggleButton(self, state):
self._buttons.button(qt.QDialogButtonBox.Ok).setEnabled(not state)
class ShiftCorrectionWidget(qt.QMainWindow):
"""
A widget to apply shift correction to a stack of images
"""
sigComputed = qt.Signal()
sigProgressChanged = qt.Signal(int)
computingSignal = qt.Signal(bool)
def __init__(self, parent=None):
qt.QMainWindow.__init__(self, parent)
......@@ -124,11 +158,16 @@ class ShiftCorrectionWidget(qt.QMainWindow):
self._inputDock.widget.correctionB.setEnabled(False)
self._inputDock.widget.abortB.show()
self.thread_correction.start()
self.computingSignal.emit(True)
def abort(self):
self._inputDock.widget.abortB.setEnabled(False)
self._update_dataset.stop_operation(Operation.SHIFT)
def resetStack(self):
self._update_dataset = self.dataset
self.setStack()
def updateProgress(self, progress):
self.sigProgressChanged.emit(progress)
......@@ -138,11 +177,13 @@ class ShiftCorrectionWidget(qt.QMainWindow):
self.thread_detection.setArgs(self._dimension, indices=self.indices)
self.thread_detection.finished.connect(self._updateShift)
self.thread_detection.start()
self.computingSignal.emit(True)
def _updateShift(self):
self._inputDock.widget._findShiftB.setEnabled(True)
self.thread_detection.finished.disconnect(self._updateShift)
self.shift = numpy.round(self.thread_detection.data[:, 1], 5)
self.computingSignal.emit(False)
def _updateData(self):
"""
......@@ -152,13 +193,13 @@ class ShiftCorrectionWidget(qt.QMainWindow):
self._inputDock.widget.abortB.hide()
self._inputDock.widget.abortB.setEnabled(True)
self._inputDock.widget.correctionB.setEnabled(True)
self.computingSignal.emit(False)
if self.thread_correction.data:
self._update_dataset = self.thread_correction.data
assert self._update_dataset is not None
if self._inputDock.widget.checkbox.isChecked():
self._chooseDimensionDock.widget._checkbox.setChecked(False)
self.setStack(self._update_dataset)
self.sigComputed.emit()
else:
print("\nCorrection aborted")
......
......@@ -32,7 +32,7 @@ __date__ = "11/12/2020"
from Orange.widgets.settings import Setting
from Orange.widgets.widget import OWWidget, Input, Output
from silx.gui.colors import Colormap
from darfix.gui.shiftCorrectionWidget import ShiftCorrectionWidget
from darfix.gui.shiftCorrectionWidget import ShiftCorrectionDialog
class ShiftCorrectionWidgetOW(OWWidget):
......@@ -60,8 +60,8 @@ class ShiftCorrectionWidgetOW(OWWidget):
def __init__(self):
super().__init__()
self._widget = ShiftCorrectionWidget(parent=self)
self._widget.sigComputed.connect(self._sendSignal)
self._widget = ShiftCorrectionDialog(parent=self)
self._widget.okSignal.connect(self._sendSignal)
self.controlArea.layout().addWidget(self._widget)
if self.shift:
......@@ -86,6 +86,6 @@ class ShiftCorrectionWidgetOW(OWWidget):
Function to emit the new dataset.
"""
self.shift = self._widget.shift
self.Outputs.dataset.send(self._widget.getDataset())
self.Outputs.colormap.send(self._widget.getStackViewColormap())
self.Outputs.dataset.send(self._widget.mainWindow.getDataset())
self.Outputs.colormap.send(self._widget.mainWindow.getStackViewColormap())
self.close()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment