1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/cpu/vec/functional.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/native/cpu/int_mm_kernel.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <c10/util/irange.h>
11 #include <c10/util/Unroll.h>
12
13 #if (defined(_WIN32) || defined(_WIN64))
14 #define RESTRICT __restrict
15 #else
16 #define RESTRICT __restrict__
17 #endif
18
19 namespace at::native {
20
21 namespace {
22
23 #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
24
25 // A block : {BLOCK_M, BLOCK_K}, lda = K
26 // B block : {BLOCK_K, BLOCK_N}, ldb = K
27 // C block : {BLOCK_M, BLOCK_N}, ldc = N
28 //
29 // scales block: {BLOCK_N}
30 //
31 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const BFloat16 * RESTRICT A,const int8_t * RESTRICT B,const BFloat16 * RESTRICT scales,BFloat16 * RESTRICT C,int lda,int ldb,int ldc,int K)32 inline void tinygemm_kernel(
33 const BFloat16* RESTRICT A,
34 const int8_t* RESTRICT B,
35 const BFloat16* RESTRICT scales,
36 BFloat16* RESTRICT C,
37 int lda,
38 int ldb,
39 int ldc,
40 int K) {
41
42 constexpr int ROWS = BLOCK_M;
43 constexpr int COLS = BLOCK_N;
44
45 const int PREFETCH_SIZE_K = 16 * 4;
46
47 __m512 va;
48 __m512 vb[COLS];
49 __m512 vc[ROWS * COLS];
50 __m512 scale[COLS];
51
52 auto load_scale = [&](int i) {
53 float ss = static_cast<float>(scales[i]);
54 scale[i] = _mm512_set1_ps(ss);
55 };
56 c10::ForcedUnroll<COLS>{}(load_scale);
57
58 auto loadc = [&](auto i) {
59 vc[i] = _mm512_setzero_ps();
60 };
61 c10::ForcedUnroll<ROWS * COLS>{}(loadc);
62
63 auto compute = [&](auto i, int k) {
64 constexpr int row = i / COLS;
65 constexpr int col = i % COLS;
66
67 if constexpr (col == 0) {
68 __m256i a16 = _mm256_load_si256((__m256i*)(A + row * lda + k));
69 if (k + PREFETCH_SIZE_K < K) {
70 _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
71 }
72 vec::cvtbf16_fp32(a16, va);
73 }
74
75 if constexpr (row == 0) {
76 __m128i b8 = _mm_load_si128((__m128i*)(B + col * ldb + k));
77 if (k + PREFETCH_SIZE_K < K) {
78 _mm_prefetch(B + col * ldb + k + PREFETCH_SIZE_K, _MM_HINT_T0);
79 }
80 __m512i b32 = _mm512_cvtepi8_epi32(b8);
81 vb[col] = _mm512_cvtepi32_ps(b32);
82 vb[col] = _mm512_mul_ps(vb[col], scale[col]);
83 }
84
85 constexpr int idx = row * COLS + col;
86 vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
87 };
88
89 for (int k = 0; k < K; k += 16) {
90 c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
91 }
92
93 auto storec = [&](auto i) {
94 constexpr int row = i / COLS;
95 constexpr int col = i % COLS;
96 C[row * ldc + col] = static_cast<BFloat16>(_mm512_reduce_add_ps(vc[i]));
97 };
98 c10::ForcedUnroll<ROWS * COLS>{}(storec);
99 }
100
101 #elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
102
103 static inline float _mm256_reduce_add_ps(__m256& v) {
104 __m256 v1 = _mm256_permute2f128_ps(v, v, 0x1);
105 v = _mm256_add_ps(v, v1);
106 v1 = _mm256_shuffle_ps(v, v, 0x4E);
107 v = _mm256_add_ps(v, v1);
108 v1 = _mm256_shuffle_ps(v, v, 0xB1);
109 v = _mm256_add_ps(v, v1);
110 return _mm256_cvtss_f32(v);
111 }
112
113 template <int BLOCK_M, int BLOCK_N>
114 inline void tinygemm_kernel(
115 const BFloat16* RESTRICT A,
116 const int8_t* RESTRICT B,
117 const BFloat16* RESTRICT scales,
118 BFloat16* RESTRICT C,
119 int lda,
120 int ldb,
121 int ldc,
122 int K) {
123
124 constexpr int ROWS = BLOCK_M;
125 constexpr int COLS = BLOCK_N;
126
127 const int PREFETCH_SIZE_K = 16 * 4;
128
129 __m256 va;
130 __m256 vb[COLS];
131 __m256 vc[ROWS * COLS];
132 __m256 scale[COLS];
133
134 auto load_scale = [&](int i) {
135 float ss = static_cast<float>(scales[i]);
136 scale[i] = _mm256_set1_ps(ss);
137 };
138 c10::ForcedUnroll<COLS>{}(load_scale);
139
140 auto loadc = [&](auto i) {
141 vc[i] = _mm256_setzero_ps();
142 };
143 c10::ForcedUnroll<ROWS * COLS>{}(loadc);
144
145 auto compute = [&](auto i, int k) {
146 constexpr int row = i / COLS;
147 constexpr int col = i % COLS;
148
149 if constexpr (col == 0) {
150 __m128i a16 = _mm_load_si128((__m128i*)(A + row * lda + k));
151 if (k + PREFETCH_SIZE_K < K) {
152 _mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
153 }
154 vec::cvtbf16_fp32(a16, va);
155 }
156
157 if constexpr (row == 0) {
158 __m128i b8 = _mm_loadu_si64((__m128i*)(B + col * ldb + k));
159 if (k + PREFETCH_SIZE_K < K) {
160 _mm_prefetch(B + col * ldb + k + PREFETCH_SIZE_K, _MM_HINT_T0);
161 }
162 __m256i b32 = _mm256_cvtepi8_epi32(b8);
163 vb[col] = _mm256_cvtepi32_ps(b32);
164 vb[col] = _mm256_mul_ps(vb[col], scale[col]);
165 }
166
167 constexpr int idx = row * COLS + col;
168 vc[idx] = _mm256_fmadd_ps(va, vb[col], vc[idx]);
169 };
170
171 for (int k = 0; k < K; k += 8) {
172 c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
173 }
174
175 auto storec = [&](auto i) {
176 constexpr int row = i / COLS;
177 constexpr int col = i % COLS;
178 C[row * ldc + col] = static_cast<BFloat16>(_mm256_reduce_add_ps(vc[i]));
179 };
180 c10::ForcedUnroll<ROWS * COLS>{}(storec);
181 }
182
183 #endif
184
185 #if !defined(C10_MOBILE) && defined(__aarch64__)
186 #include <arm_neon.h>
187
reduce(float32x4_t x)188 inline float reduce(float32x4_t x) {
189 auto sum = vpaddq_f32(x, x);
190 return vgetq_lane_f32(vpaddq_f32(sum, sum), 0);
191 }
192
load_as_float32x4x2(const Half * ptr)193 inline float32x4x2_t load_as_float32x4x2(const Half* ptr) {
194 float16x8_t f16_val = vld1q_f16(reinterpret_cast<const float16_t *>(ptr));
195 auto val_low = vcvt_f32_f16(vget_low_f16(f16_val));
196 auto val_high = vcvt_f32_f16(vget_high_f16(f16_val));
197 return {val_low, val_high};
198 }
199
load_as_float32x4(const Half * ptr)200 inline float32x4_t load_as_float32x4(const Half* ptr) {
201 return vcvt_f32_f16(vld1_f16(reinterpret_cast<const float16_t *>(ptr)));
202 }
203
load_as_float32x4x2(const BFloat16 * ptr)204 inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) {
205 int32x4_t shift = vdupq_n_s32(16);
206 uint16x8_t u16_val = vld1q_u16(reinterpret_cast<const uint16_t *>(ptr));
207 uint32x4_t int_low = vmovl_u16(vget_low_u16(u16_val));
208 uint32x4_t int_high = vmovl_u16(vget_high_u16(u16_val));
209 return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))};
210 }
211
load_as_float32x4(const BFloat16 * ptr)212 inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
213 int32x4_t shift = vdupq_n_s32(16);
214 uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast<const uint16_t *>(ptr)));
215 return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
216 }
217
load_as_float32x4(const float * ptr)218 inline float32x4_t load_as_float32x4(const float* ptr) {
219 return vld1q_f32(ptr);
220 }
221
load_as_float32x4x2(const float * ptr)222 inline float32x4x2_t load_as_float32x4x2(const float* ptr) {
223 return {vld1q_f32(ptr), vld1q_f32(ptr + 4)};
224 }
225
226 template <int BLOCK_M, int BLOCK_N, typename T>
tinygemm_kernel_(const T * RESTRICT A,const int8_t * RESTRICT B,const T * RESTRICT scales,T * RESTRICT C,int lda,int ldb,int ldc,int K)227 inline void tinygemm_kernel_(
228 const T* RESTRICT A,
229 const int8_t* RESTRICT B,
230 const T* RESTRICT scales,
231 T* RESTRICT C,
232 int lda,
233 int ldb,
234 int ldc,
235 int K) {
236
237 for (const auto m : c10::irange(BLOCK_M)) {
238 float32x4_t c_val[BLOCK_N];
239 c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
240 c_val[i] = vdupq_n_f32(0.0);
241 });
242 for (int k = 0; k < K; k += 8) {
243 auto a_val = load_as_float32x4x2(A + m * lda + k);
244 c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
245 int16x8_t b_val = vmovl_s8(vld1_s8(B + i * ldb + k));
246 auto b_val_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_val)));
247 auto b_val_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_val)));
248 c_val[i] = vfmaq_f32(c_val[i], a_val.val[1], b_val_high);
249 c_val[i] = vfmaq_f32(c_val[i], a_val.val[0], b_val_low);
250 });
251 }
252
253 #if __OPTIMIZE__
254 float32x4_t scale_val = load_as_float32x4(scales);
255 c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
256 C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i);
257 });
258 #else
259 // Workaround GCCs inability to infer lane index at compile time
260 // See https://github.com/pytorch/pytorch/issues/126283
261 c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
262 C[m * ldc + i] = reduce(c_val[i]) * float(scales[i]);
263 });
264 #endif
265 }
266 }
267
268 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const Half * RESTRICT A,const int8_t * RESTRICT B,const Half * RESTRICT scales,Half * RESTRICT C,int lda,int ldb,int ldc,int K)269 inline void tinygemm_kernel(
270 const Half* RESTRICT A,
271 const int8_t* RESTRICT B,
272 const Half* RESTRICT scales,
273 Half* RESTRICT C,
274 int lda,
275 int ldb,
276 int ldc,
277 int K) {
278 tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
279 }
280
281 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const BFloat16 * RESTRICT A,const int8_t * RESTRICT B,const BFloat16 * RESTRICT scales,BFloat16 * RESTRICT C,int lda,int ldb,int ldc,int K)282 inline void tinygemm_kernel(
283 const BFloat16* RESTRICT A,
284 const int8_t* RESTRICT B,
285 const BFloat16* RESTRICT scales,
286 BFloat16* RESTRICT C,
287 int lda,
288 int ldb,
289 int ldc,
290 int K) {
291 tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
292 }
293
294 template <int BLOCK_M, int BLOCK_N>
tinygemm_kernel(const float * RESTRICT A,const int8_t * RESTRICT B,const float * RESTRICT scales,float * RESTRICT C,int lda,int ldb,int ldc,int K)295 inline void tinygemm_kernel(
296 const float* RESTRICT A,
297 const int8_t* RESTRICT B,
298 const float* RESTRICT scales,
299 float* RESTRICT C,
300 int lda,
301 int ldb,
302 int ldc,
303 int K) {
304 tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
305 }
306 #endif
307
308 // non-vectorized version
309 template <int BLOCK_M, int BLOCK_N, typename T>
tinygemm_kernel(const T * RESTRICT A,const int8_t * RESTRICT B,const T * RESTRICT scales,T * RESTRICT C,int lda,int ldb,int ldc,int K)310 inline void tinygemm_kernel(
311 const T* RESTRICT A,
312 const int8_t* RESTRICT B,
313 const T* RESTRICT scales,
314 T* RESTRICT C,
315 int lda,
316 int ldb,
317 int ldc,
318 int K) {
319
320 for (const auto m : c10::irange(BLOCK_M)) {
321 for (const auto n : c10::irange(BLOCK_N)) {
322 float c_val = 0;
323 float scale_val = static_cast<float>(scales[n]);
324 for (const auto k : c10::irange(K)) {
325 float a_val = static_cast<float>(A[m * lda + k]);
326 float b_val = static_cast<float>(B[n * ldb + k]);
327 c_val += a_val * (b_val * scale_val);
328 }
329 C[m * ldc + n] = c_val;
330 }
331 }
332 }
333
334 #define LAUNCH_TINYGEMM_KERNEL(MB_SIZE, NB_SIZE) \
335 tinygemm_kernel<MB_SIZE, NB_SIZE>( \
336 A_ptr, B_ptr, S_ptr, C_ptr, \
337 K, K, N, K);
338
339 #define LAUNCH_TINYGEMM_NB_SIZE(MB_SIZE) \
340 switch (nb_size) { \
341 case 1: \
342 LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 1); \
343 break; \
344 case 2: \
345 LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 2); \
346 break; \
347 case 3: \
348 LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 3); \
349 break; \
350 case 4: \
351 LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 4); \
352 break; \
353 default: \
354 TORCH_CHECK(false, "Unsupported n block size: ", nb_size); \
355 break; \
356 }
357
358 template<typename T>
int8pack_mm_kernel_(const Tensor & C,const Tensor & A,const Tensor & B,const Tensor & scales)359 void int8pack_mm_kernel_(
360 const Tensor& C,
361 const Tensor& A,
362 const Tensor& B,
363 const Tensor& scales) {
364
365 const auto* A_data = A.const_data_ptr<T>();
366 const auto* B_data = B.const_data_ptr<int8_t>();
367 auto* C_data = C.data_ptr<T>();
368 const auto* S_data = scales.const_data_ptr<T>();
369
370 int M = A.size(0);
371 int N = B.size(0);
372 int K = A.size(1);
373
374 constexpr int BLOCK_M = 4;
375 constexpr int BLOCK_N = 4;
376
377 const int MB = (M + BLOCK_M - 1) / BLOCK_M;
378 const int NB = (N + BLOCK_N - 1) / BLOCK_N;
379
380 at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
381 int mb{0}, nb{0};
382 data_index_init(begin, mb, MB, nb, NB);
383
384 for (const auto i : c10::irange(begin, end)) {
385 (void)i;
386
387 int mb_start = mb * BLOCK_M;
388 int mb_size = std::min(BLOCK_M, M - mb_start);
389 int nb_start = nb * BLOCK_N;
390 int nb_size = std::min(BLOCK_N, N - nb_start);
391
392 const auto* A_ptr = A_data + mb_start * K;
393 const auto* B_ptr = B_data + nb_start * K;
394 const auto* S_ptr = S_data + nb_start;
395 auto* C_ptr = C_data + mb_start * N + nb_start;
396
397 switch (mb_size) {
398 case 1:
399 LAUNCH_TINYGEMM_NB_SIZE(1);
400 break;
401 case 2:
402 LAUNCH_TINYGEMM_NB_SIZE(2);
403 break;
404 case 3:
405 LAUNCH_TINYGEMM_NB_SIZE(3);
406 break;
407 case 4:
408 LAUNCH_TINYGEMM_NB_SIZE(4);
409 break;
410 default:
411 TORCH_CHECK(false, "Unsupported m block size: ", mb_size);
412 }
413
414 // move to the next index
415 data_index_step(mb, MB, nb, NB);
416 }
417 });
418 }
419
int8pack_mm_kernel(const Tensor & C,const Tensor & A,const Tensor & B,const Tensor & scales)420 void int8pack_mm_kernel(
421 const Tensor& C,
422 const Tensor& A,
423 const Tensor& B,
424 const Tensor& scales) {
425 if (C.dtype() == kHalf) {
426 int8pack_mm_kernel_<Half>(C, A, B, scales);
427 } else if (C.dtype() == kBFloat16) {
428 int8pack_mm_kernel_<BFloat16>(C, A, B, scales);
429 } else {
430 int8pack_mm_kernel_<float>(C, A, B, scales);
431 }
432 }
433
434 } // anonymous namespace
435
436 ALSO_REGISTER_AVX512_DISPATCH(int8pack_mm_stub, &int8pack_mm_kernel);
437
438 } // at::native
439