cart-elc

Source code for CART-ELC
git clone git://git.laack.co/cart-elc.git
Log | Files | Refs | README | LICENSE

cxx11_tensor_trace.cpp (5129B)


      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com>
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #include "main.h"
     11 
     12 #include <Eigen/CXX11/Tensor>
     13 
     14 using Eigen::Tensor;
     15 using Eigen::array;
     16 
     17 template <int DataLayout>
     18 static void test_0D_trace() {
     19   Tensor<float, 0, DataLayout> tensor;
     20   tensor.setRandom();
     21   array<ptrdiff_t, 0> dims;
     22   Tensor<float, 0, DataLayout> result = tensor.trace(dims);
     23   VERIFY_IS_EQUAL(result(), tensor());
     24 }
     25 
     26 
     27 template <int DataLayout>
     28 static void test_all_dimensions_trace() {
     29   Tensor<float, 3, DataLayout> tensor1(5, 5, 5);
     30   tensor1.setRandom();
     31   Tensor<float, 0, DataLayout> result1 = tensor1.trace();
     32   VERIFY_IS_EQUAL(result1.rank(), 0);
     33   float sum = 0.0f;
     34   for (int i = 0; i < 5; ++i) {
     35     sum += tensor1(i, i, i);
     36   }
     37   VERIFY_IS_EQUAL(result1(), sum);
     38 
     39   Tensor<float, 5, DataLayout> tensor2(7, 7, 7, 7, 7);
     40   tensor2.setRandom();
     41   array<ptrdiff_t, 5> dims = { { 2, 1, 0, 3, 4 } };
     42   Tensor<float, 0, DataLayout> result2 = tensor2.trace(dims);
     43   VERIFY_IS_EQUAL(result2.rank(), 0);
     44   sum = 0.0f;
     45   for (int i = 0; i < 7; ++i) {
     46     sum += tensor2(i, i, i, i, i);
     47   }
     48   VERIFY_IS_EQUAL(result2(), sum);
     49 }
     50 
     51 
     52 template <int DataLayout>
     53 static void test_simple_trace() {
     54   Tensor<float, 3, DataLayout> tensor1(3, 5, 3);
     55   tensor1.setRandom();
     56   array<ptrdiff_t, 2> dims1 = { { 0, 2 } };
     57   Tensor<float, 1, DataLayout> result1 = tensor1.trace(dims1);
     58   VERIFY_IS_EQUAL(result1.rank(), 1);
     59   VERIFY_IS_EQUAL(result1.dimension(0), 5);
     60   float sum = 0.0f;
     61   for (int i = 0; i < 5; ++i) {
     62     sum = 0.0f;
     63     for (int j = 0; j < 3; ++j) {
     64       sum += tensor1(j, i, j);
     65     }
     66     VERIFY_IS_EQUAL(result1(i), sum);
     67   }
     68 
     69   Tensor<float, 4, DataLayout> tensor2(5, 5, 7, 7);
     70   tensor2.setRandom();
     71   array<ptrdiff_t, 2> dims2 = { { 2, 3 } };
     72   Tensor<float, 2, DataLayout> result2 = tensor2.trace(dims2);
     73   VERIFY_IS_EQUAL(result2.rank(), 2);
     74   VERIFY_IS_EQUAL(result2.dimension(0), 5);
     75   VERIFY_IS_EQUAL(result2.dimension(1), 5);
     76   for (int i = 0; i < 5; ++i) {
     77     for (int j = 0; j < 5; ++j) {
     78       sum = 0.0f;
     79       for (int k = 0; k < 7; ++k) {
     80         sum += tensor2(i, j, k, k);
     81       }
     82       VERIFY_IS_EQUAL(result2(i, j), sum);
     83     }
     84   }
     85 
     86   array<ptrdiff_t, 2> dims3 = { { 1, 0 } };
     87   Tensor<float, 2, DataLayout> result3 = tensor2.trace(dims3);
     88   VERIFY_IS_EQUAL(result3.rank(), 2);
     89   VERIFY_IS_EQUAL(result3.dimension(0), 7);
     90   VERIFY_IS_EQUAL(result3.dimension(1), 7);
     91   for (int i = 0; i < 7; ++i) {
     92     for (int j = 0; j < 7; ++j) {
     93       sum = 0.0f;
     94       for (int k = 0; k < 5; ++k) {
     95         sum += tensor2(k, k, i, j);
     96       }
     97       VERIFY_IS_EQUAL(result3(i, j), sum);
     98     }
     99   }
    100 
    101   Tensor<float, 5, DataLayout> tensor3(3, 7, 3, 7, 3);
    102   tensor3.setRandom();
    103   array<ptrdiff_t, 3> dims4 = { { 0, 2, 4 } };
    104   Tensor<float, 2, DataLayout> result4 = tensor3.trace(dims4);
    105   VERIFY_IS_EQUAL(result4.rank(), 2);
    106   VERIFY_IS_EQUAL(result4.dimension(0), 7);
    107   VERIFY_IS_EQUAL(result4.dimension(1), 7);
    108   for (int i = 0; i < 7; ++i) {
    109     for (int j = 0; j < 7; ++j) {
    110       sum = 0.0f;
    111       for (int k = 0; k < 3; ++k) {
    112         sum += tensor3(k, i, k, j, k);
    113       }
    114       VERIFY_IS_EQUAL(result4(i, j), sum);
    115     }
    116   }
    117 
    118   Tensor<float, 5, DataLayout> tensor4(3, 7, 4, 7, 5);
    119   tensor4.setRandom();
    120   array<ptrdiff_t, 2> dims5 = { { 1, 3 } };
    121   Tensor<float, 3, DataLayout> result5 = tensor4.trace(dims5);
    122   VERIFY_IS_EQUAL(result5.rank(), 3);
    123   VERIFY_IS_EQUAL(result5.dimension(0), 3);
    124   VERIFY_IS_EQUAL(result5.dimension(1), 4);
    125   VERIFY_IS_EQUAL(result5.dimension(2), 5);
    126   for (int i = 0; i < 3; ++i) {
    127     for (int j = 0; j < 4; ++j) {
    128       for (int k = 0; k < 5; ++k) {
    129         sum = 0.0f;
    130         for (int l = 0; l < 7; ++l) {
    131           sum += tensor4(i, l, j, l, k);
    132         }
    133         VERIFY_IS_EQUAL(result5(i, j, k), sum);
    134       }
    135     }
    136   }
    137 }
    138 
    139 
    140 template<int DataLayout>
    141 static void test_trace_in_expr() {
    142   Tensor<float, 4, DataLayout> tensor(2, 3, 5, 3);
    143   tensor.setRandom();
    144   array<ptrdiff_t, 2> dims = { { 1, 3 } };
    145   Tensor<float, 2, DataLayout> result(2, 5);
    146   result = result.constant(1.0f) - tensor.trace(dims);
    147   VERIFY_IS_EQUAL(result.rank(), 2);
    148   VERIFY_IS_EQUAL(result.dimension(0), 2);
    149   VERIFY_IS_EQUAL(result.dimension(1), 5);
    150   float sum = 0.0f;
    151   for (int i = 0; i < 2; ++i) {
    152     for (int j = 0; j < 5; ++j) {
    153       sum = 0.0f;
    154       for (int k = 0; k < 3; ++k) {
    155         sum += tensor(i, k, j, k);
    156       }
    157       VERIFY_IS_EQUAL(result(i, j), 1.0f - sum);
    158     }
    159   }
    160 }
    161 
    162 
    163 EIGEN_DECLARE_TEST(cxx11_tensor_trace) {
    164   CALL_SUBTEST(test_0D_trace<ColMajor>());
    165   CALL_SUBTEST(test_0D_trace<RowMajor>());
    166   CALL_SUBTEST(test_all_dimensions_trace<ColMajor>());
    167   CALL_SUBTEST(test_all_dimensions_trace<RowMajor>());
    168   CALL_SUBTEST(test_simple_trace<ColMajor>());
    169   CALL_SUBTEST(test_simple_trace<RowMajor>());
    170   CALL_SUBTEST(test_trace_in_expr<ColMajor>());
    171   CALL_SUBTEST(test_trace_in_expr<RowMajor>());
    172 }