xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesViews.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <utility>
9 
10 #include <ATen/Operators.h>
11 #include <ATen/functorch/PlumbingHelper.h>
12 #include <ATen/functorch/BatchedFallback.h>
13 #include <ATen/core/dispatch/Dispatcher.h>
14 #include <ATen/core/TensorBody.h>
15 #include <c10/core/SymIntArrayRef.h>
16 #include <c10/util/SmallBuffer.h>
17 #include <ATen/InferSize.h>
18 
19 namespace at::functorch {
20 
21 // Note [Adding vmap support for an operator]
22 // Hey there! So you have an operator and you want to get it to work with vmap.
23 // For example, let's say you just invented the `sum.int` operator and want to make
24 // it so that the following works.
25 // >>> tensor = torch.randn(B, 3)
26 // >>> vmap(torch.sum, (0, None))(tensor, 0)` works
27 // There are three main ways to do so.
28 //
29 // Note [Writing batch rule for out-of-place operators]
30 // If your operator is out-of-place, you can write a batch rule for it.
31 // The batch rule defines how to perform the operator on inputs where each
32 // Tensor input may have an additional dimension that is being vmapped over.
33 // We refer to this dimension as the *batch dimension* or bdim for short.
34 //
35 // For example, let's consider writing a batch rule for
36 // `Tensor sum(const Tensor& self, int64_t dim)`. The signature of the
37 // batch rule has an additional std::optional<int64_t> argument after each
38 // Tensor argument and return. So, in this case, the batch rule has signature
39 //   tuple<Tensor, std::optional<int64_t>> sum_batch_rule(
40 //       const Tensor& self, std::optional<int64_t> self_bdim, int64_t dim);
41 //
42 // The vmap call above invokes the batch rule with `self = tensor`,
43 // `self_bdim = 0`, and `dim = 0`. Note that there are **no BatchedTensors**
44 // involved in this case; there exists some plumbing that automatically unwraps
45 // BatchedTensors before calling the batch rule.
46 //
47 // To write the logic of the batch rule: think about the semantics of the
48 // `sum` operation if `self` had an additional dimension (indicated by self_bdim):
49 // - If `self_bdim` is null, then we just do `result = self.sum(dim)` as usual
50 // - If `self_bdim` is not-null, then we need to modify `dim`. `dim` is equal
51 //   to whatever the user passed in (0 in this case), but we should actually
52 //   perform the reduction over dimension 1 and do `result = self.sum(1)`
53 //   because dim 0 is being vmapped over.
54 // Finally, we return the result as well as a new bdim
55 // - If `self_bdim` is null, then there's no batch dim in the result.
56 // - If `self_bdim` is not-null, then we return where the bdim is.
57 //   Since we invoked `result = self.sum(1)`, the bdim is still at dim 0.
58 //
59 // Now that we have written `sum_batch_rule`, we have to register it inside a
60 // TORCH_LIBRARY_IMPL block:
61 //   TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
62 //     ...
63 //     VMAP_SUPPORT2(sum, int, sum_batch_rule);
64 //     ...
65 //   }
66 //
67 // Note [Reusing batch rules to add vmap support for a complicated operator]
68 // Can't figure out how to write a batch rule for a big operation? If the
69 // operation can be expressed as a composition of other operations that do have
70 // batch rules, then that is another way to add vmap support. For example,
71 // consider the following schema
72 //   func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1)
73 // and assume we already have batching rules for basic arithmetic operators.
74 //
75 // To add vmap support, define a decomposition using the same signature:
76 //   Tensor addcmul_decomp(const Tensor& self, const Tensor& tensor1,
77 //                         const Tensor& tensor2, const Scalar& value) {
78 //     auto product = torch.mul(tensor1, tensor2);
79 //     return torch.add(self, product, value);
80 //   }
81 // And register it inside a TORCH_LIBRARY_IMPL block:
82 //   TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
83 //     ...
84 //     m.impl("addcmul", addcmul_decomp);
85 //     ...
86 //   }
87 //
88 // Note [Writing batch rule for in-place operators]
89 // TODO: This is kinda complicated. Saving this for a future date.
90 
91 namespace{
92 
unsqueeze_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim)93 std::tuple<Tensor, std::optional<int64_t>> unsqueeze_batch_rule(
94     const Tensor& self,
95     std::optional<int64_t> self_bdim,
96     int64_t dim) {
97   auto self_ = moveBatchDimToFront(self, self_bdim);
98   auto rank = rankWithoutBatchDim(self, self_bdim);
99   dim = maybe_wrap_dim(dim, rank + 1) + 1;
100   return std::make_tuple(self_.unsqueeze(dim), 0);
101 }
102 
103 // NB: repeat is not actually a view, but it is in this file
repeat_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,c10::SymIntArrayRef sizes)104 std::tuple<Tensor, std::optional<int64_t>> repeat_batch_rule(
105     const Tensor& self,
106     std::optional<int64_t> self_bdim,
107     c10::SymIntArrayRef sizes) {
108 
109   SymDimVector sizes_with_bdim = { sizes.begin(), sizes.end() };
110   sizes_with_bdim.insert(sizes_with_bdim.begin(), 1);
111   auto self_ = moveBatchDimToFront(self, self_bdim);
112   while (self_.dim() < (int64_t)sizes_with_bdim.size()) {
113     self_ = self_.unsqueeze(1);
114   }
115   return std::make_tuple(self_.repeat_symint(sizes_with_bdim), 0);
116 }
117 
118 
_unsafe_view_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,c10::SymIntArrayRef size)119 std::tuple<Tensor, std::optional<int64_t>> _unsafe_view_batch_rule(
120     const Tensor& self,
121     std::optional<int64_t> self_bdim,
122     c10::SymIntArrayRef size) {
123   auto self_ = moveBatchDimToFront(self, self_bdim);
124   SymDimVector view_size(size);
125   view_size.insert(view_size.begin(), self_.sym_size(0));
126 
127   // See if the view is valid. If it's not, then we copy.
128   // It's OK to copy, because _unsafe_view(x) guarantees that x isn't used
129   // anymore.
130   const at::SymDimVector inferred_size = at::infer_size_dv(view_size, self_.sym_numel());
131   const auto stride = at::detail::computeStride(self_.sym_sizes(),
132                                                 self_.sym_strides(),
133                                                 inferred_size);
134   if (!stride.has_value()) {
135     self_ = self_.contiguous();
136   }
137   return std::make_tuple(at::_unsafe_view_symint(self_, view_size), 0);
138 }
139 
flip_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,IntArrayRef dims)140 std::tuple<Tensor, std::optional<int64_t>> flip_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim, IntArrayRef dims) {
141   auto self_ = moveBatchDimToFront(self, self_bdim);
142   VmapDimVector new_dims;
143   for (auto i: dims) {
144     new_dims.push_back(getPhysicalDim(self_, true, i));
145   }
146   return std::make_tuple(at::flip(self_, new_dims), 0);
147 }
148 
resize__plumbing(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)149 const Tensor& resize__plumbing(
150     const Tensor& self,
151     IntArrayRef size,
152     std::optional<MemoryFormat> optional_memory_format) {
153   TORCH_CHECK(
154       !optional_memory_format.has_value() ||
155       optional_memory_format == c10::MemoryFormat::Contiguous,
156       "resize_: batching rule only supports None or Contiguous MemoryFormat");
157   auto maybe_layer = maybeCurrentDynamicLayer();
158   vmap_check_escaped(maybe_layer, "resize__plumbing");
159   int64_t cur_level = maybe_layer->layerId();
160   if (!isBatchedAtLevel(self, cur_level)) {
161     c10::impl::ExcludeDispatchKeyGuard guard2(DispatchKey::FuncTorchBatched);
162     return self.resize_(size, optional_memory_format);
163   }
164 
165   auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
166   TORCH_INTERNAL_ASSERT(self_bdim.has_value());
167 
168   // TODO: The following algorithm only works for batch dim == 0.
169   // To get it to work for something else we need the ability to modify
170   // the BatchDims attribute of BatchedTensorImpl
171   TORCH_INTERNAL_ASSERT(self_bdim.value() == 0, "NYI: resize_ batch rule for batch dim != 0");
172 
173   // Resize the wrapped tensor
174   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
175   self_value = moveBatchDimToFront(self_value, self_bdim);
176   VmapDimVector new_size(size);
177   new_size.insert(new_size.begin(), self_value.size(*self_bdim));
178   self_value.resize_(new_size);
179 
180   // Update the sizes and strides of the wrapper
181   auto* batched = maybeGetBatchedImpl(self);
182   TORCH_INTERNAL_ASSERT(batched);
183   batched->refreshTensorMetadata();
184 
185   return self;
186 }
187 
squeeze_batch_rule(const Tensor & self,std::optional<int64_t> bdim)188 std::tuple<Tensor, std::optional<int64_t>> squeeze_batch_rule(const Tensor& self, std::optional<int64_t> bdim) {
189   TORCH_INTERNAL_ASSERT(bdim.has_value());
190   // Special case for scalar arrays to replicate PyTorch behavior.
191   if (self.dim() == 1) {
192     return std::make_tuple(self.alias(), bdim);
193   }
194 
195   // Manually calculate the output shape by eliding all dimensions of
196   // size 1 keeping track of where the batch index started and where it
197   // ended up moving to. We also ensure we do not drop the batch index.
198   auto shape = self.sym_sizes();
199   SymDimVector squeezed_sizes;
200   bool before_batch_idx = true;
201   int64_t new_batch_idx = 0;
202   int64_t original_idx = 0;
203 
204   for (const auto& it : shape) {
205     // Keep only dimensions != 1 and the batch dimension (irrespective of size).
206     if (it != 1 || original_idx == bdim) {
207       squeezed_sizes.push_back(it);
208       if (original_idx == bdim) {
209         before_batch_idx = false;
210       }
211       // Only increment for the dimensions that will be kept in the output.
212       if (before_batch_idx) {
213         ++new_batch_idx;
214       }
215     }
216     ++original_idx;
217   }
218 
219   auto result = self.view_symint(squeezed_sizes);
220   return std::make_tuple(std::move(result), std::optional<int64_t>(new_batch_idx));
221 }
222 
squeeze_dims_batch_rule(const Tensor & self,std::optional<int64_t> bdim,IntArrayRef dims)223 std::tuple<Tensor, std::optional<int64_t>> squeeze_dims_batch_rule(
224     const Tensor& self, std::optional<int64_t> bdim, IntArrayRef dims) {
225   TORCH_INTERNAL_ASSERT(bdim.has_value());
226   // Special case for scalar arrays to replicate PyTorch behavior.
227   auto ndim = self.dim();
228   if (ndim == 1) {
229     TORCH_CHECK(
230         dims.empty() || (dims.size() == 1 && dims[0] == 0),
231         "Dimension is out of range (expected to be in range of [-1, 0], but got ", dims);
232     return std::make_tuple(self.alias(), bdim);
233   }
234 
235   // Adjust any dimensions higher than the batch dimension
236   DimVector adjusted_dims(dims.begin(), dims.end());
237   int64_t updated_batch_idx = *bdim;
238   for (auto &d : adjusted_dims) {
239     auto actual_dim = c10::maybe_wrap_dim(d, ndim - 1);
240     if (actual_dim < *bdim) {
241       d = actual_dim;
242       if (self.sym_size(actual_dim) == 1) {
243         // A column before batch dimension will be dropped so adjust accordingly.
244         --updated_batch_idx;
245       }
246     } else {
247       // Since dimension to be squeezed is after the batch dimension adjust by one to account
248       // for the original batch dimension. In this case batch dimension won't move.
249       d = actual_dim + 1;
250     }
251   }
252   return std::make_tuple(self.squeeze(adjusted_dims), std::optional<int64_t>(updated_batch_idx));
253 }
254 
squeeze_dim_batch_rule(const Tensor & self,std::optional<int64_t> bdim,int64_t dim)255 std::tuple<Tensor, std::optional<int64_t>> squeeze_dim_batch_rule(
256     const Tensor& self, std::optional<int64_t> bdim, int64_t dim) {
257   return squeeze_dims_batch_rule(self, bdim, {dim});
258 }
259 
select_batching_rule(const Tensor & self,std::optional<int64_t> bdim,int64_t dim,c10::SymInt index)260 std::tuple<Tensor, std::optional<int64_t>> select_batching_rule(const Tensor& self, std::optional<int64_t> bdim, int64_t dim, c10::SymInt index) {
261   if (!bdim) {
262     return std::make_tuple(self.select_symint(dim, std::move(index)), std::nullopt);
263   }
264 
265   auto _self = moveBatchDimToFront(self, bdim);
266   auto dim_physical = getPhysicalDim(_self, true, dim);
267   auto result = _self.select_symint(dim_physical, std::move(index));
268   return std::make_tuple(std::move(result), 0);
269 }
270 
_reshape_alias_batch_rule(const Tensor & self,std::optional<int64_t> bdim,const c10::SymIntArrayRef shape,const c10::SymIntArrayRef strides)271 std::tuple<Tensor, std::optional<int64_t>> _reshape_alias_batch_rule(const Tensor& self, std::optional<int64_t> bdim, const c10::SymIntArrayRef shape, const c10::SymIntArrayRef strides) {
272   (void) strides;
273   TORCH_INTERNAL_ASSERT(bdim.has_value());
274 
275   auto self_ = moveBatchDimToFront(self, bdim);
276   c10::SymDimVector new_shape(shape.size() + 1);
277   new_shape[0] = self_.sym_size(0);
278   std::copy(shape.begin(), shape.end(), new_shape.begin() + 1);
279   return std::make_tuple(at::reshape_symint(self_, new_shape), 0);
280 }
281 
roll_batch_rule(const Tensor & self,std::optional<int64_t> bdim,SymIntArrayRef shifts,IntArrayRef dims)282 std::tuple<Tensor, std::optional<int64_t>> roll_batch_rule(const Tensor& self, std::optional<int64_t> bdim, SymIntArrayRef shifts, IntArrayRef dims) {
283   TORCH_INTERNAL_ASSERT(bdim.has_value());
284 
285   auto self_ = moveBatchDimToFront(self, bdim);
286   VmapDimVector new_dims;
287   if (!dims.empty()) {
288     for (auto i: dims) {
289       new_dims.push_back(getPhysicalDim(self, true, i));
290     }
291     return std::make_tuple(at::roll_symint(self_, shifts, new_dims), 0);
292   }
293   // We will do something like: t.reshape(a, -1).roll(1, dims=[1, ]).reshape(old_shape)
294   auto old_shape = self_.sym_sizes();
295   new_dims.push_back(1);
296   auto logical_rank = rankWithoutBatchDim(self, bdim);
297   if (logical_rank == 0) {
298     self_ = self_.unsqueeze(0);
299   }
300 
301   auto output = at::roll_symint(self_.flatten(1), shifts, new_dims);
302   // NOTE: For scalar tensor, we don't need to unsqueeze as reshape
303   // with `old_shape` takes care of it.
304   output = output.reshape_symint(old_shape);
305   return std::make_tuple(output, 0);
306 }
307 
diagonal_batching_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t offset,int64_t dim1,int64_t dim2)308 std::tuple<Tensor, std::optional<int64_t>> diagonal_batching_rule(
309     const Tensor &self, std::optional<int64_t> self_bdim,
310     int64_t offset, int64_t dim1, int64_t dim2)
311 {
312   auto logical_rank = rankWithoutBatchDim(self, self_bdim);
313   auto self_ = moveBatchDimToFront(self, self_bdim);
314   auto dim1_ = maybe_wrap_dim(dim1, logical_rank) + 1;
315   auto dim2_ = maybe_wrap_dim(dim2, logical_rank) + 1;
316   auto result = at::diagonal(self_, offset, dim1_, dim2_);
317   return std::make_tuple(std::move(result), 0);
318 }
319 
diagonal_backward_batch_rule(const Tensor & grad_input,std::optional<int64_t> grad_input_bdim,c10::SymIntArrayRef input_sizes,int64_t offset,int64_t dim1,int64_t dim2)320 std::tuple<Tensor, std::optional<int64_t>> diagonal_backward_batch_rule(
321     const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
322     c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
323   auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
324   auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
325   dim1 = maybe_wrap_dim(dim1, logical_rank + 1) + 1;
326   dim2 = maybe_wrap_dim(dim2, logical_rank + 1) + 1;
327   c10::SymDimVector input_sizes_(input_sizes.size() + 1);
328   input_sizes_[0] = grad_input_.size(0);
329   std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
330   auto result = at::diagonal_backward_symint(grad_input_, input_sizes_, offset, dim1, dim2);
331   return std::make_tuple(std::move(result), 0);
332 }
333 
slice_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,std::optional<c10::SymInt> start,std::optional<c10::SymInt> end,c10::SymInt step)334 std::tuple<Tensor, std::optional<int64_t>> slice_batch_rule(
335     const Tensor& self,
336     std::optional<int64_t> self_bdim,
337     int64_t dim,
338     std::optional<c10::SymInt> start,
339     std::optional<c10::SymInt> end,
340     c10::SymInt step) {
341   auto self_ = moveBatchDimToFront(self, self_bdim);
342   dim = getPhysicalDim(self, self_bdim.has_value(), dim);
343 
344   auto result = self_.slice_symint(dim, std::move(start), std::move(end), std::move(step));
345   return std::make_tuple(std::move(result), 0);
346 }
347 
is_allowed_dim_on_scalar_tensor(int64_t dim)348 static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
349   return dim == 0 || dim == -1;
350 }
351 
352 std::tuple<Tensor, std::optional<int64_t>>
transpose_int_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim0,int64_t dim1)353 transpose_int_batch_rule(
354     const Tensor& self,
355     std::optional<int64_t> self_bdim,
356     int64_t dim0,
357     int64_t dim1) {
358   // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
359   // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
360   // >>> x = torch.randn(B0)  # the per-examples are all scalars
361   // >>> vmap(lambda x: x.transpose(0, -1), x)
362   // then we replicate this behavior.
363   if (/*physical*/self.dim() == 1 && is_allowed_dim_on_scalar_tensor(dim0) &&
364       is_allowed_dim_on_scalar_tensor(dim1)) {
365     return std::make_tuple(self, self_bdim);
366   }
367   auto self_ = moveBatchDimToFront(self, self_bdim);
368   dim0 = getPhysicalDim(self, self_bdim.has_value(), dim0);
369   dim1 = getPhysicalDim(self, self_bdim.has_value(), dim1);
370   auto result = self_.transpose(dim0, dim1);
371   return std::make_tuple(std::move(result), 0);
372 }
373 
permute_batching_rule(const Tensor & self,std::optional<int64_t> self_bdim,IntArrayRef dims)374 std::tuple<Tensor, std::optional<int64_t>> permute_batching_rule(
375     const Tensor &self, std::optional<int64_t> self_bdim, IntArrayRef dims)
376 {
377   if (!self_bdim.has_value()) {
378     return std::make_tuple(self.permute(dims), self_bdim);
379   }
380 
381   auto self_ = moveBatchDimToFront(self, self_bdim);
382   VmapDimVector dims_;
383   dims_.reserve(dims.size() + 1);
384   dims_.emplace_back(0);
385   for (auto dim : dims) {
386     dims_.emplace_back(getPhysicalDim(self_, self_bdim.has_value(), dim));
387   }
388 
389   return std::make_tuple(self_.permute(dims_), 0);
390 }
391 
select_backward_batch_rule(const Tensor & grad_input,std::optional<int64_t> grad_input_bdim,c10::SymIntArrayRef input_sizes,int64_t dim,c10::SymInt index)392 std::tuple<Tensor, std::optional<int64_t>> select_backward_batch_rule(
393     const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
394     c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) {
395   auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
396   auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
397   dim = maybe_wrap_dim(dim, logical_rank + 1) + 1;
398   c10::SymDimVector input_sizes_(input_sizes.size() + 1);
399   input_sizes_[0] = grad_input_.sym_size(0);
400   std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
401   auto result = at::select_backward_symint(grad_input_, input_sizes_, dim, std::move(index));
402   return std::make_tuple(std::move(result), 0);
403 }
404 
slice_backward_batch_rule(const Tensor & grad_input,std::optional<int64_t> grad_input_bdim,SymIntArrayRef input_sizes,int64_t dim,c10::SymInt start,c10::SymInt end,c10::SymInt step)405 std::tuple<Tensor, std::optional<int64_t>> slice_backward_batch_rule(
406     const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
407     SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) {
408   auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
409   auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
410   dim = maybe_wrap_dim(dim, logical_rank) + 1;
411   c10::SymDimVector input_sizes_(input_sizes.size() + 1);
412   input_sizes_[0] = grad_input_.size(0);
413   std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
414   auto result = at::slice_backward_symint(grad_input_, input_sizes_, dim, std::move(start), std::move(end), std::move(step));
415   return std::make_tuple(std::move(result), 0);
416 }
417 
view_batching_rule(const Tensor & self,std::optional<int64_t> self_bdim,SymIntArrayRef sym_size)418 std::tuple<Tensor, std::optional<int64_t>> view_batching_rule(
419     const Tensor &self, std::optional<int64_t> self_bdim, SymIntArrayRef sym_size)
420 {
421   TORCH_INTERNAL_ASSERT(self_bdim.has_value());
422   auto self_ = moveBatchDimToFront(self, self_bdim);
423   c10::SmallVector<c10::SymInt> size_(sym_size.size() + 1);
424   // copy batch size
425   size_[0] = self_.sym_size(0);
426   std::copy(sym_size.cbegin(), sym_size.cend(), size_.begin() + 1);
427   return std::make_tuple(self_.view_symint(size_), 0);
428 }
429 
view_copy_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,c10::SymIntArrayRef size)430 std::tuple<Tensor, std::optional<int64_t>> view_copy_batch_rule(
431     const Tensor& self,
432     std::optional<int64_t> self_bdim,
433     c10::SymIntArrayRef size) {
434   auto self_ = moveBatchDimToFront(self, self_bdim);
435   SymDimVector view_size(size.size() + 1);
436   view_size[0] = self_.size(0);
437   std::copy(size.cbegin(), size.cend(), view_size.begin() + 1);
438 
439   return std::make_tuple(at::view_copy_symint(self_, view_size), 0);
440 }
441 
442 
443 template <typename F, F Func>
expand_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,SymIntArrayRef size,bool implicit)444 std::tuple<Tensor, std::optional<int64_t>> expand_batch_rule(
445     const Tensor &self, std::optional<int64_t> self_bdim, SymIntArrayRef size, bool implicit)
446 {
447   auto self_dim = self.dim();
448   TORCH_CHECK(static_cast<uint64_t>(self_dim - 1) <= size.size(),
449               "expand: the number of sizes provided (", size.size(), ") ",
450               "must be greater or equal to the number of dimensions in the tensor (", static_cast<uint64_t>(self_dim - 1), ")");
451 
452   auto self_ = moveBatchDimToFront(self, self_bdim);
453   auto self_sizes = self_.sym_sizes();
454   const auto& batch_size = self_sizes[0];
455 
456   c10::SmallVector<c10::SymInt> size_(size.size() + 1);
457   size_[0] = batch_size;
458   std::copy(size.cbegin(), size.cend(), size_.begin() + 1);
459 
460   // Here, we know we are expanding a (logical) tensor to a larger number
461   // of dimensions. We have to be careful because we can't call expand directly
462   // due to the presence of batch dimensions.
463   //
464   // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
465   // The result should be a tensor of size [B0, 2, 3].
466   // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
467   // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
468   // then expand.
469   auto extra_dims = size.size() - (self_dim - 1);
470   c10::SmallVector<c10::SymInt> view_shape(size_.size(), /*init_value*/1);
471   view_shape[0] = batch_size;
472   std::copy(self_sizes.cbegin() + 1, self_sizes.cend(),
473             view_shape.begin() + 1 + extra_dims);
474 
475   return std::make_tuple(Func(self_.view_symint(view_shape), size_, implicit), 0);
476 }
477 
unfold_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,int64_t size,int64_t step)478 std::tuple<Tensor, std::optional<int64_t>> unfold_batch_rule(
479     const Tensor &self, std::optional<int64_t> self_bdim, int64_t dim, int64_t size, int64_t step)
480 {
481   TORCH_INTERNAL_ASSERT(self_bdim.has_value());
482   auto self_ = moveBatchDimToFront(self, self_bdim);
483   auto logical_rank = rankWithoutBatchDim(self, self_bdim);
484   dim = maybe_wrap_dim(dim, logical_rank) + 1;
485   if (logical_rank==0) {
486     self_ = self_.unsqueeze(-1);
487   }
488   auto result = self_.unfold(dim, size, step);
489   if (logical_rank==0) {
490     result = result.squeeze(-1);
491   }
492   return std::make_tuple(std::move(result), 0);
493 }
494 
narrow_copy_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,c10::SymInt start,c10::SymInt length)495 std::tuple<Tensor, std::optional<int64_t>> narrow_copy_batch_rule(
496     const Tensor &self, std::optional<int64_t> self_bdim, int64_t dim, c10::SymInt start, c10::SymInt length)
497 {
498   TORCH_INTERNAL_ASSERT(self_bdim.has_value());
499   auto self_ = moveBatchDimToFront(self, self_bdim);
500   auto logical_rank = rankWithoutBatchDim(self, self_bdim);
501   dim = maybe_wrap_dim(dim, logical_rank) + 1;
502   auto result = self_.narrow_copy_symint(dim, std::move(start), std::move(length));
503 
504   return std::make_tuple(std::move(result), 0);
505 }
506 
unsafe_split_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,c10::SymInt split_size,int64_t dim)507 std::tuple<std::vector<Tensor>, std::optional<int64_t>> unsafe_split_batch_rule(
508     const Tensor& self,
509     std::optional<int64_t> self_bdim,
510     c10::SymInt split_size,
511     int64_t dim) {
512   TORCH_INTERNAL_ASSERT(self_bdim.has_value());
513   auto self_ = moveBatchDimToFront(self, self_bdim);
514   auto logical_rank = rankWithoutBatchDim(self, self_bdim);
515   dim = maybe_wrap_dim(dim, logical_rank) + 1;
516   auto result = self_.unsafe_split_symint(std::move(split_size), dim);
517   return std::make_tuple(std::move(result), 0);
518 }
519 
diag_embed_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t offset,int64_t dim1,int64_t dim2)520 std::tuple<Tensor, std::optional<int64_t>> diag_embed_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim, int64_t offset, int64_t dim1, int64_t dim2) {
521   auto logical_rank = rankWithoutBatchDim(self, self_bdim);
522   auto self_ = moveBatchDimToFront(self, self_bdim);
523   dim1 = maybe_wrap_dim(dim1, logical_rank + 1) + 1;
524   dim2 = maybe_wrap_dim(dim2, logical_rank + 1) + 1;
525   return std::make_tuple(at::diag_embed(self_, offset, dim1, dim2), 0);
526 }
527 
trace_decomp(const Tensor & tensor)528 Tensor trace_decomp(const Tensor& tensor) {
529   TORCH_CHECK(tensor.dim() == 2, "trace: expected a matrix, but got tensor with dim ", tensor.dim());
530   return tensor.diagonal().sum();
531 }
532 
tril_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t diagonal=0)533 std::tuple<Tensor, std::optional<int64_t>> tril_batch_rule(
534     const Tensor& self,
535     std::optional<int64_t> self_bdim,
536     int64_t diagonal = 0) {
537   TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions.");
538   auto self_ = moveBatchDimToFront(self, self_bdim);
539   auto result = at::tril(self_, diagonal);
540   return std::make_tuple(std::move(result), 0);
541 }
542 
triu_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t diagonal=0)543 std::tuple<Tensor, std::optional<int64_t>> triu_batch_rule(
544     const Tensor& self,
545     std::optional<int64_t> self_bdim,
546     int64_t diagonal = 0) {
547   TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions.");
548   auto self_ = moveBatchDimToFront(self, self_bdim);
549   auto result = at::triu(self_, diagonal);
550   return std::make_tuple(std::move(result), 0);
551 }
552 
553 }
554 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)555 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
556   VMAP_SUPPORT(flip, flip_batch_rule);
557   m.impl("trace", trace_decomp);
558   VMAP_SUPPORT(tril, tril_batch_rule);
559   VMAP_SUPPORT(triu, triu_batch_rule);
560   VMAP_SUPPORT(repeat, repeat_batch_rule);
561   VMAP_SUPPORT(_unsafe_view, _unsafe_view_batch_rule);
562   VMAP_SUPPORT(unsqueeze, unsqueeze_batch_rule);
563   m.impl("resize_", resize__plumbing);
564   VMAP_SUPPORT2(select, int, select_batching_rule);
565   VMAP_SUPPORT(squeeze, squeeze_batch_rule);
566   VMAP_SUPPORT2(squeeze, dim, squeeze_dim_batch_rule);
567   VMAP_SUPPORT2(squeeze, dims, squeeze_dims_batch_rule);
568   VMAP_SUPPORT(_reshape_alias, _reshape_alias_batch_rule);
569   VMAP_SUPPORT(roll, roll_batch_rule);
570   VMAP_SUPPORT(permute, permute_batching_rule);
571   VMAP_SUPPORT(diagonal, diagonal_batching_rule);
572   VMAP_SUPPORT(diagonal_backward, diagonal_backward_batch_rule);
573   VMAP_SUPPORT(select_backward, select_backward_batch_rule);
574   VMAP_SUPPORT(slice_backward, slice_backward_batch_rule);
575   VMAP_SUPPORT(view, view_batching_rule);
576   VMAP_SUPPORT(view_copy, view_copy_batch_rule);
577   VMAP_SUPPORT(expand, SINGLE_ARG(expand_batch_rule<decltype(&ATEN_FN(expand)), &ATEN_FN(expand)>));
578   VMAP_SUPPORT(expand_copy, SINGLE_ARG(expand_batch_rule<decltype(&ATEN_FN(expand_copy)), &ATEN_FN(expand_copy)>));
579   VMAP_SUPPORT(unfold, unfold_batch_rule);
580   VMAP_SUPPORT2(slice, Tensor, slice_batch_rule);
581   VMAP_SUPPORT2(transpose, int, transpose_int_batch_rule);
582   m.impl("t", native::t);  // CompositeExplicitAutograd, should not go in BatchRulesDecompositions.cpp
583   m.impl("t_", native::t_);  // CompositeExplicitAutograd, should not go in BatchRulesDecompositions.cpp
584   VMAP_SUPPORT(diag_embed, diag_embed_batch_rule);
585   VMAP_SUPPORT(narrow_copy, narrow_copy_batch_rule);
586   VMAP_SUPPORT2(unsafe_split, Tensor, unsafe_split_batch_rule);
587 }
588 
589 } // namespace at::functorch
590