xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/int8mm_kernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/cpu/vec/functional.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/native/cpu/int_mm_kernel.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <c10/util/irange.h>
11 #include <c10/util/Unroll.h>
12 
13 #if (defined(_WIN32) || defined(_WIN64))
14 #define RESTRICT __restrict
15 #else
16 #define RESTRICT __restrict__
17 #endif
18 
19 namespace at::native {
20 
21 namespace {
22 
23 #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
24 
25 // A block : {BLOCK_M, BLOCK_K}, lda = K
26 // B block : {BLOCK_K, BLOCK_N}, ldb = K
27 // C block : {BLOCK_M, BLOCK_N}, ldc = N
28 //
29 // scales block: {BLOCK_N}
30 //
31 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const BFloat16 * RESTRICT A,const int8_t * RESTRICT B,const BFloat16 * RESTRICT scales,BFloat16 * RESTRICT C,int lda,int ldb,int ldc,int K)32 inline void tinygemm_kernel(
33     const BFloat16* RESTRICT A,
34     const int8_t* RESTRICT B,
35     const BFloat16* RESTRICT scales,
36     BFloat16* RESTRICT C,
37     int lda,
38     int ldb,
39     int ldc,
40     int K) {
41 
42   constexpr int ROWS = BLOCK_M;
43   constexpr int COLS = BLOCK_N;
44 
45   const int PREFETCH_SIZE_K = 16 * 4;
46 
47   __m512 va;
48   __m512 vb[COLS];
49   __m512 vc[ROWS * COLS];
50   __m512 scale[COLS];
51 
52   auto load_scale = [&](int i) {
53     float ss = static_cast<float>(scales[i]);
54     scale[i] = _mm512_set1_ps(ss);
55   };
56   c10::ForcedUnroll<COLS>{}(load_scale);
57 
58   auto loadc = [&](auto i) {
59     vc[i] = _mm512_setzero_ps();
60   };
61   c10::ForcedUnroll<ROWS * COLS>{}(loadc);
62 
63   auto compute = [&](auto i, int k) {
64     constexpr int row = i / COLS;
65     constexpr int col = i % COLS;
66 
67     if constexpr (col == 0) {
68       __m256i a16 = _mm256_load_si256((__m256i*)(A + row * lda + k));
69       if (k + PREFETCH_SIZE_K < K) {
70         _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
71       }
72       vec::cvtbf16_fp32(a16, va);
73     }
74 
75     if constexpr (row == 0) {
76       __m128i b8 = _mm_load_si128((__m128i*)(B + col * ldb + k));
77       if (k + PREFETCH_SIZE_K < K) {
78         _mm_prefetch(B + col * ldb + k + PREFETCH_SIZE_K, _MM_HINT_T0);
79       }
80       __m512i b32 = _mm512_cvtepi8_epi32(b8);
81       vb[col] = _mm512_cvtepi32_ps(b32);
82       vb[col] = _mm512_mul_ps(vb[col], scale[col]);
83     }
84 
85     constexpr int idx = row * COLS + col;
86     vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
87   };
88 
89   for (int k = 0; k < K; k += 16) {
90       c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
91   }
92 
93   auto storec = [&](auto i) {
94     constexpr int row = i / COLS;
95     constexpr int col = i % COLS;
96     C[row * ldc + col] = static_cast<BFloat16>(_mm512_reduce_add_ps(vc[i]));
97   };
98   c10::ForcedUnroll<ROWS * COLS>{}(storec);
99 }
100 
101 #elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
102 
103 static inline float _mm256_reduce_add_ps(__m256& v) {
104   __m256 v1 = _mm256_permute2f128_ps(v, v, 0x1);
105   v = _mm256_add_ps(v, v1);
106   v1 = _mm256_shuffle_ps(v, v, 0x4E);
107   v = _mm256_add_ps(v, v1);
108   v1 = _mm256_shuffle_ps(v, v, 0xB1);
109   v = _mm256_add_ps(v, v1);
110   return _mm256_cvtss_f32(v);
111 }
112 
113 template <int BLOCK_M, int BLOCK_N>
114 inline void tinygemm_kernel(
115     const BFloat16* RESTRICT A,
116     const int8_t* RESTRICT B,
117     const BFloat16* RESTRICT scales,
118     BFloat16* RESTRICT C,
119     int lda,
120     int ldb,
121     int ldc,
122     int K) {
123 
124   constexpr int ROWS = BLOCK_M;
125   constexpr int COLS = BLOCK_N;
126 
127   const int PREFETCH_SIZE_K = 16 * 4;
128 
129   __m256 va;
130   __m256 vb[COLS];
131   __m256 vc[ROWS * COLS];
132   __m256 scale[COLS];
133 
134   auto load_scale = [&](int i) {
135     float ss = static_cast<float>(scales[i]);
136     scale[i] = _mm256_set1_ps(ss);
137   };
138   c10::ForcedUnroll<COLS>{}(load_scale);
139 
140   auto loadc = [&](auto i) {
141     vc[i] = _mm256_setzero_ps();
142   };
143   c10::ForcedUnroll<ROWS * COLS>{}(loadc);
144 
145   auto compute = [&](auto i, int k) {
146     constexpr int row = i / COLS;
147     constexpr int col = i % COLS;
148 
149     if constexpr (col == 0) {
150       __m128i a16 = _mm_load_si128((__m128i*)(A + row * lda + k));
151       if (k + PREFETCH_SIZE_K < K) {
152         _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
153       }
154       vec::cvtbf16_fp32(a16, va);
155     }
156 
157     if constexpr (row == 0) {
158        __m128i b8 = _mm_loadu_si64((__m128i*)(B + col * ldb + k));
159        if (k + PREFETCH_SIZE_K < K) {
160          _mm_prefetch(B + col * ldb + k + PREFETCH_SIZE_K, _MM_HINT_T0);
161        }
162        __m256i b32 = _mm256_cvtepi8_epi32(b8);
163        vb[col] = _mm256_cvtepi32_ps(b32);
164        vb[col] = _mm256_mul_ps(vb[col], scale[col]);
165      }
166 
167      constexpr int idx = row * COLS + col;
168      vc[idx] = _mm256_fmadd_ps(va, vb[col], vc[idx]);
169   };
170 
171   for (int k = 0; k < K; k += 8) {
172     c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
173   }
174 
175   auto storec = [&](auto i) {
176     constexpr int row = i / COLS;
177     constexpr int col = i % COLS;
178     C[row * ldc + col] = static_cast<BFloat16>(_mm256_reduce_add_ps(vc[i]));
179   };
180   c10::ForcedUnroll<ROWS * COLS>{}(storec);
181 }
182 
183 #endif
184 
185 #if !defined(C10_MOBILE) && defined(__aarch64__)
186 #include <arm_neon.h>
187 
reduce(float32x4_t x)188 inline float reduce(float32x4_t x) {
189         auto sum = vpaddq_f32(x, x);
190         return vgetq_lane_f32(vpaddq_f32(sum, sum), 0);
191 }
192 
load_as_float32x4x2(const Half * ptr)193 inline float32x4x2_t load_as_float32x4x2(const Half* ptr) {
194   float16x8_t f16_val = vld1q_f16(reinterpret_cast<const float16_t *>(ptr));
195   auto val_low = vcvt_f32_f16(vget_low_f16(f16_val));
196   auto val_high = vcvt_f32_f16(vget_high_f16(f16_val));
197   return {val_low, val_high};
198 }
199 
load_as_float32x4(const Half * ptr)200 inline float32x4_t load_as_float32x4(const Half* ptr) {
201     return vcvt_f32_f16(vld1_f16(reinterpret_cast<const float16_t *>(ptr)));
202 }
203 
load_as_float32x4x2(const BFloat16 * ptr)204 inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) {
205   int32x4_t shift = vdupq_n_s32(16);
206   uint16x8_t u16_val = vld1q_u16(reinterpret_cast<const uint16_t *>(ptr));
207   uint32x4_t int_low = vmovl_u16(vget_low_u16(u16_val));
208   uint32x4_t int_high = vmovl_u16(vget_high_u16(u16_val));
209   return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))};
210 }
211 
load_as_float32x4(const BFloat16 * ptr)212 inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
213   int32x4_t shift = vdupq_n_s32(16);
214   uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast<const uint16_t *>(ptr)));
215   return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
216 }
217 
load_as_float32x4(const float * ptr)218 inline float32x4_t load_as_float32x4(const float* ptr) {
219   return vld1q_f32(ptr);
220 }
221 
load_as_float32x4x2(const float * ptr)222 inline float32x4x2_t load_as_float32x4x2(const float* ptr) {
223   return {vld1q_f32(ptr), vld1q_f32(ptr + 4)};
224 }
225 
226 template <int BLOCK_M, int BLOCK_N, typename T>
tinygemm_kernel_(const T * RESTRICT A,const int8_t * RESTRICT B,const T * RESTRICT scales,T * RESTRICT C,int lda,int ldb,int ldc,int K)227 inline void tinygemm_kernel_(
228     const T* RESTRICT A,
229     const int8_t* RESTRICT B,
230     const T* RESTRICT scales,
231     T* RESTRICT C,
232     int lda,
233     int ldb,
234     int ldc,
235     int K) {
236 
237   for (const auto m : c10::irange(BLOCK_M)) {
238     float32x4_t c_val[BLOCK_N];
239     c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
240         c_val[i] = vdupq_n_f32(0.0);
241     });
242     for (int k = 0; k < K; k += 8) {
243       auto a_val = load_as_float32x4x2(A + m * lda + k);
244       c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
245         int16x8_t b_val = vmovl_s8(vld1_s8(B + i * ldb + k));
246         auto b_val_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_val)));
247         auto b_val_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_val)));
248         c_val[i] = vfmaq_f32(c_val[i], a_val.val[1], b_val_high);
249         c_val[i] = vfmaq_f32(c_val[i], a_val.val[0], b_val_low);
250       });
251     }
252 
253 #if __OPTIMIZE__
254     float32x4_t scale_val = load_as_float32x4(scales);
255     c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
256       C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i);
257     });
258 #else
259     // Workaround GCCs inability to infer lane index at compile time
260     // See https://github.com/pytorch/pytorch/issues/126283
261     c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
262       C[m * ldc + i] = reduce(c_val[i]) * float(scales[i]);
263     });
264 #endif
265   }
266 }
267 
268 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const Half * RESTRICT A,const int8_t * RESTRICT B,const Half * RESTRICT scales,Half * RESTRICT C,int lda,int ldb,int ldc,int K)269 inline void tinygemm_kernel(
270     const Half* RESTRICT A,
271     const int8_t* RESTRICT B,
272     const Half* RESTRICT scales,
273     Half* RESTRICT C,
274     int lda,
275     int ldb,
276     int ldc,
277     int K) {
278   tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
279 }
280 
281 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const BFloat16 * RESTRICT A,const int8_t * RESTRICT B,const BFloat16 * RESTRICT scales,BFloat16 * RESTRICT C,int lda,int ldb,int ldc,int K)282 inline void tinygemm_kernel(
283     const BFloat16* RESTRICT A,
284     const int8_t* RESTRICT B,
285     const BFloat16* RESTRICT scales,
286     BFloat16* RESTRICT C,
287     int lda,
288     int ldb,
289     int ldc,
290     int K) {
291   tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
292 }
293 
294 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const float * RESTRICT A,const int8_t * RESTRICT B,const float * RESTRICT scales,float * RESTRICT C,int lda,int ldb,int ldc,int K)295 inline void tinygemm_kernel(
296     const float* RESTRICT A,
297     const int8_t* RESTRICT B,
298     const float* RESTRICT scales,
299     float* RESTRICT C,
300     int lda,
301     int ldb,
302     int ldc,
303     int K) {
304   tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
305 }
306 #endif
307 
308 // non-vectorized version
309 template <int BLOCK_M, int BLOCK_N, typename T>
tinygemm_kernel(const T * RESTRICT A,const int8_t * RESTRICT B,const T * RESTRICT scales,T * RESTRICT C,int lda,int ldb,int ldc,int K)310 inline void tinygemm_kernel(
311     const T* RESTRICT A,
312     const int8_t* RESTRICT B,
313     const T* RESTRICT scales,
314     T* RESTRICT C,
315     int lda,
316     int ldb,
317     int ldc,
318     int K) {
319 
320   for (const auto m : c10::irange(BLOCK_M)) {
321     for (const auto n : c10::irange(BLOCK_N)) {
322       float c_val = 0;
323       float scale_val = static_cast<float>(scales[n]);
324       for (const auto k : c10::irange(K)) {
325         float a_val = static_cast<float>(A[m * lda + k]);
326         float b_val = static_cast<float>(B[n * ldb + k]);
327         c_val += a_val * (b_val * scale_val);
328       }
329       C[m * ldc + n] = c_val;
330     }
331   }
332 }
333 
334 #define LAUNCH_TINYGEMM_KERNEL(MB_SIZE, NB_SIZE)                 \
335   tinygemm_kernel<MB_SIZE, NB_SIZE>(                             \
336       A_ptr, B_ptr, S_ptr, C_ptr,                                \
337       K, K, N, K);
338 
339 #define LAUNCH_TINYGEMM_NB_SIZE(MB_SIZE)                         \
340   switch (nb_size) {                                             \
341     case 1:                                                      \
342       LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 1);                        \
343       break;                                                     \
344     case 2:                                                      \
345       LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 2);                        \
346       break;                                                     \
347     case 3:                                                      \
348       LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 3);                        \
349       break;                                                     \
350     case 4:                                                      \
351       LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 4);                        \
352       break;                                                     \
353     default:                                                     \
354       TORCH_CHECK(false, "Unsupported n block size: ", nb_size); \
355       break;                                                     \
356   }
357 
358 template<typename T>
int8pack_mm_kernel_(const Tensor & C,const Tensor & A,const Tensor & B,const Tensor & scales)359 void int8pack_mm_kernel_(
360     const Tensor& C,
361     const Tensor& A,
362     const Tensor& B,
363     const Tensor& scales) {
364 
365   const auto* A_data = A.const_data_ptr<T>();
366   const auto* B_data = B.const_data_ptr<int8_t>();
367   auto* C_data = C.data_ptr<T>();
368   const auto* S_data = scales.const_data_ptr<T>();
369 
370   int M = A.size(0);
371   int N = B.size(0);
372   int K = A.size(1);
373 
374   constexpr int BLOCK_M = 4;
375   constexpr int BLOCK_N = 4;
376 
377   const int MB = (M + BLOCK_M - 1) / BLOCK_M;
378   const int NB = (N + BLOCK_N - 1) / BLOCK_N;
379 
380   at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
381     int mb{0}, nb{0};
382     data_index_init(begin, mb, MB, nb, NB);
383 
384     for (const auto i : c10::irange(begin, end)) {
385       (void)i;
386 
387       int mb_start = mb * BLOCK_M;
388       int mb_size = std::min(BLOCK_M, M - mb_start);
389       int nb_start = nb * BLOCK_N;
390       int nb_size = std::min(BLOCK_N, N - nb_start);
391 
392       const auto* A_ptr = A_data + mb_start * K;
393       const auto* B_ptr = B_data + nb_start * K;
394       const auto* S_ptr = S_data + nb_start;
395       auto* C_ptr = C_data + mb_start * N + nb_start;
396 
397       switch (mb_size) {
398         case 1:
399           LAUNCH_TINYGEMM_NB_SIZE(1);
400           break;
401         case 2:
402           LAUNCH_TINYGEMM_NB_SIZE(2);
403           break;
404         case 3:
405           LAUNCH_TINYGEMM_NB_SIZE(3);
406           break;
407         case 4:
408           LAUNCH_TINYGEMM_NB_SIZE(4);
409           break;
410         default:
411           TORCH_CHECK(false, "Unsupported m block size: ", mb_size);
412       }
413 
414       // move to the next index
415       data_index_step(mb, MB, nb, NB);
416     }
417   });
418 }
419 
int8pack_mm_kernel(const Tensor & C,const Tensor & A,const Tensor & B,const Tensor & scales)420 void int8pack_mm_kernel(
421     const Tensor& C,
422     const Tensor& A,
423     const Tensor& B,
424     const Tensor& scales) {
425   if (C.dtype() == kHalf) {
426     int8pack_mm_kernel_<Half>(C, A, B, scales);
427   } else if (C.dtype() == kBFloat16) {
428     int8pack_mm_kernel_<BFloat16>(C, A, B, scales);
429   } else {
430     int8pack_mm_kernel_<float>(C, A, B, scales);
431   }
432 }
433 
434 } // anonymous namespace
435 
436 ALSO_REGISTER_AVX512_DISPATCH(int8pack_mm_stub, &int8pack_mm_kernel);
437 
438 } // at::native
439