xref: /aosp_15_r20/external/pytorch/c10/util/BFloat16-inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 #include <c10/util/bit_cast.h>
5 
6 #include <limits>
7 
8 C10_CLANG_DIAGNOSTIC_PUSH()
9 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
10 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
11 #endif
12 
13 #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
14 #if defined(CL_SYCL_LANGUAGE_VERSION)
15 #include <CL/sycl.hpp> // for SYCL 1.2.1
16 #else
17 #include <sycl/sycl.hpp> // for SYCL 2020
18 #endif
19 #include <ext/oneapi/bfloat16.hpp>
20 #endif
21 
22 namespace c10 {
23 
24 /// Constructors
BFloat16(float value)25 inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
26     :
27 #if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
28     __CUDA_ARCH__ >= 800
29       x(__bfloat16_as_ushort(__float2bfloat16(value)))
30 #elif defined(__SYCL_DEVICE_ONLY__) && \
31     defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
32       x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
33 #else
34       // RNE by default
35       x(detail::round_to_nearest_even(value))
36 #endif
37 {
38 }
39 
40 /// Implicit conversions
41 inline C10_HOST_DEVICE BFloat16::operator float() const {
42 #if defined(__CUDACC__) && !defined(USE_ROCM)
43   return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
44 #elif defined(__SYCL_DEVICE_ONLY__) && \
45     defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
46   return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x));
47 #else
48   return detail::f32_from_bits(x);
49 #endif
50 }
51 
52 #if defined(__CUDACC__) && !defined(USE_ROCM)
BFloat16(const __nv_bfloat16 & value)53 inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
54   x = *reinterpret_cast<const unsigned short*>(&value);
55 }
__nv_bfloat16()56 inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
57   return *reinterpret_cast<const __nv_bfloat16*>(&x);
58 }
59 #endif
60 #if defined(__HIPCC__) && defined(USE_ROCM)
61 // 6.2.0 introduced __hip_bfloat16_raw
62 #if defined(__BF16_HOST_DEVICE__)
BFloat16(const __hip_bfloat16 & value)63 inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
64   x = __hip_bfloat16_raw(value).x;
65 }
__hip_bfloat16()66 inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
67   return __hip_bfloat16(__hip_bfloat16_raw{x});
68 }
69 #else // !defined(__BF16_HOST_DEVICE__)
BFloat16(const __hip_bfloat16 & value)70 inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
71   x = value.data;
72 }
__hip_bfloat16()73 inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
74   return __hip_bfloat16{x};
75 }
76 #endif // !defined(__BF16_HOST_DEVICE__)
77 #endif // defined(__HIPCC__) && defined(USE_ROCM)
78 
79 #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
BFloat16(const sycl::ext::oneapi::bfloat16 & value)80 inline C10_HOST_DEVICE BFloat16::BFloat16(
81     const sycl::ext::oneapi::bfloat16& value) {
82   x = *reinterpret_cast<const unsigned short*>(&value);
83 }
bfloat16()84 inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
85   return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x);
86 }
87 #endif
88 
89 // CUDA intrinsics
90 
91 #if defined(__CUDACC__) || defined(__HIPCC__)
__ldg(const BFloat16 * ptr)92 inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
93 #if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
94   return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
95 #else
96   return *ptr;
97 #endif
98 }
99 #endif
100 
101 /// Arithmetic
102 
103 inline C10_HOST_DEVICE BFloat16
104 operator+(const BFloat16& a, const BFloat16& b) {
105   return static_cast<float>(a) + static_cast<float>(b);
106 }
107 
108 inline C10_HOST_DEVICE BFloat16
109 operator-(const BFloat16& a, const BFloat16& b) {
110   return static_cast<float>(a) - static_cast<float>(b);
111 }
112 
113 inline C10_HOST_DEVICE BFloat16
114 operator*(const BFloat16& a, const BFloat16& b) {
115   return static_cast<float>(a) * static_cast<float>(b);
116 }
117 
118 inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b)
119     __ubsan_ignore_float_divide_by_zero__ {
120   return static_cast<float>(a) / static_cast<float>(b);
121 }
122 
123 inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) {
124   return -static_cast<float>(a);
125 }
126 
127 inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
128   a = a + b;
129   return a;
130 }
131 
132 inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
133   a = a - b;
134   return a;
135 }
136 
137 inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
138   a = a * b;
139   return a;
140 }
141 
142 inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
143   a = a / b;
144   return a;
145 }
146 
147 inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) {
148   a.x = a.x | b.x;
149   return a;
150 }
151 
152 inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) {
153   a.x = a.x ^ b.x;
154   return a;
155 }
156 
157 inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) {
158   a.x = a.x & b.x;
159   return a;
160 }
161 
162 /// Arithmetic with floats
163 
164 inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
165   return static_cast<float>(a) + b;
166 }
167 inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
168   return static_cast<float>(a) - b;
169 }
170 inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
171   return static_cast<float>(a) * b;
172 }
173 inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
174   return static_cast<float>(a) / b;
175 }
176 
177 inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
178   return a + static_cast<float>(b);
179 }
180 inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
181   return a - static_cast<float>(b);
182 }
183 inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
184   return a * static_cast<float>(b);
185 }
186 inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
187   return a / static_cast<float>(b);
188 }
189 
190 inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) {
191   return a += static_cast<float>(b);
192 }
193 inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) {
194   return a -= static_cast<float>(b);
195 }
196 inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) {
197   return a *= static_cast<float>(b);
198 }
199 inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) {
200   return a /= static_cast<float>(b);
201 }
202 
203 /// Arithmetic with doubles
204 
205 inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
206   return static_cast<double>(a) + b;
207 }
208 inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
209   return static_cast<double>(a) - b;
210 }
211 inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
212   return static_cast<double>(a) * b;
213 }
214 inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
215   return static_cast<double>(a) / b;
216 }
217 
218 inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
219   return a + static_cast<double>(b);
220 }
221 inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
222   return a - static_cast<double>(b);
223 }
224 inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
225   return a * static_cast<double>(b);
226 }
227 inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
228   return a / static_cast<double>(b);
229 }
230 
231 /// Arithmetic with ints
232 
233 inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
234   return a + static_cast<BFloat16>(b);
235 }
236 inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
237   return a - static_cast<BFloat16>(b);
238 }
239 inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
240   return a * static_cast<BFloat16>(b);
241 }
242 inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
243   return a / static_cast<BFloat16>(b);
244 }
245 
246 inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
247   return static_cast<BFloat16>(a) + b;
248 }
249 inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
250   return static_cast<BFloat16>(a) - b;
251 }
252 inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
253   return static_cast<BFloat16>(a) * b;
254 }
255 inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
256   return static_cast<BFloat16>(a) / b;
257 }
258 
259 //// Arithmetic with int64_t
260 
261 inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
262   return a + static_cast<BFloat16>(b);
263 }
264 inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
265   return a - static_cast<BFloat16>(b);
266 }
267 inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
268   return a * static_cast<BFloat16>(b);
269 }
270 inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
271   return a / static_cast<BFloat16>(b);
272 }
273 
274 inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
275   return static_cast<BFloat16>(a) + b;
276 }
277 inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
278   return static_cast<BFloat16>(a) - b;
279 }
280 inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
281   return static_cast<BFloat16>(a) * b;
282 }
283 inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
284   return static_cast<BFloat16>(a) / b;
285 }
286 
287 // Overloading < and > operators, because std::max and std::min use them.
288 
289 inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) {
290   return float(lhs) > float(rhs);
291 }
292 
293 inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) {
294   return float(lhs) < float(rhs);
295 }
296 
297 } // namespace c10
298 
299 namespace std {
300 
301 template <>
302 class numeric_limits<c10::BFloat16> {
303  public:
304   static constexpr bool is_signed = true;
305   static constexpr bool is_specialized = true;
306   static constexpr bool is_integer = false;
307   static constexpr bool is_exact = false;
308   static constexpr bool has_infinity = true;
309   static constexpr bool has_quiet_NaN = true;
310   static constexpr bool has_signaling_NaN = true;
311   static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
312   static constexpr auto has_denorm_loss =
313       numeric_limits<float>::has_denorm_loss;
314   static constexpr auto round_style = numeric_limits<float>::round_style;
315   static constexpr bool is_iec559 = false;
316   static constexpr bool is_bounded = true;
317   static constexpr bool is_modulo = false;
318   static constexpr int digits = 8;
319   static constexpr int digits10 = 2;
320   static constexpr int max_digits10 = 4;
321   static constexpr int radix = 2;
322   static constexpr int min_exponent = -125;
323   static constexpr int min_exponent10 = -37;
324   static constexpr int max_exponent = 128;
325   static constexpr int max_exponent10 = 38;
326   static constexpr auto traps = numeric_limits<float>::traps;
327   static constexpr auto tinyness_before =
328       numeric_limits<float>::tinyness_before;
329 
min()330   static constexpr c10::BFloat16 min() {
331     return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
332   }
lowest()333   static constexpr c10::BFloat16 lowest() {
334     return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
335   }
max()336   static constexpr c10::BFloat16 max() {
337     return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
338   }
epsilon()339   static constexpr c10::BFloat16 epsilon() {
340     return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
341   }
round_error()342   static constexpr c10::BFloat16 round_error() {
343     return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
344   }
infinity()345   static constexpr c10::BFloat16 infinity() {
346     return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
347   }
quiet_NaN()348   static constexpr c10::BFloat16 quiet_NaN() {
349     return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
350   }
signaling_NaN()351   static constexpr c10::BFloat16 signaling_NaN() {
352     return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
353   }
denorm_min()354   static constexpr c10::BFloat16 denorm_min() {
355     return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
356   }
357 };
358 
359 } // namespace std
360 
361 C10_CLANG_DIAGNOSTIC_POP()
362