xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkl/LinearAlgebra.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/Config.h>
3 #include <ATen/native/TransposeType.h>
4 #include <c10/util/complex.h>
5 #include <c10/core/ScalarType.h>
6 
7 #if !AT_MKL_ENABLED()
8 #define MKL_INT int
9 #else
10 #include <mkl.h>
11 #endif
12 
13 namespace at {
14 namespace native {
15 
16 void mkl_gemm_batched(
17     TransposeType trans_A, TransposeType trans_B,
18     MKL_INT batch_size, MKL_INT M, MKL_INT N, MKL_INT K, float alpha,
19     const float** A, MKL_INT lda, const float** B, MKL_INT ldb, float beta,
20     float** C, MKL_INT ldc);
21 
22 void mkl_gemm_batched(
23     TransposeType trans_A, TransposeType trans_B,
24     MKL_INT batch_size, MKL_INT M, MKL_INT N, MKL_INT K, double alpha,
25     const double** A, MKL_INT lda, const double** B, MKL_INT ldb, double beta,
26     double** C, MKL_INT ldc);
27 
28 void mkl_gemm_batched(
29     TransposeType trans_A, TransposeType trans_B,
30     MKL_INT batch_size, MKL_INT M, MKL_INT N, MKL_INT K, c10::complex<float> alpha,
31     const c10::complex<float>** A, MKL_INT lda, const c10::complex<float>** B, MKL_INT ldb,
32     c10::complex<float> beta, c10::complex<float>** C, MKL_INT ldc);
33 
34 void mkl_gemm_batched(
35     TransposeType trans_A, TransposeType trans_B,
36     MKL_INT batch_size, MKL_INT M, MKL_INT N, MKL_INT K, c10::complex<double> alpha,
37     const c10::complex<double>** A, MKL_INT lda, const c10::complex<double>** B, MKL_INT ldb,
38     c10::complex<double> beta, c10::complex<double>** C, MKL_INT ldc);
39 
40 void mkl_gemm_bf16bf16f32(
41     TransposeType trans_A, TransposeType trans_B,
42     MKL_INT M, MKL_INT N, MKL_INT K, const float alpha,
43     const c10::BFloat16* A, MKL_INT lda, const c10::BFloat16* B, MKL_INT ldb,
44     const float beta, float* C, MKL_INT ldc);
45 
46 void mkl_gemm_f16f16f32(
47     TransposeType trans_A, TransposeType trans_B,
48     int M, int N, int K, const float alpha,
49     const c10::Half* A, int lda, const c10::Half* B, int ldb,
50     const float beta, float* C, int ldc);
51 }}  // namespace at::native
52