Commit 19e4178a authored by Pierre Paleo's avatar Pierre Paleo
Browse files

WIP Backprojector API

parent 466842d1
......@@ -13,21 +13,52 @@ from .filtering import SinoFilter
class Backprojector(CudaProcessing):
"""
TODO docstring
Cuda Backprojector.
"""
def __init__(
self,
slice_shape,
angles,
dwidth_x=None,
dwidth_z=None,
sino_shape,
slice_shape=None,
angles=None,
rot_center=None,
filter_name=None,
extra_options={},
cuda_options={},
):
"""
Initialize a Cuda Backprojector.
Parameters
-----------
sino_shape: tuple
Shape of the sinogram, in the form `(n_angles, detector_width)`
(for backprojecting one sinogram) or `(n_sinos, n_angles, detector_width)`.
slice_shape: int or tuple, optional
Shape of the slice. By default, the slice shape is (n_x, n_x) where
`n_x = detector_width`
angles: array-like, optional
Rotation anles in radians.
By default, angles are equispaced between [0, pi[.
rot_center: float, optional
Rotation axis position. Default is `(detector_width - 1)/2.0`
filter_name: str, optional
Name of the filter for filtered-backprojection.
extra_options: dict, optional
Advanced extra options.
See the "Extra options" section for more information.
cuda_options: dict, optional
Cuda options.
Extra options
--------------
The parameter `extra_options` is a dictionary with the following defaults:
- "padding_mode": "zeros"
- "axis_correction": None
"""
super().__init__(**cuda_options)
self.configure_extra_options(extra_options=extra_options)
self.init_geometry(slice_shape, angles, dwidth_x, dwidth_z, rot_center)
self.init_geometry(sino_shape, slice_shape, angles, rot_center)
self._init_filter(filter_name)
self.allocate_memory()
self.compute_angles()
self.compile_kernels()
......@@ -37,7 +68,6 @@ class Backprojector(CudaProcessing):
def configure_extra_options(self, extra_options={}):
self._axis_array = None
self.extra_options = {
"filter_name": "ramlak",
"padding_mode": "zeros",
"axis_correction": None,
}
......@@ -45,21 +75,36 @@ class Backprojector(CudaProcessing):
self._axis_array = self.extra_options["axis_correction"]
def init_geometry(self, slice_shape, angles, dwidth_x, dwidth_z, rot_center):
if np.isscalar(slice_shape):
slice_shape = (slice_shape, slice_shape)
self.slice_shape = slice_shape
n_y, n_x = slice_shape
def init_geometry(self, slice_shape, angles, rot_center):
self.sino_shape = sino_shape
if len(sino_shape) == 2:
n_angles, dwidth = sino_shape
n_slices = 1
elif len(sino_shape) == 3:
n_slices, n_angles, dwidth = sino_shape
else:
raise ValueError("Expected 2D or 3D sinogram")
n_sinos = n_slices
n_y = dwidth
n_x = dwidth
if slice_shape is not None:
if np.isscalar(slice_shape):
slice_shape = (slice_shape, slice_shape)
n_y, n_x = slice_shape
self.n_x = n_x
self.n_y = n_y
self.dwidth_x = dwidth_x or max(n_x, n_y)
self.slice_shape = (n_y, n_x)
self.n_slices = n_slices
self.n_sinos = n_sinos
self.n_angles = n_angles
self.dwidth = dwidth
if np.isscalar(angles):
angles = np.linspace(0, np.pi, angles, False)
self.angles = angles
else:
assert len(angles) == self.n_angles
self.angles = angles
self.n_angles = len(self.angles)
self.sino_shape = (self.n_angles, self.dwidth_x)
self.rot_center = rot_center or (self.dwidth_x -1)/2.
self.sino_shape = (self.n_angles, self.dwidth)
self.rot_center = rot_center or (self.dwidth - 1)/2.
self.axis_pos = self.rot_center
......@@ -77,15 +122,16 @@ class Backprojector(CudaProcessing):
self._d_msin = garray.to_gpu(self.h_msin[0])
self._d_cos = garray.to_gpu(self.h_cos[0])
def compile_kernels(self):
# Configure sinogram filter
def _init_filter(self, filter_name):
self.filter_name = filter_name
self.sino_filter = SinoFilter(
self.sino_shape,
filter_name=self.extra_options["filter_name"],
filter_name=self.filter_name,
padding_mode=self.extra_options["padding_mode"],
ctx=self.ctx,
)
def compile_kernels(self):
# Configure backprojector
fname = get_cuda_srcfile("backproj.cu")
self.gpu_projector = CudaKernel("backproj", filename=fname)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment