1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/util/bit_cast.h>
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker #include <limits>
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker C10_CLANG_DIAGNOSTIC_PUSH()
9*da0073e9SAndroid Build Coastguard Worker #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
10*da0073e9SAndroid Build Coastguard Worker C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
11*da0073e9SAndroid Build Coastguard Worker #endif
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
14*da0073e9SAndroid Build Coastguard Worker #if defined(CL_SYCL_LANGUAGE_VERSION)
15*da0073e9SAndroid Build Coastguard Worker #include <CL/sycl.hpp> // for SYCL 1.2.1
16*da0073e9SAndroid Build Coastguard Worker #else
17*da0073e9SAndroid Build Coastguard Worker #include <sycl/sycl.hpp> // for SYCL 2020
18*da0073e9SAndroid Build Coastguard Worker #endif
19*da0073e9SAndroid Build Coastguard Worker #include <ext/oneapi/bfloat16.hpp>
20*da0073e9SAndroid Build Coastguard Worker #endif
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker namespace c10 {
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker /// Constructors
BFloat16(float value)25*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
26*da0073e9SAndroid Build Coastguard Worker :
27*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \
28*da0073e9SAndroid Build Coastguard Worker __CUDA_ARCH__ >= 800
29*da0073e9SAndroid Build Coastguard Worker x(__bfloat16_as_ushort(__float2bfloat16(value)))
30*da0073e9SAndroid Build Coastguard Worker #elif defined(__SYCL_DEVICE_ONLY__) && \
31*da0073e9SAndroid Build Coastguard Worker defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
32*da0073e9SAndroid Build Coastguard Worker x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
33*da0073e9SAndroid Build Coastguard Worker #else
34*da0073e9SAndroid Build Coastguard Worker // RNE by default
35*da0073e9SAndroid Build Coastguard Worker x(detail::round_to_nearest_even(value))
36*da0073e9SAndroid Build Coastguard Worker #endif
37*da0073e9SAndroid Build Coastguard Worker {
38*da0073e9SAndroid Build Coastguard Worker }
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker /// Implicit conversions
41*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16::operator float() const {
42*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) && !defined(USE_ROCM)
43*da0073e9SAndroid Build Coastguard Worker return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&x));
44*da0073e9SAndroid Build Coastguard Worker #elif defined(__SYCL_DEVICE_ONLY__) && \
45*da0073e9SAndroid Build Coastguard Worker defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
46*da0073e9SAndroid Build Coastguard Worker return float(*reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x));
47*da0073e9SAndroid Build Coastguard Worker #else
48*da0073e9SAndroid Build Coastguard Worker return detail::f32_from_bits(x);
49*da0073e9SAndroid Build Coastguard Worker #endif
50*da0073e9SAndroid Build Coastguard Worker }
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) && !defined(USE_ROCM)
BFloat16(const __nv_bfloat16 & value)53*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) {
54*da0073e9SAndroid Build Coastguard Worker x = *reinterpret_cast<const unsigned short*>(&value);
55*da0073e9SAndroid Build Coastguard Worker }
__nv_bfloat16()56*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const {
57*da0073e9SAndroid Build Coastguard Worker return *reinterpret_cast<const __nv_bfloat16*>(&x);
58*da0073e9SAndroid Build Coastguard Worker }
59*da0073e9SAndroid Build Coastguard Worker #endif
60*da0073e9SAndroid Build Coastguard Worker #if defined(__HIPCC__) && defined(USE_ROCM)
61*da0073e9SAndroid Build Coastguard Worker // 6.2.0 introduced __hip_bfloat16_raw
62*da0073e9SAndroid Build Coastguard Worker #if defined(__BF16_HOST_DEVICE__)
BFloat16(const __hip_bfloat16 & value)63*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
64*da0073e9SAndroid Build Coastguard Worker x = __hip_bfloat16_raw(value).x;
65*da0073e9SAndroid Build Coastguard Worker }
__hip_bfloat16()66*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
67*da0073e9SAndroid Build Coastguard Worker return __hip_bfloat16(__hip_bfloat16_raw{x});
68*da0073e9SAndroid Build Coastguard Worker }
69*da0073e9SAndroid Build Coastguard Worker #else // !defined(__BF16_HOST_DEVICE__)
BFloat16(const __hip_bfloat16 & value)70*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) {
71*da0073e9SAndroid Build Coastguard Worker x = value.data;
72*da0073e9SAndroid Build Coastguard Worker }
__hip_bfloat16()73*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const {
74*da0073e9SAndroid Build Coastguard Worker return __hip_bfloat16{x};
75*da0073e9SAndroid Build Coastguard Worker }
76*da0073e9SAndroid Build Coastguard Worker #endif // !defined(__BF16_HOST_DEVICE__)
77*da0073e9SAndroid Build Coastguard Worker #endif // defined(__HIPCC__) && defined(USE_ROCM)
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
BFloat16(const sycl::ext::oneapi::bfloat16 & value)80*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16::BFloat16(
81*da0073e9SAndroid Build Coastguard Worker const sycl::ext::oneapi::bfloat16& value) {
82*da0073e9SAndroid Build Coastguard Worker x = *reinterpret_cast<const unsigned short*>(&value);
83*da0073e9SAndroid Build Coastguard Worker }
bfloat16()84*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16::operator sycl::ext::oneapi::bfloat16() const {
85*da0073e9SAndroid Build Coastguard Worker return *reinterpret_cast<const sycl::ext::oneapi::bfloat16*>(&x);
86*da0073e9SAndroid Build Coastguard Worker }
87*da0073e9SAndroid Build Coastguard Worker #endif
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker // CUDA intrinsics
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
__ldg(const BFloat16 * ptr)92*da0073e9SAndroid Build Coastguard Worker inline C10_DEVICE BFloat16 __ldg(const BFloat16* ptr) {
93*da0073e9SAndroid Build Coastguard Worker #if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
94*da0073e9SAndroid Build Coastguard Worker return __ldg(reinterpret_cast<const __nv_bfloat16*>(ptr));
95*da0073e9SAndroid Build Coastguard Worker #else
96*da0073e9SAndroid Build Coastguard Worker return *ptr;
97*da0073e9SAndroid Build Coastguard Worker #endif
98*da0073e9SAndroid Build Coastguard Worker }
99*da0073e9SAndroid Build Coastguard Worker #endif
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker /// Arithmetic
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16
104*da0073e9SAndroid Build Coastguard Worker operator+(const BFloat16& a, const BFloat16& b) {
105*da0073e9SAndroid Build Coastguard Worker return static_cast<float>(a) + static_cast<float>(b);
106*da0073e9SAndroid Build Coastguard Worker }
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16
109*da0073e9SAndroid Build Coastguard Worker operator-(const BFloat16& a, const BFloat16& b) {
110*da0073e9SAndroid Build Coastguard Worker return static_cast<float>(a) - static_cast<float>(b);
111*da0073e9SAndroid Build Coastguard Worker }
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16
114*da0073e9SAndroid Build Coastguard Worker operator*(const BFloat16& a, const BFloat16& b) {
115*da0073e9SAndroid Build Coastguard Worker return static_cast<float>(a) * static_cast<float>(b);
116*da0073e9SAndroid Build Coastguard Worker }
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator/(const BFloat16& a, const BFloat16& b)
119*da0073e9SAndroid Build Coastguard Worker __ubsan_ignore_float_divide_by_zero__ {
120*da0073e9SAndroid Build Coastguard Worker return static_cast<float>(a) / static_cast<float>(b);
121*da0073e9SAndroid Build Coastguard Worker }
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator-(const BFloat16& a) {
124*da0073e9SAndroid Build Coastguard Worker return -static_cast<float>(a);
125*da0073e9SAndroid Build Coastguard Worker }
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16& operator+=(BFloat16& a, const BFloat16& b) {
128*da0073e9SAndroid Build Coastguard Worker a = a + b;
129*da0073e9SAndroid Build Coastguard Worker return a;
130*da0073e9SAndroid Build Coastguard Worker }
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16& operator-=(BFloat16& a, const BFloat16& b) {
133*da0073e9SAndroid Build Coastguard Worker a = a - b;
134*da0073e9SAndroid Build Coastguard Worker return a;
135*da0073e9SAndroid Build Coastguard Worker }
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16& operator*=(BFloat16& a, const BFloat16& b) {
138*da0073e9SAndroid Build Coastguard Worker a = a * b;
139*da0073e9SAndroid Build Coastguard Worker return a;
140*da0073e9SAndroid Build Coastguard Worker }
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16& operator/=(BFloat16& a, const BFloat16& b) {
143*da0073e9SAndroid Build Coastguard Worker a = a / b;
144*da0073e9SAndroid Build Coastguard Worker return a;
145*da0073e9SAndroid Build Coastguard Worker }
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16& operator|(BFloat16& a, const BFloat16& b) {
148*da0073e9SAndroid Build Coastguard Worker a.x = a.x | b.x;
149*da0073e9SAndroid Build Coastguard Worker return a;
150*da0073e9SAndroid Build Coastguard Worker }
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16& operator^(BFloat16& a, const BFloat16& b) {
153*da0073e9SAndroid Build Coastguard Worker a.x = a.x ^ b.x;
154*da0073e9SAndroid Build Coastguard Worker return a;
155*da0073e9SAndroid Build Coastguard Worker }
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16& operator&(BFloat16& a, const BFloat16& b) {
158*da0073e9SAndroid Build Coastguard Worker a.x = a.x & b.x;
159*da0073e9SAndroid Build Coastguard Worker return a;
160*da0073e9SAndroid Build Coastguard Worker }
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker /// Arithmetic with floats
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float operator+(BFloat16 a, float b) {
165*da0073e9SAndroid Build Coastguard Worker return static_cast<float>(a) + b;
166*da0073e9SAndroid Build Coastguard Worker }
167*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float operator-(BFloat16 a, float b) {
168*da0073e9SAndroid Build Coastguard Worker return static_cast<float>(a) - b;
169*da0073e9SAndroid Build Coastguard Worker }
170*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float operator*(BFloat16 a, float b) {
171*da0073e9SAndroid Build Coastguard Worker return static_cast<float>(a) * b;
172*da0073e9SAndroid Build Coastguard Worker }
173*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float operator/(BFloat16 a, float b) {
174*da0073e9SAndroid Build Coastguard Worker return static_cast<float>(a) / b;
175*da0073e9SAndroid Build Coastguard Worker }
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float operator+(float a, BFloat16 b) {
178*da0073e9SAndroid Build Coastguard Worker return a + static_cast<float>(b);
179*da0073e9SAndroid Build Coastguard Worker }
180*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float operator-(float a, BFloat16 b) {
181*da0073e9SAndroid Build Coastguard Worker return a - static_cast<float>(b);
182*da0073e9SAndroid Build Coastguard Worker }
183*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float operator*(float a, BFloat16 b) {
184*da0073e9SAndroid Build Coastguard Worker return a * static_cast<float>(b);
185*da0073e9SAndroid Build Coastguard Worker }
186*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float operator/(float a, BFloat16 b) {
187*da0073e9SAndroid Build Coastguard Worker return a / static_cast<float>(b);
188*da0073e9SAndroid Build Coastguard Worker }
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float& operator+=(float& a, const BFloat16& b) {
191*da0073e9SAndroid Build Coastguard Worker return a += static_cast<float>(b);
192*da0073e9SAndroid Build Coastguard Worker }
193*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float& operator-=(float& a, const BFloat16& b) {
194*da0073e9SAndroid Build Coastguard Worker return a -= static_cast<float>(b);
195*da0073e9SAndroid Build Coastguard Worker }
196*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float& operator*=(float& a, const BFloat16& b) {
197*da0073e9SAndroid Build Coastguard Worker return a *= static_cast<float>(b);
198*da0073e9SAndroid Build Coastguard Worker }
199*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE float& operator/=(float& a, const BFloat16& b) {
200*da0073e9SAndroid Build Coastguard Worker return a /= static_cast<float>(b);
201*da0073e9SAndroid Build Coastguard Worker }
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker /// Arithmetic with doubles
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE double operator+(BFloat16 a, double b) {
206*da0073e9SAndroid Build Coastguard Worker return static_cast<double>(a) + b;
207*da0073e9SAndroid Build Coastguard Worker }
208*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE double operator-(BFloat16 a, double b) {
209*da0073e9SAndroid Build Coastguard Worker return static_cast<double>(a) - b;
210*da0073e9SAndroid Build Coastguard Worker }
211*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE double operator*(BFloat16 a, double b) {
212*da0073e9SAndroid Build Coastguard Worker return static_cast<double>(a) * b;
213*da0073e9SAndroid Build Coastguard Worker }
214*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE double operator/(BFloat16 a, double b) {
215*da0073e9SAndroid Build Coastguard Worker return static_cast<double>(a) / b;
216*da0073e9SAndroid Build Coastguard Worker }
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE double operator+(double a, BFloat16 b) {
219*da0073e9SAndroid Build Coastguard Worker return a + static_cast<double>(b);
220*da0073e9SAndroid Build Coastguard Worker }
221*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE double operator-(double a, BFloat16 b) {
222*da0073e9SAndroid Build Coastguard Worker return a - static_cast<double>(b);
223*da0073e9SAndroid Build Coastguard Worker }
224*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE double operator*(double a, BFloat16 b) {
225*da0073e9SAndroid Build Coastguard Worker return a * static_cast<double>(b);
226*da0073e9SAndroid Build Coastguard Worker }
227*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE double operator/(double a, BFloat16 b) {
228*da0073e9SAndroid Build Coastguard Worker return a / static_cast<double>(b);
229*da0073e9SAndroid Build Coastguard Worker }
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker /// Arithmetic with ints
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int b) {
234*da0073e9SAndroid Build Coastguard Worker return a + static_cast<BFloat16>(b);
235*da0073e9SAndroid Build Coastguard Worker }
236*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int b) {
237*da0073e9SAndroid Build Coastguard Worker return a - static_cast<BFloat16>(b);
238*da0073e9SAndroid Build Coastguard Worker }
239*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int b) {
240*da0073e9SAndroid Build Coastguard Worker return a * static_cast<BFloat16>(b);
241*da0073e9SAndroid Build Coastguard Worker }
242*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int b) {
243*da0073e9SAndroid Build Coastguard Worker return a / static_cast<BFloat16>(b);
244*da0073e9SAndroid Build Coastguard Worker }
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator+(int a, BFloat16 b) {
247*da0073e9SAndroid Build Coastguard Worker return static_cast<BFloat16>(a) + b;
248*da0073e9SAndroid Build Coastguard Worker }
249*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator-(int a, BFloat16 b) {
250*da0073e9SAndroid Build Coastguard Worker return static_cast<BFloat16>(a) - b;
251*da0073e9SAndroid Build Coastguard Worker }
252*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator*(int a, BFloat16 b) {
253*da0073e9SAndroid Build Coastguard Worker return static_cast<BFloat16>(a) * b;
254*da0073e9SAndroid Build Coastguard Worker }
255*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator/(int a, BFloat16 b) {
256*da0073e9SAndroid Build Coastguard Worker return static_cast<BFloat16>(a) / b;
257*da0073e9SAndroid Build Coastguard Worker }
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker //// Arithmetic with int64_t
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator+(BFloat16 a, int64_t b) {
262*da0073e9SAndroid Build Coastguard Worker return a + static_cast<BFloat16>(b);
263*da0073e9SAndroid Build Coastguard Worker }
264*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator-(BFloat16 a, int64_t b) {
265*da0073e9SAndroid Build Coastguard Worker return a - static_cast<BFloat16>(b);
266*da0073e9SAndroid Build Coastguard Worker }
267*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator*(BFloat16 a, int64_t b) {
268*da0073e9SAndroid Build Coastguard Worker return a * static_cast<BFloat16>(b);
269*da0073e9SAndroid Build Coastguard Worker }
270*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator/(BFloat16 a, int64_t b) {
271*da0073e9SAndroid Build Coastguard Worker return a / static_cast<BFloat16>(b);
272*da0073e9SAndroid Build Coastguard Worker }
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator+(int64_t a, BFloat16 b) {
275*da0073e9SAndroid Build Coastguard Worker return static_cast<BFloat16>(a) + b;
276*da0073e9SAndroid Build Coastguard Worker }
277*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator-(int64_t a, BFloat16 b) {
278*da0073e9SAndroid Build Coastguard Worker return static_cast<BFloat16>(a) - b;
279*da0073e9SAndroid Build Coastguard Worker }
280*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator*(int64_t a, BFloat16 b) {
281*da0073e9SAndroid Build Coastguard Worker return static_cast<BFloat16>(a) * b;
282*da0073e9SAndroid Build Coastguard Worker }
283*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE BFloat16 operator/(int64_t a, BFloat16 b) {
284*da0073e9SAndroid Build Coastguard Worker return static_cast<BFloat16>(a) / b;
285*da0073e9SAndroid Build Coastguard Worker }
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker // Overloading < and > operators, because std::max and std::min use them.
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool operator>(BFloat16& lhs, BFloat16& rhs) {
290*da0073e9SAndroid Build Coastguard Worker return float(lhs) > float(rhs);
291*da0073e9SAndroid Build Coastguard Worker }
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker inline C10_HOST_DEVICE bool operator<(BFloat16& lhs, BFloat16& rhs) {
294*da0073e9SAndroid Build Coastguard Worker return float(lhs) < float(rhs);
295*da0073e9SAndroid Build Coastguard Worker }
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker } // namespace c10
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker namespace std {
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker template <>
302*da0073e9SAndroid Build Coastguard Worker class numeric_limits<c10::BFloat16> {
303*da0073e9SAndroid Build Coastguard Worker public:
304*da0073e9SAndroid Build Coastguard Worker static constexpr bool is_signed = true;
305*da0073e9SAndroid Build Coastguard Worker static constexpr bool is_specialized = true;
306*da0073e9SAndroid Build Coastguard Worker static constexpr bool is_integer = false;
307*da0073e9SAndroid Build Coastguard Worker static constexpr bool is_exact = false;
308*da0073e9SAndroid Build Coastguard Worker static constexpr bool has_infinity = true;
309*da0073e9SAndroid Build Coastguard Worker static constexpr bool has_quiet_NaN = true;
310*da0073e9SAndroid Build Coastguard Worker static constexpr bool has_signaling_NaN = true;
311*da0073e9SAndroid Build Coastguard Worker static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
312*da0073e9SAndroid Build Coastguard Worker static constexpr auto has_denorm_loss =
313*da0073e9SAndroid Build Coastguard Worker numeric_limits<float>::has_denorm_loss;
314*da0073e9SAndroid Build Coastguard Worker static constexpr auto round_style = numeric_limits<float>::round_style;
315*da0073e9SAndroid Build Coastguard Worker static constexpr bool is_iec559 = false;
316*da0073e9SAndroid Build Coastguard Worker static constexpr bool is_bounded = true;
317*da0073e9SAndroid Build Coastguard Worker static constexpr bool is_modulo = false;
318*da0073e9SAndroid Build Coastguard Worker static constexpr int digits = 8;
319*da0073e9SAndroid Build Coastguard Worker static constexpr int digits10 = 2;
320*da0073e9SAndroid Build Coastguard Worker static constexpr int max_digits10 = 4;
321*da0073e9SAndroid Build Coastguard Worker static constexpr int radix = 2;
322*da0073e9SAndroid Build Coastguard Worker static constexpr int min_exponent = -125;
323*da0073e9SAndroid Build Coastguard Worker static constexpr int min_exponent10 = -37;
324*da0073e9SAndroid Build Coastguard Worker static constexpr int max_exponent = 128;
325*da0073e9SAndroid Build Coastguard Worker static constexpr int max_exponent10 = 38;
326*da0073e9SAndroid Build Coastguard Worker static constexpr auto traps = numeric_limits<float>::traps;
327*da0073e9SAndroid Build Coastguard Worker static constexpr auto tinyness_before =
328*da0073e9SAndroid Build Coastguard Worker numeric_limits<float>::tinyness_before;
329*da0073e9SAndroid Build Coastguard Worker
min()330*da0073e9SAndroid Build Coastguard Worker static constexpr c10::BFloat16 min() {
331*da0073e9SAndroid Build Coastguard Worker return c10::BFloat16(0x0080, c10::BFloat16::from_bits());
332*da0073e9SAndroid Build Coastguard Worker }
lowest()333*da0073e9SAndroid Build Coastguard Worker static constexpr c10::BFloat16 lowest() {
334*da0073e9SAndroid Build Coastguard Worker return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits());
335*da0073e9SAndroid Build Coastguard Worker }
max()336*da0073e9SAndroid Build Coastguard Worker static constexpr c10::BFloat16 max() {
337*da0073e9SAndroid Build Coastguard Worker return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits());
338*da0073e9SAndroid Build Coastguard Worker }
epsilon()339*da0073e9SAndroid Build Coastguard Worker static constexpr c10::BFloat16 epsilon() {
340*da0073e9SAndroid Build Coastguard Worker return c10::BFloat16(0x3C00, c10::BFloat16::from_bits());
341*da0073e9SAndroid Build Coastguard Worker }
round_error()342*da0073e9SAndroid Build Coastguard Worker static constexpr c10::BFloat16 round_error() {
343*da0073e9SAndroid Build Coastguard Worker return c10::BFloat16(0x3F00, c10::BFloat16::from_bits());
344*da0073e9SAndroid Build Coastguard Worker }
infinity()345*da0073e9SAndroid Build Coastguard Worker static constexpr c10::BFloat16 infinity() {
346*da0073e9SAndroid Build Coastguard Worker return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
347*da0073e9SAndroid Build Coastguard Worker }
quiet_NaN()348*da0073e9SAndroid Build Coastguard Worker static constexpr c10::BFloat16 quiet_NaN() {
349*da0073e9SAndroid Build Coastguard Worker return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits());
350*da0073e9SAndroid Build Coastguard Worker }
signaling_NaN()351*da0073e9SAndroid Build Coastguard Worker static constexpr c10::BFloat16 signaling_NaN() {
352*da0073e9SAndroid Build Coastguard Worker return c10::BFloat16(0x7F80, c10::BFloat16::from_bits());
353*da0073e9SAndroid Build Coastguard Worker }
denorm_min()354*da0073e9SAndroid Build Coastguard Worker static constexpr c10::BFloat16 denorm_min() {
355*da0073e9SAndroid Build Coastguard Worker return c10::BFloat16(0x0001, c10::BFloat16::from_bits());
356*da0073e9SAndroid Build Coastguard Worker }
357*da0073e9SAndroid Build Coastguard Worker };
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker } // namespace std
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker C10_CLANG_DIAGNOSTIC_POP()
362