Skip to content
Snippets Groups Projects
Gt6DVolumeProjector.m 10.72 KiB
classdef Gt6DVolumeProjector < handle
    properties
        proj_size = [];

        volume_ss = 1;

        volume_geometry = {};
        astra_volume_geometry = {};

        geometries = {};
        astra_projection_geometries = {};

        ss_geometries = {};
        astra_ss_projection_geometries = {};

        astra_projectors = {};
        astra_ss_projectors = {};
        use_astra_projectors = false;

        volume_id = [];
        sinogram_ids = {};

        algo_fproj_ids = {};
        algo_bproj_ids = {};

        statistics = GtTasksStatistics();
    end

    properties (Constant)
        messageNoGPU = 'This machine cannot be used to forward and back project';
    end

    methods (Access = public)
        function self = Gt6DVolumeProjector(volumes, detector_size, varargin)

            if (~exist('volumes', 'var') || isempty(volumes))
                error('Gt6DVolumeProjector:wrong_argument', ...
                    'No initial volumes specified')
            end

            if (~exist('detector_size', 'var'))
                diagonal = round(norm(size(volumes{1})));
                self.proj_size = [diagonal diagonal];
            else
                self.proj_size = detector_size;
            end

            self = parse_pv_pairs(self, varargin);

            if ((exist('astra_mex_direct_c', 'file') ~= 3) || (self.volume_ss ~= 1))
                self.use_astra_projectors = false;
            end

            num_volumes = numel(volumes);

            % We check them now, bu they will be stored in the derived
            % classes
            [vol_size(1), vol_size(2), vol_size(3)] = size(volumes{1});
            for ii = 2:num_volumes
                if (~all(vol_size == size(volumes{ii})))
                    error('PROJECTOR:wrong_argument', ...
                        'Size mismatch between the old and the new volume');
                end
            end

            % volume geometry (x, y, z)
            self.volume_geometry = vol_size;
            self.astra_volume_geometry = astra_create_vol_geom(vol_size(2), vol_size(1), vol_size(3));
            if (~self.use_astra_projectors)
                self.volume_id = astra_mex_data3d('create', '-vol', self.astra_volume_geometry, 0);
            end
        end

        function delete(self)
            if (~isempty(self.volume_id))
                astra_mex_data3d('delete', self.volume_id);
            end

            self.reset_geometry();
        end

        function initProjectionGeometry(self)
            % Detector position is defined, through a vector pointing to it's
            % center

            % Let's clean up previous geometries
            if (~isempty(self.astra_projection_geometries))
                self.reset_geometry();
            end

            num_geoms = numel(self.geometries);
            for n = 1:num_geoms
                self.astra_projection_geometries{n} = astra_create_proj_geom(...
                    'parallel3d_vec', self.proj_size(2), self.proj_size(1), self.geometries{n});

                if (self.use_astra_projectors)
                    self.astra_projectors{n} = astra_create_projector('cuda3d', ...
                        self.astra_projection_geometries{n}, self.astra_volume_geometry);
                end
            end

            if (~self.use_astra_projectors)
                self.sinogram_ids = astra_mex_data3d('create', '-proj3d', self.astra_projection_geometries);

                % Let's preallocate the algorithms
                cfg = astra_struct('FP3D_CUDA');
                for ii = 1:numel(self.sinogram_ids)
                    cfg.ProjectionDataId = self.sinogram_ids{ii};
                    cfg.VolumeDataId = self.volume_id;
                    % For square weights
