Skip to content
Snippets Groups Projects
Gt6DVolumeProjector.m 12.4 KiB
Newer Older
classdef Gt6DVolumeProjector < handle
    properties
        proj_size = [];

        angular_deviation = 0;
        angular_sampling = 1;

        volume_geometries = {};
        astra_volume_geometries = {};

        geometries = {};
        astra_geometries = {};

        ss_coeffs = {};
        ss_geometries = {};
        astra_ss_geometries = {};

        volume_ids = {};
        sinogram_ids = {};

        algo_fproj_ids = {};
        algo_bproj_ids = {};

        statistics = GtTasksStatistics();
    end

    properties (Constant)
        messageNoGPU = 'This machine cannot be used to forward and back project';
    end

    methods (Access = public)
        function self = Gt6DVolumeProjector(initialVolumes, detectorSize, varargin)
%             self.checkGPU()

            if (~exist('initialVolumes', 'var') || isempty(initialVolumes))
                error('Gt6DVolumeProjector:wrong_argument', ...
                    'No initial volumes specified')
            end

            numVolumes = length(initialVolumes);

            self.statistics.add_task('fproj_raw_sum', 'Raw FP time');
            self.statistics.add_task('bproj_raw_sum', 'Raw BP time');

            self.statistics.add_task('fproj_raw_ss_single', 'Raw SS FP time (single)');
            self.statistics.add_task('bproj_raw_ss_single', 'Raw SS BP time (single)');

            for ii = 2:numVolumes
                self.testVolumeGeometry(initialVolumes{1}, initialVolumes{ii});
            end

            % volume geometry (x, y, z)
            self.volume_geometries = cell(numVolumes, 1);
            self.astra_volume_geometries = cell(numVolumes, 1);

            volSize = [size(initialVolumes{1}, 1), ...
                size(initialVolumes{1}, 2), size(initialVolumes{1}, 3)];
            tempVolGeom = astra_create_vol_geom(volSize(1), volSize(2), volSize(3));
            for n = 1:numVolumes
                self.volume_geometries{n} = volSize;
                self.astra_volume_geometries{n} = tempVolGeom;
            end

            self.volume_ids = astra_mex_data3d('create', '-vol', self.astra_volume_geometries, initialVolumes);

            if (~exist('detectorSize', 'var'))
                diagonal = round(norm(size(initialVolumes{1})));
                self.proj_size = [diagonal diagonal];
            else
                self.proj_size = detectorSize;
            end

            self = parse_pv_pairs(self, varargin);
        end

        function delete(self)
            if (~isempty(self.volume_ids))
                astra_mex_data3d('delete', self.volume_ids{:});
            end
            if (~isempty(self.sinogram_ids))
                astra_mex_data3d('delete', self.sinogram_ids{:});
            end

            if (~isempty(self.algo_fproj_ids))
                astra_mex_data3d('delete', self.algo_fproj_ids{:});
            end
            if (~isempty(self.algo_bproj_ids))
                astra_mex_data3d('delete', self.algo_bproj_ids{:});
            end
        end

        function initProjectionGeometry(self)
            % Detector position is defined, through a vector pointing to it's
            % center

            % Let's clean up previous geometries
            if (~isempty(self.astra_geometries))
                astra_mex_data3d('delete', self.astra_geometries{:});
                self.astra_geometries = {};
                astra_mex_data3d('delete', self.algo_fproj_ids{:});
                self.algo_fproj_ids = {};
                astra_mex_data3d('delete', self.algo_bproj_ids{:});
                self.algo_bproj_ids = {};
            end

            geometry = self.geometries;

            if (iscell(geometry))
                num_geoms = numel(geometry);
                for n = 1:num_geoms
                    self.astra_geometries{n} = astra_create_proj_geom(...
                        'parallel3d_vec', self.proj_size(2), self.proj_size(1), geometry{n});
                end
                self.sinogram_ids = astra_mex_data3d('create', '-proj3d', self.astra_geometries);
            else
                self.astra_geometries{end+1} = astra_create_proj_geom(...
                    'parallel3d_vec', self.proj_size(2), self.proj_size(1), geometry);
                self.sinogram_ids = astra_mex_data3d('create', '-proj3d', self.astra_geometries);
            end

            % Let's preallocate the algorithms
            cfg = astra_struct('FP3D_CUDA');
