test_alignment.py 9.31 KB
Newer Older
1
2
import pytest
import numpy as np
3
4
import os
import h5py
myron's avatar
myron committed
5
import scipy.ndimage
6
from silx.third_party.EdfFile import EdfFile
Nicola Vigano's avatar
Nicola Vigano committed
7

8
from nabu.preproc.alignment import CenterOfRotation, DetectorTranslationAlongBeam, AlignmentBase
9
from nabu.testutils import utilstest
10
11


12
13
14
15
16
17
@pytest.fixture(scope="class")
def bootstrap_base(request):
    cls = request.cls
    cls.abs_tol = 2.5e-2


18
@pytest.fixture(scope="class")
19
def bootstrap_cor(request):
20
    cls = request.cls
21
22
    cls.abs_tol = 0.2

Nicola Vigano's avatar
Nicola Vigano committed
23
24
    cls.data, cls.px = get_data_h5("tworadios.h5")

25
26
27
28

@pytest.fixture(scope="class")
def bootstrap_dtr(request):
    cls = request.cls
Nicola Vigano's avatar
Nicola Vigano committed
29
    cls.abs_tol = 1e-1
30

31
    # loading alignxc edf files from Christian. The last one is the dark. The first 6 are images for translation 0,1...6
32
33
    images = np.array([EdfFile(utilstest.getfile("alignxc%04d.edf" % i)).GetData(0, DataType="FloatValue") for i in range(7)])
    align_imgs, dark_img = images[:-1], images[-1]
34
    # removing dark
35
36
37
    cls.align_images = align_imgs - dark_img
    cls.img_pos = (1 + np.arange(6)) * 0.01

38
    cls.expected_shifts_vh = np.array((129.93, 353.18))
39

40
    cls.reference_shifts_list = [
41
        [0, 0],
42
43
44
45
46
        [-9.39, 11.29],
        [-5.02, 3.81],
        [-3.67, 9.73],
        [-4.72, 16.71],
        [6.02, 20.28],
47
48
    ]

49
50
51
52
53
54
55
56
57

def get_data_h5(*dataset_path):
    """
    Get a dataset file from silx.org/pub/nabu/data
    dataset_args is a list describing a nested folder structures, ex.
    ["path", "to", "my", "dataset.h5"]
    """
    dataset_relpath = os.path.join(*dataset_path)
    dataset_downloaded_path = utilstest.getfile(dataset_relpath)
58
59
60
61
    with h5py.File(dataset_downloaded_path, "r") as hf:
        nxentry = "entry/instrument/detector"
        px = hf[nxentry + "/x_rotation_axis_pixel_position"][()]
        data = hf[nxentry + "/data"][()]
62
63
64
    return data, px


65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
@pytest.mark.usefixtures("bootstrap_base")
class TestAlignmentBase(object):
    def test_peak_fitting_2d_3x3(self):
        # Fit a 3 x 3 grid
        fy = np.linspace(-1, 1, 3)
        fx = np.linspace(-1, 1, 3)
        yy, xx = np.meshgrid(fy, fx, indexing="ij")

        peak_pos_yx = np.random.rand(2) * 1.6 - 0.8
        f_vals = np.exp(-((yy - peak_pos_yx[0]) ** 2 + (xx - peak_pos_yx[1]) ** 2) / 100)

        fitted_peak_pos_yx = AlignmentBase.refine_max_position_2d(f_vals, fy, fx)

        message = (
            "Computed peak position: (%f, %f) " % (*fitted_peak_pos_yx,)
            + " and real peak position (%f, %f) do not coincide." % (*peak_pos_yx,)
            + " Difference: (%f, %f)," % (*(fitted_peak_pos_yx - peak_pos_yx),)
            + " tolerance: %f" % self.abs_tol
        )
        assert np.all(np.isclose(peak_pos_yx, fitted_peak_pos_yx, atol=self.abs_tol)), message

    def test_peak_fitting_2d_error_checking(self):
        # Fit a 3 x 3 grid
        fy = np.linspace(-1, 1, 3)
        fx = np.linspace(-1, 1, 3)
        yy, xx = np.meshgrid(fy, fx, indexing="ij")

        peak_pos_yx = np.random.rand(2) + 1.5
        f_vals = np.exp(-((yy - peak_pos_yx[0]) ** 2 + (xx - peak_pos_yx[1]) ** 2) / 100)

        with pytest.raises(ValueError) as ex:
            AlignmentBase.refine_max_position_2d(f_vals, fy, fx)

        message = (
            "Error should have been raised about the peak being fitted outside margins, "
            + "other error raised instead:\n%s" % str(ex.value)
        )
