xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/linalg/CUDASolver.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cuda/CUDAContext.h>
4 
5 #if defined(CUDART_VERSION) && defined(CUSOLVER_VERSION) && CUSOLVER_VERSION >= 11000
6 // cuSOLVER version >= 11000 includes 64-bit API
7 #define USE_CUSOLVER_64_BIT
8 #endif
9 
10 #if defined(CUDART_VERSION) || defined(USE_ROCM)
11 
12 namespace at {
13 namespace cuda {
14 namespace solver {
15 
16 #define CUDASOLVER_GETRF_ARGTYPES(Dtype)  \
17     cusolverDnHandle_t handle, int m, int n, Dtype* dA, int ldda, int* ipiv, int* info
18 
19 template<class Dtype>
getrf(CUDASOLVER_GETRF_ARGTYPES (Dtype))20 void getrf(CUDASOLVER_GETRF_ARGTYPES(Dtype)) {
21   static_assert(false&&sizeof(Dtype), "at::cuda::solver::getrf: not implemented");
22 }
23 template<>
24 void getrf<float>(CUDASOLVER_GETRF_ARGTYPES(float));
25 template<>
26 void getrf<double>(CUDASOLVER_GETRF_ARGTYPES(double));
27 template<>
28 void getrf<c10::complex<double>>(CUDASOLVER_GETRF_ARGTYPES(c10::complex<double>));
29 template<>
30 void getrf<c10::complex<float>>(CUDASOLVER_GETRF_ARGTYPES(c10::complex<float>));
31 
32 
33 #define CUDASOLVER_GETRS_ARGTYPES(Dtype)  \
34     cusolverDnHandle_t handle, int n, int nrhs, Dtype* dA, int lda, int* ipiv, Dtype* ret, int ldb, int* info, cublasOperation_t trans
35 
36 template<class Dtype>
getrs(CUDASOLVER_GETRS_ARGTYPES (Dtype))37 void getrs(CUDASOLVER_GETRS_ARGTYPES(Dtype)) {
38   static_assert(false&&sizeof(Dtype), "at::cuda::solver::getrs: not implemented");
39 }
40 template<>
41 void getrs<float>(CUDASOLVER_GETRS_ARGTYPES(float));
42 template<>
43 void getrs<double>(CUDASOLVER_GETRS_ARGTYPES(double));
44 template<>
45 void getrs<c10::complex<double>>(CUDASOLVER_GETRS_ARGTYPES(c10::complex<double>));
46 template<>
47 void getrs<c10::complex<float>>(CUDASOLVER_GETRS_ARGTYPES(c10::complex<float>));
48 
49 #define CUDASOLVER_SYTRF_BUFFER_ARGTYPES(Dtype) \
50   cusolverDnHandle_t handle, int n, Dtype *A, int lda, int *lwork
51 
52 template <class Dtype>
sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES (Dtype))53 void sytrf_bufferSize(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(Dtype)) {
54   static_assert(false&&sizeof(Dtype),
55       "at::cuda::solver::sytrf_bufferSize: not implemented");
56 }
57 template <>
58 void sytrf_bufferSize<float>(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(float));
59 template <>
60 void sytrf_bufferSize<double>(CUDASOLVER_SYTRF_BUFFER_ARGTYPES(double));
61 template <>
62 void sytrf_bufferSize<c10::complex<double>>(
63     CUDASOLVER_SYTRF_BUFFER_ARGTYPES(c10::complex<double>));
64 template <>
65 void sytrf_bufferSize<c10::complex<float>>(
66     CUDASOLVER_SYTRF_BUFFER_ARGTYPES(c10::complex<float>));
67 
68 #define CUDASOLVER_SYTRF_ARGTYPES(Dtype)                                      \
69   cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype *A, int lda, \
70       int *ipiv, Dtype *work, int lwork, int *devInfo
71 
72 template <class Dtype>
sytrf(CUDASOLVER_SYTRF_ARGTYPES (Dtype))73 void sytrf(CUDASOLVER_SYTRF_ARGTYPES(Dtype)) {
74   static_assert(false&&sizeof(Dtype),
75       "at::cuda::solver::sytrf: not implemented");
76 }
77 template <>
78 void sytrf<float>(CUDASOLVER_SYTRF_ARGTYPES(float));
79 template <>
80 void sytrf<double>(CUDASOLVER_SYTRF_ARGTYPES(double));
81 template <>
82 void sytrf<c10::complex<double>>(
83     CUDASOLVER_SYTRF_ARGTYPES(c10::complex<double>));
84 template <>
85 void sytrf<c10::complex<float>>(CUDASOLVER_SYTRF_ARGTYPES(c10::complex<float>));
86 
87 #define CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()  \
88     cusolverDnHandle_t handle, int m, int n, int *lwork
89 
90 template<class Dtype>
gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES ())91 void gesvd_buffersize(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES()) {
92   static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvd_buffersize: not implemented");
93 }
94 template<>
95 void gesvd_buffersize<float>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES());
96 template<>
97 void gesvd_buffersize<double>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES());
98 template<>
99 void gesvd_buffersize<c10::complex<float>>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES());
100 template<>
101 void gesvd_buffersize<c10::complex<double>>(CUDASOLVER_GESVD_BUFFERSIZE_ARGTYPES());
102 
103 
104 #define CUDASOLVER_GESVD_ARGTYPES(Dtype, Vtype)  \
105     cusolverDnHandle_t handle, signed char jobu, signed char jobvt, int m, int n, Dtype *A, int lda, \
106     Vtype *S, Dtype *U, int ldu, Dtype *VT, int ldvt, Dtype *work, int lwork, Vtype *rwork, int *info
107 
108 template<class Dtype, class Vtype>
gesvd(CUDASOLVER_GESVD_ARGTYPES (Dtype,Vtype))109 void gesvd(CUDASOLVER_GESVD_ARGTYPES(Dtype, Vtype)) {
110   static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvd: not implemented");
111 }
112 template<>
113 void gesvd<float>(CUDASOLVER_GESVD_ARGTYPES(float, float));
114 template<>
115 void gesvd<double>(CUDASOLVER_GESVD_ARGTYPES(double, double));
116 template<>
117 void gesvd<c10::complex<float>>(CUDASOLVER_GESVD_ARGTYPES(c10::complex<float>, float));
118 template<>
119 void gesvd<c10::complex<double>>(CUDASOLVER_GESVD_ARGTYPES(c10::complex<double>, double));
120 
121 
122 #define CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(Dtype, Vtype)  \
123     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, Dtype *A, int lda, Vtype *S, \
124     Dtype *U, int ldu, Dtype *V, int ldv, int *lwork, gesvdjInfo_t params
125 
126 template<class Dtype, class Vtype>
gesvdj_buffersize(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES (Dtype,Vtype))127 void gesvdj_buffersize(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(Dtype, Vtype)) {
128   static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj_buffersize: not implemented");
129 }
130 template<>
131 void gesvdj_buffersize<float>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(float, float));
132 template<>
133 void gesvdj_buffersize<double>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(double, double));
134 template<>
135 void gesvdj_buffersize<c10::complex<float>>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
136 template<>
137 void gesvdj_buffersize<c10::complex<double>>(CUDASOLVER_GESVDJ_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
138 
139 
140 #define CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype)  \
141     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, Dtype* A, int lda, Vtype* S, Dtype* U, \
142     int ldu, Dtype* V, int ldv, Dtype* work, int lwork, int *info, gesvdjInfo_t params
143 
144 template<class Dtype, class Vtype>
gesvdj(CUDASOLVER_GESVDJ_ARGTYPES (Dtype,Vtype))145 void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype)) {
146   static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj: not implemented");
147 }
148 template<>
149 void gesvdj<float>(CUDASOLVER_GESVDJ_ARGTYPES(float, float));
150 template<>
151 void gesvdj<double>(CUDASOLVER_GESVDJ_ARGTYPES(double, double));
152 template<>
153 void gesvdj<c10::complex<float>>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex<float>, float));
154 template<>
155 void gesvdj<c10::complex<double>>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex<double>, double));
156 
157 
158 #define CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype)  \
159     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, Dtype* A, int lda, Vtype* S, Dtype* U, \
160     int ldu, Dtype *V, int ldv, int *info, gesvdjInfo_t params, int batchSize
161 
162 template<class Dtype, class Vtype>
gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES (Dtype,Vtype))163 void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype)) {
164   static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdj: not implemented");
165 }
166 template<>
167 void gesvdjBatched<float>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(float, float));
168 template<>
169 void gesvdjBatched<double>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(double, double));
170 template<>
171 void gesvdjBatched<c10::complex<float>>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10::complex<float>, float));
172 template<>
173 void gesvdjBatched<c10::complex<double>>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10::complex<double>, double));
174 
175 #define CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(Dtype, Vtype)  \
176     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, Dtype *A, int lda, long long int strideA, \
177     Vtype *S, long long int strideS, Dtype *U, int ldu, long long int strideU, Dtype *V, int ldv, long long int strideV, \
178     int *lwork, int batchSize
179 
180 template<class Dtype, class Vtype>
gesvdaStridedBatched_buffersize(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES (Dtype,Vtype))181 void gesvdaStridedBatched_buffersize(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(Dtype, Vtype)) {
182   static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdaStridedBatched_buffersize: not implemented");
183 }
184 template<>
185 void gesvdaStridedBatched_buffersize<float>(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(float, float));
186 template<>
187 void gesvdaStridedBatched_buffersize<double>(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(double, double));
188 template<>
189 void gesvdaStridedBatched_buffersize<c10::complex<float>>(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
190 template<>
191 void gesvdaStridedBatched_buffersize<c10::complex<double>>(CUDASOLVER_GESVDA_STRIDED_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
192 
193 
194 #define CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(Dtype, Vtype)  \
195     cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, Dtype *A, int lda, long long int strideA, \
196     Vtype *S, long long int strideS, Dtype *U, int ldu, long long int strideU, Dtype *V, int ldv, long long int strideV, \
197     Dtype *work, int lwork, int *info, double *h_R_nrmF, int batchSize
198 // h_R_nrmF is always double, regardless of input Dtype.
199 
200 template<class Dtype, class Vtype>
gesvdaStridedBatched(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES (Dtype,Vtype))201 void gesvdaStridedBatched(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(Dtype, Vtype)) {
202   static_assert(false&&sizeof(Dtype), "at::cuda::solver::gesvdaStridedBatched: not implemented");
203 }
204 template<>
205 void gesvdaStridedBatched<float>(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(float, float));
206 template<>
207 void gesvdaStridedBatched<double>(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(double, double));
208 template<>
209 void gesvdaStridedBatched<c10::complex<float>>(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(c10::complex<float>, float));
210 template<>
211 void gesvdaStridedBatched<c10::complex<double>>(CUDASOLVER_GESVDA_STRIDED_BATCHED_ARGTYPES(c10::complex<double>, double));
212 
213 
214 #define CUDASOLVER_POTRF_ARGTYPES(Dtype)  \
215     cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype* A, int lda, Dtype* work, int lwork, int* info
216 
217 template<class Dtype>
potrf(CUDASOLVER_POTRF_ARGTYPES (Dtype))218 void potrf(CUDASOLVER_POTRF_ARGTYPES(Dtype)) {
219   static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrf: not implemented");
220 }
221 template<>
222 void potrf<float>(CUDASOLVER_POTRF_ARGTYPES(float));
223 template<>
224 void potrf<double>(CUDASOLVER_POTRF_ARGTYPES(double));
225 template<>
226 void potrf<c10::complex<float>>(CUDASOLVER_POTRF_ARGTYPES(c10::complex<float>));
227 template<>
228 void potrf<c10::complex<double>>(CUDASOLVER_POTRF_ARGTYPES(c10::complex<double>));
229 
230 
231 #define CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(Dtype)  \
232     cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype* A, int lda, int* lwork
233 
234 template<class Dtype>
potrf_buffersize(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES (Dtype))235 void potrf_buffersize(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(Dtype)) {
236   static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrf_buffersize: not implemented");
237 }
238 template<>
239 void potrf_buffersize<float>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(float));
240 template<>
241 void potrf_buffersize<double>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(double));
242 template<>
243 void potrf_buffersize<c10::complex<float>>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(c10::complex<float>));
244 template<>
245 void potrf_buffersize<c10::complex<double>>(CUDASOLVER_POTRF_BUFFERSIZE_ARGTYPES(c10::complex<double>));
246 
247 
248 #define CUDASOLVER_POTRF_BATCHED_ARGTYPES(Dtype)  \
249     cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, Dtype** A, int lda, int* info, int batchSize
250 
251 template<class Dtype>
potrfBatched(CUDASOLVER_POTRF_BATCHED_ARGTYPES (Dtype))252 void potrfBatched(CUDASOLVER_POTRF_BATCHED_ARGTYPES(Dtype)) {
253   static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrfBatched: not implemented");
254 }
255 template<>
256 void potrfBatched<float>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(float));
257 template<>
258 void potrfBatched<double>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(double));
259 template<>
260 void potrfBatched<c10::complex<float>>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(c10::complex<float>));
261 template<>
262 void potrfBatched<c10::complex<double>>(CUDASOLVER_POTRF_BATCHED_ARGTYPES(c10::complex<double>));
263 
264 #define CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(scalar_t) \
265   cusolverDnHandle_t handle, int m, int n, scalar_t *A, int lda, int *lwork
266 
267 template <class scalar_t>
geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES (scalar_t))268 void geqrf_bufferSize(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(scalar_t)) {
269   static_assert(false&&sizeof(scalar_t),
270       "at::cuda::solver::geqrf_bufferSize: not implemented");
271 }
272 template <>
273 void geqrf_bufferSize<float>(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(float));
274 template <>
275 void geqrf_bufferSize<double>(CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(double));
276 template <>
277 void geqrf_bufferSize<c10::complex<float>>(
278     CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(c10::complex<float>));
279 template <>
280 void geqrf_bufferSize<c10::complex<double>>(
281     CUDASOLVER_GEQRF_BUFFERSIZE_ARGTYPES(c10::complex<double>));
282 
283 #define CUDASOLVER_GEQRF_ARGTYPES(scalar_t)                      \
284   cusolverDnHandle_t handle, int m, int n, scalar_t *A, int lda, \
285       scalar_t *tau, scalar_t *work, int lwork, int *devInfo
286 
287 template <class scalar_t>
geqrf(CUDASOLVER_GEQRF_ARGTYPES (scalar_t))288 void geqrf(CUDASOLVER_GEQRF_ARGTYPES(scalar_t)) {
289   static_assert(false&&sizeof(scalar_t),
290       "at::cuda::solver::geqrf: not implemented");
291 }
292 template <>
293 void geqrf<float>(CUDASOLVER_GEQRF_ARGTYPES(float));
294 template <>
295 void geqrf<double>(CUDASOLVER_GEQRF_ARGTYPES(double));
296 template <>
297 void geqrf<c10::complex<float>>(CUDASOLVER_GEQRF_ARGTYPES(c10::complex<float>));
298 template <>
299 void geqrf<c10::complex<double>>(
300     CUDASOLVER_GEQRF_ARGTYPES(c10::complex<double>));
301 
302 #define CUDASOLVER_POTRS_ARGTYPES(Dtype)  \
303     cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, const Dtype *A, int lda, Dtype *B, int ldb, int *devInfo
304 
305 template<class Dtype>
potrs(CUDASOLVER_POTRS_ARGTYPES (Dtype))306 void potrs(CUDASOLVER_POTRS_ARGTYPES(Dtype)) {
307   static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrs: not implemented");
308 }
309 template<>
310 void potrs<float>(CUDASOLVER_POTRS_ARGTYPES(float));
311 template<>
312 void potrs<double>(CUDASOLVER_POTRS_ARGTYPES(double));
313 template<>
314 void potrs<c10::complex<float>>(CUDASOLVER_POTRS_ARGTYPES(c10::complex<float>));
315 template<>
316 void potrs<c10::complex<double>>(CUDASOLVER_POTRS_ARGTYPES(c10::complex<double>));
317 
318 
319 #define CUDASOLVER_POTRS_BATCHED_ARGTYPES(Dtype)  \
320     cusolverDnHandle_t handle, cublasFillMode_t uplo, int n, int nrhs, Dtype *Aarray[], int lda, Dtype *Barray[], int ldb, int *info, int batchSize
321 
322 template<class Dtype>
potrsBatched(CUDASOLVER_POTRS_BATCHED_ARGTYPES (Dtype))323 void potrsBatched(CUDASOLVER_POTRS_BATCHED_ARGTYPES(Dtype)) {
324   static_assert(false&&sizeof(Dtype), "at::cuda::solver::potrsBatched: not implemented");
325 }
326 template<>
327 void potrsBatched<float>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(float));
328 template<>
329 void potrsBatched<double>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(double));
330 template<>
331 void potrsBatched<c10::complex<float>>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(c10::complex<float>));
332 template<>
333 void potrsBatched<c10::complex<double>>(CUDASOLVER_POTRS_BATCHED_ARGTYPES(c10::complex<double>));
334 
335 
336 #define CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(Dtype)                        \
337   cusolverDnHandle_t handle, int m, int n, int k, const Dtype *A, int lda, \
338       const Dtype *tau, int *lwork
339 
340 template <class Dtype>
orgqr_buffersize(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES (Dtype))341 void orgqr_buffersize(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(Dtype)) {
342   static_assert(false&&sizeof(Dtype), "at::cuda::solver::orgqr_buffersize: not implemented");
343 }
344 template <>
345 void orgqr_buffersize<float>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(float));
346 template <>
347 void orgqr_buffersize<double>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(double));
348 template <>
349 void orgqr_buffersize<c10::complex<float>>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(c10::complex<float>));
350 template <>
351 void orgqr_buffersize<c10::complex<double>>(CUDASOLVER_ORGQR_BUFFERSIZE_ARGTYPES(c10::complex<double>));
352 
353 
354 #define CUDASOLVER_ORGQR_ARGTYPES(Dtype)                             \
355   cusolverDnHandle_t handle, int m, int n, int k, Dtype *A, int lda, \
356       const Dtype *tau, Dtype *work, int lwork, int *devInfo
357 
358 template <class Dtype>
orgqr(CUDASOLVER_ORGQR_ARGTYPES (Dtype))359 void orgqr(CUDASOLVER_ORGQR_ARGTYPES(Dtype)) {
360   static_assert(false&&sizeof(Dtype), "at::cuda::solver::orgqr: not implemented");
361 }
362 template <>
363 void orgqr<float>(CUDASOLVER_ORGQR_ARGTYPES(float));
364 template <>
365 void orgqr<double>(CUDASOLVER_ORGQR_ARGTYPES(double));
366 template <>
367 void orgqr<c10::complex<float>>(CUDASOLVER_ORGQR_ARGTYPES(c10::complex<float>));
368 template <>
369 void orgqr<c10::complex<double>>(CUDASOLVER_ORGQR_ARGTYPES(c10::complex<double>));
370 
371 #define CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(Dtype)                          \
372   cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, \
373       int m, int n, int k, const Dtype *A, int lda, const Dtype *tau,        \
374       const Dtype *C, int ldc, int *lwork
375 
376 template <class Dtype>
ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES (Dtype))377 void ormqr_bufferSize(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(Dtype)) {
378   static_assert(false&&sizeof(Dtype),
379       "at::cuda::solver::ormqr_bufferSize: not implemented");
380 }
381 template <>
382 void ormqr_bufferSize<float>(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(float));
383 template <>
384 void ormqr_bufferSize<double>(CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(double));
385 template <>
386 void ormqr_bufferSize<c10::complex<float>>(
387     CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(c10::complex<float>));
388 template <>
389 void ormqr_bufferSize<c10::complex<double>>(
390     CUDASOLVER_ORMQR_BUFFERSIZE_ARGTYPES(c10::complex<double>));
391 
392 #define CUDASOLVER_ORMQR_ARGTYPES(Dtype)                                     \
393   cusolverDnHandle_t handle, cublasSideMode_t side, cublasOperation_t trans, \
394       int m, int n, int k, const Dtype *A, int lda, const Dtype *tau, Dtype *C,    \
395       int ldc, Dtype *work, int lwork, int *devInfo
396 
397 template <class Dtype>
ormqr(CUDASOLVER_ORMQR_ARGTYPES (Dtype))398 void ormqr(CUDASOLVER_ORMQR_ARGTYPES(Dtype)) {
399   static_assert(false&&sizeof(Dtype),
400       "at::cuda::solver::ormqr: not implemented");
401 }
402 template <>
403 void ormqr<float>(CUDASOLVER_ORMQR_ARGTYPES(float));
404 template <>
405 void ormqr<double>(CUDASOLVER_ORMQR_ARGTYPES(double));
406 template <>
407 void ormqr<c10::complex<float>>(CUDASOLVER_ORMQR_ARGTYPES(c10::complex<float>));
408 template <>
409 void ormqr<c10::complex<double>>(
410     CUDASOLVER_ORMQR_ARGTYPES(c10::complex<double>));
411 
412 #ifdef USE_CUSOLVER_64_BIT
413 
414 template<class Dtype>
get_cusolver_datatype()415 cudaDataType get_cusolver_datatype() {
416   static_assert(false&&sizeof(Dtype), "cusolver doesn't support data type");
417   return {};
418 }
419 template<> cudaDataType get_cusolver_datatype<float>();
420 template<> cudaDataType get_cusolver_datatype<double>();
421 template<> cudaDataType get_cusolver_datatype<c10::complex<float>>();
422 template<> cudaDataType get_cusolver_datatype<c10::complex<double>>();
423 
424 void xpotrf_buffersize(
425     cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, cudaDataType dataTypeA, const void *A,
426     int64_t lda, cudaDataType computeType, size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost);
427 
428 void xpotrf(
429     cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, cudaDataType dataTypeA, void *A,
430     int64_t lda, cudaDataType computeType, void *bufferOnDevice, size_t workspaceInBytesOnDevice, void *bufferOnHost, size_t workspaceInBytesOnHost,
431     int *info);
432 
433 void xpotrs(
434     cusolverDnHandle_t handle, cusolverDnParams_t params, cublasFillMode_t uplo, int64_t n, int64_t nrhs, cudaDataType dataTypeA, const void *A,
435     int64_t lda, cudaDataType dataTypeB, void *B, int64_t ldb, int *info);
436 
437 #endif // USE_CUSOLVER_64_BIT
438 
439 #define CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t)             \
440   cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
441       int n, const scalar_t *A, int lda, const value_t *W, int *lwork
442 
443 template <class scalar_t, class value_t = scalar_t>
syevd_bufferSize(CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES (scalar_t,value_t))444 void syevd_bufferSize(CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) {
445   static_assert(false&&sizeof(scalar_t),
446       "at::cuda::solver::syevd_bufferSize: not implemented");
447 }
448 
449 template <>
450 void syevd_bufferSize<float>(
451     CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(float, float));
452 template <>
453 void syevd_bufferSize<double>(
454     CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(double, double));
455 template <>
456 void syevd_bufferSize<c10::complex<float>, float>(
457     CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
458 template <>
459 void syevd_bufferSize<c10::complex<double>, double>(
460     CUDASOLVER_SYEVD_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
461 
462 #define CUDASOLVER_SYEVD_ARGTYPES(scalar_t, value_t)                        \
463   cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
464       int n, scalar_t *A, int lda, value_t *W, scalar_t *work, int lwork,   \
465       int *info
466 
467 template <class scalar_t, class value_t = scalar_t>
syevd(CUDASOLVER_SYEVD_ARGTYPES (scalar_t,value_t))468 void syevd(CUDASOLVER_SYEVD_ARGTYPES(scalar_t, value_t)) {
469   static_assert(false&&sizeof(scalar_t),
470       "at::cuda::solver::syevd: not implemented");
471 }
472 
473 template <>
474 void syevd<float>(CUDASOLVER_SYEVD_ARGTYPES(float, float));
475 template <>
476 void syevd<double>(CUDASOLVER_SYEVD_ARGTYPES(double, double));
477 template <>
478 void syevd<c10::complex<float>, float>(
479     CUDASOLVER_SYEVD_ARGTYPES(c10::complex<float>, float));
480 template <>
481 void syevd<c10::complex<double>, double>(
482     CUDASOLVER_SYEVD_ARGTYPES(c10::complex<double>, double));
483 
484 #define CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(scalar_t, value_t)             \
485   cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
486       int n, const scalar_t *A, int lda, const value_t *W, int *lwork,      \
487       syevjInfo_t params
488 
489 template <class scalar_t, class value_t = scalar_t>
syevj_bufferSize(CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES (scalar_t,value_t))490 void syevj_bufferSize(CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) {
491   static_assert(false&&sizeof(scalar_t),
492       "at::cuda::solver::syevj_bufferSize: not implemented");
493 }
494 
495 template <>
496 void syevj_bufferSize<float>(
497     CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(float, float));
498 template <>
499 void syevj_bufferSize<double>(
500     CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(double, double));
501 template <>
502 void syevj_bufferSize<c10::complex<float>, float>(
503     CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
504 template <>
505 void syevj_bufferSize<c10::complex<double>, double>(
506     CUDASOLVER_SYEVJ_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
507 
508 #define CUDASOLVER_SYEVJ_ARGTYPES(scalar_t, value_t)                        \
509   cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
510       int n, scalar_t *A, int lda, value_t *W, scalar_t *work, int lwork,   \
511       int *info, syevjInfo_t params
512 
513 template <class scalar_t, class value_t = scalar_t>
syevj(CUDASOLVER_SYEVJ_ARGTYPES (scalar_t,value_t))514 void syevj(CUDASOLVER_SYEVJ_ARGTYPES(scalar_t, value_t)) {
515   static_assert(false&&sizeof(scalar_t), "at::cuda::solver::syevj: not implemented");
516 }
517 
518 template <>
519 void syevj<float>(CUDASOLVER_SYEVJ_ARGTYPES(float, float));
520 template <>
521 void syevj<double>(CUDASOLVER_SYEVJ_ARGTYPES(double, double));
522 template <>
523 void syevj<c10::complex<float>, float>(
524     CUDASOLVER_SYEVJ_ARGTYPES(c10::complex<float>, float));
525 template <>
526 void syevj<c10::complex<double>, double>(
527     CUDASOLVER_SYEVJ_ARGTYPES(c10::complex<double>, double));
528 
529 #define CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t)     \
530   cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
531       int n, const scalar_t *A, int lda, const value_t *W, int *lwork,      \
532       syevjInfo_t params, int batchsize
533 
534 template <class scalar_t, class value_t = scalar_t>
syevjBatched_bufferSize(CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES (scalar_t,value_t))535 void syevjBatched_bufferSize(
536     CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) {
537   static_assert(false&&sizeof(scalar_t),
538       "at::cuda::solver::syevjBatched_bufferSize: not implemented");
539 }
540 
541 template <>
542 void syevjBatched_bufferSize<float>(
543     CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(float, float));
544 template <>
545 void syevjBatched_bufferSize<double>(
546     CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(double, double));
547 template <>
548 void syevjBatched_bufferSize<c10::complex<float>, float>(
549     CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
550 template <>
551 void syevjBatched_bufferSize<c10::complex<double>, double>(
552     CUDASOLVER_SYEVJ_BATCHED_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
553 
554 #define CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t)                \
555   cusolverDnHandle_t handle, cusolverEigMode_t jobz, cublasFillMode_t uplo, \
556       int n, scalar_t *A, int lda, value_t *W, scalar_t *work, int lwork,   \
557       int *info, syevjInfo_t params, int batchsize
558 
559 template <class scalar_t, class value_t = scalar_t>
syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES (scalar_t,value_t))560 void syevjBatched(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(scalar_t, value_t)) {
561   static_assert(false&&sizeof(scalar_t),
562       "at::cuda::solver::syevjBatched: not implemented");
563 }
564 
565 template <>
566 void syevjBatched<float>(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(float, float));
567 template <>
568 void syevjBatched<double>(CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(double, double));
569 template <>
570 void syevjBatched<c10::complex<float>, float>(
571     CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(c10::complex<float>, float));
572 template <>
573 void syevjBatched<c10::complex<double>, double>(
574     CUDASOLVER_SYEVJ_BATCHED_ARGTYPES(c10::complex<double>, double));
575 
576 #ifdef USE_CUSOLVER_64_BIT
577 
578 #define CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(scalar_t)                       \
579   cusolverDnHandle_t handle, cusolverDnParams_t params, int64_t m, int64_t n, \
580       const scalar_t *A, int64_t lda, const scalar_t *tau,                    \
581       size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost
582 
583 template <class scalar_t>
xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES (scalar_t))584 void xgeqrf_bufferSize(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(scalar_t)) {
585   static_assert(false&&sizeof(scalar_t),
586       "at::cuda::solver::xgeqrf_bufferSize: not implemented");
587 }
588 
589 template <>
590 void xgeqrf_bufferSize<float>(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(float));
591 template <>
592 void xgeqrf_bufferSize<double>(CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(double));
593 template <>
594 void xgeqrf_bufferSize<c10::complex<float>>(
595     CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(c10::complex<float>));
596 template <>
597 void xgeqrf_bufferSize<c10::complex<double>>(
598     CUDASOLVER_XGEQRF_BUFFERSIZE_ARGTYPES(c10::complex<double>));
599 
600 #define CUDASOLVER_XGEQRF_ARGTYPES(scalar_t)                                  \
601   cusolverDnHandle_t handle, cusolverDnParams_t params, int64_t m, int64_t n, \
602       scalar_t *A, int64_t lda, scalar_t *tau, scalar_t *bufferOnDevice,      \
603       size_t workspaceInBytesOnDevice, scalar_t *bufferOnHost,                \
604       size_t workspaceInBytesOnHost, int *info
605 
606 template <class scalar_t>
xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES (scalar_t))607 void xgeqrf(CUDASOLVER_XGEQRF_ARGTYPES(scalar_t)) {
608   static_assert(false&&sizeof(scalar_t), "at::cuda::solver::xgeqrf: not implemented");
609 }
610 
611 template <>
612 void xgeqrf<float>(CUDASOLVER_XGEQRF_ARGTYPES(float));
613 template <>
614 void xgeqrf<double>(CUDASOLVER_XGEQRF_ARGTYPES(double));
615 template <>
616 void xgeqrf<c10::complex<float>>(
617     CUDASOLVER_XGEQRF_ARGTYPES(c10::complex<float>));
618 template <>
619 void xgeqrf<c10::complex<double>>(
620     CUDASOLVER_XGEQRF_ARGTYPES(c10::complex<double>));
621 
622 #define CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t) \
623   cusolverDnHandle_t handle, cusolverDnParams_t params,          \
624       cusolverEigMode_t jobz, cublasFillMode_t uplo, int64_t n,  \
625       const scalar_t *A, int64_t lda, const value_t *W,          \
626       size_t *workspaceInBytesOnDevice, size_t *workspaceInBytesOnHost
627 
628 template <class scalar_t, class value_t = scalar_t>
xsyevd_bufferSize(CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES (scalar_t,value_t))629 void xsyevd_bufferSize(
630     CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(scalar_t, value_t)) {
631   static_assert(false&&sizeof(scalar_t),
632       "at::cuda::solver::xsyevd_bufferSize: not implemented");
633 }
634 
635 template <>
636 void xsyevd_bufferSize<float>(
637     CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(float, float));
638 template <>
639 void xsyevd_bufferSize<double>(
640     CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(double, double));
641 template <>
642 void xsyevd_bufferSize<c10::complex<float>, float>(
643     CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(c10::complex<float>, float));
644 template <>
645 void xsyevd_bufferSize<c10::complex<double>, double>(
646     CUDASOLVER_XSYEVD_BUFFERSIZE_ARGTYPES(c10::complex<double>, double));
647 
648 #define CUDASOLVER_XSYEVD_ARGTYPES(scalar_t, value_t)                        \
649   cusolverDnHandle_t handle, cusolverDnParams_t params,                      \
650       cusolverEigMode_t jobz, cublasFillMode_t uplo, int64_t n, scalar_t *A, \
651       int64_t lda, value_t *W, scalar_t *bufferOnDevice,                     \
652       size_t workspaceInBytesOnDevice, scalar_t *bufferOnHost,               \
653       size_t workspaceInBytesOnHost, int *info
654 
655 template <class scalar_t, class value_t = scalar_t>
xsyevd(CUDASOLVER_XSYEVD_ARGTYPES (scalar_t,value_t))656 void xsyevd(CUDASOLVER_XSYEVD_ARGTYPES(scalar_t, value_t)) {
657   static_assert(false&&sizeof(scalar_t),
658       "at::cuda::solver::xsyevd: not implemented");
659 }
660 
661 template <>
662 void xsyevd<float>(CUDASOLVER_XSYEVD_ARGTYPES(float, float));
663 template <>
664 void xsyevd<double>(CUDASOLVER_XSYEVD_ARGTYPES(double, double));
665 template <>
666 void xsyevd<c10::complex<float>, float>(
667     CUDASOLVER_XSYEVD_ARGTYPES(c10::complex<float>, float));
668 template <>
669 void xsyevd<c10::complex<double>, double>(
670     CUDASOLVER_XSYEVD_ARGTYPES(c10::complex<double>, double));
671 
672 #endif // USE_CUSOLVER_64_BIT
673 
674 } // namespace solver
675 } // namespace cuda
676 } // namespace at
677 
678 #endif // CUDART_VERSION
679