%             cfg.option.DetectorSuperSampling = self.angular_sampling;
%             cfg.option.AngularDeviation = self.angular_deviation;
            for ii = 1:numel(self.sinogram_ids)
                cfg.ProjectionDataId = self.sinogram_ids{ii};
                cfg.VolumeDataId = self.volume_ids{ii};
                % For square weights
%                 cfg.ProjectionKernel = 'sum_square_weights';

                self.algo_fproj_ids{ii} = astra_mex_algorithm('create', cfg);
            end

            cfg = astra_struct('BP3D_CUDA');
%             cfg.option.VoxelSuperSampling = self.angular_sampling;
%             cfg.option.AngularDeviation = self.angular_deviation;
            for ii = 1:numel(self.sinogram_ids)
                cfg.ProjectionDataId = self.sinogram_ids{ii};
                cfg.ReconstructionDataId = self.volume_ids{ii};

                self.algo_bproj_ids{ii} = astra_mex_algorithm('create', cfg);
            end
        end

        function initProjectionGeometrySS(self)
            % Let's clean up previous geometries
            if (~isempty(self.astra_ss_geometries))
                astra_mex_data3d('delete', self.astra_ss_geometries{:});
                self.astra_ss_geometries = {};
            end

            num_ss_geoms = numel(self.ss_geometries);

            for g_ii = 1:num_ss_geoms
                self.astra_ss_geometries{g_ii} = astra_create_proj_geom(...
                    'parallel3d_vec', self.proj_size(2), self.proj_size(1), self.ss_geometries{g_ii});
            end
        end

        function printStats(self)
            self.statistics.printStats()
        end

        function forwardWeights = getRowsSum(self)
            numGeoms = numel(self.astra_volume_geometries);
            ones_cell = arrayfun(@(z){1}, 1:numGeoms);

            forwardWeights = self.projectVolume(ones_cell);
        end

        function backwardWeights = getColumnsSum(self)
            backwardWeights = self.backprojectVolume(1);
        end
    end

    methods (Access = protected)
        function sinograms = projectVolume(self, volumes)
            numGeoms = numel(self.astra_geometries);
            if (numGeoms ~= numel(volumes))
                error('Gt6DVolumeProjector:wrong_argument', ...
                    'Number of volumes mismatch with allocated buffers')
            end
            astra_mex_data3d('store', self.volume_ids, volumes);

            self.statistics.tic('fproj_raw_sum');
%             astra_mex_algorithm('iterate', self.algo_fproj_ids, 1, 0, true);
            for ii = 1:numGeoms
                astra_mex_algorithm('iterate', self.algo_fproj_ids{ii});
            end
            self.statistics.toc('fproj_raw_sum');

%             sinograms = astra_mex_data3d('get', self.sinogram_ids);
            sinograms = astra_mex_data3d('get_single', self.sinogram_ids);
        end

        function volumes = backprojectVolume(self, sinograms)
            numGeoms = numel(self.astra_geometries);
            if (numGeoms ~= numel(sinograms))
                error('Gt6DVolumeProjector:backprojVol:wrong_argument', ...
                    'Sinograms should be either a volume, or a cell of volumes')
            end
            astra_mex_data3d('store', self.sinogram_ids, sinograms);

            self.statistics.tic('bproj_raw_sum');
%             astra_mex_algorithm('iterate', self.algo_bproj_ids, 1, 0, true);
            for ii = 1:numGeoms
                astra_mex_algorithm('iterate', self.algo_bproj_ids{ii});
            end
            self.statistics.toc('bproj_raw_sum');

%             volumes = astra_mex_data3d('get', self.volume_ids);
            volumes = astra_mex_data3d('get_single', self.volume_ids);
        end

        function ss_sinograms = projectVolumeSS(self, volumes)
            num_ss_geoms = numel(self.ss_geometries);
            ss_sinograms = cell(num_ss_geoms, 1);

            for ii = 1:num_ss_geoms
                ss_sinograms{ii} = self.projectSingleVolumeSS(volumes, ii);
            end
        end

        function ss_sinogram = projectSingleVolumeSS(self, volumes, ii)
            self.statistics.tic('fproj_raw_ss_single');

            cfg = astra_struct('FP3D_CUDA');
