1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/native/BatchLinearAlgebra.h>
7 #include <ATen/native/LinearAlgebraUtils.h>
8 #include <ATen/native/cpu/zmath.h>
9
10 #include <c10/util/irange.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/empty_strided.h>
18 #endif
19 namespace at::native {
20
21 namespace {
22 /*
23 Computes the Cholesky decomposition of matrices stored in `input`.
24 This is an in-place routine and the content of 'input' is overwritten with the result.
25
26 Args:
27 * `input` - [in] Input tensor for the Cholesky decomposition
28 [out] Cholesky decomposition result
29 * `info` - [out] Tensor filled with LAPACK error codes,
30 positive values indicate that the matrix is not positive definite.
31 * `upper` - controls whether the upper (true) or lower (false) triangular portion of `input` is used
32
33 For further details, please see the LAPACK documentation for POTRF.
34 */
35 template <typename scalar_t>
apply_cholesky(const Tensor & input,const Tensor & info,bool upper)36 void apply_cholesky(const Tensor& input, const Tensor& info, bool upper) {
37 #if !AT_BUILD_WITH_LAPACK()
38 TORCH_CHECK(
39 false,
40 "Calling torch.linalg.cholesky on a CPU tensor requires compiling ",
41 "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
42 #else
43 char uplo = upper ? 'U' : 'L';
44 auto input_data = input.data_ptr<scalar_t>();
45 auto info_data = info.data_ptr<int>();
46 auto input_matrix_stride = matrixStride(input);
47 auto batch_size = batchCount(input);
48 auto n = input.size(-2);
49 auto lda = std::max<int64_t>(1, n);
50
51 for (const auto i : c10::irange(batch_size)) {
52 scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
53 int* info_working_ptr = &info_data[i];
54 lapackCholesky<scalar_t>(uplo, n, input_working_ptr, lda, info_working_ptr);
55 }
56 #endif
57 }
58
59 // This is a type dispatching helper function for 'apply_cholesky'
cholesky_kernel(const Tensor & input,const Tensor & infos,bool upper)60 void cholesky_kernel(const Tensor& input, const Tensor& infos, bool upper) {
61 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "cholesky_cpu", [&]{
62 apply_cholesky<scalar_t>(input, infos, upper);
63 });
64 }
65
66 /*
67 Copies the lower (or upper) triangle of the square matrix to the other half and conjugates it.
68 This operation is performed in-place.
69 */
70 template <typename scalar_t>
apply_reflect_conj_tri_single(scalar_t * self,int64_t n,int64_t stride,bool upper)71 void apply_reflect_conj_tri_single(scalar_t* self, int64_t n, int64_t stride, bool upper) {
72 std::function<void(int64_t, int64_t)> loop = [](int64_t, int64_t){};
73 if (upper) {
74 loop = [&](int64_t start, int64_t end) {
75 for (const auto i : c10::irange(start, end)) {
76 for (int64_t j = i + 1; j < n; j++) {
77 self[i * stride + j] = conj_impl(self[j * stride + i]);
78 }
79 }
80 };
81 } else {
82 loop = [&](int64_t start, int64_t end) {
83 for (const auto i : c10::irange(start, end)) {
84 for (const auto j : c10::irange(i)) {
85 self[i * stride + j] = conj_impl(self[j * stride + i]);
86 }
87 }
88 };
89 }
90 // For small matrices OpenMP overhead is too large
91 if (n < 256) {
92 loop(0, n);
93 } else {
94 at::parallel_for(0, n, 0, loop);
95 }
96 }
97
98 /*
99 Computes the inverse of a symmetric (Hermitian) positive-definite matrix n-by-n matrix 'input' using the Cholesky factorization
100 This is an in-place routine, content of 'input' is overwritten.
101 'infos' is an int Tensor containing error codes for each matrix in the batched input.
102 For more information see LAPACK's documentation for POTRI routine.
103 */
104 template <typename scalar_t>
apply_cholesky_inverse(Tensor & input,Tensor & infos,bool upper)105 void apply_cholesky_inverse(Tensor& input, Tensor& infos, bool upper) {
106 #if !AT_BUILD_WITH_LAPACK()
107 TORCH_CHECK(false, "cholesky_inverse: LAPACK library not found in compilation");
108 #else
109 char uplo = upper ? 'U' : 'L';
110
111 auto input_data = input.data_ptr<scalar_t>();
112 auto infos_data = infos.data_ptr<int>();
113 auto input_matrix_stride = matrixStride(input);
114 auto batch_size = batchCount(input);
115 auto n = input.size(-2);
116 auto lda = std::max<int64_t>(1, n);
117
118 for (const auto i : c10::irange(batch_size)) {
119 scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
120 int* info_working_ptr = &infos_data[i];
121 lapackCholeskyInverse<scalar_t>(uplo, n, input_working_ptr, lda, info_working_ptr);
122 // LAPACK writes to only upper/lower part of the matrix leaving the other side unchanged
123 apply_reflect_conj_tri_single<scalar_t>(input_working_ptr, n, lda, upper);
124 }
125 #endif
126 }
127
128 // This is a type dispatching helper function for 'apply_cholesky_inverse'
cholesky_inverse_kernel_impl(Tensor & result,Tensor & infos,bool upper)129 Tensor& cholesky_inverse_kernel_impl(Tensor& result, Tensor& infos, bool upper) {
130 // This function calculates the inverse matrix in-place
131 // result should be in column major order and contain matrices to invert
132 // the content of result is overwritten by 'apply_cholesky_inverse'
133 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "cholesky_inverse_out_cpu", [&]{
134 apply_cholesky_inverse<scalar_t>(result, infos, upper);
135 });
136 return result;
137 }
138
139 /*
140 Computes the eigenvalues and eigenvectors of n-by-n matrix 'input'.
141 This is an in-place routine, content of 'input', 'values', 'vectors' is overwritten.
142 'infos' is an int Tensor containing error codes for each matrix in the batched input.
143 For more information see LAPACK's documentation for GEEV routine.
144 */
145 template <typename scalar_t>
apply_linalg_eig(Tensor & values,Tensor & vectors,Tensor & input,Tensor & infos,bool compute_eigenvectors)146 void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) {
147 #if !AT_BUILD_WITH_LAPACK()
148 TORCH_CHECK(false, "Calling torch.linalg.eig on a CPU tensor requires compiling ",
149 "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
150 #else
151 using value_t = typename c10::scalar_value_type<scalar_t>::type;
152
153 char jobvr = compute_eigenvectors ? 'V' : 'N';
154 char jobvl = 'N'; // only right eigenvectors are computed
155 auto n = input.size(-1);
156 auto lda = std::max<int64_t>(1, n);
157 auto batch_size = batchCount(input);
158 auto input_matrix_stride = matrixStride(input);
159 auto values_stride = values.size(-1);
160 auto input_data = input.data_ptr<scalar_t>();
161 auto values_data = values.data_ptr<scalar_t>();
162 auto infos_data = infos.data_ptr<int>();
163 auto rvectors_data = compute_eigenvectors ? vectors.data_ptr<scalar_t>() : nullptr;
164 scalar_t* lvectors_data = nullptr; // only right eigenvectors are computed
165 int64_t ldvr = compute_eigenvectors ? lda : 1;
166 int64_t ldvl = 1;
167
168 Tensor rwork;
169 value_t* rwork_data = nullptr;
170 if (input.is_complex()) {
171 ScalarType real_dtype = toRealValueType(input.scalar_type());
172 rwork = at::empty({lda * 2}, input.options().dtype(real_dtype));
173 rwork_data = rwork.mutable_data_ptr<value_t>();
174 }
175
176 // call lapackEig once to get the optimal size for work data
177 scalar_t work_query;
178 lapackEig<scalar_t, value_t>(jobvl, jobvr, n, input_data, lda, values_data,
179 lvectors_data, ldvl, rvectors_data, ldvr, &work_query, -1, rwork_data, &infos_data[0]);
180
181 int lwork = std::max<int>(1, static_cast<int>(real_impl<scalar_t, value_t>(work_query)));
182 Tensor work = at::empty({lwork}, input.dtype());
183 auto work_data = work.mutable_data_ptr<scalar_t>();
184
185 for (const auto i : c10::irange(batch_size)) {
186 scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
187 scalar_t* values_working_ptr = &values_data[i * values_stride];
188 scalar_t* rvectors_working_ptr = compute_eigenvectors ? &rvectors_data[i * input_matrix_stride] : nullptr;
189 int* info_working_ptr = &infos_data[i];
190 lapackEig<scalar_t, value_t>(jobvl, jobvr, n, input_working_ptr, lda, values_working_ptr,
191 lvectors_data, ldvl, rvectors_working_ptr, ldvr, work_data, lwork, rwork_data, info_working_ptr);
192 }
193 #endif
194 }
195
196 // This is a type dispatching helper function for 'apply_linalg_eig'
linalg_eig_kernel(Tensor & eigenvalues,Tensor & eigenvectors,Tensor & infos,const Tensor & input,bool compute_eigenvectors)197 void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) {
198 // This function calculates the non-symmetric eigendecomposition in-place
199 // tensors should be in batched column major memory format
200 // the content of eigenvalues, eigenvectors and infos is overwritten by 'apply_linalg_eig'
201
202 // apply_linalg_eig modifies in-place provided input matrix, therefore we need a copy
203 Tensor input_working_copy = at::empty(input.mT().sizes(), input.options());
204 input_working_copy.transpose_(-2, -1); // make input_working_copy to have Fortran contiguous memory layout
205 input_working_copy.copy_(input);
206
207 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "linalg_eig_out_cpu", [&]{
208 apply_linalg_eig<scalar_t>(eigenvalues, eigenvectors, input_working_copy, infos, compute_eigenvectors);
209 });
210 }
211
212 /*
213 Computes eigenvalues and eigenvectors of the input that is stored initially in 'vectors'.
214 The computation is done in-place: 'vectors' stores the input and will be overwritten,
215 'values' should be an allocated empty array.
216 'infos' is used to store information for possible checks for error.
217 'upper' controls the portion of input matrix to consider in computations
218 'compute_eigenvectors' controls whether eigenvectors should be computed.
219 This function doesn't do any error checks and it's assumed that every argument is valid.
220 */
221 template <typename scalar_t>
apply_lapack_eigh(const Tensor & values,const Tensor & vectors,const Tensor & infos,bool upper,bool compute_eigenvectors)222 void apply_lapack_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
223 #if !AT_BUILD_WITH_LAPACK()
224 TORCH_CHECK(
225 false,
226 "Calling torch.linalg.eigh or eigvalsh on a CPU tensor requires compiling ",
227 "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
228 #else
229 using value_t = typename c10::scalar_value_type<scalar_t>::type;
230
231 char uplo = upper ? 'U' : 'L';
232 char jobz = compute_eigenvectors ? 'V' : 'N';
233
234 auto n = vectors.size(-1);
235 auto lda = std::max<int64_t>(1, n);
236 auto batch_size = batchCount(vectors);
237
238 auto vectors_stride = matrixStride(vectors);
239 auto values_stride = values.size(-1);
240
241 auto vectors_data = vectors.data_ptr<scalar_t>();
242 auto values_data = values.data_ptr<value_t>();
243 auto infos_data = infos.data_ptr<int>();
244
245 // Using 'int' instead of int32_t or int64_t is consistent with the current LAPACK interface
246 // It really should be changed in the future to something like lapack_int that depends on the specific LAPACK library that is linked
247 // or switch to supporting only 64-bit indexing by default.
248 int lwork = -1;
249 int lrwork = -1;
250 int liwork = -1;
251 scalar_t lwork_query;
252 value_t rwork_query;
253 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
254 int iwork_query;
255
256 // call lapackSyevd once to get the optimal size for work data
257 lapackSyevd<scalar_t, value_t>(jobz, uplo, n, vectors_data, lda, values_data,
258 &lwork_query, lwork, &rwork_query, lrwork, &iwork_query, liwork, infos_data);
259
260 lwork = std::max<int>(1, real_impl<scalar_t, value_t>(lwork_query));
261 Tensor work = at::empty({lwork}, vectors.options());
262 auto work_data = work.mutable_data_ptr<scalar_t>();
263
264 liwork = std::max<int>(1, iwork_query);
265 Tensor iwork = at::empty({liwork}, vectors.options().dtype(at::kInt));
266 auto iwork_data = iwork.mutable_data_ptr<int>();
267
268 Tensor rwork;
269 value_t* rwork_data = nullptr;
270 if (vectors.is_complex()) {
271 lrwork = std::max<int>(1, rwork_query);
272 rwork = at::empty({lrwork}, values.options());
273 rwork_data = rwork.mutable_data_ptr<value_t>();
274 }
275
276 // Now call lapackSyevd for each matrix in the batched input
277 for (const auto i : c10::irange(batch_size)) {
278 scalar_t* vectors_working_ptr = &vectors_data[i * vectors_stride];
279 value_t* values_working_ptr = &values_data[i * values_stride];
280 int* info_working_ptr = &infos_data[i];
281 lapackSyevd<scalar_t, value_t>(jobz, uplo, n, vectors_working_ptr, lda, values_working_ptr,
282 work_data, lwork, rwork_data, lrwork, iwork_data, liwork, info_working_ptr);
283 // The current behaviour for Linear Algebra functions to raise an error if something goes wrong
284 // or input doesn't satisfy some requirement
285 // therefore return early since further computations will be wasted anyway
286 if (*info_working_ptr != 0) {
287 return;
288 }
289 }
290 #endif
291 }
292
293 // This is a type dispatching helper function for 'apply_lapack_eigh'
linalg_eigh_kernel(const Tensor & eigenvalues,const Tensor & eigenvectors,const Tensor & infos,bool upper,bool compute_eigenvectors)294 void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
295 // This function calculates the symmetric/hermitian eigendecomposition
296 // in-place tensors should be in batched column major memory format the
297 // content of eigenvalues, eigenvectors and infos is overwritten by
298 // 'apply_lapack_eigh'
299 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
300 eigenvectors.scalar_type(), "linalg_eigh_cpu", [&] {
301 apply_lapack_eigh<scalar_t>(
302 eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
303 });
304 }
305
306 /*
307 The geqrf function computes the QR decomposition of matrices stored in `input`.
308 However, rather than producing a Q matrix directly, it produces a sequence of
309 elementary reflectors which may later be composed to construct Q - for example
310 with the orgqr or ormqr functions.
311
312 Args:
313 * `input` - [in] Input tensor for QR decomposition
314 [out] QR decomposition result which contains:
315 i) The elements of R, on and above the diagonal.
316 ii) Directions of the reflectors implicitly defining Q.
317 Tensor with the directions of the elementary reflectors below the diagonal,
318 it will be overwritten with the result
319 * `tau` - [out] Tensor which will contain the magnitudes of the reflectors
320 implicitly defining Q.
321
322 For further details, please see the LAPACK documentation for GEQRF.
323 */
324 template <typename scalar_t>
apply_geqrf(const Tensor & input,const Tensor & tau)325 static void apply_geqrf(const Tensor& input, const Tensor& tau) {
326 #if !AT_BUILD_WITH_LAPACK()
327 TORCH_CHECK(
328 false,
329 "Calling torch.geqrf on a CPU tensor requires compiling ",
330 "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
331 #else
332 using value_t = typename c10::scalar_value_type<scalar_t>::type;
333 auto input_data = input.data_ptr<scalar_t>();
334 auto tau_data = tau.data_ptr<scalar_t>();
335 auto input_matrix_stride = matrixStride(input);
336 auto tau_stride = tau.size(-1);
337 auto batch_size = batchCount(input);
338 auto m = input.size(-2);
339 auto n = input.size(-1);
340 auto lda = std::max<int64_t>(1, m);
341
342 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
343 int info;
344 // Run once, first to get the optimum work size.
345 // Since we deal with batches of matrices with the same dimensions, doing this outside
346 // the loop saves (batch_size - 1) workspace queries which would provide the same result
347 // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
348 int lwork = -1;
349 scalar_t wkopt;
350 lapackGeqrf<scalar_t>(m, n, input_data, lda, tau_data, &wkopt, lwork, &info);
351 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
352
353 // if lwork is less than 'n' then a warning is printed:
354 // Intel MKL ERROR: Parameter 7 was incorrect on entry to SGEQRF.
355 lwork = std::max<int>({1, static_cast<int>(n), static_cast<int>(real_impl<scalar_t, value_t>(wkopt))});
356 Tensor work = at::empty({lwork}, input.options());
357
358 for (const auto i : c10::irange(batch_size)) {
359 scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
360 scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
361
362 // now compute the actual QR and tau
363 lapackGeqrf<scalar_t>(m, n, input_working_ptr, lda, tau_working_ptr, work.data_ptr<scalar_t>(), lwork, &info);
364
365 // info from lapackGeqrf only reports if the i-th parameter is wrong
366 // so we don't need to check it all the time
367 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
368 }
369 #endif
370 }
371
372 // This is a type dispatching helper function for 'apply_geqrf'
geqrf_kernel(const Tensor & input,const Tensor & tau)373 void geqrf_kernel(const Tensor& input, const Tensor& tau) {
374 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "geqrf_cpu", [&]{
375 apply_geqrf<scalar_t>(input, tau);
376 });
377 }
378
379 /*
380 The orgqr function allows reconstruction of an orthogonal (or unitary) matrix Q,
381 from a sequence of elementary reflectors, such as produced by the geqrf function.
382
383 Args:
384 * `self` - Tensor with the directions of the elementary reflectors below the diagonal,
385 it will be overwritten with the result
386 * `tau` - Tensor containing the magnitudes of the elementary reflectors
387
388 For further details, please see the LAPACK documentation for ORGQR and UNGQR.
389 */
390 template <typename scalar_t>
apply_orgqr(Tensor & self,const Tensor & tau)391 inline void apply_orgqr(Tensor& self, const Tensor& tau) {
392 #if !AT_BUILD_WITH_LAPACK()
393 TORCH_CHECK(false, "Calling torch.orgqr on a CPU tensor requires compiling ",
394 "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
395 #else
396 // Some LAPACK implementations might not work well with empty matrices:
397 // workspace query might return lwork as 0, which is not allowed (requirement is lwork >= 1)
398 // We don't need to do any calculations in this case, so let's return early
399 if (self.numel() == 0) {
400 return;
401 }
402
403 using value_t = typename c10::scalar_value_type<scalar_t>::type;
404 auto self_data = self.data_ptr<scalar_t>();
405 auto tau_data = tau.const_data_ptr<scalar_t>();
406 auto self_matrix_stride = matrixStride(self);
407 auto tau_stride = tau.size(-1);
408 auto batch_size = batchCount(self);
409 auto m = self.size(-2);
410 auto n = self.size(-1);
411 auto k = tau.size(-1);
412 auto lda = std::max<int64_t>(1, m);
413 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
414 int info;
415
416 // LAPACK's requirement
417 TORCH_INTERNAL_ASSERT(m >= n);
418 TORCH_INTERNAL_ASSERT(n >= k);
419
420 // Run once, first to get the optimum work size.
421 // Since we deal with batches of matrices with the same dimensions, doing this outside
422 // the loop saves (batch_size - 1) workspace queries which would provide the same result
423 // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
424 int lwork = -1;
425 scalar_t wkopt;
426 lapackOrgqr<scalar_t>(m, n, k, self_data, lda, const_cast<scalar_t*>(tau_data), &wkopt, lwork, &info);
427 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
428 lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
429 Tensor work = at::empty({lwork}, self.options());
430
431 for (const auto i : c10::irange(batch_size)) {
432 scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
433 const scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
434
435 // now compute the actual Q
436 lapackOrgqr<scalar_t>(m, n, k, self_working_ptr, lda, const_cast<scalar_t*>(tau_working_ptr), work.data_ptr<scalar_t>(), lwork, &info);
437
438 // info from lapackOrgqr only reports if the i-th parameter is wrong
439 // so we don't need to check it all the time
440 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
441 }
442 #endif
443 }
444
445 // This is a type dispatching helper function for 'apply_orgqr'
orgqr_kernel_impl(Tensor & result,const Tensor & tau)446 Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau) {
447 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "orgqr_cpu", [&]{
448 apply_orgqr<scalar_t>(result, tau);
449 });
450 return result;
451 }
452
453 /*
454 Solves a least squares problem. That is minimizing ||B - A X||.
455
456 Input args:
457 * 'input' - Tensor containing batches of m-by-n matrix A.
458 * 'other' - Tensor containing batches of max(m, n)-by-nrhs matrix B.
459 * 'cond' - relative tolerance for determining rank of A.
460 * 'driver' - the name of the LAPACK driver that is used to compute the solution.
461 Output args (modified in-place):
462 * 'solution' - Tensor to store the solution matrix X.
463 * 'residuals' - Tensor to store values of ||B - A X||.
464 * 'rank' - Tensor to store the rank of A.
465 * 'singular_values' - Tensor to store the singular values of A.
466 * 'infos' - Tensor to store error codes of linear algebra math library.
467
468 For further details, please see the LAPACK documentation for GELS/GELSY/GELSS/GELSD routines.
469 */
470 template <typename scalar_t>
apply_lstsq(const Tensor & A,Tensor & B,Tensor & rank,Tensor & singular_values,Tensor & infos,double rcond,LapackLstsqDriverType driver_type)471 void apply_lstsq(const Tensor& A, Tensor& B, Tensor& rank, Tensor& singular_values, Tensor& infos, double rcond, LapackLstsqDriverType driver_type) {
472 #if !AT_BUILD_WITH_LAPACK()
473 TORCH_CHECK(
474 false,
475 "Calling torch.linalg.lstsq on a CPU tensor requires compiling ",
476 "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
477 #else
478 using value_t = typename c10::scalar_value_type<scalar_t>::type;
479 using driver_t = at::native::LapackLstsqDriverType;
480
481 auto lapack_func = lapackLstsq<driver_t::Gelsd, scalar_t, value_t>;
482 static auto driver_type_to_func
483 = std::unordered_map<driver_t, decltype(lapack_func)>({
484 {driver_t::Gels, lapackLstsq<driver_t::Gels, scalar_t, value_t>},
485 {driver_t::Gelsy, lapackLstsq<driver_t::Gelsy, scalar_t, value_t>},
486 {driver_t::Gelsd, lapackLstsq<driver_t::Gelsd, scalar_t, value_t>},
487 {driver_t::Gelss, lapackLstsq<driver_t::Gelss, scalar_t, value_t>}
488 });
489 lapack_func = driver_type_to_func[driver_type];
490
491 char trans = 'N';
492
493 auto A_data = A.data_ptr<scalar_t>();
494 auto B_data = B.data_ptr<scalar_t>();
495 auto m = A.size(-2);
496 auto n = A.size(-1);
497 auto nrhs = B.size(-1);
498 auto lda = std::max<int64_t>(1, m);
499 auto ldb = std::max<int64_t>(1, std::max(m, n));
500 auto infos_data = infos.data_ptr<int>();
501
502 // only 'gels' driver does not compute the rank
503 int rank_32 = 0;
504 int64_t* rank_data = nullptr;
505 int64_t* rank_working_ptr = nullptr;
506 if (driver_t::Gels != driver_type) {
507 rank_data = rank.data_ptr<int64_t>();
508 rank_working_ptr = rank_data;
509 }
510
511 // 'gelsd' and 'gelss' are SVD-based algorithms
512 // so we can get singular values
513 value_t* s_data = nullptr;
514 value_t* s_working_ptr = nullptr;
515 int64_t s_stride = 0;
516 if (driver_t::Gelsd == driver_type || driver_t::Gelss == driver_type) {
517 s_data = singular_values.data_ptr<value_t>();
518 s_working_ptr = s_data;
519 s_stride = singular_values.size(-1);
520 }
521
522 // 'jpvt' workspace array is used only for 'gelsy' which uses QR factorization with column pivoting
523 Tensor jpvt;
524 int* jpvt_data = nullptr;
525 if (driver_t::Gelsy == driver_type) {
526 jpvt = at::empty({std::max<int64_t>(1, n)}, A.options().dtype(at::kInt));
527 jpvt_data = jpvt.mutable_data_ptr<int>();
528 }
529
530 // Run once the driver, first to get the optimal workspace size
531 int lwork = -1; // default value to decide the opt size for workspace arrays
532 scalar_t work_opt;
533 value_t rwork_opt;
534 int iwork_opt = 0;
535 lapack_func(trans, m, n, nrhs,
536 A_data, lda,
537 B_data, ldb,
538 &work_opt, lwork,
539 infos_data,
540 jpvt_data,
541 static_cast<value_t>(rcond),
542 &rank_32,
543 &rwork_opt,
544 s_working_ptr,
545 &iwork_opt);
546
547 lwork = std::max<int>(1, real_impl<scalar_t, value_t>(work_opt));
548 Tensor work = at::empty({lwork}, A.options());
549 scalar_t* work_data = work.mutable_data_ptr<scalar_t>();
550
551 // 'rwork' only used for complex inputs and 'gelsy', 'gelsd' and 'gelss' drivers
552 Tensor rwork;
553 value_t* rwork_data = nullptr;
554 if (A.is_complex() && driver_t::Gels != driver_type) {
555 int64_t rwork_len = 0;
556 switch (driver_type) {
557 case driver_t::Gelsy:
558 rwork_len = std::max<int64_t>(1, 2 * n);
559 break;
560 case driver_t::Gelss:
561 rwork_len = std::max<int64_t>(1, 5 * std::min(m, n));
562 break;
563 // case driver_t::Gelsd:
564 default:
565 rwork_len = std::max<int64_t>(1, rwork_opt);
566 }
567 rwork = at::empty({rwork_len}, A.options().dtype(c10::toRealValueType(A.scalar_type())));
568 rwork_data = rwork.mutable_data_ptr<value_t>();
569 }
570
571 // 'iwork' workspace array is relevant only for 'gelsd'
572 Tensor iwork;
573 int* iwork_data = nullptr;
574 if (driver_t::Gelsd == driver_type) {
575 iwork = at::empty({std::max<int>(1, iwork_opt)}, A.options().dtype(at::kInt));
576 iwork_data = iwork.mutable_data_ptr<int>();
577 }
578
579 at::native::batch_iterator_with_broadcasting<scalar_t>(A, B,
580 [&](scalar_t* A_working_ptr, scalar_t* B_working_ptr, int64_t A_linear_batch_idx) {
581 rank_working_ptr = rank_working_ptr ? &rank_data[A_linear_batch_idx] : nullptr;
582 s_working_ptr = s_working_ptr ? &s_data[A_linear_batch_idx * s_stride] : nullptr;
583 int* infos_working_ptr = &infos_data[A_linear_batch_idx];
584
585 lapack_func(trans, m, n, nrhs,
586 A_working_ptr, lda,
587 B_working_ptr, ldb,
588 work_data, lwork,
589 infos_working_ptr,
590 jpvt_data,
591 static_cast<value_t>(rcond),
592 &rank_32,
593 rwork_data,
594 s_working_ptr,
595 iwork_data);
596
597 // we want the output `rank` Tensor to be of type int64_t,
598 // however LAPACK accepts int. That is why we use an integer
599 // variable that then gets promoted and written into `rank`.
600 // We use this approach over a tensor cast for better performance.
601 if (rank_working_ptr) {
602 *rank_working_ptr = static_cast<int64_t>(rank_32);
603 }
604 }
605 );
606 #endif
607 }
608
609 // This is a type and driver dispatching helper function for 'apply_lstsq'
lstsq_kernel(const Tensor & a,Tensor & b,Tensor & rank,Tensor & singular_values,Tensor & infos,double rcond,std::string driver_name)610 void lstsq_kernel(const Tensor& a, Tensor& b, Tensor& rank, Tensor& singular_values, Tensor& infos, double rcond, std::string driver_name) {
611
612 static auto driver_string_to_type = std::unordered_map<c10::string_view, LapackLstsqDriverType>({
613 {"gels", at::native::LapackLstsqDriverType::Gels},
614 {"gelsy", at::native::LapackLstsqDriverType::Gelsy},
615 {"gelsd", at::native::LapackLstsqDriverType::Gelsd},
616 {"gelss", at::native::LapackLstsqDriverType::Gelss}
617 });
618 auto driver_type = driver_string_to_type[driver_name];
619
620 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(a.scalar_type(), "linalg_lstsq_cpu", [&]{
621 apply_lstsq<scalar_t>(a, b, rank, singular_values, infos, rcond, driver_type);
622 });
623 }
624
625 /*
626 The ormqr function multiplies Q with another matrix from a sequence of
627 elementary reflectors, such as is produced by the geqrf function.
628
629 Args:
630 * `input` - Tensor with elementary reflectors below the diagonal,
631 encoding the matrix Q.
632 * `tau` - Tensor containing the magnitudes of the elementary
633 reflectors.
634 * `other` - [in] Tensor containing the matrix to be multiplied.
635 [out] result of the matrix multiplication with Q.
636 * `left` - bool, determining whether `other` is left- or right-multiplied with Q.
637 * `transpose` - bool, determining whether to transpose (or conjugate transpose) Q before multiplying.
638
639 For further details, please see the LAPACK documentation.
640 */
641 template <typename scalar_t>
apply_ormqr(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose)642 void apply_ormqr(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
643 #if !AT_BUILD_WITH_LAPACK()
644 TORCH_CHECK(false, "Calling torch.ormqr on a CPU tensor requires compiling ",
645 "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
646 #else
647 using value_t = typename c10::scalar_value_type<scalar_t>::type;
648
649 char side = left ? 'L' : 'R';
650 char trans = transpose ? (input.is_complex() ? 'C' : 'T') : 'N';
651
652 auto input_data = input.const_data_ptr<scalar_t>();
653 auto tau_data = tau.const_data_ptr<scalar_t>();
654 auto other_data = other.data_ptr<scalar_t>();
655
656 auto input_matrix_stride = matrixStride(input);
657 auto other_matrix_stride = matrixStride(other);
658 auto tau_stride = tau.size(-1);
659 auto batch_size = batchCount(input);
660 auto m = other.size(-2);
661 auto n = other.size(-1);
662 auto k = tau.size(-1);
663 auto lda = std::max<int64_t>(1, left ? m : n);
664 auto ldc = std::max<int64_t>(1, m);
665 int info = 0;
666
667 // LAPACK's requirement
668 TORCH_INTERNAL_ASSERT_DEBUG_ONLY((left ? m : n) >= k);
669
670 // Query for the optimal size of the workspace tensor
671 int lwork = -1;
672 scalar_t wkopt;
673 lapackOrmqr<scalar_t>(side, trans, m, n, k, const_cast<scalar_t*>(input_data), lda, const_cast<scalar_t*>(tau_data), other_data, ldc, &wkopt, lwork, &info);
674 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
675 lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
676 Tensor work = at::empty({lwork}, input.options());
677
678 for (const auto i : c10::irange(batch_size)) {
679 const scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
680 scalar_t* other_working_ptr = &other_data[i * other_matrix_stride];
681 const scalar_t* tau_working_ptr = &tau_data[i * tau_stride];
682
683 // now compute the actual result
684 lapackOrmqr<scalar_t>(
685 side, trans, m, n, k,
686 const_cast<scalar_t*>(input_working_ptr), lda,
687 const_cast<scalar_t*>(tau_working_ptr),
688 other_working_ptr, ldc,
689 work.data_ptr<scalar_t>(), lwork, &info);
690
691 // info from lapackOrmqr only reports if the i-th parameter is wrong
692 // so we don't need to check it all the time
693 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
694 }
695 #endif
696 }
697
698 // This is a type dispatching helper function for 'apply_ormqr'
ormqr_kernel(const Tensor & input,const Tensor & tau,const Tensor & other,bool left,bool transpose)699 void ormqr_kernel(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose) {
700 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "ormqr_cpu", [&]{
701 apply_ormqr<scalar_t>(input, tau, other, left, transpose);
702 });
703 }
704
705 /*
706 Solves the matrix equation op(A) X = B
707 X and B are n-by-nrhs matrices, A is a unit, or non-unit, upper or lower triangular matrix
708 and op(A) is one of op(A) = A or op(A) = A^T or op(A) = A^H.
709 This is an in-place routine, content of 'B' is overwritten.
710 'upper' controls the portion of input matrix to consider in computations,
711 'transpose' chooses op(A)
712 'unitriangular' if true then the diagonal elements of A are assumed to be 1
713 and the actual diagonal values are not used.
714 */
715 template<typename scalar_t>
apply_triangular_solve(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)716 void apply_triangular_solve(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
717 #if !AT_BUILD_WITH_BLAS()
718 TORCH_CHECK(
719 false,
720 "Calling torch.triangular_solve on a CPU tensor requires compiling ",
721 "PyTorch with BLAS. Please use PyTorch built with BLAS support.");
722 #else
723 char uplo = upper ? 'U' : 'L';
724 char diag = unitriangular ? 'U' : 'N';
725 char side = left ? 'L' : 'R';
726 const char trans = to_blas(transpose);
727
728 auto A_data = A.const_data_ptr<scalar_t>();
729 auto B_data = B.data_ptr<scalar_t>();
730 auto A_mat_stride = matrixStride(A);
731 auto B_mat_stride = matrixStride(B);
732 auto batch_size = batchCount(A);
733 // This allows to pass rectangular A and B when left = True
734 auto m = left ? A.size(-1) : B.size(-2);
735 auto n = B.size(-1);
736 auto lda = std::max<int64_t>(1, A.size(-2));
737 auto ldb = std::max<int64_t>(1, B.size(-2));
738
739 for (const auto i : c10::irange(batch_size)) {
740 const scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
741 scalar_t* B_working_ptr = &B_data[i * B_mat_stride];
742 blasTriangularSolve<scalar_t>(side, uplo, trans, diag, m, n, const_cast<scalar_t*>(A_working_ptr), lda, B_working_ptr, ldb);
743 }
744 #endif
745 }
746
triangular_solve_kernel(const Tensor & A,const Tensor & B,bool left,bool upper,TransposeType transpose,bool unitriangular)747 void triangular_solve_kernel(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular) {
748 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "triangular_solve_cpu", [&]{
749 apply_triangular_solve<scalar_t>(A, B, left, upper, transpose, unitriangular);
750 });
751 }
752
753 template <typename scalar_t>
apply_ldl_factor(const Tensor & A,const Tensor & pivots,const Tensor & info,bool upper,bool hermitian)754 void apply_ldl_factor(
755 const Tensor& A,
756 const Tensor& pivots,
757 const Tensor& info,
758 bool upper,
759 bool hermitian) {
760 #if !AT_BUILD_WITH_LAPACK()
761 TORCH_CHECK(
762 false,
763 "Calling torch.linalg.ldl_factor on a CPU tensor requires compiling ",
764 "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
765 #else
766 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(A) > 0);
767 auto batch_size = batchCount(A);
768 auto n = A.size(-2);
769 auto leading_dim = A.stride(-1);
770 auto uplo = upper ? 'U' : 'L';
771
772 auto a_stride = A.dim() > 2 ? A.stride(-3) : 0;
773 auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
774
775 auto a_data = A.data_ptr<scalar_t>();
776 auto pivots_data = pivots.data_ptr<int>();
777 auto info_data = info.data_ptr<int>();
778
779 auto ldl_func =
780 hermitian ? lapackLdlHermitian<scalar_t> : lapackLdlSymmetric<scalar_t>;
781
782 scalar_t wkopt;
783 ldl_func(uplo, n, a_data, leading_dim, pivots_data, &wkopt, -1, info_data);
784 using value_t = typename c10::scalar_value_type<scalar_t>::type;
785 int lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
786 Tensor work = at::empty({lwork}, A.dtype());
787 auto work_data = work.mutable_data_ptr<scalar_t>();
788
789 for (const auto i : c10::irange(batch_size)) {
790 scalar_t* a_working_ptr = &a_data[i * a_stride];
791 auto* pivots_working_ptr = &pivots_data[i * pivots_stride];
792 auto* info_working_ptr = &info_data[i];
793 ldl_func(
794 uplo,
795 n,
796 a_working_ptr,
797 leading_dim,
798 pivots_working_ptr,
799 work_data,
800 lwork,
801 info_working_ptr);
802 }
803 #endif
804 }
805
ldl_factor_kernel(const Tensor & LD,const Tensor & pivots,const Tensor & info,bool upper,bool hermitian)806 void ldl_factor_kernel(
807 const Tensor& LD,
808 const Tensor& pivots,
809 const Tensor& info,
810 bool upper,
811 bool hermitian) {
812 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
813 LD.scalar_type(), "ldl_factor_kernel_cpu", [&] {
814 apply_ldl_factor<scalar_t>(LD, pivots, info, upper, hermitian);
815 });
816 }
817
818 template <typename scalar_t>
apply_ldl_solve(const Tensor & A,const Tensor & pivots,const Tensor & B,bool upper,bool hermitian)819 void apply_ldl_solve(
820 const Tensor& A,
821 const Tensor& pivots,
822 const Tensor& B,
823 bool upper,
824 bool hermitian) {
825 #if !AT_BUILD_WITH_LAPACK()
826 TORCH_CHECK(
827 false,
828 "Calling torch.linalg.ldl_factor on a CPU tensor requires compiling ",
829 "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
830 #else
831 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(A) > 0);
832 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(batchCount(pivots.unsqueeze(-1)) > 0);
833 auto batch_size = batchCount(B);
834 auto n = A.size(-2);
835 auto nrhs = B.size(-1);
836 auto lda = A.stride(-1);
837 auto ldb = B.stride(-1);
838 auto uplo = upper ? 'U' : 'L';
839
840 auto a_stride = A.dim() > 2 ? A.stride(-3) : 0;
841 auto b_stride = B.dim() > 2 ? B.stride(-3) : 0;
842 auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
843
844 auto a_data = A.const_data_ptr<scalar_t>();
845 auto b_data = B.data_ptr<scalar_t>();
846 auto pivots_ = pivots.to(kInt);
847 auto pivots_data = pivots_.const_data_ptr<int>();
848
849 auto ldl_solve_func = hermitian ? lapackLdlSolveHermitian<scalar_t>
850 : lapackLdlSolveSymmetric<scalar_t>;
851
852 int info = 0;
853 for (const auto i : c10::irange(batch_size)) {
854 const scalar_t* a_working_ptr = &a_data[i * a_stride];
855 scalar_t* b_working_ptr = &b_data[i * b_stride];
856 const auto* pivots_working_ptr = &pivots_data[i * pivots_stride];
857 ldl_solve_func(
858 uplo,
859 n,
860 nrhs,
861 const_cast<scalar_t*>(a_working_ptr),
862 lda,
863 const_cast<int*>(pivots_working_ptr),
864 b_working_ptr,
865 ldb,
866 &info);
867 }
868 TORCH_INTERNAL_ASSERT(info == 0);
869 #endif
870 }
871
ldl_solve_kernel(const Tensor & LD,const Tensor & pivots,const Tensor & result,bool upper,bool hermitian)872 void ldl_solve_kernel(
873 const Tensor& LD,
874 const Tensor& pivots,
875 const Tensor& result,
876 bool upper,
877 bool hermitian) {
878 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
879 LD.scalar_type(), "ldl_solve_kernel_cpu", [&] {
880 apply_ldl_solve<scalar_t>(LD, pivots, result, upper, hermitian);
881 });
882 }
883
884 /*
885 Computes the LU decomposition of a m×n matrix or batch of matrices in 'input' tensor.
886 This is an in-place routine, content of 'input', 'pivots', and 'infos' is overwritten.
887
888 Args:
889 * `input` - [in] the input matrix for LU decomposition
890 [out] the LU decomposition
891 * `pivots` - [out] the pivot indices
892 * `infos` - [out] error codes, positive values indicate singular matrices
893 * `compute_pivots` - should always be true (can be false only for CUDA)
894
895 For further details, please see the LAPACK documentation for GETRF.
896 */
897 template <typename scalar_t>
apply_lu_factor(const Tensor & input,const Tensor & pivots,const Tensor & infos,bool compute_pivots)898 void apply_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
899 #if !AT_BUILD_WITH_LAPACK()
900 TORCH_CHECK(
901 false,
902 "Calling torch.linalg.lu_factor on a CPU tensor requires compiling ",
903 "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
904 #else
905 TORCH_CHECK(compute_pivots, "linalg.lu_factor: LU without pivoting is not implemented on the CPU");
906
907 auto input_data = input.data_ptr<scalar_t>();
908 auto pivots_data = pivots.data_ptr<int>();
909 auto infos_data = infos.data_ptr<int>();
910 auto input_matrix_stride = matrixStride(input);
911 auto pivots_stride = pivots.size(-1);
912 auto batch_size = batchCount(input);
913 auto m = input.size(-2);
914 auto n = input.size(-1);
915 auto leading_dimension = std::max<int64_t>(1, m);
916
917 const auto loop = [&](int64_t start, int64_t end) {
918 for (const auto i : c10::irange(start, end)) {
919 scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
920 int* pivots_working_ptr = &pivots_data[i * pivots_stride];
921 int* infos_working_ptr = &infos_data[i];
922 lapackLu<scalar_t>(
923 m,
924 n,
925 input_working_ptr,
926 leading_dimension,
927 pivots_working_ptr,
928 infos_working_ptr);
929 }
930 };
931 // avoid overflow
932 float matrix_rank = float(std::min(m, n));
933 // A heuristic tested on a 32 core/socket ICX system
934 // https://github.com/pytorch/pytorch/pull/93037#discussion_r1090112948
935 int64_t chunk_size_per_thread = int64_t(
936 std::min(1.0, 3200.0 / (matrix_rank * matrix_rank * matrix_rank)));
937 int64_t grain_size = chunk_size_per_thread * at::get_num_threads();
938 at::parallel_for(0, batch_size, grain_size, loop);
939 #endif
940 }
941
942 // This is a type dispatching helper function for 'apply_lu'
lu_factor_kernel(const Tensor & input,const Tensor & pivots,const Tensor & infos,bool compute_pivots)943 void lu_factor_kernel(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) {
944 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "lu_cpu", [&]{
945 apply_lu_factor<scalar_t>(input, pivots, infos, compute_pivots);
946 });
947 }
948
949 /*
950 Solves the matrix equation A X = B
951 X and B are n-by-nrhs matrices, A is represented using the LU factorization.
952 This is an in-place routine, content of `b` is overwritten.
953
954 Args:
955 * `b` - [in] the right hand side matrix B
956 [out] the solution matrix X
957 * `lu` - [in] the LU factorization of matrix A (see at::linalg_lu_factor)
958 * `pivots` - [in] the pivot indices (see at::linalg_lu_factor)
959
960 For further details, please see the LAPACK documentation for GETRS.
961 */
962 template <typename scalar_t>
apply_lu_solve(const Tensor & LU,const Tensor & pivots,const Tensor & B,TransposeType transpose)963 void apply_lu_solve(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose) {
964 #if !AT_BUILD_WITH_LAPACK()
965 TORCH_CHECK(
966 false,
967 "Calling linalg.lu_solve on a CPU tensor requires compiling ",
968 "PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
969 #else
970 auto b_data = B.data_ptr<scalar_t>();
971 auto lu_data = LU.const_data_ptr<scalar_t>();
972 const auto trans = to_blas(transpose);
973 auto pivots_data = pivots.const_data_ptr<int>();
974 auto b_stride = matrixStride(B);
975 auto lu_stride = LU.dim() > 2 ? LU.stride(-3) : 0;
976 auto pivots_stride = pivots.dim() > 1 ? pivots.stride(-2) : 0;
977 auto batch_size = batchCount(B);
978
979 auto n = LU.size(-2);
980 auto nrhs = B.size(-1);
981 auto leading_dimension = std::max<int64_t>(1, n);
982
983 int info = 0;
984
985 // lu and pivots tensors can be broadcast to B
986 // here we construct a helper indexing tensor to linearly index into LU and pivots
987 IntArrayRef lu_batch_shape(LU.sizes().data(), LU.dim() - 2);
988 IntArrayRef b_batch_shape(B.sizes().data(), B.dim() - 2);
989 BroadcastLinearIndices lu_index(
990 batchCount(LU), lu_batch_shape, b_batch_shape);
991
992 for (const auto i : c10::irange(batch_size)) {
993 int64_t lu_index_i = lu_index(i);
994 scalar_t* b_working_ptr = &b_data[i * b_stride];
995 const scalar_t* lu_working_ptr = &lu_data[lu_index_i * lu_stride];
996 const int* pivots_working_ptr = &pivots_data[lu_index_i * pivots_stride];
997
998 lapackLuSolve<scalar_t>(trans, n, nrhs, const_cast<scalar_t*>(lu_working_ptr), leading_dimension, const_cast<int*>(pivots_working_ptr),
999 b_working_ptr, leading_dimension, &info);
1000
1001 // info from lapackLuSolve only reports if the i-th parameter is wrong
1002 // so we don't need to check it all the time
1003 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(info == 0);
1004 }
1005 #endif
1006 }
1007
1008 // This is a type dispatching helper function for 'apply_lu_solve'
lu_solve_kernel(const Tensor & LU,const Tensor & pivots,const Tensor & B,TransposeType trans)1009 void lu_solve_kernel(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType trans) {
1010 // Lapack will write into unrelated memory if pivots are not in the right range so we do
1011 // some simple sanity checks here for the CPU version
1012 TORCH_CHECK(pivots.gt(0).all().item<bool>(),
1013 "Pivots given to lu_solve must all be greater or equal to 1. "
1014 "Did you properly pass the result of lu_factor?");
1015 TORCH_CHECK(pivots.le(LU.size(-2)).all().item<bool>(),
1016 "Pivots given to lu_solve must all be smaller or equal to LU.size(-2). "
1017 "Did you properly pass the result of lu_factor?");
1018
1019 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "linalg.lu_solve_cpu", [&]{
1020 apply_lu_solve<scalar_t>(LU, pivots, B, trans);
1021 });
1022 }
1023
1024 template <typename scalar_t>
apply_svd(const Tensor & A,const bool full_matrices,const bool compute_uv,const Tensor & U,const Tensor & S,const Tensor & Vh,const Tensor & info)1025 static void apply_svd(const Tensor& A,
1026 const bool full_matrices,
1027 const bool compute_uv,
1028 const Tensor& U,
1029 const Tensor& S,
1030 const Tensor& Vh,
1031 const Tensor& info) {
1032 #if !AT_BUILD_WITH_LAPACK()
1033 TORCH_CHECK(false, "svd: LAPACK library not found in compilation");
1034 #else
1035 using value_t = typename c10::scalar_value_type<scalar_t>::type;
1036 const auto A_data = A.data_ptr<scalar_t>();
1037 const auto U_data = compute_uv ? U.data_ptr<scalar_t>() : nullptr;
1038 const auto S_data = S.data_ptr<value_t>();
1039 const auto info_data = info.data_ptr<int>();
1040 const auto Vh_data = compute_uv ? Vh.data_ptr<scalar_t>() : nullptr;
1041 const auto A_stride = matrixStride(A);
1042 const auto S_stride = S.size(-1);
1043 const auto U_stride = compute_uv ? matrixStride(U) : 1;
1044 const auto Vh_stride = compute_uv ? matrixStride(Vh) : 1;
1045 const auto batchsize = batchCount(A);
1046 const char jobz = compute_uv ? (full_matrices ? 'A' : 'S') : 'N';
1047
1048 const auto m = A.size(-2);
1049 const auto n = A.size(-1);
1050 const auto lda = A.stride(-1);
1051 const auto ldu= compute_uv ? U.stride(-1) : 1;
1052 const auto ldvh = compute_uv ? Vh.stride(-1) : 1;
1053
1054 auto iwork = std::vector<int>(8 * std::min(m, n));
1055 auto* const iwork_data = iwork.data();
1056
1057 // rwork is just used for the complex decomposition
1058 auto rwork = std::vector<value_t>{};
1059 if (A.is_complex()) {
1060 rwork.resize(std::max(computeLRWorkDim(jobz, m, n), int64_t{1}));
1061 }
1062 auto* const rwork_data = rwork.data();
1063
1064 // Query svd for the optimal lwork size
1065 int lwork = -1;
1066 {
1067 scalar_t wkopt;
1068 lapackSvd<scalar_t, value_t>(jobz, m, n, A_data, lda, S_data, U_data, ldu, Vh_data, ldvh, &wkopt, lwork, rwork_data, iwork_data, info_data);
1069 lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
1070 }
1071 auto work = std::vector<scalar_t>(lwork);
1072 auto* const work_data = work.data();
1073
1074 for (const auto i : c10::irange(batchsize)) {
1075 auto* const A_working_ptr = &A_data[i * A_stride];
1076 auto* const S_working_ptr = &S_data[i * S_stride];
1077 auto* const U_working_ptr = compute_uv ? &U_data[i * U_stride] : nullptr;
1078 auto* const Vh_working_ptr = compute_uv ? &Vh_data[i * Vh_stride] : nullptr;
1079
1080 // Compute S, U (optionally) and Vh (optionally)
1081 lapackSvd<scalar_t, value_t>(jobz, m, n, A_working_ptr, lda,
1082 S_working_ptr, U_working_ptr, ldu, Vh_working_ptr, ldvh, work_data, lwork, rwork_data, iwork_data, info_data + i);
1083 }
1084 #endif
1085 }
1086
svd_kernel(const Tensor & A,const bool full_matrices,const bool compute_uv,const std::optional<c10::string_view> & driver,const Tensor & U,const Tensor & S,const Tensor & Vh,const Tensor & infos)1087 void svd_kernel(const Tensor& A,
1088 const bool full_matrices,
1089 const bool compute_uv,
1090 const std::optional<c10::string_view>& driver,
1091 const Tensor& U,
1092 const Tensor& S,
1093 const Tensor& Vh,
1094 const Tensor& infos) {
1095 TORCH_INTERNAL_ASSERT(!driver.has_value(), "svd_kernel: driver shouldn't have a value here. ");
1096 // Need to copy A as column major, as its contents will be destroyed in the LAPACK call.
1097 // FIXME It'd be more efficient, rather than cloning A, to copy it into `U` or `Vh` (depending on m > n
1098 // or m < n) and call jobz='O'
1099 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(A.scalar_type(), "linalg_svd_cpu", [&]{
1100 apply_svd<scalar_t>(cloneBatchedColumnMajor(A), full_matrices, compute_uv, U, S, Vh, infos);
1101 });
1102 }
1103
unpack_pivots_cpu_kernel(TensorIterator & iter,const int64_t dim_size,const int64_t max_pivot)1104 void unpack_pivots_cpu_kernel(TensorIterator& iter, const int64_t dim_size, const int64_t max_pivot) {
1105 if (iter.numel() == 0 || dim_size == 0) {
1106 return;
1107 }
1108 auto loop = [&](char* const* const data, const int64_t* const strides, const int64_t nelems) {
1109 auto* perm_ptr = data[0];
1110 const auto* pivots_ptr = data[1];
1111
1112 for (C10_UNUSED const auto elem : c10::irange(nelems)) {
1113 // WARNING: linalg.lu_factor returns int32 pivots,
1114 // this behavior could change in the future.
1115 const auto perm_data = reinterpret_cast<int64_t*>(perm_ptr);
1116 const auto pivots_data = reinterpret_cast<const int32_t*>(pivots_ptr);
1117
1118 for (const auto i : c10::irange(dim_size)) {
1119 auto new_idx = pivots_data[i] - 1;
1120 TORCH_CHECK(new_idx >= 0 && new_idx < max_pivot,
1121 "pivots passed to lu_unpack must be between 1 and LU.size(-2) inclusive."
1122 "Did you properly pass the result of lu_factor?");
1123 std::swap(
1124 perm_data[i],
1125 perm_data[new_idx]
1126 );
1127 }
1128
1129 perm_ptr += strides[0];
1130 pivots_ptr += strides[1];
1131 }
1132 };
1133
1134 iter.for_each(loop);
1135 }
1136 } // anonymous namespace
1137
1138 REGISTER_ARCH_DISPATCH(cholesky_stub, DEFAULT, &cholesky_kernel);
1139 REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel);
1140 REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel);
1141 REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel);
1142 REGISTER_ZVECTOR_DISPATCH(cholesky_stub, &cholesky_kernel);
1143
1144 REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl);
1145 REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
1146 REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
1147 REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
1148 REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
1149
1150 REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel);
1151 REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
1152 REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
1153 REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
1154 REGISTER_ZVECTOR_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
1155
1156 REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel);
1157 REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
1158 REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
1159 REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
1160 REGISTER_ZVECTOR_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
1161
1162 REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel);
1163 REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel);
1164 REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel);
1165 REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel);
1166 REGISTER_ZVECTOR_DISPATCH(geqrf_stub, &geqrf_kernel);
1167
1168 REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl);
1169 REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
1170 REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
1171 REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
1172 REGISTER_ZVECTOR_DISPATCH(orgqr_stub, &orgqr_kernel_impl);
1173
1174 REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel);
1175 REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel);
1176 REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel);
1177 REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel);
1178 REGISTER_ZVECTOR_DISPATCH(ormqr_stub, &ormqr_kernel);
1179
1180 REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel);
1181 REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel);
1182 REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel);
1183 REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel);
1184 REGISTER_ZVECTOR_DISPATCH(lstsq_stub, &lstsq_kernel);
1185
1186 REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel);
1187 REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
1188 REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
1189 REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
1190 REGISTER_ZVECTOR_DISPATCH(triangular_solve_stub, &triangular_solve_kernel);
1191
1192 REGISTER_ARCH_DISPATCH(lu_factor_stub, DEFAULT, &lu_factor_kernel);
1193 REGISTER_AVX512_DISPATCH(lu_factor_stub, &lu_factor_kernel);
1194 REGISTER_AVX2_DISPATCH(lu_factor_stub, &lu_factor_kernel);
1195 REGISTER_VSX_DISPATCH(lu_factor_stub, &lu_factor_kernel);
1196 REGISTER_ZVECTOR_DISPATCH(lu_factor_stub, &lu_factor_kernel);
1197
1198 REGISTER_ARCH_DISPATCH(ldl_factor_stub, DEFAULT, &ldl_factor_kernel);
1199 REGISTER_AVX512_DISPATCH(ldl_factor_stub, &ldl_factor_kernel);
1200 REGISTER_AVX2_DISPATCH(ldl_factor_stub, &ldl_factor_kernel);
1201 REGISTER_VSX_DISPATCH(ldl_factor_stub, &ldl_factor_kernel);
1202 REGISTER_ZVECTOR_DISPATCH(ldl_factor_stub, &ldl_factor_kernel);
1203
1204 REGISTER_ARCH_DISPATCH(ldl_solve_stub, DEFAULT, &ldl_solve_kernel);
1205 REGISTER_AVX512_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
1206 REGISTER_AVX2_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
1207 REGISTER_VSX_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
1208 REGISTER_ZVECTOR_DISPATCH(ldl_solve_stub, &ldl_solve_kernel);
1209 REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel);
1210 REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel);
1211 REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel);
1212 REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel);
1213 REGISTER_ZVECTOR_DISPATCH(lu_solve_stub, &lu_solve_kernel);
1214
1215 REGISTER_ARCH_DISPATCH(svd_stub, DEFAULT, &svd_kernel);
1216 REGISTER_AVX512_DISPATCH(svd_stub, &svd_kernel);
1217 REGISTER_AVX2_DISPATCH(svd_stub, &svd_kernel);
1218 REGISTER_VSX_DISPATCH(svd_stub, &svd_kernel);
1219 REGISTER_ZVECTOR_DISPATCH(svd_stub, &svd_kernel);
1220
1221 REGISTER_ARCH_DISPATCH(unpack_pivots_stub, DEFAULT, &unpack_pivots_cpu_kernel);
1222 REGISTER_AVX512_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
1223 REGISTER_AVX2_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
1224 REGISTER_VSX_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
1225 REGISTER_ZVECTOR_DISPATCH(unpack_pivots_stub, &unpack_pivots_cpu_kernel);
1226 } // namespace at::native
1227