xref: /aosp_15_r20/external/llvm-libc/src/__support/FPUtil/dyadic_float.h (revision 71db0c75aadcf003ffe3238005f61d7618a3fead)
1 //===-- A class to store high precision floating point numbers --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
11 
12 #include "FEnvImpl.h"
13 #include "FPBits.h"
14 #include "hdr/errno_macros.h"
15 #include "hdr/fenv_macros.h"
16 #include "multiply_add.h"
17 #include "rounding_mode.h"
18 #include "src/__support/CPP/type_traits.h"
19 #include "src/__support/big_int.h"
20 #include "src/__support/macros/config.h"
21 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
22 #include "src/__support/macros/properties/types.h"
23 
24 #include <stddef.h>
25 
26 namespace LIBC_NAMESPACE_DECL {
27 namespace fputil {
28 
29 // A generic class to perform computations of high precision floating points.
30 // We store the value in dyadic format, including 3 fields:
31 //   sign    : boolean value - false means positive, true means negative
32 //   exponent: the exponent value of the least significant bit of the mantissa.
33 //   mantissa: unsigned integer of length `Bits`.
34 // So the real value that is stored is:
35 //   real value = (-1)^sign * 2^exponent * (mantissa as unsigned integer)
36 // The stored data is normal if for non-zero mantissa, the leading bit is 1.
37 // The outputs of the constructors and most functions will be normalized.
38 // To simplify and improve the efficiency, many functions will assume that the
39 // inputs are normal.
40 template <size_t Bits> struct DyadicFloat {
41   using MantissaType = LIBC_NAMESPACE::UInt<Bits>;
42 
43   Sign sign = Sign::POS;
44   int exponent = 0;
45   MantissaType mantissa = MantissaType(0);
46 
47   LIBC_INLINE constexpr DyadicFloat() = default;
48 
49   template <typename T, cpp::enable_if_t<cpp::is_floating_point_v<T>, int> = 0>
DyadicFloatDyadicFloat50   LIBC_INLINE constexpr DyadicFloat(T x) {
51     static_assert(FPBits<T>::FRACTION_LEN < Bits);
52     FPBits<T> x_bits(x);
53     sign = x_bits.sign();
54     exponent = x_bits.get_explicit_exponent() - FPBits<T>::FRACTION_LEN;
55     mantissa = MantissaType(x_bits.get_explicit_mantissa());
56     normalize();
57   }
58 
DyadicFloatDyadicFloat59   LIBC_INLINE constexpr DyadicFloat(Sign s, int e, MantissaType m)
60       : sign(s), exponent(e), mantissa(m) {
61     normalize();
62   }
63 
64   // Normalizing the mantissa, bringing the leading 1 bit to the most
65   // significant bit.
normalizeDyadicFloat66   LIBC_INLINE constexpr DyadicFloat &normalize() {
67     if (!mantissa.is_zero()) {
68       int shift_length = cpp::countl_zero(mantissa);
69       exponent -= shift_length;
70       mantissa <<= static_cast<size_t>(shift_length);
71     }
72     return *this;
73   }
74 
75   // Used for aligning exponents.  Output might not be normalized.
shift_leftDyadicFloat76   LIBC_INLINE constexpr DyadicFloat &shift_left(unsigned shift_length) {
77     if (shift_length < Bits) {
78       exponent -= static_cast<int>(shift_length);
79       mantissa <<= shift_length;
80     } else {
81       exponent = 0;
82       mantissa = MantissaType(0);
83     }
84     return *this;
85   }
86 
87   // Used for aligning exponents.  Output might not be normalized.
shift_rightDyadicFloat88   LIBC_INLINE constexpr DyadicFloat &shift_right(unsigned shift_length) {
89     if (shift_length < Bits) {
90       exponent += static_cast<int>(shift_length);
91       mantissa >>= shift_length;
92     } else {
93       exponent = 0;
94       mantissa = MantissaType(0);
95     }
96     return *this;
97   }
98 
99   // Assume that it is already normalized.  Output the unbiased exponent.
get_unbiased_exponentDyadicFloat100   LIBC_INLINE constexpr int get_unbiased_exponent() const {
101     return exponent + (Bits - 1);
102   }
103 
104 #ifdef LIBC_TYPES_HAS_FLOAT16
105   template <typename T, bool ShouldSignalExceptions>
106   LIBC_INLINE constexpr cpp::enable_if_t<
107       cpp::is_floating_point_v<T> && (FPBits<T>::FRACTION_LEN < Bits), T>
generic_asDyadicFloat108   generic_as() const {
109     using FPBits = FPBits<float16>;
110     using StorageType = typename FPBits::StorageType;
111 
112     constexpr int EXTRA_FRACTION_LEN = Bits - 1 - FPBits::FRACTION_LEN;
113 
114     if (mantissa == 0)
115       return FPBits::zero(sign).get_val();
116 
117     int unbiased_exp = get_unbiased_exponent();
118 
119     if (unbiased_exp + FPBits::EXP_BIAS >= FPBits::MAX_BIASED_EXPONENT) {
120       if constexpr (ShouldSignalExceptions) {
121         set_errno_if_required(ERANGE);
122         raise_except_if_required(FE_OVERFLOW | FE_INEXACT);
123       }
124 
125       switch (quick_get_round()) {
126       case FE_TONEAREST:
127         return FPBits::inf(sign).get_val();
128       case FE_TOWARDZERO:
129         return FPBits::max_normal(sign).get_val();
130       case FE_DOWNWARD:
131         if (sign.is_pos())
132           return FPBits::max_normal(Sign::POS).get_val();
133         return FPBits::inf(Sign::NEG).get_val();
134       case FE_UPWARD:
135         if (sign.is_neg())
136           return FPBits::max_normal(Sign::NEG).get_val();
137         return FPBits::inf(Sign::POS).get_val();
138       default:
139         __builtin_unreachable();
140       }
141     }
142 
143     StorageType out_biased_exp = 0;
144     StorageType out_mantissa = 0;
145     bool round = false;
146     bool sticky = false;
147     bool underflow = false;
148 
149     if (unbiased_exp < -FPBits::EXP_BIAS - FPBits::FRACTION_LEN) {
150       sticky = true;
151       underflow = true;
152     } else if (unbiased_exp == -FPBits::EXP_BIAS - FPBits::FRACTION_LEN) {
153       round = true;
154       MantissaType sticky_mask = (MantissaType(1) << (Bits - 1)) - 1;
155       sticky = (mantissa & sticky_mask) != 0;
156     } else {
157       int extra_fraction_len = EXTRA_FRACTION_LEN;
158 
159       if (unbiased_exp < 1 - FPBits::EXP_BIAS) {
160         underflow = true;
161         extra_fraction_len += 1 - FPBits::EXP_BIAS - unbiased_exp;
162       } else {
163         out_biased_exp =
164             static_cast<StorageType>(unbiased_exp + FPBits::EXP_BIAS);
165       }
166 
167       MantissaType round_mask = MantissaType(1) << (extra_fraction_len - 1);
168       round = (mantissa & round_mask) != 0;
169       MantissaType sticky_mask = round_mask - 1;
170       sticky = (mantissa & sticky_mask) != 0;
171 
172       out_mantissa = static_cast<StorageType>(mantissa >> extra_fraction_len);
173     }
174 
175     bool lsb = (out_mantissa & 1) != 0;
176 
177     StorageType result =
178         FPBits::create_value(sign, out_biased_exp, out_mantissa).uintval();
179 
180     switch (quick_get_round()) {
181     case FE_TONEAREST:
182       if (round && (lsb || sticky))
183         ++result;
184       break;
185     case FE_DOWNWARD:
186       if (sign.is_neg() && (round || sticky))
187         ++result;
188       break;
189     case FE_UPWARD:
190       if (sign.is_pos() && (round || sticky))
191         ++result;
192       break;
193     default:
194       break;
195     }
196 
197     if (ShouldSignalExceptions && (round || sticky)) {
198       int excepts = FE_INEXACT;
199       if (FPBits(result).is_inf()) {
200         set_errno_if_required(ERANGE);
201         excepts |= FE_OVERFLOW;
202       } else if (underflow) {
203         set_errno_if_required(ERANGE);
204         excepts |= FE_UNDERFLOW;
205       }
206       raise_except_if_required(excepts);
207     }
208 
209     return FPBits(result).get_val();
210   }
211 #endif // LIBC_TYPES_HAS_FLOAT16
212 
213   template <typename T, bool ShouldSignalExceptions,
214             typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
215                                             (FPBits<T>::FRACTION_LEN < Bits),
216                                         void>>
fast_asDyadicFloat217   LIBC_INLINE constexpr T fast_as() const {
218     if (LIBC_UNLIKELY(mantissa.is_zero()))
219       return FPBits<T>::zero(sign).get_val();
220 
221     // Assume that it is normalized, and output is also normal.
222     constexpr uint32_t PRECISION = FPBits<T>::FRACTION_LEN + 1;
223     using output_bits_t = typename FPBits<T>::StorageType;
224     constexpr output_bits_t IMPLICIT_MASK =
225         FPBits<T>::SIG_MASK - FPBits<T>::FRACTION_MASK;
226 
227     int exp_hi = exponent + static_cast<int>((Bits - 1) + FPBits<T>::EXP_BIAS);
228 
229     if (LIBC_UNLIKELY(exp_hi > 2 * FPBits<T>::EXP_BIAS)) {
230       // Results overflow.
231       T d_hi =
232           FPBits<T>::create_value(sign, 2 * FPBits<T>::EXP_BIAS, IMPLICIT_MASK)
233               .get_val();
234       // volatile prevents constant propagation that would result in infinity
235       // always being returned no matter the current rounding mode.
236       volatile T two = static_cast<T>(2.0);
237       T r = two * d_hi;
238 
239       // TODO: Whether rounding down the absolute value to max_normal should
240       // also raise FE_OVERFLOW and set ERANGE is debatable.
241       if (ShouldSignalExceptions && FPBits<T>(r).is_inf())
242         set_errno_if_required(ERANGE);
243 
244       return r;
245     }
246 
247     bool denorm = false;
248     uint32_t shift = Bits - PRECISION;
249     if (LIBC_UNLIKELY(exp_hi <= 0)) {
250       // Output is denormal.
251       denorm = true;
252       shift = (Bits - PRECISION) + static_cast<uint32_t>(1 - exp_hi);
253 
254       exp_hi = FPBits<T>::EXP_BIAS;
255     }
256 
257     int exp_lo = exp_hi - static_cast<int>(PRECISION) - 1;
258 
259     MantissaType m_hi =
260         shift >= MantissaType::BITS ? MantissaType(0) : mantissa >> shift;
261 
262     T d_hi = FPBits<T>::create_value(
263                  sign, static_cast<output_bits_t>(exp_hi),
264                  (static_cast<output_bits_t>(m_hi) & FPBits<T>::SIG_MASK) |
265                      IMPLICIT_MASK)
266                  .get_val();
267 
268     MantissaType round_mask =
269         shift > MantissaType::BITS ? 0 : MantissaType(1) << (shift - 1);
270     MantissaType sticky_mask = round_mask - MantissaType(1);
271 
272     bool round_bit = !(mantissa & round_mask).is_zero();
273     bool sticky_bit = !(mantissa & sticky_mask).is_zero();
274     int round_and_sticky = int(round_bit) * 2 + int(sticky_bit);
275 
276     T d_lo;
277 
278     if (LIBC_UNLIKELY(exp_lo <= 0)) {
279       // d_lo is denormal, but the output is normal.
280       int scale_up_exponent = 1 - exp_lo;
281       T scale_up_factor =
282           FPBits<T>::create_value(Sign::POS,
283                                   static_cast<output_bits_t>(
284                                       FPBits<T>::EXP_BIAS + scale_up_exponent),
285                                   IMPLICIT_MASK)
286               .get_val();
287       T scale_down_factor =
288           FPBits<T>::create_value(Sign::POS,
289                                   static_cast<output_bits_t>(
290                                       FPBits<T>::EXP_BIAS - scale_up_exponent),
291                                   IMPLICIT_MASK)
292               .get_val();
293 
294       d_lo = FPBits<T>::create_value(
295                  sign, static_cast<output_bits_t>(exp_lo + scale_up_exponent),
296                  IMPLICIT_MASK)
297                  .get_val();
298 
299       return multiply_add(d_lo, T(round_and_sticky), d_hi * scale_up_factor) *
300              scale_down_factor;
301     }
302 
303     d_lo = FPBits<T>::create_value(sign, static_cast<output_bits_t>(exp_lo),
304                                    IMPLICIT_MASK)
305                .get_val();
306 
307     // Still correct without FMA instructions if `d_lo` is not underflow.
308     T r = multiply_add(d_lo, T(round_and_sticky), d_hi);
309 
310     if (LIBC_UNLIKELY(denorm)) {
311       // Exponent before rounding is in denormal range, simply clear the
312       // exponent field.
313       output_bits_t clear_exp = static_cast<output_bits_t>(
314           output_bits_t(exp_hi) << FPBits<T>::SIG_LEN);
315       output_bits_t r_bits = FPBits<T>(r).uintval() - clear_exp;
316 
317       if (!(r_bits & FPBits<T>::EXP_MASK)) {
318         // Output is denormal after rounding, clear the implicit bit for 80-bit
319         // long double.
320         r_bits -= IMPLICIT_MASK;
321 
322         // TODO: IEEE Std 754-2019 lets implementers choose whether to check for
323         // "tininess" before or after rounding for base-2 formats, as long as
324         // the same choice is made for all operations. Our choice to check after
325         // rounding might not be the same as the hardware's.
326         if (ShouldSignalExceptions && round_and_sticky) {
327           set_errno_if_required(ERANGE);
328           raise_except_if_required(FE_UNDERFLOW);
329         }
330       }
331 
332       return FPBits<T>(r_bits).get_val();
333     }
334 
335     return r;
336   }
337 
338   // Assume that it is already normalized.
339   // Output is rounded correctly with respect to the current rounding mode.
340   template <typename T, bool ShouldSignalExceptions,
341             typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
342                                             (FPBits<T>::FRACTION_LEN < Bits),
343                                         void>>
asDyadicFloat344   LIBC_INLINE constexpr T as() const {
345 #if defined(LIBC_TYPES_HAS_FLOAT16) && !defined(__LIBC_USE_FLOAT16_CONVERSION)
346     if constexpr (cpp::is_same_v<T, float16>)
347       return generic_as<T, ShouldSignalExceptions>();
348 #endif
349     return fast_as<T, ShouldSignalExceptions>();
350   }
351 
352   template <typename T,
353             typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
354                                             (FPBits<T>::FRACTION_LEN < Bits),
355                                         void>>
TDyadicFloat356   LIBC_INLINE explicit constexpr operator T() const {
357     return as<T, /*ShouldSignalExceptions=*/false>();
358   }
359 
as_mantissa_typeDyadicFloat360   LIBC_INLINE constexpr MantissaType as_mantissa_type() const {
361     if (mantissa.is_zero())
362       return 0;
363 
364     MantissaType new_mant = mantissa;
365     if (exponent > 0) {
366       new_mant <<= exponent;
367     } else {
368       new_mant >>= (-exponent);
369     }
370 
371     if (sign.is_neg()) {
372       new_mant = (~new_mant) + 1;
373     }
374 
375     return new_mant;
376   }
377 };
378 
379 // Quick add - Add 2 dyadic floats with rounding toward 0 and then normalize the
380 // output:
381 //   - Align the exponents so that:
382 //     new a.exponent = new b.exponent = max(a.exponent, b.exponent)
383 //   - Add or subtract the mantissas depending on the signs.
384 //   - Normalize the result.
385 // The absolute errors compared to the mathematical sum is bounded by:
386 //   | quick_add(a, b) - (a + b) | < MSB(a + b) * 2^(-Bits + 2),
387 // i.e., errors are up to 2 ULPs.
388 // Assume inputs are normalized (by constructors or other functions) so that we
389 // don't need to normalize the inputs again in this function.  If the inputs are
390 // not normalized, the results might lose precision significantly.
391 template <size_t Bits>
quick_add(DyadicFloat<Bits> a,DyadicFloat<Bits> b)392 LIBC_INLINE constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
393                                                   DyadicFloat<Bits> b) {
394   if (LIBC_UNLIKELY(a.mantissa.is_zero()))
395     return b;
396   if (LIBC_UNLIKELY(b.mantissa.is_zero()))
397     return a;
398 
399   // Align exponents
400   if (a.exponent > b.exponent)
401     b.shift_right(static_cast<unsigned>(a.exponent - b.exponent));
402   else if (b.exponent > a.exponent)
403     a.shift_right(static_cast<unsigned>(b.exponent - a.exponent));
404 
405   DyadicFloat<Bits> result;
406 
407   if (a.sign == b.sign) {
408     // Addition
409     result.sign = a.sign;
410     result.exponent = a.exponent;
411     result.mantissa = a.mantissa;
412     if (result.mantissa.add_overflow(b.mantissa)) {
413       // Mantissa addition overflow.
414       result.shift_right(1);
415       result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] |=
416           (uint64_t(1) << 63);
417     }
418     // Result is already normalized.
419     return result;
420   }
421 
422   // Subtraction
423   if (a.mantissa >= b.mantissa) {
424     result.sign = a.sign;
425     result.exponent = a.exponent;
426     result.mantissa = a.mantissa - b.mantissa;
427   } else {
428     result.sign = b.sign;
429     result.exponent = b.exponent;
430     result.mantissa = b.mantissa - a.mantissa;
431   }
432 
433   return result.normalize();
434 }
435 
436 // Quick Mul - Slightly less accurate but efficient multiplication of 2 dyadic
437 // floats with rounding toward 0 and then normalize the output:
438 //   result.exponent = a.exponent + b.exponent + Bits,
439 //   result.mantissa = quick_mul_hi(a.mantissa + b.mantissa)
440 //                   ~ (full product a.mantissa * b.mantissa) >> Bits.
441 // The errors compared to the mathematical product is bounded by:
442 //   2 * errors of quick_mul_hi = 2 * (UInt<Bits>::WORD_COUNT - 1) in ULPs.
443 // Assume inputs are normalized (by constructors or other functions) so that we
444 // don't need to normalize the inputs again in this function.  If the inputs are
445 // not normalized, the results might lose precision significantly.
446 template <size_t Bits>
quick_mul(const DyadicFloat<Bits> & a,const DyadicFloat<Bits> & b)447 LIBC_INLINE constexpr DyadicFloat<Bits> quick_mul(const DyadicFloat<Bits> &a,
448                                                   const DyadicFloat<Bits> &b) {
449   DyadicFloat<Bits> result;
450   result.sign = (a.sign != b.sign) ? Sign::NEG : Sign::POS;
451   result.exponent = a.exponent + b.exponent + static_cast<int>(Bits);
452 
453   if (!(a.mantissa.is_zero() || b.mantissa.is_zero())) {
454     result.mantissa = a.mantissa.quick_mul_hi(b.mantissa);
455     // Check the leading bit directly, should be faster than using clz in
456     // normalize().
457     if (result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] >>
458             63 ==
459         0)
460       result.shift_left(1);
461   } else {
462     result.mantissa = (typename DyadicFloat<Bits>::MantissaType)(0);
463   }
464   return result;
465 }
466 
467 // Simple polynomial approximation.
468 template <size_t Bits>
469 LIBC_INLINE constexpr DyadicFloat<Bits>
multiply_add(const DyadicFloat<Bits> & a,const DyadicFloat<Bits> & b,const DyadicFloat<Bits> & c)470 multiply_add(const DyadicFloat<Bits> &a, const DyadicFloat<Bits> &b,
471              const DyadicFloat<Bits> &c) {
472   return quick_add(c, quick_mul(a, b));
473 }
474 
475 // Simple exponentiation implementation for printf. Only handles positive
476 // exponents, since division isn't implemented.
477 template <size_t Bits>
pow_n(const DyadicFloat<Bits> & a,uint32_t power)478 LIBC_INLINE constexpr DyadicFloat<Bits> pow_n(const DyadicFloat<Bits> &a,
479                                               uint32_t power) {
480   DyadicFloat<Bits> result = 1.0;
481   DyadicFloat<Bits> cur_power = a;
482 
483   while (power > 0) {
484     if ((power % 2) > 0) {
485       result = quick_mul(result, cur_power);
486     }
487     power = power >> 1;
488     cur_power = quick_mul(cur_power, cur_power);
489   }
490   return result;
491 }
492 
493 template <size_t Bits>
mul_pow_2(const DyadicFloat<Bits> & a,int32_t pow_2)494 LIBC_INLINE constexpr DyadicFloat<Bits> mul_pow_2(const DyadicFloat<Bits> &a,
495                                                   int32_t pow_2) {
496   DyadicFloat<Bits> result = a;
497   result.exponent += pow_2;
498   return result;
499 }
500 
501 } // namespace fputil
502 } // namespace LIBC_NAMESPACE_DECL
503 
504 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
505