Commit 1db95e77 authored by Pierre Paleo's avatar Pierre Paleo
Browse files

Rewrite cuda flatfield kernel

parent 5b605233
......@@ -9,7 +9,7 @@
/**
* In-place flat-field normalization.
* In-place flat-field normalization with linear interpolation.
* This kernel assumes that all the radios are loaded into memory
* (although not necessarily the full radios images)
* and in radios[x, y z], z in the radio index
......@@ -20,7 +20,8 @@
* Nx: number of pixel horizontally in the radios
* Nx: number of pixel vertically in the radios
* Nx: number of radios
* flats_indices: indices of flats, in sorted order
* flats_indices: indices of flats to fetch for each radio
* flats_weights: weights of flats for each radio
* darks_indices: indices of darks, in sorted order
**/
__global__ void flatfield_normalization(
......@@ -31,43 +32,30 @@ __global__ void flatfield_normalization(
int Ny,
int Nz,
int* flats_indices,
int* darks_indices,
int* radios_indices
float* flats_weights
) {
uint x = blockDim.x * blockIdx.x + threadIdx.x;
uint y = blockDim.y * blockIdx.y + threadIdx.y;
uint z = blockDim.z * blockIdx.z + threadIdx.z;
if ((x >= Nx) || (y >= Ny) || (z >= Nz)) return;
uint pos = (z*Ny+y)*Nx + x;
int radio_idx = radios_indices[z];
float dark_val = 0.0f, flat_val = 1.0f;
#if N_FLATS == 1
flat_val = flats[y*Nx + x];
#else
// interpolation between 2 flats
for (int i = 0; i < N_FLATS-1; i++) {
int ind_prev = flats_indices[i];
int ind_next = flats_indices[i+1];
if (ind_prev >= radio_idx) {
flat_val = flats[(i*Ny+y)*Nx + x];
break;
}
else if (ind_prev < radio_idx && radio_idx < ind_next) {
// Linear interpolation
// TODO nearest interpolation
int delta = ind_next - ind_prev;
float w1 = 1.0f - (radio_idx*1.0f - ind_prev) / delta;
float w2 = 1.0f - (ind_next*1.0f - radio_idx) / delta;
flat_val = w1 * flats[(i*Ny+y)*Nx + x] + w2 * flats[((i+1)*Ny+y)*Nx + x];
break;
}
else if (ind_next <= radio_idx) {
flat_val = flats[((i+1)*Ny+y)*Nx + x];
break;
}
int prev_idx = flats_indices[z*2 + 0];
int next_idx = flats_indices[z*2 + 1];
float w1 = flats_weights[z*2 + 0];
float w2 = flats_weights[z*2 + 1];
if (next_idx == -1) {
flat_val = flats[(prev_idx*Ny+y)*Nx + x];
}
else {
flat_val = w1 * flats[(prev_idx*Ny+y)*Nx + x] + w2 * flats[(next_idx*Ny+y)*Nx + x];
}
#endif
#if (N_DARKS == 1)
dark_val = darks[y*Nx + x];
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment