Skip to content
Snippets Groups Projects
GtGrainODFwSolver.m 12.8 KiB
Newer Older
classdef GtGrainODFwSolver < handle
    properties
        parameters;

        grid_gr;
        volume = [];
        size_volume = [];

        sino = [];
        size_sino = [];
        pre_paddings;

        S;
        St;

        S__ws;
        St_ws;
        tau_ws;

        num_iter = 0;
    end

    methods (Access = public)
        function self = GtGrainODFwSolver(parameters)
            self.parameters = parameters;
        end

        function vol = solve_synthetic(self, ref_gr, gvdm, algorithm, lambda)
            bls = ref_gr.bl(ref_gr.selected);

            self.build_sinogram(bls, 1.1);
            self.build_orientation_sampling_synthetic(ref_gr, gvdm, self.get_num_ws()/5, 1.3);
            self.build_projection_matrices(bls, ref_gr.selected);
            self.build_projection_weights();

            if (~exist('algorithm', 'var'))
                algorithm = 'cplsnn';
            end
            switch(lower(algorithm))
                case 'sirt'
                    self.solve_sirt();
                case 'cplsnn'
                    self.solve_cplsnn();
                case 'cplsl1nn'
                    self.solve_cplsl1nn(lambda);
                case 'cpl1nn'
                    self.solve_cpl1nn();
            end
            vol = self.get_volume();
        end

        function vol = solve_experimental(self, ref_gr, algorithm, lambda)
            bls = ref_gr.bl(ref_gr.selected);

            self.build_sinogram(bls, 1.2);
            self.build_orientation_sampling_experimental(ref_gr, floor(self.get_num_ws()/1.2*1.1), 1.1);
            self.build_projection_matrices(bls, ref_gr.selected);
            self.build_projection_weights();

            if (~exist('algorithm', 'var'))
                algorithm = 'cplsnn';
            end
            switch(lower(algorithm))
                case 'sirt'
                    self.solve_sirt();
                case 'cplsnn'
                    self.solve_cplsnn();
                case 'cplsl1nn'
                    self.solve_cplsl1nn(lambda);
                case 'cpl1nn'
                    self.solve_cpl1nn();
            end
            vol = self.get_volume();
        end
    end

    methods (Access = public) % Low Level API
        function build_orientation_sampling_synthetic(self, ref_gr, gvdm, grid_edge, oversize)
            sampler = GtOrientationSampling(ref_gr.bl, self.parameters, ref_gr);
%             sampler.make_simple_grid('cubic', grid_edge, gvdm, oversize);
            sampler.make_even_simple_grid('cubic', grid_edge, gvdm, oversize);
            self.grid_gr = sampler.get_orientations();
            self.size_volume = size(sampler.lattice.gr);
        end

        function build_orientation_sampling_experimental(self, ref_gr, grid_edge, oversize)
            sampler = GtOrientationSampling(ref_gr.bl, self.parameters, ref_gr);
            sampler.make_simple_grid_estim_ODF('cubic', grid_edge, true, oversize);
            self.grid_gr = sampler.get_orientations();
            self.size_volume = size(sampler.lattice.gr);
        end

        function build_sinogram(self, bls, oversize)
            num_blobs = numel(bls);
