1 #pragma once
2
3 #include <c10/macros/Macros.h>
4 #include <cstdint>
5 #include <limits>
6
7 C10_CLANG_DIAGNOSTIC_PUSH()
8 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
9 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
10 #endif
11
12 namespace c10 {
13
14 /// Constructors
15
Float8_e4m3fn(float value)16 inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value)
17 : x(detail::fp8e4m3fn_from_fp32_value(value)) {}
18
19 /// Implicit conversions
20
21 inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const {
22 return detail::fp8e4m3fn_to_fp32_value(x);
23 }
24
25 /// Special values helper
26
isnan()27 inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const {
28 return (x & 0b01111111) == 0b01111111;
29 }
30
31 /// Arithmetic
32
33 inline C10_HOST_DEVICE Float8_e4m3fn
34 operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
35 return static_cast<float>(a) + static_cast<float>(b);
36 }
37
38 inline C10_HOST_DEVICE Float8_e4m3fn
39 operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
40 return static_cast<float>(a) - static_cast<float>(b);
41 }
42
43 inline C10_HOST_DEVICE Float8_e4m3fn
44 operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
45 return static_cast<float>(a) * static_cast<float>(b);
46 }
47
48 inline C10_HOST_DEVICE Float8_e4m3fn operator/(
49 const Float8_e4m3fn& a,
50 const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ {
51 return static_cast<float>(a) / static_cast<float>(b);
52 }
53
54 inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) {
55 return -static_cast<float>(a);
56 }
57
58 inline C10_HOST_DEVICE Float8_e4m3fn& operator+=(
59 Float8_e4m3fn& a,
60 const Float8_e4m3fn& b) {
61 a = a + b;
62 return a;
63 }
64
65 inline C10_HOST_DEVICE Float8_e4m3fn& operator-=(
66 Float8_e4m3fn& a,
67 const Float8_e4m3fn& b) {
68 a = a - b;
69 return a;
70 }
71
72 inline C10_HOST_DEVICE Float8_e4m3fn& operator*=(
73 Float8_e4m3fn& a,
74 const Float8_e4m3fn& b) {
75 a = a * b;
76 return a;
77 }
78
79 inline C10_HOST_DEVICE Float8_e4m3fn& operator/=(
80 Float8_e4m3fn& a,
81 const Float8_e4m3fn& b) {
82 a = a / b;
83 return a;
84 }
85
86 /// Arithmetic with floats
87
88 inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) {
89 return static_cast<float>(a) + b;
90 }
91 inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) {
92 return static_cast<float>(a) - b;
93 }
94 inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) {
95 return static_cast<float>(a) * b;
96 }
97 inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b)
98 __ubsan_ignore_float_divide_by_zero__ {
99 return static_cast<float>(a) / b;
100 }
101
102 inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) {
103 return a + static_cast<float>(b);
104 }
105 inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) {
106 return a - static_cast<float>(b);
107 }
108 inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) {
109 return a * static_cast<float>(b);
110 }
111 inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b)
112 __ubsan_ignore_float_divide_by_zero__ {
113 return a / static_cast<float>(b);
114 }
115
116 inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) {
117 return a += static_cast<float>(b);
118 }
119 inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) {
120 return a -= static_cast<float>(b);
121 }
122 inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) {
123 return a *= static_cast<float>(b);
124 }
125 inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) {
126 return a /= static_cast<float>(b);
127 }
128
129 /// Arithmetic with doubles
130
131 inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) {
132 return static_cast<double>(a) + b;
133 }
134 inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) {
135 return static_cast<double>(a) - b;
136 }
137 inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) {
138 return static_cast<double>(a) * b;
139 }
140 inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b)
141 __ubsan_ignore_float_divide_by_zero__ {
142 return static_cast<double>(a) / b;
143 }
144
145 inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) {
146 return a + static_cast<double>(b);
147 }
148 inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) {
149 return a - static_cast<double>(b);
150 }
151 inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) {
152 return a * static_cast<double>(b);
153 }
154 inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b)
155 __ubsan_ignore_float_divide_by_zero__ {
156 return a / static_cast<double>(b);
157 }
158
159 /// Arithmetic with ints
160
161 inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) {
162 return a + static_cast<Float8_e4m3fn>(b);
163 }
164 inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) {
165 return a - static_cast<Float8_e4m3fn>(b);
166 }
167 inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) {
168 return a * static_cast<Float8_e4m3fn>(b);
169 }
170 inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) {
171 return a / static_cast<Float8_e4m3fn>(b);
172 }
173
174 inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) {
175 return static_cast<Float8_e4m3fn>(a) + b;
176 }
177 inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) {
178 return static_cast<Float8_e4m3fn>(a) - b;
179 }
180 inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) {
181 return static_cast<Float8_e4m3fn>(a) * b;
182 }
183 inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) {
184 return static_cast<Float8_e4m3fn>(a) / b;
185 }
186
187 //// Arithmetic with int64_t
188
189 inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) {
190 return a + static_cast<Float8_e4m3fn>(b);
191 }
192 inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) {
193 return a - static_cast<Float8_e4m3fn>(b);
194 }
195 inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) {
196 return a * static_cast<Float8_e4m3fn>(b);
197 }
198 inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) {
199 return a / static_cast<Float8_e4m3fn>(b);
200 }
201
202 inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) {
203 return static_cast<Float8_e4m3fn>(a) + b;
204 }
205 inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) {
206 return static_cast<Float8_e4m3fn>(a) - b;
207 }
208 inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) {
209 return static_cast<Float8_e4m3fn>(a) * b;
210 }
211 inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) {
212 return static_cast<Float8_e4m3fn>(a) / b;
213 }
214
215 /// NOTE: we do not define comparisons directly and instead rely on the implicit
216 /// conversion from c10::Float8_e4m3fn to float.
217
218 } // namespace c10
219
220 namespace std {
221
222 template <>
223 class numeric_limits<c10::Float8_e4m3fn> {
224 public:
225 static constexpr bool is_specialized = true;
226 static constexpr bool is_signed = true;
227 static constexpr bool is_integer = false;
228 static constexpr bool is_exact = false;
229 static constexpr bool has_infinity = false;
230 static constexpr bool has_quiet_NaN = true;
231 static constexpr bool has_signaling_NaN = false;
232 static constexpr auto has_denorm = true;
233 static constexpr auto has_denorm_loss = true;
234 static constexpr auto round_style = numeric_limits<float>::round_style;
235 static constexpr bool is_iec559 = false;
236 static constexpr bool is_bounded = true;
237 static constexpr bool is_modulo = false;
238 static constexpr int digits = 4;
239 static constexpr int digits10 = 0;
240 static constexpr int max_digits10 = 3;
241 static constexpr int radix = 2;
242 static constexpr int min_exponent = -5;
243 static constexpr int min_exponent10 = -1;
244 static constexpr int max_exponent = 8;
245 static constexpr int max_exponent10 = 2;
246 static constexpr auto traps = numeric_limits<float>::traps;
247 static constexpr auto tinyness_before = false;
248
min()249 static constexpr c10::Float8_e4m3fn min() {
250 return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits());
251 }
lowest()252 static constexpr c10::Float8_e4m3fn lowest() {
253 return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits());
254 }
max()255 static constexpr c10::Float8_e4m3fn max() {
256 return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits());
257 }
epsilon()258 static constexpr c10::Float8_e4m3fn epsilon() {
259 return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits());
260 }
round_error()261 static constexpr c10::Float8_e4m3fn round_error() {
262 return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits());
263 }
quiet_NaN()264 static constexpr c10::Float8_e4m3fn quiet_NaN() {
265 return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits());
266 }
denorm_min()267 static constexpr c10::Float8_e4m3fn denorm_min() {
268 return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits());
269 }
270 };
271
272 } // namespace std
273
274 C10_CLANG_DIAGNOSTIC_POP()
275