1 #pragma once 2 3 // Defines the bloat16 type (brain floating-point). This representation uses 4 // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. 5 6 #include <c10/macros/Macros.h> 7 #include <cmath> 8 #include <cstdint> 9 #include <cstring> 10 #include <iosfwd> 11 #include <ostream> 12 13 #if defined(__CUDACC__) && !defined(USE_ROCM) 14 #include <cuda_bf16.h> 15 #endif 16 #if defined(__HIPCC__) && defined(USE_ROCM) 17 #include <hip/hip_bf16.h> 18 #endif 19 20 #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) 21 #if defined(CL_SYCL_LANGUAGE_VERSION) 22 #include <CL/sycl.hpp> // for SYCL 1.2.1 23 #else 24 #include <sycl/sycl.hpp> // for SYCL 2020 25 #endif 26 #include <ext/oneapi/bfloat16.hpp> 27 #endif 28 29 namespace c10 { 30 31 namespace detail { f32_from_bits(uint16_t src)32inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) { 33 float res = 0; 34 uint32_t tmp = src; 35 tmp <<= 16; 36 37 #if defined(USE_ROCM) 38 float* tempRes; 39 40 // We should be using memcpy in order to respect the strict aliasing rule 41 // but it fails in the HIP environment. 42 tempRes = reinterpret_cast<float*>(&tmp); 43 res = *tempRes; 44 #else 45 std::memcpy(&res, &tmp, sizeof(tmp)); 46 #endif 47 48 return res; 49 } 50 bits_from_f32(float src)51inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) { 52 uint32_t res = 0; 53 54 #if defined(USE_ROCM) 55 // We should be using memcpy in order to respect the strict aliasing rule 56 // but it fails in the HIP environment. 57 uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src); 58 res = *tempRes; 59 #else 60 std::memcpy(&res, &src, sizeof(res)); 61 #endif 62 63 return res >> 16; 64 } 65 round_to_nearest_even(float src)66inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { 67 #if defined(USE_ROCM) 68 if (src != src) { 69 #elif defined(_MSC_VER) 70 if (isnan(src)) { 71 #else 72 if (std::isnan(src)) { 73 #endif 74 return UINT16_C(0x7FC0); 75 } else { 76 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) 77 union { 78 uint32_t U32; // NOLINT(facebook-hte-BadMemberName) 79 float F32; // NOLINT(facebook-hte-BadMemberName) 80 }; 81 82 F32 = src; 83 uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); 84 return static_cast<uint16_t>((U32 + rounding_bias) >> 16); 85 } 86 } 87 } // namespace detail 88 89 struct alignas(2) BFloat16 { 90 uint16_t x; 91 92 // HIP wants __host__ __device__ tag, CUDA does not 93 #if defined(USE_ROCM) 94 C10_HOST_DEVICE BFloat16() = default; 95 #else 96 BFloat16() = default; 97 #endif 98 99 struct from_bits_t {}; 100 static constexpr C10_HOST_DEVICE from_bits_t from_bits() { 101 return from_bits_t(); 102 } 103 104 constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t) 105 : x(bits) {} 106 /* implicit */ inline C10_HOST_DEVICE BFloat16(float value); 107 inline C10_HOST_DEVICE operator float() const; 108 109 #if defined(__CUDACC__) && !defined(USE_ROCM) 110 inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); 111 explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; 112 #endif 113 #if defined(__HIPCC__) && defined(USE_ROCM) 114 inline C10_HOST_DEVICE BFloat16(const __hip_bfloat16& value); 115 explicit inline C10_HOST_DEVICE operator __hip_bfloat16() const; 116 #endif 117 118 #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) 119 inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); 120 explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const; 121 #endif 122 }; 123 124 C10_API inline std::ostream& operator<<( 125 std::ostream& out, 126 const BFloat16& value) { 127 out << (float)value; 128 return out; 129 } 130 131 } // namespace c10 132 133 #include <c10/util/BFloat16-inl.h> // IWYU pragma: keep 134