xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkl/LinearAlgebra.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/mkl/LinearAlgebra.h>
3 #include <ATen/Config.h>
4 
5 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
6 #if !AT_MKL_ENABLED()
7 
8 namespace at { namespace native {
9 
mkl_gemm_batched(const TransposeType trans_A,const TransposeType trans_B,const MKL_INT batch_size,const MKL_INT M,const MKL_INT N,const MKL_INT K,const float alpha,const float ** A,const MKL_INT lda,const float ** B,const MKL_INT ldb,const float beta,float ** C,const MKL_INT ldc)10 void mkl_gemm_batched(
11     const TransposeType trans_A, const TransposeType trans_B,
12     const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const float alpha,
13     const float** A, const MKL_INT lda, const float** B, const MKL_INT ldb, const float beta,
14     float** C, const MKL_INT ldc) {
15   TORCH_INTERNAL_ASSERT(false, "mkl_gemm_batched: ATen not compiled with MKL support");
16 }
17 
mkl_gemm_batched(const TransposeType trans_A,const TransposeType trans_B,const MKL_INT batch_size,const MKL_INT M,const MKL_INT N,const MKL_INT K,const double alpha,const double ** A,const MKL_INT lda,const double ** B,const MKL_INT ldb,const double beta,double ** C,const MKL_INT ldc)18 void mkl_gemm_batched(
19     const TransposeType trans_A, const TransposeType trans_B,
20     const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const double alpha,
21     const double** A, const MKL_INT lda, const double** B, const MKL_INT ldb, const double beta,
22     double** C, const MKL_INT ldc) {
23   TORCH_INTERNAL_ASSERT(false, "mkl_gemm_batched: ATen not compiled with MKL support");
24 }
25 
mkl_gemm_batched(const TransposeType trans_A,const TransposeType trans_B,const MKL_INT batch_size,const MKL_INT M,const MKL_INT N,const MKL_INT K,const c10::complex<float> alpha,const c10::complex<float> ** A,const MKL_INT lda,const c10::complex<float> ** B,const MKL_INT ldb,const c10::complex<float> beta,c10::complex<float> ** C,const MKL_INT ldc)26 void mkl_gemm_batched(
27     const TransposeType trans_A, const TransposeType trans_B,
28     const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const c10::complex<float> alpha,
29     const c10::complex<float>** A, const MKL_INT lda, const c10::complex<float>** B, const MKL_INT ldb,
30     const c10::complex<float> beta, c10::complex<float>** C, const MKL_INT ldc) {
31   TORCH_INTERNAL_ASSERT(false, "mkl_gemm_batched: ATen not compiled with MKL support");
32 }
33 
mkl_gemm_batched(const TransposeType trans_A,const TransposeType trans_B,const MKL_INT batch_size,const MKL_INT M,const MKL_INT N,const MKL_INT K,const c10::complex<double> alpha,const c10::complex<double> ** A,const MKL_INT lda,const c10::complex<double> ** B,const MKL_INT ldb,const c10::complex<double> beta,c10::complex<double> ** C,const MKL_INT ldc)34 void mkl_gemm_batched(
35     const TransposeType trans_A, const TransposeType trans_B,
36     const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const c10::complex<double> alpha,
37     const c10::complex<double>** A, const MKL_INT lda, const c10::complex<double>** B, const MKL_INT ldb,
38     const c10::complex<double> beta, c10::complex<double>** C, const MKL_INT ldc) {
39   TORCH_INTERNAL_ASSERT(false, "mkl_gemm_batched: ATen not compiled with MKL support");
40 }
41 
mkl_gemm_bf16bf16f32(TransposeType trans_A,TransposeType trans_B,MKL_INT M,MKL_INT N,MKL_INT K,const float alpha,const c10::BFloat16 * A,MKL_INT lda,const c10::BFloat16 * B,MKL_INT ldb,const float beta,float * C,MKL_INT ldc)42 void mkl_gemm_bf16bf16f32(
43     TransposeType trans_A, TransposeType trans_B,
44     MKL_INT M, MKL_INT N, MKL_INT K, const float alpha,
45     const c10::BFloat16* A, MKL_INT lda, const c10::BFloat16* B, MKL_INT ldb,
46     const float beta, float* C, MKL_INT ldc) {
47   TORCH_INTERNAL_ASSERT(false, "mkl_gemm_bf16bf16f32: ATen not compiled with MKL support");
48 }
49 
mkl_gemm_f16f16f32(TransposeType trans_A,TransposeType trans_B,int M,int N,int K,const float alpha,const c10::Half * A,int lda,const c10::Half * B,int ldb,const float beta,float * C,int ldc)50 void mkl_gemm_f16f16f32(
51     TransposeType trans_A, TransposeType trans_B,
52     int M, int N, int K, const float alpha,
53     const c10::Half* A, int lda, const c10::Half* B, int ldb,
54     const float beta, float* C, int ldc) {
55   TORCH_INTERNAL_ASSERT(false, "mkl_gemm_f16f16f32: ATen not compiled with MKL support");
56 }
57 
58 }}
59 
60 #else // AT_MKL_ENABLED
61 
62 #include <mkl.h>
63 #include <c10/util/irange.h>
64 
65 namespace at { namespace native {
66 
67 static CBLAS_TRANSPOSE to_cblas(TransposeType x) {
68   switch (x) {
69     case TransposeType::NoTranspose: return CblasNoTrans;
70     case TransposeType::Transpose: return CblasTrans;
71     case TransposeType::ConjTranspose: return CblasConjTrans;
72   }
73   TORCH_INTERNAL_ASSERT(false, "Unknown TransposeType");
74 }
75 
76 void mkl_gemm_batched(
77     const TransposeType trans_A, const TransposeType trans_B,
78     const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const float alpha,
79     const float** A, const MKL_INT lda, const float** B, const MKL_INT ldb, const float beta,
80     float** C, const MKL_INT ldc) {
81   auto transa_cblas = to_cblas(trans_A);
82   auto transb_cblas = to_cblas(trans_B);
83   cblas_sgemm_batch(CblasColMajor, &transa_cblas, &transb_cblas, &M, &N, &K, &alpha,
84                     A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
85 }
86 
87 void mkl_gemm_batched(
88     const TransposeType trans_A, const TransposeType trans_B,
89     const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const double alpha,
90     const double** A, const MKL_INT lda, const double** B, const MKL_INT ldb, const double beta,
91     double** C, const MKL_INT ldc) {
92   auto transa_cblas = to_cblas(trans_A);
93   auto transb_cblas = to_cblas(trans_B);
94   cblas_dgemm_batch(CblasColMajor, &transa_cblas, &transb_cblas, &M, &N, &K, &alpha,
95                     A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
96 }
97 
98 void mkl_gemm_batched(
99     const TransposeType trans_A, const TransposeType trans_B,
100     const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const c10::complex<float> alpha,
101     const c10::complex<float>** A, const MKL_INT lda, const c10::complex<float>** B, const MKL_INT ldb,
102     const c10::complex<float> beta, c10::complex<float>** C, const MKL_INT ldc) {
103   auto transa_cblas = to_cblas(trans_A);
104   auto transb_cblas = to_cblas(trans_B);
105   cblas_cgemm_batch(CblasColMajor, &transa_cblas, &transb_cblas, &M, &N, &K,
106                     reinterpret_cast<const void*>(&alpha),
107                     reinterpret_cast<const void**>(A), &lda, reinterpret_cast<const void**>(B), &ldb,
108                     reinterpret_cast<const void*>(&beta), reinterpret_cast<void**>(C), &ldc, 1, &batch_size);
109 }
110 
111 void mkl_gemm_batched(
112     const TransposeType trans_A, const TransposeType trans_B,
113     const MKL_INT batch_size, const MKL_INT M, const MKL_INT N, const MKL_INT K, const c10::complex<double> alpha,
114     const c10::complex<double>** A, const MKL_INT lda, const c10::complex<double>** B, const MKL_INT ldb,
115     const c10::complex<double> beta, c10::complex<double>** C, const MKL_INT ldc) {
116   auto transa_cblas = to_cblas(trans_A);
117   auto transb_cblas = to_cblas(trans_B);
118   cblas_zgemm_batch(CblasColMajor, &transa_cblas, &transb_cblas, &M, &N, &K,
119                     reinterpret_cast<const void*>(&alpha),
120                     reinterpret_cast<const void**>(A), &lda, reinterpret_cast<const void**>(B), &ldb,
121                     reinterpret_cast<const void*>(&beta), reinterpret_cast<void**>(C), &ldc, 1, &batch_size);
122 }
123 
124 void mkl_gemm_bf16bf16f32(
125     TransposeType trans_A, TransposeType trans_B,
126     MKL_INT M, MKL_INT N, MKL_INT K, const float alpha,
127     const c10::BFloat16* A, MKL_INT lda, const c10::BFloat16* B, MKL_INT ldb,
128     const float beta, float* C, MKL_INT ldc) {
129 #ifdef MKL_HAS_SBGEMM
130   auto transa_cblas = to_cblas(trans_A);
131   auto transb_cblas = to_cblas(trans_B);
132   cblas_gemm_bf16bf16f32(CblasColMajor, transa_cblas, transb_cblas, M, N, K, alpha,
133                          (const MKL_BF16*)A, lda, (const MKL_BF16*)B, ldb, beta, C, ldc);
134 #else
135   TORCH_INTERNAL_ASSERT(false, "mkl_gemm_bf16bf16f32 requires mkl version > 2021.0");
136 #endif
137 }
138 
139 void mkl_gemm_f16f16f32(
140     TransposeType trans_A, TransposeType trans_B,
141     int M, int N, int K, const float alpha,
142     const c10::Half* A, int lda, const c10::Half* B, int ldb,
143     const float beta, float* C, int ldc) {
144 #ifdef MKL_HAS_SHGEMM
145   auto transa_cblas = to_cblas(trans_A);
146   auto transb_cblas = to_cblas(trans_B);
147   cblas_gemm_f16f16f32(CblasColMajor, transa_cblas, transb_cblas, M, N, K, alpha,
148                          (const MKL_F16*)A, lda, (const MKL_F16*)B, ldb, beta, C, ldc);
149 #else
150   TORCH_INTERNAL_ASSERT(false, "mkl_gemm_f16f16f32 requires mkl version >= 2024.0");
151 #endif
152 }
153 
154 }} // namespace at::native
155 
156 #endif
157 C10_DIAGNOSTIC_POP()
158