1 #pragma once
2
3 #include <c10/macros/Macros.h>
4 #include <cstring>
5 #include <limits>
6
7 C10_CLANG_DIAGNOSTIC_PUSH()
8 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
9 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
10 #endif
11
12 #define EXP_WIDTH_FP8 5
13 #define MAN_WIDTH_FP8 2
14 #define EXP_BIAS_FP8 15
15
16 namespace c10 {
17
18 /// Constructors
19
Float8_e5m2(float value)20 inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value)
21 : x(detail::fp8e5m2_from_fp32_value(value)) {}
22
23 /// Implicit conversions
24
25 inline C10_HOST_DEVICE Float8_e5m2::operator float() const {
26 return detail::fp8e5m2_to_fp32_value(x);
27 }
28
29 /// Special values helpers
30
isnan()31 inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const {
32 return (x & 0b01111111) > 0b01111100;
33 }
34
isinf()35 inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const {
36 return (x & 0b01111111) == 0b01111100;
37 }
38
39 /// Arithmetic
40
41 inline C10_HOST_DEVICE Float8_e5m2
42 operator+(const Float8_e5m2& a, const Float8_e5m2& b) {
43 return static_cast<float>(a) + static_cast<float>(b);
44 }
45
46 inline C10_HOST_DEVICE Float8_e5m2
47 operator-(const Float8_e5m2& a, const Float8_e5m2& b) {
48 return static_cast<float>(a) - static_cast<float>(b);
49 }
50
51 inline C10_HOST_DEVICE Float8_e5m2
52 operator*(const Float8_e5m2& a, const Float8_e5m2& b) {
53 return static_cast<float>(a) * static_cast<float>(b);
54 }
55
56 inline C10_HOST_DEVICE Float8_e5m2 operator/(
57 const Float8_e5m2& a,
58 const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ {
59 return static_cast<float>(a) / static_cast<float>(b);
60 }
61
62 inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) {
63 return -static_cast<float>(a);
64 }
65
66 inline C10_HOST_DEVICE Float8_e5m2& operator+=(
67 Float8_e5m2& a,
68 const Float8_e5m2& b) {
69 a = a + b;
70 return a;
71 }
72
73 inline C10_HOST_DEVICE Float8_e5m2& operator-=(
74 Float8_e5m2& a,
75 const Float8_e5m2& b) {
76 a = a - b;
77 return a;
78 }
79
80 inline C10_HOST_DEVICE Float8_e5m2& operator*=(
81 Float8_e5m2& a,
82 const Float8_e5m2& b) {
83 a = a * b;
84 return a;
85 }
86
87 inline C10_HOST_DEVICE Float8_e5m2& operator/=(
88 Float8_e5m2& a,
89 const Float8_e5m2& b) {
90 a = a / b;
91 return a;
92 }
93
94 /// Arithmetic with floats
95
96 inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) {
97 return static_cast<float>(a) + b;
98 }
99 inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) {
100 return static_cast<float>(a) - b;
101 }
102 inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) {
103 return static_cast<float>(a) * b;
104 }
105 inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b)
106 __ubsan_ignore_float_divide_by_zero__ {
107 return static_cast<float>(a) / b;
108 }
109
110 inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) {
111 return a + static_cast<float>(b);
112 }
113 inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) {
114 return a - static_cast<float>(b);
115 }
116 inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) {
117 return a * static_cast<float>(b);
118 }
119 inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b)
120 __ubsan_ignore_float_divide_by_zero__ {
121 return a / static_cast<float>(b);
122 }
123
124 inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) {
125 return a += static_cast<float>(b);
126 }
127 inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) {
128 return a -= static_cast<float>(b);
129 }
130 inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) {
131 return a *= static_cast<float>(b);
132 }
133 inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) {
134 return a /= static_cast<float>(b);
135 }
136
137 /// Arithmetic with doubles
138
139 inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) {
140 return static_cast<double>(a) + b;
141 }
142 inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) {
143 return static_cast<double>(a) - b;
144 }
145 inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) {
146 return static_cast<double>(a) * b;
147 }
148 inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b)
149 __ubsan_ignore_float_divide_by_zero__ {
150 return static_cast<double>(a) / b;
151 }
152
153 inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) {
154 return a + static_cast<double>(b);
155 }
156 inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) {
157 return a - static_cast<double>(b);
158 }
159 inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) {
160 return a * static_cast<double>(b);
161 }
162 inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b)
163 __ubsan_ignore_float_divide_by_zero__ {
164 return a / static_cast<double>(b);
165 }
166
167 /// Arithmetic with ints
168
169 inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) {
170 return a + static_cast<Float8_e5m2>(b);
171 }
172 inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) {
173 return a - static_cast<Float8_e5m2>(b);
174 }
175 inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) {
176 return a * static_cast<Float8_e5m2>(b);
177 }
178 inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) {
179 return a / static_cast<Float8_e5m2>(b);
180 }
181
182 inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) {
183 return static_cast<Float8_e5m2>(a) + b;
184 }
185 inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) {
186 return static_cast<Float8_e5m2>(a) - b;
187 }
188 inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) {
189 return static_cast<Float8_e5m2>(a) * b;
190 }
191 inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) {
192 return static_cast<Float8_e5m2>(a) / b;
193 }
194
195 //// Arithmetic with int64_t
196
197 inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) {
198 return a + static_cast<Float8_e5m2>(b);
199 }
200 inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) {
201 return a - static_cast<Float8_e5m2>(b);
202 }
203 inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) {
204 return a * static_cast<Float8_e5m2>(b);
205 }
206 inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) {
207 return a / static_cast<Float8_e5m2>(b);
208 }
209
210 inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) {
211 return static_cast<Float8_e5m2>(a) + b;
212 }
213 inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) {
214 return static_cast<Float8_e5m2>(a) - b;
215 }
216 inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) {
217 return static_cast<Float8_e5m2>(a) * b;
218 }
219 inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) {
220 return static_cast<Float8_e5m2>(a) / b;
221 }
222
223 /// NOTE: we do not define comparisons directly and instead rely on the implicit
224 /// conversion from c10::Float8_e5m2 to float.
225
226 } // namespace c10
227
228 namespace std {
229
230 template <>
231 class numeric_limits<c10::Float8_e5m2> {
232 public:
233 static constexpr bool is_signed = true;
234 static constexpr bool is_integer = false;
235 static constexpr bool is_specialized = true;
236 static constexpr bool is_exact = false;
237 static constexpr bool has_infinity = true;
238 static constexpr bool has_quiet_NaN = true;
239 static constexpr bool has_signaling_NaN = false;
240 static constexpr auto has_denorm = true;
241 static constexpr auto has_denorm_loss = true;
242 static constexpr auto round_style = numeric_limits<float>::round_style;
243 static constexpr bool is_iec559 = false;
244 static constexpr bool is_bounded = true;
245 static constexpr bool is_modulo = false;
246 static constexpr int digits = 3;
247 static constexpr int digits10 = 0;
248 static constexpr int max_digits10 = 2;
249 static constexpr int radix = 2;
250 static constexpr int min_exponent = -13;
251 static constexpr int min_exponent10 = -4;
252 static constexpr int max_exponent = 16;
253 static constexpr int max_exponent10 = 4;
254 static constexpr auto traps = numeric_limits<float>::traps;
255 static constexpr auto tinyness_before =
256 numeric_limits<float>::tinyness_before;
257
min()258 static constexpr c10::Float8_e5m2 min() {
259 return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits());
260 }
max()261 static constexpr c10::Float8_e5m2 max() {
262 return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits());
263 }
lowest()264 static constexpr c10::Float8_e5m2 lowest() {
265 return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits());
266 }
epsilon()267 static constexpr c10::Float8_e5m2 epsilon() {
268 return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits());
269 }
round_error()270 static constexpr c10::Float8_e5m2 round_error() {
271 return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits());
272 }
infinity()273 static constexpr c10::Float8_e5m2 infinity() {
274 return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
275 }
quiet_NaN()276 static constexpr c10::Float8_e5m2 quiet_NaN() {
277 return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits());
278 }
denorm_min()279 static constexpr c10::Float8_e5m2 denorm_min() {
280 return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
281 }
282 };
283
284 } // namespace std
285
286 C10_CLANG_DIAGNOSTIC_POP()
287