1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <utility>
9
10 #include <ATen/Operators.h>
11 #include <ATen/functorch/PlumbingHelper.h>
12 #include <ATen/functorch/BatchedFallback.h>
13 #include <ATen/core/dispatch/Dispatcher.h>
14 #include <ATen/core/TensorBody.h>
15 #include <c10/core/SymIntArrayRef.h>
16 #include <c10/util/SmallBuffer.h>
17 #include <ATen/InferSize.h>
18
19 namespace at::functorch {
20
21 // Note [Adding vmap support for an operator]
22 // Hey there! So you have an operator and you want to get it to work with vmap.
23 // For example, let's say you just invented the `sum.int` operator and want to make
24 // it so that the following works.
25 // >>> tensor = torch.randn(B, 3)
26 // >>> vmap(torch.sum, (0, None))(tensor, 0)` works
27 // There are three main ways to do so.
28 //
29 // Note [Writing batch rule for out-of-place operators]
30 // If your operator is out-of-place, you can write a batch rule for it.
31 // The batch rule defines how to perform the operator on inputs where each
32 // Tensor input may have an additional dimension that is being vmapped over.
33 // We refer to this dimension as the *batch dimension* or bdim for short.
34 //
35 // For example, let's consider writing a batch rule for
36 // `Tensor sum(const Tensor& self, int64_t dim)`. The signature of the
37 // batch rule has an additional std::optional<int64_t> argument after each
38 // Tensor argument and return. So, in this case, the batch rule has signature
39 // tuple<Tensor, std::optional<int64_t>> sum_batch_rule(
40 // const Tensor& self, std::optional<int64_t> self_bdim, int64_t dim);
41 //
42 // The vmap call above invokes the batch rule with `self = tensor`,
43 // `self_bdim = 0`, and `dim = 0`. Note that there are **no BatchedTensors**
44 // involved in this case; there exists some plumbing that automatically unwraps
45 // BatchedTensors before calling the batch rule.
46 //
47 // To write the logic of the batch rule: think about the semantics of the
48 // `sum` operation if `self` had an additional dimension (indicated by self_bdim):
49 // - If `self_bdim` is null, then we just do `result = self.sum(dim)` as usual
50 // - If `self_bdim` is not-null, then we need to modify `dim`. `dim` is equal
51 // to whatever the user passed in (0 in this case), but we should actually
52 // perform the reduction over dimension 1 and do `result = self.sum(1)`
53 // because dim 0 is being vmapped over.
54 // Finally, we return the result as well as a new bdim
55 // - If `self_bdim` is null, then there's no batch dim in the result.
56 // - If `self_bdim` is not-null, then we return where the bdim is.
57 // Since we invoked `result = self.sum(1)`, the bdim is still at dim 0.
58 //
59 // Now that we have written `sum_batch_rule`, we have to register it inside a
60 // TORCH_LIBRARY_IMPL block:
61 // TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
62 // ...
63 // VMAP_SUPPORT2(sum, int, sum_batch_rule);
64 // ...
65 // }
66 //
67 // Note [Reusing batch rules to add vmap support for a complicated operator]
68 // Can't figure out how to write a batch rule for a big operation? If the
69 // operation can be expressed as a composition of other operations that do have
70 // batch rules, then that is another way to add vmap support. For example,
71 // consider the following schema
72 // func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1)
73 // and assume we already have batching rules for basic arithmetic operators.
74 //
75 // To add vmap support, define a decomposition using the same signature:
76 // Tensor addcmul_decomp(const Tensor& self, const Tensor& tensor1,
77 // const Tensor& tensor2, const Scalar& value) {
78 // auto product = torch.mul(tensor1, tensor2);
79 // return torch.add(self, product, value);
80 // }
81 // And register it inside a TORCH_LIBRARY_IMPL block:
82 // TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
83 // ...
84 // m.impl("addcmul", addcmul_decomp);
85 // ...
86 // }
87 //
88 // Note [Writing batch rule for in-place operators]
89 // TODO: This is kinda complicated. Saving this for a future date.
90
91 namespace{
92
unsqueeze_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim)93 std::tuple<Tensor, std::optional<int64_t>> unsqueeze_batch_rule(
94 const Tensor& self,
95 std::optional<int64_t> self_bdim,
96 int64_t dim) {
97 auto self_ = moveBatchDimToFront(self, self_bdim);
98 auto rank = rankWithoutBatchDim(self, self_bdim);
99 dim = maybe_wrap_dim(dim, rank + 1) + 1;
100 return std::make_tuple(self_.unsqueeze(dim), 0);
101 }
102
103 // NB: repeat is not actually a view, but it is in this file
repeat_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,c10::SymIntArrayRef sizes)104 std::tuple<Tensor, std::optional<int64_t>> repeat_batch_rule(
105 const Tensor& self,
106 std::optional<int64_t> self_bdim,
107 c10::SymIntArrayRef sizes) {
108
109 SymDimVector sizes_with_bdim = { sizes.begin(), sizes.end() };
110 sizes_with_bdim.insert(sizes_with_bdim.begin(), 1);
111 auto self_ = moveBatchDimToFront(self, self_bdim);
112 while (self_.dim() < (int64_t)sizes_with_bdim.size()) {
113 self_ = self_.unsqueeze(1);
114 }
115 return std::make_tuple(self_.repeat_symint(sizes_with_bdim), 0);
116 }
117
118
_unsafe_view_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,c10::SymIntArrayRef size)119 std::tuple<Tensor, std::optional<int64_t>> _unsafe_view_batch_rule(
120 const Tensor& self,
121 std::optional<int64_t> self_bdim,
122 c10::SymIntArrayRef size) {
123 auto self_ = moveBatchDimToFront(self, self_bdim);
124 SymDimVector view_size(size);
125 view_size.insert(view_size.begin(), self_.sym_size(0));
126
127 // See if the view is valid. If it's not, then we copy.
128 // It's OK to copy, because _unsafe_view(x) guarantees that x isn't used
129 // anymore.
130 const at::SymDimVector inferred_size = at::infer_size_dv(view_size, self_.sym_numel());
131 const auto stride = at::detail::computeStride(self_.sym_sizes(),
132 self_.sym_strides(),
133 inferred_size);
134 if (!stride.has_value()) {
135 self_ = self_.contiguous();
136 }
137 return std::make_tuple(at::_unsafe_view_symint(self_, view_size), 0);
138 }
139
flip_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,IntArrayRef dims)140 std::tuple<Tensor, std::optional<int64_t>> flip_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim, IntArrayRef dims) {
141 auto self_ = moveBatchDimToFront(self, self_bdim);
142 VmapDimVector new_dims;
143 for (auto i: dims) {
144 new_dims.push_back(getPhysicalDim(self_, true, i));
145 }
146 return std::make_tuple(at::flip(self_, new_dims), 0);
147 }
148
resize__plumbing(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)149 const Tensor& resize__plumbing(
150 const Tensor& self,
151 IntArrayRef size,
152 std::optional<MemoryFormat> optional_memory_format) {
153 TORCH_CHECK(
154 !optional_memory_format.has_value() ||
155 optional_memory_format == c10::MemoryFormat::Contiguous,
156 "resize_: batching rule only supports None or Contiguous MemoryFormat");
157 auto maybe_layer = maybeCurrentDynamicLayer();
158 vmap_check_escaped(maybe_layer, "resize__plumbing");
159 int64_t cur_level = maybe_layer->layerId();
160 if (!isBatchedAtLevel(self, cur_level)) {
161 c10::impl::ExcludeDispatchKeyGuard guard2(DispatchKey::FuncTorchBatched);
162 return self.resize_(size, optional_memory_format);
163 }
164
165 auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
166 TORCH_INTERNAL_ASSERT(self_bdim.has_value());
167
168 // TODO: The following algorithm only works for batch dim == 0.
169 // To get it to work for something else we need the ability to modify
170 // the BatchDims attribute of BatchedTensorImpl
171 TORCH_INTERNAL_ASSERT(self_bdim.value() == 0, "NYI: resize_ batch rule for batch dim != 0");
172
173 // Resize the wrapped tensor
174 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
175 self_value = moveBatchDimToFront(self_value, self_bdim);
176 VmapDimVector new_size(size);
177 new_size.insert(new_size.begin(), self_value.size(*self_bdim));
178 self_value.resize_(new_size);
179
180 // Update the sizes and strides of the wrapper
181 auto* batched = maybeGetBatchedImpl(self);
182 TORCH_INTERNAL_ASSERT(batched);
183 batched->refreshTensorMetadata();
184
185 return self;
186 }
187
squeeze_batch_rule(const Tensor & self,std::optional<int64_t> bdim)188 std::tuple<Tensor, std::optional<int64_t>> squeeze_batch_rule(const Tensor& self, std::optional<int64_t> bdim) {
189 TORCH_INTERNAL_ASSERT(bdim.has_value());
190 // Special case for scalar arrays to replicate PyTorch behavior.
191 if (self.dim() == 1) {
192 return std::make_tuple(self.alias(), bdim);
193 }
194
195 // Manually calculate the output shape by eliding all dimensions of
196 // size 1 keeping track of where the batch index started and where it
197 // ended up moving to. We also ensure we do not drop the batch index.
198 auto shape = self.sym_sizes();
199 SymDimVector squeezed_sizes;
200 bool before_batch_idx = true;
201 int64_t new_batch_idx = 0;
202 int64_t original_idx = 0;
203
204 for (const auto& it : shape) {
205 // Keep only dimensions != 1 and the batch dimension (irrespective of size).
206 if (it != 1 || original_idx == bdim) {
207 squeezed_sizes.push_back(it);
208 if (original_idx == bdim) {
209 before_batch_idx = false;
210 }
211 // Only increment for the dimensions that will be kept in the output.
212 if (before_batch_idx) {
213 ++new_batch_idx;
214 }
215 }
216 ++original_idx;
217 }
218
219 auto result = self.view_symint(squeezed_sizes);
220 return std::make_tuple(std::move(result), std::optional<int64_t>(new_batch_idx));
221 }
222
squeeze_dims_batch_rule(const Tensor & self,std::optional<int64_t> bdim,IntArrayRef dims)223 std::tuple<Tensor, std::optional<int64_t>> squeeze_dims_batch_rule(
224 const Tensor& self, std::optional<int64_t> bdim, IntArrayRef dims) {
225 TORCH_INTERNAL_ASSERT(bdim.has_value());
226 // Special case for scalar arrays to replicate PyTorch behavior.
227 auto ndim = self.dim();
228 if (ndim == 1) {
229 TORCH_CHECK(
230 dims.empty() || (dims.size() == 1 && dims[0] == 0),
231 "Dimension is out of range (expected to be in range of [-1, 0], but got ", dims);
232 return std::make_tuple(self.alias(), bdim);
233 }
234
235 // Adjust any dimensions higher than the batch dimension
236 DimVector adjusted_dims(dims.begin(), dims.end());
237 int64_t updated_batch_idx = *bdim;
238 for (auto &d : adjusted_dims) {
239 auto actual_dim = c10::maybe_wrap_dim(d, ndim - 1);
240 if (actual_dim < *bdim) {
241 d = actual_dim;
242 if (self.sym_size(actual_dim) == 1) {
243 // A column before batch dimension will be dropped so adjust accordingly.
244 --updated_batch_idx;
245 }
246 } else {
247 // Since dimension to be squeezed is after the batch dimension adjust by one to account
248 // for the original batch dimension. In this case batch dimension won't move.
249 d = actual_dim + 1;
250 }
251 }
252 return std::make_tuple(self.squeeze(adjusted_dims), std::optional<int64_t>(updated_batch_idx));
253 }
254
squeeze_dim_batch_rule(const Tensor & self,std::optional<int64_t> bdim,int64_t dim)255 std::tuple<Tensor, std::optional<int64_t>> squeeze_dim_batch_rule(
256 const Tensor& self, std::optional<int64_t> bdim, int64_t dim) {
257 return squeeze_dims_batch_rule(self, bdim, {dim});
258 }
259
select_batching_rule(const Tensor & self,std::optional<int64_t> bdim,int64_t dim,c10::SymInt index)260 std::tuple<Tensor, std::optional<int64_t>> select_batching_rule(const Tensor& self, std::optional<int64_t> bdim, int64_t dim, c10::SymInt index) {
261 if (!bdim) {
262 return std::make_tuple(self.select_symint(dim, std::move(index)), std::nullopt);
263 }
264
265 auto _self = moveBatchDimToFront(self, bdim);
266 auto dim_physical = getPhysicalDim(_self, true, dim);
267 auto result = _self.select_symint(dim_physical, std::move(index));
268 return std::make_tuple(std::move(result), 0);
269 }
270
_reshape_alias_batch_rule(const Tensor & self,std::optional<int64_t> bdim,const c10::SymIntArrayRef shape,const c10::SymIntArrayRef strides)271 std::tuple<Tensor, std::optional<int64_t>> _reshape_alias_batch_rule(const Tensor& self, std::optional<int64_t> bdim, const c10::SymIntArrayRef shape, const c10::SymIntArrayRef strides) {
272 (void) strides;
273 TORCH_INTERNAL_ASSERT(bdim.has_value());
274
275 auto self_ = moveBatchDimToFront(self, bdim);
276 c10::SymDimVector new_shape(shape.size() + 1);
277 new_shape[0] = self_.sym_size(0);
278 std::copy(shape.begin(), shape.end(), new_shape.begin() + 1);
279 return std::make_tuple(at::reshape_symint(self_, new_shape), 0);
280 }
281
roll_batch_rule(const Tensor & self,std::optional<int64_t> bdim,SymIntArrayRef shifts,IntArrayRef dims)282 std::tuple<Tensor, std::optional<int64_t>> roll_batch_rule(const Tensor& self, std::optional<int64_t> bdim, SymIntArrayRef shifts, IntArrayRef dims) {
283 TORCH_INTERNAL_ASSERT(bdim.has_value());
284
285 auto self_ = moveBatchDimToFront(self, bdim);
286 VmapDimVector new_dims;
287 if (!dims.empty()) {
288 for (auto i: dims) {
289 new_dims.push_back(getPhysicalDim(self, true, i));
290 }
291 return std::make_tuple(at::roll_symint(self_, shifts, new_dims), 0);
292 }
293 // We will do something like: t.reshape(a, -1).roll(1, dims=[1, ]).reshape(old_shape)
294 auto old_shape = self_.sym_sizes();
295 new_dims.push_back(1);
296 auto logical_rank = rankWithoutBatchDim(self, bdim);
297 if (logical_rank == 0) {
298 self_ = self_.unsqueeze(0);
299 }
300
301 auto output = at::roll_symint(self_.flatten(1), shifts, new_dims);
302 // NOTE: For scalar tensor, we don't need to unsqueeze as reshape
303 // with `old_shape` takes care of it.
304 output = output.reshape_symint(old_shape);
305 return std::make_tuple(output, 0);
306 }
307
diagonal_batching_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t offset,int64_t dim1,int64_t dim2)308 std::tuple<Tensor, std::optional<int64_t>> diagonal_batching_rule(
309 const Tensor &self, std::optional<int64_t> self_bdim,
310 int64_t offset, int64_t dim1, int64_t dim2)
311 {
312 auto logical_rank = rankWithoutBatchDim(self, self_bdim);
313 auto self_ = moveBatchDimToFront(self, self_bdim);
314 auto dim1_ = maybe_wrap_dim(dim1, logical_rank) + 1;
315 auto dim2_ = maybe_wrap_dim(dim2, logical_rank) + 1;
316 auto result = at::diagonal(self_, offset, dim1_, dim2_);
317 return std::make_tuple(std::move(result), 0);
318 }
319
diagonal_backward_batch_rule(const Tensor & grad_input,std::optional<int64_t> grad_input_bdim,c10::SymIntArrayRef input_sizes,int64_t offset,int64_t dim1,int64_t dim2)320 std::tuple<Tensor, std::optional<int64_t>> diagonal_backward_batch_rule(
321 const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
322 c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
323 auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
324 auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
325 dim1 = maybe_wrap_dim(dim1, logical_rank + 1) + 1;
326 dim2 = maybe_wrap_dim(dim2, logical_rank + 1) + 1;
327 c10::SymDimVector input_sizes_(input_sizes.size() + 1);
328 input_sizes_[0] = grad_input_.size(0);
329 std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
330 auto result = at::diagonal_backward_symint(grad_input_, input_sizes_, offset, dim1, dim2);
331 return std::make_tuple(std::move(result), 0);
332 }
333
slice_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,std::optional<c10::SymInt> start,std::optional<c10::SymInt> end,c10::SymInt step)334 std::tuple<Tensor, std::optional<int64_t>> slice_batch_rule(
335 const Tensor& self,
336 std::optional<int64_t> self_bdim,
337 int64_t dim,
338 std::optional<c10::SymInt> start,
339 std::optional<c10::SymInt> end,
340 c10::SymInt step) {
341 auto self_ = moveBatchDimToFront(self, self_bdim);
342 dim = getPhysicalDim(self, self_bdim.has_value(), dim);
343
344 auto result = self_.slice_symint(dim, std::move(start), std::move(end), std::move(step));
345 return std::make_tuple(std::move(result), 0);
346 }
347
is_allowed_dim_on_scalar_tensor(int64_t dim)348 static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
349 return dim == 0 || dim == -1;
350 }
351
352 std::tuple<Tensor, std::optional<int64_t>>
transpose_int_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim0,int64_t dim1)353 transpose_int_batch_rule(
354 const Tensor& self,
355 std::optional<int64_t> self_bdim,
356 int64_t dim0,
357 int64_t dim1) {
358 // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
359 // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
360 // >>> x = torch.randn(B0) # the per-examples are all scalars
361 // >>> vmap(lambda x: x.transpose(0, -1), x)
362 // then we replicate this behavior.
363 if (/*physical*/self.dim() == 1 && is_allowed_dim_on_scalar_tensor(dim0) &&
364 is_allowed_dim_on_scalar_tensor(dim1)) {
365 return std::make_tuple(self, self_bdim);
366 }
367 auto self_ = moveBatchDimToFront(self, self_bdim);
368 dim0 = getPhysicalDim(self, self_bdim.has_value(), dim0);
369 dim1 = getPhysicalDim(self, self_bdim.has_value(), dim1);
370 auto result = self_.transpose(dim0, dim1);
371 return std::make_tuple(std::move(result), 0);
372 }
373
permute_batching_rule(const Tensor & self,std::optional<int64_t> self_bdim,IntArrayRef dims)374 std::tuple<Tensor, std::optional<int64_t>> permute_batching_rule(
375 const Tensor &self, std::optional<int64_t> self_bdim, IntArrayRef dims)
376 {
377 if (!self_bdim.has_value()) {
378 return std::make_tuple(self.permute(dims), self_bdim);
379 }
380
381 auto self_ = moveBatchDimToFront(self, self_bdim);
382 VmapDimVector dims_;
383 dims_.reserve(dims.size() + 1);
384 dims_.emplace_back(0);
385 for (auto dim : dims) {
386 dims_.emplace_back(getPhysicalDim(self_, self_bdim.has_value(), dim));
387 }
388
389 return std::make_tuple(self_.permute(dims_), 0);
390 }
391
select_backward_batch_rule(const Tensor & grad_input,std::optional<int64_t> grad_input_bdim,c10::SymIntArrayRef input_sizes,int64_t dim,c10::SymInt index)392 std::tuple<Tensor, std::optional<int64_t>> select_backward_batch_rule(
393 const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
394 c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) {
395 auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
396 auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
397 dim = maybe_wrap_dim(dim, logical_rank + 1) + 1;
398 c10::SymDimVector input_sizes_(input_sizes.size() + 1);
399 input_sizes_[0] = grad_input_.sym_size(0);
400 std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
401 auto result = at::select_backward_symint(grad_input_, input_sizes_, dim, std::move(index));
402 return std::make_tuple(std::move(result), 0);
403 }
404
slice_backward_batch_rule(const Tensor & grad_input,std::optional<int64_t> grad_input_bdim,SymIntArrayRef input_sizes,int64_t dim,c10::SymInt start,c10::SymInt end,c10::SymInt step)405 std::tuple<Tensor, std::optional<int64_t>> slice_backward_batch_rule(
406 const Tensor& grad_input, std::optional<int64_t> grad_input_bdim,
407 SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) {
408 auto logical_rank = rankWithoutBatchDim(grad_input, grad_input_bdim);
409 auto grad_input_ = moveBatchDimToFront(grad_input, grad_input_bdim);
410 dim = maybe_wrap_dim(dim, logical_rank) + 1;
411 c10::SymDimVector input_sizes_(input_sizes.size() + 1);
412 input_sizes_[0] = grad_input_.size(0);
413 std::copy(input_sizes.begin(), input_sizes.end(), input_sizes_.begin() + 1);
414 auto result = at::slice_backward_symint(grad_input_, input_sizes_, dim, std::move(start), std::move(end), std::move(step));
415 return std::make_tuple(std::move(result), 0);
416 }
417
view_batching_rule(const Tensor & self,std::optional<int64_t> self_bdim,SymIntArrayRef sym_size)418 std::tuple<Tensor, std::optional<int64_t>> view_batching_rule(
419 const Tensor &self, std::optional<int64_t> self_bdim, SymIntArrayRef sym_size)
420 {
421 TORCH_INTERNAL_ASSERT(self_bdim.has_value());
422 auto self_ = moveBatchDimToFront(self, self_bdim);
423 c10::SmallVector<c10::SymInt> size_(sym_size.size() + 1);
424 // copy batch size
425 size_[0] = self_.sym_size(0);
426 std::copy(sym_size.cbegin(), sym_size.cend(), size_.begin() + 1);
427 return std::make_tuple(self_.view_symint(size_), 0);
428 }
429
view_copy_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,c10::SymIntArrayRef size)430 std::tuple<Tensor, std::optional<int64_t>> view_copy_batch_rule(
431 const Tensor& self,
432 std::optional<int64_t> self_bdim,
433 c10::SymIntArrayRef size) {
434 auto self_ = moveBatchDimToFront(self, self_bdim);
435 SymDimVector view_size(size.size() + 1);
436 view_size[0] = self_.size(0);
437 std::copy(size.cbegin(), size.cend(), view_size.begin() + 1);
438
439 return std::make_tuple(at::view_copy_symint(self_, view_size), 0);
440 }
441
442
443 template <typename F, F Func>
expand_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,SymIntArrayRef size,bool implicit)444 std::tuple<Tensor, std::optional<int64_t>> expand_batch_rule(
445 const Tensor &self, std::optional<int64_t> self_bdim, SymIntArrayRef size, bool implicit)
446 {
447 auto self_dim = self.dim();
448 TORCH_CHECK(static_cast<uint64_t>(self_dim - 1) <= size.size(),
449 "expand: the number of sizes provided (", size.size(), ") ",
450 "must be greater or equal to the number of dimensions in the tensor (", static_cast<uint64_t>(self_dim - 1), ")");
451
452 auto self_ = moveBatchDimToFront(self, self_bdim);
453 auto self_sizes = self_.sym_sizes();
454 const auto& batch_size = self_sizes[0];
455
456 c10::SmallVector<c10::SymInt> size_(size.size() + 1);
457 size_[0] = batch_size;
458 std::copy(size.cbegin(), size.cend(), size_.begin() + 1);
459
460 // Here, we know we are expanding a (logical) tensor to a larger number
461 // of dimensions. We have to be careful because we can't call expand directly
462 // due to the presence of batch dimensions.
463 //
464 // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
465 // The result should be a tensor of size [B0, 2, 3].
466 // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
467 // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
468 // then expand.
469 auto extra_dims = size.size() - (self_dim - 1);
470 c10::SmallVector<c10::SymInt> view_shape(size_.size(), /*init_value*/1);
471 view_shape[0] = batch_size;
472 std::copy(self_sizes.cbegin() + 1, self_sizes.cend(),
473 view_shape.begin() + 1 + extra_dims);
474
475 return std::make_tuple(Func(self_.view_symint(view_shape), size_, implicit), 0);
476 }
477
unfold_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,int64_t size,int64_t step)478 std::tuple<Tensor, std::optional<int64_t>> unfold_batch_rule(
479 const Tensor &self, std::optional<int64_t> self_bdim, int64_t dim, int64_t size, int64_t step)
480 {
481 TORCH_INTERNAL_ASSERT(self_bdim.has_value());
482 auto self_ = moveBatchDimToFront(self, self_bdim);
483 auto logical_rank = rankWithoutBatchDim(self, self_bdim);
484 dim = maybe_wrap_dim(dim, logical_rank) + 1;
485 if (logical_rank==0) {
486 self_ = self_.unsqueeze(-1);
487 }
488 auto result = self_.unfold(dim, size, step);
489 if (logical_rank==0) {
490 result = result.squeeze(-1);
491 }
492 return std::make_tuple(std::move(result), 0);
493 }
494
narrow_copy_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,c10::SymInt start,c10::SymInt length)495 std::tuple<Tensor, std::optional<int64_t>> narrow_copy_batch_rule(
496 const Tensor &self, std::optional<int64_t> self_bdim, int64_t dim, c10::SymInt start, c10::SymInt length)
497 {
498 TORCH_INTERNAL_ASSERT(self_bdim.has_value());
499 auto self_ = moveBatchDimToFront(self, self_bdim);
500 auto logical_rank = rankWithoutBatchDim(self, self_bdim);
501 dim = maybe_wrap_dim(dim, logical_rank) + 1;
502 auto result = self_.narrow_copy_symint(dim, std::move(start), std::move(length));
503
504 return std::make_tuple(std::move(result), 0);
505 }
506
unsafe_split_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,c10::SymInt split_size,int64_t dim)507 std::tuple<std::vector<Tensor>, std::optional<int64_t>> unsafe_split_batch_rule(
508 const Tensor& self,
509 std::optional<int64_t> self_bdim,
510 c10::SymInt split_size,
511 int64_t dim) {
512 TORCH_INTERNAL_ASSERT(self_bdim.has_value());
513 auto self_ = moveBatchDimToFront(self, self_bdim);
514 auto logical_rank = rankWithoutBatchDim(self, self_bdim);
515 dim = maybe_wrap_dim(dim, logical_rank) + 1;
516 auto result = self_.unsafe_split_symint(std::move(split_size), dim);
517 return std::make_tuple(std::move(result), 0);
518 }
519
diag_embed_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t offset,int64_t dim1,int64_t dim2)520 std::tuple<Tensor, std::optional<int64_t>> diag_embed_batch_rule(const Tensor& self, std::optional<int64_t> self_bdim, int64_t offset, int64_t dim1, int64_t dim2) {
521 auto logical_rank = rankWithoutBatchDim(self, self_bdim);
522 auto self_ = moveBatchDimToFront(self, self_bdim);
523 dim1 = maybe_wrap_dim(dim1, logical_rank + 1) + 1;
524 dim2 = maybe_wrap_dim(dim2, logical_rank + 1) + 1;
525 return std::make_tuple(at::diag_embed(self_, offset, dim1, dim2), 0);
526 }
527
trace_decomp(const Tensor & tensor)528 Tensor trace_decomp(const Tensor& tensor) {
529 TORCH_CHECK(tensor.dim() == 2, "trace: expected a matrix, but got tensor with dim ", tensor.dim());
530 return tensor.diagonal().sum();
531 }
532
tril_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t diagonal=0)533 std::tuple<Tensor, std::optional<int64_t>> tril_batch_rule(
534 const Tensor& self,
535 std::optional<int64_t> self_bdim,
536 int64_t diagonal = 0) {
537 TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions.");
538 auto self_ = moveBatchDimToFront(self, self_bdim);
539 auto result = at::tril(self_, diagonal);
540 return std::make_tuple(std::move(result), 0);
541 }
542
triu_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t diagonal=0)543 std::tuple<Tensor, std::optional<int64_t>> triu_batch_rule(
544 const Tensor& self,
545 std::optional<int64_t> self_bdim,
546 int64_t diagonal = 0) {
547 TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions.");
548 auto self_ = moveBatchDimToFront(self, self_bdim);
549 auto result = at::triu(self_, diagonal);
550 return std::make_tuple(std::move(result), 0);
551 }
552
553 }
554
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)555 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
556 VMAP_SUPPORT(flip, flip_batch_rule);
557 m.impl("trace", trace_decomp);
558 VMAP_SUPPORT(tril, tril_batch_rule);
559 VMAP_SUPPORT(triu, triu_batch_rule);
560 VMAP_SUPPORT(repeat, repeat_batch_rule);
561 VMAP_SUPPORT(_unsafe_view, _unsafe_view_batch_rule);
562 VMAP_SUPPORT(unsqueeze, unsqueeze_batch_rule);
563 m.impl("resize_", resize__plumbing);
564 VMAP_SUPPORT2(select, int, select_batching_rule);
565 VMAP_SUPPORT(squeeze, squeeze_batch_rule);
566 VMAP_SUPPORT2(squeeze, dim, squeeze_dim_batch_rule);
567 VMAP_SUPPORT2(squeeze, dims, squeeze_dims_batch_rule);
568 VMAP_SUPPORT(_reshape_alias, _reshape_alias_batch_rule);
569 VMAP_SUPPORT(roll, roll_batch_rule);
570 VMAP_SUPPORT(permute, permute_batching_rule);
571 VMAP_SUPPORT(diagonal, diagonal_batching_rule);
572 VMAP_SUPPORT(diagonal_backward, diagonal_backward_batch_rule);
573 VMAP_SUPPORT(select_backward, select_backward_batch_rule);
574 VMAP_SUPPORT(slice_backward, slice_backward_batch_rule);
575 VMAP_SUPPORT(view, view_batching_rule);
576 VMAP_SUPPORT(view_copy, view_copy_batch_rule);
577 VMAP_SUPPORT(expand, SINGLE_ARG(expand_batch_rule<decltype(&ATEN_FN(expand)), &ATEN_FN(expand)>));
578 VMAP_SUPPORT(expand_copy, SINGLE_ARG(expand_batch_rule<decltype(&ATEN_FN(expand_copy)), &ATEN_FN(expand_copy)>));
579 VMAP_SUPPORT(unfold, unfold_batch_rule);
580 VMAP_SUPPORT2(slice, Tensor, slice_batch_rule);
581 VMAP_SUPPORT2(transpose, int, transpose_int_batch_rule);
582 m.impl("t", native::t); // CompositeExplicitAutograd, should not go in BatchRulesDecompositions.cpp
583 m.impl("t_", native::t_); // CompositeExplicitAutograd, should not go in BatchRulesDecompositions.cpp
584 VMAP_SUPPORT(diag_embed, diag_embed_batch_rule);
585 VMAP_SUPPORT(narrow_copy, narrow_copy_batch_rule);
586 VMAP_SUPPORT2(unsafe_split, Tensor, unsafe_split_batch_rule);
587 }
588
589 } // namespace at::functorch
590