%             % Should be worked on!
%             ints_w = arrayfun(@(x){squeeze(sum(sum(x.intm, 1), 2))}, bls);
%             real_blob_limits = cellfun(@(x){...
%                 [max(find(x, 1, 'first'), 1), ...
%                 min(find(x, 1, 'last'), numel(x))]...
%                 }, ints_w);
%             real_blob_limits = cat(1, real_blob_limits{:});
%             
%             blob_dephs = real_blob_limits(:, 2) - real_blob_limits(:, 1) + 1;
            blob_dephs = arrayfun(@(x)size(x.intm, 3), bls);
            blob_dephs = reshape(blob_dephs, [], 1);
            num_ws = max(blob_dephs) + 2;
            if (exist('oversize', 'var'))
                num_ws = round(num_ws * oversize);
            end

            self.sino = zeros(num_ws, num_blobs);
            self.size_sino = size(self.sino);
            self.pre_paddings = floor((num_ws - blob_dephs) / 2) + 1;

            for ii_b = 1:num_blobs
                ints_interval = self.pre_paddings(ii_b):(self.pre_paddings(ii_b) + blob_dephs(ii_b) -1);
                masked_blob = bls(ii_b).intm;
                masked_blob(~bls(ii_b).mask) = 0;
                self.sino(ints_interval, ii_b) = squeeze(sum(sum(masked_blob, 1), 2));
            end
            self.sino = reshape(self.sino, [], 1);
        end

        function sino = get_sinogram(self)
            sino = reshape(self.sino, self.size_sino);
        end

        function comp_sino = get_computed_sinogram(self)
            if (~isempty(self.volume))
                comp_sino = self.S * self.volume;
                comp_sino = reshape(comp_sino, self.size_sino);
            else
                error('GtGrainODFSolver:no_reconstruction', ...
                    'No reconstruction performed!');
            end
        end

        function vol = get_volume(self)
            vol = reshape(self.volume, self.size_volume);
        end

        function or = get_orientations(self)
            or = reshape(self.grid_gr, self.size_volume);
        end

        function r_vecs = get_R_vectors(self)
            r_vecs = [self.grid_gr{:}];
            r_vecs = {cat(1, r_vecs(:).R_vector)};
        end

        function build_projection_matrices(self, bls, bl_selected)
            fprintf('Computing projection matrices..')
            c = tic();
            om_step = 180 / self.parameters.acq.nproj;
            num_ws = self.get_num_ws();

            num_orients = numel(self.grid_gr);

            bls_bbws = cat(1, bls(:).bbwim);
            min_conds = bls_bbws(:, 1) - self.pre_paddings + 1;
            max_conds = min_conds + num_ws - 1;

            b_ws = [];
            b_cs = [];
            b_is = [];
            b_os = [];

            for ii_o = 1:num_orients
                ab = self.grid_gr{ii_o}.allblobs;
                ws = ab.omega(bl_selected) / om_step;

                min_ws = floor(ws);
                max_ws = min_ws + 1;

                max_cs = ws - min_ws;
                min_cs = 1 - max_cs;

                ok_mins = (min_ws >= min_conds) & (min_ws <= max_conds);
                ok_maxs = (max_ws >= min_conds) & (max_ws <= max_conds) & (max_cs > eps('single'));

                indx_mins = find(ok_mins);
                indx_maxs = find(ok_maxs);

                b_ws = [b_ws; ...
                    min_ws(indx_mins) - bls_bbws(indx_mins, 1) + self.pre_paddings(indx_mins); ...
                    max_ws(indx_maxs) - bls_bbws(indx_maxs, 1) + self.pre_paddings(indx_maxs)];
                b_cs = [b_cs; ...
                    min_cs(indx_mins); max_cs(indx_maxs)];
                b_is = [b_is; ...
                    indx_mins; indx_maxs];
                b_os = [b_os; ...
                    ii_o(ones(numel(indx_mins) + numel(indx_maxs), 1), 1)];
            end

            sino_indx = sub2ind(self.size_sino, b_ws, b_is);

            self.S = sparse( ...
                    sino_indx, b_os, b_cs, ...
                    numel(self.sino), num_orients);

            self.St = self.S';
            fprintf('\b\b: Done in %f seconds.\n', toc(c));
        end

        function build_projection_weights(self)
            self.St_ws = self.bp(gtMathsGetSameSizeOnes(self.sino));
            self.S__ws = self.fp(gtMathsGetSameSizeOnes(self.St_ws));

            self.St_ws = 1 ./ (self.St_ws + (self.St_ws == 0));
            self.S__ws = 1 ./ (self.S__ws + (self.S__ws == 0));
        end

        function solve_sirt(self)
            c = tic();
            if (~self.num_iter)
                self.num_iter = 100;
            end
            residuals = zeros(self.num_iter, 1);

            x0 = self.bw(self.bp(self.fw(self.sino)));
            x0(x0 < 0) = 0;

            res_norm_0 = gtMathsNorm_l2(self.sino);

            x = x0;
            fprintf('Solving SIRT: ')
            for ii = 1:self.num_iter
                num_chars = fprintf('%03d/%03d', ii, self.num_iter);

                comp_ints = self.fp(x);
                res = self.sino - comp_ints;
                residuals(ii) = gtMathsNorm_l2(res) / res_norm_0;
                res = self.fw(res);
                res_vol = self.bp(res);
                res_vol = self.bw(res_vol);
                x = x + res_vol;
                x(x < 0) = 0;

                fprintf(repmat('\b', [1 num_chars]));
            end
            comp_ints = self.fp(x);
            res = self.sino - comp_ints;
            res_norm = gtMathsNorm_l2(res) / res_norm_0;

            fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), res_norm);
            figure, semilogy(residuals)
            self.volume = x;
        end

        function solve_cplsnn(self)
            c = tic();
            if (~self.num_iter)
                self.num_iter = 50;
            end
            residuals = zeros(self.num_iter, 1);

            p = gtMathsGetSameSizeZeros(self.S__ws);
            x = gtMathsGetSameSizeZeros(self.St_ws);

            res_norm_0 = gtMathsNorm_l2(self.sino);

            xe = x;
            fprintf('Solving CPLSNN: ')
            for ii = 1:self.num_iter
                num_chars = fprintf('%03d/%03d', ii, self.num_iter);

                p = (p + self.fw(self.fp(xe) - self.sino)) ./ (1 + self.S__ws);
                x = x - self.bw(self.bp(p));
                x(x < 0) = 0;

                xe = x + (x - xo);

                residuals(ii) = gtMathsNorm_l2(self.sino - self.fp(x)) / res_norm_0;

                fprintf(repmat('\b', [1 num_chars]));
            end

            fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), residuals(end));
            figure, semilogy(residuals)
            self.volume = x;
        end

        function solve_cplsl1nn(self, lambda)
            c = tic();
            if (~self.num_iter)
                self.num_iter = 50;
            end
            residuals = zeros(self.num_iter, 1);

            p = gtMathsGetSameSizeZeros(self.S__ws);
            q = gtMathsGetSameSizeZeros(self.St_ws);
            qo = gtMathsGetSameSizeOnes(self.St_ws);
            x = gtMathsGetSameSizeZeros(self.St_ws);

            res_norm_0 = gtMathsNorm_l2(self.sino);

            self.tau_ws = 1 ./ (1 ./ self.St_ws + 1);

            xe = x;
            fprintf('Solving CPLSL1NN: ')
            for ii = 1:self.num_iter
                num_chars = fprintf('%03d/%03d', ii, self.num_iter);

                p = (p + self.fw(self.fp(xe) - self.sino)) ./ (1 + self.S__ws);

                qn = q + xe;
                q = lambda .* (qn ./ max(qo, abs(qn)));

                xo = x;
                x = x - (self.bp(p) + q) .* self.tau_ws;
                x(x < 0) = 0;

                xe = x + (x - xo);

                residuals(ii) = gtMathsNorm_l2(self.sino - self.fp(x)) / res_norm_0;

                fprintf(repmat('\b', [1 num_chars]));
            end

            fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), residuals(end));
            figure, semilogy(residuals)
            self.volume = x;
        end

        function solve_cpl1nn(self)
            rescaled_sino = self.sino ./ norm(self.sino(:)) .* self.get_num_ws();

            c = tic();
            if (~self.num_iter)
                self.num_iter = 50;
            end
            residuals = zeros(self.num_iter, 1);

            p = gtMathsGetSameSizeZeros(self.S__ws);
            po = gtMathsGetSameSizeOnes(self.S__ws);
            x = gtMathsGetSameSizeZeros(self.St_ws);

            res_norm_0 = gtMathsNorm_l2(rescaled_sino);

            self.tau_ws = 1 ./ (1 ./ self.St_ws + 1);

            xe = x;
            fprintf('Solving CPL1NN: ')
            for ii = 1:self.num_iter
                num_chars = fprintf('%03d/%03d', ii, self.num_iter);

                pn = p + self.fw(self.fp(xe) - rescaled_sino);
                p = pn ./ max(po, abs(pn));

                xo = x;
                x = x - self.bw(self.bp(p));
                x(x < 0) = 0;

                xe = x + (x - xo);

                residuals(ii) = gtMathsNorm_l2(rescaled_sino - self.fp(x)) / res_norm_0;

                fprintf(repmat('\b', [1 num_chars]));
            end

            x = x .* norm(self.sino(:)) ./ self.get_num_ws();

            fprintf('Done %d iterations in %f seconds: residual %f\n', self.num_iter, toc(c), residuals(end));
            figure, semilogy(residuals)
            self.volume = x;
        end
    end

    methods (Access = protected)
        function x = fp(self, x)
            x = self.S * x;
        end

        function x = bp(self, x)
            x = self.St * x;
        end

        function x = fw(self, x)
            x = x .* self.S__ws;
        end

        function x = bw(self, x)
            x = x .* self.St_ws;
        end

        function num_ws = get_num_ws(self)
            num_ws = self.size_sino(1);
        end
    end
end