Skip to content
Snippets Groups Projects
Gt6DBlobReconstructor.m 33.1 KiB
Newer Older
classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
    properties
        % Data
        blobs;

        % Solution
        currentSolution = {};

        % Variables to plot
        normInitialResidual;
        normResiduals;

        verbose = false;
        detector_norm = Gt6DBlobReconstructor.possible_detector_norms{1};
        tv_norm = Gt6DBlobReconstructor.possible_tv_norms{1};
        tv_strategy = Gt6DBlobReconstructor.possible_tv_strategies{1};
        algo_ops_c_functions = true;
    properties (Constant)
        possible_detector_norms = {'l2', 'kl', 'l1'};
        possible_tv_norms = {'l12', 'l1', 'ln'};
        possible_tv_strategies = {'groups', 'volume'};
    end

    methods (Access = public)
        function self = Gt6DBlobReconstructor(volumes, blobs, varargin)
            fprintf('Initializing BLOB Recontruction:\n - Setup..');
            ct = tic();
            c = ct;
            blobs_depths = cellfun(@(x)size(x, 2), blobs);
            proj_size = Gt6DBlobReconstructor.getProjSize(blobs);

            if (iscell(volumes))
                [vols_size(1), vols_size(2), vols_size(3)] = size(volumes{1});
            else
                vols_size = volumes;
            end
            self = self@Gt6DVolumeToBlobProjector(vols_size, proj_size, blobs_depths, varargin{:});
            % Let's complain about wrong options
            if (~ismember(self.tv_norm, self.possible_tv_norms))
                error([mfilename ':wrong_argument'], ...
                    'TV-norm: %s is not allowed! Use one of [%s] instead', ...
                    self.tv_norm, sprintf(' "%s"', self.possible_tv_norms{:}))
            if (~ismember(self.tv_strategy, self.possible_tv_strategies))
                error([mfilename ':wrong_argument'], ...
                    'TV-strategy: %s is not allowed! Use one of [%s] instead', ...
                    self.tv_strategy, sprintf(' "%s"', self.possible_tv_strategies{:}))
            if (~ismember(self.detector_norm, self.possible_detector_norms))
                error([mfilename ':wrong_argument'], ...
                    'Detector-norm: %s is not allowed! Use one of [%s] instead', ...
                    self.detector_norm, sprintf(' "%s"', self.possible_detector_norms{:}))
            self.statistics.add_task('cp_dual_update_detector', 'CP Dual variable (detector) update');
            self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_FP', 'Forward Projection');
            self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_SB', 'Sinograms -> Blobs');
            self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_IN', 'Projected blobs initialization');
            self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_PSF', 'PSF application');
            self.statistics.add_task_partial('cp_dual_update_detector', 'cp_dual_detector_prox', 'Proximal application');
            self.statistics.add_task('cp_dual_update_l1', 'CP Dual variable (l1) update');
            self.statistics.add_task('cp_dual_update_tv', 'CP Dual variable (TV) update');
            self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_reduction', 'Volumes reduction');
            self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_gradient', 'Gradient');
            self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_proximal', 'Proximal application');
            self.statistics.add_task_partial('cp_dual_update_tv', 'cp_dual_tv_divergence', 'Divergence');
            self.statistics.add_task('cp_primal_update', 'CP Primal variable update');
            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_BP', 'Back Projection');
            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_BS', 'Blobs -> Sinograms');
            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_PSF', 'PSF application');
            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_CORR', 'Primal correction computation');
            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_APP', 'Primal update application');
            self.blobs = blobs;
            if (iscell(volumes))
                self.currentSolution = volumes;
            else
                num_geoms = self.get_number_geometries();
                fprintf('\b\b: Done. (%2.1f s)\n - Creating initial volumes (size: [%s] x %d bytes x %d vols = %g GB): ', ...
                    toc(c), strjoin(arrayfun(@(x){sprintf('%d', x)}, vols_size), ', '), ...
                    4, num_geoms, prod(vols_size) * 4 * num_geoms / 2^30)
                c = tic();
                self.currentSolution = cell(num_geoms, 1);
                for ii = 1:num_geoms
                    num_chars = fprintf('%03d/%03d', ii, num_geoms);
                    self.currentSolution{ii} = zeros(vols_size, self.data_type);
                    fprintf(repmat('\b', [1 num_chars]));
                end
            end
            if (~isempty(self.psf))
                fprintf('\b\b: Done. (%2.1f s)\n - Initializing PSF(s)..', toc(c))
                c = tic();
                for ii_p = 1:numel(self.psf)
                    psf_d = self.psf{ii_p};
                    self.psf{ii_p} = GtPSF();
                    self.psf{ii_p}.set_psf_direct(psf_d, proj_size);
                end
            end

            fprintf('\b\b: Done. (%2.1f s)\n - Weights:\n', toc(c));
            c = tic();
            self.initializeWeights();
            fprintf('   Done. (%2.1f s)\nTot (%3.1f s)).\n', toc(c), toc(ct));
        end

        function solution = getCurrentSolution(self)
            solution = self.currentSolution;
        function [proj_blobs, proj_spots] = getProjectionOfCurrentSolution(self)
            proj_blobs = self.compute_fwd_projection(self.currentSolution, false);
            proj_spots = cellfun(@(x){sum(x, 2)}, proj_blobs);
            proj_spots = permute(cat(2, proj_spots{:}), [1 3 2]);
        end

        function reInit(self, blobs, volumes)
            if (exist('volumes', 'var'))
                num_geoms = self.get_number_geometries();
                num_vols = numel(volumes);
                if (num_vols ~= num_geoms)
                    error('Gt6DBlobReconstructor:wrong_argument', ...
                        'Wrong new volumes number. Expected: %d, got: %d', ...
                        num_geoms, num_vols)
                end
                for n = 1:num_geoms
                    if (any(size(volumes{n}) ~= self.volume_geometry))
                        error('Gt6DBlobReconstructor:wrong_argument', ...
                            'Wrong volume size for new volume: %d. Expected: (%s), got: (%s)', ...
                            n, sprintf(' %d', self.volume_geometry), ...
                            sprintf(' %d', size(volumes{n})) )
                    end
                end
                self.currentSolution = volumes;
            else
                self.currentSolution = gtMathsGetSameSizeZeros(self.currentSolution);
            end

            if (exist('blobs', 'var') && ~isempty(blobs))
                self.blobs = blobs;
            end

            fprintf('Recomputing Weights:\n');
            c = tic();
            self.initializeWeights();
            fprintf('- Done in %3.1f s.\n', toc(c));

            self.currentDetDual = {};
            self.statistics.clear();
        function cp_l1(self, numIters, lambda)
            self.cp('6DL1', numIters, 'lambda_l1', lambda)
        function cp_tv(self, numIters, lambda)
            self.cp('6DTV', numIters, 'lambda_tv', lambda);
        function cp_tvl1(self, numIters, lambda)
        function cp(self, algo, numIters, varargin)
            conf = struct('lambda_l1', [], 'lambda_tv', []);
            conf = parse_pv_pairs(conf, varargin);

            do_tv_update = ~isempty(strfind(upper(algo), 'TV'));
            do_l1_update = ~isempty(strfind(upper(algo), 'L1'));

            self.initializeVariables(numIters);
            sample_rate = 5;

            fprintf('Initializing CP_%s weights: ', upper(algo))
            c = tic();
            [sigma1, sigma1_1, tau] = self.init_cp_weights(algo);
            fprintf('Done (%g seconds).\nInitializing CP_%s variables: ', toc(c), upper(algo))
            c = tic();
            [p, nextEnhancedSolution, q_l1, q_tv] = self.init_cp_vars(algo);
            fprintf('Done (%g seconds).\n', toc(c))
            fprintf('Reconstruction using algorithm: %s\n', upper(algo));
            fprintf(' - Detector data-fidelity term: %s\n', self.detector_norm);
            if (do_l1_update)
                fprintf(' - l1-term lambda: %g\n', conf.lambda_l1);
                fprintf(' - TV-term lambda: %g\n', conf.lambda_tv);
                fprintf(' - TV norm: %s\n', self.tv_norm);
                fprintf(' - TV strategy: %s\n', self.tv_strategy);
            end
            fprintf('Iteration: ');

            c = tic();
            switch (numel(numIters))
                case 1
                    iters = 1:numIters;
                case 2
                    iters = numIters(1):numIters(2);
                otherwise
                    iters = numIters;
            end

            tot_iters = numel(iters);
            for ii = iters
                numchars = fprintf('%03d/%03d (in %f s)', ii-min(iters), tot_iters, toc(c));

                % Computing update of dual detector
                self.statistics.tic('cp_dual_update_detector');
                proj_bls = self.compute_fwd_projection(nextEnhancedSolution);
                [proj_bls, psf_time] = self.apply_psf(proj_bls, true);
                self.statistics.add_timestamp(psf_time, 'cp_dual_update_detector', 'cp_dual_detector_PSF');
                p = self.update_dual_detector(p, proj_bls, sigma1, sigma1_1);
                self.statistics.toc('cp_dual_update_detector');

                % Computing update dual TV
                    self.statistics.tic('cp_dual_update_tv');
                    [q_tv, mdiv_tv] = self.update_dual_TV(q_tv, nextEnhancedSolution, conf.lambda_tv);
                    self.statistics.toc('cp_dual_update_tv');
                end

                % Computing update dual l1
                    self.statistics.tic('cp_dual_update_l1');
                    q_l1 = self.update_dual_l1(q_l1, nextEnhancedSolution);
                    self.statistics.toc('cp_dual_update_l1');
                end
                % Function for computing the update to apply
                switch (upper(algo))
                    case '6DLS'
                        % q_l1 is just a volume of zeros
                        compute_update_primal = @(ii)gtCxxMathsCellTimes(...
                            q_l1(ii), conf.lambda_l1, 'threads', self.num_threads);
                        compute_update_primal = @(ii)get_divergence(self, mdiv_tv, ii);
                        compute_update_primal = @(ii)gtCxxMathsCellPlus(...
                            get_divergence(self, mdiv_tv, ii), q_l1(ii), ...
                            'scale', conf.lambda_l1, 'threads', self.num_threads);
                % Computing update primal
                self.statistics.tic('cp_primal_update');
                [p_blurr, psf_time] = self.apply_psf(p, false);
                self.statistics.add_timestamp(psf_time, 'cp_primal_update', 'cp_primal_PSF');
                nextEnhancedSolution = self.compute_bwd_projection(...
                    nextEnhancedSolution, p_blurr, compute_update_primal, tau);
                self.statistics.toc('cp_primal_update');

                fprintf('%s', repmat(sprintf('\b'), 1, numchars))

                if (self.verbose && (mod(ii, sample_rate) == 0))
                    self.normResiduals(ii / sample_rate) ...
                        = self.compute_functional_value(p, algo, conf.lambda_l1);
            end
            fprintf('(%04d) Done in %f seconds.\n', tot_iters, toc(c) )

            self.printStats();

            if (self.verbose)
                figure;
                subplot(1, 2, 1), plot(self.normResiduals(1:floor(max(iters) / sample_rate)));
                subplot(1, 2, 2), semilogy(abs(self.normResiduals));
            end

        function num_bytes = get_peak_memory_consumption(self, algo)
            bytes_blobs = self.get_memory_consumption_blobs();
            bytes_vols = self.get_memory_consumption_volumes();

            % currentSolution, blobs, and ASTRA's copies
            % Contribution of the projection weigths
            % Dual variable on detector
            num_bytes = base_bytes + bytes_blobs;
            % EnhancedSolution
            num_bytes = num_bytes + bytes_vols;
                    % Dual variable in real-space
                    num_bytes = num_bytes + bytes_vols;
                    % Computation temporaries
                case {'6DLS', '6DTV'}
                    % Computation temporaries
                    num_bytes = num_bytes + bytes_blobs;
                otherwise
                    error('Gt6DBlobReconstructor:wrong_argument', ...
                        'No such algorithm as "%s"', algo)
            end
        end

        function num_bytes = get_memory_consumption_blobs(self)
            num_bytes = GtBenchmarks.getSizeVariable(self.blobs);
        end

        function num_bytes = get_memory_consumption_volumes(self)
            num_bytes = GtBenchmarks.getSizeVariable(self.currentSolution);
        end

        function num_bytes = get_memory_consumption_sinograms(self)
            num_bytes = 0;
            float_bytes = GtBenchmarks.getSizeVariable(zeros(1, 1, 'single'));
            num_geoms = self.get_number_geometries();
            if (isempty(self.ss_geometries))
                for n = 1:num_geoms
                    sinogram_size = [ ...
                        self.proj_size(1), size(self.geometries{n}, 1), self.proj_size(2)];
                    num_bytes = num_bytes + float_bytes * prod(sinogram_size);
                end
            else
                for n = 1:num_geoms
                    sinogram_size = [ ...
                        self.proj_size(1), size(self.ss_geometries{n}, 1), self.proj_size(2)];
                    num_bytes = num_bytes + float_bytes * prod(sinogram_size);
                end
    %%% Primal and Duals updating functions
        function [q, mdivq] = update_dual_TV(self, q, new_enh_sol, lambda)
            if (~exist('lambda', 'var') || isempty(lambda))
                lambda = 1;
            end

            switch (self.tv_strategy)
                case 'volume'
                    num_groups = 1;
                    or_ranges = {1:self.orientation_groups(end)};
                    sigmas = 1 ./  (2 * numel(new_enh_sol));
                case 'groups'
                    num_groups = size(self.orientation_groups, 1);
                    or_ranges = cell(num_groups, 1);
                    sigmas = 1 ./ (2 .* (self.orientation_groups(:, 2) - self.orientation_groups(:, 1) + 1));
                    for ii_g = 1:num_groups
                        or_ranges{ii_g} = self.orientation_groups(ii_g, 1):self.orientation_groups(ii_g, 2);
                    end
            end
            reduction_time = 0;
            gradient_time = 0;
            proximal_time = 0;
            dsES = cell(num_groups, 3);
            mdivq = cell(num_groups, 1);

            for ii_g = 1:num_groups
                c = tic();
                if (self.algo_ops_c_functions)
                    sES = gtCxxMathsSumCellVolumes(new_enh_sol(or_ranges{ii_g}));
                else
                    sES = gtMathsSumCellVolumes(new_enh_sol(or_ranges{ii_g}));
                end
                reduction_time = reduction_time + toc(c);

                c = tic();
                dsES(ii_g, :) = gtMathsGradient(sES);
                gradient_time = gradient_time + toc(c);

                c = tic();
                for ii_d = 1:3
                    q{ii_g, ii_d} = q{ii_g, ii_d} + sigmas(ii_g) * dsES{ii_g, ii_d};
                end
                proximal_time = proximal_time + toc(c);
            end
                    for ii_g = 1:num_groups
                        grad_l2 = sqrt(q{ii_g, 1} .^ 2 + q{ii_g, 2} .^ 2 + q{ii_g, 3} .^ 2);
                        for ii_d = 1:3
                            q{ii_g, ii_d} = q{ii_g, ii_d} ./ max(1, grad_l2);
                        end
                    for n = 1:numel(q)
                        q{n} = q{n} ./ max(1, abs(q{n}));
                    end
                case 'ln'
                    size_vols = size(q{1});
                    num_voxels = numel(q{1});
                    z = zeros(num_groups, 3, num_voxels, self.data_type);
                            z(ii_g, ii_d, :) = reshape(q{ii_g, ii_d}, [], 1);
                        end
                    end
                    ZZ = gtMathsMatrixProduct(permute(z, [2 1 3]), z);
                    if (false) % <- this is always correct
                        Vts = zeros(3, 3, num_voxels);
                        es = zeros(1, 3, num_voxels);
                        for ii_v = 1:num_voxels
                            [Vts(:, :, ii_v), E] = eigs(ZZ(:, :, ii_v));
                            es(1, :, ii_v) = diag(E);
                        end
                    else % <- this can have problems with only one volume
                        [es, Vts] = gtMathsEig3x3SymmPosDef(ZZ);
                    end
                    svals = sqrt(es); % <- they're always positive
                    Vs = permute(Vts, [2 1 3]);
                    sigmas = 1 ./ max(1, svals);
                    sigmas(svals == 0) = 0;
                    sigmas = bsxfun(@times, sigmas, Vts);
                    sigmas = gtMathsMatrixProduct(Vs, sigmas);
                    z = gtMathsMatrixProduct(z, sigmas);
                    for ii_g = 1:num_groups
                        for ii_d = 1:3
                            q{ii_g, ii_d} = reshape(z(ii_g, ii_d, :), size_vols);
                mdivq{ii_g} = - lambda * gtMathsDivergence(q(ii_g, :));
            end
            divergence_time = toc(c);

            self.statistics.add_timestamp(reduction_time, 'cp_dual_update_tv', 'cp_dual_tv_reduction');
            self.statistics.add_timestamp(proximal_time, 'cp_dual_update_tv', 'cp_dual_tv_proximal');
            self.statistics.add_timestamp(gradient_time, 'cp_dual_update_tv', 'cp_dual_tv_gradient');
            self.statistics.add_timestamp(divergence_time, 'cp_dual_update_tv', 'cp_dual_tv_divergence');
        end

        function mdivq_ii = get_divergence(self, mdivq, sel_inds)
            num_groups = size(self.orientation_groups, 1);
            if (strcmpi(self.tv_strategy, 'volume') || (num_groups == 1))
            elseif (strcmpi(self.tv_strategy, 'groups'))
                sel_inds = reshape(sel_inds, 1, []);
                ind = bsxfun(@ge, sel_inds, self.orientation_groups(:, 1)) ...
                    & bsxfun(@le, sel_inds, self.orientation_groups(:, 2));
                [ind, ~] = find(ind);
