1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/Operators.h>
9 #include <ATen/functorch/PlumbingHelper.h>
10 #include <ATen/functorch/BatchedFallback.h>
11 #include <ATen/native/TensorAdvancedIndexing.h>
12 #include <ATen/native/IndexKernel.h>
13 #include <ATen/native/IndexingUtils.h>
14 #include <torch/library.h>
15
16
17 namespace at::functorch {
18
19 namespace {
any_has_value(ArrayRef<std::optional<int64_t>> bdims)20 static bool any_has_value(ArrayRef<std::optional<int64_t>> bdims) {
21 for (const auto& bdim : bdims) {
22 if (bdim.has_value()) {
23 return true;
24 }
25 }
26 return false;
27 }
28
get_num_leading_nones(ArrayRef<std::optional<Tensor>> indices)29 static int64_t get_num_leading_nones(ArrayRef<std::optional<Tensor>> indices) {
30 int64_t result = 0;
31 for (const auto& idx : indices) {
32 if (!idx.has_value() || !idx->defined()) {
33 result++;
34 } else {
35 return result;
36 }
37 }
38 return result;
39 }
40
get_max_index_logical_dim(ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims)41 static int64_t get_max_index_logical_dim(
42 ArrayRef<std::optional<Tensor>> indices,
43 ArrayRef<std::optional<int64_t>> indices_bdims) {
44 int64_t max_logical_dim = -1;
45 TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
46 TORCH_INTERNAL_ASSERT(!indices.empty());
47 for (const auto i : c10::irange(0, indices.size())) {
48 const auto& maybe_tensor = indices[i];
49 if (!maybe_tensor.has_value() || !maybe_tensor->defined()) {
50 continue;
51 }
52 auto logical_dim = rankWithoutBatchDim(maybe_tensor.value(), indices_bdims[i]);
53 max_logical_dim = std::max(logical_dim, max_logical_dim);
54 }
55 return max_logical_dim;
56 }
57
batchIndices(ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims,int64_t batch_size,std::optional<int64_t> self_bdim,std::optional<int64_t> values_bdim=std::nullopt)58 static std::vector<std::optional<Tensor>> batchIndices(
59 ArrayRef<std::optional<Tensor>> indices,
60 ArrayRef<std::optional<int64_t>> indices_bdims,
61 int64_t batch_size,
62 std::optional<int64_t> self_bdim,
63 std::optional<int64_t> values_bdim = std::nullopt) {
64 // There are 3 main cases:
65 // 1. self is batched, indices/values are not batched
66 // In this case, we just need to augment indices with a None at the front to
67 // basically broadcast the indexing across the batch dimension of self.
68 //
69 // 2. self is not batched, some indices are batched.
70 // In this case, we don't need to do anything - indices will automatically
71 // broadcast to work with the unbatched self.
72 //
73 // 3. self is batched, some indices are batched.
74 // In this case, we simply need to add an arange that indexes along the first
75 // dimension (i.e. the batch dimension). We also need to make sure this
76 // broadcasts with the rest of the indices.
77 //
78 // In all three cases, depending on if advanced indices are adjacent we will
79 // have to permute the output.
80 // See NOTE: [advanced indexing (index.Tensor) batch rule] for more details
81 //
82 // There is one more case worth mentioning - boolean tensor indices. If we
83 // have "batched" boolean tensor indices, that is unrepresentable, as each
84 // batch would result in a tensor with different values.
85 std::vector<std::optional<Tensor>> indices_;
86
87 int64_t maxLogicalRank = get_max_index_logical_dim(indices, indices_bdims);
88 bool indices_batched = any_has_value(indices_bdims);
89
90 for (size_t i = 0; i < indices.size(); i++) {
91 auto index = indices[i];
92 if (index.has_value() && index->numel() != 0) {
93 const auto idx_bdim = indices_bdims[i];
94 indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, maxLogicalRank));
95 if (index.value().dtype() == kBool && indices_bdims[i].has_value()) {
96 throw std::runtime_error("vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask.");
97 }
98 } else {
99 indices_.push_back(index);
100 }
101 }
102
103 auto maxIndexDim = maxLogicalRank;
104 if (indices_batched || values_bdim.has_value()) {
105 maxIndexDim += 1;
106 }
107
108 if (!indices_batched && self_bdim.has_value()) {
109 indices_.insert(indices_.begin(), std::nullopt);
110 } else if (indices_batched && !self_bdim.has_value()) {
111 // do nothing
112 } else if (indices_batched && (self_bdim.has_value() || values_bdim.has_value())) {
113 auto arange_index = at::arange(0, batch_size);
114 while (arange_index.dim() < maxIndexDim) {
115 arange_index = arange_index.unsqueeze(-1);
116 }
117 // TODO: this is O(N)
118 indices_.insert(indices_.begin(), arange_index);
119 }
120 return indices_;
121 }
122
123 // Define an "advanced index" to be a selection object that is
124 // a non-trivial Tensor (i.e. it does not represent :).
is_advanced_index(const std::optional<Tensor> & idx)125 static bool is_advanced_index(const std::optional<Tensor>& idx) {
126 if (!idx.has_value()) {
127 return false;
128 }
129 if (!idx->defined()) {
130 return false;
131 }
132 return true;
133 }
134
135 // See NOTE: [advanced indices adjacent] for definition
are_advanced_indices_adjacent(ArrayRef<std::optional<Tensor>> indices)136 static bool are_advanced_indices_adjacent(ArrayRef<std::optional<Tensor>> indices) {
137 int64_t num_advanced_indices_regions = 0;
138 bool in_advanced_indices_region = false;
139 for (const auto& idx : indices) {
140 if (!in_advanced_indices_region && is_advanced_index(idx)) {
141 num_advanced_indices_regions++;
142 in_advanced_indices_region = true;
143 continue;
144 }
145 if (in_advanced_indices_region && !is_advanced_index(idx)) {
146 in_advanced_indices_region = false;
147 continue;
148 }
149 }
150 return num_advanced_indices_regions <= 1;
151 }
152
153 // Given a Tensor[B, <first_region>, <second_region>, ...]
154 // Swaps the regions to produce Tensor[B, <second_region>, <first_region>, ...]
155 //
156 // Concretely speaking, given
157 // - tensor: Tensor[B, 2, 3, 4, 5, 6, 7, 8]
158 // - first_region_size: 2
159 // - second_region_size: 3
160 // Produces:
161 // - result: Tensor[B, 4, 5, 6, 2, 3, 7, 8]
162 // ------- ----
163 // region2 region1
swap_regions(const Tensor & tensor,int64_t first_region_size,int64_t second_region_size)164 static Tensor swap_regions(const Tensor& tensor, int64_t first_region_size, int64_t second_region_size) {
165 VmapDimVector permutation(tensor.dim(), 0);
166 std::iota(permutation.begin(), permutation.end(), 0);
167 std::rotate(
168 permutation.begin() + 1,
169 permutation.begin() + 1 + first_region_size,
170 permutation.begin() + 1 + first_region_size + second_region_size);
171 return tensor.permute(permutation);
172 }
173
index_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims)174 std::tuple<Tensor, std::optional<int64_t>> index_batch_rule(
175 const Tensor& self,
176 std::optional<int64_t> self_bdim,
177 ArrayRef<std::optional<Tensor>> indices,
178 ArrayRef<std::optional<int64_t>> indices_bdims) {
179
180 // NOTE: [advanced indexing (index.Tensor) batch rule]
181 //
182 // This is a three step procedure:
183 // 1. batch `indices`. Depends on self_bdim and indices_bdim.
184 // 2. call at::index
185 // 3. (maybe) reorder the dimensions in the result.
186 // Why is step 3 necessary? Let's take a detour first.
187 //
188 // NOTE: [advanced indices adjacent]
189 // Definition: In a list of std::optional<Tensor> indices,
190 // we say that "advanced indices are adjacent" if ALL advanced indices are
191 // not separated by a None (slice).
192 //
193 // So, for example,
194 // [:, :, (0, 1), (0, 1), :] -> True
195 // [:, (0, 1), :, (0, 1), :] -> False, the advanced indices are separated by a slice
196 //
197 // See https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
198 // for more details.
199 //
200 // NOTE: [Why is step 3 necessary?]
201 //
202 // In the original self[*indices] expression,
203 // depending on whether or not the "advanced indices inside `indices` are
204 // adjacent", something different happens.
205 //
206 // For example:
207 // - self: Tensor[4, 5, 6, 7]
208 // - indices: [:, (0, 1), (0, 1), :] (advanced indices are adjacent)
209 // - self[*indices]: Tensor[4, 2, 7]
210 // If advanced indices are adjacent, you get the output you would expect.
211 // (0, 1), (0, 1) says "please index these two dimensions at (0, 0) and (1, 1)
212 // to produce two elements".
213 //
214 // If advanced indices are not adjacent, it is ambiguous to where the new
215 // dimension of size 2 should go. The numpy spec says it should go at the very
216 // front of the Tensor.
217 //
218 // - self: Tensor[4, 5, 6, 7]
219 // - indices: [:, (0, 1), :, (0, 1)] (advanced indices not adjacent)
220 // - self[*indices]: Tensor[2, 4, 6]
221 //
222 // Now, this leads to some weird interactions with vmap.
223 // The indices might originally have adjacent advanced indices, but after
224 // batching them with "batchIndices", they may no longer be adjacent!
225 // - indices: [:, (0, 1), (0, 1)]
226 // - batched_indices (for example): [(0, 1), :, (0, 1), (0, 1)]
227 // This leads to the dimension of size 2 appearing somewhere else.
228 //
229 // There are a couple of different cases that we walk through in the code below.
230 //
231 // Background reading for why we care about if the advanced indices are adjacent:
232 // https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
233 auto self_ = moveBatchDimToFront(self, self_bdim);
234 TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
235 bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(indices);
236
237 // Step 1
238 const auto batched_indices = batchIndices(indices, indices_bdims, self_.size(0), self_bdim);
239 auto num_leading_nones = get_num_leading_nones(indices);
240 auto max_index_dim = get_max_index_logical_dim(indices, indices_bdims);
241
242 // Step 2
243 auto res = at::index(self_, List<std::optional<Tensor>>(batched_indices));
244
245 // Step 3: There are three cases (these match the cases outlined in batchIndices)
246 bool self_batched = self_bdim.has_value();
247 bool indices_batched = any_has_value(indices_bdims);
248
249 TORCH_INTERNAL_ASSERT(self_batched || indices_batched, "Requires at least one batched to get here");
250
251 // Case 1
252 if (self_batched && !indices_batched) {
253 if (advanced_indices_are_adjacent) {
254 // self: Tensor[B, 5, 6, 7, 8]
255 // indices: [:, Tensor[2, 2], Tensor[2, 2], :]
256 // batched_indices: [:, :, Tensor[2, 2], Tensor[2, 2], :]
257 // res: Tensor[B, 5, 2, 2, 8]
258 return std::make_tuple(res, 0);
259 } else {
260 // self: Tensor[B, 5, 6, 7]
261 // indices: [Tensor[2, 2], :, Tensor[2, 2]]
262 // batched_indices: [:, Tensor[2, 2], :, Tensor[2, 2]]
263 // res: Tensor[2, 2, B, 6]
264 return std::make_tuple(res, max_index_dim);
265 }
266 }
267
268 // Case 2
269 if (!self_batched && indices_batched) {
270 if (advanced_indices_are_adjacent) {
271 // self: Tensor[5, 6, 7, 8]
272 // indices: [:, :, Tensor[B, 2, 2], Tensor[2, 2]]
273 // batched_indices: indices (no change)
274 // res: Tensor[5, 6, B, 2, 2]
275 return std::make_tuple(res, num_leading_nones);
276 } else {
277 // self: Tensor[5, 6, 7, 8, 9]
278 // indices: [:, :, Tensor[B, 2, 2], :, Tensor[2, 2]]
279 // batched_indices: indices (no change)
280 // res: Tensor[B, 2, 2, 5, 6, 8]
281 return std::make_tuple(res, 0);
282 }
283 }
284
285 // Case 3: self_batched and indices_batched
286 TORCH_INTERNAL_ASSERT(self_batched && indices_batched);
287 if (!advanced_indices_are_adjacent) {
288 // self: Tensor[B, 5, 6, 7, 8]
289 // indices: [:, Tensor[B, 2, 2], :, Tensor[2, 2]]
290 // batched_indices: [arange(B).expand(B, 2, 2), :, Tensor[B, 2, 2], :, Tensor[2, 2]]
291 // res: Tensor[B, 2, 2, 5, 7]
292 return std::make_tuple(res, 0);
293 }
294 // In other words, in batched_indices, advanced indices are adjacent
295 if (num_leading_nones == 0) {
296 // self: Tensor[B, 5, 6, 7, 8]
297 // indices: [Tensor[B, 2, 2], Tensor[2, 2], :, :]
298 // batched_indices: [arange(B).expand(B, 2, 2), Tensor[B, 2, 2], Tensor[2, 2], :, :]
299 // res: Tensor[B, 2, 2, 7, 8]
300 return std::make_tuple(res, 0);
301 }
302 // This is the tricky case. In indices, advanced indices are adjacent.
303 // In batched_indices, advanced indices are no longer adjacent
304 //
305 // self: Tensor[B, 5, 6, 7, 8, 9]
306 // indices: [:, :, Tensor[B, 2, 3], Tensor[2, 3], :]
307 // batched_indices: [arange(B).expand(B, 2, 3), :, :, Tensor[B, 2, 3], Tensor[2, 3], :]
308 // res: Tensor[B, 2, 3, 5, 6, 9]
309 // expected: Tensor[B, 5, 6, 2, 3, 9]
310 //
311 // The resolution is to move dims around until we get the right shape.
312 // The result is set up as [B, <maxIndexDim>, <leading_nones>, ...]
313 // we just have to move the <leading_nones> to before the <maxIndexDim> to produce
314 // [B, <leading_nones>, <maxIndexDim>, ...]
315 return std::make_tuple(swap_regions(res, max_index_dim, num_leading_nones), 0);
316 }
317
318 // plumbing done since we don't support List<std::optional<Tensor>> in codegen
index_plumbing(const Tensor & self,const List<std::optional<Tensor>> & indices)319 Tensor index_plumbing(const Tensor & self, const List<std::optional<Tensor>> & indices
320 ) {
321 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
322 auto maybe_layer = maybeCurrentDynamicLayer();
323 vmap_check_escaped(maybe_layer, "index_plumbing");
324 int64_t cur_level = maybe_layer->layerId();
325 if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) {
326 return at::index(self, indices);
327 }
328 auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
329 std::vector<std::optional<Tensor>> indices_value;
330 std::vector<std::optional<int64_t>> indices_bdims;
331 for (const auto&& indRef : indices) {
332 std::optional<Tensor> ind = indRef;
333 std::optional<Tensor> index;
334 std::optional<int64_t> index_bdim;
335 if (ind.has_value()) {
336 std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level);
337 }
338 indices_value.push_back(index);
339 indices_bdims.push_back(index_bdim);
340 }
341 auto results = index_batch_rule(self_value, self_bdim, indices_value, indices_bdims);
342 return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
343 }
344
345 namespace {
346 // Code is mostly duplicated from
347 // https://github.com/pytorch/pytorch/blob/fb0e27d38a8fdab4e1c14d6378c9e41cb30fd6a3
348 // /aten/src/ATen/native/TensorAdvancedIndexing.cpp#L294-L312
compute_indexed_shape(const Tensor & src,TensorList indices_list)349 VmapDimVector compute_indexed_shape(const Tensor &src, TensorList indices_list)
350 {
351 int64_t dims_before = 0, dims_indexed = 0;
352 IntArrayRef replacement_shape;
353 for (const auto dim : c10::irange(indices_list.size())) {
354 if (!indices_list[dim].defined()) {
355 if (dims_indexed == 0) {
356 dims_before++;
357 }
358 } else {
359 dims_indexed++;
360 replacement_shape = indices_list[dim].sizes();
361 }
362 }
363
364 // Replace indexed dimensions in src with stride 0 and the size of the result tensor.
365 // The offset in these dimensions is computed by the kernel using the index tensor's
366 // values and the stride of src. The new shape is not meaningful. It's used to make
367 // the shape compatible with the result tensor.
368 auto shape = VmapDimVector(src.sizes());
369 int64_t end = dims_before + dims_indexed;
370 shape.erase(shape.begin() + dims_before, shape.begin() + end);
371 shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end());
372 return shape;
373 }
374
375 // Code is mostly duplicated from
376 // https://github.com/pytorch/pytorch/blob/fb0e27d38a8fdab4e1c14d6378c9e41cb30fd6a3
377 // /aten/src/ATen/native/TensorAdvancedIndexing.cpp#L379-L405
get_indexed_shape(Tensor self,const torch::List<std::optional<at::Tensor>> & orig)378 VmapDimVector get_indexed_shape(Tensor self, const torch::List<std::optional<at::Tensor>> &orig)
379 {
380 at::native::checkIndexTensorTypes(orig, /*allow_int*/ true);
381 // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
382 auto indices = at::native::expandTensors(self, orig);
383 // next broadcast all index tensors together
384 try {
385 indices = at::expand_outplace(indices);
386 } catch (std::exception &e) {
387 TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
388 " with shapes ");
389 }
390 // add missing null Tensors so that it matches self.dim()
391 while (indices.size() < static_cast<size_t>(self.dim())) {
392 indices.emplace_back();
393 }
394 // if the non-null indices are not all adjacent, transpose self and indices
395 // together so that they're adjacent at the front
396 if (!at::native::hasContiguousSubspace(indices)) {
397 std::tie(self, indices) = at::native::transposeToFront(self, indices);
398 }
399 return compute_indexed_shape(self, indices);
400 }
401
402 std::tuple<Tensor, std::vector<std::optional<Tensor>>, Tensor>
index_put_batch_rule_helper(const Tensor & self,std::optional<int64_t> self_bdim,ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims,const Tensor & values,std::optional<int64_t> values_bdim,std::optional<int64_t> opt_batch_size={})403 index_put_batch_rule_helper(const Tensor &self,
404 std::optional<int64_t> self_bdim,
405 ArrayRef<std::optional<Tensor>> indices,
406 ArrayRef<std::optional<int64_t>> indices_bdims,
407 const Tensor &values,
408 std::optional<int64_t> values_bdim,
409 std::optional<int64_t> opt_batch_size = {}) {
410
411 Tensor self_ = moveBatchDimToFront(self, self_bdim);
412 Tensor values_ = moveBatchDimToFront(values, values_bdim);
413 // for inplace variants `index_put_` and `_index_put_impl_` we find the batch_size
414 // here while for `index_put` does it outside of this function.
415 const auto batch_size = opt_batch_size ? opt_batch_size.value() : self_.size(0);
416 self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
417 values_ = ensure_has_bdim(values_, values_bdim.has_value(), batch_size);
418 TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
419
420 // we've already made sure that self has bdim at 0.
421 const auto indices_ = batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim);
422
423 auto indexed_shape = get_indexed_shape(self_, List<std::optional<Tensor>>(indices_));
424
425 // handle broadcasting support for values
426 // Eg. Given `indexed_shape.size()` is 5 and
427 // shape of `values` is (N, 2, 3), then following block
428 // will reshape `values` to (N, 1, 1, 2, 3).
429 if ( (int64_t) indexed_shape.size() > values_.dim()) {
430 auto values_sizes = values_.sizes();
431
432 // number of unit dims (for broadcasting value to indexed_shape)
433 auto n_unit_dims = indexed_shape.size() - values_sizes.size();
434 VmapDimVector new_values_shape(values_sizes.size() + n_unit_dims);
435
436 // add the batch-dim
437 new_values_shape[0] = batch_size;
438
439 // insert the unit dims for broadcasting.
440 for (const auto idx : c10::irange(n_unit_dims)) {
441 // since batch-dim is already be filled.
442 new_values_shape[idx + 1] = 1;
443 }
444 for (const auto idx: c10::irange(1, values_sizes.size())) {
445 // since batch and unit dims are already be filled.
446 new_values_shape[idx + n_unit_dims] = values_sizes[idx];
447 }
448 values_ = values_.view(new_values_shape);
449 }
450
451 return std::make_tuple(self_, indices_, values_);
452 }
453
unpackSelfAndIndicesAndValuesAtCurrentLevel(const Tensor & self,const List<std::optional<Tensor>> & indices,const Tensor & values,int64_t cur_level)454 auto unpackSelfAndIndicesAndValuesAtCurrentLevel(const Tensor &self,
455 const List<std::optional<Tensor>> &indices,
456 const Tensor &values, int64_t cur_level)
457 {
458 auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
459 std::vector<std::optional<Tensor>> indices_value;
460 std::vector<std::optional<int64_t>> indices_bdims;
461 for (const auto &&indRef : indices)
462 {
463 std::optional<Tensor> ind = indRef;
464 std::optional<Tensor> index;
465 std::optional<int64_t> index_bdim;
466 if (ind.has_value()) {
467 std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level);
468 }
469 indices_value.push_back(index);
470 indices_bdims.push_back(index_bdim);
471 }
472 auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level);
473 return std::make_tuple(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim);
474 }
475
476 } // namespace
477
index_put__batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims,const Tensor & values,std::optional<int64_t> values_bdim,bool accumulate)478 void index_put__batch_rule(
479 const Tensor& self,
480 std::optional<int64_t> self_bdim,
481 ArrayRef<std::optional<Tensor>> indices,
482 ArrayRef<std::optional<int64_t>> indices_bdims,
483 const Tensor& values,
484 std::optional<int64_t> values_bdim,
485 bool accumulate) {
486 if (!self_bdim.has_value()) {
487 vmapIncompatibleInplaceError("index_put_");
488 }
489 auto [self_, indices_, values_] = index_put_batch_rule_helper(
490 self, self_bdim, indices, indices_bdims, values, values_bdim);
491 at::index_put_(self_, List<std::optional<Tensor>>(indices_), values_, accumulate);
492 }
493
494 // plumbing done since we don't support List<std::optional<Tensor>> in codegen
index_put__plumbing(Tensor & self,const List<std::optional<Tensor>> & indices,const Tensor & values,bool accumulate)495 Tensor& index_put__plumbing(Tensor & self, const List<std::optional<Tensor>> & indices
496 , const Tensor & values, bool accumulate) {
497 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
498 auto maybe_layer = maybeCurrentDynamicLayer();
499 vmap_check_escaped(maybe_layer, "index_put__plumbing");
500 int64_t cur_level = maybe_layer->layerId();
501
502 // on device mismatch, we can move 0d tensors to self device
503 auto values_ = values;
504 if (values.device() != self.device() && values.numel() == 1 && values.dim() == 0) {
505 values_ = values.to(self.device());
506 }
507
508 if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values_, cur_level)) {
509 return self.index_put_(indices, values_, accumulate);
510 }
511 auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
512 unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values_, cur_level);
513 index_put__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
514 return self;
515 }
516
_index_put_impl__batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims,const Tensor & values,std::optional<int64_t> values_bdim,bool accumulate,bool unsafe)517 void _index_put_impl__batch_rule(
518 const Tensor& self,
519 std::optional<int64_t> self_bdim,
520 ArrayRef<std::optional<Tensor>> indices,
521 ArrayRef<std::optional<int64_t>> indices_bdims,
522 const Tensor& values,
523 std::optional<int64_t> values_bdim,
524 bool accumulate,
525 bool unsafe) {
526 if (!self_bdim.has_value()) {
527 vmapIncompatibleInplaceError("_index_put_impl_");
528 }
529 auto [self_, indices_, values_] = index_put_batch_rule_helper(
530 self, self_bdim, indices, indices_bdims, values, values_bdim);
531 at::_index_put_impl_(self_, List<std::optional<Tensor>>(indices_), values_, accumulate, unsafe);
532 }
533
534 // plumbing done since we don't support List<std::optional<Tensor>> in codegen
_index_put_impl__plumbing(Tensor & self,const List<std::optional<Tensor>> & indices,const Tensor & values,bool accumulate,bool unsafe)535 Tensor &_index_put_impl__plumbing(Tensor &self, const List<std::optional<Tensor>> &indices,
536 const Tensor &values, bool accumulate, bool unsafe) {
537 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
538 auto maybe_layer = maybeCurrentDynamicLayer();
539 vmap_check_escaped(maybe_layer, "_index_put_impl__plumbing");
540 int64_t cur_level = maybe_layer->layerId();
541 if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) {
542 return at::_index_put_impl_(self, indices, values, accumulate, unsafe);
543 }
544 auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
545 unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values, cur_level);
546 _index_put_impl__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate, unsafe);
547 return self;
548 }
549
maybe_permute_values(const Tensor & values,ArrayRef<std::optional<Tensor>> orig_indices,ArrayRef<std::optional<int64_t>> orig_indices_bdims)550 static Tensor maybe_permute_values(
551 const Tensor& values,
552 ArrayRef<std::optional<Tensor>> orig_indices,
553 ArrayRef<std::optional<int64_t>> orig_indices_bdims) {
554 bool indices_batched = any_has_value(orig_indices_bdims);
555 bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(orig_indices);
556 auto num_leading_nones = get_num_leading_nones(orig_indices);
557 auto max_index_dim = get_max_index_logical_dim(orig_indices, orig_indices_bdims);
558
559 // NB: values has its B dimension at the front
560 if (!indices_batched) {
561 if (advanced_indices_are_adjacent) {
562 // self: Tensor[B, 5, 6, 7, 8]
563 // indices: [:, Tensor[2, 2], Tensor[2, 2], :]
564 // batched_indices: [:, :, Tensor[2, 2], Tensor[2, 2], :]
565 // required values: Tensor[B, 5, 2, 2, 8]
566 return values;
567 }
568 // self: Tensor[B, 5, 6, 7]
569 // indices: [Tensor[2, 2], :, Tensor[2, 2]]
570 // batched_indices: [:, Tensor[2, 2], :, Tensor[2, 2]]
571 // required values: Tensor[2, 2, B, 6]
572 return values.movedim(0, max_index_dim);
573 }
574 if (!advanced_indices_are_adjacent) {
575 // self: Tensor[B, 5, 6, 7, 8]
576 // indices: [:, Tensor[B, 2, 2], :, Tensor[2, 2]]
577 // batched_indices: [arange(B).expand(B, 2, 2), :, Tensor[B, 2, 2], :, Tensor[2, 2]]
578 // required values: Tensor[B, 2, 2, 5, 7]
579 return values;
580 }
581 // In other words, in batched_indices, advanced indices are adjacent
582 if (num_leading_nones == 0) {
583 // self: Tensor[B, 5, 6, 7, 8]
584 // indices: [Tensor[B, 2, 2], Tensor[2, 2], :, :]
585 // batched_indices: [arange(B).expand(B, 2, 2), Tensor[B, 2, 2], Tensor[2, 2], :, :]
586 // required values: Tensor[B, 2, 2, 7, 8]
587 return values;
588 }
589 // This is the tricky case. In indices, advanced indices are adjacent.
590 // In batched_indices, advanced indices are no longer adjacent
591 //
592 // self: Tensor[B, 5, 6, 7, 8, 9]
593 // indices: [:, :, Tensor[B, 2, 3], Tensor[2, 3], :]
594 // batched_indices: [arange(B).expand(B, 2, 3), :, :, Tensor[B, 2, 3], Tensor[2, 3], :]
595 // required values: Tensor[B, 2, 3, 5, 6, 9]
596 // actual values: Tensor[B, 5, 6, 2, 3, 9]
597 //
598 // The resolution is to move dims around until we get the right shape.
599 // The values is set up as [B, <leading_nones>, <maxIndexDim>, ...]
600 // we just have to move the <maxIndexDim> to before the <leading_nones> to produce
601 // [B, <maxIndexDim>, <leading_nones>, ...]
602 return swap_regions(values, num_leading_nones, max_index_dim);
603 }
604
index_put_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims,const Tensor & values,std::optional<int64_t> values_bdim,bool accumulate)605 std::tuple<Tensor, std::optional<int64_t>> index_put_batch_rule(
606 const Tensor& self,
607 std::optional<int64_t> self_bdim,
608 ArrayRef<std::optional<Tensor>> indices,
609 ArrayRef<std::optional<int64_t>> indices_bdims,
610 const Tensor& values,
611 std::optional<int64_t> values_bdim,
612 bool accumulate) {
613 TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
614
615 // find the batch_size
616 int64_t batch_size = 0;
617 if (self_bdim || values_bdim) {
618 batch_size = get_bdim_size2(self, self_bdim, values, values_bdim);
619 } else {
620 // one or more of the indices is batched.
621 for (size_t i = 0; i < indices.size(); i++) {
622 if (indices_bdims[i] && indices[i].has_value()) {
623 batch_size = indices[i].value().size(*indices_bdims[i]);
624 break;
625 }
626 }
627 }
628
629 auto [self_, indices_, values_] = index_put_batch_rule_helper(
630 self, self_bdim, indices, indices_bdims, values, values_bdim, batch_size);
631
632 // Why do we need to permute values?
633 // See NOTE [Advanced indexing (index.Tensor) batch rule] for details,
634 // but the gist is that index_put effectively does the following:
635 // - result = self_.clone()
636 // - result[indices_] = values
637 // - return result
638 // Now, the problem is, result[indices_] might return a Tensor whose shape is
639 // the shape of values, but permuted. This is because the shape of result[indices_]
640 // depends on if the original indices "have adjacent advanced indices"
641 // and the batched `indices_` might change the "have adjacent advanced indices" property
642 values_ = maybe_permute_values(values_, indices, indices_bdims);
643
644 auto result = at::index_put(self_, List<std::optional<Tensor>>(indices_), values_, accumulate);
645 return std::make_tuple(result, 0);
646 }
647
648 // plumbing done since we don't support List<std::optional<Tensor>> in codegen
index_put_plumbing(const Tensor & self,const List<std::optional<Tensor>> & indices,const Tensor & values,bool accumulate)649 Tensor index_put_plumbing(const Tensor & self, const List<std::optional<Tensor>> & indices,
650 const Tensor & values, bool accumulate) {
651 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
652 auto maybe_layer = maybeCurrentDynamicLayer();
653 vmap_check_escaped(maybe_layer, "index_put_plumbing");
654 int64_t cur_level = maybe_layer->layerId();
655
656 // on device mismatch, we can move 0d tensors to self device
657 auto values_ = values;
658 if (values.device() != self.device() && values.numel() == 1 && values.dim() == 0) {
659 values_ = values.to(self.device());
660 }
661
662 if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values_, cur_level)) {
663 return self.index_put(indices, values_, accumulate);
664 }
665 auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
666 unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values_, cur_level);
667 auto results = index_put_batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
668 return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
669 }
670
671 namespace {
672
673 template<typename Func, typename ...Args>
scatter_batch_rule(Func f,const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & value,Args...args)674 std::tuple<Tensor, std::optional<int64_t>> scatter_batch_rule(
675 Func f,
676 const Tensor& self, std::optional<int64_t> self_bdim,
677 int64_t dim,
678 const Tensor& index, std::optional<int64_t> index_bdim,
679 const Scalar& value, Args... args) {
680 auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
681 auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
682 auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
683
684 auto self_ = moveBatchDimToFront(self, self_bdim);
685 auto index_ = moveBatchDimToFront(index, index_bdim);
686
687 if (self_logical_rank == 0) {
688 self_ = self_.unsqueeze(-1);
689 }
690 if (index_logical_rank == 0) {
691 index_ = index_.unsqueeze(-1);
692 }
693 self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
694 index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
695 auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
696
697 auto result = f(self_, physical_dim, index_, value, args...);
698 // result should have same shape as self
699 if (self_logical_rank == 0) {
700 result = result.squeeze(-1);
701 }
702 return std::make_tuple(result, 0);
703 }
704
705 template <typename Func, typename ...Args>
scatter_batch_rule(Func f,const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & src,std::optional<int64_t> src_bdim,Args...args)706 inline std::tuple<Tensor, std::optional<int64_t>> scatter_batch_rule(
707 Func f,
708 const Tensor& self, std::optional<int64_t> self_bdim,
709 int64_t dim,
710 const Tensor& index, std::optional<int64_t> index_bdim,
711 const Tensor& src, std::optional<int64_t> src_bdim, Args... args) {
712 auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
713 auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
714 auto src_logical_rank = rankWithoutBatchDim(src, src_bdim);
715 auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, src, src_bdim);
716
717 auto self_ = moveBatchDimToFront(self, self_bdim);
718 auto index_ = moveBatchDimToFront(index, index_bdim);
719 auto src_ = moveBatchDimToFront(src, src_bdim);
720
721 if (self_logical_rank == 0) {
722 self_ = self_.unsqueeze(-1);
723 }
724 if (index_logical_rank == 0) {
725 index_ = index_.unsqueeze(-1);
726 }
727 if (src_logical_rank == 0) {
728 src_ = src_.unsqueeze(-1);
729 }
730 self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
731 index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
732 src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size);
733 auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
734
735 auto result = f(self_, physical_dim, index_, src_, args...);
736 // result should have same shape as self
737 if (self_logical_rank == 0) {
738 result = result.squeeze(-1);
739 }
740 return std::make_tuple(result, 0);
741 }
742
743 } // namespace
744
scatter_value_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & value)745 std::tuple<Tensor, std::optional<int64_t>> scatter_value_batch_rule(
746 const Tensor& self, std::optional<int64_t> self_bdim,
747 int64_t dim,
748 const Tensor& index, std::optional<int64_t> index_bdim,
749 const Scalar& value) {
750 return scatter_batch_rule(ATEN_FN2(scatter, value),
751 self, self_bdim, dim, index, index_bdim, value);
752 }
753
scatter_src_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & src,std::optional<int64_t> src_bdim)754 std::tuple<Tensor, std::optional<int64_t>> scatter_src_batch_rule(
755 const Tensor& self, std::optional<int64_t> self_bdim,
756 int64_t dim,
757 const Tensor& index, std::optional<int64_t> index_bdim,
758 const Tensor& src, std::optional<int64_t> src_bdim) {
759 return scatter_batch_rule(ATEN_FN2(scatter, src),
760 self, self_bdim, dim, index, index_bdim, src, src_bdim);
761 }
762
scatter_add_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & src,std::optional<int64_t> src_bdim)763 std::tuple<Tensor, std::optional<int64_t>> scatter_add_batch_rule(
764 const Tensor& self, std::optional<int64_t> self_bdim,
765 int64_t dim,
766 const Tensor& index, std::optional<int64_t> index_bdim,
767 const Tensor& src, std::optional<int64_t> src_bdim) {
768 return scatter_batch_rule(ATEN_FN(scatter_add),
769 self, self_bdim, dim, index, index_bdim, src, src_bdim);
770 }
771
scatter_reduce_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & src,std::optional<int64_t> src_bdim,const c10::string_view reduce)772 std::tuple<Tensor, std::optional<int64_t>> scatter_reduce_batch_rule(
773 const Tensor& self, std::optional<int64_t> self_bdim,
774 int64_t dim,
775 const Tensor& index, std::optional<int64_t> index_bdim,
776 const Tensor& src, std::optional<int64_t> src_bdim,
777 const c10::string_view reduce) {
778 return scatter_batch_rule(ATEN_FN2(scatter, reduce),
779 self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce);
780 }
781
scatter_value_reduce_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & src,const c10::string_view reduce)782 std::tuple<Tensor, std::optional<int64_t>> scatter_value_reduce_batch_rule(
783 const Tensor& self, std::optional<int64_t> self_bdim,
784 int64_t dim,
785 const Tensor& index, std::optional<int64_t> index_bdim,
786 const Scalar& src,
787 const c10::string_view reduce) {
788 return scatter_batch_rule(ATEN_FN2(scatter, value_reduce),
789 self, self_bdim, dim, index, index_bdim, src, reduce);
790 }
791
gather_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,bool sparse_grad)792 std::tuple<Tensor, std::optional<int64_t>> gather_batch_rule(
793 const Tensor& self, std::optional<int64_t> self_bdim,
794 int64_t dim,
795 const Tensor& index, std::optional<int64_t> index_bdim,
796 bool sparse_grad) {
797 auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
798 auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
799 auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
800
801 auto self_ = moveBatchDimToFront(self, self_bdim);
802 auto index_ = moveBatchDimToFront(index, index_bdim);
803
804 if (self_logical_rank == 0) {
805 self_ = self_.unsqueeze(-1);
806 }
807 if (index_logical_rank == 0) {
808 index_ = index_.unsqueeze(-1);
809 }
810 self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
811 index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
812 auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
813
814 auto result = at::gather(self_, physical_dim, index_, sparse_grad);
815 // result should have same rank as index
816 if (index_logical_rank == 0) {
817 result = result.squeeze(-1);
818 }
819 return std::make_tuple(result, 0);
820 }
821
get_expanded_index(const Tensor & index,IntArrayRef self_size,int64_t dim)822 Tensor get_expanded_index(const Tensor& index, IntArrayRef self_size, int64_t dim) {
823 if (index.dim() == 0) {
824 return index.expand(self_size);
825 }
826 dim = maybe_wrap_dim(dim, static_cast<int64_t>(self_size.size()));
827
828 // setup new_index_shape as [BS, 1, ..., idx_size, ..., 1]
829 // to reshape index_
830 auto idx_size = index.size(0); // get non-batch size of index tensor
831 Tensor index_;
832 {
833 VmapDimVector new_index_shape(self_size.size(), 1);
834 new_index_shape[dim] = idx_size;
835 index_ = index.view(new_index_shape);
836 }
837 // Now apply expand to index_
838 {
839 VmapDimVector new_index_shape = {self_size.begin(), self_size.end()};
840 new_index_shape[dim] = idx_size;
841 index_ = index_.expand(new_index_shape);
842 }
843 return index_;
844 }
845
index_select_decomp(const Tensor & self,int64_t dim,const Tensor & index)846 Tensor index_select_decomp(const Tensor &self, int64_t dim, const Tensor &index)
847 {
848 Tensor index_ = index;
849 if (self.dim() > index.dim()) {
850 index_ = get_expanded_index(index, self.sizes(), dim);
851 }
852
853 auto result = at::gather(self, dim, index_);
854
855 // output of gather has same dimension as `index` while
856 // output of index_select has same dimension as self
857 // Eg. t = torch.tensor(1)
858 // idx = torch.tensor([0])
859 // torch.index_select(t, 0, idx) # 0-D
860 // torch.gather(t, 0, idx) # 1-D
861 if (self.dim() == 0 && result.dim() != 0) {
862 result = result.squeeze(-1);
863 }
864
865 return result;
866 }
867
index_copy_decomp(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source)868 Tensor index_copy_decomp(
869 const Tensor &self, int64_t dim,
870 const Tensor &index, const Tensor &source)
871 {
872 Tensor index_ = index;
873 if (self.dim() > index.dim()) {
874 index_ = get_expanded_index(index, self.sizes(), dim);
875 }
876
877 return at::scatter(self, dim, index_, source); ;
878 }
879
880 // Note [Fix vmap slice_scatter]
881 // registers a decomposition for `slice_scatter` that calls into `slice.src`
882 // *_scatter operators have some special semantics though, that we can't easily
883 // through a decomposition: slice_scatter's output needs to have the same
884 // size, size, strides and storage_offset as the input.
slice_scatter_decomp(const Tensor & self,const Tensor & src,int64_t dim,std::optional<int64_t> start,std::optional<int64_t> end,int64_t step)885 Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src,
886 int64_t dim, std::optional<int64_t> start,
887 std::optional<int64_t> end, int64_t step)
888 {
889 auto idx = at::arange(start.value_or(0), end.value_or(self.size(dim)), step, self.options().dtype(kLong));
890 idx = get_expanded_index(idx, self.sizes(), dim);
891 return at::scatter(self, dim, idx, src);
892 }
893
select_scatter_decomp(const Tensor & self,const Tensor & source,int64_t dim,int64_t index)894 Tensor select_scatter_decomp(
895 const Tensor &self, const Tensor &source,
896 int64_t dim, int64_t index)
897 {
898 // supports negative index
899 index = maybe_wrap_dim(index, self.size(dim));
900 auto index_ = at::scalar_tensor(index, self.options().dtype(kLong));
901
902 return at::scatter(self, dim, index_.expand_as(self), source.unsqueeze(dim).expand_as(self));
903 }
904
diagonal_scatter_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & src,std::optional<int64_t> src_bdim,int64_t offset,int64_t dim1,int64_t dim2)905 std::tuple<Tensor, std::optional<int64_t>> diagonal_scatter_batch_rule(
906 const Tensor &self, std::optional<int64_t> self_bdim,
907 const Tensor &src, std::optional<int64_t> src_bdim,
908 int64_t offset, int64_t dim1, int64_t dim2)
909 {
910 auto self_ = moveBatchDimToFront(self, self_bdim);
911 auto src_ = moveBatchDimToFront(src, src_bdim);
912
913 auto batch_size = get_bdim_size2(self, self_bdim, src, src_bdim);
914
915 self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
916 src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size);
917
918 auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
919 dim1 = maybe_wrap_dim(dim1, self_logical_rank) + 1;
920 dim2 = maybe_wrap_dim(dim2, self_logical_rank) + 1;
921
922 return std::make_tuple(at::diagonal_scatter(self_, src_, offset, dim1, dim2), 0);
923 }
924
index_add_batch_rule_impl(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & other,std::optional<int64_t> other_bdim,const Scalar & alpha,const bool inplace)925 std::tuple<Tensor, std::optional<int64_t>> index_add_batch_rule_impl(
926 Tensor& self, std::optional<int64_t> self_bdim,
927 int64_t dim,
928 const Tensor& index, std::optional<int64_t> index_bdim,
929 const Tensor& other, std::optional<int64_t> other_bdim,
930 const Scalar& alpha,
931 const bool inplace) {
932
933 if (inplace && !self_bdim.has_value()){
934 vmapIncompatibleInplaceError("index_add_");
935 }
936
937 if (!index_bdim) {
938 // Handle scalar tensors... self, other can be scalar tensors
939 const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
940 const auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
941 auto self_ = moveBatchDimToFront(self, self_bdim);
942 if (self_logical_rank == 0) {
943 self_ = self_.unsqueeze(-1);
944 }
945 auto other_ = moveBatchDimToFront(other, other_bdim);
946 if (other_logical_rank == 0) {
947 other_ = other_.unsqueeze(-1);
948 }
949 dim = maybe_wrap_dim(dim, self_logical_rank);
950
951 const auto batch_size = get_bdim_size2(self, self_bdim, other, other_bdim);
952 self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
953 other_ = ensure_has_bdim(other_, other_bdim.has_value(), batch_size);
954
955 if (inplace) {
956 self_.index_add_(dim + 1, index, other_, alpha);
957 if (self_logical_rank == 0) {
958 self_ = self_.squeeze(-1);
959 }
960 return std::make_tuple(self, 0);
961 }
962
963 auto result = self_.index_add(dim + 1, index, other_, alpha);
964 if (self_logical_rank == 0) {
965 result = result.squeeze(-1);
966 }
967 return std::make_tuple(result, 0);
968 }
969
970 // Index is batched. For-loop and stack is the best thing I can come up with
971 // right now. We really want generalized index_add kernel in PyTorch
972 auto batch_size = get_bdim_size3(self, self_bdim, other, other_bdim, index, index_bdim);
973 std::vector<Tensor> results;
974 if (!inplace) {
975 results.reserve(batch_size);
976 }
977 for (const auto i : c10::irange(0, batch_size)) {
978 const auto& self_slice = self_bdim.has_value() ?
979 self.select(*self_bdim, i) : self;
980 const auto& other_slice = other_bdim.has_value() ?
981 other.select(*other_bdim, i) : other;
982 const auto& index_slice = index_bdim.has_value() ?
983 index.select(*index_bdim, i) : index;
984
985 if (inplace) {
986 self_slice.index_add_(dim, index_slice, other_slice, alpha);
987 } else {
988 results.push_back(at::index_add(self_slice, dim, index_slice, other_slice, alpha));
989 }
990 }
991 if (inplace) {
992 return std::make_tuple(at::stack(self), 0);
993 }
994 return std::make_tuple(at::stack(results), 0);
995 }
996
index_add__batch_rule(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & other,std::optional<int64_t> other_bdim,const Scalar & alpha)997 void index_add__batch_rule(
998 Tensor& self, std::optional<int64_t> self_bdim,
999 int64_t dim,
1000 const Tensor& index, std::optional<int64_t> index_bdim,
1001 const Tensor& other, std::optional<int64_t> other_bdim,
1002 const Scalar& alpha) {
1003 index_add_batch_rule_impl(self, self_bdim, dim, index, index_bdim, other,
1004 other_bdim, alpha, true);
1005 }
1006
index_add_batch_rule(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & other,std::optional<int64_t> other_bdim,const Scalar & alpha)1007 std::tuple<Tensor, std::optional<int64_t>> index_add_batch_rule(
1008 Tensor& self, std::optional<int64_t> self_bdim,
1009 int64_t dim,
1010 const Tensor& index, std::optional<int64_t> index_bdim,
1011 const Tensor& other, std::optional<int64_t> other_bdim,
1012 const Scalar& alpha) {
1013 auto self_ = self.clone(at::MemoryFormat::Preserve);
1014 return index_add_batch_rule_impl(self_, self_bdim, dim, index, index_bdim,
1015 other, other_bdim, alpha, false);
1016 }
1017
binary_pointwise_align(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & mask,std::optional<int64_t> mask_bdim)1018 static std::tuple<Tensor,Tensor> binary_pointwise_align(
1019 const Tensor & self,
1020 std::optional<int64_t> self_bdim,
1021 const Tensor & mask,
1022 std::optional<int64_t> mask_bdim) {
1023 // compute max logical rank
1024 auto tensor_logical_rank = rankWithoutBatchDim(self, self_bdim);
1025 auto other_logical_rank = rankWithoutBatchDim(mask, mask_bdim);
1026 auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
1027
1028 auto tensor_ = moveBatchDimToFront(self, self_bdim);
1029 auto other_ = moveBatchDimToFront(mask, mask_bdim);
1030
1031 // If the dimensions aren't aligned, we need to line them up.
1032 // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
1033 // Note that only tensors that have a batch dim need to be modified.
1034 // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
1035 tensor_ = maybePadToLogicalRank(tensor_, self_bdim, max_logical_rank);
1036 other_ = maybePadToLogicalRank(other_, mask_bdim, max_logical_rank);
1037
1038 return std::make_tuple(tensor_, other_);
1039 }
1040
masked_fill_scalar_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & mask,std::optional<int64_t> mask_bdim,const Scalar & source)1041 std::tuple<Tensor, std::optional<int64_t>> masked_fill_scalar_batch_rule(
1042 const Tensor & self,
1043 std::optional<int64_t> self_bdim,
1044 const Tensor & mask,
1045 std::optional<int64_t> mask_bdim,
1046 const Scalar& source) {
1047 auto tensors = binary_pointwise_align(self, self_bdim, mask, mask_bdim);
1048 auto result = at::masked_fill(std::get<0>(tensors), std::get<1>(tensors), source);
1049 return std::make_tuple(result, 0);
1050 }
1051
index_fill_batch_rule_helper(int64_t batch_size,int64_t self_logical_rank,int64_t index_logical_rank,Tensor & self_,int64_t dim,Tensor & index_,const Scalar & value)1052 std::tuple<Tensor, std::optional<int64_t>> index_fill_batch_rule_helper(
1053 int64_t batch_size,
1054 int64_t self_logical_rank,
1055 int64_t index_logical_rank,
1056 Tensor & self_,
1057 int64_t dim,
1058 Tensor & index_,
1059 const Scalar & value
1060 ){
1061 if (self_logical_rank != 0){
1062 auto index_offset = at::arange(
1063 batch_size,
1064 at::TensorOptions().dtype(index_.scalar_type()).device(index_.device())
1065 );
1066 if (index_logical_rank == 0){
1067 index_ = index_.unsqueeze(-1);
1068 }
1069 index_ = index_.add(index_offset.unsqueeze(-1), self_.size(dim + 1));
1070 index_ = reshape_dim_into(0, 0, index_);
1071 self_ = reshape_dim_into(0, dim, self_);
1072 self_.index_fill_(dim, index_, value);
1073 self_ = reshape_dim_outof(dim, batch_size, self_);
1074 return std::make_tuple(self_, dim);
1075 }
1076
1077 // If self_logical_rank == 0, the batch dim is certainly 0, and we must apply batched indices to each row.
1078 if (index_logical_rank != 0){
1079 index_ = reshape_dim_into(0, 0, index_);
1080 }
1081 self_.unsqueeze_(-1);
1082 self_.index_fill_(dim + 1, index_, value);
1083 self_.squeeze_(-1);
1084
1085 return std::make_tuple(self_, 0);
1086 }
1087
index_fill_int_scalar_batch_rule_impl(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & value,const bool inplace)1088 std::tuple<Tensor, std::optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
1089 Tensor & self, std::optional<int64_t> self_bdim,
1090 int64_t dim,
1091 const Tensor & index, std::optional<int64_t> index_bdim,
1092 const Scalar & value,
1093 const bool inplace) {
1094 const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
1095 const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
1096 Tensor self_ = moveBatchDimToFront(self, self_bdim);
1097 Tensor index_ = moveBatchDimToFront(index, index_bdim);
1098 dim = maybe_wrap_dim(dim, self_logical_rank);
1099
1100 if (inplace && !self_bdim.has_value()) {
1101 vmapIncompatibleInplaceError("index_fill_");
1102 }
1103
1104 if (!index_bdim) {
1105 if (self_logical_rank == 0){
1106 self_.unsqueeze_(-1);
1107 }
1108 self_.index_fill_(dim + 1, index_, value);
1109 if (self_logical_rank == 0) {
1110 self_.squeeze_(-1);
1111 }
1112 return std::make_tuple(self_, 0);
1113 }
1114
1115 auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
1116 self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
1117 index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
1118
1119 if (inplace) {
1120 // Do for-loop for in-place because we cannot reshape
1121 // `self_` having an incompatible stride without copying.
1122 for (const auto i : c10::irange(0, batch_size)) {
1123 const auto& self_slice = self_.select(0, i);
1124 const auto& index_slice = index_.select(0, i);
1125 self_slice.index_fill_(
1126 dim,
1127 index_slice,
1128 value
1129 );
1130 }
1131 return std::make_tuple(self_, 0);
1132 }
1133
1134 self_ = self_bdim.has_value() ? self_ : self_.clone();
1135
1136 return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value);
1137 }
1138
index_fill_int_tensor_batch_rule_impl(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & value,std::optional<int64_t> value_bdim,const bool inplace)1139 std::tuple<Tensor, std::optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
1140 Tensor & self, std::optional<int64_t> self_bdim,
1141 int64_t dim,
1142 const Tensor & index, std::optional<int64_t> index_bdim,
1143 const Tensor & value, std::optional<int64_t> value_bdim,
1144 const bool inplace) {
1145 const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
1146 const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
1147 Tensor self_ = moveBatchDimToFront(self, self_bdim);
1148 Tensor index_ = moveBatchDimToFront(index, index_bdim);
1149 Tensor value_ = moveBatchDimToFront(value, value_bdim);
1150 dim = maybe_wrap_dim(dim, self_logical_rank);
1151
1152 if (inplace && !self_bdim.has_value()) {
1153 vmapIncompatibleInplaceError("index_fill_");
1154 }
1155
1156 if (!index_bdim && !value_bdim) {
1157 if (self_logical_rank == 0){
1158 self_.unsqueeze_(-1);
1159 }
1160 self_.index_fill_(dim + 1, index_, value);
1161 if (self_logical_rank == 0) {
1162 self_.squeeze_(-1);
1163 }
1164 return std::make_tuple(self_, 0);
1165 }
1166
1167 auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
1168 self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
1169 index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
1170
1171 if (inplace || value_bdim.has_value()) {
1172 // Do for-loop for in-place because we cannot reshape
1173 // `self_` having an incompatible stride without copying.
1174 // If value has a batch dim, we do for-loop as well because
1175 // index_fill_ supports 1-element tensor only.
1176 for (const auto i : c10::irange(0, batch_size)) {
1177 const auto& self_slice = self_.select(0, i);
1178 const auto& index_slice = index_.select(0, i);
1179 self_slice.index_fill_(
1180 dim,
1181 index_slice,
1182 value_bdim.has_value() ? value_.select(0, i) : value_
1183 );
1184 }
1185 return std::make_tuple(self_, 0);
1186 }
1187
1188 self_ = self_bdim.has_value() ? self_ : self_.clone();
1189
1190 // calling .item() on value is safe here because value is guaranteed to not be a batched tensor.
1191 return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value.item());
1192 }
1193
index_fill__int_scalar_batch_rule(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & value)1194 void index_fill__int_scalar_batch_rule(
1195 Tensor & self, std::optional<int64_t> self_bdim,
1196 int64_t dim,
1197 const Tensor & index, std::optional<int64_t> index_bdim,
1198 const Scalar & value) {
1199 index_fill_int_scalar_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, true);
1200 }
1201
index_fill__int_tensor_batch_rule(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & value,std::optional<int64_t> value_bdim)1202 void index_fill__int_tensor_batch_rule(
1203 Tensor & self, std::optional<int64_t> self_bdim,
1204 int64_t dim,
1205 const Tensor & index, std::optional<int64_t> index_bdim,
1206 const Tensor & value, std::optional<int64_t> value_bdim) {
1207 index_fill_int_tensor_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, value_bdim, true);
1208 }
1209
index_fill_int_scalar_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & value)1210 std::tuple<Tensor, std::optional<int64_t>> index_fill_int_scalar_batch_rule(
1211 const Tensor & self, std::optional<int64_t> self_bdim,
1212 int64_t dim,
1213 const Tensor & index, std::optional<int64_t> index_bdim,
1214 const Scalar & value) {
1215 auto self_ = self.clone(at::MemoryFormat::Preserve);
1216 return index_fill_int_scalar_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, false);
1217 }
1218
index_fill_int_tensor_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & value,std::optional<int64_t> value_bdim)1219 std::tuple<Tensor, std::optional<int64_t>> index_fill_int_tensor_batch_rule(
1220 const Tensor & self, std::optional<int64_t> self_bdim,
1221 int64_t dim,
1222 const Tensor & index, std::optional<int64_t> index_bdim,
1223 const Tensor & value, std::optional<int64_t> value_bdim) {
1224 auto self_ = self.clone(at::MemoryFormat::Preserve);
1225 return index_fill_int_tensor_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, value_bdim, false);
1226 }
1227
1228 }
1229
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)1230 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
1231 m.impl("index.Tensor", index_plumbing);
1232 m.impl("index_put_", index_put__plumbing);
1233 m.impl("index_put", index_put_plumbing);
1234 m.impl("_index_put_impl_", _index_put_impl__plumbing);
1235 m.impl("slice_scatter", slice_scatter_decomp);
1236 m.impl("select_scatter", select_scatter_decomp);
1237 m.impl("index_copy", index_copy_decomp);
1238 m.impl("index_select", index_select_decomp);
1239 VMAP_SUPPORT2(masked_fill, Scalar, masked_fill_scalar_batch_rule);
1240 VMAP_SUPPORT2(index_fill_, int_Tensor, index_fill__int_tensor_batch_rule);
1241 VMAP_SUPPORT2(index_fill_, int_Scalar, index_fill__int_scalar_batch_rule);
1242 VMAP_SUPPORT2(index_fill, int_Tensor, index_fill_int_tensor_batch_rule);
1243 VMAP_SUPPORT2(index_fill, int_Scalar, index_fill_int_scalar_batch_rule);
1244 VMAP_SUPPORT(index_add_, index_add__batch_rule);
1245 VMAP_SUPPORT(index_add, index_add_batch_rule);
1246 VMAP_SUPPORT(diagonal_scatter, diagonal_scatter_batch_rule);
1247 VMAP_SUPPORT(gather, gather_batch_rule);
1248 VMAP_SUPPORT2(scatter, value, scatter_value_batch_rule);
1249 VMAP_SUPPORT2(scatter, src, scatter_src_batch_rule);
1250 VMAP_SUPPORT(scatter_add, scatter_add_batch_rule);
1251 VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule);
1252 VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule);
1253 // as_strided_scatter does not work with the for-loop fallback today,
1254 // because as_strided_scatter will return an output that matches
1255 // the strides/storage_offset of its input.
1256 // With the for loop fallback, each input tensor is a slice into
1257 // the larger batched tensor.
1258 m.impl("as_strided_scatter", torch::CppFunction::makeFromBoxedFunction<&vmapErrorFallback>());
1259 }
1260
1261 } // namespace at::functorch
1262