cxx11_tensor_broadcasting.cpp (9209B)
1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #include "main.h" 11 12 #include <Eigen/CXX11/Tensor> 13 14 using Eigen::Tensor; 15 16 template <int DataLayout> 17 static void test_simple_broadcasting() 18 { 19 Tensor<float, 4, DataLayout> tensor(2,3,5,7); 20 tensor.setRandom(); 21 array<ptrdiff_t, 4> broadcasts; 22 broadcasts[0] = 1; 23 broadcasts[1] = 1; 24 broadcasts[2] = 1; 25 broadcasts[3] = 1; 26 27 Tensor<float, 4, DataLayout> no_broadcast; 28 no_broadcast = tensor.broadcast(broadcasts); 29 30 VERIFY_IS_EQUAL(no_broadcast.dimension(0), 2); 31 VERIFY_IS_EQUAL(no_broadcast.dimension(1), 3); 32 VERIFY_IS_EQUAL(no_broadcast.dimension(2), 5); 33 VERIFY_IS_EQUAL(no_broadcast.dimension(3), 7); 34 35 for (int i = 0; i < 2; ++i) { 36 for (int j = 0; j < 3; ++j) { 37 for (int k = 0; k < 5; ++k) { 38 for (int l = 0; l < 7; ++l) { 39 VERIFY_IS_EQUAL(tensor(i,j,k,l), no_broadcast(i,j,k,l)); 40 } 41 } 42 } 43 } 44 45 broadcasts[0] = 2; 46 broadcasts[1] = 3; 47 broadcasts[2] = 1; 48 broadcasts[3] = 4; 49 Tensor<float, 4, DataLayout> broadcast; 50 broadcast = tensor.broadcast(broadcasts); 51 52 VERIFY_IS_EQUAL(broadcast.dimension(0), 4); 53 VERIFY_IS_EQUAL(broadcast.dimension(1), 9); 54 VERIFY_IS_EQUAL(broadcast.dimension(2), 5); 55 VERIFY_IS_EQUAL(broadcast.dimension(3), 28); 56 57 for (int i = 0; i < 4; ++i) { 58 for (int j = 0; j < 9; ++j) { 59 for (int k = 0; k < 5; ++k) { 60 for (int l = 0; l < 28; ++l) { 61 VERIFY_IS_EQUAL(tensor(i%2,j%3,k%5,l%7), broadcast(i,j,k,l)); 62 } 63 } 64 } 65 } 66 } 67 68 69 template <int DataLayout> 70 static void test_vectorized_broadcasting() 71 { 72 Tensor<float, 3, DataLayout> tensor(8,3,5); 73 tensor.setRandom(); 74 array<ptrdiff_t, 3> broadcasts; 75 broadcasts[0] = 2; 76 broadcasts[1] = 3; 77 broadcasts[2] = 4; 78 79 Tensor<float, 3, DataLayout> broadcast; 80 broadcast = tensor.broadcast(broadcasts); 81 82 VERIFY_IS_EQUAL(broadcast.dimension(0), 16); 83 VERIFY_IS_EQUAL(broadcast.dimension(1), 9); 84 VERIFY_IS_EQUAL(broadcast.dimension(2), 20); 85 86 for (int i = 0; i < 16; ++i) { 87 for (int j = 0; j < 9; ++j) { 88 for (int k = 0; k < 20; ++k) { 89 VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k)); 90 } 91 } 92 } 93 94 #if EIGEN_HAS_VARIADIC_TEMPLATES 95 tensor.resize(11,3,5); 96 #else 97 array<Index, 3> new_dims; 98 new_dims[0] = 11; 99 new_dims[1] = 3; 100 new_dims[2] = 5; 101 tensor.resize(new_dims); 102 #endif 103 104 tensor.setRandom(); 105 broadcast = tensor.broadcast(broadcasts); 106 107 VERIFY_IS_EQUAL(broadcast.dimension(0), 22); 108 VERIFY_IS_EQUAL(broadcast.dimension(1), 9); 109 VERIFY_IS_EQUAL(broadcast.dimension(2), 20); 110 111 for (int i = 0; i < 22; ++i) { 112 for (int j = 0; j < 9; ++j) { 113 for (int k = 0; k < 20; ++k) { 114 VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k)); 115 } 116 } 117 } 118 } 119 120 121 template <int DataLayout> 122 static void test_static_broadcasting() 123 { 124 Tensor<float, 3, DataLayout> tensor(8,3,5); 125 tensor.setRandom(); 126 127 #if defined(EIGEN_HAS_INDEX_LIST) 128 Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts; 129 #else 130 Eigen::array<int, 3> broadcasts; 131 broadcasts[0] = 2; 132 broadcasts[1] = 3; 133 broadcasts[2] = 4; 134 #endif 135 136 Tensor<float, 3, DataLayout> broadcast; 137 broadcast = tensor.broadcast(broadcasts); 138 139 VERIFY_IS_EQUAL(broadcast.dimension(0), 16); 140 VERIFY_IS_EQUAL(broadcast.dimension(1), 9); 141 VERIFY_IS_EQUAL(broadcast.dimension(2), 20); 142 143 for (int i = 0; i < 16; ++i) { 144 for (int j = 0; j < 9; ++j) { 145 for (int k = 0; k < 20; ++k) { 146 VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k)); 147 } 148 } 149 } 150 151 #if EIGEN_HAS_VARIADIC_TEMPLATES 152 tensor.resize(11,3,5); 153 #else 154 array<Index, 3> new_dims; 155 new_dims[0] = 11; 156 new_dims[1] = 3; 157 new_dims[2] = 5; 158 tensor.resize(new_dims); 159 #endif 160 161 tensor.setRandom(); 162 broadcast = tensor.broadcast(broadcasts); 163 164 VERIFY_IS_EQUAL(broadcast.dimension(0), 22); 165 VERIFY_IS_EQUAL(broadcast.dimension(1), 9); 166 VERIFY_IS_EQUAL(broadcast.dimension(2), 20); 167 168 for (int i = 0; i < 22; ++i) { 169 for (int j = 0; j < 9; ++j) { 170 for (int k = 0; k < 20; ++k) { 171 VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k)); 172 } 173 } 174 } 175 } 176 177 178 template <int DataLayout> 179 static void test_fixed_size_broadcasting() 180 { 181 // Need to add a [] operator to the Size class for this to work 182 #if 0 183 Tensor<float, 1, DataLayout> t1(10); 184 t1.setRandom(); 185 TensorFixedSize<float, Sizes<1>, DataLayout> t2; 186 t2 = t2.constant(20.0f); 187 188 Tensor<float, 1, DataLayout> t3 = t1 + t2.broadcast(Eigen::array<int, 1>{{10}}); 189 for (int i = 0; i < 10; ++i) { 190 VERIFY_IS_APPROX(t3(i), t1(i) + t2(0)); 191 } 192 193 TensorMap<TensorFixedSize<float, Sizes<1>, DataLayout> > t4(t2.data(), {{1}}); 194 Tensor<float, 1, DataLayout> t5 = t1 + t4.broadcast(Eigen::array<int, 1>{{10}}); 195 for (int i = 0; i < 10; ++i) { 196 VERIFY_IS_APPROX(t5(i), t1(i) + t2(0)); 197 } 198 #endif 199 } 200 201 template <int DataLayout> 202 static void test_simple_broadcasting_one_by_n() 203 { 204 Tensor<float, 4, DataLayout> tensor(1,13,5,7); 205 tensor.setRandom(); 206 array<ptrdiff_t, 4> broadcasts; 207 broadcasts[0] = 9; 208 broadcasts[1] = 1; 209 broadcasts[2] = 1; 210 broadcasts[3] = 1; 211 Tensor<float, 4, DataLayout> broadcast; 212 broadcast = tensor.broadcast(broadcasts); 213 214 VERIFY_IS_EQUAL(broadcast.dimension(0), 9); 215 VERIFY_IS_EQUAL(broadcast.dimension(1), 13); 216 VERIFY_IS_EQUAL(broadcast.dimension(2), 5); 217 VERIFY_IS_EQUAL(broadcast.dimension(3), 7); 218 219 for (int i = 0; i < 9; ++i) { 220 for (int j = 0; j < 13; ++j) { 221 for (int k = 0; k < 5; ++k) { 222 for (int l = 0; l < 7; ++l) { 223 VERIFY_IS_EQUAL(tensor(i%1,j%13,k%5,l%7), broadcast(i,j,k,l)); 224 } 225 } 226 } 227 } 228 } 229 230 template <int DataLayout> 231 static void test_simple_broadcasting_n_by_one() 232 { 233 Tensor<float, 4, DataLayout> tensor(7,3,5,1); 234 tensor.setRandom(); 235 array<ptrdiff_t, 4> broadcasts; 236 broadcasts[0] = 1; 237 broadcasts[1] = 1; 238 broadcasts[2] = 1; 239 broadcasts[3] = 19; 240 Tensor<float, 4, DataLayout> broadcast; 241 broadcast = tensor.broadcast(broadcasts); 242 243 VERIFY_IS_EQUAL(broadcast.dimension(0), 7); 244 VERIFY_IS_EQUAL(broadcast.dimension(1), 3); 245 VERIFY_IS_EQUAL(broadcast.dimension(2), 5); 246 VERIFY_IS_EQUAL(broadcast.dimension(3), 19); 247 248 for (int i = 0; i < 7; ++i) { 249 for (int j = 0; j < 3; ++j) { 250 for (int k = 0; k < 5; ++k) { 251 for (int l = 0; l < 19; ++l) { 252 VERIFY_IS_EQUAL(tensor(i%7,j%3,k%5,l%1), broadcast(i,j,k,l)); 253 } 254 } 255 } 256 } 257 } 258 259 template <int DataLayout> 260 static void test_simple_broadcasting_one_by_n_by_one_1d() 261 { 262 Tensor<float, 3, DataLayout> tensor(1,7,1); 263 tensor.setRandom(); 264 array<ptrdiff_t, 3> broadcasts; 265 broadcasts[0] = 5; 266 broadcasts[1] = 1; 267 broadcasts[2] = 13; 268 Tensor<float, 3, DataLayout> broadcasted; 269 broadcasted = tensor.broadcast(broadcasts); 270 271 VERIFY_IS_EQUAL(broadcasted.dimension(0), 5); 272 VERIFY_IS_EQUAL(broadcasted.dimension(1), 7); 273 VERIFY_IS_EQUAL(broadcasted.dimension(2), 13); 274 275 for (int i = 0; i < 5; ++i) { 276 for (int j = 0; j < 7; ++j) { 277 for (int k = 0; k < 13; ++k) { 278 VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k)); 279 } 280 } 281 } 282 } 283 284 template <int DataLayout> 285 static void test_simple_broadcasting_one_by_n_by_one_2d() 286 { 287 Tensor<float, 4, DataLayout> tensor(1,7,13,1); 288 tensor.setRandom(); 289 array<ptrdiff_t, 4> broadcasts; 290 broadcasts[0] = 5; 291 broadcasts[1] = 1; 292 broadcasts[2] = 1; 293 broadcasts[3] = 19; 294 Tensor<float, 4, DataLayout> broadcast; 295 broadcast = tensor.broadcast(broadcasts); 296 297 VERIFY_IS_EQUAL(broadcast.dimension(0), 5); 298 VERIFY_IS_EQUAL(broadcast.dimension(1), 7); 299 VERIFY_IS_EQUAL(broadcast.dimension(2), 13); 300 VERIFY_IS_EQUAL(broadcast.dimension(3), 19); 301 302 for (int i = 0; i < 5; ++i) { 303 for (int j = 0; j < 7; ++j) { 304 for (int k = 0; k < 13; ++k) { 305 for (int l = 0; l < 19; ++l) { 306 VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l)); 307 } 308 } 309 } 310 } 311 } 312 313 EIGEN_DECLARE_TEST(cxx11_tensor_broadcasting) 314 { 315 CALL_SUBTEST(test_simple_broadcasting<ColMajor>()); 316 CALL_SUBTEST(test_simple_broadcasting<RowMajor>()); 317 CALL_SUBTEST(test_vectorized_broadcasting<ColMajor>()); 318 CALL_SUBTEST(test_vectorized_broadcasting<RowMajor>()); 319 CALL_SUBTEST(test_static_broadcasting<ColMajor>()); 320 CALL_SUBTEST(test_static_broadcasting<RowMajor>()); 321 CALL_SUBTEST(test_fixed_size_broadcasting<ColMajor>()); 322 CALL_SUBTEST(test_fixed_size_broadcasting<RowMajor>()); 323 CALL_SUBTEST(test_simple_broadcasting_one_by_n<RowMajor>()); 324 CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>()); 325 CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>()); 326 CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>()); 327 CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<ColMajor>()); 328 CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<ColMajor>()); 329 CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<RowMajor>()); 330 CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<RowMajor>()); 331 }