%                 ind = find( ...
%                     (sel_inds >= self.orientation_groups(:, 1)) ...
%                     & (sel_inds <= self.orientation_groups(:, 2)), 1);
            mdivq_ii = mdivq(ind);
        function p = update_dual_detector(self, p, proj_bls, sigma, sigma_1)
            if (self.algo_ops_c_functions)
                if (strcmpi(self.detector_norm, 'l2'))
                    p = gtCxx6DUpdateDualDetector(self.detector_norm, p, self.blobs, proj_bls, sigma, sigma_1, 'threads', self.num_threads);
                else
                    p = gtCxx6DUpdateDualDetector(self.detector_norm, p, self.blobs, proj_bls, sigma, 'threads', self.num_threads);
                end
            else
                switch (lower(self.detector_norm))
                    case 'kl'
                        for n = 1:numel(p)
                            temp_p = p{n} + sigma{n} .* proj_bls{n};
                            temp_bls = 4 .* sigma{n} .* self.blobs{n};
                            temp_d = (temp_p - 1) .^ 2 + temp_bls;
                            temp_d = sqrt(temp_d);
                            p{n} = (1 + temp_p - temp_d) * 0.5;
                        end
                    case 'l2'
                        for n = 1:numel(p)
                            p{n} = p{n} + sigma{n} .* (proj_bls{n} - self.blobs{n});
                            p{n} = p{n} .* sigma_1{n};
                        end
                    case 'l1'
                        for n = 1:numel(p)
                            p{n} = p{n} + sigma{n} .* (proj_bls{n} - self.blobs{n});
                            p{n} = p{n} ./ max(1, abs(p{n}));
                        end
                end
            self.statistics.add_timestamp(toc(c), 'cp_dual_update_detector', 'cp_dual_detector_prox');
        function q = update_dual_l1(self, q, new_enh_sol)
                q = gtCxxMathsCellPlus(q, new_enh_sol, 'copy', false, 'threads', self.num_threads);
                q = gtCxxMathsCellBoxOneL1(q, 'copy', false, 'threads', self.num_threads);
                for n = 1:numel(q)
                    q{n} = q{n} + new_enh_sol{n};
        function [curr_sol, curr_enh_sol] = update_primal(self, curr_sol, curr_enh_sol, correction, tau)
            if (self.algo_ops_c_functions)
                [curr_sol, curr_enh_sol] = gtCxx6DUpdatePrimal( ...
                    curr_sol, curr_enh_sol, correction, tau, 'threads', self.num_threads);
                    v = curr_sol{ii} + correction{ii} .* tau{ii};
                    v(v < 0) = 0;
                    curr_enh_sol{ii} = v + (v - curr_sol{ii});
    end

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%% Projection functions
    methods (Access = protected)
        function proj_bls = compute_fwd_projection(self, volumes, apply_stats)
            if (~exist('apply_stats', 'var') || isempty(apply_stats))
                apply_stats = true;
            end

            if (self.algo_ops_c_functions)
                proj_bls = gtCxx6DCreateEmptyBlobs(self.blobs, 'threads', self.num_threads);
            else
                proj_bls = gtMathsGetSameSizeZeros(self.blobs);
            end
            timing_in = toc(c);

            timing_fp = 0;
            timing_sb = 0;

            num_geoms = self.get_number_geometries();
            chunk_size = self.get_jobs_chunk_size();
                chunk_safe_size = min(chunk_size, (num_geoms - n + 1));
                inds = n:(n+chunk_safe_size-1);
                [proj_bls, fp_time, sb_time] = self.fwd_project_volume(proj_bls, volumes(inds), inds);

                timing_fp = timing_fp + fp_time;
                timing_sb = timing_sb + sb_time;
            end

            if (apply_stats)
                self.statistics.add_timestamp(timing_fp, 'cp_dual_update_detector', 'cp_dual_detector_FP');
                self.statistics.add_timestamp(timing_sb, 'cp_dual_update_detector', 'cp_dual_detector_SB');
                self.statistics.add_timestamp(timing_in, 'cp_dual_update_detector', 'cp_dual_detector_IN');
        end

        function nextEnhancedSolution = compute_bwd_projection(self, nextEnhancedSolution, p_blurr, compute_update_primal, tau)
            timing_bp = 0;
            timing_bs = 0;
            timing_app = 0;

            num_geoms = self.get_number_geometries();
            chunk_size = self.get_jobs_chunk_size();
            for n = 1:chunk_size:num_geoms
                chunk_safe_size = min(chunk_size, (num_geoms - n + 1));
                inds = n:(n+chunk_safe_size-1);
                [v, bp_time, bs_time] = self.bwd_project_volume(p_blurr, inds);
                timing_bp = timing_bp + bp_time;
                timing_bs = timing_bs + bs_time;

                up_prim = [compute_update_primal(inds), v];
                if (self.algo_ops_c_functions)
                    correction = gtCxxMathsSumCellVolumes(up_prim, 2);
                else
                    correction = gtMathsSumCellVolumes(up_prim, 2);
                end
                timing_corr = timing_corr + toc(c);

                c = tic();
                [self.currentSolution(inds), nextEnhancedSolution(inds)] ...
                    = self.update_primal(self.currentSolution(inds), nextEnhancedSolution(inds), correction, tau(inds));
                timing_app = timing_app + toc(c);
            end

            self.statistics.add_timestamp(timing_bp, 'cp_primal_update', 'cp_primal_BP')
            self.statistics.add_timestamp(timing_bs, 'cp_primal_update', 'cp_primal_BS')
            self.statistics.add_timestamp(timing_corr, 'cp_primal_update', 'cp_primal_CORR')
            self.statistics.add_timestamp(timing_app, 'cp_primal_update', 'cp_primal_APP')
        end
    end

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%% Initialization functions
    methods (Access = protected)
        function initializeWeights(self)
            fprintf('   + Projecting ones vols..');
            self.fwd_weights = self.apply_psf(self.fwd_weights, true);
            for ii_b = 1:numel(self.fwd_weights)
                self.fwd_weights{ii_b} = abs(self.fwd_weights{ii_b});
            end
            fprintf('   + Computing back-projection weights..');
            num_ors = self.get_number_geometries();
            if (~self.using_super_sampling())
                for ii = 1:num_ors
                    self.bwd_weights{ii} = sum(self.offsets{ii}.proj_coeffs);
                for ii = 1:num_ors
                    proj_structs = [self.ss_offsets{ii}{:}];
                    proj_weights = cat(2, proj_structs(:).proj_coeffs);
                    self.bwd_weights{ii} = sum(proj_weights);
        function [sigma1, sigma1_1, tau] = init_cp_weights(self, algo)
            sigma1 = cell(size(self.fwd_weights));
            for n = 1:numel(sigma1)
                sigma1{n} = 1 ./ (self.fwd_weights{n} + (self.fwd_weights{n} == 0));
            end

            sigma1_1 = sigma1;
            for n = 1:numel(sigma1_1)
                sigma1_1{n} = 1 ./ (1 + sigma1_1{n});
            end

            num_geoms = self.get_number_geometries();
            tau = cell(size(self.bwd_weights));
            switch (upper(algo))
                case '6DLS'
                    for n = 1:num_geoms
                        tau{n} = cast(- 1 ./ (self.bwd_weights{n}), self.data_type);
                        tau{n} = cast(- 1 ./ (self.bwd_weights{n} + 1), self.data_type);
                        tau{n} = cast(- 1 ./ (self.bwd_weights{n} + 6), self.data_type);
                        tau{n} = cast(- 1 ./ (self.bwd_weights{n} + 6 + 1), self.data_type);
        function [p, p1, q_l1, q_tv] = init_cp_vars(self, algo)
            if (isempty(self.currentDetDual))
                p = gtMathsGetSameSizeZeros(self.blobs);
            else
                p = self.currentDetDual;
            end

            do_tv_update = ~isempty(strfind(upper(algo), 'TV'));
            do_l1_update = ~isempty(strfind(upper(algo), 'L1'));
            do_ls_update = ~isempty(strfind(upper(algo), 'LS'));

            if (self.algo_ops_c_functions)
                p1 = gtCxxMathsCellCopy(self.currentSolution, 'threads', self.num_threads);
            else
                p1 = self.currentSolution;
            end
            if (do_l1_update)
                q_l1 = gtMathsGetSameSizeZeros(self.currentSolution);
            elseif (do_ls_update)
                q_l1 = gtMathsGetSameSizeZeros(self.currentSolution{1});
            else
                q_l1 = [];
            end

            if (do_tv_update)
                if (strcmpi(self.tv_strategy, 'groups'))
                    num_groups = size(self.orientation_groups, 1);
                    q_tv = gtMathsGetSameSizeZeros(self.currentSolution(ones(num_groups, 3)));
                    q_tv = reshape(q_tv, num_groups, 3);
                else
                    q_tv = gtMathsGetSameSizeZeros(self.currentSolution([1 1 1]));
        function [proj_data, psf_time] = apply_psf(self, proj_data, is_direct)
                for n = 1:numel(proj_data)
                    if (numel(self.psf) == 1)
                        proj_data{n} = self.psf{1}.apply(proj_data{n}, is_direct);
