-
Nicola Vigano authored
Signed-off-by:
Nicola Vigano <nicola.vigano@esrf.fr> git-svn-id: https://svn.code.sf.net/p/dct/code/trunk@1009 4c865b51-4357-4376-afb4-474e03ccb993
Nicola Vigano authoredSigned-off-by:
Nicola Vigano <nicola.vigano@esrf.fr> git-svn-id: https://svn.code.sf.net/p/dct/code/trunk@1009 4c865b51-4357-4376-afb4-474e03ccb993
gtMathsSIRTSolveCell.m 7.05 KiB
function [bestSol, flag, bestRes, bestIter, squareResiduals] = ...
gtMathsSIRTSolveCell(A, At, b, x0, toll, maxIter, verbose, use_c_functions)
% GTMATHSSIRTSOLVECELL SIRT solver for Cell arrays
% [bestSol, flag, res, bestIter, resvec] = gtMathsSIRTSolveCell(A, At, b, x0, toll, maxIter[, verbose[, use_c_functions]])
%
% Inputs:
% - A: a matrix, a cell of matrices or a function handle
% - At: the transpose of A. If A is a matrix, it is ignored
% - b: a vector or a cell of vectors
% - toll: tollerance
% - maxIter: maximum number of iterations
% - x0: initial solution (if not spcified, 0 will be the predefined)
% - verbose: boolean flag to be verbose or not
% - use_c_functions: boolean flag to activate accelerated C++ functions
%
% Outputs:
% - bestSol: solution found by the algorithm
% - flag: Exit condition (extended from matlab's 'pcg' routine)
% - bestRes: relative residual of best solution
% - numIters: number of iterations performed
% - resvec: vector of square residuals for each iteration
%
% Flags:
% - 0: terminated with success
% - 1: reached max number of iterations
% - 2: Preconditioner is ill-conditioned (Not used)
% - 3: Not returned, yet. (Stagnation of the method)
% - 4: One of the scalar quantities became either too big or too small
%
if (~exist('verbose', 'var'))
out = GtConditionalOutput(false);
else
out = GtConditionalOutput(verbose);
end
if (~exist('use_c_functions', 'var'))
use_c_functions = false;
end
if (~iscell(b))
if (isnumeric(b))
b = { b };
else
error('gtMathsSIRTSolveCell:wrong_argument', ...
'This function works on Cell arrays');
end
end
if (~isa(A, 'function_handle'))
if (iscell(A))
if (length(A) == length(b))
At = A;
for ii = 1:length(A)
At{ii} = At{ii}';
end
At = @(x)multiplyCellMatrix(At, x);
A = @(x)multiplyCellMatrix(A, x);
else
error('gtMathsSIRTSolveCell:wrong_argument', ...
'Size of cell A matrix, and cell vector b mismatch!');
end
elseif (isnumeric(A))
warning('gtMathsSIRTSolveCell:parameter', ...
['I''m assuming that the input A is a square matrix, that' ...
' suits for the vectors in b']);
At = @(x)multiplyMatrix(A', x);
A = @(x)multiplyMatrix(A, x);
end
elseif (~isa(At, 'function_handle'))
if (iscell(A))
if (length(At) == length(b))
At = @(x)multiplyCellMatrix(At, x);
else
error('gtMathsSIRTSolveCell:wrong_argument', ...
'Size of cell At matrix, and cell vector b mismatch!');
end
elseif (isnumeric(At))
At = @(x)multiplyMatrix(At, x);
end
end
numCells = length(b);
% To see the trend of residuals
squareResiduals = zeros(maxIter, 1);
Aweights = getRowsSum(A, x0);
Atweights = getColumnsSum(At, b);
% Solution to the system
if (use_c_functions)
z = internal_cell_copy(x0);
else
z = x0;
end
r = measureResidual(A, b, z, use_c_functions);
nextSquareResidual = squareNorm(r, use_c_functions);
% Best solutions in the iterations (just in case we run till the end)
initialResidualNorm = sqrt(squareNorm(b, use_c_functions));
nextResidualNorm = sqrt(nextSquareResidual);
bestRes = nextResidualNorm / initialResidualNorm;
bestSol = z;
bestIter = 0;
for jj = 1:maxIter
% Next iteration
squareResiduals(jj) = nextSquareResidual;
d = backProjectResidual(At, r, Aweights, Atweights, use_c_functions);
% Update the solution
if (use_c_functions)
z = internal_cell_sum_assign(z, d);
else
for ii = 1:numCells
z{ii} = z{ii} + d{ii};
end
end
% Evaluation of the residual
r = measureResidual(A, b, z, use_c_functions);
% computing beta and the weights
nextSquareResidual = squareNorm(r, use_c_functions);
if (nextSquareResidual == 0 || isinf(nextSquareResidual))
flag = 4;
return;
end
% Test if it is the best, and save it
nextResidualNorm = sqrt(nextSquareResidual);
if (nextResidualNorm < bestRes * initialResidualNorm)
bestRes = nextResidualNorm / initialResidualNorm;
bestSol = z;
bestIter = jj;
end
out.fprintf('Iteration %03d/%03d: BestResidualNorm %f, ResidualNorm %5.20f\n', ...
jj, maxIter, bestRes, nextResidualNorm);
% Exit condition
if (nextResidualNorm < toll * initialResidualNorm)
out.fprintf('\nExit by tollerance\n');
flag = 0;
return;
end
end
flag = 1;
end
function value = squareNorm(x, use_c_functions)
if (use_c_functions)
value = internal_cell_square_norm(x);
% value = internal_cell_dot_product(x, x);
else
value = 0;
for kk = 1:length(x)
tempSquare = sum(x{kk} .* x{kk});
value = value + sum(tempSquare(:));
end
end
end
function prod = multiplyCellMatrix(M, x)
numCells = length(M);
prod = cell(1, numCells);
for ii = 1:numCells
prod{ii} = M{ii} * x{ii};
end
end
function prod = multiplyMatrix(M, x)
numCells = length(x);
prod = cell(1, numCells);
for ii = 1:numCells
prod{ii} = M * x{ii};
end
end
function is_finite = isFinite(x)
is_finite = true;
numCells = length(x);
for ii = 1:numCells
is_finite = all(isfinite(x{ii}));
if (~is_finite), break; end
end
end
function r = measureResidual(A, b, z, use_c_functions)
% Evaluation of the true residual
numCells = length(b);
tempAz = A(z);
if (use_c_functions)
r = internal_cell_sub_copy(b, tempAz);
else
for ii = 1:numCells
r{ii} = b{ii} - tempAz{ii};
end
end
end
function p = backProjectResidual(At, r, Aweights, Atweights, use_c_functions)
if (use_c_functions)
p = internal_cell_div_copy(r, Aweights);
p = At(p);
p = internal_cell_div_assign(p, Atweights);
else
numCells = length(r);
p = cell(1, numCells);
for ii = 1:numCells
p{ii} = r{ii} ./ Aweights{ii};
end
p = At(p);
for ii = 1:numCells
p{ii} = p{ii} ./ Atweights{ii};
end
end
end
function Aweights = getRowsSum(A, x0)
numCells = length(x0);
onesVols = cell(1, numCells);
for ii = 1:numCells
onesVols{ii} = ones(size(x0{ii}));
end
Aweights = A(onesVols);
end
function Atweights = getColumnsSum(At, b)
numCells = length(b);
onesVols = cell(1, numCells);
for ii = 1:numCells
onesVols{ii} = ones(size(b{ii}));
end
Atweights = At(onesVols);
end