1*da0073e9SAndroid Build Coastguard Worker #include <ATen/Context.h>
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/LegacyBatchedFallback.h>
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/MatrixRef.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/LegacyVmapTransforms.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/dispatch/Dispatcher.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/util/accumulate.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/util/llvmMathExtras.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker namespace at {
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker // Given a linear index, return the actual index.
13*da0073e9SAndroid Build Coastguard Worker // Example: Given linear_idx = 3, sizes = [5, 2], we would return [1, 0]
14*da0073e9SAndroid Build Coastguard Worker static SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize>
computeIndex(int64_t linear_idx,IntArrayRef sizes)15*da0073e9SAndroid Build Coastguard Worker computeIndex(int64_t linear_idx, IntArrayRef sizes) {
16*da0073e9SAndroid Build Coastguard Worker SmallVector<indexing::TensorIndex,kVmapStaticDimVecSize> result;
17*da0073e9SAndroid Build Coastguard Worker result.reserve(sizes.size());
18*da0073e9SAndroid Build Coastguard Worker for (auto it = sizes.rbegin(); it != sizes.rend(); it++) {
19*da0073e9SAndroid Build Coastguard Worker auto remainder = linear_idx % *it;
20*da0073e9SAndroid Build Coastguard Worker result.push_back(remainder);
21*da0073e9SAndroid Build Coastguard Worker linear_idx -= remainder;
22*da0073e9SAndroid Build Coastguard Worker linear_idx /= *it;
23*da0073e9SAndroid Build Coastguard Worker }
24*da0073e9SAndroid Build Coastguard Worker std::reverse(std::begin(result), std::end(result));
25*da0073e9SAndroid Build Coastguard Worker return result;
26*da0073e9SAndroid Build Coastguard Worker }
27*da0073e9SAndroid Build Coastguard Worker
areAllReturnsTensors(const FunctionSchema & schema)28*da0073e9SAndroid Build Coastguard Worker static bool areAllReturnsTensors(const FunctionSchema& schema) {
29*da0073e9SAndroid Build Coastguard Worker return std::all_of(
30*da0073e9SAndroid Build Coastguard Worker schema.returns().begin(),
31*da0073e9SAndroid Build Coastguard Worker schema.returns().end(),
32*da0073e9SAndroid Build Coastguard Worker [] (const Argument& arg) { return arg.type() == TensorType::get(); });
33*da0073e9SAndroid Build Coastguard Worker }
34*da0073e9SAndroid Build Coastguard Worker
areAnyArgumentsTensorList(const FunctionSchema & schema)35*da0073e9SAndroid Build Coastguard Worker static bool areAnyArgumentsTensorList(const FunctionSchema& schema) {
36*da0073e9SAndroid Build Coastguard Worker return std::any_of(
37*da0073e9SAndroid Build Coastguard Worker schema.arguments().begin(),
38*da0073e9SAndroid Build Coastguard Worker schema.arguments().end(),
39*da0073e9SAndroid Build Coastguard Worker [] (const Argument& arg) { return arg.type()->isSubtypeOf(*ListType::ofTensors()); });
40*da0073e9SAndroid Build Coastguard Worker }
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker // Returns if an operator is in-place. An operator is inplace if:
43*da0073e9SAndroid Build Coastguard Worker // 1. The first argument is a Tensor and it is being written to
44*da0073e9SAndroid Build Coastguard Worker // 2. The first argument is being returned
45*da0073e9SAndroid Build Coastguard Worker // 3. No other arguments are aliased
46*da0073e9SAndroid Build Coastguard Worker // Here is an example of an in-place operator:
47*da0073e9SAndroid Build Coastguard Worker // add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
isInplaceOp(const c10::FunctionSchema & schema)48*da0073e9SAndroid Build Coastguard Worker static bool isInplaceOp(const c10::FunctionSchema& schema) {
49*da0073e9SAndroid Build Coastguard Worker if (!schema.is_mutable() || schema.returns().size() != 1) {
50*da0073e9SAndroid Build Coastguard Worker return false;
51*da0073e9SAndroid Build Coastguard Worker }
52*da0073e9SAndroid Build Coastguard Worker // Check that the first argument is being written to
53*da0073e9SAndroid Build Coastguard Worker const AliasInfo* first_arg_alias_info = schema.arguments().begin()->alias_info();
54*da0073e9SAndroid Build Coastguard Worker if (!first_arg_alias_info || !first_arg_alias_info->isWrite()) {
55*da0073e9SAndroid Build Coastguard Worker return false;
56*da0073e9SAndroid Build Coastguard Worker }
57*da0073e9SAndroid Build Coastguard Worker // Check that none of the other args are being aliased
58*da0073e9SAndroid Build Coastguard Worker for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) {
59*da0073e9SAndroid Build Coastguard Worker const AliasInfo* alias_info = it->alias_info();
60*da0073e9SAndroid Build Coastguard Worker if (alias_info) {
61*da0073e9SAndroid Build Coastguard Worker return false;
62*da0073e9SAndroid Build Coastguard Worker }
63*da0073e9SAndroid Build Coastguard Worker }
64*da0073e9SAndroid Build Coastguard Worker // Check that the first tensor is being returned (i.e., output has a (a!))
65*da0073e9SAndroid Build Coastguard Worker const AliasInfo* return_alias_info = schema.returns()[0].alias_info();
66*da0073e9SAndroid Build Coastguard Worker return return_alias_info && return_alias_info->isWrite();
67*da0073e9SAndroid Build Coastguard Worker }
68*da0073e9SAndroid Build Coastguard Worker
warnFallback(const c10::FunctionSchema & schema)69*da0073e9SAndroid Build Coastguard Worker static void warnFallback(const c10::FunctionSchema& schema) {
70*da0073e9SAndroid Build Coastguard Worker if (!globalContext().areVmapFallbackWarningsEnabled()) {
71*da0073e9SAndroid Build Coastguard Worker return;
72*da0073e9SAndroid Build Coastguard Worker }
73*da0073e9SAndroid Build Coastguard Worker TORCH_WARN("There is a performance drop because we have not yet implemented ",
74*da0073e9SAndroid Build Coastguard Worker "the batching rule for ", schema.operator_name(), ". ",
75*da0073e9SAndroid Build Coastguard Worker "You are using the legacy vmap prototype (torch._vmap_internals.vmap). ",
76*da0073e9SAndroid Build Coastguard Worker "If you are using torch.autograd.functional.{jacobian, hessian} ",
77*da0073e9SAndroid Build Coastguard Worker "or torch._vmap_internals.vmap: please switch to using ",
78*da0073e9SAndroid Build Coastguard Worker "torch.func.{jacrev, jacfwd, hessian} and/or torch.vmap instead ",
79*da0073e9SAndroid Build Coastguard Worker "for better operator coverage and performance improvements .");
80*da0073e9SAndroid Build Coastguard Worker }
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker // The general flow of the algorithm is as follows.
83*da0073e9SAndroid Build Coastguard Worker // - First, we figure out which arguments are BatchedTensors and save them
84*da0073e9SAndroid Build Coastguard Worker // to a vector. We also store a vector of which index of the arguments list
85*da0073e9SAndroid Build Coastguard Worker // each BatchedTensor appears in. This will be useful for bookkeeping later.
86*da0073e9SAndroid Build Coastguard Worker // - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
87*da0073e9SAndroid Build Coastguard Worker // This returns a vector of VmapPhysicalView that hold tensors that contain
88*da0073e9SAndroid Build Coastguard Worker // all of the collective batch dimensions at the front of the tensors.
89*da0073e9SAndroid Build Coastguard Worker // - Then, we attempt to call `op` once per slice of the inputs. To do this,
90*da0073e9SAndroid Build Coastguard Worker // we repeatedly we slice the input arguments (if they are BatchedTensors),
91*da0073e9SAndroid Build Coastguard Worker // put the sliced (or a not-sliced) version of the input onto the stack, invoke
92*da0073e9SAndroid Build Coastguard Worker // the operator, and then pop the results off the stack.
batchedTensorInplaceForLoopFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)93*da0073e9SAndroid Build Coastguard Worker static void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
94*da0073e9SAndroid Build Coastguard Worker const auto& schema = op.schema();
95*da0073e9SAndroid Build Coastguard Worker warnFallback(schema);
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
98*da0073e9SAndroid Build Coastguard Worker const auto arguments = torch::jit::last(stack, num_arguments);
99*da0073e9SAndroid Build Coastguard Worker const auto arguments_begin = stack->size() - num_arguments;
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker // `self` is the Tensor being modified in-place
102*da0073e9SAndroid Build Coastguard Worker Tensor self = arguments[0].toTensor();
103*da0073e9SAndroid Build Coastguard Worker const auto* self_impl = maybeGetBatchedImpl(self);
104*da0073e9SAndroid Build Coastguard Worker std::bitset<kVmapMaxTensorDims> self_vmap_levels;
105*da0073e9SAndroid Build Coastguard Worker if (self_impl) {
106*da0073e9SAndroid Build Coastguard Worker self_vmap_levels = createVmapLevelsBitset(self_impl->bdims());
107*da0073e9SAndroid Build Coastguard Worker }
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker // Figure out which arguments are BatchedTensor. Save them to a vector.
110*da0073e9SAndroid Build Coastguard Worker // For each BatchedTensor, also record what position of `arguments` they came from.
111*da0073e9SAndroid Build Coastguard Worker SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
112*da0073e9SAndroid Build Coastguard Worker VmapDimVector batched_tensor_inputs_position;
113*da0073e9SAndroid Build Coastguard Worker for (const auto idx : c10::irange(arguments.size())) {
114*da0073e9SAndroid Build Coastguard Worker const auto& ivalue = arguments[idx];
115*da0073e9SAndroid Build Coastguard Worker if (!ivalue.isTensor()) {
116*da0073e9SAndroid Build Coastguard Worker continue;
117*da0073e9SAndroid Build Coastguard Worker }
118*da0073e9SAndroid Build Coastguard Worker const auto& tensor = ivalue.toTensor();
119*da0073e9SAndroid Build Coastguard Worker if (!tensor.defined()) {
120*da0073e9SAndroid Build Coastguard Worker continue;
121*da0073e9SAndroid Build Coastguard Worker }
122*da0073e9SAndroid Build Coastguard Worker const auto* batched = maybeGetBatchedImpl(tensor);
123*da0073e9SAndroid Build Coastguard Worker if (!batched) {
124*da0073e9SAndroid Build Coastguard Worker continue;
125*da0073e9SAndroid Build Coastguard Worker }
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker // NOTE: [vmap-incompatible in-place operations]
128*da0073e9SAndroid Build Coastguard Worker // In-place operations on `self` are not possible if there exists some vmap
129*da0073e9SAndroid Build Coastguard Worker // level `l` such that `self` is not being vmapped on that level but another
130*da0073e9SAndroid Build Coastguard Worker // argument is. For example, let B0 be a batch dim inside vmap and consider
131*da0073e9SAndroid Build Coastguard Worker // vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3))
132*da0073e9SAndroid Build Coastguard Worker // - self is torch.ones(3) and does not participate in this vmap
133*da0073e9SAndroid Build Coastguard Worker // - other is BatchedTensor(torch.ones(B0, 3))
134*da0073e9SAndroid Build Coastguard Worker // There's no way to do self.add_(other) because `other` has more elements
135*da0073e9SAndroid Build Coastguard Worker // elements than `self` due to being vmapped over.
136*da0073e9SAndroid Build Coastguard Worker //
137*da0073e9SAndroid Build Coastguard Worker // In the vmap fallback, we should error out when we detect this.
138*da0073e9SAndroid Build Coastguard Worker auto other_vmap_levels = createVmapLevelsBitset(batched->bdims());
139*da0073e9SAndroid Build Coastguard Worker if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) {
140*da0073e9SAndroid Build Coastguard Worker // Find one vmap level to complain about
141*da0073e9SAndroid Build Coastguard Worker auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels;
142*da0073e9SAndroid Build Coastguard Worker [[maybe_unused]] auto offending_level = llvm::findLastSet(additional_bdims.to_ulong());
143*da0073e9SAndroid Build Coastguard Worker // The following prints out "vmap: aten::add_(tensor, ...) is not possible",
144*da0073e9SAndroid Build Coastguard Worker // but it would be better to print out "tensor.add_(...) is not possible".
145*da0073e9SAndroid Build Coastguard Worker // Afaict there's no official way to get the add_ and there is no way to
146*da0073e9SAndroid Build Coastguard Worker // tell if an operator has method or function variants.
147*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false,
148*da0073e9SAndroid Build Coastguard Worker "vmap: ", schema.name(), "(self, *extra_args) is not possible because ",
149*da0073e9SAndroid Build Coastguard Worker "there exists a Tensor `other` in extra_args that has more elements ",
150*da0073e9SAndroid Build Coastguard Worker "than `self`. This happened due to `other` being vmapped over but ",
151*da0073e9SAndroid Build Coastguard Worker "`self` not being vmapped over at level ", offending_level, ". ",
152*da0073e9SAndroid Build Coastguard Worker "Please try to use out-of-place operators instead of ", schema.name(), ". ",
153*da0073e9SAndroid Build Coastguard Worker "If said operator is being called inside the PyTorch framework, ",
154*da0073e9SAndroid Build Coastguard Worker "please file a bug report instead.");
155*da0073e9SAndroid Build Coastguard Worker }
156*da0073e9SAndroid Build Coastguard Worker batched_tensor_inputs.push_back(tensor);
157*da0073e9SAndroid Build Coastguard Worker batched_tensor_inputs_position.push_back(idx);
158*da0073e9SAndroid Build Coastguard Worker }
159*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker // MultiBatchVmapTransform the BatchedTensor arguments. This returns
162*da0073e9SAndroid Build Coastguard Worker // VmapPhysicalViews that contain all of the batch dimensions.
163*da0073e9SAndroid Build Coastguard Worker const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
164*da0073e9SAndroid Build Coastguard Worker batched_tensor_inputs);
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker // Compute the total number of batches
167*da0073e9SAndroid Build Coastguard Worker auto num_batch_dims = input_physical_views.front().numBatchDims();
168*da0073e9SAndroid Build Coastguard Worker auto first_physical_view_sizes = input_physical_views.front().tensor().sizes();
169*da0073e9SAndroid Build Coastguard Worker auto batch_sizes = ArrayRef<int64_t>(
170*da0073e9SAndroid Build Coastguard Worker first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims);
171*da0073e9SAndroid Build Coastguard Worker const auto num_batches = c10::multiply_integers(batch_sizes);
172*da0073e9SAndroid Build Coastguard Worker // Without a shape-checking API, we're unable to compute the correct shape of
173*da0073e9SAndroid Build Coastguard Worker // the output so we just error out.
174*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(num_batches > 0,
175*da0073e9SAndroid Build Coastguard Worker "Batching rule not implemented for ", schema.operator_name(), ". ",
176*da0073e9SAndroid Build Coastguard Worker "The fallback path does not support vmap over dims of size 0.");
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker // Strategy: For each batch, we are going to push slices (where applicable)
179*da0073e9SAndroid Build Coastguard Worker // of the arguments onto `stack`, and call `op`.
180*da0073e9SAndroid Build Coastguard Worker for (const auto linear_idx : c10::irange(num_batches)) {
181*da0073e9SAndroid Build Coastguard Worker auto index = computeIndex(linear_idx, batch_sizes);
182*da0073e9SAndroid Build Coastguard Worker auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
183*da0073e9SAndroid Build Coastguard Worker auto input_physical_views_iter = input_physical_views.begin();
184*da0073e9SAndroid Build Coastguard Worker for (const auto arg_idx : c10::irange(num_arguments)) {
185*da0073e9SAndroid Build Coastguard Worker // We assume that torch::jit::Stack is backed by vector<IValue> for
186*da0073e9SAndroid Build Coastguard Worker // simplicity. When that is not the case, this code should be updated.
187*da0073e9SAndroid Build Coastguard Worker const auto& argument = (*stack)[arguments_begin + arg_idx];
188*da0073e9SAndroid Build Coastguard Worker if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
189*da0073e9SAndroid Build Coastguard Worker || arg_idx != *batched_tensor_inputs_pos_iter) {
190*da0073e9SAndroid Build Coastguard Worker // argument isn't a BatchedTensor
191*da0073e9SAndroid Build Coastguard Worker torch::jit::push(stack, argument);
192*da0073e9SAndroid Build Coastguard Worker continue;
193*da0073e9SAndroid Build Coastguard Worker }
194*da0073e9SAndroid Build Coastguard Worker // argument is a BatchedTensor
195*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
196*da0073e9SAndroid Build Coastguard Worker const auto& physical_view_for_argument = *input_physical_views_iter;
197*da0073e9SAndroid Build Coastguard Worker torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
198*da0073e9SAndroid Build Coastguard Worker batched_tensor_inputs_pos_iter++;
199*da0073e9SAndroid Build Coastguard Worker input_physical_views_iter++;
200*da0073e9SAndroid Build Coastguard Worker }
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker op.callBoxed(stack);
203*da0073e9SAndroid Build Coastguard Worker torch::jit::drop(stack, 1);
204*da0073e9SAndroid Build Coastguard Worker }
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker // Return the tensor that was written to in-place
207*da0073e9SAndroid Build Coastguard Worker torch::jit::drop(stack, num_arguments);
208*da0073e9SAndroid Build Coastguard Worker torch::jit::push(stack, self);
209*da0073e9SAndroid Build Coastguard Worker }
210*da0073e9SAndroid Build Coastguard Worker
safeStack(TensorList tensors)211*da0073e9SAndroid Build Coastguard Worker static Tensor safeStack(TensorList tensors) {
212*da0073e9SAndroid Build Coastguard Worker auto is_defined = [](const Tensor& t) { return t.defined(); };
213*da0073e9SAndroid Build Coastguard Worker if (std::all_of(tensors.begin(), tensors.end(), is_defined)) {
214*da0073e9SAndroid Build Coastguard Worker return at::stack(tensors);
215*da0073e9SAndroid Build Coastguard Worker }
216*da0073e9SAndroid Build Coastguard Worker // NOTE [vmap through backward and undefined grad]
217*da0073e9SAndroid Build Coastguard Worker // While vmapping through backward functions (to compute batched grad), it
218*da0073e9SAndroid Build Coastguard Worker // is possible for the backward function to return an undefined grad for some
219*da0073e9SAndroid Build Coastguard Worker // grad_input for each example. In that case, we return an undefined grad.
220*da0073e9SAndroid Build Coastguard Worker //
221*da0073e9SAndroid Build Coastguard Worker // It is theoretically posssible for *some* of the examples to produce an
222*da0073e9SAndroid Build Coastguard Worker // undefined grad (a kernel could peek at the gradient values and return an
223*da0073e9SAndroid Build Coastguard Worker // undefined tensor if it determines the gradient is full of zeros). We
224*da0073e9SAndroid Build Coastguard Worker // could handle this by treating the undefined grad as a zero-filled tensor
225*da0073e9SAndroid Build Coastguard Worker // of the correct shape while stacking the tensors together. However I expect
226*da0073e9SAndroid Build Coastguard Worker // this to happen very rarely (I have not been able to find an example in our
227*da0073e9SAndroid Build Coastguard Worker // codebase) so we just error out in this case.
228*da0073e9SAndroid Build Coastguard Worker if (std::none_of(tensors.begin(), tensors.end(), is_defined)) {
229*da0073e9SAndroid Build Coastguard Worker return Tensor();
230*da0073e9SAndroid Build Coastguard Worker }
231*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false,
232*da0073e9SAndroid Build Coastguard Worker "vmap: slow fallback received a mix of undefined and defined tensors ",
233*da0073e9SAndroid Build Coastguard Worker "as the result of an operation. This is not supported, please file us ",
234*da0073e9SAndroid Build Coastguard Worker "an issue on github.");
235*da0073e9SAndroid Build Coastguard Worker }
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker // The general flow of the algorithm is as follows.
238*da0073e9SAndroid Build Coastguard Worker // - First, we figure out which arguments are BatchedTensors and save them
239*da0073e9SAndroid Build Coastguard Worker // to a vector. We also store a vector of which index of the arguments list
240*da0073e9SAndroid Build Coastguard Worker // each BatchedTensor appears in. This will be useful for bookkeeping later.
241*da0073e9SAndroid Build Coastguard Worker // - Next, we apply the MultiBatchVmapTransform to all of the BatchedTensors.
242*da0073e9SAndroid Build Coastguard Worker // This returns a vector of VmapPhysicalView that hold tensors that contain
243*da0073e9SAndroid Build Coastguard Worker // all of the collective batch dimensions at the front of the tensors.
244*da0073e9SAndroid Build Coastguard Worker // - Then, we attempt to call `op` once per slice of the inputs. To do this,
245*da0073e9SAndroid Build Coastguard Worker // we repeatedly we slice the input arguments (if they are BatchedTensors),
246*da0073e9SAndroid Build Coastguard Worker // put the sliced (or a not-sliced) version of the input onto the stack, invoke
247*da0073e9SAndroid Build Coastguard Worker // the operator, and then pop the results off the stack.
248*da0073e9SAndroid Build Coastguard Worker // - Each result obtained from the previous step is a slice of the total result,
249*da0073e9SAndroid Build Coastguard Worker // so we stack those tensors together to form the final result.
batchedTensorForLoopFallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)250*da0073e9SAndroid Build Coastguard Worker void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
251*da0073e9SAndroid Build Coastguard Worker const auto& schema = op.schema();
252*da0073e9SAndroid Build Coastguard Worker const auto num_returns = schema.returns().size();
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker if (isInplaceOp(schema)) {
255*da0073e9SAndroid Build Coastguard Worker batchedTensorInplaceForLoopFallback(op, stack);
256*da0073e9SAndroid Build Coastguard Worker return;
257*da0073e9SAndroid Build Coastguard Worker }
258*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(),
259*da0073e9SAndroid Build Coastguard Worker "Batching rule not implemented for ", schema.operator_name(), "; ",
260*da0073e9SAndroid Build Coastguard Worker "the fallback path doesn't work on out= or view ops.");
261*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
262*da0073e9SAndroid Build Coastguard Worker "Batching rule not implemented for ", schema.operator_name(), ". ",
263*da0073e9SAndroid Build Coastguard Worker "We could not generate a fallback.");
264*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(num_returns >= 1,
265*da0073e9SAndroid Build Coastguard Worker "Batching rule not implemented for ", schema.operator_name(), ". ",
266*da0073e9SAndroid Build Coastguard Worker "The fallback path does not support operations with no returns.");
267*da0073e9SAndroid Build Coastguard Worker warnFallback(schema);
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
270*da0073e9SAndroid Build Coastguard Worker const auto arguments = torch::jit::last(stack, num_arguments);
271*da0073e9SAndroid Build Coastguard Worker const auto arguments_begin = stack->size() - num_arguments;
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker // Figure out which arguments are BatchedTensor. Save them to a vector.
274*da0073e9SAndroid Build Coastguard Worker // For each BatchedTensor, also record what position of `arguments` they came from.
275*da0073e9SAndroid Build Coastguard Worker SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
276*da0073e9SAndroid Build Coastguard Worker VmapDimVector batched_tensor_inputs_position;
277*da0073e9SAndroid Build Coastguard Worker for (const auto idx : c10::irange(arguments.size())) {
278*da0073e9SAndroid Build Coastguard Worker const auto& ivalue = arguments[idx];
279*da0073e9SAndroid Build Coastguard Worker if (!ivalue.isTensor()) {
280*da0073e9SAndroid Build Coastguard Worker continue;
281*da0073e9SAndroid Build Coastguard Worker }
282*da0073e9SAndroid Build Coastguard Worker const auto& tensor = ivalue.toTensor();
283*da0073e9SAndroid Build Coastguard Worker if (!tensor.defined()) {
284*da0073e9SAndroid Build Coastguard Worker continue;
285*da0073e9SAndroid Build Coastguard Worker }
286*da0073e9SAndroid Build Coastguard Worker const auto* batched = maybeGetBatchedImpl(tensor);
287*da0073e9SAndroid Build Coastguard Worker if (!batched) {
288*da0073e9SAndroid Build Coastguard Worker continue;
289*da0073e9SAndroid Build Coastguard Worker }
290*da0073e9SAndroid Build Coastguard Worker batched_tensor_inputs.push_back(tensor);
291*da0073e9SAndroid Build Coastguard Worker batched_tensor_inputs_position.push_back(idx);
292*da0073e9SAndroid Build Coastguard Worker }
293*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(!batched_tensor_inputs.empty());
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker // MultiBatchVmapTransform the BatchedTensor arguments. This returns
296*da0073e9SAndroid Build Coastguard Worker // VmapPhysicalViews that contain all of the batch dimensions.
297*da0073e9SAndroid Build Coastguard Worker const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
298*da0073e9SAndroid Build Coastguard Worker batched_tensor_inputs);
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker // Compute the total number of batches
301*da0073e9SAndroid Build Coastguard Worker auto num_batch_dims = input_physical_views.front().numBatchDims();
302*da0073e9SAndroid Build Coastguard Worker auto some_sizes = input_physical_views.front().tensor().sizes();
303*da0073e9SAndroid Build Coastguard Worker auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims);
304*da0073e9SAndroid Build Coastguard Worker const auto num_batches = c10::multiply_integers(batch_sizes);
305*da0073e9SAndroid Build Coastguard Worker // Without a shape-checking API, we're unable to compute the correct shape of
306*da0073e9SAndroid Build Coastguard Worker // the output so we just error out.
307*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(num_batches > 0,
308*da0073e9SAndroid Build Coastguard Worker "Batching rule not implemented for ", schema.operator_name(), ". ",
309*da0073e9SAndroid Build Coastguard Worker "The fallback path does not support vmap over dims of size 0.");
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker // Strategy: For each batch, we are going to push slices (where applicable)
312*da0073e9SAndroid Build Coastguard Worker // of the arguments onto `stack`, call `op`, and store the result in
313*da0073e9SAndroid Build Coastguard Worker // `output_shards`.
314*da0073e9SAndroid Build Coastguard Worker //
315*da0073e9SAndroid Build Coastguard Worker // NOTE: [Output shards layout]
316*da0073e9SAndroid Build Coastguard Worker // Assume that the operator has three outputs: a, b, c.
317*da0073e9SAndroid Build Coastguard Worker // The layout of output_shards is as follows:
318*da0073e9SAndroid Build Coastguard Worker // [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3]
319*da0073e9SAndroid Build Coastguard Worker // This is so that we can call at::stack([a0...a3]), at::stack([b0...b3])
320*da0073e9SAndroid Build Coastguard Worker // more easily in the next step.
321*da0073e9SAndroid Build Coastguard Worker std::vector<Tensor> output_shards(num_batches * num_returns);
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker for (const auto linear_idx : c10::irange(num_batches)) {
324*da0073e9SAndroid Build Coastguard Worker auto index = computeIndex(linear_idx, batch_sizes);
325*da0073e9SAndroid Build Coastguard Worker auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
326*da0073e9SAndroid Build Coastguard Worker auto input_physical_views_iter = input_physical_views.begin();
327*da0073e9SAndroid Build Coastguard Worker for (const auto arg_idx : c10::irange(num_arguments)) {
328*da0073e9SAndroid Build Coastguard Worker // We assume that torch::jit::Stack is backed by vector<IValue> for
329*da0073e9SAndroid Build Coastguard Worker // simplicity. When that is not the case, this code should be updated.
330*da0073e9SAndroid Build Coastguard Worker const auto& argument = (*stack)[arguments_begin + arg_idx];
331*da0073e9SAndroid Build Coastguard Worker if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
332*da0073e9SAndroid Build Coastguard Worker || arg_idx != *batched_tensor_inputs_pos_iter) {
333*da0073e9SAndroid Build Coastguard Worker // argument isn't a BatchedTensor
334*da0073e9SAndroid Build Coastguard Worker torch::jit::push(stack, argument);
335*da0073e9SAndroid Build Coastguard Worker continue;
336*da0073e9SAndroid Build Coastguard Worker }
337*da0073e9SAndroid Build Coastguard Worker // argument is a BatchedTensor
338*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
339*da0073e9SAndroid Build Coastguard Worker const auto& physical_view_for_argument = *input_physical_views_iter;
340*da0073e9SAndroid Build Coastguard Worker torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
341*da0073e9SAndroid Build Coastguard Worker batched_tensor_inputs_pos_iter++;
342*da0073e9SAndroid Build Coastguard Worker input_physical_views_iter++;
343*da0073e9SAndroid Build Coastguard Worker }
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker op.callBoxed(stack);
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker // Store the result into `output_shards`. See NOTE: [Output shards layout]
348*da0073e9SAndroid Build Coastguard Worker // to learn about the details of how we store the shards.
349*da0073e9SAndroid Build Coastguard Worker const auto returns = torch::jit::last(stack, num_returns);
350*da0073e9SAndroid Build Coastguard Worker for (const auto return_idx : c10::irange(returns.size())) {
351*da0073e9SAndroid Build Coastguard Worker output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor();
352*da0073e9SAndroid Build Coastguard Worker }
353*da0073e9SAndroid Build Coastguard Worker torch::jit::drop(stack, num_returns);
354*da0073e9SAndroid Build Coastguard Worker }
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker // For each output Tensor, stack the shards of the tensor together to form a return
357*da0073e9SAndroid Build Coastguard Worker torch::jit::drop(stack, num_arguments);
358*da0073e9SAndroid Build Coastguard Worker auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
359*da0073e9SAndroid Build Coastguard Worker for (const auto return_idx : c10::irange(num_returns)) {
360*da0073e9SAndroid Build Coastguard Worker auto shards = output_shards_chunks[return_idx];
361*da0073e9SAndroid Build Coastguard Worker auto flat_output = safeStack(shards);
362*da0073e9SAndroid Build Coastguard Worker // See NOTE [vmap through backward and undefined grad]
363*da0073e9SAndroid Build Coastguard Worker if (!flat_output.defined()) {
364*da0073e9SAndroid Build Coastguard Worker torch::jit::push(stack, flat_output);
365*da0073e9SAndroid Build Coastguard Worker continue;
366*da0073e9SAndroid Build Coastguard Worker }
367*da0073e9SAndroid Build Coastguard Worker VmapDimVector output_sizes(batch_sizes);
368*da0073e9SAndroid Build Coastguard Worker output_sizes.insert(
369*da0073e9SAndroid Build Coastguard Worker output_sizes.end(),
370*da0073e9SAndroid Build Coastguard Worker flat_output.sizes().begin() + 1,
371*da0073e9SAndroid Build Coastguard Worker flat_output.sizes().end());
372*da0073e9SAndroid Build Coastguard Worker torch::jit::push(
373*da0073e9SAndroid Build Coastguard Worker stack,
374*da0073e9SAndroid Build Coastguard Worker input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes)));
375*da0073e9SAndroid Build Coastguard Worker }
376*da0073e9SAndroid Build Coastguard Worker }
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker } // namespace at
379