1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <torch/library.h>
8 #include <ATen/native/ResizeCommon.h>
9 #include <ATen/ATen.h>
10 #include <ATen/native/TensorShape.h>
11
12 #include <ATen/NestedTensorImpl.h>
13 #include <ATen/functorch/DynamicLayer.h>
14 #include <ATen/functorch/TensorWrapper.h>
15 #include <ATen/functorch/BatchingMetaprogramming.h>
16 #include <ATen/functorch/LegacyVmapTransforms.h>
17 #include <ATen/functorch/BatchedFallback.h>
18 #include <ATen/functorch/BatchRulesHelper.h>
19
20 #include <utility>
21
22 namespace at::functorch {
23
24
25 // NOTE: [What is a batching rule?]
26 //
27 // NB: the following description only applies to this file and is about
28 // the legacy (deprecated) batching rule API. Please see writing_batch_rules.md
29 // for how to write new-style batching rules.
30 //
31 // This files contains batching rules written with the legacy (now-deprecated)
32 // batching rule API.
33 // Please try to use the new-style batching rule API (see writing_batch_rules.md)
34 //
35 // A *batching rule* implements the logic of how to call an operator on inputs
36 // that have zero or more additional batch dimensions. When one does a vmap, the
37 // dimension(s) being vmap'ed over get recorded as batch dimensions.
38 //
39 // For example, vmap(torch.add)(x, y)
40 // 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)];
41 // 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)];
42 // 3. and then runs `torch.add(batched_x, batched_y)`.
43
44 // NOTE: [When should I add a batching rule?]
45 // When you are adding a new operator, you'll need to add a batching rule so
46 // that vmap can work efficiently with said operator. If you do not, we'll attempt
47 // to generate a slow fallback for the batching rule.
48
49 // NOTE: [How to write batching rules?]
50 // The signature of a batching rule should look like exactly like the C++ signature
51 // of its operator.
52 //
53 // First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology.
54 //
55 // At a high level, what a batching rule does is the following:
56 // 1. Converts (logical) BatchedTensors to views on physical tensors.
57 // 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical
58 // arguments that correspond to the physical tensors.
59 // 3. Calls at:: operations on the physical tensors and arguments to produce
60 // some physical results.
61 // 4. Converts physical results back to BatchedTensors.
62 //
63 // Steps 1, 2, and 4 differ for operators with different batching behaviors. When
64 // writing a new batching rule, please select a VmapTransform that matches the
65 // batching behavior of your operation. The VmapTransform provides helper functions
66 // to do steps (1), (2), and (4).
67 // (see NOTE: [What is an VmapTransform?] in VmapTransforms.h)
68
69 namespace{
70 // PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor.
is_allowed_dim_on_scalar_tensor(int64_t dim)71 static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
72 return dim == 0 || dim == -1;
73 }
74
get_current_level()75 static int64_t get_current_level() {
76 auto maybe_level = maybeCurrentDynamicLayer();
77 TORCH_INTERNAL_ASSERT(maybe_level.has_value());
78 return maybe_level->layerId();
79 }
80
81 // This check should probably go into the dispatcher...
participatesInCurrentLevel(const Tensor & self)82 static bool participatesInCurrentLevel(const Tensor& self) {
83 auto current_level = get_current_level();
84 auto* maybe_batched_impl = maybeGetBatchedImpl(self);
85 if (!maybe_batched_impl) {
86 return false;
87 }
88 auto self_level = maybe_batched_impl->level();
89 TORCH_INTERNAL_ASSERT(self_level <= current_level);
90 return self_level == current_level;
91 }
92
participatesInCurrentLevel(ITensorListRef self)93 static bool participatesInCurrentLevel(ITensorListRef self) {
94 for (const Tensor& tensor : self) {
95 if (participatesInCurrentLevel(tensor)) {
96 return true;
97 }
98 }
99 return false;
100 }
101
squeeze_dims__batching_rule(Tensor & self,IntArrayRef dims)102 Tensor& squeeze_dims__batching_rule(Tensor& self, IntArrayRef dims) {
103 if (!participatesInCurrentLevel(self)) {
104 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
105 return self.squeeze_(dims);
106 }
107 auto* batched = maybeGetBatchedImpl(self);
108 const auto bdim = batched->bdim();
109 auto logical_dim = self.dim();
110
111 if (logical_dim == 0) {
112 TORCH_CHECK(
113 dims.empty() || (dims.size() == 1 && dims[0] == 0),
114 "Dimension is out of range (expected to be in range of [-1, 0], but got ", dims);
115 return self;
116 }
117
118 // Adjust any dimensions higher than the batch dimension
119 DimVector adjusted_dims(dims.begin(), dims.end());
120 int64_t updated_batch_idx = bdim;
121 for (auto &d : adjusted_dims) {
122 auto actual_dim = c10::maybe_wrap_dim(d, logical_dim);
123 if (actual_dim < bdim) {
124 d = actual_dim;
125 if (batched->value().sym_size(actual_dim) == 1) {
126 // A column before batch dimension will be dropped so adjust accordingly.
127 --updated_batch_idx;
128 }
129 } else {
130 // Since dimension to be squeezed is after the batch dimension adjust by one to account
131 // for the original batch dimension. In this case batch dimension won't move.
132 d = actual_dim + 1;
133 }
134 }
135
136 batched->value().squeeze_(adjusted_dims);
137 if (updated_batch_idx != bdim) {
138 batched->unsafe_set_bdim(updated_batch_idx);
139 }
140 batched->refreshTensorMetadata();
141 return self;
142 }
143
squeeze_dim__batching_rule(Tensor & self,int64_t dim)144 Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) {
145 return squeeze_dims__batching_rule(self, {dim});
146 }
147
squeeze__batching_rule(Tensor & self)148 Tensor& squeeze__batching_rule(Tensor& self) {
149 if (!participatesInCurrentLevel(self)) {
150 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
151 return self.squeeze_();
152 }
153 auto* batched = maybeGetBatchedImpl(self);
154
155 // Need to find out how many dimensions of size 1 are before the bdim
156 const auto bdim = batched->bdim();
157 const auto physical_shape = batched->value().sizes();
158 auto how_many_dims_of_size_1_before_bdim = 0;
159 for (const auto i : c10::irange(0, physical_shape.size())) {
160 if ((int64_t)i == bdim) {
161 break;
162 }
163 if (physical_shape[i] == 1) {
164 how_many_dims_of_size_1_before_bdim++;
165 }
166 }
167
168 int64_t new_bdim = bdim - how_many_dims_of_size_1_before_bdim;
169 if (physical_shape[bdim] != 1) {
170 // if bdim is not 1, can just call squeeze_()
171 batched->value().squeeze_();
172 } else {
173 // otherwise, squeeze_() is going to get rid of the bdim too.
174 // We "fix it up" by calling unsqueeze_.
175 batched->value().squeeze_();
176 batched->value().unsqueeze(new_bdim);
177 }
178
179 // Refresh metadata
180 batched->unsafe_set_bdim(new_bdim);
181 batched->refreshTensorMetadata();
182 return self;
183 }
184
unsqueeze__batching_rule(Tensor & self,int64_t dim)185 Tensor& unsqueeze__batching_rule(Tensor& self, int64_t dim) {
186 if (!participatesInCurrentLevel(self)) {
187 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
188 return self.unsqueeze_(dim);
189 }
190 auto* batched = maybeGetBatchedImpl(self);
191 auto logical_dim = self.dim();
192 int64_t dim_physical = maybe_wrap_dim(dim, logical_dim + 1);
193 if (dim_physical >= batched->bdim()) {
194 dim_physical = 1 + dim_physical;
195 } else {
196 batched->unsafe_set_bdim(batched->bdim() + 1);
197 }
198 batched->value().unsqueeze_(dim_physical);
199
200 // Also need to change some metadata...
201 batched->refreshTensorMetadata();
202 return self;
203 }
204
transpose__batching_rule(Tensor & self,int64_t dim0,int64_t dim1)205 Tensor& transpose__batching_rule(Tensor& self, int64_t dim0, int64_t dim1) {
206 if (!participatesInCurrentLevel(self)) {
207 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
208 return self.transpose_(dim0, dim1);
209 }
210 auto* batched = maybeGetBatchedImpl(self);
211 auto logical_dim = self.dim();
212
213 // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
214 // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
215 // >>> x = torch.randn(B0) # the per-examples are all scalars
216 // >>> vmap(lambda x: x.transpose_(0, -1), x)
217 // then we replicate this behavior.
218 if (logical_dim == 0 &&
219 is_allowed_dim_on_scalar_tensor(dim0) &&
220 is_allowed_dim_on_scalar_tensor(dim1)) {
221 // No transposing happened :P
222 return self;
223 }
224
225 dim0 = maybe_wrap_dim(dim0, logical_dim);
226 dim1 = maybe_wrap_dim(dim1, logical_dim);
227
228 dim0 = dim0 >= batched->bdim() ? dim0 + 1 : dim0;
229 dim1 = dim1 >= batched->bdim() ? dim1 + 1 : dim1;
230 batched->value().transpose_(dim0, dim1);
231
232 // Also need to change some metadata...
233 batched->refreshTensorMetadata();
234 return self;
235 }
236
split_batching_rule(const Tensor & self,int64_t split_size,int64_t dim)237 std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) {
238 if (!participatesInCurrentLevel(self)) {
239 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
240 return at::split(self, split_size, dim);
241 }
242 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
243 auto dim_physical = self_physical.getPhysicalDim(dim);
244 auto result = at::split(self_physical.tensor(), split_size, dim_physical);
245 self_physical.getPhysicalToLogicalMap().applyInplace(result);
246 return result;
247 }
248
split_with_sizes_batching_rule(const Tensor & self,SymIntArrayRef split_sizes,int64_t dim)249 std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, SymIntArrayRef split_sizes, int64_t dim) {
250 if (!participatesInCurrentLevel(self)) {
251 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
252 return split_with_sizes_symint(self, split_sizes, dim);
253 }
254 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
255 auto dim_physical = self_physical.getPhysicalDim(dim);
256 auto result = split_with_sizes_symint(self_physical.tensor(), split_sizes, dim_physical);
257 self_physical.getPhysicalToLogicalMap().applyInplace(result);
258 return result;
259 }
260
split_with_sizes_copy_batching_rule(const Tensor & self,SymIntArrayRef split_sizes,int64_t dim)261 std::vector<Tensor> split_with_sizes_copy_batching_rule(const Tensor& self, SymIntArrayRef split_sizes, int64_t dim) {
262 if (!participatesInCurrentLevel(self)) {
263 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
264 return split_with_sizes_copy_symint(self, split_sizes, dim);
265 }
266 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
267 auto dim_physical = self_physical.getPhysicalDim(dim);
268 auto result = split_with_sizes_copy_symint(self_physical.tensor(), split_sizes, dim_physical);
269 self_physical.getPhysicalToLogicalMap().applyInplace(result);
270 return result;
271 }
272
unbind_batching_rule(const Tensor & self,int64_t dim)273 std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) {
274 if (!participatesInCurrentLevel(self)) {
275 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
276 return at::unbind(self, dim);
277 }
278 auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
279 auto dim_physical = self_physical.getPhysicalDim(dim);
280 auto result = at::unbind(self_physical.tensor(), dim_physical);
281 self_physical.getPhysicalToLogicalMap().applyInplace(result);
282 return result;
283 }
284
285 // given (sizes, strides, storage_offset) returns the maximum location that
286 // can be indexed (or nullopt if such a location doesn't exist, e.g., tensors
287 // with zero-size dims).
maximum_indexable_location(c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides,const c10::SymInt & storage_offset)288 static std::optional<c10::SymInt> maximum_indexable_location(
289 c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides, const c10::SymInt& storage_offset) {
290 auto result = native::storage_size_for(sizes, strides);
291 if (result == 0) {
292 return std::nullopt;
293 }
294 return result + storage_offset;
295 }
296
297 // Let x be the "first slice" of physical_tensor.
298 // This checks that the range of possible memory locations accessible by
299 // x.as_strided(sizes, strides, maybe_storage_offset)
300 // are within the bounds of possible memory locations accessible by x.
checkBasicAsStridedValidForSlice(const Tensor & physical_tensor,int64_t num_batch_dims,c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides,const std::optional<c10::SymInt> & maybe_storage_offset)301 static void checkBasicAsStridedValidForSlice(
302 const Tensor& physical_tensor,
303 int64_t num_batch_dims,
304 c10::SymIntArrayRef sizes,
305 c10::SymIntArrayRef strides,
306 const std::optional<c10::SymInt>& maybe_storage_offset) {
307 auto slice_sizes = physical_tensor.sym_sizes().slice(num_batch_dims);
308 auto slice_strides = physical_tensor.sym_strides().slice(num_batch_dims);
309 auto base_offset = physical_tensor.sym_storage_offset();
310
311 auto storage_offset = maybe_storage_offset.value_or(base_offset);
312
313 auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset);
314 auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset);
315
316 if (!max_as_strided_loc.has_value()) {
317 return;
318 }
319 if (!max_slice_loc.has_value()) {
320 TORCH_CHECK(false,
321 "result = tensor.as_strided(", sizes, ", ", strides, ", ", storage_offset, ") ",
322 "can access memory outside of `tensor`. `tensor` has no storage but the ",
323 "passed-in (size, stride, storage_offset) imply a result with some storage. ",
324 "This is not supported inside of vmap, please try to rewrite the ",
325 "`as_strided` call as a sequence of PyTorch view operations");
326 }
327
328 TORCH_CHECK(
329 *max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset,
330 "result = tensor.as_strided(", sizes, ", ", strides, ", ", storage_offset, ") ",
331 "can access memory outside of `tensor`. `result` can access some ",
332 "memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ",
333 "`tensor` can only access some memory in range [", base_offset, ", ",
334 *max_slice_loc, "]. This is not supported inside of vmap, please try to ",
335 "rewrite the `as_strided` call as a sequence of PyTorch view operations");
336 }
337
338 // What are the semantics of as_strided inside of vmap?
339 // y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
340 // This returns a view on `x`, `y`, such that each y[i] has:
341 // - sizes: `sizes`
342 // - strides: `strides`
343 // - storage_offset: offset + i * x.stride(batch_dim)
344 //
345 // In other words, it is as if we had treated each x[i] as having storage
346 // offset equal to xs.offset() and called as_strided(sizes, sizes, offset).
347 // (that is equivalent to x[i].as_strided(
348 // sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i)
349 //
350 // Note that this *may* be different from actually running as_strided
351 // in a for-loop. This is due to how as_strided takes in `offset` to be
352 // an *absolute* offset. As an example, consider:
353 // >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1)
354 // >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)]
355 // Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))!
356 // However, we consider the above for-loop comprehension to be a user error:
357 // a user should have written the following if they wanted to use as_strided
358 // in a per-sample way:
359 // >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)]
as_strided_batching_rule(const Tensor & tensor,c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides,std::optional<c10::SymInt> storage_offset)360 Tensor as_strided_batching_rule(
361 const Tensor& tensor,
362 c10::SymIntArrayRef sizes,
363 c10::SymIntArrayRef strides,
364 std::optional<c10::SymInt> storage_offset) {
365 if (!participatesInCurrentLevel(tensor)) {
366 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
367 return at::as_strided_symint(tensor, sizes, strides, std::move(storage_offset));
368 }
369 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(tensor);
370 auto num_batch_dims = physical_view.numBatchDims();
371 auto physical_sizes = physical_view.getPhysicalShape(sizes);
372 const auto& physical_tensor = physical_view.tensor();
373
374 // We can't rely on the physical as_strided call to do this for us because
375 // we do some sanity checks on the size/strides before calling into as_strided.
376 TORCH_CHECK(sizes.size() == strides.size(),
377 "Tensor.as_strided(size, stride, ...): size and stride must have the ",
378 "same length! Got size ", sizes, " and stride ", strides);
379
380 // Sanity checks:
381 // 1. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset())
382 // is valid for a slice of the input tensor.
383 // See Note: [When will the as_strided batching rule fail?] for details.
384 checkBasicAsStridedValidForSlice(
385 physical_tensor, num_batch_dims, sizes, strides, storage_offset);
386
387 // physical_strides = physical tensor's batch strides + (logical) strides
388 auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims);
389 SymDimVector physical_strides;
390 physical_strides.reserve(num_batch_dims + strides.size());
391 physical_strides.insert(
392 physical_strides.end(), batch_strides.begin(), batch_strides.end());
393 physical_strides.insert(
394 physical_strides.end(), strides.begin(), strides.end());
395
396 // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
397 // is valid for all i, then it turns out that
398 // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds
399 // and creates a tensor y such that each y[i] references the same memory
400 // locations as zi. See NOTE: [When will the as_strided batching rule fail?]
401 auto result = physical_view.tensor().as_strided_symint(
402 physical_sizes, physical_strides, std::move(storage_offset));
403 return physical_view.getPhysicalToLogicalMap().apply(result);
404 }
405
406 // NOTE: [When will the as_strided batching rule fail?]
407 // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
408 // is valid for all i, then it turns out that
409 // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and
410 // creates a tensor y such that each y[i] refers to the same memory as zi.
411 //
412 // Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()).
413 // Furthermore, let's say that as a part of being "valid" this as_strided call
414 // does not return a result that can index memory not indexable by xs[i].
415 //
416 // WLOG, assume that there's only one batch dim and it is at the front of the
417 // `xs` tensor. Let B be the batch size and S be the stride of the batch dim.
418 // - If the batch dim isn't at the front of the tensor, then we can just move it
419 // to the front with movedim/permute. This is always valid because it just swaps
420 // some strides around.
421 // - This proof also works for tensors with multiple batch dims. We just have to
422 // do a little accounting:
423 // - instead of [B], we'd have [B0, B1, ..., Bk].
424 // - instead of [S], we'd have [S0, S1, ..., Sk].
425 // - instead of i, we'd have a list of indices [I0, I1, ..., Ik]
426 // - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i
427 //
428 // [Equation 1]
429 // xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has:
430 // - sizes: sizes
431 // - strides: strides
432 // - offset: offset + S * i
433 //
434 // x.as_strided itself checks that:
435 // - (sizes, strides, offset) are in bounds for `x`'s storage.
436 // - strides are positive
437 // - offset is positive
438 //
439 // Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
440 // is valid, then
441 // ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage.
442 //
443 // If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset)
444 // won't error out. So all we need to check is that the memory locations are
445 // what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important)
446 //
447 // xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to
448 // xs.as_strided([B] + sizes, [S] + strides, offset)
449 //
450 // xs.as_strided([B] + sizes, [S] + strides, offset) has:
451 // - sizes: [B] + sizes
452 // - strides: [S] + strides
453 // - offset: offset
454 //
455 // xs.as_strided([B] + sizes, [S] + strides, offset)[i] has:
456 // - sizes: sizes
457 // - strides: strides
458 // - offset: offset + S * i
459 // These memory locations are exactly the same as what we got for [Equation 1],
460 // so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid.
461 //
462 // [Hand-wavy proof of Claim 1]
463 // Part of our definition of being valid is that xs[i].as_strided(...)
464 // must return a tensor that only uses memory indexable by xs[i].
465 // This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies:
466 // offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
467 // <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
468 // (the largest-index memory location of xs[i].as_strided(...) must be \leq
469 // the largest-index memory location of xs[i])
470 //
471 // Fiddling that inequality gives us:
472 // offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
473 // <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
474 //
475 // offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
476 // <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
477 //
478 // offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
479 // <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
480 //
481 // offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
482 // <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
483 // (the largest-index memory location of xs.as_strided(size, stride, offset)
484 // is \leq than the largest-index memory location of xs)
485 // Under the assumptions we've made, the lower bound (lowest indexed memory)
486 // is trivially within the storage.
487 //
488 // Therefore ([B] + sizes, [S] + strides, offset) are in bounds for
489 // `xs`'s storage.
490
491 template <typename F, F Func, typename... ExtraArgs>
unwrap_and_call(const Tensor & input,ExtraArgs...args)492 Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) {
493 if (!participatesInCurrentLevel(input)) {
494 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
495 return Func(input, args...);
496 }
497 // guard against the user passing in a batch of scalar tensors with batch
498 auto* input_batched = unsafeGetBatchedImpl(input);
499 auto output_physical = Func(input_batched->value(), args...);
500 return makeBatched(output_physical, input_batched->bdim(), input_batched->level());
501 }
502
503 template <typename F, F Func, typename... ExtraArgs>
unwrap_and_call_method(const Tensor & input,ExtraArgs...extra_args)504 Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
505 if (!participatesInCurrentLevel(input)) {
506 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
507 return (input.*Func)(extra_args...);
508 }
509 auto* input_batched = unsafeGetBatchedImpl(input);
510 auto output_physical = (input_batched->value().*Func)(extra_args...);
511 return makeBatched(output_physical, input_batched->bdim(), input_batched->level());
512 }
513
cat_batching_rule(const ITensorListRef & tensors,int64_t dim)514 Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
515 if (!participatesInCurrentLevel(tensors)) {
516 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
517 return at::cat(tensors, dim);
518 }
519
520 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
521
522 // NB: Probably bad for perf that we're allocating std::vectors for each level, but
523 // what can you do.
524 auto materialized = tensors.materialize();
525 dim = at::legacy_cat_wrap_dim(dim, materialized);
526
527 // Strategy:
528 // we're going to unwrap tensors, move their batch dims to the front,
529 // and put them into `tensors_to_cat`. Tensors that don't have a batch dim
530 // will get one forced onto them.
531 //
532 // Then, we'll do at::cat(tensors_to_cat, ...).
533 //
534 // There's a special case where at::cat ignores tensors that have logical shape
535 // [0]. If we see a Tensor that has logical shape [0] (but physical shape [B, 0]),
536 // we'll just slice the tensor to get a Tensor of shape [0] to pass to at::cat.
537 std::vector<Tensor> tensors_to_cat;
538 tensors_to_cat.reserve(tensors.size());
539 std::optional<int64_t> bdim_size = std::nullopt;
540
541 // find the bdim size. Might not exist if all BatchedTensors should be skipped
542 // by cat's special case.
543 for (const auto& tensor : tensors) {
544 if (!participatesInCurrentLevel(tensor)) {
545 continue;
546 }
547 if (at::native::cat_should_skip_tensor(tensor)) {
548 continue;
549 }
550 const auto* batched = unsafeGetBatchedImpl(tensor);
551 bdim_size = batched->value().size(batched->bdim());
552 break;
553 }
554
555 // unwrap batchedtensors; expand out bdims
556 for (const auto& tensor : tensors) {
557 if (!participatesInCurrentLevel(tensor)) {
558 if (at::native::cat_should_skip_tensor(tensor) || !bdim_size.has_value()) {
559 tensors_to_cat.emplace_back(tensor);
560 continue;
561 }
562 tensors_to_cat.emplace_back(ensure_has_bdim(tensor, /*has_bdim*/false, *bdim_size));
563 continue;
564 }
565 const auto* batched = unsafeGetBatchedImpl(tensor);
566 if (at::native::cat_should_skip_tensor(tensor)) {
567 // Special case: slice the tensor to get something of shape [0] to pass to cat
568 // We slice instead of allocate a new tensor to propagate requires_gradness...
569 tensors_to_cat.emplace_back(batched->value().select(/*dim=*/batched->bdim(), /*index=*/0));
570 continue;
571 }
572 tensors_to_cat.emplace_back(moveBatchDimToFront(batched->value(), batched->bdim()));
573 }
574
575 auto new_dim = bdim_size.has_value() ? dim + 1 : dim;
576 std::optional<int64_t> new_bdim = bdim_size.has_value() ? std::make_optional((int64_t)0) : std::nullopt;
577 auto result = at::cat(tensors_to_cat, new_dim);
578 return makeBatched(result, new_bdim, get_current_level());
579 }
580
block_diag_batching_rule(TensorList tensors)581 Tensor block_diag_batching_rule(TensorList tensors) {
582 if (!participatesInCurrentLevel(tensors)) {
583 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
584 return at::block_diag(tensors);
585 }
586 auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
587 auto physical_tensors = fmap(
588 physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
589 TORCH_INTERNAL_ASSERT(
590 !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
591 // Implementing this as a dummy for loop for now, since I'm not sure how to do it any better.
592 // I'm probably not accounting for potentially multiple batched dimensions?
593 auto bdim = physical_tensors[0].size(0);
594 std::vector<Tensor> batched_outputs;
595 batched_outputs.reserve(bdim);
596 for (const auto& i : c10::irange(bdim)) {
597 std::vector<Tensor> inputs_for_batch;
598 inputs_for_batch.reserve(physical_tensors.size());
599 for (const auto& t : physical_tensors) {
600 inputs_for_batch.push_back(t[i]);
601 }
602 auto out_for_batch = at::block_diag(inputs_for_batch);
603 batched_outputs.push_back(out_for_batch.unsqueeze(0));
604 }
605 auto result = at::cat(batched_outputs);
606 return physical_views[0].getPhysicalToLogicalMap().apply(result);
607 }
608
stack_batching_rule(TensorList tensors,int64_t dim)609 Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
610 if (!participatesInCurrentLevel(tensors)) {
611 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
612 return at::stack(tensors, dim);
613 }
614 auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
615 auto physical_tensors = fmap(
616 physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
617 TORCH_INTERNAL_ASSERT(
618 !tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
619 // NB: stack wraps the dimensionality to (logical dim + 1), so we have to
620 // manually handle that here.
621 auto dim_physical =
622 physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1);
623 auto result = at::stack(physical_tensors, dim_physical);
624 return physical_views[0].getPhysicalToLogicalMap().apply(result);
625 }
626
new_empty_strided_batching_rule(const Tensor & self,SymIntArrayRef sym_size,SymIntArrayRef sym_stride,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)627 Tensor new_empty_strided_batching_rule(
628 const Tensor& self,
629 SymIntArrayRef sym_size,
630 SymIntArrayRef sym_stride,
631 std::optional<ScalarType> dtype,
632 std::optional<Layout> layout,
633 std::optional<Device> device,
634 std::optional<bool> pin_memory) {
635
636 auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
637 auto stride = C10_AS_INTARRAYREF_SLOW(sym_stride);
638 if (!participatesInCurrentLevel(self)) {
639 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
640 return self.new_empty_strided(
641 size, stride, dtype, layout, device, pin_memory);
642 }
643
644 auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
645 auto physical_size = physical_view.getPhysicalShape(size);
646
647 // Let [B0, B1, B2] be the shape of the batch dims. We're going to create
648 // the batch dimensions at the front of the tensor (in memory layout),
649 // irrespective of whether or not they are actually at the front (in memory layout)
650 // in the original `self` tensor. This is because when a user calls
651 // `new_empty_strided` in general, the `strides` they provide are for a new
652 // tensor and have no relation to the strides of the original tensor.
653 //
654 // So, the physical shape of the result should be ([B0, B1, B2] + size),
655 // but what about the physical strides?
656 //
657 // We're actually free to pick whatever stride we want:
658 // e.g., for size=[5, 3], stride=[0, 1], we could decide to
659 // use
660 // - physical size: [B0, B1, B2, 5, 3]
661 // - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1]
662 //
663 // Let's select some reasonable strides such that:
664 // - The batch dims are "contiguous" with respect to each other
665 // - if empty_strided(size, stride) would have created a contiguous Tensor,
666 // then this new physical Tensor (with batch dims) is also contiguous
667 //
668 // Let S be the size of the storage if one were to construct a tensor
669 // with `size` and `stride` via empty_strided(size, stride).
670 // Then the physical sizes/strides should be:
671 // - physical size: [B0, B1, B2, 5, 3]
672 // - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1]
673 auto batch_shape = IntArrayRef(
674 physical_view.tensor().sizes().begin(), physical_view.numBatchDims());
675
676 // physical_strides = [B1 * B2 * S, B2 * S, S]
677 auto physical_strides = at::detail::defaultStrides(batch_shape);
678 TORCH_CHECK(size.size() == stride.size(),
679 "new_empty_strided(sizes, strides): dimensionality of sizes (",
680 size.size(), ") must match dimensionality of strides (",
681 stride.size(), ")");
682 auto storage_size = native::storage_size_for(size, stride);
683 for (auto& physical_stride : physical_strides) {
684 physical_stride *= storage_size;
685 }
686
687 // physical_strides = [B1 * B2 * S, B2 * S, S] + strides
688 physical_strides.insert(physical_strides.end(), stride.begin(), stride.end());
689
690 auto result = physical_view.tensor().new_empty_strided(
691 physical_size, physical_strides, dtype, layout, device, pin_memory);
692 return physical_view.getPhysicalToLogicalMap().apply(result);
693 }
694
nested_cat_batching_rule(const ITensorListRef & tensors,int64_t dim)695 Tensor nested_cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
696 TORCH_CHECK(!tensors.empty(), "cat() not supported on empty tensor list");
697
698 std::vector<std::vector<Tensor>> unbound;
699 for (const auto & tensor : tensors) {
700 auto* maybe_batched_impl = maybeGetBatchedImpl(tensor);
701 TORCH_CHECK(maybe_batched_impl, "Tried to run batching rule for cat() on a non-batched tensor");
702 auto nt = maybe_batched_impl->value();
703 TORCH_CHECK(nt.is_nested(), "Tried to run batching rule for cat() on a non-nested tensor");
704 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::BatchedNestedTensor);
705 auto this_unbound = nt.unbind();
706 if (!unbound.empty()) {
707 TORCH_INTERNAL_ASSERT(unbound.front().size() == this_unbound.size(),
708 "cat() not supported for differently-sized nested arguments");
709 }
710 unbound.push_back(this_unbound);
711 }
712
713 // Do a cat for each set of zipped unbound components
714 const auto num_components = unbound.front().size();
715 std::vector<Tensor> outputs;
716 for (auto i : c10::irange(num_components)) {
717 std::vector<Tensor> arg_list;
718 for (auto j : c10::irange(unbound.size())) {
719 arg_list.push_back(unbound[j][i]);
720 }
721 outputs.push_back(at::cat(arg_list, dim));
722 }
723
724 // NB: NTs only support batching over dim 0
725 auto out_nt = at::_nested_tensor_from_tensor_list(outputs);
726 return makeBatched(out_nt, 0, get_current_level());
727 }
728
729 }
730
TORCH_LIBRARY_IMPL(_,FuncTorchBatched,m)731 TORCH_LIBRARY_IMPL(_, FuncTorchBatched, m) {
732 m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
733 }
734
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)735 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
736 // still legacy b/c teturns multiple tensors
737 m.impl("split.Tensor", split_batching_rule);
738 m.impl("split_with_sizes", split_with_sizes_batching_rule);
739 m.impl("split_with_sizes_copy", split_with_sizes_copy_batching_rule);
740 m.impl("unbind.int", unbind_batching_rule);
741 m.impl("cat", cat_batching_rule);
742 m.impl("block_diag", block_diag_batching_rule);
743 m.impl("stack", stack_batching_rule);
744
745 // still legacy b/c needs special inplace rules
746 m.impl("squeeze_", squeeze__batching_rule);
747 m.impl("squeeze_.dim", squeeze_dim__batching_rule);
748 m.impl("squeeze_.dims", squeeze_dims__batching_rule);
749 m.impl("unsqueeze_", unsqueeze__batching_rule);
750 m.impl("transpose_", transpose__batching_rule);
751
752 // still legacy because these are ridiculously complicated
753 m.impl("as_strided", as_strided_batching_rule);
754 m.impl("new_empty_strided", new_empty_strided_batching_rule);
755
756 }
757
TORCH_LIBRARY_IMPL(_,BatchedNestedTensor,m)758 TORCH_LIBRARY_IMPL(_, BatchedNestedTensor, m) {
759 m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedNestedTensorForLoopFallback>());
760 }
761
762 // TODO: Move this somewhere better?
TORCH_LIBRARY_IMPL(aten,BatchedNestedTensor,m)763 TORCH_LIBRARY_IMPL(aten, BatchedNestedTensor, m) {
764 m.impl("cat", nested_cat_batching_rule);
765 }
766
767 } // namespace at::functorch
768