diff --git a/darfix/core/dataset.py b/darfix/core/dataset.py index 54f714da1ea4b08a02ffefcdf24c94b68a8c3008..fcd860b69383ee7b2f291cf9d3379a6db3d7b4be 100644 --- a/darfix/core/dataset.py +++ b/darfix/core/dataset.py @@ -285,6 +285,8 @@ class Dataset(): return data.flatten() else: data = self.data.flatten() + if return_indices: + return (data, numpy.arange(len(data))) if indices is None else (data[indices], indices) return data if indices is None else data[indices] @property @@ -754,6 +756,28 @@ class Dataset(): """ return shift_detection(self.get_data(indices, dimension), steps) + def find_shift_along_dimension(self, dimension, steps=50, indices=None): + shift = [] + for value in range(self.dims.get(dimension[0]).size): + s = self.find_shift([dimension[0], value], steps, indices) + shift.append(s) + + return numpy.array(shift) + + def apply_shift_along_dimension(self, shift, dimension, shift_approach="fft", indices=None, + callback=None, _dir=None): + + dataset = self + for value in range(self.dims.get(dimension[0]).size): + data, rindices = self.get_data(indices=indices, dimension=[dimension[0], value], + return_indices=True) + frames = numpy.arange(self.get_data(indices=indices, + dimension=[dimension[0], value]).shape[0]) + dataset = dataset.apply_shift(numpy.outer(shift[value], frames), [dimension[0], value], + shift_approach, indices, callback, _dir) + + return dataset + def apply_shift(self, shift, dimension=None, shift_approach="fft", indices=None, callback=None, _dir=None): """ @@ -782,7 +806,7 @@ class Dataset(): if not os.path.isdir(_dir): os.mkdir(_dir) - data = self.get_data(indices, dimension) + data, rindices = self.get_data(indices, dimension, return_indices=True) self._lock.acquire() self.operations_state[Operation.SHIFT] = 1 self._lock.release() @@ -792,7 +816,7 @@ class Dataset(): _file.create_dataset("update_dataset", data=_file["dataset"]) dataset_name = "update_dataset" else: - _file.create_dataset("dataset", data.shape, dtype=data.dtype) + _file.create_dataset("dataset", self.get_data().shape, dtype=self.data.dtype) io_utils.advancement_display(0, len(data), "Applying shift") if dimension is not None: @@ -802,16 +826,16 @@ class Dataset(): dimension[0] = [dimension[0]] dimension[1] = [dimension[1]] urls = [] - for i in range(len(data)): + for i, idx in enumerate(rindices): if not self.operations_state[Operation.SHIFT]: del _file["update_dataset"] return img = apply_shift(data[i], shift[:, i], shift_approach) if shift[:, i].all() > 1: shift_approach = "linear" - _file[dataset_name][i] = img - urls.append(DataUrl(file_path=_dir + '/data.hdf5', data_path="/dataset", data_slice=i, scheme='silx')) - io_utils.advancement_display(i + 1, len(data), "Applying shift") + _file[dataset_name][idx] = img + urls.append(DataUrl(file_path=_dir + '/data.hdf5', data_path="/dataset", data_slice=idx, scheme='silx')) + io_utils.advancement_display(i + 1, len(rindices), "Applying shift") # Replace specific urls that correspond to the modified data new_urls = numpy.array(self.data.urls, dtype=object) @@ -1339,18 +1363,16 @@ class Data(numpy.ndarray): def __new__(cls, urls, metadata, in_memory=True, data=None): urls = numpy.asarray(urls) if in_memory: - if data is not None and urls.shape == data.shape[:-2]: - input_data = data - else: + if data is None or urls.shape != data.shape[:-2]: # Create array as stack of images input_data = [] for url in urls.flatten(): - input_data += [utils.get_data(url)] - input_data = numpy.asarray(input_data) + input_data.append(utils.get_data(url)) + data = numpy.asarray(input_data) shape = list(urls.shape) - shape.append(input_data.shape[-2]) - shape.append(input_data.shape[-1]) - obj = input_data.reshape(shape).view(cls) + shape.append(data.shape[-2]) + shape.append(data.shape[-1]) + obj = data.reshape(shape).view(cls) else: # Access image one at a time using url obj = super(Data, cls).__new__(cls, urls.shape) diff --git a/darfix/core/imageRegistration.py b/darfix/core/imageRegistration.py index 376c2d9496ba00ca6a49e5298f9f2b5d020a5e2e..ab817b83ff7884b9a147d7d3dff604e110f8680a 100644 --- a/darfix/core/imageRegistration.py +++ b/darfix/core/imageRegistration.py @@ -231,6 +231,8 @@ def shift_detection(data, steps, shift_approach="linear"): v = normalize(shift) v_ = 2 * shift / len(data) h = numpy.sqrt(v_[0]**2 + v_[1]**2) + if not h: + return numpy.outer([0, 0], numpy.arange(len(data))) epsilon = 2 * h h = improve_linear_shift(data, v, h, epsilon, steps, shift_approach=shift_approach) return numpy.outer(h * v, numpy.arange(len(data))) diff --git a/darfix/core/test/test_dimension.py b/darfix/core/test/test_dimension.py index 4f6b603beed9578ff55fd1ffae427622af7ee4f5..838c21f161483f145284667a600a5a2a251441e5 100644 --- a/darfix/core/test/test_dimension.py +++ b/darfix/core/test/test_dimension.py @@ -229,6 +229,46 @@ class TestDimension(unittest.TestCase): self.assertEqual(new_dataset.data.urls[0, 0, 0], dataset.data.urls[0, 0, 0]) self.assertNotEqual(new_dataset.data.urls[0, 0, 1], dataset.data.urls[0, 0, 1]) + def test_find_shift_along_dimension(self): + """ Tests the shift detection along a dimension""" + + # In memory + self.in_memory_dataset.find_dimensions(POSITIONER_METADATA) + dataset = self.in_memory_dataset.reshape_data() + indices = numpy.arange(10) + shift = dataset.find_shift_along_dimension(dimension=[1], indices=indices) + self.assertEqual(shift.shape, (2, 2, 5)) + shift = dataset.find_shift_along_dimension(dimension=[0], indices=indices) + self.assertEqual(shift.shape, (5, 2, 2)) + + # In disk + self.in_disk_dataset.find_dimensions(POSITIONER_METADATA) + dataset = self.in_disk_dataset.reshape_data() + indices = numpy.arange(10) + shift = dataset.find_shift_along_dimension(dimension=[1], indices=indices) + self.assertEqual(shift.shape, (2, 2, 5)) + shift = dataset.find_shift_along_dimension(dimension=[0], indices=indices) + self.assertEqual(shift.shape, (5, 2, 2)) + + def test_apply_shift_along_dimension(self): + """ Tests the shift correction with dimensions and indices""" + + # In memory + self.in_memory_dataset.find_dimensions(POSITIONER_METADATA) + dataset = self.in_memory_dataset.reshape_data() + shift = numpy.random.random((4, 2, 2)) + new_dataset = dataset.apply_shift_along_dimension(shift=shift, dimension=[1], indices=[1, 2, 3, 4]) + self.assertEqual(new_dataset.data.urls[0, 0, 0], dataset.data.urls[0, 0, 0]) + self.assertNotEqual(new_dataset.data.urls[0, 0, 1], dataset.data.urls[0, 0, 1]) + # In disk + self.in_disk_dataset.find_dimensions(POSITIONER_METADATA) + dataset = self.in_disk_dataset.reshape_data() + shift = numpy.random.random((4, 2, 2)) + new_dataset = dataset.apply_shift_along_dimension(shift=shift, dimension=[1], indices=[1, 2, 3, 4]) + + self.assertEqual(new_dataset.data.urls[0, 0, 0], dataset.data.urls[0, 0, 0]) + self.assertNotEqual(new_dataset.data.urls[0, 0, 1], dataset.data.urls[0, 0, 1]) + def test_zsum(self): """ Tests the shift detection with dimensions and indices""" diff --git a/darfix/gui/shiftCorrectionWidget.py b/darfix/gui/shiftCorrectionWidget.py index c20f68441dbb2df7e522b1af1f2194a609aafbb7..6c8e4edc875b7653096cb7c8ea45a84282cdf06b 100644 --- a/darfix/gui/shiftCorrectionWidget.py +++ b/darfix/gui/shiftCorrectionWidget.py @@ -87,7 +87,8 @@ class ShiftCorrectionWidget(qt.QMainWindow): qt.QMainWindow.__init__(self, parent) self.setWindowFlags(qt.Qt.Widget) - self._shift = [0, 0] + self._shift = numpy.array([0, 0]) + self._filtered_shift = None self._dimension = None self._update_dataset = None self.indices = None @@ -116,6 +117,8 @@ class ShiftCorrectionWidget(qt.QMainWindow): self._inputDock.widget.correctionB.clicked.connect(self.correct) self._inputDock.widget.abortB.clicked.connect(self.abort) + self._inputDock.widget.dxLE.editingFinished.connect(self._updateShiftValue) + self._inputDock.widget.dyLE.editingFinished.connect(self._updateShiftValue) self._inputDock.widget._findShiftB.clicked.connect(self._findShift) self._chooseDimensionDock.widget.filterChanged.connect(self._filterStack) self._chooseDimensionDock.widget.stateDisabled.connect(self._wholeStack) @@ -149,11 +152,14 @@ class ShiftCorrectionWidget(qt.QMainWindow): """ dx = self._inputDock.widget.getDx() dy = self._inputDock.widget.getDy() - self.shift = [dy, dx] - dimension = self._dimension if not self._inputDock.widget.checkbox.isChecked() else None - frames = numpy.arange(self._update_dataset.get_data(indices=self.indices, dimension=dimension).shape[0]) - self.thread_correction = OperationThread(self, self._update_dataset.apply_shift) - self.thread_correction.setArgs(numpy.outer(self.shift, frames), dimension, indices=self.indices) + self.shift = numpy.array([dy, dx]) + if self._filtered_shift is None or self._inputDock.widget.checkbox.isChecked(): + frames = numpy.arange(self._update_dataset.get_data(indices=self.indices, dimension=self._dimension).shape[0]) + self.thread_correction = OperationThread(self, self._update_dataset.apply_shift) + self.thread_correction.setArgs(numpy.outer(self.shift, frames), self._dimension, indices=self.indices) + else: + self.thread_correction = OperationThread(self, self._update_dataset.apply_shift_along_dimension) + self.thread_correction.setArgs(self._filtered_shift, self._dimension[0], indices=self.indices) self.thread_correction.finished.connect(self._updateData) self._inputDock.widget.correctionB.setEnabled(False) self._inputDock.widget.abortB.show() @@ -172,17 +178,29 @@ class ShiftCorrectionWidget(qt.QMainWindow): self.sigProgressChanged.emit(progress) def _findShift(self): - self.thread_detection = OperationThread(self, self._update_dataset.find_shift) + if self._filtered_shift is not None: + self.thread_detection = OperationThread(self, self._update_dataset.find_shift_along_dimension) + self.thread_detection.setArgs(self._dimension[0], indices=self.indices) + else: + self.thread_detection = OperationThread(self, self._update_dataset.find_shift) + self.thread_detection.setArgs(self._dimension, indices=self.indices) self._inputDock.widget._findShiftB.setEnabled(False) - self.thread_detection.setArgs(self._dimension, indices=self.indices) self.thread_detection.finished.connect(self._updateShift) self.thread_detection.start() self.computingSignal.emit(True) + def _updateShiftValue(self): + if self._filtered_shift is not None: + self._filtered_shift[self._dimension[1]] = [self._inputDock.widget.getDy(), self._inputDock.widget.getDx()] + def _updateShift(self): self._inputDock.widget._findShiftB.setEnabled(True) self.thread_detection.finished.disconnect(self._updateShift) - self.shift = numpy.round(self.thread_detection.data[:, 1], 5) + if self._filtered_shift is None: + self.shift = numpy.round(self.thread_detection.data[:, 1], 5) + else: + self._filtered_shift = numpy.round(self.thread_detection.data[:, :, 1], 5) + self.shift = self._filtered_shift[self._dimension[1][0]] self.computingSignal.emit(False) def _updateData(self): @@ -197,8 +215,6 @@ class ShiftCorrectionWidget(qt.QMainWindow): if self.thread_correction.data: self._update_dataset = self.thread_correction.data assert self._update_dataset is not None - if self._inputDock.widget.checkbox.isChecked(): - self._chooseDimensionDock.widget._checkbox.setChecked(False) self.setStack(self._update_dataset) else: print("\nCorrection aborted") @@ -213,10 +229,7 @@ class ShiftCorrectionWidget(qt.QMainWindow): if dataset is None: dataset = self.dataset nframe = self._sv.getFrameNumber() - if self.indices is None: - self._sv.setStack(dataset.get_data() if dataset is not None else None) - else: - self._sv.setStack(dataset.get_data(self.indices) if dataset is not None else None) + self._sv.setStack(dataset.get_data(self.indices, self._dimension) if dataset is not None else None) self._sv.setFrameNumber(nframe) def clearStack(self): @@ -224,9 +237,15 @@ class ShiftCorrectionWidget(qt.QMainWindow): self._inputDock.widget.correctionB.setEnabled(False) def _filterStack(self, dim=0, val=0): - self._inputDock.widget.checkbox.show() self._dimension = [dim, val] + data = self._update_dataset.get_data(self.indices, self._dimension) + if self.dataset.dims.ndim == 2: + stack_size = self.dataset.dims.get(dim[0]).size + reset_shift = self._filtered_shift is None or self._filtered_shift.shape[0] != stack_size + self._inputDock.widget.checkbox.show() + self._filtered_shift = numpy.zeros((stack_size, 2)) if reset_shift else self._filtered_shift + self.shift = self._filtered_shift[val[0]] if data.shape[0]: self._sv.setStack(data) else: @@ -234,6 +253,8 @@ class ShiftCorrectionWidget(qt.QMainWindow): def _wholeStack(self): self._dimension = None + self._filtered_shift = None + self.shift = numpy.array([0, 0]) self._inputDock.widget.checkbox.hide() self.setStack(self._update_dataset) @@ -296,8 +317,8 @@ class _InputWidget(qt.QWidget): self.correctionB = qt.QPushButton("Correct") self.abortB = qt.QPushButton("Abort") self.abortB.hide() - self.checkbox = qt.QCheckBox("Apply to whole dataset") - self.checkbox.setChecked(True) + self.checkbox = qt.QCheckBox("Apply only to selected value") + self.checkbox.setChecked(False) self.checkbox.hide() self.dxLE.setValidator(qt.QDoubleValidator()) diff --git a/examples/shift_correction.py b/examples/shift_correction.py index 9d01c4845a992065363a6d1656e0133b87309a3e..e5614afe736cee044c724e8731b56c0a79a23536 100644 --- a/examples/shift_correction.py +++ b/examples/shift_correction.py @@ -36,7 +36,8 @@ import sys import numpy from silx.gui import qt -from darfix.test.utils import createDataset +from darfix.test.utils import createRandomDataset +from darfix.core.dimension import POSITIONER_METADATA from darfix.gui.shiftCorrectionWidget import ShiftCorrectionWidget @@ -61,8 +62,9 @@ def exec_(): data = numpy.repeat(data, 10, axis=0) - dataset = createDataset(data=data) - w.setDataset(dataset) + dataset = createRandomDataset((100, 100), nb_data_files=10, header=True) + dataset.find_dimensions(POSITIONER_METADATA) + w.setDataset(dataset.reshape_data()) w.show() qapp.exec_()