From 0ddf15ccec4af64f1494d930664c01c4f06cb8fe Mon Sep 17 00:00:00 2001 From: Nicola Vigano <vigano@yoda.esrf.fr> Date: Tue, 3 May 2016 18:42:12 +0200 Subject: [PATCH] 6D-reconstruction: next round of coefficient fixing and simplifications Signed-off-by: Nicola Vigano <nicola.vigano@esrf.fr> --- zUtil_Cxx/6D_ops/gt6DUpdateDualL1_c.cpp | 22 +-- zUtil_Cxx/6D_ops/gt6DUpdateDualTV_c.cpp | 88 ----------- zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp | 22 +-- zUtil_Cxx/include/gt6DUpdateDualL1Ops.h | 17 +-- zUtil_Cxx/include/gt6DUpdateDualTVOps.h | 175 ---------------------- zUtil_Cxx/include/gt6DUpdatePrimalOps.h | 73 ++++----- zUtil_Deformation/Gt6DBlobReconstructor.m | 77 +++++----- 7 files changed, 92 insertions(+), 382 deletions(-) delete mode 100644 zUtil_Cxx/6D_ops/gt6DUpdateDualTV_c.cpp delete mode 100644 zUtil_Cxx/include/gt6DUpdateDualTVOps.h diff --git a/zUtil_Cxx/6D_ops/gt6DUpdateDualL1_c.cpp b/zUtil_Cxx/6D_ops/gt6DUpdateDualL1_c.cpp index 7dec63af..6e634415 100644 --- a/zUtil_Cxx/6D_ops/gt6DUpdateDualL1_c.cpp +++ b/zUtil_Cxx/6D_ops/gt6DUpdateDualL1_c.cpp @@ -9,7 +9,7 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] ) { - if (nrhs < 4) { + if (nrhs < 3) { mexErrMsgIdAndTxt(GT6D::dual_l1_error_id, "Not enough arguments!"); return; @@ -17,8 +17,7 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] ) const mxArray * const dual = prhs[0]; const mxArray * const new_enh_sol = prhs[1]; - const mxArray * const lambda_a = prhs[2]; - const mxArray * const sigma_a = prhs[3]; + const mxArray * const sigma_a = prhs[2]; if (!mxIsCell(dual)) { mexErrMsgIdAndTxt(GT6D::dual_l1_error_id, @@ -30,24 +29,17 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] ) "The second argument should be a Cell array (New Enhancement of Solution)"); return; } - if (!mxIsNumeric(lambda_a)) { + if (!mxIsNumeric(sigma_a)) { mexErrMsgIdAndTxt(GT6D::dual_l1_error_id, - "The third argument should be a scalar (lambda)"); + "The third argument should be a scalar (sigma)"); return; } - if (!mxIsDouble(lambda_a) && !mxIsSingle(lambda_a)) { + if (!mxIsDouble(sigma_a) && !mxIsSingle(sigma_a)) { mexErrMsgIdAndTxt(GT6D::dual_l1_error_id, "The third argument should be either a 'double' or a 'single'"); return; } - double lambda; - if (mxIsDouble(lambda_a)) { - lambda = *mxGetPr(lambda_a); - } else { - lambda = *(float *)mxGetData(lambda_a); - } - double sigma; if (mxIsDouble(sigma_a)) { sigma = *mxGetPr(sigma_a); @@ -76,13 +68,13 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] ) { case mxDOUBLE_CLASS: { - GT6D::update_dual_l1<double> func(lambda, sigma); + GT6D::update_dual_l1<double> func(sigma); cell_iteration<double, false>(plhs[0], prhs[1], func); break; } case mxSINGLE_CLASS: { - GT6D::update_dual_l1<float> func(lambda, sigma); + GT6D::update_dual_l1<float> func(sigma); cell_iteration<float, false>(plhs[0], prhs[1], func); break; } diff --git a/zUtil_Cxx/6D_ops/gt6DUpdateDualTV_c.cpp b/zUtil_Cxx/6D_ops/gt6DUpdateDualTV_c.cpp deleted file mode 100644 index 81a59a28..00000000 --- a/zUtil_Cxx/6D_ops/gt6DUpdateDualTV_c.cpp +++ /dev/null @@ -1,88 +0,0 @@ -/* - * gt6DUpdateDualL1_c.cpp - * - * Created on: Oct 31, 2014 - * Author: vigano - */ - -#include <gt6DUpdateDualTVOps.h> - -void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] ) -{ - if (nrhs < 3) { - mexErrMsgIdAndTxt(GT6D::dual_tv_error_id, - "Not enough arguments!"); - return; - } - - const mxArray * const dual = prhs[0]; - const mxArray * const new_enh_sol = prhs[1]; - const mxArray * const lambda_a = prhs[2]; - - if (!mxIsCell(dual)) { - mexErrMsgIdAndTxt(GT6D::dual_tv_error_id, - "The first argument should be a Cell array (dual variable)"); - return; - } - if (!mxIsCell(new_enh_sol)) { - mexErrMsgIdAndTxt(GT6D::dual_tv_error_id, - "The second argument should be a Cell array (New Enhancement of Solution)"); - return; - } - if (!mxIsNumeric(lambda_a)) { - mexErrMsgIdAndTxt(GT6D::dual_tv_error_id, - "The third argument should be a scalar (lambda)"); - return; - } - if (!mxIsDouble(lambda_a) && !mxIsSingle(lambda_a)) { - mexErrMsgIdAndTxt(GT6D::dual_tv_error_id, - "The third argument should be either a 'double' or a 'single'"); - return; - } - - double lambda; - if (mxIsDouble(lambda_a)) { - lambda = *mxGetPr(lambda_a); - } else { - lambda = *(float *)mxGetData(lambda_a); - } - - if (nrhs >= 4) - { - initialize_multithreading(*mxGetPr(prhs[3])); - } - else - { - initialize_multithreading(); - } - - const mxClassID data_type = GT6D::check_dual_tv_arguments(dual, new_enh_sol); - if (data_type == mxUNKNOWN_CLASS) - { - return; - } - - plhs[0] = mxCreateSharedDataCopy(dual); - - switch (data_type) - { - case mxDOUBLE_CLASS: - { - GT6D::update_dual_tv<double> func(lambda); - cell_iteration<double, false>(plhs[0], prhs[1], func); - break; - } - case mxSINGLE_CLASS: - { - GT6D::update_dual_tv<float> func(lambda); - cell_iteration<float, false>(plhs[0], prhs[1], func); - break; - } - default: - { - mexErrMsgIdAndTxt(GT6D::dual_tv_error_id, - "The argument needs to be a Cell array of floating point numbers"); - return; - } - } -} diff --git a/zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp b/zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp index 9495e21b..71a5fc23 100644 --- a/zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp +++ b/zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp @@ -10,7 +10,7 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] ) { - if (nrhs < 5) { + if (nrhs < 4) { mexErrMsgIdAndTxt(GT6D::primal_error_id, "Not enough arguments!"); return; @@ -23,9 +23,8 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] ) const mxArray * const curr_sol = prhs[0]; const mxArray * const curr_enh_sol = prhs[1]; - const mxArray * const corr_tomo = prhs[2]; - const mxArray * const corr_l1_tv = prhs[3]; - const mxArray * const tau = prhs[4]; + const mxArray * const corrections = prhs[2]; + const mxArray * const tau = prhs[3]; if (!mxIsNumeric(curr_sol)) { mexErrMsgIdAndTxt(GT6D::primal_error_id, @@ -37,16 +36,11 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] ) "The first argument should be an array (old enhanced solution variable)"); return; } - if (!mxIsNumeric(corr_tomo)) { + if (!mxIsNumeric(corrections)) { mexErrMsgIdAndTxt(GT6D::primal_error_id, "The second argument should be an array (Tomo correction)"); return; } - if (!mxIsNumeric(corr_l1_tv)) { - mexErrMsgIdAndTxt(GT6D::primal_error_id, - "The third argument should be an array (l1 correction)"); - return; - } if (!mxIsNumeric(tau)) { mexErrMsgIdAndTxt(GT6D::primal_error_id, "The fourth argument should be a scalar (tau)"); @@ -74,8 +68,8 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] ) initialize_multithreading(); } - if (!GT6D::check_primal_arguments(curr_sol, corr_tomo, corr_l1_tv) || - !GT6D::check_primal_arguments(curr_enh_sol, corr_tomo, corr_l1_tv)) + if (!GT6D::check_primal_arguments(curr_sol, corrections) || + !GT6D::check_primal_arguments(curr_enh_sol, corrections)) { return; } @@ -90,13 +84,13 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] ) case mxDOUBLE_CLASS: { GT6D::update_primal<double> func(scale); - GT6D::primal_iteration<double>(plhs[0], plhs[1], curr_sol, corr_tomo, corr_l1_tv, func); + GT6D::primal_iteration<double>(plhs[0], plhs[1], curr_sol, corrections, func); break; } case mxSINGLE_CLASS: { GT6D::update_primal<float> func(scale); - GT6D::primal_iteration<float>(plhs[0], plhs[1], curr_sol, corr_tomo, corr_l1_tv, func); + GT6D::primal_iteration<float>(plhs[0], plhs[1], curr_sol, corrections, func); break; } case mxUNKNOWN_CLASS: diff --git a/zUtil_Cxx/include/gt6DUpdateDualL1Ops.h b/zUtil_Cxx/include/gt6DUpdateDualL1Ops.h index 9fe55f19..38ad2642 100644 --- a/zUtil_Cxx/include/gt6DUpdateDualL1Ops.h +++ b/zUtil_Cxx/include/gt6DUpdateDualL1Ops.h @@ -22,23 +22,22 @@ namespace GT6D { public: typedef typename SIMDUnrolling<Type>::vVvf vVvf; - update_dual_l1(const Type & _lambda, const Type & _sigma); + update_dual_l1(const Type & _sigma); const Type operator()(const Type & dual, const Type & new_enh_sol) const throw() { const Type temp = dual + new_enh_sol * sigma; - return lambda * temp / std::max(lambda, this->abs(temp)); + return temp / std::max((Type)1.0, this->abs(temp)); } const vVvf operator()(const vVvf & dual, const vVvf & new_enh_sol) const throw() { const vVvf temp = dual + new_enh_sol * sigma_v; - return lambda_v * temp / this->max(lambda_v, this->abs(temp)); + return temp / this->max(ones_v, this->abs(temp)); } protected: - const Type lambda; - const vVvf lambda_v; + const vVvf ones_v; const Type sigma; const vVvf sigma_v; @@ -54,15 +53,15 @@ namespace GT6D { }; template<> - update_dual_l1<float>::update_dual_l1(const float & _lambda, const float & _sigma) - : lambda(_lambda), lambda_v(Coeff<float>::get(_lambda)) + update_dual_l1<float>::update_dual_l1(const float & _sigma) + : ones_v(Coeff<float>::get(1.0)) , sigma(_sigma), sigma_v(Coeff<float>::get(_sigma)) , abs_mask_v(_mm_castsi128_ps(_mm_set1_epi32(0x7fffffff))) { } template<> - update_dual_l1<double>::update_dual_l1(const double & _lambda, const double & _sigma) - : lambda(_lambda), lambda_v(Coeff<double>::get(_lambda)) + update_dual_l1<double>::update_dual_l1(const double & _sigma) + : ones_v(Coeff<double>::get(1)) , sigma(_sigma), sigma_v(Coeff<double>::get(_sigma)) , abs_mask_v(_mm_castsi128_pd(_mm_set1_epi64x(0x7fffffffffffffffL))) { } diff --git a/zUtil_Cxx/include/gt6DUpdateDualTVOps.h b/zUtil_Cxx/include/gt6DUpdateDualTVOps.h deleted file mode 100644 index b405834b..00000000 --- a/zUtil_Cxx/include/gt6DUpdateDualTVOps.h +++ /dev/null @@ -1,175 +0,0 @@ -/* - * gt6DUpdateDualL1Ops.h - * - * Created on: Oct 31, 2014 - * Author: vigano - */ - -#ifndef ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALL1OPS_H_ -#define ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALL1OPS_H_ - -#include "internal_cell_defs.h" - -#include <algorithm> -#include <cmath> - -namespace GT6D { - - const char * dual_tv_error_id = "C_FUN:gt6DUpdateDualL1:wrong_argument"; - - template<typename Type> - class update_dual_tv { - public: - typedef typename SIMDUnrolling<Type>::vVvf vVvf; - - update_dual_tv(const Type & _lambda); - - const Type - operator()(const Type & dual, const Type & new_enh_sol) const throw() - { - const Type temp = dual + new_enh_sol; - return lambda * temp / std::max(lambda, this->abs(temp)); - } - const vVvf - operator()(const vVvf & dual, const vVvf & new_enh_sol) const throw() - { - const vVvf temp = dual + new_enh_sol; - return lambda_v * temp / this->max(lambda_v, this->abs(temp)); - } - protected: - const Type lambda; - const vVvf lambda_v; - - const vVvf abs_mask_v; - - const Type - abs(const Type & val) const throw(); - const vVvf - abs(const vVvf & val) const throw(); - const vVvf - max(const vVvf & val1, const vVvf & val2) const throw(); - }; - - template<> - update_dual_tv<float>::update_dual_tv(const float & _lambda) - : lambda(_lambda), lambda_v(Coeff<float>::get(_lambda)) - , abs_mask_v(_mm_castsi128_ps(_mm_set1_epi32(0x7fffffff))) - { } - - template<> - update_dual_tv<double>::update_dual_tv(const double & _lambda) - : lambda(_lambda), lambda_v(Coeff<double>::get(_lambda)) - , abs_mask_v(_mm_castsi128_pd(_mm_set1_epi64x(0x7fffffffffffffffL))) - { } - - template<> - inline const update_dual_tv<float>::vVvf - update_dual_tv<float>::abs( - const update_dual_tv<float>::vVvf & val) - const throw() - { - typedef update_dual_tv<float>::vVvf vVvf; - return _mm_and_ps(val, abs_mask_v); - } - - template<> - inline const update_dual_tv<double>::vVvf - update_dual_tv<double>::abs( - const update_dual_tv<double>::vVvf & val) - const throw() - { - typedef update_dual_tv<double>::vVvf vVvf; - return _mm_and_pd(val, abs_mask_v); - } - - template<> - inline const float - update_dual_tv<float>::abs(const float & val) - const throw() - { - return std::abs(val); - } - - template<> - inline const double - update_dual_tv<double>::abs(const double & val) - const throw() - { - return std::fabs(val); - } - - template<> - inline const update_dual_tv<float>::vVvf - update_dual_tv<float>::max( - const update_dual_tv<float>::vVvf & val1, - const update_dual_tv<float>::vVvf & val2) - const throw() - { - typedef update_dual_tv<float>::vVvf vVvf; - return _mm_max_ps(val1, val2); - } - - template<> - inline const update_dual_tv<double>::vVvf - update_dual_tv<double>::max( - const update_dual_tv<double>::vVvf & val1, - const update_dual_tv<double>::vVvf & val2) - const throw() - { - typedef update_dual_tv<double>::vVvf vVvf; - return _mm_max_pd(val1, val2); - } - - - inline mxClassID - check_dual_tv_arguments(const mxArray * const cells_dual, - const mxArray * const cells_new_enh_sol) - { - const mwSize num_cells = mxGetNumberOfElements(cells_dual); - if (num_cells != mxGetNumberOfElements(cells_new_enh_sol)) - { - mexErrMsgIdAndTxt(dual_tv_error_id, - "The Cell arrays should have the same number of elements"); - return mxUNKNOWN_CLASS; - } - - mxClassID expected_type = mxUNKNOWN_CLASS; - - for(mwIndex cell_idx = 0; cell_idx < num_cells; cell_idx++) - { - const mxArray * const cell_dual = mxGetCell(cells_dual, cell_idx); - const mxArray * const cell_new_enh_sol = mxGetCell(cells_new_enh_sol, cell_idx); - - const mxClassID type_dual = mxGetClassID(cell_dual); - if (cell_idx == 0) - { - expected_type = type_dual; - } - - if (type_dual != expected_type) - { - mexErrMsgIdAndTxt(dual_tv_error_id, - "The arguments need to be Cell arrays of coherent floating point types"); - return mxUNKNOWN_CLASS; - } - if (type_dual != mxGetClassID(cell_new_enh_sol)) - { - mexErrMsgIdAndTxt(dual_tv_error_id, - "The arguments need to be Cell arrays of coherent floating point types"); - return mxUNKNOWN_CLASS; - } - - const mwSize num_elems = mxGetNumberOfElements(cell_dual); - if (num_elems != mxGetNumberOfElements(cell_new_enh_sol)) - { - mexErrMsgIdAndTxt(dual_tv_error_id, - "Blobs volumes should have the same number of elements as the Dual"); - return mxUNKNOWN_CLASS; - } - } - - return expected_type; - } -}; - -#endif /* ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALL1OPS_H_ */ diff --git a/zUtil_Cxx/include/gt6DUpdatePrimalOps.h b/zUtil_Cxx/include/gt6DUpdatePrimalOps.h index 6231930c..1959b149 100644 --- a/zUtil_Cxx/include/gt6DUpdatePrimalOps.h +++ b/zUtil_Cxx/include/gt6DUpdatePrimalOps.h @@ -14,14 +14,13 @@ namespace GT6D { -#define APPLY_FUNC_5FOLD_PRIMAL(shift_val) \ +#define APPLY_FUNC_4FOLD_PRIMAL(shift_val) \ { \ vVvf inV11 = access.load(out1 + shift_val * simd_8.shift);\ vVvf inV21 = access.load(out2 + shift_val * simd_8.shift);\ const vVvf inV31 = access.load(in3 + shift_val * simd_8.shift);\ const vVvf inV41 = access.load(in4 + shift_val * simd_8.shift);\ - const vVvf inV51 = access.load(in5 + shift_val * simd_8.shift);\ - func(inV11, inV21, inV31, inV41, inV51);\ + func(inV11, inV21, inV31, inV41);\ access.store(out1 + shift_val * simd_8.shift, inV11);\ access.store(out2 + shift_val * simd_8.shift, inV21);\ } @@ -38,20 +37,18 @@ namespace GT6D { void operator()(Type & new_solution, Type & new_enh_solution, - const Type & old_solution, const Type & correction_tomo, - const Type & correction_l1) const throw() + const Type & old_solution, const Type & correction_tomo) const throw() { - new_solution = non_neg(old_solution + (correction_tomo + correction_l1) * tau, Type()); - new_enh_solution = (new_solution + new_solution) - old_solution; + new_solution = non_neg(old_solution + correction_tomo * tau, Type()); + new_enh_solution = new_solution + (new_solution - old_solution); } void operator()(vVvf & new_solution, vVvf & new_enh_solution, - const vVvf & old_solution, const vVvf & correction_tomo, - const vVvf & correction_l1) const throw() + const vVvf & old_solution, const vVvf & correction_tomo) const throw() { - new_solution = non_neg(old_solution + (correction_tomo + correction_l1) * tau_v, vVvf()); - new_enh_solution = (new_solution + new_solution) - old_solution; + new_solution = non_neg(old_solution + correction_tomo * tau_v, vVvf()); + new_enh_solution = new_solution + (new_solution - old_solution); } protected: const Type tau; @@ -62,8 +59,7 @@ namespace GT6D { inline bool check_primal_arguments(const mxArray * const mat_old_solution, - const mxArray * const mat_correction_tomo, - const mxArray * const mat_correction_l1) + const mxArray * const mat_correction_tomo) { const mxClassID type = mxGetClassID(mat_old_solution); const mwSize num_elems = mxGetNumberOfElements(mat_old_solution); @@ -74,12 +70,6 @@ namespace GT6D { "tomo correction should have the same type as the Primal"); return false; } - if (type != mxGetClassID(mat_correction_l1)) - { - mexErrMsgIdAndTxt(primal_error_id, - "l1 correction blobs should have the same type as the Primal"); - return false; - } if (num_elems != mxGetNumberOfElements(mat_correction_tomo)) { @@ -87,12 +77,6 @@ namespace GT6D { "tomo correction should have the same number of elements as the Primal"); return false; } - if (num_elems != mxGetNumberOfElements(mat_correction_l1)) - { - mexErrMsgIdAndTxt(primal_error_id, - "l1 correction should have the same number of elements as the Primal"); - return false; - } return true; } @@ -103,7 +87,6 @@ namespace GT6D { Type * const __restrict out_data2, const Type * const __restrict in_data3, const Type * const __restrict in_data4, - const Type * const __restrict in_data5, const mwSize & num_elems, Function & func) { typedef typename SIMDUnrolling<Type>::vVvf vVvf; @@ -123,16 +106,15 @@ namespace GT6D { Type * const out2 = out_data2 + elemIdx; const Type * const in3 = in_data3 + elemIdx; const Type * const in4 = in_data4 + elemIdx; - const Type * const in5 = in_data5 + elemIdx; - - APPLY_FUNC_5FOLD_PRIMAL(0); - APPLY_FUNC_5FOLD_PRIMAL(1); - APPLY_FUNC_5FOLD_PRIMAL(2); - APPLY_FUNC_5FOLD_PRIMAL(3); - APPLY_FUNC_5FOLD_PRIMAL(4); - APPLY_FUNC_5FOLD_PRIMAL(5); - APPLY_FUNC_5FOLD_PRIMAL(6); - APPLY_FUNC_5FOLD_PRIMAL(7); + + APPLY_FUNC_4FOLD_PRIMAL(0); + APPLY_FUNC_4FOLD_PRIMAL(1); + APPLY_FUNC_4FOLD_PRIMAL(2); + APPLY_FUNC_4FOLD_PRIMAL(3); + APPLY_FUNC_4FOLD_PRIMAL(4); + APPLY_FUNC_4FOLD_PRIMAL(5); + APPLY_FUNC_4FOLD_PRIMAL(6); + APPLY_FUNC_4FOLD_PRIMAL(7); } #pragma omp for nowait for(mwIndex elemIdx = num_elems_unroll_8; elemIdx < num_elems_unroll_1; elemIdx += simd_1.block) @@ -142,9 +124,8 @@ namespace GT6D { const vVvf inV31 = access.load(&in_data3[elemIdx]); const vVvf inV41 = access.load(&in_data4[elemIdx]); - const vVvf inV51 = access.load(&in_data5[elemIdx]); - func(inV11, inV21, inV31, inV41, inV51); + func(inV11, inV21, inV31, inV41); access.store(&out_data1[elemIdx], inV11); access.store(&out_data2[elemIdx], inV21); @@ -152,7 +133,7 @@ namespace GT6D { #pragma omp for nowait for(mwIndex elemIdx = num_elems_unroll_1; elemIdx < num_elems; elemIdx++) { - func(out_data1[elemIdx], out_data2[elemIdx], in_data3[elemIdx], in_data4[elemIdx], in_data5[elemIdx]); + func(out_data1[elemIdx], out_data2[elemIdx], in_data3[elemIdx], in_data4[elemIdx]); } } @@ -162,21 +143,19 @@ namespace GT6D { mxArray * mat_new_enh_solution, const mxArray * const mat_old_solution, const mxArray * const mat_correction_tomo, - const mxArray * const mat_correction_l1, const Function & func) { Type * const new_solution = (Type *) mxGetData(mat_new_solution); Type * const new_enh_solution = (Type *) mxGetData(mat_new_enh_solution); const Type * const old_solution = (const Type *) mxGetData(mat_old_solution); const Type * const corr_tomo = (const Type *) mxGetData(mat_correction_tomo); - const Type * const corr_l1 = (const Type *) mxGetData(mat_correction_l1); const mwSize num_elems = mxGetNumberOfElements(mat_new_solution); #pragma omp parallel { primal_inner_cycle_sse< Type, const Function, AccessAligned<Type> >( - new_solution, new_enh_solution, old_solution, corr_tomo, corr_l1, num_elems, func); + new_solution, new_enh_solution, old_solution, corr_tomo, num_elems, func); } } @@ -185,8 +164,6 @@ namespace GT6D { mxArray * & mat_new_enh_solution, const mxArray * const mat_old_solution) { - const mwSize zero_dims[2] = {0, 0}; - const mxClassID type = mxGetClassID(mat_old_solution); const mwSize numDims = mxGetNumberOfDimensions(mat_old_solution); const mwSize * dims = mxGetDimensions(mat_old_solution); @@ -194,13 +171,15 @@ namespace GT6D { const mwSize elem_size = mxGetElementSize(mat_old_solution); mat_new_solution = mxCreateSharedDataCopy(mat_old_solution); -// mat_new_solution = mxCreateNumericArray(2, zero_dims, type, mxREAL); -// mxSetDimensions(mat_new_solution, dims, numDims); -// mxSetData(mat_new_solution, mxMalloc(elem_size * num_elems)); +//#if defined(EMLRT_VERSION_INFO) && EMLRT_VERSION_INFO >= 0x2015a +// mat_new_enh_solution = mxCreateUninitNumericArray(numDims, (mwSize *)dims, type, mxREAL); +//#else + const mwSize zero_dims[2] = {0, 0}; mat_new_enh_solution = mxCreateNumericArray(2, zero_dims, type, mxREAL); mxSetDimensions(mat_new_enh_solution, dims, numDims); mxSetData(mat_new_enh_solution, mxMalloc(elem_size * num_elems)); +//#endif } }; diff --git a/zUtil_Deformation/Gt6DBlobReconstructor.m b/zUtil_Deformation/Gt6DBlobReconstructor.m index 43c9288e..c1a352c0 100644 --- a/zUtil_Deformation/Gt6DBlobReconstructor.m +++ b/zUtil_Deformation/Gt6DBlobReconstructor.m @@ -52,6 +52,7 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector self.statistics.add_task_partial('cp_primal_update', 'cp_primal_BS', 'Blobs -> Sinograms'); self.statistics.add_task_partial('cp_primal_update', 'cp_primal_SF', 'Transposed Shape Functions'); self.statistics.add_task_partial('cp_primal_update', 'cp_primal_ODF', 'ODF Correction'); + self.statistics.add_task_partial('cp_primal_update', 'cp_primal_CORR', 'Primal correction computation'); self.statistics.add_task_partial('cp_primal_update', 'cp_primal_APP', 'Primal update application'); self.blobs = blobs; @@ -147,7 +148,7 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector fprintf('Initializing CP_%s weights: ', upper(algo)) c = tic(); - [sigma1, sigma1_1, sigma2, sigma2_1, tau] = self.init_cp_weights(algo); + [sigma1, sigma1_1, sigma2, tau] = self.init_cp_weights(algo); fprintf('Done (%g seconds).\nInitializing CP_%s variables: ', toc(c), upper(algo)) c = tic(); [p, nextEnhancedSolution, q_l1, q_tv, q_odf] = self.init_cp_vars(algo); @@ -183,14 +184,14 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector % Computing update dual TV if (do_tv_update) self.statistics.tic('cp_dual_update_tv'); - [q_tv, mdiv_tv] = self.update_dual_TV(q_tv, nextEnhancedSolution, 1); + [q_tv, mdiv_tv] = self.update_dual_TV(q_tv, nextEnhancedSolution); self.statistics.toc('cp_dual_update_tv'); end % Computing update dual l1 if (do_l1_update) self.statistics.tic('cp_dual_update_l1'); - q_l1 = self.update_dual_l1(q_l1, nextEnhancedSolution, lambda); + q_l1 = self.update_dual_l1(q_l1, nextEnhancedSolution); self.statistics.toc('cp_dual_update_l1'); end @@ -198,25 +199,25 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector switch (upper(algo)) case '6DLS' % q_l1 is just a volume of zeros - compute_update_primal = @(ii)q_l1; + compute_update_primal = @(ii){}; case '6DL1' - compute_update_primal = @(ii)q_l1{ii}; + compute_update_primal = @(ii){lambda * q_l1{ii}}; case '6DTV' - compute_update_primal = @(ii)mdiv_tv; + compute_update_primal = @(ii){mdiv_tv}; case '6DTVL1' - compute_update_primal = @(ii)(q_l1{ii} + mdiv_tv); + compute_update_primal = @(ii){lambda * q_l1{ii}, mdiv_tv}; end % Computing update dual ODF if (do_odf_update) self.statistics.tic('cp_dual_update_ODF'); - q_odf = self.update_dual_ODF(q_odf, nextEnhancedSolution, sigma2, sigma2_1); + q_odf = self.update_dual_ODF(q_odf, nextEnhancedSolution, sigma2); self.statistics.toc('cp_dual_update_ODF'); % [self.ODF, temp_odf, q_odf] base_update_primal = compute_update_primal; - compute_update_primal = @(ii)(base_update_primal(ii) + q_odf(ii)); + compute_update_primal = @(ii)[base_update_primal(ii), {q_odf(ii)}]; end % Computing update primal @@ -233,7 +234,7 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector self.normResiduals(ii / sample_rate) ... = self.compute_functional_value(p, algo, lambda); end - drawnow(); %pause + drawnow(); end fprintf('(%04d) Done in %f seconds.\n', tot_iters, toc(c) ) @@ -299,11 +300,11 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%% Primal and Duals updating functions methods (Access = public) - function [q, mdivq] = update_dual_TV(self, q, new_enh_sol, lambda) + function [q, mdivq] = update_dual_TV(self, q, new_enh_sol) sigma = 1 ./ (2 * numel(new_enh_sol)); c = tic(); - sES = sigma * gtMathsSumCellVolumes(new_enh_sol); + sES = gtMathsSumCellVolumes(new_enh_sol); self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_reduction'); c = tic(); @@ -311,16 +312,18 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_gradient'); if (self.algo_ops_c_functions) - q = gt6DUpdateDualTV_c(q, dsES, lambda, self.num_threads); + % Using the l1 update function because TV is the l1 of the + % gradient + q = gt6DUpdateDualL1_c(q, dsES, sigma, self.num_threads); else for n = 1:numel(q) q{n} = q{n} + sigma * dsES{n}; - q{n} = lambda * q{n} ./ max(lambda, abs(q{n})); + q{n} = q{n} ./ max(1, abs(q{n})); end end c = tic(); - mdivq = -gtMathsDivergence(q); + mdivq = - gtMathsDivergence(q); self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_divergence'); end @@ -341,39 +344,44 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector end end - function q = update_dual_l1(self, q, new_enh_sol, lambda) + function q = update_dual_l1(self, q, new_enh_sol) num_vols = numel(q); - sigma = 1 / num_vols; +% sigma = 1 / num_vols; + sigma = 1; if (self.algo_ops_c_functions) - q = gt6DUpdateDualL1_c(q, new_enh_sol, lambda, sigma, self.num_threads); + q = gt6DUpdateDualL1_c(q, new_enh_sol, sigma, self.num_threads); else for n = 1:num_vols q{n} = q{n} + new_enh_sol{n} * sigma; - q{n} = lambda * q{n} ./ max(lambda, abs(q{n})); + q{n} = q{n} ./ max(1, abs(q{n})); end end end - function q_odf = update_dual_ODF(self, q_odf, new_enh_sol, sigma, sigma_1) + function q_odf = update_dual_ODF(self, q_odf, new_enh_sol, sigma) c = tic(); temp_odf = self.compute_solution_ODF(new_enh_sol); self.statistics.add_timestamp(toc(c), 'cp_dual_update_ODF', 'cp_dual_ODF_compute'); - q_odf = (q_odf + sigma .* (temp_odf - self.ODF)) .* sigma_1; + q_odf = q_odf + sigma .* (temp_odf - self.ODF); end - function [curr_sol, curr_enh_sol, app_time] = update_primal(self, curr_sol, curr_enh_sol, corr_tomo, corr_l1, tau) + function [curr_sol, curr_enh_sol, app_time, corr_time] = update_primal(self, curr_sol, curr_enh_sol, corrections, tau) + c = tic(); + correction = gtMathsSumCellVolumes(corrections); + corr_time = toc(c); + c = tic(); if (self.algo_ops_c_functions) % We actually re-use the same allocated volumes, saving % time and reducing problems with matlab's garbage % collection - [curr_sol, curr_enh_sol] = gt6DUpdatePrimal_c(curr_sol, curr_enh_sol, corr_tomo, corr_l1, tau, self.num_threads); + [curr_sol, curr_enh_sol] = gt6DUpdatePrimal_c(curr_sol, curr_enh_sol, correction, tau, self.num_threads); else theta = 1; - v = curr_sol + (corr_tomo + corr_l1) .* tau; + v = curr_sol + correction .* tau; v(v < 0) = 0; curr_enh_sol = v + theta .* (v - curr_sol); curr_sol = v; @@ -435,6 +443,7 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector function nextEnhancedSolution = compute_bwd_projection(self, nextEnhancedSolution, p_blurr, compute_update_primal, tau) timing_bp = 0; timing_bs = 0; + timing_corr = 0; timing_app = 0; timing_sf_bp = 0; @@ -455,14 +464,14 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector for ii_gpu = 1:chunk_size % Computing the update to apply ii_v = inds(ii_gpu); - up_prim = compute_update_primal(ii_v); + up_prim = [compute_update_primal(ii_v), v(ii_gpu)]; - [self.currentSolution{ii_v}, nextEnhancedSolution{ii_v}, app_time] ... + [self.currentSolution{ii_v}, nextEnhancedSolution{ii_v}, app_time, corr_time] ... = self.update_primal(self.currentSolution{ii_v}, ... - nextEnhancedSolution{ii_v}, v{ii_gpu}, up_prim, tau{ii_v}); + nextEnhancedSolution{ii_v}, up_prim, tau{ii_v}); + timing_corr = timing_corr + corr_time; timing_app = timing_app + app_time; end - clear v end else chunk_end = 0; @@ -475,18 +484,19 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector timing_sf_bp = timing_sf_bp + sf_bp_time; % Computing the update to apply - up_prim = compute_update_primal(n); + up_prim = [compute_update_primal(n), v]; - [self.currentSolution{n}, nextEnhancedSolution{n}, app_time] ... + [self.currentSolution{n}, nextEnhancedSolution{n}, app_time, corr_time] ... = self.update_primal(self.currentSolution{n}, ... - nextEnhancedSolution{n}, v, up_prim, tau{n}); + nextEnhancedSolution{n}, up_prim, tau{n}); + timing_corr = timing_corr + corr_time; timing_app = timing_app + app_time; - clear v end self.statistics.add_timestamp(timing_bp, 'cp_primal_update', 'cp_primal_BP') self.statistics.add_timestamp(timing_bs, 'cp_primal_update', 'cp_primal_BS') self.statistics.add_timestamp(timing_sf_bp, 'cp_primal_update', 'cp_primal_SF') + self.statistics.add_timestamp(timing_corr, 'cp_primal_update', 'cp_primal_CORR') self.statistics.add_timestamp(timing_app, 'cp_primal_update', 'cp_primal_APP') end end @@ -510,7 +520,7 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector fprintf('\b\b (%2.1f s)\n', toc(c)); end - function [sigma1, sigma1_1, sigma2, sigma2_1, tau] = init_cp_weights(self, algo) + function [sigma1, sigma1_1, sigma2, tau] = init_cp_weights(self, algo) sigma1 = cell(size(self.fwd_weights)); for n = 1:numel(sigma1) sigma1{n} = 1 ./ (self.fwd_weights{n} + (self.fwd_weights{n} == 0)); @@ -524,7 +534,6 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector use_ODF = ~isempty(self.ODF); sigma2 = use_ODF / prod(self.volume_geometry); - sigma2_1 = 1 / (1 + sigma2); num_geoms = self.get_number_geometries(); tau = cell(size(self.bwd_weights)); -- GitLab