xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/CPUBlas.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/OpMathType.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <ATen/native/TransposeType.h>
6 #include <c10/util/complex.h>
7 #include <c10/core/ScalarType.h>
8 #include <c10/core/Scalar.h>
9 
10 
11 namespace at::native::cpublas {
12 
13 namespace internal {
14 void normalize_last_dims(
15   TransposeType transa, TransposeType transb,
16   int64_t m, int64_t n, int64_t k,
17   int64_t *lda, int64_t *ldb, int64_t *ldc);
18 }  // namespace internal
19 
20 using gemm_fn = void(*)(
21     at::ScalarType type,
22     TransposeType transa, TransposeType transb,
23     int64_t m, int64_t n, int64_t k,
24     const Scalar& alpha,
25     const void *a, int64_t lda,
26     const void *b, int64_t ldb,
27     const Scalar& beta,
28     void *c, int64_t ldc);
29 
30 DECLARE_DISPATCH(gemm_fn, gemm_stub);
31 
32 template <typename scalar_t>
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,at::opmath_type<scalar_t> alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,at::opmath_type<scalar_t> beta,scalar_t * c,int64_t ldc)33 void gemm(
34     TransposeType transa, TransposeType transb,
35     int64_t m, int64_t n, int64_t k,
36     at::opmath_type<scalar_t> alpha,
37     const scalar_t *a, int64_t lda,
38     const scalar_t *b, int64_t ldb,
39     at::opmath_type<scalar_t> beta,
40     scalar_t *c, int64_t ldc) {
41   internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
42   gemm_stub(
43     kCPU, c10::CppTypeToScalarType<scalar_t>::value,
44     transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
45 }
46 
47 void gemm(
48     TransposeType transa, TransposeType transb,
49     int64_t m, int64_t n, int64_t k,
50     double alpha,
51     const double *a, int64_t lda,
52     const double *b, int64_t ldb,
53     double beta,
54     double *c, int64_t ldc);
55 
56 void gemm(
57     TransposeType transa, TransposeType transb,
58     int64_t m, int64_t n, int64_t k,
59     float alpha,
60     const float *a, int64_t lda,
61     const float *b, int64_t ldb,
62     float beta,
63     float *c, int64_t ldc);
64 
65 void gemm(
66     TransposeType transa, TransposeType transb,
67     int64_t m, int64_t n, int64_t k,
68     float alpha,
69     const at::BFloat16 *a, int64_t lda,
70     const at::BFloat16 *b, int64_t ldb,
71     float beta,
72     at::BFloat16 *c, int64_t ldc);
73 
74 void gemm(
75     TransposeType transa, TransposeType transb,
76     int64_t m, int64_t n, int64_t k,
77     const float alpha,
78     const at::BFloat16 *a, int64_t lda,
79     const at::BFloat16 *b, int64_t ldb,
80     const float beta,
81     float *c, int64_t ldc);
82 
83 void gemm(
84     TransposeType transa, TransposeType transb,
85     int64_t m, int64_t n, int64_t k,
86     float alpha,
87     const at::Half *a, int64_t lda,
88     const at::Half *b, int64_t ldb,
89     float beta,
90     at::Half *c, int64_t ldc);
91 
92 void gemm(
93     TransposeType transa, TransposeType transb,
94     int64_t m, int64_t n, int64_t k,
95     const float alpha,
96     const at::Half *a, int64_t lda,
97     const at::Half *b, int64_t ldb,
98     const float beta,
99     float *c, int64_t ldc);
100 
101 void gemm(
102     TransposeType transa, TransposeType transb,
103     int64_t m, int64_t n, int64_t k,
104     c10::complex<double> alpha,
105     const c10::complex<double> *a, int64_t lda,
106     const c10::complex<double> *b, int64_t ldb,
107     c10::complex<double> beta,
108     c10::complex<double> *c, int64_t ldc);
109 
110 void gemm(
111     TransposeType transa, TransposeType transb,
112     int64_t m, int64_t n, int64_t k,
113     c10::complex<float> alpha,
114     const c10::complex<float> *a, int64_t lda,
115     const c10::complex<float> *b, int64_t ldb,
116     c10::complex<float> beta,
117     c10::complex<float> *c, int64_t ldc);
118 
119 void gemm(
120     TransposeType transa, TransposeType transb,
121     int64_t m, int64_t n, int64_t k,
122     int64_t alpha,
123     const int64_t *a, int64_t lda,
124     const int64_t *b, int64_t ldb,
125     int64_t beta,
126     int64_t *c, int64_t ldc);
127 
128 template <typename scalar_t>
129 void gemm_batched(
130     TransposeType transa, TransposeType transb,
131     int64_t batch_size, int64_t m, int64_t n, int64_t k,
132     scalar_t alpha,
133     const scalar_t * const *a, int64_t lda,
134     const scalar_t * const *b, int64_t ldb,
135     const scalar_t beta,
136     scalar_t * const *c, int64_t ldc);
137 
138 template <typename scalar_t>
139 void gemm_batched_with_stride(
140     TransposeType transa, TransposeType transb,
141     int64_t batch_size, int64_t m, int64_t n, int64_t k,
142     scalar_t alpha,
143     const scalar_t *a, int64_t lda, int64_t batch_stride_a,
144     const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
145     scalar_t beta,
146     scalar_t *c, int64_t ldc, int64_t batch_stride_c);
147 
148 using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
149 
150 DECLARE_DISPATCH(axpy_fn, axpy_stub);
151 
152 template<typename scalar_t>
axpy(int64_t n,scalar_t a,const scalar_t * x,int64_t incx,scalar_t * y,int64_t incy)153 void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
154   if(n == 1)
155   {
156     incx = 1;
157     incy = 1;
158   }
159   axpy_stub(
160       kCPU, c10::CppTypeToScalarType<scalar_t>::value,
161       n, a, x, incx, y, incy);
162 }
163 
164 void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
165 void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
166 void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
167 void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
168 
169 using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
170 
171 DECLARE_DISPATCH(copy_fn, copy_stub);
172 
173 template<typename scalar_t>
copy(int64_t n,const scalar_t * x,int64_t incx,scalar_t * y,int64_t incy)174 void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
175   if(n == 1)
176   {
177     incx = 1;
178     incy = 1;
179   }
180   copy_stub(
181       kCPU, c10::CppTypeToScalarType<scalar_t>::value,
182       n, x, incx, y, incy);
183 }
184 
185 void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
186 void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
187 void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
188 void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
189 
190 // Batch-reduce GEMM
191 // Operates by the following formula:
192 // C = alpha * SUM(A[i] x B[i]) + beta * C, i = 0 to batch size
193 // A Base pointer to a tensor A.
194 // B Base pointer to a tensor B.
195 // C Pointer to a tensor C (accumulation buffer).
196 TORCH_API void brgemm(
197     int64_t M,
198     int64_t N,
199     int64_t K,
200     int64_t ld_a,
201     int64_t ld_b,
202     int64_t ld_c,
203     const float alpha,
204     const float beta,
205     const at::Half* A,
206     const at::Half* B,
207     float* C);
208 
209 // Release brgemm hardware context
210 void brgemm_release();
211 
212 // Pack B matrix to get better performance if needed
213 void pack(
214     int64_t K,
215     int64_t N,
216     int64_t ld_in,
217     int64_t ld_out,
218     ScalarType dt_in,
219     ScalarType dt_out,
220     const void* in,
221     void* out);
222 
223 // Whether pack is needed in the platform.
224 bool need_pack(ScalarType dt_in);
225 
226 } // namespace at::native::cpublas
227