Commit 2bf4d2e8 authored by payno's avatar payno
Browse files

[tomoscanbase] add `_get_flats_weights`

parent 9cd81acf
......@@ -41,6 +41,7 @@ from silx.io.utils import get_data
import silx.io.utils
from math import ceil
from .progress import Progress
from bisect import bisect_left
_logger = logging.getLogger(__name__)
......@@ -81,11 +82,18 @@ class TomoScanBase:
self._notify_ffc_rsc_missing = True
"""Should we notify the user if ffc fails because cannot find dark or
flat. Used to avoid several warnings. Only display one"""
self._projections = None
self._alignment_projections = None
self._flats_weights = None
"""list flats indexes to use for flat field correction and associate
weights"""
def clear_caches(self):
"""clear caches. Might be call if some data changed after
first read of data or metadata"""
self._notify_ffc_rsc_missing = True
self._alignment_projections = None
self._flats_weight = None
@property
def normed_darks(self):
......@@ -558,3 +566,32 @@ class TomoScanBase:
)
for frame, proj_i in zip(projs, proj_indexes)
]
def _get_flats_weights(self):
"""compute flats indexes to use and weights for each projection"""
flats_indexes = sorted(self.normed_flats.keys())
def get_weights(proj_index):
if proj_index in flats_indexes:
return {proj_index: 1.0}
pos = bisect_left(flats_indexes, proj_index)
left_pos = flats_indexes[pos - 1]
if pos == 0:
return {flats_indexes[0]: 1.0}
elif pos > len(flats_indexes) -1:
return {flats_indexes[-1]: 1.0}
else:
right_pos = flats_indexes[pos]
delta = right_pos - left_pos
return {
left_pos: 1 - (proj_index - left_pos) / delta,
right_pos: 1 - (right_pos - proj_index) / delta,
}
if self.normed_flats is None or len(self.normed_flats) == 0:
return {}
else:
res = {}
for proj_index in self.projections:
res[proj_index] = get_weights(proj_index=proj_index)
return res
......@@ -31,12 +31,14 @@ __date__ = "15/05/2017"
import unittest
from ..esrf import test as esrf_test
from . import test_factory
from . import test_scanbase
def suite(loader=None):
test_suite = unittest.TestSuite()
test_suite.addTest(esrf_test.suite())
test_suite.addTest(test_factory.suite())
test_suite.addTest(test_scanbase.suite())
return test_suite
......
# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/
__authors__ = ["H. Payno"]
__license__ = "MIT"
__date__ = "08/09/2020"
import unittest
import numpy.random
from tomoscan.scanbase import TomoScanBase
class TestFlatFieldCorrection(unittest.TestCase):
def setUp(self):
self.scan = TomoScanBase(None, None)
self.scan.set_normed_darks(
{
0: numpy.random.random(100).reshape((10, 10)),
}
)
self.scan.set_normed_flats(
{
1: numpy.random.random(100).reshape((10, 10)),
12: numpy.random.random(100).reshape((10, 10)),
21: numpy.random.random(100).reshape((10, 10)),
}
)
projections = {}
for i in range(-2, 30):
projections[i] = numpy.random.random(100).reshape((10, 10))
self.scan.projections = projections
def test_get_flats_weights(self):
flat_weights = self.scan._get_flats_weights()
self.assertTrue(isinstance(flat_weights, dict))
self.assertEqual(len(flat_weights), 32)
self.assertEqual(flat_weights.keys(), self.scan.projections.keys())
self.assertEqual(flat_weights[-2], {1: 1.0})
self.assertEqual(flat_weights[0], {1: 1.0})
self.assertEqual(flat_weights[1], {1: 1.0})
self.assertEqual(flat_weights[12], {12: 1.0})
self.assertEqual(flat_weights[21], {21: 1.0})
self.assertEqual(flat_weights[24], {21: 1.0})
def assertAlmostEqual(ddict1, ddict2):
self.assertEqual(ddict1.keys(), ddict2.keys())
for key in ddict1.keys():
self.assertAlmostEqual(ddict1[key], ddict2[key])
assertAlmostEqual(flat_weights[2], {1: 10.0/11.0, 12: 1.0/11.0})
assertAlmostEqual(flat_weights[10], {1: 2.0/11.0, 12: 9.0/11.0})
assertAlmostEqual(flat_weights[18], {12: 3.0/9.0, 21: 6.0/9.0})
def suite():
test_suite = unittest.TestSuite()
for ui in (TestFlatFieldCorrection,):
test_suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(ui))
return test_suite
if __name__ == "__main__":
unittest.main(defaultTest="suite")
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment