xref: /aosp_15_r20/external/pytorch/c10/util/Float8_e4m3fn-inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 #include <cstdint>
5 #include <limits>
6 
7 C10_CLANG_DIAGNOSTIC_PUSH()
8 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
9 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
10 #endif
11 
12 namespace c10 {
13 
14 /// Constructors
15 
Float8_e4m3fn(float value)16 inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value)
17     : x(detail::fp8e4m3fn_from_fp32_value(value)) {}
18 
19 /// Implicit conversions
20 
21 inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const {
22   return detail::fp8e4m3fn_to_fp32_value(x);
23 }
24 
25 /// Special values helper
26 
isnan()27 inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const {
28   return (x & 0b01111111) == 0b01111111;
29 }
30 
31 /// Arithmetic
32 
33 inline C10_HOST_DEVICE Float8_e4m3fn
34 operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
35   return static_cast<float>(a) + static_cast<float>(b);
36 }
37 
38 inline C10_HOST_DEVICE Float8_e4m3fn
39 operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
40   return static_cast<float>(a) - static_cast<float>(b);
41 }
42 
43 inline C10_HOST_DEVICE Float8_e4m3fn
44 operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
45   return static_cast<float>(a) * static_cast<float>(b);
46 }
47 
48 inline C10_HOST_DEVICE Float8_e4m3fn operator/(
49     const Float8_e4m3fn& a,
50     const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ {
51   return static_cast<float>(a) / static_cast<float>(b);
52 }
53 
54 inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) {
55   return -static_cast<float>(a);
56 }
57 
58 inline C10_HOST_DEVICE Float8_e4m3fn& operator+=(
59     Float8_e4m3fn& a,
60     const Float8_e4m3fn& b) {
61   a = a + b;
62   return a;
63 }
64 
65 inline C10_HOST_DEVICE Float8_e4m3fn& operator-=(
66     Float8_e4m3fn& a,
67     const Float8_e4m3fn& b) {
68   a = a - b;
69   return a;
70 }
71 
72 inline C10_HOST_DEVICE Float8_e4m3fn& operator*=(
73     Float8_e4m3fn& a,
74     const Float8_e4m3fn& b) {
75   a = a * b;
76   return a;
77 }
78 
79 inline C10_HOST_DEVICE Float8_e4m3fn& operator/=(
80     Float8_e4m3fn& a,
81     const Float8_e4m3fn& b) {
82   a = a / b;
83   return a;
84 }
85 
86 /// Arithmetic with floats
87 
88 inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) {
89   return static_cast<float>(a) + b;
90 }
91 inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) {
92   return static_cast<float>(a) - b;
93 }
94 inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) {
95   return static_cast<float>(a) * b;
96 }
97 inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b)
98     __ubsan_ignore_float_divide_by_zero__ {
99   return static_cast<float>(a) / b;
100 }
101 
102 inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) {
103   return a + static_cast<float>(b);
104 }
105 inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) {
106   return a - static_cast<float>(b);
107 }
108 inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) {
109   return a * static_cast<float>(b);
110 }
111 inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b)
112     __ubsan_ignore_float_divide_by_zero__ {
113   return a / static_cast<float>(b);
114 }
115 
116 inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) {
117   return a += static_cast<float>(b);
118 }
119 inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) {
120   return a -= static_cast<float>(b);
121 }
122 inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) {
123   return a *= static_cast<float>(b);
124 }
125 inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) {
126   return a /= static_cast<float>(b);
127 }
128 
129 /// Arithmetic with doubles
130 
131 inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) {
132   return static_cast<double>(a) + b;
133 }
134 inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) {
135   return static_cast<double>(a) - b;
136 }
137 inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) {
138   return static_cast<double>(a) * b;
139 }
140 inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b)
141     __ubsan_ignore_float_divide_by_zero__ {
142   return static_cast<double>(a) / b;
143 }
144 
145 inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) {
146   return a + static_cast<double>(b);
147 }
148 inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) {
149   return a - static_cast<double>(b);
150 }
151 inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) {
152   return a * static_cast<double>(b);
153 }
154 inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b)
155     __ubsan_ignore_float_divide_by_zero__ {
156   return a / static_cast<double>(b);
157 }
158 
159 /// Arithmetic with ints
160 
161 inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) {
162   return a + static_cast<Float8_e4m3fn>(b);
163 }
164 inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) {
165   return a - static_cast<Float8_e4m3fn>(b);
166 }
167 inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) {
168   return a * static_cast<Float8_e4m3fn>(b);
169 }
170 inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) {
171   return a / static_cast<Float8_e4m3fn>(b);
172 }
173 
174 inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) {
175   return static_cast<Float8_e4m3fn>(a) + b;
176 }
177 inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) {
178   return static_cast<Float8_e4m3fn>(a) - b;
179 }
180 inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) {
181   return static_cast<Float8_e4m3fn>(a) * b;
182 }
183 inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) {
184   return static_cast<Float8_e4m3fn>(a) / b;
185 }
186 
187 //// Arithmetic with int64_t
188 
189 inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) {
190   return a + static_cast<Float8_e4m3fn>(b);
191 }
192 inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) {
193   return a - static_cast<Float8_e4m3fn>(b);
194 }
195 inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) {
196   return a * static_cast<Float8_e4m3fn>(b);
197 }
198 inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) {
199   return a / static_cast<Float8_e4m3fn>(b);
200 }
201 
202 inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) {
203   return static_cast<Float8_e4m3fn>(a) + b;
204 }
205 inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) {
206   return static_cast<Float8_e4m3fn>(a) - b;
207 }
208 inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) {
209   return static_cast<Float8_e4m3fn>(a) * b;
210 }
211 inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) {
212   return static_cast<Float8_e4m3fn>(a) / b;
213 }
214 
215 /// NOTE: we do not define comparisons directly and instead rely on the implicit
216 /// conversion from c10::Float8_e4m3fn to float.
217 
218 } // namespace c10
219 
220 namespace std {
221 
222 template <>
223 class numeric_limits<c10::Float8_e4m3fn> {
224  public:
225   static constexpr bool is_specialized = true;
226   static constexpr bool is_signed = true;
227   static constexpr bool is_integer = false;
228   static constexpr bool is_exact = false;
229   static constexpr bool has_infinity = false;
230   static constexpr bool has_quiet_NaN = true;
231   static constexpr bool has_signaling_NaN = false;
232   static constexpr auto has_denorm = true;
233   static constexpr auto has_denorm_loss = true;
234   static constexpr auto round_style = numeric_limits<float>::round_style;
235   static constexpr bool is_iec559 = false;
236   static constexpr bool is_bounded = true;
237   static constexpr bool is_modulo = false;
238   static constexpr int digits = 4;
239   static constexpr int digits10 = 0;
240   static constexpr int max_digits10 = 3;
241   static constexpr int radix = 2;
242   static constexpr int min_exponent = -5;
243   static constexpr int min_exponent10 = -1;
244   static constexpr int max_exponent = 8;
245   static constexpr int max_exponent10 = 2;
246   static constexpr auto traps = numeric_limits<float>::traps;
247   static constexpr auto tinyness_before = false;
248 
min()249   static constexpr c10::Float8_e4m3fn min() {
250     return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits());
251   }
lowest()252   static constexpr c10::Float8_e4m3fn lowest() {
253     return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits());
254   }
max()255   static constexpr c10::Float8_e4m3fn max() {
256     return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits());
257   }
epsilon()258   static constexpr c10::Float8_e4m3fn epsilon() {
259     return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits());
260   }
round_error()261   static constexpr c10::Float8_e4m3fn round_error() {
262     return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits());
263   }
quiet_NaN()264   static constexpr c10::Float8_e4m3fn quiet_NaN() {
265     return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits());
266   }
denorm_min()267   static constexpr c10::Float8_e4m3fn denorm_min() {
268     return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits());
269   }
270 };
271 
272 } // namespace std
273 
274 C10_CLANG_DIAGNOSTIC_POP()
275