1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/mkl/LinearAlgebra.h>
3 #include <ATen/Config.h>
4
5 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
6 #if !AT_MKL_ENABLED()
7
8 namespace at { namespace native {
9
mkl_gemm_batched(const TransposeType trans_A,const TransposeType trans_B,const MKL_INT batch_size,const MKL_INT M,const MKL_INT N,const MKL_INT K,const float alpha,const float ** A,const MKL_INT lda,const float ** B,const MKL_INT ldb,const float beta,float ** C,const MKL_INT ldc)10 void mkl_gemm_batched(
11 const TransposeType trans_A, const TransposeType trans_B,
12 const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const float alpha,
13 const float** A, const MKL_INT lda, const float** B, const MKL_INT ldb, const float beta,
14 float** C, const MKL_INT ldc) {
15 TORCH_INTERNAL_ASSERT(false, "mkl_gemm_batched: ATen not compiled with MKL support");
16 }
17
mkl_gemm_batched(const TransposeType trans_A,const TransposeType trans_B,const MKL_INT batch_size,const MKL_INT M,const MKL_INT N,const MKL_INT K,const double alpha,const double ** A,const MKL_INT lda,const double ** B,const MKL_INT ldb,const double beta,double ** C,const MKL_INT ldc)18 void mkl_gemm_batched(
19 const TransposeType trans_A, const TransposeType trans_B,
20 const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const double alpha,
21 const double** A, const MKL_INT lda, const double** B, const MKL_INT ldb, const double beta,
22 double** C, const MKL_INT ldc) {
23 TORCH_INTERNAL_ASSERT(false, "mkl_gemm_batched: ATen not compiled with MKL support");
24 }
25
mkl_gemm_batched(const TransposeType trans_A,const TransposeType trans_B,const MKL_INT batch_size,const MKL_INT M,const MKL_INT N,const MKL_INT K,const c10::complex<float> alpha,const c10::complex<float> ** A,const MKL_INT lda,const c10::complex<float> ** B,const MKL_INT ldb,const c10::complex<float> beta,c10::complex<float> ** C,const MKL_INT ldc)26 void mkl_gemm_batched(
27 const TransposeType trans_A, const TransposeType trans_B,
28 const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const c10::complex<float> alpha,
29 const c10::complex<float>** A, const MKL_INT lda, const c10::complex<float>** B, const MKL_INT ldb,
30 const c10::complex<float> beta, c10::complex<float>** C, const MKL_INT ldc) {
31 TORCH_INTERNAL_ASSERT(false, "mkl_gemm_batched: ATen not compiled with MKL support");
32 }
33
mkl_gemm_batched(const TransposeType trans_A,const TransposeType trans_B,const MKL_INT batch_size,const MKL_INT M,const MKL_INT N,const MKL_INT K,const c10::complex<double> alpha,const c10::complex<double> ** A,const MKL_INT lda,const c10::complex<double> ** B,const MKL_INT ldb,const c10::complex<double> beta,c10::complex<double> ** C,const MKL_INT ldc)34 void mkl_gemm_batched(
35 const TransposeType trans_A, const TransposeType trans_B,
36 const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const c10::complex<double> alpha,
37 const c10::complex<double>** A, const MKL_INT lda, const c10::complex<double>** B, const MKL_INT ldb,
38 const c10::complex<double> beta, c10::complex<double>** C, const MKL_INT ldc) {
39 TORCH_INTERNAL_ASSERT(false, "mkl_gemm_batched: ATen not compiled with MKL support");
40 }
41
mkl_gemm_bf16bf16f32(TransposeType trans_A,TransposeType trans_B,MKL_INT M,MKL_INT N,MKL_INT K,const float alpha,const c10::BFloat16 * A,MKL_INT lda,const c10::BFloat16 * B,MKL_INT ldb,const float beta,float * C,MKL_INT ldc)42 void mkl_gemm_bf16bf16f32(
43 TransposeType trans_A, TransposeType trans_B,
44 MKL_INT M, MKL_INT N, MKL_INT K, const float alpha,
45 const c10::BFloat16* A, MKL_INT lda, const c10::BFloat16* B, MKL_INT ldb,
46 const float beta, float* C, MKL_INT ldc) {
47 TORCH_INTERNAL_ASSERT(false, "mkl_gemm_bf16bf16f32: ATen not compiled with MKL support");
48 }
49
mkl_gemm_f16f16f32(TransposeType trans_A,TransposeType trans_B,int M,int N,int K,const float alpha,const c10::Half * A,int lda,const c10::Half * B,int ldb,const float beta,float * C,int ldc)50 void mkl_gemm_f16f16f32(
51 TransposeType trans_A, TransposeType trans_B,
52 int M, int N, int K, const float alpha,
53 const c10::Half* A, int lda, const c10::Half* B, int ldb,
54 const float beta, float* C, int ldc) {
55 TORCH_INTERNAL_ASSERT(false, "mkl_gemm_f16f16f32: ATen not compiled with MKL support");
56 }
57
58 }}
59
60 #else // AT_MKL_ENABLED
61
62 #include <mkl.h>
63 #include <c10/util/irange.h>
64
65 namespace at { namespace native {
66
67 static CBLAS_TRANSPOSE to_cblas(TransposeType x) {
68 switch (x) {
69 case TransposeType::NoTranspose: return CblasNoTrans;
70 case TransposeType::Transpose: return CblasTrans;
71 case TransposeType::ConjTranspose: return CblasConjTrans;
72 }
73 TORCH_INTERNAL_ASSERT(false, "Unknown TransposeType");
74 }
75
76 void mkl_gemm_batched(
77 const TransposeType trans_A, const TransposeType trans_B,
78 const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const float alpha,
79 const float** A, const MKL_INT lda, const float** B, const MKL_INT ldb, const float beta,
80 float** C, const MKL_INT ldc) {
81 auto transa_cblas = to_cblas(trans_A);
82 auto transb_cblas = to_cblas(trans_B);
83 cblas_sgemm_batch(CblasColMajor, &transa_cblas, &transb_cblas, &M, &N, &K, &alpha,
84 A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
85 }
86
87 void mkl_gemm_batched(
88 const TransposeType trans_A, const TransposeType trans_B,
89 const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const double alpha,
90 const double** A, const MKL_INT lda, const double** B, const MKL_INT ldb, const double beta,
91 double** C, const MKL_INT ldc) {
92 auto transa_cblas = to_cblas(trans_A);
93 auto transb_cblas = to_cblas(trans_B);
94 cblas_dgemm_batch(CblasColMajor, &transa_cblas, &transb_cblas, &M, &N, &K, &alpha,
95 A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
96 }
97
98 void mkl_gemm_batched(
99 const TransposeType trans_A, const TransposeType trans_B,
100 const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const c10::complex<float> alpha,
101 const c10::complex<float>** A, const MKL_INT lda, const c10::complex<float>** B, const MKL_INT ldb,
102 const c10::complex<float> beta, c10::complex<float>** C, const MKL_INT ldc) {
103 auto transa_cblas = to_cblas(trans_A);
104 auto transb_cblas = to_cblas(trans_B);
105 cblas_cgemm_batch(CblasColMajor, &transa_cblas, &transb_cblas, &M, &N, &K,
106 reinterpret_cast<const void*>(&alpha),
107 reinterpret_cast<const void**>(A), &lda, reinterpret_cast<const void**>(B), &ldb,
108 reinterpret_cast<const void*>(&beta), reinterpret_cast<void**>(C), &ldc, 1, &batch_size);
109 }
110
111 void mkl_gemm_batched(
112 const TransposeType trans_A, const TransposeType trans_B,
113 const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const c10::complex<double> alpha,
114 const c10::complex<double>** A, const MKL_INT lda, const c10::complex<double>** B, const MKL_INT ldb,
115 const c10::complex<double> beta, c10::complex<double>** C, const MKL_INT ldc) {
116 auto transa_cblas = to_cblas(trans_A);
117 auto transb_cblas = to_cblas(trans_B);
118 cblas_zgemm_batch(CblasColMajor, &transa_cblas, &transb_cblas, &M, &N, &K,
119 reinterpret_cast<const void*>(&alpha),
120 reinterpret_cast<const void**>(A), &lda, reinterpret_cast<const void**>(B), &ldb,
121 reinterpret_cast<const void*>(&beta), reinterpret_cast<void**>(C), &ldc, 1, &batch_size);
122 }
123
124 void mkl_gemm_bf16bf16f32(
125 TransposeType trans_A, TransposeType trans_B,
126 MKL_INT M, MKL_INT N, MKL_INT K, const float alpha,
127 const c10::BFloat16* A, MKL_INT lda, const c10::BFloat16* B, MKL_INT ldb,
128 const float beta, float* C, MKL_INT ldc) {
129 #ifdef MKL_HAS_SBGEMM
130 auto transa_cblas = to_cblas(trans_A);
131 auto transb_cblas = to_cblas(trans_B);
132 cblas_gemm_bf16bf16f32(CblasColMajor, transa_cblas, transb_cblas, M, N, K, alpha,
133 (const MKL_BF16*)A, lda, (const MKL_BF16*)B, ldb, beta, C, ldc);
134 #else
135 TORCH_INTERNAL_ASSERT(false, "mkl_gemm_bf16bf16f32 requires mkl version > 2021.0");
136 #endif
137 }
138
139 void mkl_gemm_f16f16f32(
140 TransposeType trans_A, TransposeType trans_B,
141 int M, int N, int K, const float alpha,
142 const c10::Half* A, int lda, const c10::Half* B, int ldb,
143 const float beta, float* C, int ldc) {
144 #ifdef MKL_HAS_SHGEMM
145 auto transa_cblas = to_cblas(trans_A);
146 auto transb_cblas = to_cblas(trans_B);
147 cblas_gemm_f16f16f32(CblasColMajor, transa_cblas, transb_cblas, M, N, K, alpha,
148 (const MKL_F16*)A, lda, (const MKL_F16*)B, ldb, beta, C, ldc);
149 #else
150 TORCH_INTERNAL_ASSERT(false, "mkl_gemm_f16f16f32 requires mkl version >= 2024.0");
151 #endif
152 }
153
154 }} // namespace at::native
155
156 #endif
157 C10_DIAGNOSTIC_POP()
158