Commit 67c7b9bd authored by payno's avatar payno
Browse files

Merge branch 'speed_up_hdf5_sinogram_load' into 'master'

Speed up hdf5 sinogram load

See merge request !36
parents 90894a21 e8698e96
Pipeline #43918 passed with stages
in 14 minutes and 24 seconds
import os
import h5py
from tomoscan.esrf.hdf5scan import HDF5TomoScan, ImageKey
import numpy
import time
proj_data = numpy.arange(1000, 1000 + 10 * 20 * 30).reshape(30, 10, 20)
proj_angle = numpy.linspace(0, 180, 30)
dark_value = 0.5
flat_value = 1
dark_data = numpy.ones((10, 20)) * dark_value
dark_angle = numpy.array(
[
0,
]
)
flat_data_1 = numpy.ones((10, 20)) * flat_value
flat_angle_1 = numpy.array(
[
0,
]
)
flat_data_2 = numpy.ones((10, 20)) * flat_value
flat_angle_2 = numpy.array(
[
90,
]
)
flat_data_3 = numpy.ones((10, 20)) * flat_value
flat_angle_3 = numpy.array(
[
180,
]
)
# data dataset
data = numpy.empty((34, 10, 20))
data[0] = dark_data
data[1] = flat_data_1
data[2:17] = proj_data[:15]
data[17] = flat_data_2
data[18:33] = proj_data[15:]
data[33] = flat_data_3
def create_arange_dataset(file_path):
if os.path.exists(file_path):
os.remove(file_path)
with h5py.File(file_path, mode="a") as h5f:
entry = h5f.require_group("entry0000")
# rotation angle
assert data.ndim == 3
entry["instrument/detector/data"] = data
rotation_angle = numpy.empty(34)
rotation_angle[0] = dark_angle
rotation_angle[1] = flat_angle_1
rotation_angle[2:17] = proj_angle[:15]
rotation_angle[17] = flat_angle_2
rotation_angle[18:33] = proj_angle[15:]
rotation_angle[33] = flat_angle_3
entry["sample/rotation_angle"] = rotation_angle
# image key / images keys
image_keys = []
image_keys.append(ImageKey.DARK_FIELD.value)
image_keys.append(ImageKey.FLAT_FIELD.value)
image_keys.extend([ImageKey.PROJECTION.value] * 15)
image_keys.append(ImageKey.FLAT_FIELD.value)
image_keys.extend([ImageKey.PROJECTION.value] * 15)
image_keys.append(ImageKey.FLAT_FIELD.value)
entry["instrument/detector/image_key"] = numpy.array(image_keys)
entry["instrument/detector/image_key_control"] = numpy.array(image_keys)
file_path = "test.h5"
create_arange_dataset(file_path)
scan = HDF5TomoScan(file_path, "entry0000")
assert len(scan.projections) == 30
assert len(scan.flats) == 3
assert len(scan.darks) == 1
scan.set_normed_darks(
{
0: dark_data,
}
)
scan.set_normed_flats(
{
1: flat_data_1,
17: flat_data_2,
33: flat_data_3,
}
)
scan._flats_weights = scan._get_flats_weights()
t0 = time.time()
sinogram_old = scan._get_sinogram_ref_imp(line=5)
print("execution time of the old implementation: {}".format(time.time() - t0))
t0 = time.time()
sinogram_new = scan.get_sinogram(line=5)
print("execution time of the new implementation: {}".format(time.time() - t0))
# plot the sinogram
from silx.gui import qt
from silx.gui.plot import Plot2D
app = qt.QApplication([])
plot_old = Plot2D()
plot_old.addImage(sinogram_old)
plot_old.setWindowTitle("old sinogram")
plot_old.show()
plot_new = Plot2D()
plot_new.addImage(sinogram_new)
plot_new.setWindowTitle("new sinogram")
plot_new.show()
raw_sinogram = proj_data[:, 5, :]
plot_raw = Plot2D()
plot_raw.addImage(raw_sinogram)
plot_raw.setWindowTitle("raw sinogram")
plot_raw.show()
# TODO: get the one from tomwer to compare as well
try:
import tomwer.core.scan.hdf5scan
from tomwer.core.scan.hdf5scan import HDF5TomoScan as TomwerHDF5TomoScan
except:
pass
else:
scan_t = TomwerHDF5TomoScan(file_path, "entry0000")
scan_t.set_normed_darks(
{
0: dark_data,
}
)
scan_t.set_normed_flats(
{
1: flat_data_1,
17: flat_data_2,
33: flat_data_3,
}
)
sinogram_tomwer = scan_t.get_sinogram(line=5)
plot_tomwer = Plot2D()
plot_tomwer.addImage(sinogram_tomwer)
plot_tomwer.setWindowTitle("tomwer sinogram")
plot_tomwer.show()
# TODO: check all lines of the sinogram
corrected = (raw_sinogram - dark_value) / (flat_value - dark_value)
numpy.testing.assert_array_equal(corrected, sinogram_old)
app.exec_()
...@@ -833,7 +833,7 @@ class HDF5TomoScan(TomoScanBase): ...@@ -833,7 +833,7 @@ class HDF5TomoScan(TomoScanBase):
url = DataUrl( url = DataUrl(
file_path=self.master_file, file_path=self.master_file,
data_slice=(i_frame), data_slice=(i_frame),
data_path=self.entry + "/instrument/detector/data", data_path=self.get_detector_data_path(),
scheme="silx", scheme="silx",
) )
...@@ -878,6 +878,78 @@ class HDF5TomoScan(TomoScanBase): ...@@ -878,6 +878,78 @@ class HDF5TomoScan(TomoScanBase):
else: else:
return None return None
def _get_sinogram_ref_imp(self, line, subsampling=1):
"""call the reference implementation of get_sinogram.
Used for unit test and insure the result is the same as get_sinogram
"""
return TomoScanBase.get_sinogram(self, line=line, subsampling=subsampling)
@docstring(TomoScanBase)
def get_sinogram(self, line, subsampling=1) -> numpy.array:
if (
self.tomo_n is not None and self.dim_2 is not None and line > self.dim_2
) or line < 0:
raise ValueError("requested line {} is not in the scan".format(line))
if not isinstance(subsampling, int):
raise TypeError("subsampling expected to be an int")
if subsampling <= 0:
raise ValueError("subsampling expected to be higher than 1")
if self.projections is not None:
# get the z line
with HDF5File(self.master_file, mode="r") as h5f:
raw_sinogram = h5f[self.get_detector_data_path()][:, line, :]
assert raw_sinogram.ndim == 2
ignored_projs = []
if self.ignore_projections is not None:
ignored_projs = self.ignore_projections
def is_pure_projection(frame: Frame):
return (
frame.image_key == ImageKey.PROJECTION
and not frame.is_control
and frame.index not in ignored_projs
)
is_projection_array = numpy.array(
[is_pure_projection(frame) for frame in self.frames]
)
# TODO: simplify & reduce with filter or map ?
proj_indexes = []
for x, y in zip(self.frames, is_projection_array):
if y == True:
proj_indexes.append(x.index)
raw_sinogram = raw_sinogram[is_projection_array, :]
assert len(raw_sinogram) == len(proj_indexes)
assert raw_sinogram.ndim == 2
# now apply flat field correction on each line
res = []
for z_frame_raw_sino, proj_index in zip(raw_sinogram, proj_indexes):
assert z_frame_raw_sino.ndim == 1
line_corrected = self.flat_field_correction(
projs=(z_frame_raw_sino,),
proj_indexes=[
proj_index,
],
line=line,
)[0]
assert isinstance(line_corrected, numpy.ndarray)
assert line_corrected.ndim == 1
res.append(line_corrected)
sinogram = numpy.array(res)
assert sinogram.ndim == 2
# apply subsampling (could be speed up but not sure this is useful
# compare to complexity that we would need to had
return sinogram[::subsampling]
else:
return None
def get_detector_data_path(self) -> str:
return self.entry + "/instrument/detector/data"
@property @property
def projections_compacted(self): def projections_compacted(self):
""" """
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
__authors__ = ["H. Payno"] __authors__ = ["H. Payno"]
__license__ = "MIT" __license__ = "MIT"
__date__ = "16/09/2019" __date__ = "26/03/2021"
import unittest import unittest
import shutil import shutil
...@@ -37,6 +37,7 @@ from tomoscan.esrf.hdf5scan import HDF5TomoScan, ImageKey, Frame ...@@ -37,6 +37,7 @@ from tomoscan.esrf.hdf5scan import HDF5TomoScan, ImageKey, Frame
from tomoscan.unitsystem import metricsystem from tomoscan.unitsystem import metricsystem
from silx.io.utils import get_data from silx.io.utils import get_data
import numpy import numpy
import h5py
class HDF5TestBaseClass(unittest.TestCase): class HDF5TestBaseClass(unittest.TestCase):
...@@ -285,7 +286,32 @@ class TestGetSinogram(HDF5TestBaseClass): ...@@ -285,7 +286,32 @@ class TestGetSinogram(HDF5TestBaseClass):
1520: numpy.random.random(20 * 20).reshape((20, 20)), 1520: numpy.random.random(20 * 20).reshape((20, 20)),
} }
) )
self.scan.set_normed_darks({0: numpy.random.random(20 * 20).reshape((20, 20))}) dark = numpy.random.random(20 * 20).reshape((20, 20))
self.scan.set_normed_darks({0: dark})
self.scan._flats_weights = self.scan._get_flats_weights()
self._raw_frame = []
for index, url in self.scan.projections.items():
self._raw_frame.append(get_data(url))
self._raw_frame = numpy.asarray(self._raw_frame)
assert self._raw_frame.ndim == 3
normed_frames = []
for proj_i, z_frame in enumerate(self._raw_frame):
normed_frames.append(
self.scan._frame_flat_field_correction(
data=z_frame,
dark=dark,
flat_weights=self.scan._flats_weights[proj_i]
if proj_i in self.scan._flats_weights
else None,
)
)
self._normed_volume = numpy.array(normed_frames)
assert self._normed_volume.ndim == 3
self._normed_sinogram_12 = self._normed_volume[:, 12, :]
assert self._normed_sinogram_12.ndim is 2
assert self._normed_sinogram_12.shape == (1500, 20)
def testGetSinogram1(self): def testGetSinogram1(self):
sinogram = self.scan.get_sinogram(line=12, subsampling=1) sinogram = self.scan.get_sinogram(line=12, subsampling=1)
...@@ -331,6 +357,116 @@ class TestIgnoredProjections(HDF5TestBaseClass): ...@@ -331,6 +357,116 @@ class TestIgnoredProjections(HDF5TestBaseClass):
) )
class TestGetSinogramLegacy(unittest.TestCase):
def setUp(self) -> None:
self.test_dir = tempfile.mkdtemp()
self.proj_data = numpy.arange(1000, 1000 + 10 * 20 * 30).reshape(30, 10, 20)
self.proj_angle = numpy.linspace(0, 180, 30)
self.dark_value = 0.5
self.flat_value = 1
self.dark_data = numpy.ones((10, 20)) * self.dark_value
self.dark_angle = numpy.array(
[
0,
]
)
self.flat_data_1 = numpy.ones((10, 20)) * self.flat_value
self.flat_angle_1 = numpy.array(
[
0,
]
)
self.flat_data_2 = numpy.ones((10, 20)) * self.flat_value
self.flat_angle_2 = numpy.array(
[
90,
]
)
self.flat_data_3 = numpy.ones((10, 20)) * self.flat_value
self.flat_angle_3 = numpy.array(
[
180,
]
)
# data dataset
self.data = numpy.empty((34, 10, 20))
self.data[0] = self.dark_data
self.data[1] = self.flat_data_1
self.data[2:17] = self.proj_data[:15]
self.data[17] = self.flat_data_2
self.data[18:33] = self.proj_data[15:]
self.data[33] = self.flat_data_3
self.file_path = os.path.join(self.test_dir, "test.h5")
self.create_arange_dataset(self.file_path)
def create_arange_dataset(self, file_path):
if os.path.exists(file_path):
os.remove(file_path)
with h5py.File(file_path, mode="a") as h5f:
entry = h5f.require_group("entry0000")
# rotation angle
entry["instrument/detector/data"] = self.data
rotation_angle = numpy.empty(34)
rotation_angle[0] = self.dark_angle
rotation_angle[1] = self.flat_angle_1
rotation_angle[2:17] = self.proj_angle[:15]
rotation_angle[17] = self.flat_angle_2
rotation_angle[18:33] = self.proj_angle[15:]
rotation_angle[33] = self.flat_angle_3
entry["sample/rotation_angle"] = rotation_angle
# image key / images keys
image_keys = []
image_keys.append(ImageKey.DARK_FIELD.value)
image_keys.append(ImageKey.FLAT_FIELD.value)
image_keys.extend([ImageKey.PROJECTION.value] * 15)
image_keys.append(ImageKey.FLAT_FIELD.value)
image_keys.extend([ImageKey.PROJECTION.value] * 15)
image_keys.append(ImageKey.FLAT_FIELD.value)
entry["instrument/detector/image_key"] = numpy.array(image_keys)
entry["instrument/detector/image_key_control"] = numpy.array(image_keys)
def tearDown(self) -> None:
shutil.rmtree(self.test_dir)
def testImplementations(self):
scan = HDF5TomoScan(self.file_path, "entry0000")
assert len(scan.projections) == 30
assert len(scan.flats) == 3
assert len(scan.darks) == 1
scan.set_normed_darks(
{
0: self.dark_data,
}
)
scan.set_normed_flats(
{
1: self.flat_data_1,
17: self.flat_data_2,
33: self.flat_data_3,
}
)
scan._flats_weights = scan._get_flats_weights()
sinogram_old = scan._get_sinogram_ref_imp(line=5)
sinogram_new = scan.get_sinogram(line=5)
raw_sinogram = self.proj_data[:, 5, :]
corrected = (raw_sinogram - self.dark_value) / (
self.flat_value - self.dark_value
)
numpy.testing.assert_array_equal(corrected, sinogram_new)
numpy.testing.assert_array_equal(sinogram_old, sinogram_new)
def suite(): def suite():
test_suite = unittest.TestSuite() test_suite = unittest.TestSuite()
for ui in ( for ui in (
...@@ -338,6 +474,7 @@ def suite(): ...@@ -338,6 +474,7 @@ def suite():
TestFlatFieldCorrection, TestFlatFieldCorrection,
TestGetSinogram, TestGetSinogram,
TestIgnoredProjections, TestIgnoredProjections,
TestGetSinogramLegacy,
): ):
test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(ui)) test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(ui))
return test_suite return test_suite
......
...@@ -460,7 +460,7 @@ class TomoScanBase: ...@@ -460,7 +460,7 @@ class TomoScanBase:
dim1, dim2 = self.dim_1, self.dim_2 dim1, dim2 = self.dim_1, self.dim_2
y_dim = ceil(self.tomo_n / subsampling) y_dim = ceil(self.tomo_n / subsampling)
sinogram = numpy.empty((y_dim, dim1)) sinogram = numpy.empty((y_dim, dim1))
_logger.info( _logger.debug(
"compute sinogram for line {} of {} (subsampling: {})".format( "compute sinogram for line {} of {} (subsampling: {})".format(
line, self.path, subsampling line, self.path, subsampling
) )
...@@ -473,12 +473,12 @@ class TomoScanBase: ...@@ -473,12 +473,12 @@ class TomoScanBase:
projections = self.projections projections = self.projections
o_keys = list(projections.keys()) o_keys = list(projections.keys())
o_keys.sort() o_keys.sort()
for i_proj, proj_key in enumerate(o_keys): for i_proj, proj_index in enumerate(o_keys):
if i_proj % subsampling == 0: if i_proj % subsampling == 0:
proj_url = projections[proj_key] proj_url = projections[proj_index]
proj = silx.io.utils.get_data(proj_url) proj = silx.io.utils.get_data(proj_url)
proj = self.flat_field_correction( proj = self.flat_field_correction(
projs=[proj], proj_indexes=[i_proj] projs=[proj], proj_indexes=[proj_index]
)[0] )[0]
sinogram[i_proj // subsampling] = proj[line] sinogram[i_proj // subsampling] = proj[line]
advancement.increaseAdvancement(1) advancement.increaseAdvancement(1)
...@@ -489,9 +489,9 @@ class TomoScanBase: ...@@ -489,9 +489,9 @@ class TomoScanBase:
def _frame_flat_field_correction( def _frame_flat_field_correction(
self, self,
data: typing.Union[numpy.ndarray, DataUrl], data: typing.Union[numpy.ndarray, DataUrl],
index_proj: typing.Union[int, None],
dark, dark,
flat_weights: dict, flat_weights: dict,
line: Union[None, int] = None,
): ):
""" """
compute flat field correction for a provided data from is index compute flat field correction for a provided data from is index
...@@ -501,7 +501,6 @@ class TomoScanBase: ...@@ -501,7 +501,6 @@ class TomoScanBase:
if isinstance(data, DataUrl): if isinstance(data, DataUrl):
data = get_data(data) data = get_data(data)
can_process = True can_process = True
if flat_weights in (None, {}): if flat_weights in (None, {}):
if self._notify_ffc_rsc_missing: if self._notify_ffc_rsc_missing:
_logger.error("cannot make flat field correction, flat not found") _logger.error("cannot make flat field correction, flat not found")
...@@ -526,7 +525,10 @@ class TomoScanBase: ...@@ -526,7 +525,10 @@ class TomoScanBase:
if can_process is False: if can_process is False:
self._notify_ffc_rsc_missing = False self._notify_ffc_rsc_missing = False
return data if line is None:
return data
else:
return data[line]
if len(flat_weights) == 1: if len(flat_weights) == 1:
flat_value = self.normed_flats[list(flat_weights.keys())[0]] flat_value = self.normed_flats[list(flat_weights.keys())[0]]
...@@ -534,7 +536,6 @@ class TomoScanBase: ...@@ -534,7 +536,6 @@ class TomoScanBase:
flat_keys = list(flat_weights.keys()) flat_keys = list(flat_weights.keys())
flat_1 = flat_keys[0] flat_1 = flat_keys[0]
flat_2 = flat_keys[1] flat_2 = flat_keys[1]
flat_value = ( flat_value = (
self.normed_flats[flat_1] * flat_weights[flat_1] self.normed_flats[flat_1] * flat_weights[flat_1]
+ self.normed_flats[flat_2] * flat_weights[flat_2] + self.normed_flats[flat_2] * flat_weights[flat_2]
...@@ -544,13 +545,22 @@ class TomoScanBase: ...@@ -544,13 +545,22 @@ class TomoScanBase:
"no more than two flats are expected and" "no more than two flats are expected and"
"at least one shuold be provided" "at least one shuold be provided"
) )
if line is None:
div = flat_value - dark assert data.ndim == 2
div[div == 0] = 1 div = flat_value - dark
return (data - dark) / div div[div == 0] = 1.0
return (data - dark) / div
else:
assert data.ndim == 1
div = flat_value[line] - dark[line]
div[div == 0] = 1
return (data - dark[line]) / div
def flat_field_correction( def flat_field_correction(
self, projs: typing.Iterable, proj_indexes: typing.Iterable self,
projs: typing.Iterable,
proj_indexes: typing.Iterable,
line: Union[None, int] = None,
): ):
"""Apply flat field correction on the given data """Apply flat field correction on the given data
...@@ -561,11 +571,15 @@ class TomoScanBase: ...@@ -561,11 +571,15 @@ class TomoScanBase:
be int or None. If None then the be int or None. If None then the
index take will be the one in the index take will be the one in the
middle of the flats taken. middle of the flats taken.
:param line: index of the line to apply flat filed. If not provided
consider we want to apply flat filed on the entire frame
:type line: None or int
:return: corrected data: list of numpy array :return: corrected data: list of numpy array