cxx11_tensor_trace.cpp (5129B)
1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #include "main.h" 11 12 #include <Eigen/CXX11/Tensor> 13 14 using Eigen::Tensor; 15 using Eigen::array; 16 17 template <int DataLayout> 18 static void test_0D_trace() { 19 Tensor<float, 0, DataLayout> tensor; 20 tensor.setRandom(); 21 array<ptrdiff_t, 0> dims; 22 Tensor<float, 0, DataLayout> result = tensor.trace(dims); 23 VERIFY_IS_EQUAL(result(), tensor()); 24 } 25 26 27 template <int DataLayout> 28 static void test_all_dimensions_trace() { 29 Tensor<float, 3, DataLayout> tensor1(5, 5, 5); 30 tensor1.setRandom(); 31 Tensor<float, 0, DataLayout> result1 = tensor1.trace(); 32 VERIFY_IS_EQUAL(result1.rank(), 0); 33 float sum = 0.0f; 34 for (int i = 0; i < 5; ++i) { 35 sum += tensor1(i, i, i); 36 } 37 VERIFY_IS_EQUAL(result1(), sum); 38 39 Tensor<float, 5, DataLayout> tensor2(7, 7, 7, 7, 7); 40 tensor2.setRandom(); 41 array<ptrdiff_t, 5> dims = { { 2, 1, 0, 3, 4 } }; 42 Tensor<float, 0, DataLayout> result2 = tensor2.trace(dims); 43 VERIFY_IS_EQUAL(result2.rank(), 0); 44 sum = 0.0f; 45 for (int i = 0; i < 7; ++i) { 46 sum += tensor2(i, i, i, i, i); 47 } 48 VERIFY_IS_EQUAL(result2(), sum); 49 } 50 51 52 template <int DataLayout> 53 static void test_simple_trace() { 54 Tensor<float, 3, DataLayout> tensor1(3, 5, 3); 55 tensor1.setRandom(); 56 array<ptrdiff_t, 2> dims1 = { { 0, 2 } }; 57 Tensor<float, 1, DataLayout> result1 = tensor1.trace(dims1); 58 VERIFY_IS_EQUAL(result1.rank(), 1); 59 VERIFY_IS_EQUAL(result1.dimension(0), 5); 60 float sum = 0.0f; 61 for (int i = 0; i < 5; ++i) { 62 sum = 0.0f; 63 for (int j = 0; j < 3; ++j) { 64 sum += tensor1(j, i, j); 65 } 66 VERIFY_IS_EQUAL(result1(i), sum); 67 } 68 69 Tensor<float, 4, DataLayout> tensor2(5, 5, 7, 7); 70 tensor2.setRandom(); 71 array<ptrdiff_t, 2> dims2 = { { 2, 3 } }; 72 Tensor<float, 2, DataLayout> result2 = tensor2.trace(dims2); 73 VERIFY_IS_EQUAL(result2.rank(), 2); 74 VERIFY_IS_EQUAL(result2.dimension(0), 5); 75 VERIFY_IS_EQUAL(result2.dimension(1), 5); 76 for (int i = 0; i < 5; ++i) { 77 for (int j = 0; j < 5; ++j) { 78 sum = 0.0f; 79 for (int k = 0; k < 7; ++k) { 80 sum += tensor2(i, j, k, k); 81 } 82 VERIFY_IS_EQUAL(result2(i, j), sum); 83 } 84 } 85 86 array<ptrdiff_t, 2> dims3 = { { 1, 0 } }; 87 Tensor<float, 2, DataLayout> result3 = tensor2.trace(dims3); 88 VERIFY_IS_EQUAL(result3.rank(), 2); 89 VERIFY_IS_EQUAL(result3.dimension(0), 7); 90 VERIFY_IS_EQUAL(result3.dimension(1), 7); 91 for (int i = 0; i < 7; ++i) { 92 for (int j = 0; j < 7; ++j) { 93 sum = 0.0f; 94 for (int k = 0; k < 5; ++k) { 95 sum += tensor2(k, k, i, j); 96 } 97 VERIFY_IS_EQUAL(result3(i, j), sum); 98 } 99 } 100 101 Tensor<float, 5, DataLayout> tensor3(3, 7, 3, 7, 3); 102 tensor3.setRandom(); 103 array<ptrdiff_t, 3> dims4 = { { 0, 2, 4 } }; 104 Tensor<float, 2, DataLayout> result4 = tensor3.trace(dims4); 105 VERIFY_IS_EQUAL(result4.rank(), 2); 106 VERIFY_IS_EQUAL(result4.dimension(0), 7); 107 VERIFY_IS_EQUAL(result4.dimension(1), 7); 108 for (int i = 0; i < 7; ++i) { 109 for (int j = 0; j < 7; ++j) { 110 sum = 0.0f; 111 for (int k = 0; k < 3; ++k) { 112 sum += tensor3(k, i, k, j, k); 113 } 114 VERIFY_IS_EQUAL(result4(i, j), sum); 115 } 116 } 117 118 Tensor<float, 5, DataLayout> tensor4(3, 7, 4, 7, 5); 119 tensor4.setRandom(); 120 array<ptrdiff_t, 2> dims5 = { { 1, 3 } }; 121 Tensor<float, 3, DataLayout> result5 = tensor4.trace(dims5); 122 VERIFY_IS_EQUAL(result5.rank(), 3); 123 VERIFY_IS_EQUAL(result5.dimension(0), 3); 124 VERIFY_IS_EQUAL(result5.dimension(1), 4); 125 VERIFY_IS_EQUAL(result5.dimension(2), 5); 126 for (int i = 0; i < 3; ++i) { 127 for (int j = 0; j < 4; ++j) { 128 for (int k = 0; k < 5; ++k) { 129 sum = 0.0f; 130 for (int l = 0; l < 7; ++l) { 131 sum += tensor4(i, l, j, l, k); 132 } 133 VERIFY_IS_EQUAL(result5(i, j, k), sum); 134 } 135 } 136 } 137 } 138 139 140 template<int DataLayout> 141 static void test_trace_in_expr() { 142 Tensor<float, 4, DataLayout> tensor(2, 3, 5, 3); 143 tensor.setRandom(); 144 array<ptrdiff_t, 2> dims = { { 1, 3 } }; 145 Tensor<float, 2, DataLayout> result(2, 5); 146 result = result.constant(1.0f) - tensor.trace(dims); 147 VERIFY_IS_EQUAL(result.rank(), 2); 148 VERIFY_IS_EQUAL(result.dimension(0), 2); 149 VERIFY_IS_EQUAL(result.dimension(1), 5); 150 float sum = 0.0f; 151 for (int i = 0; i < 2; ++i) { 152 for (int j = 0; j < 5; ++j) { 153 sum = 0.0f; 154 for (int k = 0; k < 3; ++k) { 155 sum += tensor(i, k, j, k); 156 } 157 VERIFY_IS_EQUAL(result(i, j), 1.0f - sum); 158 } 159 } 160 } 161 162 163 EIGEN_DECLARE_TEST(cxx11_tensor_trace) { 164 CALL_SUBTEST(test_0D_trace<ColMajor>()); 165 CALL_SUBTEST(test_0D_trace<RowMajor>()); 166 CALL_SUBTEST(test_all_dimensions_trace<ColMajor>()); 167 CALL_SUBTEST(test_all_dimensions_trace<RowMajor>()); 168 CALL_SUBTEST(test_simple_trace<ColMajor>()); 169 CALL_SUBTEST(test_simple_trace<RowMajor>()); 170 CALL_SUBTEST(test_trace_in_expr<ColMajor>()); 171 CALL_SUBTEST(test_trace_in_expr<RowMajor>()); 172 }