xref: /aosp_15_r20/external/llvm-libc/src/__support/FPUtil/generic/FMA.h (revision 71db0c75aadcf003ffe3238005f61d7618a3fead)
1 //===-- Common header for FMA implementations -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
11 
12 #include "src/__support/CPP/bit.h"
13 #include "src/__support/CPP/limits.h"
14 #include "src/__support/CPP/type_traits.h"
15 #include "src/__support/FPUtil/BasicOperations.h"
16 #include "src/__support/FPUtil/FPBits.h"
17 #include "src/__support/FPUtil/cast.h"
18 #include "src/__support/FPUtil/dyadic_float.h"
19 #include "src/__support/FPUtil/rounding_mode.h"
20 #include "src/__support/big_int.h"
21 #include "src/__support/macros/attributes.h"   // LIBC_INLINE
22 #include "src/__support/macros/config.h"
23 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
24 
25 #include "hdr/fenv_macros.h"
26 
27 namespace LIBC_NAMESPACE_DECL {
28 namespace fputil {
29 namespace generic {
30 
31 template <typename OutType, typename InType>
32 LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
33                                  cpp::is_floating_point_v<InType> &&
34                                  sizeof(OutType) <= sizeof(InType),
35                              OutType>
36 fma(InType x, InType y, InType z);
37 
38 // TODO(lntue): Implement fmaf that is correctly rounded to all rounding modes.
39 // The implementation below only is only correct for the default rounding mode,
40 // round-to-nearest tie-to-even.
41 template <> LIBC_INLINE float fma<float>(float x, float y, float z) {
42   // Product is exact.
43   double prod = static_cast<double>(x) * static_cast<double>(y);
44   double z_d = static_cast<double>(z);
45   double sum = prod + z_d;
46   fputil::FPBits<double> bit_prod(prod), bitz(z_d), bit_sum(sum);
47 
48   if (!(bit_sum.is_inf_or_nan() || bit_sum.is_zero())) {
49     // Since the sum is computed in double precision, rounding might happen
50     // (for instance, when bitz.exponent > bit_prod.exponent + 5, or
51     // bit_prod.exponent > bitz.exponent + 40).  In that case, when we round
52     // the sum back to float, double rounding error might occur.
53     // A concrete example of this phenomenon is as follows:
54     //   x = y = 1 + 2^(-12), z = 2^(-53)
55     // The exact value of x*y + z is 1 + 2^(-11) + 2^(-24) + 2^(-53)
56     // So when rounding to float, fmaf(x, y, z) = 1 + 2^(-11) + 2^(-23)
57     // On the other hand, with the default rounding mode,
58     //   double(x*y + z) = 1 + 2^(-11) + 2^(-24)
59     // and casting again to float gives us:
60     //   float(double(x*y + z)) = 1 + 2^(-11).
61     //
62     // In order to correct this possible double rounding error, first we use
63     // Dekker's 2Sum algorithm to find t such that sum - t = prod + z exactly,
64     // assuming the (default) rounding mode is round-to-the-nearest,
65     // tie-to-even.  Moreover, t satisfies the condition that t < eps(sum),
66     // i.e., t.exponent < sum.exponent - 52. So if t is not 0, meaning rounding
67     // occurs when computing the sum, we just need to use t to adjust (any) last
68     // bit of sum, so that the sticky bits used when rounding sum to float are
69     // correct (when it matters).
70     fputil::FPBits<double> t(
71         (bit_prod.get_biased_exponent() >= bitz.get_biased_exponent())
72             ? ((bit_sum.get_val() - bit_prod.get_val()) - bitz.get_val())
73             : ((bit_sum.get_val() - bitz.get_val()) - bit_prod.get_val()));
74 
75     // Update sticky bits if t != 0.0 and the least (52 - 23 - 1 = 28) bits are
76     // zero.
77     if (!t.is_zero() && ((bit_sum.get_mantissa() & 0xfff'ffffULL) == 0)) {
78       if (bit_sum.sign() != t.sign())
79         bit_sum.set_mantissa(bit_sum.get_mantissa() + 1);
80       else if (bit_sum.get_mantissa())
81         bit_sum.set_mantissa(bit_sum.get_mantissa() - 1);
82     }
83   }
84 
85   return static_cast<float>(bit_sum.get_val());
86 }
87 
88 namespace internal {
89 
90 // Extract the sticky bits and shift the `mantissa` to the right by
91 // `shift_length`.
92 template <typename T>
93 LIBC_INLINE cpp::enable_if_t<is_unsigned_integral_or_big_int_v<T>, bool>
shift_mantissa(int shift_length,T & mant)94 shift_mantissa(int shift_length, T &mant) {
95   if (shift_length >= cpp::numeric_limits<T>::digits) {
96     mant = 0;
97     return true; // prod_mant is non-zero.
98   }
99   T mask = (T(1) << shift_length) - 1;
100   bool sticky_bits = (mant & mask) != 0;
101   mant >>= shift_length;
102   return sticky_bits;
103 }
104 
105 } // namespace internal
106 
107 template <typename OutType, typename InType>
108 LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
109                                  cpp::is_floating_point_v<InType> &&
110                                  sizeof(OutType) <= sizeof(InType),
111                              OutType>
112 fma(InType x, InType y, InType z) {
113   using OutFPBits = FPBits<OutType>;
114   using OutStorageType = typename OutFPBits::StorageType;
115   using InFPBits = FPBits<InType>;
116   using InStorageType = typename InFPBits::StorageType;
117 
118   constexpr int IN_EXPLICIT_MANT_LEN = InFPBits::FRACTION_LEN + 1;
119   constexpr size_t PROD_LEN = 2 * IN_EXPLICIT_MANT_LEN;
120   constexpr size_t TMP_RESULT_LEN = cpp::bit_ceil(PROD_LEN + 1);
121   using TmpResultType = UInt<TMP_RESULT_LEN>;
122   using DyadicFloat = DyadicFloat<TMP_RESULT_LEN>;
123 
124   InFPBits x_bits(x), y_bits(y), z_bits(z);
125 
126   if (LIBC_UNLIKELY(x_bits.is_nan() || y_bits.is_nan() || z_bits.is_nan())) {
127     if (x_bits.is_nan() || y_bits.is_nan()) {
128       if (x_bits.is_signaling_nan() || y_bits.is_signaling_nan() ||
129           z_bits.is_signaling_nan())
130         raise_except_if_required(FE_INVALID);
131 
132       if (x_bits.is_quiet_nan()) {
133         InStorageType x_payload = x_bits.get_mantissa();
134         x_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
135         return OutFPBits::quiet_nan(x_bits.sign(),
136                                     static_cast<OutStorageType>(x_payload))
137             .get_val();
138       }
139 
140       if (y_bits.is_quiet_nan()) {
141         InStorageType y_payload = y_bits.get_mantissa();
142         y_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
143         return OutFPBits::quiet_nan(y_bits.sign(),
144                                     static_cast<OutStorageType>(y_payload))
145             .get_val();
146       }
147 
148       if (z_bits.is_quiet_nan()) {
149         InStorageType z_payload = z_bits.get_mantissa();
150         z_payload >>= InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
151         return OutFPBits::quiet_nan(z_bits.sign(),
152                                     static_cast<OutStorageType>(z_payload))
153             .get_val();
154       }
155 
156       return OutFPBits::quiet_nan().get_val();
157     }
158   }
159 
160   if (LIBC_UNLIKELY(x == 0 || y == 0 || z == 0))
161     return cast<OutType>(x * y + z);
162 
163   int x_exp = 0;
164   int y_exp = 0;
165   int z_exp = 0;
166 
167   // Denormal scaling = 2^(fraction length).
168   constexpr InStorageType IMPLICIT_MASK =
169       InFPBits::SIG_MASK - InFPBits::FRACTION_MASK;
170 
171   constexpr InType DENORMAL_SCALING =
172       InFPBits::create_value(
173           Sign::POS, InFPBits::FRACTION_LEN + InFPBits::EXP_BIAS, IMPLICIT_MASK)
174           .get_val();
175 
176   // Normalize denormal inputs.
177   if (LIBC_UNLIKELY(InFPBits(x).is_subnormal())) {
178     x_exp -= InFPBits::FRACTION_LEN;
179     x *= DENORMAL_SCALING;
180   }
181   if (LIBC_UNLIKELY(InFPBits(y).is_subnormal())) {
182     y_exp -= InFPBits::FRACTION_LEN;
183     y *= DENORMAL_SCALING;
184   }
185   if (LIBC_UNLIKELY(InFPBits(z).is_subnormal())) {
186     z_exp -= InFPBits::FRACTION_LEN;
187     z *= DENORMAL_SCALING;
188   }
189 
190   x_bits = InFPBits(x);
191   y_bits = InFPBits(y);
192   z_bits = InFPBits(z);
193   const Sign z_sign = z_bits.sign();
194   Sign prod_sign = (x_bits.sign() == y_bits.sign()) ? Sign::POS : Sign::NEG;
195   x_exp += x_bits.get_biased_exponent();
196   y_exp += y_bits.get_biased_exponent();
197   z_exp += z_bits.get_biased_exponent();
198 
199   if (LIBC_UNLIKELY(x_exp == InFPBits::MAX_BIASED_EXPONENT ||
200                     y_exp == InFPBits::MAX_BIASED_EXPONENT ||
201                     z_exp == InFPBits::MAX_BIASED_EXPONENT))
202     return cast<OutType>(x * y + z);
203 
204   // Extract mantissa and append hidden leading bits.
205   InStorageType x_mant = x_bits.get_explicit_mantissa();
206   InStorageType y_mant = y_bits.get_explicit_mantissa();
207   TmpResultType z_mant = z_bits.get_explicit_mantissa();
208 
209   // If the exponent of the product x*y > the exponent of z, then no extra
210   // precision beside the entire product x*y is needed.  On the other hand, when
211   // the exponent of z >= the exponent of the product x*y, the worst-case that
212   // we need extra precision is when there is cancellation and the most
213   // significant bit of the product is aligned exactly with the second most
214   // significant bit of z:
215   //      z :    10aa...a
216   // - prod :     1bb...bb....b
217   // In that case, in order to store the exact result, we need at least
218   //     (Length of prod) - (Fraction length of z)
219   //   = 2*(Length of input explicit mantissa) - (Fraction length of z) bits.
220   // Overall, before aligning the mantissas and exponents, we can simply left-
221   // shift the mantissa of z by that amount.  After that, it is enough to align
222   // the least significant bit, given that we keep track of the round and sticky
223   // bits after the least significant bit.
224 
225   TmpResultType prod_mant = TmpResultType(x_mant) * y_mant;
226   int prod_lsb_exp =
227       x_exp + y_exp - (InFPBits::EXP_BIAS + 2 * InFPBits::FRACTION_LEN);
228 
229   constexpr int RESULT_MIN_LEN = PROD_LEN - InFPBits::FRACTION_LEN;
230   z_mant <<= RESULT_MIN_LEN;
231   int z_lsb_exp = z_exp - (InFPBits::FRACTION_LEN + RESULT_MIN_LEN);
232   bool sticky_bits = false;
233   bool z_shifted = false;
234 
235   // Align exponents.
236   if (prod_lsb_exp < z_lsb_exp) {
237     sticky_bits = internal::shift_mantissa(z_lsb_exp - prod_lsb_exp, prod_mant);
238     prod_lsb_exp = z_lsb_exp;
239   } else if (z_lsb_exp < prod_lsb_exp) {
240     z_shifted = true;
241     sticky_bits = internal::shift_mantissa(prod_lsb_exp - z_lsb_exp, z_mant);
242   }
243 
244   // Perform the addition:
245   //   (-1)^prod_sign * prod_mant + (-1)^z_sign * z_mant.
246   // The final result will be stored in prod_sign and prod_mant.
247   if (prod_sign == z_sign) {
248     // Effectively an addition.
249     prod_mant += z_mant;
250   } else {
251     // Subtraction cases.
252     if (prod_mant >= z_mant) {
253       if (z_shifted && sticky_bits) {
254         // Add 1 more to the subtrahend so that the sticky bits remain
255         // positive. This would simplify the rounding logic.
256         ++z_mant;
257       }
258       prod_mant -= z_mant;
259     } else {
260       if (!z_shifted && sticky_bits) {
261         // Add 1 more to the subtrahend so that the sticky bits remain
262         // positive. This would simplify the rounding logic.
263         ++prod_mant;
264       }
265       prod_mant = z_mant - prod_mant;
266       prod_sign = z_sign;
267     }
268   }
269 
270   if (prod_mant == 0) {
271     // When there is exact cancellation, i.e., x*y == -z exactly, return -0.0 if
272     // rounding downward and +0.0 for other rounding modes.
273     if (quick_get_round() == FE_DOWNWARD)
274       prod_sign = Sign::NEG;
275     else
276       prod_sign = Sign::POS;
277   }
278 
279   DyadicFloat result(prod_sign, prod_lsb_exp - InFPBits::EXP_BIAS, prod_mant);
280   result.mantissa |= static_cast<unsigned int>(sticky_bits);
281   return result.template as<OutType, /*ShouldSignalExceptions=*/true>();
282 }
283 
284 } // namespace generic
285 } // namespace fputil
286 } // namespace LIBC_NAMESPACE_DECL
287 
288 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_FMA_H
289