1 #pragma once
2
3 #include <c10/macros/Macros.h>
4 #include <c10/util/Float8_fnuz_cvt.h>
5 #include <cstring>
6 #include <limits>
7
8 C10_CLANG_DIAGNOSTIC_PUSH()
9 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
10 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
11 #endif
12
13 namespace c10 {
14
15 /// Constructors
16
Float8_e4m3fnuz(float value)17 inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value)
18 : x(detail::fp8e4m3fnuz_from_fp32_value(value)) {}
19
20 /// Implicit conversions
21
22 inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const {
23 return detail::fp8_fnuz_to_fp32_value<4, 3>(x);
24 }
25
26 /// Special values helper
27
isnan()28 inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const {
29 return x == 0b10000000;
30 }
31
32 /// Arithmetic
33
34 inline C10_HOST_DEVICE Float8_e4m3fnuz
35 operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
36 return static_cast<float>(a) + static_cast<float>(b);
37 }
38
39 inline C10_HOST_DEVICE Float8_e4m3fnuz
40 operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
41 return static_cast<float>(a) - static_cast<float>(b);
42 }
43
44 inline C10_HOST_DEVICE Float8_e4m3fnuz
45 operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) {
46 return static_cast<float>(a) * static_cast<float>(b);
47 }
48
49 inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(
50 const Float8_e4m3fnuz& a,
51 const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ {
52 return static_cast<float>(a) / static_cast<float>(b);
53 }
54
55 inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) {
56 return -static_cast<float>(a);
57 }
58
59 inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=(
60 Float8_e4m3fnuz& a,
61 const Float8_e4m3fnuz& b) {
62 a = a + b;
63 return a;
64 }
65
66 inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=(
67 Float8_e4m3fnuz& a,
68 const Float8_e4m3fnuz& b) {
69 a = a - b;
70 return a;
71 }
72
73 inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=(
74 Float8_e4m3fnuz& a,
75 const Float8_e4m3fnuz& b) {
76 a = a * b;
77 return a;
78 }
79
80 inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=(
81 Float8_e4m3fnuz& a,
82 const Float8_e4m3fnuz& b) {
83 a = a / b;
84 return a;
85 }
86
87 /// Arithmetic with floats
88
89 inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) {
90 return static_cast<float>(a) + b;
91 }
92 inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) {
93 return static_cast<float>(a) - b;
94 }
95 inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) {
96 return static_cast<float>(a) * b;
97 }
98 inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b)
99 __ubsan_ignore_float_divide_by_zero__ {
100 return static_cast<float>(a) / b;
101 }
102
103 inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) {
104 return a + static_cast<float>(b);
105 }
106 inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) {
107 return a - static_cast<float>(b);
108 }
109 inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) {
110 return a * static_cast<float>(b);
111 }
112 inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b)
113 __ubsan_ignore_float_divide_by_zero__ {
114 return a / static_cast<float>(b);
115 }
116
117 inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) {
118 return a += static_cast<float>(b);
119 }
120 inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) {
121 return a -= static_cast<float>(b);
122 }
123 inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) {
124 return a *= static_cast<float>(b);
125 }
126 inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) {
127 return a /= static_cast<float>(b);
128 }
129
130 /// Arithmetic with doubles
131
132 inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) {
133 return static_cast<double>(a) + b;
134 }
135 inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) {
136 return static_cast<double>(a) - b;
137 }
138 inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) {
139 return static_cast<double>(a) * b;
140 }
141 inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b)
142 __ubsan_ignore_float_divide_by_zero__ {
143 return static_cast<double>(a) / b;
144 }
145
146 inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) {
147 return a + static_cast<double>(b);
148 }
149 inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) {
150 return a - static_cast<double>(b);
151 }
152 inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) {
153 return a * static_cast<double>(b);
154 }
155 inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b)
156 __ubsan_ignore_float_divide_by_zero__ {
157 return a / static_cast<double>(b);
158 }
159
160 /// Arithmetic with ints
161
162 inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) {
163 return a + static_cast<Float8_e4m3fnuz>(b);
164 }
165 inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) {
166 return a - static_cast<Float8_e4m3fnuz>(b);
167 }
168 inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) {
169 return a * static_cast<Float8_e4m3fnuz>(b);
170 }
171 inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) {
172 return a / static_cast<Float8_e4m3fnuz>(b);
173 }
174
175 inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) {
176 return static_cast<Float8_e4m3fnuz>(a) + b;
177 }
178 inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) {
179 return static_cast<Float8_e4m3fnuz>(a) - b;
180 }
181 inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) {
182 return static_cast<Float8_e4m3fnuz>(a) * b;
183 }
184 inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) {
185 return static_cast<Float8_e4m3fnuz>(a) / b;
186 }
187
188 //// Arithmetic with int64_t
189
190 inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) {
191 return a + static_cast<Float8_e4m3fnuz>(b);
192 }
193 inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) {
194 return a - static_cast<Float8_e4m3fnuz>(b);
195 }
196 inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) {
197 return a * static_cast<Float8_e4m3fnuz>(b);
198 }
199 inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) {
200 return a / static_cast<Float8_e4m3fnuz>(b);
201 }
202
203 inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) {
204 return static_cast<Float8_e4m3fnuz>(a) + b;
205 }
206 inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) {
207 return static_cast<Float8_e4m3fnuz>(a) - b;
208 }
209 inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) {
210 return static_cast<Float8_e4m3fnuz>(a) * b;
211 }
212 inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) {
213 return static_cast<Float8_e4m3fnuz>(a) / b;
214 }
215
216 /// NOTE: we do not define comparisons directly and instead rely on the implicit
217 /// conversion from c10::Float8_e4m3fnuz to float.
218
219 } // namespace c10
220
221 namespace std {
222
223 template <>
224 class numeric_limits<c10::Float8_e4m3fnuz> {
225 public:
226 static constexpr bool is_specialized = true;
227 static constexpr bool is_signed = true;
228 static constexpr bool is_integer = false;
229 static constexpr bool is_exact = false;
230 static constexpr bool has_infinity = false;
231 static constexpr bool has_quiet_NaN = true;
232 static constexpr bool has_signaling_NaN = false;
233 static constexpr auto has_denorm = true;
234 static constexpr auto has_denorm_loss = true;
235 static constexpr auto round_style = numeric_limits<float>::round_style;
236 static constexpr bool is_iec559 = false;
237 static constexpr bool is_bounded = true;
238 static constexpr bool is_modulo = false;
239 static constexpr int digits = 4;
240 static constexpr int digits10 = 0;
241 static constexpr int max_digits10 = 3;
242 static constexpr int radix = 2;
243 static constexpr int min_exponent = -6;
244 static constexpr int min_exponent10 = -1;
245 static constexpr int max_exponent = 8;
246 static constexpr int max_exponent10 = 2;
247 static constexpr auto traps = numeric_limits<float>::traps;
248 static constexpr auto tinyness_before = false;
249
min()250 static constexpr c10::Float8_e4m3fnuz min() {
251 return c10::Float8_e4m3fnuz(0x08, c10::Float8_e4m3fnuz::from_bits());
252 }
lowest()253 static constexpr c10::Float8_e4m3fnuz lowest() {
254 return c10::Float8_e4m3fnuz(0xFF, c10::Float8_e4m3fnuz::from_bits());
255 }
max()256 static constexpr c10::Float8_e4m3fnuz max() {
257 return c10::Float8_e4m3fnuz(0x7F, c10::Float8_e4m3fnuz::from_bits());
258 }
epsilon()259 static constexpr c10::Float8_e4m3fnuz epsilon() {
260 return c10::Float8_e4m3fnuz(0x28, c10::Float8_e4m3fnuz::from_bits());
261 }
round_error()262 static constexpr c10::Float8_e4m3fnuz round_error() {
263 return c10::Float8_e4m3fnuz(0x38, c10::Float8_e4m3fnuz::from_bits());
264 }
infinity()265 static constexpr c10::Float8_e4m3fnuz infinity() {
266 // NaN (no infinities)
267 return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
268 }
quiet_NaN()269 static constexpr c10::Float8_e4m3fnuz quiet_NaN() {
270 return c10::Float8_e4m3fnuz(0x80, c10::Float8_e4m3fnuz::from_bits());
271 }
denorm_min()272 static constexpr c10::Float8_e4m3fnuz denorm_min() {
273 return c10::Float8_e4m3fnuz(0x01, c10::Float8_e4m3fnuz::from_bits());
274 }
275 };
276
277 } // namespace std
278
279 C10_CLANG_DIAGNOSTIC_POP()
280