xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Distributions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/Math.h>
4 #include <c10/macros/Macros.h>
5 #include <c10/util/MathConstants.h>
6 
7 // ROCM hcc doesn't work well with using std:: in kernel functions
8 #if defined(__CUDA_ARCH__)
9 #include <c10/cuda/CUDAMathCompat.h>
10 #define compat_exp c10::cuda::compat::exp
11 #define compat_ceil c10::cuda::compat::ceil
12 #define compat_floor c10::cuda::compat::floor
13 #define compat_log c10::cuda::compat::log
14 #define compat_pow c10::cuda::compat::pow
15 #define compat_sqrt c10::cuda::compat::sqrt
16 #define compat_tan c10::cuda::compat::tan
17 #define compat_abs c10::cuda::compat::abs
18 #define compat_log1p c10::cuda::compat::log1p
19 #elif defined(__HIPCC__)
20 #include <c10/hip/HIPMathCompat.h>
21 #define compat_exp c10::hip::compat::exp
22 #define compat_ceil c10::hip::compat::ceil
23 #define compat_floor c10::hip::compat::floor
24 #define compat_log c10::hip::compat::log
25 #define compat_pow c10::hip::compat::pow
26 #define compat_sqrt c10::hip::compat::sqrt
27 #define compat_tan c10::hip::compat::tan
28 #define compat_abs c10::hip::compat::abs
29 #define compat_log1p c10::hip::compat::log1p
30 #else
31 #define compat_exp std::exp
32 #define compat_ceil std::ceil
33 #define compat_floor std::floor
34 #define compat_log std::log
35 #define compat_pow std::pow
36 #define compat_sqrt std::sqrt
37 #define compat_tan std::tan
38 #define compat_abs std::abs
39 #define compat_log1p std::log1p
40 #endif
41 
42 namespace {
43 
44 #if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
45 // we cannot use std::isnan directly due to some incompatibility of
46 // gcc constexpr'ing and nvcc
47 using std::isnan;
48 #endif
49 
50 // Here sampler_t should be function type scalar_t(void). For gpu
51 // "sampler" is a device function, but since ROCM doesn't have
52 // equivalent to nvstd::function, we use a template type parameter to
53 // capture it.
54 template<typename scalar_t, typename sampler_t>
55 struct BaseSampler {
56   sampler_t sampler;
BaseSamplerBaseSampler57   C10_DEVICE BaseSampler(const sampler_t& sampler): sampler(sampler) {}
sampleBaseSampler58   C10_DEVICE scalar_t sample() {
59     return sampler();
60   }
61 };
62 
63 // The function `sample_gamma` is
64 // is adapted from Numpy's distributions.c implementation.
65 // It is MIT licensed, so here is the copyright:
66 
67 /* Copyright 2005 Robert Kern ([email protected])
68  *
69  * Permission is hereby granted, free of charge, to any person obtaining a
70  * copy of this software and associated documentation files (the
71  * "Software"), to deal in the Software without restriction, including
72  * without limitation the rights to use, copy, modify, merge, publish,
73  * distribute, sublicense, and/or sell copies of the Software, and to
74  * permit persons to whom the Software is furnished to do so, subject to
75  * the following conditions:
76  *
77  * The above copyright notice and this permission notice shall be included
78  * in all copies or substantial portions of the Software.
79  *
80  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
81  * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
82  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
83  * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
84  * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
85  * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
86  * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
87 */
88 
89 template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t, typename normal_sampler_t>
sample_gamma(scalar_t alpha,BaseSampler<accscalar_t,uniform_sampler_t> & standard_uniform,BaseSampler<accscalar_t,normal_sampler_t> & standard_normal)90 C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform, BaseSampler<accscalar_t, normal_sampler_t>& standard_normal) {
91   accscalar_t scale = 1.0f;
92 
93   // Boost alpha for higher acceptance probability.
94   if (alpha < 1.0f) {
95     if (alpha == 0.f) return 0.f;
96     scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha);
97     alpha += 1.0f;
98   }
99 
100   // This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
101   // doi:10.1145/358407.358414
102   const accscalar_t d = alpha - 1.0f / 3.0f;
103   const accscalar_t c = 1.0f / compat_sqrt(9.0f * d);
104   for (;;) {
105     accscalar_t x, y;
106     do {
107       x = standard_normal.sample();
108       y = 1.0f + c * x;
109     } while (y <= 0);
110     const accscalar_t v = y * y * y;
111     const accscalar_t u = 1 - standard_uniform.sample();
112     const accscalar_t xx = x * x;
113     if (u < 1.0f - 0.0331f * xx * xx)
114       return static_cast<scalar_t>(scale * d * v);
115     if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v)))
116       return static_cast<scalar_t>(scale * d * v);
117   }
118 }
119 
120 /* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted
121  * from TensorFlow's random_binomial_op.cc implementation. That code is under
122  * copyright: 2019 The TensorFlow Authors.
123  *
124  * It was released under the Apache License, Version 2.0 (the "License"), available at:
125  *    http://www.apache.org/licenses/LICENSE-2.0
126  */
127 
128 template<typename scalar_t>
stirling_approx_tail(scalar_t k)129 C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
130   const static scalar_t kTailValues[] = {
131     0.0810614667953272,
132     0.0413406959554092,
133     0.0276779256849983,
134     0.02079067210376509,
135     0.0166446911898211,
136     0.0138761288230707,
137     0.0118967099458917,
138     0.0104112652619720,
139     0.00925546218271273,
140     0.00833056343336287
141   };
142   if (k <= 9) {
143     return kTailValues[static_cast<size_t>(k)];
144   }
145   scalar_t kp1sq = (k + 1) * (k + 1);
146   return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
147 }
148 
149 
150 template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
binomial_inversion(scalar_t count,scalar_t prob,BaseSampler<accscalar_t,uniform_sampler_t> & standard_uniform)151 C10_DEVICE scalar_t binomial_inversion(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
152   accscalar_t U;
153   accscalar_t geom_sum = 0;
154   scalar_t num_geom = 0;
155 
156   accscalar_t logprob = compat_log1p(-prob);
157 
158   while (1) {
159     U = standard_uniform.sample();
160     accscalar_t geom = compat_ceil(compat_log(U) / logprob);
161     geom_sum += geom;
162     if (geom_sum > count) {
163       break;
164     }
165     num_geom = num_geom + 1;
166   }
167   return num_geom;
168 }
169 
170 template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
btrs(scalar_t count,scalar_t prob,BaseSampler<accscalar_t,uniform_sampler_t> & standard_uniform)171 C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
172   scalar_t k;
173   accscalar_t U, V, us;
174 
175   // This is spq in the paper.
176   const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob));
177 
178   // Other coefficients for Transformed Rejection sampling.
179   const accscalar_t b = 1.15 + 2.53 * stddev;
180   const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob;
181   const accscalar_t c = count * prob + 0.5;
182   const accscalar_t v_r = 0.92 - 4.2 / b;
183   const accscalar_t r = prob / (1 - prob);
184 
185   const accscalar_t alpha = (2.83 + 5.1 / b) * stddev;
186   const accscalar_t m = compat_floor((count + 1) * prob);
187 
188   while (1) {
189     U = standard_uniform.sample() - 0.5;
190     V = standard_uniform.sample();
191 
192     us = 0.5 - compat_abs(U);
193     k = static_cast<scalar_t>(compat_floor((2 * a / us + b) * U + c));
194 
195     // Reject non-sensical answers.
196     if (k < 0 || k > count) {
197       continue;
198     }
199     // Region for which the box is tight, and we can return our calculated value.
200     // This should happen 0.86 * v_r times. In the limit as n * p is large,
201     // the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
202     if (us >= 0.07 && V <= v_r) {
203       return k;
204     }
205 
206     // This deviates from Hormann's BTRS algorithm, as there is a log missing.
207     // For all (u, v) pairs outside of the bounding box, this calculates the
208     // transformed-reject ratio.
209     V = compat_log(V * alpha / (a / (us * us) + b));
210     accscalar_t upperbound =
211         ((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) +
212          (count + 1) * compat_log((count - m + 1) / (count - k + 1)) +
213          (k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) +
214          stirling_approx_tail<accscalar_t>(m) + stirling_approx_tail<accscalar_t>(count - m) -
215          stirling_approx_tail<accscalar_t>(k) - stirling_approx_tail<accscalar_t>(count - k));
216 
217     if (V <= upperbound) {
218       return k;
219     }
220   }
221 }
222 
223 template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
sample_binomial(scalar_t count,scalar_t prob,BaseSampler<accscalar_t,uniform_sampler_t> & standard_uniform)224 C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
225   if (count <= 0.0 || prob <= 0.0) {
226     return 0;
227   } else if (prob >= 1.0) {
228     return count;
229   } else if (prob <= 0.5) {
230     if (count * prob >= 10.0) {
231       // btrs
232       return btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
233     } else {
234       // binomial inversion
235       return binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
236     }
237   } else if (prob > 0.5) {
238     scalar_t qprob = 1.0 - prob;
239     if (count * qprob >= 10.0) {
240       // btrs
241       return count - btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
242     } else {
243       // count - binomial inversion
244       return count - binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
245     }
246   } else {
247     // prob is nan?
248     return static_cast<scalar_t>(NAN);
249   }
250 }
251 
252 /*
253  * This function is derived from the implementation of the digamma function in the Cephes Math Library.
254  * See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
255  */
256 template<typename scalar_t, typename accscalar_t>
digamma_one(scalar_t x)257 C10_DEVICE inline scalar_t digamma_one(scalar_t x) {
258   constexpr accscalar_t PSI_10 = 2.25175258906672110764;
259   if (x == 0) {
260     return INFINITY;
261   }
262   accscalar_t additional_summand = 0;
263   int x_is_integer = x == compat_floor(x);
264   if (x < 0) {
265     if (x_is_integer) {
266       return INFINITY;
267     }
268     // it is more standard to write this as recursion, but
269     // nvcc does not like that
270     additional_summand = -c10::pi<scalar_t> /
271         compat_tan(c10::pi<scalar_t> * x);
272     x = 1 - x;
273   }
274 
275   // Push x to be >= 10
276   accscalar_t result = 0;
277   while (x < 10) {
278     result -= 1 / x;
279     x += 1;
280   }
281   if (x == 10) {
282     return result + PSI_10 + additional_summand;
283   }
284 
285   // Compute asymptotic digamma
286   static const accscalar_t A[] = {
287      8.33333333333333333333E-2,
288     -2.10927960927960927961E-2,
289      7.57575757575757575758E-3,
290     -4.16666666666666666667E-3,
291      3.96825396825396825397E-3,
292     -8.33333333333333333333E-3,
293      8.33333333333333333333E-2,
294   };
295 
296   accscalar_t y = 0;
297   if (x < 1.0e17f) {
298     accscalar_t z = 1.0 / (x * x);
299     y = z * polevl<accscalar_t>(z, A, 6);
300   }
301   return static_cast<scalar_t>(
302       result + compat_log(x) - (0.5f / x) - y + additional_summand);
303 }
304 
305 // Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
306 // for random number x drawn from a standard Gamma distribution Gamma(alpha).
307 template <typename scalar_t, typename accscalar_t>
standard_gamma_grad_one(scalar_t alpha_,scalar_t x_)308 C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
309   // Use a Taylor series expansion for small x.
310   accscalar_t x = static_cast<accscalar_t>(x_);
311   accscalar_t alpha = static_cast<accscalar_t>(alpha_);
312   if (x < 0.8f) {
313     accscalar_t numer = 1;
314     accscalar_t denom = alpha;
315     auto series1 = numer / denom;
316     auto series2 = numer / (denom * denom);
317     for (int i = 1; i <= 5; ++i) {
318       numer *= -x / static_cast<accscalar_t>(i);
319       denom += 1;
320       series1 += numer / denom;
321       series2 += numer / (denom * denom);
322     }
323     const auto pow_x_alpha = compat_pow(x, alpha);
324     const auto gamma_pdf = compat_pow(x, alpha - 1) * compat_exp(-x);
325     const auto gamma_cdf = pow_x_alpha * series1;
326     const auto gamma_cdf_alpha =
327         (compat_log(x) - digamma_one<accscalar_t, accscalar_t>(alpha)) *
328             gamma_cdf -
329         pow_x_alpha * series2;
330     const auto result = -gamma_cdf_alpha / gamma_pdf;
331     return isnan(result) ? static_cast<scalar_t>( 0.f ) : static_cast<scalar_t>(result);
332   }
333 
334   // Use a Rice saddle point expansion for large alpha.
335   if (alpha > 8.0f) {
336     if (0.9f * alpha <= x && x <= 1.1f * alpha) {
337       const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
338       const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
339           - 65 * x * x / alpha + alpha * (107 + 3600 * x);
340       const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
341       return static_cast<scalar_t>(numer_1 * numer_2 / denom);
342     }
343     const auto denom = compat_sqrt(8 * alpha);
344     const auto term2 = denom / (alpha - x);
345     const auto term3 = compat_pow(
346         x - alpha - alpha * compat_log(x / alpha),
347         static_cast<accscalar_t>(-1.5));
348     const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
349     const auto term1 = compat_log(x / alpha) * term23 -
350         compat_sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
351     const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
352     const auto numer = x * term1;
353     return static_cast<scalar_t>(-stirling * numer / denom);
354   }
355 
356   // Use a bivariate rational approximation to the reparameterized gradient.
357   const auto u = compat_log(x / alpha);
358   const auto v = compat_log(alpha);
359   static const accscalar_t coef_uv[3][8] = {
360     {0.16009398, -0.094634809, 0.025146376, -0.0030648343,
361      1, 0.32668115, 0.10406089, 0.0014179084},
362     {0.53487893, 0.1298071, 0.065735949, -0.0015649758,
363      0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
364     {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
365      0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
366   };
367   accscalar_t coef_v[8];
368   for (int i = 0; i < 8; ++ i) {
369     coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
370   }
371   const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
372   const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
373   return static_cast<scalar_t>(compat_exp(p / q));
374 }
375 
376 // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
377 // Assumes x is close to zero and uses a Taylor expansion.
378 template <typename scalar_t, typename accscalar_t>
_beta_grad_alpha_small(scalar_t x,scalar_t alpha,scalar_t beta)379 C10_DEVICE inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
380   const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha)
381                         - digamma_one<scalar_t, accscalar_t>(alpha + beta) - compat_log(x);
382   scalar_t numer = 1;
383   scalar_t series = numer / alpha * (factor + 1 / alpha);
384   for (int i = 1; i <= 10; ++i) {
385     scalar_t casted_i = static_cast<scalar_t>(i);
386     numer *= (casted_i - beta) * x / casted_i;
387     const scalar_t denom = alpha + casted_i;
388     series += numer / denom * (factor + 1 / denom);
389   }
390   const scalar_t result = x * compat_pow(1 - x, -beta) * series;
391   return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
392 }
393 
394 // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
395 // Assumes x is close to zero and uses a Taylor expansion.
396 template <typename scalar_t, typename accscalar_t>
_beta_grad_beta_small(scalar_t x,scalar_t alpha,scalar_t beta)397 C10_DEVICE inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
398   const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha + beta) - digamma_one<scalar_t, accscalar_t>(beta);
399   scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
400   for (int i = 1; i <= 8; ++i) {
401     scalar_t casted_i = static_cast<scalar_t>(i);
402     numer *= -x / casted_i;
403     dbetas = dbetas * (beta - casted_i) + betas;
404     betas = betas * (beta - casted_i);
405     series += numer / (alpha + casted_i) * (dbetas + factor * betas);
406   }
407   const scalar_t result = -compat_pow(1 - x, 1 - beta) * series;
408   return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
409 }
410 
411 // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
412 // Assumes alpha and beta are both large and uses a Rice saddle point expansion.
413 // To ensure numerical stability, this computation is performed at higher precision.
414 template<typename scalar_t, typename accscalar_t>
_beta_grad_alpha_mid(accscalar_t x,accscalar_t alpha,accscalar_t beta)415 C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
416   const accscalar_t total = alpha + beta;
417   const accscalar_t mean = alpha / total;
418   const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
419   if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
420     // Avoid the singularity at x = mean.
421     const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
422                            (43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
423                            3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
424                            (453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
425                            8 * (1 - x) * (135 * beta - 11)))));
426     const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
427     const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
428     return prefactor_num / (1 - x) * poly / prefactor_den;
429   }
430   const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
431   const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
432                              * (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
433                              / (1 + 1 / (12 * total) + 1 / (288 * total * total));
434   const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
435   const accscalar_t axbx = alpha * (x - 1) + beta * x;
436   const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast<accscalar_t>(1.5f)) * axbx * axbx;
437   const accscalar_t term1 = term1_num / term1_den;
438   const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
439   const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
440   const accscalar_t term3_den = beta * x + alpha * (x - 1);
441   const accscalar_t term3 = term3_num / term3_den;
442   const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
443                                alpha * compat_log(alpha / (total * x));
444   const accscalar_t term4 = compat_pow(term4_base, static_cast<accscalar_t>(-1.5f));
445   const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
446   return static_cast<scalar_t>(stirling * prefactor * term1234);
447 }
448 
449 // Computes a scaled reparameterized gradient
450 //   -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x)
451 // for random number x drawn from a Beta distribution Beta(alpha,beta).
452 // This function inputs total=alpha+beta to make it easy to implement
453 // Dirichlet reparameterized gradients in terms of Betas.
454 template<typename scalar_t, typename accscalar_t>
dirichlet_grad_one(scalar_t x,scalar_t alpha,scalar_t total)455 C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
456   accscalar_t x_ = static_cast<accscalar_t>(x);
457   accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
458   accscalar_t total_ = static_cast<accscalar_t>(total);
459 
460   const scalar_t beta = total - alpha;
461   const accscalar_t beta_ = total_ - alpha_;
462   const scalar_t boundary = total * x * (1 - x);
463 
464   // Use an asymptotic approximation for x close to 0.
465   if (x <= 0.5f && boundary < 2.5f) {
466     return _beta_grad_alpha_small<scalar_t, accscalar_t>(x, alpha, beta);
467   }
468 
469   // Use an asymptotic approximation for x close to 1.
470   if (x >= 0.5f && boundary < 0.75f) {
471     return -_beta_grad_beta_small<scalar_t, accscalar_t>(1 - x, beta, alpha);
472   }
473 
474   // Use an asymptotic approximation when alpha and (total - alpha) are both large.
475   if (alpha > 6 && beta > 6) {
476     return _beta_grad_alpha_mid<scalar_t, accscalar_t>(x_, alpha_, beta_);
477   }
478 
479   // Use a rational correction to an analytic approximation.
480   static const accscalar_t c[2][3][3][4] = {
481     {{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
482       {0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
483       {-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
484      {{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
485       {-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
486       {0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
487      {{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
488       {0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
489       {0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
490     {{{1, -0.02924021934, -0.04438342661, 0.007285809825},
491       {0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
492       {-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
493      {{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
494       {0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
495       {-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
496      {{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
497       {0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
498       {-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
499   };
500   const accscalar_t u = compat_log(x_);
501   const accscalar_t a = compat_log(alpha_) - u;
502   const accscalar_t b = compat_log(total_) - a;
503   const accscalar_t pow_u[3] = {1, u, u * u};
504   const accscalar_t pow_a[3] = {1, a, a * a};
505   accscalar_t p = 0.0;
506   accscalar_t q = 0.0;
507   for (int i = 0; i < 3; ++i) {
508     for (int j = 0; j < 3; ++j) {
509       const accscalar_t ua = pow_u[i] * pow_a[j];
510       p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
511       q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
512     }
513   }
514   const accscalar_t approx = x_ * (digamma_one<scalar_t, accscalar_t>(total_) - digamma_one<scalar_t, accscalar_t>(alpha_)) / beta_;
515   return static_cast<scalar_t>(p / q * approx);
516 }
517 
518 } // namespace
519