Commit c7277a15 authored by Pierre Paleo's avatar Pierre Paleo
Browse files

Merge branch 'sino_rings' into 'master'

Sinogram based rings removal

Closes #95

See merge request !100
parents b86109eb 072a2fbf
Pipeline #39332 passed with stages
in 5 minutes and 6 seconds
from os import path, mkdir
from time import time
import numpy as np
from ..resources.logger import LoggerOrPrint
from .utils import use_options, pipeline_step, WriterConfigurator
......@@ -9,9 +10,10 @@ from ..preproc.shift import VerticalShift
from ..preproc.double_flatfield import DoubleFlatField
from ..preproc.phase import PaganinPhaseRetrieval
from ..preproc.sinogram import SinoProcessing, SinoNormalization
from ..preproc.rings import MunchDeringer
from ..misc.unsharp import UnsharpMask
from ..misc.histogram import PartialHistogram, hist_as_2Darray
from ..resources.utils import is_hdf5_extension
from ..resources.utils import is_hdf5_extension, extract_parameters
class FullFieldPipeline:
......@@ -29,6 +31,7 @@ class FullFieldPipeline:
UnsharpMaskClass = UnsharpMask
VerticalShiftClass = VerticalShift
SinoProcessingClass = SinoProcessing
SinoDeringerClass = MunchDeringer
MLogClass = Log
SinoNormalizationClass = SinoNormalization
FBPClass = None # For now we don't have a plain python/numpy backend for reconstruction
......@@ -260,6 +263,8 @@ class FullFieldPipeline:
shape = self._radios_cropped_shape
elif step_name == "build_sino":
shape = self._radios_cropped_shape
elif step_name == "sino_rings_correction":
shape = self.sino_builder.output_shape
elif step_name == "reconstruction":
shape = self.sino_builder.output_shape[1:]
else:
......@@ -323,6 +328,7 @@ class FullFieldPipeline:
self._init_mlog()
self._init_sino_normalization()
self._init_sino_builder()
self._init_sino_rings_correction()
self._prepare_reconstruction()
self._init_reconstruction()
self._init_histogram()
......@@ -456,6 +462,17 @@ class FullFieldPipeline:
self.sinos = self._allocate_sinobuilder_output()
self._sinobuilder_output = self.sinos
@use_options("sino_rings_correction", "sino_deringer")
def _init_sino_rings_correction(self):
options = self.processing_options["sino_rings_correction"]
fw_params = extract_parameters(options["user_options"])
fw_sigma = fw_params.pop("sigma", 1.)
self.sino_deringer = self.SinoDeringerClass(
fw_sigma,
sinos_shape=self._get_shape("sino_rings_correction"),
**fw_params
)
@use_options("reconstruction", "reconstruction")
def _prepare_reconstruction(self):
options = self.processing_options["reconstruction"]
......@@ -569,7 +586,16 @@ class FullFieldPipeline:
@pipeline_step("chunk_reader", "Reading data")
def _read_data(self):
self.logger.debug("Region = %s" % str(self.sub_region))
t0 = time()
self.chunk_reader.load_files()
el = time() - t0
shp = self.chunk_reader.chunk_shape
GB = np.prod(shp) * self.chunk_reader.dtype.itemsize / 1e9
self.logger.info(
"Read subvolume %s (%.2f GB) in %.1f s: %.2f GB/s"
% (str(shp), GB, el, GB/el)
)
@pipeline_step("flatfield", "Applying flat-field")
def _flatfield(self):
......@@ -641,6 +667,12 @@ class FullFieldPipeline:
copy=self._sinobuilder_copy
)
@pipeline_step("sino_deringer", "Removing rings on sinograms")
def _destripe_sinos(self, sinos=None):
if sinos is None:
sinos = self.sinos
self.sino_deringer.remove_rings(sinos)
@pipeline_step("reconstruction", "Reconstruction")
def _reconstruct(self, sinos=None):
if sinos is None:
......@@ -690,6 +722,7 @@ class FullFieldPipeline:
self._radios_movements()
self._normalize_sinos()
self._build_sino()
self._destripe_sinos()
self._reconstruct()
self._compute_histogram()
self._write_data()
......
......@@ -8,6 +8,7 @@ from ..preproc.double_flatfield_cuda import CudaDoubleFlatField
from ..preproc.phase_cuda import CudaPaganinPhaseRetrieval
from ..preproc.sinogram_cuda import CudaSinoProcessing, CudaSinoNormalization
from ..preproc.sinogram import SinoProcessing, SinoNormalization
from ..preproc.rings_cuda import CudaMunchDeringer
from ..misc.unsharp_cuda import CudaUnsharpMask
from ..misc.histogram_cuda import CudaPartialHistogram
from ..reconstruction.fbp import Backprojector
......@@ -31,6 +32,7 @@ class CudaFullFieldPipeline(FullFieldPipeline):
UnsharpMaskClass = CudaUnsharpMask
VerticalShiftClass = CudaVerticalShift
SinoProcessingClass = CudaSinoProcessing
SinoDeringerClass = CudaMunchDeringer
MLogClass = CudaLog
FBPClass = Backprojector
HistogramClass = CudaPartialHistogram
......@@ -198,7 +200,6 @@ class CudaFullFieldPipelineLimitedMemory(CudaFullFieldPipeline):
raise ValueError("Binning in z is not supported with this class")
#
def _get_shape(self, step_name):
"""
Get the shape to provide to the class corresponding to step_name.
......@@ -227,6 +228,8 @@ class CudaFullFieldPipelineLimitedMemory(CudaFullFieldPipeline):
shape = self._radios_cropped_shape
elif step_name == "build_sino":
shape = (n_a, self.n_recs, self.radios_shape[-1])
elif step_name == "sino_rings_correction":
shape = self.sino_builder.output_shape[1:]
elif step_name == "reconstruction":
shape = self.sino_builder.output_shape[1:]
else:
......@@ -275,12 +278,11 @@ class CudaFullFieldPipelineLimitedMemory(CudaFullFieldPipeline):
self._get_phase_output_shape()
# Processing acting on sinograms will be done later
self._processing_steps = self.processing_steps.copy()
for step in ["build_sino", "reconstruction", "save"]:
for step in ["build_sino", "sino_rings_correction", "reconstruction", "save"]:
if step in self.processing_steps:
self.processing_steps.remove(step)
self._partial_histograms = []
def _allocate_radios(self):
self._allocate_array(self.radios_group_shape, "f", name="radios")
self._h_radios = self.radios # (n_angles, delta_z, width) (does not fit in GPU mem)
......@@ -299,12 +301,10 @@ class CudaFullFieldPipelineLimitedMemory(CudaFullFieldPipeline):
def _allocate_recs(self, ny, nx):
self.recs = self._allocate_array((self.chunk_size, ny, nx), "f", name="recs")
def _register_callbacks(self):
# No callbacks are registered for this subclass
pass
def _process_finalize(self):
# release cuda memory
if self._d_sinos is not None:
......@@ -333,12 +333,10 @@ class CudaFullFieldPipelineLimitedMemory(CudaFullFieldPipeline):
ff_opts["projs_indices"] = self._ff_proj_indices[start_idx:end_idx]
self._init_flatfield(shape=(transfer_size, ) + self.radios_shape[1:])
def _flatfield_radios_group(self, start_idx, end_idx, transfer_size):
self._reinit_flatfield(start_idx, end_idx, transfer_size)
self._flatfield()
def _apply_flatfield_and_dff(self, n_groups, group_size, n_images):
"""
If double flat-field is activated, apply flat-field + double flat-field.
......@@ -365,7 +363,6 @@ class CudaFullFieldPipelineLimitedMemory(CudaFullFieldPipeline):
self._double_flatfield(radios=self._h_radios)
self._flatfield_is_done = True
def _compute_histogram_partial(self, data=None):
if data is None:
data = self._d_recs
......@@ -373,13 +370,11 @@ class CudaFullFieldPipelineLimitedMemory(CudaFullFieldPipeline):
self._compute_histogram(data=data)
self._partial_histograms.append(self.recs_histogram)
def _merge_partial_histograms(self):
if self.histogram is None:
return
self.recs_histogram = self.histogram.merge_histograms(self._partial_histograms)
def _process_chunk_ccd(self):
"""
Perform the processing in the "CCD space" (on radios)
......@@ -428,7 +423,6 @@ class CudaFullFieldPipelineLimitedMemory(CudaFullFieldPipeline):
"f"
)
def _process_chunk_sinos(self):
"""
Perform the processing in the "sinograms space"
......@@ -437,6 +431,7 @@ class CudaFullFieldPipelineLimitedMemory(CudaFullFieldPipeline):
if "reconstruction" not in self.processing_steps:
return
self.logger.debug("Initializing processing on sinos")
self._init_sino_rings_correction()
self._prepare_reconstruction()
self._init_reconstruction()
self._init_writer()
......@@ -477,6 +472,7 @@ class CudaFullFieldPipelineLimitedMemory(CudaFullFieldPipeline):
#
self._d_sinos[:transfer_size, :, :] = sinos[:, :, :]
# Process stack of sinograms (chunk_size, n_angles, width)
self._destripe_sinos(sinos=self._d_sinos)
self._reconstruct(sinos=self._d_sinos)
self._compute_histogram_partial(data=self._d_recs[:transfer_size])
# Copy D2H
......@@ -487,7 +483,6 @@ class CudaFullFieldPipelineLimitedMemory(CudaFullFieldPipeline):
# Write
self._write_data(data=self._h_recs)
def _process_chunk(self):
self._process_chunk_ccd()
self._process_chunk_sinos()
......
import numpy as np
from .sinogram import SinoProcessing
from ..thirdparty.pore3d_deringer_munch import munchetal_filter
from ..utils import check_supported
class Deringer(SinoProcessing):
"""
Class aimed at wrapping several sinogram-based rings correction methods.
"""
class MunchDeringer(SinoProcessing):
_available_methods = {
"munch": munchetal_filter
}
def __init__(self, sinos_shape=None, radios_shape=None, method="munch", deringer_args=None, deringer_kwargs=None):
def __init__(self, sigma, levels=None, wname='db15', sinos_shape=None, radios_shape=None):
"""
Initialize a "Munch Et Al" sinogram deringer. See References for more information.
Parameters
-----------
method: str
Sinogram rings correction method.
deringer_args: list or tuple
List of options to pass (additionally to the current sinogram) to
the sinogram de-striping function.
deringer_kwargs: dict
Dictionary of named options to pass to the sinogram de-striping function.
sigma: float
Standard deviation of the damping parameter. The higher value of sigma,
the more important the filtering effect on the rings.
levels: int, optional
Number of wavelets decomposition levels.
By default (None), the maximum number of decomposition levels is used.
wname: str, optional
Default is "db15" (Daubechies, 15 vanishing moments)
sinos_shape: tuple, optional
Shape of the sinogram (or sinograms stack).
This class requires either sinos_shape or radios_shape.
radios_shape: tuple, optional
Shape of the projection (or projections stack)
This class requires either sinos_shape or radios_shape.
Please see the SinoProcessing documentation for other parameters
References
----------
B. Munch, P. Trtik, F. Marone, M. Stampanoni, Stripe and ring artifact removal with
combined wavelet-Fourier filtering, Optics Express 17(10):8567-8591, 2009.
"""
super().__init__(sinos_shape=sinos_shape, radios_shape=radios_shape)
self._init_destriping_function(method, deringer_args, deringer_kwargs)
self.sigma = sigma
self.levels = levels
self.wname = wname
self._check_can_use_wavelets()
def _check_can_use_wavelets(self):
if munchetal_filter is None:
raise ValueError("Need pywavelets to use this class")
def _init_destriping_function(self, method, deringer_args, deringer_kwargs):
check_supported(method, list(self._available_methods.keys()), "method")
self.method = method
self._destriping_func = self._available_methods[method]
self.deringer_args = []
if deringer_args is not None:
self.deringer_args = deringer_args
self.deringer_kwargs = {}
if deringer_kwargs is not None:
self.deringer_kwargs = deringer_kwargs
def _destripe_sinogram(self, sinogram):
return self._destriping_func(
sinogram,
*self.deringer_args,
**self.deringer_kwargs
)
def _destripe_2D(self, sino, output):
res = munchetal_filter(sino, self.levels, self.sigma, wname=self.wname)
output[:] = res
return output
def correct_rings(self, sinos, output=None):
def remove_rings(self, sinos, output=None):
"""
Correct the rings in sinogram.
Defaults to in-place processing !
Main function to performs rings artefacts removal on sinogram(s).
CAUTION: this function defaults to in-place processing, meaning that
the sinogram(s) you pass will be overwritten.
Parameters
-----------
sino: numpy.ndarray, optional
----------
sinos: numpy.ndarray
Sinogram or stack of sinograms.
output: numpy.ndarray, optional
Stack of sinograms. If not provided, the correction will overwrite
the sinogram passed as input.
Output array. If set to None (default), the output overwrites the input.
"""
if output is None:
output = sinos
# TODO more elegant
if sinos.ndim == 2:
output[:] = self._destripe_sinogram(sinos)
return output
return self._destripe_2D(sinos, output)
n_sinos = sinos.shape[0]
for i in range(n_sinos):
output[i] = self._destripe_sinogram(sinos[i])
self._destripe_2D(sinos[i], output[i])
return output
import numpy as np
import pycuda.gpuarray as garray
from .sinogram_cuda import CudaSinoProcessing
from .rings import MunchDeringer
from ..utils import get_cuda_srcfile, updiv
from ..cuda.processing import CudaProcessing
from ..cuda.kernel import CudaKernel
try:
from pypwt import Wavelets
__have_pypwt__ = True
from pycudwt import Wavelets
__have_pycudwt__ = True
except ImportError:
__have_pypwt__ = False
__have_pycudwt__ = False
try:
from skcuda.fft import Plan
from skcuda.fft import fft as cufft
......@@ -17,40 +18,51 @@ except ImportError:
__have_skcuda__ = False
def get_minor_version(semver_version):
return float(".".join(semver_version.split(".")[:2]))
class CudaMunchDeringer(MunchDeringer):
# "Get memory pointer" only available from pypwt 0.9
if __have_pypwt__:
__pypwt_version__ = get_minor_version(Wavelets.version())
if __pypwt_version__ < 0.9:
__have_pypwt__ = False
class CudaMunchDeringer(CudaSinoProcessing):
def __init__(self, levels, sigma, wname="db15", sinos_shape=None, radios_shape=None, cuda_options=None):
def __init__(self, sigma, levels=None, wname='db15', sinos_shape=None, radios_shape=None, cuda_options=None):
"""
Cuda implementation of Fourier-Wavelets de-striping method [1].
Please see the documentation of nabu.preproc.sinogram.SinoProcessing.
Initialize a "Munch Et Al" sinogram deringer with the Cuda backend.
See References for more information.
Parameters
-----------
levels: int
Number of Wavelets decomposition levels.
sigma: float
Damping factor in the Wavelets domain.
Standard deviation of the damping parameter. The higher value of sigma,
the more important the filtering effect on the rings.
levels: int, optional
Number of wavelets decomposition levels.
By default (None), the maximum number of decomposition levels is used.
wname: str, optional
Wavelets name
Default is "db15" (Daubechies, 15 vanishing moments)
sinos_shape: tuple, optional
Shape of the sinogram (or sinograms stack).
This class requires either sinos_shape or radios_shape.
radios_shape: tuple, optional
Shape of the projection (or projections stack)
This class requires either sinos_shape or radios_shape.
References
----------
B. Munch, P. Trtik, F. Marone, M. Stampanoni, Stripe and ring artifact removal with
combined wavelet-Fourier filtering, Optics Express 17(10):8567-8591, 2009.
"""
if not(__have_pypwt__ and __have_skcuda__):
raise ValueError("Needs pypwt and scikit-cuda to use this class")
super().__init__(sinos_shape=sinos_shape, radios_shape=radios_shape, cuda_options=cuda_options)
self._init_wavelets(levels, sigma, wname)
super().__init__(
sigma, levels=levels, wname=wname, sinos_shape=sinos_shape, radios_shape=radios_shape
)
self._check_can_use_wavelets()
cuda_options = cuda_options or {}
self.cuda_processing = CudaProcessing(**cuda_options)
self._init_pycudwt()
self._init_fft()
self._setup_fw_kernel()
def _check_can_use_wavelets(self):
if not(__have_pycudwt__ and __have_skcuda__):
raise ValueError("Needs pycudwt and scikit-cuda to use this class")
def _init_fft(self):
self._fft_plans = {}
for level, d_vcoeff in self._d_vertical_coeffs.items():
......@@ -81,10 +93,12 @@ class CudaMunchDeringer(CudaSinoProcessing):
)
self._fft_plans[level] = {"forward": p_f, "inverse": p_i}
def _init_wavelets(self, levels, sigma, wname):
self.sigma = float(sigma)
def _init_pycudwt(self):
if self.levels is None:
self.levels = 100 # will be clipped by pycudwt
self.sino_shape = self.sinos_shape[1:]
self.cudwt = Wavelets(np.zeros(self.sino_shape, "f"), wname, levels)
self.cudwt = Wavelets(np.zeros(self.sino_shape, "f"), self.wname, self.levels)
self.levels = self.cudwt.levels
# Access memory allocated by "pypwt" from pycuda
self._d_sino = garray.empty(self.sino_shape, np.float32, gpudata=self.cudwt.image_int_ptr())
self._get_vertical_coeffs()
......@@ -115,7 +129,7 @@ class CudaMunchDeringer(CudaSinoProcessing):
)
def destripe_munch(self, d_sino, output=None):
def _destripe_2D(self, d_sino, output):
# set the "image" for DWT (memcpy D2D)
self._d_sino.set(d_sino)
# perform forward DWT
......@@ -141,8 +155,7 @@ class CudaMunchDeringer(CudaSinoProcessing):
)
# Finally, inverse DWT
self.cudwt.inverse()
if output is None:
output = d_sino
output.set(self._d_sino)
return output
......@@ -2,14 +2,14 @@ import numpy as np
import pytest
from nabu.utils import clip_circle
from nabu.testutils import get_data, compare_arrays
from nabu.preproc.rings import Deringer
from nabu.preproc.rings import MunchDeringer
from nabu.thirdparty.pore3d_deringer_munch import munchetal_filter
from nabu.cuda.utils import __has_pycuda__
__have_gpuderinger__ = False
if __has_pycuda__:
import pycuda.gpuarray as garray
from nabu.preproc.rings_cuda import CudaMunchDeringer, __have_pypwt__, __have_skcuda__
if __have_pypwt__ and __have_skcuda__:
from nabu.preproc.rings_cuda import CudaMunchDeringer, __have_pycudwt__, __have_skcuda__
if __have_pycudwt__ and __have_skcuda__:
__have_gpuderinger__ = True
......@@ -28,7 +28,7 @@ def bootstrap(request):
@pytest.mark.usefixtures('bootstrap')
class TestDeringer:
class TestMunchDeringer:
@staticmethod
def add_stripes_to_sino(sino, rings_desc):
......@@ -52,31 +52,34 @@ class TestDeringer:
@pytest.mark.skipif(munchetal_filter is None, reason="Need PyWavelets for this test")
def test_munch_deringer(self):
deringer = Deringer(
sinos_shape=self.sino.shape, method="munch",
deringer_args=[self.fw_levels, self.fw_sigma],
deringer_kwargs={"wname": self.fw_wname}
deringer = MunchDeringer(
self.fw_sigma,
levels=self.fw_levels,
wname=self.fw_wname,
sinos_shape=self.sino.shape
)
sino = self.add_stripes_to_sino(self.sino, self.rings)
# Reference destriping with pore3d "munchetal_filter"
ref = munchetal_filter(sino, self.fw_levels, self.fw_sigma, wname=self.fw_wname)
# Wrapping with DeRinger
res = np.zeros((1, ) + sino.shape, dtype=np.float32)
deringer.correct_rings(sino, output=res)
deringer.remove_rings(sino, output=res)
err_max = np.max(np.abs(res[0] - ref))
assert err_max < self.tol, "Max error is too high"
@pytest.mark.skipif(not(__have_gpuderinger__), reason="Need pycuda, pypwt and scikit-cuda for this test")
@pytest.mark.skipif(
not(__have_gpuderinger__) or munchetal_filter is None,
reason="Need pycuda, pycudwt and scikit-cuda for this test"
)
def test_cuda_munch_deringer(self):
sino = self.add_stripes_to_sino(self.sino, self.rings)
deringer = CudaMunchDeringer(
self.fw_levels, self.fw_sigma, wname=self.fw_wname, sinos_shape=self.sino.shape
self.fw_sigma, levels=self.fw_levels, wname=self.fw_wname, sinos_shape=self.sino.shape
)
d_sino = garray.to_gpu(sino)
deringer.destripe_munch(d_sino)
deringer.remove_rings(d_sino)
res = d_sino.get()
ref = munchetal_filter(sino, self.fw_levels, self.fw_sigma, wname=self.fw_wname)
......
......@@ -73,6 +73,14 @@ def estimate_required_memory(process_config, chunk_size=None, radios_and_slices=
img_size_cplx = 2 * 8 * ((Nx_p * Ny_p) // 2 + 1)
total_memory_needed += 2 * img_size_real + 3 * img_size_cplx
# Sinogram de-ringing
# -------------------
if "sino_rings_correction" in processing_steps:
# Process is done image-wise.
# Needs one Discrete Wavelets transform and one FFT/IFFT plan for each scale
total_memory_needed += (Nx * Na * 4) * 5.5 # approx.
# Reconstruction
# ---------------
reconstructed_volume_size = 0
......
......@@ -115,6 +115,18 @@ nabu_config = {
"validator": optional_file_location_validator,
"type": "advanced",
},
"sino_rings_correction": {
"default": "",
"help": "Sinogram rings removal method. Default (empty) is None. Available are: None, munch. See also: sino_rings_options",
"validator": sino_deringer_methods,
"type": "optional",
},
"sino_rings_options": {
"default": "sigma=1.0 ; levels=10",
"help": "Options for sinogram rings correction methods. The parameters are separated by commas and passed as 'name=value', for example: sigma=1.0;levels=10. Mind the semicolon separator (;).",
"validator": cor_options_validator,
"type": "advanced",
},
},
"phase": {
"method": {
......
......@@ -152,3 +152,14 @@ cor_methods = {
"growing-window": "growing-window",
"growing window": "growing-window",
}
class RingsMethods(Enum):
NONE = None
MUNCH = "munch"
rings_methods = {
"none": None,
"": None,
"munch": "munch",
}
......@@ -235,6 +235,16 @@ class ProcessConfig:
options["sino_normalization"] = {
"method": nabu_config["preproc"]["sino_normalization"]
}
#
# Sinogram-based rings artefacts removal
#
if nabu_config["preproc"]["sino_rings_correction"]:
tasks.append("sino_rings_correction")
options["sino_rings_correction"] = {
"user_options": nabu_config["preproc"]["sino_rings_options"],
}
#
# Reconstruction
#
......
......@@ -339,6 +339,16 @@ def sino_normalization_validator(val):
)
return val
@validator
def sino_deringer_methods(val):
val = name_range_checker(
val,
RingsMethods.values(),
"sinogram rings artefacts correction method",
replacements=rings_methods,
)