1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchedFallback.h>
8 #include <ATen/functorch/LegacyVmapTransforms.h>
9 #include <ATen/functorch/TensorWrapper.h>
10 #include <ATen/functorch/DynamicLayer.h>
11 #include <ATen/functorch/PlumbingHelper.h>
12
13 #include <ATen/Context.h>
14 #include <ATen/MatrixRef.h>
15 #include <ATen/core/dispatch/Dispatcher.h>
16 #include <c10/util/accumulate.h>
17 #include <c10/util/llvmMathExtras.h>
18 #include <c10/util/irange.h>
19
20 namespace at::functorch {
21
22 bool kVmapFallbackWarningEnabled = true;
23
isVmapFallbackWarningEnabled()24 bool isVmapFallbackWarningEnabled() {
25 return kVmapFallbackWarningEnabled;
26 }
27
setVmapFallbackWarningEnabled(bool enabled)28 void setVmapFallbackWarningEnabled(bool enabled) {
29 kVmapFallbackWarningEnabled = enabled;
30 }
31
32 bool kVmapFallbackEnabled = true;
33
isVmapFallbackEnabled()34 bool isVmapFallbackEnabled() {
35 return kVmapFallbackEnabled;
36 }
37
setVmapFallbackEnabled(bool enabled)38 void setVmapFallbackEnabled(bool enabled) {
39 kVmapFallbackEnabled = enabled;
40 }
41
42 // Given a linear index, return the actual index.
43 // Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0]
44 static at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize>
computeIndex(int64_t linear_idx,IntArrayRef sizes)45 computeIndex(int64_t linear_idx, IntArrayRef sizes) {
46 at::SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize> result;
47 result.reserve(sizes.size());
48 for (auto it = sizes.rbegin(); it != sizes.rend(); it++) {
49 auto remainder = linear_idx % *it;
50 result.push_back(remainder);
51 linear_idx -= remainder;
52 linear_idx /= *it;
53 }
54 std::reverse(std::begin(result), std::end(result));
55 return result;
56 }
57
areAllReturnsTensors(const at::FunctionSchema & schema)58 static bool areAllReturnsTensors(const at::FunctionSchema& schema) {
59 return std::all_of(
60 schema.returns().begin(),
61 schema.returns().end(),
62 [] (const Argument& arg) { return arg.type() == TensorType::get(); });
63 }
64
areAnyArgumentsTensorList(const at::FunctionSchema & schema)65 static bool areAnyArgumentsTensorList(const at::FunctionSchema& schema) {
66 return std::any_of(
67 schema.arguments().begin(),
68 schema.arguments().end(),
69 [] (const Argument& arg) {
70 return arg.type()->isSubtypeOf(ListType::ofTensors()) ||
71 arg.type()->isSubtypeOf(ListType::ofOptionalTensors());
72 });
73 }
74
warnFallback(const c10::FunctionSchema & schema,bool is_inplace,bool is_nested=false)75 static void warnFallback(const c10::FunctionSchema& schema, bool is_inplace, bool is_nested=false) {
76 TORCH_CHECK(isVmapFallbackEnabled(),
77 schema.operator_name(), " hit the vmap fallback which is currently disabled");
78 if (!isVmapFallbackWarningEnabled()) {
79 return;
80 }
81 TORCH_WARN("There is a performance drop because we have not yet implemented ",
82 "the ", (is_nested ? "nested " : "") , "batching rule for ",
83 schema.operator_name(), ". Please file us an issue on GitHub so that ",
84 "we can prioritize its implementation.");
85 }
86
87 // The general flow of the algorithm is as follows.
88 // - First, we figure out which arguments are BatchedTensors and save them
89 // to a vector. We also store a vector of which index of the arguments list
90 // each BatchedTensor appears in. This will be useful for bookkeeping later.
91 // - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
92 // This returns a vector of VmapPhysicalView that hold tensors that contain
93 // all of the collective batch dimensions at the front of the tensors.
94 // - Then, we attempt to call `op` once per slice of the inputs. To do this,
95 // we repeatedly we slice the input arguments (if they are BatchedTensors),
96 // put the sliced (or a not-sliced) version of the input onto the stack, invoke
97 // the operator, and then pop the results off the stack.
batchedTensorInplaceForLoopFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)98 static void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
99 const auto& schema = op.schema();
100 warnFallback(schema, /*in_place*/true);
101
102 const auto num_arguments = schema.arguments().size();
103 const auto arguments = torch::jit::last(stack, num_arguments);
104 const auto arguments_begin = stack->size() - num_arguments;
105
106 // `self` is the Tensor being modified in-place
107 Tensor self = arguments[0].toTensor();
108 const auto* self_impl = maybeGetBatchedImpl(self);
109 std::bitset<kVmapMaxTensorDims> self_vmap_levels;
110 if (self_impl) {
111 self_vmap_levels = createVmapLevelsBitset(self_impl->level());
112 }
113
114 // Figure out which arguments are BatchedTensor. Save them to a vector.
115 // For each BatchedTensor, also record what position of `arguments` they came from.
116 at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
117 VmapDimVector batched_tensor_inputs_position;
118 for (const auto idx : c10::irange(0, arguments.size())) {
119 const auto& ivalue = arguments[idx];
120 if (!ivalue.isTensor()) {
121 continue;
122 }
123 const auto& tensor = ivalue.toTensor();
124 if (!tensor.defined()) {
125 continue;
126 }
127 const auto* batched = maybeGetBatchedImpl(tensor);
128 if (!batched) {
129 continue;
130 }
131
132 // NOTE: [vmap-incompatible in-place operations]
133 // In-place operations on `self` are not possible if there exists some vmap
134 // level `l` such that `self` is not being vmapped on that level but another
135 // argument is. For example, let B0 be a batch dim inside vmap and consider
136 // vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3))
137 // - self is torch.ones(3) and does not participate in this vmap
138 // - other is BatchedTensor(torch.ones(B0, 3))
139 // There's no way to do self.add_(other) because `other` has more elements
140 // elements than `self` due to being vmapped over.
141 //
142 // In the vmap fallback, we should error out when we detect this.
143 auto other_vmap_levels = createVmapLevelsBitset(batched->level());
144 if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) {
145 // Find one vmap level to complain about
146 auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels;
147 [[maybe_unused]] auto offending_level = llvm::findLastSet(additional_bdims.to_ulong());
148 // The following prints out "vmap: aten::add_(tensor, ...) is not possible",
149 // but it would be better to print out "tensor.add_(...) is not possible".
150 // Afaict there's no official way to get the add_ and there is no way to
151 // tell if an operator has method or function variants.
152 TORCH_CHECK(false,
153 "vmap: ", schema.name(), "(self, *extra_args) is not possible because ",
154 "there exists a Tensor `other` in extra_args that has more elements ",
155 "than `self`. This happened due to `other` being vmapped over but ",
156 "`self` not being vmapped over at level ", offending_level, ". ",
157 "Please try to use out-of-place operators instead of ", schema.name(), ". ",
158 "If said operator is being called inside the PyTorch framework, ",
159 "please file a bug report instead.");
160 }
161 batched_tensor_inputs.push_back(tensor);
162 batched_tensor_inputs_position.push_back(static_cast<int64_t>(idx));
163 }
164 TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
165
166 // MultiBatchVmapTransform the BatchedTensor arguments. This returns
167 // VmapPhysicalViews that contain all of the batch dimensions.
168 const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
169 batched_tensor_inputs);
170
171 // Compute the total number of batches
172 auto num_batch_dims = input_physical_views.front().numBatchDims();
173 auto first_physical_view_sizes = input_physical_views.front().tensor().sizes();
174 auto batch_sizes = ArrayRef<int64_t>(
175 first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims);
176 const auto num_batches = c10::multiply_integers(batch_sizes);
177 // Without a shape-checking API, we're unable to compute the correct shape of
178 // the output so we just error out.
179 TORCH_CHECK(num_batches > 0,
180 "Batching rule not implemented for ", schema.operator_name(), ". ",
181 "The fallback path does not support vmap over dims of size 0.");
182
183 // Strategy: For each batch, we are going to push slices (where applicable)
184 // of the arguments onto `stack`, and call `op`.
185 for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
186 auto index = computeIndex(linear_idx, batch_sizes);
187 auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
188 auto input_physical_views_iter = input_physical_views.begin();
189 for (const auto arg_idx : c10::irange(0, num_arguments)) {
190 // We assume that torch::jit::Stack is backed by vector<IValue> for
191 // simplicity. When that is not the case, this code should be updated.
192 const auto& argument = (*stack)[arguments_begin + arg_idx];
193 if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
194 || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
195 // argument isn't a BatchedTensor
196 torch::jit::push(stack, argument);
197 continue;
198 }
199 // argument is a BatchedTensor
200 TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
201 const auto& physical_view_for_argument = *input_physical_views_iter;
202 auto thing = physical_view_for_argument.tensor().index(index);
203 torch::jit::push(stack, thing);
204 batched_tensor_inputs_pos_iter++;
205 input_physical_views_iter++;
206 }
207
208 op.callBoxed(stack);
209 torch::jit::drop(stack, 1);
210 }
211
212 // Return the tensor that was written to in-place
213 torch::jit::drop(stack, num_arguments);
214 torch::jit::push(stack, self);
215 }
216
safeStack(TensorList tensors)217 static Tensor safeStack(TensorList tensors) {
218 auto is_defined = [](const Tensor& t) { return t.defined(); };
219 if (std::all_of(tensors.begin(), tensors.end(), is_defined)) {
220 return at::stack(tensors);
221 }
222 // NOTE [vmap through backward and undefined grad]
223 // While vmapping through backward functions (to compute batched grad), it
224 // is possible for the backward function to return an undefined grad for some
225 // grad_input for each example. In that case, we return an undefined grad.
226 //
227 // It is theoretically posssible for *some* of the examples to produce an
228 // undefined grad (a kernel could peek at the gradient values and return an
229 // undefined tensor if it determines the gradient is full of zeros). We
230 // could handle this by treating the undefined grad as a zero-filled tensor
231 // of the correct shape while stacking the tensors together. However I expect
232 // this to happen very rarely (I have not been able to find an example in our
233 // codebase) so we just error out in this case.
234 if (std::none_of(tensors.begin(), tensors.end(), is_defined)) {
235 return Tensor();
236 }
237 TORCH_CHECK(false,
238 "vmap: slow fallback received a mix of undefined and defined tensors ",
239 "as the result of an operation. This is not supported, please file us ",
240 "an issue on github.");
241 }
242
243 // TODO: Consider rewriting the following to look like:
244 // https://gist.github.com/zou3519/7b7c6a4a258d580f62d1d969851be6b1<Paste>
245
246 // The general flow of the algorithm is as follows.
247 // - First, we figure out which arguments are BatchedTensors and save them
248 // to a vector. We also store a vector of which index of the arguments list
249 // each BatchedTensor appears in. This will be useful for bookkeeping later.
250 // - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
251 // This returns a vector of VmapPhysicalView that hold tensors that contain
252 // all of the collective batch dimensions at the front of the tensors.
253 // - Then, we attempt to call `op` once per slice of the inputs. To do this,
254 // we repeatedly we slice the input arguments (if they are BatchedTensors),
255 // put the sliced (or a not-sliced) version of the input onto the stack, invoke
256 // the operator, and then pop the results off the stack.
257 // - Each result obtained from the previous step is a slice of the total result,
258 // so we stack those tensors together to form the final result.
batchedTensorForLoopFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)259 void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
260 const auto& schema = op.schema();
261 const auto num_returns = schema.returns().size();
262 const auto num_arguments = schema.arguments().size();
263 const auto arguments = torch::jit::last(stack, num_arguments);
264
265 TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
266 "Batching rule not implemented for ", schema.operator_name(), ". ",
267 "We could not generate a fallback.");
268
269 if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
270 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
271 op.callBoxed(stack);
272 return;
273 }
274
275 if (isInplaceOp(schema)) {
276 batchedTensorInplaceForLoopFallback(op, stack);
277 return;
278 }
279 TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(),
280 "Batching rule not implemented for ", schema.operator_name(), "; ",
281 "the fallback path doesn't work on out= or view ops.");
282 TORCH_CHECK(num_returns >= 1,
283 "Batching rule not implemented for ", schema.operator_name(), ". ",
284 "The fallback path does not support operations with no returns.");
285 warnFallback(schema, /*in_place*/false);
286
287 const auto arguments_begin = stack->size() - num_arguments;
288
289 // Figure out which arguments are BatchedTensor. Save them to a vector.
290 // For each BatchedTensor, also record what position of `arguments` they came from.
291 at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
292 VmapDimVector batched_tensor_inputs_position;
293 for (const auto idx : c10::irange(0, arguments.size())) {
294 const auto& ivalue = arguments[idx];
295 if (!ivalue.isTensor()) {
296 continue;
297 }
298 const auto& tensor = ivalue.toTensor();
299 if (!tensor.defined()) {
300 continue;
301 }
302 const auto* batched = maybeGetBatchedImpl(tensor);
303 if (!batched) {
304 continue;
305 }
306 batched_tensor_inputs.push_back(tensor);
307 batched_tensor_inputs_position.push_back(static_cast<int64_t>(idx));
308 }
309 TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
310
311 // MultiBatchVmapTransform the BatchedTensor arguments. This returns
312 // VmapPhysicalViews that contain all of the batch dimensions.
313 const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
314 batched_tensor_inputs);
315
316 // Compute the total number of batches
317 auto num_batch_dims = input_physical_views.front().numBatchDims();
318 auto some_sizes = input_physical_views.front().tensor().sizes();
319 auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims);
320 const auto num_batches = c10::multiply_integers(batch_sizes);
321 // Without a shape-checking API, we're unable to compute the correct shape of
322 // the output so we just error out.
323 TORCH_CHECK(num_batches > 0,
324 "Batching rule not implemented for ", schema.operator_name(), ". ",
325 "The fallback path does not support vmap over dims of size 0.");
326
327 // Strategy: For each batch, we are going to push slices (where applicable)
328 // of the arguments onto `stack`, call `op`, and store the result in
329 // `output_shards`.
330 //
331 // NOTE: [Output shards layout]
332 // Assume that the operator has three outputs: a, b, c.
333 // The layout of output_shards is as follows:
334 // [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3]
335 // This is so that we can call at::stack([a0...a3]), at::stack([b0...b3])
336 // more easily in the next step.
337 std::vector<Tensor> output_shards(num_batches * num_returns);
338
339 for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
340 auto index = computeIndex(linear_idx, batch_sizes);
341 auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
342 auto input_physical_views_iter = input_physical_views.begin();
343 for (const auto arg_idx : c10::irange(0, num_arguments)) {
344 // We assume that torch::jit::Stack is backed by vector<IValue> for
345 // simplicity. When that is not the case, this code should be updated.
346 const auto& argument = (*stack)[arguments_begin + arg_idx];
347 if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
348 || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
349 // argument isn't a BatchedTensor
350 torch::jit::push(stack, argument);
351 continue;
352 }
353 // argument is a BatchedTensor
354 TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
355 const auto& physical_view_for_argument = *input_physical_views_iter;
356 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
357 torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
358 batched_tensor_inputs_pos_iter++;
359 input_physical_views_iter++;
360 }
361
362 // std::cout << "[Fallback]: ";
363 // at::dump_tensor((*stack)[stack->size() - 1].toTensor());
364 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
365 op.callBoxed(stack);
366
367 // Store the result into `output_shards`. See NOTE: [Output shards layout]
368 // to learn about the details of how we store the shards.
369 const auto returns = torch::jit::last(stack, num_returns);
370 for (const auto return_idx : c10::irange(0, returns.size())) {
371 output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor();
372 }
373 torch::jit::drop(stack, num_returns);
374 }
375
376 // For each output Tensor, stack the shards of the tensor together to form a return
377 torch::jit::drop(stack, num_arguments);
378 auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
379 for (const auto return_idx : c10::irange(0, num_returns)) {
380 auto shards = output_shards_chunks[return_idx];
381 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
382 auto flat_output = safeStack(shards);
383 // See NOTE [vmap through backward and undefined grad]
384 if (!flat_output.defined()) {
385 torch::jit::push(stack, flat_output);
386 continue;
387 }
388 VmapDimVector output_sizes(batch_sizes);
389 output_sizes.insert(
390 output_sizes.end(),
391 flat_output.sizes().begin() + 1,
392 flat_output.sizes().end());
393 torch::jit::push(
394 stack,
395 input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes)));
396 }
397 }
398
batchedNestedTensorForLoopFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)399 void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
400 const auto& schema = op.schema();
401 const auto num_returns = schema.returns().size();
402 const auto num_arguments = schema.arguments().size();
403 const auto arguments = torch::jit::last(stack, num_arguments);
404
405 TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
406 "Nested batching rule not implemented for ", schema.operator_name(), ". ",
407 "We could not generate a fallback.");
408
409 if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
410 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
411 c10::impl::ExcludeDispatchKeyGuard nt_guard(DispatchKey::BatchedNestedTensor);
412 op.callBoxed(stack);
413 return;
414 }
415
416 if (isInplaceOp(schema)) {
417 TORCH_INTERNAL_ASSERT(false, "vmap fallback not supported for in-place ops on nested tensors");
418 return;
419 }
420 TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(),
421 "Nested batching rule not implemented for ", schema.operator_name(), "; ",
422 "the fallback path doesn't work on out= or view ops.");
423 TORCH_CHECK(num_returns >= 1,
424 "Nested batching rule not implemented for ", schema.operator_name(), ". ",
425 "The fallback path does not support operations with no returns.");
426 warnFallback(schema, /*in_place*/false, /*is_nested*/true);
427
428 const auto arguments_begin = stack->size() - num_arguments;
429
430 // Figure out which arguments are BatchedTensor. Save them to a vector.
431 // For each BatchedTensor, also record what position of `arguments` they came from.
432 at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
433 VmapDimVector batched_tensor_inputs_position;
434 for (const auto idx : c10::irange(0, arguments.size())) {
435 const auto& ivalue = arguments[idx];
436 if (!ivalue.isTensor()) {
437 continue;
438 }
439 const auto& tensor = ivalue.toTensor();
440 if (!tensor.defined()) {
441 continue;
442 }
443 const auto* batched = maybeGetBatchedImpl(tensor);
444 if (!batched) {
445 continue;
446 }
447 batched_tensor_inputs.push_back(tensor);
448 batched_tensor_inputs_position.push_back(static_cast<int64_t>(idx));
449 }
450 TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
451
452 std::vector<std::vector<Tensor>> unbound;
453 for (auto const &batched_tensor_input: batched_tensor_inputs) {
454 auto *batched_impl = maybeGetBatchedImpl(batched_tensor_input);
455 TORCH_INTERNAL_ASSERT(batched_impl->value().is_nested() || batched_impl->bdim() == 0,
456 "Fallback not supported for mixed nested / non-nested arguments without bdim=0");
457 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::BatchedNestedTensor);
458 auto this_unbound = batched_impl->value().unbind();
459 if (!unbound.empty()) {
460 TORCH_INTERNAL_ASSERT(unbound.front().size() == this_unbound.size(),
461 "Fallback not supported for differently-sized nested arguments");
462 }
463 unbound.push_back(this_unbound);
464 }
465
466 const auto num_components = unbound.front().size();
467 std::vector<Tensor> output_shards(num_components * num_returns);
468 for (const auto component_idx : c10::irange(0, num_components)) {
469 auto batched_idx = 0;
470 auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
471 for (const auto arg_idx : c10::irange(0, num_arguments)) {
472 // We assume that torch::jit::Stack is backed by vector<IValue> for
473 // simplicity. When that is not the case, this code should be updated.
474 const auto& argument = (*stack)[arguments_begin + arg_idx];
475 if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
476 || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
477 // argument isn't a BatchedTensor
478 torch::jit::push(stack, argument);
479 continue;
480 }
481 // argument is a BatchedTensor
482 torch::jit::push(stack, unbound[batched_idx][component_idx]);
483 ++batched_idx;
484 ++batched_tensor_inputs_pos_iter;
485 }
486
487 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::BatchedNestedTensor);
488 op.callBoxed(stack);
489
490 // Store the result into `output_shards`. See NOTE: [Output shards layout]
491 // to learn about the details of how we store the shards.
492 const auto returns = torch::jit::last(stack, num_returns);
493 for (const auto return_idx : c10::irange(0, returns.size())) {
494 output_shards[num_components * return_idx + component_idx] = returns[return_idx].toTensor();
495 }
496 torch::jit::drop(stack, num_returns);
497 }
498
499 // For each output Tensor, stack the shards of the tensor together to form a nested return
500 // TODO: Determine when the output needs to be nested and when it can be non-nested?
501 torch::jit::drop(stack, num_arguments);
502 auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_components);
503 for (const auto return_idx : c10::irange(0, num_returns)) {
504 auto shards = output_shards_chunks[return_idx];
505 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::BatchedNestedTensor);
506 auto out_nt = at::_nested_tensor_from_tensor_list(shards);
507 // NB: NTs only support batching over dim 0
508 torch::jit::push(stack, makeBatched(out_nt, 0, maybeCurrentDynamicLayer()->layerId()));
509 }
510 }
511
vmapErrorFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)512 void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
513 TORCH_CHECK(false, "Error: ", op.operator_name(), " requires special handling, and does not yet have a batching rule. Feel free to file a github issue!");
514 }
515
516 } // namespace at::functorch
517