Commit f6f7b961 authored by Julia Garriga Ferrer's avatar Julia Garriga Ferrer
Browse files

[core][dataset][shift] Fix typo when applying shift on dimension

parent 7c5f9793
...@@ -285,6 +285,8 @@ class Dataset(): ...@@ -285,6 +285,8 @@ class Dataset():
return data.flatten() return data.flatten()
else: else:
data = self.data.flatten() data = self.data.flatten()
if return_indices:
return (data, numpy.arange(len(data))) if indices is None else (data[indices], indices)
return data if indices is None else data[indices] return data if indices is None else data[indices]
@property @property
...@@ -767,6 +769,8 @@ class Dataset(): ...@@ -767,6 +769,8 @@ class Dataset():
dataset = self dataset = self
for value in range(self.dims.get(dimension[0]).size): for value in range(self.dims.get(dimension[0]).size):
data, rindices = self.get_data(indices=indices, dimension=[dimension[0], value],
return_indices=True)
frames = numpy.arange(self.get_data(indices=indices, frames = numpy.arange(self.get_data(indices=indices,
dimension=[dimension[0], value]).shape[0]) dimension=[dimension[0], value]).shape[0])
dataset = dataset.apply_shift(numpy.outer(shift[value], frames), [dimension[0], value], dataset = dataset.apply_shift(numpy.outer(shift[value], frames), [dimension[0], value],
...@@ -802,7 +806,7 @@ class Dataset(): ...@@ -802,7 +806,7 @@ class Dataset():
if not os.path.isdir(_dir): if not os.path.isdir(_dir):
os.mkdir(_dir) os.mkdir(_dir)
data = self.get_data(indices, dimension) data, rindices = self.get_data(indices, dimension, return_indices=True)
self._lock.acquire() self._lock.acquire()
self.operations_state[Operation.SHIFT] = 1 self.operations_state[Operation.SHIFT] = 1
self._lock.release() self._lock.release()
...@@ -812,7 +816,7 @@ class Dataset(): ...@@ -812,7 +816,7 @@ class Dataset():
_file.create_dataset("update_dataset", data=_file["dataset"]) _file.create_dataset("update_dataset", data=_file["dataset"])
dataset_name = "update_dataset" dataset_name = "update_dataset"
else: else:
_file.create_dataset("dataset", data.shape, dtype=data.dtype) _file.create_dataset("dataset", self.get_data().shape, dtype=self.data.dtype)
io_utils.advancement_display(0, len(data), "Applying shift") io_utils.advancement_display(0, len(data), "Applying shift")
if dimension is not None: if dimension is not None:
...@@ -822,16 +826,16 @@ class Dataset(): ...@@ -822,16 +826,16 @@ class Dataset():
dimension[0] = [dimension[0]] dimension[0] = [dimension[0]]
dimension[1] = [dimension[1]] dimension[1] = [dimension[1]]
urls = [] urls = []
for i in range(len(data)): for i, idx in enumerate(rindices):
if not self.operations_state[Operation.SHIFT]: if not self.operations_state[Operation.SHIFT]:
del _file["update_dataset"] del _file["update_dataset"]
return return
img = apply_shift(data[i], shift[:, i], shift_approach) img = apply_shift(data[i], shift[:, i], shift_approach)
if shift[:, i].all() > 1: if shift[:, i].all() > 1:
shift_approach = "linear" shift_approach = "linear"
_file[dataset_name][i] = img _file[dataset_name][idx] = img
urls.append(DataUrl(file_path=_dir + '/data.hdf5', data_path="/dataset", data_slice=i, scheme='silx')) urls.append(DataUrl(file_path=_dir + '/data.hdf5', data_path="/dataset", data_slice=idx, scheme='silx'))
io_utils.advancement_display(i + 1, len(data), "Applying shift") io_utils.advancement_display(i + 1, len(rindices), "Applying shift")
# Replace specific urls that correspond to the modified data # Replace specific urls that correspond to the modified data
new_urls = numpy.array(self.data.urls, dtype=object) new_urls = numpy.array(self.data.urls, dtype=object)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment