Commit f55e6f1e authored by Pierre Paleo's avatar Pierre Paleo
Browse files

Rewrite CudaFlatField

parent 460e2f34
......@@ -33,6 +33,7 @@ class CudaFlatField(FlatField):
)
self._set_cuda_options(cuda_options)
self._init_cuda_kernels()
self._precompute_flats_indices_weights()
self._load_flats_and_darks_on_gpu()
def _set_cuda_options(self, user_cuda_options):
......@@ -51,7 +52,7 @@ class CudaFlatField(FlatField):
self.cuda_kernel = CudaKernel(
"flatfield_normalization",
self._cuda_fname,
signature="PPPiiiPPP",
signature="PPPiiiPP",
options=[
"-DN_FLATS=%d" % self.n_flats,
"-DN_DARKS=%d" % self.n_darks,
......@@ -60,26 +61,41 @@ class CudaFlatField(FlatField):
self._nx = np.int32(self.shape[1])
self._ny = np.int32(self.shape[0])
def _precompute_flats_indices_weights(self):
flats_idx = np.zeros((self.n_radios, 2), dtype=np.int32)
flats_weights = np.zeros((self.n_radios, 2), dtype=np.float32)
for i, idx in enumerate(self.radios_indices):
prev_next = self.get_previous_next_indices(self._sorted_flat_indices, idx)
if len(prev_next) == 1: # current index corresponds to an acquired flat
weights = (1, 0)
f_idx = (self._flat2arrayidx[prev_next[0]], -1)
else: # interpolate
prev_idx, next_idx = prev_next
delta = next_idx - prev_idx
w1 = 1 - (idx - prev_idx) / delta
w2 = 1 - (next_idx - idx) / delta
weights = (w1, w2)
f_idx = (self._flat2arrayidx[prev_idx], self._flat2arrayidx[next_idx])
flats_idx[i] = f_idx
flats_weights[i] = weights
self.flats_idx = flats_idx
self.flats_weights = flats_weights
def _load_flats_and_darks_on_gpu(self):
# Flats
self.d_flats = garray.zeros((self.n_flats,) + self.shape, np.float32)
# ~ for i, flat_arr in enumerate(self.flats_arr.values()):
for i, flat_idx in enumerate(self._sorted_flat_indices):
self.d_flats[i].set(np.ascontiguousarray(self.flats_arr[flat_idx], dtype=np.float32))
self.d_flats_indices = garray.to_gpu(
np.array(self._sorted_flat_indices, dtype=np.int32)
)
# Darks
self.d_darks = garray.zeros((self.n_darks,) + self.shape, np.float32)
# ~ for i, dark_arr in enumerate(self.darks_arr.values()):
for i, dark_idx in enumerate(self._sorted_dark_indices):
self.d_darks[i].set(np.ascontiguousarray(self.darks_arr[dark_idx], dtype=np.float32))
self.d_darks_indices = garray.to_gpu(
np.array(self._sorted_dark_indices, dtype=np.int32)
)
# Radios indices
self.d_radios_indices = garray.to_gpu(self.radios_indices)
# Indices
self.d_flats_indices = garray.to_gpu(self.flats_idx)
self.d_flats_weights = garray.to_gpu(self.flats_weights)
def normalize_radios(self, radios):
"""
......@@ -105,8 +121,7 @@ class CudaFlatField(FlatField):
self._ny,
np.int32(self.n_radios),
self.d_flats_indices,
self.d_darks_indices,
self.d_radios_indices
self.d_flats_weights,
)
return radios
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment