Skip to content
Snippets Groups Projects
Gt6DBlobReconstructor.m 36.6 KiB
Newer Older
classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
    properties
        % Solution
        currentSolution = {};

        % Variables to plot
        normInitialResidual;
        normResiduals;

        verbose = false;
        detector_norm = Gt6DBlobReconstructor.possible_detector_norms{1};
        tv_norm = Gt6DBlobReconstructor.possible_tv_norms{1};
        tv_strategy = Gt6DBlobReconstructor.possible_tv_strategies{1};
        lambda_l1 = 1e-2;
        lambda_tv = 1;
        algo_ops_c_functions = true;
    properties (Constant)
        possible_detector_norms = {'l2', 'kl', 'l1'};
        possible_tv_norms = {'l12', 'l1', 'ln'};
        possible_tv_strategies = {'groups', 'volume'};
    end

    methods (Access = public)
        function self = Gt6DBlobReconstructor(volumes, blobs, proj_sizes_uv, varargin)
            fprintf('Initializing BLOB Recontruction:\n - Setup..');
            ct = tic();
            c = ct;

            for ii_d = 1:numel(blobs)
                for n = 1:numel(blobs(ii_d).data)
                    if (any(blobs(ii_d).size_uv ~= [size(blobs(ii_d).data{n}, 1), size(blobs(ii_d).data{n}, 3)]))
                        error('Gt6DBlobReconstructor:wrong_argument', ...
                            'Blob: %d is malformed!', n)
                    end
                end
            end
            if (iscell(volumes))
                [vols_size(1), vols_size(2), vols_size(3)] = size(volumes{1});
            else
                vols_size = volumes;
            end
            self = self@Gt6DVolumeToBlobProjector(vols_size, blobs, proj_sizes_uv, varargin{:});
            % Let's complain about wrong options
            if (~ismember(self.tv_norm, self.possible_tv_norms))
                error([mfilename ':wrong_argument'], ...
                    'TV-norm: %s is not allowed! Use one of [%s] instead', ...
                    self.tv_norm, sprintf(' "%s"', self.possible_tv_norms{:}))
            if (~ismember(self.tv_strategy, self.possible_tv_strategies))
                error([mfilename ':wrong_argument'], ...
                    'TV-strategy: %s is not allowed! Use one of [%s] instead', ...
                    self.tv_strategy, sprintf(' "%s"', self.possible_tv_strategies{:}))
            if (~ismember(self.detector_norm, self.possible_detector_norms))
                error([mfilename ':wrong_argument'], ...
                    'Detector-norm: %s is not allowed! Use one of [%s] instead', ...
                    self.detector_norm, sprintf(' "%s"', self.possible_detector_norms{:}))
            self.statistics.add_task('cp_dual_update_detector', 'CP Dual variable (detector) update');
            self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_FP', 'Forward Projection');
            self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_SB', 'Sinograms -> Blobs');
            self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_IN', 'Projected blobs initialization');
            self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_PSF', 'PSF application');
            self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_RS', 'Projection pixelsize rescaling');
            self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_prox', 'Proximal application');
            self.statistics.add_task('cp_dual_update_l1', 'CP Dual variable (l1) update');
            self.statistics.add_task('cp_dual_update_tv', 'CP Dual variable (TV) update');
            self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_reduction', 'Volumes reduction');
            self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_gradient', 'Gradient');
            self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_proximal', 'Proximal application');
            self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_divergence', 'Divergence');
            self.statistics.add_task('cp_primal_update', 'CP Primal variable update');
            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_BP', 'Back Projection');
            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_BS', 'Blobs -> Sinograms');
            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_PSF', 'PSF application');
            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_CORR', 'Primal correction computation');
            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_RS', 'Volume contribution rescaling');
            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_APP', 'Primal update application');
            if (iscell(volumes))
                self.currentSolution = volumes;
            else
                num_geoms = self.get_number_geometries();
                fprintf('\b\b: Done. (%2.1f s)\n - Creating initial volumes (size: [%s] x %d bytes x %d vols = %g GB): ', ...
                    toc(c), strjoin(arrayfun(@(x){sprintf('%d', x)}, vols_size), ', '), ...
                    4, num_geoms, prod(vols_size) * 4 * num_geoms / 2^30)
                c = tic();
                self.currentSolution = cell(num_geoms, 1);
                for ii = 1:num_geoms
                    num_chars = fprintf('%03d/%03d', ii, num_geoms);
                    self.currentSolution{ii} = zeros(vols_size, self.data_type);
                    fprintf(repmat('\b', [1 num_chars]));
                end
            end
            num_det = self.get_number_detectors();
            self.psf = cell(num_det, 1);
            for ii_d = 1:num_det
                if (~isempty(blobs(ii_d).psf))
                    fprintf('\b\b: Done. (%2.1f s)\n - Initializing PSF(s)..', toc(c))
                    c = tic();
                    num_psfs_det = numel(blobs(ii_d).psf);
                    self.psf{ii_d} = cell(num_psfs_det, 1);
                    for ii_p = 1:num_psfs_det
                        psf_d = blobs(ii_d).psf{ii_p};
                        self.psf{ii_d}{ii_p} = GtPSF();
                        self.psf{ii_d}{ii_p}.set_psf_direct(psf_d, blobs(ii_d).size_uv);
                    end
                else
                    self.psf{ii_d} = {};
            fprintf('\b\b: Done. (%2.1f s)\n - Weights:\n', toc(c));
            c = tic();
            self.initializeWeights();
            fprintf('   Done. (%2.1f s)\nTot (%3.1f s)).\n', toc(c), toc(ct));
        end

        function solution = getCurrentSolution(self)
            solution = self.currentSolution;
        function [proj_blobs, proj_spots] = getProjectionOfCurrentSolution(self)
            proj_blobs = self.compute_fwd_projection(self.currentSolution, false);
            proj_spots = cell(size(proj_blobs));
            num_det = self.get_number_detectors();
            for ii_d = 1:num_det
                proj_spots{ii_d} = cell(size(proj_blobs{ii_d}));
                for ii_b = 1:numel(proj_blobs{ii_d})
                    proj_blobs{ii_d}{ii_b} = proj_blobs{ii_d}{ii_b} * self.intensity_scale;
                    proj_spots{ii_d}{ii_b} = sum(proj_blobs{ii_d}{ii_b}, 2);
                end
                proj_spots{ii_d} = permute(cat(2, proj_spots{ii_d}{:}), [1 3 2]);
            end
            if (exist('volumes', 'var'))
                num_geoms = self.get_number_geometries();
                num_vols = numel(volumes);
                if (num_vols ~= num_geoms)
                    error('Gt6DBlobReconstructor:wrong_argument', ...
                        'Wrong new volumes number. Expected: %d, got: %d', ...
                        num_geoms, num_vols)
                end
                for n = 1:num_geoms
                    if (any(size(volumes{n}) ~= self.volume_geometry))
                        error('Gt6DBlobReconstructor:wrong_argument', ...
                            'Wrong volume size for new volume: %d. Expected: (%s), got: (%s)', ...
                            n, sprintf(' %d', self.volume_geometry), ...
                            sprintf(' %d', size(volumes{n})) )
                    end
                end
                self.currentSolution = volumes;
            else
                self.currentSolution = gtMathsGetSameSizeZeros(self.currentSolution);
            end

            fprintf('Recomputing Weights:\n');
            c = tic();
            self.initializeWeights();
            fprintf('- Done in %3.1f s.\n', toc(c));

            self.currentDetDual = {};
            self.statistics.clear();
        function cp_l1(self, numIters)
            self.cp('6DL1', numIters)
        function cp_tv(self, numIters)
            self.cp('6DTV', numIters);
        function cp_tvl1(self, numIters)
            self.cp('6DTVL1', numIters);
        function cp(self, algo, numIters)
            do_tv_update = ~isempty(strfind(upper(algo), 'TV'));
            do_l1_update = ~isempty(strfind(upper(algo), 'L1'));

            self.initializeVariables(numIters);
            sample_rate = 5;

            fprintf('Initializing CP_%s weights: ', upper(algo))
            c = tic();
            [sigma1, sigma1_1, tau] = self.init_cp_weights(algo);
            fprintf('Done (%g seconds).\nInitializing CP_%s variables: ', toc(c), upper(algo))
            c = tic();
            [p, nextEnhancedSolution, q_l1, q_tv] = self.init_cp_vars(algo);
            fprintf('Done (%g seconds).\n', toc(c))
            fprintf('Reconstruction using algorithm: %s\n', upper(algo));
            fprintf(' - Detector data-fidelity term: %s\n', self.detector_norm);
            if (do_l1_update)
                fprintf(' - l1-term lambda: %g\n', self.lambda_l1);
                fprintf(' - TV-term lambda: %g\n', self.lambda_tv);
                fprintf(' - TV norm: %s\n', self.tv_norm);
                fprintf(' - TV strategy: %s\n', self.tv_strategy);
            end
            fprintf('Iteration: ');

            c = tic();
            switch (numel(numIters))
                case 1
                    iters = 1:numIters;
                case 2
                    iters = numIters(1):numIters(2);
                otherwise
                    iters = numIters;
            end

            tot_iters = numel(iters);
            for ii = iters
                numchars = fprintf('%03d/%03d (in %f s)', ii-min(iters), tot_iters, toc(c));

                % Computing update of dual detector
                self.statistics.tic('cp_dual_update_detector');
                proj_bls = self.compute_fwd_projection(nextEnhancedSolution);
                [proj_bls, psf_time] = self.apply_psf(proj_bls, true);
                self.statistics.add_timestamp(psf_time, 'cp_dual_update_detector', 'cp_dual_detector_PSF');
                p = self.update_dual_detector(p, proj_bls, sigma1, sigma1_1);
                self.statistics.toc('cp_dual_update_detector');

                % Computing update dual TV
                    self.statistics.tic('cp_dual_update_tv');
                    [q_tv, mdiv_tv] = self.update_dual_TV(q_tv, nextEnhancedSolution);
                    self.statistics.toc('cp_dual_update_tv');
                end

                % Computing update dual l1
                    self.statistics.tic('cp_dual_update_l1');
                    q_l1 = self.update_dual_l1(q_l1, nextEnhancedSolution);
                    self.statistics.toc('cp_dual_update_l1');
                end
                % Function for computing the update to apply
                switch (upper(algo))
                    case '6DLS'
                        % q_l1 is just a volume of zeros
                        compute_update_primal = @(ii)gtCxxMathsCellTimes(...
                            q_l1(ii), -self.lambda_l1, 'threads', self.num_threads);
                        compute_update_primal = @(ii)get_divergence(self, mdiv_tv, ii);
                        compute_update_primal = @(ii)gtCxxMathsCellPlus(...
                            get_divergence(self, mdiv_tv, ii), q_l1(ii), ...
                            'scale', -self.lambda_l1, 'threads', self.num_threads);
                % Computing update primal
                self.statistics.tic('cp_primal_update');
                [p_blurr, psf_time] = self.apply_psf(p, false);
                self.statistics.add_timestamp(psf_time, 'cp_primal_update', 'cp_primal_PSF');
                nextEnhancedSolution = self.compute_bwd_projection(...
                    nextEnhancedSolution, p_blurr, compute_update_primal, tau);
                self.statistics.toc('cp_primal_update');

                fprintf('%s', repmat(sprintf('\b'), 1, numchars))

                if (self.verbose && (mod(ii, sample_rate) == 0))
                    self.normResiduals(ii / sample_rate) ...
                        = self.compute_functional_value(p, algo);
            end
            fprintf('(%04d) Done in %f seconds.\n', tot_iters, toc(c) )

            self.printStats();

            if (self.verbose)
                figure;
                subplot(1, 2, 1), plot(self.normResiduals(1:floor(max(iters) / sample_rate)));
                subplot(1, 2, 2), semilogy(abs(self.normResiduals));
            end

        function num_bytes = get_peak_memory_consumption(self, algo)
            bytes_blobs = self.get_memory_consumption_blobs();
            bytes_vols = self.get_memory_consumption_volumes();

            % currentSolution, blobs, and ASTRA's copies
            % Contribution of the projection weigths
            % Dual variable on detector
            num_bytes = base_bytes + bytes_blobs;
            % EnhancedSolution
            num_bytes = num_bytes + bytes_vols;
                    % Dual variable in real-space
                    num_bytes = num_bytes + bytes_vols;
                    % Computation temporaries
                case {'6DLS', '6DTV'}
                    % Computation temporaries
                    num_bytes = num_bytes + bytes_blobs;
                otherwise
                    error('Gt6DBlobReconstructor:wrong_argument', ...
                        'No such algorithm as "%s"', algo)
            end
        end

        function num_bytes = get_memory_consumption_blobs(self)
            num_bytes = GtBenchmarks.getSizeVariable(self.blobs);
        end

        function num_bytes = get_memory_consumption_volumes(self)
            num_bytes = GtBenchmarks.getSizeVariable(self.currentSolution);
        end

        function num_bytes = get_memory_consumption_sinograms(self)
            num_bytes = 0;
            float_bytes = GtBenchmarks.getSizeVariable(zeros(1, 1, 'single'));
            num_geoms = self.get_number_geometries();
            num_det = self.get_number_detectors();
            for ii_d = 1:num_det
                for n = 1:num_geoms
                    sinogram_size = [ ...
                        self.proj_sizes_uv(1), size(self.geometries{ii_d}{n}, 1), self.proj_sizes_uv(2)];
                    num_bytes = num_bytes + float_bytes * prod(sinogram_size);
    %%% Primal and Duals updating functions
        function [q, mdivq] = update_dual_TV(self, q, new_enh_sol)
            switch (self.tv_strategy)
                case 'volume'
                    num_groups = 1;
                    or_ranges = {1:self.orientation_groups(end)};
                    sigmas = 1 ./  (2 * numel(new_enh_sol));
                case 'groups'
                    num_groups = size(self.orientation_groups, 1);
                    or_ranges = cell(num_groups, 1);
                    sigmas = 1 ./ (2 .* (self.orientation_groups(:, 2) - self.orientation_groups(:, 1) + 1));
                    for ii_g = 1:num_groups
                        or_ranges{ii_g} = self.orientation_groups(ii_g, 1):self.orientation_groups(ii_g, 2);
                    end
            end
            reduction_time = 0;
            gradient_time = 0;
            proximal_time = 0;
            dsES = cell(num_groups, 3);
            mdivq = cell(num_groups, 1);

            for ii_g = 1:num_groups
                c = tic();
                if (self.algo_ops_c_functions)
                    sES = gtCxxMathsSumCellVolumes(new_enh_sol(or_ranges{ii_g}));
                else
                    sES = gtMathsSumCellVolumes(new_enh_sol(or_ranges{ii_g}));
                end
                reduction_time = reduction_time + toc(c);

                c = tic();
                dsES(ii_g, :) = gtMathsGradient(sES);
                gradient_time = gradient_time + toc(c);

                c = tic();
                for ii_d = 1:3
                    q{ii_g, ii_d} = q{ii_g, ii_d} + sigmas(ii_g) * dsES{ii_g, ii_d};
                end
                proximal_time = proximal_time + toc(c);
            end
                    for ii_g = 1:num_groups
                        grad_l2 = sqrt(q{ii_g, 1} .^ 2 + q{ii_g, 2} .^ 2 + q{ii_g, 3} .^ 2);
                        for ii_d = 1:3
                            q{ii_g, ii_d} = q{ii_g, ii_d} ./ max(1, grad_l2);
                        end
                    for n = 1:numel(q)
                        q{n} = q{n} ./ max(1, abs(q{n}));
                    end
                case 'ln'
                    size_vols = size(q{1});
                    num_voxels = numel(q{1});
                    z = zeros(num_groups, 3, num_voxels, self.data_type);
                            z(ii_g, ii_d, :) = reshape(q{ii_g, ii_d}, [], 1);
                        end
                    end
                    ZZ = gtMathsMatrixProduct(permute(z, [2 1 3]), z);
                    if (false) % <- this is always correct
                        Vts = zeros(3, 3, num_voxels);
                        es = zeros(1, 3, num_voxels);
                        for ii_v = 1:num_voxels
                            [Vts(:, :, ii_v), E] = eigs(ZZ(:, :, ii_v));
                            es(1, :, ii_v) = diag(E);
                        end
                    else % <- this can have problems with only one volume
                        [es, Vts] = gtMathsEig3x3SymmPosDef(ZZ);
                    end
                    svals = sqrt(es); % <- they're always positive
                    Vs = permute(Vts, [2 1 3]);
                    sigmas = 1 ./ max(1, svals);
                    sigmas(svals == 0) = 0;
                    sigmas = bsxfun(@times, sigmas, Vts);
                    sigmas = gtMathsMatrixProduct(Vs, sigmas);
                    z = gtMathsMatrixProduct(z, sigmas);
                    for ii_g = 1:num_groups
                        for ii_d = 1:3
                            q{ii_g, ii_d} = reshape(z(ii_g, ii_d, :), size_vols);
                mdivq{ii_g} = - self.lambda_tv * gtMathsDivergence(q(ii_g, :));
            end
            divergence_time = toc(c);

            self.statistics.add_timestamp(reduction_time, 'cp_dual_update_tv', 'cp_dual_tv_reduction');
            self.statistics.add_timestamp(proximal_time, 'cp_dual_update_tv', 'cp_dual_tv_proximal');
            self.statistics.add_timestamp(gradient_time, 'cp_dual_update_tv', 'cp_dual_tv_gradient');
            self.statistics.add_timestamp(divergence_time, 'cp_dual_update_tv', 'cp_dual_tv_divergence');
        end

        function mdivq_ii = get_divergence(self, mdivq, sel_inds)
            num_groups = size(self.orientation_groups, 1);
            if (strcmpi(self.tv_strategy, 'volume') || (num_groups == 1))
            elseif (strcmpi(self.tv_strategy, 'groups'))
                sel_inds = reshape(sel_inds, 1, []);
                ind = bsxfun(@ge, sel_inds, self.orientation_groups(:, 1)) ...
                    & bsxfun(@le, sel_inds, self.orientation_groups(:, 2));
                [ind, ~] = find(ind);
%                 ind = find( ...
%                     (sel_inds >= self.orientation_groups(:, 1)) ...
%                     & (sel_inds <= self.orientation_groups(:, 2)), 1);
            mdivq_ii = mdivq(ind);
        function p = update_dual_detector(self, p, proj_bls, sigma, sigma_1)
                for ii_d = 1:num_det
                    if (strcmpi(self.detector_norm, 'l2'))
                        p{ii_d} = gtCxx6DUpdateDualDetector(self.detector_norm, ...
                            p{ii_d}, self.blobs(ii_d).data, proj_bls{ii_d}, sigma{ii_d}, sigma_1{ii_d}, 'threads', self.num_threads);
                    else
                        p{ii_d} = gtCxx6DUpdateDualDetector(self.detector_norm, ...
                            p{ii_d}, self.blobs(ii_d).data, proj_bls{ii_d}, sigma{ii_d}, 'threads', self.num_threads);
                    end
                for ii_d = 1:num_det
                    switch (lower(self.detector_norm))
                        case 'kl'
                            for n = 1:numel(p{ii_d})
                                temp_p = p{ii_d}{n} + sigma{ii_d}{n} .* proj_bls{ii_d}{n};
                                temp_bls = 4 .* sigma{ii_d}{n} .* self.blobs(ii_d).data{n};
                                temp_d = (temp_p - 1) .^ 2 + temp_bls;
                                temp_d = sqrt(temp_d);
                                p{ii_d}{n} = (1 + temp_p - temp_d) * 0.5;
                            end
                        case 'l2'
                            for n = 1:numel(p{ii_d})
                                p{ii_d}{n} = (p{ii_d}{n} + sigma{ii_d}{n} .* (proj_bls{ii_d}{n} - self.blobs(ii_d).data{n})) .* sigma_1{ii_d}{n};
                            for n = 1:numel(p{ii_d})
                                p{ii_d}{n} = p{ii_d}{n} + sigma{ii_d}{n} .* (proj_bls{ii_d}{n} - self.blobs(ii_d).data{n});
                                p{ii_d}{n} = p{ii_d}{n} ./ max(1, abs(p{ii_d}{n}));
                            end
                    end
            self.statistics.add_timestamp(toc(c), 'cp_dual_update_detector', 'cp_dual_detector_prox');
        function q = update_dual_l1(self, q, new_enh_sol)
                q = gtCxxMathsCellPlus(q, new_enh_sol, 'copy', false, 'threads', self.num_threads);
                q = gtCxxMathsCellBoxOneL1(q, 'copy', false, 'threads', self.num_threads);
                for n = 1:numel(q)
                    q{n} = q{n} + new_enh_sol{n};
        function [curr_sol, curr_enh_sol] = update_primal(self, curr_sol, curr_enh_sol, correction, tau)
            if (self.algo_ops_c_functions)
                [curr_sol, curr_enh_sol] = gtCxx6DUpdatePrimal( ...
                    curr_sol, curr_enh_sol, correction, tau, 'threads', self.num_threads);
                    v = curr_sol{ii} + correction{ii} .* tau{ii};
                    v(v < 0) = 0;
                    curr_enh_sol{ii} = v + (v - curr_sol{ii});
    end

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%% Projection functions
    methods (Access = protected)
        function proj_bls = compute_fwd_projection(self, volumes, apply_stats)
            if (~exist('apply_stats', 'var') || isempty(apply_stats))
                apply_stats = true;
            end

            chunk_size = self.get_jobs_chunk_size();
            num_geoms = self.get_number_geometries();
            num_det = self.get_number_detectors();
            proj_bls = cell(num_det, 1);
            do_renorm = abs(renorm_factors - 1) > eps('single');

            for ii_d = 1:num_det
                c = tic();
                if (self.algo_ops_c_functions)
                    proj_bls{ii_d} = gtCxx6DCreateEmptyBlobs(self.blobs(ii_d).data, 'threads', self.num_threads);
                else
                    proj_bls{ii_d} = gtMathsGetSameSizeZeros(self.blobs(ii_d).data);
                end
                timing_in = timing_in + toc(c);

                for n = 1:chunk_size:num_geoms
                    chunk_safe_size = min(chunk_size, (num_geoms - n + 1));
                    inds = n:(n+chunk_safe_size-1);
                    [proj_bls{ii_d}, fp_time, sb_time] = self.fwd_project_volume(proj_bls{ii_d}, volumes(inds), ii_d, inds);
                    timing_fp = timing_fp + fp_time;
                    timing_sb = timing_sb + sb_time;
                end

                if (do_renorm(ii_d))
                    c = tic();
                    for ii_b = 1:numel(proj_bls{ii_d})
                        proj_bls{ii_d}{ii_b} = proj_bls{ii_d}{ii_b} * renorm_factors(ii_d);
                    end
                    timing_rs = timing_rs + toc(c);
                end
            if (apply_stats)
                self.statistics.add_timestamp(timing_fp, 'cp_dual_update_detector', 'cp_dual_detector_FP');
                self.statistics.add_timestamp(timing_sb, 'cp_dual_update_detector', 'cp_dual_detector_SB');
                self.statistics.add_timestamp(timing_in, 'cp_dual_update_detector', 'cp_dual_detector_IN');
                if (any(do_renorm))
                    self.statistics.add_timestamp(timing_rs, 'cp_dual_update_detector', 'cp_dual_detector_RS');
                end
        end

        function nextEnhancedSolution = compute_bwd_projection(self, nextEnhancedSolution, p_blurr, compute_update_primal, tau)
            timing_bp = 0;
            timing_bs = 0;
            chunk_size = self.get_jobs_chunk_size();
            renorm_factors = self.lambda_det ./ (self.detector_ss .^ 2);
