Skip to content
Snippets Groups Projects
GtGrainODFwSolver.m 17.2 KiB
Newer Older
classdef GtGrainODFwSolver < handle
    properties
        parameters;

        volume = [];
        size_volume = [];

        sino = [];
        size_sino = [];
        pre_paddings;

        S;
        St;

        S__ws;
        St_ws;
        tau_ws;

        shape_functions_type = 'none';
        shape_functions = [];


        ospace_oversize = 1;
        ospace_oversampling = 1;
    end

    methods (Access = public)
        function self = GtGrainODFwSolver(parameters, varargin)
            self.parameters = parameters;

            self = parse_pv_pairs(self, varargin);
        function conf = initialize(self, ref_gr, varargin)
        % FUNCTION conf = initialize(self, ref_gr, varargin)
        %     'algorithm': {'cplsnn'} | 'sirt' | 'cplsl1nn' | 'cpl1nn'
        %     'lambda': 1e-2
        %     'det_index': 1
        %

            conf = struct( ...
                'algorithm', 'cplsnn', ...
            [conf, ~] = parse_pv_pairs(conf, varargin);
            self.build_orientation_sampling(ref_gr, conf.det_index);
            self.build_projection_matrices();
            self.build_projection_weights();
        end

        function vol = solve(self, ref_gr, varargin)
        % FUNCTION vol = solve(self, ref_gr, varargin)
        %
        % INPUT (varargin): same as initialze(self, ref_gr, varargin)
        %

            conf = self.initialize(ref_gr, varargin{:});
                case 'sirt'
                    self.solve_sirt();
                case 'cplsnn'
                    self.solve_cplsnn();
                case 'cplsl1nn'
                case 'cpl1nn'
                    self.solve_cpl1nn();
            end
            vol = self.get_volume();
        end
    end

    methods (Access = public) % Low Level API
        function build_orientation_sampling(self, ref_gr, det_index)
            if (~exist('det_index', 'var') || isempty(det_index))
                det_index = 1;
            end

            self.sampler = GtOrientationSampling(self.parameters, ref_gr, ...
                'detector_index', det_index);
            % 'verbose', self.verbose,

            self.sampler.make_grid_estim_ODF_resoluion('cubic', ...
                -self.ospace_oversampling, self.ospace_oversize);
            self.size_volume = size(self.sampler.lattice.gr);
            switch (self.shape_functions_type)
                case 'none'
                    self.shape_functions = [];
                case 'w'
                    self.shape_functions = gtDefShapeFunctionsFwdProj(self.sampler, ...
                        'shape_function_type', 'w', ...
                        'data_type', self.data_type);
%                     self.shape_functions = gtDefShapeFunctionsCreateW(self.sampler);
            sel_ref = self.sampler.selected;

            bls = self.sampler.ref_gr.proj(self.sampler.detector_index).bl(ref_sel);
            num_blobs = numel(bls);
            blob_dephs = arrayfun(@(x)size(x.intm, 3), bls);
            blob_dephs = reshape(blob_dephs, [], 1);
            blob_lims = cat(1, bls(:).bbwim);

            switch(self.shape_functions_type)
                case 'w'
                    for ii = numel(self.shape_functions):-1:1
                        proj_lims(:, :, ii) = cat(1, self.shape_functions{ii}(:).bbwim);
                    end
                    proj_lims = [min(proj_lims(:, 1, :), [], 3), max(proj_lims(:, 2, :), [], 3)];
                    delta_omegas = proj_lims(:, 2) - proj_lims(:, 1) + 1;
                case 'none'
                    [delta_omegas, proj_lims] = self.sampler.get_omega_deviations();
                    delta_omegas = delta_omegas(sel_ref, :);
                    proj_lims = proj_lims(sel_ref, :);
            chosen_depts = max(blob_dephs, delta_omegas);
            self.sino = zeros(num_ws, num_blobs, self.data_type);
            self.size_sino = size(self.sino);

            % In the case that the sampled orientations determine the shift, we
            % have to take it into account!
            additional_shift = blob_lims(:, 1) - proj_lims(:, 1);
            additional_shift(additional_shift < 0) = 0;
            self.pre_paddings = floor((num_ws - chosen_depts) / 2) + 1 + additional_shift;

            for ii_b = 1:num_blobs
                ints_interval = self.pre_paddings(ii_b):(self.pre_paddings(ii_b) + blob_dephs(ii_b) -1);
                masked_blob = bls(ii_b).intm;
                masked_blob(~bls(ii_b).mask) = 0;
                self.sino(ints_interval, ii_b) = squeeze(sum(sum(masked_blob, 1), 2));
            end
            self.sino = reshape(self.sino, [], 1);

            if (self.verbose)
                fprintf('Sino size: [%d, %d]\n', self.size_sino);
            end
        end

        function sino = get_sinogram(self)
            sino = reshape(self.sino, self.size_sino);
        end

        function comp_sino = get_computed_sinogram(self)
            if (~isempty(self.volume))
                comp_sino = reshape(comp_sino, self.size_sino);
            else
                error('GtGrainODFSolver:no_reconstruction', ...
                    'No reconstruction performed!');
            end
        end

        function vol = get_volume(self)
            vol = reshape(self.volume, self.size_volume);
        end

        function or = get_orientations(self)
            grid_gr = self.sampler.get_orientations();
            or = reshape(grid_gr, self.size_volume);
        end

        function r_vecs = get_R_vectors(self)
            grid_gr = self.sampler.get_orientations();
            r_vecs = [grid_gr{:}];
            r_vecs = cat(1, r_vecs(:).R_vector);
        function build_projection_matrices(self)
            fprintf('Computing projection matrices..')
            c = tic();
            switch (self.shape_functions_type)
                case 'none'
                    self.build_projection_matrices_sf_none();
                case 'w'
                    self.build_projection_matrices_sf_w();
                    self.build_projection_matrices_sf_uvw();
            end
            fprintf('\b\b: Done in %f seconds.\n', toc(c));
        end

        function build_projection_matrices_sf_none(self)
            bls = self.sampler.ref_gr.proj(self.sampler.detector_index).bl(self.sampler.selected);
            num_ws = self.get_num_ws();

            bls_bbws = cat(1, bls(:).bbwim);
            min_conds = bls_bbws(:, 1) - self.pre_paddings + 1;
            max_conds = min_conds + num_ws - 1;

            b_ws = [];
            b_cs = zeros(0, 1, self.data_type);
            b_is = [];
            b_os = [];

            grid_gr = self.sampler.get_orientations();
            num_orients = numel(grid_gr);

            om_step = gtAcqGetOmegaStep(self.parameters, self.sampler.detector_index);
            for ii_o = 1:num_orients
                ws = ab.omega(self.sampler.selected) / om_step;

                min_ws = floor(ws);
                max_ws = min_ws + 1;

                max_cs = ws - min_ws;
                min_cs = 1 - max_cs;

                ok_mins = (min_ws >= min_conds) & (min_ws <= max_conds);
                ok_maxs = (max_ws >= min_conds) & (max_ws <= max_conds) & (max_cs > eps('single'));

                indx_mins = find(ok_mins);
                indx_maxs = find(ok_maxs);

                b_ws = [b_ws; ...
Nicola Vigano's avatar
Nicola Vigano committed
                    min_ws(indx_mins) - min_conds(indx_mins) + 1; ...
                    max_ws(indx_maxs) - min_conds(indx_maxs) + 1]; %#ok<AGROW>
                b_cs = [b_cs; ...
Nicola Vigano's avatar
Nicola Vigano committed
                    min_cs(indx_mins); max_cs(indx_maxs)]; %#ok<AGROW>
                b_is = [b_is; ...
Nicola Vigano's avatar
Nicola Vigano committed
                    indx_mins; indx_maxs]; %#ok<AGROW>
                b_os = [b_os; ...
Nicola Vigano's avatar
Nicola Vigano committed
                    ii_o(ones(numel(indx_mins) + numel(indx_maxs), 1), 1)]; %#ok<AGROW>
            end

            sino_indx = sub2ind(self.size_sino, b_ws, b_is);

            self.S = sparse( ...
                    sino_indx, b_os, b_cs, ...
                    numel(self.sino), num_orients);

            self.St = self.S';
        end

        function build_projection_matrices_sf_w(self)
            bls = self.sampler.ref_gr.proj(self.sampler.detector_index).bl(self.sampler.selected);
            num_ws = self.get_num_ws();
            num_blobs = self.get_num_blobs();

            bls_bbws = cat(1, bls(:).bbwim);
            min_conds = bls_bbws(:, 1) - self.pre_paddings + 1;
            max_conds = min_conds + num_ws - 1;

            b_ws = [];
            b_cs = zeros(0, 1, self.data_type);
            b_is = [];
            b_os = [];

            num_orients = numel(self.shape_functions);

            for ii_o = 1:num_orients
                sf = self.shape_functions{ii_o};

                cs = reshape(cs, [], 1);

                ws = cell(num_blobs, 1);
                is = cell(num_blobs, 1);
                for ii_b = 1:num_blobs
                    ws{ii_b} = sf(ii_b).bbwim(1):sf(ii_b).bbwim(2);
                    is{ii_b} = ii_b(ones(sf(ii_b).bbsize(1), 1));
                end

                ws = reshape([ws{:}], [], 1);
                is = cat(1, is{:});

                wrong = ws < min_conds(is) | ws > max_conds(is);
                if (any(wrong))
                    ii_o
                    find(wrong)
                    [min_conds(is(wrong)), ws(wrong), ...
                        max_conds(is(wrong)), ...
                        (max_conds(is(wrong))-min_conds(is(wrong))+1), ...
                        (ws(wrong) - min_conds(is(wrong)) + 1) ]
                end

                b_ws = [b_ws; (ws - min_conds(is) + 1)]; %#ok<AGROW>
                b_cs = [b_cs; cs]; %#ok<AGROW>
                b_is = [b_is; is]; %#ok<AGROW>
                b_os = [b_os; ii_o(ones(numel(cs), 1), 1)]; %#ok<AGROW>
            end

            sino_indx = sub2ind(self.size_sino, b_ws, b_is);

            self.S = sparse( ...
                    sino_indx, b_os, b_cs, ...
                    numel(self.sino), num_orients);

            self.St = self.S';
        end

        function build_projection_matrices_sf_nw(self)
            error('GtGrainODFwSolver:wrong_argument', ...
                'Eta-Omega shape functions not supported!')
        function build_projection_matrices_sf_uvw(self)
            error('GtGrainODFwSolver:wrong_argument', ...
                'UV-Omega shape functions not supported!')
        end

        function build_projection_weights(self)
            fprintf('Computing projection weights..')
            c = tic();
            self.St_ws = self.bp(gtMathsGetSameSizeOnes(self.sino, self.data_type));
            self.S__ws = self.fp(gtMathsGetSameSizeOnes(self.St_ws, self.data_type));

            self.St_ws = 1 ./ (self.St_ws + (self.St_ws == 0));
            self.S__ws = 1 ./ (self.S__ws + (self.S__ws == 0));
            fprintf('\b\b: Done in %f seconds.\n', toc(c));
        end

        function solve_sirt(self)
            c = tic();
            residuals = zeros(self.num_iter, 1);

            x0 = self.bw(self.bp(self.fw(self.sino)));
            x0(x0 < 0) = 0;

            res_norm_0 = gtMathsNorm_l2(self.sino);

            x = x0;
            fprintf('Solving SIRT: ')
            for ii = 1:self.num_iter
                current_time = toc(c);
                num_chars = fprintf('%03d/%03d (%g, ETA: %g)', ii, self.num_iter, ...
                    current_time, current_time/(ii - 1)*self.num_iter - current_time);
                comp_ints = self.fp(x);
                res = self.sino - comp_ints;
                residuals(ii) = gtMathsNorm_l2(res) / res_norm_0;
                res = self.fw(res);
                res_vol = self.bp(res);
                res_vol = self.bw(res_vol);
                x = x + res_vol;
                x(x < 0) = 0;

                fprintf(repmat('\b', [1 num_chars]));
            end
            comp_ints = self.fp(x);
            res = self.sino - comp_ints;
            res_norm = gtMathsNorm_l2(res) / res_norm_0;

            fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), res_norm);
            if (self.verbose)
                figure, semilogy(residuals)
            end
            self.volume = x;
        end

        function solve_cplsnn(self)
            c = tic();
            residuals = zeros(self.num_iter, 1);

            p = gtMathsGetSameSizeZeros(self.S__ws);
            x = gtMathsGetSameSizeZeros(self.St_ws);

            res_norm_0 = gtMathsNorm_l2(self.sino);

            xe = x;
            fprintf('Solving CPLSNN: ')
            for ii = 1:self.num_iter
                current_time = toc(c);
                num_chars = fprintf('%03d/%03d (%g, ETA: %g)', ii, self.num_iter, ...
                    current_time, current_time/(ii - 1)*self.num_iter - current_time);
                p = (p + self.fw(self.fp(xe) - self.sino)) ./ (1 + self.S__ws);
                x = x - self.bw(self.bp(p));
                x(x < 0) = 0;

                xe = x + (x - xo);

                residuals(ii) = gtMathsNorm_l2(self.sino - self.fp(x)) / res_norm_0;

                fprintf(repmat('\b', [1 num_chars]));
            end

            fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), residuals(end));
            if (self.verbose)
                figure, semilogy(residuals)
            end
            self.volume = x;
        end

        function solve_cplsl1nn(self, lambda)
            c = tic();
            residuals = zeros(self.num_iter, 1);

            p = gtMathsGetSameSizeZeros(self.S__ws);
            q = gtMathsGetSameSizeZeros(self.St_ws);
            q1 = gtMathsGetSameSizeOnes(self.St_ws);
            x = gtMathsGetSameSizeZeros(self.St_ws);

            res_norm_0 = gtMathsNorm_l2(self.sino);

            self.tau_ws = 1 ./ (1 ./ self.St_ws + 1);

            xe = x;
            fprintf('Solving CPLSL1NN: ')
            for ii = 1:self.num_iter
                current_time = toc(c);
                num_chars = fprintf('%03d/%03d (%g, ETA: %g)', ii, self.num_iter, ...
                    current_time, current_time/(ii - 1)*self.num_iter - current_time);
                p = (p + self.fw(self.fp(xe) - self.sino)) ./ (1 + self.S__ws);

                qn = q + xe;
                q = (qn ./ max(q1, abs(qn)));
                x = x - (self.bp(p) + lambda .* q) .* self.tau_ws;
                x(x < 0) = 0;

                xe = x + (x - xo);

                residuals(ii) = gtMathsNorm_l2(self.sino - self.fp(x)) / res_norm_0;

                fprintf(repmat('\b', [1 num_chars]));
            end

            fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), residuals(end));
            if (self.verbose)
                figure, semilogy(residuals)
            end
            self.volume = x;
        end

        function solve_cpl1nn(self)
            rescaled_sino = self.sino ./ norm(self.sino(:)) .* self.get_num_ws();

            c = tic();
            residuals = zeros(self.num_iter, 1);

            p = gtMathsGetSameSizeZeros(self.S__ws);
            p1 = gtMathsGetSameSizeOnes(self.S__ws);
            x = gtMathsGetSameSizeZeros(self.St_ws);

            res_norm_0 = gtMathsNorm_l2(rescaled_sino);

            self.tau_ws = 1 ./ (1 ./ self.St_ws + 1);

            xe = x;
            fprintf('Solving CPL1NN: ')
            for ii = 1:self.num_iter
                current_time = toc(c);
                num_chars = fprintf('%03d/%03d (%g, ETA: %g)', ii, self.num_iter, ...
                    current_time, current_time/(ii - 1)*self.num_iter - current_time);
                pn = p + self.fw(self.fp(xe) - rescaled_sino);
                p = pn ./ max(p1, abs(pn));
                x = x - self.bw(self.bp(p));
                x(x < 0) = 0;

                xe = x + (x - xo);

                residuals(ii) = gtMathsNorm_l2(rescaled_sino - self.fp(x)) / res_norm_0;

                fprintf(repmat('\b', [1 num_chars]));
            end

            x = x .* norm(self.sino(:)) ./ self.get_num_ws();

            fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), residuals(end));
            if (self.verbose)
                figure, semilogy(residuals)
            end
            self.volume = x;
        end
    end

    methods (Access = protected)
        function x = fp(self, x)
            x = self.S * x;
        end

        function x = bp(self, x)
            x = self.St * x;
        end

        function x = fw(self, x)
            x = x .* self.S__ws;
        end

        function x = bw(self, x)
            x = x .* self.St_ws;
        end

        function num_ws = get_num_ws(self)
            num_ws = self.size_sino(1);
        end

        function num_bs = get_num_blobs(self)
            num_bs = self.size_sino(2);
        end