xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/int4mm_kernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <type_traits>
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/core/Tensor.h>
4 
5 #include <ATen/Dispatch.h>
6 #include <ATen/Parallel.h>
7 #include <ATen/cpu/vec/functional.h>
8 #include <ATen/cpu/vec/vec.h>
9 #include <ATen/native/cpu/int_mm_kernel.h>
10 #include <ATen/native/cpu/utils.h>
11 #include <c10/util/irange.h>
12 #include <c10/util/Unroll.h>
13 
14 #if (defined(_WIN32) || defined(_WIN64))
15 #define RESTRICT __restrict
16 #else
17 #define RESTRICT __restrict__
18 #endif
19 
20 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable")
21 namespace at::native {
22 
23 namespace {
24 
is_block_start(int index,int BLOCK_SIZE)25 inline bool is_block_start(int index, int BLOCK_SIZE) {
26   return !(index & (BLOCK_SIZE -1));
27 }
28 
29 #if (defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER)
30 // convert 16x int4 to int8, handle 64 bits at a time
31 // used in avx2 and avx512
conver_int4_to_int8(const uint8_t * data)32 inline __m128i conver_int4_to_int8(const uint8_t* data) {
33   __m128i tmp = _mm_loadu_si64((const __m128i*)data);
34   __m128i bytes = _mm_cvtepu8_epi16(tmp);
35   const __m128i lowMask = _mm_set1_epi8(0xF);
36   __m128i high = _mm_andnot_si128(lowMask, bytes);
37   __m128i low = _mm_and_si128(lowMask, bytes);
38   high = _mm_slli_epi16(high, 4);
39   bytes = _mm_or_si128(low, high);
40   return bytes;
41 }
42 #endif
43 
44 #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
45 
46 // A block : {BLOCK_M, BLOCK_K}, lda = K
47 // B block : {BLOCK_K, BLOCK_N / 2}, ldb = BLOCK_N / 2
48 // C block : {BLOCK_M, BLOCK_N}, ldc = N
49 //
50 // ScaleAndZeros block : {1, BLOCK_N, 2}
51 //
52 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const BFloat16 * RESTRICT A,const uint8_t * RESTRICT B,const BFloat16 * RESTRICT ScaleAndZeros,BFloat16 * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)53 inline void tinygemm_kernel(
54     const BFloat16* RESTRICT A,
55     const uint8_t* RESTRICT B,
56     const BFloat16* RESTRICT ScaleAndZeros,
57     BFloat16* RESTRICT C,
58     int lda,
59     int ldb,
60     int ldc,
61     int K,
62     int BLOCK_K) {
63 
64   constexpr int ROWS = BLOCK_M;
65   constexpr int COLS = BLOCK_N / 16;
66 
67   const int PREFETCH_SIZE_K = 16 * 4;
68   const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
69 
70   // number of blocks on K
71   const int KB = K / BLOCK_K;
72 
73   __m512 va;
74   __m512 vb[COLS];
75   __m512 vc[ROWS * COLS];
76   __m512 scale[COLS];
77   __m512 zero[COLS];
78 
79   // Lookup table to de-quantize int4 values to bf16.
80   // Values are dequantized as truly int4 [-8, 7] range;
81   //
82   // dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
83   //
84   static const __m512 lut = _mm512_set_ps(
85       7.0f, 6.0f, 5.0f, 4.0f,
86       3.0f, 2.0f, 1.0f, 0.0f,
87       -1.0f, -2.0f, -3.0f, -4.0f,
88       -5.0f, -6.0f, -7.0f, -8.0f);
89 
90   // index for transpose
91   static const __m512i idx1 = _mm512_set_epi32(
92       30, 28, 26, 24, 22, 20, 18, 16,
93       14, 12, 10, 8, 6, 4, 2, 0);
94   static const __m512i idx2 = _mm512_set_epi32(
95       31, 29, 27, 25, 23, 21, 19, 17,
96       15, 13, 11, 9, 7, 5, 3, 1);
97 
98   // load scale and zero point
99   auto load_scale_and_zeros = [&](int i, int _kb) {
100     // load 2x bfloat16 vector
101     __m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * ldc * 2 + 32 * i));
102     if (_kb + PREFETCH_SIZE_KB < KB) {
103       _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * ldc * 2 + 32 * i, _MM_HINT_T0);
104     }
105 
106     // convert to 2x f32 vector
107     __m512 a, b;
108     vec::cvtbf16_fp32(t, a, b);
109 
110     // transpose scale_and_zero from {16, 2} to {2, 16}
111     // inputs:
112     //   a: {s0, z0, s1, z1, ..., s7, z7}
113     //   b: {s8, z8, s9, z9, ..., s15, z15}
114     // output:
115     //   scale: {s0, s1, s2, ..., s15}
116     //   zero:  {z0, z1, z2, ..., z15}
117     scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
118     zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
119   };
120 
121   auto loadc = [&](auto i) {
122     vc[i] = _mm512_setzero_ps();
123   };
124   c10::ForcedUnroll<ROWS * COLS>{}(loadc);
125 
126   auto compute = [&, COLS](auto i, int k) {
127     constexpr  int row = i / COLS;
128     constexpr  int col = i % COLS;
129 
130     if constexpr (col == 0) {
131       float aa = static_cast<float>(A[row * lda + k]);
132       if (k + PREFETCH_SIZE_K < K) {
133         _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
134       }
135       va = _mm512_set1_ps(aa);
136     }
137 
138     if constexpr (row == 0) {
139       if constexpr (COLS == 4) {
140         // when BLOCK_N = 64, handle each row at a time
141         // to reduce de-quantize overhead.
142         if constexpr (col == 0) {
143           __m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb));
144           if (k + PREFETCH_SIZE_K < K) {
145             _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
146           }
147 
148           __m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
149           vb[0] = _mm512_permutexvar_ps(b32, lut);
150           vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
151           vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
152           vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);
153 
154           b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
155           vb[1] = _mm512_permutexvar_ps(b32, lut);
156           vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
157           vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
158           vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
159         }
160       } else {
161         __m128i b8 = conver_int4_to_int8(B + k * ldb + col * 8);
162         __m512i b32 = _mm512_cvtepu8_epi32(b8);
163         vb[col] = _mm512_permutexvar_ps(b32, lut);
164         vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]);
165       }
166     }
167 
168     constexpr int idx = row * COLS + col;
169     vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
170   };
171 
172   for (int k = 0, kb = 0; k < K; ++k) {
173     if (is_block_start(k, BLOCK_K)) {
174       c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
175     }
176     c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
177   }
178 
179   //store to C
180   auto storec = [&, COLS](auto i) {
181     constexpr int row = i / COLS;
182     constexpr int col = i % COLS;
183     if constexpr (COLS == 4) {
184       // when BLOCK_N = 64, handle each row at a time
185       // to reduce `cvtfp32_bf16` overhead.
186       if constexpr (col == 0) {
187         __m512i c01 = vec::cvtfp32_bf16(vc[row * 4 + 0], vc[row * 4 + 1]);
188         __m512i c23 = vec::cvtfp32_bf16(vc[row * 4 + 2], vc[row * 4 + 3]);
189         _mm512_storeu_si512((__m512i*)(C + row * ldc + 0 * 32), c01);
190         _mm512_storeu_si512((__m512i*)(C + row * ldc + 1 * 32), c23);
191       }
192     } else {
193       __m256i ci = vec::cvtfp32_bf16(vc[i]);
194       _mm256_storeu_si256((__m256i*)(C + row * ldc + col * 16), ci);
195     }
196   };
197   c10::ForcedUnroll<ROWS * COLS>{}(storec);
198 }
199 
200 #elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
201 
202 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const BFloat16 * RESTRICT A,const uint8_t * RESTRICT B,const BFloat16 * RESTRICT ScaleAndZeros,BFloat16 * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)203 inline void tinygemm_kernel(
204     const BFloat16* RESTRICT A,
205     const uint8_t* RESTRICT B,
206     const BFloat16* RESTRICT ScaleAndZeros,
207     BFloat16* RESTRICT C,
208     int lda,
209     int ldb,
210     int ldc,
211     int K,
212     int BLOCK_K) {
213 
214   constexpr int ROWS = BLOCK_M;
215   constexpr int COLS = BLOCK_N / 8;
216 
217   const int PREFETCH_SIZE_K = 16 * 4;
218   const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
219 
220   // number of blocks on K
221   const int KB = K / BLOCK_K;
222 
223   __m256 va;
224   __m256 vb[COLS];
225   __m256 vc[ROWS * COLS];
226   __m256 scale[COLS];
227   __m256 zero[COLS];
228 
229   static const __m256i idx1 = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
230 
231   // offset to shift from range [0, 15] to [-8, 7]
232   const __m256 offset = _mm256_set1_ps(-8.0f);
233 
234   // load scale and zero point
235   auto load_scale_and_zeros = [&](int i, int _kb) {
236     // load 2x bfloat16 vector
237     __m256i t = _mm256_loadu_si256((__m256i*)(ScaleAndZeros + _kb * ldc * 2 + 16 * i));
238     if (_kb + PREFETCH_SIZE_KB < KB) {
239       _mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * ldc * 2 + 16 * i, _MM_HINT_T0);
240     }
241 
242     // convert to 2x f32 vector
243     __m256 a, b;
244     vec::cvtbf16_fp32(t, a, b);
245 
246     // transpose scale_and_zero from {8, 2} to {2, 8}
247     // inputs:
248     //   a: {s0, z0, s1, z1, s2, z2, s3, z3}
249     //   b: {s4, z4, s5, z5, s6, z6, s7, z7}
250     // output:
251     //   scale: {s0, s1, s2, s3, s4, s5, s6, s7}
252     //   zero:  {z0, z1, z2, z3, z4, z5, z6, z7}
253     a = _mm256_permutevar8x32_ps(a, idx1);
254     b = _mm256_permutevar8x32_ps(b, idx1);
255     scale[i] = _mm256_permute2f128_ps(a, b, 0b0100000);
256     zero[i] = _mm256_permute2f128_ps(a, b, 0b0110001);
257 
258     // zero = -8 * scale + zero
259     zero[i] = _mm256_fmadd_ps(scale[i], offset, zero[i]);
260   };
261 
262   auto loadc = [&](auto i) {
263     vc[i] = _mm256_setzero_ps();
264   };
265   c10::ForcedUnroll<ROWS * COLS>{}(loadc);
266 
267   auto compute = [&, COLS](auto i, int k) {
268     constexpr int row = i / COLS;
269     constexpr int col = i % COLS;
270 
271     if constexpr (col == 0) {
272       float aa = static_cast<float>(A[row * lda + k]);
273       if (k + PREFETCH_SIZE_K < K) {
274         _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
275       }
276       va = _mm256_set1_ps(aa);
277     }
278 
279     if constexpr (row == 0) {
280       if constexpr (COLS == 4) {
281         // when BLOCK_N = 32, handle each row at a time
282         if constexpr (col == 0) {
283           __m256i mask = _mm256_set1_epi32(0xF);
284           __m128i b4 = _mm_loadu_si128((__m128i*)(B + k * ldb));
285           if (k + PREFETCH_SIZE_K < K) {
286             _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
287           }
288 
289           __m256i b32 = _mm256_cvtepu8_epi32(b4);
290           vb[0] = _mm256_cvtepi32_ps(_mm256_and_si256(b32, mask));
291           vb[0] = _mm256_fmadd_ps(vb[0], scale[0], zero[0]);
292           vb[2] = _mm256_cvtepi32_ps(_mm256_srli_epi32(b32, 4));
293           vb[2] = _mm256_fmadd_ps(vb[2], scale[2], zero[2]);
294 
295           b32 = _mm256_cvtepu8_epi32(_mm_shuffle_epi32(b4, _MM_SHUFFLE(3, 2, 3, 2)));
296           vb[1] = _mm256_cvtepi32_ps(_mm256_and_si256(b32, mask));
297           vb[1] = _mm256_fmadd_ps(vb[1], scale[1], zero[1]);
298           vb[3] = _mm256_cvtepi32_ps(_mm256_srli_epi32(b32, 4));
299           vb[3] = _mm256_fmadd_ps(vb[3], scale[3], zero[3]);
300         }
301       } else {
302         if constexpr (col % 2 == 0) {
303           // de-quantize per 64 bits (16x int4)
304           __m128i b8 = conver_int4_to_int8(B + k * ldb + col * 4);
305           __m128i b8_val0 = _mm_set1_epi64x(_mm_extract_epi64(b8, 0));
306           __m128i b8_val1 = _mm_set1_epi64x(_mm_extract_epi64(b8, 1));
307           if (k + PREFETCH_SIZE_K < K) {
308             _mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb + col * 4, _MM_HINT_T0);
309           }
310 
311           vb[col] = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(b8_val0));
312           vb[col] = _mm256_fmadd_ps(vb[col], scale[col], zero[col]);
313           vb[col + 1] = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(b8_val1));
314           vb[col + 1] = _mm256_fmadd_ps(vb[col + 1], scale[col + 1], zero[col + 1]);
315         }
316       }
317     }
318 
319     constexpr int idx = row * COLS + col;
320     vc[idx] = _mm256_fmadd_ps(va, vb[col], vc[idx]);
321   };
322   for (int k = 0, kb = 0; k < K; ++k) {
323     if (is_block_start(k, BLOCK_K)) {
324         c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
325     }
326     c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
327   }
328 
329   // store to C
330   auto storec = [&](auto i) {
331     constexpr int row = i / COLS;
332     constexpr int col = i % COLS;
333     if constexpr (col % 2 == 0) {
334       __m256i ci = vec::cvtfp32_bf16(vc[row * COLS + col], vc[row * COLS + col + 1]);
335       _mm256_storeu_si256((__m256i*)(C + row * ldc + col * 8), ci);
336     }
337   };
338   c10::ForcedUnroll<ROWS * COLS>{}(storec);
339 }
340 
341 #endif
342 
343 #if !defined(C10_MOBILE) && defined(__aarch64__)
344 #include <arm_neon.h>
345 
load_as_float32x4x2(const Half * ptr)346 inline float32x4x2_t load_as_float32x4x2(const Half* ptr) {
347   float16x4x2_t f16_val = vld2_f16(reinterpret_cast<const float16_t *>(ptr));
348   auto val_low = vcvt_f32_f16(f16_val.val[0]);
349   auto val_high = vcvt_f32_f16(f16_val.val[1]);
350   return {val_low, val_high};
351 }
352 
store_float32x4(Half * ptr,float32x4_t val)353 inline void store_float32x4(Half* ptr, float32x4_t val) {
354     vst1_f16(reinterpret_cast<float16_t*>(ptr), vcvt_f16_f32(val));
355 }
356 
load_as_float32x4x2(const BFloat16 * ptr)357 inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) {
358   int32x4_t shift = vdupq_n_s32(16);
359   uint16x4x2_t u16_val = vld2_u16(reinterpret_cast<const uint16_t *>(ptr));
360   uint32x4_t int_low = vmovl_u16(u16_val.val[0]);
361   uint32x4_t int_high = vmovl_u16(u16_val.val[1]);
362   return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))};
363 }
364 
store_float32x4(BFloat16 * ptr,float32x4_t val)365 inline void store_float32x4(BFloat16* ptr, float32x4_t val) {
366     int32x4_t shift = vdupq_n_s32(-16);
367     uint32x4_t uint32_val = vshlq_u32(vreinterpretq_u32_f32(val), shift);
368     vst1_u16(reinterpret_cast<uint16_t*>(ptr), vmovn_u32(uint32_val));
369 }
370 
load_as_float32x4x2(const float * ptr)371 inline float32x4x2_t load_as_float32x4x2(const float* ptr) {
372   return vld2q_f32(ptr);
373 }
374 
store_float32x4(float * ptr,float32x4_t val)375 inline void store_float32x4(float* ptr, float32x4_t val) {
376     vst1q_f32(ptr, val);
377 }
378 
379 template <int BLOCK_M, int BLOCK_N, typename T>
tinygemm_kernel_(const T * RESTRICT A,const uint8_t * RESTRICT B,const T * RESTRICT ScaleAndZeros,T * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)380 inline void tinygemm_kernel_(
381     const T* RESTRICT A,
382     const uint8_t* RESTRICT B,
383     const T* RESTRICT ScaleAndZeros,
384     T* RESTRICT C,
385     int lda,
386     int ldb,
387     int ldc,
388     int K,
389     int BLOCK_K) {
390   int16_t shift_vals[4] = {0, -4, -8, -12};
391   int16x4_t shifts = vld1_s16(shift_vals);
392   int16x4_t offs = vdup_n_s16(8);
393   uint16x4_t mask = vdup_n_u16(0x0F);
394   for (const auto m : c10::irange(BLOCK_M)) {
395     for (int n = 0; n < BLOCK_N; n+= 16) {
396       float32x4_t c_val[4];
397       float32x4_t scales[4], zeros[4];
398       c10::ForcedUnroll<4>{}([&](auto i) {
399           c_val[i] = vdupq_n_f32(0.0);
400       });
401       for (const auto k : c10::irange(K)) {
402         const auto a_val = vdupq_n_f32(static_cast<float>(A[m * lda + k]));
403         if (is_block_start(k, BLOCK_K)) {
404           int kb = k / BLOCK_K;
405           c10::ForcedUnroll<4>{}([&](auto i) {
406             auto scales_and_zeros = load_as_float32x4x2(ScaleAndZeros + kb * ldc * 2 + n * 2 + i * 8);
407             scales[i] = scales_and_zeros.val[0];
408             zeros[i] = scales_and_zeros.val[1];
409           });
410         }
411         c10::ForcedUnroll<4>{}([&](auto i) {
412           uint16_t b_pack = reinterpret_cast<const uint16_t*>(B + k * ldb + n / 2)[i];
413           uint16x4_t b_masked = vand_u16(vshl_u16(vdup_n_u16(b_pack), shifts), mask);
414           int16x4_t b_ints = vsub_s16(vreinterpret_s16_u16(b_masked), offs);
415           float32x4_t b_vals = vcvtq_f32_s32(vmovl_s16(b_ints));
416           b_vals = vaddq_f32(zeros[i], vmulq_f32(scales[i], b_vals));
417           c_val[i] = vfmaq_f32(c_val[i], b_vals, a_val);
418         });
419       }
420       c10::ForcedUnroll<4>{}([&](auto i) {
421         store_float32x4(C + m * ldc + n + i * 4, c_val[i]);
422       });
423     }
424   }
425 }
426 
427 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const Half * RESTRICT A,const uint8_t * RESTRICT B,const Half * RESTRICT ScaleAndZeros,Half * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)428 inline void tinygemm_kernel(
429     const Half* RESTRICT A,
430     const uint8_t* RESTRICT B,
431     const Half* RESTRICT ScaleAndZeros,
432     Half* RESTRICT C,
433     int lda,
434     int ldb,
435     int ldc,
436     int K,
437     int BLOCK_K) {
438   tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
439 }
440 
441 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const BFloat16 * RESTRICT A,const uint8_t * RESTRICT B,const BFloat16 * RESTRICT ScaleAndZeros,BFloat16 * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)442 inline void tinygemm_kernel(
443     const BFloat16* RESTRICT A,
444     const uint8_t* RESTRICT B,
445     const BFloat16* RESTRICT ScaleAndZeros,
446     BFloat16* RESTRICT C,
447     int lda,
448     int ldb,
449     int ldc,
450     int K,
451     int BLOCK_K) {
452   tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
453 }
454 
455 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const float * RESTRICT A,const uint8_t * RESTRICT B,const float * RESTRICT ScaleAndZeros,float * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)456 inline void tinygemm_kernel(
457     const float* RESTRICT A,
458     const uint8_t* RESTRICT B,
459     const float* RESTRICT ScaleAndZeros,
460     float* RESTRICT C,
461     int lda,
462     int ldb,
463     int ldc,
464     int K,
465     int BLOCK_K) {
466   tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, ScaleAndZeros, C, lda, ldb, ldc, K, BLOCK_K);
467 }
468 #endif
469 
470 template<int BLOCK_N>
convert_int4_to_float(const uint8_t * b,int n)471 inline float convert_int4_to_float(const uint8_t* b, int n) {
472   static constexpr float lut[16] = {
473     -8.0f, -7.0f, -6.0f, -5.0f,
474     -4.0f, -3.0f, -2.0f, -1.0f,
475     0.0f, 1.0f, 2.0f, 3.0f,
476     4.0f, 5.0f, 6.0f, 7.0f
477   };
478   int index;
479 #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
480   if constexpr (BLOCK_N == 64) {
481     const int nb = n/BLOCK_N;
482     n -= nb*BLOCK_N;
483     if (n < 32) {
484       auto val = b[nb * BLOCK_N / 2 + n];
485       index = val & 0x0f;
486     } else {
487       auto val = b[nb * BLOCK_N / 2 + (n - 32)];
488       index = val >> 4;
489     }
490   } else
491 #elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
492   if constexpr (BLOCK_N == 32) {
493     const int nb = n/BLOCK_N;
494     n -= nb*BLOCK_N;
495     if (n < 16) {
496       auto val = b[nb * BLOCK_N / 2 + n];
497       index = val & 0x0f;
498     } else {
499       auto val = b[nb * BLOCK_N / 2 + (n - 16)];
500       index = val >> 4;
501     }
502   } else
503 #endif
504   {
505     const auto is_even = (n & 1) == 0;
506     auto val = b[n/2];
507     index = is_even ? (val & 0x0F) : (val >> 4);
508   }
509   return lut[index];
510 }
511 
512 // non-vectorized version
513 template <int BLOCK_M, int BLOCK_N, typename T>
tinygemm_kernel(const T * RESTRICT A,const uint8_t * RESTRICT B,const T * RESTRICT ScaleAndZeros,T * RESTRICT C,int lda,int ldb,int ldc,int K,int BLOCK_K)514 inline void tinygemm_kernel(
515     const T* RESTRICT A,
516     const uint8_t* RESTRICT B,
517     const T* RESTRICT ScaleAndZeros,
518     T* RESTRICT C,
519     int lda,
520     int ldb,
521     int ldc,
522     int K,
523     int BLOCK_K) {
524 
525   for (const auto m : c10::irange(BLOCK_M)) {
526     for (const auto n : c10::irange(BLOCK_N)) {
527       float c_val = 0;
528       for (const auto k : c10::irange(K)) {
529         int kb = k / BLOCK_K;
530         const auto scale = static_cast<float>(ScaleAndZeros[kb * ldc * 2 + n * 2]);
531         const auto zero = static_cast<float>(ScaleAndZeros[kb * ldc * 2 + n * 2 + 1]);
532         const auto a_val = static_cast<float>(A[m * lda + k]);
533         float b_val = convert_int4_to_float<BLOCK_N>(B + k *ldb, n);
534         b_val = b_val * scale + zero;
535 
536         c_val += a_val * b_val;
537       }
538       C[m * ldc + n] = c_val;
539     }
540   }
541 }
542 
543 
544 #define LAUNCH_TINYGEMM_KERNEL(MB_SIZE, NB_SIZE)                 \
545   tinygemm_kernel<MB_SIZE, NB_SIZE>(                             \
546       A_ptr, B_ptr, S_ptr, C_ptr,                                \
547       K, NB_SIZE / 2, N, K, BLOCK_K);
548 
549 #define LAUNCH_TINYGEMM_NB_SIZE(MB_SIZE)                         \
550   switch (nb_size) {                                             \
551     case 16:                                                     \
552       LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 16);                       \
553       break;                                                     \
554     case 32:                                                     \
555       LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 32);                       \
556       break;                                                     \
557     case 48:                                                     \
558       LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 48);                       \
559       break;                                                     \
560     case 64:                                                     \
561       LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 64);                       \
562       break;                                                     \
563     default:                                                     \
564       TORCH_CHECK(false, "Unsupported n block size: ", nb_size); \
565       break;                                                     \
566   }
567 
568 // NB: int4 weight pack (with BLOCK_N 64)
569 //   weight (int32): {N/64, 64, K}
570 //   packed (uint8): {N/64, K, 32}
571 //
572 // 1. avx512 packed format:
573 //   When N is 64, to do 256-bit unpacking at a time, we pack Lane0 with Lane2,
574 //   Lane1 with Lane3 since we can only do shift on a 128-bit basis.
575 //
576 //   weight:
577 //     [Lane0] N0...15:  {a00, a01, a02, ...}
578 //     [Lane1] N16...31: {a10, a11, a12, ...}
579 //     [Lane2] N32...47: {a20, a21, a22, ...}
580 //     [Lane3] N48...63: {a30, a31, a32, ...}
581 //
582 //  packed:
583 //     [Lane02] N0...31:  {a20|a00, a21|a01, a22|a02, ...}
584 //     [Lane13] N32...63: {a30|a10, a31|a11, a32|a12, ...}
585 //
586 //  Note: when N is 16, 32 or 48, pack with 64-bit format.
587 //
588 // 2. avx2 packed format:
589 //   When N is 32, to do 128-bit unpacking at a time.
590 //
591 //   weight:
592 //     [Lane0] N0...15:  { a0,  a1,  a2, ...}
593 //     [Lane1] N16...32: {a16, a17, a18, ...}
594 //
595 //  packed:
596 //    [Lane01] N0...32: {a16|a0, a17|a1, a18|a2, ...}
597 //
598 //  Note: When N is 16, pack with 64-bit format
599 //
600 // 3 non-vectorized packed format:
601 //   Do 64-bit unpacking at a time.
602 //
603 //   weight: {a0, a1, a2, a3, ..., a14, a15}
604 //   packed: {a1|a0, a3, a2, ..., a15|a14}
605 //
weight_to_int4pack_kernel(const Tensor & weight_packed,const Tensor & weight,int N,int K)606 void weight_to_int4pack_kernel(
607     const Tensor& weight_packed,
608     const Tensor& weight,
609     int N, int K) {
610 
611   auto weight_packed_data = reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
612   const auto weight_data = weight.data_ptr<uint8_t>();
613 
614   // 64 for avx512 and 32 for avx2/non-vectorized
615   constexpr int BLOCK_N = vec::Vectorized<float>::size() * 4;
616   const int NB =  (N + BLOCK_N - 1) / BLOCK_N;
617   int K_div_2 = K / 2;
618 
619   // parallel on NB blocks
620   at::parallel_for(0, NB, 0, [&](int begin, int end) {
621     for (const auto i : c10::irange(begin, end)) {
622       int nb_size = std::min(BLOCK_N, N - i * BLOCK_N);
623 
624       const uint8_t* src = weight_data + i * BLOCK_N * K_div_2;
625       uint8_t* dst = weight_packed_data + i * K * BLOCK_N / 2;
626       for (const auto k : c10::irange(K_div_2)) {
627 #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
628         if (nb_size == BLOCK_N) {
629           for (const auto d : c10::irange(16)) {
630             uint8_t val0 = src[(d + 0) * K_div_2 + k];
631             uint8_t val1 = src[(d + 16) * K_div_2 + k];
632             uint8_t val2 = src[(d + 32) * K_div_2 + k];
633             uint8_t val3 = src[(d + 48) * K_div_2 + k];
634 
635             uint8_t packed02_0 = (val2 & 0xF0) | ((val0 & 0xF0) >> 4);
636             uint8_t packed13_0 = (val3 & 0xF0) | ((val1 & 0xF0) >> 4);
637             uint8_t packed02_1 = ((val2 & 0xF) << 4) | (val0 & 0xF);
638             uint8_t packed13_1 = ((val3 & 0xF) << 4) | (val1 & 0xF);
639 
640             dst[k * 2 * 32 + d] = packed02_0;
641             dst[k * 2 * 32 + 16 + d] = packed13_0;
642             dst[(k * 2 + 1) * 32 + d] = packed02_1;
643             dst[(k * 2 + 1) * 32 + 16 + d] = packed13_1;
644           }
645         } else {
646           // for nb_size 16, 32, 48
647           for (int n = 0; n < nb_size; n += 2) {
648             uint8_t val0 = src[n * K_div_2 + k];
649             uint8_t val1 = src[n * K_div_2 + K_div_2 + k];
650 
651             uint8_t packed_0 = ((val1 & 0xF0)) | ((val0 & 0xF0) >> 4);
652             uint8_t packed_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
653             dst[k * 2 * nb_size / 2 + n / 2] = packed_0;
654             dst[(k * 2 + 1) * nb_size / 2 + n / 2] = packed_1;
655           }
656         }
657 #elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
658         if (nb_size == BLOCK_N) {
659           // for nb_size 32
660           for (const auto d : c10::irange(16)) {
661             uint8_t val0 = src[(d + 0) * K_div_2 + k];
662             uint8_t val1 = src[(d + 16) * K_div_2 + k];
663 
664             uint8_t packed01_0 = ((val1 & 0xF0) | ((val0 & 0xF0) >> 4));
665             uint8_t packed01_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
666             dst[k * 2 * 16 + d] = packed01_0;
667             dst[(k * 2 + 1) * 16 + d] = packed01_1;
668           }
669         } else {
670           // for nb_size 16
671           for (int n = 0; n < nb_size; n += 2) {
672             int32_t val0 = src[n * K_div_2 + k];
673             int32_t val1 = src[n * K_div_2 + K_div_2 + k];
674 
675             uint8_t packed_0 = ((val1 & 0xF0)) | ((val0 & 0xF0) >> 4);
676             uint8_t packed_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
677             dst[k * 2 * nb_size / 2 + n / 2] = packed_0;
678             dst[(k * 2 + 1) * nb_size / 2 + n / 2] = packed_1;
679           }
680         }
681 #else
682         for (int n = 0; n < nb_size; n += 2) {
683           uint8_t val0 = src[n * K_div_2 + k];
684           uint8_t val1 = src[n * K_div_2 + K_div_2 + k];
685 
686           uint8_t packed_0 = ((val1 & 0xF0)) | ((val0 & 0xF0) >> 4);
687           uint8_t packed_1 = ((val1 & 0xF) << 4) | (val0 & 0xF);
688           dst[k * 2 * nb_size / 2 + n / 2] = packed_0;
689           dst[(k * 2 + 1) * nb_size / 2 + n / 2] = packed_1;
690         }
691 #endif
692       }
693     }
694   });
695 }
696 
697 template<typename T>
int4pack_mm_kernel_(const Tensor & C,const Tensor & A,const Tensor & B,int qGroupSize,const Tensor & qScaleAndZeros,int N,int K)698 void int4pack_mm_kernel_(
699     const Tensor& C,
700     const Tensor& A,
701     const Tensor& B,
702     int qGroupSize,
703     const Tensor& qScaleAndZeros,
704     int N, int K) {
705 
706   const auto* A_data = A.const_data_ptr<T>();
707   const auto* B_data = reinterpret_cast<const uint8_t*>(B.const_data_ptr());
708   auto* C_data = C.data_ptr<T>();
709   const auto* S_data = qScaleAndZeros.const_data_ptr<T>();
710 
711   int M = A.size(0);
712 
713   constexpr int BLOCK_M = 4;
714   // 64 for avx512 and 32 for avx2/non-vectorized
715   constexpr int BLOCK_N = vec::Vectorized<float>::size() * 4;
716   // 32, 64, 128, 256
717   const int BLOCK_K = qGroupSize;
718 
719   const int MB = (M + BLOCK_M - 1) / BLOCK_M;
720   const int NB = (N + BLOCK_N - 1) / BLOCK_N;
721 
722   at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
723     int mb{0}, nb{0};
724     data_index_init(begin, mb, MB, nb, NB);
725 
726     for (C10_UNUSED const auto i : c10::irange(begin, end)) {
727       int mb_start = mb * BLOCK_M;
728       int mb_size = std::min(BLOCK_M, M - mb_start);
729       int nb_start = nb * BLOCK_N;
730       int nb_size = std::min(BLOCK_N, N - nb_start);
731 
732       const auto* A_ptr = A_data + mb_start * K;
733       const auto* B_ptr = B_data + nb_start * K / 2;
734       const auto* S_ptr = S_data + nb_start * 2;
735       auto* C_ptr = C_data + mb_start * N + nb_start;
736 
737       switch (mb_size) {
738         case 1:
739           LAUNCH_TINYGEMM_NB_SIZE(1);
740           break;
741         case 2:
742           LAUNCH_TINYGEMM_NB_SIZE(2);
743           break;
744         case 3:
745           LAUNCH_TINYGEMM_NB_SIZE(3);
746           break;
747         case 4:
748           LAUNCH_TINYGEMM_NB_SIZE(4);
749           break;
750         default:
751           TORCH_CHECK(false, "Unsupported m block size: ", mb_size);
752       }
753 
754       // move to the next index
755       data_index_step(mb, MB, nb, NB);
756     }
757   });
758 }
759 
int4pack_mm_kernel(const Tensor & C,const Tensor & A,const Tensor & B,int qGroupSize,const Tensor & qScaleAndZeros,int N,int K)760 void int4pack_mm_kernel(
761     const Tensor& C,
762     const Tensor& A,
763     const Tensor& B,
764     int qGroupSize,
765     const Tensor& qScaleAndZeros,
766     int N, int K) {
767   if (C.scalar_type() == kBFloat16) {
768     int4pack_mm_kernel_<BFloat16>(C, A, B, qGroupSize, qScaleAndZeros, N, K);
769   } else if (C.scalar_type() == kHalf) {
770     int4pack_mm_kernel_<Half>(C, A, B, qGroupSize, qScaleAndZeros, N, K);
771   } else {
772     int4pack_mm_kernel_<float>(C, A, B, qGroupSize, qScaleAndZeros, N, K);
773   }
774 }
775 
776 } // anonymous namespace
777 
778 ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel);
779 ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel);
780 
781 } // at::native
782 C10_DIAGNOSTIC_POP()
783