xref: /aosp_15_r20/external/pytorch/c10/util/Float8_e4m3fnuz-inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 #include <c10/util/Float8_fnuz_cvt.h>
5 #include <cstring>
6 #include <limits>
7 
8 C10_CLANG_DIAGNOSTIC_PUSH()
9 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
10 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
11 #endif
12 
13 namespace c10 {
14 
15 /// Constructors
16 
Float8_e4m3fnuz(float value)17 inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value)
18     : x(detail::fp8e4m3fnuz_from_fp32_value(value)) {}
19 
20 /// Implicit conversions
21 
22 inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const {
23   return detail::fp8_fnuz_to_fp32_value<4, 3>(x);
24 }
25 
26 /// Special values helper
27 
isnan()28 inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const {
29   return x == 0b10000000;
30 }
31 
32 /// Arithmetic
33 
34 inline C10_HOST_DEVICE Float8_e4m3fnuz
35 operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
36   return static_cast<float>(a) + static_cast<float>(b);
37 }
38 
39 inline C10_HOST_DEVICE Float8_e4m3fnuz
40 operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
41   return static_cast<float>(a) - static_cast<float>(b);
42 }
43 
44 inline C10_HOST_DEVICE Float8_e4m3fnuz
45 operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
46   return static_cast<float>(a) * static_cast<float>(b);
47 }
48 
49 inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(
50     const Float8_e4m3fnuz& a,
51     const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
52   return static_cast<float>(a) / static_cast<float>(b);
53 }
54 
55 inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) {
56   return -static_cast<float>(a);
57 }
58 
59 inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=(
60     Float8_e4m3fnuz& a,
61     const Float8_e4m3fnuz& b) {
62   a = a + b;
63   return a;
64 }
65 
66 inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=(
67     Float8_e4m3fnuz& a,
68     const Float8_e4m3fnuz& b) {
69   a = a - b;
70   return a;
71 }
72 
73 inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=(
74     Float8_e4m3fnuz& a,
75     const Float8_e4m3fnuz& b) {
76   a = a * b;
77   return a;
78 }
79 
80 inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=(
81     Float8_e4m3fnuz& a,
82     const Float8_e4m3fnuz& b) {
83   a = a / b;
84   return a;
85 }
86 
87 /// Arithmetic with floats
88 
89 inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) {
90   return static_cast<float>(a) + b;
91 }
92 inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) {
93   return static_cast<float>(a) - b;
94 }
95 inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) {
96   return static_cast<float>(a) * b;
97 }
98 inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b)
99     __ubsan_ignore_float_divide_by_zero__ {
100   return static_cast<float>(a) / b;
101 }
102 
103 inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) {
104   return a + static_cast<float>(b);
105 }
106 inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) {
107   return a - static_cast<float>(b);
108 }
109 inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) {
110   return a * static_cast<float>(b);
111 }
112 inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b)
113     __ubsan_ignore_float_divide_by_zero__ {
114   return a / static_cast<float>(b);
115 }
116 
117 inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) {
118   return a += static_cast<float>(b);
119 }
120 inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) {
121   return a -= static_cast<float>(b);
122 }
123 inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) {
124   return a *= static_cast<float>(b);
125 }
126 inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) {
127   return a /= static_cast<float>(b);
128 }
129 
130 /// Arithmetic with doubles
131 
132 inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) {
133   return static_cast<double>(a) + b;
134 }
135 inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) {
136   return static_cast<double>(a) - b;
137 }
138 inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) {
139   return static_cast<double>(a) * b;
140 }
141 inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b)
142     __ubsan_ignore_float_divide_by_zero__ {
143   return static_cast<double>(a) / b;
144 }
145 
146 inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) {
147   return a + static_cast<double>(b);
148 }
149 inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) {
150   return a - static_cast<double>(b);
151 }
152 inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) {
153   return a * static_cast<double>(b);
154 }
155 inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b)
156     __ubsan_ignore_float_divide_by_zero__ {
157   return a / static_cast<double>(b);
158 }
159 
160 /// Arithmetic with ints
161 
162 inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) {
163   return a + static_cast<Float8_e4m3fnuz>(b);
164 }
165 inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) {
166   return a - static_cast<Float8_e4m3fnuz>(b);
167 }
168 inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) {
169   return a * static_cast<Float8_e4m3fnuz>(b);
170 }
171 inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) {
172   return a / static_cast<Float8_e4m3fnuz>(b);
173 }
174 
175 inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) {
176   return static_cast<Float8_e4m3fnuz>(a) + b;
177 }
178 inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) {
179   return static_cast<Float8_e4m3fnuz>(a) - b;
180 }
181 inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) {
182   return static_cast<Float8_e4m3fnuz>(a) * b;
183 }
184 inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) {
185   return static_cast<Float8_e4m3fnuz>(a) / b;
186 }
187 
188 //// Arithmetic with int64_t
189 
190 inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) {
191   return a + static_cast<Float8_e4m3fnuz>(b);
192 }
193 inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) {
194   return a - static_cast<Float8_e4m3fnuz>(b);
195 }
196 inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) {
197   return a * static_cast<Float8_e4m3fnuz>(b);
198 }
199 inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) {
200   return a / static_cast<Float8_e4m3fnuz>(b);
201 }
202 
203 inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) {
204   return static_cast<Float8_e4m3fnuz>(a) + b;
205 }
206 inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) {
207   return static_cast<Float8_e4m3fnuz>(a) - b;
208 }
209 inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) {
210   return static_cast<Float8_e4m3fnuz>(a) * b;
211 }
212 inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) {
213   return static_cast<Float8_e4m3fnuz>(a) / b;
214 }
215 
216 /// NOTE: we do not define comparisons directly and instead rely on the implicit
217 /// conversion from c10::Float8_e4m3fnuz to float.
218 
219 } // namespace c10
220 
221 namespace std {
222 
223 template <>
224 class numeric_limits<c10::Float8_e4m3fnuz> {
225  public:
226   static constexpr bool is_specialized = true;
227   static constexpr bool is_signed = true;
228   static constexpr bool is_integer = false;
229   static constexpr bool is_exact = false;
230   static constexpr bool has_infinity = false;
231   static constexpr bool has_quiet_NaN = true;
232   static constexpr bool has_signaling_NaN = false;
233   static constexpr auto has_denorm = true;
234   static constexpr auto has_denorm_loss = true;
235   static constexpr auto round_style = numeric_limits<float>::round_style;
236   static constexpr bool is_iec559 = false;
237   static constexpr bool is_bounded = true;
238   static constexpr bool is_modulo = false;
239   static constexpr int digits = 4;
240   static constexpr int digits10 = 0;
241   static constexpr int max_digits10 = 3;
242   static constexpr int radix = 2;
243   static constexpr int min_exponent = -6;
244   static constexpr int min_exponent10 = -1;
245   static constexpr int max_exponent = 8;
246   static constexpr int max_exponent10 = 2;
247   static constexpr auto traps = numeric_limits<float>::traps;
248   static constexpr auto tinyness_before = false;
249 
min()250   static constexpr c10::Float8_e4m3fnuz min() {
251     return c10::Float8_e4m3fnuz(0x08, c10::Float8_e4m3fnuz::from_bits());
252   }
lowest()253   static constexpr c10::Float8_e4m3fnuz lowest() {
254     return c10::Float8_e4m3fnuz(0xFF, c10::Float8_e4m3fnuz::from_bits());
255   }
max()256   static constexpr c10::Float8_e4m3fnuz max() {
257     return c10::Float8_e4m3fnuz(0x7F, c10::Float8_e4m3fnuz::from_bits());
258   }
epsilon()259   static constexpr c10::Float8_e4m3fnuz epsilon() {
260     return c10::Float8_e4m3fnuz(0x28, c10::Float8_e4m3fnuz::from_bits());
261   }
round_error()262   static constexpr c10::Float8_e4m3fnuz round_error() {
263     return c10::Float8_e4m3fnuz(0x38, c10::Float8_e4m3fnuz::from_bits());
264   }
infinity()265   static constexpr c10::Float8_e4m3fnuz infinity() {
266     // NaN (no infinities)
267     return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
268   }
quiet_NaN()269   static constexpr c10::Float8_e4m3fnuz quiet_NaN() {
270     return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
271   }
denorm_min()272   static constexpr c10::Float8_e4m3fnuz denorm_min() {
273     return c10::Float8_e4m3fnuz(0x01, c10::Float8_e4m3fnuz::from_bits());
274   }
275 };
276 
277 } // namespace std
278 
279 C10_CLANG_DIAGNOSTIC_POP()
280