Skip to content
Snippets Groups Projects
Gt6DVolumeProjector.m 10.9 KiB
Newer Older
classdef Gt6DVolumeProjector < handle
    properties
        proj_size = [];

        volume_geometry = {};
        astra_volume_geometry = {};

        geometries = {};
        astra_projection_geometries = {};
        ss_geometries = {};
        astra_ss_projection_geometries = {};

        astra_ss_projectors = {};
        sinogram_ids = {};

        algo_fproj_ids = {};
        algo_bproj_ids = {};

        statistics = GtTasksStatistics();
    end

    properties (Constant)
        messageNoGPU = 'This machine cannot be used to forward and back project';
    end

    methods (Access = public)
        function self = Gt6DVolumeProjector(vols_size, detector_size, varargin)
            % volume geometry (x, y, z)
            self.volume_geometry = vols_size;
            self.astra_volume_geometry = astra_create_vol_geom(vols_size(2), vols_size(1), vols_size(3));
            self.proj_size = detector_size;
            self = parse_pv_pairs(self, varargin);

            if ((exist('astra_mex_direct_c', 'file') ~= 3) || (self.volume_ss ~= 1))
                self.use_astra_projectors = false;
            end

            if (~self.use_astra_projectors)
                self.volume_id = astra_mex_data3d('create', '-vol', self.astra_volume_geometry, 0);
            else
                if (~isempty(self.ss_geometries))
                    try
                        self.use_astra_oss_projectors = astra_mex_direct_c('supports', 'FP3D_OSS') ...
                            && astra_mex_direct_c('supports', 'BP3D_OSS');
                    catch
                    end
                end
        end

        function delete(self)
            if (~isempty(self.volume_id))
                astra_mex_data3d('delete', self.volume_id);
        end

        function initProjectionGeometry(self)
            % Detector position is defined, through a vector pointing to it's
            % center

            % Let's clean up previous geometries
            if (~isempty(self.astra_projection_geometries))
            num_geoms = numel(self.geometries);
            for n = 1:num_geoms
                self.astra_projection_geometries{n} = astra_create_proj_geom(...
                    'parallel3d_vec', self.proj_size(2), self.proj_size(1), self.geometries{n});

                if (self.use_astra_projectors)
                    self.astra_projectors{n} = astra_create_projector('cuda3d', ...
                        self.astra_projection_geometries{n}, self.astra_volume_geometry);
                end
            if (~self.use_astra_projectors)
                self.sinogram_ids = astra_mex_data3d('create', '-proj3d', self.astra_projection_geometries);
                % Let's preallocate the algorithms
                cfg = astra_struct('FP3D_CUDA');
                for ii = 1:numel(self.sinogram_ids)
                    cfg.ProjectionDataId = self.sinogram_ids{ii};
                    cfg.VolumeDataId = self.volume_id;
                    % For square weights
