1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/CPUBlas.h>
3 #include <ATen/native/mkl/LinearAlgebra.h>
4 #include <ATen/native/mkldnn/Matmul.h>
5 #include <ATen/Config.h>
6
7 #include <c10/util/SmallBuffer.h>
8 #include <c10/util/irange.h>
9
10 #include <climits>
11
12 #if AT_BUILD_WITH_BLAS()
13 #if C10_IOS
14 #include <Accelerate/Accelerate.h>
15 #else
16 extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc);
17 extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc);
18 extern "C" void cgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc);
19 extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc);
20 #ifdef BLAS_HAS_SBGEMM
21 extern "C" void sbgemm_(char *transa, char *transb, int *m, int *n, int *k,
22 float *alpha,
23 const at::BFloat16 *a, int *lda,
24 const at::BFloat16 *b, int *ldb,
25 float *beta,
26 float *c, int *ldc);
27 #endif // BLAS_HAS_SBGEMM
28 extern "C" void cswap_(int *n, const void *x, int *incx, void *y, int *incy);
29 extern "C" void dcopy_(int *n, const double *x, int *incx, double *y, int *incy);
30 extern "C" void scopy_(int *n, const float *x, int *incx, float *y, int *incy);
31 extern "C" void zcopy_(int *n, const void *x, int *incx, void *y, int *incy);
32 extern "C" void ccopy_(int *n, const void *x, int *incx, void *y, int *incy);
33 extern "C" void daxpy_(int *n, double *a, const double *x, int *incx, double *y, int *incy);
34 extern "C" void saxpy_(int *n, float *a, const float *x, int *incx, float *y, int *incy);
35 extern "C" void caxpy_(int *n, void *a, const void *x, int *incx, void *y, int *incy);
36 extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int *incy);
37 #endif // C10_IOS
38 #endif // AT_BUILD_WITH_BLAS
39
40 #ifdef USE_FBGEMM
41 #include <fbgemm/FbgemmI64.h>
42 #endif // USE_FBGEMM
43
44 #if AT_MKLDNN_ENABLED()
45 #include <oneapi/dnnl/dnnl_version.h>
46 #endif // oneDNN
47
48 #define ONEDNN_UKERNEL_ENABLED (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=5)
49
50 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
51 #include <oneapi/dnnl/dnnl_ukernel.hpp>
52 #include <oneapi/dnnl/dnnl.hpp>
53 #endif // oneDNN BRGEMM
54
55 namespace at::native::cpublas {
56 namespace internal {
57
normalize_last_dims(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,int64_t * lda,int64_t * ldb,int64_t * ldc)58 void normalize_last_dims(
59 TransposeType transa, TransposeType transb,
60 int64_t m, int64_t n, int64_t k,
61 int64_t *lda, int64_t *ldb, int64_t *ldc) {
62 if (n == 1) {
63 *ldc = m;
64 }
65
66 if(transa != TransposeType::NoTranspose) {
67 if (m == 1) {
68 *lda = k;
69 }
70 } else if(k == 1) {
71 *lda = m;
72 }
73
74 if(transb != TransposeType::NoTranspose) {
75 if (k == 1) {
76 *ldb = n;
77 }
78 } else if (n == 1) {
79 *ldb = k;
80 }
81 }
82 } // namespace internal
83
84 namespace {
85 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunneeded-internal-declaration")
use_blas_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,int64_t lda,int64_t ldb,int64_t ldc)86 bool use_blas_gemm(
87 TransposeType transa, TransposeType transb,
88 int64_t m, int64_t n, int64_t k,
89 int64_t lda, int64_t ldb, int64_t ldc) {
90 const bool transa_ = transa != TransposeType::NoTranspose;
91 const bool transb_ = transb != TransposeType::NoTranspose;
92 return (
93 (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) &&
94 (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) &&
95 (lda >= std::max(int64_t{1}, (transa_ ? k : m))) &&
96 (ldb >= std::max(int64_t{1}, (transb_ ? n : k))) &&
97 (ldc >= std::max(int64_t{1}, m)));
98 }
C10_DIAGNOSTIC_POP()99 C10_DIAGNOSTIC_POP()
100
101 #ifdef USE_FBGEMM
102 fbgemm::matrix_op_t to_fbgemm(TransposeType trans) {
103 switch (trans) {
104 case TransposeType::Transpose: return fbgemm::matrix_op_t::Transpose;
105 case TransposeType::NoTranspose: return fbgemm::matrix_op_t::NoTranspose;
106 case TransposeType::ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm");
107 }
108 TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
109 }
110 #endif // USE_FBGEMM
111
112 #if (AT_BUILD_WITH_BLAS() && C10_IOS)
to_apple_accelerate_transpose(TransposeType trans)113 CBLAS_TRANSPOSE to_apple_accelerate_transpose(TransposeType trans) {
114 switch (trans) {
115 case TransposeType::Transpose: return CblasTrans;
116 case TransposeType::NoTranspose: return CblasNoTrans;
117 case TransposeType::ConjTranspose: return CblasConjTrans;
118 }
119 TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
120 }
121 #endif
122
123 } // namespace (anonymous)
124
125 DEFINE_DISPATCH(gemm_stub);
126
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const double alpha,const double * a,int64_t lda,const double * b,int64_t ldb,const double beta,double * c,int64_t ldc)127 void gemm(
128 TransposeType transa, TransposeType transb,
129 int64_t m, int64_t n, int64_t k,
130 const double alpha,
131 const double *a, int64_t lda,
132 const double *b, int64_t ldb,
133 const double beta,
134 double *c, int64_t ldc) {
135 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
136 #if AT_BUILD_WITH_BLAS()
137 if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
138 int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
139 double alpha_ = alpha, beta_ = beta;
140 #if C10_IOS
141 CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa);
142 CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb);
143 cblas_dgemm(CblasColMajor,
144 transa_, transb_,
145 m_, n_, k_,
146 alpha_,
147 a, lda_,
148 b, ldb_,
149 beta_,
150 c, ldc_);
151 #else
152 char transa_ = to_blas(transa), transb_ = to_blas(transb);
153 dgemm_(
154 &transa_, &transb_,
155 &m_, &n_, &k_,
156 &alpha_,
157 a, &lda_,
158 b, &ldb_,
159 &beta_,
160 c, &ldc_);
161 #endif
162 return;
163 }
164 #endif
165 gemm_stub(
166 at::kCPU, at::kDouble,
167 transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
168 }
169
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const float * a,int64_t lda,const float * b,int64_t ldb,const float beta,float * c,int64_t ldc)170 void gemm(
171 TransposeType transa, TransposeType transb,
172 int64_t m, int64_t n, int64_t k,
173 const float alpha,
174 const float *a, int64_t lda,
175 const float *b, int64_t ldb,
176 const float beta,
177 float *c, int64_t ldc) {
178 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
179 #if AT_MKLDNN_ENABLED()
180 if (mkldnn_bf32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
181 return;
182 }
183 #endif
184 #if AT_BUILD_WITH_BLAS()
185 if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
186 int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
187 float alpha_ = alpha, beta_ = beta;
188 #if C10_IOS
189 CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa);
190 CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb);
191 cblas_sgemm(CblasColMajor,
192 transa_, transb_,
193 m_, n_, k_,
194 alpha_,
195 a, lda_,
196 b, ldb_,
197 beta_,
198 c, ldc_);
199 #else
200 char transa_ = to_blas(transa), transb_ = to_blas(transb);
201 sgemm_(
202 &transa_, &transb_,
203 &m_, &n_, &k_,
204 &alpha_,
205 a, &lda_,
206 b, &ldb_,
207 &beta_,
208 c, &ldc_);
209 #endif
210 return;
211 }
212 #endif
213 gemm_stub(
214 at::kCPU, at::kFloat,
215 transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
216 }
217
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const c10::complex<double> alpha,const c10::complex<double> * a,int64_t lda,const c10::complex<double> * b,int64_t ldb,const c10::complex<double> beta,c10::complex<double> * c,int64_t ldc)218 void gemm(
219 TransposeType transa, TransposeType transb,
220 int64_t m, int64_t n, int64_t k,
221 const c10::complex<double> alpha,
222 const c10::complex<double> *a, int64_t lda,
223 const c10::complex<double> *b, int64_t ldb,
224 const c10::complex<double> beta,
225 c10::complex<double> *c, int64_t ldc) {
226 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
227 #if AT_BUILD_WITH_BLAS()
228 if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
229 int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
230 c10::complex<double> alpha_ = alpha, beta_ = beta;
231 #if C10_IOS
232 CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa);
233 CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb);
234 cblas_zgemm(CblasColMajor,
235 transa_, transb_,
236 m_, n_, k_,
237 &alpha_,
238 a, lda_,
239 b, ldb_,
240 &beta_,
241 c, ldc_);
242 #else
243 char transa_ = to_blas(transa), transb_ = to_blas(transb);
244 zgemm_(
245 &transa_, &transb_,
246 &m_, &n_, &k_,
247 &alpha_,
248 a, &lda_,
249 b, &ldb_,
250 &beta_,
251 c, &ldc_);
252 #endif
253 return;
254 }
255 #endif
256 gemm_stub(
257 at::kCPU, at::kComplexDouble,
258 transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
259 }
260
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const c10::complex<float> alpha,const c10::complex<float> * a,int64_t lda,const c10::complex<float> * b,int64_t ldb,const c10::complex<float> beta,c10::complex<float> * c,int64_t ldc)261 void gemm(
262 TransposeType transa, TransposeType transb,
263 int64_t m, int64_t n, int64_t k,
264 const c10::complex<float> alpha,
265 const c10::complex<float> *a, int64_t lda,
266 const c10::complex<float> *b, int64_t ldb,
267 const c10::complex<float> beta,
268 c10::complex<float> *c, int64_t ldc) {
269 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
270 #if AT_BUILD_WITH_BLAS()
271 if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
272 int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
273 c10::complex<float> alpha_ = alpha, beta_ = beta;
274 #if C10_IOS
275 CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa);
276 CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb);
277 cblas_cgemm(CblasColMajor,
278 transa_, transb_,
279 m_, n_, k_,
280 &alpha_,
281 a, lda_,
282 b, ldb_,
283 &beta_,
284 c, ldc_);
285 #else
286 char transa_ = to_blas(transa), transb_ = to_blas(transb);
287 cgemm_(
288 &transa_, &transb_,
289 &m_, &n_, &k_,
290 &alpha_,
291 a, &lda_,
292 b, &ldb_,
293 &beta_,
294 c, &ldc_);
295 #endif
296 return;
297 }
298 #endif
299 gemm_stub(
300 at::kCPU, at::kComplexFloat,
301 transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
302 }
303
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const at::BFloat16 * a,int64_t lda,const at::BFloat16 * b,int64_t ldb,const float beta,at::BFloat16 * c,int64_t ldc)304 void gemm(
305 TransposeType transa, TransposeType transb,
306 int64_t m, int64_t n, int64_t k,
307 const float alpha,
308 const at::BFloat16 *a, int64_t lda,
309 const at::BFloat16 *b, int64_t ldb,
310 const float beta,
311 at::BFloat16 *c, int64_t ldc) {
312 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
313 #if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM)
314 if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
315 int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
316 char transa_ = to_blas(transa), transb_ = to_blas(transb);
317 float alpha_ = alpha, beta_ = beta;
318 int c_size = n_ * ldc_;
319 // C matrix in OpenBLAS sbgemm are of type "float" so we have to convert, copy and copy back.
320 std::vector<float> float_v(c, c + c_size);
321 sbgemm_(&transa_, &transb_,
322 &m_, &n_, &k_,
323 &alpha_,
324 a, &lda_,
325 b, &ldb_,
326 &beta_,
327 float_v.data(), &ldc_);
328 for (auto cv: float_v) {
329 *(c++) = c10::convert<at::BFloat16>(cv);
330 }
331 return;
332 }
333 #endif
334 #if AT_MKLDNN_ENABLED()
335 if (mkldnn_bf16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
336 return;
337 }
338 #endif
339 gemm_stub(
340 at::kCPU, at::kBFloat16,
341 transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
342 }
343
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const at::Half * a,int64_t lda,const at::Half * b,int64_t ldb,const float beta,at::Half * c,int64_t ldc)344 void gemm(
345 TransposeType transa, TransposeType transb,
346 int64_t m, int64_t n, int64_t k,
347 const float alpha,
348 const at::Half *a, int64_t lda,
349 const at::Half *b, int64_t ldb,
350 const float beta,
351 at::Half *c, int64_t ldc) {
352 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
353 #if AT_MKLDNN_ENABLED()
354 if (mkldnn_fp16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
355 return;
356 }
357 #endif
358 gemm_stub(
359 at::kCPU, at::kHalf,
360 transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
361 }
362
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const at::BFloat16 * a,int64_t lda,const at::BFloat16 * b,int64_t ldb,const float beta,float * c,int64_t ldc)363 void gemm(
364 TransposeType transa, TransposeType transb,
365 int64_t m, int64_t n, int64_t k,
366 const float alpha,
367 const at::BFloat16 *a, int64_t lda,
368 const at::BFloat16 *b, int64_t ldb,
369 const float beta,
370 float *c, int64_t ldc) {
371 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
372 #if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM)
373 if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
374 int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
375 char transa_ = to_blas(transa), transb_ = to_blas(transb);
376 float alpha_ = alpha, beta_ = beta;
377 sbgemm_(&transa_, &transb_,
378 &m_, &n_, &k_,
379 &alpha_,
380 a, &lda_,
381 b, &ldb_,
382 &beta_,
383 c, &ldc_);
384 return;
385 }
386 #endif
387 #ifdef MKL_HAS_SBGEMM
388 if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
389 int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
390 mkl_gemm_bf16bf16f32(transa, transb, m_, n_, k_, alpha, a, lda_, b, ldb_, beta, c, ldc_);
391 return;
392 }
393 #endif
394 // for the fallback path, first compute gemm with beta = 0,
395 // and then add c in full precision.
396 int64_t c_size = n * m;
397 std::vector<at::BFloat16> bfloat_c(c_size, 0.f);
398 gemm_stub(
399 at::kCPU, at::kBFloat16,
400 transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, bfloat_c.data(), m);
401 for (const auto j : c10::irange(n)) {
402 for (const auto i : c10::irange(m)) {
403 auto offset = j * ldc + i;
404 // beta == 0 won't propagate NaN from C
405 if (beta == 0.f) {
406 c[offset] = c10::convert<float>(bfloat_c[j * m + i]);
407 } else {
408 c[offset] = beta * c[offset] + c10::convert<float>(bfloat_c[j * m + i]);
409 }
410 }
411 }
412 }
413
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const at::Half * a,int64_t lda,const at::Half * b,int64_t ldb,const float beta,float * c,int64_t ldc)414 void gemm(
415 TransposeType transa, TransposeType transb,
416 int64_t m, int64_t n, int64_t k,
417 const float alpha,
418 const at::Half *a, int64_t lda,
419 const at::Half *b, int64_t ldb,
420 const float beta,
421 float *c, int64_t ldc) {
422 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
423 #ifdef MKL_HAS_SHGEMM
424 if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
425 int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
426 mkl_gemm_f16f16f32(transa, transb, m_, n_, k_, alpha, a, lda_, b, ldb_, beta, c, ldc_);
427 return;
428 }
429 #endif
430 // for the fallback path, first compute gemm with beta = 0,
431 // and then add c in full precision.
432 int64_t c_size = n * m;
433 std::vector<at::Half> float16_c(c_size, 0.f);
434 gemm_stub(
435 at::kCPU, at::kHalf,
436 transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float16_c.data(), m);
437 for (const auto j : c10::irange(n)) {
438 for (const auto i : c10::irange(m)) {
439 auto offset = j * ldc + i;
440 // beta == 0 won't propagate NaN from C
441 if (beta == 0.f) {
442 c[offset] = c10::convert<float>(float16_c[j * m + i]);
443 } else {
444 c[offset] = beta * c[offset] + c10::convert<float>(float16_c[j * m + i]);
445 }
446 }
447 }
448 }
449
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const int64_t alpha,const int64_t * a,int64_t lda,const int64_t * b,int64_t ldb,const int64_t beta,int64_t * c,int64_t ldc)450 void gemm(
451 TransposeType transa, TransposeType transb,
452 int64_t m, int64_t n, int64_t k,
453 const int64_t alpha,
454 const int64_t *a, int64_t lda,
455 const int64_t *b, int64_t ldb,
456 const int64_t beta,
457 int64_t *c, int64_t ldc) {
458 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
459 #ifdef USE_FBGEMM
460 if (alpha == 1 && (beta == 0 || beta == 1)) {
461 // In FBGEMM, we assume row-major ordering; However, here we assume the
462 // column-major ordering following the FORTRAN tradition in BLAS interface
463 // in this function: we can configure the layout (row/column-major ordering)
464 // of A and B by changing transa_ and transb_, but we cannot change the
465 // layout of C with this FORTRAN-style BLAS interface.
466 //
467 // The workaround is that we compute
468 // C^T (n x m) = B^T (n x k) * A^T (k x m) instead.
469 //
470 // In this way we view C^T as the row-major ordering when passing to FBGEMM.
471 fbgemm::cblas_gemm_i64_i64acc(
472 to_fbgemm(transb),
473 to_fbgemm(transa),
474 n,
475 m,
476 k,
477 b,
478 ldb,
479 a,
480 lda,
481 beta == 1,
482 c,
483 ldc);
484 return;
485 }
486 #endif
487
488 gemm_stub(
489 kCPU, kLong,
490 transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
491 }
492
493 template <typename scalar_t>
gemm_batched_mkl_impl(TransposeType transa,TransposeType transb,int64_t batch_size,int64_t m,int64_t n,int64_t k,scalar_t alpha,const scalar_t ** a,int64_t lda,const scalar_t ** b,int64_t ldb,scalar_t beta,scalar_t ** c,int64_t ldc)494 static void gemm_batched_mkl_impl(
495 TransposeType transa, TransposeType transb,
496 int64_t batch_size, int64_t m, int64_t n, int64_t k,
497 scalar_t alpha,
498 const scalar_t **a, int64_t lda,
499 const scalar_t **b, int64_t ldb,
500 scalar_t beta,
501 scalar_t **c, int64_t ldc) {
502 for (int64_t i = 0; i < batch_size;) {
503 int sub_batch = std::min(batch_size - i, int64_t{INT_MAX});
504 mkl_gemm_batched(transa, transb, sub_batch, m, n, k, alpha,
505 &a[i], lda, &b[i], ldb, beta, &c[i], ldc);
506 i += sub_batch;
507 }
508 }
509
510 template <typename scalar_t>
511 using is_blas_library_type = std::integral_constant<bool,
512 std::is_same_v<scalar_t, double> ||
513 std::is_same_v<scalar_t, float> ||
514 std::is_same_v<scalar_t, c10::complex<double>> ||
515 std::is_same_v<scalar_t, c10::complex<float>>>;
516
517 template <typename scalar_t>
gemm_batched_generic(TransposeType transa,TransposeType transb,int64_t batch_size,int64_t m,int64_t n,int64_t k,scalar_t alpha,const scalar_t ** a,int64_t lda,const scalar_t ** b,int64_t ldb,scalar_t beta,scalar_t ** c,int64_t ldc)518 void gemm_batched_generic(
519 TransposeType transa, TransposeType transb,
520 int64_t batch_size, int64_t m, int64_t n, int64_t k,
521 scalar_t alpha,
522 const scalar_t **a, int64_t lda,
523 const scalar_t **b, int64_t ldb,
524 scalar_t beta,
525 scalar_t **c, int64_t ldc) {
526 for (const auto batch : c10::irange(batch_size)) {
527 gemm(transa, transb, m, n, k, alpha, a[batch], lda, b[batch], ldb, beta, c[batch], ldc);
528 }
529 }
530
531 template <typename scalar_t>
gemm_batched(TransposeType transa,TransposeType transb,int64_t batch_size,int64_t m,int64_t n,int64_t k,scalar_t alpha,const scalar_t ** a,int64_t lda,const scalar_t ** b,int64_t ldb,scalar_t beta,scalar_t ** c,int64_t ldc)532 void gemm_batched(
533 TransposeType transa, TransposeType transb,
534 int64_t batch_size, int64_t m, int64_t n, int64_t k,
535 scalar_t alpha,
536 const scalar_t **a, int64_t lda,
537 const scalar_t **b, int64_t ldb,
538 scalar_t beta,
539 scalar_t **c, int64_t ldc) {
540 if (batch_size == 1) {
541 return gemm(transa, transb, m, n, k, alpha, a[0], lda, b[0], ldb, beta, c[0], ldc);
542 }
543
544 if constexpr (AT_MKL_ENABLED() && is_blas_library_type<scalar_t>::value) {
545 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
546 if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
547 gemm_batched_mkl_impl(
548 transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
549 } else {
550 gemm_batched_generic(
551 transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
552 }
553 } else {
554 gemm_batched_generic(
555 transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
556 }
557 }
558
559 template <typename scalar_t>
gemm_batched_with_stride_generic(TransposeType transa,TransposeType transb,int64_t batch_size,int64_t m,int64_t n,int64_t k,scalar_t alpha,const scalar_t * a,int64_t lda,int64_t batch_stride_a,const scalar_t * b,int64_t ldb,int64_t batch_stride_b,scalar_t beta,scalar_t * c,int64_t ldc,int64_t batch_stride_c)560 void gemm_batched_with_stride_generic(
561 TransposeType transa, TransposeType transb,
562 int64_t batch_size, int64_t m, int64_t n, int64_t k,
563 scalar_t alpha,
564 const scalar_t *a, int64_t lda, int64_t batch_stride_a,
565 const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
566 scalar_t beta,
567 scalar_t *c, int64_t ldc, int64_t batch_stride_c) {
568 for (const auto batch : c10::irange(batch_size)) {
569 const auto a_batch = a + batch_stride_a * batch;
570 const auto b_batch = b + batch_stride_b * batch;
571 const auto c_batch = c + batch_stride_c * batch;
572 gemm(transa, transb, m, n, k, alpha, a_batch, lda, b_batch, ldb, beta, c_batch, ldc);
573 }
574 }
575
576 template <typename scalar_t>
gemm_batched_with_stride(TransposeType transa,TransposeType transb,int64_t batch_size,int64_t m,int64_t n,int64_t k,scalar_t alpha,const scalar_t * a,int64_t lda,int64_t batch_stride_a,const scalar_t * b,int64_t ldb,int64_t batch_stride_b,scalar_t beta,scalar_t * c,int64_t ldc,int64_t batch_stride_c)577 void gemm_batched_with_stride(
578 TransposeType transa, TransposeType transb,
579 int64_t batch_size, int64_t m, int64_t n, int64_t k,
580 scalar_t alpha,
581 const scalar_t *a, int64_t lda, int64_t batch_stride_a,
582 const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
583 scalar_t beta,
584 scalar_t *c, int64_t ldc, int64_t batch_stride_c) {
585 if (batch_size == 1) {
586 return gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
587 }
588
589 if constexpr (AT_MKL_ENABLED() && is_blas_library_type<scalar_t>::value) {
590 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
591 if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
592 c10::SmallBuffer<const scalar_t*, 16> a_ptrs(batch_size);
593 c10::SmallBuffer<const scalar_t*, 16> b_ptrs(batch_size);
594 c10::SmallBuffer<scalar_t*, 16> c_ptrs(batch_size);
595
596 for (const auto batch : c10::irange(batch_size)) {
597 a_ptrs[batch] = a + batch_stride_a * batch;
598 b_ptrs[batch] = b + batch_stride_b * batch;
599 c_ptrs[batch] = c + batch_stride_c * batch;
600 }
601 gemm_batched_mkl_impl(
602 transa, transb, batch_size, m, n, k, alpha, a_ptrs.data(), lda,
603 b_ptrs.data(), ldb, beta, c_ptrs.data(), ldc);
604 } else {
605 gemm_batched_with_stride_generic(
606 transa, transb, batch_size, m, n, k, alpha, a, lda, batch_stride_a,
607 b, ldb, batch_stride_b, beta, c, ldc, batch_stride_c);
608 }
609 } else {
610 gemm_batched_with_stride_generic(transa, transb, batch_size, m, n, k, alpha,
611 a, lda, batch_stride_a, b, ldb, batch_stride_b,
612 beta, c, ldc, batch_stride_c);
613 }
614 }
615
616 #define INSTANTIATE_BATCHED_GEMM(scalar_t, DType) \
617 template void gemm_batched( \
618 TransposeType transa, TransposeType transb, \
619 int64_t batch_size, int64_t m, int64_t n, int64_t k, \
620 scalar_t alpha, \
621 const scalar_t **a, int64_t lda, \
622 const scalar_t **b, int64_t ldb, \
623 scalar_t beta, \
624 scalar_t **c, int64_t ldc); \
625 template void gemm_batched_with_stride( \
626 TransposeType transa, TransposeType transb, \
627 int64_t batch_size, int64_t m, int64_t n, int64_t k, \
628 scalar_t alpha, \
629 const scalar_t *a, int64_t lda, int64_t batch_stride_a, \
630 const scalar_t *b, int64_t ldb, int64_t batch_stride_b, \
631 scalar_t beta, \
632 scalar_t *c, int64_t ldc, int64_t batch_stride_c);
633
634 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(INSTANTIATE_BATCHED_GEMM)
635
636 DEFINE_DISPATCH(axpy_stub);
637
axpy(int64_t n,double a,const double * x,int64_t incx,double * y,int64_t incy)638 void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy) {
639 if(n == 1)
640 {
641 incx = 1;
642 incy = 1;
643 }
644 #if AT_BUILD_WITH_BLAS()
645 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
646 {
647 int i_n = (int)n;
648 int i_incx = (int)incx;
649 int i_incy = (int)incy;
650 #if C10_IOS
651 cblas_daxpy(i_n, a, x, i_incx, y, i_incy);
652 #else
653 daxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
654 #endif
655 return;
656 }
657 #endif
658 axpy_stub(
659 kCPU, at::kDouble,
660 n, a, x, incx, y, incy);
661 }
662
axpy(int64_t n,float a,const float * x,int64_t incx,float * y,int64_t incy)663 void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy) {
664 if(n == 1)
665 {
666 incx = 1;
667 incy = 1;
668 }
669 #if AT_BUILD_WITH_BLAS()
670 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
671 {
672 int i_n = (int)n;
673 int i_incx = (int)incx;
674 int i_incy = (int)incy;
675 #if C10_IOS
676 cblas_saxpy(i_n, a, x, i_incx, y, i_incy);
677 #else
678 saxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
679 #endif
680 return;
681 }
682 #endif
683 axpy_stub(
684 kCPU, at::kFloat,
685 n, a, x, incx, y, incy);
686 }
687
axpy(int64_t n,c10::complex<double> a,const c10::complex<double> * x,int64_t incx,c10::complex<double> * y,int64_t incy)688 void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy) {
689 if(n == 1)
690 {
691 incx = 1;
692 incy = 1;
693 }
694 #if AT_BUILD_WITH_BLAS()
695 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
696 {
697 int i_n = (int)n;
698 int i_incx = (int)incx;
699 int i_incy = (int)incy;
700 #if C10_IOS
701 cblas_zaxpy(i_n, &a, x, i_incx, y, i_incy);
702 #else
703 zaxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
704 #endif
705 return;
706 }
707 #endif
708 axpy_stub(
709 kCPU, at::kComplexDouble,
710 n, a, x, incx, y, incy);
711 }
712
axpy(int64_t n,c10::complex<float> a,const c10::complex<float> * x,int64_t incx,c10::complex<float> * y,int64_t incy)713 void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy) {
714 if(n == 1)
715 {
716 incx = 1;
717 incy = 1;
718 }
719 #if AT_BUILD_WITH_BLAS()
720 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
721 {
722 int i_n = (int)n;
723 int i_incx = (int)incx;
724 int i_incy = (int)incy;
725 #if C10_IOS
726 cblas_caxpy(i_n, &a, x, i_incx, y, i_incy);
727 #else
728 caxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
729 #endif
730 return;
731 }
732 #endif
733 axpy_stub(
734 kCPU, at::kComplexFloat,
735 n, a, x, incx, y, incy);
736 }
737
738 DEFINE_DISPATCH(copy_stub);
739
copy(int64_t n,const double * x,int64_t incx,double * y,int64_t incy)740 void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy) {
741 if(n == 1)
742 {
743 incx = 1;
744 incy = 1;
745 }
746 #if AT_BUILD_WITH_BLAS()
747 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
748 int i_n = (int)n;
749 int i_incx = (int)incx;
750 int i_incy = (int)incy;
751 #if C10_IOS
752 cblas_dcopy(i_n, x, i_incx, y, i_incy);
753 #else
754 dcopy_(&i_n, x, &i_incx, y, &i_incy);
755 #endif
756 return;
757 }
758 #endif
759 copy_stub(
760 kCPU, at::kDouble,
761 n, x, incx, y, incy);
762 }
763
copy(int64_t n,const float * x,int64_t incx,float * y,int64_t incy)764 void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy) {
765 if(n == 1)
766 {
767 incx = 1;
768 incy = 1;
769 }
770 #if AT_BUILD_WITH_BLAS()
771 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
772 int i_n = (int)n;
773 int i_incx = (int)incx;
774 int i_incy = (int)incy;
775 #if C10_IOS
776 cblas_scopy(i_n, x, i_incx, y, i_incy);
777 #else
778 scopy_(&i_n, x, &i_incx, y, &i_incy);
779 #endif
780 return;
781 }
782 #endif
783 copy_stub(
784 kCPU, at::kFloat,
785 n, x, incx, y, incy);
786 }
787
copy(int64_t n,const c10::complex<double> * x,int64_t incx,c10::complex<double> * y,int64_t incy)788 void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy) {
789 if(n == 1)
790 {
791 incx = 1;
792 incy = 1;
793 }
794 #if AT_BUILD_WITH_BLAS()
795 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
796 int i_n = (int)n;
797 int i_incx = (int)incx;
798 int i_incy = (int)incy;
799 #if C10_IOS
800 cblas_zcopy(i_n, x, i_incx, y, i_incy);
801 #else
802 zcopy_(&i_n, x, &i_incx, y, &i_incy);
803 #endif
804 return;
805 }
806 #endif
807 copy_stub(
808 kCPU, at::kComplexDouble,
809 n, x, incx, y, incy);
810 }
811
copy(int64_t n,const c10::complex<float> * x,int64_t incx,c10::complex<float> * y,int64_t incy)812 void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy){
813 if(n == 1)
814 {
815 incx = 1;
816 incy = 1;
817 }
818 #if AT_BUILD_WITH_BLAS()
819 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
820 int i_n = (int)n;
821 int i_incx = (int)incx;
822 int i_incy = (int)incy;
823 #if C10_IOS
824 cblas_ccopy(i_n, &x, i_incx, y, i_incy);
825 #else
826 ccopy_(&i_n, x, &i_incx, y, &i_incy);
827 #endif
828 return;
829 }
830 #endif
831 copy_stub(
832 kCPU, at::kComplexFloat,
833 n, x, incx, y, incy);
834 }
835
836 // oneDNN BRGEMM
837 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
838 struct BrgemmKey {
839 int64_t M;
840 int64_t N;
841 int64_t K;
842 int64_t batch_size;
843 int64_t lda;
844 int64_t ldb;
845 int64_t ldc;
846 ScalarType dt_a;
847 ScalarType dt_b;
848 ScalarType dt_c;
849 float alpha;
850 float beta;
BrgemmKeyat::native::cpublas::BrgemmKey851 BrgemmKey(
852 int64_t M,
853 int64_t N,
854 int64_t K,
855 int64_t batch_size,
856 int64_t lda,
857 int64_t ldb,
858 int64_t ldc,
859 ScalarType dt_a,
860 ScalarType dt_b,
861 ScalarType dt_c,
862 float alpha,
863 float beta)
864 : M(M),
865 N(N),
866 K(K),
867 batch_size(batch_size),
868 lda(lda),
869 ldb(ldb),
870 ldc(ldc),
871 dt_a(dt_a),
872 dt_b(dt_b),
873 dt_c(dt_c),
874 alpha(alpha),
875 beta(beta) {}
operator ==at::native::cpublas::BrgemmKey876 bool operator==(const BrgemmKey& other) const {
877 return M == other.M && N == other.N && K == other.K &&
878 batch_size == other.batch_size && lda == other.lda &&
879 ldb == other.ldb && ldc == other.ldc && dt_a == other.dt_a &&
880 dt_b == other.dt_b && dt_c == other.dt_c && alpha == other.alpha &&
881 beta == other.beta;
882 }
883 };
884
885 struct PackKey {
886 int64_t K;
887 int64_t N;
888 int64_t ld_in;
889 int64_t ld_out;
890 ScalarType dt_in;
891 ScalarType dt_out;
PackKeyat::native::cpublas::PackKey892 PackKey(
893 int64_t K,
894 int64_t N,
895 int64_t ld_in,
896 int64_t ld_out,
897 ScalarType dt_in,
898 ScalarType dt_out)
899 : K(K),
900 N(N),
901 ld_in(ld_in),
902 ld_out(ld_out),
903 dt_in(dt_in),
904 dt_out(dt_out) {}
operator ==at::native::cpublas::PackKey905 bool operator==(const PackKey& other) const {
906 return N == other.N && K == other.K && ld_in == other.ld_in &&
907 ld_out == other.ld_out && dt_in == other.dt_in &&
908 dt_out == other.dt_out;
909 }
910 };
911
get_dnnl_dtype(ScalarType dtype)912 inline dnnl::memory::data_type get_dnnl_dtype(ScalarType dtype) {
913 if (dtype == ScalarType::Float) {
914 return dnnl::memory::data_type::f32;
915 } else if (dtype == ScalarType::BFloat16) {
916 return dnnl::memory::data_type::bf16;
917 } else if (dtype == ScalarType::Half) {
918 return dnnl::memory::data_type::f16;
919 } else if (dtype == ScalarType::Byte) {
920 return dnnl::memory::data_type::u8;
921 } else if (dtype == ScalarType::Char) {
922 return dnnl::memory::data_type::s8;
923 } else {
924 TORCH_CHECK(false, "get_dnnl_dtype expects float/bfloat16/half/int8 tensor input");
925 }
926 }
927
928 template<typename key_t>
929 struct UnsafeUkernelKeyHasher {
930 std::size_t operator()(const key_t& key) const;
931 };
932
933 template<>
operator ()(const BrgemmKey & key) const934 std::size_t UnsafeUkernelKeyHasher<BrgemmKey>::operator()(const BrgemmKey& key) const {
935 // Use beta, M, N, and K to compute hash to reduce the overhead as
936 // batch size, alpha, and data types are unlikely to change within the same kernel and
937 // leading dimensions are likely to be related to M, K, N or use fixed values.
938 std::size_t h = std::hash<float>()(key.beta + 1);
939 h = std::hash<int64_t>()(key.M) ^ (h << 1);
940 h = std::hash<int64_t>()(key.N) ^ (h << 1);
941 h = std::hash<int64_t>()(key.K) ^ (h << 1);
942 h = std::hash<int64_t>()(key.ldc) ^ (h << 1);
943 return h;
944 }
945
946 template<>
operator ()(const PackKey & key) const947 std::size_t UnsafeUkernelKeyHasher<PackKey>::operator()(const PackKey& key) const {
948 // Use K and N to compute hash to reduce the overhead as
949 // data types are unlikely to change and
950 // ld_in/ld_out is likely to be related to K, N or use fixed values
951 std::size_t h = std::hash<int64_t>()(key.K);
952 h = std::hash<int64_t>()(key.N) ^ (h << 1);
953 return h;
954 }
955
956 template <typename key_t, typename value_t>
957 struct KernelCache {
958 using kstore_t = std::unordered_map<key_t, std::shared_ptr<value_t>, UnsafeUkernelKeyHasher<key_t>>;
fetch_or_createat::native::cpublas::KernelCache959 static inline std::shared_ptr<value_t>&& fetch_or_create(
960 const key_t& key,
961 const std::function<std::shared_ptr<value_t>()>& callback) {
962 auto&& search = get_store().find(key);
963 if (search != get_store().end()) {
964 return std::move(search->second);
965 } else {
966 get_store().insert({key, callback()});
967 return std::move(get_store()[key]);
968 }
969 }
970
get_storeat::native::cpublas::KernelCache971 static inline kstore_t& get_store() {
972 static thread_local kstore_t cache_kernels;
973 return cache_kernels;
974 }
975 };
976
977 // Helper struct for convenient brgemm configuration
978 struct GemmHelper {
GemmHelperat::native::cpublas::GemmHelper979 GemmHelper(
980 int64_t M,
981 int64_t N,
982 int64_t K,
983 int64_t bs,
984 int64_t ld_a,
985 int64_t ld_b,
986 int64_t ld_c,
987 ScalarType dt_a,
988 ScalarType dt_b,
989 ScalarType dt_c,
990 const float alpha,
991 const float beta) {
992 // Create brgemm
993 brg = dnnl::ukernel::brgemm(
994 M,
995 N,
996 K,
997 bs,
998 ld_a,
999 ld_b,
1000 ld_c,
1001 get_dnnl_dtype(dt_a),
1002 get_dnnl_dtype(dt_b),
1003 get_dnnl_dtype(dt_c),
1004 alpha,
1005 beta);
1006 // Create a scratchpad buffer for the brgemm execution
1007 scratchpad = std::vector<uint8_t>(brg.get_scratchpad_size());
1008 // Prepare default vector of pairs of tensors A and B offsets for each batch.
1009 A_B_offsets.reserve(1);
1010 A_B_offsets[0] = std::make_pair(0, 0);
1011 }
1012 dnnl::ukernel::brgemm brg;
1013 std::vector<uint8_t> scratchpad;
1014 std::vector<std::pair<int64_t, int64_t>> A_B_offsets;
1015 };
1016
1017 struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
1018 // Fetch/create GemmHelper object and execute brgemm with batch size = 1
1019 template <typename scalar_t_a, typename scalar_t_b, typename scalar_t_c>
callat::native::cpublas::Brgemm1020 static inline void call(
1021 int64_t M,
1022 int64_t N,
1023 int64_t K,
1024 int64_t ld_a,
1025 int64_t ld_b,
1026 int64_t ld_c,
1027 const float alpha,
1028 const float beta,
1029 const scalar_t_a* A,
1030 const scalar_t_b* B,
1031 scalar_t_c* C) {
1032 auto&& key = BrgemmKey(
1033 M,
1034 N,
1035 K,
1036 int64_t(1),
1037 ld_a,
1038 ld_b,
1039 ld_c,
1040 c10::CppTypeToScalarType<scalar_t_a>::value,
1041 c10::CppTypeToScalarType<scalar_t_b>::value,
1042 c10::CppTypeToScalarType<scalar_t_c>::value,
1043 alpha,
1044 beta);
1045 // Fetch/create GemmHelper object
1046 auto&& value = fetch_or_create(key, [&]() {
1047 auto&& v = std::make_shared<GemmHelper>(
1048 M,
1049 N,
1050 K,
1051 1,
1052 ld_a,
1053 ld_b,
1054 ld_c,
1055 c10::CppTypeToScalarType<scalar_t_a>::value,
1056 c10::CppTypeToScalarType<scalar_t_b>::value,
1057 c10::CppTypeToScalarType<scalar_t_c>::value,
1058 alpha,
1059 beta);
1060 (*v).brg.generate();
1061 return std::move(v);
1062 });
1063 if (get_current() != value) {
1064 dnnl::ukernel::brgemm::release_hw_context();
1065 ((*value).brg).set_hw_context();
1066 get_current() = value;
1067 }
1068 ((*value).brg)
1069 .execute(A, B, (*value).A_B_offsets, C, (*value).scratchpad.data());
1070 }
1071
get_currentat::native::cpublas::Brgemm1072 static inline std::shared_ptr<GemmHelper>& get_current() {
1073 static thread_local std::shared_ptr<GemmHelper> current;
1074 return current;
1075 }
1076
device_checkat::native::cpublas::Brgemm1077 static inline bool device_check(ScalarType dtype) {
1078 if (!at::globalContext().userEnabledMkldnn()) {
1079 return false;
1080 }
1081 if (dtype == ScalarType::Half) {
1082 static bool fp16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_fp16;
1083 return fp16_support;
1084 }
1085 return false;
1086 }
1087 };
1088
1089 using pack_t = dnnl::ukernel::brgemm_pack_B;
1090 struct Pack : public KernelCache <PackKey, pack_t> {
callat::native::cpublas::Pack1091 static inline void call(
1092 int64_t K,
1093 int64_t N,
1094 int64_t ld_in,
1095 int64_t ld_out,
1096 ScalarType dt_in,
1097 ScalarType dt_out,
1098 const void* in,
1099 void* out) {
1100 auto&& key = PackKey(K, N, ld_in, ld_out, dt_in, dt_out);
1101 auto&& pack = fetch_or_create(key, [&]() {
1102 auto&& p = std::make_shared<pack_t>(
1103 K, N, ld_in, ld_out, get_dnnl_dtype(dt_in), get_dnnl_dtype(dt_out));
1104 if (need_pack(dt_in)) {
1105 (*p).generate();
1106 }
1107 return std::move(p);
1108 });
1109 if (need_pack(dt_in)) {
1110 (*pack).execute(in, out);
1111 } else {
1112 TORCH_CHECK(false, "No need to pack");
1113 }
1114 }
1115
need_packat::native::cpublas::Pack1116 static inline bool need_pack(ScalarType dtype) {
1117 if (!at::globalContext().userEnabledMkldnn()) {
1118 return false;
1119 }
1120 if (dtype == ScalarType::Half) {
1121 static bool fp16_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx_fp16;
1122 return fp16_pack;
1123 }
1124 return false;
1125 }
1126 };
1127 #endif
1128
brgemm(int64_t M,int64_t N,int64_t K,int64_t ld_a,int64_t ld_b,int64_t ld_c,const float alpha,const float beta,const at::Half * A,const at::Half * B,float * C)1129 void brgemm(
1130 int64_t M,
1131 int64_t N,
1132 int64_t K,
1133 int64_t ld_a,
1134 int64_t ld_b,
1135 int64_t ld_c,
1136 const float alpha,
1137 const float beta,
1138 const at::Half* A,
1139 const at::Half* B,
1140 float* C) {
1141 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1142 if (Brgemm::device_check(ScalarType::Half)) {
1143 Brgemm::call<at::Half, at::Half, float>(
1144 M, N, K, ld_a, ld_b, ld_c, alpha, beta, A, B, C);
1145 return;
1146 }
1147 #endif
1148 TORCH_CHECK(false,
1149 "Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported");
1150 }
1151
brgemm_release()1152 void brgemm_release() {
1153 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1154 dnnl::ukernel::brgemm::release_hw_context();
1155 #endif
1156 }
1157
pack(int64_t K,int64_t N,int64_t ld_in,int64_t ld_out,ScalarType dt_in,ScalarType dt_out,const void * in,void * out)1158 void pack(
1159 int64_t K,
1160 int64_t N,
1161 int64_t ld_in,
1162 int64_t ld_out,
1163 ScalarType dt_in,
1164 ScalarType dt_out,
1165 const void* in,
1166 void* out) {
1167 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1168 Pack::call(K, N, ld_in, ld_out, dt_in, dt_out, in, out);
1169 #else
1170 TORCH_CHECK(false, "pack is only supported on X64 with oneDNN ukernel enabled");
1171 #endif
1172 }
1173
need_pack(ScalarType dt_in)1174 bool need_pack(ScalarType dt_in) {
1175 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1176 return Pack::need_pack(dt_in);
1177 #else
1178 return false;
1179 #endif
1180 }
1181
1182 } // namespace at::native::cpublas
1183