1 #pragma once 2 3 #include <c10/util/floating_point_utils.h> 4 5 #include <cstdint> 6 7 #if defined(SYCL_LANGUAGE_VERSION) 8 #include <sycl/sycl.hpp> 9 #endif 10 11 namespace c10::detail { 12 13 /* 14 * Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ 15 * format, in bit representation, to a 32-bit floating-point number. 16 */ 17 template <uint32_t we, uint32_t wm> fp8_fnuz_to_fp32_value(uint8_t x)18inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) { 19 static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2)); 20 constexpr uint32_t weo = 8; 21 constexpr uint32_t wmo = 23; 22 23 if (x == 0) { 24 return 0; 25 } 26 27 if (x == 0x80) { 28 constexpr uint32_t ifNaN = 0x7F800001; 29 return fp32_from_bits(ifNaN); 30 } 31 32 uint32_t mantissa = x & ((1 << wm) - 1); 33 uint32_t exponent = (x & 0x7F) >> wm; 34 35 // subnormal input 36 if (exponent == 0) { 37 // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above 38 #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) 39 uint32_t renorm_shift = __clz(mantissa); 40 #elif defined(__SYCL_DEVICE_ONLY__) 41 uint32_t renorm_shift = sycl::clz(mantissa); 42 #elif defined(_MSC_VER) 43 unsigned long nonsign_bsr; 44 _BitScanReverse(&nonsign_bsr, (unsigned long)mantissa); 45 uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; 46 #else 47 uint32_t renorm_shift = __builtin_clz(mantissa); 48 #endif 49 uint32_t sh = 1 + renorm_shift - (32 - wm); 50 mantissa <<= sh; 51 exponent += 1 - sh; 52 mantissa &= ((1 << wm) - 1); 53 } 54 55 const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)); 56 exponent += exp_low_cutoff - 1; 57 mantissa <<= wmo - wm; 58 59 uint32_t sign = x >> 7; 60 uint32_t retval = (sign << 31) | (exponent << 23) | mantissa; 61 return fp32_from_bits(retval); 62 } 63 64 } // namespace c10::detail 65