xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/LinearAlgebra.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/ExpandUtils.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/OpMathType.h>
7 #include <ATen/Parallel.h>
8 #include <ATen/TensorIndexing.h>
9 #include <ATen/TensorIterator.h>
10 #include <ATen/TensorOperators.h>
11 #include <ATen/TensorSubclassLikeUtils.h>
12 #include <ATen/TensorUtils.h>
13 #include <ATen/core/Tensor.h>
14 #include <ATen/native/CPUBlas.h>
15 #include <ATen/native/cpu/int_mm_kernel.h>
16 #include <ATen/native/LinearAlgebra.h>
17 #include <ATen/native/LinearAlgebraUtils.h>
18 #include <ATen/native/ReduceOps.h>
19 #include <ATen/native/ReduceOpsUtils.h>
20 #include <ATen/native/Resize.h>
21 #include <ATen/native/mkldnn/Matmul.h>
22 #include <ATen/native/mkldnn/Utils.h>
23 #include <c10/core/GradMode.h>
24 #include <c10/util/accumulate.h>
25 #include <c10/util/irange.h>
26 #include <variant>
27 
28 #ifndef AT_PER_OPERATOR_HEADERS
29 #include <ATen/Functions.h>
30 #include <ATen/NativeFunctions.h>
31 #else
32 #include <ATen/ops/_addmm_activation_native.h>
33 #include <ATen/ops/_compute_linear_combination_native.h>
34 #include <ATen/ops/_convert_weight_to_int4pack_native.h>
35 #include <ATen/ops/_int_mm_native.h>
36 #include <ATen/ops/_linalg_check_errors.h>
37 #include <ATen/ops/_linalg_det.h>
38 #include <ATen/ops/_linalg_det_native.h>
39 #include <ATen/ops/_linalg_slogdet.h>
40 #include <ATen/ops/_linalg_slogdet_native.h>
41 #include <ATen/ops/_unsafe_view.h>
42 #include <ATen/ops/_weight_int4pack_mm_native.h>
43 #include <ATen/ops/_weight_int8pack_mm_native.h>
44 #include <ATen/ops/abs.h>
45 #include <ATen/ops/addbmm_native.h>
46 #include <ATen/ops/addmm_native.h>
47 #include <ATen/ops/addr.h>
48 #include <ATen/ops/addr_native.h>
49 #include <ATen/ops/arange.h>
50 #include <ATen/ops/argsort.h>
51 #include <ATen/ops/baddbmm_native.h>
52 #include <ATen/ops/bmm.h>
53 #include <ATen/ops/bmm_native.h>
54 #include <ATen/ops/cat.h>
55 #include <ATen/ops/ceil.h>
56 #include <ATen/ops/chain_matmul_native.h>
57 #include <ATen/ops/cumsum.h>
58 #include <ATen/ops/det_native.h>
59 #include <ATen/ops/diag_embed.h>
60 #include <ATen/ops/diff.h>
61 #include <ATen/ops/dot.h>
62 #include <ATen/ops/dot_native.h>
63 #include <ATen/ops/empty.h>
64 #include <ATen/ops/empty_like.h>
65 #include <ATen/ops/eye.h>
66 #include <ATen/ops/floor.h>
67 #include <ATen/ops/frobenius_norm_native.h>
68 #include <ATen/ops/from_blob.h>
69 #include <ATen/ops/full.h>
70 #include <ATen/ops/full_like.h>
71 #include <ATen/ops/gelu.h>
72 #include <ATen/ops/ger_native.h>
73 #include <ATen/ops/index_select.h>
74 #include <ATen/ops/inner_native.h>
75 #include <ATen/ops/is_complex_native.h>
76 #include <ATen/ops/is_floating_point_native.h>
77 #include <ATen/ops/kron_native.h>
78 #include <ATen/ops/linalg_cond.h>
79 #include <ATen/ops/linalg_cond_native.h>
80 #include <ATen/ops/linalg_det.h>
81 #include <ATen/ops/linalg_det_native.h>
82 #include <ATen/ops/linalg_diagonal_native.h>
83 #include <ATen/ops/linalg_eigh.h>
84 #include <ATen/ops/linalg_eigvalsh.h>
85 #include <ATen/ops/linalg_inv.h>
86 #include <ATen/ops/linalg_inv_ex.h>
87 #include <ATen/ops/linalg_lu_factor_ex.h>
88 #include <ATen/ops/linalg_matmul_native.h>
89 #include <ATen/ops/linalg_matrix_exp.h>
90 #include <ATen/ops/linalg_matrix_exp_native.h>
91 #include <ATen/ops/linalg_matrix_norm.h>
92 #include <ATen/ops/linalg_matrix_norm_native.h>
93 #include <ATen/ops/linalg_matrix_power_native.h>
94 #include <ATen/ops/linalg_matrix_rank.h>
95 #include <ATen/ops/linalg_matrix_rank_native.h>
96 #include <ATen/ops/linalg_multi_dot_native.h>
97 #include <ATen/ops/linalg_norm.h>
98 #include <ATen/ops/linalg_norm_native.h>
99 #include <ATen/ops/linalg_pinv.h>
100 #include <ATen/ops/linalg_pinv_native.h>
101 #include <ATen/ops/linalg_slogdet.h>
102 #include <ATen/ops/linalg_slogdet_native.h>
103 #include <ATen/ops/linalg_solve.h>
104 #include <ATen/ops/linalg_svdvals.h>
105 #include <ATen/ops/linalg_tensorinv.h>
106 #include <ATen/ops/linalg_tensorinv_native.h>
107 #include <ATen/ops/linalg_tensorsolve.h>
108 #include <ATen/ops/linalg_tensorsolve_native.h>
109 #include <ATen/ops/linalg_vector_norm.h>
110 #include <ATen/ops/linalg_vector_norm_native.h>
111 #include <ATen/ops/log2.h>
112 #include <ATen/ops/logdet_native.h>
113 #include <ATen/ops/matmul.h>
114 #include <ATen/ops/matmul_native.h>
115 #include <ATen/ops/matrix_exp_backward_native.h>
116 #include <ATen/ops/matrix_exp_native.h>
117 #include <ATen/ops/matrix_power_native.h>
118 #include <ATen/ops/max.h>
119 #include <ATen/ops/mm.h>
120 #include <ATen/ops/mm_native.h>
121 #include <ATen/ops/movedim.h>
122 #include <ATen/ops/mul.h>
123 #include <ATen/ops/mv.h>
124 #include <ATen/ops/narrow.h>
125 #include <ATen/ops/ne.h>
126 #include <ATen/ops/norm.h>
127 #include <ATen/ops/nuclear_norm_native.h>
128 #include <ATen/ops/ones.h>
129 #include <ATen/ops/outer.h>
130 #include <ATen/ops/outer_native.h>
131 #include <ATen/ops/pinverse_native.h>
132 #include <ATen/ops/pow.h>
133 #include <ATen/ops/prod.h>
134 #include <ATen/ops/real.h>
135 #include <ATen/ops/relu.h>
136 #include <ATen/ops/slogdet_native.h>
137 #include <ATen/ops/sort.h>
138 #include <ATen/ops/sqrt.h>
139 #include <ATen/ops/sum.h>
140 #include <ATen/ops/tensordot.h>
141 #include <ATen/ops/unique_consecutive.h>
142 #include <ATen/ops/vdot_native.h>
143 #include <ATen/ops/where.h>
144 #include <ATen/ops/zeros.h>
145 #include <ATen/ops/zeros_like.h>
146 #endif
147 
148 #include <limits>
149 #include <numeric>
150 #include <string>
151 #include <tuple>
152 #include <utility>
153 #if !defined(__s390x__) && !defined(__powerpc__)
154 #include <cpuinfo.h>
155 #endif
156 
157 namespace at {
158 
159 namespace detail {
check_linalg_norm_dtype(std::optional<ScalarType> opt_dtype,ScalarType self_dtype,const char * const name)160   static void check_linalg_norm_dtype(std::optional<ScalarType> opt_dtype, ScalarType self_dtype, const char* const name) {
161     if (opt_dtype.has_value()) {
162       auto dtype = opt_dtype.value();
163       TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype), name, ": dtype should"
164           " be floating point or complex, but got ", dtype);
165       TORCH_CHECK(isComplexType(self_dtype) == isComplexType(dtype),
166           name, ": dtype should be ", isComplexType(self_dtype) ? "complex" : "real",
167           " for ", isComplexType(self_dtype) ? "complex" : "real", " inputs, but got ", dtype);
168       TORCH_CHECK(promoteTypes(self_dtype, dtype) == dtype,
169           name, ": the dtype of the input ", "(", self_dtype, ") should be convertible ",
170           "without narrowing to the specified dtype (", dtype, ")");
171     }
172   }
173 }
174 
175 namespace meta {
176 
177 #define ADDMM_META() \
178   TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type()); \
179   TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type()); \
180   TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor"); \
181   TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor"); \
182   TORCH_CHECK( \
183       mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", \
184       mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); \
185  \
186   auto names = at::namedinference::propagate_names_for_addmm(mat1, mat2, self); \
187   set_output_raw_strided(0, {mat1.sizes()[0], mat2.sizes()[1]}, {}, mat1.options(), names);
188 
TORCH_META_FUNC(addmm)189 TORCH_META_FUNC(addmm)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
190   ADDMM_META();
191 }
192 
TORCH_META_FUNC(_addmm_activation)193 TORCH_META_FUNC(_addmm_activation)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu) {
194   ADDMM_META();
195 }
196 
TORCH_META_FUNC(mm)197 TORCH_META_FUNC(mm)(const Tensor & self, const Tensor & mat2) {
198   TORCH_CHECK(self.dim() == 2, "self must be a matrix");
199   TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
200   TORCH_CHECK(
201       self.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
202       self.sizes()[0], "x", self.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
203 
204   auto names = at::namedinference::compute_matmul_outnames(self, mat2);
205   set_output_raw_strided(0, {self.sizes()[0], mat2.sizes()[1]}, {}, self.options(), names);
206 }
207 
TORCH_META_FUNC(linalg_vector_norm)208 TORCH_META_FUNC(linalg_vector_norm)(const Tensor& self, const Scalar& scalar_ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
209   at::native::checkFloatingOrComplex(self, "linalg.vector_norm");
210 
211   auto dim = opt_dim.value_or(IntArrayRef{});
212   // Casting a large integer to a double will just introduce an error for
213   // values larger than 10^53 (same for negative numbers), so that's fine.
214   auto ord = scalar_ord.toDouble();
215 
216   // For more context, see issue 52783
217   // If the tensor is empty and norm < 0 || norm == infty
218   //   - We cannot reduce the whole tensor
219   //   - We cannot reduce over an empty dimension
220   if (self.numel() == 0 && (ord < 0. || ord == INFINITY)) {
221     // dim=None or dim=() reduces the whole tensor
222     TORCH_CHECK(opt_dim.has_value() && !opt_dim->empty(),
223       "linalg.vector_norm cannot compute the ", scalar_ord, " norm on an empty ",
224       "tensor because the operation does not have an identity");
225     for (auto dim_num : dim) {
226       TORCH_CHECK(self.size(dim_num) != 0,
227         "linalg.vector_norm cannot compute the ", scalar_ord, " norm on the dimension ", dim_num ,
228         "because this dimension is empty and the operation does not have an identity");
229     }
230   }
231 
232   at::detail::check_linalg_norm_dtype(opt_dtype, self.scalar_type(), "linalg.vector_norm");
233 
234   auto mask = at::native::make_dim_mask(dim, self.dim());
235   auto shape = at::native::shape_from_dim_mask(self, std::move(mask), keepdim);
236   auto options = self.options()
237                      .dtype(toRealValueType(opt_dtype.value_or(self.scalar_type())));
238 
239   set_output_raw_strided(0, shape, {}, options);
240 }
241 
TORCH_META_FUNC(_linalg_det)242 TORCH_META_FUNC(_linalg_det)(const Tensor& A) {
243   at::native::squareCheckInputs(A, "linalg.det");
244   at::native::checkFloatingOrComplex(A, "linalg.det");
245 
246   auto shape = A.sizes();
247   auto ndim = shape.size();
248 
249   // det
250   set_output_contiguous(0, shape.slice(0, ndim - 2), A.options());
251 
252   // LU
253   auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
254   set_output_strided(1, shape, LU_strides, A.options());
255 
256   // pivots
257   set_output_contiguous(2, shape.slice(0, ndim - 1), A.options().dtype(kInt));
258 }
259 
TORCH_META_FUNC(_linalg_slogdet)260 TORCH_META_FUNC(_linalg_slogdet)(const Tensor& A) {
261   at::native::squareCheckInputs(A, "linalg.slogdet");
262   at::native::checkFloatingOrComplex(A, "linalg.slogdet", /*low_precision*/false);
263 
264   auto shape= A.sizes();
265   auto ndim = shape.size();
266 
267   auto shape_outputs = shape.slice(0, ndim - 2);
268 
269   // sign
270   set_output_contiguous(0, shape_outputs, A.options());
271 
272   // logabsdet
273   set_output_contiguous(1, shape_outputs, A.options().dtype(toRealValueType(A.scalar_type())));
274 
275   // LU
276   auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
277   set_output_strided(2, shape, LU_strides, A.options());
278 
279   // pivots
280   set_output_contiguous(3, shape.slice(0, ndim - 1), A.options().dtype(kInt));
281 }
282 
283 template <typename Meta>
common_checks_baddbmm_bmm(Meta & meta,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha,bool is_bmm,const std::optional<Tensor> & self_baddbmm=std::nullopt)284 void common_checks_baddbmm_bmm(Meta& meta, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
285   TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
286   TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
287 
288   const auto batch1_sizes = batch1.sizes();
289   const auto batch2_sizes = batch2.sizes();
290 
291   int64_t bs = batch1_sizes[0];
292   int64_t contraction_size = batch1_sizes[2];
293   int64_t res_rows = batch1_sizes[1];
294   int64_t res_cols = batch2_sizes[2];
295   std::vector<int64_t> output_size {bs, res_rows, res_cols};
296 
297   TORCH_CHECK(batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size,
298               "Expected size for first two dimensions of batch2 tensor to be: [",
299               bs, ", ", contraction_size, "] but got: [", batch2_sizes[0], ", ", batch2_sizes[1], "].");
300 
301   auto& result = meta.maybe_get_output(0);
302   // 'set_output' does not resize for in-place calls
303   meta.set_output_raw_strided(0, output_size, {}, batch2.options());
304   const auto result_sizes = result.sizes();
305   // Error is raised if called from in-place overload with incorrect shape
306   TORCH_CHECK(result_sizes == output_size,
307               "Expected an output tensor with shape [", output_size, "] but got shape ", result_sizes);
308 
309   std::vector<Dimname> outnames = {};
310   if (!is_bmm) {
311     if (self_baddbmm.has_value()) {
312       const auto& self = self_baddbmm.value();
313       if (beta.toComplexDouble() != 0.0) result.copy_(self);
314       TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
315       const auto self_sizes = self.sizes();
316       TORCH_CHECK(self_sizes == output_size,
317                   "Expected an input tensor shape with shape ", output_size, " but got shape: ", self_sizes);
318       outnames = namedinference::compute_baddbmm_outnames(result, batch1, batch2, self);
319     }
320   } else {
321     outnames = namedinference::compute_bmm_outnames(result, batch1, batch2);
322   }
323 
324   namedinference::propagate_names_if_nonempty(
325     result,
326     outnames
327   );
328 }
329 
TORCH_META_FUNC(bmm)330 TORCH_META_FUNC(bmm)(const Tensor& self, const Tensor& mat2) {
331     common_checks_baddbmm_bmm(*this, self, mat2, Scalar(0.0), Scalar(1.0), true);
332 }
333 
TORCH_META_FUNC(baddbmm)334 TORCH_META_FUNC(baddbmm)(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
335   auto self_ = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm");
336   TORCH_CHECK(self.dtype() == batch1.dtype(), "Input dtypes must be the same, got: input ", self.dtype(), ", batch1: ", batch1.dtype(), ", batch2: ", batch2.dtype());
337   common_checks_baddbmm_bmm(*this, batch1, batch2, beta, alpha, false, *self_);
338 }
339 
340 } // namespace meta
341 namespace native {
342 
343 DEFINE_DISPATCH(addr_stub);
344 
345 
346 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.det ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
347 
348 // As P is a permutation matrix
349 // det(P) = 1 if it's an even permutation and det(P) = -1 if it's an odd permutation
lu_det_P(const Tensor & pivots)350 static Tensor lu_det_P(const Tensor& pivots) {
351   return (at::arange(1, pivots.size(-1) + 1, pivots.options()) != pivots)
352     .sum(-1, /*keepdim=*/false, /*dtype=*/at::kLong)
353     .fmod_(2)
354     // take 0 to 1 and 1 to -1
355     .mul_(-2)
356     .add_(1);
357 }
358 
359 // Auxiliary function that returns the LU decomposition to use it in the backward
TORCH_IMPL_FUNC(_linalg_det_out)360 TORCH_IMPL_FUNC(_linalg_det_out)(const Tensor& A, const Tensor& result, const Tensor& LU, const Tensor& pivots) {
361   // info is an aux tensor
362   auto info = at::empty({0}, A.options().dtype(kInt));
363   // Optimisation: lu_factor_ex requires the input to be F-contig, otherwise it copies
364   // Use the transpose of if A is contiguous since det(A^T) = det(A)
365   // We limit this to real matrices, but it could also be implemented for complex matrices
366   at::linalg_lu_factor_ex_out(const_cast<Tensor&>(LU), const_cast<Tensor&>(pivots), const_cast<Tensor&>(info), A.is_contiguous() && !A.is_complex() ? A.mH() : A);
367 
368   // det = det_P * prod(diag(LU))
369   at::mul_out(const_cast<Tensor&>(result), lu_det_P(pivots), at::prod(LU.diagonal(0, -2 ,-1), /*dim=*/-1));
370 }
371 
linalg_det(const Tensor & A)372 Tensor linalg_det(const Tensor& A) {
373   return std::get<0>(at::_linalg_det(A));
374 }
375 
linalg_det_out(const Tensor & A,Tensor & result)376 Tensor& linalg_det_out(const Tensor& A, Tensor& result) {
377   auto LU = at::empty({0}, A.options());
378   auto pivots = at::empty({0}, A.options().dtype(kInt));
379   at::_linalg_det_out(result, LU, pivots, A);
380   return result;
381 }
382 
383 // torch.det, alias for torch.linalg.det
det(const Tensor & self)384 Tensor det(const Tensor& self) {
385   return at::linalg_det(self);
386 }
387 
388 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.slogdet ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
389 
390 // Auxiliary function that returns the LU decomposition to use it in the backward
TORCH_IMPL_FUNC(_linalg_slogdet_out)391 TORCH_IMPL_FUNC(_linalg_slogdet_out)(const Tensor& A, const Tensor& sign, const Tensor& logabsdet, const Tensor& LU, const Tensor& pivots) {
392   // info is an aux tensor
393   auto info = at::empty({0}, A.options().dtype(kInt));
394   // Optimisation: lu_factor_ex requires the input to be F-contig, otherwise it copies
395   // Use the transpose of if A is contiguous since det(A^T) = det(A)
396   // We limit this to real matrices, but it could also be implemented for complex matrices
397   at::linalg_lu_factor_ex_out(const_cast<Tensor&>(LU), const_cast<Tensor&>(pivots), const_cast<Tensor&>(info), A.is_contiguous() && !A.is_complex() ? A.mH() : A);
398 
399   auto diag_U = LU.diagonal(0, -2, -1);
400   // sign
401   at::mul_out(const_cast<Tensor&>(sign), diag_U.sgn().prod(-1), lu_det_P(pivots));
402 
403   // logabsdet
404   at::sum_out(const_cast<Tensor&>(logabsdet), diag_U.abs().log_(), -1);
405 }
406 
linalg_slogdet(const Tensor & A)407 std::tuple<Tensor, Tensor> linalg_slogdet(const Tensor& A) {
408   auto out = at::_linalg_slogdet(A);
409   return std::make_tuple(std::move(std::get<0>(out)), std::move(std::get<1>(out)));
410 }
411 
linalg_slogdet_out(const Tensor & A,Tensor & sign,Tensor & logabsdet)412 std::tuple<Tensor&, Tensor&> linalg_slogdet_out(const Tensor& A, Tensor& sign, Tensor& logabsdet) {
413   auto LU = at::empty({0}, A.options());
414   auto pivots = at::empty({0}, A.options().dtype(kInt));
415   at::_linalg_slogdet_out(sign, logabsdet, LU, pivots, A);
416   return std::tie(sign, logabsdet);
417 }
418 
419 // Alias
slogdet(const Tensor & A)420 std::tuple<Tensor, Tensor> slogdet(const Tensor& A) {
421   return at::linalg_slogdet(A);
422 }
423 
slogdet_out(const Tensor & A,Tensor & sign,Tensor & logabsdet)424 std::tuple<Tensor&, Tensor&> slogdet_out(const Tensor& A, Tensor& sign, Tensor& logabsdet) {
425   return at::linalg_slogdet_out(sign, logabsdet, A);
426 }
427 
428 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ logdet ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
429 
logdet(const Tensor & A)430 Tensor logdet(const Tensor& A) {
431   squareCheckInputs(A, "logdet");
432   checkFloatingOrComplex(A, "logdet", /*low_precision*/false);
433   auto [sign, logabsdet] = at::linalg_slogdet(A);
434 
435   if (A.is_complex()) {
436     return sign.log() + logabsdet;
437   } else {
438     return at::where(sign == -1., NAN, logabsdet);
439   }
440 }
441 
442 namespace {
443 
444 // This function extracts the optional Tensors for atol and rtol
445 // Default value for atol is zero
446 // Default value for rtol is eps*max(rows, cols)
447 // If atol is specified and rtol is not specified then default value for rtol is zero
448 // It is used for matrix_rank and pinv
get_atol_rtol(const Tensor & input,const std::optional<Tensor> & atol_opt,const std::optional<Tensor> & rtol_opt,const c10::string_view function_name)449 std::tuple<Tensor, Tensor> get_atol_rtol(
450     const Tensor& input,
451     const std::optional<Tensor>& atol_opt,
452     const std::optional<Tensor>& rtol_opt,
453     const c10::string_view function_name) {
454   auto options = input.options();
455   if (input.device().type() == kMetal || input.device().type() == kMPS) {
456     options = options.dtype(ScalarType::Float);
457   } else {
458     options = options.dtype(ScalarType::Double);
459   }
460   auto atol = atol_opt.has_value() ? atol_opt.value() : at::zeros({}, options);
461   checkNotComplexTolerance(atol, function_name, "atol");
462   Tensor rtol;
463   if (rtol_opt.has_value()) {
464     rtol = rtol_opt.value();
465     checkNotComplexTolerance(rtol, function_name, "rtol");
466   } else {
467     ScalarType real_dtype = toRealValueType(input.scalar_type());
468     auto default_rtol = at::full({}, _get_epsilon(real_dtype) * std::max(input.sym_size(-1), input.sym_size(-2)), options);
469     rtol = atol_opt.has_value()
470            ? at::where(atol_opt.value() > 0, at::zeros({}, options), default_rtol)
471            : std::move(default_rtol);
472   }
473   return std::make_tuple(atol, rtol);
474 }
475 
get_atol_rtol(const Tensor & input,std::optional<double> atol_opt,std::optional<double> rtol_opt)476 std::tuple<Tensor, Tensor> get_atol_rtol(
477     const Tensor& input,
478     std::optional<double> atol_opt,
479     std::optional<double> rtol_opt) {
480   auto atol = atol_opt.has_value() ? atol_opt.value() : 0.0;
481   c10::SymFloat rtol;
482   if (rtol_opt.has_value()) {
483     rtol = rtol_opt.value();
484   } else {
485     ScalarType real_dtype = toRealValueType(input.scalar_type());
486     auto default_rtol = _get_epsilon(real_dtype) * std::max(input.sym_size(-1), input.sym_size(-2));
487     rtol = (atol_opt.has_value() && atol_opt.value() > 0.0)
488            ? 0.0
489            : default_rtol;
490   }
491   auto options = input.options();
492   if (input.device().type() == kMetal || input.device().type() == kMPS) {
493     options = options.dtype(ScalarType::Float);
494   } else {
495     options = options.dtype(ScalarType::Double);
496   }
497   auto atol_tensor = at::full({}, atol, options);
498   auto rtol_tensor = at::full({}, rtol, options);
499   return std::make_tuple(atol_tensor, rtol_tensor);
500 }
501 
502 } // anonymous namespace
503 
linalg_pinv(const Tensor & input,const std::optional<Tensor> & atol_opt,const std::optional<Tensor> & rtol_opt,bool hermitian)504 Tensor linalg_pinv(
505     const Tensor& input,
506     const std::optional<Tensor>& atol_opt,
507     const std::optional<Tensor>& rtol_opt,
508     bool hermitian) {
509   // FIXME: Whenever we have a nice lstsq, we should dispatch this function to simply be
510   // `torch.lstsq(A, torch.eye(A.shape[-1]), atol=atol, rtol=rtol)`
511   // with a driver that supports singular inputs
512   NoTF32Guard disable_tf32;
513   ScalarType t = input.scalar_type();
514   TORCH_CHECK((t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble)
515               && input.dim() >= 2,
516               "linalg.pinv(", t, "{", input.sizes(), "}): expected a tensor with 2 or more dimensions "
517               "of float, double, cfloat or cdouble types");
518 
519   auto [atol, rtol] = get_atol_rtol(input, atol_opt, rtol_opt, "torch.linalg.pinv");
520 
521   if (input.sym_numel() == 0) {
522     // The implementation below uses operations that do not work for zero numel tensors
523     // therefore we need this early return for 'input.numel() == 0' case
524     // TODO: replace input.svd with linalg_svd when torch/xla can work with at::linalg_svd
525     auto [U, S, V] = input.svd();
526     return at::matmul(V * S.reciprocal().unsqueeze(-2), U.mH());
527   }
528 
529   // If not Hermitian use singular value decomposition, else use eigenvalue decomposition
530   if (!hermitian) {
531     // TODO: replace input.svd with linalg_svd
532     // using linalg_svd breaks pytorch/xla, see https://github.com/pytorch/xla/issues/2755
533     auto [U, S, V] = input.svd();
534     Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1);  // singular values are sorted in descending order
535     Tensor tol = at::max(atol.unsqueeze(-1), rtol.unsqueeze(-1) * max_val);
536     Tensor S_pseudoinv = at::where(S > tol, S.reciprocal(), at::zeros({}, S.options())).to(input.dtype());
537     // computes V @ diag(S_pseudoinv) @ U.conj().T
538     return at::matmul(V * S_pseudoinv.unsqueeze(-2), U.mH());
539   } else {
540     auto [S, U] = at::linalg_eigh(input);
541     // For Hermitian matrices, singular values equal to abs(eigenvalues)
542     Tensor S_abs = S.abs();
543     // eigenvalues are sorted in ascending order starting with negative values, we need a maximum value of abs(eigenvalues)
544     Tensor max_val = S_abs.amax(/*dim=*/-1, /*keepdim=*/true);
545     Tensor tol = at::max(atol.unsqueeze(-1), rtol.unsqueeze(-1) * max_val);
546     Tensor S_pseudoinv = at::where(S_abs > tol, S.reciprocal(), at::zeros({}, S.options())).to(input.dtype());
547     // computes U @ diag(S_pseudoinv) @ U.conj().T
548     return at::matmul(U * S_pseudoinv.unsqueeze(-2), U.mH());
549   }
550 }
551 
linalg_pinv(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)552 Tensor linalg_pinv(const Tensor& input, std::optional<double> atol, std::optional<double> rtol, bool hermitian) {
553   auto [atol_tensor, rtol_tensor] = get_atol_rtol(input, atol, rtol);
554   return at::linalg_pinv(input, atol_tensor, rtol_tensor, hermitian);
555 }
556 
linalg_pinv(const Tensor & input,const Tensor & rcond,bool hermitian)557 Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {
558   // For NumPy compatibility the rcond argument is used as relative tolerance
559   checkNotComplexTolerance(rcond, "torch.linalg.pinv", "rcond");
560   auto options = input.options();
561   if (input.device().type() == kMetal || input.device().type() == kMPS) {
562     options = options.dtype(ScalarType::Float);
563   } else {
564     options = options.dtype(ScalarType::Double);
565   }
566   return at::linalg_pinv(input, at::zeros({}, options), rcond, hermitian);
567 }
568 
linalg_pinv(const Tensor & input,double rcond,bool hermitian)569 Tensor linalg_pinv(const Tensor& input, double rcond, bool hermitian) {
570   // For NumPy compatibility the rcond argument is used as relative tolerance
571   return at::linalg_pinv(input, 0.0, rcond, hermitian);
572 }
573 
574 // TODO: implement _out variant avoiding copy and using already allocated storage directly
linalg_pinv_out(const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian,Tensor & result)575 Tensor& linalg_pinv_out(
576     const Tensor& input,
577     const std::optional<Tensor>& atol,
578     const std::optional<Tensor>& rtol,
579     bool hermitian,
580     Tensor& result) {
581   checkSameDevice("linalg.pinv", result, input);
582   checkLinalgCompatibleDtype("linalg.pinv", result, input);
583   Tensor result_tmp = at::linalg_pinv(input, atol, rtol, hermitian);
584   at::native::resize_output(result, result_tmp.sizes());
585   result.copy_(result_tmp);
586   return result;
587 }
588 
linalg_pinv_out(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian,Tensor & result)589 Tensor& linalg_pinv_out(
590     const Tensor& input,
591     std::optional<double> atol,
592     std::optional<double> rtol,
593     bool hermitian,
594     Tensor& result) {
595   checkSameDevice("linalg.pinv", result, input);
596   checkLinalgCompatibleDtype("linalg.pinv", result, input);
597   Tensor result_tmp = at::linalg_pinv(input, atol, rtol, hermitian);
598   at::native::resize_output(result, result_tmp.sizes());
599   result.copy_(result_tmp);
600   return result;
601 }
602 
linalg_pinv_out(const Tensor & input,const Tensor & rcond,bool hermitian,Tensor & result)603 Tensor& linalg_pinv_out(const Tensor& input, const Tensor& rcond, bool hermitian, Tensor& result) {
604   checkSameDevice("linalg.pinv", result, input);
605   checkLinalgCompatibleDtype("linalg.pinv", result, input);
606 
607   Tensor result_tmp = at::linalg_pinv(input, rcond, hermitian);
608   at::native::resize_output(result, result_tmp.sizes());
609   result.copy_(result_tmp);
610   return result;
611 }
612 
linalg_pinv_out(const Tensor & input,double rcond,bool hermitian,Tensor & result)613 Tensor& linalg_pinv_out(const Tensor& input, double rcond, bool hermitian, Tensor& result) {
614   Tensor rcond_tensor = at::full({}, rcond, input.options().dtype(ScalarType::Double));
615   return at::linalg_pinv_out(result, input, rcond_tensor, hermitian);
616 }
617 
pinverse(const Tensor & self,double rcond)618 Tensor pinverse(const Tensor& self, double rcond) {
619   return at::linalg_pinv(self, rcond, /*hermitian=*/false);
620 }
621 
622 // matrix_power implementation
623 namespace {
624 
625 /**
626  * @brief Raises the input matrix to the given power n
627  *
628  * If the exponent n is negative, the inverse of the input
629  * matrix will be raised to power abs(n).
630  *
631  * @param self (batched) square matrix to raise to power n
632  * @param n exponent to raise matrix (or matrices in batch) to
633  * @param _out optional tensor to write the output to
634  * @return Tensor input matrix raised to power n
635  */
linalg_matrix_power_impl(const Tensor & self,int64_t n,std::optional<Tensor> _out)636 Tensor linalg_matrix_power_impl(
637     const Tensor& self,
638     int64_t n,
639     std::optional<Tensor> _out) {
640   NoTF32Guard disable_tf32;
641   auto out = _out.value_or(Tensor());
642 
643   squareCheckInputs(self, "linalg.matrix_power");
644   if (_out.has_value()) {
645     checkSameDevice("matrix_power", out, self);
646     checkLinalgCompatibleDtype("matrix_power", out, self);
647     at::native::resize_output_symint(out, self.sym_sizes());
648   }
649 
650   // For n=0 we return the identity matrix of the same shape as input.
651   if (n == 0) {
652     if (!_out.has_value()) {
653       // Clone input to include result in the autograd graph
654       out = self.clone(at::MemoryFormat::Contiguous);
655     }
656     return out.copy_(at::eye_symint(self.sym_size(-2), self.options()));
657   }
658   if (n == 1) {
659     return _out.has_value() ? out.copy_(self)
660                             : self.clone(at::MemoryFormat::Contiguous);
661   }
662   if (n == -1) {
663     return _out.has_value() ? at::linalg_inv_out(out, self)
664                             : at::linalg_inv(self);
665   }
666 
667   // For negative n we inverte the input matrix before raising to power abs(n)
668   auto a = n < 0 ? at::linalg_inv(self) : self;
669   n = std::abs(n);
670 
671   // Fast paths for small powers
672   if (n == 2) {
673     return _out.has_value() ? at::matmul_out(out, a, a) : at::matmul(a, a);
674   }
675   if (n == 3) {
676     return _out.has_value() ? at::matmul_out(out, at::matmul(a, a), a)
677                             : at::matmul(at::matmul(a, a), a);
678   }
679 
680   // This is a binary decomposition of n.
681   // Moving from the least significant bit to the most significant bit
682   // This is done to reduce the number of matrix multiplications
683   // by raising the input matrix in powers of 2
684   // The total number of matrix multiplications are
685   // number of bits + number of bits that equal 1 ~ O(log n)
686   // instead of O(n)
687   Tensor z, result;
688   while (n > 0) {
689     const auto bit = n % 2;
690     n = n / 2;
691     z = z.defined() ? at::matmul(z, z) : a;
692     if (bit == 1) {
693       if (_out.has_value() && n <= 0) {
694         // Last multiplication can use the out version
695         return result.defined() ? at::matmul_out(out, result, z) : out.copy_(z);
696       }
697       result = result.defined() ? at::matmul(result, z) : z;
698     }
699   }
700 
701   return result;
702 }
703 
704 } // namespace
705 
linalg_matrix_power_out(const Tensor & self,int64_t n,Tensor & result)706 Tensor& linalg_matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
707   linalg_matrix_power_impl(self, n, result);
708   return result;
709 }
710 
linalg_matrix_power(const Tensor & self,int64_t n)711 Tensor linalg_matrix_power(const Tensor& self, int64_t n) {
712   return linalg_matrix_power_impl(self, n, std::nullopt);
713 }
714 
matrix_power_out(const Tensor & self,int64_t n,Tensor & result)715 Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
716   return at::native::linalg_matrix_power_out(self, n, result);
717 }
718 
matrix_power(const Tensor & self,int64_t n)719 Tensor matrix_power(const Tensor& self, int64_t n) {
720   return at::native::linalg_matrix_power(self, n);
721 }
722 
723 namespace {
724 
725 // Computes the rank of 'input' and saves the result in-place in 'result'.
726 // 'hermitian' controls whether SVD or eigendecomposition is used for computing the singular values
727 // 'atol' and 'rtol' are the absolute and relative tolerances, respectively.
matrix_rank_impl(const Tensor & input,const std::optional<Tensor> & atol_opt,const std::optional<Tensor> & rtol_opt,bool hermitian,Tensor & result)728 Tensor& matrix_rank_impl(
729     const Tensor& input,
730     const std::optional<Tensor>& atol_opt,
731     const std::optional<Tensor>& rtol_opt,
732     bool hermitian,
733     Tensor& result) {
734   auto [atol, rtol] = get_atol_rtol(input, atol_opt, rtol_opt, "torch.linalg.matrix_rank");
735 
736   checkSameDevice("torch.linalg.matrix_rank", result, input);
737   checkSameDevice("torch.linalg.matrix_rank", atol, input, "atol");
738   checkSameDevice("torch.linalg.matrix_rank", rtol, input, "rtol");
739   ScalarType output_type = ScalarType::Long;
740   checkLinalgCompatibleDtype("torch.linalg.matrix_rank", result.scalar_type(), output_type);
741 
742   checkNotComplexTolerance(atol, "torch.linalg.matrix_rank", "atol");
743   checkNotComplexTolerance(rtol, "torch.linalg.matrix_rank", "rtol");
744 
745   // NumPy doesn't take into account possible input with no elements and it errors on max not defined for this case
746   // Let's output 0 for this case, since that kind of matrices have zero number of non-zero rows, hence rank is 0.
747   if (input.sym_numel() == 0) {
748     result.fill_(0);
749     return result;
750   }
751 
752   // We compute matrix rank as the number of singular or absolute eigen values
753   // that are above max(atol, rtol * max(S)) threshold
754   Tensor S, max_S;
755   if (!hermitian) {
756     S = at::linalg_svdvals(input);
757     // singular values are sorted in descending order
758     max_S = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1);
759   } else {
760     S = at::linalg_eigvalsh(input);
761     S = S.abs();
762     // eigenvalues are sorted in ascending order starting with negative values, we need a maximum value of abs(eigenvalues)
763     max_S = S.amax(/*dim=*/-1, /*keepdim=*/true);
764   }
765 
766   Tensor tol = at::max(atol.unsqueeze(-1), rtol.unsqueeze(-1) * max_S);
767 
768   if (isTensorSubclassLike(input)) {
769      result = at::sum(S > tol, /*dim=*/-1);
770      return result;
771   }
772 
773   result = at::sum_out(result, S > tol, /*dim=*/-1);
774   return result;
775 }
776 
get_matrix_rank_result_tensor(const Tensor & input)777 Tensor get_matrix_rank_result_tensor(const Tensor& input) {
778   // Matrices or batch of matrices are allowed
779   checkIsMatrix(input, "torch.linalg.matrix_rank", "input");
780   // For Composite Compliance, allocate `result` of correct shape to
781   // avoid resizing in `out` variant.
782   // See also `NOTE [matrix rank output shape]`
783   auto result_shape =
784       SymIntArrayRef(input.sym_sizes().cbegin(), input.sym_sizes().cend() - 2);
785   Tensor result =
786       at::empty_symint(result_shape, input.options().dtype(ScalarType::Long));
787 
788   return result;
789 }
790 
791 }  // anonymous namespace
792 
linalg_matrix_rank_out(const Tensor & input,const std::optional<Tensor> & atol_opt,const std::optional<Tensor> & rtol_opt,bool hermitian,Tensor & result)793 Tensor& linalg_matrix_rank_out(
794     const Tensor& input,
795     const std::optional<Tensor>& atol_opt,
796     const std::optional<Tensor>& rtol_opt,
797     bool hermitian,
798     Tensor& result) {
799   // Matrices or batch of matrices are allowed
800   checkIsMatrix(input, "torch.linalg.matrix_rank", "input");
801   auto result_shape =
802     IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2);
803   at::native::resize_output(result, result_shape);
804   return matrix_rank_impl(input, atol_opt, rtol_opt, hermitian, result);
805 }
806 
linalg_matrix_rank_out(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian,Tensor & result)807 Tensor& linalg_matrix_rank_out(const Tensor& input, std::optional<double> atol, std::optional<double> rtol, bool hermitian, Tensor& result) {
808   auto [atol_tensor, rtol_tensor] = get_atol_rtol(input, atol, rtol);
809   result = linalg_matrix_rank_out(input, atol_tensor, rtol_tensor, hermitian, result);
810   return result;
811 }
812 
linalg_matrix_rank(const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian)813 Tensor linalg_matrix_rank(const Tensor& input, const std::optional<Tensor>& atol, const std::optional<Tensor>& rtol, bool hermitian) {
814   auto result = get_matrix_rank_result_tensor(input);
815   return matrix_rank_impl(input, atol, rtol, hermitian, result);
816 }
817 
linalg_matrix_rank(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)818 Tensor linalg_matrix_rank(const Tensor& input, std::optional<double> atol, std::optional<double> rtol, bool hermitian) {
819   auto result = get_matrix_rank_result_tensor(input);
820 
821   auto [atol_tensor, rtol_tensor] = get_atol_rtol(input, atol, rtol);
822 
823   return matrix_rank_impl(input, atol_tensor, rtol_tensor, hermitian, result);
824 }
825 
linalg_matrix_rank_out(const Tensor & input,const Tensor & tol,bool hermitian,Tensor & result)826 Tensor& linalg_matrix_rank_out(const Tensor& input, const Tensor& tol, bool hermitian, Tensor& result) {
827   // For NumPy compatibility tol is not scaled with max(singular_value) if the value for tol is provided
828   // It is assumed that the provided value is the absolute tolerance
829   Tensor rtol = at::zeros({}, tol.options());
830   result = at::linalg_matrix_rank_outf(input, tol, rtol, hermitian, result);
831   return result;
832 }
833 
linalg_matrix_rank_out(const Tensor & input,double tol,bool hermitian,Tensor & result)834 Tensor& linalg_matrix_rank_out(const Tensor& input, double tol, bool hermitian, Tensor& result) {
835   // For NumPy compatibility tol is not scaled with max(singular_value) if the value for tol is provided
836   // It is assumed that the provided value is the absolute tolerance
837   result = at::linalg_matrix_rank_outf(input, tol, 0.0, hermitian, result);
838   return result;
839 }
840 
linalg_matrix_rank(const Tensor & input,const Tensor & tol,bool hermitian)841 Tensor linalg_matrix_rank(const Tensor& input, const Tensor& tol, bool hermitian) {
842   auto result = get_matrix_rank_result_tensor(input);
843   return matrix_rank_impl(input, tol, at::zeros({}, tol.options()), hermitian, result);
844 }
845 
linalg_matrix_rank(const Tensor & input,double tol,bool hermitian)846 Tensor linalg_matrix_rank(const Tensor& input, double tol, bool hermitian) {
847   auto result = get_matrix_rank_result_tensor(input);
848 
849   auto [atol_tensor, rtol_tensor] = get_atol_rtol(input, tol, 0.0);
850 
851   return matrix_rank_impl(input, atol_tensor, rtol_tensor, hermitian, result);
852 }
853 
854 // multi_dot helper functions
855 namespace {
856 
857 /**
858  * @brief Computes the optimal matrix chain multiplication order
859  *
860  * Follows the dynamic programming algorithm from Cormen et al.,
861  * "Introduction to Algorithms, Third Edition", Chapter 15.2,
862  * p. 370-378. Note that the book uses 1-based indexing.
863  *
864  * The cost of multiplying two matrices with sizes p x q and q x r
865  * is defined here as p * q * r. The optimal multiplication order
866  * is the one that minimizes the total cost.
867  *
868  * @param tensors list of 2D tensors
869  * @return a 2D vector s used by #matrix_chain_multiplication to construct
870  *         the optimal matrix multiplication order. The optimal multiplication
871  *         order for multiplying tensors i...j is to multiply tensors i...s[i, j]
872  *         and tensors (s[i, j] + 1)...j first and then the result of that.
873  */
matrix_chain_order(TensorList tensors)874 std::vector<std::vector<int64_t>> matrix_chain_order(TensorList tensors) {
875   const size_t n = tensors.size();
876 
877   // Tensor i has dimensions p[i] x p[i + 1]
878   std::vector<int64_t> p(n + 1);
879   for (const auto i : c10::irange(n)) {
880     p[i] = tensors[i].size(0);
881   }
882   p[n] = tensors[n - 1].size(1);
883 
884   // m[i, j] = k where k is the minimum cost for multiplying tensors i...j
885   std::vector<std::vector<int64_t>> m(n, std::vector<int64_t>(n, 0));
886 
887   // s[i, j] = k where k is the index at which to split the list such that
888   // optimally multiplying matrices i...k and k...j first and then the resulting
889   // matrices is the optimal order for multiplying matrices i...j.
890   std::vector<std::vector<int64_t>> s(n, std::vector<int64_t>(n));
891 
892   // Compute the optimal multiplication order
893   for (const auto l : c10::irange(1, n)) {
894     for (const auto i : c10::irange(n - l)) {
895       const auto j = i + l;
896       m[i][j] = std::numeric_limits<int64_t>::max();
897       for (const auto k : c10::irange(i, j)) {
898         const auto q = m[i][k] + m[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
899         if (q < m[i][j]) {
900           m[i][j] = q;
901           s[i][j] = k;
902         }
903       }
904     }
905   }
906 
907   return s;
908 }
909 
910 /**
911  * @brief Recursively multiplies the tensors i...j using the given order
912  *
913  * @param tensors matrices to multiply together
914  * @param order optimal chain multiplication order from #matrix_chain_order
915  * @param i index of first tensor to be multiplied
916  * @param j index of last tensor to be multiplied
917  * @return Tensor result of multiplying tensors[i...j] together.
918  */
matrix_chain_multiplication(TensorList tensors,const std::vector<std::vector<int64_t>> & order,int64_t i,int64_t j)919 Tensor matrix_chain_multiplication(
920     TensorList tensors,
921     const std::vector<std::vector<int64_t>>& order,
922     int64_t i,
923     int64_t j) {
924   if (i == j) {
925     return tensors[i];
926   }
927   return at::mm(
928       matrix_chain_multiplication(tensors, order, i, order[i][j]),
929       matrix_chain_multiplication(tensors, order, order[i][j] + 1, j));
930 }
931 
932 // Implements torch.linalg.multi_dot
multi_dot_impl(TensorList _tensors,std::optional<Tensor> _out)933 Tensor multi_dot_impl(TensorList _tensors, std::optional<Tensor> _out) {
934   const size_t n = _tensors.size();
935   TORCH_CHECK(n >= 2, "multi_dot(): expected at least 2 tensors but got ", n);
936 
937   std::vector<int64_t> out_shape;
938   std::vector<Tensor> tensors(n);
939 
940   // If the first tensor is 1D of size n view it as a row vector (1, n)
941   if (_tensors[0].dim() == 1) {
942     tensors[0] = _tensors[0].unsqueeze(0);
943   } else if (_tensors[0].dim() == 2) {
944     tensors[0] = _tensors[0];
945     out_shape.emplace_back(tensors[0].size(0));
946   } else {
947     TORCH_CHECK(
948         false,
949         "multi_dot(): the first tensor must be 1D or 2D but got ",
950         _tensors[0].dim(),
951         "D");
952   }
953 
954   // If the last tensor is 1D of size n view it as a column vector (n, 1)
955   if (_tensors[n - 1].dim() == 1) {
956     tensors[n - 1] = _tensors[n - 1].unsqueeze(-1);
957   } else if (_tensors[n - 1].dim() == 2) {
958     tensors[n - 1] = _tensors[n - 1];
959     out_shape.emplace_back(tensors[n - 1].size(1));
960   } else {
961     TORCH_CHECK(
962         false,
963         "multi_dot(): the last tensor must be 1D or 2D but got ",
964         _tensors[n - 1].dim(),
965         "D");
966   }
967 
968   // Ensure middle tensors are 2D
969   for (const auto i : c10::irange(1, n - 1)) {
970     TORCH_CHECK(
971         _tensors[i].dim() == 2,
972         "multi_dot(): tensor ",
973         i,
974         " must be 2D but got ",
975         _tensors[i].dim(),
976         "D");
977     tensors[i] = _tensors[i];
978   }
979 
980   // Ensure all tensors have the same device and dtype and check
981   // that the shapes can be multiplied
982   const auto dtype = tensors[0].dtype();
983   const auto device = tensors[0].device();
984   for (const auto i : c10::irange(1, n)) {
985     TORCH_CHECK(
986         tensors[i].dtype() == dtype,
987         "multi_dot(): all tensors must have be the same dtype but tensor 0 is ",
988         dtype,
989         " and tensor ",
990         i,
991         " ",
992         tensors[i].dtype());
993     TORCH_CHECK(
994         tensors[i].device() == device,
995         "multi_dot(): all tensors must be on the same device but tensor 0 is on ",
996         device,
997         " and tensor ",
998         i,
999         " on ",
1000         tensors[i].device());
1001     TORCH_CHECK(
1002         tensors[i - 1].size(-1) == tensors[i].size(0),
1003         "multi_dot(): tensors ",
1004         i - 1,
1005         " and ",
1006         i,
1007         " with shapes ",
1008         _tensors[i - 1].sizes(),
1009         " and ",
1010         _tensors[i].sizes(),
1011         " cannot be multiplied")
1012   }
1013 
1014   Tensor result;
1015 
1016   if (_out.has_value()) {
1017     auto out = *_out;
1018     TORCH_CHECK(
1019         dtype == out.dtype(),
1020         "multi_dot(): expected out tensor to have dtype ",
1021         dtype,
1022         " but got ",
1023         out.dtype());
1024     TORCH_CHECK(
1025         device == out.device(),
1026         "multi_dot(): expected out tensor to be on device ",
1027         device,
1028         " but got ",
1029         out.device());
1030 
1031     // If the last and last tensors have shapes (a, b) and (b, c) the
1032     // output has shape (a, c). If either the first or last tensor is 1D
1033     // a and/or c dimensions will be implicitly size 1 and will be omitted
1034     // from the output. e.g. for inputs (a, b) x (b) the output has shape (a,).
1035     at::native::resize_output(out, out_shape);
1036 
1037     // View output as 2D for simplicity of computation.
1038     result = out.view({tensors[0].size(0), tensors.back().size(-1)});
1039   }
1040 
1041   // The resize_ and view calls below are to ensure the
1042   // output shape respects the original dimensionality of
1043   // the first and last tensors which we are now viewed as 2D
1044 
1045   if (tensors.size() == 2) {
1046     return _out.has_value() ? at::mm_out(result, tensors[0], tensors[1])
1047                          : at::mm(tensors[0], tensors[1]).view(out_shape);
1048   }
1049 
1050   // Why the separate implementation for 3 matrices?
1051   // The logic for three matrices is much faster when done directly
1052   // Requires 1 comparison to 4 comparisons and fewer arithmetic operations
1053   if (tensors.size() == 3) {
1054     const auto a = tensors[0].size(0);
1055     const auto b = tensors[1].size(0);
1056     const auto c = tensors[2].size(0);
1057     const auto d = tensors[2].size(1);
1058 
1059     // The matrices are of size (a x b), (b x c), (c x d)
1060     // cost_1 is the cost of parenthesizing (a x b) and (b x c) and then
1061     // combining (c x d) cost_2 is the cost of parenthesizing (b x c) and (c x
1062     // d) and then combining (a x b)
1063     const auto cost_1 = (a * c) * (b + d);
1064     const auto cost_2 = (b * d) * (a + c);
1065 
1066     if (cost_1 > cost_2) {
1067       return _out.has_value()
1068           ? at::mm_out(result, tensors[0], at::mm(tensors[1], tensors[2]))
1069           : at::mm(tensors[0], at::mm(tensors[1], tensors[2])).view(out_shape);
1070     } else {
1071       return _out.has_value()
1072           ? at::mm_out(result, at::mm(tensors[0], tensors[1]), tensors[2])
1073           : at::mm(at::mm(tensors[0], tensors[1]), tensors[2]).view(out_shape);
1074     }
1075   }
1076 
1077   // Algorithm for multiplying 4 or more matrices
1078   const auto order = matrix_chain_order(tensors);
1079   const int64_t i = 0;
1080   const int64_t j = n - 1;
1081 
1082   if (_out.has_value()) {
1083     // We manually implement the first recursive layer here so we can use mm_out
1084     // for the final multiplication
1085     return at::mm_out(
1086         result,
1087         matrix_chain_multiplication(tensors, order, i, order[i][j]),
1088         matrix_chain_multiplication(tensors, order, order[i][j] + 1, j));
1089   }
1090   return matrix_chain_multiplication(tensors, order, i, j).view(out_shape);
1091 }
1092 
1093 } // namespace
1094 
linalg_multi_dot(TensorList tensors)1095 Tensor linalg_multi_dot(TensorList tensors) {
1096   return multi_dot_impl(tensors, std::nullopt);
1097 }
1098 
linalg_multi_dot_out(TensorList tensors,Tensor & result)1099 Tensor& linalg_multi_dot_out(TensorList tensors, Tensor& result) {
1100   multi_dot_impl(tensors, result);
1101   return result;
1102 }
1103 
chain_matmul(TensorList matrices)1104 Tensor chain_matmul(TensorList matrices) {
1105   TORCH_WARN_ONCE(
1106       "torch.chain_matmul is deprecated and will be removed in a future PyTorch release. ",
1107       "Use torch.linalg.multi_dot instead, which accepts a list of two or more tensors rather than ",
1108       "multiple parameters."
1109   );
1110   checkAllSameDim(matrices, 2);
1111 
1112   TORCH_CHECK(
1113       !matrices.empty(), "chain_matmul(): Expected one or more matrices");
1114 
1115   if (matrices.size() == 1) {
1116     return matrices[0].clone();
1117   }
1118 
1119   return at::native::linalg_multi_dot(matrices);
1120 }
1121 
chain_matmul_out(TensorList matrices,Tensor & result)1122 Tensor& chain_matmul_out(TensorList matrices, Tensor& result) {
1123   TORCH_WARN_ONCE(
1124       "torch.chain_matmul is deprecated and will be removed in a future PyTorch release. ",
1125       "Use torch.linalg.multi_dot instead, which accepts a list of two or more tensors rather than ",
1126       "multiple parameters."
1127   );
1128   checkAllSameDim(matrices, 2);
1129 
1130   TORCH_CHECK(
1131       !matrices.empty(), "chain_matmul(): Expected one or more matrices");
1132 
1133   if (matrices.size() == 1) {
1134     at::native::resize_output(result, matrices[0].sizes());
1135     return result.copy_(matrices[0]);
1136   }
1137 
1138   return at::native::linalg_multi_dot_out(matrices, result);
1139 }
1140 
check_1d(const Tensor & t,const char * arg,const char * fn)1141 static void check_1d(const Tensor& t, const char* arg, const char* fn) {
1142  TORCH_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D");
1143 }
1144 
check_addr_scalar(const ScalarType dtype,const Scalar & scalar,const std::string & scalar_name)1145 static void check_addr_scalar(const ScalarType dtype,
1146                               const Scalar& scalar,
1147                               const std::string& scalar_name) {
1148   TORCH_CHECK(
1149     !scalar.isBoolean() || dtype == ScalarType::Bool,
1150     "Boolean ", scalar_name, " only supported for Boolean results.");
1151   TORCH_CHECK(
1152     isFloatingType(dtype) || isComplexType(dtype) || scalar.isIntegral(true),
1153     "For integral input tensors, "
1154     "argument ", scalar_name ," must not be a floating point number.");
1155 }
1156 
build_addr_iter(Tensor & result,const Tensor & self,const Tensor & vec1,const Tensor & vec2)1157 static TensorIterator build_addr_iter(Tensor& result,
1158                                       const Tensor& self,
1159                                       const Tensor& vec1,
1160                                       const Tensor& vec2) {
1161   check_1d(vec1, "vec1", "addr");
1162   check_1d(vec2, "vec2", "addr");
1163 
1164   const auto vec1_size0 = vec1.sizes()[0];
1165   const auto vec2_size0 = vec2.sizes()[0];
1166   auto self_ = &result == &self
1167     ? c10::MaybeOwned<Tensor>::borrowed(self)
1168     : expand_size(self, {vec1_size0, vec2_size0}, "addr");
1169   TORCH_CHECK(
1170     self_->dim() == 2,
1171     "2D tensor expected, got ", self_->dim(), "D tensor for input"
1172   );
1173   TORCH_CHECK(
1174     self_->sizes()[0] == vec1_size0 && self_->sizes()[1] == vec2_size0,
1175     "size mismatch, input: ", self_->sizes(),
1176     ", v1: ", vec1.sizes(),
1177     ", v2: ", vec2.sizes()
1178   );
1179 
1180   auto iter = TensorIteratorConfig()
1181     .set_check_mem_overlap(true)
1182     .add_output(result)
1183     .add_owned_const_input(*self_)
1184     .add_owned_const_input(vec1.reshape({vec1_size0, 1}))
1185     .add_const_input(vec2)
1186     .allow_cpu_scalars(true)
1187     .promote_inputs_to_common_dtype(true)
1188     .cast_common_dtype_to_outputs(true)
1189     .enforce_safe_casting_to_output(true)
1190     .build();
1191   return iter;
1192 }
1193 
addr(const Tensor & self,const Tensor & vec1,const Tensor & vec2,const Scalar & beta,const Scalar & alpha)1194 Tensor addr(const Tensor& self,
1195             const Tensor& vec1, const Tensor& vec2,
1196             const Scalar& beta, const Scalar& alpha) {
1197   Tensor result;
1198   auto iter = build_addr_iter(result, self, vec1, vec2);
1199 
1200   check_addr_scalar(iter.dtype(), beta, "beta");
1201   check_addr_scalar(iter.dtype(), alpha, "alpha");
1202 
1203   addr_stub(iter.device_type(), iter, beta, alpha);
1204   return iter.output();
1205 }
1206 
addr_(Tensor & self,const Tensor & vec1,const Tensor & vec2,const Scalar & beta,const Scalar & alpha)1207 Tensor& addr_(Tensor& self,
1208               const Tensor& vec1, const Tensor& vec2,
1209               const Scalar& beta, const Scalar& alpha) {
1210   return at::addr_out(self, self, vec1, vec2, beta, alpha);
1211 }
1212 
addr_out(const Tensor & self,const Tensor & vec1,const Tensor & vec2,const Scalar & beta,const Scalar & alpha,Tensor & result)1213 Tensor& addr_out(const Tensor& self,
1214                  const Tensor& vec1, const Tensor& vec2,
1215                  const Scalar& beta, const Scalar& alpha, Tensor &result) {
1216   auto iter = build_addr_iter(result, self, vec1, vec2);
1217 
1218   check_addr_scalar(iter.dtype(), beta, "beta");
1219   check_addr_scalar(iter.dtype(), alpha, "alpha");
1220 
1221   addr_stub(iter.device_type(), iter, beta, alpha);
1222   return result;
1223 }
1224 
1225 // The math_addr and math_addr_out functions support backends
1226 // other than CPU and CUDA, such as XLA.
1227 // They are implemented using the composition of existing ops
math_addr(const Tensor & self,const Tensor & vec1,const Tensor & vec2,const Scalar & beta,const Scalar & alpha)1228 Tensor math_addr(const Tensor& self,
1229                  const Tensor& vec1, const Tensor& vec2,
1230                  const Scalar& beta, const Scalar& alpha) {
1231   // when beta==0, values in self should be ignored,
1232   // nans and infs in self should not propagate.
1233   Tensor out;
1234   if (beta.toComplexDouble() == 0.0) {
1235     if (alpha.toComplexDouble() == 1.0) {
1236       out = at::outer(vec1, vec2);
1237     } else {
1238       out = alpha * at::outer(vec1, vec2);
1239     }
1240   } else if (beta.toComplexDouble() == 1.0) {
1241     if (alpha.toComplexDouble() == 1.0) {
1242       out = self + at::outer(vec1, vec2);
1243     } else {
1244       out = self + alpha * at::outer(vec1, vec2);
1245     }
1246   } else if (alpha.toComplexDouble() == 1.0) {
1247     out = beta * self + at::outer(vec1, vec2);
1248   } else {
1249     out = beta * self + alpha * at::outer(vec1, vec2);
1250   }
1251   auto result_type = c10::promoteTypes(c10::promoteTypes(self.scalar_type(), vec1.scalar_type()), vec2.scalar_type());
1252   return out.to(c10::TensorOptions().dtype(result_type));
1253 }
1254 
math_addr_out(const Tensor & self,const Tensor & vec1,const Tensor & vec2,const Scalar & beta,const Scalar & alpha,Tensor & result)1255 Tensor& math_addr_out(const Tensor& self,
1256                       const Tensor& vec1, const Tensor& vec2,
1257                       const Scalar& beta, const Scalar& alpha, Tensor &result) {
1258   auto addr_result = at::addr(self, vec1, vec2, beta, alpha);
1259 
1260   // Validates safe casting
1261   const auto result_dtype = addr_result.scalar_type();
1262   TORCH_CHECK(canCast(result_dtype, result.scalar_type()),
1263               "result type ", result_dtype,
1264               " can't be cast to the desired output type ", result.scalar_type());
1265 
1266   at::native::resize_output(result, addr_result.sizes().vec());
1267   result.copy_(addr_result);
1268   return result;
1269 }
1270 
1271 // torch.ger, alias for torch.outer
ger_out(const Tensor & self,const Tensor & vec2,Tensor & result)1272 Tensor& ger_out(const Tensor& self, const Tensor& vec2, Tensor &result) {
1273   TORCH_WARN("torch.ger is deprecated and will be removed in a future PyTorch release. "
1274              "Use torch.outer instead.");
1275   return at::outer_out(result, self, vec2);
1276 }
1277 
ger(const Tensor & self,const Tensor & vec2)1278 Tensor ger(const Tensor& self, const Tensor& vec2) {
1279   return self.outer(vec2);
1280 }
1281 
inner_out(const Tensor & self,const Tensor & other,Tensor & out)1282 Tensor& inner_out(const Tensor& self, const Tensor& other, Tensor& out) {
1283   checkDeviceType("inner()", {out, self, other}, self.device().type());
1284 
1285   // If either self or other is a scalar just multiply them
1286   if (self.dim() == 0 || other.dim() == 0) {
1287     at::mul_out(out, self, other);
1288     return out;
1289   }
1290 
1291   // Last dimension should match (tensordot does not enforce this)
1292   TORCH_CHECK(
1293       self.size(-1) == other.size(-1),
1294       "inner() the last dimension must match on both input tensors but got shapes ",
1295       self.sizes(),
1296       " and ",
1297       other.sizes());
1298 
1299   at::tensordot_out(out, self, other, -1, -1);
1300   return out;
1301 }
1302 
inner(const Tensor & self,const Tensor & other)1303 Tensor inner(const Tensor& self, const Tensor& other) {
1304   checkDeviceType("inner()", {self, other}, self.device().type());
1305 
1306   // If either self or other is a scalar just multiply them
1307   if (self.dim() == 0 || other.dim() == 0) {
1308     return self * other;
1309   }
1310 
1311   // Last dimension should match (tensordot does not enforce this)
1312   TORCH_CHECK(
1313       self.sym_size(-1) == other.sym_size(-1),
1314       "inner() the last dimension must match on both input tensors but got shapes ",
1315       self.sym_sizes(),
1316       " and ",
1317       other.sym_sizes());
1318 
1319   return at::tensordot(self, other, -1, -1);
1320 }
1321 
outer_out(const Tensor & self,const Tensor & vec2,Tensor & result)1322 Tensor& outer_out(const Tensor& self, const Tensor& vec2, Tensor &result) {
1323   check_1d(self, "self", "outer");
1324   check_1d(vec2, "vec2", "outer");
1325 
1326   // torch.outer is implemented as a composite op using reshape and mul
1327   at::mul_out(result, self.reshape({self.size(0), 1}), vec2);
1328   return result;
1329 }
1330 
outer(const Tensor & self,const Tensor & vec2)1331 Tensor outer(const Tensor& self, const Tensor& vec2) {
1332   check_1d(self, "self", "outer");
1333   check_1d(vec2, "vec2", "outer");
1334 
1335   return self.reshape_symint({self.sym_size(0), 1}) * vec2;
1336 }
1337 
1338 
1339 #if !defined(C10_MOBILE)
1340 #define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...)                                               \
1341         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(                                                 \
1342             kBFloat16, kHalf, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
1343             TYPE, NAME, __VA_ARGS__)
1344 #else
1345 // Include half dtype in ADDMM. Used to build ExecuTorch in xplat.
1346 #if defined(C10_MOBILE_HALF)
1347 #define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...)        \
1348         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, \
1349             TYPE, NAME, __VA_ARGS__)
1350 #else
1351 #define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...)        \
1352         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, \
1353             TYPE, NAME, __VA_ARGS__)
1354 #endif
1355 #endif
1356 
1357 
get_mkldnn_matmul_min_dim()1358 static inline int64_t get_mkldnn_matmul_min_dim() {
1359   static auto value = [&] {
1360     const int64_t default_min_dim = [&] {
1361       // Minimum dimension requirement for MKLDNN; derived based on experiments.
1362       //it's enabled on all Neoverse cpus.
1363       return is_arm_neoverse() ? 8 : 0;
1364     }();
1365     const char* ptr = std::getenv("TORCH_MKLDNN_MATMUL_MIN_DIM");
1366     return ptr != nullptr ? std::atoi(ptr) : default_min_dim;
1367   }();
1368   return value;
1369 }
1370 
1371 
get_mkldnn_matmul_min_size()1372 static inline int64_t get_mkldnn_matmul_min_size() {
1373   static auto value = [&] {
1374     const int64_t default_min_size = [&] {
1375       // Minimum size requirement for MKLDNN; derived based on experiments.
1376       // it's enabled on all Neoverse cpus.
1377       return is_arm_neoverse() ? 8 * 1024 : 0;
1378     }();
1379     const char* ptr = std::getenv("TORCH_MKLDNN_MATMUL_MIN_SIZE");
1380     return ptr != nullptr ? std::atoi(ptr) : default_min_size;
1381   }();
1382   return value;
1383 }
1384 
1385 
apply_mkldnn_matmul_heur(int64_t m,int64_t k,int64_t n)1386 static inline bool apply_mkldnn_matmul_heur(int64_t m, int64_t k, int64_t n) {
1387   const int64_t min_dim = get_mkldnn_matmul_min_dim();
1388   const int64_t min_size = get_mkldnn_matmul_min_size();
1389   return at::globalContext().userEnabledMkldnn() && m > min_dim && k > min_dim && n > min_dim && m * k * n > min_size;
1390 }
1391 
1392 
addmm_impl_cpu_(Tensor & result,const Tensor & self,Tensor m1,Tensor m2,const Scalar & beta,const Scalar & alpha)1393 static void addmm_impl_cpu_(
1394     Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
1395   TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
1396 
1397   TORCH_CHECK(
1398     m1.dtype() == m2.dtype(),
1399     "expected m1 and m2 to have the same dtype, but got: ", m1.dtype(), " != ", m2.dtype()
1400   )
1401   // Array access is faster than .size(n) and .stride(n)
1402   const auto self_sizes = self.sizes();
1403   auto m1_strides = m1.strides();
1404   auto m1_sizes = m1.sizes();
1405   auto m2_strides = m2.strides();
1406   auto m2_sizes = m2.sizes();
1407 
1408   TORCH_CHECK(
1409       self_sizes[0] == m1_sizes[0] && self_sizes[1] == m2_sizes[1],
1410       "input shape is incompatible with matrix multiplication (",
1411       m1_sizes[0], "x", m1_sizes[1], " @ ", m2_sizes[0], "x", m2_sizes[1], " != ",
1412       self_sizes[0], "x", self_sizes[1], ")");
1413 
1414   at::native::resize_output(result, self_sizes);
1415   const auto result_strides = result.strides();
1416   const auto result_sizes = result.sizes();
1417 
1418   if (result.numel() == 0) {
1419     return;
1420   }
1421 
1422   // Some paths in the code below do not handle multiplications of the form [a, 0] x [0, b]
1423   if (m1_sizes[1] == 0) {
1424     if (beta.toComplexDouble() == 0.0) {
1425       result.zero_();
1426     } else {
1427       if (!self.is_same(result)) {
1428         result.copy_(self);
1429       }
1430       result.mul_(beta);
1431     }
1432     return;
1433   }
1434 
1435   if (beta.toComplexDouble() != 0.0 && !self.is_same(result)) {
1436     result.copy_(self);
1437   }
1438 
1439   bool transpose_c = false;
1440   Tensor c;
1441 
1442   // Cast result as matrix a
1443   if (result_strides[0] == 1 &&
1444       (result_sizes[1] == 1 || result_strides[1] >= std::max(int64_t{1}, result_sizes[0]))) {
1445     transpose_c = false;
1446     c = result.resolve_conj();
1447   } else if (result_strides[1] == 1 &&
1448              (result_sizes[0] == 1 || result_strides[0] >= std::max(int64_t{1}, result_sizes[1]))) {
1449     std::swap(m1, m2);
1450     std::swap(m1_sizes, m2_sizes);
1451     std::swap(m1_strides, m2_strides);
1452     transpose_c = true;
1453     c = result.resolve_conj();
1454   } else {
1455     transpose_c = false;
1456     // make c FORTRAN contiguous
1457     c = result.resolve_conj().transpose(0, 1).contiguous().transpose_(0, 1);
1458   }
1459 
1460   const int64_t m = result_sizes[transpose_c ? 1 : 0];
1461   const int64_t n = result_sizes[transpose_c ? 0 : 1];
1462   const int64_t k = m1_sizes[transpose_c ? 0 : 1];
1463 
1464   // Cast m1 as matrix a
1465   bool transpose_a = false;
1466   Tensor a;
1467   /* Need lda >= max(1, (transpose_a ? k : m)) */
1468   if (m1_strides[transpose_c ? 1 : 0] == 1 &&
1469       m1_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, m)) {
1470     transpose_a = false;
1471     a = m1.resolve_conj();
1472   } else if (m1_strides[transpose_c ? 0 : 1] == 1 &&
1473              m1_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, k)) {
1474     transpose_a = true;
1475     a = m1;
1476   } else {
1477     transpose_a = !transpose_c;
1478     a = m1.clone(at::MemoryFormat::Contiguous);
1479   }
1480 
1481   // Cast m2 as matrix b
1482   bool transpose_b = false;
1483   Tensor b;
1484   /* Need ldm2_ >= max(1, (transpose_m2 == 'n' ? k : n)) */
1485   if (m2_strides[transpose_c ? 1 : 0] == 1 &&
1486       m2_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, k)) {
1487     transpose_b = false;
1488     b = m2.resolve_conj();
1489   } else if (m2_strides[transpose_c ? 0 : 1] == 1 &&
1490              m2_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, n)) {
1491     transpose_b = true;
1492     b = m2;
1493   } else {
1494     transpose_b = !transpose_c;
1495     b = m2.clone(at::MemoryFormat::Contiguous);
1496   }
1497 
1498   const int64_t lda = a.strides()[(transpose_a == transpose_c) ? 1 : 0];
1499   const int64_t ldb = b.strides()[(transpose_b == transpose_c) ? 1 : 0];
1500   const int64_t ldc = c.strides()[transpose_c ? 0 : 1];
1501 
1502   // Always ensure the conjugation for c is resolved since there's no way to specify c's conjugation in the gemm call
1503   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c.is_conj());
1504 
1505   bool dispatched = false;
1506 #if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
1507   // On AArch64 if LHS matrix in BLAS routine is transposed but RHS is not then
1508   // it is faster to call oneDNN matrix multiplication primitive with RHS*LHS
1509   // that will call then into Arm® Compute Library (ACL) GEMM kernel and also
1510   // additionally have support for running kernel with BF16 instructions
1511   if (transpose_c) {
1512     bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
1513     if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
1514       try {
1515         mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
1516         // We have dispatched to ACL GEMM for single precision float
1517         // so do not need to dispatch to BLAS GEMM below
1518         dispatched = true;
1519       } catch (const std::exception& e) {
1520         TORCH_WARN("mkldnn_matmul failed, switching to BLAS gemm:", e.what());
1521         at::globalContext().setUserEnabledMkldnn(false);
1522       }
1523     }
1524   }
1525 #endif
1526 
1527   if(!dispatched) {
1528     // Apply BLAS routine
1529     _AT_DISPATCH_ADDMM_TYPES(result.scalar_type(), "addmm_impl_cpu_", [&]{
1530           using opmath_t = at::opmath_type<scalar_t>;
1531           at::native::cpublas::gemm(
1532               transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
1533               transpose_b ? b.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
1534               m, n, k,
1535               alpha.to<opmath_t>(),
1536               a.const_data_ptr<scalar_t>(), lda,
1537               b.const_data_ptr<scalar_t>(), ldb,
1538               beta.to<opmath_t>(),
1539               c.mutable_data_ptr<scalar_t>(), ldc);
1540         });
1541   }
1542 
1543   if (!c.is_same(result)) {
1544     result.copy_(c);
1545   }
1546 }
1547 
addbmm_impl_(Tensor & result,const Tensor & self,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)1548 static void addbmm_impl_(
1549     Tensor &result, const Tensor &self, const Tensor &batch1, const Tensor &batch2, const Scalar& beta, const Scalar& alpha) {
1550   TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
1551   TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
1552   TORCH_CHECK(batch1.size(0) == batch2.size(0),
1553       "batch1 and batch2 must have same number of batches, got ",
1554       batch1.size(0), " and ", batch2.size(0));
1555   TORCH_CHECK(batch1.size(2) == batch2.size(1),
1556       "Incompatible matrix sizes for bmm (",
1557       batch1.size(1), "x", batch1.size(2), " and ",
1558       batch2.size(1), "x", batch2.size(2), ")");
1559 
1560   const int64_t dim1 = batch1.size(1);
1561   const int64_t dim2 = batch2.size(2);
1562   TORCH_CHECK(self.size(0) == dim1 && self.size(1) == dim2,
1563       "self tensor does not match matmul output shape");
1564 
1565   result.resize_as_(self);
1566 
1567   if (beta.to<c10::complex<double>>() != 0.0 && !self.is_same(result)) {
1568     result.copy_(self);
1569   }
1570 
1571   const int64_t num_batches = batch1.size(0);
1572 
1573   if (num_batches == 0) {
1574     if (beta.to<c10::complex<double>>() != 0.0) {
1575       result.mul_(beta);
1576     } else {
1577       result.zero_();
1578     }
1579     return;
1580   }
1581 
1582   auto adjusted_beta(beta);
1583   for (const auto batch : c10::irange(num_batches)) {
1584     result.addmm_(batch1[batch], batch2[batch], adjusted_beta, alpha);
1585     adjusted_beta = 1; // accumulate output once
1586   }
1587 }
1588 
addbmm_out(const Tensor & self,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha,Tensor & result)1589 Tensor& addbmm_out(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, Tensor& result) {
1590   auto b_self = expand_size(self, {batch1.size(1), batch2.size(2)}, "addbmm_out");
1591   {
1592     at::NoNamesGuard guard;
1593     addbmm_impl_(result, *b_self, batch1, batch2, beta, alpha);
1594   }
1595   auto names = at::namedinference::propagate_names_for_addmm(batch1, batch2, self);
1596   at::namedinference::propagate_names_if_nonempty(result, names);
1597   return result;
1598 }
1599 
addbmm_(Tensor & self,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)1600 Tensor &addbmm_(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
1601   return native::addbmm_out(self, batch1, batch2, beta, alpha, self);
1602 }
1603 
addbmm(const Tensor & self,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)1604 Tensor addbmm(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
1605   Tensor result = at::empty({0}, self.options());
1606   return native::addbmm_out(self, batch1, batch2, beta, alpha, result);
1607 }
1608 
TORCH_IMPL_FUNC(addmm_out_cpu)1609 TORCH_IMPL_FUNC(addmm_out_cpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor &result) {
1610   auto b_self = expand_size(self, {mat1.sizes()[0], mat2.sizes()[1]}, "addmm_out");
1611   {
1612     at::NoNamesGuard guard;
1613     addmm_impl_cpu_(const_cast<Tensor&>(result), *b_self, mat1, mat2, beta, alpha);
1614   }
1615 }
1616 
TORCH_IMPL_FUNC(addmm_activation_out_cpu)1617 TORCH_IMPL_FUNC(addmm_activation_out_cpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu, const Tensor &result) {
1618   auto b_self = expand_size(self, {mat1.sizes()[0], mat2.sizes()[1]}, "addmm_out");
1619   {
1620     at::NoNamesGuard guard;
1621     addmm_impl_cpu_(const_cast<Tensor&>(result), *b_self, mat1, mat2, beta, alpha);
1622     if (use_gelu) {
1623       at::gelu_(const_cast<Tensor&>(result));
1624     } else {
1625       at::relu_(const_cast<Tensor&>(result));
1626     }
1627   }
1628 }
1629 
TORCH_IMPL_FUNC(mm_out_cpu)1630 TORCH_IMPL_FUNC(mm_out_cpu)(const Tensor & self, const Tensor & mat2, const Tensor & result) {
1631   {
1632     at::NoNamesGuard guard;
1633     addmm_impl_cpu_(const_cast<Tensor&>(result), result, self, mat2, 0, 1);
1634   }
1635 }
1636 
1637 template <typename scalar_t, bool is_bmm>
baddbmm_cpu_kernel(const Tensor & result,const Tensor & self,const Tensor & mat2,const Scalar & beta_,const Scalar & alpha_)1638 inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const Tensor& mat2, const Scalar& beta_, const Scalar& alpha_) {
1639   int64_t bs = result.size(0);
1640   int64_t is = result.size(1);
1641   int64_t js = result.size(2);
1642   int64_t ks = self.size(2);
1643 
1644   using opmath_t = at::opmath_type<scalar_t>;
1645   opmath_t alpha = alpha_.to<opmath_t>();
1646   opmath_t beta = beta_.to<opmath_t>();
1647 
1648   auto r0 = result.accessor<scalar_t, 3>();
1649   auto s0 = self.accessor<const scalar_t, 3>();
1650   auto m0 = mat2.accessor<const scalar_t, 3>();
1651 
1652   int64_t grain_size = std::max(internal::GRAIN_SIZE / (is * js * ks), (int64_t)1);
1653   using opmath_t = at::opmath_type<scalar_t>;
1654   parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) {
1655       for (const auto b : c10::irange(b_begin, b_end)) {
1656         auto r1 = r0[b];
1657         auto s1 = s0[b];
1658         auto m1 = m0[b];
1659         for (const auto i : c10::irange(is)) {
1660           auto r2 = r1[i];
1661           auto s2 = s1[i];
1662           for (const auto j : c10::irange(js)) {
1663             opmath_t acc_value = 0;//is_bmm ? opmath_t(0) : opmath_t(r2[j]);
1664             for (const auto k : c10::irange(ks)) {
1665               acc_value += static_cast<opmath_t>(s2[k]) *
1666                   static_cast<opmath_t>(m1[k][j]);
1667             }
1668             if (is_bmm) {
1669               r2[j] = acc_value;
1670             } else {
1671               // For beta == 0, the r's value will be ignored, especially for nan value.
1672               if (beta == opmath_t{0}) {
1673                 r2[j] = alpha * acc_value;
1674               } else {
1675                 r2[j] = static_cast<opmath_t>(r2[j]) * beta + alpha * acc_value;
1676               }
1677             }
1678           }
1679         }
1680       }
1681     });
1682 }
1683 
baddbmm_with_gemm_(const Tensor & result,const Tensor & mat1,const Tensor & mat2,const Scalar & beta_,const Scalar & alpha_)1684 static void baddbmm_with_gemm_(const Tensor &result, const Tensor &mat1, const Tensor &mat2, const Scalar &beta_, const Scalar &alpha_) {
1685   TORCH_INTERNAL_ASSERT(result.is_contiguous());
1686 
1687   const auto result_sizes = result.sizes();
1688   const auto result_strides = result.strides();
1689   const auto mat1_strides = mat1.strides();
1690   const auto mat2_strides = mat2.strides();
1691   const auto mat1_sizes = mat1.sizes();
1692   const auto mat2_sizes = mat2.sizes();
1693 
1694   auto is_transposed = [](const c10::IntArrayRef& strides, const c10::IntArrayRef& sizes) {
1695     return strides[1] == 1 && strides[2] >= sizes[1];
1696   };
1697 
1698   // gemm expects fortran order matrices, so we swap argument order to transpose everything
1699   const auto transpose_a = is_transposed(mat2_strides, mat2_sizes);
1700   const auto transpose_b = is_transposed(mat1_strides, mat1_sizes);
1701 
1702   const int64_t batch_size = mat1_sizes[0];
1703   const int64_t m = result_sizes[2];
1704   const int64_t n = result_sizes[1];
1705   const int64_t k = mat2_sizes[1];
1706 
1707   const int64_t lda = mat2_strides[transpose_a ? 2 : 1];
1708   const int64_t ldb = mat1_strides[transpose_b ? 2 : 1];
1709   const int64_t ldc = result_strides[1];
1710 
1711   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "baddbmm_with_gemm", [&] {
1712     using opmath_t = at::opmath_type<scalar_t>;
1713     const auto alpha = alpha_.to<opmath_t>();
1714     const auto beta = beta_.to<opmath_t>();
1715     at::native::cpublas::gemm_batched_with_stride(
1716         transpose_a ? TransposeType::Transpose : TransposeType::NoTranspose,
1717         transpose_b ? TransposeType::Transpose : TransposeType::NoTranspose,
1718         batch_size, m, n, k, alpha,
1719         mat2.const_data_ptr<scalar_t>(), lda, mat2_strides[0],
1720         mat1.const_data_ptr<scalar_t>(), ldb, mat1_strides[0],
1721         beta,
1722         result.data_ptr<scalar_t>(), ldc, result_strides[0]);
1723   });
1724 }
1725 
1726 // This tries to apply some optimizations to bmm/baddbmm:
1727 // - When the operand size is small, computation are parallelized over the batch
1728 //   dimension using OMP and naive matrix multiplication is applied.
1729 // - When the operand size is larger than the threshold, if compiled with MKL, MKL's batch gemm is used.
1730 // - Otherwise, we use a series of matrix multiplications.
1731 // The threshold of 400 for the first has not been thoroughly benchmarked yet and may have room for further
1732 // optimization, it likely depends on the characteristics of the CPU, MKL will be different from non-MKL etc.,
1733 // but this seems to be a first starting point.
1734 
bmm_out_or_baddbmm_(const Tensor & self_or_result_,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha,bool is_bmm_out)1735 static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, bool is_bmm_out) {
1736   // is_bmm_out: true for bmm_out, false for baddbmm_
1737   // self_or_result is "self" for baddbmm_ and "result" for bmm_out
1738   Tensor& self_or_result = const_cast<Tensor&>(self_or_result_);
1739 
1740   const auto batch1_sizes = batch1.sizes();
1741   const auto batch2_sizes = batch2.sizes();
1742 
1743   int64_t bs = batch1_sizes[0];
1744   int64_t contraction_size = batch1_sizes[2];
1745   int64_t res_rows = batch1_sizes[1];
1746   int64_t res_cols = batch2_sizes[2];
1747 
1748   // handle pathological cases that blas may not like
1749   if (self_or_result.numel() == 0) {
1750     return;
1751   } else if (contraction_size == 0) {
1752     if (is_bmm_out || (beta.to<c10::complex<double>>() == 0.0)) {
1753       self_or_result.zero_();
1754       return;
1755     } else {
1756       self_or_result.mul_(beta);
1757       return;
1758     }
1759   }
1760 
1761   auto batch_items_contiguous_or_transposed = [&](const Tensor& t) {
1762     const auto sizes = t.sizes();
1763     const auto strides = t.strides();
1764     // we do not care dimension's stride if its size equals to 1
1765     return (strides[2] == 1 && (sizes[1] == 1 || strides[1] >= sizes[2])) ||
1766         (strides[1] == 1 && (sizes[2] == 1 || strides[2] >= sizes[1]));
1767   };
1768 
1769   bool apply_heur = apply_mkldnn_matmul_heur(batch1.sizes()[1], batch1.sizes()[2], batch2.sizes()[2]);
1770   if (apply_heur && use_mkldnn_matmul(batch1, batch2, self_or_result)) {
1771     try {
1772       mkldnn_matmul(batch1, batch2, self_or_result, beta.to<float>(), alpha.to<float>());
1773       return;
1774     } catch (const std::exception& e) {
1775       TORCH_WARN("mkldnn_matmul failed, switching to baddbmm:", e.what());
1776       at::globalContext().setUserEnabledMkldnn(false);
1777     }
1778   }
1779 
1780   if (contraction_size * res_rows * res_cols < 400) {
1781     if (is_bmm_out) {
1782       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, batch1.scalar_type(), "bmm", [&] {
1783           baddbmm_cpu_kernel<scalar_t, true>(self_or_result, batch1, batch2, beta, alpha);
1784         });
1785     } else {
1786       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, batch1.scalar_type(), "baddbmm", [&] {
1787           baddbmm_cpu_kernel<scalar_t, false>(self_or_result, batch1, batch2, beta, alpha);
1788         });
1789     }
1790   } else if (at::hasMKL() && ((
1791             self_or_result.scalar_type() != kBFloat16 &&
1792             self_or_result.scalar_type() != kHalf &&
1793             at::native::is_floating_point(self_or_result)) ||
1794             at::native::is_complex(self_or_result))
1795             && batch_items_contiguous_or_transposed(batch1)
1796             && batch_items_contiguous_or_transposed(batch2)
1797             && self_or_result.is_contiguous()) {
1798     baddbmm_with_gemm_(self_or_result, batch1, batch2, beta, alpha);
1799   } else { // split along batch dimension
1800 #ifdef C10_MOBILE
1801     /*
1802      * We only do multithreading when Inference mode is enabled because various
1803      * thread local state is not appropriately propagated through
1804      * at::parallel_for. e.g. RecordFunction related state, dispatchKeySet Big
1805      * concern with this is that if we use at::parallel_for where state is not
1806      * propagated then dispatch machinery may work differently on main thread
1807      * vs. other threads, leading to undefined behavior.
1808      * Thus it is recommended to not use at::parallel_for where lambdas do
1809      * ops that go through dispatcher.
1810      * For now we circumvent this by InferenceMode guard in order to unlock
1811      * performance.
1812      * Longer term we probably want a separate API that explicitly calls out
1813      * the TLS that it propagates.
1814      * Also note that this is enabled for mobile only because blas
1815      * implementation for non-mobile build is already multithreaded.
1816      */
1817     // Benchmarking was done as follows:
1818     // bmm_test: operator benchmark under
1819     // benchmarks/operator_benchmarks/pt/bmm_test.py Ran this benchmark for
1820     // various matrix sizes on Samsung S8U
1821     const bool enable_multithreaded_bmm = c10::InferenceMode::is_enabled() &&
1822         bs >= 4 && res_rows >= 4 && res_cols >= 16 && contraction_size >= 16;
1823 #else
1824     const bool enable_multithreaded_bmm{false};
1825 #endif
1826     if (is_bmm_out) {
1827       if (enable_multithreaded_bmm) {
1828         auto bmm_out_fn = [&](uint64_t start, uint64_t end) {
1829           c10::InferenceMode guard;
1830           for (const auto b : c10::irange(start, end)) {
1831             auto r = self_or_result.select(0, b);
1832             addmm_impl_cpu_(
1833                 r, r, batch1.select(0, b), batch2.select(0, b), 0, 1);
1834           }
1835         };
1836         // Materialize if COW, since we cannot do so during parallel_for
1837         self_or_result.mutable_data_ptr();
1838         at::parallel_for(0, bs, 1, bmm_out_fn);
1839       } else {
1840         for (const auto b : c10::irange(bs)) {
1841           auto r = self_or_result.select(0, b);
1842           addmm_impl_cpu_(r, r, batch1.select(0, b), batch2.select(0, b), 0, 1);
1843         }
1844       }
1845     } else {
1846       if (enable_multithreaded_bmm) {
1847         auto bmm_fn = [&](uint64_t start, uint64_t end) {
1848           c10::InferenceMode guard;
1849           for (const auto b : c10::irange(start, end)) {
1850             self_or_result.select(0, b).addmm_(
1851                 batch1.select(0, b), batch2.select(0, b), beta, alpha);
1852           }
1853         };
1854         // Materialize if COW, since we cannot do so during parallel_for
1855         self_or_result.mutable_data_ptr();
1856         at::parallel_for(0, bs, 1, bmm_fn);
1857       } else {
1858         for (const auto b : c10::irange(bs)) {
1859           self_or_result.select(0, b).addmm_(
1860               batch1.select(0, b), batch2.select(0, b), beta, alpha);
1861         }
1862       }
1863     }
1864   }
1865   return;
1866 }
1867 
conjugate_mutable_input_if_needed(const Tensor & self,bool conjugate)1868 static void conjugate_mutable_input_if_needed(const Tensor& self, bool conjugate) {
1869   if (conjugate) {
1870     self.conj_physical_();
1871   }
1872 }
1873 
TORCH_IMPL_FUNC(baddbmm_out_cpu)1874 TORCH_IMPL_FUNC(baddbmm_out_cpu)
1875 (const Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
1876     bool self_is_conj = result.is_conj();
1877     conjugate_mutable_input_if_needed(result, self_is_conj);
1878     bmm_out_or_baddbmm_(result, batch1.resolve_conj(), batch2.resolve_conj(), beta, alpha, false);
1879     conjugate_mutable_input_if_needed(result, self_is_conj);
1880   }
1881 
TORCH_IMPL_FUNC(bmm_out_cpu)1882 TORCH_IMPL_FUNC(bmm_out_cpu)
1883 (const Tensor & batch1, const Tensor & batch2, const Tensor & result) {
1884     {
1885     NoNamesGuard guard;
1886     bool result_is_conj = result.is_conj();
1887     conjugate_mutable_input_if_needed(result, result_is_conj);
1888     bmm_out_or_baddbmm_(result, batch1.resolve_conj(), batch2.resolve_conj(), Scalar(0.0), Scalar(1.0), true);
1889     conjugate_mutable_input_if_needed(result, result_is_conj);
1890     }
1891 }
1892 
dot_out(const Tensor & self,const Tensor & other,Tensor & result)1893 Tensor& dot_out(const Tensor& self, const Tensor& other, Tensor& result) {
1894   auto output_device = result.device();
1895   auto input1_device = self.device();
1896   auto input2_device = other.device();
1897   // check if the input & output tensors are on the same device.
1898   TORCH_CHECK(
1899     (output_device == input1_device) && (input1_device == input2_device),
1900     "dot: Expected the output and input tensors to be on the "
1901     "same device, but got the output tensor on ", output_device,
1902     ", the 'input' tensor on ", input1_device, ", and the 'other' tensor on ", input2_device);
1903   at::native::resize_output(result, {});
1904   TORCH_CHECK(result.scalar_type() == self.scalar_type(),
1905            "result dtype ", result.scalar_type(), " does not match input dtype ", self.scalar_type());
1906   return result.fill_(self.dot(other));
1907 }
1908 
vdot_out(const Tensor & self,const Tensor & other,Tensor & result)1909 Tensor& vdot_out(const Tensor& self, const Tensor& other, Tensor& result) {
1910   auto output_device = result.device();
1911   auto input1_device = self.device();
1912   auto input2_device = other.device();
1913   // check if the input & output tensors are on the same device.
1914   TORCH_CHECK(
1915     (output_device == input1_device) && (input1_device == input2_device),
1916     "vdot: Expected the output and input tensors to be on the "
1917     "same device, but got the output tensor on ", output_device,
1918     ", the 'input' tensor on ", input1_device, ", and the 'other' tensor on ", input2_device);
1919   at::native::resize_output(result, {});
1920   TORCH_CHECK(result.scalar_type() == self.scalar_type(),
1921            "result dtype ", result.scalar_type(), " does not match input dtype ", self.scalar_type());
1922   return result.fill_(self.vdot(other));
1923 }
1924 
should_fold(const Tensor & tensor1,const Tensor & tensor2,bool has_out)1925 static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_out) {
1926   // We check that we can fold the larger tensor into a matrix and dispatch to mm or mv rather than
1927   // to bmm. We want to make sure we can do so without incurring in any extra copy
1928   const auto tensor1_larger = tensor1.dim() >= tensor2.dim();
1929 
1930   // We order the tensors. t1 will be the larger tensor
1931   // We can always transpose tensor2 as the dimensions are always >= 1 (precondition from matmul)
1932   // and tensor1_larger iff tensor2.dim() > tensor1.dim(9
1933   const auto t1 = tensor1_larger ? MaybeOwned<Tensor>::borrowed(tensor1)
1934                                  : MaybeOwned<Tensor>::owned(tensor2.mT());
1935   const int64_t dim_t1 = t1->dim();
1936   const auto dim_t2 = tensor1_larger ? tensor2.dim()
1937                                      : tensor1.dim();
1938 
1939   // Just fold for dim_t1 >= 3 and (dim_t2 == 1 || dim_t2 == 2)
1940   if (!(dim_t1 >= 3 && dim_t2 <= 2)) {
1941     return false;
1942   }
1943 
1944   // In this case we *do* incur in an extra copy to avoid creating an unnecessary large tensor in the backward
1945   // Suppose we don't fold here. Let t1.shape = [b, m, n] t2.shape = [n, k] like in a transformer
1946   // t2 will be expanded to a tensor of shape [b, n, k] and then we do t1.bmm(t2_expanded)
1947   // The issue appears in the backward.
1948   // The output gradient g of this operation would have shape [b, m, k]
1949   // The backward wrt. t2 of bmm would be given by t1.mH @ g, which has shape [b, n, k]
1950   // Then, the backward of expand is simply `sum(0)`. As such, we are instantiating a tensor
1951   // of shape [b, n, k] unnecessarily, which may cause a large memory footprint, and in the
1952   // worst case, an OOM
1953   bool t2_requires_grad = tensor1_larger ? tensor2.requires_grad() : tensor1.requires_grad();
1954   if (t2_requires_grad && !has_out) {
1955     // We should be checking !at::GradMode::is_enabled(), but apparently
1956     // this regresses performance in some cases:
1957     // https://github.com/pytorch/pytorch/issues/118548#issuecomment-1916022394
1958     return true;
1959   }
1960 
1961   // Don't fold in this case, as we would have to call mm on the transposed tensor, the result
1962   // would be contiguous, and then we would need to transpose it and call contiguous on it, thus
1963   // having to copy the tensor
1964   if (tensor1.dim() == 2) {
1965     return false;
1966   }
1967 
1968   // Can always fold if the tensor is empty
1969   // This serves as a precondition for the code below
1970   if (t1->numel() == 0) {
1971     return true;
1972   }
1973 
1974   // t1->view(-1, t1->size(-1)) does not copy only when the first n-1 dimensions are contiguous
1975   // in the sense that t1_stride[i] = t1_stride[i+1]*t1_shape[i+1]
1976   const auto t1_shape = t1->sizes();
1977   const auto t1_strides = t1->strides();
1978   for (auto i = int64_t{0}; i < dim_t1 - int64_t{2}; ++i) {
1979     if (t1_strides[i] != t1_strides[i+1] * t1_shape[i+1]) {
1980       return false;
1981     }
1982   }
1983   return true;
1984 }
1985 
1986 /*
1987 Matrix product of two Tensors.
1988 The behavior depends on the dimensionality of the Tensors as follows:
1989 - If both Tensors are 1-dimensional, (1d) the dot product (scalar) is returned.
1990 - If the arguments are 2D - 1D or 1D - 2D, the matrix-vector product is returned.
1991 - If both arguments are 2D, the matrix-matrix product is returned.
1992 - If one of the arguments is ND with N >= 3 and the other is 1D or 2D, and some
1993   conditions on the strides apply (see should_fold) we fold the first N-1 dimensions
1994   of the ND argument to form a matrix, call mm or mv, reshape it back to ND and return it
1995 - Otherwise, we return bmm, after broadcasting and folding the batched dimensions if
1996   there's more than one
1997 */
_matmul_impl(Tensor & out,const Tensor & tensor1,const Tensor & tensor2)1998 static Tensor _matmul_impl(
1999     Tensor& out,
2000     const Tensor& tensor1,
2001     const Tensor& tensor2) {
2002   NoNamesGuard guard;
2003   const auto dim_tensor1 = tensor1.dim();
2004   const auto dim_tensor2 = tensor2.dim();
2005 
2006   // This is checked up here to simplify the logic below
2007   // Note that the strings are just evaluated on failure, so almost always we just evaluate
2008   // the condition and move on
2009   TORCH_CHECK(dim_tensor1 != 0 && dim_tensor2 != 0,
2010               "both arguments to matmul need to be at least 1D, but they are ",
2011               dim_tensor1, "D and ", dim_tensor2, "D");
2012 
2013 
2014   const bool has_out = out.defined();
2015 
2016   if (has_out) {
2017     // Usually we would rely on the out= kernels we decompose into to check this, but
2018     // for matmul there is logic at the composite level that relies on this invariant.
2019     TORCH_CHECK(!(tensor1.requires_grad() || tensor2.requires_grad() || out.requires_grad()) || !at::GradMode::is_enabled(),
2020       "matmul(): functions with out=... arguments don't support automatic differentiation, "
2021       "but one of the arguments requires grad."
2022     );
2023   }
2024 
2025   if (dim_tensor1 == 1 && dim_tensor2 == 1) {
2026     return has_out ? at::dot_out(out, tensor1, tensor2) : tensor1.dot(tensor2);
2027   } else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
2028     return has_out ? at::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2);
2029   } else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
2030     return has_out ? at::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0)
2031                    : tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
2032   } else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
2033     return has_out ? at::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2);
2034   } else if (should_fold(tensor1, tensor2, has_out)) {
2035     // dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) ||
2036     // dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2)
2037     // and at least one of the following two conditions hold
2038     // - the small tensor requires grad (see should_fold for the why)
2039     // - we can fold the larger tensor t1 into a matrix as t1.view(-1, t1.size(-1)) without copying
2040 
2041     // optimization: use mm instead of bmm by folding the batch of the larger tensor
2042     // into its leading matrix dimension
2043     const auto transpose = dim_tensor2 > dim_tensor1;
2044     const auto t1 = transpose ? MaybeOwned<Tensor>::owned(tensor2.mT())
2045                               : MaybeOwned<Tensor>::borrowed(tensor1);
2046     const auto t2 = !transpose ? MaybeOwned<Tensor>::borrowed(tensor2)
2047                                : dim_tensor1 == 2
2048                                    ? MaybeOwned<Tensor>::owned(tensor1.t())
2049                                    : MaybeOwned<Tensor>::borrowed(tensor1);
2050     // Invariant: t1->dim() >= 3 && (t2->dim() == 1 || t2->dim() == 2)
2051     //            and *t1 and *t2 are matmul-compatible
2052 
2053     // Why not t1->view(-1, sizes_1.back())?
2054     // If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous.
2055     // This can happen in e.g. [3, 5, 0] @ [0, 0].
2056     const auto sizes_1 = t1->sizes();
2057     auto output_shape = DimVector(sizes_1.begin(), sizes_1.end() - 1);
2058     const auto folded_dim1 = c10::multiply_integers(output_shape);
2059 
2060     // Readjust output_shape if we are multiplying by a matrix
2061     const auto t2_is_matrix = t2->dim() == 2;
2062     if (t2_is_matrix) {
2063       output_shape.push_back(t2->sizes()[1]);
2064     }
2065     // This will almost always be a view.
2066     // It may not be a view if t2->requires_grad(). See should_fold for an explanation
2067     const auto t1_folded = t1->reshape({folded_dim1, sizes_1.back()});
2068     if (!has_out) {
2069       if (t2_is_matrix) {
2070         const auto output = at::_unsafe_view(t1_folded.mm(*t2), output_shape);
2071         // This copies if we perform a 2D @ 3D and the first tensor requires_grad
2072         // See should_fold for why.
2073         // If mm_out were differentiable, we could use it here, and pass a result with the
2074         // correct strides to avoid this unnecessary copy.
2075         return transpose ? output.mT().contiguous() : output;
2076       } else {
2077         return at::_unsafe_view(t1_folded.mv(*t2), output_shape);
2078       }
2079     } else {
2080       // See the !has_out branch for an explanation
2081       TORCH_INTERNAL_ASSERT(!(transpose && t2_is_matrix));
2082 
2083       // Resize output into the correct shape
2084       at::native::resize_output(out, output_shape);
2085 
2086       // We then reshape the output to the expected shape and call mm/mv
2087       // and transpose back if necessary
2088       auto reshaped_out = t2_is_matrix ? out.reshape({folded_dim1, t2->sizes().back()})
2089                                        : out.reshape({folded_dim1});
2090       if (t2_is_matrix) {
2091         at::mm_out(reshaped_out, t1_folded, *t2);
2092       } else {
2093         at::mv_out(reshaped_out, t1_folded, *t2);
2094       }
2095       if (!reshaped_out.is_alias_of(out)) {
2096         out.copy_(reshaped_out);
2097       }
2098       return out;
2099     }
2100   } else {
2101     // dim_tensor1 >= 3 || dim_tensor2 >= 3
2102     // We track m1 vs m2 separately even though they must match for nicer error messages
2103     const int64_t n = dim_tensor1 > 1 ? tensor1.sizes().cend()[-2] : 1LL;
2104     const int64_t m1 = tensor1.sizes().back();
2105     auto batch_tensor1 = tensor1.sizes().slice(0, std::max<int64_t>(dim_tensor1 - 2, 0LL));
2106     const int64_t m2 = dim_tensor2 > 1 ? tensor2.sizes().cend()[-2] : tensor2.sizes().front();
2107     const int64_t p = dim_tensor2 > 1 ? tensor2.sizes().back() : 1LL;
2108     const IntArrayRef batch_tensor2(tensor2.sizes().data(),
2109                                     std::max<int64_t>(dim_tensor2 - 2, 0LL));
2110 
2111     // Same optimization for the gradients as that in should_fold
2112     // If we're going to broadcast we force it to go through the should_fold branch
2113     if (dim_tensor1 == 3 && dim_tensor2 == 3 && batch_tensor1[0] != batch_tensor2[0]) {
2114       if (batch_tensor1[0] == 1 && (tensor1.requires_grad() || isTensorSubclassLike(tensor1))) {
2115         return _matmul_impl(out, tensor1.squeeze(0), tensor2);
2116       }
2117       if (batch_tensor2[0] == 1 && (tensor2.requires_grad() || isTensorSubclassLike(tensor2))) {
2118         return _matmul_impl(out, tensor1, tensor2.squeeze(0));
2119       }
2120     }
2121 
2122     auto output_shape = infer_size_dimvector(batch_tensor1, batch_tensor2);
2123     const int64_t expand_batch_product = c10::multiply_integers(output_shape);
2124 
2125     // flatten expanded batches
2126     const auto tensor1_expand_size = [&output_shape, n, m1]{ DimVector ret(output_shape);
2127                                                              ret.append({n, m1});
2128                                                              return ret; }();
2129     const auto tensor1_expanded = tensor1.expand(tensor1_expand_size)
2130                                          .reshape({expand_batch_product, n, m1});
2131     // We need to treat the dim_tensor2 == 1 case separately as broadcasting would not convert
2132     // a vector of shape (n,) into a batch of matrices of shape (*, n, 1)
2133     auto vector_rhs = dim_tensor2 == 1;
2134     const auto tensor2_expand_size = [&output_shape, m2, p, vector_rhs]{
2135       DimVector ret(output_shape);
2136       if (vector_rhs) {
2137         ret.push_back(m2);
2138       } else {
2139         ret.append({m2, p});
2140       }
2141       return ret;
2142     }();
2143     auto tensor2_expanded = tensor2.expand(tensor2_expand_size);
2144     if (vector_rhs) {
2145       tensor2_expanded = tensor2_expanded.reshape({expand_batch_product, m2}).unsqueeze(2);
2146     } else {
2147       tensor2_expanded = tensor2_expanded.reshape({expand_batch_product, m2, p});
2148     }
2149 
2150     if (dim_tensor1 > 1) {
2151       output_shape.push_back(n);
2152     }
2153     if (dim_tensor2 > 1) {
2154       output_shape.push_back(p);
2155     }
2156 
2157     if (!has_out) {
2158       if (vector_rhs) {
2159         return at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded).squeeze(-1), output_shape);
2160       } else {
2161         return at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded), output_shape);
2162       }
2163     } else {
2164       at::native::resize_output(out, output_shape);
2165       auto reshaped_out = out.reshape({expand_batch_product, n, p});
2166       at::bmm_out(reshaped_out, tensor1_expanded, tensor2_expanded);
2167       if (vector_rhs) {
2168         reshaped_out = reshaped_out.squeeze(-1);
2169       }
2170       if (!reshaped_out.is_alias_of(out)) {
2171         out.copy_(reshaped_out.view_as(out));
2172       }
2173       return out;
2174     }
2175   }
2176 }
2177 
matmul(const Tensor & tensor1,const Tensor & tensor2)2178 Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
2179   auto maybe_outnames = namedinference::compute_matmul_outnames(tensor1, tensor2);
2180   at::Tensor result, unused;
2181   result = at::native::_matmul_impl(unused, tensor1, tensor2);
2182   namedinference::propagate_names_if_nonempty(result, maybe_outnames);
2183   return result;
2184 }
2185 
matmul_out(const Tensor & tensor1,const Tensor & tensor2,Tensor & result)2186 Tensor& matmul_out(const Tensor & tensor1, const Tensor & tensor2, Tensor &result) {
2187   auto maybe_outnames = namedinference::compute_matmul_outnames(tensor1, tensor2);
2188   at::native::_matmul_impl(result, tensor1, tensor2);
2189   namedinference::propagate_names_if_nonempty(result, maybe_outnames);
2190   return result;
2191 }
2192 
2193 // torch.linalg.matmul, alias for torch.matmul
linalg_matmul(const Tensor & tensor1,const Tensor & tensor2)2194 Tensor linalg_matmul(const Tensor & tensor1, const Tensor & tensor2) {
2195   return at::matmul(tensor1, tensor2);
2196 }
2197 
linalg_matmul_out(const Tensor & tensor1,const Tensor & tensor2,Tensor & result)2198 Tensor& linalg_matmul_out(const Tensor & tensor1, const Tensor & tensor2, Tensor &result) {
2199   return at::matmul_out(result, tensor1, tensor2);
2200 }
2201 
2202 // torch.linalg.diagonal, alias for torch.diagonal with dim1=-2, dim2=-1 as defaults
linalg_diagonal(const Tensor & A,int64_t offset,int64_t dim1,int64_t dim2)2203 Tensor linalg_diagonal(const Tensor& A, int64_t offset, int64_t dim1, int64_t dim2) {
2204   return A.diagonal(offset, dim1, dim2);
2205 }
2206 
2207 // helper methods for matrix_exp
2208 namespace {
2209 
2210 template <typename scalar_t, int ROW, int COL>
2211 using array2d = std::array<std::array<scalar_t, COL>, ROW>;
2212 
2213 // we consider 6 Taylor expansions of degree
2214 // 1, 2, 4, 8, 12, 18
2215 constexpr int total_n_degs = 6;
2216 
operator_1_norm(const Tensor & tensor)2217 Tensor operator_1_norm(const Tensor& tensor) {
2218   return std::get<0>(tensor.abs().sum(-2).max(-1));
2219 }
2220 
2221 // Allocates a buffers of uninitialized or zero values
2222 // of shape [n_copies, a.size()]
_allocate_buffer(const Tensor & a,int n_copies,bool is_zero=false)2223 Tensor _allocate_buffer(const Tensor& a, int n_copies, bool is_zero = false) {
2224   auto res = at::empty(
2225     {n_copies, a.size(0), a.size(1), a.size(2)},
2226     a.options().memory_format(at::MemoryFormat::Contiguous)
2227   );
2228 
2229   if (is_zero) {
2230     res.zero_();
2231   }
2232 
2233   return res;
2234 }
2235 
2236 // Makes `buffer` to store `num_matrices` number of matrices needed for
2237 // compute the matrix exponentials of different orders, i.e.
2238 // first `num_matrices` matrices from the list l := {I, A, A^2, A^3, A^6}
2239 // in a contiguous block of memory such that
2240 // buffer[0, ...] = l[0], // I
2241 // buffer[1, ...] = l[1], // A
2242 // ...
2243 // buffer[num_matrices - 1, ...] = l[num_matries - 1]
_fill_matrix_powers(Tensor & buffer,const Tensor & a,int num_matrices)2244 void _fill_matrix_powers(Tensor& buffer, const Tensor& a, int num_matrices) {
2245   auto a_sizes_minus_last = a.sizes().vec();
2246   a_sizes_minus_last.pop_back();
2247   // fill I
2248   buffer.select(0, 0).copy_(
2249     at::diag_embed(
2250       at::ones({1}, buffer.options())
2251         .expand(a_sizes_minus_last)
2252     )
2253   );
2254 
2255   // fill a
2256   buffer.select(0, 1).copy_(a);
2257 
2258   // fill a^2
2259   if (2 <= num_matrices - 1) {
2260     // out for a^2
2261     auto view_out = buffer.select(0, 2);
2262     _matmul_impl(
2263       view_out,
2264       buffer.select(0, 1),
2265       buffer.select(0, 1)
2266     );
2267   }
2268 
2269   // fill a^3
2270   if (3 <= num_matrices - 1) {
2271     // out for a^3
2272     auto view_out = buffer.select(0, 3);
2273     _matmul_impl(
2274       view_out,
2275       buffer.select(0, 1),
2276       buffer.select(0, 2)
2277     );
2278   }
2279 
2280   // fill a^6
2281   if (4 <= num_matrices - 1) {
2282     // out for a^6
2283     auto view_out = buffer.select(0, 4);
2284     _matmul_impl(
2285       view_out,
2286       buffer.select(0, 3),
2287       buffer.select(0, 3)
2288     );
2289   }
2290 }
2291 
_move_memory_if_cuda_input(const Tensor & mem,const Tensor & in)2292 inline Tensor _move_memory_if_cuda_input(
2293   const Tensor& mem,
2294   const Tensor& in
2295 ) {
2296   return (in.device().type() == at::kCUDA)
2297     ? mem.to(at::device_of(in).value())
2298     : mem;
2299 }
2300 
2301 // convert a 1D blob to a 2D Tensor of size [1, blob.size()]
2302 // such that blob.device() == in.device())
2303 // designed to be used with _compute_linear_combination
2304 template <typename scalar_t>
_blob_to_Tensor(std::initializer_list<scalar_t> blob,const Tensor & in)2305 inline Tensor _blob_to_Tensor(
2306   std::initializer_list<scalar_t> blob,
2307   const Tensor& in
2308 ) {
2309   // we convert to void* expecitly because begin() returns
2310   // a pointer to a constant.
2311   // Blob is assumed to be a 1D array, that is why
2312   // we also insert a fake dimension so that the result could directly
2313   // be used in _compute_linear_combination
2314   auto tensor = at::from_blob((void*)blob.begin(), blob.size(),
2315     c10::toRealValueType(in.scalar_type())).unsqueeze(0);
2316   return _move_memory_if_cuda_input(tensor, in);
2317 }
2318 
2319 template <typename scalar_t>
_linear_combination(const Tensor & t,std::initializer_list<scalar_t> blob)2320 inline Tensor _linear_combination(
2321     const Tensor& t,
2322     std::initializer_list<scalar_t> blob) {
2323   // _blob_to_Tensor converts blob to a 2D tensor for _compute_linear_combination.
2324   // If this tensor is of shape (1, *), the result of _compute_linear_combination
2325   // is going to be of shape (1, *t.shape) so we squeeze(0) so that
2326   // for any t with t.dim() >= 1: t.dim() == _compute_linear_combination(t, ...).dim().
2327   return at::native::_compute_linear_combination(
2328       t, _blob_to_Tensor<scalar_t>(blob, t))
2329     .squeeze(0);
2330 }
2331 
2332 // I + A
compute_T1(const Tensor & A)2333 Tensor compute_T1(const Tensor& A) {
2334   // 2 for {I, A}
2335   auto As = _allocate_buffer(A, 2);
2336   _fill_matrix_powers(As, A, 2);
2337   return As.sum(0);
2338 }
2339 
2340 // I + A + A^2 / 2
compute_T2(const Tensor & A)2341 Tensor compute_T2(const Tensor& A) {
2342   auto As = _allocate_buffer(A, 3);
2343   // 3 for {I, A, A^2}
2344   _fill_matrix_powers(As, A, 3);
2345   As.select(0, 2).div_(2.0);
2346   return As.sum(0);
2347 }
2348 
2349 // I + A + A^2 * (I / 2 + A / 6 + A^2 / 24)
2350 template <typename scalar_t>
compute_T4(const Tensor & A)2351 Tensor compute_T4(const Tensor& A) {
2352   auto As = _allocate_buffer(A, 4);
2353   // 3 for {I, A, A^2}
2354   _fill_matrix_powers(As, A, 3);
2355 
2356   // output for A^2 * (I / 2 + A / 6 + A^2 / 24)
2357   auto view_out = As.select(0, 3);
2358   _matmul_impl(
2359     view_out,
2360     // contains A^2
2361     As.select(0, 2),
2362     // computes (I / 2 + A / 6 + A^2 / 24)
2363     _linear_combination<scalar_t>(
2364       As.narrow(0, 0, 3),
2365       {1 / 2.0, 1 / 6.0, 1 / 24.0}
2366     )
2367   );
2368 
2369   // I + A + A^2 * (I / 2 + A / 6 + A^2 / 24)
2370   return _linear_combination<scalar_t>(
2371     As, {1.0, 1.0, 0.0, 1.0}
2372   );
2373 }
2374 
2375 template <typename scalar_t>
compute_T8(const Tensor & A)2376 Tensor compute_T8(const Tensor& A) {
2377   constexpr scalar_t sqrt_177 = 0.1330413469565007072504e+2;
2378   constexpr scalar_t x3 = 2. / 3.;
2379   constexpr scalar_t x1 = x3 * ((1. + sqrt_177) / 88.);
2380   constexpr scalar_t x2 = x3 * ((1. + sqrt_177) / 352.);
2381   constexpr scalar_t x4 = (-271. + 29. * sqrt_177) / (315. * x3);
2382   constexpr scalar_t x5 = (-11. + 11. * sqrt_177) / (1260. * x3);
2383   constexpr scalar_t x6 = (-99. + 11. * sqrt_177) / (5040. * x3);
2384   constexpr scalar_t x7 = (89. - sqrt_177) / (5040. * x3);
2385   constexpr scalar_t y2 = (857. - 58. * sqrt_177) / 630.;
2386 
2387   auto As = _allocate_buffer(A, 5);
2388   // 3 for {I, A, A^2}
2389   _fill_matrix_powers(As, A, 3);
2390 
2391   // output for A4
2392   auto view_out = As.select(0, 3);
2393   // A4 =  A2 * (x1 * A + x2 * A2)
2394   _matmul_impl(
2395     view_out,
2396     // As.select(0, 2) = A^2
2397     As.select(0, 2),
2398     _linear_combination<scalar_t>(
2399       // extract {A, A^2} from As
2400       As.narrow(0, 1, 2),
2401       {x1, x2}
2402     )
2403   );
2404 
2405   // output for A8
2406   view_out = As.select(0, 4);
2407   // A8 = (x3 * A2 + A4) * (x4 * I + x5 * A + x6 * A2 + x7 * A4)
2408   _matmul_impl(
2409     view_out,
2410     // x3 * A2 + A4
2411     _linear_combination<scalar_t>(
2412       As.narrow(0, 2, 2),
2413       {x3, 1.0}
2414     ),
2415     _linear_combination<scalar_t>(
2416       As.narrow(0, 0, 4),
2417       {x4, x5, x6, x7}
2418     )
2419   );
2420 
2421   // return I + A + y2 * A2 + A8;
2422   return _linear_combination<scalar_t>(
2423     As, {1.0, 1.0, y2, 0.0, 1.0}
2424   );
2425 }
2426 
2427 template <typename scalar_t>
compute_T12(const Tensor & A)2428 Tensor compute_T12(const Tensor& A) {
2429   constexpr int num_prods = 4;
2430   array2d<scalar_t, num_prods, num_prods> b = {{
2431     {
2432       9.0198e-16,
2433       0.46932117595418237389,
2434       -0.20099424927047284052,
2435       -0.04623946134063071740
2436     },
2437     {
2438       5.31597895759871264183,
2439       1.19926790417132231573,
2440       0.01179296240992997031,
2441       0.01108844528519167989
2442     },
2443     {
2444       0.18188869982170434744,
2445       0.05502798439925399070,
2446       0.09351590770535414968,
2447       0.00610700528898058230
2448     },
2449     {
2450       -2.0861320e-13,
2451       -0.13181061013830184015,
2452       -0.02027855540589259079,
2453       -0.00675951846863086359
2454     }
2455   }};
2456 
2457   // gather coefficients `b` from above into a tensor,
2458   // and move them to device `device_of(A)`
2459   auto bs = at::from_blob(
2460     reinterpret_cast<void*>(&b),
2461     {num_prods, num_prods},
2462     {num_prods, 1},
2463     c10::toRealValueType(A.scalar_type())
2464   );
2465   bs = _move_memory_if_cuda_input(bs, A);
2466 
2467   auto As = _allocate_buffer(A, num_prods);
2468   _fill_matrix_powers(As, A, num_prods);
2469 
2470   auto Bs = at::native::_compute_linear_combination(As, bs);
2471 
2472   // output for A6
2473   auto view_out = As.select(0, 0);
2474   // compute A6
2475   Bs.select(0, 2).add_(_matmul_impl(
2476     view_out,
2477     Bs.select(0, 3),
2478     Bs.select(0, 3)
2479   ));
2480 
2481   return Bs.select(0, 0).add_(_matmul_impl(
2482     view_out,
2483     Bs.select(0, 1).add_(Bs.select(0, 2)),
2484     Bs.select(0, 2)
2485   ));
2486 }
2487 
2488 template <typename scalar_t>
compute_T18(const Tensor & A)2489 Tensor compute_T18(const Tensor& A) {
2490   constexpr int num_prods = 5;
2491   array2d<scalar_t, num_prods, num_prods> b = {{
2492     {
2493       0.,
2494       -1.00365581030144618291e-01,
2495       -8.02924648241156932449e-03,
2496       -8.92138498045729985177e-04,
2497       0.
2498     },
2499     {
2500       0.,
2501       3.97849749499645077844e-01,
2502       1.36783778460411720168e+00,
2503       4.98289622525382669416e-01,
2504       -6.37898194594723280150e-04
2505     },
2506     {
2507       -1.09676396052962061844e+01,
2508       1.68015813878906206114e+00,
2509       5.71779846478865511061e-02,
2510       -6.98210122488052056106e-03,
2511       3.34975017086070470649e-05
2512     },
2513     {
2514       -9.04316832390810593223e-02,
2515       -6.76404519071381882256e-02,
2516       6.75961301770459654925e-02,
2517       2.95552570429315521194e-02,
2518       -1.39180257516060693404e-05
2519     },
2520     {
2521       0.,
2522       0.,
2523       -9.23364619367118555360e-02,
2524       -1.69364939002081722752e-02,
2525       -1.40086798182036094347e-05
2526     }
2527   }};
2528 
2529   // gather coefficients `b` from above into a tensor,
2530   // and move them to device `device_of(A)`
2531   auto bs = at::from_blob(
2532     reinterpret_cast<void*>(&b),
2533     {num_prods, num_prods},
2534     {num_prods, 1},
2535     c10::toRealValueType(A.scalar_type())
2536   );
2537   bs = _move_memory_if_cuda_input(bs, A);
2538 
2539   auto As = _allocate_buffer(A, num_prods);
2540   _fill_matrix_powers(As, A, num_prods);
2541 
2542   auto Bs = at::native::_compute_linear_combination(As, bs);
2543 
2544   // tmp buffer for this matrix product
2545   auto view_out = As.select(0, 0);
2546   // compute A9
2547   Bs.select(0, 3).add_(_matmul_impl(
2548     view_out,
2549     Bs.select(0, 0),
2550     Bs.select(0, 4))
2551   );
2552 
2553   return Bs.select(0, 1).add_(_matmul_impl(
2554     view_out,
2555     Bs.select(0, 2).add_(Bs.select(0, 3)),
2556     Bs.select(0, 3)
2557   ));
2558 }
2559 
2560 template <typename scalar_t>
compute_T18_scale_square(const Tensor & a,const Tensor & norm,scalar_t theta)2561 Tensor compute_T18_scale_square(
2562   const Tensor& a,
2563   const Tensor& norm,
2564   scalar_t theta
2565 ) {
2566   // Scale
2567   // We eventually need to do the matrix multiplication to calculate the result.
2568   // For example, if we have `norm` equal to [27, 6, 6, 0.05], we will end up to
2569   // get `s` as [4, 1, 1, 0], so we can use it to get the result by calculating
2570   // matrix[0]^(2^4), matrix[1]^(2^1) and matrix[2]^(2^1) one by one to get the
2571   // result, such "one by one calculation" will be quite slow.
2572   const auto s = (at::ceil(at::log2(norm / theta))).clamp(/*min=*/0);
2573   const auto pow2s = at::pow(2, -s);
2574   const auto a_scaled = a * pow2s.view({-1, 1, 1});
2575   auto mexp_scaled = at::native::compute_T18<scalar_t>(a_scaled);
2576 
2577   // Sort:
2578   // Consider inputs are square matrix, so if we first power `matrix 0,1,2`, then
2579   // the remain thing will only be multiply `matrix 0` by (2^4 - 1) times, which
2580   // gives us an opportunity to calculate the matrix multiplication in a batch.
2581   // The first thing we need to do is sort tensor `s`, which will be helpful to
2582   // do the matrix multiplication by range.
2583   // With above example, `sorted_s` is [0, 1, 1, 4], we also will need the index
2584   // info, so we can use it to compose the result back.
2585   auto [sorted_s, sorted_s_inds] = at::sort(s, /*dim=*/0);
2586   sorted_s = sorted_s.to(at::kLong);
2587   // Then we call `unique_consecutive` and we will use it to split `sorted_s`,
2588   // with above example, `split_counts` is [1, 2, 1].
2589   auto split_counts = std::get<2>(at::unique_consecutive(sorted_s, true, /*return_counts=*/true));
2590   // We also need to know the index of the last element of each split, so we can
2591   // know how many times we need to do the multiplication for each split matrix.
2592   // Notice that, we will not need to calculate the actual pows, because we will
2593   // use the cumulative matrix multiplication.
2594   // With about example, `mul_times` will be [0, 1, 3].
2595   auto split_edges = at::cumsum(split_counts, /*dim=*/0) - 1;
2596   auto unique_s = sorted_s.index_select(0, split_edges).clamp(/*min=*/0);
2597   auto mul_times = at::diff(unique_s, 1, -1, /*prepend=*/unique_s.new_zeros({1}));
2598 
2599   // Square
2600   auto section_values = at::cat({split_counts, mul_times}, 0).to(at::kCPU);
2601 
2602   TORCH_INTERNAL_ASSERT(section_values.is_contiguous());
2603   const auto section_numel = section_values.numel() / 2;
2604   auto scs = section_values. template data_ptr<int64_t>();
2605   auto pts = &scs[section_numel];
2606 
2607   // We now will do the matrix multiplication in a batch, with above example:
2608   // 1. Multiply all matrices by 0 (`mul_times[0]`) times, then do `slice`
2609   // to get the remain matrices by acc[1:] (`split_counts[0]`),
2610   // 2. Multiply remain matrices by 1 times and slice to acc[2:]
2611   // 3. Multiply remain matrices by 3 times and slice to acc[1:]
2612   // All processed matrices will be stored in `output_pieces`.
2613   std::vector<Tensor> output_pieces;
2614   auto acc = mexp_scaled.index_select(0, sorted_s_inds);
2615   for (int64_t i = 0; i < section_numel; ++i) {
2616     for (int64_t j = 0; j < pts[i]; j++) {
2617       // To avoid AMP autocasting caused by at::matmul
2618       auto acc_out = at::empty_like(acc);
2619       acc = at::matmul_out(acc_out, acc, acc);
2620     }
2621     output_pieces.push_back(acc.slice(0, 0, scs[i]));
2622     acc = acc.slice(0, scs[i]);
2623   }
2624 
2625   // Compose the result back
2626   auto output = at::cat(output_pieces, 0);
2627   return output.index_select(0, at::argsort(sorted_s_inds));
2628 }
2629 
2630 template <typename scalar_t>
mexp_impl(const Tensor & a,std::array<scalar_t,total_n_degs> thetas,bool compute_highest_degree_approx=false)2631 Tensor mexp_impl(
2632   const Tensor& a,
2633   std::array<scalar_t, total_n_degs> thetas,
2634   bool compute_highest_degree_approx = false
2635 ) {
2636   const auto norm = operator_1_norm(a);
2637   const auto batch_size = a.size(0);
2638   if (batch_size > 1) {
2639     compute_highest_degree_approx = true;
2640   }
2641 
2642   if (!compute_highest_degree_approx) {
2643     // To prevent undefined behavior which outputs "normal" result from a matrix
2644     // contains NaN values, we put NaN values in `res`, so if input has NaN values,
2645     // its computation will be skipped to return the NaN contained `res` directly.
2646     auto res = at::full_like(a, std::numeric_limits<double>::quiet_NaN(), {},
2647                              at::MemoryFormat::Contiguous);
2648     // `norm_cpu` is used to decide which Tensors require which approximation
2649     // based on their norm. This decision takes place on CPU.
2650     // It requires moving data back and forth between devices when `a` is on CUDA,
2651     // but at the cost of only one sigle CPU-CUDA synchronization (instead of 6),
2652     // and better performance overall (benchmarked).
2653     const auto norm_cpu = (a.device().type() == at::kCUDA)
2654       ? norm.to(at::kCPU) : norm;
2655 
2656     constexpr std::array<
2657       Tensor(*)(const Tensor&),
2658       total_n_degs - 1>
2659     compute_Ts = {
2660       compute_T1, compute_T2, compute_T4<scalar_t>,
2661       compute_T8<scalar_t>, compute_T12<scalar_t>
2662     };
2663 
2664     for (int i = 0; i < total_n_degs - 1; ++i) {
2665       auto norm_lower_bound = (i == 0) ? static_cast<scalar_t>(-1) : thetas[i - 1];
2666       auto norm_upper_bound = thetas[i];
2667       // nonzero returns a 2D tensor, hence squeeze(-1) to make it 1D
2668       auto idx_curr_norm_interval = (
2669         (norm_lower_bound < norm_cpu) * (norm_cpu <= norm_upper_bound)
2670       ).nonzero().squeeze(-1);
2671 
2672       if (idx_curr_norm_interval.numel()) {
2673         auto idx_to_device = _move_memory_if_cuda_input(
2674           idx_curr_norm_interval, a
2675         );
2676         auto sub_a = at::index_select(a, 0, idx_to_device);
2677         res.index_put_({idx_to_device}, compute_Ts[i](sub_a));
2678       }
2679     }
2680 
2681     // nonzero returns a 2D tensor, hence squeeze(-1) to make it 1D
2682     auto idx_large_norm = (norm_cpu >= thetas[total_n_degs - 2])
2683       .nonzero().squeeze(-1);
2684 
2685     if (idx_large_norm.numel()) {
2686       auto idx_to_device = _move_memory_if_cuda_input(
2687         idx_large_norm, a
2688       );
2689       auto a_large_norm = at::index_select(a, 0, idx_to_device);
2690       auto large_norm_subset = at::index_select(norm, 0, idx_to_device);
2691       auto mexp_out = compute_T18_scale_square(
2692         a_large_norm,
2693         large_norm_subset,
2694         thetas[total_n_degs - 1]
2695       );
2696       res.index_put_({idx_large_norm}, mexp_out);
2697     }
2698     return res;
2699   }
2700 
2701   return compute_T18_scale_square(
2702     a, norm,
2703     thetas[total_n_degs - 1]
2704   );
2705 }
2706 
2707 // matrix exponential
mexp(const Tensor & a,bool compute_highest_degree_approx=false)2708 Tensor mexp(const Tensor& a, bool compute_highest_degree_approx = false) {
2709   // squash batch dimensions to one dimension for simplicity
2710   const auto a_3d = a.view({-1, a.size(-2), a.size(-1)});
2711 
2712   if (a.scalar_type() == at::ScalarType::Float
2713       || a.scalar_type() == at::ScalarType::ComplexFloat) {
2714     constexpr std::array<float, total_n_degs> thetas_float = {
2715       1.192092800768788e-07, // deg 1
2716       5.978858893805233e-04, // deg 2
2717       5.116619363445086e-02, // deg 4
2718       5.800524627688768e-01, // deg 8
2719       1.461661507209034e+00, // deg 12
2720       3.010066362817634e+00  // deg 18
2721     };
2722 
2723     return mexp_impl<float>(a_3d, thetas_float, compute_highest_degree_approx)
2724       .view(a.sizes());
2725   }
2726   else { // if Double or ComplexDouble
2727     constexpr std::array<double, total_n_degs> thetas_double = {
2728       2.220446049250313e-16, // deg 1
2729       2.580956802971767e-08, // deg 2
2730       3.397168839976962e-04, // deg 4
2731       4.991228871115323e-02, // deg 8
2732       2.996158913811580e-01, // deg 12
2733       1.090863719290036e+00  // deg 18
2734     };
2735 
2736     return mexp_impl<double>(a_3d, thetas_double, compute_highest_degree_approx)
2737       .view(a.sizes());
2738   }
2739 }
2740 
2741 // TODO This should be deprecated in favor of linalg_matrix_exp_differential
2742 //      in FunctionsManual.cpp
2743 template <typename func_t>
backward_analytic_function_of_a_matrix(const Tensor & self,const Tensor & grad,const func_t & function_of_a_matrix)2744 Tensor backward_analytic_function_of_a_matrix(
2745     const Tensor& self, const Tensor& grad,
2746     const func_t& function_of_a_matrix
2747   ) {
2748   auto self_transposed = self.mH();
2749   auto self_transposed_sizes = self_transposed.sizes().vec();
2750   self_transposed_sizes[self.dim() - 2] <<= 1;
2751   self_transposed_sizes[self.dim() - 1] <<= 1;
2752 
2753   auto n = self_transposed.size(-1);
2754   auto meta_grad = at::zeros(self_transposed_sizes, grad.options());
2755   meta_grad.narrow(-2, 0, n).narrow(-1, 0, n).copy_(self_transposed);
2756   meta_grad.narrow(-2, n, n).narrow(-1, n, n).copy_(self_transposed);
2757   meta_grad.narrow(-2, 0, n).narrow(-1, n, n).copy_(grad);
2758 
2759   auto grad_input = function_of_a_matrix(meta_grad)
2760     .narrow(-2, 0, n).narrow(-1, n, n);
2761   return grad_input;
2762 }
2763 } // end anon namespace
2764 
2765 // Computes the matrix exponential for a given batch of squared matrices.
2766 // The implementation is based on:
2767 //
2768 // Bader, P.; Blanes, S.; Casas, F.
2769 // Computing the Matrix Exponential with an Optimized Taylor Polynomial Approximation.
2770 // Mathematics 2019, 7, 1174.
2771 //
linalg_matrix_exp(const Tensor & a)2772 Tensor linalg_matrix_exp(const Tensor& a) {
2773   squareCheckInputs(a, "linalg.matrix_exp");
2774   checkFloatingOrComplex(a, "linalg.matrix_exp");
2775 
2776   NoTF32Guard disable_tf32;
2777 
2778   // Trivial cases
2779   const auto n = a.size(-1);
2780   if (n == 0) {
2781     return a.clone();
2782   } else if (n == 1) {
2783     return a.exp();
2784   } else {
2785     return at::native::mexp(a);
2786   }
2787 }
2788 
2789 // Alias
matrix_exp(const Tensor & a)2790 Tensor matrix_exp(const Tensor& a) {
2791   return at::linalg_matrix_exp(a);
2792 }
2793 
2794 // TODO This should be deprecated in favor of linalg_matrix_exp_differential
2795 //      in FunctionsManual.cpp
matrix_exp_backward(const Tensor & self,const Tensor & grad)2796 Tensor matrix_exp_backward(const Tensor& self, const Tensor& grad) {
2797   NoTF32Guard disable_tf32;
2798   return backward_analytic_function_of_a_matrix(
2799     self, grad,
2800     [](const Tensor& a) {
2801       return a.matrix_exp();
2802     }
2803   );
2804 }
2805 
TORCH_IMPL_FUNC(linalg_vector_norm_out)2806 TORCH_IMPL_FUNC(linalg_vector_norm_out)(const Tensor& self, const Scalar& scalar_ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype, const Tensor& result) {
2807   // Casting a large integer to a double will just introduce an error for
2808   // values larger than 10^53 (same for negative numbers), so that's fine.
2809   auto ord = scalar_ord.toDouble();
2810   auto dim = opt_dim.value_or(IntArrayRef{});
2811   auto size = self.sizes();
2812   auto ndim = self.dim();
2813 
2814   auto opt_dim_ = dim.vec();
2815   maybe_wrap_dims(opt_dim_, ndim);
2816 
2817   using Int = IntArrayRef::value_type;
2818   std::vector<Int> all_dim(ndim);
2819   std::iota(all_dim.begin(), all_dim.end(), 0);
2820 
2821   bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty();
2822   auto reduce_dim = is_all_reduce ? all_dim : opt_dim_;
2823 
2824   bool is_reduce_over_1D_vector = true;
2825   for (auto i : reduce_dim) {
2826     if (size[i] != 1){
2827       is_reduce_over_1D_vector = false;
2828       break;
2829     }
2830   }
2831 
2832   if (is_reduce_over_1D_vector) {
2833     Tensor self_;
2834     if (opt_dtype.has_value()) {
2835       self_ = self.to(*opt_dtype);
2836     } else {
2837       self_ = self;
2838     }
2839     if (ord != 0.0) {
2840       keepdim ? at::abs_outf(self_, const_cast<Tensor&>(result)) : at::abs_outf(self_.squeeze(reduce_dim), const_cast<Tensor&>(result));
2841     } else {
2842       keepdim ? at::ne_outf(self_, 0, const_cast<Tensor&>(result)) : at::ne_outf(self_.squeeze(reduce_dim), 0, const_cast<Tensor&>(result));
2843     }
2844     return;
2845   }
2846 
2847   // No need to handle opt_dtype explicitly as it is already encoded in the dtype of result
2848 
2849   // https://github.com/pytorch/pytorch/issues/52648
2850   // Reductions always use `std::abs` to compute the absolute value. In the backward of this
2851   // function, we need to locate the index that was selected as the largest value. To do so
2852   // we do self.abs() == result to locate the index of the largest element.
2853   // Now, self.abs() may dispatch to a vectorized implementation which gives slightly different
2854   // results to the std::abs(std::complex<T>) implementation.
2855   // As such, to be able to compute the correct index in the backward, we need to use self.abs()
2856   // both in the forward and in the backward
2857   Tensor self_;
2858   if (self.is_cpu() && self.is_complex() && std::abs(ord) == INFINITY) {
2859     if (opt_dtype.has_value()) {
2860       self_ = self.to(*opt_dtype).abs();
2861     } else {
2862       self_ = self.abs();
2863     }
2864   } else {
2865     self_ = self;
2866   }
2867 
2868   auto iter = make_reduction("vector_norm", const_cast<Tensor&>(result), self_, dim, keepdim, result.scalar_type());
2869   norm_stub(iter.device_type(), iter, ord);
2870 }
2871 
_linalg_matrix_norm_checks(const Tensor & A,std::vector<int64_t> & dim,std::optional<ScalarType> opt_dtype,bool low_precision)2872 static void _linalg_matrix_norm_checks(const Tensor& A, std::vector<int64_t>& dim, std::optional<ScalarType> opt_dtype, bool low_precision) {
2873   // A
2874   at::native::checkIsMatrix(A, "linalg.matrix_norm");
2875   at::native::checkFloatingOrComplex(A, "linalg.matrix_norm", /*low_precision*/low_precision);
2876 
2877   // dim
2878   TORCH_CHECK(dim.size() == 2, "linalg.matrix_norm: dim must be a 2-tuple. Got ", dim);
2879   // wrap first to identify weird scenarios like A.ndim = 2, dim = (1, -1)
2880   // dim is modified in place while wrapping it
2881   maybe_wrap_dims(dim, A.dim());
2882   TORCH_CHECK(dim[0] != dim[1], "linalg.matrix_norm: dims must be different. Got (", dim[0], ", ", dim[1], ")");
2883 
2884   // dtype
2885   at::detail::check_linalg_norm_dtype(opt_dtype, A.scalar_type(), "linalg.matrix_norm");
2886 }
2887 
linalg_matrix_norm(const Tensor & A,const Scalar & scalar_ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype)2888 Tensor linalg_matrix_norm(
2889     const Tensor& A,
2890     const Scalar& scalar_ord,
2891     IntArrayRef dim,
2892     bool keepdim,
2893     std::optional<ScalarType> opt_dtype) {
2894   // Check ord first as it will be used in the dtype check of A
2895   auto ord = scalar_ord.toDouble();
2896   auto abs_ord = std::abs(ord);
2897   TORCH_CHECK(abs_ord == 2. || abs_ord == 1. || abs_ord == INFINITY, "linalg.matrix_norm: Order ", ord, " not supported.");
2898 
2899   auto dim_ = dim.vec();
2900   // Check A, dim, and dtype
2901   _linalg_matrix_norm_checks(A, dim_, opt_dtype, /*low_precision*/abs_ord != 2.);
2902 
2903   auto max_min = [ord, keepdim](const Tensor& A, int64_t dim) { return ord > 0 ? A.amax(dim, keepdim) : A.amin(dim, keepdim); };
2904   if (abs_ord == 2.) {
2905     // Move dims to the end
2906     auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], A.dim());
2907 
2908     auto A_ = opt_dtype.has_value() ? A.to(*opt_dtype) : A;
2909     auto result = max_min(at::linalg_svdvals(A_.permute(permutation)), -1);
2910     if (keepdim) {
2911       auto permutation_reverse = create_reverse_permutation(std::move(permutation));
2912       result = result.unsqueeze(-1).permute(permutation_reverse);
2913     }
2914     return result;
2915   } else {  // 1, -1, inf, -inf
2916     // The infty norm is like the 1 norm on the transposed matrix
2917     if (abs_ord == INFINITY) {
2918       std::swap(dim_[0], dim_[1]);
2919     }
2920 
2921     // If the first reduction removes one dim from the front (dim_[0] < dim_[1]), after this
2922     // reduction dim_[1] will be off by one
2923     if (!keepdim && (dim_[0] < dim_[1])) {
2924       dim_[1]--;
2925     }
2926     return max_min(at::linalg_vector_norm(A, 1., {dim_[0]}, keepdim, opt_dtype), dim_[1]);
2927   }
2928 }
2929 
linalg_matrix_norm_out(const Tensor & A,const Scalar & ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)2930 Tensor& linalg_matrix_norm_out(
2931     const Tensor& A,
2932     const Scalar& ord,
2933     IntArrayRef dim,
2934     bool keepdim,
2935     std::optional<ScalarType> opt_dtype,
2936     Tensor& result) {
2937   checkSameDevice("linalg.matrix_norm", A, result);
2938   auto out = at::linalg_matrix_norm(A, ord, dim, keepdim, opt_dtype);
2939   TORCH_CHECK(out.scalar_type() == result.scalar_type(),
2940               "linalg.matrix_norm expected out tensor dtype ", out.scalar_type(),
2941               " but got: ", result.scalar_type());
2942   at::native::resize_output(result, out.sizes());
2943   result.copy_(out);
2944   return result;
2945 }
2946 
2947 // fro / nuc
linalg_matrix_norm(const Tensor & A,c10::string_view ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype)2948 Tensor linalg_matrix_norm(
2949     const Tensor& A,
2950     c10::string_view ord,
2951     IntArrayRef dim,
2952     bool keepdim,
2953     std::optional<ScalarType> opt_dtype) {
2954   // Check ord first as it will be used in the dtype check of A
2955   TORCH_CHECK(ord == "fro" || ord == "nuc", "linalg.matrix_norm: Order ", ord, " not supported.");
2956 
2957   auto dim_ = dim.vec();
2958   // Check A, dim, and dtype
2959   _linalg_matrix_norm_checks(A, dim_, opt_dtype, /*low_precision*/ord != "nuc");
2960 
2961   if (ord == "fro") {
2962     return at::linalg_vector_norm(A, 2, dim_, keepdim, opt_dtype);
2963   } else {  // nuc
2964     auto A_ = opt_dtype.has_value() ? A.to(*opt_dtype) : A;
2965 
2966     // Move dims to the end
2967     auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], A_.dim());
2968     auto result = at::linalg_svdvals(A_.permute(permutation)).sum(-1, keepdim);
2969     if (keepdim) {
2970       auto permutation_reverse = create_reverse_permutation(std::move(permutation));
2971       result = result.unsqueeze(-1).permute(permutation_reverse);
2972     }
2973     return result;
2974   }
2975 }
2976 
linalg_matrix_norm_out(const Tensor & A,c10::string_view ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)2977 Tensor& linalg_matrix_norm_out(
2978     const Tensor& A,
2979     c10::string_view ord,
2980     IntArrayRef dim,
2981     bool keepdim,
2982     std::optional<ScalarType> opt_dtype,
2983     Tensor& result) {
2984   checkSameDevice("linalg.matrix_norm", A, result);
2985   auto out = at::linalg_matrix_norm(A, ord, dim, keepdim, opt_dtype);
2986   TORCH_CHECK(out.scalar_type() == result.scalar_type(),
2987               "linalg.matrix_norm expected out tensor dtype ", out.scalar_type(),
2988               " but got: ", result.scalar_type());
2989   at::native::resize_output(result, out.sizes());
2990   result.copy_(out);
2991   return result;
2992 }
2993 
2994 // Numerical or None norms
linalg_norm(const Tensor & X,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)2995 Tensor linalg_norm(const Tensor& X, const std::optional<Scalar>& opt_ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
2996   if (opt_dim.has_value()) {
2997     TORCH_CHECK(opt_dim->size() == 1 || opt_dim ->size() == 2, "linalg.norm: If ",
2998               "dim is specified, it must be of length 1 or 2. Got ", *opt_dim);
2999   } else {
3000     if (opt_ord.has_value()) {
3001       TORCH_CHECK(X.dim() == 1 || X.dim() == 2, "linalg.norm: If ",
3002                   "dim is not specified but ord is, the input must be 1D or 2D. Got ", X.dim(), "D.");
3003     }
3004   }
3005 
3006   // If ord=None, we'll always use the 2-norm or frob norm (which are the same) so we go through
3007   // vector_norm
3008   if (opt_ord.has_value() &&
3009        ((opt_dim.has_value() && opt_dim->size() == 2) ||
3010         (!opt_dim.has_value() && X.dim() == 2))) {
3011     using Int = IntArrayRef::value_type;
3012     auto dim = opt_dim.has_value() ? opt_dim.value().vec() : std::vector<Int>{0, 1};
3013     return at::linalg_matrix_norm(X, *opt_ord, dim, keepdim, opt_dtype);
3014   } else {
3015     auto scalar_ord = opt_ord.value_or(Scalar(2.));
3016     return at::linalg_vector_norm(X, scalar_ord, opt_dim, keepdim, opt_dtype);
3017   }
3018 }
3019 
linalg_norm_out(const Tensor & X,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)3020 Tensor& linalg_norm_out(const Tensor& X, const std::optional<Scalar>& opt_ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
3021   checkSameDevice("linalg.norm", X, result);
3022   auto out = at::linalg_norm(X, opt_ord, opt_dim, keepdim, opt_dtype);
3023   TORCH_CHECK(out.scalar_type() == result.scalar_type(),
3024               "linalg.norm expected out tensor dtype ", out.scalar_type(),
3025               " but got: ", result.scalar_type());
3026   at::native::resize_output(result, out.sizes());
3027   result.copy_(out);
3028   return result;
3029 }
3030 
3031 // Frobenius and nuclear norms
linalg_norm(const Tensor & X,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)3032 Tensor linalg_norm(const Tensor& X, c10::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
3033   if (opt_dim.has_value()) {
3034     TORCH_CHECK(opt_dim->size() == 1 || opt_dim ->size() == 2, "linalg.norm: If ",
3035               "dim is specified, it mut be of length 1 or 2. Got ", *opt_dim);
3036   } else {
3037     TORCH_CHECK(X.dim() == 1 || X.dim() == 2, "linalg.norm: If ",
3038                 "dim is not specified but ord is, the input must be 1D or 2D. Got ", X.dim(), "D.");
3039   }
3040   using Int = IntArrayRef::value_type;
3041   auto dim = opt_dim.has_value() ? opt_dim.value().vec() : std::vector<Int>{0, 1};
3042   return at::linalg_matrix_norm(X, ord, dim, keepdim, opt_dtype);
3043 }
3044 
linalg_norm_out(const Tensor & X,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)3045 Tensor& linalg_norm_out(const Tensor& X, c10::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
3046   checkSameDevice("linalg.norm", X, result);
3047   auto out = at::linalg_norm(X, ord, opt_dim, keepdim, opt_dtype);
3048   TORCH_CHECK(out.scalar_type() == result.scalar_type(),
3049               "linalg.norm expected out tensor dtype ", out.scalar_type(),
3050               " but got: ", result.scalar_type());
3051   at::native::resize_output(result, out.sizes());
3052   result.copy_(out);
3053   return result;
3054 }
3055 
3056 ////////////////////////////////////////////////////////////////////////////////
3057 //                              Frobenius Norm                                //
3058 ////////////////////////////////////////////////////////////////////////////////
3059 
frobenius_norm(const Tensor & self,IntArrayRef dim,bool keepdim)3060 Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
3061   auto device = self.device();
3062   if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
3063     TORCH_WARN_ONCE(
3064       "at::frobenius_norm is deprecated and it is just left for JIT compatibility. ",
3065       "It will be removed in a future PyTorch release. Please use ",
3066       "`linalg.vector_norm(A, 2., dim, keepdim)` instead"
3067     );
3068   }
3069   // This frobenius norm is just wrong, but well
3070   TORCH_CHECK(dim.size() <= 2,
3071               "Expected at most 2 dimensions, but got ", dim.size(), " dimensions instead.");
3072   // Dispatch to at::norm as it is implemented for Sparse and MPS backends
3073   // TODO Make the backends implement vector_norm and matrix_norm
3074   return at::norm(self, 2., dim, keepdim);
3075 }
3076 
frobenius_norm_out(const Tensor & self,IntArrayRef dim,bool keepdim,Tensor & result)3077 Tensor &frobenius_norm_out(const Tensor& self,
3078     IntArrayRef dim,
3079     bool keepdim,
3080     Tensor& result) {
3081   auto device = self.device();
3082   if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
3083     TORCH_WARN_ONCE(
3084       "at::frobenius_norm is deprecated and it is just left for JIT compatibility. ",
3085       "It will be removed in a future PyTorch release. Please use ",
3086       "`linalg.vector_norm(A, 2., dim, keepdim)` instead"
3087     );
3088   }
3089   TORCH_CHECK(dim.size() <= 2,
3090               "Expected at most 2 dimensions, but got ", dim.size(), " dimensions instead.");
3091   return at::norm_out(result, self, 2., dim, keepdim);
3092 }
3093 
3094 ////////////////////////////////////////////////////////////////////////////////
3095 //                                Nuclear Norm                                //
3096 ////////////////////////////////////////////////////////////////////////////////
3097 
nuclear_norm(const Tensor & self,bool keepdim)3098 Tensor nuclear_norm(const Tensor& self, bool keepdim) {
3099   return at::native::nuclear_norm(self, IntArrayRef({-2, -1}), keepdim);
3100 }
3101 
nuclear_norm_out(const Tensor & self,bool keepdim,Tensor & result)3102 Tensor &nuclear_norm_out(const Tensor& self, bool keepdim, Tensor& result) {
3103   auto device = self.device();
3104   if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
3105     TORCH_WARN_ONCE(
3106       "at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
3107       "It will be removed in a future PyTorch release. Please use ",
3108       "`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
3109     );
3110   }
3111   return at::linalg_matrix_norm_out(result, self, "nuc", IntArrayRef({-2, -1}), keepdim);
3112 }
3113 
nuclear_norm(const Tensor & self,IntArrayRef dim,bool keepdim)3114 Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
3115   auto device = self.device();
3116   if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
3117     TORCH_WARN_ONCE(
3118       "at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
3119       "It will be removed in a future PyTorch release. Please use ",
3120       "`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
3121     );
3122   }
3123   return at::linalg_matrix_norm(self, "nuc", dim, keepdim);
3124 }
3125 
nuclear_norm_out(const Tensor & self,IntArrayRef dim,bool keepdim,Tensor & result)3126 Tensor& nuclear_norm_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
3127   auto device = self.device();
3128   if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
3129     TORCH_WARN_ONCE(
3130       "at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
3131       "It will be removed in a future PyTorch release. Please use ",
3132       "`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
3133     );
3134   }
3135   return at::linalg_matrix_norm_out(result, self, "nuc", dim, keepdim);
3136 }
3137 
3138 ////////////////////////////////////////////////////////////////////////////////
3139 //                              linalg.cond                                   //
3140 ////////////////////////////////////////////////////////////////////////////////
3141 
3142 
3143 // This function helps to dispatch norm computations depending on 'ord' of variant type
_linalg_cond_helper(const Tensor & self,std::variant<Scalar,c10::string_view> ord_variant)3144 static Tensor _linalg_cond_helper(const Tensor& self, std::variant<Scalar, c10::string_view> ord_variant) {
3145   Tensor inverse, info;
3146   std::tie(inverse, info) = at::linalg_inv_ex(self);
3147   info.unsqueeze_(-1).unsqueeze_(-1);
3148   inverse.masked_fill_(info > 0, INFINITY);
3149 
3150   return std::visit([&](auto&& ord) {
3151     Tensor norm_self = at::linalg_matrix_norm(self, ord);
3152     Tensor norm_inverse = at::linalg_matrix_norm(inverse, ord);
3153     Tensor result = norm_self * norm_inverse;
3154     // fix multiplication of zero and infinity for NumPy compatibility
3155     result.nan_to_num_(INFINITY, INFINITY, -INFINITY);
3156     return result;
3157   }, ord_variant);
3158 }
3159 
3160 // Return zero for each matrix in the batch
_linalg_cond_empty_matrix(const Tensor & self,c10::ScalarType dtype)3161 static Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) {
3162   auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2);
3163   TensorOptions options = self.options().dtype(toRealValueType(self.scalar_type()));
3164   return at::zeros(result_shape, options);
3165 }
3166 
_linalg_cond_check_ord(std::variant<Scalar,c10::string_view> ord_variant)3167 static void _linalg_cond_check_ord(std::variant<Scalar, c10::string_view> ord_variant) {
3168   if (ord_variant.index() == 0) {
3169     Scalar* ord = std::get_if<Scalar>(&ord_variant);
3170     double abs_ord = std::abs(ord->toDouble());
3171     TORCH_CHECK(abs_ord == 2.0 || abs_ord == 1.0 || abs_ord == INFINITY,
3172       "linalg.cond got an invalid norm type: ", ord->toDouble());
3173   } else if (ord_variant.index() == 1) {
3174     c10::string_view* ord = std::get_if<c10::string_view>(&ord_variant);
3175     TORCH_CHECK(*ord == "fro" || *ord == "nuc",
3176       "linalg.cond got an invalid norm type: ", *ord);
3177   } else {
3178     TORCH_CHECK(false,
3179       "linalg.cond: something went wrong while checking the norm type");
3180   }
3181 }
3182 
3183 // Numerical or None norms
linalg_cond(const Tensor & self,const std::optional<Scalar> & opt_ord)3184 Tensor linalg_cond(const Tensor& self, const std::optional<Scalar>& opt_ord) {
3185   TORCH_CHECK(self.dim() >= 2, "linalg.cond: The input tensor must have at least 2 dimensions.");
3186 
3187   // The default case is using 2-norm
3188   Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2;
3189 
3190   std::variant<Scalar, c10::string_view> ord_variant = ord;
3191   _linalg_cond_check_ord(ord_variant);
3192 
3193   // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input
3194   if (self.sym_numel() == 0) {
3195     auto real_dtype = toRealValueType(typeMetaToScalarType(self.dtype()));
3196     return _linalg_cond_empty_matrix(self, real_dtype);
3197   }
3198 
3199   // If ord == None or ord == ±2
3200   if (std::abs(ord.toDouble()) == 2.0) {
3201     auto singular_values = at::linalg_svdvals(self);
3202     // singular values are sorted in descending order
3203     auto s_max = at::narrow(singular_values, /*dim=*/-1, /*start=*/0, /*length=*/1);
3204     auto s_min = at::narrow(singular_values, /*dim=*/-1, /*start=*/-1, /*length=*/1);
3205     Tensor result;
3206     if (ord.toDouble() == -2.0) {
3207       result = s_min / s_max;
3208     } else {
3209       result = s_max / s_min;
3210     }
3211     // squeeze the result for NumPy compatibility
3212     return result.squeeze(-1);
3213   }
3214 
3215   // ord == ±1 ord == ±inf
3216   if (ord.isFloatingPoint()) { // ord == ±1
3217     squareCheckInputs(self, ("linalg.cond(ord=" + std::to_string(ord.to<double>()) + ")").c_str());
3218   } else { // ord == ±inf
3219     squareCheckInputs(self, ("linalg.cond(ord=" + std::to_string(ord.to<int64_t>()) + ")").c_str());
3220   }
3221   return _linalg_cond_helper(self, std::move(ord_variant));
3222 }
3223 
linalg_cond_out(const Tensor & self,const std::optional<Scalar> & opt_ord,Tensor & result)3224 Tensor& linalg_cond_out(const Tensor& self, const std::optional<Scalar>& opt_ord, Tensor& result) {
3225   checkSameDevice("linalg.cond", result, self);
3226   ScalarType real_dtype = toRealValueType(self.scalar_type());
3227   checkLinalgCompatibleDtype("linalg.cond", result.scalar_type(), real_dtype);
3228 
3229   Tensor result_tmp = at::linalg_cond(self, opt_ord);
3230   at::native::resize_output(result, result_tmp.sizes());
3231   result.copy_(result_tmp);
3232   return result;
3233 }
3234 
3235 // Frobenius or nuclear norms
linalg_cond(const Tensor & self,c10::string_view ord)3236 Tensor linalg_cond(const Tensor& self, c10::string_view ord) {
3237   squareCheckInputs(self, ("linalg.cond(ord=" + std::string(ord) + ")").c_str());
3238   std::variant<Scalar, c10::string_view> ord_variant = ord;
3239   _linalg_cond_check_ord(ord_variant);
3240 
3241   // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input
3242   if (self.numel() == 0) {
3243     return _linalg_cond_empty_matrix(self, self.scalar_type());
3244   }
3245 
3246   if (ord == "nuc") {
3247     // calling matrix_norm with "nuc" on inputs with infinities raises an error
3248     // therefore we use the mathematical definition of nuclear norm directly
3249     // instead of going through the matrix_norm
3250     auto singular_values = at::linalg_svdvals(self);
3251     return singular_values.sum(-1) * (singular_values.reciprocal().sum(-1));
3252   }
3253 
3254   return _linalg_cond_helper(self, std::move(ord_variant));
3255 }
3256 
3257 // TODO: implement _out variant avoiding copy and using already allocated storage directly
linalg_cond_out(const Tensor & self,c10::string_view ord,Tensor & result)3258 Tensor& linalg_cond_out(const Tensor& self, c10::string_view ord, Tensor& result) {
3259   checkSameDevice("linalg.cond", result, self);
3260   ScalarType real_dtype = toRealValueType(self.scalar_type());
3261   checkLinalgCompatibleDtype("linalg.cond", result.scalar_type(), real_dtype);
3262 
3263   Tensor result_tmp = at::linalg_cond(self, ord);
3264   at::native::resize_output(result, result_tmp.sizes());
3265   result.copy_(result_tmp);
3266   return result;
3267 }
3268 
linalg_tensorinv(const Tensor & self,int64_t ind)3269 Tensor linalg_tensorinv(const Tensor& self, int64_t ind) {
3270   /*
3271   The idea is to reduce the problem to 2D square matrix inversion.
3272   Step 1. Calculate the shape of the result and the shape of the intermediate 2D matrix.
3273   Step 2. Reshape `self` to 2D matrix.
3274   Step 3. Invert the 2D matrix self.to_2D()
3275           There is no quick way to find out whether the matrix is invertible,
3276           so at this stage an error from at::inverse can be thrown.
3277           Note that for CUDA this causes cross-device memory synchronization that can be slow.
3278   Step 4. reshape the result.
3279   */
3280   TORCH_CHECK(ind > 0, "Expected a strictly positive integer for 'ind', but got ", ind);
3281 
3282   // self[ind:]
3283   std::vector<c10::SymInt> shape_ind_end = self.sym_sizes().slice(ind).vec();
3284   // self[:ind]
3285   std::vector<c10::SymInt> shape_start_ind = self.sym_sizes().slice(0, ind).vec();
3286 
3287   c10::SymInt prod_ind_end = c10::multiply_integers(shape_ind_end.cbegin(), shape_ind_end.cend());
3288   c10::SymInt prod_start_ind = c10::multiply_integers(shape_start_ind.cbegin(), shape_start_ind.cend());
3289 
3290   // Check whether the self tensor can be reshaped to the 2D square matrix
3291   TORCH_CHECK(prod_ind_end == prod_start_ind,
3292     "Expected self to satisfy the requirement prod(self.shape[ind:]) == prod(self.shape[:ind]), but got ",
3293     prod_ind_end, " != ", prod_start_ind);
3294 
3295   // Concatenate shape_ind_end and shape_start_ind to form the shape of the result
3296   // self[ind:] + self[:ind]
3297   shape_ind_end.insert(shape_ind_end.cend(), shape_start_ind.cbegin(), shape_start_ind.cend());
3298 
3299   // If the reshaped self is not invertible catch this error
3300   auto [result, info] = at::linalg_inv_ex(self.reshape_symint({prod_ind_end, prod_ind_end}), /*check_errors=*/false);
3301   at::_linalg_check_errors(info, "inv", /*is_matrix*/true);
3302 
3303   return result.reshape_symint(shape_ind_end);
3304 }
3305 
3306 // TODO: implement _out variant avoiding copy and using already allocated storage directly
linalg_tensorinv_out(const Tensor & self,int64_t ind,Tensor & result)3307 Tensor& linalg_tensorinv_out(const Tensor& self, int64_t ind, Tensor& result) {
3308   checkSameDevice("tensorinv", result, self);
3309   checkLinalgCompatibleDtype("tensorinv", result, self);
3310 
3311   Tensor result_tmp = at::linalg_tensorinv(self, ind);
3312   at::native::resize_output(result, result_tmp.sizes());
3313   result.copy_(result_tmp);
3314   return result;
3315 }
3316 
linalg_tensorsolve(const Tensor & self,const Tensor & other,OptionalIntArrayRef dims)3317 Tensor linalg_tensorsolve(const Tensor& self, const Tensor& other, OptionalIntArrayRef dims) {
3318   /*
3319   The idea is to reduce the problem to 2D matrix solve.
3320   Step 1. (optional) `self` is permuted with `dims` such that dimensions from `dims` are moved to the right.
3321   For example, if we have 4D input with the shape (1, 2, 3, 4) and dims=(0, 2),
3322   then the result of permutation would have the shape (2, 4, 1, 3).
3323   Step 2. reshape `self` to 2D matrix.
3324   Step 3. solve the matrix equation self.to_2D() @ result = other.to_1D()
3325   Step 4. reshape the result.
3326   */
3327   int64_t ndim = self.dim();
3328   Tensor self_ = self;
3329 
3330   // move dimensions of `self_` from `dims` to the end
3331   if (dims.has_value()) {
3332     DimVector dest_axes(dims.value().size());
3333     std::iota(dest_axes.begin(), dest_axes.end(), ndim - dest_axes.size());
3334     self_ = at::movedim(self_, dims.value(), dest_axes);
3335   }
3336 
3337   // result_shape is self_.sizes[-(an-other.dim):]
3338   std::vector<c10::SymInt> result_shape = self_.sym_sizes().slice(other.dim(), ndim - other.dim()).vec();
3339 
3340   c10::SymInt result_product = c10::multiply_integers(result_shape.begin(), result_shape.end());
3341   c10::SymInt other_product = c10::multiply_integers(other.sym_sizes().begin(), other.sym_sizes().end());
3342 
3343   // Check whether the self tensor can be reshaped to the 2D square matrix
3344   TORCH_CHECK(result_product == other_product,
3345     "Expected self to satisfy the requirement prod(self.shape[other.ndim:]) == prod(self.shape[:other.ndim]), but got ",
3346     result_product, " != ", other_product);
3347 
3348   self_ = self_.reshape_symint({result_product, result_product});
3349 
3350   // normally `other` would be flattened by at::linalg_solve expects 2D input
3351   Tensor result = at::linalg_solve(self_, other.flatten());
3352   return result.reshape_symint(result_shape);
3353 }
3354 
linalg_tensorsolve_out(const Tensor & self,const Tensor & other,OptionalIntArrayRef dims,Tensor & result)3355 Tensor& linalg_tensorsolve_out(const Tensor& self, const Tensor& other, OptionalIntArrayRef dims, Tensor& result) {
3356   checkSameDevice("tensorsolve", result, self);
3357   checkLinalgCompatibleDtype("tensorsolve", result, self);
3358 
3359   Tensor result_tmp = at::linalg_tensorsolve(self, other, dims);
3360   at::native::resize_output(result, result_tmp.sizes());
3361   result.copy_(result_tmp);
3362   return result;
3363 }
3364 
3365 namespace {
3366 struct KronImpl final {
3367   public:
KronImplat::native::__anon4af68b541811::KronImpl3368     explicit KronImpl(const Tensor& self, const Tensor& other) {
3369       maxdim = std::max(self.dim(), other.dim());
3370       int64_t pad_self = maxdim - self.dim();
3371       int64_t pad_other = maxdim - other.dim();
3372       a_reshape = c10::SmallVector<int64_t, 10>(2 * maxdim);
3373       b_reshape = c10::SmallVector<int64_t, 10>(2 * maxdim);
3374       result_reshape = c10::SmallVector<int64_t, 10>(maxdim);
3375       for (const auto i : c10::irange(maxdim)) {
3376         a_reshape[2 * i] = (i >= pad_self ? self.sizes()[i - pad_self] : 1);
3377         a_reshape[2 * i + 1] = 1;
3378         b_reshape[2 * i] = 1;
3379         b_reshape[2 * i + 1] = (i >= pad_other ? other.sizes()[i - pad_other] : 1);
3380         result_reshape[i] = a_reshape[2 * i] * b_reshape[2 * i + 1];
3381       }
3382       self_view = at::_unsafe_view(self, a_reshape);
3383       other_view = at::_unsafe_view(other, b_reshape);
3384     }
3385 
kron_outat::native::__anon4af68b541811::KronImpl3386     Tensor& kron_out(Tensor& result) const {
3387       TORCH_INTERNAL_ASSERT(result.defined(), "Cannot call kron_out with an undefined result tensor as the out argument. Please allocate a Tensor before calling kron_out with it.");
3388 
3389       c10::SmallVector<int64_t, 10> mul_shape(2 * maxdim);
3390       for (const auto i : c10::irange(maxdim)) {
3391         mul_shape[2 * i] = a_reshape[2 * i];
3392         mul_shape[2 * i + 1] = b_reshape[2 * i + 1];
3393       }
3394       at::native::resize_output(result, result_reshape);
3395       auto result_mul = at::_unsafe_view(result, mul_shape);
3396       at::mul_out(result_mul, self_view, other_view);
3397 
3398       return result;
3399     }
3400 
kronat::native::__anon4af68b541811::KronImpl3401     Tensor kron() const {
3402       return at::_unsafe_view(at::mul(self_view, other_view), result_reshape);
3403     }
3404   private:
3405     int64_t maxdim;
3406     Tensor self_view;
3407     Tensor other_view;
3408     c10::SmallVector<int64_t, 10> result_reshape;
3409     c10::SmallVector<int64_t, 10> a_reshape;
3410     c10::SmallVector<int64_t, 10> b_reshape;
3411 };
3412 }
3413 
3414 /*
3415 Calculates the Kronecker product between two Tensors.
3416 */
kron_out(const Tensor & self,const Tensor & other,Tensor & result)3417 Tensor& kron_out(const Tensor& self, const Tensor& other, Tensor& result) {
3418   return KronImpl(self, other).kron_out(result);
3419 }
3420 
kron(const Tensor & self,const Tensor & other)3421 Tensor kron(const Tensor& self, const Tensor& other) {
3422   return KronImpl(self, other).kron();
3423 }
3424 
3425 // Weight Only Quantization Gemm
3426 DEFINE_DISPATCH(weight_to_int4pack_stub);
3427 DEFINE_DISPATCH(int4pack_mm_stub);
3428 DEFINE_DISPATCH(int8pack_mm_stub);
3429 
_convert_weight_to_int4pack_cpu(const Tensor & in,int64_t innerKTiles)3430 Tensor _convert_weight_to_int4pack_cpu(
3431     const Tensor& in,
3432     int64_t innerKTiles) {
3433 
3434   TORCH_CHECK(in.dim() == 2,
3435       __func__, " : expect weight to be 2D tensor.");
3436   TORCH_CHECK(in.dtype() == at::kByte,
3437       __func__, " : expect weight to be kByte.");
3438   TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8,
3439       __func__, " : innerKTiles need to be 2, 4, or 8, got ", innerKTiles);
3440 
3441   auto weight = in.contiguous();
3442   auto N = weight.size(0);
3443   auto K = weight.size(1) * 2;
3444 
3445   // Create fake shapes for cpu. The meta registration in dynamo requires
3446   // operator has the same output shape for each device. So creating a fake
3447   // shape {N / 8, K / (16 * innerKTiles), 32, innerKTiles / 2}
3448   constexpr int64_t kNTileSize = 8;
3449   constexpr int64_t kKTileSize = 16;
3450   auto nTiles = (N + kNTileSize - 1) / kNTileSize;
3451 
3452   TORCH_CHECK(N % 16 == 0,
3453       __func__, " : expect N to be dividable by 16");
3454   const int64_t kSuperKTileSize = kKTileSize * innerKTiles;
3455   TORCH_CHECK( K % kSuperKTileSize == 0,
3456       __func__, " : epxect K to be dividable by ", kSuperKTileSize);
3457   auto kSuperTiles = (K + kSuperKTileSize - 1) / kSuperKTileSize;
3458 
3459   auto weight_packed = at::empty(
3460       {nTiles, kSuperTiles, 32, innerKTiles / 2},
3461       at::TensorOptions().dtype(at::kInt));
3462 
3463   weight_to_int4pack_stub(kCPU, weight_packed, weight, N, K);
3464   return weight_packed;
3465 }
3466 
_weight_int4pack_mm_cpu(const Tensor & A,const Tensor & B,int64_t qGroupSize,const Tensor & qScaleAndZeros)3467 Tensor _weight_int4pack_mm_cpu(
3468     const Tensor& A,
3469     const Tensor& B,
3470     int64_t qGroupSize,
3471     const Tensor& qScaleAndZeros) {
3472 
3473   constexpr int64_t kNTileSize = 8;
3474 
3475   auto M = A.size(0);
3476   auto N = B.size(0) * kNTileSize;
3477   auto K = A.size(1);
3478 
3479   TORCH_CHECK(A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
3480       __func__, " : expect A to be either 32-bit or 16-bit float tensor.");
3481   TORCH_CHECK(A.is_contiguous(),
3482       __func__, " : expect A to be contiguous.");
3483   TORCH_CHECK(A.dim() == 2,
3484       __func__, " : expect A to be 2D tensor.");
3485 
3486   TORCH_CHECK(B.dtype() == kInt,
3487       __func__, " : expect B to be int32 tensor.");
3488   TORCH_CHECK(B.is_contiguous(),
3489       __func__, " : expect B to be contiguous.");
3490   TORCH_CHECK(B.dim() == 4,
3491       __func__, " : expect B to 4d tensor.");
3492 
3493   TORCH_CHECK(qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128
3494       || qGroupSize == 256,
3495       __func__, ": expect qGroupSize to be 32, 64, 128 or 256, got ", qGroupSize);
3496 
3497   TORCH_CHECK(qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(1) == N
3498       && qScaleAndZeros.size(2) == 2,
3499       __func__, ": expect qScaleAndZeros to be 3d tensor with sizes [:, ", N, ", 2]");
3500 
3501   auto C = at::empty({M, N}, A.options());
3502   int4pack_mm_stub(kCPU, C, A, B, qGroupSize, qScaleAndZeros, N, K);
3503 
3504   return C;
3505 }
3506 
_weight_int8pack_mm_cpu(const Tensor & A,const Tensor & B,const Tensor & scales)3507 Tensor _weight_int8pack_mm_cpu(
3508     const Tensor& A,
3509     const Tensor& B,
3510     const Tensor& scales) {
3511 
3512   auto M = A.size(0);
3513   auto N = B.size(0);
3514   auto K = A.size(1);
3515 
3516   TORCH_CHECK(A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
3517       __func__, " : expect A to be either 32-bit or 16-bit float tensor.");
3518   TORCH_CHECK(A.is_contiguous(),
3519       __func__, " : expect A to be contiguous.");
3520   TORCH_CHECK(A.dim() == 2,
3521       __func__, " : expect A to be 2D tensor.");
3522 
3523   TORCH_CHECK(B.dtype() == kChar,
3524       __func__, " : expect B to be int8 tensor.");
3525   TORCH_CHECK(B.is_contiguous(),
3526       __func__, " : expect B to be contiguous.");
3527   TORCH_CHECK(B.size(1) == K,
3528       __func__, " : expect B.size(1) == ", K);
3529 
3530   TORCH_CHECK(scales.dim() == 1 && scales.size(0) == N,
3531       __func__, " : expect scales to be 1d tensor with size ", N);
3532 
3533   auto C = at::empty({M, N}, A.options());
3534   int8pack_mm_stub(kCPU, C, A, B, scales);
3535 
3536   return C;
3537 }
3538 
_int_mm_out_cpu(const Tensor & self,const Tensor & mat2,Tensor & result)3539 Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result) {
3540 #ifndef STRIP_ERROR_MESSAGES
3541   static constexpr c10::string_view func_name = "int_mm_out_cpu";
3542 #endif
3543   TORCH_CHECK(self.dim() == 2, func_name, ": Expected self to be of dimension 2 but got ", self.dim());
3544   TORCH_CHECK(mat2.dim() == 2, func_name, ": Expected mat2 to be of dimension 2 but got ", mat2.dim());
3545   TORCH_CHECK(self.size(1) == mat2.size(0), func_name, ": self.size(1) needs to match mat2.size(0) but got ", self.size(1), " and ", mat2.size(0));
3546   TORCH_CHECK(self.dtype() == at::kChar, func_name, ": Expected self dtype to be of type int8 but got ", self.dtype());
3547   TORCH_CHECK(mat2.dtype() == at::kChar, func_name, ": Expected mat2 dtype to be of type int8 but got ", mat2.dtype());
3548   TORCH_CHECK(result.dtype() == at::kInt, func_name, ": Expected result dtype to be of type kInt but got ", result.dtype());
3549   TORCH_CHECK(result.size(0) == self.size(0), func_name, ": Expected result.size(0) to be ", self.size(0), " but got ", result.size(0));
3550   TORCH_CHECK(result.size(1) == mat2.size(1), func_name, ": Expected result.size(1) to be ", mat2.size(1), " but got ", result.size(1));
3551   TORCH_CHECK(result.dim() == 2, func_name, ": Expected result to be of dimension 2 but got ", result.dim());
3552   TORCH_CHECK(result.is_contiguous(), func_name, ": Expected result to be contiguous.");
3553 
3554   if (result.numel() == 0 || self.size(1) == 0) {
3555     return result.zero_();
3556   }
3557 
3558   bool dispatched = false;
3559   if (at::globalContext().userEnabledMkldnn()) {
3560     try {
3561       mkldnn_matmul_i8i8i32(self, mat2, result);
3562       dispatched = true;
3563     } catch (const std::exception& e) {
3564       TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
3565     }
3566   }
3567   if (!dispatched) {
3568     auto a = reinterpret_cast<int8_t*>(self.data_ptr());
3569     auto b = reinterpret_cast<int8_t*>(mat2.data_ptr());
3570     auto c = reinterpret_cast<int32_t*>(result.data_ptr());
3571     const int64_t m = result.size(0);
3572     const int64_t n = result.size(1);
3573     const int64_t k = self.size(1);
3574     const int64_t lda_0 = self.strides()[0];
3575     const int64_t lda_1 = self.strides()[1];
3576     const int64_t ldb_0 = mat2.strides()[0];
3577     const int64_t ldb_1 = mat2.strides()[1];
3578     const int64_t ldc = result.strides()[0];
3579     parallel_for(0, m * n, 1, [&](int64_t start, int64_t end) {
3580       for (const auto i : c10::irange(start, end)) {
3581         auto row = i / n;
3582         auto col = i % n;
3583         c[row * ldc + col] = 0;
3584         for (const auto k : c10::irange(k)) {
3585           c[row * ldc + col] = c[row * ldc + col] +
3586               static_cast<int32_t>(a[row * lda_0 + k * lda_1]) *
3587                   static_cast<int32_t>(b[k * ldb_0 + col * ldb_1]);
3588         }
3589       }
3590     });
3591   }
3592   return result;
3593 }
3594 
_int_mm_cpu(const Tensor & self,const Tensor & mat2)3595 Tensor _int_mm_cpu(const Tensor& self, const Tensor& mat2) {
3596   Tensor result = at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
3597   return _int_mm_out_cpu(self, mat2, result);
3598 }
3599 
3600 } // namespace native
3601 } // namespace at
3602