xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/CPUBlas.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/CPUBlas.h>
3 #include <ATen/native/mkl/LinearAlgebra.h>
4 #include <ATen/native/mkldnn/Matmul.h>
5 #include <ATen/Config.h>
6 
7 #include <c10/util/SmallBuffer.h>
8 #include <c10/util/irange.h>
9 
10 #include <climits>
11 
12 #if AT_BUILD_WITH_BLAS()
13 #if C10_IOS
14 #include <Accelerate/Accelerate.h>
15 #else
16 extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc);
17 extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc);
18 extern "C" void cgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc);
19 extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc);
20 #ifdef BLAS_HAS_SBGEMM
21 extern "C" void sbgemm_(char *transa, char *transb, int *m, int *n, int *k,
22                 float *alpha,
23                 const at::BFloat16 *a, int *lda,
24                 const at::BFloat16 *b, int *ldb,
25                 float *beta,
26                 float *c, int *ldc);
27 #endif  // BLAS_HAS_SBGEMM
28 extern "C" void cswap_(int *n, const void *x, int *incx, void *y, int *incy);
29 extern "C" void dcopy_(int *n, const double *x, int *incx, double *y, int *incy);
30 extern "C" void scopy_(int *n, const float *x, int *incx, float *y, int *incy);
31 extern "C" void zcopy_(int *n, const void *x, int *incx, void *y, int *incy);
32 extern "C" void ccopy_(int *n, const void *x, int *incx, void *y, int *incy);
33 extern "C" void daxpy_(int *n, double *a, const double *x, int *incx, double *y, int *incy);
34 extern "C" void saxpy_(int *n, float *a, const float *x, int *incx, float *y, int *incy);
35 extern "C" void caxpy_(int *n, void *a, const void *x, int *incx, void *y, int *incy);
36 extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int *incy);
37 #endif  // C10_IOS
38 #endif  // AT_BUILD_WITH_BLAS
39 
40 #ifdef USE_FBGEMM
41 #include <fbgemm/FbgemmI64.h>
42 #endif  // USE_FBGEMM
43 
44 #if AT_MKLDNN_ENABLED()
45 #include <oneapi/dnnl/dnnl_version.h>
46 #endif // oneDNN
47 
48 #define ONEDNN_UKERNEL_ENABLED (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=5)
49 
50 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
51 #include <oneapi/dnnl/dnnl_ukernel.hpp>
52 #include <oneapi/dnnl/dnnl.hpp>
53 #endif // oneDNN BRGEMM
54 
55 namespace at::native::cpublas {
56 namespace internal {
57 
normalize_last_dims(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,int64_t * lda,int64_t * ldb,int64_t * ldc)58 void normalize_last_dims(
59     TransposeType transa, TransposeType transb,
60     int64_t m, int64_t n, int64_t k,
61     int64_t *lda, int64_t *ldb, int64_t *ldc) {
62   if (n == 1) {
63     *ldc = m;
64   }
65 
66   if(transa != TransposeType::NoTranspose) {
67     if (m == 1) {
68       *lda = k;
69     }
70   } else if(k == 1) {
71     *lda = m;
72   }
73 
74   if(transb != TransposeType::NoTranspose) {
75     if (k == 1) {
76       *ldb = n;
77     }
78   } else if (n == 1) {
79     *ldb = k;
80   }
81 }
82 }  // namespace internal
83 
84 namespace {
85 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunneeded-internal-declaration")
use_blas_gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,int64_t lda,int64_t ldb,int64_t ldc)86 bool use_blas_gemm(
87     TransposeType transa, TransposeType transb,
88     int64_t m, int64_t n, int64_t k,
89     int64_t lda, int64_t ldb, int64_t ldc) {
90   const bool transa_ = transa != TransposeType::NoTranspose;
91   const bool transb_ = transb != TransposeType::NoTranspose;
92   return (
93       (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) &&
94       (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) &&
95       (lda >= std::max(int64_t{1}, (transa_ ? k : m))) &&
96       (ldb >= std::max(int64_t{1}, (transb_ ? n : k))) &&
97       (ldc >= std::max(int64_t{1}, m)));
98 }
C10_DIAGNOSTIC_POP()99 C10_DIAGNOSTIC_POP()
100 
101 #ifdef USE_FBGEMM
102 fbgemm::matrix_op_t to_fbgemm(TransposeType trans) {
103   switch (trans) {
104     case TransposeType::Transpose: return fbgemm::matrix_op_t::Transpose;
105     case TransposeType::NoTranspose: return fbgemm::matrix_op_t::NoTranspose;
106     case TransposeType::ConjTranspose: TORCH_INTERNAL_ASSERT(false, "ConjTranspose type is not supported in fbgemm");
107   }
108   TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
109 }
110 #endif  // USE_FBGEMM
111 
112 #if (AT_BUILD_WITH_BLAS() && C10_IOS)
to_apple_accelerate_transpose(TransposeType trans)113 CBLAS_TRANSPOSE to_apple_accelerate_transpose(TransposeType trans) {
114   switch (trans) {
115     case TransposeType::Transpose: return CblasTrans;
116     case TransposeType::NoTranspose: return CblasNoTrans;
117     case TransposeType::ConjTranspose: return CblasConjTrans;
118   }
119   TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
120 }
121 #endif
122 
123 }  // namespace (anonymous)
124 
125 DEFINE_DISPATCH(gemm_stub);
126 
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const double alpha,const double * a,int64_t lda,const double * b,int64_t ldb,const double beta,double * c,int64_t ldc)127 void gemm(
128     TransposeType transa, TransposeType transb,
129     int64_t m, int64_t n, int64_t k,
130     const double alpha,
131     const double *a, int64_t lda,
132     const double *b, int64_t ldb,
133     const double beta,
134     double *c, int64_t ldc) {
135   internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
136 #if AT_BUILD_WITH_BLAS()
137   if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
138     int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
139     double alpha_ = alpha, beta_ = beta;
140     #if C10_IOS
141     CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa);
142     CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb);
143     cblas_dgemm(CblasColMajor,
144       transa_, transb_,
145       m_, n_, k_,
146       alpha_,
147       a, lda_,
148       b, ldb_,
149       beta_,
150       c, ldc_);
151     #else
152     char transa_ = to_blas(transa), transb_ = to_blas(transb);
153     dgemm_(
154         &transa_, &transb_,
155         &m_, &n_, &k_,
156         &alpha_,
157         a, &lda_,
158         b, &ldb_,
159         &beta_,
160         c, &ldc_);
161     #endif
162     return;
163   }
164 #endif
165   gemm_stub(
166       at::kCPU, at::kDouble,
167       transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
168 }
169 
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const float * a,int64_t lda,const float * b,int64_t ldb,const float beta,float * c,int64_t ldc)170 void gemm(
171     TransposeType transa, TransposeType transb,
172     int64_t m, int64_t n, int64_t k,
173     const float alpha,
174     const float *a, int64_t lda,
175     const float *b, int64_t ldb,
176     const float beta,
177     float *c, int64_t ldc) {
178   internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
179 #if AT_MKLDNN_ENABLED()
180    if (mkldnn_bf32_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
181      return;
182    }
183 #endif
184 #if AT_BUILD_WITH_BLAS()
185   if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
186     int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
187     float alpha_ = alpha, beta_ = beta;
188     #if C10_IOS
189     CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa);
190     CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb);
191     cblas_sgemm(CblasColMajor,
192       transa_, transb_,
193       m_, n_, k_,
194       alpha_,
195       a, lda_,
196       b, ldb_,
197       beta_,
198       c, ldc_);
199     #else
200     char transa_ = to_blas(transa), transb_ = to_blas(transb);
201     sgemm_(
202         &transa_, &transb_,
203         &m_, &n_, &k_,
204         &alpha_,
205         a, &lda_,
206         b, &ldb_,
207         &beta_,
208         c, &ldc_);
209     #endif
210     return;
211   }
212 #endif
213   gemm_stub(
214       at::kCPU, at::kFloat,
215       transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
216 }
217 
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const c10::complex<double> alpha,const c10::complex<double> * a,int64_t lda,const c10::complex<double> * b,int64_t ldb,const c10::complex<double> beta,c10::complex<double> * c,int64_t ldc)218 void gemm(
219     TransposeType transa, TransposeType transb,
220     int64_t m, int64_t n, int64_t k,
221     const c10::complex<double> alpha,
222     const c10::complex<double> *a, int64_t lda,
223     const c10::complex<double> *b, int64_t ldb,
224     const c10::complex<double> beta,
225     c10::complex<double> *c, int64_t ldc) {
226   internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
227 #if AT_BUILD_WITH_BLAS()
228   if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
229     int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
230     c10::complex<double> alpha_ = alpha, beta_ = beta;
231     #if C10_IOS
232     CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa);
233     CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb);
234     cblas_zgemm(CblasColMajor,
235       transa_, transb_,
236       m_, n_, k_,
237       &alpha_,
238       a, lda_,
239       b, ldb_,
240       &beta_,
241       c, ldc_);
242     #else
243     char transa_ = to_blas(transa), transb_ = to_blas(transb);
244     zgemm_(
245         &transa_, &transb_,
246         &m_, &n_, &k_,
247         &alpha_,
248         a, &lda_,
249         b, &ldb_,
250         &beta_,
251         c, &ldc_);
252     #endif
253     return;
254   }
255 #endif
256   gemm_stub(
257       at::kCPU, at::kComplexDouble,
258       transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
259 }
260 
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const c10::complex<float> alpha,const c10::complex<float> * a,int64_t lda,const c10::complex<float> * b,int64_t ldb,const c10::complex<float> beta,c10::complex<float> * c,int64_t ldc)261 void gemm(
262     TransposeType transa, TransposeType transb,
263     int64_t m, int64_t n, int64_t k,
264     const c10::complex<float> alpha,
265     const c10::complex<float> *a, int64_t lda,
266     const c10::complex<float> *b, int64_t ldb,
267     const c10::complex<float> beta,
268     c10::complex<float> *c, int64_t ldc) {
269   internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
270 #if AT_BUILD_WITH_BLAS()
271   if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
272     int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
273     c10::complex<float> alpha_ = alpha, beta_ = beta;
274     #if C10_IOS
275     CBLAS_TRANSPOSE transa_ = to_apple_accelerate_transpose(transa);
276     CBLAS_TRANSPOSE transb_ = to_apple_accelerate_transpose(transb);
277     cblas_cgemm(CblasColMajor,
278       transa_, transb_,
279       m_, n_, k_,
280       &alpha_,
281       a, lda_,
282       b, ldb_,
283       &beta_,
284       c, ldc_);
285     #else
286     char transa_ = to_blas(transa), transb_ = to_blas(transb);
287     cgemm_(
288         &transa_, &transb_,
289         &m_, &n_, &k_,
290         &alpha_,
291         a, &lda_,
292         b, &ldb_,
293         &beta_,
294         c, &ldc_);
295     #endif
296     return;
297   }
298 #endif
299   gemm_stub(
300       at::kCPU, at::kComplexFloat,
301       transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
302 }
303 
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const at::BFloat16 * a,int64_t lda,const at::BFloat16 * b,int64_t ldb,const float beta,at::BFloat16 * c,int64_t ldc)304 void gemm(
305    TransposeType transa, TransposeType transb,
306    int64_t m, int64_t n, int64_t k,
307    const float alpha,
308    const at::BFloat16 *a, int64_t lda,
309    const at::BFloat16 *b, int64_t ldb,
310    const float beta,
311    at::BFloat16 *c, int64_t ldc) {
312    internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
313 #if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM)
314    if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
315       int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
316       char transa_ = to_blas(transa), transb_ = to_blas(transb);
317       float alpha_ = alpha, beta_ = beta;
318       int c_size = n_ * ldc_;
319       // C matrix in OpenBLAS sbgemm are of type "float" so we have to convert, copy and copy back.
320       std::vector<float> float_v(c, c + c_size);
321       sbgemm_(&transa_, &transb_,
322               &m_, &n_, &k_,
323               &alpha_,
324               a, &lda_,
325               b, &ldb_,
326               &beta_,
327               float_v.data(), &ldc_);
328       for (auto cv: float_v) {
329         *(c++) = c10::convert<at::BFloat16>(cv);
330       }
331       return;
332    }
333 #endif
334 #if AT_MKLDNN_ENABLED()
335    if (mkldnn_bf16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
336      return;
337    }
338 #endif
339    gemm_stub(
340       at::kCPU, at::kBFloat16,
341       transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
342 }
343 
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const at::Half * a,int64_t lda,const at::Half * b,int64_t ldb,const float beta,at::Half * c,int64_t ldc)344 void gemm(
345    TransposeType transa, TransposeType transb,
346    int64_t m, int64_t n, int64_t k,
347    const float alpha,
348    const at::Half *a, int64_t lda,
349    const at::Half *b, int64_t ldb,
350    const float beta,
351    at::Half *c, int64_t ldc) {
352    internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
353 #if AT_MKLDNN_ENABLED()
354    if (mkldnn_fp16_gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)) {
355      return;
356    }
357 #endif
358    gemm_stub(
359       at::kCPU, at::kHalf,
360       transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
361 }
362 
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const at::BFloat16 * a,int64_t lda,const at::BFloat16 * b,int64_t ldb,const float beta,float * c,int64_t ldc)363 void gemm(
364     TransposeType transa, TransposeType transb,
365     int64_t m, int64_t n, int64_t k,
366     const float alpha,
367     const at::BFloat16 *a, int64_t lda,
368     const at::BFloat16 *b, int64_t ldb,
369     const float beta,
370     float *c, int64_t ldc) {
371   internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
372 #if AT_BUILD_WITH_BLAS() && defined(BLAS_HAS_SBGEMM)
373    if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
374       int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
375       char transa_ = to_blas(transa), transb_ = to_blas(transb);
376       float alpha_ = alpha, beta_ = beta;
377       sbgemm_(&transa_, &transb_,
378               &m_, &n_, &k_,
379               &alpha_,
380               a, &lda_,
381               b, &ldb_,
382               &beta_,
383               c, &ldc_);
384       return;
385    }
386 #endif
387 #ifdef MKL_HAS_SBGEMM
388   if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
389     int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
390     mkl_gemm_bf16bf16f32(transa, transb, m_, n_, k_, alpha, a, lda_, b, ldb_, beta, c, ldc_);
391     return;
392   }
393 #endif
394   // for the fallback path, first compute gemm with beta = 0,
395   // and then add c in full precision.
396   int64_t c_size = n * m;
397   std::vector<at::BFloat16> bfloat_c(c_size, 0.f);
398   gemm_stub(
399       at::kCPU, at::kBFloat16,
400       transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, bfloat_c.data(), m);
401   for (const auto j : c10::irange(n)) {
402     for (const auto i : c10::irange(m)) {
403       auto offset = j * ldc + i;
404       // beta == 0 won't propagate NaN from C
405       if (beta == 0.f) {
406         c[offset] = c10::convert<float>(bfloat_c[j * m + i]);
407       } else {
408         c[offset] = beta * c[offset] + c10::convert<float>(bfloat_c[j * m + i]);
409       }
410     }
411   }
412 }
413 
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const float alpha,const at::Half * a,int64_t lda,const at::Half * b,int64_t ldb,const float beta,float * c,int64_t ldc)414 void gemm(
415     TransposeType transa, TransposeType transb,
416     int64_t m, int64_t n, int64_t k,
417     const float alpha,
418     const at::Half *a, int64_t lda,
419     const at::Half *b, int64_t ldb,
420     const float beta,
421     float *c, int64_t ldc) {
422   internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
423 #ifdef MKL_HAS_SHGEMM
424   if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
425     int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
426     mkl_gemm_f16f16f32(transa, transb, m_, n_, k_, alpha, a, lda_, b, ldb_, beta, c, ldc_);
427     return;
428   }
429 #endif
430   // for the fallback path, first compute gemm with beta = 0,
431   // and then add c in full precision.
432   int64_t c_size = n * m;
433   std::vector<at::Half> float16_c(c_size, 0.f);
434   gemm_stub(
435       at::kCPU, at::kHalf,
436       transa, transb, m, n, k, alpha, a, lda, b, ldb, 0.f, float16_c.data(), m);
437   for (const auto j : c10::irange(n)) {
438     for (const auto i : c10::irange(m)) {
439       auto offset = j * ldc + i;
440       // beta == 0 won't propagate NaN from C
441       if (beta == 0.f) {
442         c[offset] = c10::convert<float>(float16_c[j * m + i]);
443       } else {
444         c[offset] = beta * c[offset] + c10::convert<float>(float16_c[j * m + i]);
445       }
446     }
447   }
448 }
449 
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,const int64_t alpha,const int64_t * a,int64_t lda,const int64_t * b,int64_t ldb,const int64_t beta,int64_t * c,int64_t ldc)450 void gemm(
451     TransposeType transa, TransposeType transb,
452     int64_t m, int64_t n, int64_t k,
453     const int64_t alpha,
454     const int64_t *a, int64_t lda,
455     const int64_t *b, int64_t ldb,
456     const int64_t beta,
457     int64_t *c, int64_t ldc) {
458   internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
459 #ifdef USE_FBGEMM
460   if (alpha == 1 && (beta == 0 || beta == 1)) {
461     // In FBGEMM, we assume row-major ordering; However, here we assume the
462     // column-major ordering following the FORTRAN tradition in BLAS interface
463     // in this function: we can configure the layout (row/column-major ordering)
464     // of A and B by changing transa_ and transb_, but we cannot change the
465     // layout of C with this FORTRAN-style BLAS interface.
466     //
467     // The workaround is that we compute
468     // C^T (n x m) = B^T (n x k) * A^T (k x m) instead.
469     //
470     // In this way we view C^T as the row-major ordering when passing to FBGEMM.
471     fbgemm::cblas_gemm_i64_i64acc(
472         to_fbgemm(transb),
473         to_fbgemm(transa),
474         n,
475         m,
476         k,
477         b,
478         ldb,
479         a,
480         lda,
481         beta == 1,
482         c,
483         ldc);
484     return;
485   }
486 #endif
487 
488   gemm_stub(
489       kCPU, kLong,
490       transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
491 }
492 
493 template <typename scalar_t>
gemm_batched_mkl_impl(TransposeType transa,TransposeType transb,int64_t batch_size,int64_t m,int64_t n,int64_t k,scalar_t alpha,const scalar_t ** a,int64_t lda,const scalar_t ** b,int64_t ldb,scalar_t beta,scalar_t ** c,int64_t ldc)494 static void gemm_batched_mkl_impl(
495       TransposeType transa, TransposeType transb,
496       int64_t batch_size, int64_t m, int64_t n, int64_t k,
497       scalar_t alpha,
498       const scalar_t **a, int64_t lda,
499       const scalar_t **b, int64_t ldb,
500       scalar_t beta,
501       scalar_t **c, int64_t ldc) {
502   for (int64_t i = 0; i < batch_size;) {
503     int sub_batch = std::min(batch_size - i, int64_t{INT_MAX});
504     mkl_gemm_batched(transa, transb, sub_batch, m, n, k, alpha,
505                      &a[i], lda, &b[i], ldb, beta, &c[i], ldc);
506     i += sub_batch;
507   }
508 }
509 
510 template <typename scalar_t>
511 using is_blas_library_type = std::integral_constant<bool,
512     std::is_same_v<scalar_t, double> ||
513     std::is_same_v<scalar_t, float> ||
514     std::is_same_v<scalar_t, c10::complex<double>> ||
515     std::is_same_v<scalar_t, c10::complex<float>>>;
516 
517 template <typename scalar_t>
gemm_batched_generic(TransposeType transa,TransposeType transb,int64_t batch_size,int64_t m,int64_t n,int64_t k,scalar_t alpha,const scalar_t ** a,int64_t lda,const scalar_t ** b,int64_t ldb,scalar_t beta,scalar_t ** c,int64_t ldc)518 void gemm_batched_generic(
519     TransposeType transa, TransposeType transb,
520     int64_t batch_size, int64_t m, int64_t n, int64_t k,
521     scalar_t alpha,
522     const scalar_t **a, int64_t lda,
523     const scalar_t **b, int64_t ldb,
524     scalar_t beta,
525     scalar_t **c, int64_t ldc) {
526   for (const auto batch : c10::irange(batch_size)) {
527     gemm(transa, transb, m, n, k, alpha, a[batch], lda, b[batch], ldb, beta, c[batch], ldc);
528   }
529 }
530 
531 template <typename scalar_t>
gemm_batched(TransposeType transa,TransposeType transb,int64_t batch_size,int64_t m,int64_t n,int64_t k,scalar_t alpha,const scalar_t ** a,int64_t lda,const scalar_t ** b,int64_t ldb,scalar_t beta,scalar_t ** c,int64_t ldc)532 void gemm_batched(
533     TransposeType transa, TransposeType transb,
534     int64_t batch_size, int64_t m, int64_t n, int64_t k,
535     scalar_t alpha,
536     const scalar_t **a, int64_t lda,
537     const scalar_t **b, int64_t ldb,
538     scalar_t beta,
539     scalar_t **c, int64_t ldc) {
540   if (batch_size == 1) {
541     return gemm(transa, transb, m, n, k, alpha, a[0], lda, b[0], ldb, beta, c[0], ldc);
542   }
543 
544   if constexpr (AT_MKL_ENABLED() && is_blas_library_type<scalar_t>::value) {
545     internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
546     if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
547       gemm_batched_mkl_impl(
548           transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
549     } else {
550       gemm_batched_generic(
551           transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
552     }
553   } else {
554     gemm_batched_generic(
555         transa, transb, batch_size, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
556   }
557 }
558 
559 template <typename scalar_t>
gemm_batched_with_stride_generic(TransposeType transa,TransposeType transb,int64_t batch_size,int64_t m,int64_t n,int64_t k,scalar_t alpha,const scalar_t * a,int64_t lda,int64_t batch_stride_a,const scalar_t * b,int64_t ldb,int64_t batch_stride_b,scalar_t beta,scalar_t * c,int64_t ldc,int64_t batch_stride_c)560 void gemm_batched_with_stride_generic(
561     TransposeType transa, TransposeType transb,
562     int64_t batch_size, int64_t m, int64_t n, int64_t k,
563     scalar_t alpha,
564     const scalar_t *a, int64_t lda, int64_t batch_stride_a,
565     const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
566     scalar_t beta,
567     scalar_t *c, int64_t ldc, int64_t batch_stride_c) {
568   for (const auto batch : c10::irange(batch_size)) {
569     const auto a_batch = a + batch_stride_a * batch;
570     const auto b_batch = b + batch_stride_b * batch;
571     const auto c_batch = c + batch_stride_c * batch;
572     gemm(transa, transb, m, n, k, alpha, a_batch, lda, b_batch, ldb, beta, c_batch, ldc);
573   }
574 }
575 
576 template <typename scalar_t>
gemm_batched_with_stride(TransposeType transa,TransposeType transb,int64_t batch_size,int64_t m,int64_t n,int64_t k,scalar_t alpha,const scalar_t * a,int64_t lda,int64_t batch_stride_a,const scalar_t * b,int64_t ldb,int64_t batch_stride_b,scalar_t beta,scalar_t * c,int64_t ldc,int64_t batch_stride_c)577 void gemm_batched_with_stride(
578     TransposeType transa, TransposeType transb,
579     int64_t batch_size, int64_t m, int64_t n, int64_t k,
580     scalar_t alpha,
581     const scalar_t *a, int64_t lda, int64_t batch_stride_a,
582     const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
583     scalar_t beta,
584     scalar_t *c, int64_t ldc, int64_t batch_stride_c) {
585   if (batch_size == 1) {
586     return gemm(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
587   }
588 
589   if constexpr (AT_MKL_ENABLED() && is_blas_library_type<scalar_t>::value) {
590     internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
591     if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) {
592       c10::SmallBuffer<const scalar_t*, 16> a_ptrs(batch_size);
593       c10::SmallBuffer<const scalar_t*, 16> b_ptrs(batch_size);
594       c10::SmallBuffer<scalar_t*, 16> c_ptrs(batch_size);
595 
596       for (const auto batch : c10::irange(batch_size)) {
597         a_ptrs[batch] = a + batch_stride_a * batch;
598         b_ptrs[batch] = b + batch_stride_b * batch;
599         c_ptrs[batch] = c + batch_stride_c * batch;
600       }
601       gemm_batched_mkl_impl(
602           transa, transb, batch_size, m, n, k, alpha, a_ptrs.data(), lda,
603           b_ptrs.data(), ldb, beta, c_ptrs.data(), ldc);
604     } else {
605       gemm_batched_with_stride_generic(
606           transa, transb, batch_size, m, n, k, alpha, a, lda, batch_stride_a,
607           b, ldb, batch_stride_b, beta, c, ldc, batch_stride_c);
608     }
609   } else {
610     gemm_batched_with_stride_generic(transa, transb, batch_size, m, n, k, alpha,
611                                      a, lda, batch_stride_a, b, ldb, batch_stride_b,
612                                      beta, c, ldc, batch_stride_c);
613   }
614 }
615 
616 #define INSTANTIATE_BATCHED_GEMM(scalar_t, DType)               \
617   template void gemm_batched(                                   \
618       TransposeType transa, TransposeType transb,               \
619       int64_t batch_size, int64_t m, int64_t n, int64_t k,      \
620       scalar_t alpha,                                           \
621       const scalar_t **a, int64_t lda,                          \
622       const scalar_t **b, int64_t ldb,                          \
623       scalar_t beta,                                            \
624       scalar_t **c, int64_t ldc);                               \
625   template void gemm_batched_with_stride(                       \
626       TransposeType transa, TransposeType transb,               \
627       int64_t batch_size, int64_t m, int64_t n, int64_t k,      \
628       scalar_t alpha,                                           \
629       const scalar_t *a, int64_t lda, int64_t batch_stride_a,   \
630       const scalar_t *b, int64_t ldb, int64_t batch_stride_b,   \
631       scalar_t beta,                                            \
632       scalar_t *c, int64_t ldc, int64_t batch_stride_c);
633 
634 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(INSTANTIATE_BATCHED_GEMM)
635 
636 DEFINE_DISPATCH(axpy_stub);
637 
axpy(int64_t n,double a,const double * x,int64_t incx,double * y,int64_t incy)638 void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy) {
639   if(n == 1)
640   {
641     incx = 1;
642     incy = 1;
643   }
644   #if AT_BUILD_WITH_BLAS()
645   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
646   {
647     int i_n = (int)n;
648     int i_incx = (int)incx;
649     int i_incy = (int)incy;
650     #if C10_IOS
651     cblas_daxpy(i_n, a, x, i_incx, y, i_incy);
652     #else
653     daxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
654     #endif
655     return;
656   }
657   #endif
658   axpy_stub(
659       kCPU, at::kDouble,
660       n, a, x, incx, y, incy);
661 }
662 
axpy(int64_t n,float a,const float * x,int64_t incx,float * y,int64_t incy)663 void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy) {
664   if(n == 1)
665   {
666     incx = 1;
667     incy = 1;
668   }
669   #if AT_BUILD_WITH_BLAS()
670   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
671   {
672     int i_n = (int)n;
673     int i_incx = (int)incx;
674     int i_incy = (int)incy;
675     #if C10_IOS
676     cblas_saxpy(i_n, a, x, i_incx, y, i_incy);
677     #else
678     saxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
679     #endif
680     return;
681   }
682   #endif
683   axpy_stub(
684       kCPU, at::kFloat,
685       n, a, x, incx, y, incy);
686 }
687 
axpy(int64_t n,c10::complex<double> a,const c10::complex<double> * x,int64_t incx,c10::complex<double> * y,int64_t incy)688 void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy) {
689   if(n == 1)
690   {
691     incx = 1;
692     incy = 1;
693   }
694   #if AT_BUILD_WITH_BLAS()
695   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
696   {
697     int i_n = (int)n;
698     int i_incx = (int)incx;
699     int i_incy = (int)incy;
700     #if C10_IOS
701     cblas_zaxpy(i_n, &a, x, i_incx, y, i_incy);
702     #else
703     zaxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
704     #endif
705     return;
706   }
707   #endif
708   axpy_stub(
709       kCPU, at::kComplexDouble,
710       n, a, x, incx, y, incy);
711 }
712 
axpy(int64_t n,c10::complex<float> a,const c10::complex<float> * x,int64_t incx,c10::complex<float> * y,int64_t incy)713 void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy) {
714   if(n == 1)
715   {
716     incx = 1;
717     incy = 1;
718   }
719   #if AT_BUILD_WITH_BLAS()
720   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
721   {
722     int i_n = (int)n;
723     int i_incx = (int)incx;
724     int i_incy = (int)incy;
725     #if C10_IOS
726     cblas_caxpy(i_n, &a, x, i_incx, y, i_incy);
727     #else
728     caxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
729     #endif
730     return;
731   }
732   #endif
733   axpy_stub(
734       kCPU, at::kComplexFloat,
735       n, a, x, incx, y, incy);
736 }
737 
738 DEFINE_DISPATCH(copy_stub);
739 
copy(int64_t n,const double * x,int64_t incx,double * y,int64_t incy)740 void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy) {
741   if(n == 1)
742   {
743     incx = 1;
744     incy = 1;
745   }
746   #if AT_BUILD_WITH_BLAS()
747   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
748     int i_n = (int)n;
749     int i_incx = (int)incx;
750     int i_incy = (int)incy;
751     #if C10_IOS
752     cblas_dcopy(i_n, x, i_incx, y, i_incy);
753     #else
754     dcopy_(&i_n, x, &i_incx, y, &i_incy);
755     #endif
756     return;
757   }
758   #endif
759   copy_stub(
760       kCPU, at::kDouble,
761       n, x, incx, y, incy);
762 }
763 
copy(int64_t n,const float * x,int64_t incx,float * y,int64_t incy)764 void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy) {
765   if(n == 1)
766   {
767     incx = 1;
768     incy = 1;
769   }
770   #if AT_BUILD_WITH_BLAS()
771   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
772     int i_n = (int)n;
773     int i_incx = (int)incx;
774     int i_incy = (int)incy;
775     #if C10_IOS
776     cblas_scopy(i_n, x, i_incx, y, i_incy);
777     #else
778     scopy_(&i_n, x, &i_incx, y, &i_incy);
779     #endif
780     return;
781   }
782   #endif
783   copy_stub(
784       kCPU, at::kFloat,
785       n, x, incx, y, incy);
786 }
787 
copy(int64_t n,const c10::complex<double> * x,int64_t incx,c10::complex<double> * y,int64_t incy)788 void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy) {
789   if(n == 1)
790   {
791     incx = 1;
792     incy = 1;
793   }
794   #if AT_BUILD_WITH_BLAS()
795   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
796     int i_n = (int)n;
797     int i_incx = (int)incx;
798     int i_incy = (int)incy;
799     #if C10_IOS
800     cblas_zcopy(i_n, x, i_incx, y, i_incy);
801     #else
802     zcopy_(&i_n, x, &i_incx, y, &i_incy);
803     #endif
804     return;
805   }
806   #endif
807   copy_stub(
808       kCPU, at::kComplexDouble,
809       n, x, incx, y, incy);
810 }
811 
copy(int64_t n,const c10::complex<float> * x,int64_t incx,c10::complex<float> * y,int64_t incy)812 void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy){
813   if(n == 1)
814   {
815     incx = 1;
816     incy = 1;
817   }
818   #if AT_BUILD_WITH_BLAS()
819   if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
820     int i_n = (int)n;
821     int i_incx = (int)incx;
822     int i_incy = (int)incy;
823     #if C10_IOS
824     cblas_ccopy(i_n, &x, i_incx, y, i_incy);
825     #else
826     ccopy_(&i_n, x, &i_incx, y, &i_incy);
827     #endif
828     return;
829   }
830   #endif
831   copy_stub(
832       kCPU, at::kComplexFloat,
833       n, x, incx, y, incy);
834 }
835 
836 // oneDNN BRGEMM
837 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
838 struct BrgemmKey {
839   int64_t M;
840   int64_t N;
841   int64_t K;
842   int64_t batch_size;
843   int64_t lda;
844   int64_t ldb;
845   int64_t ldc;
846   ScalarType dt_a;
847   ScalarType dt_b;
848   ScalarType dt_c;
849   float alpha;
850   float beta;
BrgemmKeyat::native::cpublas::BrgemmKey851   BrgemmKey(
852       int64_t M,
853       int64_t N,
854       int64_t K,
855       int64_t batch_size,
856       int64_t lda,
857       int64_t ldb,
858       int64_t ldc,
859       ScalarType dt_a,
860       ScalarType dt_b,
861       ScalarType dt_c,
862       float alpha,
863       float beta)
864       : M(M),
865         N(N),
866         K(K),
867         batch_size(batch_size),
868         lda(lda),
869         ldb(ldb),
870         ldc(ldc),
871         dt_a(dt_a),
872         dt_b(dt_b),
873         dt_c(dt_c),
874         alpha(alpha),
875         beta(beta) {}
operator ==at::native::cpublas::BrgemmKey876   bool operator==(const BrgemmKey& other) const {
877     return M == other.M && N == other.N && K == other.K &&
878         batch_size == other.batch_size && lda == other.lda &&
879         ldb == other.ldb && ldc == other.ldc && dt_a == other.dt_a &&
880         dt_b == other.dt_b && dt_c == other.dt_c && alpha == other.alpha &&
881         beta == other.beta;
882   }
883 };
884 
885 struct PackKey {
886   int64_t K;
887   int64_t N;
888   int64_t ld_in;
889   int64_t ld_out;
890   ScalarType dt_in;
891   ScalarType dt_out;
PackKeyat::native::cpublas::PackKey892   PackKey(
893       int64_t K,
894       int64_t N,
895       int64_t ld_in,
896       int64_t ld_out,
897       ScalarType dt_in,
898       ScalarType dt_out)
899       : K(K),
900         N(N),
901         ld_in(ld_in),
902         ld_out(ld_out),
903         dt_in(dt_in),
904         dt_out(dt_out) {}
operator ==at::native::cpublas::PackKey905   bool operator==(const PackKey& other) const {
906     return N == other.N && K == other.K && ld_in == other.ld_in &&
907         ld_out == other.ld_out && dt_in == other.dt_in &&
908         dt_out == other.dt_out;
909   }
910 };
911 
get_dnnl_dtype(ScalarType dtype)912 inline dnnl::memory::data_type get_dnnl_dtype(ScalarType dtype) {
913   if (dtype == ScalarType::Float) {
914     return dnnl::memory::data_type::f32;
915   } else if (dtype == ScalarType::BFloat16) {
916     return dnnl::memory::data_type::bf16;
917   } else if (dtype == ScalarType::Half) {
918     return dnnl::memory::data_type::f16;
919   } else if (dtype == ScalarType::Byte) {
920     return dnnl::memory::data_type::u8;
921   } else if (dtype == ScalarType::Char) {
922     return dnnl::memory::data_type::s8;
923   } else {
924     TORCH_CHECK(false, "get_dnnl_dtype expects float/bfloat16/half/int8 tensor input");
925   }
926 }
927 
928 template<typename key_t>
929 struct UnsafeUkernelKeyHasher {
930   std::size_t operator()(const key_t& key) const;
931 };
932 
933 template<>
operator ()(const BrgemmKey & key) const934 std::size_t UnsafeUkernelKeyHasher<BrgemmKey>::operator()(const BrgemmKey& key) const {
935   // Use beta, M, N, and K to compute hash to reduce the overhead as
936   // batch size, alpha, and data types are unlikely to change within the same kernel and
937   // leading dimensions are likely to be related to M, K, N or use fixed values.
938   std::size_t h = std::hash<float>()(key.beta + 1);
939   h = std::hash<int64_t>()(key.M) ^ (h << 1);
940   h = std::hash<int64_t>()(key.N) ^ (h << 1);
941   h = std::hash<int64_t>()(key.K) ^ (h << 1);
942   h = std::hash<int64_t>()(key.ldc) ^ (h << 1);
943   return h;
944 }
945 
946 template<>
operator ()(const PackKey & key) const947 std::size_t UnsafeUkernelKeyHasher<PackKey>::operator()(const PackKey& key) const {
948   // Use K and N to compute hash to reduce the overhead as
949   // data types are unlikely to change and
950   // ld_in/ld_out is likely to be related to K, N or use fixed values
951   std::size_t h = std::hash<int64_t>()(key.K);
952   h = std::hash<int64_t>()(key.N) ^ (h << 1);
953   return h;
954 }
955 
956 template <typename key_t, typename value_t>
957 struct KernelCache  {
958   using kstore_t = std::unordered_map<key_t, std::shared_ptr<value_t>, UnsafeUkernelKeyHasher<key_t>>;
fetch_or_createat::native::cpublas::KernelCache959   static inline std::shared_ptr<value_t>&& fetch_or_create(
960       const key_t& key,
961       const std::function<std::shared_ptr<value_t>()>& callback) {
962     auto&& search = get_store().find(key);
963     if (search != get_store().end()) {
964       return std::move(search->second);
965     } else {
966       get_store().insert({key, callback()});
967       return std::move(get_store()[key]);
968     }
969   }
970 
get_storeat::native::cpublas::KernelCache971   static inline kstore_t& get_store() {
972     static thread_local kstore_t cache_kernels;
973     return cache_kernels;
974   }
975 };
976 
977 // Helper struct for convenient brgemm configuration
978 struct GemmHelper {
GemmHelperat::native::cpublas::GemmHelper979   GemmHelper(
980       int64_t M,
981       int64_t N,
982       int64_t K,
983       int64_t bs,
984       int64_t ld_a,
985       int64_t ld_b,
986       int64_t ld_c,
987       ScalarType dt_a,
988       ScalarType dt_b,
989       ScalarType dt_c,
990       const float alpha,
991       const float beta) {
992     // Create brgemm
993     brg = dnnl::ukernel::brgemm(
994         M,
995         N,
996         K,
997         bs,
998         ld_a,
999         ld_b,
1000         ld_c,
1001         get_dnnl_dtype(dt_a),
1002         get_dnnl_dtype(dt_b),
1003         get_dnnl_dtype(dt_c),
1004         alpha,
1005         beta);
1006     // Create a scratchpad buffer for the brgemm execution
1007     scratchpad = std::vector<uint8_t>(brg.get_scratchpad_size());
1008     // Prepare default vector of pairs of tensors A and B offsets for each batch.
1009     A_B_offsets.reserve(1);
1010     A_B_offsets[0] = std::make_pair(0, 0);
1011   }
1012   dnnl::ukernel::brgemm brg;
1013   std::vector<uint8_t> scratchpad;
1014   std::vector<std::pair<int64_t, int64_t>> A_B_offsets;
1015 };
1016 
1017 struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
1018   // Fetch/create GemmHelper object and execute brgemm with batch size = 1
1019   template <typename scalar_t_a, typename scalar_t_b, typename scalar_t_c>
callat::native::cpublas::Brgemm1020   static inline void call(
1021       int64_t M,
1022       int64_t N,
1023       int64_t K,
1024       int64_t ld_a,
1025       int64_t ld_b,
1026       int64_t ld_c,
1027       const float alpha,
1028       const float beta,
1029       const scalar_t_a* A,
1030       const scalar_t_b* B,
1031       scalar_t_c* C) {
1032     auto&& key = BrgemmKey(
1033         M,
1034         N,
1035         K,
1036         int64_t(1),
1037         ld_a,
1038         ld_b,
1039         ld_c,
1040         c10::CppTypeToScalarType<scalar_t_a>::value,
1041         c10::CppTypeToScalarType<scalar_t_b>::value,
1042         c10::CppTypeToScalarType<scalar_t_c>::value,
1043         alpha,
1044         beta);
1045     // Fetch/create GemmHelper object
1046     auto&& value = fetch_or_create(key, [&]() {
1047       auto&& v = std::make_shared<GemmHelper>(
1048           M,
1049           N,
1050           K,
1051           1,
1052           ld_a,
1053           ld_b,
1054           ld_c,
1055           c10::CppTypeToScalarType<scalar_t_a>::value,
1056           c10::CppTypeToScalarType<scalar_t_b>::value,
1057           c10::CppTypeToScalarType<scalar_t_c>::value,
1058           alpha,
1059           beta);
1060       (*v).brg.generate();
1061       return std::move(v);
1062     });
1063     if (get_current() != value) {
1064       dnnl::ukernel::brgemm::release_hw_context();
1065       ((*value).brg).set_hw_context();
1066       get_current() = value;
1067     }
1068     ((*value).brg)
1069         .execute(A, B, (*value).A_B_offsets, C, (*value).scratchpad.data());
1070   }
1071 
get_currentat::native::cpublas::Brgemm1072   static inline std::shared_ptr<GemmHelper>& get_current() {
1073     static thread_local std::shared_ptr<GemmHelper> current;
1074     return current;
1075   }
1076 
device_checkat::native::cpublas::Brgemm1077   static inline bool device_check(ScalarType dtype) {
1078     if (!at::globalContext().userEnabledMkldnn()) {
1079       return false;
1080     }
1081     if (dtype == ScalarType::Half) {
1082       static bool fp16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_fp16;
1083       return fp16_support;
1084     }
1085     return false;
1086   }
1087 };
1088 
1089 using pack_t = dnnl::ukernel::brgemm_pack_B;
1090 struct Pack : public KernelCache <PackKey, pack_t> {
callat::native::cpublas::Pack1091   static inline void call(
1092       int64_t K,
1093       int64_t N,
1094       int64_t ld_in,
1095       int64_t ld_out,
1096       ScalarType dt_in,
1097       ScalarType dt_out,
1098       const void* in,
1099       void* out) {
1100     auto&& key = PackKey(K, N, ld_in, ld_out, dt_in, dt_out);
1101     auto&& pack = fetch_or_create(key, [&]() {
1102       auto&& p = std::make_shared<pack_t>(
1103           K, N, ld_in, ld_out, get_dnnl_dtype(dt_in), get_dnnl_dtype(dt_out));
1104       if (need_pack(dt_in)) {
1105         (*p).generate();
1106       }
1107       return std::move(p);
1108     });
1109     if (need_pack(dt_in)) {
1110       (*pack).execute(in, out);
1111     } else {
1112       TORCH_CHECK(false, "No need to pack");
1113     }
1114   }
1115 
need_packat::native::cpublas::Pack1116   static inline bool need_pack(ScalarType dtype) {
1117     if (!at::globalContext().userEnabledMkldnn()) {
1118       return false;
1119     }
1120     if (dtype == ScalarType::Half) {
1121       static bool fp16_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx_fp16;
1122       return fp16_pack;
1123     }
1124     return false;
1125   }
1126 };
1127 #endif
1128 
brgemm(int64_t M,int64_t N,int64_t K,int64_t ld_a,int64_t ld_b,int64_t ld_c,const float alpha,const float beta,const at::Half * A,const at::Half * B,float * C)1129 void brgemm(
1130     int64_t M,
1131     int64_t N,
1132     int64_t K,
1133     int64_t ld_a,
1134     int64_t ld_b,
1135     int64_t ld_c,
1136     const float alpha,
1137     const float beta,
1138     const at::Half* A,
1139     const at::Half* B,
1140     float* C) {
1141 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1142   if (Brgemm::device_check(ScalarType::Half)) {
1143     Brgemm::call<at::Half, at::Half, float>(
1144       M, N, K, ld_a, ld_b, ld_c, alpha, beta, A, B, C);
1145     return;
1146   }
1147 #endif
1148   TORCH_CHECK(false,
1149   "Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported");
1150 }
1151 
brgemm_release()1152 void brgemm_release() {
1153 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1154   dnnl::ukernel::brgemm::release_hw_context();
1155 #endif
1156 }
1157 
pack(int64_t K,int64_t N,int64_t ld_in,int64_t ld_out,ScalarType dt_in,ScalarType dt_out,const void * in,void * out)1158 void pack(
1159     int64_t K,
1160     int64_t N,
1161     int64_t ld_in,
1162     int64_t ld_out,
1163     ScalarType dt_in,
1164     ScalarType dt_out,
1165     const void* in,
1166     void* out) {
1167 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1168   Pack::call(K, N, ld_in, ld_out, dt_in, dt_out, in, out);
1169 #else
1170   TORCH_CHECK(false, "pack is only supported on X64 with oneDNN ukernel enabled");
1171 #endif
1172 }
1173 
need_pack(ScalarType dt_in)1174 bool need_pack(ScalarType dt_in) {
1175 #if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
1176   return Pack::need_pack(dt_in);
1177 #else
1178   return false;
1179 #endif
1180 }
1181 
1182 } // namespace at::native::cpublas
1183