Skip to content
Snippets Groups Projects
gt6DUpdateDualL1Ops.h 4.85 KiB
/*
 * gt6DUpdateDualL1Ops.h
 *
 *  Created on: Oct 31, 2014
 *      Author: vigano
 */

#ifndef ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALL1OPS_H_
#define ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALL1OPS_H_

#include "internal_cell_defs.h"

#include <algorithm>
#include <cmath>

namespace GT6D {

  const char * dual_l1_error_id = "C_FUN:gt6DUpdateDualL1:wrong_argument";

  template<typename Type>
  class update_dual_l1 {
  public:
    typedef typename SIMDUnrolling<Type>::vVvf vVvf;

    update_dual_l1(const Type & _lambda, const Type & _num_orientations);

    const Type
    operator()(const Type & dual, const Type & new_enh_sol) const throw()
    {
      const Type temp = dual + new_enh_sol / n_orient;
      return lambda * temp / std::max(lambda, this->abs(temp));
    }
    const vVvf
    operator()(const vVvf & dual, const vVvf & new_enh_sol) const throw()
    {
      const vVvf temp = dual + new_enh_sol / n_orient_v;
      return lambda_v * temp / this->max(lambda_v, this->abs(temp));
    }
  protected:
    const Type lambda;
    const vVvf lambda_v;

    const Type n_orient;
    const vVvf n_orient_v;

    const vVvf abs_mask_v;

    const Type
    abs(const Type & val) const throw();
    const vVvf
    abs(const vVvf & val) const throw();
    const vVvf
    max(const vVvf & val1, const vVvf & val2) const throw();
  };

  template<>
  update_dual_l1<float>::update_dual_l1(const float & _lambda, const float & _n_orient)
    : lambda(_lambda), lambda_v(Coeff<float>::get(_lambda))
    , n_orient(_n_orient), n_orient_v(Coeff<float>::get(_n_orient))
    , abs_mask_v(_mm_castsi128_ps(_mm_set1_epi32(0x7fffffff)))
  { }

  template<>
  update_dual_l1<double>::update_dual_l1(const double & _lambda, const double & _n_orient)
    : lambda(_lambda), lambda_v(Coeff<double>::get(_lambda))
    , n_orient(_n_orient), n_orient_v(Coeff<double>::get(_n_orient))
    , abs_mask_v(_mm_castsi128_pd(_mm_set1_epi64x(0x7fffffffffffffffL)))
  { }

  template<>
  inline const update_dual_l1<float>::vVvf
  update_dual_l1<float>::abs(
      const update_dual_l1<float>::vVvf & val)
  const throw()
  {
    typedef update_dual_l1<float>::vVvf vVvf;
    return _mm_and_ps(val, abs_mask_v);
  }

  template<>
  inline const update_dual_l1<double>::vVvf
  update_dual_l1<double>::abs(
      const update_dual_l1<double>::vVvf & val)
  const throw()
  {
    typedef update_dual_l1<double>::vVvf vVvf;
    return _mm_and_pd(val, abs_mask_v);
  }

  template<>
  inline const float
  update_dual_l1<float>::abs(const float & val)
  const throw()
  {
    return std::abs(val);
  }

  template<>
  inline const double
  update_dual_l1<double>::abs(const double & val)
  const throw()
  {
    return std::fabs(val);
  }

  template<>
  inline const update_dual_l1<float>::vVvf
  update_dual_l1<float>::max(
      const update_dual_l1<float>::vVvf & val1,
      const update_dual_l1<float>::vVvf & val2)
  const throw()
  {
    typedef update_dual_l1<float>::vVvf vVvf;
    return _mm_max_ps(val1, val2);
  }

  template<>
  inline const update_dual_l1<double>::vVvf
  update_dual_l1<double>::max(
      const update_dual_l1<double>::vVvf & val1,
      const update_dual_l1<double>::vVvf & val2)
  const throw()
  {
    typedef update_dual_l1<double>::vVvf vVvf;
    return _mm_max_pd(val1, val2);
  }


  inline mxClassID
  check_dual_l1_arguments(const mxArray * const cells_dual,
      const mxArray * const cells_new_enh_sol)
  {
    const mwSize num_cells = mxGetNumberOfElements(cells_dual);
    if (num_cells != mxGetNumberOfElements(cells_new_enh_sol))
    {
      mexErrMsgIdAndTxt(dual_l1_error_id,
          "The Cell arrays should have the same number of elements");
      return mxUNKNOWN_CLASS;
    }
    mxClassID expected_type = mxUNKNOWN_CLASS;

    for(mwIndex cell_idx = 0; cell_idx < num_cells; cell_idx++)
    {
      const mxArray * const cell_dual = mxGetCell(cells_dual, cell_idx);
      const mxArray * const cell_new_enh_sol = mxGetCell(cells_new_enh_sol, cell_idx);

      const mxClassID type_dual = mxGetClassID(cell_dual);
      if (cell_idx == 0)
      {
        expected_type = type_dual;
      }

      if (type_dual != expected_type)
      {
        mexErrMsgIdAndTxt(dual_l1_error_id,
            "The arguments need to be Cell arrays of coherent floating point types");
        return mxUNKNOWN_CLASS;
      }
      if (type_dual != mxGetClassID(cell_new_enh_sol))
      {
        mexErrMsgIdAndTxt(dual_l1_error_id,
            "The arguments need to be Cell arrays of coherent floating point types");
        return mxUNKNOWN_CLASS;
      }

      const mwSize num_elems = mxGetNumberOfElements(cell_dual);
      if (num_elems != mxGetNumberOfElements(cell_new_enh_sol))
      {
        mexErrMsgIdAndTxt(dual_l1_error_id,
                "Blobs volumes should have the same number of elements as the Dual");
        return mxUNKNOWN_CLASS;
      }
    }

    return expected_type;
  }
};

#endif /* ZUTIL_CXX_INCLUDE_GT6DUPDATEDUALL1OPS_H_ */