xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/BatchLinearAlgebra.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <optional>
4 #include <c10/util/string_view.h>
5 #include <ATen/Config.h>
6 #include <ATen/native/DispatchStub.h>
7 
8 // Forward declare TI
9 namespace at {
10 class Tensor;
11 struct TensorIterator;
12 
13 namespace native {
14 enum class TransposeType;
15 }
16 
17 }
18 
19 namespace at::native {
20 
21 enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss};
22 
23 #if AT_BUILD_WITH_LAPACK()
24 // Define per-batch functions to be used in the implementation of batched
25 // linear algebra operations
26 
27 template <class scalar_t>
28 void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info);
29 
30 template <class scalar_t>
31 void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);
32 
33 template <class scalar_t, class value_t=scalar_t>
34 void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);
35 
36 template <class scalar_t>
37 void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
38 
39 template <class scalar_t>
40 void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
41 
42 template <class scalar_t>
43 void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info);
44 
45 template <class scalar_t, class value_t = scalar_t>
46 void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info);
47 
48 template <class scalar_t>
49 void lapackGels(char trans, int m, int n, int nrhs,
50     scalar_t *a, int lda, scalar_t *b, int ldb,
51     scalar_t *work, int lwork, int *info);
52 
53 template <class scalar_t, class value_t = scalar_t>
54 void lapackGelsd(int m, int n, int nrhs,
55     scalar_t *a, int lda, scalar_t *b, int ldb,
56     value_t *s, value_t rcond, int *rank,
57     scalar_t* work, int lwork,
58     value_t *rwork, int* iwork, int *info);
59 
60 template <class scalar_t, class value_t = scalar_t>
61 void lapackGelsy(int m, int n, int nrhs,
62     scalar_t *a, int lda, scalar_t *b, int ldb,
63     int *jpvt, value_t rcond, int *rank,
64     scalar_t *work, int lwork, value_t* rwork, int *info);
65 
66 template <class scalar_t, class value_t = scalar_t>
67 void lapackGelss(int m, int n, int nrhs,
68     scalar_t *a, int lda, scalar_t *b, int ldb,
69     value_t *s, value_t rcond, int *rank,
70     scalar_t *work, int lwork,
71     value_t *rwork, int *info);
72 
73 template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t>
74 struct lapackLstsq_impl;
75 
76 template <class scalar_t, class value_t>
77 struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> {
78   static void call(
79       char trans, int m, int n, int nrhs,
80       scalar_t *a, int lda, scalar_t *b, int ldb,
81       scalar_t *work, int lwork, int *info, // Gels flavor
82       int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
83       value_t *s, // Gelss flavor
84       int *iwork // Gelsd flavor
85       ) {
86     lapackGels<scalar_t>(
87         trans, m, n, nrhs,
88         a, lda, b, ldb,
89         work, lwork, info);
90   }
91 };
92 
93 template <class scalar_t, class value_t>
94 struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> {
95   static void call(
96       char trans, int m, int n, int nrhs,
97       scalar_t *a, int lda, scalar_t *b, int ldb,
98       scalar_t *work, int lwork, int *info, // Gels flavor
99       int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
100       value_t *s, // Gelss flavor
101       int *iwork // Gelsd flavor
102       ) {
103     lapackGelsy<scalar_t, value_t>(
104         m, n, nrhs,
105         a, lda, b, ldb,
106         jpvt, rcond, rank,
107         work, lwork, rwork, info);
108   }
109 };
110 
111 template <class scalar_t, class value_t>
112 struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> {
113   static void call(
114       char trans, int m, int n, int nrhs,
115       scalar_t *a, int lda, scalar_t *b, int ldb,
116       scalar_t *work, int lwork, int *info, // Gels flavor
117       int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
118       value_t *s, // Gelss flavor
119       int *iwork // Gelsd flavor
120       ) {
121     lapackGelsd<scalar_t, value_t>(
122         m, n, nrhs,
123         a, lda, b, ldb,
124         s, rcond, rank,
125         work, lwork,
126         rwork, iwork, info);
127   }
128 };
129 
130 template <class scalar_t, class value_t>
131 struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> {
132   static void call(
133       char trans, int m, int n, int nrhs,
134       scalar_t *a, int lda, scalar_t *b, int ldb,
135       scalar_t *work, int lwork, int *info, // Gels flavor
136       int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
137       value_t *s, // Gelss flavor
138       int *iwork // Gelsd flavor
139       ) {
140     lapackGelss<scalar_t, value_t>(
141         m, n, nrhs,
142         a, lda, b, ldb,
143         s, rcond, rank,
144         work, lwork,
145         rwork, info);
146   }
147 };
148 
149 template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t>
150 void lapackLstsq(
151     char trans, int m, int n, int nrhs,
152     scalar_t *a, int lda, scalar_t *b, int ldb,
153     scalar_t *work, int lwork, int *info, // Gels flavor
154     int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor
155     value_t *s, // Gelss flavor
156     int *iwork // Gelsd flavor
157     ) {
158   lapackLstsq_impl<driver_type, scalar_t, value_t>::call(
159       trans, m, n, nrhs,
160       a, lda, b, ldb,
161       work, lwork, info,
162       jpvt, rcond, rank, rwork,
163       s,
164       iwork);
165 }
166 
167 template <class scalar_t>
168 void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info);
169 
170 template <class scalar_t>
171 void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info);
172 
173 template <class scalar_t>
174 void lapackLdlHermitian(
175     char uplo,
176     int n,
177     scalar_t* a,
178     int lda,
179     int* ipiv,
180     scalar_t* work,
181     int lwork,
182     int* info);
183 
184 template <class scalar_t>
185 void lapackLdlSymmetric(
186     char uplo,
187     int n,
188     scalar_t* a,
189     int lda,
190     int* ipiv,
191     scalar_t* work,
192     int lwork,
193     int* info);
194 
195 template <class scalar_t>
196 void lapackLdlSolveHermitian(
197     char uplo,
198     int n,
199     int nrhs,
200     scalar_t* a,
201     int lda,
202     int* ipiv,
203     scalar_t* b,
204     int ldb,
205     int* info);
206 
207 template <class scalar_t>
208 void lapackLdlSolveSymmetric(
209     char uplo,
210     int n,
211     int nrhs,
212     scalar_t* a,
213     int lda,
214     int* ipiv,
215     scalar_t* b,
216     int ldb,
217     int* info);
218 
219 template<class scalar_t, class value_t=scalar_t>
220 void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info);
221 #endif
222 
223 #if AT_BUILD_WITH_BLAS()
224 template <class scalar_t>
225 void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb);
226 #endif
227 
228 using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/);
229 DECLARE_DISPATCH(cholesky_fn, cholesky_stub);
230 
231 using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);
232 
233 DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
234 
235 using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/);
236 
237 DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub);
238 
239 using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/);
240 DECLARE_DISPATCH(geqrf_fn, geqrf_stub);
241 
242 using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/);
243 DECLARE_DISPATCH(orgqr_fn, orgqr_stub);
244 
245 using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/);
246 DECLARE_DISPATCH(ormqr_fn, ormqr_stub);
247 
248 using linalg_eigh_fn = void (*)(
249     const Tensor& /*eigenvalues*/,
250     const Tensor& /*eigenvectors*/,
251     const Tensor& /*infos*/,
252     bool /*upper*/,
253     bool /*compute_eigenvectors*/);
254 DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub);
255 
256 using lstsq_fn = void (*)(
257     const Tensor& /*a*/,
258     Tensor& /*b*/,
259     Tensor& /*rank*/,
260     Tensor& /*singular_values*/,
261     Tensor& /*infos*/,
262     double /*rcond*/,
263     std::string /*driver_name*/);
264 DECLARE_DISPATCH(lstsq_fn, lstsq_stub);
265 
266 using triangular_solve_fn = void (*)(
267     const Tensor& /*A*/,
268     const Tensor& /*B*/,
269     bool /*left*/,
270     bool /*upper*/,
271     TransposeType /*transpose*/,
272     bool /*unitriangular*/);
273 DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub);
274 
275 using lu_factor_fn = void (*)(
276     const Tensor& /*input*/,
277     const Tensor& /*pivots*/,
278     const Tensor& /*infos*/,
279     bool /*compute_pivots*/);
280 DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub);
281 
282 using unpack_pivots_fn = void(*)(
283   TensorIterator& iter,
284   const int64_t dim_size,
285   const int64_t max_pivot);
286 DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub);
287 
288 using lu_solve_fn = void (*)(
289     const Tensor& /*LU*/,
290     const Tensor& /*pivots*/,
291     const Tensor& /*B*/,
292     TransposeType /*trans*/);
293 DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub);
294 
295 using ldl_factor_fn = void (*)(
296     const Tensor& /*LD*/,
297     const Tensor& /*pivots*/,
298     const Tensor& /*info*/,
299     bool /*upper*/,
300     bool /*hermitian*/);
301 DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub);
302 
303 using svd_fn = void (*)(
304     const Tensor& /*A*/,
305     const bool /*full_matrices*/,
306     const bool /*compute_uv*/,
307     const std::optional<c10::string_view>& /*driver*/,
308     const Tensor& /*U*/,
309     const Tensor& /*S*/,
310     const Tensor& /*Vh*/,
311     const Tensor& /*info*/);
312 DECLARE_DISPATCH(svd_fn, svd_stub);
313 
314 using ldl_solve_fn = void (*)(
315     const Tensor& /*LD*/,
316     const Tensor& /*pivots*/,
317     const Tensor& /*result*/,
318     bool /*upper*/,
319     bool /*hermitian*/);
320 DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub);
321 } // namespace at::native
322