1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Context.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/ExpandUtils.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/OpMathType.h>
7 #include <ATen/Parallel.h>
8 #include <ATen/TensorIndexing.h>
9 #include <ATen/TensorIterator.h>
10 #include <ATen/TensorOperators.h>
11 #include <ATen/TensorSubclassLikeUtils.h>
12 #include <ATen/TensorUtils.h>
13 #include <ATen/core/Tensor.h>
14 #include <ATen/native/CPUBlas.h>
15 #include <ATen/native/cpu/int_mm_kernel.h>
16 #include <ATen/native/LinearAlgebra.h>
17 #include <ATen/native/LinearAlgebraUtils.h>
18 #include <ATen/native/ReduceOps.h>
19 #include <ATen/native/ReduceOpsUtils.h>
20 #include <ATen/native/Resize.h>
21 #include <ATen/native/mkldnn/Matmul.h>
22 #include <ATen/native/mkldnn/Utils.h>
23 #include <c10/core/GradMode.h>
24 #include <c10/util/accumulate.h>
25 #include <c10/util/irange.h>
26 #include <variant>
27
28 #ifndef AT_PER_OPERATOR_HEADERS
29 #include <ATen/Functions.h>
30 #include <ATen/NativeFunctions.h>
31 #else
32 #include <ATen/ops/_addmm_activation_native.h>
33 #include <ATen/ops/_compute_linear_combination_native.h>
34 #include <ATen/ops/_convert_weight_to_int4pack_native.h>
35 #include <ATen/ops/_int_mm_native.h>
36 #include <ATen/ops/_linalg_check_errors.h>
37 #include <ATen/ops/_linalg_det.h>
38 #include <ATen/ops/_linalg_det_native.h>
39 #include <ATen/ops/_linalg_slogdet.h>
40 #include <ATen/ops/_linalg_slogdet_native.h>
41 #include <ATen/ops/_unsafe_view.h>
42 #include <ATen/ops/_weight_int4pack_mm_native.h>
43 #include <ATen/ops/_weight_int8pack_mm_native.h>
44 #include <ATen/ops/abs.h>
45 #include <ATen/ops/addbmm_native.h>
46 #include <ATen/ops/addmm_native.h>
47 #include <ATen/ops/addr.h>
48 #include <ATen/ops/addr_native.h>
49 #include <ATen/ops/arange.h>
50 #include <ATen/ops/argsort.h>
51 #include <ATen/ops/baddbmm_native.h>
52 #include <ATen/ops/bmm.h>
53 #include <ATen/ops/bmm_native.h>
54 #include <ATen/ops/cat.h>
55 #include <ATen/ops/ceil.h>
56 #include <ATen/ops/chain_matmul_native.h>
57 #include <ATen/ops/cumsum.h>
58 #include <ATen/ops/det_native.h>
59 #include <ATen/ops/diag_embed.h>
60 #include <ATen/ops/diff.h>
61 #include <ATen/ops/dot.h>
62 #include <ATen/ops/dot_native.h>
63 #include <ATen/ops/empty.h>
64 #include <ATen/ops/empty_like.h>
65 #include <ATen/ops/eye.h>
66 #include <ATen/ops/floor.h>
67 #include <ATen/ops/frobenius_norm_native.h>
68 #include <ATen/ops/from_blob.h>
69 #include <ATen/ops/full.h>
70 #include <ATen/ops/full_like.h>
71 #include <ATen/ops/gelu.h>
72 #include <ATen/ops/ger_native.h>
73 #include <ATen/ops/index_select.h>
74 #include <ATen/ops/inner_native.h>
75 #include <ATen/ops/is_complex_native.h>
76 #include <ATen/ops/is_floating_point_native.h>
77 #include <ATen/ops/kron_native.h>
78 #include <ATen/ops/linalg_cond.h>
79 #include <ATen/ops/linalg_cond_native.h>
80 #include <ATen/ops/linalg_det.h>
81 #include <ATen/ops/linalg_det_native.h>
82 #include <ATen/ops/linalg_diagonal_native.h>
83 #include <ATen/ops/linalg_eigh.h>
84 #include <ATen/ops/linalg_eigvalsh.h>
85 #include <ATen/ops/linalg_inv.h>
86 #include <ATen/ops/linalg_inv_ex.h>
87 #include <ATen/ops/linalg_lu_factor_ex.h>
88 #include <ATen/ops/linalg_matmul_native.h>
89 #include <ATen/ops/linalg_matrix_exp.h>
90 #include <ATen/ops/linalg_matrix_exp_native.h>
91 #include <ATen/ops/linalg_matrix_norm.h>
92 #include <ATen/ops/linalg_matrix_norm_native.h>
93 #include <ATen/ops/linalg_matrix_power_native.h>
94 #include <ATen/ops/linalg_matrix_rank.h>
95 #include <ATen/ops/linalg_matrix_rank_native.h>
96 #include <ATen/ops/linalg_multi_dot_native.h>
97 #include <ATen/ops/linalg_norm.h>
98 #include <ATen/ops/linalg_norm_native.h>
99 #include <ATen/ops/linalg_pinv.h>
100 #include <ATen/ops/linalg_pinv_native.h>
101 #include <ATen/ops/linalg_slogdet.h>
102 #include <ATen/ops/linalg_slogdet_native.h>
103 #include <ATen/ops/linalg_solve.h>
104 #include <ATen/ops/linalg_svdvals.h>
105 #include <ATen/ops/linalg_tensorinv.h>
106 #include <ATen/ops/linalg_tensorinv_native.h>
107 #include <ATen/ops/linalg_tensorsolve.h>
108 #include <ATen/ops/linalg_tensorsolve_native.h>
109 #include <ATen/ops/linalg_vector_norm.h>
110 #include <ATen/ops/linalg_vector_norm_native.h>
111 #include <ATen/ops/log2.h>
112 #include <ATen/ops/logdet_native.h>
113 #include <ATen/ops/matmul.h>
114 #include <ATen/ops/matmul_native.h>
115 #include <ATen/ops/matrix_exp_backward_native.h>
116 #include <ATen/ops/matrix_exp_native.h>
117 #include <ATen/ops/matrix_power_native.h>
118 #include <ATen/ops/max.h>
119 #include <ATen/ops/mm.h>
120 #include <ATen/ops/mm_native.h>
121 #include <ATen/ops/movedim.h>
122 #include <ATen/ops/mul.h>
123 #include <ATen/ops/mv.h>
124 #include <ATen/ops/narrow.h>
125 #include <ATen/ops/ne.h>
126 #include <ATen/ops/norm.h>
127 #include <ATen/ops/nuclear_norm_native.h>
128 #include <ATen/ops/ones.h>
129 #include <ATen/ops/outer.h>
130 #include <ATen/ops/outer_native.h>
131 #include <ATen/ops/pinverse_native.h>
132 #include <ATen/ops/pow.h>
133 #include <ATen/ops/prod.h>
134 #include <ATen/ops/real.h>
135 #include <ATen/ops/relu.h>
136 #include <ATen/ops/slogdet_native.h>
137 #include <ATen/ops/sort.h>
138 #include <ATen/ops/sqrt.h>
139 #include <ATen/ops/sum.h>
140 #include <ATen/ops/tensordot.h>
141 #include <ATen/ops/unique_consecutive.h>
142 #include <ATen/ops/vdot_native.h>
143 #include <ATen/ops/where.h>
144 #include <ATen/ops/zeros.h>
145 #include <ATen/ops/zeros_like.h>
146 #endif
147
148 #include <limits>
149 #include <numeric>
150 #include <string>
151 #include <tuple>
152 #include <utility>
153 #if !defined(__s390x__) && !defined(__powerpc__)
154 #include <cpuinfo.h>
155 #endif
156
157 namespace at {
158
159 namespace detail {
check_linalg_norm_dtype(std::optional<ScalarType> opt_dtype,ScalarType self_dtype,const char * const name)160 static void check_linalg_norm_dtype(std::optional<ScalarType> opt_dtype, ScalarType self_dtype, const char* const name) {
161 if (opt_dtype.has_value()) {
162 auto dtype = opt_dtype.value();
163 TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype), name, ": dtype should"
164 " be floating point or complex, but got ", dtype);
165 TORCH_CHECK(isComplexType(self_dtype) == isComplexType(dtype),
166 name, ": dtype should be ", isComplexType(self_dtype) ? "complex" : "real",
167 " for ", isComplexType(self_dtype) ? "complex" : "real", " inputs, but got ", dtype);
168 TORCH_CHECK(promoteTypes(self_dtype, dtype) == dtype,
169 name, ": the dtype of the input ", "(", self_dtype, ") should be convertible ",
170 "without narrowing to the specified dtype (", dtype, ")");
171 }
172 }
173 }
174
175 namespace meta {
176
177 #define ADDMM_META() \
178 TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype, but got ", self.scalar_type(), " and ", mat2.scalar_type()); \
179 TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype, but got ", mat1.scalar_type(), " and ", mat2.scalar_type()); \
180 TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor"); \
181 TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor"); \
182 TORCH_CHECK( \
183 mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", \
184 mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); \
185 \
186 auto names = at::namedinference::propagate_names_for_addmm(mat1, mat2, self); \
187 set_output_raw_strided(0, {mat1.sizes()[0], mat2.sizes()[1]}, {}, mat1.options(), names);
188
TORCH_META_FUNC(addmm)189 TORCH_META_FUNC(addmm)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
190 ADDMM_META();
191 }
192
TORCH_META_FUNC(_addmm_activation)193 TORCH_META_FUNC(_addmm_activation)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu) {
194 ADDMM_META();
195 }
196
TORCH_META_FUNC(mm)197 TORCH_META_FUNC(mm)(const Tensor & self, const Tensor & mat2) {
198 TORCH_CHECK(self.dim() == 2, "self must be a matrix");
199 TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
200 TORCH_CHECK(
201 self.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
202 self.sizes()[0], "x", self.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
203
204 auto names = at::namedinference::compute_matmul_outnames(self, mat2);
205 set_output_raw_strided(0, {self.sizes()[0], mat2.sizes()[1]}, {}, self.options(), names);
206 }
207
TORCH_META_FUNC(linalg_vector_norm)208 TORCH_META_FUNC(linalg_vector_norm)(const Tensor& self, const Scalar& scalar_ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
209 at::native::checkFloatingOrComplex(self, "linalg.vector_norm");
210
211 auto dim = opt_dim.value_or(IntArrayRef{});
212 // Casting a large integer to a double will just introduce an error for
213 // values larger than 10^53 (same for negative numbers), so that's fine.
214 auto ord = scalar_ord.toDouble();
215
216 // For more context, see issue 52783
217 // If the tensor is empty and norm < 0 || norm == infty
218 // - We cannot reduce the whole tensor
219 // - We cannot reduce over an empty dimension
220 if (self.numel() == 0 && (ord < 0. || ord == INFINITY)) {
221 // dim=None or dim=() reduces the whole tensor
222 TORCH_CHECK(opt_dim.has_value() && !opt_dim->empty(),
223 "linalg.vector_norm cannot compute the ", scalar_ord, " norm on an empty ",
224 "tensor because the operation does not have an identity");
225 for (auto dim_num : dim) {
226 TORCH_CHECK(self.size(dim_num) != 0,
227 "linalg.vector_norm cannot compute the ", scalar_ord, " norm on the dimension ", dim_num ,
228 "because this dimension is empty and the operation does not have an identity");
229 }
230 }
231
232 at::detail::check_linalg_norm_dtype(opt_dtype, self.scalar_type(), "linalg.vector_norm");
233
234 auto mask = at::native::make_dim_mask(dim, self.dim());
235 auto shape = at::native::shape_from_dim_mask(self, std::move(mask), keepdim);
236 auto options = self.options()
237 .dtype(toRealValueType(opt_dtype.value_or(self.scalar_type())));
238
239 set_output_raw_strided(0, shape, {}, options);
240 }
241
TORCH_META_FUNC(_linalg_det)242 TORCH_META_FUNC(_linalg_det)(const Tensor& A) {
243 at::native::squareCheckInputs(A, "linalg.det");
244 at::native::checkFloatingOrComplex(A, "linalg.det");
245
246 auto shape = A.sizes();
247 auto ndim = shape.size();
248
249 // det
250 set_output_contiguous(0, shape.slice(0, ndim - 2), A.options());
251
252 // LU
253 auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
254 set_output_strided(1, shape, LU_strides, A.options());
255
256 // pivots
257 set_output_contiguous(2, shape.slice(0, ndim - 1), A.options().dtype(kInt));
258 }
259
TORCH_META_FUNC(_linalg_slogdet)260 TORCH_META_FUNC(_linalg_slogdet)(const Tensor& A) {
261 at::native::squareCheckInputs(A, "linalg.slogdet");
262 at::native::checkFloatingOrComplex(A, "linalg.slogdet", /*low_precision*/false);
263
264 auto shape= A.sizes();
265 auto ndim = shape.size();
266
267 auto shape_outputs = shape.slice(0, ndim - 2);
268
269 // sign
270 set_output_contiguous(0, shape_outputs, A.options());
271
272 // logabsdet
273 set_output_contiguous(1, shape_outputs, A.options().dtype(toRealValueType(A.scalar_type())));
274
275 // LU
276 auto LU_strides = at::native::batched_matrix_contiguous_strides(shape, /*f-contig*=*/true);
277 set_output_strided(2, shape, LU_strides, A.options());
278
279 // pivots
280 set_output_contiguous(3, shape.slice(0, ndim - 1), A.options().dtype(kInt));
281 }
282
283 template <typename Meta>
common_checks_baddbmm_bmm(Meta & meta,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha,bool is_bmm,const std::optional<Tensor> & self_baddbmm=std::nullopt)284 void common_checks_baddbmm_bmm(Meta& meta, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, bool is_bmm, const std::optional<Tensor>& self_baddbmm = std::nullopt) {
285 TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
286 TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
287
288 const auto batch1_sizes = batch1.sizes();
289 const auto batch2_sizes = batch2.sizes();
290
291 int64_t bs = batch1_sizes[0];
292 int64_t contraction_size = batch1_sizes[2];
293 int64_t res_rows = batch1_sizes[1];
294 int64_t res_cols = batch2_sizes[2];
295 std::vector<int64_t> output_size {bs, res_rows, res_cols};
296
297 TORCH_CHECK(batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size,
298 "Expected size for first two dimensions of batch2 tensor to be: [",
299 bs, ", ", contraction_size, "] but got: [", batch2_sizes[0], ", ", batch2_sizes[1], "].");
300
301 auto& result = meta.maybe_get_output(0);
302 // 'set_output' does not resize for in-place calls
303 meta.set_output_raw_strided(0, output_size, {}, batch2.options());
304 const auto result_sizes = result.sizes();
305 // Error is raised if called from in-place overload with incorrect shape
306 TORCH_CHECK(result_sizes == output_size,
307 "Expected an output tensor with shape [", output_size, "] but got shape ", result_sizes);
308
309 std::vector<Dimname> outnames = {};
310 if (!is_bmm) {
311 if (self_baddbmm.has_value()) {
312 const auto& self = self_baddbmm.value();
313 if (beta.toComplexDouble() != 0.0) result.copy_(self);
314 TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
315 const auto self_sizes = self.sizes();
316 TORCH_CHECK(self_sizes == output_size,
317 "Expected an input tensor shape with shape ", output_size, " but got shape: ", self_sizes);
318 outnames = namedinference::compute_baddbmm_outnames(result, batch1, batch2, self);
319 }
320 } else {
321 outnames = namedinference::compute_bmm_outnames(result, batch1, batch2);
322 }
323
324 namedinference::propagate_names_if_nonempty(
325 result,
326 outnames
327 );
328 }
329
TORCH_META_FUNC(bmm)330 TORCH_META_FUNC(bmm)(const Tensor& self, const Tensor& mat2) {
331 common_checks_baddbmm_bmm(*this, self, mat2, Scalar(0.0), Scalar(1.0), true);
332 }
333
TORCH_META_FUNC(baddbmm)334 TORCH_META_FUNC(baddbmm)(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
335 auto self_ = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm");
336 TORCH_CHECK(self.dtype() == batch1.dtype(), "Input dtypes must be the same, got: input ", self.dtype(), ", batch1: ", batch1.dtype(), ", batch2: ", batch2.dtype());
337 common_checks_baddbmm_bmm(*this, batch1, batch2, beta, alpha, false, *self_);
338 }
339
340 } // namespace meta
341 namespace native {
342
343 DEFINE_DISPATCH(addr_stub);
344
345
346 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.det ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
347
348 // As P is a permutation matrix
349 // det(P) = 1 if it's an even permutation and det(P) = -1 if it's an odd permutation
lu_det_P(const Tensor & pivots)350 static Tensor lu_det_P(const Tensor& pivots) {
351 return (at::arange(1, pivots.size(-1) + 1, pivots.options()) != pivots)
352 .sum(-1, /*keepdim=*/false, /*dtype=*/at::kLong)
353 .fmod_(2)
354 // take 0 to 1 and 1 to -1
355 .mul_(-2)
356 .add_(1);
357 }
358
359 // Auxiliary function that returns the LU decomposition to use it in the backward
TORCH_IMPL_FUNC(_linalg_det_out)360 TORCH_IMPL_FUNC(_linalg_det_out)(const Tensor& A, const Tensor& result, const Tensor& LU, const Tensor& pivots) {
361 // info is an aux tensor
362 auto info = at::empty({0}, A.options().dtype(kInt));
363 // Optimisation: lu_factor_ex requires the input to be F-contig, otherwise it copies
364 // Use the transpose of if A is contiguous since det(A^T) = det(A)
365 // We limit this to real matrices, but it could also be implemented for complex matrices
366 at::linalg_lu_factor_ex_out(const_cast<Tensor&>(LU), const_cast<Tensor&>(pivots), const_cast<Tensor&>(info), A.is_contiguous() && !A.is_complex() ? A.mH() : A);
367
368 // det = det_P * prod(diag(LU))
369 at::mul_out(const_cast<Tensor&>(result), lu_det_P(pivots), at::prod(LU.diagonal(0, -2 ,-1), /*dim=*/-1));
370 }
371
linalg_det(const Tensor & A)372 Tensor linalg_det(const Tensor& A) {
373 return std::get<0>(at::_linalg_det(A));
374 }
375
linalg_det_out(const Tensor & A,Tensor & result)376 Tensor& linalg_det_out(const Tensor& A, Tensor& result) {
377 auto LU = at::empty({0}, A.options());
378 auto pivots = at::empty({0}, A.options().dtype(kInt));
379 at::_linalg_det_out(result, LU, pivots, A);
380 return result;
381 }
382
383 // torch.det, alias for torch.linalg.det
det(const Tensor & self)384 Tensor det(const Tensor& self) {
385 return at::linalg_det(self);
386 }
387
388 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg.slogdet ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
389
390 // Auxiliary function that returns the LU decomposition to use it in the backward
TORCH_IMPL_FUNC(_linalg_slogdet_out)391 TORCH_IMPL_FUNC(_linalg_slogdet_out)(const Tensor& A, const Tensor& sign, const Tensor& logabsdet, const Tensor& LU, const Tensor& pivots) {
392 // info is an aux tensor
393 auto info = at::empty({0}, A.options().dtype(kInt));
394 // Optimisation: lu_factor_ex requires the input to be F-contig, otherwise it copies
395 // Use the transpose of if A is contiguous since det(A^T) = det(A)
396 // We limit this to real matrices, but it could also be implemented for complex matrices
397 at::linalg_lu_factor_ex_out(const_cast<Tensor&>(LU), const_cast<Tensor&>(pivots), const_cast<Tensor&>(info), A.is_contiguous() && !A.is_complex() ? A.mH() : A);
398
399 auto diag_U = LU.diagonal(0, -2, -1);
400 // sign
401 at::mul_out(const_cast<Tensor&>(sign), diag_U.sgn().prod(-1), lu_det_P(pivots));
402
403 // logabsdet
404 at::sum_out(const_cast<Tensor&>(logabsdet), diag_U.abs().log_(), -1);
405 }
406
linalg_slogdet(const Tensor & A)407 std::tuple<Tensor, Tensor> linalg_slogdet(const Tensor& A) {
408 auto out = at::_linalg_slogdet(A);
409 return std::make_tuple(std::move(std::get<0>(out)), std::move(std::get<1>(out)));
410 }
411
linalg_slogdet_out(const Tensor & A,Tensor & sign,Tensor & logabsdet)412 std::tuple<Tensor&, Tensor&> linalg_slogdet_out(const Tensor& A, Tensor& sign, Tensor& logabsdet) {
413 auto LU = at::empty({0}, A.options());
414 auto pivots = at::empty({0}, A.options().dtype(kInt));
415 at::_linalg_slogdet_out(sign, logabsdet, LU, pivots, A);
416 return std::tie(sign, logabsdet);
417 }
418
419 // Alias
slogdet(const Tensor & A)420 std::tuple<Tensor, Tensor> slogdet(const Tensor& A) {
421 return at::linalg_slogdet(A);
422 }
423
slogdet_out(const Tensor & A,Tensor & sign,Tensor & logabsdet)424 std::tuple<Tensor&, Tensor&> slogdet_out(const Tensor& A, Tensor& sign, Tensor& logabsdet) {
425 return at::linalg_slogdet_out(sign, logabsdet, A);
426 }
427
428 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ logdet ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
429
logdet(const Tensor & A)430 Tensor logdet(const Tensor& A) {
431 squareCheckInputs(A, "logdet");
432 checkFloatingOrComplex(A, "logdet", /*low_precision*/false);
433 auto [sign, logabsdet] = at::linalg_slogdet(A);
434
435 if (A.is_complex()) {
436 return sign.log() + logabsdet;
437 } else {
438 return at::where(sign == -1., NAN, logabsdet);
439 }
440 }
441
442 namespace {
443
444 // This function extracts the optional Tensors for atol and rtol
445 // Default value for atol is zero
446 // Default value for rtol is eps*max(rows, cols)
447 // If atol is specified and rtol is not specified then default value for rtol is zero
448 // It is used for matrix_rank and pinv
get_atol_rtol(const Tensor & input,const std::optional<Tensor> & atol_opt,const std::optional<Tensor> & rtol_opt,const c10::string_view function_name)449 std::tuple<Tensor, Tensor> get_atol_rtol(
450 const Tensor& input,
451 const std::optional<Tensor>& atol_opt,
452 const std::optional<Tensor>& rtol_opt,
453 const c10::string_view function_name) {
454 auto options = input.options();
455 if (input.device().type() == kMetal || input.device().type() == kMPS) {
456 options = options.dtype(ScalarType::Float);
457 } else {
458 options = options.dtype(ScalarType::Double);
459 }
460 auto atol = atol_opt.has_value() ? atol_opt.value() : at::zeros({}, options);
461 checkNotComplexTolerance(atol, function_name, "atol");
462 Tensor rtol;
463 if (rtol_opt.has_value()) {
464 rtol = rtol_opt.value();
465 checkNotComplexTolerance(rtol, function_name, "rtol");
466 } else {
467 ScalarType real_dtype = toRealValueType(input.scalar_type());
468 auto default_rtol = at::full({}, _get_epsilon(real_dtype) * std::max(input.sym_size(-1), input.sym_size(-2)), options);
469 rtol = atol_opt.has_value()
470 ? at::where(atol_opt.value() > 0, at::zeros({}, options), default_rtol)
471 : std::move(default_rtol);
472 }
473 return std::make_tuple(atol, rtol);
474 }
475
get_atol_rtol(const Tensor & input,std::optional<double> atol_opt,std::optional<double> rtol_opt)476 std::tuple<Tensor, Tensor> get_atol_rtol(
477 const Tensor& input,
478 std::optional<double> atol_opt,
479 std::optional<double> rtol_opt) {
480 auto atol = atol_opt.has_value() ? atol_opt.value() : 0.0;
481 c10::SymFloat rtol;
482 if (rtol_opt.has_value()) {
483 rtol = rtol_opt.value();
484 } else {
485 ScalarType real_dtype = toRealValueType(input.scalar_type());
486 auto default_rtol = _get_epsilon(real_dtype) * std::max(input.sym_size(-1), input.sym_size(-2));
487 rtol = (atol_opt.has_value() && atol_opt.value() > 0.0)
488 ? 0.0
489 : default_rtol;
490 }
491 auto options = input.options();
492 if (input.device().type() == kMetal || input.device().type() == kMPS) {
493 options = options.dtype(ScalarType::Float);
494 } else {
495 options = options.dtype(ScalarType::Double);
496 }
497 auto atol_tensor = at::full({}, atol, options);
498 auto rtol_tensor = at::full({}, rtol, options);
499 return std::make_tuple(atol_tensor, rtol_tensor);
500 }
501
502 } // anonymous namespace
503
linalg_pinv(const Tensor & input,const std::optional<Tensor> & atol_opt,const std::optional<Tensor> & rtol_opt,bool hermitian)504 Tensor linalg_pinv(
505 const Tensor& input,
506 const std::optional<Tensor>& atol_opt,
507 const std::optional<Tensor>& rtol_opt,
508 bool hermitian) {
509 // FIXME: Whenever we have a nice lstsq, we should dispatch this function to simply be
510 // `torch.lstsq(A, torch.eye(A.shape[-1]), atol=atol, rtol=rtol)`
511 // with a driver that supports singular inputs
512 NoTF32Guard disable_tf32;
513 ScalarType t = input.scalar_type();
514 TORCH_CHECK((t == ScalarType::Double || t == ScalarType::Float || t == ScalarType::ComplexFloat || t == ScalarType::ComplexDouble)
515 && input.dim() >= 2,
516 "linalg.pinv(", t, "{", input.sizes(), "}): expected a tensor with 2 or more dimensions "
517 "of float, double, cfloat or cdouble types");
518
519 auto [atol, rtol] = get_atol_rtol(input, atol_opt, rtol_opt, "torch.linalg.pinv");
520
521 if (input.sym_numel() == 0) {
522 // The implementation below uses operations that do not work for zero numel tensors
523 // therefore we need this early return for 'input.numel() == 0' case
524 // TODO: replace input.svd with linalg_svd when torch/xla can work with at::linalg_svd
525 auto [U, S, V] = input.svd();
526 return at::matmul(V * S.reciprocal().unsqueeze(-2), U.mH());
527 }
528
529 // If not Hermitian use singular value decomposition, else use eigenvalue decomposition
530 if (!hermitian) {
531 // TODO: replace input.svd with linalg_svd
532 // using linalg_svd breaks pytorch/xla, see https://github.com/pytorch/xla/issues/2755
533 auto [U, S, V] = input.svd();
534 Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order
535 Tensor tol = at::max(atol.unsqueeze(-1), rtol.unsqueeze(-1) * max_val);
536 Tensor S_pseudoinv = at::where(S > tol, S.reciprocal(), at::zeros({}, S.options())).to(input.dtype());
537 // computes V @ diag(S_pseudoinv) @ U.conj().T
538 return at::matmul(V * S_pseudoinv.unsqueeze(-2), U.mH());
539 } else {
540 auto [S, U] = at::linalg_eigh(input);
541 // For Hermitian matrices, singular values equal to abs(eigenvalues)
542 Tensor S_abs = S.abs();
543 // eigenvalues are sorted in ascending order starting with negative values, we need a maximum value of abs(eigenvalues)
544 Tensor max_val = S_abs.amax(/*dim=*/-1, /*keepdim=*/true);
545 Tensor tol = at::max(atol.unsqueeze(-1), rtol.unsqueeze(-1) * max_val);
546 Tensor S_pseudoinv = at::where(S_abs > tol, S.reciprocal(), at::zeros({}, S.options())).to(input.dtype());
547 // computes U @ diag(S_pseudoinv) @ U.conj().T
548 return at::matmul(U * S_pseudoinv.unsqueeze(-2), U.mH());
549 }
550 }
551
linalg_pinv(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)552 Tensor linalg_pinv(const Tensor& input, std::optional<double> atol, std::optional<double> rtol, bool hermitian) {
553 auto [atol_tensor, rtol_tensor] = get_atol_rtol(input, atol, rtol);
554 return at::linalg_pinv(input, atol_tensor, rtol_tensor, hermitian);
555 }
556
linalg_pinv(const Tensor & input,const Tensor & rcond,bool hermitian)557 Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {
558 // For NumPy compatibility the rcond argument is used as relative tolerance
559 checkNotComplexTolerance(rcond, "torch.linalg.pinv", "rcond");
560 auto options = input.options();
561 if (input.device().type() == kMetal || input.device().type() == kMPS) {
562 options = options.dtype(ScalarType::Float);
563 } else {
564 options = options.dtype(ScalarType::Double);
565 }
566 return at::linalg_pinv(input, at::zeros({}, options), rcond, hermitian);
567 }
568
linalg_pinv(const Tensor & input,double rcond,bool hermitian)569 Tensor linalg_pinv(const Tensor& input, double rcond, bool hermitian) {
570 // For NumPy compatibility the rcond argument is used as relative tolerance
571 return at::linalg_pinv(input, 0.0, rcond, hermitian);
572 }
573
574 // TODO: implement _out variant avoiding copy and using already allocated storage directly
linalg_pinv_out(const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian,Tensor & result)575 Tensor& linalg_pinv_out(
576 const Tensor& input,
577 const std::optional<Tensor>& atol,
578 const std::optional<Tensor>& rtol,
579 bool hermitian,
580 Tensor& result) {
581 checkSameDevice("linalg.pinv", result, input);
582 checkLinalgCompatibleDtype("linalg.pinv", result, input);
583 Tensor result_tmp = at::linalg_pinv(input, atol, rtol, hermitian);
584 at::native::resize_output(result, result_tmp.sizes());
585 result.copy_(result_tmp);
586 return result;
587 }
588
linalg_pinv_out(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian,Tensor & result)589 Tensor& linalg_pinv_out(
590 const Tensor& input,
591 std::optional<double> atol,
592 std::optional<double> rtol,
593 bool hermitian,
594 Tensor& result) {
595 checkSameDevice("linalg.pinv", result, input);
596 checkLinalgCompatibleDtype("linalg.pinv", result, input);
597 Tensor result_tmp = at::linalg_pinv(input, atol, rtol, hermitian);
598 at::native::resize_output(result, result_tmp.sizes());
599 result.copy_(result_tmp);
600 return result;
601 }
602
linalg_pinv_out(const Tensor & input,const Tensor & rcond,bool hermitian,Tensor & result)603 Tensor& linalg_pinv_out(const Tensor& input, const Tensor& rcond, bool hermitian, Tensor& result) {
604 checkSameDevice("linalg.pinv", result, input);
605 checkLinalgCompatibleDtype("linalg.pinv", result, input);
606
607 Tensor result_tmp = at::linalg_pinv(input, rcond, hermitian);
608 at::native::resize_output(result, result_tmp.sizes());
609 result.copy_(result_tmp);
610 return result;
611 }
612
linalg_pinv_out(const Tensor & input,double rcond,bool hermitian,Tensor & result)613 Tensor& linalg_pinv_out(const Tensor& input, double rcond, bool hermitian, Tensor& result) {
614 Tensor rcond_tensor = at::full({}, rcond, input.options().dtype(ScalarType::Double));
615 return at::linalg_pinv_out(result, input, rcond_tensor, hermitian);
616 }
617
pinverse(const Tensor & self,double rcond)618 Tensor pinverse(const Tensor& self, double rcond) {
619 return at::linalg_pinv(self, rcond, /*hermitian=*/false);
620 }
621
622 // matrix_power implementation
623 namespace {
624
625 /**
626 * @brief Raises the input matrix to the given power n
627 *
628 * If the exponent n is negative, the inverse of the input
629 * matrix will be raised to power abs(n).
630 *
631 * @param self (batched) square matrix to raise to power n
632 * @param n exponent to raise matrix (or matrices in batch) to
633 * @param _out optional tensor to write the output to
634 * @return Tensor input matrix raised to power n
635 */
linalg_matrix_power_impl(const Tensor & self,int64_t n,std::optional<Tensor> _out)636 Tensor linalg_matrix_power_impl(
637 const Tensor& self,
638 int64_t n,
639 std::optional<Tensor> _out) {
640 NoTF32Guard disable_tf32;
641 auto out = _out.value_or(Tensor());
642
643 squareCheckInputs(self, "linalg.matrix_power");
644 if (_out.has_value()) {
645 checkSameDevice("matrix_power", out, self);
646 checkLinalgCompatibleDtype("matrix_power", out, self);
647 at::native::resize_output_symint(out, self.sym_sizes());
648 }
649
650 // For n=0 we return the identity matrix of the same shape as input.
651 if (n == 0) {
652 if (!_out.has_value()) {
653 // Clone input to include result in the autograd graph
654 out = self.clone(at::MemoryFormat::Contiguous);
655 }
656 return out.copy_(at::eye_symint(self.sym_size(-2), self.options()));
657 }
658 if (n == 1) {
659 return _out.has_value() ? out.copy_(self)
660 : self.clone(at::MemoryFormat::Contiguous);
661 }
662 if (n == -1) {
663 return _out.has_value() ? at::linalg_inv_out(out, self)
664 : at::linalg_inv(self);
665 }
666
667 // For negative n we inverte the input matrix before raising to power abs(n)
668 auto a = n < 0 ? at::linalg_inv(self) : self;
669 n = std::abs(n);
670
671 // Fast paths for small powers
672 if (n == 2) {
673 return _out.has_value() ? at::matmul_out(out, a, a) : at::matmul(a, a);
674 }
675 if (n == 3) {
676 return _out.has_value() ? at::matmul_out(out, at::matmul(a, a), a)
677 : at::matmul(at::matmul(a, a), a);
678 }
679
680 // This is a binary decomposition of n.
681 // Moving from the least significant bit to the most significant bit
682 // This is done to reduce the number of matrix multiplications
683 // by raising the input matrix in powers of 2
684 // The total number of matrix multiplications are
685 // number of bits + number of bits that equal 1 ~ O(log n)
686 // instead of O(n)
687 Tensor z, result;
688 while (n > 0) {
689 const auto bit = n % 2;
690 n = n / 2;
691 z = z.defined() ? at::matmul(z, z) : a;
692 if (bit == 1) {
693 if (_out.has_value() && n <= 0) {
694 // Last multiplication can use the out version
695 return result.defined() ? at::matmul_out(out, result, z) : out.copy_(z);
696 }
697 result = result.defined() ? at::matmul(result, z) : z;
698 }
699 }
700
701 return result;
702 }
703
704 } // namespace
705
linalg_matrix_power_out(const Tensor & self,int64_t n,Tensor & result)706 Tensor& linalg_matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
707 linalg_matrix_power_impl(self, n, result);
708 return result;
709 }
710
linalg_matrix_power(const Tensor & self,int64_t n)711 Tensor linalg_matrix_power(const Tensor& self, int64_t n) {
712 return linalg_matrix_power_impl(self, n, std::nullopt);
713 }
714
matrix_power_out(const Tensor & self,int64_t n,Tensor & result)715 Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) {
716 return at::native::linalg_matrix_power_out(self, n, result);
717 }
718
matrix_power(const Tensor & self,int64_t n)719 Tensor matrix_power(const Tensor& self, int64_t n) {
720 return at::native::linalg_matrix_power(self, n);
721 }
722
723 namespace {
724
725 // Computes the rank of 'input' and saves the result in-place in 'result'.
726 // 'hermitian' controls whether SVD or eigendecomposition is used for computing the singular values
727 // 'atol' and 'rtol' are the absolute and relative tolerances, respectively.
matrix_rank_impl(const Tensor & input,const std::optional<Tensor> & atol_opt,const std::optional<Tensor> & rtol_opt,bool hermitian,Tensor & result)728 Tensor& matrix_rank_impl(
729 const Tensor& input,
730 const std::optional<Tensor>& atol_opt,
731 const std::optional<Tensor>& rtol_opt,
732 bool hermitian,
733 Tensor& result) {
734 auto [atol, rtol] = get_atol_rtol(input, atol_opt, rtol_opt, "torch.linalg.matrix_rank");
735
736 checkSameDevice("torch.linalg.matrix_rank", result, input);
737 checkSameDevice("torch.linalg.matrix_rank", atol, input, "atol");
738 checkSameDevice("torch.linalg.matrix_rank", rtol, input, "rtol");
739 ScalarType output_type = ScalarType::Long;
740 checkLinalgCompatibleDtype("torch.linalg.matrix_rank", result.scalar_type(), output_type);
741
742 checkNotComplexTolerance(atol, "torch.linalg.matrix_rank", "atol");
743 checkNotComplexTolerance(rtol, "torch.linalg.matrix_rank", "rtol");
744
745 // NumPy doesn't take into account possible input with no elements and it errors on max not defined for this case
746 // Let's output 0 for this case, since that kind of matrices have zero number of non-zero rows, hence rank is 0.
747 if (input.sym_numel() == 0) {
748 result.fill_(0);
749 return result;
750 }
751
752 // We compute matrix rank as the number of singular or absolute eigen values
753 // that are above max(atol, rtol * max(S)) threshold
754 Tensor S, max_S;
755 if (!hermitian) {
756 S = at::linalg_svdvals(input);
757 // singular values are sorted in descending order
758 max_S = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1);
759 } else {
760 S = at::linalg_eigvalsh(input);
761 S = S.abs();
762 // eigenvalues are sorted in ascending order starting with negative values, we need a maximum value of abs(eigenvalues)
763 max_S = S.amax(/*dim=*/-1, /*keepdim=*/true);
764 }
765
766 Tensor tol = at::max(atol.unsqueeze(-1), rtol.unsqueeze(-1) * max_S);
767
768 if (isTensorSubclassLike(input)) {
769 result = at::sum(S > tol, /*dim=*/-1);
770 return result;
771 }
772
773 result = at::sum_out(result, S > tol, /*dim=*/-1);
774 return result;
775 }
776
get_matrix_rank_result_tensor(const Tensor & input)777 Tensor get_matrix_rank_result_tensor(const Tensor& input) {
778 // Matrices or batch of matrices are allowed
779 checkIsMatrix(input, "torch.linalg.matrix_rank", "input");
780 // For Composite Compliance, allocate `result` of correct shape to
781 // avoid resizing in `out` variant.
782 // See also `NOTE [matrix rank output shape]`
783 auto result_shape =
784 SymIntArrayRef(input.sym_sizes().cbegin(), input.sym_sizes().cend() - 2);
785 Tensor result =
786 at::empty_symint(result_shape, input.options().dtype(ScalarType::Long));
787
788 return result;
789 }
790
791 } // anonymous namespace
792
linalg_matrix_rank_out(const Tensor & input,const std::optional<Tensor> & atol_opt,const std::optional<Tensor> & rtol_opt,bool hermitian,Tensor & result)793 Tensor& linalg_matrix_rank_out(
794 const Tensor& input,
795 const std::optional<Tensor>& atol_opt,
796 const std::optional<Tensor>& rtol_opt,
797 bool hermitian,
798 Tensor& result) {
799 // Matrices or batch of matrices are allowed
800 checkIsMatrix(input, "torch.linalg.matrix_rank", "input");
801 auto result_shape =
802 IntArrayRef(input.sizes().cbegin(), input.sizes().cend() - 2);
803 at::native::resize_output(result, result_shape);
804 return matrix_rank_impl(input, atol_opt, rtol_opt, hermitian, result);
805 }
806
linalg_matrix_rank_out(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian,Tensor & result)807 Tensor& linalg_matrix_rank_out(const Tensor& input, std::optional<double> atol, std::optional<double> rtol, bool hermitian, Tensor& result) {
808 auto [atol_tensor, rtol_tensor] = get_atol_rtol(input, atol, rtol);
809 result = linalg_matrix_rank_out(input, atol_tensor, rtol_tensor, hermitian, result);
810 return result;
811 }
812
linalg_matrix_rank(const Tensor & input,const std::optional<Tensor> & atol,const std::optional<Tensor> & rtol,bool hermitian)813 Tensor linalg_matrix_rank(const Tensor& input, const std::optional<Tensor>& atol, const std::optional<Tensor>& rtol, bool hermitian) {
814 auto result = get_matrix_rank_result_tensor(input);
815 return matrix_rank_impl(input, atol, rtol, hermitian, result);
816 }
817
linalg_matrix_rank(const Tensor & input,std::optional<double> atol,std::optional<double> rtol,bool hermitian)818 Tensor linalg_matrix_rank(const Tensor& input, std::optional<double> atol, std::optional<double> rtol, bool hermitian) {
819 auto result = get_matrix_rank_result_tensor(input);
820
821 auto [atol_tensor, rtol_tensor] = get_atol_rtol(input, atol, rtol);
822
823 return matrix_rank_impl(input, atol_tensor, rtol_tensor, hermitian, result);
824 }
825
linalg_matrix_rank_out(const Tensor & input,const Tensor & tol,bool hermitian,Tensor & result)826 Tensor& linalg_matrix_rank_out(const Tensor& input, const Tensor& tol, bool hermitian, Tensor& result) {
827 // For NumPy compatibility tol is not scaled with max(singular_value) if the value for tol is provided
828 // It is assumed that the provided value is the absolute tolerance
829 Tensor rtol = at::zeros({}, tol.options());
830 result = at::linalg_matrix_rank_outf(input, tol, rtol, hermitian, result);
831 return result;
832 }
833
linalg_matrix_rank_out(const Tensor & input,double tol,bool hermitian,Tensor & result)834 Tensor& linalg_matrix_rank_out(const Tensor& input, double tol, bool hermitian, Tensor& result) {
835 // For NumPy compatibility tol is not scaled with max(singular_value) if the value for tol is provided
836 // It is assumed that the provided value is the absolute tolerance
837 result = at::linalg_matrix_rank_outf(input, tol, 0.0, hermitian, result);
838 return result;
839 }
840
linalg_matrix_rank(const Tensor & input,const Tensor & tol,bool hermitian)841 Tensor linalg_matrix_rank(const Tensor& input, const Tensor& tol, bool hermitian) {
842 auto result = get_matrix_rank_result_tensor(input);
843 return matrix_rank_impl(input, tol, at::zeros({}, tol.options()), hermitian, result);
844 }
845
linalg_matrix_rank(const Tensor & input,double tol,bool hermitian)846 Tensor linalg_matrix_rank(const Tensor& input, double tol, bool hermitian) {
847 auto result = get_matrix_rank_result_tensor(input);
848
849 auto [atol_tensor, rtol_tensor] = get_atol_rtol(input, tol, 0.0);
850
851 return matrix_rank_impl(input, atol_tensor, rtol_tensor, hermitian, result);
852 }
853
854 // multi_dot helper functions
855 namespace {
856
857 /**
858 * @brief Computes the optimal matrix chain multiplication order
859 *
860 * Follows the dynamic programming algorithm from Cormen et al.,
861 * "Introduction to Algorithms, Third Edition", Chapter 15.2,
862 * p. 370-378. Note that the book uses 1-based indexing.
863 *
864 * The cost of multiplying two matrices with sizes p x q and q x r
865 * is defined here as p * q * r. The optimal multiplication order
866 * is the one that minimizes the total cost.
867 *
868 * @param tensors list of 2D tensors
869 * @return a 2D vector s used by #matrix_chain_multiplication to construct
870 * the optimal matrix multiplication order. The optimal multiplication
871 * order for multiplying tensors i...j is to multiply tensors i...s[i, j]
872 * and tensors (s[i, j] + 1)...j first and then the result of that.
873 */
matrix_chain_order(TensorList tensors)874 std::vector<std::vector<int64_t>> matrix_chain_order(TensorList tensors) {
875 const size_t n = tensors.size();
876
877 // Tensor i has dimensions p[i] x p[i + 1]
878 std::vector<int64_t> p(n + 1);
879 for (const auto i : c10::irange(n)) {
880 p[i] = tensors[i].size(0);
881 }
882 p[n] = tensors[n - 1].size(1);
883
884 // m[i, j] = k where k is the minimum cost for multiplying tensors i...j
885 std::vector<std::vector<int64_t>> m(n, std::vector<int64_t>(n, 0));
886
887 // s[i, j] = k where k is the index at which to split the list such that
888 // optimally multiplying matrices i...k and k...j first and then the resulting
889 // matrices is the optimal order for multiplying matrices i...j.
890 std::vector<std::vector<int64_t>> s(n, std::vector<int64_t>(n));
891
892 // Compute the optimal multiplication order
893 for (const auto l : c10::irange(1, n)) {
894 for (const auto i : c10::irange(n - l)) {
895 const auto j = i + l;
896 m[i][j] = std::numeric_limits<int64_t>::max();
897 for (const auto k : c10::irange(i, j)) {
898 const auto q = m[i][k] + m[k + 1][j] + p[i] * p[k + 1] * p[j + 1];
899 if (q < m[i][j]) {
900 m[i][j] = q;
901 s[i][j] = k;
902 }
903 }
904 }
905 }
906
907 return s;
908 }
909
910 /**
911 * @brief Recursively multiplies the tensors i...j using the given order
912 *
913 * @param tensors matrices to multiply together
914 * @param order optimal chain multiplication order from #matrix_chain_order
915 * @param i index of first tensor to be multiplied
916 * @param j index of last tensor to be multiplied
917 * @return Tensor result of multiplying tensors[i...j] together.
918 */
matrix_chain_multiplication(TensorList tensors,const std::vector<std::vector<int64_t>> & order,int64_t i,int64_t j)919 Tensor matrix_chain_multiplication(
920 TensorList tensors,
921 const std::vector<std::vector<int64_t>>& order,
922 int64_t i,
923 int64_t j) {
924 if (i == j) {
925 return tensors[i];
926 }
927 return at::mm(
928 matrix_chain_multiplication(tensors, order, i, order[i][j]),
929 matrix_chain_multiplication(tensors, order, order[i][j] + 1, j));
930 }
931
932 // Implements torch.linalg.multi_dot
multi_dot_impl(TensorList _tensors,std::optional<Tensor> _out)933 Tensor multi_dot_impl(TensorList _tensors, std::optional<Tensor> _out) {
934 const size_t n = _tensors.size();
935 TORCH_CHECK(n >= 2, "multi_dot(): expected at least 2 tensors but got ", n);
936
937 std::vector<int64_t> out_shape;
938 std::vector<Tensor> tensors(n);
939
940 // If the first tensor is 1D of size n view it as a row vector (1, n)
941 if (_tensors[0].dim() == 1) {
942 tensors[0] = _tensors[0].unsqueeze(0);
943 } else if (_tensors[0].dim() == 2) {
944 tensors[0] = _tensors[0];
945 out_shape.emplace_back(tensors[0].size(0));
946 } else {
947 TORCH_CHECK(
948 false,
949 "multi_dot(): the first tensor must be 1D or 2D but got ",
950 _tensors[0].dim(),
951 "D");
952 }
953
954 // If the last tensor is 1D of size n view it as a column vector (n, 1)
955 if (_tensors[n - 1].dim() == 1) {
956 tensors[n - 1] = _tensors[n - 1].unsqueeze(-1);
957 } else if (_tensors[n - 1].dim() == 2) {
958 tensors[n - 1] = _tensors[n - 1];
959 out_shape.emplace_back(tensors[n - 1].size(1));
960 } else {
961 TORCH_CHECK(
962 false,
963 "multi_dot(): the last tensor must be 1D or 2D but got ",
964 _tensors[n - 1].dim(),
965 "D");
966 }
967
968 // Ensure middle tensors are 2D
969 for (const auto i : c10::irange(1, n - 1)) {
970 TORCH_CHECK(
971 _tensors[i].dim() == 2,
972 "multi_dot(): tensor ",
973 i,
974 " must be 2D but got ",
975 _tensors[i].dim(),
976 "D");
977 tensors[i] = _tensors[i];
978 }
979
980 // Ensure all tensors have the same device and dtype and check
981 // that the shapes can be multiplied
982 const auto dtype = tensors[0].dtype();
983 const auto device = tensors[0].device();
984 for (const auto i : c10::irange(1, n)) {
985 TORCH_CHECK(
986 tensors[i].dtype() == dtype,
987 "multi_dot(): all tensors must have be the same dtype but tensor 0 is ",
988 dtype,
989 " and tensor ",
990 i,
991 " ",
992 tensors[i].dtype());
993 TORCH_CHECK(
994 tensors[i].device() == device,
995 "multi_dot(): all tensors must be on the same device but tensor 0 is on ",
996 device,
997 " and tensor ",
998 i,
999 " on ",
1000 tensors[i].device());
1001 TORCH_CHECK(
1002 tensors[i - 1].size(-1) == tensors[i].size(0),
1003 "multi_dot(): tensors ",
1004 i - 1,
1005 " and ",
1006 i,
1007 " with shapes ",
1008 _tensors[i - 1].sizes(),
1009 " and ",
1010 _tensors[i].sizes(),
1011 " cannot be multiplied")
1012 }
1013
1014 Tensor result;
1015
1016 if (_out.has_value()) {
1017 auto out = *_out;
1018 TORCH_CHECK(
1019 dtype == out.dtype(),
1020 "multi_dot(): expected out tensor to have dtype ",
1021 dtype,
1022 " but got ",
1023 out.dtype());
1024 TORCH_CHECK(
1025 device == out.device(),
1026 "multi_dot(): expected out tensor to be on device ",
1027 device,
1028 " but got ",
1029 out.device());
1030
1031 // If the last and last tensors have shapes (a, b) and (b, c) the
1032 // output has shape (a, c). If either the first or last tensor is 1D
1033 // a and/or c dimensions will be implicitly size 1 and will be omitted
1034 // from the output. e.g. for inputs (a, b) x (b) the output has shape (a,).
1035 at::native::resize_output(out, out_shape);
1036
1037 // View output as 2D for simplicity of computation.
1038 result = out.view({tensors[0].size(0), tensors.back().size(-1)});
1039 }
1040
1041 // The resize_ and view calls below are to ensure the
1042 // output shape respects the original dimensionality of
1043 // the first and last tensors which we are now viewed as 2D
1044
1045 if (tensors.size() == 2) {
1046 return _out.has_value() ? at::mm_out(result, tensors[0], tensors[1])
1047 : at::mm(tensors[0], tensors[1]).view(out_shape);
1048 }
1049
1050 // Why the separate implementation for 3 matrices?
1051 // The logic for three matrices is much faster when done directly
1052 // Requires 1 comparison to 4 comparisons and fewer arithmetic operations
1053 if (tensors.size() == 3) {
1054 const auto a = tensors[0].size(0);
1055 const auto b = tensors[1].size(0);
1056 const auto c = tensors[2].size(0);
1057 const auto d = tensors[2].size(1);
1058
1059 // The matrices are of size (a x b), (b x c), (c x d)
1060 // cost_1 is the cost of parenthesizing (a x b) and (b x c) and then
1061 // combining (c x d) cost_2 is the cost of parenthesizing (b x c) and (c x
1062 // d) and then combining (a x b)
1063 const auto cost_1 = (a * c) * (b + d);
1064 const auto cost_2 = (b * d) * (a + c);
1065
1066 if (cost_1 > cost_2) {
1067 return _out.has_value()
1068 ? at::mm_out(result, tensors[0], at::mm(tensors[1], tensors[2]))
1069 : at::mm(tensors[0], at::mm(tensors[1], tensors[2])).view(out_shape);
1070 } else {
1071 return _out.has_value()
1072 ? at::mm_out(result, at::mm(tensors[0], tensors[1]), tensors[2])
1073 : at::mm(at::mm(tensors[0], tensors[1]), tensors[2]).view(out_shape);
1074 }
1075 }
1076
1077 // Algorithm for multiplying 4 or more matrices
1078 const auto order = matrix_chain_order(tensors);
1079 const int64_t i = 0;
1080 const int64_t j = n - 1;
1081
1082 if (_out.has_value()) {
1083 // We manually implement the first recursive layer here so we can use mm_out
1084 // for the final multiplication
1085 return at::mm_out(
1086 result,
1087 matrix_chain_multiplication(tensors, order, i, order[i][j]),
1088 matrix_chain_multiplication(tensors, order, order[i][j] + 1, j));
1089 }
1090 return matrix_chain_multiplication(tensors, order, i, j).view(out_shape);
1091 }
1092
1093 } // namespace
1094
linalg_multi_dot(TensorList tensors)1095 Tensor linalg_multi_dot(TensorList tensors) {
1096 return multi_dot_impl(tensors, std::nullopt);
1097 }
1098
linalg_multi_dot_out(TensorList tensors,Tensor & result)1099 Tensor& linalg_multi_dot_out(TensorList tensors, Tensor& result) {
1100 multi_dot_impl(tensors, result);
1101 return result;
1102 }
1103
chain_matmul(TensorList matrices)1104 Tensor chain_matmul(TensorList matrices) {
1105 TORCH_WARN_ONCE(
1106 "torch.chain_matmul is deprecated and will be removed in a future PyTorch release. ",
1107 "Use torch.linalg.multi_dot instead, which accepts a list of two or more tensors rather than ",
1108 "multiple parameters."
1109 );
1110 checkAllSameDim(matrices, 2);
1111
1112 TORCH_CHECK(
1113 !matrices.empty(), "chain_matmul(): Expected one or more matrices");
1114
1115 if (matrices.size() == 1) {
1116 return matrices[0].clone();
1117 }
1118
1119 return at::native::linalg_multi_dot(matrices);
1120 }
1121
chain_matmul_out(TensorList matrices,Tensor & result)1122 Tensor& chain_matmul_out(TensorList matrices, Tensor& result) {
1123 TORCH_WARN_ONCE(
1124 "torch.chain_matmul is deprecated and will be removed in a future PyTorch release. ",
1125 "Use torch.linalg.multi_dot instead, which accepts a list of two or more tensors rather than ",
1126 "multiple parameters."
1127 );
1128 checkAllSameDim(matrices, 2);
1129
1130 TORCH_CHECK(
1131 !matrices.empty(), "chain_matmul(): Expected one or more matrices");
1132
1133 if (matrices.size() == 1) {
1134 at::native::resize_output(result, matrices[0].sizes());
1135 return result.copy_(matrices[0]);
1136 }
1137
1138 return at::native::linalg_multi_dot_out(matrices, result);
1139 }
1140
check_1d(const Tensor & t,const char * arg,const char * fn)1141 static void check_1d(const Tensor& t, const char* arg, const char* fn) {
1142 TORCH_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D");
1143 }
1144
check_addr_scalar(const ScalarType dtype,const Scalar & scalar,const std::string & scalar_name)1145 static void check_addr_scalar(const ScalarType dtype,
1146 const Scalar& scalar,
1147 const std::string& scalar_name) {
1148 TORCH_CHECK(
1149 !scalar.isBoolean() || dtype == ScalarType::Bool,
1150 "Boolean ", scalar_name, " only supported for Boolean results.");
1151 TORCH_CHECK(
1152 isFloatingType(dtype) || isComplexType(dtype) || scalar.isIntegral(true),
1153 "For integral input tensors, "
1154 "argument ", scalar_name ," must not be a floating point number.");
1155 }
1156
build_addr_iter(Tensor & result,const Tensor & self,const Tensor & vec1,const Tensor & vec2)1157 static TensorIterator build_addr_iter(Tensor& result,
1158 const Tensor& self,
1159 const Tensor& vec1,
1160 const Tensor& vec2) {
1161 check_1d(vec1, "vec1", "addr");
1162 check_1d(vec2, "vec2", "addr");
1163
1164 const auto vec1_size0 = vec1.sizes()[0];
1165 const auto vec2_size0 = vec2.sizes()[0];
1166 auto self_ = &result == &self
1167 ? c10::MaybeOwned<Tensor>::borrowed(self)
1168 : expand_size(self, {vec1_size0, vec2_size0}, "addr");
1169 TORCH_CHECK(
1170 self_->dim() == 2,
1171 "2D tensor expected, got ", self_->dim(), "D tensor for input"
1172 );
1173 TORCH_CHECK(
1174 self_->sizes()[0] == vec1_size0 && self_->sizes()[1] == vec2_size0,
1175 "size mismatch, input: ", self_->sizes(),
1176 ", v1: ", vec1.sizes(),
1177 ", v2: ", vec2.sizes()
1178 );
1179
1180 auto iter = TensorIteratorConfig()
1181 .set_check_mem_overlap(true)
1182 .add_output(result)
1183 .add_owned_const_input(*self_)
1184 .add_owned_const_input(vec1.reshape({vec1_size0, 1}))
1185 .add_const_input(vec2)
1186 .allow_cpu_scalars(true)
1187 .promote_inputs_to_common_dtype(true)
1188 .cast_common_dtype_to_outputs(true)
1189 .enforce_safe_casting_to_output(true)
1190 .build();
1191 return iter;
1192 }
1193
addr(const Tensor & self,const Tensor & vec1,const Tensor & vec2,const Scalar & beta,const Scalar & alpha)1194 Tensor addr(const Tensor& self,
1195 const Tensor& vec1, const Tensor& vec2,
1196 const Scalar& beta, const Scalar& alpha) {
1197 Tensor result;
1198 auto iter = build_addr_iter(result, self, vec1, vec2);
1199
1200 check_addr_scalar(iter.dtype(), beta, "beta");
1201 check_addr_scalar(iter.dtype(), alpha, "alpha");
1202
1203 addr_stub(iter.device_type(), iter, beta, alpha);
1204 return iter.output();
1205 }
1206
addr_(Tensor & self,const Tensor & vec1,const Tensor & vec2,const Scalar & beta,const Scalar & alpha)1207 Tensor& addr_(Tensor& self,
1208 const Tensor& vec1, const Tensor& vec2,
1209 const Scalar& beta, const Scalar& alpha) {
1210 return at::addr_out(self, self, vec1, vec2, beta, alpha);
1211 }
1212
addr_out(const Tensor & self,const Tensor & vec1,const Tensor & vec2,const Scalar & beta,const Scalar & alpha,Tensor & result)1213 Tensor& addr_out(const Tensor& self,
1214 const Tensor& vec1, const Tensor& vec2,
1215 const Scalar& beta, const Scalar& alpha, Tensor &result) {
1216 auto iter = build_addr_iter(result, self, vec1, vec2);
1217
1218 check_addr_scalar(iter.dtype(), beta, "beta");
1219 check_addr_scalar(iter.dtype(), alpha, "alpha");
1220
1221 addr_stub(iter.device_type(), iter, beta, alpha);
1222 return result;
1223 }
1224
1225 // The math_addr and math_addr_out functions support backends
1226 // other than CPU and CUDA, such as XLA.
1227 // They are implemented using the composition of existing ops
math_addr(const Tensor & self,const Tensor & vec1,const Tensor & vec2,const Scalar & beta,const Scalar & alpha)1228 Tensor math_addr(const Tensor& self,
1229 const Tensor& vec1, const Tensor& vec2,
1230 const Scalar& beta, const Scalar& alpha) {
1231 // when beta==0, values in self should be ignored,
1232 // nans and infs in self should not propagate.
1233 Tensor out;
1234 if (beta.toComplexDouble() == 0.0) {
1235 if (alpha.toComplexDouble() == 1.0) {
1236 out = at::outer(vec1, vec2);
1237 } else {
1238 out = alpha * at::outer(vec1, vec2);
1239 }
1240 } else if (beta.toComplexDouble() == 1.0) {
1241 if (alpha.toComplexDouble() == 1.0) {
1242 out = self + at::outer(vec1, vec2);
1243 } else {
1244 out = self + alpha * at::outer(vec1, vec2);
1245 }
1246 } else if (alpha.toComplexDouble() == 1.0) {
1247 out = beta * self + at::outer(vec1, vec2);
1248 } else {
1249 out = beta * self + alpha * at::outer(vec1, vec2);
1250 }
1251 auto result_type = c10::promoteTypes(c10::promoteTypes(self.scalar_type(), vec1.scalar_type()), vec2.scalar_type());
1252 return out.to(c10::TensorOptions().dtype(result_type));
1253 }
1254
math_addr_out(const Tensor & self,const Tensor & vec1,const Tensor & vec2,const Scalar & beta,const Scalar & alpha,Tensor & result)1255 Tensor& math_addr_out(const Tensor& self,
1256 const Tensor& vec1, const Tensor& vec2,
1257 const Scalar& beta, const Scalar& alpha, Tensor &result) {
1258 auto addr_result = at::addr(self, vec1, vec2, beta, alpha);
1259
1260 // Validates safe casting
1261 const auto result_dtype = addr_result.scalar_type();
1262 TORCH_CHECK(canCast(result_dtype, result.scalar_type()),
1263 "result type ", result_dtype,
1264 " can't be cast to the desired output type ", result.scalar_type());
1265
1266 at::native::resize_output(result, addr_result.sizes().vec());
1267 result.copy_(addr_result);
1268 return result;
1269 }
1270
1271 // torch.ger, alias for torch.outer
ger_out(const Tensor & self,const Tensor & vec2,Tensor & result)1272 Tensor& ger_out(const Tensor& self, const Tensor& vec2, Tensor &result) {
1273 TORCH_WARN("torch.ger is deprecated and will be removed in a future PyTorch release. "
1274 "Use torch.outer instead.");
1275 return at::outer_out(result, self, vec2);
1276 }
1277
ger(const Tensor & self,const Tensor & vec2)1278 Tensor ger(const Tensor& self, const Tensor& vec2) {
1279 return self.outer(vec2);
1280 }
1281
inner_out(const Tensor & self,const Tensor & other,Tensor & out)1282 Tensor& inner_out(const Tensor& self, const Tensor& other, Tensor& out) {
1283 checkDeviceType("inner()", {out, self, other}, self.device().type());
1284
1285 // If either self or other is a scalar just multiply them
1286 if (self.dim() == 0 || other.dim() == 0) {
1287 at::mul_out(out, self, other);
1288 return out;
1289 }
1290
1291 // Last dimension should match (tensordot does not enforce this)
1292 TORCH_CHECK(
1293 self.size(-1) == other.size(-1),
1294 "inner() the last dimension must match on both input tensors but got shapes ",
1295 self.sizes(),
1296 " and ",
1297 other.sizes());
1298
1299 at::tensordot_out(out, self, other, -1, -1);
1300 return out;
1301 }
1302
inner(const Tensor & self,const Tensor & other)1303 Tensor inner(const Tensor& self, const Tensor& other) {
1304 checkDeviceType("inner()", {self, other}, self.device().type());
1305
1306 // If either self or other is a scalar just multiply them
1307 if (self.dim() == 0 || other.dim() == 0) {
1308 return self * other;
1309 }
1310
1311 // Last dimension should match (tensordot does not enforce this)
1312 TORCH_CHECK(
1313 self.sym_size(-1) == other.sym_size(-1),
1314 "inner() the last dimension must match on both input tensors but got shapes ",
1315 self.sym_sizes(),
1316 " and ",
1317 other.sym_sizes());
1318
1319 return at::tensordot(self, other, -1, -1);
1320 }
1321
outer_out(const Tensor & self,const Tensor & vec2,Tensor & result)1322 Tensor& outer_out(const Tensor& self, const Tensor& vec2, Tensor &result) {
1323 check_1d(self, "self", "outer");
1324 check_1d(vec2, "vec2", "outer");
1325
1326 // torch.outer is implemented as a composite op using reshape and mul
1327 at::mul_out(result, self.reshape({self.size(0), 1}), vec2);
1328 return result;
1329 }
1330
outer(const Tensor & self,const Tensor & vec2)1331 Tensor outer(const Tensor& self, const Tensor& vec2) {
1332 check_1d(self, "self", "outer");
1333 check_1d(vec2, "vec2", "outer");
1334
1335 return self.reshape_symint({self.sym_size(0), 1}) * vec2;
1336 }
1337
1338
1339 #if !defined(C10_MOBILE)
1340 #define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
1341 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
1342 kBFloat16, kHalf, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
1343 TYPE, NAME, __VA_ARGS__)
1344 #else
1345 // Include half dtype in ADDMM. Used to build ExecuTorch in xplat.
1346 #if defined(C10_MOBILE_HALF)
1347 #define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
1348 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, \
1349 TYPE, NAME, __VA_ARGS__)
1350 #else
1351 #define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
1352 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, \
1353 TYPE, NAME, __VA_ARGS__)
1354 #endif
1355 #endif
1356
1357
get_mkldnn_matmul_min_dim()1358 static inline int64_t get_mkldnn_matmul_min_dim() {
1359 static auto value = [&] {
1360 const int64_t default_min_dim = [&] {
1361 // Minimum dimension requirement for MKLDNN; derived based on experiments.
1362 //it's enabled on all Neoverse cpus.
1363 return is_arm_neoverse() ? 8 : 0;
1364 }();
1365 const char* ptr = std::getenv("TORCH_MKLDNN_MATMUL_MIN_DIM");
1366 return ptr != nullptr ? std::atoi(ptr) : default_min_dim;
1367 }();
1368 return value;
1369 }
1370
1371
get_mkldnn_matmul_min_size()1372 static inline int64_t get_mkldnn_matmul_min_size() {
1373 static auto value = [&] {
1374 const int64_t default_min_size = [&] {
1375 // Minimum size requirement for MKLDNN; derived based on experiments.
1376 // it's enabled on all Neoverse cpus.
1377 return is_arm_neoverse() ? 8 * 1024 : 0;
1378 }();
1379 const char* ptr = std::getenv("TORCH_MKLDNN_MATMUL_MIN_SIZE");
1380 return ptr != nullptr ? std::atoi(ptr) : default_min_size;
1381 }();
1382 return value;
1383 }
1384
1385
apply_mkldnn_matmul_heur(int64_t m,int64_t k,int64_t n)1386 static inline bool apply_mkldnn_matmul_heur(int64_t m, int64_t k, int64_t n) {
1387 const int64_t min_dim = get_mkldnn_matmul_min_dim();
1388 const int64_t min_size = get_mkldnn_matmul_min_size();
1389 return at::globalContext().userEnabledMkldnn() && m > min_dim && k > min_dim && n > min_dim && m * k * n > min_size;
1390 }
1391
1392
addmm_impl_cpu_(Tensor & result,const Tensor & self,Tensor m1,Tensor m2,const Scalar & beta,const Scalar & alpha)1393 static void addmm_impl_cpu_(
1394 Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
1395 TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
1396
1397 TORCH_CHECK(
1398 m1.dtype() == m2.dtype(),
1399 "expected m1 and m2 to have the same dtype, but got: ", m1.dtype(), " != ", m2.dtype()
1400 )
1401 // Array access is faster than .size(n) and .stride(n)
1402 const auto self_sizes = self.sizes();
1403 auto m1_strides = m1.strides();
1404 auto m1_sizes = m1.sizes();
1405 auto m2_strides = m2.strides();
1406 auto m2_sizes = m2.sizes();
1407
1408 TORCH_CHECK(
1409 self_sizes[0] == m1_sizes[0] && self_sizes[1] == m2_sizes[1],
1410 "input shape is incompatible with matrix multiplication (",
1411 m1_sizes[0], "x", m1_sizes[1], " @ ", m2_sizes[0], "x", m2_sizes[1], " != ",
1412 self_sizes[0], "x", self_sizes[1], ")");
1413
1414 at::native::resize_output(result, self_sizes);
1415 const auto result_strides = result.strides();
1416 const auto result_sizes = result.sizes();
1417
1418 if (result.numel() == 0) {
1419 return;
1420 }
1421
1422 // Some paths in the code below do not handle multiplications of the form [a, 0] x [0, b]
1423 if (m1_sizes[1] == 0) {
1424 if (beta.toComplexDouble() == 0.0) {
1425 result.zero_();
1426 } else {
1427 if (!self.is_same(result)) {
1428 result.copy_(self);
1429 }
1430 result.mul_(beta);
1431 }
1432 return;
1433 }
1434
1435 if (beta.toComplexDouble() != 0.0 && !self.is_same(result)) {
1436 result.copy_(self);
1437 }
1438
1439 bool transpose_c = false;
1440 Tensor c;
1441
1442 // Cast result as matrix a
1443 if (result_strides[0] == 1 &&
1444 (result_sizes[1] == 1 || result_strides[1] >= std::max(int64_t{1}, result_sizes[0]))) {
1445 transpose_c = false;
1446 c = result.resolve_conj();
1447 } else if (result_strides[1] == 1 &&
1448 (result_sizes[0] == 1 || result_strides[0] >= std::max(int64_t{1}, result_sizes[1]))) {
1449 std::swap(m1, m2);
1450 std::swap(m1_sizes, m2_sizes);
1451 std::swap(m1_strides, m2_strides);
1452 transpose_c = true;
1453 c = result.resolve_conj();
1454 } else {
1455 transpose_c = false;
1456 // make c FORTRAN contiguous
1457 c = result.resolve_conj().transpose(0, 1).contiguous().transpose_(0, 1);
1458 }
1459
1460 const int64_t m = result_sizes[transpose_c ? 1 : 0];
1461 const int64_t n = result_sizes[transpose_c ? 0 : 1];
1462 const int64_t k = m1_sizes[transpose_c ? 0 : 1];
1463
1464 // Cast m1 as matrix a
1465 bool transpose_a = false;
1466 Tensor a;
1467 /* Need lda >= max(1, (transpose_a ? k : m)) */
1468 if (m1_strides[transpose_c ? 1 : 0] == 1 &&
1469 m1_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, m)) {
1470 transpose_a = false;
1471 a = m1.resolve_conj();
1472 } else if (m1_strides[transpose_c ? 0 : 1] == 1 &&
1473 m1_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, k)) {
1474 transpose_a = true;
1475 a = m1;
1476 } else {
1477 transpose_a = !transpose_c;
1478 a = m1.clone(at::MemoryFormat::Contiguous);
1479 }
1480
1481 // Cast m2 as matrix b
1482 bool transpose_b = false;
1483 Tensor b;
1484 /* Need ldm2_ >= max(1, (transpose_m2 == 'n' ? k : n)) */
1485 if (m2_strides[transpose_c ? 1 : 0] == 1 &&
1486 m2_strides[transpose_c ? 0 : 1] >= std::max(int64_t{1}, k)) {
1487 transpose_b = false;
1488 b = m2.resolve_conj();
1489 } else if (m2_strides[transpose_c ? 0 : 1] == 1 &&
1490 m2_strides[transpose_c ? 1 : 0] >= std::max(int64_t{1}, n)) {
1491 transpose_b = true;
1492 b = m2;
1493 } else {
1494 transpose_b = !transpose_c;
1495 b = m2.clone(at::MemoryFormat::Contiguous);
1496 }
1497
1498 const int64_t lda = a.strides()[(transpose_a == transpose_c) ? 1 : 0];
1499 const int64_t ldb = b.strides()[(transpose_b == transpose_c) ? 1 : 0];
1500 const int64_t ldc = c.strides()[transpose_c ? 0 : 1];
1501
1502 // Always ensure the conjugation for c is resolved since there's no way to specify c's conjugation in the gemm call
1503 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c.is_conj());
1504
1505 bool dispatched = false;
1506 #if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
1507 // On AArch64 if LHS matrix in BLAS routine is transposed but RHS is not then
1508 // it is faster to call oneDNN matrix multiplication primitive with RHS*LHS
1509 // that will call then into Arm® Compute Library (ACL) GEMM kernel and also
1510 // additionally have support for running kernel with BF16 instructions
1511 if (transpose_c) {
1512 bool apply_heur = apply_mkldnn_matmul_heur(b.sizes()[0], b.sizes()[1], a.sizes()[1]);
1513 if (apply_heur && transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
1514 try {
1515 mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
1516 // We have dispatched to ACL GEMM for single precision float
1517 // so do not need to dispatch to BLAS GEMM below
1518 dispatched = true;
1519 } catch (const std::exception& e) {
1520 TORCH_WARN("mkldnn_matmul failed, switching to BLAS gemm:", e.what());
1521 at::globalContext().setUserEnabledMkldnn(false);
1522 }
1523 }
1524 }
1525 #endif
1526
1527 if(!dispatched) {
1528 // Apply BLAS routine
1529 _AT_DISPATCH_ADDMM_TYPES(result.scalar_type(), "addmm_impl_cpu_", [&]{
1530 using opmath_t = at::opmath_type<scalar_t>;
1531 at::native::cpublas::gemm(
1532 transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
1533 transpose_b ? b.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
1534 m, n, k,
1535 alpha.to<opmath_t>(),
1536 a.const_data_ptr<scalar_t>(), lda,
1537 b.const_data_ptr<scalar_t>(), ldb,
1538 beta.to<opmath_t>(),
1539 c.mutable_data_ptr<scalar_t>(), ldc);
1540 });
1541 }
1542
1543 if (!c.is_same(result)) {
1544 result.copy_(c);
1545 }
1546 }
1547
addbmm_impl_(Tensor & result,const Tensor & self,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)1548 static void addbmm_impl_(
1549 Tensor &result, const Tensor &self, const Tensor &batch1, const Tensor &batch2, const Scalar& beta, const Scalar& alpha) {
1550 TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
1551 TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
1552 TORCH_CHECK(batch1.size(0) == batch2.size(0),
1553 "batch1 and batch2 must have same number of batches, got ",
1554 batch1.size(0), " and ", batch2.size(0));
1555 TORCH_CHECK(batch1.size(2) == batch2.size(1),
1556 "Incompatible matrix sizes for bmm (",
1557 batch1.size(1), "x", batch1.size(2), " and ",
1558 batch2.size(1), "x", batch2.size(2), ")");
1559
1560 const int64_t dim1 = batch1.size(1);
1561 const int64_t dim2 = batch2.size(2);
1562 TORCH_CHECK(self.size(0) == dim1 && self.size(1) == dim2,
1563 "self tensor does not match matmul output shape");
1564
1565 result.resize_as_(self);
1566
1567 if (beta.to<c10::complex<double>>() != 0.0 && !self.is_same(result)) {
1568 result.copy_(self);
1569 }
1570
1571 const int64_t num_batches = batch1.size(0);
1572
1573 if (num_batches == 0) {
1574 if (beta.to<c10::complex<double>>() != 0.0) {
1575 result.mul_(beta);
1576 } else {
1577 result.zero_();
1578 }
1579 return;
1580 }
1581
1582 auto adjusted_beta(beta);
1583 for (const auto batch : c10::irange(num_batches)) {
1584 result.addmm_(batch1[batch], batch2[batch], adjusted_beta, alpha);
1585 adjusted_beta = 1; // accumulate output once
1586 }
1587 }
1588
addbmm_out(const Tensor & self,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha,Tensor & result)1589 Tensor& addbmm_out(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, Tensor& result) {
1590 auto b_self = expand_size(self, {batch1.size(1), batch2.size(2)}, "addbmm_out");
1591 {
1592 at::NoNamesGuard guard;
1593 addbmm_impl_(result, *b_self, batch1, batch2, beta, alpha);
1594 }
1595 auto names = at::namedinference::propagate_names_for_addmm(batch1, batch2, self);
1596 at::namedinference::propagate_names_if_nonempty(result, names);
1597 return result;
1598 }
1599
addbmm_(Tensor & self,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)1600 Tensor &addbmm_(Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
1601 return native::addbmm_out(self, batch1, batch2, beta, alpha, self);
1602 }
1603
addbmm(const Tensor & self,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)1604 Tensor addbmm(const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
1605 Tensor result = at::empty({0}, self.options());
1606 return native::addbmm_out(self, batch1, batch2, beta, alpha, result);
1607 }
1608
TORCH_IMPL_FUNC(addmm_out_cpu)1609 TORCH_IMPL_FUNC(addmm_out_cpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, const Tensor &result) {
1610 auto b_self = expand_size(self, {mat1.sizes()[0], mat2.sizes()[1]}, "addmm_out");
1611 {
1612 at::NoNamesGuard guard;
1613 addmm_impl_cpu_(const_cast<Tensor&>(result), *b_self, mat1, mat2, beta, alpha);
1614 }
1615 }
1616
TORCH_IMPL_FUNC(addmm_activation_out_cpu)1617 TORCH_IMPL_FUNC(addmm_activation_out_cpu)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, bool use_gelu, const Tensor &result) {
1618 auto b_self = expand_size(self, {mat1.sizes()[0], mat2.sizes()[1]}, "addmm_out");
1619 {
1620 at::NoNamesGuard guard;
1621 addmm_impl_cpu_(const_cast<Tensor&>(result), *b_self, mat1, mat2, beta, alpha);
1622 if (use_gelu) {
1623 at::gelu_(const_cast<Tensor&>(result));
1624 } else {
1625 at::relu_(const_cast<Tensor&>(result));
1626 }
1627 }
1628 }
1629
TORCH_IMPL_FUNC(mm_out_cpu)1630 TORCH_IMPL_FUNC(mm_out_cpu)(const Tensor & self, const Tensor & mat2, const Tensor & result) {
1631 {
1632 at::NoNamesGuard guard;
1633 addmm_impl_cpu_(const_cast<Tensor&>(result), result, self, mat2, 0, 1);
1634 }
1635 }
1636
1637 template <typename scalar_t, bool is_bmm>
baddbmm_cpu_kernel(const Tensor & result,const Tensor & self,const Tensor & mat2,const Scalar & beta_,const Scalar & alpha_)1638 inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const Tensor& mat2, const Scalar& beta_, const Scalar& alpha_) {
1639 int64_t bs = result.size(0);
1640 int64_t is = result.size(1);
1641 int64_t js = result.size(2);
1642 int64_t ks = self.size(2);
1643
1644 using opmath_t = at::opmath_type<scalar_t>;
1645 opmath_t alpha = alpha_.to<opmath_t>();
1646 opmath_t beta = beta_.to<opmath_t>();
1647
1648 auto r0 = result.accessor<scalar_t, 3>();
1649 auto s0 = self.accessor<const scalar_t, 3>();
1650 auto m0 = mat2.accessor<const scalar_t, 3>();
1651
1652 int64_t grain_size = std::max(internal::GRAIN_SIZE / (is * js * ks), (int64_t)1);
1653 using opmath_t = at::opmath_type<scalar_t>;
1654 parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) {
1655 for (const auto b : c10::irange(b_begin, b_end)) {
1656 auto r1 = r0[b];
1657 auto s1 = s0[b];
1658 auto m1 = m0[b];
1659 for (const auto i : c10::irange(is)) {
1660 auto r2 = r1[i];
1661 auto s2 = s1[i];
1662 for (const auto j : c10::irange(js)) {
1663 opmath_t acc_value = 0;//is_bmm ? opmath_t(0) : opmath_t(r2[j]);
1664 for (const auto k : c10::irange(ks)) {
1665 acc_value += static_cast<opmath_t>(s2[k]) *
1666 static_cast<opmath_t>(m1[k][j]);
1667 }
1668 if (is_bmm) {
1669 r2[j] = acc_value;
1670 } else {
1671 // For beta == 0, the r's value will be ignored, especially for nan value.
1672 if (beta == opmath_t{0}) {
1673 r2[j] = alpha * acc_value;
1674 } else {
1675 r2[j] = static_cast<opmath_t>(r2[j]) * beta + alpha * acc_value;
1676 }
1677 }
1678 }
1679 }
1680 }
1681 });
1682 }
1683
baddbmm_with_gemm_(const Tensor & result,const Tensor & mat1,const Tensor & mat2,const Scalar & beta_,const Scalar & alpha_)1684 static void baddbmm_with_gemm_(const Tensor &result, const Tensor &mat1, const Tensor &mat2, const Scalar &beta_, const Scalar &alpha_) {
1685 TORCH_INTERNAL_ASSERT(result.is_contiguous());
1686
1687 const auto result_sizes = result.sizes();
1688 const auto result_strides = result.strides();
1689 const auto mat1_strides = mat1.strides();
1690 const auto mat2_strides = mat2.strides();
1691 const auto mat1_sizes = mat1.sizes();
1692 const auto mat2_sizes = mat2.sizes();
1693
1694 auto is_transposed = [](const c10::IntArrayRef& strides, const c10::IntArrayRef& sizes) {
1695 return strides[1] == 1 && strides[2] >= sizes[1];
1696 };
1697
1698 // gemm expects fortran order matrices, so we swap argument order to transpose everything
1699 const auto transpose_a = is_transposed(mat2_strides, mat2_sizes);
1700 const auto transpose_b = is_transposed(mat1_strides, mat1_sizes);
1701
1702 const int64_t batch_size = mat1_sizes[0];
1703 const int64_t m = result_sizes[2];
1704 const int64_t n = result_sizes[1];
1705 const int64_t k = mat2_sizes[1];
1706
1707 const int64_t lda = mat2_strides[transpose_a ? 2 : 1];
1708 const int64_t ldb = mat1_strides[transpose_b ? 2 : 1];
1709 const int64_t ldc = result_strides[1];
1710
1711 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "baddbmm_with_gemm", [&] {
1712 using opmath_t = at::opmath_type<scalar_t>;
1713 const auto alpha = alpha_.to<opmath_t>();
1714 const auto beta = beta_.to<opmath_t>();
1715 at::native::cpublas::gemm_batched_with_stride(
1716 transpose_a ? TransposeType::Transpose : TransposeType::NoTranspose,
1717 transpose_b ? TransposeType::Transpose : TransposeType::NoTranspose,
1718 batch_size, m, n, k, alpha,
1719 mat2.const_data_ptr<scalar_t>(), lda, mat2_strides[0],
1720 mat1.const_data_ptr<scalar_t>(), ldb, mat1_strides[0],
1721 beta,
1722 result.data_ptr<scalar_t>(), ldc, result_strides[0]);
1723 });
1724 }
1725
1726 // This tries to apply some optimizations to bmm/baddbmm:
1727 // - When the operand size is small, computation are parallelized over the batch
1728 // dimension using OMP and naive matrix multiplication is applied.
1729 // - When the operand size is larger than the threshold, if compiled with MKL, MKL's batch gemm is used.
1730 // - Otherwise, we use a series of matrix multiplications.
1731 // The threshold of 400 for the first has not been thoroughly benchmarked yet and may have room for further
1732 // optimization, it likely depends on the characteristics of the CPU, MKL will be different from non-MKL etc.,
1733 // but this seems to be a first starting point.
1734
bmm_out_or_baddbmm_(const Tensor & self_or_result_,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha,bool is_bmm_out)1735 static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, bool is_bmm_out) {
1736 // is_bmm_out: true for bmm_out, false for baddbmm_
1737 // self_or_result is "self" for baddbmm_ and "result" for bmm_out
1738 Tensor& self_or_result = const_cast<Tensor&>(self_or_result_);
1739
1740 const auto batch1_sizes = batch1.sizes();
1741 const auto batch2_sizes = batch2.sizes();
1742
1743 int64_t bs = batch1_sizes[0];
1744 int64_t contraction_size = batch1_sizes[2];
1745 int64_t res_rows = batch1_sizes[1];
1746 int64_t res_cols = batch2_sizes[2];
1747
1748 // handle pathological cases that blas may not like
1749 if (self_or_result.numel() == 0) {
1750 return;
1751 } else if (contraction_size == 0) {
1752 if (is_bmm_out || (beta.to<c10::complex<double>>() == 0.0)) {
1753 self_or_result.zero_();
1754 return;
1755 } else {
1756 self_or_result.mul_(beta);
1757 return;
1758 }
1759 }
1760
1761 auto batch_items_contiguous_or_transposed = [&](const Tensor& t) {
1762 const auto sizes = t.sizes();
1763 const auto strides = t.strides();
1764 // we do not care dimension's stride if its size equals to 1
1765 return (strides[2] == 1 && (sizes[1] == 1 || strides[1] >= sizes[2])) ||
1766 (strides[1] == 1 && (sizes[2] == 1 || strides[2] >= sizes[1]));
1767 };
1768
1769 bool apply_heur = apply_mkldnn_matmul_heur(batch1.sizes()[1], batch1.sizes()[2], batch2.sizes()[2]);
1770 if (apply_heur && use_mkldnn_matmul(batch1, batch2, self_or_result)) {
1771 try {
1772 mkldnn_matmul(batch1, batch2, self_or_result, beta.to<float>(), alpha.to<float>());
1773 return;
1774 } catch (const std::exception& e) {
1775 TORCH_WARN("mkldnn_matmul failed, switching to baddbmm:", e.what());
1776 at::globalContext().setUserEnabledMkldnn(false);
1777 }
1778 }
1779
1780 if (contraction_size * res_rows * res_cols < 400) {
1781 if (is_bmm_out) {
1782 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, batch1.scalar_type(), "bmm", [&] {
1783 baddbmm_cpu_kernel<scalar_t, true>(self_or_result, batch1, batch2, beta, alpha);
1784 });
1785 } else {
1786 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, batch1.scalar_type(), "baddbmm", [&] {
1787 baddbmm_cpu_kernel<scalar_t, false>(self_or_result, batch1, batch2, beta, alpha);
1788 });
1789 }
1790 } else if (at::hasMKL() && ((
1791 self_or_result.scalar_type() != kBFloat16 &&
1792 self_or_result.scalar_type() != kHalf &&
1793 at::native::is_floating_point(self_or_result)) ||
1794 at::native::is_complex(self_or_result))
1795 && batch_items_contiguous_or_transposed(batch1)
1796 && batch_items_contiguous_or_transposed(batch2)
1797 && self_or_result.is_contiguous()) {
1798 baddbmm_with_gemm_(self_or_result, batch1, batch2, beta, alpha);
1799 } else { // split along batch dimension
1800 #ifdef C10_MOBILE
1801 /*
1802 * We only do multithreading when Inference mode is enabled because various
1803 * thread local state is not appropriately propagated through
1804 * at::parallel_for. e.g. RecordFunction related state, dispatchKeySet Big
1805 * concern with this is that if we use at::parallel_for where state is not
1806 * propagated then dispatch machinery may work differently on main thread
1807 * vs. other threads, leading to undefined behavior.
1808 * Thus it is recommended to not use at::parallel_for where lambdas do
1809 * ops that go through dispatcher.
1810 * For now we circumvent this by InferenceMode guard in order to unlock
1811 * performance.
1812 * Longer term we probably want a separate API that explicitly calls out
1813 * the TLS that it propagates.
1814 * Also note that this is enabled for mobile only because blas
1815 * implementation for non-mobile build is already multithreaded.
1816 */
1817 // Benchmarking was done as follows:
1818 // bmm_test: operator benchmark under
1819 // benchmarks/operator_benchmarks/pt/bmm_test.py Ran this benchmark for
1820 // various matrix sizes on Samsung S8U
1821 const bool enable_multithreaded_bmm = c10::InferenceMode::is_enabled() &&
1822 bs >= 4 && res_rows >= 4 && res_cols >= 16 && contraction_size >= 16;
1823 #else
1824 const bool enable_multithreaded_bmm{false};
1825 #endif
1826 if (is_bmm_out) {
1827 if (enable_multithreaded_bmm) {
1828 auto bmm_out_fn = [&](uint64_t start, uint64_t end) {
1829 c10::InferenceMode guard;
1830 for (const auto b : c10::irange(start, end)) {
1831 auto r = self_or_result.select(0, b);
1832 addmm_impl_cpu_(
1833 r, r, batch1.select(0, b), batch2.select(0, b), 0, 1);
1834 }
1835 };
1836 // Materialize if COW, since we cannot do so during parallel_for
1837 self_or_result.mutable_data_ptr();
1838 at::parallel_for(0, bs, 1, bmm_out_fn);
1839 } else {
1840 for (const auto b : c10::irange(bs)) {
1841 auto r = self_or_result.select(0, b);
1842 addmm_impl_cpu_(r, r, batch1.select(0, b), batch2.select(0, b), 0, 1);
1843 }
1844 }
1845 } else {
1846 if (enable_multithreaded_bmm) {
1847 auto bmm_fn = [&](uint64_t start, uint64_t end) {
1848 c10::InferenceMode guard;
1849 for (const auto b : c10::irange(start, end)) {
1850 self_or_result.select(0, b).addmm_(
1851 batch1.select(0, b), batch2.select(0, b), beta, alpha);
1852 }
1853 };
1854 // Materialize if COW, since we cannot do so during parallel_for
1855 self_or_result.mutable_data_ptr();
1856 at::parallel_for(0, bs, 1, bmm_fn);
1857 } else {
1858 for (const auto b : c10::irange(bs)) {
1859 self_or_result.select(0, b).addmm_(
1860 batch1.select(0, b), batch2.select(0, b), beta, alpha);
1861 }
1862 }
1863 }
1864 }
1865 return;
1866 }
1867
conjugate_mutable_input_if_needed(const Tensor & self,bool conjugate)1868 static void conjugate_mutable_input_if_needed(const Tensor& self, bool conjugate) {
1869 if (conjugate) {
1870 self.conj_physical_();
1871 }
1872 }
1873
TORCH_IMPL_FUNC(baddbmm_out_cpu)1874 TORCH_IMPL_FUNC(baddbmm_out_cpu)
1875 (const Tensor & self, const Tensor & batch1, const Tensor & batch2, const Scalar& beta, const Scalar& alpha, const Tensor& result) {
1876 bool self_is_conj = result.is_conj();
1877 conjugate_mutable_input_if_needed(result, self_is_conj);
1878 bmm_out_or_baddbmm_(result, batch1.resolve_conj(), batch2.resolve_conj(), beta, alpha, false);
1879 conjugate_mutable_input_if_needed(result, self_is_conj);
1880 }
1881
TORCH_IMPL_FUNC(bmm_out_cpu)1882 TORCH_IMPL_FUNC(bmm_out_cpu)
1883 (const Tensor & batch1, const Tensor & batch2, const Tensor & result) {
1884 {
1885 NoNamesGuard guard;
1886 bool result_is_conj = result.is_conj();
1887 conjugate_mutable_input_if_needed(result, result_is_conj);
1888 bmm_out_or_baddbmm_(result, batch1.resolve_conj(), batch2.resolve_conj(), Scalar(0.0), Scalar(1.0), true);
1889 conjugate_mutable_input_if_needed(result, result_is_conj);
1890 }
1891 }
1892
dot_out(const Tensor & self,const Tensor & other,Tensor & result)1893 Tensor& dot_out(const Tensor& self, const Tensor& other, Tensor& result) {
1894 auto output_device = result.device();
1895 auto input1_device = self.device();
1896 auto input2_device = other.device();
1897 // check if the input & output tensors are on the same device.
1898 TORCH_CHECK(
1899 (output_device == input1_device) && (input1_device == input2_device),
1900 "dot: Expected the output and input tensors to be on the "
1901 "same device, but got the output tensor on ", output_device,
1902 ", the 'input' tensor on ", input1_device, ", and the 'other' tensor on ", input2_device);
1903 at::native::resize_output(result, {});
1904 TORCH_CHECK(result.scalar_type() == self.scalar_type(),
1905 "result dtype ", result.scalar_type(), " does not match input dtype ", self.scalar_type());
1906 return result.fill_(self.dot(other));
1907 }
1908
vdot_out(const Tensor & self,const Tensor & other,Tensor & result)1909 Tensor& vdot_out(const Tensor& self, const Tensor& other, Tensor& result) {
1910 auto output_device = result.device();
1911 auto input1_device = self.device();
1912 auto input2_device = other.device();
1913 // check if the input & output tensors are on the same device.
1914 TORCH_CHECK(
1915 (output_device == input1_device) && (input1_device == input2_device),
1916 "vdot: Expected the output and input tensors to be on the "
1917 "same device, but got the output tensor on ", output_device,
1918 ", the 'input' tensor on ", input1_device, ", and the 'other' tensor on ", input2_device);
1919 at::native::resize_output(result, {});
1920 TORCH_CHECK(result.scalar_type() == self.scalar_type(),
1921 "result dtype ", result.scalar_type(), " does not match input dtype ", self.scalar_type());
1922 return result.fill_(self.vdot(other));
1923 }
1924
should_fold(const Tensor & tensor1,const Tensor & tensor2,bool has_out)1925 static bool should_fold(const Tensor& tensor1, const Tensor& tensor2, bool has_out) {
1926 // We check that we can fold the larger tensor into a matrix and dispatch to mm or mv rather than
1927 // to bmm. We want to make sure we can do so without incurring in any extra copy
1928 const auto tensor1_larger = tensor1.dim() >= tensor2.dim();
1929
1930 // We order the tensors. t1 will be the larger tensor
1931 // We can always transpose tensor2 as the dimensions are always >= 1 (precondition from matmul)
1932 // and tensor1_larger iff tensor2.dim() > tensor1.dim(9
1933 const auto t1 = tensor1_larger ? MaybeOwned<Tensor>::borrowed(tensor1)
1934 : MaybeOwned<Tensor>::owned(tensor2.mT());
1935 const int64_t dim_t1 = t1->dim();
1936 const auto dim_t2 = tensor1_larger ? tensor2.dim()
1937 : tensor1.dim();
1938
1939 // Just fold for dim_t1 >= 3 and (dim_t2 == 1 || dim_t2 == 2)
1940 if (!(dim_t1 >= 3 && dim_t2 <= 2)) {
1941 return false;
1942 }
1943
1944 // In this case we *do* incur in an extra copy to avoid creating an unnecessary large tensor in the backward
1945 // Suppose we don't fold here. Let t1.shape = [b, m, n] t2.shape = [n, k] like in a transformer
1946 // t2 will be expanded to a tensor of shape [b, n, k] and then we do t1.bmm(t2_expanded)
1947 // The issue appears in the backward.
1948 // The output gradient g of this operation would have shape [b, m, k]
1949 // The backward wrt. t2 of bmm would be given by t1.mH @ g, which has shape [b, n, k]
1950 // Then, the backward of expand is simply `sum(0)`. As such, we are instantiating a tensor
1951 // of shape [b, n, k] unnecessarily, which may cause a large memory footprint, and in the
1952 // worst case, an OOM
1953 bool t2_requires_grad = tensor1_larger ? tensor2.requires_grad() : tensor1.requires_grad();
1954 if (t2_requires_grad && !has_out) {
1955 // We should be checking !at::GradMode::is_enabled(), but apparently
1956 // this regresses performance in some cases:
1957 // https://github.com/pytorch/pytorch/issues/118548#issuecomment-1916022394
1958 return true;
1959 }
1960
1961 // Don't fold in this case, as we would have to call mm on the transposed tensor, the result
1962 // would be contiguous, and then we would need to transpose it and call contiguous on it, thus
1963 // having to copy the tensor
1964 if (tensor1.dim() == 2) {
1965 return false;
1966 }
1967
1968 // Can always fold if the tensor is empty
1969 // This serves as a precondition for the code below
1970 if (t1->numel() == 0) {
1971 return true;
1972 }
1973
1974 // t1->view(-1, t1->size(-1)) does not copy only when the first n-1 dimensions are contiguous
1975 // in the sense that t1_stride[i] = t1_stride[i+1]*t1_shape[i+1]
1976 const auto t1_shape = t1->sizes();
1977 const auto t1_strides = t1->strides();
1978 for (auto i = int64_t{0}; i < dim_t1 - int64_t{2}; ++i) {
1979 if (t1_strides[i] != t1_strides[i+1] * t1_shape[i+1]) {
1980 return false;
1981 }
1982 }
1983 return true;
1984 }
1985
1986 /*
1987 Matrix product of two Tensors.
1988 The behavior depends on the dimensionality of the Tensors as follows:
1989 - If both Tensors are 1-dimensional, (1d) the dot product (scalar) is returned.
1990 - If the arguments are 2D - 1D or 1D - 2D, the matrix-vector product is returned.
1991 - If both arguments are 2D, the matrix-matrix product is returned.
1992 - If one of the arguments is ND with N >= 3 and the other is 1D or 2D, and some
1993 conditions on the strides apply (see should_fold) we fold the first N-1 dimensions
1994 of the ND argument to form a matrix, call mm or mv, reshape it back to ND and return it
1995 - Otherwise, we return bmm, after broadcasting and folding the batched dimensions if
1996 there's more than one
1997 */
_matmul_impl(Tensor & out,const Tensor & tensor1,const Tensor & tensor2)1998 static Tensor _matmul_impl(
1999 Tensor& out,
2000 const Tensor& tensor1,
2001 const Tensor& tensor2) {
2002 NoNamesGuard guard;
2003 const auto dim_tensor1 = tensor1.dim();
2004 const auto dim_tensor2 = tensor2.dim();
2005
2006 // This is checked up here to simplify the logic below
2007 // Note that the strings are just evaluated on failure, so almost always we just evaluate
2008 // the condition and move on
2009 TORCH_CHECK(dim_tensor1 != 0 && dim_tensor2 != 0,
2010 "both arguments to matmul need to be at least 1D, but they are ",
2011 dim_tensor1, "D and ", dim_tensor2, "D");
2012
2013
2014 const bool has_out = out.defined();
2015
2016 if (has_out) {
2017 // Usually we would rely on the out= kernels we decompose into to check this, but
2018 // for matmul there is logic at the composite level that relies on this invariant.
2019 TORCH_CHECK(!(tensor1.requires_grad() || tensor2.requires_grad() || out.requires_grad()) || !at::GradMode::is_enabled(),
2020 "matmul(): functions with out=... arguments don't support automatic differentiation, "
2021 "but one of the arguments requires grad."
2022 );
2023 }
2024
2025 if (dim_tensor1 == 1 && dim_tensor2 == 1) {
2026 return has_out ? at::dot_out(out, tensor1, tensor2) : tensor1.dot(tensor2);
2027 } else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
2028 return has_out ? at::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2);
2029 } else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
2030 return has_out ? at::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0)
2031 : tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
2032 } else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
2033 return has_out ? at::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2);
2034 } else if (should_fold(tensor1, tensor2, has_out)) {
2035 // dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) ||
2036 // dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2)
2037 // and at least one of the following two conditions hold
2038 // - the small tensor requires grad (see should_fold for the why)
2039 // - we can fold the larger tensor t1 into a matrix as t1.view(-1, t1.size(-1)) without copying
2040
2041 // optimization: use mm instead of bmm by folding the batch of the larger tensor
2042 // into its leading matrix dimension
2043 const auto transpose = dim_tensor2 > dim_tensor1;
2044 const auto t1 = transpose ? MaybeOwned<Tensor>::owned(tensor2.mT())
2045 : MaybeOwned<Tensor>::borrowed(tensor1);
2046 const auto t2 = !transpose ? MaybeOwned<Tensor>::borrowed(tensor2)
2047 : dim_tensor1 == 2
2048 ? MaybeOwned<Tensor>::owned(tensor1.t())
2049 : MaybeOwned<Tensor>::borrowed(tensor1);
2050 // Invariant: t1->dim() >= 3 && (t2->dim() == 1 || t2->dim() == 2)
2051 // and *t1 and *t2 are matmul-compatible
2052
2053 // Why not t1->view(-1, sizes_1.back())?
2054 // If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous.
2055 // This can happen in e.g. [3, 5, 0] @ [0, 0].
2056 const auto sizes_1 = t1->sizes();
2057 auto output_shape = DimVector(sizes_1.begin(), sizes_1.end() - 1);
2058 const auto folded_dim1 = c10::multiply_integers(output_shape);
2059
2060 // Readjust output_shape if we are multiplying by a matrix
2061 const auto t2_is_matrix = t2->dim() == 2;
2062 if (t2_is_matrix) {
2063 output_shape.push_back(t2->sizes()[1]);
2064 }
2065 // This will almost always be a view.
2066 // It may not be a view if t2->requires_grad(). See should_fold for an explanation
2067 const auto t1_folded = t1->reshape({folded_dim1, sizes_1.back()});
2068 if (!has_out) {
2069 if (t2_is_matrix) {
2070 const auto output = at::_unsafe_view(t1_folded.mm(*t2), output_shape);
2071 // This copies if we perform a 2D @ 3D and the first tensor requires_grad
2072 // See should_fold for why.
2073 // If mm_out were differentiable, we could use it here, and pass a result with the
2074 // correct strides to avoid this unnecessary copy.
2075 return transpose ? output.mT().contiguous() : output;
2076 } else {
2077 return at::_unsafe_view(t1_folded.mv(*t2), output_shape);
2078 }
2079 } else {
2080 // See the !has_out branch for an explanation
2081 TORCH_INTERNAL_ASSERT(!(transpose && t2_is_matrix));
2082
2083 // Resize output into the correct shape
2084 at::native::resize_output(out, output_shape);
2085
2086 // We then reshape the output to the expected shape and call mm/mv
2087 // and transpose back if necessary
2088 auto reshaped_out = t2_is_matrix ? out.reshape({folded_dim1, t2->sizes().back()})
2089 : out.reshape({folded_dim1});
2090 if (t2_is_matrix) {
2091 at::mm_out(reshaped_out, t1_folded, *t2);
2092 } else {
2093 at::mv_out(reshaped_out, t1_folded, *t2);
2094 }
2095 if (!reshaped_out.is_alias_of(out)) {
2096 out.copy_(reshaped_out);
2097 }
2098 return out;
2099 }
2100 } else {
2101 // dim_tensor1 >= 3 || dim_tensor2 >= 3
2102 // We track m1 vs m2 separately even though they must match for nicer error messages
2103 const int64_t n = dim_tensor1 > 1 ? tensor1.sizes().cend()[-2] : 1LL;
2104 const int64_t m1 = tensor1.sizes().back();
2105 auto batch_tensor1 = tensor1.sizes().slice(0, std::max<int64_t>(dim_tensor1 - 2, 0LL));
2106 const int64_t m2 = dim_tensor2 > 1 ? tensor2.sizes().cend()[-2] : tensor2.sizes().front();
2107 const int64_t p = dim_tensor2 > 1 ? tensor2.sizes().back() : 1LL;
2108 const IntArrayRef batch_tensor2(tensor2.sizes().data(),
2109 std::max<int64_t>(dim_tensor2 - 2, 0LL));
2110
2111 // Same optimization for the gradients as that in should_fold
2112 // If we're going to broadcast we force it to go through the should_fold branch
2113 if (dim_tensor1 == 3 && dim_tensor2 == 3 && batch_tensor1[0] != batch_tensor2[0]) {
2114 if (batch_tensor1[0] == 1 && (tensor1.requires_grad() || isTensorSubclassLike(tensor1))) {
2115 return _matmul_impl(out, tensor1.squeeze(0), tensor2);
2116 }
2117 if (batch_tensor2[0] == 1 && (tensor2.requires_grad() || isTensorSubclassLike(tensor2))) {
2118 return _matmul_impl(out, tensor1, tensor2.squeeze(0));
2119 }
2120 }
2121
2122 auto output_shape = infer_size_dimvector(batch_tensor1, batch_tensor2);
2123 const int64_t expand_batch_product = c10::multiply_integers(output_shape);
2124
2125 // flatten expanded batches
2126 const auto tensor1_expand_size = [&output_shape, n, m1]{ DimVector ret(output_shape);
2127 ret.append({n, m1});
2128 return ret; }();
2129 const auto tensor1_expanded = tensor1.expand(tensor1_expand_size)
2130 .reshape({expand_batch_product, n, m1});
2131 // We need to treat the dim_tensor2 == 1 case separately as broadcasting would not convert
2132 // a vector of shape (n,) into a batch of matrices of shape (*, n, 1)
2133 auto vector_rhs = dim_tensor2 == 1;
2134 const auto tensor2_expand_size = [&output_shape, m2, p, vector_rhs]{
2135 DimVector ret(output_shape);
2136 if (vector_rhs) {
2137 ret.push_back(m2);
2138 } else {
2139 ret.append({m2, p});
2140 }
2141 return ret;
2142 }();
2143 auto tensor2_expanded = tensor2.expand(tensor2_expand_size);
2144 if (vector_rhs) {
2145 tensor2_expanded = tensor2_expanded.reshape({expand_batch_product, m2}).unsqueeze(2);
2146 } else {
2147 tensor2_expanded = tensor2_expanded.reshape({expand_batch_product, m2, p});
2148 }
2149
2150 if (dim_tensor1 > 1) {
2151 output_shape.push_back(n);
2152 }
2153 if (dim_tensor2 > 1) {
2154 output_shape.push_back(p);
2155 }
2156
2157 if (!has_out) {
2158 if (vector_rhs) {
2159 return at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded).squeeze(-1), output_shape);
2160 } else {
2161 return at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded), output_shape);
2162 }
2163 } else {
2164 at::native::resize_output(out, output_shape);
2165 auto reshaped_out = out.reshape({expand_batch_product, n, p});
2166 at::bmm_out(reshaped_out, tensor1_expanded, tensor2_expanded);
2167 if (vector_rhs) {
2168 reshaped_out = reshaped_out.squeeze(-1);
2169 }
2170 if (!reshaped_out.is_alias_of(out)) {
2171 out.copy_(reshaped_out.view_as(out));
2172 }
2173 return out;
2174 }
2175 }
2176 }
2177
matmul(const Tensor & tensor1,const Tensor & tensor2)2178 Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
2179 auto maybe_outnames = namedinference::compute_matmul_outnames(tensor1, tensor2);
2180 at::Tensor result, unused;
2181 result = at::native::_matmul_impl(unused, tensor1, tensor2);
2182 namedinference::propagate_names_if_nonempty(result, maybe_outnames);
2183 return result;
2184 }
2185
matmul_out(const Tensor & tensor1,const Tensor & tensor2,Tensor & result)2186 Tensor& matmul_out(const Tensor & tensor1, const Tensor & tensor2, Tensor &result) {
2187 auto maybe_outnames = namedinference::compute_matmul_outnames(tensor1, tensor2);
2188 at::native::_matmul_impl(result, tensor1, tensor2);
2189 namedinference::propagate_names_if_nonempty(result, maybe_outnames);
2190 return result;
2191 }
2192
2193 // torch.linalg.matmul, alias for torch.matmul
linalg_matmul(const Tensor & tensor1,const Tensor & tensor2)2194 Tensor linalg_matmul(const Tensor & tensor1, const Tensor & tensor2) {
2195 return at::matmul(tensor1, tensor2);
2196 }
2197
linalg_matmul_out(const Tensor & tensor1,const Tensor & tensor2,Tensor & result)2198 Tensor& linalg_matmul_out(const Tensor & tensor1, const Tensor & tensor2, Tensor &result) {
2199 return at::matmul_out(result, tensor1, tensor2);
2200 }
2201
2202 // torch.linalg.diagonal, alias for torch.diagonal with dim1=-2, dim2=-1 as defaults
linalg_diagonal(const Tensor & A,int64_t offset,int64_t dim1,int64_t dim2)2203 Tensor linalg_diagonal(const Tensor& A, int64_t offset, int64_t dim1, int64_t dim2) {
2204 return A.diagonal(offset, dim1, dim2);
2205 }
2206
2207 // helper methods for matrix_exp
2208 namespace {
2209
2210 template <typename scalar_t, int ROW, int COL>
2211 using array2d = std::array<std::array<scalar_t, COL>, ROW>;
2212
2213 // we consider 6 Taylor expansions of degree
2214 // 1, 2, 4, 8, 12, 18
2215 constexpr int total_n_degs = 6;
2216
operator_1_norm(const Tensor & tensor)2217 Tensor operator_1_norm(const Tensor& tensor) {
2218 return std::get<0>(tensor.abs().sum(-2).max(-1));
2219 }
2220
2221 // Allocates a buffers of uninitialized or zero values
2222 // of shape [n_copies, a.size()]
_allocate_buffer(const Tensor & a,int n_copies,bool is_zero=false)2223 Tensor _allocate_buffer(const Tensor& a, int n_copies, bool is_zero = false) {
2224 auto res = at::empty(
2225 {n_copies, a.size(0), a.size(1), a.size(2)},
2226 a.options().memory_format(at::MemoryFormat::Contiguous)
2227 );
2228
2229 if (is_zero) {
2230 res.zero_();
2231 }
2232
2233 return res;
2234 }
2235
2236 // Makes `buffer` to store `num_matrices` number of matrices needed for
2237 // compute the matrix exponentials of different orders, i.e.
2238 // first `num_matrices` matrices from the list l := {I, A, A^2, A^3, A^6}
2239 // in a contiguous block of memory such that
2240 // buffer[0, ...] = l[0], // I
2241 // buffer[1, ...] = l[1], // A
2242 // ...
2243 // buffer[num_matrices - 1, ...] = l[num_matries - 1]
_fill_matrix_powers(Tensor & buffer,const Tensor & a,int num_matrices)2244 void _fill_matrix_powers(Tensor& buffer, const Tensor& a, int num_matrices) {
2245 auto a_sizes_minus_last = a.sizes().vec();
2246 a_sizes_minus_last.pop_back();
2247 // fill I
2248 buffer.select(0, 0).copy_(
2249 at::diag_embed(
2250 at::ones({1}, buffer.options())
2251 .expand(a_sizes_minus_last)
2252 )
2253 );
2254
2255 // fill a
2256 buffer.select(0, 1).copy_(a);
2257
2258 // fill a^2
2259 if (2 <= num_matrices - 1) {
2260 // out for a^2
2261 auto view_out = buffer.select(0, 2);
2262 _matmul_impl(
2263 view_out,
2264 buffer.select(0, 1),
2265 buffer.select(0, 1)
2266 );
2267 }
2268
2269 // fill a^3
2270 if (3 <= num_matrices - 1) {
2271 // out for a^3
2272 auto view_out = buffer.select(0, 3);
2273 _matmul_impl(
2274 view_out,
2275 buffer.select(0, 1),
2276 buffer.select(0, 2)
2277 );
2278 }
2279
2280 // fill a^6
2281 if (4 <= num_matrices - 1) {
2282 // out for a^6
2283 auto view_out = buffer.select(0, 4);
2284 _matmul_impl(
2285 view_out,
2286 buffer.select(0, 3),
2287 buffer.select(0, 3)
2288 );
2289 }
2290 }
2291
_move_memory_if_cuda_input(const Tensor & mem,const Tensor & in)2292 inline Tensor _move_memory_if_cuda_input(
2293 const Tensor& mem,
2294 const Tensor& in
2295 ) {
2296 return (in.device().type() == at::kCUDA)
2297 ? mem.to(at::device_of(in).value())
2298 : mem;
2299 }
2300
2301 // convert a 1D blob to a 2D Tensor of size [1, blob.size()]
2302 // such that blob.device() == in.device())
2303 // designed to be used with _compute_linear_combination
2304 template <typename scalar_t>
_blob_to_Tensor(std::initializer_list<scalar_t> blob,const Tensor & in)2305 inline Tensor _blob_to_Tensor(
2306 std::initializer_list<scalar_t> blob,
2307 const Tensor& in
2308 ) {
2309 // we convert to void* expecitly because begin() returns
2310 // a pointer to a constant.
2311 // Blob is assumed to be a 1D array, that is why
2312 // we also insert a fake dimension so that the result could directly
2313 // be used in _compute_linear_combination
2314 auto tensor = at::from_blob((void*)blob.begin(), blob.size(),
2315 c10::toRealValueType(in.scalar_type())).unsqueeze(0);
2316 return _move_memory_if_cuda_input(tensor, in);
2317 }
2318
2319 template <typename scalar_t>
_linear_combination(const Tensor & t,std::initializer_list<scalar_t> blob)2320 inline Tensor _linear_combination(
2321 const Tensor& t,
2322 std::initializer_list<scalar_t> blob) {
2323 // _blob_to_Tensor converts blob to a 2D tensor for _compute_linear_combination.
2324 // If this tensor is of shape (1, *), the result of _compute_linear_combination
2325 // is going to be of shape (1, *t.shape) so we squeeze(0) so that
2326 // for any t with t.dim() >= 1: t.dim() == _compute_linear_combination(t, ...).dim().
2327 return at::native::_compute_linear_combination(
2328 t, _blob_to_Tensor<scalar_t>(blob, t))
2329 .squeeze(0);
2330 }
2331
2332 // I + A
compute_T1(const Tensor & A)2333 Tensor compute_T1(const Tensor& A) {
2334 // 2 for {I, A}
2335 auto As = _allocate_buffer(A, 2);
2336 _fill_matrix_powers(As, A, 2);
2337 return As.sum(0);
2338 }
2339
2340 // I + A + A^2 / 2
compute_T2(const Tensor & A)2341 Tensor compute_T2(const Tensor& A) {
2342 auto As = _allocate_buffer(A, 3);
2343 // 3 for {I, A, A^2}
2344 _fill_matrix_powers(As, A, 3);
2345 As.select(0, 2).div_(2.0);
2346 return As.sum(0);
2347 }
2348
2349 // I + A + A^2 * (I / 2 + A / 6 + A^2 / 24)
2350 template <typename scalar_t>
compute_T4(const Tensor & A)2351 Tensor compute_T4(const Tensor& A) {
2352 auto As = _allocate_buffer(A, 4);
2353 // 3 for {I, A, A^2}
2354 _fill_matrix_powers(As, A, 3);
2355
2356 // output for A^2 * (I / 2 + A / 6 + A^2 / 24)
2357 auto view_out = As.select(0, 3);
2358 _matmul_impl(
2359 view_out,
2360 // contains A^2
2361 As.select(0, 2),
2362 // computes (I / 2 + A / 6 + A^2 / 24)
2363 _linear_combination<scalar_t>(
2364 As.narrow(0, 0, 3),
2365 {1 / 2.0, 1 / 6.0, 1 / 24.0}
2366 )
2367 );
2368
2369 // I + A + A^2 * (I / 2 + A / 6 + A^2 / 24)
2370 return _linear_combination<scalar_t>(
2371 As, {1.0, 1.0, 0.0, 1.0}
2372 );
2373 }
2374
2375 template <typename scalar_t>
compute_T8(const Tensor & A)2376 Tensor compute_T8(const Tensor& A) {
2377 constexpr scalar_t sqrt_177 = 0.1330413469565007072504e+2;
2378 constexpr scalar_t x3 = 2. / 3.;
2379 constexpr scalar_t x1 = x3 * ((1. + sqrt_177) / 88.);
2380 constexpr scalar_t x2 = x3 * ((1. + sqrt_177) / 352.);
2381 constexpr scalar_t x4 = (-271. + 29. * sqrt_177) / (315. * x3);
2382 constexpr scalar_t x5 = (-11. + 11. * sqrt_177) / (1260. * x3);
2383 constexpr scalar_t x6 = (-99. + 11. * sqrt_177) / (5040. * x3);
2384 constexpr scalar_t x7 = (89. - sqrt_177) / (5040. * x3);
2385 constexpr scalar_t y2 = (857. - 58. * sqrt_177) / 630.;
2386
2387 auto As = _allocate_buffer(A, 5);
2388 // 3 for {I, A, A^2}
2389 _fill_matrix_powers(As, A, 3);
2390
2391 // output for A4
2392 auto view_out = As.select(0, 3);
2393 // A4 = A2 * (x1 * A + x2 * A2)
2394 _matmul_impl(
2395 view_out,
2396 // As.select(0, 2) = A^2
2397 As.select(0, 2),
2398 _linear_combination<scalar_t>(
2399 // extract {A, A^2} from As
2400 As.narrow(0, 1, 2),
2401 {x1, x2}
2402 )
2403 );
2404
2405 // output for A8
2406 view_out = As.select(0, 4);
2407 // A8 = (x3 * A2 + A4) * (x4 * I + x5 * A + x6 * A2 + x7 * A4)
2408 _matmul_impl(
2409 view_out,
2410 // x3 * A2 + A4
2411 _linear_combination<scalar_t>(
2412 As.narrow(0, 2, 2),
2413 {x3, 1.0}
2414 ),
2415 _linear_combination<scalar_t>(
2416 As.narrow(0, 0, 4),
2417 {x4, x5, x6, x7}
2418 )
2419 );
2420
2421 // return I + A + y2 * A2 + A8;
2422 return _linear_combination<scalar_t>(
2423 As, {1.0, 1.0, y2, 0.0, 1.0}
2424 );
2425 }
2426
2427 template <typename scalar_t>
compute_T12(const Tensor & A)2428 Tensor compute_T12(const Tensor& A) {
2429 constexpr int num_prods = 4;
2430 array2d<scalar_t, num_prods, num_prods> b = {{
2431 {
2432 9.0198e-16,
2433 0.46932117595418237389,
2434 -0.20099424927047284052,
2435 -0.04623946134063071740
2436 },
2437 {
2438 5.31597895759871264183,
2439 1.19926790417132231573,
2440 0.01179296240992997031,
2441 0.01108844528519167989
2442 },
2443 {
2444 0.18188869982170434744,
2445 0.05502798439925399070,
2446 0.09351590770535414968,
2447 0.00610700528898058230
2448 },
2449 {
2450 -2.0861320e-13,
2451 -0.13181061013830184015,
2452 -0.02027855540589259079,
2453 -0.00675951846863086359
2454 }
2455 }};
2456
2457 // gather coefficients `b` from above into a tensor,
2458 // and move them to device `device_of(A)`
2459 auto bs = at::from_blob(
2460 reinterpret_cast<void*>(&b),
2461 {num_prods, num_prods},
2462 {num_prods, 1},
2463 c10::toRealValueType(A.scalar_type())
2464 );
2465 bs = _move_memory_if_cuda_input(bs, A);
2466
2467 auto As = _allocate_buffer(A, num_prods);
2468 _fill_matrix_powers(As, A, num_prods);
2469
2470 auto Bs = at::native::_compute_linear_combination(As, bs);
2471
2472 // output for A6
2473 auto view_out = As.select(0, 0);
2474 // compute A6
2475 Bs.select(0, 2).add_(_matmul_impl(
2476 view_out,
2477 Bs.select(0, 3),
2478 Bs.select(0, 3)
2479 ));
2480
2481 return Bs.select(0, 0).add_(_matmul_impl(
2482 view_out,
2483 Bs.select(0, 1).add_(Bs.select(0, 2)),
2484 Bs.select(0, 2)
2485 ));
2486 }
2487
2488 template <typename scalar_t>
compute_T18(const Tensor & A)2489 Tensor compute_T18(const Tensor& A) {
2490 constexpr int num_prods = 5;
2491 array2d<scalar_t, num_prods, num_prods> b = {{
2492 {
2493 0.,
2494 -1.00365581030144618291e-01,
2495 -8.02924648241156932449e-03,
2496 -8.92138498045729985177e-04,
2497 0.
2498 },
2499 {
2500 0.,
2501 3.97849749499645077844e-01,
2502 1.36783778460411720168e+00,
2503 4.98289622525382669416e-01,
2504 -6.37898194594723280150e-04
2505 },
2506 {
2507 -1.09676396052962061844e+01,
2508 1.68015813878906206114e+00,
2509 5.71779846478865511061e-02,
2510 -6.98210122488052056106e-03,
2511 3.34975017086070470649e-05
2512 },
2513 {
2514 -9.04316832390810593223e-02,
2515 -6.76404519071381882256e-02,
2516 6.75961301770459654925e-02,
2517 2.95552570429315521194e-02,
2518 -1.39180257516060693404e-05
2519 },
2520 {
2521 0.,
2522 0.,
2523 -9.23364619367118555360e-02,
2524 -1.69364939002081722752e-02,
2525 -1.40086798182036094347e-05
2526 }
2527 }};
2528
2529 // gather coefficients `b` from above into a tensor,
2530 // and move them to device `device_of(A)`
2531 auto bs = at::from_blob(
2532 reinterpret_cast<void*>(&b),
2533 {num_prods, num_prods},
2534 {num_prods, 1},
2535 c10::toRealValueType(A.scalar_type())
2536 );
2537 bs = _move_memory_if_cuda_input(bs, A);
2538
2539 auto As = _allocate_buffer(A, num_prods);
2540 _fill_matrix_powers(As, A, num_prods);
2541
2542 auto Bs = at::native::_compute_linear_combination(As, bs);
2543
2544 // tmp buffer for this matrix product
2545 auto view_out = As.select(0, 0);
2546 // compute A9
2547 Bs.select(0, 3).add_(_matmul_impl(
2548 view_out,
2549 Bs.select(0, 0),
2550 Bs.select(0, 4))
2551 );
2552
2553 return Bs.select(0, 1).add_(_matmul_impl(
2554 view_out,
2555 Bs.select(0, 2).add_(Bs.select(0, 3)),
2556 Bs.select(0, 3)
2557 ));
2558 }
2559
2560 template <typename scalar_t>
compute_T18_scale_square(const Tensor & a,const Tensor & norm,scalar_t theta)2561 Tensor compute_T18_scale_square(
2562 const Tensor& a,
2563 const Tensor& norm,
2564 scalar_t theta
2565 ) {
2566 // Scale
2567 // We eventually need to do the matrix multiplication to calculate the result.
2568 // For example, if we have `norm` equal to [27, 6, 6, 0.05], we will end up to
2569 // get `s` as [4, 1, 1, 0], so we can use it to get the result by calculating
2570 // matrix[0]^(2^4), matrix[1]^(2^1) and matrix[2]^(2^1) one by one to get the
2571 // result, such "one by one calculation" will be quite slow.
2572 const auto s = (at::ceil(at::log2(norm / theta))).clamp(/*min=*/0);
2573 const auto pow2s = at::pow(2, -s);
2574 const auto a_scaled = a * pow2s.view({-1, 1, 1});
2575 auto mexp_scaled = at::native::compute_T18<scalar_t>(a_scaled);
2576
2577 // Sort:
2578 // Consider inputs are square matrix, so if we first power `matrix 0,1,2`, then
2579 // the remain thing will only be multiply `matrix 0` by (2^4 - 1) times, which
2580 // gives us an opportunity to calculate the matrix multiplication in a batch.
2581 // The first thing we need to do is sort tensor `s`, which will be helpful to
2582 // do the matrix multiplication by range.
2583 // With above example, `sorted_s` is [0, 1, 1, 4], we also will need the index
2584 // info, so we can use it to compose the result back.
2585 auto [sorted_s, sorted_s_inds] = at::sort(s, /*dim=*/0);
2586 sorted_s = sorted_s.to(at::kLong);
2587 // Then we call `unique_consecutive` and we will use it to split `sorted_s`,
2588 // with above example, `split_counts` is [1, 2, 1].
2589 auto split_counts = std::get<2>(at::unique_consecutive(sorted_s, true, /*return_counts=*/true));
2590 // We also need to know the index of the last element of each split, so we can
2591 // know how many times we need to do the multiplication for each split matrix.
2592 // Notice that, we will not need to calculate the actual pows, because we will
2593 // use the cumulative matrix multiplication.
2594 // With about example, `mul_times` will be [0, 1, 3].
2595 auto split_edges = at::cumsum(split_counts, /*dim=*/0) - 1;
2596 auto unique_s = sorted_s.index_select(0, split_edges).clamp(/*min=*/0);
2597 auto mul_times = at::diff(unique_s, 1, -1, /*prepend=*/unique_s.new_zeros({1}));
2598
2599 // Square
2600 auto section_values = at::cat({split_counts, mul_times}, 0).to(at::kCPU);
2601
2602 TORCH_INTERNAL_ASSERT(section_values.is_contiguous());
2603 const auto section_numel = section_values.numel() / 2;
2604 auto scs = section_values. template data_ptr<int64_t>();
2605 auto pts = &scs[section_numel];
2606
2607 // We now will do the matrix multiplication in a batch, with above example:
2608 // 1. Multiply all matrices by 0 (`mul_times[0]`) times, then do `slice`
2609 // to get the remain matrices by acc[1:] (`split_counts[0]`),
2610 // 2. Multiply remain matrices by 1 times and slice to acc[2:]
2611 // 3. Multiply remain matrices by 3 times and slice to acc[1:]
2612 // All processed matrices will be stored in `output_pieces`.
2613 std::vector<Tensor> output_pieces;
2614 auto acc = mexp_scaled.index_select(0, sorted_s_inds);
2615 for (int64_t i = 0; i < section_numel; ++i) {
2616 for (int64_t j = 0; j < pts[i]; j++) {
2617 // To avoid AMP autocasting caused by at::matmul
2618 auto acc_out = at::empty_like(acc);
2619 acc = at::matmul_out(acc_out, acc, acc);
2620 }
2621 output_pieces.push_back(acc.slice(0, 0, scs[i]));
2622 acc = acc.slice(0, scs[i]);
2623 }
2624
2625 // Compose the result back
2626 auto output = at::cat(output_pieces, 0);
2627 return output.index_select(0, at::argsort(sorted_s_inds));
2628 }
2629
2630 template <typename scalar_t>
mexp_impl(const Tensor & a,std::array<scalar_t,total_n_degs> thetas,bool compute_highest_degree_approx=false)2631 Tensor mexp_impl(
2632 const Tensor& a,
2633 std::array<scalar_t, total_n_degs> thetas,
2634 bool compute_highest_degree_approx = false
2635 ) {
2636 const auto norm = operator_1_norm(a);
2637 const auto batch_size = a.size(0);
2638 if (batch_size > 1) {
2639 compute_highest_degree_approx = true;
2640 }
2641
2642 if (!compute_highest_degree_approx) {
2643 // To prevent undefined behavior which outputs "normal" result from a matrix
2644 // contains NaN values, we put NaN values in `res`, so if input has NaN values,
2645 // its computation will be skipped to return the NaN contained `res` directly.
2646 auto res = at::full_like(a, std::numeric_limits<double>::quiet_NaN(), {},
2647 at::MemoryFormat::Contiguous);
2648 // `norm_cpu` is used to decide which Tensors require which approximation
2649 // based on their norm. This decision takes place on CPU.
2650 // It requires moving data back and forth between devices when `a` is on CUDA,
2651 // but at the cost of only one sigle CPU-CUDA synchronization (instead of 6),
2652 // and better performance overall (benchmarked).
2653 const auto norm_cpu = (a.device().type() == at::kCUDA)
2654 ? norm.to(at::kCPU) : norm;
2655
2656 constexpr std::array<
2657 Tensor(*)(const Tensor&),
2658 total_n_degs - 1>
2659 compute_Ts = {
2660 compute_T1, compute_T2, compute_T4<scalar_t>,
2661 compute_T8<scalar_t>, compute_T12<scalar_t>
2662 };
2663
2664 for (int i = 0; i < total_n_degs - 1; ++i) {
2665 auto norm_lower_bound = (i == 0) ? static_cast<scalar_t>(-1) : thetas[i - 1];
2666 auto norm_upper_bound = thetas[i];
2667 // nonzero returns a 2D tensor, hence squeeze(-1) to make it 1D
2668 auto idx_curr_norm_interval = (
2669 (norm_lower_bound < norm_cpu) * (norm_cpu <= norm_upper_bound)
2670 ).nonzero().squeeze(-1);
2671
2672 if (idx_curr_norm_interval.numel()) {
2673 auto idx_to_device = _move_memory_if_cuda_input(
2674 idx_curr_norm_interval, a
2675 );
2676 auto sub_a = at::index_select(a, 0, idx_to_device);
2677 res.index_put_({idx_to_device}, compute_Ts[i](sub_a));
2678 }
2679 }
2680
2681 // nonzero returns a 2D tensor, hence squeeze(-1) to make it 1D
2682 auto idx_large_norm = (norm_cpu >= thetas[total_n_degs - 2])
2683 .nonzero().squeeze(-1);
2684
2685 if (idx_large_norm.numel()) {
2686 auto idx_to_device = _move_memory_if_cuda_input(
2687 idx_large_norm, a
2688 );
2689 auto a_large_norm = at::index_select(a, 0, idx_to_device);
2690 auto large_norm_subset = at::index_select(norm, 0, idx_to_device);
2691 auto mexp_out = compute_T18_scale_square(
2692 a_large_norm,
2693 large_norm_subset,
2694 thetas[total_n_degs - 1]
2695 );
2696 res.index_put_({idx_large_norm}, mexp_out);
2697 }
2698 return res;
2699 }
2700
2701 return compute_T18_scale_square(
2702 a, norm,
2703 thetas[total_n_degs - 1]
2704 );
2705 }
2706
2707 // matrix exponential
mexp(const Tensor & a,bool compute_highest_degree_approx=false)2708 Tensor mexp(const Tensor& a, bool compute_highest_degree_approx = false) {
2709 // squash batch dimensions to one dimension for simplicity
2710 const auto a_3d = a.view({-1, a.size(-2), a.size(-1)});
2711
2712 if (a.scalar_type() == at::ScalarType::Float
2713 || a.scalar_type() == at::ScalarType::ComplexFloat) {
2714 constexpr std::array<float, total_n_degs> thetas_float = {
2715 1.192092800768788e-07, // deg 1
2716 5.978858893805233e-04, // deg 2
2717 5.116619363445086e-02, // deg 4
2718 5.800524627688768e-01, // deg 8
2719 1.461661507209034e+00, // deg 12
2720 3.010066362817634e+00 // deg 18
2721 };
2722
2723 return mexp_impl<float>(a_3d, thetas_float, compute_highest_degree_approx)
2724 .view(a.sizes());
2725 }
2726 else { // if Double or ComplexDouble
2727 constexpr std::array<double, total_n_degs> thetas_double = {
2728 2.220446049250313e-16, // deg 1
2729 2.580956802971767e-08, // deg 2
2730 3.397168839976962e-04, // deg 4
2731 4.991228871115323e-02, // deg 8
2732 2.996158913811580e-01, // deg 12
2733 1.090863719290036e+00 // deg 18
2734 };
2735
2736 return mexp_impl<double>(a_3d, thetas_double, compute_highest_degree_approx)
2737 .view(a.sizes());
2738 }
2739 }
2740
2741 // TODO This should be deprecated in favor of linalg_matrix_exp_differential
2742 // in FunctionsManual.cpp
2743 template <typename func_t>
backward_analytic_function_of_a_matrix(const Tensor & self,const Tensor & grad,const func_t & function_of_a_matrix)2744 Tensor backward_analytic_function_of_a_matrix(
2745 const Tensor& self, const Tensor& grad,
2746 const func_t& function_of_a_matrix
2747 ) {
2748 auto self_transposed = self.mH();
2749 auto self_transposed_sizes = self_transposed.sizes().vec();
2750 self_transposed_sizes[self.dim() - 2] <<= 1;
2751 self_transposed_sizes[self.dim() - 1] <<= 1;
2752
2753 auto n = self_transposed.size(-1);
2754 auto meta_grad = at::zeros(self_transposed_sizes, grad.options());
2755 meta_grad.narrow(-2, 0, n).narrow(-1, 0, n).copy_(self_transposed);
2756 meta_grad.narrow(-2, n, n).narrow(-1, n, n).copy_(self_transposed);
2757 meta_grad.narrow(-2, 0, n).narrow(-1, n, n).copy_(grad);
2758
2759 auto grad_input = function_of_a_matrix(meta_grad)
2760 .narrow(-2, 0, n).narrow(-1, n, n);
2761 return grad_input;
2762 }
2763 } // end anon namespace
2764
2765 // Computes the matrix exponential for a given batch of squared matrices.
2766 // The implementation is based on:
2767 //
2768 // Bader, P.; Blanes, S.; Casas, F.
2769 // Computing the Matrix Exponential with an Optimized Taylor Polynomial Approximation.
2770 // Mathematics 2019, 7, 1174.
2771 //
linalg_matrix_exp(const Tensor & a)2772 Tensor linalg_matrix_exp(const Tensor& a) {
2773 squareCheckInputs(a, "linalg.matrix_exp");
2774 checkFloatingOrComplex(a, "linalg.matrix_exp");
2775
2776 NoTF32Guard disable_tf32;
2777
2778 // Trivial cases
2779 const auto n = a.size(-1);
2780 if (n == 0) {
2781 return a.clone();
2782 } else if (n == 1) {
2783 return a.exp();
2784 } else {
2785 return at::native::mexp(a);
2786 }
2787 }
2788
2789 // Alias
matrix_exp(const Tensor & a)2790 Tensor matrix_exp(const Tensor& a) {
2791 return at::linalg_matrix_exp(a);
2792 }
2793
2794 // TODO This should be deprecated in favor of linalg_matrix_exp_differential
2795 // in FunctionsManual.cpp
matrix_exp_backward(const Tensor & self,const Tensor & grad)2796 Tensor matrix_exp_backward(const Tensor& self, const Tensor& grad) {
2797 NoTF32Guard disable_tf32;
2798 return backward_analytic_function_of_a_matrix(
2799 self, grad,
2800 [](const Tensor& a) {
2801 return a.matrix_exp();
2802 }
2803 );
2804 }
2805
TORCH_IMPL_FUNC(linalg_vector_norm_out)2806 TORCH_IMPL_FUNC(linalg_vector_norm_out)(const Tensor& self, const Scalar& scalar_ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype, const Tensor& result) {
2807 // Casting a large integer to a double will just introduce an error for
2808 // values larger than 10^53 (same for negative numbers), so that's fine.
2809 auto ord = scalar_ord.toDouble();
2810 auto dim = opt_dim.value_or(IntArrayRef{});
2811 auto size = self.sizes();
2812 auto ndim = self.dim();
2813
2814 auto opt_dim_ = dim.vec();
2815 maybe_wrap_dims(opt_dim_, ndim);
2816
2817 using Int = IntArrayRef::value_type;
2818 std::vector<Int> all_dim(ndim);
2819 std::iota(all_dim.begin(), all_dim.end(), 0);
2820
2821 bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty();
2822 auto reduce_dim = is_all_reduce ? all_dim : opt_dim_;
2823
2824 bool is_reduce_over_1D_vector = true;
2825 for (auto i : reduce_dim) {
2826 if (size[i] != 1){
2827 is_reduce_over_1D_vector = false;
2828 break;
2829 }
2830 }
2831
2832 if (is_reduce_over_1D_vector) {
2833 Tensor self_;
2834 if (opt_dtype.has_value()) {
2835 self_ = self.to(*opt_dtype);
2836 } else {
2837 self_ = self;
2838 }
2839 if (ord != 0.0) {
2840 keepdim ? at::abs_outf(self_, const_cast<Tensor&>(result)) : at::abs_outf(self_.squeeze(reduce_dim), const_cast<Tensor&>(result));
2841 } else {
2842 keepdim ? at::ne_outf(self_, 0, const_cast<Tensor&>(result)) : at::ne_outf(self_.squeeze(reduce_dim), 0, const_cast<Tensor&>(result));
2843 }
2844 return;
2845 }
2846
2847 // No need to handle opt_dtype explicitly as it is already encoded in the dtype of result
2848
2849 // https://github.com/pytorch/pytorch/issues/52648
2850 // Reductions always use `std::abs` to compute the absolute value. In the backward of this
2851 // function, we need to locate the index that was selected as the largest value. To do so
2852 // we do self.abs() == result to locate the index of the largest element.
2853 // Now, self.abs() may dispatch to a vectorized implementation which gives slightly different
2854 // results to the std::abs(std::complex<T>) implementation.
2855 // As such, to be able to compute the correct index in the backward, we need to use self.abs()
2856 // both in the forward and in the backward
2857 Tensor self_;
2858 if (self.is_cpu() && self.is_complex() && std::abs(ord) == INFINITY) {
2859 if (opt_dtype.has_value()) {
2860 self_ = self.to(*opt_dtype).abs();
2861 } else {
2862 self_ = self.abs();
2863 }
2864 } else {
2865 self_ = self;
2866 }
2867
2868 auto iter = make_reduction("vector_norm", const_cast<Tensor&>(result), self_, dim, keepdim, result.scalar_type());
2869 norm_stub(iter.device_type(), iter, ord);
2870 }
2871
_linalg_matrix_norm_checks(const Tensor & A,std::vector<int64_t> & dim,std::optional<ScalarType> opt_dtype,bool low_precision)2872 static void _linalg_matrix_norm_checks(const Tensor& A, std::vector<int64_t>& dim, std::optional<ScalarType> opt_dtype, bool low_precision) {
2873 // A
2874 at::native::checkIsMatrix(A, "linalg.matrix_norm");
2875 at::native::checkFloatingOrComplex(A, "linalg.matrix_norm", /*low_precision*/low_precision);
2876
2877 // dim
2878 TORCH_CHECK(dim.size() == 2, "linalg.matrix_norm: dim must be a 2-tuple. Got ", dim);
2879 // wrap first to identify weird scenarios like A.ndim = 2, dim = (1, -1)
2880 // dim is modified in place while wrapping it
2881 maybe_wrap_dims(dim, A.dim());
2882 TORCH_CHECK(dim[0] != dim[1], "linalg.matrix_norm: dims must be different. Got (", dim[0], ", ", dim[1], ")");
2883
2884 // dtype
2885 at::detail::check_linalg_norm_dtype(opt_dtype, A.scalar_type(), "linalg.matrix_norm");
2886 }
2887
linalg_matrix_norm(const Tensor & A,const Scalar & scalar_ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype)2888 Tensor linalg_matrix_norm(
2889 const Tensor& A,
2890 const Scalar& scalar_ord,
2891 IntArrayRef dim,
2892 bool keepdim,
2893 std::optional<ScalarType> opt_dtype) {
2894 // Check ord first as it will be used in the dtype check of A
2895 auto ord = scalar_ord.toDouble();
2896 auto abs_ord = std::abs(ord);
2897 TORCH_CHECK(abs_ord == 2. || abs_ord == 1. || abs_ord == INFINITY, "linalg.matrix_norm: Order ", ord, " not supported.");
2898
2899 auto dim_ = dim.vec();
2900 // Check A, dim, and dtype
2901 _linalg_matrix_norm_checks(A, dim_, opt_dtype, /*low_precision*/abs_ord != 2.);
2902
2903 auto max_min = [ord, keepdim](const Tensor& A, int64_t dim) { return ord > 0 ? A.amax(dim, keepdim) : A.amin(dim, keepdim); };
2904 if (abs_ord == 2.) {
2905 // Move dims to the end
2906 auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], A.dim());
2907
2908 auto A_ = opt_dtype.has_value() ? A.to(*opt_dtype) : A;
2909 auto result = max_min(at::linalg_svdvals(A_.permute(permutation)), -1);
2910 if (keepdim) {
2911 auto permutation_reverse = create_reverse_permutation(std::move(permutation));
2912 result = result.unsqueeze(-1).permute(permutation_reverse);
2913 }
2914 return result;
2915 } else { // 1, -1, inf, -inf
2916 // The infty norm is like the 1 norm on the transposed matrix
2917 if (abs_ord == INFINITY) {
2918 std::swap(dim_[0], dim_[1]);
2919 }
2920
2921 // If the first reduction removes one dim from the front (dim_[0] < dim_[1]), after this
2922 // reduction dim_[1] will be off by one
2923 if (!keepdim && (dim_[0] < dim_[1])) {
2924 dim_[1]--;
2925 }
2926 return max_min(at::linalg_vector_norm(A, 1., {dim_[0]}, keepdim, opt_dtype), dim_[1]);
2927 }
2928 }
2929
linalg_matrix_norm_out(const Tensor & A,const Scalar & ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)2930 Tensor& linalg_matrix_norm_out(
2931 const Tensor& A,
2932 const Scalar& ord,
2933 IntArrayRef dim,
2934 bool keepdim,
2935 std::optional<ScalarType> opt_dtype,
2936 Tensor& result) {
2937 checkSameDevice("linalg.matrix_norm", A, result);
2938 auto out = at::linalg_matrix_norm(A, ord, dim, keepdim, opt_dtype);
2939 TORCH_CHECK(out.scalar_type() == result.scalar_type(),
2940 "linalg.matrix_norm expected out tensor dtype ", out.scalar_type(),
2941 " but got: ", result.scalar_type());
2942 at::native::resize_output(result, out.sizes());
2943 result.copy_(out);
2944 return result;
2945 }
2946
2947 // fro / nuc
linalg_matrix_norm(const Tensor & A,c10::string_view ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype)2948 Tensor linalg_matrix_norm(
2949 const Tensor& A,
2950 c10::string_view ord,
2951 IntArrayRef dim,
2952 bool keepdim,
2953 std::optional<ScalarType> opt_dtype) {
2954 // Check ord first as it will be used in the dtype check of A
2955 TORCH_CHECK(ord == "fro" || ord == "nuc", "linalg.matrix_norm: Order ", ord, " not supported.");
2956
2957 auto dim_ = dim.vec();
2958 // Check A, dim, and dtype
2959 _linalg_matrix_norm_checks(A, dim_, opt_dtype, /*low_precision*/ord != "nuc");
2960
2961 if (ord == "fro") {
2962 return at::linalg_vector_norm(A, 2, dim_, keepdim, opt_dtype);
2963 } else { // nuc
2964 auto A_ = opt_dtype.has_value() ? A.to(*opt_dtype) : A;
2965
2966 // Move dims to the end
2967 auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], A_.dim());
2968 auto result = at::linalg_svdvals(A_.permute(permutation)).sum(-1, keepdim);
2969 if (keepdim) {
2970 auto permutation_reverse = create_reverse_permutation(std::move(permutation));
2971 result = result.unsqueeze(-1).permute(permutation_reverse);
2972 }
2973 return result;
2974 }
2975 }
2976
linalg_matrix_norm_out(const Tensor & A,c10::string_view ord,IntArrayRef dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)2977 Tensor& linalg_matrix_norm_out(
2978 const Tensor& A,
2979 c10::string_view ord,
2980 IntArrayRef dim,
2981 bool keepdim,
2982 std::optional<ScalarType> opt_dtype,
2983 Tensor& result) {
2984 checkSameDevice("linalg.matrix_norm", A, result);
2985 auto out = at::linalg_matrix_norm(A, ord, dim, keepdim, opt_dtype);
2986 TORCH_CHECK(out.scalar_type() == result.scalar_type(),
2987 "linalg.matrix_norm expected out tensor dtype ", out.scalar_type(),
2988 " but got: ", result.scalar_type());
2989 at::native::resize_output(result, out.sizes());
2990 result.copy_(out);
2991 return result;
2992 }
2993
2994 // Numerical or None norms
linalg_norm(const Tensor & X,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)2995 Tensor linalg_norm(const Tensor& X, const std::optional<Scalar>& opt_ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
2996 if (opt_dim.has_value()) {
2997 TORCH_CHECK(opt_dim->size() == 1 || opt_dim ->size() == 2, "linalg.norm: If ",
2998 "dim is specified, it must be of length 1 or 2. Got ", *opt_dim);
2999 } else {
3000 if (opt_ord.has_value()) {
3001 TORCH_CHECK(X.dim() == 1 || X.dim() == 2, "linalg.norm: If ",
3002 "dim is not specified but ord is, the input must be 1D or 2D. Got ", X.dim(), "D.");
3003 }
3004 }
3005
3006 // If ord=None, we'll always use the 2-norm or frob norm (which are the same) so we go through
3007 // vector_norm
3008 if (opt_ord.has_value() &&
3009 ((opt_dim.has_value() && opt_dim->size() == 2) ||
3010 (!opt_dim.has_value() && X.dim() == 2))) {
3011 using Int = IntArrayRef::value_type;
3012 auto dim = opt_dim.has_value() ? opt_dim.value().vec() : std::vector<Int>{0, 1};
3013 return at::linalg_matrix_norm(X, *opt_ord, dim, keepdim, opt_dtype);
3014 } else {
3015 auto scalar_ord = opt_ord.value_or(Scalar(2.));
3016 return at::linalg_vector_norm(X, scalar_ord, opt_dim, keepdim, opt_dtype);
3017 }
3018 }
3019
linalg_norm_out(const Tensor & X,const std::optional<Scalar> & opt_ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)3020 Tensor& linalg_norm_out(const Tensor& X, const std::optional<Scalar>& opt_ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
3021 checkSameDevice("linalg.norm", X, result);
3022 auto out = at::linalg_norm(X, opt_ord, opt_dim, keepdim, opt_dtype);
3023 TORCH_CHECK(out.scalar_type() == result.scalar_type(),
3024 "linalg.norm expected out tensor dtype ", out.scalar_type(),
3025 " but got: ", result.scalar_type());
3026 at::native::resize_output(result, out.sizes());
3027 result.copy_(out);
3028 return result;
3029 }
3030
3031 // Frobenius and nuclear norms
linalg_norm(const Tensor & X,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype)3032 Tensor linalg_norm(const Tensor& X, c10::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype) {
3033 if (opt_dim.has_value()) {
3034 TORCH_CHECK(opt_dim->size() == 1 || opt_dim ->size() == 2, "linalg.norm: If ",
3035 "dim is specified, it mut be of length 1 or 2. Got ", *opt_dim);
3036 } else {
3037 TORCH_CHECK(X.dim() == 1 || X.dim() == 2, "linalg.norm: If ",
3038 "dim is not specified but ord is, the input must be 1D or 2D. Got ", X.dim(), "D.");
3039 }
3040 using Int = IntArrayRef::value_type;
3041 auto dim = opt_dim.has_value() ? opt_dim.value().vec() : std::vector<Int>{0, 1};
3042 return at::linalg_matrix_norm(X, ord, dim, keepdim, opt_dtype);
3043 }
3044
linalg_norm_out(const Tensor & X,c10::string_view ord,OptionalIntArrayRef opt_dim,bool keepdim,std::optional<ScalarType> opt_dtype,Tensor & result)3045 Tensor& linalg_norm_out(const Tensor& X, c10::string_view ord, OptionalIntArrayRef opt_dim, bool keepdim, std::optional<ScalarType> opt_dtype, Tensor& result) {
3046 checkSameDevice("linalg.norm", X, result);
3047 auto out = at::linalg_norm(X, ord, opt_dim, keepdim, opt_dtype);
3048 TORCH_CHECK(out.scalar_type() == result.scalar_type(),
3049 "linalg.norm expected out tensor dtype ", out.scalar_type(),
3050 " but got: ", result.scalar_type());
3051 at::native::resize_output(result, out.sizes());
3052 result.copy_(out);
3053 return result;
3054 }
3055
3056 ////////////////////////////////////////////////////////////////////////////////
3057 // Frobenius Norm //
3058 ////////////////////////////////////////////////////////////////////////////////
3059
frobenius_norm(const Tensor & self,IntArrayRef dim,bool keepdim)3060 Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
3061 auto device = self.device();
3062 if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
3063 TORCH_WARN_ONCE(
3064 "at::frobenius_norm is deprecated and it is just left for JIT compatibility. ",
3065 "It will be removed in a future PyTorch release. Please use ",
3066 "`linalg.vector_norm(A, 2., dim, keepdim)` instead"
3067 );
3068 }
3069 // This frobenius norm is just wrong, but well
3070 TORCH_CHECK(dim.size() <= 2,
3071 "Expected at most 2 dimensions, but got ", dim.size(), " dimensions instead.");
3072 // Dispatch to at::norm as it is implemented for Sparse and MPS backends
3073 // TODO Make the backends implement vector_norm and matrix_norm
3074 return at::norm(self, 2., dim, keepdim);
3075 }
3076
frobenius_norm_out(const Tensor & self,IntArrayRef dim,bool keepdim,Tensor & result)3077 Tensor &frobenius_norm_out(const Tensor& self,
3078 IntArrayRef dim,
3079 bool keepdim,
3080 Tensor& result) {
3081 auto device = self.device();
3082 if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
3083 TORCH_WARN_ONCE(
3084 "at::frobenius_norm is deprecated and it is just left for JIT compatibility. ",
3085 "It will be removed in a future PyTorch release. Please use ",
3086 "`linalg.vector_norm(A, 2., dim, keepdim)` instead"
3087 );
3088 }
3089 TORCH_CHECK(dim.size() <= 2,
3090 "Expected at most 2 dimensions, but got ", dim.size(), " dimensions instead.");
3091 return at::norm_out(result, self, 2., dim, keepdim);
3092 }
3093
3094 ////////////////////////////////////////////////////////////////////////////////
3095 // Nuclear Norm //
3096 ////////////////////////////////////////////////////////////////////////////////
3097
nuclear_norm(const Tensor & self,bool keepdim)3098 Tensor nuclear_norm(const Tensor& self, bool keepdim) {
3099 return at::native::nuclear_norm(self, IntArrayRef({-2, -1}), keepdim);
3100 }
3101
nuclear_norm_out(const Tensor & self,bool keepdim,Tensor & result)3102 Tensor &nuclear_norm_out(const Tensor& self, bool keepdim, Tensor& result) {
3103 auto device = self.device();
3104 if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
3105 TORCH_WARN_ONCE(
3106 "at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
3107 "It will be removed in a future PyTorch release. Please use ",
3108 "`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
3109 );
3110 }
3111 return at::linalg_matrix_norm_out(result, self, "nuc", IntArrayRef({-2, -1}), keepdim);
3112 }
3113
nuclear_norm(const Tensor & self,IntArrayRef dim,bool keepdim)3114 Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
3115 auto device = self.device();
3116 if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
3117 TORCH_WARN_ONCE(
3118 "at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
3119 "It will be removed in a future PyTorch release. Please use ",
3120 "`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
3121 );
3122 }
3123 return at::linalg_matrix_norm(self, "nuc", dim, keepdim);
3124 }
3125
nuclear_norm_out(const Tensor & self,IntArrayRef dim,bool keepdim,Tensor & result)3126 Tensor& nuclear_norm_out(const Tensor& self, IntArrayRef dim, bool keepdim, Tensor& result) {
3127 auto device = self.device();
3128 if (self.layout() == Layout::Strided && (device == kCPU || device == kCUDA || device == kMeta)) {
3129 TORCH_WARN_ONCE(
3130 "at::nuclear_norm is deprecated and it is just left for JIT compatibility. ",
3131 "It will be removed in a future PyTorch release. Please use ",
3132 "`linalg.matrix_norm(A, 'nuc', dim, keepdim)` instead"
3133 );
3134 }
3135 return at::linalg_matrix_norm_out(result, self, "nuc", dim, keepdim);
3136 }
3137
3138 ////////////////////////////////////////////////////////////////////////////////
3139 // linalg.cond //
3140 ////////////////////////////////////////////////////////////////////////////////
3141
3142
3143 // This function helps to dispatch norm computations depending on 'ord' of variant type
_linalg_cond_helper(const Tensor & self,std::variant<Scalar,c10::string_view> ord_variant)3144 static Tensor _linalg_cond_helper(const Tensor& self, std::variant<Scalar, c10::string_view> ord_variant) {
3145 Tensor inverse, info;
3146 std::tie(inverse, info) = at::linalg_inv_ex(self);
3147 info.unsqueeze_(-1).unsqueeze_(-1);
3148 inverse.masked_fill_(info > 0, INFINITY);
3149
3150 return std::visit([&](auto&& ord) {
3151 Tensor norm_self = at::linalg_matrix_norm(self, ord);
3152 Tensor norm_inverse = at::linalg_matrix_norm(inverse, ord);
3153 Tensor result = norm_self * norm_inverse;
3154 // fix multiplication of zero and infinity for NumPy compatibility
3155 result.nan_to_num_(INFINITY, INFINITY, -INFINITY);
3156 return result;
3157 }, ord_variant);
3158 }
3159
3160 // Return zero for each matrix in the batch
_linalg_cond_empty_matrix(const Tensor & self,c10::ScalarType dtype)3161 static Tensor _linalg_cond_empty_matrix(const Tensor& self, c10::ScalarType dtype) {
3162 auto result_shape = IntArrayRef(self.sizes().cbegin(), self.sizes().cend()-2);
3163 TensorOptions options = self.options().dtype(toRealValueType(self.scalar_type()));
3164 return at::zeros(result_shape, options);
3165 }
3166
_linalg_cond_check_ord(std::variant<Scalar,c10::string_view> ord_variant)3167 static void _linalg_cond_check_ord(std::variant<Scalar, c10::string_view> ord_variant) {
3168 if (ord_variant.index() == 0) {
3169 Scalar* ord = std::get_if<Scalar>(&ord_variant);
3170 double abs_ord = std::abs(ord->toDouble());
3171 TORCH_CHECK(abs_ord == 2.0 || abs_ord == 1.0 || abs_ord == INFINITY,
3172 "linalg.cond got an invalid norm type: ", ord->toDouble());
3173 } else if (ord_variant.index() == 1) {
3174 c10::string_view* ord = std::get_if<c10::string_view>(&ord_variant);
3175 TORCH_CHECK(*ord == "fro" || *ord == "nuc",
3176 "linalg.cond got an invalid norm type: ", *ord);
3177 } else {
3178 TORCH_CHECK(false,
3179 "linalg.cond: something went wrong while checking the norm type");
3180 }
3181 }
3182
3183 // Numerical or None norms
linalg_cond(const Tensor & self,const std::optional<Scalar> & opt_ord)3184 Tensor linalg_cond(const Tensor& self, const std::optional<Scalar>& opt_ord) {
3185 TORCH_CHECK(self.dim() >= 2, "linalg.cond: The input tensor must have at least 2 dimensions.");
3186
3187 // The default case is using 2-norm
3188 Scalar ord = opt_ord.has_value() ? opt_ord.value() : 2;
3189
3190 std::variant<Scalar, c10::string_view> ord_variant = ord;
3191 _linalg_cond_check_ord(ord_variant);
3192
3193 // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input
3194 if (self.sym_numel() == 0) {
3195 auto real_dtype = toRealValueType(typeMetaToScalarType(self.dtype()));
3196 return _linalg_cond_empty_matrix(self, real_dtype);
3197 }
3198
3199 // If ord == None or ord == ±2
3200 if (std::abs(ord.toDouble()) == 2.0) {
3201 auto singular_values = at::linalg_svdvals(self);
3202 // singular values are sorted in descending order
3203 auto s_max = at::narrow(singular_values, /*dim=*/-1, /*start=*/0, /*length=*/1);
3204 auto s_min = at::narrow(singular_values, /*dim=*/-1, /*start=*/-1, /*length=*/1);
3205 Tensor result;
3206 if (ord.toDouble() == -2.0) {
3207 result = s_min / s_max;
3208 } else {
3209 result = s_max / s_min;
3210 }
3211 // squeeze the result for NumPy compatibility
3212 return result.squeeze(-1);
3213 }
3214
3215 // ord == ±1 ord == ±inf
3216 if (ord.isFloatingPoint()) { // ord == ±1
3217 squareCheckInputs(self, ("linalg.cond(ord=" + std::to_string(ord.to<double>()) + ")").c_str());
3218 } else { // ord == ±inf
3219 squareCheckInputs(self, ("linalg.cond(ord=" + std::to_string(ord.to<int64_t>()) + ")").c_str());
3220 }
3221 return _linalg_cond_helper(self, std::move(ord_variant));
3222 }
3223
linalg_cond_out(const Tensor & self,const std::optional<Scalar> & opt_ord,Tensor & result)3224 Tensor& linalg_cond_out(const Tensor& self, const std::optional<Scalar>& opt_ord, Tensor& result) {
3225 checkSameDevice("linalg.cond", result, self);
3226 ScalarType real_dtype = toRealValueType(self.scalar_type());
3227 checkLinalgCompatibleDtype("linalg.cond", result.scalar_type(), real_dtype);
3228
3229 Tensor result_tmp = at::linalg_cond(self, opt_ord);
3230 at::native::resize_output(result, result_tmp.sizes());
3231 result.copy_(result_tmp);
3232 return result;
3233 }
3234
3235 // Frobenius or nuclear norms
linalg_cond(const Tensor & self,c10::string_view ord)3236 Tensor linalg_cond(const Tensor& self, c10::string_view ord) {
3237 squareCheckInputs(self, ("linalg.cond(ord=" + std::string(ord) + ")").c_str());
3238 std::variant<Scalar, c10::string_view> ord_variant = ord;
3239 _linalg_cond_check_ord(ord_variant);
3240
3241 // NumPy doesn't define the condition number for 0x0 matrices, we return 0.0 for such input
3242 if (self.numel() == 0) {
3243 return _linalg_cond_empty_matrix(self, self.scalar_type());
3244 }
3245
3246 if (ord == "nuc") {
3247 // calling matrix_norm with "nuc" on inputs with infinities raises an error
3248 // therefore we use the mathematical definition of nuclear norm directly
3249 // instead of going through the matrix_norm
3250 auto singular_values = at::linalg_svdvals(self);
3251 return singular_values.sum(-1) * (singular_values.reciprocal().sum(-1));
3252 }
3253
3254 return _linalg_cond_helper(self, std::move(ord_variant));
3255 }
3256
3257 // TODO: implement _out variant avoiding copy and using already allocated storage directly
linalg_cond_out(const Tensor & self,c10::string_view ord,Tensor & result)3258 Tensor& linalg_cond_out(const Tensor& self, c10::string_view ord, Tensor& result) {
3259 checkSameDevice("linalg.cond", result, self);
3260 ScalarType real_dtype = toRealValueType(self.scalar_type());
3261 checkLinalgCompatibleDtype("linalg.cond", result.scalar_type(), real_dtype);
3262
3263 Tensor result_tmp = at::linalg_cond(self, ord);
3264 at::native::resize_output(result, result_tmp.sizes());
3265 result.copy_(result_tmp);
3266 return result;
3267 }
3268
linalg_tensorinv(const Tensor & self,int64_t ind)3269 Tensor linalg_tensorinv(const Tensor& self, int64_t ind) {
3270 /*
3271 The idea is to reduce the problem to 2D square matrix inversion.
3272 Step 1. Calculate the shape of the result and the shape of the intermediate 2D matrix.
3273 Step 2. Reshape `self` to 2D matrix.
3274 Step 3. Invert the 2D matrix self.to_2D()
3275 There is no quick way to find out whether the matrix is invertible,
3276 so at this stage an error from at::inverse can be thrown.
3277 Note that for CUDA this causes cross-device memory synchronization that can be slow.
3278 Step 4. reshape the result.
3279 */
3280 TORCH_CHECK(ind > 0, "Expected a strictly positive integer for 'ind', but got ", ind);
3281
3282 // self[ind:]
3283 std::vector<c10::SymInt> shape_ind_end = self.sym_sizes().slice(ind).vec();
3284 // self[:ind]
3285 std::vector<c10::SymInt> shape_start_ind = self.sym_sizes().slice(0, ind).vec();
3286
3287 c10::SymInt prod_ind_end = c10::multiply_integers(shape_ind_end.cbegin(), shape_ind_end.cend());
3288 c10::SymInt prod_start_ind = c10::multiply_integers(shape_start_ind.cbegin(), shape_start_ind.cend());
3289
3290 // Check whether the self tensor can be reshaped to the 2D square matrix
3291 TORCH_CHECK(prod_ind_end == prod_start_ind,
3292 "Expected self to satisfy the requirement prod(self.shape[ind:]) == prod(self.shape[:ind]), but got ",
3293 prod_ind_end, " != ", prod_start_ind);
3294
3295 // Concatenate shape_ind_end and shape_start_ind to form the shape of the result
3296 // self[ind:] + self[:ind]
3297 shape_ind_end.insert(shape_ind_end.cend(), shape_start_ind.cbegin(), shape_start_ind.cend());
3298
3299 // If the reshaped self is not invertible catch this error
3300 auto [result, info] = at::linalg_inv_ex(self.reshape_symint({prod_ind_end, prod_ind_end}), /*check_errors=*/false);
3301 at::_linalg_check_errors(info, "inv", /*is_matrix*/true);
3302
3303 return result.reshape_symint(shape_ind_end);
3304 }
3305
3306 // TODO: implement _out variant avoiding copy and using already allocated storage directly
linalg_tensorinv_out(const Tensor & self,int64_t ind,Tensor & result)3307 Tensor& linalg_tensorinv_out(const Tensor& self, int64_t ind, Tensor& result) {
3308 checkSameDevice("tensorinv", result, self);
3309 checkLinalgCompatibleDtype("tensorinv", result, self);
3310
3311 Tensor result_tmp = at::linalg_tensorinv(self, ind);
3312 at::native::resize_output(result, result_tmp.sizes());
3313 result.copy_(result_tmp);
3314 return result;
3315 }
3316
linalg_tensorsolve(const Tensor & self,const Tensor & other,OptionalIntArrayRef dims)3317 Tensor linalg_tensorsolve(const Tensor& self, const Tensor& other, OptionalIntArrayRef dims) {
3318 /*
3319 The idea is to reduce the problem to 2D matrix solve.
3320 Step 1. (optional) `self` is permuted with `dims` such that dimensions from `dims` are moved to the right.
3321 For example, if we have 4D input with the shape (1, 2, 3, 4) and dims=(0, 2),
3322 then the result of permutation would have the shape (2, 4, 1, 3).
3323 Step 2. reshape `self` to 2D matrix.
3324 Step 3. solve the matrix equation self.to_2D() @ result = other.to_1D()
3325 Step 4. reshape the result.
3326 */
3327 int64_t ndim = self.dim();
3328 Tensor self_ = self;
3329
3330 // move dimensions of `self_` from `dims` to the end
3331 if (dims.has_value()) {
3332 DimVector dest_axes(dims.value().size());
3333 std::iota(dest_axes.begin(), dest_axes.end(), ndim - dest_axes.size());
3334 self_ = at::movedim(self_, dims.value(), dest_axes);
3335 }
3336
3337 // result_shape is self_.sizes[-(an-other.dim):]
3338 std::vector<c10::SymInt> result_shape = self_.sym_sizes().slice(other.dim(), ndim - other.dim()).vec();
3339
3340 c10::SymInt result_product = c10::multiply_integers(result_shape.begin(), result_shape.end());
3341 c10::SymInt other_product = c10::multiply_integers(other.sym_sizes().begin(), other.sym_sizes().end());
3342
3343 // Check whether the self tensor can be reshaped to the 2D square matrix
3344 TORCH_CHECK(result_product == other_product,
3345 "Expected self to satisfy the requirement prod(self.shape[other.ndim:]) == prod(self.shape[:other.ndim]), but got ",
3346 result_product, " != ", other_product);
3347
3348 self_ = self_.reshape_symint({result_product, result_product});
3349
3350 // normally `other` would be flattened by at::linalg_solve expects 2D input
3351 Tensor result = at::linalg_solve(self_, other.flatten());
3352 return result.reshape_symint(result_shape);
3353 }
3354
linalg_tensorsolve_out(const Tensor & self,const Tensor & other,OptionalIntArrayRef dims,Tensor & result)3355 Tensor& linalg_tensorsolve_out(const Tensor& self, const Tensor& other, OptionalIntArrayRef dims, Tensor& result) {
3356 checkSameDevice("tensorsolve", result, self);
3357 checkLinalgCompatibleDtype("tensorsolve", result, self);
3358
3359 Tensor result_tmp = at::linalg_tensorsolve(self, other, dims);
3360 at::native::resize_output(result, result_tmp.sizes());
3361 result.copy_(result_tmp);
3362 return result;
3363 }
3364
3365 namespace {
3366 struct KronImpl final {
3367 public:
KronImplat::native::__anon4af68b541811::KronImpl3368 explicit KronImpl(const Tensor& self, const Tensor& other) {
3369 maxdim = std::max(self.dim(), other.dim());
3370 int64_t pad_self = maxdim - self.dim();
3371 int64_t pad_other = maxdim - other.dim();
3372 a_reshape = c10::SmallVector<int64_t, 10>(2 * maxdim);
3373 b_reshape = c10::SmallVector<int64_t, 10>(2 * maxdim);
3374 result_reshape = c10::SmallVector<int64_t, 10>(maxdim);
3375 for (const auto i : c10::irange(maxdim)) {
3376 a_reshape[2 * i] = (i >= pad_self ? self.sizes()[i - pad_self] : 1);
3377 a_reshape[2 * i + 1] = 1;
3378 b_reshape[2 * i] = 1;
3379 b_reshape[2 * i + 1] = (i >= pad_other ? other.sizes()[i - pad_other] : 1);
3380 result_reshape[i] = a_reshape[2 * i] * b_reshape[2 * i + 1];
3381 }
3382 self_view = at::_unsafe_view(self, a_reshape);
3383 other_view = at::_unsafe_view(other, b_reshape);
3384 }
3385
kron_outat::native::__anon4af68b541811::KronImpl3386 Tensor& kron_out(Tensor& result) const {
3387 TORCH_INTERNAL_ASSERT(result.defined(), "Cannot call kron_out with an undefined result tensor as the out argument. Please allocate a Tensor before calling kron_out with it.");
3388
3389 c10::SmallVector<int64_t, 10> mul_shape(2 * maxdim);
3390 for (const auto i : c10::irange(maxdim)) {
3391 mul_shape[2 * i] = a_reshape[2 * i];
3392 mul_shape[2 * i + 1] = b_reshape[2 * i + 1];
3393 }
3394 at::native::resize_output(result, result_reshape);
3395 auto result_mul = at::_unsafe_view(result, mul_shape);
3396 at::mul_out(result_mul, self_view, other_view);
3397
3398 return result;
3399 }
3400
kronat::native::__anon4af68b541811::KronImpl3401 Tensor kron() const {
3402 return at::_unsafe_view(at::mul(self_view, other_view), result_reshape);
3403 }
3404 private:
3405 int64_t maxdim;
3406 Tensor self_view;
3407 Tensor other_view;
3408 c10::SmallVector<int64_t, 10> result_reshape;
3409 c10::SmallVector<int64_t, 10> a_reshape;
3410 c10::SmallVector<int64_t, 10> b_reshape;
3411 };
3412 }
3413
3414 /*
3415 Calculates the Kronecker product between two Tensors.
3416 */
kron_out(const Tensor & self,const Tensor & other,Tensor & result)3417 Tensor& kron_out(const Tensor& self, const Tensor& other, Tensor& result) {
3418 return KronImpl(self, other).kron_out(result);
3419 }
3420
kron(const Tensor & self,const Tensor & other)3421 Tensor kron(const Tensor& self, const Tensor& other) {
3422 return KronImpl(self, other).kron();
3423 }
3424
3425 // Weight Only Quantization Gemm
3426 DEFINE_DISPATCH(weight_to_int4pack_stub);
3427 DEFINE_DISPATCH(int4pack_mm_stub);
3428 DEFINE_DISPATCH(int8pack_mm_stub);
3429
_convert_weight_to_int4pack_cpu(const Tensor & in,int64_t innerKTiles)3430 Tensor _convert_weight_to_int4pack_cpu(
3431 const Tensor& in,
3432 int64_t innerKTiles) {
3433
3434 TORCH_CHECK(in.dim() == 2,
3435 __func__, " : expect weight to be 2D tensor.");
3436 TORCH_CHECK(in.dtype() == at::kByte,
3437 __func__, " : expect weight to be kByte.");
3438 TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8,
3439 __func__, " : innerKTiles need to be 2, 4, or 8, got ", innerKTiles);
3440
3441 auto weight = in.contiguous();
3442 auto N = weight.size(0);
3443 auto K = weight.size(1) * 2;
3444
3445 // Create fake shapes for cpu. The meta registration in dynamo requires
3446 // operator has the same output shape for each device. So creating a fake
3447 // shape {N / 8, K / (16 * innerKTiles), 32, innerKTiles / 2}
3448 constexpr int64_t kNTileSize = 8;
3449 constexpr int64_t kKTileSize = 16;
3450 auto nTiles = (N + kNTileSize - 1) / kNTileSize;
3451
3452 TORCH_CHECK(N % 16 == 0,
3453 __func__, " : expect N to be dividable by 16");
3454 const int64_t kSuperKTileSize = kKTileSize * innerKTiles;
3455 TORCH_CHECK( K % kSuperKTileSize == 0,
3456 __func__, " : epxect K to be dividable by ", kSuperKTileSize);
3457 auto kSuperTiles = (K + kSuperKTileSize - 1) / kSuperKTileSize;
3458
3459 auto weight_packed = at::empty(
3460 {nTiles, kSuperTiles, 32, innerKTiles / 2},
3461 at::TensorOptions().dtype(at::kInt));
3462
3463 weight_to_int4pack_stub(kCPU, weight_packed, weight, N, K);
3464 return weight_packed;
3465 }
3466
_weight_int4pack_mm_cpu(const Tensor & A,const Tensor & B,int64_t qGroupSize,const Tensor & qScaleAndZeros)3467 Tensor _weight_int4pack_mm_cpu(
3468 const Tensor& A,
3469 const Tensor& B,
3470 int64_t qGroupSize,
3471 const Tensor& qScaleAndZeros) {
3472
3473 constexpr int64_t kNTileSize = 8;
3474
3475 auto M = A.size(0);
3476 auto N = B.size(0) * kNTileSize;
3477 auto K = A.size(1);
3478
3479 TORCH_CHECK(A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
3480 __func__, " : expect A to be either 32-bit or 16-bit float tensor.");
3481 TORCH_CHECK(A.is_contiguous(),
3482 __func__, " : expect A to be contiguous.");
3483 TORCH_CHECK(A.dim() == 2,
3484 __func__, " : expect A to be 2D tensor.");
3485
3486 TORCH_CHECK(B.dtype() == kInt,
3487 __func__, " : expect B to be int32 tensor.");
3488 TORCH_CHECK(B.is_contiguous(),
3489 __func__, " : expect B to be contiguous.");
3490 TORCH_CHECK(B.dim() == 4,
3491 __func__, " : expect B to 4d tensor.");
3492
3493 TORCH_CHECK(qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128
3494 || qGroupSize == 256,
3495 __func__, ": expect qGroupSize to be 32, 64, 128 or 256, got ", qGroupSize);
3496
3497 TORCH_CHECK(qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(1) == N
3498 && qScaleAndZeros.size(2) == 2,
3499 __func__, ": expect qScaleAndZeros to be 3d tensor with sizes [:, ", N, ", 2]");
3500
3501 auto C = at::empty({M, N}, A.options());
3502 int4pack_mm_stub(kCPU, C, A, B, qGroupSize, qScaleAndZeros, N, K);
3503
3504 return C;
3505 }
3506
_weight_int8pack_mm_cpu(const Tensor & A,const Tensor & B,const Tensor & scales)3507 Tensor _weight_int8pack_mm_cpu(
3508 const Tensor& A,
3509 const Tensor& B,
3510 const Tensor& scales) {
3511
3512 auto M = A.size(0);
3513 auto N = B.size(0);
3514 auto K = A.size(1);
3515
3516 TORCH_CHECK(A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
3517 __func__, " : expect A to be either 32-bit or 16-bit float tensor.");
3518 TORCH_CHECK(A.is_contiguous(),
3519 __func__, " : expect A to be contiguous.");
3520 TORCH_CHECK(A.dim() == 2,
3521 __func__, " : expect A to be 2D tensor.");
3522
3523 TORCH_CHECK(B.dtype() == kChar,
3524 __func__, " : expect B to be int8 tensor.");
3525 TORCH_CHECK(B.is_contiguous(),
3526 __func__, " : expect B to be contiguous.");
3527 TORCH_CHECK(B.size(1) == K,
3528 __func__, " : expect B.size(1) == ", K);
3529
3530 TORCH_CHECK(scales.dim() == 1 && scales.size(0) == N,
3531 __func__, " : expect scales to be 1d tensor with size ", N);
3532
3533 auto C = at::empty({M, N}, A.options());
3534 int8pack_mm_stub(kCPU, C, A, B, scales);
3535
3536 return C;
3537 }
3538
_int_mm_out_cpu(const Tensor & self,const Tensor & mat2,Tensor & result)3539 Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result) {
3540 #ifndef STRIP_ERROR_MESSAGES
3541 static constexpr c10::string_view func_name = "int_mm_out_cpu";
3542 #endif
3543 TORCH_CHECK(self.dim() == 2, func_name, ": Expected self to be of dimension 2 but got ", self.dim());
3544 TORCH_CHECK(mat2.dim() == 2, func_name, ": Expected mat2 to be of dimension 2 but got ", mat2.dim());
3545 TORCH_CHECK(self.size(1) == mat2.size(0), func_name, ": self.size(1) needs to match mat2.size(0) but got ", self.size(1), " and ", mat2.size(0));
3546 TORCH_CHECK(self.dtype() == at::kChar, func_name, ": Expected self dtype to be of type int8 but got ", self.dtype());
3547 TORCH_CHECK(mat2.dtype() == at::kChar, func_name, ": Expected mat2 dtype to be of type int8 but got ", mat2.dtype());
3548 TORCH_CHECK(result.dtype() == at::kInt, func_name, ": Expected result dtype to be of type kInt but got ", result.dtype());
3549 TORCH_CHECK(result.size(0) == self.size(0), func_name, ": Expected result.size(0) to be ", self.size(0), " but got ", result.size(0));
3550 TORCH_CHECK(result.size(1) == mat2.size(1), func_name, ": Expected result.size(1) to be ", mat2.size(1), " but got ", result.size(1));
3551 TORCH_CHECK(result.dim() == 2, func_name, ": Expected result to be of dimension 2 but got ", result.dim());
3552 TORCH_CHECK(result.is_contiguous(), func_name, ": Expected result to be contiguous.");
3553
3554 if (result.numel() == 0 || self.size(1) == 0) {
3555 return result.zero_();
3556 }
3557
3558 bool dispatched = false;
3559 if (at::globalContext().userEnabledMkldnn()) {
3560 try {
3561 mkldnn_matmul_i8i8i32(self, mat2, result);
3562 dispatched = true;
3563 } catch (const std::exception& e) {
3564 TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
3565 }
3566 }
3567 if (!dispatched) {
3568 auto a = reinterpret_cast<int8_t*>(self.data_ptr());
3569 auto b = reinterpret_cast<int8_t*>(mat2.data_ptr());
3570 auto c = reinterpret_cast<int32_t*>(result.data_ptr());
3571 const int64_t m = result.size(0);
3572 const int64_t n = result.size(1);
3573 const int64_t k = self.size(1);
3574 const int64_t lda_0 = self.strides()[0];
3575 const int64_t lda_1 = self.strides()[1];
3576 const int64_t ldb_0 = mat2.strides()[0];
3577 const int64_t ldb_1 = mat2.strides()[1];
3578 const int64_t ldc = result.strides()[0];
3579 parallel_for(0, m * n, 1, [&](int64_t start, int64_t end) {
3580 for (const auto i : c10::irange(start, end)) {
3581 auto row = i / n;
3582 auto col = i % n;
3583 c[row * ldc + col] = 0;
3584 for (const auto k : c10::irange(k)) {
3585 c[row * ldc + col] = c[row * ldc + col] +
3586 static_cast<int32_t>(a[row * lda_0 + k * lda_1]) *
3587 static_cast<int32_t>(b[k * ldb_0 + col * ldb_1]);
3588 }
3589 }
3590 });
3591 }
3592 return result;
3593 }
3594
_int_mm_cpu(const Tensor & self,const Tensor & mat2)3595 Tensor _int_mm_cpu(const Tensor& self, const Tensor& mat2) {
3596 Tensor result = at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
3597 return _int_mm_out_cpu(self, mat2, result);
3598 }
3599
3600 } // namespace native
3601 } // namespace at
3602