1 #pragma once
2
3 #include <ATen/cuda/CUDAContext.h>
4
5 #if defined(CUDART_VERSION) && defined(CUSOLVER_VERSION) && CUSOLVER_VERSION >= 11000
6 // cuSOLVER version >= 11000 includes 64-bit API
7 #define USE_CUSOLVER_64_BIT
8 #endif
9
10 #if defined(CUDART_VERSION) || defined(USE_ROCM)
11
12 namespace at {
13 namespace cuda {
14 namespace solver {
15
16 #define CUDASOLVER_GETRF_ARGTYPES(Dtype) \
17 cusolverDnHandle_t handle, int m, int n, Dtype* dA, int ldda, int* ipiv, int* info
18
19 template<class Dtype>
getrf(CUDASOLVER_GETRF_ARGTYPES (Dtype))20 void getrf(CUDASOLVER_GETRF_ARGTYPES(Dtype)) {
21 static_assert(false&&sizeof(Dtype), "at::cuda::solver::getrf: not implemented");
22 }
23 template<>
24 void getrf<float>(CUDASOLVER_GETRF_ARGTYPES(float));
25 template<>
26 void getrf<double>(CUDASOLVER_GETRF_ARGTYPES(double));
27 template<>
28 void getrf<c10::complex<double>>(CUDASOLVER_GETRF_ARGTYPES(c10::complex<double>));
29 template<>
30 void getrf<c10::complex<float>>(CUDASOLVER_GETRF_ARGTYPES(c10::complex<float>));
31
32
33 #define CUDASOLVER_GETRS_ARGTYPES(Dtype) \
34 cusolverDnHandle_t handle, int n, int nrhs, Dtype* dA, int lda, int* ipiv, Dtype* ret, int ldb, int* info, cublasOperation_t trans
35
36 template<class Dtype>
getrs(CUDASOLVER_GETRS_ARGTYPES (Dtype))37 void getrs(CUDASOLVER_GETRS_ARGTYPES(Dtype)) {
38 static_assert(false&&sizeof(Dtype), "at::cuda::solver::getrs: not implemented");
39 }
40 template<>
41 void getrs<float>(CUDASOLVER_GETRS_ARGTYPES(float));
42 template<>
43 void getrs<double>(CUDASOLVER_GETRS_ARGTYPES(double));
44 template<>
45 void getrs<c10::complex<double>>(CUDASOLVER_GETRS_ARGTYPES(c10::complex<double>));
46 template<>
47 void getrs<c10::complex<float>>(CUDASOLVER_GETRS_ARGTYPES(c10::complex<float>));
48
49 #define CUDASOLVER_SYTRF_BUFFER_ARGTYPES(Dtype) \
50 cusolverDnHandle_t handle, int n, Dtype *A, int lda, int *lwork
51
52 template <class Dtype>
sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES (Dtype))53 void sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(Dtype)) {
54 static_assert(false&&sizeof(Dtype),
55 "at::cuda::solver::sytrf_bufferSize: not implemented");
56 }
57 template <>
58 void sytrf_bufferSize<float>(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(float));
59 template <>
60 void sytrf_bufferSize<double>(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(double));
61 template <>
62 void sytrf_bufferSize<c10::complex<double>>(
63 CUDASOLVER_SYTRF_BUFFER_ARGTYPES(c10::complex<double>));
64 template <>
65 void sytrf_bufferSize<c10::complex<float>>(
66 CUDASOLVER_SYTRF_BUFFER_ARGTYPES(c10::complex<float>));
67
68 #define CUDASOLVER_SYTRF_ARGTYPES(Dtype) \
69 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype *A, int lda, \
70 int *ipiv, Dtype *work, int lwork, int *devInfo
71
72 template <class Dtype>
sytrf(CUDASOLVER_SYTRF_ARGTYPES (Dtype))73 void sytrf(CUDASOLVER_SYTRF_ARGTYPES(Dtype)) {
74 static_assert(false&&sizeof(Dtype),
75 "at::cuda::solver::sytrf: not implemented");
76 }
77 template <>
78 void sytrf<float>(CUDASOLVER_SYTRF_ARGTYPES(float));
79 template <>
80 void sytrf<double>(CUDASOLVER_SYTRF_ARGTYPES(double));
81 template <>
82 void sytrf<c10::complex<double>>(
83 CUDASOLVER_SYTRF_ARGTYPES(c10::complex<double>));
84 template <>
85 void sytrf<c10::complex<float>>(CUDASOLVER_SYTRF_ARGTYPES(c10::complex<float>));
86
87 #define CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES() \
88 cusolverDnHandle_t handle, int m, int n, int *lwork
89
90 template<class Dtype>
gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES ())91 void gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
92 static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvd_buffersize: not implemented");
93 }
94 template<>
95 void gesvd_buffersize<float>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES());
96 template<>
97 void gesvd_buffersize<double>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES());
98 template<>
99 void gesvd_buffersize<c10::complex<float>>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES());
100 template<>
101 void gesvd_buffersize<c10::complex<double>>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES());
102
103
104 #define CUDASOLVER_GESVD_ARGTYPES(Dtype, Vtype) \
105 cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, int n, Dtype *A, int lda, \
106 Vtype *S, Dtype *U, int ldu, Dtype *VT, int ldvt, Dtype *work, int lwork, Vtype *rwork, int *info
107
108 template<class Dtype, class Vtype>
gesvd(CUDASOLVER_GESVD_ARGTYPES (Dtype,Vtype))109 void gesvd(CUDASOLVER_GESVD_ARGTYPES(Dtype, Vtype)) {
110 static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvd: not implemented");
111 }
112 template<>
113 void gesvd<float>(CUDASOLVER_GESVD_ARGTYPES(float, float));
114 template<>
115 void gesvd<double>(CUDASOLVER_GESVD_ARGTYPES(double, double));
116 template<>
117 void gesvd<c10::complex<float>>(CUDASOLVER_GESVD_ARGTYPES(c10::complex<float>, float));
118 template<>
119 void gesvd<c10::complex<double>>(CUDASOLVER_GESVD_ARGTYPES(c10::complex<double>, double));
120
121
122 #define CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(Dtype, Vtype) \
123 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, Dtype *A, int lda, Vtype *S, \
124 Dtype *U, int ldu, Dtype *V, int ldv, int *lwork, gesvdjInfo_t params
125
126 template<class Dtype, class Vtype>
gesvdj_buffersize(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES (Dtype,Vtype))127 void gesvdj_buffersize(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(Dtype, Vtype)) {
128 static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj_buffersize: not implemented");
129 }
130 template<>
131 void gesvdj_buffersize<float>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(float, float));
132 template<>
133 void gesvdj_buffersize<double>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(double, double));
134 template<>
135 void gesvdj_buffersize<c10::complex<float>>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
136 template<>
137 void gesvdj_buffersize<c10::complex<double>>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
138
139
140 #define CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype) \
141 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, Dtype* A, int lda, Vtype* S, Dtype* U, \
142 int ldu, Dtype* V, int ldv, Dtype* work, int lwork, int *info, gesvdjInfo_t params
143
144 template<class Dtype, class Vtype>
gesvdj(CUDASOLVER_GESVDJ_ARGTYPES (Dtype,Vtype))145 void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype)) {
146 static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj: not implemented");
147 }
148 template<>
149 void gesvdj<float>(CUDASOLVER_GESVDJ_ARGTYPES(float, float));
150 template<>
151 void gesvdj<double>(CUDASOLVER_GESVDJ_ARGTYPES(double, double));
152 template<>
153 void gesvdj<c10::complex<float>>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex<float>, float));
154 template<>
155 void gesvdj<c10::complex<double>>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex<double>, double));
156
157
158 #define CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype) \
159 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, Dtype* A, int lda, Vtype* S, Dtype* U, \
160 int ldu, Dtype *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
161
162 template<class Dtype, class Vtype>
gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES (Dtype,Vtype))163 void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype)) {
164 static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj: not implemented");
165 }
166 template<>
167 void gesvdjBatched<float>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(float, float));
168 template<>
169 void gesvdjBatched<double>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(double, double));
170 template<>
171 void gesvdjBatched<c10::complex<float>>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10::complex<float>, float));
172 template<>
173 void gesvdjBatched<c10::complex<double>>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10::complex<double>, double));
174
175 #define CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(Dtype, Vtype) \
176 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, Dtype *A, int lda, long long int strideA, \
177 Vtype *S, long long int strideS, Dtype *U, int ldu, long long int strideU, Dtype *V, int ldv, long long int strideV, \
178 int *lwork, int batchSize
179
180 template<class Dtype, class Vtype>
gesvdaStridedBatched_buffersize(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES (Dtype,Vtype))181 void gesvdaStridedBatched_buffersize(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(Dtype, Vtype)) {
182 static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdaStridedBatched_buffersize: not implemented");
183 }
184 template<>
185 void gesvdaStridedBatched_buffersize<float>(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(float, float));
186 template<>
187 void gesvdaStridedBatched_buffersize<double>(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(double, double));
188 template<>
189 void gesvdaStridedBatched_buffersize<c10::complex<float>>(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
190 template<>
191 void gesvdaStridedBatched_buffersize<c10::complex<double>>(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
192
193
194 #define CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(Dtype, Vtype) \
195 cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, Dtype *A, int lda, long long int strideA, \
196 Vtype *S, long long int strideS, Dtype *U, int ldu, long long int strideU, Dtype *V, int ldv, long long int strideV, \
197 Dtype *work, int lwork, int *info, double *h_R_nrmF, int batchSize
198 // h_R_nrmF is always double, regardless of input Dtype.
199
200 template<class Dtype, class Vtype>
gesvdaStridedBatched(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES (Dtype,Vtype))201 void gesvdaStridedBatched(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(Dtype, Vtype)) {
202 static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdaStridedBatched: not implemented");
203 }
204 template<>
205 void gesvdaStridedBatched<float>(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(float, float));
206 template<>
207 void gesvdaStridedBatched<double>(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(double, double));
208 template<>
209 void gesvdaStridedBatched<c10::complex<float>>(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(c10::complex<float>, float));
210 template<>
211 void gesvdaStridedBatched<c10::complex<double>>(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(c10::complex<double>, double));
212
213
214 #define CUDASOLVER_POTRF_ARGTYPES(Dtype) \
215 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype* A, int lda, Dtype* work, int lwork, int* info
216
217 template<class Dtype>
potrf(CUDASOLVER_POTRF_ARGTYPES (Dtype))218 void potrf(CUDASOLVER_POTRF_ARGTYPES(Dtype)) {
219 static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrf: not implemented");
220 }
221 template<>
222 void potrf<float>(CUDASOLVER_POTRF_ARGTYPES(float));
223 template<>
224 void potrf<double>(CUDASOLVER_POTRF_ARGTYPES(double));
225 template<>
226 void potrf<c10::complex<float>>(CUDASOLVER_POTRF_ARGTYPES(c10::complex<float>));
227 template<>
228 void potrf<c10::complex<double>>(CUDASOLVER_POTRF_ARGTYPES(c10::complex<double>));
229
230
231 #define CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(Dtype) \
232 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype* A, int lda, int* lwork
233
234 template<class Dtype>
potrf_buffersize(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES (Dtype))235 void potrf_buffersize(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(Dtype)) {
236 static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrf_buffersize: not implemented");
237 }
238 template<>
239 void potrf_buffersize<float>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(float));
240 template<>
241 void potrf_buffersize<double>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(double));
242 template<>
243 void potrf_buffersize<c10::complex<float>>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(c10::complex<float>));
244 template<>
245 void potrf_buffersize<c10::complex<double>>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(c10::complex<double>));
246
247
248 #define CUDASOLVER_POTRF_BATCHED_ARGTYPES(Dtype) \
249 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype** A, int lda, int* info, int batchSize
250
251 template<class Dtype>
potrfBatched(CUDASOLVER_POTRF_BATCHED_ARGTYPES (Dtype))252 void potrfBatched(CUDASOLVER_POTRF_BATCHED_ARGTYPES(Dtype)) {
253 static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrfBatched: not implemented");
254 }
255 template<>
256 void potrfBatched<float>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(float));
257 template<>
258 void potrfBatched<double>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(double));
259 template<>
260 void potrfBatched<c10::complex<float>>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(c10::complex<float>));
261 template<>
262 void potrfBatched<c10::complex<double>>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(c10::complex<double>));
263
264 #define CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(scalar_t) \
265 cusolverDnHandle_t handle, int m, int n, scalar_t *A, int lda, int *lwork
266
267 template <class scalar_t>
geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES (scalar_t))268 void geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(scalar_t)) {
269 static_assert(false&&sizeof(scalar_t),
270 "at::cuda::solver::geqrf_bufferSize: not implemented");
271 }
272 template <>
273 void geqrf_bufferSize<float>(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(float));
274 template <>
275 void geqrf_bufferSize<double>(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(double));
276 template <>
277 void geqrf_bufferSize<c10::complex<float>>(
278 CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(c10::complex<float>));
279 template <>
280 void geqrf_bufferSize<c10::complex<double>>(
281 CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(c10::complex<double>));
282
283 #define CUDASOLVER_GEQRF_ARGTYPES(scalar_t) \
284 cusolverDnHandle_t handle, int m, int n, scalar_t *A, int lda, \
285 scalar_t *tau, scalar_t *work, int lwork, int *devInfo
286
287 template <class scalar_t>
geqrf(CUDASOLVER_GEQRF_ARGTYPES (scalar_t))288 void geqrf(CUDASOLVER_GEQRF_ARGTYPES(scalar_t)) {
289 static_assert(false&&sizeof(scalar_t),
290 "at::cuda::solver::geqrf: not implemented");
291 }
292 template <>
293 void geqrf<float>(CUDASOLVER_GEQRF_ARGTYPES(float));
294 template <>
295 void geqrf<double>(CUDASOLVER_GEQRF_ARGTYPES(double));
296 template <>
297 void geqrf<c10::complex<float>>(CUDASOLVER_GEQRF_ARGTYPES(c10::complex<float>));
298 template <>
299 void geqrf<c10::complex<double>>(
300 CUDASOLVER_GEQRF_ARGTYPES(c10::complex<double>));
301
302 #define CUDASOLVER_POTRS_ARGTYPES(Dtype) \
303 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, const Dtype *A, int lda, Dtype *B, int ldb, int *devInfo
304
305 template<class Dtype>
potrs(CUDASOLVER_POTRS_ARGTYPES (Dtype))306 void potrs(CUDASOLVER_POTRS_ARGTYPES(Dtype)) {
307 static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrs: not implemented");
308 }
309 template<>
310 void potrs<float>(CUDASOLVER_POTRS_ARGTYPES(float));
311 template<>
312 void potrs<double>(CUDASOLVER_POTRS_ARGTYPES(double));
313 template<>
314 void potrs<c10::complex<float>>(CUDASOLVER_POTRS_ARGTYPES(c10::complex<float>));
315 template<>
316 void potrs<c10::complex<double>>(CUDASOLVER_POTRS_ARGTYPES(c10::complex<double>));
317
318
319 #define CUDASOLVER_POTRS_BATCHED_ARGTYPES(Dtype) \
320 cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, Dtype *Aarray[], int lda, Dtype *Barray[], int ldb, int *info, int batchSize
321
322 template<class Dtype>
potrsBatched(CUDASOLVER_POTRS_BATCHED_ARGTYPES (Dtype))323 void potrsBatched(CUDASOLVER_POTRS_BATCHED_ARGTYPES(Dtype)) {
324 static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrsBatched: not implemented");
325 }
326 template<>
327 void potrsBatched<float>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(float));
328 template<>
329 void potrsBatched<double>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(double));
330 template<>
331 void potrsBatched<c10::complex<float>>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(c10::complex<float>));
332 template<>
333 void potrsBatched<c10::complex<double>>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(c10::complex<double>));
334
335
336 #define CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(Dtype) \
337 cusolverDnHandle_t handle, int m, int n, int k, const Dtype *A, int lda, \
338 const Dtype *tau, int *lwork
339
340 template <class Dtype>
orgqr_buffersize(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES (Dtype))341 void orgqr_buffersize(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(Dtype)) {
342 static_assert(false&&sizeof(Dtype), "at::cuda::solver::orgqr_buffersize: not implemented");
343 }
344 template <>
345 void orgqr_buffersize<float>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(float));
346 template <>
347 void orgqr_buffersize<double>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(double));
348 template <>
349 void orgqr_buffersize<c10::complex<float>>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(c10::complex<float>));
350 template <>
351 void orgqr_buffersize<c10::complex<double>>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(c10::complex<double>));
352
353
354 #define CUDASOLVER_ORGQR_ARGTYPES(Dtype) \
355 cusolverDnHandle_t handle, int m, int n, int k, Dtype *A, int lda, \
356 const Dtype *tau, Dtype *work, int lwork, int *devInfo
357
358 template <class Dtype>
orgqr(CUDASOLVER_ORGQR_ARGTYPES (Dtype))359 void orgqr(CUDASOLVER_ORGQR_ARGTYPES(Dtype)) {
360 static_assert(false&&sizeof(Dtype), "at::cuda::solver::orgqr: not implemented");
361 }
362 template <>
363 void orgqr<float>(CUDASOLVER_ORGQR_ARGTYPES(float));
364 template <>
365 void orgqr<double>(CUDASOLVER_ORGQR_ARGTYPES(double));
366 template <>
367 void orgqr<c10::complex<float>>(CUDASOLVER_ORGQR_ARGTYPES(c10::complex<float>));
368 template <>
369 void orgqr<c10::complex<double>>(CUDASOLVER_ORGQR_ARGTYPES(c10::complex<double>));
370
371 #define CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(Dtype) \
372 cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, \
373 int m, int n, int k, const Dtype *A, int lda, const Dtype *tau, \
374 const Dtype *C, int ldc, int *lwork
375
376 template <class Dtype>
ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES (Dtype))377 void ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(Dtype)) {
378 static_assert(false&&sizeof(Dtype),
379 "at::cuda::solver::ormqr_bufferSize: not implemented");
380 }
381 template <>
382 void ormqr_bufferSize<float>(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(float));
383 template <>
384 void ormqr_bufferSize<double>(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(double));
385 template <>
386 void ormqr_bufferSize<c10::complex<float>>(
387 CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(c10::complex<float>));
388 template <>
389 void ormqr_bufferSize<c10::complex<double>>(
390 CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(c10::complex<double>));
391
392 #define CUDASOLVER_ORMQR_ARGTYPES(Dtype) \
393 cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, \
394 int m, int n, int k, const Dtype *A, int lda, const Dtype *tau, Dtype *C, \
395 int ldc, Dtype *work, int lwork, int *devInfo
396
397 template <class Dtype>
ormqr(CUDASOLVER_ORMQR_ARGTYPES (Dtype))398 void ormqr(CUDASOLVER_ORMQR_ARGTYPES(Dtype)) {
399 static_assert(false&&sizeof(Dtype),
400 "at::cuda::solver::ormqr: not implemented");
401 }
402 template <>
403 void ormqr<float>(CUDASOLVER_ORMQR_ARGTYPES(float));
404 template <>
405 void ormqr<double>(CUDASOLVER_ORMQR_ARGTYPES(double));
406 template <>
407 void ormqr<c10::complex<float>>(CUDASOLVER_ORMQR_ARGTYPES(c10::complex<float>));
408 template <>
409 void ormqr<c10::complex<double>>(
410 CUDASOLVER_ORMQR_ARGTYPES(c10::complex<double>));
411
412 #ifdef USE_CUSOLVER_64_BIT
413
414 template<class Dtype>
get_cusolver_datatype()415 cudaDataType get_cusolver_datatype() {
416 static_assert(false&&sizeof(Dtype), "cusolver doesn't support data type");
417 return {};
418 }
419 template<> cudaDataType get_cusolver_datatype<float>();
420 template<> cudaDataType get_cusolver_datatype<double>();
421 template<> cudaDataType get_cusolver_datatype<c10::complex<float>>();
422 template<> cudaDataType get_cusolver_datatype<c10::complex<double>>();
423
424 void xpotrf_buffersize(
425 cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, cudaDataType dataTypeA, const void *A,
426 int64_t lda, cudaDataType computeType, size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost);
427
428 void xpotrf(
429 cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, cudaDataType dataTypeA, void *A,
430 int64_t lda, cudaDataType computeType, void *bufferOnDevice, size_t workspaceInBytesOnDevice, void *bufferOnHost, size_t workspaceInBytesOnHost,
431 int *info);
432
433 void xpotrs(
434 cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, int64_t nrhs, cudaDataType dataTypeA, const void *A,
435 int64_t lda, cudaDataType dataTypeB, void *B, int64_t ldb, int *info);
436
437 #endif // USE_CUSOLVER_64_BIT
438
439 #define CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \
440 cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
441 int n, const scalar_t *A, int lda, const value_t *W, int *lwork
442
443 template <class scalar_t, class value_t = scalar_t>
syevd_bufferSize(CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES (scalar_t,value_t))444 void syevd_bufferSize(CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) {
445 static_assert(false&&sizeof(scalar_t),
446 "at::cuda::solver::syevd_bufferSize: not implemented");
447 }
448
449 template <>
450 void syevd_bufferSize<float>(
451 CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(float, float));
452 template <>
453 void syevd_bufferSize<double>(
454 CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(double, double));
455 template <>
456 void syevd_bufferSize<c10::complex<float>, float>(
457 CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
458 template <>
459 void syevd_bufferSize<c10::complex<double>, double>(
460 CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
461
462 #define CUDASOLVER_SYEVD_ARGTYPES(scalar_t, value_t) \
463 cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
464 int n, scalar_t *A, int lda, value_t *W, scalar_t *work, int lwork, \
465 int *info
466
467 template <class scalar_t, class value_t = scalar_t>
syevd(CUDASOLVER_SYEVD_ARGTYPES (scalar_t,value_t))468 void syevd(CUDASOLVER_SYEVD_ARGTYPES(scalar_t, value_t)) {
469 static_assert(false&&sizeof(scalar_t),
470 "at::cuda::solver::syevd: not implemented");
471 }
472
473 template <>
474 void syevd<float>(CUDASOLVER_SYEVD_ARGTYPES(float, float));
475 template <>
476 void syevd<double>(CUDASOLVER_SYEVD_ARGTYPES(double, double));
477 template <>
478 void syevd<c10::complex<float>, float>(
479 CUDASOLVER_SYEVD_ARGTYPES(c10::complex<float>, float));
480 template <>
481 void syevd<c10::complex<double>, double>(
482 CUDASOLVER_SYEVD_ARGTYPES(c10::complex<double>, double));
483
484 #define CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \
485 cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
486 int n, const scalar_t *A, int lda, const value_t *W, int *lwork, \
487 syevjInfo_t params
488
489 template <class scalar_t, class value_t = scalar_t>
syevj_bufferSize(CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES (scalar_t,value_t))490 void syevj_bufferSize(CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) {
491 static_assert(false&&sizeof(scalar_t),
492 "at::cuda::solver::syevj_bufferSize: not implemented");
493 }
494
495 template <>
496 void syevj_bufferSize<float>(
497 CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(float, float));
498 template <>
499 void syevj_bufferSize<double>(
500 CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(double, double));
501 template <>
502 void syevj_bufferSize<c10::complex<float>, float>(
503 CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
504 template <>
505 void syevj_bufferSize<c10::complex<double>, double>(
506 CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
507
508 #define CUDASOLVER_SYEVJ_ARGTYPES(scalar_t, value_t) \
509 cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
510 int n, scalar_t *A, int lda, value_t *W, scalar_t *work, int lwork, \
511 int *info, syevjInfo_t params
512
513 template <class scalar_t, class value_t = scalar_t>
syevj(CUDASOLVER_SYEVJ_ARGTYPES (scalar_t,value_t))514 void syevj(CUDASOLVER_SYEVJ_ARGTYPES(scalar_t, value_t)) {
515 static_assert(false&&sizeof(scalar_t), "at::cuda::solver::syevj: not implemented");
516 }
517
518 template <>
519 void syevj<float>(CUDASOLVER_SYEVJ_ARGTYPES(float, float));
520 template <>
521 void syevj<double>(CUDASOLVER_SYEVJ_ARGTYPES(double, double));
522 template <>
523 void syevj<c10::complex<float>, float>(
524 CUDASOLVER_SYEVJ_ARGTYPES(c10::complex<float>, float));
525 template <>
526 void syevj<c10::complex<double>, double>(
527 CUDASOLVER_SYEVJ_ARGTYPES(c10::complex<double>, double));
528
529 #define CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \
530 cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
531 int n, const scalar_t *A, int lda, const value_t *W, int *lwork, \
532 syevjInfo_t params, int batchsize
533
534 template <class scalar_t, class value_t = scalar_t>
syevjBatched_bufferSize(CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES (scalar_t,value_t))535 void syevjBatched_bufferSize(
536 CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) {
537 static_assert(false&&sizeof(scalar_t),
538 "at::cuda::solver::syevjBatched_bufferSize: not implemented");
539 }
540
541 template <>
542 void syevjBatched_bufferSize<float>(
543 CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(float, float));
544 template <>
545 void syevjBatched_bufferSize<double>(
546 CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(double, double));
547 template <>
548 void syevjBatched_bufferSize<c10::complex<float>, float>(
549 CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
550 template <>
551 void syevjBatched_bufferSize<c10::complex<double>, double>(
552 CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
553
554 #define CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t) \
555 cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
556 int n, scalar_t *A, int lda, value_t *W, scalar_t *work, int lwork, \
557 int *info, syevjInfo_t params, int batchsize
558
559 template <class scalar_t, class value_t = scalar_t>
syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES (scalar_t,value_t))560 void syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t)) {
561 static_assert(false&&sizeof(scalar_t),
562 "at::cuda::solver::syevjBatched: not implemented");
563 }
564
565 template <>
566 void syevjBatched<float>(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(float, float));
567 template <>
568 void syevjBatched<double>(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(double, double));
569 template <>
570 void syevjBatched<c10::complex<float>, float>(
571 CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(c10::complex<float>, float));
572 template <>
573 void syevjBatched<c10::complex<double>, double>(
574 CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(c10::complex<double>, double));
575
576 #ifdef USE_CUSOLVER_64_BIT
577
578 #define CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(scalar_t) \
579 cusolverDnHandle_t handle, cusolverDnParams_t params, int64_t m, int64_t n, \
580 const scalar_t *A, int64_t lda, const scalar_t *tau, \
581 size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost
582
583 template <class scalar_t>
xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES (scalar_t))584 void xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(scalar_t)) {
585 static_assert(false&&sizeof(scalar_t),
586 "at::cuda::solver::xgeqrf_bufferSize: not implemented");
587 }
588
589 template <>
590 void xgeqrf_bufferSize<float>(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(float));
591 template <>
592 void xgeqrf_bufferSize<double>(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(double));
593 template <>
594 void xgeqrf_bufferSize<c10::complex<float>>(
595 CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(c10::complex<float>));
596 template <>
597 void xgeqrf_bufferSize<c10::complex<double>>(
598 CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(c10::complex<double>));
599
600 #define CUDASOLVER_XGEQRF_ARGTYPES(scalar_t) \
601 cusolverDnHandle_t handle, cusolverDnParams_t params, int64_t m, int64_t n, \
602 scalar_t *A, int64_t lda, scalar_t *tau, scalar_t *bufferOnDevice, \
603 size_t workspaceInBytesOnDevice, scalar_t *bufferOnHost, \
604 size_t workspaceInBytesOnHost, int *info
605
606 template <class scalar_t>
xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES (scalar_t))607 void xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES(scalar_t)) {
608 static_assert(false&&sizeof(scalar_t), "at::cuda::solver::xgeqrf: not implemented");
609 }
610
611 template <>
612 void xgeqrf<float>(CUDASOLVER_XGEQRF_ARGTYPES(float));
613 template <>
614 void xgeqrf<double>(CUDASOLVER_XGEQRF_ARGTYPES(double));
615 template <>
616 void xgeqrf<c10::complex<float>>(
617 CUDASOLVER_XGEQRF_ARGTYPES(c10::complex<float>));
618 template <>
619 void xgeqrf<c10::complex<double>>(
620 CUDASOLVER_XGEQRF_ARGTYPES(c10::complex<double>));
621
622 #define CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \
623 cusolverDnHandle_t handle, cusolverDnParams_t params, \
624 cusolverEigMode_t jobz, cublasFillMode_t uplo, int64_t n, \
625 const scalar_t *A, int64_t lda, const value_t *W, \
626 size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost
627
628 template <class scalar_t, class value_t = scalar_t>
xsyevd_bufferSize(CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES (scalar_t,value_t))629 void xsyevd_bufferSize(
630 CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) {
631 static_assert(false&&sizeof(scalar_t),
632 "at::cuda::solver::xsyevd_bufferSize: not implemented");
633 }
634
635 template <>
636 void xsyevd_bufferSize<float>(
637 CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(float, float));
638 template <>
639 void xsyevd_bufferSize<double>(
640 CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(double, double));
641 template <>
642 void xsyevd_bufferSize<c10::complex<float>, float>(
643 CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
644 template <>
645 void xsyevd_bufferSize<c10::complex<double>, double>(
646 CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
647
648 #define CUDASOLVER_XSYEVD_ARGTYPES(scalar_t, value_t) \
649 cusolverDnHandle_t handle, cusolverDnParams_t params, \
650 cusolverEigMode_t jobz, cublasFillMode_t uplo, int64_t n, scalar_t *A, \
651 int64_t lda, value_t *W, scalar_t *bufferOnDevice, \
652 size_t workspaceInBytesOnDevice, scalar_t *bufferOnHost, \
653 size_t workspaceInBytesOnHost, int *info
654
655 template <class scalar_t, class value_t = scalar_t>
xsyevd(CUDASOLVER_XSYEVD_ARGTYPES (scalar_t,value_t))656 void xsyevd(CUDASOLVER_XSYEVD_ARGTYPES(scalar_t, value_t)) {
657 static_assert(false&&sizeof(scalar_t),
658 "at::cuda::solver::xsyevd: not implemented");
659 }
660
661 template <>
662 void xsyevd<float>(CUDASOLVER_XSYEVD_ARGTYPES(float, float));
663 template <>
664 void xsyevd<double>(CUDASOLVER_XSYEVD_ARGTYPES(double, double));
665 template <>
666 void xsyevd<c10::complex<float>, float>(
667 CUDASOLVER_XSYEVD_ARGTYPES(c10::complex<float>, float));
668 template <>
669 void xsyevd<c10::complex<double>, double>(
670 CUDASOLVER_XSYEVD_ARGTYPES(c10::complex<double>, double));
671
672 #endif // USE_CUSOLVER_64_BIT
673
674 } // namespace solver
675 } // namespace cuda
676 } // namespace at
677
678 #endif // CUDART_VERSION
679