xref: /aosp_15_r20/external/eigen/unsupported/test/special_functions.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Gael Guennebaud <[email protected]>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include <limits.h>
11 #include "main.h"
12 #include "../Eigen/SpecialFunctions"
13 
14 // Hack to allow "implicit" conversions from double to Scalar via comma-initialization.
15 template<typename Derived>
operator <<(Eigen::DenseBase<Derived> & dense,double v)16 Eigen::CommaInitializer<Derived> operator<<(Eigen::DenseBase<Derived>& dense, double v) {
17   return (dense << static_cast<typename Derived::Scalar>(v));
18 }
19 
20 template<typename XprType>
operator ,(Eigen::CommaInitializer<XprType> & ci,double v)21 Eigen::CommaInitializer<XprType>& operator,(Eigen::CommaInitializer<XprType>& ci, double v) {
22   return (ci, static_cast<typename XprType::Scalar>(v));
23 }
24 
25 template<typename X, typename Y>
verify_component_wise(const X & x,const Y & y)26 void verify_component_wise(const X& x, const Y& y)
27 {
28   for(Index i=0; i<x.size(); ++i)
29   {
30     if((numext::isfinite)(y(i)))
31       VERIFY_IS_APPROX( x(i), y(i) );
32     else if((numext::isnan)(y(i)))
33       VERIFY((numext::isnan)(x(i)));
34     else
35       VERIFY_IS_EQUAL( x(i), y(i) );
36   }
37 }
38 
array_special_functions()39 template<typename ArrayType> void array_special_functions()
40 {
41   using std::abs;
42   using std::sqrt;
43   typedef typename ArrayType::Scalar Scalar;
44   typedef typename NumTraits<Scalar>::Real RealScalar;
45 
46   Scalar plusinf = std::numeric_limits<Scalar>::infinity();
47   Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
48 
49   Index rows = internal::random<Index>(1,30);
50   Index cols = 1;
51 
52   // API
53   {
54     ArrayType m1 = ArrayType::Random(rows,cols);
55 #if EIGEN_HAS_C99_MATH
56     VERIFY_IS_APPROX(m1.lgamma(), lgamma(m1));
57     VERIFY_IS_APPROX(m1.digamma(), digamma(m1));
58     VERIFY_IS_APPROX(m1.erf(), erf(m1));
59     VERIFY_IS_APPROX(m1.erfc(), erfc(m1));
60 #endif  // EIGEN_HAS_C99_MATH
61   }
62 
63 
64 #if EIGEN_HAS_C99_MATH
65   // check special functions (comparing against numpy implementation)
66   if (!NumTraits<Scalar>::IsComplex)
67   {
68 
69     {
70       ArrayType m1 = ArrayType::Random(rows,cols);
71       ArrayType m2 = ArrayType::Random(rows,cols);
72 
73       // Test various propreties of igamma & igammac.  These are normalized
74       // gamma integrals where
75       //   igammac(a, x) = Gamma(a, x) / Gamma(a)
76       //   igamma(a, x) = gamma(a, x) / Gamma(a)
77       // where Gamma and gamma are considered the standard unnormalized
78       // upper and lower incomplete gamma functions, respectively.
79       ArrayType a = m1.abs() + Scalar(2);
80       ArrayType x = m2.abs() + Scalar(2);
81       ArrayType zero = ArrayType::Zero(rows, cols);
82       ArrayType one = ArrayType::Constant(rows, cols, Scalar(1.0));
83       ArrayType a_m1 = a - one;
84       ArrayType Gamma_a_x = Eigen::igammac(a, x) * a.lgamma().exp();
85       ArrayType Gamma_a_m1_x = Eigen::igammac(a_m1, x) * a_m1.lgamma().exp();
86       ArrayType gamma_a_x = Eigen::igamma(a, x) * a.lgamma().exp();
87       ArrayType gamma_a_m1_x = Eigen::igamma(a_m1, x) * a_m1.lgamma().exp();
88 
89 
90       // Gamma(a, 0) == Gamma(a)
91       VERIFY_IS_APPROX(Eigen::igammac(a, zero), one);
92 
93       // Gamma(a, x) + gamma(a, x) == Gamma(a)
94       VERIFY_IS_APPROX(Gamma_a_x + gamma_a_x, a.lgamma().exp());
95 
96       // Gamma(a, x) == (a - 1) * Gamma(a-1, x) + x^(a-1) * exp(-x)
97       VERIFY_IS_APPROX(Gamma_a_x, (a - Scalar(1)) * Gamma_a_m1_x + x.pow(a-Scalar(1)) * (-x).exp());
98 
99       // gamma(a, x) == (a - 1) * gamma(a-1, x) - x^(a-1) * exp(-x)
100       VERIFY_IS_APPROX(gamma_a_x, (a - Scalar(1)) * gamma_a_m1_x - x.pow(a-Scalar(1)) * (-x).exp());
101     }
102     {
103       // Verify for large a and x that values are between 0 and 1.
104       ArrayType m1 = ArrayType::Random(rows,cols);
105       ArrayType m2 = ArrayType::Random(rows,cols);
106       int max_exponent = std::numeric_limits<Scalar>::max_exponent10;
107       ArrayType a = m1.abs() *  Scalar(pow(10., max_exponent - 1));
108       ArrayType x = m2.abs() *  Scalar(pow(10., max_exponent - 1));
109       for (int i = 0; i < a.size(); ++i) {
110         Scalar igam = numext::igamma(a(i), x(i));
111         VERIFY(0 <= igam);
112         VERIFY(igam <= 1);
113       }
114     }
115 
116     {
117       // Check exact values of igamma and igammac against a third party calculation.
118       Scalar a_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
119       Scalar x_s[] = {Scalar(0), Scalar(1), Scalar(1.5), Scalar(4), Scalar(0.0001), Scalar(1000.5)};
120 
121       // location i*6+j corresponds to a_s[i], x_s[j].
122       Scalar igamma_s[][6] = {
123           {Scalar(0.0), nan, nan, nan, nan, nan},
124           {Scalar(0.0), Scalar(0.6321205588285578), Scalar(0.7768698398515702),
125            Scalar(0.9816843611112658), Scalar(9.999500016666262e-05),
126            Scalar(1.0)},
127           {Scalar(0.0), Scalar(0.4275932955291202), Scalar(0.608374823728911),
128            Scalar(0.9539882943107686), Scalar(7.522076445089201e-07),
129            Scalar(1.0)},
130           {Scalar(0.0), Scalar(0.01898815687615381),
131            Scalar(0.06564245437845008), Scalar(0.5665298796332909),
132            Scalar(4.166333347221828e-18), Scalar(1.0)},
133           {Scalar(0.0), Scalar(0.9999780593618628), Scalar(0.9999899967080838),
134            Scalar(0.9999996219837988), Scalar(0.9991370418689945), Scalar(1.0)},
135           {Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0), Scalar(0.0),
136            Scalar(0.5042041932513908)}};
137       Scalar igammac_s[][6] = {
138           {nan, nan, nan, nan, nan, nan},
139           {Scalar(1.0), Scalar(0.36787944117144233),
140            Scalar(0.22313016014842982), Scalar(0.018315638888734182),
141            Scalar(0.9999000049998333), Scalar(0.0)},
142           {Scalar(1.0), Scalar(0.5724067044708798), Scalar(0.3916251762710878),
143            Scalar(0.04601170568923136), Scalar(0.9999992477923555),
144            Scalar(0.0)},
145           {Scalar(1.0), Scalar(0.9810118431238462), Scalar(0.9343575456215499),
146            Scalar(0.4334701203667089), Scalar(1.0), Scalar(0.0)},
147           {Scalar(1.0), Scalar(2.1940638138146658e-05),
148            Scalar(1.0003291916285e-05), Scalar(3.7801620118431334e-07),
149            Scalar(0.0008629581310054535), Scalar(0.0)},
150           {Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0), Scalar(1.0),
151            Scalar(0.49579580674813944)}};
152 
153       for (int i = 0; i < 6; ++i) {
154         for (int j = 0; j < 6; ++j) {
155           if ((std::isnan)(igamma_s[i][j])) {
156             VERIFY((std::isnan)(numext::igamma(a_s[i], x_s[j])));
157           } else {
158             VERIFY_IS_APPROX(numext::igamma(a_s[i], x_s[j]), igamma_s[i][j]);
159           }
160 
161           if ((std::isnan)(igammac_s[i][j])) {
162             VERIFY((std::isnan)(numext::igammac(a_s[i], x_s[j])));
163           } else {
164             VERIFY_IS_APPROX(numext::igammac(a_s[i], x_s[j]), igammac_s[i][j]);
165           }
166         }
167       }
168     }
169   }
170 #endif  // EIGEN_HAS_C99_MATH
171 
172   // Check the ndtri function against scipy.special.ndtri
173   {
174     ArrayType x(7), res(7), ref(7);
175     x << 0.5, 0.2, 0.8, 0.9, 0.1, 0.99, 0.01;
176     ref << 0., -0.8416212335729142, 0.8416212335729142, 1.2815515655446004, -1.2815515655446004, 2.3263478740408408, -2.3263478740408408;
177     CALL_SUBTEST( verify_component_wise(ref, ref); );
178     CALL_SUBTEST( res = x.ndtri(); verify_component_wise(res, ref); );
179     CALL_SUBTEST( res = ndtri(x); verify_component_wise(res, ref); );
180 
181     // ndtri(normal_cdf(x)) ~= x
182     CALL_SUBTEST(
183         ArrayType m1 = ArrayType::Random(32);
184         using std::sqrt;
185 
186         ArrayType cdf_val = (m1 / Scalar(sqrt(2.))).erf();
187         cdf_val = (cdf_val + Scalar(1)) / Scalar(2);
188         verify_component_wise(cdf_val.ndtri(), m1););
189 
190   }
191 
192   // Check the zeta function against scipy.special.zeta
193   {
194     ArrayType x(10), q(10), res(10), ref(10);
195     x << 1.5,   4, 10.5, 10000.5,    3,      1,    0.9,  2,  3,  4;
196     q <<   2, 1.5,    3,  1.0001, -2.5, 1.2345, 1.2345, -1, -2, -3;
197     ref << 1.61237534869, 0.234848505667, 1.03086757337e-5, 0.367879440865, 0.054102025820864097, plusinf, nan, plusinf, nan, plusinf;
198     CALL_SUBTEST( verify_component_wise(ref, ref); );
199     CALL_SUBTEST( res = x.zeta(q); verify_component_wise(res, ref); );
200     CALL_SUBTEST( res = zeta(x,q); verify_component_wise(res, ref); );
201   }
202 
203   // digamma
204   {
205     ArrayType x(9), res(9), ref(9);
206     x << 1, 1.5, 4, -10.5, 10000.5, 0, -1, -2, -3;
207     ref << -0.5772156649015329, 0.03648997397857645, 1.2561176684318, 2.398239129535781, 9.210340372392849, nan, nan, nan, nan;
208     CALL_SUBTEST( verify_component_wise(ref, ref); );
209 
210     CALL_SUBTEST( res = x.digamma(); verify_component_wise(res, ref); );
211     CALL_SUBTEST( res = digamma(x);  verify_component_wise(res, ref); );
212   }
213 
214 #if EIGEN_HAS_C99_MATH
215   {
216     ArrayType n(16), x(16), res(16), ref(16);
217     n << 1, 1,    1, 1.5,   17,   31,   28,    8,   42,  147, 170, -1,  0,  1,  2,  3;
218     x << 2, 3, 25.5, 1.5,  4.7, 11.8, 17.7, 30.2, 15.8, 54.1,  64, -1, -2, -3, -4, -5;
219     ref << 0.644934066848, 0.394934066848, 0.0399946696496, nan, 293.334565435, 0.445487887616, -2.47810300902e-07, -8.29668781082e-09, -0.434562276666, 0.567742190178, -0.0108615497927, nan, nan, plusinf, nan, plusinf;
220     CALL_SUBTEST( verify_component_wise(ref, ref); );
221 
222     if(sizeof(RealScalar)>=8) {  // double
223       // Reason for commented line: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
224       //       CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res, ref); );
225       CALL_SUBTEST( res = polygamma(n,x);  verify_component_wise(res, ref); );
226     }
227     else {
228       //       CALL_SUBTEST( res = x.polygamma(n); verify_component_wise(res.head(8), ref.head(8)); );
229       CALL_SUBTEST( res = polygamma(n,x); verify_component_wise(res.head(8), ref.head(8)); );
230     }
231   }
232 #endif
233 
234 #if EIGEN_HAS_C99_MATH
235   {
236     // Inputs and ground truth generated with scipy via:
237     //   a = np.logspace(-3, 3, 5) - 1e-3
238     //   b = np.logspace(-3, 3, 5) - 1e-3
239     //   x = np.linspace(-0.1, 1.1, 5)
240     //   (full_a, full_b, full_x) = np.vectorize(lambda a, b, x: (a, b, x))(*np.ix_(a, b, x))
241     //   full_a = full_a.flatten().tolist()  # same for full_b, full_x
242     //   v = scipy.special.betainc(full_a, full_b, full_x).flatten().tolist()
243     //
244     // Note in Eigen, we call betainc with arguments in the order (x, a, b).
245     ArrayType a(125);
246     ArrayType b(125);
247     ArrayType x(125);
248     ArrayType v(125);
249     ArrayType res(125);
250 
251     a << 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
252         0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
253         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
254         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
255         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
256         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
257         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
258         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
259         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
260         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
261         0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
262         0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
263         0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999, 0.999,
264         31.62177660168379, 31.62177660168379, 31.62177660168379,
265         31.62177660168379, 31.62177660168379, 31.62177660168379,
266         31.62177660168379, 31.62177660168379, 31.62177660168379,
267         31.62177660168379, 31.62177660168379, 31.62177660168379,
268         31.62177660168379, 31.62177660168379, 31.62177660168379,
269         31.62177660168379, 31.62177660168379, 31.62177660168379,
270         31.62177660168379, 31.62177660168379, 31.62177660168379,
271         31.62177660168379, 31.62177660168379, 31.62177660168379,
272         31.62177660168379, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
273         999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
274         999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999, 999.999,
275         999.999, 999.999, 999.999;
276 
277     b << 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379, 0.03062277660168379,
278         0.03062277660168379, 0.03062277660168379, 0.03062277660168379, 0.999,
279         0.999, 0.999, 0.999, 0.999, 31.62177660168379, 31.62177660168379,
280         31.62177660168379, 31.62177660168379, 31.62177660168379, 999.999,
281         999.999, 999.999, 999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0,
282         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
283         0.03062277660168379, 0.03062277660168379, 0.999, 0.999, 0.999, 0.999,
284         0.999, 31.62177660168379, 31.62177660168379, 31.62177660168379,
285         31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
286         999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
287         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
288         0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
289         31.62177660168379, 31.62177660168379, 31.62177660168379,
290         31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
291         999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
292         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
293         0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
294         31.62177660168379, 31.62177660168379, 31.62177660168379,
295         31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
296         999.999, 999.999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03062277660168379,
297         0.03062277660168379, 0.03062277660168379, 0.03062277660168379,
298         0.03062277660168379, 0.999, 0.999, 0.999, 0.999, 0.999,
299         31.62177660168379, 31.62177660168379, 31.62177660168379,
300         31.62177660168379, 31.62177660168379, 999.999, 999.999, 999.999,
301         999.999, 999.999;
302 
303     x << -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
304         0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
305         0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
306         0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1,
307         -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8,
308         1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
309         0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2,
310         0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1,
311         0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5, 0.8, 1.1, -0.1, 0.2, 0.5,
312         0.8, 1.1;
313 
314     v << nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
315         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
316         nan, nan, nan, 0.47972119876364683, 0.5, 0.5202788012363533, nan, nan,
317         0.9518683957740043, 0.9789663010413743, 0.9931729188073435, nan, nan,
318         0.999995949033062, 0.9999999999993698, 0.9999999999999999, nan, nan,
319         0.9999999999999999, 0.9999999999999999, 0.9999999999999999, nan, nan,
320         nan, nan, nan, nan, nan, 0.006827081192655869, 0.0210336989586256,
321         0.04813160422599567, nan, nan, 0.20014344256217678, 0.5000000000000001,
322         0.7998565574378232, nan, nan, 0.9991401428435834, 0.999999999698403,
323         0.9999999999999999, nan, nan, 0.9999999999999999, 0.9999999999999999,
324         0.9999999999999999, nan, nan, nan, nan, nan, nan, nan,
325         1.0646600232370887e-25, 6.301722877826246e-13, 4.050966937974938e-06,
326         nan, nan, 7.864342668429763e-23, 3.015969667594166e-10,
327         0.0008598571564165444, nan, nan, 6.031987710123844e-08,
328         0.5000000000000007, 0.9999999396801229, nan, nan, 0.9999999999999999,
329         0.9999999999999999, 0.9999999999999999, nan, nan, nan, nan, nan, nan,
330         nan, 0.0, 7.029920380986636e-306, 2.2450728208591345e-101, nan, nan,
331         0.0, 9.275871147869727e-302, 1.2232913026152827e-97, nan, nan, 0.0,
332         3.0891393081932924e-252, 2.9303043666183996e-60, nan, nan,
333         2.248913486879199e-196, 0.5000000000004947, 0.9999999999999999, nan;
334 
335     CALL_SUBTEST(res = betainc(a, b, x);
336                  verify_component_wise(res, v););
337   }
338 
339   // Test various properties of betainc
340   {
341     ArrayType m1 = ArrayType::Random(32);
342     ArrayType m2 = ArrayType::Random(32);
343     ArrayType m3 = ArrayType::Random(32);
344     ArrayType one = ArrayType::Constant(32, Scalar(1.0));
345     const Scalar eps = std::numeric_limits<Scalar>::epsilon();
346     ArrayType a = (m1 * Scalar(4)).exp();
347     ArrayType b = (m2 * Scalar(4)).exp();
348     ArrayType x = m3.abs();
349 
350     // betainc(a, 1, x) == x**a
351     CALL_SUBTEST(
352         ArrayType test = betainc(a, one, x);
353         ArrayType expected = x.pow(a);
354         verify_component_wise(test, expected););
355 
356     // betainc(1, b, x) == 1 - (1 - x)**b
357     CALL_SUBTEST(
358         ArrayType test = betainc(one, b, x);
359         ArrayType expected = one - (one - x).pow(b);
360         verify_component_wise(test, expected););
361 
362     // betainc(a, b, x) == 1 - betainc(b, a, 1-x)
363     CALL_SUBTEST(
364         ArrayType test = betainc(a, b, x) + betainc(b, a, one - x);
365         ArrayType expected = one;
366         verify_component_wise(test, expected););
367 
368     // betainc(a+1, b, x) = betainc(a, b, x) - x**a * (1 - x)**b / (a * beta(a, b))
369     CALL_SUBTEST(
370         ArrayType num = x.pow(a) * (one - x).pow(b);
371         ArrayType denom = a * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
372         // Add eps to rhs and lhs so that component-wise test doesn't result in
373         // nans when both outputs are zeros.
374         ArrayType expected = betainc(a, b, x) - num / denom + eps;
375         ArrayType test = betainc(a + one, b, x) + eps;
376         if (sizeof(Scalar) >= 8) { // double
377           verify_component_wise(test, expected);
378         } else {
379           // Reason for limited test: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1232
380           verify_component_wise(test.head(8), expected.head(8));
381         });
382 
383     // betainc(a, b+1, x) = betainc(a, b, x) + x**a * (1 - x)**b / (b * beta(a, b))
384     CALL_SUBTEST(
385         // Add eps to rhs and lhs so that component-wise test doesn't result in
386         // nans when both outputs are zeros.
387         ArrayType num = x.pow(a) * (one - x).pow(b);
388         ArrayType denom = b * (a.lgamma() + b.lgamma() - (a + b).lgamma()).exp();
389         ArrayType expected = betainc(a, b, x) + num / denom + eps;
390         ArrayType test = betainc(a, b + one, x) + eps;
391         verify_component_wise(test, expected););
392   }
393 #endif  // EIGEN_HAS_C99_MATH
394 
395     /* Code to generate the data for the following two test cases.
396     N = 5
397     np.random.seed(3)
398 
399     a = np.logspace(-2, 3, 6)
400     a = np.ravel(np.tile(np.reshape(a, [-1, 1]), [1, N]))
401     x = np.random.gamma(a, 1.0)
402     x = np.maximum(x, np.finfo(np.float32).tiny)
403 
404     def igamma(a, x):
405       return mpmath.gammainc(a, 0, x, regularized=True)
406 
407     def igamma_der_a(a, x):
408       res = mpmath.diff(lambda a_prime: igamma(a_prime, x), a)
409       return np.float64(res)
410 
411     def gamma_sample_der_alpha(a, x):
412       igamma_x = igamma(a, x)
413       def igammainv_of_igamma(a_prime):
414         return mpmath.findroot(lambda x_prime: igamma(a_prime, x_prime) -
415             igamma_x, x, solver='newton')
416       return np.float64(mpmath.diff(igammainv_of_igamma, a))
417 
418     v_igamma_der_a = np.vectorize(igamma_der_a)(a, x)
419     v_gamma_sample_der_alpha = np.vectorize(gamma_sample_der_alpha)(a, x)
420   */
421 
422 #if EIGEN_HAS_C99_MATH
423   // Test igamma_der_a
424   {
425     ArrayType a(30);
426     ArrayType x(30);
427     ArrayType res(30);
428     ArrayType v(30);
429 
430     a << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0, 1.0,
431         1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
432         100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
433 
434     x << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
435         1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16,
436         0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06,
437         0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426,
438         0.786686768458, 7.63873279537, 13.1944344379, 11.896042354,
439         10.5830172417, 10.5020942233, 92.8918587747, 95.003720371,
440         86.3715926467, 96.0330217672, 82.6389930677, 968.702906754,
441         969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
442 
443     v << -32.7256441441, -36.4394150514, -9.66467612263, -36.4394150514,
444         -36.4394150514, -1.0891900302, -2.66351229645, -2.48666868596,
445         -0.929700494428, -3.56327722764, -0.455320135314, -0.391437214323,
446         -0.491352055991, -0.350454834292, -0.471773162921, -0.104084440522,
447         -0.0723646747909, -0.0992828975532, -0.121638215446, -0.122619605294,
448         -0.0317670267286, -0.0359974812869, -0.0154359225363, -0.0375775365921,
449         -0.00794899153653, -0.00777303219211, -0.00796085782042,
450         -0.0125850719397, -0.00455500206958, -0.00476436993148;
451 
452     CALL_SUBTEST(res = igamma_der_a(a, x); verify_component_wise(res, v););
453   }
454 
455   // Test gamma_sample_der_alpha
456   {
457     ArrayType alpha(30);
458     ArrayType sample(30);
459     ArrayType res(30);
460     ArrayType v(30);
461 
462     alpha << 0.01, 0.01, 0.01, 0.01, 0.01, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 1.0,
463         1.0, 1.0, 1.0, 10.0, 10.0, 10.0, 10.0, 10.0, 100.0, 100.0, 100.0, 100.0,
464         100.0, 1000.0, 1000.0, 1000.0, 1000.0, 1000.0;
465 
466     sample << 1.25668890405e-26, 1.17549435082e-38, 1.20938905072e-05,
467         1.17549435082e-38, 1.17549435082e-38, 5.66572070696e-16,
468         0.0132865061065, 0.0200034203853, 6.29263709118e-17, 1.37160367764e-06,
469         0.333412038288, 1.18135687766, 0.580629033777, 0.170631439426,
470         0.786686768458, 7.63873279537, 13.1944344379, 11.896042354,
471         10.5830172417, 10.5020942233, 92.8918587747, 95.003720371,
472         86.3715926467, 96.0330217672, 82.6389930677, 968.702906754,
473         969.463546828, 1001.79726022, 955.047416547, 1044.27458568;
474 
475     v << 7.42424742367e-23, 1.02004297287e-34, 0.0130155240738,
476         1.02004297287e-34, 1.02004297287e-34, 1.96505168277e-13, 0.525575786243,
477         0.713903991771, 2.32077561808e-14, 0.000179348049886, 0.635500453302,
478         1.27561284917, 0.878125852156, 0.41565819538, 1.03606488534,
479         0.885964824887, 1.16424049334, 1.10764479598, 1.04590810812,
480         1.04193666963, 0.965193152414, 0.976217589464, 0.93008035061,
481         0.98153216096, 0.909196397698, 0.98434963993, 0.984738050206,
482         1.00106492525, 0.97734200649, 1.02198794179;
483 
484     CALL_SUBTEST(res = gamma_sample_der_alpha(alpha, sample);
485                  verify_component_wise(res, v););
486   }
487 #endif  // EIGEN_HAS_C99_MATH
488 }
489 
EIGEN_DECLARE_TEST(special_functions)490 EIGEN_DECLARE_TEST(special_functions)
491 {
492   CALL_SUBTEST_1(array_special_functions<ArrayXf>());
493   CALL_SUBTEST_2(array_special_functions<ArrayXd>());
494   // TODO(cantonios): half/bfloat16 don't have enough precision to reproduce results above.
495   // CALL_SUBTEST_3(array_special_functions<ArrayX<Eigen::half>>());
496   // CALL_SUBTEST_4(array_special_functions<ArrayX<Eigen::bfloat16>>());
497 }
498