Commit 67c7b9bd authored by payno's avatar payno
Browse files

Merge branch 'speed_up_hdf5_sinogram_load' into 'master'

Speed up hdf5 sinogram load

See merge request !36
parents 90894a21 e8698e96
Pipeline #43918 passed with stages
in 14 minutes and 24 seconds
import os
import h5py
from tomoscan.esrf.hdf5scan import HDF5TomoScan, ImageKey
import numpy
import time
proj_data = numpy.arange(1000, 1000 + 10 * 20 * 30).reshape(30, 10, 20)
proj_angle = numpy.linspace(0, 180, 30)
dark_value = 0.5
flat_value = 1
dark_data = numpy.ones((10, 20)) * dark_value
dark_angle = numpy.array(
[
0,
]
)
flat_data_1 = numpy.ones((10, 20)) * flat_value
flat_angle_1 = numpy.array(
[
0,
]
)
flat_data_2 = numpy.ones((10, 20)) * flat_value
flat_angle_2 = numpy.array(
[
90,
]
)
flat_data_3 = numpy.ones((10, 20)) * flat_value
flat_angle_3 = numpy.array(
[
180,
]
)
# data dataset
data = numpy.empty((34, 10, 20))
data[0] = dark_data
data[1] = flat_data_1
data[2:17] = proj_data[:15]
data[17] = flat_data_2
data[18:33] = proj_data[15:]
data[33] = flat_data_3
def create_arange_dataset(file_path):
if os.path.exists(file_path):
os.remove(file_path)
with h5py.File(file_path, mode="a") as h5f:
entry = h5f.require_group("entry0000")
# rotation angle
assert data.ndim == 3
entry["instrument/detector/data"] = data
rotation_angle = numpy.empty(34)
rotation_angle[0] = dark_angle
rotation_angle[1] = flat_angle_1
rotation_angle[2:17] = proj_angle[:15]
rotation_angle[17] = flat_angle_2
rotation_angle[18:33] = proj_angle[15:]
rotation_angle[33] = flat_angle_3
entry["sample/rotation_angle"] = rotation_angle
# image key / images keys
image_keys = []
image_keys.append(ImageKey.DARK_FIELD.value)
image_keys.append(ImageKey.FLAT_FIELD.value)
image_keys.extend([ImageKey.PROJECTION.value] * 15)
image_keys.append(ImageKey.FLAT_FIELD.value)
image_keys.extend([ImageKey.PROJECTION.value] * 15)
image_keys.append(ImageKey.FLAT_FIELD.value)
entry["instrument/detector/image_key"] = numpy.array(image_keys)
entry["instrument/detector/image_key_control"] = numpy.array(image_keys)
file_path = "test.h5"
create_arange_dataset(file_path)
scan = HDF5TomoScan(file_path, "entry0000")
assert len(scan.projections) == 30
assert len(scan.flats) == 3
assert len(scan.darks) == 1
scan.set_normed_darks(
{
0: dark_data,
}
)
scan.set_normed_flats(
{
1: flat_data_1,
17: flat_data_2,
33: flat_data_3,
}
)
scan._flats_weights = scan._get_flats_weights()
t0 = time.time()
sinogram_old = scan._get_sinogram_ref_imp(line=5)
print("execution time of the old implementation: {}".format(time.time() - t0))
t0 = time.time()
sinogram_new = scan.get_sinogram(line=5)
print("execution time of the new implementation: {}".format(time.time() - t0))
# plot the sinogram
from silx.gui import qt
from silx.gui.plot import Plot2D
app = qt.QApplication([])
plot_old = Plot2D()
plot_old.addImage(sinogram_old)
plot_old.setWindowTitle("old sinogram")
plot_old.show()
plot_new = Plot2D()
plot_new.addImage(sinogram_new)
plot_new.setWindowTitle("new sinogram")
plot_new.show()
raw_sinogram = proj_data[:, 5, :]
plot_raw = Plot2D()
plot_raw.addImage(raw_sinogram)
plot_raw.setWindowTitle("raw sinogram")
plot_raw.show()
# TODO: get the one from tomwer to compare as well
try:
import tomwer.core.scan.hdf5scan
from tomwer.core.scan.hdf5scan import HDF5TomoScan as TomwerHDF5TomoScan
except:
pass
else:
scan_t = TomwerHDF5TomoScan(file_path, "entry0000")
scan_t.set_normed_darks(
{
0: dark_data,
}
)
scan_t.set_normed_flats(
{
1: flat_data_1,
17: flat_data_2,
33: flat_data_3,
}
)
sinogram_tomwer = scan_t.get_sinogram(line=5)
plot_tomwer = Plot2D()
plot_tomwer.addImage(sinogram_tomwer)
plot_tomwer.setWindowTitle("tomwer sinogram")
plot_tomwer.show()
# TODO: check all lines of the sinogram
corrected = (raw_sinogram - dark_value) / (flat_value - dark_value)
numpy.testing.assert_array_equal(corrected, sinogram_old)
app.exec_()
......@@ -833,7 +833,7 @@ class HDF5TomoScan(TomoScanBase):
url = DataUrl(
file_path=self.master_file,
data_slice=(i_frame),
data_path=self.entry + "/instrument/detector/data",
data_path=self.get_detector_data_path(),
scheme="silx",
)
......@@ -878,6 +878,78 @@ class HDF5TomoScan(TomoScanBase):
else:
return None
def _get_sinogram_ref_imp(self, line, subsampling=1):
"""call the reference implementation of get_sinogram.
Used for unit test and insure the result is the same as get_sinogram
"""
return TomoScanBase.get_sinogram(self, line=line, subsampling=subsampling)
@docstring(TomoScanBase)
def get_sinogram(self, line, subsampling=1) -> numpy.array:
if (
self.tomo_n is not None and self.dim_2 is not None and line > self.dim_2
) or line < 0:
raise ValueError("requested line {} is not in the scan".format(line))
if not isinstance(subsampling, int):
raise TypeError("subsampling expected to be an int")
if subsampling <= 0:
raise ValueError("subsampling expected to be higher than 1")
if self.projections is not None:
# get the z line
with HDF5File(self.master_file, mode="r") as h5f:
raw_sinogram = h5f[self.get_detector_data_path()][:, line, :]
assert raw_sinogram.ndim == 2
ignored_projs = []
if self.ignore_projections is not None:
ignored_projs = self.ignore_projections
def is_pure_projection(frame: Frame):
return (
frame.image_key == ImageKey.PROJECTION
and not frame.is_control
and frame.index not in ignored_projs
)
is_projection_array = numpy.array(
[is_pure_projection(frame) for frame in self.frames]
)
# TODO: simplify & reduce with filter or map ?
proj_indexes = []
for x, y in zip(self.frames, is_projection_array):
if y == True:
proj_indexes.append(x.index)
raw_sinogram = raw_sinogram[is_projection_array, :]
assert len(raw_sinogram) == len(proj_indexes)
assert raw_sinogram.ndim == 2
# now apply flat field correction on each line
res = []
for z_frame_raw_sino, proj_index in zip(raw_sinogram, proj_indexes):
assert z_frame_raw_sino.ndim == 1
line_corrected = self.flat_field_correction(
projs=(z_frame_raw_sino,),
proj_indexes=[
proj_index,
],
line=line,
)[0]
assert isinstance(line_corrected, numpy.ndarray)
assert line_corrected.ndim == 1
res.append(line_corrected)
sinogram = numpy.array(res)
assert sinogram.ndim == 2
# apply subsampling (could be speed up but not sure this is useful
# compare to complexity that we would need to had
return sinogram[::subsampling]
else:
return None
def get_detector_data_path(self) -> str:
return self.entry + "/instrument/detector/data"
@property
def projections_compacted(self):
"""
......
......@@ -26,7 +26,7 @@
__authors__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "16/09/2019"
__date__ = "26/03/2021"
import unittest
import shutil
......@@ -37,6 +37,7 @@ from tomoscan.esrf.hdf5scan import HDF5TomoScan, ImageKey, Frame
from tomoscan.unitsystem import metricsystem
from silx.io.utils import get_data
import numpy
import h5py
class HDF5TestBaseClass(unittest.TestCase):
......@@ -285,7 +286,32 @@ class TestGetSinogram(HDF5TestBaseClass):
1520: numpy.random.random(20 * 20).reshape((20, 20)),
}
)
self.scan.set_normed_darks({0: numpy.random.random(20 * 20).reshape((20, 20))})
dark = numpy.random.random(20 * 20).reshape((20, 20))
self.scan.set_normed_darks({0: dark})
self.scan._flats_weights = self.scan._get_flats_weights()
self._raw_frame = []
for index, url in self.scan.projections.items():
self._raw_frame.append(get_data(url))
self._raw_frame = numpy.asarray(self._raw_frame)
assert self._raw_frame.ndim == 3
normed_frames = []
for proj_i, z_frame in enumerate(self._raw_frame):
normed_frames.append(
self.scan._frame_flat_field_correction(
data=z_frame,
dark=dark,
flat_weights=self.scan._flats_weights[proj_i]
if proj_i in self.scan._flats_weights
else None,
)
)
self._normed_volume = numpy.array(normed_frames)
assert self._normed_volume.ndim == 3
self._normed_sinogram_12 = self._normed_volume[:, 12, :]
assert self._normed_sinogram_12.ndim is 2
assert self._normed_sinogram_12.shape == (1500, 20)
def testGetSinogram1(self):
sinogram = self.scan.get_sinogram(line=12, subsampling=1)
......@@ -331,6 +357,116 @@ class TestIgnoredProjections(HDF5TestBaseClass):
)
class TestGetSinogramLegacy(unittest.TestCase):
def setUp(self) -> None:
self.test_dir = tempfile.mkdtemp()
self.proj_data = numpy.arange(1000, 1000 + 10 * 20 * 30).reshape(30, 10, 20)
self.proj_angle = numpy.linspace(0, 180, 30)
self.dark_value = 0.5
self.flat_value = 1
self.dark_data = numpy.ones((10, 20)) * self.dark_value
self.dark_angle = numpy.array(
[
0,
]
)
self.flat_data_1 = numpy.ones((10, 20)) * self.flat_value
self.flat_angle_1 = numpy.array(
[
0,
]
)
self.flat_data_2 = numpy.ones((10, 20)) * self.flat_value
self.flat_angle_2 = numpy.array(
[
90,
]
)
self.flat_data_3 = numpy.ones((10, 20)) * self.flat_value
self.flat_angle_3 = numpy.array(
[
180,
]
)
# data dataset
self.data = numpy.empty((34, 10, 20))
self.data[0] = self.dark_data
self.data[1] = self.flat_data_1
self.data[2:17] = self.proj_data[:15]
self.data[17] = self.flat_data_2
self.data[18:33] = self.proj_data[15:]
self.data[33] = self.flat_data_3
self.file_path = os.path.join(self.test_dir, "test.h5")
self.create_arange_dataset(self.file_path)
def create_arange_dataset(self, file_path):
if os.path.exists(file_path):
os.remove(file_path)
with h5py.File(file_path, mode="a") as h5f:
entry = h5f.require_group("entry0000")
# rotation angle
entry["instrument/detector/data"] = self.data
rotation_angle = numpy.empty(34)
rotation_angle[0] = self.dark_angle
rotation_angle[1] = self.flat_angle_1
rotation_angle[2:17] = self.proj_angle[:15]
rotation_angle[17] = self.flat_angle_2
rotation_angle[18:33] = self.proj_angle[15:]
rotation_angle[33] = self.flat_angle_3
entry["sample/rotation_angle"] = rotation_angle
# image key / images keys
image_keys = []
image_keys.append(ImageKey.DARK_FIELD.value)
image_keys.append(ImageKey.FLAT_FIELD.value)
image_keys.extend([ImageKey.PROJECTION.value] * 15)
image_keys.append(ImageKey.FLAT_FIELD.value)
image_keys.extend([ImageKey.PROJECTION.value] * 15)
image_keys.append(ImageKey.FLAT_FIELD.value)
entry["instrument/detector/image_key"] = numpy.array(image_keys)
entry["instrument/detector/image_key_control"] = numpy.array(image_keys)
def tearDown(self) -> None:
shutil.rmtree(self.test_dir)
def testImplementations(self):
scan = HDF5TomoScan(self.file_path, "entry0000")
assert len(scan.projections) == 30
assert len(scan.flats) == 3
assert len(scan.darks) == 1
scan.set_normed_darks(
{
0: self.dark_data,
}
)
scan.set_normed_flats(
{
1: self.flat_data_1,
17: self.flat_data_2,
33: self.flat_data_3,
}
)
scan._flats_weights = scan._get_flats_weights()
sinogram_old = scan._get_sinogram_ref_imp(line=5)
sinogram_new = scan.get_sinogram(line=5)
raw_sinogram = self.proj_data[:, 5, :]
corrected = (raw_sinogram - self.dark_value) / (
self.flat_value - self.dark_value
)
numpy.testing.assert_array_equal(corrected, sinogram_new)
numpy.testing.assert_array_equal(sinogram_old, sinogram_new)
def suite():
test_suite = unittest.TestSuite()
for ui in (
......@@ -338,6 +474,7 @@ def suite():
TestFlatFieldCorrection,
TestGetSinogram,
TestIgnoredProjections,
TestGetSinogramLegacy,
):
test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(ui))
return test_suite
......
......@@ -460,7 +460,7 @@ class TomoScanBase:
dim1, dim2 = self.dim_1, self.dim_2
y_dim = ceil(self.tomo_n / subsampling)
sinogram = numpy.empty((y_dim, dim1))
_logger.info(
_logger.debug(
"compute sinogram for line {} of {} (subsampling: {})".format(
line, self.path, subsampling
)
......@@ -473,12 +473,12 @@ class TomoScanBase:
projections = self.projections
o_keys = list(projections.keys())
o_keys.sort()
for i_proj, proj_key in enumerate(o_keys):
for i_proj, proj_index in enumerate(o_keys):
if i_proj % subsampling == 0:
proj_url = projections[proj_key]
proj_url = projections[proj_index]
proj = silx.io.utils.get_data(proj_url)
proj = self.flat_field_correction(
projs=[proj], proj_indexes=[i_proj]
projs=[proj], proj_indexes=[proj_index]
)[0]
sinogram[i_proj // subsampling] = proj[line]
advancement.increaseAdvancement(1)
......@@ -489,9 +489,9 @@ class TomoScanBase:
def _frame_flat_field_correction(
self,
data: typing.Union[numpy.ndarray, DataUrl],
index_proj: typing.Union[int, None],
dark,
flat_weights: dict,
line: Union[None, int] = None,
):
"""
compute flat field correction for a provided data from is index
......@@ -501,7 +501,6 @@ class TomoScanBase:
if isinstance(data, DataUrl):
data = get_data(data)
can_process = True
if flat_weights in (None, {}):
if self._notify_ffc_rsc_missing:
_logger.error("cannot make flat field correction, flat not found")
......@@ -526,7 +525,10 @@ class TomoScanBase:
if can_process is False:
self._notify_ffc_rsc_missing = False
return data
if line is None:
return data
else:
return data[line]
if len(flat_weights) == 1:
flat_value = self.normed_flats[list(flat_weights.keys())[0]]
......@@ -534,7 +536,6 @@ class TomoScanBase:
flat_keys = list(flat_weights.keys())
flat_1 = flat_keys[0]
flat_2 = flat_keys[1]
flat_value = (
self.normed_flats[flat_1] * flat_weights[flat_1]
+ self.normed_flats[flat_2] * flat_weights[flat_2]
......@@ -544,13 +545,22 @@ class TomoScanBase:
"no more than two flats are expected and"
"at least one shuold be provided"
)
div = flat_value - dark
div[div == 0] = 1
return (data - dark) / div
if line is None:
assert data.ndim == 2
div = flat_value - dark
div[div == 0] = 1.0
return (data - dark) / div
else:
assert data.ndim == 1
div = flat_value[line] - dark[line]
div[div == 0] = 1
return (data - dark[line]) / div
def flat_field_correction(
self, projs: typing.Iterable, proj_indexes: typing.Iterable
self,
projs: typing.Iterable,
proj_indexes: typing.Iterable,
line: Union[None, int] = None,
):
"""Apply flat field correction on the given data
......@@ -561,11 +571,15 @@ class TomoScanBase:
be int or None. If None then the
index take will be the one in the
middle of the flats taken.
:param line: index of the line to apply flat filed. If not provided
consider we want to apply flat filed on the entire frame
:type line: None or int
:return: corrected data: list of numpy array
:rtype: list
"""
assert isinstance(projs, typing.Iterable)
assert isinstance(proj_indexes, typing.Iterable)
assert isinstance(line, (type(None), int))
def has_missing_keys():
if proj_indexes is None:
......@@ -575,11 +589,34 @@ class TomoScanBase:
return False
return True
def return_without_correction():
def load_data(proj):
if isinstance(proj, DataUrl):
return get_data(proj)
else:
return proj
if line is not None:
res = [
load_data(proj)[line] if isinstance(proj, DataUrl) else proj
for proj in projs
]
else:
res = [
load_data(proj) if isinstance(proj, DataUrl) else proj
for proj in projs
]
return res
if self._flats_weights in (None, {}) or has_missing_keys():
self._flats_weights = self._get_flats_weights()
if self._flats_weights in (None, {}):
_logger.error("Unable to compute flat weights")
if self._notify_ffc_rsc_missing:
_logger.error("Unable to compute flat weights")
self._notify_ffc_rsc_missing = False
return return_without_correction()
darks = self._normed_darks
if darks is not None and len(darks) > 0:
......@@ -591,30 +628,30 @@ class TomoScanBase:
if dark is None:
if self._notify_ffc_rsc_missing:
_logger.error("cannot make flat field correction, dark not found")
return [
get_data(proj) if isinstance(proj, DataUrl) else proj
for proj in projs
]
self._notify_ffc_rsc_missing = False
return return_without_correction()
if dark is not None and dark.ndim != 2:
_logger.error(
"cannot make flat field correction, dark should be of " "dimension 2"
)
return [
get_data(proj) if isinstance(proj, DataUrl) else proj for proj in projs
if self._notify_ffc_rsc_missing:
_logger.error(
"cannot make flat field correction, dark should be of "
"dimension 2"
)
self._notify_ffc_rsc_missing = False
return return_without_correction()
return numpy.array(
[
self._frame_flat_field_correction(
data=frame,
dark=dark,
flat_weights=