1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/cuda/CUDAContext.h>
3 #include <c10/util/Exception.h>
4 #include <ATen/cuda/Exceptions.h>
5 #include <ATen/native/sparse/cuda/SparseCUDABlas.h>
6 #include <c10/cuda/CUDACachingAllocator.h>
7
8 #include <cusparse.h>
9
10 // LIMITATION (cusparseSpMM):
11 // The generic APIs are available on all platforms on CUDA 11.0
12 // For CUDA 10.1+ it is available for all platforms except Windows.
13 // Using these APIs in any other systems will result in compile-time or run-time failures.
14 // Their support will be extended in the next releases.
15
16 #if defined(CUDART_VERSION) && (CUSPARSE_VERSION >= 11000 || (!defined(_MSC_VER) && CUSPARSE_VERSION >= 10301))
17 #define IS_SPMM_AVAILABLE() 1
18 #else
19 #define IS_SPMM_AVAILABLE() 0
20 #endif
21
22 #if defined(USE_ROCM)
23 #define IS_SPMM_HIP_AVAILABLE() 1
24 #else
25 #define IS_SPMM_HIP_AVAILABLE() 0
26 #endif
27
28 #if IS_SPMM_AVAILABLE() || IS_SPMM_HIP_AVAILABLE()
29 #include <library_types.h>
30 #endif
31
32 #if !defined(CUSPARSE_VERSION) || (CUSPARSE_VERSION < 10100)
cusparseGetErrorString(cusparseStatus_t status)33 const char* cusparseGetErrorString(cusparseStatus_t status) {
34 switch(status)
35 {
36 case CUSPARSE_STATUS_SUCCESS:
37 return "success";
38
39 case CUSPARSE_STATUS_NOT_INITIALIZED:
40 return "library not initialized";
41
42 case CUSPARSE_STATUS_ALLOC_FAILED:
43 return "resource allocation failed";
44
45 case CUSPARSE_STATUS_INVALID_VALUE:
46 return "an invalid numeric value was used as an argument";
47
48 case CUSPARSE_STATUS_ARCH_MISMATCH:
49 return "an absent device architectural feature is required";
50
51 case CUSPARSE_STATUS_MAPPING_ERROR:
52 return "an access to GPU memory space failed";
53
54 case CUSPARSE_STATUS_EXECUTION_FAILED:
55 return "the GPU program failed to execute";
56
57 case CUSPARSE_STATUS_INTERNAL_ERROR:
58 return "an internal operation failed";
59
60 case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
61 return "the matrix type is not supported by this function";
62
63 case CUSPARSE_STATUS_ZERO_PIVOT:
64 return "an entry of the matrix is either structural zero or numerical zero (singular block)";
65
66 default:
67 return "unknown error";
68 }
69 }
70 #endif
71
72 namespace at::native::sparse::cuda {
73
Xcoo2csr(const int * coorowind,int64_t nnz,int64_t m,int * csrrowptr)74 void Xcoo2csr(const int *coorowind, int64_t nnz, int64_t m, int *csrrowptr) {
75 TORCH_CHECK((m <= INT_MAX) && (nnz <= INT_MAX),
76 "cusparseXcoo2csr only supports m, nnz with the bound [val] <= ",
77 INT_MAX);
78
79 int i_nnz = (int)nnz;
80 int i_m = (int)m;
81
82 auto handle = at::cuda::getCurrentCUDASparseHandle();
83 TORCH_CUDASPARSE_CHECK(cusparseXcoo2csr(handle, coorowind, i_nnz, i_m, csrrowptr, CUSPARSE_INDEX_BASE_ZERO));
84 }
85
convertTransToCusparseOperation(char trans)86 cusparseOperation_t convertTransToCusparseOperation(char trans) {
87 if (trans == 't') return CUSPARSE_OPERATION_TRANSPOSE;
88 else if (trans == 'n') return CUSPARSE_OPERATION_NON_TRANSPOSE;
89 else if (trans == 'c') return CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE;
90 else {
91 AT_ERROR("trans must be one of: t, n, c");
92 }
93 }
94
95 #if IS_SPMM_AVAILABLE() || IS_SPMM_HIP_AVAILABLE()
96
97 namespace {
98 template<typename T>
_csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,T * alpha,T * csrvala,int * csrrowptra,int * csrcolinda,T * b,int64_t ldb,T * beta,T * c,int64_t ldc,cudaDataType cusparse_value_type)99 void _csrmm2(
100 char transa, char transb,
101 int64_t m, int64_t n, int64_t k, int64_t nnz,
102 T *alpha, T *csrvala, int *csrrowptra, int *csrcolinda,
103 T *b, int64_t ldb, T *beta, T *c, int64_t ldc,
104 cudaDataType cusparse_value_type)
105 {
106 if (csrvala == nullptr || b == nullptr || c == nullptr) return;
107
108 cusparseOperation_t opa = convertTransToCusparseOperation(transa);
109 cusparseOperation_t opb = convertTransToCusparseOperation(transb);
110
111 // cusparseSpMM actually supports int64_t.
112 // In order to support int64 here, index pointers csrrowptra, csrcolinda have to be passed as int64_t.
113 TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX),
114 "At the moment, cusparseSpMM only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX, ".",
115 "If you need this, please file an issue on GitHub."
116 );
117
118 int64_t ma = m, ka = k;
119 if (transa != 'n') std::swap(ma, ka);
120
121 cusparseSpMatDescr_t descA;
122 TORCH_CUDASPARSE_CHECK(cusparseCreateCsr(
123 &descA, /* output */
124 ma, ka, nnz, /* rows, cols, number of non zero elements */
125 csrrowptra, /* row offsets of the sparse matrix, size = rows +1 */
126 csrcolinda, /* column indices of the sparse matrix, size = nnz */
127 csrvala, /* values of the sparse matrix, size = nnz */
128 CUSPARSE_INDEX_32I, /* data type of row offsets index */
129 CUSPARSE_INDEX_32I, /* data type of col indices */
130 CUSPARSE_INDEX_BASE_ZERO, /* base index of row offset and col indes */
131 cusparse_value_type /* data type of values */
132 ));
133
134 int64_t kb = k, nb = n;
135 if (transb != 'n') std::swap(kb, nb);
136
137 cusparseDnMatDescr_t descB;
138 TORCH_CUDASPARSE_CHECK(cusparseCreateDnMat(
139 &descB, /* output */
140 kb, nb, ldb, /* rows, cols, leading dimension */
141 b, /* values */
142 cusparse_value_type, /* data type of values */
143 CUSPARSE_ORDER_COL /* memory layout, ONLY column-major is supported now */
144 ));
145
146 cusparseDnMatDescr_t descC;
147 TORCH_CUDASPARSE_CHECK(cusparseCreateDnMat(
148 &descC, /* output */
149 m, n, ldc, /* rows, cols, leading dimension */
150 c, /* values */
151 cusparse_value_type, /* data type of values */
152 CUSPARSE_ORDER_COL /* memory layout, ONLY column-major is supported now */
153 ));
154
155
156 auto handle = at::cuda::getCurrentCUDASparseHandle();
157 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
158 // ALG1 is broken on SM89 as of CUDA 11.8+
159 #if !defined(USE_ROCM)
160 auto default_alg = prop->major == 8 && prop->minor == 9 ? CUSPARSE_SPMM_CSR_ALG2 : CUSPARSE_SPMM_CSR_ALG1;
161 #else
162 auto default_alg = CUSPARSE_SPMM_CSR_ALG1;
163 #endif
164
165 // cusparseSpMM_bufferSize returns the bufferSize that can be used by cusparseSpMM
166 size_t bufferSize;
167 TORCH_CUDASPARSE_CHECK(cusparseSpMM_bufferSize(
168 handle, opa, opb,
169 alpha,
170 descA, descB,
171 beta,
172 descC,
173 cusparse_value_type, /* data type in which the computation is executed */
174 default_alg, /* default computing algorithm for CSR sparse matrix format */
175 &bufferSize /* output */
176 ));
177
178 auto& allocator = *c10::cuda::CUDACachingAllocator::get();
179 auto dataPtr = allocator.allocate(bufferSize);
180
181 TORCH_CUDASPARSE_CHECK(cusparseSpMM(
182 handle, opa, opb,
183 alpha,
184 descA, descB,
185 beta,
186 descC,
187 cusparse_value_type, /* data type in which the computation is executed */
188 default_alg, /* default computing algorithm for CSR sparse matrix format */
189 dataPtr.get() /* external buffer */
190 ));
191
192 TORCH_CUDASPARSE_CHECK(cusparseDestroySpMat(descA));
193 TORCH_CUDASPARSE_CHECK(cusparseDestroyDnMat(descB));
194 TORCH_CUDASPARSE_CHECK(cusparseDestroyDnMat(descC));
195
196 // TODO: Proper fix is to create real descriptor classes
197 }
198 } // end anonymous namespace
199
200 template<typename T>
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,T alpha,T * csrvala,int * csrrowptra,int * csrcolinda,T * b,int64_t ldb,T beta,T * c,int64_t ldc)201 void csrmm2(
202 char transa, char transb,
203 int64_t m, int64_t n, int64_t k, int64_t nnz,
204 T alpha, T *csrvala, int *csrrowptra, int *csrcolinda,
205 T *b, int64_t ldb, T beta, T *c, int64_t ldc)
206 {
207 static_assert(false&&sizeof(T), "cusparse csr MM only supports data type of float, double, cfloat and cdouble.");
208 }
209
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,float alpha,float * csrvala,int * csrrowptra,int * csrcolinda,float * b,int64_t ldb,float beta,float * c,int64_t ldc)210 template<> void csrmm2<float>(
211 char transa, char transb,
212 int64_t m, int64_t n, int64_t k, int64_t nnz,
213 float alpha, float *csrvala, int *csrrowptra, int *csrcolinda,
214 float *b, int64_t ldb, float beta, float *c, int64_t ldc)
215 {
216 _csrmm2(transa, transb, m, n, k, nnz, &alpha, csrvala, csrrowptra, csrcolinda, b, ldb, &beta, c, ldc, CUDA_R_32F);
217 }
218
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,double alpha,double * csrvala,int * csrrowptra,int * csrcolinda,double * b,int64_t ldb,double beta,double * c,int64_t ldc)219 template<> void csrmm2<double>(
220 char transa, char transb,
221 int64_t m, int64_t n, int64_t k, int64_t nnz,
222 double alpha, double *csrvala, int *csrrowptra, int *csrcolinda,
223 double *b, int64_t ldb, double beta, double *c, int64_t ldc)
224 {
225 _csrmm2(transa, transb, m, n, k, nnz, &alpha, csrvala, csrrowptra, csrcolinda, b, ldb, &beta, c, ldc, CUDA_R_64F);
226 }
227
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,c10::complex<float> alpha,c10::complex<float> * csrvala,int * csrrowptra,int * csrcolinda,c10::complex<float> * b,int64_t ldb,c10::complex<float> beta,c10::complex<float> * c,int64_t ldc)228 template<> void csrmm2<c10::complex<float>>(
229 char transa, char transb,
230 int64_t m, int64_t n, int64_t k, int64_t nnz,
231 c10::complex<float> alpha, c10::complex<float> *csrvala, int *csrrowptra, int *csrcolinda,
232 c10::complex<float> *b, int64_t ldb, c10::complex<float> beta, c10::complex<float> *c, int64_t ldc)
233 {
234 _csrmm2(transa, transb, m, n, k, nnz,
235 reinterpret_cast<cuComplex*>(&alpha),
236 reinterpret_cast<cuComplex*>(csrvala),
237 csrrowptra,
238 csrcolinda,
239 reinterpret_cast<cuComplex*>(b),
240 ldb,
241 reinterpret_cast<cuComplex*>(&beta),
242 reinterpret_cast<cuComplex*>(c), ldc, CUDA_C_32F);
243 }
244
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,c10::complex<double> alpha,c10::complex<double> * csrvala,int * csrrowptra,int * csrcolinda,c10::complex<double> * b,int64_t ldb,c10::complex<double> beta,c10::complex<double> * c,int64_t ldc)245 template<> void csrmm2<c10::complex<double>>(
246 char transa, char transb,
247 int64_t m, int64_t n, int64_t k, int64_t nnz,
248 c10::complex<double> alpha, c10::complex<double> *csrvala, int *csrrowptra, int *csrcolinda,
249 c10::complex<double> *b, int64_t ldb, c10::complex<double> beta, c10::complex<double> *c, int64_t ldc)
250 {
251 _csrmm2(transa, transb, m, n, k, nnz,
252 reinterpret_cast<cuDoubleComplex*>(&alpha),
253 reinterpret_cast<cuDoubleComplex*>(csrvala),
254 csrrowptra,
255 csrcolinda,
256 reinterpret_cast<cuDoubleComplex*>(b),
257 ldb,
258 reinterpret_cast<cuDoubleComplex*>(&beta),
259 reinterpret_cast<cuDoubleComplex*>(c), ldc, CUDA_C_64F);
260 }
261
262 #else
263
adjustLd(char transb,int64_t m,int64_t n,int64_t k,int64_t * ldb,int64_t * ldc)264 void adjustLd(char transb, int64_t m, int64_t n, int64_t k, int64_t *ldb, int64_t *ldc)
265 {
266 int transb_ = ((transb == 't') || (transb == 'T'));
267
268 if(n == 1)
269 *ldc = m;
270
271 if(transb_)
272 {
273 if(k == 1)
274 *ldb = n;
275 }
276 else
277 {
278 if(n == 1)
279 *ldb = k;
280 }
281 }
282
Scsrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,const float * alpha,const float * csrvala,int * csrrowptra,int * csrcolinda,const float * b,int64_t ldb,const float * beta,float * c,int64_t ldc)283 void Scsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const float *alpha, const float *csrvala, int *csrrowptra, int *csrcolinda, const float *b, int64_t ldb, const float *beta, float *c, int64_t ldc)
284 {
285 adjustLd(transb, m, n, k, &ldb, &ldc);
286 cusparseOperation_t opa = convertTransToCusparseOperation(transa);
287 cusparseOperation_t opb = convertTransToCusparseOperation(transb);
288
289 TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX),
290 "cusparseScsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX);
291 int i_m = (int)m;
292 int i_n = (int)n;
293 int i_k = (int)k;
294 int i_nnz = (int)nnz;
295 int i_ldb = (int)ldb;
296 int i_ldc = (int)ldc;
297
298 auto handle = at::cuda::getCurrentCUDASparseHandle();
299 cusparseMatDescr_t desc;
300 cusparseCreateMatDescr(&desc);
301 TORCH_CUDASPARSE_CHECK(cusparseScsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc));
302 TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc));
303 }
304
Dcsrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,const double * alpha,const double * csrvala,int * csrrowptra,int * csrcolinda,const double * b,int64_t ldb,const double * beta,double * c,int64_t ldc)305 void Dcsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const double *alpha, const double *csrvala, int *csrrowptra, int *csrcolinda, const double *b, int64_t ldb, const double *beta, double *c, int64_t ldc)
306 {
307 adjustLd(transb, m, n, k, &ldb, &ldc);
308 cusparseOperation_t opa = convertTransToCusparseOperation(transa);
309 cusparseOperation_t opb = convertTransToCusparseOperation(transb);
310
311 TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX),
312 "cusparseDcsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX);
313 int i_m = (int)m;
314 int i_n = (int)n;
315 int i_k = (int)k;
316 int i_nnz = (int)nnz;
317 int i_ldb = (int)ldb;
318 int i_ldc = (int)ldc;
319
320
321 auto handle = at::cuda::getCurrentCUDASparseHandle();
322 cusparseMatDescr_t desc;
323 cusparseCreateMatDescr(&desc);
324 TORCH_CUDASPARSE_CHECK(cusparseDcsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc));
325 TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc));
326 // TODO: Proper fix is to create real descriptor classes
327 }
328
329 template<class complex_target_t>
Ccsrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,const complex_target_t * alpha,const complex_target_t * csrvala,int * csrrowptra,int * csrcolinda,const complex_target_t * b,int64_t ldb,const complex_target_t * beta,complex_target_t * c,int64_t ldc)330 void Ccsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const complex_target_t *alpha, const complex_target_t *csrvala, int *csrrowptra, int *csrcolinda, const complex_target_t *b, int64_t ldb, const complex_target_t *beta, complex_target_t *c, int64_t ldc)
331 {
332 adjustLd(transb, m, n, k, &ldb, &ldc);
333 cusparseOperation_t opa = convertTransToCusparseOperation(transa);
334 cusparseOperation_t opb = convertTransToCusparseOperation(transb);
335
336 TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX),
337 "cusparseCcsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX);
338 int i_m = (int)m;
339 int i_n = (int)n;
340 int i_k = (int)k;
341 int i_nnz = (int)nnz;
342 int i_ldb = (int)ldb;
343 int i_ldc = (int)ldc;
344
345 auto handle = at::cuda::getCurrentCUDASparseHandle();
346 cusparseMatDescr_t desc;
347 cusparseCreateMatDescr(&desc);
348 TORCH_CUDASPARSE_CHECK(cusparseCcsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc));
349 TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc));
350 }
351
352 template<class complex_target_t>
Zcsrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,const complex_target_t * alpha,const complex_target_t * csrvala,int * csrrowptra,int * csrcolinda,const complex_target_t * b,int64_t ldb,const complex_target_t * beta,complex_target_t * c,int64_t ldc)353 void Zcsrmm2(char transa, char transb, int64_t m, int64_t n, int64_t k, int64_t nnz, const complex_target_t *alpha, const complex_target_t *csrvala, int *csrrowptra, int *csrcolinda, const complex_target_t *b, int64_t ldb, const complex_target_t *beta, complex_target_t *c, int64_t ldc)
354 {
355 adjustLd(transb, m, n, k, &ldb, &ldc);
356 cusparseOperation_t opa = convertTransToCusparseOperation(transa);
357 cusparseOperation_t opb = convertTransToCusparseOperation(transb);
358
359 TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (nnz <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX),
360 "cusparseZcsrmm2 only supports m, n, k, nnz, ldb, ldc with the bound [val] <= ", INT_MAX);
361 int i_m = (int)m;
362 int i_n = (int)n;
363 int i_k = (int)k;
364 int i_nnz = (int)nnz;
365 int i_ldb = (int)ldb;
366 int i_ldc = (int)ldc;
367
368
369 auto handle = at::cuda::getCurrentCUDASparseHandle();
370 cusparseMatDescr_t desc;
371 cusparseCreateMatDescr(&desc);
372 TORCH_CUDASPARSE_CHECK(cusparseZcsrmm2(handle, opa, opb, i_m, i_n, i_k, i_nnz, alpha, desc, csrvala, csrrowptra, csrcolinda, b, i_ldb, beta, c, i_ldc));
373 TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc));
374 }
375
376 // T can only be float or double
377 template<typename T>
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,T alpha,T * csrvala,int * csrrowptra,int * csrcolinda,T * b,int64_t ldb,T beta,T * c,int64_t ldc)378 void csrmm2(
379 char transa, char transb,
380 int64_t m, int64_t n, int64_t k, int64_t nnz,
381 T alpha, T *csrvala, int *csrrowptra, int *csrcolinda,
382 T *b, int64_t ldb, T beta, T *c, int64_t ldc)
383 {
384 static_assert(false&&sizeof(T), "cusparse csr MM only supports data type of float, double, cfloat and cdouble.");
385 }
386
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,float alpha,float * csrvala,int * csrrowptra,int * csrcolinda,float * b,int64_t ldb,float beta,float * c,int64_t ldc)387 template<> void csrmm2<float>(
388 char transa, char transb,
389 int64_t m, int64_t n, int64_t k, int64_t nnz,
390 float alpha, float *csrvala, int *csrrowptra, int *csrcolinda,
391 float *b, int64_t ldb, float beta, float *c, int64_t ldc)
392 {
393 Scsrmm2(transa, transb, m, n, k, nnz, &alpha, csrvala, csrrowptra, csrcolinda, b, ldb, &beta, c, ldc);
394 }
395
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,double alpha,double * csrvala,int * csrrowptra,int * csrcolinda,double * b,int64_t ldb,double beta,double * c,int64_t ldc)396 template<> void csrmm2<double>(
397 char transa, char transb,
398 int64_t m, int64_t n, int64_t k, int64_t nnz,
399 double alpha, double *csrvala, int *csrrowptra, int *csrcolinda,
400 double *b, int64_t ldb, double beta, double *c, int64_t ldc)
401 {
402 Dcsrmm2(transa, transb, m, n, k, nnz, &alpha, csrvala, csrrowptra, csrcolinda, b, ldb, &beta, c, ldc);
403 }
404
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,c10::complex<float> alpha,c10::complex<float> * csrvala,int * csrrowptra,int * csrcolinda,c10::complex<float> * b,int64_t ldb,c10::complex<float> beta,c10::complex<float> * c,int64_t ldc)405 template<> void csrmm2<c10::complex<float>>(
406 char transa, char transb,
407 int64_t m, int64_t n, int64_t k, int64_t nnz,
408 c10::complex<float> alpha, c10::complex<float> *csrvala, int *csrrowptra, int *csrcolinda,
409 c10::complex<float> *b, int64_t ldb, c10::complex<float> beta, c10::complex<float> *c, int64_t ldc)
410 {
411
412 #ifdef USE_ROCM
413 Ccsrmm2(transa, transb, m, n, k, nnz,
414 reinterpret_cast<const hipComplex*>(&alpha),
415 reinterpret_cast<const hipComplex*>(csrvala),
416 csrrowptra,
417 csrcolinda,
418 reinterpret_cast<const hipComplex*>(b),
419 ldb,
420 reinterpret_cast<const hipComplex*>(&beta),
421 reinterpret_cast<hipComplex*>(c), ldc);
422 #else
423 Ccsrmm2(transa, transb, m, n, k, nnz,
424 reinterpret_cast<const cuComplex*>(&alpha),
425 reinterpret_cast<const cuComplex*>(csrvala),
426 csrrowptra,
427 csrcolinda,
428 reinterpret_cast<const cuComplex*>(b),
429 ldb,
430 reinterpret_cast<const cuComplex*>(&beta),
431 reinterpret_cast<cuComplex*>(c), ldc);
432 #endif
433 }
434
csrmm2(char transa,char transb,int64_t m,int64_t n,int64_t k,int64_t nnz,c10::complex<double> alpha,c10::complex<double> * csrvala,int * csrrowptra,int * csrcolinda,c10::complex<double> * b,int64_t ldb,c10::complex<double> beta,c10::complex<double> * c,int64_t ldc)435 template<> void csrmm2<c10::complex<double>>(
436 char transa, char transb,
437 int64_t m, int64_t n, int64_t k, int64_t nnz,
438 c10::complex<double> alpha, c10::complex<double> *csrvala, int *csrrowptra, int *csrcolinda,
439 c10::complex<double> *b, int64_t ldb, c10::complex<double> beta, c10::complex<double> *c, int64_t ldc)
440 {
441 #ifdef USE_ROCM
442 Zcsrmm2(transa, transb, m, n, k, nnz,
443 reinterpret_cast<const hipDoubleComplex*>(&alpha),
444 reinterpret_cast<const hipDoubleComplex*>(csrvala),
445 csrrowptra,
446 csrcolinda,
447 reinterpret_cast<const hipDoubleComplex*>(b),
448 ldb,
449 reinterpret_cast<const hipDoubleComplex*>(&beta),
450 reinterpret_cast<hipDoubleComplex*>(c), ldc);
451 #else
452 Zcsrmm2(transa, transb, m, n, k, nnz,
453 reinterpret_cast<const cuDoubleComplex*>(&alpha),
454 reinterpret_cast<const cuDoubleComplex*>(csrvala),
455 csrrowptra,
456 csrcolinda,
457 reinterpret_cast<const cuDoubleComplex*>(b),
458 ldb,
459 reinterpret_cast<const cuDoubleComplex*>(&beta),
460 reinterpret_cast<cuDoubleComplex*>(c), ldc);
461 #endif
462 }
463
464
465 #endif
466
467 /* format conversion */
CreateIdentityPermutation(int64_t nnz,int * P)468 void CreateIdentityPermutation(int64_t nnz, int *P) {
469 TORCH_CHECK((nnz <= INT_MAX),
470 "Xcsrsort_bufferSizeExt only supports m, n, nnz with the bound [val] <= ",
471 INT_MAX);
472 int i_nnz = (int)nnz;
473
474 auto handle = at::cuda::getCurrentCUDASparseHandle();
475 cusparseCreateIdentityPermutation(handle, i_nnz, P);
476 }
477
Xcsrsort_bufferSizeExt(int64_t m,int64_t n,int64_t nnz,const int * csrRowPtr,const int * csrColInd,size_t * pBufferSizeInBytes)478 void Xcsrsort_bufferSizeExt(int64_t m, int64_t n, int64_t nnz, const int *csrRowPtr, const int *csrColInd, size_t *pBufferSizeInBytes)
479 {
480 TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX),
481 "Xcsrsort_bufferSizeExt only supports m, n, nnz with the bound [val] <=",
482 INT_MAX);
483 int i_m = (int)m;
484 int i_n = (int)n;
485 int i_nnz = (int)nnz;
486
487 auto handle = at::cuda::getCurrentCUDASparseHandle();
488 TORCH_CUDASPARSE_CHECK(cusparseXcsrsort_bufferSizeExt(handle, i_m, i_n, i_nnz, csrRowPtr, csrColInd, pBufferSizeInBytes));
489 }
490
Xcsrsort(int64_t m,int64_t n,int64_t nnz,const int * csrRowPtr,int * csrColInd,int * P,void * pBuffer)491 void Xcsrsort(int64_t m, int64_t n, int64_t nnz, const int *csrRowPtr, int *csrColInd, int *P, void *pBuffer)
492 {
493 TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX),
494 "Xcsrsort only supports m, n, nnz with the bound [val] <= ",
495 INT_MAX);
496 int i_m = (int)m;
497 int i_n = (int)n;
498 int i_nnz = (int)nnz;
499
500 auto handle = at::cuda::getCurrentCUDASparseHandle();
501 cusparseMatDescr_t desc;
502 cusparseCreateMatDescr(&desc);
503 TORCH_CUDASPARSE_CHECK(cusparseXcsrsort(handle, i_m, i_n, i_nnz, desc, csrRowPtr, csrColInd, P, pBuffer));
504 TORCH_CUDASPARSE_CHECK(cusparseDestroyMatDescr(desc));
505 }
506
Xcoosort_bufferSizeExt(int64_t m,int64_t n,int64_t nnz,const int * cooRows,const int * cooCols,size_t * pBufferSizeInBytes)507 void Xcoosort_bufferSizeExt(int64_t m, int64_t n, int64_t nnz, const int *cooRows, const int *cooCols, size_t *pBufferSizeInBytes)
508 {
509 TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX),
510 "Xcoosort_bufferSizeExt only supports m, n, nnz with the bound [val] <= ",
511 INT_MAX);
512 int i_m = (int)m;
513 int i_n = (int)n;
514 int i_nnz = (int)nnz;
515
516 auto handle = at::cuda::getCurrentCUDASparseHandle();
517 TORCH_CUDASPARSE_CHECK(cusparseXcoosort_bufferSizeExt(handle, i_m, i_n, i_nnz, cooRows, cooCols, pBufferSizeInBytes));
518 }
519
XcoosortByRow(int64_t m,int64_t n,int64_t nnz,int * cooRows,int * cooCols,int * P,void * pBuffer)520 void XcoosortByRow(int64_t m, int64_t n, int64_t nnz, int *cooRows, int *cooCols, int *P, void *pBuffer)
521 {
522 TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX),
523 "XcoosortByRow only supports m, n, nnz with the bound [val] <= ",
524 INT_MAX);
525 int i_m = (int)m;
526 int i_n = (int)n;
527 int i_nnz = (int)nnz;
528
529 auto handle = at::cuda::getCurrentCUDASparseHandle();
530 TORCH_CUDASPARSE_CHECK(cusparseXcoosortByRow(handle, i_m, i_n, i_nnz, cooRows, cooCols, P, pBuffer));
531 }
532
533
534 } // namespace at::native::sparse::cuda
535