%                     cfg.ProjectionKernel = 'sum_square_weights';

                    self.algo_fproj_ids{ii} = astra_mex_algorithm('create', cfg);
                end

                cfg = astra_struct('BP3D_CUDA');
                for ii = 1:numel(self.sinogram_ids)
                    cfg.ProjectionDataId = self.sinogram_ids{ii};
                    cfg.ReconstructionDataId = self.volume_id;
                    cfg.option.VoxelSuperSampling = self.volume_ss;

                    self.algo_bproj_ids{ii} = astra_mex_algorithm('create', cfg);
                end
            end
        end

        function initProjectionGeometrySS(self)
            % Let's clean up previous geometries
            if (~isempty(self.astra_ss_projection_geometries))
                self.reset_ss_geometry();
            end

            num_ss_geoms = numel(self.ss_geometries);

            for g_ii = 1:num_ss_geoms
                num_sub_ss_geoms = numel(self.ss_geometries{g_ii});
                self.astra_ss_projection_geometries{g_ii} = cell(num_sub_ss_geoms, 1);
                if (self.use_astra_projectors)
                    self.astra_ss_projectors{g_ii} = cell(num_sub_ss_geoms, 1);
                end
                for g_ss_ii = 1:num_sub_ss_geoms
                    self.astra_ss_projection_geometries{g_ii}{g_ss_ii} = astra_create_proj_geom(...
                        'parallel3d_vec', self.proj_size(2), self.proj_size(1), self.ss_geometries{g_ii}{g_ss_ii});

                    if (self.use_astra_projectors)
                        self.astra_ss_projectors{g_ii}{g_ss_ii} = astra_create_projector('cuda3d', ...
                            self.astra_ss_projection_geometries{g_ii}{g_ss_ii}, self.astra_volume_geometry);
                    end
                end
            end
        end

        function printStats(self)
            self.statistics.printStats()
        end

        function stats = get_statistics(self)
            stats = self.statistics;
        end
    end

    methods (Access = protected)
        function ss_sinogram = fwd_project_single_volume_ss(self, volume, ii)
            ss_geoms = self.astra_ss_projection_geometries{ii};
            num_ss_geoms = numel(ss_geoms);
            volume = volume ./ num_ss_geoms;
            if (self.use_astra_projectors)
                if (numel(volume) == 1)
                    volume = volume * ones(self.volume_geometry, 'single');
                end
                ss_projectors = self.astra_ss_projectors{ii};
            else
                astra_mex_data3d('store', self.volume_id, volume);
            end

            ss_sinogram = cell(num_ss_geoms, 1);
            for ii_ss = 1:num_ss_geoms
                if (self.use_astra_projectors)
                    ss_sinogram{ii_ss} = astra_mex_direct_c('FP3D', ss_projectors{ii_ss}, volume);
                else
                    sino_id = astra_mex_data3d('create', '-proj3d', ss_geoms{ii_ss}, 0);

                    cfg = astra_struct('FP3D_CUDA');
                    cfg.ProjectionDataId = sino_id;
                    cfg.VolumeDataId = self.volume_id;

                    algo_id = astra_mex_algorithm('create', cfg);
                    astra_mex_algorithm('iterate', algo_id);
                    astra_mex_algorithm('delete', algo_id);

                    ss_sinogram{ii_ss} = astra_mex_data3d('get_single', sino_id);
                    astra_mex_data3d('delete', sino_id);
                end
            end
        end

        function volume = bwd_project_single_volume_ss(self, ss_sinogram, ii)
            ss_geoms = self.astra_ss_projection_geometries{ii};
            num_ss_geoms = numel(ss_geoms);
            if (self.use_astra_projectors)
                ss_projectors = self.astra_ss_projectors{ii};
            end

            volume = cell(num_ss_geoms, 1);
            for ii_ss = 1:num_ss_geoms
                if (self.use_astra_projectors)
                    volume{ii_ss} = astra_mex_direct_c('BP3D', ss_projectors{ii_ss}, ss_sinogram{ii_ss});
                else
                    sino_id = astra_mex_data3d('create', '-proj3d', ss_geoms{ii_ss}, ss_sinogram{ii_ss});
                    cfg = astra_struct('BP3D_CUDA');
                    cfg.ProjectionDataId = sino_id;
                    cfg.ReconstructionDataId = self.volume_id;

                    algo_id = astra_mex_algorithm('create', cfg);
                    astra_mex_algorithm('iterate', algo_id);
                    astra_mex_algorithm('delete', algo_id);

                    volume{ii_ss} = astra_mex_data3d('get_single', self.volume_id);
                    astra_mex_data3d('delete', sino_id);
                end
            end
            volume = gtMathsSumCellVolumes(volume) ./ num_ss_geoms;
        end

        function sinogram = fwd_project_single_volume(self, volume, geom_idx)
            if (self.use_astra_projectors)
                if (numel(volume) == 1)
                    volume = volume * ones(self.volume_geometry, 'single');
                end
                sinogram = astra_mex_direct_c('FP3D', self.astra_projectors{geom_idx}, volume);
            else
                astra_mex_data3d('store', self.volume_id, volume);
                astra_mex_algorithm('iterate', self.algo_fproj_ids{geom_idx});
                sinogram = astra_mex_data3d('get_single', self.sinogram_ids{geom_idx});
            end
        end

        function volume = bwd_project_single_volume(self, sinogram, geom_idx)
            if (self.use_astra_projectors)
                volume = astra_mex_direct_c('BP3D', self.astra_projectors{geom_idx}, sinogram);
            else
                astra_mex_data3d('store', self.sinogram_ids{geom_idx}, sinogram);
                astra_mex_algorithm('iterate', self.algo_bproj_ids{geom_idx});
                volume = astra_mex_data3d('get_single', self.volume_id);
            end
        end

        function reset_geometry(self)
            self.astra_projection_geometries = {};

            if (~self.use_astra_projectors)
                if (~isempty(self.sinogram_ids))
                    astra_mex_data3d('delete', self.sinogram_ids{:});
                    self.sinogram_ids = {};
                end

                if (~isempty(self.algo_fproj_ids))
                    astra_mex_algorithm('delete', self.algo_fproj_ids{:});
                    self.algo_fproj_ids = {};
                end

                if (~isempty(self.algo_bproj_ids))
                    astra_mex_algorithm('delete', self.algo_bproj_ids{:});
                    self.algo_bproj_ids = {};
                end
            else
                if (~isempty(self.astra_projectors))
                    astra_mex_projector3d('delete', self.astra_projectors{:});
                    self.astra_projectors = {};
                end
            end
        end

        function reset_ss_geometry(self)
            self.astra_ss_projection_geometries = {};

            if (self.use_astra_projectors && ~isempty(self.astra_ss_projectors))
                astra_mex_projector3d('delete', self.astra_ss_projectors{:});
                self.astra_ss_projectors = {};
            end
        end
    end
end