1 #include <type_traits>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/core/Tensor.h>
4
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/cpu/vec/functional.h>
8 #include <ATen/cpu/vec/vec.h>
9 #include <ATen/native/cpu/int_mm_kernel.h>
10 #include <ATen/native/cpu/utils.h>
11 #include <c10/util/irange.h>
12 #include <c10/util/Unroll.h>
13
14 #if (defined(_WIN32) || defined(_WIN64))
15 #define RESTRICT __restrict
16 #else
17 #define RESTRICT __restrict__
18 #endif
19
20 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable")
21 namespace at::native {
22
23 namespace {
24
is_block_start(int index,int BLOCK_SIZE)25 inline bool is_block_start(int index, int BLOCK_SIZE) {
26 return !(index & (BLOCK_SIZE -1));
27 }
28
29 #if (defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
30 // convert 16x int4 to int8, handle 64 bits at a time
31 // used in avx2 and avx512
conver_int4_to_int8(const uint8_t * data)32 inline __m128i conver_int4_to_int8(const uint8_t* data) {
33 __m128i tmp = _mm_loadu_si64((const __m128i*)data);
34 __m128i bytes = _mm_cvtepu8_epi16(tmp);
35 const __m128i lowMask = _mm_set1_epi8(0xF);
36 __m128i high = _mm_andnot_si128(lowMask, bytes);
37 __m128i low = _mm_and_si128(lowMask, bytes);
38 high = _mm_slli_epi16(high, 4);
39 bytes = _mm_or_si128(low, high);
40 return bytes;
41 }
42 #endif
43
44 #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
45
46 // A block : {BLOCK_M, BLOCK_K}, lda = K
47 // B block : {BLOCK_K, BLOCK_N / 2}, ldb = BLOCK_N / 2
48 // C block : {BLOCK_M, BLOCK_N}, ldc = N
49 //
50 // ScaleAndZeros block : {1, BLOCK_N, 2}
51 //
52 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const BFloat16 * RESTRICT A,const uint8_t * RESTRICT B,const BFloat16 * RESTRICT ScaleAndZeros,BFloat16 * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)53 inline void tinygemm_kernel(
54 const BFloat16* RESTRICT A,
55 const uint8_t* RESTRICT B,
56 const BFloat16* RESTRICT ScaleAndZeros,
57 BFloat16* RESTRICT C,
58 int lda,
59 int ldb,
60 int ldc,
61 int K,
62 int BLOCK_K) {
63
64 constexpr int ROWS = BLOCK_M;
65 constexpr int COLS = BLOCK_N / 16;
66
67 const int PREFETCH_SIZE_K = 16 * 4;
68 const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
69
70 // number of blocks on K
71 const int KB = K / BLOCK_K;
72
73 __m512 va;
74 __m512 vb[COLS];
75 __m512 vc[ROWS * COLS];
76 __m512 scale[COLS];
77 __m512 zero[COLS];
78
79 // Lookup table to de-quantize int4 values to bf16.
80 // Values are dequantized as truly int4 [-8, 7] range;
81 //
82 // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
83 //
84 static const __m512 lut = _mm512_set_ps(
85 7.0f, 6.0f, 5.0f, 4.0f,
86 3.0f, 2.0f, 1.0f, 0.0f,
87 -1.0f, -2.0f, -3.0f, -4.0f,
88 -5.0f, -6.0f, -7.0f, -8.0f);
89
90 // index for transpose
91 static const __m512i idx1 = _mm512_set_epi32(
92 30, 28, 26, 24, 22, 20, 18, 16,
93 14, 12, 10, 8, 6, 4, 2, 0);
94 static const __m512i idx2 = _mm512_set_epi32(
95 31, 29, 27, 25, 23, 21, 19, 17,
96 15, 13, 11, 9, 7, 5, 3, 1);
97
98 // load scale and zero point
99 auto load_scale_and_zeros = [&](int i, int _kb) {
100 // load 2x bfloat16 vector
101 __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * ldc * 2 + 32 * i));
102 if (_kb + PREFETCH_SIZE_KB < KB) {
103 _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * ldc * 2 + 32 * i, _MM_HINT_T0);
104 }
105
106 // convert to 2x f32 vector
107 __m512 a, b;
108 vec::cvtbf16_fp32(t, a, b);
109
110 // transpose scale_and_zero from {16, 2} to {2, 16}
111 // inputs:
112 // a: {s0, z0, s1, z1, ..., s7, z7}
113 // b: {s8, z8, s9, z9, ..., s15, z15}
114 // output:
115 // scale: {s0, s1, s2, ..., s15}
116 // zero: {z0, z1, z2, ..., z15}
117 scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
118 zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
119 };
120
121 auto loadc = [&](auto i) {
122 vc[i] = _mm512_setzero_ps();
123 };
124 c10::ForcedUnroll<ROWS * COLS>{}(loadc);
125
126 auto compute = [&, COLS](auto i, int k) {
127 constexpr int row = i / COLS;
128 constexpr int col = i % COLS;
129
130 if constexpr (col == 0) {
131 float aa = static_cast<float>(A[row * lda + k]);
132 if (k + PREFETCH_SIZE_K < K) {
133 _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
134 }
135 va = _mm512_set1_ps(aa);
136 }
137
138 if constexpr (row == 0) {
139 if constexpr (COLS == 4) {
140 // when BLOCK_N = 64, handle each row at a time
141 // to reduce de-quantize overhead.
142 if constexpr (col == 0) {
143 __m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb));
144 if (k + PREFETCH_SIZE_K < K) {
145 _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
146 }
147
148 __m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
149 vb[0] = _mm512_permutexvar_ps(b32, lut);
150 vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
151 vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
152 vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);
153
154 b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
155 vb[1] = _mm512_permutexvar_ps(b32, lut);
156 vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
157 vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
158 vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
159 }
160 } else {
161 __m128i b8 = conver_int4_to_int8(B + k * ldb + col * 8);
162 __m512i b32 = _mm512_cvtepu8_epi32(b8);
163 vb[col] = _mm512_permutexvar_ps(b32, lut);
164 vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]);
165 }
166 }
167
168 constexpr int idx = row * COLS + col;
169 vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
170 };
171
172 for (int k = 0, kb = 0; k < K; ++k) {
173 if (is_block_start(k, BLOCK_K)) {
174 c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
175 }
176 c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
177 }
178
179 //store to C
180 auto storec = [&, COLS](auto i) {
181 constexpr int row = i / COLS;
182 constexpr int col = i % COLS;
183 if constexpr (COLS == 4) {
184 // when BLOCK_N = 64, handle each row at a time
185 // to reduce `cvtfp32_bf16` overhead.
186 if constexpr (col == 0) {
187 __m512i c01 = vec::cvtfp32_bf16(vc[row * 4 + 0], vc[row * 4 + 1]);
188 __m512i c23 = vec::cvtfp32_bf16(vc[row * 4 + 2], vc[row * 4 + 3]);
189 _mm512_storeu_si512((__m512i*)(C + row * ldc + 0 * 32), c01);
190 _mm512_storeu_si512((__m512i*)(C + row * ldc + 1 * 32), c23);
191 }
192 } else {
193 __m256i ci = vec::cvtfp32_bf16(vc[i]);
194 _mm256_storeu_si256((__m256i*)(C + row * ldc + col * 16), ci);
195 }
196 };
197 c10::ForcedUnroll<ROWS * COLS>{}(storec);
198 }
199
200 #elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
201
202 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const BFloat16 * RESTRICT A,const uint8_t * RESTRICT B,const BFloat16 * RESTRICT ScaleAndZeros,BFloat16 * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)203 inline void tinygemm_kernel(
204 const BFloat16* RESTRICT A,
205 const uint8_t* RESTRICT B,
206 const BFloat16* RESTRICT ScaleAndZeros,
207 BFloat16* RESTRICT C,
208 int lda,
209 int ldb,
210 int ldc,
211 int K,
212 int BLOCK_K) {
213
214 constexpr int ROWS = BLOCK_M;
215 constexpr int COLS = BLOCK_N / 8;
216
217 const int PREFETCH_SIZE_K = 16 * 4;
218 const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
219
220 // number of blocks on K
221 const int KB = K / BLOCK_K;
222
223 __m256 va;
224 __m256 vb[COLS];
225 __m256 vc[ROWS * COLS];
226 __m256 scale[COLS];
227 __m256 zero[COLS];
228
229 static const __m256i idx1 = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
230
231 // offset to shift from range [0, 15] to [-8, 7]
232 const __m256 offset = _mm256_set1_ps(-8.0f);
233
234 // load scale and zero point
235 auto load_scale_and_zeros = [&](int i, int _kb) {
236 // load 2x bfloat16 vector
237 __m256i t = _mm256_loadu_si256((__m256i*)(ScaleAndZeros + _kb * ldc * 2 + 16 * i));
238 if (_kb + PREFETCH_SIZE_KB < KB) {
239 _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * ldc * 2 + 16 * i, _MM_HINT_T0);
240 }
241
242 // convert to 2x f32 vector
243 __m256 a, b;
244 vec::cvtbf16_fp32(t, a, b);
245
246 // transpose scale_and_zero from {8, 2} to {2, 8}
247 // inputs:
248 // a: {s0, z0, s1, z1, s2, z2, s3, z3}
249 // b: {s4, z4, s5, z5, s6, z6, s7, z7}
250 // output:
251 // scale: {s0, s1, s2, s3, s4, s5, s6, s7}
252 // zero: {z0, z1, z2, z3, z4, z5, z6, z7}
253 a = _mm256_permutevar8x32_ps(a, idx1);
254 b = _mm256_permutevar8x32_ps(b, idx1);
255 scale[i] = _mm256_permute2f128_ps(a, b, 0b0100000);
256 zero[i] = _mm256_permute2f128_ps(a, b, 0b0110001);
257
258 // zero = -8 * scale + zero
259 zero[i] = _mm256_fmadd_ps(scale[i], offset, zero[i]);
260 };
261
262 auto loadc = [&](auto i) {
263 vc[i] = _mm256_setzero_ps();
264 };
265 c10::ForcedUnroll<ROWS * COLS>{}(loadc);
266
267 auto compute = [&, COLS](auto i, int k) {
268 constexpr int row = i / COLS;
269 constexpr int col = i % COLS;
270
271 if constexpr (col == 0) {
272 float aa = static_cast<float>(A[row * lda + k]);
273 if (k + PREFETCH_SIZE_K < K) {
274 _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
275 }
276 va = _mm256_set1_ps(aa);
277 }
278
279 if constexpr (row == 0) {
280 if constexpr (COLS == 4) {
281 // when BLOCK_N = 32, handle each row at a time
282 if constexpr (col == 0) {
283 __m256i mask = _mm256_set1_epi32(0xF);
284 __m128i b4 = _mm_loadu_si128((__m128i*)(B + k * ldb));
285 if (k + PREFETCH_SIZE_K < K) {
286 _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
287 }
288
289 __m256i b32 = _mm256_cvtepu8_epi32(b4);
290 vb[0] = _mm256_cvtepi32_ps(_mm256_and_si256(b32, mask));
291 vb[0] = _mm256_fmadd_ps(vb[0], scale[0], zero[0]);
292 vb[2] = _mm256_cvtepi32_ps(_mm256_srli_epi32(b32, 4));
293 vb[2] = _mm256_fmadd_ps(vb[2], scale[2], zero[2]);
294
295 b32 = _mm256_cvtepu8_epi32(_mm_shuffle_epi32(b4, _MM_SHUFFLE(3, 2, 3, 2)));
296 vb[1] = _mm256_cvtepi32_ps(_mm256_and_si256(b32, mask));
297 vb[1] = _mm256_fmadd_ps(vb[1], scale[1], zero[1]);
298 vb[3] = _mm256_cvtepi32_ps(_mm256_srli_epi32(b32, 4));
299 vb[3] = _mm256_fmadd_ps(vb[3], scale[3], zero[3]);
300 }
301 } else {
302 if constexpr (col % 2 == 0) {
303 // de-quantize per 64 bits (16x int4)
304 __m128i b8 = conver_int4_to_int8(B + k * ldb + col * 4);
305 __m128i b8_val0 = _mm_set1_epi64x(_mm_extract_epi64(b8, 0));
306 __m128i b8_val1 = _mm_set1_epi64x(_mm_extract_epi64(b8, 1));
307 if (k + PREFETCH_SIZE_K < K) {
308 _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb + col * 4, _MM_HINT_T0);
309 }
310
311 vb[col] = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(b8_val0));
312 vb[col] = _mm256_fmadd_ps(vb[col], scale[col], zero[col]);
313 vb[col + 1] = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(b8_val1));
314 vb[col + 1] = _mm256_fmadd_ps(vb[col + 1], scale[col + 1], zero[col + 1]);
315 }
316 }
317 }
318
319 constexpr int idx = row * COLS + col;
320 vc[idx] = _mm256_fmadd_ps(va, vb[col], vc[idx]);
321 };
322 for (int k = 0, kb = 0; k < K; ++k) {
323 if (is_block_start(k, BLOCK_K)) {
324 c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
325 }
326 c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
327 }
328
329 // store to C
330 auto storec = [&](auto i) {
331 constexpr int row = i / COLS;
332 constexpr int col = i % COLS;
333 if constexpr (col % 2 == 0) {
334 __m256i ci = vec::cvtfp32_bf16(vc[row * COLS + col], vc[row * COLS + col + 1]);
335 _mm256_storeu_si256((__m256i*)(C + row * ldc + col * 8), ci);
336 }
337 };
338 c10::ForcedUnroll<ROWS * COLS>{}(storec);
339 }
340
341 #endif
342
343 #if !defined(C10_MOBILE) && defined(__aarch64__)
344 #include <arm_neon.h>
345
load_as_float32x4x2(const Half * ptr)346 inline float32x4x2_t load_as_float32x4x2(const Half* ptr) {
347 float16x4x2_t f16_val = vld2_f16(reinterpret_cast<const float16_t *>(ptr));
348 auto val_low = vcvt_f32_f16(f16_val.val[0]);
349 auto val_high = vcvt_f32_f16(f16_val.val[1]);
350 return {val_low, val_high};
351 }
352
store_float32x4(Half * ptr,float32x4_t val)353 inline void store_float32x4(Half* ptr, float32x4_t val) {
354 vst1_f16(reinterpret_cast<float16_t*>(ptr), vcvt_f16_f32(val));
355 }
356
load_as_float32x4x2(const BFloat16 * ptr)357 inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) {
358 int32x4_t shift = vdupq_n_s32(16);
359 uint16x4x2_t u16_val = vld2_u16(reinterpret_cast<const uint16_t *>(ptr));
360 uint32x4_t int_low = vmovl_u16(u16_val.val[0]);
361 uint32x4_t int_high = vmovl_u16(u16_val.val[1]);
362 return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))};
363 }
364
store_float32x4(BFloat16 * ptr,float32x4_t val)365 inline void store_float32x4(BFloat16* ptr, float32x4_t val) {
366 int32x4_t shift = vdupq_n_s32(-16);
367 uint32x4_t uint32_val = vshlq_u32(vreinterpretq_u32_f32(val), shift);
368 vst1_u16(reinterpret_cast<uint16_t*>(ptr), vmovn_u32(uint32_val));
369 }
370
load_as_float32x4x2(const float * ptr)371 inline float32x4x2_t load_as_float32x4x2(const float* ptr) {
372 return vld2q_f32(ptr);
373 }
374
store_float32x4(float * ptr,float32x4_t val)375 inline void store_float32x4(float* ptr, float32x4_t val) {
376 vst1q_f32(ptr, val);
377 }
378
379 template <int BLOCK_M, int BLOCK_N, typename T>
tinygemm_kernel_(const T * RESTRICT A,const uint8_t * RESTRICT B,const T * RESTRICT ScaleAndZeros,T * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)380 inline void tinygemm_kernel_(
381 const T* RESTRICT A,
382 const uint8_t* RESTRICT B,
383 const T* RESTRICT ScaleAndZeros,
384 T* RESTRICT C,
385 int lda,
386 int ldb,
387 int ldc,
388 int K,
389 int BLOCK_K) {
390 int16_t shift_vals[4] = {0, -4, -8, -12};
391 int16x4_t shifts = vld1_s16(shift_vals);
392 int16x4_t offs = vdup_n_s16(8);
393 uint16x4_t mask = vdup_n_u16(0x0F);
394 for (const auto m : c10::irange(BLOCK_M)) {
395 for (int n = 0; n < BLOCK_N; n+= 16) {
396 float32x4_t c_val[4];
397 float32x4_t scales[4], zeros[4];
398 c10::ForcedUnroll<4>{}([&](auto i) {
399 c_val[i] = vdupq_n_f32(0.0);
400 });
401 for (const auto k : c10::irange(K)) {
402 const auto a_val = vdupq_n_f32(static_cast<float>(A[m * lda + k]));
403 if (is_block_start(k, BLOCK_K)) {
404 int kb = k / BLOCK_K;
405 c10::ForcedUnroll<4>{}([&](auto i) {
406 auto scales_and_zeros = load_as_float32x4x2(ScaleAndZeros + kb * ldc * 2 + n * 2 + i * 8);
407 scales[i] = scales_and_zeros.val[0];
408 zeros[i] = scales_and_zeros.val[1];
409 });
410 }
411 c10::ForcedUnroll<4>{}([&](auto i) {
412 uint16_t b_pack = reinterpret_cast<const uint16_t*>(B + k * ldb + n / 2)[i];
413 uint16x4_t b_masked = vand_u16(vshl_u16(vdup_n_u16(b_pack), shifts), mask);
414 int16x4_t b_ints = vsub_s16(vreinterpret_s16_u16(b_masked), offs);
415 float32x4_t b_vals = vcvtq_f32_s32(vmovl_s16(b_ints));
416 b_vals = vaddq_f32(zeros[i], vmulq_f32(scales[i], b_vals));
417 c_val[i] = vfmaq_f32(c_val[i], b_vals, a_val);
418 });
419 }
420 c10::ForcedUnroll<4>{}([&](auto i) {
421 store_float32x4(C + m * ldc + n + i * 4, c_val[i]);
422 });
423 }
424 }
425 }
426
427 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const Half * RESTRICT A,const uint8_t * RESTRICT B,const Half * RESTRICT ScaleAndZeros,Half * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)428 inline void tinygemm_kernel(
429 const Half* RESTRICT A,
430 const uint8_t* RESTRICT B,
431 const Half* RESTRICT ScaleAndZeros,
432 Half* RESTRICT C,
433 int lda,
434 int ldb,
435 int ldc,
436 int K,
437 int BLOCK_K) {
438 tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
439 }
440
441 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const BFloat16 * RESTRICT A,const uint8_t * RESTRICT B,const BFloat16 * RESTRICT ScaleAndZeros,BFloat16 * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)442 inline void tinygemm_kernel(
443 const BFloat16* RESTRICT A,
444 const uint8_t* RESTRICT B,
445 const BFloat16* RESTRICT ScaleAndZeros,
446 BFloat16* RESTRICT C,
447 int lda,
448 int ldb,
449 int ldc,
450 int K,
451 int BLOCK_K) {
452 tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
453 }
454
455 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const float * RESTRICT A,const uint8_t * RESTRICT B,const float * RESTRICT ScaleAndZeros,float * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)456 inline void tinygemm_kernel(
457 const float* RESTRICT A,
458 const uint8_t* RESTRICT B,
459 const float* RESTRICT ScaleAndZeros,
460 float* RESTRICT C,
461 int lda,
462 int ldb,
463 int ldc,
464 int K,
465 int BLOCK_K) {
466 tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
467 }
468 #endif
469
470 template<int BLOCK_N>
convert_int4_to_float(const uint8_t * b,int n)471 inline float convert_int4_to_float(const uint8_t* b, int n) {
472 static constexpr float lut[16] = {
473 -8.0f, -7.0f, -6.0f, -5.0f,
474 -4.0f, -3.0f, -2.0f, -1.0f,
475 0.0f, 1.0f, 2.0f, 3.0f,
476 4.0f, 5.0f, 6.0f, 7.0f
477 };
478 int index;
479 #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
480 if constexpr (BLOCK_N == 64) {
481 const int nb = n/BLOCK_N;
482 n -= nb*BLOCK_N;
483 if (n < 32) {
484 auto val = b[nb * BLOCK_N / 2 + n];
485 index = val & 0x0f;
486 } else {
487 auto val = b[nb * BLOCK_N / 2 + (n - 32)];
488 index = val >> 4;
489 }
490 } else
491 #elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
492 if constexpr (BLOCK_N == 32) {
493 const int nb = n/BLOCK_N;
494 n -= nb*BLOCK_N;
495 if (n < 16) {
496 auto val = b[nb * BLOCK_N / 2 + n];
497 index = val & 0x0f;
498 } else {
499 auto val = b[nb * BLOCK_N / 2 + (n - 16)];
500 index = val >> 4;
501 }
502 } else
503 #endif
504 {
505 const auto is_even = (n & 1) == 0;
506 auto val = b[n/2];
507 index = is_even ? (val & 0x0F) : (val >> 4);
508 }
509 return lut[index];
510 }
511
512 // non-vectorized version
513 template <int BLOCK_M, int BLOCK_N, typename T>
tinygemm_kernel(const T * RESTRICT A,const uint8_t * RESTRICT B,const T * RESTRICT ScaleAndZeros,T * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)514 inline void tinygemm_kernel(
515 const T* RESTRICT A,
516 const uint8_t* RESTRICT B,
517 const T* RESTRICT ScaleAndZeros,
518 T* RESTRICT C,
519 int lda,
520 int ldb,
521 int ldc,
522 int K,
523 int BLOCK_K) {
524
525 for (const auto m : c10::irange(BLOCK_M)) {
526 for (const auto n : c10::irange(BLOCK_N)) {
527 float c_val = 0;
528 for (const auto k : c10::irange(K)) {
529 int kb = k / BLOCK_K;
530 const auto scale = static_cast<float>(ScaleAndZeros[kb * ldc * 2 + n * 2]);
531 const auto zero = static_cast<float>(ScaleAndZeros[kb * ldc * 2 + n * 2 + 1]);
532 const auto a_val = static_cast<float>(A[m * lda + k]);
533 float b_val = convert_int4_to_float<BLOCK_N>(B + k *ldb, n);
534 b_val = b_val * scale + zero;
535
536 c_val += a_val * b_val;
537 }
538 C[m * ldc + n] = c_val;
539 }
540 }
541 }
542
543
544 #define LAUNCH_TINYGEMM_KERNEL(MB_SIZE, NB_SIZE) \
545 tinygemm_kernel<MB_SIZE, NB_SIZE>( \
546 A_ptr, B_ptr, S_ptr, C_ptr, \
547 K, NB_SIZE / 2, N, K, BLOCK_K);
548
549 #define LAUNCH_TINYGEMM_NB_SIZE(MB_SIZE) \
550 switch (nb_size) { \
551 case 16: \
552 LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 16); \
553 break; \
554 case 32: \
555 LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 32); \
556 break; \
557 case 48: \
558 LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 48); \
559 break; \
560 case 64: \
561 LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 64); \
562 break; \
563 default: \
564 TORCH_CHECK(false, "Unsupported n block size: ", nb_size); \
565 break; \
566 }
567
568 // NB: int4 weight pack (with BLOCK_N 64)
569 // weight (int32): {N/64, 64, K}
570 // packed (uint8): {N/64, K, 32}
571 //
572 // 1. avx512 packed format:
573 // When N is 64, to do 256-bit unpacking at a time, we pack Lane0 with Lane2,
574 // Lane1 with Lane3 since we can only do shift on a 128-bit basis.
575 //
576 // weight:
577 // [Lane0] N0...15: {a00, a01, a02, ...}
578 // [Lane1] N16...31: {a10, a11, a12, ...}
579 // [Lane2] N32...47: {a20, a21, a22, ...}
580 // [Lane3] N48...63: {a30, a31, a32, ...}
581 //
582 // packed:
583 // [Lane02] N0...31: {a20|a00, a21|a01, a22|a02, ...}
584 // [Lane13] N32...63: {a30|a10, a31|a11, a32|a12, ...}
585 //
586 // Note: when N is 16, 32 or 48, pack with 64-bit format.
587 //
588 // 2. avx2 packed format:
589 // When N is 32, to do 128-bit unpacking at a time.
590 //
591 // weight:
592 // [Lane0] N0...15: { a0, a1, a2, ...}
593 // [Lane1] N16...32: {a16, a17, a18, ...}
594 //
595 // packed:
596 // [Lane01] N0...32: {a16|a0, a17|a1, a18|a2, ...}
597 //
598 // Note: When N is 16, pack with 64-bit format
599 //
600 // 3 non-vectorized packed format:
601 // Do 64-bit unpacking at a time.
602 //
603 // weight: {a0, a1, a2, a3, ..., a14, a15}
604 // packed: {a1|a0, a3, a2, ..., a15|a14}
605 //
weight_to_int4pack_kernel(const Tensor & weight_packed,const Tensor & weight,int N,int K)606 void weight_to_int4pack_kernel(
607 const Tensor& weight_packed,
608 const Tensor& weight,
609 int N, int K) {
610
611 auto weight_packed_data = reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
612 const auto weight_data = weight.data_ptr<uint8_t>();
613
614 // 64 for avx512 and 32 for avx2/non-vectorized
615 constexpr int BLOCK_N = vec::Vectorized<float>::size() * 4;
616 const int NB = (N + BLOCK_N - 1) / BLOCK_N;
617 int K_div_2 = K / 2;
618
619 // parallel on NB blocks
620 at::parallel_for(0, NB, 0, [&](int begin, int end) {
621 for (const auto i : c10::irange(begin, end)) {
622 int nb_size = std::min(BLOCK_N, N - i * BLOCK_N);
623
624 const uint8_t* src = weight_data + i * BLOCK_N * K_div_2;
625 uint8_t* dst = weight_packed_data + i * K * BLOCK_N / 2;
626 for (const auto k : c10::irange(K_div_2)) {
627 #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
628 if (nb_size == BLOCK_N) {
629 for (const auto d : c10::irange(16)) {
630 uint8_t val0 = src[(d + 0) * K_div_2 + k];
631 uint8_t val1 = src[(d + 16) * K_div_2 + k];
632 uint8_t val2 = src[(d + 32) * K_div_2 + k];
633 uint8_t val3 = src[(d + 48) * K_div_2 + k];
634
635 uint8_t packed02_0 = (val2 & 0xF0) | ((val0 & 0xF0) >> 4);
636 uint8_t packed13_0 = (val3 & 0xF0) | ((val1 & 0xF0) >> 4);
637 uint8_t packed02_1 = ((val2 & 0xF) << 4) | (val0 & 0xF);
638 uint8_t packed13_1 = ((val3 & 0xF) << 4) | (val1 & 0xF);
639
640 dst[k * 2 * 32 + d] = packed02_0;
641 dst[k * 2 * 32 + 16 + d] = packed13_0;
642 dst[(k * 2 + 1) * 32 + d] = packed02_1;
643 dst[(k * 2 + 1) * 32 + 16 + d] = packed13_1;
644 }
645 } else {
646 // for nb_size 16, 32, 48
647 for (int n = 0; n < nb_size; n += 2) {
648 uint8_t val0 = src[n * K_div_2 + k];
649 uint8_t val1 = src[n * K_div_2 + K_div_2 + k];
650
651 uint8_t packed_0 = ((val1 & 0xF0)) | ((val0 & 0xF0) >> 4);
652 uint8_t packed_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
653 dst[k * 2 * nb_size / 2 + n / 2] = packed_0;
654 dst[(k * 2 + 1) * nb_size / 2 + n / 2] = packed_1;
655 }
656 }
657 #elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
658 if (nb_size == BLOCK_N) {
659 // for nb_size 32
660 for (const auto d : c10::irange(16)) {
661 uint8_t val0 = src[(d + 0) * K_div_2 + k];
662 uint8_t val1 = src[(d + 16) * K_div_2 + k];
663
664 uint8_t packed01_0 = ((val1 & 0xF0) | ((val0 & 0xF0) >> 4));
665 uint8_t packed01_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
666 dst[k * 2 * 16 + d] = packed01_0;
667 dst[(k * 2 + 1) * 16 + d] = packed01_1;
668 }
669 } else {
670 // for nb_size 16
671 for (int n = 0; n < nb_size; n += 2) {
672 int32_t val0 = src[n * K_div_2 + k];
673 int32_t val1 = src[n * K_div_2 + K_div_2 + k];
674
675 uint8_t packed_0 = ((val1 & 0xF0)) | ((val0 & 0xF0) >> 4);
676 uint8_t packed_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
677 dst[k * 2 * nb_size / 2 + n / 2] = packed_0;
678 dst[(k * 2 + 1) * nb_size / 2 + n / 2] = packed_1;
679 }
680 }
681 #else
682 for (int n = 0; n < nb_size; n += 2) {
683 uint8_t val0 = src[n * K_div_2 + k];
684 uint8_t val1 = src[n * K_div_2 + K_div_2 + k];
685
686 uint8_t packed_0 = ((val1 & 0xF0)) | ((val0 & 0xF0) >> 4);
687 uint8_t packed_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
688 dst[k * 2 * nb_size / 2 + n / 2] = packed_0;
689 dst[(k * 2 + 1) * nb_size / 2 + n / 2] = packed_1;
690 }
691 #endif
692 }
693 }
694 });
695 }
696
697 template<typename T>
int4pack_mm_kernel_(const Tensor & C,const Tensor & A,const Tensor & B,int qGroupSize,const Tensor & qScaleAndZeros,int N,int K)698 void int4pack_mm_kernel_(
699 const Tensor& C,
700 const Tensor& A,
701 const Tensor& B,
702 int qGroupSize,
703 const Tensor& qScaleAndZeros,
704 int N, int K) {
705
706 const auto* A_data = A.const_data_ptr<T>();
707 const auto* B_data = reinterpret_cast<const uint8_t*>(B.const_data_ptr());
708 auto* C_data = C.data_ptr<T>();
709 const auto* S_data = qScaleAndZeros.const_data_ptr<T>();
710
711 int M = A.size(0);
712
713 constexpr int BLOCK_M = 4;
714 // 64 for avx512 and 32 for avx2/non-vectorized
715 constexpr int BLOCK_N = vec::Vectorized<float>::size() * 4;
716 // 32, 64, 128, 256
717 const int BLOCK_K = qGroupSize;
718
719 const int MB = (M + BLOCK_M - 1) / BLOCK_M;
720 const int NB = (N + BLOCK_N - 1) / BLOCK_N;
721
722 at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
723 int mb{0}, nb{0};
724 data_index_init(begin, mb, MB, nb, NB);
725
726 for (C10_UNUSED const auto i : c10::irange(begin, end)) {
727 int mb_start = mb * BLOCK_M;
728 int mb_size = std::min(BLOCK_M, M - mb_start);
729 int nb_start = nb * BLOCK_N;
730 int nb_size = std::min(BLOCK_N, N - nb_start);
731
732 const auto* A_ptr = A_data + mb_start * K;
733 const auto* B_ptr = B_data + nb_start * K / 2;
734 const auto* S_ptr = S_data + nb_start * 2;
735 auto* C_ptr = C_data + mb_start * N + nb_start;
736
737 switch (mb_size) {
738 case 1:
739 LAUNCH_TINYGEMM_NB_SIZE(1);
740 break;
741 case 2:
742 LAUNCH_TINYGEMM_NB_SIZE(2);
743 break;
744 case 3:
745 LAUNCH_TINYGEMM_NB_SIZE(3);
746 break;
747 case 4:
748 LAUNCH_TINYGEMM_NB_SIZE(4);
749 break;
750 default:
751 TORCH_CHECK(false, "Unsupported m block size: ", mb_size);
752 }
753
754 // move to the next index
755 data_index_step(mb, MB, nb, NB);
756 }
757 });
758 }
759
int4pack_mm_kernel(const Tensor & C,const Tensor & A,const Tensor & B,int qGroupSize,const Tensor & qScaleAndZeros,int N,int K)760 void int4pack_mm_kernel(
761 const Tensor& C,
762 const Tensor& A,
763 const Tensor& B,
764 int qGroupSize,
765 const Tensor& qScaleAndZeros,
766 int N, int K) {
767 if (C.scalar_type() == kBFloat16) {
768 int4pack_mm_kernel_<BFloat16>(C, A, B, qGroupSize, qScaleAndZeros, N, K);
769 } else if (C.scalar_type() == kHalf) {
770 int4pack_mm_kernel_<Half>(C, A, B, qGroupSize, qScaleAndZeros, N, K);
771 } else {
772 int4pack_mm_kernel_<float>(C, A, B, qGroupSize, qScaleAndZeros, N, K);
773 }
774 }
775
776 } // anonymous namespace
777
778 ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel);
779 ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel);
780
781 } // at::native
782 C10_DIAGNOSTIC_POP()
783