cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

cxx11_tensor_shuffling.cpp (7692B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #include "main.h"
     11 
     12 #include <Eigen/CXX11/Tensor>
     13 
     14 using Eigen::Tensor;
     15 using Eigen::array;
     16 
     17 template <int DataLayout>
     18 static void test_simple_shuffling()
     19 {
     20   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
     21   tensor.setRandom();
     22   array<ptrdiff_t, 4> shuffles;
     23   shuffles[0] = 0;
     24   shuffles[1] = 1;
     25   shuffles[2] = 2;
     26   shuffles[3] = 3;
     27 
     28   Tensor<float, 4, DataLayout> no_shuffle;
     29   no_shuffle = tensor.shuffle(shuffles);
     30 
     31   VERIFY_IS_EQUAL(no_shuffle.dimension(0), 2);
     32   VERIFY_IS_EQUAL(no_shuffle.dimension(1), 3);
     33   VERIFY_IS_EQUAL(no_shuffle.dimension(2), 5);
     34   VERIFY_IS_EQUAL(no_shuffle.dimension(3), 7);
     35 
     36   for (int i = 0; i < 2; ++i) {
     37     for (int j = 0; j < 3; ++j) {
     38       for (int k = 0; k < 5; ++k) {
     39         for (int l = 0; l < 7; ++l) {
     40           VERIFY_IS_EQUAL(tensor(i,j,k,l), no_shuffle(i,j,k,l));
     41         }
     42       }
     43     }
     44   }
     45 
     46   shuffles[0] = 2;
     47   shuffles[1] = 3;
     48   shuffles[2] = 1;
     49   shuffles[3] = 0;
     50   Tensor<float, 4, DataLayout> shuffle;
     51   shuffle = tensor.shuffle(shuffles);
     52 
     53   VERIFY_IS_EQUAL(shuffle.dimension(0), 5);
     54   VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
     55   VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
     56   VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
     57 
     58   for (int i = 0; i < 2; ++i) {
     59     for (int j = 0; j < 3; ++j) {
     60       for (int k = 0; k < 5; ++k) {
     61         for (int l = 0; l < 7; ++l) {
     62           VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
     63         }
     64       }
     65     }
     66   }
     67 }
     68 
     69 
     70 template <int DataLayout>
     71 static void test_expr_shuffling()
     72 {
     73   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
     74   tensor.setRandom();
     75 
     76   array<ptrdiff_t, 4> shuffles;
     77   shuffles[0] = 2;
     78   shuffles[1] = 3;
     79   shuffles[2] = 1;
     80   shuffles[3] = 0;
     81   Tensor<float, 4, DataLayout> expected;
     82   expected = tensor.shuffle(shuffles);
     83 
     84   Tensor<float, 4, DataLayout> result(5, 7, 3, 2);
     85 
     86   array<ptrdiff_t, 4> src_slice_dim{{2, 3, 1, 7}};
     87   array<ptrdiff_t, 4> src_slice_start{{0, 0, 0, 0}};
     88   array<ptrdiff_t, 4> dst_slice_dim{{1, 7, 3, 2}};
     89   array<ptrdiff_t, 4> dst_slice_start{{0, 0, 0, 0}};
     90 
     91   for (int i = 0; i < 5; ++i) {
     92     result.slice(dst_slice_start, dst_slice_dim) =
     93         tensor.slice(src_slice_start, src_slice_dim).shuffle(shuffles);
     94     src_slice_start[2] += 1;
     95     dst_slice_start[0] += 1;
     96   }
     97 
     98   VERIFY_IS_EQUAL(result.dimension(0), 5);
     99   VERIFY_IS_EQUAL(result.dimension(1), 7);
    100   VERIFY_IS_EQUAL(result.dimension(2), 3);
    101   VERIFY_IS_EQUAL(result.dimension(3), 2);
    102 
    103   for (int i = 0; i < expected.dimension(0); ++i) {
    104     for (int j = 0; j < expected.dimension(1); ++j) {
    105       for (int k = 0; k < expected.dimension(2); ++k) {
    106         for (int l = 0; l < expected.dimension(3); ++l) {
    107           VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
    108         }
    109       }
    110     }
    111   }
    112 
    113   dst_slice_start[0] = 0;
    114   result.setRandom();
    115   for (int i = 0; i < 5; ++i) {
    116     result.slice(dst_slice_start, dst_slice_dim) =
    117         tensor.shuffle(shuffles).slice(dst_slice_start, dst_slice_dim);
    118     dst_slice_start[0] += 1;
    119   }
    120 
    121   for (int i = 0; i < expected.dimension(0); ++i) {
    122     for (int j = 0; j < expected.dimension(1); ++j) {
    123       for (int k = 0; k < expected.dimension(2); ++k) {
    124         for (int l = 0; l < expected.dimension(3); ++l) {
    125           VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
    126         }
    127       }
    128     }
    129   }
    130 }
    131 
    132 
    133 template <int DataLayout>
    134 static void test_shuffling_as_value()
    135 {
    136   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
    137   tensor.setRandom();
    138   array<ptrdiff_t, 4> shuffles;
    139   shuffles[2] = 0;
    140   shuffles[3] = 1;
    141   shuffles[1] = 2;
    142   shuffles[0] = 3;
    143   Tensor<float, 4, DataLayout> shuffle(5,7,3,2);
    144   shuffle.shuffle(shuffles) = tensor;
    145 
    146   VERIFY_IS_EQUAL(shuffle.dimension(0), 5);
    147   VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
    148   VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
    149   VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
    150 
    151   for (int i = 0; i < 2; ++i) {
    152     for (int j = 0; j < 3; ++j) {
    153       for (int k = 0; k < 5; ++k) {
    154         for (int l = 0; l < 7; ++l) {
    155           VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
    156         }
    157       }
    158     }
    159   }
    160 
    161   array<ptrdiff_t, 4> no_shuffle;
    162   no_shuffle[0] = 0;
    163   no_shuffle[1] = 1;
    164   no_shuffle[2] = 2;
    165   no_shuffle[3] = 3;
    166   Tensor<float, 4, DataLayout> shuffle2(5,7,3,2);
    167   shuffle2.shuffle(shuffles) = tensor.shuffle(no_shuffle);
    168   for (int i = 0; i < 5; ++i) {
    169     for (int j = 0; j < 7; ++j) {
    170       for (int k = 0; k < 3; ++k) {
    171         for (int l = 0; l < 2; ++l) {
    172           VERIFY_IS_EQUAL(shuffle2(i,j,k,l), shuffle(i,j,k,l));
    173         }
    174       }
    175     }
    176   }
    177 }
    178 
    179 
    180 template <int DataLayout>
    181 static void test_shuffle_unshuffle()
    182 {
    183   Tensor<float, 4, DataLayout> tensor(2,3,5,7);
    184   tensor.setRandom();
    185 
    186   // Choose a random permutation.
    187   array<ptrdiff_t, 4> shuffles;
    188   for (int i = 0; i < 4; ++i) {
    189     shuffles[i] = i;
    190   }
    191   array<ptrdiff_t, 4> shuffles_inverse;
    192   for (int i = 0; i < 4; ++i) {
    193     const ptrdiff_t index = internal::random<ptrdiff_t>(i, 3);
    194     shuffles_inverse[shuffles[index]] = i;
    195     std::swap(shuffles[i], shuffles[index]);
    196   }
    197 
    198   Tensor<float, 4, DataLayout> shuffle;
    199   shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse);
    200 
    201   VERIFY_IS_EQUAL(shuffle.dimension(0), 2);
    202   VERIFY_IS_EQUAL(shuffle.dimension(1), 3);
    203   VERIFY_IS_EQUAL(shuffle.dimension(2), 5);
    204   VERIFY_IS_EQUAL(shuffle.dimension(3), 7);
    205 
    206   for (int i = 0; i < 2; ++i) {
    207     for (int j = 0; j < 3; ++j) {
    208       for (int k = 0; k < 5; ++k) {
    209         for (int l = 0; l < 7; ++l) {
    210           VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l));
    211         }
    212       }
    213     }
    214   }
    215 }
    216 
    217 
    218 template <int DataLayout>
    219 static void test_empty_shuffling()
    220 {
    221   Tensor<float, 4, DataLayout> tensor(2,3,0,7);
    222   tensor.setRandom();
    223   array<ptrdiff_t, 4> shuffles;
    224   shuffles[0] = 0;
    225   shuffles[1] = 1;
    226   shuffles[2] = 2;
    227   shuffles[3] = 3;
    228 
    229   Tensor<float, 4, DataLayout> no_shuffle;
    230   no_shuffle = tensor.shuffle(shuffles);
    231 
    232   VERIFY_IS_EQUAL(no_shuffle.dimension(0), 2);
    233   VERIFY_IS_EQUAL(no_shuffle.dimension(1), 3);
    234   VERIFY_IS_EQUAL(no_shuffle.dimension(2), 0);
    235   VERIFY_IS_EQUAL(no_shuffle.dimension(3), 7);
    236 
    237   for (int i = 0; i < 2; ++i) {
    238     for (int j = 0; j < 3; ++j) {
    239       for (int k = 0; k < 0; ++k) {
    240         for (int l = 0; l < 7; ++l) {
    241           VERIFY_IS_EQUAL(tensor(i,j,k,l), no_shuffle(i,j,k,l));
    242         }
    243       }
    244     }
    245   }
    246 
    247   shuffles[0] = 2;
    248   shuffles[1] = 3;
    249   shuffles[2] = 1;
    250   shuffles[3] = 0;
    251   Tensor<float, 4, DataLayout> shuffle;
    252   shuffle = tensor.shuffle(shuffles);
    253 
    254   VERIFY_IS_EQUAL(shuffle.dimension(0), 0);
    255   VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
    256   VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
    257   VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
    258 
    259   for (int i = 0; i < 2; ++i) {
    260     for (int j = 0; j < 3; ++j) {
    261       for (int k = 0; k < 0; ++k) {
    262         for (int l = 0; l < 7; ++l) {
    263           VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
    264         }
    265       }
    266     }
    267   }
    268 }
    269 
    270 
    271 EIGEN_DECLARE_TEST(cxx11_tensor_shuffling)
    272 {
    273   CALL_SUBTEST(test_simple_shuffling<ColMajor>());
    274   CALL_SUBTEST(test_simple_shuffling<RowMajor>());
    275   CALL_SUBTEST(test_expr_shuffling<ColMajor>());
    276   CALL_SUBTEST(test_expr_shuffling<RowMajor>());
    277   CALL_SUBTEST(test_shuffling_as_value<ColMajor>());
    278   CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
    279   CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>());
    280   CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>());
    281   CALL_SUBTEST(test_empty_shuffling<ColMajor>());
    282   CALL_SUBTEST(test_empty_shuffling<RowMajor>());
    283 }