1 #include <ATen/Context.h>
2 #include <ATen/NativeFunctions.h>
3 #include <ATen/native/cuda/linalg/CUDASolver.h>
4 #include <c10/cuda/CUDACachingAllocator.h>
5 #include <c10/macros/Export.h>
6
7 #if defined(CUDART_VERSION) || defined(USE_ROCM)
8
9 namespace at::cuda::solver {
10
11 template <>
getrf(cusolverDnHandle_t handle,int m,int n,double * dA,int ldda,int * ipiv,int * info)12 void getrf<double>(
13 cusolverDnHandle_t handle, int m, int n, double* dA, int ldda, int* ipiv, int* info) {
14 int lwork;
15 TORCH_CUSOLVER_CHECK(
16 cusolverDnDgetrf_bufferSize(handle, m, n, dA, ldda, &lwork));
17 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
18 auto dataPtr = allocator.allocate(sizeof(double)*lwork);
19 TORCH_CUSOLVER_CHECK(cusolverDnDgetrf(
20 handle, m, n, dA, ldda, static_cast<double*>(dataPtr.get()), ipiv, info));
21 }
22
23 template <>
getrf(cusolverDnHandle_t handle,int m,int n,float * dA,int ldda,int * ipiv,int * info)24 void getrf<float>(
25 cusolverDnHandle_t handle, int m, int n, float* dA, int ldda, int* ipiv, int* info) {
26 int lwork;
27 TORCH_CUSOLVER_CHECK(
28 cusolverDnSgetrf_bufferSize(handle, m, n, dA, ldda, &lwork));
29 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
30 auto dataPtr = allocator.allocate(sizeof(float)*lwork);
31 TORCH_CUSOLVER_CHECK(cusolverDnSgetrf(
32 handle, m, n, dA, ldda, static_cast<float*>(dataPtr.get()), ipiv, info));
33 }
34
35 template <>
getrf(cusolverDnHandle_t handle,int m,int n,c10::complex<double> * dA,int ldda,int * ipiv,int * info)36 void getrf<c10::complex<double>>(
37 cusolverDnHandle_t handle,
38 int m,
39 int n,
40 c10::complex<double>* dA,
41 int ldda,
42 int* ipiv,
43 int* info) {
44 int lwork;
45 TORCH_CUSOLVER_CHECK(cusolverDnZgetrf_bufferSize(
46 handle, m, n, reinterpret_cast<cuDoubleComplex*>(dA), ldda, &lwork));
47 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
48 auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex) * lwork);
49 TORCH_CUSOLVER_CHECK(cusolverDnZgetrf(
50 handle,
51 m,
52 n,
53 reinterpret_cast<cuDoubleComplex*>(dA),
54 ldda,
55 static_cast<cuDoubleComplex*>(dataPtr.get()),
56 ipiv,
57 info));
58 }
59
60 template <>
getrf(cusolverDnHandle_t handle,int m,int n,c10::complex<float> * dA,int ldda,int * ipiv,int * info)61 void getrf<c10::complex<float>>(
62 cusolverDnHandle_t handle,
63 int m,
64 int n,
65 c10::complex<float>* dA,
66 int ldda,
67 int* ipiv,
68 int* info) {
69 int lwork;
70 TORCH_CUSOLVER_CHECK(cusolverDnCgetrf_bufferSize(
71 handle, m, n, reinterpret_cast<cuComplex*>(dA), ldda, &lwork));
72 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
73 auto dataPtr = allocator.allocate(sizeof(cuComplex) * lwork);
74 TORCH_CUSOLVER_CHECK(cusolverDnCgetrf(
75 handle,
76 m,
77 n,
78 reinterpret_cast<cuComplex*>(dA),
79 ldda,
80 static_cast<cuComplex*>(dataPtr.get()),
81 ipiv,
82 info));
83 }
84
85 template <>
getrs(cusolverDnHandle_t handle,int n,int nrhs,double * dA,int lda,int * ipiv,double * ret,int ldb,int * info,cublasOperation_t trans)86 void getrs<double>(
87 cusolverDnHandle_t handle, int n, int nrhs, double* dA, int lda, int* ipiv, double* ret, int ldb, int* info, cublasOperation_t trans) {
88 TORCH_CUSOLVER_CHECK(cusolverDnDgetrs(
89 handle, trans, n, nrhs, dA, lda, ipiv, ret, ldb, info));
90 }
91
92 template <>
getrs(cusolverDnHandle_t handle,int n,int nrhs,float * dA,int lda,int * ipiv,float * ret,int ldb,int * info,cublasOperation_t trans)93 void getrs<float>(
94 cusolverDnHandle_t handle, int n, int nrhs, float* dA, int lda, int* ipiv, float* ret, int ldb, int* info, cublasOperation_t trans) {
95 TORCH_CUSOLVER_CHECK(cusolverDnSgetrs(
96 handle, trans, n, nrhs, dA, lda, ipiv, ret, ldb, info));
97 }
98
99 template <>
getrs(cusolverDnHandle_t handle,int n,int nrhs,c10::complex<double> * dA,int lda,int * ipiv,c10::complex<double> * ret,int ldb,int * info,cublasOperation_t trans)100 void getrs<c10::complex<double>>(
101 cusolverDnHandle_t handle,
102 int n,
103 int nrhs,
104 c10::complex<double>* dA,
105 int lda,
106 int* ipiv,
107 c10::complex<double>* ret,
108 int ldb,
109 int* info,
110 cublasOperation_t trans) {
111 TORCH_CUSOLVER_CHECK(cusolverDnZgetrs(
112 handle,
113 trans,
114 n,
115 nrhs,
116 reinterpret_cast<cuDoubleComplex*>(dA),
117 lda,
118 ipiv,
119 reinterpret_cast<cuDoubleComplex*>(ret),
120 ldb,
121 info));
122 }
123
124 template <>
getrs(cusolverDnHandle_t handle,int n,int nrhs,c10::complex<float> * dA,int lda,int * ipiv,c10::complex<float> * ret,int ldb,int * info,cublasOperation_t trans)125 void getrs<c10::complex<float>>(
126 cusolverDnHandle_t handle,
127 int n,
128 int nrhs,
129 c10::complex<float>* dA,
130 int lda,
131 int* ipiv,
132 c10::complex<float>* ret,
133 int ldb,
134 int* info,
135 cublasOperation_t trans) {
136 TORCH_CUSOLVER_CHECK(cusolverDnCgetrs(
137 handle,
138 trans,
139 n,
140 nrhs,
141 reinterpret_cast<cuComplex*>(dA),
142 lda,
143 ipiv,
144 reinterpret_cast<cuComplex*>(ret),
145 ldb,
146 info));
147 }
148
149 template <>
sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES (double))150 void sytrf_bufferSize<double>(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(double)) {
151 TORCH_CUSOLVER_CHECK(cusolverDnDsytrf_bufferSize(handle, n, A, lda, lwork));
152 }
153
154 template <>
sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES (float))155 void sytrf_bufferSize<float>(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(float)) {
156 TORCH_CUSOLVER_CHECK(cusolverDnSsytrf_bufferSize(handle, n, A, lda, lwork));
157 }
158
159 template <>
sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES (c10::complex<double>))160 void sytrf_bufferSize<c10::complex<double>>(
161 CUDASOLVER_SYTRF_BUFFER_ARGTYPES(c10::complex<double>)) {
162 TORCH_CUSOLVER_CHECK(cusolverDnZsytrf_bufferSize(
163 handle, n, reinterpret_cast<cuDoubleComplex*>(A), lda, lwork));
164 }
165
166 template <>
sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES (c10::complex<float>))167 void sytrf_bufferSize<c10::complex<float>>(
168 CUDASOLVER_SYTRF_BUFFER_ARGTYPES(c10::complex<float>)) {
169 TORCH_CUSOLVER_CHECK(cusolverDnCsytrf_bufferSize(
170 handle, n, reinterpret_cast<cuComplex*>(A), lda, lwork));
171 }
172
173 template <>
sytrf(CUDASOLVER_SYTRF_ARGTYPES (double))174 void sytrf<double>(CUDASOLVER_SYTRF_ARGTYPES(double)) {
175 TORCH_CUSOLVER_CHECK(
176 cusolverDnDsytrf(handle, uplo, n, A, lda, ipiv, work, lwork, devInfo));
177 }
178
179 template <>
sytrf(CUDASOLVER_SYTRF_ARGTYPES (float))180 void sytrf<float>(CUDASOLVER_SYTRF_ARGTYPES(float)) {
181 TORCH_CUSOLVER_CHECK(
182 cusolverDnSsytrf(handle, uplo, n, A, lda, ipiv, work, lwork, devInfo));
183 }
184
185 template <>
sytrf(CUDASOLVER_SYTRF_ARGTYPES (c10::complex<double>))186 void sytrf<c10::complex<double>>(
187 CUDASOLVER_SYTRF_ARGTYPES(c10::complex<double>)) {
188 TORCH_CUSOLVER_CHECK(cusolverDnZsytrf(
189 handle,
190 uplo,
191 n,
192 reinterpret_cast<cuDoubleComplex*>(A),
193 lda,
194 ipiv,
195 reinterpret_cast<cuDoubleComplex*>(work),
196 lwork,
197 devInfo));
198 }
199
200 template <>
sytrf(CUDASOLVER_SYTRF_ARGTYPES (c10::complex<float>))201 void sytrf<c10::complex<float>>(
202 CUDASOLVER_SYTRF_ARGTYPES(c10::complex<float>)) {
203 TORCH_CUSOLVER_CHECK(cusolverDnCsytrf(
204 handle,
205 uplo,
206 n,
207 reinterpret_cast<cuComplex*>(A),
208 lda,
209 ipiv,
210 reinterpret_cast<cuComplex*>(work),
211 lwork,
212 devInfo));
213 }
214
215 template<>
gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES ())216 void gesvd_buffersize<float>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
217 TORCH_CUSOLVER_CHECK(cusolverDnSgesvd_bufferSize(handle, m, n, lwork));
218 }
219
220 template<>
gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES ())221 void gesvd_buffersize<double>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
222 TORCH_CUSOLVER_CHECK(cusolverDnDgesvd_bufferSize(handle, m, n, lwork));
223 }
224
225 template<>
gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES ())226 void gesvd_buffersize<c10::complex<float>>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
227 TORCH_CUSOLVER_CHECK(cusolverDnCgesvd_bufferSize(handle, m, n, lwork));
228 }
229
230 template<>
gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES ())231 void gesvd_buffersize<c10::complex<double>>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
232 TORCH_CUSOLVER_CHECK(cusolverDnZgesvd_bufferSize(handle, m, n, lwork));
233 }
234
235
236 template<>
gesvd(CUDASOLVER_GESVD_ARGTYPES (float,float))237 void gesvd<float>(CUDASOLVER_GESVD_ARGTYPES(float, float)) {
238 TORCH_CUSOLVER_CHECK(cusolverDnSgesvd(
239 handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, lwork, rwork, info));
240 }
241
242 template<>
gesvd(CUDASOLVER_GESVD_ARGTYPES (double,double))243 void gesvd<double>(CUDASOLVER_GESVD_ARGTYPES(double, double)) {
244 TORCH_CUSOLVER_CHECK(cusolverDnDgesvd(
245 handle, jobu, jobvt, m, n, A, lda, S, U, ldu, VT, ldvt, work, lwork, rwork, info));
246 }
247
248
249 template<>
gesvd(CUDASOLVER_GESVD_ARGTYPES (c10::complex<float>,float))250 void gesvd<c10::complex<float>>(CUDASOLVER_GESVD_ARGTYPES(c10::complex<float>, float)) {
251 TORCH_CUSOLVER_CHECK(cusolverDnCgesvd(
252 handle, jobu, jobvt, m, n,
253 reinterpret_cast<cuComplex*>(A),
254 lda, S,
255 reinterpret_cast<cuComplex*>(U),
256 ldu,
257 reinterpret_cast<cuComplex*>(VT),
258 ldvt,
259 reinterpret_cast<cuComplex*>(work),
260 lwork, rwork, info
261 ));
262 }
263
264 template<>
gesvd(CUDASOLVER_GESVD_ARGTYPES (c10::complex<double>,double))265 void gesvd<c10::complex<double>>(CUDASOLVER_GESVD_ARGTYPES(c10::complex<double>, double)) {
266 TORCH_CUSOLVER_CHECK(cusolverDnZgesvd(
267 handle, jobu, jobvt, m, n,
268 reinterpret_cast<cuDoubleComplex*>(A),
269 lda, S,
270 reinterpret_cast<cuDoubleComplex*>(U),
271 ldu,
272 reinterpret_cast<cuDoubleComplex*>(VT),
273 ldvt,
274 reinterpret_cast<cuDoubleComplex*>(work),
275 lwork, rwork, info
276 ));
277 }
278
279
280 template<>
gesvdj_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,float * A,int lda,float * S,float * U,int ldu,float * V,int ldv,int * lwork,gesvdjInfo_t params)281 void gesvdj_buffersize<float>(
282 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, float *A, int lda, float *S,
283 float *U, int ldu, float *V, int ldv, int *lwork, gesvdjInfo_t params
284 ) {
285 TORCH_CUSOLVER_CHECK(cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, params));
286 }
287
288 template<>
gesvdj_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,double * A,int lda,double * S,double * U,int ldu,double * V,int ldv,int * lwork,gesvdjInfo_t params)289 void gesvdj_buffersize<double>(
290 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, double *A, int lda, double *S,
291 double *U, int ldu, double *V, int ldv, int *lwork, gesvdjInfo_t params
292 ) {
293 TORCH_CUSOLVER_CHECK(cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, lwork, params));
294 }
295
296 template<>
gesvdj_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,c10::complex<float> * A,int lda,float * S,c10::complex<float> * U,int ldu,c10::complex<float> * V,int ldv,int * lwork,gesvdjInfo_t params)297 void gesvdj_buffersize<c10::complex<float>>(
298 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex<float> *A, int lda, float *S,
299 c10::complex<float> *U, int ldu, c10::complex<float> *V, int ldv, int *lwork, gesvdjInfo_t params
300 ) {
301 TORCH_CUSOLVER_CHECK(cusolverDnCgesvdj_bufferSize(handle, jobz, econ, m, n,
302 reinterpret_cast<cuComplex*>(A),
303 lda,
304 S,
305 reinterpret_cast<cuComplex*>(U),
306 ldu,
307 reinterpret_cast<cuComplex*>(V),
308 ldv, lwork, params));
309 }
310
311 template<>
gesvdj_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,c10::complex<double> * A,int lda,double * S,c10::complex<double> * U,int ldu,c10::complex<double> * V,int ldv,int * lwork,gesvdjInfo_t params)312 void gesvdj_buffersize<c10::complex<double>>(
313 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex<double> *A, int lda, double *S,
314 c10::complex<double> *U, int ldu, c10::complex<double> *V, int ldv, int *lwork, gesvdjInfo_t params
315 ) {
316 TORCH_CUSOLVER_CHECK(cusolverDnZgesvdj_bufferSize(handle, jobz, econ, m, n,
317 reinterpret_cast<cuDoubleComplex*>(A),
318 lda,
319 S,
320 reinterpret_cast<cuDoubleComplex*>(U),
321 ldu,
322 reinterpret_cast<cuDoubleComplex*>(V),
323 ldv, lwork, params));
324 }
325
326
327 template<>
gesvdj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,float * A,int lda,float * S,float * U,int ldu,float * V,int ldv,float * work,int lwork,int * info,gesvdjInfo_t params)328 void gesvdj<float>(
329 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, float* A, int lda, float* S, float* U,
330 int ldu, float *V, int ldv, float* work, int lwork, int *info, gesvdjInfo_t params
331 ) {
332 TORCH_CUSOLVER_CHECK(cusolverDnSgesvdj(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, info, params));
333 }
334
335 template<>
gesvdj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,double * A,int lda,double * S,double * U,int ldu,double * V,int ldv,double * work,int lwork,int * info,gesvdjInfo_t params)336 void gesvdj<double>(
337 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, double* A, int lda, double* S, double* U,
338 int ldu, double *V, int ldv, double* work, int lwork, int *info, gesvdjInfo_t params
339 ) {
340 TORCH_CUSOLVER_CHECK(cusolverDnDgesvdj(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, work, lwork, info, params));
341 }
342
343 template<>
gesvdj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,c10::complex<float> * A,int lda,float * S,c10::complex<float> * U,int ldu,c10::complex<float> * V,int ldv,c10::complex<float> * work,int lwork,int * info,gesvdjInfo_t params)344 void gesvdj<c10::complex<float>>(
345 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex<float>* A, int lda, float* S, c10::complex<float>* U,
346 int ldu, c10::complex<float> *V, int ldv, c10::complex<float>* work, int lwork, int *info, gesvdjInfo_t params
347 ) {
348 TORCH_CUSOLVER_CHECK(cusolverDnCgesvdj(
349 handle, jobz, econ, m, n,
350 reinterpret_cast<cuComplex*>(A),
351 lda, S,
352 reinterpret_cast<cuComplex*>(U),
353 ldu,
354 reinterpret_cast<cuComplex*>(V),
355 ldv,
356 reinterpret_cast<cuComplex*>(work),
357 lwork, info, params));
358 }
359
360 template<>
gesvdj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int econ,int m,int n,c10::complex<double> * A,int lda,double * S,c10::complex<double> * U,int ldu,c10::complex<double> * V,int ldv,c10::complex<double> * work,int lwork,int * info,gesvdjInfo_t params)361 void gesvdj<c10::complex<double>>(
362 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex<double>* A, int lda, double* S, c10::complex<double>* U,
363 int ldu, c10::complex<double> *V, int ldv, c10::complex<double>* work, int lwork, int *info, gesvdjInfo_t params
364 ) {
365 TORCH_CUSOLVER_CHECK(cusolverDnZgesvdj(
366 handle, jobz, econ, m, n,
367 reinterpret_cast<cuDoubleComplex*>(A),
368 lda, S,
369 reinterpret_cast<cuDoubleComplex*>(U),
370 ldu,
371 reinterpret_cast<cuDoubleComplex*>(V),
372 ldv,
373 reinterpret_cast<cuDoubleComplex*>(work),
374 lwork, info, params));
375 }
376
377
378 template<>
gesvdjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int m,int n,float * A,int lda,float * S,float * U,int ldu,float * V,int ldv,int * info,gesvdjInfo_t params,int batchSize)379 void gesvdjBatched<float>(
380 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, float* A, int lda, float* S, float* U,
381 int ldu, float *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
382 ) {
383 int lwork;
384 TORCH_CUSOLVER_CHECK(cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, &lwork, params, batchSize));
385
386 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
387 auto dataPtr = allocator.allocate(sizeof(float)*lwork);
388
389 TORCH_CUSOLVER_CHECK(cusolverDnSgesvdjBatched(
390 handle, jobz, m, n, A, lda, S, U, ldu, V, ldv,
391 static_cast<float*>(dataPtr.get()),
392 lwork, info, params, batchSize));
393 }
394
395 template<>
gesvdjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int m,int n,double * A,int lda,double * S,double * U,int ldu,double * V,int ldv,int * info,gesvdjInfo_t params,int batchSize)396 void gesvdjBatched<double>(
397 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, double* A, int lda, double* S, double* U,
398 int ldu, double *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
399 ) {
400 int lwork;
401 TORCH_CUSOLVER_CHECK(cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, &lwork, params, batchSize));
402
403 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
404 auto dataPtr = allocator.allocate(sizeof(double)*lwork);
405
406 TORCH_CUSOLVER_CHECK(cusolverDnDgesvdjBatched(
407 handle, jobz, m, n, A, lda, S, U, ldu, V, ldv,
408 static_cast<double*>(dataPtr.get()),
409 lwork, info, params, batchSize));
410 }
411
412 template<>
gesvdjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int m,int n,c10::complex<float> * A,int lda,float * S,c10::complex<float> * U,int ldu,c10::complex<float> * V,int ldv,int * info,gesvdjInfo_t params,int batchSize)413 void gesvdjBatched<c10::complex<float>>(
414 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, c10::complex<float>* A, int lda, float* S, c10::complex<float>* U,
415 int ldu, c10::complex<float> *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
416 ) {
417 int lwork;
418 TORCH_CUSOLVER_CHECK(cusolverDnCgesvdjBatched_bufferSize(
419 handle, jobz, m, n,
420 reinterpret_cast<cuComplex*>(A),
421 lda, S,
422 reinterpret_cast<cuComplex*>(U),
423 ldu,
424 reinterpret_cast<cuComplex*>(V),
425 ldv, &lwork, params, batchSize));
426
427 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
428 auto dataPtr = allocator.allocate(sizeof(cuComplex)*lwork);
429
430 TORCH_CUSOLVER_CHECK(cusolverDnCgesvdjBatched(
431 handle, jobz, m, n,
432 reinterpret_cast<cuComplex*>(A),
433 lda, S,
434 reinterpret_cast<cuComplex*>(U),
435 ldu,
436 reinterpret_cast<cuComplex*>(V),
437 ldv,
438 static_cast<cuComplex*>(dataPtr.get()),
439 lwork, info, params, batchSize));
440 }
441
442 template<>
gesvdjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int m,int n,c10::complex<double> * A,int lda,double * S,c10::complex<double> * U,int ldu,c10::complex<double> * V,int ldv,int * info,gesvdjInfo_t params,int batchSize)443 void gesvdjBatched<c10::complex<double>>(
444 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, c10::complex<double>* A, int lda, double* S, c10::complex<double>* U,
445 int ldu, c10::complex<double> *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
446 ) {
447 int lwork;
448 TORCH_CUSOLVER_CHECK(cusolverDnZgesvdjBatched_bufferSize(
449 handle, jobz, m, n,
450 reinterpret_cast<cuDoubleComplex*>(A),
451 lda, S,
452 reinterpret_cast<cuDoubleComplex*>(U),
453 ldu,
454 reinterpret_cast<cuDoubleComplex*>(V),
455 ldv, &lwork, params, batchSize));
456
457 auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
458 auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex)*lwork);
459
460 TORCH_CUSOLVER_CHECK(cusolverDnZgesvdjBatched(
461 handle, jobz, m, n,
462 reinterpret_cast<cuDoubleComplex*>(A),
463 lda, S,
464 reinterpret_cast<cuDoubleComplex*>(U),
465 ldu,
466 reinterpret_cast<cuDoubleComplex*>(V),
467 ldv,
468 static_cast<cuDoubleComplex*>(dataPtr.get()),
469 lwork, info, params, batchSize));
470 }
471
472
473 // ROCM does not implement gesdva yet
474 #ifdef CUDART_VERSION
475 template<>
gesvdaStridedBatched_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,float * A,int lda,long long int strideA,float * S,long long int strideS,float * U,int ldu,long long int strideU,float * V,int ldv,long long int strideV,int * lwork,int batchSize)476 void gesvdaStridedBatched_buffersize<float>(
477 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, float *A, int lda, long long int strideA,
478 float *S, long long int strideS, float *U, int ldu, long long int strideU, float *V, int ldv, long long int strideV,
479 int *lwork, int batchSize
480 ) {
481 TORCH_CUSOLVER_CHECK(cusolverDnSgesvdaStridedBatched_bufferSize(
482 handle, jobz, rank, m, n, A, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, lwork, batchSize
483 ));
484 }
485
486 template<>
gesvdaStridedBatched_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,double * A,int lda,long long int strideA,double * S,long long int strideS,double * U,int ldu,long long int strideU,double * V,int ldv,long long int strideV,int * lwork,int batchSize)487 void gesvdaStridedBatched_buffersize<double>(
488 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, double *A, int lda, long long int strideA,
489 double *S, long long int strideS, double *U, int ldu, long long int strideU, double *V, int ldv, long long int strideV,
490 int *lwork, int batchSize
491 ) {
492 TORCH_CUSOLVER_CHECK(cusolverDnDgesvdaStridedBatched_bufferSize(
493 handle, jobz, rank, m, n, A, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, lwork, batchSize
494 ));
495 }
496
497 template<>
gesvdaStridedBatched_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,c10::complex<float> * A,int lda,long long int strideA,float * S,long long int strideS,c10::complex<float> * U,int ldu,long long int strideU,c10::complex<float> * V,int ldv,long long int strideV,int * lwork,int batchSize)498 void gesvdaStridedBatched_buffersize<c10::complex<float>>(
499 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, c10::complex<float> *A, int lda, long long int strideA,
500 float *S, long long int strideS, c10::complex<float> *U, int ldu, long long int strideU,
501 c10::complex<float> *V, int ldv, long long int strideV,
502 int *lwork, int batchSize
503 ) {
504 TORCH_CUSOLVER_CHECK(cusolverDnCgesvdaStridedBatched_bufferSize(
505 handle, jobz, rank, m, n,
506 reinterpret_cast<cuComplex*>(A),
507 lda, strideA, S, strideS,
508 reinterpret_cast<cuComplex*>(U),
509 ldu, strideU,
510 reinterpret_cast<cuComplex*>(V),
511 ldv, strideV, lwork, batchSize
512 ));
513 }
514
515 template<>
gesvdaStridedBatched_buffersize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,c10::complex<double> * A,int lda,long long int strideA,double * S,long long int strideS,c10::complex<double> * U,int ldu,long long int strideU,c10::complex<double> * V,int ldv,long long int strideV,int * lwork,int batchSize)516 void gesvdaStridedBatched_buffersize<c10::complex<double>>(
517 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, c10::complex<double> *A, int lda, long long int strideA,
518 double *S, long long int strideS, c10::complex<double> *U, int ldu, long long int strideU,
519 c10::complex<double> *V, int ldv, long long int strideV,
520 int *lwork, int batchSize
521 ) {
522 TORCH_CUSOLVER_CHECK(cusolverDnZgesvdaStridedBatched_bufferSize(
523 handle, jobz, rank, m, n,
524 reinterpret_cast<cuDoubleComplex*>(A),
525 lda, strideA, S, strideS,
526 reinterpret_cast<cuDoubleComplex*>(U),
527 ldu, strideU,
528 reinterpret_cast<cuDoubleComplex*>(V),
529 ldv, strideV, lwork, batchSize
530 ));
531 }
532
533
534 template<>
gesvdaStridedBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,float * A,int lda,long long int strideA,float * S,long long int strideS,float * U,int ldu,long long int strideU,float * V,int ldv,long long int strideV,float * work,int lwork,int * info,double * h_R_nrmF,int batchSize)535 void gesvdaStridedBatched<float>(
536 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, float *A, int lda, long long int strideA,
537 float *S, long long int strideS, float *U, int ldu, long long int strideU, float *V, int ldv, long long int strideV,
538 float *work, int lwork, int *info, double *h_R_nrmF, int batchSize
539 ) {
540 TORCH_CUSOLVER_CHECK(cusolverDnSgesvdaStridedBatched(
541 handle, jobz, rank, m, n, A, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, work, lwork, info, h_R_nrmF, batchSize
542 ));
543 }
544
545 template<>
gesvdaStridedBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,double * A,int lda,long long int strideA,double * S,long long int strideS,double * U,int ldu,long long int strideU,double * V,int ldv,long long int strideV,double * work,int lwork,int * info,double * h_R_nrmF,int batchSize)546 void gesvdaStridedBatched<double>(
547 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, double *A, int lda, long long int strideA,
548 double *S, long long int strideS, double *U, int ldu, long long int strideU, double *V, int ldv, long long int strideV,
549 double *work, int lwork, int *info, double *h_R_nrmF, int batchSize
550 ) {
551 TORCH_CUSOLVER_CHECK(cusolverDnDgesvdaStridedBatched(
552 handle, jobz, rank, m, n, A, lda, strideA, S, strideS, U, ldu, strideU, V, ldv, strideV, work, lwork, info, h_R_nrmF, batchSize
553 ));
554 }
555
556 template<>
gesvdaStridedBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,c10::complex<float> * A,int lda,long long int strideA,float * S,long long int strideS,c10::complex<float> * U,int ldu,long long int strideU,c10::complex<float> * V,int ldv,long long int strideV,c10::complex<float> * work,int lwork,int * info,double * h_R_nrmF,int batchSize)557 void gesvdaStridedBatched<c10::complex<float>>(
558 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, c10::complex<float> *A, int lda, long long int strideA,
559 float *S, long long int strideS, c10::complex<float> *U, int ldu, long long int strideU,
560 c10::complex<float> *V, int ldv, long long int strideV,
561 c10::complex<float> *work, int lwork, int *info, double *h_R_nrmF, int batchSize
562 ) {
563 TORCH_CUSOLVER_CHECK(cusolverDnCgesvdaStridedBatched(
564 handle, jobz, rank, m, n,
565 reinterpret_cast<cuComplex*>(A),
566 lda, strideA, S, strideS,
567 reinterpret_cast<cuComplex*>(U),
568 ldu, strideU,
569 reinterpret_cast<cuComplex*>(V),
570 ldv, strideV,
571 reinterpret_cast<cuComplex*>(work),
572 lwork, info, h_R_nrmF, batchSize
573 ));
574 }
575
576 template<>
gesvdaStridedBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,int rank,int m,int n,c10::complex<double> * A,int lda,long long int strideA,double * S,long long int strideS,c10::complex<double> * U,int ldu,long long int strideU,c10::complex<double> * V,int ldv,long long int strideV,c10::complex<double> * work,int lwork,int * info,double * h_R_nrmF,int batchSize)577 void gesvdaStridedBatched<c10::complex<double>>(
578 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, c10::complex<double> *A, int lda, long long int strideA,
579 double *S, long long int strideS, c10::complex<double> *U, int ldu, long long int strideU,
580 c10::complex<double> *V, int ldv, long long int strideV,
581 c10::complex<double> *work, int lwork, int *info, double *h_R_nrmF, int batchSize
582 ) {
583 TORCH_CUSOLVER_CHECK(cusolverDnZgesvdaStridedBatched(
584 handle, jobz, rank, m, n,
585 reinterpret_cast<cuDoubleComplex*>(A),
586 lda, strideA, S, strideS,
587 reinterpret_cast<cuDoubleComplex*>(U),
588 ldu, strideU,
589 reinterpret_cast<cuDoubleComplex*>(V),
590 ldv, strideV,
591 reinterpret_cast<cuDoubleComplex*>(work),
592 lwork, info, h_R_nrmF, batchSize
593 ));
594 }
595 #endif
596
597
598 template<>
potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,float * A,int lda,float * work,int lwork,int * info)599 void potrf<float>(
600 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float* A, int lda, float* work, int lwork, int* info
601 ) {
602 TORCH_CUSOLVER_CHECK(cusolverDnSpotrf(
603 handle, uplo, n, A, lda, work, lwork, info));
604 }
605
606 template<>
potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,double * A,int lda,double * work,int lwork,int * info)607 void potrf<double>(
608 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double* A, int lda, double* work, int lwork, int* info
609 ) {
610 TORCH_CUSOLVER_CHECK(cusolverDnDpotrf(
611 handle, uplo, n, A, lda, work, lwork, info));
612 }
613
614 template<>
potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<float> * A,int lda,c10::complex<float> * work,int lwork,int * info)615 void potrf<c10::complex<float>>(
616 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<float>* A, int lda, c10::complex<float>* work, int lwork, int* info
617 ) {
618 TORCH_CUSOLVER_CHECK(cusolverDnCpotrf(
619 handle,
620 uplo,
621 n,
622 reinterpret_cast<cuComplex*>(A),
623 lda,
624 reinterpret_cast<cuComplex*>(work),
625 lwork,
626 info));
627 }
628
629 template<>
potrf(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<double> * A,int lda,c10::complex<double> * work,int lwork,int * info)630 void potrf<c10::complex<double>>(
631 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<double>* A, int lda, c10::complex<double>* work, int lwork, int* info
632 ) {
633 TORCH_CUSOLVER_CHECK(cusolverDnZpotrf(
634 handle,
635 uplo,
636 n,
637 reinterpret_cast<cuDoubleComplex*>(A),
638 lda,
639 reinterpret_cast<cuDoubleComplex*>(work),
640 lwork,
641 info));
642 }
643
644
645 template<>
potrf_buffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,float * A,int lda,int * lwork)646 void potrf_buffersize<float>(
647 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float* A, int lda, int* lwork
648 ) {
649 TORCH_CUSOLVER_CHECK(cusolverDnSpotrf_bufferSize(handle, uplo, n, A, lda, lwork));
650 }
651
652 template<>
potrf_buffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,double * A,int lda,int * lwork)653 void potrf_buffersize<double>(
654 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double* A, int lda, int* lwork
655 ) {
656 TORCH_CUSOLVER_CHECK(cusolverDnDpotrf_bufferSize(handle, uplo, n, A, lda, lwork));
657 }
658
659 template<>
potrf_buffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<float> * A,int lda,int * lwork)660 void potrf_buffersize<c10::complex<float>>(
661 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<float>* A, int lda, int* lwork
662 ) {
663 TORCH_CUSOLVER_CHECK(cusolverDnCpotrf_bufferSize(
664 handle, uplo, n,
665 reinterpret_cast<cuComplex*>(A),
666 lda, lwork));
667 }
668
669 template<>
potrf_buffersize(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<double> * A,int lda,int * lwork)670 void potrf_buffersize<c10::complex<double>>(
671 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<double>* A, int lda, int* lwork
672 ) {
673 TORCH_CUSOLVER_CHECK(cusolverDnZpotrf_bufferSize(
674 handle, uplo, n,
675 reinterpret_cast<cuDoubleComplex*>(A),
676 lda, lwork));
677 }
678
679
680 template<>
potrfBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,float ** A,int lda,int * info,int batchSize)681 void potrfBatched<float>(
682 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, float** A, int lda, int* info, int batchSize
683 ) {
684 TORCH_CUSOLVER_CHECK(cusolverDnSpotrfBatched(handle, uplo, n, A, lda, info, batchSize));
685 }
686
687 template<>
potrfBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,double ** A,int lda,int * info,int batchSize)688 void potrfBatched<double>(
689 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, double** A, int lda, int* info, int batchSize
690 ) {
691 TORCH_CUSOLVER_CHECK(cusolverDnDpotrfBatched(handle, uplo, n, A, lda, info, batchSize));
692 }
693
694 template<>
potrfBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<float> ** A,int lda,int * info,int batchSize)695 void potrfBatched<c10::complex<float>>(
696 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<float>** A, int lda, int* info, int batchSize
697 ) {
698 TORCH_CUSOLVER_CHECK(cusolverDnCpotrfBatched(
699 handle, uplo, n,
700 reinterpret_cast<cuComplex**>(A),
701 lda, info, batchSize));
702 }
703
704 template<>
potrfBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,c10::complex<double> ** A,int lda,int * info,int batchSize)705 void potrfBatched<c10::complex<double>>(
706 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, c10::complex<double>** A, int lda, int* info, int batchSize
707 ) {
708 TORCH_CUSOLVER_CHECK(cusolverDnZpotrfBatched(
709 handle, uplo, n,
710 reinterpret_cast<cuDoubleComplex**>(A),
711 lda, info, batchSize));
712 }
713
714 template <>
geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES (float))715 void geqrf_bufferSize<float>(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(float)) {
716 TORCH_CUSOLVER_CHECK(
717 cusolverDnSgeqrf_bufferSize(handle, m, n, A, lda, lwork));
718 }
719
720 template <>
geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES (double))721 void geqrf_bufferSize<double>(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(double)) {
722 TORCH_CUSOLVER_CHECK(
723 cusolverDnDgeqrf_bufferSize(handle, m, n, A, lda, lwork));
724 }
725
726 template <>
geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES (c10::complex<float>))727 void geqrf_bufferSize<c10::complex<float>>(
728 CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(c10::complex<float>)) {
729 TORCH_CUSOLVER_CHECK(cusolverDnCgeqrf_bufferSize(
730 handle, m, n, reinterpret_cast<cuComplex*>(A), lda, lwork));
731 }
732
733 template <>
geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES (c10::complex<double>))734 void geqrf_bufferSize<c10::complex<double>>(
735 CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(c10::complex<double>)) {
736 TORCH_CUSOLVER_CHECK(cusolverDnZgeqrf_bufferSize(
737 handle, m, n, reinterpret_cast<cuDoubleComplex*>(A), lda, lwork));
738 }
739
740 template <>
geqrf(CUDASOLVER_GEQRF_ARGTYPES (float))741 void geqrf<float>(CUDASOLVER_GEQRF_ARGTYPES(float)) {
742 TORCH_CUSOLVER_CHECK(
743 cusolverDnSgeqrf(handle, m, n, A, lda, tau, work, lwork, devInfo));
744 }
745
746 template <>
geqrf(CUDASOLVER_GEQRF_ARGTYPES (double))747 void geqrf<double>(CUDASOLVER_GEQRF_ARGTYPES(double)) {
748 TORCH_CUSOLVER_CHECK(
749 cusolverDnDgeqrf(handle, m, n, A, lda, tau, work, lwork, devInfo));
750 }
751
752 template <>
geqrf(CUDASOLVER_GEQRF_ARGTYPES (c10::complex<float>))753 void geqrf<c10::complex<float>>(
754 CUDASOLVER_GEQRF_ARGTYPES(c10::complex<float>)) {
755 TORCH_CUSOLVER_CHECK(cusolverDnCgeqrf(
756 handle,
757 m,
758 n,
759 reinterpret_cast<cuComplex*>(A),
760 lda,
761 reinterpret_cast<cuComplex*>(tau),
762 reinterpret_cast<cuComplex*>(work),
763 lwork,
764 devInfo));
765 }
766
767 template <>
geqrf(CUDASOLVER_GEQRF_ARGTYPES (c10::complex<double>))768 void geqrf<c10::complex<double>>(
769 CUDASOLVER_GEQRF_ARGTYPES(c10::complex<double>)) {
770 TORCH_CUSOLVER_CHECK(cusolverDnZgeqrf(
771 handle,
772 m,
773 n,
774 reinterpret_cast<cuDoubleComplex*>(A),
775 lda,
776 reinterpret_cast<cuDoubleComplex*>(tau),
777 reinterpret_cast<cuDoubleComplex*>(work),
778 lwork,
779 devInfo));
780 }
781
782 template<>
potrs(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,const float * A,int lda,float * B,int ldb,int * devInfo)783 void potrs<float>(
784 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, const float *A, int lda, float *B, int ldb, int *devInfo
785 ) {
786 TORCH_CUSOLVER_CHECK(cusolverDnSpotrs(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo));
787 }
788
789 template<>
potrs(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,const double * A,int lda,double * B,int ldb,int * devInfo)790 void potrs<double>(
791 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, const double *A, int lda, double *B, int ldb, int *devInfo
792 ) {
793 TORCH_CUSOLVER_CHECK(cusolverDnDpotrs(handle, uplo, n, nrhs, A, lda, B, ldb, devInfo));
794 }
795
796 template<>
potrs(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,const c10::complex<float> * A,int lda,c10::complex<float> * B,int ldb,int * devInfo)797 void potrs<c10::complex<float>>(
798 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, const c10::complex<float> *A, int lda, c10::complex<float> *B, int ldb, int *devInfo
799 ) {
800 TORCH_CUSOLVER_CHECK(cusolverDnCpotrs(
801 handle, uplo, n, nrhs,
802 reinterpret_cast<const cuComplex*>(A),
803 lda,
804 reinterpret_cast<cuComplex*>(B),
805 ldb, devInfo));
806 }
807
808 template<>
potrs(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,const c10::complex<double> * A,int lda,c10::complex<double> * B,int ldb,int * devInfo)809 void potrs<c10::complex<double>>(
810 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, const c10::complex<double> *A, int lda, c10::complex<double> *B, int ldb, int *devInfo
811 ) {
812 TORCH_CUSOLVER_CHECK(cusolverDnZpotrs(
813 handle, uplo, n, nrhs,
814 reinterpret_cast<const cuDoubleComplex*>(A),
815 lda,
816 reinterpret_cast<cuDoubleComplex*>(B),
817 ldb, devInfo));
818 }
819
820 template<>
potrsBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,float * Aarray[],int lda,float * Barray[],int ldb,int * info,int batchSize)821 void potrsBatched<float>(
822 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, float *Aarray[], int lda, float *Barray[], int ldb, int *info, int batchSize
823 ) {
824 TORCH_CUSOLVER_CHECK(cusolverDnSpotrsBatched(handle, uplo, n, nrhs, Aarray, lda, Barray, ldb, info, batchSize));
825 }
826
827 template<>
potrsBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,double * Aarray[],int lda,double * Barray[],int ldb,int * info,int batchSize)828 void potrsBatched<double>(
829 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, double *Aarray[], int lda, double *Barray[], int ldb, int *info, int batchSize
830 ) {
831 TORCH_CUSOLVER_CHECK(cusolverDnDpotrsBatched(handle, uplo, n, nrhs, Aarray, lda, Barray, ldb, info, batchSize));
832 }
833
834 template<>
potrsBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,c10::complex<float> * Aarray[],int lda,c10::complex<float> * Barray[],int ldb,int * info,int batchSize)835 void potrsBatched<c10::complex<float>>(
836 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, c10::complex<float> *Aarray[], int lda, c10::complex<float> *Barray[], int ldb, int *info, int batchSize
837 ) {
838 TORCH_CUSOLVER_CHECK(cusolverDnCpotrsBatched(
839 handle, uplo, n, nrhs,
840 reinterpret_cast<cuComplex**>(Aarray),
841 lda,
842 reinterpret_cast<cuComplex**>(Barray),
843 ldb, info, batchSize));
844 }
845
846 template<>
potrsBatched(cusolverDnHandle_t handle,cublasFillMode_t uplo,int n,int nrhs,c10::complex<double> * Aarray[],int lda,c10::complex<double> * Barray[],int ldb,int * info,int batchSize)847 void potrsBatched<c10::complex<double>>(
848 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, c10::complex<double> *Aarray[], int lda, c10::complex<double> *Barray[], int ldb, int *info, int batchSize
849 ) {
850 TORCH_CUSOLVER_CHECK(cusolverDnZpotrsBatched(
851 handle, uplo, n, nrhs,
852 reinterpret_cast<cuDoubleComplex**>(Aarray),
853 lda,
854 reinterpret_cast<cuDoubleComplex**>(Barray),
855 ldb, info, batchSize));
856 }
857
858
859 template <>
orgqr_buffersize(cusolverDnHandle_t handle,int m,int n,int k,const float * A,int lda,const float * tau,int * lwork)860 void orgqr_buffersize<float>(
861 cusolverDnHandle_t handle,
862 int m, int n, int k,
863 const float* A, int lda,
864 const float* tau, int* lwork) {
865 TORCH_CUSOLVER_CHECK(
866 cusolverDnSorgqr_bufferSize(handle, m, n, k, A, lda, tau, lwork));
867 }
868
869 template <>
orgqr_buffersize(cusolverDnHandle_t handle,int m,int n,int k,const double * A,int lda,const double * tau,int * lwork)870 void orgqr_buffersize<double>(
871 cusolverDnHandle_t handle,
872 int m, int n, int k,
873 const double* A, int lda,
874 const double* tau, int* lwork) {
875 TORCH_CUSOLVER_CHECK(
876 cusolverDnDorgqr_bufferSize(handle, m, n, k, A, lda, tau, lwork));
877 }
878
879 template <>
orgqr_buffersize(cusolverDnHandle_t handle,int m,int n,int k,const c10::complex<float> * A,int lda,const c10::complex<float> * tau,int * lwork)880 void orgqr_buffersize<c10::complex<float>>(
881 cusolverDnHandle_t handle,
882 int m, int n, int k,
883 const c10::complex<float>* A, int lda,
884 const c10::complex<float>* tau, int* lwork) {
885 TORCH_CUSOLVER_CHECK(cusolverDnCungqr_bufferSize(
886 handle,
887 m, n, k,
888 reinterpret_cast<const cuComplex*>(A), lda,
889 reinterpret_cast<const cuComplex*>(tau), lwork));
890 }
891
892 template <>
orgqr_buffersize(cusolverDnHandle_t handle,int m,int n,int k,const c10::complex<double> * A,int lda,const c10::complex<double> * tau,int * lwork)893 void orgqr_buffersize<c10::complex<double>>(
894 cusolverDnHandle_t handle,
895 int m, int n, int k,
896 const c10::complex<double>* A, int lda,
897 const c10::complex<double>* tau, int* lwork) {
898 TORCH_CUSOLVER_CHECK(cusolverDnZungqr_bufferSize(
899 handle,
900 m, n, k,
901 reinterpret_cast<const cuDoubleComplex*>(A), lda,
902 reinterpret_cast<const cuDoubleComplex*>(tau), lwork));
903 }
904
905 template <>
orgqr(cusolverDnHandle_t handle,int m,int n,int k,float * A,int lda,const float * tau,float * work,int lwork,int * devInfo)906 void orgqr<float>(
907 cusolverDnHandle_t handle,
908 int m, int n, int k,
909 float* A, int lda,
910 const float* tau,
911 float* work, int lwork,
912 int* devInfo) {
913 TORCH_CUSOLVER_CHECK(
914 cusolverDnSorgqr(handle, m, n, k, A, lda, tau, work, lwork, devInfo));
915 }
916
917 template <>
orgqr(cusolverDnHandle_t handle,int m,int n,int k,double * A,int lda,const double * tau,double * work,int lwork,int * devInfo)918 void orgqr<double>(
919 cusolverDnHandle_t handle,
920 int m, int n, int k,
921 double* A, int lda,
922 const double* tau,
923 double* work, int lwork,
924 int* devInfo) {
925 TORCH_CUSOLVER_CHECK(
926 cusolverDnDorgqr(handle, m, n, k, A, lda, tau, work, lwork, devInfo));
927 }
928
929 template <>
orgqr(cusolverDnHandle_t handle,int m,int n,int k,c10::complex<float> * A,int lda,const c10::complex<float> * tau,c10::complex<float> * work,int lwork,int * devInfo)930 void orgqr<c10::complex<float>>(
931 cusolverDnHandle_t handle,
932 int m, int n, int k,
933 c10::complex<float>* A, int lda,
934 const c10::complex<float>* tau,
935 c10::complex<float>* work, int lwork,
936 int* devInfo) {
937 TORCH_CUSOLVER_CHECK(cusolverDnCungqr(
938 handle,
939 m, n, k,
940 reinterpret_cast<cuComplex*>(A), lda,
941 reinterpret_cast<const cuComplex*>(tau),
942 reinterpret_cast<cuComplex*>(work), lwork,
943 devInfo));
944 }
945
946 template <>
orgqr(cusolverDnHandle_t handle,int m,int n,int k,c10::complex<double> * A,int lda,const c10::complex<double> * tau,c10::complex<double> * work,int lwork,int * devInfo)947 void orgqr<c10::complex<double>>(
948 cusolverDnHandle_t handle,
949 int m, int n, int k,
950 c10::complex<double>* A, int lda,
951 const c10::complex<double>* tau,
952 c10::complex<double>* work, int lwork,
953 int* devInfo) {
954 TORCH_CUSOLVER_CHECK(cusolverDnZungqr(
955 handle,
956 m, n, k,
957 reinterpret_cast<cuDoubleComplex*>(A), lda,
958 reinterpret_cast<const cuDoubleComplex*>(tau),
959 reinterpret_cast<cuDoubleComplex*>(work), lwork,
960 devInfo));
961 }
962
963 template <>
ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES (float))964 void ormqr_bufferSize<float>(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(float)) {
965 TORCH_CUSOLVER_CHECK(
966 cusolverDnSormqr_bufferSize(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork));
967 }
968
969 template <>
ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES (double))970 void ormqr_bufferSize<double>(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(double)) {
971 TORCH_CUSOLVER_CHECK(
972 cusolverDnDormqr_bufferSize(handle, side, trans, m, n, k, A, lda, tau, C, ldc, lwork));
973 }
974
975 template <>
ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES (c10::complex<float>))976 void ormqr_bufferSize<c10::complex<float>>(
977 CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(c10::complex<float>)) {
978 TORCH_CUSOLVER_CHECK(cusolverDnCunmqr_bufferSize(
979 handle, side, trans,
980 m, n, k,
981 reinterpret_cast<const cuComplex*>(A), lda,
982 reinterpret_cast<const cuComplex*>(tau),
983 reinterpret_cast<const cuComplex*>(C), ldc,
984 lwork));
985 }
986
987 template <>
ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES (c10::complex<double>))988 void ormqr_bufferSize<c10::complex<double>>(
989 CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(c10::complex<double>)) {
990 TORCH_CUSOLVER_CHECK(cusolverDnZunmqr_bufferSize(
991 handle, side, trans,
992 m, n, k,
993 reinterpret_cast<const cuDoubleComplex*>(A), lda,
994 reinterpret_cast<const cuDoubleComplex*>(tau),
995 reinterpret_cast<const cuDoubleComplex*>(C), ldc,
996 lwork));
997 }
998
999 template <>
ormqr(CUDASOLVER_ORMQR_ARGTYPES (float))1000 void ormqr<float>(CUDASOLVER_ORMQR_ARGTYPES(float)) {
1001 TORCH_CUSOLVER_CHECK(
1002 cusolverDnSormqr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, lwork, devInfo));
1003 }
1004
1005 template <>
ormqr(CUDASOLVER_ORMQR_ARGTYPES (double))1006 void ormqr<double>(CUDASOLVER_ORMQR_ARGTYPES(double)) {
1007 TORCH_CUSOLVER_CHECK(
1008 cusolverDnDormqr(handle, side, trans, m, n, k, A, lda, tau, C, ldc, work, lwork, devInfo));
1009 }
1010
1011 template <>
ormqr(CUDASOLVER_ORMQR_ARGTYPES (c10::complex<float>))1012 void ormqr<c10::complex<float>>(CUDASOLVER_ORMQR_ARGTYPES(c10::complex<float>)) {
1013 TORCH_CUSOLVER_CHECK(cusolverDnCunmqr(
1014 handle, side, trans,
1015 m, n, k,
1016 reinterpret_cast<const cuComplex*>(A), lda,
1017 reinterpret_cast<const cuComplex*>(tau),
1018 reinterpret_cast<cuComplex*>(C), ldc,
1019 reinterpret_cast<cuComplex*>(work), lwork,
1020 devInfo));
1021 }
1022
1023 template <>
ormqr(CUDASOLVER_ORMQR_ARGTYPES (c10::complex<double>))1024 void ormqr<c10::complex<double>>(CUDASOLVER_ORMQR_ARGTYPES(c10::complex<double>)) {
1025 TORCH_CUSOLVER_CHECK(cusolverDnZunmqr(
1026 handle, side, trans,
1027 m, n, k,
1028 reinterpret_cast<const cuDoubleComplex*>(A), lda,
1029 reinterpret_cast<const cuDoubleComplex*>(tau),
1030 reinterpret_cast<cuDoubleComplex*>(C), ldc,
1031 reinterpret_cast<cuDoubleComplex*>(work), lwork,
1032 devInfo));
1033 }
1034
1035 #ifdef USE_CUSOLVER_64_BIT
1036
get_cusolver_datatype()1037 template<> cudaDataType get_cusolver_datatype<float>() { return CUDA_R_32F; }
get_cusolver_datatype()1038 template<> cudaDataType get_cusolver_datatype<double>() { return CUDA_R_64F; }
get_cusolver_datatype()1039 template<> cudaDataType get_cusolver_datatype<c10::complex<float>>() { return CUDA_C_32F; }
get_cusolver_datatype()1040 template<> cudaDataType get_cusolver_datatype<c10::complex<double>>() { return CUDA_C_64F; }
1041
xpotrf_buffersize(cusolverDnHandle_t handle,cusolverDnParams_t params,cublasFillMode_t uplo,int64_t n,cudaDataType dataTypeA,const void * A,int64_t lda,cudaDataType computeType,size_t * workspaceInBytesOnDevice,size_t * workspaceInBytesOnHost)1042 void xpotrf_buffersize(
1043 cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, cudaDataType dataTypeA, const void *A,
1044 int64_t lda, cudaDataType computeType, size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost) {
1045 TORCH_CUSOLVER_CHECK(cusolverDnXpotrf_bufferSize(
1046 handle, params, uplo, n, dataTypeA, A, lda, computeType, workspaceInBytesOnDevice, workspaceInBytesOnHost
1047 ));
1048 }
1049
xpotrf(cusolverDnHandle_t handle,cusolverDnParams_t params,cublasFillMode_t uplo,int64_t n,cudaDataType dataTypeA,void * A,int64_t lda,cudaDataType computeType,void * bufferOnDevice,size_t workspaceInBytesOnDevice,void * bufferOnHost,size_t workspaceInBytesOnHost,int * info)1050 void xpotrf(
1051 cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, cudaDataType dataTypeA, void *A,
1052 int64_t lda, cudaDataType computeType, void *bufferOnDevice, size_t workspaceInBytesOnDevice, void *bufferOnHost, size_t workspaceInBytesOnHost,
1053 int *info) {
1054 TORCH_CUSOLVER_CHECK(cusolverDnXpotrf(
1055 handle, params, uplo, n, dataTypeA, A, lda, computeType, bufferOnDevice, workspaceInBytesOnDevice, bufferOnHost, workspaceInBytesOnHost, info
1056 ));
1057 }
1058 #endif // USE_CUSOLVER_64_BIT
1059
1060 template <>
syevd_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const float * A,int lda,const float * W,int * lwork)1061 void syevd_bufferSize<float>(
1062 cusolverDnHandle_t handle,
1063 cusolverEigMode_t jobz,
1064 cublasFillMode_t uplo,
1065 int n,
1066 const float* A,
1067 int lda,
1068 const float* W,
1069 int* lwork) {
1070 TORCH_CUSOLVER_CHECK(
1071 cusolverDnSsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W, lwork));
1072 }
1073
1074 template <>
syevd_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const double * A,int lda,const double * W,int * lwork)1075 void syevd_bufferSize<double>(
1076 cusolverDnHandle_t handle,
1077 cusolverEigMode_t jobz,
1078 cublasFillMode_t uplo,
1079 int n,
1080 const double* A,
1081 int lda,
1082 const double* W,
1083 int* lwork) {
1084 TORCH_CUSOLVER_CHECK(
1085 cusolverDnDsyevd_bufferSize(handle, jobz, uplo, n, A, lda, W, lwork));
1086 }
1087
1088 template <>
syevd_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<float> * A,int lda,const float * W,int * lwork)1089 void syevd_bufferSize<c10::complex<float>, float>(
1090 cusolverDnHandle_t handle,
1091 cusolverEigMode_t jobz,
1092 cublasFillMode_t uplo,
1093 int n,
1094 const c10::complex<float>* A,
1095 int lda,
1096 const float* W,
1097 int* lwork) {
1098 TORCH_CUSOLVER_CHECK(cusolverDnCheevd_bufferSize(
1099 handle,
1100 jobz,
1101 uplo,
1102 n,
1103 reinterpret_cast<const cuComplex*>(A),
1104 lda,
1105 W,
1106 lwork));
1107 }
1108
1109 template <>
syevd_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<double> * A,int lda,const double * W,int * lwork)1110 void syevd_bufferSize<c10::complex<double>, double>(
1111 cusolverDnHandle_t handle,
1112 cusolverEigMode_t jobz,
1113 cublasFillMode_t uplo,
1114 int n,
1115 const c10::complex<double>* A,
1116 int lda,
1117 const double* W,
1118 int* lwork) {
1119 TORCH_CUSOLVER_CHECK(cusolverDnZheevd_bufferSize(
1120 handle,
1121 jobz,
1122 uplo,
1123 n,
1124 reinterpret_cast<const cuDoubleComplex*>(A),
1125 lda,
1126 W,
1127 lwork));
1128 }
1129
1130 template <>
syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,float * A,int lda,float * W,float * work,int lwork,int * info)1131 void syevd<float>(
1132 cusolverDnHandle_t handle,
1133 cusolverEigMode_t jobz,
1134 cublasFillMode_t uplo,
1135 int n,
1136 float* A,
1137 int lda,
1138 float* W,
1139 float* work,
1140 int lwork,
1141 int* info) {
1142 TORCH_CUSOLVER_CHECK(
1143 cusolverDnSsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork, info));
1144 }
1145
1146 template <>
syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,double * A,int lda,double * W,double * work,int lwork,int * info)1147 void syevd<double>(
1148 cusolverDnHandle_t handle,
1149 cusolverEigMode_t jobz,
1150 cublasFillMode_t uplo,
1151 int n,
1152 double* A,
1153 int lda,
1154 double* W,
1155 double* work,
1156 int lwork,
1157 int* info) {
1158 TORCH_CUSOLVER_CHECK(
1159 cusolverDnDsyevd(handle, jobz, uplo, n, A, lda, W, work, lwork, info));
1160 }
1161
1162 template <>
syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<float> * A,int lda,float * W,c10::complex<float> * work,int lwork,int * info)1163 void syevd<c10::complex<float>, float>(
1164 cusolverDnHandle_t handle,
1165 cusolverEigMode_t jobz,
1166 cublasFillMode_t uplo,
1167 int n,
1168 c10::complex<float>* A,
1169 int lda,
1170 float* W,
1171 c10::complex<float>* work,
1172 int lwork,
1173 int* info) {
1174 TORCH_CUSOLVER_CHECK(cusolverDnCheevd(
1175 handle,
1176 jobz,
1177 uplo,
1178 n,
1179 reinterpret_cast<cuComplex*>(A),
1180 lda,
1181 W,
1182 reinterpret_cast<cuComplex*>(work),
1183 lwork,
1184 info));
1185 }
1186
1187 template <>
syevd(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<double> * A,int lda,double * W,c10::complex<double> * work,int lwork,int * info)1188 void syevd<c10::complex<double>, double>(
1189 cusolverDnHandle_t handle,
1190 cusolverEigMode_t jobz,
1191 cublasFillMode_t uplo,
1192 int n,
1193 c10::complex<double>* A,
1194 int lda,
1195 double* W,
1196 c10::complex<double>* work,
1197 int lwork,
1198 int* info) {
1199 TORCH_CUSOLVER_CHECK(cusolverDnZheevd(
1200 handle,
1201 jobz,
1202 uplo,
1203 n,
1204 reinterpret_cast<cuDoubleComplex*>(A),
1205 lda,
1206 W,
1207 reinterpret_cast<cuDoubleComplex*>(work),
1208 lwork,
1209 info));
1210 }
1211
1212 template <>
syevj_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const float * A,int lda,const float * W,int * lwork,syevjInfo_t params)1213 void syevj_bufferSize<float>(
1214 cusolverDnHandle_t handle,
1215 cusolverEigMode_t jobz,
1216 cublasFillMode_t uplo,
1217 int n,
1218 const float* A,
1219 int lda,
1220 const float* W,
1221 int* lwork,
1222 syevjInfo_t params) {
1223 TORCH_CUSOLVER_CHECK(cusolverDnSsyevj_bufferSize(
1224 handle, jobz, uplo, n, A, lda, W, lwork, params));
1225 }
1226
1227 template <>
syevj_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const double * A,int lda,const double * W,int * lwork,syevjInfo_t params)1228 void syevj_bufferSize<double>(
1229 cusolverDnHandle_t handle,
1230 cusolverEigMode_t jobz,
1231 cublasFillMode_t uplo,
1232 int n,
1233 const double* A,
1234 int lda,
1235 const double* W,
1236 int* lwork,
1237 syevjInfo_t params) {
1238 TORCH_CUSOLVER_CHECK(cusolverDnDsyevj_bufferSize(
1239 handle, jobz, uplo, n, A, lda, W, lwork, params));
1240 }
1241
1242 template <>
syevj_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<float> * A,int lda,const float * W,int * lwork,syevjInfo_t params)1243 void syevj_bufferSize<c10::complex<float>, float>(
1244 cusolverDnHandle_t handle,
1245 cusolverEigMode_t jobz,
1246 cublasFillMode_t uplo,
1247 int n,
1248 const c10::complex<float>* A,
1249 int lda,
1250 const float* W,
1251 int* lwork,
1252 syevjInfo_t params) {
1253 TORCH_CUSOLVER_CHECK(cusolverDnCheevj_bufferSize(
1254 handle,
1255 jobz,
1256 uplo,
1257 n,
1258 reinterpret_cast<const cuComplex*>(A),
1259 lda,
1260 W,
1261 lwork,
1262 params));
1263 }
1264
1265 template <>
syevj_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<double> * A,int lda,const double * W,int * lwork,syevjInfo_t params)1266 void syevj_bufferSize<c10::complex<double>, double>(
1267 cusolverDnHandle_t handle,
1268 cusolverEigMode_t jobz,
1269 cublasFillMode_t uplo,
1270 int n,
1271 const c10::complex<double>* A,
1272 int lda,
1273 const double* W,
1274 int* lwork,
1275 syevjInfo_t params) {
1276 TORCH_CUSOLVER_CHECK(cusolverDnZheevj_bufferSize(
1277 handle,
1278 jobz,
1279 uplo,
1280 n,
1281 reinterpret_cast<const cuDoubleComplex*>(A),
1282 lda,
1283 W,
1284 lwork,
1285 params));
1286 }
1287
1288 template <>
syevj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,float * A,int lda,float * W,float * work,int lwork,int * info,syevjInfo_t params)1289 void syevj<float>(
1290 cusolverDnHandle_t handle,
1291 cusolverEigMode_t jobz,
1292 cublasFillMode_t uplo,
1293 int n,
1294 float* A,
1295 int lda,
1296 float* W,
1297 float* work,
1298 int lwork,
1299 int* info,
1300 syevjInfo_t params) {
1301 TORCH_CUSOLVER_CHECK(cusolverDnSsyevj(
1302 handle, jobz, uplo, n, A, lda, W, work, lwork, info, params));
1303 }
1304
1305 template <>
syevj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,double * A,int lda,double * W,double * work,int lwork,int * info,syevjInfo_t params)1306 void syevj<double>(
1307 cusolverDnHandle_t handle,
1308 cusolverEigMode_t jobz,
1309 cublasFillMode_t uplo,
1310 int n,
1311 double* A,
1312 int lda,
1313 double* W,
1314 double* work,
1315 int lwork,
1316 int* info,
1317 syevjInfo_t params) {
1318 TORCH_CUSOLVER_CHECK(cusolverDnDsyevj(
1319 handle, jobz, uplo, n, A, lda, W, work, lwork, info, params));
1320 }
1321
1322 template <>
syevj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<float> * A,int lda,float * W,c10::complex<float> * work,int lwork,int * info,syevjInfo_t params)1323 void syevj<c10::complex<float>, float>(
1324 cusolverDnHandle_t handle,
1325 cusolverEigMode_t jobz,
1326 cublasFillMode_t uplo,
1327 int n,
1328 c10::complex<float>* A,
1329 int lda,
1330 float* W,
1331 c10::complex<float>* work,
1332 int lwork,
1333 int* info,
1334 syevjInfo_t params) {
1335 TORCH_CUSOLVER_CHECK(cusolverDnCheevj(
1336 handle,
1337 jobz,
1338 uplo,
1339 n,
1340 reinterpret_cast<cuComplex*>(A),
1341 lda,
1342 W,
1343 reinterpret_cast<cuComplex*>(work),
1344 lwork,
1345 info,
1346 params));
1347 }
1348
1349 template <>
syevj(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<double> * A,int lda,double * W,c10::complex<double> * work,int lwork,int * info,syevjInfo_t params)1350 void syevj<c10::complex<double>, double>(
1351 cusolverDnHandle_t handle,
1352 cusolverEigMode_t jobz,
1353 cublasFillMode_t uplo,
1354 int n,
1355 c10::complex<double>* A,
1356 int lda,
1357 double* W,
1358 c10::complex<double>* work,
1359 int lwork,
1360 int* info,
1361 syevjInfo_t params) {
1362 TORCH_CUSOLVER_CHECK(cusolverDnZheevj(
1363 handle,
1364 jobz,
1365 uplo,
1366 n,
1367 reinterpret_cast<cuDoubleComplex*>(A),
1368 lda,
1369 W,
1370 reinterpret_cast<cuDoubleComplex*>(work),
1371 lwork,
1372 info,
1373 params));
1374 }
1375
1376 template <>
syevjBatched_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const float * A,int lda,const float * W,int * lwork,syevjInfo_t params,int batchsize)1377 void syevjBatched_bufferSize<float>(
1378 cusolverDnHandle_t handle,
1379 cusolverEigMode_t jobz,
1380 cublasFillMode_t uplo,
1381 int n,
1382 const float* A,
1383 int lda,
1384 const float* W,
1385 int* lwork,
1386 syevjInfo_t params,
1387 int batchsize) {
1388 TORCH_CUSOLVER_CHECK(cusolverDnSsyevjBatched_bufferSize(
1389 handle, jobz, uplo, n, A, lda, W, lwork, params, batchsize));
1390 }
1391
1392 template <>
syevjBatched_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const double * A,int lda,const double * W,int * lwork,syevjInfo_t params,int batchsize)1393 void syevjBatched_bufferSize<double>(
1394 cusolverDnHandle_t handle,
1395 cusolverEigMode_t jobz,
1396 cublasFillMode_t uplo,
1397 int n,
1398 const double* A,
1399 int lda,
1400 const double* W,
1401 int* lwork,
1402 syevjInfo_t params,
1403 int batchsize) {
1404 TORCH_CUSOLVER_CHECK(cusolverDnDsyevjBatched_bufferSize(
1405 handle, jobz, uplo, n, A, lda, W, lwork, params, batchsize));
1406 }
1407
1408 template <>
syevjBatched_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<float> * A,int lda,const float * W,int * lwork,syevjInfo_t params,int batchsize)1409 void syevjBatched_bufferSize<c10::complex<float>, float>(
1410 cusolverDnHandle_t handle,
1411 cusolverEigMode_t jobz,
1412 cublasFillMode_t uplo,
1413 int n,
1414 const c10::complex<float>* A,
1415 int lda,
1416 const float* W,
1417 int* lwork,
1418 syevjInfo_t params,
1419 int batchsize) {
1420 TORCH_CUSOLVER_CHECK(cusolverDnCheevjBatched_bufferSize(
1421 handle,
1422 jobz,
1423 uplo,
1424 n,
1425 reinterpret_cast<const cuComplex*>(A),
1426 lda,
1427 W,
1428 lwork,
1429 params,
1430 batchsize));
1431 }
1432
1433 template <>
syevjBatched_bufferSize(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,const c10::complex<double> * A,int lda,const double * W,int * lwork,syevjInfo_t params,int batchsize)1434 void syevjBatched_bufferSize<c10::complex<double>, double>(
1435 cusolverDnHandle_t handle,
1436 cusolverEigMode_t jobz,
1437 cublasFillMode_t uplo,
1438 int n,
1439 const c10::complex<double>* A,
1440 int lda,
1441 const double* W,
1442 int* lwork,
1443 syevjInfo_t params,
1444 int batchsize) {
1445 TORCH_CUSOLVER_CHECK(cusolverDnZheevjBatched_bufferSize(
1446 handle,
1447 jobz,
1448 uplo,
1449 n,
1450 reinterpret_cast<const cuDoubleComplex*>(A),
1451 lda,
1452 W,
1453 lwork,
1454 params,
1455 batchsize));
1456 }
1457
1458 template <>
syevjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,float * A,int lda,float * W,float * work,int lwork,int * info,syevjInfo_t params,int batchsize)1459 void syevjBatched<float>(
1460 cusolverDnHandle_t handle,
1461 cusolverEigMode_t jobz,
1462 cublasFillMode_t uplo,
1463 int n,
1464 float* A,
1465 int lda,
1466 float* W,
1467 float* work,
1468 int lwork,
1469 int* info,
1470 syevjInfo_t params,
1471 int batchsize) {
1472 TORCH_CUSOLVER_CHECK(cusolverDnSsyevjBatched(
1473 handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, batchsize));
1474 }
1475
1476 template <>
syevjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,double * A,int lda,double * W,double * work,int lwork,int * info,syevjInfo_t params,int batchsize)1477 void syevjBatched<double>(
1478 cusolverDnHandle_t handle,
1479 cusolverEigMode_t jobz,
1480 cublasFillMode_t uplo,
1481 int n,
1482 double* A,
1483 int lda,
1484 double* W,
1485 double* work,
1486 int lwork,
1487 int* info,
1488 syevjInfo_t params,
1489 int batchsize) {
1490 TORCH_CUSOLVER_CHECK(cusolverDnDsyevjBatched(
1491 handle, jobz, uplo, n, A, lda, W, work, lwork, info, params, batchsize));
1492 }
1493
1494 template <>
syevjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<float> * A,int lda,float * W,c10::complex<float> * work,int lwork,int * info,syevjInfo_t params,int batchsize)1495 void syevjBatched<c10::complex<float>, float>(
1496 cusolverDnHandle_t handle,
1497 cusolverEigMode_t jobz,
1498 cublasFillMode_t uplo,
1499 int n,
1500 c10::complex<float>* A,
1501 int lda,
1502 float* W,
1503 c10::complex<float>* work,
1504 int lwork,
1505 int* info,
1506 syevjInfo_t params,
1507 int batchsize) {
1508 TORCH_CUSOLVER_CHECK(cusolverDnCheevjBatched(
1509 handle,
1510 jobz,
1511 uplo,
1512 n,
1513 reinterpret_cast<cuComplex*>(A),
1514 lda,
1515 W,
1516 reinterpret_cast<cuComplex*>(work),
1517 lwork,
1518 info,
1519 params,
1520 batchsize));
1521 }
1522
1523 template <>
syevjBatched(cusolverDnHandle_t handle,cusolverEigMode_t jobz,cublasFillMode_t uplo,int n,c10::complex<double> * A,int lda,double * W,c10::complex<double> * work,int lwork,int * info,syevjInfo_t params,int batchsize)1524 void syevjBatched<c10::complex<double>, double>(
1525 cusolverDnHandle_t handle,
1526 cusolverEigMode_t jobz,
1527 cublasFillMode_t uplo,
1528 int n,
1529 c10::complex<double>* A,
1530 int lda,
1531 double* W,
1532 c10::complex<double>* work,
1533 int lwork,
1534 int* info,
1535 syevjInfo_t params,
1536 int batchsize) {
1537 TORCH_CUSOLVER_CHECK(cusolverDnZheevjBatched(
1538 handle,
1539 jobz,
1540 uplo,
1541 n,
1542 reinterpret_cast<cuDoubleComplex*>(A),
1543 lda,
1544 W,
1545 reinterpret_cast<cuDoubleComplex*>(work),
1546 lwork,
1547 info,
1548 params,
1549 batchsize));
1550 }
1551
1552 #ifdef USE_CUSOLVER_64_BIT
1553
xpotrs(cusolverDnHandle_t handle,cusolverDnParams_t params,cublasFillMode_t uplo,int64_t n,int64_t nrhs,cudaDataType dataTypeA,const void * A,int64_t lda,cudaDataType dataTypeB,void * B,int64_t ldb,int * info)1554 void xpotrs(
1555 cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, int64_t nrhs, cudaDataType dataTypeA, const void *A,
1556 int64_t lda, cudaDataType dataTypeB, void *B, int64_t ldb, int *info) {
1557 TORCH_CUSOLVER_CHECK(cusolverDnXpotrs(handle, params, uplo, n, nrhs, dataTypeA, A, lda, dataTypeB, B, ldb, info));
1558 }
1559
1560 template <>
xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES (float))1561 void xgeqrf_bufferSize<float>(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(float)) {
1562 TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf_bufferSize(
1563 handle,
1564 params,
1565 m,
1566 n,
1567 CUDA_R_32F,
1568 reinterpret_cast<const void*>(A),
1569 lda,
1570 CUDA_R_32F,
1571 reinterpret_cast<const void*>(tau),
1572 CUDA_R_32F,
1573 workspaceInBytesOnDevice,
1574 workspaceInBytesOnHost));
1575 }
1576
1577 template <>
xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES (double))1578 void xgeqrf_bufferSize<double>(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(double)) {
1579 TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf_bufferSize(
1580 handle,
1581 params,
1582 m,
1583 n,
1584 CUDA_R_64F,
1585 reinterpret_cast<const void*>(A),
1586 lda,
1587 CUDA_R_64F,
1588 reinterpret_cast<const void*>(tau),
1589 CUDA_R_64F,
1590 workspaceInBytesOnDevice,
1591 workspaceInBytesOnHost));
1592 }
1593
1594 template <>
xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES (c10::complex<float>))1595 void xgeqrf_bufferSize<c10::complex<float>>(
1596 CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(c10::complex<float>)) {
1597 TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf_bufferSize(
1598 handle,
1599 params,
1600 m,
1601 n,
1602 CUDA_C_32F,
1603 reinterpret_cast<const void*>(A),
1604 lda,
1605 CUDA_C_32F,
1606 reinterpret_cast<const void*>(tau),
1607 CUDA_C_32F,
1608 workspaceInBytesOnDevice,
1609 workspaceInBytesOnHost));
1610 }
1611
1612 template <>
xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES (c10::complex<double>))1613 void xgeqrf_bufferSize<c10::complex<double>>(
1614 CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(c10::complex<double>)) {
1615 TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf_bufferSize(
1616 handle,
1617 params,
1618 m,
1619 n,
1620 CUDA_C_64F,
1621 reinterpret_cast<const void*>(A),
1622 lda,
1623 CUDA_C_64F,
1624 reinterpret_cast<const void*>(tau),
1625 CUDA_C_64F,
1626 workspaceInBytesOnDevice,
1627 workspaceInBytesOnHost));
1628 }
1629
1630 template <>
xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES (float))1631 void xgeqrf<float>(CUDASOLVER_XGEQRF_ARGTYPES(float)) {
1632 TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf(
1633 handle,
1634 params,
1635 m,
1636 n,
1637 CUDA_R_32F,
1638 reinterpret_cast<void*>(A),
1639 lda,
1640 CUDA_R_32F,
1641 reinterpret_cast<void*>(tau),
1642 CUDA_R_32F,
1643 reinterpret_cast<void*>(bufferOnDevice),
1644 workspaceInBytesOnDevice,
1645 reinterpret_cast<void*>(bufferOnHost),
1646 workspaceInBytesOnHost,
1647 info));
1648 }
1649
1650 template <>
xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES (double))1651 void xgeqrf<double>(CUDASOLVER_XGEQRF_ARGTYPES(double)) {
1652 TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf(
1653 handle,
1654 params,
1655 m,
1656 n,
1657 CUDA_R_64F,
1658 reinterpret_cast<void*>(A),
1659 lda,
1660 CUDA_R_64F,
1661 reinterpret_cast<void*>(tau),
1662 CUDA_R_64F,
1663 reinterpret_cast<void*>(bufferOnDevice),
1664 workspaceInBytesOnDevice,
1665 reinterpret_cast<void*>(bufferOnHost),
1666 workspaceInBytesOnHost,
1667 info));
1668 }
1669
1670 template <>
xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES (c10::complex<float>))1671 void xgeqrf<c10::complex<float>>(CUDASOLVER_XGEQRF_ARGTYPES(c10::complex<float>)) {
1672 TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf(
1673 handle,
1674 params,
1675 m,
1676 n,
1677 CUDA_C_32F,
1678 reinterpret_cast<void*>(A),
1679 lda,
1680 CUDA_C_32F,
1681 reinterpret_cast<void*>(tau),
1682 CUDA_C_32F,
1683 reinterpret_cast<void*>(bufferOnDevice),
1684 workspaceInBytesOnDevice,
1685 reinterpret_cast<void*>(bufferOnHost),
1686 workspaceInBytesOnHost,
1687 info));
1688 }
1689
1690 template <>
xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES (c10::complex<double>))1691 void xgeqrf<c10::complex<double>>(CUDASOLVER_XGEQRF_ARGTYPES(c10::complex<double>)) {
1692 TORCH_CUSOLVER_CHECK(cusolverDnXgeqrf(
1693 handle,
1694 params,
1695 m,
1696 n,
1697 CUDA_C_64F,
1698 reinterpret_cast<void*>(A),
1699 lda,
1700 CUDA_C_64F,
1701 reinterpret_cast<void*>(tau),
1702 CUDA_C_64F,
1703 reinterpret_cast<void*>(bufferOnDevice),
1704 workspaceInBytesOnDevice,
1705 reinterpret_cast<void*>(bufferOnHost),
1706 workspaceInBytesOnHost,
1707 info));
1708 }
1709
1710 template <>
xsyevd_bufferSize(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,const float * A,int64_t lda,const float * W,size_t * workspaceInBytesOnDevice,size_t * workspaceInBytesOnHost)1711 void xsyevd_bufferSize<float>(
1712 cusolverDnHandle_t handle,
1713 cusolverDnParams_t params,
1714 cusolverEigMode_t jobz,
1715 cublasFillMode_t uplo,
1716 int64_t n,
1717 const float* A,
1718 int64_t lda,
1719 const float* W,
1720 size_t* workspaceInBytesOnDevice,
1721 size_t* workspaceInBytesOnHost) {
1722 TORCH_CUSOLVER_CHECK(cusolverDnXsyevd_bufferSize(
1723 handle,
1724 params,
1725 jobz,
1726 uplo,
1727 n,
1728 CUDA_R_32F,
1729 reinterpret_cast<const void*>(A),
1730 lda,
1731 CUDA_R_32F,
1732 reinterpret_cast<const void*>(W),
1733 CUDA_R_32F,
1734 workspaceInBytesOnDevice,
1735 workspaceInBytesOnHost));
1736 }
1737
1738 template <>
xsyevd_bufferSize(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,const double * A,int64_t lda,const double * W,size_t * workspaceInBytesOnDevice,size_t * workspaceInBytesOnHost)1739 void xsyevd_bufferSize<double>(
1740 cusolverDnHandle_t handle,
1741 cusolverDnParams_t params,
1742 cusolverEigMode_t jobz,
1743 cublasFillMode_t uplo,
1744 int64_t n,
1745 const double* A,
1746 int64_t lda,
1747 const double* W,
1748 size_t* workspaceInBytesOnDevice,
1749 size_t* workspaceInBytesOnHost) {
1750 TORCH_CUSOLVER_CHECK(cusolverDnXsyevd_bufferSize(
1751 handle,
1752 params,
1753 jobz,
1754 uplo,
1755 n,
1756 CUDA_R_64F,
1757 reinterpret_cast<const void*>(A),
1758 lda,
1759 CUDA_R_64F,
1760 reinterpret_cast<const void*>(W),
1761 CUDA_R_64F,
1762 workspaceInBytesOnDevice,
1763 workspaceInBytesOnHost));
1764 }
1765
1766 template <>
xsyevd_bufferSize(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,const c10::complex<float> * A,int64_t lda,const float * W,size_t * workspaceInBytesOnDevice,size_t * workspaceInBytesOnHost)1767 void xsyevd_bufferSize<c10::complex<float>, float>(
1768 cusolverDnHandle_t handle,
1769 cusolverDnParams_t params,
1770 cusolverEigMode_t jobz,
1771 cublasFillMode_t uplo,
1772 int64_t n,
1773 const c10::complex<float>* A,
1774 int64_t lda,
1775 const float* W,
1776 size_t* workspaceInBytesOnDevice,
1777 size_t* workspaceInBytesOnHost) {
1778 TORCH_CUSOLVER_CHECK(cusolverDnXsyevd_bufferSize(
1779 handle,
1780 params,
1781 jobz,
1782 uplo,
1783 n,
1784 CUDA_C_32F,
1785 reinterpret_cast<const void*>(A),
1786 lda,
1787 CUDA_R_32F,
1788 reinterpret_cast<const void*>(W),
1789 CUDA_C_32F,
1790 workspaceInBytesOnDevice,
1791 workspaceInBytesOnHost));
1792 }
1793
1794 template <>
xsyevd_bufferSize(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,const c10::complex<double> * A,int64_t lda,const double * W,size_t * workspaceInBytesOnDevice,size_t * workspaceInBytesOnHost)1795 void xsyevd_bufferSize<c10::complex<double>, double>(
1796 cusolverDnHandle_t handle,
1797 cusolverDnParams_t params,
1798 cusolverEigMode_t jobz,
1799 cublasFillMode_t uplo,
1800 int64_t n,
1801 const c10::complex<double>* A,
1802 int64_t lda,
1803 const double* W,
1804 size_t* workspaceInBytesOnDevice,
1805 size_t* workspaceInBytesOnHost) {
1806 TORCH_CUSOLVER_CHECK(cusolverDnXsyevd_bufferSize(
1807 handle,
1808 params,
1809 jobz,
1810 uplo,
1811 n,
1812 CUDA_C_64F,
1813 reinterpret_cast<const void*>(A),
1814 lda,
1815 CUDA_R_64F,
1816 reinterpret_cast<const void*>(W),
1817 CUDA_C_64F,
1818 workspaceInBytesOnDevice,
1819 workspaceInBytesOnHost));
1820 }
1821
1822 template <>
xsyevd(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,float * A,int64_t lda,float * W,float * bufferOnDevice,size_t workspaceInBytesOnDevice,float * bufferOnHost,size_t workspaceInBytesOnHost,int * info)1823 void xsyevd<float>(
1824 cusolverDnHandle_t handle,
1825 cusolverDnParams_t params,
1826 cusolverEigMode_t jobz,
1827 cublasFillMode_t uplo,
1828 int64_t n,
1829 float* A,
1830 int64_t lda,
1831 float* W,
1832 float* bufferOnDevice,
1833 size_t workspaceInBytesOnDevice,
1834 float* bufferOnHost,
1835 size_t workspaceInBytesOnHost,
1836 int* info) {
1837 TORCH_CUSOLVER_CHECK(cusolverDnXsyevd(
1838 handle,
1839 params,
1840 jobz,
1841 uplo,
1842 n,
1843 CUDA_R_32F,
1844 reinterpret_cast<void*>(A),
1845 lda,
1846 CUDA_R_32F,
1847 reinterpret_cast<void*>(W),
1848 CUDA_R_32F,
1849 reinterpret_cast<void*>(bufferOnDevice),
1850 workspaceInBytesOnDevice,
1851 reinterpret_cast<void*>(bufferOnHost),
1852 workspaceInBytesOnHost,
1853 info));
1854 }
1855
1856 template <>
xsyevd(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,double * A,int64_t lda,double * W,double * bufferOnDevice,size_t workspaceInBytesOnDevice,double * bufferOnHost,size_t workspaceInBytesOnHost,int * info)1857 void xsyevd<double>(
1858 cusolverDnHandle_t handle,
1859 cusolverDnParams_t params,
1860 cusolverEigMode_t jobz,
1861 cublasFillMode_t uplo,
1862 int64_t n,
1863 double* A,
1864 int64_t lda,
1865 double* W,
1866 double* bufferOnDevice,
1867 size_t workspaceInBytesOnDevice,
1868 double* bufferOnHost,
1869 size_t workspaceInBytesOnHost,
1870 int* info) {
1871 TORCH_CUSOLVER_CHECK(cusolverDnXsyevd(
1872 handle,
1873 params,
1874 jobz,
1875 uplo,
1876 n,
1877 CUDA_R_64F,
1878 reinterpret_cast<void*>(A),
1879 lda,
1880 CUDA_R_64F,
1881 reinterpret_cast<void*>(W),
1882 CUDA_R_64F,
1883 reinterpret_cast<void*>(bufferOnDevice),
1884 workspaceInBytesOnDevice,
1885 reinterpret_cast<void*>(bufferOnHost),
1886 workspaceInBytesOnHost,
1887 info));
1888 }
1889
1890 template <>
xsyevd(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,c10::complex<float> * A,int64_t lda,float * W,c10::complex<float> * bufferOnDevice,size_t workspaceInBytesOnDevice,c10::complex<float> * bufferOnHost,size_t workspaceInBytesOnHost,int * info)1891 void xsyevd<c10::complex<float>, float>(
1892 cusolverDnHandle_t handle,
1893 cusolverDnParams_t params,
1894 cusolverEigMode_t jobz,
1895 cublasFillMode_t uplo,
1896 int64_t n,
1897 c10::complex<float>* A,
1898 int64_t lda,
1899 float* W,
1900 c10::complex<float>* bufferOnDevice,
1901 size_t workspaceInBytesOnDevice,
1902 c10::complex<float>* bufferOnHost,
1903 size_t workspaceInBytesOnHost,
1904 int* info) {
1905 TORCH_CUSOLVER_CHECK(cusolverDnXsyevd(
1906 handle,
1907 params,
1908 jobz,
1909 uplo,
1910 n,
1911 CUDA_C_32F,
1912 reinterpret_cast<void*>(A),
1913 lda,
1914 CUDA_R_32F,
1915 reinterpret_cast<void*>(W),
1916 CUDA_C_32F,
1917 reinterpret_cast<void*>(bufferOnDevice),
1918 workspaceInBytesOnDevice,
1919 reinterpret_cast<void*>(bufferOnHost),
1920 workspaceInBytesOnHost,
1921 info));
1922 }
1923
1924 template <>
xsyevd(cusolverDnHandle_t handle,cusolverDnParams_t params,cusolverEigMode_t jobz,cublasFillMode_t uplo,int64_t n,c10::complex<double> * A,int64_t lda,double * W,c10::complex<double> * bufferOnDevice,size_t workspaceInBytesOnDevice,c10::complex<double> * bufferOnHost,size_t workspaceInBytesOnHost,int * info)1925 void xsyevd<c10::complex<double>, double>(
1926 cusolverDnHandle_t handle,
1927 cusolverDnParams_t params,
1928 cusolverEigMode_t jobz,
1929 cublasFillMode_t uplo,
1930 int64_t n,
1931 c10::complex<double>* A,
1932 int64_t lda,
1933 double* W,
1934 c10::complex<double>* bufferOnDevice,
1935 size_t workspaceInBytesOnDevice,
1936 c10::complex<double>* bufferOnHost,
1937 size_t workspaceInBytesOnHost,
1938 int* info) {
1939 TORCH_CUSOLVER_CHECK(cusolverDnXsyevd(
1940 handle,
1941 params,
1942 jobz,
1943 uplo,
1944 n,
1945 CUDA_C_64F,
1946 reinterpret_cast<void*>(A),
1947 lda,
1948 CUDA_R_64F,
1949 reinterpret_cast<void*>(W),
1950 CUDA_C_64F,
1951 reinterpret_cast<void*>(bufferOnDevice),
1952 workspaceInBytesOnDevice,
1953 reinterpret_cast<void*>(bufferOnHost),
1954 workspaceInBytesOnHost,
1955 info));
1956 }
1957 #endif // USE_CUSOLVER_64_BIT
1958
1959 } // namespace at::cuda::solver
1960
1961 #endif // CUDART_VERSION
1962