1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Gael Guennebaud <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10 #include <limits.h>
11 #include "main.h"
12 #include "../Eigen/SpecialFunctions"
13
14 // Hack to allow "implicit" conversions from double to Scalar via comma-initialization.
15 template<typename Derived>
operator <<(Eigen::DenseBase<Derived> & dense,double v)16 Eigen::CommaInitializer<Derived> operator<<(Eigen::DenseBase<Derived>& dense, double v) {
17 return (dense << static_cast<typename Derived::Scalar>(v));
18 }
19
20 template<typename XprType>
operator ,(Eigen::CommaInitializer<XprType> & ci,double v)21 Eigen::CommaInitializer<XprType>& operator,(Eigen::CommaInitializer<XprType>& ci, double v) {
22 return (ci, static_cast<typename XprType::Scalar>(v));
23 }
24
25 template<typename X, typename Y>
verify_component_wise(const X & x,const Y & y)26 void verify_component_wise(const X& x, const Y& y)
27 {
28 for(Index i=0; i<x.size(); ++i)
29 {
30 if((numext::isfinite)(y(i)))
31 VERIFY_IS_APPROX( x(i), y(i) );
32 else if((numext::isnan)(y(i)))
33 VERIFY((numext::isnan)(x(i)));
34 else
35 VERIFY_IS_EQUAL( x(i), y(i) );
36 }
37 }
38
array_special_functions()39 template<typename ArrayType> void array_special_functions()
40 {
41 using std::abs;
42 using std::sqrt;
43 typedef typename ArrayType::Scalar Scalar;
44 typedef typename NumTraits<Scalar>::Real RealScalar;
45
46 Scalar plusinf = std::numeric_limits<Scalar>::infinity();
47 Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
48
49 Index rows = internal::random<Index>(1,30);
50 Index cols = 1;
51
52 // API
53 {
54 ArrayType m1 = ArrayType::Random(rows,cols);
55 #if EIGEN_HAS_C99_MATH
56 VERIFY_IS_APPROX(m1.lgamma(), lgamma(m1));
57 VERIFY_IS_APPROX(m1.digamma(), digamma(m1));
58 VERIFY_IS_APPROX(m1.erf(), erf(m1));
59 VERIFY_IS_APPROX(m1.erfc(), erfc(m1));
60 #endif // EIGEN_HAS_C99_MATH
61 }
62
63
64 #if EIGEN_HAS_C99_MATH
65 // check special functions (comparing against numpy implementation)
66 if (!NumTraits<Scalar>::IsComplex)
67 {
68
69 {
70 ArrayType m1 = ArrayType::Random(rows,cols);
71 ArrayType m2 = ArrayType::Random(rows,cols);
72
73 // Test various propreties of igamma & igammac. These are normalized
74 // gamma integrals where
75 // igammac(a, x) = Gamma(a, x) / Gamma(a)
76 // igamma(a, x) = gamma(a, x) / Gamma(a)
77 // where Gamma and gamma are considered the standard unnormalized
78 // upper and lower incomplete gamma functions, respectively.
79 ArrayType a = m1.abs() + Scalar(2);
80 ArrayType x = m2.abs() + Scalar(2);
81 ArrayType zero = ArrayType::Zero(rows, cols);
82 ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0));
83 ArrayType a_m1 = a - one;
84 ArrayType Gamma_a_x = Eigen::igammac(a, x) * a.lgamma().exp();
85 ArrayType Gamma_a_m1_x = Eigen::igammac(a_m1, x) * a_m1.lgamma().exp();
86 ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp();
87 ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp();
88
89
90 // Gamma(a, 0) == Gamma(a)
91 VERIFY_IS_APPROX(Eigen::igammac(a, zero), one);
92
93 // Gamma(a, x) + gamma(a, x) == Gamma(a)
94 VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp());
95
96 // Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x)
97 VERIFY_IS_APPROX(Gamma_a_x, (a - Scalar(1)) * Gamma_a_m1_x + x.pow(a-Scalar(1)) * (-x).exp());
98
99 // gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
100 VERIFY_IS_APPROX(gamma_a_x, (a - Scalar(1)) * gamma_a_m1_x - x.pow(a-Scalar(1)) * (-x).exp());
101 }
102 {
103 // Verify for large a and x that values are between 0 and 1.
104 ArrayType m1 = ArrayType::Random(rows,cols);
105 ArrayType m2 = ArrayType::Random(rows,cols);
106 int max_exponent = std::numeric_limits<Scalar>::max_exponent10;
107 ArrayType a = m1.abs() * Scalar(pow(10., max_exponent - 1));
108 ArrayType x = m2.abs() * Scalar(pow(10., max_exponent - 1));
109 for (int i = 0; i < a.size(); ++i) {
110 Scalar igam = numext::igamma(a(i), x(i));
111 VERIFY(0 <= igam);
112 VERIFY(igam <= 1);
113 }
114 }
115
116 {
117 // Check exact values of igamma and igammac against a third party calculation.
118 Scalar a_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
119 Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
120
121 // location i*6+j corresponds to a_s[i], x_s[j].
122 Scalar igamma_s[][6] = {
123 {Scalar(0.0), nan, nan, nan, nan, nan},
124 {Scalar(0.0), Scalar(0.6321205588285578), Scalar(0.7768698398515702),
125 Scalar(0.9816843611112658), Scalar(9.999500016666262e-05),
126 Scalar(1.0)},
127 {Scalar(0.0), Scalar(0.4275932955291202), Scalar(0.608374823728911),
128 Scalar(0.9539882943107686), Scalar(7.522076445089201e-07),
129 Scalar(1.0)},
130 {Scalar(0.0), Scalar(0.01898815687615381),
131 Scalar(0.06564245437845008), Scalar(0.5665298796332909),
132 Scalar(4.166333347221828e-18), Scalar(1.0)},
133 {Scalar(0.0), Scalar(0.9999780593618628), Scalar(0.9999899967080838),
134 Scalar(0.9999996219837988), Scalar(0.9991370418689945), Scalar(1.0)},
135 {Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0),
136 Scalar(0.5042041932513908)}};
137 Scalar igammac_s[][6] = {
138 {nan, nan, nan, nan, nan, nan},
139 {Scalar(1.0), Scalar(0.36787944117144233),
140 Scalar(0.22313016014842982), Scalar(0.018315638888734182),
141 Scalar(0.9999000049998333), Scalar(0.0)},
142 {Scalar(1.0), Scalar(0.5724067044708798), Scalar(0.3916251762710878),
143 Scalar(0.04601170568923136), Scalar(0.9999992477923555),
144 Scalar(0.0)},
145 {Scalar(1.0), Scalar(0.9810118431238462), Scalar(0.9343575456215499),
146 Scalar(0.4334701203667089), Scalar(1.0), Scalar(0.0)},
147 {Scalar(1.0), Scalar(2.1940638138146658e-05),
148 Scalar(1.0003291916285e-05), Scalar(3.7801620118431334e-07),
149 Scalar(0.0008629581310054535), Scalar(0.0)},
150 {Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0),
151 Scalar(0.49579580674813944)}};
152
153 for (int i = 0; i < 6; ++i) {
154 for (int j = 0; j < 6; ++j) {
155 if ((std::isnan)(igamma_s[i][j])) {
156 VERIFY((std::isnan)(numext::igamma(a_s[i], x_s[j])));
157 } else {
158 VERIFY_IS_APPROX(numext::igamma(a_s[i], x_s[j]), igamma_s[i][j]);
159 }
160
161 if ((std::isnan)(igammac_s[i][j])) {
162 VERIFY((std::isnan)(numext::igammac(a_s[i], x_s[j])));
163 } else {
164 VERIFY_IS_APPROX(numext::igammac(a_s[i], x_s[j]), igammac_s[i][j]);
165 }
166 }
167 }
168 }
169 }
170 #endif // EIGEN_HAS_C99_MATH
171
172 // Check the ndtri function against scipy.special.ndtri
173 {
174 ArrayType x(7), res(7), ref(7);
175 x << 0.5, 0.2, 0.8, 0.9, 0.1, 0.99, 0.01;
176 ref << 0., -0.8416212335729142, 0.8416212335729142, 1.2815515655446004, -1.2815515655446004, 2.3263478740408408, -2.3263478740408408;
177 CALL_SUBTEST( verify_component_wise(ref, ref); );
178 CALL_SUBTEST( res = x.ndtri(); verify_component_wise(res, ref); );
179 CALL_SUBTEST( res = ndtri(x); verify_component_wise(res, ref); );
180
181 // ndtri(normal_cdf(x)) ~= x
182 CALL_SUBTEST(
183 ArrayType m1 = ArrayType::Random(32);
184 using std::sqrt;
185
186 ArrayType cdf_val = (m1 / Scalar(sqrt(2.))).erf();
187 cdf_val = (cdf_val + Scalar(1)) / Scalar(2);
188 verify_component_wise(cdf_val.ndtri(), m1););
189
190 }
191
192 // Check the zeta function against scipy.special.zeta
193 {
194 ArrayType x(10), q(10), res(10), ref(10);
195 x << 1.5, 4, 10.5, 10000.5, 3, 1, 0.9, 2, 3, 4;
196 q << 2, 1.5, 3, 1.0001, -2.5, 1.2345, 1.2345, -1, -2, -3;
197 ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan, plusinf, nan, plusinf;
198 CALL_SUBTEST( verify_component_wise(ref, ref); );
199 CALL_SUBTEST( res = x.zeta(q); verify_component_wise(res, ref); );
200 CALL_SUBTEST( res = zeta(x,q); verify_component_wise(res, ref); );
201 }
202
203 // digamma
204 {
205 ArrayType x(9), res(9), ref(9);
206 x << 1, 1.5, 4, -10.5, 10000.5, 0, -1, -2, -3;
207 ref << -0.5772156649015329, 0.03648997397857645, 1.2561176684318, 2.398239129535781, 9.210340372392849, nan, nan, nan, nan;
208 CALL_SUBTEST( verify_component_wise(ref, ref); );
209
210 CALL_SUBTEST( res = x.digamma(); verify_component_wise(res, ref); );
211 CALL_SUBTEST( res = digamma(x); verify_component_wise(res, ref); );
212 }
213
214 #if EIGEN_HAS_C99_MATH
215 {
216 ArrayType n(16), x(16), res(16), ref(16);
217 n << 1, 1, 1, 1.5, 17, 31, 28, 8, 42, 147, 170, -1, 0, 1, 2, 3;
218 x << 2, 3, 25.5, 1.5, 4.7, 11.8, 17.7, 30.2, 15.8, 54.1, 64, -1, -2, -3, -4, -5;
219 ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927, nan, nan, plusinf, nan, plusinf;
220 CALL_SUBTEST( verify_component_wise(ref, ref); );
221
222 if(sizeof(RealScalar)>=8) { // double
223 // Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
224 // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
225 CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res, ref); );
226 }
227 else {
228 // CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
229 CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); );
230 }
231 }
232 #endif
233
234 #if EIGEN_HAS_C99_MATH
235 {
236 // Inputs and ground truth generated with scipy via:
237 // a = np.logspace(-3, 3, 5) - 1e-3
238 // b = np.logspace(-3, 3, 5) - 1e-3
239 // x = np.linspace(-0.1, 1.1, 5)
240 // (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x))
241 // full_a = full_a.flatten().tolist() # same for full_b, full_x
242 // v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist()
243 //
244 // Note in Eigen, we call betainc with arguments in the order (x, a, b).
245 ArrayType a(125);
246 ArrayType b(125);
247 ArrayType x(125);
248 ArrayType v(125);
249 ArrayType res(125);
250
251 a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
252 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
253 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
254 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
255 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
256 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
257 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
258 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
259 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
260 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
261 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
262 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
263 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
264 31.62177660168379, 31.62177660168379, 31.62177660168379,
265 31.62177660168379, 31.62177660168379, 31.62177660168379,
266 31.62177660168379, 31.62177660168379, 31.62177660168379,
267 31.62177660168379, 31.62177660168379, 31.62177660168379,
268 31.62177660168379, 31.62177660168379, 31.62177660168379,
269 31.62177660168379, 31.62177660168379, 31.62177660168379,
270 31.62177660168379, 31.62177660168379, 31.62177660168379,
271 31.62177660168379, 31.62177660168379, 31.62177660168379,
272 31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
273 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
274 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
275 999.999, 999.999, 999.999;
276
277 b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
278 0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
279 0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
280 31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999,
281 999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0,
282 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
283 0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999,
284 0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379,
285 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
286 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
287 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
288 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
289 31.62177660168379, 31.62177660168379, 31.62177660168379,
290 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
291 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
292 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
293 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
294 31.62177660168379, 31.62177660168379, 31.62177660168379,
295 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
296 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
297 0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
298 0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
299 31.62177660168379, 31.62177660168379, 31.62177660168379,
300 31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
301 999.999, 999.999;
302
303 x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
304 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
305 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
306 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
307 -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
308 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
309 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
310 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
311 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
312 0.8, 1.1;
313
314 v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
315 nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
316 nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
317 0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
318 0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
319 0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan,
320 nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
321 0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
322 0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
323 0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
324 0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
325 1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06,
326 nan, nan, 7.864342668429763e-23, 3.015969667594166e-10,
327 0.0008598571564165444, nan, nan, 6.031987710123844e-08,
328 0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999,
329 0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan,
330 nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan,
331 0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0,
332 3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan,
333 2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan;
334
335 CALL_SUBTEST(res = betainc(a, b, x);
336 verify_component_wise(res, v););
337 }
338
339 // Test various properties of betainc
340 {
341 ArrayType m1 = ArrayType::Random(32);
342 ArrayType m2 = ArrayType::Random(32);
343 ArrayType m3 = ArrayType::Random(32);
344 ArrayType one = ArrayType::Constant(32, Scalar(1.0));
345 const Scalar eps = std::numeric_limits<Scalar>::epsilon();
346 ArrayType a = (m1 * Scalar(4)).exp();
347 ArrayType b = (m2 * Scalar(4)).exp();
348 ArrayType x = m3.abs();
349
350 // betainc(a, 1, x) == x**a
351 CALL_SUBTEST(
352 ArrayType test = betainc(a, one, x);
353 ArrayType expected = x.pow(a);
354 verify_component_wise(test, expected););
355
356 // betainc(1, b, x) == 1 - (1 - x)**b
357 CALL_SUBTEST(
358 ArrayType test = betainc(one, b, x);
359 ArrayType expected = one - (one - x).pow(b);
360 verify_component_wise(test, expected););
361
362 // betainc(a, b, x) == 1 - betainc(b, a, 1-x)
363 CALL_SUBTEST(
364 ArrayType test = betainc(a, b, x) + betainc(b, a, one - x);
365 ArrayType expected = one;
366 verify_component_wise(test, expected););
367
368 // betainc(a+1, b, x) = betainc(a, b, x) - x**a * (1 - x)**b / (a * beta(a, b))
369 CALL_SUBTEST(
370 ArrayType num = x.pow(a) * (one - x).pow(b);
371 ArrayType denom = a * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
372 // Add eps to rhs and lhs so that component-wise test doesn't result in
373 // nans when both outputs are zeros.
374 ArrayType expected = betainc(a, b, x) - num / denom + eps;
375 ArrayType test = betainc(a + one, b, x) + eps;
376 if (sizeof(Scalar) >= 8) { // double
377 verify_component_wise(test, expected);
378 } else {
379 // Reason for limited test: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
380 verify_component_wise(test.head(8), expected.head(8));
381 });
382
383 // betainc(a, b+1, x) = betainc(a, b, x) + x**a * (1 - x)**b / (b * beta(a, b))
384 CALL_SUBTEST(
385 // Add eps to rhs and lhs so that component-wise test doesn't result in
386 // nans when both outputs are zeros.
387 ArrayType num = x.pow(a) * (one - x).pow(b);
388 ArrayType denom = b * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
389 ArrayType expected = betainc(a, b, x) + num / denom + eps;
390 ArrayType test = betainc(a, b + one, x) + eps;
391 verify_component_wise(test, expected););
392 }
393 #endif // EIGEN_HAS_C99_MATH
394
395 /* Code to generate the data for the following two test cases.
396 N = 5
397 np.random.seed(3)
398
399 a = np.logspace(-2, 3, 6)
400 a = np.ravel(np.tile(np.reshape(a, [-1, 1]), [1, N]))
401 x = np.random.gamma(a, 1.0)
402 x = np.maximum(x, np.finfo(np.float32).tiny)
403
404 def igamma(a, x):
405 return mpmath.gammainc(a, 0, x, regularized=True)
406
407 def igamma_der_a(a, x):
408 res = mpmath.diff(lambda a_prime: igamma(a_prime, x), a)
409 return np.float64(res)
410
411 def gamma_sample_der_alpha(a, x):
412 igamma_x = igamma(a, x)
413 def igammainv_of_igamma(a_prime):
414 return mpmath.findroot(lambda x_prime: igamma(a_prime, x_prime) -
415 igamma_x, x, solver='newton')
416 return np.float64(mpmath.diff(igammainv_of_igamma, a))
417
418 v_igamma_der_a = np.vectorize(igamma_der_a)(a, x)
419 v_gamma_sample_der_alpha = np.vectorize(gamma_sample_der_alpha)(a, x)
420 */
421
422 #if EIGEN_HAS_C99_MATH
423 // Test igamma_der_a
424 {
425 ArrayType a(30);
426 ArrayType x(30);
427 ArrayType res(30);
428 ArrayType v(30);
429
430 a << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0,
431 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
432 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
433
434 x << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
435 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16,
436 0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06,
437 0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426,
438 0.786686768458, 7.63873279537, 13.1944344379, 11.896042354,
439 10.5830172417, 10.5020942233, 92.8918587747, 95.003720371,
440 86.3715926467, 96.0330217672, 82.6389930677, 968.702906754,
441 969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
442
443 v << -32.7256441441, -36.4394150514, -9.66467612263, -36.4394150514,
444 -36.4394150514, -1.0891900302, -2.66351229645, -2.48666868596,
445 -0.929700494428, -3.56327722764, -0.455320135314, -0.391437214323,
446 -0.491352055991, -0.350454834292, -0.471773162921, -0.104084440522,
447 -0.0723646747909, -0.0992828975532, -0.121638215446, -0.122619605294,
448 -0.0317670267286, -0.0359974812869, -0.0154359225363, -0.0375775365921,
449 -0.00794899153653, -0.00777303219211, -0.00796085782042,
450 -0.0125850719397, -0.00455500206958, -0.00476436993148;
451
452 CALL_SUBTEST(res = igamma_der_a(a, x); verify_component_wise(res, v););
453 }
454
455 // Test gamma_sample_der_alpha
456 {
457 ArrayType alpha(30);
458 ArrayType sample(30);
459 ArrayType res(30);
460 ArrayType v(30);
461
462 alpha << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0,
463 1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
464 100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
465
466 sample << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
467 1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16,
468 0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06,
469 0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426,
470 0.786686768458, 7.63873279537, 13.1944344379, 11.896042354,
471 10.5830172417, 10.5020942233, 92.8918587747, 95.003720371,
472 86.3715926467, 96.0330217672, 82.6389930677, 968.702906754,
473 969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
474
475 v << 7.42424742367e-23, 1.02004297287e-34, 0.0130155240738,
476 1.02004297287e-34, 1.02004297287e-34, 1.96505168277e-13, 0.525575786243,
477 0.713903991771, 2.32077561808e-14, 0.000179348049886, 0.635500453302,
478 1.27561284917, 0.878125852156, 0.41565819538, 1.03606488534,
479 0.885964824887, 1.16424049334, 1.10764479598, 1.04590810812,
480 1.04193666963, 0.965193152414, 0.976217589464, 0.93008035061,
481 0.98153216096, 0.909196397698, 0.98434963993, 0.984738050206,
482 1.00106492525, 0.97734200649, 1.02198794179;
483
484 CALL_SUBTEST(res = gamma_sample_der_alpha(alpha, sample);
485 verify_component_wise(res, v););
486 }
487 #endif // EIGEN_HAS_C99_MATH
488 }
489
EIGEN_DECLARE_TEST(special_functions)490 EIGEN_DECLARE_TEST(special_functions)
491 {
492 CALL_SUBTEST_1(array_special_functions<ArrayXf>());
493 CALL_SUBTEST_2(array_special_functions<ArrayXd>());
494 // TODO(cantonios): half/bfloat16 don't have enough precision to reproduce results above.
495 // CALL_SUBTEST_3(array_special_functions<ArrayX<Eigen::half>>());
496 // CALL_SUBTEST_4(array_special_functions<ArrayX<Eigen::bfloat16>>());
497 }
498