MathFunctionsImpl.h (7156B)
1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Pedro Gonnet (pedro.gonnet@gmail.com) 5 // Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr> 6 // 7 // This Source Code Form is subject to the terms of the Mozilla 8 // Public License v. 2.0. If a copy of the MPL was not distributed 9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 10 11 #ifndef EIGEN_MATHFUNCTIONSIMPL_H 12 #define EIGEN_MATHFUNCTIONSIMPL_H 13 14 namespace Eigen { 15 16 namespace internal { 17 18 /** \internal \returns the hyperbolic tan of \a a (coeff-wise) 19 Doesn't do anything fancy, just a 13/6-degree rational interpolant which 20 is accurate up to a couple of ulps in the (approximate) range [-8, 8], 21 outside of which tanh(x) = +/-1 in single precision. The input is clamped 22 to the range [-c, c]. The value c is chosen as the smallest value where 23 the approximation evaluates to exactly 1. In the reange [-0.0004, 0.0004] 24 the approxmation tanh(x) ~= x is used for better accuracy as x tends to zero. 25 26 This implementation works on both scalars and packets. 27 */ 28 template<typename T> 29 T generic_fast_tanh_float(const T& a_x) 30 { 31 // Clamp the inputs to the range [-c, c] 32 #ifdef EIGEN_VECTORIZE_FMA 33 const T plus_clamp = pset1<T>(7.99881172180175781f); 34 const T minus_clamp = pset1<T>(-7.99881172180175781f); 35 #else 36 const T plus_clamp = pset1<T>(7.90531110763549805f); 37 const T minus_clamp = pset1<T>(-7.90531110763549805f); 38 #endif 39 const T tiny = pset1<T>(0.0004f); 40 const T x = pmax(pmin(a_x, plus_clamp), minus_clamp); 41 const T tiny_mask = pcmp_lt(pabs(a_x), tiny); 42 // The monomial coefficients of the numerator polynomial (odd). 43 const T alpha_1 = pset1<T>(4.89352455891786e-03f); 44 const T alpha_3 = pset1<T>(6.37261928875436e-04f); 45 const T alpha_5 = pset1<T>(1.48572235717979e-05f); 46 const T alpha_7 = pset1<T>(5.12229709037114e-08f); 47 const T alpha_9 = pset1<T>(-8.60467152213735e-11f); 48 const T alpha_11 = pset1<T>(2.00018790482477e-13f); 49 const T alpha_13 = pset1<T>(-2.76076847742355e-16f); 50 51 // The monomial coefficients of the denominator polynomial (even). 52 const T beta_0 = pset1<T>(4.89352518554385e-03f); 53 const T beta_2 = pset1<T>(2.26843463243900e-03f); 54 const T beta_4 = pset1<T>(1.18534705686654e-04f); 55 const T beta_6 = pset1<T>(1.19825839466702e-06f); 56 57 // Since the polynomials are odd/even, we need x^2. 58 const T x2 = pmul(x, x); 59 60 // Evaluate the numerator polynomial p. 61 T p = pmadd(x2, alpha_13, alpha_11); 62 p = pmadd(x2, p, alpha_9); 63 p = pmadd(x2, p, alpha_7); 64 p = pmadd(x2, p, alpha_5); 65 p = pmadd(x2, p, alpha_3); 66 p = pmadd(x2, p, alpha_1); 67 p = pmul(x, p); 68 69 // Evaluate the denominator polynomial q. 70 T q = pmadd(x2, beta_6, beta_4); 71 q = pmadd(x2, q, beta_2); 72 q = pmadd(x2, q, beta_0); 73 74 // Divide the numerator by the denominator. 75 return pselect(tiny_mask, x, pdiv(p, q)); 76 } 77 78 template<typename RealScalar> 79 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 80 RealScalar positive_real_hypot(const RealScalar& x, const RealScalar& y) 81 { 82 // IEEE IEC 6059 special cases. 83 if ((numext::isinf)(x) || (numext::isinf)(y)) 84 return NumTraits<RealScalar>::infinity(); 85 if ((numext::isnan)(x) || (numext::isnan)(y)) 86 return NumTraits<RealScalar>::quiet_NaN(); 87 88 EIGEN_USING_STD(sqrt); 89 RealScalar p, qp; 90 p = numext::maxi(x,y); 91 if(p==RealScalar(0)) return RealScalar(0); 92 qp = numext::mini(y,x) / p; 93 return p * sqrt(RealScalar(1) + qp*qp); 94 } 95 96 template<typename Scalar> 97 struct hypot_impl 98 { 99 typedef typename NumTraits<Scalar>::Real RealScalar; 100 static EIGEN_DEVICE_FUNC 101 inline RealScalar run(const Scalar& x, const Scalar& y) 102 { 103 EIGEN_USING_STD(abs); 104 return positive_real_hypot<RealScalar>(abs(x), abs(y)); 105 } 106 }; 107 108 // Generic complex sqrt implementation that correctly handles corner cases 109 // according to https://en.cppreference.com/w/cpp/numeric/complex/sqrt 110 template<typename T> 111 EIGEN_DEVICE_FUNC std::complex<T> complex_sqrt(const std::complex<T>& z) { 112 // Computes the principal sqrt of the input. 113 // 114 // For a complex square root of the number x + i*y. We want to find real 115 // numbers u and v such that 116 // (u + i*v)^2 = x + i*y <=> 117 // u^2 - v^2 + i*2*u*v = x + i*v. 118 // By equating the real and imaginary parts we get: 119 // u^2 - v^2 = x 120 // 2*u*v = y. 121 // 122 // For x >= 0, this has the numerically stable solution 123 // u = sqrt(0.5 * (x + sqrt(x^2 + y^2))) 124 // v = y / (2 * u) 125 // and for x < 0, 126 // v = sign(y) * sqrt(0.5 * (-x + sqrt(x^2 + y^2))) 127 // u = y / (2 * v) 128 // 129 // Letting w = sqrt(0.5 * (|x| + |z|)), 130 // if x == 0: u = w, v = sign(y) * w 131 // if x > 0: u = w, v = y / (2 * w) 132 // if x < 0: u = |y| / (2 * w), v = sign(y) * w 133 134 const T x = numext::real(z); 135 const T y = numext::imag(z); 136 const T zero = T(0); 137 const T w = numext::sqrt(T(0.5) * (numext::abs(x) + numext::hypot(x, y))); 138 139 return 140 (numext::isinf)(y) ? std::complex<T>(NumTraits<T>::infinity(), y) 141 : x == zero ? std::complex<T>(w, y < zero ? -w : w) 142 : x > zero ? std::complex<T>(w, y / (2 * w)) 143 : std::complex<T>(numext::abs(y) / (2 * w), y < zero ? -w : w ); 144 } 145 146 // Generic complex rsqrt implementation. 147 template<typename T> 148 EIGEN_DEVICE_FUNC std::complex<T> complex_rsqrt(const std::complex<T>& z) { 149 // Computes the principal reciprocal sqrt of the input. 150 // 151 // For a complex reciprocal square root of the number z = x + i*y. We want to 152 // find real numbers u and v such that 153 // (u + i*v)^2 = 1 / (x + i*y) <=> 154 // u^2 - v^2 + i*2*u*v = x/|z|^2 - i*v/|z|^2. 155 // By equating the real and imaginary parts we get: 156 // u^2 - v^2 = x/|z|^2 157 // 2*u*v = y/|z|^2. 158 // 159 // For x >= 0, this has the numerically stable solution 160 // u = sqrt(0.5 * (x + |z|)) / |z| 161 // v = -y / (2 * u * |z|) 162 // and for x < 0, 163 // v = -sign(y) * sqrt(0.5 * (-x + |z|)) / |z| 164 // u = -y / (2 * v * |z|) 165 // 166 // Letting w = sqrt(0.5 * (|x| + |z|)), 167 // if x == 0: u = w / |z|, v = -sign(y) * w / |z| 168 // if x > 0: u = w / |z|, v = -y / (2 * w * |z|) 169 // if x < 0: u = |y| / (2 * w * |z|), v = -sign(y) * w / |z| 170 171 const T x = numext::real(z); 172 const T y = numext::imag(z); 173 const T zero = T(0); 174 175 const T abs_z = numext::hypot(x, y); 176 const T w = numext::sqrt(T(0.5) * (numext::abs(x) + abs_z)); 177 const T woz = w / abs_z; 178 // Corner cases consistent with 1/sqrt(z) on gcc/clang. 179 return 180 abs_z == zero ? std::complex<T>(NumTraits<T>::infinity(), NumTraits<T>::quiet_NaN()) 181 : ((numext::isinf)(x) || (numext::isinf)(y)) ? std::complex<T>(zero, zero) 182 : x == zero ? std::complex<T>(woz, y < zero ? woz : -woz) 183 : x > zero ? std::complex<T>(woz, -y / (2 * w * abs_z)) 184 : std::complex<T>(numext::abs(y) / (2 * w * abs_z), y < zero ? woz : -woz ); 185 } 186 187 template<typename T> 188 EIGEN_DEVICE_FUNC std::complex<T> complex_log(const std::complex<T>& z) { 189 // Computes complex log. 190 T a = numext::abs(z); 191 EIGEN_USING_STD(atan2); 192 T b = atan2(z.imag(), z.real()); 193 return std::complex<T>(numext::log(a), b); 194 } 195 196 } // end namespace internal 197 198 } // end namespace Eigen 199 200 #endif // EIGEN_MATHFUNCTIONSIMPL_H