xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDABlas.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2   Provides the implementations of CUDA BLAS function templates.
3  */
4 
5 #include <ATen/ATen.h>
6 #include <ATen/cuda/CUDABlas.h>
7 #include <ATen/cuda/Exceptions.h>
8 #include <ATen/cuda/CUDADataType.h>
9 #include <ATen/cuda/tunable/Tunable.h>
10 #include <ATen/cuda/tunable/TunableGemm.h>
11 #include <c10/cuda/CUDACachingAllocator.h>
12 #include <c10/cuda/CUDAFunctions.h>
13 #include <c10/macros/Export.h>
14 #include <c10/util/irange.h>
15 
16 #ifdef USE_ROCM
17 #include <hipblaslt/hipblaslt-ext.hpp>
18 // until hipblas has an API to accept flags, we must use rocblas here
19 #include <hipblas/hipblas.h>
20 #include <rocblas/rocblas.h>
21 #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR)
22 #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242)
23 // needed to work around calling rocblas API instead of hipblas API
hipOperationToRocOperation(hipblasOperation_t op)24 static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
25 {
26     switch(op)
27     {
28     case HIPBLAS_OP_N:
29         return rocblas_operation_none;
30     case HIPBLAS_OP_T:
31         return rocblas_operation_transpose;
32     case HIPBLAS_OP_C:
33         return rocblas_operation_conjugate_transpose;
34     }
35     AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
36 }
rocBLASStatusToHIPStatus(rocblas_status error)37 static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
38 {
39     switch(error)
40     {
41     case rocblas_status_size_unchanged:
42     case rocblas_status_size_increased:
43     case rocblas_status_success:
44         return HIPBLAS_STATUS_SUCCESS;
45     case rocblas_status_invalid_handle:
46         return HIPBLAS_STATUS_NOT_INITIALIZED;
47     case rocblas_status_not_implemented:
48         return HIPBLAS_STATUS_NOT_SUPPORTED;
49     case rocblas_status_invalid_pointer:
50     case rocblas_status_invalid_size:
51     case rocblas_status_invalid_value:
52         return HIPBLAS_STATUS_INVALID_VALUE;
53     case rocblas_status_memory_error:
54         return HIPBLAS_STATUS_ALLOC_FAILED;
55     case rocblas_status_internal_error:
56         return HIPBLAS_STATUS_INTERNAL_ERROR;
57     }
58     AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
59 }
60 // hipblas does not have hipblasSetMathMode
61 #define hipblasSetMathMode(handle, flags) HIPBLAS_STATUS_SUCCESS
62 // until we use hiblas v2
63 // hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
64 // however hipblas v1 is still using its custom type
65 #ifndef HIPBLAS_V2
66 #define HIP_R_16F  HIPBLAS_R_16F
67 #define HIP_R_32F  HIPBLAS_R_32F
68 #define HIP_R_64F  HIPBLAS_R_64F
69 #define HIP_C_16F  HIPBLAS_C_16F
70 #define HIP_C_32F  HIPBLAS_C_32F
71 #define HIP_C_64F  HIPBLAS_C_64F
72 #define HIP_R_8I   HIPBLAS_R_8I
73 #define HIP_R_8U   HIPBLAS_R_8U
74 #define HIP_R_32I  HIPBLAS_R_32I
75 #define HIP_R_32U  HIPBLAS_R_32U
76 #define HIP_C_8I   HIPBLAS_C_8I
77 #define HIP_C_8U   HIPBLAS_C_8U
78 #define HIP_C_32I  HIPBLAS_C_32I
79 #define HIP_C_32U  HIPBLAS_C_32U
80 #define HIP_R_16BF HIPBLAS_R_16B
81 #define HIP_C_16BF HIPBLAS_C_16B
82 #endif
83 #endif
84 
85 #define CUDABLAS_POSINT_CHECK(FD, X)         \
86   TORCH_CHECK(                               \
87       (X > 0 && X <= INT_MAX),               \
88       "at::cuda::blas::" #FD " argument " #X \
89       " must be positive and less than ",    \
90       INT_MAX,                               \
91       " but got ",                           \
92       X)
93 
94 #define CUDABLAS_NONNEGINT_CHECK(FD, X)       \
95   TORCH_CHECK(                                \
96       (X >= 0 && X <= INT_MAX),               \
97       "at::cuda::blas::" #FD " argument " #X  \
98       " must be non-negative and less than ", \
99       INT_MAX,                                \
100       " but got ",                            \
101       X)
102 
103 namespace {
104 
_cublasOpFromChar(char op)105 static cublasOperation_t _cublasOpFromChar(char op) {
106   switch (op) {
107     case 'n':
108     case 'N':
109       return CUBLAS_OP_N;
110     case 't':
111     case 'T':
112       return CUBLAS_OP_T;
113     case 'c':
114     case 'C':
115       return CUBLAS_OP_C;
116   }
117   AT_ERROR(
118       "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`");
119 }
120 
_cublasAdjustLdLevel2(int64_t m,int64_t n,int64_t * lda)121 static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) {
122   // Note: leading dimensions generally are checked that they are > 0
123   // and at least as big the result requires (even if the value won't
124   // be used).
125 
126   // Q: Why does Level3 check trans but this doesn't?
127   // A: In level 2, the sizes (m, n) specify the size of A
128   // (independent of trans value). In level 3. the sizes (m, n, k)
129   // specify the sizes of op(A), op(B) where op depend on trans
130   // values.
131   if (n <= 1)
132     *lda = std::max<int64_t>(m, 1);
133 }
134 
_cublasAdjustLdLevel3(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t * lda,int64_t * ldb,int64_t * ldc)135 static void _cublasAdjustLdLevel3(
136     char transa,
137     char transb,
138     int64_t m,
139     int64_t n,
140     int64_t k,
141     int64_t* lda,
142     int64_t* ldb,
143     int64_t* ldc) {
144   bool transa_ = ((transa != 'n') && (transa != 'N'));
145   bool transb_ = ((transb != 'n') && (transb != 'N'));
146 
147   // Note: leading dimensions generally are checked that they are > 0
148   // and at least as big the result requires (even if the value won't
149   // be used).
150   if (n <= 1)
151     *ldc = std::max<int64_t>(m, 1);
152 
153   if (transa_) {
154     if (m <= 1)
155       *lda = std::max<int64_t>(k, 1);
156   } else {
157     if (k <= 1)
158       *lda = std::max<int64_t>(m, 1);
159   }
160 
161   if (transb_) {
162     if (k <= 1)
163       *ldb = std::max<int64_t>(n, 1);
164   } else {
165     if (n <= 1)
166       *ldb = std::max<int64_t>(k, 1);
167   }
168 }
169 
170 #ifndef USE_ROCM
_getAlignment(uintptr_t address)171 uint32_t _getAlignment(uintptr_t address) {
172   // alignment are in bytes
173   uint32_t alignment = 256;
174   for (; ; alignment /= 2) {
175     if (!(address % alignment)) {
176       return alignment;
177     }
178   }
179 }
180 #endif
181 
_parseChosenWorkspaceSize()182 static size_t _parseChosenWorkspaceSize() {
183   const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
184 #ifdef USE_ROCM
185   if (!val) {
186     // accept either env var
187     val = getenv("HIPBLASLT_WORKSPACE_SIZE");
188   }
189 #endif
190   size_t workspace_size = 1024; /* default size in KiB according to #73328 */
191   if (val) {
192     try {
193       workspace_size = std::stoi(val);
194     } catch(std::invalid_argument const& e) {
195       TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,",
196                  " using default workspace size of ", workspace_size, " KiB.");
197     } catch(std::out_of_range const& e) {
198       TORCH_WARN("CUBLASLT_WORKSPACE_SIZE out of range,",
199                  " using default workspace size of ", workspace_size, " KiB.");
200     }
201   }
202   return workspace_size * 1024;
203 }
204 
_getWorkspaceSize()205 static size_t _getWorkspaceSize() {
206   static size_t workspace_size = _parseChosenWorkspaceSize();
207   return workspace_size;
208 }
209 
210 } // anonymous namespace
211 
212 namespace at::cuda::blas {
213 
214 /* LEVEL 3 BLAS FUNCTIONS */
215 
216 #define GEMM_CHECK_ARGVALUES(Dtype)           \
217   do {                                        \
218     CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, m); \
219     CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, n); \
220     CUDABLAS_NONNEGINT_CHECK(gemm<Dtype>, k); \
221     CUDABLAS_POSINT_CHECK(gemm<Dtype>, lda);  \
222     CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldb);  \
223     CUDABLAS_POSINT_CHECK(gemm<Dtype>, ldc);  \
224   } while (0)
225 
226 #define BGEMM_CHECK_ARGVALUES(Dtype)           \
227   do {                                        \
228     CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, m); \
229     CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, n); \
230     CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, k); \
231     CUDABLAS_POSINT_CHECK(bgemm<Dtype>, lda);  \
232     CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldb);  \
233     CUDABLAS_POSINT_CHECK(bgemm<Dtype>, ldc);  \
234     CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches);  \
235   } while (0)
236 
237 
238 namespace {
239 // Following the pattern of CuSparseDescriptor
240 // Defined here for now because this is the only place cublas_lt interface is
241 // used but can be moved to a header once cublas_lt interface is used in
242 // multiple places.
243 template <typename T, cublasStatus_t (*destructor)(T*)>
244 struct CuBlasLtDeleter {
operator ()at::cuda::blas::__anon764ff9400211::CuBlasLtDeleter245   void operator()(T* x) {
246     if (x != nullptr) {
247       TORCH_CUDABLAS_CHECK(destructor(x));
248     }
249   }
250 };
251 
252 template <typename T, cublasStatus_t (*destructor)(T*)>
253 class CuBlasLtDescriptor {
254  public:
descriptor() const255   T* descriptor() const {
256     return descriptor_.get();
257   }
descriptor()258   T* descriptor() {
259     return descriptor_.get();
260   }
261 
262  protected:
263   std::unique_ptr<T, CuBlasLtDeleter<T, destructor>> descriptor_;
264 };
265 
266 class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
267                                      cublasLtMatmulDescOpaque_t,
268                                      &cublasLtMatmulDescDestroy> {
269  public:
CuBlasLtMatmulDescriptor(cublasComputeType_t compute_type,cudaDataType_t scale_type)270   CuBlasLtMatmulDescriptor(
271       cublasComputeType_t compute_type,
272       cudaDataType_t scale_type) {
273     cublasLtMatmulDesc_t raw_descriptor = nullptr;
274     TORCH_CUDABLAS_CHECK(
275         cublasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
276     descriptor_.reset(raw_descriptor);
277   }
278   template <typename T>
setAttribute(cublasLtMatmulDescAttributes_t attr,const T value)279   inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
280     TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
281   }
282 };
283 
284 class CuBlasLtMatrixLayout : public CuBlasLtDescriptor<
285                                  cublasLtMatrixLayoutOpaque_t,
286                                  &cublasLtMatrixLayoutDestroy> {
287  public:
CuBlasLtMatrixLayout(cudaDataType_t type,uint64_t rows,uint64_t cols,int64_t ld,bool t=false)288   CuBlasLtMatrixLayout(
289       cudaDataType_t type,
290       uint64_t rows,
291       uint64_t cols,
292       int64_t ld,
293       bool t = false) {
294     cublasLtMatrixLayout_t raw_descriptor = nullptr;
295     TORCH_CUDABLAS_CHECK(
296         cublasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld));
297     descriptor_.reset(raw_descriptor);
298   }
299   template <typename T>
setAttribute(cublasLtMatrixLayoutAttribute_t attr,const T value)300   inline void setAttribute(cublasLtMatrixLayoutAttribute_t attr, const T value) {
301     TORCH_CUDABLAS_CHECK(::cublasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T)));
302   }
303 };
304 
305 class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
306                                      cublasLtMatmulPreferenceOpaque_t,
307                                      &cublasLtMatmulPreferenceDestroy> {
308  public:
CuBlasLtMatmulPreference()309   CuBlasLtMatmulPreference() {
310     cublasLtMatmulPreference_t raw_descriptor = nullptr;
311     TORCH_CUDABLAS_CHECK(cublasLtMatmulPreferenceCreate(&raw_descriptor));
312     descriptor_.reset(raw_descriptor);
313   }
314   template <typename T>
setAttribute(cublasLtMatmulPreferenceAttributes_t attr,const T value)315   inline void setAttribute(cublasLtMatmulPreferenceAttributes_t attr, const T value) {
316     TORCH_CUDABLAS_CHECK(::cublasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T)));
317   }
318 };
319 } // namespace
320 
321 
322 template <typename Dtype>
bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES (Dtype))323 inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
324   cudaDataType_t abcType = CUDA_R_32F;
325   cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
326   cudaDataType_t scaleType = CUDA_R_32F;
327   if constexpr (std::is_same_v<Dtype, double>) {
328     abcType = CUDA_R_64F;
329     computeType = CUBLAS_COMPUTE_64F;
330     scaleType = CUDA_R_64F;
331   } else if constexpr (std::is_same_v<Dtype, float>) {
332 #ifndef USE_ROCM
333     if (at::globalContext().allowTF32CuBLAS()) {
334       computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
335     }
336 #endif
337   } else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
338     abcType = CUDA_C_64F;
339     computeType = CUBLAS_COMPUTE_64F;
340     scaleType = CUDA_C_64F;
341   } else if constexpr (std::is_same_v<Dtype, c10::complex<float>>) {
342     abcType = CUDA_C_32F;
343     scaleType = CUDA_C_32F;
344   } else if constexpr (std::is_same_v<Dtype, at::Half>) {
345     abcType = CUDA_R_16F;
346   } else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
347     abcType = CUDA_R_16BF;
348   } else {
349     static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublaslt: not implemented");
350   }
351 
352   globalContext().alertCuBLASConfigNotDeterministic();
353   cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
354   cublasOperation_t opa = _cublasOpFromChar(transa);
355   cublasOperation_t opb = _cublasOpFromChar(transb);
356   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
357 
358   CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
359   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
360   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
361   CuBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == CUBLAS_OP_T);
362   CuBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == CUBLAS_OP_T);
363   CuBlasLtMatrixLayout Cdesc(abcType, m, n, ldc);
364 
365   if (num_batches > 1) {
366     int num_batches_as_int = static_cast<int>(num_batches);
367     Adesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
368     Bdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
369     Cdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int);
370     Adesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridea);
371     Bdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, strideb);
372     Cdesc.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridec);
373   }
374 
375   CuBlasLtMatmulPreference preference;
376   // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
377   // setting this to 1M.
378   size_t workspaceSize = _getWorkspaceSize();
379   preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
380 
381 #ifndef USE_ROCM
382   uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(a));
383   uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(b));
384   uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(c));
385   preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment);
386   preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
387   preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
388 #endif
389 
390   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
391   auto workspace = allocator.allocate(workspaceSize);
392   TORCH_CHECK(workspace.get() != nullptr, "OOM trying to allocate workspace for cublaslt");
393 
394   cublasLtMatmulHeuristicResult_t heuristicResult = {};
395   int returnedResult = 0;
396   TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
397       ltHandle,
398       computeDesc.descriptor(),
399       Adesc.descriptor(),
400       Bdesc.descriptor(),
401       Cdesc.descriptor(),
402       Cdesc.descriptor(),
403       preference.descriptor(),
404       1,
405       &heuristicResult,
406       &returnedResult));
407   if (returnedResult == 0) {
408     TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
409   }
410 
411   cublasStatus_t cublasStatus = cublasLtMatmul(
412       ltHandle,
413       computeDesc.descriptor(),
414       &alpha,
415       a,
416       Adesc.descriptor(),
417       b,
418       Bdesc.descriptor(),
419       &beta,
420       c,
421       Cdesc.descriptor(),
422       c,
423       Cdesc.descriptor(),
424       &heuristicResult.algo,
425       workspace.mutable_get(),
426       workspaceSize,
427       at::cuda::getCurrentCUDAStream());
428   TORCH_CHECK(
429       cublasStatus == CUBLAS_STATUS_SUCCESS,
430       "CUDA error: ",
431       at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
432       " when calling cublasLtMatmul with transpose_mat1 ",
433       (opa == CUBLAS_OP_T),
434       " transpose_mat2 ",
435       (opb == CUBLAS_OP_T),
436       " m ",
437       m,
438       " n ",
439       n,
440       " k ",
441       k,
442       " lda ",
443       lda,
444       " ldb ",
445       ldb,
446       " ldc ",
447       ldc,
448       " abcType ",
449       abcType,
450       " computeType ",
451       computeType,
452       " scaleType ",
453       scaleType);
454 }
455 
456 
457 template <typename Dtype>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (Dtype))458 inline void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
459   static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublas: not implemented");
460 }
461 
462 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (double))463 void bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
464   // See Note [Writing Nondeterministic Operations]
465   globalContext().alertCuBLASConfigNotDeterministic();
466   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
467   cublasOperation_t opa = _cublasOpFromChar(transa);
468   cublasOperation_t opb = _cublasOpFromChar(transb);
469   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
470   BGEMM_CHECK_ARGVALUES(double);
471   TORCH_CUDABLAS_CHECK(cublasDgemmStridedBatched(
472       handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
473 }
474 
475 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (float))476 void bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
477   // See Note [Writing Nondeterministic Operations]
478   globalContext().alertCuBLASConfigNotDeterministic();
479   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
480   cublasOperation_t opa = _cublasOpFromChar(transa);
481   cublasOperation_t opb = _cublasOpFromChar(transb);
482   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
483   BGEMM_CHECK_ARGVALUES(float);
484   TORCH_CUDABLAS_CHECK(cublasSgemmStridedBatched(
485       handle, opa, opb, m, n, k, &alpha, a, lda, stridea, b, ldb, strideb, &beta, c, ldc, stridec, num_batches));
486 }
487 
488 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (c10::complex<double>))489 void bgemm_internal_cublas<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
490   // See Note [Writing Nondeterministic Operations]
491   globalContext().alertCuBLASConfigNotDeterministic();
492   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
493   cublasOperation_t opa = _cublasOpFromChar(transa);
494   cublasOperation_t opb = _cublasOpFromChar(transb);
495   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
496   BGEMM_CHECK_ARGVALUES(c10::complex<double>);
497   TORCH_CUDABLAS_CHECK(cublasZgemmStridedBatched(
498       handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
499       lda, stridea, reinterpret_cast<const cuDoubleComplex*>(b), ldb, strideb, reinterpret_cast<const cuDoubleComplex*>(&beta),
500       reinterpret_cast<cuDoubleComplex*>(c), ldc, stridec, num_batches));
501 }
502 
503 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (c10::complex<float>))504 void bgemm_internal_cublas<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
505   // See Note [Writing Nondeterministic Operations]
506   globalContext().alertCuBLASConfigNotDeterministic();
507   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
508   cublasOperation_t opa = _cublasOpFromChar(transa);
509   cublasOperation_t opb = _cublasOpFromChar(transb);
510   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
511   BGEMM_CHECK_ARGVALUES(c10::complex<float>);
512   TORCH_CUDABLAS_CHECK(cublasCgemmStridedBatched(
513       handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
514       lda, stridea, reinterpret_cast<const cuComplex*>(b), ldb, strideb, reinterpret_cast<const cuComplex*>(&beta),
515       reinterpret_cast<cuComplex*>(c), ldc, stridec, num_batches));
516 }
517 
518 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (at::Half))519 void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
520   // See Note [Writing Nondeterministic Operations]
521   globalContext().alertCuBLASConfigNotDeterministic();
522   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
523   cublasOperation_t opa = _cublasOpFromChar(transa);
524   cublasOperation_t opb = _cublasOpFromChar(transb);
525   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
526   BGEMM_CHECK_ARGVALUES(at::Half);
527   float falpha = alpha;
528   float fbeta = beta;
529 #ifdef USE_ROCM
530   int flag = 0;
531 #if USE_GEMM_FLAGS_FP16_ALT_IMPL
532   flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
533 #endif
534   TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle,
535                                    hipOperationToRocOperation(opa),
536                                    hipOperationToRocOperation(opb), (int)m, (int)n, (int)k,
537                                    (void*)&falpha, a, rocblas_datatype_f16_r, (int)lda, stridea,
538                                    b, rocblas_datatype_f16_r, (int)ldb, strideb,
539                                    (void*)&fbeta, c, rocblas_datatype_f16_r, (int)ldc, stridec,
540                                    c, rocblas_datatype_f16_r, (int)ldc, stridec,
541                                    (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard,
542                                    0, flag)));
543 #else
544   cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
545   if (prop->major >= 5){
546     TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(
547       handle, opa, opb, m, n, k,
548       (void*)(&falpha), a, CUDA_R_16F, lda, stridea,
549       b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
550       c, CUDA_R_16F, ldc, stridec,
551       num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
552   } else {
553     for (const auto i : c10::irange(num_batches)) {
554       at::cuda::blas::gemm<at::Half>(
555         transa, transb,
556         m, n, k,
557         alpha, (a + i * stridea), lda,
558         (b + i * strideb), ldb, beta,
559         (c + i * stridec), ldc);
560     }
561   }
562 #endif // USE_ROCM
563 }
564 
565 template <>
bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES (at::BFloat16))566 void bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
567   // See Note [Writing Nondeterministic Operations]
568   globalContext().alertCuBLASConfigNotDeterministic();
569   BGEMM_CHECK_ARGVALUES(at::BFloat16);
570   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
571   cublasOperation_t opa = _cublasOpFromChar(transa);
572   cublasOperation_t opb = _cublasOpFromChar(transb);
573   const float falpha = alpha;
574   const float fbeta = beta;
575   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
576 
577 #if defined(USE_ROCM)
578   auto compute_type = CUBLAS_COMPUTE_32F;
579 #else
580   auto compute_type = CUDA_R_32F;
581 #endif
582   TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(handle,
583                                   opa, opb, (int)m, (int)n, (int)k,
584                                   (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea,
585                                   b, CUDA_R_16BF, (int)ldb, strideb,
586                                   (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
587                                   (int)num_batches,
588                                   compute_type,
589                                   CUBLAS_GEMM_DEFAULT_TENSOR_OP));
590 }
591 
592 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (double))593 void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
594 {
595   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
596 #ifdef USE_ROCM
597     // hipblaslt does not support double gemm yet
598     bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGS(double));
599 #else
600     bgemm_internal_cublaslt<double>(CUDABLAS_BGEMM_ARGS(double));
601 #endif
602   }
603   else {
604     bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGS(double));
605   }
606 }
607 
608 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (float))609 void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float))
610 {
611   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
612     bgemm_internal_cublaslt<float>(CUDABLAS_BGEMM_ARGS(float));
613   }
614   else {
615     bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGS(float));
616   }
617 }
618 
619 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (c10::complex<double>))620 void bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>))
621 {
622   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
623 #ifdef USE_ROCM
624     // hipblaslt does not support complex<double> gemm yet
625     bgemm_internal_cublas<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
626 #else
627     bgemm_internal_cublaslt<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
628 #endif
629   }
630   else {
631     bgemm_internal_cublas<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
632   }
633 }
634 
635 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (c10::complex<float>))636 void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>))
637 {
638   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
639 #ifdef USE_ROCM
640     // hipblaslt does not support complex<float> gemm yet
641     bgemm_internal_cublas<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
642 #else
643     bgemm_internal_cublaslt<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
644 #endif
645   }
646   else {
647     bgemm_internal_cublas<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
648   }
649 }
650 
651 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (at::Half))652 void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
653 {
654   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
655     bgemm_internal_cublaslt<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
656   }
657   else {
658     bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
659   }
660 }
661 
662 template <>
bgemm_internal(CUDABLAS_BGEMM_ARGTYPES (at::BFloat16))663 void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
664 {
665   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
666     bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
667   }
668   else {
669     bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
670   }
671 }
672 
673 template <typename DType>
bgemm_tunable(CUDABLAS_BGEMM_ARGTYPES (DType))674 inline void bgemm_tunable(CUDABLAS_BGEMM_ARGTYPES(DType)) {
675   tunable::GemmStridedBatchedParams<DType> params;
676   params.transa = transa;
677   params.transb = transb;
678   params.m = m;
679   params.n = n;
680   params.k = k;
681   params.alpha = alpha;
682   params.a = a;
683   params.lda = lda;
684   params.stride_a = stridea;
685   params.b = b;
686   params.ldb = ldb;
687   params.stride_b = strideb;
688   params.beta = beta;
689   params.c = c;
690   params.ldc = ldc;
691   params.stride_c = stridec;
692   params.batch = num_batches;
693 
694   bool transa_ = ((transa != 'n') && (transa != 'N'));
695   bool transb_ = ((transb != 'n') && (transb != 'N'));
696 
697   if (transa_ && transb_) {
698     static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> bgemm{};
699     bgemm(&params);
700   }
701   else if (transa_ && !transb_) {
702     static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> bgemm{};
703     bgemm(&params);
704   }
705   else if (!transa_ && transb_) {
706     static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> bgemm{};
707     bgemm(&params);
708   }
709   else if (!transa_ && !transb_) {
710     static tunable::GemmStridedBatchedTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> bgemm{};
711     bgemm(&params);
712   }
713   else {
714     TORCH_CHECK(false, "unreachable");
715   }
716 }
717 
718 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (double))719 void bgemm<double>(CUDABLAS_BGEMM_ARGTYPES(double)) {
720   auto tuning_ctx = at::cuda::tunable::getTuningContext();
721   if (tuning_ctx->IsTunableOpEnabled()) {
722     bgemm_tunable<double>(CUDABLAS_BGEMM_ARGS(double));
723   }
724   else {
725     bgemm_internal<double>(CUDABLAS_BGEMM_ARGS(double));
726   }
727 }
728 
729 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (float))730 void bgemm<float>(CUDABLAS_BGEMM_ARGTYPES(float)) {
731   auto tuning_ctx = at::cuda::tunable::getTuningContext();
732   if (tuning_ctx->IsTunableOpEnabled()) {
733     bgemm_tunable<float>(CUDABLAS_BGEMM_ARGS(float));
734   }
735   else {
736     bgemm_internal<float>(CUDABLAS_BGEMM_ARGS(float));
737   }
738 }
739 
740 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (c10::complex<double>))741 void bgemm<c10::complex<double>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<double>)) {
742   auto tuning_ctx = at::cuda::tunable::getTuningContext();
743   if (tuning_ctx->IsTunableOpEnabled()) {
744     bgemm_tunable<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
745   }
746   else {
747     bgemm_internal<c10::complex<double>>(CUDABLAS_BGEMM_ARGS(c10::complex<double>));
748   }
749 }
750 
751 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (c10::complex<float>))752 void bgemm<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<float>)) {
753   auto tuning_ctx = at::cuda::tunable::getTuningContext();
754   if (tuning_ctx->IsTunableOpEnabled()) {
755     bgemm_tunable<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
756   }
757   else {
758     bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGS(c10::complex<float>));
759   }
760 }
761 
762 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (at::Half))763 void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
764   auto tuning_ctx = at::cuda::tunable::getTuningContext();
765   if (tuning_ctx->IsTunableOpEnabled()) {
766     bgemm_tunable<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
767   }
768   else {
769     bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
770   }
771 }
772 
773 template <>
bgemm(CUDABLAS_BGEMM_ARGTYPES (at::BFloat16))774 void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
775   auto tuning_ctx = at::cuda::tunable::getTuningContext();
776   if (tuning_ctx->IsTunableOpEnabled()) {
777     bgemm_tunable<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
778   }
779   else {
780     bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
781   }
782 }
783 
784 template <typename Dtype>
gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES (Dtype))785 inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
786   // forward to bgemm implementation but set strides and batches to 0
787   bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0);
788 }
789 
790 template <typename Dtype>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (Dtype))791 inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
792   static_assert(false && sizeof(Dtype), "at::cuda::blas::gemm_internal_cublas: not implemented");
793 }
794 
795 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (double))796 void gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
797   // See Note [Writing Nondeterministic Operations]
798   globalContext().alertCuBLASConfigNotDeterministic();
799   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
800   cublasOperation_t opa = _cublasOpFromChar(transa);
801   cublasOperation_t opb = _cublasOpFromChar(transb);
802   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
803   GEMM_CHECK_ARGVALUES(double);
804   TORCH_CUDABLAS_CHECK(cublasDgemm(
805       handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
806 }
807 
808 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (float))809 void gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
810   // See Note [Writing Nondeterministic Operations]
811   globalContext().alertCuBLASConfigNotDeterministic();
812   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
813   cublasOperation_t opa = _cublasOpFromChar(transa);
814   cublasOperation_t opb = _cublasOpFromChar(transb);
815   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
816   GEMM_CHECK_ARGVALUES(float);
817   TORCH_CUDABLAS_CHECK(cublasSgemm(
818       handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc));
819 }
820 
821 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (c10::complex<double>))822 void gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
823   // See Note [Writing Nondeterministic Operations]
824   globalContext().alertCuBLASConfigNotDeterministic();
825   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
826   cublasOperation_t opa = _cublasOpFromChar(transa);
827   cublasOperation_t opb = _cublasOpFromChar(transb);
828   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
829   GEMM_CHECK_ARGVALUES(c10::complex<double>);
830   TORCH_CUDABLAS_CHECK(cublasZgemm(
831       handle, opa, opb, m, n, k, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
832       lda, reinterpret_cast<const cuDoubleComplex*>(b), ldb, reinterpret_cast<const cuDoubleComplex*>(&beta),
833       reinterpret_cast<cuDoubleComplex*>(c), ldc));
834 }
835 
836 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (c10::complex<float>))837 void gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
838   // See Note [Writing Nondeterministic Operations]
839   globalContext().alertCuBLASConfigNotDeterministic();
840   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
841   cublasOperation_t opa = _cublasOpFromChar(transa);
842   cublasOperation_t opb = _cublasOpFromChar(transb);
843   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
844   GEMM_CHECK_ARGVALUES(c10::complex<float>);
845   TORCH_CUDABLAS_CHECK(cublasCgemm(
846       handle, opa, opb, m, n, k, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
847       lda, reinterpret_cast<const cuComplex*>(b), ldb, reinterpret_cast<const cuComplex*>(&beta),
848       reinterpret_cast<cuComplex*>(c), ldc));
849 }
850 
851 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (at::Half))852 void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
853   // See Note [Writing Nondeterministic Operations]
854   globalContext().alertCuBLASConfigNotDeterministic();
855   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
856   cublasOperation_t opa = _cublasOpFromChar(transa);
857   cublasOperation_t opb = _cublasOpFromChar(transb);
858   float falpha = alpha;
859   float fbeta = beta;
860   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
861   GEMM_CHECK_ARGVALUES(at::Half);
862 #ifdef USE_ROCM
863   int flag = 0;
864 #if USE_GEMM_FLAGS_FP16_ALT_IMPL
865   flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
866 #endif
867   TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(
868       (rocblas_handle)handle,
869       hipOperationToRocOperation(opa),
870       hipOperationToRocOperation(opb),
871       m,
872       n,
873       k,
874       &falpha,
875       a,
876       rocblas_datatype_f16_r,
877       lda,
878       b,
879       rocblas_datatype_f16_r,
880       ldb,
881       &fbeta,
882       c,
883       rocblas_datatype_f16_r,
884       ldc,
885       c,
886       rocblas_datatype_f16_r,
887       ldc,
888       rocblas_datatype_f32_r,
889       rocblas_gemm_algo_standard,
890       0,
891       flag)));
892 #else
893   cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
894   if (prop->major >= 5) {
895 #ifndef USE_ROCM
896     cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
897     if (!at::globalContext().allowFP16ReductionCuBLAS()) {
898       cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
899     }
900 #endif
901     // Disallow fp16 reductions that could lead to unexpected overflow issues.
902     TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
903     TORCH_CUDABLAS_CHECK(cublasGemmEx(
904         handle,
905         opa,
906         opb,
907         m,
908         n,
909         k,
910         &falpha,
911         a,
912         CUDA_R_16F,
913         lda,
914         b,
915         CUDA_R_16F,
916         ldb,
917         &fbeta,
918         c,
919         CUDA_R_16F,
920         ldc,
921         CUDA_R_32F,
922         CUBLAS_GEMM_DEFAULT_TENSOR_OP));
923     TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
924   } else {
925     TORCH_CUDABLAS_CHECK(cublasSgemmEx(
926         handle,
927         opa,
928         opb,
929         m,
930         n,
931         k,
932         &falpha,
933         a,
934         CUDA_R_16F,
935         lda,
936         b,
937         CUDA_R_16F,
938         ldb,
939         &fbeta,
940         c,
941         CUDA_R_16F,
942         ldc));
943   }
944 #endif
945 }
946 
947 template <>
gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES (at::BFloat16))948 void gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
949   globalContext().alertCuBLASConfigNotDeterministic();
950   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
951   cublasOperation_t opa = _cublasOpFromChar(transa);
952   cublasOperation_t opb = _cublasOpFromChar(transb);
953   float falpha = alpha;
954   float fbeta = beta;
955   _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
956   GEMM_CHECK_ARGVALUES(at::BFloat16);
957 #ifndef USE_ROCM
958   cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
959   if (!at::globalContext().allowBF16ReductionCuBLAS()) {
960     cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
961   }
962 #endif
963 #if defined(USE_ROCM)
964   auto compute_type = CUBLAS_COMPUTE_32F;
965 #else
966   auto compute_type = CUDA_R_32F;
967 #endif
968   TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
969   TORCH_CUDABLAS_CHECK(cublasGemmEx(
970       handle,
971       opa,
972       opb,
973       m,
974       n,
975       k,
976       &falpha,
977       a,
978       CUDA_R_16BF,
979       lda,
980       b,
981       CUDA_R_16BF,
982       ldb,
983       &fbeta,
984       c,
985       CUDA_R_16BF,
986       ldc,
987       compute_type,
988       CUBLAS_GEMM_DEFAULT_TENSOR_OP));
989   TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
990 }
991 
992 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (double))993 void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
994 {
995   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
996 #ifdef USE_ROCM
997     // hipblaslt does not support double gemm yet
998     gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
999 #else
1000     gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
1001 #endif
1002   }
1003   else {
1004     gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
1005   }
1006 }
1007 
1008 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (float))1009 void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
1010 {
1011   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
1012     gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
1013   }
1014   else {
1015     gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
1016   }
1017 }
1018 
1019 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (c10::complex<double>))1020 void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>))
1021 {
1022   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
1023 #ifdef USE_ROCM
1024     // hipblaslt does not support complex gemm yet
1025     gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
1026 #else
1027     gemm_internal_cublaslt<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
1028 #endif
1029   }
1030   else {
1031     gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
1032   }
1033 }
1034 
1035 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (c10::complex<float>))1036 void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>))
1037 {
1038   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
1039 #ifdef USE_ROCM
1040     // hipblaslt does not support complex gemm yet
1041     gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
1042 #else
1043     gemm_internal_cublaslt<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
1044 #endif
1045   }
1046   else {
1047     gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
1048   }
1049 }
1050 
1051 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (at::Half))1052 void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
1053 {
1054   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
1055     gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
1056   }
1057   else {
1058     gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
1059   }
1060 }
1061 
1062 template <>
gemm_internal(CUDABLAS_GEMM_ARGTYPES (at::BFloat16))1063 void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
1064 {
1065   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
1066     gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
1067   }
1068   else {
1069     gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
1070   }
1071 }
1072 
1073 template <typename DType>
gemm_tunable(CUDABLAS_GEMM_ARGTYPES (DType))1074 inline void gemm_tunable(CUDABLAS_GEMM_ARGTYPES(DType)) {
1075   tunable::GemmParams<DType> params;
1076   params.transa = transa;
1077   params.transb = transb;
1078   params.m = m;
1079   params.n = n;
1080   params.k = k;
1081   params.alpha = alpha;
1082   params.a = a;
1083   params.lda = lda;
1084   params.b = b;
1085   params.ldb = ldb;
1086   params.beta = beta;
1087   params.c = c;
1088   params.ldc = ldc;
1089 
1090   bool transa_ = ((transa != 'n') && (transa != 'N'));
1091   bool transb_ = ((transb != 'n') && (transb != 'N'));
1092 
1093   if (transa_ && transb_) {
1094     static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::T> gemm{};
1095     gemm(&params);
1096   }
1097   else if (transa_ && !transb_) {
1098     static tunable::GemmTunableOp<DType, tunable::BlasOp::T, tunable::BlasOp::N> gemm{};
1099     gemm(&params);
1100   }
1101   else if (!transa_ && transb_) {
1102     static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::T> gemm{};
1103     gemm(&params);
1104   }
1105   else if (!transa_ && !transb_) {
1106     static tunable::GemmTunableOp<DType, tunable::BlasOp::N, tunable::BlasOp::N> gemm{};
1107     gemm(&params);
1108   }
1109   else {
1110     TORCH_CHECK(false, "unreachable");
1111   }
1112 }
1113 
1114 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (double))1115 void gemm<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
1116   auto tuning_ctx = at::cuda::tunable::getTuningContext();
1117   if (tuning_ctx->IsTunableOpEnabled()) {
1118     gemm_tunable<double>(CUDABLAS_GEMM_ARGS(double));
1119   }
1120   else {
1121     gemm_internal<double>(CUDABLAS_GEMM_ARGS(double));
1122   }
1123 }
1124 
1125 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (float))1126 void gemm<float>(CUDABLAS_GEMM_ARGTYPES(float)) {
1127   auto tuning_ctx = at::cuda::tunable::getTuningContext();
1128   if (tuning_ctx->IsTunableOpEnabled()) {
1129     gemm_tunable<float>(CUDABLAS_GEMM_ARGS(float));
1130   }
1131   else {
1132     gemm_internal<float>(CUDABLAS_GEMM_ARGS(float));
1133   }
1134 }
1135 
1136 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (c10::complex<double>))1137 void gemm<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>)) {
1138   auto tuning_ctx = at::cuda::tunable::getTuningContext();
1139   if (tuning_ctx->IsTunableOpEnabled()) {
1140     gemm_tunable<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
1141   }
1142   else {
1143     gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
1144   }
1145 }
1146 
1147 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (c10::complex<float>))1148 void gemm<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>)) {
1149   auto tuning_ctx = at::cuda::tunable::getTuningContext();
1150   if (tuning_ctx->IsTunableOpEnabled()) {
1151     gemm_tunable<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
1152   }
1153   else {
1154     gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
1155   }
1156 }
1157 
1158 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (at::Half))1159 void gemm<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
1160   auto tuning_ctx = at::cuda::tunable::getTuningContext();
1161   if (tuning_ctx->IsTunableOpEnabled()) {
1162     gemm_tunable<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
1163   }
1164   else {
1165     gemm_internal<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
1166   }
1167 }
1168 
1169 template <>
gemm(CUDABLAS_GEMM_ARGTYPES (at::BFloat16))1170 void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
1171   auto tuning_ctx = at::cuda::tunable::getTuningContext();
1172   if (tuning_ctx->IsTunableOpEnabled()) {
1173     gemm_tunable<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
1174   }
1175   else {
1176     gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
1177   }
1178 }
1179 
1180 
1181 template <typename Dtype>
gemm_and_bias(bool transpose_mat1,bool transpose_mat2,int64_t m,int64_t n,int64_t k,at::opmath_type<Dtype> alpha_val,const Dtype * mat1_ptr,int64_t mat1_ld,const Dtype * mat2_ptr,int64_t mat2_ld,const Dtype * bias,Dtype * result_ptr,int64_t result_ld,GEMMAndBiasActivationEpilogue activation)1182 void gemm_and_bias(
1183     bool transpose_mat1,
1184     bool transpose_mat2,
1185     int64_t m,
1186     int64_t n,
1187     int64_t k,
1188     at::opmath_type<Dtype> alpha_val,
1189     const Dtype* mat1_ptr,
1190     int64_t mat1_ld,
1191     const Dtype* mat2_ptr,
1192     int64_t mat2_ld,
1193     const Dtype* bias,
1194     Dtype* result_ptr,
1195     int64_t result_ld,
1196     GEMMAndBiasActivationEpilogue activation) {
1197   using opmath_t = at::opmath_type<Dtype>;
1198   opmath_t beta_val = 0; // bias is added in epilogue
1199 
1200   cudaDataType_t abcType = CUDA_R_32F;
1201   cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
1202   cudaDataType_t scaleType = CUDA_R_32F;
1203   if constexpr (std::is_same_v<Dtype, double>) {
1204     abcType = CUDA_R_64F;
1205     computeType = CUBLAS_COMPUTE_64F;
1206     scaleType = CUDA_R_64F;
1207   } else if constexpr (std::is_same_v<Dtype, float>) {
1208 #ifndef USE_ROCM
1209     if (at::globalContext().allowTF32CuBLAS()) {
1210       computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
1211     }
1212 #endif
1213     abcType = CUDA_R_32F;
1214   } else if constexpr (std::is_same_v<Dtype, at::Half>) {
1215     abcType = CUDA_R_16F;
1216   } else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
1217     abcType = CUDA_R_16BF;
1218   }
1219 
1220   CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
1221   cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
1222   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
1223   cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
1224   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
1225   cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
1226   if (activation == GEMMAndBiasActivationEpilogue::RELU) {
1227     epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
1228   } else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
1229 #if CUDA_VERSION >= 11040 || defined(USE_ROCM)
1230     epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
1231 #endif
1232   }
1233 
1234   if (bias != nullptr) {
1235     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
1236     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
1237   }
1238 
1239   CuBlasLtMatrixLayout Adesc(abcType, m, k, mat1_ld, transpose_mat1);
1240   CuBlasLtMatrixLayout Bdesc(abcType, k, n, mat2_ld, transpose_mat2);
1241   CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);
1242 
1243   CuBlasLtMatmulPreference preference;
1244   // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
1245   // setting this to 1M.
1246   size_t workspaceSize = _getWorkspaceSize();
1247   preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
1248 
1249 #ifndef USE_ROCM
1250   uint32_t a_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat1_ptr));
1251   uint32_t b_alignment = _getAlignment(reinterpret_cast<uintptr_t>(mat2_ptr));
1252   uint32_t c_alignment = _getAlignment(reinterpret_cast<uintptr_t>(result_ptr));
1253   uint32_t d_alignment = _getAlignment(reinterpret_cast<uintptr_t>(bias));
1254   preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment);
1255   preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment);
1256   preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
1257   preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
1258 #endif
1259 
1260   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
1261   auto workspace = allocator.allocate(workspaceSize);
1262   TORCH_CHECK(workspace.get() != nullptr, "OOM trying to allocate workspace for cublaslt");
1263 
1264   cublasLtMatmulHeuristicResult_t heuristicResult = {};
1265   int returnedResult = 0;
1266   cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
1267   TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
1268       ltHandle,
1269       computeDesc.descriptor(),
1270       Adesc.descriptor(),
1271       Bdesc.descriptor(),
1272       Cdesc.descriptor(),
1273       Cdesc.descriptor(),
1274       preference.descriptor(),
1275       1,
1276       &heuristicResult,
1277       &returnedResult));
1278   if (returnedResult == 0) {
1279     TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
1280   }
1281 
1282   cublasStatus_t cublasStatus = cublasLtMatmul(
1283       ltHandle,
1284       computeDesc.descriptor(),
1285       &alpha_val,
1286       mat1_ptr,
1287       Adesc.descriptor(),
1288       mat2_ptr,
1289       Bdesc.descriptor(),
1290       &beta_val,
1291       result_ptr,
1292       Cdesc.descriptor(),
1293       result_ptr,
1294       Cdesc.descriptor(),
1295       &heuristicResult.algo,
1296       workspace.mutable_get(),
1297       workspaceSize,
1298       at::cuda::getCurrentCUDAStream());
1299   TORCH_CHECK(
1300       cublasStatus == CUBLAS_STATUS_SUCCESS,
1301       "CUDA error: ",
1302       at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
1303       " when calling cublasLtMatmul with transpose_mat1 ",
1304       transpose_mat1,
1305       " transpose_mat2 ",
1306       transpose_mat2,
1307       " m ",
1308       m,
1309       " n ",
1310       n,
1311       " k ",
1312       k,
1313       " mat1_ld ",
1314       mat1_ld,
1315       " mat2_ld ",
1316       mat2_ld,
1317       " result_ld ",
1318       result_ld,
1319       " abcType ",
1320       abcType,
1321       " computeType ",
1322       computeType,
1323       " scaleType ",
1324       scaleType);
1325 }
1326 
1327 template void gemm_and_bias(
1328     bool transpose_mat1,
1329     bool transpose_mat2,
1330     int64_t m,
1331     int64_t n,
1332     int64_t k,
1333     at::opmath_type<double> alpha_val,
1334     const double* mat1_ptr,
1335     int64_t mat1_ld,
1336     const double* mat2_ptr,
1337     int64_t mat2_ld,
1338     const double* bias,
1339     double* result_ptr,
1340     int64_t result_ld,
1341     GEMMAndBiasActivationEpilogue activation);
1342 
1343 template void gemm_and_bias(
1344     bool transpose_mat1,
1345     bool transpose_mat2,
1346     int64_t m,
1347     int64_t n,
1348     int64_t k,
1349     at::opmath_type<float> alpha_val,
1350     const float* mat1_ptr,
1351     int64_t mat1_ld,
1352     const float* mat2_ptr,
1353     int64_t mat2_ld,
1354     const float* bias,
1355     float* result_ptr,
1356     int64_t result_ld,
1357     GEMMAndBiasActivationEpilogue activation);
1358 
1359 template void gemm_and_bias(
1360     bool transpose_mat1,
1361     bool transpose_mat2,
1362     int64_t m,
1363     int64_t n,
1364     int64_t k,
1365     at::opmath_type<at::Half> alpha_val,
1366     const at::Half* mat1_ptr,
1367     int64_t mat1_ld,
1368     const at::Half* mat2_ptr,
1369     int64_t mat2_ld,
1370     const at::Half* bias,
1371     at::Half* result_ptr,
1372     int64_t result_ld,
1373     GEMMAndBiasActivationEpilogue activation);
1374 
1375 template void gemm_and_bias(
1376     bool transpose_mat1,
1377     bool transpose_mat2,
1378     int64_t m,
1379     int64_t n,
1380     int64_t k,
1381     at::opmath_type<at::BFloat16> alpha_val,
1382     const at::BFloat16* mat1_ptr,
1383     int64_t mat1_ld,
1384     const at::BFloat16* mat2_ptr,
1385     int64_t mat2_ld,
1386     const at::BFloat16* bias,
1387     at::BFloat16* result_ptr,
1388     int64_t result_ld,
1389     GEMMAndBiasActivationEpilogue activation);
1390 
scaled_gemm(char transa,char transb,int64_t m,int64_t n,int64_t k,const void * mat1_ptr,const void * mat1_scale_ptr,int64_t mat1_ld,ScalarType mat1_dtype,const void * mat2_ptr,const void * mat2_scale_ptr,int64_t mat2_ld,ScalarType mat2_dtype,const void * bias_ptr,ScalarType bias_dtype,void * result_ptr,const void * result_scale_ptr,int64_t result_ld,ScalarType result_dtype,void * amax_ptr,bool use_fast_accum)1391 void scaled_gemm(
1392     char transa,
1393     char transb,
1394     int64_t m,
1395     int64_t n,
1396     int64_t k,
1397     const void* mat1_ptr,
1398     const void* mat1_scale_ptr,
1399     int64_t mat1_ld,
1400     ScalarType mat1_dtype,
1401     const void* mat2_ptr,
1402     const void* mat2_scale_ptr,
1403     int64_t mat2_ld,
1404     ScalarType mat2_dtype,
1405     const void* bias_ptr,
1406     ScalarType bias_dtype,
1407     void* result_ptr,
1408     const void *result_scale_ptr,
1409     int64_t result_ld,
1410     ScalarType result_dtype,
1411     void* amax_ptr,
1412     bool use_fast_accum) {
1413 #if CUDA_VERSION >= 11080 || defined(USE_ROCM)
1414   const auto computeType = CUBLAS_COMPUTE_32F;
1415   const auto scaleType = CUDA_R_32F;
1416   const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
1417   const float alpha_val = 1.0;
1418   const float beta_val = 0.0;
1419   CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
1420   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa));
1421   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
1422   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
1423   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
1424   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
1425 #if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60200)
1426   // Amax support in ROCm as of 6.2
1427   if (isFloat8Type(result_dtype)) {
1428     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, amax_ptr);
1429   }
1430 #endif
1431 #ifndef USE_ROCM
1432   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fastAccuMode);
1433 #endif
1434   CuBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't');
1435   CuBlasLtMatrixLayout Bdesc(ScalarTypeToCudaDataType(mat2_dtype), k, n, mat2_ld, transb == 't');
1436 #ifdef USE_ROCM
1437   // Cdesc is unused, beta is 0. But hipblaslt needs this set to something reasonable.
1438   CuBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld);
1439 #else
1440   CuBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(bias_dtype), m, n, result_ld);
1441 #endif
1442   CuBlasLtMatrixLayout Ddesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld);
1443   if (bias_ptr) {
1444     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
1445     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
1446     computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
1447   }
1448   size_t workspaceSize = _getWorkspaceSize();
1449   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
1450   auto workspace = allocator.allocate(workspaceSize);
1451   TORCH_CHECK(workspace.get() != nullptr, "OOM trying to allocate workspace for cublaslt");
1452 
1453   CuBlasLtMatmulPreference preference;
1454   preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
1455   cublasLtMatmulHeuristicResult_t heuristicResult = {};
1456   int returnedResult = 0;
1457   cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
1458   TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
1459       ltHandle,
1460       computeDesc.descriptor(),
1461       Adesc.descriptor(),
1462       Bdesc.descriptor(),
1463       Cdesc.descriptor(),
1464       Ddesc.descriptor(),
1465       preference.descriptor(),
1466       1,
1467       &heuristicResult,
1468       &returnedResult));
1469   if (returnedResult == 0) {
1470 #ifndef USE_ROCM
1471     TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
1472 #else
1473     // hipblaslt might be able to recover by returning all algos
1474     std::vector<hipblasLtMatmulHeuristicResult_t> all_algos;
1475     TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAllAlgos(
1476         ltHandle,
1477         hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
1478         _cublasOpFromChar(transa),
1479         _cublasOpFromChar(transb),
1480         ScalarTypeToCudaDataType(mat1_dtype),
1481         ScalarTypeToCudaDataType(mat2_dtype),
1482         // C is nullptr and beta=0, so set to something reasonable. See above.
1483         //ScalarTypeToCudaDataType(bias_dtype),
1484         ScalarTypeToCudaDataType(result_dtype),
1485         ScalarTypeToCudaDataType(result_dtype),
1486         CUBLAS_COMPUTE_32F,
1487         all_algos));
1488     if (all_algos.size() == 0) {
1489       TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
1490     }
1491     // pick first valid solution
1492     bool found = false;
1493     for (size_t i = 0; i < all_algos.size(); i++) {
1494         size_t ret_workspace_size = 0;
1495         auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported(
1496                 ltHandle,
1497                 computeDesc.descriptor(),
1498                 &alpha_val,
1499                 Adesc.descriptor(),
1500                 Bdesc.descriptor(),
1501                 &beta_val,
1502                 Cdesc.descriptor(),
1503                 Ddesc.descriptor(),
1504                 all_algos[i].algo,
1505                 ret_workspace_size);
1506         if (is_valid_status == HIPBLAS_STATUS_SUCCESS) {
1507             if (ret_workspace_size <= workspaceSize) {
1508                 heuristicResult = all_algos[i];
1509                 found = true;
1510                 break;
1511             }
1512         }
1513     }
1514     TORCH_CHECK(found, "could not find valid hipblaslt solution");
1515 #endif
1516   }
1517   cublasStatus_t cublasStatus = cublasLtMatmul(
1518       ltHandle,
1519       computeDesc.descriptor(),
1520       &alpha_val,
1521       mat1_ptr,
1522       Adesc.descriptor(),
1523       mat2_ptr,
1524       Bdesc.descriptor(),
1525       &beta_val,
1526 #ifdef USE_ROCM
1527       result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr
1528 #else
1529       nullptr,
1530 #endif
1531       Cdesc.descriptor(),
1532       result_ptr,
1533       Ddesc.descriptor(),
1534       &heuristicResult.algo,
1535       workspace.mutable_get(),
1536       workspaceSize,
1537       at::cuda::getCurrentCUDAStream());
1538   TORCH_CHECK(
1539       cublasStatus == CUBLAS_STATUS_SUCCESS,
1540       "CUDA error: ",
1541       at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
1542       " when calling cublasLtMatmul with transpose_mat1 ",
1543       transa,
1544       " transpose_mat2 ",
1545       transb,
1546       " m ",
1547       m,
1548       " n ",
1549       n,
1550       " k ",
1551       k,
1552       " mat1_ld ",
1553       mat1_ld,
1554       " mat2_ld ",
1555       mat2_ld,
1556       " result_ld ",
1557       result_ld,
1558       " computeType ",
1559       computeType,
1560       " scaleType ",
1561       scaleType);
1562   return;
1563 #endif // CUDA_VERSION >= 11080 || defined(USE_ROCM)
1564   TORCH_CHECK(false, "scaled_gemm is only supported for CUDA 11.8 and above");
1565 }
1566 
int8_gemm(bool transpose_mat1,bool transpose_mat2,int64_t m,int64_t n,int64_t k,const int8_t * mat1_ptr,int64_t mat1_ld,const int8_t * mat2_ptr,int64_t mat2_ld,int32_t * result_ptr,int64_t result_ld)1567 void int8_gemm(
1568     bool transpose_mat1,
1569     bool transpose_mat2,
1570     int64_t m,
1571     int64_t n,
1572     int64_t k,
1573     const int8_t* mat1_ptr,
1574     int64_t mat1_ld,
1575     const int8_t* mat2_ptr,
1576     int64_t mat2_ld,
1577     int32_t* result_ptr,
1578     int64_t result_ld) {
1579 
1580   cublasComputeType_t computeType = CUBLAS_COMPUTE_32I;
1581   cudaDataType_t scaleType = CUDA_R_32I;
1582 
1583   cudaDataType_t abType = CUDA_R_8I;
1584   cudaDataType_t cType = CUDA_R_32I;
1585 
1586   CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
1587   cublasOperation_t transa = transpose_mat1 ? CUBLAS_OP_T : CUBLAS_OP_N;
1588   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
1589   cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
1590   computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
1591 
1592 
1593   CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
1594   CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
1595   CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);
1596 
1597   // cublas team: alpha and beta need to be the same dtype as of scaleType
1598   at::opmath_type<int32_t> alpha_val = 1;
1599   int32_t beta_val = 0;
1600   cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
1601 
1602 #ifdef USE_ROCM
1603   CuBlasLtMatmulPreference preference;
1604   size_t workspaceSize = _getWorkspaceSize();
1605   preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
1606   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
1607   auto workspace = allocator.allocate(workspaceSize);
1608   cublasLtMatmulHeuristicResult_t heuristicResult = {};
1609   int returnedResult = 0;
1610   TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
1611       ltHandle,
1612       computeDesc.descriptor(),
1613       Adesc.descriptor(),
1614       Bdesc.descriptor(),
1615       Cdesc.descriptor(),
1616       Cdesc.descriptor(),
1617       preference.descriptor(),
1618       1,
1619       &heuristicResult,
1620       &returnedResult));
1621   if (returnedResult == 0) {
1622     TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
1623   }
1624 #endif
1625 
1626   cublasStatus_t cublasStatus = cublasLtMatmul(
1627       ltHandle,
1628       computeDesc.descriptor(),
1629       &alpha_val,
1630       mat1_ptr,
1631       Adesc.descriptor(),
1632       mat2_ptr,
1633       Bdesc.descriptor(),
1634       &beta_val,
1635       result_ptr,
1636       Cdesc.descriptor(),
1637       result_ptr,
1638       Cdesc.descriptor(),
1639 #ifdef USE_ROCM
1640       &heuristicResult.algo,
1641 #else
1642       nullptr, // Heuristics don't seem to work for int8
1643 #endif
1644 #ifdef USE_ROCM
1645       workspace.mutable_get(),
1646 #else
1647       nullptr, // Non-zero workspace doesn't seem to work.
1648 #endif
1649 #ifdef USE_ROCM
1650       workspaceSize,
1651 #else
1652       0,
1653 #endif
1654       at::cuda::getCurrentCUDAStream());
1655   TORCH_CHECK(
1656       cublasStatus == CUBLAS_STATUS_SUCCESS,
1657       "CUDA error: ",
1658       at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
1659       " when calling cublasLtMatmul with transpose_mat1 ",
1660       transpose_mat1,
1661       " transpose_mat2 ",
1662       transpose_mat2,
1663       " m ",
1664       m,
1665       " n ",
1666       n,
1667       " k ",
1668       k,
1669       " mat1_ld ",
1670       mat1_ld,
1671       " mat2_ld ",
1672       mat2_ld,
1673       " result_ld ",
1674       result_ld,
1675       " abType ",
1676       abType,
1677       " cType ",
1678       cType,
1679       " computeType ",
1680       computeType,
1681       " scaleType ",
1682       scaleType);
1683 }
1684 
1685 template <>
trsm(CUDABLAS_TRSM_ARGTYPES (float))1686 void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
1687   TORCH_CUDABLAS_CHECK(cublasStrsm(
1688       handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb));
1689 }
1690 
1691 template <>
trsm(CUDABLAS_TRSM_ARGTYPES (double))1692 void trsm<double>(CUDABLAS_TRSM_ARGTYPES(double)) {
1693   TORCH_CUDABLAS_CHECK(cublasDtrsm(
1694       handle, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb));
1695 }
1696 
1697 template <>
trsm(CUDABLAS_TRSM_ARGTYPES (c10::complex<float>))1698 void trsm<c10::complex<float>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<float>)) {
1699   TORCH_CUDABLAS_CHECK(cublasCtrsm(
1700       handle,
1701       side,
1702       uplo,
1703       trans,
1704       diag,
1705       m,
1706       n,
1707       reinterpret_cast<const cuComplex*>(alpha),
1708       reinterpret_cast<const cuComplex*>(A),
1709       lda,
1710       reinterpret_cast<cuComplex*>(B),
1711       ldb));
1712 }
1713 
1714 template <>
trsm(CUDABLAS_TRSM_ARGTYPES (c10::complex<double>))1715 void trsm<c10::complex<double>>(CUDABLAS_TRSM_ARGTYPES(c10::complex<double>)) {
1716   TORCH_CUDABLAS_CHECK(cublasZtrsm(
1717       handle,
1718       side,
1719       uplo,
1720       trans,
1721       diag,
1722       m,
1723       n,
1724       reinterpret_cast<const cuDoubleComplex*>(alpha),
1725       reinterpret_cast<const cuDoubleComplex*>(A),
1726       lda,
1727       reinterpret_cast<cuDoubleComplex*>(B),
1728       ldb));
1729 }
1730 
1731 template <>
trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES (float))1732 void trsmBatched<float>(CUDABLAS_TRSM_BATCHED_ARGTYPES(float)) {
1733   TORCH_CUDABLAS_CHECK(cublasStrsmBatched(
1734       handle,
1735       side,
1736       uplo,
1737       trans,
1738       diag,
1739       m,
1740       n,
1741       alpha,
1742       A,
1743       lda,
1744       B,
1745       ldb,
1746       batchCount));
1747 }
1748 
1749 template <>
trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES (double))1750 void trsmBatched<double>(CUDABLAS_TRSM_BATCHED_ARGTYPES(double)) {
1751   TORCH_CUDABLAS_CHECK(cublasDtrsmBatched(
1752       handle,
1753       side,
1754       uplo,
1755       trans,
1756       diag,
1757       m,
1758       n,
1759       alpha,
1760       A,
1761       lda,
1762       B,
1763       ldb,
1764       batchCount));
1765 }
1766 
1767 template <>
trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES (c10::complex<float>))1768 void trsmBatched<c10::complex<float>>(
1769     CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<float>)) {
1770   TORCH_CUDABLAS_CHECK(cublasCtrsmBatched(
1771       handle,
1772       side,
1773       uplo,
1774       trans,
1775       diag,
1776       m,
1777       n,
1778       reinterpret_cast<const cuComplex*>(alpha),
1779       reinterpret_cast<cuComplex**>(A),
1780       lda,
1781       reinterpret_cast<cuComplex**>(B),
1782       ldb,
1783       batchCount));
1784 }
1785 
1786 template <>
trsmBatched(CUDABLAS_TRSM_BATCHED_ARGTYPES (c10::complex<double>))1787 void trsmBatched<c10::complex<double>>(
1788     CUDABLAS_TRSM_BATCHED_ARGTYPES(c10::complex<double>)) {
1789   TORCH_CUDABLAS_CHECK(cublasZtrsmBatched(
1790       handle,
1791       side,
1792       uplo,
1793       trans,
1794       diag,
1795       m,
1796       n,
1797       reinterpret_cast<const cuDoubleComplex*>(alpha),
1798       reinterpret_cast<cuDoubleComplex**>(A),
1799       lda,
1800       reinterpret_cast<cuDoubleComplex**>(B),
1801       ldb,
1802       batchCount));
1803 }
1804 
1805 /* LEVEL 2 BLAS FUNCTIONS */
1806 
1807 #define GEMV_CHECK_ARGVALUES(Dtype)           \
1808   do {                                        \
1809     CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, m); \
1810     CUDABLAS_NONNEGINT_CHECK(gemv<Dtype>, n); \
1811     CUDABLAS_POSINT_CHECK(gemv<Dtype>, lda);  \
1812     CUDABLAS_POSINT_CHECK(gemv<Dtype>, incx); \
1813     CUDABLAS_POSINT_CHECK(gemv<Dtype>, incy); \
1814   } while (0)
1815 
1816 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (c10::complex<double>))1817 void gemv<c10::complex<double>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<double>)) {
1818   // See Note [Writing Nondeterministic Operations]
1819   globalContext().alertCuBLASConfigNotDeterministic();
1820   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1821   cublasOperation_t op = _cublasOpFromChar(trans);
1822   _cublasAdjustLdLevel2(m, n, &lda);
1823   GEMV_CHECK_ARGVALUES(c10::complex<double>);
1824   TORCH_CUDABLAS_CHECK(
1825       cublasZgemv(handle, op, m, n, reinterpret_cast<const cuDoubleComplex*>(&alpha), reinterpret_cast<const cuDoubleComplex*>(a),
1826       lda, reinterpret_cast<const cuDoubleComplex*>(x), incx, reinterpret_cast<const cuDoubleComplex*>(&beta),
1827       reinterpret_cast<cuDoubleComplex*>(y), incy));
1828 }
1829 
1830 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (c10::complex<float>))1831 void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
1832   // gemv is bw bound, and does not benefit from TF32. But the precision
1833   // loss still happens on TF32. So we disable it here.
1834   NoTF32Guard disable_tf32;
1835   // See Note [Writing Nondeterministic Operations]
1836   globalContext().alertCuBLASConfigNotDeterministic();
1837   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1838   cublasOperation_t op = _cublasOpFromChar(trans);
1839   _cublasAdjustLdLevel2(m, n, &lda);
1840   GEMV_CHECK_ARGVALUES(c10::complex<float>);
1841   TORCH_CUDABLAS_CHECK(
1842       cublasCgemv(handle, op, m, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
1843       lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta),
1844       reinterpret_cast<cuComplex*>(y), incy));
1845 }
1846 
1847 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (double))1848 void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
1849   // See Note [Writing Nondeterministic Operations]
1850   globalContext().alertCuBLASConfigNotDeterministic();
1851   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1852   cublasOperation_t op = _cublasOpFromChar(trans);
1853   _cublasAdjustLdLevel2(m, n, &lda);
1854   GEMV_CHECK_ARGVALUES(double);
1855   TORCH_CUDABLAS_CHECK(
1856       cublasDgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy));
1857 }
1858 
1859 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (float))1860 void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) {
1861   // gemv is bw bound, and does not benefit from TF32. But the precision
1862   // loss still happens on TF32. So we disable it here.
1863   NoTF32Guard disable_tf32;
1864   // See Note [Writing Nondeterministic Operations]
1865   globalContext().alertCuBLASConfigNotDeterministic();
1866   cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
1867   cublasOperation_t op = _cublasOpFromChar(trans);
1868   _cublasAdjustLdLevel2(m, n, &lda);
1869   GEMV_CHECK_ARGVALUES(float);
1870   TORCH_CUDABLAS_CHECK(
1871       cublasSgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy));
1872 }
1873 
1874 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (at::Half))1875 void gemv<at::Half>(CUDABLAS_GEMV_ARGTYPES(at::Half)) {
1876   // In general, cublas regards matrices as column-major.
1877   // The cublasS/Dgemv usages in cuda::blas::gemv<float>/<double> above
1878   // require that external blas::gemv callers obey the following convention:
1879   //
1880   // If "a" is row-major with shape (output, summed) in blas::gemv's caller,
1881   // caller interprets it as column-major with shape (summed, output), passes
1882   // summed and output respectively to our local vars m, n, and requests that cublas
1883   // internally transpose ("trans") the column-major interpretation of a.
1884   //
1885   // There's no such thing as "cublasHalfgemv", so here we hack gemv with a gemm.
1886   // However, we must allow the same calling convention, because the caller shouldn't
1887   // have to swap args based on whether it's calling blas::gemv<at::Half> or <float>.
1888 
1889   bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
1890   if (trans_bool) {
1891     std::swap(m, n);
1892   }
1893   // After swap, local vars m, n contain the output and summed sizes respectively,
1894   // regardless of whether "a" was row-major or column-major in gemv<>'s caller.
1895 
1896   // To handle the possibility incy > 1, interprets vector y as column-major matrix with one row
1897   // (shape (1, output)) and leading dim incy.
1898   // trans(a)*x would compute a matrix with one column (shape (output, 1)) which wouldn't match y.
1899   // So instead, we interpret x similarly to y, as a column-major matrix with one row
1900   // (shape (1, summed)) and leading dim incx.  The gemm then carries out x*transpose(trans(a)) to
1901   // produce a matrix with one row (shape (1, output)), matching y.
1902   char trans_flipped = (trans_bool ? 'n' : 't');
1903   gemm<at::Half>(
1904       'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
1905 }
1906 
1907 template <>
gemv(CUDABLAS_GEMV_ARGTYPES (at::BFloat16))1908 void gemv<at::BFloat16>(CUDABLAS_GEMV_ARGTYPES(at::BFloat16)) {
1909   bool trans_bool = (_cublasOpFromChar(trans) != CUBLAS_OP_N);
1910   if (trans_bool) {
1911     std::swap(m, n);
1912   }
1913   char trans_flipped = (trans_bool ? 'n' : 't');
1914   gemm<at::BFloat16>(
1915       'n', trans_flipped, 1, m, n, alpha, x, incx, a, lda, beta, y, incy);
1916 }
1917 
1918 /* LEVEL 1 BLAS FUNCTIONS */
1919 
1920 template <>
dot(CUDABLAS_DOT_ARGTYPES (double))1921 void dot<double>(CUDABLAS_DOT_ARGTYPES(double)) {
1922   TORCH_CUDABLAS_CHECK(cublasDdot(handle, n, x, incx, y, incy, result));
1923 }
1924 
1925 template <>
dot(CUDABLAS_DOT_ARGTYPES (float))1926 void dot<float>(CUDABLAS_DOT_ARGTYPES(float)) {
1927   TORCH_CUDABLAS_CHECK(cublasSdot(handle, n, x, incx, y, incy, result));
1928 }
1929 
1930 template <>
dot(CUDABLAS_DOT_ARGTYPES (c10::complex<double>))1931 void dot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
1932   TORCH_CUDABLAS_CHECK(cublasZdotu(handle, n, reinterpret_cast<const cuDoubleComplex*>(x),
1933                                    incx, reinterpret_cast<const cuDoubleComplex*>(y), incy,
1934                                    reinterpret_cast<cuDoubleComplex*>(result)));
1935 }
1936 
1937 template <>
dot(CUDABLAS_DOT_ARGTYPES (c10::complex<float>))1938 void dot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
1939   TORCH_CUDABLAS_CHECK(cublasCdotu(handle, n, reinterpret_cast<const cuComplex*>(x),
1940                                    incx, reinterpret_cast<const cuComplex*>(y), incy,
1941                                    reinterpret_cast<cuComplex*>(result)));
1942 }
1943 
1944 template <>
dot(CUDABLAS_DOT_ARGTYPES (at::Half))1945 void dot<at::Half>(CUDABLAS_DOT_ARGTYPES(at::Half)) {
1946   TORCH_CUDABLAS_CHECK(cublasDotEx(
1947       handle,
1948       n,
1949       x,
1950       CUDA_R_16F,
1951       incx,
1952       y,
1953       CUDA_R_16F,
1954       incy,
1955       result,
1956       CUDA_R_16F,
1957       CUDA_R_32F));
1958 }
1959 
1960 template <>
dot(CUDABLAS_DOT_ARGTYPES (at::BFloat16))1961 void dot<at::BFloat16>(CUDABLAS_DOT_ARGTYPES(at::BFloat16)) {
1962   TORCH_CUDABLAS_CHECK(cublasDotEx(
1963       handle,
1964       n,
1965       x,
1966       CUDA_R_16BF,
1967       incx,
1968       y,
1969       CUDA_R_16BF,
1970       incy,
1971       result,
1972       CUDA_R_16BF,
1973       CUDA_R_32F));
1974 }
1975 
1976 template <>
vdot(CUDABLAS_DOT_ARGTYPES (c10::complex<float>))1977 void vdot<c10::complex<float>>(CUDABLAS_DOT_ARGTYPES(c10::complex<float>)) {
1978   TORCH_CUDABLAS_CHECK(cublasCdotc(handle, n, reinterpret_cast<const cuComplex*>(x),
1979                                    incx, reinterpret_cast<const cuComplex*>(y), incy,
1980                                    reinterpret_cast<cuComplex*>(result)));
1981 }
1982 
1983 template <>
vdot(CUDABLAS_DOT_ARGTYPES (c10::complex<double>))1984 void vdot<c10::complex<double>>(CUDABLAS_DOT_ARGTYPES(c10::complex<double>)) {
1985   TORCH_CUDABLAS_CHECK(cublasZdotc(handle, n, reinterpret_cast<const cuDoubleComplex*>(x),
1986                                    incx, reinterpret_cast<const cuDoubleComplex*>(y), incy,
1987                                    reinterpret_cast<cuDoubleComplex*>(result)));
1988 }
1989 
1990 template <>
getrsBatched(CUDABLAS_GETRS_ARGTYPES (float))1991 void getrsBatched<float>(CUDABLAS_GETRS_ARGTYPES(float)) {
1992   TORCH_CUDABLAS_CHECK(cublasSgetrsBatched(
1993       handle,
1994       trans,
1995       n,
1996       nrhs,
1997       dA_array,
1998       lda,
1999       ipiv_array,
2000       dB_array,
2001       ldb,
2002       info_array,
2003       batchsize));
2004 }
2005 
2006 template <>
getrsBatched(CUDABLAS_GETRS_ARGTYPES (double))2007 void getrsBatched<double>(CUDABLAS_GETRS_ARGTYPES(double)) {
2008   TORCH_CUDABLAS_CHECK(cublasDgetrsBatched(
2009       handle,
2010       trans,
2011       n,
2012       nrhs,
2013       dA_array,
2014       lda,
2015       ipiv_array,
2016       dB_array,
2017       ldb,
2018       info_array,
2019       batchsize));
2020 }
2021 
2022 template <>
getrsBatched(CUDABLAS_GETRS_ARGTYPES (c10::complex<float>))2023 void getrsBatched<c10::complex<float>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<float>)) {
2024   TORCH_CUDABLAS_CHECK(cublasCgetrsBatched(
2025       handle,
2026       trans,
2027       n,
2028       nrhs,
2029       reinterpret_cast<cuComplex**>(dA_array),
2030       lda,
2031       ipiv_array,
2032       reinterpret_cast<cuComplex**>(dB_array),
2033       ldb,
2034       info_array,
2035       batchsize));
2036 }
2037 
2038 template <>
getrsBatched(CUDABLAS_GETRS_ARGTYPES (c10::complex<double>))2039 void getrsBatched<c10::complex<double>>(CUDABLAS_GETRS_ARGTYPES(c10::complex<double>)) {
2040   TORCH_CUDABLAS_CHECK(cublasZgetrsBatched(
2041       handle,
2042       trans,
2043       n,
2044       nrhs,
2045       reinterpret_cast<cuDoubleComplex**>(dA_array),
2046       lda,
2047       ipiv_array,
2048       reinterpret_cast<cuDoubleComplex**>(dB_array),
2049       ldb,
2050       info_array,
2051       batchsize));
2052 }
2053 
2054 template <>
geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES (float))2055 void geqrfBatched<float>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(float)) {
2056   TORCH_CUDABLAS_CHECK(cublasSgeqrfBatched(
2057       handle, m, n, A_array, lda, tau_array, info, batchsize));
2058 }
2059 
2060 template <>
geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES (double))2061 void geqrfBatched<double>(CUDABLAS_GEQRF_BATCHED_ARGTYPES(double)) {
2062   TORCH_CUDABLAS_CHECK(cublasDgeqrfBatched(
2063       handle, m, n, A_array, lda, tau_array, info, batchsize));
2064 }
2065 
2066 template <>
geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES (c10::complex<float>))2067 void geqrfBatched<c10::complex<float>>(
2068     CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<float>)) {
2069   TORCH_CUDABLAS_CHECK(cublasCgeqrfBatched(
2070       handle,
2071       m,
2072       n,
2073       reinterpret_cast<cuComplex**>(A_array),
2074       lda,
2075       reinterpret_cast<cuComplex**>(tau_array),
2076       info,
2077       batchsize));
2078 }
2079 
2080 template <>
geqrfBatched(CUDABLAS_GEQRF_BATCHED_ARGTYPES (c10::complex<double>))2081 void geqrfBatched<c10::complex<double>>(
2082     CUDABLAS_GEQRF_BATCHED_ARGTYPES(c10::complex<double>)) {
2083   TORCH_CUDABLAS_CHECK(cublasZgeqrfBatched(
2084       handle,
2085       m,
2086       n,
2087       reinterpret_cast<cuDoubleComplex**>(A_array),
2088       lda,
2089       reinterpret_cast<cuDoubleComplex**>(tau_array),
2090       info,
2091       batchsize));
2092 }
2093 
2094 template <>
getrfBatched(int n,double ** dA_array,int ldda,int * ipiv_array,int * info_array,int batchsize)2095 void getrfBatched<double>(
2096     int n, double** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) {
2097   auto handle = at::cuda::getCurrentCUDABlasHandle();
2098   TORCH_CUDABLAS_CHECK(cublasDgetrfBatched(
2099       handle, n, dA_array, ldda, ipiv_array, info_array, batchsize));
2100 }
2101 
2102 template <>
getrfBatched(int n,float ** dA_array,int ldda,int * ipiv_array,int * info_array,int batchsize)2103 void getrfBatched<float>(
2104     int n, float** dA_array, int ldda, int* ipiv_array, int* info_array, int batchsize) {
2105   auto handle = at::cuda::getCurrentCUDABlasHandle();
2106   TORCH_CUDABLAS_CHECK(cublasSgetrfBatched(
2107       handle, n, dA_array, ldda, ipiv_array, info_array, batchsize));
2108 }
2109 
2110 template <>
getrfBatched(int n,c10::complex<double> ** dA_array,int ldda,int * ipiv_array,int * info_array,int batchsize)2111 void getrfBatched<c10::complex<double>>(
2112     int n,
2113     c10::complex<double>** dA_array,
2114     int ldda,
2115     int* ipiv_array,
2116     int* info_array,
2117     int batchsize) {
2118   auto handle = at::cuda::getCurrentCUDABlasHandle();
2119   TORCH_CUDABLAS_CHECK(cublasZgetrfBatched(
2120       handle,
2121       n,
2122       reinterpret_cast<cuDoubleComplex**>(dA_array),
2123       ldda,
2124       ipiv_array,
2125       info_array,
2126       batchsize));
2127 }
2128 
2129 template <>
getrfBatched(int n,c10::complex<float> ** dA_array,int ldda,int * ipiv_array,int * info_array,int batchsize)2130 void getrfBatched<c10::complex<float>>(
2131     int n,
2132     c10::complex<float>** dA_array,
2133     int ldda,
2134     int* ipiv_array,
2135     int* info_array,
2136     int batchsize) {
2137   auto handle = at::cuda::getCurrentCUDABlasHandle();
2138   TORCH_CUDABLAS_CHECK(cublasCgetrfBatched(
2139       handle,
2140       n,
2141       reinterpret_cast<cuComplex**>(dA_array),
2142       ldda,
2143       ipiv_array,
2144       info_array,
2145       batchsize));
2146 }
2147 
2148 
2149 template <>
gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES (double))2150 void gelsBatched<double>(CUDABLAS_GELS_BATCHED_ARGTYPES(double)) {
2151   TORCH_CUDABLAS_CHECK(cublasDgelsBatched(
2152       handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize));
2153 }
2154 
2155 template <>
gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES (float))2156 void gelsBatched<float>(CUDABLAS_GELS_BATCHED_ARGTYPES(float)) {
2157   TORCH_CUDABLAS_CHECK(cublasSgelsBatched(
2158       handle, trans, m, n, nrhs, dA_array, ldda, dC_array, lddc, info, devInfoArray, batchSize));
2159 }
2160 
2161 template <>
gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES (c10::complex<double>))2162 void gelsBatched<c10::complex<double>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<double>)) {
2163   TORCH_CUDABLAS_CHECK(cublasZgelsBatched(
2164       handle, trans,
2165       m, n, nrhs,
2166       reinterpret_cast<cuDoubleComplex**>(dA_array),
2167       ldda,
2168       reinterpret_cast<cuDoubleComplex**>(dC_array),
2169       lddc,
2170       info,
2171       devInfoArray,
2172       batchSize));
2173 }
2174 
2175 template <>
gelsBatched(CUDABLAS_GELS_BATCHED_ARGTYPES (c10::complex<float>))2176 void gelsBatched<c10::complex<float>>(CUDABLAS_GELS_BATCHED_ARGTYPES(c10::complex<float>)) {
2177   TORCH_CUDABLAS_CHECK(cublasCgelsBatched(
2178       handle, trans,
2179       m, n, nrhs,
2180       reinterpret_cast<cuComplex**>(dA_array),
2181       ldda,
2182       reinterpret_cast<cuComplex**>(dC_array),
2183       lddc,
2184       info,
2185       devInfoArray,
2186       batchSize));
2187 }
2188 
2189 } // namespace at::cuda::blas
2190