Skip to content
Snippets Groups Projects
gtGetODF6DFromGvdm.m 4.05 KiB
Newer Older
function odf_6D = gtGetODF6DFromGvdm(gvdm, gvcs, ospace_grid_points, rspace_grid_points, gvint, mode)
% FUNCTION odf = gtGetODF6DFromGvdm(gvdm, gvcs, ospace_grid_points, rspace_grid_points, gvint, mode)
%
% ddm is a <n_voxels x 3> vector, where each row is the Rodriguez vector of
% each voxel

    if (~exist('mode', 'var') || isempty(mode))
        mode = 'linear';
    end

    if (size(gvdm, 1) == 3)
        gvdm = gvdm';
    end

    if (size(gvcs, 1) == 3)
        gvcs = gvcs';
    end

    if (~exist('gvint', 'var'))
        gvint = ones(size(gvdm, 1), 1);
    else
        gvint = reshape(gvint, [], 1);
    end

    if (iscell(ospace_grid_points))
        o_voxel_size = ospace_grid_points{2, 2, 2}.R_vector - ospace_grid_points{1, 1, 1}.R_vector;

        min_o_voxel_center = ospace_grid_points{1, 1, 1}.R_vector;
        max_o_voxel_center = ospace_grid_points{end, end, end}.R_vector;
    else
        % We expect a <1x9> vector with <1x6> o-space bb, plus <1x3> the
        % resolution
        o_voxel_size = ospace_grid_points(7:9);

        min_o_voxel_center = ospace_grid_points(1:3);
        max_o_voxel_center = ospace_grid_points(4:6);
    half_o_voxel_size = o_voxel_size / 2;
    inv_o_voxel_size = 1 ./ o_voxel_size;
    min_o_corner = min_o_voxel_center - half_o_voxel_size;
    max_o_corner = max_o_voxel_center + half_o_voxel_size;
    size_odf = round((max_o_corner - min_o_corner) .* inv_o_voxel_size);

    gvdm_rescaled = gvdm - min_o_voxel_center(ones(num_voxels, 1), :);
    gvdm_rescaled = gvdm_rescaled .* inv_o_voxel_size(ones(num_voxels, 1), :) + 1;
    % We expect a <1x9> vector with <1x6> r-space bb, plus <1x3> the resolution
    r_voxel_size = rspace_grid_points(7:9);
    half_r_voxel_size = r_voxel_size / 2;
    inv_r_voxel_size = 1 ./ r_voxel_size;
    min_r_voxel_center = rspace_grid_points(1:3);
    max_r_voxel_center = rspace_grid_points(4:6);
    min_r_corner = min_r_voxel_center - half_r_voxel_size;
    max_r_corner = max_r_voxel_center + half_r_voxel_size;
    size_vol = round((max_r_corner - min_r_corner) .* inv_r_voxel_size);
    gvcs_rescaled = gvcs - min_r_voxel_center(ones(num_voxels, 1), :);
    gvcs_rescaled = gvcs_rescaled .* inv_r_voxel_size(ones(num_voxels, 1), :) + 1;
    switch (mode)
        case 'nearest'
            inds_o = round(gvdm_rescaled);
            inds_r = round(gvcs_rescaled);

            inds = [inds_o inds_r];
            ints = gvint;
            ones_gvint = gtMathsGetSameSizeOnes(gvint);
            [inds8_o, ints8_o] = gtMathsGetInterpolationIndices(gvdm_rescaled, ones_gvint);
            [inds8_r, ints8_r] = gtMathsGetInterpolationIndices(gvcs_rescaled, ones_gvint);

            % inds8 is <(nx8) x 3> while ints8 is <(nx8) x 1>

            ints8_o = permute(reshape(ints8_o, [], 8), [3 2 1]);
            ints8_r = permute(reshape(ints8_r, [], 8), [2 3 1]);
            ints = ints8_o(ones(8, 1), :, :) .* ints8_r(:, ones(8, 1), :);
            gvint64 = reshape(gvint, 1, []);
            gvint64 = reshape(gvint64(ones(64, 1), :), [], 1);
            ints = reshape(ints, [], 1) .* gvint64;

            inds8_o = reshape(inds8_o, [], 8, 3);
            inds8_r = reshape(inds8_r, [], 8, 3);

            inds8_o = permute(inds8_o, [3 4 2 1]);
            inds8_r = permute(inds8_r, [3 2 4 1]);
            inds8_o = inds8_o(:, ones(8, 1), :, :);
            inds8_r = inds8_r(:, :, ones(8, 1), :);
            inds8_o = reshape(inds8_o, 3, [])';
            inds8_r = reshape(inds8_r, 3, [])';
            inds = [inds8_o, inds8_r];
    end
    valid = ints > eps('single') & all(inds > 0, 2) ...
        & inds(:, 1) <= size_odf(1) & inds(:, 2) <= size_odf(2) & inds(:, 3) <= size_odf(3) ...
        & inds(:, 4) <= size_vol(1) & inds(:, 5) <= size_vol(2) & inds(:, 6) <= size_vol(3);
    inds = inds(valid, :);
    ints = ints(valid, :);
    odf_6D = accumarray(inds, ints, [size_odf, size_vol]);

    fprintf('\b\b (%3.1f s) Done. Total intensity: %g, included: %g\n', ...
        toc(c), sum(gvint), sum(ints));