xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/LinearAlgebraUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/ScalarType.h>
4 #include <c10/util/irange.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/strides.h>
7 #include <ATen/core/Tensor.h>
8 #include <ATen/ExpandUtils.h>
9 #include <ATen/TensorUtils.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/native/TransposeType.h>
12 #include <limits>
13 #include <type_traits>
14 #include <sstream>
15 #include <cstring>
16 #include <cctype>
17 
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #else
21 #include <ATen/ops/arange.h>
22 #include <ATen/ops/empty.h>
23 #include <ATen/ops/empty_like.h>
24 #include <ATen/ops/empty_strided.h>
25 #include <ATen/ops/zeros.h>
26 #endif
27 
28 namespace at::native {
29 
expect_resolved_conj(const Tensor & tensor)30 inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
31   if (tensor.is_conj()) {
32     return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
33   } else {
34     return c10::MaybeOwned<Tensor>::borrowed(tensor);
35   }
36 }
37 
38 inline DimVector batched_matrix_contiguous_strides(
39     const IntArrayRef sizes,
40     const bool f_contig = false) {
41   // f_contig chooses between the strides of a batch of Fortran (F-contiguous)
42   // and C-contiguous matrices
43   auto strides = c10::contiguous_strides(sizes);
44   auto dim = strides.size();
45 
46   if (f_contig && dim >= 2) {
47     // Fix the strides of the last two dimensions, so that we return
48     // C-contiguous batches of F-contiguous matrices.
49     strides[dim - 1] = std::max(sizes[dim - 2], static_cast<int64_t>(1));
50     strides[dim - 2] = 1;
51   }
52   return strides;
53 }
54 
55 /*
56  * Clones a Tensor so that the following conditions hold:
57  * If we think of a Tensor of having size (B, M, N), where B is any number
58  * of batch dimensions, then:
59  * - Each (M, N) matrix is in column major form
60  * - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
61  *   Then when laid out in memory, the M by N matrix starting at
62  *   P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
63  *   matrix starting at Q.data_ptr()[B * M' * N'].
64  */
cloneBatchedColumnMajor(const Tensor & src)65 inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
66   // If src is already in batched column major format, then
67   // this will be efficient (no reordering of the data will occur)
68   // because the first transpose will make the tensor contiguous,
69   // and cloning a contiguous tensor is fast.
70   auto result = src.mT().clone(at::MemoryFormat::Contiguous);
71   result.transpose_(-2, -1);
72   return result;
73 }
74 
75 /*
76  * contig chooses between C-contig (true) and F-contig (false)
77  */
borrow_else_clone(const bool cond,const Tensor & borrow,const Tensor & clone,const bool contig)78 inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
79   return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
80               : c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
81                                                       : cloneBatchedColumnMajor(clone));
82 }
83 
84 /*
85  * This method is designed to be a faster alternative to
86  * `cloneBatchedColumnMajor` with some additional features,
87  * namely:
88  * 1. It uses `copy` instead of `clone` which could be much faster.
89  * 2. `nrows` parameter used to create inputs with the number of rows larger
90  *  than the original input, which is required for some LAPACK/MAGMA methods.
91  * 3. `desired_batch_size` is used to create copies with the batch size
92  *  which is either the original batch size of the input, or its larger
93  *  broadcasted shape.
94  */
95 inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
96     at::OptionalIntArrayRef desired_batch_sizes = std::nullopt) {
97   nrows = (nrows == -1) ? src.size(-2) : nrows;
98   auto copy_sizes = desired_batch_sizes.has_value()
99     ? desired_batch_sizes.value().vec()
100     : IntArrayRef(src.sizes().data(), src.dim() - 2).vec();
101   copy_sizes.insert(copy_sizes.end(), {nrows, src.size(-1)});
102   const auto copy_strides = batched_matrix_contiguous_strides(copy_sizes, /*f-contig*/true);
103   auto copy = at::empty_strided(copy_sizes, copy_strides, src.options());
104   copy.narrow(-2, 0, src.size(-2)).copy_(src);
105   return copy;
106 }
107 
108 /*
109  * Given batches of matrices with arbitrary batch dim,
110  * computes the number of batches.
111  */
batchCount(const Tensor & batched_matrices)112 inline int64_t batchCount(const Tensor& batched_matrices) {
113   int64_t result = 1;
114   for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
115     result *= batched_matrices.size(i);
116   }
117   return result;
118 }
119 
120 // Computes the number of elements of a matrix in a batched matrix tensor
matrixStride(const Tensor & batched_matrices)121 inline int64_t matrixStride(const Tensor& batched_matrices) {
122   return batched_matrices.size(-1) * batched_matrices.size(-2);
123 }
124 
125 // Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
126 inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
127   TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
128 }
129 inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
130   checkIsMatrix(self, f_name, arg_name);
131   TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
132               f_name,
133               ": ", arg_name, " must be batches of square matrices, "
134               "but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
135 }
136 
checkInputsSolver(const Tensor & A,const Tensor & B,const bool left,const char * const f_name)137 inline void checkInputsSolver(const Tensor& A,
138                                      const Tensor& B,
139                                      const bool left,
140                                      const char* const f_name) {
141   squareCheckInputs(A, f_name, "A");
142   checkIsMatrix(B, f_name, "B");
143   TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
144               f_name, ": Incompatible shapes of A and B for the equation ",
145               left ? "AX = B" : "XA = B",
146               " (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
147 }
148 
is_row_or_column_contiguous(const Tensor & t)149 inline bool is_row_or_column_contiguous(const Tensor& t) {
150   // This could be made more general, similar to how it's checked in matmul, which would allow to
151   // ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
152   // We choose to be conservative for simplicity
153   return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
154 }
155 
to_transpose_type(const bool contig,const bool conj)156 inline TransposeType to_transpose_type(const bool contig, const bool conj) {
157   if (conj) {
158     if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
159     else {        return TransposeType::ConjTranspose; }
160   } else {
161     if (contig) { return TransposeType::NoTranspose; }
162     else {        return TransposeType::Transpose; }
163   }
164 }
165 
166 
167 // This function is designed to be used with linear algebra methods that minimize
168 // L(ax - b) = 0, where L is generally the identity map (`solve`, for example)
169 // or the L2 norm (`lstsq`).
170 // It is expected that `a` and `b` are contiguous tensors of column-major matrices
171 // (so that a.view({-1, a.size(-2), a.size(-1)}) succeeds, same for `b`),
172 // with the following additional properties:
173 //
174 // 1. a.dim() == b.dim()
175 // 2. a.shape[:-2] broadcasts over b.shape[:-2]
176 // 3. a.size(i) <= b.size(i) for i=0,..., a.dim() - 3 (only for batch dimensions)
177 //
178 // MAGMA/LAPACK modify tensor `a` in-place, and the main goal of this method
179 // is to be memory efficient, which means that if there exists an index i such that
180 // a.shape[i] < b.shape[i], 0 <= i <= a.dim() - 3,
181 // then instead of materializing copies of `a` in the broadcasted shape, we keep
182 // a buffer copy of `a` along with flags that check whether specific batch dimension
183 // indices for `a` were already accessed. If they were, we copy the data from the buffer
184 // into `a`. The number of copies does not exceed
185 // prod(max(a.shape[:-2], b.shape[:-2]) - a.shape[:-2] + 1)
186 // and this value is attained by tensors with non-empty batch dimensions.
187 //
188 // func_t `f` is a callable that is being supplied with
189 // scalar_t* a_working_ptr, scalar_t* b_working_ptr, int64_t a_linear_batch_idx.
190 // a_working_ptr and b_working_ptr can directly be passed to LAPACK/MAGMA routines,
191 // and a_linear_batch_idx is an index in the 3d representation which corresponds to
192 // the memory a_working_ptr points to, in other words:
193 // a_working_ptr == a.view({-1, a.size(-2), a.size(-1)}.select(0, a_linear_batch_idx).data_ptr<scalar_t>();
194 // a_linear_batch_idx is useful to store metadata related to `a`, such as, for example,
195 // its rank or singular values (see linalg_lstsq).
196 template<typename scalar_t, typename func_t>
batch_iterator_with_broadcasting(const Tensor & a,const Tensor & b,const func_t & f)197 void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
198   IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
199   IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);
200 
201   auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
202   auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);
203 
204   TensorIterator iter = TensorIteratorConfig()
205     .set_check_mem_overlap(false)
206     .check_all_same_dtype(false)
207     .resize_outputs(false)
208     .add_output(b_linear_batch_idx)
209     .add_input(a_linear_batch_idx)
210     .build();
211 
212   auto m = a.size(-2);
213   auto n = a.size(-1);
214   auto a_3d = a.view({batchCount(a), m, n});
215   auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});
216 
217   auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
218   Tensor a_buffer, a_was_accessed, a_buffer_3d;
219   std::function<void(int64_t)> check_if_copy_needed_for_a
220     = [](int64_t /*a_curr_linear_batch_idx*/){};
221   if (a_broadcasts_over_b) {
222     a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
223       .copy_(a);
224     a_was_accessed = at::zeros(batchCount(a), at::kBool);
225     a_buffer_3d = a_buffer.view({batchCount(a), m, n});
226     check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
227       auto* a_was_accessed_flag = a_was_accessed
228         .select(0, a_curr_linear_batch_idx)
229         .data_ptr<bool>();
230       if (!(*a_was_accessed_flag)) {
231         *a_was_accessed_flag = true;
232       }
233       else {
234         a_3d.select(0, a_curr_linear_batch_idx)
235           .copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
236       }
237     };
238   }
239 
240   auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
241     auto* b_batch_idx_ptr = data[0];
242     auto* a_batch_idx_ptr = data[1];
243 
244     for (const auto elem C10_UNUSED : c10::irange(nelems)) {
245       auto b_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(b_batch_idx_ptr);
246       auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);
247 
248       check_if_copy_needed_for_a(a_curr_linear_batch_idx);
249 
250       auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
251         .data_ptr<scalar_t>();
252       auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
253         .data_ptr<scalar_t>();
254       f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);
255 
256       b_batch_idx_ptr += strides[0];
257       a_batch_idx_ptr += strides[1];
258     }
259   };
260   iter.serial_for_each(loop, {0, batchCount(b)});
261 }
262 
263 // Returns the epsilon value for floating types except half
_get_epsilon(const ScalarType & sc_type)264 inline double _get_epsilon(const ScalarType& sc_type) {
265   switch (sc_type) {
266     case at::ScalarType::Float:
267       return static_cast<double>(std::numeric_limits<float>::epsilon());
268     case at::ScalarType::Double:
269       return std::numeric_limits<double>::epsilon();
270     default:
271       AT_ERROR("This function doesn't handle types other than float and double");
272   }
273 }
274 
275 // Validates input shapes and devices
276 // for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
linearSolveCheckInputs(const Tensor & self,const Tensor & A,const char * name)277 inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
278   TORCH_CHECK(self.device() == A.device(),
279               "Expected b and A to be on the same device, but found b on ",
280               self.device(), " and A on ", A.device(), " instead.");
281 
282   TORCH_CHECK(self.scalar_type() == A.scalar_type(),
283               "Expected b and A to have the same dtype, but found b of type ",
284               self.scalar_type(), " and A of type ", A.scalar_type(), " instead.");
285 
286   TORCH_CHECK(A.size(-1) == A.size(-2),
287               "A must be batches of square matrices, "
288               "but they are ", A.size(-2), " by ", A.size(-1), " matrices");
289 
290   TORCH_CHECK(A.size(-1) == self.size(-2),
291               "Incompatible matrix sizes for ", name, ": each A "
292               "matrix is ", A.size(-1), " by ", A.size(-1),
293               " but each b matrix is ", self.size(-2), " by ", self.size(-1));
294 }
295 
296 inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
297   auto dtype = t.scalar_type();
298   TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
299               f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
300   if (!allow_low_precision_dtypes) {
301     TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble,
302                 f_name, ": Low precision dtypes not supported. Got ", dtype);
303   }
304 }
305 
306 
307 // Checks if all the Tensors in a TensorList are of the same dimensions
checkAllSameDim(TensorList tensors,int64_t dim)308 inline void checkAllSameDim(TensorList tensors, int64_t dim) {
309   for (auto &t : tensors) {
310     TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
311   }
312 }
313 
_linalg_broadcast_batch_dims(const Tensor & arg1,const Tensor & arg2)314 inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
315   // broadcast the batch dimensions of arg1 and arg2.
316   IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
317   IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
318   std::vector<int64_t> expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);
319 
320   std::vector<int64_t> arg1_expand_size({expand_batch_portion});
321   arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });
322 
323   std::vector<int64_t> arg2_expand_size({expand_batch_portion});
324   arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });
325   return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
326 }
327 
_linalg_broadcast_batch_dims(const Tensor & arg1,const Tensor & arg2,const char * name)328 inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
329   // If there's no name we assume we don't want to check the errors
330   if (name != nullptr) {
331     linearSolveCheckInputs(arg1, arg2, name);
332   }
333 
334   auto [arg1_expand_size, arg2_expand_size] = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
335 
336   auto arg1_broadcasted  = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
337   auto arg2_broadcasted  = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
338   return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
339 }
340 
broadcast_batch_size(const Tensor & t1,const Tensor & t2,int64_t n_batch_dims)341 inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
342   IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
343   IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
344   auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
345   return broadcasted_batch_sizes;
346 }
347 
348 // Return a permutation with the given axes moved to the end.
_move_to_end(const Tensor & self,IntArrayRef axes)349 inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
350   const std::vector<int64_t> a = axes.vec();
351   const int64_t ndim = self.ndimension();
352   std::vector<int64_t> perm;
353 
354   for (const auto i : c10::irange(ndim)) {
355     auto it = std::find(a.begin(), a.end(), i);
356     if (it == a.end()) {
357        perm.push_back(i);
358     }
359   }
360   for (auto i : a) {
361     perm.push_back(i);
362   }
363 
364   TORCH_CHECK((int64_t)perm.size() == ndim,
365     "duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
366 
367   return self.permute(perm);
368 }
369 
370 // parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
_parse_qr_mode(c10::string_view mode)371 inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
372   bool compute_q;
373   bool reduced;
374   if (mode == "reduced") {
375     compute_q = true;
376     reduced = true;
377   } else if (mode == "complete") {
378     compute_q = true;
379     reduced = false;
380   } else if (mode == "r") {
381     compute_q = false;
382     reduced = true; // this is actually irrelevant in this mode
383   } else {
384       TORCH_CHECK(false, "qr received unrecognized mode '", mode,
385                   "' but expected one of 'reduced' (default), 'r', or 'complete'");
386   }
387   return std::make_tuple(compute_q, reduced);
388 }
389 
390 // Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
_compute_geometry_for_Q(const Tensor & input,bool reduced)391 inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
392     const Tensor& input,
393     bool reduced) {
394   int64_t m = input.size(-2), n = input.size(-1);
395   int64_t n_columns_q;
396 
397   // We need to compute the required size of Q based on the `reduced` option
398   DimVector q_sizes(input.sizes());
399   if (!reduced && m > n) {
400     q_sizes[input.dim() - 1] = m;
401     n_columns_q = m;
402   } else {
403     q_sizes[input.dim() - 1] = n;
404     n_columns_q = std::min(m, n);
405   }
406   auto q_strides = batched_matrix_contiguous_strides(q_sizes, /*f-contig*/true);
407   return std::make_tuple(q_sizes, q_strides, n_columns_q);
408 }
409 
svd_uses_cusolver(const Tensor & A)410 inline bool svd_uses_cusolver(const Tensor& A) {
411   // if cusolver is available, it is used unconditionally
412   return A.is_cuda()
413          && at::globalContext().hasCuSOLVER()
414          && at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma;
415 }
416 
417 
418 // Function used instead of .to so that the original strides are retained
419 // .to doesn't retain strides and make the output tensor contiguous
same_stride_to(const Tensor & original_tensor,const at::TensorOptions & options)420 inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
421   auto strided_to = at::empty_strided(original_tensor.sizes(),
422                                       original_tensor.strides(),
423                                       options);
424   strided_to.copy_(original_tensor);
425   return strided_to;
426 }
427 
428 // Creates a dimension permutation array that can be given to `at::permute()`, which will shift
429 // the two specified dimensions to the end of a tensor, without changing the order of
430 // the other dimensions. `dim1` will be placed at the very end, and `dim0` will be
431 // placed just to the left of it.
432 //
433 // For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
434 // calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
435 // be `vec(0, 2, 1, 3)`.
create_dim_backshift_permutation(int64_t dim0,int64_t dim1,int64_t ndim)436 inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
437   TORCH_CHECK(
438     (dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
439     "duplicate or invalid dimensions");
440   std::vector<int64_t> permutation(ndim);
441   int64_t cur_permuted_dim = 0;
442   for (const auto dim_ind : c10::irange(ndim)) {
443     if ((dim_ind != dim0) && (dim_ind != dim1)) {
444       permutation[cur_permuted_dim++] = dim_ind;
445     }
446   }
447   permutation[cur_permuted_dim++] = dim0;
448   permutation[cur_permuted_dim] = dim1;
449   return permutation;
450 }
451 
452 // Creates a dimension permutation array that can be given to `at::permute()`, which
453 // will reverse a given permutation.
454 // The reverse permutation array is created by swapping the indices and their
455 // associated values from the given permutation array.
create_reverse_permutation(std::vector<int64_t> permutation)456 inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
457   int64_t ndim = permutation.size();
458   std::vector<int64_t> reverse_permutation(ndim);
459   for (const auto dim_ind : c10::irange(ndim)) {
460     reverse_permutation[permutation[dim_ind]] = dim_ind;
461   }
462   return reverse_permutation;
463 }
464 
465 // Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd
466 // See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186
computeLRWorkDim(const char jobz,int64_t m,int64_t n)467 inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
468   auto mn = std::min(m, n);
469   auto mx = std::max(m, n);
470   if (jobz == 'N') {
471 #ifdef __APPLE__
472     // According to `vecLib.framework/Headers/clapack.h` Accelerate.framework is based on LAPACK 3.2.1
473     return 7 * mn;
474 #else
475     // These setting is valid for on LAPACK 3.6+
476     return 5 * mn;
477 #endif
478   }
479   if (mx > 10 * mn) {
480     return 5 * mn * mn + 5 * mn;
481   }
482   return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
483 }
484 
485 // This function checks whether the uplo argument input is valid
486 // Allowed strings are "u", "U", "l", "L"
checkUplo(const c10::string_view uplo)487 inline void checkUplo(const c10::string_view uplo) {
488   // To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
489   char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
490   TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
491     "Expected UPLO argument to be 'L' or 'U', but got ", uplo);
492 }
493 
494 inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
495   TORCH_CHECK(
496       result.device() == input.device(),
497       fn_name,
498       ": Expected ", result_name, " and input tensors to be on the same device, but got ",
499       result_name, " on ", result.device(), " and input on ", input.device());
500 }
501 
502 // Check the dtype of result and input tensors (for _out variants).
503 // Most linear algebra functions have the same dtype for input and output
504 // (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype.
505 // According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
506 // c10::canCast is used for checking the "safe copy" dtype requirements.
507 inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
508   bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
509   TORCH_CHECK(
510       can_cast,
511       fn_name,
512       ": Expected ", result_name, " to be safely castable from ", input.scalar_type(), " dtype, but got ",
513       result_name, " with dtype ", result.scalar_type());
514 }
515 
516 // Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type)
517 inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
518   bool can_cast = c10::canCast(result_type, out_type);
519   TORCH_CHECK(
520       can_cast,
521       fn_name,
522       ": Expected ", out_name, " to be safely castable from ", result_type, " dtype, but got ",
523       out_name, " with dtype ", out_type);
524 }
525 
checkNotComplexTolerance(const Tensor & tol,const c10::string_view f_name,const c10::string_view tol_name)526 inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
527   TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
528               f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
529 }
530 
531 /*
532   Two types of 'other' tensors are supported when solving
533   a system of linear equations matmul(input, x) = other:
534   * 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
535   * 2-dimensional (2D) tensor or batch of 2D tensors (matrix case).
536   The original torch.solve supported only the matrix case, while NumPy works for both cases.
537   For the batched input we need to be able to distinguish them.
538   Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
539   This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
540 */
linalg_solve_is_vector_rhs(const Tensor & input,const Tensor & other)541 inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
542   auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1]
543   bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape));
544   return vector_case;
545 }
546 
547 /*
548   Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor.
549 */
get_linear_indices(int64_t numel,IntArrayRef original_shape,IntArrayRef broadcast_shape)550 inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
551   TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
552   return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
553 }
554 
555 class BroadcastLinearIndices {
556  private:
557   Tensor linear_indices_;
558   bool is_broadcasting_;
559 
560  public:
BroadcastLinearIndices(int64_t numel,IntArrayRef original_shape,IntArrayRef broadcast_shape)561   BroadcastLinearIndices(
562       int64_t numel,
563       IntArrayRef original_shape,
564       IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
565     // The assumption is that the broadcast_shape is a materialized broadcast
566     // shape of the original_shape. We need to compute the linear indices
567     // compatible with the original_shape to access the elements in the original
568     // tensor corresponding to the broadcast tensor.
569     if (is_broadcasting_) {
570       linear_indices_ =
571           get_linear_indices(numel, original_shape, broadcast_shape);
572     }
573   }
operator()574   int64_t operator()(int64_t broadcast_linear_index) {
575     return is_broadcasting_
576         ? linear_indices_.data_ptr<int64_t>()[broadcast_linear_index]
577         : broadcast_linear_index;
578   }
579 };
580 
is_blas_compatible_column_major_order(const Tensor & input)581 inline bool is_blas_compatible_column_major_order(const Tensor& input) {
582   IntArrayRef input_strides = input.strides();
583   IntArrayRef input_sizes = input.sizes();
584   auto ndim = input.dim();
585   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
586   if (ndim > 3) {
587     return input.transpose(-2, -1).is_contiguous();
588   }
589   auto leading_dimension = input_strides[ndim - 1];
590   auto rows = input_sizes[ndim - 2];
591   bool batch_stride_compatible = true;
592   if (ndim == 3) {
593     auto cols = input_sizes[ndim - 1];
594     batch_stride_compatible =
595         input_strides[ndim - 3] >= leading_dimension * cols;
596   }
597   return (input_strides[ndim - 2] == 1) &&
598       (leading_dimension >= std::max<int64_t>(1, rows)) &&
599       batch_stride_compatible;
600 }
601 
is_blas_compatible_row_major_order(const Tensor & input)602 inline bool is_blas_compatible_row_major_order(const Tensor& input) {
603   IntArrayRef input_strides = input.strides();
604   IntArrayRef input_sizes = input.sizes();
605   auto ndim = input.dim();
606   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
607   if (ndim > 3) {
608     return input.is_contiguous();
609   }
610   auto leading_dimension = input_strides[ndim - 2];
611   auto cols = input_sizes[ndim - 1];
612   bool batch_stride_compatible = true;
613   if (ndim == 3) {
614     auto rows = input_sizes[ndim - 2];
615     batch_stride_compatible =
616         input_strides[ndim - 3] >= leading_dimension * rows;
617   }
618   return (input_strides[ndim - 1] == 1) &&
619       (leading_dimension >= std::max<int64_t>(1, cols)) &&
620       batch_stride_compatible;
621 }
622 
623 }  // namespace at::native
624