xref: /aosp_15_r20/external/pytorch/c10/util/Float8_e5m2-inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 #include <cstring>
5 #include <limits>
6 
7 C10_CLANG_DIAGNOSTIC_PUSH()
8 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
9 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
10 #endif
11 
12 #define EXP_WIDTH_FP8 5
13 #define MAN_WIDTH_FP8 2
14 #define EXP_BIAS_FP8 15
15 
16 namespace c10 {
17 
18 /// Constructors
19 
Float8_e5m2(float value)20 inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value)
21     : x(detail::fp8e5m2_from_fp32_value(value)) {}
22 
23 /// Implicit conversions
24 
25 inline C10_HOST_DEVICE Float8_e5m2::operator float() const {
26   return detail::fp8e5m2_to_fp32_value(x);
27 }
28 
29 /// Special values helpers
30 
isnan()31 inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const {
32   return (x & 0b01111111) > 0b01111100;
33 }
34 
isinf()35 inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const {
36   return (x & 0b01111111) == 0b01111100;
37 }
38 
39 /// Arithmetic
40 
41 inline C10_HOST_DEVICE Float8_e5m2
42 operator+(const Float8_e5m2& a, const Float8_e5m2& b) {
43   return static_cast<float>(a) + static_cast<float>(b);
44 }
45 
46 inline C10_HOST_DEVICE Float8_e5m2
47 operator-(const Float8_e5m2& a, const Float8_e5m2& b) {
48   return static_cast<float>(a) - static_cast<float>(b);
49 }
50 
51 inline C10_HOST_DEVICE Float8_e5m2
52 operator*(const Float8_e5m2& a, const Float8_e5m2& b) {
53   return static_cast<float>(a) * static_cast<float>(b);
54 }
55 
56 inline C10_HOST_DEVICE Float8_e5m2 operator/(
57     const Float8_e5m2& a,
58     const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ {
59   return static_cast<float>(a) / static_cast<float>(b);
60 }
61 
62 inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) {
63   return -static_cast<float>(a);
64 }
65 
66 inline C10_HOST_DEVICE Float8_e5m2& operator+=(
67     Float8_e5m2& a,
68     const Float8_e5m2& b) {
69   a = a + b;
70   return a;
71 }
72 
73 inline C10_HOST_DEVICE Float8_e5m2& operator-=(
74     Float8_e5m2& a,
75     const Float8_e5m2& b) {
76   a = a - b;
77   return a;
78 }
79 
80 inline C10_HOST_DEVICE Float8_e5m2& operator*=(
81     Float8_e5m2& a,
82     const Float8_e5m2& b) {
83   a = a * b;
84   return a;
85 }
86 
87 inline C10_HOST_DEVICE Float8_e5m2& operator/=(
88     Float8_e5m2& a,
89     const Float8_e5m2& b) {
90   a = a / b;
91   return a;
92 }
93 
94 /// Arithmetic with floats
95 
96 inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) {
97   return static_cast<float>(a) + b;
98 }
99 inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) {
100   return static_cast<float>(a) - b;
101 }
102 inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) {
103   return static_cast<float>(a) * b;
104 }
105 inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b)
106     __ubsan_ignore_float_divide_by_zero__ {
107   return static_cast<float>(a) / b;
108 }
109 
110 inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) {
111   return a + static_cast<float>(b);
112 }
113 inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) {
114   return a - static_cast<float>(b);
115 }
116 inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) {
117   return a * static_cast<float>(b);
118 }
119 inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b)
120     __ubsan_ignore_float_divide_by_zero__ {
121   return a / static_cast<float>(b);
122 }
123 
124 inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) {
125   return a += static_cast<float>(b);
126 }
127 inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) {
128   return a -= static_cast<float>(b);
129 }
130 inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) {
131   return a *= static_cast<float>(b);
132 }
133 inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) {
134   return a /= static_cast<float>(b);
135 }
136 
137 /// Arithmetic with doubles
138 
139 inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) {
140   return static_cast<double>(a) + b;
141 }
142 inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) {
143   return static_cast<double>(a) - b;
144 }
145 inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) {
146   return static_cast<double>(a) * b;
147 }
148 inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b)
149     __ubsan_ignore_float_divide_by_zero__ {
150   return static_cast<double>(a) / b;
151 }
152 
153 inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) {
154   return a + static_cast<double>(b);
155 }
156 inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) {
157   return a - static_cast<double>(b);
158 }
159 inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) {
160   return a * static_cast<double>(b);
161 }
162 inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b)
163     __ubsan_ignore_float_divide_by_zero__ {
164   return a / static_cast<double>(b);
165 }
166 
167 /// Arithmetic with ints
168 
169 inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) {
170   return a + static_cast<Float8_e5m2>(b);
171 }
172 inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) {
173   return a - static_cast<Float8_e5m2>(b);
174 }
175 inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) {
176   return a * static_cast<Float8_e5m2>(b);
177 }
178 inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) {
179   return a / static_cast<Float8_e5m2>(b);
180 }
181 
182 inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) {
183   return static_cast<Float8_e5m2>(a) + b;
184 }
185 inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) {
186   return static_cast<Float8_e5m2>(a) - b;
187 }
188 inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) {
189   return static_cast<Float8_e5m2>(a) * b;
190 }
191 inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) {
192   return static_cast<Float8_e5m2>(a) / b;
193 }
194 
195 //// Arithmetic with int64_t
196 
197 inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) {
198   return a + static_cast<Float8_e5m2>(b);
199 }
200 inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) {
201   return a - static_cast<Float8_e5m2>(b);
202 }
203 inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) {
204   return a * static_cast<Float8_e5m2>(b);
205 }
206 inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) {
207   return a / static_cast<Float8_e5m2>(b);
208 }
209 
210 inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) {
211   return static_cast<Float8_e5m2>(a) + b;
212 }
213 inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) {
214   return static_cast<Float8_e5m2>(a) - b;
215 }
216 inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) {
217   return static_cast<Float8_e5m2>(a) * b;
218 }
219 inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) {
220   return static_cast<Float8_e5m2>(a) / b;
221 }
222 
223 /// NOTE: we do not define comparisons directly and instead rely on the implicit
224 /// conversion from c10::Float8_e5m2 to float.
225 
226 } // namespace c10
227 
228 namespace std {
229 
230 template <>
231 class numeric_limits<c10::Float8_e5m2> {
232  public:
233   static constexpr bool is_signed = true;
234   static constexpr bool is_integer = false;
235   static constexpr bool is_specialized = true;
236   static constexpr bool is_exact = false;
237   static constexpr bool has_infinity = true;
238   static constexpr bool has_quiet_NaN = true;
239   static constexpr bool has_signaling_NaN = false;
240   static constexpr auto has_denorm = true;
241   static constexpr auto has_denorm_loss = true;
242   static constexpr auto round_style = numeric_limits<float>::round_style;
243   static constexpr bool is_iec559 = false;
244   static constexpr bool is_bounded = true;
245   static constexpr bool is_modulo = false;
246   static constexpr int digits = 3;
247   static constexpr int digits10 = 0;
248   static constexpr int max_digits10 = 2;
249   static constexpr int radix = 2;
250   static constexpr int min_exponent = -13;
251   static constexpr int min_exponent10 = -4;
252   static constexpr int max_exponent = 16;
253   static constexpr int max_exponent10 = 4;
254   static constexpr auto traps = numeric_limits<float>::traps;
255   static constexpr auto tinyness_before =
256       numeric_limits<float>::tinyness_before;
257 
min()258   static constexpr c10::Float8_e5m2 min() {
259     return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits());
260   }
max()261   static constexpr c10::Float8_e5m2 max() {
262     return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits());
263   }
lowest()264   static constexpr c10::Float8_e5m2 lowest() {
265     return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits());
266   }
epsilon()267   static constexpr c10::Float8_e5m2 epsilon() {
268     return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits());
269   }
round_error()270   static constexpr c10::Float8_e5m2 round_error() {
271     return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits());
272   }
infinity()273   static constexpr c10::Float8_e5m2 infinity() {
274     return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
275   }
quiet_NaN()276   static constexpr c10::Float8_e5m2 quiet_NaN() {
277     return c10::Float8_e5m2(0x7F, c10::Float8_e5m2::from_bits());
278   }
denorm_min()279   static constexpr c10::Float8_e5m2 denorm_min() {
280     return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
281   }
282 };
283 
284 } // namespace std
285 
286 C10_CLANG_DIAGNOSTIC_POP()
287