1 /*
2 Provides the implementations of cuSPARSE function templates.
3 */
4
5 #include <ATen/cuda/CUDASparseBlas.h>
6
7 namespace at::cuda::sparse {
8
9 template <>
csrgeam2_bufferSizeExt(CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES (float))10 void csrgeam2_bufferSizeExt<float>(
11 CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float)) {
12 TORCH_CUDASPARSE_CHECK(cusparseScsrgeam2_bufferSizeExt(
13 handle,
14 m,
15 n,
16 alpha,
17 descrA,
18 nnzA,
19 csrSortedValA,
20 csrSortedRowPtrA,
21 csrSortedColIndA,
22 beta,
23 descrB,
24 nnzB,
25 csrSortedValB,
26 csrSortedRowPtrB,
27 csrSortedColIndB,
28 descrC,
29 csrSortedValC,
30 csrSortedRowPtrC,
31 csrSortedColIndC,
32 pBufferSizeInBytes));
33 }
34
35 template <>
csrgeam2_bufferSizeExt(CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES (double))36 void csrgeam2_bufferSizeExt<double>(
37 CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double)) {
38 TORCH_CUDASPARSE_CHECK(cusparseDcsrgeam2_bufferSizeExt(
39 handle,
40 m,
41 n,
42 alpha,
43 descrA,
44 nnzA,
45 csrSortedValA,
46 csrSortedRowPtrA,
47 csrSortedColIndA,
48 beta,
49 descrB,
50 nnzB,
51 csrSortedValB,
52 csrSortedRowPtrB,
53 csrSortedColIndB,
54 descrC,
55 csrSortedValC,
56 csrSortedRowPtrC,
57 csrSortedColIndC,
58 pBufferSizeInBytes));
59 }
60
61 template <>
csrgeam2_bufferSizeExt(CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES (c10::complex<float>))62 void csrgeam2_bufferSizeExt<c10::complex<float>>(
63 CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<float>)) {
64 TORCH_CUDASPARSE_CHECK(cusparseCcsrgeam2_bufferSizeExt(
65 handle,
66 m,
67 n,
68 reinterpret_cast<const cuComplex*>(alpha),
69 descrA,
70 nnzA,
71 reinterpret_cast<const cuComplex*>(csrSortedValA),
72 csrSortedRowPtrA,
73 csrSortedColIndA,
74 reinterpret_cast<const cuComplex*>(beta),
75 descrB,
76 nnzB,
77 reinterpret_cast<const cuComplex*>(csrSortedValB),
78 csrSortedRowPtrB,
79 csrSortedColIndB,
80 descrC,
81 reinterpret_cast<const cuComplex*>(csrSortedValC),
82 csrSortedRowPtrC,
83 csrSortedColIndC,
84 pBufferSizeInBytes));
85 }
86
87 template <>
csrgeam2_bufferSizeExt(CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES (c10::complex<double>))88 void csrgeam2_bufferSizeExt<c10::complex<double>>(
89 CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<double>)) {
90 TORCH_CUDASPARSE_CHECK(cusparseZcsrgeam2_bufferSizeExt(
91 handle,
92 m,
93 n,
94 reinterpret_cast<const cuDoubleComplex*>(alpha),
95 descrA,
96 nnzA,
97 reinterpret_cast<const cuDoubleComplex*>(csrSortedValA),
98 csrSortedRowPtrA,
99 csrSortedColIndA,
100 reinterpret_cast<const cuDoubleComplex*>(beta),
101 descrB,
102 nnzB,
103 reinterpret_cast<const cuDoubleComplex*>(csrSortedValB),
104 csrSortedRowPtrB,
105 csrSortedColIndB,
106 descrC,
107 reinterpret_cast<const cuDoubleComplex*>(csrSortedValC),
108 csrSortedRowPtrC,
109 csrSortedColIndC,
110 pBufferSizeInBytes));
111 }
112
113 template <>
csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES (float))114 void csrgeam2<float>(CUSPARSE_CSRGEAM2_ARGTYPES(float)) {
115 TORCH_CUDASPARSE_CHECK(cusparseScsrgeam2(
116 handle,
117 m,
118 n,
119 alpha,
120 descrA,
121 nnzA,
122 csrSortedValA,
123 csrSortedRowPtrA,
124 csrSortedColIndA,
125 beta,
126 descrB,
127 nnzB,
128 csrSortedValB,
129 csrSortedRowPtrB,
130 csrSortedColIndB,
131 descrC,
132 csrSortedValC,
133 csrSortedRowPtrC,
134 csrSortedColIndC,
135 pBuffer));
136 }
137
138 template <>
csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES (double))139 void csrgeam2<double>(CUSPARSE_CSRGEAM2_ARGTYPES(double)) {
140 TORCH_CUDASPARSE_CHECK(cusparseDcsrgeam2(
141 handle,
142 m,
143 n,
144 alpha,
145 descrA,
146 nnzA,
147 csrSortedValA,
148 csrSortedRowPtrA,
149 csrSortedColIndA,
150 beta,
151 descrB,
152 nnzB,
153 csrSortedValB,
154 csrSortedRowPtrB,
155 csrSortedColIndB,
156 descrC,
157 csrSortedValC,
158 csrSortedRowPtrC,
159 csrSortedColIndC,
160 pBuffer));
161 }
162
163 template <>
csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES (c10::complex<float>))164 void csrgeam2<c10::complex<float>>(
165 CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<float>)) {
166 TORCH_CUDASPARSE_CHECK(cusparseCcsrgeam2(
167 handle,
168 m,
169 n,
170 reinterpret_cast<const cuComplex*>(alpha),
171 descrA,
172 nnzA,
173 reinterpret_cast<const cuComplex*>(csrSortedValA),
174 csrSortedRowPtrA,
175 csrSortedColIndA,
176 reinterpret_cast<const cuComplex*>(beta),
177 descrB,
178 nnzB,
179 reinterpret_cast<const cuComplex*>(csrSortedValB),
180 csrSortedRowPtrB,
181 csrSortedColIndB,
182 descrC,
183 reinterpret_cast<cuComplex*>(csrSortedValC),
184 csrSortedRowPtrC,
185 csrSortedColIndC,
186 pBuffer));
187 }
188
189 template <>
csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES (c10::complex<double>))190 void csrgeam2<c10::complex<double>>(
191 CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<double>)) {
192 TORCH_CUDASPARSE_CHECK(cusparseZcsrgeam2(
193 handle,
194 m,
195 n,
196 reinterpret_cast<const cuDoubleComplex*>(alpha),
197 descrA,
198 nnzA,
199 reinterpret_cast<const cuDoubleComplex*>(csrSortedValA),
200 csrSortedRowPtrA,
201 csrSortedColIndA,
202 reinterpret_cast<const cuDoubleComplex*>(beta),
203 descrB,
204 nnzB,
205 reinterpret_cast<const cuDoubleComplex*>(csrSortedValB),
206 csrSortedRowPtrB,
207 csrSortedColIndB,
208 descrC,
209 reinterpret_cast<cuDoubleComplex*>(csrSortedValC),
210 csrSortedRowPtrC,
211 csrSortedColIndC,
212 pBuffer));
213 }
214
215 template <>
bsrmm(CUSPARSE_BSRMM_ARGTYPES (float))216 void bsrmm<float>(CUSPARSE_BSRMM_ARGTYPES(float)) {
217 TORCH_CUDASPARSE_CHECK(cusparseSbsrmm(
218 handle,
219 dirA,
220 transA,
221 transB,
222 mb,
223 n,
224 kb,
225 nnzb,
226 alpha,
227 descrA,
228 bsrValA,
229 bsrRowPtrA,
230 bsrColIndA,
231 blockDim,
232 B,
233 ldb,
234 beta,
235 C,
236 ldc));
237 }
238
239 template <>
bsrmm(CUSPARSE_BSRMM_ARGTYPES (double))240 void bsrmm<double>(CUSPARSE_BSRMM_ARGTYPES(double)) {
241 TORCH_CUDASPARSE_CHECK(cusparseDbsrmm(
242 handle,
243 dirA,
244 transA,
245 transB,
246 mb,
247 n,
248 kb,
249 nnzb,
250 alpha,
251 descrA,
252 bsrValA,
253 bsrRowPtrA,
254 bsrColIndA,
255 blockDim,
256 B,
257 ldb,
258 beta,
259 C,
260 ldc));
261 }
262
263 template <>
bsrmm(CUSPARSE_BSRMM_ARGTYPES (c10::complex<float>))264 void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>)) {
265 TORCH_CUDASPARSE_CHECK(cusparseCbsrmm(
266 handle,
267 dirA,
268 transA,
269 transB,
270 mb,
271 n,
272 kb,
273 nnzb,
274 reinterpret_cast<const cuComplex*>(alpha),
275 descrA,
276 reinterpret_cast<const cuComplex*>(bsrValA),
277 bsrRowPtrA,
278 bsrColIndA,
279 blockDim,
280 reinterpret_cast<const cuComplex*>(B),
281 ldb,
282 reinterpret_cast<const cuComplex*>(beta),
283 reinterpret_cast<cuComplex*>(C),
284 ldc));
285 }
286
287 template <>
bsrmm(CUSPARSE_BSRMM_ARGTYPES (c10::complex<double>))288 void bsrmm<c10::complex<double>>(
289 CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>)) {
290 TORCH_CUDASPARSE_CHECK(cusparseZbsrmm(
291 handle,
292 dirA,
293 transA,
294 transB,
295 mb,
296 n,
297 kb,
298 nnzb,
299 reinterpret_cast<const cuDoubleComplex*>(alpha),
300 descrA,
301 reinterpret_cast<const cuDoubleComplex*>(bsrValA),
302 bsrRowPtrA,
303 bsrColIndA,
304 blockDim,
305 reinterpret_cast<const cuDoubleComplex*>(B),
306 ldb,
307 reinterpret_cast<const cuDoubleComplex*>(beta),
308 reinterpret_cast<cuDoubleComplex*>(C),
309 ldc));
310 }
311
312 template <>
bsrmv(CUSPARSE_BSRMV_ARGTYPES (float))313 void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float)) {
314 TORCH_CUDASPARSE_CHECK(cusparseSbsrmv(
315 handle,
316 dirA,
317 transA,
318 mb,
319 nb,
320 nnzb,
321 alpha,
322 descrA,
323 bsrValA,
324 bsrRowPtrA,
325 bsrColIndA,
326 blockDim,
327 x,
328 beta,
329 y));
330 }
331
332 template <>
bsrmv(CUSPARSE_BSRMV_ARGTYPES (double))333 void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double)) {
334 TORCH_CUDASPARSE_CHECK(cusparseDbsrmv(
335 handle,
336 dirA,
337 transA,
338 mb,
339 nb,
340 nnzb,
341 alpha,
342 descrA,
343 bsrValA,
344 bsrRowPtrA,
345 bsrColIndA,
346 blockDim,
347 x,
348 beta,
349 y));
350 }
351
352 template <>
bsrmv(CUSPARSE_BSRMV_ARGTYPES (c10::complex<float>))353 void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>)) {
354 TORCH_CUDASPARSE_CHECK(cusparseCbsrmv(
355 handle,
356 dirA,
357 transA,
358 mb,
359 nb,
360 nnzb,
361 reinterpret_cast<const cuComplex*>(alpha),
362 descrA,
363 reinterpret_cast<const cuComplex*>(bsrValA),
364 bsrRowPtrA,
365 bsrColIndA,
366 blockDim,
367 reinterpret_cast<const cuComplex*>(x),
368 reinterpret_cast<const cuComplex*>(beta),
369 reinterpret_cast<cuComplex*>(y)));
370 }
371
372 template <>
bsrmv(CUSPARSE_BSRMV_ARGTYPES (c10::complex<double>))373 void bsrmv<c10::complex<double>>(
374 CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>)) {
375 TORCH_CUDASPARSE_CHECK(cusparseZbsrmv(
376 handle,
377 dirA,
378 transA,
379 mb,
380 nb,
381 nnzb,
382 reinterpret_cast<const cuDoubleComplex*>(alpha),
383 descrA,
384 reinterpret_cast<const cuDoubleComplex*>(bsrValA),
385 bsrRowPtrA,
386 bsrColIndA,
387 blockDim,
388 reinterpret_cast<const cuDoubleComplex*>(x),
389 reinterpret_cast<const cuDoubleComplex*>(beta),
390 reinterpret_cast<cuDoubleComplex*>(y)));
391 }
392
393 #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
394
395 template <>
bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES (float))396 void bsrsv2_bufferSize<float>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float)) {
397 TORCH_CUDASPARSE_CHECK(cusparseSbsrsv2_bufferSize(
398 handle,
399 dirA,
400 transA,
401 mb,
402 nnzb,
403 descrA,
404 bsrValA,
405 bsrRowPtrA,
406 bsrColIndA,
407 blockDim,
408 info,
409 pBufferSizeInBytes));
410 }
411
412 template <>
bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES (double))413 void bsrsv2_bufferSize<double>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double)) {
414 TORCH_CUDASPARSE_CHECK(cusparseDbsrsv2_bufferSize(
415 handle,
416 dirA,
417 transA,
418 mb,
419 nnzb,
420 descrA,
421 bsrValA,
422 bsrRowPtrA,
423 bsrColIndA,
424 blockDim,
425 info,
426 pBufferSizeInBytes));
427 }
428
429 template <>
bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES (c10::complex<float>))430 void bsrsv2_bufferSize<c10::complex<float>>(
431 CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<float>)) {
432 TORCH_CUDASPARSE_CHECK(cusparseCbsrsv2_bufferSize(
433 handle,
434 dirA,
435 transA,
436 mb,
437 nnzb,
438 descrA,
439 reinterpret_cast<cuComplex*>(bsrValA),
440 bsrRowPtrA,
441 bsrColIndA,
442 blockDim,
443 info,
444 pBufferSizeInBytes));
445 }
446
447 template <>
bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES (c10::complex<double>))448 void bsrsv2_bufferSize<c10::complex<double>>(
449 CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<double>)) {
450 TORCH_CUDASPARSE_CHECK(cusparseZbsrsv2_bufferSize(
451 handle,
452 dirA,
453 transA,
454 mb,
455 nnzb,
456 descrA,
457 reinterpret_cast<cuDoubleComplex*>(bsrValA),
458 bsrRowPtrA,
459 bsrColIndA,
460 blockDim,
461 info,
462 pBufferSizeInBytes));
463 }
464
465 template <>
bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES (float))466 void bsrsv2_analysis<float>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float)) {
467 TORCH_CUDASPARSE_CHECK(cusparseSbsrsv2_analysis(
468 handle,
469 dirA,
470 transA,
471 mb,
472 nnzb,
473 descrA,
474 bsrValA,
475 bsrRowPtrA,
476 bsrColIndA,
477 blockDim,
478 info,
479 policy,
480 pBuffer));
481 }
482
483 template <>
bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES (double))484 void bsrsv2_analysis<double>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double)) {
485 TORCH_CUDASPARSE_CHECK(cusparseDbsrsv2_analysis(
486 handle,
487 dirA,
488 transA,
489 mb,
490 nnzb,
491 descrA,
492 bsrValA,
493 bsrRowPtrA,
494 bsrColIndA,
495 blockDim,
496 info,
497 policy,
498 pBuffer));
499 }
500
501 template <>
bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES (c10::complex<float>))502 void bsrsv2_analysis<c10::complex<float>>(
503 CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<float>)) {
504 TORCH_CUDASPARSE_CHECK(cusparseCbsrsv2_analysis(
505 handle,
506 dirA,
507 transA,
508 mb,
509 nnzb,
510 descrA,
511 reinterpret_cast<const cuComplex*>(bsrValA),
512 bsrRowPtrA,
513 bsrColIndA,
514 blockDim,
515 info,
516 policy,
517 pBuffer));
518 }
519
520 template <>
bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES (c10::complex<double>))521 void bsrsv2_analysis<c10::complex<double>>(
522 CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<double>)) {
523 TORCH_CUDASPARSE_CHECK(cusparseZbsrsv2_analysis(
524 handle,
525 dirA,
526 transA,
527 mb,
528 nnzb,
529 descrA,
530 reinterpret_cast<const cuDoubleComplex*>(bsrValA),
531 bsrRowPtrA,
532 bsrColIndA,
533 blockDim,
534 info,
535 policy,
536 pBuffer));
537 }
538
539 template <>
bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES (float))540 void bsrsv2_solve<float>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float)) {
541 TORCH_CUDASPARSE_CHECK(cusparseSbsrsv2_solve(
542 handle,
543 dirA,
544 transA,
545 mb,
546 nnzb,
547 alpha,
548 descrA,
549 bsrValA,
550 bsrRowPtrA,
551 bsrColIndA,
552 blockDim,
553 info,
554 x,
555 y,
556 policy,
557 pBuffer));
558 }
559
560 template <>
bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES (double))561 void bsrsv2_solve<double>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double)) {
562 TORCH_CUDASPARSE_CHECK(cusparseDbsrsv2_solve(
563 handle,
564 dirA,
565 transA,
566 mb,
567 nnzb,
568 alpha,
569 descrA,
570 bsrValA,
571 bsrRowPtrA,
572 bsrColIndA,
573 blockDim,
574 info,
575 x,
576 y,
577 policy,
578 pBuffer));
579 }
580
581 template <>
bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES (c10::complex<float>))582 void bsrsv2_solve<c10::complex<float>>(
583 CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<float>)) {
584 TORCH_CUDASPARSE_CHECK(cusparseCbsrsv2_solve(
585 handle,
586 dirA,
587 transA,
588 mb,
589 nnzb,
590 reinterpret_cast<const cuComplex*>(alpha),
591 descrA,
592 reinterpret_cast<const cuComplex*>(bsrValA),
593 bsrRowPtrA,
594 bsrColIndA,
595 blockDim,
596 info,
597 reinterpret_cast<const cuComplex*>(x),
598 reinterpret_cast<cuComplex*>(y),
599 policy,
600 pBuffer));
601 }
602
603 template <>
bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES (c10::complex<double>))604 void bsrsv2_solve<c10::complex<double>>(
605 CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<double>)) {
606 TORCH_CUDASPARSE_CHECK(cusparseZbsrsv2_solve(
607 handle,
608 dirA,
609 transA,
610 mb,
611 nnzb,
612 reinterpret_cast<const cuDoubleComplex*>(alpha),
613 descrA,
614 reinterpret_cast<const cuDoubleComplex*>(bsrValA),
615 bsrRowPtrA,
616 bsrColIndA,
617 blockDim,
618 info,
619 reinterpret_cast<const cuDoubleComplex*>(x),
620 reinterpret_cast<cuDoubleComplex*>(y),
621 policy,
622 pBuffer));
623 }
624
625 template <>
bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES (float))626 void bsrsm2_bufferSize<float>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float)) {
627 TORCH_CUDASPARSE_CHECK(cusparseSbsrsm2_bufferSize(
628 handle,
629 dirA,
630 transA,
631 transX,
632 mb,
633 n,
634 nnzb,
635 descrA,
636 bsrValA,
637 bsrRowPtrA,
638 bsrColIndA,
639 blockDim,
640 info,
641 pBufferSizeInBytes));
642 }
643
644 template <>
bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES (double))645 void bsrsm2_bufferSize<double>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double)) {
646 TORCH_CUDASPARSE_CHECK(cusparseDbsrsm2_bufferSize(
647 handle,
648 dirA,
649 transA,
650 transX,
651 mb,
652 n,
653 nnzb,
654 descrA,
655 bsrValA,
656 bsrRowPtrA,
657 bsrColIndA,
658 blockDim,
659 info,
660 pBufferSizeInBytes));
661 }
662
663 template <>
bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES (c10::complex<float>))664 void bsrsm2_bufferSize<c10::complex<float>>(
665 CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<float>)) {
666 TORCH_CUDASPARSE_CHECK(cusparseCbsrsm2_bufferSize(
667 handle,
668 dirA,
669 transA,
670 transX,
671 mb,
672 n,
673 nnzb,
674 descrA,
675 reinterpret_cast<cuComplex*>(bsrValA),
676 bsrRowPtrA,
677 bsrColIndA,
678 blockDim,
679 info,
680 pBufferSizeInBytes));
681 }
682
683 template <>
bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES (c10::complex<double>))684 void bsrsm2_bufferSize<c10::complex<double>>(
685 CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<double>)) {
686 TORCH_CUDASPARSE_CHECK(cusparseZbsrsm2_bufferSize(
687 handle,
688 dirA,
689 transA,
690 transX,
691 mb,
692 n,
693 nnzb,
694 descrA,
695 reinterpret_cast<cuDoubleComplex*>(bsrValA),
696 bsrRowPtrA,
697 bsrColIndA,
698 blockDim,
699 info,
700 pBufferSizeInBytes));
701 }
702
703 template <>
bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES (float))704 void bsrsm2_analysis<float>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float)) {
705 TORCH_CUDASPARSE_CHECK(cusparseSbsrsm2_analysis(
706 handle,
707 dirA,
708 transA,
709 transX,
710 mb,
711 n,
712 nnzb,
713 descrA,
714 bsrValA,
715 bsrRowPtrA,
716 bsrColIndA,
717 blockDim,
718 info,
719 policy,
720 pBuffer));
721 }
722
723 template <>
bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES (double))724 void bsrsm2_analysis<double>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double)) {
725 TORCH_CUDASPARSE_CHECK(cusparseDbsrsm2_analysis(
726 handle,
727 dirA,
728 transA,
729 transX,
730 mb,
731 n,
732 nnzb,
733 descrA,
734 bsrValA,
735 bsrRowPtrA,
736 bsrColIndA,
737 blockDim,
738 info,
739 policy,
740 pBuffer));
741 }
742
743 template <>
bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES (c10::complex<float>))744 void bsrsm2_analysis<c10::complex<float>>(
745 CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<float>)) {
746 TORCH_CUDASPARSE_CHECK(cusparseCbsrsm2_analysis(
747 handle,
748 dirA,
749 transA,
750 transX,
751 mb,
752 n,
753 nnzb,
754 descrA,
755 reinterpret_cast<const cuComplex*>(bsrValA),
756 bsrRowPtrA,
757 bsrColIndA,
758 blockDim,
759 info,
760 policy,
761 pBuffer));
762 }
763
764 template <>
bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES (c10::complex<double>))765 void bsrsm2_analysis<c10::complex<double>>(
766 CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<double>)) {
767 TORCH_CUDASPARSE_CHECK(cusparseZbsrsm2_analysis(
768 handle,
769 dirA,
770 transA,
771 transX,
772 mb,
773 n,
774 nnzb,
775 descrA,
776 reinterpret_cast<const cuDoubleComplex*>(bsrValA),
777 bsrRowPtrA,
778 bsrColIndA,
779 blockDim,
780 info,
781 policy,
782 pBuffer));
783 }
784
785 template <>
bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES (float))786 void bsrsm2_solve<float>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float)) {
787 TORCH_CUDASPARSE_CHECK(cusparseSbsrsm2_solve(
788 handle,
789 dirA,
790 transA,
791 transX,
792 mb,
793 n,
794 nnzb,
795 alpha,
796 descrA,
797 bsrValA,
798 bsrRowPtrA,
799 bsrColIndA,
800 blockDim,
801 info,
802 B,
803 ldb,
804 X,
805 ldx,
806 policy,
807 pBuffer));
808 }
809
810 template <>
bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES (double))811 void bsrsm2_solve<double>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double)) {
812 TORCH_CUDASPARSE_CHECK(cusparseDbsrsm2_solve(
813 handle,
814 dirA,
815 transA,
816 transX,
817 mb,
818 n,
819 nnzb,
820 alpha,
821 descrA,
822 bsrValA,
823 bsrRowPtrA,
824 bsrColIndA,
825 blockDim,
826 info,
827 B,
828 ldb,
829 X,
830 ldx,
831 policy,
832 pBuffer));
833 }
834
835 template <>
bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES (c10::complex<float>))836 void bsrsm2_solve<c10::complex<float>>(
837 CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<float>)) {
838 TORCH_CUDASPARSE_CHECK(cusparseCbsrsm2_solve(
839 handle,
840 dirA,
841 transA,
842 transX,
843 mb,
844 n,
845 nnzb,
846 reinterpret_cast<const cuComplex*>(alpha),
847 descrA,
848 reinterpret_cast<const cuComplex*>(bsrValA),
849 bsrRowPtrA,
850 bsrColIndA,
851 blockDim,
852 info,
853 reinterpret_cast<const cuComplex*>(B),
854 ldb,
855 reinterpret_cast<cuComplex*>(X),
856 ldx,
857 policy,
858 pBuffer));
859 }
860
861 template <>
bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES (c10::complex<double>))862 void bsrsm2_solve<c10::complex<double>>(
863 CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<double>)) {
864 TORCH_CUDASPARSE_CHECK(cusparseZbsrsm2_solve(
865 handle,
866 dirA,
867 transA,
868 transX,
869 mb,
870 n,
871 nnzb,
872 reinterpret_cast<const cuDoubleComplex*>(alpha),
873 descrA,
874 reinterpret_cast<const cuDoubleComplex*>(bsrValA),
875 bsrRowPtrA,
876 bsrColIndA,
877 blockDim,
878 info,
879 reinterpret_cast<const cuDoubleComplex*>(B),
880 ldb,
881 reinterpret_cast<cuDoubleComplex*>(X),
882 ldx,
883 policy,
884 pBuffer));
885 }
886
887 #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
888
889 } // namespace at::cuda::sparse
890