xref: /aosp_15_r20/external/pytorch/c10/util/Float8_fnuz_cvt.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/floating_point_utils.h>
4 
5 #include <cstdint>
6 
7 #if defined(SYCL_LANGUAGE_VERSION)
8 #include <sycl/sycl.hpp>
9 #endif
10 
11 namespace c10::detail {
12 
13 /*
14  * Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ
15  * format, in bit representation, to a 32-bit floating-point number.
16  */
17 template <uint32_t we, uint32_t wm>
fp8_fnuz_to_fp32_value(uint8_t x)18 inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) {
19   static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2));
20   constexpr uint32_t weo = 8;
21   constexpr uint32_t wmo = 23;
22 
23   if (x == 0) {
24     return 0;
25   }
26 
27   if (x == 0x80) {
28     constexpr uint32_t ifNaN = 0x7F800001;
29     return fp32_from_bits(ifNaN);
30   }
31 
32   uint32_t mantissa = x & ((1 << wm) - 1);
33   uint32_t exponent = (x & 0x7F) >> wm;
34 
35   // subnormal input
36   if (exponent == 0) {
37     // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
38 #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
39     uint32_t renorm_shift = __clz(mantissa);
40 #elif defined(__SYCL_DEVICE_ONLY__)
41     uint32_t renorm_shift = sycl::clz(mantissa);
42 #elif defined(_MSC_VER)
43     unsigned long nonsign_bsr;
44     _BitScanReverse(&nonsign_bsr, (unsigned long)mantissa);
45     uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
46 #else
47     uint32_t renorm_shift = __builtin_clz(mantissa);
48 #endif
49     uint32_t sh = 1 + renorm_shift - (32 - wm);
50     mantissa <<= sh;
51     exponent += 1 - sh;
52     mantissa &= ((1 << wm) - 1);
53   }
54 
55   const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1));
56   exponent += exp_low_cutoff - 1;
57   mantissa <<= wmo - wm;
58 
59   uint32_t sign = x >> 7;
60   uint32_t retval = (sign << 31) | (exponent << 23) | mantissa;
61   return fp32_from_bits(retval);
62 }
63 
64 } // namespace c10::detail
65