Nicola Vigano's avatar
Nicola Vigano committed
102
        assert "positions are outide the input margins" in str(ex.value), message
103

104
105
106
107
108
109
110
111
112
113
    def test_extract_peak_regions_1d(self):
        img = np.random.randint(0, 10, size=(8, 8))

        peaks_pos = np.argmax(img, axis=-1)
        peaks_val = np.max(img, axis=-1)

        cc_coords = np.arange(0, 8)

        (found_peaks_val, found_peaks_pos) = AlignmentBase.extract_peak_regions_1d(img, axis=-1, cc_coords=cc_coords)
        message = "The found peak positions do not correspond to the expected peak positions:\n  Expected: %s\n  Found: %s" % (
myron's avatar
myron committed
114
115
            peaks_pos,
            found_peaks_pos[1, :],
116
117
118
        )
        assert np.all(peaks_val == found_peaks_val[1, :]), message

119

120
121
@pytest.mark.usefixtures("bootstrap_cor")
class TestCor(object):
122
    def test_cor_posx(self):
123
        radio1 = self.data[0, :, :]
124
125
        radio2 = np.fliplr(self.data[1, :, :])

126
        CoR_calc = CenterOfRotation()
127
        cor_position = CoR_calc.find_shift(radio1, radio2)
128

129
130
        message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.px
        assert np.abs(self.px - cor_position) < self.abs_tol, message
131
132
133
134
135

    def test_noisy_cor_posx(self):
        radio1 = np.fmax(self.data[0, :, :], 0)
        radio2 = np.fmax(self.data[1, :, :], 0)

136
137
        radio1 = np.random.poisson(radio1 * 400)
        radio2 = np.random.poisson(np.fliplr(radio2) * 400)
138

139
        CoR_calc = CenterOfRotation()
140
        cor_position = CoR_calc.find_shift(radio1, radio2, median_filt_shape=(3, 3))
141

142
143
        message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.px
        assert np.abs(self.px - cor_position) < self.abs_tol, message
144

myron's avatar
myron committed
145
146
147
148
149
150
    def test_noisyHF_cor_posx(self):
        """  test with noise at high frequencies
        """
        radio1 = self.data[0, :, :]
        radio2 = np.fliplr(self.data[1, :, :])

151
        noise_level = radio1.max() / 16.0
152
153
154
155
156
        noise_ima1 = np.random.normal(0.0, size=radio1.shape) * noise_level
        noise_ima2 = np.random.normal(0.0, size=radio2.shape) * noise_level

        noise_ima1 = noise_ima1 - scipy.ndimage.filters.gaussian_filter(noise_ima1, 2.0)
        noise_ima2 = noise_ima2 - scipy.ndimage.filters.gaussian_filter(noise_ima2, 2.0)
myron's avatar
myron committed
157

158
159
        radio1 = radio1 + noise_ima1
        radio2 = radio2 + noise_ima2
myron's avatar
myron committed
160

161
        CoR_calc = CenterOfRotation()
myron's avatar
myron committed
162

163
        # cor_position = CoR_calc.find_shift(radio1, radio2)
164
        cor_position = CoR_calc.find_shift(radio1, radio2, low_pass=(6.0, 0.3))
myron's avatar
myron committed
165

166
167
        message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.px
        assert np.abs(self.px - cor_position) < self.abs_tol, message
myron's avatar
myron committed
168