%             cfg.option.DetectorSuperSampling = self.angular_sampling;
%             cfg.option.AngularDeviation = self.angular_deviation;

            vol_size = self.volume_geometries{1};
            vol = zeros(vol_size, class(volumes{1}));

            coeffs = self.ss_coeffs(ii);
            num_inerp = numel(coeffs.indx);
            for c = 1:num_inerp
                vol = vol + coeffs.coeff(c) * volumes{coeffs.indx(c)};
            end

            sino_id = astra_mex_data3d('create', '-proj3d', self.astra_ss_geometries{ii}, 0);
            vol_id = astra_mex_data3d('create', '-vol', self.astra_volume_geometries{1}, vol);

            cfg.ProjectionDataId = sino_id;
            cfg.VolumeDataId = vol_id;

            algo_id = astra_mex_algorithm('create', cfg);
            astra_mex_algorithm('iterate', algo_id);
            astra_mex_algorithm('delete', algo_id);

            ss_sinogram = astra_mex_data3d('get_single', sino_id);
            astra_mex_data3d('delete', sino_id, vol_id);

            self.statistics.toc('fproj_raw_ss_single');
        end

        function volumes = backprojectVolumeSS(self, ss_sinograms)
            num_ss_geoms = numel(self.ss_geometries);

            num_vols = numel(self.volume_geometries);
            volumes = cell(num_vols, 1);
            for ii = 1:num_vols
                volumes{ii} = zeros(self.volume_geometries{ii}, 'single');
            end

            for ii = 1:num_ss_geoms
                ss_sino = ss_sinograms{ii};
                vol = self.backprojectSingleVolumeSS(ss_sino, ii);

                coeffs = self.ss_coeffs(ii);
                num_inerp = numel(coeffs.indx);
                for c = 1:num_inerp
                    volumes{coeffs.indx(c)} = volumes{coeffs.indx(c)} + coeffs.coeff(c) * vol;
                end
            end
        end

        function volume = backprojectSingleVolumeSS(self, ss_sinogram, ii)
            self.statistics.tic('bproj_raw_ss_single');

            cfg = astra_struct('BP3D_CUDA');
%             cfg.option.VoxelSuperSampling = self.angular_sampling;
%             cfg.option.AngularDeviation = self.angular_deviation;

            vol_size = self.astra_volume_geometries{1};

            sino_id = astra_mex_data3d('create', '-proj3d', self.astra_ss_geometries{ii}, ss_sinogram);
            vol_id = astra_mex_data3d('create', '-vol', vol_size, 0);

            cfg.ProjectionDataId = sino_id;
            cfg.ReconstructionDataId = vol_id;

            algo_id = astra_mex_algorithm('create', cfg);
            astra_mex_algorithm('iterate', algo_id);
            astra_mex_algorithm('delete', algo_id);

            volume = astra_mex_data3d('get_single', vol_id);

            astra_mex_data3d('delete', sino_id, vol_id);

            self.statistics.toc('bproj_raw_ss_single');
      end

        function assignVol(self, newVols)
            for ii = 1:length(newVols)
                self.testVolumeGeometry(newVols{ii});

                self.volume{ii} = newVols{ii};
            end
        end

        function updateVol(self, diffVols)
            for ii = 1:length(diffVols)
                self.testVolumeGeometry(diffVols{ii});

                self.volume{ii} = self.volume{ii} - diffVols{ii};
            end
        end

        function checkGPU(self)
            try
               d = gpuDevice;
               hasGPU = d.SupportsDouble;
            catch mexc1
               mexc = MException('ASTRA:no_GPU', self.messageNoGPU);
               mexc = addCause(mexc, mexc1);
               throw(mexc)
            end

            if (~hasGPU)
               mexc = MException('ASTRA:no_GPU', self.messageNoGPU);
               throw(mexc)
            end
        end

        function testVolumeGeometry(~, referenceVol, newVol)
            if (~isequal(size(referenceVol), size(newVol)))
                mexc = MException('PROJECTOR:wrong_argument', ...
                            'Size mismatch between the old and the new volume');
                throw(mexc)
            end
        end
    end
end