1 #pragma once
2 /*
3 Provides a subset of CUDA BLAS functions as templates:
4
5 gemm<Dtype>(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c,
6 ldc)
7
8 gemv<Dtype>(transa, m, n, alpha, a, lda, x, incx, beta, y, incy)
9
10 dot<Dtype>(n, x, incx, y, incy, result)
11
12 where Dtype is double, float, at::Half or at::BFloat16 (ROCm, NOT for dot).
13 The functions are available in at::cuda::blas namespace.
14 */
15
16 #include <ATen/cuda/CUDAContext.h>
17 #include <ATen/OpMathType.h>
18
19 namespace at::cuda::blas {
20
21 // RAII guard that sets the CuBLAS pointer mode and restores it to
22 // its previous value when the guard is destroyed
23 class PointerModeGuard {
24 public:
PointerModeGuard(cublasHandle_t handle,cublasPointerMode_t mode)25 PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) :
26 handle(handle) {
27 TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode));
28 TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode));
29 }
30
~PointerModeGuard()31 ~PointerModeGuard() {
32 cublasSetPointerMode(handle, previous_mode);
33 }
34
35 private:
36 cublasHandle_t handle;
37 cublasPointerMode_t previous_mode;
38 };
39
40 /* LEVEL 3 BLAS FUNCTIONS */
41
42 #define CUDABLAS_GEMM_ARGTYPES(Dtype) \
43 char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
44 const Dtype *a, int64_t lda, const Dtype *b, int64_t ldb, at::opmath_type<Dtype> beta,\
45 Dtype *c, int64_t ldc
46
47 #define CUDABLAS_GEMM_ARGS(Dtype) transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc
48
49 template <typename Dtype>
gemm(CUDABLAS_GEMM_ARGTYPES (Dtype))50 inline void gemm(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
51 static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm: not implemented");
52 }
53
54 template <>
55 void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double));
56 template <>
57 void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float));
58 template <>
59 void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
60 template <>
61 void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
62 template <>
63 void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
64 template <>
65 void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
66
67 template <typename Dtype>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (Dtype))68 inline void gemm_internal(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
69 static_assert(false&&sizeof(Dtype),"at::cuda::blas::gemm_internal: not implemented");
70 }
71
72 template <>
73 void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double));
74 template <>
75 void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float));
76 template <>
77 void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>));
78 template <>
79 void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>));
80 template <>
81 void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
82 template <>
83 void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
84
85 enum GEMMAndBiasActivationEpilogue {
86 None,
87 RELU,
88 GELU,
89 };
90
91 // NOTE: GELU activation is not supported prior to CUDA 11.4 and will
92 // do nothing if passed in that case.
93 template <typename Dtype>
94 void gemm_and_bias(
95 bool transpose_mat1,
96 bool transpose_mat2,
97 int64_t m,
98 int64_t n,
99 int64_t k,
100 at::opmath_type<Dtype> alpha_val,
101 const Dtype* mat1_ptr,
102 int64_t mat1_ld,
103 const Dtype* mat2_ptr,
104 int64_t mat2_ld,
105 const Dtype* bias,
106 Dtype* result_ptr,
107 int64_t result_ld,
108 GEMMAndBiasActivationEpilogue activation = GEMMAndBiasActivationEpilogue::None);
109
110 void int8_gemm(
111 bool transpose_mat1,
112 bool transpose_mat2,
113 int64_t m,
114 int64_t n,
115 int64_t k,
116 const int8_t* mat1_ptr,
117 int64_t mat1_ld,
118 const int8_t* mat2_ptr,
119 int64_t mat2_ld,
120 int32_t* result_ptr,
121 int64_t result_ld);
122
123 void scaled_gemm(
124 char transa,
125 char transb,
126 int64_t m,
127 int64_t n,
128 int64_t k,
129 const void* mat1_ptr,
130 const void* mat1_scale_ptr,
131 int64_t mat1_ld,
132 ScalarType mat1_dtype,
133 const void* mat2_ptr,
134 const void* mat2_scale_ptr,
135 int64_t mat2_ld,
136 ScalarType mat2_dtype,
137 const void* bias_ptr,
138 ScalarType bias_dtype,
139 void* result_ptr,
140 const void* result_scale_ptr,
141 int64_t result_ld,
142 ScalarType result_dtype,
143 void* amax_ptr,
144 bool use_fast_accum);
145
146 #define CUDABLAS_BGEMM_ARGTYPES(Dtype) \
147 char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha, \
148 const Dtype *a, int64_t lda, int64_t stridea, \
149 const Dtype *b, int64_t ldb, int64_t strideb, \
150 at::opmath_type<Dtype> beta, Dtype *c, int64_t ldc, int64_t stridec, int64_t num_batches
151
152 #define CUDABLAS_BGEMM_ARGS(Dtype) \
153 transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, beta, c, ldc, stridec, num_batches
154
155 template <typename Dtype>
bgemm(CUDABLAS_BGEMM_ARGTYPES (Dtype))156 inline void bgemm(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
157 static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm: not implemented");
158 }
159
160 template <>
161 void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double));
162 template <>
163 void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float));
164 template <>
165 void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
166 template <>
167 void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
168 template <>
169 void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
170 template <>
171 void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
172
173 template <typename Dtype>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (Dtype))174 inline void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
175 static_assert(false&&sizeof(Dtype),"at::cuda::blas::bgemm_internal: not implemented");
176 }
177
178 template <>
179 void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double));
180 template <>
181 void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float));
182 template <>
183 void bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>));
184 template <>
185 void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>));
186 template <>
187 void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half));
188 template <>
189 void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16));
190
191 #define CUDABLAS_TRSM_ARGTYPES(Dtype) \
192 cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
193 cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
194 const Dtype *alpha, const Dtype *A, int lda, Dtype *B, int ldb
195
196 template <typename Dtype>
trsm(CUDABLAS_TRSM_ARGTYPES (Dtype))197 inline void trsm(CUDABLAS_TRSM_ARGTYPES(Dtype)) {
198 static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsm: not implemented");
199 }
200
201 template <>
202 TORCH_CUDA_CU_API void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float));
203 template <>
204 TORCH_CUDA_CU_API void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double));
205 template <>
206 TORCH_CUDA_CU_API void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>));
207 template <>
208 TORCH_CUDA_CU_API void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>));
209
210 #define CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype) \
211 cublasHandle_t handle, cublasSideMode_t side, cublasFillMode_t uplo, \
212 cublasOperation_t trans, cublasDiagType_t diag, int m, int n, \
213 const Dtype *alpha, Dtype *A[], int lda, Dtype *B[], int ldb, \
214 int batchCount
215
216 template <typename Dtype>
trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES (Dtype))217 inline void trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES(Dtype)) {
218 static_assert(false&&sizeof(Dtype), "at::cuda::blas::trsmBatched: not implemented");
219 }
220
221 template <>
222 TORCH_CUDA_CU_API void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float));
223 template <>
224 TORCH_CUDA_CU_API void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double));
225 template <>
226 TORCH_CUDA_CU_API void trsmBatched<c10::complex<float>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>));
227 template <>
228 TORCH_CUDA_CU_API void trsmBatched<c10::complex<double>>(CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>));
229
230 /* LEVEL 2 BLAS FUNCTIONS */
231
232 #define CUDABLAS_GEMV_ARGTYPES(Dtype) \
233 char trans, int64_t m, int64_t n, Dtype alpha, const Dtype *a, int64_t lda, \
234 const Dtype *x, int64_t incx, Dtype beta, Dtype *y, int64_t incy
235
236 template <typename Dtype>
gemv(CUDABLAS_GEMV_ARGTYPES (Dtype))237 inline void gemv(CUDABLAS_GEMV_ARGTYPES(Dtype)) {
238 static_assert(false&&sizeof(Dtype), "at::cuda::blas::gemv: not implemented");
239 }
240
241 template <>
242 void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double));
243 template <>
244 void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float));
245 template <>
246 void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>));
247 template <>
248 void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>));
249 template <>
250 void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half));
251 template <>
252 void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16));
253
254 /* LEVEL 1 BLAS FUNCTIONS */
255
256 #define CUDABLAS_DOT_ARGTYPES(Dtype) \
257 cublasHandle_t handle, int n, const Dtype *x, int incx, const Dtype *y, \
258 int incy, Dtype *result
259
260 template <typename Dtype>
dot(CUDABLAS_DOT_ARGTYPES (Dtype))261 inline void dot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
262 static_assert(false&&sizeof(Dtype),"at::cuda::blas::dot: not implemented");
263 }
264
265 template <>
266 void dot<double>(CUDABLAS_DOT_ARGTYPES(double));
267 template <>
268 void dot<float>(CUDABLAS_DOT_ARGTYPES(float));
269 template <>
270 void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half));
271 template <>
272 void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16));
273 template <>
274 void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
275 template <>
276 void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
277
278 template <typename Dtype>
vdot(CUDABLAS_DOT_ARGTYPES (Dtype))279 inline void vdot(CUDABLAS_DOT_ARGTYPES(Dtype)) {
280 static_assert(false&&sizeof(Dtype),"at::cuda::blas::vdot: not implemented");
281 }
282
283 template <>
284 void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>));
285 template <>
286 void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>));
287
288 #define CUDABLAS_GETRS_ARGTYPES(Dtype) \
289 cublasHandle_t handle, cublasOperation_t trans, \
290 int n, int nrhs, Dtype** dA_array, int lda, int* ipiv_array, \
291 Dtype** dB_array, int ldb, int* info_array, int batchsize
292
293 template<class Dtype>
getrsBatched(CUDABLAS_GETRS_ARGTYPES (Dtype))294 void getrsBatched(CUDABLAS_GETRS_ARGTYPES(Dtype)) {
295 static_assert(false&&sizeof(Dtype),"at::cuda::blas::getrsBatched: not implemented");
296 }
297 template<>
298 TORCH_CUDA_CU_API void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float));
299 template<>
300 TORCH_CUDA_CU_API void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double));
301 template<>
302 TORCH_CUDA_CU_API void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>));
303 template<>
304 TORCH_CUDA_CU_API void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>));
305
306 #define CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype) \
307 cublasHandle_t handle, int m, int n, Dtype **A_array, int lda, \
308 Dtype **tau_array, int *info, int batchsize
309
310 template <class Dtype>
geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES (Dtype))311 void geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES(Dtype)) {
312 static_assert(false&&sizeof(Dtype), "at::cuda::blas::geqrfBatched: not implemented");
313 }
314 template <>
315 TORCH_CUDA_CU_API void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float));
316 template <>
317 TORCH_CUDA_CU_API void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double));
318 template <>
319 TORCH_CUDA_CU_API void geqrfBatched<c10::complex<double>>(
320 CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>));
321 template <>
322 TORCH_CUDA_CU_API void geqrfBatched<c10::complex<float>>(
323 CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>));
324
325 #define CUDABLAS_GETRF_ARGTYPES(Dtype) \
326 int n, Dtype** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize
327
328 template<class Dtype>
getrfBatched(CUDABLAS_GETRF_ARGTYPES (Dtype))329 void getrfBatched(CUDABLAS_GETRF_ARGTYPES(Dtype)) {
330 TORCH_CHECK(false, "at::cuda::blas::getrfBatched: not implemented");
331 }
332 template<>
333 TORCH_CUDA_CU_API void getrfBatched<float>(CUDABLAS_GETRF_ARGTYPES(float));
334 template<>
335 TORCH_CUDA_CU_API void getrfBatched<double>(CUDABLAS_GETRF_ARGTYPES(double));
336 template<>
337 TORCH_CUDA_CU_API void getrfBatched<c10::complex<double>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<double>));
338 template<>
339 TORCH_CUDA_CU_API void getrfBatched<c10::complex<float>>(CUDABLAS_GETRF_ARGTYPES(c10::complex<float>));
340
341 #define CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype) \
342 cublasHandle_t handle, cublasOperation_t trans, int m, int n, int nrhs, Dtype** dA_array, int ldda, Dtype** dC_array, int lddc, int* info, int *devInfoArray, int batchSize
343
344 template <class Dtype>
gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES (Dtype))345 void gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES(Dtype)) {
346 static_assert(false&&sizeof(Dtype),"at::cuda::blas::gelsBatched: not implemented");
347 }
348
349 template<>
350 TORCH_CUDA_CU_API void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double));
351 template<>
352 TORCH_CUDA_CU_API void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float));
353 template<>
354 TORCH_CUDA_CU_API void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>));
355 template<>
356 TORCH_CUDA_CU_API void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>));
357
358 } // namespace at::cuda::blas
359