169
170
171
172
173
    def test_cor_posx_linear(self):
        radio1 = self.data[0, :, :]
        radio2 = np.fliplr(self.data[1, :, :])

        CoR_calc = CenterOfRotation()
174
        cor_position = CoR_calc.find_shift(radio1, radio2, padding_mode="constant")
175

176
177
        message = "Computed CoR %f " % cor_position + " and real CoR %f do not coincide" % self.px
        assert np.abs(self.px - cor_position) < self.abs_tol, message
178

179
    def test_error_checking_001(self):
180
        CoR_calc = CenterOfRotation()
181
182
183
184
185

        radio1 = self.data[0, :, :1:]
        radio2 = self.data[1, :, :]

        with pytest.raises(ValueError) as ex:
186
            CoR_calc.find_shift(radio1, radio2)
187

188
189
        message = "Error should have been raised about img #1 shape, other error raised instead:\n%s" % str(ex.value)
        assert "Images need to be 2-dimensional. Shape of image #1" in str(ex.value), message
190
191

    def test_error_checking_002(self):
192
        CoR_calc = CenterOfRotation()
193
194
195
196
197

        radio1 = self.data[0, :, :]
        radio2 = self.data

        with pytest.raises(ValueError) as ex:
198
            CoR_calc.find_shift(radio1, radio2)
199

200
201
        message = "Error should have been raised about img #2 shape, other error raised instead:\n%s" % str(ex.value)
        assert "Images need to be 2-dimensional. Shape of image #2" in str(ex.value), message
202
203

    def test_error_checking_003(self):
204
        CoR_calc = CenterOfRotation()
205
206
207
208
209

        radio1 = self.data[0, :, :]
        radio2 = self.data[1, :, 0:10]

        with pytest.raises(ValueError) as ex:
210
            CoR_calc.find_shift(radio1, radio2)
211

212
213
        message = "Error should have been raised about different image shapes, " + "other error raised instead:\n%s" % str(
            ex.value
214
215
        )
        assert "Images need to be of the same shape" in str(ex.value), message
216

217
218
219

@pytest.mark.usefixtures("bootstrap_dtr")
class TestDetectorTranslation(object):
220
221
222
    def test_alignxc(self):
        T_calc = DetectorTranslationAlongBeam()

223
        shifts_v, shifts_h, found_shifts_list = T_calc.find_shift(self.align_images, self.img_pos, return_shifts=True)
224

225
        message = "Computed shifts coefficients %s and expected %s do not coincide" % (
226
227
            (shifts_v, shifts_h),
            self.expected_shifts_vh,
228
        )
Nicola Vigano's avatar
Nicola Vigano committed
229
        assert np.all(np.isclose(self.expected_shifts_vh, [shifts_v, shifts_h], atol=self.abs_tol)), message
230

231
        message = "Computed shifts %s and expected %s do not coincide" % (found_shifts_list, self.reference_shifts_list)
232
        assert np.all(np.isclose(found_shifts_list, self.reference_shifts_list, atol=self.abs_tol)), message
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

    def test_alignxc_synth(self):
        T_calc = DetectorTranslationAlongBeam()

        stack = np.zeros([4, 512, 512], "d")
        for i in range(4):
            stack[i, 200 - i * 10, 200 - i * 10] = 1
        stack = scipy.ndimage.filters.gaussian_filter(stack, [0, 10, 10.0]) * 100
        x, y = np.meshgrid(np.arange(stack.shape[-1]), np.arange(stack.shape[-2]))
        for i in range(4):
            xc = x - (250 + i * 1.234)
            yc = y - (250 + i * 1.234 * 2)
            stack[i] += np.exp(-(xc * xc + yc * yc) * 0.5)
        shifts_v, shifts_h, found_shifts_list = T_calc.find_shift(
            stack, np.array([0.0, 1, 2, 3]), high_pass=1.0, return_shifts=True
        )

        message = "Found shifts per units %s and reference %s do not coincide" % ((shifts_v, shifts_h), (-1.234 * 2, -1.234))
        assert np.all(np.isclose((shifts_v, shifts_h), (-1.234 * 2, -1.234), atol=self.abs_tol)), message