xref: /aosp_15_r20/external/pytorch/c10/util/BFloat16.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // Defines the bloat16 type (brain floating-point). This representation uses
4 // 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
5 
6 #include <c10/macros/Macros.h>
7 #include <cmath>
8 #include <cstdint>
9 #include <cstring>
10 #include <iosfwd>
11 #include <ostream>
12 
13 #if defined(__CUDACC__) && !defined(USE_ROCM)
14 #include <cuda_bf16.h>
15 #endif
16 #if defined(__HIPCC__) && defined(USE_ROCM)
17 #include <hip/hip_bf16.h>
18 #endif
19 
20 #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
21 #if defined(CL_SYCL_LANGUAGE_VERSION)
22 #include <CL/sycl.hpp> // for SYCL 1.2.1
23 #else
24 #include <sycl/sycl.hpp> // for SYCL 2020
25 #endif
26 #include <ext/oneapi/bfloat16.hpp>
27 #endif
28 
29 namespace c10 {
30 
31 namespace detail {
f32_from_bits(uint16_t src)32 inline C10_HOST_DEVICE float f32_from_bits(uint16_t src) {
33   float res = 0;
34   uint32_t tmp = src;
35   tmp <<= 16;
36 
37 #if defined(USE_ROCM)
38   float* tempRes;
39 
40   // We should be using memcpy in order to respect the strict aliasing rule
41   // but it fails in the HIP environment.
42   tempRes = reinterpret_cast<float*>(&tmp);
43   res = *tempRes;
44 #else
45   std::memcpy(&res, &tmp, sizeof(tmp));
46 #endif
47 
48   return res;
49 }
50 
bits_from_f32(float src)51 inline C10_HOST_DEVICE uint16_t bits_from_f32(float src) {
52   uint32_t res = 0;
53 
54 #if defined(USE_ROCM)
55   // We should be using memcpy in order to respect the strict aliasing rule
56   // but it fails in the HIP environment.
57   uint32_t* tempRes = reinterpret_cast<uint32_t*>(&src);
58   res = *tempRes;
59 #else
60   std::memcpy(&res, &src, sizeof(res));
61 #endif
62 
63   return res >> 16;
64 }
65 
round_to_nearest_even(float src)66 inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
67 #if defined(USE_ROCM)
68   if (src != src) {
69 #elif defined(_MSC_VER)
70   if (isnan(src)) {
71 #else
72   if (std::isnan(src)) {
73 #endif
74     return UINT16_C(0x7FC0);
75   } else {
76     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
77     union {
78       uint32_t U32; // NOLINT(facebook-hte-BadMemberName)
79       float F32; // NOLINT(facebook-hte-BadMemberName)
80     };
81 
82     F32 = src;
83     uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
84     return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
85   }
86 }
87 } // namespace detail
88 
89 struct alignas(2) BFloat16 {
90   uint16_t x;
91 
92   // HIP wants __host__ __device__ tag, CUDA does not
93 #if defined(USE_ROCM)
94   C10_HOST_DEVICE BFloat16() = default;
95 #else
96   BFloat16() = default;
97 #endif
98 
99   struct from_bits_t {};
100   static constexpr C10_HOST_DEVICE from_bits_t from_bits() {
101     return from_bits_t();
102   }
103 
104   constexpr C10_HOST_DEVICE BFloat16(unsigned short bits, from_bits_t)
105       : x(bits) {}
106   /* implicit */ inline C10_HOST_DEVICE BFloat16(float value);
107   inline C10_HOST_DEVICE operator float() const;
108 
109 #if defined(__CUDACC__) && !defined(USE_ROCM)
110   inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value);
111   explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const;
112 #endif
113 #if defined(__HIPCC__) && defined(USE_ROCM)
114   inline C10_HOST_DEVICE BFloat16(const __hip_bfloat16& value);
115   explicit inline C10_HOST_DEVICE operator __hip_bfloat16() const;
116 #endif
117 
118 #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
119   inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value);
120   explicit inline C10_HOST_DEVICE operator sycl::ext::oneapi::bfloat16() const;
121 #endif
122 };
123 
124 C10_API inline std::ostream& operator<<(
125     std::ostream& out,
126     const BFloat16& value) {
127   out << (float)value;
128   return out;
129 }
130 
131 } // namespace c10
132 
133 #include <c10/util/BFloat16-inl.h> // IWYU pragma: keep
134