1 #pragma once
2
3 /// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions
4 /// to standard C types and basic arithmetic operations. Note that arithmetic
5 /// operations are implemented by converting to floating point and
6 /// performing the operation in float32.
7 /// Binary configuration:
8 /// s eeee mmm
9 /// 1 sign bit
10 /// 4 exponent bits
11 /// 3 mantissa bits
12 /// bias = 7
13 ///
14 /// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
15 /// and inspired by Half implementation from pytorch/c10/util/Half.h
16
17 #include <c10/macros/Macros.h>
18 #include <c10/util/floating_point_utils.h>
19
20 #if defined(__cplusplus)
21 #include <cmath>
22 #include <cstdint>
23 #elif !defined(__OPENCL_VERSION__)
24 #include <math.h>
25 #include <stdint.h>
26 #endif
27
28 #ifdef _MSC_VER
29 #include <intrin.h>
30 #endif
31
32 #include <climits>
33 #include <iostream>
34
35 namespace c10 {
36
37 namespace detail {
38
39 /*
40 * Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit
41 * representation, to a 32-bit floating-point number in IEEE single-precision
42 * format, in bit representation.
43 *
44 * @note The implementation doesn't use any floating-point operations.
45 */
fp8e4m3fn_to_fp32_value(uint8_t input)46 inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
47 /*
48 * Extend the fp8 E4M3FN number to 32 bits and shift to the
49 * upper part of the 32-bit word:
50 * +---+----+---+-----------------------------+
51 * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
52 * +---+----+---+-----------------------------+
53 * Bits 31 27-30 24-26 0-23
54 *
55 * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
56 * - zero bits.
57 */
58 const uint32_t w = (uint32_t)input << 24;
59 /*
60 * Extract the sign of the input number into the high bit of the 32-bit word:
61 *
62 * +---+----------------------------------+
63 * | S |0000000 00000000 00000000 00000000|
64 * +---+----------------------------------+
65 * Bits 31 0-31
66 */
67 const uint32_t sign = w & UINT32_C(0x80000000);
68 /*
69 * Extract mantissa and biased exponent of the input number into the bits 0-30
70 * of the 32-bit word:
71 *
72 * +---+----+---+-----------------------------+
73 * | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
74 * +---+----+---+-----------------------------+
75 * Bits 31 27-30 24-26 0-23
76 */
77 const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
78 /*
79 * Renorm shift is the number of bits to shift mantissa left to make the
80 * half-precision number normalized. If the initial number is normalized, some
81 * of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case
82 * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
83 * that if we shift denormalized nonsign by renorm_shift, the unit bit of
84 * mantissa will shift into exponent, turning the biased exponent into 1, and
85 * making mantissa normalized (i.e. without leading 1).
86 */
87 #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
88 uint32_t renorm_shift = __clz(nonsign);
89 #elif defined(__SYCL_DEVICE_ONLY__)
90 // Note: zero is not a supported input into `__builtin_clz`
91 uint32_t renorm_shift =
92 nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
93 #elif defined(_MSC_VER)
94 unsigned long nonsign_bsr;
95 _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
96 uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
97 #else
98 // Note: zero is not a supported input into `__builtin_clz`
99 uint32_t renorm_shift =
100 nonsign != 0 ? __builtin_clz(nonsign) : sizeof(uint32_t) * CHAR_BIT;
101 #endif
102 renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
103 /*
104 * Iff fp8e4m3fn number has all exponent and mantissa bits set to 1,
105 * the addition overflows it into bit 31, and the subsequent shift turns the
106 * high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number
107 * is Nan, 0x00000000 otherwise
108 */
109 const int32_t inf_nan_mask =
110 ((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000);
111 /*
112 * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
113 * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
114 * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
115 * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
116 * 0x00000000 otherwise
117 */
118 const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
119 /*
120 * 1. Shift nonsign left by renorm_shift to normalize it (if the input
121 * was denormal)
122 * 2. Shift nonsign right by 4 so the exponent (4 bits originally)
123 * becomes an 8-bit field and 3-bit mantissa shifts into the 3 high
124 * bits of the 23-bit mantissa of IEEE single-precision number.
125 * 3. Add 0x78 to the exponent (starting at bit 23) to compensate the
126 * different in exponent bias (0x7F for single-precision number less 0x07
127 * for fp8e4m3fn number).
128 * 4. Subtract renorm_shift from the exponent (starting at bit 23) to
129 * account for renormalization. As renorm_shift is less than 0x78, this
130 * can be combined with step 3.
131 * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
132 * input was NaN or infinity.
133 * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
134 * into zero if the input was zero.
135 * 7. Combine with the sign of the input number.
136 */
137 uint32_t result = sign |
138 ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
139 inf_nan_mask) &
140 ~zero_mask);
141 return fp32_from_bits(result);
142 }
143
144 /*
145 * Convert a 32-bit floating-point number in IEEE single-precision format to a
146 * 8-bit floating-point number in fp8 E4M3FN format, in bit representation.
147 */
fp8e4m3fn_from_fp32_value(float f)148 inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) {
149 /*
150 * Binary representation of 480.0f, which is the first value
151 * not representable in fp8e4m3fn range:
152 * 0 1111 111 - fp8e4m3fn
153 * 0 10000111 11100000000000000000000 - fp32
154 */
155 constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
156
157 /*
158 * A mask for converting fp32 numbers lower than fp8e4m3fn normal range
159 * into denorm representation
160 * magic number: ((127 - 7) + (23 - 3) + 1)
161 */
162 constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
163
164 uint32_t f_bits = fp32_to_bits(f);
165
166 uint8_t result = 0u;
167
168 /*
169 * Extract the sign of the input number into the high bit of the 32-bit word:
170 *
171 * +---+----------------------------------+
172 * | S |0000000 00000000 00000000 00000000|
173 * +---+----------------------------------+
174 * Bits 31 0-31
175 */
176 const uint32_t sign = f_bits & UINT32_C(0x80000000);
177
178 /*
179 * Set sign bit to 0
180 */
181 f_bits ^= sign;
182
183 if (f_bits >= fp8_max) {
184 // NaN - all exponent and mantissa bits set to 1
185 result = 0x7f;
186 } else {
187 if (f_bits < (UINT32_C(121) << 23)) {
188 // Input number is smaller than 2^(-6), which is the smallest
189 // fp8e4m3fn normal number
190 f_bits =
191 fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
192 result = static_cast<uint8_t>(f_bits - denorm_mask);
193 } else {
194 // resulting mantissa is odd
195 uint8_t mant_odd = (f_bits >> 20) & 1;
196
197 // update exponent, rounding bias part 1
198 f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
199
200 // rounding bias part 2
201 f_bits += mant_odd;
202
203 // take the bits!
204 result = static_cast<uint8_t>(f_bits >> 20);
205 }
206 }
207
208 result |= static_cast<uint8_t>(sign >> 24);
209 return result;
210 }
211
212 } // namespace detail
213
214 struct alignas(1) Float8_e4m3fn {
215 uint8_t x;
216
217 struct from_bits_t {};
from_bitsFloat8_e4m3fn218 C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
219 return from_bits_t();
220 }
221
222 Float8_e4m3fn() = default;
223
Float8_e4m3fnFloat8_e4m3fn224 constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t)
225 : x(bits) {}
226 inline C10_HOST_DEVICE Float8_e4m3fn(float value);
227 inline C10_HOST_DEVICE operator float() const;
228 inline C10_HOST_DEVICE bool isnan() const;
229 };
230
231 C10_API inline std::ostream& operator<<(
232 std::ostream& out,
233 const Float8_e4m3fn& value) {
234 out << (float)value;
235 return out;
236 }
237
238 } // namespace c10
239
240 #include <c10/util/Float8_e4m3fn-inl.h> // IWYU pragma: keep
241