xref: /aosp_15_r20/external/llvm-libc/src/math/generic/expxf16.h (revision 71db0c75aadcf003ffe3238005f61d7618a3fead)
1 //===-- Common utilities for half-precision exponential functions ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_SRC_MATH_GENERIC_EXPXF16_H
10 #define LLVM_LIBC_SRC_MATH_GENERIC_EXPXF16_H
11 
12 #include "src/__support/CPP/array.h"
13 #include "src/__support/FPUtil/FPBits.h"
14 #include "src/__support/FPUtil/PolyEval.h"
15 #include "src/__support/FPUtil/cast.h"
16 #include "src/__support/FPUtil/multiply_add.h"
17 #include "src/__support/FPUtil/nearest_integer.h"
18 #include "src/__support/macros/attributes.h"
19 #include "src/__support/macros/config.h"
20 #include <stdint.h>
21 
22 namespace LIBC_NAMESPACE_DECL {
23 
24 // Generated by Sollya with the following commands:
25 //   > display = hexadecimal;
26 //   > for i from -18 to 12 do print(round(exp(i), SG, RN));
27 static constexpr cpp::array<float, 31> EXP_HI = {
28     0x1.05a628p-26f, 0x1.639e32p-25f, 0x1.e355bcp-24f, 0x1.4875cap-22f,
29     0x1.be6c7p-21f,  0x1.2f6054p-19f, 0x1.9c54c4p-18f, 0x1.183542p-16f,
30     0x1.7cd79cp-15f, 0x1.02cf22p-13f, 0x1.5fc21p-12f,  0x1.de16bap-11f,
31     0x1.44e52p-9f,   0x1.b993fep-8f,  0x1.2c155cp-6f,  0x1.97db0cp-5f,
32     0x1.152aaap-3f,  0x1.78b564p-2f,  0x1p+0f,         0x1.5bf0a8p+1f,
33     0x1.d8e64cp+2f,  0x1.415e5cp+4f,  0x1.b4c902p+5f,  0x1.28d38ap+7f,
34     0x1.936dc6p+8f,  0x1.122886p+10f, 0x1.749ea8p+11f, 0x1.fa7158p+12f,
35     0x1.5829dcp+14f, 0x1.d3c448p+15f, 0x1.3de166p+17f,
36 };
37 
38 // Generated by Sollya with the following commands:
39 //   > display = hexadecimal;
40 //   > for i from 0 to 7 do print(round(exp(i * 2^-3), SG, RN));
41 static constexpr cpp::array<float, 8> EXP_MID = {
42     0x1p+0f,        0x1.221604p+0f, 0x1.48b5e4p+0f, 0x1.747a52p+0f,
43     0x1.a61298p+0f, 0x1.de455ep+0f, 0x1.0ef9dcp+1f, 0x1.330e58p+1f,
44 };
45 
46 struct ExpRangeReduction {
47   float exp_hi_mid;
48   float exp_lo;
49 };
50 
exp_range_reduction(float16 x)51 LIBC_INLINE ExpRangeReduction exp_range_reduction(float16 x) {
52   // For -18 < x < 12, to compute exp(x), we perform the following range
53   // reduction: find hi, mid, lo, such that:
54   //   x = hi + mid + lo, in which
55   //     hi is an integer,
56   //     mid * 2^3 is an integer,
57   //     -2^(-4) <= lo < 2^(-4).
58   // In particular,
59   //   hi + mid = round(x * 2^3) * 2^(-3).
60   // Then,
61   //   exp(x) = exp(hi + mid + lo) = exp(hi) * exp(mid) * exp(lo).
62   // We store exp(hi) and exp(mid) in the lookup tables EXP_HI and EXP_MID
63   // respectively.  exp(lo) is computed using a degree-3 minimax polynomial
64   // generated by Sollya.
65 
66   float xf = x;
67   float kf = fputil::nearest_integer(xf * 0x1.0p+3f);
68   int x_hi_mid = static_cast<int>(kf);
69   int x_hi = x_hi_mid >> 3;
70   int x_mid = x_hi_mid & 0x7;
71   // lo = x - (hi + mid) = round(x * 2^3) * (-2^(-3)) + x
72   float lo = fputil::multiply_add(kf, -0x1.0p-3f, xf);
73 
74   float exp_hi = EXP_HI[x_hi + 18];
75   float exp_mid = EXP_MID[x_mid];
76   // Degree-3 minimax polynomial generated by Sollya with the following
77   // commands:
78   //   > display = hexadecimal;
79   //   > P = fpminimax(expm1(x)/x, 2, [|SG...|], [-2^-4, 2^-4]);
80   //   > 1 + x * P;
81   float exp_lo =
82       fputil::polyeval(lo, 0x1p+0f, 0x1p+0f, 0x1.001p-1f, 0x1.555ddep-3f);
83   return {exp_hi * exp_mid, exp_lo};
84 }
85 
86 // Generated by Sollya with the following commands:
87 //   > display = hexadecimal;
88 //   > for i from 0 to 7 do printsingle(round(2^(i * 2^-3), SG, RN));
89 constexpr cpp::array<uint32_t, 8> EXP2_MID_BITS = {
90     0x3f80'0000U, 0x3f8b'95c2U, 0x3f98'37f0U, 0x3fa5'fed7U,
91     0x3fb5'04f3U, 0x3fc5'672aU, 0x3fd7'44fdU, 0x3fea'c0c7U,
92 };
93 
exp2_range_reduction(float16 x)94 LIBC_INLINE ExpRangeReduction exp2_range_reduction(float16 x) {
95   // For -25 < x < 16, to compute 2^x, we perform the following range reduction:
96   // find hi, mid, lo, such that:
97   //   x = hi + mid + lo, in which
98   //     hi is an integer,
99   //     mid * 2^3 is an integer,
100   //     -2^(-4) <= lo < 2^(-4).
101   // In particular,
102   //   hi + mid = round(x * 2^3) * 2^(-3).
103   // Then,
104   //   2^x = 2^(hi + mid + lo) = 2^hi * 2^mid * 2^lo.
105   // We store 2^mid in the lookup table EXP2_MID_BITS, and compute 2^hi * 2^mid
106   // by adding hi to the exponent field of 2^mid.  2^lo is computed using a
107   // degree-3 minimax polynomial generated by Sollya.
108 
109   float xf = x;
110   float kf = fputil::nearest_integer(xf * 0x1.0p+3f);
111   int x_hi_mid = static_cast<int>(kf);
112   unsigned x_hi = static_cast<unsigned>(x_hi_mid) >> 3;
113   unsigned x_mid = static_cast<unsigned>(x_hi_mid) & 0x7;
114   // lo = x - (hi + mid) = round(x * 2^3) * (-2^(-3)) + x
115   float lo = fputil::multiply_add(kf, -0x1.0p-3f, xf);
116 
117   uint32_t exp2_hi_mid_bits =
118       EXP2_MID_BITS[x_mid] +
119       static_cast<uint32_t>(x_hi << fputil::FPBits<float>::FRACTION_LEN);
120   float exp2_hi_mid = fputil::FPBits<float>(exp2_hi_mid_bits).get_val();
121   // Degree-3 minimax polynomial generated by Sollya with the following
122   // commands:
123   //   > display = hexadecimal;
124   //   > P = fpminimax((2^x - 1)/x, 2, [|SG...|], [-2^-4, 2^-4]);
125   //   > 1 + x * P;
126   float exp2_lo = fputil::polyeval(lo, 0x1p+0f, 0x1.62e43p-1f, 0x1.ec0aa6p-3f,
127                                    0x1.c6b4a6p-5f);
128   return {exp2_hi_mid, exp2_lo};
129 }
130 
131 // Generated by Sollya with the following commands:
132 //   > display = hexadecimal;
133 //   > round(log2(10), SG, RN);
134 static constexpr float LOG2F_10 = 0x1.a934fp+1f;
135 
136 // Generated by Sollya with the following commands:
137 //   > display = hexadecimal;
138 //   > round(log10(2), SG, RN);
139 static constexpr float LOG10F_2 = 0x1.344136p-2f;
140 
exp10_range_reduction(float16 x)141 LIBC_INLINE ExpRangeReduction exp10_range_reduction(float16 x) {
142   // For -8 < x < 5, to compute 10^x, we perform the following range reduction:
143   // find hi, mid, lo, such that:
144   //   x = (hi + mid) * log2(10) + lo, in which
145   //     hi is an integer,
146   //     mid * 2^3 is an integer,
147   //     -2^(-4) <= lo < 2^(-4).
148   // In particular,
149   //   hi + mid = round(x * 2^3) * 2^(-3).
150   // Then,
151   //   10^x = 10^(hi + mid + lo) = 2^((hi + mid) * log2(10)) + 10^lo
152   // We store 2^mid in the lookup table EXP2_MID_BITS, and compute 2^hi * 2^mid
153   // by adding hi to the exponent field of 2^mid.  10^lo is computed using a
154   // degree-4 minimax polynomial generated by Sollya.
155 
156   float xf = x;
157   float kf = fputil::nearest_integer(xf * (LOG2F_10 * 0x1.0p+3f));
158   int x_hi_mid = static_cast<int>(kf);
159   unsigned x_hi = static_cast<unsigned>(x_hi_mid) >> 3;
160   unsigned x_mid = static_cast<unsigned>(x_hi_mid) & 0x7;
161   // lo = x - (hi + mid) = round(x * 2^3 * log2(10)) * log10(2) * (-2^(-3)) + x
162   float lo = fputil::multiply_add(kf, LOG10F_2 * -0x1.0p-3f, xf);
163 
164   uint32_t exp2_hi_mid_bits =
165       EXP2_MID_BITS[x_mid] +
166       static_cast<uint32_t>(x_hi << fputil::FPBits<float>::FRACTION_LEN);
167   float exp2_hi_mid = fputil::FPBits<float>(exp2_hi_mid_bits).get_val();
168   // Degree-4 minimax polynomial generated by Sollya with the following
169   // commands:
170   //   > display = hexadecimal;
171   //   > P = fpminimax((10^x - 1)/x, 3, [|SG...|], [-2^-4, 2^-4]);
172   //   > 1 + x * P;
173   float exp10_lo = fputil::polyeval(lo, 0x1p+0f, 0x1.26bb14p+1f, 0x1.53526p+1f,
174                                     0x1.04b434p+1f, 0x1.2bcf9ep+0f);
175   return {exp2_hi_mid, exp10_lo};
176 }
177 
178 // Generated by Sollya with the following commands:
179 //   > display = hexadecimal;
180 //   > round(log2(exp(1)), SG, RN);
181 static constexpr float LOG2F_E = 0x1.715476p+0f;
182 
183 // Generated by Sollya with the following commands:
184 //   > display = hexadecimal;
185 //   > round(log(2), SG, RN);
186 static constexpr float LOGF_2 = 0x1.62e43p-1f;
187 
188 // Generated by Sollya with the following commands:
189 //   > display = hexadecimal;
190 //   > for i from 0 to 31 do printsingle(round(2^(i * 2^-5), SG, RN));
191 static constexpr cpp::array<uint32_t, 32> EXP2_MID_5_BITS = {
192     0x3f80'0000U, 0x3f82'cd87U, 0x3f85'aac3U, 0x3f88'980fU, 0x3f8b'95c2U,
193     0x3f8e'a43aU, 0x3f91'c3d3U, 0x3f94'f4f0U, 0x3f98'37f0U, 0x3f9b'8d3aU,
194     0x3f9e'f532U, 0x3fa2'7043U, 0x3fa5'fed7U, 0x3fa9'a15bU, 0x3fad'583fU,
195     0x3fb1'23f6U, 0x3fb5'04f3U, 0x3fb8'fbafU, 0x3fbd'08a4U, 0x3fc1'2c4dU,
196     0x3fc5'672aU, 0x3fc9'b9beU, 0x3fce'248cU, 0x3fd2'a81eU, 0x3fd7'44fdU,
197     0x3fdb'fbb8U, 0x3fe0'ccdfU, 0x3fe5'b907U, 0x3fea'c0c7U, 0x3fef'e4baU,
198     0x3ff5'257dU, 0x3ffa'83b3U,
199 };
200 
201 // This function correctly calculates sinh(x) and cosh(x) by calculating exp(x)
202 // and exp(-x) simultaneously.
203 // To compute e^x, we perform the following range reduction:
204 // find hi, mid, lo such that:
205 //   x = (hi + mid) * log(2) + lo, in which
206 //     hi is an integer,
207 //     0 <= mid * 2^5 < 32 is an integer
208 //     -2^(-5) <= lo * log2(e) <= 2^-5.
209 // In particular,
210 //   hi + mid = round(x * log2(e) * 2^5) * 2^(-5).
211 // Then,
212 //   e^x = 2^(hi + mid) * e^lo = 2^hi * 2^mid * e^lo.
213 // We store 2^mid in the lookup table EXP2_MID_5_BITS, and compute 2^hi * 2^mid
214 // by adding hi to the exponent field of 2^mid.
215 // e^lo is computed using a degree-3 minimax polynomial generated by Sollya:
216 //   e^lo ~ P(lo)
217 //        = 1 + lo + c2 * lo^2 + ... + c5 * lo^5
218 //        = (1 + c2*lo^2 + c4*lo^4) + lo * (1 + c3*lo^2 + c5*lo^4)
219 //        = P_even + lo * P_odd
220 // To compute e^(-x), notice that:
221 //   e^(-x) = 2^(-(hi + mid)) * e^(-lo)
222 //          ~ 2^(-(hi + mid)) * P(-lo)
223 //          = 2^(-(hi + mid)) * (P_even - lo * P_odd)
224 // So:
225 //   sinh(x) = (e^x - e^(-x)) / 2
226 //           ~ 0.5 * (2^(hi + mid) * (P_even + lo * P_odd) -
227 //                    2^(-(hi + mid)) * (P_even - lo * P_odd))
228 //           = 0.5 * (P_even * (2^(hi + mid) - 2^(-(hi + mid))) +
229 //                    lo * P_odd * (2^(hi + mid) + 2^(-(hi + mid))))
230 // And similarly:
231 //   cosh(x) = (e^x + e^(-x)) / 2
232 //           ~ 0.5 * (P_even * (2^(hi + mid) + 2^(-(hi + mid))) +
233 //                    lo * P_odd * (2^(hi + mid) - 2^(-(hi + mid))))
234 // The main point of these formulas is that the expensive part of calculating
235 // the polynomials approximating lower parts of e^x and e^(-x) is shared and
236 // only done once.
eval_sinh_or_cosh(float16 x)237 template <bool IsSinh> LIBC_INLINE float16 eval_sinh_or_cosh(float16 x) {
238   float xf = x;
239   float kf = fputil::nearest_integer(xf * (LOG2F_E * 0x1.0p+5f));
240   int x_hi_mid_p = static_cast<int>(kf);
241   int x_hi_mid_m = -x_hi_mid_p;
242 
243   unsigned x_hi_p = static_cast<unsigned>(x_hi_mid_p) >> 5;
244   unsigned x_hi_m = static_cast<unsigned>(x_hi_mid_m) >> 5;
245   unsigned x_mid_p = static_cast<unsigned>(x_hi_mid_p) & 0x1f;
246   unsigned x_mid_m = static_cast<unsigned>(x_hi_mid_m) & 0x1f;
247 
248   uint32_t exp2_hi_mid_bits_p =
249       EXP2_MID_5_BITS[x_mid_p] +
250       static_cast<uint32_t>(x_hi_p << fputil::FPBits<float>::FRACTION_LEN);
251   uint32_t exp2_hi_mid_bits_m =
252       EXP2_MID_5_BITS[x_mid_m] +
253       static_cast<uint32_t>(x_hi_m << fputil::FPBits<float>::FRACTION_LEN);
254   // exp2_hi_mid_p = 2^(hi + mid)
255   float exp2_hi_mid_p = fputil::FPBits<float>(exp2_hi_mid_bits_p).get_val();
256   // exp2_hi_mid_m = 2^(-(hi + mid))
257   float exp2_hi_mid_m = fputil::FPBits<float>(exp2_hi_mid_bits_m).get_val();
258 
259   // exp2_hi_mid_sum = 2^(hi + mid) + 2^(-(hi + mid))
260   float exp2_hi_mid_sum = exp2_hi_mid_p + exp2_hi_mid_m;
261   // exp2_hi_mid_diff = 2^(hi + mid) - 2^(-(hi + mid))
262   float exp2_hi_mid_diff = exp2_hi_mid_p - exp2_hi_mid_m;
263 
264   // lo = x - (hi + mid) = round(x * log2(e) * 2^5) * log(2) * (-2^(-5)) + x
265   float lo = fputil::multiply_add(kf, LOGF_2 * -0x1.0p-5f, xf);
266   float lo_sq = lo * lo;
267 
268   // Degree-3 minimax polynomial generated by Sollya with the following
269   // commands:
270   //   > display = hexadecimal;
271   //   > P = fpminimax(expm1(x)/x, 2, [|SG...|], [-2^-5, 2^-5]);
272   //   > 1 + x * P;
273   constexpr cpp::array<float, 4> COEFFS = {0x1p+0f, 0x1p+0f, 0x1.0004p-1f,
274                                            0x1.555778p-3f};
275   float half_p_odd =
276       fputil::polyeval(lo_sq, COEFFS[1] * 0.5f, COEFFS[3] * 0.5f);
277   float half_p_even =
278       fputil::polyeval(lo_sq, COEFFS[0] * 0.5f, COEFFS[2] * 0.5f);
279 
280   // sinh(x) = lo * (0.5 * P_odd * (2^(hi + mid) + 2^(-(hi + mid)))) +
281   //                (0.5 * P_even * (2^(hi + mid) - 2^(-(hi + mid))))
282   if constexpr (IsSinh)
283     return fputil::cast<float16>(fputil::multiply_add(
284         lo, half_p_odd * exp2_hi_mid_sum, half_p_even * exp2_hi_mid_diff));
285   // cosh(x) = lo * (0.5 * P_odd * (2^(hi + mid) - 2^(-(hi + mid)))) +
286   //                (0.5 * P_even * (2^(hi + mid) + 2^(-(hi + mid))))
287   return fputil::cast<float16>(fputil::multiply_add(
288       lo, half_p_odd * exp2_hi_mid_diff, half_p_even * exp2_hi_mid_sum));
289 }
290 
291 // Generated by Sollya with the following commands:
292 //   > display = hexadecimal;
293 //   > for i from 0 to 31 do print(round(log(1 + i * 2^-5), SG, RN));
294 constexpr cpp::array<float, 32> LOGF_F = {
295     0x0p+0f,        0x1.f829bp-6f,  0x1.f0a30cp-5f, 0x1.6f0d28p-4f,
296     0x1.e27076p-4f, 0x1.29553p-3f,  0x1.5ff308p-3f, 0x1.9525aap-3f,
297     0x1.c8ff7cp-3f, 0x1.fb9186p-3f, 0x1.1675cap-2f, 0x1.2e8e2cp-2f,
298     0x1.4618bcp-2f, 0x1.5d1bdcp-2f, 0x1.739d8p-2f,  0x1.89a338p-2f,
299     0x1.9f323ep-2f, 0x1.b44f78p-2f, 0x1.c8ff7cp-2f, 0x1.dd46ap-2f,
300     0x1.f128f6p-2f, 0x1.02552ap-1f, 0x1.0be72ep-1f, 0x1.154c3ep-1f,
301     0x1.1e85f6p-1f, 0x1.2795e2p-1f, 0x1.307d74p-1f, 0x1.393e0ep-1f,
302     0x1.41d8fep-1f, 0x1.4a4f86p-1f, 0x1.52a2d2p-1f, 0x1.5ad404p-1f,
303 };
304 
305 // Generated by Sollya with the following commands:
306 //   > display = hexadecimal;
307 //   > for i from 0 to 31 do print(round(log2(1 + i * 2^-5), SG, RN));
308 constexpr cpp::array<float, 32> LOG2F_F = {
309     0x0p+0f,        0x1.6bad38p-5f, 0x1.663f7p-4f,  0x1.08c588p-3f,
310     0x1.5c01a4p-3f, 0x1.acf5e2p-3f, 0x1.fbc16cp-3f, 0x1.24407ap-2f,
311     0x1.49a784p-2f, 0x1.6e221cp-2f, 0x1.91bba8p-2f, 0x1.b47ecp-2f,
312     0x1.d6753ep-2f, 0x1.f7a856p-2f, 0x1.0c105p-1f,  0x1.1bf312p-1f,
313     0x1.2b8034p-1f, 0x1.3abb4p-1f,  0x1.49a784p-1f, 0x1.584822p-1f,
314     0x1.66a008p-1f, 0x1.74b1fep-1f, 0x1.82809ep-1f, 0x1.900e62p-1f,
315     0x1.9d5dap-1f,  0x1.aa709p-1f,  0x1.b74948p-1f, 0x1.c3e9cap-1f,
316     0x1.d053f6p-1f, 0x1.dc899ap-1f, 0x1.e88c6cp-1f, 0x1.f45e08p-1f,
317 };
318 
319 // Generated by Sollya with the following commands:
320 //   > display = hexadecimal;
321 //   > for i from 0 to 31 do print(round(log10(1 + i * 2^-5), SG, RN));
322 constexpr cpp::array<float, 32> LOG10F_F = {
323     0x0p+0f,        0x1.b5e908p-7f, 0x1.af5f92p-6f, 0x1.3ed11ap-5f,
324     0x1.a30a9ep-5f, 0x1.02428cp-4f, 0x1.31b306p-4f, 0x1.5fe804p-4f,
325     0x1.8cf184p-4f, 0x1.b8de4ep-4f, 0x1.e3bc1ap-4f, 0x1.06cbd6p-3f,
326     0x1.1b3e72p-3f, 0x1.2f3b6ap-3f, 0x1.42c7e8p-3f, 0x1.55e8c6p-3f,
327     0x1.68a288p-3f, 0x1.7af974p-3f, 0x1.8cf184p-3f, 0x1.9e8e7cp-3f,
328     0x1.afd3e4p-3f, 0x1.c0c514p-3f, 0x1.d1653p-3f,  0x1.e1b734p-3f,
329     0x1.f1bdeep-3f, 0x1.00be06p-2f, 0x1.087a08p-2f, 0x1.101432p-2f,
330     0x1.178da6p-2f, 0x1.1ee778p-2f, 0x1.2622bp-2f,  0x1.2d404cp-2f,
331 };
332 
333 // Generated by Sollya with the following commands:
334 //   > display = hexadecimal;
335 //   > for i from 0 to 31 do print(round(1 / (1 + i * 2^-5), SG, RN));
336 constexpr cpp::array<float, 32> ONE_OVER_F_F = {
337     0x1p+0f,        0x1.f07c2p-1f,  0x1.e1e1e2p-1f, 0x1.d41d42p-1f,
338     0x1.c71c72p-1f, 0x1.bacf92p-1f, 0x1.af286cp-1f, 0x1.a41a42p-1f,
339     0x1.99999ap-1f, 0x1.8f9c18p-1f, 0x1.861862p-1f, 0x1.7d05f4p-1f,
340     0x1.745d18p-1f, 0x1.6c16c2p-1f, 0x1.642c86p-1f, 0x1.5c9882p-1f,
341     0x1.555556p-1f, 0x1.4e5e0ap-1f, 0x1.47ae14p-1f, 0x1.414142p-1f,
342     0x1.3b13b2p-1f, 0x1.3521dp-1f,  0x1.2f684cp-1f, 0x1.29e412p-1f,
343     0x1.24924ap-1f, 0x1.1f7048p-1f, 0x1.1a7b96p-1f, 0x1.15b1e6p-1f,
344     0x1.111112p-1f, 0x1.0c9714p-1f, 0x1.08421p-1f,  0x1.041042p-1f,
345 };
346 
347 } // namespace LIBC_NAMESPACE_DECL
348 
349 #endif // LLVM_LIBC_SRC_MATH_GENERIC_EXPXF16_H
350