%             renorm_factors = 1 ./ (self.detector_ss .^ 2);
            do_renorm = abs(renorm_factors - 1) > eps('single');
            for n = 1:chunk_size:num_geoms
                chunk_safe_size = min(chunk_size, (num_geoms - n + 1));
                inds = n:(n+chunk_safe_size-1);
                v = cell(chunk_safe_size, num_det);
                for ii_d = 1:num_det
                    [v(:, ii_d), bp_time, bs_time] = self.bwd_project_volume(p_blurr{ii_d}, ii_d, inds);
                    timing_bp = timing_bp + bp_time;
                    timing_bs = timing_bs + bs_time;

                    if (do_renorm(ii_d))
                        c = tic();
                        v(:, ii_d) = gtCxxMathsCellTimes(v(:, ii_d), renorm_factors(ii_d), 'threads', self.num_threads, 'copy', false);
                        timing_rs = timing_rs + toc(c);
                    end
                end
                up_prim = [compute_update_primal(inds), v];
                if (self.algo_ops_c_functions)
                    correction = gtCxxMathsSumCellVolumes(up_prim, 2);
                else
                    correction = gtMathsSumCellVolumes(up_prim, 2);
                end
                timing_corr = timing_corr + toc(c);

                c = tic();
                [self.currentSolution(inds), nextEnhancedSolution(inds)] ...
                    = self.update_primal(self.currentSolution(inds), nextEnhancedSolution(inds), correction, tau(inds));
                timing_app = timing_app + toc(c);
            end

            self.statistics.add_timestamp(timing_bp, 'cp_primal_update', 'cp_primal_BP')
            self.statistics.add_timestamp(timing_bs, 'cp_primal_update', 'cp_primal_BS')
            self.statistics.add_timestamp(timing_corr, 'cp_primal_update', 'cp_primal_CORR')
            self.statistics.add_timestamp(timing_app, 'cp_primal_update', 'cp_primal_APP')
            if (any(do_renorm))
                self.statistics.add_timestamp(timing_rs, 'cp_primal_update', 'cp_primal_RS')
            end
    end

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%% Initialization functions
    methods (Access = protected)
        function initializeWeights(self)
            vols_ones = cell(size(self.currentSolution));
            vols_ones(:) = {ones(self.volume_geometry, self.data_type)};
            self.fwd_weights = self.compute_fwd_projection(vols_ones, false);
            self.fwd_weights = self.apply_psf(self.fwd_weights, true);
            for ii_d = 1:num_det
                for ii_b = 1:numel(self.fwd_weights)
                    self.fwd_weights{ii_d}{ii_b} = abs(self.fwd_weights{ii_d}{ii_b});
                end
            fprintf('   + Computing back-projection weights..');
            num_ors = self.get_number_geometries();
            self.bwd_weights = cell(num_ors, 1);
            self.bwd_weights(1:num_ors) = {0};

            for ii_d = 1:num_det
