special_functions.cpp (22854B)
1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2016 Gael Guennebaud <gael.guennebaud@inria.fr> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #include <limits.h> 11 #include "main.h" 12 #include "../Eigen/SpecialFunctions" 13 14 // Hack to allow "implicit" conversions from double to Scalar via comma-initialization. 15 template<typename Derived> 16 Eigen::CommaInitializer<Derived> operator<<(Eigen::DenseBase<Derived>& dense, double v) { 17 return (dense << static_cast<typename Derived::Scalar>(v)); 18 } 19 20 template<typename XprType> 21 Eigen::CommaInitializer<XprType>& operator,(Eigen::CommaInitializer<XprType>& ci, double v) { 22 return (ci, static_cast<typename XprType::Scalar>(v)); 23 } 24 25 template<typename X, typename Y> 26 void verify_component_wise(const X& x, const Y& y) 27 { 28 for(Index i=0; i<x.size(); ++i) 29 { 30 if((numext::isfinite)(y(i))) 31 VERIFY_IS_APPROX( x(i), y(i) ); 32 else if((numext::isnan)(y(i))) 33 VERIFY((numext::isnan)(x(i))); 34 else 35 VERIFY_IS_EQUAL( x(i), y(i) ); 36 } 37 } 38 39 template<typename ArrayType> void array_special_functions() 40 { 41 using std::abs; 42 using std::sqrt; 43 typedef typename ArrayType::Scalar Scalar; 44 typedef typename NumTraits<Scalar>::Real RealScalar; 45 46 Scalar plusinf = std::numeric_limits<Scalar>::infinity(); 47 Scalar nan = std::numeric_limits<Scalar>::quiet_NaN(); 48 49 Index rows = internal::random<Index>(1,30); 50 Index cols = 1; 51 52 // API 53 { 54 ArrayType m1 = ArrayType::Random(rows,cols); 55 #if EIGEN_HAS_C99_MATH 56 VERIFY_IS_APPROX(m1.lgamma(), lgamma(m1)); 57 VERIFY_IS_APPROX(m1.digamma(), digamma(m1)); 58 VERIFY_IS_APPROX(m1.erf(), erf(m1)); 59 VERIFY_IS_APPROX(m1.erfc(), erfc(m1)); 60 #endif // EIGEN_HAS_C99_MATH 61 } 62 63 64 #if EIGEN_HAS_C99_MATH 65 // check special functions (comparing against numpy implementation) 66 if (!NumTraits<Scalar>::IsComplex) 67 { 68 69 { 70 ArrayType m1 = ArrayType::Random(rows,cols); 71 ArrayType m2 = ArrayType::Random(rows,cols); 72 73 // Test various propreties of igamma & igammac. These are normalized 74 // gamma integrals where 75 // igammac(a, x) = Gamma(a, x) / Gamma(a) 76 // igamma(a, x) = gamma(a, x) / Gamma(a) 77 // where Gamma and gamma are considered the standard unnormalized 78 // upper and lower incomplete gamma functions, respectively. 79 ArrayType a = m1.abs() + Scalar(2); 80 ArrayType x = m2.abs() + Scalar(2); 81 ArrayType zero = ArrayType::Zero(rows, cols); 82 ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0)); 83 ArrayType a_m1 = a - one; 84 ArrayType Gamma_a_x = Eigen::igammac(a, x) * a.lgamma().exp(); 85 ArrayType Gamma_a_m1_x = Eigen::igammac(a_m1, x) * a_m1.lgamma().exp(); 86 ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp(); 87 ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp(); 88 89 90 // Gamma(a, 0) == Gamma(a) 91 VERIFY_IS_APPROX(Eigen::igammac(a, zero), one); 92 93 // Gamma(a, x) + gamma(a, x) == Gamma(a) 94 VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp()); 95 96 // Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x) 97 VERIFY_IS_APPROX(Gamma_a_x, (a - Scalar(1)) * Gamma_a_m1_x + x.pow(a-Scalar(1)) * (-x).exp()); 98 99 // gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x) 100 VERIFY_IS_APPROX(gamma_a_x, (a - Scalar(1)) * gamma_a_m1_x - x.pow(a-Scalar(1)) * (-x).exp()); 101 } 102 { 103 // Verify for large a and x that values are between 0 and 1. 104 ArrayType m1 = ArrayType::Random(rows,cols); 105 ArrayType m2 = ArrayType::Random(rows,cols); 106 int max_exponent = std::numeric_limits<Scalar>::max_exponent10; 107 ArrayType a = m1.abs() * Scalar(pow(10., max_exponent - 1)); 108 ArrayType x = m2.abs() * Scalar(pow(10., max_exponent - 1)); 109 for (int i = 0; i < a.size(); ++i) { 110 Scalar igam = numext::igamma(a(i), x(i)); 111 VERIFY(0 <= igam); 112 VERIFY(igam <= 1); 113 } 114 } 115 116 { 117 // Check exact values of igamma and igammac against a third party calculation. 118 Scalar a_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)}; 119 Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)}; 120 121 // location i*6+j corresponds to a_s[i], x_s[j]. 122 Scalar igamma_s[][6] = { 123 {Scalar(0.0), nan, nan, nan, nan, nan}, 124 {Scalar(0.0), Scalar(0.6321205588285578), Scalar(0.7768698398515702), 125 Scalar(0.9816843611112658), Scalar(9.999500016666262e-05), 126 Scalar(1.0)}, 127 {Scalar(0.0), Scalar(0.4275932955291202), Scalar(0.608374823728911), 128 Scalar(0.9539882943107686), Scalar(7.522076445089201e-07), 129 Scalar(1.0)}, 130 {Scalar(0.0), Scalar(0.01898815687615381), 131 Scalar(0.06564245437845008), Scalar(0.5665298796332909), 132 Scalar(4.166333347221828e-18), Scalar(1.0)}, 133 {Scalar(0.0), Scalar(0.9999780593618628), Scalar(0.9999899967080838), 134 Scalar(0.9999996219837988), Scalar(0.9991370418689945), Scalar(1.0)}, 135 {Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0), 136 Scalar(0.5042041932513908)}}; 137 Scalar igammac_s[][6] = { 138 {nan, nan, nan, nan, nan, nan}, 139 {Scalar(1.0), Scalar(0.36787944117144233), 140 Scalar(0.22313016014842982), Scalar(0.018315638888734182), 141 Scalar(0.9999000049998333), Scalar(0.0)}, 142 {Scalar(1.0), Scalar(0.5724067044708798), Scalar(0.3916251762710878), 143 Scalar(0.04601170568923136), Scalar(0.9999992477923555), 144 Scalar(0.0)}, 145 {Scalar(1.0), Scalar(0.9810118431238462), Scalar(0.9343575456215499), 146 Scalar(0.4334701203667089), Scalar(1.0), Scalar(0.0)}, 147 {Scalar(1.0), Scalar(2.1940638138146658e-05), 148 Scalar(1.0003291916285e-05), Scalar(3.7801620118431334e-07), 149 Scalar(0.0008629581310054535), Scalar(0.0)}, 150 {Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0), 151 Scalar(0.49579580674813944)}}; 152 153 for (int i = 0; i < 6; ++i) { 154 for (int j = 0; j < 6; ++j) { 155 if ((std::isnan)(igamma_s[i][j])) { 156 VERIFY((std::isnan)(numext::igamma(a_s[i], x_s[j]))); 157 } else { 158 VERIFY_IS_APPROX(numext::igamma(a_s[i], x_s[j]), igamma_s[i][j]); 159 } 160 161 if ((std::isnan)(igammac_s[i][j])) { 162 VERIFY((std::isnan)(numext::igammac(a_s[i], x_s[j]))); 163 } else { 164 VERIFY_IS_APPROX(numext::igammac(a_s[i], x_s[j]), igammac_s[i][j]); 165 } 166 } 167 } 168 } 169 } 170 #endif // EIGEN_HAS_C99_MATH 171 172 // Check the ndtri function against scipy.special.ndtri 173 { 174 ArrayType x(7), res(7), ref(7); 175 x << 0.5, 0.2, 0.8, 0.9, 0.1, 0.99, 0.01; 176 ref << 0., -0.8416212335729142, 0.8416212335729142, 1.2815515655446004, -1.2815515655446004, 2.3263478740408408, -2.3263478740408408; 177 CALL_SUBTEST( verify_component_wise(ref, ref); ); 178 CALL_SUBTEST( res = x.ndtri(); verify_component_wise(res, ref); ); 179 CALL_SUBTEST( res = ndtri(x); verify_component_wise(res, ref); ); 180 181 // ndtri(normal_cdf(x)) ~= x 182 CALL_SUBTEST( 183 ArrayType m1 = ArrayType::Random(32); 184 using std::sqrt; 185 186 ArrayType cdf_val = (m1 / Scalar(sqrt(2.))).erf(); 187 cdf_val = (cdf_val + Scalar(1)) / Scalar(2); 188 verify_component_wise(cdf_val.ndtri(), m1);); 189 190 } 191 192 // Check the zeta function against scipy.special.zeta 193 { 194 ArrayType x(10), q(10), res(10), ref(10); 195 x << 1.5, 4, 10.5, 10000.5, 3, 1, 0.9, 2, 3, 4; 196 q << 2, 1.5, 3, 1.0001, -2.5, 1.2345, 1.2345, -1, -2, -3; 197 ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan, plusinf, nan, plusinf; 198 CALL_SUBTEST( verify_component_wise(ref, ref); ); 199 CALL_SUBTEST( res = x.zeta(q); verify_component_wise(res, ref); ); 200 CALL_SUBTEST( res = zeta(x,q); verify_component_wise(res, ref); ); 201 } 202 203 // digamma 204 { 205 ArrayType x(9), res(9), ref(9); 206 x << 1, 1.5, 4, -10.5, 10000.5, 0, -1, -2, -3; 207 ref << -0.5772156649015329, 0.03648997397857645, 1.2561176684318, 2.398239129535781, 9.210340372392849, nan, nan, nan, nan; 208 CALL_SUBTEST( verify_component_wise(ref, ref); ); 209 210 CALL_SUBTEST( res = x.digamma(); verify_component_wise(res, ref); ); 211 CALL_SUBTEST( res = digamma(x); verify_component_wise(res, ref); ); 212 } 213 214 #if EIGEN_HAS_C99_MATH 215 { 216 ArrayType n(16), x(16), res(16), ref(16); 217 n << 1, 1, 1, 1.5, 17, 31, 28, 8, 42, 147, 170, -1, 0, 1, 2, 3; 218 x << 2, 3, 25.5, 1.5, 4.7, 11.8, 17.7, 30.2, 15.8, 54.1, 64, -1, -2, -3, -4, -5; 219 ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927, nan, nan, plusinf, nan, plusinf; 220 CALL_SUBTEST( verify_component_wise(ref, ref); ); 221 222 if(sizeof(RealScalar)>=8) { // double 223 // Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232 224 // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); ); 225 CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res, ref); ); 226 } 227 else { 228 // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); ); 229 CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); ); 230 } 231 } 232 #endif 233 234 #if EIGEN_HAS_C99_MATH 235 { 236 // Inputs and ground truth generated with scipy via: 237 // a = np.logspace(-3, 3, 5) - 1e-3 238 // b = np.logspace(-3, 3, 5) - 1e-3 239 // x = np.linspace(-0.1, 1.1, 5) 240 // (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x)) 241 // full_a = full_a.flatten().tolist() # same for full_b, full_x 242 // v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist() 243 // 244 // Note in Eigen, we call betainc with arguments in the order (x, a, b). 245 ArrayType a(125); 246 ArrayType b(125); 247 ArrayType x(125); 248 ArrayType v(125); 249 ArrayType res(125); 250 251 a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 252 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 253 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 254 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 255 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 256 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 257 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 258 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 259 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 260 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 261 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 262 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 263 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 264 31.62177660168379, 31.62177660168379, 31.62177660168379, 265 31.62177660168379, 31.62177660168379, 31.62177660168379, 266 31.62177660168379, 31.62177660168379, 31.62177660168379, 267 31.62177660168379, 31.62177660168379, 31.62177660168379, 268 31.62177660168379, 31.62177660168379, 31.62177660168379, 269 31.62177660168379, 31.62177660168379, 31.62177660168379, 270 31.62177660168379, 31.62177660168379, 31.62177660168379, 271 31.62177660168379, 31.62177660168379, 31.62177660168379, 272 31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 273 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 274 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 275 999.999, 999.999, 999.999; 276 277 b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379, 278 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999, 279 0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379, 280 31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999, 281 999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 282 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 283 0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 284 0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379, 285 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999, 286 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 287 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 288 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 289 31.62177660168379, 31.62177660168379, 31.62177660168379, 290 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999, 291 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 292 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 293 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 294 31.62177660168379, 31.62177660168379, 31.62177660168379, 295 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999, 296 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 297 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 298 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 299 31.62177660168379, 31.62177660168379, 31.62177660168379, 300 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999, 301 999.999, 999.999; 302 303 x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 304 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 305 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 306 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, 307 -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 308 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 309 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 310 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 311 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 312 0.8, 1.1; 313 314 v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 315 nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, 316 nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan, 317 0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan, 318 0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan, 319 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan, 320 nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256, 321 0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001, 322 0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403, 323 0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999, 324 0.9999999999999999, nan, nan, nan, nan, nan, nan, nan, 325 1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06, 326 nan, nan, 7.864342668429763e-23, 3.015969667594166e-10, 327 0.0008598571564165444, nan, nan, 6.031987710123844e-08, 328 0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999, 329 0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan, 330 nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan, 331 0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0, 332 3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan, 333 2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan; 334 335 CALL_SUBTEST(res = betainc(a, b, x); 336 verify_component_wise(res, v);); 337 } 338 339 // Test various properties of betainc 340 { 341 ArrayType m1 = ArrayType::Random(32); 342 ArrayType m2 = ArrayType::Random(32); 343 ArrayType m3 = ArrayType::Random(32); 344 ArrayType one = ArrayType::Constant(32, Scalar(1.0)); 345 const Scalar eps = std::numeric_limits<Scalar>::epsilon(); 346 ArrayType a = (m1 * Scalar(4)).exp(); 347 ArrayType b = (m2 * Scalar(4)).exp(); 348 ArrayType x = m3.abs(); 349 350 // betainc(a, 1, x) == x**a 351 CALL_SUBTEST( 352 ArrayType test = betainc(a, one, x); 353 ArrayType expected = x.pow(a); 354 verify_component_wise(test, expected);); 355 356 // betainc(1, b, x) == 1 - (1 - x)**b 357 CALL_SUBTEST( 358 ArrayType test = betainc(one, b, x); 359 ArrayType expected = one - (one - x).pow(b); 360 verify_component_wise(test, expected);); 361 362 // betainc(a, b, x) == 1 - betainc(b, a, 1-x) 363 CALL_SUBTEST( 364 ArrayType test = betainc(a, b, x) + betainc(b, a, one - x); 365 ArrayType expected = one; 366 verify_component_wise(test, expected);); 367 368 // betainc(a+1, b, x) = betainc(a, b, x) - x**a * (1 - x)**b / (a * beta(a, b)) 369 CALL_SUBTEST( 370 ArrayType num = x.pow(a) * (one - x).pow(b); 371 ArrayType denom = a * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp(); 372 // Add eps to rhs and lhs so that component-wise test doesn't result in 373 // nans when both outputs are zeros. 374 ArrayType expected = betainc(a, b, x) - num / denom + eps; 375 ArrayType test = betainc(a + one, b, x) + eps; 376 if (sizeof(Scalar) >= 8) { // double 377 verify_component_wise(test, expected); 378 } else { 379 // Reason for limited test: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232 380 verify_component_wise(test.head(8), expected.head(8)); 381 }); 382 383 // betainc(a, b+1, x) = betainc(a, b, x) + x**a * (1 - x)**b / (b * beta(a, b)) 384 CALL_SUBTEST( 385 // Add eps to rhs and lhs so that component-wise test doesn't result in 386 // nans when both outputs are zeros. 387 ArrayType num = x.pow(a) * (one - x).pow(b); 388 ArrayType denom = b * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp(); 389 ArrayType expected = betainc(a, b, x) + num / denom + eps; 390 ArrayType test = betainc(a, b + one, x) + eps; 391 verify_component_wise(test, expected);); 392 } 393 #endif // EIGEN_HAS_C99_MATH 394 395 /* Code to generate the data for the following two test cases. 396 N = 5 397 np.random.seed(3) 398 399 a = np.logspace(-2, 3, 6) 400 a = np.ravel(np.tile(np.reshape(a, [-1, 1]), [1, N])) 401 x = np.random.gamma(a, 1.0) 402 x = np.maximum(x, np.finfo(np.float32).tiny) 403 404 def igamma(a, x): 405 return mpmath.gammainc(a, 0, x, regularized=True) 406 407 def igamma_der_a(a, x): 408 res = mpmath.diff(lambda a_prime: igamma(a_prime, x), a) 409 return np.float64(res) 410 411 def gamma_sample_der_alpha(a, x): 412 igamma_x = igamma(a, x) 413 def igammainv_of_igamma(a_prime): 414 return mpmath.findroot(lambda x_prime: igamma(a_prime, x_prime) - 415 igamma_x, x, solver='newton') 416 return np.float64(mpmath.diff(igammainv_of_igamma, a)) 417 418 v_igamma_der_a = np.vectorize(igamma_der_a)(a, x) 419 v_gamma_sample_der_alpha = np.vectorize(gamma_sample_der_alpha)(a, x) 420 */ 421 422 #if EIGEN_HAS_C99_MATH 423 // Test igamma_der_a 424 { 425 ArrayType a(30); 426 ArrayType x(30); 427 ArrayType res(30); 428 ArrayType v(30); 429 430 a << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0, 431 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0, 432 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0; 433 434 x << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05, 435 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16, 436 0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06, 437 0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426, 438 0.786686768458, 7.63873279537, 13.1944344379, 11.896042354, 439 10.5830172417, 10.5020942233, 92.8918587747, 95.003720371, 440 86.3715926467, 96.0330217672, 82.6389930677, 968.702906754, 441 969.463546828, 1001.79726022, 955.047416547, 1044.27458568; 442 443 v << -32.7256441441, -36.4394150514, -9.66467612263, -36.4394150514, 444 -36.4394150514, -1.0891900302, -2.66351229645, -2.48666868596, 445 -0.929700494428, -3.56327722764, -0.455320135314, -0.391437214323, 446 -0.491352055991, -0.350454834292, -0.471773162921, -0.104084440522, 447 -0.0723646747909, -0.0992828975532, -0.121638215446, -0.122619605294, 448 -0.0317670267286, -0.0359974812869, -0.0154359225363, -0.0375775365921, 449 -0.00794899153653, -0.00777303219211, -0.00796085782042, 450 -0.0125850719397, -0.00455500206958, -0.00476436993148; 451 452 CALL_SUBTEST(res = igamma_der_a(a, x); verify_component_wise(res, v);); 453 } 454 455 // Test gamma_sample_der_alpha 456 { 457 ArrayType alpha(30); 458 ArrayType sample(30); 459 ArrayType res(30); 460 ArrayType v(30); 461 462 alpha << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 463 1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0, 464 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0; 465 466 sample << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05, 467 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16, 468 0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06, 469 0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426, 470 0.786686768458, 7.63873279537, 13.1944344379, 11.896042354, 471 10.5830172417, 10.5020942233, 92.8918587747, 95.003720371, 472 86.3715926467, 96.0330217672, 82.6389930677, 968.702906754, 473 969.463546828, 1001.79726022, 955.047416547, 1044.27458568; 474 475 v << 7.42424742367e-23, 1.02004297287e-34, 0.0130155240738, 476 1.02004297287e-34, 1.02004297287e-34, 1.96505168277e-13, 0.525575786243, 477 0.713903991771, 2.32077561808e-14, 0.000179348049886, 0.635500453302, 478 1.27561284917, 0.878125852156, 0.41565819538, 1.03606488534, 479 0.885964824887, 1.16424049334, 1.10764479598, 1.04590810812, 480 1.04193666963, 0.965193152414, 0.976217589464, 0.93008035061, 481 0.98153216096, 0.909196397698, 0.98434963993, 0.984738050206, 482 1.00106492525, 0.97734200649, 1.02198794179; 483 484 CALL_SUBTEST(res = gamma_sample_der_alpha(alpha, sample); 485 verify_component_wise(res, v);); 486 } 487 #endif // EIGEN_HAS_C99_MATH 488 } 489 490 EIGEN_DECLARE_TEST(special_functions) 491 { 492 CALL_SUBTEST_1(array_special_functions<ArrayXf>()); 493 CALL_SUBTEST_2(array_special_functions<ArrayXd>()); 494 // TODO(cantonios): half/bfloat16 don't have enough precision to reproduce results above. 495 // CALL_SUBTEST_3(array_special_functions<ArrayX<Eigen::half>>()); 496 // CALL_SUBTEST_4(array_special_functions<ArrayX<Eigen::bfloat16>>()); 497 }