cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

cxx11_tensor_broadcasting.cpp (9209B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #include "main.h"
     11 
     12 #include <Eigen/CXX11/Tensor>
     13 
     14 using Eigen::Tensor;
     15 
     16 template <int DataLayout>
     17 static void test_simple_broadcasting()
     18 {
     19   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
     20   tensor.setRandom();
     21   array<ptrdiff_t, 4> broadcasts;
     22   broadcasts[0] = 1;
     23   broadcasts[1] = 1;
     24   broadcasts[2] = 1;
     25   broadcasts[3] = 1;
     26 
     27   Tensor<float, 4, DataLayout> no_broadcast;
     28   no_broadcast = tensor.broadcast(broadcasts);
     29 
     30   VERIFY_IS_EQUAL(no_broadcast.dimension(0), 2);
     31   VERIFY_IS_EQUAL(no_broadcast.dimension(1), 3);
     32   VERIFY_IS_EQUAL(no_broadcast.dimension(2), 5);
     33   VERIFY_IS_EQUAL(no_broadcast.dimension(3), 7);
     34 
     35   for (int i = 0; i < 2; ++i) {
     36     for (int j = 0; j < 3; ++j) {
     37       for (int k = 0; k < 5; ++k) {
     38         for (int l = 0; l < 7; ++l) {
     39           VERIFY_IS_EQUAL(tensor(i,j,k,l), no_broadcast(i,j,k,l));
     40         }
     41       }
     42     }
     43   }
     44 
     45   broadcasts[0] = 2;
     46   broadcasts[1] = 3;
     47   broadcasts[2] = 1;
     48   broadcasts[3] = 4;
     49   Tensor<float, 4, DataLayout> broadcast;
     50   broadcast = tensor.broadcast(broadcasts);
     51 
     52   VERIFY_IS_EQUAL(broadcast.dimension(0), 4);
     53   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
     54   VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
     55   VERIFY_IS_EQUAL(broadcast.dimension(3), 28);
     56 
     57   for (int i = 0; i < 4; ++i) {
     58     for (int j = 0; j < 9; ++j) {
     59       for (int k = 0; k < 5; ++k) {
     60         for (int l = 0; l < 28; ++l) {
     61           VERIFY_IS_EQUAL(tensor(i%2,j%3,k%5,l%7), broadcast(i,j,k,l));
     62         }
     63       }
     64     }
     65   }
     66 }
     67 
     68 
     69 template <int DataLayout>
     70 static void test_vectorized_broadcasting()
     71 {
     72   Tensor<float, 3, DataLayout> tensor(8,3,5);
     73   tensor.setRandom();
     74   array<ptrdiff_t, 3> broadcasts;
     75   broadcasts[0] = 2;
     76   broadcasts[1] = 3;
     77   broadcasts[2] = 4;
     78 
     79   Tensor<float, 3, DataLayout> broadcast;
     80   broadcast = tensor.broadcast(broadcasts);
     81 
     82   VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
     83   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
     84   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
     85 
     86   for (int i = 0; i < 16; ++i) {
     87     for (int j = 0; j < 9; ++j) {
     88       for (int k = 0; k < 20; ++k) {
     89         VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
     90       }
     91     }
     92   }
     93 
     94 #if EIGEN_HAS_VARIADIC_TEMPLATES
     95   tensor.resize(11,3,5);
     96 #else
     97   array<Index, 3> new_dims;
     98   new_dims[0] = 11;
     99   new_dims[1] = 3;
    100   new_dims[2] = 5;
    101   tensor.resize(new_dims);
    102 #endif
    103 
    104   tensor.setRandom();
    105   broadcast = tensor.broadcast(broadcasts);
    106 
    107   VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
    108   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
    109   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
    110 
    111   for (int i = 0; i < 22; ++i) {
    112     for (int j = 0; j < 9; ++j) {
    113       for (int k = 0; k < 20; ++k) {
    114         VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
    115       }
    116     }
    117   }
    118 }
    119 
    120 
    121 template <int DataLayout>
    122 static void test_static_broadcasting()
    123 {
    124   Tensor<float, 3, DataLayout> tensor(8,3,5);
    125   tensor.setRandom();
    126 
    127 #if defined(EIGEN_HAS_INDEX_LIST)
    128   Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts;
    129 #else
    130   Eigen::array<int, 3> broadcasts;
    131   broadcasts[0] = 2;
    132   broadcasts[1] = 3;
    133   broadcasts[2] = 4;
    134 #endif
    135 
    136   Tensor<float, 3, DataLayout> broadcast;
    137   broadcast = tensor.broadcast(broadcasts);
    138 
    139   VERIFY_IS_EQUAL(broadcast.dimension(0), 16);
    140   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
    141   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
    142 
    143   for (int i = 0; i < 16; ++i) {
    144     for (int j = 0; j < 9; ++j) {
    145       for (int k = 0; k < 20; ++k) {
    146         VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
    147       }
    148     }
    149   }
    150 
    151 #if EIGEN_HAS_VARIADIC_TEMPLATES
    152   tensor.resize(11,3,5);
    153 #else
    154   array<Index, 3> new_dims;
    155   new_dims[0] = 11;
    156   new_dims[1] = 3;
    157   new_dims[2] = 5;
    158   tensor.resize(new_dims);
    159 #endif
    160 
    161   tensor.setRandom();
    162   broadcast = tensor.broadcast(broadcasts);
    163 
    164   VERIFY_IS_EQUAL(broadcast.dimension(0), 22);
    165   VERIFY_IS_EQUAL(broadcast.dimension(1), 9);
    166   VERIFY_IS_EQUAL(broadcast.dimension(2), 20);
    167 
    168   for (int i = 0; i < 22; ++i) {
    169     for (int j = 0; j < 9; ++j) {
    170       for (int k = 0; k < 20; ++k) {
    171         VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
    172       }
    173     }
    174   }
    175 }
    176 
    177 
    178 template <int DataLayout>
    179 static void test_fixed_size_broadcasting()
    180 {
    181   // Need to add a [] operator to the Size class for this to work
    182 #if 0
    183   Tensor<float, 1, DataLayout> t1(10);
    184   t1.setRandom();
    185   TensorFixedSize<float, Sizes<1>, DataLayout> t2;
    186   t2 = t2.constant(20.0f);
    187 
    188   Tensor<float, 1, DataLayout> t3 = t1 + t2.broadcast(Eigen::array<int, 1>{{10}});
    189   for (int i = 0; i < 10; ++i) {
    190     VERIFY_IS_APPROX(t3(i), t1(i) + t2(0));
    191   }
    192 
    193   TensorMap<TensorFixedSize<float, Sizes<1>, DataLayout> > t4(t2.data(), {{1}});
    194   Tensor<float, 1, DataLayout> t5 = t1 + t4.broadcast(Eigen::array<int, 1>{{10}});
    195   for (int i = 0; i < 10; ++i) {
    196     VERIFY_IS_APPROX(t5(i), t1(i) + t2(0));
    197   }
    198 #endif
    199 }
    200 
    201 template <int DataLayout>
    202 static void test_simple_broadcasting_one_by_n()
    203 {
    204   Tensor<float, 4, DataLayout> tensor(1,13,5,7);
    205   tensor.setRandom();
    206   array<ptrdiff_t, 4> broadcasts;
    207   broadcasts[0] = 9;
    208   broadcasts[1] = 1;
    209   broadcasts[2] = 1;
    210   broadcasts[3] = 1;
    211   Tensor<float, 4, DataLayout> broadcast;
    212   broadcast = tensor.broadcast(broadcasts);
    213 
    214   VERIFY_IS_EQUAL(broadcast.dimension(0), 9);
    215   VERIFY_IS_EQUAL(broadcast.dimension(1), 13);
    216   VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
    217   VERIFY_IS_EQUAL(broadcast.dimension(3), 7);
    218 
    219   for (int i = 0; i < 9; ++i) {
    220     for (int j = 0; j < 13; ++j) {
    221       for (int k = 0; k < 5; ++k) {
    222         for (int l = 0; l < 7; ++l) {
    223           VERIFY_IS_EQUAL(tensor(i%1,j%13,k%5,l%7), broadcast(i,j,k,l));
    224         }
    225       }
    226     }
    227   }
    228 }
    229 
    230 template <int DataLayout>
    231 static void test_simple_broadcasting_n_by_one()
    232 {
    233   Tensor<float, 4, DataLayout> tensor(7,3,5,1);
    234   tensor.setRandom();
    235   array<ptrdiff_t, 4> broadcasts;
    236   broadcasts[0] = 1;
    237   broadcasts[1] = 1;
    238   broadcasts[2] = 1;
    239   broadcasts[3] = 19;
    240   Tensor<float, 4, DataLayout> broadcast;
    241   broadcast = tensor.broadcast(broadcasts);
    242 
    243   VERIFY_IS_EQUAL(broadcast.dimension(0), 7);
    244   VERIFY_IS_EQUAL(broadcast.dimension(1), 3);
    245   VERIFY_IS_EQUAL(broadcast.dimension(2), 5);
    246   VERIFY_IS_EQUAL(broadcast.dimension(3), 19);
    247 
    248   for (int i = 0; i < 7; ++i) {
    249     for (int j = 0; j < 3; ++j) {
    250       for (int k = 0; k < 5; ++k) {
    251         for (int l = 0; l < 19; ++l) {
    252           VERIFY_IS_EQUAL(tensor(i%7,j%3,k%5,l%1), broadcast(i,j,k,l));
    253         }
    254       }
    255     }
    256   }
    257 }
    258 
    259 template <int DataLayout>
    260 static void test_simple_broadcasting_one_by_n_by_one_1d()
    261 {
    262   Tensor<float, 3, DataLayout> tensor(1,7,1);
    263   tensor.setRandom();
    264   array<ptrdiff_t, 3> broadcasts;
    265   broadcasts[0] = 5;
    266   broadcasts[1] = 1;
    267   broadcasts[2] = 13;
    268   Tensor<float, 3, DataLayout> broadcasted;
    269   broadcasted = tensor.broadcast(broadcasts);
    270 
    271   VERIFY_IS_EQUAL(broadcasted.dimension(0), 5);
    272   VERIFY_IS_EQUAL(broadcasted.dimension(1), 7);
    273   VERIFY_IS_EQUAL(broadcasted.dimension(2), 13);
    274 
    275   for (int i = 0; i < 5; ++i) {
    276     for (int j = 0; j < 7; ++j) {
    277       for (int k = 0; k < 13; ++k) {
    278         VERIFY_IS_EQUAL(tensor(0,j%7,0), broadcasted(i,j,k));
    279       }
    280     }
    281   }
    282 }
    283 
    284 template <int DataLayout>
    285 static void test_simple_broadcasting_one_by_n_by_one_2d()
    286 {
    287   Tensor<float, 4, DataLayout> tensor(1,7,13,1);
    288   tensor.setRandom();
    289   array<ptrdiff_t, 4> broadcasts;
    290   broadcasts[0] = 5;
    291   broadcasts[1] = 1;
    292   broadcasts[2] = 1;
    293   broadcasts[3] = 19;
    294   Tensor<float, 4, DataLayout> broadcast;
    295   broadcast = tensor.broadcast(broadcasts);
    296 
    297   VERIFY_IS_EQUAL(broadcast.dimension(0), 5);
    298   VERIFY_IS_EQUAL(broadcast.dimension(1), 7);
    299   VERIFY_IS_EQUAL(broadcast.dimension(2), 13);
    300   VERIFY_IS_EQUAL(broadcast.dimension(3), 19);
    301 
    302   for (int i = 0; i < 5; ++i) {
    303     for (int j = 0; j < 7; ++j) {
    304       for (int k = 0; k < 13; ++k) {
    305         for (int l = 0; l < 19; ++l) {
    306           VERIFY_IS_EQUAL(tensor(0,j%7,k%13,0), broadcast(i,j,k,l));
    307         }
    308       }
    309     }
    310   }
    311 }
    312 
    313 EIGEN_DECLARE_TEST(cxx11_tensor_broadcasting)
    314 {
    315   CALL_SUBTEST(test_simple_broadcasting<ColMajor>());
    316   CALL_SUBTEST(test_simple_broadcasting<RowMajor>());
    317   CALL_SUBTEST(test_vectorized_broadcasting<ColMajor>());
    318   CALL_SUBTEST(test_vectorized_broadcasting<RowMajor>());
    319   CALL_SUBTEST(test_static_broadcasting<ColMajor>());
    320   CALL_SUBTEST(test_static_broadcasting<RowMajor>());
    321   CALL_SUBTEST(test_fixed_size_broadcasting<ColMajor>());
    322   CALL_SUBTEST(test_fixed_size_broadcasting<RowMajor>());
    323   CALL_SUBTEST(test_simple_broadcasting_one_by_n<RowMajor>());
    324   CALL_SUBTEST(test_simple_broadcasting_n_by_one<RowMajor>());
    325   CALL_SUBTEST(test_simple_broadcasting_one_by_n<ColMajor>());
    326   CALL_SUBTEST(test_simple_broadcasting_n_by_one<ColMajor>());
    327   CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<ColMajor>());
    328   CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<ColMajor>());
    329   CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_1d<RowMajor>());
    330   CALL_SUBTEST(test_simple_broadcasting_one_by_n_by_one_2d<RowMajor>());
    331 }