cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

TensorConcatenation.h (15665B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
     12 
     13 namespace Eigen {
     14 
     15 /** \class TensorConcatenationOp
     16   * \ingroup CXX11_Tensor_Module
     17   *
     18   * \brief Tensor concatenation class.
     19   *
     20   *
     21   */
     22 namespace internal {
     23 template<typename Axis, typename LhsXprType, typename RhsXprType>
     24 struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >
     25 {
     26   // Type promotion to handle the case where the types of the lhs and the rhs are different.
     27   typedef typename promote_storage_type<typename LhsXprType::Scalar,
     28                                         typename RhsXprType::Scalar>::ret Scalar;
     29   typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
     30                                         typename traits<RhsXprType>::StorageKind>::ret StorageKind;
     31   typedef typename promote_index_type<typename traits<LhsXprType>::Index,
     32                                       typename traits<RhsXprType>::Index>::type Index;
     33   typedef typename LhsXprType::Nested LhsNested;
     34   typedef typename RhsXprType::Nested RhsNested;
     35   typedef typename remove_reference<LhsNested>::type _LhsNested;
     36   typedef typename remove_reference<RhsNested>::type _RhsNested;
     37   static const int NumDimensions = traits<LhsXprType>::NumDimensions;
     38   static const int Layout = traits<LhsXprType>::Layout;
     39   enum { Flags = 0 };
     40   typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
     41                                typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType;
     42 };
     43 
     44 template<typename Axis, typename LhsXprType, typename RhsXprType>
     45 struct eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, Eigen::Dense>
     46 {
     47   typedef const TensorConcatenationOp<Axis, LhsXprType, RhsXprType>& type;
     48 };
     49 
     50 template<typename Axis, typename LhsXprType, typename RhsXprType>
     51 struct nested<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, 1, typename eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >::type>
     52 {
     53   typedef TensorConcatenationOp<Axis, LhsXprType, RhsXprType> type;
     54 };
     55 
     56 }  // end namespace internal
     57 
     58 
     59 template<typename Axis, typename LhsXprType, typename RhsXprType>
     60 class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors>
     61 {
     62   public:
     63     typedef TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors> Base;
     64     typedef typename internal::traits<TensorConcatenationOp>::Scalar Scalar;
     65     typedef typename internal::traits<TensorConcatenationOp>::StorageKind StorageKind;
     66     typedef typename internal::traits<TensorConcatenationOp>::Index Index;
     67     typedef typename internal::nested<TensorConcatenationOp>::type Nested;
     68     typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
     69                                                     typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
     70     typedef typename NumTraits<Scalar>::Real RealScalar;
     71 
     72     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp(const LhsXprType& lhs, const RhsXprType& rhs, Axis axis)
     73         : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_axis(axis) {}
     74 
     75     EIGEN_DEVICE_FUNC
     76     const typename internal::remove_all<typename LhsXprType::Nested>::type&
     77     lhsExpression() const { return m_lhs_xpr; }
     78 
     79     EIGEN_DEVICE_FUNC
     80     const typename internal::remove_all<typename RhsXprType::Nested>::type&
     81     rhsExpression() const { return m_rhs_xpr; }
     82 
     83     EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; }
     84 
     85     EIGEN_TENSOR_INHERIT_ASSIGNMENT_OPERATORS(TensorConcatenationOp)
     86   protected:
     87     typename LhsXprType::Nested m_lhs_xpr;
     88     typename RhsXprType::Nested m_rhs_xpr;
     89     const Axis m_axis;
     90 };
     91 
     92 
     93 // Eval as rvalue
     94 template<typename Axis, typename LeftArgType, typename RightArgType, typename Device>
     95 struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
     96 {
     97   typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
     98   typedef typename XprType::Index Index;
     99   static const int NumDims = internal::array_size<typename TensorEvaluator<LeftArgType, Device>::Dimensions>::value;
    100   static const int RightNumDims = internal::array_size<typename TensorEvaluator<RightArgType, Device>::Dimensions>::value;
    101   typedef DSizes<Index, NumDims> Dimensions;
    102   typedef typename XprType::Scalar Scalar;
    103   typedef typename XprType::CoeffReturnType CoeffReturnType;
    104   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
    105   typedef StorageMemory<CoeffReturnType, Device> Storage;
    106   typedef typename Storage::Type EvaluatorPointerType;
    107   enum {
    108     IsAligned         = false,
    109     PacketAccess      = TensorEvaluator<LeftArgType, Device>::PacketAccess &&
    110                         TensorEvaluator<RightArgType, Device>::PacketAccess,
    111     BlockAccess       = false,
    112     PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess ||
    113                         TensorEvaluator<RightArgType, Device>::PreferBlockAccess,
    114     Layout            = TensorEvaluator<LeftArgType, Device>::Layout,
    115     RawAccess         = false
    116   };
    117 
    118   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
    119   typedef internal::TensorBlockNotImplemented TensorBlock;
    120   //===--------------------------------------------------------------------===//
    121 
    122   EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
    123     : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_axis(op.axis())
    124   {
    125     EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) || NumDims == 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
    126     EIGEN_STATIC_ASSERT((NumDims == RightNumDims), YOU_MADE_A_PROGRAMMING_MISTAKE);
    127     EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
    128 
    129     eigen_assert(0 <= m_axis && m_axis < NumDims);
    130     const Dimensions& lhs_dims = m_leftImpl.dimensions();
    131     const Dimensions& rhs_dims = m_rightImpl.dimensions();
    132     {
    133       int i = 0;
    134       for (; i < m_axis; ++i) {
    135         eigen_assert(lhs_dims[i] > 0);
    136         eigen_assert(lhs_dims[i] == rhs_dims[i]);
    137         m_dimensions[i] = lhs_dims[i];
    138       }
    139       eigen_assert(lhs_dims[i] > 0);  // Now i == m_axis.
    140       eigen_assert(rhs_dims[i] > 0);
    141       m_dimensions[i] = lhs_dims[i] + rhs_dims[i];
    142       for (++i; i < NumDims; ++i) {
    143         eigen_assert(lhs_dims[i] > 0);
    144         eigen_assert(lhs_dims[i] == rhs_dims[i]);
    145         m_dimensions[i] = lhs_dims[i];
    146       }
    147     }
    148 
    149     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    150       m_leftStrides[0] = 1;
    151       m_rightStrides[0] = 1;
    152       m_outputStrides[0] = 1;
    153 
    154       for (int j = 1; j < NumDims; ++j) {
    155         m_leftStrides[j] = m_leftStrides[j-1] * lhs_dims[j-1];
    156         m_rightStrides[j] = m_rightStrides[j-1] * rhs_dims[j-1];
    157         m_outputStrides[j] = m_outputStrides[j-1] * m_dimensions[j-1];
    158       }
    159     } else {
    160       m_leftStrides[NumDims - 1] = 1;
    161       m_rightStrides[NumDims - 1] = 1;
    162       m_outputStrides[NumDims - 1] = 1;
    163 
    164       for (int j = NumDims - 2; j >= 0; --j) {
    165         m_leftStrides[j] = m_leftStrides[j+1] * lhs_dims[j+1];
    166         m_rightStrides[j] = m_rightStrides[j+1] * rhs_dims[j+1];
    167         m_outputStrides[j] = m_outputStrides[j+1] * m_dimensions[j+1];
    168       }
    169     }
    170   }
    171 
    172   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
    173 
    174   // TODO(phli): Add short-circuit memcpy evaluation if underlying data are linear?
    175   EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType)
    176   {
    177     m_leftImpl.evalSubExprsIfNeeded(NULL);
    178     m_rightImpl.evalSubExprsIfNeeded(NULL);
    179     return true;
    180   }
    181 
    182   EIGEN_STRONG_INLINE void cleanup()
    183   {
    184     m_leftImpl.cleanup();
    185     m_rightImpl.cleanup();
    186   }
    187 
    188   // TODO(phli): attempt to speed this up. The integer divisions and modulo are slow.
    189   // See CL/76180724 comments for more ideas.
    190   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
    191   {
    192     // Collect dimension-wise indices (subs).
    193     array<Index, NumDims> subs;
    194     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    195       for (int i = NumDims - 1; i > 0; --i) {
    196         subs[i] = index / m_outputStrides[i];
    197         index -= subs[i] * m_outputStrides[i];
    198       }
    199       subs[0] = index;
    200     } else {
    201       for (int i = 0; i < NumDims - 1; ++i) {
    202         subs[i] = index / m_outputStrides[i];
    203         index -= subs[i] * m_outputStrides[i];
    204       }
    205       subs[NumDims - 1] = index;
    206     }
    207 
    208     const Dimensions& left_dims = m_leftImpl.dimensions();
    209     if (subs[m_axis] < left_dims[m_axis]) {
    210       Index left_index;
    211       if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    212         left_index = subs[0];
    213         EIGEN_UNROLL_LOOP
    214         for (int i = 1; i < NumDims; ++i) {
    215           left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
    216         }
    217       } else {
    218         left_index = subs[NumDims - 1];
    219         EIGEN_UNROLL_LOOP
    220         for (int i = NumDims - 2; i >= 0; --i) {
    221           left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
    222         }
    223       }
    224       return m_leftImpl.coeff(left_index);
    225     } else {
    226       subs[m_axis] -= left_dims[m_axis];
    227       const Dimensions& right_dims = m_rightImpl.dimensions();
    228       Index right_index;
    229       if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
    230         right_index = subs[0];
    231         EIGEN_UNROLL_LOOP
    232         for (int i = 1; i < NumDims; ++i) {
    233           right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
    234         }
    235       } else {
    236         right_index = subs[NumDims - 1];
    237         EIGEN_UNROLL_LOOP
    238         for (int i = NumDims - 2; i >= 0; --i) {
    239           right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
    240         }
    241       }
    242       return m_rightImpl.coeff(right_index);
    243     }
    244   }
    245 
    246   // TODO(phli): Add a real vectorization.
    247   template<int LoadMode>
    248   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
    249   {
    250     const int packetSize = PacketType<CoeffReturnType, Device>::size;
    251     EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
    252     eigen_assert(index + packetSize - 1 < dimensions().TotalSize());
    253 
    254     EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
    255     EIGEN_UNROLL_LOOP
    256     for (int i = 0; i < packetSize; ++i) {
    257       values[i] = coeff(index+i);
    258     }
    259     PacketReturnType rslt = internal::pload<PacketReturnType>(values);
    260     return rslt;
    261   }
    262 
    263   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
    264   costPerCoeff(bool vectorized) const {
    265     const double compute_cost = NumDims * (2 * TensorOpCost::AddCost<Index>() +
    266                                            2 * TensorOpCost::MulCost<Index>() +
    267                                            TensorOpCost::DivCost<Index>() +
    268                                            TensorOpCost::ModCost<Index>());
    269     const double lhs_size = m_leftImpl.dimensions().TotalSize();
    270     const double rhs_size = m_rightImpl.dimensions().TotalSize();
    271     return (lhs_size / (lhs_size + rhs_size)) *
    272                m_leftImpl.costPerCoeff(vectorized) +
    273            (rhs_size / (lhs_size + rhs_size)) *
    274                m_rightImpl.costPerCoeff(vectorized) +
    275            TensorOpCost(0, 0, compute_cost);
    276   }
    277 
    278   EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
    279 
    280   #ifdef EIGEN_USE_SYCL
    281   // binding placeholder accessors to a command group handler for SYCL
    282   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
    283     m_leftImpl.bind(cgh);
    284     m_rightImpl.bind(cgh);
    285   }
    286   #endif
    287 
    288   protected:
    289     Dimensions m_dimensions;
    290     array<Index, NumDims> m_outputStrides;
    291     array<Index, NumDims> m_leftStrides;
    292     array<Index, NumDims> m_rightStrides;
    293     TensorEvaluator<LeftArgType, Device> m_leftImpl;
    294     TensorEvaluator<RightArgType, Device> m_rightImpl;
    295     const Axis m_axis;
    296 };
    297 
    298 // Eval as lvalue
    299 template<typename Axis, typename LeftArgType, typename RightArgType, typename Device>
    300   struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
    301   : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
    302 {
    303   typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base;
    304   typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
    305   typedef typename Base::Dimensions Dimensions;
    306   enum {
    307     IsAligned         = false,
    308     PacketAccess      = TensorEvaluator<LeftArgType, Device>::PacketAccess &&
    309                         TensorEvaluator<RightArgType, Device>::PacketAccess,
    310     BlockAccess       = false,
    311     PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess ||
    312                         TensorEvaluator<RightArgType, Device>::PreferBlockAccess,
    313     Layout            = TensorEvaluator<LeftArgType, Device>::Layout,
    314     RawAccess         = false
    315   };
    316 
    317   //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
    318   typedef internal::TensorBlockNotImplemented TensorBlock;
    319   //===--------------------------------------------------------------------===//
    320 
    321   EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device)
    322     : Base(op, device)
    323   {
    324     EIGEN_STATIC_ASSERT((static_cast<int>(Layout) == static_cast<int>(ColMajor)), YOU_MADE_A_PROGRAMMING_MISTAKE);
    325   }
    326 
    327   typedef typename XprType::Index Index;
    328   typedef typename XprType::Scalar Scalar;
    329   typedef typename XprType::CoeffReturnType CoeffReturnType;
    330   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
    331 
    332   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
    333   {
    334     // Collect dimension-wise indices (subs).
    335     array<Index, Base::NumDims> subs;
    336     for (int i = Base::NumDims - 1; i > 0; --i) {
    337       subs[i] = index / this->m_outputStrides[i];
    338       index -= subs[i] * this->m_outputStrides[i];
    339     }
    340     subs[0] = index;
    341 
    342     const Dimensions& left_dims = this->m_leftImpl.dimensions();
    343     if (subs[this->m_axis] < left_dims[this->m_axis]) {
    344       Index left_index = subs[0];
    345       for (int i = 1; i < Base::NumDims; ++i) {
    346         left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i];
    347       }
    348       return this->m_leftImpl.coeffRef(left_index);
    349     } else {
    350       subs[this->m_axis] -= left_dims[this->m_axis];
    351       const Dimensions& right_dims = this->m_rightImpl.dimensions();
    352       Index right_index = subs[0];
    353       for (int i = 1; i < Base::NumDims; ++i) {
    354         right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i];
    355       }
    356       return this->m_rightImpl.coeffRef(right_index);
    357     }
    358   }
    359 
    360   template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
    361   void writePacket(Index index, const PacketReturnType& x)
    362   {
    363     const int packetSize = PacketType<CoeffReturnType, Device>::size;
    364     EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
    365     eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize());
    366 
    367     EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
    368     internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
    369     for (int i = 0; i < packetSize; ++i) {
    370       coeffRef(index+i) = values[i];
    371     }
    372   }
    373 };
    374 
    375 } // end namespace Eigen
    376 
    377 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H