1 #pragma once 2 #include <ATen/Config.h> 3 #include <ATen/native/TransposeType.h> 4 #include <c10/util/complex.h> 5 #include <c10/core/ScalarType.h> 6 7 #if !AT_MKL_ENABLED() 8 #define MKL_INT int 9 #else 10 #include <mkl.h> 11 #endif 12 13 namespace at { 14 namespace native { 15 16 void mkl_gemm_batched( 17 TransposeType trans_A, TransposeType trans_B, 18 MKL_INT batch_size, MKL_INT M, MKL_INT N, MKL_INT K, float alpha, 19 const float** A, MKL_INT lda, const float** B, MKL_INT ldb, float beta, 20 float** C, MKL_INT ldc); 21 22 void mkl_gemm_batched( 23 TransposeType trans_A, TransposeType trans_B, 24 MKL_INT batch_size, MKL_INT M, MKL_INT N, MKL_INT K, double alpha, 25 const double** A, MKL_INT lda, const double** B, MKL_INT ldb, double beta, 26 double** C, MKL_INT ldc); 27 28 void mkl_gemm_batched( 29 TransposeType trans_A, TransposeType trans_B, 30 MKL_INT batch_size, MKL_INT M, MKL_INT N, MKL_INT K, c10::complex<float> alpha, 31 const c10::complex<float>** A, MKL_INT lda, const c10::complex<float>** B, MKL_INT ldb, 32 c10::complex<float> beta, c10::complex<float>** C, MKL_INT ldc); 33 34 void mkl_gemm_batched( 35 TransposeType trans_A, TransposeType trans_B, 36 MKL_INT batch_size, MKL_INT M, MKL_INT N, MKL_INT K, c10::complex<double> alpha, 37 const c10::complex<double>** A, MKL_INT lda, const c10::complex<double>** B, MKL_INT ldb, 38 c10::complex<double> beta, c10::complex<double>** C, MKL_INT ldc); 39 40 void mkl_gemm_bf16bf16f32( 41 TransposeType trans_A, TransposeType trans_B, 42 MKL_INT M, MKL_INT N, MKL_INT K, const float alpha, 43 const c10::BFloat16* A, MKL_INT lda, const c10::BFloat16* B, MKL_INT ldb, 44 const float beta, float* C, MKL_INT ldc); 45 46 void mkl_gemm_f16f16f32( 47 TransposeType trans_A, TransposeType trans_B, 48 int M, int N, int K, const float alpha, 49 const c10::Half* A, int lda, const c10::Half* B, int ldb, 50 const float beta, float* C, int ldc); 51 }} // namespace at::native 52