cxx11_tensor_concatenation.cpp (4624B)
1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #include "main.h" 11 12 #include <Eigen/CXX11/Tensor> 13 14 using Eigen::Tensor; 15 16 template<int DataLayout> 17 static void test_dimension_failures() 18 { 19 Tensor<int, 3, DataLayout> left(2, 3, 1); 20 Tensor<int, 3, DataLayout> right(3, 3, 1); 21 left.setRandom(); 22 right.setRandom(); 23 24 // Okay; other dimensions are equal. 25 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0); 26 27 // Dimension mismatches. 28 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1)); 29 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2)); 30 31 // Axis > NumDims or < 0. 32 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3)); 33 VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1)); 34 } 35 36 template<int DataLayout> 37 static void test_static_dimension_failure() 38 { 39 Tensor<int, 2, DataLayout> left(2, 3); 40 Tensor<int, 3, DataLayout> right(2, 3, 1); 41 42 #ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE 43 // Technically compatible, but we static assert that the inputs have same 44 // NumDims. 45 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0); 46 #endif 47 48 // This can be worked around in this case. 49 Tensor<int, 3, DataLayout> concatenation = left 50 .reshape(Tensor<int, 3>::Dimensions(2, 3, 1)) 51 .concatenate(right, 0); 52 Tensor<int, 2, DataLayout> alternative = left 53 // Clang compiler break with {{{}}} with an ambiguous error on copy constructor 54 // the variadic DSize constructor added for #ifndef EIGEN_EMULATE_CXX11_META_H. 55 // Solution: 56 // either the code should change to 57 // Tensor<int, 2>::Dimensions{{2, 3}} 58 // or Tensor<int, 2>::Dimensions{Tensor<int, 2>::Dimensions{{2, 3}}} 59 .concatenate(right.reshape(Tensor<int, 2>::Dimensions(2, 3)), 0); 60 } 61 62 template<int DataLayout> 63 static void test_simple_concatenation() 64 { 65 Tensor<int, 3, DataLayout> left(2, 3, 1); 66 Tensor<int, 3, DataLayout> right(2, 3, 1); 67 left.setRandom(); 68 right.setRandom(); 69 70 Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0); 71 VERIFY_IS_EQUAL(concatenation.dimension(0), 4); 72 VERIFY_IS_EQUAL(concatenation.dimension(1), 3); 73 VERIFY_IS_EQUAL(concatenation.dimension(2), 1); 74 for (int j = 0; j < 3; ++j) { 75 for (int i = 0; i < 2; ++i) { 76 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); 77 } 78 for (int i = 2; i < 4; ++i) { 79 VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0)); 80 } 81 } 82 83 concatenation = left.concatenate(right, 1); 84 VERIFY_IS_EQUAL(concatenation.dimension(0), 2); 85 VERIFY_IS_EQUAL(concatenation.dimension(1), 6); 86 VERIFY_IS_EQUAL(concatenation.dimension(2), 1); 87 for (int i = 0; i < 2; ++i) { 88 for (int j = 0; j < 3; ++j) { 89 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); 90 } 91 for (int j = 3; j < 6; ++j) { 92 VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0)); 93 } 94 } 95 96 concatenation = left.concatenate(right, 2); 97 VERIFY_IS_EQUAL(concatenation.dimension(0), 2); 98 VERIFY_IS_EQUAL(concatenation.dimension(1), 3); 99 VERIFY_IS_EQUAL(concatenation.dimension(2), 2); 100 for (int i = 0; i < 2; ++i) { 101 for (int j = 0; j < 3; ++j) { 102 VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0)); 103 VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0)); 104 } 105 } 106 } 107 108 109 // TODO(phli): Add test once we have a real vectorized implementation. 110 // static void test_vectorized_concatenation() {} 111 112 static void test_concatenation_as_lvalue() 113 { 114 Tensor<int, 2> t1(2, 3); 115 Tensor<int, 2> t2(2, 3); 116 t1.setRandom(); 117 t2.setRandom(); 118 119 Tensor<int, 2> result(4, 3); 120 result.setRandom(); 121 t1.concatenate(t2, 0) = result; 122 123 for (int i = 0; i < 2; ++i) { 124 for (int j = 0; j < 3; ++j) { 125 VERIFY_IS_EQUAL(t1(i, j), result(i, j)); 126 VERIFY_IS_EQUAL(t2(i, j), result(i+2, j)); 127 } 128 } 129 } 130 131 132 EIGEN_DECLARE_TEST(cxx11_tensor_concatenation) 133 { 134 CALL_SUBTEST(test_dimension_failures<ColMajor>()); 135 CALL_SUBTEST(test_dimension_failures<RowMajor>()); 136 CALL_SUBTEST(test_static_dimension_failure<ColMajor>()); 137 CALL_SUBTEST(test_static_dimension_failure<RowMajor>()); 138 CALL_SUBTEST(test_simple_concatenation<ColMajor>()); 139 CALL_SUBTEST(test_simple_concatenation<RowMajor>()); 140 // CALL_SUBTEST(test_vectorized_concatenation()); 141 CALL_SUBTEST(test_concatenation_as_lvalue()); 142 143 }