1 #pragma once 2 3 #include <optional> 4 #include <c10/util/string_view.h> 5 #include <ATen/Config.h> 6 #include <ATen/native/DispatchStub.h> 7 8 // Forward declare TI 9 namespace at { 10 class Tensor; 11 struct TensorIterator; 12 13 namespace native { 14 enum class TransposeType; 15 } 16 17 } 18 19 namespace at::native { 20 21 enum class LapackLstsqDriverType : int64_t { Gels, Gelsd, Gelsy, Gelss}; 22 23 #if AT_BUILD_WITH_LAPACK() 24 // Define per-batch functions to be used in the implementation of batched 25 // linear algebra operations 26 27 template <class scalar_t> 28 void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info); 29 30 template <class scalar_t> 31 void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info); 32 33 template <class scalar_t, class value_t=scalar_t> 34 void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info); 35 36 template <class scalar_t> 37 void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); 38 39 template <class scalar_t> 40 void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); 41 42 template <class scalar_t> 43 void lapackOrmqr(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info); 44 45 template <class scalar_t, class value_t = scalar_t> 46 void lapackSyevd(char jobz, char uplo, int n, scalar_t* a, int lda, value_t* w, scalar_t* work, int lwork, value_t* rwork, int lrwork, int* iwork, int liwork, int* info); 47 48 template <class scalar_t> 49 void lapackGels(char trans, int m, int n, int nrhs, 50 scalar_t *a, int lda, scalar_t *b, int ldb, 51 scalar_t *work, int lwork, int *info); 52 53 template <class scalar_t, class value_t = scalar_t> 54 void lapackGelsd(int m, int n, int nrhs, 55 scalar_t *a, int lda, scalar_t *b, int ldb, 56 value_t *s, value_t rcond, int *rank, 57 scalar_t* work, int lwork, 58 value_t *rwork, int* iwork, int *info); 59 60 template <class scalar_t, class value_t = scalar_t> 61 void lapackGelsy(int m, int n, int nrhs, 62 scalar_t *a, int lda, scalar_t *b, int ldb, 63 int *jpvt, value_t rcond, int *rank, 64 scalar_t *work, int lwork, value_t* rwork, int *info); 65 66 template <class scalar_t, class value_t = scalar_t> 67 void lapackGelss(int m, int n, int nrhs, 68 scalar_t *a, int lda, scalar_t *b, int ldb, 69 value_t *s, value_t rcond, int *rank, 70 scalar_t *work, int lwork, 71 value_t *rwork, int *info); 72 73 template <LapackLstsqDriverType, class scalar_t, class value_t = scalar_t> 74 struct lapackLstsq_impl; 75 76 template <class scalar_t, class value_t> 77 struct lapackLstsq_impl<LapackLstsqDriverType::Gels, scalar_t, value_t> { 78 static void call( 79 char trans, int m, int n, int nrhs, 80 scalar_t *a, int lda, scalar_t *b, int ldb, 81 scalar_t *work, int lwork, int *info, // Gels flavor 82 int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor 83 value_t *s, // Gelss flavor 84 int *iwork // Gelsd flavor 85 ) { 86 lapackGels<scalar_t>( 87 trans, m, n, nrhs, 88 a, lda, b, ldb, 89 work, lwork, info); 90 } 91 }; 92 93 template <class scalar_t, class value_t> 94 struct lapackLstsq_impl<LapackLstsqDriverType::Gelsy, scalar_t, value_t> { 95 static void call( 96 char trans, int m, int n, int nrhs, 97 scalar_t *a, int lda, scalar_t *b, int ldb, 98 scalar_t *work, int lwork, int *info, // Gels flavor 99 int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor 100 value_t *s, // Gelss flavor 101 int *iwork // Gelsd flavor 102 ) { 103 lapackGelsy<scalar_t, value_t>( 104 m, n, nrhs, 105 a, lda, b, ldb, 106 jpvt, rcond, rank, 107 work, lwork, rwork, info); 108 } 109 }; 110 111 template <class scalar_t, class value_t> 112 struct lapackLstsq_impl<LapackLstsqDriverType::Gelsd, scalar_t, value_t> { 113 static void call( 114 char trans, int m, int n, int nrhs, 115 scalar_t *a, int lda, scalar_t *b, int ldb, 116 scalar_t *work, int lwork, int *info, // Gels flavor 117 int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor 118 value_t *s, // Gelss flavor 119 int *iwork // Gelsd flavor 120 ) { 121 lapackGelsd<scalar_t, value_t>( 122 m, n, nrhs, 123 a, lda, b, ldb, 124 s, rcond, rank, 125 work, lwork, 126 rwork, iwork, info); 127 } 128 }; 129 130 template <class scalar_t, class value_t> 131 struct lapackLstsq_impl<LapackLstsqDriverType::Gelss, scalar_t, value_t> { 132 static void call( 133 char trans, int m, int n, int nrhs, 134 scalar_t *a, int lda, scalar_t *b, int ldb, 135 scalar_t *work, int lwork, int *info, // Gels flavor 136 int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor 137 value_t *s, // Gelss flavor 138 int *iwork // Gelsd flavor 139 ) { 140 lapackGelss<scalar_t, value_t>( 141 m, n, nrhs, 142 a, lda, b, ldb, 143 s, rcond, rank, 144 work, lwork, 145 rwork, info); 146 } 147 }; 148 149 template <LapackLstsqDriverType driver_type, class scalar_t, class value_t = scalar_t> 150 void lapackLstsq( 151 char trans, int m, int n, int nrhs, 152 scalar_t *a, int lda, scalar_t *b, int ldb, 153 scalar_t *work, int lwork, int *info, // Gels flavor 154 int *jpvt, value_t rcond, int *rank, value_t* rwork, // Gelsy flavor 155 value_t *s, // Gelss flavor 156 int *iwork // Gelsd flavor 157 ) { 158 lapackLstsq_impl<driver_type, scalar_t, value_t>::call( 159 trans, m, n, nrhs, 160 a, lda, b, ldb, 161 work, lwork, info, 162 jpvt, rcond, rank, rwork, 163 s, 164 iwork); 165 } 166 167 template <class scalar_t> 168 void lapackLuSolve(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info); 169 170 template <class scalar_t> 171 void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info); 172 173 template <class scalar_t> 174 void lapackLdlHermitian( 175 char uplo, 176 int n, 177 scalar_t* a, 178 int lda, 179 int* ipiv, 180 scalar_t* work, 181 int lwork, 182 int* info); 183 184 template <class scalar_t> 185 void lapackLdlSymmetric( 186 char uplo, 187 int n, 188 scalar_t* a, 189 int lda, 190 int* ipiv, 191 scalar_t* work, 192 int lwork, 193 int* info); 194 195 template <class scalar_t> 196 void lapackLdlSolveHermitian( 197 char uplo, 198 int n, 199 int nrhs, 200 scalar_t* a, 201 int lda, 202 int* ipiv, 203 scalar_t* b, 204 int ldb, 205 int* info); 206 207 template <class scalar_t> 208 void lapackLdlSolveSymmetric( 209 char uplo, 210 int n, 211 int nrhs, 212 scalar_t* a, 213 int lda, 214 int* ipiv, 215 scalar_t* b, 216 int ldb, 217 int* info); 218 219 template<class scalar_t, class value_t=scalar_t> 220 void lapackSvd(char jobz, int m, int n, scalar_t *a, int lda, value_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, value_t *rwork, int *iwork, int *info); 221 #endif 222 223 #if AT_BUILD_WITH_BLAS() 224 template <class scalar_t> 225 void blasTriangularSolve(char side, char uplo, char trans, char diag, int n, int nrhs, scalar_t* a, int lda, scalar_t* b, int ldb); 226 #endif 227 228 using cholesky_fn = void (*)(const Tensor& /*input*/, const Tensor& /*info*/, bool /*upper*/); 229 DECLARE_DISPATCH(cholesky_fn, cholesky_stub); 230 231 using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/); 232 233 DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub); 234 235 using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/); 236 237 DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub); 238 239 using geqrf_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/); 240 DECLARE_DISPATCH(geqrf_fn, geqrf_stub); 241 242 using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/); 243 DECLARE_DISPATCH(orgqr_fn, orgqr_stub); 244 245 using ormqr_fn = void (*)(const Tensor& /*input*/, const Tensor& /*tau*/, const Tensor& /*other*/, bool /*left*/, bool /*transpose*/); 246 DECLARE_DISPATCH(ormqr_fn, ormqr_stub); 247 248 using linalg_eigh_fn = void (*)( 249 const Tensor& /*eigenvalues*/, 250 const Tensor& /*eigenvectors*/, 251 const Tensor& /*infos*/, 252 bool /*upper*/, 253 bool /*compute_eigenvectors*/); 254 DECLARE_DISPATCH(linalg_eigh_fn, linalg_eigh_stub); 255 256 using lstsq_fn = void (*)( 257 const Tensor& /*a*/, 258 Tensor& /*b*/, 259 Tensor& /*rank*/, 260 Tensor& /*singular_values*/, 261 Tensor& /*infos*/, 262 double /*rcond*/, 263 std::string /*driver_name*/); 264 DECLARE_DISPATCH(lstsq_fn, lstsq_stub); 265 266 using triangular_solve_fn = void (*)( 267 const Tensor& /*A*/, 268 const Tensor& /*B*/, 269 bool /*left*/, 270 bool /*upper*/, 271 TransposeType /*transpose*/, 272 bool /*unitriangular*/); 273 DECLARE_DISPATCH(triangular_solve_fn, triangular_solve_stub); 274 275 using lu_factor_fn = void (*)( 276 const Tensor& /*input*/, 277 const Tensor& /*pivots*/, 278 const Tensor& /*infos*/, 279 bool /*compute_pivots*/); 280 DECLARE_DISPATCH(lu_factor_fn, lu_factor_stub); 281 282 using unpack_pivots_fn = void(*)( 283 TensorIterator& iter, 284 const int64_t dim_size, 285 const int64_t max_pivot); 286 DECLARE_DISPATCH(unpack_pivots_fn, unpack_pivots_stub); 287 288 using lu_solve_fn = void (*)( 289 const Tensor& /*LU*/, 290 const Tensor& /*pivots*/, 291 const Tensor& /*B*/, 292 TransposeType /*trans*/); 293 DECLARE_DISPATCH(lu_solve_fn, lu_solve_stub); 294 295 using ldl_factor_fn = void (*)( 296 const Tensor& /*LD*/, 297 const Tensor& /*pivots*/, 298 const Tensor& /*info*/, 299 bool /*upper*/, 300 bool /*hermitian*/); 301 DECLARE_DISPATCH(ldl_factor_fn, ldl_factor_stub); 302 303 using svd_fn = void (*)( 304 const Tensor& /*A*/, 305 const bool /*full_matrices*/, 306 const bool /*compute_uv*/, 307 const std::optional<c10::string_view>& /*driver*/, 308 const Tensor& /*U*/, 309 const Tensor& /*S*/, 310 const Tensor& /*Vh*/, 311 const Tensor& /*info*/); 312 DECLARE_DISPATCH(svd_fn, svd_stub); 313 314 using ldl_solve_fn = void (*)( 315 const Tensor& /*LD*/, 316 const Tensor& /*pivots*/, 317 const Tensor& /*result*/, 318 bool /*upper*/, 319 bool /*hermitian*/); 320 DECLARE_DISPATCH(ldl_solve_fn, ldl_solve_stub); 321 } // namespace at::native 322