1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/grad_mode.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorMeta.h>
7 #include <ATen/TensorOperators.h>
8 #include <ATen/TensorSubclassLikeUtils.h>
9
10 #include <ATen/native/BatchLinearAlgebra.h>
11 #include <ATen/native/LinearAlgebraUtils.h>
12 #include <ATen/native/Resize.h>
13 #include <ATen/native/cpu/zmath.h>
14
15 #include <c10/util/irange.h>
16
17 #include <utility>
18 #include <vector>
19
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #include <ATen/NativeFunctions.h>
23 #else
24 #include <ATen/ops/_spsolve.h>
25 #include <ATen/ops/_cholesky_solve_helper.h>
26 #include <ATen/ops/_cholesky_solve_helper_native.h>
27 #include <ATen/ops/_linalg_check_errors.h>
28 #include <ATen/ops/_linalg_check_errors_native.h>
29 #include <ATen/ops/_linalg_eigh.h>
30 #include <ATen/ops/_linalg_eigh_meta.h>
31 #include <ATen/ops/_linalg_eigh_native.h>
32 #include <ATen/ops/_linalg_eigvals.h>
33 #include <ATen/ops/_linalg_eigvals_native.h>
34 #include <ATen/ops/_linalg_solve_ex.h>
35 #include <ATen/ops/_linalg_solve_ex_meta.h>
36 #include <ATen/ops/_linalg_solve_ex_native.h>
37 #include <ATen/ops/_linalg_svd.h>
38 #include <ATen/ops/_linalg_svd_meta.h>
39 #include <ATen/ops/_linalg_svd_native.h>
40 #include <ATen/ops/_lu_with_info_native.h>
41 #include <ATen/ops/all.h>
42 #include <ATen/ops/arange.h>
43 #include <ATen/ops/cat.h>
44 #include <ATen/ops/cholesky.h>
45 #include <ATen/ops/cholesky_inverse.h>
46 #include <ATen/ops/cholesky_inverse_native.h>
47 #include <ATen/ops/cholesky_native.h>
48 #include <ATen/ops/cholesky_solve.h>
49 #include <ATen/ops/cholesky_solve_native.h>
50 #include <ATen/ops/clone.h>
51 #include <ATen/ops/complex.h>
52 #include <ATen/ops/cumprod.h>
53 #include <ATen/ops/empty.h>
54 #include <ATen/ops/empty_like.h>
55 #include <ATen/ops/geqrf.h>
56 #include <ATen/ops/geqrf_native.h>
57 #include <ATen/ops/inverse_native.h>
58 #include <ATen/ops/linalg_cholesky_ex.h>
59 #include <ATen/ops/linalg_cholesky_ex_meta.h>
60 #include <ATen/ops/linalg_cholesky_ex_native.h>
61 #include <ATen/ops/linalg_cholesky_native.h>
62 #include <ATen/ops/linalg_eig.h>
63 #include <ATen/ops/linalg_eig_native.h>
64 #include <ATen/ops/linalg_eigh_native.h>
65 #include <ATen/ops/linalg_eigvals.h>
66 #include <ATen/ops/linalg_eigvals_native.h>
67 #include <ATen/ops/linalg_eigvalsh_native.h>
68 #include <ATen/ops/linalg_householder_product.h>
69 #include <ATen/ops/linalg_householder_product_native.h>
70 #include <ATen/ops/linalg_inv.h>
71 #include <ATen/ops/linalg_inv_ex.h>
72 #include <ATen/ops/linalg_inv_ex_native.h>
73 #include <ATen/ops/linalg_inv_native.h>
74 #include <ATen/ops/linalg_ldl_factor_ex.h>
75 #include <ATen/ops/linalg_ldl_factor_ex_meta.h>
76 #include <ATen/ops/linalg_ldl_factor_ex_native.h>
77 #include <ATen/ops/linalg_ldl_factor_native.h>
78 #include <ATen/ops/linalg_ldl_solve_meta.h>
79 #include <ATen/ops/linalg_ldl_solve_native.h>
80 #include <ATen/ops/linalg_lstsq.h>
81 #include <ATen/ops/linalg_lstsq_native.h>
82 #include <ATen/ops/linalg_lu_factor_ex.h>
83 #include <ATen/ops/linalg_lu_factor_ex_meta.h>
84 #include <ATen/ops/linalg_lu_factor_ex_native.h>
85 #include <ATen/ops/linalg_lu_factor_native.h>
86 #include <ATen/ops/linalg_lu_meta.h>
87 #include <ATen/ops/linalg_lu_native.h>
88 #include <ATen/ops/linalg_lu_solve.h>
89 #include <ATen/ops/linalg_lu_solve_meta.h>
90 #include <ATen/ops/linalg_lu_solve_native.h>
91 #include <ATen/ops/linalg_qr.h>
92 #include <ATen/ops/linalg_qr_meta.h>
93 #include <ATen/ops/linalg_qr_native.h>
94 #include <ATen/ops/linalg_solve_ex.h>
95 #include <ATen/ops/linalg_solve_ex_native.h>
96 #include <ATen/ops/linalg_solve_native.h>
97 #include <ATen/ops/linalg_solve_triangular_native.h>
98 #include <ATen/ops/linalg_svd.h>
99 #include <ATen/ops/linalg_svd_native.h>
100 #include <ATen/ops/linalg_svdvals.h>
101 #include <ATen/ops/linalg_svdvals_native.h>
102 #include <ATen/ops/linalg_vander_native.h>
103 #include <ATen/ops/linalg_vecdot_native.h>
104 #include <ATen/ops/lu_solve_native.h>
105 #include <ATen/ops/lu_unpack.h>
106 #include <ATen/ops/lu_unpack_meta.h>
107 #include <ATen/ops/lu_unpack_native.h>
108 #include <ATen/ops/orgqr_native.h>
109 #include <ATen/ops/ormqr_native.h>
110 #include <ATen/ops/qr_native.h>
111 #include <ATen/ops/real.h>
112 #include <ATen/ops/resize_as_native.h>
113 #include <ATen/ops/sum.h>
114 #include <ATen/ops/svd_native.h>
115 #include <ATen/ops/triangular_solve_meta.h>
116 #include <ATen/ops/triangular_solve_native.h>
117 #include <ATen/ops/tril.h>
118 #include <ATen/ops/triu.h>
119 #include <ATen/ops/vdot.h>
120 #include <ATen/ops/zeros.h>
121 #endif
122
123 // First the required LAPACK implementations are registered here.
124 // A comment above the registered LAPACK routine suggest which batched
125 // linear algebra function uses that routine
126 #if AT_BUILD_WITH_LAPACK()
127
128 // getrf
129 extern "C" void zgetrf_(int *m, int *n, std::complex<double> *a, int *lda, int *ipiv, int *info);
130 extern "C" void cgetrf_(int *m, int *n, std::complex<float> *a, int *lda, int *ipiv, int *info);
131 extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info);
132 extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);
133
134 // potrs
135 extern "C" void zpotrs_(char *uplo, int *n, int *nrhs, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info);
136 extern "C" void cpotrs_(char *uplo, int *n, int *nrhs, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, int *info);
137 extern "C" void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
138 extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
139
140 // potrf
141 extern "C" void zpotrf_(char *uplo, int *n, std::complex<double> *a, int *lda, int *info);
142 extern "C" void cpotrf_(char *uplo, int *n, std::complex<float> *a, int *lda, int *info);
143 extern "C" void dpotrf_(char *uplo, int *n, double *a, int *lda, int *info);
144 extern "C" void spotrf_(char *uplo, int *n, float *a, int *lda, int *info);
145
146 // potri
147 extern "C" void zpotri_(char *uplo, int *n, std::complex<double> *a, int *lda, int *info);
148 extern "C" void cpotri_(char *uplo, int *n, std::complex<float> *a, int *lda, int *info);
149 extern "C" void dpotri_(char *uplo, int *n, double *a, int *lda, int *info);
150 extern "C" void spotri_(char *uplo, int *n, float *a, int *lda, int *info);
151
152 // sytrf
153 extern "C" void dsytrf_(
154 char* uplo,
155 int* n,
156 double* a,
157 int* lda,
158 int* ipiv,
159 double* work,
160 int* lwork,
161 int* info);
162 extern "C" void ssytrf_(
163 char* uplo,
164 int* n,
165 float* a,
166 int* lda,
167 int* ipiv,
168 float* work,
169 int* lwork,
170 int* info);
171 extern "C" void zsytrf_(
172 char* uplo,
173 int* n,
174 std::complex<double>* a,
175 int* lda,
176 int* ipiv,
177 std::complex<double>* work,
178 int* lwork,
179 int* info);
180 extern "C" void csytrf_(
181 char* uplo,
182 int* n,
183 std::complex<float>* a,
184 int* lda,
185 int* ipiv,
186 std::complex<float>* work,
187 int* lwork,
188 int* info);
189
190 // hetrf
191 extern "C" void zhetrf_(
192 char* uplo,
193 int* n,
194 std::complex<double>* a,
195 int* lda,
196 int* ipiv,
197 std::complex<double>* work,
198 int* lwork,
199 int* info);
200 extern "C" void chetrf_(
201 char* uplo,
202 int* n,
203 std::complex<float>* a,
204 int* lda,
205 int* ipiv,
206 std::complex<float>* work,
207 int* lwork,
208 int* info);
209
210 // sytrs
211 extern "C" void dsytrs_(
212 char* uplo,
213 int* n,
214 int* nrhs,
215 double* a,
216 int* lda,
217 int* ipiv,
218 double* b,
219 int* ldb,
220 int* info);
221 extern "C" void ssytrs_(
222 char* uplo,
223 int* n,
224 int* nrhs,
225 float* a,
226 int* lda,
227 int* ipiv,
228 float* b,
229 int* ldb,
230 int* info);
231 extern "C" void zsytrs_(
232 char* uplo,
233 int* n,
234 int* nrhs,
235 std::complex<double>* a,
236 int* lda,
237 int* ipiv,
238 std::complex<double>* b,
239 int* ldb,
240 int* info);
241 extern "C" void csytrs_(
242 char* uplo,
243 int* n,
244 int* nrhs,
245 std::complex<float>* a,
246 int* lda,
247 int* ipiv,
248 std::complex<float>* b,
249 int* ldb,
250 int* info);
251
252 // hetrs
253 extern "C" void zhetrs_(
254 char* uplo,
255 int* n,
256 int* nrhs,
257 std::complex<double>* a,
258 int* lda,
259 int* ipiv,
260 std::complex<double>* b,
261 int* ldb,
262 int* info);
263 extern "C" void chetrs_(
264 char* uplo,
265 int* n,
266 int* nrhs,
267 std::complex<float>* a,
268 int* lda,
269 int* ipiv,
270 std::complex<float>* b,
271 int* ldb,
272 int* info);
273
274 // geqrf
275 extern "C" void zgeqrf_(int *m, int *n, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *work, int *lwork, int *info);
276 extern "C" void cgeqrf_(int *m, int *n, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *work, int *lwork, int *info);
277 extern "C" void dgeqrf_(int *m, int *n, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
278 extern "C" void sgeqrf_(int *m, int *n, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
279
280 // orgqr
281 extern "C" void zungqr_(int *m, int *n, int *k, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *work, int *lwork, int *info);
282 extern "C" void cungqr_(int *m, int *n, int *k, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *work, int *lwork, int *info);
283 extern "C" void dorgqr_(int *m, int *n, int *k, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
284 extern "C" void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
285
286 // ormqr
287 extern "C" void zunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *c, int *ldc, std::complex<double> *work, int *lwork, int *info);
288 extern "C" void cunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *c, int *ldc, std::complex<float> *work, int *lwork, int *info);
289 extern "C" void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info);
290 extern "C" void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info);
291
292 // syevd
293 extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex<double> *a, int *lda, double *w, std::complex<double> *work, int *lwork, double *rwork, int *lrwork, int *iwork, int *liwork, int *info);
294 extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex<float> *a, int *lda, float *w, std::complex<float> *work, int *lwork, float *rwork, int *lrwork, int *iwork, int *liwork, int *info);
295 extern "C" void dsyevd_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *iwork, int *liwork, int *info);
296 extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *iwork, int *liwork, int *info);
297
298 // geev
299 extern "C" void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info);
300 extern "C" void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info);
301 extern "C" void cgeev_(char *jobvl, char *jobvr, int *n,
302 std::complex<float> *a, int *lda,
303 std::complex<float> *w,
304 std::complex<float> *vl, int *ldvl,
305 std::complex<float> *vr, int *ldvr,
306 std::complex<float> *work, int *lwork,
307 float *rwork,
308 int *info);
309 extern "C" void zgeev_(char *jobvl, char *jobvr, int *n,
310 std::complex<double> *a, int *lda,
311 std::complex<double> *w,
312 std::complex<double> *vl, int *ldvl,
313 std::complex<double> *vr, int *ldvr,
314 std::complex<double> *work, int *lwork,
315 double *rwork,
316 int *info);
317
318 // gesdd
319 extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex<double> *a, int *lda,
320 double *s, std::complex<double> *u, int *ldu, std::complex<double> *vt, int *ldvt, std::complex<double> *work, int *lwork, double *rwork, int *iwork, int *info);
321 extern "C" void cgesdd_(char *jobz, int *m, int *n, std::complex<float> *a, int *lda,
322 float *s, std::complex<float> *u, int *ldu, std::complex<float> *vt, int *ldvt, std::complex<float> *work, int *lwork, float *rwork, int *iwork, int *info);
323 extern "C" void dgesdd_(char *jobz, int *m, int *n, double *a, int *lda,
324 double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *iwork, int *info);
325 extern "C" void sgesdd_(char *jobz, int *m, int *n, float *a, int *lda,
326 float *s, float *u, int *ldu, float *vt, int *ldvt, float *work, int *lwork, int *iwork, int *info);
327
328 // getrs
329 extern "C" void zgetrs_(char *trans, int *n, int *nrhs, std::complex<double> *a, int *lda, int *ipiv, std::complex<double> *b, int *ldb, int *info);
330 extern "C" void cgetrs_(char *trans, int *n, int *nrhs, std::complex<float> *a, int *lda, int *ipiv, std::complex<float> *b, int *ldb, int *info);
331 extern "C" void dgetrs_(char *trans, int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
332 extern "C" void sgetrs_(char *trans, int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);
333
334 // gels
335 extern "C" void zgels_(char *trans, int *m, int *n, int *nrhs,
336 std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb,
337 std::complex<double> *work, int *lwork, int *info);
338 extern "C" void cgels_(char *trans, int *m, int *n, int *nrhs,
339 std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb,
340 std::complex<float> *work, int *lwork, int *info);
341 extern "C" void dgels_(char *trans, int *m, int *n, int *nrhs,
342 double *a, int *lda, double *b, int *ldb,
343 double *work, int *lwork, int *info);
344 extern "C" void sgels_(char *trans, int *m, int *n, int *nrhs,
345 float *a, int *lda, float *b, int *ldb,
346 float *work, int *lwork, int *info);
347
348 // gelsd
349 extern "C" void zgelsd_(int *m, int *n, int *nrhs,
350 std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb,
351 double *s, double *rcond, int *rank,
352 std::complex<double> *work, int *lwork, double *rwork, int *iwork, int *info);
353 extern "C" void cgelsd_(int *m, int *n, int *nrhs,
354 std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb,
355 float *s, float *rcond, int *rank,
356 std::complex<float> *work, int *lwork, float *rwork, int *iwork, int *info);
357 extern "C" void dgelsd_(int *m, int *n, int *nrhs,
358 double *a, int *lda, double *b, int *ldb,
359 double *s, double *rcond, int *rank,
360 double *work, int *lwork, int *iwork, int *info);
361 extern "C" void sgelsd_(int *m, int *n, int *nrhs,
362 float *a, int *lda, float *b, int *ldb,
363 float *s, float *rcond, int *rank,
364 float *work, int *lwork, int *iwork, int *info);
365
366 // gelsy
367 extern "C" void zgelsy_(int *m, int *n, int *nrhs,
368 std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb,
369 int *jpvt, double *rcond, int *rank,
370 std::complex<double> *work, int *lwork,
371 double *rwork, int *info);
372 extern "C" void cgelsy_(int *m, int *n, int *nrhs,
373 std::complex<float> * a, int *lda, std::complex<float> *b, int *ldb,
374 int *jpvt, float *rcond, int *rank,
375 std::complex<float> *work, int *lwork,
376 float *rwork, int *info);
377 extern "C" void dgelsy_(int *m, int *n, int *nrhs,
378 double *a, int *lda, double *b, int *ldb,
379 int *jpvt, double *rcond, int *rank,
380 double *work, int *lwork, int *info);
381 extern "C" void sgelsy_(int *m, int *n, int *nrhs,
382 float *a, int *lda, float *b, int *ldb,
383 int *jpvt, float *rcond, int *rank,
384 float *work, int *lwork, int *info);
385
386 // gelss
387 extern "C" void zgelss_(int *m, int *n, int *nrhs,
388 std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb,
389 double *s, double *rcond, int *rank,
390 std::complex<double> *work, int *lwork,
391 double *rwork, int *info);
392 extern "C" void cgelss_(int *m, int *n, int *nrhs,
393 std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb,
394 float *s, float *rcond, int *rank,
395 std::complex<float> *work, int *lwork,
396 float *rwork, int *info);
397 extern "C" void dgelss_(int *m, int *n, int *nrhs,
398 double *a, int *lda, double *b, int *ldb,
399 double *s, double *rcond, int *rank,
400 double *work, int *lwork, int *info);
401 extern "C" void sgelss_(int *m, int *n, int *nrhs,
402 float *a, int *lda, float *b, int *ldb,
403 float *s, float *rcond, int *rank,
404 float *work, int *lwork, int *info);
405 #endif
406
407 #if AT_BUILD_WITH_BLAS()
408 // trsm
409 extern "C" void ztrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<double> *alpha, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb);
410 extern "C" void ctrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<float> *alpha, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb);
411 extern "C" void dtrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, double *alpha, double *a, int *lda, double *b, int *ldb);
412 extern "C" void strsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, float *alpha, float *a, int *lda, float *b, int *ldb);
413 #endif
414
415 namespace at::meta {
416
TORCH_META_FUNC(linalg_ldl_factor_ex)417 TORCH_META_FUNC(linalg_ldl_factor_ex)
418 (const Tensor& self, bool hermitian, bool check_errors) {
419 at::native::squareCheckInputs(self, "torch.linalg.ldl_factor_ex");
420 at::native::checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex");
421
422 auto shape = self.sizes();
423 auto ndim = shape.size();
424
425 // prefer column major strides
426 auto ld_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig=*/true);
427 set_output_strided(0, shape, ld_strides, self.options(), {}); // LD
428
429 set_output_contiguous(
430 1, shape.slice(0, ndim - 1), self.options().dtype(ScalarType::Int)); // pivots
431
432 set_output_contiguous(
433 2, shape.slice(0, ndim - 2), self.options().dtype(ScalarType::Int)); // info
434 }
435
TORCH_META_FUNC(linalg_ldl_solve)436 TORCH_META_FUNC(linalg_ldl_solve)
437 (const Tensor& LD,
438 const Tensor& pivots,
439 const Tensor& B,
440 bool hermitian) {
441 at::native::squareCheckInputs(LD, "torch.linalg.ldl_solve");
442 at::native::checkFloatingOrComplex(LD, "torch.linalg.ldl_solve");
443 at::native::linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve");
444 TORCH_CHECK(
445 B.dim() >= 2,
446 "torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, but it has ",
447 B.dim(),
448 " dimensions instead");
449 auto expected_pivots_shape = LD.sizes().slice(0, LD.dim() - 1);
450 TORCH_CHECK(
451 expected_pivots_shape.equals(pivots.sizes()),
452 "torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, but got pivots with shape ",
453 pivots.sizes(),
454 " instead");
455 // pivots is allowed to be any integer type
456 // LAPACK we use is 32-bit interface while cuSOLVER uses 64-bit interface for integers
457 TORCH_CHECK(
458 at::isIntegralType(pivots.scalar_type(), /*includeBool=*/false),
459 "torch.linalg.ldl_solve: Expected pivots to be integers. Got ",
460 pivots.scalar_type());
461 TORCH_CHECK(
462 LD.scalar_type() == B.scalar_type(),
463 "torch.linalg.ldl_solve: ",
464 "LD dtype",
465 LD.scalar_type(),
466 " does not match b dtype ",
467 B.scalar_type());
468
469 auto [B_broadcast_size, _] = at::native::_linalg_broadcast_batch_dims(B, LD);
470
471 // prefer column major strides
472 auto result_strides = at::native::batched_matrix_contiguous_strides(B_broadcast_size, /*f_contig=*/true);
473 set_output_strided(0, B_broadcast_size, result_strides, B.options(), {});
474 }
475
TORCH_META_FUNC(triangular_solve)476 TORCH_META_FUNC(triangular_solve)(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
477 TORCH_CHECK(self.dim() >= 2,
478 "torch.triangular_solve: Expected b to have at least 2 dimensions, but it has ", self.dim(), " dimensions instead");
479 TORCH_CHECK(A.dim() >= 2,
480 "torch.triangular_solve: Expected A to have at least 2 dimensions, but it has ", A.dim(), " dimensions instead");
481
482 at::native::linearSolveCheckInputs(self, A, "triangular_solve");
483
484 if (A.layout() == Layout::Strided) {
485 auto [self_broadcast_size, A_broadcast_size] = at::native::_linalg_broadcast_batch_dims(self, A);
486
487 // make column major strides for BLAS
488 const auto solution_strides = at::native::batched_matrix_contiguous_strides(self_broadcast_size, /*f-contig=*/true);
489 set_output_raw_strided(0, self_broadcast_size, solution_strides, self.options(), {});
490
491 // make column major strides for BLAS
492 auto clone_A_strides = at::native::batched_matrix_contiguous_strides(A_broadcast_size, /*f_contig=*/true);
493 set_output_raw_strided(1, A_broadcast_size, clone_A_strides, A.options(), {});
494 } else if (A.layout() == Layout::SparseCsr || A.layout() == Layout::SparseBsr) {
495 // no broadcasting for non-strided layout
496 set_output_raw_strided(0, self.sizes(), {}, self.options(), {}); // make row major strides for Sparse BLAS
497 set_output_raw_strided(1, {0}, {}, self.options(), {}); // return 0-sized tensor
498 } else {
499 TORCH_INTERNAL_ASSERT(false, "triangular_solve: Got an unexpected layout.");
500 }
501 }
502
TORCH_META_FUNC(_linalg_solve_ex)503 TORCH_META_FUNC(_linalg_solve_ex)(const Tensor& A,
504 const Tensor& B,
505 bool left,
506 bool check_errors) {
507 // dtype
508 at::native::checkFloatingOrComplex(A, "linalg.solve");
509 TORCH_CHECK(A.scalar_type() == B.scalar_type(),
510 "linalg.solve: Expected A and B to have the same dtype, but found A of type ",
511 A.scalar_type(), " and B of type ", B.scalar_type(), " instead");
512
513 // NumPy compat: Two types of 'B' tensors are supported:
514 // - 1D tensor or batch of 1D tensors (vector case)
515 // - 2D tensor or batch of 2D tensors (matrix case)
516 const bool vector_case = at::native::linalg_solve_is_vector_rhs(A, B);
517 auto B_ = vector_case ? B.unsqueeze(-1) : B;
518
519 // matrix shapes
520 at::native::checkInputsSolver(A, B_, /*left=*/left, "linalg.solve");
521
522 // Check that B can be broadcasted to the shape of A
523 auto B_broad_shape = std::get<0>(at::native::_linalg_broadcast_batch_dims(B_, A));
524 // We disallow the broadcasting of B as a vector when left=False as, in that case, A.shape = (*, 1, 1)
525 TORCH_CHECK(left || !vector_case, "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. In this case linalg.solve is equivalent to B / A.squeeze(-1)");
526 auto result_shape = vector_case ? IntArrayRef(B_broad_shape.data(), B_broad_shape.size() - 1)
527 : B_broad_shape;
528 auto result_strides = at::native::batched_matrix_contiguous_strides(result_shape, /*f_contig=*/left);
529
530 set_output_strided(0, result_shape, result_strides, B.options(), {});
531
532 auto shape = A.sizes();
533 auto ndim = shape.size();
534
535 // LU
536 auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
537 set_output_strided(1, shape, LU_strides, A.options(), {});
538
539 // pivots
540 set_output_contiguous(2, shape.slice(0, ndim - 1), A.options().dtype(kInt));
541
542 // info
543 set_output_contiguous(3, shape.slice(0, ndim - 2), A.options().dtype(kInt));
544 }
545
TORCH_META_FUNC(linalg_inv_ex)546 TORCH_META_FUNC(linalg_inv_ex)(const Tensor& A, bool check_errors) {
547 at::native::squareCheckInputs(A, "linalg.inv");
548 at::native::checkFloatingOrComplex(A, "linalg.inv", /*allow_low_precision_dtypes*/false);
549
550 auto shape = A.sizes();
551
552 auto result_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
553 set_output_strided(0, shape, result_strides, A.options(), {});
554 set_output_contiguous(
555 1, shape.slice(0, shape.size() - 2), A.options().dtype(ScalarType::Int)); // info
556 }
557
TORCH_META_FUNC(linalg_lu_factor_ex)558 TORCH_META_FUNC(linalg_lu_factor_ex)(const Tensor& A, bool pivot, bool check_errors) {
559 TORCH_CHECK(A.dim() >= 2, "torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: ", A.sizes(), " instead");
560
561 auto sizes = A.sizes().vec();
562 const auto m = sizes.cend()[-2];
563 const auto n = sizes.cend()[-1];
564
565 // make column major strides for BLAS
566 auto LU_strides = at::native::batched_matrix_contiguous_strides(sizes, /*f-contig*=*/true);
567 set_output_strided(0, sizes, LU_strides, A.options(), {});
568
569 // Set sizes to the size of pivots
570 sizes.pop_back();
571 sizes.back() = std::min(m, n);
572 set_output_contiguous(1, sizes, A.options().dtype(kInt), {});
573
574 // Set sizes to the size of info
575 sizes.pop_back();
576 set_output_contiguous(2, sizes, A.options().dtype(kInt), {});
577 }
578
TORCH_META_FUNC(linalg_lu_solve)579 TORCH_META_FUNC(linalg_lu_solve)(const Tensor& LU,
580 const Tensor& pivots,
581 const Tensor& B,
582 bool left,
583 bool adjoint) {
584 // dtype
585 at::native::checkFloatingOrComplex(LU, "torch.linalg.lu_solve");
586 TORCH_CHECK(LU.scalar_type() == B.scalar_type(),
587 "linalg.lu_solve: Expected LU and B to have the same dtype, but found LU of type ",
588 LU.scalar_type(), " and B of type ", B.scalar_type(), " instead");
589 TORCH_CHECK(pivots.dtype() == at::kInt,
590 "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32");
591
592 // matrix shapes
593 at::native::squareCheckInputs(LU, "torch.linalg.lu_solve");
594 at::native::checkInputsSolver(LU, B, left, "linalg.lu_solve");
595 //
596 TORCH_CHECK(LU.size(-1) == pivots.size(-1),
597 "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix");
598
599 // batches
600 TORCH_CHECK(
601 LU.sizes().slice(0, LU.dim() - 1).equals(pivots.sizes()),
602 "linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, but got pivots with shape ",
603 pivots.sizes(), " instead");
604
605 // This one checks that B can be broadcasted to the shape of A
606 auto B_broadcast_size = std::get<0>(at::native::_linalg_broadcast_batch_dims(B, LU));
607 auto result_strides = at::native::batched_matrix_contiguous_strides(B_broadcast_size, /*f_contig=*/left);
608
609 set_output_strided(0, B_broadcast_size, result_strides, B.options(), {});
610 }
611
TORCH_META_FUNC(linalg_cholesky_ex)612 TORCH_META_FUNC(linalg_cholesky_ex)(const Tensor& A,
613 bool upper,
614 bool check_errors) {
615 at::native::squareCheckInputs(A, "linalg.cholesky");
616 at::native::checkFloatingOrComplex(A, "linalg.cholesky");
617
618 auto A_shape = A.sizes();
619 auto ndim = A_shape.size();
620
621 // L
622 auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/true);
623 set_output_strided(0, A_shape, L_strides, A.options(), {});
624
625 // info
626 set_output_contiguous(1, A_shape.slice(0, ndim - 2), A.options().dtype(ScalarType::Int));
627 }
628
TORCH_META_FUNC(linalg_qr)629 TORCH_META_FUNC(linalg_qr)(const Tensor& A,
630 c10::string_view mode) {
631 at::native::checkIsMatrix(A, "linalg.qr");
632 at::native::checkFloatingOrComplex(A, "linalg.qr");
633 auto [compute_q, reduced_mode] = at::native::_parse_qr_mode(mode);
634
635 auto A_shape = A.sizes().vec();
636 const auto m = A_shape.cend()[-2];
637 const auto n = A_shape.cend()[-1];
638 const auto k = std::min(m, n);
639
640 if (compute_q) {
641 auto Q_shape = A_shape;
642 Q_shape.end()[-1] = reduced_mode ? k : m;
643 auto Q_strides = at::native::batched_matrix_contiguous_strides(Q_shape, /*f-contig*=*/true);
644 set_output_strided(0, Q_shape, Q_strides, A.options(), {});
645 } else {
646 set_output_raw_strided(0, {0}, {}, A.options(), {});
647 }
648
649 // For readability
650 auto R_shape = std::move(A_shape);
651 R_shape.end()[-2] = (reduced_mode || !compute_q) ? k : m;
652 auto R_strides = at::native::batched_matrix_contiguous_strides(R_shape, /*f-contig*=*/true);
653 set_output_strided(1, R_shape, R_strides, A.options(), {});
654 }
655
656
TORCH_META_FUNC(_linalg_svd)657 TORCH_META_FUNC(_linalg_svd)(const Tensor& A,
658 bool full_matrices,
659 bool compute_uv,
660 std::optional<c10::string_view> driver) {
661 at::native::checkIsMatrix(A, "linalg.svd");
662 at::native::checkFloatingOrComplex(A, "linalg.svd");
663
664 auto sizes = A.sizes().vec();
665 const auto m = sizes.cend()[-2];
666 const auto n = sizes.cend()[-1];
667 const auto k = std::min(m, n);
668
669 // Prepare sizes for U
670 if (compute_uv) {
671 sizes.back() = full_matrices ? m : k;
672 auto U_strides = at::native::batched_matrix_contiguous_strides(sizes, /*f-contig*=*/true);
673 set_output_strided(0, sizes, U_strides, A.options(), {});
674
675 // Prepare sizes for Vh
676 sizes.end()[-2] = full_matrices ? n : k;
677 sizes.end()[-1] = n;
678
679 // We need to distinguish the cuSOLVER case, as the cuSOLVER algorithms we use
680 // expect F-contig matrices, but they compute V rather than Vh
681 const bool use_cusolver = at::native::svd_uses_cusolver(A);
682 auto Vh_strides = at::native::batched_matrix_contiguous_strides(sizes, /*f-contig*=*/!use_cusolver);
683 set_output_strided(2, sizes, Vh_strides, A.options(), {});
684 } else {
685 set_output_raw_strided(0, {0}, {}, A.options(), {});
686 set_output_raw_strided(2, {0}, {}, A.options(), {});
687 }
688
689 // Prepare sizes for S. S is always real, even when A is complex.
690 sizes.pop_back();
691 sizes.end()[-1] = k;
692 set_output_contiguous(1, sizes, A.options().dtype(c10::toRealValueType(A.scalar_type())), {});
693 }
694
TORCH_META_FUNC(lu_unpack)695 TORCH_META_FUNC(lu_unpack)(const Tensor& LU, const Tensor& pivots, bool unpack_data, bool unpack_pivots) {
696 TORCH_CHECK(LU.dim() >= 2, "torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: ", LU.sizes(), " instead");
697 if (unpack_pivots) {
698 TORCH_CHECK(pivots.scalar_type() == at::kInt,
699 "torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n"
700 "Note: this function is intended to be used with the output produced by torch.linalg.lu_factor");
701 }
702
703 auto sizes = LU.sizes().vec();
704 const auto m = sizes.cend()[-2];
705 const auto n = sizes.cend()[-1];
706 const auto k = std::min(m, n);
707
708 // P.shape[-2:] == (m, m) (or size zero if pivot == False)
709 sizes.end()[-1] = m;
710 if (unpack_pivots) {
711 set_output_raw_strided(0, sizes, {}, LU.options(), {});
712 } else {
713 set_output_raw_strided(0, {0}, {}, LU.options(), {});
714 }
715
716 if (unpack_data) {
717 // L.shape[-2:] == (m, k)
718 sizes.end()[-1] = k;
719 set_output_raw_strided(1, sizes, {}, LU.options(), {});
720
721 // U.shape[-2:] == (k, n)
722 sizes.end()[-2] = k;
723 sizes.end()[-1] = n;
724 set_output_raw_strided(2, sizes, {}, LU.options(), {});
725 } else {
726 set_output_raw_strided(1, {0}, {}, LU.options(), {});
727 set_output_raw_strided(2, {0}, {}, LU.options(), {});
728 }
729 }
730
TORCH_META_FUNC(_linalg_eigh)731 TORCH_META_FUNC(_linalg_eigh)(const Tensor& A,
732 c10::string_view uplo,
733 bool compute_v) {
734 at::native::squareCheckInputs(A, "linalg.eigh");
735 at::native::checkUplo(uplo);
736
737 auto shape = A.sizes().vec();
738 if (compute_v) {
739 // eigenvectors
740 auto V_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
741 set_output_strided(1, shape, V_strides, A.options(), {});
742 } else {
743 set_output_raw_strided(1, {0}, {}, A.options(), {});
744 }
745
746 // eigenvalues
747 shape.pop_back();
748 set_output_contiguous(0, shape, A.options().dtype(c10::toRealValueType(A.scalar_type())), {});
749 }
750
TORCH_META_FUNC(linalg_lu)751 TORCH_META_FUNC(linalg_lu)(const Tensor& A, bool pivot) {
752 TORCH_CHECK(A.dim() >= 2, "linalg.lu: Expected tensor with 2 or more dimensions. Got size: ", A.sizes(), " instead");
753
754 auto sizes = A.sizes().vec();
755 const auto m = sizes.cend()[-2];
756 const auto n = sizes.cend()[-1];
757 const auto k = std::min(m, n);
758
759 // P.shape[-2:] == (m, m) (or size zero if pivot == False)
760 sizes.end()[-1] = m;
761 if (pivot) {
762 set_output_raw_strided(0, sizes, {}, A.options(), {});
763 } else {
764 set_output_raw_strided(0, {0}, {}, A.options(), {});
765 }
766
767 // L.shape[-2:] == (m, k)
768 sizes.end()[-1] = k;
769 set_output_raw_strided(1, sizes, {}, A.options(), {});
770
771 // U.shape[-2:] == (k, n)
772 sizes.end()[-2] = k;
773 sizes.end()[-1] = n;
774 set_output_raw_strided(2, sizes, {}, A.options(), {});
775 }
776
777 } // namespace at::meta
778
779 namespace at::native {
780
781 #if AT_BUILD_WITH_LAPACK()
782 // Define the per-batch functions to be used in the main implementation of the batched
783 // linear algebra operations
784
785 template<class scalar_t>
786 void lapackCholeskySolve(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info);
787
788 template<class scalar_t, class value_t=scalar_t>
789 void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int *info);
790
lapackLu(int m,int n,c10::complex<double> * a,int lda,int * ipiv,int * info)791 template<> void lapackLu<c10::complex<double>>(int m, int n, c10::complex<double> *a, int lda, int *ipiv, int *info) {
792 zgetrf_(&m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, ipiv, info);
793 }
794
lapackLu(int m,int n,c10::complex<float> * a,int lda,int * ipiv,int * info)795 template<> void lapackLu<c10::complex<float>>(int m, int n, c10::complex<float> *a, int lda, int *ipiv, int *info) {
796 cgetrf_(&m, &n, reinterpret_cast<std::complex<float>*>(a), &lda, ipiv, info);
797 }
798
lapackLu(int m,int n,double * a,int lda,int * ipiv,int * info)799 template<> void lapackLu<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
800 dgetrf_(&m, &n, a, &lda, ipiv, info);
801 }
802
lapackLu(int m,int n,float * a,int lda,int * ipiv,int * info)803 template<> void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
804 sgetrf_(&m, &n, a, &lda, ipiv, info);
805 }
806
lapackCholeskySolve(char uplo,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb,int * info)807 template<> void lapackCholeskySolve<c10::complex<double>>(char uplo, int n, int nrhs, c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb, int *info) {
808 zpotrs_(&uplo, &n, &nrhs, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(b), &ldb, info);
809 }
810
lapackCholeskySolve(char uplo,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb,int * info)811 template<> void lapackCholeskySolve<c10::complex<float>>(char uplo, int n, int nrhs, c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb, int *info) {
812 cpotrs_(&uplo, &n, &nrhs, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(b), &ldb, info);
813 }
814
lapackCholeskySolve(char uplo,int n,int nrhs,double * a,int lda,double * b,int ldb,int * info)815 template<> void lapackCholeskySolve<double>(char uplo, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
816 dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
817 }
818
lapackCholeskySolve(char uplo,int n,int nrhs,float * a,int lda,float * b,int ldb,int * info)819 template<> void lapackCholeskySolve<float>(char uplo, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
820 spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
821 }
822
lapackCholesky(char uplo,int n,c10::complex<double> * a,int lda,int * info)823 template<> void lapackCholesky<c10::complex<double>>(char uplo, int n, c10::complex<double> *a, int lda, int *info) {
824 zpotrf_(&uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, info);
825 }
826
lapackCholesky(char uplo,int n,c10::complex<float> * a,int lda,int * info)827 template<> void lapackCholesky<c10::complex<float>>(char uplo, int n, c10::complex<float> *a, int lda, int *info) {
828 cpotrf_(&uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, info);
829 }
830
lapackCholesky(char uplo,int n,double * a,int lda,int * info)831 template<> void lapackCholesky<double>(char uplo, int n, double *a, int lda, int *info) {
832 dpotrf_(&uplo, &n, a, &lda, info);
833 }
834
lapackCholesky(char uplo,int n,float * a,int lda,int * info)835 template<> void lapackCholesky<float>(char uplo, int n, float *a, int lda, int *info) {
836 spotrf_(&uplo, &n, a, &lda, info);
837 }
838
lapackCholeskyInverse(char uplo,int n,c10::complex<double> * a,int lda,int * info)839 template<> void lapackCholeskyInverse<c10::complex<double>>(char uplo, int n, c10::complex<double> *a, int lda, int *info) {
840 zpotri_(&uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, info);
841 }
842
lapackCholeskyInverse(char uplo,int n,c10::complex<float> * a,int lda,int * info)843 template<> void lapackCholeskyInverse<c10::complex<float>>(char uplo, int n, c10::complex<float> *a, int lda, int *info) {
844 cpotri_(&uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, info);
845 }
846
lapackCholeskyInverse(char uplo,int n,double * a,int lda,int * info)847 template<> void lapackCholeskyInverse<double>(char uplo, int n, double *a, int lda, int *info) {
848 dpotri_(&uplo, &n, a, &lda, info);
849 }
850
lapackCholeskyInverse(char uplo,int n,float * a,int lda,int * info)851 template<> void lapackCholeskyInverse<float>(char uplo, int n, float *a, int lda, int *info) {
852 spotri_(&uplo, &n, a, &lda, info);
853 }
854
lapackGeqrf(int m,int n,c10::complex<double> * a,int lda,c10::complex<double> * tau,c10::complex<double> * work,int lwork,int * info)855 template<> void lapackGeqrf<c10::complex<double>>(int m, int n, c10::complex<double> *a, int lda, c10::complex<double> *tau, c10::complex<double> *work, int lwork, int *info) {
856 zgeqrf_(&m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(work), &lwork, info);
857 }
858
lapackGeqrf(int m,int n,c10::complex<float> * a,int lda,c10::complex<float> * tau,c10::complex<float> * work,int lwork,int * info)859 template<> void lapackGeqrf<c10::complex<float>>(int m, int n, c10::complex<float> *a, int lda, c10::complex<float> *tau, c10::complex<float> *work, int lwork, int *info) {
860 cgeqrf_(&m, &n, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(tau), reinterpret_cast<std::complex<float>*>(work), &lwork, info);
861 }
862
lapackGeqrf(int m,int n,double * a,int lda,double * tau,double * work,int lwork,int * info)863 template<> void lapackGeqrf<double>(int m, int n, double *a, int lda, double *tau, double *work, int lwork, int *info) {
864 dgeqrf_(&m, &n, a, &lda, tau, work, &lwork, info);
865 }
866
lapackGeqrf(int m,int n,float * a,int lda,float * tau,float * work,int lwork,int * info)867 template<> void lapackGeqrf<float>(int m, int n, float *a, int lda, float *tau, float *work, int lwork, int *info) {
868 sgeqrf_(&m, &n, a, &lda, tau, work, &lwork, info);
869 }
870
lapackOrgqr(int m,int n,int k,c10::complex<double> * a,int lda,c10::complex<double> * tau,c10::complex<double> * work,int lwork,int * info)871 template<> void lapackOrgqr<c10::complex<double>>(int m, int n, int k, c10::complex<double> *a, int lda, c10::complex<double> *tau, c10::complex<double> *work, int lwork, int *info) {
872 zungqr_(&m, &n, &k, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(work), &lwork, info);
873 }
874
lapackOrgqr(int m,int n,int k,c10::complex<float> * a,int lda,c10::complex<float> * tau,c10::complex<float> * work,int lwork,int * info)875 template<> void lapackOrgqr<c10::complex<float>>(int m, int n, int k, c10::complex<float> *a, int lda, c10::complex<float> *tau, c10::complex<float> *work, int lwork, int *info) {
876 cungqr_(&m, &n, &k, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(tau), reinterpret_cast<std::complex<float>*>(work), &lwork, info);
877 }
878
lapackOrgqr(int m,int n,int k,double * a,int lda,double * tau,double * work,int lwork,int * info)879 template<> void lapackOrgqr<double>(int m, int n, int k, double *a, int lda, double *tau, double *work, int lwork, int *info) {
880 dorgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, info);
881 }
882
lapackOrgqr(int m,int n,int k,float * a,int lda,float * tau,float * work,int lwork,int * info)883 template<> void lapackOrgqr<float>(int m, int n, int k, float *a, int lda, float *tau, float *work, int lwork, int *info) {
884 sorgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, info);
885 }
886
lapackOrmqr(char side,char trans,int m,int n,int k,c10::complex<double> * a,int lda,c10::complex<double> * tau,c10::complex<double> * c,int ldc,c10::complex<double> * work,int lwork,int * info)887 template<> void lapackOrmqr<c10::complex<double>>(char side, char trans, int m, int n, int k, c10::complex<double> *a, int lda, c10::complex<double> *tau, c10::complex<double> *c, int ldc, c10::complex<double> *work, int lwork, int *info) {
888 zunmqr_(&side, &trans, &m, &n, &k, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(c), &ldc, reinterpret_cast<std::complex<double>*>(work), &lwork, info);
889 }
890
lapackOrmqr(char side,char trans,int m,int n,int k,c10::complex<float> * a,int lda,c10::complex<float> * tau,c10::complex<float> * c,int ldc,c10::complex<float> * work,int lwork,int * info)891 template<> void lapackOrmqr<c10::complex<float>>(char side, char trans, int m, int n, int k, c10::complex<float> *a, int lda, c10::complex<float> *tau, c10::complex<float> *c, int ldc, c10::complex<float> *work, int lwork, int *info) {
892 cunmqr_(&side, &trans, &m, &n, &k, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(tau), reinterpret_cast<std::complex<float>*>(c), &ldc, reinterpret_cast<std::complex<float>*>(work), &lwork, info);
893 }
894
lapackOrmqr(char side,char trans,int m,int n,int k,double * a,int lda,double * tau,double * c,int ldc,double * work,int lwork,int * info)895 template<> void lapackOrmqr<double>(char side, char trans, int m, int n, int k, double *a, int lda, double *tau, double *c, int ldc, double *work, int lwork, int *info) {
896 dormqr_(&side, &trans, &m, &n, &k, a, &lda, tau, c, &ldc, work, &lwork, info);
897 }
898
lapackOrmqr(char side,char trans,int m,int n,int k,float * a,int lda,float * tau,float * c,int ldc,float * work,int lwork,int * info)899 template<> void lapackOrmqr<float>(char side, char trans, int m, int n, int k, float *a, int lda, float *tau, float *c, int ldc, float *work, int lwork, int *info) {
900 sormqr_(&side, &trans, &m, &n, &k, a, &lda, tau, c, &ldc, work, &lwork, info);
901 }
902
lapackSyevd(char jobz,char uplo,int n,c10::complex<double> * a,int lda,double * w,c10::complex<double> * work,int lwork,double * rwork,int lrwork,int * iwork,int liwork,int * info)903 template<> void lapackSyevd<c10::complex<double>, double>(char jobz, char uplo, int n, c10::complex<double> *a, int lda, double *w, c10::complex<double> *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) {
904 zheevd_(&jobz, &uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, w, reinterpret_cast<std::complex<double>*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info);
905 }
906
lapackSyevd(char jobz,char uplo,int n,c10::complex<float> * a,int lda,float * w,c10::complex<float> * work,int lwork,float * rwork,int lrwork,int * iwork,int liwork,int * info)907 template<> void lapackSyevd<c10::complex<float>, float>(char jobz, char uplo, int n, c10::complex<float> *a, int lda, float *w, c10::complex<float> *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) {
908 cheevd_(&jobz, &uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, w, reinterpret_cast<std::complex<float>*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info);
909 }
910
lapackSyevd(char jobz,char uplo,int n,double * a,int lda,double * w,double * work,int lwork,double * rwork,int lrwork,int * iwork,int liwork,int * info)911 template<> void lapackSyevd<double>(char jobz, char uplo, int n, double *a, int lda, double *w, double *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) {
912 (void)rwork; // unused
913 (void)lrwork; // unused
914 dsyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
915 }
916
lapackSyevd(char jobz,char uplo,int n,float * a,int lda,float * w,float * work,int lwork,float * rwork,int lrwork,int * iwork,int liwork,int * info)917 template<> void lapackSyevd<float>(char jobz, char uplo, int n, float *a, int lda, float *w, float *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) {
918 (void)rwork; // unused
919 (void)lrwork; // unused
920 ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
921 }
922
lapackEig(char jobvl,char jobvr,int n,double * a,int lda,double * w,double * vl,int ldvl,double * vr,int ldvr,double * work,int lwork,double * rwork,int * info)923 template<> void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int lda, double *w, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, double *rwork, int *info) {
924 // lapack [sd]geev wants to separate output arrays: wr and wi for the real
925 // and imaginary parts
926 double *wr = w;
927 double *wi = w ? w + n : nullptr;
928 (void)rwork; // unused
929 dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
930 }
931
lapackEig(char jobvl,char jobvr,int n,float * a,int lda,float * w,float * vl,int ldvl,float * vr,int ldvr,float * work,int lwork,float * rwork,int * info)932 template<> void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int lda, float *w, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, float *rwork, int *info) {
933 // lapack [sd]geev wants to separate output arrays: wr and wi for the real
934 // and imaginary parts
935 float *wr = w;
936 float *wi = w ? w + n : nullptr;
937 (void)rwork; // unused
938 sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
939 }
940
lapackEig(char jobvl,char jobvr,int n,c10::complex<double> * a,int lda,c10::complex<double> * w,c10::complex<double> * vl,int ldvl,c10::complex<double> * vr,int ldvr,c10::complex<double> * work,int lwork,double * rwork,int * info)941 template<> void lapackEig<c10::complex<double>, double>(char jobvl, char jobvr, int n, c10::complex<double> *a, int lda, c10::complex<double> *w, c10::complex<double> *vl, int ldvl, c10::complex<double> *vr, int ldvr, c10::complex<double> *work, int lwork, double *rwork, int *info) {
942 zgeev_(&jobvl, &jobvr, &n,
943 reinterpret_cast<std::complex<double>*>(a), &lda,
944 reinterpret_cast<std::complex<double>*>(w),
945 reinterpret_cast<std::complex<double>*>(vl), &ldvl,
946 reinterpret_cast<std::complex<double>*>(vr), &ldvr,
947 reinterpret_cast<std::complex<double>*>(work), &lwork,
948 rwork, info);
949 }
950
lapackEig(char jobvl,char jobvr,int n,c10::complex<float> * a,int lda,c10::complex<float> * w,c10::complex<float> * vl,int ldvl,c10::complex<float> * vr,int ldvr,c10::complex<float> * work,int lwork,float * rwork,int * info)951 template<> void lapackEig<c10::complex<float>, float>(char jobvl, char jobvr, int n, c10::complex<float> *a, int lda, c10::complex<float> *w, c10::complex<float> *vl, int ldvl, c10::complex<float> *vr, int ldvr, c10::complex<float> *work, int lwork, float *rwork, int *info) {
952 cgeev_(&jobvl, &jobvr, &n,
953 reinterpret_cast<std::complex<float>*>(a), &lda,
954 reinterpret_cast<std::complex<float>*>(w),
955 reinterpret_cast<std::complex<float>*>(vl), &ldvl,
956 reinterpret_cast<std::complex<float>*>(vr), &ldvr,
957 reinterpret_cast<std::complex<float>*>(work), &lwork,
958 rwork, info);
959 }
960
lapackSvd(char jobz,int m,int n,c10::complex<double> * a,int lda,double * s,c10::complex<double> * u,int ldu,c10::complex<double> * vt,int ldvt,c10::complex<double> * work,int lwork,double * rwork,int * iwork,int * info)961 template<> void lapackSvd<c10::complex<double>, double>(char jobz, int m, int n, c10::complex<double> *a, int lda,
962 double *s, c10::complex<double> *u, int ldu, c10::complex<double> *vt, int ldvt, c10::complex<double> *work, int lwork, double *rwork, int *iwork, int *info) {
963 zgesdd_(&jobz, &m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, s, reinterpret_cast<std::complex<double>*>(u), &ldu,
964 reinterpret_cast<std::complex<double>*>(vt), &ldvt, reinterpret_cast<std::complex<double>*>(work), &lwork, rwork, iwork, info);
965 }
966
lapackSvd(char jobz,int m,int n,c10::complex<float> * a,int lda,float * s,c10::complex<float> * u,int ldu,c10::complex<float> * vt,int ldvt,c10::complex<float> * work,int lwork,float * rwork,int * iwork,int * info)967 template<> void lapackSvd<c10::complex<float>, float>(char jobz, int m, int n, c10::complex<float> *a, int lda,
968 float *s, c10::complex<float> *u, int ldu, c10::complex<float> *vt, int ldvt, c10::complex<float> *work, int lwork, float *rwork, int *iwork, int *info) {
969 cgesdd_(&jobz, &m, &n, reinterpret_cast<std::complex<float>*>(a), &lda, s, reinterpret_cast<std::complex<float>*>(u), &ldu,
970 reinterpret_cast<std::complex<float>*>(vt), &ldvt, reinterpret_cast<std::complex<float>*>(work), &lwork, rwork, iwork, info);
971 }
972
lapackSvd(char jobz,int m,int n,double * a,int lda,double * s,double * u,int ldu,double * vt,int ldvt,double * work,int lwork,double * rwork,int * iwork,int * info)973 template<> void lapackSvd<double>(char jobz, int m, int n, double *a, int lda,
974 double *s, double *u, int ldu, double *vt, int ldvt, double *work, int lwork, double *rwork, int *iwork, int *info) {
975 dgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info);
976 }
977
lapackSvd(char jobz,int m,int n,float * a,int lda,float * s,float * u,int ldu,float * vt,int ldvt,float * work,int lwork,float * rwork,int * iwork,int * info)978 template<> void lapackSvd<float>(char jobz, int m, int n, float *a, int lda,
979 float *s, float *u, int ldu, float *vt, int ldvt, float *work, int lwork, float *rwork, int *iwork, int *info) {
980 sgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info);
981 }
982
983 template <>
lapackLdlSymmetric(char uplo,int n,double * a,int lda,int * ipiv,double * work,int lwork,int * info)984 void lapackLdlSymmetric<double>(
985 char uplo,
986 int n,
987 double* a,
988 int lda,
989 int* ipiv,
990 double* work,
991 int lwork,
992 int* info) {
993 dsytrf_(&uplo, &n, a, &lda, ipiv, work, &lwork, info);
994 }
995
996 template <>
lapackLdlSymmetric(char uplo,int n,float * a,int lda,int * ipiv,float * work,int lwork,int * info)997 void lapackLdlSymmetric<float>(
998 char uplo,
999 int n,
1000 float* a,
1001 int lda,
1002 int* ipiv,
1003 float* work,
1004 int lwork,
1005 int* info) {
1006 ssytrf_(&uplo, &n, a, &lda, ipiv, work, &lwork, info);
1007 }
1008
1009 template <>
lapackLdlSymmetric(char uplo,int n,c10::complex<double> * a,int lda,int * ipiv,c10::complex<double> * work,int lwork,int * info)1010 void lapackLdlSymmetric<c10::complex<double>>(
1011 char uplo,
1012 int n,
1013 c10::complex<double>* a,
1014 int lda,
1015 int* ipiv,
1016 c10::complex<double>* work,
1017 int lwork,
1018 int* info) {
1019 zsytrf_(
1020 &uplo,
1021 &n,
1022 reinterpret_cast<std::complex<double>*>(a),
1023 &lda,
1024 ipiv,
1025 reinterpret_cast<std::complex<double>*>(work),
1026 &lwork,
1027 info);
1028 }
1029
1030 template <>
lapackLdlSymmetric(char uplo,int n,c10::complex<float> * a,int lda,int * ipiv,c10::complex<float> * work,int lwork,int * info)1031 void lapackLdlSymmetric<c10::complex<float>>(
1032 char uplo,
1033 int n,
1034 c10::complex<float>* a,
1035 int lda,
1036 int* ipiv,
1037 c10::complex<float>* work,
1038 int lwork,
1039 int* info) {
1040 csytrf_(
1041 &uplo,
1042 &n,
1043 reinterpret_cast<std::complex<float>*>(a),
1044 &lda,
1045 ipiv,
1046 reinterpret_cast<std::complex<float>*>(work),
1047 &lwork,
1048 info);
1049 }
1050
1051 template <>
lapackLdlHermitian(char uplo,int n,double * a,int lda,int * ipiv,double * work,int lwork,int * info)1052 void lapackLdlHermitian<double>(
1053 char uplo,
1054 int n,
1055 double* a,
1056 int lda,
1057 int* ipiv,
1058 double* work,
1059 int lwork,
1060 int* info) {
1061 dsytrf_(&uplo, &n, a, &lda, ipiv, work, &lwork, info);
1062 }
1063
1064 template <>
lapackLdlHermitian(char uplo,int n,float * a,int lda,int * ipiv,float * work,int lwork,int * info)1065 void lapackLdlHermitian<float>(
1066 char uplo,
1067 int n,
1068 float* a,
1069 int lda,
1070 int* ipiv,
1071 float* work,
1072 int lwork,
1073 int* info) {
1074 ssytrf_(&uplo, &n, a, &lda, ipiv, work, &lwork, info);
1075 }
1076
1077 template <>
lapackLdlHermitian(char uplo,int n,c10::complex<double> * a,int lda,int * ipiv,c10::complex<double> * work,int lwork,int * info)1078 void lapackLdlHermitian<c10::complex<double>>(
1079 char uplo,
1080 int n,
1081 c10::complex<double>* a,
1082 int lda,
1083 int* ipiv,
1084 c10::complex<double>* work,
1085 int lwork,
1086 int* info) {
1087 zhetrf_(
1088 &uplo,
1089 &n,
1090 reinterpret_cast<std::complex<double>*>(a),
1091 &lda,
1092 ipiv,
1093 reinterpret_cast<std::complex<double>*>(work),
1094 &lwork,
1095 info);
1096 }
1097
1098 template <>
lapackLdlHermitian(char uplo,int n,c10::complex<float> * a,int lda,int * ipiv,c10::complex<float> * work,int lwork,int * info)1099 void lapackLdlHermitian<c10::complex<float>>(
1100 char uplo,
1101 int n,
1102 c10::complex<float>* a,
1103 int lda,
1104 int* ipiv,
1105 c10::complex<float>* work,
1106 int lwork,
1107 int* info) {
1108 chetrf_(
1109 &uplo,
1110 &n,
1111 reinterpret_cast<std::complex<float>*>(a),
1112 &lda,
1113 ipiv,
1114 reinterpret_cast<std::complex<float>*>(work),
1115 &lwork,
1116 info);
1117 }
1118
1119 template <>
lapackLdlSolveSymmetric(char uplo,int n,int nrhs,double * a,int lda,int * ipiv,double * b,int ldb,int * info)1120 void lapackLdlSolveSymmetric<double>(
1121 char uplo,
1122 int n,
1123 int nrhs,
1124 double* a,
1125 int lda,
1126 int* ipiv,
1127 double* b,
1128 int ldb,
1129 int* info) {
1130 dsytrs_(&uplo, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1131 }
1132
1133 template <>
lapackLdlSolveSymmetric(char uplo,int n,int nrhs,float * a,int lda,int * ipiv,float * b,int ldb,int * info)1134 void lapackLdlSolveSymmetric<float>(
1135 char uplo,
1136 int n,
1137 int nrhs,
1138 float* a,
1139 int lda,
1140 int* ipiv,
1141 float* b,
1142 int ldb,
1143 int* info) {
1144 ssytrs_(&uplo, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1145 }
1146
1147 template <>
lapackLdlSolveSymmetric(char uplo,int n,int nrhs,c10::complex<double> * a,int lda,int * ipiv,c10::complex<double> * b,int ldb,int * info)1148 void lapackLdlSolveSymmetric<c10::complex<double>>(
1149 char uplo,
1150 int n,
1151 int nrhs,
1152 c10::complex<double>* a,
1153 int lda,
1154 int* ipiv,
1155 c10::complex<double>* b,
1156 int ldb,
1157 int* info) {
1158 zsytrs_(
1159 &uplo,
1160 &n,
1161 &nrhs,
1162 reinterpret_cast<std::complex<double>*>(a),
1163 &lda,
1164 ipiv,
1165 reinterpret_cast<std::complex<double>*>(b),
1166 &ldb,
1167 info);
1168 }
1169
1170 template <>
lapackLdlSolveSymmetric(char uplo,int n,int nrhs,c10::complex<float> * a,int lda,int * ipiv,c10::complex<float> * b,int ldb,int * info)1171 void lapackLdlSolveSymmetric<c10::complex<float>>(
1172 char uplo,
1173 int n,
1174 int nrhs,
1175 c10::complex<float>* a,
1176 int lda,
1177 int* ipiv,
1178 c10::complex<float>* b,
1179 int ldb,
1180 int* info) {
1181 csytrs_(
1182 &uplo,
1183 &n,
1184 &nrhs,
1185 reinterpret_cast<std::complex<float>*>(a),
1186 &lda,
1187 ipiv,
1188 reinterpret_cast<std::complex<float>*>(b),
1189 &ldb,
1190 info);
1191 }
1192
1193 template <>
lapackLdlSolveHermitian(char uplo,int n,int nrhs,double * a,int lda,int * ipiv,double * b,int ldb,int * info)1194 void lapackLdlSolveHermitian<double>(
1195 char uplo,
1196 int n,
1197 int nrhs,
1198 double* a,
1199 int lda,
1200 int* ipiv,
1201 double* b,
1202 int ldb,
1203 int* info) {
1204 dsytrs_(&uplo, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1205 }
1206
1207 template <>
lapackLdlSolveHermitian(char uplo,int n,int nrhs,float * a,int lda,int * ipiv,float * b,int ldb,int * info)1208 void lapackLdlSolveHermitian<float>(
1209 char uplo,
1210 int n,
1211 int nrhs,
1212 float* a,
1213 int lda,
1214 int* ipiv,
1215 float* b,
1216 int ldb,
1217 int* info) {
1218 ssytrs_(&uplo, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1219 }
1220
1221 template <>
lapackLdlSolveHermitian(char uplo,int n,int nrhs,c10::complex<double> * a,int lda,int * ipiv,c10::complex<double> * b,int ldb,int * info)1222 void lapackLdlSolveHermitian<c10::complex<double>>(
1223 char uplo,
1224 int n,
1225 int nrhs,
1226 c10::complex<double>* a,
1227 int lda,
1228 int* ipiv,
1229 c10::complex<double>* b,
1230 int ldb,
1231 int* info) {
1232 zhetrs_(
1233 &uplo,
1234 &n,
1235 &nrhs,
1236 reinterpret_cast<std::complex<double>*>(a),
1237 &lda,
1238 ipiv,
1239 reinterpret_cast<std::complex<double>*>(b),
1240 &ldb,
1241 info);
1242 }
1243
1244 template <>
lapackLdlSolveHermitian(char uplo,int n,int nrhs,c10::complex<float> * a,int lda,int * ipiv,c10::complex<float> * b,int ldb,int * info)1245 void lapackLdlSolveHermitian<c10::complex<float>>(
1246 char uplo,
1247 int n,
1248 int nrhs,
1249 c10::complex<float>* a,
1250 int lda,
1251 int* ipiv,
1252 c10::complex<float>* b,
1253 int ldb,
1254 int* info) {
1255 chetrs_(
1256 &uplo,
1257 &n,
1258 &nrhs,
1259 reinterpret_cast<std::complex<float>*>(a),
1260 &lda,
1261 ipiv,
1262 reinterpret_cast<std::complex<float>*>(b),
1263 &ldb,
1264 info);
1265 }
1266
lapackLuSolve(char trans,int n,int nrhs,c10::complex<double> * a,int lda,int * ipiv,c10::complex<double> * b,int ldb,int * info)1267 template<> void lapackLuSolve<c10::complex<double>>(char trans, int n, int nrhs, c10::complex<double> *a, int lda, int *ipiv, c10::complex<double> *b, int ldb, int *info) {
1268 zgetrs_(&trans, &n, &nrhs, reinterpret_cast<std::complex<double>*>(a), &lda, ipiv, reinterpret_cast<std::complex<double>*>(b), &ldb, info);
1269 }
1270
lapackLuSolve(char trans,int n,int nrhs,c10::complex<float> * a,int lda,int * ipiv,c10::complex<float> * b,int ldb,int * info)1271 template<> void lapackLuSolve<c10::complex<float>>(char trans, int n, int nrhs, c10::complex<float> *a, int lda, int *ipiv, c10::complex<float> *b, int ldb, int *info) {
1272 cgetrs_(&trans, &n, &nrhs, reinterpret_cast<std::complex<float>*>(a), &lda, ipiv, reinterpret_cast<std::complex<float>*>(b), &ldb, info);
1273 }
1274
lapackLuSolve(char trans,int n,int nrhs,double * a,int lda,int * ipiv,double * b,int ldb,int * info)1275 template<> void lapackLuSolve<double>(char trans, int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
1276 dgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1277 }
1278
lapackLuSolve(char trans,int n,int nrhs,float * a,int lda,int * ipiv,float * b,int ldb,int * info)1279 template<> void lapackLuSolve<float>(char trans, int n, int nrhs, float *a, int lda, int *ipiv, float *b, int ldb, int *info) {
1280 sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1281 }
1282
lapackGels(char trans,int m,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb,c10::complex<double> * work,int lwork,int * info)1283 template<> void lapackGels<c10::complex<double>>(
1284 char trans, int m, int n, int nrhs,
1285 c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb,
1286 c10::complex<double> *work, int lwork, int *info) {
1287 zgels_(&trans, &m, &n, &nrhs,
1288 reinterpret_cast<std::complex<double>*>(a), &lda,
1289 reinterpret_cast<std::complex<double>*>(b), &ldb,
1290 reinterpret_cast<std::complex<double>*>(work), &lwork, info);
1291 }
1292
lapackGels(char trans,int m,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb,c10::complex<float> * work,int lwork,int * info)1293 template<> void lapackGels<c10::complex<float>>(
1294 char trans, int m, int n, int nrhs,
1295 c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb,
1296 c10::complex<float> *work, int lwork, int *info) {
1297 cgels_(&trans, &m, &n, &nrhs,
1298 reinterpret_cast<std::complex<float>*>(a), &lda,
1299 reinterpret_cast<std::complex<float>*>(b), &ldb,
1300 reinterpret_cast<std::complex<float>*>(work), &lwork, info);
1301 }
1302
lapackGels(char trans,int m,int n,int nrhs,double * a,int lda,double * b,int ldb,double * work,int lwork,int * info)1303 template<> void lapackGels<double>(
1304 char trans, int m, int n, int nrhs,
1305 double *a, int lda, double *b, int ldb,
1306 double *work, int lwork, int *info) {
1307 dgels_(&trans, &m, &n, &nrhs,
1308 a, &lda, b, &ldb, work, &lwork, info);
1309 }
1310
lapackGels(char trans,int m,int n,int nrhs,float * a,int lda,float * b,int ldb,float * work,int lwork,int * info)1311 template<> void lapackGels<float>(
1312 char trans, int m, int n, int nrhs,
1313 float *a, int lda, float *b, int ldb,
1314 float *work, int lwork, int *info) {
1315 sgels_(&trans, &m, &n, &nrhs,
1316 a, &lda, b, &ldb, work, &lwork, info);
1317 }
1318
lapackGelsd(int m,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb,double * s,double rcond,int * rank,c10::complex<double> * work,int lwork,double * rwork,int * iwork,int * info)1319 template<> void lapackGelsd<c10::complex<double>, double>(
1320 int m, int n, int nrhs,
1321 c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb,
1322 double *s, double rcond, int *rank,
1323 c10::complex<double> *work, int lwork,
1324 double *rwork, int *iwork, int *info) {
1325 zgelsd_(&m, &n, &nrhs,
1326 reinterpret_cast<std::complex<double>*>(a), &lda,
1327 reinterpret_cast<std::complex<double>*>(b), &ldb,
1328 s, &rcond, rank,
1329 reinterpret_cast<std::complex<double>*>(work), &lwork,
1330 rwork, iwork, info);
1331 }
1332
lapackGelsd(int m,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb,float * s,float rcond,int * rank,c10::complex<float> * work,int lwork,float * rwork,int * iwork,int * info)1333 template<> void lapackGelsd<c10::complex<float>, float>(
1334 int m, int n, int nrhs,
1335 c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb,
1336 float *s, float rcond, int *rank,
1337 c10::complex<float> *work, int lwork,
1338 float *rwork, int *iwork, int *info) {
1339 cgelsd_(&m, &n, &nrhs,
1340 reinterpret_cast<std::complex<float>*>(a), &lda,
1341 reinterpret_cast<std::complex<float>*>(b), &ldb,
1342 s, &rcond, rank,
1343 reinterpret_cast<std::complex<float>*>(work), &lwork,
1344 rwork, iwork, info);
1345 }
1346
lapackGelsd(int m,int n,int nrhs,double * a,int lda,double * b,int ldb,double * s,double rcond,int * rank,double * work,int lwork,double * rwork,int * iwork,int * info)1347 template<> void lapackGelsd<double>(
1348 int m, int n, int nrhs,
1349 double *a, int lda, double *b, int ldb,
1350 double *s, double rcond, int *rank,
1351 double *work, int lwork,
1352 double *rwork, int *iwork, int *info) {
1353 dgelsd_(&m, &n, &nrhs,
1354 a, &lda, b, &ldb,
1355 s, &rcond, rank,
1356 work, &lwork, iwork, info);
1357 }
1358
lapackGelsd(int m,int n,int nrhs,float * a,int lda,float * b,int ldb,float * s,float rcond,int * rank,float * work,int lwork,float * rwork,int * iwork,int * info)1359 template<> void lapackGelsd<float>(
1360 int m, int n, int nrhs,
1361 float *a, int lda, float *b, int ldb,
1362 float *s, float rcond, int *rank,
1363 float *work, int lwork,
1364 float *rwork, int *iwork, int *info) {
1365 sgelsd_(&m, &n, &nrhs,
1366 a, &lda, b, &ldb,
1367 s, &rcond, rank,
1368 work, &lwork, iwork, info);
1369 }
1370
lapackGelsy(int m,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb,int * jpvt,double rcond,int * rank,c10::complex<double> * work,int lwork,double * rwork,int * info)1371 template<> void lapackGelsy<c10::complex<double>, double>(
1372 int m, int n, int nrhs,
1373 c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb,
1374 int *jpvt, double rcond, int *rank,
1375 c10::complex<double> *work, int lwork, double *rwork, int *info) {
1376 zgelsy_(&m, &n, &nrhs,
1377 reinterpret_cast<std::complex<double>*>(a), &lda,
1378 reinterpret_cast<std::complex<double>*>(b), &ldb,
1379 jpvt, &rcond, rank,
1380 reinterpret_cast<std::complex<double>*>(work), &lwork,
1381 rwork, info);
1382 }
1383
lapackGelsy(int m,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb,int * jpvt,float rcond,int * rank,c10::complex<float> * work,int lwork,float * rwork,int * info)1384 template<> void lapackGelsy<c10::complex<float>, float>(
1385 int m, int n, int nrhs,
1386 c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb,
1387 int *jpvt, float rcond, int *rank,
1388 c10::complex<float> *work, int lwork, float *rwork, int *info) {
1389 cgelsy_(&m, &n, &nrhs,
1390 reinterpret_cast<std::complex<float>*>(a), &lda,
1391 reinterpret_cast<std::complex<float>*>(b), &ldb,
1392 jpvt, &rcond, rank,
1393 reinterpret_cast<std::complex<float>*>(work), &lwork,
1394 rwork, info);
1395 }
1396
lapackGelsy(int m,int n,int nrhs,double * a,int lda,double * b,int ldb,int * jpvt,double rcond,int * rank,double * work,int lwork,double * rwork,int * info)1397 template<> void lapackGelsy<double>(
1398 int m, int n, int nrhs,
1399 double *a, int lda, double *b, int ldb,
1400 int *jpvt, double rcond, int *rank,
1401 double *work, int lwork, double *rwork, int *info) {
1402 dgelsy_(&m, &n, &nrhs,
1403 a, &lda, b, &ldb,
1404 jpvt, &rcond, rank,
1405 work, &lwork, info);
1406 }
1407
lapackGelsy(int m,int n,int nrhs,float * a,int lda,float * b,int ldb,int * jpvt,float rcond,int * rank,float * work,int lwork,float * rwork,int * info)1408 template<> void lapackGelsy<float>(
1409 int m, int n, int nrhs,
1410 float *a, int lda, float *b, int ldb,
1411 int *jpvt, float rcond, int *rank,
1412 float *work, int lwork, float *rwork, int *info) {
1413 sgelsy_(&m, &n, &nrhs,
1414 a, &lda, b, &ldb,
1415 jpvt, &rcond, rank,
1416 work, &lwork, info);
1417 }
1418
lapackGelss(int m,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb,double * s,double rcond,int * rank,c10::complex<double> * work,int lwork,double * rwork,int * info)1419 template<> void lapackGelss<c10::complex<double>, double>(
1420 int m, int n, int nrhs,
1421 c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb,
1422 double *s, double rcond, int *rank,
1423 c10::complex<double> *work, int lwork,
1424 double *rwork, int *info
1425 ) {
1426 zgelss_(&m, &n, &nrhs,
1427 reinterpret_cast<std::complex<double>*>(a), &lda,
1428 reinterpret_cast<std::complex<double>*>(b), &ldb,
1429 s, &rcond, rank,
1430 reinterpret_cast<std::complex<double>*>(work), &lwork,
1431 rwork, info);
1432 }
1433
lapackGelss(int m,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb,float * s,float rcond,int * rank,c10::complex<float> * work,int lwork,float * rwork,int * info)1434 template<> void lapackGelss<c10::complex<float>, float>(
1435 int m, int n, int nrhs,
1436 c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb,
1437 float *s, float rcond, int *rank,
1438 c10::complex<float> *work, int lwork,
1439 float *rwork, int *info
1440 ) {
1441 cgelss_(&m, &n, &nrhs,
1442 reinterpret_cast<std::complex<float>*>(a), &lda,
1443 reinterpret_cast<std::complex<float>*>(b), &ldb,
1444 s, &rcond, rank,
1445 reinterpret_cast<std::complex<float>*>(work), &lwork,
1446 rwork, info);
1447 }
1448
lapackGelss(int m,int n,int nrhs,double * a,int lda,double * b,int ldb,double * s,double rcond,int * rank,double * work,int lwork,double * rwork,int * info)1449 template<> void lapackGelss<double>(
1450 int m, int n, int nrhs,
1451 double *a, int lda, double *b, int ldb,
1452 double *s, double rcond, int *rank,
1453 double *work, int lwork,
1454 double *rwork, int *info) {
1455 dgelss_(&m, &n, &nrhs,
1456 a, &lda, b, &ldb,
1457 s, &rcond, rank,
1458 work, &lwork, info);
1459 }
1460
lapackGelss(int m,int n,int nrhs,float * a,int lda,float * b,int ldb,float * s,float rcond,int * rank,float * work,int lwork,float * rwork,int * info)1461 template<> void lapackGelss<float>(
1462 int m, int n, int nrhs,
1463 float *a, int lda, float *b, int ldb,
1464 float *s, float rcond, int *rank,
1465 float *work, int lwork,
1466 float *rwork, int *info) {
1467 sgelss_(&m, &n, &nrhs,
1468 a, &lda, b, &ldb,
1469 s, &rcond, rank,
1470 work, &lwork, info);
1471 }
1472 #endif
1473
1474 #if AT_BUILD_WITH_BLAS()
blasTriangularSolve(char side,char uplo,char trans,char diag,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb)1475 template<> void blasTriangularSolve<c10::complex<double>>(char side, char uplo, char trans, char diag, int n, int nrhs, c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb) {
1476 std::complex<double> one{1., 0.};
1477 ztrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(b), &ldb);
1478 }
1479
blasTriangularSolve(char side,char uplo,char trans,char diag,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb)1480 template<> void blasTriangularSolve<c10::complex<float>>(char side, char uplo, char trans, char diag, int n, int nrhs, c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb) {
1481 std::complex<float> one{1.f, 0.f};
1482 ctrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(b), &ldb);
1483 }
1484
blasTriangularSolve(char side,char uplo,char trans,char diag,int n,int nrhs,double * a,int lda,double * b,int ldb)1485 template<> void blasTriangularSolve<double>(char side, char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb) {
1486 auto one = 1.;
1487 dtrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, a, &lda, b, &ldb);
1488 }
1489
blasTriangularSolve(char side,char uplo,char trans,char diag,int n,int nrhs,float * a,int lda,float * b,int ldb)1490 template<> void blasTriangularSolve<float>(char side, char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb) {
1491 auto one = 1.f;
1492 strsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, a, &lda, b, &ldb);
1493 }
1494 #endif
1495
_linalg_check_errors(const Tensor & infos,const c10::string_view api_name,bool is_matrix)1496 void _linalg_check_errors(
1497 const Tensor& infos,
1498 const c10::string_view api_name,
1499 bool is_matrix) {
1500 TORCH_INTERNAL_ASSERT(infos.scalar_type() == kInt);
1501 TORCH_INTERNAL_ASSERT(infos.is_contiguous());
1502 if (infos.is_meta()) {
1503 return;
1504 }
1505
1506 // If it's all zeros, we return early.
1507 // We optimise for the most likely case.
1508 if (C10_LIKELY(!infos.any().item<bool>())) {
1509 return;
1510 }
1511
1512 int32_t info = 0;
1513 std::string batch_str;
1514 if (is_matrix) {
1515 info = infos.item<int>();
1516 // batch_str needn't be set for matrices
1517 } else {
1518 // Find the first non-zero info
1519 auto infos_cpu = infos.to(at::kCPU);
1520 auto ptr = infos_cpu.const_data_ptr<int32_t>();
1521 auto n = infos.numel();
1522 auto info_ptr = std::find_if(ptr, ptr + n, [](int32_t x) { return x != 0; });
1523 info = *info_ptr;
1524 batch_str = ": (Batch element " + std::to_string(std::distance(ptr, info_ptr)) + ")";
1525 }
1526
1527 if (info < 0) {
1528 // Reference LAPACK 3.10+ changed `info` behavior for inputs with non-finite values
1529 // Previously, it would return `info` > 0, but now it returns `info` = -4
1530 // OpenBLAS 0.3.15+ uses the Reference LAPACK 3.10+.
1531 // MKL 2022.0+ uses the Reference LAPACK 3.10+.
1532 // Older version of MKL and OpenBLAS follow the old behavior (return `info` > 0).
1533 // Here we check for the case where `info` is -4 and raise an error
1534 if (api_name.find("svd") != api_name.npos) {
1535 TORCH_CHECK_LINALG(info != -4, api_name, batch_str,
1536 ": The algorithm failed to converge because the input matrix contained non-finite values.");
1537 }
1538 TORCH_INTERNAL_ASSERT(false, api_name, batch_str,
1539 ": Argument ", -info, " has illegal value. Most certainly there is a bug in the implementation calling the backend library.");
1540 } else if (info > 0) {
1541 if (api_name.find("inv") != api_name.npos) {
1542 // inv, inverse, cholesky_inverse, etc.
1543 TORCH_CHECK_LINALG(false, api_name, batch_str,
1544 ": The diagonal element ", info, " is zero, the inversion could not be completed because the input matrix is singular.");
1545 } else if (api_name.find("solve") != api_name.npos) {
1546 // solve, linalg_solve, cholesky_solve, etc.
1547 TORCH_CHECK_LINALG(false, api_name, batch_str,
1548 ": The solver failed because the input matrix is singular.");
1549 } else if (api_name.find("cholesky") != api_name.npos) {
1550 TORCH_CHECK_LINALG(false, api_name, batch_str,
1551 ": The factorization could not be completed because the input is not positive-definite (the leading minor of order ", info, " is not positive-definite).");
1552 } else if (api_name.find("svd") != api_name.npos) {
1553 TORCH_CHECK_LINALG(false, api_name, batch_str,
1554 ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated singular values (error code: ", info, ").");
1555 } else if (api_name.find("eig") != api_name.npos || api_name.find("syevd") != api_name.npos) {
1556 TORCH_CHECK_LINALG(false, api_name, batch_str,
1557 ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: ", info, ").");
1558 } else if (api_name.find("lstsq") != api_name.npos) {
1559 TORCH_CHECK_LINALG(false, api_name, batch_str,
1560 ": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ").");
1561 } else if (api_name.find("lu_factor") != api_name.npos) {
1562 TORCH_CHECK(false, api_name, batch_str,
1563 ": U[", info, ",", info, "] is zero and using it on lu_solve would result in a division by zero. "
1564 "If you still want to perform the factorization, consider calling linalg.lu(A, pivot) or "
1565 "linalg.lu_factor_ex(A, pivot)");
1566 } else {
1567 TORCH_INTERNAL_ASSERT(false, api_name, ": Unknown error code: ", info, ".");
1568 }
1569 }
1570 // We should never reach this point as info was non-zero
1571 TORCH_INTERNAL_ASSERT(false);
1572 }
1573
1574 // If an input requires fw or bw grad then we need to go down a different
1575 // (slower) path to ensure that the gradients are computable.
1576 // That is what `_may_require_fw_or_bw_grad` is helpful for.
1577 //
1578 // Why is there a isTensorSubclassLike check here?
1579 // Without it, this function can lead to composite compliance problems, which
1580 // may lead to bugs in functorch, where a Tensor Subclass that doesn't
1581 // require grad may wrap a Tensor subclass that requires grad.
_may_require_fw_or_bw_grad(const Tensor & input)1582 static bool _may_require_fw_or_bw_grad(const Tensor& input) {
1583 return ((at::GradMode::is_enabled() && input.requires_grad())
1584 || input._fw_grad(/*level */ 0).defined()
1585 || isTensorSubclassLike(input));
1586 }
1587
1588 // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast)
1589
1590 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.inv ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
TORCH_IMPL_FUNC(linalg_inv_ex_out)1591 TORCH_IMPL_FUNC(linalg_inv_ex_out)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
1592 // Fill result with the identity
1593 result.zero_();
1594 result.diagonal(0, -2, -1).fill_(1.);
1595 at::linalg_solve_ex_out(const_cast<Tensor&>(result), const_cast<Tensor&>(info), A, result, /*left*/true);
1596 if (check_errors) {
1597 at::_linalg_check_errors(info, "linalg.inv_ex", A.dim() == 2);
1598 }
1599 }
1600
linalg_inv_out(const Tensor & A,Tensor & result)1601 Tensor& linalg_inv_out(const Tensor& A, Tensor& result) {
1602 auto info = at::empty({0}, A.options().dtype(kInt));
1603 at::linalg_inv_ex_out(result, info, A);
1604 at::_linalg_check_errors(info, "linalg.inv", A.dim() == 2);
1605 return result;
1606 }
1607
linalg_inv(const Tensor & A)1608 Tensor linalg_inv(const Tensor& A) {
1609 auto [result, info] = at::linalg_inv_ex(A);
1610 at::_linalg_check_errors(info, "linalg.inv", A.dim() == 2);
1611 return result;
1612 }
1613
inverse_out(const Tensor & A,Tensor & result)1614 Tensor& inverse_out(const Tensor& A, Tensor& result) {
1615 return at::linalg_inv_out(result, A);
1616 }
1617
inverse(const Tensor & A)1618 Tensor inverse(const Tensor& A) {
1619 return at::linalg_inv(A);
1620 }
1621
1622 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1623
1624 template<typename scalar_t>
apply_cholesky_solve(Tensor & b,Tensor & A,bool upper,Tensor & infos)1625 static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, Tensor& infos) {
1626 #if !AT_BUILD_WITH_LAPACK()
1627 AT_ERROR("cholesky_solve: LAPACK library not found in compilation");
1628 #else
1629 char uplo = upper ? 'U' : 'L';
1630
1631 auto A_data = A.const_data_ptr<scalar_t>();
1632 auto b_data = b.data_ptr<scalar_t>();
1633 auto infos_data = infos.data_ptr<int>();
1634 auto A_mat_stride = matrixStride(A);
1635 auto b_mat_stride = matrixStride(b);
1636 auto batch_size = batchCount(A);
1637 auto n = A.size(-2);
1638 auto ldab = std::max<int64_t>(1, n);
1639 auto nrhs = b.size(-1);
1640
1641 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1642 int info;
1643 for (const auto i : c10::irange(batch_size)) {
1644 const scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
1645 scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
1646 lapackCholeskySolve<scalar_t>(uplo, n, nrhs, const_cast<scalar_t*>(A_working_ptr), ldab, b_working_ptr, ldab, &info);
1647 infos_data[i] = info;
1648 if (info != 0) {
1649 return;
1650 }
1651 }
1652 #endif
1653 }
1654
_cholesky_solve_helper_cpu(const Tensor & self,const Tensor & A,bool upper)1655 Tensor _cholesky_solve_helper_cpu(const Tensor& self, const Tensor& A, bool upper) {
1656 auto self_working_copy = cloneBatchedColumnMajor(self);
1657 auto A_working_copy = cloneBatchedColumnMajor(A);
1658 auto infos = at::zeros({batchCount(self)}, self.options().dtype(kInt));
1659 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "cholesky_solve_cpu", [&]{
1660 apply_cholesky_solve<scalar_t>(self_working_copy, A_working_copy, upper, infos);
1661 });
1662
1663 at::_linalg_check_errors(infos, "cholesky_solve_cpu", self.dim() == 2);
1664 return self_working_copy;
1665 }
1666
1667 // Supports arbitrary batch dimensions for self and A
cholesky_solve(const Tensor & self,const Tensor & A,bool upper)1668 Tensor cholesky_solve(const Tensor& self, const Tensor& A, bool upper) {
1669 TORCH_CHECK(self.dim() >= 2,
1670 "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
1671 TORCH_CHECK(A.dim() >= 2,
1672 "u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead");
1673 auto [self_broadcasted, A_broadcasted] = _linalg_broadcast_batch_dims(self, A, "cholesky_solve");
1674 return at::_cholesky_solve_helper(self_broadcasted, A_broadcasted, upper);
1675 }
1676
cholesky_solve_out(const Tensor & self,const Tensor & A,bool upper,Tensor & result)1677 Tensor& cholesky_solve_out(const Tensor& self, const Tensor& A, bool upper, Tensor& result) {
1678 checkSameDevice("cholesky_solve", result, self);
1679 checkLinalgCompatibleDtype("cholesky_solve", result, self);
1680 Tensor result_tmp = at::cholesky_solve(self, A, upper);
1681 at::native::resize_output(result, result_tmp.sizes());
1682 result.copy_(result_tmp);
1683 return result;
1684 }
1685
1686 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1687
1688 DEFINE_DISPATCH(cholesky_stub);
1689
cholesky(const Tensor & self,bool upper)1690 Tensor cholesky(const Tensor &self, bool upper) {
1691 TORCH_WARN_ONCE(
1692 "torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be ",
1693 "removed in a future PyTorch release.\n",
1694 "L = torch.cholesky(A)\n",
1695 "should be replaced with\n",
1696 "L = torch.linalg.cholesky(A)\n",
1697 "and\n"
1698 "U = torch.cholesky(A, upper=True)\n",
1699 "should be replaced with\n",
1700 "U = torch.linalg.cholesky(A).mH\n"
1701 "This transform will produce equivalent results for all valid (symmetric positive definite) inputs."
1702 );
1703 if (self.numel() == 0) {
1704 return at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1705 }
1706 squareCheckInputs(self, "cholesky");
1707
1708 auto raw_cholesky_output = cloneBatchedColumnMajor(self);
1709 auto info_shape = IntArrayRef(
1710 self.sizes().cbegin(), self.sizes().cend() - 2); // self.shape[:-2]
1711 auto info = at::empty({info_shape}, self.options().dtype(kInt));
1712
1713 // fill the raw_cholesky_output with the result
1714 cholesky_stub(self.device().type(), raw_cholesky_output, info, upper);
1715
1716 at::_linalg_check_errors(info, "cholesky", self.dim() == 2);
1717
1718 if (upper) {
1719 return raw_cholesky_output.triu_();
1720 } else {
1721 return raw_cholesky_output.tril_();
1722 }
1723 }
1724
cholesky_out(const Tensor & self,bool upper,Tensor & result)1725 Tensor& cholesky_out(const Tensor &self, bool upper, Tensor &result) {
1726 TORCH_WARN_ONCE(
1727 "torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be ",
1728 "removed in a future PyTorch release.\n",
1729 "L = torch.cholesky(A)\n",
1730 "should be replaced with\n",
1731 "L = torch.linalg.cholesky(A)\n",
1732 "and\n"
1733 "U = torch.cholesky(A, upper=True)\n",
1734 "should be replaced with\n",
1735 "U = torch.linalg.cholesky(A).mH\n"
1736 "This transform will produce equivalent results for all valid (symmetric positive definite) inputs."
1737 );
1738 checkSameDevice("cholesky", result, self);
1739 checkLinalgCompatibleDtype("cholesky", result, self);
1740 Tensor result_tmp = at::cholesky(self, upper);
1741 at::native::resize_output(result, result_tmp.sizes());
1742 result.copy_(result_tmp);
1743 return result;
1744 }
1745
TORCH_IMPL_FUNC(linalg_cholesky_ex_out)1746 TORCH_IMPL_FUNC(linalg_cholesky_ex_out)(const Tensor& A,
1747 bool upper,
1748 bool check_errors,
1749 const Tensor& L,
1750 const Tensor& info) {
1751 // Nothing to do there
1752 if (L.numel() == 0) {
1753 info.zero_();
1754 return;
1755 }
1756 const auto cpu = A.device() == kCPU;
1757
1758 // We can perform this optimisation just on CPU as it fails for MAGMA
1759 // due to some bug
1760 if (cpu) {
1761 if (upper) {
1762 at::triu_out(const_cast<Tensor&>(L), A);
1763 } else {
1764 at::tril_out(const_cast<Tensor&>(L), A);
1765 }
1766 } else {
1767 L.copy_(A);
1768 }
1769
1770 cholesky_stub(L.device().type(), L, info, upper);
1771
1772 if (!cpu) {
1773 if (upper) {
1774 L.triu_();
1775 } else {
1776 L.tril_();
1777 }
1778 }
1779
1780 if (check_errors) {
1781 at::_linalg_check_errors(info, "linalg.cholesky_ex", A.dim() == 2);
1782 }
1783 }
1784
linalg_cholesky(const Tensor & A,bool upper)1785 Tensor linalg_cholesky(const Tensor& A, bool upper) {
1786 auto [L, info] = at::linalg_cholesky_ex(A, upper, /*check_errors=*/false);
1787 at::_linalg_check_errors(info, "linalg.cholesky", A.dim() == 2);
1788 return L;
1789 }
1790
linalg_cholesky_out(const Tensor & A,bool upper,Tensor & L)1791 Tensor& linalg_cholesky_out(const Tensor& A, bool upper, Tensor& L) {
1792 auto info = at::empty({0}, A.options().dtype(kInt));
1793 at::linalg_cholesky_ex_out(L, info, A, upper, /*check_errors=*/false);
1794 at::_linalg_check_errors(info, "linalg.cholesky", A.dim() == 2);
1795 return L;
1796 }
1797
1798 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1799
1800 DEFINE_DISPATCH(cholesky_inverse_stub);
1801
cholesky_inverse_out_info(Tensor & result,Tensor & infos,const Tensor & input,bool upper)1802 static Tensor& cholesky_inverse_out_info(Tensor& result, Tensor& infos, const Tensor& input, bool upper) {
1803 TORCH_INTERNAL_ASSERT(input.dim() >= 2);
1804 TORCH_INTERNAL_ASSERT(input.size(-1) == input.size(-2));
1805
1806 TORCH_INTERNAL_ASSERT(result.scalar_type() == input.scalar_type());
1807 TORCH_INTERNAL_ASSERT(result.device() == input.device());
1808
1809 TORCH_INTERNAL_ASSERT(infos.scalar_type() == at::kInt);
1810 TORCH_INTERNAL_ASSERT(infos.device() == at::kCPU);
1811 TORCH_INTERNAL_ASSERT(infos.numel() == std::max<int64_t>(1, batchCount(input)));
1812
1813 // if result has no elements we can modify it
1814 if (result.numel() == 0) {
1815 at::native::resize_as_(result, input.mT(), MemoryFormat::Contiguous);
1816 result.transpose_(-2, -1);
1817 }
1818
1819 // result tensor must be in batched column major order (Fortran contiguous)
1820 TORCH_INTERNAL_ASSERT(result.mT().is_contiguous());
1821 TORCH_INTERNAL_ASSERT(result.sizes().equals(input.sizes()));
1822
1823 // cholesky_inverse_stub (apply_cholesky_inverse) performs calculations in-place and result must be a copy of input
1824 result.copy_(input);
1825
1826 // infos must be contiguous
1827 TORCH_INTERNAL_ASSERT(infos.is_contiguous());
1828 infos.fill_(0);
1829
1830 result = cholesky_inverse_stub(result.device().type(), result, infos, upper);
1831 return result;
1832 }
1833
cholesky_inverse_out(const Tensor & input,bool upper,Tensor & result)1834 Tensor& cholesky_inverse_out(const Tensor &input, bool upper, Tensor &result) {
1835 squareCheckInputs(input, "cholesky_inverse");
1836 checkSameDevice("cholesky_inverse", result, input);
1837 checkLinalgCompatibleDtype("cholesky_inverse", result, input);
1838
1839 // MAGMA requires 'infos' to reside in CPU memory, therefore we create 'infos' only on CPU for now.
1840 auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt).device(kCPU));
1841
1842 bool result_input_same_type = (result.scalar_type() == input.scalar_type());
1843 bool result_equal_expected_shape = result.sizes().equals(input.sizes());
1844 bool is_batched_column_major = false;
1845 if (result.dim() >= 2) {
1846 is_batched_column_major = result.mT().is_contiguous();
1847 }
1848
1849 // if result is not empty and not in batched column major format
1850 bool copy_needed = (result.numel() != 0 && !is_batched_column_major);
1851 copy_needed |= !result_input_same_type; // or result does not have the same dtype as input
1852 copy_needed |= (result.numel() != 0 && !result_equal_expected_shape); // or result does not have the expected shape
1853 // we have to allocate a temporary tensor
1854 if (copy_needed) {
1855 Tensor result_tmp = at::empty({0}, input.options());
1856 result_tmp = cholesky_inverse_out_info(result_tmp, infos, input, upper);
1857 at::native::resize_output(result, result_tmp.sizes());
1858 result.copy_(result_tmp);
1859 } else {
1860 // use result's memory directly
1861 result = cholesky_inverse_out_info(result, infos, input, upper);
1862 }
1863
1864 // Now check LAPACK/MAGMA error codes
1865 at::_linalg_check_errors(infos, "cholesky_inverse", result.dim() == 2);
1866 return result;
1867 }
1868
cholesky_inverse(const Tensor & input,bool upper)1869 Tensor cholesky_inverse(const Tensor &input, bool upper) {
1870 Tensor result = at::empty({0}, input.options());
1871 result = at::cholesky_inverse_out(result, input, upper);
1872 return result;
1873 }
1874
1875 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1876
1877 // Auxiliary function that returns the LU decomposition to use it in the backward
TORCH_IMPL_FUNC(_linalg_solve_ex_out)1878 TORCH_IMPL_FUNC(_linalg_solve_ex_out)(const Tensor& A,
1879 const Tensor& B,
1880 bool left,
1881 bool check_errors,
1882 const Tensor& result,
1883 const Tensor& LU,
1884 const Tensor& pivots,
1885 const Tensor& info) {
1886 // Possible optimization: Compute the LU factorization of A^T if A is contiguous
1887 // Then we solve A^T X = B with adjoint=True
1888 // This saves a copy as A doesn't need to be copied into an F-contig matrix in lu_factor
1889 // This optimization makes functorch's batching rule difficult. See NOTE [ solve_ex Batch Rule Contiguity ]
1890 const bool use_A_T = A.is_contiguous() && !A.is_complex();
1891 at::linalg_lu_factor_ex_out(const_cast<Tensor&>(LU),
1892 const_cast<Tensor&>(pivots),
1893 const_cast<Tensor&>(info),
1894 use_A_T ? A.mT() : A);
1895 if (check_errors) {
1896 at::_linalg_check_errors(info, "torch.linalg.solve_ex", A.dim() == 2);
1897 }
1898
1899 // [numpy-compat] Handle vectors on the rhs
1900 const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, B);
1901 auto result_ = vector_case ? result.unsqueeze(-1) : result;
1902 auto B_ = vector_case ? B.unsqueeze(-1) : B;
1903 at::linalg_lu_solve_out(result_, LU, pivots, B_, left, /*adjoint*/use_A_T);
1904 }
1905
linalg_solve_ex_out(const Tensor & A,const Tensor & B,bool left,bool check_errors,Tensor & result,Tensor & info)1906 std::tuple<Tensor&, Tensor&> linalg_solve_ex_out(const Tensor& A,
1907 const Tensor& B,
1908 bool left,
1909 bool check_errors,
1910 Tensor& result,
1911 Tensor& info) {
1912 auto LU = B.new_empty({0});
1913 auto pivots = B.new_empty({0}, kInt);
1914 at::_linalg_solve_ex_out(result, LU, pivots, info, A, B, left, check_errors);
1915 return std::tie(result, info);
1916 }
1917
1918 // We implement linalg_solve_ex as a composite function of _linalg_solve
linalg_solve_ex(const Tensor & A,const Tensor & B,bool left,bool check_errors)1919 std::tuple<Tensor, Tensor> linalg_solve_ex(const Tensor& A,
1920 const Tensor& B,
1921 bool left,
1922 bool check_errors) {
1923 auto [result, LU, pivots, info] = at::_linalg_solve_ex(A, B, left, check_errors);
1924 return std::make_tuple(std::move(result), std::move(info));
1925 }
1926
linalg_solve_out(const Tensor & A,const Tensor & B,bool left,Tensor & result)1927 Tensor& linalg_solve_out(const Tensor& A,
1928 const Tensor& B,
1929 bool left,
1930 Tensor& result) {
1931 auto info = B.new_empty({0}, kInt);
1932 at::linalg_solve_ex_out(result, info, A, B, left);
1933 at::_linalg_check_errors(info, "torch.linalg.solve", A.dim() == 2);
1934 return result;
1935 }
1936
linalg_solve(const Tensor & A,const Tensor & B,bool left)1937 Tensor linalg_solve(const Tensor& A,
1938 const Tensor& B,
1939 bool left) {
1940 if (A.layout() == kSparseCsr) {
1941 return at::_spsolve(A, B, left);
1942 }
1943 auto [result, info] = at::linalg_solve_ex(A, B, left);
1944 at::_linalg_check_errors(info, "torch.linalg.solve", A.dim() == 2);
1945 return result;
1946 }
1947
1948 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_factor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1949
1950 DEFINE_DISPATCH(lu_factor_stub);
1951
TORCH_IMPL_FUNC(linalg_lu_factor_ex_out)1952 TORCH_IMPL_FUNC(linalg_lu_factor_ex_out)(const Tensor& A,
1953 bool pivot,
1954 bool check_errors,
1955 const Tensor& LU,
1956 const Tensor& pivots,
1957 const Tensor& info) {
1958 if (A.numel() == 0) {
1959 // zero out the infos as it will have one element if the input is a matrix of size (0, 0)
1960 info.zero_();
1961 return;
1962 }
1963 if (!LU.is_same(A)) {
1964 LU.copy_(A);
1965 }
1966
1967 lu_factor_stub(A.device().type(), LU, pivots, info, pivot);
1968
1969 if (check_errors) {
1970 at::_linalg_check_errors(info, "torch.linalg.lu_factor_ex", A.dim() == 2);
1971 }
1972 }
1973
linalg_lu_factor_out(const Tensor & A,bool pivot,Tensor & LU,Tensor & pivots)1974 std::tuple<Tensor&, Tensor&> linalg_lu_factor_out(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
1975 auto info = at::empty({0}, A.options().dtype(kInt));
1976 // We pass check_errors as we want to use lu_factor rather than lu_factor_ex in the errors
1977 at::linalg_lu_factor_ex_out(LU, pivots, info, A, pivot, /*check_errors=*/false);
1978 at::_linalg_check_errors(info, "torch.linalg.lu_factor", A.dim() == 2);
1979 return std::tie(LU, pivots);
1980 }
1981
linalg_lu_factor(const Tensor & A,bool pivot)1982 std::tuple<Tensor, Tensor> linalg_lu_factor(const Tensor& A, bool pivot) {
1983 auto [LU, pivots, info] = at::linalg_lu_factor_ex(A, pivot, /*check_errors=*/false);
1984 at::_linalg_check_errors(info, "torch.linalg.lu_factor", A.dim() == 2);
1985 return std::make_tuple(std::move(LU), std::move(pivots));
1986 }
1987
1988 // TODO Deprecate this function in favour of linalg_lu_factor_ex
_lu_with_info(const Tensor & self,bool compute_pivots,bool)1989 std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool) {
1990 TORCH_WARN_ONCE(
1991 "torch.lu is deprecated in favor of torch.linalg.lu_factor / torch.linalg.lu_factor_ex and will be ",
1992 "removed in a future PyTorch release.\n",
1993 "LU, pivots = torch.lu(A, compute_pivots)\n",
1994 "should be replaced with\n",
1995 "LU, pivots = torch.linalg.lu_factor(A, compute_pivots)\n",
1996 "and\n",
1997 "LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)\n",
1998 "should be replaced with\n",
1999 "LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)"
2000 );
2001 return at::linalg_lu_factor_ex(self, compute_pivots, false);
2002 }
2003
2004 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2005
2006 DEFINE_DISPATCH(unpack_pivots_stub);
2007
TORCH_IMPL_FUNC(linalg_lu_out)2008 TORCH_IMPL_FUNC(linalg_lu_out)(const Tensor& A,
2009 bool pivot,
2010 const Tensor& P,
2011 const Tensor& L,
2012 const Tensor& U) {
2013 const auto m = A.sizes().end()[-2];
2014 const auto n = A.sizes().end()[-1];
2015
2016 // A.shape[-2:] == (m, n)
2017 // P.shape[-2:] == (m, m)
2018 // L.shape[-2:] == (m, k)
2019 // U.shape[-2:] == (k, n)
2020 // with k = min(m, n)
2021
2022 // Use L as it has the correct size
2023 const bool use_L = m > n;
2024 auto pivots = at::empty({0}, A.options().dtype(kInt));
2025 auto info = at::empty({0}, A.options().dtype(kInt));
2026 at::linalg_lu_factor_ex_out(const_cast<Tensor&>(use_L ? L : U),
2027 const_cast<Tensor&>(pivots),
2028 const_cast<Tensor&>(info),
2029 A,
2030 pivot,
2031 /*check_errors=*/false);
2032 at::lu_unpack_out(const_cast<Tensor&>(P),
2033 const_cast<Tensor&>(L),
2034 const_cast<Tensor&>(U),
2035 use_L ? L : U,
2036 pivots,
2037 /*unpack_data=*/true,
2038 /*unpack_pivots=*/pivot);
2039 }
2040
2041 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_unpack ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2042
TORCH_IMPL_FUNC(lu_unpack_out)2043 TORCH_IMPL_FUNC(lu_unpack_out)(const Tensor& LU,
2044 const Tensor& pivots,
2045 bool unpack_lu,
2046 bool unpack_pivots,
2047 const Tensor& P,
2048 const Tensor& L,
2049 const Tensor& U) {
2050 const auto m = LU.sizes().end()[-2];
2051 const auto n = LU.sizes().end()[-1];
2052
2053 // A.shape[-2:] == (m, n)
2054 // P.shape[-2:] == (m, m)
2055 // L.shape[-2:] == (m, k)
2056 // U.shape[-2:] == (k, n)
2057 // with k = min(m, n)
2058
2059 if (unpack_lu) {
2060 if (m > n || LU.is_same(L)) {
2061 // The order of triu and tril is important as we may have LU.is_same(L)
2062 at::triu_out(const_cast<Tensor&>(U), m == n ? LU : LU.narrow(-2, 0, n), 0);
2063 at::tril_out(const_cast<Tensor&>(L), LU, -1);
2064 L.diagonal(0, -2, -1).fill_(1.);
2065 } else {
2066 // The order of triu and tril is important as we may have LU.is_same(U)
2067 at::tril_out(const_cast<Tensor&>(L), m == n ? LU : LU.narrow(-1, 0, m), -1);
2068 L.diagonal(0, -2, -1).fill_(1.);
2069 at::triu_out(const_cast<Tensor&>(U), LU, 0);
2070 }
2071 }
2072 if (unpack_pivots) {
2073 // lu_factor_ex returns an int32 1-based indexing, which is what we have in `pivots`
2074 // We transform that to a proper permutation of the indices {0, ..., m-1}
2075 const auto perm_sizes = IntArrayRef(P.sizes().data(), P.dim() - 1);
2076
2077 // Fill `perm` with the identity permutation (perhaps batched)
2078 const auto perm = at::arange(m, pivots.options().memory_format(at::MemoryFormat::Contiguous).dtype(kLong))
2079 .expand(perm_sizes)
2080 .contiguous();
2081
2082 // Note that perm is of type kLong and pivots is a 1-indexed kInt.
2083 // This is taken into account in the unpack_pivots kernel
2084 auto iter = TensorIteratorConfig()
2085 .set_check_mem_overlap(false)
2086 .check_all_same_dtype(false)
2087 .resize_outputs(false)
2088 .declare_static_shape(pivots.sizes(), /*squash_dims=*/pivots.dim() - 1)
2089 .add_output(perm)
2090 .add_owned_const_input(pivots.contiguous())
2091 .build();
2092
2093 unpack_pivots_stub(pivots.device().type(), iter, std::min(m, n), m);
2094
2095 // Transform the permutation into a permutation matrix
2096 P.zero_();
2097 P.scatter_(-2, perm.unsqueeze(-2), 1.);
2098 }
2099 }
2100
2101 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2102 DEFINE_DISPATCH(lu_solve_stub);
2103
TORCH_IMPL_FUNC(linalg_lu_solve_out)2104 TORCH_IMPL_FUNC(linalg_lu_solve_out)(const Tensor& LU,
2105 const Tensor& pivots,
2106 const Tensor& B,
2107 bool left,
2108 bool adjoint,
2109 const Tensor& result) {
2110 // Trivial case
2111 if (result.numel() == 0) {
2112 return;
2113 }
2114
2115 // Solve A^H X = B^H. Then we return X^H
2116 if (!left) {
2117 adjoint = !adjoint;
2118 result.transpose_(-2, -1);
2119 }
2120
2121 // Copy B (or B^H) into result
2122 if (!result.is_same(B)) {
2123 result.copy_(left ? B : B.mH());
2124 }
2125
2126 // Make LU / pivots F-contiguous
2127 auto pivots_ = pivots.expect_contiguous();
2128 auto LU_ = at::native::borrow_else_clone(
2129 LU.mT().is_contiguous(), LU, LU, /*contig=*/false);
2130
2131 const auto trans = !adjoint ? TransposeType::NoTranspose :
2132 LU.is_complex() ? TransposeType::ConjTranspose
2133 : TransposeType::Transpose;
2134
2135 lu_solve_stub(LU_->device().type(), *LU_, *pivots_, result, trans);
2136
2137 // Conj-transpose back in-place
2138 if (!left) {
2139 result.transpose_(-2, -1);
2140 if (result.is_complex()) {
2141 result._set_conj(!result.is_conj());
2142 }
2143 }
2144 }
2145
lu_solve(const Tensor & self,const Tensor & LU_data,const Tensor & LU_pivots)2146 Tensor lu_solve(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots) {
2147 TORCH_WARN_ONCE(
2148 "torch.lu_solve is deprecated in favor of torch.linalg.lu_solve",
2149 "and will be removed in a future PyTorch release.\n",
2150 "Note that torch.linalg.lu_solve has its arguments reversed.\n",
2151 "X = torch.lu_solve(B, LU, pivots)\n",
2152 "should be replaced with\n",
2153 "X = torch.linalg.lu_solve(LU, pivots, B)"
2154 );
2155 return at::linalg_lu_solve(LU_data, LU_pivots, self);
2156 }
2157
lu_solve_out(const Tensor & self,const Tensor & LU_data,const Tensor & LU_pivots,Tensor & result)2158 Tensor& lu_solve_out(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, Tensor& result) {
2159 TORCH_WARN_ONCE(
2160 "torch.lu_solve is deprecated in favor of torch.linalg.lu_solve",
2161 "and will be removed in a future PyTorch release.\n",
2162 "Note that torch.linalg.lu_solve has its arguments reversed.\n",
2163 "X = torch.lu_solve(B, LU, pivots)\n",
2164 "should be replaced with\n",
2165 "X = torch.linalg.lu_solve(LU, pivots, B)"
2166 );
2167 return at::linalg_lu_solve_out(result, LU_data, LU_pivots, self);
2168 }
2169
2170 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2171
2172 DEFINE_DISPATCH(triangular_solve_stub);
2173
2174 /*
2175 Solves the matrix equation 'input' @ 'result' = 'other' for the 'result'.
2176 The result of the computation is saved in-place in 'result' tensor,
2177 'clone_input' will be a copy of 'input',
2178 'infos' is used to store information for possible checks for error,
2179 'upper' controls the portion of input matrix to consider in computations,
2180 'transpose' if true then 'input.mT()' @ 'result' = 'other' is solved,
2181 'unitriangular' if true then the diagonal elements of 'input' are assumed to be 1
2182 and the actual diagonal values are not used.
2183 */
triangular_solve_out_impl(const Tensor & result,const Tensor & clone_input,const Tensor & input,const Tensor & other,bool upper,bool transpose,bool unitriangular)2184 static void triangular_solve_out_impl(
2185 const Tensor& result,
2186 const Tensor& clone_input,
2187 const Tensor& input,
2188 const Tensor& other,
2189 bool upper, bool transpose, bool unitriangular) {
2190 TORCH_WARN_ONCE(
2191 "torch.triangular_solve is deprecated in favor of torch.linalg.solve_triangular",
2192 "and will be removed in a future PyTorch release.\n",
2193 "torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.\n",
2194 "X = torch.triangular_solve(B, A).solution\n",
2195 "should be replaced with\n",
2196 "X = torch.linalg.solve_triangular(A, B).");
2197 // These internal asserts make explicit the assumptions in the implementation
2198 // Error check with the actual error messages are done on the higher level of
2199 // the hierarchy of calls
2200 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2);
2201 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.size(-2) == input.size(-1));
2202
2203 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == other.device());
2204 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == result.device());
2205 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == clone_input.device());
2206
2207 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.scalar_type() == other.scalar_type());
2208 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.scalar_type() == result.scalar_type());
2209 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.scalar_type() == clone_input.scalar_type());
2210
2211 // if 'result' has no elements we can modify it
2212 if (result.numel() == 0) {
2213 result.resize_(other.mT().sizes(), MemoryFormat::Contiguous);
2214 result.transpose_(-2, -1); // make 'result' to have Fortran contiguous memory layout
2215 }
2216
2217 // if 'clone_input' has no elements we can modify it
2218 if (clone_input.numel() == 0) {
2219 clone_input.resize_(input.mT().sizes(), MemoryFormat::Contiguous);
2220 clone_input.transpose_(-2, -1); // make 'clone_input' to have Fortran contiguous memory layout
2221 }
2222
2223 // 'result' and 'clone_input' must be in batched column major order (Fortran contiguous)
2224 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.mT().is_contiguous());
2225 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(clone_input.mT().is_contiguous());
2226
2227 // triangular_solve_stub performs calculations in-place
2228 // 'result' must be a copy of 'other'
2229 // 'clone_input' must be a copy of 'input'
2230 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(other.sizes()));
2231 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(clone_input.sizes().equals(input.sizes()));
2232 result.copy_(other);
2233 clone_input.copy_(input);
2234
2235 triangular_solve_stub(input.device().type(), clone_input, result, /*left=*/true, upper, transpose ? TransposeType::Transpose : TransposeType::NoTranspose, unitriangular);
2236 }
2237
TORCH_IMPL_FUNC(triangular_solve_out)2238 TORCH_IMPL_FUNC(triangular_solve_out)(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular, const Tensor& result, const Tensor& clone_A) {
2239 auto [self_broadcast, A_broadcast] = _linalg_broadcast_batch_dims(self, A, "triangular_solve");
2240
2241 bool copy_needed = !result.transpose(-2, -1).is_contiguous();
2242 copy_needed |= !clone_A.transpose(-2, -1).is_contiguous();
2243
2244 if (copy_needed) {
2245 Tensor result_tmp = at::empty({0}, self.options());
2246 Tensor clone_A_tmp = at::empty({0}, A.options());
2247
2248 triangular_solve_out_impl(result_tmp, clone_A_tmp, A_broadcast, self_broadcast, upper, transpose, unitriangular);
2249
2250 result.copy_(result_tmp);
2251 clone_A.copy_(clone_A_tmp);
2252 } else {
2253 triangular_solve_out_impl(result, clone_A, A_broadcast, self_broadcast, upper, transpose, unitriangular);
2254 }
2255 }
2256
2257 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ qr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2258
2259 DEFINE_DISPATCH(geqrf_stub);
2260
geqrf_out_helper(const Tensor & input,const Tensor & QR,const Tensor & tau)2261 static void geqrf_out_helper(const Tensor& input, const Tensor& QR, const Tensor& tau) {
2262 TORCH_INTERNAL_ASSERT(input.dim() >= 2);
2263
2264 TORCH_INTERNAL_ASSERT(input.scalar_type() == QR.scalar_type());
2265 TORCH_INTERNAL_ASSERT(input.device() == QR.device());
2266
2267 TORCH_INTERNAL_ASSERT(input.scalar_type() == tau.scalar_type());
2268 TORCH_INTERNAL_ASSERT(input.device() == tau.device());
2269
2270 // if 'QR' has no elements we can modify it
2271 if (QR.numel() == 0) {
2272 QR.resize_as_(input.mT(), MemoryFormat::Contiguous);
2273 QR.transpose_(-2, -1); // make Fortran-contiguous
2274 }
2275
2276 auto expected_batch_tau_shape = IntArrayRef(input.sizes().data(), input.dim() - 2).vec(); // input.shape[:-2]
2277 expected_batch_tau_shape.push_back(std::min(input.size(-2), input.size(-1)));
2278 if (tau.numel() == 0) {
2279 tau.resize_(expected_batch_tau_shape);
2280 }
2281
2282 // QR tensor must be in batched column major order (Fortran contiguous)
2283 TORCH_INTERNAL_ASSERT(QR.mT().is_contiguous());
2284 TORCH_INTERNAL_ASSERT(QR.sizes().equals(input.sizes()));
2285
2286 // tau tensor must be contiguous
2287 TORCH_INTERNAL_ASSERT(tau.is_contiguous());
2288 TORCH_INTERNAL_ASSERT(tau.sizes().equals(expected_batch_tau_shape));
2289
2290 // geqrf_stub (apply_geqrf) performs calculations in-place and 'QR' must be a copy of input
2291 QR.copy_(input);
2292 geqrf_stub(input.device().type(), QR, tau);
2293 }
2294
geqrf_out(const Tensor & input,Tensor & QR,Tensor & tau)2295 std::tuple<Tensor&, Tensor&> geqrf_out(const Tensor& input, Tensor& QR, Tensor& tau) {
2296 TORCH_CHECK(input.dim() >= 2, "torch.geqrf: input must have at least 2 dimensions.");
2297
2298 checkSameDevice("torch.geqrf", QR, input, "a"); // 'a' is used in documentation and native_functions.yml
2299 checkSameDevice("torch.geqrf", tau, input, "tau");
2300 checkLinalgCompatibleDtype("torch.geqrf", QR, input, "a");
2301 checkLinalgCompatibleDtype("torch.geqrf", tau, input, "tau");
2302
2303 bool QR_input_same_type = (QR.scalar_type() == input.scalar_type());
2304 bool tau_input_same_type = (tau.scalar_type() == input.scalar_type());
2305 bool QR_equal_expected_shape = QR.sizes().equals(input.sizes());
2306
2307 auto expected_batch_tau_shape = IntArrayRef(input.sizes().data(), input.dim() - 2).vec(); // input.shape[:-2]
2308 expected_batch_tau_shape.push_back(std::min(input.size(-2), input.size(-1)));
2309 bool tau_equal_expected_shape = tau.sizes().equals(expected_batch_tau_shape);
2310
2311 bool is_batched_column_major = false;
2312 if (QR.dim() >= 2) {
2313 is_batched_column_major = QR.mT().is_contiguous();
2314 }
2315
2316 // if 'QR' is not empty and not in batched column major format
2317 bool copy_needed = (QR.numel() != 0 && !is_batched_column_major);
2318 copy_needed |= (QR.numel() != 0 && !QR_equal_expected_shape); // or 'QR' does not have the expected shape
2319 copy_needed |= !QR_input_same_type; // or 'QR' does not have the same dtype as input
2320 // we have to allocate a temporary tensor
2321
2322 copy_needed |= (tau.numel() != 0 && !tau.is_contiguous());
2323 copy_needed |= (tau.numel() != 0 && !tau_equal_expected_shape); // or 'tau' does not have the expected shape
2324 copy_needed |= !tau_input_same_type; // or 'tau' does not have the same dtype as input
2325
2326 if (copy_needed) {
2327 Tensor QR_tmp = at::empty({0}, input.options());
2328 Tensor tau_tmp = at::empty({0}, input.options());
2329
2330 geqrf_out_helper(input, QR_tmp, tau_tmp);
2331
2332 at::native::resize_output(QR, QR_tmp.sizes());
2333 QR.copy_(QR_tmp);
2334 at::native::resize_output(tau, tau_tmp.sizes());
2335 tau.copy_(tau_tmp);
2336 } else {
2337 // use "out" tensors' storage directly
2338 geqrf_out_helper(input, QR, tau);
2339 }
2340
2341 return std::tuple<Tensor&, Tensor&>(QR, tau);
2342 }
2343
geqrf(const Tensor & input)2344 std::tuple<Tensor, Tensor> geqrf(const Tensor& input) {
2345 Tensor QR = at::empty({0}, input.options());
2346 Tensor tau = at::empty({0}, input.options());
2347 std::tie(QR, tau) = at::geqrf_outf(input, QR, tau);
2348 return std::make_tuple(std::move(QR), std::move(tau));
2349 }
2350
2351 /*
2352 Computes the QR decomposition using GEQRF and ORGQR operations.
2353 This is an in-place function and Q, R tensors must have correct shape and be Fortran contiguous.
2354
2355 Args:
2356 * `input` - [in] Input tensor for QR decomposition
2357 * `Q` - [out] Tensor containing the Q matrices of QR decomposition
2358 * `R` - [out] Tensor containing the R matrices of QR decomposition
2359 * `compute_q` - controls whether the Q tensor is computed
2360 * `reduced_mode` - controls the size of Q and R tensors
2361
2362 For further details, please see the LAPACK documentation for GEQRF and ORGQR.
2363 */
TORCH_IMPL_FUNC(linalg_qr_out)2364 TORCH_IMPL_FUNC(linalg_qr_out)(const Tensor& A,
2365 c10::string_view mode,
2366 const Tensor & Q,
2367 const Tensor & R) {
2368 auto m = A.size(-2);
2369 auto n = A.size(-1);
2370 auto k = std::min(m, n);
2371 auto [compute_q, reduced_mode] = at::native::_parse_qr_mode(mode);
2372
2373
2374 // We need an auxiliary tensor to call geqrf
2375 auto tau_shape = A.sizes().vec();
2376 tau_shape.pop_back();
2377 tau_shape.back() = k;
2378 auto tau = A.new_empty(tau_shape);
2379
2380 // geqrf requires m x n workspace input that is modified in-place
2381 // We try to use Q. If it doesn't fit, we try to use R
2382 // If m > n and compute_q==false, it won't fit into Q or R, so we neet to create an auxiliary tensor
2383 Tensor QR;
2384 if (compute_q && Q.size(-1) == n) {
2385 QR = Q;
2386 QR.copy_(A);
2387 } else if (R.size(-2) == m) {
2388 QR = R;
2389 QR.copy_(A);
2390 } else {
2391 QR = cloneBatchedColumnMajor(A);
2392 }
2393
2394 geqrf_stub(A.device().type(), QR, tau);
2395
2396 // Split QR into Q (unless compute_q == false) and R
2397 if (QR.is_alias_of(R)) {
2398 // Copy QR into Q
2399 if (compute_q) {
2400 // If the result didn't fit in Q and compute_q == true is because Q is not of size m x n (i.e. it's of size m x m)
2401 TORCH_INTERNAL_ASSERT(Q.size(-1) == m);
2402 if (m < n) {
2403 Q.copy_(QR.slice(-1, 0, m));
2404 } else {
2405 Q.slice(-1, 0, n).copy_(QR);
2406 }
2407 }
2408 R.triu_();
2409 } else {
2410 // Copy QR into R from Q or the aux tensor
2411 at::triu_out(const_cast<Tensor&>(R), QR.slice(-2, 0, n));
2412 }
2413
2414 if (compute_q) {
2415 // Next perform ORGQR for Q using the result from GEQRF
2416 orgqr_stub(A.device().type(), const_cast<Tensor&>(Q), tau);
2417 }
2418 }
2419
2420
qr(const Tensor & self,bool some)2421 std::tuple<Tensor,Tensor> qr(const Tensor& self, bool some) {
2422 TORCH_WARN_ONCE(
2423 "torch.qr is deprecated in favor of torch.linalg.qr and will be removed in a future PyTorch release.\n",
2424 "The boolean parameter 'some' has been replaced with a string parameter 'mode'.\n",
2425 "Q, R = torch.qr(A, some)\n",
2426 "should be replaced with\n",
2427 "Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete')"
2428 );
2429 const char* mode = some ? "reduced" : "complete";
2430 return at::linalg_qr(self, mode);
2431 }
2432
qr_out(const Tensor & self,bool some,Tensor & Q,Tensor & R)2433 std::tuple<Tensor&,Tensor&> qr_out(const Tensor& self, bool some, Tensor& Q, Tensor& R) {
2434 TORCH_WARN_ONCE(
2435 "torch.qr is deprecated in favor of torch.linalg.qr and will be removed in a future PyTorch release.\n",
2436 "The boolean parameter 'some' has been replaced with a string parameter 'mode'.\n",
2437 "Q, R = torch.qr(A, some)\n",
2438 "should be replaced with\n",
2439 "Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete')"
2440 );
2441 const char* mode = some ? "reduced" : "complete";
2442 return at::linalg_qr_out(Q, R, self, mode);
2443 }
2444
2445 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ orgqr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2446
2447 DEFINE_DISPATCH(orgqr_stub);
2448
2449 /*
2450 The householder_product (orgqr) function allows reconstruction of an orthogonal (or unitary) matrix Q,
2451 from a sequence of elementary reflectors, such as is produced by the geqrf function.
2452
2453 Args:
2454 * `input` - Tensor with the directions of the elementary reflectors below the diagonal.
2455 * `tau` - Tensor containing the magnitudes of the elementary reflectors.
2456 * `result` - result Tensor, which will contain the orthogonal (or unitary) matrix Q.
2457
2458 For further details, please see the LAPACK/MAGMA documentation.
2459 */
householder_product_out_helper(const Tensor & input,const Tensor & tau,Tensor & result)2460 static Tensor& householder_product_out_helper(const Tensor& input, const Tensor& tau, Tensor& result) {
2461 TORCH_INTERNAL_ASSERT(input.dim() >= 2);
2462 TORCH_INTERNAL_ASSERT(input.size(-2) >= input.size(-1));
2463 TORCH_INTERNAL_ASSERT(input.size(-1) >= tau.size(-1));
2464
2465 TORCH_INTERNAL_ASSERT(input.scalar_type() == tau.scalar_type());
2466 TORCH_INTERNAL_ASSERT(input.device() == tau.device());
2467
2468 TORCH_INTERNAL_ASSERT(result.scalar_type() == input.scalar_type());
2469 TORCH_INTERNAL_ASSERT(result.device() == input.device());
2470
2471 // if result has no elements we can modify it
2472 if (result.numel() == 0) {
2473 at::native::resize_as_(result, input.mT(), MemoryFormat::Contiguous);
2474 result.transpose_(-2, -1);
2475 }
2476
2477 // result tensor must be in batched column major order (Fortran contiguous)
2478 TORCH_INTERNAL_ASSERT(result.mT().is_contiguous());
2479 TORCH_INTERNAL_ASSERT(result.sizes().equals(input.sizes()));
2480
2481 // tau tensor must be contiguous
2482 Tensor tau_ = tau;
2483 if (!tau.is_contiguous()) {
2484 tau_ = at::empty(tau.sizes(), tau.options(), MemoryFormat::Contiguous);
2485 tau_.copy_(tau);
2486 }
2487
2488 // orgqr_stub (apply_orgqr) performs calculations in-place and result must be a copy of input
2489 result.copy_(input);
2490
2491 result = orgqr_stub(result.device().type(), result, tau_);
2492 return result;
2493 }
2494
linalg_householder_product_out(const Tensor & input,const Tensor & tau,Tensor & result)2495 Tensor& linalg_householder_product_out(const Tensor& input, const Tensor& tau, Tensor& result) {
2496 TORCH_CHECK(input.dim() >= 2, "torch.linalg.householder_product: input must have at least 2 dimensions.");
2497 TORCH_CHECK(
2498 input.size(-2) >= input.size(-1),
2499 "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]");
2500 TORCH_CHECK(
2501 input.size(-1) >= tau.size(-1),
2502 "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]");
2503
2504 TORCH_CHECK(
2505 input.dim() - tau.dim() == 1,
2506 "torch.linalg.householder_product: Expected tau to have one dimension less than input, but got tau.ndim equal to ",
2507 tau.dim(),
2508 " and input.ndim is equal to ",
2509 input.dim());
2510 if (input.dim() > 2) {
2511 auto expected_batch_tau_shape = IntArrayRef(input.sizes().data(), input.dim() - 2); // input.shape[:-2]
2512 auto actual_batch_tau_shape = IntArrayRef(tau.sizes().data(), tau.dim() - 1); // tau.shape[:-1]
2513 TORCH_CHECK(
2514 actual_batch_tau_shape.equals(expected_batch_tau_shape),
2515 "torch.linalg.householder_product: Expected batch dimensions of tau to be equal to input.shape[:-2], but got ",
2516 actual_batch_tau_shape);
2517 }
2518
2519 TORCH_CHECK(
2520 tau.scalar_type() == input.scalar_type(),
2521 "torch.linalg.householder_product: tau dtype ",
2522 tau.scalar_type(),
2523 " does not match input dtype ",
2524 input.scalar_type());
2525 checkSameDevice("torch.linalg.householder_product", tau, input, "tau");
2526 checkSameDevice("torch.linalg.householder_product", result, input);
2527 checkLinalgCompatibleDtype("torch.linalg.householder_product", result, input);
2528
2529 // TODO: uncomment the following when passing incorrectly sized 'result' is not allowed
2530 // if (result.numel() != 0) {
2531 // // Resize messes up the strides, so let's not use at::native::resize_output
2532 // TORCH_CHECK(result.sizes().equals(input.sizes()),
2533 // "result shape ", result.sizes(), " does not match input shape ", input.sizes());
2534 // }
2535
2536 bool result_input_same_type = (result.scalar_type() == input.scalar_type());
2537 bool result_equal_expected_shape = result.sizes().equals(input.sizes());
2538 bool is_batched_column_major = false;
2539 if (result.dim() >= 2) {
2540 is_batched_column_major = result.mT().is_contiguous();
2541 }
2542
2543 // if result is not empty and not in batched column major format
2544 bool copy_needed = (result.numel() != 0 && !is_batched_column_major);
2545 copy_needed |= !result_input_same_type; // or result does not have the same dtype as input
2546 copy_needed |= (result.numel() != 0 && !result_equal_expected_shape); // or result does not have the expected shape
2547 // we have to allocate a temporary tensor
2548 if (copy_needed) {
2549 Tensor result_tmp = at::empty({0}, input.options());
2550 result_tmp = householder_product_out_helper(input, tau, result_tmp);
2551 at::native::resize_output(result, result_tmp.sizes());
2552 result.copy_(result_tmp);
2553 } else {
2554 // use result's storage directly
2555 result = householder_product_out_helper(input, tau, result);
2556 }
2557
2558 return result;
2559 }
2560
linalg_householder_product(const Tensor & input,const Tensor & tau)2561 Tensor linalg_householder_product(const Tensor& input, const Tensor& tau) {
2562 Tensor result = at::empty({0}, input.options());
2563 result = at::linalg_householder_product_outf(input, tau, result);
2564 return result;
2565 }
2566
2567 // torch.orgqr is an alias of torch.linalg.householder_product
2568 // torch.linalg.householder_product is the preferred new function
orgqr_out(const Tensor & input,const Tensor & tau,Tensor & result)2569 Tensor& orgqr_out(const Tensor& input, const Tensor& tau, Tensor& result) {
2570 return at::linalg_householder_product_outf(input, tau, result);
2571 }
2572
orgqr(const Tensor & input,const Tensor & tau)2573 Tensor orgqr(const Tensor& input, const Tensor& tau) {
2574 return at::linalg_householder_product(input, tau);
2575 }
2576
2577 DEFINE_DISPATCH(ormqr_stub);
2578
ormqr_out_helper(const Tensor & input,const Tensor & tau,const Tensor & other,const Tensor & result,bool left,bool transpose)2579 static void ormqr_out_helper(const Tensor& input, const Tensor& tau, const Tensor& other, const Tensor& result, bool left, bool transpose) {
2580 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2);
2581 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other.dim() >= 2);
2582
2583 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other.size(left ? -2 : -1) >= tau.size(-1));
2584 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other.size(left ? -2 : -1) == input.size(-2));
2585
2586 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.scalar_type() == tau.scalar_type());
2587 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == tau.device());
2588
2589 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.scalar_type() == other.scalar_type());
2590 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == other.device());
2591
2592 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.scalar_type() == input.scalar_type());
2593 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.device() == input.device());
2594
2595 // if 'result' has no elements we can modify it
2596 if (result.numel() == 0) {
2597 at::native::resize_as_(result, other.mT(), MemoryFormat::Contiguous);
2598 result.transpose_(-2, -1);
2599 }
2600
2601 // 'result' tensor must be in batched column major order (Fortran contiguous)
2602 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.mT().is_contiguous());
2603 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(other.sizes()));
2604
2605 // 'tau' tensor must be contiguous
2606 Tensor tau_ = tau;
2607 if (!tau.is_contiguous()) {
2608 tau_ = at::empty(tau.sizes(), tau.options(), MemoryFormat::Contiguous);
2609 tau_.copy_(tau);
2610 }
2611
2612 // 'input' tensor must be Fortran contiguous
2613 Tensor input_ = input;
2614 if (!input.mT().is_contiguous()) {
2615 input_ = at::empty(input.mT().sizes(), input.options(), MemoryFormat::Contiguous);
2616 input_.transpose_(-2, -1);
2617 input_.copy_(input);
2618 }
2619
2620 // ormqr_stub (apply_ormqr) performs calculations in-place and 'result' must be a copy of 'other'
2621 result.copy_(other);
2622
2623 ormqr_stub(result.device().type(), input_, tau_, result, left, transpose);
2624 }
2625
ormqr_out(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose,Tensor & result)2626 Tensor& ormqr_out(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose, Tensor& result) {
2627 TORCH_CHECK(input.dim() >= 2, "torch.ormqr: input must have at least 2 dimensions.");
2628 TORCH_CHECK(other.dim() >= 2, "torch.ormqr: other must have at least 2 dimensions.");
2629
2630 int64_t left_size_condition = left ? -2 : -1;
2631 TORCH_CHECK(
2632 other.size(left_size_condition) >= tau.size(-1),
2633 "torch.ormqr: other.shape[",
2634 left_size_condition,
2635 "] must be greater than or equal to tau.shape[-1]");
2636
2637 TORCH_CHECK(
2638 other.size(left_size_condition) == input.size(-2),
2639 "torch.ormqr: other.shape[",
2640 left_size_condition,
2641 "] must be equal to input.shape[-2]");
2642
2643 TORCH_CHECK(
2644 tau.size(-1) <= input.size(-1),
2645 "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]");
2646
2647 TORCH_CHECK(
2648 input.dim() - tau.dim() == 1,
2649 "torch.ormqr: ",
2650 "Expected tau to have one dimension less than input, but got tau.ndim equal to ",
2651 tau.dim(),
2652 " and input.ndim is equal to ",
2653 input.dim());
2654 TORCH_CHECK(
2655 input.dim() == other.dim(),
2656 "torch.ormqr: ",
2657 "Expected other to have the same number of dimensions as input, but got other.ndim equal to ",
2658 other.dim(),
2659 " and input.ndim is equal to ",
2660 input.dim());
2661
2662 if (input.dim() > 2) {
2663 auto expected_batch_shape = IntArrayRef(input.sizes().data(), input.dim() - 2); // input.shape[:-2]
2664 auto actual_batch_tau_shape = IntArrayRef(tau.sizes().data(), tau.dim() - 1); // tau.shape[:-1]
2665 TORCH_CHECK(
2666 actual_batch_tau_shape.equals(expected_batch_shape),
2667 "torch.ormqr: Expected batch dimensions of tau to be equal to input.shape[:-2], but got ",
2668 actual_batch_tau_shape);
2669
2670 auto actual_batch_other_shape = IntArrayRef(other.sizes().data(), other.dim() - 2); // other.shape[:-2]
2671 TORCH_CHECK(
2672 actual_batch_other_shape.equals(expected_batch_shape),
2673 "torch.ormqr: Expected batch dimensions of other to be equal to input.shape[:-2], but got ",
2674 actual_batch_other_shape);
2675 }
2676
2677 TORCH_CHECK(
2678 tau.scalar_type() == input.scalar_type(),
2679 "torch.ormqr: Expected input and tau to have the same dtype, but input has dtype", input.scalar_type(),
2680 " and tau has dtype ", tau.scalar_type());
2681 TORCH_CHECK(
2682 other.scalar_type() == input.scalar_type(),
2683 "torch.ormqr: Expected input and other to have the same dtype, but input has dtype", input.scalar_type(),
2684 " and other has dtype ", other.scalar_type());
2685 TORCH_CHECK(
2686 result.scalar_type() == input.scalar_type(),
2687 "torch.ormqr: Expected input and result to have the same dtype, but input has dtype", input.scalar_type(),
2688 " and result has dtype ", result.scalar_type());
2689
2690 checkSameDevice("torch.ormqr", tau, input, "tau");
2691 checkSameDevice("torch.ormqr", other, input, "other");
2692 checkSameDevice("torch.ormqr", result, input);
2693
2694 bool result_equal_expected_shape = result.sizes().equals(other.sizes());
2695 bool is_batched_column_major = false;
2696 if (result.dim() >= 2) {
2697 is_batched_column_major = result.mT().is_contiguous();
2698 }
2699
2700 // if result is not empty and not in batched column major format
2701 bool copy_needed = (result.numel() != 0 && !is_batched_column_major);
2702 copy_needed |= (result.numel() != 0 && !result_equal_expected_shape); // or result does not have the expected shape
2703 // we have to allocate a temporary tensor
2704 if (copy_needed) {
2705 Tensor result_tmp = at::empty({0}, input.options());
2706 ormqr_out_helper(input, tau, other, result_tmp, left, transpose);
2707 at::native::resize_output(result, result_tmp.sizes());
2708 result.copy_(result_tmp);
2709 } else {
2710 // use result's storage directly
2711 ormqr_out_helper(input, tau, other, result, left, transpose);
2712 }
2713
2714 return result;
2715 }
2716
ormqr(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose)2717 Tensor ormqr(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
2718 Tensor result = at::empty({0}, input.options());
2719 result = at::native::ormqr_out(input, tau, other, left, transpose, result);
2720 return result;
2721 }
2722
2723 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2724
2725 DEFINE_DISPATCH(linalg_eigh_stub);
2726
2727 /*
2728 Computes eigenvalues and eigenvectors of the tensor 'input'.
2729
2730 Args:
2731 * 'input' - input Tensor for eigendecomposition
2732 * 'values' - Tensor to store computed eigenvalues
2733 * 'vectors' - Tensor to store computed eigenvectors
2734 * 'infos' - Tensor to store LAPACK/MAGMA/cuSOLVER error codes
2735 * 'compute_eigenvectors' - controls whether eigenvectors should be computed
2736 * 'uplo' - controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L"
2737 "u", "U" - upper triangular portion of the input matrix is used in computations; "l", "L" - lower.
2738 */
2739
TORCH_IMPL_FUNC(_linalg_eigh_out)2740 TORCH_IMPL_FUNC(_linalg_eigh_out)(const Tensor& A,
2741 c10::string_view uplo,
2742 bool compute_v,
2743 const Tensor& L,
2744 const Tensor& V) {
2745 if (A.numel() == 0) {
2746 return;
2747 }
2748
2749 auto uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
2750 bool upper = (uplo_uppercase == 'U');
2751
2752 Tensor V_ = V;
2753 if (compute_v) {
2754 V_.copy_(A);
2755 } else {
2756 // We need a tensor to hold A
2757 V_ = cloneBatchedColumnMajor(A);
2758 }
2759
2760 const auto info = at::zeros(A.sizes().slice(0, A.dim() - 2), A.options().dtype(kInt));
2761 linalg_eigh_stub(A.device().type(), L, V_, info, upper, compute_v);
2762
2763 at::_linalg_check_errors(info, "linalg.eigh", /*is_matrix*/A.dim() == 2);
2764 }
2765
linalg_eigh(const Tensor & A,c10::string_view uplo)2766 std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& A, c10::string_view uplo) {
2767 // TODO (Good intro task) Implement linalg_eigh_ex_out
2768 return at::_linalg_eigh(A, uplo, /*compute_v*/true);
2769 }
2770
linalg_eigh_out(const Tensor & A,c10::string_view uplo,Tensor & L,Tensor & V)2771 std::tuple<Tensor&, Tensor&> linalg_eigh_out(const Tensor& A, c10::string_view uplo, Tensor& L, Tensor& V) {
2772 return at::_linalg_eigh_out(L, V, A, uplo, /*compute_v=*/true);
2773 }
2774
2775
linalg_eigvalsh(const Tensor & A,c10::string_view uplo)2776 Tensor linalg_eigvalsh(const Tensor& A, c10::string_view uplo) {
2777 return std::get<0>(at::_linalg_eigh(A, uplo,
2778 /*compute_v=*/_may_require_fw_or_bw_grad(A)));
2779 }
2780
linalg_eigvalsh_out(const Tensor & A,c10::string_view uplo,Tensor & L)2781 Tensor& linalg_eigvalsh_out(const Tensor& A, c10::string_view uplo, Tensor& L) {
2782 auto V = at::empty({0}, A.options());
2783 at::_linalg_eigh_out(L, V, A, uplo, /*compute_v=*/false);
2784 return L;
2785 }
2786
2787 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2788
2789 // This function returns complex-valued eigenvectors that is obtained from LAPACK GEEV's real-valued output
2790 // This function is also used for the MAGMA path because intermediate MAGMA's results live on CPU
2791 template <typename scalar_t>
linalg_eig_make_complex_eigenvectors_impl(Tensor & result,const Tensor & complex_values,const Tensor & real_vectors)2792 static void linalg_eig_make_complex_eigenvectors_impl(Tensor& result, const Tensor& complex_values, const Tensor& real_vectors) {
2793 // From GEEV documentation:
2794 // Complex conjugate pairs of eigenvalues appear consecutively with the eigenvalue having the positive imaginary part first
2795 // If the j-th eigenvalue is real, then v(j) = VR(:,j), the j-th column of VR.
2796 // If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, then v(j) = VR(:,j) + i*VR(:,j+1) and v(j+1) = VR(:,j) - i*VR(:,j+1).
2797
2798 auto batch_size = batchCount(real_vectors);
2799 auto n = real_vectors.size(-1);
2800 auto matrix_stride = matrixStride(real_vectors);
2801
2802 auto result_data = result.data_ptr<c10::complex<scalar_t>>();
2803 auto real_vectors_data = real_vectors.const_data_ptr<scalar_t>();
2804 auto values_data = complex_values.const_data_ptr<c10::complex<scalar_t>>();
2805
2806 for (auto b = decltype(batch_size){0}; b < batch_size; b++) {
2807 const scalar_t* vecs = &real_vectors_data[b * matrix_stride];
2808 c10::complex<scalar_t>* res = &result_data[b * matrix_stride];
2809 const c10::complex<scalar_t>* vals = &values_data[b * n];
2810 for (auto j = decltype(n){0}; j < n; j++) {
2811 if (vals[j].imag() == 0.0) { // eigenvalue is real, then v(j) = VR(:,j)
2812 for (auto i = decltype(n){0}; i < n; i++) {
2813 res[j * n + i] = c10::complex<scalar_t>(vecs[j * n + i], 0);
2814 }
2815 } else {
2816 for (auto i = decltype(n){0}; i < n; i++) {
2817 res[j * n + i] = c10::complex<scalar_t>(vecs[j * n + i], vecs[(j+1) * n + i]); // v(j) = VR(:,j) + i*VR(:,j+1)
2818 res[(j+1) * n + i] = c10::complex<scalar_t>(vecs[j * n + i], -vecs[(j+1) * n + i]); // v(j+1) = VR(:,j) - i*VR(:,j+1)
2819 }
2820 j++;
2821 }
2822 }
2823 }
2824 }
2825
linalg_eig_make_complex_eigenvectors(Tensor & complex_vectors,const Tensor & complex_values,const Tensor & real_vectors)2826 static Tensor& linalg_eig_make_complex_eigenvectors(Tensor& complex_vectors, const Tensor& complex_values, const Tensor& real_vectors) {
2827 // These asserts make explicit the requirements on tensors for 'linalg_eig_make_complex_eigenvectors_impl'
2828 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_vectors.device() == at::kCPU);
2829 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_values.device() == at::kCPU);
2830 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(real_vectors.device() == at::kCPU);
2831
2832 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_vectors.is_complex());
2833 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_values.is_complex());
2834 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(real_vectors.is_floating_point());
2835
2836 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_vectors.mT().is_contiguous());
2837 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_values.is_contiguous());
2838 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(real_vectors.mT().is_contiguous());
2839
2840 AT_DISPATCH_FLOATING_TYPES(real_vectors.scalar_type(), "linalg_eig_make_complex_vector", [&]{
2841 linalg_eig_make_complex_eigenvectors_impl<scalar_t>(complex_vectors, complex_values, real_vectors);
2842 });
2843 return complex_vectors;
2844 }
2845
2846 DEFINE_DISPATCH(linalg_eig_stub);
2847
linalg_eig_out_info(const Tensor & input,Tensor & values,Tensor & vectors,Tensor & infos,bool compute_eigenvectors)2848 static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Tensor& values, Tensor& vectors, Tensor& infos, bool compute_eigenvectors) {
2849 // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
2850 // therefore we create all intermediate tensors on CPU
2851 auto options = input.options().device(at::kCPU);
2852
2853 // These internal asserts make explicit the assumptions in the implementation
2854 // Error check with the actual error messages are done on the higher level of the hierarchy of calls
2855 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2);
2856 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.size(-2) == input.size(-1));
2857
2858 // for real-valued 'input', eigenvalues can be real-valued or complex-valued
2859 TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == values.scalar_type()) || (input.scalar_type() == values.scalar_type()));
2860 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU);
2861
2862 // for real-valued 'input', eigenvectors can be real-valued or complex-valued
2863 if (compute_eigenvectors) {
2864 TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == vectors.scalar_type()) || (input.scalar_type() == vectors.scalar_type()));
2865 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.device() == at::kCPU);
2866 }
2867
2868 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.scalar_type() == at::kInt);
2869 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.device() == at::kCPU);
2870 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.numel() == std::max<int64_t>(1, batchCount(input)));
2871 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_contiguous());
2872
2873 // if 'vectors' has no elements we can modify it
2874 if (vectors.numel() == 0 && compute_eigenvectors) {
2875 vectors.resize_(input.sizes(), MemoryFormat::Contiguous);
2876 vectors.transpose_(-2, -1); // make 'vectors' to have Fortran contiguous memory layout
2877 }
2878
2879 // if 'values' has no elements we can modify it
2880 auto values_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // input.shape[:-1]
2881 if (values.numel() == 0) {
2882 values.resize_(values_shape, MemoryFormat::Contiguous);
2883 }
2884
2885 // 'vectors' must be in batched column major order (Fortran contiguous)
2886 if (compute_eigenvectors) {
2887 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.mT().is_contiguous());
2888 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.sizes().equals(input.sizes()));
2889 }
2890
2891 // 'values' must be contiguous
2892 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());
2893 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.sizes().equals(values_shape));
2894
2895 // if 'input' is complex then use 'values' directly else create a temporary to hold the real and imaginary parts
2896 // and then use at::complex_out
2897 Tensor real_imag_values = values;
2898
2899 // if 'input' is complex then use 'vectors' directly else maybe create a temporary to hold real vectors
2900 // and then use linalg_eig_make_complex_eigenvectors
2901 Tensor maybe_complex_vectors = vectors;
2902 if (!input.is_complex()) {
2903 // first n elements to hold the real portion of the output and the last n elements to hold the imaginary portion
2904 auto real_imag_shape = IntArrayRef(input.sizes().data(), input.dim()-2).vec(); // input.shape[:-2]
2905 real_imag_shape.push_back(input.size(-1) * 2);
2906 real_imag_values = at::empty(real_imag_shape, options, MemoryFormat::Contiguous);
2907
2908 // linalg_eig_stub expects real-valued tensor to store eigenvectors
2909 // output of linalg_eig_stub need to be post-processed later to produce complex-valued eigenvectors
2910 // we do this post-processing only if 'vectors' is complex-valued
2911 // otherwise storage of 'vectors' is used directly
2912 if (vectors.is_complex() && compute_eigenvectors) {
2913 maybe_complex_vectors = at::empty(input.sizes(), options, MemoryFormat::Contiguous);
2914 maybe_complex_vectors.transpose_(-2, -1); // make 'maybe_complex_vectors' to have Fortran contiguous memory layout
2915 }
2916 }
2917
2918 // MAGMA uses a hybrid CPU-GPU algorithm that performs well only for large matrices
2919 // See: https://github.com/pytorch/pytorch/pull/52491#issuecomment-795685687
2920 // Here we call CPU path for matrices smaller than 2048x2048
2921 // that should be in general significantly faster than calling MAGMA
2922 if (input.size(-1) <= 2048) {
2923 linalg_eig_stub(at::kCPU, real_imag_values, maybe_complex_vectors, infos, input.to(kCPU), compute_eigenvectors);
2924 } else {
2925 linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors);
2926 }
2927
2928 // if input is not complex we need to do some post-processing
2929 if (!input.is_complex()) {
2930 // extract real and imaginary parts of the output
2931 auto real_values = real_imag_values.slice(/*dim=*/-1, /*start=*/0, /*end*/input.size(-1));
2932 auto imag_values = real_imag_values.slice(/*dim=*/-1, /*start=*/input.size(-1));
2933
2934 // if the imaginary part is zero we don't need to do anything
2935 bool is_zero_imag = at::all(imag_values == 0.0).item().toBool();
2936 if (is_zero_imag) {
2937 values.copy_(real_values);
2938 if (compute_eigenvectors) {
2939 vectors.copy_(maybe_complex_vectors); // does nothing for !vectors.is_complex() because vectors.is_same(maybe_complex_vectors) == true
2940 }
2941 return std::tuple<Tensor&, Tensor&>(values, vectors);
2942 }
2943
2944 if (values.is_complex()) {
2945 values = at::complex_out(values, real_values, imag_values);
2946 } else {
2947 TORCH_CHECK(false, "torch.linalg.eig: imaginary part of eigenvalues is non-zero, can't safely cast eigenvalues to non-complex dtype.")
2948 }
2949 if (compute_eigenvectors) {
2950 if (vectors.is_complex()) {
2951 vectors = linalg_eig_make_complex_eigenvectors(vectors, values, maybe_complex_vectors);
2952 } else {
2953 TORCH_CHECK(false, "torch.linalg.eig: imaginary part of eigenvectors is non-zero, can't safely cast eigenvectors to non-complex dtype.")
2954 }
2955 }
2956 }
2957
2958 return std::tuple<Tensor&, Tensor&>(values, vectors);
2959 }
2960
linalg_eig_out(const Tensor & input,Tensor & values,Tensor & vectors)2961 std::tuple<Tensor&, Tensor&> linalg_eig_out(const Tensor& input, Tensor& values, Tensor& vectors) {
2962 TORCH_CHECK(input.isfinite().all().item<bool>(), "torch.linalg.eig: input tensor should not contain infs or NaNs.");
2963 squareCheckInputs(input, "linalg.eig");
2964
2965 // unlike NumPy for real-valued inputs the output is always complex-valued
2966 checkLinalgCompatibleDtype("torch.linalg.eig", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
2967 checkLinalgCompatibleDtype("torch.linalg.eig", vectors.scalar_type(), toComplexType(input.scalar_type()), "eigenvectors");
2968 checkSameDevice("torch.linalg.eig", values, input, "eigenvalues");
2969 checkSameDevice("torch.linalg.eig", vectors, input, "eigenvectors");
2970
2971 // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
2972 auto options = input.options().device(at::kCPU);
2973 auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
2974
2975 // if result is not empty and not in batched column major format we have to allocate a temporary tensor
2976 bool is_batched_column_major = false;
2977 if (vectors.dim() >= 2) {
2978 is_batched_column_major = vectors.mT().is_contiguous();
2979 }
2980
2981 bool values_expected_type = (values.scalar_type() == toComplexType(input.scalar_type()));
2982 bool vectors_expected_type = (vectors.scalar_type() == toComplexType(input.scalar_type()));
2983
2984 auto expected_values_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // input.shape[:-1]
2985 bool values_equal_expected_shape = values.sizes().equals(expected_values_shape);
2986 bool vectors_equal_expected_shape = vectors.sizes().equals(input.sizes());
2987
2988 // if result is not empty and not in batched column major format
2989 bool values_tmp_needed = (values.numel() != 0 && !values.is_contiguous());
2990 bool vectors_tmp_needed = (vectors.numel() != 0 && !is_batched_column_major);
2991 // or result does not have the expected shape
2992 values_tmp_needed |= (values.numel() != 0 && !values_equal_expected_shape);
2993 vectors_tmp_needed |= (vectors.numel() != 0 && !vectors_equal_expected_shape);
2994 // or result does not have the expected dtype
2995 values_tmp_needed |= !values_expected_type;
2996 vectors_tmp_needed |= !vectors_expected_type;
2997 // we will allocate a temporary tensor and do the copy
2998
2999 // because MAGMA's GEEV takes CPU inputs and returns CPU outputs
3000 // "out" tensors that are on GPU device can't be used directly
3001 values_tmp_needed |= values.is_cuda();
3002 vectors_tmp_needed |= vectors.is_cuda();
3003
3004 // determine the appropriate scalar_type for the temporary tensors
3005 ScalarType values_type = input.scalar_type();
3006 ScalarType vectors_type = input.scalar_type();
3007 if (!input.is_complex()) {
3008 // for real-valued input we can have either real- or complex-valued output
3009 ScalarType input_complex_dtype = toComplexType(input.scalar_type());
3010 values_type = values.is_complex() ? input_complex_dtype : values_type;
3011 vectors_type = vectors.is_complex() ? input_complex_dtype : vectors_type;
3012 }
3013
3014 if (values_tmp_needed && vectors_tmp_needed) {
3015 Tensor values_tmp = at::empty({0}, options.dtype(values_type));
3016 Tensor vectors_tmp = at::empty({0}, options.dtype(vectors_type));
3017 std::tie(values_tmp, vectors_tmp) = linalg_eig_out_info(input, values_tmp, vectors_tmp, infos, true);
3018 at::native::resize_output(values, values_tmp.sizes());
3019 values.copy_(values_tmp);
3020 at::native::resize_output(vectors, vectors_tmp.sizes());
3021 vectors.copy_(vectors_tmp);
3022 } else if (!values_tmp_needed && vectors_tmp_needed) {
3023 // use 'values' storage directly
3024 Tensor vectors_tmp = at::empty({0}, options.dtype(vectors_type));
3025 std::tie(values, vectors_tmp) = linalg_eig_out_info(input, values, vectors_tmp, infos, true);
3026 at::native::resize_output(vectors, vectors_tmp.sizes());
3027 vectors.copy_(vectors_tmp);
3028 } else if (values_tmp_needed && !vectors_tmp_needed) {
3029 // use 'vectors' storage directly
3030 Tensor values_tmp = at::empty({0}, options.dtype(values_type));
3031 std::tie(values_tmp, vectors) = linalg_eig_out_info(input, values_tmp, vectors, infos, true);
3032 at::native::resize_output(values, values_tmp.sizes());
3033 values.copy_(values_tmp);
3034 } else {
3035 // use 'values' and 'vectors' storage directly
3036 std::tie(values, vectors) = linalg_eig_out_info(input, values, vectors, infos, true);
3037 }
3038
3039 // Now check LAPACK/MAGMA error codes
3040 at::_linalg_check_errors(infos, "torch.linalg.eig", input.dim() == 2);
3041 return std::tuple<Tensor&, Tensor&>(values, vectors);
3042 }
3043
linalg_eig(const Tensor & input)3044 std::tuple<Tensor, Tensor> linalg_eig(const Tensor& input) {
3045 ScalarType complex_dtype = toComplexType(input.scalar_type());
3046 Tensor values = at::empty({0}, input.options().dtype(complex_dtype));
3047 Tensor vectors = at::empty({0}, input.options().dtype(complex_dtype));
3048
3049 at::linalg_eig_outf(input, values, vectors);
3050
3051 return std::tuple<Tensor, Tensor>(values, vectors);
3052 }
3053
linalg_eigvals_out(const Tensor & input,Tensor & values)3054 Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
3055 squareCheckInputs(input, "linalg.eigvals");
3056
3057 // unlike NumPy for real-valued inputs the output is always complex-valued
3058 checkLinalgCompatibleDtype("torch.linalg.eigvals", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
3059 checkSameDevice("torch.linalg.eigvals", values, input, "eigenvalues");
3060
3061 // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
3062 auto options = input.options().device(at::kCPU);
3063 auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
3064
3065 bool values_expected_type = (values.scalar_type() == toComplexType(input.scalar_type()));
3066
3067 auto expected_values_shape = IntArrayRef(input.sizes().data(), input.dim()-1); // input.shape[:-1]
3068 bool values_equal_expected_shape = values.sizes().equals(expected_values_shape);
3069
3070 // if result is not empty and not in batched column major format
3071 bool values_tmp_needed = (values.numel() != 0 && !values.is_contiguous());
3072 // or result does not have the expected shape
3073 values_tmp_needed |= (values.numel() != 0 && !values_equal_expected_shape);
3074 // or result does not have the expected dtype
3075 values_tmp_needed |= !values_expected_type;
3076 // we will allocate a temporary tensor and do the copy
3077
3078 // because MAGMA's GEEV takes CPU inputs and returns CPU outputs
3079 // 'values' tensor that is on GPU device can't be used directly
3080 values_tmp_needed |= (!values.is_cpu());
3081
3082 // determine the appropriate scalar_type for the temporary tensors
3083 ScalarType values_type = input.scalar_type();
3084 if (!input.is_complex()) {
3085 // for real-valued input we can have either real- or complex-valued output
3086 ScalarType input_complex_dtype = toComplexType(input.scalar_type());
3087 values_type = values.is_complex() ? input_complex_dtype : values_type;
3088 }
3089
3090 Tensor vectors;
3091 if (values_tmp_needed) {
3092 Tensor values_tmp = at::empty({0}, options.dtype(values_type));
3093 std::tie(values_tmp, std::ignore) = linalg_eig_out_info(input, values_tmp, vectors, infos, /*compute_eigenvectors=*/false);
3094 at::native::resize_output(values, values_tmp.sizes());
3095 values.copy_(values_tmp);
3096 } else { // use 'values' storage directly
3097 std::tie(values, std::ignore) = linalg_eig_out_info(input, values, vectors, infos, /*compute_eigenvectors=*/false);
3098 }
3099
3100 // Now check LAPACK/MAGMA error codes
3101 at::_linalg_check_errors(infos, "torch.linalg.eigvals", input.dim() == 2);
3102 return values;
3103 }
3104
linalg_eigvals(const Tensor & input)3105 Tensor linalg_eigvals(const Tensor& input) {
3106 // if input requires grad we must compute the eigenvectors to make this function differentiable
3107 // the eigenvectors are not exposed to the user
3108 if (_may_require_fw_or_bw_grad(input)) {
3109 return std::get<0>(at::linalg_eig(input));
3110 }
3111 return at::_linalg_eigvals(input);
3112 }
3113
_linalg_eigvals(const Tensor & input)3114 Tensor _linalg_eigvals(const Tensor& input) {
3115 ScalarType complex_dtype = toComplexType(input.scalar_type());
3116 Tensor values = at::empty({0}, input.options().dtype(complex_dtype));
3117 linalg_eigvals_out(input, values);
3118 return values;
3119 }
3120
3121 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3122
3123 /* torch.svd, implemented in terms of torch.linalg.svd. There are two main
3124 differences:
3125
3126 1. the 2nd parameter is bool some=True, which if effectively the opposite
3127 of full_matrices=True
3128
3129 2. svd returns V, while linalg.svd returns Vh = V^H
3130 */
3131
3132 DEFINE_DISPATCH(svd_stub);
3133
TORCH_IMPL_FUNC(_linalg_svd_out)3134 TORCH_IMPL_FUNC(_linalg_svd_out)(const Tensor& A,
3135 const bool full_matrices,
3136 const bool compute_uv,
3137 std::optional<c10::string_view> driver,
3138 const Tensor & U,
3139 const Tensor & S,
3140 const Tensor & Vh) {
3141 // Half optimisation half precondition for some parts of the LAPACK / cuSOLVER
3142 // In particular, the call to lapackSvd to compute lwork fails otherwise
3143 if (A.numel() == 0) {
3144 // Needed in the case that we have e.g. A.shape == (3, 0) and full_matrices=True
3145 // We fill U or Vh with the identity matrix as it's a valid SVD for the empty matrix
3146 if (compute_uv && full_matrices) {
3147 if (U.numel() != 0) {
3148 U.zero_();
3149 U.diagonal(0, -2, -1).fill_(1.);
3150 }
3151 if (Vh.numel() != 0) {
3152 Vh.zero_();
3153 Vh.diagonal(0, -2, -1).fill_(1.);
3154 }
3155 }
3156 return;
3157 }
3158
3159 // We need to distinguish the cuSOLVER case, as cuSOLVER expects F-contig matrices, but
3160 // it computes V rather than Vh
3161 const bool use_cusolver = at::native::svd_uses_cusolver(A);
3162 TORCH_CHECK(use_cusolver || !driver.has_value(),
3163 "torch.linalg.svd: keyword argument `driver=` is only supported on CUDA inputs with cuSOLVER backend.");
3164
3165 // A always needs to be copied as its contents will be destroyed during the computation of the SVD
3166 // Now, MAGMA needs the copy to be on CPU, while cuSOLVER needs it to be on CUDA, so we'll defer
3167 // the copy as a column major matrix to the backends.
3168 const auto info = at::zeros(IntArrayRef(A.sizes().begin(), A.sizes().end() - 2), A.options().dtype(kInt));
3169
3170 svd_stub(A.device().type(),
3171 A,
3172 full_matrices,
3173 compute_uv,
3174 driver,
3175 U, S, Vh, info);
3176
3177 // TODO This should be removed, and the code checking for convergence should be lifted
3178 // from svd_cusolver to this function. We should then make sure that this function
3179 // never errors out.
3180 at::_linalg_check_errors(info, "linalg.svd", /*is_matrix*/A.dim() == 2);
3181 }
3182
3183 std::tuple<Tensor&, Tensor&, Tensor&>
linalg_svd_out(const Tensor & A,bool full_matrices,std::optional<c10::string_view> driver,Tensor & U,Tensor & S,Tensor & Vh)3184 linalg_svd_out(const Tensor& A,
3185 bool full_matrices,
3186 std::optional<c10::string_view> driver,
3187 Tensor & U,
3188 Tensor & S,
3189 Tensor & Vh) {
3190 // This function does not have an _ex variant as we always check errors inside
3191 // to assure the convergence of the algorithm anyway. See
3192 // https://github.com/pytorch/pytorch/issues/28293
3193 // https://github.com/pytorch/pytorch/issues/64237
3194 //
3195 // We must delegate both linalg_svd and linalg_svdvals to
3196 // _linalg_svd (rather than delegating linalg_svdvals to linalg_svd) because
3197 // 1. We don't want to expose the `compute_uv` parameter in svd
3198 // 2. We would like to make use of the `compute_uv=False` optimisation within svdvals
3199 // The only way to achieve these two things and still abide by the compositionality rules
3200 // is by dispatching to another function.
3201 return at::_linalg_svd_out(U, S, Vh, A, full_matrices, /*compute_uv=*/true, driver);
3202 }
3203
linalg_svd(const Tensor & A,bool full_matrices,std::optional<c10::string_view> driver)3204 std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& A, bool full_matrices,
3205 std::optional<c10::string_view> driver) {
3206 return at::_linalg_svd(A, full_matrices, /*compute_uv=*/true, driver);
3207 }
3208
3209 // See note in linalg_svd for why this function does not have an _ex variant
linalg_svdvals_out(const Tensor & A,std::optional<c10::string_view> driver,Tensor & S)3210 Tensor& linalg_svdvals_out(const Tensor& A, std::optional<c10::string_view> driver, Tensor & S) {
3211 // Dummies
3212 auto U = at::empty({0}, A.options());
3213 auto Vh = at::empty({0}, A.options());
3214 at::_linalg_svd_out(U, S, Vh, A, /*full_matrices=*/false, /*compute_uv=*/false, /*driver=*/driver);
3215 return S;
3216 }
3217
linalg_svdvals(const Tensor & A,std::optional<c10::string_view> driver)3218 Tensor linalg_svdvals(const Tensor& A, std::optional<c10::string_view> driver) {
3219 return std::get<1>(at::_linalg_svd(A, /*full_matrices=*/false,
3220 /*compute_uv=*/_may_require_fw_or_bw_grad(A),
3221 /*driver=*/driver));
3222 }
3223
svd_out(const Tensor & self,bool some,bool compute_uv,Tensor & U,Tensor & S,Tensor & V)3224 std::tuple<Tensor&, Tensor&, Tensor&> svd_out(const Tensor& self, bool some, bool compute_uv,
3225 Tensor& U, Tensor& S, Tensor& V) {
3226
3227 if (compute_uv) {
3228 if (V.dim() >= 2) {
3229 V.transpose_(-2, -1);
3230 }
3231 at::linalg_svd_out(U, S, V, self, /*full_matrices=*/!some);
3232 V.transpose_(-2, -1);
3233 if (V.is_complex()) {
3234 // We cannot use `_set_conj` as it does not play well with backwards
3235 V.conj_physical_();
3236 }
3237 } else {
3238 TORCH_CHECK(self.scalar_type() == U.scalar_type(),
3239 "torch.svd: Expected out tensor to have dtype ", self.scalar_type(), " but got ", U.scalar_type(), " instead");
3240
3241 TORCH_CHECK(self.scalar_type() == V.scalar_type(),
3242 "torch.svd: Expected out tensor to have dtype ", self.scalar_type(), " but got ", V.scalar_type(), " instead");
3243
3244 at::linalg_svdvals_out(S, self);
3245 // some == false returns U, Vh of size (m, m), (n, n) full of zeros
3246 const auto m = self.size(-2);
3247 const auto n = self.size(-1);
3248 auto sizes = self.sizes().vec();
3249
3250 sizes.end()[-1] = m;
3251 at::native::resize_output(U, sizes);
3252 U.zero_();
3253
3254 sizes.end()[-2] = n;
3255 sizes.end()[-1] = n;
3256 at::native::resize_output(V, sizes);
3257 V.zero_();
3258 }
3259
3260 return std::tie(U, S, V);
3261 }
3262
svd(const Tensor & self,bool some,bool compute_uv)3263 std::tuple<Tensor, Tensor, Tensor> svd(const Tensor& self, bool some, bool compute_uv) {
3264 // TODO: uncomment the following when svd is deprecated not only in docs
3265 // torch/xla is blocking the transition from at::svd to at::linalg_svd in at::linalg_pinv code
3266 // see https://github.com/pytorch/xla/issues/2755
3267 // TORCH_WARN_ONCE(
3268 // "torch.svd is deprecated in favor of torch.linalg.svd and will be ",
3269 // "removed in a future PyTorch release.\n",
3270 // "U, S, V = torch.svd(A, some=some, compute_uv=True) (default)\n",
3271 // "should be replaced with\n",
3272 // "U, S, Vh = torch.linalg.svd(A, full_matrices=not some)\n",
3273 // "V = Vh.mH\n",
3274 // "and\n",
3275 // "_, S, _ = torch.svd(A, some=some, compute_uv=False)\n",
3276 // "should be replaced with\n",
3277 // "S = torch.linalg.svdvals(A)");
3278 TORCH_CHECK(self.dim() >= 2, "linalg.svd: input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
3279 Tensor U, S, Vh;
3280 if (compute_uv) {
3281 std::tie(U, S, Vh) = at::linalg_svd(self, /*full_matrices=*/!some);
3282 } else {
3283 S = at::linalg_svdvals(self);
3284 // some == false returns U, Vh of size (m, m), (n, n) full of zeros
3285 const auto m = self.size(-2);
3286 const auto n = self.size(-1);
3287
3288 auto sizes = self.sizes().vec();
3289 sizes.end()[-1] = m;
3290 U = at::zeros(sizes, self.options());
3291 sizes.end()[-2] = n;
3292 sizes.end()[-1] = n;
3293 Vh = at::zeros(sizes, self.options());
3294 }
3295 return std::make_tuple(std::move(U), std::move(S), Vh.mH());
3296 }
3297
3298 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3299
3300 DEFINE_DISPATCH(lstsq_stub);
3301
3302 /*
3303 Solves a least squares problem. That is minimizing the squared Frobenius norm of |B - A X|.
3304
3305 Input args:
3306 * 'input' - Tensor containing batches of m-by-n matrix A.
3307 * 'other' - Tensor containing batches of max(m, n)-by-nrhs matrix B.
3308 * 'cond' - relative tolerance for determining rank of A.
3309 * 'driver' - the name of the LAPACK driver that is used to compute the solution.
3310 Output args (modified in-place):
3311 * 'solution' - Tensor to store the solution matrix X.
3312 * 'residuals' - Tensor to store values of the residual sum of squares for each column of the solution.
3313 * 'rank' - Tensor to store the rank of A.
3314 * 'singular_values' - Tensor to store the singular values of A.
3315 * 'infos' - Tensor to store error codes of linear algebra math library.
3316
3317 For further details, please see the LAPACK documentation for GELS/GELSY/GELSS/GELSD routines.
3318 */
linalg_lstsq_out_info(Tensor & solution,Tensor & residuals,Tensor & rank,Tensor & singular_values,Tensor & infos,const Tensor & input,const Tensor & other,double rcond,std::string & driver)3319 static void linalg_lstsq_out_info(
3320 Tensor& solution,
3321 Tensor& residuals,
3322 Tensor& rank,
3323 Tensor& singular_values,
3324 Tensor& infos,
3325 const Tensor& input,
3326 const Tensor& other,
3327 double rcond,
3328 std::string& driver) {
3329 // These internal asserts make explicit the assumptions in the implementation
3330 // Error check with the actual error messages are done on the higher level of
3331 // the hierarchy of calls
3332 TORCH_INTERNAL_ASSERT(input.dim() >= 2);
3333 TORCH_INTERNAL_ASSERT(other.dim() >= 1);
3334
3335 auto dim_diff = input.dim() - other.dim();
3336 TORCH_INTERNAL_ASSERT(0 <= dim_diff && dim_diff <= 1);
3337
3338 TORCH_INTERNAL_ASSERT(input.scalar_type() == other.scalar_type());
3339 TORCH_INTERNAL_ASSERT(input.device() == other.device());
3340
3341 TORCH_INTERNAL_ASSERT(solution.scalar_type() == input.scalar_type());
3342 TORCH_INTERNAL_ASSERT(solution.device() == input.device());
3343
3344 TORCH_INTERNAL_ASSERT(residuals.device() == input.device());
3345
3346 TORCH_INTERNAL_ASSERT(rank.scalar_type() == at::kLong);
3347 TORCH_INTERNAL_ASSERT(rank.device() == input.device());
3348
3349 auto real_dtype = toRealValueType(input.scalar_type());
3350 TORCH_INTERNAL_ASSERT(singular_values.scalar_type() == real_dtype);
3351 TORCH_INTERNAL_ASSERT(singular_values.device() == input.device());
3352
3353 TORCH_INTERNAL_ASSERT(infos.scalar_type() == at::kInt);
3354 TORCH_INTERNAL_ASSERT(infos.device() == input.device());
3355 TORCH_INTERNAL_ASSERT(infos.numel() == std::max<int64_t>(1, batchCount(input)));
3356 TORCH_INTERNAL_ASSERT(infos.is_contiguous());
3357
3358 bool vector_case = linalg_solve_is_vector_rhs(input, other);
3359 // we need to unsqueeze 'other' because 2-dimensional tensors are expected in the implementation
3360 Tensor other_2d = vector_case ? other.unsqueeze(-1) : other;
3361
3362 TORCH_INTERNAL_ASSERT(input.size(-2) == other_2d.size(-2));
3363
3364 std::vector<int64_t> expected_solution_shape = broadcast_batch_size(input, other_2d, input.dim() - 2);
3365 // the actual shape of the solution returned is (*, n,) or (*, n, nrhs)
3366 // but LAPACK requires extra dimensions to store raw residuals
3367 // so the expected shape is (*, max(m, n),) or (*, max(m, n), nrhs)
3368 auto m = input.size(-2);
3369 auto n = input.size(-1);
3370 auto nrhs = other.size(-1);
3371 expected_solution_shape.push_back(std::max(m, n));
3372 if (!vector_case) {
3373 expected_solution_shape.push_back(nrhs);
3374 }
3375
3376 // if 'solution' has no elements we can modify it
3377 if (solution.numel() == 0) {
3378 if (vector_case) {
3379 solution.resize_(expected_solution_shape, MemoryFormat::Contiguous);
3380 } else {
3381 auto shape_transposed = expected_solution_shape;
3382 std::swap(shape_transposed.end()[-1], shape_transposed.end()[-2]);
3383 solution.resize_(shape_transposed, MemoryFormat::Contiguous);
3384 solution.transpose_(-2, -1);
3385 }
3386 }
3387
3388 // if 'solution' is non-empty it must have the expected shape
3389 TORCH_INTERNAL_ASSERT(solution.sizes().equals(expected_solution_shape));
3390
3391 // 'solution' must be in batched column major order (Fortran contiguous) for 2D inputs
3392 // or C contiguous for 1D input
3393 if (vector_case) {
3394 TORCH_INTERNAL_ASSERT(solution.is_contiguous());
3395 } else {
3396 TORCH_INTERNAL_ASSERT(solution.mT().is_contiguous());
3397 }
3398
3399 // for 1-dimensional 'other', we need to unsqueeze the 'solution' before passing to "apply_solve"
3400 if (vector_case) {
3401 solution = solution.unsqueeze_(-1);
3402 }
3403
3404 // _linalg_lstsq_helper_ performs calculations in-place and 'solution' must be a copy of other_2d
3405 solution.narrow(-2, 0, other_2d.size(-2)).copy_(other_2d);
3406
3407 // if 'rank' is empty we might resize it
3408 auto input_batch_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2);
3409 if (rank.numel() == 0 && driver != "gels") { // gels driver doesn't set 'rank'
3410 rank.resize_(input_batch_shape, MemoryFormat::Contiguous);
3411 }
3412
3413 // if 'rank' is non-empty it must have the expected shape and be contiguous
3414 if (driver != "gels") {
3415 TORCH_INTERNAL_ASSERT(rank.sizes().equals(input_batch_shape));
3416 TORCH_INTERNAL_ASSERT(rank.is_contiguous());
3417 }
3418
3419 // if 'singular_values' is empty we might resize it
3420 auto singular_values_shape = input_batch_shape.vec();
3421 singular_values_shape.push_back(std::min(m, n));
3422 if (singular_values.numel() == 0 && (driver == "gelsd" || driver == "gelss")) {
3423 singular_values.resize_(singular_values_shape, MemoryFormat::Contiguous);
3424 }
3425
3426 // if 'singular_values' is non-empty it must have the expected shape and be contiguous
3427 if (driver == "gelsd" || driver == "gelss") {
3428 TORCH_INTERNAL_ASSERT(singular_values.sizes().equals(singular_values_shape));
3429 TORCH_INTERNAL_ASSERT(singular_values.is_contiguous());
3430 }
3431
3432 // 'input' is modified in-place so we need a column-major copy
3433 auto input_working_copy = copyBatchedColumnMajor(input);
3434
3435 // now the actual call that computes the result in-place (apply_lstsq)
3436 lstsq_stub(input.device().type(), input_working_copy, solution, rank, singular_values, infos, rcond, driver);
3437
3438 // residuals are available only if m > n and drivers other than gelsy used
3439 if (m > n && driver != "gelsy") {
3440 // if the driver is gelss or gelsd then the residuals are available only if rank == n
3441 bool compute_residuals = true;
3442 if (driver == "gelss" || driver == "gelsd") {
3443 if (input.dim() == 2) {
3444 compute_residuals = (rank.item().toInt() == n);
3445 } else {
3446 // it is not clear what to do if some matrices have rank < n in case of batched input
3447 // For now let's compute the residuals only if all matrices have rank equal to n
3448 // This behaviour may be changed in the future
3449 // See https://github.com/pytorch/pytorch/issues/56483
3450 compute_residuals = at::all(rank == n).item().toBool();
3451 }
3452 }
3453 if (compute_residuals) {
3454 // LAPACK stores residuals data for postprocessing in rows n:(m-n)
3455 auto raw_residuals = solution.narrow(/*dim=*/-2, /*start=*/n, /*length*/m - n);
3456 if (raw_residuals.is_complex()) {
3457 raw_residuals.mul_(raw_residuals.conj());
3458 raw_residuals = at::real(raw_residuals);
3459 } else {
3460 raw_residuals.pow_(2);
3461 }
3462 at::sum_out(residuals, raw_residuals, /*dim=*/-2, /*keepdim=*/false, /*dtype*/real_dtype);
3463 }
3464 }
3465 auto solution_view = solution.narrow(/*dim=*/-2, /*start=*/0, /*length*/n);
3466 // manually restride original
3467 solution.set_(solution.storage(), solution_view.storage_offset(), solution_view.sizes(), solution_view.strides());
3468 if (m == 0) {
3469 solution.zero_();
3470 }
3471
3472 // for 1-dimensional 'other', we need to squeeze the solution after "apply_lstsq"
3473 if (vector_case) {
3474 solution.squeeze_(-1);
3475 }
3476 }
3477
get_default_lstsq_driver(std::optional<c10::string_view> driver,const Tensor & input)3478 static std::string get_default_lstsq_driver(std::optional<c10::string_view> driver, const Tensor& input) {
3479 // if `driver` is empty, we set driver_str to "gels" if working with CUDA tensors,
3480 // otherwise to "gelsy" driver.
3481 std::string driver_str;
3482 // check whether the user provided name is a valid driver name
3483 if (driver.has_value()) {
3484 driver_str = std::string(driver.value());
3485 // convert `driver_str` to lower case inplace.
3486 std::transform(driver_str.begin(), driver_str.end(), driver_str.begin(),
3487 [](unsigned char c) { return std::tolower(c); });
3488 static std::unordered_set<c10::string_view> allowed_drivers = {
3489 "gels", "gelsy", "gelsd", "gelss"
3490 };
3491 if (input.device() == at::kCPU) {
3492 TORCH_CHECK(
3493 allowed_drivers.find(driver_str) != allowed_drivers.end(),
3494 "torch.linalg.lstsq: parameter `driver` should be one of "
3495 "(gels, gelsy, gelsd, gelss)"
3496 );
3497 } else { // else if (input.is_cuda())
3498 TORCH_CHECK(
3499 driver_str == "gels",
3500 "torch.linalg.lstsq: `driver` other than `gels` is not supported on CUDA"
3501 );
3502 }
3503 } else {
3504 // if driver name is not provided, set to default 'gelsy' if on CPU,
3505 // or to `gels` if on CUDA.
3506 driver_str = input.is_cuda() ? "gels" : "gelsy";
3507 }
3508 return driver_str;
3509 }
3510
linalg_lstsq_out(const Tensor & input,const Tensor & other,std::optional<double> rcond,std::optional<c10::string_view> driver,Tensor & solution,Tensor & residuals,Tensor & rank,Tensor & singular_values)3511 std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> linalg_lstsq_out(
3512 const Tensor& input,
3513 const Tensor& other,
3514 std::optional<double> rcond,
3515 std::optional<c10::string_view> driver,
3516 Tensor& solution,
3517 Tensor& residuals,
3518 Tensor& rank,
3519 Tensor& singular_values) {
3520 TORCH_CHECK(input.dim() >= 2, "torch.linalg.lstsq: input must have at least 2 dimensions.");
3521 TORCH_CHECK(other.dim() >= 1, "torch.linalg.lstsq: other must have at least 1 dimension.");
3522 TORCH_CHECK(
3523 input.scalar_type() == other.scalar_type(),
3524 "torch.linalg.lstsq: Expected input and other to have the same dtype, but got input's dtype ",
3525 input.scalar_type(),
3526 " and other's dtype ",
3527 other.scalar_type());
3528
3529 auto dim_diff = input.dim() - other.dim();
3530 TORCH_CHECK(
3531 0 <= dim_diff && dim_diff <= 1,
3532 "torch.linalg.lstsq: input.dim() must be greater or equal to other.dim() and (input.dim() - other.dim()) <= 1");
3533
3534 // now check whether the provided output tensors can be used directly
3535
3536 // Two types of 'other' tensors are supported:
3537 // - 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
3538 // - 2-dimensional (2D) tensor or batch of 2D tensors (matrix case)
3539 // original torch.lstsq supported only the matrix case, while NumPy works for both cases
3540 // for the batched input we need to be able to distinguish them
3541 // auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); // input.shape[:-1]
3542 // bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape));
3543
3544 bool vector_case = linalg_solve_is_vector_rhs(input, other);
3545 Tensor other_2d = vector_case ? other.unsqueeze(-1) : other;
3546 TORCH_CHECK(
3547 input.size(-2) == other_2d.size(-2),
3548 vector_case ? "torch.linalg.lstsq: input.size(-2) should match other.size(-1)"
3549 : "torch.linalg.lstsq: input.size(-2) should match other.size(-2)");
3550
3551 checkSameDevice("torch.linalg.lstsq", other, input, "other");
3552 checkSameDevice("torch.linalg.lstsq", solution, input, "solution");
3553 checkSameDevice("torch.linalg.lstsq", residuals, input, "residuals");
3554 checkSameDevice("torch.linalg.lstsq", rank, input, "rank");
3555 checkSameDevice("torch.linalg.lstsq", singular_values, input, "singular_values");
3556
3557 // 'solution' is expected to have same dtype as input
3558 checkLinalgCompatibleDtype("torch.linalg.lstsq", solution, input, "solution");
3559
3560 // 'residuals' is expected to have real float dtype
3561 ScalarType real_dtype = c10::toRealValueType(input.scalar_type());
3562 checkLinalgCompatibleDtype("torch.linalg.lstsq", residuals.scalar_type(), real_dtype, "solution");
3563
3564 // 'rank' is expected to have integer dtype
3565 // actual LAPACK calls use int32_t type for rank, but we promote it to int64_t
3566 // to be consistent with torch.linalg.matrix_rank output dtype
3567 ScalarType rank_expected_type = ScalarType::Long;
3568 checkLinalgCompatibleDtype("torch.linalg.lstsq", rank.scalar_type(), rank_expected_type, "rank");
3569
3570 // 'singular_values' is expected to have real float dtype
3571 checkLinalgCompatibleDtype("torch.linalg.lstsq", singular_values.scalar_type(), real_dtype, "singular_values");
3572
3573 std::string driver_name = get_default_lstsq_driver(driver, input);
3574
3575 // set default rcond value
3576 double rcond_value = rcond.has_value()
3577 ? rcond.value()
3578 : _get_epsilon(c10::toRealValueType(input.scalar_type())) * static_cast<double>(std::max<int64_t>(input.size(-2), input.size(-1)));
3579
3580 auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt));
3581
3582 // provided output tensor can be used directly if:
3583 // 1. the shape matches the expected shape
3584 // 2. the dtype matches the expected dtype
3585 // 3. the tensor is contiguous
3586
3587 // Checks for the 'solution' tensor
3588 std::vector<int64_t> expected_solution_shape = broadcast_batch_size(input, other_2d, input.dim() - 2);
3589 // the actual shape of the shape of the solution returned in (*, n,) or (*, n, nrhs)
3590 // but LAPACK requires extra dimensions so the expected shape is (*, max(m, n),) or (*, max(m, n), nrhs)
3591 expected_solution_shape.push_back(std::max(input.size(-1), input.size(-2)));
3592 if (!vector_case && other.dim() > 2) {
3593 expected_solution_shape.push_back(other.size(-1));
3594 }
3595
3596 bool solution_equal_expected_shape = solution.sizes().equals(expected_solution_shape);
3597 bool solution_input_same_type = (solution.scalar_type() == input.scalar_type());
3598
3599 bool is_solution_batched_column_major = false;
3600 if (vector_case) {
3601 is_solution_batched_column_major = solution.is_contiguous();
3602 } else if (!vector_case && solution.dim() >= 2) {
3603 is_solution_batched_column_major = solution.mT().is_contiguous();
3604 }
3605
3606 // 'residuals' is not checked here because at::sum_out(residuals, ...) does that
3607
3608 auto input_batch_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2);
3609
3610 // Checks for the 'rank' tensor
3611 // rank is a scalar value for each matrix in the batch so
3612 // rank's expected shape is equal to input.shape[0:input.ndim-2]
3613 bool rank_equal_expected_shape = true;
3614 bool rank_equal_expected_type = true;
3615 bool rank_is_contiguous = true;
3616 if (driver_name != "gels") { // gels driver doesn't set 'rank'
3617 rank_equal_expected_shape = rank.sizes().equals(input_batch_shape);
3618 rank_equal_expected_type = (rank.scalar_type() == at::kLong);
3619 rank_is_contiguous = rank.is_contiguous();
3620 }
3621
3622 // Checks for the 'singular_values' tensor
3623 // singular values are computed only with "gelsd" and "gelss" drivers currently
3624 bool singular_values_equal_expected_shape = true;
3625 bool singular_values_equal_expected_type = true;
3626 bool singular_values_is_contiguous = true;
3627 if (driver_name == "gelsd" || driver_name == "gelss") {
3628 auto singular_values_shape = input_batch_shape.vec();
3629 singular_values_shape.push_back(std::min(input.size(-1), input.size(-2)));
3630 singular_values_equal_expected_shape = singular_values.sizes().equals(singular_values_shape);
3631 singular_values_equal_expected_type = (singular_values.scalar_type() == real_dtype);
3632 singular_values_is_contiguous = singular_values.is_contiguous();
3633 }
3634
3635 // if solution is not empty and not in batched column major format
3636 bool copy_needed = (solution.numel() != 0 && !is_solution_batched_column_major);
3637 copy_needed |= !solution_input_same_type; // or solution does not have the same dtype as input
3638 copy_needed |= (solution.numel() != 0 && !solution_equal_expected_shape); // or solution does not have the expected shape
3639
3640 copy_needed |= !rank_equal_expected_type;
3641 copy_needed |= (rank.numel() != 0 && !rank_equal_expected_shape);
3642 copy_needed |= (rank.numel() != 0 && !rank_is_contiguous);
3643
3644 copy_needed |= !singular_values_equal_expected_type;
3645 copy_needed |= (singular_values.numel() != 0 && !singular_values_equal_expected_shape);
3646 copy_needed |= (singular_values.numel() != 0 && !singular_values_is_contiguous);
3647
3648 if (copy_needed) { // we have to allocate temporary tensors
3649 Tensor solution_tmp = at::empty({0}, input.options());
3650 Tensor residuals_tmp = at::empty({0}, input.options().dtype(real_dtype));
3651 Tensor rank_tmp = at::empty({0}, input.options().dtype(at::kLong));
3652 Tensor singular_values_tmp = at::empty({0}, input.options().dtype(real_dtype));
3653
3654 linalg_lstsq_out_info(solution_tmp, residuals_tmp, rank_tmp, singular_values_tmp, infos, input, other, rcond_value, driver_name);
3655
3656 at::native::resize_output(solution, solution_tmp.sizes());
3657 solution.copy_(solution_tmp);
3658
3659 at::native::resize_output(residuals, residuals_tmp.sizes());
3660 residuals.copy_(residuals_tmp);
3661
3662 at::native::resize_output(rank, rank_tmp.sizes());
3663 rank.copy_(rank_tmp);
3664
3665 at::native::resize_output(singular_values, singular_values_tmp.sizes());
3666 singular_values.copy_(singular_values_tmp);
3667 } else {
3668 // else use the provided output storage directly
3669 linalg_lstsq_out_info(solution, residuals, rank, singular_values, infos, input, other, rcond_value, driver_name);
3670 }
3671
3672 at::_linalg_check_errors(infos, "torch.linalg.lstsq", infos.numel() <= 1);
3673 return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(solution, residuals, rank, singular_values);
3674 }
3675
linalg_lstsq(const Tensor & input,const Tensor & other,std::optional<double> rcond,std::optional<c10::string_view> driver)3676 std::tuple<Tensor, Tensor, Tensor, Tensor> linalg_lstsq(
3677 const Tensor& input, const Tensor& other,
3678 std::optional<double> rcond,
3679 std::optional<c10::string_view> driver) {
3680 Tensor solution = at::empty({0}, input.options());
3681 Tensor residuals = at::empty({0}, input.options().dtype(toRealValueType(input.scalar_type())));
3682 Tensor rank = at::empty({0}, input.options().dtype(at::kLong));
3683 Tensor singular_values = at::empty({0}, input.options().dtype(toRealValueType(input.scalar_type())));
3684 std::tie(solution, residuals, rank, singular_values) =
3685 at::linalg_lstsq_outf(input, other, rcond, driver, solution, residuals, rank, singular_values);
3686 return std::make_tuple(std::move(solution), std::move(residuals), std::move(rank), std::move(singular_values));
3687 }
3688
3689 DEFINE_DISPATCH(ldl_factor_stub);
3690
TORCH_IMPL_FUNC(linalg_ldl_factor_ex_out)3691 TORCH_IMPL_FUNC(linalg_ldl_factor_ex_out)
3692 (const Tensor& self,
3693 bool hermitian,
3694 bool check_errors,
3695 const Tensor& LD,
3696 const Tensor& pivots,
3697 const Tensor& info) {
3698 // LAPACK workspace query segfalts if the input has 0 in batch dimensions.
3699 if (self.numel() == 0) {
3700 info.zero_();
3701 return;
3702 }
3703
3704 // We decided not to include upper flag in the API.
3705 // https://github.com/pytorch/pytorch/pull/69828#issuecomment-1015143819
3706 // We can revisit this decision later and remove upper completely
3707 // also from low level functions or add it to the public API.
3708 constexpr bool upper = false;
3709 if constexpr (upper) {
3710 at::triu_out(const_cast<Tensor&>(LD), self);
3711 } else {
3712 at::tril_out(const_cast<Tensor&>(LD), self);
3713 }
3714
3715 // call ldl_factor_stub that fills the result tensors
3716 ldl_factor_stub(
3717 self.device().type(), LD, pivots, info, upper, hermitian);
3718
3719 if (check_errors) {
3720 at::_linalg_check_errors(
3721 info, "torch.linalg.ldl_factor_ex", self.dim() == 2);
3722 }
3723 }
3724
linalg_ldl_factor_out(const Tensor & self,bool hermitian,Tensor & LD,Tensor & pivots)3725 std::tuple<Tensor&, Tensor&> linalg_ldl_factor_out(
3726 const Tensor& self,
3727 bool hermitian,
3728 Tensor& LD,
3729 Tensor& pivots) {
3730 auto info = at::empty({0}, self.options().dtype(kInt));
3731 // We pass check_errors as we want to use lu_factor rather than lu_factor_ex
3732 // in the errors
3733 at::linalg_ldl_factor_ex_outf(
3734 self, hermitian, /*check_errors=*/false, LD, pivots, info);
3735 at::_linalg_check_errors(info, "torch.linalg.ldl_factor", self.dim() == 2);
3736 return std::tie(LD, pivots);
3737 }
3738
linalg_ldl_factor(const Tensor & self,bool hermitian)3739 std::tuple<Tensor, Tensor> linalg_ldl_factor(
3740 const Tensor& self,
3741 bool hermitian) {
3742 auto [LD, pivots, info] =
3743 at::linalg_ldl_factor_ex(self, hermitian, /*check_errors=*/false);
3744 at::_linalg_check_errors(info, "torch.linalg.ldl_factor", self.dim() == 2);
3745 return std::make_tuple(std::move(LD), std::move(pivots));
3746 }
3747
3748 DEFINE_DISPATCH(ldl_solve_stub);
3749
TORCH_IMPL_FUNC(linalg_ldl_solve_out)3750 TORCH_IMPL_FUNC(linalg_ldl_solve_out)
3751 (const Tensor& LD,
3752 const Tensor& pivots,
3753 const Tensor& B,
3754 bool hermitian,
3755 const Tensor& result) {
3756 if (LD.numel() == 0 || pivots.numel() == 0) {
3757 return;
3758 }
3759
3760 auto pivots_ = pivots.expect_contiguous();
3761
3762 auto LD_ = at::native::borrow_else_clone(
3763 LD.mT().is_contiguous(), LD, LD, /*contig=*/false);
3764 result.copy_(B);
3765 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(result) == batchCount(result));
3766
3767 ldl_solve_stub(
3768 B.device().type(), *LD_, *pivots_, result, false, hermitian);
3769 }
3770
3771 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve_triangular ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3772
linalg_vecdot_out(const Tensor & x,const Tensor & y,int64_t dim,Tensor & out)3773 Tensor& linalg_vecdot_out(const Tensor& x, const Tensor& y, int64_t dim, Tensor& out) {
3774 checkFloatingOrComplex(x, "linalg.vecdot");
3775 TORCH_CHECK(x.scalar_type() == y.scalar_type(),
3776 "linalg.vecdot: Expected x and y to have the same dtype, but found x of type ",
3777 x.scalar_type(), " and y of type ", y.scalar_type(), " instead");
3778 // out checks
3779 TORCH_CHECK(out.scalar_type() == x.scalar_type(),
3780 "linalg.vecdot: Expected out of dtype", x.scalar_type(),
3781 " but found ", out.scalar_type());
3782 checkSameDevice("linalg.vecdot", x, out);
3783
3784 // Computes x^H y
3785 if (x.dim() == 1 && y.dim() == 1) {
3786 at::native::resize_output(out, {});
3787 return at::vdot_out(out, x, y);
3788 } else {
3789 return at::sum_out(out, x.conj() * y, /*dim=*/dim);
3790 }
3791 }
3792
linalg_vecdot(const Tensor & x,const Tensor & y,int64_t dim)3793 Tensor linalg_vecdot(const Tensor& x, const Tensor& y, int64_t dim) {
3794 checkFloatingOrComplex(x, "linalg.vecdot");
3795 TORCH_CHECK(x.scalar_type() == y.scalar_type(),
3796 "linalg.vecdot: Expected x and y to have the same dtype, but found x of type ",
3797 x.scalar_type(), " and y of type ", y.scalar_type(), " instead");
3798 // Computes x^H y
3799 if (x.dim() == 1 && y.dim() == 1) {
3800 return at::vdot(x, y);
3801 } else {
3802 return x.conj().mul(y).sum(/*dim=*/dim);
3803 }
3804 }
3805
3806 /*
3807 Solves the matrix equation AX = B for A triangular.
3808 'left' If true solves AX = B, if false solves XA = B
3809 'upper' controls the portion of input matrix to consider in computations,
3810 'unitriangular' if true then we assume diag(A) to be ones
3811 'out' The tensor with the result. If A == out, A will be modified in place
3812 */
linalg_solve_triangular_out(const Tensor & A,const Tensor & B,bool upper,bool left,bool unitriangular,Tensor & out)3813 Tensor& linalg_solve_triangular_out(
3814 const Tensor& A,
3815 const Tensor& B,
3816 bool upper,
3817 bool left,
3818 bool unitriangular,
3819 Tensor& out) {
3820 checkInputsSolver(A, B, left, "linalg.solve_triangular");
3821 auto [B_, A_] = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/nullptr);
3822
3823 // We'll write F-contig / F-transpose for FORTRAN contiguous / FORTRAN transpose etc
3824 // We say that a matrix is F-ready if it's F-contig OR F-transpose
3825 // At this point, A, B have been broadcasted but may or may not be F-ready
3826
3827 // The following algorithm minimises copies and allocations. In pseudocode:
3828 // if out is wrong size:
3829 // resize_output(out)
3830 // # Invariant: out is the right size
3831 // Tensor out_f; # Tensor that we will pass to FORTRAN
3832 // if out is F-ready:
3833 // out_f = out;
3834 // else:
3835 // Allocate out_f F-ready
3836 // if B != out_f:
3837 // copy B into out_f
3838 // # Invariant: out_f F-ready and has B copied into it
3839 // if out_f is F-transposed:
3840 // transpose equation
3841 // if out_f is conj:
3842 // conjugate equation
3843 // # Invariant: out_f is not conjugated and F-contig
3844 // Tensor A_f; # Tensor that will be sent to FORTRAN
3845 // if A is F-ready:
3846 // if A is conj and A is not transposed:
3847 // # We need to clone A in this case. See [Cloning A]
3848 // clone A F-contig into A_f
3849 // else:
3850 // A_f = A;
3851 // else:
3852 // clone A F-contig into A_f
3853 // # Invariant: out_f is F-contig and A_f is F-ready
3854 // # We pass FORTRAN the flags indicating if A_f is transposed and or conjugated
3855 //
3856 // # Here we undo the conjugations / transposes on out_f if needed
3857 //
3858 // if out_f not same out:
3859 // copy out_f into out
3860 // return out
3861 //
3862 // Note: The logic for the negative bit is the same as that for the conjugate bit
3863 //
3864 // Note: [Cloning A] If we are careful when allocating B when it needs to be allocated at the
3865 // beginning of the algorithm, it is possible to always elide the copy of A here.
3866 // Via this trick, the algorithm will copy at most one of A or B (never both) whenever A
3867 // and B are F-ready and not A.is_neg() (which happens almost always in practice).
3868 // When called as f(A, B, out=B) in most practical cases it'll perform no copies.
3869
3870 const bool avoid_copy_A = A_.transpose(-2, -1).is_contiguous() && A_.is_conj();
3871 if (avoid_copy_A) {
3872 // See Note: [Cloning A]
3873 at::native::resize_output(out, B_.sizes());
3874 }
3875 else {
3876 // poorman's reimplementation of resize_output with result F-contig
3877 if (resize_output_check(out, B_.sizes())) {
3878 out.resize_(B_.transpose(-2, -1).sizes(), MemoryFormat::Contiguous);
3879 out.transpose_(-2, -1); // make 'out' have Fortran contiguous memory layout
3880 }
3881 }
3882 // Invariant: out has the right size, so we'll be able to copy into it later on
3883
3884 Tensor out_f; // the out that will go into fortran
3885 // We use C10_LIKELY mostly for documentation as it helps following what's the most likely path
3886 if C10_LIKELY (is_row_or_column_contiguous(out)) {
3887 out_f = out;
3888 if C10_LIKELY (!out.is_same(B_)) {
3889 out_f.copy_(B_);
3890 }
3891 } else {
3892 if (avoid_copy_A) {
3893 // See Note: [Cloning A]
3894 out_f = B_.clone(at::MemoryFormat::Contiguous);
3895 }
3896 else {
3897 out_f = cloneBatchedColumnMajor(B_);
3898 }
3899 }
3900 // Invariant: out_f F-ready and has B copied into it
3901
3902 // out_f is F-transposed
3903 bool transpose_A = false;
3904 bool transpose_out_f = false;
3905 if (out_f.stride(-1) == 1) {
3906 left = !left;
3907 transpose_A = true;
3908 transpose_out_f = true;
3909 out_f.transpose_(-2 ,-1);
3910 }
3911
3912 // No need to conjugate anything if out_f is conj as AX = conj(B) <=> conj(A)conj(X) = B
3913 // and X = B after the algorithm. We just annotate that A is conjugated later on
3914 // The solution will be written into out_f, so it'll be conjugated already
3915
3916 Tensor A_f = std::move(A_); // The A that will go into fortran
3917
3918 bool A_is_conj = A_f.is_conj() != out_f.is_conj();
3919 bool A_is_neg = A_f.is_neg() != out_f.is_neg();
3920 bool A_is_f_contig = (A_f.stride(-1) == 1) == transpose_A;
3921 if C10_UNLIKELY (!is_row_or_column_contiguous(A_f)) {
3922 // We first annotate with flags on A_f all the conj / transpose / neg coming from out
3923 // and then we clone the resulting tensor to resolve all of them in memory
3924 if (out_f.is_conj()) {
3925 A_f = A_f.conj();
3926 }
3927 A_is_conj = false;
3928
3929 if (out_f.is_neg()) {
3930 A_f = A_f._neg_view();
3931 }
3932 A_is_neg = false;
3933
3934 // This choice is to be consistent with how we flip `upper` later on
3935 // Note that this is the same reasoning we apply for neg and conj below
3936 // If B has neg or out or transpose, then we need to resolve it in memory
3937 A_f = transpose_A ? A_f.clone(at::MemoryFormat::Contiguous)
3938 : cloneBatchedColumnMajor(A_f);
3939 A_is_f_contig = true;
3940 } else if C10_UNLIKELY (A_is_f_contig && A_is_conj) {
3941 if C10_UNLIKELY (A_f.is_neg() || out_f.is_neg()) {
3942 // Cases A_is_neg (remember that B.is_neg() iff out_f.is_same(B))
3943 // -AX = -B => A(-X) = B. Swap neg of A_f. Nothing to do on X as X.is_same(B).
3944 // -AX = B. We resolve the neg in memory
3945 // AX = -B => -A -X = B. We resolve the neg in memory for A,
3946 // Since X.is_same(B), we already have that X.is_neg() == true
3947
3948 // We do the neg with a view, as this will be resolved in the clone below
3949 if (out_f.is_neg()) {
3950 A_f = A_f._neg_view();
3951 }
3952 A_is_neg = false;
3953 }
3954 // We resolve the transpose if necessary and then leave A_f F-transposed,
3955 // as BLAS can handle the case F-transposed and conjugated
3956 A_f = at::clone(transpose_A ? A_f.mT() : A_f, at::MemoryFormat::Contiguous);
3957 A_is_f_contig = false;
3958 if (transpose_A) {
3959 upper = !upper;
3960 }
3961 // As we've already resolved the conj of A in the clone
3962 A_is_conj = out_f.is_conj();
3963 } else if C10_UNLIKELY (A_is_neg) {
3964 // We follow the same logic as above, only that in this case we need to perform the
3965 // negation in memory
3966 if (out_f.is_neg()) {
3967 A_f = -A_f;
3968 } else {
3969 A_f = A_f.resolve_neg();
3970 }
3971 A_is_neg = false;
3972 // As we've already resolved the conj of A in the negationa bove
3973 A_is_conj = out_f.is_conj();
3974 }
3975 // Invariant: out_f is F-contig and A_f is F-ready
3976 // neg has been resolved
3977
3978 // If we pass the matrix physically F-transposed, we need to change the parity of upper
3979 if (A_f.stride(-1) == 1) {
3980 upper = !upper;
3981 }
3982
3983 triangular_solve_stub(
3984 A_f.device().type(), A_f, out_f,
3985 /*left=*/left,
3986 /*upper=*/upper,
3987 /*transpose*/to_transpose_type(A_is_f_contig, A_is_conj),
3988 /*unitriangular=*/unitriangular);
3989
3990 if (transpose_out_f) {
3991 out_f.transpose_(-2, -1);
3992 }
3993
3994 if (!out_f.is_same(out)) {
3995 out.copy_(out_f);
3996 }
3997 return out;
3998 }
3999
linalg_solve_triangular(const Tensor & A,const Tensor & B,bool upper,bool left,bool unitriangular)4000 Tensor linalg_solve_triangular(
4001 const Tensor& A,
4002 const Tensor& B,
4003 bool upper,
4004 bool left,
4005 bool unitriangular) {
4006 Tensor out = at::empty({0}, A.options());
4007 linalg_solve_triangular_out(A, B, upper, left, unitriangular, out);
4008 return out;
4009 }
4010
linalg_vander_symint(const Tensor & x,std::optional<c10::SymInt> N)4011 Tensor linalg_vander_symint(
4012 const Tensor& x,
4013 std::optional<c10::SymInt> N) {
4014 auto t = x.scalar_type();
4015 TORCH_CHECK(t == ScalarType::Float ||
4016 t == ScalarType::Double ||
4017 t == ScalarType::ComplexFloat ||
4018 t == ScalarType::ComplexDouble ||
4019 c10::isIntegralType(t, false),
4020 "linalg.vander supports floating point, complex, and integer tensors, but got ", t);
4021 const auto x_ = x.dim() == 0 ? x.unsqueeze(-1) : x;
4022
4023 auto shape = x_.sym_sizes().vec();
4024 const auto n = N.value_or(shape.back());
4025 TORCH_CHECK(n > 1, "N must be greater than 1.");
4026
4027 // Append cumprod of the oher 0...n-1 powers
4028 shape.push_back(n - 1);
4029 auto result = at::cumprod(x_.unsqueeze(-1).expand_symint(shape), -1);
4030 // The row of ones
4031 shape.back() = 1LL;
4032 auto ones = result.new_ones_symint(shape);
4033 return at::cat({std::move(ones), std::move(result)}, /*dim=*/ -1);
4034 }
4035 // NOLINTEND(cppcoreguidelines-pro-type-const-cast)
4036 } // namespace at::native
4037