xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Context.h>
2 #include <ATen/NativeFunctions.h>
3 #include <ATen/native/cuda/linalg/CUDASolver.h>
4 #include <c10/cuda/CUDACachingAllocator.h>
5 #include <c10/macros/Export.h>
6 
7 #if defined(CUDART_VERSION) || defined(USE_ROCM)
8 
9 namespace at::cuda::solver {
10 
11 template <>
getrf(cusolverDnHandle_t handle,int m,int n,double * dA,int ldda,int * ipiv,int * info)12 void getrf<double>(
13     cusolverDnHandle_t handle, int m, int n, double* dA, int ldda, int* ipiv, int* info) {
14   int lwork;
15   TORCH_CUSOLVER_CHECK(
16       cusolverDnDgetrf_bufferSize(handle, m, n, dA, ldda, &lwork));
17   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
18   auto dataPtr = allocator.allocate(sizeof(double)*lwork);
19   TORCH_CUSOLVER_CHECK(cusolverDnDgetrf(
20       handle, m, n, dA, ldda, static_cast<double*>(dataPtr.get()), ipiv, info));
21 }
22 
23 template <>
getrf(cusolverDnHandle_t handle,int m,int n,float * dA,int ldda,int * ipiv,int * info)24 void getrf<float>(
25     cusolverDnHandle_t handle, int m, int n, float* dA, int ldda, int* ipiv, int* info) {
26   int lwork;
27   TORCH_CUSOLVER_CHECK(
28       cusolverDnSgetrf_bufferSize(handle, m, n, dA, ldda, &lwork));
29   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
30   auto dataPtr = allocator.allocate(sizeof(float)*lwork);
31   TORCH_CUSOLVER_CHECK(cusolverDnSgetrf(
32       handle, m, n, dA, ldda, static_cast<float*>(dataPtr.get()), ipiv, info));
33 }
34 
35 template <>
getrf(cusolverDnHandle_t handle,int m,int n,c10::complex<double> * dA,int ldda,int * ipiv,int * info)36 void getrf<c10::complex<double>>(
37     cusolverDnHandle_t handle,
38     int m,
39     int n,
40     c10::complex<double>* dA,
41     int ldda,
42     int* ipiv,
43     int* info) {
44   int lwork;
45   TORCH_CUSOLVER_CHECK(cusolverDnZgetrf_bufferSize(
46       handle, m, n, reinterpret_cast<cuDoubleComplex*>(dA), ldda, &lwork));
47   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
48   auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex) * lwork);
49   TORCH_CUSOLVER_CHECK(cusolverDnZgetrf(
50       handle,
51       m,
52       n,
53       reinterpret_cast<cuDoubleComplex*>(dA),
54       ldda,
55       static_cast<cuDoubleComplex*>(dataPtr.get()),
56       ipiv,
57       info));
58 }
59 
60 template <>
getrf(cusolverDnHandle_t handle,int m,int n,c10::complex<float> * dA,int ldda,int * ipiv,int * info)61 void getrf<c10::complex<float>>(
62     cusolverDnHandle_t handle,
63     int m,
64     int n,
65     c10::complex<float>* dA,
66     int ldda,
67     int* ipiv,
68     int* info) {
69   int lwork;
70   TORCH_CUSOLVER_CHECK(cusolverDnCgetrf_bufferSize(
71       handle, m, n, reinterpret_cast<cuComplex*>(dA), ldda, &lwork));
72   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
73   auto dataPtr = allocator.allocate(sizeof(cuComplex) * lwork);
74   TORCH_CUSOLVER_CHECK(cusolverDnCgetrf(
75       handle,
76       m,
77       n,
78       reinterpret_cast<cuComplex*>(dA),
79       ldda,
80       static_cast<cuComplex*>(dataPtr.get()),
81       ipiv,
82       info));
83 }
84 
85 template <>
getrs(cusolverDnHandle_t handle,int n,int nrhs,double * dA,int lda,int * ipiv,double * ret,int ldb,int * info,cublasOperation_t trans)86 void getrs<double>(
87     cusolverDnHandle_t handle, int n, int nrhs, double* dA, int lda, int* ipiv, double* ret, int ldb, int* info, cublasOperation_t trans) {
88   TORCH_CUSOLVER_CHECK(cusolverDnDgetrs(
89     handle, trans, n, nrhs, dA, lda, ipiv, ret, ldb, info));
90 }
91 
92 template <>
getrs(cusolverDnHandle_t handle,int n,int nrhs,float * dA,int lda,int * ipiv,float * ret,int ldb,int * info,cublasOperation_t trans)93 void getrs<float>(
94     cusolverDnHandle_t handle, int n, int nrhs, float* dA, int lda, int* ipiv, float* ret, int ldb, int* info, cublasOperation_t trans) {
95   TORCH_CUSOLVER_CHECK(cusolverDnSgetrs(
96     handle, trans, n, nrhs, dA, lda, ipiv, ret, ldb, info));
97 }
98 
99 template <>
getrs(cusolverDnHandle_t handle,int n,int nrhs,c10::complex<double> * dA,int lda,int * ipiv,c10::complex<double> * ret,int ldb,int * info,cublasOperation_t trans)100 void getrs<c10::complex<double>>(
101     cusolverDnHandle_t handle,
102     int n,
103     int nrhs,
104     c10::complex<double>* dA,
105     int lda,
106     int* ipiv,
107     c10::complex<double>* ret,
108     int ldb,
109     int* info,
110     cublasOperation_t trans) {
111   TORCH_CUSOLVER_CHECK(cusolverDnZgetrs(
112       handle,
113       trans,
114       n,
115       nrhs,
116       reinterpret_cast<cuDoubleComplex*>(dA),
117       lda,
118       ipiv,
119       reinterpret_cast<cuDoubleComplex*>(ret),
120       ldb,
121       info));
122 }
123 
124 template <>
getrs(cusolverDnHandle_t handle,int n,int nrhs,c10::complex<float> * dA,int lda,int * ipiv,c10::complex<float> * ret,int ldb,int * info,cublasOperation_t trans)125 void getrs<c10::complex<float>>(
126     cusolverDnHandle_t handle,
127     int n,
128     int nrhs,
129     c10::complex<float>* dA,
130     int lda,
131     int* ipiv,
132     c10::complex<float>* ret,
133     int ldb,
134     int* info,
135     cublasOperation_t trans) {
136   TORCH_CUSOLVER_CHECK(cusolverDnCgetrs(
137       handle,
138       trans,
139       n,
140       nrhs,
141       reinterpret_cast<cuComplex*>(dA),
142       lda,
143       ipiv,
144       reinterpret_cast<cuComplex*>(ret),
145       ldb,
146       info));
147 }
148 
149 template <>
sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES (double))150 void sytrf_bufferSize<double>(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(double)) {
151   TORCH_CUSOLVER_CHECK(cusolverDnDsytrf_bufferSize(handle, n, A, lda, lwork));
152 }
153 
154 template <>
sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES (float))155 void sytrf_bufferSize<float>(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(float)) {
156   TORCH_CUSOLVER_CHECK(cusolverDnSsytrf_bufferSize(handle, n, A, lda, lwork));
157 }
158 
159 template <>
sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES (c10::complex<double>))160 void sytrf_bufferSize<c10::complex<double>>(
161     CUDASOLVER_SYTRF_BUFFER_ARGTYPES(c10::complex<double>)) {
162   TORCH_CUSOLVER_CHECK(cusolverDnZsytrf_bufferSize(
163       handle, n, reinterpret_cast<cuDoubleComplex*>(A), lda, lwork));
164 }
165 
166 template <>
sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES (c10::complex<float>))167 void sytrf_bufferSize<c10::complex<float>>(
168     CUDASOLVER_SYTRF_BUFFER_ARGTYPES(c10::complex<float>)) {
169   TORCH_CUSOLVER_CHECK(cusolverDnCsytrf_bufferSize(
170       handle, n, reinterpret_cast<cuComplex*>(A), lda, lwork));
171 }
172 
173 template <>
sytrf(CUDASOLVER_SYTRF_ARGTYPES (double))174 void sytrf<double>(CUDASOLVER_SYTRF_ARGTYPES(double)) {
175   TORCH_CUSOLVER_CHECK(
176       cusolverDnDsytrf(handle, uplo, n, A, lda, ipiv, work, lwork, devInfo));
177 }
178 
179 template <>
sytrf(CUDASOLVER_SYTRF_ARGTYPES (float))180 void sytrf<float>(CUDASOLVER_SYTRF_ARGTYPES(float)) {
181   TORCH_CUSOLVER_CHECK(
182       cusolverDnSsytrf(handle, uplo, n, A, lda, ipiv, work, lwork, devInfo));
183 }
184 
185 template <>
sytrf(CUDASOLVER_SYTRF_ARGTYPES (c10::complex<double>))186 void sytrf<c10::complex<double>>(
187     CUDASOLVER_SYTRF_ARGTYPES(c10::complex<double>)) {
188   TORCH_CUSOLVER_CHECK(cusolverDnZsytrf(
189       handle,
190       uplo,
191       n,
192       reinterpret_cast<cuDoubleComplex*>(A),
193       lda,
194       ipiv,
195       reinterpret_cast<cuDoubleComplex*>(work),
196       lwork,
197       devInfo));
198 }
199 
200 template <>
sytrf(CUDASOLVER_SYTRF_ARGTYPES (c10::complex<float>))201 void sytrf<c10::complex<float>>(
202     CUDASOLVER_SYTRF_ARGTYPES(c10::complex<float>)) {
203   TORCH_CUSOLVER_CHECK(cusolverDnCsytrf(
204       handle,
205       uplo,
206       n,
207       reinterpret_cast<cuComplex*>(A),
208       lda,
209       ipiv,
210       reinterpret_cast<cuComplex*>(work),
211       lwork,
212       devInfo));
213 }
214 
215 template<>
gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES ())216 void gesvd_buffersize<float>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
217   TORCH_CUSOLVER_CHECK(cusolverDnSgesvd_bufferSize(handle, m, n, lwork));
218 }
219 
220 template<>
gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES ())221 void gesvd_buffersize<double>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
222   TORCH_CUSOLVER_CHECK(cusolverDnDgesvd_bufferSize(handle, m, n, lwork));
223 }
224 
225 template<>
gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES ())226 void gesvd_buffersize<c10::complex<float>>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
227   TORCH_CUSOLVER_CHECK(cusolverDnCgesvd_bufferSize(handle, m, n, lwork));
228 }
229 
230 template<>
gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES ())231 void gesvd_buffersize<c10::complex<double>>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
232   TORCH_CUSOLVER_CHECK(cusolverDnZgesvd_bufferSize(handle, m, n, lwork));
233 }
234 
235 
236 template<>
gesvd(CUDASOLVER_GESVD_ARGTYPES (float,float))237 void gesvd<float>(CUDASOLVER_GESVD_ARGTYPES(float, float)) {
238   TORCH_CUSOLVER_CHECK(cusolverDnSgesvd(
239       handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, lwork, rwork, info));
240 }
241 
242 template<>
gesvd(CUDASOLVER_GESVD_ARGTYPES (double,double))243 void gesvd<double>(CUDASOLVER_GESVD_ARGTYPES(double, double)) {
244   TORCH_CUSOLVER_CHECK(cusolverDnDgesvd(
245       handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, lwork, rwork, info));
246 }
247 
248 
249 template<>
gesvd(CUDASOLVER_GESVD_ARGTYPES (c10::complex<float>,float))250 void gesvd<c10::complex<float>>(CUDASOLVER_GESVD_ARGTYPES(c10::complex<float>, float)) {
251   TORCH_CUSOLVER_CHECK(cusolverDnCgesvd(
252       handle, jobu, jobvt, m, n,
253       reinterpret_cast<cuComplex*>(A),
254       lda, S,
255       reinterpret_cast<cuComplex*>(U),
256       ldu,
257       reinterpret_cast<cuComplex*>(VT),
258       ldvt,
259       reinterpret_cast<cuComplex*>(work),
260       lwork, rwork, info
261   ));
262 }
263 
264 template<>
gesvd(CUDASOLVER_GESVD_ARGTYPES (c10::complex<double>,double))265 void gesvd<c10::complex<double>>(CUDASOLVER_GESVD_ARGTYPES(c10::complex<double>, double)) {
266   TORCH_CUSOLVER_CHECK(cusolverDnZgesvd(
267       handle, jobu, jobvt, m, n,
268       reinterpret_cast<cuDoubleComplex*>(A),
269       lda, S,
270       reinterpret_cast<cuDoubleComplex*>(U),
271       ldu,
272       reinterpret_cast<cuDoubleComplex*>(VT),
273       ldvt,
274       reinterpret_cast<cuDoubleComplex*>(work),
275       lwork, rwork, info
276   ));
277 }
278 
279 
280 template<>
gesvdj_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,float * A,int lda,float * S,float * U,int ldu,float * V,int ldv,int * lwork,gesvdjInfo_t params)281 void gesvdj_buffersize<float>(
282     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, float *A, int lda, float *S,
283     float *U, int ldu, float *V, int ldv, int *lwork, gesvdjInfo_t params
284 ) {
285   TORCH_CUSOLVER_CHECK(cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, params));
286 }
287 
288 template<>
gesvdj_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,double * A,int lda,double * S,double * U,int ldu,double * V,int ldv,int * lwork,gesvdjInfo_t params)289 void gesvdj_buffersize<double>(
290     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, double *A, int lda, double *S,
291     double *U, int ldu, double *V, int ldv, int *lwork, gesvdjInfo_t params
292 ) {
293   TORCH_CUSOLVER_CHECK(cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, params));
294 }
295 
296 template<>
gesvdj_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,c10::complex<float> * A,int lda,float * S,c10::complex<float> * U,int ldu,c10::complex<float> * V,int ldv,int * lwork,gesvdjInfo_t params)297 void gesvdj_buffersize<c10::complex<float>>(
298     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex<float> *A, int lda, float *S,
299     c10::complex<float> *U, int ldu, c10::complex<float> *V, int ldv, int *lwork, gesvdjInfo_t params
300 ) {
301   TORCH_CUSOLVER_CHECK(cusolverDnCgesvdj_bufferSize(handle, jobz, econ, m, n,
302     reinterpret_cast<cuComplex*>(A),
303     lda,
304     S,
305     reinterpret_cast<cuComplex*>(U),
306     ldu,
307     reinterpret_cast<cuComplex*>(V),
308     ldv, lwork, params));
309 }
310 
311 template<>
gesvdj_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,c10::complex<double> * A,int lda,double * S,c10::complex<double> * U,int ldu,c10::complex<double> * V,int ldv,int * lwork,gesvdjInfo_t params)312 void gesvdj_buffersize<c10::complex<double>>(
313     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex<double> *A, int lda, double *S,
314     c10::complex<double> *U, int ldu, c10::complex<double> *V, int ldv, int *lwork, gesvdjInfo_t params
315 ) {
316   TORCH_CUSOLVER_CHECK(cusolverDnZgesvdj_bufferSize(handle, jobz, econ, m, n,
317     reinterpret_cast<cuDoubleComplex*>(A),
318     lda,
319     S,
320     reinterpret_cast<cuDoubleComplex*>(U),
321     ldu,
322     reinterpret_cast<cuDoubleComplex*>(V),
323     ldv, lwork, params));
324 }
325 
326 
327 template<>
gesvdj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,float * A,int lda,float * S,float * U,int ldu,float * V,int ldv,float * work,int lwork,int * info,gesvdjInfo_t params)328 void gesvdj<float>(
329     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, float* A, int lda, float* S, float* U,
330     int ldu, float *V, int ldv, float* work, int lwork, int *info, gesvdjInfo_t params
331 ) {
332   TORCH_CUSOLVER_CHECK(cusolverDnSgesvdj(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, info, params));
333 }
334 
335 template<>
gesvdj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,double * A,int lda,double * S,double * U,int ldu,double * V,int ldv,double * work,int lwork,int * info,gesvdjInfo_t params)336 void gesvdj<double>(
337     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, double* A, int lda, double* S, double* U,
338     int ldu, double *V, int ldv, double* work, int lwork, int *info, gesvdjInfo_t params
339 ) {
340   TORCH_CUSOLVER_CHECK(cusolverDnDgesvdj(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, info, params));
341 }
342 
343 template<>
gesvdj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,c10::complex<float> * A,int lda,float * S,c10::complex<float> * U,int ldu,c10::complex<float> * V,int ldv,c10::complex<float> * work,int lwork,int * info,gesvdjInfo_t params)344 void gesvdj<c10::complex<float>>(
345     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex<float>* A, int lda, float* S, c10::complex<float>* U,
346     int ldu, c10::complex<float> *V, int ldv, c10::complex<float>* work, int lwork, int *info, gesvdjInfo_t params
347 ) {
348   TORCH_CUSOLVER_CHECK(cusolverDnCgesvdj(
349     handle, jobz, econ, m, n,
350     reinterpret_cast<cuComplex*>(A),
351     lda, S,
352     reinterpret_cast<cuComplex*>(U),
353     ldu,
354     reinterpret_cast<cuComplex*>(V),
355     ldv,
356     reinterpret_cast<cuComplex*>(work),
357     lwork, info, params));
358 }
359 
360 template<>
gesvdj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,c10::complex<double> * A,int lda,double * S,c10::complex<double> * U,int ldu,c10::complex<double> * V,int ldv,c10::complex<double> * work,int lwork,int * info,gesvdjInfo_t params)361 void gesvdj<c10::complex<double>>(
362     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex<double>* A, int lda, double* S, c10::complex<double>* U,
363     int ldu, c10::complex<double> *V, int ldv, c10::complex<double>* work, int lwork, int *info, gesvdjInfo_t params
364 ) {
365   TORCH_CUSOLVER_CHECK(cusolverDnZgesvdj(
366     handle, jobz, econ, m, n,
367     reinterpret_cast<cuDoubleComplex*>(A),
368     lda, S,
369     reinterpret_cast<cuDoubleComplex*>(U),
370     ldu,
371     reinterpret_cast<cuDoubleComplex*>(V),
372     ldv,
373     reinterpret_cast<cuDoubleComplex*>(work),
374     lwork, info, params));
375 }
376 
377 
378 template<>
gesvdjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int m,int n,float * A,int lda,float * S,float * U,int ldu,float * V,int ldv,int * info,gesvdjInfo_t params,int batchSize)379 void gesvdjBatched<float>(
380     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, float* A, int lda, float* S, float* U,
381     int ldu, float *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
382 ) {
383   int lwork;
384   TORCH_CUSOLVER_CHECK(cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, &lwork, params, batchSize));
385 
386   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
387   auto dataPtr = allocator.allocate(sizeof(float)*lwork);
388 
389   TORCH_CUSOLVER_CHECK(cusolverDnSgesvdjBatched(
390     handle, jobz, m, n, A, lda, S, U, ldu, V, ldv,
391     static_cast<float*>(dataPtr.get()),
392     lwork, info, params, batchSize));
393 }
394 
395 template<>
gesvdjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int m,int n,double * A,int lda,double * S,double * U,int ldu,double * V,int ldv,int * info,gesvdjInfo_t params,int batchSize)396 void gesvdjBatched<double>(
397     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, double* A, int lda, double* S, double* U,
398     int ldu, double *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
399 ) {
400   int lwork;
401   TORCH_CUSOLVER_CHECK(cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, &lwork, params, batchSize));
402 
403   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
404   auto dataPtr = allocator.allocate(sizeof(double)*lwork);
405 
406   TORCH_CUSOLVER_CHECK(cusolverDnDgesvdjBatched(
407     handle, jobz, m, n, A, lda, S, U, ldu, V, ldv,
408     static_cast<double*>(dataPtr.get()),
409     lwork, info, params, batchSize));
410 }
411 
412 template<>
gesvdjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int m,int n,c10::complex<float> * A,int lda,float * S,c10::complex<float> * U,int ldu,c10::complex<float> * V,int ldv,int * info,gesvdjInfo_t params,int batchSize)413 void gesvdjBatched<c10::complex<float>>(
414     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, c10::complex<float>* A, int lda, float* S, c10::complex<float>* U,
415     int ldu, c10::complex<float> *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
416 ) {
417   int lwork;
418   TORCH_CUSOLVER_CHECK(cusolverDnCgesvdjBatched_bufferSize(
419     handle, jobz, m, n,
420     reinterpret_cast<cuComplex*>(A),
421     lda, S,
422     reinterpret_cast<cuComplex*>(U),
423     ldu,
424     reinterpret_cast<cuComplex*>(V),
425     ldv, &lwork, params, batchSize));
426 
427   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
428   auto dataPtr = allocator.allocate(sizeof(cuComplex)*lwork);
429 
430   TORCH_CUSOLVER_CHECK(cusolverDnCgesvdjBatched(
431     handle, jobz, m, n,
432     reinterpret_cast<cuComplex*>(A),
433     lda, S,
434     reinterpret_cast<cuComplex*>(U),
435     ldu,
436     reinterpret_cast<cuComplex*>(V),
437     ldv,
438     static_cast<cuComplex*>(dataPtr.get()),
439     lwork, info, params, batchSize));
440 }
441 
442 template<>
gesvdjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int m,int n,c10::complex<double> * A,int lda,double * S,c10::complex<double> * U,int ldu,c10::complex<double> * V,int ldv,int * info,gesvdjInfo_t params,int batchSize)443 void gesvdjBatched<c10::complex<double>>(
444     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, c10::complex<double>* A, int lda, double* S, c10::complex<double>* U,
445     int ldu, c10::complex<double> *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
446 ) {
447   int lwork;
448   TORCH_CUSOLVER_CHECK(cusolverDnZgesvdjBatched_bufferSize(
449     handle, jobz, m, n,
450     reinterpret_cast<cuDoubleComplex*>(A),
451     lda, S,
452     reinterpret_cast<cuDoubleComplex*>(U),
453     ldu,
454     reinterpret_cast<cuDoubleComplex*>(V),
455     ldv, &lwork, params, batchSize));
456 
457   auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
458   auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex)*lwork);
459 
460   TORCH_CUSOLVER_CHECK(cusolverDnZgesvdjBatched(
461     handle, jobz, m, n,
462     reinterpret_cast<cuDoubleComplex*>(A),
463     lda, S,
464     reinterpret_cast<cuDoubleComplex*>(U),
465     ldu,
466     reinterpret_cast<cuDoubleComplex*>(V),
467     ldv,
468     static_cast<cuDoubleComplex*>(dataPtr.get()),
469     lwork, info, params, batchSize));
470 }
471 
472 
473 // ROCM does not implement gesdva yet
474 #ifdef CUDART_VERSION
475 template<>
gesvdaStridedBatched_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,float * A,int lda,long long int strideA,float * S,long long int strideS,float * U,int ldu,long long int strideU,float * V,int ldv,long long int strideV,int * lwork,int batchSize)476 void gesvdaStridedBatched_buffersize<float>(
477     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, float *A, int lda, long long int strideA,
478     float *S, long long int strideS, float *U, int ldu, long long int strideU, float *V, int ldv, long long int strideV,
479     int *lwork, int batchSize
480 ) {
481   TORCH_CUSOLVER_CHECK(cusolverDnSgesvdaStridedBatched_bufferSize(
482     handle, jobz, rank, m, n, A, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, lwork, batchSize
483   ));
484 }
485 
486 template<>
gesvdaStridedBatched_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,double * A,int lda,long long int strideA,double * S,long long int strideS,double * U,int ldu,long long int strideU,double * V,int ldv,long long int strideV,int * lwork,int batchSize)487 void gesvdaStridedBatched_buffersize<double>(
488     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, double *A, int lda, long long int strideA,
489     double *S, long long int strideS, double *U, int ldu, long long int strideU, double *V, int ldv, long long int strideV,
490     int *lwork, int batchSize
491 ) {
492   TORCH_CUSOLVER_CHECK(cusolverDnDgesvdaStridedBatched_bufferSize(
493     handle, jobz, rank, m, n, A, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, lwork, batchSize
494   ));
495 }
496 
497 template<>
gesvdaStridedBatched_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,c10::complex<float> * A,int lda,long long int strideA,float * S,long long int strideS,c10::complex<float> * U,int ldu,long long int strideU,c10::complex<float> * V,int ldv,long long int strideV,int * lwork,int batchSize)498 void gesvdaStridedBatched_buffersize<c10::complex<float>>(
499     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, c10::complex<float> *A, int lda, long long int strideA,
500     float *S, long long int strideS, c10::complex<float> *U, int ldu, long long int strideU,
501     c10::complex<float> *V, int ldv, long long int strideV,
502     int *lwork, int batchSize
503 ) {
504   TORCH_CUSOLVER_CHECK(cusolverDnCgesvdaStridedBatched_bufferSize(
505     handle, jobz, rank, m, n,
506     reinterpret_cast<cuComplex*>(A),
507     lda, strideA, S, strideS,
508     reinterpret_cast<cuComplex*>(U),
509     ldu, strideU,
510     reinterpret_cast<cuComplex*>(V),
511     ldv, strideV, lwork, batchSize
512   ));
513 }
514 
515 template<>
gesvdaStridedBatched_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,c10::complex<double> * A,int lda,long long int strideA,double * S,long long int strideS,c10::complex<double> * U,int ldu,long long int strideU,c10::complex<double> * V,int ldv,long long int strideV,int * lwork,int batchSize)516 void gesvdaStridedBatched_buffersize<c10::complex<double>>(
517     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, c10::complex<double> *A, int lda, long long int strideA,
518     double *S, long long int strideS, c10::complex<double> *U, int ldu, long long int strideU,
519     c10::complex<double> *V, int ldv, long long int strideV,
520     int *lwork, int batchSize
521 ) {
522   TORCH_CUSOLVER_CHECK(cusolverDnZgesvdaStridedBatched_bufferSize(
523     handle, jobz, rank, m, n,
524     reinterpret_cast<cuDoubleComplex*>(A),
525     lda, strideA, S, strideS,
526     reinterpret_cast<cuDoubleComplex*>(U),
527     ldu, strideU,
528     reinterpret_cast<cuDoubleComplex*>(V),
529     ldv, strideV, lwork, batchSize
530   ));
531 }
532 
533 
534 template<>
gesvdaStridedBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,float * A,int lda,long long int strideA,float * S,long long int strideS,float * U,int ldu,long long int strideU,float * V,int ldv,long long int strideV,float * work,int lwork,int * info,double * h_R_nrmF,int batchSize)535 void gesvdaStridedBatched<float>(
536     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, float *A, int lda, long long int strideA,
537     float *S, long long int strideS, float *U, int ldu, long long int strideU, float *V, int ldv, long long int strideV,
538     float *work, int lwork, int *info, double *h_R_nrmF, int batchSize
539 ) {
540   TORCH_CUSOLVER_CHECK(cusolverDnSgesvdaStridedBatched(
541     handle, jobz, rank, m, n, A, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, work, lwork, info, h_R_nrmF, batchSize
542   ));
543 }
544 
545 template<>
gesvdaStridedBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,double * A,int lda,long long int strideA,double * S,long long int strideS,double * U,int ldu,long long int strideU,double * V,int ldv,long long int strideV,double * work,int lwork,int * info,double * h_R_nrmF,int batchSize)546 void gesvdaStridedBatched<double>(
547     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, double *A, int lda, long long int strideA,
548     double *S, long long int strideS, double *U, int ldu, long long int strideU, double *V, int ldv, long long int strideV,
549     double *work, int lwork, int *info, double *h_R_nrmF, int batchSize
550 ) {
551   TORCH_CUSOLVER_CHECK(cusolverDnDgesvdaStridedBatched(
552     handle, jobz, rank, m, n, A, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, work, lwork, info, h_R_nrmF, batchSize
553   ));
554 }
555 
556 template<>
gesvdaStridedBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,c10::complex<float> * A,int lda,long long int strideA,float * S,long long int strideS,c10::complex<float> * U,int ldu,long long int strideU,c10::complex<float> * V,int ldv,long long int strideV,c10::complex<float> * work,int lwork,int * info,double * h_R_nrmF,int batchSize)557 void gesvdaStridedBatched<c10::complex<float>>(
558     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, c10::complex<float> *A, int lda, long long int strideA,
559     float *S, long long int strideS, c10::complex<float> *U, int ldu, long long int strideU,
560     c10::complex<float> *V, int ldv, long long int strideV,
561     c10::complex<float> *work, int lwork, int *info, double *h_R_nrmF, int batchSize
562 ) {
563   TORCH_CUSOLVER_CHECK(cusolverDnCgesvdaStridedBatched(
564     handle, jobz, rank, m, n,
565     reinterpret_cast<cuComplex*>(A),
566     lda, strideA, S, strideS,
567     reinterpret_cast<cuComplex*>(U),
568     ldu, strideU,
569     reinterpret_cast<cuComplex*>(V),
570     ldv, strideV,
571     reinterpret_cast<cuComplex*>(work),
572     lwork, info, h_R_nrmF, batchSize
573   ));
574 }
575 
576 template<>
gesvdaStridedBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,c10::complex<double> * A,int lda,long long int strideA,double * S,long long int strideS,c10::complex<double> * U,int ldu,long long int strideU,c10::complex<double> * V,int ldv,long long int strideV,c10::complex<double> * work,int lwork,int * info,double * h_R_nrmF,int batchSize)577 void gesvdaStridedBatched<c10::complex<double>>(
578     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, c10::complex<double> *A, int lda, long long int strideA,
579     double *S, long long int strideS, c10::complex<double> *U, int ldu, long long int strideU,
580     c10::complex<double> *V, int ldv, long long int strideV,
581     c10::complex<double> *work, int lwork, int *info, double *h_R_nrmF, int batchSize
582 ) {
583   TORCH_CUSOLVER_CHECK(cusolverDnZgesvdaStridedBatched(
584     handle, jobz, rank, m, n,
585     reinterpret_cast<cuDoubleComplex*>(A),
586     lda, strideA, S, strideS,
587     reinterpret_cast<cuDoubleComplex*>(U),
588     ldu, strideU,
589     reinterpret_cast<cuDoubleComplex*>(V),
590     ldv, strideV,
591     reinterpret_cast<cuDoubleComplex*>(work),
592     lwork, info, h_R_nrmF, batchSize
593   ));
594 }
595 #endif
596 
597 
598 template<>
potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,float * A,int lda,float * work,int lwork,int * info)599 void potrf<float>(
600   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float* A, int lda, float* work, int lwork, int* info
601 ) {
602   TORCH_CUSOLVER_CHECK(cusolverDnSpotrf(
603     handle, uplo, n, A, lda, work, lwork, info));
604 }
605 
606 template<>
potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,double * A,int lda,double * work,int lwork,int * info)607 void potrf<double>(
608   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double* A, int lda, double* work, int lwork, int* info
609 ) {
610   TORCH_CUSOLVER_CHECK(cusolverDnDpotrf(
611     handle, uplo, n, A, lda, work, lwork, info));
612 }
613 
614 template<>
potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<float> * A,int lda,c10::complex<float> * work,int lwork,int * info)615 void potrf<c10::complex<float>>(
616   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<float>* A, int lda, c10::complex<float>* work, int lwork, int* info
617 ) {
618   TORCH_CUSOLVER_CHECK(cusolverDnCpotrf(
619     handle,
620     uplo,
621     n,
622     reinterpret_cast<cuComplex*>(A),
623     lda,
624     reinterpret_cast<cuComplex*>(work),
625     lwork,
626     info));
627 }
628 
629 template<>
potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<double> * A,int lda,c10::complex<double> * work,int lwork,int * info)630 void potrf<c10::complex<double>>(
631   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<double>* A, int lda, c10::complex<double>* work, int lwork, int* info
632 ) {
633   TORCH_CUSOLVER_CHECK(cusolverDnZpotrf(
634     handle,
635     uplo,
636     n,
637     reinterpret_cast<cuDoubleComplex*>(A),
638     lda,
639     reinterpret_cast<cuDoubleComplex*>(work),
640     lwork,
641     info));
642 }
643 
644 
645 template<>
potrf_buffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,float * A,int lda,int * lwork)646 void potrf_buffersize<float>(
647   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float* A, int lda, int* lwork
648 ) {
649   TORCH_CUSOLVER_CHECK(cusolverDnSpotrf_bufferSize(handle, uplo, n, A, lda, lwork));
650 }
651 
652 template<>
potrf_buffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,double * A,int lda,int * lwork)653 void potrf_buffersize<double>(
654   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double* A, int lda, int* lwork
655 ) {
656   TORCH_CUSOLVER_CHECK(cusolverDnDpotrf_bufferSize(handle, uplo, n, A, lda, lwork));
657 }
658 
659 template<>
potrf_buffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<float> * A,int lda,int * lwork)660 void potrf_buffersize<c10::complex<float>>(
661   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<float>* A, int lda, int* lwork
662 ) {
663   TORCH_CUSOLVER_CHECK(cusolverDnCpotrf_bufferSize(
664     handle, uplo, n,
665     reinterpret_cast<cuComplex*>(A),
666     lda, lwork));
667 }
668 
669 template<>
potrf_buffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<double> * A,int lda,int * lwork)670 void potrf_buffersize<c10::complex<double>>(
671   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<double>* A, int lda, int* lwork
672 ) {
673   TORCH_CUSOLVER_CHECK(cusolverDnZpotrf_bufferSize(
674     handle, uplo, n,
675     reinterpret_cast<cuDoubleComplex*>(A),
676     lda, lwork));
677 }
678 
679 
680 template<>
potrfBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,float ** A,int lda,int * info,int batchSize)681 void potrfBatched<float>(
682   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float** A, int lda, int* info, int batchSize
683 ) {
684   TORCH_CUSOLVER_CHECK(cusolverDnSpotrfBatched(handle, uplo, n, A, lda, info, batchSize));
685 }
686 
687 template<>
potrfBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,double ** A,int lda,int * info,int batchSize)688 void potrfBatched<double>(
689   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double** A, int lda, int* info, int batchSize
690 ) {
691   TORCH_CUSOLVER_CHECK(cusolverDnDpotrfBatched(handle, uplo, n, A, lda, info, batchSize));
692 }
693 
694 template<>
potrfBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<float> ** A,int lda,int * info,int batchSize)695 void potrfBatched<c10::complex<float>>(
696   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<float>** A, int lda, int* info, int batchSize
697 ) {
698   TORCH_CUSOLVER_CHECK(cusolverDnCpotrfBatched(
699     handle, uplo, n,
700     reinterpret_cast<cuComplex**>(A),
701     lda, info, batchSize));
702 }
703 
704 template<>
potrfBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<double> ** A,int lda,int * info,int batchSize)705 void potrfBatched<c10::complex<double>>(
706   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<double>** A, int lda, int* info, int batchSize
707 ) {
708   TORCH_CUSOLVER_CHECK(cusolverDnZpotrfBatched(
709     handle, uplo, n,
710     reinterpret_cast<cuDoubleComplex**>(A),
711     lda, info, batchSize));
712 }
713 
714 template <>
geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES (float))715 void geqrf_bufferSize<float>(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(float)) {
716   TORCH_CUSOLVER_CHECK(
717       cusolverDnSgeqrf_bufferSize(handle, m, n, A, lda, lwork));
718 }
719 
720 template <>
geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES (double))721 void geqrf_bufferSize<double>(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(double)) {
722   TORCH_CUSOLVER_CHECK(
723       cusolverDnDgeqrf_bufferSize(handle, m, n, A, lda, lwork));
724 }
725 
726 template <>
geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES (c10::complex<float>))727 void geqrf_bufferSize<c10::complex<float>>(
728     CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(c10::complex<float>)) {
729   TORCH_CUSOLVER_CHECK(cusolverDnCgeqrf_bufferSize(
730       handle, m, n, reinterpret_cast<cuComplex*>(A), lda, lwork));
731 }
732 
733 template <>
geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES (c10::complex<double>))734 void geqrf_bufferSize<c10::complex<double>>(
735     CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(c10::complex<double>)) {
736   TORCH_CUSOLVER_CHECK(cusolverDnZgeqrf_bufferSize(
737       handle, m, n, reinterpret_cast<cuDoubleComplex*>(A), lda, lwork));
738 }
739 
740 template <>
geqrf(CUDASOLVER_GEQRF_ARGTYPES (float))741 void geqrf<float>(CUDASOLVER_GEQRF_ARGTYPES(float)) {
742   TORCH_CUSOLVER_CHECK(
743       cusolverDnSgeqrf(handle, m, n, A, lda, tau, work, lwork, devInfo));
744 }
745 
746 template <>
geqrf(CUDASOLVER_GEQRF_ARGTYPES (double))747 void geqrf<double>(CUDASOLVER_GEQRF_ARGTYPES(double)) {
748   TORCH_CUSOLVER_CHECK(
749       cusolverDnDgeqrf(handle, m, n, A, lda, tau, work, lwork, devInfo));
750 }
751 
752 template <>
geqrf(CUDASOLVER_GEQRF_ARGTYPES (c10::complex<float>))753 void geqrf<c10::complex<float>>(
754     CUDASOLVER_GEQRF_ARGTYPES(c10::complex<float>)) {
755   TORCH_CUSOLVER_CHECK(cusolverDnCgeqrf(
756       handle,
757       m,
758       n,
759       reinterpret_cast<cuComplex*>(A),
760       lda,
761       reinterpret_cast<cuComplex*>(tau),
762       reinterpret_cast<cuComplex*>(work),
763       lwork,
764       devInfo));
765 }
766 
767 template <>
geqrf(CUDASOLVER_GEQRF_ARGTYPES (c10::complex<double>))768 void geqrf<c10::complex<double>>(
769     CUDASOLVER_GEQRF_ARGTYPES(c10::complex<double>)) {
770   TORCH_CUSOLVER_CHECK(cusolverDnZgeqrf(
771       handle,
772       m,
773       n,
774       reinterpret_cast<cuDoubleComplex*>(A),
775       lda,
776       reinterpret_cast<cuDoubleComplex*>(tau),
777       reinterpret_cast<cuDoubleComplex*>(work),
778       lwork,
779       devInfo));
780 }
781 
782 template<>
potrs(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,const float * A,int lda,float * B,int ldb,int * devInfo)783 void potrs<float>(
784     cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, const float *A, int lda, float *B, int ldb, int *devInfo
785 ) {
786   TORCH_CUSOLVER_CHECK(cusolverDnSpotrs(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo));
787 }
788 
789 template<>
potrs(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,const double * A,int lda,double * B,int ldb,int * devInfo)790 void potrs<double>(
791   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, const double *A, int lda, double *B, int ldb, int *devInfo
792 ) {
793   TORCH_CUSOLVER_CHECK(cusolverDnDpotrs(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo));
794 }
795 
796 template<>
potrs(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,const c10::complex<float> * A,int lda,c10::complex<float> * B,int ldb,int * devInfo)797 void potrs<c10::complex<float>>(
798   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, const c10::complex<float> *A, int lda, c10::complex<float> *B, int ldb, int *devInfo
799 ) {
800   TORCH_CUSOLVER_CHECK(cusolverDnCpotrs(
801     handle, uplo, n, nrhs,
802     reinterpret_cast<const cuComplex*>(A),
803     lda,
804     reinterpret_cast<cuComplex*>(B),
805     ldb, devInfo));
806 }
807 
808 template<>
potrs(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,const c10::complex<double> * A,int lda,c10::complex<double> * B,int ldb,int * devInfo)809 void potrs<c10::complex<double>>(
810   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, const c10::complex<double> *A, int lda, c10::complex<double> *B, int ldb, int *devInfo
811 ) {
812   TORCH_CUSOLVER_CHECK(cusolverDnZpotrs(
813     handle, uplo, n, nrhs,
814     reinterpret_cast<const cuDoubleComplex*>(A),
815     lda,
816     reinterpret_cast<cuDoubleComplex*>(B),
817     ldb, devInfo));
818 }
819 
820 template<>
potrsBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,float * Aarray[],int lda,float * Barray[],int ldb,int * info,int batchSize)821 void potrsBatched<float>(
822   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, float *Aarray[], int lda, float *Barray[], int ldb, int *info, int batchSize
823 ) {
824   TORCH_CUSOLVER_CHECK(cusolverDnSpotrsBatched(handle, uplo, n, nrhs, Aarray, lda, Barray, ldb, info, batchSize));
825 }
826 
827 template<>
potrsBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,double * Aarray[],int lda,double * Barray[],int ldb,int * info,int batchSize)828 void potrsBatched<double>(
829   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, double *Aarray[], int lda, double *Barray[], int ldb, int *info, int batchSize
830 ) {
831   TORCH_CUSOLVER_CHECK(cusolverDnDpotrsBatched(handle, uplo, n, nrhs, Aarray, lda, Barray, ldb, info, batchSize));
832 }
833 
834 template<>
potrsBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,c10::complex<float> * Aarray[],int lda,c10::complex<float> * Barray[],int ldb,int * info,int batchSize)835 void potrsBatched<c10::complex<float>>(
836   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, c10::complex<float> *Aarray[], int lda, c10::complex<float> *Barray[], int ldb, int *info, int batchSize
837 ) {
838   TORCH_CUSOLVER_CHECK(cusolverDnCpotrsBatched(
839     handle, uplo, n, nrhs,
840     reinterpret_cast<cuComplex**>(Aarray),
841     lda,
842     reinterpret_cast<cuComplex**>(Barray),
843     ldb, info, batchSize));
844 }
845 
846 template<>
potrsBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,c10::complex<double> * Aarray[],int lda,c10::complex<double> * Barray[],int ldb,int * info,int batchSize)847 void potrsBatched<c10::complex<double>>(
848   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, c10::complex<double> *Aarray[], int lda, c10::complex<double> *Barray[], int ldb, int *info, int batchSize
849 ) {
850   TORCH_CUSOLVER_CHECK(cusolverDnZpotrsBatched(
851     handle, uplo, n, nrhs,
852     reinterpret_cast<cuDoubleComplex**>(Aarray),
853     lda,
854     reinterpret_cast<cuDoubleComplex**>(Barray),
855     ldb, info, batchSize));
856 }
857 
858 
859 template <>
orgqr_buffersize(cusolverDnHandle_t handle,int m,int n,int k,const float * A,int lda,const float * tau,int * lwork)860 void orgqr_buffersize<float>(
861     cusolverDnHandle_t handle,
862     int m, int n, int k,
863     const float* A, int lda,
864     const float* tau, int* lwork) {
865   TORCH_CUSOLVER_CHECK(
866       cusolverDnSorgqr_bufferSize(handle, m, n, k, A, lda, tau, lwork));
867 }
868 
869 template <>
orgqr_buffersize(cusolverDnHandle_t handle,int m,int n,int k,const double * A,int lda,const double * tau,int * lwork)870 void orgqr_buffersize<double>(
871     cusolverDnHandle_t handle,
872     int m, int n, int k,
873     const double* A, int lda,
874     const double* tau, int* lwork) {
875   TORCH_CUSOLVER_CHECK(
876       cusolverDnDorgqr_bufferSize(handle, m, n, k, A, lda, tau, lwork));
877 }
878 
879 template <>
orgqr_buffersize(cusolverDnHandle_t handle,int m,int n,int k,const c10::complex<float> * A,int lda,const c10::complex<float> * tau,int * lwork)880 void orgqr_buffersize<c10::complex<float>>(
881     cusolverDnHandle_t handle,
882     int m, int n, int k,
883     const c10::complex<float>* A, int lda,
884     const c10::complex<float>* tau, int* lwork) {
885   TORCH_CUSOLVER_CHECK(cusolverDnCungqr_bufferSize(
886       handle,
887       m, n, k,
888       reinterpret_cast<const cuComplex*>(A), lda,
889       reinterpret_cast<const cuComplex*>(tau), lwork));
890 }
891 
892 template <>
orgqr_buffersize(cusolverDnHandle_t handle,int m,int n,int k,const c10::complex<double> * A,int lda,const c10::complex<double> * tau,int * lwork)893 void orgqr_buffersize<c10::complex<double>>(
894     cusolverDnHandle_t handle,
895     int m, int n, int k,
896     const c10::complex<double>* A, int lda,
897     const c10::complex<double>* tau, int* lwork) {
898   TORCH_CUSOLVER_CHECK(cusolverDnZungqr_bufferSize(
899       handle,
900       m, n, k,
901       reinterpret_cast<const cuDoubleComplex*>(A), lda,
902       reinterpret_cast<const cuDoubleComplex*>(tau), lwork));
903 }
904 
905 template <>
orgqr(cusolverDnHandle_t handle,int m,int n,int k,float * A,int lda,const float * tau,float * work,int lwork,int * devInfo)906 void orgqr<float>(
907     cusolverDnHandle_t handle,
908     int m, int n, int k,
909     float* A, int lda,
910     const float* tau,
911     float* work, int lwork,
912     int* devInfo) {
913   TORCH_CUSOLVER_CHECK(
914       cusolverDnSorgqr(handle, m, n, k, A, lda, tau, work, lwork, devInfo));
915 }
916 
917 template <>
orgqr(cusolverDnHandle_t handle,int m,int n,int k,double * A,int lda,const double * tau,double * work,int lwork,int * devInfo)918 void orgqr<double>(
919     cusolverDnHandle_t handle,
920     int m, int n, int k,
921     double* A, int lda,
922     const double* tau,
923     double* work, int lwork,
924     int* devInfo) {
925   TORCH_CUSOLVER_CHECK(
926       cusolverDnDorgqr(handle, m, n, k, A, lda, tau, work, lwork, devInfo));
927 }
928 
929 template <>
orgqr(cusolverDnHandle_t handle,int m,int n,int k,c10::complex<float> * A,int lda,const c10::complex<float> * tau,c10::complex<float> * work,int lwork,int * devInfo)930 void orgqr<c10::complex<float>>(
931     cusolverDnHandle_t handle,
932     int m, int n, int k,
933     c10::complex<float>* A, int lda,
934     const c10::complex<float>* tau,
935     c10::complex<float>* work, int lwork,
936     int* devInfo) {
937   TORCH_CUSOLVER_CHECK(cusolverDnCungqr(
938       handle,
939       m, n, k,
940       reinterpret_cast<cuComplex*>(A), lda,
941       reinterpret_cast<const cuComplex*>(tau),
942       reinterpret_cast<cuComplex*>(work), lwork,
943       devInfo));
944 }
945 
946 template <>
orgqr(cusolverDnHandle_t handle,int m,int n,int k,c10::complex<double> * A,int lda,const c10::complex<double> * tau,c10::complex<double> * work,int lwork,int * devInfo)947 void orgqr<c10::complex<double>>(
948     cusolverDnHandle_t handle,
949     int m, int n, int k,
950     c10::complex<double>* A, int lda,
951     const c10::complex<double>* tau,
952     c10::complex<double>* work, int lwork,
953     int* devInfo) {
954   TORCH_CUSOLVER_CHECK(cusolverDnZungqr(
955       handle,
956       m, n, k,
957       reinterpret_cast<cuDoubleComplex*>(A), lda,
958       reinterpret_cast<const cuDoubleComplex*>(tau),
959       reinterpret_cast<cuDoubleComplex*>(work), lwork,
960       devInfo));
961 }
962 
963 template <>
ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES (float))964 void ormqr_bufferSize<float>(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(float)) {
965   TORCH_CUSOLVER_CHECK(
966       cusolverDnSormqr_bufferSize(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork));
967 }
968 
969 template <>
ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES (double))970 void ormqr_bufferSize<double>(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(double)) {
971   TORCH_CUSOLVER_CHECK(
972       cusolverDnDormqr_bufferSize(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork));
973 }
974 
975 template <>
ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES (c10::complex<float>))976 void ormqr_bufferSize<c10::complex<float>>(
977     CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(c10::complex<float>)) {
978   TORCH_CUSOLVER_CHECK(cusolverDnCunmqr_bufferSize(
979       handle, side, trans,
980       m, n, k,
981       reinterpret_cast<const cuComplex*>(A), lda,
982       reinterpret_cast<const cuComplex*>(tau),
983       reinterpret_cast<const cuComplex*>(C), ldc,
984       lwork));
985 }
986 
987 template <>
ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES (c10::complex<double>))988 void ormqr_bufferSize<c10::complex<double>>(
989     CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(c10::complex<double>)) {
990   TORCH_CUSOLVER_CHECK(cusolverDnZunmqr_bufferSize(
991       handle, side, trans,
992       m, n, k,
993       reinterpret_cast<const cuDoubleComplex*>(A), lda,
994       reinterpret_cast<const cuDoubleComplex*>(tau),
995       reinterpret_cast<const cuDoubleComplex*>(C), ldc,
996       lwork));
997 }
998 
999 template <>
ormqr(CUDASOLVER_ORMQR_ARGTYPES (float))1000 void ormqr<float>(CUDASOLVER_ORMQR_ARGTYPES(float)) {
1001   TORCH_CUSOLVER_CHECK(
1002       cusolverDnSormqr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, lwork, devInfo));
1003 }
1004 
1005 template <>
ormqr(CUDASOLVER_ORMQR_ARGTYPES (double))1006 void ormqr<double>(CUDASOLVER_ORMQR_ARGTYPES(double)) {
1007   TORCH_CUSOLVER_CHECK(
1008       cusolverDnDormqr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, lwork, devInfo));
1009 }
1010 
1011 template <>
ormqr(CUDASOLVER_ORMQR_ARGTYPES (c10::complex<float>))1012 void ormqr<c10::complex<float>>(CUDASOLVER_ORMQR_ARGTYPES(c10::complex<float>)) {
1013   TORCH_CUSOLVER_CHECK(cusolverDnCunmqr(
1014       handle, side, trans,
1015       m, n, k,
1016       reinterpret_cast<const cuComplex*>(A), lda,
1017       reinterpret_cast<const cuComplex*>(tau),
1018       reinterpret_cast<cuComplex*>(C), ldc,
1019       reinterpret_cast<cuComplex*>(work), lwork,
1020       devInfo));
1021 }
1022 
1023 template <>
ormqr(CUDASOLVER_ORMQR_ARGTYPES (c10::complex<double>))1024 void ormqr<c10::complex<double>>(CUDASOLVER_ORMQR_ARGTYPES(c10::complex<double>)) {
1025   TORCH_CUSOLVER_CHECK(cusolverDnZunmqr(
1026       handle, side, trans,
1027       m, n, k,
1028       reinterpret_cast<const cuDoubleComplex*>(A), lda,
1029       reinterpret_cast<const cuDoubleComplex*>(tau),
1030       reinterpret_cast<cuDoubleComplex*>(C), ldc,
1031       reinterpret_cast<cuDoubleComplex*>(work), lwork,
1032       devInfo));
1033 }
1034 
1035 #ifdef USE_CUSOLVER_64_BIT
1036 
get_cusolver_datatype()1037 template<> cudaDataType get_cusolver_datatype<float>() { return CUDA_R_32F; }
get_cusolver_datatype()1038 template<> cudaDataType get_cusolver_datatype<double>() { return CUDA_R_64F; }
get_cusolver_datatype()1039 template<> cudaDataType get_cusolver_datatype<c10::complex<float>>() { return CUDA_C_32F; }
get_cusolver_datatype()1040 template<> cudaDataType get_cusolver_datatype<c10::complex<double>>() { return CUDA_C_64F; }
1041 
xpotrf_buffersize(cusolverDnHandle_t handle,cusolverDnParams_t params,cublasFillMode_t uplo,int64_t n,cudaDataType dataTypeA,const void * A,int64_t lda,cudaDataType computeType,size_t * workspaceInBytesOnDevice,size_t * workspaceInBytesOnHost)1042 void xpotrf_buffersize(
1043     cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, cudaDataType dataTypeA, const void *A,
1044     int64_t lda, cudaDataType computeType, size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost) {
1045   TORCH_CUSOLVER_CHECK(cusolverDnXpotrf_bufferSize(
1046     handle, params, uplo, n, dataTypeA, A, lda, computeType, workspaceInBytesOnDevice, workspaceInBytesOnHost
1047   ));
1048 }
1049 
xpotrf(cusolverDnHandle_t handle,cusolverDnParams_t params,cublasFillMode_t uplo,int64_t n,cudaDataType dataTypeA,void * A,int64_t lda,cudaDataType computeType,void * bufferOnDevice,size_t workspaceInBytesOnDevice,void * bufferOnHost,size_t workspaceInBytesOnHost,int * info)1050 void xpotrf(
1051     cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, cudaDataType dataTypeA, void *A,
1052     int64_t lda, cudaDataType computeType, void *bufferOnDevice, size_t workspaceInBytesOnDevice, void *bufferOnHost, size_t workspaceInBytesOnHost,
1053     int *info) {
1054   TORCH_CUSOLVER_CHECK(cusolverDnXpotrf(
1055     handle, params, uplo, n, dataTypeA, A, lda, computeType, bufferOnDevice, workspaceInBytesOnDevice, bufferOnHost, workspaceInBytesOnHost, info
1056   ));
1057 }
1058 #endif // USE_CUSOLVER_64_BIT
1059 
1060 template <>
syevd_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const float * A,int lda,const float * W,int * lwork)1061 void syevd_bufferSize<float>(
1062     cusolverDnHandle_t handle,
1063     cusolverEigMode_t jobz,
1064     cublasFillMode_t uplo,
1065     int n,
1066     const float* A,
1067     int lda,
1068     const float* W,
1069     int* lwork) {
1070   TORCH_CUSOLVER_CHECK(
1071       cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W, lwork));
1072 }
1073 
1074 template <>
syevd_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const double * A,int lda,const double * W,int * lwork)1075 void syevd_bufferSize<double>(
1076     cusolverDnHandle_t handle,
1077     cusolverEigMode_t jobz,
1078     cublasFillMode_t uplo,
1079     int n,
1080     const double* A,
1081     int lda,
1082     const double* W,
1083     int* lwork) {
1084   TORCH_CUSOLVER_CHECK(
1085       cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W, lwork));
1086 }
1087 
1088 template <>
syevd_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<float> * A,int lda,const float * W,int * lwork)1089 void syevd_bufferSize<c10::complex<float>, float>(
1090     cusolverDnHandle_t handle,
1091     cusolverEigMode_t jobz,
1092     cublasFillMode_t uplo,
1093     int n,
1094     const c10::complex<float>* A,
1095     int lda,
1096     const float* W,
1097     int* lwork) {
1098   TORCH_CUSOLVER_CHECK(cusolverDnCheevd_bufferSize(
1099       handle,
1100       jobz,
1101       uplo,
1102       n,
1103       reinterpret_cast<const cuComplex*>(A),
1104       lda,
1105       W,
1106       lwork));
1107 }
1108 
1109 template <>
syevd_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<double> * A,int lda,const double * W,int * lwork)1110 void syevd_bufferSize<c10::complex<double>, double>(
1111     cusolverDnHandle_t handle,
1112     cusolverEigMode_t jobz,
1113     cublasFillMode_t uplo,
1114     int n,
1115     const c10::complex<double>* A,
1116     int lda,
1117     const double* W,
1118     int* lwork) {
1119   TORCH_CUSOLVER_CHECK(cusolverDnZheevd_bufferSize(
1120       handle,
1121       jobz,
1122       uplo,
1123       n,
1124       reinterpret_cast<const cuDoubleComplex*>(A),
1125       lda,
1126       W,
1127       lwork));
1128 }
1129 
1130 template <>
syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,float * A,int lda,float * W,float * work,int lwork,int * info)1131 void syevd<float>(
1132     cusolverDnHandle_t handle,
1133     cusolverEigMode_t jobz,
1134     cublasFillMode_t uplo,
1135     int n,
1136     float* A,
1137     int lda,
1138     float* W,
1139     float* work,
1140     int lwork,
1141     int* info) {
1142   TORCH_CUSOLVER_CHECK(
1143       cusolverDnSsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork, info));
1144 }
1145 
1146 template <>
syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,double * A,int lda,double * W,double * work,int lwork,int * info)1147 void syevd<double>(
1148     cusolverDnHandle_t handle,
1149     cusolverEigMode_t jobz,
1150     cublasFillMode_t uplo,
1151     int n,
1152     double* A,
1153     int lda,
1154     double* W,
1155     double* work,
1156     int lwork,
1157     int* info) {
1158   TORCH_CUSOLVER_CHECK(
1159       cusolverDnDsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork, info));
1160 }
1161 
1162 template <>
syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<float> * A,int lda,float * W,c10::complex<float> * work,int lwork,int * info)1163 void syevd<c10::complex<float>, float>(
1164     cusolverDnHandle_t handle,
1165     cusolverEigMode_t jobz,
1166     cublasFillMode_t uplo,
1167     int n,
1168     c10::complex<float>* A,
1169     int lda,
1170     float* W,
1171     c10::complex<float>* work,
1172     int lwork,
1173     int* info) {
1174   TORCH_CUSOLVER_CHECK(cusolverDnCheevd(
1175       handle,
1176       jobz,
1177       uplo,
1178       n,
1179       reinterpret_cast<cuComplex*>(A),
1180       lda,
1181       W,
1182       reinterpret_cast<cuComplex*>(work),
1183       lwork,
1184       info));
1185 }
1186 
1187 template <>
syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<double> * A,int lda,double * W,c10::complex<double> * work,int lwork,int * info)1188 void syevd<c10::complex<double>, double>(
1189     cusolverDnHandle_t handle,
1190     cusolverEigMode_t jobz,
1191     cublasFillMode_t uplo,
1192     int n,
1193     c10::complex<double>* A,
1194     int lda,
1195     double* W,
1196     c10::complex<double>* work,
1197     int lwork,
1198     int* info) {
1199   TORCH_CUSOLVER_CHECK(cusolverDnZheevd(
1200       handle,
1201       jobz,
1202       uplo,
1203       n,
1204       reinterpret_cast<cuDoubleComplex*>(A),
1205       lda,
1206       W,
1207       reinterpret_cast<cuDoubleComplex*>(work),
1208       lwork,
1209       info));
1210 }
1211 
1212 template <>
syevj_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const float * A,int lda,const float * W,int * lwork,syevjInfo_t params)1213 void syevj_bufferSize<float>(
1214     cusolverDnHandle_t handle,
1215     cusolverEigMode_t jobz,
1216     cublasFillMode_t uplo,
1217     int n,
1218     const float* A,
1219     int lda,
1220     const float* W,
1221     int* lwork,
1222     syevjInfo_t params) {
1223   TORCH_CUSOLVER_CHECK(cusolverDnSsyevj_bufferSize(
1224       handle, jobz, uplo, n, A, lda, W, lwork, params));
1225 }
1226 
1227 template <>
syevj_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const double * A,int lda,const double * W,int * lwork,syevjInfo_t params)1228 void syevj_bufferSize<double>(
1229     cusolverDnHandle_t handle,
1230     cusolverEigMode_t jobz,
1231     cublasFillMode_t uplo,
1232     int n,
1233     const double* A,
1234     int lda,
1235     const double* W,
1236     int* lwork,
1237     syevjInfo_t params) {
1238   TORCH_CUSOLVER_CHECK(cusolverDnDsyevj_bufferSize(
1239       handle, jobz, uplo, n, A, lda, W, lwork, params));
1240 }
1241 
1242 template <>
syevj_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<float> * A,int lda,const float * W,int * lwork,syevjInfo_t params)1243 void syevj_bufferSize<c10::complex<float>, float>(
1244     cusolverDnHandle_t handle,
1245     cusolverEigMode_t jobz,
1246     cublasFillMode_t uplo,
1247     int n,
1248     const c10::complex<float>* A,
1249     int lda,
1250     const float* W,
1251     int* lwork,
1252     syevjInfo_t params) {
1253   TORCH_CUSOLVER_CHECK(cusolverDnCheevj_bufferSize(
1254       handle,
1255       jobz,
1256       uplo,
1257       n,
1258       reinterpret_cast<const cuComplex*>(A),
1259       lda,
1260       W,
1261       lwork,
1262       params));
1263 }
1264 
1265 template <>
syevj_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<double> * A,int lda,const double * W,int * lwork,syevjInfo_t params)1266 void syevj_bufferSize<c10::complex<double>, double>(
1267     cusolverDnHandle_t handle,
1268     cusolverEigMode_t jobz,
1269     cublasFillMode_t uplo,
1270     int n,
1271     const c10::complex<double>* A,
1272     int lda,
1273     const double* W,
1274     int* lwork,
1275     syevjInfo_t params) {
1276   TORCH_CUSOLVER_CHECK(cusolverDnZheevj_bufferSize(
1277       handle,
1278       jobz,
1279       uplo,
1280       n,
1281       reinterpret_cast<const cuDoubleComplex*>(A),
1282       lda,
1283       W,
1284       lwork,
1285       params));
1286 }
1287 
1288 template <>
syevj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,float * A,int lda,float * W,float * work,int lwork,int * info,syevjInfo_t params)1289 void syevj<float>(
1290     cusolverDnHandle_t handle,
1291     cusolverEigMode_t jobz,
1292     cublasFillMode_t uplo,
1293     int n,
1294     float* A,
1295     int lda,
1296     float* W,
1297     float* work,
1298     int lwork,
1299     int* info,
1300     syevjInfo_t params) {
1301   TORCH_CUSOLVER_CHECK(cusolverDnSsyevj(
1302       handle, jobz, uplo, n, A, lda, W, work, lwork, info, params));
1303 }
1304 
1305 template <>
syevj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,double * A,int lda,double * W,double * work,int lwork,int * info,syevjInfo_t params)1306 void syevj<double>(
1307     cusolverDnHandle_t handle,
1308     cusolverEigMode_t jobz,
1309     cublasFillMode_t uplo,
1310     int n,
1311     double* A,
1312     int lda,
1313     double* W,
1314     double* work,
1315     int lwork,
1316     int* info,
1317     syevjInfo_t params) {
1318   TORCH_CUSOLVER_CHECK(cusolverDnDsyevj(
1319       handle, jobz, uplo, n, A, lda, W, work, lwork, info, params));
1320 }
1321 
1322 template <>
syevj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<float> * A,int lda,float * W,c10::complex<float> * work,int lwork,int * info,syevjInfo_t params)1323 void syevj<c10::complex<float>, float>(
1324     cusolverDnHandle_t handle,
1325     cusolverEigMode_t jobz,
1326     cublasFillMode_t uplo,
1327     int n,
1328     c10::complex<float>* A,
1329     int lda,
1330     float* W,
1331     c10::complex<float>* work,
1332     int lwork,
1333     int* info,
1334     syevjInfo_t params) {
1335   TORCH_CUSOLVER_CHECK(cusolverDnCheevj(
1336       handle,
1337       jobz,
1338       uplo,
1339       n,
1340       reinterpret_cast<cuComplex*>(A),
1341       lda,
1342       W,
1343       reinterpret_cast<cuComplex*>(work),
1344       lwork,
1345       info,
1346       params));
1347 }
1348 
1349 template <>
syevj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<double> * A,int lda,double * W,c10::complex<double> * work,int lwork,int * info,syevjInfo_t params)1350 void syevj<c10::complex<double>, double>(
1351     cusolverDnHandle_t handle,
1352     cusolverEigMode_t jobz,
1353     cublasFillMode_t uplo,
1354     int n,
1355     c10::complex<double>* A,
1356     int lda,
1357     double* W,
1358     c10::complex<double>* work,
1359     int lwork,
1360     int* info,
1361     syevjInfo_t params) {
1362   TORCH_CUSOLVER_CHECK(cusolverDnZheevj(
1363       handle,
1364       jobz,
1365       uplo,
1366       n,
1367       reinterpret_cast<cuDoubleComplex*>(A),
1368       lda,
1369       W,
1370       reinterpret_cast<cuDoubleComplex*>(work),
1371       lwork,
1372       info,
1373       params));
1374 }
1375 
1376 template <>
syevjBatched_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const float * A,int lda,const float * W,int * lwork,syevjInfo_t params,int batchsize)1377 void syevjBatched_bufferSize<float>(
1378     cusolverDnHandle_t handle,
1379     cusolverEigMode_t jobz,
1380     cublasFillMode_t uplo,
1381     int n,
1382     const float* A,
1383     int lda,
1384     const float* W,
1385     int* lwork,
1386     syevjInfo_t params,
1387     int batchsize) {
1388   TORCH_CUSOLVER_CHECK(cusolverDnSsyevjBatched_bufferSize(
1389       handle, jobz, uplo, n, A, lda, W, lwork, params, batchsize));
1390 }
1391 
1392 template <>
syevjBatched_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const double * A,int lda,const double * W,int * lwork,syevjInfo_t params,int batchsize)1393 void syevjBatched_bufferSize<double>(
1394     cusolverDnHandle_t handle,
1395     cusolverEigMode_t jobz,
1396     cublasFillMode_t uplo,
1397     int n,
1398     const double* A,
1399     int lda,
1400     const double* W,
1401     int* lwork,
1402     syevjInfo_t params,
1403     int batchsize) {
1404   TORCH_CUSOLVER_CHECK(cusolverDnDsyevjBatched_bufferSize(
1405       handle, jobz, uplo, n, A, lda, W, lwork, params, batchsize));
1406 }
1407 
1408 template <>
syevjBatched_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<float> * A,int lda,const float * W,int * lwork,syevjInfo_t params,int batchsize)1409 void syevjBatched_bufferSize<c10::complex<float>, float>(
1410     cusolverDnHandle_t handle,
1411     cusolverEigMode_t jobz,
1412     cublasFillMode_t uplo,
1413     int n,
1414     const c10::complex<float>* A,
1415     int lda,
1416     const float* W,
1417     int* lwork,
1418     syevjInfo_t params,
1419     int batchsize) {
1420   TORCH_CUSOLVER_CHECK(cusolverDnCheevjBatched_bufferSize(
1421       handle,
1422       jobz,
1423       uplo,
1424       n,
1425       reinterpret_cast<const cuComplex*>(A),
1426       lda,
1427       W,
1428       lwork,
1429       params,
1430       batchsize));
1431 }
1432 
1433 template <>
syevjBatched_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<double> * A,int lda,const double * W,int * lwork,syevjInfo_t params,int batchsize)1434 void syevjBatched_bufferSize<c10::complex<double>, double>(
1435     cusolverDnHandle_t handle,
1436     cusolverEigMode_t jobz,
1437     cublasFillMode_t uplo,
1438     int n,
1439     const c10::complex<double>* A,
1440     int lda,
1441     const double* W,
1442     int* lwork,
1443     syevjInfo_t params,
1444     int batchsize) {
1445   TORCH_CUSOLVER_CHECK(cusolverDnZheevjBatched_bufferSize(
1446       handle,
1447       jobz,
1448       uplo,
1449       n,
1450       reinterpret_cast<const cuDoubleComplex*>(A),
1451       lda,
1452       W,
1453       lwork,
1454       params,
1455       batchsize));
1456 }
1457 
1458 template <>
syevjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,float * A,int lda,float * W,float * work,int lwork,int * info,syevjInfo_t params,int batchsize)1459 void syevjBatched<float>(
1460     cusolverDnHandle_t handle,
1461     cusolverEigMode_t jobz,
1462     cublasFillMode_t uplo,
1463     int n,
1464     float* A,
1465     int lda,
1466     float* W,
1467     float* work,
1468     int lwork,
1469     int* info,
1470     syevjInfo_t params,
1471     int batchsize) {
1472   TORCH_CUSOLVER_CHECK(cusolverDnSsyevjBatched(
1473       handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, batchsize));
1474 }
1475 
1476 template <>
syevjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,double * A,int lda,double * W,double * work,int lwork,int * info,syevjInfo_t params,int batchsize)1477 void syevjBatched<double>(
1478     cusolverDnHandle_t handle,
1479     cusolverEigMode_t jobz,
1480     cublasFillMode_t uplo,
1481     int n,
1482     double* A,
1483     int lda,
1484     double* W,
1485     double* work,
1486     int lwork,
1487     int* info,
1488     syevjInfo_t params,
1489     int batchsize) {
1490   TORCH_CUSOLVER_CHECK(cusolverDnDsyevjBatched(
1491       handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, batchsize));
1492 }
1493 
1494 template <>
syevjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<float> * A,int lda,float * W,c10::complex<float> * work,int lwork,int * info,syevjInfo_t params,int batchsize)1495 void syevjBatched<c10::complex<float>, float>(
1496     cusolverDnHandle_t handle,
1497     cusolverEigMode_t jobz,
1498     cublasFillMode_t uplo,
1499     int n,
1500     c10::complex<float>* A,
1501     int lda,
1502     float* W,
1503     c10::complex<float>* work,
1504     int lwork,
1505     int* info,
1506     syevjInfo_t params,
1507     int batchsize) {
1508   TORCH_CUSOLVER_CHECK(cusolverDnCheevjBatched(
1509       handle,
1510       jobz,
1511       uplo,
1512       n,
1513       reinterpret_cast<cuComplex*>(A),
1514       lda,
1515       W,
1516       reinterpret_cast<cuComplex*>(work),
1517       lwork,
1518       info,
1519       params,
1520       batchsize));
1521 }
1522 
1523 template <>
syevjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<double> * A,int lda,double * W,c10::complex<double> * work,int lwork,int * info,syevjInfo_t params,int batchsize)1524 void syevjBatched<c10::complex<double>, double>(
1525     cusolverDnHandle_t handle,
1526     cusolverEigMode_t jobz,
1527     cublasFillMode_t uplo,
1528     int n,
1529     c10::complex<double>* A,
1530     int lda,
1531     double* W,
1532     c10::complex<double>* work,
1533     int lwork,
1534     int* info,
1535     syevjInfo_t params,
1536     int batchsize) {
1537   TORCH_CUSOLVER_CHECK(cusolverDnZheevjBatched(
1538       handle,
1539       jobz,
1540       uplo,
1541       n,
1542       reinterpret_cast<cuDoubleComplex*>(A),
1543       lda,
1544       W,
1545       reinterpret_cast<cuDoubleComplex*>(work),
1546       lwork,
1547       info,
1548       params,
1549       batchsize));
1550 }
1551 
1552 #ifdef USE_CUSOLVER_64_BIT
1553 
xpotrs(cusolverDnHandle_t handle,cusolverDnParams_t params,cublasFillMode_t uplo,int64_t n,int64_t nrhs,cudaDataType dataTypeA,const void * A,int64_t lda,cudaDataType dataTypeB,void * B,int64_t ldb,int * info)1554 void xpotrs(
1555     cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, int64_t nrhs, cudaDataType dataTypeA, const void *A,
1556     int64_t lda, cudaDataType dataTypeB, void *B, int64_t ldb, int *info) {
1557   TORCH_CUSOLVER_CHECK(cusolverDnXpotrs(handle, params, uplo, n, nrhs, dataTypeA, A, lda, dataTypeB, B, ldb, info));
1558 }
1559 
1560 template <>
xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES (float))1561 void xgeqrf_bufferSize<float>(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(float)) {
1562   TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf_bufferSize(
1563       handle,
1564       params,
1565       m,
1566       n,
1567       CUDA_R_32F,
1568       reinterpret_cast<const void*>(A),
1569       lda,
1570       CUDA_R_32F,
1571       reinterpret_cast<const void*>(tau),
1572       CUDA_R_32F,
1573       workspaceInBytesOnDevice,
1574       workspaceInBytesOnHost));
1575 }
1576 
1577 template <>
xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES (double))1578 void xgeqrf_bufferSize<double>(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(double)) {
1579   TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf_bufferSize(
1580       handle,
1581       params,
1582       m,
1583       n,
1584       CUDA_R_64F,
1585       reinterpret_cast<const void*>(A),
1586       lda,
1587       CUDA_R_64F,
1588       reinterpret_cast<const void*>(tau),
1589       CUDA_R_64F,
1590       workspaceInBytesOnDevice,
1591       workspaceInBytesOnHost));
1592 }
1593 
1594 template <>
xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES (c10::complex<float>))1595 void xgeqrf_bufferSize<c10::complex<float>>(
1596     CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(c10::complex<float>)) {
1597   TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf_bufferSize(
1598       handle,
1599       params,
1600       m,
1601       n,
1602       CUDA_C_32F,
1603       reinterpret_cast<const void*>(A),
1604       lda,
1605       CUDA_C_32F,
1606       reinterpret_cast<const void*>(tau),
1607       CUDA_C_32F,
1608       workspaceInBytesOnDevice,
1609       workspaceInBytesOnHost));
1610 }
1611 
1612 template <>
xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES (c10::complex<double>))1613 void xgeqrf_bufferSize<c10::complex<double>>(
1614     CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(c10::complex<double>)) {
1615   TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf_bufferSize(
1616       handle,
1617       params,
1618       m,
1619       n,
1620       CUDA_C_64F,
1621       reinterpret_cast<const void*>(A),
1622       lda,
1623       CUDA_C_64F,
1624       reinterpret_cast<const void*>(tau),
1625       CUDA_C_64F,
1626       workspaceInBytesOnDevice,
1627       workspaceInBytesOnHost));
1628 }
1629 
1630 template <>
xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES (float))1631 void xgeqrf<float>(CUDASOLVER_XGEQRF_ARGTYPES(float)) {
1632   TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf(
1633       handle,
1634       params,
1635       m,
1636       n,
1637       CUDA_R_32F,
1638       reinterpret_cast<void*>(A),
1639       lda,
1640       CUDA_R_32F,
1641       reinterpret_cast<void*>(tau),
1642       CUDA_R_32F,
1643       reinterpret_cast<void*>(bufferOnDevice),
1644       workspaceInBytesOnDevice,
1645       reinterpret_cast<void*>(bufferOnHost),
1646       workspaceInBytesOnHost,
1647       info));
1648 }
1649 
1650 template <>
xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES (double))1651 void xgeqrf<double>(CUDASOLVER_XGEQRF_ARGTYPES(double)) {
1652   TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf(
1653       handle,
1654       params,
1655       m,
1656       n,
1657       CUDA_R_64F,
1658       reinterpret_cast<void*>(A),
1659       lda,
1660       CUDA_R_64F,
1661       reinterpret_cast<void*>(tau),
1662       CUDA_R_64F,
1663       reinterpret_cast<void*>(bufferOnDevice),
1664       workspaceInBytesOnDevice,
1665       reinterpret_cast<void*>(bufferOnHost),
1666       workspaceInBytesOnHost,
1667       info));
1668 }
1669 
1670 template <>
xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES (c10::complex<float>))1671 void xgeqrf<c10::complex<float>>(CUDASOLVER_XGEQRF_ARGTYPES(c10::complex<float>)) {
1672   TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf(
1673       handle,
1674       params,
1675       m,
1676       n,
1677       CUDA_C_32F,
1678       reinterpret_cast<void*>(A),
1679       lda,
1680       CUDA_C_32F,
1681       reinterpret_cast<void*>(tau),
1682       CUDA_C_32F,
1683       reinterpret_cast<void*>(bufferOnDevice),
1684       workspaceInBytesOnDevice,
1685       reinterpret_cast<void*>(bufferOnHost),
1686       workspaceInBytesOnHost,
1687       info));
1688 }
1689 
1690 template <>
xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES (c10::complex<double>))1691 void xgeqrf<c10::complex<double>>(CUDASOLVER_XGEQRF_ARGTYPES(c10::complex<double>)) {
1692   TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf(
1693       handle,
1694       params,
1695       m,
1696       n,
1697       CUDA_C_64F,
1698       reinterpret_cast<void*>(A),
1699       lda,
1700       CUDA_C_64F,
1701       reinterpret_cast<void*>(tau),
1702       CUDA_C_64F,
1703       reinterpret_cast<void*>(bufferOnDevice),
1704       workspaceInBytesOnDevice,
1705       reinterpret_cast<void*>(bufferOnHost),
1706       workspaceInBytesOnHost,
1707       info));
1708 }
1709 
1710 template <>
xsyevd_bufferSize(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,const float * A,int64_t lda,const float * W,size_t * workspaceInBytesOnDevice,size_t * workspaceInBytesOnHost)1711 void xsyevd_bufferSize<float>(
1712     cusolverDnHandle_t handle,
1713     cusolverDnParams_t params,
1714     cusolverEigMode_t jobz,
1715     cublasFillMode_t uplo,
1716     int64_t n,
1717     const float* A,
1718     int64_t lda,
1719     const float* W,
1720     size_t* workspaceInBytesOnDevice,
1721     size_t* workspaceInBytesOnHost) {
1722   TORCH_CUSOLVER_CHECK(cusolverDnXsyevd_bufferSize(
1723       handle,
1724       params,
1725       jobz,
1726       uplo,
1727       n,
1728       CUDA_R_32F,
1729       reinterpret_cast<const void*>(A),
1730       lda,
1731       CUDA_R_32F,
1732       reinterpret_cast<const void*>(W),
1733       CUDA_R_32F,
1734       workspaceInBytesOnDevice,
1735       workspaceInBytesOnHost));
1736 }
1737 
1738 template <>
xsyevd_bufferSize(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,const double * A,int64_t lda,const double * W,size_t * workspaceInBytesOnDevice,size_t * workspaceInBytesOnHost)1739 void xsyevd_bufferSize<double>(
1740     cusolverDnHandle_t handle,
1741     cusolverDnParams_t params,
1742     cusolverEigMode_t jobz,
1743     cublasFillMode_t uplo,
1744     int64_t n,
1745     const double* A,
1746     int64_t lda,
1747     const double* W,
1748     size_t* workspaceInBytesOnDevice,
1749     size_t* workspaceInBytesOnHost) {
1750   TORCH_CUSOLVER_CHECK(cusolverDnXsyevd_bufferSize(
1751       handle,
1752       params,
1753       jobz,
1754       uplo,
1755       n,
1756       CUDA_R_64F,
1757       reinterpret_cast<const void*>(A),
1758       lda,
1759       CUDA_R_64F,
1760       reinterpret_cast<const void*>(W),
1761       CUDA_R_64F,
1762       workspaceInBytesOnDevice,
1763       workspaceInBytesOnHost));
1764 }
1765 
1766 template <>
xsyevd_bufferSize(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,const c10::complex<float> * A,int64_t lda,const float * W,size_t * workspaceInBytesOnDevice,size_t * workspaceInBytesOnHost)1767 void xsyevd_bufferSize<c10::complex<float>, float>(
1768     cusolverDnHandle_t handle,
1769     cusolverDnParams_t params,
1770     cusolverEigMode_t jobz,
1771     cublasFillMode_t uplo,
1772     int64_t n,
1773     const c10::complex<float>* A,
1774     int64_t lda,
1775     const float* W,
1776     size_t* workspaceInBytesOnDevice,
1777     size_t* workspaceInBytesOnHost) {
1778   TORCH_CUSOLVER_CHECK(cusolverDnXsyevd_bufferSize(
1779       handle,
1780       params,
1781       jobz,
1782       uplo,
1783       n,
1784       CUDA_C_32F,
1785       reinterpret_cast<const void*>(A),
1786       lda,
1787       CUDA_R_32F,
1788       reinterpret_cast<const void*>(W),
1789       CUDA_C_32F,
1790       workspaceInBytesOnDevice,
1791       workspaceInBytesOnHost));
1792 }
1793 
1794 template <>
xsyevd_bufferSize(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,const c10::complex<double> * A,int64_t lda,const double * W,size_t * workspaceInBytesOnDevice,size_t * workspaceInBytesOnHost)1795 void xsyevd_bufferSize<c10::complex<double>, double>(
1796     cusolverDnHandle_t handle,
1797     cusolverDnParams_t params,
1798     cusolverEigMode_t jobz,
1799     cublasFillMode_t uplo,
1800     int64_t n,
1801     const c10::complex<double>* A,
1802     int64_t lda,
1803     const double* W,
1804     size_t* workspaceInBytesOnDevice,
1805     size_t* workspaceInBytesOnHost) {
1806   TORCH_CUSOLVER_CHECK(cusolverDnXsyevd_bufferSize(
1807       handle,
1808       params,
1809       jobz,
1810       uplo,
1811       n,
1812       CUDA_C_64F,
1813       reinterpret_cast<const void*>(A),
1814       lda,
1815       CUDA_R_64F,
1816       reinterpret_cast<const void*>(W),
1817       CUDA_C_64F,
1818       workspaceInBytesOnDevice,
1819       workspaceInBytesOnHost));
1820 }
1821 
1822 template <>
xsyevd(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,float * A,int64_t lda,float * W,float * bufferOnDevice,size_t workspaceInBytesOnDevice,float * bufferOnHost,size_t workspaceInBytesOnHost,int * info)1823 void xsyevd<float>(
1824     cusolverDnHandle_t handle,
1825     cusolverDnParams_t params,
1826     cusolverEigMode_t jobz,
1827     cublasFillMode_t uplo,
1828     int64_t n,
1829     float* A,
1830     int64_t lda,
1831     float* W,
1832     float* bufferOnDevice,
1833     size_t workspaceInBytesOnDevice,
1834     float* bufferOnHost,
1835     size_t workspaceInBytesOnHost,
1836     int* info) {
1837   TORCH_CUSOLVER_CHECK(cusolverDnXsyevd(
1838       handle,
1839       params,
1840       jobz,
1841       uplo,
1842       n,
1843       CUDA_R_32F,
1844       reinterpret_cast<void*>(A),
1845       lda,
1846       CUDA_R_32F,
1847       reinterpret_cast<void*>(W),
1848       CUDA_R_32F,
1849       reinterpret_cast<void*>(bufferOnDevice),
1850       workspaceInBytesOnDevice,
1851       reinterpret_cast<void*>(bufferOnHost),
1852       workspaceInBytesOnHost,
1853       info));
1854 }
1855 
1856 template <>
xsyevd(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,double * A,int64_t lda,double * W,double * bufferOnDevice,size_t workspaceInBytesOnDevice,double * bufferOnHost,size_t workspaceInBytesOnHost,int * info)1857 void xsyevd<double>(
1858     cusolverDnHandle_t handle,
1859     cusolverDnParams_t params,
1860     cusolverEigMode_t jobz,
1861     cublasFillMode_t uplo,
1862     int64_t n,
1863     double* A,
1864     int64_t lda,
1865     double* W,
1866     double* bufferOnDevice,
1867     size_t workspaceInBytesOnDevice,
1868     double* bufferOnHost,
1869     size_t workspaceInBytesOnHost,
1870     int* info) {
1871   TORCH_CUSOLVER_CHECK(cusolverDnXsyevd(
1872       handle,
1873       params,
1874       jobz,
1875       uplo,
1876       n,
1877       CUDA_R_64F,
1878       reinterpret_cast<void*>(A),
1879       lda,
1880       CUDA_R_64F,
1881       reinterpret_cast<void*>(W),
1882       CUDA_R_64F,
1883       reinterpret_cast<void*>(bufferOnDevice),
1884       workspaceInBytesOnDevice,
1885       reinterpret_cast<void*>(bufferOnHost),
1886       workspaceInBytesOnHost,
1887       info));
1888 }
1889 
1890 template <>
xsyevd(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,c10::complex<float> * A,int64_t lda,float * W,c10::complex<float> * bufferOnDevice,size_t workspaceInBytesOnDevice,c10::complex<float> * bufferOnHost,size_t workspaceInBytesOnHost,int * info)1891 void xsyevd<c10::complex<float>, float>(
1892     cusolverDnHandle_t handle,
1893     cusolverDnParams_t params,
1894     cusolverEigMode_t jobz,
1895     cublasFillMode_t uplo,
1896     int64_t n,
1897     c10::complex<float>* A,
1898     int64_t lda,
1899     float* W,
1900     c10::complex<float>* bufferOnDevice,
1901     size_t workspaceInBytesOnDevice,
1902     c10::complex<float>* bufferOnHost,
1903     size_t workspaceInBytesOnHost,
1904     int* info) {
1905   TORCH_CUSOLVER_CHECK(cusolverDnXsyevd(
1906       handle,
1907       params,
1908       jobz,
1909       uplo,
1910       n,
1911       CUDA_C_32F,
1912       reinterpret_cast<void*>(A),
1913       lda,
1914       CUDA_R_32F,
1915       reinterpret_cast<void*>(W),
1916       CUDA_C_32F,
1917       reinterpret_cast<void*>(bufferOnDevice),
1918       workspaceInBytesOnDevice,
1919       reinterpret_cast<void*>(bufferOnHost),
1920       workspaceInBytesOnHost,
1921       info));
1922 }
1923 
1924 template <>
xsyevd(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,c10::complex<double> * A,int64_t lda,double * W,c10::complex<double> * bufferOnDevice,size_t workspaceInBytesOnDevice,c10::complex<double> * bufferOnHost,size_t workspaceInBytesOnHost,int * info)1925 void xsyevd<c10::complex<double>, double>(
1926     cusolverDnHandle_t handle,
1927     cusolverDnParams_t params,
1928     cusolverEigMode_t jobz,
1929     cublasFillMode_t uplo,
1930     int64_t n,
1931     c10::complex<double>* A,
1932     int64_t lda,
1933     double* W,
1934     c10::complex<double>* bufferOnDevice,
1935     size_t workspaceInBytesOnDevice,
1936     c10::complex<double>* bufferOnHost,
1937     size_t workspaceInBytesOnHost,
1938     int* info) {
1939   TORCH_CUSOLVER_CHECK(cusolverDnXsyevd(
1940       handle,
1941       params,
1942       jobz,
1943       uplo,
1944       n,
1945       CUDA_C_64F,
1946       reinterpret_cast<void*>(A),
1947       lda,
1948       CUDA_R_64F,
1949       reinterpret_cast<void*>(W),
1950       CUDA_C_64F,
1951       reinterpret_cast<void*>(bufferOnDevice),
1952       workspaceInBytesOnDevice,
1953       reinterpret_cast<void*>(bufferOnHost),
1954       workspaceInBytesOnHost,
1955       info));
1956 }
1957 #endif // USE_CUSOLVER_64_BIT
1958 
1959 } // namespace at::cuda::solver
1960 
1961 #endif // CUDART_VERSION
1962