xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 
9 namespace at::functorch {
10 
11 typedef std::tuple<Tensor, std::optional<int64_t>> oneOutput;
12 typedef std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>> twoOutputs;
13 typedef std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>> threeOutputs;
14 typedef std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>> fourOutputs;
15 
16 namespace {
17 
18 // Note [Batching rules for matmul-like operators]
19 // at::matmul doesn't "de-expand" arguments to get better performance (maybe
20 // it should). In the batching rules for matmul-like operators (dot, mv, mm),
21 // we should be careful not to expand any unnecessary dimensions. i.e., if
22 // only one of the two arguments is a BatchedTensor, then we should try
23 // not to expand batch dimensions onto the other arg.
24 
dot_batch_rule(const Tensor & A,std::optional<int64_t> A_bdim,const Tensor & B,std::optional<int64_t> B_bdim)25 std::tuple<Tensor, std::optional<int64_t>> dot_batch_rule(const Tensor& A, std::optional<int64_t> A_bdim, const Tensor& B, std::optional<int64_t> B_bdim) {
26   TORCH_CHECK(A.dim() - A_bdim.has_value() == 1 && B.dim() - B_bdim.has_value() == 1, "Got wrong shapes for dot");
27   auto A_ = moveBatchDimToFront(A, A_bdim);
28   auto B_ = moveBatchDimToFront(B, B_bdim);
29   if (A_bdim && B_bdim) {
30     return std::make_tuple(at::matmul(A_.unsqueeze(-2), B_.unsqueeze(-1)).squeeze(-1).squeeze(-1), 0);
31   } else {
32     return std::make_tuple(at::matmul(A_, B_.t()), 0);
33   }
34 }
vdot_decomp(const Tensor & A,const Tensor & B)35 Tensor vdot_decomp(const Tensor& A, const Tensor& B) {
36   return at::dot(A.is_complex() ? A.conj() : A, B);
37 }
38 
39 // NB: I wrote this like this because we *might* want its for a future matmul
40 // batch rule that isn't decomposed...
41 // "tv" = tensor @ vector
tv_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)42 static std::tuple<Tensor, std::optional<int64_t>> tv_batch_rule(
43     const Tensor& self, std::optional<int64_t> self_bdim,
44     const Tensor& other, std::optional<int64_t> other_bdim) {
45   if (self_bdim && other_bdim) {
46     // See Note [Batching rules for matmul-like operators]
47     // B...OI, BI -> ...BOI, BI1 -> ...BO1 -> ...BO
48     auto self_ = at::movedim(self, *self_bdim, -3);
49     auto other_ = moveBatchDimToFront(other, other_bdim);
50     other_ = other_.unsqueeze(-1);
51     auto result = at::matmul(self_, other_).squeeze(-1);
52     auto result_bdim = result.dim() - 2;
53     return std::make_tuple( std::move(result), result_bdim );
54   }
55   else if (self_bdim && !other_bdim) {
56     // B...OI, I -> B...O
57     auto self_ = moveBatchDimToFront(self, self_bdim);
58     return std::make_tuple( at::matmul(self_, other), 0 );
59   }
60   else if (!self_bdim && other_bdim) {
61     // ...OI, BI -> ...OI, IB -> OB
62     auto other_ = at::movedim(other, *other_bdim, -1);
63     auto result = at::matmul(self, other_);
64     return std::make_tuple( std::move(result), 1 );
65   }
66   TORCH_INTERNAL_ASSERT(false, "can't get here");
67 }
68 
mv_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)69 static std::tuple<Tensor, std::optional<int64_t>> mv_batch_rule(
70     const Tensor& self, std::optional<int64_t> self_bdim,
71     const Tensor& other, std::optional<int64_t> other_bdim) {
72   auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
73   auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
74   TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 1,
75       "Shape mismatch: ",
76       "Got incorrect dims for mv(a, b). a has dim ", self_logical_rank,
77       "and b has dim ", other_logical_rank,
78       "but expected them to have dim 2 and dim 1");
79   return tv_batch_rule(self, self_bdim, other, other_bdim);
80 }
81 
mm_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)82 static std::tuple<Tensor, std::optional<int64_t>> mm_batch_rule(
83     const Tensor& self, std::optional<int64_t> self_bdim,
84     const Tensor& other, std::optional<int64_t> other_bdim) {
85   auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
86   auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
87   TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 2,
88       "Shape mismatch: Got incorrect dims for mm(a, b). "
89       "a has dim ", self_logical_rank,
90       "and b has dim ", other_logical_rank,
91       "but expected them to have dim 2 and dim 2");
92   auto self_ = moveBatchDimToFront(self, self_bdim);
93   auto other_ = moveBatchDimToFront(other, other_bdim);
94   return std::make_tuple( at::matmul(self_, other_), 0 );
95 }
96 
bmm_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)97 static std::tuple<Tensor, std::optional<int64_t>> bmm_batch_rule(
98     const Tensor& self, std::optional<int64_t> self_bdim,
99     const Tensor& other, std::optional<int64_t> other_bdim) {
100   auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
101   auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
102   TORCH_CHECK(self_logical_rank == 3 && other_logical_rank == 3,
103       "Shape mismatch: Got incorrect dims for bmm(a, b). "
104       "a has dim ", self_logical_rank,
105       "and b has dim ", other_logical_rank,
106       "but expected them to have dim 3 and dim 3");
107   auto self_ = moveBatchDimToFront(self, self_bdim);
108   auto other_ = moveBatchDimToFront(other, other_bdim);
109   return std::make_tuple( at::matmul(self_, other_), 0 );
110 }
111 
112 // AFAICT, nothing here can be batched. So we decompose :)
addmv_decomp(const Tensor & input,const Tensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha)113 Tensor addmv_decomp(
114   const Tensor& input, const Tensor& mat, const Tensor& vec, const Scalar& beta, const Scalar& alpha) {
115   Tensor out = at::mv(mat, vec);
116   if (!alpha.equal(1)) {
117     out = alpha * out;
118   }
119   if (!beta.equal(0)) {
120     out = beta * input + out;
121   }
122   return out;
123 }
124 
addbmm_decomp(const Tensor & input,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)125 Tensor addbmm_decomp(
126   const Tensor& input, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
127   Tensor out = at::bmm(batch1, batch2).sum(0);
128   if (!alpha.equal(1)) {
129     out = alpha * out;
130   }
131   if (!beta.equal(0)) {
132     out = beta * input + out;
133   }
134   return out;
135 }
136 
baddbmm_decomp(const Tensor & input,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)137 Tensor baddbmm_decomp(
138   const Tensor& input, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
139   Tensor out = at::bmm(batch1, batch2);
140   if (!alpha.equal(1)) {
141     out = alpha * out;
142   }
143   if (!beta.equal(0)) {
144     out = beta * input + out;
145   }
146   return out;
147 }
148 
addmm_decomp(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha)149 Tensor addmm_decomp(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
150   // Decomposition that is probably not very fast...
151   return at::add(self * beta, at::mm(mat1, mat2), alpha);
152 }
153 
_linalg_check_errors_batch_rule(const Tensor & info,std::optional<int64_t> info_bdim,c10::string_view api_name,bool is_matrix)154 void _linalg_check_errors_batch_rule(const Tensor& info, std::optional<int64_t> info_bdim, c10::string_view api_name, bool is_matrix) {
155   auto info_ = moveBatchDimToFront(info, info_bdim);
156   // Not a matrix means this is a batch of matrices
157   at::_linalg_check_errors(info_, api_name, false);
158 }
159 
160 std::tuple<Tensor, std::optional<int64_t>>
householder_product_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & tau,std::optional<int64_t> tau_bdim)161 householder_product_batch_rule(const Tensor &input, std::optional<int64_t> input_bdim,
162                                const Tensor &tau, std::optional<int64_t> tau_bdim)
163 {
164   auto input_ = moveBatchDimToFront(input, input_bdim);
165   auto tau_ = moveBatchDimToFront(tau, tau_bdim);
166 
167   auto batch_size = get_bdim_size2(input, input_bdim, tau, tau_bdim);
168 
169   input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size);
170   tau_ = ensure_has_bdim(tau_, tau_bdim.has_value(), batch_size);
171   return std::make_tuple(at::linalg_householder_product(input_, tau_), 0);
172 }
173 
174 template <char const *op_name, typename A, A a, typename C>
175 struct LinalgCheckMatrixUnaryRuleHelper;
176 
177 template <char const *op_name, typename F, F Func, typename A, typename... T>
178 struct LinalgCheckMatrixUnaryRuleHelper<op_name, F, Func, typelist<A, T...>> {
check_and_reshape_inputat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixUnaryRuleHelper179   static inline Tensor check_and_reshape_input(const Tensor& tensor, std::optional<int64_t> batch_dim) {
180     TORCH_CHECK(rankWithoutBatchDim(tensor, batch_dim) >= 2, op_name, ": The input tensor A must have at least 2 dimensions.");
181     return moveBatchDimToFront(tensor, batch_dim);
182   }
183 
apply_oneat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixUnaryRuleHelper184   static oneOutput apply_one(
185       const Tensor& tensor,
186       std::optional<int64_t> batch_dim,
187       T... extra_args) {
188     const auto tensor_ = check_and_reshape_input(tensor, batch_dim);
189     return std::make_tuple(Func(tensor_, std::forward<T>(extra_args)...), 0);
190   }
191 
apply_twoat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixUnaryRuleHelper192   static twoOutputs apply_two(
193       const Tensor& tensor,
194       std::optional<int64_t> batch_dim,
195       T... extra_args) {
196     const auto tensor_ = check_and_reshape_input(tensor, batch_dim);
197     const auto res = Func(tensor_, std::forward<T>(extra_args)...);
198     return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0);
199   }
200 
apply_threeat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixUnaryRuleHelper201   static threeOutputs apply_three(
202       const Tensor& tensor,
203       std::optional<int64_t> batch_dim,
204       T... extra_args) {
205     const auto tensor_ = check_and_reshape_input(tensor, batch_dim);
206     const auto res = Func(tensor_, std::forward<T>(extra_args)...);
207     return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0);
208   }
209 
apply_fourat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixUnaryRuleHelper210   static fourOutputs apply_four(
211       const Tensor& tensor,
212       std::optional<int64_t> batch_dim,
213       T... extra_args) {
214     const auto tensor_ = check_and_reshape_input(tensor, batch_dim);
215     const auto res = Func(tensor_, std::forward<T>(extra_args)...);
216     return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0, std::get<3>(res), 0);
217   }
218 };
219 
220 template <char const *op_name, typename A, A a, typename C>
221 struct LinalgCheckMatrixBinaryRuleHelper;
222 
223 template <char const *op_name, typename F, F Func, typename A, typename B, typename... T>
224 struct LinalgCheckMatrixBinaryRuleHelper<op_name, F, Func, typelist<A, B, T...>> {
check_inputs_and_reshape_inputsat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixBinaryRuleHelper225   static inline std::tuple<Tensor, Tensor> check_inputs_and_reshape_inputs(
226       const Tensor& first, std::optional<int64_t> first_bdim,
227       const Tensor& second, std::optional<int64_t> second_bdim) {
228     TORCH_CHECK(rankWithoutBatchDim(first, first_bdim) >= 2,
229                 op_name, ": The input tensor A must have at least 2 dimensions.");
230     TORCH_CHECK(rankWithoutBatchDim(second, second_bdim) >= 2,
231                 op_name, ": The input tensor B must have at least 2 dimensions.");
232     return _binary_pointwise_helper(first, first_bdim, second, second_bdim, false);
233   }
234 
apply_oneat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixBinaryRuleHelper235   static oneOutput apply_one(
236       const Tensor& first, std::optional<int64_t> first_bdim,
237       const Tensor& second, std::optional<int64_t> second_bdim,
238       T... extra_args) {
239     const auto tensor_other = check_inputs_and_reshape_inputs(first, first_bdim, second, second_bdim);
240     const auto tensor_ = std::get<0>(tensor_other);
241     const auto other_ = std::get<1>(tensor_other);
242     return std::make_tuple(Func(tensor_, other_, std::forward<T>(extra_args)...), 0);
243   }
244 
apply_twoat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixBinaryRuleHelper245   static twoOutputs apply_two(
246       const Tensor& first, std::optional<int64_t> first_bdim,
247       const Tensor& second, std::optional<int64_t> second_bdim,
248       T... extra_args) {
249     const auto tensor_other = check_inputs_and_reshape_inputs(first, first_bdim, second, second_bdim);
250     const auto tensor_ = std::get<0>(tensor_other);
251     const auto other_ = std::get<1>(tensor_other);
252     const auto res = Func(tensor_, other_, std::forward<T>(extra_args)...);
253     return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0);
254   }
255 };
256 
expect_at_least_rank(const Tensor & tensor,std::optional<int64_t> tensor_bdim,int64_t expected_rank,const char * name)257 static void expect_at_least_rank(
258     const Tensor& tensor,
259     std::optional<int64_t> tensor_bdim,
260     int64_t expected_rank,
261     const char* name) {
262   auto rank = rankWithoutBatchDim(tensor, tensor_bdim);
263   TORCH_CHECK(rank >= expected_rank,
264       name, " should have at least ", expected_rank, " dimensions, but has ",
265       rank, " dimensions instead.");
266 }
267 
linalg_lu_unpack_batch_rule(const Tensor & LU,std::optional<int64_t> LU_bdim,const Tensor & pivots,std::optional<int64_t> pivots_bdim,bool unpack_data,bool unpack_pivots)268 threeOutputs linalg_lu_unpack_batch_rule(
269     const Tensor& LU, std::optional<int64_t> LU_bdim,
270     const Tensor& pivots, std::optional<int64_t> pivots_bdim,
271     bool unpack_data, bool unpack_pivots) {
272   auto LU_ = moveBatchDimToFront(LU, LU_bdim);
273   auto pivots_ = moveBatchDimToFront(pivots, pivots_bdim);
274 
275   // LU and pivots's first {N-2} (for LU), {N-1} (for pivots) dimensions must
276   // match So if only one of them is being vmapped over, we must expand out that
277   // dimension.
278   if (LU_bdim.has_value() != pivots_bdim.has_value()) {
279     auto bdim_size = get_bdim_size2(LU, LU_bdim, pivots, pivots_bdim);
280     LU_ = ensure_has_bdim(LU_, LU_bdim.has_value(), bdim_size);
281     pivots_ = ensure_has_bdim(pivots_, pivots_bdim.has_value(), bdim_size);
282     pivots_bdim = 0;
283     LU_bdim = 0;
284   }
285 
286   const auto res = at::lu_unpack(LU_, pivots_, unpack_data, unpack_pivots);
287   return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0);
288 }
289 
linalg_lu_solve_batch_rule(const Tensor & LU,std::optional<int64_t> LU_bdim,const Tensor & pivots,std::optional<int64_t> pivots_bdim,const Tensor & B,std::optional<int64_t> B_bdim,bool left,bool adjoint)290 oneOutput linalg_lu_solve_batch_rule(
291     const Tensor& LU, std::optional<int64_t> LU_bdim,
292     const Tensor& pivots, std::optional<int64_t> pivots_bdim,
293     const Tensor& B, std::optional<int64_t> B_bdim,
294     bool left, bool adjoint) {
295   const auto LU_min_rank = 2;
296   const auto pivots_min_rank = 1;
297   const auto B_min_rank = 2;
298 
299   expect_at_least_rank(LU, LU_bdim, LU_min_rank, "LU");
300   expect_at_least_rank(pivots, pivots_bdim, pivots_min_rank, "pivots");
301   expect_at_least_rank(B, B_bdim, B_min_rank, "B");
302 
303   auto LU_ = moveBatchDimToFront(LU, LU_bdim);
304   auto pivots_ = moveBatchDimToFront(pivots, pivots_bdim);
305   auto B_ = moveBatchDimToFront(B, B_bdim);
306 
307   // LU and pivots's first {N-2} (for LU), {N-1} (for pivots) dimensions must match
308   // So if only one of them is being vmapped over, we must expand out that dimension.
309   if (LU_bdim.has_value() ^ pivots_bdim.has_value()) {
310     auto bdim_size = get_bdim_size2(LU, LU_bdim, pivots, pivots_bdim);
311     LU_ = ensure_has_bdim(LU_, LU_bdim.has_value(), bdim_size);
312     pivots_ = ensure_has_bdim(pivots_, pivots_bdim.has_value(), bdim_size);
313     pivots_bdim = 0;
314     LU_bdim = 0;
315   }
316 
317   // Now, {LU, pivots} and B's first dimensions are allowed to broadcast.
318   // The rest of the logic handles that.
319   const auto LU_num_batch_dims = rankWithoutBatchDim(LU_, LU_bdim) - LU_min_rank;
320   const auto pivots_num_batch_dims = rankWithoutBatchDim(pivots_, pivots_bdim) - pivots_min_rank;
321   const auto B_num_batch_dims = rankWithoutBatchDim(B_, B_bdim) - B_min_rank;
322   const auto max_num_batch_dims = std::max(std::max(LU_num_batch_dims, pivots_num_batch_dims), B_num_batch_dims);
323 
324   LU_ = maybePadToLogicalRank(LU_, LU_bdim, max_num_batch_dims + LU_min_rank);
325   pivots_ = maybePadToLogicalRank(pivots_, pivots_bdim, max_num_batch_dims + pivots_min_rank);
326   B_ = maybePadToLogicalRank(B_, B_bdim, max_num_batch_dims + B_min_rank);
327 
328   const auto result = at::linalg_lu_solve(LU_, pivots_, B_, left, adjoint);
329   return std::make_tuple(result, 0);
330 }
331 
cholesky_solve_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & A,std::optional<int64_t> A_bdim,bool upper)332 oneOutput cholesky_solve_batch_rule(
333     const Tensor& self, std::optional<int64_t> self_bdim,
334     const Tensor& A, std::optional<int64_t> A_bdim,
335     bool upper) {
336   TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2,
337            "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
338   TORCH_CHECK(rankWithoutBatchDim(A, A_bdim) >= 2,
339            "u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead");
340 
341   const auto tensor_other = _binary_pointwise_helper(self, self_bdim, A, A_bdim, /*do_type_promotion=*/false);
342   const auto tensor_ = std::get<0>(tensor_other);
343   const auto other_ = std::get<1>(tensor_other);
344   return std::make_tuple(at::cholesky_solve(tensor_, other_, upper), 0);
345 }
346 
linalg_lu_factor_ex_batch_rule(const Tensor & A,std::optional<int64_t> A_bdim,bool pivot,bool check_errors)347 threeOutputs linalg_lu_factor_ex_batch_rule(
348     const Tensor& A, std::optional<int64_t> A_bdim, bool pivot, bool check_errors) {
349   TORCH_CHECK(rankWithoutBatchDim(A, A_bdim) >= 2, "torch.lu_factor_ex: Expected tensor with 2 or more dimensions. Got size: ", A.sizes(), " instead");
350   const auto A_ = moveBatchDimToFront(A, A_bdim);
351   const auto res = at::linalg_lu_factor_ex(A_, pivot, check_errors);
352   return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0);
353 }
354 
matrix_exp_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim)355 oneOutput matrix_exp_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim) {
356   TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2, "torch.matrix_exp: The input tensor A must have at least 2 dimensions.");
357   const auto self_ = moveBatchDimToFront(self, self_bdim).contiguous();  // seems to be a bug
358   return std::make_tuple(at::matrix_exp(self_), 0);
359 }
360 
solve_ex_batch_rule(const Tensor & A,std::optional<int64_t> A_bdim,const Tensor & B,std::optional<int64_t> B_bdim,bool left,bool check_errors)361 fourOutputs solve_ex_batch_rule(
362     const Tensor& A, std::optional<int64_t> A_bdim,
363     const Tensor& B, std::optional<int64_t> B_bdim,
364     bool left, bool check_errors) {
365   auto batch_size = get_bdim_size2(A, A_bdim, B, B_bdim);
366   const auto A_logical_rank = rankWithoutBatchDim(A, A_bdim);
367   const auto B_logical_rank = rankWithoutBatchDim(B, B_bdim);
368   const auto max_logical_rank = std::max(A_logical_rank, B_logical_rank);
369 
370   TORCH_CHECK(A_logical_rank >= 2,
371             "linalg.solve: The input tensor A must have at least 2 dimensions.");
372 
373   auto b_logical_rank = max_logical_rank;
374   if (A_logical_rank > B_logical_rank) {  // vector case: B was a vector or batched vector
375     // not accurate but matches linalg error message
376     TORCH_CHECK(B_logical_rank >= 1, "linalg.solve: The input tensor B must have at least 2 dimensions.");
377     b_logical_rank = max_logical_rank - 1;
378   } else {  // matrix case: A and B are both matrices or batches of matrices
379     TORCH_CHECK(B_logical_rank >= 2, "linalg.solve: The input tensor B must have at least 2 dimensions.");
380   }
381 
382   // basically binary pointwise helper but if B was a vector incoming, we must pad it to be 1 dim smaller than A
383   auto A_ = moveBatchDimToFront(A, A_bdim);
384   auto B_ = moveBatchDimToFront(B, B_bdim);
385   A_ = maybePadToLogicalRank(A_, A_bdim, max_logical_rank);
386   B_ = maybePadToLogicalRank(B_, B_bdim, b_logical_rank);
387 
388   A_ = ensure_has_bdim(A_, A_bdim.has_value(), batch_size);
389   B_ = ensure_has_bdim(B_, B_bdim.has_value(), batch_size);
390 
391   // NOTE [ solve_ex Batch Rule Contiguity ]
392   // A determines whether or not linalg_solve takes an optimized path. We need the check on A_ to match the one run on
393   // A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behvaior
394   // differs based on whether or not the optimized path was taken
395   const auto batched_A_was_contiguous = A_bdim.has_value() ? at::select(A, *A_bdim, 0).is_contiguous() : A.is_contiguous();
396   if (batched_A_was_contiguous && !A.is_complex()) {
397     A_ = A_.contiguous();
398   }
399   const auto res = _linalg_solve_ex(A_, B_, left, check_errors);
400   return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0, std::get<3>(res), 0);
401 }
402 
cross_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim,const int64_t dim)403 oneOutput cross_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim,
404                            const Tensor& other, std::optional<int64_t> other_bdim, const int64_t dim) {
405   // match cross dimension checks
406   TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) == rankWithoutBatchDim(other, other_bdim),
407     "linalg.cross: inputs must have the same number of dimensions."
408   );
409 
410   const auto batch_size = get_bdim_size2(self, self_bdim, other, other_bdim);
411   const auto self_other_bundled = _binary_pointwise_helper(self, self_bdim, other, other_bdim, false);
412 
413   const auto self_ = ensure_has_bdim(std::get<0>(self_other_bundled), self_bdim.has_value(), batch_size);
414   const auto other_ = ensure_has_bdim(std::get<1>(self_other_bundled), other_bdim.has_value(), batch_size);
415 
416   const auto dim_ = getPhysicalDim(self_, true, dim);
417 
418   return std::make_tuple(linalg_cross(self_, other_, dim_), 0);
419 }
420 
batch_dim_if_not_empty(const Tensor & t)421 std::optional<int64_t> batch_dim_if_not_empty(const Tensor& t) {
422   if (t.dim() == 1 && t.size(0) == 0) {
423     return std::optional<int64_t>();
424   }
425   return std::optional<int64_t>(0);
426 }
427 
linalg_lstsq_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & b,std::optional<int64_t> b_bdim,std::optional<double> rcond,std::optional<c10::string_view> driver)428 fourOutputs linalg_lstsq_batch_rule(
429     const Tensor& self, std::optional<int64_t> self_bdim, const Tensor& b, std::optional<int64_t> b_bdim,
430     std::optional<double> rcond, std::optional<c10::string_view> driver) {
431   TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2, "torch.linalg.lstsq: input must have at least 2 dimensions.");
432   TORCH_CHECK(rankWithoutBatchDim(b, b_bdim) >= 1, "torch.linalg.lstsq: other must have at least 1 dimension.");
433 
434   const auto batch_size = get_bdim_size2(self, self_bdim, b, b_bdim);
435   const auto tensor_other = _binary_pointwise_helper(self, self_bdim, b, b_bdim, /*do_type_promotion=*/false);
436 
437   // because of ambiguity with vector case, lstsq can broadcast [1, 2] -> [batch_size, 2] but not [2] -> [batch_size, 2]
438   // so could unsqueeze if there's no bdim or just ensure_has_bdim
439   const auto self_ = ensure_has_bdim(std::get<0>(tensor_other), self_bdim.has_value(), batch_size);
440   const auto b_ = ensure_has_bdim(std::get<1>(tensor_other), b_bdim.has_value(), batch_size);
441 
442   auto [res, res_1, res_2, res_3] = at::linalg_lstsq(self_, b_, rcond, driver);
443 
444   // everything but the 0th output are only sometimes computed. When they aren't, they're empty tensors without a bdim
445   const auto res_1_bdim = batch_dim_if_not_empty(res_1);
446   const auto res_2_bdim = batch_dim_if_not_empty(res_2);
447   const auto res_3_bdim = batch_dim_if_not_empty(res_3);
448   return std::make_tuple(res, 0, res_1, res_1_bdim, res_2, res_2_bdim, res_3, res_3_bdim);
449 }
450 
451 template<typename F>
452 std::tuple<Tensor, std::optional<int64_t>>
atol_rtol_tensor_batch_rule(F Func,const Tensor & input,std::optional<int64_t> input_bdim,const std::optional<Tensor> & atol,const std::optional<int64_t> atol_bdim,const std::optional<Tensor> & rtol,const std::optional<int64_t> rtol_bdim,bool hermitian,char const * op_name)453 atol_rtol_tensor_batch_rule(
454     F Func, const Tensor& input, std::optional<int64_t> input_bdim,
455     const std::optional<Tensor>& atol, const std::optional<int64_t> atol_bdim,
456     const std::optional<Tensor>& rtol, const std::optional<int64_t> rtol_bdim, bool hermitian, char const *op_name) {
457   auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
458 
459   TORCH_CHECK(input_logical_rank >= 2,
460             op_name, ": The input tensor input must have at least 2 dimensions.");
461 
462   // atol and rtol's dims must be broadcastable to the number of batch dims of input
463   // which is input's dim - 2 (input represents a batch of matrices, so 2 is for the matrix dimensions)
464   const auto input_logical_num_bdims = input_logical_rank - 2;
465   const int64_t atol_logical_num_bdims = atol.has_value() ? rankWithoutBatchDim(*atol, atol_bdim) : 0;
466   const int64_t rtol_logical_num_bdims = rtol.has_value() ? rankWithoutBatchDim(*rtol, rtol_bdim) : 0;
467   const auto max_logical_bdims = std::max({input_logical_num_bdims, atol_logical_num_bdims, rtol_logical_num_bdims});
468 
469   auto input_ = moveBatchDimToFront(input, input_bdim);
470   auto atol_ = atol.has_value() ? moveBatchDimToFront(*atol, atol_bdim) : atol;
471   auto rtol_ = rtol.has_value() ? moveBatchDimToFront(*rtol, rtol_bdim) : rtol;
472 
473   // pad all inputs to have the same number of (non-vmap) batch dimensions
474   input_ = maybePadToLogicalRank(input_, input_bdim, max_logical_bdims + 2);
475   atol_ = atol_.has_value() ? maybePadToLogicalRank(*atol_, atol_bdim, max_logical_bdims) : atol_;
476   rtol_ = rtol_.has_value() ? maybePadToLogicalRank(*rtol_, rtol_bdim, max_logical_bdims) : rtol_;
477 
478   return std::make_tuple(Func(input_, atol_, rtol_, hermitian), 0);
479 }
480 
481 static std::tuple<Tensor, std::optional<int64_t>>
pinv_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,const std::optional<Tensor> & atol,const std::optional<int64_t> atol_bdim,const std::optional<Tensor> & rtol,const std::optional<int64_t> rtol_bdim,bool hermitian)482 pinv_batch_rule(
483     const Tensor& input, std::optional<int64_t> input_bdim, const std::optional<Tensor>& atol,
484     const std::optional<int64_t> atol_bdim, const std::optional<Tensor>& rtol,
485     const std::optional<int64_t> rtol_bdim, bool hermitian) {
486   return atol_rtol_tensor_batch_rule(ATEN_FN2(linalg_pinv, atol_rtol_tensor), input, input_bdim, atol, atol_bdim, rtol, rtol_bdim, hermitian, "linalg.pinv");
487 }
488 
489 std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, SymInt, SymInt, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
_scaled_dot_product_flash_attention_batch_rule(const Tensor & query,std::optional<int64_t> query_bdim,const Tensor & key,std::optional<int64_t> key_bdim,const Tensor & value,std::optional<int64_t> value_bdim,double dropout_p,bool is_causal,bool return_debug_mask,c10::optional<double> scale)490 _scaled_dot_product_flash_attention_batch_rule(
491   const Tensor& query, std::optional<int64_t> query_bdim,
492   const Tensor& key, std::optional<int64_t> key_bdim,
493   const Tensor& value, std::optional<int64_t> value_bdim,
494   double dropout_p,
495   bool is_causal,
496   bool return_debug_mask,
497   c10::optional<double> scale
498 ) {
499   if (dropout_p > 0) {
500     auto maybe_layer = maybeCurrentDynamicLayer();
501     RandomnessType randomness = maybe_layer->randomness();
502     check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value());
503   }
504   auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
505   auto query_ = moveBatchDimToFront(query, query_bdim);
506   auto key_ = moveBatchDimToFront(key, key_bdim);
507   auto value_ = moveBatchDimToFront(value, value_bdim);
508   query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
509   key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
510   value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
511   query_ = query_.flatten(0, 1);
512   key_ = key_.flatten(0, 1);
513   value_ = value_.flatten(0, 1);
514 
515   const auto [res0, res1, res2, res3, res4, res5, res6, res7, res8] = at::_scaled_dot_product_flash_attention(
516       query_, key_, value_, dropout_p, is_causal, return_debug_mask, scale);
517 
518   const auto res0_ = reshape_dim_outof(0, batch_size, res0);
519   const auto res1_ = reshape_dim_outof(0, batch_size, res1);
520   // res2 and res3 (cum_seq_q and cum_seq_k) are always [0] for dense tensors
521   // res4 and res5 (max_q and max_k) are SymInts, so they don't need reshaping
522   // res6 and res7 (philox seed and offset) are always non-batched
523   const auto res8_ = return_debug_mask ? reshape_dim_outof(0, batch_size, res8) : res8;
524 
525   return std::make_tuple(
526     res0_, 0,
527     res1_, 0,
528     res2, std::nullopt,
529     res3, std::nullopt,
530     res4,
531     res5,
532     res6, std::nullopt,
533     res7, std::nullopt,
534     res8_, return_debug_mask ? std::optional<int64_t>(0) : std::nullopt
535   );
536 }
537 
_scaled_dot_product_efficient_attention_batch_rule(const Tensor & query,optional<int64_t> query_bdim,const Tensor & key,optional<int64_t> key_bdim,const Tensor & value,optional<int64_t> value_bdim,const std::optional<Tensor> & attn_bias,optional<int64_t> attn_bias_bdim,bool compute_log_sumexp,double dropout_p,bool is_causal,c10::optional<double> scale)538 fourOutputs _scaled_dot_product_efficient_attention_batch_rule(
539   const Tensor& query, optional<int64_t> query_bdim,
540   const Tensor& key, optional<int64_t> key_bdim,
541   const Tensor& value, optional<int64_t> value_bdim,
542   const std::optional<Tensor>& attn_bias, optional<int64_t> attn_bias_bdim,
543   bool compute_log_sumexp,
544   double dropout_p,
545   bool is_causal,
546   c10::optional<double> scale
547 ) {
548   if (dropout_p > 0) {
549     auto maybe_layer = maybeCurrentDynamicLayer();
550     RandomnessType randomness = maybe_layer->randomness();
551     check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value());
552   }
553   auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
554   auto query_ = moveBatchDimToFront(query, query_bdim);
555   auto key_ = moveBatchDimToFront(key, key_bdim);
556   auto value_ = moveBatchDimToFront(value, value_bdim);
557   query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
558   key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
559   value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
560 
561   query_ = query_.flatten(0, 1);
562   key_ = key_.flatten(0, 1);
563   value_ = value_.flatten(0, 1);
564 
565   std::optional<Tensor> attn_bias_;
566   if (attn_bias.has_value() && attn_bias->defined()) {
567     attn_bias_ = attn_bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value();
568   }
569   const auto [res0, res1, res2, res3] = at::_scaled_dot_product_efficient_attention(
570       query_, key_, value_, attn_bias_, compute_log_sumexp, dropout_p, is_causal, scale);
571   const auto res0_ = reshape_dim_outof(0, batch_size, res0);
572   const auto res1_ = reshape_dim_outof(0, batch_size, res1);
573   // philox seed is always non-batched
574   return std::make_tuple(res0_, 0, res1_, 0, res2, std::nullopt, res3, std::nullopt);
575 }
576 
577 // Please unify SDPA APIs!!!
578 std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, SymInt, SymInt, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
_scaled_dot_product_cudnn_attention_batch_rule(const Tensor & query,std::optional<int64_t> query_bdim,const Tensor & key,std::optional<int64_t> key_bdim,const Tensor & value,std::optional<int64_t> value_bdim,const std::optional<Tensor> & attn_bias,std::optional<int64_t> attn_bias_bdim,bool compute_log_sumexp,double dropout_p,bool is_causal,bool return_debug_mask,c10::optional<double> scale)579 _scaled_dot_product_cudnn_attention_batch_rule(
580   const Tensor& query, std::optional<int64_t> query_bdim,
581   const Tensor& key, std::optional<int64_t> key_bdim,
582   const Tensor& value, std::optional<int64_t> value_bdim,
583   const std::optional<Tensor>& attn_bias, std::optional<int64_t> attn_bias_bdim,
584   bool compute_log_sumexp,
585   double dropout_p,
586   bool is_causal,
587   bool return_debug_mask,
588   c10::optional<double> scale
589 ) {
590   if (dropout_p > 0) {
591     auto maybe_layer = maybeCurrentDynamicLayer();
592     RandomnessType randomness = maybe_layer->randomness();
593     check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value());
594   }
595   auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
596   auto query_ = moveBatchDimToFront(query, query_bdim);
597   auto key_ = moveBatchDimToFront(key, key_bdim);
598   auto value_ = moveBatchDimToFront(value, value_bdim);
599   query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
600   key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
601   value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
602   query_ = query_.flatten(0, 1);
603   key_ = key_.flatten(0, 1);
604   value_ = value_.flatten(0, 1);
605 
606   std::optional<Tensor> attn_bias_;
607   if (attn_bias.has_value() && attn_bias->defined()) {
608     attn_bias_ = attn_bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value();
609   }
610 
611   const auto [res0, res1, res2, res3, res4, res5, res6, res7, res8] = at::_scaled_dot_product_cudnn_attention(
612       query_, key_, value_, attn_bias_, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale);
613 
614   const auto res0_ = reshape_dim_outof(0, batch_size, res0);
615   Tensor res1_;
616   std::optional<int64_t> res1_bdim;
617   if (compute_log_sumexp) {
618     res1_ = reshape_dim_outof(0, batch_size, res1);
619     res1_bdim = 0;
620   } else {
621     res1_ = res1;
622     res1_bdim = std::nullopt;
623   }
624   // res2 and res3 (cum_seq_q and cum_seq_k) are always [0] for dense tensors
625   // res4 and res5 (max_q and max_k) are SymInts, so they don't need reshaping
626   // res6 and res7 (philox seed and offset) are always non-batched
627   const auto res8_ = return_debug_mask ? reshape_dim_outof(0, batch_size, res8) : res8;
628 
629   return std::make_tuple(
630     res0_, 0,
631     res1_, res1_bdim,
632     res2, std::nullopt,
633     res3, std::nullopt,
634     res4,
635     res5,
636     res6, std::nullopt,
637     res7, std::nullopt,
638     res8_, return_debug_mask ? std::optional<int64_t>(0) : std::nullopt
639   );
640 }
641 
642 }
643 
644 #define LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, num_out) SINGLE_ARG(\
645   LinalgCheckMatrixUnaryRuleHelper<\
646     func_string_##fn,\
647     decltype(&ATEN_FN(fn)),\
648     &ATEN_FN(fn),\
649     c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply_##num_out)
650 
651 #define LINALG_CHECK_MATRIX_UNARY_BATCH_RULE2(fn, overload, num_out) SINGLE_ARG(\
652   LinalgCheckMatrixUnaryRuleHelper<\
653     func_string_##fn_##overload,\
654     decltype(&ATEN_FN2(fn, overload)),\
655     &ATEN_FN2(fn, overload),\
656     c10::guts::function_traits<decltype(ATEN_FN2(fn, overload))>::parameter_types>::apply_##num_out)
657 
658 #define LINALG_CHECK_MATRIX_BINARY_BATCH_RULE(fn, num_out) SINGLE_ARG(\
659   LinalgCheckMatrixBinaryRuleHelper<\
660     func_string_##fn,\
661     decltype(&ATEN_FN(fn)),\
662     &ATEN_FN(fn),\
663     c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply_##num_out)
664 
665 
666 // Define string constants with the function names. These will be used as template parameters
667 // C++ doesn't let us use string literals as template parameters, so we have to declare them as consts first
668 // What is going on with these macros?
669 // - clang-5 seems to require the constexpr
670 // - windows compiles with or without the constexpr, but the constexpr causes test problems
671 // - as a result we have some macro guards.
672 #if defined(_MSC_VER)
673 #define LINALG_STRING_CONST(fn, op_name) \
674   const char func_string_##fn[] = #op_name;\
675 
676 #define LINALG_STRING_CONST2(fn, overload, op_name) \
677   const char func_string_##fn_##overload[] = #op_name;\
678 
679 #else
680 #define LINALG_STRING_CONST(fn, op_name) \
681   constexpr const char func_string_##fn[] = #op_name;\
682 
683 #define LINALG_STRING_CONST2(fn, overload, op_name) \
684   constexpr const char func_string_##fn_##overload[] = #op_name;\
685 
686 #endif
687 
688 #define LINALG_CHECK_MATRIX_UNARY_ONE_OUT(fn, op_name) \
689   LINALG_STRING_CONST(fn, op_name);\
690   TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
691     VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, one));\
692   }
693 
694 #define LINALG_CHECK_MATRIX_UNARY_ONE_OUT2(fn, overload, op_name) \
695   LINALG_STRING_CONST2(fn, overload, op_name);\
696   TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
697     VMAP_SUPPORT2(fn, overload, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE2(fn, overload, one));\
698   }
699 
700 #define LINALG_CHECK_MATRIX_UNARY_TWO_OUT(fn, op_name) \
701   LINALG_STRING_CONST(fn, op_name);\
702   TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
703     VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, two));\
704   }
705 
706 #define LINALG_CHECK_MATRIX_UNARY_THREE_OUT(fn, op_name) \
707   LINALG_STRING_CONST(fn, op_name);\
708   TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
709     VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, three));\
710   }
711 
712 #define LINALG_CHECK_MATRIX_UNARY_FOUR_OUT(fn, op_name) \
713   LINALG_STRING_CONST(fn, op_name);\
714   TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
715     VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, four));\
716   }
717 
718 #define LINALG_CHECK_MATRIX_BINARY_ONE_OUT(fn, op_name) \
719   LINALG_STRING_CONST(fn, op_name);\
720   TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
721     VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_BINARY_BATCH_RULE(fn, one));\
722   }
723 
724 #define LINALG_CHECK_MATRIX_BINARY_TWO_OUT(fn, op_name) \
725   LINALG_STRING_CONST(fn, op_name);\
726   TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
727     VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_BINARY_BATCH_RULE(fn, two));\
728   }
729 
730 // These need to be outside. String constant must be declared outside of a macro to be used as template param
731 // NOLINTBEGIN(*array*)
732 LINALG_CHECK_MATRIX_UNARY_ONE_OUT(cholesky, cholesky);
733 LINALG_CHECK_MATRIX_UNARY_ONE_OUT(cholesky_inverse, cholesky_inverse);
734 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_cholesky_ex, linalg.cholesky);
735 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_eig, linalg.eig);
736 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_inv_ex, linalg.inv_ex);
737 LINALG_CHECK_MATRIX_UNARY_THREE_OUT(linalg_ldl_factor_ex, torch.linalg.ldl_factor_ex);
738 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_qr, linalg.qr);
739 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_slogdet, linalg.slogdet);
740 LINALG_CHECK_MATRIX_BINARY_ONE_OUT(linalg_solve_triangular, linalg.solve_triangular);
741 
742 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(geqrf, geqrf);
743 LINALG_CHECK_MATRIX_BINARY_TWO_OUT(triangular_solve, triangular_solve);
744 LINALG_CHECK_MATRIX_UNARY_THREE_OUT(_linalg_det, linalg.det);
745 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(_linalg_eigh, linalg.eigh);
746 LINALG_CHECK_MATRIX_UNARY_FOUR_OUT(_linalg_slogdet, linalg.slogdet);
747 LINALG_CHECK_MATRIX_UNARY_THREE_OUT(_linalg_svd, linalg.svd);
748 // NOLINTEND(*array*)
749 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)750 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
751   VMAP_SUPPORT(bmm, bmm_batch_rule);
752   m.impl("addmv", addmv_decomp);
753   m.impl("addmm", addmm_decomp);
754   m.impl("addbmm", addbmm_decomp);
755   m.impl("baddbmm", baddbmm_decomp);
756   VMAP_SUPPORT(dot, dot_batch_rule);
757   VMAP_SUPPORT(mv, mv_batch_rule);
758   VMAP_SUPPORT(mm, mm_batch_rule);
759   VMAP_SUPPORT(lu_unpack, linalg_lu_unpack_batch_rule);
760   VMAP_SUPPORT(linalg_lu_solve, linalg_lu_solve_batch_rule);
761   VMAP_SUPPORT(linalg_householder_product, householder_product_batch_rule);
762   VMAP_SUPPORT(cholesky_solve, cholesky_solve_batch_rule);  // custom dim error
763   VMAP_SUPPORT(linalg_lstsq, linalg_lstsq_batch_rule);  // custom errors and sometimes empty return
764   VMAP_SUPPORT(linalg_lu_factor_ex, linalg_lu_factor_ex_batch_rule);
765   VMAP_SUPPORT(linalg_matrix_exp, matrix_exp_batch_rule);
766   VMAP_SUPPORT(_linalg_solve_ex, solve_ex_batch_rule);
767   VMAP_SUPPORT(linalg_cross, cross_batch_rule);
768   VMAP_SUPPORT2(linalg_pinv, atol_rtol_tensor, pinv_batch_rule);
769   VMAP_SUPPORT(_scaled_dot_product_efficient_attention, _scaled_dot_product_efficient_attention_batch_rule);
770 
771   VMAP_SUPPORT(_scaled_dot_product_flash_attention, _scaled_dot_product_flash_attention_batch_rule);
772   VMAP_SUPPORT(_scaled_dot_product_cudnn_attention, _scaled_dot_product_cudnn_attention_batch_rule);
773 
774   VMAP_SUPPORT(_linalg_check_errors, _linalg_check_errors_batch_rule);
775 
776   m.impl("vdot", vdot_decomp);
777 }
778 } // namespace at::functorch
779