Newer
Older
classdef GtGrainODFwSolver < handle
properties
parameters;
Nicola Vigano
committed
sampler;
volume = [];
size_volume = [];
sino = [];
size_sino = [];
pre_paddings;
S;
St;
S__ws;
St_ws;
tau_ws;
shape_functions_type = 'none';
shape_functions = [];
verbose = false;
data_type = 'double';
ospace_oversize = 1;
ospace_oversampling = 1;
function self = GtGrainODFwSolver(parameters, varargin)
self = parse_pv_pairs(self, varargin);
function conf = initialize(self, ref_gr, varargin)
% FUNCTION conf = initialize(self, ref_gr, varargin)
Nicola Vigano
committed
%
% INPUT (varargin):
% 'algorithm': {'cplsnn'} | 'sirt' | 'cplsl1nn' | 'cpl1nn'
Nicola Vigano
committed
% 'lambda': 1e-2
% 'det_index': 1
%
conf = struct( ...
'algorithm', 'cplsnn', ...
Nicola Vigano
committed
'lambda', 1e-2, ...
'det_index', 1 );
[conf, ~] = parse_pv_pairs(conf, varargin);
Nicola Vigano
committed
self.build_orientation_sampling(ref_gr, conf.det_index);
Nicola Vigano
committed
self.build_sinogram();
self.build_projection_matrices();
end
function vol = solve(self, ref_gr, varargin)
% FUNCTION vol = solve(self, ref_gr, varargin)
%
% INPUT (varargin): same as initialze(self, ref_gr, varargin)
%
conf = self.initialize(ref_gr, varargin{:});
Nicola Vigano
committed
switch(lower(conf.algorithm))
case 'sirt'
self.solve_sirt();
case 'cplsnn'
self.solve_cplsnn();
case 'cplsl1nn'
Nicola Vigano
committed
self.solve_cplsl1nn(conf.lambda);
case 'cpl1nn'
self.solve_cpl1nn();
end
vol = self.get_volume();
end
end
methods (Access = public) % Low Level API
function build_orientation_sampling(self, ref_gr, det_index)
if (~exist('det_index', 'var') || isempty(det_index))
det_index = 1;
end
Nicola Vigano
committed
self.sampler = GtOrientationSampling(self.parameters, ref_gr, ...
'detector_index', det_index);
% 'verbose', self.verbose,
self.sampler.make_grid_estim_ODF_resoluion('cubic', ...
-self.ospace_oversampling, self.ospace_oversize);
Nicola Vigano
committed
self.size_volume = size(self.sampler.lattice.gr);
switch (self.shape_functions_type)
case 'none'
self.shape_functions = [];
case 'w'
self.shape_functions = gtDefShapeFunctionsFwdProj(self.sampler, ...
'shape_function_type', 'w', ...
'data_type', self.data_type);
% self.shape_functions = gtDefShapeFunctionsCreateW(self.sampler);
Nicola Vigano
committed
function build_sinogram(self)
sel_ref = self.sampler.selected;
bls = self.sampler.ref_gr.proj(self.sampler.detector_index).bl(ref_sel);
num_blobs = numel(bls);
blob_dephs = arrayfun(@(x)size(x.intm, 3), bls);
blob_dephs = reshape(blob_dephs, [], 1);
Nicola Vigano
committed
blob_lims = cat(1, bls(:).bbwim);
switch(self.shape_functions_type)
case 'w'
for ii = numel(self.shape_functions):-1:1
proj_lims(:, :, ii) = cat(1, self.shape_functions{ii}(:).bbwim);
end
proj_lims = [min(proj_lims(:, 1, :), [], 3), max(proj_lims(:, 2, :), [], 3)];
delta_omegas = proj_lims(:, 2) - proj_lims(:, 1) + 1;
case 'none'
[delta_omegas, proj_lims] = self.sampler.get_omega_deviations();
delta_omegas = delta_omegas(sel_ref, :);
proj_lims = proj_lims(sel_ref, :);
chosen_depts = max(blob_dephs, delta_omegas);
Nicola Vigano
committed
num_ws = max(chosen_depts) + 2;
self.sino = zeros(num_ws, num_blobs, self.data_type);
% In the case that the sampled orientations determine the shift, we
% have to take it into account!
additional_shift = blob_lims(:, 1) - proj_lims(:, 1);
additional_shift(additional_shift < 0) = 0;
self.pre_paddings = floor((num_ws - chosen_depts) / 2) + 1 + additional_shift;
for ii_b = 1:num_blobs
ints_interval = self.pre_paddings(ii_b):(self.pre_paddings(ii_b) + blob_dephs(ii_b) -1);
masked_blob = bls(ii_b).intm;
masked_blob(~bls(ii_b).mask) = 0;
self.sino(ints_interval, ii_b) = squeeze(sum(sum(masked_blob, 1), 2));
end
self.sino = reshape(self.sino, [], 1);
Nicola Vigano
committed
if (self.verbose)
fprintf('Sino size: [%d, %d]\n', self.size_sino);
end
end
function sino = get_sinogram(self)
sino = reshape(self.sino, self.size_sino);
end
function comp_sino = get_computed_sinogram(self)
if (~isempty(self.volume))
Nicola Vigano
committed
comp_sino = self.fp(self.volume);
comp_sino = reshape(comp_sino, self.size_sino);
else
error('GtGrainODFSolver:no_reconstruction', ...
'No reconstruction performed!');
end
end
function vol = get_volume(self)
vol = reshape(self.volume, self.size_volume);
end
function or = get_orientations(self)
Nicola Vigano
committed
grid_gr = self.sampler.get_orientations();
or = reshape(grid_gr, self.size_volume);
end
function r_vecs = get_R_vectors(self)
Nicola Vigano
committed
grid_gr = self.sampler.get_orientations();
r_vecs = [grid_gr{:}];
r_vecs = cat(1, r_vecs(:).R_vector);
function build_projection_matrices(self)
fprintf('Computing projection matrices..')
c = tic();
switch (self.shape_functions_type)
case 'none'
self.build_projection_matrices_sf_none();
case 'w'
self.build_projection_matrices_sf_w();
case {'nw2uvw', 'uvw'}
self.build_projection_matrices_sf_uvw();
end
fprintf('\b\b: Done in %f seconds.\n', toc(c));
end
function build_projection_matrices_sf_none(self)
bls = self.sampler.ref_gr.proj(self.sampler.detector_index).bl(self.sampler.selected);
num_ws = self.get_num_ws();
bls_bbws = cat(1, bls(:).bbwim);
min_conds = bls_bbws(:, 1) - self.pre_paddings + 1;
max_conds = min_conds + num_ws - 1;
b_ws = [];
b_cs = zeros(0, 1, self.data_type);
Nicola Vigano
committed
grid_gr = self.sampler.get_orientations();
num_orients = numel(grid_gr);
om_step = gtAcqGetOmegaStep(self.parameters, self.sampler.detector_index);
Nicola Vigano
committed
ab = grid_gr{ii_o}.allblobs;
ws = ab.omega(self.sampler.selected) / om_step;
min_ws = floor(ws);
max_ws = min_ws + 1;
max_cs = ws - min_ws;
min_cs = 1 - max_cs;
ok_mins = (min_ws >= min_conds) & (min_ws <= max_conds);
ok_maxs = (max_ws >= min_conds) & (max_ws <= max_conds) & (max_cs > eps('single'));
indx_mins = find(ok_mins);
indx_maxs = find(ok_maxs);
b_ws = [b_ws; ...
min_ws(indx_mins) - min_conds(indx_mins) + 1; ...
max_ws(indx_maxs) - min_conds(indx_maxs) + 1]; %#ok<AGROW>
min_cs(indx_mins); max_cs(indx_maxs)]; %#ok<AGROW>
ii_o(ones(numel(indx_mins) + numel(indx_maxs), 1), 1)]; %#ok<AGROW>
end
sino_indx = sub2ind(self.size_sino, b_ws, b_is);
self.S = sparse( ...
sino_indx, b_os, b_cs, ...
numel(self.sino), num_orients);
self.St = self.S';
end
function build_projection_matrices_sf_w(self)
bls = self.sampler.ref_gr.proj(self.sampler.detector_index).bl(self.sampler.selected);
num_ws = self.get_num_ws();
num_blobs = self.get_num_blobs();
bls_bbws = cat(1, bls(:).bbwim);
min_conds = bls_bbws(:, 1) - self.pre_paddings + 1;
max_conds = min_conds + num_ws - 1;
b_ws = [];
b_cs = zeros(0, 1, self.data_type);
b_is = [];
b_os = [];
num_orients = numel(self.shape_functions);
for ii_o = 1:num_orients
sf = self.shape_functions{ii_o};
cs = cat(1, sf(:).intm);
cs = reshape(cs, [], 1);
ws = cell(num_blobs, 1);
is = cell(num_blobs, 1);
for ii_b = 1:num_blobs
ws{ii_b} = sf(ii_b).bbwim(1):sf(ii_b).bbwim(2);
is{ii_b} = ii_b(ones(sf(ii_b).bbsize(1), 1));
end
ws = reshape([ws{:}], [], 1);
is = cat(1, is{:});
wrong = ws < min_conds(is) | ws > max_conds(is);
if (any(wrong))
ii_o
find(wrong)
[min_conds(is(wrong)), ws(wrong), ...
max_conds(is(wrong)), ...
(max_conds(is(wrong))-min_conds(is(wrong))+1), ...
(ws(wrong) - min_conds(is(wrong)) + 1) ]
end
b_ws = [b_ws; (ws - min_conds(is) + 1)]; %#ok<AGROW>
b_cs = [b_cs; cs]; %#ok<AGROW>
b_is = [b_is; is]; %#ok<AGROW>
b_os = [b_os; ii_o(ones(numel(cs), 1), 1)]; %#ok<AGROW>
end
sino_indx = sub2ind(self.size_sino, b_ws, b_is);
self.S = sparse( ...
sino_indx, b_os, b_cs, ...
numel(self.sino), num_orients);
self.St = self.S';
end
function build_projection_matrices_sf_nw(self)
error('GtGrainODFwSolver:wrong_argument', ...
'Eta-Omega shape functions not supported!')
function build_projection_matrices_sf_uvw(self)
error('GtGrainODFwSolver:wrong_argument', ...
'UV-Omega shape functions not supported!')
end
fprintf('Computing projection weights..')
c = tic();
self.St_ws = self.bp(gtMathsGetSameSizeOnes(self.sino, self.data_type));
self.S__ws = self.fp(gtMathsGetSameSizeOnes(self.St_ws, self.data_type));
self.St_ws = 1 ./ (self.St_ws + (self.St_ws == 0));
self.S__ws = 1 ./ (self.S__ws + (self.S__ws == 0));
fprintf('\b\b: Done in %f seconds.\n', toc(c));
end
function solve_sirt(self)
c = tic();
residuals = zeros(self.num_iter, 1);
x0 = self.bw(self.bp(self.fw(self.sino)));
x0(x0 < 0) = 0;
res_norm_0 = gtMathsNorm_l2(self.sino);
x = x0;
fprintf('Solving SIRT: ')
for ii = 1:self.num_iter
current_time = toc(c);
num_chars = fprintf('%03d/%03d (%g, ETA: %g)', ii, self.num_iter, ...
current_time, current_time/(ii - 1)*self.num_iter - current_time);
res = self.sino - comp_ints;
residuals(ii) = gtMathsNorm_l2(res) / res_norm_0;
res = self.fw(res);
res_vol = self.bp(res);
res_vol = self.bw(res_vol);
x = x + res_vol;
x(x < 0) = 0;
fprintf(repmat('\b', [1 num_chars]));
end
res = self.sino - comp_ints;
res_norm = gtMathsNorm_l2(res) / res_norm_0;
fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), res_norm);
if (self.verbose)
figure, semilogy(residuals)
end
self.volume = x;
end
function solve_cplsnn(self)
c = tic();
residuals = zeros(self.num_iter, 1);
p = gtMathsGetSameSizeZeros(self.S__ws);
x = gtMathsGetSameSizeZeros(self.St_ws);
res_norm_0 = gtMathsNorm_l2(self.sino);
xe = x;
fprintf('Solving CPLSNN: ')
for ii = 1:self.num_iter
current_time = toc(c);
num_chars = fprintf('%03d/%03d (%g, ETA: %g)', ii, self.num_iter, ...
current_time, current_time/(ii - 1)*self.num_iter - current_time);
p = (p + self.fw(self.fp(xe) - self.sino)) ./ (1 + self.S__ws);
x = x - self.bw(self.bp(p));
x(x < 0) = 0;
xe = x + (x - xo);
residuals(ii) = gtMathsNorm_l2(self.sino - self.fp(x)) / res_norm_0;
fprintf(repmat('\b', [1 num_chars]));
end
fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), residuals(end));
if (self.verbose)
figure, semilogy(residuals)
end
self.volume = x;
end
function solve_cplsl1nn(self, lambda)
c = tic();
residuals = zeros(self.num_iter, 1);
p = gtMathsGetSameSizeZeros(self.S__ws);
q = gtMathsGetSameSizeZeros(self.St_ws);
q1 = gtMathsGetSameSizeOnes(self.St_ws);
x = gtMathsGetSameSizeZeros(self.St_ws);
res_norm_0 = gtMathsNorm_l2(self.sino);
self.tau_ws = 1 ./ (1 ./ self.St_ws + 1);
xe = x;
fprintf('Solving CPLSL1NN: ')
for ii = 1:self.num_iter
current_time = toc(c);
num_chars = fprintf('%03d/%03d (%g, ETA: %g)', ii, self.num_iter, ...
current_time, current_time/(ii - 1)*self.num_iter - current_time);
p = (p + self.fw(self.fp(xe) - self.sino)) ./ (1 + self.S__ws);
q = (qn ./ max(q1, abs(qn)));
x = x - (self.bp(p) + lambda .* q) .* self.tau_ws;
x(x < 0) = 0;
xe = x + (x - xo);
residuals(ii) = gtMathsNorm_l2(self.sino - self.fp(x)) / res_norm_0;
fprintf(repmat('\b', [1 num_chars]));
end
fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), residuals(end));
if (self.verbose)
figure, semilogy(residuals)
end
self.volume = x;
end
function solve_cpl1nn(self)
rescaled_sino = self.sino ./ norm(self.sino(:)) .* self.get_num_ws();
c = tic();
residuals = zeros(self.num_iter, 1);
p = gtMathsGetSameSizeZeros(self.S__ws);
p1 = gtMathsGetSameSizeOnes(self.S__ws);
x = gtMathsGetSameSizeZeros(self.St_ws);
res_norm_0 = gtMathsNorm_l2(rescaled_sino);
self.tau_ws = 1 ./ (1 ./ self.St_ws + 1);
xe = x;
fprintf('Solving CPL1NN: ')
for ii = 1:self.num_iter
current_time = toc(c);
num_chars = fprintf('%03d/%03d (%g, ETA: %g)', ii, self.num_iter, ...
current_time, current_time/(ii - 1)*self.num_iter - current_time);
pn = p + self.fw(self.fp(xe) - rescaled_sino);
p = pn ./ max(p1, abs(pn));
x = x - self.bw(self.bp(p));
x(x < 0) = 0;
xe = x + (x - xo);
residuals(ii) = gtMathsNorm_l2(rescaled_sino - self.fp(x)) / res_norm_0;
fprintf(repmat('\b', [1 num_chars]));
end
x = x .* norm(self.sino(:)) ./ self.get_num_ws();
fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), residuals(end));
if (self.verbose)
figure, semilogy(residuals)
end
self.volume = x;
end
end
methods (Access = protected)
function x = fp(self, x)
x = self.S * x;
end
function x = bp(self, x)
x = self.St * x;
end
function x = fw(self, x)
x = x .* self.S__ws;
end
function x = bw(self, x)
x = x .* self.St_ws;
end
function num_ws = get_num_ws(self)
num_ws = self.size_sino(1);
end
function num_bs = get_num_blobs(self)
num_bs = self.size_sino(2);
end