xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/grad_mode.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorMeta.h>
7 #include <ATen/TensorOperators.h>
8 #include <ATen/TensorSubclassLikeUtils.h>
9 
10 #include <ATen/native/BatchLinearAlgebra.h>
11 #include <ATen/native/LinearAlgebraUtils.h>
12 #include <ATen/native/Resize.h>
13 #include <ATen/native/cpu/zmath.h>
14 
15 #include <c10/util/irange.h>
16 
17 #include <utility>
18 #include <vector>
19 
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #include <ATen/NativeFunctions.h>
23 #else
24 #include <ATen/ops/_spsolve.h>
25 #include <ATen/ops/_cholesky_solve_helper.h>
26 #include <ATen/ops/_cholesky_solve_helper_native.h>
27 #include <ATen/ops/_linalg_check_errors.h>
28 #include <ATen/ops/_linalg_check_errors_native.h>
29 #include <ATen/ops/_linalg_eigh.h>
30 #include <ATen/ops/_linalg_eigh_meta.h>
31 #include <ATen/ops/_linalg_eigh_native.h>
32 #include <ATen/ops/_linalg_eigvals.h>
33 #include <ATen/ops/_linalg_eigvals_native.h>
34 #include <ATen/ops/_linalg_solve_ex.h>
35 #include <ATen/ops/_linalg_solve_ex_meta.h>
36 #include <ATen/ops/_linalg_solve_ex_native.h>
37 #include <ATen/ops/_linalg_svd.h>
38 #include <ATen/ops/_linalg_svd_meta.h>
39 #include <ATen/ops/_linalg_svd_native.h>
40 #include <ATen/ops/_lu_with_info_native.h>
41 #include <ATen/ops/all.h>
42 #include <ATen/ops/arange.h>
43 #include <ATen/ops/cat.h>
44 #include <ATen/ops/cholesky.h>
45 #include <ATen/ops/cholesky_inverse.h>
46 #include <ATen/ops/cholesky_inverse_native.h>
47 #include <ATen/ops/cholesky_native.h>
48 #include <ATen/ops/cholesky_solve.h>
49 #include <ATen/ops/cholesky_solve_native.h>
50 #include <ATen/ops/clone.h>
51 #include <ATen/ops/complex.h>
52 #include <ATen/ops/cumprod.h>
53 #include <ATen/ops/empty.h>
54 #include <ATen/ops/empty_like.h>
55 #include <ATen/ops/geqrf.h>
56 #include <ATen/ops/geqrf_native.h>
57 #include <ATen/ops/inverse_native.h>
58 #include <ATen/ops/linalg_cholesky_ex.h>
59 #include <ATen/ops/linalg_cholesky_ex_meta.h>
60 #include <ATen/ops/linalg_cholesky_ex_native.h>
61 #include <ATen/ops/linalg_cholesky_native.h>
62 #include <ATen/ops/linalg_eig.h>
63 #include <ATen/ops/linalg_eig_native.h>
64 #include <ATen/ops/linalg_eigh_native.h>
65 #include <ATen/ops/linalg_eigvals.h>
66 #include <ATen/ops/linalg_eigvals_native.h>
67 #include <ATen/ops/linalg_eigvalsh_native.h>
68 #include <ATen/ops/linalg_householder_product.h>
69 #include <ATen/ops/linalg_householder_product_native.h>
70 #include <ATen/ops/linalg_inv.h>
71 #include <ATen/ops/linalg_inv_ex.h>
72 #include <ATen/ops/linalg_inv_ex_native.h>
73 #include <ATen/ops/linalg_inv_native.h>
74 #include <ATen/ops/linalg_ldl_factor_ex.h>
75 #include <ATen/ops/linalg_ldl_factor_ex_meta.h>
76 #include <ATen/ops/linalg_ldl_factor_ex_native.h>
77 #include <ATen/ops/linalg_ldl_factor_native.h>
78 #include <ATen/ops/linalg_ldl_solve_meta.h>
79 #include <ATen/ops/linalg_ldl_solve_native.h>
80 #include <ATen/ops/linalg_lstsq.h>
81 #include <ATen/ops/linalg_lstsq_native.h>
82 #include <ATen/ops/linalg_lu_factor_ex.h>
83 #include <ATen/ops/linalg_lu_factor_ex_meta.h>
84 #include <ATen/ops/linalg_lu_factor_ex_native.h>
85 #include <ATen/ops/linalg_lu_factor_native.h>
86 #include <ATen/ops/linalg_lu_meta.h>
87 #include <ATen/ops/linalg_lu_native.h>
88 #include <ATen/ops/linalg_lu_solve.h>
89 #include <ATen/ops/linalg_lu_solve_meta.h>
90 #include <ATen/ops/linalg_lu_solve_native.h>
91 #include <ATen/ops/linalg_qr.h>
92 #include <ATen/ops/linalg_qr_meta.h>
93 #include <ATen/ops/linalg_qr_native.h>
94 #include <ATen/ops/linalg_solve_ex.h>
95 #include <ATen/ops/linalg_solve_ex_native.h>
96 #include <ATen/ops/linalg_solve_native.h>
97 #include <ATen/ops/linalg_solve_triangular_native.h>
98 #include <ATen/ops/linalg_svd.h>
99 #include <ATen/ops/linalg_svd_native.h>
100 #include <ATen/ops/linalg_svdvals.h>
101 #include <ATen/ops/linalg_svdvals_native.h>
102 #include <ATen/ops/linalg_vander_native.h>
103 #include <ATen/ops/linalg_vecdot_native.h>
104 #include <ATen/ops/lu_solve_native.h>
105 #include <ATen/ops/lu_unpack.h>
106 #include <ATen/ops/lu_unpack_meta.h>
107 #include <ATen/ops/lu_unpack_native.h>
108 #include <ATen/ops/orgqr_native.h>
109 #include <ATen/ops/ormqr_native.h>
110 #include <ATen/ops/qr_native.h>
111 #include <ATen/ops/real.h>
112 #include <ATen/ops/resize_as_native.h>
113 #include <ATen/ops/sum.h>
114 #include <ATen/ops/svd_native.h>
115 #include <ATen/ops/triangular_solve_meta.h>
116 #include <ATen/ops/triangular_solve_native.h>
117 #include <ATen/ops/tril.h>
118 #include <ATen/ops/triu.h>
119 #include <ATen/ops/vdot.h>
120 #include <ATen/ops/zeros.h>
121 #endif
122 
123 // First the required LAPACK implementations are registered here.
124 // A comment above the registered LAPACK routine suggest which batched
125 // linear algebra function uses that routine
126 #if AT_BUILD_WITH_LAPACK()
127 
128 // getrf
129 extern "C" void zgetrf_(int *m, int *n, std::complex<double> *a, int *lda, int *ipiv, int *info);
130 extern "C" void cgetrf_(int *m, int *n, std::complex<float> *a, int *lda, int *ipiv, int *info);
131 extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info);
132 extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);
133 
134 // potrs
135 extern "C" void zpotrs_(char *uplo, int *n, int *nrhs, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info);
136 extern "C" void cpotrs_(char *uplo, int *n, int *nrhs, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, int *info);
137 extern "C" void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
138 extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
139 
140 // potrf
141 extern "C" void zpotrf_(char *uplo, int *n, std::complex<double> *a, int *lda, int *info);
142 extern "C" void cpotrf_(char *uplo, int *n, std::complex<float> *a, int *lda, int *info);
143 extern "C" void dpotrf_(char *uplo, int *n, double *a, int *lda, int *info);
144 extern "C" void spotrf_(char *uplo, int *n, float *a, int *lda, int *info);
145 
146 // potri
147 extern "C" void zpotri_(char *uplo, int *n, std::complex<double> *a, int *lda, int *info);
148 extern "C" void cpotri_(char *uplo, int *n, std::complex<float> *a, int *lda, int *info);
149 extern "C" void dpotri_(char *uplo, int *n, double *a, int *lda, int *info);
150 extern "C" void spotri_(char *uplo, int *n, float *a, int *lda, int *info);
151 
152 // sytrf
153 extern "C" void dsytrf_(
154     char* uplo,
155     int* n,
156     double* a,
157     int* lda,
158     int* ipiv,
159     double* work,
160     int* lwork,
161     int* info);
162 extern "C" void ssytrf_(
163     char* uplo,
164     int* n,
165     float* a,
166     int* lda,
167     int* ipiv,
168     float* work,
169     int* lwork,
170     int* info);
171 extern "C" void zsytrf_(
172     char* uplo,
173     int* n,
174     std::complex<double>* a,
175     int* lda,
176     int* ipiv,
177     std::complex<double>* work,
178     int* lwork,
179     int* info);
180 extern "C" void csytrf_(
181     char* uplo,
182     int* n,
183     std::complex<float>* a,
184     int* lda,
185     int* ipiv,
186     std::complex<float>* work,
187     int* lwork,
188     int* info);
189 
190 // hetrf
191 extern "C" void zhetrf_(
192     char* uplo,
193     int* n,
194     std::complex<double>* a,
195     int* lda,
196     int* ipiv,
197     std::complex<double>* work,
198     int* lwork,
199     int* info);
200 extern "C" void chetrf_(
201     char* uplo,
202     int* n,
203     std::complex<float>* a,
204     int* lda,
205     int* ipiv,
206     std::complex<float>* work,
207     int* lwork,
208     int* info);
209 
210 // sytrs
211 extern "C" void dsytrs_(
212     char* uplo,
213     int* n,
214     int* nrhs,
215     double* a,
216     int* lda,
217     int* ipiv,
218     double* b,
219     int* ldb,
220     int* info);
221 extern "C" void ssytrs_(
222     char* uplo,
223     int* n,
224     int* nrhs,
225     float* a,
226     int* lda,
227     int* ipiv,
228     float* b,
229     int* ldb,
230     int* info);
231 extern "C" void zsytrs_(
232     char* uplo,
233     int* n,
234     int* nrhs,
235     std::complex<double>* a,
236     int* lda,
237     int* ipiv,
238     std::complex<double>* b,
239     int* ldb,
240     int* info);
241 extern "C" void csytrs_(
242     char* uplo,
243     int* n,
244     int* nrhs,
245     std::complex<float>* a,
246     int* lda,
247     int* ipiv,
248     std::complex<float>* b,
249     int* ldb,
250     int* info);
251 
252 // hetrs
253 extern "C" void zhetrs_(
254     char* uplo,
255     int* n,
256     int* nrhs,
257     std::complex<double>* a,
258     int* lda,
259     int* ipiv,
260     std::complex<double>* b,
261     int* ldb,
262     int* info);
263 extern "C" void chetrs_(
264     char* uplo,
265     int* n,
266     int* nrhs,
267     std::complex<float>* a,
268     int* lda,
269     int* ipiv,
270     std::complex<float>* b,
271     int* ldb,
272     int* info);
273 
274 // geqrf
275 extern "C" void zgeqrf_(int *m, int *n, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *work, int *lwork, int *info);
276 extern "C" void cgeqrf_(int *m, int *n, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *work, int *lwork, int *info);
277 extern "C" void dgeqrf_(int *m, int *n, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
278 extern "C" void sgeqrf_(int *m, int *n, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
279 
280 // orgqr
281 extern "C" void zungqr_(int *m, int *n, int *k, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *work, int *lwork, int *info);
282 extern "C" void cungqr_(int *m, int *n, int *k, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *work, int *lwork, int *info);
283 extern "C" void dorgqr_(int *m, int *n, int *k, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
284 extern "C" void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
285 
286 // ormqr
287 extern "C" void zunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *c, int *ldc, std::complex<double> *work, int *lwork, int *info);
288 extern "C" void cunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *c, int *ldc, std::complex<float> *work, int *lwork, int *info);
289 extern "C" void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info);
290 extern "C" void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info);
291 
292 // syevd
293 extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex<double> *a, int *lda, double *w, std::complex<double> *work, int *lwork, double *rwork, int *lrwork, int *iwork, int *liwork, int *info);
294 extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex<float> *a, int *lda, float *w, std::complex<float> *work, int *lwork, float *rwork, int *lrwork, int *iwork, int *liwork, int *info);
295 extern "C" void dsyevd_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *iwork, int *liwork, int *info);
296 extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *iwork, int *liwork, int *info);
297 
298 // geev
299 extern "C" void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info);
300 extern "C" void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info);
301 extern "C" void cgeev_(char *jobvl, char *jobvr, int *n,
302              std::complex<float> *a, int *lda,
303              std::complex<float> *w,
304              std::complex<float> *vl, int *ldvl,
305              std::complex<float> *vr, int *ldvr,
306              std::complex<float> *work, int *lwork,
307              float *rwork,
308              int *info);
309 extern "C" void zgeev_(char *jobvl, char *jobvr, int *n,
310              std::complex<double> *a, int *lda,
311              std::complex<double> *w,
312              std::complex<double> *vl, int *ldvl,
313              std::complex<double> *vr, int *ldvr,
314              std::complex<double> *work, int *lwork,
315              double *rwork,
316              int *info);
317 
318 // gesdd
319 extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex<double> *a, int *lda,
320                         double *s, std::complex<double> *u, int *ldu, std::complex<double> *vt, int *ldvt, std::complex<double> *work, int *lwork, double *rwork, int *iwork, int *info);
321 extern "C" void cgesdd_(char *jobz, int *m, int *n, std::complex<float> *a, int *lda,
322                         float *s, std::complex<float> *u, int *ldu, std::complex<float> *vt, int *ldvt, std::complex<float> *work, int *lwork, float *rwork, int *iwork, int *info);
323 extern "C" void dgesdd_(char *jobz, int *m, int *n, double *a, int *lda,
324                         double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *iwork, int *info);
325 extern "C" void sgesdd_(char *jobz, int *m, int *n, float *a, int *lda,
326                         float *s, float *u, int *ldu, float *vt, int *ldvt, float *work, int *lwork, int *iwork, int *info);
327 
328 // getrs
329 extern "C" void zgetrs_(char *trans, int *n, int *nrhs, std::complex<double> *a, int *lda, int *ipiv, std::complex<double> *b, int *ldb, int *info);
330 extern "C" void cgetrs_(char *trans, int *n, int *nrhs, std::complex<float> *a, int *lda, int *ipiv, std::complex<float> *b, int *ldb, int *info);
331 extern "C" void dgetrs_(char *trans, int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
332 extern "C" void sgetrs_(char *trans, int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);
333 
334 // gels
335 extern "C" void zgels_(char *trans, int *m, int *n, int *nrhs,
336     std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb,
337     std::complex<double> *work, int *lwork, int *info);
338 extern "C" void cgels_(char *trans, int *m, int *n, int *nrhs,
339     std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb,
340     std::complex<float> *work, int *lwork, int *info);
341 extern "C" void dgels_(char *trans, int *m, int *n, int *nrhs,
342     double *a, int *lda, double *b, int *ldb,
343     double *work, int *lwork, int *info);
344 extern "C" void sgels_(char *trans, int *m, int *n, int *nrhs,
345     float *a, int *lda, float *b, int *ldb,
346     float *work, int *lwork, int *info);
347 
348 // gelsd
349 extern "C" void zgelsd_(int *m, int *n, int *nrhs,
350     std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb,
351     double *s, double *rcond, int *rank,
352     std::complex<double> *work, int *lwork, double *rwork, int *iwork, int *info);
353 extern "C" void cgelsd_(int *m, int *n, int *nrhs,
354     std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb,
355     float *s, float *rcond, int *rank,
356     std::complex<float> *work, int *lwork, float *rwork, int *iwork, int *info);
357 extern "C" void dgelsd_(int *m, int *n, int *nrhs,
358     double *a, int *lda, double *b, int *ldb,
359     double *s, double *rcond, int *rank,
360     double *work, int *lwork, int *iwork, int *info);
361 extern "C" void sgelsd_(int *m, int *n, int *nrhs,
362     float *a, int *lda, float *b, int *ldb,
363     float *s, float *rcond, int *rank,
364     float *work, int *lwork, int *iwork, int *info);
365 
366 // gelsy
367 extern "C" void zgelsy_(int *m, int *n, int *nrhs,
368     std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb,
369     int *jpvt, double *rcond, int *rank,
370     std::complex<double> *work, int *lwork,
371     double *rwork, int *info);
372 extern "C" void cgelsy_(int *m, int *n, int *nrhs,
373     std::complex<float> * a, int *lda, std::complex<float> *b, int *ldb,
374     int *jpvt, float *rcond, int *rank,
375     std::complex<float> *work, int *lwork,
376     float *rwork, int *info);
377 extern "C" void dgelsy_(int *m, int *n, int *nrhs,
378     double *a, int *lda, double *b, int *ldb,
379     int *jpvt, double *rcond, int *rank,
380     double *work, int *lwork, int *info);
381 extern "C" void sgelsy_(int *m, int *n, int *nrhs,
382     float *a, int *lda, float *b, int *ldb,
383     int *jpvt, float *rcond, int *rank,
384     float *work, int *lwork, int *info);
385 
386 // gelss
387 extern "C" void zgelss_(int *m, int *n, int *nrhs,
388     std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb,
389     double *s, double *rcond, int *rank,
390     std::complex<double> *work, int *lwork,
391     double *rwork, int *info);
392 extern "C" void cgelss_(int *m, int *n, int *nrhs,
393     std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb,
394     float *s, float *rcond, int *rank,
395     std::complex<float> *work, int *lwork,
396     float *rwork, int *info);
397 extern "C" void dgelss_(int *m, int *n, int *nrhs,
398     double *a, int *lda, double *b, int *ldb,
399     double *s, double *rcond, int *rank,
400     double *work, int *lwork, int *info);
401 extern "C" void sgelss_(int *m, int *n, int *nrhs,
402     float *a, int *lda, float *b, int *ldb,
403     float *s, float *rcond, int *rank,
404     float *work, int *lwork, int *info);
405 #endif
406 
407 #if AT_BUILD_WITH_BLAS()
408 // trsm
409 extern "C" void ztrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<double> *alpha, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb);
410 extern "C" void ctrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<float> *alpha, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb);
411 extern "C" void dtrsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, double *alpha, double *a, int *lda, double *b, int *ldb);
412 extern "C" void strsm_(char *side, char *uplo, char *trans, char *diag, int *n, int *nrhs, float *alpha, float *a, int *lda, float *b, int *ldb);
413 #endif
414 
415 namespace at::meta {
416 
TORCH_META_FUNC(linalg_ldl_factor_ex)417 TORCH_META_FUNC(linalg_ldl_factor_ex)
418 (const Tensor& self, bool hermitian, bool check_errors) {
419   at::native::squareCheckInputs(self, "torch.linalg.ldl_factor_ex");
420   at::native::checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex");
421 
422   auto shape = self.sizes();
423   auto ndim = shape.size();
424 
425   // prefer column major strides
426   auto ld_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig=*/true);
427   set_output_strided(0, shape, ld_strides, self.options(), {}); // LD
428 
429   set_output_contiguous(
430       1, shape.slice(0, ndim - 1), self.options().dtype(ScalarType::Int)); // pivots
431 
432   set_output_contiguous(
433       2, shape.slice(0, ndim - 2), self.options().dtype(ScalarType::Int)); // info
434 }
435 
TORCH_META_FUNC(linalg_ldl_solve)436 TORCH_META_FUNC(linalg_ldl_solve)
437 (const Tensor& LD,
438  const Tensor& pivots,
439  const Tensor& B,
440  bool hermitian) {
441   at::native::squareCheckInputs(LD, "torch.linalg.ldl_solve");
442   at::native::checkFloatingOrComplex(LD, "torch.linalg.ldl_solve");
443   at::native::linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve");
444   TORCH_CHECK(
445       B.dim() >= 2,
446       "torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, but it has ",
447       B.dim(),
448       " dimensions instead");
449   auto expected_pivots_shape = LD.sizes().slice(0, LD.dim() - 1);
450   TORCH_CHECK(
451       expected_pivots_shape.equals(pivots.sizes()),
452       "torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, but got pivots with shape ",
453       pivots.sizes(),
454       " instead");
455   // pivots is allowed to be any integer type
456   // LAPACK we use is 32-bit interface while cuSOLVER uses 64-bit interface for integers
457   TORCH_CHECK(
458       at::isIntegralType(pivots.scalar_type(), /*includeBool=*/false),
459       "torch.linalg.ldl_solve: Expected pivots to be integers. Got ",
460       pivots.scalar_type());
461   TORCH_CHECK(
462       LD.scalar_type() == B.scalar_type(),
463       "torch.linalg.ldl_solve: ",
464       "LD dtype",
465       LD.scalar_type(),
466       " does not match b dtype ",
467       B.scalar_type());
468 
469     auto [B_broadcast_size, _] = at::native::_linalg_broadcast_batch_dims(B, LD);
470 
471   // prefer column major strides
472   auto result_strides = at::native::batched_matrix_contiguous_strides(B_broadcast_size, /*f_contig=*/true);
473   set_output_strided(0, B_broadcast_size, result_strides, B.options(), {});
474 }
475 
TORCH_META_FUNC(triangular_solve)476 TORCH_META_FUNC(triangular_solve)(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
477   TORCH_CHECK(self.dim() >= 2,
478            "torch.triangular_solve: Expected b to have at least 2 dimensions, but it has ", self.dim(), " dimensions instead");
479   TORCH_CHECK(A.dim() >= 2,
480            "torch.triangular_solve: Expected A to have at least 2 dimensions, but it has ", A.dim(), " dimensions instead");
481 
482   at::native::linearSolveCheckInputs(self, A, "triangular_solve");
483 
484   if (A.layout() == Layout::Strided) {
485     auto [self_broadcast_size, A_broadcast_size] = at::native::_linalg_broadcast_batch_dims(self, A);
486 
487     // make column major strides for BLAS
488     const auto solution_strides = at::native::batched_matrix_contiguous_strides(self_broadcast_size, /*f-contig=*/true);
489     set_output_raw_strided(0, self_broadcast_size, solution_strides, self.options(), {});
490 
491     // make column major strides for BLAS
492     auto clone_A_strides = at::native::batched_matrix_contiguous_strides(A_broadcast_size, /*f_contig=*/true);
493     set_output_raw_strided(1, A_broadcast_size, clone_A_strides, A.options(), {});
494   } else if (A.layout() == Layout::SparseCsr || A.layout() == Layout::SparseBsr) {
495     // no broadcasting for non-strided layout
496     set_output_raw_strided(0, self.sizes(), {}, self.options(), {}); // make row major strides for Sparse BLAS
497     set_output_raw_strided(1, {0}, {}, self.options(), {}); // return 0-sized tensor
498   } else {
499     TORCH_INTERNAL_ASSERT(false, "triangular_solve: Got an unexpected layout.");
500   }
501 }
502 
TORCH_META_FUNC(_linalg_solve_ex)503 TORCH_META_FUNC(_linalg_solve_ex)(const Tensor& A,
504                                   const Tensor& B,
505                                   bool left,
506                                   bool check_errors) {
507   // dtype
508   at::native::checkFloatingOrComplex(A, "linalg.solve");
509   TORCH_CHECK(A.scalar_type() == B.scalar_type(),
510               "linalg.solve: Expected A and B to have the same dtype, but found A of type ",
511               A.scalar_type(), " and B of type ", B.scalar_type(), " instead");
512 
513   // NumPy compat: Two types of 'B' tensors are supported:
514   // - 1D tensor or batch of 1D tensors (vector case)
515   // - 2D tensor or batch of 2D tensors (matrix case)
516   const bool vector_case = at::native::linalg_solve_is_vector_rhs(A, B);
517   auto B_ = vector_case ? B.unsqueeze(-1) : B;
518 
519   // matrix shapes
520   at::native::checkInputsSolver(A, B_, /*left=*/left, "linalg.solve");
521 
522   // Check that B can be broadcasted to the shape of A
523   auto B_broad_shape = std::get<0>(at::native::_linalg_broadcast_batch_dims(B_, A));
524   // We disallow the broadcasting of B as a vector when left=False as, in that case, A.shape = (*, 1, 1)
525   TORCH_CHECK(left || !vector_case, "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. In this case linalg.solve is equivalent to B / A.squeeze(-1)");
526   auto result_shape = vector_case ? IntArrayRef(B_broad_shape.data(), B_broad_shape.size() - 1)
527                                   : B_broad_shape;
528   auto result_strides = at::native::batched_matrix_contiguous_strides(result_shape, /*f_contig=*/left);
529 
530   set_output_strided(0, result_shape, result_strides, B.options(), {});
531 
532   auto shape = A.sizes();
533   auto ndim = shape.size();
534 
535   // LU
536   auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
537   set_output_strided(1, shape, LU_strides, A.options(), {});
538 
539   // pivots
540   set_output_contiguous(2, shape.slice(0, ndim - 1), A.options().dtype(kInt));
541 
542   // info
543   set_output_contiguous(3, shape.slice(0, ndim - 2), A.options().dtype(kInt));
544 }
545 
TORCH_META_FUNC(linalg_inv_ex)546 TORCH_META_FUNC(linalg_inv_ex)(const Tensor& A, bool check_errors) {
547   at::native::squareCheckInputs(A, "linalg.inv");
548   at::native::checkFloatingOrComplex(A, "linalg.inv", /*allow_low_precision_dtypes*/false);
549 
550   auto shape = A.sizes();
551 
552   auto result_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
553   set_output_strided(0, shape, result_strides, A.options(), {});
554   set_output_contiguous(
555       1, shape.slice(0, shape.size() - 2), A.options().dtype(ScalarType::Int)); // info
556 }
557 
TORCH_META_FUNC(linalg_lu_factor_ex)558 TORCH_META_FUNC(linalg_lu_factor_ex)(const Tensor& A, bool pivot, bool check_errors) {
559   TORCH_CHECK(A.dim() >= 2, "torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: ", A.sizes(), " instead");
560 
561   auto sizes = A.sizes().vec();
562   const auto m = sizes.cend()[-2];
563   const auto n = sizes.cend()[-1];
564 
565   // make column major strides for BLAS
566   auto LU_strides = at::native::batched_matrix_contiguous_strides(sizes, /*f-contig*=*/true);
567   set_output_strided(0, sizes, LU_strides, A.options(), {});
568 
569   // Set sizes to the size of pivots
570   sizes.pop_back();
571   sizes.back() = std::min(m, n);
572   set_output_contiguous(1, sizes, A.options().dtype(kInt), {});
573 
574   // Set sizes to the size of info
575   sizes.pop_back();
576   set_output_contiguous(2, sizes, A.options().dtype(kInt), {});
577 }
578 
TORCH_META_FUNC(linalg_lu_solve)579 TORCH_META_FUNC(linalg_lu_solve)(const Tensor& LU,
580                                  const Tensor& pivots,
581                                  const Tensor& B,
582                                  bool left,
583                                  bool adjoint) {
584   // dtype
585   at::native::checkFloatingOrComplex(LU, "torch.linalg.lu_solve");
586   TORCH_CHECK(LU.scalar_type() == B.scalar_type(),
587               "linalg.lu_solve: Expected LU and B to have the same dtype, but found LU of type ",
588               LU.scalar_type(), " and B of type ", B.scalar_type(), " instead");
589   TORCH_CHECK(pivots.dtype() == at::kInt,
590               "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32");
591 
592   // matrix shapes
593   at::native::squareCheckInputs(LU, "torch.linalg.lu_solve");
594   at::native::checkInputsSolver(LU, B, left, "linalg.lu_solve");
595   //
596   TORCH_CHECK(LU.size(-1) == pivots.size(-1),
597               "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix");
598 
599   // batches
600   TORCH_CHECK(
601       LU.sizes().slice(0, LU.dim() - 1).equals(pivots.sizes()),
602       "linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, but got pivots with shape ",
603       pivots.sizes(), " instead");
604 
605   // This one checks that B can be broadcasted to the shape of A
606   auto B_broadcast_size = std::get<0>(at::native::_linalg_broadcast_batch_dims(B, LU));
607   auto result_strides = at::native::batched_matrix_contiguous_strides(B_broadcast_size, /*f_contig=*/left);
608 
609   set_output_strided(0, B_broadcast_size, result_strides, B.options(), {});
610 }
611 
TORCH_META_FUNC(linalg_cholesky_ex)612 TORCH_META_FUNC(linalg_cholesky_ex)(const Tensor& A,
613                                     bool upper,
614                                     bool check_errors) {
615   at::native::squareCheckInputs(A, "linalg.cholesky");
616   at::native::checkFloatingOrComplex(A, "linalg.cholesky");
617 
618   auto A_shape = A.sizes();
619   auto ndim = A_shape.size();
620 
621   // L
622   auto L_strides = at::native::batched_matrix_contiguous_strides(A_shape, /*f-contig*=*/true);
623   set_output_strided(0, A_shape, L_strides, A.options(), {});
624 
625   // info
626   set_output_contiguous(1, A_shape.slice(0, ndim - 2), A.options().dtype(ScalarType::Int));
627 }
628 
TORCH_META_FUNC(linalg_qr)629 TORCH_META_FUNC(linalg_qr)(const Tensor& A,
630                            c10::string_view mode) {
631   at::native::checkIsMatrix(A, "linalg.qr");
632   at::native::checkFloatingOrComplex(A, "linalg.qr");
633   auto [compute_q, reduced_mode] = at::native::_parse_qr_mode(mode);
634 
635   auto A_shape = A.sizes().vec();
636   const auto m = A_shape.cend()[-2];
637   const auto n = A_shape.cend()[-1];
638   const auto k = std::min(m, n);
639 
640   if (compute_q) {
641     auto Q_shape = A_shape;
642     Q_shape.end()[-1] = reduced_mode ? k : m;
643     auto Q_strides = at::native::batched_matrix_contiguous_strides(Q_shape, /*f-contig*=*/true);
644     set_output_strided(0, Q_shape, Q_strides, A.options(), {});
645   } else {
646     set_output_raw_strided(0, {0}, {}, A.options(), {});
647   }
648 
649   // For readability
650   auto R_shape = std::move(A_shape);
651   R_shape.end()[-2] = (reduced_mode || !compute_q) ? k : m;
652   auto R_strides = at::native::batched_matrix_contiguous_strides(R_shape, /*f-contig*=*/true);
653   set_output_strided(1, R_shape, R_strides, A.options(), {});
654 }
655 
656 
TORCH_META_FUNC(_linalg_svd)657 TORCH_META_FUNC(_linalg_svd)(const Tensor& A,
658                              bool full_matrices,
659                              bool compute_uv,
660                              std::optional<c10::string_view> driver) {
661   at::native::checkIsMatrix(A, "linalg.svd");
662   at::native::checkFloatingOrComplex(A, "linalg.svd");
663 
664   auto sizes = A.sizes().vec();
665   const auto m = sizes.cend()[-2];
666   const auto n = sizes.cend()[-1];
667   const auto k = std::min(m, n);
668 
669   // Prepare sizes for U
670   if (compute_uv) {
671     sizes.back() = full_matrices ? m : k;
672     auto U_strides = at::native::batched_matrix_contiguous_strides(sizes, /*f-contig*=*/true);
673     set_output_strided(0, sizes, U_strides, A.options(), {});
674 
675     // Prepare sizes for Vh
676     sizes.end()[-2] = full_matrices ? n : k;
677     sizes.end()[-1] = n;
678 
679     // We need to distinguish the cuSOLVER case, as the cuSOLVER algorithms we use
680     // expect F-contig matrices, but they compute V rather than Vh
681     const bool use_cusolver = at::native::svd_uses_cusolver(A);
682     auto Vh_strides = at::native::batched_matrix_contiguous_strides(sizes, /*f-contig*=*/!use_cusolver);
683     set_output_strided(2, sizes, Vh_strides, A.options(), {});
684   } else {
685     set_output_raw_strided(0, {0}, {}, A.options(), {});
686     set_output_raw_strided(2, {0}, {}, A.options(), {});
687   }
688 
689   // Prepare sizes for S. S is always real, even when A is complex.
690   sizes.pop_back();
691   sizes.end()[-1] = k;
692   set_output_contiguous(1, sizes, A.options().dtype(c10::toRealValueType(A.scalar_type())), {});
693 }
694 
TORCH_META_FUNC(lu_unpack)695 TORCH_META_FUNC(lu_unpack)(const Tensor& LU, const Tensor& pivots, bool unpack_data, bool unpack_pivots) {
696   TORCH_CHECK(LU.dim() >= 2, "torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: ", LU.sizes(), " instead");
697   if (unpack_pivots) {
698     TORCH_CHECK(pivots.scalar_type() == at::kInt,
699         "torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n"
700         "Note: this function is intended to be used with the output produced by torch.linalg.lu_factor");
701   }
702 
703   auto sizes = LU.sizes().vec();
704   const auto m = sizes.cend()[-2];
705   const auto n = sizes.cend()[-1];
706   const auto k = std::min(m, n);
707 
708   // P.shape[-2:] == (m, m) (or size zero if pivot == False)
709   sizes.end()[-1] = m;
710   if (unpack_pivots) {
711     set_output_raw_strided(0, sizes, {}, LU.options(), {});
712   } else {
713     set_output_raw_strided(0, {0}, {}, LU.options(), {});
714   }
715 
716   if (unpack_data) {
717     // L.shape[-2:] == (m, k)
718     sizes.end()[-1] = k;
719     set_output_raw_strided(1, sizes, {}, LU.options(), {});
720 
721     // U.shape[-2:] == (k, n)
722     sizes.end()[-2] = k;
723     sizes.end()[-1] = n;
724     set_output_raw_strided(2, sizes, {}, LU.options(), {});
725   } else {
726     set_output_raw_strided(1, {0}, {}, LU.options(), {});
727     set_output_raw_strided(2, {0}, {}, LU.options(), {});
728   }
729 }
730 
TORCH_META_FUNC(_linalg_eigh)731 TORCH_META_FUNC(_linalg_eigh)(const Tensor& A,
732                               c10::string_view uplo,
733                               bool compute_v) {
734   at::native::squareCheckInputs(A, "linalg.eigh");
735   at::native::checkUplo(uplo);
736 
737   auto shape = A.sizes().vec();
738   if (compute_v) {
739     // eigenvectors
740     auto V_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
741     set_output_strided(1, shape, V_strides, A.options(), {});
742   } else {
743     set_output_raw_strided(1, {0}, {}, A.options(), {});
744   }
745 
746   // eigenvalues
747   shape.pop_back();
748   set_output_contiguous(0, shape, A.options().dtype(c10::toRealValueType(A.scalar_type())), {});
749 }
750 
TORCH_META_FUNC(linalg_lu)751 TORCH_META_FUNC(linalg_lu)(const Tensor& A, bool pivot) {
752   TORCH_CHECK(A.dim() >= 2, "linalg.lu: Expected tensor with 2 or more dimensions. Got size: ", A.sizes(), " instead");
753 
754   auto sizes = A.sizes().vec();
755   const auto m = sizes.cend()[-2];
756   const auto n = sizes.cend()[-1];
757   const auto k = std::min(m, n);
758 
759   // P.shape[-2:] == (m, m) (or size zero if pivot == False)
760   sizes.end()[-1] = m;
761   if (pivot) {
762     set_output_raw_strided(0, sizes, {}, A.options(), {});
763   } else {
764     set_output_raw_strided(0, {0}, {}, A.options(), {});
765   }
766 
767   // L.shape[-2:] == (m, k)
768   sizes.end()[-1] = k;
769   set_output_raw_strided(1, sizes, {}, A.options(), {});
770 
771   // U.shape[-2:] == (k, n)
772   sizes.end()[-2] = k;
773   sizes.end()[-1] = n;
774   set_output_raw_strided(2, sizes, {}, A.options(), {});
775 }
776 
777 } // namespace at::meta
778 
779 namespace at::native {
780 
781 #if AT_BUILD_WITH_LAPACK()
782 // Define the per-batch functions to be used in the main implementation of the batched
783 // linear algebra operations
784 
785 template<class scalar_t>
786 void lapackCholeskySolve(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info);
787 
788 template<class scalar_t, class value_t=scalar_t>
789 void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int *info);
790 
lapackLu(int m,int n,c10::complex<double> * a,int lda,int * ipiv,int * info)791 template<> void lapackLu<c10::complex<double>>(int m, int n, c10::complex<double> *a, int lda, int *ipiv, int *info) {
792   zgetrf_(&m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, ipiv, info);
793 }
794 
lapackLu(int m,int n,c10::complex<float> * a,int lda,int * ipiv,int * info)795 template<> void lapackLu<c10::complex<float>>(int m, int n, c10::complex<float> *a, int lda, int *ipiv, int *info) {
796   cgetrf_(&m, &n, reinterpret_cast<std::complex<float>*>(a), &lda, ipiv, info);
797 }
798 
lapackLu(int m,int n,double * a,int lda,int * ipiv,int * info)799 template<> void lapackLu<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
800   dgetrf_(&m, &n, a, &lda, ipiv, info);
801 }
802 
lapackLu(int m,int n,float * a,int lda,int * ipiv,int * info)803 template<> void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
804   sgetrf_(&m, &n, a, &lda, ipiv, info);
805 }
806 
lapackCholeskySolve(char uplo,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb,int * info)807 template<> void lapackCholeskySolve<c10::complex<double>>(char uplo, int n, int nrhs, c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb, int *info) {
808   zpotrs_(&uplo, &n, &nrhs, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(b), &ldb, info);
809 }
810 
lapackCholeskySolve(char uplo,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb,int * info)811 template<> void lapackCholeskySolve<c10::complex<float>>(char uplo, int n, int nrhs, c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb, int *info) {
812   cpotrs_(&uplo, &n, &nrhs, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(b), &ldb, info);
813 }
814 
lapackCholeskySolve(char uplo,int n,int nrhs,double * a,int lda,double * b,int ldb,int * info)815 template<> void lapackCholeskySolve<double>(char uplo, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
816   dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
817 }
818 
lapackCholeskySolve(char uplo,int n,int nrhs,float * a,int lda,float * b,int ldb,int * info)819 template<> void lapackCholeskySolve<float>(char uplo, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
820   spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
821 }
822 
lapackCholesky(char uplo,int n,c10::complex<double> * a,int lda,int * info)823 template<> void lapackCholesky<c10::complex<double>>(char uplo, int n, c10::complex<double> *a, int lda, int *info) {
824   zpotrf_(&uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, info);
825 }
826 
lapackCholesky(char uplo,int n,c10::complex<float> * a,int lda,int * info)827 template<> void lapackCholesky<c10::complex<float>>(char uplo, int n, c10::complex<float> *a, int lda, int *info) {
828   cpotrf_(&uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, info);
829 }
830 
lapackCholesky(char uplo,int n,double * a,int lda,int * info)831 template<> void lapackCholesky<double>(char uplo, int n, double *a, int lda, int *info) {
832   dpotrf_(&uplo, &n, a, &lda, info);
833 }
834 
lapackCholesky(char uplo,int n,float * a,int lda,int * info)835 template<> void lapackCholesky<float>(char uplo, int n, float *a, int lda, int *info) {
836   spotrf_(&uplo, &n, a, &lda, info);
837 }
838 
lapackCholeskyInverse(char uplo,int n,c10::complex<double> * a,int lda,int * info)839 template<> void lapackCholeskyInverse<c10::complex<double>>(char uplo, int n, c10::complex<double> *a, int lda, int *info) {
840   zpotri_(&uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, info);
841 }
842 
lapackCholeskyInverse(char uplo,int n,c10::complex<float> * a,int lda,int * info)843 template<> void lapackCholeskyInverse<c10::complex<float>>(char uplo, int n, c10::complex<float> *a, int lda, int *info) {
844   cpotri_(&uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, info);
845 }
846 
lapackCholeskyInverse(char uplo,int n,double * a,int lda,int * info)847 template<> void lapackCholeskyInverse<double>(char uplo, int n, double *a, int lda, int *info) {
848   dpotri_(&uplo, &n, a, &lda, info);
849 }
850 
lapackCholeskyInverse(char uplo,int n,float * a,int lda,int * info)851 template<> void lapackCholeskyInverse<float>(char uplo, int n, float *a, int lda, int *info) {
852   spotri_(&uplo, &n, a, &lda, info);
853 }
854 
lapackGeqrf(int m,int n,c10::complex<double> * a,int lda,c10::complex<double> * tau,c10::complex<double> * work,int lwork,int * info)855 template<> void lapackGeqrf<c10::complex<double>>(int m, int n, c10::complex<double> *a, int lda, c10::complex<double> *tau, c10::complex<double> *work, int lwork, int *info) {
856   zgeqrf_(&m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(work), &lwork, info);
857 }
858 
lapackGeqrf(int m,int n,c10::complex<float> * a,int lda,c10::complex<float> * tau,c10::complex<float> * work,int lwork,int * info)859 template<> void lapackGeqrf<c10::complex<float>>(int m, int n, c10::complex<float> *a, int lda, c10::complex<float> *tau, c10::complex<float> *work, int lwork, int *info) {
860   cgeqrf_(&m, &n, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(tau), reinterpret_cast<std::complex<float>*>(work), &lwork, info);
861 }
862 
lapackGeqrf(int m,int n,double * a,int lda,double * tau,double * work,int lwork,int * info)863 template<> void lapackGeqrf<double>(int m, int n, double *a, int lda, double *tau, double *work, int lwork, int *info) {
864   dgeqrf_(&m, &n, a, &lda, tau, work, &lwork, info);
865 }
866 
lapackGeqrf(int m,int n,float * a,int lda,float * tau,float * work,int lwork,int * info)867 template<> void lapackGeqrf<float>(int m, int n, float *a, int lda, float *tau, float *work, int lwork, int *info) {
868   sgeqrf_(&m, &n, a, &lda, tau, work, &lwork, info);
869 }
870 
lapackOrgqr(int m,int n,int k,c10::complex<double> * a,int lda,c10::complex<double> * tau,c10::complex<double> * work,int lwork,int * info)871 template<> void lapackOrgqr<c10::complex<double>>(int m, int n, int k, c10::complex<double> *a, int lda, c10::complex<double> *tau, c10::complex<double> *work, int lwork, int *info) {
872   zungqr_(&m, &n, &k, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(work), &lwork, info);
873 }
874 
lapackOrgqr(int m,int n,int k,c10::complex<float> * a,int lda,c10::complex<float> * tau,c10::complex<float> * work,int lwork,int * info)875 template<> void lapackOrgqr<c10::complex<float>>(int m, int n, int k, c10::complex<float> *a, int lda, c10::complex<float> *tau, c10::complex<float> *work, int lwork, int *info) {
876   cungqr_(&m, &n, &k, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(tau), reinterpret_cast<std::complex<float>*>(work), &lwork, info);
877 }
878 
lapackOrgqr(int m,int n,int k,double * a,int lda,double * tau,double * work,int lwork,int * info)879 template<> void lapackOrgqr<double>(int m, int n, int k, double *a, int lda, double *tau, double *work, int lwork, int *info) {
880   dorgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, info);
881 }
882 
lapackOrgqr(int m,int n,int k,float * a,int lda,float * tau,float * work,int lwork,int * info)883 template<> void lapackOrgqr<float>(int m, int n, int k, float *a, int lda, float *tau, float *work, int lwork, int *info) {
884   sorgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, info);
885 }
886 
lapackOrmqr(char side,char trans,int m,int n,int k,c10::complex<double> * a,int lda,c10::complex<double> * tau,c10::complex<double> * c,int ldc,c10::complex<double> * work,int lwork,int * info)887 template<> void lapackOrmqr<c10::complex<double>>(char side, char trans, int m, int n, int k, c10::complex<double> *a, int lda, c10::complex<double> *tau, c10::complex<double> *c, int ldc, c10::complex<double> *work, int lwork, int *info) {
888   zunmqr_(&side, &trans, &m, &n, &k, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(tau), reinterpret_cast<std::complex<double>*>(c), &ldc, reinterpret_cast<std::complex<double>*>(work), &lwork, info);
889 }
890 
lapackOrmqr(char side,char trans,int m,int n,int k,c10::complex<float> * a,int lda,c10::complex<float> * tau,c10::complex<float> * c,int ldc,c10::complex<float> * work,int lwork,int * info)891 template<> void lapackOrmqr<c10::complex<float>>(char side, char trans, int m, int n, int k, c10::complex<float> *a, int lda, c10::complex<float> *tau, c10::complex<float> *c, int ldc, c10::complex<float> *work, int lwork, int *info) {
892   cunmqr_(&side, &trans, &m, &n, &k, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(tau), reinterpret_cast<std::complex<float>*>(c), &ldc, reinterpret_cast<std::complex<float>*>(work), &lwork, info);
893 }
894 
lapackOrmqr(char side,char trans,int m,int n,int k,double * a,int lda,double * tau,double * c,int ldc,double * work,int lwork,int * info)895 template<> void lapackOrmqr<double>(char side, char trans, int m, int n, int k, double *a, int lda, double *tau, double *c, int ldc, double *work, int lwork, int *info) {
896   dormqr_(&side, &trans, &m, &n, &k, a, &lda, tau, c, &ldc, work, &lwork, info);
897 }
898 
lapackOrmqr(char side,char trans,int m,int n,int k,float * a,int lda,float * tau,float * c,int ldc,float * work,int lwork,int * info)899 template<> void lapackOrmqr<float>(char side, char trans, int m, int n, int k, float *a, int lda, float *tau, float *c, int ldc, float *work, int lwork, int *info) {
900   sormqr_(&side, &trans, &m, &n, &k, a, &lda, tau, c, &ldc, work, &lwork, info);
901 }
902 
lapackSyevd(char jobz,char uplo,int n,c10::complex<double> * a,int lda,double * w,c10::complex<double> * work,int lwork,double * rwork,int lrwork,int * iwork,int liwork,int * info)903 template<> void lapackSyevd<c10::complex<double>, double>(char jobz, char uplo, int n, c10::complex<double> *a, int lda, double *w, c10::complex<double> *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) {
904   zheevd_(&jobz, &uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, w, reinterpret_cast<std::complex<double>*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info);
905 }
906 
lapackSyevd(char jobz,char uplo,int n,c10::complex<float> * a,int lda,float * w,c10::complex<float> * work,int lwork,float * rwork,int lrwork,int * iwork,int liwork,int * info)907 template<> void lapackSyevd<c10::complex<float>, float>(char jobz, char uplo, int n, c10::complex<float> *a, int lda, float *w, c10::complex<float> *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) {
908   cheevd_(&jobz, &uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, w, reinterpret_cast<std::complex<float>*>(work), &lwork, rwork, &lrwork, iwork, &liwork, info);
909 }
910 
lapackSyevd(char jobz,char uplo,int n,double * a,int lda,double * w,double * work,int lwork,double * rwork,int lrwork,int * iwork,int liwork,int * info)911 template<> void lapackSyevd<double>(char jobz, char uplo, int n, double *a, int lda, double *w, double *work, int lwork, double *rwork, int lrwork, int *iwork, int liwork, int *info) {
912   (void)rwork;  // unused
913   (void)lrwork;  // unused
914   dsyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
915 }
916 
lapackSyevd(char jobz,char uplo,int n,float * a,int lda,float * w,float * work,int lwork,float * rwork,int lrwork,int * iwork,int liwork,int * info)917 template<> void lapackSyevd<float>(char jobz, char uplo, int n, float *a, int lda, float *w, float *work, int lwork, float *rwork, int lrwork, int *iwork, int liwork, int *info) {
918   (void)rwork;  // unused
919   (void)lrwork;  // unused
920   ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
921 }
922 
lapackEig(char jobvl,char jobvr,int n,double * a,int lda,double * w,double * vl,int ldvl,double * vr,int ldvr,double * work,int lwork,double * rwork,int * info)923 template<> void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int lda, double *w, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, double *rwork, int *info) {
924   // lapack [sd]geev wants to separate output arrays: wr and wi for the real
925   // and imaginary parts
926   double *wr = w;
927   double *wi = w ? w + n : nullptr;
928   (void)rwork; // unused
929   dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
930 }
931 
lapackEig(char jobvl,char jobvr,int n,float * a,int lda,float * w,float * vl,int ldvl,float * vr,int ldvr,float * work,int lwork,float * rwork,int * info)932 template<> void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int lda, float *w, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, float *rwork, int *info) {
933   // lapack [sd]geev wants to separate output arrays: wr and wi for the real
934   // and imaginary parts
935   float *wr = w;
936   float *wi = w ? w  + n : nullptr;
937   (void)rwork; // unused
938   sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
939 }
940 
lapackEig(char jobvl,char jobvr,int n,c10::complex<double> * a,int lda,c10::complex<double> * w,c10::complex<double> * vl,int ldvl,c10::complex<double> * vr,int ldvr,c10::complex<double> * work,int lwork,double * rwork,int * info)941 template<> void lapackEig<c10::complex<double>, double>(char jobvl, char jobvr, int n, c10::complex<double> *a, int lda, c10::complex<double> *w, c10::complex<double> *vl, int ldvl, c10::complex<double> *vr, int ldvr, c10::complex<double> *work, int lwork, double *rwork, int *info) {
942   zgeev_(&jobvl, &jobvr, &n,
943          reinterpret_cast<std::complex<double>*>(a), &lda,
944          reinterpret_cast<std::complex<double>*>(w),
945          reinterpret_cast<std::complex<double>*>(vl), &ldvl,
946          reinterpret_cast<std::complex<double>*>(vr), &ldvr,
947          reinterpret_cast<std::complex<double>*>(work), &lwork,
948          rwork, info);
949 }
950 
lapackEig(char jobvl,char jobvr,int n,c10::complex<float> * a,int lda,c10::complex<float> * w,c10::complex<float> * vl,int ldvl,c10::complex<float> * vr,int ldvr,c10::complex<float> * work,int lwork,float * rwork,int * info)951 template<> void lapackEig<c10::complex<float>, float>(char jobvl, char jobvr, int n, c10::complex<float> *a, int lda, c10::complex<float> *w, c10::complex<float> *vl, int ldvl, c10::complex<float> *vr, int ldvr, c10::complex<float> *work, int lwork, float *rwork, int *info) {
952   cgeev_(&jobvl, &jobvr, &n,
953          reinterpret_cast<std::complex<float>*>(a), &lda,
954          reinterpret_cast<std::complex<float>*>(w),
955          reinterpret_cast<std::complex<float>*>(vl), &ldvl,
956          reinterpret_cast<std::complex<float>*>(vr), &ldvr,
957          reinterpret_cast<std::complex<float>*>(work), &lwork,
958          rwork, info);
959 }
960 
lapackSvd(char jobz,int m,int n,c10::complex<double> * a,int lda,double * s,c10::complex<double> * u,int ldu,c10::complex<double> * vt,int ldvt,c10::complex<double> * work,int lwork,double * rwork,int * iwork,int * info)961 template<> void lapackSvd<c10::complex<double>, double>(char jobz, int m, int n, c10::complex<double> *a, int lda,
962                                   double *s, c10::complex<double> *u, int ldu, c10::complex<double> *vt, int ldvt, c10::complex<double> *work, int lwork, double *rwork, int *iwork, int *info) {
963   zgesdd_(&jobz, &m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, s, reinterpret_cast<std::complex<double>*>(u), &ldu,
964           reinterpret_cast<std::complex<double>*>(vt), &ldvt, reinterpret_cast<std::complex<double>*>(work), &lwork, rwork, iwork, info);
965 }
966 
lapackSvd(char jobz,int m,int n,c10::complex<float> * a,int lda,float * s,c10::complex<float> * u,int ldu,c10::complex<float> * vt,int ldvt,c10::complex<float> * work,int lwork,float * rwork,int * iwork,int * info)967 template<> void lapackSvd<c10::complex<float>, float>(char jobz, int m, int n, c10::complex<float> *a, int lda,
968                                  float *s, c10::complex<float> *u, int ldu, c10::complex<float> *vt, int ldvt, c10::complex<float> *work, int lwork, float *rwork, int *iwork, int *info) {
969   cgesdd_(&jobz, &m, &n, reinterpret_cast<std::complex<float>*>(a), &lda, s, reinterpret_cast<std::complex<float>*>(u), &ldu,
970           reinterpret_cast<std::complex<float>*>(vt), &ldvt, reinterpret_cast<std::complex<float>*>(work), &lwork, rwork, iwork, info);
971 }
972 
lapackSvd(char jobz,int m,int n,double * a,int lda,double * s,double * u,int ldu,double * vt,int ldvt,double * work,int lwork,double * rwork,int * iwork,int * info)973 template<> void lapackSvd<double>(char jobz, int m, int n, double *a, int lda,
974                                   double *s, double *u, int ldu, double *vt, int ldvt, double *work, int lwork, double *rwork, int *iwork, int *info) {
975   dgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info);
976 }
977 
lapackSvd(char jobz,int m,int n,float * a,int lda,float * s,float * u,int ldu,float * vt,int ldvt,float * work,int lwork,float * rwork,int * iwork,int * info)978 template<> void lapackSvd<float>(char jobz, int m, int n, float *a, int lda,
979                                  float *s, float *u, int ldu, float *vt, int ldvt, float *work, int lwork, float *rwork, int *iwork, int *info) {
980   sgesdd_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info);
981 }
982 
983 template <>
lapackLdlSymmetric(char uplo,int n,double * a,int lda,int * ipiv,double * work,int lwork,int * info)984 void lapackLdlSymmetric<double>(
985     char uplo,
986     int n,
987     double* a,
988     int lda,
989     int* ipiv,
990     double* work,
991     int lwork,
992     int* info) {
993   dsytrf_(&uplo, &n, a, &lda, ipiv, work, &lwork, info);
994 }
995 
996 template <>
lapackLdlSymmetric(char uplo,int n,float * a,int lda,int * ipiv,float * work,int lwork,int * info)997 void lapackLdlSymmetric<float>(
998     char uplo,
999     int n,
1000     float* a,
1001     int lda,
1002     int* ipiv,
1003     float* work,
1004     int lwork,
1005     int* info) {
1006   ssytrf_(&uplo, &n, a, &lda, ipiv, work, &lwork, info);
1007 }
1008 
1009 template <>
lapackLdlSymmetric(char uplo,int n,c10::complex<double> * a,int lda,int * ipiv,c10::complex<double> * work,int lwork,int * info)1010 void lapackLdlSymmetric<c10::complex<double>>(
1011     char uplo,
1012     int n,
1013     c10::complex<double>* a,
1014     int lda,
1015     int* ipiv,
1016     c10::complex<double>* work,
1017     int lwork,
1018     int* info) {
1019   zsytrf_(
1020       &uplo,
1021       &n,
1022       reinterpret_cast<std::complex<double>*>(a),
1023       &lda,
1024       ipiv,
1025       reinterpret_cast<std::complex<double>*>(work),
1026       &lwork,
1027       info);
1028 }
1029 
1030 template <>
lapackLdlSymmetric(char uplo,int n,c10::complex<float> * a,int lda,int * ipiv,c10::complex<float> * work,int lwork,int * info)1031 void lapackLdlSymmetric<c10::complex<float>>(
1032     char uplo,
1033     int n,
1034     c10::complex<float>* a,
1035     int lda,
1036     int* ipiv,
1037     c10::complex<float>* work,
1038     int lwork,
1039     int* info) {
1040   csytrf_(
1041       &uplo,
1042       &n,
1043       reinterpret_cast<std::complex<float>*>(a),
1044       &lda,
1045       ipiv,
1046       reinterpret_cast<std::complex<float>*>(work),
1047       &lwork,
1048       info);
1049 }
1050 
1051 template <>
lapackLdlHermitian(char uplo,int n,double * a,int lda,int * ipiv,double * work,int lwork,int * info)1052 void lapackLdlHermitian<double>(
1053     char uplo,
1054     int n,
1055     double* a,
1056     int lda,
1057     int* ipiv,
1058     double* work,
1059     int lwork,
1060     int* info) {
1061   dsytrf_(&uplo, &n, a, &lda, ipiv, work, &lwork, info);
1062 }
1063 
1064 template <>
lapackLdlHermitian(char uplo,int n,float * a,int lda,int * ipiv,float * work,int lwork,int * info)1065 void lapackLdlHermitian<float>(
1066     char uplo,
1067     int n,
1068     float* a,
1069     int lda,
1070     int* ipiv,
1071     float* work,
1072     int lwork,
1073     int* info) {
1074   ssytrf_(&uplo, &n, a, &lda, ipiv, work, &lwork, info);
1075 }
1076 
1077 template <>
lapackLdlHermitian(char uplo,int n,c10::complex<double> * a,int lda,int * ipiv,c10::complex<double> * work,int lwork,int * info)1078 void lapackLdlHermitian<c10::complex<double>>(
1079     char uplo,
1080     int n,
1081     c10::complex<double>* a,
1082     int lda,
1083     int* ipiv,
1084     c10::complex<double>* work,
1085     int lwork,
1086     int* info) {
1087   zhetrf_(
1088       &uplo,
1089       &n,
1090       reinterpret_cast<std::complex<double>*>(a),
1091       &lda,
1092       ipiv,
1093       reinterpret_cast<std::complex<double>*>(work),
1094       &lwork,
1095       info);
1096 }
1097 
1098 template <>
lapackLdlHermitian(char uplo,int n,c10::complex<float> * a,int lda,int * ipiv,c10::complex<float> * work,int lwork,int * info)1099 void lapackLdlHermitian<c10::complex<float>>(
1100     char uplo,
1101     int n,
1102     c10::complex<float>* a,
1103     int lda,
1104     int* ipiv,
1105     c10::complex<float>* work,
1106     int lwork,
1107     int* info) {
1108   chetrf_(
1109       &uplo,
1110       &n,
1111       reinterpret_cast<std::complex<float>*>(a),
1112       &lda,
1113       ipiv,
1114       reinterpret_cast<std::complex<float>*>(work),
1115       &lwork,
1116       info);
1117 }
1118 
1119 template <>
lapackLdlSolveSymmetric(char uplo,int n,int nrhs,double * a,int lda,int * ipiv,double * b,int ldb,int * info)1120 void lapackLdlSolveSymmetric<double>(
1121     char uplo,
1122     int n,
1123     int nrhs,
1124     double* a,
1125     int lda,
1126     int* ipiv,
1127     double* b,
1128     int ldb,
1129     int* info) {
1130   dsytrs_(&uplo, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1131 }
1132 
1133 template <>
lapackLdlSolveSymmetric(char uplo,int n,int nrhs,float * a,int lda,int * ipiv,float * b,int ldb,int * info)1134 void lapackLdlSolveSymmetric<float>(
1135     char uplo,
1136     int n,
1137     int nrhs,
1138     float* a,
1139     int lda,
1140     int* ipiv,
1141     float* b,
1142     int ldb,
1143     int* info) {
1144   ssytrs_(&uplo, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1145 }
1146 
1147 template <>
lapackLdlSolveSymmetric(char uplo,int n,int nrhs,c10::complex<double> * a,int lda,int * ipiv,c10::complex<double> * b,int ldb,int * info)1148 void lapackLdlSolveSymmetric<c10::complex<double>>(
1149     char uplo,
1150     int n,
1151     int nrhs,
1152     c10::complex<double>* a,
1153     int lda,
1154     int* ipiv,
1155     c10::complex<double>* b,
1156     int ldb,
1157     int* info) {
1158   zsytrs_(
1159       &uplo,
1160       &n,
1161       &nrhs,
1162       reinterpret_cast<std::complex<double>*>(a),
1163       &lda,
1164       ipiv,
1165       reinterpret_cast<std::complex<double>*>(b),
1166       &ldb,
1167       info);
1168 }
1169 
1170 template <>
lapackLdlSolveSymmetric(char uplo,int n,int nrhs,c10::complex<float> * a,int lda,int * ipiv,c10::complex<float> * b,int ldb,int * info)1171 void lapackLdlSolveSymmetric<c10::complex<float>>(
1172     char uplo,
1173     int n,
1174     int nrhs,
1175     c10::complex<float>* a,
1176     int lda,
1177     int* ipiv,
1178     c10::complex<float>* b,
1179     int ldb,
1180     int* info) {
1181   csytrs_(
1182       &uplo,
1183       &n,
1184       &nrhs,
1185       reinterpret_cast<std::complex<float>*>(a),
1186       &lda,
1187       ipiv,
1188       reinterpret_cast<std::complex<float>*>(b),
1189       &ldb,
1190       info);
1191 }
1192 
1193 template <>
lapackLdlSolveHermitian(char uplo,int n,int nrhs,double * a,int lda,int * ipiv,double * b,int ldb,int * info)1194 void lapackLdlSolveHermitian<double>(
1195     char uplo,
1196     int n,
1197     int nrhs,
1198     double* a,
1199     int lda,
1200     int* ipiv,
1201     double* b,
1202     int ldb,
1203     int* info) {
1204   dsytrs_(&uplo, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1205 }
1206 
1207 template <>
lapackLdlSolveHermitian(char uplo,int n,int nrhs,float * a,int lda,int * ipiv,float * b,int ldb,int * info)1208 void lapackLdlSolveHermitian<float>(
1209     char uplo,
1210     int n,
1211     int nrhs,
1212     float* a,
1213     int lda,
1214     int* ipiv,
1215     float* b,
1216     int ldb,
1217     int* info) {
1218   ssytrs_(&uplo, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1219 }
1220 
1221 template <>
lapackLdlSolveHermitian(char uplo,int n,int nrhs,c10::complex<double> * a,int lda,int * ipiv,c10::complex<double> * b,int ldb,int * info)1222 void lapackLdlSolveHermitian<c10::complex<double>>(
1223     char uplo,
1224     int n,
1225     int nrhs,
1226     c10::complex<double>* a,
1227     int lda,
1228     int* ipiv,
1229     c10::complex<double>* b,
1230     int ldb,
1231     int* info) {
1232   zhetrs_(
1233       &uplo,
1234       &n,
1235       &nrhs,
1236       reinterpret_cast<std::complex<double>*>(a),
1237       &lda,
1238       ipiv,
1239       reinterpret_cast<std::complex<double>*>(b),
1240       &ldb,
1241       info);
1242 }
1243 
1244 template <>
lapackLdlSolveHermitian(char uplo,int n,int nrhs,c10::complex<float> * a,int lda,int * ipiv,c10::complex<float> * b,int ldb,int * info)1245 void lapackLdlSolveHermitian<c10::complex<float>>(
1246     char uplo,
1247     int n,
1248     int nrhs,
1249     c10::complex<float>* a,
1250     int lda,
1251     int* ipiv,
1252     c10::complex<float>* b,
1253     int ldb,
1254     int* info) {
1255   chetrs_(
1256       &uplo,
1257       &n,
1258       &nrhs,
1259       reinterpret_cast<std::complex<float>*>(a),
1260       &lda,
1261       ipiv,
1262       reinterpret_cast<std::complex<float>*>(b),
1263       &ldb,
1264       info);
1265 }
1266 
lapackLuSolve(char trans,int n,int nrhs,c10::complex<double> * a,int lda,int * ipiv,c10::complex<double> * b,int ldb,int * info)1267 template<> void lapackLuSolve<c10::complex<double>>(char trans, int n, int nrhs, c10::complex<double> *a, int lda, int *ipiv, c10::complex<double> *b, int ldb, int *info) {
1268   zgetrs_(&trans, &n, &nrhs, reinterpret_cast<std::complex<double>*>(a), &lda, ipiv, reinterpret_cast<std::complex<double>*>(b), &ldb, info);
1269 }
1270 
lapackLuSolve(char trans,int n,int nrhs,c10::complex<float> * a,int lda,int * ipiv,c10::complex<float> * b,int ldb,int * info)1271 template<> void lapackLuSolve<c10::complex<float>>(char trans, int n, int nrhs, c10::complex<float> *a, int lda, int *ipiv, c10::complex<float> *b, int ldb, int *info) {
1272   cgetrs_(&trans, &n, &nrhs, reinterpret_cast<std::complex<float>*>(a), &lda, ipiv, reinterpret_cast<std::complex<float>*>(b), &ldb, info);
1273 }
1274 
lapackLuSolve(char trans,int n,int nrhs,double * a,int lda,int * ipiv,double * b,int ldb,int * info)1275 template<> void lapackLuSolve<double>(char trans, int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
1276   dgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1277 }
1278 
lapackLuSolve(char trans,int n,int nrhs,float * a,int lda,int * ipiv,float * b,int ldb,int * info)1279 template<> void lapackLuSolve<float>(char trans, int n, int nrhs, float *a, int lda, int *ipiv, float *b, int ldb, int *info) {
1280   sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
1281 }
1282 
lapackGels(char trans,int m,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb,c10::complex<double> * work,int lwork,int * info)1283 template<> void lapackGels<c10::complex<double>>(
1284     char trans, int m, int n, int nrhs,
1285     c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb,
1286     c10::complex<double> *work, int lwork, int *info) {
1287   zgels_(&trans, &m, &n, &nrhs,
1288       reinterpret_cast<std::complex<double>*>(a), &lda,
1289       reinterpret_cast<std::complex<double>*>(b), &ldb,
1290       reinterpret_cast<std::complex<double>*>(work), &lwork, info);
1291 }
1292 
lapackGels(char trans,int m,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb,c10::complex<float> * work,int lwork,int * info)1293 template<> void lapackGels<c10::complex<float>>(
1294     char trans, int m, int n, int nrhs,
1295     c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb,
1296     c10::complex<float> *work, int lwork, int *info) {
1297   cgels_(&trans, &m, &n, &nrhs,
1298       reinterpret_cast<std::complex<float>*>(a), &lda,
1299       reinterpret_cast<std::complex<float>*>(b), &ldb,
1300       reinterpret_cast<std::complex<float>*>(work), &lwork, info);
1301 }
1302 
lapackGels(char trans,int m,int n,int nrhs,double * a,int lda,double * b,int ldb,double * work,int lwork,int * info)1303 template<> void lapackGels<double>(
1304     char trans, int m, int n, int nrhs,
1305     double *a, int lda, double *b, int ldb,
1306     double *work, int lwork, int *info) {
1307   dgels_(&trans, &m, &n, &nrhs,
1308       a, &lda, b, &ldb, work, &lwork, info);
1309 }
1310 
lapackGels(char trans,int m,int n,int nrhs,float * a,int lda,float * b,int ldb,float * work,int lwork,int * info)1311 template<> void lapackGels<float>(
1312     char trans, int m, int n, int nrhs,
1313     float *a, int lda, float *b, int ldb,
1314     float *work, int lwork, int *info) {
1315   sgels_(&trans, &m, &n, &nrhs,
1316       a, &lda, b, &ldb, work, &lwork, info);
1317 }
1318 
lapackGelsd(int m,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb,double * s,double rcond,int * rank,c10::complex<double> * work,int lwork,double * rwork,int * iwork,int * info)1319 template<> void lapackGelsd<c10::complex<double>, double>(
1320     int m, int n, int nrhs,
1321     c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb,
1322     double *s, double rcond, int *rank,
1323     c10::complex<double> *work, int lwork,
1324     double *rwork, int *iwork, int *info) {
1325   zgelsd_(&m, &n, &nrhs,
1326       reinterpret_cast<std::complex<double>*>(a), &lda,
1327       reinterpret_cast<std::complex<double>*>(b), &ldb,
1328       s, &rcond, rank,
1329       reinterpret_cast<std::complex<double>*>(work), &lwork,
1330       rwork, iwork, info);
1331 }
1332 
lapackGelsd(int m,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb,float * s,float rcond,int * rank,c10::complex<float> * work,int lwork,float * rwork,int * iwork,int * info)1333 template<> void lapackGelsd<c10::complex<float>, float>(
1334     int m, int n, int nrhs,
1335     c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb,
1336     float *s, float rcond, int *rank,
1337     c10::complex<float> *work, int lwork,
1338     float *rwork, int *iwork, int *info) {
1339   cgelsd_(&m, &n, &nrhs,
1340       reinterpret_cast<std::complex<float>*>(a), &lda,
1341       reinterpret_cast<std::complex<float>*>(b), &ldb,
1342       s, &rcond, rank,
1343       reinterpret_cast<std::complex<float>*>(work), &lwork,
1344       rwork, iwork, info);
1345 }
1346 
lapackGelsd(int m,int n,int nrhs,double * a,int lda,double * b,int ldb,double * s,double rcond,int * rank,double * work,int lwork,double * rwork,int * iwork,int * info)1347 template<> void lapackGelsd<double>(
1348     int m, int n, int nrhs,
1349     double *a, int lda, double *b, int ldb,
1350     double *s, double rcond, int *rank,
1351     double *work, int lwork,
1352     double *rwork, int *iwork, int *info) {
1353   dgelsd_(&m, &n, &nrhs,
1354       a, &lda, b, &ldb,
1355       s, &rcond, rank,
1356       work, &lwork, iwork, info);
1357 }
1358 
lapackGelsd(int m,int n,int nrhs,float * a,int lda,float * b,int ldb,float * s,float rcond,int * rank,float * work,int lwork,float * rwork,int * iwork,int * info)1359 template<> void lapackGelsd<float>(
1360     int m, int n, int nrhs,
1361     float *a, int lda, float *b, int ldb,
1362     float *s, float rcond, int *rank,
1363     float *work, int lwork,
1364     float *rwork, int *iwork, int *info) {
1365   sgelsd_(&m, &n, &nrhs,
1366       a, &lda, b, &ldb,
1367       s, &rcond, rank,
1368       work, &lwork, iwork, info);
1369 }
1370 
lapackGelsy(int m,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb,int * jpvt,double rcond,int * rank,c10::complex<double> * work,int lwork,double * rwork,int * info)1371 template<> void lapackGelsy<c10::complex<double>, double>(
1372     int m, int n, int nrhs,
1373     c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb,
1374     int *jpvt, double rcond, int *rank,
1375     c10::complex<double> *work, int lwork, double *rwork, int *info) {
1376   zgelsy_(&m, &n, &nrhs,
1377       reinterpret_cast<std::complex<double>*>(a), &lda,
1378       reinterpret_cast<std::complex<double>*>(b), &ldb,
1379       jpvt, &rcond, rank,
1380       reinterpret_cast<std::complex<double>*>(work), &lwork,
1381       rwork, info);
1382 }
1383 
lapackGelsy(int m,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb,int * jpvt,float rcond,int * rank,c10::complex<float> * work,int lwork,float * rwork,int * info)1384 template<> void lapackGelsy<c10::complex<float>, float>(
1385     int m, int n, int nrhs,
1386     c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb,
1387     int *jpvt, float rcond, int *rank,
1388     c10::complex<float> *work, int lwork, float *rwork, int *info) {
1389   cgelsy_(&m, &n, &nrhs,
1390       reinterpret_cast<std::complex<float>*>(a), &lda,
1391       reinterpret_cast<std::complex<float>*>(b), &ldb,
1392       jpvt, &rcond, rank,
1393       reinterpret_cast<std::complex<float>*>(work), &lwork,
1394       rwork, info);
1395 }
1396 
lapackGelsy(int m,int n,int nrhs,double * a,int lda,double * b,int ldb,int * jpvt,double rcond,int * rank,double * work,int lwork,double * rwork,int * info)1397 template<> void lapackGelsy<double>(
1398     int m, int n, int nrhs,
1399     double *a, int lda, double *b, int ldb,
1400     int *jpvt, double rcond, int *rank,
1401     double *work, int lwork, double *rwork, int *info) {
1402   dgelsy_(&m, &n, &nrhs,
1403       a, &lda, b, &ldb,
1404       jpvt, &rcond, rank,
1405       work, &lwork, info);
1406 }
1407 
lapackGelsy(int m,int n,int nrhs,float * a,int lda,float * b,int ldb,int * jpvt,float rcond,int * rank,float * work,int lwork,float * rwork,int * info)1408 template<> void lapackGelsy<float>(
1409     int m, int n, int nrhs,
1410     float *a, int lda, float *b, int ldb,
1411     int *jpvt, float rcond, int *rank,
1412     float *work, int lwork, float *rwork, int *info) {
1413   sgelsy_(&m, &n, &nrhs,
1414       a, &lda, b, &ldb,
1415       jpvt, &rcond, rank,
1416       work, &lwork, info);
1417 }
1418 
lapackGelss(int m,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb,double * s,double rcond,int * rank,c10::complex<double> * work,int lwork,double * rwork,int * info)1419 template<> void lapackGelss<c10::complex<double>, double>(
1420     int m, int n, int nrhs,
1421     c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb,
1422     double *s, double rcond, int *rank,
1423     c10::complex<double> *work, int lwork,
1424     double *rwork, int *info
1425     ) {
1426   zgelss_(&m, &n, &nrhs,
1427       reinterpret_cast<std::complex<double>*>(a), &lda,
1428       reinterpret_cast<std::complex<double>*>(b), &ldb,
1429       s, &rcond, rank,
1430       reinterpret_cast<std::complex<double>*>(work), &lwork,
1431       rwork, info);
1432 }
1433 
lapackGelss(int m,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb,float * s,float rcond,int * rank,c10::complex<float> * work,int lwork,float * rwork,int * info)1434 template<> void lapackGelss<c10::complex<float>, float>(
1435     int m, int n, int nrhs,
1436     c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb,
1437     float *s, float rcond, int *rank,
1438     c10::complex<float> *work, int lwork,
1439     float *rwork, int *info
1440     ) {
1441   cgelss_(&m, &n, &nrhs,
1442       reinterpret_cast<std::complex<float>*>(a), &lda,
1443       reinterpret_cast<std::complex<float>*>(b), &ldb,
1444       s, &rcond, rank,
1445       reinterpret_cast<std::complex<float>*>(work), &lwork,
1446       rwork, info);
1447 }
1448 
lapackGelss(int m,int n,int nrhs,double * a,int lda,double * b,int ldb,double * s,double rcond,int * rank,double * work,int lwork,double * rwork,int * info)1449 template<> void lapackGelss<double>(
1450     int m, int n, int nrhs,
1451     double *a, int lda, double *b, int ldb,
1452     double *s, double rcond, int *rank,
1453     double *work, int lwork,
1454     double *rwork, int *info) {
1455   dgelss_(&m, &n, &nrhs,
1456       a, &lda, b, &ldb,
1457       s, &rcond, rank,
1458       work, &lwork, info);
1459 }
1460 
lapackGelss(int m,int n,int nrhs,float * a,int lda,float * b,int ldb,float * s,float rcond,int * rank,float * work,int lwork,float * rwork,int * info)1461 template<> void lapackGelss<float>(
1462     int m, int n, int nrhs,
1463     float *a, int lda, float *b, int ldb,
1464     float *s, float rcond, int *rank,
1465     float *work, int lwork,
1466     float *rwork, int *info) {
1467   sgelss_(&m, &n, &nrhs,
1468       a, &lda, b, &ldb,
1469       s, &rcond, rank,
1470       work, &lwork, info);
1471 }
1472 #endif
1473 
1474 #if AT_BUILD_WITH_BLAS()
blasTriangularSolve(char side,char uplo,char trans,char diag,int n,int nrhs,c10::complex<double> * a,int lda,c10::complex<double> * b,int ldb)1475 template<> void blasTriangularSolve<c10::complex<double>>(char side, char uplo, char trans, char diag, int n, int nrhs, c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb) {
1476   std::complex<double> one{1., 0.};
1477   ztrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(b), &ldb);
1478 }
1479 
blasTriangularSolve(char side,char uplo,char trans,char diag,int n,int nrhs,c10::complex<float> * a,int lda,c10::complex<float> * b,int ldb)1480 template<> void blasTriangularSolve<c10::complex<float>>(char side, char uplo, char trans, char diag, int n, int nrhs, c10::complex<float> *a, int lda, c10::complex<float> *b, int ldb) {
1481   std::complex<float> one{1.f, 0.f};
1482   ctrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, reinterpret_cast<std::complex<float>*>(a), &lda, reinterpret_cast<std::complex<float>*>(b), &ldb);
1483 }
1484 
blasTriangularSolve(char side,char uplo,char trans,char diag,int n,int nrhs,double * a,int lda,double * b,int ldb)1485 template<> void blasTriangularSolve<double>(char side, char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb) {
1486   auto one = 1.;
1487   dtrsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, a, &lda, b, &ldb);
1488 }
1489 
blasTriangularSolve(char side,char uplo,char trans,char diag,int n,int nrhs,float * a,int lda,float * b,int ldb)1490 template<> void blasTriangularSolve<float>(char side, char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb) {
1491   auto one = 1.f;
1492   strsm_(&side, &uplo, &trans, &diag, &n, &nrhs, &one, a, &lda, b, &ldb);
1493 }
1494 #endif
1495 
_linalg_check_errors(const Tensor & infos,const c10::string_view api_name,bool is_matrix)1496 void _linalg_check_errors(
1497     const Tensor& infos,
1498     const c10::string_view api_name,
1499     bool is_matrix) {
1500   TORCH_INTERNAL_ASSERT(infos.scalar_type() == kInt);
1501   TORCH_INTERNAL_ASSERT(infos.is_contiguous());
1502   if (infos.is_meta()) {
1503     return;
1504   }
1505 
1506   // If it's all zeros, we return early.
1507   // We optimise for the most likely case.
1508   if (C10_LIKELY(!infos.any().item<bool>())) {
1509     return;
1510   }
1511 
1512   int32_t info = 0;
1513   std::string batch_str;
1514   if (is_matrix) {
1515     info = infos.item<int>();
1516     // batch_str needn't be set for matrices
1517   } else {
1518     // Find the first non-zero info
1519     auto infos_cpu = infos.to(at::kCPU);
1520     auto ptr = infos_cpu.const_data_ptr<int32_t>();
1521     auto n = infos.numel();
1522     auto info_ptr = std::find_if(ptr, ptr + n, [](int32_t x) { return x != 0; });
1523     info = *info_ptr;
1524     batch_str = ": (Batch element " + std::to_string(std::distance(ptr, info_ptr)) + ")";
1525   }
1526 
1527   if (info < 0) {
1528     // Reference LAPACK 3.10+ changed `info` behavior for inputs with non-finite values
1529     // Previously, it would return `info` > 0, but now it returns `info` = -4
1530     // OpenBLAS 0.3.15+ uses the Reference LAPACK 3.10+.
1531     // MKL 2022.0+ uses the Reference LAPACK 3.10+.
1532     // Older version of MKL and OpenBLAS follow the old behavior (return `info` > 0).
1533     // Here we check for the case where `info` is -4 and raise an error
1534     if (api_name.find("svd") != api_name.npos) {
1535       TORCH_CHECK_LINALG(info != -4, api_name, batch_str,
1536           ": The algorithm failed to converge because the input matrix contained non-finite values.");
1537     }
1538     TORCH_INTERNAL_ASSERT(false, api_name, batch_str,
1539         ": Argument ", -info, " has illegal value. Most certainly there is a bug in the implementation calling the backend library.");
1540   } else if (info > 0) {
1541     if (api_name.find("inv") != api_name.npos) {
1542       // inv, inverse, cholesky_inverse, etc.
1543       TORCH_CHECK_LINALG(false, api_name, batch_str,
1544           ": The diagonal element ", info, " is zero, the inversion could not be completed because the input matrix is singular.");
1545     } else if (api_name.find("solve") != api_name.npos) {
1546       // solve, linalg_solve, cholesky_solve, etc.
1547       TORCH_CHECK_LINALG(false, api_name, batch_str,
1548           ": The solver failed because the input matrix is singular.");
1549     } else if (api_name.find("cholesky") != api_name.npos) {
1550       TORCH_CHECK_LINALG(false, api_name, batch_str,
1551           ": The factorization could not be completed because the input is not positive-definite (the leading minor of order ", info, " is not positive-definite).");
1552     } else if (api_name.find("svd") != api_name.npos) {
1553       TORCH_CHECK_LINALG(false, api_name, batch_str,
1554           ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated singular values (error code: ", info, ").");
1555     } else if (api_name.find("eig") != api_name.npos || api_name.find("syevd") != api_name.npos) {
1556       TORCH_CHECK_LINALG(false, api_name, batch_str,
1557           ": The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: ", info, ").");
1558     } else if (api_name.find("lstsq") != api_name.npos) {
1559       TORCH_CHECK_LINALG(false, api_name, batch_str,
1560           ": The least squares solution could not be computed because the input matrix does not have full rank (error code: ", info, ").");
1561     } else if (api_name.find("lu_factor") != api_name.npos) {
1562       TORCH_CHECK(false, api_name, batch_str,
1563           ": U[", info, ",", info, "] is zero and using it on lu_solve would result in a division by zero. "
1564           "If you still want to perform the factorization, consider calling linalg.lu(A, pivot) or "
1565           "linalg.lu_factor_ex(A, pivot)");
1566     } else {
1567       TORCH_INTERNAL_ASSERT(false, api_name, ": Unknown error code: ", info, ".");
1568     }
1569   }
1570   // We should never reach this point as info was non-zero
1571   TORCH_INTERNAL_ASSERT(false);
1572 }
1573 
1574 // If an input requires fw or bw grad then we need to go down a different
1575 // (slower) path to ensure that the gradients are computable.
1576 // That is what `_may_require_fw_or_bw_grad` is helpful for.
1577 //
1578 // Why is there a isTensorSubclassLike check here?
1579 // Without it, this function can lead to composite compliance problems, which
1580 // may lead to bugs in functorch, where a Tensor Subclass that doesn't
1581 // require grad may wrap a Tensor subclass that requires grad.
_may_require_fw_or_bw_grad(const Tensor & input)1582 static bool _may_require_fw_or_bw_grad(const Tensor& input) {
1583   return ((at::GradMode::is_enabled() && input.requires_grad())
1584           || input._fw_grad(/*level */ 0).defined()
1585           || isTensorSubclassLike(input));
1586 }
1587 
1588 // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast)
1589 
1590 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.inv ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
TORCH_IMPL_FUNC(linalg_inv_ex_out)1591 TORCH_IMPL_FUNC(linalg_inv_ex_out)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
1592   // Fill result with the identity
1593   result.zero_();
1594   result.diagonal(0, -2, -1).fill_(1.);
1595   at::linalg_solve_ex_out(const_cast<Tensor&>(result), const_cast<Tensor&>(info), A, result, /*left*/true);
1596   if (check_errors) {
1597     at::_linalg_check_errors(info, "linalg.inv_ex", A.dim() == 2);
1598   }
1599 }
1600 
linalg_inv_out(const Tensor & A,Tensor & result)1601 Tensor& linalg_inv_out(const Tensor& A, Tensor& result) {
1602   auto info = at::empty({0}, A.options().dtype(kInt));
1603   at::linalg_inv_ex_out(result, info, A);
1604   at::_linalg_check_errors(info, "linalg.inv", A.dim() == 2);
1605   return result;
1606 }
1607 
linalg_inv(const Tensor & A)1608 Tensor linalg_inv(const Tensor& A) {
1609   auto [result, info] = at::linalg_inv_ex(A);
1610   at::_linalg_check_errors(info, "linalg.inv", A.dim() == 2);
1611   return result;
1612 }
1613 
inverse_out(const Tensor & A,Tensor & result)1614 Tensor& inverse_out(const Tensor& A, Tensor& result) {
1615   return at::linalg_inv_out(result, A);
1616 }
1617 
inverse(const Tensor & A)1618 Tensor inverse(const Tensor& A) {
1619   return at::linalg_inv(A);
1620 }
1621 
1622 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1623 
1624 template<typename scalar_t>
apply_cholesky_solve(Tensor & b,Tensor & A,bool upper,Tensor & infos)1625 static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, Tensor& infos) {
1626 #if !AT_BUILD_WITH_LAPACK()
1627   AT_ERROR("cholesky_solve: LAPACK library not found in compilation");
1628 #else
1629   char uplo = upper ? 'U' : 'L';
1630 
1631   auto A_data = A.const_data_ptr<scalar_t>();
1632   auto b_data = b.data_ptr<scalar_t>();
1633   auto infos_data = infos.data_ptr<int>();
1634   auto A_mat_stride = matrixStride(A);
1635   auto b_mat_stride = matrixStride(b);
1636   auto batch_size = batchCount(A);
1637   auto n = A.size(-2);
1638   auto ldab = std::max<int64_t>(1, n);
1639   auto nrhs = b.size(-1);
1640 
1641   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1642   int info;
1643   for (const auto i : c10::irange(batch_size)) {
1644     const scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
1645     scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
1646     lapackCholeskySolve<scalar_t>(uplo, n, nrhs, const_cast<scalar_t*>(A_working_ptr), ldab, b_working_ptr, ldab, &info);
1647     infos_data[i] = info;
1648     if (info != 0) {
1649       return;
1650     }
1651   }
1652 #endif
1653 }
1654 
_cholesky_solve_helper_cpu(const Tensor & self,const Tensor & A,bool upper)1655 Tensor _cholesky_solve_helper_cpu(const Tensor& self, const Tensor& A, bool upper) {
1656   auto self_working_copy = cloneBatchedColumnMajor(self);
1657   auto A_working_copy = cloneBatchedColumnMajor(A);
1658   auto infos = at::zeros({batchCount(self)}, self.options().dtype(kInt));
1659   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "cholesky_solve_cpu", [&]{
1660     apply_cholesky_solve<scalar_t>(self_working_copy, A_working_copy, upper, infos);
1661   });
1662 
1663   at::_linalg_check_errors(infos, "cholesky_solve_cpu", self.dim() == 2);
1664   return self_working_copy;
1665 }
1666 
1667 // Supports arbitrary batch dimensions for self and A
cholesky_solve(const Tensor & self,const Tensor & A,bool upper)1668 Tensor cholesky_solve(const Tensor& self, const Tensor& A, bool upper) {
1669   TORCH_CHECK(self.dim() >= 2,
1670            "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
1671   TORCH_CHECK(A.dim() >= 2,
1672            "u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead");
1673   auto [self_broadcasted, A_broadcasted] = _linalg_broadcast_batch_dims(self, A, "cholesky_solve");
1674   return at::_cholesky_solve_helper(self_broadcasted, A_broadcasted, upper);
1675 }
1676 
cholesky_solve_out(const Tensor & self,const Tensor & A,bool upper,Tensor & result)1677 Tensor& cholesky_solve_out(const Tensor& self, const Tensor& A, bool upper, Tensor& result) {
1678   checkSameDevice("cholesky_solve", result, self);
1679   checkLinalgCompatibleDtype("cholesky_solve", result, self);
1680   Tensor result_tmp = at::cholesky_solve(self, A, upper);
1681   at::native::resize_output(result, result_tmp.sizes());
1682   result.copy_(result_tmp);
1683   return result;
1684 }
1685 
1686 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1687 
1688 DEFINE_DISPATCH(cholesky_stub);
1689 
cholesky(const Tensor & self,bool upper)1690 Tensor cholesky(const Tensor &self, bool upper) {
1691    TORCH_WARN_ONCE(
1692     "torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be ",
1693     "removed in a future PyTorch release.\n",
1694     "L = torch.cholesky(A)\n",
1695     "should be replaced with\n",
1696     "L = torch.linalg.cholesky(A)\n",
1697     "and\n"
1698     "U = torch.cholesky(A, upper=True)\n",
1699     "should be replaced with\n",
1700     "U = torch.linalg.cholesky(A).mH\n"
1701     "This transform will produce equivalent results for all valid (symmetric positive definite) inputs."
1702   );
1703   if (self.numel() == 0) {
1704     return at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
1705   }
1706   squareCheckInputs(self, "cholesky");
1707 
1708   auto raw_cholesky_output = cloneBatchedColumnMajor(self);
1709   auto info_shape = IntArrayRef(
1710       self.sizes().cbegin(), self.sizes().cend() - 2); // self.shape[:-2]
1711   auto info = at::empty({info_shape}, self.options().dtype(kInt));
1712 
1713   // fill the raw_cholesky_output with the result
1714   cholesky_stub(self.device().type(), raw_cholesky_output, info, upper);
1715 
1716   at::_linalg_check_errors(info, "cholesky", self.dim() == 2);
1717 
1718   if (upper) {
1719     return raw_cholesky_output.triu_();
1720   } else {
1721     return raw_cholesky_output.tril_();
1722   }
1723 }
1724 
cholesky_out(const Tensor & self,bool upper,Tensor & result)1725 Tensor& cholesky_out(const Tensor &self, bool upper, Tensor &result) {
1726    TORCH_WARN_ONCE(
1727     "torch.cholesky is deprecated in favor of torch.linalg.cholesky and will be ",
1728     "removed in a future PyTorch release.\n",
1729     "L = torch.cholesky(A)\n",
1730     "should be replaced with\n",
1731     "L = torch.linalg.cholesky(A)\n",
1732     "and\n"
1733     "U = torch.cholesky(A, upper=True)\n",
1734     "should be replaced with\n",
1735     "U = torch.linalg.cholesky(A).mH\n"
1736     "This transform will produce equivalent results for all valid (symmetric positive definite) inputs."
1737   );
1738   checkSameDevice("cholesky", result, self);
1739   checkLinalgCompatibleDtype("cholesky", result, self);
1740   Tensor result_tmp = at::cholesky(self, upper);
1741   at::native::resize_output(result, result_tmp.sizes());
1742   result.copy_(result_tmp);
1743   return result;
1744 }
1745 
TORCH_IMPL_FUNC(linalg_cholesky_ex_out)1746 TORCH_IMPL_FUNC(linalg_cholesky_ex_out)(const Tensor& A,
1747                                         bool upper,
1748                                         bool check_errors,
1749                                         const Tensor& L,
1750                                         const Tensor& info) {
1751   // Nothing to do there
1752   if (L.numel() == 0) {
1753     info.zero_();
1754     return;
1755   }
1756   const auto cpu = A.device() == kCPU;
1757 
1758   // We can perform this optimisation just on CPU as it fails for MAGMA
1759   // due to some bug
1760   if (cpu) {
1761     if (upper) {
1762       at::triu_out(const_cast<Tensor&>(L), A);
1763     } else {
1764       at::tril_out(const_cast<Tensor&>(L), A);
1765     }
1766   } else {
1767     L.copy_(A);
1768   }
1769 
1770   cholesky_stub(L.device().type(), L, info, upper);
1771 
1772   if (!cpu) {
1773     if (upper) {
1774       L.triu_();
1775     } else {
1776       L.tril_();
1777     }
1778   }
1779 
1780   if (check_errors) {
1781     at::_linalg_check_errors(info, "linalg.cholesky_ex", A.dim() == 2);
1782   }
1783 }
1784 
linalg_cholesky(const Tensor & A,bool upper)1785 Tensor linalg_cholesky(const Tensor& A, bool upper) {
1786   auto [L, info] = at::linalg_cholesky_ex(A, upper, /*check_errors=*/false);
1787   at::_linalg_check_errors(info, "linalg.cholesky", A.dim() == 2);
1788   return L;
1789 }
1790 
linalg_cholesky_out(const Tensor & A,bool upper,Tensor & L)1791 Tensor& linalg_cholesky_out(const Tensor& A, bool upper, Tensor& L) {
1792   auto info = at::empty({0}, A.options().dtype(kInt));
1793   at::linalg_cholesky_ex_out(L, info, A, upper, /*check_errors=*/false);
1794   at::_linalg_check_errors(info, "linalg.cholesky", A.dim() == 2);
1795   return L;
1796 }
1797 
1798 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1799 
1800 DEFINE_DISPATCH(cholesky_inverse_stub);
1801 
cholesky_inverse_out_info(Tensor & result,Tensor & infos,const Tensor & input,bool upper)1802 static Tensor& cholesky_inverse_out_info(Tensor& result, Tensor& infos, const Tensor& input, bool upper) {
1803   TORCH_INTERNAL_ASSERT(input.dim() >= 2);
1804   TORCH_INTERNAL_ASSERT(input.size(-1) == input.size(-2));
1805 
1806   TORCH_INTERNAL_ASSERT(result.scalar_type() == input.scalar_type());
1807   TORCH_INTERNAL_ASSERT(result.device() == input.device());
1808 
1809   TORCH_INTERNAL_ASSERT(infos.scalar_type() == at::kInt);
1810   TORCH_INTERNAL_ASSERT(infos.device() == at::kCPU);
1811   TORCH_INTERNAL_ASSERT(infos.numel() == std::max<int64_t>(1, batchCount(input)));
1812 
1813   // if result has no elements we can modify it
1814   if (result.numel() == 0) {
1815     at::native::resize_as_(result, input.mT(), MemoryFormat::Contiguous);
1816     result.transpose_(-2, -1);
1817   }
1818 
1819   // result tensor must be in batched column major order (Fortran contiguous)
1820   TORCH_INTERNAL_ASSERT(result.mT().is_contiguous());
1821   TORCH_INTERNAL_ASSERT(result.sizes().equals(input.sizes()));
1822 
1823   // cholesky_inverse_stub (apply_cholesky_inverse) performs calculations in-place and result must be a copy of input
1824   result.copy_(input);
1825 
1826   // infos must be contiguous
1827   TORCH_INTERNAL_ASSERT(infos.is_contiguous());
1828   infos.fill_(0);
1829 
1830   result = cholesky_inverse_stub(result.device().type(), result, infos, upper);
1831   return result;
1832 }
1833 
cholesky_inverse_out(const Tensor & input,bool upper,Tensor & result)1834 Tensor& cholesky_inverse_out(const Tensor &input, bool upper, Tensor &result) {
1835   squareCheckInputs(input, "cholesky_inverse");
1836   checkSameDevice("cholesky_inverse", result, input);
1837   checkLinalgCompatibleDtype("cholesky_inverse", result, input);
1838 
1839   // MAGMA requires 'infos' to reside in CPU memory, therefore we create 'infos' only on CPU for now.
1840   auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt).device(kCPU));
1841 
1842   bool result_input_same_type = (result.scalar_type() == input.scalar_type());
1843   bool result_equal_expected_shape = result.sizes().equals(input.sizes());
1844   bool is_batched_column_major = false;
1845   if (result.dim() >= 2) {
1846     is_batched_column_major = result.mT().is_contiguous();
1847   }
1848 
1849   // if result is not empty and not in batched column major format
1850   bool copy_needed = (result.numel() != 0 && !is_batched_column_major);
1851   copy_needed |= !result_input_same_type;  // or result does not have the same dtype as input
1852   copy_needed |= (result.numel() != 0 && !result_equal_expected_shape); // or result does not have the expected shape
1853   // we have to allocate a temporary tensor
1854   if (copy_needed) {
1855     Tensor result_tmp = at::empty({0}, input.options());
1856     result_tmp = cholesky_inverse_out_info(result_tmp, infos, input, upper);
1857     at::native::resize_output(result, result_tmp.sizes());
1858     result.copy_(result_tmp);
1859   } else {
1860     // use result's memory directly
1861     result = cholesky_inverse_out_info(result, infos, input, upper);
1862   }
1863 
1864   // Now check LAPACK/MAGMA error codes
1865   at::_linalg_check_errors(infos, "cholesky_inverse", result.dim() == 2);
1866   return result;
1867 }
1868 
cholesky_inverse(const Tensor & input,bool upper)1869 Tensor cholesky_inverse(const Tensor &input, bool upper) {
1870   Tensor result = at::empty({0}, input.options());
1871   result = at::cholesky_inverse_out(result, input, upper);
1872   return result;
1873 }
1874 
1875 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1876 
1877 // Auxiliary function that returns the LU decomposition to use it in the backward
TORCH_IMPL_FUNC(_linalg_solve_ex_out)1878 TORCH_IMPL_FUNC(_linalg_solve_ex_out)(const Tensor& A,
1879                                       const Tensor& B,
1880                                       bool left,
1881                                       bool check_errors,
1882                                       const Tensor& result,
1883                                       const Tensor& LU,
1884                                       const Tensor& pivots,
1885                                       const Tensor& info) {
1886   // Possible optimization: Compute the LU factorization of A^T if A is contiguous
1887   // Then we solve A^T X = B with adjoint=True
1888   // This saves a copy as A doesn't need to be copied into an F-contig matrix in lu_factor
1889   // This optimization makes functorch's batching rule difficult. See NOTE [ solve_ex Batch Rule Contiguity ]
1890   const bool use_A_T = A.is_contiguous() && !A.is_complex();
1891   at::linalg_lu_factor_ex_out(const_cast<Tensor&>(LU),
1892                               const_cast<Tensor&>(pivots),
1893                               const_cast<Tensor&>(info),
1894                               use_A_T ? A.mT() : A);
1895   if (check_errors) {
1896     at::_linalg_check_errors(info, "torch.linalg.solve_ex", A.dim() == 2);
1897   }
1898 
1899   // [numpy-compat] Handle vectors on the rhs
1900   const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, B);
1901   auto result_ = vector_case ? result.unsqueeze(-1) : result;
1902   auto B_ = vector_case ? B.unsqueeze(-1) : B;
1903   at::linalg_lu_solve_out(result_, LU, pivots, B_, left, /*adjoint*/use_A_T);
1904 }
1905 
linalg_solve_ex_out(const Tensor & A,const Tensor & B,bool left,bool check_errors,Tensor & result,Tensor & info)1906 std::tuple<Tensor&, Tensor&> linalg_solve_ex_out(const Tensor& A,
1907                                                  const Tensor& B,
1908                                                  bool left,
1909                                                  bool check_errors,
1910                                                  Tensor& result,
1911                                                  Tensor& info) {
1912   auto LU = B.new_empty({0});
1913   auto pivots = B.new_empty({0}, kInt);
1914   at::_linalg_solve_ex_out(result, LU, pivots, info, A, B, left, check_errors);
1915   return std::tie(result, info);
1916 }
1917 
1918 // We implement linalg_solve_ex as a composite function of _linalg_solve
linalg_solve_ex(const Tensor & A,const Tensor & B,bool left,bool check_errors)1919 std::tuple<Tensor, Tensor> linalg_solve_ex(const Tensor& A,
1920                                            const Tensor& B,
1921                                            bool left,
1922                                            bool check_errors) {
1923   auto [result, LU, pivots, info] = at::_linalg_solve_ex(A, B, left, check_errors);
1924   return std::make_tuple(std::move(result), std::move(info));
1925 }
1926 
linalg_solve_out(const Tensor & A,const Tensor & B,bool left,Tensor & result)1927 Tensor& linalg_solve_out(const Tensor& A,
1928                          const Tensor& B,
1929                          bool left,
1930                          Tensor& result) {
1931   auto info = B.new_empty({0}, kInt);
1932   at::linalg_solve_ex_out(result, info, A, B, left);
1933   at::_linalg_check_errors(info, "torch.linalg.solve", A.dim() == 2);
1934   return result;
1935 }
1936 
linalg_solve(const Tensor & A,const Tensor & B,bool left)1937 Tensor linalg_solve(const Tensor& A,
1938                     const Tensor& B,
1939                     bool left) {
1940   if (A.layout() == kSparseCsr) {
1941     return at::_spsolve(A, B, left);
1942   }
1943   auto [result, info] = at::linalg_solve_ex(A, B, left);
1944   at::_linalg_check_errors(info, "torch.linalg.solve", A.dim() == 2);
1945   return result;
1946 }
1947 
1948 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_factor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1949 
1950 DEFINE_DISPATCH(lu_factor_stub);
1951 
TORCH_IMPL_FUNC(linalg_lu_factor_ex_out)1952 TORCH_IMPL_FUNC(linalg_lu_factor_ex_out)(const Tensor& A,
1953                                          bool pivot,
1954                                          bool check_errors,
1955                                          const Tensor& LU,
1956                                          const Tensor& pivots,
1957                                          const Tensor& info) {
1958   if (A.numel() == 0) {
1959     // zero out the infos as it will have one element if the input is a matrix of size (0, 0)
1960     info.zero_();
1961     return;
1962   }
1963   if (!LU.is_same(A)) {
1964     LU.copy_(A);
1965   }
1966 
1967   lu_factor_stub(A.device().type(), LU, pivots, info, pivot);
1968 
1969   if (check_errors) {
1970     at::_linalg_check_errors(info, "torch.linalg.lu_factor_ex", A.dim() == 2);
1971   }
1972 }
1973 
linalg_lu_factor_out(const Tensor & A,bool pivot,Tensor & LU,Tensor & pivots)1974 std::tuple<Tensor&, Tensor&> linalg_lu_factor_out(const Tensor& A, bool pivot, Tensor& LU, Tensor& pivots) {
1975   auto info = at::empty({0}, A.options().dtype(kInt));
1976   // We pass check_errors as we want to use lu_factor rather than lu_factor_ex in the errors
1977   at::linalg_lu_factor_ex_out(LU, pivots, info, A, pivot, /*check_errors=*/false);
1978   at::_linalg_check_errors(info, "torch.linalg.lu_factor", A.dim() == 2);
1979   return std::tie(LU, pivots);
1980 }
1981 
linalg_lu_factor(const Tensor & A,bool pivot)1982 std::tuple<Tensor, Tensor> linalg_lu_factor(const Tensor& A, bool pivot) {
1983   auto [LU, pivots, info] = at::linalg_lu_factor_ex(A, pivot, /*check_errors=*/false);
1984   at::_linalg_check_errors(info, "torch.linalg.lu_factor", A.dim() == 2);
1985   return std::make_tuple(std::move(LU), std::move(pivots));
1986 }
1987 
1988 // TODO Deprecate this function in favour of linalg_lu_factor_ex
_lu_with_info(const Tensor & self,bool compute_pivots,bool)1989 std::tuple<Tensor, Tensor, Tensor> _lu_with_info(const Tensor& self, bool compute_pivots, bool) {
1990    TORCH_WARN_ONCE(
1991     "torch.lu is deprecated in favor of torch.linalg.lu_factor / torch.linalg.lu_factor_ex and will be ",
1992     "removed in a future PyTorch release.\n",
1993     "LU, pivots = torch.lu(A, compute_pivots)\n",
1994     "should be replaced with\n",
1995     "LU, pivots = torch.linalg.lu_factor(A, compute_pivots)\n",
1996     "and\n",
1997     "LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)\n",
1998     "should be replaced with\n",
1999     "LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)"
2000   );
2001   return at::linalg_lu_factor_ex(self, compute_pivots, false);
2002 }
2003 
2004 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2005 
2006 DEFINE_DISPATCH(unpack_pivots_stub);
2007 
TORCH_IMPL_FUNC(linalg_lu_out)2008 TORCH_IMPL_FUNC(linalg_lu_out)(const Tensor& A,
2009                                bool pivot,
2010                                const Tensor& P,
2011                                const Tensor& L,
2012                                const Tensor& U) {
2013   const auto m = A.sizes().end()[-2];
2014   const auto n = A.sizes().end()[-1];
2015 
2016   // A.shape[-2:] == (m, n)
2017   // P.shape[-2:] == (m, m)
2018   // L.shape[-2:] == (m, k)
2019   // U.shape[-2:] == (k, n)
2020   // with k = min(m, n)
2021 
2022   // Use L as it has the correct size
2023   const bool use_L = m > n;
2024   auto pivots = at::empty({0}, A.options().dtype(kInt));
2025   auto info = at::empty({0}, A.options().dtype(kInt));
2026   at::linalg_lu_factor_ex_out(const_cast<Tensor&>(use_L ? L : U),
2027                               const_cast<Tensor&>(pivots),
2028                               const_cast<Tensor&>(info),
2029                               A,
2030                               pivot,
2031                               /*check_errors=*/false);
2032   at::lu_unpack_out(const_cast<Tensor&>(P),
2033                     const_cast<Tensor&>(L),
2034                     const_cast<Tensor&>(U),
2035                     use_L ? L : U,
2036                     pivots,
2037                     /*unpack_data=*/true,
2038                     /*unpack_pivots=*/pivot);
2039 }
2040 
2041 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_unpack ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2042 
TORCH_IMPL_FUNC(lu_unpack_out)2043 TORCH_IMPL_FUNC(lu_unpack_out)(const Tensor& LU,
2044                                const Tensor& pivots,
2045                                bool unpack_lu,
2046                                bool unpack_pivots,
2047                                const Tensor& P,
2048                                const Tensor& L,
2049                                const Tensor& U) {
2050   const auto m = LU.sizes().end()[-2];
2051   const auto n = LU.sizes().end()[-1];
2052 
2053   // A.shape[-2:] == (m, n)
2054   // P.shape[-2:] == (m, m)
2055   // L.shape[-2:] == (m, k)
2056   // U.shape[-2:] == (k, n)
2057   // with k = min(m, n)
2058 
2059   if (unpack_lu) {
2060     if (m > n || LU.is_same(L)) {
2061       // The order of triu and tril is important as we may have LU.is_same(L)
2062       at::triu_out(const_cast<Tensor&>(U), m == n ? LU : LU.narrow(-2, 0, n), 0);
2063       at::tril_out(const_cast<Tensor&>(L), LU, -1);
2064       L.diagonal(0, -2, -1).fill_(1.);
2065     } else {
2066       // The order of triu and tril is important as we may have LU.is_same(U)
2067       at::tril_out(const_cast<Tensor&>(L), m == n ? LU : LU.narrow(-1, 0, m), -1);
2068       L.diagonal(0, -2, -1).fill_(1.);
2069       at::triu_out(const_cast<Tensor&>(U), LU, 0);
2070     }
2071   }
2072   if (unpack_pivots) {
2073     // lu_factor_ex returns an int32 1-based indexing, which is what we have in `pivots`
2074     // We transform that to a proper permutation of the indices {0, ..., m-1}
2075     const auto perm_sizes = IntArrayRef(P.sizes().data(), P.dim() - 1);
2076 
2077     // Fill `perm` with the identity permutation (perhaps batched)
2078     const auto perm = at::arange(m, pivots.options().memory_format(at::MemoryFormat::Contiguous).dtype(kLong))
2079                         .expand(perm_sizes)
2080                         .contiguous();
2081 
2082     // Note that perm is of type kLong and pivots is a 1-indexed kInt.
2083     // This is taken into account in the unpack_pivots kernel
2084     auto iter = TensorIteratorConfig()
2085       .set_check_mem_overlap(false)
2086       .check_all_same_dtype(false)
2087       .resize_outputs(false)
2088       .declare_static_shape(pivots.sizes(), /*squash_dims=*/pivots.dim() - 1)
2089       .add_output(perm)
2090       .add_owned_const_input(pivots.contiguous())
2091       .build();
2092 
2093     unpack_pivots_stub(pivots.device().type(), iter, std::min(m, n), m);
2094 
2095     // Transform the permutation into a permutation matrix
2096     P.zero_();
2097     P.scatter_(-2, perm.unsqueeze(-2), 1.);
2098   }
2099 }
2100 
2101 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2102 DEFINE_DISPATCH(lu_solve_stub);
2103 
TORCH_IMPL_FUNC(linalg_lu_solve_out)2104 TORCH_IMPL_FUNC(linalg_lu_solve_out)(const Tensor& LU,
2105                                      const Tensor& pivots,
2106                                      const Tensor& B,
2107                                      bool left,
2108                                      bool adjoint,
2109                                      const Tensor& result) {
2110   // Trivial case
2111   if (result.numel() == 0) {
2112     return;
2113   }
2114 
2115   // Solve A^H X = B^H. Then we return X^H
2116   if (!left) {
2117     adjoint = !adjoint;
2118     result.transpose_(-2, -1);
2119   }
2120 
2121   // Copy B (or B^H) into result
2122   if (!result.is_same(B)) {
2123     result.copy_(left ? B : B.mH());
2124   }
2125 
2126   // Make LU / pivots F-contiguous
2127   auto pivots_ = pivots.expect_contiguous();
2128   auto LU_ = at::native::borrow_else_clone(
2129       LU.mT().is_contiguous(), LU, LU, /*contig=*/false);
2130 
2131   const auto trans = !adjoint ? TransposeType::NoTranspose :
2132                      LU.is_complex() ? TransposeType::ConjTranspose
2133                                      : TransposeType::Transpose;
2134 
2135   lu_solve_stub(LU_->device().type(), *LU_, *pivots_, result, trans);
2136 
2137   // Conj-transpose back in-place
2138   if (!left) {
2139     result.transpose_(-2, -1);
2140     if (result.is_complex()) {
2141       result._set_conj(!result.is_conj());
2142     }
2143   }
2144 }
2145 
lu_solve(const Tensor & self,const Tensor & LU_data,const Tensor & LU_pivots)2146 Tensor lu_solve(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots) {
2147   TORCH_WARN_ONCE(
2148     "torch.lu_solve is deprecated in favor of torch.linalg.lu_solve",
2149     "and will be removed in a future PyTorch release.\n",
2150     "Note that torch.linalg.lu_solve has its arguments reversed.\n",
2151     "X = torch.lu_solve(B, LU, pivots)\n",
2152     "should be replaced with\n",
2153     "X = torch.linalg.lu_solve(LU, pivots, B)"
2154   );
2155   return at::linalg_lu_solve(LU_data, LU_pivots, self);
2156 }
2157 
lu_solve_out(const Tensor & self,const Tensor & LU_data,const Tensor & LU_pivots,Tensor & result)2158 Tensor& lu_solve_out(const Tensor& self, const Tensor& LU_data, const Tensor& LU_pivots, Tensor& result) {
2159   TORCH_WARN_ONCE(
2160     "torch.lu_solve is deprecated in favor of torch.linalg.lu_solve",
2161     "and will be removed in a future PyTorch release.\n",
2162     "Note that torch.linalg.lu_solve has its arguments reversed.\n",
2163     "X = torch.lu_solve(B, LU, pivots)\n",
2164     "should be replaced with\n",
2165     "X = torch.linalg.lu_solve(LU, pivots, B)"
2166   );
2167   return at::linalg_lu_solve_out(result, LU_data, LU_pivots, self);
2168 }
2169 
2170 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2171 
2172 DEFINE_DISPATCH(triangular_solve_stub);
2173 
2174 /*
2175 Solves the matrix equation 'input' @ 'result' = 'other' for the 'result'.
2176 The result of the computation is saved in-place in 'result' tensor,
2177 'clone_input' will be a copy of 'input',
2178 'infos' is used to store information for possible checks for error,
2179 'upper' controls the portion of input matrix to consider in computations,
2180 'transpose' if true then 'input.mT()' @ 'result' = 'other' is solved,
2181 'unitriangular' if true then the diagonal elements of 'input' are assumed to be 1
2182 and the actual diagonal values are not used.
2183 */
triangular_solve_out_impl(const Tensor & result,const Tensor & clone_input,const Tensor & input,const Tensor & other,bool upper,bool transpose,bool unitriangular)2184 static void triangular_solve_out_impl(
2185     const Tensor& result,
2186     const Tensor& clone_input,
2187     const Tensor& input,
2188     const Tensor& other,
2189     bool upper, bool transpose, bool unitriangular) {
2190   TORCH_WARN_ONCE(
2191     "torch.triangular_solve is deprecated in favor of torch.linalg.solve_triangular",
2192     "and will be removed in a future PyTorch release.\n",
2193     "torch.linalg.solve_triangular has its arguments reversed and does not return a copy of one of the inputs.\n",
2194     "X = torch.triangular_solve(B, A).solution\n",
2195     "should be replaced with\n",
2196     "X = torch.linalg.solve_triangular(A, B).");
2197   // These internal asserts make explicit the assumptions in the implementation
2198   // Error check with the actual error messages are done on the higher level of
2199   // the hierarchy of calls
2200   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2);
2201   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.size(-2) == input.size(-1));
2202 
2203   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == other.device());
2204   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == result.device());
2205   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == clone_input.device());
2206 
2207   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.scalar_type() == other.scalar_type());
2208   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.scalar_type() == result.scalar_type());
2209   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.scalar_type() == clone_input.scalar_type());
2210 
2211   // if 'result' has no elements we can modify it
2212   if (result.numel() == 0) {
2213     result.resize_(other.mT().sizes(), MemoryFormat::Contiguous);
2214     result.transpose_(-2, -1);  // make 'result' to have Fortran contiguous memory layout
2215   }
2216 
2217   // if 'clone_input' has no elements we can modify it
2218   if (clone_input.numel() == 0) {
2219     clone_input.resize_(input.mT().sizes(), MemoryFormat::Contiguous);
2220     clone_input.transpose_(-2, -1);  // make 'clone_input' to have Fortran contiguous memory layout
2221   }
2222 
2223   // 'result' and 'clone_input' must be in batched column major order (Fortran contiguous)
2224   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.mT().is_contiguous());
2225   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(clone_input.mT().is_contiguous());
2226 
2227   // triangular_solve_stub performs calculations in-place
2228   // 'result' must be a copy of 'other'
2229   // 'clone_input' must be a copy of 'input'
2230   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(other.sizes()));
2231   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(clone_input.sizes().equals(input.sizes()));
2232   result.copy_(other);
2233   clone_input.copy_(input);
2234 
2235   triangular_solve_stub(input.device().type(), clone_input, result, /*left=*/true, upper, transpose ? TransposeType::Transpose : TransposeType::NoTranspose, unitriangular);
2236 }
2237 
TORCH_IMPL_FUNC(triangular_solve_out)2238 TORCH_IMPL_FUNC(triangular_solve_out)(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular, const Tensor& result, const Tensor& clone_A) {
2239   auto [self_broadcast, A_broadcast] = _linalg_broadcast_batch_dims(self, A, "triangular_solve");
2240 
2241   bool copy_needed = !result.transpose(-2, -1).is_contiguous();
2242   copy_needed |= !clone_A.transpose(-2, -1).is_contiguous();
2243 
2244   if (copy_needed) {
2245     Tensor result_tmp = at::empty({0}, self.options());
2246     Tensor clone_A_tmp = at::empty({0}, A.options());
2247 
2248     triangular_solve_out_impl(result_tmp, clone_A_tmp, A_broadcast, self_broadcast, upper, transpose, unitriangular);
2249 
2250     result.copy_(result_tmp);
2251     clone_A.copy_(clone_A_tmp);
2252   } else {
2253     triangular_solve_out_impl(result, clone_A, A_broadcast, self_broadcast, upper, transpose, unitriangular);
2254   }
2255 }
2256 
2257 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ qr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2258 
2259 DEFINE_DISPATCH(geqrf_stub);
2260 
geqrf_out_helper(const Tensor & input,const Tensor & QR,const Tensor & tau)2261 static void geqrf_out_helper(const Tensor& input, const Tensor& QR, const Tensor& tau) {
2262   TORCH_INTERNAL_ASSERT(input.dim() >= 2);
2263 
2264   TORCH_INTERNAL_ASSERT(input.scalar_type() == QR.scalar_type());
2265   TORCH_INTERNAL_ASSERT(input.device() == QR.device());
2266 
2267   TORCH_INTERNAL_ASSERT(input.scalar_type() == tau.scalar_type());
2268   TORCH_INTERNAL_ASSERT(input.device() == tau.device());
2269 
2270   // if 'QR' has no elements we can modify it
2271   if (QR.numel() == 0) {
2272     QR.resize_as_(input.mT(), MemoryFormat::Contiguous);
2273     QR.transpose_(-2, -1); // make Fortran-contiguous
2274   }
2275 
2276   auto expected_batch_tau_shape = IntArrayRef(input.sizes().data(), input.dim() - 2).vec(); // input.shape[:-2]
2277   expected_batch_tau_shape.push_back(std::min(input.size(-2), input.size(-1)));
2278   if (tau.numel() == 0) {
2279     tau.resize_(expected_batch_tau_shape);
2280   }
2281 
2282   // QR tensor must be in batched column major order (Fortran contiguous)
2283   TORCH_INTERNAL_ASSERT(QR.mT().is_contiguous());
2284   TORCH_INTERNAL_ASSERT(QR.sizes().equals(input.sizes()));
2285 
2286   // tau tensor must be contiguous
2287   TORCH_INTERNAL_ASSERT(tau.is_contiguous());
2288   TORCH_INTERNAL_ASSERT(tau.sizes().equals(expected_batch_tau_shape));
2289 
2290   // geqrf_stub (apply_geqrf) performs calculations in-place and 'QR' must be a copy of input
2291   QR.copy_(input);
2292   geqrf_stub(input.device().type(), QR, tau);
2293 }
2294 
geqrf_out(const Tensor & input,Tensor & QR,Tensor & tau)2295 std::tuple<Tensor&, Tensor&> geqrf_out(const Tensor& input, Tensor& QR, Tensor& tau) {
2296   TORCH_CHECK(input.dim() >= 2, "torch.geqrf: input must have at least 2 dimensions.");
2297 
2298   checkSameDevice("torch.geqrf", QR, input, "a"); // 'a' is used in documentation and native_functions.yml
2299   checkSameDevice("torch.geqrf", tau, input, "tau");
2300   checkLinalgCompatibleDtype("torch.geqrf", QR, input, "a");
2301   checkLinalgCompatibleDtype("torch.geqrf", tau, input, "tau");
2302 
2303   bool QR_input_same_type = (QR.scalar_type() == input.scalar_type());
2304   bool tau_input_same_type = (tau.scalar_type() == input.scalar_type());
2305   bool QR_equal_expected_shape = QR.sizes().equals(input.sizes());
2306 
2307   auto expected_batch_tau_shape = IntArrayRef(input.sizes().data(), input.dim() - 2).vec(); // input.shape[:-2]
2308   expected_batch_tau_shape.push_back(std::min(input.size(-2), input.size(-1)));
2309   bool tau_equal_expected_shape = tau.sizes().equals(expected_batch_tau_shape);
2310 
2311   bool is_batched_column_major = false;
2312   if (QR.dim() >= 2) {
2313     is_batched_column_major = QR.mT().is_contiguous();
2314   }
2315 
2316   // if 'QR' is not empty and not in batched column major format
2317   bool copy_needed = (QR.numel() != 0 && !is_batched_column_major);
2318   copy_needed |= (QR.numel() != 0 && !QR_equal_expected_shape); // or 'QR' does not have the expected shape
2319   copy_needed |= !QR_input_same_type;  // or 'QR' does not have the same dtype as input
2320   // we have to allocate a temporary tensor
2321 
2322   copy_needed |= (tau.numel() != 0 && !tau.is_contiguous());
2323   copy_needed |= (tau.numel() != 0 && !tau_equal_expected_shape); // or 'tau' does not have the expected shape
2324   copy_needed |= !tau_input_same_type;  // or 'tau' does not have the same dtype as input
2325 
2326   if (copy_needed) {
2327     Tensor QR_tmp = at::empty({0}, input.options());
2328     Tensor tau_tmp = at::empty({0}, input.options());
2329 
2330     geqrf_out_helper(input, QR_tmp, tau_tmp);
2331 
2332     at::native::resize_output(QR, QR_tmp.sizes());
2333     QR.copy_(QR_tmp);
2334     at::native::resize_output(tau, tau_tmp.sizes());
2335     tau.copy_(tau_tmp);
2336   } else {
2337     // use "out" tensors' storage directly
2338     geqrf_out_helper(input, QR, tau);
2339   }
2340 
2341   return std::tuple<Tensor&, Tensor&>(QR, tau);
2342 }
2343 
geqrf(const Tensor & input)2344 std::tuple<Tensor, Tensor> geqrf(const Tensor& input) {
2345   Tensor QR = at::empty({0}, input.options());
2346   Tensor tau = at::empty({0}, input.options());
2347   std::tie(QR, tau) = at::geqrf_outf(input, QR, tau);
2348   return std::make_tuple(std::move(QR), std::move(tau));
2349 }
2350 
2351 /*
2352   Computes the QR decomposition using GEQRF and ORGQR operations.
2353   This is an in-place function and Q, R tensors must have correct shape and be Fortran contiguous.
2354 
2355   Args:
2356   * `input` - [in] Input tensor for QR decomposition
2357   * `Q` - [out] Tensor containing the Q matrices of QR decomposition
2358   * `R` - [out] Tensor containing the R matrices of QR decomposition
2359   * `compute_q` - controls whether the Q tensor is computed
2360   * `reduced_mode` - controls the size of Q and R tensors
2361 
2362   For further details, please see the LAPACK documentation for GEQRF and ORGQR.
2363 */
TORCH_IMPL_FUNC(linalg_qr_out)2364 TORCH_IMPL_FUNC(linalg_qr_out)(const Tensor& A,
2365                                c10::string_view mode,
2366                                const Tensor & Q,
2367                                const Tensor & R) {
2368   auto m = A.size(-2);
2369   auto n = A.size(-1);
2370   auto k = std::min(m, n);
2371   auto [compute_q, reduced_mode] = at::native::_parse_qr_mode(mode);
2372 
2373 
2374   // We need an auxiliary tensor to call geqrf
2375   auto tau_shape = A.sizes().vec();
2376   tau_shape.pop_back();
2377   tau_shape.back() = k;
2378   auto tau = A.new_empty(tau_shape);
2379 
2380   // geqrf requires m x n workspace input that is modified in-place
2381   // We try to use Q. If it doesn't fit, we try to use R
2382   // If m > n and compute_q==false, it won't fit into Q or R, so we neet to create an auxiliary tensor
2383   Tensor QR;
2384   if (compute_q && Q.size(-1) == n) {
2385     QR = Q;
2386     QR.copy_(A);
2387   } else if (R.size(-2) == m) {
2388     QR = R;
2389     QR.copy_(A);
2390   } else {
2391     QR = cloneBatchedColumnMajor(A);
2392   }
2393 
2394   geqrf_stub(A.device().type(), QR, tau);
2395 
2396   // Split QR into Q (unless compute_q == false) and R
2397   if (QR.is_alias_of(R)) {
2398     // Copy QR into Q
2399     if (compute_q) {
2400       // If the result didn't fit in Q and compute_q == true is because Q is not of size m x n (i.e. it's of size m x m)
2401       TORCH_INTERNAL_ASSERT(Q.size(-1) == m);
2402       if (m < n) {
2403         Q.copy_(QR.slice(-1, 0, m));
2404       } else {
2405         Q.slice(-1, 0, n).copy_(QR);
2406       }
2407     }
2408     R.triu_();
2409   } else {
2410     // Copy QR into R from Q or the aux tensor
2411     at::triu_out(const_cast<Tensor&>(R), QR.slice(-2, 0, n));
2412   }
2413 
2414   if (compute_q) {
2415     // Next perform ORGQR for Q using the result from GEQRF
2416     orgqr_stub(A.device().type(), const_cast<Tensor&>(Q), tau);
2417   }
2418 }
2419 
2420 
qr(const Tensor & self,bool some)2421 std::tuple<Tensor,Tensor> qr(const Tensor& self, bool some) {
2422   TORCH_WARN_ONCE(
2423     "torch.qr is deprecated in favor of torch.linalg.qr and will be removed in a future PyTorch release.\n",
2424     "The boolean parameter 'some' has been replaced with a string parameter 'mode'.\n",
2425     "Q, R = torch.qr(A, some)\n",
2426     "should be replaced with\n",
2427     "Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete')"
2428   );
2429   const char* mode = some ? "reduced" : "complete";
2430   return at::linalg_qr(self, mode);
2431 }
2432 
qr_out(const Tensor & self,bool some,Tensor & Q,Tensor & R)2433 std::tuple<Tensor&,Tensor&> qr_out(const Tensor& self, bool some, Tensor& Q, Tensor& R) {
2434   TORCH_WARN_ONCE(
2435     "torch.qr is deprecated in favor of torch.linalg.qr and will be removed in a future PyTorch release.\n",
2436     "The boolean parameter 'some' has been replaced with a string parameter 'mode'.\n",
2437     "Q, R = torch.qr(A, some)\n",
2438     "should be replaced with\n",
2439     "Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete')"
2440   );
2441   const char* mode = some ? "reduced" : "complete";
2442   return at::linalg_qr_out(Q, R, self, mode);
2443 }
2444 
2445 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ orgqr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2446 
2447 DEFINE_DISPATCH(orgqr_stub);
2448 
2449 /*
2450   The householder_product (orgqr) function allows reconstruction of an orthogonal (or unitary) matrix Q,
2451   from a sequence of elementary reflectors, such as is produced by the geqrf function.
2452 
2453   Args:
2454   * `input` - Tensor with the directions of the elementary reflectors below the diagonal.
2455   * `tau` - Tensor containing the magnitudes of the elementary reflectors.
2456   * `result` - result Tensor, which will contain the orthogonal (or unitary) matrix Q.
2457 
2458   For further details, please see the LAPACK/MAGMA documentation.
2459 */
householder_product_out_helper(const Tensor & input,const Tensor & tau,Tensor & result)2460 static Tensor& householder_product_out_helper(const Tensor& input, const Tensor& tau, Tensor& result) {
2461   TORCH_INTERNAL_ASSERT(input.dim() >= 2);
2462   TORCH_INTERNAL_ASSERT(input.size(-2) >= input.size(-1));
2463   TORCH_INTERNAL_ASSERT(input.size(-1) >= tau.size(-1));
2464 
2465   TORCH_INTERNAL_ASSERT(input.scalar_type() == tau.scalar_type());
2466   TORCH_INTERNAL_ASSERT(input.device() == tau.device());
2467 
2468   TORCH_INTERNAL_ASSERT(result.scalar_type() == input.scalar_type());
2469   TORCH_INTERNAL_ASSERT(result.device() == input.device());
2470 
2471   // if result has no elements we can modify it
2472   if (result.numel() == 0) {
2473     at::native::resize_as_(result, input.mT(), MemoryFormat::Contiguous);
2474     result.transpose_(-2, -1);
2475   }
2476 
2477   // result tensor must be in batched column major order (Fortran contiguous)
2478   TORCH_INTERNAL_ASSERT(result.mT().is_contiguous());
2479   TORCH_INTERNAL_ASSERT(result.sizes().equals(input.sizes()));
2480 
2481   // tau tensor must be contiguous
2482   Tensor tau_ = tau;
2483   if (!tau.is_contiguous()) {
2484     tau_ = at::empty(tau.sizes(), tau.options(), MemoryFormat::Contiguous);
2485     tau_.copy_(tau);
2486   }
2487 
2488   // orgqr_stub (apply_orgqr) performs calculations in-place and result must be a copy of input
2489   result.copy_(input);
2490 
2491   result = orgqr_stub(result.device().type(), result, tau_);
2492   return result;
2493 }
2494 
linalg_householder_product_out(const Tensor & input,const Tensor & tau,Tensor & result)2495 Tensor& linalg_householder_product_out(const Tensor& input, const Tensor& tau, Tensor& result) {
2496   TORCH_CHECK(input.dim() >= 2, "torch.linalg.householder_product: input must have at least 2 dimensions.");
2497   TORCH_CHECK(
2498       input.size(-2) >= input.size(-1),
2499       "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]");
2500   TORCH_CHECK(
2501       input.size(-1) >= tau.size(-1),
2502       "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]");
2503 
2504   TORCH_CHECK(
2505       input.dim() - tau.dim() == 1,
2506       "torch.linalg.householder_product: Expected tau to have one dimension less than input, but got tau.ndim equal to ",
2507       tau.dim(),
2508       " and input.ndim is equal to ",
2509       input.dim());
2510   if (input.dim() > 2) {
2511     auto expected_batch_tau_shape = IntArrayRef(input.sizes().data(), input.dim() - 2); // input.shape[:-2]
2512     auto actual_batch_tau_shape = IntArrayRef(tau.sizes().data(), tau.dim() - 1); // tau.shape[:-1]
2513     TORCH_CHECK(
2514         actual_batch_tau_shape.equals(expected_batch_tau_shape),
2515         "torch.linalg.householder_product: Expected batch dimensions of tau to be equal to input.shape[:-2], but got ",
2516         actual_batch_tau_shape);
2517   }
2518 
2519   TORCH_CHECK(
2520       tau.scalar_type() == input.scalar_type(),
2521       "torch.linalg.householder_product: tau dtype ",
2522       tau.scalar_type(),
2523       " does not match input dtype ",
2524       input.scalar_type());
2525   checkSameDevice("torch.linalg.householder_product", tau, input, "tau");
2526   checkSameDevice("torch.linalg.householder_product", result, input);
2527   checkLinalgCompatibleDtype("torch.linalg.householder_product", result, input);
2528 
2529   // TODO: uncomment the following when passing incorrectly sized 'result' is not allowed
2530   // if (result.numel() != 0) {
2531   //   // Resize messes up the strides, so let's not use at::native::resize_output
2532   //   TORCH_CHECK(result.sizes().equals(input.sizes()),
2533   //   "result shape ", result.sizes(), " does not match input shape ", input.sizes());
2534   // }
2535 
2536   bool result_input_same_type = (result.scalar_type() == input.scalar_type());
2537   bool result_equal_expected_shape = result.sizes().equals(input.sizes());
2538   bool is_batched_column_major = false;
2539   if (result.dim() >= 2) {
2540     is_batched_column_major = result.mT().is_contiguous();
2541   }
2542 
2543   // if result is not empty and not in batched column major format
2544   bool copy_needed = (result.numel() != 0 && !is_batched_column_major);
2545   copy_needed |= !result_input_same_type;  // or result does not have the same dtype as input
2546   copy_needed |= (result.numel() != 0 && !result_equal_expected_shape); // or result does not have the expected shape
2547   // we have to allocate a temporary tensor
2548   if (copy_needed) {
2549     Tensor result_tmp = at::empty({0}, input.options());
2550     result_tmp = householder_product_out_helper(input, tau, result_tmp);
2551     at::native::resize_output(result, result_tmp.sizes());
2552     result.copy_(result_tmp);
2553   } else {
2554     // use result's storage directly
2555     result = householder_product_out_helper(input, tau, result);
2556   }
2557 
2558   return result;
2559 }
2560 
linalg_householder_product(const Tensor & input,const Tensor & tau)2561 Tensor linalg_householder_product(const Tensor& input, const Tensor& tau) {
2562   Tensor result = at::empty({0}, input.options());
2563   result = at::linalg_householder_product_outf(input, tau, result);
2564   return result;
2565 }
2566 
2567 // torch.orgqr is an alias of torch.linalg.householder_product
2568 // torch.linalg.householder_product is the preferred new function
orgqr_out(const Tensor & input,const Tensor & tau,Tensor & result)2569 Tensor& orgqr_out(const Tensor& input, const Tensor& tau, Tensor& result) {
2570   return at::linalg_householder_product_outf(input, tau, result);
2571 }
2572 
orgqr(const Tensor & input,const Tensor & tau)2573 Tensor orgqr(const Tensor& input, const Tensor& tau) {
2574   return at::linalg_householder_product(input, tau);
2575 }
2576 
2577 DEFINE_DISPATCH(ormqr_stub);
2578 
ormqr_out_helper(const Tensor & input,const Tensor & tau,const Tensor & other,const Tensor & result,bool left,bool transpose)2579 static void ormqr_out_helper(const Tensor& input, const Tensor& tau, const Tensor& other, const Tensor& result, bool left, bool transpose) {
2580   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2);
2581   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other.dim() >= 2);
2582 
2583   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other.size(left ? -2 : -1) >= tau.size(-1));
2584   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other.size(left ? -2 : -1) == input.size(-2));
2585 
2586   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.scalar_type() == tau.scalar_type());
2587   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == tau.device());
2588 
2589   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.scalar_type() == other.scalar_type());
2590   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == other.device());
2591 
2592   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.scalar_type() == input.scalar_type());
2593   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.device() == input.device());
2594 
2595   // if 'result' has no elements we can modify it
2596   if (result.numel() == 0) {
2597     at::native::resize_as_(result, other.mT(), MemoryFormat::Contiguous);
2598     result.transpose_(-2, -1);
2599   }
2600 
2601   // 'result' tensor must be in batched column major order (Fortran contiguous)
2602   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.mT().is_contiguous());
2603   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.sizes().equals(other.sizes()));
2604 
2605   // 'tau' tensor must be contiguous
2606   Tensor tau_ = tau;
2607   if (!tau.is_contiguous()) {
2608     tau_ = at::empty(tau.sizes(), tau.options(), MemoryFormat::Contiguous);
2609     tau_.copy_(tau);
2610   }
2611 
2612   // 'input' tensor must be Fortran contiguous
2613   Tensor input_ = input;
2614   if (!input.mT().is_contiguous()) {
2615     input_ = at::empty(input.mT().sizes(), input.options(), MemoryFormat::Contiguous);
2616     input_.transpose_(-2, -1);
2617     input_.copy_(input);
2618   }
2619 
2620   // ormqr_stub (apply_ormqr) performs calculations in-place and 'result' must be a copy of 'other'
2621   result.copy_(other);
2622 
2623   ormqr_stub(result.device().type(), input_, tau_, result, left, transpose);
2624 }
2625 
ormqr_out(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose,Tensor & result)2626 Tensor& ormqr_out(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose, Tensor& result) {
2627   TORCH_CHECK(input.dim() >= 2, "torch.ormqr: input must have at least 2 dimensions.");
2628   TORCH_CHECK(other.dim() >= 2, "torch.ormqr: other must have at least 2 dimensions.");
2629 
2630   int64_t left_size_condition = left ? -2 : -1;
2631   TORCH_CHECK(
2632       other.size(left_size_condition) >= tau.size(-1),
2633       "torch.ormqr: other.shape[",
2634       left_size_condition,
2635       "] must be greater than or equal to tau.shape[-1]");
2636 
2637   TORCH_CHECK(
2638       other.size(left_size_condition) == input.size(-2),
2639       "torch.ormqr: other.shape[",
2640       left_size_condition,
2641       "] must be equal to input.shape[-2]");
2642 
2643   TORCH_CHECK(
2644       tau.size(-1) <= input.size(-1),
2645       "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]");
2646 
2647   TORCH_CHECK(
2648       input.dim() - tau.dim() == 1,
2649       "torch.ormqr: ",
2650       "Expected tau to have one dimension less than input, but got tau.ndim equal to ",
2651       tau.dim(),
2652       " and input.ndim is equal to ",
2653       input.dim());
2654   TORCH_CHECK(
2655       input.dim() == other.dim(),
2656       "torch.ormqr: ",
2657       "Expected other to have the same number of dimensions as input, but got other.ndim equal to ",
2658       other.dim(),
2659       " and input.ndim is equal to ",
2660       input.dim());
2661 
2662   if (input.dim() > 2) {
2663     auto expected_batch_shape = IntArrayRef(input.sizes().data(), input.dim() - 2); // input.shape[:-2]
2664     auto actual_batch_tau_shape = IntArrayRef(tau.sizes().data(), tau.dim() - 1); // tau.shape[:-1]
2665     TORCH_CHECK(
2666         actual_batch_tau_shape.equals(expected_batch_shape),
2667         "torch.ormqr: Expected batch dimensions of tau to be equal to input.shape[:-2], but got ",
2668         actual_batch_tau_shape);
2669 
2670     auto actual_batch_other_shape = IntArrayRef(other.sizes().data(), other.dim() - 2); // other.shape[:-2]
2671     TORCH_CHECK(
2672         actual_batch_other_shape.equals(expected_batch_shape),
2673         "torch.ormqr: Expected batch dimensions of other to be equal to input.shape[:-2], but got ",
2674         actual_batch_other_shape);
2675   }
2676 
2677   TORCH_CHECK(
2678       tau.scalar_type() == input.scalar_type(),
2679       "torch.ormqr: Expected input and tau to have the same dtype, but input has dtype", input.scalar_type(),
2680       " and tau has dtype ", tau.scalar_type());
2681   TORCH_CHECK(
2682       other.scalar_type() == input.scalar_type(),
2683       "torch.ormqr: Expected input and other to have the same dtype, but input has dtype", input.scalar_type(),
2684       " and other has dtype ", other.scalar_type());
2685   TORCH_CHECK(
2686       result.scalar_type() == input.scalar_type(),
2687       "torch.ormqr: Expected input and result to have the same dtype, but input has dtype", input.scalar_type(),
2688       " and result has dtype ", result.scalar_type());
2689 
2690   checkSameDevice("torch.ormqr", tau, input, "tau");
2691   checkSameDevice("torch.ormqr", other, input, "other");
2692   checkSameDevice("torch.ormqr", result, input);
2693 
2694   bool result_equal_expected_shape = result.sizes().equals(other.sizes());
2695   bool is_batched_column_major = false;
2696   if (result.dim() >= 2) {
2697     is_batched_column_major = result.mT().is_contiguous();
2698   }
2699 
2700   // if result is not empty and not in batched column major format
2701   bool copy_needed = (result.numel() != 0 && !is_batched_column_major);
2702   copy_needed |= (result.numel() != 0 && !result_equal_expected_shape); // or result does not have the expected shape
2703   // we have to allocate a temporary tensor
2704   if (copy_needed) {
2705     Tensor result_tmp = at::empty({0}, input.options());
2706     ormqr_out_helper(input, tau, other, result_tmp, left, transpose);
2707     at::native::resize_output(result, result_tmp.sizes());
2708     result.copy_(result_tmp);
2709   } else {
2710     // use result's storage directly
2711     ormqr_out_helper(input, tau, other, result, left, transpose);
2712   }
2713 
2714   return result;
2715 }
2716 
ormqr(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose)2717 Tensor ormqr(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
2718   Tensor result = at::empty({0}, input.options());
2719   result = at::native::ormqr_out(input, tau, other, left, transpose, result);
2720   return result;
2721 }
2722 
2723 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2724 
2725 DEFINE_DISPATCH(linalg_eigh_stub);
2726 
2727 /*
2728   Computes eigenvalues and eigenvectors of the tensor 'input'.
2729 
2730   Args:
2731   * 'input' - input Tensor for eigendecomposition
2732   * 'values' - Tensor to store computed eigenvalues
2733   * 'vectors' - Tensor to store computed eigenvectors
2734   * 'infos' - Tensor to store LAPACK/MAGMA/cuSOLVER error codes
2735   * 'compute_eigenvectors' - controls whether eigenvectors should be computed
2736   * 'uplo' - controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L"
2737     "u", "U" - upper triangular portion of the input matrix is used in computations; "l", "L" - lower.
2738 */
2739 
TORCH_IMPL_FUNC(_linalg_eigh_out)2740 TORCH_IMPL_FUNC(_linalg_eigh_out)(const Tensor& A,
2741                                   c10::string_view uplo,
2742                                   bool compute_v,
2743                                   const Tensor& L,
2744                                   const Tensor& V) {
2745   if (A.numel() == 0) {
2746     return;
2747   }
2748 
2749   auto uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
2750   bool upper = (uplo_uppercase == 'U');
2751 
2752   Tensor V_ = V;
2753   if (compute_v) {
2754     V_.copy_(A);
2755   } else {
2756     // We need a tensor to hold A
2757     V_ = cloneBatchedColumnMajor(A);
2758   }
2759 
2760   const auto info = at::zeros(A.sizes().slice(0, A.dim() - 2), A.options().dtype(kInt));
2761   linalg_eigh_stub(A.device().type(), L, V_, info, upper, compute_v);
2762 
2763   at::_linalg_check_errors(info, "linalg.eigh", /*is_matrix*/A.dim() == 2);
2764 }
2765 
linalg_eigh(const Tensor & A,c10::string_view uplo)2766 std::tuple<Tensor, Tensor> linalg_eigh(const Tensor& A, c10::string_view uplo) {
2767   // TODO (Good intro task) Implement linalg_eigh_ex_out
2768   return at::_linalg_eigh(A, uplo, /*compute_v*/true);
2769 }
2770 
linalg_eigh_out(const Tensor & A,c10::string_view uplo,Tensor & L,Tensor & V)2771 std::tuple<Tensor&, Tensor&> linalg_eigh_out(const Tensor& A, c10::string_view uplo, Tensor& L, Tensor& V) {
2772   return at::_linalg_eigh_out(L, V, A, uplo, /*compute_v=*/true);
2773 }
2774 
2775 
linalg_eigvalsh(const Tensor & A,c10::string_view uplo)2776 Tensor linalg_eigvalsh(const Tensor& A, c10::string_view uplo) {
2777   return std::get<0>(at::_linalg_eigh(A, uplo,
2778                      /*compute_v=*/_may_require_fw_or_bw_grad(A)));
2779 }
2780 
linalg_eigvalsh_out(const Tensor & A,c10::string_view uplo,Tensor & L)2781 Tensor& linalg_eigvalsh_out(const Tensor& A, c10::string_view uplo, Tensor& L) {
2782   auto V = at::empty({0}, A.options());
2783   at::_linalg_eigh_out(L, V, A, uplo, /*compute_v=*/false);
2784   return L;
2785 }
2786 
2787 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2788 
2789 // This function returns complex-valued eigenvectors that is obtained from LAPACK GEEV's real-valued output
2790 // This function is also used for the MAGMA path because intermediate MAGMA's results live on CPU
2791 template <typename scalar_t>
linalg_eig_make_complex_eigenvectors_impl(Tensor & result,const Tensor & complex_values,const Tensor & real_vectors)2792 static void linalg_eig_make_complex_eigenvectors_impl(Tensor& result, const Tensor& complex_values, const Tensor& real_vectors) {
2793   // From GEEV documentation:
2794   // Complex conjugate pairs of eigenvalues appear consecutively with the eigenvalue having the positive imaginary part first
2795   // If the j-th eigenvalue is real, then v(j) = VR(:,j), the j-th column of VR.
2796   // If the j-th and (j+1)-st eigenvalues form a complex conjugate pair, then v(j) = VR(:,j) + i*VR(:,j+1) and v(j+1) = VR(:,j) - i*VR(:,j+1).
2797 
2798   auto batch_size = batchCount(real_vectors);
2799   auto n = real_vectors.size(-1);
2800   auto matrix_stride = matrixStride(real_vectors);
2801 
2802   auto result_data = result.data_ptr<c10::complex<scalar_t>>();
2803   auto real_vectors_data = real_vectors.const_data_ptr<scalar_t>();
2804   auto values_data = complex_values.const_data_ptr<c10::complex<scalar_t>>();
2805 
2806   for (auto b = decltype(batch_size){0}; b < batch_size; b++) {
2807     const scalar_t* vecs = &real_vectors_data[b * matrix_stride];
2808     c10::complex<scalar_t>* res = &result_data[b * matrix_stride];
2809     const c10::complex<scalar_t>* vals = &values_data[b * n];
2810     for (auto j = decltype(n){0}; j < n; j++) {
2811       if (vals[j].imag() == 0.0) {  // eigenvalue is real, then v(j) = VR(:,j)
2812         for (auto i = decltype(n){0}; i < n; i++) {
2813           res[j * n + i] = c10::complex<scalar_t>(vecs[j * n + i], 0);
2814         }
2815       } else {
2816         for (auto i = decltype(n){0}; i < n; i++) {
2817           res[j * n + i] = c10::complex<scalar_t>(vecs[j * n + i],  vecs[(j+1) * n + i]);      // v(j)   = VR(:,j) + i*VR(:,j+1)
2818           res[(j+1) * n + i] = c10::complex<scalar_t>(vecs[j * n + i], -vecs[(j+1) * n + i]);  // v(j+1) = VR(:,j) - i*VR(:,j+1)
2819         }
2820         j++;
2821       }
2822     }
2823   }
2824 }
2825 
linalg_eig_make_complex_eigenvectors(Tensor & complex_vectors,const Tensor & complex_values,const Tensor & real_vectors)2826 static Tensor& linalg_eig_make_complex_eigenvectors(Tensor& complex_vectors, const Tensor& complex_values, const Tensor& real_vectors) {
2827   // These asserts make explicit the requirements on tensors for 'linalg_eig_make_complex_eigenvectors_impl'
2828   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_vectors.device() == at::kCPU);
2829   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_values.device() == at::kCPU);
2830   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(real_vectors.device() == at::kCPU);
2831 
2832   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_vectors.is_complex());
2833   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_values.is_complex());
2834   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(real_vectors.is_floating_point());
2835 
2836   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_vectors.mT().is_contiguous());
2837   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(complex_values.is_contiguous());
2838   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(real_vectors.mT().is_contiguous());
2839 
2840   AT_DISPATCH_FLOATING_TYPES(real_vectors.scalar_type(), "linalg_eig_make_complex_vector", [&]{
2841     linalg_eig_make_complex_eigenvectors_impl<scalar_t>(complex_vectors, complex_values, real_vectors);
2842   });
2843   return complex_vectors;
2844 }
2845 
2846 DEFINE_DISPATCH(linalg_eig_stub);
2847 
linalg_eig_out_info(const Tensor & input,Tensor & values,Tensor & vectors,Tensor & infos,bool compute_eigenvectors)2848 static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Tensor& values, Tensor& vectors, Tensor& infos, bool compute_eigenvectors) {
2849   // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
2850   // therefore we create all intermediate tensors on CPU
2851   auto options = input.options().device(at::kCPU);
2852 
2853   // These internal asserts make explicit the assumptions in the implementation
2854   // Error check with the actual error messages are done on the higher level of the hierarchy of calls
2855   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.dim() >= 2);
2856   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.size(-2) == input.size(-1));
2857 
2858   // for real-valued 'input', eigenvalues can be real-valued or complex-valued
2859   TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == values.scalar_type()) || (input.scalar_type() == values.scalar_type()));
2860   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU);
2861 
2862   // for real-valued 'input', eigenvectors can be real-valued or complex-valued
2863   if (compute_eigenvectors) {
2864     TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == vectors.scalar_type()) || (input.scalar_type() == vectors.scalar_type()));
2865     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.device() == at::kCPU);
2866   }
2867 
2868   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.scalar_type() == at::kInt);
2869   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.device() == at::kCPU);
2870   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.numel() == std::max<int64_t>(1, batchCount(input)));
2871   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_contiguous());
2872 
2873   // if 'vectors' has no elements we can modify it
2874   if (vectors.numel() == 0 && compute_eigenvectors) {
2875     vectors.resize_(input.sizes(), MemoryFormat::Contiguous);
2876     vectors.transpose_(-2, -1);  // make 'vectors' to have Fortran contiguous memory layout
2877   }
2878 
2879   // if 'values' has no elements we can modify it
2880   auto values_shape = IntArrayRef(input.sizes().data(), input.dim()-1);  // input.shape[:-1]
2881   if (values.numel() == 0) {
2882     values.resize_(values_shape, MemoryFormat::Contiguous);
2883   }
2884 
2885   // 'vectors' must be in batched column major order (Fortran contiguous)
2886   if (compute_eigenvectors) {
2887     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.mT().is_contiguous());
2888     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.sizes().equals(input.sizes()));
2889   }
2890 
2891   // 'values' must be contiguous
2892   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_contiguous());
2893   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.sizes().equals(values_shape));
2894 
2895   // if 'input' is complex then use 'values' directly else create a temporary to hold the real and imaginary parts
2896   // and then use at::complex_out
2897   Tensor real_imag_values = values;
2898 
2899   // if 'input' is complex then use 'vectors' directly else maybe create a temporary to hold real vectors
2900   // and then use linalg_eig_make_complex_eigenvectors
2901   Tensor maybe_complex_vectors = vectors;
2902   if (!input.is_complex()) {
2903     // first n elements to hold the real portion of the output and the last n elements to hold the imaginary portion
2904     auto real_imag_shape = IntArrayRef(input.sizes().data(), input.dim()-2).vec();  // input.shape[:-2]
2905     real_imag_shape.push_back(input.size(-1) * 2);
2906     real_imag_values = at::empty(real_imag_shape, options, MemoryFormat::Contiguous);
2907 
2908     // linalg_eig_stub expects real-valued tensor to store eigenvectors
2909     // output of linalg_eig_stub need to be post-processed later to produce complex-valued eigenvectors
2910     // we do this post-processing only if 'vectors' is complex-valued
2911     // otherwise storage of 'vectors' is used directly
2912     if (vectors.is_complex() && compute_eigenvectors) {
2913       maybe_complex_vectors = at::empty(input.sizes(), options, MemoryFormat::Contiguous);
2914       maybe_complex_vectors.transpose_(-2, -1);  // make 'maybe_complex_vectors' to have Fortran contiguous memory layout
2915     }
2916   }
2917 
2918   // MAGMA uses a hybrid CPU-GPU algorithm that performs well only for large matrices
2919   // See: https://github.com/pytorch/pytorch/pull/52491#issuecomment-795685687
2920   // Here we call CPU path for matrices smaller than 2048x2048
2921   // that should be in general significantly faster than calling MAGMA
2922   if (input.size(-1) <= 2048) {
2923     linalg_eig_stub(at::kCPU, real_imag_values, maybe_complex_vectors, infos, input.to(kCPU), compute_eigenvectors);
2924   } else {
2925     linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors);
2926   }
2927 
2928   // if input is not complex we need to do some post-processing
2929   if (!input.is_complex()) {
2930     // extract real and imaginary parts of the output
2931     auto real_values = real_imag_values.slice(/*dim=*/-1, /*start=*/0, /*end*/input.size(-1));
2932     auto imag_values = real_imag_values.slice(/*dim=*/-1, /*start=*/input.size(-1));
2933 
2934     // if the imaginary part is zero we don't need to do anything
2935     bool is_zero_imag = at::all(imag_values == 0.0).item().toBool();
2936     if (is_zero_imag) {
2937       values.copy_(real_values);
2938       if (compute_eigenvectors) {
2939         vectors.copy_(maybe_complex_vectors);  // does nothing for !vectors.is_complex() because vectors.is_same(maybe_complex_vectors) == true
2940       }
2941       return std::tuple<Tensor&, Tensor&>(values, vectors);
2942     }
2943 
2944     if (values.is_complex()) {
2945       values = at::complex_out(values, real_values, imag_values);
2946     } else {
2947       TORCH_CHECK(false, "torch.linalg.eig: imaginary part of eigenvalues is non-zero, can't safely cast eigenvalues to non-complex dtype.")
2948     }
2949     if (compute_eigenvectors) {
2950       if (vectors.is_complex()) {
2951           vectors = linalg_eig_make_complex_eigenvectors(vectors, values, maybe_complex_vectors);
2952       } else {
2953         TORCH_CHECK(false, "torch.linalg.eig: imaginary part of eigenvectors is non-zero, can't safely cast eigenvectors to non-complex dtype.")
2954       }
2955     }
2956   }
2957 
2958   return std::tuple<Tensor&, Tensor&>(values, vectors);
2959 }
2960 
linalg_eig_out(const Tensor & input,Tensor & values,Tensor & vectors)2961 std::tuple<Tensor&, Tensor&> linalg_eig_out(const Tensor& input, Tensor& values, Tensor& vectors) {
2962   TORCH_CHECK(input.isfinite().all().item<bool>(), "torch.linalg.eig: input tensor should not contain infs or NaNs.");
2963   squareCheckInputs(input, "linalg.eig");
2964 
2965   // unlike NumPy for real-valued inputs the output is always complex-valued
2966   checkLinalgCompatibleDtype("torch.linalg.eig", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
2967   checkLinalgCompatibleDtype("torch.linalg.eig", vectors.scalar_type(), toComplexType(input.scalar_type()), "eigenvectors");
2968   checkSameDevice("torch.linalg.eig", values, input, "eigenvalues");
2969   checkSameDevice("torch.linalg.eig", vectors, input, "eigenvectors");
2970 
2971   // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
2972   auto options = input.options().device(at::kCPU);
2973   auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
2974 
2975   // if result is not empty and not in batched column major format we have to allocate a temporary tensor
2976   bool is_batched_column_major = false;
2977   if (vectors.dim() >= 2) {
2978     is_batched_column_major = vectors.mT().is_contiguous();
2979   }
2980 
2981   bool values_expected_type = (values.scalar_type() == toComplexType(input.scalar_type()));
2982   bool vectors_expected_type = (vectors.scalar_type() == toComplexType(input.scalar_type()));
2983 
2984   auto expected_values_shape = IntArrayRef(input.sizes().data(), input.dim()-1);  // input.shape[:-1]
2985   bool values_equal_expected_shape = values.sizes().equals(expected_values_shape);
2986   bool vectors_equal_expected_shape = vectors.sizes().equals(input.sizes());
2987 
2988   // if result is not empty and not in batched column major format
2989   bool values_tmp_needed = (values.numel() != 0 && !values.is_contiguous());
2990   bool vectors_tmp_needed = (vectors.numel() != 0 && !is_batched_column_major);
2991   // or result does not have the expected shape
2992   values_tmp_needed |= (values.numel() != 0 && !values_equal_expected_shape);
2993   vectors_tmp_needed |= (vectors.numel() != 0 && !vectors_equal_expected_shape);
2994   // or result does not have the expected dtype
2995   values_tmp_needed |= !values_expected_type;
2996   vectors_tmp_needed |= !vectors_expected_type;
2997   // we will allocate a temporary tensor and do the copy
2998 
2999   // because MAGMA's GEEV takes CPU inputs and returns CPU outputs
3000   // "out" tensors that are on GPU device can't be used directly
3001   values_tmp_needed |= values.is_cuda();
3002   vectors_tmp_needed |= vectors.is_cuda();
3003 
3004   // determine the appropriate scalar_type for the temporary tensors
3005   ScalarType values_type = input.scalar_type();
3006   ScalarType vectors_type = input.scalar_type();
3007   if (!input.is_complex()) {
3008     // for real-valued input we can have either real- or complex-valued output
3009     ScalarType input_complex_dtype = toComplexType(input.scalar_type());
3010     values_type = values.is_complex() ? input_complex_dtype : values_type;
3011     vectors_type = vectors.is_complex() ? input_complex_dtype : vectors_type;
3012   }
3013 
3014   if (values_tmp_needed && vectors_tmp_needed) {
3015     Tensor values_tmp = at::empty({0}, options.dtype(values_type));
3016     Tensor vectors_tmp = at::empty({0}, options.dtype(vectors_type));
3017     std::tie(values_tmp, vectors_tmp) = linalg_eig_out_info(input, values_tmp, vectors_tmp, infos, true);
3018     at::native::resize_output(values, values_tmp.sizes());
3019     values.copy_(values_tmp);
3020     at::native::resize_output(vectors, vectors_tmp.sizes());
3021     vectors.copy_(vectors_tmp);
3022   } else if (!values_tmp_needed && vectors_tmp_needed) {
3023     // use 'values' storage directly
3024     Tensor vectors_tmp = at::empty({0}, options.dtype(vectors_type));
3025     std::tie(values, vectors_tmp) = linalg_eig_out_info(input, values, vectors_tmp, infos, true);
3026     at::native::resize_output(vectors, vectors_tmp.sizes());
3027     vectors.copy_(vectors_tmp);
3028   } else if (values_tmp_needed && !vectors_tmp_needed) {
3029     // use 'vectors' storage directly
3030     Tensor values_tmp = at::empty({0}, options.dtype(values_type));
3031     std::tie(values_tmp, vectors) = linalg_eig_out_info(input, values_tmp, vectors, infos, true);
3032     at::native::resize_output(values, values_tmp.sizes());
3033     values.copy_(values_tmp);
3034   } else {
3035     // use 'values' and 'vectors' storage directly
3036     std::tie(values, vectors) = linalg_eig_out_info(input, values, vectors, infos, true);
3037   }
3038 
3039   // Now check LAPACK/MAGMA error codes
3040   at::_linalg_check_errors(infos, "torch.linalg.eig", input.dim() == 2);
3041   return std::tuple<Tensor&, Tensor&>(values, vectors);
3042 }
3043 
linalg_eig(const Tensor & input)3044 std::tuple<Tensor, Tensor> linalg_eig(const Tensor& input) {
3045   ScalarType complex_dtype = toComplexType(input.scalar_type());
3046   Tensor values = at::empty({0}, input.options().dtype(complex_dtype));
3047   Tensor vectors = at::empty({0}, input.options().dtype(complex_dtype));
3048 
3049   at::linalg_eig_outf(input, values, vectors);
3050 
3051   return std::tuple<Tensor, Tensor>(values, vectors);
3052 }
3053 
linalg_eigvals_out(const Tensor & input,Tensor & values)3054 Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
3055   squareCheckInputs(input, "linalg.eigvals");
3056 
3057   // unlike NumPy for real-valued inputs the output is always complex-valued
3058   checkLinalgCompatibleDtype("torch.linalg.eigvals", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
3059   checkSameDevice("torch.linalg.eigvals", values, input, "eigenvalues");
3060 
3061   // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
3062   auto options = input.options().device(at::kCPU);
3063   auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
3064 
3065   bool values_expected_type = (values.scalar_type() == toComplexType(input.scalar_type()));
3066 
3067   auto expected_values_shape = IntArrayRef(input.sizes().data(), input.dim()-1);  // input.shape[:-1]
3068   bool values_equal_expected_shape = values.sizes().equals(expected_values_shape);
3069 
3070   // if result is not empty and not in batched column major format
3071   bool values_tmp_needed = (values.numel() != 0 && !values.is_contiguous());
3072   // or result does not have the expected shape
3073   values_tmp_needed |= (values.numel() != 0 && !values_equal_expected_shape);
3074   // or result does not have the expected dtype
3075   values_tmp_needed |= !values_expected_type;
3076   // we will allocate a temporary tensor and do the copy
3077 
3078   // because MAGMA's GEEV takes CPU inputs and returns CPU outputs
3079   // 'values' tensor that is on GPU device can't be used directly
3080   values_tmp_needed |= (!values.is_cpu());
3081 
3082   // determine the appropriate scalar_type for the temporary tensors
3083   ScalarType values_type = input.scalar_type();
3084   if (!input.is_complex()) {
3085     // for real-valued input we can have either real- or complex-valued output
3086     ScalarType input_complex_dtype = toComplexType(input.scalar_type());
3087     values_type = values.is_complex() ? input_complex_dtype : values_type;
3088   }
3089 
3090   Tensor vectors;
3091   if (values_tmp_needed) {
3092     Tensor values_tmp = at::empty({0}, options.dtype(values_type));
3093     std::tie(values_tmp, std::ignore) = linalg_eig_out_info(input, values_tmp, vectors, infos, /*compute_eigenvectors=*/false);
3094     at::native::resize_output(values, values_tmp.sizes());
3095     values.copy_(values_tmp);
3096   } else { // use 'values' storage directly
3097     std::tie(values, std::ignore) = linalg_eig_out_info(input, values, vectors, infos, /*compute_eigenvectors=*/false);
3098   }
3099 
3100   // Now check LAPACK/MAGMA error codes
3101   at::_linalg_check_errors(infos, "torch.linalg.eigvals", input.dim() == 2);
3102   return values;
3103 }
3104 
linalg_eigvals(const Tensor & input)3105 Tensor linalg_eigvals(const Tensor& input) {
3106   // if input requires grad we must compute the eigenvectors to make this function differentiable
3107   // the eigenvectors are not exposed to the user
3108   if (_may_require_fw_or_bw_grad(input)) {
3109     return std::get<0>(at::linalg_eig(input));
3110   }
3111   return at::_linalg_eigvals(input);
3112 }
3113 
_linalg_eigvals(const Tensor & input)3114 Tensor _linalg_eigvals(const Tensor& input) {
3115   ScalarType complex_dtype = toComplexType(input.scalar_type());
3116   Tensor values = at::empty({0}, input.options().dtype(complex_dtype));
3117   linalg_eigvals_out(input, values);
3118   return values;
3119 }
3120 
3121 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3122 
3123 /* torch.svd, implemented in terms of torch.linalg.svd. There are two main
3124    differences:
3125 
3126     1. the 2nd parameter is bool some=True, which if effectively the opposite
3127        of full_matrices=True
3128 
3129     2. svd returns V, while linalg.svd returns Vh = V^H
3130 */
3131 
3132 DEFINE_DISPATCH(svd_stub);
3133 
TORCH_IMPL_FUNC(_linalg_svd_out)3134 TORCH_IMPL_FUNC(_linalg_svd_out)(const Tensor& A,
3135                                  const bool full_matrices,
3136                                  const bool compute_uv,
3137                                  std::optional<c10::string_view> driver,
3138                                  const Tensor & U,
3139                                  const Tensor & S,
3140                                  const Tensor & Vh) {
3141   // Half optimisation half precondition for some parts of the LAPACK / cuSOLVER
3142   // In particular, the call to lapackSvd to compute lwork fails otherwise
3143   if (A.numel() == 0) {
3144     // Needed in the case that we have e.g. A.shape == (3, 0) and full_matrices=True
3145     // We fill U or Vh with the identity matrix as it's a valid SVD for the empty matrix
3146     if (compute_uv && full_matrices) {
3147       if (U.numel() != 0) {
3148         U.zero_();
3149         U.diagonal(0, -2, -1).fill_(1.);
3150       }
3151       if (Vh.numel() != 0) {
3152         Vh.zero_();
3153         Vh.diagonal(0, -2, -1).fill_(1.);
3154       }
3155     }
3156     return;
3157   }
3158 
3159   // We need to distinguish the cuSOLVER case, as cuSOLVER expects F-contig matrices, but
3160   // it computes V rather than Vh
3161   const bool use_cusolver = at::native::svd_uses_cusolver(A);
3162   TORCH_CHECK(use_cusolver || !driver.has_value(),
3163     "torch.linalg.svd: keyword argument `driver=` is only supported on CUDA inputs with cuSOLVER backend.");
3164 
3165   // A always needs to be copied as its contents will be destroyed during the computation of the SVD
3166   // Now, MAGMA needs the copy to be on CPU, while cuSOLVER needs it to be on CUDA, so we'll defer
3167   // the copy as a column major matrix to the backends.
3168   const auto info = at::zeros(IntArrayRef(A.sizes().begin(), A.sizes().end() - 2), A.options().dtype(kInt));
3169 
3170   svd_stub(A.device().type(),
3171            A,
3172            full_matrices,
3173            compute_uv,
3174            driver,
3175            U, S, Vh, info);
3176 
3177   // TODO This should be removed, and the code checking for convergence should be lifted
3178   // from svd_cusolver to this function. We should then make sure that this function
3179   // never errors out.
3180   at::_linalg_check_errors(info, "linalg.svd", /*is_matrix*/A.dim() == 2);
3181 }
3182 
3183 std::tuple<Tensor&, Tensor&, Tensor&>
linalg_svd_out(const Tensor & A,bool full_matrices,std::optional<c10::string_view> driver,Tensor & U,Tensor & S,Tensor & Vh)3184 linalg_svd_out(const Tensor& A,
3185                bool full_matrices,
3186                std::optional<c10::string_view> driver,
3187                Tensor & U,
3188                Tensor & S,
3189                Tensor & Vh) {
3190   // This function does not have an _ex variant as we always check errors inside
3191   // to assure the convergence of the algorithm anyway. See
3192   // https://github.com/pytorch/pytorch/issues/28293
3193   // https://github.com/pytorch/pytorch/issues/64237
3194   //
3195   // We must delegate both linalg_svd and linalg_svdvals to
3196   // _linalg_svd (rather than delegating linalg_svdvals to linalg_svd) because
3197   //   1. We don't want to expose the `compute_uv` parameter in svd
3198   //   2. We would like to make use of the `compute_uv=False` optimisation within svdvals
3199   // The only way to achieve these two things and still abide by the compositionality rules
3200   // is by dispatching to another function.
3201   return at::_linalg_svd_out(U, S, Vh, A, full_matrices, /*compute_uv=*/true, driver);
3202 }
3203 
linalg_svd(const Tensor & A,bool full_matrices,std::optional<c10::string_view> driver)3204 std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& A, bool full_matrices,
3205     std::optional<c10::string_view> driver) {
3206   return at::_linalg_svd(A, full_matrices, /*compute_uv=*/true, driver);
3207 }
3208 
3209 // See note in linalg_svd for why this function does not have an _ex variant
linalg_svdvals_out(const Tensor & A,std::optional<c10::string_view> driver,Tensor & S)3210 Tensor& linalg_svdvals_out(const Tensor& A, std::optional<c10::string_view> driver, Tensor & S) {
3211   // Dummies
3212   auto U = at::empty({0}, A.options());
3213   auto Vh = at::empty({0}, A.options());
3214   at::_linalg_svd_out(U, S, Vh, A, /*full_matrices=*/false, /*compute_uv=*/false, /*driver=*/driver);
3215   return S;
3216 }
3217 
linalg_svdvals(const Tensor & A,std::optional<c10::string_view> driver)3218 Tensor linalg_svdvals(const Tensor& A, std::optional<c10::string_view> driver) {
3219   return std::get<1>(at::_linalg_svd(A, /*full_matrices=*/false,
3220                      /*compute_uv=*/_may_require_fw_or_bw_grad(A),
3221                      /*driver=*/driver));
3222 }
3223 
svd_out(const Tensor & self,bool some,bool compute_uv,Tensor & U,Tensor & S,Tensor & V)3224 std::tuple<Tensor&, Tensor&, Tensor&> svd_out(const Tensor& self, bool some, bool compute_uv,
3225     Tensor& U, Tensor& S, Tensor& V) {
3226 
3227   if (compute_uv) {
3228     if (V.dim() >= 2) {
3229       V.transpose_(-2, -1);
3230     }
3231     at::linalg_svd_out(U, S, V, self, /*full_matrices=*/!some);
3232     V.transpose_(-2, -1);
3233     if (V.is_complex()) {
3234       // We cannot use `_set_conj` as it does not play well with backwards
3235       V.conj_physical_();
3236     }
3237   } else {
3238     TORCH_CHECK(self.scalar_type() == U.scalar_type(),
3239     "torch.svd: Expected out tensor to have dtype ", self.scalar_type(), " but got ", U.scalar_type(), " instead");
3240 
3241     TORCH_CHECK(self.scalar_type() == V.scalar_type(),
3242     "torch.svd: Expected out tensor to have dtype ", self.scalar_type(), " but got ", V.scalar_type(), " instead");
3243 
3244     at::linalg_svdvals_out(S, self);
3245     // some == false returns U, Vh of size (m, m), (n, n) full of zeros
3246     const auto m = self.size(-2);
3247     const auto n = self.size(-1);
3248     auto sizes = self.sizes().vec();
3249 
3250     sizes.end()[-1] = m;
3251     at::native::resize_output(U, sizes);
3252     U.zero_();
3253 
3254     sizes.end()[-2] = n;
3255     sizes.end()[-1] = n;
3256     at::native::resize_output(V, sizes);
3257     V.zero_();
3258   }
3259 
3260   return std::tie(U, S, V);
3261 }
3262 
svd(const Tensor & self,bool some,bool compute_uv)3263 std::tuple<Tensor, Tensor, Tensor> svd(const Tensor& self, bool some, bool compute_uv) {
3264   // TODO: uncomment the following when svd is deprecated not only in docs
3265   // torch/xla is blocking the transition from at::svd to at::linalg_svd in at::linalg_pinv code
3266   // see https://github.com/pytorch/xla/issues/2755
3267   // TORCH_WARN_ONCE(
3268   //     "torch.svd is deprecated in favor of torch.linalg.svd and will be ",
3269   //     "removed in a future PyTorch release.\n",
3270   //     "U, S, V = torch.svd(A, some=some, compute_uv=True) (default)\n",
3271   //     "should be replaced with\n",
3272   //     "U, S, Vh = torch.linalg.svd(A, full_matrices=not some)\n",
3273   //     "V = Vh.mH\n",
3274   //     "and\n",
3275   //     "_, S, _ = torch.svd(A, some=some, compute_uv=False)\n",
3276   //     "should be replaced with\n",
3277   //     "S = torch.linalg.svdvals(A)");
3278   TORCH_CHECK(self.dim() >= 2, "linalg.svd: input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
3279   Tensor U, S, Vh;
3280   if (compute_uv) {
3281     std::tie(U, S, Vh) = at::linalg_svd(self, /*full_matrices=*/!some);
3282   } else {
3283     S = at::linalg_svdvals(self);
3284     // some == false returns U, Vh of size (m, m), (n, n) full of zeros
3285     const auto m = self.size(-2);
3286     const auto n = self.size(-1);
3287 
3288     auto sizes = self.sizes().vec();
3289     sizes.end()[-1] = m;
3290     U = at::zeros(sizes, self.options());
3291     sizes.end()[-2] = n;
3292     sizes.end()[-1] = n;
3293     Vh = at::zeros(sizes, self.options());
3294   }
3295   return std::make_tuple(std::move(U), std::move(S), Vh.mH());
3296 }
3297 
3298 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3299 
3300 DEFINE_DISPATCH(lstsq_stub);
3301 
3302 /*
3303   Solves a least squares problem. That is minimizing the squared Frobenius norm of |B - A X|.
3304 
3305   Input args:
3306   * 'input' - Tensor containing batches of m-by-n matrix A.
3307   * 'other' - Tensor containing batches of max(m, n)-by-nrhs matrix B.
3308   * 'cond' - relative tolerance for determining rank of A.
3309   * 'driver' - the name of the LAPACK driver that is used to compute the solution.
3310   Output args (modified in-place):
3311   * 'solution' - Tensor to store the solution matrix X.
3312   * 'residuals' - Tensor to store values of the residual sum of squares for each column of the solution.
3313   * 'rank' - Tensor to store the rank of A.
3314   * 'singular_values' - Tensor to store the singular values of A.
3315   * 'infos' - Tensor to store error codes of linear algebra math library.
3316 
3317   For further details, please see the LAPACK documentation for GELS/GELSY/GELSS/GELSD routines.
3318 */
linalg_lstsq_out_info(Tensor & solution,Tensor & residuals,Tensor & rank,Tensor & singular_values,Tensor & infos,const Tensor & input,const Tensor & other,double rcond,std::string & driver)3319 static void linalg_lstsq_out_info(
3320     Tensor& solution,
3321     Tensor& residuals,
3322     Tensor& rank,
3323     Tensor& singular_values,
3324     Tensor& infos,
3325     const Tensor& input,
3326     const Tensor& other,
3327     double rcond,
3328     std::string& driver) {
3329   // These internal asserts make explicit the assumptions in the implementation
3330   // Error check with the actual error messages are done on the higher level of
3331   // the hierarchy of calls
3332   TORCH_INTERNAL_ASSERT(input.dim() >= 2);
3333   TORCH_INTERNAL_ASSERT(other.dim() >= 1);
3334 
3335   auto dim_diff = input.dim() - other.dim();
3336   TORCH_INTERNAL_ASSERT(0 <= dim_diff && dim_diff <= 1);
3337 
3338   TORCH_INTERNAL_ASSERT(input.scalar_type() == other.scalar_type());
3339   TORCH_INTERNAL_ASSERT(input.device() == other.device());
3340 
3341   TORCH_INTERNAL_ASSERT(solution.scalar_type() == input.scalar_type());
3342   TORCH_INTERNAL_ASSERT(solution.device() == input.device());
3343 
3344   TORCH_INTERNAL_ASSERT(residuals.device() == input.device());
3345 
3346   TORCH_INTERNAL_ASSERT(rank.scalar_type() == at::kLong);
3347   TORCH_INTERNAL_ASSERT(rank.device() == input.device());
3348 
3349   auto real_dtype = toRealValueType(input.scalar_type());
3350   TORCH_INTERNAL_ASSERT(singular_values.scalar_type() == real_dtype);
3351   TORCH_INTERNAL_ASSERT(singular_values.device() == input.device());
3352 
3353   TORCH_INTERNAL_ASSERT(infos.scalar_type() == at::kInt);
3354   TORCH_INTERNAL_ASSERT(infos.device() == input.device());
3355   TORCH_INTERNAL_ASSERT(infos.numel() == std::max<int64_t>(1, batchCount(input)));
3356   TORCH_INTERNAL_ASSERT(infos.is_contiguous());
3357 
3358   bool vector_case = linalg_solve_is_vector_rhs(input, other);
3359   // we need to unsqueeze 'other' because 2-dimensional tensors are expected in the implementation
3360   Tensor other_2d = vector_case ? other.unsqueeze(-1) : other;
3361 
3362   TORCH_INTERNAL_ASSERT(input.size(-2) == other_2d.size(-2));
3363 
3364   std::vector<int64_t> expected_solution_shape = broadcast_batch_size(input, other_2d, input.dim() - 2);
3365   // the actual shape of the solution returned is (*, n,) or (*, n, nrhs)
3366   // but LAPACK requires extra dimensions to store raw residuals
3367   // so the expected shape is (*, max(m, n),) or (*, max(m, n), nrhs)
3368   auto m = input.size(-2);
3369   auto n = input.size(-1);
3370   auto nrhs = other.size(-1);
3371   expected_solution_shape.push_back(std::max(m, n));
3372   if (!vector_case) {
3373     expected_solution_shape.push_back(nrhs);
3374   }
3375 
3376   // if 'solution' has no elements we can modify it
3377   if (solution.numel() == 0) {
3378     if (vector_case) {
3379       solution.resize_(expected_solution_shape, MemoryFormat::Contiguous);
3380     } else {
3381       auto shape_transposed = expected_solution_shape;
3382       std::swap(shape_transposed.end()[-1], shape_transposed.end()[-2]);
3383       solution.resize_(shape_transposed, MemoryFormat::Contiguous);
3384       solution.transpose_(-2, -1);
3385     }
3386   }
3387 
3388   // if 'solution' is non-empty it must have the expected shape
3389   TORCH_INTERNAL_ASSERT(solution.sizes().equals(expected_solution_shape));
3390 
3391   // 'solution' must be in batched column major order (Fortran contiguous) for 2D inputs
3392   // or C contiguous for 1D input
3393   if (vector_case) {
3394     TORCH_INTERNAL_ASSERT(solution.is_contiguous());
3395   } else {
3396     TORCH_INTERNAL_ASSERT(solution.mT().is_contiguous());
3397   }
3398 
3399   // for 1-dimensional 'other', we need to unsqueeze the 'solution' before passing to "apply_solve"
3400   if (vector_case) {
3401     solution = solution.unsqueeze_(-1);
3402   }
3403 
3404   // _linalg_lstsq_helper_ performs calculations in-place and 'solution' must be a copy of other_2d
3405   solution.narrow(-2, 0, other_2d.size(-2)).copy_(other_2d);
3406 
3407   // if 'rank' is empty we might resize it
3408   auto input_batch_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2);
3409   if (rank.numel() == 0 && driver != "gels") { // gels driver doesn't set 'rank'
3410     rank.resize_(input_batch_shape, MemoryFormat::Contiguous);
3411   }
3412 
3413   // if 'rank' is non-empty it must have the expected shape and be contiguous
3414   if (driver != "gels") {
3415     TORCH_INTERNAL_ASSERT(rank.sizes().equals(input_batch_shape));
3416     TORCH_INTERNAL_ASSERT(rank.is_contiguous());
3417   }
3418 
3419   // if 'singular_values' is empty we might resize it
3420   auto singular_values_shape = input_batch_shape.vec();
3421   singular_values_shape.push_back(std::min(m, n));
3422   if (singular_values.numel() == 0 && (driver == "gelsd" || driver == "gelss")) {
3423     singular_values.resize_(singular_values_shape, MemoryFormat::Contiguous);
3424   }
3425 
3426   // if 'singular_values' is non-empty it must have the expected shape and be contiguous
3427   if (driver == "gelsd" || driver == "gelss") {
3428     TORCH_INTERNAL_ASSERT(singular_values.sizes().equals(singular_values_shape));
3429     TORCH_INTERNAL_ASSERT(singular_values.is_contiguous());
3430   }
3431 
3432   // 'input' is modified in-place so we need a column-major copy
3433   auto input_working_copy = copyBatchedColumnMajor(input);
3434 
3435   // now the actual call that computes the result in-place (apply_lstsq)
3436   lstsq_stub(input.device().type(), input_working_copy, solution, rank, singular_values, infos, rcond, driver);
3437 
3438   // residuals are available only if m > n and drivers other than gelsy used
3439   if (m > n && driver != "gelsy") {
3440     // if the driver is gelss or gelsd then the residuals are available only if rank == n
3441     bool compute_residuals = true;
3442     if (driver == "gelss" || driver == "gelsd") {
3443       if (input.dim() == 2) {
3444         compute_residuals = (rank.item().toInt() == n);
3445       } else {
3446         // it is not clear what to do if some matrices have rank < n in case of batched input
3447         // For now let's compute the residuals only if all matrices have rank equal to n
3448         // This behaviour may be changed in the future
3449         // See https://github.com/pytorch/pytorch/issues/56483
3450         compute_residuals = at::all(rank == n).item().toBool();
3451       }
3452     }
3453     if (compute_residuals) {
3454       // LAPACK stores residuals data for postprocessing in rows n:(m-n)
3455       auto raw_residuals = solution.narrow(/*dim=*/-2, /*start=*/n, /*length*/m - n);
3456       if (raw_residuals.is_complex()) {
3457         raw_residuals.mul_(raw_residuals.conj());
3458         raw_residuals = at::real(raw_residuals);
3459       } else {
3460         raw_residuals.pow_(2);
3461       }
3462       at::sum_out(residuals, raw_residuals, /*dim=*/-2, /*keepdim=*/false, /*dtype*/real_dtype);
3463     }
3464   }
3465   auto solution_view = solution.narrow(/*dim=*/-2, /*start=*/0, /*length*/n);
3466   // manually restride original
3467   solution.set_(solution.storage(), solution_view.storage_offset(), solution_view.sizes(), solution_view.strides());
3468   if (m == 0) {
3469     solution.zero_();
3470   }
3471 
3472   // for 1-dimensional 'other', we need to squeeze the solution after "apply_lstsq"
3473   if (vector_case) {
3474     solution.squeeze_(-1);
3475   }
3476 }
3477 
get_default_lstsq_driver(std::optional<c10::string_view> driver,const Tensor & input)3478 static std::string get_default_lstsq_driver(std::optional<c10::string_view> driver, const Tensor& input) {
3479   // if `driver` is empty, we set driver_str to "gels" if working with CUDA tensors,
3480   // otherwise to "gelsy" driver.
3481   std::string driver_str;
3482   // check whether the user provided name is a valid driver name
3483   if (driver.has_value()) {
3484     driver_str = std::string(driver.value());
3485     // convert `driver_str` to lower case inplace.
3486     std::transform(driver_str.begin(), driver_str.end(), driver_str.begin(),
3487       [](unsigned char c) { return std::tolower(c); });
3488     static std::unordered_set<c10::string_view> allowed_drivers = {
3489       "gels", "gelsy", "gelsd", "gelss"
3490     };
3491     if (input.device() == at::kCPU) {
3492       TORCH_CHECK(
3493         allowed_drivers.find(driver_str) != allowed_drivers.end(),
3494         "torch.linalg.lstsq: parameter `driver` should be one of "
3495         "(gels, gelsy, gelsd, gelss)"
3496       );
3497     } else { // else if (input.is_cuda())
3498       TORCH_CHECK(
3499         driver_str == "gels",
3500         "torch.linalg.lstsq: `driver` other than `gels` is not supported on CUDA"
3501       );
3502     }
3503   } else {
3504     // if driver name is not provided, set to default 'gelsy' if on CPU,
3505     // or to `gels` if on CUDA.
3506     driver_str = input.is_cuda() ? "gels" : "gelsy";
3507   }
3508   return driver_str;
3509 }
3510 
linalg_lstsq_out(const Tensor & input,const Tensor & other,std::optional<double> rcond,std::optional<c10::string_view> driver,Tensor & solution,Tensor & residuals,Tensor & rank,Tensor & singular_values)3511 std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> linalg_lstsq_out(
3512     const Tensor& input,
3513     const Tensor& other,
3514     std::optional<double> rcond,
3515     std::optional<c10::string_view> driver,
3516     Tensor& solution,
3517     Tensor& residuals,
3518     Tensor& rank,
3519     Tensor& singular_values) {
3520   TORCH_CHECK(input.dim() >= 2, "torch.linalg.lstsq: input must have at least 2 dimensions.");
3521   TORCH_CHECK(other.dim() >= 1, "torch.linalg.lstsq: other must have at least 1 dimension.");
3522   TORCH_CHECK(
3523       input.scalar_type() == other.scalar_type(),
3524       "torch.linalg.lstsq: Expected input and other to have the same dtype, but got input's dtype ",
3525       input.scalar_type(),
3526       " and other's dtype ",
3527       other.scalar_type());
3528 
3529   auto dim_diff = input.dim() - other.dim();
3530   TORCH_CHECK(
3531       0 <= dim_diff && dim_diff <= 1,
3532       "torch.linalg.lstsq: input.dim() must be greater or equal to other.dim() and (input.dim() - other.dim()) <= 1");
3533 
3534   // now check whether the provided output tensors can be used directly
3535 
3536   // Two types of 'other' tensors are supported:
3537   // - 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
3538   // - 2-dimensional (2D) tensor or batch of 2D tensors (matrix case)
3539   // original torch.lstsq supported only the matrix case, while NumPy works for both cases
3540   // for the batched input we need to be able to distinguish them
3541   // auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); // input.shape[:-1]
3542   // bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape));
3543 
3544   bool vector_case = linalg_solve_is_vector_rhs(input, other);
3545   Tensor other_2d = vector_case ? other.unsqueeze(-1) : other;
3546   TORCH_CHECK(
3547       input.size(-2) == other_2d.size(-2),
3548       vector_case ? "torch.linalg.lstsq: input.size(-2) should match other.size(-1)"
3549                : "torch.linalg.lstsq: input.size(-2) should match other.size(-2)");
3550 
3551   checkSameDevice("torch.linalg.lstsq", other, input, "other");
3552   checkSameDevice("torch.linalg.lstsq", solution, input, "solution");
3553   checkSameDevice("torch.linalg.lstsq", residuals, input, "residuals");
3554   checkSameDevice("torch.linalg.lstsq", rank, input, "rank");
3555   checkSameDevice("torch.linalg.lstsq", singular_values, input, "singular_values");
3556 
3557   // 'solution' is expected to have same dtype as input
3558   checkLinalgCompatibleDtype("torch.linalg.lstsq", solution, input, "solution");
3559 
3560   // 'residuals' is expected to have real float dtype
3561   ScalarType real_dtype = c10::toRealValueType(input.scalar_type());
3562   checkLinalgCompatibleDtype("torch.linalg.lstsq", residuals.scalar_type(), real_dtype, "solution");
3563 
3564   // 'rank' is expected to have integer dtype
3565   // actual LAPACK calls use int32_t type for rank, but we promote it to int64_t
3566   // to be consistent with torch.linalg.matrix_rank output dtype
3567   ScalarType rank_expected_type = ScalarType::Long;
3568   checkLinalgCompatibleDtype("torch.linalg.lstsq", rank.scalar_type(), rank_expected_type, "rank");
3569 
3570   // 'singular_values' is expected to have real float dtype
3571   checkLinalgCompatibleDtype("torch.linalg.lstsq", singular_values.scalar_type(), real_dtype, "singular_values");
3572 
3573   std::string driver_name = get_default_lstsq_driver(driver, input);
3574 
3575   // set default rcond value
3576   double rcond_value = rcond.has_value()
3577     ? rcond.value()
3578     : _get_epsilon(c10::toRealValueType(input.scalar_type())) * static_cast<double>(std::max<int64_t>(input.size(-2), input.size(-1)));
3579 
3580   auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt));
3581 
3582   // provided output tensor can be used directly if:
3583   // 1. the shape matches the expected shape
3584   // 2. the dtype matches the expected dtype
3585   // 3. the tensor is contiguous
3586 
3587   // Checks for the 'solution' tensor
3588   std::vector<int64_t> expected_solution_shape = broadcast_batch_size(input, other_2d, input.dim() - 2);
3589   // the actual shape of the shape of the solution returned in (*, n,) or (*, n, nrhs)
3590   // but LAPACK requires extra dimensions so the expected shape is (*, max(m, n),) or (*, max(m, n), nrhs)
3591   expected_solution_shape.push_back(std::max(input.size(-1), input.size(-2)));
3592   if (!vector_case && other.dim() > 2) {
3593     expected_solution_shape.push_back(other.size(-1));
3594   }
3595 
3596   bool solution_equal_expected_shape = solution.sizes().equals(expected_solution_shape);
3597   bool solution_input_same_type = (solution.scalar_type() == input.scalar_type());
3598 
3599   bool is_solution_batched_column_major = false;
3600   if (vector_case) {
3601     is_solution_batched_column_major = solution.is_contiguous();
3602   } else if (!vector_case && solution.dim() >= 2) {
3603     is_solution_batched_column_major = solution.mT().is_contiguous();
3604   }
3605 
3606   // 'residuals' is not checked here because at::sum_out(residuals, ...) does that
3607 
3608   auto input_batch_shape = IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2);
3609 
3610   // Checks for the 'rank' tensor
3611   // rank is a scalar value for each matrix in the batch so
3612   // rank's expected shape is equal to input.shape[0:input.ndim-2]
3613   bool rank_equal_expected_shape = true;
3614   bool rank_equal_expected_type = true;
3615   bool rank_is_contiguous = true;
3616   if (driver_name != "gels") { // gels driver doesn't set 'rank'
3617     rank_equal_expected_shape = rank.sizes().equals(input_batch_shape);
3618     rank_equal_expected_type = (rank.scalar_type() == at::kLong);
3619     rank_is_contiguous = rank.is_contiguous();
3620   }
3621 
3622   // Checks for the 'singular_values' tensor
3623   // singular values are computed only with "gelsd" and "gelss" drivers currently
3624   bool singular_values_equal_expected_shape = true;
3625   bool singular_values_equal_expected_type = true;
3626   bool singular_values_is_contiguous = true;
3627   if (driver_name == "gelsd" || driver_name == "gelss") {
3628     auto singular_values_shape = input_batch_shape.vec();
3629     singular_values_shape.push_back(std::min(input.size(-1), input.size(-2)));
3630     singular_values_equal_expected_shape = singular_values.sizes().equals(singular_values_shape);
3631     singular_values_equal_expected_type = (singular_values.scalar_type() == real_dtype);
3632     singular_values_is_contiguous = singular_values.is_contiguous();
3633   }
3634 
3635   // if solution is not empty and not in batched column major format
3636   bool copy_needed = (solution.numel() != 0 && !is_solution_batched_column_major);
3637   copy_needed |= !solution_input_same_type;  // or solution does not have the same dtype as input
3638   copy_needed |= (solution.numel() != 0 && !solution_equal_expected_shape); // or solution does not have the expected shape
3639 
3640   copy_needed |= !rank_equal_expected_type;
3641   copy_needed |= (rank.numel() != 0 && !rank_equal_expected_shape);
3642   copy_needed |= (rank.numel() != 0 && !rank_is_contiguous);
3643 
3644   copy_needed |= !singular_values_equal_expected_type;
3645   copy_needed |= (singular_values.numel() != 0 && !singular_values_equal_expected_shape);
3646   copy_needed |= (singular_values.numel() != 0 && !singular_values_is_contiguous);
3647 
3648   if (copy_needed) { // we have to allocate temporary tensors
3649     Tensor solution_tmp = at::empty({0}, input.options());
3650     Tensor residuals_tmp = at::empty({0}, input.options().dtype(real_dtype));
3651     Tensor rank_tmp = at::empty({0}, input.options().dtype(at::kLong));
3652     Tensor singular_values_tmp = at::empty({0}, input.options().dtype(real_dtype));
3653 
3654     linalg_lstsq_out_info(solution_tmp, residuals_tmp, rank_tmp, singular_values_tmp, infos, input, other, rcond_value, driver_name);
3655 
3656     at::native::resize_output(solution, solution_tmp.sizes());
3657     solution.copy_(solution_tmp);
3658 
3659     at::native::resize_output(residuals, residuals_tmp.sizes());
3660     residuals.copy_(residuals_tmp);
3661 
3662     at::native::resize_output(rank, rank_tmp.sizes());
3663     rank.copy_(rank_tmp);
3664 
3665     at::native::resize_output(singular_values, singular_values_tmp.sizes());
3666     singular_values.copy_(singular_values_tmp);
3667   } else {
3668     // else use the provided output storage directly
3669     linalg_lstsq_out_info(solution, residuals, rank, singular_values, infos, input, other, rcond_value, driver_name);
3670   }
3671 
3672   at::_linalg_check_errors(infos, "torch.linalg.lstsq", infos.numel() <= 1);
3673   return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>(solution, residuals, rank, singular_values);
3674 }
3675 
linalg_lstsq(const Tensor & input,const Tensor & other,std::optional<double> rcond,std::optional<c10::string_view> driver)3676 std::tuple<Tensor, Tensor, Tensor, Tensor> linalg_lstsq(
3677     const Tensor& input, const Tensor& other,
3678     std::optional<double> rcond,
3679     std::optional<c10::string_view> driver) {
3680   Tensor solution = at::empty({0}, input.options());
3681   Tensor residuals = at::empty({0}, input.options().dtype(toRealValueType(input.scalar_type())));
3682   Tensor rank = at::empty({0}, input.options().dtype(at::kLong));
3683   Tensor singular_values = at::empty({0}, input.options().dtype(toRealValueType(input.scalar_type())));
3684   std::tie(solution, residuals, rank, singular_values) =
3685       at::linalg_lstsq_outf(input, other, rcond, driver, solution, residuals, rank, singular_values);
3686   return std::make_tuple(std::move(solution), std::move(residuals), std::move(rank), std::move(singular_values));
3687 }
3688 
3689 DEFINE_DISPATCH(ldl_factor_stub);
3690 
TORCH_IMPL_FUNC(linalg_ldl_factor_ex_out)3691 TORCH_IMPL_FUNC(linalg_ldl_factor_ex_out)
3692 (const Tensor& self,
3693  bool hermitian,
3694  bool check_errors,
3695  const Tensor& LD,
3696  const Tensor& pivots,
3697  const Tensor& info) {
3698   // LAPACK workspace query segfalts if the input has 0 in batch dimensions.
3699   if (self.numel() == 0) {
3700     info.zero_();
3701     return;
3702   }
3703 
3704   // We decided not to include upper flag in the API.
3705   // https://github.com/pytorch/pytorch/pull/69828#issuecomment-1015143819
3706   // We can revisit this decision later and remove upper completely
3707   // also from low level functions or add it to the public API.
3708   constexpr bool upper = false;
3709   if constexpr (upper) {
3710     at::triu_out(const_cast<Tensor&>(LD), self);
3711   } else {
3712     at::tril_out(const_cast<Tensor&>(LD), self);
3713   }
3714 
3715   // call ldl_factor_stub that fills the result tensors
3716   ldl_factor_stub(
3717       self.device().type(), LD, pivots, info, upper, hermitian);
3718 
3719   if (check_errors) {
3720     at::_linalg_check_errors(
3721         info, "torch.linalg.ldl_factor_ex", self.dim() == 2);
3722   }
3723 }
3724 
linalg_ldl_factor_out(const Tensor & self,bool hermitian,Tensor & LD,Tensor & pivots)3725 std::tuple<Tensor&, Tensor&> linalg_ldl_factor_out(
3726     const Tensor& self,
3727     bool hermitian,
3728     Tensor& LD,
3729     Tensor& pivots) {
3730   auto info = at::empty({0}, self.options().dtype(kInt));
3731   // We pass check_errors as we want to use lu_factor rather than lu_factor_ex
3732   // in the errors
3733   at::linalg_ldl_factor_ex_outf(
3734       self, hermitian, /*check_errors=*/false, LD, pivots, info);
3735   at::_linalg_check_errors(info, "torch.linalg.ldl_factor", self.dim() == 2);
3736   return std::tie(LD, pivots);
3737 }
3738 
linalg_ldl_factor(const Tensor & self,bool hermitian)3739 std::tuple<Tensor, Tensor> linalg_ldl_factor(
3740     const Tensor& self,
3741     bool hermitian) {
3742   auto [LD, pivots, info] =
3743       at::linalg_ldl_factor_ex(self, hermitian, /*check_errors=*/false);
3744   at::_linalg_check_errors(info, "torch.linalg.ldl_factor", self.dim() == 2);
3745   return std::make_tuple(std::move(LD), std::move(pivots));
3746 }
3747 
3748 DEFINE_DISPATCH(ldl_solve_stub);
3749 
TORCH_IMPL_FUNC(linalg_ldl_solve_out)3750 TORCH_IMPL_FUNC(linalg_ldl_solve_out)
3751 (const Tensor& LD,
3752  const Tensor& pivots,
3753  const Tensor& B,
3754  bool hermitian,
3755  const Tensor& result) {
3756   if (LD.numel() == 0 || pivots.numel() == 0) {
3757     return;
3758   }
3759 
3760   auto pivots_ = pivots.expect_contiguous();
3761 
3762   auto LD_ = at::native::borrow_else_clone(
3763       LD.mT().is_contiguous(), LD, LD, /*contig=*/false);
3764   result.copy_(B);
3765   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(result) == batchCount(result));
3766 
3767   ldl_solve_stub(
3768       B.device().type(), *LD_, *pivots_, result, false, hermitian);
3769 }
3770 
3771 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve_triangular ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3772 
linalg_vecdot_out(const Tensor & x,const Tensor & y,int64_t dim,Tensor & out)3773 Tensor& linalg_vecdot_out(const Tensor& x, const Tensor& y, int64_t dim, Tensor& out) {
3774   checkFloatingOrComplex(x, "linalg.vecdot");
3775   TORCH_CHECK(x.scalar_type() == y.scalar_type(),
3776               "linalg.vecdot: Expected x and y to have the same dtype, but found x of type ",
3777               x.scalar_type(), " and y of type ", y.scalar_type(), " instead");
3778   // out checks
3779   TORCH_CHECK(out.scalar_type() == x.scalar_type(),
3780               "linalg.vecdot: Expected out of dtype", x.scalar_type(),
3781               " but found ", out.scalar_type());
3782   checkSameDevice("linalg.vecdot", x, out);
3783 
3784   // Computes x^H y
3785   if (x.dim() == 1 && y.dim() == 1) {
3786     at::native::resize_output(out, {});
3787     return at::vdot_out(out, x, y);
3788   } else {
3789     return at::sum_out(out, x.conj() * y, /*dim=*/dim);
3790   }
3791 }
3792 
linalg_vecdot(const Tensor & x,const Tensor & y,int64_t dim)3793 Tensor linalg_vecdot(const Tensor& x, const Tensor& y, int64_t dim) {
3794   checkFloatingOrComplex(x, "linalg.vecdot");
3795   TORCH_CHECK(x.scalar_type() == y.scalar_type(),
3796               "linalg.vecdot: Expected x and y to have the same dtype, but found x of type ",
3797               x.scalar_type(), " and y of type ", y.scalar_type(), " instead");
3798   // Computes x^H y
3799   if (x.dim() == 1 && y.dim() == 1) {
3800     return at::vdot(x, y);
3801   } else {
3802     return x.conj().mul(y).sum(/*dim=*/dim);
3803   }
3804 }
3805 
3806 /*
3807 Solves the matrix equation AX = B for A triangular.
3808 'left' If true solves AX = B, if false solves XA = B
3809 'upper' controls the portion of input matrix to consider in computations,
3810 'unitriangular' if true then we assume diag(A) to be ones
3811 'out' The tensor with the result. If A == out, A will be modified in place
3812 */
linalg_solve_triangular_out(const Tensor & A,const Tensor & B,bool upper,bool left,bool unitriangular,Tensor & out)3813 Tensor& linalg_solve_triangular_out(
3814     const Tensor& A,
3815     const Tensor& B,
3816     bool upper,
3817     bool left,
3818     bool unitriangular,
3819     Tensor& out) {
3820   checkInputsSolver(A, B, left, "linalg.solve_triangular");
3821   auto [B_, A_] = _linalg_broadcast_batch_dims(B, A, /*don't check errors*/nullptr);
3822 
3823   // We'll write F-contig / F-transpose for FORTRAN contiguous / FORTRAN transpose etc
3824   // We say that a matrix is F-ready if it's F-contig OR F-transpose
3825   // At this point, A, B have been broadcasted but may or may not be F-ready
3826 
3827   // The following algorithm minimises copies and allocations. In pseudocode:
3828   // if out is wrong size:
3829   //   resize_output(out)
3830   // # Invariant: out is the right size
3831   // Tensor out_f; # Tensor that we will pass to FORTRAN
3832   // if out is F-ready:
3833   //   out_f = out;
3834   // else:
3835   //   Allocate out_f F-ready
3836   // if B != out_f:
3837   //   copy B into out_f
3838   // # Invariant: out_f F-ready and has B copied into it
3839   // if out_f is F-transposed:
3840   //   transpose equation
3841   // if out_f is conj:
3842   //   conjugate equation
3843   // # Invariant: out_f is not conjugated and F-contig
3844   // Tensor A_f; # Tensor that will be sent to FORTRAN
3845   // if A is F-ready:
3846   //   if A is conj and A is not transposed:
3847   //     # We need to clone A in this case. See [Cloning A]
3848   //     clone A F-contig into A_f
3849   //   else:
3850   //     A_f = A;
3851   // else:
3852   //   clone A F-contig into A_f
3853   // # Invariant: out_f is F-contig and A_f is F-ready
3854   // # We pass FORTRAN the flags indicating if A_f is transposed and or conjugated
3855   //
3856   // # Here we undo the conjugations / transposes on out_f if needed
3857   //
3858   // if out_f not same out:
3859   //   copy out_f into out
3860   // return out
3861   //
3862   // Note: The logic for the negative bit is the same as that for the conjugate bit
3863   //
3864   // Note: [Cloning A] If we are careful when allocating B when it needs to be allocated at the
3865   // beginning of the algorithm, it is possible to always elide the copy of A here.
3866   // Via this trick, the algorithm will copy at most one of A or B (never both) whenever A
3867   // and B are F-ready and not A.is_neg() (which happens almost always in practice).
3868   // When called as f(A, B, out=B) in most practical cases it'll perform no copies.
3869 
3870   const bool avoid_copy_A = A_.transpose(-2, -1).is_contiguous() && A_.is_conj();
3871   if (avoid_copy_A) {
3872     // See Note: [Cloning A]
3873     at::native::resize_output(out, B_.sizes());
3874   }
3875   else {
3876     // poorman's reimplementation of resize_output with result F-contig
3877     if (resize_output_check(out, B_.sizes())) {
3878       out.resize_(B_.transpose(-2, -1).sizes(), MemoryFormat::Contiguous);
3879       out.transpose_(-2, -1);  // make 'out' have Fortran contiguous memory layout
3880     }
3881   }
3882   // Invariant: out has the right size, so we'll be able to copy into it later on
3883 
3884   Tensor out_f; // the out that will go into fortran
3885   // We use C10_LIKELY mostly for documentation as it helps following what's the most likely path
3886   if C10_LIKELY (is_row_or_column_contiguous(out)) {
3887     out_f = out;
3888     if C10_LIKELY (!out.is_same(B_)) {
3889       out_f.copy_(B_);
3890     }
3891   } else {
3892     if (avoid_copy_A) {
3893       // See Note: [Cloning A]
3894       out_f = B_.clone(at::MemoryFormat::Contiguous);
3895     }
3896     else {
3897       out_f = cloneBatchedColumnMajor(B_);
3898     }
3899   }
3900   // Invariant: out_f F-ready and has B copied into it
3901 
3902   // out_f is F-transposed
3903   bool transpose_A = false;
3904   bool transpose_out_f = false;
3905   if (out_f.stride(-1) == 1) {
3906     left = !left;
3907     transpose_A = true;
3908     transpose_out_f = true;
3909     out_f.transpose_(-2 ,-1);
3910   }
3911 
3912   // No need to conjugate anything if out_f is conj as AX = conj(B) <=> conj(A)conj(X) = B
3913   // and X = B after the algorithm. We just annotate that A is conjugated later on
3914   // The solution will be written into out_f, so it'll be conjugated already
3915 
3916   Tensor A_f = std::move(A_);  // The A that will go into fortran
3917 
3918   bool A_is_conj = A_f.is_conj() != out_f.is_conj();
3919   bool A_is_neg = A_f.is_neg() != out_f.is_neg();
3920   bool A_is_f_contig = (A_f.stride(-1) == 1) == transpose_A;
3921   if C10_UNLIKELY (!is_row_or_column_contiguous(A_f)) {
3922     // We first annotate with flags on A_f all the conj / transpose / neg coming from out
3923     // and then we clone the resulting tensor to resolve all of them in memory
3924     if (out_f.is_conj()) {
3925       A_f = A_f.conj();
3926     }
3927     A_is_conj = false;
3928 
3929     if (out_f.is_neg()) {
3930       A_f = A_f._neg_view();
3931     }
3932     A_is_neg = false;
3933 
3934     // This choice is to be consistent with how we flip `upper` later on
3935     // Note that this is the same reasoning we apply for neg and conj below
3936     // If B has neg or out or transpose, then we need to resolve it in memory
3937     A_f = transpose_A ? A_f.clone(at::MemoryFormat::Contiguous)
3938                       : cloneBatchedColumnMajor(A_f);
3939     A_is_f_contig = true;
3940   } else if C10_UNLIKELY (A_is_f_contig && A_is_conj) {
3941     if C10_UNLIKELY (A_f.is_neg() || out_f.is_neg()) {
3942       // Cases A_is_neg (remember that B.is_neg() iff out_f.is_same(B))
3943       // -AX = -B => A(-X) = B. Swap neg of A_f. Nothing to do on X as X.is_same(B).
3944       // -AX = B. We resolve the neg in memory
3945       // AX = -B => -A -X = B. We resolve the neg in memory for A,
3946       //                       Since X.is_same(B), we already have that X.is_neg() == true
3947 
3948       // We do the neg with a view, as this will be resolved in the clone below
3949       if (out_f.is_neg()) {
3950         A_f = A_f._neg_view();
3951       }
3952       A_is_neg = false;
3953     }
3954     // We resolve the transpose if necessary and then leave A_f F-transposed,
3955     // as BLAS can handle the case F-transposed and conjugated
3956     A_f = at::clone(transpose_A ? A_f.mT() : A_f, at::MemoryFormat::Contiguous);
3957     A_is_f_contig = false;
3958     if (transpose_A) {
3959       upper = !upper;
3960     }
3961     // As we've already resolved the conj of A in the clone
3962     A_is_conj = out_f.is_conj();
3963   } else if C10_UNLIKELY (A_is_neg) {
3964     // We follow the same logic as above, only that in this case we need to perform the
3965     // negation in memory
3966     if (out_f.is_neg()) {
3967       A_f = -A_f;
3968     } else {
3969       A_f = A_f.resolve_neg();
3970     }
3971     A_is_neg = false;
3972     // As we've already resolved the conj of A in the negationa bove
3973     A_is_conj = out_f.is_conj();
3974   }
3975   // Invariant: out_f is F-contig and A_f is F-ready
3976   // neg has been resolved
3977 
3978   // If we pass the matrix physically F-transposed, we need to change the parity of upper
3979   if (A_f.stride(-1) == 1) {
3980     upper = !upper;
3981   }
3982 
3983   triangular_solve_stub(
3984     A_f.device().type(), A_f, out_f,
3985     /*left=*/left,
3986     /*upper=*/upper,
3987     /*transpose*/to_transpose_type(A_is_f_contig, A_is_conj),
3988     /*unitriangular=*/unitriangular);
3989 
3990   if (transpose_out_f) {
3991     out_f.transpose_(-2, -1);
3992   }
3993 
3994   if (!out_f.is_same(out)) {
3995     out.copy_(out_f);
3996   }
3997   return out;
3998 }
3999 
linalg_solve_triangular(const Tensor & A,const Tensor & B,bool upper,bool left,bool unitriangular)4000 Tensor linalg_solve_triangular(
4001     const Tensor& A,
4002     const Tensor& B,
4003     bool upper,
4004     bool left,
4005     bool unitriangular) {
4006   Tensor out = at::empty({0}, A.options());
4007   linalg_solve_triangular_out(A, B, upper, left, unitriangular, out);
4008   return out;
4009 }
4010 
linalg_vander_symint(const Tensor & x,std::optional<c10::SymInt> N)4011 Tensor linalg_vander_symint(
4012     const Tensor& x,
4013     std::optional<c10::SymInt> N) {
4014   auto t = x.scalar_type();
4015   TORCH_CHECK(t == ScalarType::Float ||
4016               t == ScalarType::Double ||
4017               t == ScalarType::ComplexFloat ||
4018               t == ScalarType::ComplexDouble ||
4019               c10::isIntegralType(t, false),
4020               "linalg.vander supports floating point, complex, and integer tensors, but got ", t);
4021   const auto x_ = x.dim() == 0 ? x.unsqueeze(-1) : x;
4022 
4023   auto shape = x_.sym_sizes().vec();
4024   const auto n = N.value_or(shape.back());
4025   TORCH_CHECK(n > 1, "N must be greater than 1.");
4026 
4027   // Append cumprod of the oher 0...n-1 powers
4028   shape.push_back(n - 1);
4029   auto result = at::cumprod(x_.unsqueeze(-1).expand_symint(shape), -1);
4030   // The row of ones
4031   shape.back() = 1LL;
4032   auto ones =  result.new_ones_symint(shape);
4033   return at::cat({std::move(ones), std::move(result)}, /*dim=*/ -1);
4034 }
4035 // NOLINTEND(cppcoreguidelines-pro-type-const-cast)
4036 }  // namespace at::native
4037