cxx11_tensor_casts.cpp (5526B)
1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #include "main.h" 11 #include "random_without_cast_overflow.h" 12 13 #include <Eigen/CXX11/Tensor> 14 15 using Eigen::Tensor; 16 using Eigen::array; 17 18 static void test_simple_cast() 19 { 20 Tensor<float, 2> ftensor(20,30); 21 ftensor = ftensor.random() * 100.f; 22 Tensor<char, 2> chartensor(20,30); 23 chartensor.setRandom(); 24 Tensor<std::complex<float>, 2> cplextensor(20,30); 25 cplextensor.setRandom(); 26 27 chartensor = ftensor.cast<char>(); 28 cplextensor = ftensor.cast<std::complex<float> >(); 29 30 for (int i = 0; i < 20; ++i) { 31 for (int j = 0; j < 30; ++j) { 32 VERIFY_IS_EQUAL(chartensor(i,j), static_cast<char>(ftensor(i,j))); 33 VERIFY_IS_EQUAL(cplextensor(i,j), static_cast<std::complex<float> >(ftensor(i,j))); 34 } 35 } 36 } 37 38 39 static void test_vectorized_cast() 40 { 41 Tensor<int, 2> itensor(20,30); 42 itensor = itensor.random() / 1000; 43 Tensor<float, 2> ftensor(20,30); 44 ftensor.setRandom(); 45 Tensor<double, 2> dtensor(20,30); 46 dtensor.setRandom(); 47 48 ftensor = itensor.cast<float>(); 49 dtensor = itensor.cast<double>(); 50 51 for (int i = 0; i < 20; ++i) { 52 for (int j = 0; j < 30; ++j) { 53 VERIFY_IS_EQUAL(itensor(i,j), static_cast<int>(ftensor(i,j))); 54 VERIFY_IS_EQUAL(dtensor(i,j), static_cast<double>(ftensor(i,j))); 55 } 56 } 57 } 58 59 60 static void test_float_to_int_cast() 61 { 62 Tensor<float, 2> ftensor(20,30); 63 ftensor = ftensor.random() * 1000.0f; 64 Tensor<double, 2> dtensor(20,30); 65 dtensor = dtensor.random() * 1000.0; 66 67 Tensor<int, 2> i1tensor = ftensor.cast<int>(); 68 Tensor<int, 2> i2tensor = dtensor.cast<int>(); 69 70 for (int i = 0; i < 20; ++i) { 71 for (int j = 0; j < 30; ++j) { 72 VERIFY_IS_EQUAL(i1tensor(i,j), static_cast<int>(ftensor(i,j))); 73 VERIFY_IS_EQUAL(i2tensor(i,j), static_cast<int>(dtensor(i,j))); 74 } 75 } 76 } 77 78 79 static void test_big_to_small_type_cast() 80 { 81 Tensor<double, 2> dtensor(20, 30); 82 dtensor.setRandom(); 83 Tensor<float, 2> ftensor(20, 30); 84 ftensor = dtensor.cast<float>(); 85 86 for (int i = 0; i < 20; ++i) { 87 for (int j = 0; j < 30; ++j) { 88 VERIFY_IS_APPROX(dtensor(i,j), static_cast<double>(ftensor(i,j))); 89 } 90 } 91 } 92 93 94 static void test_small_to_big_type_cast() 95 { 96 Tensor<float, 2> ftensor(20, 30); 97 ftensor.setRandom(); 98 Tensor<double, 2> dtensor(20, 30); 99 dtensor = ftensor.cast<double>(); 100 101 for (int i = 0; i < 20; ++i) { 102 for (int j = 0; j < 30; ++j) { 103 VERIFY_IS_APPROX(dtensor(i,j), static_cast<double>(ftensor(i,j))); 104 } 105 } 106 } 107 108 template <typename FromType, typename ToType> 109 static void test_type_cast() { 110 Tensor<FromType, 2> ftensor(100, 200); 111 // Generate random values for a valid cast. 112 for (int i = 0; i < 100; ++i) { 113 for (int j = 0; j < 200; ++j) { 114 ftensor(i, j) = internal::random_without_cast_overflow<FromType,ToType>::value(); 115 } 116 } 117 118 Tensor<ToType, 2> ttensor(100, 200); 119 ttensor = ftensor.template cast<ToType>(); 120 121 for (int i = 0; i < 100; ++i) { 122 for (int j = 0; j < 200; ++j) { 123 const ToType ref = internal::cast<FromType,ToType>(ftensor(i, j)); 124 VERIFY_IS_APPROX(ttensor(i, j), ref); 125 } 126 } 127 } 128 129 template<typename Scalar, typename EnableIf = void> 130 struct test_cast_runner { 131 static void run() { 132 test_type_cast<Scalar, bool>(); 133 test_type_cast<Scalar, int8_t>(); 134 test_type_cast<Scalar, int16_t>(); 135 test_type_cast<Scalar, int32_t>(); 136 test_type_cast<Scalar, int64_t>(); 137 test_type_cast<Scalar, uint8_t>(); 138 test_type_cast<Scalar, uint16_t>(); 139 test_type_cast<Scalar, uint32_t>(); 140 test_type_cast<Scalar, uint64_t>(); 141 test_type_cast<Scalar, half>(); 142 test_type_cast<Scalar, bfloat16>(); 143 test_type_cast<Scalar, float>(); 144 test_type_cast<Scalar, double>(); 145 test_type_cast<Scalar, std::complex<float>>(); 146 test_type_cast<Scalar, std::complex<double>>(); 147 } 148 }; 149 150 // Only certain types allow cast from std::complex<>. 151 template<typename Scalar> 152 struct test_cast_runner<Scalar, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> { 153 static void run() { 154 test_type_cast<Scalar, half>(); 155 test_type_cast<Scalar, bfloat16>(); 156 test_type_cast<Scalar, std::complex<float>>(); 157 test_type_cast<Scalar, std::complex<double>>(); 158 } 159 }; 160 161 162 EIGEN_DECLARE_TEST(cxx11_tensor_casts) 163 { 164 CALL_SUBTEST(test_simple_cast()); 165 CALL_SUBTEST(test_vectorized_cast()); 166 CALL_SUBTEST(test_float_to_int_cast()); 167 CALL_SUBTEST(test_big_to_small_type_cast()); 168 CALL_SUBTEST(test_small_to_big_type_cast()); 169 170 CALL_SUBTEST(test_cast_runner<bool>::run()); 171 CALL_SUBTEST(test_cast_runner<int8_t>::run()); 172 CALL_SUBTEST(test_cast_runner<int16_t>::run()); 173 CALL_SUBTEST(test_cast_runner<int32_t>::run()); 174 CALL_SUBTEST(test_cast_runner<int64_t>::run()); 175 CALL_SUBTEST(test_cast_runner<uint8_t>::run()); 176 CALL_SUBTEST(test_cast_runner<uint16_t>::run()); 177 CALL_SUBTEST(test_cast_runner<uint32_t>::run()); 178 CALL_SUBTEST(test_cast_runner<uint64_t>::run()); 179 CALL_SUBTEST(test_cast_runner<half>::run()); 180 CALL_SUBTEST(test_cast_runner<bfloat16>::run()); 181 CALL_SUBTEST(test_cast_runner<float>::run()); 182 CALL_SUBTEST(test_cast_runner<double>::run()); 183 CALL_SUBTEST(test_cast_runner<std::complex<float>>::run()); 184 CALL_SUBTEST(test_cast_runner<std::complex<double>>::run()); 185 186 }