Commit dca8bd41 authored by Julia Garriga Ferrer's avatar Julia Garriga Ferrer
Browse files

Merge branch 'auto_shift_on_dimension' into 'master'

Auto shift on dimension

See merge request !134
parents d5dd3a11 2f8fe068
Pipeline #53242 passed with stage
in 4 minutes and 35 seconds
......@@ -285,6 +285,8 @@ class Dataset():
return data.flatten()
else:
data = self.data.flatten()
if return_indices:
return (data, numpy.arange(len(data))) if indices is None else (data[indices], indices)
return data if indices is None else data[indices]
@property
......@@ -754,6 +756,28 @@ class Dataset():
"""
return shift_detection(self.get_data(indices, dimension), steps)
def find_shift_along_dimension(self, dimension, steps=50, indices=None):
shift = []
for value in range(self.dims.get(dimension[0]).size):
s = self.find_shift([dimension[0], value], steps, indices)
shift.append(s)
return numpy.array(shift)
def apply_shift_along_dimension(self, shift, dimension, shift_approach="fft", indices=None,
callback=None, _dir=None):
dataset = self
for value in range(self.dims.get(dimension[0]).size):
data, rindices = self.get_data(indices=indices, dimension=[dimension[0], value],
return_indices=True)
frames = numpy.arange(self.get_data(indices=indices,
dimension=[dimension[0], value]).shape[0])
dataset = dataset.apply_shift(numpy.outer(shift[value], frames), [dimension[0], value],
shift_approach, indices, callback, _dir)
return dataset
def apply_shift(self, shift, dimension=None, shift_approach="fft", indices=None,
callback=None, _dir=None):
"""
......@@ -782,7 +806,7 @@ class Dataset():
if not os.path.isdir(_dir):
os.mkdir(_dir)
data = self.get_data(indices, dimension)
data, rindices = self.get_data(indices, dimension, return_indices=True)
self._lock.acquire()
self.operations_state[Operation.SHIFT] = 1
self._lock.release()
......@@ -792,7 +816,7 @@ class Dataset():
_file.create_dataset("update_dataset", data=_file["dataset"])
dataset_name = "update_dataset"
else:
_file.create_dataset("dataset", data.shape, dtype=data.dtype)
_file.create_dataset("dataset", self.get_data().shape, dtype=self.data.dtype)
io_utils.advancement_display(0, len(data), "Applying shift")
if dimension is not None:
......@@ -802,16 +826,16 @@ class Dataset():
dimension[0] = [dimension[0]]
dimension[1] = [dimension[1]]
urls = []
for i in range(len(data)):
for i, idx in enumerate(rindices):
if not self.operations_state[Operation.SHIFT]:
del _file["update_dataset"]
return
img = apply_shift(data[i], shift[:, i], shift_approach)
if shift[:, i].all() > 1:
shift_approach = "linear"
_file[dataset_name][i] = img
urls.append(DataUrl(file_path=_dir + '/data.hdf5', data_path="/dataset", data_slice=i, scheme='silx'))
io_utils.advancement_display(i + 1, len(data), "Applying shift")
_file[dataset_name][idx] = img
urls.append(DataUrl(file_path=_dir + '/data.hdf5', data_path="/dataset", data_slice=idx, scheme='silx'))
io_utils.advancement_display(i + 1, len(rindices), "Applying shift")
# Replace specific urls that correspond to the modified data
new_urls = numpy.array(self.data.urls, dtype=object)
......@@ -1339,18 +1363,16 @@ class Data(numpy.ndarray):
def __new__(cls, urls, metadata, in_memory=True, data=None):
urls = numpy.asarray(urls)
if in_memory:
if data is not None and urls.shape == data.shape[:-2]:
input_data = data
else:
if data is None or urls.shape != data.shape[:-2]:
# Create array as stack of images
input_data = []
for url in urls.flatten():
input_data += [utils.get_data(url)]
input_data = numpy.asarray(input_data)
input_data.append(utils.get_data(url))
data = numpy.asarray(input_data)
shape = list(urls.shape)
shape.append(input_data.shape[-2])
shape.append(input_data.shape[-1])
obj = input_data.reshape(shape).view(cls)
shape.append(data.shape[-2])
shape.append(data.shape[-1])
obj = data.reshape(shape).view(cls)
else:
# Access image one at a time using url
obj = super(Data, cls).__new__(cls, urls.shape)
......
......@@ -231,6 +231,8 @@ def shift_detection(data, steps, shift_approach="linear"):
v = normalize(shift)
v_ = 2 * shift / len(data)
h = numpy.sqrt(v_[0]**2 + v_[1]**2)
if not h:
return numpy.outer([0, 0], numpy.arange(len(data)))
epsilon = 2 * h
h = improve_linear_shift(data, v, h, epsilon, steps, shift_approach=shift_approach)
return numpy.outer(h * v, numpy.arange(len(data)))
......
......@@ -229,6 +229,46 @@ class TestDimension(unittest.TestCase):
self.assertEqual(new_dataset.data.urls[0, 0, 0], dataset.data.urls[0, 0, 0])
self.assertNotEqual(new_dataset.data.urls[0, 0, 1], dataset.data.urls[0, 0, 1])
def test_find_shift_along_dimension(self):
""" Tests the shift detection along a dimension"""
# In memory
self.in_memory_dataset.find_dimensions(POSITIONER_METADATA)
dataset = self.in_memory_dataset.reshape_data()
indices = numpy.arange(10)
shift = dataset.find_shift_along_dimension(dimension=[1], indices=indices)
self.assertEqual(shift.shape, (2, 2, 5))
shift = dataset.find_shift_along_dimension(dimension=[0], indices=indices)
self.assertEqual(shift.shape, (5, 2, 2))
# In disk
self.in_disk_dataset.find_dimensions(POSITIONER_METADATA)
dataset = self.in_disk_dataset.reshape_data()
indices = numpy.arange(10)
shift = dataset.find_shift_along_dimension(dimension=[1], indices=indices)
self.assertEqual(shift.shape, (2, 2, 5))
shift = dataset.find_shift_along_dimension(dimension=[0], indices=indices)
self.assertEqual(shift.shape, (5, 2, 2))
def test_apply_shift_along_dimension(self):
""" Tests the shift correction with dimensions and indices"""
# In memory
self.in_memory_dataset.find_dimensions(POSITIONER_METADATA)
dataset = self.in_memory_dataset.reshape_data()
shift = numpy.random.random((4, 2, 2))
new_dataset = dataset.apply_shift_along_dimension(shift=shift, dimension=[1], indices=[1, 2, 3, 4])
self.assertEqual(new_dataset.data.urls[0, 0, 0], dataset.data.urls[0, 0, 0])
self.assertNotEqual(new_dataset.data.urls[0, 0, 1], dataset.data.urls[0, 0, 1])
# In disk
self.in_disk_dataset.find_dimensions(POSITIONER_METADATA)
dataset = self.in_disk_dataset.reshape_data()
shift = numpy.random.random((4, 2, 2))
new_dataset = dataset.apply_shift_along_dimension(shift=shift, dimension=[1], indices=[1, 2, 3, 4])
self.assertEqual(new_dataset.data.urls[0, 0, 0], dataset.data.urls[0, 0, 0])
self.assertNotEqual(new_dataset.data.urls[0, 0, 1], dataset.data.urls[0, 0, 1])
def test_zsum(self):
""" Tests the shift detection with dimensions and indices"""
......
......@@ -87,7 +87,8 @@ class ShiftCorrectionWidget(qt.QMainWindow):
qt.QMainWindow.__init__(self, parent)
self.setWindowFlags(qt.Qt.Widget)
self._shift = [0, 0]
self._shift = numpy.array([0, 0])
self._filtered_shift = None
self._dimension = None
self._update_dataset = None
self.indices = None
......@@ -116,6 +117,8 @@ class ShiftCorrectionWidget(qt.QMainWindow):
self._inputDock.widget.correctionB.clicked.connect(self.correct)
self._inputDock.widget.abortB.clicked.connect(self.abort)
self._inputDock.widget.dxLE.editingFinished.connect(self._updateShiftValue)
self._inputDock.widget.dyLE.editingFinished.connect(self._updateShiftValue)
self._inputDock.widget._findShiftB.clicked.connect(self._findShift)
self._chooseDimensionDock.widget.filterChanged.connect(self._filterStack)
self._chooseDimensionDock.widget.stateDisabled.connect(self._wholeStack)
......@@ -149,11 +152,14 @@ class ShiftCorrectionWidget(qt.QMainWindow):
"""
dx = self._inputDock.widget.getDx()
dy = self._inputDock.widget.getDy()
self.shift = [dy, dx]
dimension = self._dimension if not self._inputDock.widget.checkbox.isChecked() else None
frames = numpy.arange(self._update_dataset.get_data(indices=self.indices, dimension=dimension).shape[0])
self.thread_correction = OperationThread(self, self._update_dataset.apply_shift)
self.thread_correction.setArgs(numpy.outer(self.shift, frames), dimension, indices=self.indices)
self.shift = numpy.array([dy, dx])
if self._filtered_shift is None or self._inputDock.widget.checkbox.isChecked():
frames = numpy.arange(self._update_dataset.get_data(indices=self.indices, dimension=self._dimension).shape[0])
self.thread_correction = OperationThread(self, self._update_dataset.apply_shift)
self.thread_correction.setArgs(numpy.outer(self.shift, frames), self._dimension, indices=self.indices)
else:
self.thread_correction = OperationThread(self, self._update_dataset.apply_shift_along_dimension)
self.thread_correction.setArgs(self._filtered_shift, self._dimension[0], indices=self.indices)
self.thread_correction.finished.connect(self._updateData)
self._inputDock.widget.correctionB.setEnabled(False)
self._inputDock.widget.abortB.show()
......@@ -172,17 +178,29 @@ class ShiftCorrectionWidget(qt.QMainWindow):
self.sigProgressChanged.emit(progress)
def _findShift(self):
self.thread_detection = OperationThread(self, self._update_dataset.find_shift)
if self._filtered_shift is not None:
self.thread_detection = OperationThread(self, self._update_dataset.find_shift_along_dimension)
self.thread_detection.setArgs(self._dimension[0], indices=self.indices)
else:
self.thread_detection = OperationThread(self, self._update_dataset.find_shift)
self.thread_detection.setArgs(self._dimension, indices=self.indices)
self._inputDock.widget._findShiftB.setEnabled(False)
self.thread_detection.setArgs(self._dimension, indices=self.indices)
self.thread_detection.finished.connect(self._updateShift)
self.thread_detection.start()
self.computingSignal.emit(True)
def _updateShiftValue(self):
if self._filtered_shift is not None:
self._filtered_shift[self._dimension[1]] = [self._inputDock.widget.getDy(), self._inputDock.widget.getDx()]
def _updateShift(self):
self._inputDock.widget._findShiftB.setEnabled(True)
self.thread_detection.finished.disconnect(self._updateShift)
self.shift = numpy.round(self.thread_detection.data[:, 1], 5)
if self._filtered_shift is None:
self.shift = numpy.round(self.thread_detection.data[:, 1], 5)
else:
self._filtered_shift = numpy.round(self.thread_detection.data[:, :, 1], 5)
self.shift = self._filtered_shift[self._dimension[1][0]]
self.computingSignal.emit(False)
def _updateData(self):
......@@ -197,8 +215,6 @@ class ShiftCorrectionWidget(qt.QMainWindow):
if self.thread_correction.data:
self._update_dataset = self.thread_correction.data
assert self._update_dataset is not None
if self._inputDock.widget.checkbox.isChecked():
self._chooseDimensionDock.widget._checkbox.setChecked(False)
self.setStack(self._update_dataset)
else:
print("\nCorrection aborted")
......@@ -213,10 +229,7 @@ class ShiftCorrectionWidget(qt.QMainWindow):
if dataset is None:
dataset = self.dataset
nframe = self._sv.getFrameNumber()
if self.indices is None:
self._sv.setStack(dataset.get_data() if dataset is not None else None)
else:
self._sv.setStack(dataset.get_data(self.indices) if dataset is not None else None)
self._sv.setStack(dataset.get_data(self.indices, self._dimension) if dataset is not None else None)
self._sv.setFrameNumber(nframe)
def clearStack(self):
......@@ -224,9 +237,15 @@ class ShiftCorrectionWidget(qt.QMainWindow):
self._inputDock.widget.correctionB.setEnabled(False)
def _filterStack(self, dim=0, val=0):
self._inputDock.widget.checkbox.show()
self._dimension = [dim, val]
data = self._update_dataset.get_data(self.indices, self._dimension)
if self.dataset.dims.ndim == 2:
stack_size = self.dataset.dims.get(dim[0]).size
reset_shift = self._filtered_shift is None or self._filtered_shift.shape[0] != stack_size
self._inputDock.widget.checkbox.show()
self._filtered_shift = numpy.zeros((stack_size, 2)) if reset_shift else self._filtered_shift
self.shift = self._filtered_shift[val[0]]
if data.shape[0]:
self._sv.setStack(data)
else:
......@@ -234,6 +253,8 @@ class ShiftCorrectionWidget(qt.QMainWindow):
def _wholeStack(self):
self._dimension = None
self._filtered_shift = None
self.shift = numpy.array([0, 0])
self._inputDock.widget.checkbox.hide()
self.setStack(self._update_dataset)
......@@ -296,8 +317,8 @@ class _InputWidget(qt.QWidget):
self.correctionB = qt.QPushButton("Correct")
self.abortB = qt.QPushButton("Abort")
self.abortB.hide()
self.checkbox = qt.QCheckBox("Apply to whole dataset")
self.checkbox.setChecked(True)
self.checkbox = qt.QCheckBox("Apply only to selected value")
self.checkbox.setChecked(False)
self.checkbox.hide()
self.dxLE.setValidator(qt.QDoubleValidator())
......
......@@ -36,7 +36,8 @@ import sys
import numpy
from silx.gui import qt
from darfix.test.utils import createDataset
from darfix.test.utils import createRandomDataset
from darfix.core.dimension import POSITIONER_METADATA
from darfix.gui.shiftCorrectionWidget import ShiftCorrectionWidget
......@@ -61,8 +62,9 @@ def exec_():
data = numpy.repeat(data, 10, axis=0)
dataset = createDataset(data=data)
w.setDataset(dataset)
dataset = createRandomDataset((100, 100), nb_data_files=10, header=True)
dataset.find_dimensions(POSITIONER_METADATA)
w.setDataset(dataset.reshape_data())
w.show()
qapp.exec_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment