xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <torch/library.h>
8 #include <ATen/native/ResizeCommon.h>
9 #include <ATen/ATen.h>
10 #include <ATen/native/TensorShape.h>
11 
12 #include <ATen/NestedTensorImpl.h>
13 #include <ATen/functorch/DynamicLayer.h>
14 #include <ATen/functorch/TensorWrapper.h>
15 #include <ATen/functorch/BatchingMetaprogramming.h>
16 #include <ATen/functorch/LegacyVmapTransforms.h>
17 #include <ATen/functorch/BatchedFallback.h>
18 #include <ATen/functorch/BatchRulesHelper.h>
19 
20 #include <utility>
21 
22 namespace at::functorch {
23 
24 
25 // NOTE: [What is a batching rule?]
26 //
27 // NB: the following description only applies to this file and is about
28 // the legacy (deprecated) batching rule API. Please see writing_batch_rules.md
29 // for how to write new-style batching rules.
30 //
31 // This files contains batching rules written with the legacy (now-deprecated)
32 // batching rule API.
33 // Please try to use the new-style batching rule API (see writing_batch_rules.md)
34 //
35 // A *batching rule* implements the logic of how to call an operator on inputs
36 // that have zero or more additional batch dimensions. When one does a vmap, the
37 // dimension(s) being vmap'ed over get recorded as batch dimensions.
38 //
39 // For example, vmap(torch.add)(x, y)
40 // 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)];
41 // 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)];
42 // 3. and then runs `torch.add(batched_x, batched_y)`.
43 
44 // NOTE: [When should I add a batching rule?]
45 // When you are adding a new operator, you'll need to add a batching rule so
46 // that vmap can work efficiently with said operator. If you do not, we'll attempt
47 // to generate a slow fallback for the batching rule.
48 
49 // NOTE: [How to write batching rules?]
50 // The signature of a batching rule should look like exactly like the C++ signature
51 // of its operator.
52 //
53 // First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology.
54 //
55 // At a high level, what a batching rule does is the following:
56 // 1. Converts (logical) BatchedTensors to views on physical tensors.
57 // 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical
58 //    arguments that correspond to the physical tensors.
59 // 3. Calls at:: operations on the physical tensors and arguments to produce
60 //    some physical results.
61 // 4. Converts physical results back to BatchedTensors.
62 //
63 // Steps 1, 2, and 4 differ for operators with different batching behaviors. When
64 // writing a new batching rule, please select a VmapTransform that matches the
65 // batching behavior of your operation. The VmapTransform provides helper functions
66 // to do steps (1), (2), and (4).
67 // (see NOTE: [What is an VmapTransform?] in VmapTransforms.h)
68 
69 namespace{
70 // PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor.
is_allowed_dim_on_scalar_tensor(int64_t dim)71 static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
72   return dim == 0 || dim == -1;
73 }
74 
get_current_level()75 static int64_t get_current_level() {
76   auto maybe_level = maybeCurrentDynamicLayer();
77   TORCH_INTERNAL_ASSERT(maybe_level.has_value());
78   return maybe_level->layerId();
79 }
80 
81 // This check should probably go into the dispatcher...
participatesInCurrentLevel(const Tensor & self)82 static bool participatesInCurrentLevel(const Tensor& self) {
83   auto current_level = get_current_level();
84   auto* maybe_batched_impl = maybeGetBatchedImpl(self);
85   if (!maybe_batched_impl) {
86     return false;
87   }
88   auto self_level = maybe_batched_impl->level();
89   TORCH_INTERNAL_ASSERT(self_level <= current_level);
90   return self_level == current_level;
91 }
92 
participatesInCurrentLevel(ITensorListRef self)93 static bool participatesInCurrentLevel(ITensorListRef self) {
94   for (const Tensor& tensor : self) {
95     if (participatesInCurrentLevel(tensor)) {
96       return true;
97     }
98   }
99   return false;
100 }
101 
squeeze_dims__batching_rule(Tensor & self,IntArrayRef dims)102 Tensor& squeeze_dims__batching_rule(Tensor& self, IntArrayRef dims) {
103   if (!participatesInCurrentLevel(self)) {
104     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
105     return self.squeeze_(dims);
106   }
107   auto* batched = maybeGetBatchedImpl(self);
108   const auto bdim = batched->bdim();
109   auto logical_dim = self.dim();
110 
111   if (logical_dim == 0) {
112     TORCH_CHECK(
113         dims.empty() || (dims.size() == 1 && dims[0] == 0),
114         "Dimension is out of range (expected to be in range of [-1, 0], but got ", dims);
115     return self;
116   }
117 
118   // Adjust any dimensions higher than the batch dimension
119   DimVector adjusted_dims(dims.begin(), dims.end());
120   int64_t updated_batch_idx = bdim;
121   for (auto &d : adjusted_dims) {
122     auto actual_dim = c10::maybe_wrap_dim(d, logical_dim);
123     if (actual_dim < bdim) {
124       d = actual_dim;
125       if (batched->value().sym_size(actual_dim) == 1) {
126         // A column before batch dimension will be dropped so adjust accordingly.
127         --updated_batch_idx;
128       }
129     } else {
130       // Since dimension to be squeezed is after the batch dimension adjust by one to account
131       // for the original batch dimension. In this case batch dimension won't move.
132       d = actual_dim + 1;
133     }
134   }
135 
136   batched->value().squeeze_(adjusted_dims);
137   if (updated_batch_idx != bdim) {
138     batched->unsafe_set_bdim(updated_batch_idx);
139   }
140   batched->refreshTensorMetadata();
141   return self;
142 }
143 
squeeze_dim__batching_rule(Tensor & self,int64_t dim)144 Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) {
145   return squeeze_dims__batching_rule(self, {dim});
146 }
147 
squeeze__batching_rule(Tensor & self)148 Tensor& squeeze__batching_rule(Tensor& self) {
149   if (!participatesInCurrentLevel(self)) {
150     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
151     return self.squeeze_();
152   }
153   auto* batched = maybeGetBatchedImpl(self);
154 
155   // Need to find out how many dimensions of size 1 are before the bdim
156   const auto bdim = batched->bdim();
157   const auto physical_shape = batched->value().sizes();
158   auto how_many_dims_of_size_1_before_bdim = 0;
159   for (const auto i : c10::irange(0, physical_shape.size())) {
160     if ((int64_t)i == bdim) {
161       break;
162     }
163     if (physical_shape[i] == 1) {
164       how_many_dims_of_size_1_before_bdim++;
165     }
166   }
167 
168   int64_t new_bdim = bdim - how_many_dims_of_size_1_before_bdim;
169   if (physical_shape[bdim] != 1) {
170     // if bdim is not 1, can just call squeeze_()
171     batched->value().squeeze_();
172   } else {
173     // otherwise, squeeze_() is going to get rid of the bdim too.
174     // We "fix it up" by calling unsqueeze_.
175     batched->value().squeeze_();
176     batched->value().unsqueeze(new_bdim);
177   }
178 
179   // Refresh metadata
180   batched->unsafe_set_bdim(new_bdim);
181   batched->refreshTensorMetadata();
182   return self;
183 }
184 
unsqueeze__batching_rule(Tensor & self,int64_t dim)185 Tensor& unsqueeze__batching_rule(Tensor& self, int64_t dim) {
186   if (!participatesInCurrentLevel(self)) {
187     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
188     return self.unsqueeze_(dim);
189   }
190   auto* batched = maybeGetBatchedImpl(self);
191   auto logical_dim = self.dim();
192   int64_t dim_physical = maybe_wrap_dim(dim, logical_dim + 1);
193   if (dim_physical >= batched->bdim()) {
194     dim_physical = 1 + dim_physical;
195   } else {
196     batched->unsafe_set_bdim(batched->bdim() + 1);
197   }
198   batched->value().unsqueeze_(dim_physical);
199 
200   // Also need to change some metadata...
201   batched->refreshTensorMetadata();
202   return self;
203 }
204 
transpose__batching_rule(Tensor & self,int64_t dim0,int64_t dim1)205 Tensor& transpose__batching_rule(Tensor& self, int64_t dim0, int64_t dim1) {
206   if (!participatesInCurrentLevel(self)) {
207     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
208     return self.transpose_(dim0, dim1);
209   }
210   auto* batched = maybeGetBatchedImpl(self);
211   auto logical_dim = self.dim();
212 
213   // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
214   // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
215   // >>> x = torch.randn(B0)  # the per-examples are all scalars
216   // >>> vmap(lambda x: x.transpose_(0, -1), x)
217   // then we replicate this behavior.
218   if (logical_dim == 0 &&
219       is_allowed_dim_on_scalar_tensor(dim0) &&
220       is_allowed_dim_on_scalar_tensor(dim1)) {
221     // No transposing happened :P
222     return self;
223   }
224 
225   dim0 = maybe_wrap_dim(dim0, logical_dim);
226   dim1 = maybe_wrap_dim(dim1, logical_dim);
227 
228   dim0 = dim0 >= batched->bdim() ? dim0 + 1 : dim0;
229   dim1 = dim1 >= batched->bdim() ? dim1 + 1 : dim1;
230   batched->value().transpose_(dim0, dim1);
231 
232   // Also need to change some metadata...
233   batched->refreshTensorMetadata();
234   return self;
235 }
236 
split_batching_rule(const Tensor & self,int64_t split_size,int64_t dim)237 std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) {
238   if (!participatesInCurrentLevel(self)) {
239     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
240     return at::split(self, split_size, dim);
241   }
242   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
243   auto dim_physical = self_physical.getPhysicalDim(dim);
244   auto result = at::split(self_physical.tensor(), split_size, dim_physical);
245   self_physical.getPhysicalToLogicalMap().applyInplace(result);
246   return result;
247 }
248 
split_with_sizes_batching_rule(const Tensor & self,SymIntArrayRef split_sizes,int64_t dim)249 std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, SymIntArrayRef split_sizes, int64_t dim) {
250   if (!participatesInCurrentLevel(self)) {
251     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
252     return split_with_sizes_symint(self, split_sizes, dim);
253   }
254   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
255   auto dim_physical = self_physical.getPhysicalDim(dim);
256   auto result = split_with_sizes_symint(self_physical.tensor(), split_sizes, dim_physical);
257   self_physical.getPhysicalToLogicalMap().applyInplace(result);
258   return result;
259 }
260 
split_with_sizes_copy_batching_rule(const Tensor & self,SymIntArrayRef split_sizes,int64_t dim)261 std::vector<Tensor> split_with_sizes_copy_batching_rule(const Tensor& self, SymIntArrayRef split_sizes, int64_t dim) {
262   if (!participatesInCurrentLevel(self)) {
263     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
264     return split_with_sizes_copy_symint(self, split_sizes, dim);
265   }
266   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
267   auto dim_physical = self_physical.getPhysicalDim(dim);
268   auto result = split_with_sizes_copy_symint(self_physical.tensor(), split_sizes, dim_physical);
269   self_physical.getPhysicalToLogicalMap().applyInplace(result);
270   return result;
271 }
272 
unbind_batching_rule(const Tensor & self,int64_t dim)273 std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) {
274   if (!participatesInCurrentLevel(self)) {
275     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
276     return at::unbind(self, dim);
277   }
278   auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
279   auto dim_physical = self_physical.getPhysicalDim(dim);
280   auto result = at::unbind(self_physical.tensor(), dim_physical);
281   self_physical.getPhysicalToLogicalMap().applyInplace(result);
282   return result;
283 }
284 
285 // given (sizes, strides, storage_offset) returns the maximum location that
286 // can be indexed (or nullopt if such a location doesn't exist, e.g., tensors
287 // with zero-size dims).
maximum_indexable_location(c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides,const c10::SymInt & storage_offset)288 static std::optional<c10::SymInt> maximum_indexable_location(
289     c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides, const c10::SymInt& storage_offset) {
290   auto result = native::storage_size_for(sizes, strides);
291   if (result == 0) {
292     return std::nullopt;
293   }
294   return result + storage_offset;
295 }
296 
297 // Let x be the "first slice" of physical_tensor.
298 // This checks that the range of possible memory locations accessible by
299 // x.as_strided(sizes, strides, maybe_storage_offset)
300 // are within the bounds of possible memory locations accessible by x.
checkBasicAsStridedValidForSlice(const Tensor & physical_tensor,int64_t num_batch_dims,c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides,const std::optional<c10::SymInt> & maybe_storage_offset)301 static void checkBasicAsStridedValidForSlice(
302     const Tensor& physical_tensor,
303     int64_t num_batch_dims,
304     c10::SymIntArrayRef sizes,
305     c10::SymIntArrayRef strides,
306     const std::optional<c10::SymInt>& maybe_storage_offset) {
307   auto slice_sizes = physical_tensor.sym_sizes().slice(num_batch_dims);
308   auto slice_strides = physical_tensor.sym_strides().slice(num_batch_dims);
309   auto base_offset = physical_tensor.sym_storage_offset();
310 
311   auto storage_offset = maybe_storage_offset.value_or(base_offset);
312 
313   auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset);
314   auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset);
315 
316   if (!max_as_strided_loc.has_value()) {
317     return;
318   }
319   if (!max_slice_loc.has_value()) {
320     TORCH_CHECK(false,
321         "result = tensor.as_strided(", sizes, ", ",  strides, ", ", storage_offset, ") ",
322         "can access memory outside of `tensor`. `tensor` has no storage but the ",
323         "passed-in (size, stride, storage_offset) imply a result with some storage. ",
324         "This is not supported inside of vmap, please try to rewrite the ",
325         "`as_strided` call as a sequence of PyTorch view operations");
326   }
327 
328   TORCH_CHECK(
329       *max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset,
330       "result = tensor.as_strided(", sizes, ", ",  strides, ", ", storage_offset, ") ",
331       "can access memory outside of `tensor`. `result` can access some ",
332       "memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ",
333       "`tensor` can only access some memory in range [", base_offset, ", ",
334       *max_slice_loc, "]. This is not supported inside of vmap, please try to ",
335       "rewrite the `as_strided` call as a sequence of PyTorch view operations");
336 }
337 
338 // What are the semantics of as_strided inside of vmap?
339 // y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
340 // This returns a view on `x`, `y`, such that each y[i] has:
341 // - sizes: `sizes`
342 // - strides: `strides`
343 // - storage_offset: offset + i * x.stride(batch_dim)
344 //
345 // In other words, it is as if we had treated each x[i] as having storage
346 // offset equal to xs.offset() and called as_strided(sizes, sizes, offset).
347 // (that is equivalent to x[i].as_strided(
348 //    sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i)
349 //
350 // Note that this *may* be different from actually running as_strided
351 // in a for-loop. This is due to how as_strided takes in `offset` to be
352 // an *absolute* offset. As an example, consider:
353 // >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1)
354 // >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)]
355 // Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))!
356 // However, we consider the above for-loop comprehension to be a user error:
357 // a user should have written the following if they wanted to use as_strided
358 // in a per-sample way:
359 // >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)]
as_strided_batching_rule(const Tensor & tensor,c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides,std::optional<c10::SymInt> storage_offset)360 Tensor as_strided_batching_rule(
361     const Tensor& tensor,
362     c10::SymIntArrayRef sizes,
363     c10::SymIntArrayRef strides,
364     std::optional<c10::SymInt> storage_offset) {
365   if (!participatesInCurrentLevel(tensor)) {
366     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
367     return at::as_strided_symint(tensor, sizes, strides, std::move(storage_offset));
368   }
369   auto physical_view = MultiBatchVmapTransform::logicalToPhysical(tensor);
370   auto num_batch_dims = physical_view.numBatchDims();
371   auto physical_sizes = physical_view.getPhysicalShape(sizes);
372   const auto& physical_tensor = physical_view.tensor();
373 
374   // We can't rely on the physical as_strided call to do this for us because
375   // we do some sanity checks on the size/strides before calling into as_strided.
376   TORCH_CHECK(sizes.size() == strides.size(),
377       "Tensor.as_strided(size, stride, ...): size and stride must have the ",
378       "same length! Got size ", sizes, " and stride ", strides);
379 
380   // Sanity checks:
381   // 1. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset())
382   // is valid for a slice of the input tensor.
383   // See Note: [When will the as_strided batching rule fail?] for details.
384   checkBasicAsStridedValidForSlice(
385       physical_tensor, num_batch_dims, sizes, strides, storage_offset);
386 
387   // physical_strides = physical tensor's batch strides + (logical) strides
388   auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims);
389   SymDimVector physical_strides;
390   physical_strides.reserve(num_batch_dims + strides.size());
391   physical_strides.insert(
392       physical_strides.end(), batch_strides.begin(), batch_strides.end());
393   physical_strides.insert(
394       physical_strides.end(), strides.begin(), strides.end());
395 
396   // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
397   // is valid for all i, then it turns out that
398   // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds
399   // and creates a tensor y such that each y[i] references the same memory
400   // locations as zi. See NOTE: [When will the as_strided batching rule fail?]
401   auto result = physical_view.tensor().as_strided_symint(
402       physical_sizes, physical_strides, std::move(storage_offset));
403   return physical_view.getPhysicalToLogicalMap().apply(result);
404 }
405 
406 // NOTE: [When will the as_strided batching rule fail?]
407 // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
408 // is valid for all i, then it turns out that
409 // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and
410 // creates a tensor y such that each y[i] refers to the same memory as zi.
411 //
412 // Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()).
413 // Furthermore, let's say that as a part of being "valid" this as_strided call
414 // does not return a result that can index memory not indexable by xs[i].
415 //
416 // WLOG, assume that there's only one batch dim and it is at the front of the
417 // `xs` tensor. Let B be the batch size and S be the stride of the batch dim.
418 // - If the batch dim isn't at the front of the tensor, then we can just move it
419 // to the front with movedim/permute. This is always valid because it just swaps
420 // some strides around.
421 // - This proof also works for tensors with multiple batch dims. We just have to
422 // do a little accounting:
423 //   - instead of [B], we'd have [B0, B1, ..., Bk].
424 //   - instead of [S], we'd have [S0, S1, ..., Sk].
425 //   - instead of i, we'd have a list of indices [I0, I1, ..., Ik]
426 //   - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i
427 //
428 // [Equation 1]
429 // xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has:
430 // - sizes: sizes
431 // - strides: strides
432 // - offset: offset + S * i
433 //
434 // x.as_strided itself checks that:
435 // - (sizes, strides, offset) are in bounds for `x`'s storage.
436 // - strides are positive
437 // - offset is positive
438 //
439 // Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
440 // is valid, then
441 // ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage.
442 //
443 // If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset)
444 // won't error out. So all we need to check is that the memory locations are
445 // what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important)
446 //
447 // xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to
448 // xs.as_strided([B] + sizes, [S] + strides, offset)
449 //
450 // xs.as_strided([B] + sizes, [S] + strides, offset) has:
451 // - sizes: [B] + sizes
452 // - strides: [S] + strides
453 // - offset: offset
454 //
455 // xs.as_strided([B] + sizes, [S] + strides, offset)[i] has:
456 // - sizes: sizes
457 // - strides: strides
458 // - offset: offset + S * i
459 // These memory locations are exactly the same as what we got for [Equation 1],
460 // so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid.
461 //
462 // [Hand-wavy proof of Claim 1]
463 // Part of our definition of being valid is that xs[i].as_strided(...)
464 // must return a tensor that only uses memory indexable by xs[i].
465 // This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies:
466 //    offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
467 //    <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
468 // (the largest-index memory location of xs[i].as_strided(...) must be \leq
469 // the largest-index memory location of xs[i])
470 //
471 // Fiddling that inequality gives us:
472 //    offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
473 //    <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
474 //
475 //    offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
476 //    <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
477 //
478 //    offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
479 //    <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
480 //
481 //    offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
482 //    <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
483 // (the largest-index memory location of xs.as_strided(size, stride, offset)
484 // is \leq than the largest-index memory location of xs)
485 // Under the assumptions we've made, the lower bound (lowest indexed memory)
486 // is trivially within the storage.
487 //
488 // Therefore ([B] + sizes, [S] + strides, offset) are in bounds for
489 // `xs`'s storage.
490 
491 template <typename F, F Func, typename... ExtraArgs>
unwrap_and_call(const Tensor & input,ExtraArgs...args)492 Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) {
493   if (!participatesInCurrentLevel(input)) {
494     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
495     return Func(input, args...);
496   }
497   // guard against the user passing in a batch of scalar tensors with batch
498   auto* input_batched = unsafeGetBatchedImpl(input);
499   auto output_physical = Func(input_batched->value(), args...);
500   return makeBatched(output_physical, input_batched->bdim(), input_batched->level());
501 }
502 
503 template <typename F, F Func, typename... ExtraArgs>
unwrap_and_call_method(const Tensor & input,ExtraArgs...extra_args)504 Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
505   if (!participatesInCurrentLevel(input)) {
506     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
507     return (input.*Func)(extra_args...);
508   }
509   auto* input_batched = unsafeGetBatchedImpl(input);
510   auto output_physical = (input_batched->value().*Func)(extra_args...);
511   return makeBatched(output_physical, input_batched->bdim(), input_batched->level());
512 }
513 
cat_batching_rule(const ITensorListRef & tensors,int64_t dim)514 Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
515   if (!participatesInCurrentLevel(tensors)) {
516     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
517     return at::cat(tensors, dim);
518   }
519 
520   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
521 
522   // NB: Probably bad for perf that we're allocating std::vectors for each level, but
523   // what can you do.
524   auto materialized = tensors.materialize();
525   dim = at::legacy_cat_wrap_dim(dim, materialized);
526 
527   // Strategy:
528   // we're going to unwrap tensors, move their batch dims to the front,
529   // and put them into `tensors_to_cat`. Tensors that don't have a batch dim
530   // will get one forced onto them.
531   //
532   // Then, we'll do at::cat(tensors_to_cat, ...).
533   //
534   // There's a special case where at::cat ignores tensors that have logical shape
535   // [0]. If we see a Tensor that has logical shape [0] (but physical shape [B, 0]),
536   // we'll just slice the tensor to get a Tensor of shape [0] to pass to at::cat.
537   std::vector<Tensor> tensors_to_cat;
538   tensors_to_cat.reserve(tensors.size());
539   std::optional<int64_t> bdim_size = std::nullopt;
540 
541   // find the bdim size. Might not exist if all BatchedTensors should be skipped
542   // by cat's special case.
543   for (const auto& tensor : tensors) {
544     if (!participatesInCurrentLevel(tensor)) {
545       continue;
546     }
547     if (at::native::cat_should_skip_tensor(tensor)) {
548       continue;
549     }
550     const auto* batched = unsafeGetBatchedImpl(tensor);
551     bdim_size = batched->value().size(batched->bdim());
552     break;
553   }
554 
555   // unwrap batchedtensors; expand out bdims
556   for (const auto& tensor : tensors) {
557     if (!participatesInCurrentLevel(tensor)) {
558       if (at::native::cat_should_skip_tensor(tensor) || !bdim_size.has_value()) {
559         tensors_to_cat.emplace_back(tensor);
560         continue;
561       }
562       tensors_to_cat.emplace_back(ensure_has_bdim(tensor, /*has_bdim*/false, *bdim_size));
563       continue;
564     }
565     const auto* batched = unsafeGetBatchedImpl(tensor);
566     if (at::native::cat_should_skip_tensor(tensor)) {
567       // Special case: slice the tensor to get something of shape [0] to pass to cat
568       // We slice instead of allocate a new tensor to propagate requires_gradness...
569       tensors_to_cat.emplace_back(batched->value().select(/*dim=*/batched->bdim(), /*index=*/0));
570       continue;
571     }
572     tensors_to_cat.emplace_back(moveBatchDimToFront(batched->value(), batched->bdim()));
573   }
574 
575   auto new_dim = bdim_size.has_value() ? dim + 1 : dim;
576   std::optional<int64_t> new_bdim = bdim_size.has_value() ? std::make_optional((int64_t)0) : std::nullopt;
577   auto result = at::cat(tensors_to_cat, new_dim);
578   return makeBatched(result, new_bdim, get_current_level());
579 }
580 
block_diag_batching_rule(TensorList tensors)581 Tensor block_diag_batching_rule(TensorList tensors) {
582   if (!participatesInCurrentLevel(tensors)) {
583     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
584     return at::block_diag(tensors);
585   }
586   auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
587   auto physical_tensors = fmap(
588       physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
589   TORCH_INTERNAL_ASSERT(
590       !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
591   // Implementing this as a dummy for loop for now, since I'm not sure how to do it any better.
592   // I'm probably not accounting for potentially multiple batched dimensions?
593   auto bdim = physical_tensors[0].size(0);
594   std::vector<Tensor> batched_outputs;
595   batched_outputs.reserve(bdim);
596   for (const auto& i : c10::irange(bdim)) {
597     std::vector<Tensor> inputs_for_batch;
598     inputs_for_batch.reserve(physical_tensors.size());
599     for (const auto& t : physical_tensors) {
600       inputs_for_batch.push_back(t[i]);
601     }
602     auto out_for_batch = at::block_diag(inputs_for_batch);
603     batched_outputs.push_back(out_for_batch.unsqueeze(0));
604   }
605   auto result = at::cat(batched_outputs);
606   return physical_views[0].getPhysicalToLogicalMap().apply(result);
607 }
608 
stack_batching_rule(TensorList tensors,int64_t dim)609 Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
610   if (!participatesInCurrentLevel(tensors)) {
611     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
612     return at::stack(tensors, dim);
613   }
614   auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
615   auto physical_tensors = fmap(
616       physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
617   TORCH_INTERNAL_ASSERT(
618       !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
619   // NB: stack wraps the dimensionality to (logical dim + 1), so we have to
620   // manually handle that here.
621   auto dim_physical =
622       physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1);
623   auto result = at::stack(physical_tensors, dim_physical);
624   return physical_views[0].getPhysicalToLogicalMap().apply(result);
625 }
626 
new_empty_strided_batching_rule(const Tensor & self,SymIntArrayRef sym_size,SymIntArrayRef sym_stride,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)627 Tensor new_empty_strided_batching_rule(
628     const Tensor& self,
629     SymIntArrayRef sym_size,
630     SymIntArrayRef sym_stride,
631     std::optional<ScalarType> dtype,
632     std::optional<Layout> layout,
633     std::optional<Device> device,
634     std::optional<bool> pin_memory) {
635 
636   auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
637   auto stride = C10_AS_INTARRAYREF_SLOW(sym_stride);
638   if (!participatesInCurrentLevel(self)) {
639     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
640     return self.new_empty_strided(
641         size, stride, dtype, layout, device, pin_memory);
642   }
643 
644   auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
645   auto physical_size = physical_view.getPhysicalShape(size);
646 
647   // Let [B0, B1, B2] be the shape of the batch dims. We're going to create
648   // the batch dimensions at the front of the tensor (in memory layout),
649   // irrespective of whether or not they are actually at the front (in memory layout)
650   // in the original `self` tensor. This is because when a user calls
651   // `new_empty_strided` in general, the `strides` they provide are for a new
652   // tensor and have no relation to the strides of the original tensor.
653   //
654   // So, the physical shape of the result should be ([B0, B1, B2] + size),
655   // but what about the physical strides?
656   //
657   // We're actually free to pick whatever stride we want:
658   // e.g., for size=[5, 3], stride=[0, 1], we could decide to
659   // use
660   // - physical size: [B0, B1, B2, 5, 3]
661   // - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1]
662   //
663   // Let's select some reasonable strides such that:
664   // - The batch dims are "contiguous" with respect to each other
665   // - if empty_strided(size, stride) would have created a contiguous Tensor,
666   // then this new physical Tensor (with batch dims) is also contiguous
667   //
668   // Let S be the size of the storage if one were to construct a tensor
669   // with `size` and `stride` via empty_strided(size, stride).
670   // Then the physical sizes/strides should be:
671   // - physical size: [B0, B1, B2, 5, 3]
672   // - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1]
673   auto batch_shape = IntArrayRef(
674       physical_view.tensor().sizes().begin(), physical_view.numBatchDims());
675 
676   // physical_strides = [B1 * B2 * S, B2 * S, S]
677   auto physical_strides = at::detail::defaultStrides(batch_shape);
678   TORCH_CHECK(size.size() == stride.size(),
679         "new_empty_strided(sizes, strides): dimensionality of sizes (",
680         size.size(), ") must match dimensionality of strides (",
681         stride.size(), ")");
682   auto storage_size = native::storage_size_for(size, stride);
683   for (auto& physical_stride : physical_strides) {
684     physical_stride *= storage_size;
685   }
686 
687   // physical_strides = [B1 * B2 * S, B2 * S, S] + strides
688   physical_strides.insert(physical_strides.end(), stride.begin(), stride.end());
689 
690   auto result = physical_view.tensor().new_empty_strided(
691       physical_size, physical_strides, dtype, layout, device, pin_memory);
692   return physical_view.getPhysicalToLogicalMap().apply(result);
693 }
694 
nested_cat_batching_rule(const ITensorListRef & tensors,int64_t dim)695 Tensor nested_cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
696   TORCH_CHECK(!tensors.empty(), "cat() not supported on empty tensor list");
697 
698   std::vector<std::vector<Tensor>> unbound;
699   for (const auto & tensor : tensors) {
700     auto* maybe_batched_impl = maybeGetBatchedImpl(tensor);
701     TORCH_CHECK(maybe_batched_impl, "Tried to run batching rule for cat() on a non-batched tensor");
702     auto nt = maybe_batched_impl->value();
703     TORCH_CHECK(nt.is_nested(), "Tried to run batching rule for cat() on a non-nested tensor");
704     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::BatchedNestedTensor);
705     auto this_unbound = nt.unbind();
706     if (!unbound.empty()) {
707       TORCH_INTERNAL_ASSERT(unbound.front().size() == this_unbound.size(),
708           "cat() not supported for differently-sized nested arguments");
709     }
710     unbound.push_back(this_unbound);
711   }
712 
713   // Do a cat for each set of zipped unbound components
714   const auto num_components = unbound.front().size();
715   std::vector<Tensor> outputs;
716   for (auto i : c10::irange(num_components)) {
717     std::vector<Tensor> arg_list;
718     for (auto j : c10::irange(unbound.size())) {
719       arg_list.push_back(unbound[j][i]);
720     }
721     outputs.push_back(at::cat(arg_list, dim));
722   }
723 
724   // NB: NTs only support batching over dim 0
725   auto out_nt = at::_nested_tensor_from_tensor_list(outputs);
726   return makeBatched(out_nt, 0, get_current_level());
727 }
728 
729 }
730 
TORCH_LIBRARY_IMPL(_,FuncTorchBatched,m)731 TORCH_LIBRARY_IMPL(_, FuncTorchBatched, m) {
732   m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
733 }
734 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)735 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
736   // still legacy b/c teturns multiple tensors
737   m.impl("split.Tensor", split_batching_rule);
738   m.impl("split_with_sizes", split_with_sizes_batching_rule);
739   m.impl("split_with_sizes_copy", split_with_sizes_copy_batching_rule);
740   m.impl("unbind.int", unbind_batching_rule);
741   m.impl("cat", cat_batching_rule);
742   m.impl("block_diag", block_diag_batching_rule);
743   m.impl("stack", stack_batching_rule);
744 
745   // still legacy b/c needs special inplace rules
746   m.impl("squeeze_", squeeze__batching_rule);
747   m.impl("squeeze_.dim", squeeze_dim__batching_rule);
748   m.impl("squeeze_.dims", squeeze_dims__batching_rule);
749   m.impl("unsqueeze_", unsqueeze__batching_rule);
750   m.impl("transpose_", transpose__batching_rule);
751 
752   // still legacy because these are ridiculously complicated
753   m.impl("as_strided", as_strided_batching_rule);
754   m.impl("new_empty_strided", new_empty_strided_batching_rule);
755 
756 }
757 
TORCH_LIBRARY_IMPL(_,BatchedNestedTensor,m)758 TORCH_LIBRARY_IMPL(_, BatchedNestedTensor, m) {
759   m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedNestedTensorForLoopFallback>());
760 }
761 
762 // TODO: Move this somewhere better?
TORCH_LIBRARY_IMPL(aten,BatchedNestedTensor,m)763 TORCH_LIBRARY_IMPL(aten, BatchedNestedTensor, m) {
764   m.impl("cat", nested_cat_batching_rule);
765 }
766 
767 } // namespace at::functorch
768