From 0ddf15ccec4af64f1494d930664c01c4f06cb8fe Mon Sep 17 00:00:00 2001
From: Nicola Vigano <vigano@yoda.esrf.fr>
Date: Tue, 3 May 2016 18:42:12 +0200
Subject: [PATCH] 6D-reconstruction: next round of coefficient fixing and
 simplifications

Signed-off-by: Nicola Vigano <nicola.vigano@esrf.fr>
---
 zUtil_Cxx/6D_ops/gt6DUpdateDualL1_c.cpp   |  22 +--
 zUtil_Cxx/6D_ops/gt6DUpdateDualTV_c.cpp   |  88 -----------
 zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp   |  22 +--
 zUtil_Cxx/include/gt6DUpdateDualL1Ops.h   |  17 +--
 zUtil_Cxx/include/gt6DUpdateDualTVOps.h   | 175 ----------------------
 zUtil_Cxx/include/gt6DUpdatePrimalOps.h   |  73 ++++-----
 zUtil_Deformation/Gt6DBlobReconstructor.m |  77 +++++-----
 7 files changed, 92 insertions(+), 382 deletions(-)
 delete mode 100644 zUtil_Cxx/6D_ops/gt6DUpdateDualTV_c.cpp
 delete mode 100644 zUtil_Cxx/include/gt6DUpdateDualTVOps.h

diff --git a/zUtil_Cxx/6D_ops/gt6DUpdateDualL1_c.cpp b/zUtil_Cxx/6D_ops/gt6DUpdateDualL1_c.cpp
index 7dec63af..6e634415 100644
--- a/zUtil_Cxx/6D_ops/gt6DUpdateDualL1_c.cpp
+++ b/zUtil_Cxx/6D_ops/gt6DUpdateDualL1_c.cpp
@@ -9,7 +9,7 @@
 
 void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
 {
-  if (nrhs < 4) {
+  if (nrhs < 3) {
     mexErrMsgIdAndTxt(GT6D::dual_l1_error_id,
         "Not enough arguments!");
     return;
@@ -17,8 +17,7 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
 
   const mxArray * const dual = prhs[0];
   const mxArray * const new_enh_sol = prhs[1];
-  const mxArray * const lambda_a = prhs[2];
-  const mxArray * const sigma_a = prhs[3];
+  const mxArray * const sigma_a = prhs[2];
 
   if (!mxIsCell(dual)) {
     mexErrMsgIdAndTxt(GT6D::dual_l1_error_id,
@@ -30,24 +29,17 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
         "The second argument should be a Cell array (New Enhancement of Solution)");
     return;
   }
-  if (!mxIsNumeric(lambda_a)) {
+  if (!mxIsNumeric(sigma_a)) {
     mexErrMsgIdAndTxt(GT6D::dual_l1_error_id,
-        "The third argument should be a scalar (lambda)");
+        "The third argument should be a scalar (sigma)");
     return;
   }
-  if (!mxIsDouble(lambda_a) && !mxIsSingle(lambda_a)) {
+  if (!mxIsDouble(sigma_a) && !mxIsSingle(sigma_a)) {
     mexErrMsgIdAndTxt(GT6D::dual_l1_error_id,
         "The third argument should be either a 'double' or a 'single'");
     return;
   }
 
-  double lambda;
-  if (mxIsDouble(lambda_a)) {
-    lambda = *mxGetPr(lambda_a);
-  } else {
-    lambda = *(float *)mxGetData(lambda_a);
-  }
-
   double sigma;
   if (mxIsDouble(sigma_a)) {
     sigma = *mxGetPr(sigma_a);
@@ -76,13 +68,13 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
   {
     case mxDOUBLE_CLASS:
     {
-      GT6D::update_dual_l1<double> func(lambda, sigma);
+      GT6D::update_dual_l1<double> func(sigma);
       cell_iteration<double, false>(plhs[0], prhs[1], func);
       break;
     }
     case mxSINGLE_CLASS:
     {
-      GT6D::update_dual_l1<float> func(lambda, sigma);
+      GT6D::update_dual_l1<float> func(sigma);
       cell_iteration<float, false>(plhs[0], prhs[1], func);
       break;
     }
diff --git a/zUtil_Cxx/6D_ops/gt6DUpdateDualTV_c.cpp b/zUtil_Cxx/6D_ops/gt6DUpdateDualTV_c.cpp
deleted file mode 100644
index 81a59a28..00000000
--- a/zUtil_Cxx/6D_ops/gt6DUpdateDualTV_c.cpp
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * gt6DUpdateDualL1_c.cpp
- *
- *  Created on: Oct 31, 2014
- *      Author: vigano
- */
-
-#include <gt6DUpdateDualTVOps.h>
-
-void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
-{
-  if (nrhs < 3) {
-    mexErrMsgIdAndTxt(GT6D::dual_tv_error_id,
-        "Not enough arguments!");
-    return;
-  }
-
-  const mxArray * const dual = prhs[0];
-  const mxArray * const new_enh_sol = prhs[1];
-  const mxArray * const lambda_a = prhs[2];
-
-  if (!mxIsCell(dual)) {
-    mexErrMsgIdAndTxt(GT6D::dual_tv_error_id,
-        "The first argument should be a Cell array (dual variable)");
-    return;
-  }
-  if (!mxIsCell(new_enh_sol)) {
-    mexErrMsgIdAndTxt(GT6D::dual_tv_error_id,
-        "The second argument should be a Cell array (New Enhancement of Solution)");
-    return;
-  }
-  if (!mxIsNumeric(lambda_a)) {
-    mexErrMsgIdAndTxt(GT6D::dual_tv_error_id,
-        "The third argument should be a scalar (lambda)");
-    return;
-  }
-  if (!mxIsDouble(lambda_a) && !mxIsSingle(lambda_a)) {
-    mexErrMsgIdAndTxt(GT6D::dual_tv_error_id,
-        "The third argument should be either a 'double' or a 'single'");
-    return;
-  }
-
-  double lambda;
-  if (mxIsDouble(lambda_a)) {
-    lambda = *mxGetPr(lambda_a);
-  } else {
-    lambda = *(float *)mxGetData(lambda_a);
-  }
-
-  if (nrhs >= 4)
-  {
-    initialize_multithreading(*mxGetPr(prhs[3]));
-  }
-  else
-  {
-    initialize_multithreading();
-  }
-
-  const mxClassID data_type = GT6D::check_dual_tv_arguments(dual, new_enh_sol);
-  if (data_type == mxUNKNOWN_CLASS)
-  {
-    return;
-  }
-
-  plhs[0] = mxCreateSharedDataCopy(dual);
-
-  switch (data_type)
-  {
-    case mxDOUBLE_CLASS:
-    {
-      GT6D::update_dual_tv<double> func(lambda);
-      cell_iteration<double, false>(plhs[0], prhs[1], func);
-      break;
-    }
-    case mxSINGLE_CLASS:
-    {
-      GT6D::update_dual_tv<float> func(lambda);
-      cell_iteration<float, false>(plhs[0], prhs[1], func);
-      break;
-    }
-    default:
-    {
-      mexErrMsgIdAndTxt(GT6D::dual_tv_error_id,
-          "The argument needs to be a Cell array of floating point numbers");
-      return;
-    }
-  }
-}
diff --git a/zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp b/zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp
index 9495e21b..71a5fc23 100644
--- a/zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp
+++ b/zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp
@@ -10,7 +10,7 @@
 
 void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
 {
-  if (nrhs < 5) {
+  if (nrhs < 4) {
     mexErrMsgIdAndTxt(GT6D::primal_error_id,
         "Not enough arguments!");
     return;
@@ -23,9 +23,8 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
 
   const mxArray * const curr_sol = prhs[0];
   const mxArray * const curr_enh_sol = prhs[1];
-  const mxArray * const corr_tomo = prhs[2];
-  const mxArray * const corr_l1_tv = prhs[3];
-  const mxArray * const tau = prhs[4];
+  const mxArray * const corrections = prhs[2];
+  const mxArray * const tau = prhs[3];
 
   if (!mxIsNumeric(curr_sol)) {
     mexErrMsgIdAndTxt(GT6D::primal_error_id,
@@ -37,16 +36,11 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
         "The first argument should be an array (old enhanced solution variable)");
     return;
   }
-  if (!mxIsNumeric(corr_tomo)) {
+  if (!mxIsNumeric(corrections)) {
     mexErrMsgIdAndTxt(GT6D::primal_error_id,
         "The second argument should be an array (Tomo correction)");
     return;
   }
-  if (!mxIsNumeric(corr_l1_tv)) {
-    mexErrMsgIdAndTxt(GT6D::primal_error_id,
-        "The third argument should be an array (l1 correction)");
-    return;
-  }
   if (!mxIsNumeric(tau)) {
     mexErrMsgIdAndTxt(GT6D::primal_error_id,
         "The fourth argument should be a scalar (tau)");
@@ -74,8 +68,8 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
     initialize_multithreading();
   }
 
-  if (!GT6D::check_primal_arguments(curr_sol, corr_tomo, corr_l1_tv) ||
-      !GT6D::check_primal_arguments(curr_enh_sol, corr_tomo, corr_l1_tv))
+  if (!GT6D::check_primal_arguments(curr_sol, corrections) ||
+      !GT6D::check_primal_arguments(curr_enh_sol, corrections))
   {
     return;
   }
@@ -90,13 +84,13 @@ void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
     case mxDOUBLE_CLASS:
     {
       GT6D::update_primal<double> func(scale);
-      GT6D::primal_iteration<double>(plhs[0], plhs[1], curr_sol, corr_tomo, corr_l1_tv, func);
+      GT6D::primal_iteration<double>(plhs[0], plhs[1], curr_sol, corrections, func);
       break;
     }
     case mxSINGLE_CLASS:
     {
       GT6D::update_primal<float> func(scale);
-      GT6D::primal_iteration<float>(plhs[0], plhs[1], curr_sol, corr_tomo, corr_l1_tv, func);
+      GT6D::primal_iteration<float>(plhs[0], plhs[1], curr_sol, corrections, func);
       break;
     }
     case mxUNKNOWN_CLASS:
diff --git a/zUtil_Cxx/include/gt6DUpdateDualL1Ops.h b/zUtil_Cxx/include/gt6DUpdateDualL1Ops.h
index 9fe55f19..38ad2642 100644
--- a/zUtil_Cxx/include/gt6DUpdateDualL1Ops.h
+++ b/zUtil_Cxx/include/gt6DUpdateDualL1Ops.h
@@ -22,23 +22,22 @@ namespace GT6D {
   public:
     typedef typename SIMDUnrolling<Type>::vVvf vVvf;
 
-    update_dual_l1(const Type & _lambda, const Type & _sigma);
+    update_dual_l1(const Type & _sigma);
 
     const Type
     operator()(const Type & dual, const Type & new_enh_sol) const throw()
     {
       const Type temp = dual + new_enh_sol * sigma;
-      return lambda * temp / std::max(lambda, this->abs(temp));
+      return temp / std::max((Type)1.0, this->abs(temp));
     }
     const vVvf
     operator()(const vVvf & dual, const vVvf & new_enh_sol) const throw()
     {
       const vVvf temp = dual + new_enh_sol * sigma_v;
-      return lambda_v * temp / this->max(lambda_v, this->abs(temp));
+      return temp / this->max(ones_v, this->abs(temp));
     }
   protected:
-    const Type lambda;
-    const vVvf lambda_v;
+    const vVvf ones_v;
 
     const Type sigma;
     const vVvf sigma_v;
@@ -54,15 +53,15 @@ namespace GT6D {
   };
 
   template<>
-  update_dual_l1<float>::update_dual_l1(const float & _lambda, const float & _sigma)
-    : lambda(_lambda), lambda_v(Coeff<float>::get(_lambda))
+  update_dual_l1<float>::update_dual_l1(const float & _sigma)
+    : ones_v(Coeff<float>::get(1.0))
     , sigma(_sigma), sigma_v(Coeff<float>::get(_sigma))
     , abs_mask_v(_mm_castsi128_ps(_mm_set1_epi32(0x7fffffff)))
   { }
 
   template<>
-  update_dual_l1<double>::update_dual_l1(const double & _lambda, const double & _sigma)
-    : lambda(_lambda), lambda_v(Coeff<double>::get(_lambda))
+  update_dual_l1<double>::update_dual_l1(const double & _sigma)
+    : ones_v(Coeff<double>::get(1))
     , sigma(_sigma), sigma_v(Coeff<double>::get(_sigma))
     , abs_mask_v(_mm_castsi128_pd(_mm_set1_epi64x(0x7fffffffffffffffL)))
   { }
diff --git a/zUtil_Cxx/include/gt6DUpdateDualTVOps.h b/zUtil_Cxx/include/gt6DUpdateDualTVOps.h
deleted file mode 100644
index b405834b..00000000
--- a/zUtil_Cxx/include/gt6DUpdateDualTVOps.h
+++ /dev/null
@@ -1,175 +0,0 @@
-/*
- * gt6DUpdateDualL1Ops.h
- *
- *  Created on: Oct 31, 2014
- *      Author: vigano
- */
-
-#ifndef ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALL1OPS_H_
-#define ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALL1OPS_H_
-
-#include "internal_cell_defs.h"
-
-#include <algorithm>
-#include <cmath>
-
-namespace GT6D {
-
-  const char * dual_tv_error_id = "C_FUN:gt6DUpdateDualL1:wrong_argument";
-
-  template<typename Type>
-  class update_dual_tv {
-  public:
-    typedef typename SIMDUnrolling<Type>::vVvf vVvf;
-
-    update_dual_tv(const Type & _lambda);
-
-    const Type
-    operator()(const Type & dual, const Type & new_enh_sol) const throw()
-    {
-      const Type temp = dual + new_enh_sol;
-      return lambda * temp / std::max(lambda, this->abs(temp));
-    }
-    const vVvf
-    operator()(const vVvf & dual, const vVvf & new_enh_sol) const throw()
-    {
-      const vVvf temp = dual + new_enh_sol;
-      return lambda_v * temp / this->max(lambda_v, this->abs(temp));
-    }
-  protected:
-    const Type lambda;
-    const vVvf lambda_v;
-
-    const vVvf abs_mask_v;
-
-    const Type
-    abs(const Type & val) const throw();
-    const vVvf
-    abs(const vVvf & val) const throw();
-    const vVvf
-    max(const vVvf & val1, const vVvf & val2) const throw();
-  };
-
-  template<>
-  update_dual_tv<float>::update_dual_tv(const float & _lambda)
-    : lambda(_lambda), lambda_v(Coeff<float>::get(_lambda))
-    , abs_mask_v(_mm_castsi128_ps(_mm_set1_epi32(0x7fffffff)))
-  { }
-
-  template<>
-  update_dual_tv<double>::update_dual_tv(const double & _lambda)
-    : lambda(_lambda), lambda_v(Coeff<double>::get(_lambda))
-    , abs_mask_v(_mm_castsi128_pd(_mm_set1_epi64x(0x7fffffffffffffffL)))
-  { }
-
-  template<>
-  inline const update_dual_tv<float>::vVvf
-  update_dual_tv<float>::abs(
-      const update_dual_tv<float>::vVvf & val)
-  const throw()
-  {
-    typedef update_dual_tv<float>::vVvf vVvf;
-    return _mm_and_ps(val, abs_mask_v);
-  }
-
-  template<>
-  inline const update_dual_tv<double>::vVvf
-  update_dual_tv<double>::abs(
-      const update_dual_tv<double>::vVvf & val)
-  const throw()
-  {
-    typedef update_dual_tv<double>::vVvf vVvf;
-    return _mm_and_pd(val, abs_mask_v);
-  }
-
-  template<>
-  inline const float
-  update_dual_tv<float>::abs(const float & val)
-  const throw()
-  {
-    return std::abs(val);
-  }
-
-  template<>
-  inline const double
-  update_dual_tv<double>::abs(const double & val)
-  const throw()
-  {
-    return std::fabs(val);
-  }
-
-  template<>
-  inline const update_dual_tv<float>::vVvf
-  update_dual_tv<float>::max(
-      const update_dual_tv<float>::vVvf & val1,
-      const update_dual_tv<float>::vVvf & val2)
-  const throw()
-  {
-    typedef update_dual_tv<float>::vVvf vVvf;
-    return _mm_max_ps(val1, val2);
-  }
-
-  template<>
-  inline const update_dual_tv<double>::vVvf
-  update_dual_tv<double>::max(
-      const update_dual_tv<double>::vVvf & val1,
-      const update_dual_tv<double>::vVvf & val2)
-  const throw()
-  {
-    typedef update_dual_tv<double>::vVvf vVvf;
-    return _mm_max_pd(val1, val2);
-  }
-
-
-  inline mxClassID
-  check_dual_tv_arguments(const mxArray * const cells_dual,
-      const mxArray * const cells_new_enh_sol)
-  {
-    const mwSize num_cells = mxGetNumberOfElements(cells_dual);
-    if (num_cells != mxGetNumberOfElements(cells_new_enh_sol))
-    {
-      mexErrMsgIdAndTxt(dual_tv_error_id,
-          "The Cell arrays should have the same number of elements");
-      return mxUNKNOWN_CLASS;
-    }
-
-    mxClassID expected_type = mxUNKNOWN_CLASS;
-
-    for(mwIndex cell_idx = 0; cell_idx < num_cells; cell_idx++)
-    {
-      const mxArray * const cell_dual = mxGetCell(cells_dual, cell_idx);
-      const mxArray * const cell_new_enh_sol = mxGetCell(cells_new_enh_sol, cell_idx);
-
-      const mxClassID type_dual = mxGetClassID(cell_dual);
-      if (cell_idx == 0)
-      {
-        expected_type = type_dual;
-      }
-
-      if (type_dual != expected_type)
-      {
-        mexErrMsgIdAndTxt(dual_tv_error_id,
-            "The arguments need to be Cell arrays of coherent floating point types");
-        return mxUNKNOWN_CLASS;
-      }
-      if (type_dual != mxGetClassID(cell_new_enh_sol))
-      {
-        mexErrMsgIdAndTxt(dual_tv_error_id,
-            "The arguments need to be Cell arrays of coherent floating point types");
-        return mxUNKNOWN_CLASS;
-      }
-
-      const mwSize num_elems = mxGetNumberOfElements(cell_dual);
-      if (num_elems != mxGetNumberOfElements(cell_new_enh_sol))
-      {
-        mexErrMsgIdAndTxt(dual_tv_error_id,
-                "Blobs volumes should have the same number of elements as the Dual");
-        return mxUNKNOWN_CLASS;
-      }
-    }
-
-    return expected_type;
-  }
-};
-
-#endif /* ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALL1OPS_H_ */
diff --git a/zUtil_Cxx/include/gt6DUpdatePrimalOps.h b/zUtil_Cxx/include/gt6DUpdatePrimalOps.h
index 6231930c..1959b149 100644
--- a/zUtil_Cxx/include/gt6DUpdatePrimalOps.h
+++ b/zUtil_Cxx/include/gt6DUpdatePrimalOps.h
@@ -14,14 +14,13 @@
 
 namespace GT6D {
 
-#define APPLY_FUNC_5FOLD_PRIMAL(shift_val) \
+#define APPLY_FUNC_4FOLD_PRIMAL(shift_val) \
 { \
   vVvf inV11 = access.load(out1 + shift_val * simd_8.shift);\
   vVvf inV21 = access.load(out2 + shift_val * simd_8.shift);\
   const vVvf inV31 = access.load(in3 + shift_val * simd_8.shift);\
   const vVvf inV41 = access.load(in4 + shift_val * simd_8.shift);\
-  const vVvf inV51 = access.load(in5 + shift_val * simd_8.shift);\
-  func(inV11, inV21, inV31, inV41, inV51);\
+  func(inV11, inV21, inV31, inV41);\
   access.store(out1 + shift_val * simd_8.shift, inV11);\
   access.store(out2 + shift_val * simd_8.shift, inV21);\
 }
@@ -38,20 +37,18 @@ namespace GT6D {
 
     void
     operator()(Type & new_solution, Type & new_enh_solution,
-        const Type & old_solution, const Type & correction_tomo,
-        const Type & correction_l1) const throw()
+        const Type & old_solution, const Type & correction_tomo) const throw()
     {
 
-      new_solution = non_neg(old_solution + (correction_tomo + correction_l1) * tau, Type());
-      new_enh_solution = (new_solution + new_solution) - old_solution;
+      new_solution = non_neg(old_solution + correction_tomo * tau, Type());
+      new_enh_solution = new_solution + (new_solution - old_solution);
     }
     void
     operator()(vVvf & new_solution, vVvf & new_enh_solution,
-        const vVvf & old_solution, const vVvf & correction_tomo,
-        const vVvf & correction_l1) const throw()
+        const vVvf & old_solution, const vVvf & correction_tomo) const throw()
     {
-      new_solution = non_neg(old_solution + (correction_tomo + correction_l1) * tau_v, vVvf());
-      new_enh_solution = (new_solution + new_solution) - old_solution;
+      new_solution = non_neg(old_solution + correction_tomo * tau_v, vVvf());
+      new_enh_solution = new_solution + (new_solution - old_solution);
     }
   protected:
     const Type tau;
@@ -62,8 +59,7 @@ namespace GT6D {
 
   inline bool
   check_primal_arguments(const mxArray * const mat_old_solution,
-      const mxArray * const mat_correction_tomo,
-      const mxArray * const mat_correction_l1)
+      const mxArray * const mat_correction_tomo)
   {
     const mxClassID type = mxGetClassID(mat_old_solution);
     const mwSize num_elems = mxGetNumberOfElements(mat_old_solution);
@@ -74,12 +70,6 @@ namespace GT6D {
               "tomo correction should have the same type as the Primal");
       return false;
     }
-    if (type != mxGetClassID(mat_correction_l1))
-    {
-      mexErrMsgIdAndTxt(primal_error_id,
-              "l1 correction blobs should have the same type as the Primal");
-      return false;
-    }
 
     if (num_elems != mxGetNumberOfElements(mat_correction_tomo))
     {
@@ -87,12 +77,6 @@ namespace GT6D {
               "tomo correction should have the same number of elements as the Primal");
       return false;
     }
-    if (num_elems != mxGetNumberOfElements(mat_correction_l1))
-    {
-      mexErrMsgIdAndTxt(primal_error_id,
-              "l1 correction should have the same number of elements as the Primal");
-      return false;
-    }
 
     return true;
   }
@@ -103,7 +87,6 @@ namespace GT6D {
       Type * const __restrict out_data2,
       const Type * const __restrict in_data3,
       const Type * const __restrict in_data4,
-      const Type * const __restrict in_data5,
       const mwSize & num_elems, Function & func)
   {
     typedef typename SIMDUnrolling<Type>::vVvf vVvf;
@@ -123,16 +106,15 @@ namespace GT6D {
       Type * const out2 = out_data2 + elemIdx;
       const Type * const in3 = in_data3 + elemIdx;
       const Type * const in4 = in_data4 + elemIdx;
-      const Type * const in5 = in_data5 + elemIdx;
-
-      APPLY_FUNC_5FOLD_PRIMAL(0);
-      APPLY_FUNC_5FOLD_PRIMAL(1);
-      APPLY_FUNC_5FOLD_PRIMAL(2);
-      APPLY_FUNC_5FOLD_PRIMAL(3);
-      APPLY_FUNC_5FOLD_PRIMAL(4);
-      APPLY_FUNC_5FOLD_PRIMAL(5);
-      APPLY_FUNC_5FOLD_PRIMAL(6);
-      APPLY_FUNC_5FOLD_PRIMAL(7);
+
+      APPLY_FUNC_4FOLD_PRIMAL(0);
+      APPLY_FUNC_4FOLD_PRIMAL(1);
+      APPLY_FUNC_4FOLD_PRIMAL(2);
+      APPLY_FUNC_4FOLD_PRIMAL(3);
+      APPLY_FUNC_4FOLD_PRIMAL(4);
+      APPLY_FUNC_4FOLD_PRIMAL(5);
+      APPLY_FUNC_4FOLD_PRIMAL(6);
+      APPLY_FUNC_4FOLD_PRIMAL(7);
     }
   #pragma omp for nowait
     for(mwIndex elemIdx = num_elems_unroll_8; elemIdx < num_elems_unroll_1; elemIdx += simd_1.block)
@@ -142,9 +124,8 @@ namespace GT6D {
 
       const vVvf inV31 = access.load(&in_data3[elemIdx]);
       const vVvf inV41 = access.load(&in_data4[elemIdx]);
-      const vVvf inV51 = access.load(&in_data5[elemIdx]);
 
-      func(inV11, inV21, inV31, inV41, inV51);
+      func(inV11, inV21, inV31, inV41);
 
       access.store(&out_data1[elemIdx], inV11);
       access.store(&out_data2[elemIdx], inV21);
@@ -152,7 +133,7 @@ namespace GT6D {
   #pragma omp for nowait
     for(mwIndex elemIdx = num_elems_unroll_1; elemIdx < num_elems; elemIdx++)
     {
-       func(out_data1[elemIdx], out_data2[elemIdx], in_data3[elemIdx], in_data4[elemIdx], in_data5[elemIdx]);
+       func(out_data1[elemIdx], out_data2[elemIdx], in_data3[elemIdx], in_data4[elemIdx]);
     }
   }
 
@@ -162,21 +143,19 @@ namespace GT6D {
       mxArray * mat_new_enh_solution,
       const mxArray * const mat_old_solution,
       const mxArray * const mat_correction_tomo,
-      const mxArray * const mat_correction_l1,
       const Function & func)
   {
     Type * const new_solution = (Type *) mxGetData(mat_new_solution);
     Type * const new_enh_solution = (Type *) mxGetData(mat_new_enh_solution);
     const Type * const old_solution = (const Type *) mxGetData(mat_old_solution);
     const Type * const corr_tomo = (const Type *) mxGetData(mat_correction_tomo);
-    const Type * const corr_l1 = (const Type *) mxGetData(mat_correction_l1);
 
     const mwSize num_elems = mxGetNumberOfElements(mat_new_solution);
 
   #pragma omp parallel
     {
       primal_inner_cycle_sse< Type, const Function, AccessAligned<Type> >(
-          new_solution, new_enh_solution, old_solution, corr_tomo, corr_l1, num_elems, func);
+          new_solution, new_enh_solution, old_solution, corr_tomo, num_elems, func);
     }
   }
 
@@ -185,8 +164,6 @@ namespace GT6D {
       mxArray * & mat_new_enh_solution,
       const mxArray * const mat_old_solution)
   {
-    const mwSize zero_dims[2] = {0, 0};
-
     const mxClassID type = mxGetClassID(mat_old_solution);
     const mwSize numDims = mxGetNumberOfDimensions(mat_old_solution);
     const mwSize * dims = mxGetDimensions(mat_old_solution);
@@ -194,13 +171,15 @@ namespace GT6D {
     const mwSize elem_size = mxGetElementSize(mat_old_solution);
 
     mat_new_solution = mxCreateSharedDataCopy(mat_old_solution);
-//    mat_new_solution = mxCreateNumericArray(2, zero_dims, type, mxREAL);
-//    mxSetDimensions(mat_new_solution, dims, numDims);
-//    mxSetData(mat_new_solution, mxMalloc(elem_size * num_elems));
 
+//#if defined(EMLRT_VERSION_INFO) && EMLRT_VERSION_INFO >= 0x2015a
+//    mat_new_enh_solution = mxCreateUninitNumericArray(numDims, (mwSize *)dims, type, mxREAL);
+//#else
+    const mwSize zero_dims[2] = {0, 0};
     mat_new_enh_solution = mxCreateNumericArray(2, zero_dims, type, mxREAL);
     mxSetDimensions(mat_new_enh_solution, dims, numDims);
     mxSetData(mat_new_enh_solution, mxMalloc(elem_size * num_elems));
+//#endif
   }
 };
 
diff --git a/zUtil_Deformation/Gt6DBlobReconstructor.m b/zUtil_Deformation/Gt6DBlobReconstructor.m
index 43c9288e..c1a352c0 100644
--- a/zUtil_Deformation/Gt6DBlobReconstructor.m
+++ b/zUtil_Deformation/Gt6DBlobReconstructor.m
@@ -52,6 +52,7 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
             self.statistics.add_task_partial('cp_primal_update', 'cp_primal_BS', 'Blobs -> Sinograms');
             self.statistics.add_task_partial('cp_primal_update', 'cp_primal_SF', 'Transposed Shape Functions');
             self.statistics.add_task_partial('cp_primal_update', 'cp_primal_ODF', 'ODF Correction');
+            self.statistics.add_task_partial('cp_primal_update', 'cp_primal_CORR', 'Primal correction computation');
             self.statistics.add_task_partial('cp_primal_update', 'cp_primal_APP', 'Primal update application');
 
             self.blobs = blobs;
@@ -147,7 +148,7 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
 
             fprintf('Initializing CP_%s weights: ', upper(algo))
             c = tic();
-            [sigma1, sigma1_1, sigma2, sigma2_1, tau] = self.init_cp_weights(algo);
+            [sigma1, sigma1_1, sigma2, tau] = self.init_cp_weights(algo);
             fprintf('Done (%g seconds).\nInitializing CP_%s variables: ', toc(c), upper(algo))
             c = tic();
             [p, nextEnhancedSolution, q_l1, q_tv, q_odf] = self.init_cp_vars(algo);
@@ -183,14 +184,14 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
                 % Computing update dual TV
                 if (do_tv_update)
                     self.statistics.tic('cp_dual_update_tv');
-                    [q_tv, mdiv_tv] = self.update_dual_TV(q_tv, nextEnhancedSolution, 1);
+                    [q_tv, mdiv_tv] = self.update_dual_TV(q_tv, nextEnhancedSolution);
                     self.statistics.toc('cp_dual_update_tv');
                 end
 
                 % Computing update dual l1
                 if (do_l1_update)
                     self.statistics.tic('cp_dual_update_l1');
-                    q_l1 = self.update_dual_l1(q_l1, nextEnhancedSolution, lambda);
+                    q_l1 = self.update_dual_l1(q_l1, nextEnhancedSolution);
                     self.statistics.toc('cp_dual_update_l1');
                 end
 
@@ -198,25 +199,25 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
                 switch (upper(algo))
                     case '6DLS'
                         % q_l1 is just a volume of zeros
-                        compute_update_primal = @(ii)q_l1;
+                        compute_update_primal = @(ii){};
                     case '6DL1'
-                        compute_update_primal = @(ii)q_l1{ii};
+                        compute_update_primal = @(ii){lambda * q_l1{ii}};
                     case '6DTV'
-                        compute_update_primal = @(ii)mdiv_tv;
+                        compute_update_primal = @(ii){mdiv_tv};
                     case '6DTVL1'
-                        compute_update_primal = @(ii)(q_l1{ii} + mdiv_tv);
+                        compute_update_primal = @(ii){lambda * q_l1{ii}, mdiv_tv};
                 end
 
                 % Computing update dual ODF
                 if (do_odf_update)
                     self.statistics.tic('cp_dual_update_ODF');
-                    q_odf = self.update_dual_ODF(q_odf, nextEnhancedSolution, sigma2, sigma2_1);
+                    q_odf = self.update_dual_ODF(q_odf, nextEnhancedSolution, sigma2);
                     self.statistics.toc('cp_dual_update_ODF');
 
 %                     [self.ODF, temp_odf, q_odf]
 
                     base_update_primal = compute_update_primal;
-                    compute_update_primal = @(ii)(base_update_primal(ii) + q_odf(ii));
+                    compute_update_primal = @(ii)[base_update_primal(ii), {q_odf(ii)}];
                 end
 
                 % Computing update primal
@@ -233,7 +234,7 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
                     self.normResiduals(ii / sample_rate) ...
                         = self.compute_functional_value(p, algo, lambda);
                 end
-                drawnow(); %pause
+                drawnow();
             end
             fprintf('(%04d) Done in %f seconds.\n', tot_iters, toc(c) )
 
@@ -299,11 +300,11 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
     %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
     %%% Primal and Duals updating functions
     methods (Access = public)
-        function [q, mdivq] = update_dual_TV(self, q, new_enh_sol, lambda)
+        function [q, mdivq] = update_dual_TV(self, q, new_enh_sol)
             sigma = 1 ./  (2 * numel(new_enh_sol));
 
             c = tic();
-            sES = sigma * gtMathsSumCellVolumes(new_enh_sol);
+            sES = gtMathsSumCellVolumes(new_enh_sol);
             self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_reduction');
 
             c = tic();
@@ -311,16 +312,18 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
             self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_gradient');
 
             if (self.algo_ops_c_functions)
-                q = gt6DUpdateDualTV_c(q, dsES, lambda, self.num_threads);
+                % Using the l1 update function because TV is the l1 of the
+                % gradient
+                q = gt6DUpdateDualL1_c(q, dsES, sigma, self.num_threads);
             else
                 for n = 1:numel(q)
                     q{n} = q{n} + sigma * dsES{n};
-                    q{n} = lambda * q{n} ./ max(lambda, abs(q{n}));
+                    q{n} = q{n} ./ max(1, abs(q{n}));
                 end
             end
 
             c = tic();
-            mdivq = -gtMathsDivergence(q);
+            mdivq = - gtMathsDivergence(q);
             self.statistics.add_timestamp(toc(c), 'cp_dual_update_tv', 'cp_dual_tv_divergence');
         end
 
@@ -341,39 +344,44 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
             end
         end
 
-        function q = update_dual_l1(self, q, new_enh_sol, lambda)
+        function q = update_dual_l1(self, q, new_enh_sol)
             num_vols = numel(q);
-            sigma = 1 / num_vols;
+%             sigma = 1 / num_vols;
+            sigma = 1;
 
             if (self.algo_ops_c_functions)
-                q = gt6DUpdateDualL1_c(q, new_enh_sol, lambda, sigma, self.num_threads);
+                q = gt6DUpdateDualL1_c(q, new_enh_sol, sigma, self.num_threads);
             else
                 for n = 1:num_vols
                     q{n} = q{n} + new_enh_sol{n} * sigma;
-                    q{n} = lambda * q{n} ./ max(lambda, abs(q{n}));
+                    q{n} = q{n} ./ max(1, abs(q{n}));
                 end
             end
         end
 
-        function q_odf = update_dual_ODF(self, q_odf, new_enh_sol, sigma, sigma_1)
+        function q_odf = update_dual_ODF(self, q_odf, new_enh_sol, sigma)
             c = tic();
             temp_odf = self.compute_solution_ODF(new_enh_sol);
             self.statistics.add_timestamp(toc(c), 'cp_dual_update_ODF', 'cp_dual_ODF_compute');
 
-            q_odf = (q_odf + sigma .* (temp_odf - self.ODF)) .* sigma_1;
+            q_odf = q_odf + sigma .* (temp_odf - self.ODF);
         end
 
-        function [curr_sol, curr_enh_sol, app_time] = update_primal(self, curr_sol, curr_enh_sol, corr_tomo, corr_l1, tau)
+        function [curr_sol, curr_enh_sol, app_time, corr_time] = update_primal(self, curr_sol, curr_enh_sol, corrections, tau)
+            c = tic();
+            correction = gtMathsSumCellVolumes(corrections);
+            corr_time = toc(c);
+
             c = tic();
             if (self.algo_ops_c_functions)
                 % We actually re-use the same allocated volumes, saving
                 % time and reducing problems with matlab's garbage
                 % collection
-                [curr_sol, curr_enh_sol] = gt6DUpdatePrimal_c(curr_sol, curr_enh_sol, corr_tomo, corr_l1, tau, self.num_threads);
+                [curr_sol, curr_enh_sol] = gt6DUpdatePrimal_c(curr_sol, curr_enh_sol, correction, tau, self.num_threads);
             else
                 theta = 1;
 
-                v = curr_sol + (corr_tomo + corr_l1) .* tau;
+                v = curr_sol + correction .* tau;
                 v(v < 0) = 0;
                 curr_enh_sol = v + theta .* (v - curr_sol);
                 curr_sol = v;
@@ -435,6 +443,7 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
         function nextEnhancedSolution = compute_bwd_projection(self, nextEnhancedSolution, p_blurr, compute_update_primal, tau)
             timing_bp = 0;
             timing_bs = 0;
+            timing_corr = 0;
             timing_app = 0;
             timing_sf_bp = 0;
 
@@ -455,14 +464,14 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
                     for ii_gpu = 1:chunk_size
                         % Computing the update to apply
                         ii_v = inds(ii_gpu);
-                        up_prim = compute_update_primal(ii_v);
+                        up_prim = [compute_update_primal(ii_v), v(ii_gpu)];
 
-                        [self.currentSolution{ii_v}, nextEnhancedSolution{ii_v}, app_time] ...
+                        [self.currentSolution{ii_v}, nextEnhancedSolution{ii_v}, app_time, corr_time] ...
                             = self.update_primal(self.currentSolution{ii_v}, ...
-                            nextEnhancedSolution{ii_v}, v{ii_gpu}, up_prim, tau{ii_v});
+                            nextEnhancedSolution{ii_v}, up_prim, tau{ii_v});
+                        timing_corr = timing_corr + corr_time;
                         timing_app = timing_app + app_time;
                     end
-                    clear v
                 end
             else
                 chunk_end = 0;
@@ -475,18 +484,19 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
                 timing_sf_bp = timing_sf_bp + sf_bp_time;
 
                 % Computing the update to apply
-                up_prim = compute_update_primal(n);
+                up_prim = [compute_update_primal(n), v];
 
-                [self.currentSolution{n}, nextEnhancedSolution{n}, app_time] ...
+                [self.currentSolution{n}, nextEnhancedSolution{n}, app_time, corr_time] ...
                     = self.update_primal(self.currentSolution{n}, ...
-                    nextEnhancedSolution{n}, v, up_prim, tau{n});
+                    nextEnhancedSolution{n}, up_prim, tau{n});
+                timing_corr = timing_corr + corr_time;
                 timing_app = timing_app + app_time;
-                clear v
             end
 
             self.statistics.add_timestamp(timing_bp, 'cp_primal_update', 'cp_primal_BP')
             self.statistics.add_timestamp(timing_bs, 'cp_primal_update', 'cp_primal_BS')
             self.statistics.add_timestamp(timing_sf_bp, 'cp_primal_update', 'cp_primal_SF')
+            self.statistics.add_timestamp(timing_corr, 'cp_primal_update', 'cp_primal_CORR')
             self.statistics.add_timestamp(timing_app, 'cp_primal_update', 'cp_primal_APP')
         end
     end
@@ -510,7 +520,7 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
             fprintf('\b\b (%2.1f s)\n', toc(c));
         end
 
-        function [sigma1, sigma1_1, sigma2, sigma2_1, tau] = init_cp_weights(self, algo)
+        function [sigma1, sigma1_1, sigma2, tau] = init_cp_weights(self, algo)
             sigma1 = cell(size(self.fwd_weights));
             for n = 1:numel(sigma1)
                 sigma1{n} = 1 ./ (self.fwd_weights{n} + (self.fwd_weights{n} == 0));
@@ -524,7 +534,6 @@ classdef Gt6DBlobReconstructor < Gt6DVolumeToBlobProjector
             use_ODF = ~isempty(self.ODF);
 
             sigma2 = use_ODF / prod(self.volume_geometry);
-            sigma2_1 = 1 / (1 + sigma2);
 
             num_geoms = self.get_number_geometries();
             tau = cell(size(self.bwd_weights));
-- 
GitLab