Skip to content
Snippets Groups Projects
Commit 3407c96e authored by Nicola Vigano's avatar Nicola Vigano
Browse files

CG-Cell: added initial exit flag support (limited matlab's pcg compatibility)


Signed-off-by: default avatarNicola Vigano <nicola.vigano@esrf.fr>

git-svn-id: https://svn.code.sf.net/p/dct/code/trunk@895 4c865b51-4357-4376-afb4-474e03ccb993
parent bd1a8e3a
No related branches found
No related tags found
No related merge requests found
function [bestSolution, bestRes, jj, squareResiduals] = ... function [bestSol, flag, bestRes, jj, squareResiduals] = ...
gtMathsCGSolveCell(A, b, toll, maxIter, precond, verbose, use_c_functions) gtMathsCGSolveCell(A, b, toll, maxIter, precond, verbose, use_c_functions)
% GTMATHSCGSOLVECELL Conjugate Gradient solver for Cell arrays % GTMATHSCGSOLVECELL Conjugate Gradient solver for Cell arrays
% [x, res, numIters, resvec] = gtMathsCGSolveCell(A, b, maxIter[, precond[, verbose[, use_c_functions]]]) % [bestSol, flag, res, numIters, resvec] = gtMathsCGSolveCell(A, b, maxIter[, precond[, verbose[, use_c_functions]]])
% %
% Inputs: % Inputs:
% - A: a matrix, a cell of matrices or a function handle % - A: a matrix, a cell of matrices or a function handle
...@@ -12,6 +12,13 @@ function [bestSolution, bestRes, jj, squareResiduals] = ... ...@@ -12,6 +12,13 @@ function [bestSolution, bestRes, jj, squareResiduals] = ...
% function handle % function handle
% - verbose: boolean flag to be verbose or not % - verbose: boolean flag to be verbose or not
% - use_c_functions: boolean flag to activate accelerated C++ functions % - use_c_functions: boolean flag to activate accelerated C++ functions
%
% Outputs:
% - bestSol: solution found by the algorithm
% - flag: Exit condition (extended from matlab's 'pcg' routine)
% - bestRes: relative residual of best solution
% - numIters: number of iterations performed
% - resvec: vector of square residuals for each iteration
% %
if (~exist('precond', 'var') || isempty(precond)) if (~exist('precond', 'var') || isempty(precond))
...@@ -97,7 +104,7 @@ function [bestSolution, bestRes, jj, squareResiduals] = ... ...@@ -97,7 +104,7 @@ function [bestSolution, bestRes, jj, squareResiduals] = ...
nextResidualNorm = sqrt(nextSquareResidual); nextResidualNorm = sqrt(nextSquareResidual);
initialResidualNorm = nextResidualNorm; initialResidualNorm = nextResidualNorm;
bestRes = 1; bestRes = 1;
bestSolution = d; bestSol = d;
for jj = 1:maxIter for jj = 1:maxIter
% Next iteration % Next iteration
...@@ -112,10 +119,12 @@ function [bestSolution, bestRes, jj, squareResiduals] = ... ...@@ -112,10 +119,12 @@ function [bestSolution, bestRes, jj, squareResiduals] = ...
if (squareAd < 0) if (squareAd < 0)
if (jj == 0) if (jj == 0)
out.fprintf('\nExit by bad Hessian\n') out.fprintf('\nExit by bad Hessian\n')
break; flag = 5;
return;
else else
out.fprintf('\nExit by negativeness\n') out.fprintf('\nExit by negativeness\n')
break; flag = 6;
return;
end end
end end
...@@ -176,7 +185,7 @@ function [bestSolution, bestRes, jj, squareResiduals] = ... ...@@ -176,7 +185,7 @@ function [bestSolution, bestRes, jj, squareResiduals] = ...
nextResidualNorm = sqrt(nextSquareResidual); nextResidualNorm = sqrt(nextSquareResidual);
if (nextResidualNorm < bestRes * initialResidualNorm) if (nextResidualNorm < bestRes * initialResidualNorm)
bestRes = nextResidualNorm / initialResidualNorm; bestRes = nextResidualNorm / initialResidualNorm;
bestSolution = z; bestSol = z;
end end
out.fprintf('Iteration %03d/%03d: BestResidualNorm %f, ResidualNorm %5.20f\n', ... out.fprintf('Iteration %03d/%03d: BestResidualNorm %f, ResidualNorm %5.20f\n', ...
...@@ -185,7 +194,8 @@ function [bestSolution, bestRes, jj, squareResiduals] = ... ...@@ -185,7 +194,8 @@ function [bestSolution, bestRes, jj, squareResiduals] = ...
% Exit condition % Exit condition
if (nextResidualNorm < toll * initialResidualNorm) if (nextResidualNorm < toll * initialResidualNorm)
out.fprintf('\nExit by tollerance\n'); out.fprintf('\nExit by tollerance\n');
break; flag = 0;
return;
end end
% Update the correction to apply at the next iteration % Update the correction to apply at the next iteration
...@@ -198,6 +208,8 @@ function [bestSolution, bestRes, jj, squareResiduals] = ... ...@@ -198,6 +208,8 @@ function [bestSolution, bestRes, jj, squareResiduals] = ...
end end
end end
end end
flag = 1;
end end
% function toll = getTollerance(residual, squareNormGrad, use_c_functions) % function toll = getTollerance(residual, squareNormGrad, use_c_functions)
......
...@@ -29,13 +29,13 @@ function [theoX, A, cgX] = gtMathsCGSolveCellExample ...@@ -29,13 +29,13 @@ function [theoX, A, cgX] = gtMathsCGSolveCellExample
fprintf('\nMatlab''s Gauss-Pivoting: %f seconds\n\n', time_to_compute); fprintf('\nMatlab''s Gauss-Pivoting: %f seconds\n\n', time_to_compute);
tic(); tic();
[cgX{1}, cgres, numIters, sr1] = gtMathsCGSolveCell(H11, b, toll, maxIters, [], verbose, false); [cgX{1}, flag, cgres, numIters, sr1] = gtMathsCGSolveCell(H11, b, toll, maxIters, [], verbose, false);
time_to_compute = toc(); time_to_compute = toc();
fprintf('\nMatlab''s implementation: %f seconds (NumIters: %3d, Res: %g)\n\n', time_to_compute, numIters, cgres); fprintf('\nMatlab''s implementation: %f seconds (NumIters: %3d, Res: %g, flag: %d)\n\n', time_to_compute, numIters, cgres, flag);
tic(); tic();
[cgX{2}, cgres, numIters] = gtMathsCGSolveCell(H11, b, toll, maxIters, [], verbose, true); [cgX{2}, flag, cgres, numIters] = gtMathsCGSolveCell(H11, b, toll, maxIters, [], verbose, true);
time_to_compute = toc(); time_to_compute = toc();
fprintf('\nNicola''s implementation: %f seconds (NumIters: %3d, Res: %g)\n\n', time_to_compute, numIters, cgres); fprintf('\nNicola''s implementation: %f seconds (NumIters: %3d, Res: %g, flag: %d)\n\n', time_to_compute, numIters, cgres, flag);
% tic(); % tic();
% cgX3 = cell(1, numCells); % cgX3 = cell(1, numCells);
...@@ -45,28 +45,39 @@ function [theoX, A, cgX] = gtMathsCGSolveCellExample ...@@ -45,28 +45,39 @@ function [theoX, A, cgX] = gtMathsCGSolveCellExample
% time_to_compute = toc(); % time_to_compute = toc();
% fprintf('\nL1magic''s implementation: %f seconds (NumIters: %d, %d)\n\n', time_to_compute, numIters); % fprintf('\nL1magic''s implementation: %f seconds (NumIters: %d, %d)\n\n', time_to_compute, numIters);
numIters = zeros(1, numCells);
cgres = zeros(1, numCells);
flag = zeros(1, numCells);
tic(); tic();
cgX{4} = cell(1, numCells); cgX{4} = cell(1, numCells);
for ii = 1:numCells for ii = 1:numCells
[cgX{4}{ii}, ~, cgres, numIters, resv] = pcg(H11{ii}, b{ii}, toll, maxIters); [cgX{4}{ii}, flag(ii), cgres(ii), numIters(ii), resv] = pcg(H11{ii}, b{ii}, toll, maxIters);
end end
time_to_compute = toc(); time_to_compute = toc();
fprintf('\nMatlab''s implementation: %f seconds (NumIters: %3d, Res: %g)\n\n', time_to_compute, numIters, cgres); fprintf('\nMatlab''s implementation: %f seconds (NumIters:%s, Res:%s, flag:%s)\n\n', ...
time_to_compute, sprintf(' %3d', numIters), sprintf(' %g', cgres), ...
sprintf(' %d', flag));
tic(); tic();
[cgX{5}, cgres, numIters, sr2] = gtMathsCGSolveCell(H11, b, toll, maxIters, diag(1 ./ diag(A)), verbose, false); [cgX{5}, flag, cgres, numIters, sr2] = gtMathsCGSolveCell(H11, b, toll, maxIters, diag(1 ./ diag(A)), verbose, false);
time_to_compute = toc(); time_to_compute = toc();
fprintf('\nMatlab''s implementation: %f seconds (NumIters: %3d, Res: %g)\n\n', time_to_compute, numIters, cgres); fprintf('\nMatlab''s implementation: %f seconds (NumIters: %3d, Res: %g, flag: %d)\n\n', time_to_compute, numIters, cgres, flag);
tic();
[cgX{6}, flag, cgres, numIters, sr3] = gtMathsCGSolveCell(H11, b, toll, maxIters, diag(1 ./ diag(A)), verbose, true);
time_to_compute = toc();
fprintf('\nMatlab''s implementation: %f seconds (NumIters: %3d, Res: %g, flag: %d)\n\n', time_to_compute, numIters, cgres, flag);
if (verbose) if (verbose)
disp([theoX{1}, cgX{1}{1}, cgX{2}{1}]) disp([theoX{1}, cgX{1}{1}, cgX{2}{1}])
end end
fprintf('Max relative errors:\n'); fprintf('Max relative errors:\n');
populatedCells = find(cellfun(@iscell, cgX));
for ii = 1:numCells for ii = 1:numCells
fprintf(' (%e, %e, %e, %e)\n', ... pattern = sprintf('%e, ', ...
arrayfun(@(x)max((theoX{ii} - cgX{x}{ii}) ./ theoX{ii}), ... arrayfun(@(x)max((theoX{ii} - cgX{x}{ii}) ./ theoX{ii}), ...
find(cellfun(@iscell, cgX))) ); populatedCells) );
fprintf([' (' pattern(1:end-2) ')\n']);
end end
fprintf('\n'); fprintf('\n');
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment