Commit b9bad08c authored by payno's avatar payno
Browse files

[scanbase] rework flat_field_normalization

parent 2bf4d2e8
......@@ -93,7 +93,7 @@ class TomoScanBase:
first read of data or metadata"""
self._notify_ffc_rsc_missing = True
self._alignment_projections = None
self._flats_weight = None
self._flats_weights = None
@property
def normed_darks(self):
......@@ -446,15 +446,12 @@ class TomoScanBase:
else:
return None
def _flat_field_correction(
def _frame_flat_field_correction(
self,
data: typing.Union[numpy.ndarray, DataUrl],
index_proj: typing.Union[int, None],
dark,
flat1,
flat2,
index_flat1: int,
index_flat2: int,
flat_weights: dict,
):
"""
compute flat field correction for a provided data from is index
......@@ -465,57 +462,40 @@ class TomoScanBase:
data = get_data(data)
can_process = True
if dark is None:
if self._notify_ffc_rsc_missing:
_logger.error("cannot make flat field correction, dark not found")
can_process = False
if dark is not None and dark.ndim != 2:
_logger.error(
"cannot make flat field correction, dark should be of " "dimension 2"
)
can_process = False
if flat1 is None:
if flat_weights in (None, {}):
if self._notify_ffc_rsc_missing:
_logger.error("cannot make flat field correction, flat not found")
can_process = False
else:
if flat1.ndim != 2:
_logger.error(
"cannot make flat field correction, flat should be of "
"dimension 2"
)
can_process = False
if flat2 is not None and flat1.shape != flat2.shape:
_logger.error("the tow flats provided have different shapes.")
can_process = False
if dark is not None and flat1 is not None and dark.shape != flat1.shape:
_logger.error("Given dark and flat have incoherent dimension")
can_process = False
if dark is not None and data.shape != dark.shape:
_logger.error(
"Image has invalid shape. Cannot apply flat field" "correction it"
)
can_process = False
for flat_index, _ in flat_weights.items():
if flat_index not in self.normed_flats:
_logger.error(
"flat {} has been removed, unable to apply flat field"
"".format(flat_index)
)
can_process = False
elif self.normed_flats is not None and self.normed_flats[flat_index].ndim != 2:
_logger.error(
"cannot make flat field correction, flat should be of "
"dimension 2"
)
can_process = False
if can_process is False:
self._notify_ffc_rsc_missing = False
return data
if flat2 is None:
flat_value = flat1
if len(flat_weights) == 1:
flat_value = self.normed_flats[list(flat_weights.keys())[0]]
elif len(flat_weights) == 2:
flat_keys = list(flat_weights.keys())
flat_1 = flat_keys[0]
flat_2 = flat_keys[1]
flat_value = self.normed_flats[flat_1] * flat_weights[flat_1] + self.normed_flats[flat_2] * flat_weights[flat_2]
else:
# compute weight and clip it if necessary
if index_proj is None:
w = 0.5
else:
w = (index_proj - index_flat1) / (index_flat2 - index_flat1)
w = min(1, w)
w = max(0, w)
flat_value = flat1 * w + flat2 * (1 - w)
raise ValueError('no more than two flats are expected and'
'at least one shuold be provided')
div = flat_value - dark
div[div == 0] = 1
......@@ -538,31 +518,45 @@ class TomoScanBase:
"""
assert isinstance(projs, typing.Iterable)
assert isinstance(proj_indexes, typing.Iterable)
flats = self.normed_flats
flat1 = flat2 = None
index_flat1 = index_flat2 = None
if flats is not None:
flat_indexes = sorted(list(flats.keys()))
if len(flats) > 0:
index_flat1 = flat_indexes[0]
flat1 = flats[index_flat1]
if len(flats) > 1:
index_flat2 = flat_indexes[-1]
flat2 = flats[index_flat2]
darks = self.normed_darks
dark = None
def has_missing_keys():
if proj_indexes is None:
return False
for proj_index in proj_indexes:
if proj_index not in self._flats_weights:
return False
return True
if self._flats_weights in (None, {}) or has_missing_keys():
self._flats_weights = self._get_flats_weights()
if self._flats_weights in (None, {}):
_logger.error('Unable to compute flat weights')
darks = self._normed_darks
if darks is not None and len(darks) > 0:
# take only one dark into account for now
dark = list(darks.values())[0]
else:
dark = None
if dark is None:
if self._notify_ffc_rsc_missing:
_logger.error("cannot make flat field correction, dark not found")
return
if dark is not None and dark.ndim != 2:
_logger.error(
"cannot make flat field correction, dark should be of " "dimension 2"
)
return
return [
self._flat_field_correction(
self._frame_flat_field_correction(
data=frame,
dark=dark,
flat1=flat1,
flat2=flat2,
index_flat1=index_flat1,
index_flat2=index_flat2,
index_proj=proj_i,
flat_weights=self._flats_weights[proj_i],
)
for frame, proj_i in zip(projs, proj_indexes)
]
......
......@@ -31,10 +31,16 @@ __date__ = "08/09/2020"
import unittest
import numpy.random
from tomoscan.scanbase import TomoScanBase
import shutil
import tempfile
from silx.io.url import DataUrl
import h5py
import os
class TestFlatFieldCorrection(unittest.TestCase):
def setUp(self):
self.data_dir = tempfile.mkdtemp()
self.scan = TomoScanBase(None, None)
self.scan.set_normed_darks(
{
......@@ -50,13 +56,27 @@ class TestFlatFieldCorrection(unittest.TestCase):
}
)
self._data_urls = {}
projections = {}
file_path = os.path.join(self.data_dir, 'data_file.h5')
for i in range(-2, 30):
projections[i] = numpy.random.random(100).reshape((10, 10))
data_path = '/'.join(('data', str(i)))
self._data_urls[i] = DataUrl(file_path=file_path,
data_path=data_path,
scheme='silx')
with h5py.File(file_path, mode='a') as h5s:
h5s[data_path] = projections[i]
self.scan.projections = projections
def tearDown(self):
shutil.rmtree(self.data_dir)
def test_get_flats_weights(self):
"""test the _get_flats_weights function and insure flat weights
are correct"""
flat_weights = self.scan._get_flats_weights()
self.assertTrue(isinstance(flat_weights, dict))
self.assertEqual(len(flat_weights), 32)
......@@ -77,6 +97,21 @@ class TestFlatFieldCorrection(unittest.TestCase):
assertAlmostEqual(flat_weights[10], {1: 2.0/11.0, 12: 9.0/11.0})
assertAlmostEqual(flat_weights[18], {12: 3.0/9.0, 21: 6.0/9.0})
def test_flat_field_data_url(self):
"""insure the flat_field is computed. Simple processing test when
provided data is a DataUrl"""
projections_keys = [key for key in self.scan.projections.keys()]
projections_urls = [self.scan.projections[key] for key in projections_keys]
self.scan.flat_field_correction(projections_urls, projections_keys)
def test_flat_field_data_numpy_array(self):
"""insure the flat_field is computed. Simple processing test when
provided data is a numpy array"""
self.scan.projections = self._data_urls
projections_keys = [key for key in self.scan.projections.keys()]
projections_urls = [self.scan.projections[key] for key in projections_keys]
self.scan.flat_field_correction(projections_urls, projections_keys)
def suite():
test_suite = unittest.TestSuite()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment