1 //===-- Square root of IEEE 754 floating point numbers ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H 10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H 11 12 #include "sqrt_80_bit_long_double.h" 13 #include "src/__support/CPP/bit.h" // countl_zero 14 #include "src/__support/CPP/type_traits.h" 15 #include "src/__support/FPUtil/FEnvImpl.h" 16 #include "src/__support/FPUtil/FPBits.h" 17 #include "src/__support/FPUtil/cast.h" 18 #include "src/__support/FPUtil/dyadic_float.h" 19 #include "src/__support/common.h" 20 #include "src/__support/macros/config.h" 21 #include "src/__support/uint128.h" 22 23 #include "hdr/fenv_macros.h" 24 25 namespace LIBC_NAMESPACE_DECL { 26 namespace fputil { 27 28 namespace internal { 29 30 template <typename T> struct SpecialLongDouble { 31 static constexpr bool VALUE = false; 32 }; 33 34 #if defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80) 35 template <> struct SpecialLongDouble<long double> { 36 static constexpr bool VALUE = true; 37 }; 38 #endif // LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80 39 40 template <typename T> 41 LIBC_INLINE void normalize(int &exponent, 42 typename FPBits<T>::StorageType &mantissa) { 43 const int shift = 44 cpp::countl_zero(mantissa) - 45 (8 * static_cast<int>(sizeof(mantissa)) - 1 - FPBits<T>::FRACTION_LEN); 46 exponent -= shift; 47 mantissa <<= shift; 48 } 49 50 #ifdef LIBC_TYPES_LONG_DOUBLE_IS_FLOAT64 51 template <> 52 LIBC_INLINE void normalize<long double>(int &exponent, uint64_t &mantissa) { 53 normalize<double>(exponent, mantissa); 54 } 55 #elif !defined(LIBC_TYPES_LONG_DOUBLE_IS_X86_FLOAT80) 56 template <> 57 LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) { 58 const uint64_t hi_bits = static_cast<uint64_t>(mantissa >> 64); 59 const int shift = 60 hi_bits ? (cpp::countl_zero(hi_bits) - 15) 61 : (cpp::countl_zero(static_cast<uint64_t>(mantissa)) + 49); 62 exponent -= shift; 63 mantissa <<= shift; 64 } 65 #endif 66 67 } // namespace internal 68 69 // Correctly rounded IEEE 754 SQRT for all rounding modes. 70 // Shift-and-add algorithm. 71 template <typename OutType, typename InType> 72 LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> && 73 cpp::is_floating_point_v<InType> && 74 sizeof(OutType) <= sizeof(InType), 75 OutType> 76 sqrt(InType x) { 77 if constexpr (internal::SpecialLongDouble<OutType>::VALUE && 78 internal::SpecialLongDouble<InType>::VALUE) { 79 // Special 80-bit long double. 80 return x86::sqrt(x); 81 } else { 82 // IEEE floating points formats. 83 using OutFPBits = FPBits<OutType>; 84 using InFPBits = FPBits<InType>; 85 using InStorageType = typename InFPBits::StorageType; 86 using DyadicFloat = 87 DyadicFloat<cpp::bit_ceil(static_cast<size_t>(InFPBits::STORAGE_LEN))>; 88 89 constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN; 90 constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val(); 91 92 InFPBits bits(x); 93 94 if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) { 95 // sqrt(+Inf) = +Inf 96 // sqrt(+0) = +0 97 // sqrt(-0) = -0 98 // sqrt(NaN) = NaN 99 // sqrt(-NaN) = -NaN 100 return cast<OutType>(x); 101 } else if (bits.is_neg()) { 102 // sqrt(-Inf) = NaN 103 // sqrt(-x) = NaN 104 return FLT_NAN; 105 } else { 106 int x_exp = bits.get_exponent(); 107 InStorageType x_mant = bits.get_mantissa(); 108 109 // Step 1a: Normalize denormal input and append hidden bit to the mantissa 110 if (bits.is_subnormal()) { 111 ++x_exp; // let x_exp be the correct exponent of ONE bit. 112 internal::normalize<InType>(x_exp, x_mant); 113 } else { 114 x_mant |= ONE; 115 } 116 117 // Step 1b: Make sure the exponent is even. 118 if (x_exp & 1) { 119 --x_exp; 120 x_mant <<= 1; 121 } 122 123 // After step 1b, x = 2^(x_exp) * x_mant, where x_exp is even, and 124 // 1 <= x_mant < 4. So sqrt(x) = 2^(x_exp / 2) * y, with 1 <= y < 2. 125 // Notice that the output of sqrt is always in the normal range. 126 // To perform shift-and-add algorithm to find y, let denote: 127 // y(n) = 1.y_1 y_2 ... y_n, we can define the nth residue to be: 128 // r(n) = 2^n ( x_mant - y(n)^2 ). 129 // That leads to the following recurrence formula: 130 // r(n) = 2*r(n-1) - y_n*[ 2*y(n-1) + 2^(-n-1) ] 131 // with the initial conditions: y(0) = 1, and r(0) = x - 1. 132 // So the nth digit y_n of the mantissa of sqrt(x) can be found by: 133 // y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1) 134 // 0 otherwise. 135 InStorageType y = ONE; 136 InStorageType r = x_mant - ONE; 137 138 // TODO: Reduce iteration count to OutFPBits::FRACTION_LEN + 2 or + 3. 139 for (InStorageType current_bit = ONE >> 1; current_bit; 140 current_bit >>= 1) { 141 r <<= 1; 142 // 2*y(n - 1) + 2^(-n-1) 143 InStorageType tmp = static_cast<InStorageType>((y << 1) + current_bit); 144 if (r >= tmp) { 145 r -= tmp; 146 y += current_bit; 147 } 148 } 149 150 // We compute one more iteration in order to round correctly. 151 r <<= 2; 152 y <<= 2; 153 InStorageType tmp = y + 1; 154 if (r >= tmp) { 155 r -= tmp; 156 // Rounding bit. 157 y |= 2; 158 } 159 // Sticky bit. 160 y |= static_cast<unsigned int>(r != 0); 161 162 DyadicFloat yd(Sign::POS, (x_exp >> 1) - 2 - InFPBits::FRACTION_LEN, y); 163 return yd.template as<OutType, /*ShouldSignalExceptions=*/true>(); 164 } 165 } 166 } 167 168 } // namespace fputil 169 } // namespace LIBC_NAMESPACE_DECL 170 171 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_GENERIC_SQRT_H 172