Skip to content
Snippets Groups Projects
Gt6DVolumeProjector.m 5.81 KiB
classdef Gt6DVolumeProjector < handle
    properties
        proj_size = [];

        volume_ss = 1; % Volume downscaling option
        rspace_ss = 1; % Real-space oversampling

        volume_geometry = {};
        astra_volume_geometry = {};

        geometries = {};
        ss_geometries = {};

        astra_projection_geometries = {};
        astra_projector_ids = {};

        algo_fproj_ids = {};
        algo_bproj_ids = {};

        data_type = 'single';

        num_threads = 1;
        num_gpus = 1;
        jobs_bunch_size = 16;

        statistics = GtTasksStatistics();
    end

    properties (Constant)
        messageNoGPU = 'This machine cannot be used to forward and back project';
    end

    methods (Access = public)
        function self = Gt6DVolumeProjector(vols_size, detector_size, varargin)
            % volume geometry (x, y, z)
            self.volume_geometry = vols_size;
            self.astra_volume_geometry = astra_create_vol_geom(vols_size(2), vols_size(1), vols_size(3));

            self.proj_size = detector_size;

            self = parse_pv_pairs(self, varargin);

            if (exist('astra_mex_direct_c', 'file') ~= 3)
                error('Gt6DVolumeProjector:bad_astra_installation', ...
                    '"astra_mex_direct_c" is not available! Please update or recompile ASTRA')
            end

            self.initProjectionGeometry();

            try
                self.num_threads = feature('NumCores');
                self.jobs_bunch_size = self.num_threads;
            catch
            end

            try
                xml_conf = gtConfLoadXML();
                astra_gpus = gtConfGetField(xml_conf, 'astra.gpu');
                for ii_c = 1:numel(astra_gpus)
                    gpus = gtConfFilterAttribute(astra_gpus(ii_c), 'hostname');
                    if (~isempty(gpus))
                        self.num_gpus = gtConfGetField(gpus, 'count');
                        try
                            gpus_indx = gtConfGetField(gpus, 'indx');
                        catch
                            gpus_indx = 0:self.num_gpus-1;
                        end
                        astra_mex('set_gpu_index', gpus_indx);
                        break;
                    end
                end
            catch mexc
                self.num_gpus = 1;
                gtPrintException(mexc, ...
                    'No Astra gpu information, defaulting to 1 gpu only');
            end
        end

        function delete(self)
            self.reset_geometry();
        end

        function initProjectionGeometry(self)
            % Let's clean up previous geometries
            self.reset_geometry();

            % Volume Downscaling option and similar
            opts = struct( ...
                'VoxelSuperSampling', self.volume_ss * self.rspace_ss, ...
                'DetectorSuperSampling', self.rspace_ss, ...
                'GPUindex', -1 );

            if (isempty(self.ss_geometries))
                num_geoms = numel(self.geometries);
            else
                num_geoms = numel(self.ss_geometries);
            end

            for n = 1:num_geoms
                if (isempty(self.ss_geometries))
                    geom = self.geometries{n};
                else
                    geom = cat(1, self.ss_geometries{n}{:});
                end

                self.astra_projection_geometries{n} = astra_create_proj_geom(...
                    'parallel3d_vec', self.proj_size(2), self.proj_size(1), geom);

                self.astra_projector_ids{n} = astra_create_projector('cuda3d', ...
                    self.astra_projection_geometries{n}, self.astra_volume_geometry, opts);
            end
        end

        function num_geometries = get_number_geometries(self)
            num_geometries = numel(self.astra_projector_ids);
        end

        function is_ss = using_super_sampling(self)
            is_ss = ~isempty(self.ss_geometries);
        end

        function chunk_size = get_jobs_chunk_size(self)
            chunk_size = self.num_gpus * self.jobs_bunch_size;
        end

        function printStats(self)
            self.statistics.printStats()
        end

        function stats = get_statistics(self)
            stats = self.statistics;
        end
    end

    methods (Access = protected)
        function sinogram = fwd_project_volumes_to_sinos(self, volume, n)
        % Basic Fwd-Projection function

            if (self.using_super_sampling())
                volume = self.rescale_volume_ss(volume, n);
            end

            sinogram = astra_mex_direct_c('FP3D', [self.astra_projector_ids{n}], volume);
        end

        function volume = bwd_project_sinos_to_volumes(self, sinogram, n)
        % Basic Bwd-Projection function

            volume = astra_mex_direct_c('BP3D', [self.astra_projector_ids{n}], sinogram);

            if (self.using_super_sampling())
                volume = self.rescale_volume_ss(volume, n);
            end
        end

        function reset_geometry(self)
            self.astra_projection_geometries = {};

            if (~isempty(self.astra_projector_ids))
                astra_mex_projector3d('delete', self.astra_projector_ids{:});
                self.astra_projector_ids = {};
            end
        end
    end

    methods (Access = private)
        function volume = rescale_single_volume_ss(self, volume, n)
            num_ss_geoms = numel(self.ss_geometries{n});
            volume = volume ./ num_ss_geoms;
        end

        function volume = rescale_volume_ss(self, volume, n)
            num_orients = numel(n);
            if (num_orients > 1)
                for ii_o = 1:num_orients
                    volume{ii_o} = self.rescale_single_volume_ss(volume{ii_o}, n(ii_o));
                end
            else
                volume = self.rescale_single_volume_ss(volume, n);
            end
        end
    end
end