xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorAdvancedIndexing.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Indexing tensors by by tensors
2 //
3 // This corresponds to "advanced indexing" in NumPy. The two operations are:
4 //
5 //  index(Tensor self, indices) -> Tensor
6 //  index_put_(Tensor self, indices, value, accumulate=false)
7 //
8 // The index is a TensorList containing kLong, kBool or kByte tensors or nulls. Byte
9 // tensors (boolean masks) are expanded to long tensors via nonzero(). Null
10 // tensors signify that the dimension is not indexed.
11 //
12 // All indexes are broadcast together and iterated as *one*. From NumPy:
13 //
14 // result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
15 //                           ..., ind_N[i_1, ..., i_M]]
16 //
17 // Note 1: ByteTensors expand to index as many dimensions as there are in the
18 // mask.
19 //
20 // Note 2: The behavior is more complicated when the index tensors are not all
21 // adjacent (e.g. x[[0, 1], :, [2, 3]]). In this case, self and the index
22 // tensors are transposed to the front: x.transpose(1, 2)[[0, 1], [2, 3]]
23 //
24 // The code contains two implementations of indexing. The more efficient
25 // implementation treats indexing like an elementwise operation over the
26 // tensors `result`, `x`, `ind_1`, `ind_2`, etc. This implementation does
27 // not work for index_put_ with accumulate=True. The other implementation
28 // combines the indexed tensors into a single linear index that is used
29 // with Tensor.put_. This is used for index_put_ with accumulate=True.
30 //
31 // The more efficient implementation takes the following steps for the
32 // above operation:
33 //
34 // 1) Broadcast ind_1, ind_2, ind_3 together to a common shape
35 // 2) Record x.stride(i) for each indexed dimension `i`
36 // 3) Replace the indexed subspace of `x` with the shape of the corresponding
37 //    subspace of `result` but with stride 0
38 // 4) Add dimensions of size 1 to the index tensors (ind_1, ind_2, etc.) so
39 //    that their shape is compatible with the result shape
40 //
41 // The CPU or CUDA kernel then computes element-wise over the broadcasted
42 // and restrided result, x, ind_1,  ind_2, etc.:
43 //
44 //   result[...] = *(&x[...] +
45 //                   ind_1[...] * x.stride(1) +
46 //                   ind_2[...] * x.stride(2) +
47 //                   ...)
48 //
49 // where & and * represent the C-style address-of and indirection operations.
50 // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
51 #include <ATen/ATen.h>
52 
53 #include <ATen/native/TensorAdvancedIndexing.h>
54 #include <ATen/native/IndexKernel.h>
55 #include <ATen/native/IndexingUtils.h>
56 
57 #include <ATen/core/Tensor.h>
58 #include <ATen/core/IListRef.h>
59 #include <ATen/Context.h>
60 #include <ATen/Dispatch.h>
61 #include <ATen/ExpandUtils.h>
62 #include <ATen/MemoryOverlap.h>
63 #include <ATen/NamedTensorUtils.h>
64 #include <ATen/Parallel.h>
65 #include <ATen/TensorIterator.h>
66 #include <ATen/TensorMeta.h>
67 #include <ATen/TensorOperators.h>
68 #include <ATen/TensorUtils.h>
69 #include <ATen/WrapDimUtils.h>
70 #include <ATen/native/BinaryOps.h>
71 #include <ATen/native/Copy.h>
72 #include <ATen/native/Resize.h>
73 #include <ATen/native/ScatterGatherChecks.h>
74 #include <ATen/native/TensorAdvancedIndexingUtils.h>
75 #include <ATen/Parallel.h>
76 #include <ATen/NumericUtils.h>
77 #include <ATen/TensorSubclassLikeUtils.h>
78 
79 #ifndef AT_PER_OPERATOR_HEADERS
80 #include <ATen/Functions.h>
81 #include <ATen/NativeFunctions.h>
82 #else
83 #include <ATen/ops/_gather_sparse_backward.h>
84 #include <ATen/ops/_gather_sparse_backward_native.h>
85 #include <ATen/ops/_index_put_impl.h>
86 #include <ATen/ops/_index_put_impl_native.h>
87 #include <ATen/ops/_sparse_coo_tensor_unsafe.h>
88 #include <ATen/ops/_unsafe_index_native.h>
89 #include <ATen/ops/_unsafe_index_put_native.h>
90 #include <ATen/ops/arange.h>
91 #include <ATen/ops/argwhere_native.h>
92 #include <ATen/ops/as_strided.h>
93 #include <ATen/ops/broadcast_to.h>
94 #include <ATen/ops/count_nonzero.h>
95 #include <ATen/ops/count_nonzero_native.h>
96 #include <ATen/ops/empty.h>
97 #include <ATen/ops/empty_quantized.h>
98 #include <ATen/ops/gather.h>
99 #include <ATen/ops/gather_backward_native.h>
100 #include <ATen/ops/gather_meta.h>
101 #include <ATen/ops/gather_native.h>
102 #include <ATen/ops/index.h>
103 #include <ATen/ops/index_add_meta.h>
104 #include <ATen/ops/index_add_native.h>
105 #include <ATen/ops/index_copy_meta.h>
106 #include <ATen/ops/index_copy_native.h>
107 #include <ATen/ops/index_fill_native.h>
108 #include <ATen/ops/index_meta.h>
109 #include <ATen/ops/index_native.h>
110 #include <ATen/ops/index_put_native.h>
111 #include <ATen/ops/index_reduce_meta.h>
112 #include <ATen/ops/index_reduce_native.h>
113 #include <ATen/ops/index_select_backward_native.h>
114 #include <ATen/ops/index_select_native.h>
115 #include <ATen/ops/masked_fill_native.h>
116 #include <ATen/ops/masked_scatter_native.h>
117 #include <ATen/ops/masked_select_backward_native.h>
118 #include <ATen/ops/masked_select_native.h>
119 #include <ATen/ops/nested_to_padded_tensor_native.h>
120 #include <ATen/ops/nonzero_native.h>
121 #include <ATen/ops/nonzero_numpy_native.h>
122 #include <ATen/ops/nonzero_static_native.h>
123 #include <ATen/ops/ones_like.h>
124 #include <ATen/ops/put_native.h>
125 #include <ATen/ops/quantize_per_tensor.h>
126 #include <ATen/ops/scatter_add_meta.h>
127 #include <ATen/ops/scatter_add_native.h>
128 #include <ATen/ops/scatter_meta.h>
129 #include <ATen/ops/scatter_native.h>
130 #include <ATen/ops/scatter_reduce_meta.h>
131 #include <ATen/ops/scatter_reduce_native.h>
132 #include <ATen/ops/take_along_dim_native.h>
133 #include <ATen/ops/take_native.h>
134 #include <ATen/ops/zeros_like.h>
135 #endif
136 
137 #ifdef USE_FBGEMM
138 #include <fbgemm/Utils.h>
139 #endif
140 
141 #include <c10/util/irange.h>
142 #include <c10/util/Unroll.h>
143 
144 #include <algorithm>
145 #include <numeric>
146 #include <utility>
147 #include <vector>
148 
149 namespace at::native {
150 
151 std::string shapes_as_str(TensorList tensors);
152 AdvancedIndex make_info(Tensor self, IOptTensorListRef orig);
153 
154 } // namespace at::native
155 
156 namespace at::meta {
157 
TORCH_META_FUNC(gather)158 TORCH_META_FUNC(gather)
159 (const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) {
160   const Tensor& result = maybe_get_output(0);
161   int64_t wrapped_dim = at::maybe_wrap_dim(dim, self.dim());
162 
163   // Memory overlap checks need to be done after resizing (if required) is done.
164   // But it only makes sense to do these checks when result was defined, hence
165   // the boolean variable `check_result` here.
166   // For more details, see: https://github.com/pytorch/pytorch/pull/63312#discussion_r694794832
167   // and https://github.com/pytorch/pytorch/issues/63837
168   bool check_result = result.defined();
169   set_output_raw_strided(0, index.sizes(), {}, self.options());
170   if (check_result) {
171     at::assert_no_internal_overlap(result);
172     at::assert_no_overlap(result, self);
173     at::assert_no_partial_overlap(result, index);
174   }
175 
176   auto is_index_empty = index.numel() == 0;
177   if (!is_index_empty) {
178     TORCH_CHECK(
179       index.scalar_type() == at::ScalarType::Long,
180       "gather", "(): Expected dtype int64 for index"
181     );
182   }
183   if (is_index_empty) return;
184   at::native::gather_shape_check(self, wrapped_dim, index);
185 }
186 
187 template <bool use_new_options = false, typename Meta>
scatter_meta_impl(Meta & meta,const Tensor & self,int64_t dim,const Tensor & index,const std::optional<Tensor> & src=std::nullopt,const std::optional<c10::string_view> reduce=std::nullopt)188 void scatter_meta_impl(
189     Meta& meta,
190     const Tensor& self,
191     int64_t dim,
192     const Tensor& index,
193     const std::optional<Tensor>& src = std::nullopt,
194     const std::optional<c10::string_view> reduce = std::nullopt) {
195   int64_t wrapped_dim = at::maybe_wrap_dim(dim, self.dim());
196   at::native::scatter_gather_dtype_check("scatter", self, index, src);
197   at::native::scatter_shape_check(self, wrapped_dim, index, src);
198   auto output = meta.maybe_get_output(0);
199 
200   if (output.defined()) {
201     at::assert_no_internal_overlap(output);
202     at::assert_no_overlap(output, index);
203     if (src.has_value()) {
204       at::assert_no_overlap(output, src.value());
205     }
206   }
207 
208   meta.set_output_raw_strided(0, self.sizes(), {}, self.options());
209   if (reduce.has_value()) {
210     // Check if we have a valid reduce operator.
211     at::native::get_operator_enum(reduce.value(), use_new_options);
212   }
213 }
214 
TORCH_META_FUNC2(scatter,src)215 TORCH_META_FUNC2(scatter, src)
216 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
217   scatter_meta_impl(*this, self, dim, index, src);
218 }
219 
TORCH_META_FUNC2(scatter,value)220 TORCH_META_FUNC2(scatter, value)
221 (const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value) {
222   scatter_meta_impl(*this, self, dim, index);
223 }
224 
TORCH_META_FUNC2(scatter,reduce)225 TORCH_META_FUNC2(scatter, reduce)
226 (const Tensor& self,
227  int64_t dim,
228  const Tensor& index,
229  const Tensor& src,
230  const c10::string_view reduce) {
231   TORCH_WARN_ONCE(
232       "The reduce argument of torch.scatter with Tensor src is deprecated and will be removed ",
233       "in a future PyTorch release. Use torch.scatter_reduce instead for more reduction options."
234   );
235   scatter_meta_impl(*this, self, dim, index, src, reduce);
236 }
237 
TORCH_META_FUNC2(scatter,value_reduce)238 TORCH_META_FUNC2(scatter, value_reduce)
239 (const Tensor& self,
240  int64_t dim,
241  const Tensor& index,
242  const Scalar& src,
243  const c10::string_view reduce) {
244   scatter_meta_impl(*this, self, dim, index, std::nullopt, reduce);
245 }
246 
TORCH_META_FUNC(scatter_add)247 TORCH_META_FUNC(scatter_add)
248 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
249   scatter_meta_impl(*this, self, dim, index, src, "add");
250 }
251 
TORCH_META_FUNC2(scatter_reduce,two)252 TORCH_META_FUNC2(scatter_reduce, two)
253 (const Tensor& self,
254  int64_t dim,
255  const Tensor& index,
256  const Tensor& src,
257  const c10::string_view reduce,
258  bool include_self) {
259   (void) include_self;
260   scatter_meta_impl</*use_new_options=*/true>(*this, self, dim, index, src, reduce);
261 }
262 
TORCH_PRECOMPUTE_META_FUNC(index_copy)263 TORCH_PRECOMPUTE_META_FUNC(index_copy)
264 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source) {
265   dim = maybe_wrap_dim(dim, self.dim());
266 
267   const Tensor& result = maybe_get_output(0);
268 
269   // Memory overlap checks need to be done after resizing (if required) is done.
270   // But it only makes sense to do these checks when result was defined, hence
271   // the boolean variable `check_result` here.
272   // For more details, see: https://github.com/pytorch/pytorch/pull/63312#discussion_r694794832
273   // and https://github.com/pytorch/pytorch/issues/63837
274   bool check_result = result.defined();
275   set_output_raw_strided(0, self.sizes(), {}, self.options());
276   if (check_result) {
277     at::assert_no_internal_overlap(result);
278     at::assert_no_overlap(result, index);
279     at::assert_no_overlap(result, source);
280   }
281 
282   TORCH_CHECK_INDEX(index.dim() < 2, "index_copy_(): Index should have dimension 1 or 0 (got ", index.dim(), ")");
283 
284   int64_t numIndices = index.numel();
285   if (source.dim() == 0 && numIndices != 1) {
286     TORCH_CHECK_INDEX(false, "index_copy_(): When source is scalar, index should have one element (got ", numIndices, ")");
287   } else if ((source.dim() != self.dim()) && (source.dim() != 0 && self.dim() != 0)) {
288     TORCH_CHECK_INDEX(false, "index_copy_(): When source and destination are not scalars, their dimensionality must match. Source dimensionality (",
289                    source.dim(), "), destination dimensionality (", self.dim(), ")");
290   }
291 
292   TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_copy_(): Expected a long tensor for index, but got ", index.scalar_type());
293   TORCH_CHECK(self.scalar_type() == source.scalar_type(), "index_copy_(): self and source expected to have the same dtype, but got (self) ", self.scalar_type(), " and (source) ", source.scalar_type());
294   TORCH_CHECK(self.device() == source.device() && self.device() == index.device(),
295       "index_copy_(): self, index and source expected to be in the same device, but got (self) ",
296       self.device(), ", (index) ", index.device(), ", and (source) ", source.device());
297 
298   // Check that source and destination slices have the same size
299   auto selfSlicedSizes = self.sizes().vec();
300   if (!selfSlicedSizes.empty()) {
301     selfSlicedSizes.erase(selfSlicedSizes.begin() + dim);
302   }
303   auto sourceSlicedSizes = source.sizes().vec();
304   if (!sourceSlicedSizes.empty()) {
305     sourceSlicedSizes.erase(sourceSlicedSizes.begin() + dim);
306   }
307   if (selfSlicedSizes.size() != sourceSlicedSizes.size() ||
308       !std::equal(selfSlicedSizes.begin(), selfSlicedSizes.end(),
309                   sourceSlicedSizes.begin())) {
310     std::stringstream ss;
311     ss << "index_copy_(): Source/destination tensor must have same slice shapes. ";
312     ss << "Destination slice shape: " << selfSlicedSizes << " at dimension " << dim;
313     ss << " and source slice shape: " << sourceSlicedSizes << " at dimension 0.";
314     TORCH_CHECK(false, ss.str());
315   }
316   TORCH_CHECK_INDEX(source.dim() == 0 || numIndices == source.size(dim),
317           "index_copy_(): Number of indices (", numIndices, ") should be equal to source.size(dim) (", source.size(dim), ")");
318 
319   return TORCH_PRECOMPUTE_STRUCT(index_copy)().set_dim(dim);
320 }
321 
322 template <typename Meta>
index_func_meta_impl(Meta & meta,const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source,c10::string_view func)323 void index_func_meta_impl(
324   Meta& meta,
325   const Tensor& self,
326   int64_t dim,
327   const Tensor& index,
328   const Tensor& source,
329   c10::string_view func) {
330   auto numel = index.numel();
331 
332   TORCH_CHECK_INDEX(index.dim() <= 1, func, "_(): Index is supposed to be a vector, but got dim: ",
333                     index.dim(), " with type: ", index.scalar_type(), " and size: ", index.sizes());
334   TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int,
335               func, "_(): Expected dtype int32/int64 for index but got: ", index.scalar_type());
336   TORCH_CHECK(self.scalar_type() == source.scalar_type(),
337               func, "_(): self (", self.scalar_type(), ") and source (", source.scalar_type(),
338               ") must have the same scalar type");
339   TORCH_CHECK(dim == 0 || dim < source.dim(),
340               func, "_(): Indexing dim ", dim, " is out of bounds of the source tensor with dim ",
341               source.dim());
342   TORCH_CHECK(numel == (source.dim() == 0 ? 1 : source.size(dim)),
343               func, "_(): Number of indices (", numel, ") should be equal to source.size(dim): (",
344               source.size(dim), "), for dim: ", dim);
345 
346   auto self_sizes = self.sizes().vec();
347   auto source_sizes = source.sizes().vec();
348   if (source.dim() != 0 && self.dim() != 0) {
349     self_sizes.erase(self_sizes.begin() + dim);
350     source_sizes.erase(source_sizes.begin() + dim);
351   }
352   TORCH_CHECK(
353       self_sizes == source_sizes,
354       "source tensor shape must match self tensor shape, excluding the specified dimension. Got self.shape = ",
355       self.sizes(),
356       " source.shape = ",
357       source.sizes());
358 
359   auto& result = meta.maybe_get_output(0);
360   bool is_defined = result.defined();
361   meta.set_output_raw_strided(0, self.sizes(), {}, self.options());
362   if (is_defined) {
363     at::assert_no_internal_overlap(result);
364     at::assert_no_overlap(result, index);
365     at::assert_no_overlap(result, source);
366   }
367 
368   // A hack to run TensorIterator checks in the meta function.
369   // See comment: https://github.com/pytorch/pytorch/pull/65993#discussion_r760307417
370   // TODO: (@krshrimali) Try inheriting from TensorIteratorBase instead.
371   if (result.device() == kMeta && result.dim() > 0) {
372     auto selfSlice = result.select(dim, 0);
373     auto sourceSlice = source.select(dim, 0);
374     auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice);
375   }
376 }
377 
TORCH_PRECOMPUTE_META_FUNC(index_add)378 TORCH_PRECOMPUTE_META_FUNC(index_add)
379 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha) {
380   dim = maybe_wrap_dim(dim, self.dim());
381   index_func_meta_impl(*this, self, dim, index, source, "index_add");
382   return TORCH_PRECOMPUTE_STRUCT(index_add)().set_dim(dim);
383 }
384 
TORCH_PRECOMPUTE_META_FUNC(index_reduce)385 TORCH_PRECOMPUTE_META_FUNC(index_reduce)
386 (const Tensor& self,
387  int64_t dim,
388  const Tensor& index,
389  const Tensor& source,
390  const c10::string_view reduce,
391  bool include_self) {
392   (void)include_self;
393   TORCH_CHECK(reduce == "prod" || reduce == "mean" || reduce == "amax" || reduce == "amin",
394               "index_reduce(): Expected reduce to be one of prod, mean, amax or amin but got ", reduce, ".");
395   dim = maybe_wrap_dim(dim, self.dim());
396   index_func_meta_impl(*this, self, dim, index, source, "index_reduce");
397   return TORCH_PRECOMPUTE_STRUCT(index_reduce)().set_dim(dim);
398 }
399 
build_index_op(TensorIteratorBase & iter,const at::native::AdvancedIndex & info,const Tensor & result)400 static void build_index_op(
401     TensorIteratorBase& iter,
402     const at::native::AdvancedIndex& info,
403     const Tensor& result) {
404   // 'TensorIterator' needs to own the things comming from 'info', since
405   // 'info' will be destroyed after the META function.
406   TensorIteratorConfig config;
407   // info.src is a restrided view of result
408   config.set_check_mem_overlap(false)
409       .check_all_same_dtype(false)
410       .add_output(result)
411       .add_owned_const_input(info.src);
412   for (auto& index : info.indices) {
413     config.add_owned_const_input(index);
414   }
415   if (!result.defined()) {
416     config.declare_static_dtype_and_device(info.src.scalar_type(), info.src.device());
417   }
418   iter.build(config);
419 }
420 
check_indices_on_cpu_or_selfdevice(const Tensor & self,const at::MaterializedIOptTensorListRef & indices)421 static void check_indices_on_cpu_or_selfdevice(
422     const Tensor& self,
423     const at::MaterializedIOptTensorListRef& indices) {
424   auto dev = self.device();
425   bool indices_on_cpu_or_dev = std::all_of(
426       indices.begin(), indices.end(), [=](const at::OptionalTensorRef& opt) {
427         return opt.has_value() ? (opt->is_cpu() || opt->device() == dev) : true;
428       });
429   TORCH_CHECK(
430       indices_on_cpu_or_dev,
431       "indices should be either on ", kCPU,
432       " or on the same device as the indexed tensor (", dev, ")");
433 }
434 
TORCH_PRECOMPUTE_META_FUNC2(index,Tensor)435 TORCH_PRECOMPUTE_META_FUNC2(index, Tensor)
436 (const Tensor& self, at::IOptTensorListRef indices) {
437   auto materialized = indices.materialize();
438 
439   TORCH_CHECK_INDEX(
440       materialized.size() <= (size_t)self.dim(),
441       "too many indices for tensor of dimension ",
442       self.dim(), " (got ", materialized.size(), ")");
443 
444   // Only allow: `dev_tensor[{cpu,dev}_tensor]`.
445   // See: https://github.com/pytorch/pytorch/pull/69607
446   check_indices_on_cpu_or_selfdevice(self, materialized);
447 
448   const auto& result = maybe_get_output();
449 
450   if (result.defined()) {
451     TORCH_CHECK(self.scalar_type() == result.scalar_type(),
452                 "index_out: self (", self.scalar_type(), ") and result (", result.scalar_type(),
453                 ") must have the same scalar type");
454     at::assert_no_internal_overlap(result);
455     at::assert_no_overlap(result, self);
456     for (const at::OptionalTensorRef& index : materialized) {
457       if (index.has_value()) {
458         at::assert_no_overlap(result, *index);
459       }
460     }
461   }
462 
463   auto info = at::native::make_info(self, std::move(indices));
464   build_index_op(*this, info, result);
465   return TORCH_PRECOMPUTE_STRUCT2(index, Tensor)()
466       .set_sizes(std::move(info.indexed_sizes))
467       .set_strides(std::move(info.indexed_strides));
468 }
469 
470 } // namespace at::meta
471 
472 namespace at::native {
473 
474 DEFINE_DISPATCH(index_stub);
475 DEFINE_DISPATCH(index_fill_stub);
476 DEFINE_DISPATCH(index_copy_stub);
477 DEFINE_DISPATCH(index_put_stub);
478 DEFINE_DISPATCH(index_put_with_sort_stub);
479 DEFINE_DISPATCH(put_stub);
480 DEFINE_DISPATCH(take_stub);
481 DEFINE_DISPATCH(masked_fill_stub);
482 REGISTER_NO_CPU_DISPATCH(index_put_with_sort_stub);
483 REGISTER_NO_CPU_DISPATCH(index_put_with_sort_quantized_stub);
484 DEFINE_DISPATCH(masked_select_serial_stub);
485 DEFINE_DISPATCH(masked_select_stub);
486 DEFINE_DISPATCH(masked_scatter_stub);
487 
488 DEFINE_DISPATCH(gather_stub);
489 DEFINE_DISPATCH(scatter_stub);
490 DEFINE_DISPATCH(scatter_fill_stub);
491 DEFINE_DISPATCH(scatter_add_stub);
492 DEFINE_DISPATCH(scatter_reduce_stub);
493 DEFINE_DISPATCH(scatter_scalar_reduce_stub);
494 DEFINE_DISPATCH(scatter_reduce_two_stub);
495 
496 DEFINE_DISPATCH(scatter_add_expanded_index_stub);
497 DEFINE_DISPATCH(scatter_reduce_expanded_index_stub);
498 DEFINE_DISPATCH(gather_expanded_index_stub);
499 
all_strides_match(TensorList tensors)500 static bool all_strides_match(TensorList tensors) {
501   TORCH_CHECK(!tensors.empty());
502   auto strides = tensors[0].strides();
503   for (auto& tensor : tensors.slice(1)) {
504     if (!strides.equals(tensor.strides())) {
505       return false;
506     }
507   }
508   return true;
509 }
510 
shapes_as_str(TensorList tensors)511 inline std::string shapes_as_str(TensorList tensors) {
512   std::ostringstream os;
513   bool first = true;
514   for (auto& tensor : tensors) {
515     if (tensor.defined()) {
516       if (!first) {
517         os << ", ";
518       }
519       os << tensor.sizes();
520       first = false;
521     }
522   }
523   return os.str();
524 }
525 
526 // Replace indexed dimensions in src with stride 0 and the size of the result tensor.
527 // The offset in these dimensions is computed by the kernel using the index tensor's
528 // values and the stride of src. The new shape is not meaningful. It's used to make
529 // the shape compatible with the result tensor.
restride_src(const Tensor & src,int64_t dims_before,int64_t dims_indexed,IntArrayRef replacement_shape)530 static Tensor restride_src(const Tensor& src, int64_t dims_before, int64_t dims_indexed,
531                            IntArrayRef replacement_shape) {
532   auto shape = DimVector(src.sizes());
533   auto strides = DimVector(src.strides());
534   int64_t end = dims_before + dims_indexed;
535   shape.erase(shape.begin() + dims_before, shape.begin() + end);
536   strides.erase(strides.begin() + dims_before, strides.begin() + end);
537   shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end());
538   strides.insert(strides.begin() + dims_before, replacement_shape.size(), 0);
539   return src.as_strided(shape, strides);
540 }
541 
542 // Add dimensions of size 1 to an index tensor so that it can be broadcast to the result
543 // shape and iterated over element-wise like the result tensor and the restrided src.
reshape_indexer(const Tensor & index,int64_t dims_before,int64_t dims_after)544 static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t dims_after) {
545   auto orig_shape = index.sizes();
546   auto shape = DimVector();
547   shape.append(dims_before, 1);
548   shape.append(orig_shape.begin(), orig_shape.end());
549   shape.append(dims_after, 1);
550   return index.reshape(shape);
551 }
552 
AdvancedIndex(const Tensor & src,TensorList indices_list)553 AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
554 {
555   int64_t element_size_bytes = src.element_size();
556   int64_t dims_before = 0, dims_after = 0, dims_indexed = 0;
557   IntArrayRef replacement_shape;
558   for (const auto dim : c10::irange(indices_list.size())) {
559     if (!indices_list[dim].defined()) {
560       if (dims_indexed == 0) {
561         dims_before++;
562       } else {
563         dims_after++;
564       }
565     } else {
566       dims_indexed++;
567       replacement_shape = indices_list[dim].sizes();
568       indexed_sizes.push_back(src.size(dim));
569       indexed_strides.push_back(src.stride(dim) * element_size_bytes);
570     }
571   }
572 
573   // Check if the indexed subspace contains a dim of size 0, but the replacement
574   // shape does not. This implies that an index is out of bounds, because there
575   // is no number that's a valid index for an empty tensor. Normally, out of
576   // bounds is handled in the indexing kernel, but this case fails earlier in
577   // restride_src with an unhelpful error message.
578   if (std::find(indexed_sizes.begin(), indexed_sizes.end(), 0) != indexed_sizes.end() &&
579       std::find(replacement_shape.begin(), replacement_shape.end(), 0) == replacement_shape.end()) {
580     TORCH_CHECK_INDEX(false, "index is out of bounds for dimension with size 0");
581   }
582 
583   this->dims_before = dims_before;
584   this->dims_after = dims_after;
585   this->src = restride_src(src, dims_before, dims_indexed, replacement_shape);
586 
587   for (auto& index : indices_list) {
588     if (index.defined()) {
589       indices.push_back(reshape_indexer(index, dims_before, dims_after));
590     }
591   }
592 
593   // For CUDA/MPS/XPU tensors, force all index tensors to have the same striding to
594   // simplify the CUDA/MPS/XPU kernel.
595   if (indices.size() >= 2 && (this->src.device().type() == kCUDA || this->src.device().type() == kMPS || this->src.device().type() == kXPU)) {
596     if (!all_strides_match(indices)) {
597       for (auto & indice : indices) {
598         indice = indice.contiguous();
599       }
600     }
601   }
602 }
603 
make_index_put_iterator(const AdvancedIndex & info,const Tensor & value)604 static TensorIterator make_index_put_iterator(const AdvancedIndex& info, const Tensor& value) {
605   TORCH_CHECK(is_expandable_to(value.sizes(), info.src.sizes()), "shape mismatch: value tensor of shape ", value.sizes(),
606              " cannot be broadcast to indexing result of shape ", info.src.sizes());
607   TORCH_CHECK(value.scalar_type() == info.src.scalar_type(),
608               "Index put requires the source and destination dtypes match, "
609               "got ", info.src.scalar_type(), " for the destination "
610               "and ", value.scalar_type(), " for the source.");
611   TensorIteratorConfig config;
612   // info.src is restrided by restride_src with 0 strided dimensions
613   config.set_check_mem_overlap(false);
614   config.resize_outputs(false);
615   config.check_all_same_dtype(false);
616   config.add_output(info.src);
617   config.add_const_input(value);
618   for (auto& index : info.indices) {
619     config.add_const_input(index);
620   }
621   return config.build();
622 }
623 
TORCH_IMPL_FUNC(index_out)624 TORCH_IMPL_FUNC(index_out)
625 (const Tensor& self,
626  DimVector sizes,
627  DimVector strides,
628  const Tensor& result) {
629   index_stub(device_type(), *this, sizes, strides);
630 }
631 
quantized_index(const Tensor & self,const torch::List<std::optional<Tensor>> & indices)632 Tensor quantized_index(const Tensor & self, const torch::List<std::optional<Tensor>>& indices) {
633   TORCH_INTERNAL_ASSERT(
634       self.qscheme() == c10::kPerTensorAffine ||
635       self.qscheme() == c10::kPerTensorSymmetric,
636       "Indexing is only supported for per-Tensor quantized Tensors.");
637 
638   // For now, this is a naive implementation which does dq -> index -> q.
639   // TODO(future PR): improve performance by removing the copies.
640   const auto& self_dq = self.dequantize();
641   auto result = at::index(self_dq, indices);
642   return at::quantize_per_tensor(
643       result, self.q_scale(), self.q_zero_point(), self.scalar_type());
644 }
645 
_unsafe_index(const Tensor & self,const torch::List<std::optional<Tensor>> & indices)646 Tensor _unsafe_index(const Tensor& self, const torch::List<std::optional<Tensor>>& indices) {
647   // Disallow boolean indexing since it leads to dynamic output shapes
648   for (auto i : c10::irange(indices.size())) {
649     auto index = indices.get(i);
650     if (index.has_value()) {
651       auto dtype = index->scalar_type();
652       TORCH_CHECK(dtype == kLong || dtype == kInt,
653                   "_unsafe_index found unexpected index type ", dtype);
654     }
655   }
656   return at::index(self, indices);
657 }
658 
_unsafe_masked_index(const Tensor & self,const Tensor & mask,const torch::List<std::optional<Tensor>> & indices,const Scalar & fill)659 Tensor _unsafe_masked_index(const Tensor& self, const Tensor& mask, const torch::List<std::optional<Tensor>>& indices, const Scalar& fill) {
660   // Unsafe masked index is equivalent to
661   //   where(mask, self[indices], fill)
662   // with the main difference being that the when the `mask` is false, the tensor
663   // `self` is not indexed using `indices`. This allows `indices` to be out-of-bounds
664   // when `mask` is false. When `mask` is true, the `indices` are expected to be
665   // in bounds and is not checked.
666   //
667   // This function is not meant to be executed on eager mode. An unoptimized version
668   // is provided here.
669   //
670   // compiler backends should implement this op such that `self[indices]` is not
671   // loaded when `mask` is true. See inductor for a reference.
672   auto clamp = [](const std::optional<Tensor>& index, auto size) -> std::optional<Tensor> {
673     if (!index) {
674       return index;
675     }
676     // Disallow bool
677     auto dtype = index->scalar_type();
678     TORCH_CHECK(dtype == kLong || dtype == kInt,
679                 "_unsafe_masked_index found unexpected index type ", dtype);
680     return at::clamp(*index, -size, size - 1);
681   };
682 
683   torch::List<std::optional<Tensor>> clamped_indices(indices);
684   std::transform(indices.begin(), indices.end(), self.sizes().begin(), clamped_indices.begin(), clamp);
685 
686   if (self.numel() == 0) {
687       // Returns a tensor filled with `fill` value
688       // We use a hack here since we do not have a method to get the
689       // correct size of the tensor. (except with meta impl which is
690       // not available on mobile builds)
691       std::vector<int64_t> new_sizes(self.dim());
692       auto compute_new_size = [](const std::optional<Tensor>& index, auto size) -> int64_t {
693           if (index && size == 0) {
694               return 1;
695           } else {
696               return size;
697           }
698       };
699       std::transform(indices.begin(), indices.end(), self.sizes().begin(), new_sizes.begin(), compute_new_size);
700       auto result = self.new_full(new_sizes, fill);
701       return at::_unsafe_index(result, clamped_indices);
702   }
703 
704   auto result = at::_unsafe_index(self, clamped_indices);
705   return result.masked_fill(at::logical_not(mask), fill);
706 }
707 
_unsafe_masked_index_put_accumulate(const Tensor & self,const Tensor & mask,const torch::List<std::optional<Tensor>> & indices,const Tensor & values)708 Tensor _unsafe_masked_index_put_accumulate(const Tensor& self, const Tensor& mask, const torch::List<std::optional<Tensor>>& indices, const Tensor& values) {
709   // This is the backward of _unsafe_masked_index.
710   // This function is not meant to be executed on eager mode.
711 
712   if (self.numel() == 0) {
713     return self.clone();
714   }
715 
716   // We recompute the clamped indices and rely on inductor to CSE the computation
717   auto clamp = [](const std::optional<Tensor>& index, auto size) -> std::optional<Tensor> {
718     if (!index) {
719       return index;
720     }
721     // Disallow bool
722     auto dtype = index->scalar_type();
723     TORCH_CHECK(dtype == kLong || dtype == kInt,
724                 "_unsafe_masked_index found unexpected index type ", dtype);
725     return at::clamp(*index, -size, size - 1);
726   };
727 
728   torch::List<std::optional<Tensor>> clamped_indices(indices);
729   std::transform(indices.begin(), indices.end(), self.sizes().begin(), clamped_indices.begin(), clamp);
730 
731   auto masked_value = values.masked_fill(at::logical_not(mask), 0);
732   return at::_unsafe_index_put(self, clamped_indices, masked_value, true);
733 }
734 
put_(Tensor & self,const Tensor & index,const Tensor & source,const bool accumulate)735 Tensor & put_(Tensor & self, const Tensor& index, const Tensor & source, const bool accumulate) {
736   // See note [Writing Nondeterministic Operations]
737   // Nondeterministic when index contains duplicate entries and we do not accumulate
738   // If we accumulate on GPU, we use atomicGPUAdd, which is non-deterministic
739   if (!accumulate || (accumulate && self.device().type() == DeviceType::CUDA)) {
740     at::globalContext().alertNotDeterministic("put_");
741   }
742 
743   // Type and device checks
744   TORCH_CHECK(index.scalar_type() == ScalarType::Long, "put_(): Expected a long tensor for index, but got ", index.scalar_type())
745   TORCH_CHECK(self.scalar_type() == source.scalar_type(), "put_(): self and source expected to have the same dtype, but got self.dtype = ", self.scalar_type(), " and source.dtype = ", source.scalar_type());
746   TORCH_CHECK(self.device() == source.device() && self.device() == index.device(),
747       "put_(): self, index and source expected to be in the same device, but got self.device = ",
748       self.device(), ", index.device = ", index.device(), ", and source.device = ", source.device());
749 
750   // index checks
751   TORCH_CHECK_INDEX(source.numel() == index.numel(), "put_(): Expected source and index to have the same number of elements, but got source.numel() = ", source.numel(), ", index.numel() = ", index.numel());
752   TORCH_CHECK_INDEX(!(self.numel() == 0 && index.numel() != 0), "put_(): Tried to put elements into an empty tensor");
753 
754   at::assert_no_internal_overlap(self);
755   at::assert_no_overlap(self, index);
756   at::assert_no_overlap(self, source);
757 
758   // Early return
759   if (index.numel() == 0) {
760     return self;
761   }
762 
763   auto index_reshaped = index.reshape(source.sizes());
764   // Do not iterate over self, we will compute the offsets manually
765   auto iter = TensorIteratorConfig()
766     .set_check_mem_overlap(false)
767     .check_all_same_dtype(false)
768     .add_const_input(source)
769     .add_const_input(index_reshaped)
770     .build();
771 
772   put_stub(iter.device_type(), iter, self, accumulate);
773 
774   return self;
775 }
776 
put(const Tensor & self,const Tensor & index,const Tensor & source,const bool accumulate)777 Tensor put(const Tensor & self, const Tensor& index, const Tensor & source, const bool accumulate) {
778   return self.clone(at::MemoryFormat::Preserve).put_(index, source, accumulate);
779 }
780 
index_put(const Tensor & self,const torch::List<std::optional<Tensor>> & indices,const Tensor & value,bool accumulate)781 Tensor index_put(const Tensor & self, const torch::List<std::optional<Tensor>>& indices, const Tensor & value, bool accumulate) {
782   return self.clone(at::MemoryFormat::Preserve).index_put_(indices, value, accumulate);
783 }
784 
_unsafe_index_put(const Tensor & self,const torch::List<std::optional<Tensor>> & indices,const Tensor & value,bool accumulate)785 Tensor _unsafe_index_put(const Tensor& self, const torch::List<std::optional<Tensor>>& indices, const Tensor& value, bool accumulate) {
786   return at::index_put(self, indices, value, accumulate);
787 }
788 
_index_put_impl_(Tensor & self,const torch::List<std::optional<Tensor>> & indices,const Tensor & value,const bool accumulate,const bool unsafe)789 Tensor & _index_put_impl_(Tensor & self, const torch::List<std::optional<Tensor>>& indices, const Tensor & value, const bool accumulate, const bool unsafe) {
790   TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
791   if (at::has_internal_overlap(self) == MemOverlap::Yes) {
792     TORCH_WARN(
793       "Use of index_put_ on expanded tensors is deprecated. "
794       "Please clone() the tensor before performing this operation. "
795       "This also applies to advanced indexing e.g. tensor[indices] = tensor");
796   }
797   if (!accumulate) {
798     auto masked_fill_dispatch = canDispatchToMaskedFill(self, indices, value);
799     if (std::get<0>(masked_fill_dispatch)) {
800       return self.masked_fill_(std::get<1>(masked_fill_dispatch), value.item());
801     }
802   }
803   auto value_ = value;
804   if (value.device() != self.device() && value.numel() == 1 && value.dim() == 0) {
805     value_ = value.to(self.device());
806   }
807   at::assert_no_overlap(self, value);
808   // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
809   for (const std::optional<Tensor>& index: indices) {
810     if (index.has_value()) {
811       at::assert_no_overlap(self, *index);
812     }
813   }
814   if ((self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::XPU) && (accumulate || globalContext().deterministicAlgorithms())) {
815       TORCH_CHECK(value_.device() == self.device(), "expected device ", self.device(), " but got device ",
816       value_.device(), " for value tensor");
817       index_put_with_sort_stub(self.device().type(), self, indices, value_, accumulate, unsafe);
818       return self;
819   }
820 
821   auto info = make_info(self, indices);
822   auto iter = make_index_put_iterator(info, value_);
823   index_put_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides, accumulate);
824   return self;
825 }
826 
take_out(const Tensor & self,const Tensor & index,Tensor & out)827 Tensor& take_out(const Tensor& self, const Tensor& index, Tensor& out) {
828   // Type and device checks
829   TORCH_CHECK(index.scalar_type() == ScalarType::Long, "take(): Expected a long tensor for index, but got ", index.scalar_type())
830   TORCH_CHECK(self.scalar_type() == out.scalar_type(), "take(): self and out expected to have the same dtype, but got self.dtype = ", self.scalar_type(), " and out.dtype = ", out.scalar_type());
831   TORCH_CHECK(self.device() == out.device() && self.device() == index.device(),
832       "take(): self, index and out expected to be in the same device, but got self.device = ",
833       self.device(), ", index.device = ", index.device(), ", and out.device = ", out.device());
834 
835   // index checks
836   TORCH_CHECK_INDEX(!(self.numel() == 0 && index.numel() != 0), "take(): tried to take from an empty tensor");
837 
838   at::assert_no_internal_overlap(out);
839   at::assert_no_overlap(out, index);
840   at::assert_no_overlap(out, self);
841 
842   // Do not iterate over self, we will compute the offsets manually
843   // out is resized inside tensor_iterator
844   auto iter = TensorIteratorConfig()
845     .set_check_mem_overlap(false)
846     .check_all_same_dtype(false)
847     .add_output(out)
848     .add_const_input(index)
849     .build();
850 
851   // Early return after out has been resized
852   if (index.numel() == 0) {
853     return out;
854   }
855 
856   take_stub(iter.device_type(), iter, self);
857 
858   return out;
859 }
860 
take(const Tensor & self,const Tensor & index)861 Tensor take(const Tensor& self, const Tensor& index) {
862     auto out = at::empty(index.sizes(), self.options());
863     at::native::take_out(self, index, out);
864     return out;
865 }
866 
index_put_(Tensor & self,const torch::List<std::optional<Tensor>> & indices,const Tensor & value,const bool accumulate)867 Tensor & index_put_(Tensor & self, const torch::List<std::optional<Tensor>>& indices, const Tensor & value, const bool accumulate) {
868   return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false);
869 }
870 
TORCH_IMPL_FUNC(index_copy_out)871 TORCH_IMPL_FUNC(index_copy_out)
872 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Tensor& result) {
873     if (!result.is_same(self)) result.copy_(self);
874 
875     // See Note [Enabling Deterministic Operations]
876     if (result.is_cuda() && globalContext().deterministicAlgorithms()){
877         torch::List<std::optional<Tensor>> indices;
878         indices.reserve(dim + 1);
879         for (const auto i: c10::irange(dim)) {
880           (void)i;
881           indices.emplace_back();
882         }
883         indices.emplace_back(index);
884         result.index_put_(indices, source, false);
885         return;
886     }
887 
888     // Handle the case when self / source is 0-dim
889     Tensor result_nonzero = result.dim() == 0 ? result.unsqueeze(0) : result;
890     Tensor source_nonzero = source.dim() == 0 ? source.unsqueeze(0) : source;
891 
892     // The only difference between the following  tensor iterator and that of index_fill_ is that
893     // this one has also source as an input. We should refactor it when if constexpr is available (C++17)
894 
895     // Prepare `index` for TensorIterator.
896     // It is restrided to be broadcastable over `self` in TensorIterator.
897     auto index_sizes = std::vector<int64_t>(result_nonzero.dim(), 1);
898     auto index_strides = std::vector<int64_t>(result_nonzero.dim(), 0);
899     index_sizes[dim] = index.numel();
900     index_strides[dim] = (index.dim() > 0) ? index.stride(0) : 1; // `index` is 1d or scalar
901     auto index_restrided = index.as_strided(
902       index_sizes, index_strides);
903 
904     // Prepare `result` for TensorIterator.
905     // Restride `result` to not advance in dimension `dim`.
906     // We do not use squash_dim here because `index` will
907     // need to advance in this dimension.
908     // Note that self_sizes[dim] is set to index.numel().
909     // This is done so that self_sizes[dim] and index_sizes[dim]
910     // match as required by TensorIterator (input shape should
911     // strictly broadcast over output shape, i.e.
912     // output.shape[i] >= input.shape[i] for i in range(dims)).
913     auto result_sizes = result_nonzero.sizes().vec();
914     auto result_strides = result_nonzero.strides().vec();
915     result_sizes[dim] = index.numel();
916     result_strides[dim] = 0;
917     auto result_restrided = result_nonzero.as_strided(result_sizes, result_strides);
918 
919     auto iter = TensorIteratorConfig()
920       // We do not check for overlap because `result` is restrided
921       // with zero stride. Zero strides trigger memory overlap assert
922       // within TensorIterator.
923       .set_check_mem_overlap(false)
924       .check_all_same_dtype(false)
925       .resize_outputs(false)
926       .add_output(result_restrided)
927       .add_const_input(index_restrided)
928       .add_const_input(source_nonzero)
929       .build();
930 
931     auto result_dim_size = result_nonzero.size(dim);
932     auto result_dim_stride = result_nonzero.stride(dim);
933     index_copy_stub(
934       iter.device_type(),
935       iter,
936       dim,
937       result_dim_size,
938       result_dim_stride);
939 }
940 
941 // Not calling into index_reduce_func_impl because of a different dtype dispatch
TORCH_IMPL_FUNC(index_add_cpu_out)942 TORCH_IMPL_FUNC(index_add_cpu_out)
943 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) {
944   if (!result.is_same(self)) {
945      result.copy_(self);
946   }
947   auto numel = index.numel();
948 
949   auto index_contig = index.contiguous();
950 
951   if (result.dim() > 1) {
952     // Equivalent to:
953     //   for (const auto i : c10::irange(numel)) {
954     //     auto selfSlice = self.select(dim, index_data[i]);
955     //     auto sourceSlice = source.select(dim, i);
956     //     selfSlice.add_(sourceSlice);
957     //   }
958     // But much faster as this reuses the iterator from add_
959     if (numel == 0 || self.numel() == 0) {
960       return;
961     }
962 
963     dim = maybe_wrap_dim(dim, self.dim());
964 
965     // When the slice of source or result is noncontiguous,
966     // original index_add is slow as it uses add for the sliced tensor,
967     // which is serial on index and parallel on sliced tensor to avoid write conflict.
968     // Doing parallel on the sliced tensor is not optimal as the size of sliced tensor
969     // may be not big enough to parallel and also causes multiple parallelizations.
970     // scatter_add is used to speedup for this case as scatter_add parallels on
971     // the outer dimension of input and is serial on the inner dimension to
972     // avoid write conflict. scatter_add only need one parallel and the size of
973     // outer dimensions is bigger to do parallel.
974 
975     if ((dim == 0 || dim == self.dim() - 1) &&
976         // Data type of index should be long and alpha should be 1 to use scatter_add.
977         alpha.equal(1.0) && index_contig.scalar_type() == ScalarType::Long &&
978         // scatter_add does not support ComplexHalf
979         source.scalar_type() != ScalarType::ComplexHalf &&
980         result.scalar_type() != ScalarType::ComplexHalf) {
981       std::vector<int64_t> ep_sizes(result.sizes().size());
982       std::vector<int64_t> ep_strides(source.sizes().size());
983 
984       // Check whether result and source are matched apart from the dimension dim.
985       // Note that the broadcast case:
986       // source.select(dim, i) is broadcast for result.select(dim, index_data[i])
987       // The broadcast case is not applicable for scatter_add
988       auto check_sizes = [&ep_sizes, &ep_strides, &numel](IntArrayRef a, IntArrayRef b, int64_t dim) -> bool {
989 
990         ep_sizes[dim] = numel;
991         ep_strides[dim] = 1;
992         for (const int64_t i : c10::irange(a.size())) {
993           if (i == dim) {
994             continue;
995           }
996 
997           if (a[i] != b[i]) {
998             return false;
999           }
1000           ep_sizes[i] = a[i];
1001           ep_strides[i] = 0;
1002 
1003         }
1004         return true;
1005       };
1006 
1007       if (check_sizes(result.sizes(), source.sizes(), dim)) {
1008         auto ep_index = index_contig.as_strided(ep_sizes, ep_strides);
1009         result.scatter_add_(dim, ep_index, source);
1010         return;
1011       }
1012     }
1013 
1014     auto selfSlice = result.select(dim, 0);
1015     auto sourceSlice = source.select(dim, 0);
1016     auto self_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());
1017     auto source_stride_bytes = source.stride(dim) * elementSize(source.scalar_type());
1018     auto self_dim_size = result.size(dim);
1019     auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice);
1020 
1021     AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cpu_", [&] () {
1022       auto index_data = index_contig.const_data_ptr<index_t>();
1023       for (const auto i : c10::irange(numel)) {
1024           auto self_i = index_data[i];
1025           TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
1026           auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
1027           auto source_data = static_cast<const char*>(sourceSlice.const_data_ptr()) + i * source_stride_bytes;
1028           iter.unsafe_replace_operand(0, self_data);
1029           iter.unsafe_replace_operand(1, self_data);
1030           iter.unsafe_replace_operand(2, const_cast<char*>(source_data));
1031           add_stub(iter.device_type(), iter, alpha);
1032       }
1033     });
1034   } else {
1035     TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
1036 
1037     // explicitly capture all required variables to work around windows build
1038     // TODO: fix this when windows can correctly capture variables in nested lambda
1039     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, ScalarType::ComplexHalf,
1040       result.scalar_type(), "index_add_", [&result, &source, &dim, &index_contig, &numel, &alpha] {
1041       auto alpha_value = alpha.to<scalar_t>();
1042       auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
1043       auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
1044       // TODO: Maybe TensorAccessor can be used here?
1045       auto* result_ptr = result.data_ptr<scalar_t>();
1046       auto* source_ptr = source.const_data_ptr<scalar_t>();
1047       AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_add_cpu_",
1048         [&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &alpha_value] {
1049         auto index_data = index_contig.const_data_ptr<index_t>();
1050         for (const auto i : c10::irange(numel)) {
1051             auto self_i = index_data[i];
1052             TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self");
1053             scalar_t *self_ip = result_ptr + self_i * result_stride;
1054             *self_ip += *(source_ptr + i * source_stride) * alpha_value;
1055         }
1056       });
1057     });
1058   }
1059 }
1060 
index_reduce_func_impl(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source,bool include_self,const Tensor & result,const ReductionType & op)1061 static void index_reduce_func_impl(
1062   const Tensor& self,
1063   int64_t dim,
1064   const Tensor& index,
1065   const Tensor& source,
1066   bool include_self,
1067   const Tensor& result,
1068   const ReductionType& op) {
1069   if (!result.is_same(self)) result.copy_(self);
1070   if (!include_self) {
1071     AT_DISPATCH_ALL_TYPES_AND2(
1072       at::ScalarType::Half, at::ScalarType::BFloat16,
1073       self.scalar_type(), "index_reduce_func_exclude_input_init", [&] {
1074       scalar_t init_val;
1075       switch (op) {
1076         case ReductionType::PROD:
1077           init_val = (scalar_t)1;
1078           break;
1079         case ReductionType::MAX:
1080           init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
1081                      : std::numeric_limits<scalar_t>::lowest();
1082           break;
1083         case ReductionType::MIN:
1084           init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
1085                      : std::numeric_limits<scalar_t>::max();
1086           break;
1087         default:
1088           init_val = (scalar_t)0;
1089           break;
1090       }
1091       // index_fill_ requires index to be a LongTensor
1092       result.index_fill_(dim, index.to(at::ScalarType::Long), init_val);
1093     });
1094   }
1095 
1096   auto numel = index.numel();
1097 
1098   auto index_contig = index.contiguous();
1099 
1100   if (result.dim() > 1) {
1101     // Equivalent to:
1102     //   for (const auto i : c10::irange(numel)) {
1103     //     auto selfSlice = self.select(dim, index_data[i]);
1104     //     auto sourceSlice = source.select(dim, i);
1105     //     selfSlice.op_(sourceSlice);
1106     //   }
1107     // But much faster as this reuses the iterator from the binary op
1108     if (numel == 0) {
1109       return;
1110     }
1111     auto selfSlice = result.select(dim, 0);
1112     auto sourceSlice = source.select(dim, 0);
1113     auto self_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());
1114     auto source_stride_bytes = source.stride(dim) * elementSize(source.scalar_type());
1115     auto self_dim_size = result.size(dim);
1116     auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice);
1117 
1118     AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_func_cpu_", [&] () {
1119       auto index_data = index_contig.const_data_ptr<index_t>();
1120       for (const auto i : c10::irange(numel)) {
1121         auto self_i = index_data[i];
1122         TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
1123         auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
1124         auto source_data = static_cast<const char*>(sourceSlice.const_data_ptr()) + i * source_stride_bytes;
1125         iter.unsafe_replace_operand(0, self_data);
1126         iter.unsafe_replace_operand(1, self_data);
1127         iter.unsafe_replace_operand(2, const_cast<char*>(source_data));
1128 
1129         switch (op) {
1130           case ReductionType::PROD :
1131             mul_stub(iter.device_type(), iter);
1132             break;
1133           case ReductionType::MIN :
1134             minimum_stub(iter.device_type(), iter);
1135             break;
1136           case ReductionType::MAX :
1137             maximum_stub(iter.device_type(), iter);
1138             break;
1139           default :
1140             add_stub(iter.device_type(), iter, 1);
1141             break;
1142         }
1143       }
1144     });
1145 
1146     if (op == ReductionType::MEAN) {
1147       auto counts = include_self ? at::ones_like(result) : at::zeros_like(result);
1148       counts.index_add_(dim, index, at::ones_like(source));
1149       counts.masked_fill_(counts == 0, 1);
1150       if (result.is_floating_point() || result.is_complex()) {
1151         result.div_(counts);
1152       } else {
1153         result.div_(counts, "floor");
1154       }
1155     }
1156   }
1157   else {
1158     TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
1159     auto counts = include_self ? at::ones_like(result) : at::zeros_like(result);
1160     // explicitly capture all required variables to work around windows build
1161     // TODO: fix this when windows can correctly capture variables in nested lambda
1162     AT_DISPATCH_ALL_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
1163       result.scalar_type(), "index_func_", [&result, &source, &dim, &index_contig, &numel, &op, &counts] {
1164       auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
1165       auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
1166       auto counts_stride = counts.dim() == 0 ? 1 : counts.stride(dim);
1167       // TODO: Maybe TensorAccessor can be used here?
1168       auto* result_ptr = result.data_ptr<scalar_t>();
1169       auto* source_ptr = source.const_data_ptr<scalar_t>();
1170       auto counts_ptr = counts.data_ptr<scalar_t>();
1171       AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_func_cpu_",
1172         [&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &op, &counts_ptr, &counts_stride] {
1173         auto index_data = index_contig.const_data_ptr<index_t>();
1174         for (const auto i : c10::irange(numel)) {
1175             auto self_i = index_data[i];
1176             TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self");
1177             scalar_t *self_ip = result_ptr + self_i * result_stride;
1178             scalar_t *count_ip;
1179             scalar_t val;
1180             switch (op) {
1181               case ReductionType::MEAN :
1182                 *self_ip += *(source_ptr + i * source_stride);
1183                 count_ip = counts_ptr + self_i * counts_stride;
1184                 *count_ip += 1;
1185                 break;
1186               case ReductionType::PROD :
1187                 *self_ip *= *(source_ptr + i * source_stride);
1188                 break;
1189               case ReductionType::MIN :
1190                 val = *(source_ptr + i * source_stride);
1191                 *self_ip = at::_isnan<scalar_t>(val) ? val : std::min(*self_ip, val);
1192                 break;
1193               case ReductionType::MAX :
1194                 val = *(source_ptr + i * source_stride);
1195                 *self_ip = at::_isnan<scalar_t>(val) ? val : std::max(*self_ip, val);
1196                 break;
1197               default:
1198                 break;
1199             }
1200         }
1201       });
1202     });
1203     if (op == ReductionType::MEAN) {
1204       counts.masked_fill_(counts == 0, 1);
1205       if (result.is_floating_point() || result.is_complex()) {
1206         result.div_(counts);
1207       } else {
1208         result.div_(counts, "floor");
1209       }
1210     }
1211   }
1212 }
1213 
TORCH_IMPL_FUNC(index_reduce_cpu_out)1214 TORCH_IMPL_FUNC(index_reduce_cpu_out)
1215 (const Tensor& self,
1216  int64_t dim,
1217  const Tensor& index,
1218  const Tensor& source,
1219  const c10::string_view reduce,
1220  bool include_input,
1221  const Tensor& result) {
1222   TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time.");
1223   auto op = get_operator_enum(reduce, true);
1224   index_reduce_func_impl(self, dim, index, source, include_input, result, op);
1225 }
1226 
1227 // Check that indices fall within dimension array size
1228 // Avoid redispatch call to min/max
1229 template <typename IndexType>
check_indexarray_range(const IndexType * indices,int64_t n,IndexType indexing_axis_dim)1230 static void check_indexarray_range(
1231     const IndexType* indices,
1232     int64_t n,
1233     IndexType indexing_axis_dim) {
1234   for (const auto i : c10::irange(n)) {
1235     auto idx = indices[i];
1236     TORCH_CHECK(
1237         0 <= idx && idx < indexing_axis_dim,
1238         "INDICES element is out of DATA bounds, id=",
1239         idx,
1240         " axis_dim=",
1241         indexing_axis_dim);
1242   }
1243 }
1244 
index_select_out_cpu_dim1_(Tensor & result_contig,const Tensor & self,const Tensor & index_contig)1245 static Tensor & index_select_out_cpu_dim1_(
1246     Tensor & result_contig, const Tensor & self, const Tensor & index_contig) {
1247 
1248   auto self_contig = self.contiguous();
1249   const caffe2::TypeMeta dataType = self_contig.dtype();
1250   size_t item_bytesize = dataType.itemsize();
1251 
1252   auto out = static_cast<char*>(result_contig.data_ptr());
1253 
1254   auto src_base = static_cast<const char*>(self_contig.const_data_ptr());
1255 
1256   auto self_sizes = self_contig.sizes();
1257   auto outer_dims_product = c10::size_to_dim_(1, self_sizes);
1258   auto block_size = c10::size_from_dim_(2, self_sizes);
1259   auto block_bytesize = block_size * item_bytesize;
1260 
1261   auto src_indexing_axis_dim = self_sizes[1];
1262   auto src_batch_bytesize = self_sizes[1] * block_bytesize;
1263   auto N = index_contig.numel();
1264 
1265   auto gathered_batch_bytesize = N * block_bytesize;
1266 
1267   AT_DISPATCH_INDEX_TYPES(
1268     index_contig.scalar_type(), "batch_index_select_compute", [&]() {
1269 
1270       const auto* idxs = index_contig.const_data_ptr<index_t>();
1271       check_indexarray_range<index_t>(idxs, N, src_indexing_axis_dim);
1272 
1273       // Special-case single-float copy for efficiency
1274       if (self.scalar_type() == ScalarType::Float && block_size == 1) {
1275         for (const auto batch : c10::irange(outer_dims_product)) {
1276           const float* src_floats =
1277               (const float*)(src_base + batch * src_batch_bytesize);
1278           float* dst_floats = (float*)(out + batch * gathered_batch_bytesize);
1279 
1280           for (const auto i : c10::irange(N)) {
1281             auto idx = idxs[i];
1282             dst_floats[i] = src_floats[idx];
1283           }
1284         }
1285       } else {
1286         // outer_dims_product specifies how many times we repeat inner dimensions,
1287         // so we just iterate over it to cover all outer dimensions.
1288         for (const auto batch : c10::irange(outer_dims_product)) {
1289           for (const auto i : c10::irange(N)) {
1290             auto idx = idxs[i];
1291             auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize;
1292             auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize;
1293             memcpy(dst, src, block_bytesize);
1294           }
1295         }
1296       }
1297   });
1298   return result_contig;
1299 }
1300 
index_select_out_cpu_(const Tensor & self,int64_t dim,const Tensor & index,Tensor & result)1301 Tensor & index_select_out_cpu_(const Tensor & self, int64_t dim, const Tensor & index, Tensor & result) {
1302   if (self.is_quantized()) {
1303     TORCH_CHECK(
1304         self.qscheme() == kPerTensorAffine,
1305         "Only per_tensor quantized quantized tensors are supported by index_select.")
1306   }
1307   dim = maybe_wrap_dim(dim, self.dim());
1308   auto numel = index.numel();
1309   TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector");
1310   TORCH_CHECK(!(self.dim() == 0 && numel != 1), "index_select(): Index to scalar can have only 1 value, got ", numel, " value(s)");
1311   TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index");
1312   TORCH_CHECK(self.scalar_type() == result.scalar_type(),
1313               "index_select(): self and result must have the same scalar type");
1314   at::assert_no_internal_overlap(result);
1315   at::assert_no_overlap(result, self);
1316   at::assert_no_overlap(result, index);
1317   auto result_size = self.sizes().vec();
1318   if (self.dim() > 0) {
1319     result_size[dim] = numel;
1320   }
1321   at::native::resize_output(result, result_size);
1322 
1323   auto index_contig = index.contiguous();
1324 
1325   if (self.dim() > 1) {
1326     if (numel == 0) {
1327       return result;
1328     }
1329     if (self.numel() == 0) {
1330       auto src_indexing_axis_dim = self.size(dim);
1331       TORCH_CHECK(src_indexing_axis_dim > 0,
1332                   "index_select(): self indexing axis dim should be positive");
1333       AT_DISPATCH_INDEX_TYPES(
1334       index_contig.scalar_type(), "index_select_empty_self_bound_check", [&]() {
1335         const auto* idxs = index_contig.const_data_ptr<index_t>();
1336         check_indexarray_range<index_t>(idxs, numel, src_indexing_axis_dim);
1337       });
1338       return result;
1339     }
1340 
1341     if (dim == 1 && result.is_contiguous()) {
1342       // fast pass
1343       return index_select_out_cpu_dim1_(result, self, index_contig);
1344     }
1345 
1346     auto selfSlice = self.select(dim, 0);
1347     auto resultSlice = result.select(dim, 0);
1348     auto selfSlice_data = selfSlice.const_data_ptr();
1349     auto resultSlice_data = resultSlice.data_ptr();
1350     auto self_stride_bytes = self.stride(dim) * elementSize(self.scalar_type());
1351     auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());
1352     auto self_dim_size = self.size(dim);
1353     auto slice_size = selfSlice.numel();
1354 
1355     auto iter = TensorIteratorConfig()
1356       .check_all_same_dtype(false)
1357       .resize_outputs(false)
1358       .add_output(resultSlice)
1359       .add_const_input(selfSlice)
1360       .build();
1361 
1362     auto grain_size = at::internal::GRAIN_SIZE;
1363     auto outer_loop =
1364       // explicitly capture all required variables to work around windows build
1365       // TODO: fix this when windows can correctly capture variables in nested lambda
1366       [&index_contig, &iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, &resultSlice_data,
1367         &result_stride_bytes](int64_t start, int64_t end) {
1368       auto sub_iter = TensorIterator(iter);
1369       AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
1370         [&index_contig, &start, &end, &sub_iter, &self_dim_size, &selfSlice_data, &self_stride_bytes,
1371           &resultSlice_data, &result_stride_bytes] () {
1372         auto index_data = index_contig.const_data_ptr<index_t>();
1373         for (const auto i : c10::irange(start, end)) {
1374           auto self_i = index_data[i];
1375           TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
1376           auto self_data = static_cast<const char*>(selfSlice_data) + self_i * self_stride_bytes;
1377           auto result_data = static_cast<char*>(resultSlice_data) + i * result_stride_bytes;
1378           sub_iter.unsafe_replace_operand(0, result_data);
1379           sub_iter.unsafe_replace_operand(1, const_cast<char*>(self_data));
1380           copy_stub(sub_iter.device_type(), sub_iter, false);
1381         };
1382       });
1383     };
1384 
1385     // parallel on inner loop in case the slice is large enough;
1386     // otherwise parallel on outer loop
1387     if (slice_size >= grain_size) {
1388       outer_loop(0, numel);
1389     } else {
1390       // use a fast loop when self and result are contiguous and of the same data type
1391       if (iter.is_contiguous() && self.scalar_type() == result.scalar_type()) {
1392         auto slice_size_bytes = slice_size * elementSize(self.scalar_type());
1393         // explicitly capture all required variables to work around windows build
1394         // TODO: fix this when windows can correctly capture variables in nested lambda
1395         at::parallel_for(0, numel, grain_size / slice_size,
1396           [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data,
1397             &self_stride_bytes, &resultSlice_data, &result_stride_bytes](int64_t start, int64_t end) {
1398           AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
1399             [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data,
1400               &self_stride_bytes, &resultSlice_data, &result_stride_bytes, &start, &end] () {
1401             auto index_data = index_contig.const_data_ptr<index_t>();
1402             for (const auto i : c10::irange(start, end)) {
1403               auto self_i = index_data[i];
1404               TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
1405               auto self_data = static_cast<const char*>(selfSlice_data) + self_i * self_stride_bytes;
1406               auto result_data = static_cast<char*>(resultSlice_data) + i * result_stride_bytes;
1407               memcpy(result_data, self_data, slice_size_bytes);
1408             }
1409           });
1410         });
1411       } else {
1412         at::parallel_for(0, numel, grain_size / slice_size, outer_loop);
1413       }
1414     }
1415   } else {
1416     TORCH_CHECK(result.dim() <= 1, "result.dim() (", result.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
1417     // explicitly capture all required variables to work around windows build
1418     // TODO: fix this when windows can correctly capture variables in nested lambda
1419     if(self.is_quantized()){
1420       AT_DISPATCH_QINT_TYPES(self.scalar_type(), "index_select_quant", [&index_contig, &self, &result, &dim, &numel] {
1421         auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
1422         auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
1423         auto self_data_ptr = self.const_data_ptr<scalar_t>();
1424         auto result_data_ptr = result.data_ptr<scalar_t>();
1425         auto self_numel = self.numel();
1426         AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_quant_",
1427           [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] {
1428           auto index_data = index_contig.const_data_ptr<index_t>();
1429           for (const auto i : c10::irange(numel)) {
1430             auto self_i = index_data[i];
1431             TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
1432             const scalar_t *self_ip = self_data_ptr + self_i * self_stride;
1433             *(result_data_ptr + i * result_stride) = *self_ip;
1434           }
1435         });
1436       });
1437     } else {
1438       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
1439         self.scalar_type(), "index_select", [&index_contig, &self, &result, &dim, &numel] {
1440         auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
1441         auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
1442 
1443         auto self_data_ptr = self.const_data_ptr<scalar_t>();
1444         auto result_data_ptr = result.data_ptr<scalar_t>();
1445         auto self_numel = self.numel();
1446         AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
1447           [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] {
1448           auto index_data = index_contig.const_data_ptr<index_t>();
1449           for (const auto i : c10::irange(numel)) {
1450             auto self_i = index_data[i];
1451             TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
1452             const scalar_t *self_ip = self_data_ptr + self_i * self_stride;
1453             *(result_data_ptr + i * result_stride) = *self_ip;
1454           }
1455         });
1456       });
1457     }
1458   }
1459 
1460   return result;
1461 }
1462 
index_select_cpu_(const Tensor & self,int64_t dim,const Tensor & index)1463 Tensor index_select_cpu_(const Tensor & self, int64_t dim, const Tensor & index) {
1464   Tensor result = at::empty({0}, self.options());
1465   return at::native::index_select_out_cpu_(self, dim, index, result);
1466 }
1467 
index_select_quantized_cpu_(const Tensor & self,int64_t dim,const Tensor & index)1468 Tensor index_select_quantized_cpu_(const Tensor & self, int64_t dim, const Tensor & index) {
1469   TORCH_CHECK(self.qscheme() == kPerTensorAffine,
1470               "Only per_tensor quantized quantized tensors are supported by index_select.")
1471   Tensor result = at::empty_quantized({0}, self);
1472   return at::native::index_select_out_cpu_(self, dim, index, result);
1473 }
1474 
index_select_backward_symint(const Tensor & grad,c10::SymIntArrayRef self_sizes,int64_t dim,const Tensor & index)1475 Tensor index_select_backward_symint(const Tensor& grad, c10::SymIntArrayRef self_sizes, int64_t dim, const Tensor& index) {
1476   // for composite compliance, use out-of-place variant of
1477   // `index_add` if index tensor is a Tensor Subclass.
1478   if (isTensorSubclassLike(index)) {
1479     return grad.new_zeros_symint(self_sizes, grad.options()).index_add(dim, index, grad);
1480   }
1481   return grad.new_zeros_symint(self_sizes, grad.options()).index_add_(dim, index, grad);
1482 }
1483 
index_fill_(Tensor & self,int64_t dim,const Tensor & index,const Scalar & source)1484 Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) {
1485   at::NoNamesGuard guard;
1486 
1487   TORCH_CHECK_INDEX(
1488     index.scalar_type() == ScalarType::Long,
1489     "index_fill_(): Expected dtype int64 for index.");
1490 
1491   at::assert_no_overlap(self, index);
1492   if (at::has_internal_overlap(self) == at::MemOverlap::Yes) {
1493     TORCH_WARN(
1494       "Use of index_fill_ on expanded tensors is deprecated. "
1495       "Please clone() the tensor before performing this operation. "
1496       "This also applies to advanced indexing e.g. tensor[mask] = scalar");
1497   }
1498 
1499   if (!self.is_complex() && source.isComplex()) {
1500     TORCH_CHECK(false, "index_fill_(): Converting complex Scalar to non-complex type is not supported");
1501   }
1502 
1503   // Handle the case when `self` is 0-dim
1504   Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self;
1505 
1506   dim = at::maybe_wrap_dim(dim, self_nonzero_dim);
1507   TORCH_CHECK(index.dim() <= 1, "Index has to be a vector/scalar");
1508 
1509   // Prepare `index` for TensorIterator.
1510   // It is restrided to be broadcastable over `self` in TensorIterator.
1511   auto index_sizes = std::vector<int64_t>(self_nonzero_dim.dim(), 1);
1512   auto index_strides = std::vector<int64_t>(self_nonzero_dim.dim(), 0);
1513   index_sizes[dim] = index.numel();
1514   index_strides[dim] = (index.dim() > 0) ? index.stride(0) : 1; // `index` is 1d or scalar
1515   auto index_restrided = index.as_strided(
1516     index_sizes, index_strides);
1517 
1518   // Prepare `self` for TensorIterator.
1519   // Restride `self` to not advance in dimension `dim`.
1520   // We do not use squash_dim here because `index` will
1521   // need to advance in this dimension.
1522   // Note that self_sizes[dim] is set to index.numel().
1523   // This is done so that self_sizes[dim] and index_sizes[dim]
1524   // match as required by TensorIterator (input shape should
1525   // strictly broadcast over output shape, i.e.
1526   // output.shape[i] >= input.shape[i] for i in range(dims)).
1527   auto self_sizes = self_nonzero_dim.sizes().vec();
1528   auto self_strides = self_nonzero_dim.strides().vec();
1529   self_sizes[dim] = index.numel();
1530   self_strides[dim] = 0;
1531   auto self_restrided = self_nonzero_dim.as_strided(self_sizes, self_strides);
1532 
1533   auto iter = TensorIteratorConfig()
1534     // We do not check for overlap because `self` is restrided
1535     // with zero stride. Zero strides trigger memory overlap assert
1536     // within TensorIterator.
1537     .set_check_mem_overlap(false)
1538     .check_all_same_dtype(false)
1539     .resize_outputs(false)
1540     .add_output(self_restrided)
1541     .add_const_input(index_restrided)
1542     .build();
1543 
1544   auto self_dim_size = (self_nonzero_dim.sizes())[dim];
1545   auto self_dim_stride = (self_nonzero_dim.strides())[dim];
1546   index_fill_stub(
1547     iter.device_type(),
1548     iter,
1549     dim,
1550     self_dim_size,
1551     self_dim_stride,
1552     source);
1553 
1554   return self;
1555 }
1556 
index_fill_(Tensor & self,int64_t dim,const Tensor & index,const Tensor & source)1557 Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
1558   TORCH_CHECK(source.dim() == 0, "index_fill_ only supports a 0-dimensional value tensor, but got tensor "
1559       "with ", source.dim(), " dimension(s).");
1560   return self.index_fill_(dim, index, source.item());
1561 }
1562 
index_fill(const Tensor & self,int64_t dim,const Tensor & index,const Scalar & source)1563 Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) {
1564   return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source);
1565 }
1566 
index_fill(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source)1567 Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
1568   return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source);
1569 }
1570 
1571 // fast paths for GNN usage
can_use_expanded_index_path(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,bool is_scatter_like)1572 static bool can_use_expanded_index_path(
1573     const Tensor& self,
1574     int64_t dim,
1575     const Tensor& index,
1576     const Tensor& src,
1577     bool is_scatter_like) {
1578 #ifdef USE_FBGEMM
1579   if (!fbgemm::is_radix_sort_accelerated_with_openmp()) {
1580     return false;
1581   }
1582 #else
1583   return false;
1584 #endif
1585 
1586   if (!self.device().is_cpu()) {
1587     return false;
1588   }
1589 
1590   const auto st = self.scalar_type();
1591   if (!(c10::isFloatingType(st))) {
1592     return false;
1593   }
1594 
1595   // skip when having empty tensor
1596   if (self.numel() == 0 || index.numel() == 0 || src.numel() == 0) {
1597     return false;
1598   }
1599 
1600   // skip when having scalar tensor
1601   if (self.ndimension() == 0 || index.ndimension() == 0 || src.ndimension() == 0) {
1602     return false;
1603   }
1604 
1605   // allow only different size on dim 0 for src and index
1606   // https://github.com/pytorch/pytorch/issues/99595
1607   for (const auto dim : c10::irange(1, index.dim())) {
1608     if (src.size(dim) != index.size(dim)) {
1609       return false;
1610     }
1611   }
1612 
1613   if (is_scatter_like) {
1614     // using `spmm` for scatter would require sorting on index,
1615     // this is only perf beneficial when the inner dimension, aka, `channels`
1616     // is big enough.
1617     constexpr int64_t threshold = 16;
1618     if (index.numel() / index.size(0) < threshold) {
1619       return false;
1620     }
1621   }
1622 
1623   // usually the expanded index has stride on the first dimension to be 1,
1624   // and strides on other dims to be 0 or 1, e.g.
1625   //   shape [108365, 16]; strides [1, 0]
1626   //   shape [13264, 1, 7]; strides [1, 1, 0]
1627   auto index_strides = index.strides().vec();
1628   bool is_index_expanded = index_strides[0] == 1;
1629   for (const auto dim : c10::irange(1, index_strides.size())) {
1630     if (index_strides[dim] > 1) { is_index_expanded = false; }
1631   }
1632 
1633   // index is expanded
1634   return dim == 0 && is_index_expanded && src.is_contiguous() && self.is_contiguous();
1635 }
1636 
1637 // gather_out_cpu_cuda
TORCH_IMPL_FUNC(gather_out)1638 TORCH_IMPL_FUNC(gather_out)
1639 (const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& result) {
1640   if (index.numel() == 0) return;
1641   dim = at::maybe_wrap_dim(dim, self.dim());
1642   if (can_use_expanded_index_path(result, dim, index, self, /*is_scatter_like=*/false)) {
1643     gather_expanded_index_stub(result.device().type(), result, self, index);
1644   } else {
1645     gather_stub(result.device().type(), result, self, dim, index);
1646   }
1647 }
1648 
gather_backward(const Tensor & grad,const Tensor & self,int64_t dim,const Tensor & index,bool sparse_grad)1649 Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad) {
1650   if (sparse_grad) {
1651     return at::_gather_sparse_backward(self, dim, index, grad);
1652   }
1653   auto result = grad.new_zeros_symint(self.sym_sizes());
1654   // for composite, vmap and inductor compliance, use out-of-place variant of
1655   // `scatter_add` if index or grad tensors is a Tensor Subclass.
1656   if (areAnyTensorSubclassLike({index, grad})) {
1657     return result.scatter_add(dim, index, grad);
1658   }
1659   result.scatter_add_(dim, index, grad);
1660   return result;
1661 }
1662 
scatter_reduce_exclude_self_helper(const Tensor & self,int64_t dim,const Tensor & index,const ReductionType & op)1663 static void scatter_reduce_exclude_self_helper(
1664   const Tensor& self,
1665   int64_t dim,
1666   const Tensor& index,
1667   const ReductionType& op) {
1668   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
1669     at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool,
1670     self.scalar_type(), "scatter_reduce_exclude_input_init", [&] {
1671     scalar_t init_val;
1672     switch (op) {
1673       case ReductionType::SUM:
1674         init_val = (scalar_t)0;
1675         break;
1676       case ReductionType::PROD:
1677         init_val = (scalar_t)1;
1678         break;
1679       case ReductionType::MAX:
1680         init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
1681                    : std::numeric_limits<scalar_t>::lowest();
1682         break;
1683       case ReductionType::MIN:
1684         init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
1685                    : std::numeric_limits<scalar_t>::max();
1686         break;
1687       case ReductionType::MEAN:
1688         init_val = (scalar_t)0;
1689         break;
1690     }
1691     self.scatter_(dim, index, init_val);
1692   });
1693 }
1694 
_scatter_via_index_put(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,const Tensor & mut_out,bool accumulate)1695 static void _scatter_via_index_put(
1696   const Tensor& self,
1697   int64_t dim,
1698   const Tensor& index,
1699   const Tensor& src,
1700   const Tensor& mut_out,
1701   bool accumulate) {
1702   if (self.dim() == 1) {
1703     torch::List<std::optional<Tensor>> indices;
1704     indices.reserve(1);
1705     indices.push_back(index);
1706     mut_out.index_put_(indices, src, accumulate);
1707   } else {
1708     Tensor mut_out_contig = mut_out.contiguous();
1709 
1710     auto index_coords_sizes = index.sizes().vec();
1711     index_coords_sizes.push_back(self.dim());
1712     auto index_coords = at::empty(
1713       index_coords_sizes,
1714       at::TensorOptions().dtype(at::ScalarType::Long).device(self.device()));
1715 
1716     for (int64_t dim_other = 0; dim_other < self.dim(); dim_other++) {
1717       if (dim_other == dim) {
1718         continue;
1719       }
1720       auto dim_coord_vals = at::arange(
1721         index.size(dim_other),
1722         at::TensorOptions().device(self.device()));
1723 
1724       for (int64_t dim_unsqueeze = 0; dim_unsqueeze < self.dim() - 1; dim_unsqueeze++) {
1725         dim_coord_vals = dim_coord_vals.unsqueeze((dim_unsqueeze >= dim_other) ? -1 : 0);
1726       }
1727 
1728       auto view_sizes = index.sizes().vec();
1729       view_sizes.push_back(1);
1730       auto view_strides = index_coords.strides().vec();
1731       view_strides[self.dim()] = self.dim();
1732 
1733       at::as_strided(
1734         index_coords,
1735         view_sizes,
1736         view_strides,
1737         dim_other
1738       ).copy_(dim_coord_vals.unsqueeze(-1));
1739     }
1740 
1741     auto view_sizes = index.sizes().vec();
1742     view_sizes.push_back(1);
1743     auto view_strides = index_coords.strides().vec();
1744     view_strides[self.dim()] = self.dim();
1745 
1746     at::as_strided(
1747       index_coords,
1748       view_sizes,
1749       view_strides,
1750       dim
1751     ).copy_(index.unsqueeze(-1));
1752 
1753     Tensor index_coords_flat = index_coords.flatten(0, -2);
1754 
1755     // Copy mut_out_contig's strides into a tensor
1756     // TODO: Is there a utility function that already does this?
1757     IntArrayRef mut_out_contig_strides = mut_out_contig.strides();
1758     Tensor coord_strides = at::empty(
1759       {mut_out_contig.dim()},
1760       TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU));
1761     std::memcpy(
1762       coord_strides.mutable_data_ptr(),
1763       mut_out_contig_strides.data(),
1764       coord_strides.nbytes());
1765     coord_strides = coord_strides.to(mut_out_contig.device());
1766 
1767     // `index_flat` contains the 1-D indices corresponding with the
1768     // flattened `mut_out`
1769     Tensor index_flat = (index_coords_flat * coord_strides).sum({-1});
1770     Tensor mut_out_flat = mut_out_contig.flatten();
1771     Tensor src_flat = at::as_strided(
1772       src,
1773       index.sizes(),
1774       src.strides()
1775     ).flatten();
1776 
1777     torch::List<std::optional<Tensor>> indices;
1778     indices.reserve(1);
1779     indices.push_back(index_flat);
1780 
1781     mut_out_flat.index_put_(indices, src_flat, accumulate);
1782 
1783     if (!mut_out.is_contiguous()) {
1784       mut_out.copy_(mut_out_flat.reshape(mut_out.sizes()));
1785     }
1786   }
1787 }
1788 
1789 template <bool use_new_options = false, typename T, typename ReduceStub, typename FillStub>
scatter_impl(const Tensor & self,int64_t dim,const Tensor & index,const T & src,const Tensor & out,ReduceStub & reduce_stub,FillStub & fill_stub,const std::optional<c10::string_view> reduce=std::nullopt,bool reduce_includes_self=true)1790 void scatter_impl(
1791     const Tensor& self,
1792     int64_t dim,
1793     const Tensor& index,
1794     const T& src,
1795     const Tensor& out,
1796     ReduceStub& reduce_stub,
1797     FillStub& fill_stub,
1798     const std::optional<c10::string_view> reduce = std::nullopt,
1799     bool reduce_includes_self = true) {
1800 
1801   dim = at::maybe_wrap_dim(dim, self.dim());
1802   auto mut_out = const_cast<Tensor&>(out);
1803 
1804   if (!self.is_same(mut_out)) {
1805     mut_out.copy_(self);
1806   }
1807 
1808   if (index.numel() == 0) return;
1809 
1810   auto op = ReductionType::SUM;
1811   bool deterministic = globalContext().deterministicAlgorithms() && (self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::XPU);
1812 
1813   if (reduce.has_value()) {
1814     op = get_operator_enum(reduce.value(), use_new_options);
1815     if (!reduce_includes_self) {
1816       // scatter inits for reduction to appropriate indices (used by scatter_reduce.two)
1817       scatter_reduce_exclude_self_helper(mut_out, dim, index, op);
1818     }
1819     // _scatter_via_index_put can only handle sum and mean reduction type
1820     deterministic = deterministic && (op == ReductionType::SUM || op == ReductionType::MEAN);
1821   }
1822 
1823   // Scalar src should already be deterministic
1824   if (deterministic && std::is_same_v<T, Tensor>) {
1825     // both runtime and compile check are required
1826     if constexpr (std::is_same_v<T, Tensor>) {
1827       bool accumulate = reduce.has_value();
1828       _scatter_via_index_put(self, dim, index, src, mut_out, accumulate);
1829       return;
1830     }
1831   }
1832 
1833   if (reduce.has_value()) {
1834     reduce_stub(self.device().type(), mut_out, dim, index, src, op);
1835   } else {
1836     fill_stub(self.device().type(), mut_out, dim, index, src);
1837   }
1838 }
1839 
TORCH_IMPL_FUNC(scatter_src_out)1840 TORCH_IMPL_FUNC(scatter_src_out)
1841 (const Tensor& self,
1842  int64_t dim,
1843  const Tensor& index,
1844  const Tensor& src,
1845  const Tensor& out) {
1846   scatter_impl(self, dim, index, src, out,
1847                scatter_reduce_stub,
1848                scatter_stub);
1849 }
1850 
TORCH_IMPL_FUNC(scatter_value_out)1851 TORCH_IMPL_FUNC(scatter_value_out)
1852 (const Tensor& self,
1853  int64_t dim,
1854  const Tensor& index,
1855  const Scalar& value,
1856  const Tensor& out) {
1857   scatter_impl(self, dim, index, value, out,
1858                scatter_scalar_reduce_stub,
1859                scatter_fill_stub);
1860 }
1861 
TORCH_IMPL_FUNC(scatter_reduce_out)1862 TORCH_IMPL_FUNC(scatter_reduce_out)
1863 (const Tensor& self,
1864  int64_t dim,
1865  const Tensor& index,
1866  const Tensor& src,
1867  const c10::string_view reduce,
1868  const Tensor& out) {
1869   scatter_impl(self, dim, index, src, out,
1870                scatter_reduce_stub,
1871                scatter_stub,
1872                reduce);
1873 }
1874 
TORCH_IMPL_FUNC(scatter_value_reduce_out)1875 TORCH_IMPL_FUNC(scatter_value_reduce_out)
1876 (const Tensor& self,
1877  int64_t dim,
1878  const Tensor& index,
1879  const Scalar& value,
1880  const c10::string_view reduce,
1881  const Tensor& out) {
1882   scatter_impl(self, dim, index, value, out,
1883                scatter_scalar_reduce_stub,
1884                scatter_fill_stub,
1885                reduce);
1886 }
1887 
TORCH_IMPL_FUNC(scatter_add)1888 TORCH_IMPL_FUNC(scatter_add)
1889 (const Tensor& self,
1890  int64_t dim,
1891  const Tensor& index,
1892  const Tensor& src,
1893  const Tensor& out) {
1894   auto mut_out = const_cast<Tensor&>(out);
1895   dim = maybe_wrap_dim(dim, self.dim());
1896 
1897   if (!self.is_same(mut_out)) {
1898     mut_out.copy_(self);
1899   }
1900 
1901   if (index.numel() == 0) return;
1902 
1903   // See Note [Enabling Deterministic Operations]
1904   // Avoid gpuAtomicAdd for CUDA if deterministic mode is turned on
1905   if (globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA) {
1906     _scatter_via_index_put(self, dim, index, src, mut_out, /*accumulate*/true);
1907   } else {
1908     if (can_use_expanded_index_path(mut_out, dim, index, src, /*is_scatter_like*/true)) {
1909       scatter_add_expanded_index_stub(self.device().type(), mut_out, index, src);
1910     } else {
1911       scatter_add_stub(self.device().type(), mut_out, dim, index, src);
1912     }
1913   }
1914 }
1915 
TORCH_IMPL_FUNC(scatter_reduce_two)1916 TORCH_IMPL_FUNC(scatter_reduce_two)
1917 (const Tensor& self,
1918  int64_t dim,
1919  const Tensor& index,
1920  const Tensor& src,
1921  const c10::string_view reduce,
1922  bool include_self,
1923  const Tensor& out) {
1924 
1925   dim = at::maybe_wrap_dim(dim, self.dim());
1926 
1927   if (!self.is_same(out)) {
1928     out.copy_(self);
1929   }
1930 
1931   const auto op = get_operator_enum(reduce, true);
1932 
1933   if (can_use_expanded_index_path(out, dim, index, src, /*is_scatter_like*/true)) {
1934     scatter_reduce_expanded_index_stub(self.device().type(), out, index, src, op, include_self);
1935     return;
1936   }
1937 
1938   scatter_impl</*use_new_options=*/true>(self, dim, index, src, out,
1939                                          scatter_reduce_two_stub,
1940                                          scatter_stub,
1941                                          reduce,
1942                                          include_self);
1943 
1944   if (op == ReductionType::MEAN) {
1945     auto ones = at::ones_like(src);
1946     auto count = include_self ? at::ones_like(out) : at::zeros_like(out);
1947     count.scatter_add_(dim, index, ones);
1948     count.masked_fill_(count == 0, 1);
1949 
1950     if (out.is_floating_point() || out.is_complex()) {
1951       out.div_(count);
1952     } else {
1953       out.div_(count, "floor");
1954     }
1955   }
1956 }
1957 
masked_scatter(const Tensor & self,const Tensor & mask,const Tensor & source)1958 Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) {
1959   auto [_mask, _self] = expand_outplace(mask, self);
1960   return _self->clone(at::MemoryFormat::Contiguous).masked_scatter_(*_mask, source);
1961 }
1962 
masked_scatter_backward_symint(const Tensor & grad,const Tensor & mask,c10::SymIntArrayRef sizes)1963 Tensor masked_scatter_backward_symint(
1964     const Tensor& grad,
1965     const Tensor& mask,
1966     c10::SymIntArrayRef sizes) {
1967   c10::SymInt numel = 1;
1968   for (const auto& size : sizes) {
1969     numel *= size;
1970   }
1971   auto mask_selected = grad.masked_select(mask);
1972   auto diff_nelem = numel - mask_selected.sym_numel();
1973   if (diff_nelem > 0) {
1974     // because mask_selected returns a 1-d tensor with size of masked elements
1975     // that are 1, we need to fill out the rest with zeros then reshape back to
1976     // tensor2's size.
1977     auto zeros_fillin =
1978         at::zeros_symint({std::move(diff_nelem)}, grad.options());
1979     mask_selected = at::cat({mask_selected, std::move(zeros_fillin)}, 0);
1980   }
1981   return mask_selected.view_symint(sizes);
1982 }
1983 
masked_fill_impl_cpu(Tensor & self,const Tensor & mask,const Scalar & value)1984 static Tensor & masked_fill_impl_cpu(Tensor & self, const Tensor & mask, const Scalar& value) {
1985   NoNamesGuard guard;
1986   TORCH_CHECK(mask.dtype() == ScalarType::Bool, "masked_fill_ only supports boolean masks, but got mask "
1987       "with dtype ", mask.dtype());
1988 
1989   if (at::has_internal_overlap(self) == MemOverlap::Yes) {
1990     TORCH_WARN(
1991       "Use of masked_fill_ on expanded tensors is deprecated. "
1992       "Please clone() the tensor before performing this operation. "
1993       "This also applies to advanced indexing e.g. tensor[mask] = scalar");
1994   }
1995   at::assert_no_partial_overlap(self, mask);
1996 
1997   auto iter = TensorIteratorConfig()
1998     .set_check_mem_overlap(false)  // deprecated, but not a hard error
1999     .check_all_same_dtype(false)
2000     .resize_outputs(false)
2001     .add_output(self)
2002     .add_const_input(mask)
2003     .build();
2004 
2005   masked_fill_stub(iter.device_type(), iter, value);
2006   return self;
2007 }
2008 
masked_fill__cpu(Tensor & self,const Tensor & mask,const Scalar & value)2009 Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, const Scalar& value) {
2010   auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_");
2011 
2012   masked_fill_impl_cpu(self, mask, value);
2013   namedinference::propagate_names_if_nonempty(self, maybe_outnames);
2014   return self;
2015 }
2016 
masked_fill__cpu(Tensor & self,const Tensor & mask,const Tensor & value)2017 Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, const Tensor & value) {
2018   auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_");
2019   TORCH_CHECK(value.dim() == 0, "masked_fill_ only supports a 0-dimensional value tensor, but got tensor "
2020       "with ", value.dim(), " dimension(s).");
2021 
2022   masked_fill_impl_cpu(self, mask, value.item());
2023   namedinference::propagate_names_if_nonempty(self, maybe_outnames);
2024   return self;
2025 }
2026 
masked_fill(const Tensor & self,const Tensor & mask,const Scalar & source)2027 Tensor masked_fill(const Tensor & self, const Tensor & mask, const Scalar& source) {
2028   Tensor result;
2029   auto maybe_outnames = namedinference::broadcast_to_outnames(mask, self, "masked_fill");
2030   {
2031     NoNamesGuard guard;
2032     auto [_mask, _self] = expand_outplace(mask, self);
2033     result = _self->clone(at::MemoryFormat::Contiguous);
2034     result.masked_fill_(mask, source);
2035   }
2036   namedinference::propagate_names_if_nonempty(result, maybe_outnames);
2037   return result;
2038 }
2039 
masked_fill(const Tensor & self,const Tensor & mask,const Tensor & source)2040 Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & source) {
2041   Tensor result;
2042   auto maybe_outnames = namedinference::broadcast_to_outnames(mask, self, "masked_fill");
2043   {
2044     NoNamesGuard guard;
2045     auto [_mask, _self] = expand_outplace(mask, self);
2046     result = _self->clone(at::MemoryFormat::Contiguous);
2047     result.masked_fill_(mask, source);
2048   }
2049   namedinference::propagate_names_if_nonempty(result, maybe_outnames);
2050   return result;
2051 }
2052 
masked_select_out_impl_cpu(Tensor & result,const Tensor & self,const Tensor & mask)2053 static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, const Tensor & mask) {
2054   NoNamesGuard guard;
2055 
2056   TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
2057               "masked_select: expected BoolTensor for mask");
2058   TORCH_CHECK(self.scalar_type() == result.scalar_type(),
2059               "masked_select(): self and result must have the same scalar type");
2060 
2061   at::assert_no_internal_overlap(result);
2062   at::assert_no_overlap(result, self);
2063   at::assert_no_overlap(result, mask);
2064 
2065   auto [_mask, _self] = expand_outplace(mask, self);
2066 
2067   auto shape = _self->sizes();
2068   int64_t numel = _mask->sum().item().toLong();
2069   at::native::resize_output(result, {numel});
2070   if (numel == 0) {
2071     return result;
2072   }
2073 
2074   // Create strided view of result before feeding into TensorIterator
2075   auto strides = DimVector(shape.size(), 0);
2076   auto orig_stride = result.strides()[0];
2077   auto result_strided = result.as_strided(shape, strides);
2078 
2079   // serial kernel
2080   // serial kernel requires that src is traversed in its logical order. However, TensorIterator might
2081   // have reordered dimensions so that src would be traversed in its physical order, producing wrong
2082   // answers. A sufficient condition that no reorder happened is that both _self and _mask is contiguous.
2083   // If it is not satisfied, use parallel kernel that handles permutations correctly
2084   bool use_serial_kernel = (self.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ) &&
2085   _self->is_contiguous() && _mask->is_contiguous();
2086   if (use_serial_kernel) {
2087     auto iter = TensorIteratorConfig()
2088       .set_check_mem_overlap(false)  // result is intentionally zero-strided above
2089       .check_all_same_dtype(false)
2090       .resize_outputs(false)
2091       .add_output(result_strided)
2092       .add_const_input(*_self)
2093       .add_const_input(*_mask)
2094       .build();
2095 
2096     masked_select_serial_stub(iter.device_type(), iter, orig_stride);
2097     return result;
2098   }
2099 
2100   // Use a prefix sum to record the output locations of the masked elements,
2101   // so as to parallel with TensorIterator.
2102   auto mask_long = at::empty(shape, self.options().dtype(at::kLong)).copy_(*_mask);
2103   auto mask_prefix_sum = at::empty(shape, self.options().dtype(at::kLong));
2104   auto mask_long_data = mask_long.data_ptr<int64_t>();
2105   auto mask_prefix_sum_data = mask_prefix_sum.data_ptr<int64_t>();
2106   // TODO: Here can only use std::partial_sum for C++14,
2107   // use std::exclusive_scan when PyTorch upgrades to C++17, which have better performance.
2108   // std::exclusive_scan(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data, 0);
2109   std::partial_sum(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data);
2110 
2111   auto iter = TensorIteratorConfig()
2112     .set_check_mem_overlap(false)  // result is intentionally zero-strided above
2113     .check_all_same_dtype(false)
2114     .resize_outputs(false)
2115     .add_output(result_strided)
2116     .add_const_input(*_self)
2117     .add_const_input(*_mask)
2118     .add_const_input(mask_prefix_sum)
2119     .build();
2120 
2121   masked_select_stub(iter.device_type(), iter, orig_stride);
2122   return result;
2123 }
2124 
masked_select_out_cpu(const Tensor & self,const Tensor & mask,Tensor & result)2125 Tensor & masked_select_out_cpu(const Tensor & self, const Tensor & mask, Tensor & result) {
2126   namedinference::compute_broadcast_outnames(self, mask);
2127   return masked_select_out_impl_cpu(result, self, mask);
2128 }
2129 
masked_select_cpu(const Tensor & self,const Tensor & mask)2130 Tensor masked_select_cpu(const Tensor & self, const Tensor & mask) {
2131   Tensor result = at::empty({0}, self.options());
2132   return at::native::masked_select_out_cpu(self, mask, result);
2133 }
2134 
masked_select_backward(const Tensor & grad,const Tensor & input,const Tensor & mask)2135 Tensor masked_select_backward(const Tensor& grad, const Tensor& input, const Tensor& mask) {
2136   // The following could just be written as `zeros_like(input).masked_scatter(mask, grad)`.
2137   // However, as an optimization, we call the in-place variant of masked_scatter.
2138   // Unfortunately, that doesn't allow for the broadcasting of the LHS, so we need
2139   // to explicitly broadcast here (the out-of-place variant of masked_scatter
2140   // implicitly handles broadcasting).
2141   auto result = at::zeros_like(
2142       input.expand(at::infer_size(input.sizes(), mask.sizes())), at::MemoryFormat::Preserve);
2143 
2144   // for composite compliance, use out-of-place variant
2145   // of `masked_scatter`.
2146   if (areAnyTensorSubclassLike({grad, mask})) {
2147     return result.masked_scatter(mask, grad);
2148   }
2149   result.masked_scatter_(mask, grad);
2150   return result;
2151 }
2152 
2153 namespace {
2154 
_take_along_dim_helper(const Tensor & self,const Tensor & indices,int64_t dim)2155 inline std::tuple<Tensor, Tensor, int64_t> _take_along_dim_helper(
2156     const Tensor& self,
2157     const Tensor& indices,
2158     int64_t dim) {
2159   TORCH_CHECK(
2160       self.dim() == indices.dim(),
2161       "torch.take_along_dim(): input and indices should have the same number of dimensions, ",
2162       "but got ", self.dim(), " dimensions for input, and ", indices.dim(), " dimensions for indices")
2163   TORCH_CHECK(
2164       indices.scalar_type() == ScalarType::Long,
2165       "torch.take_along_dim(): dtype of indices should be Long but got ", indices.scalar_type())
2166 
2167   dim = at::maybe_wrap_dim(dim, self.dim());
2168 
2169   SymDimVector self_sizes{self.sym_sizes()};
2170   // update number of elements at dim as per indices
2171   self_sizes[dim] = indices.sym_size(dim);
2172   auto broadcast_shape = infer_size_symint(self_sizes, indices.sym_sizes());
2173   auto indices_broadcasted = at::broadcast_to_symint(indices, broadcast_shape);
2174 
2175   SymDimVector indices_sizes{indices.sym_sizes()};
2176   // update number of elements at dim as per self
2177   indices_sizes[dim] = self.sym_size(dim);
2178   broadcast_shape = infer_size_symint(indices_sizes, self.sym_sizes());
2179   auto self_broadcasted = at::broadcast_to_symint(self, broadcast_shape);
2180 
2181   return std::make_tuple(std::move(self_broadcasted),
2182                          std::move(indices_broadcasted),
2183                          std::move(dim));
2184 }
2185 
checkDevice(CheckedFrom c,const Tensor & t,Device device)2186 static inline void checkDevice(CheckedFrom c, const Tensor& t, Device device) {
2187   TORCH_CHECK(
2188       !t.defined() || t.device() == device,
2189       "Expected tensor to have ", device,
2190       " Device, but got tensor with ", t.device(), " Device ",
2191       "(while checking arguments for ", c, ")");
2192 }
2193 
checkDevice(CheckedFrom c,at::ArrayRef<Tensor> tensors,Device device)2194 static inline void checkDevice(CheckedFrom c, at::ArrayRef<Tensor> tensors, Device device) {
2195   for (auto &t : tensors) {
2196     checkDevice(c, t, device);
2197   }
2198 }
2199 
2200 } // anonymous namespace
2201 
take_along_dim(const Tensor & self,const Tensor & indices,std::optional<int64_t> opt_dim)2202 Tensor take_along_dim(const Tensor& self, const Tensor& indices, std::optional<int64_t> opt_dim) {
2203   checkDevice("torch.take_along_dim():", {self, indices}, self.device());
2204   if (opt_dim.has_value()) {
2205     auto [self_broadcasted, indices_broadcasted, dim] =
2206         _take_along_dim_helper(self, indices, opt_dim.value());
2207     return self_broadcasted.gather(dim, indices_broadcasted);
2208   }
2209 
2210   // similar to `take`, but `take` doesn't support the same dtypes as `gather`.
2211   return self.view(-1).gather(0, indices.view(-1));
2212 }
2213 
take_along_dim_out(const Tensor & self,const Tensor & indices,std::optional<int64_t> opt_dim,Tensor & result)2214 Tensor& take_along_dim_out(const Tensor& self, const Tensor& indices, std::optional<int64_t> opt_dim, Tensor& result) {
2215   checkDevice("torch.take_along_dim():", {self, indices, result}, self.device());
2216   if (opt_dim.has_value()) {
2217     auto [self_broadcasted, indices_broadcasted, dim] =
2218         _take_along_dim_helper(self, indices, opt_dim.value());
2219     return at::gather_out(result, self_broadcasted, dim, indices_broadcasted);
2220   }
2221 
2222   // similar to `take`, but `take` doesn't support the same dtypes as `gather`.
2223   return at::gather_out(result, self.view(-1), 0, indices.view(-1));
2224 }
2225 
_gather_sparse_backward(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & grad)2226 Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){
2227 // special case scalar input and/or index
2228     if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(at::empty_symint({0,grad.sym_numel()}, index.options()), grad, self.sym_sizes());
2229     if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(index.view({1,1}), grad, self.sym_sizes());
2230     Tensor sparse_ind = at::empty_symint({self.ndimension(), grad.sym_numel()}, self.options().dtype(at::kLong));
2231     SymInt grad_numel = grad.sym_numel();
2232     if (grad_numel > 0) {
2233       SymInt n_above = grad_numel;
2234       SymInt n_below = 1;
2235       if (dim < 0) dim += self.ndimension();
2236       for (const auto i : c10::irange(self.ndimension())) {
2237           n_above /= grad.sym_size(i);
2238           if (i == dim) {
2239               sparse_ind[i] = index.reshape(-1);
2240           } else {
2241               sparse_ind[i] = at::arange(grad.sym_size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand_symint({grad.sym_size(i), n_above}).reshape(-1).repeat_symint(n_below);
2242           }
2243           n_below *= grad.sym_size(i);
2244       }
2245     }
2246     return at::_sparse_coo_tensor_unsafe_symint(sparse_ind, grad.reshape(-1), self.sym_sizes());
2247 }
2248 
2249 template <typename scalar_t>
count_nonzero_impl(TensorIteratorBase & iter,Range range)2250 int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) {
2251   int64_t num_nonzero = 0;
2252 
2253   auto loop = [&](char** data, const int64_t* strides, int64_t n) {
2254     constexpr int ilp_factor = 4;
2255     const char* ptr = data[0];
2256     const auto stride = strides[0];
2257     int64_t nonzero[ilp_factor] = {0};
2258 
2259     int64_t i = 0;
2260     for (; i + (ilp_factor - 1) < n; i += ilp_factor) {
2261       c10::ForcedUnroll<ilp_factor>{}([&](int k) {
2262         const auto& val = c10::load<scalar_t>(ptr + k * stride);
2263         if (val != scalar_t(0)) {
2264           ++nonzero[k];
2265         }
2266       });
2267       ptr += ilp_factor * stride;
2268     }
2269     for (; i < n; ++i) {
2270       const auto& val = c10::load<scalar_t>(ptr);
2271       if (val != scalar_t(0)) {
2272         ++nonzero[0];
2273       }
2274       ptr += stride;
2275     }
2276     for (const auto k : c10::irange(1, ilp_factor)) {
2277       nonzero[0] += nonzero[k];
2278     }
2279     num_nonzero += nonzero[0];
2280   };
2281   iter.serial_for_each(loop, range);
2282 
2283   return num_nonzero;
2284 }
2285 
count_nonzero_cuda(const Tensor & self,IntArrayRef dims)2286 Tensor count_nonzero_cuda(const Tensor& self, IntArrayRef dims){
2287   auto reduce = self;
2288   if (reduce.scalar_type() != kBool) {
2289     reduce = reduce != 0;
2290   }
2291   return reduce.sum(dims);
2292 }
2293 
count_nonzero_cpu(const Tensor & self,IntArrayRef dims)2294 Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){
2295   if (!dims.empty()) {
2296     auto reduce = self;
2297     if (reduce.scalar_type() != kBool) {
2298       reduce = reduce != 0;
2299     }
2300     return reduce.sum(dims);
2301   }
2302 
2303   // Optimized all-reduce
2304   auto iter = TensorIteratorConfig()
2305       .add_const_input(self)
2306       .build();
2307 
2308   const auto num_threads = at::get_num_threads();
2309   DimVector thread_count_nonzero(num_threads);
2310 
2311   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
2312       kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] {
2313     at::parallel_for(0, iter.numel(), internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
2314       const auto tid = at::get_thread_num();
2315       thread_count_nonzero[tid] = count_nonzero_impl<scalar_t>(iter, {begin, end});
2316     });
2317   });
2318 
2319   for (const auto i : c10::irange(1, num_threads)) {
2320     thread_count_nonzero[0] += thread_count_nonzero[i];
2321   }
2322   auto out = at::empty({}, self.options().dtype(kLong));
2323   *out.mutable_data_ptr<int64_t>() = thread_count_nonzero[0];
2324   return out;
2325 }
2326 
2327 
count_nonzero(const Tensor & self,std::optional<int64_t> dim)2328 Tensor count_nonzero(const Tensor& self, std::optional<int64_t> dim) {
2329   if (dim) {
2330     return at::count_nonzero(self, IntArrayRef{*dim});
2331   }
2332   return at::count_nonzero(self, IntArrayRef{});
2333 }
2334 
2335 
nonzero_out_cpu(const Tensor & self,Tensor & result)2336 Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) {
2337   TORCH_CHECK(result.scalar_type() == kLong,
2338               "nonzero: Expected out tensor to have scalar type Long "
2339               "but got scalar type", result.scalar_type());
2340   at::assert_no_internal_overlap(result);
2341   at::assert_no_overlap(result, self);
2342 
2343   auto iter = TensorIteratorConfig()
2344     .add_const_input(self)
2345     .enforce_linear_iteration()
2346     .build();
2347 
2348   const auto numel = iter.numel();
2349   const auto num_threads = at::get_num_threads();
2350   DimVector thread_begin(num_threads, -1);
2351   DimVector thread_count_nonzero(num_threads + 1);
2352 
2353   // Pass 1: Count nonzero element per-thread
2354   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
2355       kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] {
2356     at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
2357       const auto tid = at::get_thread_num();
2358       thread_begin[tid] = begin;
2359       thread_count_nonzero[tid + 1] = count_nonzero_impl<scalar_t>(iter, {begin, end});
2360     });
2361   });
2362 
2363   // Convert thread-local counts to cumulative sum
2364   for (const auto i : c10::irange(1, thread_count_nonzero.size())) {
2365     thread_count_nonzero[i] += thread_count_nonzero[i - 1];
2366   }
2367 
2368   const auto self_sizes = self.sizes();
2369   const auto total_nonzero = thread_count_nonzero.back();
2370   const int64_t ndim = self_sizes.size();
2371   if (resize_output(result, {total_nonzero, ndim})) {
2372     // Default to fortran-contiguous output (see gh-46224)
2373     result.as_strided_({total_nonzero, ndim}, {1, total_nonzero});
2374   }
2375 
2376   if (result.numel() == 0) {
2377     return result;
2378   }
2379 
2380   auto out_accessor = result.accessor<int64_t, 2>();
2381 
2382   // Pass 2: Write indexes
2383   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
2384       kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_cpu", [&] {
2385     at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
2386       auto tid = at::get_thread_num();
2387       // Work needs to be distributed the same on both passes
2388       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(begin == thread_begin[tid]);
2389 
2390       // +1 faster than additional condition check inside loop
2391       c10::SmallVector<int64_t, 33> sizes(ndim + 1, -1);
2392       std::copy(self_sizes.begin(), self_sizes.end(), sizes.begin() + 1);
2393       c10::SmallVector<int64_t, 33> current_idx(ndim + 1);
2394       if (begin > 0) {
2395         auto idx = begin;
2396         for (int64_t k = ndim; idx > 0 && k > 0; --k) {
2397           current_idx[k] = idx % sizes[k];
2398           idx /= sizes[k];
2399         }
2400       }
2401 
2402       auto out_ptr = out_accessor[thread_count_nonzero[tid]].data();
2403 
2404       auto loop = [&](char** data, const int64_t* strides, int64_t n1, int64_t n2) {
2405         // Copy into local variables to improve compiler alias analysis
2406         int64_t* C10_RESTRICT local_idx = current_idx.data() + 1;
2407         const int64_t* C10_RESTRICT local_sizes = sizes.data() + 1;
2408         const auto in_stride = strides[0];
2409         const auto out_stride1 = out_accessor.stride(1);
2410         const auto out_stride0 = out_accessor.stride(0) - ndim * out_stride1;
2411         const auto ndim = out_accessor.size(1);
2412         int64_t* out = out_ptr;
2413 
2414         for (const auto i : c10::irange(n2)) {
2415           const char* ptr = data[0] + i * strides[1];
2416           for (C10_UNUSED const auto j : c10::irange(n1)) {
2417             const auto& val = c10::load<scalar_t>(ptr);
2418             // If nonzero, write index
2419             if (val != scalar_t(0)) {
2420               for (const auto k : c10::irange(ndim)) {
2421                 *out = local_idx[k];
2422                 out += out_stride1;
2423               }
2424               out += out_stride0;
2425             }
2426             ptr += in_stride;
2427 
2428             // Advance current index
2429             int64_t k = ndim - 1;
2430             ++local_idx[k];
2431             while (C10_UNLIKELY(local_idx[k] == local_sizes[k])) {
2432               local_idx[k] = 0;
2433               --k;
2434               ++local_idx[k];
2435             }
2436           }
2437         }
2438         out_ptr = out;
2439       };
2440       iter.serial_for_each(loop, {begin, end});
2441       TORCH_INTERNAL_ASSERT(out_ptr == out_accessor[thread_count_nonzero[tid + 1]].data());
2442     });
2443   });
2444   return result;
2445 }
2446 
nonzero_cpu(const Tensor & self)2447 Tensor nonzero_cpu(const Tensor& self) {
2448   auto result = at::empty({0}, self.options().dtype(kLong));
2449   nonzero_out_cpu(self, result);
2450   return result;
2451 }
2452 
nonzero_static_out_cpu(const Tensor & self,int64_t size,int64_t fill_value,Tensor & result)2453 Tensor& nonzero_static_out_cpu(
2454     const Tensor& self,
2455     int64_t size,
2456     int64_t fill_value,
2457     Tensor& result) {
2458   // Check if `size` is not negative
2459   TORCH_CHECK(
2460       size >= 0, "nonzero_static: 'size' must be an non-negative integer");
2461   TORCH_CHECK(
2462       result.scalar_type() == kLong,
2463       "nonzero_static: Expected out tensor to have scalar type Long "
2464       "but got scalar type",
2465       result.scalar_type());
2466 
2467   int64_t ndim = self.dim();
2468   if (result.dim() != 2 || result.size(0) != size || result.size(1) != ndim) {
2469     at::native::resize_output(result, {size, ndim});
2470   }
2471   // Verify that the output tensor is resized to expected size=(size, ndim)
2472   TORCH_CHECK(
2473       result.dim() == 2,
2474       "nonzero_static: Expected out tensor to be a 2D tensor but got a ",
2475       result.dim(),
2476       "D tensor");
2477   TORCH_CHECK(
2478       result.size(0) == size && result.size(1) == ndim,
2479       "nonzero_static: Expected out tensor to have Size([",
2480       size,
2481       ", ",
2482       ndim,
2483       "]) but got Size([",
2484       result.size(0),
2485       ", ",
2486       result.size(1),
2487       "]) ");
2488   at::assert_no_internal_overlap(result);
2489   at::assert_no_overlap(result, self);
2490 
2491   // Return earlier if either dim is 0
2492   if (result.size(0) == 0 || result.size(1) == 0) {
2493     return result;
2494   }
2495 
2496   // Delegate call to regular nonzero to get a data-dependent output
2497   auto dyn_result = nonzero_cpu(self);
2498   int64_t num_nonzeros = dyn_result.size(0);
2499   int64_t copy_len = std::min(size, num_nonzeros);
2500   // Copy the dynamic result to the fixed-size tensor
2501   result.narrow(0, 0, copy_len).copy_(dyn_result.narrow(0, 0, copy_len));
2502   if (size > copy_len) {
2503     // Pad result with `fill_value`
2504     result.narrow(0, copy_len, size - copy_len).fill_(fill_value);
2505   }
2506   return result;
2507 }
2508 
nonzero_static_cpu(const Tensor & self,int64_t size,int64_t fill_value)2509 Tensor nonzero_static_cpu(
2510     const Tensor& self,
2511     int64_t size,
2512     int64_t fill_value) {
2513   // Check if `size` is not negative
2514   TORCH_CHECK(
2515       size >= 0, "nonzero_static: 'size' must be an non-negative integer");
2516   // Allocate fixed-size out tensor
2517   int64_t ndim = self.dim();
2518   auto result = at::empty(
2519       {size, ndim},
2520       at::TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU));
2521   nonzero_static_out_cpu(self, size, fill_value, result);
2522   return result;
2523 }
2524 
nonzero_numpy(const Tensor & self)2525 std::vector<Tensor> nonzero_numpy(const Tensor& self) {
2526   // special case scalar for compatibility with numpy:
2527   //
2528   // >>> np.array(5).nonzero()
2529   // (array([0]),)
2530   // >>> np.array(0).nonzero()
2531   // (array([], dtype=int64),)
2532 
2533   if (self.dim() == 0) {
2534     return self.unsqueeze(0).nonzero().unbind(1);
2535   }
2536 
2537   return self.nonzero().unbind(1);
2538 }
2539 
argwhere(const Tensor & self)2540 Tensor argwhere(const Tensor& self) {
2541   return self.nonzero();
2542 }
2543 
masked_scatter__cpu(Tensor & self,const Tensor & mask,const Tensor & source)2544 Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & source) {
2545   at::assert_no_internal_overlap(self);
2546   TORCH_CHECK(
2547       self.scalar_type() == source.scalar_type(),
2548       "masked_scatter: expected self and source to have same dtypes but got",
2549       self.scalar_type(),
2550       " and ",
2551       source.scalar_type());
2552 
2553   TORCH_CHECK(self.device().type() == at::kCPU, "device type of self (", self.device().type(), ") is not CPU");
2554   TORCH_CHECK(mask.device().type() == at::kCPU, "device type of mask (", mask.device().type(), ") is not CPU");
2555   TORCH_CHECK(source.device().type() == at::kCPU, "device type of source (", source.device().type(), ") is not CPU");
2556 
2557   c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_scatter_");
2558 
2559   if (b_mask->dtype() == ScalarType::Byte) {
2560     TORCH_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
2561             "please use a mask with dtype torch.bool instead.");
2562   }
2563 
2564   auto src_cont = source.contiguous();
2565 
2566   auto iter = TensorIteratorConfig()
2567       .set_check_mem_overlap(false)
2568       .check_all_same_dtype(false)
2569       .resize_outputs(false)
2570       // order of indexing matters
2571       .enforce_linear_iteration()
2572       .add_output(self)
2573       .add_const_input(*b_mask)
2574       .build();
2575 
2576   masked_scatter_stub(iter.device_type(), iter, src_cont);
2577   return self;
2578 }
2579 
2580 } // namespace at::native
2581