Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tomotools
Nabu
Commits
f55e6f1e
Commit
f55e6f1e
authored
Jun 22, 2020
by
Pierre Paleo
Browse files
Rewrite CudaFlatField
parent
460e2f34
Changes
1
Hide whitespace changes
Inline
Side-by-side
nabu/preproc/ccd_cuda.py
View file @
f55e6f1e
...
...
@@ -33,6 +33,7 @@ class CudaFlatField(FlatField):
)
self
.
_set_cuda_options
(
cuda_options
)
self
.
_init_cuda_kernels
()
self
.
_precompute_flats_indices_weights
()
self
.
_load_flats_and_darks_on_gpu
()
def
_set_cuda_options
(
self
,
user_cuda_options
):
...
...
@@ -51,7 +52,7 @@ class CudaFlatField(FlatField):
self
.
cuda_kernel
=
CudaKernel
(
"flatfield_normalization"
,
self
.
_cuda_fname
,
signature
=
"PPPiiiPP
P
"
,
signature
=
"PPPiiiPP"
,
options
=
[
"-DN_FLATS=%d"
%
self
.
n_flats
,
"-DN_DARKS=%d"
%
self
.
n_darks
,
...
...
@@ -60,26 +61,41 @@ class CudaFlatField(FlatField):
self
.
_nx
=
np
.
int32
(
self
.
shape
[
1
])
self
.
_ny
=
np
.
int32
(
self
.
shape
[
0
])
def
_precompute_flats_indices_weights
(
self
):
flats_idx
=
np
.
zeros
((
self
.
n_radios
,
2
),
dtype
=
np
.
int32
)
flats_weights
=
np
.
zeros
((
self
.
n_radios
,
2
),
dtype
=
np
.
float32
)
for
i
,
idx
in
enumerate
(
self
.
radios_indices
):
prev_next
=
self
.
get_previous_next_indices
(
self
.
_sorted_flat_indices
,
idx
)
if
len
(
prev_next
)
==
1
:
# current index corresponds to an acquired flat
weights
=
(
1
,
0
)
f_idx
=
(
self
.
_flat2arrayidx
[
prev_next
[
0
]],
-
1
)
else
:
# interpolate
prev_idx
,
next_idx
=
prev_next
delta
=
next_idx
-
prev_idx
w1
=
1
-
(
idx
-
prev_idx
)
/
delta
w2
=
1
-
(
next_idx
-
idx
)
/
delta
weights
=
(
w1
,
w2
)
f_idx
=
(
self
.
_flat2arrayidx
[
prev_idx
],
self
.
_flat2arrayidx
[
next_idx
])
flats_idx
[
i
]
=
f_idx
flats_weights
[
i
]
=
weights
self
.
flats_idx
=
flats_idx
self
.
flats_weights
=
flats_weights
def
_load_flats_and_darks_on_gpu
(
self
):
# Flats
self
.
d_flats
=
garray
.
zeros
((
self
.
n_flats
,)
+
self
.
shape
,
np
.
float32
)
# ~ for i, flat_arr in enumerate(self.flats_arr.values()):
for
i
,
flat_idx
in
enumerate
(
self
.
_sorted_flat_indices
):
self
.
d_flats
[
i
].
set
(
np
.
ascontiguousarray
(
self
.
flats_arr
[
flat_idx
],
dtype
=
np
.
float32
))
self
.
d_flats_indices
=
garray
.
to_gpu
(
np
.
array
(
self
.
_sorted_flat_indices
,
dtype
=
np
.
int32
)
)
# Darks
self
.
d_darks
=
garray
.
zeros
((
self
.
n_darks
,)
+
self
.
shape
,
np
.
float32
)
# ~ for i, dark_arr in enumerate(self.darks_arr.values()):
for
i
,
dark_idx
in
enumerate
(
self
.
_sorted_dark_indices
):
self
.
d_darks
[
i
].
set
(
np
.
ascontiguousarray
(
self
.
darks_arr
[
dark_idx
],
dtype
=
np
.
float32
))
self
.
d_darks_indices
=
garray
.
to_gpu
(
np
.
array
(
self
.
_sorted_dark_indices
,
dtype
=
np
.
int32
)
)
#
Radios i
ndices
self
.
d_
radio
s_indices
=
garray
.
to_gpu
(
self
.
radios_indices
)
#
I
ndices
self
.
d_
flat
s_indices
=
garray
.
to_gpu
(
self
.
flats_idx
)
self
.
d_flats_weights
=
garray
.
to_gpu
(
self
.
flats_weights
)
def
normalize_radios
(
self
,
radios
):
"""
...
...
@@ -105,8 +121,7 @@ class CudaFlatField(FlatField):
self
.
_ny
,
np
.
int32
(
self
.
n_radios
),
self
.
d_flats_indices
,
self
.
d_darks_indices
,
self
.
d_radios_indices
self
.
d_flats_weights
,
)
return
radios
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment