1 #pragma once
2
3 #include <c10/core/ScalarType.h>
4 #include <c10/util/irange.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/strides.h>
7 #include <ATen/core/Tensor.h>
8 #include <ATen/ExpandUtils.h>
9 #include <ATen/TensorUtils.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/native/TransposeType.h>
12 #include <limits>
13 #include <type_traits>
14 #include <sstream>
15 #include <cstring>
16 #include <cctype>
17
18 #ifndef AT_PER_OPERATOR_HEADERS
19 #include <ATen/Functions.h>
20 #else
21 #include <ATen/ops/arange.h>
22 #include <ATen/ops/empty.h>
23 #include <ATen/ops/empty_like.h>
24 #include <ATen/ops/empty_strided.h>
25 #include <ATen/ops/zeros.h>
26 #endif
27
28 namespace at::native {
29
expect_resolved_conj(const Tensor & tensor)30 inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
31 if (tensor.is_conj()) {
32 return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
33 } else {
34 return c10::MaybeOwned<Tensor>::borrowed(tensor);
35 }
36 }
37
38 inline DimVector batched_matrix_contiguous_strides(
39 const IntArrayRef sizes,
40 const bool f_contig = false) {
41 // f_contig chooses between the strides of a batch of Fortran (F-contiguous)
42 // and C-contiguous matrices
43 auto strides = c10::contiguous_strides(sizes);
44 auto dim = strides.size();
45
46 if (f_contig && dim >= 2) {
47 // Fix the strides of the last two dimensions, so that we return
48 // C-contiguous batches of F-contiguous matrices.
49 strides[dim - 1] = std::max(sizes[dim - 2], static_cast<int64_t>(1));
50 strides[dim - 2] = 1;
51 }
52 return strides;
53 }
54
55 /*
56 * Clones a Tensor so that the following conditions hold:
57 * If we think of a Tensor of having size (B, M, N), where B is any number
58 * of batch dimensions, then:
59 * - Each (M, N) matrix is in column major form
60 * - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
61 * Then when laid out in memory, the M by N matrix starting at
62 * P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
63 * matrix starting at Q.data_ptr()[B * M' * N'].
64 */
cloneBatchedColumnMajor(const Tensor & src)65 inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
66 // If src is already in batched column major format, then
67 // this will be efficient (no reordering of the data will occur)
68 // because the first transpose will make the tensor contiguous,
69 // and cloning a contiguous tensor is fast.
70 auto result = src.mT().clone(at::MemoryFormat::Contiguous);
71 result.transpose_(-2, -1);
72 return result;
73 }
74
75 /*
76 * contig chooses between C-contig (true) and F-contig (false)
77 */
borrow_else_clone(const bool cond,const Tensor & borrow,const Tensor & clone,const bool contig)78 inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
79 return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
80 : c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
81 : cloneBatchedColumnMajor(clone));
82 }
83
84 /*
85 * This method is designed to be a faster alternative to
86 * `cloneBatchedColumnMajor` with some additional features,
87 * namely:
88 * 1. It uses `copy` instead of `clone` which could be much faster.
89 * 2. `nrows` parameter used to create inputs with the number of rows larger
90 * than the original input, which is required for some LAPACK/MAGMA methods.
91 * 3. `desired_batch_size` is used to create copies with the batch size
92 * which is either the original batch size of the input, or its larger
93 * broadcasted shape.
94 */
95 inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
96 at::OptionalIntArrayRef desired_batch_sizes = std::nullopt) {
97 nrows = (nrows == -1) ? src.size(-2) : nrows;
98 auto copy_sizes = desired_batch_sizes.has_value()
99 ? desired_batch_sizes.value().vec()
100 : IntArrayRef(src.sizes().data(), src.dim() - 2).vec();
101 copy_sizes.insert(copy_sizes.end(), {nrows, src.size(-1)});
102 const auto copy_strides = batched_matrix_contiguous_strides(copy_sizes, /*f-contig*/true);
103 auto copy = at::empty_strided(copy_sizes, copy_strides, src.options());
104 copy.narrow(-2, 0, src.size(-2)).copy_(src);
105 return copy;
106 }
107
108 /*
109 * Given batches of matrices with arbitrary batch dim,
110 * computes the number of batches.
111 */
batchCount(const Tensor & batched_matrices)112 inline int64_t batchCount(const Tensor& batched_matrices) {
113 int64_t result = 1;
114 for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
115 result *= batched_matrices.size(i);
116 }
117 return result;
118 }
119
120 // Computes the number of elements of a matrix in a batched matrix tensor
matrixStride(const Tensor & batched_matrices)121 inline int64_t matrixStride(const Tensor& batched_matrices) {
122 return batched_matrices.size(-1) * batched_matrices.size(-2);
123 }
124
125 // Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
126 inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
127 TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
128 }
129 inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
130 checkIsMatrix(self, f_name, arg_name);
131 TORCH_CHECK(self.sym_size(-1) == self.sym_size(-2),
132 f_name,
133 ": ", arg_name, " must be batches of square matrices, "
134 "but they are ", self.sym_size(-2), " by ", self.sym_size(-1), " matrices");
135 }
136
checkInputsSolver(const Tensor & A,const Tensor & B,const bool left,const char * const f_name)137 inline void checkInputsSolver(const Tensor& A,
138 const Tensor& B,
139 const bool left,
140 const char* const f_name) {
141 squareCheckInputs(A, f_name, "A");
142 checkIsMatrix(B, f_name, "B");
143 TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
144 f_name, ": Incompatible shapes of A and B for the equation ",
145 left ? "AX = B" : "XA = B",
146 " (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
147 }
148
is_row_or_column_contiguous(const Tensor & t)149 inline bool is_row_or_column_contiguous(const Tensor& t) {
150 // This could be made more general, similar to how it's checked in matmul, which would allow to
151 // ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
152 // We choose to be conservative for simplicity
153 return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
154 }
155
to_transpose_type(const bool contig,const bool conj)156 inline TransposeType to_transpose_type(const bool contig, const bool conj) {
157 if (conj) {
158 if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
159 else { return TransposeType::ConjTranspose; }
160 } else {
161 if (contig) { return TransposeType::NoTranspose; }
162 else { return TransposeType::Transpose; }
163 }
164 }
165
166
167 // This function is designed to be used with linear algebra methods that minimize
168 // L(ax - b) = 0, where L is generally the identity map (`solve`, for example)
169 // or the L2 norm (`lstsq`).
170 // It is expected that `a` and `b` are contiguous tensors of column-major matrices
171 // (so that a.view({-1, a.size(-2), a.size(-1)}) succeeds, same for `b`),
172 // with the following additional properties:
173 //
174 // 1. a.dim() == b.dim()
175 // 2. a.shape[:-2] broadcasts over b.shape[:-2]
176 // 3. a.size(i) <= b.size(i) for i=0,..., a.dim() - 3 (only for batch dimensions)
177 //
178 // MAGMA/LAPACK modify tensor `a` in-place, and the main goal of this method
179 // is to be memory efficient, which means that if there exists an index i such that
180 // a.shape[i] < b.shape[i], 0 <= i <= a.dim() - 3,
181 // then instead of materializing copies of `a` in the broadcasted shape, we keep
182 // a buffer copy of `a` along with flags that check whether specific batch dimension
183 // indices for `a` were already accessed. If they were, we copy the data from the buffer
184 // into `a`. The number of copies does not exceed
185 // prod(max(a.shape[:-2], b.shape[:-2]) - a.shape[:-2] + 1)
186 // and this value is attained by tensors with non-empty batch dimensions.
187 //
188 // func_t `f` is a callable that is being supplied with
189 // scalar_t* a_working_ptr, scalar_t* b_working_ptr, int64_t a_linear_batch_idx.
190 // a_working_ptr and b_working_ptr can directly be passed to LAPACK/MAGMA routines,
191 // and a_linear_batch_idx is an index in the 3d representation which corresponds to
192 // the memory a_working_ptr points to, in other words:
193 // a_working_ptr == a.view({-1, a.size(-2), a.size(-1)}.select(0, a_linear_batch_idx).data_ptr<scalar_t>();
194 // a_linear_batch_idx is useful to store metadata related to `a`, such as, for example,
195 // its rank or singular values (see linalg_lstsq).
196 template<typename scalar_t, typename func_t>
batch_iterator_with_broadcasting(const Tensor & a,const Tensor & b,const func_t & f)197 void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
198 IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
199 IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);
200
201 auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
202 auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);
203
204 TensorIterator iter = TensorIteratorConfig()
205 .set_check_mem_overlap(false)
206 .check_all_same_dtype(false)
207 .resize_outputs(false)
208 .add_output(b_linear_batch_idx)
209 .add_input(a_linear_batch_idx)
210 .build();
211
212 auto m = a.size(-2);
213 auto n = a.size(-1);
214 auto a_3d = a.view({batchCount(a), m, n});
215 auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});
216
217 auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
218 Tensor a_buffer, a_was_accessed, a_buffer_3d;
219 std::function<void(int64_t)> check_if_copy_needed_for_a
220 = [](int64_t /*a_curr_linear_batch_idx*/){};
221 if (a_broadcasts_over_b) {
222 a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
223 .copy_(a);
224 a_was_accessed = at::zeros(batchCount(a), at::kBool);
225 a_buffer_3d = a_buffer.view({batchCount(a), m, n});
226 check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
227 auto* a_was_accessed_flag = a_was_accessed
228 .select(0, a_curr_linear_batch_idx)
229 .data_ptr<bool>();
230 if (!(*a_was_accessed_flag)) {
231 *a_was_accessed_flag = true;
232 }
233 else {
234 a_3d.select(0, a_curr_linear_batch_idx)
235 .copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
236 }
237 };
238 }
239
240 auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
241 auto* b_batch_idx_ptr = data[0];
242 auto* a_batch_idx_ptr = data[1];
243
244 for (const auto elem C10_UNUSED : c10::irange(nelems)) {
245 auto b_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(b_batch_idx_ptr);
246 auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);
247
248 check_if_copy_needed_for_a(a_curr_linear_batch_idx);
249
250 auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
251 .data_ptr<scalar_t>();
252 auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
253 .data_ptr<scalar_t>();
254 f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);
255
256 b_batch_idx_ptr += strides[0];
257 a_batch_idx_ptr += strides[1];
258 }
259 };
260 iter.serial_for_each(loop, {0, batchCount(b)});
261 }
262
263 // Returns the epsilon value for floating types except half
_get_epsilon(const ScalarType & sc_type)264 inline double _get_epsilon(const ScalarType& sc_type) {
265 switch (sc_type) {
266 case at::ScalarType::Float:
267 return static_cast<double>(std::numeric_limits<float>::epsilon());
268 case at::ScalarType::Double:
269 return std::numeric_limits<double>::epsilon();
270 default:
271 AT_ERROR("This function doesn't handle types other than float and double");
272 }
273 }
274
275 // Validates input shapes and devices
276 // for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
linearSolveCheckInputs(const Tensor & self,const Tensor & A,const char * name)277 inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
278 TORCH_CHECK(self.device() == A.device(),
279 "Expected b and A to be on the same device, but found b on ",
280 self.device(), " and A on ", A.device(), " instead.");
281
282 TORCH_CHECK(self.scalar_type() == A.scalar_type(),
283 "Expected b and A to have the same dtype, but found b of type ",
284 self.scalar_type(), " and A of type ", A.scalar_type(), " instead.");
285
286 TORCH_CHECK(A.size(-1) == A.size(-2),
287 "A must be batches of square matrices, "
288 "but they are ", A.size(-2), " by ", A.size(-1), " matrices");
289
290 TORCH_CHECK(A.size(-1) == self.size(-2),
291 "Incompatible matrix sizes for ", name, ": each A "
292 "matrix is ", A.size(-1), " by ", A.size(-1),
293 " but each b matrix is ", self.size(-2), " by ", self.size(-1));
294 }
295
296 inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
297 auto dtype = t.scalar_type();
298 TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
299 f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
300 if (!allow_low_precision_dtypes) {
301 TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble,
302 f_name, ": Low precision dtypes not supported. Got ", dtype);
303 }
304 }
305
306
307 // Checks if all the Tensors in a TensorList are of the same dimensions
checkAllSameDim(TensorList tensors,int64_t dim)308 inline void checkAllSameDim(TensorList tensors, int64_t dim) {
309 for (auto &t : tensors) {
310 TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
311 }
312 }
313
_linalg_broadcast_batch_dims(const Tensor & arg1,const Tensor & arg2)314 inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
315 // broadcast the batch dimensions of arg1 and arg2.
316 IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
317 IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
318 std::vector<int64_t> expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);
319
320 std::vector<int64_t> arg1_expand_size({expand_batch_portion});
321 arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });
322
323 std::vector<int64_t> arg2_expand_size({expand_batch_portion});
324 arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });
325 return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
326 }
327
_linalg_broadcast_batch_dims(const Tensor & arg1,const Tensor & arg2,const char * name)328 inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
329 // If there's no name we assume we don't want to check the errors
330 if (name != nullptr) {
331 linearSolveCheckInputs(arg1, arg2, name);
332 }
333
334 auto [arg1_expand_size, arg2_expand_size] = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
335
336 auto arg1_broadcasted = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
337 auto arg2_broadcasted = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
338 return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
339 }
340
broadcast_batch_size(const Tensor & t1,const Tensor & t2,int64_t n_batch_dims)341 inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
342 IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
343 IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
344 auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
345 return broadcasted_batch_sizes;
346 }
347
348 // Return a permutation with the given axes moved to the end.
_move_to_end(const Tensor & self,IntArrayRef axes)349 inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
350 const std::vector<int64_t> a = axes.vec();
351 const int64_t ndim = self.ndimension();
352 std::vector<int64_t> perm;
353
354 for (const auto i : c10::irange(ndim)) {
355 auto it = std::find(a.begin(), a.end(), i);
356 if (it == a.end()) {
357 perm.push_back(i);
358 }
359 }
360 for (auto i : a) {
361 perm.push_back(i);
362 }
363
364 TORCH_CHECK((int64_t)perm.size() == ndim,
365 "duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
366
367 return self.permute(perm);
368 }
369
370 // parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
_parse_qr_mode(c10::string_view mode)371 inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
372 bool compute_q;
373 bool reduced;
374 if (mode == "reduced") {
375 compute_q = true;
376 reduced = true;
377 } else if (mode == "complete") {
378 compute_q = true;
379 reduced = false;
380 } else if (mode == "r") {
381 compute_q = false;
382 reduced = true; // this is actually irrelevant in this mode
383 } else {
384 TORCH_CHECK(false, "qr received unrecognized mode '", mode,
385 "' but expected one of 'reduced' (default), 'r', or 'complete'");
386 }
387 return std::make_tuple(compute_q, reduced);
388 }
389
390 // Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
_compute_geometry_for_Q(const Tensor & input,bool reduced)391 inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
392 const Tensor& input,
393 bool reduced) {
394 int64_t m = input.size(-2), n = input.size(-1);
395 int64_t n_columns_q;
396
397 // We need to compute the required size of Q based on the `reduced` option
398 DimVector q_sizes(input.sizes());
399 if (!reduced && m > n) {
400 q_sizes[input.dim() - 1] = m;
401 n_columns_q = m;
402 } else {
403 q_sizes[input.dim() - 1] = n;
404 n_columns_q = std::min(m, n);
405 }
406 auto q_strides = batched_matrix_contiguous_strides(q_sizes, /*f-contig*/true);
407 return std::make_tuple(q_sizes, q_strides, n_columns_q);
408 }
409
svd_uses_cusolver(const Tensor & A)410 inline bool svd_uses_cusolver(const Tensor& A) {
411 // if cusolver is available, it is used unconditionally
412 return A.is_cuda()
413 && at::globalContext().hasCuSOLVER()
414 && at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma;
415 }
416
417
418 // Function used instead of .to so that the original strides are retained
419 // .to doesn't retain strides and make the output tensor contiguous
same_stride_to(const Tensor & original_tensor,const at::TensorOptions & options)420 inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
421 auto strided_to = at::empty_strided(original_tensor.sizes(),
422 original_tensor.strides(),
423 options);
424 strided_to.copy_(original_tensor);
425 return strided_to;
426 }
427
428 // Creates a dimension permutation array that can be given to `at::permute()`, which will shift
429 // the two specified dimensions to the end of a tensor, without changing the order of
430 // the other dimensions. `dim1` will be placed at the very end, and `dim0` will be
431 // placed just to the left of it.
432 //
433 // For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
434 // calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
435 // be `vec(0, 2, 1, 3)`.
create_dim_backshift_permutation(int64_t dim0,int64_t dim1,int64_t ndim)436 inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
437 TORCH_CHECK(
438 (dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
439 "duplicate or invalid dimensions");
440 std::vector<int64_t> permutation(ndim);
441 int64_t cur_permuted_dim = 0;
442 for (const auto dim_ind : c10::irange(ndim)) {
443 if ((dim_ind != dim0) && (dim_ind != dim1)) {
444 permutation[cur_permuted_dim++] = dim_ind;
445 }
446 }
447 permutation[cur_permuted_dim++] = dim0;
448 permutation[cur_permuted_dim] = dim1;
449 return permutation;
450 }
451
452 // Creates a dimension permutation array that can be given to `at::permute()`, which
453 // will reverse a given permutation.
454 // The reverse permutation array is created by swapping the indices and their
455 // associated values from the given permutation array.
create_reverse_permutation(std::vector<int64_t> permutation)456 inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
457 int64_t ndim = permutation.size();
458 std::vector<int64_t> reverse_permutation(ndim);
459 for (const auto dim_ind : c10::irange(ndim)) {
460 reverse_permutation[permutation[dim_ind]] = dim_ind;
461 }
462 return reverse_permutation;
463 }
464
465 // Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd
466 // See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186
computeLRWorkDim(const char jobz,int64_t m,int64_t n)467 inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
468 auto mn = std::min(m, n);
469 auto mx = std::max(m, n);
470 if (jobz == 'N') {
471 #ifdef __APPLE__
472 // According to `vecLib.framework/Headers/clapack.h` Accelerate.framework is based on LAPACK 3.2.1
473 return 7 * mn;
474 #else
475 // These setting is valid for on LAPACK 3.6+
476 return 5 * mn;
477 #endif
478 }
479 if (mx > 10 * mn) {
480 return 5 * mn * mn + 5 * mn;
481 }
482 return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
483 }
484
485 // This function checks whether the uplo argument input is valid
486 // Allowed strings are "u", "U", "l", "L"
checkUplo(const c10::string_view uplo)487 inline void checkUplo(const c10::string_view uplo) {
488 // To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
489 char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
490 TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
491 "Expected UPLO argument to be 'L' or 'U', but got ", uplo);
492 }
493
494 inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
495 TORCH_CHECK(
496 result.device() == input.device(),
497 fn_name,
498 ": Expected ", result_name, " and input tensors to be on the same device, but got ",
499 result_name, " on ", result.device(), " and input on ", input.device());
500 }
501
502 // Check the dtype of result and input tensors (for _out variants).
503 // Most linear algebra functions have the same dtype for input and output
504 // (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype.
505 // According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
506 // c10::canCast is used for checking the "safe copy" dtype requirements.
507 inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
508 bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
509 TORCH_CHECK(
510 can_cast,
511 fn_name,
512 ": Expected ", result_name, " to be safely castable from ", input.scalar_type(), " dtype, but got ",
513 result_name, " with dtype ", result.scalar_type());
514 }
515
516 // Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type)
517 inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
518 bool can_cast = c10::canCast(result_type, out_type);
519 TORCH_CHECK(
520 can_cast,
521 fn_name,
522 ": Expected ", out_name, " to be safely castable from ", result_type, " dtype, but got ",
523 out_name, " with dtype ", out_type);
524 }
525
checkNotComplexTolerance(const Tensor & tol,const c10::string_view f_name,const c10::string_view tol_name)526 inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
527 TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
528 f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
529 }
530
531 /*
532 Two types of 'other' tensors are supported when solving
533 a system of linear equations matmul(input, x) = other:
534 * 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
535 * 2-dimensional (2D) tensor or batch of 2D tensors (matrix case).
536 The original torch.solve supported only the matrix case, while NumPy works for both cases.
537 For the batched input we need to be able to distinguish them.
538 Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
539 This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
540 */
linalg_solve_is_vector_rhs(const Tensor & input,const Tensor & other)541 inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
542 auto expected_batched_rhs_shape = SymIntArrayRef(input.sym_sizes().data(), input.dim() - 1); // input.shape[:-1]
543 bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sym_sizes().equals(expected_batched_rhs_shape));
544 return vector_case;
545 }
546
547 /*
548 Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor.
549 */
get_linear_indices(int64_t numel,IntArrayRef original_shape,IntArrayRef broadcast_shape)550 inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
551 TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
552 return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
553 }
554
555 class BroadcastLinearIndices {
556 private:
557 Tensor linear_indices_;
558 bool is_broadcasting_;
559
560 public:
BroadcastLinearIndices(int64_t numel,IntArrayRef original_shape,IntArrayRef broadcast_shape)561 BroadcastLinearIndices(
562 int64_t numel,
563 IntArrayRef original_shape,
564 IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
565 // The assumption is that the broadcast_shape is a materialized broadcast
566 // shape of the original_shape. We need to compute the linear indices
567 // compatible with the original_shape to access the elements in the original
568 // tensor corresponding to the broadcast tensor.
569 if (is_broadcasting_) {
570 linear_indices_ =
571 get_linear_indices(numel, original_shape, broadcast_shape);
572 }
573 }
operator()574 int64_t operator()(int64_t broadcast_linear_index) {
575 return is_broadcasting_
576 ? linear_indices_.data_ptr<int64_t>()[broadcast_linear_index]
577 : broadcast_linear_index;
578 }
579 };
580
is_blas_compatible_column_major_order(const Tensor & input)581 inline bool is_blas_compatible_column_major_order(const Tensor& input) {
582 IntArrayRef input_strides = input.strides();
583 IntArrayRef input_sizes = input.sizes();
584 auto ndim = input.dim();
585 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
586 if (ndim > 3) {
587 return input.transpose(-2, -1).is_contiguous();
588 }
589 auto leading_dimension = input_strides[ndim - 1];
590 auto rows = input_sizes[ndim - 2];
591 bool batch_stride_compatible = true;
592 if (ndim == 3) {
593 auto cols = input_sizes[ndim - 1];
594 batch_stride_compatible =
595 input_strides[ndim - 3] >= leading_dimension * cols;
596 }
597 return (input_strides[ndim - 2] == 1) &&
598 (leading_dimension >= std::max<int64_t>(1, rows)) &&
599 batch_stride_compatible;
600 }
601
is_blas_compatible_row_major_order(const Tensor & input)602 inline bool is_blas_compatible_row_major_order(const Tensor& input) {
603 IntArrayRef input_strides = input.strides();
604 IntArrayRef input_sizes = input.sizes();
605 auto ndim = input.dim();
606 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
607 if (ndim > 3) {
608 return input.is_contiguous();
609 }
610 auto leading_dimension = input_strides[ndim - 2];
611 auto cols = input_sizes[ndim - 1];
612 bool batch_stride_compatible = true;
613 if (ndim == 3) {
614 auto rows = input_sizes[ndim - 2];
615 batch_stride_compatible =
616 input_strides[ndim - 3] >= leading_dimension * rows;
617 }
618 return (input_strides[ndim - 1] == 1) &&
619 (leading_dimension >= std::max<int64_t>(1, cols)) &&
620 batch_stride_compatible;
621 }
622
623 } // namespace at::native
624