xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/cuda/CUDAContext.h>
3 #include <c10/util/Exception.h>
4 #include <ATen/cuda/Exceptions.h>
5 #include <ATen/native/sparse/cuda/SparseCUDABlas.h>
6 #include <c10/cuda/CUDACachingAllocator.h>
7 
8 #include <cusparse.h>
9 
10 // LIMITATION (cusparseSpMM):
11 // The generic APIs are available on all platforms on CUDA 11.0
12 // For CUDA 10.1+ it is available for all platforms except Windows.
13 // Using these APIs in any other systems will result in compile-time or run-time failures.
14 // Their support will be extended in the next releases.
15 
16 #if defined(CUDART_VERSION) && (CUSPARSE_VERSION >= 11000 || (!defined(_MSC_VER) && CUSPARSE_VERSION >= 10301))
17 #define IS_SPMM_AVAILABLE() 1
18 #else
19 #define IS_SPMM_AVAILABLE() 0
20 #endif
21 
22 #if defined(USE_ROCM)
23 #define IS_SPMM_HIP_AVAILABLE() 1
24 #else
25 #define IS_SPMM_HIP_AVAILABLE() 0
26 #endif
27 
28 #if IS_SPMM_AVAILABLE() || IS_SPMM_HIP_AVAILABLE()
29 #include <library_types.h>
30 #endif
31 
32 #if !defined(CUSPARSE_VERSION) || (CUSPARSE_VERSION < 10100)
cusparseGetErrorString(cusparseStatus_t status)33 const char* cusparseGetErrorString(cusparseStatus_t status) {
34   switch(status)
35   {
36     case CUSPARSE_STATUS_SUCCESS:
37       return "success";
38 
39     case CUSPARSE_STATUS_NOT_INITIALIZED:
40       return "library not initialized";
41 
42     case CUSPARSE_STATUS_ALLOC_FAILED:
43       return "resource allocation failed";
44 
45     case CUSPARSE_STATUS_INVALID_VALUE:
46       return "an invalid numeric value was used as an argument";
47 
48     case CUSPARSE_STATUS_ARCH_MISMATCH:
49       return "an absent device architectural feature is required";
50 
51     case CUSPARSE_STATUS_MAPPING_ERROR:
52       return "an access to GPU memory space failed";
53 
54     case CUSPARSE_STATUS_EXECUTION_FAILED:
55       return "the GPU program failed to execute";
56 
57     case CUSPARSE_STATUS_INTERNAL_ERROR:
58       return "an internal operation failed";
59 
60     case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
61       return "the matrix type is not supported by this function";
62 
63     case CUSPARSE_STATUS_ZERO_PIVOT:
64       return "an entry of the matrix is either structural zero or numerical zero (singular block)";
65 
66     default:
67       return "unknown error";
68   }
69 }
70 #endif
71 
72 namespace at::native::sparse::cuda {
73 
Xcoo2csr(const int * coorowind,int64_t nnz,int64_t m,int * csrrowptr)74 void Xcoo2csr(const int *coorowind, int64_t nnz, int64_t m, int *csrrowptr) {
75   TORCH_CHECK((m <= INT_MAX) && (nnz <= INT_MAX),
76     "cusparseXcoo2csr only supports m, nnz with the bound [val] <= ",
77     INT_MAX);
78 
79   int i_nnz = (int)nnz;
80   int i_m = (int)m;
81 
82   auto handle = at::cuda::getCurrentCUDASparseHandle();
83   TORCH_CUDASPARSE_CHECK(cusparseXcoo2csr(handle, coorowind, i_nnz, i_m, csrrowptr, CUSPARSE_INDEX_BASE_ZERO));
84 }
85 
convertTransToCusparseOperation(char trans)86 cusparseOperation_t convertTransToCusparseOperation(char trans) {
87   if (trans == 't') return CUSPARSE_OPERATION_TRANSPOSE;
88   else if (trans == 'n') return CUSPARSE_OPERATION_NON_TRANSPOSE;
89   else if (trans == 'c') return CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE;
90   else {
91     AT_ERROR("trans must be one of: t, n, c");
92   }
93 }
94 
95 #if IS_SPMM_AVAILABLE() || IS_SPMM_HIP_AVAILABLE()
96 
97 namespace {
98 template<typename T>
_csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,T * alpha,T * csrvala,int * csrrowptra,int * csrcolinda,T * b,int64_t ldb,T * beta,T * c,int64_t ldc,cudaDataType cusparse_value_type)99 void _csrmm2(
100   char transa, char transb,
101   int64_t m, int64_t n, int64_t k, int64_t nnz,
102   T *alpha, T *csrvala, int *csrrowptra, int *csrcolinda,
103   T *b, int64_t ldb, T *beta, T *c, int64_t ldc,
104   cudaDataType cusparse_value_type)
105 {
106   if (csrvala == nullptr || b == nullptr || c == nullptr) return;
107 
108   cusparseOperation_t opa = convertTransToCusparseOperation(transa);
109   cusparseOperation_t opb = convertTransToCusparseOperation(transb);
110 
111   // cusparseSpMM actually supports int64_t.
112   // In order to support int64 here, index pointers csrrowptra, csrcolinda have to be passed as int64_t.
113   TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX),
114     "At the moment, cusparseSpMM only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX, ".",
115     "If you need this, please file an issue on GitHub."
116   );
117 
118   int64_t ma = m, ka = k;
119   if (transa != 'n') std::swap(ma, ka);
120 
121   cusparseSpMatDescr_t descA;
122   TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
123     &descA,                     /* output */
124     ma, ka, nnz,                /* rows, cols, number of non zero elements */
125     csrrowptra,                 /* row offsets of the sparse matrix, size = rows +1 */
126     csrcolinda,                 /* column indices of the sparse matrix, size = nnz */
127     csrvala,                    /* values of the sparse matrix, size = nnz */
128     CUSPARSE_INDEX_32I,         /* data type of row offsets index */
129     CUSPARSE_INDEX_32I,         /* data type of col indices */
130     CUSPARSE_INDEX_BASE_ZERO,   /* base index of row offset and col indes */
131     cusparse_value_type         /* data type of values */
132   ));
133 
134   int64_t kb = k, nb = n;
135   if (transb != 'n') std::swap(kb, nb);
136 
137   cusparseDnMatDescr_t descB;
138   TORCH_CUDASPARSE_CHECK(cusparseCreateDnMat(
139     &descB,               /* output */
140     kb, nb, ldb,          /* rows, cols, leading dimension */
141     b,                    /* values */
142     cusparse_value_type,  /* data type of values */
143     CUSPARSE_ORDER_COL    /* memory layout, ONLY column-major is supported now */
144   ));
145 
146   cusparseDnMatDescr_t descC;
147   TORCH_CUDASPARSE_CHECK(cusparseCreateDnMat(
148     &descC,               /* output */
149     m, n, ldc,            /* rows, cols, leading dimension */
150     c,                    /* values */
151     cusparse_value_type,  /* data type of values */
152     CUSPARSE_ORDER_COL    /* memory layout, ONLY column-major is supported now */
153   ));
154 
155 
156   auto handle = at::cuda::getCurrentCUDASparseHandle();
157   cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
158   // ALG1 is broken on SM89 as of CUDA 11.8+
159 #if !defined(USE_ROCM)
160   auto default_alg = prop->major == 8 && prop->minor == 9 ? CUSPARSE_SPMM_CSR_ALG2 : CUSPARSE_SPMM_CSR_ALG1;
161 #else
162   auto default_alg = CUSPARSE_SPMM_CSR_ALG1;
163 #endif
164 
165   // cusparseSpMM_bufferSize returns the bufferSize that can be used by cusparseSpMM
166   size_t bufferSize;
167   TORCH_CUDASPARSE_CHECK(cusparseSpMM_bufferSize(
168     handle, opa, opb,
169     alpha,
170     descA, descB,
171     beta,
172     descC,
173     cusparse_value_type,      /* data type in which the computation is executed */
174     default_alg,              /* default computing algorithm for CSR sparse matrix format */
175     &bufferSize               /* output */
176   ));
177 
178   auto& allocator = *c10::cuda::CUDACachingAllocator::get();
179   auto dataPtr = allocator.allocate(bufferSize);
180 
181   TORCH_CUDASPARSE_CHECK(cusparseSpMM(
182     handle, opa, opb,
183     alpha,
184     descA, descB,
185     beta,
186     descC,
187     cusparse_value_type,      /* data type in which the computation is executed */
188     default_alg,              /* default computing algorithm for CSR sparse matrix format */
189     dataPtr.get()             /* external buffer */
190   ));
191 
192   TORCH_CUDASPARSE_CHECK(cusparseDestroySpMat(descA));
193   TORCH_CUDASPARSE_CHECK(cusparseDestroyDnMat(descB));
194   TORCH_CUDASPARSE_CHECK(cusparseDestroyDnMat(descC));
195 
196   // TODO: Proper fix is to create real descriptor classes
197 }
198 } // end anonymous namespace
199 
200 template<typename T>
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,T alpha,T * csrvala,int * csrrowptra,int * csrcolinda,T * b,int64_t ldb,T beta,T * c,int64_t ldc)201 void csrmm2(
202   char transa, char transb,
203   int64_t m, int64_t n, int64_t k, int64_t nnz,
204   T alpha, T *csrvala, int *csrrowptra, int *csrcolinda,
205   T *b, int64_t ldb, T beta, T *c, int64_t ldc)
206 {
207   static_assert(false&&sizeof(T), "cusparse csr MM only supports data type of float, double, cfloat and cdouble.");
208 }
209 
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,float alpha,float * csrvala,int * csrrowptra,int * csrcolinda,float * b,int64_t ldb,float beta,float * c,int64_t ldc)210 template<> void csrmm2<float>(
211   char transa, char transb,
212   int64_t m, int64_t n, int64_t k, int64_t nnz,
213   float alpha, float *csrvala, int *csrrowptra, int *csrcolinda,
214   float *b, int64_t ldb, float beta, float *c, int64_t ldc)
215 {
216   _csrmm2(transa, transb, m, n, k, nnz, &alpha, csrvala, csrrowptra, csrcolinda, b, ldb, &beta, c, ldc, CUDA_R_32F);
217 }
218 
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,double alpha,double * csrvala,int * csrrowptra,int * csrcolinda,double * b,int64_t ldb,double beta,double * c,int64_t ldc)219 template<> void csrmm2<double>(
220   char transa, char transb,
221   int64_t m, int64_t n, int64_t k, int64_t nnz,
222   double alpha, double *csrvala, int *csrrowptra, int *csrcolinda,
223   double *b, int64_t ldb, double beta, double *c, int64_t ldc)
224 {
225   _csrmm2(transa, transb, m, n, k, nnz, &alpha, csrvala, csrrowptra, csrcolinda, b, ldb, &beta, c, ldc, CUDA_R_64F);
226 }
227 
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,c10::complex<float> alpha,c10::complex<float> * csrvala,int * csrrowptra,int * csrcolinda,c10::complex<float> * b,int64_t ldb,c10::complex<float> beta,c10::complex<float> * c,int64_t ldc)228 template<> void csrmm2<c10::complex<float>>(
229   char transa, char transb,
230   int64_t m, int64_t n, int64_t k, int64_t nnz,
231   c10::complex<float> alpha, c10::complex<float> *csrvala, int *csrrowptra, int *csrcolinda,
232   c10::complex<float> *b, int64_t ldb, c10::complex<float> beta, c10::complex<float> *c, int64_t ldc)
233 {
234   _csrmm2(transa, transb, m, n, k, nnz,
235     reinterpret_cast<cuComplex*>(&alpha),
236     reinterpret_cast<cuComplex*>(csrvala),
237     csrrowptra,
238     csrcolinda,
239     reinterpret_cast<cuComplex*>(b),
240     ldb,
241     reinterpret_cast<cuComplex*>(&beta),
242     reinterpret_cast<cuComplex*>(c), ldc, CUDA_C_32F);
243 }
244 
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,c10::complex<double> alpha,c10::complex<double> * csrvala,int * csrrowptra,int * csrcolinda,c10::complex<double> * b,int64_t ldb,c10::complex<double> beta,c10::complex<double> * c,int64_t ldc)245 template<> void csrmm2<c10::complex<double>>(
246   char transa, char transb,
247   int64_t m, int64_t n, int64_t k, int64_t nnz,
248   c10::complex<double> alpha, c10::complex<double> *csrvala, int *csrrowptra, int *csrcolinda,
249   c10::complex<double> *b, int64_t ldb, c10::complex<double> beta, c10::complex<double> *c, int64_t ldc)
250 {
251   _csrmm2(transa, transb, m, n, k, nnz,
252     reinterpret_cast<cuDoubleComplex*>(&alpha),
253     reinterpret_cast<cuDoubleComplex*>(csrvala),
254     csrrowptra,
255     csrcolinda,
256     reinterpret_cast<cuDoubleComplex*>(b),
257     ldb,
258     reinterpret_cast<cuDoubleComplex*>(&beta),
259     reinterpret_cast<cuDoubleComplex*>(c), ldc, CUDA_C_64F);
260 }
261 
262 #else
263 
adjustLd(char transb,int64_t m,int64_t n,int64_t k,int64_t * ldb,int64_t * ldc)264 void adjustLd(char transb, int64_t m, int64_t n, int64_t k, int64_t *ldb, int64_t *ldc)
265 {
266   int transb_ = ((transb == 't') || (transb == 'T'));
267 
268   if(n == 1)
269     *ldc = m;
270 
271   if(transb_)
272   {
273     if(k == 1)
274       *ldb = n;
275   }
276   else
277   {
278     if(n == 1)
279       *ldb = k;
280   }
281 }
282 
Scsrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,const float * alpha,const float * csrvala,int * csrrowptra,int * csrcolinda,const float * b,int64_t ldb,const float * beta,float * c,int64_t ldc)283 void Scsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const float *alpha, const float *csrvala, int *csrrowptra, int *csrcolinda, const float *b, int64_t ldb, const float *beta, float *c, int64_t ldc)
284 {
285   adjustLd(transb, m, n, k, &ldb, &ldc);
286   cusparseOperation_t opa = convertTransToCusparseOperation(transa);
287   cusparseOperation_t opb = convertTransToCusparseOperation(transb);
288 
289   TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX)  && (ldb <= INT_MAX) && (ldc <= INT_MAX),
290     "cusparseScsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX);
291   int i_m = (int)m;
292   int i_n = (int)n;
293   int i_k = (int)k;
294   int i_nnz = (int)nnz;
295   int i_ldb = (int)ldb;
296   int i_ldc = (int)ldc;
297 
298   auto handle = at::cuda::getCurrentCUDASparseHandle();
299   cusparseMatDescr_t desc;
300   cusparseCreateMatDescr(&desc);
301   TORCH_CUDASPARSE_CHECK(cusparseScsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc));
302   TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc));
303 }
304 
Dcsrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,const double * alpha,const double * csrvala,int * csrrowptra,int * csrcolinda,const double * b,int64_t ldb,const double * beta,double * c,int64_t ldc)305 void Dcsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const double *alpha, const double *csrvala, int *csrrowptra, int *csrcolinda, const double *b, int64_t ldb, const double *beta, double *c, int64_t ldc)
306 {
307   adjustLd(transb, m, n, k, &ldb, &ldc);
308   cusparseOperation_t opa = convertTransToCusparseOperation(transa);
309   cusparseOperation_t opb = convertTransToCusparseOperation(transb);
310 
311   TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX)  && (ldb <= INT_MAX) && (ldc <= INT_MAX),
312     "cusparseDcsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX);
313   int i_m = (int)m;
314   int i_n = (int)n;
315   int i_k = (int)k;
316   int i_nnz = (int)nnz;
317   int i_ldb = (int)ldb;
318   int i_ldc = (int)ldc;
319 
320 
321   auto handle = at::cuda::getCurrentCUDASparseHandle();
322   cusparseMatDescr_t desc;
323   cusparseCreateMatDescr(&desc);
324   TORCH_CUDASPARSE_CHECK(cusparseDcsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc));
325   TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc));
326   // TODO: Proper fix is to create real descriptor classes
327 }
328 
329 template<class complex_target_t>
Ccsrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,const complex_target_t * alpha,const complex_target_t * csrvala,int * csrrowptra,int * csrcolinda,const complex_target_t * b,int64_t ldb,const complex_target_t * beta,complex_target_t * c,int64_t ldc)330 void Ccsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const complex_target_t *alpha, const complex_target_t *csrvala, int *csrrowptra, int *csrcolinda, const complex_target_t *b, int64_t ldb, const complex_target_t *beta, complex_target_t *c, int64_t ldc)
331 {
332   adjustLd(transb, m, n, k, &ldb, &ldc);
333   cusparseOperation_t opa = convertTransToCusparseOperation(transa);
334   cusparseOperation_t opb = convertTransToCusparseOperation(transb);
335 
336   TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX)  && (ldb <= INT_MAX) && (ldc <= INT_MAX),
337     "cusparseCcsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX);
338   int i_m = (int)m;
339   int i_n = (int)n;
340   int i_k = (int)k;
341   int i_nnz = (int)nnz;
342   int i_ldb = (int)ldb;
343   int i_ldc = (int)ldc;
344 
345   auto handle = at::cuda::getCurrentCUDASparseHandle();
346   cusparseMatDescr_t desc;
347   cusparseCreateMatDescr(&desc);
348   TORCH_CUDASPARSE_CHECK(cusparseCcsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc));
349   TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc));
350 }
351 
352 template<class complex_target_t>
Zcsrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,const complex_target_t * alpha,const complex_target_t * csrvala,int * csrrowptra,int * csrcolinda,const complex_target_t * b,int64_t ldb,const complex_target_t * beta,complex_target_t * c,int64_t ldc)353 void Zcsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const complex_target_t *alpha, const complex_target_t *csrvala, int *csrrowptra, int *csrcolinda, const complex_target_t *b, int64_t ldb, const complex_target_t *beta, complex_target_t *c, int64_t ldc)
354 {
355   adjustLd(transb, m, n, k, &ldb, &ldc);
356   cusparseOperation_t opa = convertTransToCusparseOperation(transa);
357   cusparseOperation_t opb = convertTransToCusparseOperation(transb);
358 
359   TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX)  && (ldb <= INT_MAX) && (ldc <= INT_MAX),
360     "cusparseZcsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX);
361   int i_m = (int)m;
362   int i_n = (int)n;
363   int i_k = (int)k;
364   int i_nnz = (int)nnz;
365   int i_ldb = (int)ldb;
366   int i_ldc = (int)ldc;
367 
368 
369   auto handle = at::cuda::getCurrentCUDASparseHandle();
370   cusparseMatDescr_t desc;
371   cusparseCreateMatDescr(&desc);
372   TORCH_CUDASPARSE_CHECK(cusparseZcsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc));
373   TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc));
374 }
375 
376 // T can only be float or double
377 template<typename T>
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,T alpha,T * csrvala,int * csrrowptra,int * csrcolinda,T * b,int64_t ldb,T beta,T * c,int64_t ldc)378 void csrmm2(
379   char transa, char transb,
380   int64_t m, int64_t n, int64_t k, int64_t nnz,
381   T alpha, T *csrvala, int *csrrowptra, int *csrcolinda,
382   T *b, int64_t ldb, T beta, T *c, int64_t ldc)
383 {
384   static_assert(false&&sizeof(T), "cusparse csr MM only supports data type of float, double, cfloat and cdouble.");
385 }
386 
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,float alpha,float * csrvala,int * csrrowptra,int * csrcolinda,float * b,int64_t ldb,float beta,float * c,int64_t ldc)387 template<> void csrmm2<float>(
388   char transa, char transb,
389   int64_t m, int64_t n, int64_t k, int64_t nnz,
390   float alpha, float *csrvala, int *csrrowptra, int *csrcolinda,
391   float *b, int64_t ldb, float beta, float *c, int64_t ldc)
392 {
393   Scsrmm2(transa, transb, m, n, k, nnz, &alpha, csrvala, csrrowptra, csrcolinda, b, ldb, &beta, c, ldc);
394 }
395 
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,double alpha,double * csrvala,int * csrrowptra,int * csrcolinda,double * b,int64_t ldb,double beta,double * c,int64_t ldc)396 template<> void csrmm2<double>(
397   char transa, char transb,
398   int64_t m, int64_t n, int64_t k, int64_t nnz,
399   double alpha, double *csrvala, int *csrrowptra, int *csrcolinda,
400   double *b, int64_t ldb, double beta, double *c, int64_t ldc)
401 {
402   Dcsrmm2(transa, transb, m, n, k, nnz, &alpha, csrvala, csrrowptra, csrcolinda, b, ldb, &beta, c, ldc);
403 }
404 
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,c10::complex<float> alpha,c10::complex<float> * csrvala,int * csrrowptra,int * csrcolinda,c10::complex<float> * b,int64_t ldb,c10::complex<float> beta,c10::complex<float> * c,int64_t ldc)405 template<> void csrmm2<c10::complex<float>>(
406   char transa, char transb,
407   int64_t m, int64_t n, int64_t k, int64_t nnz,
408   c10::complex<float> alpha, c10::complex<float> *csrvala, int *csrrowptra, int *csrcolinda,
409   c10::complex<float> *b, int64_t ldb, c10::complex<float> beta, c10::complex<float> *c, int64_t ldc)
410 {
411 
412   #ifdef USE_ROCM
413   Ccsrmm2(transa, transb, m, n, k, nnz,
414     reinterpret_cast<const hipComplex*>(&alpha),
415     reinterpret_cast<const hipComplex*>(csrvala),
416     csrrowptra,
417     csrcolinda,
418     reinterpret_cast<const hipComplex*>(b),
419     ldb,
420     reinterpret_cast<const hipComplex*>(&beta),
421     reinterpret_cast<hipComplex*>(c), ldc);
422   #else
423   Ccsrmm2(transa, transb, m, n, k, nnz,
424       reinterpret_cast<const cuComplex*>(&alpha),
425       reinterpret_cast<const cuComplex*>(csrvala),
426       csrrowptra,
427       csrcolinda,
428       reinterpret_cast<const cuComplex*>(b),
429       ldb,
430       reinterpret_cast<const cuComplex*>(&beta),
431       reinterpret_cast<cuComplex*>(c), ldc);
432   #endif
433 }
434 
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,c10::complex<double> alpha,c10::complex<double> * csrvala,int * csrrowptra,int * csrcolinda,c10::complex<double> * b,int64_t ldb,c10::complex<double> beta,c10::complex<double> * c,int64_t ldc)435 template<> void csrmm2<c10::complex<double>>(
436   char transa, char transb,
437   int64_t m, int64_t n, int64_t k, int64_t nnz,
438   c10::complex<double> alpha, c10::complex<double> *csrvala, int *csrrowptra, int *csrcolinda,
439   c10::complex<double> *b, int64_t ldb, c10::complex<double> beta, c10::complex<double> *c, int64_t ldc)
440 {
441   #ifdef USE_ROCM
442   Zcsrmm2(transa, transb, m, n, k, nnz,
443     reinterpret_cast<const hipDoubleComplex*>(&alpha),
444     reinterpret_cast<const hipDoubleComplex*>(csrvala),
445     csrrowptra,
446     csrcolinda,
447     reinterpret_cast<const hipDoubleComplex*>(b),
448     ldb,
449     reinterpret_cast<const hipDoubleComplex*>(&beta),
450     reinterpret_cast<hipDoubleComplex*>(c), ldc);
451   #else
452   Zcsrmm2(transa, transb, m, n, k, nnz,
453     reinterpret_cast<const cuDoubleComplex*>(&alpha),
454     reinterpret_cast<const cuDoubleComplex*>(csrvala),
455     csrrowptra,
456     csrcolinda,
457     reinterpret_cast<const cuDoubleComplex*>(b),
458     ldb,
459     reinterpret_cast<const cuDoubleComplex*>(&beta),
460     reinterpret_cast<cuDoubleComplex*>(c), ldc);
461   #endif
462 }
463 
464 
465 #endif
466 
467 /* format conversion */
CreateIdentityPermutation(int64_t nnz,int * P)468 void CreateIdentityPermutation(int64_t nnz, int *P) {
469   TORCH_CHECK((nnz <= INT_MAX),
470     "Xcsrsort_bufferSizeExt only supports m, n, nnz with the bound [val] <= ",
471     INT_MAX);
472   int i_nnz = (int)nnz;
473 
474   auto handle = at::cuda::getCurrentCUDASparseHandle();
475   cusparseCreateIdentityPermutation(handle, i_nnz, P);
476 }
477 
Xcsrsort_bufferSizeExt(int64_t m,int64_t n,int64_t nnz,const int * csrRowPtr,const int * csrColInd,size_t * pBufferSizeInBytes)478 void Xcsrsort_bufferSizeExt(int64_t m, int64_t n, int64_t nnz, const int *csrRowPtr, const int *csrColInd, size_t *pBufferSizeInBytes)
479 {
480   TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX),
481     "Xcsrsort_bufferSizeExt only supports m, n, nnz with the bound [val] <=",
482     INT_MAX);
483   int i_m = (int)m;
484   int i_n = (int)n;
485   int i_nnz = (int)nnz;
486 
487   auto handle = at::cuda::getCurrentCUDASparseHandle();
488   TORCH_CUDASPARSE_CHECK(cusparseXcsrsort_bufferSizeExt(handle, i_m, i_n, i_nnz, csrRowPtr, csrColInd, pBufferSizeInBytes));
489 }
490 
Xcsrsort(int64_t m,int64_t n,int64_t nnz,const int * csrRowPtr,int * csrColInd,int * P,void * pBuffer)491 void Xcsrsort(int64_t m, int64_t n, int64_t nnz, const int *csrRowPtr, int *csrColInd, int *P, void *pBuffer)
492 {
493   TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX),
494     "Xcsrsort only supports m, n, nnz with the bound [val] <= ",
495     INT_MAX);
496   int i_m = (int)m;
497   int i_n = (int)n;
498   int i_nnz = (int)nnz;
499 
500   auto handle = at::cuda::getCurrentCUDASparseHandle();
501   cusparseMatDescr_t desc;
502   cusparseCreateMatDescr(&desc);
503   TORCH_CUDASPARSE_CHECK(cusparseXcsrsort(handle, i_m, i_n, i_nnz, desc, csrRowPtr, csrColInd, P, pBuffer));
504   TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc));
505 }
506 
Xcoosort_bufferSizeExt(int64_t m,int64_t n,int64_t nnz,const int * cooRows,const int * cooCols,size_t * pBufferSizeInBytes)507 void Xcoosort_bufferSizeExt(int64_t m, int64_t n, int64_t nnz, const int *cooRows, const int *cooCols, size_t *pBufferSizeInBytes)
508 {
509   TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX),
510     "Xcoosort_bufferSizeExt only supports m, n, nnz with the bound [val] <= ",
511     INT_MAX);
512   int i_m = (int)m;
513   int i_n = (int)n;
514   int i_nnz = (int)nnz;
515 
516   auto handle = at::cuda::getCurrentCUDASparseHandle();
517   TORCH_CUDASPARSE_CHECK(cusparseXcoosort_bufferSizeExt(handle, i_m, i_n, i_nnz, cooRows, cooCols, pBufferSizeInBytes));
518 }
519 
XcoosortByRow(int64_t m,int64_t n,int64_t nnz,int * cooRows,int * cooCols,int * P,void * pBuffer)520 void XcoosortByRow(int64_t m, int64_t n, int64_t nnz, int *cooRows, int *cooCols, int *P, void *pBuffer)
521 {
522   TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX),
523     "XcoosortByRow only supports m, n, nnz with the bound [val] <= ",
524     INT_MAX);
525   int i_m = (int)m;
526   int i_n = (int)n;
527   int i_nnz = (int)nnz;
528 
529   auto handle = at::cuda::getCurrentCUDASparseHandle();
530   TORCH_CUDASPARSE_CHECK(cusparseXcoosortByRow(handle, i_m, i_n, i_nnz, cooRows, cooCols, P, pBuffer));
531 }
532 
533 
534 } // namespace at::native::sparse::cuda
535