%                         proj_data{n} = convn(proj_data{n}, self.psf{1}, 'same');
                        proj_data{n} = self.psf{n}.apply(proj_data{n}, is_direct);
%                         proj_data{n} = convn(proj_data{n}, self.psf{n}, 'same');
                psf_time = toc(c);
            else
                psf_time = [];
        function [value, summed_vols] = compute_functional_value(self, p, algo, lambda)
            value = gtMathsDotProduct(p, self.blobs);
            num_geoms = self.get_number_geometries();
            if (any(strcmpi(algo, {'6DLS', '6DL1'})))
                proj_bls = self.compute_fwd_projection(self.currentSolution, false);
                if (self.algo_ops_c_functions)
                    proj_bls = gtCxxMathsCellMinus(proj_bls, self.blobs, 'copy', false, 'threads', self.num_threads);
                else
                    for ii_b = 1:numel(proj_bls)
                        proj_bls{ii_b} = proj_bls{ii_b} - self.blobs{ii_b};
                    end
                end
                value = value + gtMathsNorm_l2(proj_bls);
            if (self.algo_ops_c_functions)
                summed_vols = gtCxxMathsSumCellVolumes(self.currentSolution);
            else
                summed_vols = gtMathsSumCellVolumes(self.currentSolution);
            end
            if (any(strcmpi(algo, {'6DTV', '6DTVL1'})))
                sES = gtMathsGradient(summed_vols / num_geoms);
            if (any(strcmpi(algo, {'6DL1', '6DTVL1'})))
                value = value + lambda * gtMathsNorm_l1(self.currentSolution);
            end
        end

        function initializeVariables(self, numIters)
            switch (numel(numIters))
                case 1
                    self.normResiduals = zeros(numIters, 1);
                case 2
                    self.normResiduals(numIters(1):numIters(2)) = 0;
                otherwise
                    self.normResiduals(numIters) = 0;
            end
        end
    end

    methods (Static, Access = public)
        function proj_size = getProjSize(blobs)
            proj_size = [size(blobs{1}, 1), size(blobs{1}, 3)];
            for n = 2:numel(blobs)
                if (any(proj_size ~= [size(blobs{n}, 1), size(blobs{n}, 3)]))
                    error('Gt6DBlobReconstructor:wrong_argument', ...
                        'Blob: %d is malformed!', n)
                end
            end
        end
    end
end