xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesScatterOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/Operators.h>
9 #include <ATen/functorch/PlumbingHelper.h>
10 #include <ATen/functorch/BatchedFallback.h>
11 #include <ATen/native/TensorAdvancedIndexing.h>
12 #include <ATen/native/IndexKernel.h>
13 #include <ATen/native/IndexingUtils.h>
14 #include <torch/library.h>
15 
16 
17 namespace at::functorch {
18 
19 namespace {
any_has_value(ArrayRef<std::optional<int64_t>> bdims)20 static bool any_has_value(ArrayRef<std::optional<int64_t>> bdims) {
21   for (const auto& bdim : bdims) {
22     if (bdim.has_value()) {
23       return true;
24     }
25   }
26   return false;
27 }
28 
get_num_leading_nones(ArrayRef<std::optional<Tensor>> indices)29 static int64_t get_num_leading_nones(ArrayRef<std::optional<Tensor>> indices) {
30   int64_t result = 0;
31   for (const auto& idx : indices) {
32     if (!idx.has_value() || !idx->defined()) {
33       result++;
34     } else {
35       return result;
36     }
37   }
38   return result;
39 }
40 
get_max_index_logical_dim(ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims)41 static int64_t get_max_index_logical_dim(
42     ArrayRef<std::optional<Tensor>> indices,
43     ArrayRef<std::optional<int64_t>> indices_bdims) {
44   int64_t max_logical_dim = -1;
45   TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
46   TORCH_INTERNAL_ASSERT(!indices.empty());
47   for (const auto i : c10::irange(0, indices.size())) {
48     const auto& maybe_tensor = indices[i];
49     if (!maybe_tensor.has_value() || !maybe_tensor->defined()) {
50       continue;
51     }
52     auto logical_dim = rankWithoutBatchDim(maybe_tensor.value(), indices_bdims[i]);
53     max_logical_dim = std::max(logical_dim, max_logical_dim);
54   }
55   return max_logical_dim;
56 }
57 
batchIndices(ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims,int64_t batch_size,std::optional<int64_t> self_bdim,std::optional<int64_t> values_bdim=std::nullopt)58 static std::vector<std::optional<Tensor>> batchIndices(
59   ArrayRef<std::optional<Tensor>> indices,
60   ArrayRef<std::optional<int64_t>> indices_bdims,
61   int64_t batch_size,
62   std::optional<int64_t> self_bdim,
63   std::optional<int64_t> values_bdim = std::nullopt) {
64   // There are 3 main cases:
65   // 1. self is batched, indices/values are not batched
66   // In this case, we just need to augment indices with a None at the front to
67   // basically broadcast the indexing across the batch dimension of self.
68   //
69   // 2. self is not batched, some indices are batched.
70   // In this case, we don't need to do anything - indices will automatically
71   // broadcast to work with the unbatched self.
72   //
73   // 3. self is batched, some indices are batched.
74   // In this case, we simply need to add an arange that indexes along the first
75   // dimension (i.e. the batch dimension). We also need to make sure this
76   // broadcasts with the rest of the indices.
77   //
78   // In all three cases, depending on if advanced indices are adjacent we will
79   // have to permute the output.
80   // See NOTE: [advanced indexing (index.Tensor) batch rule] for more details
81   //
82   // There is one more case worth mentioning - boolean tensor indices. If we
83   // have "batched" boolean tensor indices, that is unrepresentable, as each
84   // batch would result in a tensor with different values.
85   std::vector<std::optional<Tensor>> indices_;
86 
87   int64_t maxLogicalRank = get_max_index_logical_dim(indices, indices_bdims);
88   bool indices_batched = any_has_value(indices_bdims);
89 
90   for (size_t i = 0; i < indices.size(); i++) {
91     auto index = indices[i];
92     if (index.has_value() && index->numel() != 0) {
93       const auto idx_bdim = indices_bdims[i];
94       indices_.emplace_back(maybePadToLogicalRank(moveBatchDimToFront(index.value(), idx_bdim), idx_bdim, maxLogicalRank));
95       if (index.value().dtype() == kBool && indices_bdims[i].has_value()) {
96         throw std::runtime_error("vmap: We do not support batching operators that can support dynamic shape. Attempting to batch over indexing with a boolean mask.");
97       }
98     } else {
99       indices_.push_back(index);
100     }
101   }
102 
103   auto maxIndexDim = maxLogicalRank;
104   if (indices_batched || values_bdim.has_value()) {
105     maxIndexDim += 1;
106   }
107 
108   if (!indices_batched && self_bdim.has_value()) {
109     indices_.insert(indices_.begin(), std::nullopt);
110   } else if (indices_batched && !self_bdim.has_value()) {
111     // do nothing
112   } else if (indices_batched && (self_bdim.has_value() || values_bdim.has_value())) {
113     auto arange_index = at::arange(0, batch_size);
114     while (arange_index.dim() < maxIndexDim) {
115       arange_index = arange_index.unsqueeze(-1);
116     }
117     // TODO: this is O(N)
118     indices_.insert(indices_.begin(), arange_index);
119   }
120   return indices_;
121 }
122 
123 // Define an "advanced index" to be a selection object that is
124 // a non-trivial Tensor (i.e. it does not represent :).
is_advanced_index(const std::optional<Tensor> & idx)125 static bool is_advanced_index(const std::optional<Tensor>& idx) {
126   if (!idx.has_value()) {
127     return false;
128   }
129   if (!idx->defined()) {
130     return false;
131   }
132   return true;
133 }
134 
135 // See NOTE: [advanced indices adjacent] for definition
are_advanced_indices_adjacent(ArrayRef<std::optional<Tensor>> indices)136 static bool are_advanced_indices_adjacent(ArrayRef<std::optional<Tensor>> indices) {
137   int64_t num_advanced_indices_regions = 0;
138   bool in_advanced_indices_region = false;
139   for (const auto& idx : indices) {
140     if (!in_advanced_indices_region && is_advanced_index(idx)) {
141       num_advanced_indices_regions++;
142       in_advanced_indices_region = true;
143       continue;
144     }
145     if (in_advanced_indices_region && !is_advanced_index(idx)) {
146       in_advanced_indices_region = false;
147       continue;
148     }
149   }
150   return num_advanced_indices_regions <= 1;
151 }
152 
153 // Given a Tensor[B, <first_region>, <second_region>, ...]
154 // Swaps the regions to produce Tensor[B, <second_region>, <first_region>, ...]
155 //
156 // Concretely speaking, given
157 // - tensor: Tensor[B, 2, 3, 4, 5, 6, 7, 8]
158 // - first_region_size: 2
159 // - second_region_size: 3
160 // Produces:
161 // - result: Tensor[B, 4, 5, 6, 2, 3, 7, 8]
162 //                     -------  ----
163 //                     region2  region1
swap_regions(const Tensor & tensor,int64_t first_region_size,int64_t second_region_size)164 static Tensor swap_regions(const Tensor& tensor, int64_t first_region_size, int64_t second_region_size) {
165   VmapDimVector permutation(tensor.dim(), 0);
166   std::iota(permutation.begin(), permutation.end(), 0);
167   std::rotate(
168       permutation.begin() + 1,
169       permutation.begin() + 1 + first_region_size,
170       permutation.begin() + 1 + first_region_size + second_region_size);
171   return tensor.permute(permutation);
172 }
173 
index_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims)174 std::tuple<Tensor, std::optional<int64_t>> index_batch_rule(
175     const Tensor& self,
176     std::optional<int64_t> self_bdim,
177     ArrayRef<std::optional<Tensor>> indices,
178     ArrayRef<std::optional<int64_t>> indices_bdims) {
179 
180   // NOTE: [advanced indexing (index.Tensor) batch rule]
181   //
182   // This is a three step procedure:
183   // 1. batch `indices`. Depends on self_bdim and indices_bdim.
184   // 2. call at::index
185   // 3. (maybe) reorder the dimensions in the result.
186   // Why is step 3 necessary? Let's take a detour first.
187   //
188   // NOTE: [advanced indices adjacent]
189   // Definition: In a list of std::optional<Tensor> indices,
190   // we say that "advanced indices are adjacent" if ALL advanced indices are
191   // not separated by a None (slice).
192   //
193   // So, for example,
194   // [:, :, (0, 1), (0, 1), :] -> True
195   // [:, (0, 1), :, (0, 1), :] -> False, the advanced indices are separated by a slice
196   //
197   // See https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
198   // for more details.
199   //
200   // NOTE: [Why is step 3 necessary?]
201   //
202   // In the original self[*indices] expression,
203   // depending on whether or not the "advanced indices inside `indices` are
204   // adjacent", something different happens.
205   //
206   // For example:
207   // - self: Tensor[4, 5, 6, 7]
208   // - indices: [:, (0, 1), (0, 1), :] (advanced indices are adjacent)
209   // - self[*indices]: Tensor[4, 2, 7]
210   // If advanced indices are adjacent, you get the output you would expect.
211   // (0, 1), (0, 1) says "please index these two dimensions at (0, 0) and (1, 1)
212   // to produce two elements".
213   //
214   // If advanced indices are not adjacent, it is ambiguous to where the new
215   // dimension of size 2 should go. The numpy spec says it should go at the very
216   // front of the Tensor.
217   //
218   // - self: Tensor[4, 5, 6, 7]
219   // - indices: [:, (0, 1), :, (0, 1)] (advanced indices not adjacent)
220   // - self[*indices]: Tensor[2, 4, 6]
221   //
222   // Now, this leads to some weird interactions with vmap.
223   // The indices might originally have adjacent advanced indices, but after
224   // batching them with "batchIndices", they may no longer be adjacent!
225   // - indices: [:, (0, 1), (0, 1)]
226   // - batched_indices (for example): [(0, 1), :, (0, 1), (0, 1)]
227   // This leads to the dimension of size 2 appearing somewhere else.
228   //
229   // There are a couple of different cases that we walk through in the code below.
230   //
231   // Background reading for why we care about if the advanced indices are adjacent:
232   // https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
233   auto self_ = moveBatchDimToFront(self, self_bdim);
234   TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
235   bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(indices);
236 
237   // Step 1
238   const auto batched_indices = batchIndices(indices, indices_bdims, self_.size(0), self_bdim);
239   auto num_leading_nones = get_num_leading_nones(indices);
240   auto max_index_dim = get_max_index_logical_dim(indices, indices_bdims);
241 
242   // Step 2
243   auto res = at::index(self_, List<std::optional<Tensor>>(batched_indices));
244 
245   // Step 3: There are three cases (these match the cases outlined in batchIndices)
246   bool self_batched = self_bdim.has_value();
247   bool indices_batched = any_has_value(indices_bdims);
248 
249   TORCH_INTERNAL_ASSERT(self_batched || indices_batched, "Requires at least one batched to get here");
250 
251   // Case 1
252   if (self_batched && !indices_batched) {
253     if (advanced_indices_are_adjacent) {
254       // self: Tensor[B, 5, 6, 7, 8]
255       // indices: [:, Tensor[2, 2], Tensor[2, 2], :]
256       // batched_indices: [:, :, Tensor[2, 2], Tensor[2, 2], :]
257       // res: Tensor[B, 5, 2, 2, 8]
258       return std::make_tuple(res, 0);
259     } else {
260       // self: Tensor[B, 5, 6, 7]
261       // indices: [Tensor[2, 2], :, Tensor[2, 2]]
262       // batched_indices: [:, Tensor[2, 2], :, Tensor[2, 2]]
263       // res: Tensor[2, 2, B, 6]
264       return std::make_tuple(res, max_index_dim);
265     }
266   }
267 
268   // Case 2
269   if (!self_batched && indices_batched) {
270     if (advanced_indices_are_adjacent) {
271       // self: Tensor[5, 6, 7, 8]
272       // indices: [:, :, Tensor[B, 2, 2], Tensor[2, 2]]
273       // batched_indices: indices (no change)
274       // res: Tensor[5, 6, B, 2, 2]
275       return std::make_tuple(res, num_leading_nones);
276     } else {
277       // self: Tensor[5, 6, 7, 8, 9]
278       // indices: [:, :, Tensor[B, 2, 2], :, Tensor[2, 2]]
279       // batched_indices: indices (no change)
280       // res: Tensor[B, 2, 2, 5, 6, 8]
281       return std::make_tuple(res, 0);
282     }
283   }
284 
285   // Case 3: self_batched and indices_batched
286   TORCH_INTERNAL_ASSERT(self_batched && indices_batched);
287   if (!advanced_indices_are_adjacent) {
288     // self: Tensor[B, 5, 6, 7, 8]
289     // indices: [:, Tensor[B, 2, 2], :, Tensor[2, 2]]
290     // batched_indices: [arange(B).expand(B, 2, 2), :, Tensor[B, 2, 2], :, Tensor[2, 2]]
291     // res: Tensor[B, 2, 2, 5, 7]
292     return std::make_tuple(res, 0);
293   }
294   // In other words, in batched_indices, advanced indices are adjacent
295   if (num_leading_nones == 0) {
296     // self: Tensor[B, 5, 6, 7, 8]
297     // indices: [Tensor[B, 2, 2], Tensor[2, 2], :, :]
298     // batched_indices: [arange(B).expand(B, 2, 2), Tensor[B, 2, 2], Tensor[2, 2], :, :]
299     // res: Tensor[B, 2, 2, 7, 8]
300     return std::make_tuple(res, 0);
301   }
302   // This is the tricky case. In indices, advanced indices are adjacent.
303   // In batched_indices, advanced indices are no longer adjacent
304   //
305   // self: Tensor[B, 5, 6, 7, 8, 9]
306   // indices: [:, :, Tensor[B, 2, 3], Tensor[2, 3], :]
307   // batched_indices: [arange(B).expand(B, 2, 3), :, :, Tensor[B, 2, 3], Tensor[2, 3], :]
308   // res: Tensor[B, 2, 3, 5, 6, 9]
309   // expected: Tensor[B, 5, 6, 2, 3, 9]
310   //
311   // The resolution is to move dims around until we get the right shape.
312   // The result is set up as [B, <maxIndexDim>, <leading_nones>, ...]
313   // we just have to move the <leading_nones> to before the <maxIndexDim> to produce
314   // [B, <leading_nones>, <maxIndexDim>, ...]
315   return std::make_tuple(swap_regions(res, max_index_dim, num_leading_nones), 0);
316 }
317 
318 // plumbing done since we don't support List<std::optional<Tensor>> in codegen
index_plumbing(const Tensor & self,const List<std::optional<Tensor>> & indices)319 Tensor index_plumbing(const Tensor & self, const List<std::optional<Tensor>> & indices
320 ) {
321   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
322   auto maybe_layer = maybeCurrentDynamicLayer();
323   vmap_check_escaped(maybe_layer, "index_plumbing");
324   int64_t cur_level = maybe_layer->layerId();
325   if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) {
326     return at::index(self, indices);
327   }
328   auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
329   std::vector<std::optional<Tensor>> indices_value;
330   std::vector<std::optional<int64_t>> indices_bdims;
331   for (const auto&& indRef : indices) {
332       std::optional<Tensor> ind = indRef;
333       std::optional<Tensor> index;
334       std::optional<int64_t> index_bdim;
335       if (ind.has_value()) {
336         std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level);
337       }
338     indices_value.push_back(index);
339     indices_bdims.push_back(index_bdim);
340   }
341   auto results = index_batch_rule(self_value, self_bdim, indices_value, indices_bdims);
342   return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
343 }
344 
345 namespace {
346   // Code is mostly duplicated from
347   // https://github.com/pytorch/pytorch/blob/fb0e27d38a8fdab4e1c14d6378c9e41cb30fd6a3
348   // /aten/src/ATen/native/TensorAdvancedIndexing.cpp#L294-L312
compute_indexed_shape(const Tensor & src,TensorList indices_list)349   VmapDimVector compute_indexed_shape(const Tensor &src, TensorList indices_list)
350   {
351     int64_t dims_before = 0, dims_indexed = 0;
352     IntArrayRef replacement_shape;
353     for (const auto dim : c10::irange(indices_list.size())) {
354       if (!indices_list[dim].defined()) {
355         if (dims_indexed == 0) {
356           dims_before++;
357         }
358       } else {
359         dims_indexed++;
360         replacement_shape = indices_list[dim].sizes();
361       }
362     }
363 
364     // Replace indexed dimensions in src with stride 0 and the size of the result tensor.
365     // The offset in these dimensions is computed by the kernel using the index tensor's
366     // values and the stride of src. The new shape is not meaningful. It's used to make
367     // the shape compatible with the result tensor.
368     auto shape = VmapDimVector(src.sizes());
369     int64_t end = dims_before + dims_indexed;
370     shape.erase(shape.begin() + dims_before, shape.begin() + end);
371     shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end());
372     return shape;
373   }
374 
375   // Code is mostly duplicated from
376   // https://github.com/pytorch/pytorch/blob/fb0e27d38a8fdab4e1c14d6378c9e41cb30fd6a3
377   // /aten/src/ATen/native/TensorAdvancedIndexing.cpp#L379-L405
get_indexed_shape(Tensor self,const torch::List<std::optional<at::Tensor>> & orig)378   VmapDimVector get_indexed_shape(Tensor self, const torch::List<std::optional<at::Tensor>> &orig)
379   {
380     at::native::checkIndexTensorTypes(orig, /*allow_int*/ true);
381     // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
382     auto indices = at::native::expandTensors(self, orig);
383     // next broadcast all index tensors together
384     try {
385       indices = at::expand_outplace(indices);
386     } catch (std::exception &e) {
387       TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together"
388                                " with shapes ");
389     }
390     // add missing null Tensors so that it matches self.dim()
391     while (indices.size() < static_cast<size_t>(self.dim())) {
392       indices.emplace_back();
393     }
394     // if the non-null indices are not all adjacent, transpose self and indices
395     // together so that they're adjacent at the front
396     if (!at::native::hasContiguousSubspace(indices)) {
397       std::tie(self, indices) = at::native::transposeToFront(self, indices);
398     }
399     return compute_indexed_shape(self, indices);
400   }
401 
402   std::tuple<Tensor, std::vector<std::optional<Tensor>>, Tensor>
index_put_batch_rule_helper(const Tensor & self,std::optional<int64_t> self_bdim,ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims,const Tensor & values,std::optional<int64_t> values_bdim,std::optional<int64_t> opt_batch_size={})403   index_put_batch_rule_helper(const Tensor &self,
404                               std::optional<int64_t> self_bdim,
405                               ArrayRef<std::optional<Tensor>> indices,
406                               ArrayRef<std::optional<int64_t>> indices_bdims,
407                               const Tensor &values,
408                               std::optional<int64_t> values_bdim,
409                               std::optional<int64_t> opt_batch_size = {}) {
410 
411     Tensor self_ = moveBatchDimToFront(self, self_bdim);
412     Tensor values_ = moveBatchDimToFront(values, values_bdim);
413     // for inplace variants `index_put_` and `_index_put_impl_` we find the batch_size
414     // here while for `index_put` does it outside of this function.
415     const auto batch_size = opt_batch_size ? opt_batch_size.value() : self_.size(0);
416     self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
417     values_ = ensure_has_bdim(values_, values_bdim.has_value(), batch_size);
418     TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
419 
420     // we've already made sure that self has bdim at 0.
421     const auto indices_ = batchIndices(indices, indices_bdims, batch_size, /*self_bdim=*/0, values_bdim);
422 
423     auto indexed_shape = get_indexed_shape(self_, List<std::optional<Tensor>>(indices_));
424 
425     // handle broadcasting support for values
426     // Eg. Given `indexed_shape.size()` is 5 and
427     // shape of `values` is (N, 2, 3), then following block
428     // will reshape `values` to (N, 1, 1, 2, 3).
429     if ( (int64_t) indexed_shape.size() > values_.dim()) {
430       auto values_sizes = values_.sizes();
431 
432       // number of unit dims (for broadcasting value to indexed_shape)
433       auto n_unit_dims = indexed_shape.size() - values_sizes.size();
434       VmapDimVector new_values_shape(values_sizes.size() + n_unit_dims);
435 
436       // add the batch-dim
437       new_values_shape[0] = batch_size;
438 
439       // insert the unit dims for broadcasting.
440       for (const auto idx : c10::irange(n_unit_dims)) {
441         // since batch-dim is already be filled.
442         new_values_shape[idx + 1] = 1;
443       }
444       for (const auto idx: c10::irange(1, values_sizes.size())) {
445         // since batch and unit dims are already be filled.
446         new_values_shape[idx + n_unit_dims] = values_sizes[idx];
447       }
448       values_ = values_.view(new_values_shape);
449     }
450 
451     return std::make_tuple(self_, indices_, values_);
452   }
453 
unpackSelfAndIndicesAndValuesAtCurrentLevel(const Tensor & self,const List<std::optional<Tensor>> & indices,const Tensor & values,int64_t cur_level)454   auto unpackSelfAndIndicesAndValuesAtCurrentLevel(const Tensor &self,
455                                                    const List<std::optional<Tensor>> &indices,
456                                                    const Tensor &values, int64_t cur_level)
457   {
458     auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level);
459     std::vector<std::optional<Tensor>> indices_value;
460     std::vector<std::optional<int64_t>> indices_bdims;
461     for (const auto &&indRef : indices)
462     {
463       std::optional<Tensor> ind = indRef;
464       std::optional<Tensor> index;
465       std::optional<int64_t> index_bdim;
466       if (ind.has_value()) {
467         std::tie(index, index_bdim) = unwrapTensorAtLevel(ind.value(), cur_level);
468       }
469       indices_value.push_back(index);
470       indices_bdims.push_back(index_bdim);
471     }
472     auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level);
473     return std::make_tuple(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim);
474   }
475 
476 }  // namespace
477 
index_put__batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims,const Tensor & values,std::optional<int64_t> values_bdim,bool accumulate)478 void index_put__batch_rule(
479     const Tensor& self,
480     std::optional<int64_t> self_bdim,
481     ArrayRef<std::optional<Tensor>> indices,
482     ArrayRef<std::optional<int64_t>> indices_bdims,
483     const Tensor& values,
484     std::optional<int64_t> values_bdim,
485     bool accumulate) {
486   if (!self_bdim.has_value()) {
487     vmapIncompatibleInplaceError("index_put_");
488   }
489   auto [self_, indices_, values_] = index_put_batch_rule_helper(
490       self, self_bdim, indices, indices_bdims, values, values_bdim);
491   at::index_put_(self_, List<std::optional<Tensor>>(indices_), values_, accumulate);
492 }
493 
494 // plumbing done since we don't support List<std::optional<Tensor>> in codegen
index_put__plumbing(Tensor & self,const List<std::optional<Tensor>> & indices,const Tensor & values,bool accumulate)495 Tensor& index_put__plumbing(Tensor & self, const List<std::optional<Tensor>> & indices
496 , const Tensor & values, bool accumulate) {
497   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
498   auto maybe_layer = maybeCurrentDynamicLayer();
499   vmap_check_escaped(maybe_layer, "index_put__plumbing");
500   int64_t cur_level = maybe_layer->layerId();
501 
502   // on device mismatch, we can move 0d tensors to self device
503   auto values_ = values;
504   if (values.device() != self.device() && values.numel() == 1 && values.dim() == 0) {
505     values_ = values.to(self.device());
506   }
507 
508   if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values_, cur_level)) {
509     return self.index_put_(indices, values_, accumulate);
510   }
511   auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
512       unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values_, cur_level);
513   index_put__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
514   return self;
515 }
516 
_index_put_impl__batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims,const Tensor & values,std::optional<int64_t> values_bdim,bool accumulate,bool unsafe)517 void _index_put_impl__batch_rule(
518     const Tensor& self,
519     std::optional<int64_t> self_bdim,
520     ArrayRef<std::optional<Tensor>> indices,
521     ArrayRef<std::optional<int64_t>> indices_bdims,
522     const Tensor& values,
523     std::optional<int64_t> values_bdim,
524     bool accumulate,
525     bool unsafe) {
526   if (!self_bdim.has_value()) {
527     vmapIncompatibleInplaceError("_index_put_impl_");
528   }
529   auto [self_, indices_, values_] = index_put_batch_rule_helper(
530       self, self_bdim, indices, indices_bdims, values, values_bdim);
531   at::_index_put_impl_(self_, List<std::optional<Tensor>>(indices_), values_, accumulate, unsafe);
532 }
533 
534 // plumbing done since we don't support List<std::optional<Tensor>> in codegen
_index_put_impl__plumbing(Tensor & self,const List<std::optional<Tensor>> & indices,const Tensor & values,bool accumulate,bool unsafe)535 Tensor &_index_put_impl__plumbing(Tensor &self, const List<std::optional<Tensor>> &indices,
536                                   const Tensor &values, bool accumulate, bool unsafe) {
537   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
538   auto maybe_layer = maybeCurrentDynamicLayer();
539   vmap_check_escaped(maybe_layer, "_index_put_impl__plumbing");
540   int64_t cur_level = maybe_layer->layerId();
541   if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) {
542     return at::_index_put_impl_(self, indices, values, accumulate, unsafe);
543   }
544   auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
545       unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values, cur_level);
546   _index_put_impl__batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate, unsafe);
547   return self;
548 }
549 
maybe_permute_values(const Tensor & values,ArrayRef<std::optional<Tensor>> orig_indices,ArrayRef<std::optional<int64_t>> orig_indices_bdims)550 static Tensor maybe_permute_values(
551     const Tensor& values,
552     ArrayRef<std::optional<Tensor>> orig_indices,
553     ArrayRef<std::optional<int64_t>> orig_indices_bdims) {
554   bool indices_batched = any_has_value(orig_indices_bdims);
555   bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(orig_indices);
556   auto num_leading_nones = get_num_leading_nones(orig_indices);
557   auto max_index_dim = get_max_index_logical_dim(orig_indices, orig_indices_bdims);
558 
559   // NB: values has its B dimension at the front
560   if (!indices_batched) {
561     if (advanced_indices_are_adjacent) {
562       // self: Tensor[B, 5, 6, 7, 8]
563       // indices: [:, Tensor[2, 2], Tensor[2, 2], :]
564       // batched_indices: [:, :, Tensor[2, 2], Tensor[2, 2], :]
565       // required values: Tensor[B, 5, 2, 2, 8]
566       return values;
567     }
568     // self: Tensor[B, 5, 6, 7]
569     // indices: [Tensor[2, 2], :, Tensor[2, 2]]
570     // batched_indices: [:, Tensor[2, 2], :, Tensor[2, 2]]
571     // required values: Tensor[2, 2, B, 6]
572     return values.movedim(0, max_index_dim);
573   }
574   if (!advanced_indices_are_adjacent) {
575     // self: Tensor[B, 5, 6, 7, 8]
576     // indices: [:, Tensor[B, 2, 2], :, Tensor[2, 2]]
577     // batched_indices: [arange(B).expand(B, 2, 2), :, Tensor[B, 2, 2], :, Tensor[2, 2]]
578     // required values: Tensor[B, 2, 2, 5, 7]
579     return values;
580   }
581   // In other words, in batched_indices, advanced indices are adjacent
582   if (num_leading_nones == 0) {
583     // self: Tensor[B, 5, 6, 7, 8]
584     // indices: [Tensor[B, 2, 2], Tensor[2, 2], :, :]
585     // batched_indices: [arange(B).expand(B, 2, 2), Tensor[B, 2, 2], Tensor[2, 2], :, :]
586     // required values: Tensor[B, 2, 2, 7, 8]
587     return values;
588   }
589   // This is the tricky case. In indices, advanced indices are adjacent.
590   // In batched_indices, advanced indices are no longer adjacent
591   //
592   // self: Tensor[B, 5, 6, 7, 8, 9]
593   // indices: [:, :, Tensor[B, 2, 3], Tensor[2, 3], :]
594   // batched_indices: [arange(B).expand(B, 2, 3), :, :, Tensor[B, 2, 3], Tensor[2, 3], :]
595   // required values: Tensor[B, 2, 3, 5, 6, 9]
596   // actual values: Tensor[B, 5, 6, 2, 3, 9]
597   //
598   // The resolution is to move dims around until we get the right shape.
599   // The values is set up as [B, <leading_nones>, <maxIndexDim>, ...]
600   // we just have to move the <maxIndexDim> to before the <leading_nones> to produce
601   // [B, <maxIndexDim>, <leading_nones>, ...]
602   return swap_regions(values, num_leading_nones, max_index_dim);
603 }
604 
index_put_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,ArrayRef<std::optional<Tensor>> indices,ArrayRef<std::optional<int64_t>> indices_bdims,const Tensor & values,std::optional<int64_t> values_bdim,bool accumulate)605 std::tuple<Tensor, std::optional<int64_t>> index_put_batch_rule(
606     const Tensor& self,
607     std::optional<int64_t> self_bdim,
608     ArrayRef<std::optional<Tensor>> indices,
609     ArrayRef<std::optional<int64_t>> indices_bdims,
610     const Tensor& values,
611     std::optional<int64_t> values_bdim,
612     bool accumulate) {
613   TORCH_INTERNAL_ASSERT(indices.size() == indices_bdims.size());
614 
615   // find the batch_size
616   int64_t batch_size = 0;
617   if (self_bdim || values_bdim) {
618     batch_size = get_bdim_size2(self, self_bdim, values, values_bdim);
619   } else {
620     // one or more of the indices is batched.
621     for (size_t i = 0; i < indices.size(); i++) {
622       if (indices_bdims[i] && indices[i].has_value()) {
623         batch_size = indices[i].value().size(*indices_bdims[i]);
624         break;
625       }
626     }
627   }
628 
629   auto [self_, indices_, values_] = index_put_batch_rule_helper(
630       self, self_bdim, indices, indices_bdims, values, values_bdim, batch_size);
631 
632   // Why do we need to permute values?
633   // See NOTE [Advanced indexing (index.Tensor) batch rule] for details,
634   // but the gist is that index_put effectively does the following:
635   // - result = self_.clone()
636   // - result[indices_] = values
637   // - return result
638   // Now, the problem is, result[indices_] might return a Tensor whose shape is
639   // the shape of values, but permuted. This is because the shape of result[indices_]
640   // depends on if the original indices "have adjacent advanced indices"
641   // and the batched `indices_` might change the "have adjacent advanced indices" property
642   values_ = maybe_permute_values(values_, indices, indices_bdims);
643 
644   auto result = at::index_put(self_, List<std::optional<Tensor>>(indices_), values_, accumulate);
645   return std::make_tuple(result, 0);
646 }
647 
648 // plumbing done since we don't support List<std::optional<Tensor>> in codegen
index_put_plumbing(const Tensor & self,const List<std::optional<Tensor>> & indices,const Tensor & values,bool accumulate)649 Tensor index_put_plumbing(const Tensor & self, const List<std::optional<Tensor>> & indices,
650                           const Tensor & values, bool accumulate) {
651   c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
652   auto maybe_layer = maybeCurrentDynamicLayer();
653   vmap_check_escaped(maybe_layer, "index_put_plumbing");
654   int64_t cur_level = maybe_layer->layerId();
655 
656   // on device mismatch, we can move 0d tensors to self device
657   auto values_ = values;
658   if (values.device() != self.device() && values.numel() == 1 && values.dim() == 0) {
659     values_ = values.to(self.device());
660   }
661 
662   if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values_, cur_level)) {
663     return self.index_put(indices, values_, accumulate);
664   }
665   auto [self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim] =
666       unpackSelfAndIndicesAndValuesAtCurrentLevel(self, indices, values_, cur_level);
667   auto results = index_put_batch_rule(self_value, self_bdim, indices_value, indices_bdims, values_value, values_bdim, accumulate);
668   return makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
669 }
670 
671 namespace {
672 
673 template<typename Func, typename ...Args>
scatter_batch_rule(Func f,const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & value,Args...args)674 std::tuple<Tensor, std::optional<int64_t>> scatter_batch_rule(
675     Func f,
676     const Tensor& self, std::optional<int64_t> self_bdim,
677     int64_t dim,
678     const Tensor& index, std::optional<int64_t> index_bdim,
679     const Scalar& value, Args... args) {
680   auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
681   auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
682   auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
683 
684   auto self_ = moveBatchDimToFront(self, self_bdim);
685   auto index_ = moveBatchDimToFront(index, index_bdim);
686 
687   if (self_logical_rank == 0) {
688     self_ = self_.unsqueeze(-1);
689   }
690   if (index_logical_rank == 0) {
691     index_ = index_.unsqueeze(-1);
692   }
693   self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
694   index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
695   auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
696 
697   auto result = f(self_, physical_dim, index_, value, args...);
698   // result should have same shape as self
699   if (self_logical_rank == 0) {
700     result = result.squeeze(-1);
701   }
702   return std::make_tuple(result, 0);
703 }
704 
705 template <typename Func, typename ...Args>
scatter_batch_rule(Func f,const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & src,std::optional<int64_t> src_bdim,Args...args)706 inline std::tuple<Tensor, std::optional<int64_t>> scatter_batch_rule(
707     Func f,
708     const Tensor& self, std::optional<int64_t> self_bdim,
709     int64_t dim,
710     const Tensor& index, std::optional<int64_t> index_bdim,
711     const Tensor& src, std::optional<int64_t> src_bdim, Args... args) {
712   auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
713   auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
714   auto src_logical_rank = rankWithoutBatchDim(src, src_bdim);
715   auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, src, src_bdim);
716 
717   auto self_ = moveBatchDimToFront(self, self_bdim);
718   auto index_ = moveBatchDimToFront(index, index_bdim);
719   auto src_ = moveBatchDimToFront(src, src_bdim);
720 
721   if (self_logical_rank == 0) {
722     self_ = self_.unsqueeze(-1);
723   }
724   if (index_logical_rank == 0) {
725     index_ = index_.unsqueeze(-1);
726   }
727   if (src_logical_rank == 0) {
728     src_ = src_.unsqueeze(-1);
729   }
730   self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
731   index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
732   src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size);
733   auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
734 
735   auto result = f(self_, physical_dim, index_, src_, args...);
736   // result should have same shape as self
737   if (self_logical_rank == 0) {
738     result = result.squeeze(-1);
739   }
740   return std::make_tuple(result, 0);
741 }
742 
743 } // namespace
744 
scatter_value_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & value)745 std::tuple<Tensor, std::optional<int64_t>> scatter_value_batch_rule(
746     const Tensor& self, std::optional<int64_t> self_bdim,
747     int64_t dim,
748     const Tensor& index, std::optional<int64_t> index_bdim,
749     const Scalar& value) {
750   return scatter_batch_rule(ATEN_FN2(scatter, value),
751                             self, self_bdim, dim, index, index_bdim, value);
752 }
753 
scatter_src_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & src,std::optional<int64_t> src_bdim)754 std::tuple<Tensor, std::optional<int64_t>> scatter_src_batch_rule(
755     const Tensor& self, std::optional<int64_t> self_bdim,
756     int64_t dim,
757     const Tensor& index, std::optional<int64_t> index_bdim,
758     const Tensor& src, std::optional<int64_t> src_bdim) {
759   return scatter_batch_rule(ATEN_FN2(scatter, src),
760                             self, self_bdim, dim, index, index_bdim, src, src_bdim);
761 }
762 
scatter_add_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & src,std::optional<int64_t> src_bdim)763 std::tuple<Tensor, std::optional<int64_t>> scatter_add_batch_rule(
764     const Tensor& self, std::optional<int64_t> self_bdim,
765     int64_t dim,
766     const Tensor& index, std::optional<int64_t> index_bdim,
767     const Tensor& src, std::optional<int64_t> src_bdim) {
768   return scatter_batch_rule(ATEN_FN(scatter_add),
769                             self, self_bdim, dim, index, index_bdim, src, src_bdim);
770 }
771 
scatter_reduce_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & src,std::optional<int64_t> src_bdim,const c10::string_view reduce)772 std::tuple<Tensor, std::optional<int64_t>> scatter_reduce_batch_rule(
773     const Tensor& self, std::optional<int64_t> self_bdim,
774     int64_t dim,
775     const Tensor& index, std::optional<int64_t> index_bdim,
776     const Tensor& src, std::optional<int64_t> src_bdim,
777     const c10::string_view reduce) {
778   return scatter_batch_rule(ATEN_FN2(scatter, reduce),
779                             self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce);
780 }
781 
scatter_value_reduce_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & src,const c10::string_view reduce)782 std::tuple<Tensor, std::optional<int64_t>> scatter_value_reduce_batch_rule(
783     const Tensor& self, std::optional<int64_t> self_bdim,
784     int64_t dim,
785     const Tensor& index, std::optional<int64_t> index_bdim,
786     const Scalar& src,
787     const c10::string_view reduce) {
788   return scatter_batch_rule(ATEN_FN2(scatter, value_reduce),
789                             self, self_bdim, dim, index, index_bdim, src, reduce);
790 }
791 
gather_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,bool sparse_grad)792 std::tuple<Tensor, std::optional<int64_t>> gather_batch_rule(
793     const Tensor& self, std::optional<int64_t> self_bdim,
794     int64_t dim,
795     const Tensor& index, std::optional<int64_t> index_bdim,
796     bool sparse_grad) {
797   auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
798   auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
799   auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
800 
801   auto self_ = moveBatchDimToFront(self, self_bdim);
802   auto index_ = moveBatchDimToFront(index, index_bdim);
803 
804   if (self_logical_rank == 0) {
805     self_ = self_.unsqueeze(-1);
806   }
807   if (index_logical_rank == 0) {
808     index_ = index_.unsqueeze(-1);
809   }
810   self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
811   index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
812   auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
813 
814   auto result = at::gather(self_, physical_dim, index_, sparse_grad);
815   // result should have same rank as index
816   if (index_logical_rank == 0) {
817     result = result.squeeze(-1);
818   }
819   return std::make_tuple(result, 0);
820 }
821 
get_expanded_index(const Tensor & index,IntArrayRef self_size,int64_t dim)822 Tensor get_expanded_index(const Tensor& index, IntArrayRef self_size, int64_t dim) {
823   if (index.dim() == 0) {
824     return index.expand(self_size);
825   }
826   dim = maybe_wrap_dim(dim, static_cast<int64_t>(self_size.size()));
827 
828   // setup new_index_shape as [BS, 1, ..., idx_size, ..., 1]
829   // to reshape index_
830   auto idx_size = index.size(0);  // get non-batch size of index tensor
831   Tensor index_;
832   {
833     VmapDimVector new_index_shape(self_size.size(), 1);
834     new_index_shape[dim] = idx_size;
835     index_ = index.view(new_index_shape);
836   }
837   // Now apply expand to index_
838   {
839     VmapDimVector new_index_shape = {self_size.begin(), self_size.end()};
840     new_index_shape[dim] = idx_size;
841     index_ = index_.expand(new_index_shape);
842   }
843   return index_;
844 }
845 
index_select_decomp(const Tensor & self,int64_t dim,const Tensor & index)846 Tensor index_select_decomp(const Tensor &self, int64_t dim, const Tensor &index)
847 {
848   Tensor index_ = index;
849   if (self.dim() > index.dim()) {
850     index_ = get_expanded_index(index, self.sizes(), dim);
851   }
852 
853   auto result = at::gather(self, dim, index_);
854 
855   // output of gather has same dimension as `index` while
856   // output of index_select has same dimension as self
857   // Eg. t = torch.tensor(1)
858   //     idx = torch.tensor([0])
859   //     torch.index_select(t, 0, idx) # 0-D
860   //     torch.gather(t, 0, idx) # 1-D
861   if (self.dim() == 0 && result.dim() != 0) {
862     result = result.squeeze(-1);
863   }
864 
865   return result;
866 }
867 
index_copy_decomp(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source)868 Tensor index_copy_decomp(
869     const Tensor &self, int64_t dim,
870     const Tensor &index, const Tensor &source)
871 {
872   Tensor index_ = index;
873   if (self.dim() > index.dim()) {
874     index_ = get_expanded_index(index, self.sizes(), dim);
875   }
876 
877   return at::scatter(self, dim, index_, source);  ;
878 }
879 
880 // Note [Fix vmap slice_scatter]
881 // registers a decomposition for `slice_scatter` that calls into `slice.src`
882 // *_scatter operators have some special semantics though, that we can't easily
883 // through a decomposition: slice_scatter's output needs to have the same
884 // size, size, strides and storage_offset as the input.
slice_scatter_decomp(const Tensor & self,const Tensor & src,int64_t dim,std::optional<int64_t> start,std::optional<int64_t> end,int64_t step)885 Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src,
886                             int64_t dim, std::optional<int64_t> start,
887                             std::optional<int64_t> end, int64_t step)
888 {
889   auto idx = at::arange(start.value_or(0), end.value_or(self.size(dim)), step, self.options().dtype(kLong));
890   idx = get_expanded_index(idx, self.sizes(), dim);
891   return at::scatter(self, dim, idx, src);
892 }
893 
select_scatter_decomp(const Tensor & self,const Tensor & source,int64_t dim,int64_t index)894 Tensor select_scatter_decomp(
895     const Tensor &self, const Tensor &source,
896     int64_t dim, int64_t index)
897 {
898   // supports negative index
899   index = maybe_wrap_dim(index, self.size(dim));
900   auto index_ = at::scalar_tensor(index, self.options().dtype(kLong));
901 
902   return at::scatter(self, dim, index_.expand_as(self), source.unsqueeze(dim).expand_as(self));
903 }
904 
diagonal_scatter_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & src,std::optional<int64_t> src_bdim,int64_t offset,int64_t dim1,int64_t dim2)905 std::tuple<Tensor, std::optional<int64_t>> diagonal_scatter_batch_rule(
906     const Tensor &self, std::optional<int64_t> self_bdim,
907     const Tensor &src, std::optional<int64_t> src_bdim,
908     int64_t offset, int64_t dim1, int64_t dim2)
909 {
910   auto self_ = moveBatchDimToFront(self, self_bdim);
911   auto src_ = moveBatchDimToFront(src, src_bdim);
912 
913   auto batch_size = get_bdim_size2(self, self_bdim, src, src_bdim);
914 
915   self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
916   src_ = ensure_has_bdim(src_, src_bdim.has_value(), batch_size);
917 
918   auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
919   dim1 = maybe_wrap_dim(dim1, self_logical_rank) + 1;
920   dim2 = maybe_wrap_dim(dim2, self_logical_rank) + 1;
921 
922   return std::make_tuple(at::diagonal_scatter(self_, src_, offset, dim1, dim2), 0);
923 }
924 
index_add_batch_rule_impl(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & other,std::optional<int64_t> other_bdim,const Scalar & alpha,const bool inplace)925 std::tuple<Tensor, std::optional<int64_t>> index_add_batch_rule_impl(
926     Tensor& self, std::optional<int64_t> self_bdim,
927     int64_t dim,
928     const Tensor& index, std::optional<int64_t> index_bdim,
929     const Tensor& other, std::optional<int64_t> other_bdim,
930     const Scalar& alpha,
931     const bool inplace) {
932 
933   if (inplace && !self_bdim.has_value()){
934     vmapIncompatibleInplaceError("index_add_");
935   }
936 
937   if (!index_bdim) {
938     // Handle scalar tensors... self, other can be scalar tensors
939     const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
940     const auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
941     auto self_ = moveBatchDimToFront(self, self_bdim);
942     if (self_logical_rank == 0) {
943       self_ = self_.unsqueeze(-1);
944     }
945     auto other_ = moveBatchDimToFront(other, other_bdim);
946     if (other_logical_rank == 0) {
947       other_ = other_.unsqueeze(-1);
948     }
949     dim = maybe_wrap_dim(dim, self_logical_rank);
950 
951     const auto batch_size = get_bdim_size2(self, self_bdim, other, other_bdim);
952     self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
953     other_ = ensure_has_bdim(other_, other_bdim.has_value(), batch_size);
954 
955     if (inplace) {
956       self_.index_add_(dim + 1, index, other_, alpha);
957       if (self_logical_rank == 0) {
958         self_ = self_.squeeze(-1);
959       }
960       return std::make_tuple(self, 0);
961     }
962 
963     auto result = self_.index_add(dim + 1, index, other_, alpha);
964     if (self_logical_rank == 0) {
965       result = result.squeeze(-1);
966     }
967     return std::make_tuple(result, 0);
968   }
969 
970   // Index is batched. For-loop and stack is the best thing I can come up with
971   // right now. We really want generalized index_add kernel in PyTorch
972   auto batch_size = get_bdim_size3(self, self_bdim, other, other_bdim, index, index_bdim);
973   std::vector<Tensor> results;
974   if (!inplace) {
975     results.reserve(batch_size);
976   }
977   for (const auto i : c10::irange(0, batch_size)) {
978     const auto& self_slice = self_bdim.has_value() ?
979       self.select(*self_bdim, i) : self;
980     const auto& other_slice = other_bdim.has_value() ?
981       other.select(*other_bdim, i) : other;
982     const auto& index_slice = index_bdim.has_value() ?
983       index.select(*index_bdim, i) : index;
984 
985     if (inplace) {
986       self_slice.index_add_(dim, index_slice, other_slice, alpha);
987     } else {
988       results.push_back(at::index_add(self_slice, dim, index_slice, other_slice, alpha));
989     }
990   }
991   if (inplace) {
992     return std::make_tuple(at::stack(self), 0);
993   }
994   return std::make_tuple(at::stack(results), 0);
995 }
996 
index_add__batch_rule(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & other,std::optional<int64_t> other_bdim,const Scalar & alpha)997 void index_add__batch_rule(
998     Tensor& self, std::optional<int64_t> self_bdim,
999     int64_t dim,
1000     const Tensor& index, std::optional<int64_t> index_bdim,
1001     const Tensor& other, std::optional<int64_t> other_bdim,
1002     const Scalar& alpha) {
1003   index_add_batch_rule_impl(self, self_bdim, dim, index, index_bdim, other,
1004                             other_bdim, alpha, true);
1005 }
1006 
index_add_batch_rule(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & other,std::optional<int64_t> other_bdim,const Scalar & alpha)1007 std::tuple<Tensor, std::optional<int64_t>> index_add_batch_rule(
1008     Tensor& self, std::optional<int64_t> self_bdim,
1009     int64_t dim,
1010     const Tensor& index, std::optional<int64_t> index_bdim,
1011     const Tensor& other, std::optional<int64_t> other_bdim,
1012     const Scalar& alpha) {
1013   auto self_ = self.clone(at::MemoryFormat::Preserve);
1014   return index_add_batch_rule_impl(self_, self_bdim, dim, index, index_bdim,
1015                                    other, other_bdim, alpha, false);
1016 }
1017 
binary_pointwise_align(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & mask,std::optional<int64_t> mask_bdim)1018 static std::tuple<Tensor,Tensor> binary_pointwise_align(
1019     const Tensor & self,
1020     std::optional<int64_t> self_bdim,
1021     const Tensor & mask,
1022     std::optional<int64_t> mask_bdim) {
1023   // compute max logical rank
1024   auto tensor_logical_rank = rankWithoutBatchDim(self, self_bdim);
1025   auto other_logical_rank = rankWithoutBatchDim(mask, mask_bdim);
1026   auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
1027 
1028   auto tensor_ = moveBatchDimToFront(self, self_bdim);
1029   auto other_ = moveBatchDimToFront(mask, mask_bdim);
1030 
1031   // If the dimensions aren't aligned, we need to line them up.
1032   // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
1033   // Note that only tensors that have a batch dim need to be modified.
1034   // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
1035   tensor_ = maybePadToLogicalRank(tensor_, self_bdim, max_logical_rank);
1036   other_ = maybePadToLogicalRank(other_, mask_bdim, max_logical_rank);
1037 
1038   return std::make_tuple(tensor_, other_);
1039 }
1040 
masked_fill_scalar_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,const Tensor & mask,std::optional<int64_t> mask_bdim,const Scalar & source)1041 std::tuple<Tensor, std::optional<int64_t>> masked_fill_scalar_batch_rule(
1042     const Tensor & self,
1043     std::optional<int64_t> self_bdim,
1044     const Tensor & mask,
1045     std::optional<int64_t> mask_bdim,
1046     const Scalar& source) {
1047   auto tensors = binary_pointwise_align(self, self_bdim, mask, mask_bdim);
1048   auto result = at::masked_fill(std::get<0>(tensors), std::get<1>(tensors), source);
1049   return std::make_tuple(result, 0);
1050 }
1051 
index_fill_batch_rule_helper(int64_t batch_size,int64_t self_logical_rank,int64_t index_logical_rank,Tensor & self_,int64_t dim,Tensor & index_,const Scalar & value)1052 std::tuple<Tensor, std::optional<int64_t>> index_fill_batch_rule_helper(
1053   int64_t batch_size,
1054   int64_t self_logical_rank,
1055   int64_t index_logical_rank,
1056   Tensor & self_,
1057   int64_t dim,
1058   Tensor & index_,
1059   const Scalar & value
1060   ){
1061   if (self_logical_rank != 0){
1062     auto index_offset = at::arange(
1063       batch_size,
1064       at::TensorOptions().dtype(index_.scalar_type()).device(index_.device())
1065     );
1066     if (index_logical_rank == 0){
1067       index_ = index_.unsqueeze(-1);
1068     }
1069     index_ = index_.add(index_offset.unsqueeze(-1), self_.size(dim + 1));
1070     index_ = reshape_dim_into(0, 0, index_);
1071     self_ = reshape_dim_into(0, dim, self_);
1072     self_.index_fill_(dim, index_, value);
1073     self_ = reshape_dim_outof(dim, batch_size, self_);
1074     return std::make_tuple(self_, dim);
1075   }
1076 
1077   // If self_logical_rank == 0, the batch dim is certainly 0, and we must apply batched indices to each row.
1078   if (index_logical_rank != 0){
1079     index_ = reshape_dim_into(0, 0, index_);
1080   }
1081   self_.unsqueeze_(-1);
1082   self_.index_fill_(dim + 1, index_, value);
1083   self_.squeeze_(-1);
1084 
1085   return std::make_tuple(self_, 0);
1086 }
1087 
index_fill_int_scalar_batch_rule_impl(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & value,const bool inplace)1088 std::tuple<Tensor, std::optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
1089     Tensor & self, std::optional<int64_t> self_bdim,
1090     int64_t dim,
1091     const Tensor & index, std::optional<int64_t> index_bdim,
1092     const Scalar & value,
1093     const bool inplace) {
1094   const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
1095   const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
1096   Tensor self_ = moveBatchDimToFront(self, self_bdim);
1097   Tensor index_ = moveBatchDimToFront(index, index_bdim);
1098   dim = maybe_wrap_dim(dim, self_logical_rank);
1099 
1100   if (inplace && !self_bdim.has_value()) {
1101     vmapIncompatibleInplaceError("index_fill_");
1102   }
1103 
1104   if (!index_bdim) {
1105     if (self_logical_rank == 0){
1106       self_.unsqueeze_(-1);
1107     }
1108     self_.index_fill_(dim + 1, index_, value);
1109     if (self_logical_rank == 0) {
1110       self_.squeeze_(-1);
1111     }
1112     return std::make_tuple(self_, 0);
1113   }
1114 
1115   auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
1116   self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
1117   index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
1118 
1119   if (inplace) {
1120     // Do for-loop for in-place because we cannot reshape
1121     // `self_` having an incompatible stride without copying.
1122     for (const auto i : c10::irange(0, batch_size)) {
1123       const auto& self_slice = self_.select(0, i);
1124       const auto& index_slice = index_.select(0, i);
1125       self_slice.index_fill_(
1126         dim,
1127         index_slice,
1128         value
1129       );
1130     }
1131     return std::make_tuple(self_, 0);
1132   }
1133 
1134   self_ = self_bdim.has_value() ? self_ : self_.clone();
1135 
1136   return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value);
1137 }
1138 
index_fill_int_tensor_batch_rule_impl(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & value,std::optional<int64_t> value_bdim,const bool inplace)1139 std::tuple<Tensor, std::optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
1140     Tensor & self, std::optional<int64_t> self_bdim,
1141     int64_t dim,
1142     const Tensor & index, std::optional<int64_t> index_bdim,
1143     const Tensor & value, std::optional<int64_t> value_bdim,
1144     const bool inplace) {
1145   const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
1146   const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
1147   Tensor self_ = moveBatchDimToFront(self, self_bdim);
1148   Tensor index_ = moveBatchDimToFront(index, index_bdim);
1149   Tensor value_ = moveBatchDimToFront(value, value_bdim);
1150   dim = maybe_wrap_dim(dim, self_logical_rank);
1151 
1152   if (inplace && !self_bdim.has_value()) {
1153     vmapIncompatibleInplaceError("index_fill_");
1154   }
1155 
1156   if (!index_bdim && !value_bdim) {
1157     if (self_logical_rank == 0){
1158       self_.unsqueeze_(-1);
1159     }
1160     self_.index_fill_(dim + 1, index_, value);
1161     if (self_logical_rank == 0) {
1162       self_.squeeze_(-1);
1163     }
1164     return std::make_tuple(self_, 0);
1165   }
1166 
1167   auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
1168   self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
1169   index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
1170 
1171   if (inplace || value_bdim.has_value()) {
1172     // Do for-loop for in-place because we cannot reshape
1173     // `self_` having an incompatible stride without copying.
1174     // If value has a batch dim, we do for-loop as well because
1175     // index_fill_ supports 1-element tensor only.
1176     for (const auto i : c10::irange(0, batch_size)) {
1177       const auto& self_slice = self_.select(0, i);
1178       const auto& index_slice = index_.select(0, i);
1179       self_slice.index_fill_(
1180         dim,
1181         index_slice,
1182         value_bdim.has_value() ? value_.select(0, i) : value_
1183       );
1184     }
1185     return std::make_tuple(self_, 0);
1186   }
1187 
1188   self_ = self_bdim.has_value() ? self_ : self_.clone();
1189 
1190   // calling .item() on value is safe here because value is guaranteed to not be a batched tensor.
1191   return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value.item());
1192 }
1193 
index_fill__int_scalar_batch_rule(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & value)1194 void index_fill__int_scalar_batch_rule(
1195     Tensor & self, std::optional<int64_t> self_bdim,
1196     int64_t dim,
1197     const Tensor & index, std::optional<int64_t> index_bdim,
1198     const Scalar & value) {
1199   index_fill_int_scalar_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, true);
1200 }
1201 
index_fill__int_tensor_batch_rule(Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & value,std::optional<int64_t> value_bdim)1202 void index_fill__int_tensor_batch_rule(
1203     Tensor & self, std::optional<int64_t> self_bdim,
1204     int64_t dim,
1205     const Tensor & index, std::optional<int64_t> index_bdim,
1206     const Tensor & value, std::optional<int64_t> value_bdim) {
1207   index_fill_int_tensor_batch_rule_impl(self, self_bdim, dim, index, index_bdim, value, value_bdim, true);
1208 }
1209 
index_fill_int_scalar_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Scalar & value)1210 std::tuple<Tensor, std::optional<int64_t>> index_fill_int_scalar_batch_rule(
1211     const Tensor & self, std::optional<int64_t> self_bdim,
1212     int64_t dim,
1213     const Tensor & index, std::optional<int64_t> index_bdim,
1214     const Scalar & value) {
1215   auto self_ = self.clone(at::MemoryFormat::Preserve);
1216   return index_fill_int_scalar_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, false);
1217 }
1218 
index_fill_int_tensor_batch_rule(const Tensor & self,std::optional<int64_t> self_bdim,int64_t dim,const Tensor & index,std::optional<int64_t> index_bdim,const Tensor & value,std::optional<int64_t> value_bdim)1219 std::tuple<Tensor, std::optional<int64_t>> index_fill_int_tensor_batch_rule(
1220     const Tensor & self, std::optional<int64_t> self_bdim,
1221     int64_t dim,
1222     const Tensor & index, std::optional<int64_t> index_bdim,
1223     const Tensor & value, std::optional<int64_t> value_bdim) {
1224   auto self_ = self.clone(at::MemoryFormat::Preserve);
1225   return index_fill_int_tensor_batch_rule_impl(self_, self_bdim, dim, index, index_bdim, value, value_bdim, false);
1226 }
1227 
1228 }
1229 
TORCH_LIBRARY_IMPL(aten,FuncTorchBatched,m)1230 TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
1231   m.impl("index.Tensor", index_plumbing);
1232   m.impl("index_put_", index_put__plumbing);
1233   m.impl("index_put", index_put_plumbing);
1234   m.impl("_index_put_impl_", _index_put_impl__plumbing);
1235   m.impl("slice_scatter", slice_scatter_decomp);
1236   m.impl("select_scatter", select_scatter_decomp);
1237   m.impl("index_copy", index_copy_decomp);
1238   m.impl("index_select", index_select_decomp);
1239   VMAP_SUPPORT2(masked_fill, Scalar, masked_fill_scalar_batch_rule);
1240   VMAP_SUPPORT2(index_fill_, int_Tensor, index_fill__int_tensor_batch_rule);
1241   VMAP_SUPPORT2(index_fill_, int_Scalar, index_fill__int_scalar_batch_rule);
1242   VMAP_SUPPORT2(index_fill, int_Tensor, index_fill_int_tensor_batch_rule);
1243   VMAP_SUPPORT2(index_fill, int_Scalar, index_fill_int_scalar_batch_rule);
1244   VMAP_SUPPORT(index_add_, index_add__batch_rule);
1245   VMAP_SUPPORT(index_add, index_add_batch_rule);
1246   VMAP_SUPPORT(diagonal_scatter, diagonal_scatter_batch_rule);
1247   VMAP_SUPPORT(gather, gather_batch_rule);
1248   VMAP_SUPPORT2(scatter, value, scatter_value_batch_rule);
1249   VMAP_SUPPORT2(scatter, src, scatter_src_batch_rule);
1250   VMAP_SUPPORT(scatter_add, scatter_add_batch_rule);
1251   VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule);
1252   VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule);
1253   // as_strided_scatter does not work with the for-loop fallback today,
1254   // because as_strided_scatter will return an output that matches
1255   // the strides/storage_offset of its input.
1256   // With the for loop fallback, each input tensor is a slice into
1257   // the larger batched tensor.
1258   m.impl("as_strided_scatter", torch::CppFunction::makeFromBoxedFunction<&vmapErrorFallback>());
1259 }
1260 
1261 } // namespace at::functorch
1262