Skip to content
Snippets Groups Projects
Commit 6f144612 authored by Nicola Vigano's avatar Nicola Vigano
Browse files

CG-Cell: fixed non-zero starting point and added reporting of best iteration number


Signed-off-by: default avatarNicola Vigano <nicola.vigano@esrf.fr>

git-svn-id: https://svn.code.sf.net/p/dct/code/trunk@899 4c865b51-4357-4376-afb4-474e03ccb993
parent 579233ea
No related branches found
No related tags found
No related merge requests found
function [bestSol, flag, bestRes, jj, squareResiduals] = ... function [bestSol, flag, bestRes, bestIter, squareResiduals] = ...
gtMathsCGSolveCell(A, b, toll, maxIter, precond, x0, verbose, use_c_functions) gtMathsCGSolveCell(A, b, toll, maxIter, precond, x0, verbose, use_c_functions)
% GTMATHSCGSOLVECELL Conjugate Gradient solver for Cell arrays % GTMATHSCGSOLVECELL Conjugate Gradient solver for Cell arrays
% [bestSol, flag, res, numIters, resvec] = gtMathsCGSolveCell(A, b, maxIter[, precond[, x0[, verbose[, use_c_functions]]]]) % [bestSol, flag, res, bestIter, resvec] = gtMathsCGSolveCell(A, b, maxIter[, precond[, x0[, verbose[, use_c_functions]]]])
% %
% Inputs: % Inputs:
% - A: a matrix, a cell of matrices or a function handle % - A: a matrix, a cell of matrices or a function handle
...@@ -84,18 +84,19 @@ function [bestSol, flag, bestRes, jj, squareResiduals] = ... ...@@ -84,18 +84,19 @@ function [bestSol, flag, bestRes, jj, squareResiduals] = ...
for ii = 1:numCells for ii = 1:numCells
z{ii} = zeros(size(b{ii})); z{ii} = zeros(size(b{ii}));
end end
% Initial residual which is exactly the b vector, since z = 0
if (use_c_functions)
r = internal_cell_copy(b);
else
r = b;
end
else else
if (use_c_functions) if (use_c_functions)
z = internal_cell_copy(x0); z = internal_cell_copy(x0);
else else
z = x0; z = x0;
end end
end r = measureResidual(A, b, z, use_c_functions);
% Initial residual which is exactly the b vector, since z = 0
if (use_c_functions)
r = internal_cell_copy(b);
else
r = b;
end end
% Computing preconditioner % Computing preconditioner
...@@ -123,10 +124,10 @@ function [bestSol, flag, bestRes, jj, squareResiduals] = ... ...@@ -123,10 +124,10 @@ function [bestSol, flag, bestRes, jj, squareResiduals] = ...
if (~isempty(precond)), nextDotRho = dotProduct(r, y, use_c_functions); end if (~isempty(precond)), nextDotRho = dotProduct(r, y, use_c_functions); end
% Best solutions in the iterations (just in case we run till the end) % Best solutions in the iterations (just in case we run till the end)
nextResidualNorm = sqrt(nextSquareResidual); initialResidualNorm = sqrt(squareNorm(b, use_c_functions));
initialResidualNorm = nextResidualNorm;
bestRes = 1; bestRes = 1;
bestSol = d; bestSol = d;
bestIter = 0;
for jj = 1:maxIter for jj = 1:maxIter
% Next iteration % Next iteration
...@@ -165,15 +166,7 @@ function [bestSol, flag, bestRes, jj, squareResiduals] = ... ...@@ -165,15 +166,7 @@ function [bestSol, flag, bestRes, jj, squareResiduals] = ...
% Evaluation of the residual % Evaluation of the residual
if (mod(jj+1, 50) == 0) && false % <-- Let's disable for now if (mod(jj+1, 50) == 0) && false % <-- Let's disable for now
% Evaluation of the true residual r = measureResidual(A, b, z, use_c_functions);
tempAz = A(z);
if (use_c_functions)
r = internal_cell_sub_copy(b, tempAz);
else
for ii = 1:numCells
r{ii} = b{ii} - tempAz{ii};
end
end
else else
% Let's update the residual % Let's update the residual
if (use_c_functions) if (use_c_functions)
...@@ -224,6 +217,7 @@ function [bestSol, flag, bestRes, jj, squareResiduals] = ... ...@@ -224,6 +217,7 @@ function [bestSol, flag, bestRes, jj, squareResiduals] = ...
if (nextResidualNorm < bestRes * initialResidualNorm) if (nextResidualNorm < bestRes * initialResidualNorm)
bestRes = nextResidualNorm / initialResidualNorm; bestRes = nextResidualNorm / initialResidualNorm;
bestSol = z; bestSol = z;
bestIter = jj;
end end
out.fprintf('Iteration %03d/%03d: BestResidualNorm %f, ResidualNorm %5.20f\n', ... out.fprintf('Iteration %03d/%03d: BestResidualNorm %f, ResidualNorm %5.20f\n', ...
...@@ -313,3 +307,16 @@ function is_finite = isFinite(x) ...@@ -313,3 +307,16 @@ function is_finite = isFinite(x)
if (~is_finite), break; end if (~is_finite), break; end
end end
end end
function r = measureResidual(A, b, z, use_c_functions)
% Evaluation of the true residual
numCells = length(b);
tempAz = A(z);
if (use_c_functions)
r = internal_cell_sub_copy(b, tempAz);
else
for ii = 1:numCells
r{ii} = b{ii} - tempAz{ii};
end
end
end
...@@ -67,6 +67,15 @@ function [theoX, A, cgX] = gtMathsCGSolveCellExample ...@@ -67,6 +67,15 @@ function [theoX, A, cgX] = gtMathsCGSolveCellExample
time_to_compute = toc(); time_to_compute = toc();
fprintf('\nMatlab''s implementation: %f seconds (NumIters: %3d, Res: %g, flag: %d)\n\n', time_to_compute, numIters, cgres, flag); fprintf('\nMatlab''s implementation: %f seconds (NumIters: %3d, Res: %g, flag: %d)\n\n', time_to_compute, numIters, cgres, flag);
tic();
[cgX{7}, flag, cgres, numIters, sr4] = gtMathsCGSolveCell(H11, b, toll*1e4, maxIters, diag(1 ./ diag(A)), [], verbose, false);
time_to_compute = toc();
fprintf('\nMatlab''s implementation: %f seconds (NumIters: %3d, Res: %g, flag: %d)\n\n', time_to_compute, numIters, cgres, flag);
tic();
[cgX{8}, flag, cgres, numIters, sr5] = gtMathsCGSolveCell(H11, b, toll, maxIters, diag(1 ./ diag(A)), cgX{7}, verbose, false);
time_to_compute = toc();
fprintf('\nMatlab''s implementation: %f seconds (NumIters: %3d, Res: %g, flag: %d)\n\n', time_to_compute, numIters, cgres, flag);
if (verbose) if (verbose)
disp([theoX{1}, cgX{1}{1}, cgX{2}{1}]) disp([theoX{1}, cgX{1}{1}, cgX{2}{1}])
end end
...@@ -86,4 +95,9 @@ function [theoX, A, cgX] = gtMathsCGSolveCellExample ...@@ -86,4 +95,9 @@ function [theoX, A, cgX] = gtMathsCGSolveCellExample
plot(sqrt(sr1), 'r') plot(sqrt(sr1), 'r')
plot(sqrt(sr2), 'g') plot(sqrt(sr2), 'g')
hold hold
figure, semilogy(sqrt(sr4))
hold
plot([zeros(find(sr4 ~= 0, 1, 'last'), 1); sqrt(sr5)], 'r')
hold
end end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment