From 2800b4442ed05277dabd6a73e5b6c57e043eabd1 Mon Sep 17 00:00:00 2001
From: Nicola Vigano <nicola.vigano@esrf.fr>
Date: Thu, 30 Oct 2014 19:02:43 +0100
Subject: [PATCH] 6D-reconstructor: added functions that implement the update
 of some of the variables in the 6D algorithm

Signed-off-by: Nicola Vigano <nicola.vigano@esrf.fr>
---
 zUtil_Cxx/6D_ops/gt6DUpdateDualDetector_c.cpp |  81 ++++++
 zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp       |  92 +++++++
 zUtil_Cxx/include/gt6DUpdateDualOps.h         | 243 ++++++++++++++++++
 zUtil_Cxx/include/gt6DUpdatePrimalOps.h       | 202 +++++++++++++++
 4 files changed, 618 insertions(+)
 create mode 100644 zUtil_Cxx/6D_ops/gt6DUpdateDualDetector_c.cpp
 create mode 100644 zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp
 create mode 100644 zUtil_Cxx/include/gt6DUpdateDualOps.h
 create mode 100644 zUtil_Cxx/include/gt6DUpdatePrimalOps.h

diff --git a/zUtil_Cxx/6D_ops/gt6DUpdateDualDetector_c.cpp b/zUtil_Cxx/6D_ops/gt6DUpdateDualDetector_c.cpp
new file mode 100644
index 00000000..8c2a7599
--- /dev/null
+++ b/zUtil_Cxx/6D_ops/gt6DUpdateDualDetector_c.cpp
@@ -0,0 +1,81 @@
+/*
+ * gt6DUpdateDualDetector_c.cpp
+ *
+ *  Created on: Oct 29, 2014
+ *
+ * Nicola Vigano', 2014, INSA Lyon / ESRF ID11, vigano@esrf.eu
+ */
+
+#include <gt6DUpdateDualOps.h>
+
+void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
+{
+  if (nrhs != 5) {
+    mexErrMsgIdAndTxt(GT6D::dual_detector_error_id,
+        "Not enough arguments!");
+    return;
+  }
+  if (!mxIsCell(prhs[0])) {
+    mexErrMsgIdAndTxt(GT6D::dual_detector_error_id,
+        "The first argument should be a Cell array (dual variable)");
+    return;
+  }
+  if (!mxIsCell(prhs[1])) {
+    mexErrMsgIdAndTxt(GT6D::dual_detector_error_id,
+        "The second argument should be a Cell array (computed blobs)");
+    return;
+  }
+  if (!mxIsCell(prhs[2])) {
+    mexErrMsgIdAndTxt(GT6D::dual_detector_error_id,
+        "The third argument should be a Cell array (real blobs)");
+    return;
+  }
+  if (!mxIsCell(prhs[3])) {
+    mexErrMsgIdAndTxt(GT6D::dual_detector_error_id,
+        "The fourth argument should be a Cell array (sigma1)");
+    return;
+  }
+  if (!mxIsCell(prhs[4])) {
+    mexErrMsgIdAndTxt(GT6D::dual_detector_error_id,
+        "The fifth argument should be a Cell array (sigma1_1)");
+    return;
+  }
+
+  if (!GT6D::check_dual_detector_arguments(prhs[0], prhs[1], prhs[2], prhs[3], prhs[4]))
+  {
+    return;
+  }
+
+  cell_allocate(plhs[0], prhs[0]);
+
+  const mxClassID firstClass = getMxClassOfCell(prhs[0]);
+  switch (firstClass)
+  {
+    case mxDOUBLE_CLASS:
+    {
+      GT6D::update_dual_detector<double> func;
+      GT6D::dual_detector_iteration<double>(plhs[0], prhs[0], prhs[1], prhs[2], prhs[3], prhs[4], func);
+      break;
+    }
+    case mxSINGLE_CLASS:
+    {
+      GT6D::update_dual_detector<float> func;
+      GT6D::dual_detector_iteration<float>(plhs[0], prhs[0], prhs[1], prhs[2], prhs[3], prhs[4], func);
+      break;
+    }
+    case mxUNKNOWN_CLASS:
+    {
+      mexErrMsgIdAndTxt(GT6D::dual_detector_error_id,
+          "The argument needs to be a non empty Cell array of coherent floating point types");
+      return;
+    }
+    default:
+    {
+      mexErrMsgIdAndTxt(GT6D::dual_detector_error_id,
+          "The argument needs to be a Cell array of floating point numbers");
+      return;
+    }
+  }
+}
+
+
diff --git a/zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp b/zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp
new file mode 100644
index 00000000..5c622cb6
--- /dev/null
+++ b/zUtil_Cxx/6D_ops/gt6DUpdatePrimal_c.cpp
@@ -0,0 +1,92 @@
+/*
+ * gt6DUpdateDualDetector_c.cpp
+ *
+ *  Created on: Oct 29, 2014
+ *
+ * Nicola Vigano', 2014, INSA Lyon / ESRF ID11, vigano@esrf.eu
+ */
+
+#include "gt6DUpdatePrimalOps.h"
+
+void mexFunction( int nlhs, mxArray * plhs[], int nrhs, const mxArray * prhs[] )
+{
+  if (nrhs != 4) {
+    mexErrMsgIdAndTxt(GT6D::primal_error_id,
+        "Not enough arguments!");
+    return;
+  }
+  if (nlhs != 2) {
+    mexErrMsgIdAndTxt(GT6D::primal_error_id,
+        "Not enough output!");
+    return;
+  }
+  if (!mxIsNumeric(prhs[0])) {
+    mexErrMsgIdAndTxt(GT6D::primal_error_id,
+        "The first argument should be an array (old solution variable)");
+    return;
+  }
+  if (!mxIsNumeric(prhs[1])) {
+    mexErrMsgIdAndTxt(GT6D::primal_error_id,
+        "The second argument should be an array (Tomo correction)");
+    return;
+  }
+  if (!mxIsNumeric(prhs[2])) {
+    mexErrMsgIdAndTxt(GT6D::primal_error_id,
+        "The third argument should be an array (l1 correction)");
+    return;
+  }
+  if (!mxIsNumeric(prhs[3])) {
+    mexErrMsgIdAndTxt(GT6D::primal_error_id,
+        "The fourth argument should be a scalar (tau)");
+    return;
+  }
+
+  double scale;
+  if (mxIsDouble(prhs[3])) {
+    scale = *mxGetPr(prhs[3]);
+  } else if (mxIsSingle(prhs[3])) {
+    scale = *(float *)mxGetData(prhs[3]);
+  } else {
+    mexErrMsgIdAndTxt(GT6D::primal_error_id,
+        "The tau needs to be of a floating point type");
+    return;
+  }
+
+  if (!GT6D::check_primal_arguments(prhs[0], prhs[1], prhs[2]))
+  {
+    return;
+  }
+
+  GT6D::primal_allocate_output(plhs[0], plhs[1], prhs[0]);
+
+  const mxClassID firstClass = mxGetClassID(prhs[0]);
+  switch (firstClass)
+  {
+    case mxDOUBLE_CLASS:
+    {
+      GT6D::update_primal<double> func(scale);
+      GT6D::primal_iteration<double>(plhs[0], plhs[1], prhs[0], prhs[1], prhs[2], func);
+      break;
+    }
+    case mxSINGLE_CLASS:
+    {
+      GT6D::update_primal<float> func(scale);
+      GT6D::primal_iteration<float>(plhs[0], plhs[1], prhs[0], prhs[1], prhs[2], func);
+      break;
+    }
+    case mxUNKNOWN_CLASS:
+    {
+      mexErrMsgIdAndTxt(GT6D::primal_error_id,
+          "The argument needs to be a non empty Cell array of coherent floating point types");
+      return;
+    }
+    default:
+    {
+      mexErrMsgIdAndTxt(GT6D::primal_error_id,
+          "The argument needs to be a Cell array of floating point numbers");
+      return;
+    }
+  }
+}
+
+
diff --git a/zUtil_Cxx/include/gt6DUpdateDualOps.h b/zUtil_Cxx/include/gt6DUpdateDualOps.h
new file mode 100644
index 00000000..94a63ab8
--- /dev/null
+++ b/zUtil_Cxx/include/gt6DUpdateDualOps.h
@@ -0,0 +1,243 @@
+/*
+ * gt6DOps.h
+ *
+ *  Created on: Oct 29, 2014
+ *      Author: vigano
+ */
+
+#ifndef ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALOPS_H_
+#define ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALOPS_H_
+
+#include "internal_cell_defs.h"
+
+#include <algorithm>
+
+namespace GT6D {
+
+#define APPLY_FUNC_5FOLD_DUAL_DETECTOR(shift_val) \
+{ \
+  const vVvf inV11 = access.load(&in_data1[elemIdx + shift_val * simd_unroll.shift]);\
+  const vVvf inV21 = access.load(&in_data2[elemIdx + shift_val * simd_unroll.shift]);\
+  const vVvf inV31 = access.load(&in_data3[elemIdx + shift_val * simd_unroll.shift]);\
+  const vVvf inV41 = access.load(&in_data4[elemIdx + shift_val * simd_unroll.shift]);\
+  const vVvf inV51 = access.load(&in_data5[elemIdx + shift_val * simd_unroll.shift]);\
+  access.store(&out_data[elemIdx + shift_val * simd_unroll.shift], func(inV11, inV21, inV31, inV41, inV51));\
+}
+
+  const char * dual_detector_error_id = "C_FUN:gt6DUpdateDualDetector:wrong_argument";
+
+  template<typename Type>
+  class update_dual_detector {
+  public:
+    typedef typename SIMDUnrolling<Type>::vVvf vVvf;
+
+    update_dual_detector() { }
+
+    const Type
+    operator()(const Type & dual, const Type & blobs,
+        const Type & computed_blobs, const Type & sigma1,
+        const Type & sigma1_1) const throw()
+    {
+      return (dual + (computed_blobs - blobs) * sigma1) * sigma1_1;
+    }
+    const vVvf
+    operator()(const vVvf & dual, const vVvf & blobs,
+        const vVvf & computed_blobs, const vVvf & sigma1,
+        const vVvf & sigma1_1) const throw()
+    {
+      return (dual + (computed_blobs - blobs) * sigma1) * sigma1_1;
+    }
+  protected:
+  };
+
+  inline bool
+  check_dual_detector_arguments(const mxArray * const cells_dual,
+      const mxArray * const cells_blobs,
+      const mxArray * const cells_computed_blobs,
+      const mxArray * const cells_sigma1,
+      const mxArray * const cells_sigma1_1)
+  {
+    const mwSize num_cells = mxGetNumberOfElements(cells_dual);
+    if (num_cells != mxGetNumberOfElements(cells_blobs))
+    {
+      mexErrMsgIdAndTxt(dual_detector_error_id,
+              "Blobs should have the same number of cells as the Dual");
+      return false;
+    }
+    if (num_cells != mxGetNumberOfElements(cells_computed_blobs))
+    {
+      mexErrMsgIdAndTxt(dual_detector_error_id,
+              "Computed blobs should have the same number of cells as the Dual");
+      return false;
+    }
+    if (num_cells != mxGetNumberOfElements(cells_sigma1))
+    {
+      mexErrMsgIdAndTxt(dual_detector_error_id,
+              "Sigma1 should have the same number of cells as the Dual");
+      return false;
+    }
+    if (num_cells != mxGetNumberOfElements(cells_sigma1_1))
+    {
+      mexErrMsgIdAndTxt(dual_detector_error_id,
+              "Sigma1_1 should have the same number of cells as the Dual");
+      return false;
+    }
+
+    for(mwIndex cell_idx = 0; cell_idx < num_cells; cell_idx++)
+    {
+      const mxArray * const cell_dual = mxGetCell(cells_dual, cell_idx);
+      const mxArray * const cell_blob = mxGetCell(cells_blobs, cell_idx);
+      const mxArray * const cell_comp_blob = mxGetCell(cells_computed_blobs, cell_idx);
+      const mxArray * const cell_sigma1 = mxGetCell(cells_sigma1, cell_idx);
+      const mxArray * const cell_sigma1_1 = mxGetCell(cells_sigma1_1, cell_idx);
+
+      const mxClassID type = mxGetClassID(cell_dual);
+      const mwSize num_elems = mxGetNumberOfElements(cell_dual);
+
+      if (type != mxGetClassID(cell_blob))
+      {
+        mexErrMsgIdAndTxt(dual_detector_error_id,
+                "Blobs volumes should have the same type as the Dual");
+        return false;
+      }
+      if (type != mxGetClassID(cell_comp_blob))
+      {
+        mexErrMsgIdAndTxt(dual_detector_error_id,
+                "Computed volumes blobs should have the same type as the Dual");
+        return false;
+      }
+      if (type != mxGetClassID(cell_sigma1))
+      {
+        mexErrMsgIdAndTxt(dual_detector_error_id,
+                "Sigma1 volumes should have the same type as the Dual");
+        return false;
+      }
+      if (type != mxGetClassID(cell_sigma1_1))
+      {
+        mexErrMsgIdAndTxt(dual_detector_error_id,
+                "Sigma1_1 volumes should have the same type as the Dual");
+        return false;
+      }
+
+      if (num_elems != mxGetNumberOfElements(cell_blob))
+      {
+        mexErrMsgIdAndTxt(dual_detector_error_id,
+                "Blobs volumes should have the same number of elements as the Dual");
+        return false;
+      }
+      if (num_elems != mxGetNumberOfElements(cell_comp_blob))
+      {
+        mexErrMsgIdAndTxt(dual_detector_error_id,
+                "Blobs volumes should have the same number of elements as the Dual");
+        return false;
+      }
+      if (num_elems != mxGetNumberOfElements(cell_sigma1))
+      {
+        mexErrMsgIdAndTxt(dual_detector_error_id,
+                "Sgima1 volumes should have the same number of elements as the Dual");
+        return false;
+      }
+      if (num_elems != mxGetNumberOfElements(cell_sigma1_1))
+      {
+        mexErrMsgIdAndTxt(dual_detector_error_id,
+                "Sigma1_1 volumes should have the same number of elements as the Dual");
+        return false;
+      }
+    }
+
+    return true;
+  }
+
+  template<typename Type, typename Function, class AccessType>
+  inline void
+  detector_dual_inner_cycle_sse(Type * const __restrict out_data,
+      const Type * const __restrict in_data1,
+      const Type * const __restrict in_data2,
+      const Type * const __restrict in_data3,
+      const Type * const __restrict in_data4,
+      const Type * const __restrict in_data5,
+      const mwSize & numElems, Function & func)
+  {
+    typedef typename SIMDUnrolling<Type>::vVvf vVvf;
+
+    const mwSize unrolling = 8;
+    const SIMDUnrolling<Type> simd_unroll(unrolling);
+    const SIMDUnrolling<Type> simd(1);
+
+    AccessType access;
+
+  #pragma omp for nowait
+    for(mwIndex elemIdx = 0; elemIdx < simd_unroll.get_unroll(numElems);
+        elemIdx += simd_unroll.block)
+    {
+      APPLY_FUNC_5FOLD_DUAL_DETECTOR(0);
+      APPLY_FUNC_5FOLD_DUAL_DETECTOR(1);
+      APPLY_FUNC_5FOLD_DUAL_DETECTOR(2);
+      APPLY_FUNC_5FOLD_DUAL_DETECTOR(3);
+      APPLY_FUNC_5FOLD_DUAL_DETECTOR(4);
+      APPLY_FUNC_5FOLD_DUAL_DETECTOR(5);
+      APPLY_FUNC_5FOLD_DUAL_DETECTOR(6);
+      APPLY_FUNC_5FOLD_DUAL_DETECTOR(7);
+    }
+  #pragma omp for nowait
+    for(mwIndex elemIdx = simd_unroll.get_unroll(numElems);
+        elemIdx < simd.get_unroll(numElems); elemIdx += simd.block)
+    {
+      const vVvf inV11 = access.load(&in_data1[elemIdx]);
+      const vVvf inV21 = access.load(&in_data2[elemIdx]);
+      const vVvf inV31 = access.load(&in_data3[elemIdx]);
+      const vVvf inV41 = access.load(&in_data4[elemIdx]);
+      const vVvf inV51 = access.load(&in_data5[elemIdx]);
+
+      access.store(&out_data[elemIdx], func(inV11, inV21, inV31, inV41, inV51));
+    }
+  #pragma omp for nowait
+    for(mwIndex elemIdx = simd.get_unroll(numElems); elemIdx < numElems; elemIdx++)
+    {
+      out_data[elemIdx] = func(in_data1[elemIdx], in_data2[elemIdx], in_data3[elemIdx], in_data4[elemIdx], in_data5[elemIdx]);
+    }
+  }
+
+  template<typename Type, typename Function>
+  inline void
+  dual_detector_iteration(mxArray * outArray,
+      const mxArray * const cells_dual,
+      const mxArray * const cells_blobs,
+      const mxArray * const cells_computed_blobs,
+      const mxArray * const cells_sigma1,
+      const mxArray * const cells_sigma1_1,
+      const Function & func)
+  {
+    const mwSize num_cells = mxGetNumberOfElements(cells_dual);
+
+    const mwSize num_threads = std::max( omp_get_max_threads()/2 , 1);
+
+  #pragma omp parallel num_threads(num_threads)
+    {
+      for(mwIndex cell_idx = 0; cell_idx < num_cells; cell_idx++)
+      {
+        const mxArray * const cell_dual = mxGetCell(cells_dual, cell_idx);
+        const mxArray * const cell_blob = mxGetCell(cells_blobs, cell_idx);
+        const mxArray * const cell_comp_blob = mxGetCell(cells_computed_blobs, cell_idx);
+        const mxArray * const cell_sigma1 = mxGetCell(cells_sigma1, cell_idx);
+        const mxArray * const cell_sigma1_1 = mxGetCell(cells_sigma1_1, cell_idx);
+
+        const Type * const data_dual = (const Type *) mxGetData(cell_dual);
+        const Type * const data_blob = (const Type *) mxGetData(cell_blob);
+        const Type * const data_comp_blob = (const Type *) mxGetData(cell_comp_blob);
+        const Type * const data_sigma1 = (const Type *) mxGetData(cell_sigma1);
+        const Type * const data_sigma1_1 = (const Type *) mxGetData(cell_sigma1_1);
+
+        const mwSize num_elems = mxGetNumberOfElements(cell_dual);
+
+        mxArray * const out_cell = mxGetCell(outArray, cell_idx);
+        Type * const out_data = (Type *) mxGetData(out_cell);
+
+        detector_dual_inner_cycle_sse< Type, const Function, AccessAligned<Type> >(
+            out_data, data_dual, data_blob, data_comp_blob, data_sigma1, data_sigma1_1, num_elems, func);
+      }
+    }
+  }
+};
+
+#endif /* ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALOPS_H_ */
diff --git a/zUtil_Cxx/include/gt6DUpdatePrimalOps.h b/zUtil_Cxx/include/gt6DUpdatePrimalOps.h
new file mode 100644
index 00000000..6908e15c
--- /dev/null
+++ b/zUtil_Cxx/include/gt6DUpdatePrimalOps.h
@@ -0,0 +1,202 @@
+/*
+ * gt6DOps.h
+ *
+ *  Created on: Oct 29, 2014
+ *      Author: vigano
+ */
+
+#ifndef ZUTIL_CXX_INCLUDE_GT6DOPS_H_
+#define ZUTIL_CXX_INCLUDE_GT6DOPS_H_
+
+#include "internal_cell_defs.h"
+
+#include <algorithm>
+
+namespace GT6D {
+
+#define APPLY_FUNC_5FOLD_PRIMAL(shift_val) \
+{ \
+  vVvf inV11 = access.load(&out_data1[elemIdx + shift_val * simd_unroll.shift]);\
+  vVvf inV21 = access.load(&out_data2[elemIdx + shift_val * simd_unroll.shift]);\
+  const vVvf inV31 = access.load(&in_data3[elemIdx + shift_val * simd_unroll.shift]);\
+  const vVvf inV41 = access.load(&in_data4[elemIdx + shift_val * simd_unroll.shift]);\
+  const vVvf inV51 = access.load(&in_data5[elemIdx + shift_val * simd_unroll.shift]);\
+  func(inV11, inV21, inV31, inV41, inV51);\
+  access.store(&out_data1[elemIdx + shift_val * simd_unroll.shift], inV11);\
+  access.store(&out_data2[elemIdx + shift_val * simd_unroll.shift], inV21);\
+}
+
+  const char * primal_error_id = "C_FUN:gt6DUpdatePrimal:wrong_argument";
+
+
+  template<typename Type>
+  class update_primal {
+  public:
+    typedef typename SIMDUnrolling<Type>::vVvf vVvf;
+
+    update_primal(const Type & _tau) : tau(_tau), tau_v(Coeff<Type>::get(_tau)) { }
+
+    void
+    operator()(Type & new_solution, Type & new_enh_solution,
+        const Type & old_solution, const Type & correction_tomo,
+        const Type & correction_l1) const throw()
+    {
+
+      new_solution = non_neg(old_solution + (correction_tomo + correction_l1) * tau, Type());
+      new_enh_solution = (new_solution + new_solution) - old_solution;
+    }
+    void
+    operator()(vVvf & new_solution, vVvf & new_enh_solution,
+        const vVvf & old_solution, const vVvf & correction_tomo,
+        const vVvf & correction_l1) const throw()
+    {
+      new_solution = non_neg(old_solution + (correction_tomo + correction_l1) * tau, vVvf());
+      new_enh_solution = (new_solution + new_solution) - old_solution;
+    }
+  protected:
+    const Type tau;
+    const vVvf tau_v;
+
+    inner_non_neg<Type> non_neg;
+  };
+
+  inline bool
+  check_primal_arguments(const mxArray * const mat_old_solution,
+      const mxArray * const mat_correction_tomo,
+      const mxArray * const mat_correction_l1)
+  {
+    const mxClassID type = mxGetClassID(mat_old_solution);
+    const mwSize num_elems = mxGetNumberOfElements(mat_old_solution);
+
+    if (type != mxGetClassID(mat_correction_tomo))
+    {
+      mexErrMsgIdAndTxt(primal_error_id,
+              "tomo correction should have the same type as the Primal");
+      return false;
+    }
+    if (type != mxGetClassID(mat_correction_l1))
+    {
+      mexErrMsgIdAndTxt(primal_error_id,
+              "l1 correction blobs should have the same type as the Primal");
+      return false;
+    }
+
+    if (num_elems != mxGetNumberOfElements(mat_correction_tomo))
+    {
+      mexErrMsgIdAndTxt(primal_error_id,
+              "tomo correction should have the same number of elements as the Primal");
+      return false;
+    }
+    if (num_elems != mxGetNumberOfElements(mat_correction_l1))
+    {
+      mexErrMsgIdAndTxt(primal_error_id,
+              "l1 correction should have the same number of elements as the Primal");
+      return false;
+    }
+
+    return true;
+  }
+
+  template<typename Type, typename Function, class AccessType>
+  inline void
+  primal_inner_cycle_sse(Type * const __restrict out_data1,
+      Type * const __restrict out_data2,
+      const Type * const __restrict in_data3,
+      const Type * const __restrict in_data4,
+      const Type * const __restrict in_data5,
+      const mwSize & numElems, Function & func)
+  {
+    typedef typename SIMDUnrolling<Type>::vVvf vVvf;
+
+    const mwSize unrolling = 8;
+    const SIMDUnrolling<Type> simd_unroll(unrolling);
+    const SIMDUnrolling<Type> simd(1);
+
+    AccessType access;
+
+  #pragma omp for nowait
+    for(mwIndex elemIdx = 0; elemIdx < simd_unroll.get_unroll(numElems);
+        elemIdx += simd_unroll.block)
+    {
+      APPLY_FUNC_5FOLD_PRIMAL(0);
+      APPLY_FUNC_5FOLD_PRIMAL(1);
+      APPLY_FUNC_5FOLD_PRIMAL(2);
+      APPLY_FUNC_5FOLD_PRIMAL(3);
+      APPLY_FUNC_5FOLD_PRIMAL(4);
+      APPLY_FUNC_5FOLD_PRIMAL(5);
+      APPLY_FUNC_5FOLD_PRIMAL(6);
+      APPLY_FUNC_5FOLD_PRIMAL(7);
+    }
+  #pragma omp for nowait
+    for(mwIndex elemIdx = simd_unroll.get_unroll(numElems);
+        elemIdx < simd.get_unroll(numElems); elemIdx += simd.block)
+    {
+      vVvf inV11 = access.load(&out_data1[elemIdx]);
+      vVvf inV21 = access.load(&out_data2[elemIdx]);
+
+      const vVvf inV31 = access.load(&in_data3[elemIdx]);
+      const vVvf inV41 = access.load(&in_data4[elemIdx]);
+      const vVvf inV51 = access.load(&in_data5[elemIdx]);
+
+      func(inV11, inV21, inV31, inV41, inV51);
+
+      access.store(&out_data1[elemIdx], inV11);
+      access.store(&out_data2[elemIdx], inV21);
+    }
+  #pragma omp for nowait
+    for(mwIndex elemIdx = simd.get_unroll(numElems); elemIdx < numElems; elemIdx++)
+    {
+       func(out_data1[elemIdx], out_data2[elemIdx], in_data3[elemIdx], in_data4[elemIdx], in_data5[elemIdx]);
+    }
+  }
+
+  template<typename Type, typename Function>
+  inline void
+  primal_iteration(mxArray * mat_new_solution,
+      mxArray * mat_new_enh_solution,
+      const mxArray * const mat_old_solution,
+      const mxArray * const mat_correction_tomo,
+      const mxArray * const mat_correction_l1,
+      const Function & func)
+  {
+    const mwSize num_threads = std::max( omp_get_max_threads()/2 , 1);
+
+  #pragma omp parallel num_threads(num_threads)
+    {
+      Type * const new_solution = (Type *) mxGetData(mat_new_solution);
+      Type * const new_enh_solution = (Type *) mxGetData(mat_new_enh_solution);
+      const Type * const old_solution = (const Type *) mxGetData(mat_old_solution);
+      const Type * const corr_tomo = (const Type *) mxGetData(mat_correction_tomo);
+      const Type * const corr_l1 = (const Type *) mxGetData(mat_correction_l1);
+
+      const mwSize num_elems = mxGetNumberOfElements(mat_new_solution);
+
+      primal_inner_cycle_sse< Type, const Function, AccessAligned<Type> >(
+          new_solution, new_enh_solution, old_solution, corr_tomo, corr_l1, num_elems, func);
+    }
+  }
+
+  inline void
+  primal_allocate_output(mxArray * & mat_new_solution,
+      mxArray * & mat_new_enh_solution,
+      const mxArray * const mat_old_solution)
+  {
+    const mwSize zero_dims[2] = {0, 0};
+
+    const mxClassID type = mxGetClassID(mat_old_solution);
+    const mwSize numDims = mxGetNumberOfDimensions(mat_old_solution);
+    const mwSize * dims = mxGetDimensions(mat_old_solution);
+    const mwSize num_elems = mxGetNumberOfElements(mat_old_solution);
+    const mwSize elem_size = mxGetElementSize(mat_old_solution);
+
+    mat_new_solution = mxCreateNumericArray(2, zero_dims, type, mxREAL);
+    mxSetDimensions(mat_new_solution, dims, numDims);
+    mxSetData(mat_new_solution, mxMalloc(elem_size * num_elems));
+
+    mat_new_enh_solution = mxCreateNumericArray(2, zero_dims, type, mxREAL);
+    mxSetDimensions(mat_new_enh_solution, dims, numDims);
+    mxSetData(mat_new_enh_solution, mxMalloc(elem_size * num_elems));
+  }
+};
+
+#endif /* ZUTIL_CXX_INCLUDE_GT6DOPS_H_ */
-- 
GitLab