function [atts_tot, atts, abs_vol] = gtGrainComputeBeamAttenuation(gr, p, det_ind, abs_vol, varargin)

    conf = struct('verbose', false, 'sampling_factor', 1);
    conf = parse_pv_pairs(conf, varargin);

    if (~exist('conf.sampling_factor', 'var') || isempty(conf.sampling_factor))
        conf.sampling_factor = 1;
    end

    if (~exist('det_ind', 'var') || isempty(det_ind))
        det_ind = 1;
    end

    if (isfield(gr.proj, 'ondet'))
        gr_included = gr.proj(det_ind).ondet(gr.proj(det_ind).included);
    else
        gr_included = gr.ondet(gr.included);
    end

    rot_l2s = gr.allblobs(det_ind).srot(:, :, gr_included);
    % They are reversed because we also use them to find the intersaction
    % with the volume border
    beam_dirs_in = - p.labgeo.beamdir * reshape(rot_l2s, 3, []);
    beam_dirs_in = permute(reshape(beam_dirs_in, 1, 3, []), [3 2 1]);

    beam_dirs_out = gr.allblobs(det_ind).dvecsam(gr_included, :);

    gr_center_pix = gr.proj(det_ind).centerpix;

    if (~exist('abs_vol', 'var') || isempty(abs_vol))
        sample = GtSample.loadFromFile();

        abs_rec = load(sample.absVolFile);

        mask_path = fullfile(p.acq(det_ind).dir, '5_reconstruction', 'volume_mask.mat');
        vol_mask = load(mask_path);

        abs_vol = abs_rec.abs .* single(vol_mask.vol);

        if (isfield(abs_rec, 'rot_angle') && ~isempty(abs_rec.rot_angle))
            rot_comp = gtMathsRotationMatrixComp(abs_rec.rot_axis', 'col');
            rot_tensor = gtMathsRotationTensor(abs_rec.rot_angle, rot_comp);

            beam_dirs_in = gtMathsMatrixProduct(beam_dirs_in, rot_tensor);
            beam_dirs_out = gtMathsMatrixProduct(beam_dirs_out, rot_tensor);

            gr_center_pix = gr_center_pix * rot_tensor;
        end
    end
    gr_center = gr_center_pix + size(abs_vol) / 2 + 1;

    % ray tracing
    dists_borders = cat(3, [1 1 1], size(abs_vol)) - gr_center(1, :, [1 1]);

    dists_renorm_in = compute_renormalized_distances(dists_borders, beam_dirs_in);
    dists_renorm_out = compute_renormalized_distances(dists_borders, beam_dirs_out);

    num_refl = numel(gr_included);
    atts = zeros(num_refl, 2);

    if (conf.verbose)
        rescaled_abs_vol = (abs_vol - min(abs_vol(:))) / (max(abs_vol(:)) - min(abs_vol(:)));
        rescaled_abs_vol = permute(rescaled_abs_vol, [ 2 1 3]);
        [sx, sy, sz] = ndgrid(1:size(abs_vol, 1), 1:size(abs_vol, 2), gr_center(3));
    end

    for ii = 1:num_refl
%         dist_in = dists_borders(d_inds_in_dir(d_inds_in_pm(ii)));
%         samp_points_in = linspace(0, dists_renorm_in(ii), ceil(abs(dist_in) * conf.sampling_factor) );

        samp_factor_in = ceil(dists_renorm_in(ii) * conf.sampling_factor);

        samp_points_in = linspace(0, dists_renorm_in(ii), samp_factor_in);
        samp_points_in = bsxfun(@times, beam_dirs_in(ii, :), samp_points_in');
        samp_points_in = bsxfun(@plus, samp_points_in, gr_center);

%         dist_out = dists_borders(d_inds_out_dir(d_inds_out_pm(ii)));
%         samp_points_out = linspace(0, dists_renorm_out(ii), ceil(abs(dist_out) * conf.sampling_factor) );

        samp_factor_out = ceil(dists_renorm_out(ii) * conf.sampling_factor);

        samp_points_out = linspace(0, dists_renorm_out(ii), samp_factor_out);
        samp_points_out = bsxfun(@times, beam_dirs_out(ii, :), samp_points_out');
        samp_points_out = bsxfun(@plus, samp_points_out, gr_center);

        [inds_in, ints_in] = gtMathsGetInterpolationIndices(samp_points_in);
        valid_in = ints_in > eps('single');
        inds_in = inds_in(valid_in, :);
        ints_in = ints_in(valid_in);

        [inds_out, ints_out] = gtMathsGetInterpolationIndices(samp_points_out);
        valid_out = ints_out > eps('single');
        inds_out = inds_out(valid_out, :);
        ints_out = ints_out(valid_out);

        inds_in = sub2ind(size(abs_vol), inds_in(:, 1), inds_in(:, 2), inds_in(:, 3));
        vals_in = abs_vol(inds_in) .* ints_in;

        inds_out = sub2ind(size(abs_vol), inds_out(:, 1), inds_out(:, 2), inds_out(:, 3));
        vals_out = abs_vol(inds_out) .* ints_out;

        abs_in = sum(vals_in) / samp_factor_in * dists_renorm_in(ii);
        abs_out = sum(vals_out) / samp_factor_out * dists_renorm_out(ii);

        atts(ii, :) = [abs_in, abs_out];

        if (conf.verbose)
            f = figure();
            ax = axes('parent', f);
            hold(ax, 'on')
            hsl = slice(ax, rescaled_abs_vol, sx, sy, sz);
            hsl.FaceColor = 'interp';
            hsl.EdgeColor = 'none';
            scatter3(ax, samp_points_in(:, 1), samp_points_in(:, 2), samp_points_in(:, 3));
            scatter3(ax, samp_points_out(:, 1), samp_points_out(:, 2), samp_points_out(:, 3));
            show_bbox(ax, cat(1, [1 1 1], size(abs_vol)))
            hold(ax, 'off')
            pause
        end
    end

    % Might need better scaling, for meaningful attenuation!! (0.1 -> mm to cm)
    atts = exp(-atts * 0.1);
    atts_tot = prod(atts, 2);

    if (conf.verbose)
        figure, semilogy(atts)
        figure, semilogy(atts_tot)
    end
end

function [dists_renorm, d_inds_pm, d_inds_dir] = compute_renormalized_distances(dists_borders, beam_dirs)
    dists_renorm = bsxfun(@times, dists_borders, 1 ./ beam_dirs);
    [dists_renorm, d_inds_pm] = max(dists_renorm, [], 3); % Takes negative out
    [dists_renorm, d_inds_dir] = min(dists_renorm, [], 2); % Takes shortest, and removes last remaining nans
end

function show_bbox(ax, borders)
    min_sampled_R_vecs = min(borders, [], 1);
    max_sampled_R_vecs = max(borders, [], 1);
    bbox_R_vecs = [ ...
        min_sampled_R_vecs(1), min_sampled_R_vecs(2), min_sampled_R_vecs(3); ...
        min_sampled_R_vecs(1), min_sampled_R_vecs(2), max_sampled_R_vecs(3); ...
        min_sampled_R_vecs(1), max_sampled_R_vecs(2), min_sampled_R_vecs(3); ...
        min_sampled_R_vecs(1), max_sampled_R_vecs(2), max_sampled_R_vecs(3); ...

        max_sampled_R_vecs(1), min_sampled_R_vecs(2), min_sampled_R_vecs(3); ...
        max_sampled_R_vecs(1), min_sampled_R_vecs(2), max_sampled_R_vecs(3); ...
        max_sampled_R_vecs(1), max_sampled_R_vecs(2), min_sampled_R_vecs(3); ...
        max_sampled_R_vecs(1), max_sampled_R_vecs(2), max_sampled_R_vecs(3); ...
        ];
    faces = [ ...
        1 5; 2 6; 3 7; 4 8; ...
        1 3; 2 4; 5 7; 6 8; ...
        1 2; 3 4; 5 6; 7 8; ...
        ];

    scatter3(ax, bbox_R_vecs(:, 1), bbox_R_vecs(:, 2), bbox_R_vecs(:, 3), 30, 'y', 'filled');
    patch('parent', ax, 'Faces', faces, 'Vertices', bbox_R_vecs, 'FaceColor', 'w');
end