xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchedFallback.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchedFallback.h>
8 #include <ATen/functorch/LegacyVmapTransforms.h>
9 #include <ATen/functorch/TensorWrapper.h>
10 #include <ATen/functorch/DynamicLayer.h>
11 #include <ATen/functorch/PlumbingHelper.h>
12 
13 #include <ATen/Context.h>
14 #include <ATen/MatrixRef.h>
15 #include <ATen/core/dispatch/Dispatcher.h>
16 #include <c10/util/accumulate.h>
17 #include <c10/util/llvmMathExtras.h>
18 #include <c10/util/irange.h>
19 
20 namespace at::functorch {
21 
22 bool kVmapFallbackWarningEnabled = true;
23 
isVmapFallbackWarningEnabled()24 bool isVmapFallbackWarningEnabled() {
25   return kVmapFallbackWarningEnabled;
26 }
27 
setVmapFallbackWarningEnabled(bool enabled)28 void setVmapFallbackWarningEnabled(bool enabled) {
29   kVmapFallbackWarningEnabled = enabled;
30 }
31 
32 bool kVmapFallbackEnabled = true;
33 
isVmapFallbackEnabled()34 bool isVmapFallbackEnabled() {
35   return kVmapFallbackEnabled;
36 }
37 
setVmapFallbackEnabled(bool enabled)38 void setVmapFallbackEnabled(bool enabled) {
39   kVmapFallbackEnabled = enabled;
40 }
41 
42 // Given a linear index, return the actual index.
43 // Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0]
44 static at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize>
computeIndex(int64_t linear_idx,IntArrayRef sizes)45 computeIndex(int64_t linear_idx, IntArrayRef sizes) {
46   at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize> result;
47   result.reserve(sizes.size());
48   for (auto it = sizes.rbegin(); it != sizes.rend(); it++) {
49     auto remainder = linear_idx % *it;
50     result.push_back(remainder);
51     linear_idx -= remainder;
52     linear_idx /= *it;
53   }
54   std::reverse(std::begin(result), std::end(result));
55   return result;
56 }
57 
areAllReturnsTensors(const at::FunctionSchema & schema)58 static bool areAllReturnsTensors(const at::FunctionSchema& schema) {
59   return std::all_of(
60       schema.returns().begin(),
61       schema.returns().end(),
62       [] (const Argument& arg) { return arg.type() == TensorType::get(); });
63 }
64 
areAnyArgumentsTensorList(const at::FunctionSchema & schema)65 static bool areAnyArgumentsTensorList(const at::FunctionSchema& schema) {
66   return std::any_of(
67       schema.arguments().begin(),
68       schema.arguments().end(),
69       [] (const Argument& arg) {
70         return arg.type()->isSubtypeOf(ListType::ofTensors()) ||
71           arg.type()->isSubtypeOf(ListType::ofOptionalTensors());
72       });
73 }
74 
warnFallback(const c10::FunctionSchema & schema,bool is_inplace,bool is_nested=false)75 static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace, bool is_nested=false) {
76   TORCH_CHECK(isVmapFallbackEnabled(),
77       schema.operator_name(), " hit the vmap fallback which is currently disabled");
78   if (!isVmapFallbackWarningEnabled()) {
79     return;
80   }
81   TORCH_WARN("There is a performance drop because we have not yet implemented ",
82              "the ", (is_nested ? "nested " : "") , "batching rule for ",
83              schema.operator_name(), ". Please file us an issue on GitHub so that ",
84              "we can prioritize its implementation.");
85 }
86 
87 // The general flow of the algorithm is as follows.
88 // - First, we figure out which arguments are BatchedTensors and save them
89 //   to a vector. We also store a vector of which index of the arguments list
90 //   each BatchedTensor appears in. This will be useful for bookkeeping later.
91 // - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
92 //   This returns a vector of VmapPhysicalView that hold tensors that contain
93 //   all of the collective batch dimensions at the front of the tensors.
94 // - Then, we attempt to call `op` once per slice of the inputs. To do this,
95 //   we repeatedly we slice the input arguments (if they are BatchedTensors),
96 //   put the sliced (or a not-sliced) version of the input onto the stack, invoke
97 //   the operator, and then pop the results off the stack.
batchedTensorInplaceForLoopFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)98 static void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
99   const auto& schema = op.schema();
100   warnFallback(schema, /*in_place*/true);
101 
102   const auto num_arguments = schema.arguments().size();
103   const auto arguments = torch::jit::last(stack, num_arguments);
104   const auto arguments_begin = stack->size() - num_arguments;
105 
106   // `self` is the Tensor being modified in-place
107   Tensor self = arguments[0].toTensor();
108   const auto* self_impl = maybeGetBatchedImpl(self);
109   std::bitset<kVmapMaxTensorDims> self_vmap_levels;
110   if (self_impl) {
111     self_vmap_levels = createVmapLevelsBitset(self_impl->level());
112   }
113 
114   // Figure out which arguments are BatchedTensor. Save them to a vector.
115   // For each BatchedTensor, also record what position of `arguments` they came from.
116   at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
117   VmapDimVector batched_tensor_inputs_position;
118   for (const auto idx : c10::irange(0, arguments.size())) {
119     const auto& ivalue = arguments[idx];
120     if (!ivalue.isTensor()) {
121       continue;
122     }
123     const auto& tensor = ivalue.toTensor();
124     if (!tensor.defined()) {
125       continue;
126     }
127     const auto* batched = maybeGetBatchedImpl(tensor);
128     if (!batched) {
129       continue;
130     }
131 
132     // NOTE: [vmap-incompatible in-place operations]
133     // In-place operations on `self` are not possible if there exists some vmap
134     // level `l` such that `self` is not being vmapped on that level but another
135     // argument is. For example, let B0 be a batch dim inside vmap and consider
136     // vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3))
137     // - self is torch.ones(3) and does not participate in this vmap
138     // - other is BatchedTensor(torch.ones(B0, 3))
139     // There's no way to do self.add_(other) because `other` has more elements
140     // elements than `self` due to being vmapped over.
141     //
142     // In the vmap fallback, we should error out when we detect this.
143     auto other_vmap_levels = createVmapLevelsBitset(batched->level());
144     if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) {
145       // Find one vmap level to complain about
146       auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels;
147       [[maybe_unused]] auto offending_level = llvm::findLastSet(additional_bdims.to_ulong());
148       // The following prints out "vmap: aten::add_(tensor, ...) is not possible",
149       // but it would be better to print out "tensor.add_(...) is not possible".
150       // Afaict there's no official way to get the add_ and there is no way to
151       // tell if an operator has method or function variants.
152       TORCH_CHECK(false,
153         "vmap: ", schema.name(), "(self, *extra_args) is not possible because ",
154         "there exists a Tensor `other` in extra_args that has more elements ",
155         "than `self`. This happened due to `other` being vmapped over but ",
156         "`self` not being vmapped over at level ", offending_level, ". ",
157         "Please try to use out-of-place operators instead of ", schema.name(), ". ",
158         "If said operator is being called inside the PyTorch framework, ",
159         "please file a bug report instead.");
160     }
161     batched_tensor_inputs.push_back(tensor);
162     batched_tensor_inputs_position.push_back(static_cast<int64_t>(idx));
163   }
164   TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
165 
166   // MultiBatchVmapTransform the BatchedTensor arguments. This returns
167   // VmapPhysicalViews that contain all of the batch dimensions.
168   const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
169       batched_tensor_inputs);
170 
171   // Compute the total number of batches
172   auto num_batch_dims = input_physical_views.front().numBatchDims();
173   auto first_physical_view_sizes = input_physical_views.front().tensor().sizes();
174   auto batch_sizes = ArrayRef<int64_t>(
175       first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims);
176   const auto num_batches = c10::multiply_integers(batch_sizes);
177   // Without a shape-checking API, we're unable to compute the correct shape of
178   // the output so we just error out.
179   TORCH_CHECK(num_batches > 0,
180       "Batching rule not implemented for ", schema.operator_name(), ". ",
181       "The fallback path does not support vmap over dims of size 0.");
182 
183   // Strategy: For each batch, we are going to push slices (where applicable)
184   // of the arguments onto `stack`, and call `op`.
185   for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
186     auto index = computeIndex(linear_idx, batch_sizes);
187     auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
188     auto input_physical_views_iter = input_physical_views.begin();
189     for (const auto arg_idx : c10::irange(0, num_arguments)) {
190       // We assume that torch::jit::Stack is backed by vector<IValue> for
191       // simplicity. When that is not the case, this code should be updated.
192       const auto& argument = (*stack)[arguments_begin + arg_idx];
193       if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
194           || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
195         // argument isn't a BatchedTensor
196         torch::jit::push(stack, argument);
197         continue;
198       }
199       // argument is a BatchedTensor
200       TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
201       const auto& physical_view_for_argument = *input_physical_views_iter;
202       auto thing = physical_view_for_argument.tensor().index(index);
203       torch::jit::push(stack, thing);
204       batched_tensor_inputs_pos_iter++;
205       input_physical_views_iter++;
206     }
207 
208     op.callBoxed(stack);
209     torch::jit::drop(stack, 1);
210   }
211 
212   // Return the tensor that was written to in-place
213   torch::jit::drop(stack, num_arguments);
214   torch::jit::push(stack, self);
215 }
216 
safeStack(TensorList tensors)217 static Tensor safeStack(TensorList tensors) {
218   auto is_defined = [](const Tensor& t) { return t.defined(); };
219   if (std::all_of(tensors.begin(), tensors.end(), is_defined)) {
220     return at::stack(tensors);
221   }
222   // NOTE [vmap through backward and undefined grad]
223   // While vmapping through backward functions (to compute batched grad), it
224   // is possible for the backward function to return an undefined grad for some
225   // grad_input for each example. In that case, we return an undefined grad.
226   //
227   // It is theoretically posssible for *some* of the examples to produce an
228   // undefined grad (a kernel could peek at the gradient values and return an
229   // undefined tensor if it determines the gradient is full of zeros). We
230   // could handle this by treating the undefined grad as a zero-filled tensor
231   // of the correct shape while stacking the tensors together. However I expect
232   // this to happen very rarely (I have not been able to find an example in our
233   // codebase) so we just error out in this case.
234   if (std::none_of(tensors.begin(), tensors.end(), is_defined)) {
235     return Tensor();
236   }
237   TORCH_CHECK(false,
238       "vmap: slow fallback received a mix of undefined and defined tensors ",
239       "as the result of an operation. This is not supported, please file us ",
240       "an issue on github.");
241 }
242 
243 // TODO: Consider rewriting the following to look like:
244 // https://gist.github.com/zou3519/7b7c6a4a258d580f62d1d969851be6b1<Paste>
245 
246 // The general flow of the algorithm is as follows.
247 // - First, we figure out which arguments are BatchedTensors and save them
248 //   to a vector. We also store a vector of which index of the arguments list
249 //   each BatchedTensor appears in. This will be useful for bookkeeping later.
250 // - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
251 //   This returns a vector of VmapPhysicalView that hold tensors that contain
252 //   all of the collective batch dimensions at the front of the tensors.
253 // - Then, we attempt to call `op` once per slice of the inputs. To do this,
254 //   we repeatedly we slice the input arguments (if they are BatchedTensors),
255 //   put the sliced (or a not-sliced) version of the input onto the stack, invoke
256 //   the operator, and then pop the results off the stack.
257 // - Each result obtained from the previous step is a slice of the total result,
258 //   so we stack those tensors together to form the final result.
batchedTensorForLoopFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)259 void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
260   const auto& schema = op.schema();
261   const auto num_returns = schema.returns().size();
262   const auto num_arguments = schema.arguments().size();
263   const auto arguments = torch::jit::last(stack, num_arguments);
264 
265   TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
266               "Batching rule not implemented for ", schema.operator_name(), ". ",
267               "We could not generate a fallback.");
268 
269   if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
270     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
271     op.callBoxed(stack);
272     return;
273   }
274 
275   if (isInplaceOp(schema)) {
276     batchedTensorInplaceForLoopFallback(op, stack);
277     return;
278   }
279   TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(),
280               "Batching rule not implemented for ", schema.operator_name(), "; ",
281               "the fallback path doesn't work on out= or view ops.");
282   TORCH_CHECK(num_returns >= 1,
283               "Batching rule not implemented for ", schema.operator_name(), ". ",
284               "The fallback path does not support operations with no returns.");
285   warnFallback(schema, /*in_place*/false);
286 
287   const auto arguments_begin = stack->size() - num_arguments;
288 
289   // Figure out which arguments are BatchedTensor. Save them to a vector.
290   // For each BatchedTensor, also record what position of `arguments` they came from.
291   at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
292   VmapDimVector batched_tensor_inputs_position;
293   for (const auto idx : c10::irange(0, arguments.size())) {
294     const auto& ivalue = arguments[idx];
295     if (!ivalue.isTensor()) {
296       continue;
297     }
298     const auto& tensor = ivalue.toTensor();
299     if (!tensor.defined()) {
300       continue;
301     }
302     const auto* batched = maybeGetBatchedImpl(tensor);
303     if (!batched) {
304       continue;
305     }
306     batched_tensor_inputs.push_back(tensor);
307     batched_tensor_inputs_position.push_back(static_cast<int64_t>(idx));
308   }
309   TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
310 
311   // MultiBatchVmapTransform the BatchedTensor arguments. This returns
312   // VmapPhysicalViews that contain all of the batch dimensions.
313   const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
314       batched_tensor_inputs);
315 
316   // Compute the total number of batches
317   auto num_batch_dims = input_physical_views.front().numBatchDims();
318   auto some_sizes = input_physical_views.front().tensor().sizes();
319   auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims);
320   const auto num_batches = c10::multiply_integers(batch_sizes);
321   // Without a shape-checking API, we're unable to compute the correct shape of
322   // the output so we just error out.
323   TORCH_CHECK(num_batches > 0,
324       "Batching rule not implemented for ", schema.operator_name(), ". ",
325       "The fallback path does not support vmap over dims of size 0.");
326 
327   // Strategy: For each batch, we are going to push slices (where applicable)
328   // of the arguments onto `stack`, call `op`, and store the result in
329   // `output_shards`.
330   //
331   // NOTE: [Output shards layout]
332   // Assume that the operator has three outputs: a, b, c.
333   // The layout of output_shards is as follows:
334   // [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3]
335   // This is so that we can call at::stack([a0...a3]), at::stack([b0...b3])
336   // more easily in the next step.
337   std::vector<Tensor> output_shards(num_batches * num_returns);
338 
339   for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
340     auto index = computeIndex(linear_idx, batch_sizes);
341     auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
342     auto input_physical_views_iter = input_physical_views.begin();
343     for (const auto arg_idx : c10::irange(0, num_arguments)) {
344       // We assume that torch::jit::Stack is backed by vector<IValue> for
345       // simplicity. When that is not the case, this code should be updated.
346       const auto& argument = (*stack)[arguments_begin + arg_idx];
347       if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
348           || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
349         // argument isn't a BatchedTensor
350         torch::jit::push(stack, argument);
351         continue;
352       }
353       // argument is a BatchedTensor
354       TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
355       const auto& physical_view_for_argument = *input_physical_views_iter;
356       c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
357       torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
358       batched_tensor_inputs_pos_iter++;
359       input_physical_views_iter++;
360     }
361 
362     // std::cout << "[Fallback]: ";
363     // at::dump_tensor((*stack)[stack->size() - 1].toTensor());
364     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
365     op.callBoxed(stack);
366 
367     // Store the result into `output_shards`. See NOTE: [Output shards layout]
368     // to learn about the details of how we store the shards.
369     const auto returns = torch::jit::last(stack, num_returns);
370     for (const auto  return_idx : c10::irange(0, returns.size())) {
371       output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor();
372     }
373     torch::jit::drop(stack, num_returns);
374   }
375 
376   // For each output Tensor, stack the shards of the tensor together to form a return
377   torch::jit::drop(stack, num_arguments);
378   auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
379   for (const auto return_idx : c10::irange(0, num_returns)) {
380     auto shards = output_shards_chunks[return_idx];
381     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
382     auto flat_output = safeStack(shards);
383     // See NOTE [vmap through backward and undefined grad]
384     if (!flat_output.defined()) {
385       torch::jit::push(stack, flat_output);
386       continue;
387     }
388     VmapDimVector output_sizes(batch_sizes);
389     output_sizes.insert(
390         output_sizes.end(),
391         flat_output.sizes().begin() + 1,
392         flat_output.sizes().end());
393     torch::jit::push(
394         stack,
395         input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes)));
396   }
397 }
398 
batchedNestedTensorForLoopFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)399 void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
400   const auto& schema = op.schema();
401   const auto num_returns = schema.returns().size();
402   const auto num_arguments = schema.arguments().size();
403   const auto arguments = torch::jit::last(stack, num_arguments);
404 
405   TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
406               "Nested batching rule not implemented for ", schema.operator_name(), ". ",
407               "We could not generate a fallback.");
408 
409   if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
410     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
411     c10::impl::ExcludeDispatchKeyGuard nt_guard(DispatchKey::BatchedNestedTensor);
412     op.callBoxed(stack);
413     return;
414   }
415 
416   if (isInplaceOp(schema)) {
417     TORCH_INTERNAL_ASSERT(false, "vmap fallback not supported for in-place ops on nested tensors");
418     return;
419   }
420   TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(),
421               "Nested batching rule not implemented for ", schema.operator_name(), "; ",
422               "the fallback path doesn't work on out= or view ops.");
423   TORCH_CHECK(num_returns >= 1,
424               "Nested batching rule not implemented for ", schema.operator_name(), ". ",
425               "The fallback path does not support operations with no returns.");
426   warnFallback(schema, /*in_place*/false, /*is_nested*/true);
427 
428   const auto arguments_begin = stack->size() - num_arguments;
429 
430   // Figure out which arguments are BatchedTensor. Save them to a vector.
431   // For each BatchedTensor, also record what position of `arguments` they came from.
432   at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
433   VmapDimVector batched_tensor_inputs_position;
434   for (const auto idx : c10::irange(0, arguments.size())) {
435     const auto& ivalue = arguments[idx];
436     if (!ivalue.isTensor()) {
437       continue;
438     }
439     const auto& tensor = ivalue.toTensor();
440     if (!tensor.defined()) {
441       continue;
442     }
443     const auto* batched = maybeGetBatchedImpl(tensor);
444     if (!batched) {
445       continue;
446     }
447     batched_tensor_inputs.push_back(tensor);
448     batched_tensor_inputs_position.push_back(static_cast<int64_t>(idx));
449   }
450   TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
451 
452   std::vector<std::vector<Tensor>> unbound;
453   for (auto const &batched_tensor_input: batched_tensor_inputs) {
454     auto *batched_impl = maybeGetBatchedImpl(batched_tensor_input);
455     TORCH_INTERNAL_ASSERT(batched_impl->value().is_nested() || batched_impl->bdim() == 0,
456         "Fallback not supported for mixed nested / non-nested arguments without bdim=0");
457     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::BatchedNestedTensor);
458     auto this_unbound = batched_impl->value().unbind();
459     if (!unbound.empty()) {
460       TORCH_INTERNAL_ASSERT(unbound.front().size() == this_unbound.size(),
461           "Fallback not supported for differently-sized nested arguments");
462     }
463     unbound.push_back(this_unbound);
464   }
465 
466   const auto num_components = unbound.front().size();
467   std::vector<Tensor> output_shards(num_components * num_returns);
468   for (const auto component_idx : c10::irange(0, num_components)) {
469     auto batched_idx = 0;
470     auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
471     for (const auto arg_idx : c10::irange(0, num_arguments)) {
472       // We assume that torch::jit::Stack is backed by vector<IValue> for
473       // simplicity. When that is not the case, this code should be updated.
474       const auto& argument = (*stack)[arguments_begin + arg_idx];
475       if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
476           || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
477         // argument isn't a BatchedTensor
478         torch::jit::push(stack, argument);
479         continue;
480       }
481       // argument is a BatchedTensor
482       torch::jit::push(stack, unbound[batched_idx][component_idx]);
483       ++batched_idx;
484       ++batched_tensor_inputs_pos_iter;
485     }
486 
487     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::BatchedNestedTensor);
488     op.callBoxed(stack);
489 
490     // Store the result into `output_shards`. See NOTE: [Output shards layout]
491     // to learn about the details of how we store the shards.
492     const auto returns = torch::jit::last(stack, num_returns);
493     for (const auto return_idx : c10::irange(0, returns.size())) {
494       output_shards[num_components * return_idx + component_idx] = returns[return_idx].toTensor();
495     }
496     torch::jit::drop(stack, num_returns);
497   }
498 
499   // For each output Tensor, stack the shards of the tensor together to form a nested return
500   // TODO: Determine when the output needs to be nested and when it can be non-nested?
501   torch::jit::drop(stack, num_arguments);
502   auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_components);
503   for (const auto return_idx : c10::irange(0, num_returns)) {
504     auto shards = output_shards_chunks[return_idx];
505     c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::BatchedNestedTensor);
506     auto out_nt = at::_nested_tensor_from_tensor_list(shards);
507     // NB: NTs only support batching over dim 0
508     torch::jit::push(stack, makeBatched(out_nt, 0, maybeCurrentDynamicLayer()->layerId()));
509   }
510 }
511 
vmapErrorFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)512 void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
513   TORCH_CHECK(false, "Error: ", op.operator_name(), " requires special handling, and does not yet have a batching rule. Feel free to file a github issue!");
514 }
515 
516 } // namespace at::functorch
517