xref: /aosp_15_r20/external/pytorch/c10/util/Half.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 /// Defines the Half type (half-precision floating-point) including conversions
4 /// to standard C types and basic arithmetic operations. Note that arithmetic
5 /// operations are implemented by converting to floating point and
6 /// performing the operation in float32, instead of using CUDA half intrinsics.
7 /// Most uses of this type within ATen are memory bound, including the
8 /// element-wise kernels, and the half intrinsics aren't efficient on all GPUs.
9 /// If you are writing a compute bound kernel, you can use the CUDA half
10 /// intrinsics directly on the Half type from device code.
11 
12 #include <c10/macros/Export.h>
13 #include <c10/macros/Macros.h>
14 #include <c10/util/TypeSafeSignMath.h>
15 #include <c10/util/bit_cast.h>
16 #include <c10/util/complex.h>
17 #include <c10/util/floating_point_utils.h>
18 #include <type_traits>
19 
20 #if defined(__cplusplus)
21 #include <cmath>
22 #elif !defined(__OPENCL_VERSION__)
23 #include <math.h>
24 #endif
25 
26 #ifdef _MSC_VER
27 #include <intrin.h>
28 #endif
29 
30 #include <cstdint>
31 #include <cstring>
32 #include <iosfwd>
33 #include <limits>
34 #include <ostream>
35 
36 #ifdef __CUDACC__
37 #include <cuda_fp16.h>
38 #endif
39 
40 #ifdef __HIPCC__
41 #include <hip/hip_fp16.h>
42 #endif
43 
44 #if defined(CL_SYCL_LANGUAGE_VERSION)
45 #include <CL/sycl.hpp> // for SYCL 1.2.1
46 #elif defined(SYCL_LANGUAGE_VERSION)
47 #include <sycl/sycl.hpp> // for SYCL 2020
48 #endif
49 
50 #if defined(__aarch64__) && !defined(__CUDACC__)
51 #include <arm_neon.h>
52 #endif
53 
54 namespace c10 {
55 
56 namespace detail {
57 
58 /*
59  * Convert a 16-bit floating-point number in IEEE half-precision format, in bit
60  * representation, to a 32-bit floating-point number in IEEE single-precision
61  * format, in bit representation.
62  *
63  * @note The implementation doesn't use any floating-point operations.
64  */
fp16_ieee_to_fp32_bits(uint16_t h)65 inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
66   /*
67    * Extend the half-precision floating-point number to 32 bits and shift to the
68    * upper part of the 32-bit word:
69    *      +---+-----+------------+-------------------+
70    *      | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
71    *      +---+-----+------------+-------------------+
72    * Bits  31  26-30    16-25            0-15
73    *
74    * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
75    * - zero bits.
76    */
77   const uint32_t w = (uint32_t)h << 16;
78   /*
79    * Extract the sign of the input number into the high bit of the 32-bit word:
80    *
81    *      +---+----------------------------------+
82    *      | S |0000000 00000000 00000000 00000000|
83    *      +---+----------------------------------+
84    * Bits  31                 0-31
85    */
86   const uint32_t sign = w & UINT32_C(0x80000000);
87   /*
88    * Extract mantissa and biased exponent of the input number into the bits 0-30
89    * of the 32-bit word:
90    *
91    *      +---+-----+------------+-------------------+
92    *      | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
93    *      +---+-----+------------+-------------------+
94    * Bits  30  27-31     17-26            0-16
95    */
96   const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
97   /*
98    * Renorm shift is the number of bits to shift mantissa left to make the
99    * half-precision number normalized. If the initial number is normalized, some
100    * of its high 6 bits (sign == 0 and 5-bit exponent) equals one. In this case
101    * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
102    * that if we shift denormalized nonsign by renorm_shift, the unit bit of
103    * mantissa will shift into exponent, turning the biased exponent into 1, and
104    * making mantissa normalized (i.e. without leading 1).
105    */
106 #ifdef _MSC_VER
107   unsigned long nonsign_bsr;
108   _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
109   uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
110 #else
111   uint32_t renorm_shift = __builtin_clz(nonsign);
112 #endif
113   renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
114   /*
115    * Iff half-precision number has exponent of 15, the addition overflows
116    * it into bit 31, and the subsequent shift turns the high 9 bits
117    * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number
118    * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise
119    */
120   const int32_t inf_nan_mask =
121       ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);
122   /*
123    * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
124    * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
125    * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
126    * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
127    * 0x00000000 otherwise
128    */
129   const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
130   /*
131    * 1. Shift nonsign left by renorm_shift to normalize it (if the input
132    * was denormal)
133    * 2. Shift nonsign right by 3 so the exponent (5 bits originally)
134    * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high
135    * bits of the 23-bit mantissa of IEEE single-precision number.
136    * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the
137    * different in exponent bias (0x7F for single-precision number less 0xF
138    * for half-precision number).
139    * 4. Subtract renorm_shift from the exponent (starting at bit 23) to
140    * account for renormalization. As renorm_shift is less than 0x70, this
141    * can be combined with step 3.
142    * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
143    * input was NaN or infinity.
144    * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
145    * into zero if the input was zero.
146    * 7. Combine with the sign of the input number.
147    */
148   return sign |
149       ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) |
150         inf_nan_mask) &
151        ~zero_mask);
152 }
153 
154 /*
155  * Convert a 16-bit floating-point number in IEEE half-precision format, in bit
156  * representation, to a 32-bit floating-point number in IEEE single-precision
157  * format.
158  *
159  * @note The implementation relies on IEEE-like (no assumption about rounding
160  * mode and no operations on denormals) floating-point operations and bitcasts
161  * between integer and floating-point variables.
162  */
fp16_ieee_to_fp32_value(uint16_t h)163 C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
164   /*
165    * Extend the half-precision floating-point number to 32 bits and shift to the
166    * upper part of the 32-bit word:
167    *      +---+-----+------------+-------------------+
168    *      | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
169    *      +---+-----+------------+-------------------+
170    * Bits  31  26-30    16-25            0-15
171    *
172    * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
173    * - zero bits.
174    */
175   const uint32_t w = (uint32_t)h << 16;
176   /*
177    * Extract the sign of the input number into the high bit of the 32-bit word:
178    *
179    *      +---+----------------------------------+
180    *      | S |0000000 00000000 00000000 00000000|
181    *      +---+----------------------------------+
182    * Bits  31                 0-31
183    */
184   const uint32_t sign = w & UINT32_C(0x80000000);
185   /*
186    * Extract mantissa and biased exponent of the input number into the high bits
187    * of the 32-bit word:
188    *
189    *      +-----+------------+---------------------+
190    *      |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
191    *      +-----+------------+---------------------+
192    * Bits  27-31    17-26            0-16
193    */
194   const uint32_t two_w = w + w;
195 
196   /*
197    * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become
198    * mantissa and exponent of a single-precision floating-point number:
199    *
200    *       S|Exponent |          Mantissa
201    *      +-+---+-----+------------+----------------+
202    *      |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
203    *      +-+---+-----+------------+----------------+
204    * Bits   | 23-31   |           0-22
205    *
206    * Next, there are some adjustments to the exponent:
207    * - The exponent needs to be corrected by the difference in exponent bias
208    * between single-precision and half-precision formats (0x7F - 0xF = 0x70)
209    * - Inf and NaN values in the inputs should become Inf and NaN values after
210    * conversion to the single-precision number. Therefore, if the biased
211    * exponent of the half-precision input was 0x1F (max possible value), the
212    * biased exponent of the single-precision output must be 0xFF (max possible
213    * value). We do this correction in two steps:
214    *   - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset
215    * below) rather than by 0x70 suggested by the difference in the exponent bias
216    * (see above).
217    *   - Then we multiply the single-precision result of exponent adjustment by
218    * 2**(-112) to reverse the effect of exponent adjustment by 0xE0 less the
219    * necessary exponent adjustment by 0x70 due to difference in exponent bias.
220    *     The floating-point multiplication hardware would ensure than Inf and
221    * NaN would retain their value on at least partially IEEE754-compliant
222    * implementations.
223    *
224    * Note that the above operations do not handle denormal inputs (where biased
225    * exponent == 0). However, they also do not operate on denormal inputs, and
226    * do not produce denormal results.
227    */
228   constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23;
229   // const float exp_scale = 0x1.0p-112f;
230   constexpr uint32_t scale_bits = (uint32_t)15 << 23;
231   float exp_scale_val = 0;
232   std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
233   const float exp_scale = exp_scale_val;
234   const float normalized_value =
235       fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
236 
237   /*
238    * Convert denormalized half-precision inputs into single-precision results
239    * (always normalized). Zero inputs are also handled here.
240    *
241    * In a denormalized number the biased exponent is zero, and mantissa has
242    * on-zero bits. First, we shift mantissa into bits 0-9 of the 32-bit word.
243    *
244    *                  zeros           |  mantissa
245    *      +---------------------------+------------+
246    *      |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
247    *      +---------------------------+------------+
248    * Bits             10-31                0-9
249    *
250    * Now, remember that denormalized half-precision numbers are represented as:
251    *    FP16 = mantissa * 2**(-24).
252    * The trick is to construct a normalized single-precision number with the
253    * same mantissa and thehalf-precision input and with an exponent which would
254    * scale the corresponding mantissa bits to 2**(-24). A normalized
255    * single-precision floating-point number is represented as: FP32 = (1 +
256    * mantissa * 2**(-23)) * 2**(exponent - 127) Therefore, when the biased
257    * exponent is 126, a unit change in the mantissa of the input denormalized
258    * half-precision number causes a change of the constructed single-precision
259    * number by 2**(-24), i.e. the same amount.
260    *
261    * The last step is to adjust the bias of the constructed single-precision
262    * number. When the input half-precision number is zero, the constructed
263    * single-precision number has the value of FP32 = 1 * 2**(126 - 127) =
264    * 2**(-1) = 0.5 Therefore, we need to subtract 0.5 from the constructed
265    * single-precision number to get the numerical equivalent of the input
266    * half-precision number.
267    */
268   constexpr uint32_t magic_mask = UINT32_C(126) << 23;
269   constexpr float magic_bias = 0.5f;
270   const float denormalized_value =
271       fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
272 
273   /*
274    * - Choose either results of conversion of input as a normalized number, or
275    * as a denormalized number, depending on the input exponent. The variable
276    * two_w contains input exponent in bits 27-31, therefore if its smaller than
277    * 2**27, the input is either a denormal number, or zero.
278    * - Combine the result of conversion of exponent and mantissa with the sign
279    * of the input number.
280    */
281   constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27;
282   const uint32_t result = sign |
283       (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
284                                    : fp32_to_bits(normalized_value));
285   return fp32_from_bits(result);
286 }
287 
288 /*
289  * Convert a 32-bit floating-point number in IEEE single-precision format to a
290  * 16-bit floating-point number in IEEE half-precision format, in bit
291  * representation.
292  *
293  * @note The implementation relies on IEEE-like (no assumption about rounding
294  * mode and no operations on denormals) floating-point operations and bitcasts
295  * between integer and floating-point variables.
296  */
fp16_ieee_from_fp32_value(float f)297 inline uint16_t fp16_ieee_from_fp32_value(float f) {
298   // const float scale_to_inf = 0x1.0p+112f;
299   // const float scale_to_zero = 0x1.0p-110f;
300   constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
301   constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
302   float scale_to_inf_val = 0, scale_to_zero_val = 0;
303   std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
304   std::memcpy(
305       &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
306   const float scale_to_inf = scale_to_inf_val;
307   const float scale_to_zero = scale_to_zero_val;
308 
309 #if defined(_MSC_VER) && _MSC_VER == 1916
310   float base = ((signbit(f) != 0 ? -f : f) * scale_to_inf) * scale_to_zero;
311 #else
312   float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
313 #endif
314 
315   const uint32_t w = fp32_to_bits(f);
316   const uint32_t shl1_w = w + w;
317   const uint32_t sign = w & UINT32_C(0x80000000);
318   uint32_t bias = shl1_w & UINT32_C(0xFF000000);
319   if (bias < UINT32_C(0x71000000)) {
320     bias = UINT32_C(0x71000000);
321   }
322 
323   base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
324   const uint32_t bits = fp32_to_bits(base);
325   const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
326   const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
327   const uint32_t nonsign = exp_bits + mantissa_bits;
328   return static_cast<uint16_t>(
329       (sign >> 16) |
330       (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign));
331 }
332 
333 #if defined(__aarch64__) && !defined(__CUDACC__)
fp16_from_bits(uint16_t h)334 inline float16_t fp16_from_bits(uint16_t h) {
335   return c10::bit_cast<float16_t>(h);
336 }
337 
fp16_to_bits(float16_t f)338 inline uint16_t fp16_to_bits(float16_t f) {
339   return c10::bit_cast<uint16_t>(f);
340 }
341 
342 // According to https://godbolt.org/z/frExdbsWG it would translate to single
343 // fcvt s0, h0
native_fp16_to_fp32_value(uint16_t h)344 inline float native_fp16_to_fp32_value(uint16_t h) {
345   return static_cast<float>(fp16_from_bits(h));
346 }
347 
native_fp16_from_fp32_value(float f)348 inline uint16_t native_fp16_from_fp32_value(float f) {
349   return fp16_to_bits(static_cast<float16_t>(f));
350 }
351 #endif
352 
353 } // namespace detail
354 
355 struct alignas(2) Half {
356   unsigned short x;
357 
358   struct from_bits_t {};
from_bitsHalf359   C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
360     return from_bits_t();
361   }
362 
363   // HIP wants __host__ __device__ tag, CUDA does not
364 #if defined(USE_ROCM)
365   C10_HOST_DEVICE Half() = default;
366 #else
367   Half() = default;
368 #endif
369 
HalfHalf370   constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits) {}
371 #if defined(__aarch64__) && !defined(__CUDACC__)
372   inline Half(float16_t value);
373   inline operator float16_t() const;
374 #else
375   inline C10_HOST_DEVICE Half(float value);
376   inline C10_HOST_DEVICE operator float() const;
377 #endif
378 
379 #if defined(__CUDACC__) || defined(__HIPCC__)
380   inline C10_HOST_DEVICE Half(const __half& value);
381   inline C10_HOST_DEVICE operator __half() const;
382 #endif
383 #ifdef SYCL_LANGUAGE_VERSION
384   inline C10_HOST_DEVICE Half(const sycl::half& value);
385   inline C10_HOST_DEVICE operator sycl::half() const;
386 #endif
387 };
388 
389 // TODO : move to complex.h
390 template <>
391 struct alignas(4) complex<Half> {
392   Half real_;
393   Half imag_;
394 
395   // Constructors
396   complex() = default;
397   // Half constructor is not constexpr so the following constructor can't
398   // be constexpr
399   C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag)
400       : real_(real), imag_(imag) {}
401   C10_HOST_DEVICE inline complex(const c10::complex<float>& value)
402       : real_(value.real()), imag_(value.imag()) {}
403 
404   // Conversion operator
405   inline C10_HOST_DEVICE operator c10::complex<float>() const {
406     return {real_, imag_};
407   }
408 
409   constexpr C10_HOST_DEVICE Half real() const {
410     return real_;
411   }
412   constexpr C10_HOST_DEVICE Half imag() const {
413     return imag_;
414   }
415 
416   C10_HOST_DEVICE complex<Half>& operator+=(const complex<Half>& other) {
417     real_ = static_cast<float>(real_) + static_cast<float>(other.real_);
418     imag_ = static_cast<float>(imag_) + static_cast<float>(other.imag_);
419     return *this;
420   }
421 
422   C10_HOST_DEVICE complex<Half>& operator-=(const complex<Half>& other) {
423     real_ = static_cast<float>(real_) - static_cast<float>(other.real_);
424     imag_ = static_cast<float>(imag_) - static_cast<float>(other.imag_);
425     return *this;
426   }
427 
428   C10_HOST_DEVICE complex<Half>& operator*=(const complex<Half>& other) {
429     auto a = static_cast<float>(real_);
430     auto b = static_cast<float>(imag_);
431     auto c = static_cast<float>(other.real());
432     auto d = static_cast<float>(other.imag());
433     real_ = a * c - b * d;
434     imag_ = a * d + b * c;
435     return *this;
436   }
437 };
438 
439 // In some versions of MSVC, there will be a compiler error when building.
440 // C4146: unary minus operator applied to unsigned type, result still unsigned
441 // C4804: unsafe use of type 'bool' in operation
442 // It can be addressed by disabling the following warning.
443 #ifdef _MSC_VER
444 #pragma warning(push)
445 #pragma warning(disable : 4146)
446 #pragma warning(disable : 4804)
447 #pragma warning(disable : 4018)
448 #endif
449 
450 // The overflow checks may involve float to int conversion which may
451 // trigger precision loss warning. Re-enable the warning once the code
452 // is fixed. See T58053069.
453 C10_CLANG_DIAGNOSTIC_PUSH()
454 #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
455 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
456 #endif
457 
458 // bool can be converted to any type.
459 // Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build:
460 // `error: comparison of constant '255' with boolean expression is always false`
461 // for `f > limit::max()` below
462 template <typename To, typename From>
463 std::enable_if_t<std::is_same_v<From, bool>, bool> overflows(
464     From /*f*/,
465     bool strict_unsigned [[maybe_unused]] = false) {
466   return false;
467 }
468 
469 // skip isnan and isinf check for integral types
470 template <typename To, typename From>
471 std::enable_if_t<std::is_integral_v<From> && !std::is_same_v<From, bool>, bool>
472 overflows(From f, bool strict_unsigned = false) {
473   using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
474   if constexpr (!limit::is_signed && std::numeric_limits<From>::is_signed) {
475     // allow for negative numbers to wrap using two's complement arithmetic.
476     // For example, with uint8, this allows for `a - b` to be treated as
477     // `a + 255 * b`.
478     if (!strict_unsigned) {
479       return greater_than_max<To>(f) ||
480           (c10::is_negative(f) &&
481            -static_cast<uint64_t>(f) > static_cast<uint64_t>(limit::max()));
482     }
483   }
484   return c10::less_than_lowest<To>(f) || greater_than_max<To>(f);
485 }
486 
487 template <typename To, typename From>
488 std::enable_if_t<std::is_floating_point_v<From>, bool> overflows(
489     From f,
490     bool strict_unsigned [[maybe_unused]] = false) {
491   using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
492   if (limit::has_infinity && std::isinf(static_cast<double>(f))) {
493     return false;
494   }
495   if (!limit::has_quiet_NaN && (f != f)) {
496     return true;
497   }
498   return f < limit::lowest() || f > limit::max();
499 }
500 
501 C10_CLANG_DIAGNOSTIC_POP()
502 
503 #ifdef _MSC_VER
504 #pragma warning(pop)
505 #endif
506 
507 template <typename To, typename From>
508 std::enable_if_t<is_complex<From>::value, bool> overflows(
509     From f,
510     bool strict_unsigned = false) {
511   // casts from complex to real are considered to overflow if the
512   // imaginary component is non-zero
513   if (!is_complex<To>::value && f.imag() != 0) {
514     return true;
515   }
516   // Check for overflow componentwise
517   // (Technically, the imag overflow check is guaranteed to be false
518   // when !is_complex<To>, but any optimizer worth its salt will be
519   // able to figure it out.)
520   return overflows<
521              typename scalar_value_type<To>::type,
522              typename From::value_type>(f.real(), strict_unsigned) ||
523       overflows<
524              typename scalar_value_type<To>::type,
525              typename From::value_type>(f.imag(), strict_unsigned);
526 }
527 
528 C10_API inline std::ostream& operator<<(std::ostream& out, const Half& value) {
529   out << (float)value;
530   return out;
531 }
532 
533 } // namespace c10
534 
535 #include <c10/util/Half-inl.h> // IWYU pragma: keep
536