%                     cfg.ProjectionKernel = 'sum_square_weights';
                    self.algo_fproj_ids{ii} = astra_mex_algorithm('create', cfg);
                end
                cfg = astra_struct('BP3D_CUDA');
                for ii = 1:numel(self.sinogram_ids)
                    cfg.ProjectionDataId = self.sinogram_ids{ii};
                    cfg.ReconstructionDataId = self.volume_id;
                    cfg.option.VoxelSuperSampling = self.volume_ss;

                    self.algo_bproj_ids{ii} = astra_mex_algorithm('create', cfg);
                end
            end
        end

        function initProjectionGeometrySS(self)
            % Let's clean up previous geometries
            if (~isempty(self.astra_ss_projection_geometries))
                self.reset_ss_geometry();
            end

            num_ss_geoms = numel(self.ss_geometries);

            for g_ii = 1:num_ss_geoms
                num_sub_ss_geoms = numel(self.ss_geometries{g_ii});
                self.astra_ss_projection_geometries{g_ii} = cell(num_sub_ss_geoms, 1);
                if (self.use_astra_projectors)
                    self.astra_ss_projectors{g_ii} = cell(num_sub_ss_geoms, 1);
                end
                for g_ss_ii = 1:num_sub_ss_geoms
                    self.astra_ss_projection_geometries{g_ii}{g_ss_ii} = astra_create_proj_geom(...
                        'parallel3d_vec', self.proj_size(2), self.proj_size(1), self.ss_geometries{g_ii}{g_ss_ii});

                    if (self.use_astra_projectors)
                        self.astra_ss_projectors{g_ii}{g_ss_ii} = astra_create_projector('cuda3d', ...
                            self.astra_ss_projection_geometries{g_ii}{g_ss_ii}, self.astra_volume_geometry);
            end
        end

        function printStats(self)
            self.statistics.printStats()
        end

        function stats = get_statistics(self)
            stats = self.statistics;
        end
    end

    methods (Access = protected)
        function ss_sinogram = fwd_project_single_volume_ss(self, volume, ii)
            ss_geoms = self.astra_ss_projection_geometries{ii};
            num_ss_geoms = numel(ss_geoms);
            if (self.use_astra_projectors)
                if (numel(volume) == 1)
                    volume = volume * ones(self.volume_geometry, 'single');
                end
                ss_projectors = self.astra_ss_projectors{ii};
            else
                astra_mex_data3d('store', self.volume_id, volume);
            if (self.use_astra_oss_projectors)
                ss_sinogram = astra_mex_direct_c('FP3D_OSS', [ss_projectors{:}], volume);
            else
                volume = volume ./ num_ss_geoms;
                ss_sinogram = cell(num_ss_geoms, 1);
                for ii_ss = 1:num_ss_geoms
                    if (self.use_astra_projectors)
                        ss_sinogram{ii_ss} = astra_mex_direct_c('FP3D', ss_projectors{ii_ss}, volume);
                    else
                        sino_id = astra_mex_data3d('create', '-proj3d', ss_geoms{ii_ss}, 0);
                        cfg = astra_struct('FP3D_CUDA');
                        cfg.ProjectionDataId = sino_id;
                        cfg.VolumeDataId = self.volume_id;
                        algo_id = astra_mex_algorithm('create', cfg);
                        astra_mex_algorithm('iterate', algo_id);
                        astra_mex_algorithm('delete', algo_id);
                        ss_sinogram{ii_ss} = astra_mex_data3d('get_single', sino_id);
                        astra_mex_data3d('delete', sino_id);
                    end
        function volume = bwd_project_single_volume_ss(self, ss_sinogram, ii)
            ss_geoms = self.astra_ss_projection_geometries{ii};
            num_ss_geoms = numel(ss_geoms);
            if (self.use_astra_projectors)
                ss_projectors = self.astra_ss_projectors{ii};
            end
            if (self.use_astra_oss_projectors)
                volume = astra_mex_direct_c('BP3D_OSS', [ss_projectors{:}], ss_sinogram);
            else
                volume = cell(num_ss_geoms, 1);
                for ii_ss = 1:num_ss_geoms
                    if (self.use_astra_projectors)
                        volume{ii_ss} = astra_mex_direct_c('BP3D', ss_projectors{ii_ss}, ss_sinogram{ii_ss});
                    else
                        sino_id = astra_mex_data3d('create', '-proj3d', ss_geoms{ii_ss}, ss_sinogram{ii_ss});
                        cfg = astra_struct('BP3D_CUDA');
                        cfg.ProjectionDataId = sino_id;
                        cfg.ReconstructionDataId = self.volume_id;
                        algo_id = astra_mex_algorithm('create', cfg);
                        astra_mex_algorithm('iterate', algo_id);
                        astra_mex_algorithm('delete', algo_id);
                        volume{ii_ss} = astra_mex_data3d('get_single', self.volume_id);
                        astra_mex_data3d('delete', sino_id);
                    end
                volume = gtMathsSumCellVolumes(volume) ./ num_ss_geoms;
        function sinogram = fwd_project_single_volume(self, volume, geom_idx)
            if (self.use_astra_projectors)
                if (numel(volume) == 1)
                    volume = volume * ones(self.volume_geometry, 'single');
                end
                sinogram = astra_mex_direct_c('FP3D', self.astra_projectors{geom_idx}, volume);
            else
                astra_mex_data3d('store', self.volume_id, volume);
                astra_mex_algorithm('iterate', self.algo_fproj_ids{geom_idx});
                sinogram = astra_mex_data3d('get_single', self.sinogram_ids{geom_idx});
            end
        function volume = bwd_project_single_volume(self, sinogram, geom_idx)
            if (self.use_astra_projectors)
                volume = astra_mex_direct_c('BP3D', self.astra_projectors{geom_idx}, sinogram);
            else
                astra_mex_data3d('store', self.sinogram_ids{geom_idx}, sinogram);
                astra_mex_algorithm('iterate', self.algo_bproj_ids{geom_idx});
                volume = astra_mex_data3d('get_single', self.volume_id);
            end
        end

        function reset_geometry(self)
            self.astra_projection_geometries = {};

            if (~self.use_astra_projectors)
                if (~isempty(self.sinogram_ids))
                    astra_mex_data3d('delete', self.sinogram_ids{:});
                    self.sinogram_ids = {};
                end

                if (~isempty(self.algo_fproj_ids))
                    astra_mex_algorithm('delete', self.algo_fproj_ids{:});
                    self.algo_fproj_ids = {};
                end

                if (~isempty(self.algo_bproj_ids))
                    astra_mex_algorithm('delete', self.algo_bproj_ids{:});
                    self.algo_bproj_ids = {};
                end
            else
                if (~isempty(self.astra_projectors))
                    astra_mex_projector3d('delete', self.astra_projectors{:});
                    self.astra_projectors = {};
                end
            end

        function reset_ss_geometry(self)
            self.astra_ss_projection_geometries = {};

            if (self.use_astra_projectors && ~isempty(self.astra_ss_projectors))
                astra_mex_projector3d('delete', self.astra_ss_projectors{:});
                self.astra_ss_projectors = {};
            end
        end