Commit 857aac85 authored by Pierre Paleo's avatar Pierre Paleo
Browse files

Add new test for multiple flatfield

parent f55e6f1e
......@@ -330,3 +330,102 @@ class TestFlatFieldH5:
cuda_flatfield.normalize_radios(d_projs)
projs = d_projs.get()
self.check_normalization(projs)
#
# Another test with more than two flats.
#
# Here we have
#
# F_i = i + 2
# R_i = i*(F_i - 1) + 1
# N_i = (R_i - D)/(F_i - D) = i*(F_i - 1)/( F_i - 1) = i
#
def generate_test_flatfield(n_radios, radio_shape, flat_interval, h5_fname):
radios = np.zeros((n_radios, ) + radio_shape, "f")
dark_data = np.ones(radios.shape[1:], "f")
tempdir = mkdtemp(prefix="nabu_")
testffname = os.path.join(tempdir, h5_fname)
flats = {}
flats_urls = {}
# F_i = i + 2
# R_i = i*(F_i - 1) + 1
# N_i = (R_i - D)/(F_i - D) = i*(F_i - 1)/( F_i - 1) = i
for i in range(n_radios):
f_i = i + 2
if (i % flat_interval) == 0:
flats["flats_%04d" % i] = np.zeros(radio_shape, "f") + f_i
flats_urls[i] = DataUrl(file_path=testffname, data_path=str("/flats/flats_%04d" % i), scheme="silx")
radios[i] = i * (f_i - 1) + 1
dark = {"dark_0000": dark_data}
dicttoh5(flats, testffname, h5path="/flats", mode="w")
dicttoh5(dark, testffname, h5path="/dark", mode="a")
dark_url = {0: DataUrl(file_path=testffname, data_path="/dark/dark_0000", scheme="silx")}
return radios, flats_urls, dark_url
@pytest.fixture(scope='class')
def bootstrap_multiflats(request):
cls = request.cls
n_radios = 50
radio_shape = (20, 21)
cls.flat_interval = 11
h5_fname = "testff.h5"
radios, flats, dark = generate_test_flatfield(
n_radios, radio_shape, cls.flat_interval, h5_fname
)
cls.radios = radios
cls.flats_urls = flats
cls.darks_urls = dark
cls.expected_results = np.arange(n_radios)
cls.tol = 5e-4
cls.tol_std = 1e-4
@pytest.mark.usefixtures('bootstrap_multiflats')
class TestFlatFieldMultiFlat:
def check_normalization(self, projs):
# Check that each projection is filled with the same values
std_projs = np.std(projs, axis=(-2, -1))
assert np.max(np.abs(std_projs)) < self.tol_std
# Check that the normalized radios are equal to 0, 1, 2, ...
stop = (projs.shape[0] // self.flat_interval) * self.flat_interval
errs = projs[:stop, 0, 0] - self.expected_results[:stop]
assert np.max(np.abs(errs)) < self.tol, "Something wrong with flat-field normalization"
def test_flatfield(self):
flatfield = FlatField(
self.radios.shape,
self.flats_urls,
self.darks_urls,
interpolation="linear"
)
projs = np.copy(self.radios)
flatfield.normalize_radios(projs)
print(projs[:, 0, 0])
self.check_normalization(projs)
@pytest.mark.skipif(not(__has_pycuda__), reason="Need cuda/pycuda for this test")
def test_cuda_flatfield(self):
d_projs = garray.to_gpu(self.radios)
cuda_flatfield = CudaFlatField(
self.radios.shape,
self.flats_urls,
self.darks_urls,
)
cuda_flatfield.normalize_radios(d_projs)
projs = d_projs.get()
self.check_normalization(projs)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment