1 #pragma once
2
3 #include <ATen/OpMathType.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <ATen/native/TransposeType.h>
6 #include <c10/util/complex.h>
7 #include <c10/core/ScalarType.h>
8 #include <c10/core/Scalar.h>
9
10
11 namespace at::native::cpublas {
12
13 namespace internal {
14 void normalize_last_dims(
15 TransposeType transa, TransposeType transb,
16 int64_t m, int64_t n, int64_t k,
17 int64_t *lda, int64_t *ldb, int64_t *ldc);
18 } // namespace internal
19
20 using gemm_fn = void(*)(
21 at::ScalarType type,
22 TransposeType transa, TransposeType transb,
23 int64_t m, int64_t n, int64_t k,
24 const Scalar& alpha,
25 const void *a, int64_t lda,
26 const void *b, int64_t ldb,
27 const Scalar& beta,
28 void *c, int64_t ldc);
29
30 DECLARE_DISPATCH(gemm_fn, gemm_stub);
31
32 template <typename scalar_t>
gemm(TransposeType transa,TransposeType transb,int64_t m,int64_t n,int64_t k,at::opmath_type<scalar_t> alpha,const scalar_t * a,int64_t lda,const scalar_t * b,int64_t ldb,at::opmath_type<scalar_t> beta,scalar_t * c,int64_t ldc)33 void gemm(
34 TransposeType transa, TransposeType transb,
35 int64_t m, int64_t n, int64_t k,
36 at::opmath_type<scalar_t> alpha,
37 const scalar_t *a, int64_t lda,
38 const scalar_t *b, int64_t ldb,
39 at::opmath_type<scalar_t> beta,
40 scalar_t *c, int64_t ldc) {
41 internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
42 gemm_stub(
43 kCPU, c10::CppTypeToScalarType<scalar_t>::value,
44 transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
45 }
46
47 void gemm(
48 TransposeType transa, TransposeType transb,
49 int64_t m, int64_t n, int64_t k,
50 double alpha,
51 const double *a, int64_t lda,
52 const double *b, int64_t ldb,
53 double beta,
54 double *c, int64_t ldc);
55
56 void gemm(
57 TransposeType transa, TransposeType transb,
58 int64_t m, int64_t n, int64_t k,
59 float alpha,
60 const float *a, int64_t lda,
61 const float *b, int64_t ldb,
62 float beta,
63 float *c, int64_t ldc);
64
65 void gemm(
66 TransposeType transa, TransposeType transb,
67 int64_t m, int64_t n, int64_t k,
68 float alpha,
69 const at::BFloat16 *a, int64_t lda,
70 const at::BFloat16 *b, int64_t ldb,
71 float beta,
72 at::BFloat16 *c, int64_t ldc);
73
74 void gemm(
75 TransposeType transa, TransposeType transb,
76 int64_t m, int64_t n, int64_t k,
77 const float alpha,
78 const at::BFloat16 *a, int64_t lda,
79 const at::BFloat16 *b, int64_t ldb,
80 const float beta,
81 float *c, int64_t ldc);
82
83 void gemm(
84 TransposeType transa, TransposeType transb,
85 int64_t m, int64_t n, int64_t k,
86 float alpha,
87 const at::Half *a, int64_t lda,
88 const at::Half *b, int64_t ldb,
89 float beta,
90 at::Half *c, int64_t ldc);
91
92 void gemm(
93 TransposeType transa, TransposeType transb,
94 int64_t m, int64_t n, int64_t k,
95 const float alpha,
96 const at::Half *a, int64_t lda,
97 const at::Half *b, int64_t ldb,
98 const float beta,
99 float *c, int64_t ldc);
100
101 void gemm(
102 TransposeType transa, TransposeType transb,
103 int64_t m, int64_t n, int64_t k,
104 c10::complex<double> alpha,
105 const c10::complex<double> *a, int64_t lda,
106 const c10::complex<double> *b, int64_t ldb,
107 c10::complex<double> beta,
108 c10::complex<double> *c, int64_t ldc);
109
110 void gemm(
111 TransposeType transa, TransposeType transb,
112 int64_t m, int64_t n, int64_t k,
113 c10::complex<float> alpha,
114 const c10::complex<float> *a, int64_t lda,
115 const c10::complex<float> *b, int64_t ldb,
116 c10::complex<float> beta,
117 c10::complex<float> *c, int64_t ldc);
118
119 void gemm(
120 TransposeType transa, TransposeType transb,
121 int64_t m, int64_t n, int64_t k,
122 int64_t alpha,
123 const int64_t *a, int64_t lda,
124 const int64_t *b, int64_t ldb,
125 int64_t beta,
126 int64_t *c, int64_t ldc);
127
128 template <typename scalar_t>
129 void gemm_batched(
130 TransposeType transa, TransposeType transb,
131 int64_t batch_size, int64_t m, int64_t n, int64_t k,
132 scalar_t alpha,
133 const scalar_t * const *a, int64_t lda,
134 const scalar_t * const *b, int64_t ldb,
135 const scalar_t beta,
136 scalar_t * const *c, int64_t ldc);
137
138 template <typename scalar_t>
139 void gemm_batched_with_stride(
140 TransposeType transa, TransposeType transb,
141 int64_t batch_size, int64_t m, int64_t n, int64_t k,
142 scalar_t alpha,
143 const scalar_t *a, int64_t lda, int64_t batch_stride_a,
144 const scalar_t *b, int64_t ldb, int64_t batch_stride_b,
145 scalar_t beta,
146 scalar_t *c, int64_t ldc, int64_t batch_stride_c);
147
148 using axpy_fn = void(*)(at::ScalarType type, int64_t n, const Scalar& a, const void *x, int64_t incx, void *y, int64_t incy);
149
150 DECLARE_DISPATCH(axpy_fn, axpy_stub);
151
152 template<typename scalar_t>
axpy(int64_t n,scalar_t a,const scalar_t * x,int64_t incx,scalar_t * y,int64_t incy)153 void axpy(int64_t n, scalar_t a, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy){
154 if(n == 1)
155 {
156 incx = 1;
157 incy = 1;
158 }
159 axpy_stub(
160 kCPU, c10::CppTypeToScalarType<scalar_t>::value,
161 n, a, x, incx, y, incy);
162 }
163
164 void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t incy);
165 void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t incy);
166 void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
167 void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
168
169 using copy_fn = void(*)(at::ScalarType type, int64_t n, const void *x, int64_t incx, void *y, int64_t incy);
170
171 DECLARE_DISPATCH(copy_fn, copy_stub);
172
173 template<typename scalar_t>
copy(int64_t n,const scalar_t * x,int64_t incx,scalar_t * y,int64_t incy)174 void copy(int64_t n, const scalar_t *x, int64_t incx, scalar_t *y, int64_t incy) {
175 if(n == 1)
176 {
177 incx = 1;
178 incy = 1;
179 }
180 copy_stub(
181 kCPU, c10::CppTypeToScalarType<scalar_t>::value,
182 n, x, incx, y, incy);
183 }
184
185 void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy);
186 void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy);
187 void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<double> *y, int64_t incy);
188 void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<float> *y, int64_t incy);
189
190 // Batch-reduce GEMM
191 // Operates by the following formula:
192 // C = alpha * SUM(A[i] x B[i]) + beta * C, i = 0 to batch size
193 // A Base pointer to a tensor A.
194 // B Base pointer to a tensor B.
195 // C Pointer to a tensor C (accumulation buffer).
196 TORCH_API void brgemm(
197 int64_t M,
198 int64_t N,
199 int64_t K,
200 int64_t ld_a,
201 int64_t ld_b,
202 int64_t ld_c,
203 const float alpha,
204 const float beta,
205 const at::Half* A,
206 const at::Half* B,
207 float* C);
208
209 // Release brgemm hardware context
210 void brgemm_release();
211
212 // Pack B matrix to get better performance if needed
213 void pack(
214 int64_t K,
215 int64_t N,
216 int64_t ld_in,
217 int64_t ld_out,
218 ScalarType dt_in,
219 ScalarType dt_out,
220 const void* in,
221 void* out);
222
223 // Whether pack is needed in the platform.
224 bool need_pack(ScalarType dt_in);
225
226 } // namespace at::native::cpublas
227