Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tomotools
tomoscan
Commits
90a88847
Commit
90a88847
authored
Sep 10, 2020
by
Pierre Paleo
Browse files
Merge branch 'rework_flat_field_correction' into 'master'
Rework flat field correction See merge request
!27
parents
a47f0562
6c121462
Pipeline
#33237
passed with stages
in 8 minutes and 39 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
tomoscan/scanbase.py
View file @
90a88847
...
...
@@ -41,6 +41,7 @@ from silx.io.utils import get_data
import
silx.io.utils
from
math
import
ceil
from
.progress
import
Progress
from
bisect
import
bisect_left
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -81,11 +82,18 @@ class TomoScanBase:
self
.
_notify_ffc_rsc_missing
=
True
"""Should we notify the user if ffc fails because cannot find dark or
flat. Used to avoid several warnings. Only display one"""
self
.
_projections
=
None
self
.
_alignment_projections
=
None
self
.
_flats_weights
=
None
"""list flats indexes to use for flat field correction and associate
weights"""
def
clear_caches
(
self
):
"""clear caches. Might be call if some data changed after
first read of data or metadata"""
self
.
_notify_ffc_rsc_missing
=
True
self
.
_alignment_projections
=
None
self
.
_flats_weights
=
None
@
property
def
normed_darks
(
self
):
...
...
@@ -447,15 +455,12 @@ class TomoScanBase:
else
:
return
None
def
_flat_field_correction
(
def
_frame
_flat_field_correction
(
self
,
data
:
typing
.
Union
[
numpy
.
ndarray
,
DataUrl
],
index_proj
:
typing
.
Union
[
int
,
None
],
dark
,
flat1
,
flat2
,
index_flat1
:
int
,
index_flat2
:
int
,
flat_weights
:
dict
,
):
"""
compute flat field correction for a provided data from is index
...
...
@@ -466,57 +471,48 @@ class TomoScanBase:
data
=
get_data
(
data
)
can_process
=
True
if
dark
is
None
:
if
self
.
_notify_ffc_rsc_missing
:
_logger
.
error
(
"cannot make flat field correction, dark not found"
)
can_process
=
False
if
dark
is
not
None
and
dark
.
ndim
!=
2
:
_logger
.
error
(
"cannot make flat field correction, dark should be of "
"dimension 2"
)
can_process
=
False
if
flat1
is
None
:
if
flat_weights
in
(
None
,
{}):
if
self
.
_notify_ffc_rsc_missing
:
_logger
.
error
(
"cannot make flat field correction, flat not found"
)
can_process
=
False
else
:
if
flat1
.
ndim
!=
2
:
_logger
.
error
(
"cannot make flat field correction, flat should be of "
"dimension 2"
)
can_process
=
False
if
flat2
is
not
None
and
flat1
.
shape
!=
flat2
.
shape
:
_logger
.
error
(
"the tow flats provided have different shapes."
)
can_process
=
False
if
dark
is
not
None
and
flat1
is
not
None
and
dark
.
shape
!=
flat1
.
shape
:
_logger
.
error
(
"Given dark and flat have incoherent dimension"
)
can_process
=
False
if
dark
is
not
None
and
data
.
shape
!=
dark
.
shape
:
_logger
.
error
(
"Image has invalid shape. Cannot apply flat field"
"correction it"
)
can_process
=
False
for
flat_index
,
_
in
flat_weights
.
items
():
if
flat_index
not
in
self
.
normed_flats
:
_logger
.
error
(
"flat {} has been removed, unable to apply flat field"
""
.
format
(
flat_index
)
)
can_process
=
False
elif
(
self
.
normed_flats
is
not
None
and
self
.
normed_flats
[
flat_index
].
ndim
!=
2
):
_logger
.
error
(
"cannot make flat field correction, flat should be of "
"dimension 2"
)
can_process
=
False
if
can_process
is
False
:
self
.
_notify_ffc_rsc_missing
=
False
return
data
if
flat2
is
None
:
flat_value
=
flat1
if
len
(
flat_weights
)
==
1
:
flat_value
=
self
.
normed_flats
[
list
(
flat_weights
.
keys
())[
0
]]
elif
len
(
flat_weights
)
==
2
:
flat_keys
=
list
(
flat_weights
.
keys
())
flat_1
=
flat_keys
[
0
]
flat_2
=
flat_keys
[
1
]
flat_value
=
(
self
.
normed_flats
[
flat_1
]
*
flat_weights
[
flat_1
]
+
self
.
normed_flats
[
flat_2
]
*
flat_weights
[
flat_2
]
)
else
:
# compute weight and clip it if necessary
if
index_proj
is
None
:
w
=
0.5
else
:
w
=
(
index_proj
-
index_flat1
)
/
(
index_flat2
-
index_flat1
)
w
=
min
(
1
,
w
)
w
=
max
(
0
,
w
)
flat_value
=
flat1
*
w
+
flat2
*
(
1
-
w
)
raise
ValueError
(
"no more than two flats are expected and"
"at least one shuold be provided"
)
div
=
flat_value
-
dark
div
[
div
==
0
]
=
1
...
...
@@ -539,31 +535,83 @@ class TomoScanBase:
"""
assert
isinstance
(
projs
,
typing
.
Iterable
)
assert
isinstance
(
proj_indexes
,
typing
.
Iterable
)
flats
=
self
.
normed_flats
flat1
=
flat2
=
None
index_flat1
=
index_flat2
=
None
if
flats
is
not
None
:
flat_indexes
=
sorted
(
list
(
flats
.
keys
()))
if
len
(
flats
)
>
0
:
index_flat1
=
flat_indexes
[
0
]
flat1
=
flats
[
index_flat1
]
if
len
(
flats
)
>
1
:
index_flat2
=
flat_indexes
[
-
1
]
flat2
=
flats
[
index_flat2
]
darks
=
self
.
normed_darks
dark
=
None
def
has_missing_keys
():
if
proj_indexes
is
None
:
return
False
for
proj_index
in
proj_indexes
:
if
proj_index
not
in
self
.
_flats_weights
:
return
False
return
True
if
self
.
_flats_weights
in
(
None
,
{})
or
has_missing_keys
():
self
.
_flats_weights
=
self
.
_get_flats_weights
()
if
self
.
_flats_weights
in
(
None
,
{}):
_logger
.
error
(
"Unable to compute flat weights"
)
darks
=
self
.
_normed_darks
if
darks
is
not
None
and
len
(
darks
)
>
0
:
# take only one dark into account for now
dark
=
list
(
darks
.
values
())[
0
]
else
:
dark
=
None
if
dark
is
None
:
if
self
.
_notify_ffc_rsc_missing
:
_logger
.
error
(
"cannot make flat field correction, dark not found"
)
return
[
get_data
(
proj
)
if
isinstance
(
proj
,
DataUrl
)
else
proj
for
proj
in
projs
]
if
dark
is
not
None
and
dark
.
ndim
!=
2
:
_logger
.
error
(
"cannot make flat field correction, dark should be of "
"dimension 2"
)
return
[
get_data
(
proj
)
if
isinstance
(
proj
,
DataUrl
)
else
proj
for
proj
in
projs
]
return
[
self
.
_flat_field_correction
(
self
.
_
frame_
flat_field_correction
(
data
=
frame
,
dark
=
dark
,
flat1
=
flat1
,
flat2
=
flat2
,
index_flat1
=
index_flat1
,
index_flat2
=
index_flat2
,
index_proj
=
proj_i
,
flat_weights
=
self
.
_flats_weights
[
proj_i
]
if
proj_i
in
self
.
_flats_weights
else
None
,
)
for
frame
,
proj_i
in
zip
(
projs
,
proj_indexes
)
]
def
_get_flats_weights
(
self
):
"""compute flats indexes to use and weights for each projection"""
if
self
.
normed_flats
is
None
:
return
None
flats_indexes
=
sorted
(
self
.
normed_flats
.
keys
())
def
get_weights
(
proj_index
):
if
proj_index
in
flats_indexes
:
return
{
proj_index
:
1.0
}
pos
=
bisect_left
(
flats_indexes
,
proj_index
)
left_pos
=
flats_indexes
[
pos
-
1
]
if
pos
==
0
:
return
{
flats_indexes
[
0
]:
1.0
}
elif
pos
>
len
(
flats_indexes
)
-
1
:
return
{
flats_indexes
[
-
1
]:
1.0
}
else
:
right_pos
=
flats_indexes
[
pos
]
delta
=
right_pos
-
left_pos
return
{
left_pos
:
1
-
(
proj_index
-
left_pos
)
/
delta
,
right_pos
:
1
-
(
right_pos
-
proj_index
)
/
delta
,
}
if
self
.
normed_flats
is
None
or
len
(
self
.
normed_flats
)
==
0
:
return
{}
else
:
res
=
{}
for
proj_index
in
self
.
projections
:
res
[
proj_index
]
=
get_weights
(
proj_index
=
proj_index
)
return
res
tomoscan/test/__init__.py
View file @
90a88847
...
...
@@ -31,12 +31,14 @@ __date__ = "15/05/2017"
import
unittest
from
..esrf
import
test
as
esrf_test
from
.
import
test_factory
from
.
import
test_scanbase
def
suite
(
loader
=
None
):
test_suite
=
unittest
.
TestSuite
()
test_suite
.
addTest
(
esrf_test
.
suite
())
test_suite
.
addTest
(
test_factory
.
suite
())
test_suite
.
addTest
(
test_scanbase
.
suite
())
return
test_suite
...
...
tomoscan/test/test_scanbase.py
0 → 100644
View file @
90a88847
# coding: utf-8
# /*##########################################################################
#
# Copyright (c) 2016-2017 European Synchrotron Radiation Facility
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# ###########################################################################*/
__authors__
=
[
"H. Payno"
]
__license__
=
"MIT"
__date__
=
"08/09/2020"
import
unittest
import
numpy.random
from
tomoscan.scanbase
import
TomoScanBase
import
shutil
import
tempfile
from
silx.io.url
import
DataUrl
import
h5py
import
os
class
TestFlatFieldCorrection
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
data_dir
=
tempfile
.
mkdtemp
()
self
.
scan
=
TomoScanBase
(
None
,
None
)
self
.
scan
.
set_normed_darks
(
{
0
:
numpy
.
random
.
random
(
100
).
reshape
((
10
,
10
)),
}
)
self
.
scan
.
set_normed_flats
(
{
1
:
numpy
.
random
.
random
(
100
).
reshape
((
10
,
10
)),
12
:
numpy
.
random
.
random
(
100
).
reshape
((
10
,
10
)),
21
:
numpy
.
random
.
random
(
100
).
reshape
((
10
,
10
)),
}
)
self
.
_data_urls
=
{}
projections
=
{}
file_path
=
os
.
path
.
join
(
self
.
data_dir
,
"data_file.h5"
)
for
i
in
range
(
-
2
,
30
):
projections
[
i
]
=
numpy
.
random
.
random
(
100
).
reshape
((
10
,
10
))
data_path
=
"/"
.
join
((
"data"
,
str
(
i
)))
self
.
_data_urls
[
i
]
=
DataUrl
(
file_path
=
file_path
,
data_path
=
data_path
,
scheme
=
"silx"
)
with
h5py
.
File
(
file_path
,
mode
=
"a"
)
as
h5s
:
h5s
[
data_path
]
=
projections
[
i
]
self
.
scan
.
projections
=
projections
def
tearDown
(
self
):
shutil
.
rmtree
(
self
.
data_dir
)
def
test_get_flats_weights
(
self
):
"""test the _get_flats_weights function and insure flat weights
are correct"""
flat_weights
=
self
.
scan
.
_get_flats_weights
()
self
.
assertTrue
(
isinstance
(
flat_weights
,
dict
))
self
.
assertEqual
(
len
(
flat_weights
),
32
)
self
.
assertEqual
(
flat_weights
.
keys
(),
self
.
scan
.
projections
.
keys
())
self
.
assertEqual
(
flat_weights
[
-
2
],
{
1
:
1.0
})
self
.
assertEqual
(
flat_weights
[
0
],
{
1
:
1.0
})
self
.
assertEqual
(
flat_weights
[
1
],
{
1
:
1.0
})
self
.
assertEqual
(
flat_weights
[
12
],
{
12
:
1.0
})
self
.
assertEqual
(
flat_weights
[
21
],
{
21
:
1.0
})
self
.
assertEqual
(
flat_weights
[
24
],
{
21
:
1.0
})
def
assertAlmostEqual
(
ddict1
,
ddict2
):
self
.
assertEqual
(
ddict1
.
keys
(),
ddict2
.
keys
())
for
key
in
ddict1
.
keys
():
self
.
assertAlmostEqual
(
ddict1
[
key
],
ddict2
[
key
])
assertAlmostEqual
(
flat_weights
[
2
],
{
1
:
10.0
/
11.0
,
12
:
1.0
/
11.0
})
assertAlmostEqual
(
flat_weights
[
10
],
{
1
:
2.0
/
11.0
,
12
:
9.0
/
11.0
})
assertAlmostEqual
(
flat_weights
[
18
],
{
12
:
3.0
/
9.0
,
21
:
6.0
/
9.0
})
def
test_flat_field_data_url
(
self
):
"""insure the flat_field is computed. Simple processing test when
provided data is a DataUrl"""
projections_keys
=
[
key
for
key
in
self
.
scan
.
projections
.
keys
()]
projections_urls
=
[
self
.
scan
.
projections
[
key
]
for
key
in
projections_keys
]
self
.
scan
.
flat_field_correction
(
projections_urls
,
projections_keys
)
def
test_flat_field_data_numpy_array
(
self
):
"""insure the flat_field is computed. Simple processing test when
provided data is a numpy array"""
self
.
scan
.
projections
=
self
.
_data_urls
projections_keys
=
[
key
for
key
in
self
.
scan
.
projections
.
keys
()]
projections_urls
=
[
self
.
scan
.
projections
[
key
]
for
key
in
projections_keys
]
self
.
scan
.
flat_field_correction
(
projections_urls
,
projections_keys
)
def
suite
():
test_suite
=
unittest
.
TestSuite
()
for
ui
in
(
TestFlatFieldCorrection
,):
test_suite
.
addTest
(
unittest
.
defaultTestLoader
.
loadTestsFromTestCase
(
ui
))
return
test_suite
if
__name__
==
"__main__"
:
unittest
.
main
(
defaultTest
=
"suite"
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment