xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDABlas.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 /*
3   Provides a subset of CUDA BLAS functions as templates:
4 
5     gemm<Dtype>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
6   ldc)
7 
8     gemv<Dtype>(transa, m, n, alpha, a, lda, x, incx, beta, y, incy)
9 
10     dot<Dtype>(n, x, incx, y, incy, result)
11 
12   where Dtype is double, float, at::Half or at::BFloat16 (ROCm, NOT for dot).
13   The functions are available in at::cuda::blas namespace.
14  */
15 
16 #include <ATen/cuda/CUDAContext.h>
17 #include <ATen/OpMathType.h>
18 
19 namespace at::cuda::blas {
20 
21 // RAII guard that sets the CuBLAS pointer mode and restores it to
22 // its previous value when the guard is destroyed
23 class PointerModeGuard {
24 public:
PointerModeGuard(cublasHandle_t handle,cublasPointerMode_t mode)25   PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) :
26       handle(handle) {
27     TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode));
28     TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode));
29   }
30 
~PointerModeGuard()31   ~PointerModeGuard() {
32     cublasSetPointerMode(handle, previous_mode);
33   }
34 
35 private:
36   cublasHandle_t handle;
37   cublasPointerMode_t previous_mode;
38 };
39 
40 /* LEVEL 3 BLAS FUNCTIONS */
41 
42 #define CUDABLAS_GEMM_ARGTYPES(Dtype)                                                       \
43   char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha,  \
44       const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\
45       Dtype *c, int64_t ldc
46 
47 #define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc
48 
49 template <typename Dtype>
gemm(CUDABLAS_GEMM_ARGTYPES (Dtype))50 inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
51   static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm: not implemented");
52 }
53 
54 template <>
55 void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
56 template <>
57 void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
58 template <>
59 void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
60 template <>
61 void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
62 template <>
63 void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
64 template <>
65 void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
66 
67 template <typename Dtype>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (Dtype))68 inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
69   static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm_internal: not implemented");
70 }
71 
72 template <>
73 void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double));
74 template <>
75 void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float));
76 template <>
77 void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
78 template <>
79 void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
80 template <>
81 void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
82 template <>
83 void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
84 
85 enum GEMMAndBiasActivationEpilogue {
86   None,
87   RELU,
88   GELU,
89 };
90 
91 // NOTE: GELU activation is not supported prior to CUDA 11.4 and will
92 // do nothing if passed in that case.
93 template <typename Dtype>
94 void gemm_and_bias(
95     bool transpose_mat1,
96     bool transpose_mat2,
97     int64_t m,
98     int64_t n,
99     int64_t k,
100     at::opmath_type<Dtype> alpha_val,
101     const Dtype* mat1_ptr,
102     int64_t mat1_ld,
103     const Dtype* mat2_ptr,
104     int64_t mat2_ld,
105     const Dtype* bias,
106     Dtype* result_ptr,
107     int64_t result_ld,
108     GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);
109 
110 void int8_gemm(
111     bool transpose_mat1,
112     bool transpose_mat2,
113     int64_t m,
114     int64_t n,
115     int64_t k,
116     const int8_t* mat1_ptr,
117     int64_t mat1_ld,
118     const int8_t* mat2_ptr,
119     int64_t mat2_ld,
120     int32_t* result_ptr,
121     int64_t result_ld);
122 
123 void scaled_gemm(
124     char transa,
125     char transb,
126     int64_t m,
127     int64_t n,
128     int64_t k,
129     const void* mat1_ptr,
130     const void* mat1_scale_ptr,
131     int64_t mat1_ld,
132     ScalarType mat1_dtype,
133     const void* mat2_ptr,
134     const void* mat2_scale_ptr,
135     int64_t mat2_ld,
136     ScalarType mat2_dtype,
137     const void* bias_ptr,
138     ScalarType bias_dtype,
139     void* result_ptr,
140     const void* result_scale_ptr,
141     int64_t result_ld,
142     ScalarType result_dtype,
143     void* amax_ptr,
144     bool use_fast_accum);
145 
146 #define CUDABLAS_BGEMM_ARGTYPES(Dtype)                                                        \
147   char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha,    \
148       const Dtype *a, int64_t lda, int64_t stridea,                                           \
149       const Dtype *b, int64_t ldb, int64_t strideb,                                           \
150       at::opmath_type<Dtype> beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
151 
152 #define CUDABLAS_BGEMM_ARGS(Dtype) \
153   transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches
154 
155 template <typename Dtype>
bgemm(CUDABLAS_BGEMM_ARGTYPES (Dtype))156 inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
157   static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm: not implemented");
158 }
159 
160 template <>
161 void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
162 template <>
163 void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
164 template <>
165 void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
166 template <>
167 void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
168 template <>
169 void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
170 template <>
171 void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
172 
173 template <typename Dtype>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (Dtype))174 inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
175   static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm_internal: not implemented");
176 }
177 
178 template <>
179 void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double));
180 template <>
181 void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float));
182 template <>
183 void bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
184 template <>
185 void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
186 template <>
187 void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
188 template <>
189 void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
190 
191 #define CUDABLAS_TRSM_ARGTYPES(Dtype)                                  \
192   cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
193       cublasOperation_t trans, cublasDiagType_t diag, int m, int n,    \
194       const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb
195 
196 template <typename Dtype>
trsm(CUDABLAS_TRSM_ARGTYPES (Dtype))197 inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) {
198   static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsm: not implemented");
199 }
200 
201 template <>
202 TORCH_CUDA_CU_API void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float));
203 template <>
204 TORCH_CUDA_CU_API void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double));
205 template <>
206 TORCH_CUDA_CU_API void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>));
207 template <>
208 TORCH_CUDA_CU_API void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>));
209 
210 #define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)                          \
211   cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
212       cublasOperation_t trans, cublasDiagType_t diag, int m, int n,    \
213       const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb,    \
214       int batchCount
215 
216 template <typename Dtype>
trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES (Dtype))217 inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) {
218   static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsmBatched: not implemented");
219 }
220 
221 template <>
222 TORCH_CUDA_CU_API void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float));
223 template <>
224 TORCH_CUDA_CU_API void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double));
225 template <>
226 TORCH_CUDA_CU_API void trsmBatched<c10::complex<float>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>));
227 template <>
228 TORCH_CUDA_CU_API void trsmBatched<c10::complex<double>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>));
229 
230 /* LEVEL 2 BLAS FUNCTIONS */
231 
232 #define CUDABLAS_GEMV_ARGTYPES(Dtype)                                         \
233   char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \
234       const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy
235 
236 template <typename Dtype>
gemv(CUDABLAS_GEMV_ARGTYPES (Dtype))237 inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) {
238   static_assert(false&&sizeof(Dtype), "at::cuda::blas::gemv: not implemented");
239 }
240 
241 template <>
242 void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
243 template <>
244 void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
245 template <>
246 void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
247 template <>
248 void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
249 template <>
250 void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
251 template <>
252 void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
253 
254 /* LEVEL 1 BLAS FUNCTIONS */
255 
256 #define CUDABLAS_DOT_ARGTYPES(Dtype)                                      \
257   cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \
258       int incy, Dtype *result
259 
260 template <typename Dtype>
dot(CUDABLAS_DOT_ARGTYPES (Dtype))261 inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
262   static_assert(false&&sizeof(Dtype),"at::cuda::blas::dot: not implemented");
263 }
264 
265 template <>
266 void dot<double>(CUDABLAS_DOT_ARGTYPES(double));
267 template <>
268 void dot<float>(CUDABLAS_DOT_ARGTYPES(float));
269 template <>
270 void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half));
271 template <>
272 void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16));
273 template <>
274 void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
275 template <>
276 void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
277 
278 template <typename Dtype>
vdot(CUDABLAS_DOT_ARGTYPES (Dtype))279 inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
280   static_assert(false&&sizeof(Dtype),"at::cuda::blas::vdot: not implemented");
281 }
282 
283 template <>
284 void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
285 template <>
286 void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
287 
288 #define CUDABLAS_GETRS_ARGTYPES(Dtype)  \
289   cublasHandle_t handle, cublasOperation_t trans, \
290   int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
291   Dtype** dB_array, int ldb, int* info_array, int batchsize
292 
293 template<class Dtype>
getrsBatched(CUDABLAS_GETRS_ARGTYPES (Dtype))294 void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) {
295   static_assert(false&&sizeof(Dtype),"at::cuda::blas::getrsBatched: not implemented");
296 }
297 template<>
298 TORCH_CUDA_CU_API void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float));
299 template<>
300 TORCH_CUDA_CU_API void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double));
301 template<>
302 TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>));
303 template<>
304 TORCH_CUDA_CU_API void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>));
305 
306 #define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)                   \
307   cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
308       Dtype **tau_array, int *info, int batchsize
309 
310 template <class Dtype>
geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES (Dtype))311 void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
312   static_assert(false&&sizeof(Dtype), "at::cuda::blas::geqrfBatched: not implemented");
313 }
314 template <>
315 TORCH_CUDA_CU_API void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float));
316 template <>
317 TORCH_CUDA_CU_API void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double));
318 template <>
319 TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>(
320     CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>));
321 template <>
322 TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float>>(
323     CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>));
324 
325 #define CUDABLAS_GETRF_ARGTYPES(Dtype)  \
326   int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize
327 
328 template<class Dtype>
getrfBatched(CUDABLAS_GETRF_ARGTYPES (Dtype))329 void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) {
330   TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented");
331 }
332 template<>
333 TORCH_CUDA_CU_API void getrfBatched<float>(CUDABLAS_GETRF_ARGTYPES(float));
334 template<>
335 TORCH_CUDA_CU_API void getrfBatched<double>(CUDABLAS_GETRF_ARGTYPES(double));
336 template<>
337 TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<double>));
338 template<>
339 TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>));
340 
341 #define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)  \
342   cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
343 
344 template <class Dtype>
gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES (Dtype))345 void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) {
346   static_assert(false&&sizeof(Dtype),"at::cuda::blas::gelsBatched: not implemented");
347 }
348 
349 template<>
350 TORCH_CUDA_CU_API void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double));
351 template<>
352 TORCH_CUDA_CU_API void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float));
353 template<>
354 TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>));
355 template<>
356 TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
357 
358 } // namespace at::cuda::blas
359