1 #pragma once
2
3 #include <c10/macros/Macros.h>
4 #include <c10/util/bit_cast.h>
5
6 #include <limits>
7
8 C10_CLANG_DIAGNOSTIC_PUSH()
9 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
10 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
11 #endif
12
13 #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
14 #if defined(CL_SYCL_LANGUAGE_VERSION)
15 #include <CL/sycl.hpp> // for SYCL 1.2.1
16 #else
17 #include <sycl/sycl.hpp> // for SYCL 2020
18 #endif
19 #include <ext/oneapi/bfloat16.hpp>
20 #endif
21
22 namespace c10 {
23
24 /// Constructors
BFloat16(float value)25 inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
26 :
27 #if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
28 __CUDA_ARCH__ >= 800
29 x(__bfloat16_as_ushort(__float2bfloat16(value)))
30 #elif defined(__SYCL_DEVICE_ONLY__) && \
31 defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
32 x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
33 #else
34 // RNE by default
35 x(detail::round_to_nearest_even(value))
36 #endif
37 {
38 }
39
40 /// Implicit conversions
41 inline C10_HOST_DEVICE BFloat16::operator float() const {
42 #if defined(__CUDACC__) && !defined(USE_ROCM)
43 return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
44 #elif defined(__SYCL_DEVICE_ONLY__) && \
45 defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
46 return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x));
47 #else
48 return detail::f32_from_bits(x);
49 #endif
50 }
51
52 #if defined(__CUDACC__) && !defined(USE_ROCM)
BFloat16(const __nv_bfloat16 & value)53 inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
54 x = *reinterpret_cast<const unsigned short*>(&value);
55 }
__nv_bfloat16()56 inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
57 return *reinterpret_cast<const __nv_bfloat16*>(&x);
58 }
59 #endif
60 #if defined(__HIPCC__) && defined(USE_ROCM)
61 // 6.2.0 introduced __hip_bfloat16_raw
62 #if defined(__BF16_HOST_DEVICE__)
BFloat16(const __hip_bfloat16 & value)63 inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
64 x = __hip_bfloat16_raw(value).x;
65 }
__hip_bfloat16()66 inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
67 return __hip_bfloat16(__hip_bfloat16_raw{x});
68 }
69 #else // !defined(__BF16_HOST_DEVICE__)
BFloat16(const __hip_bfloat16 & value)70 inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
71 x = value.data;
72 }
__hip_bfloat16()73 inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
74 return __hip_bfloat16{x};
75 }
76 #endif // !defined(__BF16_HOST_DEVICE__)
77 #endif // defined(__HIPCC__) && defined(USE_ROCM)
78
79 #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
BFloat16(const sycl::ext::oneapi::bfloat16 & value)80 inline C10_HOST_DEVICE BFloat16::BFloat16(
81 const sycl::ext::oneapi::bfloat16& value) {
82 x = *reinterpret_cast<const unsigned short*>(&value);
83 }
bfloat16()84 inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
85 return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x);
86 }
87 #endif
88
89 // CUDA intrinsics
90
91 #if defined(__CUDACC__) || defined(__HIPCC__)
__ldg(const BFloat16 * ptr)92 inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
93 #if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
94 return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
95 #else
96 return *ptr;
97 #endif
98 }
99 #endif
100
101 /// Arithmetic
102
103 inline C10_HOST_DEVICE BFloat16
104 operator+(const BFloat16& a, const BFloat16& b) {
105 return static_cast<float>(a) + static_cast<float>(b);
106 }
107
108 inline C10_HOST_DEVICE BFloat16
109 operator-(const BFloat16& a, const BFloat16& b) {
110 return static_cast<float>(a) - static_cast<float>(b);
111 }
112
113 inline C10_HOST_DEVICE BFloat16
114 operator*(const BFloat16& a, const BFloat16& b) {
115 return static_cast<float>(a) * static_cast<float>(b);
116 }
117
118 inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b)
119 __ubsan_ignore_float_divide_by_zero__ {
120 return static_cast<float>(a) / static_cast<float>(b);
121 }
122
123 inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) {
124 return -static_cast<float>(a);
125 }
126
127 inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
128 a = a + b;
129 return a;
130 }
131
132 inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
133 a = a - b;
134 return a;
135 }
136
137 inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
138 a = a * b;
139 return a;
140 }
141
142 inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
143 a = a / b;
144 return a;
145 }
146
147 inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) {
148 a.x = a.x | b.x;
149 return a;
150 }
151
152 inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) {
153 a.x = a.x ^ b.x;
154 return a;
155 }
156
157 inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) {
158 a.x = a.x & b.x;
159 return a;
160 }
161
162 /// Arithmetic with floats
163
164 inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
165 return static_cast<float>(a) + b;
166 }
167 inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
168 return static_cast<float>(a) - b;
169 }
170 inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
171 return static_cast<float>(a) * b;
172 }
173 inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
174 return static_cast<float>(a) / b;
175 }
176
177 inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
178 return a + static_cast<float>(b);
179 }
180 inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
181 return a - static_cast<float>(b);
182 }
183 inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
184 return a * static_cast<float>(b);
185 }
186 inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
187 return a / static_cast<float>(b);
188 }
189
190 inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) {
191 return a += static_cast<float>(b);
192 }
193 inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) {
194 return a -= static_cast<float>(b);
195 }
196 inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) {
197 return a *= static_cast<float>(b);
198 }
199 inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) {
200 return a /= static_cast<float>(b);
201 }
202
203 /// Arithmetic with doubles
204
205 inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
206 return static_cast<double>(a) + b;
207 }
208 inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
209 return static_cast<double>(a) - b;
210 }
211 inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
212 return static_cast<double>(a) * b;
213 }
214 inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
215 return static_cast<double>(a) / b;
216 }
217
218 inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
219 return a + static_cast<double>(b);
220 }
221 inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
222 return a - static_cast<double>(b);
223 }
224 inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
225 return a * static_cast<double>(b);
226 }
227 inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
228 return a / static_cast<double>(b);
229 }
230
231 /// Arithmetic with ints
232
233 inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
234 return a + static_cast<BFloat16>(b);
235 }
236 inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
237 return a - static_cast<BFloat16>(b);
238 }
239 inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
240 return a * static_cast<BFloat16>(b);
241 }
242 inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
243 return a / static_cast<BFloat16>(b);
244 }
245
246 inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
247 return static_cast<BFloat16>(a) + b;
248 }
249 inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
250 return static_cast<BFloat16>(a) - b;
251 }
252 inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
253 return static_cast<BFloat16>(a) * b;
254 }
255 inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
256 return static_cast<BFloat16>(a) / b;
257 }
258
259 //// Arithmetic with int64_t
260
261 inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
262 return a + static_cast<BFloat16>(b);
263 }
264 inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
265 return a - static_cast<BFloat16>(b);
266 }
267 inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
268 return a * static_cast<BFloat16>(b);
269 }
270 inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
271 return a / static_cast<BFloat16>(b);
272 }
273
274 inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
275 return static_cast<BFloat16>(a) + b;
276 }
277 inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
278 return static_cast<BFloat16>(a) - b;
279 }
280 inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
281 return static_cast<BFloat16>(a) * b;
282 }
283 inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
284 return static_cast<BFloat16>(a) / b;
285 }
286
287 // Overloading < and > operators, because std::max and std::min use them.
288
289 inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) {
290 return float(lhs) > float(rhs);
291 }
292
293 inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) {
294 return float(lhs) < float(rhs);
295 }
296
297 } // namespace c10
298
299 namespace std {
300
301 template <>
302 class numeric_limits<c10::BFloat16> {
303 public:
304 static constexpr bool is_signed = true;
305 static constexpr bool is_specialized = true;
306 static constexpr bool is_integer = false;
307 static constexpr bool is_exact = false;
308 static constexpr bool has_infinity = true;
309 static constexpr bool has_quiet_NaN = true;
310 static constexpr bool has_signaling_NaN = true;
311 static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
312 static constexpr auto has_denorm_loss =
313 numeric_limits<float>::has_denorm_loss;
314 static constexpr auto round_style = numeric_limits<float>::round_style;
315 static constexpr bool is_iec559 = false;
316 static constexpr bool is_bounded = true;
317 static constexpr bool is_modulo = false;
318 static constexpr int digits = 8;
319 static constexpr int digits10 = 2;
320 static constexpr int max_digits10 = 4;
321 static constexpr int radix = 2;
322 static constexpr int min_exponent = -125;
323 static constexpr int min_exponent10 = -37;
324 static constexpr int max_exponent = 128;
325 static constexpr int max_exponent10 = 38;
326 static constexpr auto traps = numeric_limits<float>::traps;
327 static constexpr auto tinyness_before =
328 numeric_limits<float>::tinyness_before;
329
min()330 static constexpr c10::BFloat16 min() {
331 return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
332 }
lowest()333 static constexpr c10::BFloat16 lowest() {
334 return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
335 }
max()336 static constexpr c10::BFloat16 max() {
337 return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
338 }
epsilon()339 static constexpr c10::BFloat16 epsilon() {
340 return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
341 }
round_error()342 static constexpr c10::BFloat16 round_error() {
343 return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
344 }
infinity()345 static constexpr c10::BFloat16 infinity() {
346 return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
347 }
quiet_NaN()348 static constexpr c10::BFloat16 quiet_NaN() {
349 return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
350 }
signaling_NaN()351 static constexpr c10::BFloat16 signaling_NaN() {
352 return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
353 }
denorm_min()354 static constexpr c10::BFloat16 denorm_min() {
355 return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
356 }
357 };
358
359 } // namespace std
360
361 C10_CLANG_DIAGNOSTIC_POP()
362