1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchRulesHelper.h>
8
9 namespace at::functorch {
10
11 typedef std::tuple<Tensor, std::optional<int64_t>> oneOutput;
12 typedef std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>> twoOutputs;
13 typedef std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>> threeOutputs;
14 typedef std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>> fourOutputs;
15
16 namespace {
17
18 // Note [Batching rules for matmul-like operators]
19 // at::matmul doesn't "de-expand" arguments to get better performance (maybe
20 // it should). In the batching rules for matmul-like operators (dot, mv, mm),
21 // we should be careful not to expand any unnecessary dimensions. i.e., if
22 // only one of the two arguments is a BatchedTensor, then we should try
23 // not to expand batch dimensions onto the other arg.
24
dot_batch_rule(const Tensor & A,std::optional<int64_t> A_bdim,const Tensor & B,std::optional<int64_t> B_bdim)25 std::tuple<Tensor, std::optional<int64_t>> dot_batch_rule(const Tensor& A, std::optional<int64_t> A_bdim, const Tensor& B, std::optional<int64_t> B_bdim) {
26 TORCH_CHECK(A.dim() - A_bdim.has_value() == 1 && B.dim() - B_bdim.has_value() == 1, "Got wrong shapes for dot");
27 auto A_ = moveBatchDimToFront(A, A_bdim);
28 auto B_ = moveBatchDimToFront(B, B_bdim);
29 if (A_bdim && B_bdim) {
30 return std::make_tuple(at::matmul(A_.unsqueeze(-2), B_.unsqueeze(-1)).squeeze(-1).squeeze(-1), 0);
31 } else {
32 return std::make_tuple(at::matmul(A_, B_.t()), 0);
33 }
34 }
vdot_decomp(const Tensor & A,const Tensor & B)35 Tensor vdot_decomp(const Tensor& A, const Tensor& B) {
36 return at::dot(A.is_complex() ? A.conj() : A, B);
37 }
38
39 // NB: I wrote this like this because we *might* want its for a future matmul
40 // batch rule that isn't decomposed...
41 // "tv" = tensor @ vector
tv_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)42 static std::tuple<Tensor, std::optional<int64_t>> tv_batch_rule(
43 const Tensor& self, std::optional<int64_t> self_bdim,
44 const Tensor& other, std::optional<int64_t> other_bdim) {
45 if (self_bdim && other_bdim) {
46 // See Note [Batching rules for matmul-like operators]
47 // B...OI, BI -> ...BOI, BI1 -> ...BO1 -> ...BO
48 auto self_ = at::movedim(self, *self_bdim, -3);
49 auto other_ = moveBatchDimToFront(other, other_bdim);
50 other_ = other_.unsqueeze(-1);
51 auto result = at::matmul(self_, other_).squeeze(-1);
52 auto result_bdim = result.dim() - 2;
53 return std::make_tuple( std::move(result), result_bdim );
54 }
55 else if (self_bdim && !other_bdim) {
56 // B...OI, I -> B...O
57 auto self_ = moveBatchDimToFront(self, self_bdim);
58 return std::make_tuple( at::matmul(self_, other), 0 );
59 }
60 else if (!self_bdim && other_bdim) {
61 // ...OI, BI -> ...OI, IB -> OB
62 auto other_ = at::movedim(other, *other_bdim, -1);
63 auto result = at::matmul(self, other_);
64 return std::make_tuple( std::move(result), 1 );
65 }
66 TORCH_INTERNAL_ASSERT(false, "can't get here");
67 }
68
mv_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)69 static std::tuple<Tensor, std::optional<int64_t>> mv_batch_rule(
70 const Tensor& self, std::optional<int64_t> self_bdim,
71 const Tensor& other, std::optional<int64_t> other_bdim) {
72 auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
73 auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
74 TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 1,
75 "Shape mismatch: ",
76 "Got incorrect dims for mv(a, b). a has dim ", self_logical_rank,
77 "and b has dim ", other_logical_rank,
78 "but expected them to have dim 2 and dim 1");
79 return tv_batch_rule(self, self_bdim, other, other_bdim);
80 }
81
mm_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)82 static std::tuple<Tensor, std::optional<int64_t>> mm_batch_rule(
83 const Tensor& self, std::optional<int64_t> self_bdim,
84 const Tensor& other, std::optional<int64_t> other_bdim) {
85 auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
86 auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
87 TORCH_CHECK(self_logical_rank == 2 && other_logical_rank == 2,
88 "Shape mismatch: Got incorrect dims for mm(a, b). "
89 "a has dim ", self_logical_rank,
90 "and b has dim ", other_logical_rank,
91 "but expected them to have dim 2 and dim 2");
92 auto self_ = moveBatchDimToFront(self, self_bdim);
93 auto other_ = moveBatchDimToFront(other, other_bdim);
94 return std::make_tuple( at::matmul(self_, other_), 0 );
95 }
96
bmm_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim)97 static std::tuple<Tensor, std::optional<int64_t>> bmm_batch_rule(
98 const Tensor& self, std::optional<int64_t> self_bdim,
99 const Tensor& other, std::optional<int64_t> other_bdim) {
100 auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
101 auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
102 TORCH_CHECK(self_logical_rank == 3 && other_logical_rank == 3,
103 "Shape mismatch: Got incorrect dims for bmm(a, b). "
104 "a has dim ", self_logical_rank,
105 "and b has dim ", other_logical_rank,
106 "but expected them to have dim 3 and dim 3");
107 auto self_ = moveBatchDimToFront(self, self_bdim);
108 auto other_ = moveBatchDimToFront(other, other_bdim);
109 return std::make_tuple( at::matmul(self_, other_), 0 );
110 }
111
112 // AFAICT, nothing here can be batched. So we decompose :)
addmv_decomp(const Tensor & input,const Tensor & mat,const Tensor & vec,const Scalar & beta,const Scalar & alpha)113 Tensor addmv_decomp(
114 const Tensor& input, const Tensor& mat, const Tensor& vec, const Scalar& beta, const Scalar& alpha) {
115 Tensor out = at::mv(mat, vec);
116 if (!alpha.equal(1)) {
117 out = alpha * out;
118 }
119 if (!beta.equal(0)) {
120 out = beta * input + out;
121 }
122 return out;
123 }
124
addbmm_decomp(const Tensor & input,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)125 Tensor addbmm_decomp(
126 const Tensor& input, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
127 Tensor out = at::bmm(batch1, batch2).sum(0);
128 if (!alpha.equal(1)) {
129 out = alpha * out;
130 }
131 if (!beta.equal(0)) {
132 out = beta * input + out;
133 }
134 return out;
135 }
136
baddbmm_decomp(const Tensor & input,const Tensor & batch1,const Tensor & batch2,const Scalar & beta,const Scalar & alpha)137 Tensor baddbmm_decomp(
138 const Tensor& input, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
139 Tensor out = at::bmm(batch1, batch2);
140 if (!alpha.equal(1)) {
141 out = alpha * out;
142 }
143 if (!beta.equal(0)) {
144 out = beta * input + out;
145 }
146 return out;
147 }
148
addmm_decomp(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha)149 Tensor addmm_decomp(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
150 // Decomposition that is probably not very fast...
151 return at::add(self * beta, at::mm(mat1, mat2), alpha);
152 }
153
_linalg_check_errors_batch_rule(const Tensor & info,std::optional<int64_t> info_bdim,c10::string_view api_name,bool is_matrix)154 void _linalg_check_errors_batch_rule(const Tensor& info, std::optional<int64_t> info_bdim, c10::string_view api_name, bool is_matrix) {
155 auto info_ = moveBatchDimToFront(info, info_bdim);
156 // Not a matrix means this is a batch of matrices
157 at::_linalg_check_errors(info_, api_name, false);
158 }
159
160 std::tuple<Tensor, std::optional<int64_t>>
householder_product_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,const Tensor & tau,std::optional<int64_t> tau_bdim)161 householder_product_batch_rule(const Tensor &input, std::optional<int64_t> input_bdim,
162 const Tensor &tau, std::optional<int64_t> tau_bdim)
163 {
164 auto input_ = moveBatchDimToFront(input, input_bdim);
165 auto tau_ = moveBatchDimToFront(tau, tau_bdim);
166
167 auto batch_size = get_bdim_size2(input, input_bdim, tau, tau_bdim);
168
169 input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size);
170 tau_ = ensure_has_bdim(tau_, tau_bdim.has_value(), batch_size);
171 return std::make_tuple(at::linalg_householder_product(input_, tau_), 0);
172 }
173
174 template <char const *op_name, typename A, A a, typename C>
175 struct LinalgCheckMatrixUnaryRuleHelper;
176
177 template <char const *op_name, typename F, F Func, typename A, typename... T>
178 struct LinalgCheckMatrixUnaryRuleHelper<op_name, F, Func, typelist<A, T...>> {
check_and_reshape_inputat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixUnaryRuleHelper179 static inline Tensor check_and_reshape_input(const Tensor& tensor, std::optional<int64_t> batch_dim) {
180 TORCH_CHECK(rankWithoutBatchDim(tensor, batch_dim) >= 2, op_name, ": The input tensor A must have at least 2 dimensions.");
181 return moveBatchDimToFront(tensor, batch_dim);
182 }
183
apply_oneat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixUnaryRuleHelper184 static oneOutput apply_one(
185 const Tensor& tensor,
186 std::optional<int64_t> batch_dim,
187 T... extra_args) {
188 const auto tensor_ = check_and_reshape_input(tensor, batch_dim);
189 return std::make_tuple(Func(tensor_, std::forward<T>(extra_args)...), 0);
190 }
191
apply_twoat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixUnaryRuleHelper192 static twoOutputs apply_two(
193 const Tensor& tensor,
194 std::optional<int64_t> batch_dim,
195 T... extra_args) {
196 const auto tensor_ = check_and_reshape_input(tensor, batch_dim);
197 const auto res = Func(tensor_, std::forward<T>(extra_args)...);
198 return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0);
199 }
200
apply_threeat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixUnaryRuleHelper201 static threeOutputs apply_three(
202 const Tensor& tensor,
203 std::optional<int64_t> batch_dim,
204 T... extra_args) {
205 const auto tensor_ = check_and_reshape_input(tensor, batch_dim);
206 const auto res = Func(tensor_, std::forward<T>(extra_args)...);
207 return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0);
208 }
209
apply_fourat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixUnaryRuleHelper210 static fourOutputs apply_four(
211 const Tensor& tensor,
212 std::optional<int64_t> batch_dim,
213 T... extra_args) {
214 const auto tensor_ = check_and_reshape_input(tensor, batch_dim);
215 const auto res = Func(tensor_, std::forward<T>(extra_args)...);
216 return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0, std::get<3>(res), 0);
217 }
218 };
219
220 template <char const *op_name, typename A, A a, typename C>
221 struct LinalgCheckMatrixBinaryRuleHelper;
222
223 template <char const *op_name, typename F, F Func, typename A, typename B, typename... T>
224 struct LinalgCheckMatrixBinaryRuleHelper<op_name, F, Func, typelist<A, B, T...>> {
check_inputs_and_reshape_inputsat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixBinaryRuleHelper225 static inline std::tuple<Tensor, Tensor> check_inputs_and_reshape_inputs(
226 const Tensor& first, std::optional<int64_t> first_bdim,
227 const Tensor& second, std::optional<int64_t> second_bdim) {
228 TORCH_CHECK(rankWithoutBatchDim(first, first_bdim) >= 2,
229 op_name, ": The input tensor A must have at least 2 dimensions.");
230 TORCH_CHECK(rankWithoutBatchDim(second, second_bdim) >= 2,
231 op_name, ": The input tensor B must have at least 2 dimensions.");
232 return _binary_pointwise_helper(first, first_bdim, second, second_bdim, false);
233 }
234
apply_oneat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixBinaryRuleHelper235 static oneOutput apply_one(
236 const Tensor& first, std::optional<int64_t> first_bdim,
237 const Tensor& second, std::optional<int64_t> second_bdim,
238 T... extra_args) {
239 const auto tensor_other = check_inputs_and_reshape_inputs(first, first_bdim, second, second_bdim);
240 const auto tensor_ = std::get<0>(tensor_other);
241 const auto other_ = std::get<1>(tensor_other);
242 return std::make_tuple(Func(tensor_, other_, std::forward<T>(extra_args)...), 0);
243 }
244
apply_twoat::functorch::__anonc6ba5ba60111::LinalgCheckMatrixBinaryRuleHelper245 static twoOutputs apply_two(
246 const Tensor& first, std::optional<int64_t> first_bdim,
247 const Tensor& second, std::optional<int64_t> second_bdim,
248 T... extra_args) {
249 const auto tensor_other = check_inputs_and_reshape_inputs(first, first_bdim, second, second_bdim);
250 const auto tensor_ = std::get<0>(tensor_other);
251 const auto other_ = std::get<1>(tensor_other);
252 const auto res = Func(tensor_, other_, std::forward<T>(extra_args)...);
253 return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0);
254 }
255 };
256
expect_at_least_rank(const Tensor & tensor,std::optional<int64_t> tensor_bdim,int64_t expected_rank,const char * name)257 static void expect_at_least_rank(
258 const Tensor& tensor,
259 std::optional<int64_t> tensor_bdim,
260 int64_t expected_rank,
261 const char* name) {
262 auto rank = rankWithoutBatchDim(tensor, tensor_bdim);
263 TORCH_CHECK(rank >= expected_rank,
264 name, " should have at least ", expected_rank, " dimensions, but has ",
265 rank, " dimensions instead.");
266 }
267
linalg_lu_unpack_batch_rule(const Tensor & LU,std::optional<int64_t> LU_bdim,const Tensor & pivots,std::optional<int64_t> pivots_bdim,bool unpack_data,bool unpack_pivots)268 threeOutputs linalg_lu_unpack_batch_rule(
269 const Tensor& LU, std::optional<int64_t> LU_bdim,
270 const Tensor& pivots, std::optional<int64_t> pivots_bdim,
271 bool unpack_data, bool unpack_pivots) {
272 auto LU_ = moveBatchDimToFront(LU, LU_bdim);
273 auto pivots_ = moveBatchDimToFront(pivots, pivots_bdim);
274
275 // LU and pivots's first {N-2} (for LU), {N-1} (for pivots) dimensions must
276 // match So if only one of them is being vmapped over, we must expand out that
277 // dimension.
278 if (LU_bdim.has_value() != pivots_bdim.has_value()) {
279 auto bdim_size = get_bdim_size2(LU, LU_bdim, pivots, pivots_bdim);
280 LU_ = ensure_has_bdim(LU_, LU_bdim.has_value(), bdim_size);
281 pivots_ = ensure_has_bdim(pivots_, pivots_bdim.has_value(), bdim_size);
282 pivots_bdim = 0;
283 LU_bdim = 0;
284 }
285
286 const auto res = at::lu_unpack(LU_, pivots_, unpack_data, unpack_pivots);
287 return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0);
288 }
289
linalg_lu_solve_batch_rule(const Tensor & LU,std::optional<int64_t> LU_bdim,const Tensor & pivots,std::optional<int64_t> pivots_bdim,const Tensor & B,std::optional<int64_t> B_bdim,bool left,bool adjoint)290 oneOutput linalg_lu_solve_batch_rule(
291 const Tensor& LU, std::optional<int64_t> LU_bdim,
292 const Tensor& pivots, std::optional<int64_t> pivots_bdim,
293 const Tensor& B, std::optional<int64_t> B_bdim,
294 bool left, bool adjoint) {
295 const auto LU_min_rank = 2;
296 const auto pivots_min_rank = 1;
297 const auto B_min_rank = 2;
298
299 expect_at_least_rank(LU, LU_bdim, LU_min_rank, "LU");
300 expect_at_least_rank(pivots, pivots_bdim, pivots_min_rank, "pivots");
301 expect_at_least_rank(B, B_bdim, B_min_rank, "B");
302
303 auto LU_ = moveBatchDimToFront(LU, LU_bdim);
304 auto pivots_ = moveBatchDimToFront(pivots, pivots_bdim);
305 auto B_ = moveBatchDimToFront(B, B_bdim);
306
307 // LU and pivots's first {N-2} (for LU), {N-1} (for pivots) dimensions must match
308 // So if only one of them is being vmapped over, we must expand out that dimension.
309 if (LU_bdim.has_value() ^ pivots_bdim.has_value()) {
310 auto bdim_size = get_bdim_size2(LU, LU_bdim, pivots, pivots_bdim);
311 LU_ = ensure_has_bdim(LU_, LU_bdim.has_value(), bdim_size);
312 pivots_ = ensure_has_bdim(pivots_, pivots_bdim.has_value(), bdim_size);
313 pivots_bdim = 0;
314 LU_bdim = 0;
315 }
316
317 // Now, {LU, pivots} and B's first dimensions are allowed to broadcast.
318 // The rest of the logic handles that.
319 const auto LU_num_batch_dims = rankWithoutBatchDim(LU_, LU_bdim) - LU_min_rank;
320 const auto pivots_num_batch_dims = rankWithoutBatchDim(pivots_, pivots_bdim) - pivots_min_rank;
321 const auto B_num_batch_dims = rankWithoutBatchDim(B_, B_bdim) - B_min_rank;
322 const auto max_num_batch_dims = std::max(std::max(LU_num_batch_dims, pivots_num_batch_dims), B_num_batch_dims);
323
324 LU_ = maybePadToLogicalRank(LU_, LU_bdim, max_num_batch_dims + LU_min_rank);
325 pivots_ = maybePadToLogicalRank(pivots_, pivots_bdim, max_num_batch_dims + pivots_min_rank);
326 B_ = maybePadToLogicalRank(B_, B_bdim, max_num_batch_dims + B_min_rank);
327
328 const auto result = at::linalg_lu_solve(LU_, pivots_, B_, left, adjoint);
329 return std::make_tuple(result, 0);
330 }
331
cholesky_solve_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & A,std::optional<int64_t> A_bdim,bool upper)332 oneOutput cholesky_solve_batch_rule(
333 const Tensor& self, std::optional<int64_t> self_bdim,
334 const Tensor& A, std::optional<int64_t> A_bdim,
335 bool upper) {
336 TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2,
337 "b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
338 TORCH_CHECK(rankWithoutBatchDim(A, A_bdim) >= 2,
339 "u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead");
340
341 const auto tensor_other = _binary_pointwise_helper(self, self_bdim, A, A_bdim, /*do_type_promotion=*/false);
342 const auto tensor_ = std::get<0>(tensor_other);
343 const auto other_ = std::get<1>(tensor_other);
344 return std::make_tuple(at::cholesky_solve(tensor_, other_, upper), 0);
345 }
346
linalg_lu_factor_ex_batch_rule(const Tensor & A,std::optional<int64_t> A_bdim,bool pivot,bool check_errors)347 threeOutputs linalg_lu_factor_ex_batch_rule(
348 const Tensor& A, std::optional<int64_t> A_bdim, bool pivot, bool check_errors) {
349 TORCH_CHECK(rankWithoutBatchDim(A, A_bdim) >= 2, "torch.lu_factor_ex: Expected tensor with 2 or more dimensions. Got size: ", A.sizes(), " instead");
350 const auto A_ = moveBatchDimToFront(A, A_bdim);
351 const auto res = at::linalg_lu_factor_ex(A_, pivot, check_errors);
352 return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0);
353 }
354
matrix_exp_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim)355 oneOutput matrix_exp_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim) {
356 TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2, "torch.matrix_exp: The input tensor A must have at least 2 dimensions.");
357 const auto self_ = moveBatchDimToFront(self, self_bdim).contiguous(); // seems to be a bug
358 return std::make_tuple(at::matrix_exp(self_), 0);
359 }
360
solve_ex_batch_rule(const Tensor & A,std::optional<int64_t> A_bdim,const Tensor & B,std::optional<int64_t> B_bdim,bool left,bool check_errors)361 fourOutputs solve_ex_batch_rule(
362 const Tensor& A, std::optional<int64_t> A_bdim,
363 const Tensor& B, std::optional<int64_t> B_bdim,
364 bool left, bool check_errors) {
365 auto batch_size = get_bdim_size2(A, A_bdim, B, B_bdim);
366 const auto A_logical_rank = rankWithoutBatchDim(A, A_bdim);
367 const auto B_logical_rank = rankWithoutBatchDim(B, B_bdim);
368 const auto max_logical_rank = std::max(A_logical_rank, B_logical_rank);
369
370 TORCH_CHECK(A_logical_rank >= 2,
371 "linalg.solve: The input tensor A must have at least 2 dimensions.");
372
373 auto b_logical_rank = max_logical_rank;
374 if (A_logical_rank > B_logical_rank) { // vector case: B was a vector or batched vector
375 // not accurate but matches linalg error message
376 TORCH_CHECK(B_logical_rank >= 1, "linalg.solve: The input tensor B must have at least 2 dimensions.");
377 b_logical_rank = max_logical_rank - 1;
378 } else { // matrix case: A and B are both matrices or batches of matrices
379 TORCH_CHECK(B_logical_rank >= 2, "linalg.solve: The input tensor B must have at least 2 dimensions.");
380 }
381
382 // basically binary pointwise helper but if B was a vector incoming, we must pad it to be 1 dim smaller than A
383 auto A_ = moveBatchDimToFront(A, A_bdim);
384 auto B_ = moveBatchDimToFront(B, B_bdim);
385 A_ = maybePadToLogicalRank(A_, A_bdim, max_logical_rank);
386 B_ = maybePadToLogicalRank(B_, B_bdim, b_logical_rank);
387
388 A_ = ensure_has_bdim(A_, A_bdim.has_value(), batch_size);
389 B_ = ensure_has_bdim(B_, B_bdim.has_value(), batch_size);
390
391 // NOTE [ solve_ex Batch Rule Contiguity ]
392 // A determines whether or not linalg_solve takes an optimized path. We need the check on A_ to match the one run on
393 // A as BatchedTensor since it might have been saved by autograd (specifically by the jvp) and the autograd behvaior
394 // differs based on whether or not the optimized path was taken
395 const auto batched_A_was_contiguous = A_bdim.has_value() ? at::select(A, *A_bdim, 0).is_contiguous() : A.is_contiguous();
396 if (batched_A_was_contiguous && !A.is_complex()) {
397 A_ = A_.contiguous();
398 }
399 const auto res = _linalg_solve_ex(A_, B_, left, check_errors);
400 return std::make_tuple(std::get<0>(res), 0, std::get<1>(res), 0, std::get<2>(res), 0, std::get<3>(res), 0);
401 }
402
cross_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & other,std::optional<int64_t> other_bdim,const int64_t dim)403 oneOutput cross_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim,
404 const Tensor& other, std::optional<int64_t> other_bdim, const int64_t dim) {
405 // match cross dimension checks
406 TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) == rankWithoutBatchDim(other, other_bdim),
407 "linalg.cross: inputs must have the same number of dimensions."
408 );
409
410 const auto batch_size = get_bdim_size2(self, self_bdim, other, other_bdim);
411 const auto self_other_bundled = _binary_pointwise_helper(self, self_bdim, other, other_bdim, false);
412
413 const auto self_ = ensure_has_bdim(std::get<0>(self_other_bundled), self_bdim.has_value(), batch_size);
414 const auto other_ = ensure_has_bdim(std::get<1>(self_other_bundled), other_bdim.has_value(), batch_size);
415
416 const auto dim_ = getPhysicalDim(self_, true, dim);
417
418 return std::make_tuple(linalg_cross(self_, other_, dim_), 0);
419 }
420
batch_dim_if_not_empty(const Tensor & t)421 std::optional<int64_t> batch_dim_if_not_empty(const Tensor& t) {
422 if (t.dim() == 1 && t.size(0) == 0) {
423 return std::optional<int64_t>();
424 }
425 return std::optional<int64_t>(0);
426 }
427
linalg_lstsq_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & b,std::optional<int64_t> b_bdim,std::optional<double> rcond,std::optional<c10::string_view> driver)428 fourOutputs linalg_lstsq_batch_rule(
429 const Tensor& self, std::optional<int64_t> self_bdim, const Tensor& b, std::optional<int64_t> b_bdim,
430 std::optional<double> rcond, std::optional<c10::string_view> driver) {
431 TORCH_CHECK(rankWithoutBatchDim(self, self_bdim) >= 2, "torch.linalg.lstsq: input must have at least 2 dimensions.");
432 TORCH_CHECK(rankWithoutBatchDim(b, b_bdim) >= 1, "torch.linalg.lstsq: other must have at least 1 dimension.");
433
434 const auto batch_size = get_bdim_size2(self, self_bdim, b, b_bdim);
435 const auto tensor_other = _binary_pointwise_helper(self, self_bdim, b, b_bdim, /*do_type_promotion=*/false);
436
437 // because of ambiguity with vector case, lstsq can broadcast [1, 2] -> [batch_size, 2] but not [2] -> [batch_size, 2]
438 // so could unsqueeze if there's no bdim or just ensure_has_bdim
439 const auto self_ = ensure_has_bdim(std::get<0>(tensor_other), self_bdim.has_value(), batch_size);
440 const auto b_ = ensure_has_bdim(std::get<1>(tensor_other), b_bdim.has_value(), batch_size);
441
442 auto [res, res_1, res_2, res_3] = at::linalg_lstsq(self_, b_, rcond, driver);
443
444 // everything but the 0th output are only sometimes computed. When they aren't, they're empty tensors without a bdim
445 const auto res_1_bdim = batch_dim_if_not_empty(res_1);
446 const auto res_2_bdim = batch_dim_if_not_empty(res_2);
447 const auto res_3_bdim = batch_dim_if_not_empty(res_3);
448 return std::make_tuple(res, 0, res_1, res_1_bdim, res_2, res_2_bdim, res_3, res_3_bdim);
449 }
450
451 template<typename F>
452 std::tuple<Tensor, std::optional<int64_t>>
atol_rtol_tensor_batch_rule(F Func,const Tensor & input,std::optional<int64_t> input_bdim,const std::optional<Tensor> & atol,const std::optional<int64_t> atol_bdim,const std::optional<Tensor> & rtol,const std::optional<int64_t> rtol_bdim,bool hermitian,char const * op_name)453 atol_rtol_tensor_batch_rule(
454 F Func, const Tensor& input, std::optional<int64_t> input_bdim,
455 const std::optional<Tensor>& atol, const std::optional<int64_t> atol_bdim,
456 const std::optional<Tensor>& rtol, const std::optional<int64_t> rtol_bdim, bool hermitian, char const *op_name) {
457 auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
458
459 TORCH_CHECK(input_logical_rank >= 2,
460 op_name, ": The input tensor input must have at least 2 dimensions.");
461
462 // atol and rtol's dims must be broadcastable to the number of batch dims of input
463 // which is input's dim - 2 (input represents a batch of matrices, so 2 is for the matrix dimensions)
464 const auto input_logical_num_bdims = input_logical_rank - 2;
465 const int64_t atol_logical_num_bdims = atol.has_value() ? rankWithoutBatchDim(*atol, atol_bdim) : 0;
466 const int64_t rtol_logical_num_bdims = rtol.has_value() ? rankWithoutBatchDim(*rtol, rtol_bdim) : 0;
467 const auto max_logical_bdims = std::max({input_logical_num_bdims, atol_logical_num_bdims, rtol_logical_num_bdims});
468
469 auto input_ = moveBatchDimToFront(input, input_bdim);
470 auto atol_ = atol.has_value() ? moveBatchDimToFront(*atol, atol_bdim) : atol;
471 auto rtol_ = rtol.has_value() ? moveBatchDimToFront(*rtol, rtol_bdim) : rtol;
472
473 // pad all inputs to have the same number of (non-vmap) batch dimensions
474 input_ = maybePadToLogicalRank(input_, input_bdim, max_logical_bdims + 2);
475 atol_ = atol_.has_value() ? maybePadToLogicalRank(*atol_, atol_bdim, max_logical_bdims) : atol_;
476 rtol_ = rtol_.has_value() ? maybePadToLogicalRank(*rtol_, rtol_bdim, max_logical_bdims) : rtol_;
477
478 return std::make_tuple(Func(input_, atol_, rtol_, hermitian), 0);
479 }
480
481 static std::tuple<Tensor, std::optional<int64_t>>
pinv_batch_rule(const Tensor & input,std::optional<int64_t> input_bdim,const std::optional<Tensor> & atol,const std::optional<int64_t> atol_bdim,const std::optional<Tensor> & rtol,const std::optional<int64_t> rtol_bdim,bool hermitian)482 pinv_batch_rule(
483 const Tensor& input, std::optional<int64_t> input_bdim, const std::optional<Tensor>& atol,
484 const std::optional<int64_t> atol_bdim, const std::optional<Tensor>& rtol,
485 const std::optional<int64_t> rtol_bdim, bool hermitian) {
486 return atol_rtol_tensor_batch_rule(ATEN_FN2(linalg_pinv, atol_rtol_tensor), input, input_bdim, atol, atol_bdim, rtol, rtol_bdim, hermitian, "linalg.pinv");
487 }
488
489 std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, SymInt, SymInt, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
_scaled_dot_product_flash_attention_batch_rule(const Tensor & query,std::optional<int64_t> query_bdim,const Tensor & key,std::optional<int64_t> key_bdim,const Tensor & value,std::optional<int64_t> value_bdim,double dropout_p,bool is_causal,bool return_debug_mask,c10::optional<double> scale)490 _scaled_dot_product_flash_attention_batch_rule(
491 const Tensor& query, std::optional<int64_t> query_bdim,
492 const Tensor& key, std::optional<int64_t> key_bdim,
493 const Tensor& value, std::optional<int64_t> value_bdim,
494 double dropout_p,
495 bool is_causal,
496 bool return_debug_mask,
497 c10::optional<double> scale
498 ) {
499 if (dropout_p > 0) {
500 auto maybe_layer = maybeCurrentDynamicLayer();
501 RandomnessType randomness = maybe_layer->randomness();
502 check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value());
503 }
504 auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
505 auto query_ = moveBatchDimToFront(query, query_bdim);
506 auto key_ = moveBatchDimToFront(key, key_bdim);
507 auto value_ = moveBatchDimToFront(value, value_bdim);
508 query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
509 key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
510 value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
511 query_ = query_.flatten(0, 1);
512 key_ = key_.flatten(0, 1);
513 value_ = value_.flatten(0, 1);
514
515 const auto [res0, res1, res2, res3, res4, res5, res6, res7, res8] = at::_scaled_dot_product_flash_attention(
516 query_, key_, value_, dropout_p, is_causal, return_debug_mask, scale);
517
518 const auto res0_ = reshape_dim_outof(0, batch_size, res0);
519 const auto res1_ = reshape_dim_outof(0, batch_size, res1);
520 // res2 and res3 (cum_seq_q and cum_seq_k) are always [0] for dense tensors
521 // res4 and res5 (max_q and max_k) are SymInts, so they don't need reshaping
522 // res6 and res7 (philox seed and offset) are always non-batched
523 const auto res8_ = return_debug_mask ? reshape_dim_outof(0, batch_size, res8) : res8;
524
525 return std::make_tuple(
526 res0_, 0,
527 res1_, 0,
528 res2, std::nullopt,
529 res3, std::nullopt,
530 res4,
531 res5,
532 res6, std::nullopt,
533 res7, std::nullopt,
534 res8_, return_debug_mask ? std::optional<int64_t>(0) : std::nullopt
535 );
536 }
537
_scaled_dot_product_efficient_attention_batch_rule(const Tensor & query,optional<int64_t> query_bdim,const Tensor & key,optional<int64_t> key_bdim,const Tensor & value,optional<int64_t> value_bdim,const std::optional<Tensor> & attn_bias,optional<int64_t> attn_bias_bdim,bool compute_log_sumexp,double dropout_p,bool is_causal,c10::optional<double> scale)538 fourOutputs _scaled_dot_product_efficient_attention_batch_rule(
539 const Tensor& query, optional<int64_t> query_bdim,
540 const Tensor& key, optional<int64_t> key_bdim,
541 const Tensor& value, optional<int64_t> value_bdim,
542 const std::optional<Tensor>& attn_bias, optional<int64_t> attn_bias_bdim,
543 bool compute_log_sumexp,
544 double dropout_p,
545 bool is_causal,
546 c10::optional<double> scale
547 ) {
548 if (dropout_p > 0) {
549 auto maybe_layer = maybeCurrentDynamicLayer();
550 RandomnessType randomness = maybe_layer->randomness();
551 check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value());
552 }
553 auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
554 auto query_ = moveBatchDimToFront(query, query_bdim);
555 auto key_ = moveBatchDimToFront(key, key_bdim);
556 auto value_ = moveBatchDimToFront(value, value_bdim);
557 query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
558 key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
559 value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
560
561 query_ = query_.flatten(0, 1);
562 key_ = key_.flatten(0, 1);
563 value_ = value_.flatten(0, 1);
564
565 std::optional<Tensor> attn_bias_;
566 if (attn_bias.has_value() && attn_bias->defined()) {
567 attn_bias_ = attn_bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value();
568 }
569 const auto [res0, res1, res2, res3] = at::_scaled_dot_product_efficient_attention(
570 query_, key_, value_, attn_bias_, compute_log_sumexp, dropout_p, is_causal, scale);
571 const auto res0_ = reshape_dim_outof(0, batch_size, res0);
572 const auto res1_ = reshape_dim_outof(0, batch_size, res1);
573 // philox seed is always non-batched
574 return std::make_tuple(res0_, 0, res1_, 0, res2, std::nullopt, res3, std::nullopt);
575 }
576
577 // Please unify SDPA APIs!!!
578 std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, SymInt, SymInt, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
_scaled_dot_product_cudnn_attention_batch_rule(const Tensor & query,std::optional<int64_t> query_bdim,const Tensor & key,std::optional<int64_t> key_bdim,const Tensor & value,std::optional<int64_t> value_bdim,const std::optional<Tensor> & attn_bias,std::optional<int64_t> attn_bias_bdim,bool compute_log_sumexp,double dropout_p,bool is_causal,bool return_debug_mask,c10::optional<double> scale)579 _scaled_dot_product_cudnn_attention_batch_rule(
580 const Tensor& query, std::optional<int64_t> query_bdim,
581 const Tensor& key, std::optional<int64_t> key_bdim,
582 const Tensor& value, std::optional<int64_t> value_bdim,
583 const std::optional<Tensor>& attn_bias, std::optional<int64_t> attn_bias_bdim,
584 bool compute_log_sumexp,
585 double dropout_p,
586 bool is_causal,
587 bool return_debug_mask,
588 c10::optional<double> scale
589 ) {
590 if (dropout_p > 0) {
591 auto maybe_layer = maybeCurrentDynamicLayer();
592 RandomnessType randomness = maybe_layer->randomness();
593 check_randomness(randomness, query_bdim.has_value() || key_bdim.has_value() || value_bdim.has_value());
594 }
595 auto batch_size = get_bdim_size3(query, query_bdim, key, key_bdim, value, value_bdim);
596 auto query_ = moveBatchDimToFront(query, query_bdim);
597 auto key_ = moveBatchDimToFront(key, key_bdim);
598 auto value_ = moveBatchDimToFront(value, value_bdim);
599 query_ = ensure_has_bdim(query_, query_bdim.has_value(), batch_size);
600 key_ = ensure_has_bdim(key_, key_bdim.has_value(), batch_size);
601 value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);
602 query_ = query_.flatten(0, 1);
603 key_ = key_.flatten(0, 1);
604 value_ = value_.flatten(0, 1);
605
606 std::optional<Tensor> attn_bias_;
607 if (attn_bias.has_value() && attn_bias->defined()) {
608 attn_bias_ = attn_bias_bdim.has_value() ? reshape_dim_into(*attn_bias_bdim, 0, attn_bias.value()) : attn_bias.value();
609 }
610
611 const auto [res0, res1, res2, res3, res4, res5, res6, res7, res8] = at::_scaled_dot_product_cudnn_attention(
612 query_, key_, value_, attn_bias_, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale);
613
614 const auto res0_ = reshape_dim_outof(0, batch_size, res0);
615 Tensor res1_;
616 std::optional<int64_t> res1_bdim;
617 if (compute_log_sumexp) {
618 res1_ = reshape_dim_outof(0, batch_size, res1);
619 res1_bdim = 0;
620 } else {
621 res1_ = res1;
622 res1_bdim = std::nullopt;
623 }
624 // res2 and res3 (cum_seq_q and cum_seq_k) are always [0] for dense tensors
625 // res4 and res5 (max_q and max_k) are SymInts, so they don't need reshaping
626 // res6 and res7 (philox seed and offset) are always non-batched
627 const auto res8_ = return_debug_mask ? reshape_dim_outof(0, batch_size, res8) : res8;
628
629 return std::make_tuple(
630 res0_, 0,
631 res1_, res1_bdim,
632 res2, std::nullopt,
633 res3, std::nullopt,
634 res4,
635 res5,
636 res6, std::nullopt,
637 res7, std::nullopt,
638 res8_, return_debug_mask ? std::optional<int64_t>(0) : std::nullopt
639 );
640 }
641
642 }
643
644 #define LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, num_out) SINGLE_ARG(\
645 LinalgCheckMatrixUnaryRuleHelper<\
646 func_string_##fn,\
647 decltype(&ATEN_FN(fn)),\
648 &ATEN_FN(fn),\
649 c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply_##num_out)
650
651 #define LINALG_CHECK_MATRIX_UNARY_BATCH_RULE2(fn, overload, num_out) SINGLE_ARG(\
652 LinalgCheckMatrixUnaryRuleHelper<\
653 func_string_##fn_##overload,\
654 decltype(&ATEN_FN2(fn, overload)),\
655 &ATEN_FN2(fn, overload),\
656 c10::guts::function_traits<decltype(ATEN_FN2(fn, overload))>::parameter_types>::apply_##num_out)
657
658 #define LINALG_CHECK_MATRIX_BINARY_BATCH_RULE(fn, num_out) SINGLE_ARG(\
659 LinalgCheckMatrixBinaryRuleHelper<\
660 func_string_##fn,\
661 decltype(&ATEN_FN(fn)),\
662 &ATEN_FN(fn),\
663 c10::guts::function_traits<decltype(ATEN_FN(fn))>::parameter_types>::apply_##num_out)
664
665
666 // Define string constants with the function names. These will be used as template parameters
667 // C++ doesn't let us use string literals as template parameters, so we have to declare them as consts first
668 // What is going on with these macros?
669 // - clang-5 seems to require the constexpr
670 // - windows compiles with or without the constexpr, but the constexpr causes test problems
671 // - as a result we have some macro guards.
672 #if defined(_MSC_VER)
673 #define LINALG_STRING_CONST(fn, op_name) \
674 const char func_string_##fn[] = #op_name;\
675
676 #define LINALG_STRING_CONST2(fn, overload, op_name) \
677 const char func_string_##fn_##overload[] = #op_name;\
678
679 #else
680 #define LINALG_STRING_CONST(fn, op_name) \
681 constexpr const char func_string_##fn[] = #op_name;\
682
683 #define LINALG_STRING_CONST2(fn, overload, op_name) \
684 constexpr const char func_string_##fn_##overload[] = #op_name;\
685
686 #endif
687
688 #define LINALG_CHECK_MATRIX_UNARY_ONE_OUT(fn, op_name) \
689 LINALG_STRING_CONST(fn, op_name);\
690 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
691 VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, one));\
692 }
693
694 #define LINALG_CHECK_MATRIX_UNARY_ONE_OUT2(fn, overload, op_name) \
695 LINALG_STRING_CONST2(fn, overload, op_name);\
696 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
697 VMAP_SUPPORT2(fn, overload, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE2(fn, overload, one));\
698 }
699
700 #define LINALG_CHECK_MATRIX_UNARY_TWO_OUT(fn, op_name) \
701 LINALG_STRING_CONST(fn, op_name);\
702 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
703 VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, two));\
704 }
705
706 #define LINALG_CHECK_MATRIX_UNARY_THREE_OUT(fn, op_name) \
707 LINALG_STRING_CONST(fn, op_name);\
708 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
709 VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, three));\
710 }
711
712 #define LINALG_CHECK_MATRIX_UNARY_FOUR_OUT(fn, op_name) \
713 LINALG_STRING_CONST(fn, op_name);\
714 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
715 VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_UNARY_BATCH_RULE(fn, four));\
716 }
717
718 #define LINALG_CHECK_MATRIX_BINARY_ONE_OUT(fn, op_name) \
719 LINALG_STRING_CONST(fn, op_name);\
720 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
721 VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_BINARY_BATCH_RULE(fn, one));\
722 }
723
724 #define LINALG_CHECK_MATRIX_BINARY_TWO_OUT(fn, op_name) \
725 LINALG_STRING_CONST(fn, op_name);\
726 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {\
727 VMAP_SUPPORT(fn, LINALG_CHECK_MATRIX_BINARY_BATCH_RULE(fn, two));\
728 }
729
730 // These need to be outside. String constant must be declared outside of a macro to be used as template param
731 // NOLINTBEGIN(*array*)
732 LINALG_CHECK_MATRIX_UNARY_ONE_OUT(cholesky, cholesky);
733 LINALG_CHECK_MATRIX_UNARY_ONE_OUT(cholesky_inverse, cholesky_inverse);
734 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_cholesky_ex, linalg.cholesky);
735 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_eig, linalg.eig);
736 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_inv_ex, linalg.inv_ex);
737 LINALG_CHECK_MATRIX_UNARY_THREE_OUT(linalg_ldl_factor_ex, torch.linalg.ldl_factor_ex);
738 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_qr, linalg.qr);
739 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(linalg_slogdet, linalg.slogdet);
740 LINALG_CHECK_MATRIX_BINARY_ONE_OUT(linalg_solve_triangular, linalg.solve_triangular);
741
742 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(geqrf, geqrf);
743 LINALG_CHECK_MATRIX_BINARY_TWO_OUT(triangular_solve, triangular_solve);
744 LINALG_CHECK_MATRIX_UNARY_THREE_OUT(_linalg_det, linalg.det);
745 LINALG_CHECK_MATRIX_UNARY_TWO_OUT(_linalg_eigh, linalg.eigh);
746 LINALG_CHECK_MATRIX_UNARY_FOUR_OUT(_linalg_slogdet, linalg.slogdet);
747 LINALG_CHECK_MATRIX_UNARY_THREE_OUT(_linalg_svd, linalg.svd);
748 // NOLINTEND(*array*)
749
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)750 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
751 VMAP_SUPPORT(bmm, bmm_batch_rule);
752 m.impl("addmv", addmv_decomp);
753 m.impl("addmm", addmm_decomp);
754 m.impl("addbmm", addbmm_decomp);
755 m.impl("baddbmm", baddbmm_decomp);
756 VMAP_SUPPORT(dot, dot_batch_rule);
757 VMAP_SUPPORT(mv, mv_batch_rule);
758 VMAP_SUPPORT(mm, mm_batch_rule);
759 VMAP_SUPPORT(lu_unpack, linalg_lu_unpack_batch_rule);
760 VMAP_SUPPORT(linalg_lu_solve, linalg_lu_solve_batch_rule);
761 VMAP_SUPPORT(linalg_householder_product, householder_product_batch_rule);
762 VMAP_SUPPORT(cholesky_solve, cholesky_solve_batch_rule); // custom dim error
763 VMAP_SUPPORT(linalg_lstsq, linalg_lstsq_batch_rule); // custom errors and sometimes empty return
764 VMAP_SUPPORT(linalg_lu_factor_ex, linalg_lu_factor_ex_batch_rule);
765 VMAP_SUPPORT(linalg_matrix_exp, matrix_exp_batch_rule);
766 VMAP_SUPPORT(_linalg_solve_ex, solve_ex_batch_rule);
767 VMAP_SUPPORT(linalg_cross, cross_batch_rule);
768 VMAP_SUPPORT2(linalg_pinv, atol_rtol_tensor, pinv_batch_rule);
769 VMAP_SUPPORT(_scaled_dot_product_efficient_attention, _scaled_dot_product_efficient_attention_batch_rule);
770
771 VMAP_SUPPORT(_scaled_dot_product_flash_attention, _scaled_dot_product_flash_attention_batch_rule);
772 VMAP_SUPPORT(_scaled_dot_product_cudnn_attention, _scaled_dot_product_cudnn_attention_batch_rule);
773
774 VMAP_SUPPORT(_linalg_check_errors, _linalg_check_errors_batch_rule);
775
776 m.impl("vdot", vdot_decomp);
777 }
778 } // namespace at::functorch
779