numext.cpp (9042B)
1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #include "main.h" 11 12 template<typename T, typename U> 13 bool check_if_equal_or_nans(const T& actual, const U& expected) { 14 return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected))); 15 } 16 17 template<typename T, typename U> 18 bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) { 19 return check_if_equal_or_nans(numext::real(actual), numext::real(expected)) 20 && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected)); 21 } 22 23 template<typename T, typename U> 24 bool test_is_equal_or_nans(const T& actual, const U& expected) 25 { 26 if (check_if_equal_or_nans(actual, expected)) { 27 return true; 28 } 29 30 // false: 31 std::cerr 32 << "\n actual = " << actual 33 << "\n expected = " << expected << "\n\n"; 34 return false; 35 } 36 37 #define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b)) 38 39 template<typename T> 40 void check_abs() { 41 typedef typename NumTraits<T>::Real Real; 42 Real zero(0); 43 44 if(NumTraits<T>::IsSigned) 45 VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1)); 46 VERIFY_IS_EQUAL(numext::abs(T(0)), T(0)); 47 VERIFY_IS_EQUAL(numext::abs(T(1)), T(1)); 48 49 for(int k=0; k<100; ++k) 50 { 51 T x = internal::random<T>(); 52 if(!internal::is_same<T,bool>::value) 53 x = x/Real(2); 54 if(NumTraits<T>::IsSigned) 55 { 56 VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x)); 57 VERIFY( numext::abs(-x) >= zero ); 58 } 59 VERIFY( numext::abs(x) >= zero ); 60 VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) ); 61 } 62 } 63 64 template<typename T> 65 void check_arg() { 66 typedef typename NumTraits<T>::Real Real; 67 VERIFY_IS_EQUAL(numext::abs(T(0)), T(0)); 68 VERIFY_IS_EQUAL(numext::abs(T(1)), T(1)); 69 70 for(int k=0; k<100; ++k) 71 { 72 T x = internal::random<T>(); 73 Real y = numext::arg(x); 74 VERIFY_IS_APPROX( y, std::arg(x) ); 75 } 76 } 77 78 template<typename T> 79 struct check_sqrt_impl { 80 static void run() { 81 for (int i=0; i<1000; ++i) { 82 const T x = numext::abs(internal::random<T>()); 83 const T sqrtx = numext::sqrt(x); 84 VERIFY_IS_APPROX(sqrtx*sqrtx, x); 85 } 86 87 // Corner cases. 88 const T zero = T(0); 89 const T one = T(1); 90 const T inf = std::numeric_limits<T>::infinity(); 91 const T nan = std::numeric_limits<T>::quiet_NaN(); 92 VERIFY_IS_EQUAL(numext::sqrt(zero), zero); 93 VERIFY_IS_EQUAL(numext::sqrt(inf), inf); 94 VERIFY((numext::isnan)(numext::sqrt(nan))); 95 VERIFY((numext::isnan)(numext::sqrt(-one))); 96 } 97 }; 98 99 template<typename T> 100 struct check_sqrt_impl<std::complex<T> > { 101 static void run() { 102 typedef typename std::complex<T> ComplexT; 103 104 for (int i=0; i<1000; ++i) { 105 const ComplexT x = internal::random<ComplexT>(); 106 const ComplexT sqrtx = numext::sqrt(x); 107 VERIFY_IS_APPROX(sqrtx*sqrtx, x); 108 } 109 110 // Corner cases. 111 const T zero = T(0); 112 const T one = T(1); 113 const T inf = std::numeric_limits<T>::infinity(); 114 const T nan = std::numeric_limits<T>::quiet_NaN(); 115 116 // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt 117 const int kNumCorners = 20; 118 const ComplexT corners[kNumCorners][2] = { 119 {ComplexT(zero, zero), ComplexT(zero, zero)}, 120 {ComplexT(-zero, zero), ComplexT(zero, zero)}, 121 {ComplexT(zero, -zero), ComplexT(zero, zero)}, 122 {ComplexT(-zero, -zero), ComplexT(zero, zero)}, 123 {ComplexT(one, inf), ComplexT(inf, inf)}, 124 {ComplexT(nan, inf), ComplexT(inf, inf)}, 125 {ComplexT(one, -inf), ComplexT(inf, -inf)}, 126 {ComplexT(nan, -inf), ComplexT(inf, -inf)}, 127 {ComplexT(-inf, one), ComplexT(zero, inf)}, 128 {ComplexT(inf, one), ComplexT(inf, zero)}, 129 {ComplexT(-inf, -one), ComplexT(zero, -inf)}, 130 {ComplexT(inf, -one), ComplexT(inf, -zero)}, 131 {ComplexT(-inf, nan), ComplexT(nan, inf)}, 132 {ComplexT(inf, nan), ComplexT(inf, nan)}, 133 {ComplexT(zero, nan), ComplexT(nan, nan)}, 134 {ComplexT(one, nan), ComplexT(nan, nan)}, 135 {ComplexT(nan, zero), ComplexT(nan, nan)}, 136 {ComplexT(nan, one), ComplexT(nan, nan)}, 137 {ComplexT(nan, -one), ComplexT(nan, nan)}, 138 {ComplexT(nan, nan), ComplexT(nan, nan)}, 139 }; 140 141 for (int i=0; i<kNumCorners; ++i) { 142 const ComplexT& x = corners[i][0]; 143 const ComplexT sqrtx = corners[i][1]; 144 VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx); 145 } 146 } 147 }; 148 149 template<typename T> 150 void check_sqrt() { 151 check_sqrt_impl<T>::run(); 152 } 153 154 template<typename T> 155 struct check_rsqrt_impl { 156 static void run() { 157 const T zero = T(0); 158 const T one = T(1); 159 const T inf = std::numeric_limits<T>::infinity(); 160 const T nan = std::numeric_limits<T>::quiet_NaN(); 161 162 for (int i=0; i<1000; ++i) { 163 const T x = numext::abs(internal::random<T>()); 164 const T rsqrtx = numext::rsqrt(x); 165 const T invx = one / x; 166 VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx); 167 } 168 169 // Corner cases. 170 VERIFY_IS_EQUAL(numext::rsqrt(zero), inf); 171 VERIFY_IS_EQUAL(numext::rsqrt(inf), zero); 172 VERIFY((numext::isnan)(numext::rsqrt(nan))); 173 VERIFY((numext::isnan)(numext::rsqrt(-one))); 174 } 175 }; 176 177 template<typename T> 178 struct check_rsqrt_impl<std::complex<T> > { 179 static void run() { 180 typedef typename std::complex<T> ComplexT; 181 const T zero = T(0); 182 const T one = T(1); 183 const T inf = std::numeric_limits<T>::infinity(); 184 const T nan = std::numeric_limits<T>::quiet_NaN(); 185 186 for (int i=0; i<1000; ++i) { 187 const ComplexT x = internal::random<ComplexT>(); 188 const ComplexT invx = ComplexT(one, zero) / x; 189 const ComplexT rsqrtx = numext::rsqrt(x); 190 VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx); 191 } 192 193 // GCC and MSVC differ in their treatment of 1/(0 + 0i) 194 // GCC/clang = (inf, nan) 195 // MSVC = (nan, nan) 196 // and 1 / (x + inf i) 197 // GCC/clang = (0, 0) 198 // MSVC = (nan, nan) 199 #if (EIGEN_COMP_GNUC) 200 { 201 const int kNumCorners = 20; 202 const ComplexT corners[kNumCorners][2] = { 203 // Only consistent across GCC, clang 204 {ComplexT(zero, zero), ComplexT(zero, zero)}, 205 {ComplexT(-zero, zero), ComplexT(zero, zero)}, 206 {ComplexT(zero, -zero), ComplexT(zero, zero)}, 207 {ComplexT(-zero, -zero), ComplexT(zero, zero)}, 208 {ComplexT(one, inf), ComplexT(inf, inf)}, 209 {ComplexT(nan, inf), ComplexT(inf, inf)}, 210 {ComplexT(one, -inf), ComplexT(inf, -inf)}, 211 {ComplexT(nan, -inf), ComplexT(inf, -inf)}, 212 // Consistent across GCC, clang, MSVC 213 {ComplexT(-inf, one), ComplexT(zero, inf)}, 214 {ComplexT(inf, one), ComplexT(inf, zero)}, 215 {ComplexT(-inf, -one), ComplexT(zero, -inf)}, 216 {ComplexT(inf, -one), ComplexT(inf, -zero)}, 217 {ComplexT(-inf, nan), ComplexT(nan, inf)}, 218 {ComplexT(inf, nan), ComplexT(inf, nan)}, 219 {ComplexT(zero, nan), ComplexT(nan, nan)}, 220 {ComplexT(one, nan), ComplexT(nan, nan)}, 221 {ComplexT(nan, zero), ComplexT(nan, nan)}, 222 {ComplexT(nan, one), ComplexT(nan, nan)}, 223 {ComplexT(nan, -one), ComplexT(nan, nan)}, 224 {ComplexT(nan, nan), ComplexT(nan, nan)}, 225 }; 226 227 for (int i=0; i<kNumCorners; ++i) { 228 const ComplexT& x = corners[i][0]; 229 const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1]; 230 VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx); 231 } 232 } 233 #endif 234 } 235 }; 236 237 template<typename T> 238 void check_rsqrt() { 239 check_rsqrt_impl<T>::run(); 240 } 241 242 EIGEN_DECLARE_TEST(numext) { 243 for(int k=0; k<g_repeat; ++k) 244 { 245 CALL_SUBTEST( check_abs<bool>() ); 246 CALL_SUBTEST( check_abs<signed char>() ); 247 CALL_SUBTEST( check_abs<unsigned char>() ); 248 CALL_SUBTEST( check_abs<short>() ); 249 CALL_SUBTEST( check_abs<unsigned short>() ); 250 CALL_SUBTEST( check_abs<int>() ); 251 CALL_SUBTEST( check_abs<unsigned int>() ); 252 CALL_SUBTEST( check_abs<long>() ); 253 CALL_SUBTEST( check_abs<unsigned long>() ); 254 CALL_SUBTEST( check_abs<half>() ); 255 CALL_SUBTEST( check_abs<bfloat16>() ); 256 CALL_SUBTEST( check_abs<float>() ); 257 CALL_SUBTEST( check_abs<double>() ); 258 CALL_SUBTEST( check_abs<long double>() ); 259 CALL_SUBTEST( check_abs<std::complex<float> >() ); 260 CALL_SUBTEST( check_abs<std::complex<double> >() ); 261 262 CALL_SUBTEST( check_arg<std::complex<float> >() ); 263 CALL_SUBTEST( check_arg<std::complex<double> >() ); 264 265 CALL_SUBTEST( check_sqrt<float>() ); 266 CALL_SUBTEST( check_sqrt<double>() ); 267 CALL_SUBTEST( check_sqrt<std::complex<float> >() ); 268 CALL_SUBTEST( check_sqrt<std::complex<double> >() ); 269 270 CALL_SUBTEST( check_rsqrt<float>() ); 271 CALL_SUBTEST( check_rsqrt<double>() ); 272 CALL_SUBTEST( check_rsqrt<std::complex<float> >() ); 273 CALL_SUBTEST( check_rsqrt<std::complex<double> >() ); 274 } 275 }