%                     self.bwd_weights{ii} = self.bwd_weights{ii} + sum(self.offsets{ii_d}{ii}.proj_coeffs);
                    self.bwd_weights{ii} = self.bwd_weights{ii} + self.lambda_det(ii_d) * sum(self.offsets{ii_d}{ii}.proj_coeffs);
        function [sigma1, sigma1_1, tau] = init_cp_weights(self, algo)
            for ii_d = 1:num_det
                sigma1{ii_d} = cell(size(self.fwd_weights{ii_d}));
                for n = 1:numel(sigma1{ii_d})
                    % Necessary in case of use of OTF
                    tol = eps('single') * max(self.fwd_weights{ii_d}{n}(:));
                    sigma1{ii_d}{n} = 1 ./ (self.fwd_weights{ii_d}{n} + (self.fwd_weights{ii_d}{n} < tol));
                end

                sigma1_1{ii_d} = sigma1{ii_d};
                for n = 1:numel(sigma1_1{ii_d})
                    sigma1_1{ii_d}{n} = 1 ./ (1 + sigma1_1{ii_d}{n});
                end
            end

            num_geoms = self.get_number_geometries();
            tau = cell(size(self.bwd_weights));
            switch (upper(algo))
                case '6DLS'
                    for n = 1:num_geoms
                        tau{n} = cast(- 1 ./ (self.bwd_weights{n}), self.data_type);
                        tau{n} = cast(- 1 ./ (self.bwd_weights{n} + 1 * self.lambda_l1), self.data_type);
                        tau{n} = cast(- 1 ./ (self.bwd_weights{n} + 6 * self.lambda_tv), self.data_type);
                        tau{n} = cast(- 1 ./ (self.bwd_weights{n} + 6 * self.lambda_tv + 1 * self.lambda_l1), self.data_type);
        function [p, p1, q_l1, q_tv] = init_cp_vars(self, algo)
                num_det = self.get_number_detectors();
                p = cell(num_det, 1);
                for ii_d = 1:num_det
                    p{ii_d} = gtMathsGetSameSizeZeros(self.blobs(ii_d).data);
                end
            do_tv_update = ~isempty(strfind(upper(algo), 'TV'));
            do_l1_update = ~isempty(strfind(upper(algo), 'L1'));
            do_ls_update = ~isempty(strfind(upper(algo), 'LS'));

            if (self.algo_ops_c_functions)
                p1 = gtCxxMathsCellCopy(self.currentSolution, 'threads', self.num_threads);
            else
                p1 = self.currentSolution;
            end
            if (do_l1_update)
                q_l1 = gtMathsGetSameSizeZeros(self.currentSolution);
            elseif (do_ls_update)
                q_l1 = gtMathsGetSameSizeZeros(self.currentSolution{1});
            else
                q_l1 = [];
            end

            if (do_tv_update)
                if (strcmpi(self.tv_strategy, 'groups'))
                    num_groups = size(self.orientation_groups, 1);
                    q_tv = gtMathsGetSameSizeZeros(self.currentSolution(ones(num_groups, 3)));
                    q_tv = reshape(q_tv, num_groups, 3);
                else
                    q_tv = gtMathsGetSameSizeZeros(self.currentSolution([1 1 1]));
        function [proj_data, psf_time] = apply_psf(self, proj_data, is_direct)
            num_det = self.get_number_detectors();

            psf_time = [];
            for ii_d = 1:num_det
                if (~isempty(self.psf{ii_d}))
                    c = tic();
                    for n = 1:numel(proj_data{ii_d})
                        if (numel(self.psf{ii_d}) == 1)
                            proj_data{ii_d}{n} = self.psf{ii_d}{1}.apply(proj_data{ii_d}{n}, is_direct);
                        else
                            proj_data{ii_d}{n} = self.psf{ii_d}{n}.apply(proj_data{ii_d}{n}, is_direct);
                        end
        function [value, summed_vols] = compute_functional_value(self, p, algo)
            value = gtMathsDotProduct(p, self.blobs);
            num_geoms = self.get_number_geometries();
            if (any(strcmpi(algo, {'6DLS', '6DL1'})))
                proj_bls = self.compute_fwd_projection(self.currentSolution, false);
                proj_bls = cat(1, proj_bls{:});
                det_blobs = cat(1, self.blobs(:).data);
                    proj_bls = gtCxxMathsCellMinus(proj_bls, det_blobs, 'copy', false, 'threads', self.num_threads);
                value = value + gtMathsNorm_l2(proj_bls);
            if (self.algo_ops_c_functions)
                summed_vols = gtCxxMathsSumCellVolumes(self.currentSolution);
            else
                summed_vols = gtMathsSumCellVolumes(self.currentSolution);
            end
            if (any(strcmpi(algo, {'6DTV', '6DTVL1'})))
                sES = gtMathsGradient(summed_vols / num_geoms);
                value = value + self.lambda_tv * gtMathsNorm_l1(sES);
            if (any(strcmpi(algo, {'6DL1', '6DTVL1'})))
                value = value + self.lambda_l1 * gtMathsNorm_l1(self.currentSolution);
            end
        end

        function initializeVariables(self, numIters)
            switch (numel(numIters))
                case 1
                    self.normResiduals = zeros(numIters, 1);
                case 2
                    self.normResiduals(numIters(1):numIters(2)) = 0;
                otherwise
                    self.normResiduals(numIters) = 0;
            end
        end
    end
end