cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

cxx11_tensor_concatenation.cpp (4624B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #include "main.h"
     11 
     12 #include <Eigen/CXX11/Tensor>
     13 
     14 using Eigen::Tensor;
     15 
     16 template<int DataLayout>
     17 static void test_dimension_failures()
     18 {
     19   Tensor<int, 3, DataLayout> left(2, 3, 1);
     20   Tensor<int, 3, DataLayout> right(3, 3, 1);
     21   left.setRandom();
     22   right.setRandom();
     23 
     24   // Okay; other dimensions are equal.
     25   Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
     26 
     27   // Dimension mismatches.
     28   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1));
     29   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2));
     30 
     31   // Axis > NumDims or < 0.
     32   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3));
     33   VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1));
     34 }
     35 
     36 template<int DataLayout>
     37 static void test_static_dimension_failure()
     38 {
     39   Tensor<int, 2, DataLayout> left(2, 3);
     40   Tensor<int, 3, DataLayout> right(2, 3, 1);
     41 
     42 #ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE
     43   // Technically compatible, but we static assert that the inputs have same
     44   // NumDims.
     45   Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
     46 #endif
     47 
     48   // This can be worked around in this case.
     49   Tensor<int, 3, DataLayout> concatenation = left
     50       .reshape(Tensor<int, 3>::Dimensions(2, 3, 1))
     51       .concatenate(right, 0);
     52   Tensor<int, 2, DataLayout> alternative = left
     53    // Clang compiler break with {{{}}} with an ambiguous error on copy constructor
     54   // the variadic DSize constructor added for #ifndef EIGEN_EMULATE_CXX11_META_H.
     55   // Solution:
     56   // either the code should change to 
     57   //  Tensor<int, 2>::Dimensions{{2, 3}}
     58   // or Tensor<int, 2>::Dimensions{Tensor<int, 2>::Dimensions{{2, 3}}}
     59       .concatenate(right.reshape(Tensor<int, 2>::Dimensions(2, 3)), 0);
     60 }
     61 
     62 template<int DataLayout>
     63 static void test_simple_concatenation()
     64 {
     65   Tensor<int, 3, DataLayout> left(2, 3, 1);
     66   Tensor<int, 3, DataLayout> right(2, 3, 1);
     67   left.setRandom();
     68   right.setRandom();
     69 
     70   Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
     71   VERIFY_IS_EQUAL(concatenation.dimension(0), 4);
     72   VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
     73   VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
     74   for (int j = 0; j < 3; ++j) {
     75     for (int i = 0; i < 2; ++i) {
     76       VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
     77     }
     78     for (int i = 2; i < 4; ++i) {
     79       VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0));
     80     }
     81   }
     82 
     83   concatenation = left.concatenate(right, 1);
     84   VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
     85   VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
     86   VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
     87   for (int i = 0; i < 2; ++i) {
     88     for (int j = 0; j < 3; ++j) {
     89       VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
     90     }
     91     for (int j = 3; j < 6; ++j) {
     92       VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0));
     93     }
     94   }
     95 
     96   concatenation = left.concatenate(right, 2);
     97   VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
     98   VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
     99   VERIFY_IS_EQUAL(concatenation.dimension(2), 2);
    100   for (int i = 0; i < 2; ++i) {
    101     for (int j = 0; j < 3; ++j) {
    102       VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
    103       VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0));
    104     }
    105   }
    106 }
    107 
    108 
    109 // TODO(phli): Add test once we have a real vectorized implementation.
    110 // static void test_vectorized_concatenation() {}
    111 
    112 static void test_concatenation_as_lvalue()
    113 {
    114   Tensor<int, 2> t1(2, 3);
    115   Tensor<int, 2> t2(2, 3);
    116   t1.setRandom();
    117   t2.setRandom();
    118 
    119   Tensor<int, 2> result(4, 3);
    120   result.setRandom();
    121   t1.concatenate(t2, 0) = result;
    122 
    123   for (int i = 0; i < 2; ++i) {
    124     for (int j = 0; j < 3; ++j) {
    125       VERIFY_IS_EQUAL(t1(i, j), result(i, j));
    126       VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
    127     }
    128   }
    129 }
    130 
    131 
    132 EIGEN_DECLARE_TEST(cxx11_tensor_concatenation)
    133 {
    134    CALL_SUBTEST(test_dimension_failures<ColMajor>());
    135    CALL_SUBTEST(test_dimension_failures<RowMajor>());
    136    CALL_SUBTEST(test_static_dimension_failure<ColMajor>());
    137    CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
    138    CALL_SUBTEST(test_simple_concatenation<ColMajor>());
    139    CALL_SUBTEST(test_simple_concatenation<RowMajor>());
    140    // CALL_SUBTEST(test_vectorized_concatenation());
    141    CALL_SUBTEST(test_concatenation_as_lvalue());
    142 
    143 }