xref: /aosp_15_r20/external/llvm-libc/src/__support/FPUtil/generic/sqrt.h (revision 71db0c75aadcf003ffe3238005f61d7618a3fead)
1 //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
11 
12 #include "sqrt_80_bit_long_double.h"
13 #include "src/__support/CPP/bit.h" // countl_zero
14 #include "src/__support/CPP/type_traits.h"
15 #include "src/__support/FPUtil/FEnvImpl.h"
16 #include "src/__support/FPUtil/FPBits.h"
17 #include "src/__support/FPUtil/cast.h"
18 #include "src/__support/FPUtil/dyadic_float.h"
19 #include "src/__support/common.h"
20 #include "src/__support/macros/config.h"
21 #include "src/__support/uint128.h"
22 
23 #include "hdr/fenv_macros.h"
24 
25 namespace LIBC_NAMESPACE_DECL {
26 namespace fputil {
27 
28 namespace internal {
29 
30 template <typename T> struct SpecialLongDouble {
31   static constexpr bool VALUE = false;
32 };
33 
34 #if defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
35 template <> struct SpecialLongDouble<long double> {
36   static constexpr bool VALUE = true;
37 };
38 #endif // LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80
39 
40 template <typename T>
41 LIBC_INLINE void normalize(int &exponent,
42                            typename FPBits<T>::StorageType &mantissa) {
43   const int shift =
44       cpp::countl_zero(mantissa) -
45       (8 * static_cast<int>(sizeof(mantissa)) - 1 - FPBits<T>::FRACTION_LEN);
46   exponent -= shift;
47   mantissa <<= shift;
48 }
49 
50 #ifdef LIBC_TYPES_LONG_DOUBLE_IS_FLOAT64
51 template <>
52 LIBC_INLINE void normalize<long double>(int &exponent, uint64_t &mantissa) {
53   normalize<double>(exponent, mantissa);
54 }
55 #elif !defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80)
56 template <>
57 LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
58   const uint64_t hi_bits = static_cast<uint64_t>(mantissa >> 64);
59   const int shift =
60       hi_bits ? (cpp::countl_zero(hi_bits) - 15)
61               : (cpp::countl_zero(static_cast<uint64_t>(mantissa)) + 49);
62   exponent -= shift;
63   mantissa <<= shift;
64 }
65 #endif
66 
67 } // namespace internal
68 
69 // Correctly rounded IEEE 754 SQRT for all rounding modes.
70 // Shift-and-add algorithm.
71 template <typename OutType, typename InType>
72 LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
73                                  cpp::is_floating_point_v<InType> &&
74                                  sizeof(OutType) <= sizeof(InType),
75                              OutType>
76 sqrt(InType x) {
77   if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
78                 internal::SpecialLongDouble<InType>::VALUE) {
79     // Special 80-bit long double.
80     return x86::sqrt(x);
81   } else {
82     // IEEE floating points formats.
83     using OutFPBits = FPBits<OutType>;
84     using InFPBits = FPBits<InType>;
85     using InStorageType = typename InFPBits::StorageType;
86     using DyadicFloat =
87         DyadicFloat<cpp::bit_ceil(static_cast<size_t>(InFPBits::STORAGE_LEN))>;
88 
89     constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
90     constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val();
91 
92     InFPBits bits(x);
93 
94     if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
95       // sqrt(+Inf) = +Inf
96       // sqrt(+0) = +0
97       // sqrt(-0) = -0
98       // sqrt(NaN) = NaN
99       // sqrt(-NaN) = -NaN
100       return cast<OutType>(x);
101     } else if (bits.is_neg()) {
102       // sqrt(-Inf) = NaN
103       // sqrt(-x) = NaN
104       return FLT_NAN;
105     } else {
106       int x_exp = bits.get_exponent();
107       InStorageType x_mant = bits.get_mantissa();
108 
109       // Step 1a: Normalize denormal input and append hidden bit to the mantissa
110       if (bits.is_subnormal()) {
111         ++x_exp; // let x_exp be the correct exponent of ONE bit.
112         internal::normalize<InType>(x_exp, x_mant);
113       } else {
114         x_mant |= ONE;
115       }
116 
117       // Step 1b: Make sure the exponent is even.
118       if (x_exp & 1) {
119         --x_exp;
120         x_mant <<= 1;
121       }
122 
123       // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and
124       // 1 <= x_mant < 4.  So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2.
125       // Notice that the output of sqrt is always in the normal range.
126       // To perform shift-and-add algorithm to find y, let denote:
127       //   y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be:
128       //   r(n) = 2^n ( x_mant - y(n)^2 ).
129       // That leads to the following recurrence formula:
130       //   r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ]
131       // with the initial conditions: y(0) = 1, and r(0) = x - 1.
132       // So the nth digit y_n of the mantissa of sqrt(x) can be found by:
133       //   y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
134       //         0 otherwise.
135       InStorageType y = ONE;
136       InStorageType r = x_mant - ONE;
137 
138       // TODO: Reduce iteration count to OutFPBits::FRACTION_LEN + 2 or + 3.
139       for (InStorageType current_bit = ONE >> 1; current_bit;
140            current_bit >>= 1) {
141         r <<= 1;
142         // 2*y(n - 1) + 2^(-n-1)
143         InStorageType tmp = static_cast<InStorageType>((y << 1) + current_bit);
144         if (r >= tmp) {
145           r -= tmp;
146           y += current_bit;
147         }
148       }
149 
150       // We compute one more iteration in order to round correctly.
151       r <<= 2;
152       y <<= 2;
153       InStorageType tmp = y + 1;
154       if (r >= tmp) {
155         r -= tmp;
156         // Rounding bit.
157         y |= 2;
158       }
159       // Sticky bit.
160       y |= static_cast<unsigned int>(r != 0);
161 
162       DyadicFloat yd(Sign::POS, (x_exp >> 1) - 2 - InFPBits::FRACTION_LEN, y);
163       return yd.template as<OutType, /*ShouldSignalExceptions=*/true>();
164     }
165   }
166 }
167 
168 } // namespace fputil
169 } // namespace LIBC_NAMESPACE_DECL
170 
171 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H
172