Commit 90a88847 authored by Pierre Paleo's avatar Pierre Paleo
Browse files

Merge branch 'rework_flat_field_correction' into 'master'

Rework flat field correction

See merge request !27
parents a47f0562 6c121462
Pipeline #33237 passed with stages
in 8 minutes and 39 seconds
......@@ -41,6 +41,7 @@ from silx.io.utils import get_data
import silx.io.utils
from math import ceil
from .progress import Progress
from bisect import bisect_left
_logger = logging.getLogger(__name__)
......@@ -81,11 +82,18 @@ class TomoScanBase:
self._notify_ffc_rsc_missing = True
"""Should we notify the user if ffc fails because cannot find dark or
flat. Used to avoid several warnings. Only display one"""
self._projections = None
self._alignment_projections = None
self._flats_weights = None
"""list flats indexes to use for flat field correction and associate
weights"""
def clear_caches(self):
"""clear caches. Might be call if some data changed after
first read of data or metadata"""
self._notify_ffc_rsc_missing = True
self._alignment_projections = None
self._flats_weights = None
@property
def normed_darks(self):
......@@ -447,15 +455,12 @@ class TomoScanBase:
else:
return None
def _flat_field_correction(
def _frame_flat_field_correction(
self,
data: typing.Union[numpy.ndarray, DataUrl],
index_proj: typing.Union[int, None],
dark,
flat1,
flat2,
index_flat1: int,
index_flat2: int,
flat_weights: dict,
):
"""
compute flat field correction for a provided data from is index
......@@ -466,57 +471,48 @@ class TomoScanBase:
data = get_data(data)
can_process = True
if dark is None:
if self._notify_ffc_rsc_missing:
_logger.error("cannot make flat field correction, dark not found")
can_process = False
if dark is not None and dark.ndim != 2:
_logger.error(
"cannot make flat field correction, dark should be of " "dimension 2"
)
can_process = False
if flat1 is None:
if flat_weights in (None, {}):
if self._notify_ffc_rsc_missing:
_logger.error("cannot make flat field correction, flat not found")
can_process = False
else:
if flat1.ndim != 2:
_logger.error(
"cannot make flat field correction, flat should be of "
"dimension 2"
)
can_process = False
if flat2 is not None and flat1.shape != flat2.shape:
_logger.error("the tow flats provided have different shapes.")
can_process = False
if dark is not None and flat1 is not None and dark.shape != flat1.shape:
_logger.error("Given dark and flat have incoherent dimension")
can_process = False
if dark is not None and data.shape != dark.shape:
_logger.error(
"Image has invalid shape. Cannot apply flat field" "correction it"
)
can_process = False
for flat_index, _ in flat_weights.items():
if flat_index not in self.normed_flats:
_logger.error(
"flat {} has been removed, unable to apply flat field"
"".format(flat_index)
)
can_process = False
elif (
self.normed_flats is not None
and self.normed_flats[flat_index].ndim != 2
):
_logger.error(
"cannot make flat field correction, flat should be of "
"dimension 2"
)
can_process = False
if can_process is False:
self._notify_ffc_rsc_missing = False
return data
if flat2 is None:
flat_value = flat1
if len(flat_weights) == 1:
flat_value = self.normed_flats[list(flat_weights.keys())[0]]
elif len(flat_weights) == 2:
flat_keys = list(flat_weights.keys())
flat_1 = flat_keys[0]
flat_2 = flat_keys[1]
flat_value = (
self.normed_flats[flat_1] * flat_weights[flat_1]
+ self.normed_flats[flat_2] * flat_weights[flat_2]
)
else:
# compute weight and clip it if necessary
if index_proj is None:
w = 0.5
else:
w = (index_proj - index_flat1) / (index_flat2 - index_flat1)
w = min(1, w)
w = max(0, w)
flat_value = flat1 * w + flat2 * (1 - w)
raise ValueError(
"no more than two flats are expected and"
"at least one shuold be provided"
)
div = flat_value - dark
div[div == 0] = 1
......@@ -539,31 +535,83 @@ class TomoScanBase:
"""
assert isinstance(projs, typing.Iterable)
assert isinstance(proj_indexes, typing.Iterable)
flats = self.normed_flats
flat1 = flat2 = None
index_flat1 = index_flat2 = None
if flats is not None:
flat_indexes = sorted(list(flats.keys()))
if len(flats) > 0:
index_flat1 = flat_indexes[0]
flat1 = flats[index_flat1]
if len(flats) > 1:
index_flat2 = flat_indexes[-1]
flat2 = flats[index_flat2]
darks = self.normed_darks
dark = None
def has_missing_keys():
if proj_indexes is None:
return False
for proj_index in proj_indexes:
if proj_index not in self._flats_weights:
return False
return True
if self._flats_weights in (None, {}) or has_missing_keys():
self._flats_weights = self._get_flats_weights()
if self._flats_weights in (None, {}):
_logger.error("Unable to compute flat weights")
darks = self._normed_darks
if darks is not None and len(darks) > 0:
# take only one dark into account for now
dark = list(darks.values())[0]
else:
dark = None
if dark is None:
if self._notify_ffc_rsc_missing:
_logger.error("cannot make flat field correction, dark not found")
return [
get_data(proj) if isinstance(proj, DataUrl) else proj
for proj in projs
]
if dark is not None and dark.ndim != 2:
_logger.error(
"cannot make flat field correction, dark should be of " "dimension 2"
)
return [
get_data(proj) if isinstance(proj, DataUrl) else proj for proj in projs
]
return [
self._flat_field_correction(
self._frame_flat_field_correction(
data=frame,
dark=dark,
flat1=flat1,
flat2=flat2,
index_flat1=index_flat1,
index_flat2=index_flat2,
index_proj=proj_i,
flat_weights=self._flats_weights[proj_i]
if proj_i in self._flats_weights
else None,
)
for frame, proj_i in zip(projs, proj_indexes)
]
def _get_flats_weights(self):
"""compute flats indexes to use and weights for each projection"""
if self.normed_flats is None:
return None
flats_indexes = sorted(self.normed_flats.keys())
def get_weights(proj_index):
if proj_index in flats_indexes:
return {proj_index: 1.0}
pos = bisect_left(flats_indexes, proj_index)
left_pos = flats_indexes[pos - 1]
if pos == 0:
return {flats_indexes[0]: 1.0}
elif pos > len(flats_indexes) - 1:
return {flats_indexes[-1]: 1.0}
else:
right_pos = flats_indexes[pos]
delta = right_pos - left_pos
return {
left_pos: 1 - (proj_index - left_pos) / delta,
right_pos: 1 - (right_pos - proj_index) / delta,
}
if self.normed_flats is None or len(self.normed_flats) == 0:
return {}
else:
res = {}
for proj_index in self.projections:
res[proj_index] = get_weights(proj_index=proj_index)
return res
......@@ -31,12 +31,14 @@ __date__ = "15/05/2017"
import unittest
from ..esrf import test as esrf_test
from . import test_factory
from . import test_scanbase
def suite(loader=None):
test_suite = unittest.TestSuite()
test_suite.addTest(esrf_test.suite())
test_suite.addTest(test_factory.suite())
test_suite.addTest(test_scanbase.suite())
return test_suite
......
# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/
__authors__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "08/09/2020"
import unittest
import numpy.random
from tomoscan.scanbase import TomoScanBase
import shutil
import tempfile
from silx.io.url import DataUrl
import h5py
import os
class TestFlatFieldCorrection(unittest.TestCase):
def setUp(self):
self.data_dir = tempfile.mkdtemp()
self.scan = TomoScanBase(None, None)
self.scan.set_normed_darks(
{
0: numpy.random.random(100).reshape((10, 10)),
}
)
self.scan.set_normed_flats(
{
1: numpy.random.random(100).reshape((10, 10)),
12: numpy.random.random(100).reshape((10, 10)),
21: numpy.random.random(100).reshape((10, 10)),
}
)
self._data_urls = {}
projections = {}
file_path = os.path.join(self.data_dir, "data_file.h5")
for i in range(-2, 30):
projections[i] = numpy.random.random(100).reshape((10, 10))
data_path = "/".join(("data", str(i)))
self._data_urls[i] = DataUrl(
file_path=file_path, data_path=data_path, scheme="silx"
)
with h5py.File(file_path, mode="a") as h5s:
h5s[data_path] = projections[i]
self.scan.projections = projections
def tearDown(self):
shutil.rmtree(self.data_dir)
def test_get_flats_weights(self):
"""test the _get_flats_weights function and insure flat weights
are correct"""
flat_weights = self.scan._get_flats_weights()
self.assertTrue(isinstance(flat_weights, dict))
self.assertEqual(len(flat_weights), 32)
self.assertEqual(flat_weights.keys(), self.scan.projections.keys())
self.assertEqual(flat_weights[-2], {1: 1.0})
self.assertEqual(flat_weights[0], {1: 1.0})
self.assertEqual(flat_weights[1], {1: 1.0})
self.assertEqual(flat_weights[12], {12: 1.0})
self.assertEqual(flat_weights[21], {21: 1.0})
self.assertEqual(flat_weights[24], {21: 1.0})
def assertAlmostEqual(ddict1, ddict2):
self.assertEqual(ddict1.keys(), ddict2.keys())
for key in ddict1.keys():
self.assertAlmostEqual(ddict1[key], ddict2[key])
assertAlmostEqual(flat_weights[2], {1: 10.0 / 11.0, 12: 1.0 / 11.0})
assertAlmostEqual(flat_weights[10], {1: 2.0 / 11.0, 12: 9.0 / 11.0})
assertAlmostEqual(flat_weights[18], {12: 3.0 / 9.0, 21: 6.0 / 9.0})
def test_flat_field_data_url(self):
"""insure the flat_field is computed. Simple processing test when
provided data is a DataUrl"""
projections_keys = [key for key in self.scan.projections.keys()]
projections_urls = [self.scan.projections[key] for key in projections_keys]
self.scan.flat_field_correction(projections_urls, projections_keys)
def test_flat_field_data_numpy_array(self):
"""insure the flat_field is computed. Simple processing test when
provided data is a numpy array"""
self.scan.projections = self._data_urls
projections_keys = [key for key in self.scan.projections.keys()]
projections_urls = [self.scan.projections[key] for key in projections_keys]
self.scan.flat_field_correction(projections_urls, projections_keys)
def suite():
test_suite = unittest.TestSuite()
for ui in (TestFlatFieldCorrection,):
test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(ui))
return test_suite
if __name__ == "__main__":
unittest.main(defaultTest="suite")
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment