-
Nicola Vigano authored
by removing the mex functions, and using matlab functions instead Signed-off-by:
Nicola Vigano <nicola.vigano@esrf.fr>
Nicola Vigano authoredby removing the mex functions, and using matlab functions instead Signed-off-by:
Nicola Vigano <nicola.vigano@esrf.fr>
Gt6DBlobReconstructor.m 26.88 KiB
classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
properties
% Data
blobs;
psf = {};
% Solution
currentSolution = {};
currentDetDual = {};
currentOdfDual = [];
% Variables to plot
normInitialResidual;
normResiduals;
ODF = [];
verbose = false;
algo_ops_c_functions = true;
end
methods (Access = public)
function self = Gt6DBlobReconstructor(volumes, blobs, varargin)
fprintf('Initializing BLOB Recontruction:\n - Setup..');
ct = tic();
c = ct;
blobs_depths = cellfun(@(x)size(x, 2), blobs);
proj_size = Gt6DBlobReconstructor.getProjSize(blobs);
if (iscell(volumes))
[vols_size(1), vols_size(2), vols_size(3)] = size(volumes{1});
else
vols_size = volumes;
end
self = self@Gt6DVolumeToBlobProjector(vols_size, proj_size, blobs_depths, varargin{:});
self.statistics.add_task('cp_dual_update_detector', 'CP Dual variable (detector) update');
self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_FP', 'Forward Projection');
self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_SB', 'Sinograms -> Blobs');
self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_SF', 'Shape Functions');
self.statistics.add_task('cp_dual_update_l1', 'CP Dual variable (l1) update');
self.statistics.add_task('cp_dual_update_tv', 'CP Dual variable (TV) update');
self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_reduction', 'Volumes reduction');
self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_gradient', 'Gradient');
self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_divergence', 'Divergence');
self.statistics.add_task('cp_dual_update_ODF', 'CP Dual variable (ODF) update');
self.statistics.add_task_partial('cp_dual_update_ODF', 'cp_dual_ODF_compute', 'Compute Reconstructed ODF');
self.statistics.add_task('cp_primal_update', 'CP Primal variable update');
self.statistics.add_task_partial('cp_primal_update', 'cp_primal_BP', 'Back Projection');
self.statistics.add_task_partial('cp_primal_update', 'cp_primal_BS', 'Blobs -> Sinograms');
self.statistics.add_task_partial('cp_primal_update', 'cp_primal_SF', 'Transposed Shape Functions');
self.statistics.add_task_partial('cp_primal_update', 'cp_primal_ODF', 'ODF Correction');
self.statistics.add_task_partial('cp_primal_update', 'cp_primal_CORR', 'Primal correction computation');
self.statistics.add_task_partial('cp_primal_update', 'cp_primal_APP', 'Primal update application');
self.blobs = blobs;
if (iscell(volumes))
self.currentSolution = volumes;
else
num_geoms = self.get_number_geometries();
fprintf('\b\b: Done. (%2.1f s)\n - Creating initial volumes (size: [%s] x %d bytes x %d vols = %g GB): ', ...
toc(c), strjoin(arrayfun(@(x){sprintf('%d', x)}, vols_size), ', '), ...
4, num_geoms, prod(vols_size) * 4 * num_geoms / 2^30)
c = tic();
self.currentSolution = cell(num_geoms, 1);
for ii = 1:num_geoms
num_chars = fprintf('%03d/%03d', ii, num_geoms);
self.currentSolution{ii} = zeros(vols_size, self.data_type);
fprintf(repmat('\b', [1 num_chars]));
end
end
fprintf('\b\b: Done. (%2.1f s)\n - Weights:\n', toc(c));
c = tic();
self.initializeWeights();
fprintf(' Done. (%2.1f s)\nTot (%3.1f s)).\n', toc(c), toc(ct));
end
function solution = getCurrentSolution(self)
solution = self.currentSolution;
end
function [proj_blobs, proj_spots] = getProjectionOfCurrentSolution(self)
proj_blobs = self.compute_fwd_projection(self.currentSolution, false);
proj_spots = cellfun(@(x){sum(x, 2)}, proj_blobs);
proj_spots = permute(cat(2, proj_spots{:}), [1 3 2]);
end
function reInit(self, blobs, volumes)
if (exist('volumes', 'var'))
num_geoms = self.get_number_geometries();
num_vols = numel(volumes);
if (num_vols ~= num_geoms)
error('Gt6DBlobReconstructor:wrong_argument', ...
'Wrong new volumes number. Expected: %d, got: %d', ...
num_geoms, num_vols)
end
for n = 1:num_geoms
if (any(size(volumes{n}) ~= self.volume_geometry))
error('Gt6DBlobReconstructor:wrong_argument', ...
'Wrong volume size for new volume: %d. Expected: (%s), got: (%s)', ...
n, sprintf(' %d', self.volume_geometry), ...
sprintf(' %d', size(volumes{n})) )
end
end
self.currentSolution = volumes;
else
self.currentSolution = gtMathsGetSameSizeZeros(self.currentSolution);
end
if (exist('blobs', 'var') && ~isempty(blobs))
self.blobs = blobs;
end
fprintf('Recomputing Weights:\n');
c = tic();
self.initializeWeights();
fprintf('- Done in %3.1f s.\n', toc(c));
self.currentDetDual = {};
self.currentOdfDual = [];
self.statistics.clear();
end
function cp_ls(self, numIters)
self.cp('6DLS', numIters);
end
function cp_l1(self, numIters, lambda)
self.cp('6DL1', numIters, lambda)
end
function cp_tv(self, numIters)
self.cp('6DTV', numIters);
end
function cp_tvl1(self, numIters, lambda)
self.cp('6DTVL1', numIters, lambda);
end
function cp(self, algo, numIters, lambda)
self.statistics.clear();
self.initializeVariables(numIters);
sample_rate = 5;
fprintf('Initializing CP_%s weights: ', upper(algo))
c = tic();
[sigma1, sigma1_1, sigma2, tau] = self.init_cp_weights(algo);
fprintf('Done (%g seconds).\nInitializing CP_%s variables: ', toc(c), upper(algo))
c = tic();
[p, nextEnhancedSolution, q_l1, q_tv, q_odf] = self.init_cp_vars(algo);
fprintf('Done (%g seconds).\n', toc(c))
fprintf('Reconstruction: ')
c = tic();
switch (numel(numIters))
case 1
iters = 1:numIters;
case 2
iters = numIters(1):numIters(2);
otherwise
iters = numIters;
end
do_tv_update = ~isempty(strfind(upper(algo), 'TV'));
do_l1_update = ~isempty(strfind(upper(algo), 'L1'));
do_odf_update = ~isempty(self.ODF);
tot_iters = numel(iters);
for ii = iters
numchars = fprintf('%03d/%03d (in %f s)', ii-min(iters), tot_iters, toc(c));
% Computing update of dual detector
self.statistics.tic('cp_dual_update_detector');
proj_bls = self.compute_fwd_projection(nextEnhancedSolution);
proj_bls = self.apply_psf(proj_bls);
p = self.update_dual_detector(p, proj_bls, sigma1, sigma1_1);
self.statistics.toc('cp_dual_update_detector');
% Computing update dual TV
if (do_tv_update)
self.statistics.tic('cp_dual_update_tv');
[q_tv, mdiv_tv] = self.update_dual_TV(q_tv, nextEnhancedSolution);
self.statistics.toc('cp_dual_update_tv');
end
% Computing update dual l1
if (do_l1_update)
self.statistics.tic('cp_dual_update_l1');
q_l1 = self.update_dual_l1(q_l1, nextEnhancedSolution);
self.statistics.toc('cp_dual_update_l1');
end
% Function for computing the update to apply
switch (upper(algo))
case '6DLS'
% q_l1 is just a volume of zeros
compute_update_primal = @(ii){};
case '6DL1'
compute_update_primal = @(ii){lambda * q_l1{ii}};
case '6DTV'
compute_update_primal = @(ii){mdiv_tv};
case '6DTVL1'
compute_update_primal = @(ii){lambda * q_l1{ii}, mdiv_tv};
end
% Computing update dual ODF
if (do_odf_update)
self.statistics.tic('cp_dual_update_ODF');
q_odf = self.update_dual_ODF(q_odf, nextEnhancedSolution, sigma2);
self.statistics.toc('cp_dual_update_ODF');
% [self.ODF, temp_odf, q_odf]
base_update_primal = compute_update_primal;
compute_update_primal = @(ii)[base_update_primal(ii), {q_odf(ii)}];
end
% Computing update primal
self.statistics.tic('cp_primal_update');
p_blurr = self.apply_psf(p);
nextEnhancedSolution = self.compute_bwd_projection(...
nextEnhancedSolution, p_blurr, compute_update_primal, tau);
self.statistics.toc('cp_primal_update');
fprintf('%s', repmat(sprintf('\b'), 1, numchars))
if (self.verbose && (mod(ii, sample_rate) == 0))
self.normResiduals(ii / sample_rate) ...
= self.compute_functional_value(p, algo, lambda);
end
drawnow();
end
fprintf('(%04d) Done in %f seconds.\n', tot_iters, toc(c) )
self.printStats();
if (self.verbose)
figure;
subplot(1, 2, 1), plot(self.normResiduals(1:floor(max(iters) / sample_rate)));
subplot(1, 2, 2), semilogy(abs(self.normResiduals));
end
self.currentDetDual = p;
self.currentOdfDual = q_odf;
end
function num_bytes = get_peak_memory_consumption(self, algo)
bytes_blobs = self.get_memory_consumption_blobs();
bytes_vols = self.get_memory_consumption_volumes();
% currentSolution, blobs, and ASTRA's copies
base_bytes = bytes_vols + bytes_blobs;
% Contribution of the projection weigths
base_bytes = base_bytes + bytes_blobs;
% Dual variable on detector
num_bytes = base_bytes + bytes_blobs;
% EnhancedSolution
num_bytes = num_bytes + bytes_vols;
switch (upper(algo))
case {'6DL1', '6DTVL1'}
% Dual variable in real-space
num_bytes = num_bytes + bytes_vols;
% Computation temporaries
num_bytes = num_bytes + bytes_blobs;
case {'6DLS', '6DTV'}
% Computation temporaries
num_bytes = num_bytes + bytes_blobs;
otherwise
error('Gt6DBlobReconstructor:wrong_argument', ...
'No such algorithm as "%s"', algo)
end
end
function num_bytes = get_memory_consumption_blobs(self)
num_bytes = GtBenchmarks.getSizeVariable(self.blobs);
end
function num_bytes = get_memory_consumption_volumes(self)
num_bytes = GtBenchmarks.getSizeVariable(self.currentSolution);
end
function num_bytes = get_memory_consumption_sinograms(self)
num_bytes = 0;
float_bytes = GtBenchmarks.getSizeVariable(zeros(1, 1, 'single'));
num_geoms = self.get_number_geometries();
if (isempty(self.ss_geometries))
for n = 1:num_geoms
sinogram_size = [ ...
self.proj_size(1), size(self.geometries{n}, 1), self.proj_size(2)];
num_bytes = num_bytes + float_bytes * prod(sinogram_size);
end
else
for n = 1:num_geoms
sinogram_size = [ ...
self.proj_size(1), size(self.ss_geometries{n}, 1), self.proj_size(2)];
num_bytes = num_bytes + float_bytes * prod(sinogram_size);
end
end
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Primal and Duals updating functions
methods (Access = public)
function [q, mdivq] = update_dual_TV(self, q, new_enh_sol)
sigma = 1 ./ (2 * numel(new_enh_sol));
c = tic();
sES = gtMathsSumCellVolumes(new_enh_sol);
self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_reduction');
c = tic();
dsES = gtMathsGradient(sES);
self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_gradient');
if (self.algo_ops_c_functions)
% Using the l1 update function because TV is the l1 of the
% gradient
q = gt6DUpdateDualL1_c(q, dsES, sigma, self.num_threads);
else
for n = 1:numel(q)
q{n} = q{n} + sigma * dsES{n};
q{n} = q{n} ./ max(1, abs(q{n}));
end
end
c = tic();
mdivq = - gtMathsDivergence(q);
self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_divergence');
end
function p = update_dual_detector(self, p, proj_bls, sigma, sigma_1)
if (self.algo_ops_c_functions)
p = gt6DUpdateDualDetector_c(p, self.blobs, proj_bls, sigma, sigma_1, self.num_threads);
% % Or equivalently
% proj_bls = internal_cell_sub_assign(proj_bls, bls);
% proj_bls = internal_cell_prod_assign(proj_bls, sigma1);
% p = internal_cell_sum_assign(p, proj_bls);
% p = internal_cell_prod_assign(p, sigma1_1);
else
for n = 1:numel(p)
p{n} = p{n} + sigma{n} .* (proj_bls{n} - self.blobs{n});
p{n} = p{n} .* sigma_1{n};
end
end
end
function q = update_dual_l1(self, q, new_enh_sol)
num_vols = numel(q);
% sigma = 1 / num_vols;
sigma = 1;
if (self.algo_ops_c_functions)
q = gt6DUpdateDualL1_c(q, new_enh_sol, sigma, self.num_threads);
else
for n = 1:num_vols
q{n} = q{n} + new_enh_sol{n} * sigma;
q{n} = q{n} ./ max(1, abs(q{n}));
end
end
end
function q_odf = update_dual_ODF(self, q_odf, new_enh_sol, sigma)
c = tic();
temp_odf = self.compute_solution_ODF(new_enh_sol);
self.statistics.add_timestamp(toc(c), 'cp_dual_update_ODF', 'cp_dual_ODF_compute');
q_odf = q_odf + sigma .* (temp_odf - self.ODF);
end
function [curr_sol, curr_enh_sol, app_time, corr_time] = update_primal(self, curr_sol, curr_enh_sol, corrections, tau)
c = tic();
correction = gtMathsSumCellVolumes(corrections);
corr_time = toc(c);
c = tic();
if (self.algo_ops_c_functions)
% We actually re-use the same allocated volumes, saving
% time and reducing problems with matlab's garbage
% collection
[curr_sol, curr_enh_sol] = gt6DUpdatePrimal_c(curr_sol, curr_enh_sol, correction, tau, self.num_threads);
else
theta = 1;
v = curr_sol + correction .* tau;
v(v < 0) = 0;
curr_enh_sol = v + theta .* (v - curr_sol);
curr_sol = v;
end
app_time = toc(c);
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Projection functions
methods (Access = protected)
function proj_bls = compute_fwd_projection(self, volumes, apply_stats)
if (~exist('apply_stats', 'var') || isempty(apply_stats))
apply_stats = true;
end
proj_bls = gtMathsGetSameSizeZeros(self.blobs);
timing_fp = 0;
timing_sb = 0;
timing_sf_fp = 0;
num_geoms = self.get_number_geometries();
chunk_size = self.num_gpus * self.jobs_bunch_size;
if (chunk_size > 1)
chunk_end = num_geoms - rem(num_geoms, chunk_size);
for n = 1:chunk_size:chunk_end
inds = n:(n+chunk_size-1);
[proj_bls, fp_time, sb_time, sf_fp_time] = self.fwd_project_volume( ...
proj_bls, volumes(inds), inds);
timing_fp = timing_fp + fp_time;
timing_sb = timing_sb + sb_time;
timing_sf_fp = timing_sf_fp + sf_fp_time;
end
else
chunk_end = 0;
end
for n = (chunk_end+1):num_geoms
[proj_bls, fp_time, sb_time, sf_fp_time] = self.fwd_project_volume( ...
proj_bls, volumes{n}, n);
timing_fp = timing_fp + fp_time;
timing_sb = timing_sb + sb_time;
timing_sf_fp = timing_sf_fp + sf_fp_time;
end
if (apply_stats)
self.statistics.add_timestamp(timing_fp, 'cp_dual_update_detector', 'cp_dual_detector_FP');
self.statistics.add_timestamp(timing_sb, 'cp_dual_update_detector', 'cp_dual_detector_SB');
self.statistics.add_timestamp(timing_sf_fp, 'cp_dual_update_detector', 'cp_dual_detector_SF');
end
end
function nextEnhancedSolution = compute_bwd_projection(self, nextEnhancedSolution, p_blurr, compute_update_primal, tau)
timing_bp = 0;
timing_bs = 0;
timing_corr = 0;
timing_app = 0;
timing_sf_bp = 0;
num_geoms = self.get_number_geometries();
chunk_size = self.num_gpus * self.jobs_bunch_size;
if (chunk_size > 1)
chunk_end = num_geoms - rem(num_geoms, chunk_size);
for n = 1:chunk_size:chunk_end
inds = n:(n+chunk_size-1);
[v, bp_time, bs_time, sf_bp_time] = self.bwd_project_volume(p_blurr, inds);
timing_bp = timing_bp + bp_time;
timing_bs = timing_bs + bs_time;
timing_sf_bp = timing_sf_bp + sf_bp_time;
for ii_gpu = 1:chunk_size
% Computing the update to apply
ii_v = inds(ii_gpu);
up_prim = [compute_update_primal(ii_v), v(ii_gpu)];
[self.currentSolution{ii_v}, nextEnhancedSolution{ii_v}, app_time, corr_time] ...
= self.update_primal(self.currentSolution{ii_v}, ...
nextEnhancedSolution{ii_v}, up_prim, tau{ii_v});
timing_corr = timing_corr + corr_time;
timing_app = timing_app + app_time;
end
end
else
chunk_end = 0;
end
for n = (chunk_end+1):num_geoms
[v, bp_time, bs_time, sf_bp_time] = self.bwd_project_volume(p_blurr, n);
timing_bp = timing_bp + bp_time;
timing_bs = timing_bs + bs_time;
timing_sf_bp = timing_sf_bp + sf_bp_time;
% Computing the update to apply
up_prim = [compute_update_primal(n), v];
[self.currentSolution{n}, nextEnhancedSolution{n}, app_time, corr_time] ...
= self.update_primal(self.currentSolution{n}, ...
nextEnhancedSolution{n}, up_prim, tau{n});
timing_corr = timing_corr + corr_time;
timing_app = timing_app + app_time;
end
self.statistics.add_timestamp(timing_bp, 'cp_primal_update', 'cp_primal_BP')
self.statistics.add_timestamp(timing_bs, 'cp_primal_update', 'cp_primal_BS')
self.statistics.add_timestamp(timing_sf_bp, 'cp_primal_update', 'cp_primal_SF')
self.statistics.add_timestamp(timing_corr, 'cp_primal_update', 'cp_primal_CORR')
self.statistics.add_timestamp(timing_app, 'cp_primal_update', 'cp_primal_APP')
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Initialization functions
methods (Access = protected)
function initializeWeights(self)
c = tic();
fprintf(' + Projecting ones vols..');
self.fwd_weights = self.getRowsSum();
self.fwd_weights = self.apply_psf(self.fwd_weights);
fprintf('\b\b (%3.1f s)\n', toc(c));
fprintf(' + Backprojecting ones sinos (fake bproj)..');
c = tic();
num_geoms = self.get_number_geometries();
if (~isempty(self.geometries))
for ii = 1:num_geoms
self.bwd_weights{ii} = size(self.geometries{ii}, 1);
end
else
% If this happens we are using oversampling with
% shape-functions
for ii = 1:num_geoms
self.bwd_weights{ii} = numel(self.blobs);
end
end
fprintf('\b\b (%2.1f s)\n', toc(c));
end
function [sigma1, sigma1_1, sigma2, tau] = init_cp_weights(self, algo)
sigma1 = cell(size(self.fwd_weights));
for n = 1:numel(sigma1)
sigma1{n} = 1 ./ (self.fwd_weights{n} + (self.fwd_weights{n} == 0));
end
sigma1_1 = sigma1;
for n = 1:numel(sigma1_1)
sigma1_1{n} = 1 ./ (1 + sigma1_1{n});
end
use_ODF = ~isempty(self.ODF);
sigma2 = use_ODF / prod(self.volume_geometry);
num_geoms = self.get_number_geometries();
tau = cell(size(self.bwd_weights));
switch (upper(algo))
case '6DLS'
for n = 1:num_geoms
tau{n} = - 1 ./ (self.bwd_weights{n} + use_ODF);
end
case '6DL1'
for n = 1:num_geoms
tau{n} = - 1 ./ (self.bwd_weights{n} + 1 + use_ODF);
end
case '6DTV'
for n = 1:num_geoms
tau{n} = - 1 ./ (self.bwd_weights{n} + 6 + use_ODF);
end
case '6DTVL1'
for n = 1:num_geoms
tau{n} = - 1 ./ (self.bwd_weights{n} + 6 + 1 + use_ODF);
end
otherwise
end
end
function [p, p1, q_l1, q_tv, q_odf] = init_cp_vars(self, algo)
if (isempty(self.currentDetDual))
p = gtMathsGetSameSizeZeros(self.blobs);
else
p = self.currentDetDual;
end
p1 = internal_cell_copy(self.currentSolution);
switch (upper(algo))
case '6DLS'
q_tv = [];
q_l1 = gtMathsGetSameSizeZeros(self.currentSolution{1});
case '6DL1'
q_tv = [];
q_l1 = gtMathsGetSameSizeZeros(self.currentSolution);
case '6DTV'
q_tv = gtMathsGetSameSizeZeros(self.currentSolution([1 1 1]));
q_l1 = [];
case '6DTVL1'
q_tv = gtMathsGetSameSizeZeros(self.currentSolution([1 1 1]));
q_l1 = gtMathsGetSameSizeZeros(self.currentSolution);
otherwise
end
if (isempty(self.ODF))
q_odf = [];
else
if (isempty(self.currentOdfDual))
q_odf = gtMathsGetSameSizeZeros(self.ODF);
else
q_odf = self.currentOdfDual;
end
end
end
function proj_data = apply_psf(self, proj_data)
if (~isempty(self.psf))
for n = 1:numel(proj_data)
if (numel(self.psf) == 1)
proj_data{n} = convn(proj_data{n}, self.psf{1}, 'same');
else
proj_data{n} = convn(proj_data{n}, self.psf{n}, 'same');
end
end
end
end
function calc_odf = compute_solution_ODF(self, volumes)
calc_odf = zeros(size(self.ODF), self.data_type);
for ii = 1:numel(volumes)
calc_odf(ii) = gtMathsSumNDVol(volumes{ii});
end
end
function [value, summed_vols] = compute_functional_value(self, p, algo, lambda)
value = gtMathsDotProduct(p, self.blobs);
num_geoms = self.get_number_geometries();
if (any(strcmpi(algo, {'6DLS', '6DL1'})))
proj_bls = self.compute_fwd_projection(self.currentSolution, false);
proj_bls = internal_cell_sub_assign(proj_bls, self.blobs);
value = value + gtMathsNorm_l2(proj_bls);
end
summed_vols = gtMathsSumCellVolumes(self.currentSolution);
if (any(strcmpi(algo, {'6DTV', '6DTVL1'})))
sES = gtMathsGradient(summed_vols / num_geoms);
value = value + gtMathsNorm_l1(sES);
end
if (any(strcmpi(algo, {'6DL1', '6DTVL1'})))
value = value + lambda * gtMathsNorm_l1(self.currentSolution);
end
end
function initializeVariables(self, numIters)
switch (numel(numIters))
case 1
self.normResiduals = zeros(numIters, 1);
case 2
self.normResiduals(numIters(1):numIters(2)) = 0;
otherwise
self.normResiduals(numIters) = 0;
end
end
end
methods (Static, Access = public)
function proj_size = getProjSize(blobs)
proj_size = [size(blobs{1}, 1), size(blobs{1}, 3)];
for n = 2:numel(blobs)
if (any(proj_size ~= [size(blobs{n}, 1), size(blobs{n}, 3)]))
error('Gt6DBlobReconstructor:wrong_argument', ...
'Blob: %d is malformed!', n)
end
end
end
end
end