1 /*
2 Provides the implementations of CUDA BLAS function templates.
3 */
4
5 #include <ATen/ATen.h>
6 #include <ATen/cuda/CUDABlas.h>
7 #include <ATen/cuda/Exceptions.h>
8 #include <ATen/cuda/CUDADataType.h>
9 #include <ATen/cuda/tunable/Tunable.h>
10 #include <ATen/cuda/tunable/TunableGemm.h>
11 #include <c10/cuda/CUDACachingAllocator.h>
12 #include <c10/cuda/CUDAFunctions.h>
13 #include <c10/macros/Export.h>
14 #include <c10/util/irange.h>
15
16 #ifdef USE_ROCM
17 #include <hipblaslt/hipblaslt-ext.hpp>
18 // until hipblas has an API to accept flags, we must use rocblas here
19 #include <hipblas/hipblas.h>
20 #include <rocblas/rocblas.h>
21 #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
22 #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
23 // needed to work around calling rocblas API instead of hipblas API
hipOperationToRocOperation(hipblasOperation_t op)24 static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
25 {
26 switch(op)
27 {
28 case HIPBLAS_OP_N:
29 return rocblas_operation_none;
30 case HIPBLAS_OP_T:
31 return rocblas_operation_transpose;
32 case HIPBLAS_OP_C:
33 return rocblas_operation_conjugate_transpose;
34 }
35 AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
36 }
rocBLASStatusToHIPStatus(rocblas_status error)37 static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
38 {
39 switch(error)
40 {
41 case rocblas_status_size_unchanged:
42 case rocblas_status_size_increased:
43 case rocblas_status_success:
44 return HIPBLAS_STATUS_SUCCESS;
45 case rocblas_status_invalid_handle:
46 return HIPBLAS_STATUS_NOT_INITIALIZED;
47 case rocblas_status_not_implemented:
48 return HIPBLAS_STATUS_NOT_SUPPORTED;
49 case rocblas_status_invalid_pointer:
50 case rocblas_status_invalid_size:
51 case rocblas_status_invalid_value:
52 return HIPBLAS_STATUS_INVALID_VALUE;
53 case rocblas_status_memory_error:
54 return HIPBLAS_STATUS_ALLOC_FAILED;
55 case rocblas_status_internal_error:
56 return HIPBLAS_STATUS_INTERNAL_ERROR;
57 }
58 AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
59 }
60 // hipblas does not have hipblasSetMathMode
61 #define hipblasSetMathMode(handle, flags) HIPBLAS_STATUS_SUCCESS
62 // until we use hiblas v2
63 // hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
64 // however hipblas v1 is still using its custom type
65 #ifndef HIPBLAS_V2
66 #define HIP_R_16F HIPBLAS_R_16F
67 #define HIP_R_32F HIPBLAS_R_32F
68 #define HIP_R_64F HIPBLAS_R_64F
69 #define HIP_C_16F HIPBLAS_C_16F
70 #define HIP_C_32F HIPBLAS_C_32F
71 #define HIP_C_64F HIPBLAS_C_64F
72 #define HIP_R_8I HIPBLAS_R_8I
73 #define HIP_R_8U HIPBLAS_R_8U
74 #define HIP_R_32I HIPBLAS_R_32I
75 #define HIP_R_32U HIPBLAS_R_32U
76 #define HIP_C_8I HIPBLAS_C_8I
77 #define HIP_C_8U HIPBLAS_C_8U
78 #define HIP_C_32I HIPBLAS_C_32I
79 #define HIP_C_32U HIPBLAS_C_32U
80 #define HIP_R_16BF HIPBLAS_R_16B
81 #define HIP_C_16BF HIPBLAS_C_16B
82 #endif
83 #endif
84
85 #define CUDABLAS_POSINT_CHECK(FD, X) \
86 TORCH_CHECK( \
87 (X > 0 && X <= INT_MAX), \
88 "at::cuda::blas::" #FD " argument " #X \
89 " must be positive and less than ", \
90 INT_MAX, \
91 " but got ", \
92 X)
93
94 #define CUDABLAS_NONNEGINT_CHECK(FD, X) \
95 TORCH_CHECK( \
96 (X >= 0 && X <= INT_MAX), \
97 "at::cuda::blas::" #FD " argument " #X \
98 " must be non-negative and less than ", \
99 INT_MAX, \
100 " but got ", \
101 X)
102
103 namespace {
104
_cublasOpFromChar(char op)105 static cublasOperation_t _cublasOpFromChar(char op) {
106 switch (op) {
107 case 'n':
108 case 'N':
109 return CUBLAS_OP_N;
110 case 't':
111 case 'T':
112 return CUBLAS_OP_T;
113 case 'c':
114 case 'C':
115 return CUBLAS_OP_C;
116 }
117 AT_ERROR(
118 "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
119 }
120
_cublasAdjustLdLevel2(int64_t m,int64_t n,int64_t * lda)121 static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
122 // Note: leading dimensions generally are checked that they are > 0
123 // and at least as big the result requires (even if the value won't
124 // be used).
125
126 // Q: Why does Level3 check trans but this doesn't?
127 // A: In level 2, the sizes (m, n) specify the size of A
128 // (independent of trans value). In level 3. the sizes (m, n, k)
129 // specify the sizes of op(A), op(B) where op depend on trans
130 // values.
131 if (n <= 1)
132 *lda = std::max<int64_t>(m, 1);
133 }
134
_cublasAdjustLdLevel3(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t * lda,int64_t * ldb,int64_t * ldc)135 static void _cublasAdjustLdLevel3(
136 char transa,
137 char transb,
138 int64_t m,
139 int64_t n,
140 int64_t k,
141 int64_t* lda,
142 int64_t* ldb,
143 int64_t* ldc) {
144 bool transa_ = ((transa != 'n') && (transa != 'N'));
145 bool transb_ = ((transb != 'n') && (transb != 'N'));
146
147 // Note: leading dimensions generally are checked that they are > 0
148 // and at least as big the result requires (even if the value won't
149 // be used).
150 if (n <= 1)
151 *ldc = std::max<int64_t>(m, 1);
152
153 if (transa_) {
154 if (m <= 1)
155 *lda = std::max<int64_t>(k, 1);
156 } else {
157 if (k <= 1)
158 *lda = std::max<int64_t>(m, 1);
159 }
160
161 if (transb_) {
162 if (k <= 1)
163 *ldb = std::max<int64_t>(n, 1);
164 } else {
165 if (n <= 1)
166 *ldb = std::max<int64_t>(k, 1);
167 }
168 }
169
170 #ifndef USE_ROCM
_getAlignment(uintptr_t address)171 uint32_t _getAlignment(uintptr_t address) {
172 // alignment are in bytes
173 uint32_t alignment = 256;
174 for (; ; alignment /= 2) {
175 if (!(address % alignment)) {
176 return alignment;
177 }
178 }
179 }
180 #endif
181
_parseChosenWorkspaceSize()182 static size_t _parseChosenWorkspaceSize() {
183 const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
184 #ifdef USE_ROCM
185 if (!val) {
186 // accept either env var
187 val = getenv("HIPBLASLT_WORKSPACE_SIZE");
188 }
189 #endif
190 size_t workspace_size = 1024; /* default size in KiB according to #73328 */
191 if (val) {
192 try {
193 workspace_size = std::stoi(val);
194 } catch(std::invalid_argument const& e) {
195 TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,",
196 " using default workspace size of ", workspace_size, " KiB.");
197 } catch(std::out_of_range const& e) {
198 TORCH_WARN("CUBLASLT_WORKSPACE_SIZE out of range,",
199 " using default workspace size of ", workspace_size, " KiB.");
200 }
201 }
202 return workspace_size * 1024;
203 }
204
_getWorkspaceSize()205 static size_t _getWorkspaceSize() {
206 static size_t workspace_size = _parseChosenWorkspaceSize();
207 return workspace_size;
208 }
209
210 } // anonymous namespace
211
212 namespace at::cuda::blas {
213
214 /* LEVEL 3 BLAS FUNCTIONS */
215
216 #define GEMM_CHECK_ARGVALUES(Dtype) \
217 do { \
218 CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, m); \
219 CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, n); \
220 CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, k); \
221 CUDABLAS_POSINT_CHECK(gemm<Dtype>, lda); \
222 CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldb); \
223 CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldc); \
224 } while (0)
225
226 #define BGEMM_CHECK_ARGVALUES(Dtype) \
227 do { \
228 CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, m); \
229 CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, n); \
230 CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, k); \
231 CUDABLAS_POSINT_CHECK(bgemm<Dtype>, lda); \
232 CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldb); \
233 CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldc); \
234 CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
235 } while (0)
236
237
238 namespace {
239 // Following the pattern of CuSparseDescriptor
240 // Defined here for now because this is the only place cublas_lt interface is
241 // used but can be moved to a header once cublas_lt interface is used in
242 // multiple places.
243 template <typename T, cublasStatus_t (*destructor)(T*)>
244 struct CuBlasLtDeleter {
operator ()at::cuda::blas::__anon764ff9400211::CuBlasLtDeleter245 void operator()(T* x) {
246 if (x != nullptr) {
247 TORCH_CUDABLAS_CHECK(destructor(x));
248 }
249 }
250 };
251
252 template <typename T, cublasStatus_t (*destructor)(T*)>
253 class CuBlasLtDescriptor {
254 public:
descriptor() const255 T* descriptor() const {
256 return descriptor_.get();
257 }
descriptor()258 T* descriptor() {
259 return descriptor_.get();
260 }
261
262 protected:
263 std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
264 };
265
266 class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
267 cublasLtMatmulDescOpaque_t,
268 &cublasLtMatmulDescDestroy> {
269 public:
CuBlasLtMatmulDescriptor(cublasComputeType_t compute_type,cudaDataType_t scale_type)270 CuBlasLtMatmulDescriptor(
271 cublasComputeType_t compute_type,
272 cudaDataType_t scale_type) {
273 cublasLtMatmulDesc_t raw_descriptor = nullptr;
274 TORCH_CUDABLAS_CHECK(
275 cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
276 descriptor_.reset(raw_descriptor);
277 }
278 template <typename T>
setAttribute(cublasLtMatmulDescAttributes_t attr,const T value)279 inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
280 TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
281 }
282 };
283
284 class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
285 cublasLtMatrixLayoutOpaque_t,
286 &cublasLtMatrixLayoutDestroy> {
287 public:
CuBlasLtMatrixLayout(cudaDataType_t type,uint64_t rows,uint64_t cols,int64_t ld,bool t=false)288 CuBlasLtMatrixLayout(
289 cudaDataType_t type,
290 uint64_t rows,
291 uint64_t cols,
292 int64_t ld,
293 bool t = false) {
294 cublasLtMatrixLayout_t raw_descriptor = nullptr;
295 TORCH_CUDABLAS_CHECK(
296 cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
297 descriptor_.reset(raw_descriptor);
298 }
299 template <typename T>
setAttribute(cublasLtMatrixLayoutAttribute_t attr,const T value)300 inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) {
301 TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
302 }
303 };
304
305 class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
306 cublasLtMatmulPreferenceOpaque_t,
307 &cublasLtMatmulPreferenceDestroy> {
308 public:
CuBlasLtMatmulPreference()309 CuBlasLtMatmulPreference() {
310 cublasLtMatmulPreference_t raw_descriptor = nullptr;
311 TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
312 descriptor_.reset(raw_descriptor);
313 }
314 template <typename T>
setAttribute(cublasLtMatmulPreferenceAttributes_t attr,const T value)315 inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
316 TORCH_CUDABLAS_CHECK(::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
317 }
318 };
319 } // namespace
320
321
322 template <typename Dtype>
bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES (Dtype))323 inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
324 cudaDataType_t abcType = CUDA_R_32F;
325 cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
326 cudaDataType_t scaleType = CUDA_R_32F;
327 if constexpr (std::is_same_v<Dtype, double>) {
328 abcType = CUDA_R_64F;
329 computeType = CUBLAS_COMPUTE_64F;
330 scaleType = CUDA_R_64F;
331 } else if constexpr (std::is_same_v<Dtype, float>) {
332 #ifndef USE_ROCM
333 if (at::globalContext().allowTF32CuBLAS()) {
334 computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
335 }
336 #endif
337 } else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
338 abcType = CUDA_C_64F;
339 computeType = CUBLAS_COMPUTE_64F;
340 scaleType = CUDA_C_64F;
341 } else if constexpr (std::is_same_v<Dtype, c10::complex<float>>) {
342 abcType = CUDA_C_32F;
343 scaleType = CUDA_C_32F;
344 } else if constexpr (std::is_same_v<Dtype, at::Half>) {
345 abcType = CUDA_R_16F;
346 } else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
347 abcType = CUDA_R_16BF;
348 } else {
349 static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublaslt: not implemented");
350 }
351
352 globalContext().alertCuBLASConfigNotDeterministic();
353 cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
354 cublasOperation_t opa = _cublasOpFromChar(transa);
355 cublasOperation_t opb = _cublasOpFromChar(transb);
356 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
357
358 CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
359 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
360 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
361 CuBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == CUBLAS_OP_T);
362 CuBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == CUBLAS_OP_T);
363 CuBlasLtMatrixLayout Cdesc(abcType, m, n, ldc);
364
365 if (num_batches > 1) {
366 int num_batches_as_int = static_cast<int>(num_batches);
367 Adesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
368 Bdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
369 Cdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
370 Adesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridea);
371 Bdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, strideb);
372 Cdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridec);
373 }
374
375 CuBlasLtMatmulPreference preference;
376 // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
377 // setting this to 1M.
378 size_t workspaceSize = _getWorkspaceSize();
379 preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
380
381 #ifndef USE_ROCM
382 uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(a));
383 uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(b));
384 uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(c));
385 preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment);
386 preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
387 preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
388 #endif
389
390 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
391 auto workspace = allocator.allocate(workspaceSize);
392 TORCH_CHECK(workspace.get() != nullptr, "OOM trying to allocate workspace for cublaslt");
393
394 cublasLtMatmulHeuristicResult_t heuristicResult = {};
395 int returnedResult = 0;
396 TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
397 ltHandle,
398 computeDesc.descriptor(),
399 Adesc.descriptor(),
400 Bdesc.descriptor(),
401 Cdesc.descriptor(),
402 Cdesc.descriptor(),
403 preference.descriptor(),
404 1,
405 &heuristicResult,
406 &returnedResult));
407 if (returnedResult == 0) {
408 TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
409 }
410
411 cublasStatus_t cublasStatus = cublasLtMatmul(
412 ltHandle,
413 computeDesc.descriptor(),
414 &alpha,
415 a,
416 Adesc.descriptor(),
417 b,
418 Bdesc.descriptor(),
419 &beta,
420 c,
421 Cdesc.descriptor(),
422 c,
423 Cdesc.descriptor(),
424 &heuristicResult.algo,
425 workspace.mutable_get(),
426 workspaceSize,
427 at::cuda::getCurrentCUDAStream());
428 TORCH_CHECK(
429 cublasStatus == CUBLAS_STATUS_SUCCESS,
430 "CUDA error: ",
431 at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
432 " when calling cublasLtMatmul with transpose_mat1 ",
433 (opa == CUBLAS_OP_T),
434 " transpose_mat2 ",
435 (opb == CUBLAS_OP_T),
436 " m ",
437 m,
438 " n ",
439 n,
440 " k ",
441 k,
442 " lda ",
443 lda,
444 " ldb ",
445 ldb,
446 " ldc ",
447 ldc,
448 " abcType ",
449 abcType,
450 " computeType ",
451 computeType,
452 " scaleType ",
453 scaleType);
454 }
455
456
457 template <typename Dtype>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (Dtype))458 inline void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
459 static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublas: not implemented");
460 }
461
462 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (double))463 void bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
464 // See Note [Writing Nondeterministic Operations]
465 globalContext().alertCuBLASConfigNotDeterministic();
466 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
467 cublasOperation_t opa = _cublasOpFromChar(transa);
468 cublasOperation_t opb = _cublasOpFromChar(transb);
469 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
470 BGEMM_CHECK_ARGVALUES(double);
471 TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched(
472 handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
473 }
474
475 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (float))476 void bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
477 // See Note [Writing Nondeterministic Operations]
478 globalContext().alertCuBLASConfigNotDeterministic();
479 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
480 cublasOperation_t opa = _cublasOpFromChar(transa);
481 cublasOperation_t opb = _cublasOpFromChar(transb);
482 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
483 BGEMM_CHECK_ARGVALUES(float);
484 TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched(
485 handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
486 }
487
488 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (c10::complex<double>))489 void bgemm_internal_cublas<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
490 // See Note [Writing Nondeterministic Operations]
491 globalContext().alertCuBLASConfigNotDeterministic();
492 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
493 cublasOperation_t opa = _cublasOpFromChar(transa);
494 cublasOperation_t opb = _cublasOpFromChar(transb);
495 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
496 BGEMM_CHECK_ARGVALUES(c10::complex<double>);
497 TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched(
498 handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
499 lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta),
500 reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches));
501 }
502
503 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (c10::complex<float>))504 void bgemm_internal_cublas<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
505 // See Note [Writing Nondeterministic Operations]
506 globalContext().alertCuBLASConfigNotDeterministic();
507 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
508 cublasOperation_t opa = _cublasOpFromChar(transa);
509 cublasOperation_t opb = _cublasOpFromChar(transb);
510 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
511 BGEMM_CHECK_ARGVALUES(c10::complex<float>);
512 TORCH_CUDABLAS_CHECK(cublasCgemmStridedBatched(
513 handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
514 lda, stridea, reinterpret_cast<const cuComplex*>(b), ldb, strideb, reinterpret_cast<const cuComplex*>(&beta),
515 reinterpret_cast<cuComplex*>(c), ldc, stridec, num_batches));
516 }
517
518 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (at::Half))519 void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
520 // See Note [Writing Nondeterministic Operations]
521 globalContext().alertCuBLASConfigNotDeterministic();
522 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
523 cublasOperation_t opa = _cublasOpFromChar(transa);
524 cublasOperation_t opb = _cublasOpFromChar(transb);
525 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
526 BGEMM_CHECK_ARGVALUES(at::Half);
527 float falpha = alpha;
528 float fbeta = beta;
529 #ifdef USE_ROCM
530 int flag = 0;
531 #if USE_GEMM_FLAGS_FP16_ALT_IMPL
532 flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
533 #endif
534 TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle,
535 hipOperationToRocOperation(opa),
536 hipOperationToRocOperation(opb), (int)m, (int)n, (int)k,
537 (void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
538 b, rocblas_datatype_f16_r, (int)ldb, strideb,
539 (void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
540 c, rocblas_datatype_f16_r, (int)ldc, stridec,
541 (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
542 0, flag)));
543 #else
544 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
545 if (prop->major >= 5){
546 TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(
547 handle, opa, opb, m, n, k,
548 (void*)(&falpha), a, CUDA_R_16F, lda, stridea,
549 b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
550 c, CUDA_R_16F, ldc, stridec,
551 num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
552 } else {
553 for (const auto i : c10::irange(num_batches)) {
554 at::cuda::blas::gemm<at::Half>(
555 transa, transb,
556 m, n, k,
557 alpha, (a + i * stridea), lda,
558 (b + i * strideb), ldb, beta,
559 (c + i * stridec), ldc);
560 }
561 }
562 #endif // USE_ROCM
563 }
564
565 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (at::BFloat16))566 void bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
567 // See Note [Writing Nondeterministic Operations]
568 globalContext().alertCuBLASConfigNotDeterministic();
569 BGEMM_CHECK_ARGVALUES(at::BFloat16);
570 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
571 cublasOperation_t opa = _cublasOpFromChar(transa);
572 cublasOperation_t opb = _cublasOpFromChar(transb);
573 const float falpha = alpha;
574 const float fbeta = beta;
575 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
576
577 #if defined(USE_ROCM)
578 auto compute_type = CUBLAS_COMPUTE_32F;
579 #else
580 auto compute_type = CUDA_R_32F;
581 #endif
582 TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(handle,
583 opa, opb, (int)m, (int)n, (int)k,
584 (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
585 b, CUDA_R_16BF, (int)ldb, strideb,
586 (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
587 (int)num_batches,
588 compute_type,
589 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
590 }
591
592 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (double))593 void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
594 {
595 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
596 #ifdef USE_ROCM
597 // hipblaslt does not support double gemm yet
598 bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGS(double));
599 #else
600 bgemm_internal_cublaslt<double>(CUDABLAS_BGEMM_ARGS(double));
601 #endif
602 }
603 else {
604 bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGS(double));
605 }
606 }
607
608 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (float))609 void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float))
610 {
611 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
612 bgemm_internal_cublaslt<float>(CUDABLAS_BGEMM_ARGS(float));
613 }
614 else {
615 bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGS(float));
616 }
617 }
618
619 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (c10::complex<double>))620 void bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>))
621 {
622 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
623 #ifdef USE_ROCM
624 // hipblaslt does not support complex<double> gemm yet
625 bgemm_internal_cublas<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
626 #else
627 bgemm_internal_cublaslt<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
628 #endif
629 }
630 else {
631 bgemm_internal_cublas<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
632 }
633 }
634
635 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (c10::complex<float>))636 void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>))
637 {
638 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
639 #ifdef USE_ROCM
640 // hipblaslt does not support complex<float> gemm yet
641 bgemm_internal_cublas<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
642 #else
643 bgemm_internal_cublaslt<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
644 #endif
645 }
646 else {
647 bgemm_internal_cublas<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
648 }
649 }
650
651 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (at::Half))652 void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
653 {
654 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
655 bgemm_internal_cublaslt<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
656 }
657 else {
658 bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
659 }
660 }
661
662 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (at::BFloat16))663 void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
664 {
665 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
666 bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
667 }
668 else {
669 bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
670 }
671 }
672
673 template <typename DType>
bgemm_tunable(CUDABLAS_BGEMM_ARGTYPES (DType))674 inline void bgemm_tunable(CUDABLAS_BGEMM_ARGTYPES(DType)) {
675 tunable::GemmStridedBatchedParams<DType> params;
676 params.transa = transa;
677 params.transb = transb;
678 params.m = m;
679 params.n = n;
680 params.k = k;
681 params.alpha = alpha;
682 params.a = a;
683 params.lda = lda;
684 params.stride_a = stridea;
685 params.b = b;
686 params.ldb = ldb;
687 params.stride_b = strideb;
688 params.beta = beta;
689 params.c = c;
690 params.ldc = ldc;
691 params.stride_c = stridec;
692 params.batch = num_batches;
693
694 bool transa_ = ((transa != 'n') && (transa != 'N'));
695 bool transb_ = ((transb != 'n') && (transb != 'N'));
696
697 if (transa_ && transb_) {
698 static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> bgemm{};
699 bgemm(¶ms);
700 }
701 else if (transa_ && !transb_) {
702 static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> bgemm{};
703 bgemm(¶ms);
704 }
705 else if (!transa_ && transb_) {
706 static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> bgemm{};
707 bgemm(¶ms);
708 }
709 else if (!transa_ && !transb_) {
710 static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> bgemm{};
711 bgemm(¶ms);
712 }
713 else {
714 TORCH_CHECK(false, "unreachable");
715 }
716 }
717
718 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (double))719 void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
720 auto tuning_ctx = at::cuda::tunable::getTuningContext();
721 if (tuning_ctx->IsTunableOpEnabled()) {
722 bgemm_tunable<double>(CUDABLAS_BGEMM_ARGS(double));
723 }
724 else {
725 bgemm_internal<double>(CUDABLAS_BGEMM_ARGS(double));
726 }
727 }
728
729 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (float))730 void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
731 auto tuning_ctx = at::cuda::tunable::getTuningContext();
732 if (tuning_ctx->IsTunableOpEnabled()) {
733 bgemm_tunable<float>(CUDABLAS_BGEMM_ARGS(float));
734 }
735 else {
736 bgemm_internal<float>(CUDABLAS_BGEMM_ARGS(float));
737 }
738 }
739
740 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (c10::complex<double>))741 void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
742 auto tuning_ctx = at::cuda::tunable::getTuningContext();
743 if (tuning_ctx->IsTunableOpEnabled()) {
744 bgemm_tunable<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
745 }
746 else {
747 bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
748 }
749 }
750
751 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (c10::complex<float>))752 void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
753 auto tuning_ctx = at::cuda::tunable::getTuningContext();
754 if (tuning_ctx->IsTunableOpEnabled()) {
755 bgemm_tunable<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
756 }
757 else {
758 bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
759 }
760 }
761
762 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (at::Half))763 void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
764 auto tuning_ctx = at::cuda::tunable::getTuningContext();
765 if (tuning_ctx->IsTunableOpEnabled()) {
766 bgemm_tunable<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
767 }
768 else {
769 bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
770 }
771 }
772
773 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (at::BFloat16))774 void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
775 auto tuning_ctx = at::cuda::tunable::getTuningContext();
776 if (tuning_ctx->IsTunableOpEnabled()) {
777 bgemm_tunable<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
778 }
779 else {
780 bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
781 }
782 }
783
784 template <typename Dtype>
gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES (Dtype))785 inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
786 // forward to bgemm implementation but set strides and batches to 0
787 bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0);
788 }
789
790 template <typename Dtype>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (Dtype))791 inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
792 static_assert(false && sizeof(Dtype), "at::cuda::blas::gemm_internal_cublas: not implemented");
793 }
794
795 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (double))796 void gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
797 // See Note [Writing Nondeterministic Operations]
798 globalContext().alertCuBLASConfigNotDeterministic();
799 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
800 cublasOperation_t opa = _cublasOpFromChar(transa);
801 cublasOperation_t opb = _cublasOpFromChar(transb);
802 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
803 GEMM_CHECK_ARGVALUES(double);
804 TORCH_CUDABLAS_CHECK(cublasDgemm(
805 handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
806 }
807
808 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (float))809 void gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
810 // See Note [Writing Nondeterministic Operations]
811 globalContext().alertCuBLASConfigNotDeterministic();
812 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
813 cublasOperation_t opa = _cublasOpFromChar(transa);
814 cublasOperation_t opb = _cublasOpFromChar(transb);
815 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
816 GEMM_CHECK_ARGVALUES(float);
817 TORCH_CUDABLAS_CHECK(cublasSgemm(
818 handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
819 }
820
821 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (c10::complex<double>))822 void gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
823 // See Note [Writing Nondeterministic Operations]
824 globalContext().alertCuBLASConfigNotDeterministic();
825 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
826 cublasOperation_t opa = _cublasOpFromChar(transa);
827 cublasOperation_t opb = _cublasOpFromChar(transb);
828 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
829 GEMM_CHECK_ARGVALUES(c10::complex<double>);
830 TORCH_CUDABLAS_CHECK(cublasZgemm(
831 handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
832 lda, reinterpret_cast<const cuDoubleComplex*>(b), ldb, reinterpret_cast<const cuDoubleComplex*>(&beta),
833 reinterpret_cast<cuDoubleComplex*>(c), ldc));
834 }
835
836 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (c10::complex<float>))837 void gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
838 // See Note [Writing Nondeterministic Operations]
839 globalContext().alertCuBLASConfigNotDeterministic();
840 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
841 cublasOperation_t opa = _cublasOpFromChar(transa);
842 cublasOperation_t opb = _cublasOpFromChar(transb);
843 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
844 GEMM_CHECK_ARGVALUES(c10::complex<float>);
845 TORCH_CUDABLAS_CHECK(cublasCgemm(
846 handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
847 lda, reinterpret_cast<const cuComplex*>(b), ldb, reinterpret_cast<const cuComplex*>(&beta),
848 reinterpret_cast<cuComplex*>(c), ldc));
849 }
850
851 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (at::Half))852 void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
853 // See Note [Writing Nondeterministic Operations]
854 globalContext().alertCuBLASConfigNotDeterministic();
855 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
856 cublasOperation_t opa = _cublasOpFromChar(transa);
857 cublasOperation_t opb = _cublasOpFromChar(transb);
858 float falpha = alpha;
859 float fbeta = beta;
860 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
861 GEMM_CHECK_ARGVALUES(at::Half);
862 #ifdef USE_ROCM
863 int flag = 0;
864 #if USE_GEMM_FLAGS_FP16_ALT_IMPL
865 flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
866 #endif
867 TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(
868 (rocblas_handle)handle,
869 hipOperationToRocOperation(opa),
870 hipOperationToRocOperation(opb),
871 m,
872 n,
873 k,
874 &falpha,
875 a,
876 rocblas_datatype_f16_r,
877 lda,
878 b,
879 rocblas_datatype_f16_r,
880 ldb,
881 &fbeta,
882 c,
883 rocblas_datatype_f16_r,
884 ldc,
885 c,
886 rocblas_datatype_f16_r,
887 ldc,
888 rocblas_datatype_f32_r,
889 rocblas_gemm_algo_standard,
890 0,
891 flag)));
892 #else
893 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
894 if (prop->major >= 5) {
895 #ifndef USE_ROCM
896 cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
897 if (!at::globalContext().allowFP16ReductionCuBLAS()) {
898 cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
899 }
900 #endif
901 // Disallow fp16 reductions that could lead to unexpected overflow issues.
902 TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
903 TORCH_CUDABLAS_CHECK(cublasGemmEx(
904 handle,
905 opa,
906 opb,
907 m,
908 n,
909 k,
910 &falpha,
911 a,
912 CUDA_R_16F,
913 lda,
914 b,
915 CUDA_R_16F,
916 ldb,
917 &fbeta,
918 c,
919 CUDA_R_16F,
920 ldc,
921 CUDA_R_32F,
922 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
923 TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
924 } else {
925 TORCH_CUDABLAS_CHECK(cublasSgemmEx(
926 handle,
927 opa,
928 opb,
929 m,
930 n,
931 k,
932 &falpha,
933 a,
934 CUDA_R_16F,
935 lda,
936 b,
937 CUDA_R_16F,
938 ldb,
939 &fbeta,
940 c,
941 CUDA_R_16F,
942 ldc));
943 }
944 #endif
945 }
946
947 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (at::BFloat16))948 void gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
949 globalContext().alertCuBLASConfigNotDeterministic();
950 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
951 cublasOperation_t opa = _cublasOpFromChar(transa);
952 cublasOperation_t opb = _cublasOpFromChar(transb);
953 float falpha = alpha;
954 float fbeta = beta;
955 _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
956 GEMM_CHECK_ARGVALUES(at::BFloat16);
957 #ifndef USE_ROCM
958 cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
959 if (!at::globalContext().allowBF16ReductionCuBLAS()) {
960 cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
961 }
962 #endif
963 #if defined(USE_ROCM)
964 auto compute_type = CUBLAS_COMPUTE_32F;
965 #else
966 auto compute_type = CUDA_R_32F;
967 #endif
968 TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
969 TORCH_CUDABLAS_CHECK(cublasGemmEx(
970 handle,
971 opa,
972 opb,
973 m,
974 n,
975 k,
976 &falpha,
977 a,
978 CUDA_R_16BF,
979 lda,
980 b,
981 CUDA_R_16BF,
982 ldb,
983 &fbeta,
984 c,
985 CUDA_R_16BF,
986 ldc,
987 compute_type,
988 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
989 TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
990 }
991
992 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (double))993 void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
994 {
995 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
996 #ifdef USE_ROCM
997 // hipblaslt does not support double gemm yet
998 gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
999 #else
1000 gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
1001 #endif
1002 }
1003 else {
1004 gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
1005 }
1006 }
1007
1008 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (float))1009 void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
1010 {
1011 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
1012 gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
1013 }
1014 else {
1015 gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
1016 }
1017 }
1018
1019 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (c10::complex<double>))1020 void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>))
1021 {
1022 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
1023 #ifdef USE_ROCM
1024 // hipblaslt does not support complex gemm yet
1025 gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
1026 #else
1027 gemm_internal_cublaslt<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
1028 #endif
1029 }
1030 else {
1031 gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
1032 }
1033 }
1034
1035 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (c10::complex<float>))1036 void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>))
1037 {
1038 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
1039 #ifdef USE_ROCM
1040 // hipblaslt does not support complex gemm yet
1041 gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
1042 #else
1043 gemm_internal_cublaslt<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
1044 #endif
1045 }
1046 else {
1047 gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
1048 }
1049 }
1050
1051 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (at::Half))1052 void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
1053 {
1054 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
1055 gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
1056 }
1057 else {
1058 gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
1059 }
1060 }
1061
1062 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (at::BFloat16))1063 void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
1064 {
1065 if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
1066 gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
1067 }
1068 else {
1069 gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
1070 }
1071 }
1072
1073 template <typename DType>
gemm_tunable(CUDABLAS_GEMM_ARGTYPES (DType))1074 inline void gemm_tunable(CUDABLAS_GEMM_ARGTYPES(DType)) {
1075 tunable::GemmParams<DType> params;
1076 params.transa = transa;
1077 params.transb = transb;
1078 params.m = m;
1079 params.n = n;
1080 params.k = k;
1081 params.alpha = alpha;
1082 params.a = a;
1083 params.lda = lda;
1084 params.b = b;
1085 params.ldb = ldb;
1086 params.beta = beta;
1087 params.c = c;
1088 params.ldc = ldc;
1089
1090 bool transa_ = ((transa != 'n') && (transa != 'N'));
1091 bool transb_ = ((transb != 'n') && (transb != 'N'));
1092
1093 if (transa_ && transb_) {
1094 static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> gemm{};
1095 gemm(¶ms);
1096 }
1097 else if (transa_ && !transb_) {
1098 static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> gemm{};
1099 gemm(¶ms);
1100 }
1101 else if (!transa_ && transb_) {
1102 static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> gemm{};
1103 gemm(¶ms);
1104 }
1105 else if (!transa_ && !transb_) {
1106 static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> gemm{};
1107 gemm(¶ms);
1108 }
1109 else {
1110 TORCH_CHECK(false, "unreachable");
1111 }
1112 }
1113
1114 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (double))1115 void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
1116 auto tuning_ctx = at::cuda::tunable::getTuningContext();
1117 if (tuning_ctx->IsTunableOpEnabled()) {
1118 gemm_tunable<double>(CUDABLAS_GEMM_ARGS(double));
1119 }
1120 else {
1121 gemm_internal<double>(CUDABLAS_GEMM_ARGS(double));
1122 }
1123 }
1124
1125 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (float))1126 void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
1127 auto tuning_ctx = at::cuda::tunable::getTuningContext();
1128 if (tuning_ctx->IsTunableOpEnabled()) {
1129 gemm_tunable<float>(CUDABLAS_GEMM_ARGS(float));
1130 }
1131 else {
1132 gemm_internal<float>(CUDABLAS_GEMM_ARGS(float));
1133 }
1134 }
1135
1136 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (c10::complex<double>))1137 void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
1138 auto tuning_ctx = at::cuda::tunable::getTuningContext();
1139 if (tuning_ctx->IsTunableOpEnabled()) {
1140 gemm_tunable<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
1141 }
1142 else {
1143 gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
1144 }
1145 }
1146
1147 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (c10::complex<float>))1148 void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
1149 auto tuning_ctx = at::cuda::tunable::getTuningContext();
1150 if (tuning_ctx->IsTunableOpEnabled()) {
1151 gemm_tunable<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
1152 }
1153 else {
1154 gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
1155 }
1156 }
1157
1158 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (at::Half))1159 void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
1160 auto tuning_ctx = at::cuda::tunable::getTuningContext();
1161 if (tuning_ctx->IsTunableOpEnabled()) {
1162 gemm_tunable<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
1163 }
1164 else {
1165 gemm_internal<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
1166 }
1167 }
1168
1169 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (at::BFloat16))1170 void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
1171 auto tuning_ctx = at::cuda::tunable::getTuningContext();
1172 if (tuning_ctx->IsTunableOpEnabled()) {
1173 gemm_tunable<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
1174 }
1175 else {
1176 gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
1177 }
1178 }
1179
1180
1181 template <typename Dtype>
gemm_and_bias(bool transpose_mat1,bool transpose_mat2,int64_t m,int64_t n,int64_t k,at::opmath_type<Dtype> alpha_val,const Dtype * mat1_ptr,int64_t mat1_ld,const Dtype * mat2_ptr,int64_t mat2_ld,const Dtype * bias,Dtype * result_ptr,int64_t result_ld,GEMMAndBiasActivationEpilogue activation)1182 void gemm_and_bias(
1183 bool transpose_mat1,
1184 bool transpose_mat2,
1185 int64_t m,
1186 int64_t n,
1187 int64_t k,
1188 at::opmath_type<Dtype> alpha_val,
1189 const Dtype* mat1_ptr,
1190 int64_t mat1_ld,
1191 const Dtype* mat2_ptr,
1192 int64_t mat2_ld,
1193 const Dtype* bias,
1194 Dtype* result_ptr,
1195 int64_t result_ld,
1196 GEMMAndBiasActivationEpilogue activation) {
1197 using opmath_t = at::opmath_type<Dtype>;
1198 opmath_t beta_val = 0; // bias is added in epilogue
1199
1200 cudaDataType_t abcType = CUDA_R_32F;
1201 cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
1202 cudaDataType_t scaleType = CUDA_R_32F;
1203 if constexpr (std::is_same_v<Dtype, double>) {
1204 abcType = CUDA_R_64F;
1205 computeType = CUBLAS_COMPUTE_64F;
1206 scaleType = CUDA_R_64F;
1207 } else if constexpr (std::is_same_v<Dtype, float>) {
1208 #ifndef USE_ROCM
1209 if (at::globalContext().allowTF32CuBLAS()) {
1210 computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
1211 }
1212 #endif
1213 abcType = CUDA_R_32F;
1214 } else if constexpr (std::is_same_v<Dtype, at::Half>) {
1215 abcType = CUDA_R_16F;
1216 } else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
1217 abcType = CUDA_R_16BF;
1218 }
1219
1220 CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
1221 cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
1222 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
1223 cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
1224 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
1225 cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
1226 if (activation == GEMMAndBiasActivationEpilogue::RELU) {
1227 epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
1228 } else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
1229 #if CUDA_VERSION >= 11040 || defined(USE_ROCM)
1230 epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
1231 #endif
1232 }
1233
1234 if (bias != nullptr) {
1235 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
1236 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
1237 }
1238
1239 CuBlasLtMatrixLayout Adesc(abcType, m, k, mat1_ld, transpose_mat1);
1240 CuBlasLtMatrixLayout Bdesc(abcType, k, n, mat2_ld, transpose_mat2);
1241 CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);
1242
1243 CuBlasLtMatmulPreference preference;
1244 // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
1245 // setting this to 1M.
1246 size_t workspaceSize = _getWorkspaceSize();
1247 preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
1248
1249 #ifndef USE_ROCM
1250 uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
1251 uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat2_ptr));
1252 uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
1253 uint32_t d_alignment = _getAlignment(reinterpret_cast<uintptr_t>(bias));
1254 preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment);
1255 preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
1256 preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
1257 preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
1258 #endif
1259
1260 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
1261 auto workspace = allocator.allocate(workspaceSize);
1262 TORCH_CHECK(workspace.get() != nullptr, "OOM trying to allocate workspace for cublaslt");
1263
1264 cublasLtMatmulHeuristicResult_t heuristicResult = {};
1265 int returnedResult = 0;
1266 cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
1267 TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
1268 ltHandle,
1269 computeDesc.descriptor(),
1270 Adesc.descriptor(),
1271 Bdesc.descriptor(),
1272 Cdesc.descriptor(),
1273 Cdesc.descriptor(),
1274 preference.descriptor(),
1275 1,
1276 &heuristicResult,
1277 &returnedResult));
1278 if (returnedResult == 0) {
1279 TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
1280 }
1281
1282 cublasStatus_t cublasStatus = cublasLtMatmul(
1283 ltHandle,
1284 computeDesc.descriptor(),
1285 &alpha_val,
1286 mat1_ptr,
1287 Adesc.descriptor(),
1288 mat2_ptr,
1289 Bdesc.descriptor(),
1290 &beta_val,
1291 result_ptr,
1292 Cdesc.descriptor(),
1293 result_ptr,
1294 Cdesc.descriptor(),
1295 &heuristicResult.algo,
1296 workspace.mutable_get(),
1297 workspaceSize,
1298 at::cuda::getCurrentCUDAStream());
1299 TORCH_CHECK(
1300 cublasStatus == CUBLAS_STATUS_SUCCESS,
1301 "CUDA error: ",
1302 at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
1303 " when calling cublasLtMatmul with transpose_mat1 ",
1304 transpose_mat1,
1305 " transpose_mat2 ",
1306 transpose_mat2,
1307 " m ",
1308 m,
1309 " n ",
1310 n,
1311 " k ",
1312 k,
1313 " mat1_ld ",
1314 mat1_ld,
1315 " mat2_ld ",
1316 mat2_ld,
1317 " result_ld ",
1318 result_ld,
1319 " abcType ",
1320 abcType,
1321 " computeType ",
1322 computeType,
1323 " scaleType ",
1324 scaleType);
1325 }
1326
1327 template void gemm_and_bias(
1328 bool transpose_mat1,
1329 bool transpose_mat2,
1330 int64_t m,
1331 int64_t n,
1332 int64_t k,
1333 at::opmath_type<double> alpha_val,
1334 const double* mat1_ptr,
1335 int64_t mat1_ld,
1336 const double* mat2_ptr,
1337 int64_t mat2_ld,
1338 const double* bias,
1339 double* result_ptr,
1340 int64_t result_ld,
1341 GEMMAndBiasActivationEpilogue activation);
1342
1343 template void gemm_and_bias(
1344 bool transpose_mat1,
1345 bool transpose_mat2,
1346 int64_t m,
1347 int64_t n,
1348 int64_t k,
1349 at::opmath_type<float> alpha_val,
1350 const float* mat1_ptr,
1351 int64_t mat1_ld,
1352 const float* mat2_ptr,
1353 int64_t mat2_ld,
1354 const float* bias,
1355 float* result_ptr,
1356 int64_t result_ld,
1357 GEMMAndBiasActivationEpilogue activation);
1358
1359 template void gemm_and_bias(
1360 bool transpose_mat1,
1361 bool transpose_mat2,
1362 int64_t m,
1363 int64_t n,
1364 int64_t k,
1365 at::opmath_type<at::Half> alpha_val,
1366 const at::Half* mat1_ptr,
1367 int64_t mat1_ld,
1368 const at::Half* mat2_ptr,
1369 int64_t mat2_ld,
1370 const at::Half* bias,
1371 at::Half* result_ptr,
1372 int64_t result_ld,
1373 GEMMAndBiasActivationEpilogue activation);
1374
1375 template void gemm_and_bias(
1376 bool transpose_mat1,
1377 bool transpose_mat2,
1378 int64_t m,
1379 int64_t n,
1380 int64_t k,
1381 at::opmath_type<at::BFloat16> alpha_val,
1382 const at::BFloat16* mat1_ptr,
1383 int64_t mat1_ld,
1384 const at::BFloat16* mat2_ptr,
1385 int64_t mat2_ld,
1386 const at::BFloat16* bias,
1387 at::BFloat16* result_ptr,
1388 int64_t result_ld,
1389 GEMMAndBiasActivationEpilogue activation);
1390
scaled_gemm(char transa,char transb,int64_t m,int64_t n,int64_t k,const void * mat1_ptr,const void * mat1_scale_ptr,int64_t mat1_ld,ScalarType mat1_dtype,const void * mat2_ptr,const void * mat2_scale_ptr,int64_t mat2_ld,ScalarType mat2_dtype,const void * bias_ptr,ScalarType bias_dtype,void * result_ptr,const void * result_scale_ptr,int64_t result_ld,ScalarType result_dtype,void * amax_ptr,bool use_fast_accum)1391 void scaled_gemm(
1392 char transa,
1393 char transb,
1394 int64_t m,
1395 int64_t n,
1396 int64_t k,
1397 const void* mat1_ptr,
1398 const void* mat1_scale_ptr,
1399 int64_t mat1_ld,
1400 ScalarType mat1_dtype,
1401 const void* mat2_ptr,
1402 const void* mat2_scale_ptr,
1403 int64_t mat2_ld,
1404 ScalarType mat2_dtype,
1405 const void* bias_ptr,
1406 ScalarType bias_dtype,
1407 void* result_ptr,
1408 const void *result_scale_ptr,
1409 int64_t result_ld,
1410 ScalarType result_dtype,
1411 void* amax_ptr,
1412 bool use_fast_accum) {
1413 #if CUDA_VERSION >= 11080 || defined(USE_ROCM)
1414 const auto computeType = CUBLAS_COMPUTE_32F;
1415 const auto scaleType = CUDA_R_32F;
1416 const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
1417 const float alpha_val = 1.0;
1418 const float beta_val = 0.0;
1419 CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
1420 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
1421 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
1422 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
1423 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
1424 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
1425 #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60200)
1426 // Amax support in ROCm as of 6.2
1427 if (isFloat8Type(result_dtype)) {
1428 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr);
1429 }
1430 #endif
1431 #ifndef USE_ROCM
1432 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
1433 #endif
1434 CuBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't');
1435 CuBlasLtMatrixLayout Bdesc(ScalarTypeToCudaDataType(mat2_dtype), k, n, mat2_ld, transb == 't');
1436 #ifdef USE_ROCM
1437 // Cdesc is unused, beta is 0. But hipblaslt needs this set to something reasonable.
1438 CuBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld);
1439 #else
1440 CuBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(bias_dtype), m, n, result_ld);
1441 #endif
1442 CuBlasLtMatrixLayout Ddesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld);
1443 if (bias_ptr) {
1444 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
1445 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
1446 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
1447 }
1448 size_t workspaceSize = _getWorkspaceSize();
1449 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
1450 auto workspace = allocator.allocate(workspaceSize);
1451 TORCH_CHECK(workspace.get() != nullptr, "OOM trying to allocate workspace for cublaslt");
1452
1453 CuBlasLtMatmulPreference preference;
1454 preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
1455 cublasLtMatmulHeuristicResult_t heuristicResult = {};
1456 int returnedResult = 0;
1457 cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
1458 TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
1459 ltHandle,
1460 computeDesc.descriptor(),
1461 Adesc.descriptor(),
1462 Bdesc.descriptor(),
1463 Cdesc.descriptor(),
1464 Ddesc.descriptor(),
1465 preference.descriptor(),
1466 1,
1467 &heuristicResult,
1468 &returnedResult));
1469 if (returnedResult == 0) {
1470 #ifndef USE_ROCM
1471 TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
1472 #else
1473 // hipblaslt might be able to recover by returning all algos
1474 std::vector<hipblasLtMatmulHeuristicResult_t> all_algos;
1475 TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAllAlgos(
1476 ltHandle,
1477 hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
1478 _cublasOpFromChar(transa),
1479 _cublasOpFromChar(transb),
1480 ScalarTypeToCudaDataType(mat1_dtype),
1481 ScalarTypeToCudaDataType(mat2_dtype),
1482 // C is nullptr and beta=0, so set to something reasonable. See above.
1483 //ScalarTypeToCudaDataType(bias_dtype),
1484 ScalarTypeToCudaDataType(result_dtype),
1485 ScalarTypeToCudaDataType(result_dtype),
1486 CUBLAS_COMPUTE_32F,
1487 all_algos));
1488 if (all_algos.size() == 0) {
1489 TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
1490 }
1491 // pick first valid solution
1492 bool found = false;
1493 for (size_t i = 0; i < all_algos.size(); i++) {
1494 size_t ret_workspace_size = 0;
1495 auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported(
1496 ltHandle,
1497 computeDesc.descriptor(),
1498 &alpha_val,
1499 Adesc.descriptor(),
1500 Bdesc.descriptor(),
1501 &beta_val,
1502 Cdesc.descriptor(),
1503 Ddesc.descriptor(),
1504 all_algos[i].algo,
1505 ret_workspace_size);
1506 if (is_valid_status == HIPBLAS_STATUS_SUCCESS) {
1507 if (ret_workspace_size <= workspaceSize) {
1508 heuristicResult = all_algos[i];
1509 found = true;
1510 break;
1511 }
1512 }
1513 }
1514 TORCH_CHECK(found, "could not find valid hipblaslt solution");
1515 #endif
1516 }
1517 cublasStatus_t cublasStatus = cublasLtMatmul(
1518 ltHandle,
1519 computeDesc.descriptor(),
1520 &alpha_val,
1521 mat1_ptr,
1522 Adesc.descriptor(),
1523 mat2_ptr,
1524 Bdesc.descriptor(),
1525 &beta_val,
1526 #ifdef USE_ROCM
1527 result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr
1528 #else
1529 nullptr,
1530 #endif
1531 Cdesc.descriptor(),
1532 result_ptr,
1533 Ddesc.descriptor(),
1534 &heuristicResult.algo,
1535 workspace.mutable_get(),
1536 workspaceSize,
1537 at::cuda::getCurrentCUDAStream());
1538 TORCH_CHECK(
1539 cublasStatus == CUBLAS_STATUS_SUCCESS,
1540 "CUDA error: ",
1541 at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
1542 " when calling cublasLtMatmul with transpose_mat1 ",
1543 transa,
1544 " transpose_mat2 ",
1545 transb,
1546 " m ",
1547 m,
1548 " n ",
1549 n,
1550 " k ",
1551 k,
1552 " mat1_ld ",
1553 mat1_ld,
1554 " mat2_ld ",
1555 mat2_ld,
1556 " result_ld ",
1557 result_ld,
1558 " computeType ",
1559 computeType,
1560 " scaleType ",
1561 scaleType);
1562 return;
1563 #endif // CUDA_VERSION >= 11080 || defined(USE_ROCM)
1564 TORCH_CHECK(false, "scaled_gemm is only supported for CUDA 11.8 and above");
1565 }
1566
int8_gemm(bool transpose_mat1,bool transpose_mat2,int64_t m,int64_t n,int64_t k,const int8_t * mat1_ptr,int64_t mat1_ld,const int8_t * mat2_ptr,int64_t mat2_ld,int32_t * result_ptr,int64_t result_ld)1567 void int8_gemm(
1568 bool transpose_mat1,
1569 bool transpose_mat2,
1570 int64_t m,
1571 int64_t n,
1572 int64_t k,
1573 const int8_t* mat1_ptr,
1574 int64_t mat1_ld,
1575 const int8_t* mat2_ptr,
1576 int64_t mat2_ld,
1577 int32_t* result_ptr,
1578 int64_t result_ld) {
1579
1580 cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
1581 cudaDataType_t scaleType = CUDA_R_32I;
1582
1583 cudaDataType_t abType = CUDA_R_8I;
1584 cudaDataType_t cType = CUDA_R_32I;
1585
1586 CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
1587 cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
1588 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
1589 cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
1590 computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
1591
1592
1593 CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
1594 CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
1595 CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
1596
1597 // cublas team: alpha and beta need to be the same dtype as of scaleType
1598 at::opmath_type<int32_t> alpha_val = 1;
1599 int32_t beta_val = 0;
1600 cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
1601
1602 #ifdef USE_ROCM
1603 CuBlasLtMatmulPreference preference;
1604 size_t workspaceSize = _getWorkspaceSize();
1605 preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
1606 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
1607 auto workspace = allocator.allocate(workspaceSize);
1608 cublasLtMatmulHeuristicResult_t heuristicResult = {};
1609 int returnedResult = 0;
1610 TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
1611 ltHandle,
1612 computeDesc.descriptor(),
1613 Adesc.descriptor(),
1614 Bdesc.descriptor(),
1615 Cdesc.descriptor(),
1616 Cdesc.descriptor(),
1617 preference.descriptor(),
1618 1,
1619 &heuristicResult,
1620 &returnedResult));
1621 if (returnedResult == 0) {
1622 TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
1623 }
1624 #endif
1625
1626 cublasStatus_t cublasStatus = cublasLtMatmul(
1627 ltHandle,
1628 computeDesc.descriptor(),
1629 &alpha_val,
1630 mat1_ptr,
1631 Adesc.descriptor(),
1632 mat2_ptr,
1633 Bdesc.descriptor(),
1634 &beta_val,
1635 result_ptr,
1636 Cdesc.descriptor(),
1637 result_ptr,
1638 Cdesc.descriptor(),
1639 #ifdef USE_ROCM
1640 &heuristicResult.algo,
1641 #else
1642 nullptr, // Heuristics don't seem to work for int8
1643 #endif
1644 #ifdef USE_ROCM
1645 workspace.mutable_get(),
1646 #else
1647 nullptr, // Non-zero workspace doesn't seem to work.
1648 #endif
1649 #ifdef USE_ROCM
1650 workspaceSize,
1651 #else
1652 0,
1653 #endif
1654 at::cuda::getCurrentCUDAStream());
1655 TORCH_CHECK(
1656 cublasStatus == CUBLAS_STATUS_SUCCESS,
1657 "CUDA error: ",
1658 at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
1659 " when calling cublasLtMatmul with transpose_mat1 ",
1660 transpose_mat1,
1661 " transpose_mat2 ",
1662 transpose_mat2,
1663 " m ",
1664 m,
1665 " n ",
1666 n,
1667 " k ",
1668 k,
1669 " mat1_ld ",
1670 mat1_ld,
1671 " mat2_ld ",
1672 mat2_ld,
1673 " result_ld ",
1674 result_ld,
1675 " abType ",
1676 abType,
1677 " cType ",
1678 cType,
1679 " computeType ",
1680 computeType,
1681 " scaleType ",
1682 scaleType);
1683 }
1684
1685 template <>
trsm(CUDABLAS_TRSM_ARGTYPES (float))1686 void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
1687 TORCH_CUDABLAS_CHECK(cublasStrsm(
1688 handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb));
1689 }
1690
1691 template <>
trsm(CUDABLAS_TRSM_ARGTYPES (double))1692 void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double)) {
1693 TORCH_CUDABLAS_CHECK(cublasDtrsm(
1694 handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb));
1695 }
1696
1697 template <>
trsm(CUDABLAS_TRSM_ARGTYPES (c10::complex<float>))1698 void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>)) {
1699 TORCH_CUDABLAS_CHECK(cublasCtrsm(
1700 handle,
1701 side,
1702 uplo,
1703 trans,
1704 diag,
1705 m,
1706 n,
1707 reinterpret_cast<const cuComplex*>(alpha),
1708 reinterpret_cast<const cuComplex*>(A),
1709 lda,
1710 reinterpret_cast<cuComplex*>(B),
1711 ldb));
1712 }
1713
1714 template <>
trsm(CUDABLAS_TRSM_ARGTYPES (c10::complex<double>))1715 void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>)) {
1716 TORCH_CUDABLAS_CHECK(cublasZtrsm(
1717 handle,
1718 side,
1719 uplo,
1720 trans,
1721 diag,
1722 m,
1723 n,
1724 reinterpret_cast<const cuDoubleComplex*>(alpha),
1725 reinterpret_cast<const cuDoubleComplex*>(A),
1726 lda,
1727 reinterpret_cast<cuDoubleComplex*>(B),
1728 ldb));
1729 }
1730
1731 template <>
trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES (float))1732 void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float)) {
1733 TORCH_CUDABLAS_CHECK(cublasStrsmBatched(
1734 handle,
1735 side,
1736 uplo,
1737 trans,
1738 diag,
1739 m,
1740 n,
1741 alpha,
1742 A,
1743 lda,
1744 B,
1745 ldb,
1746 batchCount));
1747 }
1748
1749 template <>
trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES (double))1750 void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double)) {
1751 TORCH_CUDABLAS_CHECK(cublasDtrsmBatched(
1752 handle,
1753 side,
1754 uplo,
1755 trans,
1756 diag,
1757 m,
1758 n,
1759 alpha,
1760 A,
1761 lda,
1762 B,
1763 ldb,
1764 batchCount));
1765 }
1766
1767 template <>
trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES (c10::complex<float>))1768 void trsmBatched<c10::complex<float>>(
1769 CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>)) {
1770 TORCH_CUDABLAS_CHECK(cublasCtrsmBatched(
1771 handle,
1772 side,
1773 uplo,
1774 trans,
1775 diag,
1776 m,
1777 n,
1778 reinterpret_cast<const cuComplex*>(alpha),
1779 reinterpret_cast<cuComplex**>(A),
1780 lda,
1781 reinterpret_cast<cuComplex**>(B),
1782 ldb,
1783 batchCount));
1784 }
1785
1786 template <>
trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES (c10::complex<double>))1787 void trsmBatched<c10::complex<double>>(
1788 CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>)) {
1789 TORCH_CUDABLAS_CHECK(cublasZtrsmBatched(
1790 handle,
1791 side,
1792 uplo,
1793 trans,
1794 diag,
1795 m,
1796 n,
1797 reinterpret_cast<const cuDoubleComplex*>(alpha),
1798 reinterpret_cast<cuDoubleComplex**>(A),
1799 lda,
1800 reinterpret_cast<cuDoubleComplex**>(B),
1801 ldb,
1802 batchCount));
1803 }
1804
1805 /* LEVEL 2 BLAS FUNCTIONS */
1806
1807 #define GEMV_CHECK_ARGVALUES(Dtype) \
1808 do { \
1809 CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, m); \
1810 CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, n); \
1811 CUDABLAS_POSINT_CHECK(gemv<Dtype>, lda); \
1812 CUDABLAS_POSINT_CHECK(gemv<Dtype>, incx); \
1813 CUDABLAS_POSINT_CHECK(gemv<Dtype>, incy); \
1814 } while (0)
1815
1816 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (c10::complex<double>))1817 void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
1818 // See Note [Writing Nondeterministic Operations]
1819 globalContext().alertCuBLASConfigNotDeterministic();
1820 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1821 cublasOperation_t op = _cublasOpFromChar(trans);
1822 _cublasAdjustLdLevel2(m, n, &lda);
1823 GEMV_CHECK_ARGVALUES(c10::complex<double>);
1824 TORCH_CUDABLAS_CHECK(
1825 cublasZgemv(handle, op, m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
1826 lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta),
1827 reinterpret_cast<cuDoubleComplex*>(y), incy));
1828 }
1829
1830 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (c10::complex<float>))1831 void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
1832 // gemv is bw bound, and does not benefit from TF32. But the precision
1833 // loss still happens on TF32. So we disable it here.
1834 NoTF32Guard disable_tf32;
1835 // See Note [Writing Nondeterministic Operations]
1836 globalContext().alertCuBLASConfigNotDeterministic();
1837 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1838 cublasOperation_t op = _cublasOpFromChar(trans);
1839 _cublasAdjustLdLevel2(m, n, &lda);
1840 GEMV_CHECK_ARGVALUES(c10::complex<float>);
1841 TORCH_CUDABLAS_CHECK(
1842 cublasCgemv(handle, op, m, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
1843 lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta),
1844 reinterpret_cast<cuComplex*>(y), incy));
1845 }
1846
1847 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (double))1848 void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
1849 // See Note [Writing Nondeterministic Operations]
1850 globalContext().alertCuBLASConfigNotDeterministic();
1851 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1852 cublasOperation_t op = _cublasOpFromChar(trans);
1853 _cublasAdjustLdLevel2(m, n, &lda);
1854 GEMV_CHECK_ARGVALUES(double);
1855 TORCH_CUDABLAS_CHECK(
1856 cublasDgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy));
1857 }
1858
1859 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (float))1860 void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) {
1861 // gemv is bw bound, and does not benefit from TF32. But the precision
1862 // loss still happens on TF32. So we disable it here.
1863 NoTF32Guard disable_tf32;
1864 // See Note [Writing Nondeterministic Operations]
1865 globalContext().alertCuBLASConfigNotDeterministic();
1866 cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1867 cublasOperation_t op = _cublasOpFromChar(trans);
1868 _cublasAdjustLdLevel2(m, n, &lda);
1869 GEMV_CHECK_ARGVALUES(float);
1870 TORCH_CUDABLAS_CHECK(
1871 cublasSgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy));
1872 }
1873
1874 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (at::Half))1875 void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half)) {
1876 // In general, cublas regards matrices as column-major.
1877 // The cublasS/Dgemv usages in cuda::blas::gemv<float>/<double> above
1878 // require that external blas::gemv callers obey the following convention:
1879 //
1880 // If "a" is row-major with shape (output, summed) in blas::gemv's caller,
1881 // caller interprets it as column-major with shape (summed, output), passes
1882 // summed and output respectively to our local vars m, n, and requests that cublas
1883 // internally transpose ("trans") the column-major interpretation of a.
1884 //
1885 // There's no such thing as "cublasHalfgemv", so here we hack gemv with a gemm.
1886 // However, we must allow the same calling convention, because the caller shouldn't
1887 // have to swap args based on whether it's calling blas::gemv<at::Half> or <float>.
1888
1889 bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
1890 if (trans_bool) {
1891 std::swap(m, n);
1892 }
1893 // After swap, local vars m, n contain the output and summed sizes respectively,
1894 // regardless of whether "a" was row-major or column-major in gemv<>'s caller.
1895
1896 // To handle the possibility incy > 1, interprets vector y as column-major matrix with one row
1897 // (shape (1, output)) and leading dim incy.
1898 // trans(a)*x would compute a matrix with one column (shape (output, 1)) which wouldn't match y.
1899 // So instead, we interpret x similarly to y, as a column-major matrix with one row
1900 // (shape (1, summed)) and leading dim incx. The gemm then carries out x*transpose(trans(a)) to
1901 // produce a matrix with one row (shape (1, output)), matching y.
1902 char trans_flipped = (trans_bool ? 'n' : 't');
1903 gemm<at::Half>(
1904 'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
1905 }
1906
1907 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (at::BFloat16))1908 void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)) {
1909 bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
1910 if (trans_bool) {
1911 std::swap(m, n);
1912 }
1913 char trans_flipped = (trans_bool ? 'n' : 't');
1914 gemm<at::BFloat16>(
1915 'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
1916 }
1917
1918 /* LEVEL 1 BLAS FUNCTIONS */
1919
1920 template <>
dot(CUDABLAS_DOT_ARGTYPES (double))1921 void dot<double>(CUDABLAS_DOT_ARGTYPES(double)) {
1922 TORCH_CUDABLAS_CHECK(cublasDdot(handle, n, x, incx, y, incy, result));
1923 }
1924
1925 template <>
dot(CUDABLAS_DOT_ARGTYPES (float))1926 void dot<float>(CUDABLAS_DOT_ARGTYPES(float)) {
1927 TORCH_CUDABLAS_CHECK(cublasSdot(handle, n, x, incx, y, incy, result));
1928 }
1929
1930 template <>
dot(CUDABLAS_DOT_ARGTYPES (c10::complex<double>))1931 void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
1932 TORCH_CUDABLAS_CHECK(cublasZdotu(handle, n, reinterpret_cast<const cuDoubleComplex*>(x),
1933 incx, reinterpret_cast<const cuDoubleComplex*>(y), incy,
1934 reinterpret_cast<cuDoubleComplex*>(result)));
1935 }
1936
1937 template <>
dot(CUDABLAS_DOT_ARGTYPES (c10::complex<float>))1938 void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
1939 TORCH_CUDABLAS_CHECK(cublasCdotu(handle, n, reinterpret_cast<const cuComplex*>(x),
1940 incx, reinterpret_cast<const cuComplex*>(y), incy,
1941 reinterpret_cast<cuComplex*>(result)));
1942 }
1943
1944 template <>
dot(CUDABLAS_DOT_ARGTYPES (at::Half))1945 void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
1946 TORCH_CUDABLAS_CHECK(cublasDotEx(
1947 handle,
1948 n,
1949 x,
1950 CUDA_R_16F,
1951 incx,
1952 y,
1953 CUDA_R_16F,
1954 incy,
1955 result,
1956 CUDA_R_16F,
1957 CUDA_R_32F));
1958 }
1959
1960 template <>
dot(CUDABLAS_DOT_ARGTYPES (at::BFloat16))1961 void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
1962 TORCH_CUDABLAS_CHECK(cublasDotEx(
1963 handle,
1964 n,
1965 x,
1966 CUDA_R_16BF,
1967 incx,
1968 y,
1969 CUDA_R_16BF,
1970 incy,
1971 result,
1972 CUDA_R_16BF,
1973 CUDA_R_32F));
1974 }
1975
1976 template <>
vdot(CUDABLAS_DOT_ARGTYPES (c10::complex<float>))1977 void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
1978 TORCH_CUDABLAS_CHECK(cublasCdotc(handle, n, reinterpret_cast<const cuComplex*>(x),
1979 incx, reinterpret_cast<const cuComplex*>(y), incy,
1980 reinterpret_cast<cuComplex*>(result)));
1981 }
1982
1983 template <>
vdot(CUDABLAS_DOT_ARGTYPES (c10::complex<double>))1984 void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
1985 TORCH_CUDABLAS_CHECK(cublasZdotc(handle, n, reinterpret_cast<const cuDoubleComplex*>(x),
1986 incx, reinterpret_cast<const cuDoubleComplex*>(y), incy,
1987 reinterpret_cast<cuDoubleComplex*>(result)));
1988 }
1989
1990 template <>
getrsBatched(CUDABLAS_GETRS_ARGTYPES (float))1991 void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float)) {
1992 TORCH_CUDABLAS_CHECK(cublasSgetrsBatched(
1993 handle,
1994 trans,
1995 n,
1996 nrhs,
1997 dA_array,
1998 lda,
1999 ipiv_array,
2000 dB_array,
2001 ldb,
2002 info_array,
2003 batchsize));
2004 }
2005
2006 template <>
getrsBatched(CUDABLAS_GETRS_ARGTYPES (double))2007 void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double)) {
2008 TORCH_CUDABLAS_CHECK(cublasDgetrsBatched(
2009 handle,
2010 trans,
2011 n,
2012 nrhs,
2013 dA_array,
2014 lda,
2015 ipiv_array,
2016 dB_array,
2017 ldb,
2018 info_array,
2019 batchsize));
2020 }
2021
2022 template <>
getrsBatched(CUDABLAS_GETRS_ARGTYPES (c10::complex<float>))2023 void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>)) {
2024 TORCH_CUDABLAS_CHECK(cublasCgetrsBatched(
2025 handle,
2026 trans,
2027 n,
2028 nrhs,
2029 reinterpret_cast<cuComplex**>(dA_array),
2030 lda,
2031 ipiv_array,
2032 reinterpret_cast<cuComplex**>(dB_array),
2033 ldb,
2034 info_array,
2035 batchsize));
2036 }
2037
2038 template <>
getrsBatched(CUDABLAS_GETRS_ARGTYPES (c10::complex<double>))2039 void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>)) {
2040 TORCH_CUDABLAS_CHECK(cublasZgetrsBatched(
2041 handle,
2042 trans,
2043 n,
2044 nrhs,
2045 reinterpret_cast<cuDoubleComplex**>(dA_array),
2046 lda,
2047 ipiv_array,
2048 reinterpret_cast<cuDoubleComplex**>(dB_array),
2049 ldb,
2050 info_array,
2051 batchsize));
2052 }
2053
2054 template <>
geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES (float))2055 void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)) {
2056 TORCH_CUDABLAS_CHECK(cublasSgeqrfBatched(
2057 handle, m, n, A_array, lda, tau_array, info, batchsize));
2058 }
2059
2060 template <>
geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES (double))2061 void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double)) {
2062 TORCH_CUDABLAS_CHECK(cublasDgeqrfBatched(
2063 handle, m, n, A_array, lda, tau_array, info, batchsize));
2064 }
2065
2066 template <>
geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES (c10::complex<float>))2067 void geqrfBatched<c10::complex<float>>(
2068 CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>)) {
2069 TORCH_CUDABLAS_CHECK(cublasCgeqrfBatched(
2070 handle,
2071 m,
2072 n,
2073 reinterpret_cast<cuComplex**>(A_array),
2074 lda,
2075 reinterpret_cast<cuComplex**>(tau_array),
2076 info,
2077 batchsize));
2078 }
2079
2080 template <>
geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES (c10::complex<double>))2081 void geqrfBatched<c10::complex<double>>(
2082 CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>)) {
2083 TORCH_CUDABLAS_CHECK(cublasZgeqrfBatched(
2084 handle,
2085 m,
2086 n,
2087 reinterpret_cast<cuDoubleComplex**>(A_array),
2088 lda,
2089 reinterpret_cast<cuDoubleComplex**>(tau_array),
2090 info,
2091 batchsize));
2092 }
2093
2094 template <>
getrfBatched(int n,double ** dA_array,int ldda,int * ipiv_array,int * info_array,int batchsize)2095 void getrfBatched<double>(
2096 int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) {
2097 auto handle = at::cuda::getCurrentCUDABlasHandle();
2098 TORCH_CUDABLAS_CHECK(cublasDgetrfBatched(
2099 handle, n, dA_array, ldda, ipiv_array, info_array, batchsize));
2100 }
2101
2102 template <>
getrfBatched(int n,float ** dA_array,int ldda,int * ipiv_array,int * info_array,int batchsize)2103 void getrfBatched<float>(
2104 int n, float** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) {
2105 auto handle = at::cuda::getCurrentCUDABlasHandle();
2106 TORCH_CUDABLAS_CHECK(cublasSgetrfBatched(
2107 handle, n, dA_array, ldda, ipiv_array, info_array, batchsize));
2108 }
2109
2110 template <>
getrfBatched(int n,c10::complex<double> ** dA_array,int ldda,int * ipiv_array,int * info_array,int batchsize)2111 void getrfBatched<c10::complex<double>>(
2112 int n,
2113 c10::complex<double>** dA_array,
2114 int ldda,
2115 int* ipiv_array,
2116 int* info_array,
2117 int batchsize) {
2118 auto handle = at::cuda::getCurrentCUDABlasHandle();
2119 TORCH_CUDABLAS_CHECK(cublasZgetrfBatched(
2120 handle,
2121 n,
2122 reinterpret_cast<cuDoubleComplex**>(dA_array),
2123 ldda,
2124 ipiv_array,
2125 info_array,
2126 batchsize));
2127 }
2128
2129 template <>
getrfBatched(int n,c10::complex<float> ** dA_array,int ldda,int * ipiv_array,int * info_array,int batchsize)2130 void getrfBatched<c10::complex<float>>(
2131 int n,
2132 c10::complex<float>** dA_array,
2133 int ldda,
2134 int* ipiv_array,
2135 int* info_array,
2136 int batchsize) {
2137 auto handle = at::cuda::getCurrentCUDABlasHandle();
2138 TORCH_CUDABLAS_CHECK(cublasCgetrfBatched(
2139 handle,
2140 n,
2141 reinterpret_cast<cuComplex**>(dA_array),
2142 ldda,
2143 ipiv_array,
2144 info_array,
2145 batchsize));
2146 }
2147
2148
2149 template <>
gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES (double))2150 void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double)) {
2151 TORCH_CUDABLAS_CHECK(cublasDgelsBatched(
2152 handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize));
2153 }
2154
2155 template <>
gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES (float))2156 void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float)) {
2157 TORCH_CUDABLAS_CHECK(cublasSgelsBatched(
2158 handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize));
2159 }
2160
2161 template <>
gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES (c10::complex<double>))2162 void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>)) {
2163 TORCH_CUDABLAS_CHECK(cublasZgelsBatched(
2164 handle, trans,
2165 m, n, nrhs,
2166 reinterpret_cast<cuDoubleComplex**>(dA_array),
2167 ldda,
2168 reinterpret_cast<cuDoubleComplex**>(dC_array),
2169 lddc,
2170 info,
2171 devInfoArray,
2172 batchSize));
2173 }
2174
2175 template <>
gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES (c10::complex<float>))2176 void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>)) {
2177 TORCH_CUDABLAS_CHECK(cublasCgelsBatched(
2178 handle, trans,
2179 m, n, nrhs,
2180 reinterpret_cast<cuComplex**>(dA_array),
2181 ldda,
2182 reinterpret_cast<cuComplex**>(dC_array),
2183 lddc,
2184 info,
2185 devInfoArray,
2186 batchSize));
2187 }
2188
2189 } // namespace at::cuda::blas
2190