Skip to content
Snippets Groups Projects
gtMathsSIRTSolveCell.m 7.05 KiB
function [bestSol, flag, bestRes, bestIter, squareResiduals] = ...
    gtMathsSIRTSolveCell(A, At, b, x0, toll, maxIter, verbose, use_c_functions)
% GTMATHSSIRTSOLVECELL SIRT solver for Cell arrays
%   [bestSol, flag, res, bestIter, resvec] = gtMathsSIRTSolveCell(A, At, b, x0, toll, maxIter[, verbose[, use_c_functions]])
%
%   Inputs:
%       - A: a matrix, a cell of matrices or a function handle
%       - At: the transpose of A. If A is a matrix, it is ignored
%       - b: a vector or a cell of vectors
%       - toll: tollerance
%       - maxIter: maximum number of iterations
%       - x0: initial solution (if not spcified, 0 will be the predefined)
%       - verbose: boolean flag to be verbose or not
%       - use_c_functions: boolean flag to activate accelerated C++ functions
%
%   Outputs:
%       - bestSol: solution found by the algorithm
%       - flag: Exit condition (extended from matlab's 'pcg' routine)
%       - bestRes: relative residual of best solution
%       - numIters: number of iterations performed
%       - resvec: vector of square residuals for each iteration
%
%   Flags:
%       - 0: terminated with success
%       - 1: reached max number of iterations
%       - 2: Preconditioner is ill-conditioned (Not used)
%       - 3: Not returned, yet. (Stagnation of the method)
%       - 4: One of the scalar quantities became either too big or too small
%

    if (~exist('verbose', 'var'))
        out = GtConditionalOutput(false);
    else
        out = GtConditionalOutput(verbose);
    end

    if (~exist('use_c_functions', 'var'))
        use_c_functions = false;
    end

    if (~iscell(b))
        if (isnumeric(b))
            b = { b };
        else
            error('gtMathsSIRTSolveCell:wrong_argument', ...
                  'This function works on Cell arrays');
        end
    end
    if (~isa(A, 'function_handle'))
        if (iscell(A))
            if (length(A) == length(b))
                At = A;
                for ii = 1:length(A)
                    At{ii} = At{ii}';
                end
                At = @(x)multiplyCellMatrix(At, x);
                A = @(x)multiplyCellMatrix(A, x);
            else
                error('gtMathsSIRTSolveCell:wrong_argument', ...
                    'Size of cell A matrix, and cell vector b mismatch!');
            end
        elseif (isnumeric(A))
            warning('gtMathsSIRTSolveCell:parameter', ...
                    ['I''m assuming that the input A is a square matrix, that' ...
                     ' suits for the vectors in  b']);
            At = @(x)multiplyMatrix(A', x);
            A = @(x)multiplyMatrix(A, x);
        end
    elseif (~isa(At, 'function_handle'))
        if (iscell(A))
            if (length(At) == length(b))
                At = @(x)multiplyCellMatrix(At, x);
            else
                error('gtMathsSIRTSolveCell:wrong_argument', ...
                    'Size of cell At matrix, and cell vector b mismatch!');
            end
        elseif (isnumeric(At))
            At = @(x)multiplyMatrix(At, x);
        end
    end

    numCells = length(b);

    % To see the trend of residuals
    squareResiduals = zeros(maxIter, 1);

    Aweights = getRowsSum(A, x0);
    Atweights = getColumnsSum(At, b);

    % Solution to the system
    if (use_c_functions)
        z = internal_cell_copy(x0);
    else
        z = x0;
    end
    r = measureResidual(A, b, z, use_c_functions);
    nextSquareResidual = squareNorm(r, use_c_functions);

    % Best solutions in the iterations (just in case we run till the end)
    initialResidualNorm = sqrt(squareNorm(b, use_c_functions));

    nextResidualNorm = sqrt(nextSquareResidual);
    bestRes = nextResidualNorm / initialResidualNorm;
    bestSol = z;
    bestIter = 0;

    for jj = 1:maxIter
        % Next iteration
        squareResiduals(jj) = nextSquareResidual;

        d = backProjectResidual(At, r, Aweights, Atweights, use_c_functions);

        % Update the solution
        if (use_c_functions)
            z = internal_cell_sum_assign(z, d);
        else
            for ii = 1:numCells
                z{ii} = z{ii} + d{ii};
            end
        end

        % Evaluation of the residual
        r = measureResidual(A, b, z, use_c_functions);

        % computing beta and the weights
        nextSquareResidual = squareNorm(r, use_c_functions);
        if (nextSquareResidual == 0 || isinf(nextSquareResidual))
            flag = 4;
            return;
        end

        % Test if it is the best, and save it
        nextResidualNorm = sqrt(nextSquareResidual);
        if (nextResidualNorm < bestRes * initialResidualNorm)
            bestRes = nextResidualNorm / initialResidualNorm;
            bestSol = z;
            bestIter = jj;
        end

        out.fprintf('Iteration %03d/%03d: BestResidualNorm %f, ResidualNorm %5.20f\n', ...
                    jj, maxIter, bestRes, nextResidualNorm);

        % Exit condition
        if (nextResidualNorm < toll * initialResidualNorm)
            out.fprintf('\nExit by tollerance\n');
            flag = 0;
            return;
        end
    end

    flag = 1;
end

function value = squareNorm(x, use_c_functions)
    if (use_c_functions)
        value = internal_cell_square_norm(x);
%         value = internal_cell_dot_product(x, x);
    else
        value = 0;
        for kk = 1:length(x)
            tempSquare = sum(x{kk} .* x{kk});
            value = value + sum(tempSquare(:));
        end
    end
end

function prod = multiplyCellMatrix(M, x)
    numCells = length(M);
    prod = cell(1, numCells);
    for ii = 1:numCells
        prod{ii} = M{ii} * x{ii};
    end
end

function prod = multiplyMatrix(M, x)
    numCells = length(x);
    prod = cell(1, numCells);
    for ii = 1:numCells
        prod{ii} = M * x{ii};
    end
end

function is_finite = isFinite(x)
    is_finite = true;
    numCells = length(x);
    for ii = 1:numCells
        is_finite = all(isfinite(x{ii}));
        if (~is_finite), break; end
    end
end

function r = measureResidual(A, b, z, use_c_functions)
% Evaluation of the true residual
    numCells = length(b);
    tempAz = A(z);
    if (use_c_functions)
        r = internal_cell_sub_copy(b, tempAz);
    else
        for ii = 1:numCells
            r{ii} = b{ii} - tempAz{ii};
        end
    end
end

function p = backProjectResidual(At, r, Aweights, Atweights, use_c_functions)
    if (use_c_functions)
        p = internal_cell_div_copy(r, Aweights);
        p = At(p);
        p = internal_cell_div_assign(p, Atweights);
    else
        numCells = length(r);
        p = cell(1, numCells);
        for ii = 1:numCells
            p{ii} = r{ii} ./ Aweights{ii};
        end
        p = At(p);
        for ii = 1:numCells
            p{ii} = p{ii} ./ Atweights{ii};
        end
    end
end

function Aweights = getRowsSum(A, x0)
    numCells = length(x0);
    onesVols = cell(1, numCells);
    for ii = 1:numCells
        onesVols{ii} = ones(size(x0{ii}));
    end

    Aweights = A(onesVols);
end

function Atweights = getColumnsSum(At, b)
    numCells = length(b);
    onesVols = cell(1, numCells);
    for ii = 1:numCells
        onesVols{ii} = ones(size(b{ii}));
    end

    Atweights = At(onesVols);
end