Skip to content
Snippets Groups Projects
Commit 91155818 authored by Nicola Vigano's avatar Nicola Vigano
Browse files

6D-Reconstructions: added a new detector data-fidelity term (KL), and two new...

6D-Reconstructions: added a new detector data-fidelity term (KL), and two new TV-norms (Isotropic TV and Nuclear TV)

Signed-off-by: default avatarNicola Vigano <nicola.vigano@esrf.fr>
parent 1b81d61e
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,11 @@ function par_rec = gtRecGrainsDefaultParameters(algo)
'lambda', 1e-2, 'volume_downscaling', 1, ...
'ospace_super_sampling', 1, 'rspace_super_sampling', 1, ...
'ospace_oversize', 1.1, 'rspace_oversize', 1.2, ...
'shape_functions_type', 'none');
'shape_functions_type', 'none', ...
'detector_norm', 'KL', ... % Possibilities are: {'KL'} | 'l2'
'tv_norm', 'l12', ... % Possibilities are: {'l12'} | 'l1' | 'ln'
'tv_strategy', 'groups' ... % Possibilities are: {'groups'} | 'volume'
);
par_rec = struct(...
'algorithm', upper(algo), 'num_iter', 100, 'list', [], ...
'options', par_6D_rec_opts);
......
......@@ -3,58 +3,66 @@ function rec_opts = gtReconstruct6DGetParamenters(parameters)
%
rec = parameters.rec;
if (isfield(rec, 'grains') && isfield(rec.grains, 'options') ...
&& ~isempty(rec.grains.options))
rec_opts = rec.grains.options;
if (~isfield(rec_opts, 'volume_downscaling') ...
|| isempty(rec_opts.volume_downscaling))
rec_opts.volume_downscaling = 1;
end
if (~isfield(rec_opts, 'ospace_super_sampling') ...
|| isempty(rec_opts.ospace_super_sampling))
if (~isfield(rec_opts, 'super_sampling') ...
|| isempty(rec_opts.super_sampling))
rec_opts.ospace_super_sampling = 1;
else
rec_opts.ospace_super_sampling = rec_opts.super_sampling;
end
end
if (~isfield(rec_opts, 'rspace_super_sampling') ...
|| isempty(rec_opts.rspace_super_sampling))
rec_opts.rspace_super_sampling = 1;
end
if (~isfield(rec_opts, 'ospace_oversize') ...
|| isempty(rec_opts.ospace_oversize))
rec_opts.ospace_oversize = 1.1;
end
if (~isfield(rec_opts, 'rspace_oversize') ...
|| isempty(rec_opts.rspace_oversize))
rec_opts.rspace_oversize = 1.2;
end
if (~isfield(rec_opts, 'use_predicted_scatter_ints') ...
|| isempty(rec_opts.use_predicted_scatter_ints))
rec_opts.use_predicted_scatter_ints = false;
end
if (~isfield(rec_opts, 'shape_functions_type') ...
|| isempty(rec_opts.shape_functions_type))
rec_opts.shape_functions_type = 'none';
end
else
if (~isfield(rec, 'grains'))
error([mfilename ':bad_structure'], ...
'The parameters are too old or malformed: there is no ".grains" in .rec!')
end
num_iter = rec.grains.num_iter;
if (~isfield(rec.grains, 'options') || isempty(rec.grains.options))
warning('gtReconstruct6DGetParamenters:wrong_parameters', ...
'The rec.grains structure doesn''t seem to be valid. Falling back to defaults')
rec_opts = struct( ...
'grid_edge', 7, 'num_interp', 1, 'lambda', 1e-1, ...
'super_sampling', 1, 'rspace_super_sampling', 1, ...
'volume_downscaling', 1, ...
'ospace_oversize', 1.1, 'rspace_oversize', 1.2, ...
'use_predicted_scatter_ints', false, ...
'shape_functions_type', 'none' );
end
if (isfield(rec, 'grains'))
rec_opts.num_iter = rec.grains.num_iter;
rec_opts.algorithm = rec.grains.algorithm;
else
rec_opts.num_iter = rec.num_iter;
rec_opts.algorithm = '6DL1';
rec.grains = gtRecGrainsDefaultParameters(rec.grains.algorithm);
end
rec_opts = rec.grains.options;
rec_opts.num_iter = num_iter;
rec_opts.algorithm = rec.grains.algorithm;
if (~isfield(rec_opts, 'volume_downscaling') ...
|| isempty(rec_opts.volume_downscaling))
rec_opts.volume_downscaling = 1;
end
if (~isfield(rec_opts, 'ospace_super_sampling') ...
|| isempty(rec_opts.ospace_super_sampling))
if (~isfield(rec_opts, 'super_sampling') ...
|| isempty(rec_opts.super_sampling))
rec_opts.ospace_super_sampling = 1;
else
rec_opts.ospace_super_sampling = rec_opts.super_sampling;
end
end
if (~isfield(rec_opts, 'rspace_super_sampling') ...
|| isempty(rec_opts.rspace_super_sampling))
rec_opts.rspace_super_sampling = 1;
end
if (~isfield(rec_opts, 'ospace_oversize') ...
|| isempty(rec_opts.ospace_oversize))
rec_opts.ospace_oversize = 1.1;
end
if (~isfield(rec_opts, 'rspace_oversize') ...
|| isempty(rec_opts.rspace_oversize))
rec_opts.rspace_oversize = 1.2;
end
if (~isfield(rec_opts, 'use_predicted_scatter_ints') ...
|| isempty(rec_opts.use_predicted_scatter_ints))
rec_opts.use_predicted_scatter_ints = false;
end
if (~isfield(rec_opts, 'shape_functions_type') ...
|| isempty(rec_opts.shape_functions_type))
rec_opts.shape_functions_type = 'none';
end
if (~isfield(rec_opts, 'detector_norm') ...
|| isempty(rec_opts.detector_norm))
rec_opts.detector_norm = 'KL';
end
if (~isfield(rec_opts, 'tv_norm') ...
|| isempty(rec_opts.tv_norm))
rec_opts.detector_norm = 'l12';
end
if (~isfield(rec_opts, 'tv_strategy') ...
|| isempty(rec_opts.tv_strategy))
rec_opts.tv_strategy, 'groups';
end
end
......@@ -15,9 +15,14 @@ function algo = gtReconstruct6DLaunchAlgorithm(sampler, rec_opts, parameters, va
algo = rec_factory.getReconstructionAlgo(sampler, rec_opts.num_interp);
% Adding extra parameters/constraints
% Adding extra constraints
algo.ODF = conf.ODF;
% Adding extra parameters
algo.detector_norm = rec_opts.detector_norm;
algo.tv_norm = rec_opts.tv_norm;
algo.tv_strategy = rec_opts.tv_strategy;
geom_memory = sum(arrayfun(@(x)x.get_memory_consumption_geometry(), sampler));
grain_memory = sum(arrayfun(@(x)x.get_memory_consumption_graindata(), sampler));
......
......@@ -19,7 +19,9 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
verbose = false;
detector_norm = 'KL'; % Possibilities are: {'KL'} | 'l2'
tv_norm = 'l12'; % Possibilities are: {'l12'} | 'l1'
tv_norm = 'l12'; % Possibilities are: {'l12'} | 'l1' | 'ln'
tv_strategy = 'groups'; % Possibilities are: {'groups'} | 'volume'
orientation_groups = [];
algo_ops_c_functions = true;
......@@ -46,6 +48,7 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
self.statistics.add_task('cp_dual_update_l1', 'CP Dual variable (l1) update');
self.statistics.add_task('cp_dual_update_tv', 'CP Dual variable (TV) update');
self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_reduction', 'Volumes reduction');
self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_proximal', 'Proximal application');
self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_gradient', 'Gradient');
self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_divergence', 'Divergence');
self.statistics.add_task('cp_dual_update_ODF', 'CP Dual variable (ODF) update');
......@@ -145,6 +148,10 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
function cp(self, algo, numIters, lambda)
self.statistics.clear();
do_tv_update = ~isempty(strfind(upper(algo), 'TV'));
do_l1_update = ~isempty(strfind(upper(algo), 'L1'));
do_odf_update = ~isempty(self.ODF);
self.initializeVariables(numIters);
sample_rate = 5;
......@@ -156,7 +163,17 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
[p, nextEnhancedSolution, q_l1, q_tv, q_odf] = self.init_cp_vars(algo);
fprintf('Done (%g seconds).\n', toc(c))
fprintf('Reconstruction: ')
fprintf('Reconstruction using algorithm: %s\n', upper(algo));
fprintf(' - Detector data-fidelity term: %s\n', self.detector_norm);
if (do_l1_update)
fprintf(' - l1-term lambda: %g\n', lambda);
end
if (do_tv_update)
fprintf(' - TV norm: %s\n', self.tv_norm);
fprintf(' - TV strategy: %s\n', self.tv_strategy);
end
fprintf('Iteration: ');
c = tic();
switch (numel(numIters))
case 1
......@@ -167,10 +184,6 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
iters = numIters;
end
do_tv_update = ~isempty(strfind(upper(algo), 'TV'));
do_l1_update = ~isempty(strfind(upper(algo), 'L1'));
do_odf_update = ~isempty(self.ODF);
tot_iters = numel(iters);
for ii = iters
numchars = fprintf('%03d/%03d (in %f s)', ii-min(iters), tot_iters, toc(c));
......@@ -205,9 +218,9 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
case '6DL1'
compute_update_primal = @(ii){lambda * q_l1{ii}};
case '6DTV'
compute_update_primal = @(ii){mdiv_tv};
compute_update_primal = @(ii){get_divergence(self, mdiv_tv, ii)};
case '6DTVL1'
compute_update_primal = @(ii){lambda * q_l1{ii}, mdiv_tv};
compute_update_primal = @(ii){lambda * q_l1{ii}, get_divergence(self, mdiv_tv, ii)};
end
% Computing update dual ODF
......@@ -311,43 +324,111 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
%%% Primal and Duals updating functions
methods (Access = public)
function [q, mdivq] = update_dual_TV(self, q, new_enh_sol)
sigma = 1 ./ (2 * numel(new_enh_sol));
switch (self.tv_strategy)
case 'volume'
num_groups = 1;
or_ranges = {1:self.orientation_groups(end)};
sigmas = 1 ./ (2 * numel(new_enh_sol));
case 'groups'
num_groups = size(self.orientation_groups, 1);
or_ranges = cell(num_groups, 1);
sigmas = 1 ./ (2 .* (self.orientation_groups(:, 2) - self.orientation_groups(:, 1) + 1));
for ii_g = 1:num_groups
or_ranges{ii_g} = self.orientation_groups(ii_g, 1):self.orientation_groups(ii_g, 2);
end
end
c = tic();
sES = gtMathsSumCellVolumes(new_enh_sol);
self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_reduction');
reduction_time = 0;
gradient_time = 0;
proximal_time = 0;
c = tic();
dsES = gtMathsGradient(sES);
self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_gradient');
dsES = cell(num_groups, 3);
mdivq = cell(num_groups, 1);
for ii_g = 1:num_groups
c = tic();
sES = gtMathsSumCellVolumes(new_enh_sol(or_ranges{ii_g}));
reduction_time = reduction_time + toc(c);
c = tic();
dsES(ii_g, :) = gtMathsGradient(sES);
gradient_time = gradient_time + toc(c);
c = tic();
for ii_d = 1:3
q{ii_g, ii_d} = q{ii_g, ii_d} + sigmas(ii_g) * dsES{ii_g, ii_d};
end
proximal_time = proximal_time + toc(c);
end
c = tic();
switch (self.tv_norm)
case 'l12'
for n = 1:numel(q)
q{n} = q{n} + sigma * dsES{n};
end
grad_l2 = sqrt(q{1} .^ 2 + q{2} .^ 2 + q{3} .^ 2);
for ii_g = 1:num_groups
grad_l2 = sqrt(q{ii_g, 1} .^ 2 + q{ii_g, 2} .^ 2 + q{ii_g, 3} .^ 2);
for n = 1:numel(q)
q{n} = q{n} ./ max(1, grad_l2);
for ii_d = 1:3
q{ii_g, ii_d} = q{ii_g, ii_d} ./ max(1, grad_l2);
end
end
case 'l1'
if (self.algo_ops_c_functions)
% Using the l1 update function because TV is the l1 of the
% gradient
q = gt6DUpdateDualL1_c(q, dsES, sigma, self.num_threads);
else
for n = 1:numel(q)
q{n} = q{n} + sigma * dsES{n};
q{n} = q{n} ./ max(1, abs(q{n}));
for n = 1:numel(q)
q{n} = q{n} ./ max(1, abs(q{n}));
end
case 'ln'
size_vols = size(q{1});
num_voxels = numel(q{1});
z = zeros(num_voxels, 3, num_groups, self.data_type);
for ii_g = 1:num_groups
for ii_d = 1:3
z(:, ii_d, ii_g) = reshape(q{ii_g, ii_d}, [], 1);
end
end
z = permute(z, [3 2 1]);
ZZ = gtMathsMatrixProduct(permute(z, [2 1 3]), z);
[es, Vs] = gtMathsEig3x3SymmPosDef(ZZ);
svals = sqrt(es);
Vts = permute(Vs, [2 1 3]);
svals_p = sign(svals) .* min(1, abs(svals));
svals_d = 1 ./ svals;
svals_d(svals == 0) = 0;
sigmas = svals_d .* svals_p;
sigmas = bsxfun(@times, Vs, sigmas);
sigmas = gtMathsMatrixProduct(sigmas, Vts);
z = gtMathsMatrixProduct(z, sigmas);
z = permute(z, [3 2 1]);
for ii_g = 1:num_groups
for ii_d = 1:3
q{ii_g, ii_d} = reshape(z(:, ii_d, ii_g), size_vols);
end
end
end
proximal_time = proximal_time + toc(c);
c = tic();
mdivq = - gtMathsDivergence(q);
self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_divergence');
for ii_g = 1:num_groups
mdivq{ii_g} = - gtMathsDivergence(q(ii_g, :));
end
divergence_time = toc(c);
self.statistics.add_timestamp(reduction_time, 'cp_dual_update_tv', 'cp_dual_tv_reduction');
self.statistics.add_timestamp(proximal_time, 'cp_dual_update_tv', 'cp_dual_tv_proximal');
self.statistics.add_timestamp(gradient_time, 'cp_dual_update_tv', 'cp_dual_tv_gradient');
self.statistics.add_timestamp(divergence_time, 'cp_dual_update_tv', 'cp_dual_tv_divergence');
end
function mdivq_ii = get_divergence(self, mdivq, ii)
num_groups = size(self.orientation_groups, 1);
if (strcmpi(self.tv_strategy, 'volume') || (num_groups == 1))
ind = 1;
elseif (strcmpi(self.tv_strategy, 'groups'))
ind = find( ...
(ii >= self.orientation_groups(:, 1)) ...
& (ii <= self.orientation_groups(:, 2)), 1);
end
mdivq_ii = mdivq{ind};
end
function p = update_dual_detector(self, p, proj_bls, sigma, sigma_1)
......@@ -599,22 +680,30 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
p = self.currentDetDual;
end
do_tv_update = ~isempty(strfind(upper(algo), 'TV'));
do_l1_update = ~isempty(strfind(upper(algo), 'L1'));
do_ls_update = ~isempty(strfind(upper(algo), 'LS'));
p1 = internal_cell_copy(self.currentSolution);
switch (upper(algo))
case '6DLS'
q_tv = [];
q_l1 = gtMathsGetSameSizeZeros(self.currentSolution{1});
case '6DL1'
q_tv = [];
q_l1 = gtMathsGetSameSizeZeros(self.currentSolution);
case '6DTV'
q_tv = gtMathsGetSameSizeZeros(self.currentSolution([1 1 1]));
q_l1 = [];
case '6DTVL1'
if (do_l1_update)
q_l1 = gtMathsGetSameSizeZeros(self.currentSolution);
elseif (do_ls_update)
q_l1 = gtMathsGetSameSizeZeros(self.currentSolution{1});
else
q_l1 = [];
end
if (do_tv_update)
if (strcmpi(self.tv_strategy, 'groups'))
num_groups = size(self.orientation_groups, 1);
q_tv = gtMathsGetSameSizeZeros(self.currentSolution(ones(num_groups, 3)));
q_tv = reshape(q_tv, num_groups, 3);
else
q_tv = gtMathsGetSameSizeZeros(self.currentSolution([1 1 1]));
q_l1 = gtMathsGetSameSizeZeros(self.currentSolution);
otherwise
end
else
q_tv = [];
end
if (isempty(self.ODF))
......
function d = gtMathsDet3x3SymmPosDef(m)
% d = m(1, 1, :) .* (m(2, 2, :) .* m(3, 3, :) - m(2, 3, :) .^ 2) ...
% - m(1, 2, :) .* (m(2, 1, :) .* m(3, 3, :) - m(2, 3, :) .* m(3, 1, :)) ...
% + m(1, 3, :) .* (m(2, 1, :) .* m(3, 2, :) - m(2, 2, :) .* m(3, 1, :));
mdiag = gtMathsDiag3x3(m);
moffdiag = cat(1, m(2, 3, :), m(1, 3, :), m(1, 2, :));
d = prod(mdiag, 1) + 2 .* prod(moffdiag, 1) - sum(mdiag .* (moffdiag .^ 2), 1);
end
\ No newline at end of file
function D = gtMathsDiag3x3(D, d_in)
[size_D(1), size_D(2), size_D(3)] = size(D);
if (any(size_D(1:2) == 1)) % vector to matrix
D_out(1:4:9, 1:size_D(3)) = D;
D = reshape(D_out, 3, 3, []);
elseif (nargin > 1) % Matrix with vector that goes on the diag
D = reshape(D, 9, []);
D(1:4:9, 1:size_D(3)) = d_in;
D = reshape(D, 3, 3, []);
else % matrix to vector
D = reshape(D, 9, []);
D = D(1:4:9, 1:size_D(3));
D = reshape(D, 3, 1, []);
end
end
function [e, V] = gtMathsEig3x3SymmPosDef(M)
num_matrs = size(M, 3);
e = zeros(1, 3, num_matrs, class(M));
diag_M = gtMathsDiag3x3(M);
diag_M = reshape(diag_M, 1, 3, []);
% Trace
trM = sum(diag_M, 2);
M23_M13 = M([2 1], 3, :);
M12_2 = M(1, 2, :) .^ 2;
p1 = sum(M23_M13 .^ 2, 1) + M12_2;
trivials = p1 == 0;
if (any(trivials))
e(1, :, trivials) = diag_M(1, :, trivials);
end
others = ~trivials;
if (any(others))
q_3 = trM(1, 1, others);
q = q_3 ./ 3;
diag_M_minus_q = diag_M(1, :, others) - q(1, [1 1 1], :);
tr_M_minusqI_2 = sum(diag_M_minus_q .^ 2, 2);
p2 = tr_M_minusqI_2 + 2 .* p1(1, 1, others);
p = sqrt(p2 ./ 6);
p_1 = 1 ./ (p + (p == 0));
M_minusqI = M(:, :, others);
M_minusqI = gtMathsDiag3x3(M_minusqI, diag_M_minus_q);
B = bsxfun(@times, p_1, M_minusqI);
r = gtMathsDet3x3SymmPosDef(B) / 2;
phi = acos(r) / 3;
phi(r <= -1) = pi / 3;
phi(r >= 1) = 0;
p_2 = 2 .* p;
e(1, 1, others) = q + p_2 .* cos(phi);
e(1, 3, others) = q + p_2 .* cos(phi + (2 * pi / 3));
e(1, 2, others) = q_3 - e(1, 1, others) - e(1, 3, others);
end
V = zeros(3, 3, num_matrs, 'like', M);
M23M12 = M(2, 3, :) .* M(1, 2, :);
M12M13 = M(1, 2, :) .* M(1, 3, :);
M22 = bsxfun(@minus, M(2, 2, :), e);
M11 = bsxfun(@minus, M(1, 1, :), e);
V(1, :, :) = bsxfun(@minus, bsxfun(@times, M(1, 3, :), M22), M23M12);
V(2, :, :) = bsxfun(@minus, bsxfun(@times, M(2, 3, :), M11), M12M13);
V(3, :, :) = bsxfun(@minus, M12_2, M11 .* M22);
norm_V = sqrt(sum(V .^ 2, 1));
norm_V = 1 ./ (norm_V + (norm_V == 0));
V = bsxfun(@times, V, norm_V);
end
function C = gtMathsMatrixProduct(A, B)
[size_A(1), size_A(2), size_A(3)] = size(A);
[size_B(1), size_B(2), size_B(3)] = size(B);
A = reshape(A, size_A(1), size_A(2), 1, size_A(3));
B = reshape(B, 1, size_B(1), size_B(2), size_B(3));
C = bsxfun(@times, A, B);
C = sum(C, 2);
C = reshape(C, size_A(1), size_B(2), size_A(3));
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment