cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

TensorExpr.h (16115B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
     11 #define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
     12 
     13 namespace Eigen {
     14 
     15 /** \class TensorExpr
     16   * \ingroup CXX11_Tensor_Module
     17   *
     18   * \brief Tensor expression classes.
     19   *
     20   * The TensorCwiseNullaryOp class applies a nullary operators to an expression.
     21   * This is typically used to generate constants.
     22   *
     23   * The TensorCwiseUnaryOp class represents an expression where a unary operator
     24   * (e.g. cwiseSqrt) is applied to an expression.
     25   *
     26   * The TensorCwiseBinaryOp class represents an expression where a binary
     27   * operator (e.g. addition) is applied to a lhs and a rhs expression.
     28   *
     29   */
     30 namespace internal {
     31 template<typename NullaryOp, typename XprType>
     32 struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
     33     : traits<XprType>
     34 {
     35   typedef traits<XprType> XprTraits;
     36   typedef typename XprType::Scalar Scalar;
     37   typedef typename XprType::Nested XprTypeNested;
     38   typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
     39   static const int NumDimensions = XprTraits::NumDimensions;
     40   static const int Layout = XprTraits::Layout;
     41   typedef typename XprTraits::PointerType PointerType;
     42   enum {
     43     Flags = 0
     44   };
     45 };
     46 
     47 }  // end namespace internal
     48 
     49 
     50 
     51 template<typename NullaryOp, typename XprType>
     52 class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors>
     53 {
     54   public:
     55     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
     56     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
     57     typedef typename XprType::CoeffReturnType CoeffReturnType;
     58     typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
     59     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
     60     typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
     61 
     62     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
     63         : m_xpr(xpr), m_functor(func) {}
     64 
     65     EIGEN_DEVICE_FUNC
     66     const typename internal::remove_all<typename XprType::Nested>::type&
     67     nestedExpression() const { return m_xpr; }
     68 
     69     EIGEN_DEVICE_FUNC
     70     const NullaryOp& functor() const { return m_functor; }
     71 
     72   protected:
     73     typename XprType::Nested m_xpr;
     74     const NullaryOp m_functor;
     75 };
     76 
     77 
     78 
     79 namespace internal {
     80 template<typename UnaryOp, typename XprType>
     81 struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
     82     : traits<XprType>
     83 {
     84   // TODO(phli): Add InputScalar, InputPacket.  Check references to
     85   // current Scalar/Packet to see if the intent is Input or Output.
     86   typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
     87   typedef traits<XprType> XprTraits;
     88   typedef typename XprType::Nested XprTypeNested;
     89   typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
     90   static const int NumDimensions = XprTraits::NumDimensions;
     91   static const int Layout = XprTraits::Layout;
     92   typedef typename TypeConversion<Scalar, 
     93                                   typename XprTraits::PointerType
     94                                   >::type 
     95                                   PointerType;
     96 };
     97 
     98 template<typename UnaryOp, typename XprType>
     99 struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense>
    100 {
    101   typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
    102 };
    103 
    104 template<typename UnaryOp, typename XprType>
    105 struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type>
    106 {
    107   typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
    108 };
    109 
    110 }  // end namespace internal
    111 
    112 
    113 
    114 template<typename UnaryOp, typename XprType>
    115 class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors>
    116 {
    117   public:
    118     // TODO(phli): Add InputScalar, InputPacket.  Check references to
    119     // current Scalar/Packet to see if the intent is Input or Output.
    120     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
    121     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    122     typedef Scalar CoeffReturnType;
    123     typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
    124     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
    125     typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
    126 
    127     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
    128       : m_xpr(xpr), m_functor(func) {}
    129 
    130     EIGEN_DEVICE_FUNC
    131     const UnaryOp& functor() const { return m_functor; }
    132 
    133     /** \returns the nested expression */
    134     EIGEN_DEVICE_FUNC
    135     const typename internal::remove_all<typename XprType::Nested>::type&
    136     nestedExpression() const { return m_xpr; }
    137 
    138   protected:
    139     typename XprType::Nested m_xpr;
    140     const UnaryOp m_functor;
    141 };
    142 
    143 
    144 namespace internal {
    145 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
    146 struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
    147 {
    148   // Type promotion to handle the case where the types of the lhs and the rhs
    149   // are different.
    150   // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
    151   // current Scalar/Packet to see if the intent is Inputs or Output.
    152   typedef typename result_of<
    153       BinaryOp(typename LhsXprType::Scalar,
    154                typename RhsXprType::Scalar)>::type Scalar;
    155   typedef traits<LhsXprType> XprTraits;
    156   typedef typename promote_storage_type<
    157       typename traits<LhsXprType>::StorageKind,
    158       typename traits<RhsXprType>::StorageKind>::ret StorageKind;
    159   typedef typename promote_index_type<
    160       typename traits<LhsXprType>::Index,
    161       typename traits<RhsXprType>::Index>::type Index;
    162   typedef typename LhsXprType::Nested LhsNested;
    163   typedef typename RhsXprType::Nested RhsNested;
    164   typedef typename remove_reference<LhsNested>::type _LhsNested;
    165   typedef typename remove_reference<RhsNested>::type _RhsNested;
    166   static const int NumDimensions = XprTraits::NumDimensions;
    167   static const int Layout = XprTraits::Layout;
    168   typedef typename TypeConversion<Scalar,
    169                                   typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
    170                                                       typename traits<LhsXprType>::PointerType,
    171                                                       typename traits<RhsXprType>::PointerType>::type
    172                                   >::type 
    173                                   PointerType;
    174   enum {
    175     Flags = 0
    176   };
    177 };
    178 
    179 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
    180 struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense>
    181 {
    182   typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
    183 };
    184 
    185 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
    186 struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type>
    187 {
    188   typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
    189 };
    190 
    191 }  // end namespace internal
    192 
    193 
    194 
    195 template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
    196 class TensorCwiseBinaryOp : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors>
    197 {
    198   public:
    199     // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket.  Check references to
    200     // current Scalar/Packet to see if the intent is Inputs or Output.
    201     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
    202     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    203     typedef Scalar CoeffReturnType;
    204     typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
    205     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
    206     typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
    207 
    208     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
    209         : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
    210 
    211     EIGEN_DEVICE_FUNC
    212     const BinaryOp& functor() const { return m_functor; }
    213 
    214     /** \returns the nested expressions */
    215     EIGEN_DEVICE_FUNC
    216     const typename internal::remove_all<typename LhsXprType::Nested>::type&
    217     lhsExpression() const { return m_lhs_xpr; }
    218 
    219     EIGEN_DEVICE_FUNC
    220     const typename internal::remove_all<typename RhsXprType::Nested>::type&
    221     rhsExpression() const { return m_rhs_xpr; }
    222 
    223   protected:
    224     typename LhsXprType::Nested m_lhs_xpr;
    225     typename RhsXprType::Nested m_rhs_xpr;
    226     const BinaryOp m_functor;
    227 };
    228 
    229 
    230 namespace internal {
    231 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
    232 struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
    233 {
    234   // Type promotion to handle the case where the types of the args are different.
    235   typedef typename result_of<
    236       TernaryOp(typename Arg1XprType::Scalar,
    237                 typename Arg2XprType::Scalar,
    238                 typename Arg3XprType::Scalar)>::type Scalar;
    239   typedef traits<Arg1XprType> XprTraits;
    240   typedef typename traits<Arg1XprType>::StorageKind StorageKind;
    241   typedef typename traits<Arg1XprType>::Index Index;
    242   typedef typename Arg1XprType::Nested Arg1Nested;
    243   typedef typename Arg2XprType::Nested Arg2Nested;
    244   typedef typename Arg3XprType::Nested Arg3Nested;
    245   typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
    246   typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
    247   typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
    248   static const int NumDimensions = XprTraits::NumDimensions;
    249   static const int Layout = XprTraits::Layout;
    250   typedef typename TypeConversion<Scalar,
    251                                   typename conditional<Pointer_type_promotion<typename Arg2XprType::Scalar, Scalar>::val,
    252                                                       typename traits<Arg2XprType>::PointerType,
    253                                                       typename traits<Arg3XprType>::PointerType>::type
    254                                   >::type 
    255                                   PointerType;
    256   enum {
    257     Flags = 0
    258   };
    259 };
    260 
    261 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
    262 struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
    263 {
    264   typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
    265 };
    266 
    267 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
    268 struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
    269 {
    270   typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
    271 };
    272 
    273 }  // end namespace internal
    274 
    275 
    276 
    277 template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
    278 class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
    279 {
    280   public:
    281     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
    282     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    283     typedef Scalar CoeffReturnType;
    284     typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
    285     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
    286     typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
    287 
    288     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
    289         : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
    290 
    291     EIGEN_DEVICE_FUNC
    292     const TernaryOp& functor() const { return m_functor; }
    293 
    294     /** \returns the nested expressions */
    295     EIGEN_DEVICE_FUNC
    296     const typename internal::remove_all<typename Arg1XprType::Nested>::type&
    297     arg1Expression() const { return m_arg1_xpr; }
    298 
    299     EIGEN_DEVICE_FUNC
    300     const typename internal::remove_all<typename Arg2XprType::Nested>::type&
    301     arg2Expression() const { return m_arg2_xpr; }
    302 
    303     EIGEN_DEVICE_FUNC
    304     const typename internal::remove_all<typename Arg3XprType::Nested>::type&
    305     arg3Expression() const { return m_arg3_xpr; }
    306 
    307   protected:
    308     typename Arg1XprType::Nested m_arg1_xpr;
    309     typename Arg2XprType::Nested m_arg2_xpr;
    310     typename Arg3XprType::Nested m_arg3_xpr;
    311     const TernaryOp m_functor;
    312 };
    313 
    314 
    315 namespace internal {
    316 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
    317 struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
    318     : traits<ThenXprType>
    319 {
    320   typedef typename traits<ThenXprType>::Scalar Scalar;
    321   typedef traits<ThenXprType> XprTraits;
    322   typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
    323                                         typename traits<ElseXprType>::StorageKind>::ret StorageKind;
    324   typedef typename promote_index_type<typename traits<ElseXprType>::Index,
    325                                       typename traits<ThenXprType>::Index>::type Index;
    326   typedef typename IfXprType::Nested IfNested;
    327   typedef typename ThenXprType::Nested ThenNested;
    328   typedef typename ElseXprType::Nested ElseNested;
    329   static const int NumDimensions = XprTraits::NumDimensions;
    330   static const int Layout = XprTraits::Layout;
    331   typedef typename conditional<Pointer_type_promotion<typename ThenXprType::Scalar, Scalar>::val,
    332                                typename traits<ThenXprType>::PointerType,
    333                                typename traits<ElseXprType>::PointerType>::type PointerType;
    334 };
    335 
    336 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
    337 struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
    338 {
    339   typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
    340 };
    341 
    342 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
    343 struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
    344 {
    345   typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
    346 };
    347 
    348 }  // end namespace internal
    349 
    350 
    351 template<typename IfXprType, typename ThenXprType, typename ElseXprType>
    352 class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors>
    353 {
    354   public:
    355     typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
    356     typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
    357     typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
    358                                                     typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
    359     typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
    360     typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
    361     typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
    362 
    363     EIGEN_DEVICE_FUNC
    364     TensorSelectOp(const IfXprType& a_condition,
    365                    const ThenXprType& a_then,
    366                    const ElseXprType& a_else)
    367       : m_condition(a_condition), m_then(a_then), m_else(a_else)
    368     { }
    369 
    370     EIGEN_DEVICE_FUNC
    371     const IfXprType& ifExpression() const { return m_condition; }
    372 
    373     EIGEN_DEVICE_FUNC
    374     const ThenXprType& thenExpression() const { return m_then; }
    375 
    376     EIGEN_DEVICE_FUNC
    377     const ElseXprType& elseExpression() const { return m_else; }
    378 
    379   protected:
    380     typename IfXprType::Nested m_condition;
    381     typename ThenXprType::Nested m_then;
    382     typename ElseXprType::Nested m_else;
    383 };
    384 
    385 
    386 } // end namespace Eigen
    387 
    388 #endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H