xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDASparseBlas.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2   Provides the implementations of cuSPARSE function templates.
3 */
4 
5 #include <ATen/cuda/CUDASparseBlas.h>
6 
7 namespace at::cuda::sparse {
8 
9 template <>
csrgeam2_bufferSizeExt(CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES (float))10 void csrgeam2_bufferSizeExt<float>(
11     CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(float)) {
12   TORCH_CUDASPARSE_CHECK(cusparseScsrgeam2_bufferSizeExt(
13       handle,
14       m,
15       n,
16       alpha,
17       descrA,
18       nnzA,
19       csrSortedValA,
20       csrSortedRowPtrA,
21       csrSortedColIndA,
22       beta,
23       descrB,
24       nnzB,
25       csrSortedValB,
26       csrSortedRowPtrB,
27       csrSortedColIndB,
28       descrC,
29       csrSortedValC,
30       csrSortedRowPtrC,
31       csrSortedColIndC,
32       pBufferSizeInBytes));
33 }
34 
35 template <>
csrgeam2_bufferSizeExt(CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES (double))36 void csrgeam2_bufferSizeExt<double>(
37     CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(double)) {
38   TORCH_CUDASPARSE_CHECK(cusparseDcsrgeam2_bufferSizeExt(
39       handle,
40       m,
41       n,
42       alpha,
43       descrA,
44       nnzA,
45       csrSortedValA,
46       csrSortedRowPtrA,
47       csrSortedColIndA,
48       beta,
49       descrB,
50       nnzB,
51       csrSortedValB,
52       csrSortedRowPtrB,
53       csrSortedColIndB,
54       descrC,
55       csrSortedValC,
56       csrSortedRowPtrC,
57       csrSortedColIndC,
58       pBufferSizeInBytes));
59 }
60 
61 template <>
csrgeam2_bufferSizeExt(CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES (c10::complex<float>))62 void csrgeam2_bufferSizeExt<c10::complex<float>>(
63     CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<float>)) {
64   TORCH_CUDASPARSE_CHECK(cusparseCcsrgeam2_bufferSizeExt(
65       handle,
66       m,
67       n,
68       reinterpret_cast<const cuComplex*>(alpha),
69       descrA,
70       nnzA,
71       reinterpret_cast<const cuComplex*>(csrSortedValA),
72       csrSortedRowPtrA,
73       csrSortedColIndA,
74       reinterpret_cast<const cuComplex*>(beta),
75       descrB,
76       nnzB,
77       reinterpret_cast<const cuComplex*>(csrSortedValB),
78       csrSortedRowPtrB,
79       csrSortedColIndB,
80       descrC,
81       reinterpret_cast<const cuComplex*>(csrSortedValC),
82       csrSortedRowPtrC,
83       csrSortedColIndC,
84       pBufferSizeInBytes));
85 }
86 
87 template <>
csrgeam2_bufferSizeExt(CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES (c10::complex<double>))88 void csrgeam2_bufferSizeExt<c10::complex<double>>(
89     CUSPARSE_CSRGEAM2_BUFFERSIZE_ARGTYPES(c10::complex<double>)) {
90   TORCH_CUDASPARSE_CHECK(cusparseZcsrgeam2_bufferSizeExt(
91       handle,
92       m,
93       n,
94       reinterpret_cast<const cuDoubleComplex*>(alpha),
95       descrA,
96       nnzA,
97       reinterpret_cast<const cuDoubleComplex*>(csrSortedValA),
98       csrSortedRowPtrA,
99       csrSortedColIndA,
100       reinterpret_cast<const cuDoubleComplex*>(beta),
101       descrB,
102       nnzB,
103       reinterpret_cast<const cuDoubleComplex*>(csrSortedValB),
104       csrSortedRowPtrB,
105       csrSortedColIndB,
106       descrC,
107       reinterpret_cast<const cuDoubleComplex*>(csrSortedValC),
108       csrSortedRowPtrC,
109       csrSortedColIndC,
110       pBufferSizeInBytes));
111 }
112 
113 template <>
csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES (float))114 void csrgeam2<float>(CUSPARSE_CSRGEAM2_ARGTYPES(float)) {
115   TORCH_CUDASPARSE_CHECK(cusparseScsrgeam2(
116       handle,
117       m,
118       n,
119       alpha,
120       descrA,
121       nnzA,
122       csrSortedValA,
123       csrSortedRowPtrA,
124       csrSortedColIndA,
125       beta,
126       descrB,
127       nnzB,
128       csrSortedValB,
129       csrSortedRowPtrB,
130       csrSortedColIndB,
131       descrC,
132       csrSortedValC,
133       csrSortedRowPtrC,
134       csrSortedColIndC,
135       pBuffer));
136 }
137 
138 template <>
csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES (double))139 void csrgeam2<double>(CUSPARSE_CSRGEAM2_ARGTYPES(double)) {
140   TORCH_CUDASPARSE_CHECK(cusparseDcsrgeam2(
141       handle,
142       m,
143       n,
144       alpha,
145       descrA,
146       nnzA,
147       csrSortedValA,
148       csrSortedRowPtrA,
149       csrSortedColIndA,
150       beta,
151       descrB,
152       nnzB,
153       csrSortedValB,
154       csrSortedRowPtrB,
155       csrSortedColIndB,
156       descrC,
157       csrSortedValC,
158       csrSortedRowPtrC,
159       csrSortedColIndC,
160       pBuffer));
161 }
162 
163 template <>
csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES (c10::complex<float>))164 void csrgeam2<c10::complex<float>>(
165     CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<float>)) {
166   TORCH_CUDASPARSE_CHECK(cusparseCcsrgeam2(
167       handle,
168       m,
169       n,
170       reinterpret_cast<const cuComplex*>(alpha),
171       descrA,
172       nnzA,
173       reinterpret_cast<const cuComplex*>(csrSortedValA),
174       csrSortedRowPtrA,
175       csrSortedColIndA,
176       reinterpret_cast<const cuComplex*>(beta),
177       descrB,
178       nnzB,
179       reinterpret_cast<const cuComplex*>(csrSortedValB),
180       csrSortedRowPtrB,
181       csrSortedColIndB,
182       descrC,
183       reinterpret_cast<cuComplex*>(csrSortedValC),
184       csrSortedRowPtrC,
185       csrSortedColIndC,
186       pBuffer));
187 }
188 
189 template <>
csrgeam2(CUSPARSE_CSRGEAM2_ARGTYPES (c10::complex<double>))190 void csrgeam2<c10::complex<double>>(
191     CUSPARSE_CSRGEAM2_ARGTYPES(c10::complex<double>)) {
192   TORCH_CUDASPARSE_CHECK(cusparseZcsrgeam2(
193       handle,
194       m,
195       n,
196       reinterpret_cast<const cuDoubleComplex*>(alpha),
197       descrA,
198       nnzA,
199       reinterpret_cast<const cuDoubleComplex*>(csrSortedValA),
200       csrSortedRowPtrA,
201       csrSortedColIndA,
202       reinterpret_cast<const cuDoubleComplex*>(beta),
203       descrB,
204       nnzB,
205       reinterpret_cast<const cuDoubleComplex*>(csrSortedValB),
206       csrSortedRowPtrB,
207       csrSortedColIndB,
208       descrC,
209       reinterpret_cast<cuDoubleComplex*>(csrSortedValC),
210       csrSortedRowPtrC,
211       csrSortedColIndC,
212       pBuffer));
213 }
214 
215 template <>
bsrmm(CUSPARSE_BSRMM_ARGTYPES (float))216 void bsrmm<float>(CUSPARSE_BSRMM_ARGTYPES(float)) {
217   TORCH_CUDASPARSE_CHECK(cusparseSbsrmm(
218       handle,
219       dirA,
220       transA,
221       transB,
222       mb,
223       n,
224       kb,
225       nnzb,
226       alpha,
227       descrA,
228       bsrValA,
229       bsrRowPtrA,
230       bsrColIndA,
231       blockDim,
232       B,
233       ldb,
234       beta,
235       C,
236       ldc));
237 }
238 
239 template <>
bsrmm(CUSPARSE_BSRMM_ARGTYPES (double))240 void bsrmm<double>(CUSPARSE_BSRMM_ARGTYPES(double)) {
241   TORCH_CUDASPARSE_CHECK(cusparseDbsrmm(
242       handle,
243       dirA,
244       transA,
245       transB,
246       mb,
247       n,
248       kb,
249       nnzb,
250       alpha,
251       descrA,
252       bsrValA,
253       bsrRowPtrA,
254       bsrColIndA,
255       blockDim,
256       B,
257       ldb,
258       beta,
259       C,
260       ldc));
261 }
262 
263 template <>
bsrmm(CUSPARSE_BSRMM_ARGTYPES (c10::complex<float>))264 void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>)) {
265   TORCH_CUDASPARSE_CHECK(cusparseCbsrmm(
266       handle,
267       dirA,
268       transA,
269       transB,
270       mb,
271       n,
272       kb,
273       nnzb,
274       reinterpret_cast<const cuComplex*>(alpha),
275       descrA,
276       reinterpret_cast<const cuComplex*>(bsrValA),
277       bsrRowPtrA,
278       bsrColIndA,
279       blockDim,
280       reinterpret_cast<const cuComplex*>(B),
281       ldb,
282       reinterpret_cast<const cuComplex*>(beta),
283       reinterpret_cast<cuComplex*>(C),
284       ldc));
285 }
286 
287 template <>
bsrmm(CUSPARSE_BSRMM_ARGTYPES (c10::complex<double>))288 void bsrmm<c10::complex<double>>(
289     CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>)) {
290   TORCH_CUDASPARSE_CHECK(cusparseZbsrmm(
291       handle,
292       dirA,
293       transA,
294       transB,
295       mb,
296       n,
297       kb,
298       nnzb,
299       reinterpret_cast<const cuDoubleComplex*>(alpha),
300       descrA,
301       reinterpret_cast<const cuDoubleComplex*>(bsrValA),
302       bsrRowPtrA,
303       bsrColIndA,
304       blockDim,
305       reinterpret_cast<const cuDoubleComplex*>(B),
306       ldb,
307       reinterpret_cast<const cuDoubleComplex*>(beta),
308       reinterpret_cast<cuDoubleComplex*>(C),
309       ldc));
310 }
311 
312 template <>
bsrmv(CUSPARSE_BSRMV_ARGTYPES (float))313 void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float)) {
314   TORCH_CUDASPARSE_CHECK(cusparseSbsrmv(
315       handle,
316       dirA,
317       transA,
318       mb,
319       nb,
320       nnzb,
321       alpha,
322       descrA,
323       bsrValA,
324       bsrRowPtrA,
325       bsrColIndA,
326       blockDim,
327       x,
328       beta,
329       y));
330 }
331 
332 template <>
bsrmv(CUSPARSE_BSRMV_ARGTYPES (double))333 void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double)) {
334   TORCH_CUDASPARSE_CHECK(cusparseDbsrmv(
335       handle,
336       dirA,
337       transA,
338       mb,
339       nb,
340       nnzb,
341       alpha,
342       descrA,
343       bsrValA,
344       bsrRowPtrA,
345       bsrColIndA,
346       blockDim,
347       x,
348       beta,
349       y));
350 }
351 
352 template <>
bsrmv(CUSPARSE_BSRMV_ARGTYPES (c10::complex<float>))353 void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>)) {
354   TORCH_CUDASPARSE_CHECK(cusparseCbsrmv(
355       handle,
356       dirA,
357       transA,
358       mb,
359       nb,
360       nnzb,
361       reinterpret_cast<const cuComplex*>(alpha),
362       descrA,
363       reinterpret_cast<const cuComplex*>(bsrValA),
364       bsrRowPtrA,
365       bsrColIndA,
366       blockDim,
367       reinterpret_cast<const cuComplex*>(x),
368       reinterpret_cast<const cuComplex*>(beta),
369       reinterpret_cast<cuComplex*>(y)));
370 }
371 
372 template <>
bsrmv(CUSPARSE_BSRMV_ARGTYPES (c10::complex<double>))373 void bsrmv<c10::complex<double>>(
374     CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>)) {
375   TORCH_CUDASPARSE_CHECK(cusparseZbsrmv(
376       handle,
377       dirA,
378       transA,
379       mb,
380       nb,
381       nnzb,
382       reinterpret_cast<const cuDoubleComplex*>(alpha),
383       descrA,
384       reinterpret_cast<const cuDoubleComplex*>(bsrValA),
385       bsrRowPtrA,
386       bsrColIndA,
387       blockDim,
388       reinterpret_cast<const cuDoubleComplex*>(x),
389       reinterpret_cast<const cuDoubleComplex*>(beta),
390       reinterpret_cast<cuDoubleComplex*>(y)));
391 }
392 
393 #if AT_USE_HIPSPARSE_TRIANGULAR_SOLVE()
394 
395 template <>
bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES (float))396 void bsrsv2_bufferSize<float>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(float)) {
397   TORCH_CUDASPARSE_CHECK(cusparseSbsrsv2_bufferSize(
398       handle,
399       dirA,
400       transA,
401       mb,
402       nnzb,
403       descrA,
404       bsrValA,
405       bsrRowPtrA,
406       bsrColIndA,
407       blockDim,
408       info,
409       pBufferSizeInBytes));
410 }
411 
412 template <>
bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES (double))413 void bsrsv2_bufferSize<double>(CUSPARSE_BSRSV2_BUFFER_ARGTYPES(double)) {
414   TORCH_CUDASPARSE_CHECK(cusparseDbsrsv2_bufferSize(
415       handle,
416       dirA,
417       transA,
418       mb,
419       nnzb,
420       descrA,
421       bsrValA,
422       bsrRowPtrA,
423       bsrColIndA,
424       blockDim,
425       info,
426       pBufferSizeInBytes));
427 }
428 
429 template <>
bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES (c10::complex<float>))430 void bsrsv2_bufferSize<c10::complex<float>>(
431     CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<float>)) {
432   TORCH_CUDASPARSE_CHECK(cusparseCbsrsv2_bufferSize(
433       handle,
434       dirA,
435       transA,
436       mb,
437       nnzb,
438       descrA,
439       reinterpret_cast<cuComplex*>(bsrValA),
440       bsrRowPtrA,
441       bsrColIndA,
442       blockDim,
443       info,
444       pBufferSizeInBytes));
445 }
446 
447 template <>
bsrsv2_bufferSize(CUSPARSE_BSRSV2_BUFFER_ARGTYPES (c10::complex<double>))448 void bsrsv2_bufferSize<c10::complex<double>>(
449     CUSPARSE_BSRSV2_BUFFER_ARGTYPES(c10::complex<double>)) {
450   TORCH_CUDASPARSE_CHECK(cusparseZbsrsv2_bufferSize(
451       handle,
452       dirA,
453       transA,
454       mb,
455       nnzb,
456       descrA,
457       reinterpret_cast<cuDoubleComplex*>(bsrValA),
458       bsrRowPtrA,
459       bsrColIndA,
460       blockDim,
461       info,
462       pBufferSizeInBytes));
463 }
464 
465 template <>
bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES (float))466 void bsrsv2_analysis<float>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(float)) {
467   TORCH_CUDASPARSE_CHECK(cusparseSbsrsv2_analysis(
468       handle,
469       dirA,
470       transA,
471       mb,
472       nnzb,
473       descrA,
474       bsrValA,
475       bsrRowPtrA,
476       bsrColIndA,
477       blockDim,
478       info,
479       policy,
480       pBuffer));
481 }
482 
483 template <>
bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES (double))484 void bsrsv2_analysis<double>(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(double)) {
485   TORCH_CUDASPARSE_CHECK(cusparseDbsrsv2_analysis(
486       handle,
487       dirA,
488       transA,
489       mb,
490       nnzb,
491       descrA,
492       bsrValA,
493       bsrRowPtrA,
494       bsrColIndA,
495       blockDim,
496       info,
497       policy,
498       pBuffer));
499 }
500 
501 template <>
bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES (c10::complex<float>))502 void bsrsv2_analysis<c10::complex<float>>(
503     CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<float>)) {
504   TORCH_CUDASPARSE_CHECK(cusparseCbsrsv2_analysis(
505       handle,
506       dirA,
507       transA,
508       mb,
509       nnzb,
510       descrA,
511       reinterpret_cast<const cuComplex*>(bsrValA),
512       bsrRowPtrA,
513       bsrColIndA,
514       blockDim,
515       info,
516       policy,
517       pBuffer));
518 }
519 
520 template <>
bsrsv2_analysis(CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES (c10::complex<double>))521 void bsrsv2_analysis<c10::complex<double>>(
522     CUSPARSE_BSRSV2_ANALYSIS_ARGTYPES(c10::complex<double>)) {
523   TORCH_CUDASPARSE_CHECK(cusparseZbsrsv2_analysis(
524       handle,
525       dirA,
526       transA,
527       mb,
528       nnzb,
529       descrA,
530       reinterpret_cast<const cuDoubleComplex*>(bsrValA),
531       bsrRowPtrA,
532       bsrColIndA,
533       blockDim,
534       info,
535       policy,
536       pBuffer));
537 }
538 
539 template <>
bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES (float))540 void bsrsv2_solve<float>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(float)) {
541   TORCH_CUDASPARSE_CHECK(cusparseSbsrsv2_solve(
542       handle,
543       dirA,
544       transA,
545       mb,
546       nnzb,
547       alpha,
548       descrA,
549       bsrValA,
550       bsrRowPtrA,
551       bsrColIndA,
552       blockDim,
553       info,
554       x,
555       y,
556       policy,
557       pBuffer));
558 }
559 
560 template <>
bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES (double))561 void bsrsv2_solve<double>(CUSPARSE_BSRSV2_SOLVE_ARGTYPES(double)) {
562   TORCH_CUDASPARSE_CHECK(cusparseDbsrsv2_solve(
563       handle,
564       dirA,
565       transA,
566       mb,
567       nnzb,
568       alpha,
569       descrA,
570       bsrValA,
571       bsrRowPtrA,
572       bsrColIndA,
573       blockDim,
574       info,
575       x,
576       y,
577       policy,
578       pBuffer));
579 }
580 
581 template <>
bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES (c10::complex<float>))582 void bsrsv2_solve<c10::complex<float>>(
583     CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<float>)) {
584   TORCH_CUDASPARSE_CHECK(cusparseCbsrsv2_solve(
585       handle,
586       dirA,
587       transA,
588       mb,
589       nnzb,
590       reinterpret_cast<const cuComplex*>(alpha),
591       descrA,
592       reinterpret_cast<const cuComplex*>(bsrValA),
593       bsrRowPtrA,
594       bsrColIndA,
595       blockDim,
596       info,
597       reinterpret_cast<const cuComplex*>(x),
598       reinterpret_cast<cuComplex*>(y),
599       policy,
600       pBuffer));
601 }
602 
603 template <>
bsrsv2_solve(CUSPARSE_BSRSV2_SOLVE_ARGTYPES (c10::complex<double>))604 void bsrsv2_solve<c10::complex<double>>(
605     CUSPARSE_BSRSV2_SOLVE_ARGTYPES(c10::complex<double>)) {
606   TORCH_CUDASPARSE_CHECK(cusparseZbsrsv2_solve(
607       handle,
608       dirA,
609       transA,
610       mb,
611       nnzb,
612       reinterpret_cast<const cuDoubleComplex*>(alpha),
613       descrA,
614       reinterpret_cast<const cuDoubleComplex*>(bsrValA),
615       bsrRowPtrA,
616       bsrColIndA,
617       blockDim,
618       info,
619       reinterpret_cast<const cuDoubleComplex*>(x),
620       reinterpret_cast<cuDoubleComplex*>(y),
621       policy,
622       pBuffer));
623 }
624 
625 template <>
bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES (float))626 void bsrsm2_bufferSize<float>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(float)) {
627   TORCH_CUDASPARSE_CHECK(cusparseSbsrsm2_bufferSize(
628       handle,
629       dirA,
630       transA,
631       transX,
632       mb,
633       n,
634       nnzb,
635       descrA,
636       bsrValA,
637       bsrRowPtrA,
638       bsrColIndA,
639       blockDim,
640       info,
641       pBufferSizeInBytes));
642 }
643 
644 template <>
bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES (double))645 void bsrsm2_bufferSize<double>(CUSPARSE_BSRSM2_BUFFER_ARGTYPES(double)) {
646   TORCH_CUDASPARSE_CHECK(cusparseDbsrsm2_bufferSize(
647       handle,
648       dirA,
649       transA,
650       transX,
651       mb,
652       n,
653       nnzb,
654       descrA,
655       bsrValA,
656       bsrRowPtrA,
657       bsrColIndA,
658       blockDim,
659       info,
660       pBufferSizeInBytes));
661 }
662 
663 template <>
bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES (c10::complex<float>))664 void bsrsm2_bufferSize<c10::complex<float>>(
665     CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<float>)) {
666   TORCH_CUDASPARSE_CHECK(cusparseCbsrsm2_bufferSize(
667       handle,
668       dirA,
669       transA,
670       transX,
671       mb,
672       n,
673       nnzb,
674       descrA,
675       reinterpret_cast<cuComplex*>(bsrValA),
676       bsrRowPtrA,
677       bsrColIndA,
678       blockDim,
679       info,
680       pBufferSizeInBytes));
681 }
682 
683 template <>
bsrsm2_bufferSize(CUSPARSE_BSRSM2_BUFFER_ARGTYPES (c10::complex<double>))684 void bsrsm2_bufferSize<c10::complex<double>>(
685     CUSPARSE_BSRSM2_BUFFER_ARGTYPES(c10::complex<double>)) {
686   TORCH_CUDASPARSE_CHECK(cusparseZbsrsm2_bufferSize(
687       handle,
688       dirA,
689       transA,
690       transX,
691       mb,
692       n,
693       nnzb,
694       descrA,
695       reinterpret_cast<cuDoubleComplex*>(bsrValA),
696       bsrRowPtrA,
697       bsrColIndA,
698       blockDim,
699       info,
700       pBufferSizeInBytes));
701 }
702 
703 template <>
bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES (float))704 void bsrsm2_analysis<float>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(float)) {
705   TORCH_CUDASPARSE_CHECK(cusparseSbsrsm2_analysis(
706       handle,
707       dirA,
708       transA,
709       transX,
710       mb,
711       n,
712       nnzb,
713       descrA,
714       bsrValA,
715       bsrRowPtrA,
716       bsrColIndA,
717       blockDim,
718       info,
719       policy,
720       pBuffer));
721 }
722 
723 template <>
bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES (double))724 void bsrsm2_analysis<double>(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(double)) {
725   TORCH_CUDASPARSE_CHECK(cusparseDbsrsm2_analysis(
726       handle,
727       dirA,
728       transA,
729       transX,
730       mb,
731       n,
732       nnzb,
733       descrA,
734       bsrValA,
735       bsrRowPtrA,
736       bsrColIndA,
737       blockDim,
738       info,
739       policy,
740       pBuffer));
741 }
742 
743 template <>
bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES (c10::complex<float>))744 void bsrsm2_analysis<c10::complex<float>>(
745     CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<float>)) {
746   TORCH_CUDASPARSE_CHECK(cusparseCbsrsm2_analysis(
747       handle,
748       dirA,
749       transA,
750       transX,
751       mb,
752       n,
753       nnzb,
754       descrA,
755       reinterpret_cast<const cuComplex*>(bsrValA),
756       bsrRowPtrA,
757       bsrColIndA,
758       blockDim,
759       info,
760       policy,
761       pBuffer));
762 }
763 
764 template <>
bsrsm2_analysis(CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES (c10::complex<double>))765 void bsrsm2_analysis<c10::complex<double>>(
766     CUSPARSE_BSRSM2_ANALYSIS_ARGTYPES(c10::complex<double>)) {
767   TORCH_CUDASPARSE_CHECK(cusparseZbsrsm2_analysis(
768       handle,
769       dirA,
770       transA,
771       transX,
772       mb,
773       n,
774       nnzb,
775       descrA,
776       reinterpret_cast<const cuDoubleComplex*>(bsrValA),
777       bsrRowPtrA,
778       bsrColIndA,
779       blockDim,
780       info,
781       policy,
782       pBuffer));
783 }
784 
785 template <>
bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES (float))786 void bsrsm2_solve<float>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(float)) {
787   TORCH_CUDASPARSE_CHECK(cusparseSbsrsm2_solve(
788       handle,
789       dirA,
790       transA,
791       transX,
792       mb,
793       n,
794       nnzb,
795       alpha,
796       descrA,
797       bsrValA,
798       bsrRowPtrA,
799       bsrColIndA,
800       blockDim,
801       info,
802       B,
803       ldb,
804       X,
805       ldx,
806       policy,
807       pBuffer));
808 }
809 
810 template <>
bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES (double))811 void bsrsm2_solve<double>(CUSPARSE_BSRSM2_SOLVE_ARGTYPES(double)) {
812   TORCH_CUDASPARSE_CHECK(cusparseDbsrsm2_solve(
813       handle,
814       dirA,
815       transA,
816       transX,
817       mb,
818       n,
819       nnzb,
820       alpha,
821       descrA,
822       bsrValA,
823       bsrRowPtrA,
824       bsrColIndA,
825       blockDim,
826       info,
827       B,
828       ldb,
829       X,
830       ldx,
831       policy,
832       pBuffer));
833 }
834 
835 template <>
bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES (c10::complex<float>))836 void bsrsm2_solve<c10::complex<float>>(
837     CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<float>)) {
838   TORCH_CUDASPARSE_CHECK(cusparseCbsrsm2_solve(
839       handle,
840       dirA,
841       transA,
842       transX,
843       mb,
844       n,
845       nnzb,
846       reinterpret_cast<const cuComplex*>(alpha),
847       descrA,
848       reinterpret_cast<const cuComplex*>(bsrValA),
849       bsrRowPtrA,
850       bsrColIndA,
851       blockDim,
852       info,
853       reinterpret_cast<const cuComplex*>(B),
854       ldb,
855       reinterpret_cast<cuComplex*>(X),
856       ldx,
857       policy,
858       pBuffer));
859 }
860 
861 template <>
bsrsm2_solve(CUSPARSE_BSRSM2_SOLVE_ARGTYPES (c10::complex<double>))862 void bsrsm2_solve<c10::complex<double>>(
863     CUSPARSE_BSRSM2_SOLVE_ARGTYPES(c10::complex<double>)) {
864   TORCH_CUDASPARSE_CHECK(cusparseZbsrsm2_solve(
865       handle,
866       dirA,
867       transA,
868       transX,
869       mb,
870       n,
871       nnzb,
872       reinterpret_cast<const cuDoubleComplex*>(alpha),
873       descrA,
874       reinterpret_cast<const cuDoubleComplex*>(bsrValA),
875       bsrRowPtrA,
876       bsrColIndA,
877       blockDim,
878       info,
879       reinterpret_cast<const cuDoubleComplex*>(B),
880       ldb,
881       reinterpret_cast<cuDoubleComplex*>(X),
882       ldx,
883       policy,
884       pBuffer));
885 }
886 
887 #endif // AT_USE_HIPSPARSE_TRIANGULAR_SOLVE
888 
889 } // namespace at::cuda::sparse
890