cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

TensorRef.h (14793B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_REF_H
     12 
     13 namespace Eigen {
     14 
     15 namespace internal {
     16 
     17 template <typename Dimensions, typename Scalar>
     18 class TensorLazyBaseEvaluator {
     19  public:
     20   TensorLazyBaseEvaluator() : m_refcount(0) { }
     21   virtual ~TensorLazyBaseEvaluator() { }
     22 
     23   EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const = 0;
     24   EIGEN_DEVICE_FUNC virtual const Scalar* data() const = 0;
     25 
     26   EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const = 0;
     27   EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) = 0;
     28 
     29   void incrRefCount() { ++m_refcount; }
     30   void decrRefCount() { --m_refcount; }
     31   int refCount() const { return m_refcount; }
     32 
     33  private:
     34   // No copy, no assignment;
     35   TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other);
     36   TensorLazyBaseEvaluator& operator = (const TensorLazyBaseEvaluator& other);
     37 
     38   int m_refcount;
     39 };
     40 
     41 
     42 template <typename Dimensions, typename Expr, typename Device>
     43 class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> {
     44  public:
     45   //  typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions;
     46   typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar;
     47   typedef StorageMemory<Scalar, Device> Storage;
     48   typedef typename Storage::Type EvaluatorPointerType;
     49   typedef  TensorEvaluator<Expr, Device> EvalType;
     50 
     51   TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
     52     m_dims = m_impl.dimensions();
     53     m_impl.evalSubExprsIfNeeded(NULL);
     54   }
     55   virtual ~TensorLazyEvaluatorReadOnly() {
     56     m_impl.cleanup();
     57   }
     58 
     59   EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const {
     60     return m_dims;
     61   }
     62   EIGEN_DEVICE_FUNC virtual const Scalar* data() const {
     63     return m_impl.data();
     64   }
     65 
     66   EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const {
     67     return m_impl.coeff(index);
     68   }
     69   EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex /*index*/) {
     70     eigen_assert(false && "can't reference the coefficient of a rvalue");
     71     return m_dummy;
     72   };
     73 
     74  protected:
     75   TensorEvaluator<Expr, Device> m_impl;
     76   Dimensions m_dims;
     77   Scalar m_dummy;
     78 };
     79 
     80 template <typename Dimensions, typename Expr, typename Device>
     81 class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> {
     82  public:
     83   typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base;
     84   typedef typename Base::Scalar Scalar;
     85   typedef StorageMemory<Scalar, Device> Storage;
     86   typedef typename Storage::Type EvaluatorPointerType;
     87 
     88   TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) {
     89   }
     90   virtual ~TensorLazyEvaluatorWritable() {
     91   }
     92 
     93   EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) {
     94     return this->m_impl.coeffRef(index);
     95   }
     96 };
     97 
     98 template <typename Dimensions, typename Expr, typename Device>
     99 class TensorLazyEvaluator : public internal::conditional<bool(internal::is_lvalue<Expr>::value),
    100                             TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
    101                             TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type {
    102  public:
    103   typedef typename internal::conditional<bool(internal::is_lvalue<Expr>::value),
    104                                          TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
    105                                          TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type Base;
    106   typedef typename Base::Scalar Scalar;
    107 
    108   TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) {
    109   }
    110   virtual ~TensorLazyEvaluator() {
    111   }
    112 };
    113 
    114 }  // namespace internal
    115 
    116 
    117 /** \class TensorRef
    118   * \ingroup CXX11_Tensor_Module
    119   *
    120   * \brief A reference to a tensor expression
    121   * The expression will be evaluated lazily (as much as possible).
    122   *
    123   */
    124 template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef<PlainObjectType> >
    125 {
    126   public:
    127     typedef TensorRef<PlainObjectType> Self;
    128     typedef typename PlainObjectType::Base Base;
    129     typedef typename Eigen::internal::nested<Self>::type Nested;
    130     typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
    131     typedef typename internal::traits<PlainObjectType>::Index Index;
    132     typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
    133     typedef typename NumTraits<Scalar>::Real RealScalar;
    134     typedef typename Base::CoeffReturnType CoeffReturnType;
    135     typedef Scalar* PointerType;
    136     typedef PointerType PointerArgType;
    137 
    138     static const Index NumIndices = PlainObjectType::NumIndices;
    139     typedef typename PlainObjectType::Dimensions Dimensions;
    140 
    141     enum {
    142       IsAligned = false,
    143       PacketAccess = false,
    144       BlockAccess = false,
    145       PreferBlockAccess = false,
    146       Layout = PlainObjectType::Layout,
    147       CoordAccess = false,  // to be implemented
    148       RawAccess = false
    149     };
    150 
    151     //===- Tensor block evaluation strategy (see TensorBlock.h) -----------===//
    152     typedef internal::TensorBlockNotImplemented TensorBlock;
    153     //===------------------------------------------------------------------===//
    154 
    155     EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
    156     }
    157 
    158     template <typename Expression>
    159     EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) {
    160       m_evaluator->incrRefCount();
    161     }
    162 
    163     template <typename Expression>
    164     EIGEN_STRONG_INLINE TensorRef& operator = (const Expression& expr) {
    165       unrefEvaluator();
    166       m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice());
    167       m_evaluator->incrRefCount();
    168       return *this;
    169     }
    170 
    171     ~TensorRef() {
    172       unrefEvaluator();
    173     }
    174 
    175     TensorRef(const TensorRef& other) : m_evaluator(other.m_evaluator) {
    176       eigen_assert(m_evaluator->refCount() > 0);
    177       m_evaluator->incrRefCount();
    178     }
    179 
    180     TensorRef& operator = (const TensorRef& other) {
    181       if (this != &other) {
    182         unrefEvaluator();
    183         m_evaluator = other.m_evaluator;
    184         eigen_assert(m_evaluator->refCount() > 0);
    185         m_evaluator->incrRefCount();
    186       }
    187       return *this;
    188     }
    189 
    190     EIGEN_DEVICE_FUNC
    191     EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
    192     EIGEN_DEVICE_FUNC
    193     EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
    194     EIGEN_DEVICE_FUNC
    195     EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
    196     EIGEN_DEVICE_FUNC
    197     EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); }
    198     EIGEN_DEVICE_FUNC
    199     EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); }
    200 
    201     EIGEN_DEVICE_FUNC
    202     EIGEN_STRONG_INLINE const Scalar operator()(Index index) const
    203     {
    204       return m_evaluator->coeff(index);
    205     }
    206 
    207 #if EIGEN_HAS_VARIADIC_TEMPLATES
    208     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
    209     EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const
    210     {
    211       const std::size_t num_indices = (sizeof...(otherIndices) + 1);
    212       const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
    213       return coeff(indices);
    214     }
    215     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
    216     EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices)
    217     {
    218       const std::size_t num_indices = (sizeof...(otherIndices) + 1);
    219       const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
    220       return coeffRef(indices);
    221     }
    222 #else
    223 
    224     EIGEN_DEVICE_FUNC
    225     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1) const
    226     {
    227       array<Index, 2> indices;
    228       indices[0] = i0;
    229       indices[1] = i1;
    230       return coeff(indices);
    231     }
    232     EIGEN_DEVICE_FUNC
    233     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2) const
    234     {
    235       array<Index, 3> indices;
    236       indices[0] = i0;
    237       indices[1] = i1;
    238       indices[2] = i2;
    239       return coeff(indices);
    240     }
    241     EIGEN_DEVICE_FUNC
    242     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3) const
    243     {
    244       array<Index, 4> indices;
    245       indices[0] = i0;
    246       indices[1] = i1;
    247       indices[2] = i2;
    248       indices[3] = i3;
    249       return coeff(indices);
    250     }
    251     EIGEN_DEVICE_FUNC
    252     EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
    253     {
    254       array<Index, 5> indices;
    255       indices[0] = i0;
    256       indices[1] = i1;
    257       indices[2] = i2;
    258       indices[3] = i3;
    259       indices[4] = i4;
    260       return coeff(indices);
    261     }
    262     EIGEN_DEVICE_FUNC
    263     EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1)
    264     {
    265       array<Index, 2> indices;
    266       indices[0] = i0;
    267       indices[1] = i1;
    268       return coeffRef(indices);
    269     }
    270     EIGEN_DEVICE_FUNC
    271     EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2)
    272     {
    273       array<Index, 3> indices;
    274       indices[0] = i0;
    275       indices[1] = i1;
    276       indices[2] = i2;
    277       return coeffRef(indices);
    278     }
    279     EIGEN_DEVICE_FUNC
    280     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
    281     {
    282       array<Index, 4> indices;
    283       indices[0] = i0;
    284       indices[1] = i1;
    285       indices[2] = i2;
    286       indices[3] = i3;
    287       return coeffRef(indices);
    288     }
    289     EIGEN_DEVICE_FUNC
    290     EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4)
    291     {
    292       array<Index, 5> indices;
    293       indices[0] = i0;
    294       indices[1] = i1;
    295       indices[2] = i2;
    296       indices[3] = i3;
    297       indices[4] = i4;
    298       return coeffRef(indices);
    299     }
    300 #endif
    301 
    302     template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
    303     EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const
    304     {
    305       const Dimensions& dims = this->dimensions();
    306       Index index = 0;
    307       if (PlainObjectType::Options & RowMajor) {
    308         index += indices[0];
    309         for (size_t i = 1; i < NumIndices; ++i) {
    310           index = index * dims[i] + indices[i];
    311         }
    312       } else {
    313         index += indices[NumIndices-1];
    314         for (int i = NumIndices-2; i >= 0; --i) {
    315           index = index * dims[i] + indices[i];
    316         }
    317       }
    318       return m_evaluator->coeff(index);
    319     }
    320     template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
    321     EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
    322     {
    323       const Dimensions& dims = this->dimensions();
    324       Index index = 0;
    325       if (PlainObjectType::Options & RowMajor) {
    326         index += indices[0];
    327         for (size_t i = 1; i < NumIndices; ++i) {
    328           index = index * dims[i] + indices[i];
    329         }
    330       } else {
    331         index += indices[NumIndices-1];
    332         for (int i = NumIndices-2; i >= 0; --i) {
    333           index = index * dims[i] + indices[i];
    334         }
    335       }
    336       return m_evaluator->coeffRef(index);
    337     }
    338 
    339     EIGEN_DEVICE_FUNC
    340     EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
    341     {
    342       return m_evaluator->coeff(index);
    343     }
    344 
    345     EIGEN_DEVICE_FUNC
    346     EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
    347     {
    348       return m_evaluator->coeffRef(index);
    349     }
    350 
    351   private:
    352     EIGEN_STRONG_INLINE void unrefEvaluator() {
    353       if (m_evaluator) {
    354         m_evaluator->decrRefCount();
    355         if (m_evaluator->refCount() == 0) {
    356           delete m_evaluator;
    357         }
    358       }
    359     }
    360 
    361   internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
    362 };
    363 
    364 
    365 // evaluator for rvalues
    366 template<typename Derived, typename Device>
    367 struct TensorEvaluator<const TensorRef<Derived>, Device>
    368 {
    369   typedef typename Derived::Index Index;
    370   typedef typename Derived::Scalar Scalar;
    371   typedef typename Derived::Scalar CoeffReturnType;
    372   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
    373   typedef typename Derived::Dimensions Dimensions;
    374   typedef StorageMemory<CoeffReturnType, Device> Storage;
    375   typedef typename Storage::Type EvaluatorPointerType;
    376 
    377   enum {
    378     IsAligned = false,
    379     PacketAccess = false,
    380     BlockAccess = false,
    381     PreferBlockAccess = false,
    382     Layout = TensorRef<Derived>::Layout,
    383     CoordAccess = false,  // to be implemented
    384     RawAccess = false
    385   };
    386 
    387   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
    388   typedef internal::TensorBlockNotImplemented TensorBlock;
    389   //===--------------------------------------------------------------------===//
    390 
    391   EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)
    392       : m_ref(m)
    393   { }
    394 
    395   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); }
    396 
    397   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) {
    398     return true;
    399   }
    400 
    401   EIGEN_STRONG_INLINE void cleanup() { }
    402 
    403   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
    404     return m_ref.coeff(index);
    405   }
    406 
    407   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
    408     return m_ref.coeffRef(index);
    409   }
    410 
    411   EIGEN_DEVICE_FUNC const Scalar* data() const { return m_ref.data(); }
    412 
    413  protected:
    414   TensorRef<Derived> m_ref;
    415 };
    416 
    417 
    418 // evaluator for lvalues
    419 template<typename Derived, typename Device>
    420 struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device>
    421 {
    422   typedef typename Derived::Index Index;
    423   typedef typename Derived::Scalar Scalar;
    424   typedef typename Derived::Scalar CoeffReturnType;
    425   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
    426   typedef typename Derived::Dimensions Dimensions;
    427 
    428   typedef TensorEvaluator<const TensorRef<Derived>, Device> Base;
    429 
    430   enum {
    431     IsAligned = false,
    432     PacketAccess = false,
    433     BlockAccess = false,
    434     PreferBlockAccess = false,
    435     RawAccess = false
    436   };
    437 
    438   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
    439   typedef internal::TensorBlockNotImplemented TensorBlock;
    440   //===--------------------------------------------------------------------===//
    441 
    442   EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d)
    443   { }
    444 
    445   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
    446     return this->m_ref.coeffRef(index);
    447   }
    448 };
    449 
    450 
    451 
    452 } // end namespace Eigen
    453 
    454 #endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H