1 // Indexing tensors by by tensors
2 //
3 // This corresponds to "advanced indexing" in NumPy. The two operations are:
4 //
5 // index(Tensor self, indices) -> Tensor
6 // index_put_(Tensor self, indices, value, accumulate=false)
7 //
8 // The index is a TensorList containing kLong, kBool or kByte tensors or nulls. Byte
9 // tensors (boolean masks) are expanded to long tensors via nonzero(). Null
10 // tensors signify that the dimension is not indexed.
11 //
12 // All indexes are broadcast together and iterated as *one*. From NumPy:
13 //
14 // result[i_1, ..., i_M] == x[ind_1[i_1, ..., i_M], ind_2[i_1, ..., i_M],
15 // ..., ind_N[i_1, ..., i_M]]
16 //
17 // Note 1: ByteTensors expand to index as many dimensions as there are in the
18 // mask.
19 //
20 // Note 2: The behavior is more complicated when the index tensors are not all
21 // adjacent (e.g. x[[0, 1], :, [2, 3]]). In this case, self and the index
22 // tensors are transposed to the front: x.transpose(1, 2)[[0, 1], [2, 3]]
23 //
24 // The code contains two implementations of indexing. The more efficient
25 // implementation treats indexing like an elementwise operation over the
26 // tensors `result`, `x`, `ind_1`, `ind_2`, etc. This implementation does
27 // not work for index_put_ with accumulate=True. The other implementation
28 // combines the indexed tensors into a single linear index that is used
29 // with Tensor.put_. This is used for index_put_ with accumulate=True.
30 //
31 // The more efficient implementation takes the following steps for the
32 // above operation:
33 //
34 // 1) Broadcast ind_1, ind_2, ind_3 together to a common shape
35 // 2) Record x.stride(i) for each indexed dimension `i`
36 // 3) Replace the indexed subspace of `x` with the shape of the corresponding
37 // subspace of `result` but with stride 0
38 // 4) Add dimensions of size 1 to the index tensors (ind_1, ind_2, etc.) so
39 // that their shape is compatible with the result shape
40 //
41 // The CPU or CUDA kernel then computes element-wise over the broadcasted
42 // and restrided result, x, ind_1, ind_2, etc.:
43 //
44 // result[...] = *(&x[...] +
45 // ind_1[...] * x.stride(1) +
46 // ind_2[...] * x.stride(2) +
47 // ...)
48 //
49 // where & and * represent the C-style address-of and indirection operations.
50 // #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
51 #include <ATen/ATen.h>
52
53 #include <ATen/native/TensorAdvancedIndexing.h>
54 #include <ATen/native/IndexKernel.h>
55 #include <ATen/native/IndexingUtils.h>
56
57 #include <ATen/core/Tensor.h>
58 #include <ATen/core/IListRef.h>
59 #include <ATen/Context.h>
60 #include <ATen/Dispatch.h>
61 #include <ATen/ExpandUtils.h>
62 #include <ATen/MemoryOverlap.h>
63 #include <ATen/NamedTensorUtils.h>
64 #include <ATen/Parallel.h>
65 #include <ATen/TensorIterator.h>
66 #include <ATen/TensorMeta.h>
67 #include <ATen/TensorOperators.h>
68 #include <ATen/TensorUtils.h>
69 #include <ATen/WrapDimUtils.h>
70 #include <ATen/native/BinaryOps.h>
71 #include <ATen/native/Copy.h>
72 #include <ATen/native/Resize.h>
73 #include <ATen/native/ScatterGatherChecks.h>
74 #include <ATen/native/TensorAdvancedIndexingUtils.h>
75 #include <ATen/Parallel.h>
76 #include <ATen/NumericUtils.h>
77 #include <ATen/TensorSubclassLikeUtils.h>
78
79 #ifndef AT_PER_OPERATOR_HEADERS
80 #include <ATen/Functions.h>
81 #include <ATen/NativeFunctions.h>
82 #else
83 #include <ATen/ops/_gather_sparse_backward.h>
84 #include <ATen/ops/_gather_sparse_backward_native.h>
85 #include <ATen/ops/_index_put_impl.h>
86 #include <ATen/ops/_index_put_impl_native.h>
87 #include <ATen/ops/_sparse_coo_tensor_unsafe.h>
88 #include <ATen/ops/_unsafe_index_native.h>
89 #include <ATen/ops/_unsafe_index_put_native.h>
90 #include <ATen/ops/arange.h>
91 #include <ATen/ops/argwhere_native.h>
92 #include <ATen/ops/as_strided.h>
93 #include <ATen/ops/broadcast_to.h>
94 #include <ATen/ops/count_nonzero.h>
95 #include <ATen/ops/count_nonzero_native.h>
96 #include <ATen/ops/empty.h>
97 #include <ATen/ops/empty_quantized.h>
98 #include <ATen/ops/gather.h>
99 #include <ATen/ops/gather_backward_native.h>
100 #include <ATen/ops/gather_meta.h>
101 #include <ATen/ops/gather_native.h>
102 #include <ATen/ops/index.h>
103 #include <ATen/ops/index_add_meta.h>
104 #include <ATen/ops/index_add_native.h>
105 #include <ATen/ops/index_copy_meta.h>
106 #include <ATen/ops/index_copy_native.h>
107 #include <ATen/ops/index_fill_native.h>
108 #include <ATen/ops/index_meta.h>
109 #include <ATen/ops/index_native.h>
110 #include <ATen/ops/index_put_native.h>
111 #include <ATen/ops/index_reduce_meta.h>
112 #include <ATen/ops/index_reduce_native.h>
113 #include <ATen/ops/index_select_backward_native.h>
114 #include <ATen/ops/index_select_native.h>
115 #include <ATen/ops/masked_fill_native.h>
116 #include <ATen/ops/masked_scatter_native.h>
117 #include <ATen/ops/masked_select_backward_native.h>
118 #include <ATen/ops/masked_select_native.h>
119 #include <ATen/ops/nested_to_padded_tensor_native.h>
120 #include <ATen/ops/nonzero_native.h>
121 #include <ATen/ops/nonzero_numpy_native.h>
122 #include <ATen/ops/nonzero_static_native.h>
123 #include <ATen/ops/ones_like.h>
124 #include <ATen/ops/put_native.h>
125 #include <ATen/ops/quantize_per_tensor.h>
126 #include <ATen/ops/scatter_add_meta.h>
127 #include <ATen/ops/scatter_add_native.h>
128 #include <ATen/ops/scatter_meta.h>
129 #include <ATen/ops/scatter_native.h>
130 #include <ATen/ops/scatter_reduce_meta.h>
131 #include <ATen/ops/scatter_reduce_native.h>
132 #include <ATen/ops/take_along_dim_native.h>
133 #include <ATen/ops/take_native.h>
134 #include <ATen/ops/zeros_like.h>
135 #endif
136
137 #ifdef USE_FBGEMM
138 #include <fbgemm/Utils.h>
139 #endif
140
141 #include <c10/util/irange.h>
142 #include <c10/util/Unroll.h>
143
144 #include <algorithm>
145 #include <numeric>
146 #include <utility>
147 #include <vector>
148
149 namespace at::native {
150
151 std::string shapes_as_str(TensorList tensors);
152 AdvancedIndex make_info(Tensor self, IOptTensorListRef orig);
153
154 } // namespace at::native
155
156 namespace at::meta {
157
TORCH_META_FUNC(gather)158 TORCH_META_FUNC(gather)
159 (const Tensor & self, int64_t dim, const Tensor & index, bool sparse_grad) {
160 const Tensor& result = maybe_get_output(0);
161 int64_t wrapped_dim = at::maybe_wrap_dim(dim, self.dim());
162
163 // Memory overlap checks need to be done after resizing (if required) is done.
164 // But it only makes sense to do these checks when result was defined, hence
165 // the boolean variable `check_result` here.
166 // For more details, see: https://github.com/pytorch/pytorch/pull/63312#discussion_r694794832
167 // and https://github.com/pytorch/pytorch/issues/63837
168 bool check_result = result.defined();
169 set_output_raw_strided(0, index.sizes(), {}, self.options());
170 if (check_result) {
171 at::assert_no_internal_overlap(result);
172 at::assert_no_overlap(result, self);
173 at::assert_no_partial_overlap(result, index);
174 }
175
176 auto is_index_empty = index.numel() == 0;
177 if (!is_index_empty) {
178 TORCH_CHECK(
179 index.scalar_type() == at::ScalarType::Long,
180 "gather", "(): Expected dtype int64 for index"
181 );
182 }
183 if (is_index_empty) return;
184 at::native::gather_shape_check(self, wrapped_dim, index);
185 }
186
187 template <bool use_new_options = false, typename Meta>
scatter_meta_impl(Meta & meta,const Tensor & self,int64_t dim,const Tensor & index,const std::optional<Tensor> & src=std::nullopt,const std::optional<c10::string_view> reduce=std::nullopt)188 void scatter_meta_impl(
189 Meta& meta,
190 const Tensor& self,
191 int64_t dim,
192 const Tensor& index,
193 const std::optional<Tensor>& src = std::nullopt,
194 const std::optional<c10::string_view> reduce = std::nullopt) {
195 int64_t wrapped_dim = at::maybe_wrap_dim(dim, self.dim());
196 at::native::scatter_gather_dtype_check("scatter", self, index, src);
197 at::native::scatter_shape_check(self, wrapped_dim, index, src);
198 auto output = meta.maybe_get_output(0);
199
200 if (output.defined()) {
201 at::assert_no_internal_overlap(output);
202 at::assert_no_overlap(output, index);
203 if (src.has_value()) {
204 at::assert_no_overlap(output, src.value());
205 }
206 }
207
208 meta.set_output_raw_strided(0, self.sizes(), {}, self.options());
209 if (reduce.has_value()) {
210 // Check if we have a valid reduce operator.
211 at::native::get_operator_enum(reduce.value(), use_new_options);
212 }
213 }
214
TORCH_META_FUNC2(scatter,src)215 TORCH_META_FUNC2(scatter, src)
216 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
217 scatter_meta_impl(*this, self, dim, index, src);
218 }
219
TORCH_META_FUNC2(scatter,value)220 TORCH_META_FUNC2(scatter, value)
221 (const Tensor& self, int64_t dim, const Tensor& index, const Scalar& value) {
222 scatter_meta_impl(*this, self, dim, index);
223 }
224
TORCH_META_FUNC2(scatter,reduce)225 TORCH_META_FUNC2(scatter, reduce)
226 (const Tensor& self,
227 int64_t dim,
228 const Tensor& index,
229 const Tensor& src,
230 const c10::string_view reduce) {
231 TORCH_WARN_ONCE(
232 "The reduce argument of torch.scatter with Tensor src is deprecated and will be removed ",
233 "in a future PyTorch release. Use torch.scatter_reduce instead for more reduction options."
234 );
235 scatter_meta_impl(*this, self, dim, index, src, reduce);
236 }
237
TORCH_META_FUNC2(scatter,value_reduce)238 TORCH_META_FUNC2(scatter, value_reduce)
239 (const Tensor& self,
240 int64_t dim,
241 const Tensor& index,
242 const Scalar& src,
243 const c10::string_view reduce) {
244 scatter_meta_impl(*this, self, dim, index, std::nullopt, reduce);
245 }
246
TORCH_META_FUNC(scatter_add)247 TORCH_META_FUNC(scatter_add)
248 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
249 scatter_meta_impl(*this, self, dim, index, src, "add");
250 }
251
TORCH_META_FUNC2(scatter_reduce,two)252 TORCH_META_FUNC2(scatter_reduce, two)
253 (const Tensor& self,
254 int64_t dim,
255 const Tensor& index,
256 const Tensor& src,
257 const c10::string_view reduce,
258 bool include_self) {
259 (void) include_self;
260 scatter_meta_impl</*use_new_options=*/true>(*this, self, dim, index, src, reduce);
261 }
262
TORCH_PRECOMPUTE_META_FUNC(index_copy)263 TORCH_PRECOMPUTE_META_FUNC(index_copy)
264 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source) {
265 dim = maybe_wrap_dim(dim, self.dim());
266
267 const Tensor& result = maybe_get_output(0);
268
269 // Memory overlap checks need to be done after resizing (if required) is done.
270 // But it only makes sense to do these checks when result was defined, hence
271 // the boolean variable `check_result` here.
272 // For more details, see: https://github.com/pytorch/pytorch/pull/63312#discussion_r694794832
273 // and https://github.com/pytorch/pytorch/issues/63837
274 bool check_result = result.defined();
275 set_output_raw_strided(0, self.sizes(), {}, self.options());
276 if (check_result) {
277 at::assert_no_internal_overlap(result);
278 at::assert_no_overlap(result, index);
279 at::assert_no_overlap(result, source);
280 }
281
282 TORCH_CHECK_INDEX(index.dim() < 2, "index_copy_(): Index should have dimension 1 or 0 (got ", index.dim(), ")");
283
284 int64_t numIndices = index.numel();
285 if (source.dim() == 0 && numIndices != 1) {
286 TORCH_CHECK_INDEX(false, "index_copy_(): When source is scalar, index should have one element (got ", numIndices, ")");
287 } else if ((source.dim() != self.dim()) && (source.dim() != 0 && self.dim() != 0)) {
288 TORCH_CHECK_INDEX(false, "index_copy_(): When source and destination are not scalars, their dimensionality must match. Source dimensionality (",
289 source.dim(), "), destination dimensionality (", self.dim(), ")");
290 }
291
292 TORCH_CHECK(index.scalar_type() == ScalarType::Long, "index_copy_(): Expected a long tensor for index, but got ", index.scalar_type());
293 TORCH_CHECK(self.scalar_type() == source.scalar_type(), "index_copy_(): self and source expected to have the same dtype, but got (self) ", self.scalar_type(), " and (source) ", source.scalar_type());
294 TORCH_CHECK(self.device() == source.device() && self.device() == index.device(),
295 "index_copy_(): self, index and source expected to be in the same device, but got (self) ",
296 self.device(), ", (index) ", index.device(), ", and (source) ", source.device());
297
298 // Check that source and destination slices have the same size
299 auto selfSlicedSizes = self.sizes().vec();
300 if (!selfSlicedSizes.empty()) {
301 selfSlicedSizes.erase(selfSlicedSizes.begin() + dim);
302 }
303 auto sourceSlicedSizes = source.sizes().vec();
304 if (!sourceSlicedSizes.empty()) {
305 sourceSlicedSizes.erase(sourceSlicedSizes.begin() + dim);
306 }
307 if (selfSlicedSizes.size() != sourceSlicedSizes.size() ||
308 !std::equal(selfSlicedSizes.begin(), selfSlicedSizes.end(),
309 sourceSlicedSizes.begin())) {
310 std::stringstream ss;
311 ss << "index_copy_(): Source/destination tensor must have same slice shapes. ";
312 ss << "Destination slice shape: " << selfSlicedSizes << " at dimension " << dim;
313 ss << " and source slice shape: " << sourceSlicedSizes << " at dimension 0.";
314 TORCH_CHECK(false, ss.str());
315 }
316 TORCH_CHECK_INDEX(source.dim() == 0 || numIndices == source.size(dim),
317 "index_copy_(): Number of indices (", numIndices, ") should be equal to source.size(dim) (", source.size(dim), ")");
318
319 return TORCH_PRECOMPUTE_STRUCT(index_copy)().set_dim(dim);
320 }
321
322 template <typename Meta>
index_func_meta_impl(Meta & meta,const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source,c10::string_view func)323 void index_func_meta_impl(
324 Meta& meta,
325 const Tensor& self,
326 int64_t dim,
327 const Tensor& index,
328 const Tensor& source,
329 c10::string_view func) {
330 auto numel = index.numel();
331
332 TORCH_CHECK_INDEX(index.dim() <= 1, func, "_(): Index is supposed to be a vector, but got dim: ",
333 index.dim(), " with type: ", index.scalar_type(), " and size: ", index.sizes());
334 TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int,
335 func, "_(): Expected dtype int32/int64 for index but got: ", index.scalar_type());
336 TORCH_CHECK(self.scalar_type() == source.scalar_type(),
337 func, "_(): self (", self.scalar_type(), ") and source (", source.scalar_type(),
338 ") must have the same scalar type");
339 TORCH_CHECK(dim == 0 || dim < source.dim(),
340 func, "_(): Indexing dim ", dim, " is out of bounds of the source tensor with dim ",
341 source.dim());
342 TORCH_CHECK(numel == (source.dim() == 0 ? 1 : source.size(dim)),
343 func, "_(): Number of indices (", numel, ") should be equal to source.size(dim): (",
344 source.size(dim), "), for dim: ", dim);
345
346 auto self_sizes = self.sizes().vec();
347 auto source_sizes = source.sizes().vec();
348 if (source.dim() != 0 && self.dim() != 0) {
349 self_sizes.erase(self_sizes.begin() + dim);
350 source_sizes.erase(source_sizes.begin() + dim);
351 }
352 TORCH_CHECK(
353 self_sizes == source_sizes,
354 "source tensor shape must match self tensor shape, excluding the specified dimension. Got self.shape = ",
355 self.sizes(),
356 " source.shape = ",
357 source.sizes());
358
359 auto& result = meta.maybe_get_output(0);
360 bool is_defined = result.defined();
361 meta.set_output_raw_strided(0, self.sizes(), {}, self.options());
362 if (is_defined) {
363 at::assert_no_internal_overlap(result);
364 at::assert_no_overlap(result, index);
365 at::assert_no_overlap(result, source);
366 }
367
368 // A hack to run TensorIterator checks in the meta function.
369 // See comment: https://github.com/pytorch/pytorch/pull/65993#discussion_r760307417
370 // TODO: (@krshrimali) Try inheriting from TensorIteratorBase instead.
371 if (result.device() == kMeta && result.dim() > 0) {
372 auto selfSlice = result.select(dim, 0);
373 auto sourceSlice = source.select(dim, 0);
374 auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice);
375 }
376 }
377
TORCH_PRECOMPUTE_META_FUNC(index_add)378 TORCH_PRECOMPUTE_META_FUNC(index_add)
379 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha) {
380 dim = maybe_wrap_dim(dim, self.dim());
381 index_func_meta_impl(*this, self, dim, index, source, "index_add");
382 return TORCH_PRECOMPUTE_STRUCT(index_add)().set_dim(dim);
383 }
384
TORCH_PRECOMPUTE_META_FUNC(index_reduce)385 TORCH_PRECOMPUTE_META_FUNC(index_reduce)
386 (const Tensor& self,
387 int64_t dim,
388 const Tensor& index,
389 const Tensor& source,
390 const c10::string_view reduce,
391 bool include_self) {
392 (void)include_self;
393 TORCH_CHECK(reduce == "prod" || reduce == "mean" || reduce == "amax" || reduce == "amin",
394 "index_reduce(): Expected reduce to be one of prod, mean, amax or amin but got ", reduce, ".");
395 dim = maybe_wrap_dim(dim, self.dim());
396 index_func_meta_impl(*this, self, dim, index, source, "index_reduce");
397 return TORCH_PRECOMPUTE_STRUCT(index_reduce)().set_dim(dim);
398 }
399
build_index_op(TensorIteratorBase & iter,const at::native::AdvancedIndex & info,const Tensor & result)400 static void build_index_op(
401 TensorIteratorBase& iter,
402 const at::native::AdvancedIndex& info,
403 const Tensor& result) {
404 // 'TensorIterator' needs to own the things comming from 'info', since
405 // 'info' will be destroyed after the META function.
406 TensorIteratorConfig config;
407 // info.src is a restrided view of result
408 config.set_check_mem_overlap(false)
409 .check_all_same_dtype(false)
410 .add_output(result)
411 .add_owned_const_input(info.src);
412 for (auto& index : info.indices) {
413 config.add_owned_const_input(index);
414 }
415 if (!result.defined()) {
416 config.declare_static_dtype_and_device(info.src.scalar_type(), info.src.device());
417 }
418 iter.build(config);
419 }
420
check_indices_on_cpu_or_selfdevice(const Tensor & self,const at::MaterializedIOptTensorListRef & indices)421 static void check_indices_on_cpu_or_selfdevice(
422 const Tensor& self,
423 const at::MaterializedIOptTensorListRef& indices) {
424 auto dev = self.device();
425 bool indices_on_cpu_or_dev = std::all_of(
426 indices.begin(), indices.end(), [=](const at::OptionalTensorRef& opt) {
427 return opt.has_value() ? (opt->is_cpu() || opt->device() == dev) : true;
428 });
429 TORCH_CHECK(
430 indices_on_cpu_or_dev,
431 "indices should be either on ", kCPU,
432 " or on the same device as the indexed tensor (", dev, ")");
433 }
434
TORCH_PRECOMPUTE_META_FUNC2(index,Tensor)435 TORCH_PRECOMPUTE_META_FUNC2(index, Tensor)
436 (const Tensor& self, at::IOptTensorListRef indices) {
437 auto materialized = indices.materialize();
438
439 TORCH_CHECK_INDEX(
440 materialized.size() <= (size_t)self.dim(),
441 "too many indices for tensor of dimension ",
442 self.dim(), " (got ", materialized.size(), ")");
443
444 // Only allow: `dev_tensor[{cpu,dev}_tensor]`.
445 // See: https://github.com/pytorch/pytorch/pull/69607
446 check_indices_on_cpu_or_selfdevice(self, materialized);
447
448 const auto& result = maybe_get_output();
449
450 if (result.defined()) {
451 TORCH_CHECK(self.scalar_type() == result.scalar_type(),
452 "index_out: self (", self.scalar_type(), ") and result (", result.scalar_type(),
453 ") must have the same scalar type");
454 at::assert_no_internal_overlap(result);
455 at::assert_no_overlap(result, self);
456 for (const at::OptionalTensorRef& index : materialized) {
457 if (index.has_value()) {
458 at::assert_no_overlap(result, *index);
459 }
460 }
461 }
462
463 auto info = at::native::make_info(self, std::move(indices));
464 build_index_op(*this, info, result);
465 return TORCH_PRECOMPUTE_STRUCT2(index, Tensor)()
466 .set_sizes(std::move(info.indexed_sizes))
467 .set_strides(std::move(info.indexed_strides));
468 }
469
470 } // namespace at::meta
471
472 namespace at::native {
473
474 DEFINE_DISPATCH(index_stub);
475 DEFINE_DISPATCH(index_fill_stub);
476 DEFINE_DISPATCH(index_copy_stub);
477 DEFINE_DISPATCH(index_put_stub);
478 DEFINE_DISPATCH(index_put_with_sort_stub);
479 DEFINE_DISPATCH(put_stub);
480 DEFINE_DISPATCH(take_stub);
481 DEFINE_DISPATCH(masked_fill_stub);
482 REGISTER_NO_CPU_DISPATCH(index_put_with_sort_stub);
483 REGISTER_NO_CPU_DISPATCH(index_put_with_sort_quantized_stub);
484 DEFINE_DISPATCH(masked_select_serial_stub);
485 DEFINE_DISPATCH(masked_select_stub);
486 DEFINE_DISPATCH(masked_scatter_stub);
487
488 DEFINE_DISPATCH(gather_stub);
489 DEFINE_DISPATCH(scatter_stub);
490 DEFINE_DISPATCH(scatter_fill_stub);
491 DEFINE_DISPATCH(scatter_add_stub);
492 DEFINE_DISPATCH(scatter_reduce_stub);
493 DEFINE_DISPATCH(scatter_scalar_reduce_stub);
494 DEFINE_DISPATCH(scatter_reduce_two_stub);
495
496 DEFINE_DISPATCH(scatter_add_expanded_index_stub);
497 DEFINE_DISPATCH(scatter_reduce_expanded_index_stub);
498 DEFINE_DISPATCH(gather_expanded_index_stub);
499
all_strides_match(TensorList tensors)500 static bool all_strides_match(TensorList tensors) {
501 TORCH_CHECK(!tensors.empty());
502 auto strides = tensors[0].strides();
503 for (auto& tensor : tensors.slice(1)) {
504 if (!strides.equals(tensor.strides())) {
505 return false;
506 }
507 }
508 return true;
509 }
510
shapes_as_str(TensorList tensors)511 inline std::string shapes_as_str(TensorList tensors) {
512 std::ostringstream os;
513 bool first = true;
514 for (auto& tensor : tensors) {
515 if (tensor.defined()) {
516 if (!first) {
517 os << ", ";
518 }
519 os << tensor.sizes();
520 first = false;
521 }
522 }
523 return os.str();
524 }
525
526 // Replace indexed dimensions in src with stride 0 and the size of the result tensor.
527 // The offset in these dimensions is computed by the kernel using the index tensor's
528 // values and the stride of src. The new shape is not meaningful. It's used to make
529 // the shape compatible with the result tensor.
restride_src(const Tensor & src,int64_t dims_before,int64_t dims_indexed,IntArrayRef replacement_shape)530 static Tensor restride_src(const Tensor& src, int64_t dims_before, int64_t dims_indexed,
531 IntArrayRef replacement_shape) {
532 auto shape = DimVector(src.sizes());
533 auto strides = DimVector(src.strides());
534 int64_t end = dims_before + dims_indexed;
535 shape.erase(shape.begin() + dims_before, shape.begin() + end);
536 strides.erase(strides.begin() + dims_before, strides.begin() + end);
537 shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end());
538 strides.insert(strides.begin() + dims_before, replacement_shape.size(), 0);
539 return src.as_strided(shape, strides);
540 }
541
542 // Add dimensions of size 1 to an index tensor so that it can be broadcast to the result
543 // shape and iterated over element-wise like the result tensor and the restrided src.
reshape_indexer(const Tensor & index,int64_t dims_before,int64_t dims_after)544 static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t dims_after) {
545 auto orig_shape = index.sizes();
546 auto shape = DimVector();
547 shape.append(dims_before, 1);
548 shape.append(orig_shape.begin(), orig_shape.end());
549 shape.append(dims_after, 1);
550 return index.reshape(shape);
551 }
552
AdvancedIndex(const Tensor & src,TensorList indices_list)553 AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
554 {
555 int64_t element_size_bytes = src.element_size();
556 int64_t dims_before = 0, dims_after = 0, dims_indexed = 0;
557 IntArrayRef replacement_shape;
558 for (const auto dim : c10::irange(indices_list.size())) {
559 if (!indices_list[dim].defined()) {
560 if (dims_indexed == 0) {
561 dims_before++;
562 } else {
563 dims_after++;
564 }
565 } else {
566 dims_indexed++;
567 replacement_shape = indices_list[dim].sizes();
568 indexed_sizes.push_back(src.size(dim));
569 indexed_strides.push_back(src.stride(dim) * element_size_bytes);
570 }
571 }
572
573 // Check if the indexed subspace contains a dim of size 0, but the replacement
574 // shape does not. This implies that an index is out of bounds, because there
575 // is no number that's a valid index for an empty tensor. Normally, out of
576 // bounds is handled in the indexing kernel, but this case fails earlier in
577 // restride_src with an unhelpful error message.
578 if (std::find(indexed_sizes.begin(), indexed_sizes.end(), 0) != indexed_sizes.end() &&
579 std::find(replacement_shape.begin(), replacement_shape.end(), 0) == replacement_shape.end()) {
580 TORCH_CHECK_INDEX(false, "index is out of bounds for dimension with size 0");
581 }
582
583 this->dims_before = dims_before;
584 this->dims_after = dims_after;
585 this->src = restride_src(src, dims_before, dims_indexed, replacement_shape);
586
587 for (auto& index : indices_list) {
588 if (index.defined()) {
589 indices.push_back(reshape_indexer(index, dims_before, dims_after));
590 }
591 }
592
593 // For CUDA/MPS/XPU tensors, force all index tensors to have the same striding to
594 // simplify the CUDA/MPS/XPU kernel.
595 if (indices.size() >= 2 && (this->src.device().type() == kCUDA || this->src.device().type() == kMPS || this->src.device().type() == kXPU)) {
596 if (!all_strides_match(indices)) {
597 for (auto & indice : indices) {
598 indice = indice.contiguous();
599 }
600 }
601 }
602 }
603
make_index_put_iterator(const AdvancedIndex & info,const Tensor & value)604 static TensorIterator make_index_put_iterator(const AdvancedIndex& info, const Tensor& value) {
605 TORCH_CHECK(is_expandable_to(value.sizes(), info.src.sizes()), "shape mismatch: value tensor of shape ", value.sizes(),
606 " cannot be broadcast to indexing result of shape ", info.src.sizes());
607 TORCH_CHECK(value.scalar_type() == info.src.scalar_type(),
608 "Index put requires the source and destination dtypes match, "
609 "got ", info.src.scalar_type(), " for the destination "
610 "and ", value.scalar_type(), " for the source.");
611 TensorIteratorConfig config;
612 // info.src is restrided by restride_src with 0 strided dimensions
613 config.set_check_mem_overlap(false);
614 config.resize_outputs(false);
615 config.check_all_same_dtype(false);
616 config.add_output(info.src);
617 config.add_const_input(value);
618 for (auto& index : info.indices) {
619 config.add_const_input(index);
620 }
621 return config.build();
622 }
623
TORCH_IMPL_FUNC(index_out)624 TORCH_IMPL_FUNC(index_out)
625 (const Tensor& self,
626 DimVector sizes,
627 DimVector strides,
628 const Tensor& result) {
629 index_stub(device_type(), *this, sizes, strides);
630 }
631
quantized_index(const Tensor & self,const torch::List<std::optional<Tensor>> & indices)632 Tensor quantized_index(const Tensor & self, const torch::List<std::optional<Tensor>>& indices) {
633 TORCH_INTERNAL_ASSERT(
634 self.qscheme() == c10::kPerTensorAffine ||
635 self.qscheme() == c10::kPerTensorSymmetric,
636 "Indexing is only supported for per-Tensor quantized Tensors.");
637
638 // For now, this is a naive implementation which does dq -> index -> q.
639 // TODO(future PR): improve performance by removing the copies.
640 const auto& self_dq = self.dequantize();
641 auto result = at::index(self_dq, indices);
642 return at::quantize_per_tensor(
643 result, self.q_scale(), self.q_zero_point(), self.scalar_type());
644 }
645
_unsafe_index(const Tensor & self,const torch::List<std::optional<Tensor>> & indices)646 Tensor _unsafe_index(const Tensor& self, const torch::List<std::optional<Tensor>>& indices) {
647 // Disallow boolean indexing since it leads to dynamic output shapes
648 for (auto i : c10::irange(indices.size())) {
649 auto index = indices.get(i);
650 if (index.has_value()) {
651 auto dtype = index->scalar_type();
652 TORCH_CHECK(dtype == kLong || dtype == kInt,
653 "_unsafe_index found unexpected index type ", dtype);
654 }
655 }
656 return at::index(self, indices);
657 }
658
_unsafe_masked_index(const Tensor & self,const Tensor & mask,const torch::List<std::optional<Tensor>> & indices,const Scalar & fill)659 Tensor _unsafe_masked_index(const Tensor& self, const Tensor& mask, const torch::List<std::optional<Tensor>>& indices, const Scalar& fill) {
660 // Unsafe masked index is equivalent to
661 // where(mask, self[indices], fill)
662 // with the main difference being that the when the `mask` is false, the tensor
663 // `self` is not indexed using `indices`. This allows `indices` to be out-of-bounds
664 // when `mask` is false. When `mask` is true, the `indices` are expected to be
665 // in bounds and is not checked.
666 //
667 // This function is not meant to be executed on eager mode. An unoptimized version
668 // is provided here.
669 //
670 // compiler backends should implement this op such that `self[indices]` is not
671 // loaded when `mask` is true. See inductor for a reference.
672 auto clamp = [](const std::optional<Tensor>& index, auto size) -> std::optional<Tensor> {
673 if (!index) {
674 return index;
675 }
676 // Disallow bool
677 auto dtype = index->scalar_type();
678 TORCH_CHECK(dtype == kLong || dtype == kInt,
679 "_unsafe_masked_index found unexpected index type ", dtype);
680 return at::clamp(*index, -size, size - 1);
681 };
682
683 torch::List<std::optional<Tensor>> clamped_indices(indices);
684 std::transform(indices.begin(), indices.end(), self.sizes().begin(), clamped_indices.begin(), clamp);
685
686 if (self.numel() == 0) {
687 // Returns a tensor filled with `fill` value
688 // We use a hack here since we do not have a method to get the
689 // correct size of the tensor. (except with meta impl which is
690 // not available on mobile builds)
691 std::vector<int64_t> new_sizes(self.dim());
692 auto compute_new_size = [](const std::optional<Tensor>& index, auto size) -> int64_t {
693 if (index && size == 0) {
694 return 1;
695 } else {
696 return size;
697 }
698 };
699 std::transform(indices.begin(), indices.end(), self.sizes().begin(), new_sizes.begin(), compute_new_size);
700 auto result = self.new_full(new_sizes, fill);
701 return at::_unsafe_index(result, clamped_indices);
702 }
703
704 auto result = at::_unsafe_index(self, clamped_indices);
705 return result.masked_fill(at::logical_not(mask), fill);
706 }
707
_unsafe_masked_index_put_accumulate(const Tensor & self,const Tensor & mask,const torch::List<std::optional<Tensor>> & indices,const Tensor & values)708 Tensor _unsafe_masked_index_put_accumulate(const Tensor& self, const Tensor& mask, const torch::List<std::optional<Tensor>>& indices, const Tensor& values) {
709 // This is the backward of _unsafe_masked_index.
710 // This function is not meant to be executed on eager mode.
711
712 if (self.numel() == 0) {
713 return self.clone();
714 }
715
716 // We recompute the clamped indices and rely on inductor to CSE the computation
717 auto clamp = [](const std::optional<Tensor>& index, auto size) -> std::optional<Tensor> {
718 if (!index) {
719 return index;
720 }
721 // Disallow bool
722 auto dtype = index->scalar_type();
723 TORCH_CHECK(dtype == kLong || dtype == kInt,
724 "_unsafe_masked_index found unexpected index type ", dtype);
725 return at::clamp(*index, -size, size - 1);
726 };
727
728 torch::List<std::optional<Tensor>> clamped_indices(indices);
729 std::transform(indices.begin(), indices.end(), self.sizes().begin(), clamped_indices.begin(), clamp);
730
731 auto masked_value = values.masked_fill(at::logical_not(mask), 0);
732 return at::_unsafe_index_put(self, clamped_indices, masked_value, true);
733 }
734
put_(Tensor & self,const Tensor & index,const Tensor & source,const bool accumulate)735 Tensor & put_(Tensor & self, const Tensor& index, const Tensor & source, const bool accumulate) {
736 // See note [Writing Nondeterministic Operations]
737 // Nondeterministic when index contains duplicate entries and we do not accumulate
738 // If we accumulate on GPU, we use atomicGPUAdd, which is non-deterministic
739 if (!accumulate || (accumulate && self.device().type() == DeviceType::CUDA)) {
740 at::globalContext().alertNotDeterministic("put_");
741 }
742
743 // Type and device checks
744 TORCH_CHECK(index.scalar_type() == ScalarType::Long, "put_(): Expected a long tensor for index, but got ", index.scalar_type())
745 TORCH_CHECK(self.scalar_type() == source.scalar_type(), "put_(): self and source expected to have the same dtype, but got self.dtype = ", self.scalar_type(), " and source.dtype = ", source.scalar_type());
746 TORCH_CHECK(self.device() == source.device() && self.device() == index.device(),
747 "put_(): self, index and source expected to be in the same device, but got self.device = ",
748 self.device(), ", index.device = ", index.device(), ", and source.device = ", source.device());
749
750 // index checks
751 TORCH_CHECK_INDEX(source.numel() == index.numel(), "put_(): Expected source and index to have the same number of elements, but got source.numel() = ", source.numel(), ", index.numel() = ", index.numel());
752 TORCH_CHECK_INDEX(!(self.numel() == 0 && index.numel() != 0), "put_(): Tried to put elements into an empty tensor");
753
754 at::assert_no_internal_overlap(self);
755 at::assert_no_overlap(self, index);
756 at::assert_no_overlap(self, source);
757
758 // Early return
759 if (index.numel() == 0) {
760 return self;
761 }
762
763 auto index_reshaped = index.reshape(source.sizes());
764 // Do not iterate over self, we will compute the offsets manually
765 auto iter = TensorIteratorConfig()
766 .set_check_mem_overlap(false)
767 .check_all_same_dtype(false)
768 .add_const_input(source)
769 .add_const_input(index_reshaped)
770 .build();
771
772 put_stub(iter.device_type(), iter, self, accumulate);
773
774 return self;
775 }
776
put(const Tensor & self,const Tensor & index,const Tensor & source,const bool accumulate)777 Tensor put(const Tensor & self, const Tensor& index, const Tensor & source, const bool accumulate) {
778 return self.clone(at::MemoryFormat::Preserve).put_(index, source, accumulate);
779 }
780
index_put(const Tensor & self,const torch::List<std::optional<Tensor>> & indices,const Tensor & value,bool accumulate)781 Tensor index_put(const Tensor & self, const torch::List<std::optional<Tensor>>& indices, const Tensor & value, bool accumulate) {
782 return self.clone(at::MemoryFormat::Preserve).index_put_(indices, value, accumulate);
783 }
784
_unsafe_index_put(const Tensor & self,const torch::List<std::optional<Tensor>> & indices,const Tensor & value,bool accumulate)785 Tensor _unsafe_index_put(const Tensor& self, const torch::List<std::optional<Tensor>>& indices, const Tensor& value, bool accumulate) {
786 return at::index_put(self, indices, value, accumulate);
787 }
788
_index_put_impl_(Tensor & self,const torch::List<std::optional<Tensor>> & indices,const Tensor & value,const bool accumulate,const bool unsafe)789 Tensor & _index_put_impl_(Tensor & self, const torch::List<std::optional<Tensor>>& indices, const Tensor & value, const bool accumulate, const bool unsafe) {
790 TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
791 if (at::has_internal_overlap(self) == MemOverlap::Yes) {
792 TORCH_WARN(
793 "Use of index_put_ on expanded tensors is deprecated. "
794 "Please clone() the tensor before performing this operation. "
795 "This also applies to advanced indexing e.g. tensor[indices] = tensor");
796 }
797 if (!accumulate) {
798 auto masked_fill_dispatch = canDispatchToMaskedFill(self, indices, value);
799 if (std::get<0>(masked_fill_dispatch)) {
800 return self.masked_fill_(std::get<1>(masked_fill_dispatch), value.item());
801 }
802 }
803 auto value_ = value;
804 if (value.device() != self.device() && value.numel() == 1 && value.dim() == 0) {
805 value_ = value.to(self.device());
806 }
807 at::assert_no_overlap(self, value);
808 // NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
809 for (const std::optional<Tensor>& index: indices) {
810 if (index.has_value()) {
811 at::assert_no_overlap(self, *index);
812 }
813 }
814 if ((self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::XPU) && (accumulate || globalContext().deterministicAlgorithms())) {
815 TORCH_CHECK(value_.device() == self.device(), "expected device ", self.device(), " but got device ",
816 value_.device(), " for value tensor");
817 index_put_with_sort_stub(self.device().type(), self, indices, value_, accumulate, unsafe);
818 return self;
819 }
820
821 auto info = make_info(self, indices);
822 auto iter = make_index_put_iterator(info, value_);
823 index_put_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides, accumulate);
824 return self;
825 }
826
take_out(const Tensor & self,const Tensor & index,Tensor & out)827 Tensor& take_out(const Tensor& self, const Tensor& index, Tensor& out) {
828 // Type and device checks
829 TORCH_CHECK(index.scalar_type() == ScalarType::Long, "take(): Expected a long tensor for index, but got ", index.scalar_type())
830 TORCH_CHECK(self.scalar_type() == out.scalar_type(), "take(): self and out expected to have the same dtype, but got self.dtype = ", self.scalar_type(), " and out.dtype = ", out.scalar_type());
831 TORCH_CHECK(self.device() == out.device() && self.device() == index.device(),
832 "take(): self, index and out expected to be in the same device, but got self.device = ",
833 self.device(), ", index.device = ", index.device(), ", and out.device = ", out.device());
834
835 // index checks
836 TORCH_CHECK_INDEX(!(self.numel() == 0 && index.numel() != 0), "take(): tried to take from an empty tensor");
837
838 at::assert_no_internal_overlap(out);
839 at::assert_no_overlap(out, index);
840 at::assert_no_overlap(out, self);
841
842 // Do not iterate over self, we will compute the offsets manually
843 // out is resized inside tensor_iterator
844 auto iter = TensorIteratorConfig()
845 .set_check_mem_overlap(false)
846 .check_all_same_dtype(false)
847 .add_output(out)
848 .add_const_input(index)
849 .build();
850
851 // Early return after out has been resized
852 if (index.numel() == 0) {
853 return out;
854 }
855
856 take_stub(iter.device_type(), iter, self);
857
858 return out;
859 }
860
take(const Tensor & self,const Tensor & index)861 Tensor take(const Tensor& self, const Tensor& index) {
862 auto out = at::empty(index.sizes(), self.options());
863 at::native::take_out(self, index, out);
864 return out;
865 }
866
index_put_(Tensor & self,const torch::List<std::optional<Tensor>> & indices,const Tensor & value,const bool accumulate)867 Tensor & index_put_(Tensor & self, const torch::List<std::optional<Tensor>>& indices, const Tensor & value, const bool accumulate) {
868 return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false);
869 }
870
TORCH_IMPL_FUNC(index_copy_out)871 TORCH_IMPL_FUNC(index_copy_out)
872 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Tensor& result) {
873 if (!result.is_same(self)) result.copy_(self);
874
875 // See Note [Enabling Deterministic Operations]
876 if (result.is_cuda() && globalContext().deterministicAlgorithms()){
877 torch::List<std::optional<Tensor>> indices;
878 indices.reserve(dim + 1);
879 for (const auto i: c10::irange(dim)) {
880 (void)i;
881 indices.emplace_back();
882 }
883 indices.emplace_back(index);
884 result.index_put_(indices, source, false);
885 return;
886 }
887
888 // Handle the case when self / source is 0-dim
889 Tensor result_nonzero = result.dim() == 0 ? result.unsqueeze(0) : result;
890 Tensor source_nonzero = source.dim() == 0 ? source.unsqueeze(0) : source;
891
892 // The only difference between the following tensor iterator and that of index_fill_ is that
893 // this one has also source as an input. We should refactor it when if constexpr is available (C++17)
894
895 // Prepare `index` for TensorIterator.
896 // It is restrided to be broadcastable over `self` in TensorIterator.
897 auto index_sizes = std::vector<int64_t>(result_nonzero.dim(), 1);
898 auto index_strides = std::vector<int64_t>(result_nonzero.dim(), 0);
899 index_sizes[dim] = index.numel();
900 index_strides[dim] = (index.dim() > 0) ? index.stride(0) : 1; // `index` is 1d or scalar
901 auto index_restrided = index.as_strided(
902 index_sizes, index_strides);
903
904 // Prepare `result` for TensorIterator.
905 // Restride `result` to not advance in dimension `dim`.
906 // We do not use squash_dim here because `index` will
907 // need to advance in this dimension.
908 // Note that self_sizes[dim] is set to index.numel().
909 // This is done so that self_sizes[dim] and index_sizes[dim]
910 // match as required by TensorIterator (input shape should
911 // strictly broadcast over output shape, i.e.
912 // output.shape[i] >= input.shape[i] for i in range(dims)).
913 auto result_sizes = result_nonzero.sizes().vec();
914 auto result_strides = result_nonzero.strides().vec();
915 result_sizes[dim] = index.numel();
916 result_strides[dim] = 0;
917 auto result_restrided = result_nonzero.as_strided(result_sizes, result_strides);
918
919 auto iter = TensorIteratorConfig()
920 // We do not check for overlap because `result` is restrided
921 // with zero stride. Zero strides trigger memory overlap assert
922 // within TensorIterator.
923 .set_check_mem_overlap(false)
924 .check_all_same_dtype(false)
925 .resize_outputs(false)
926 .add_output(result_restrided)
927 .add_const_input(index_restrided)
928 .add_const_input(source_nonzero)
929 .build();
930
931 auto result_dim_size = result_nonzero.size(dim);
932 auto result_dim_stride = result_nonzero.stride(dim);
933 index_copy_stub(
934 iter.device_type(),
935 iter,
936 dim,
937 result_dim_size,
938 result_dim_stride);
939 }
940
941 // Not calling into index_reduce_func_impl because of a different dtype dispatch
TORCH_IMPL_FUNC(index_add_cpu_out)942 TORCH_IMPL_FUNC(index_add_cpu_out)
943 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) {
944 if (!result.is_same(self)) {
945 result.copy_(self);
946 }
947 auto numel = index.numel();
948
949 auto index_contig = index.contiguous();
950
951 if (result.dim() > 1) {
952 // Equivalent to:
953 // for (const auto i : c10::irange(numel)) {
954 // auto selfSlice = self.select(dim, index_data[i]);
955 // auto sourceSlice = source.select(dim, i);
956 // selfSlice.add_(sourceSlice);
957 // }
958 // But much faster as this reuses the iterator from add_
959 if (numel == 0 || self.numel() == 0) {
960 return;
961 }
962
963 dim = maybe_wrap_dim(dim, self.dim());
964
965 // When the slice of source or result is noncontiguous,
966 // original index_add is slow as it uses add for the sliced tensor,
967 // which is serial on index and parallel on sliced tensor to avoid write conflict.
968 // Doing parallel on the sliced tensor is not optimal as the size of sliced tensor
969 // may be not big enough to parallel and also causes multiple parallelizations.
970 // scatter_add is used to speedup for this case as scatter_add parallels on
971 // the outer dimension of input and is serial on the inner dimension to
972 // avoid write conflict. scatter_add only need one parallel and the size of
973 // outer dimensions is bigger to do parallel.
974
975 if ((dim == 0 || dim == self.dim() - 1) &&
976 // Data type of index should be long and alpha should be 1 to use scatter_add.
977 alpha.equal(1.0) && index_contig.scalar_type() == ScalarType::Long &&
978 // scatter_add does not support ComplexHalf
979 source.scalar_type() != ScalarType::ComplexHalf &&
980 result.scalar_type() != ScalarType::ComplexHalf) {
981 std::vector<int64_t> ep_sizes(result.sizes().size());
982 std::vector<int64_t> ep_strides(source.sizes().size());
983
984 // Check whether result and source are matched apart from the dimension dim.
985 // Note that the broadcast case:
986 // source.select(dim, i) is broadcast for result.select(dim, index_data[i])
987 // The broadcast case is not applicable for scatter_add
988 auto check_sizes = [&ep_sizes, &ep_strides, &numel](IntArrayRef a, IntArrayRef b, int64_t dim) -> bool {
989
990 ep_sizes[dim] = numel;
991 ep_strides[dim] = 1;
992 for (const int64_t i : c10::irange(a.size())) {
993 if (i == dim) {
994 continue;
995 }
996
997 if (a[i] != b[i]) {
998 return false;
999 }
1000 ep_sizes[i] = a[i];
1001 ep_strides[i] = 0;
1002
1003 }
1004 return true;
1005 };
1006
1007 if (check_sizes(result.sizes(), source.sizes(), dim)) {
1008 auto ep_index = index_contig.as_strided(ep_sizes, ep_strides);
1009 result.scatter_add_(dim, ep_index, source);
1010 return;
1011 }
1012 }
1013
1014 auto selfSlice = result.select(dim, 0);
1015 auto sourceSlice = source.select(dim, 0);
1016 auto self_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());
1017 auto source_stride_bytes = source.stride(dim) * elementSize(source.scalar_type());
1018 auto self_dim_size = result.size(dim);
1019 auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice);
1020
1021 AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cpu_", [&] () {
1022 auto index_data = index_contig.const_data_ptr<index_t>();
1023 for (const auto i : c10::irange(numel)) {
1024 auto self_i = index_data[i];
1025 TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
1026 auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
1027 auto source_data = static_cast<const char*>(sourceSlice.const_data_ptr()) + i * source_stride_bytes;
1028 iter.unsafe_replace_operand(0, self_data);
1029 iter.unsafe_replace_operand(1, self_data);
1030 iter.unsafe_replace_operand(2, const_cast<char*>(source_data));
1031 add_stub(iter.device_type(), iter, alpha);
1032 }
1033 });
1034 } else {
1035 TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
1036
1037 // explicitly capture all required variables to work around windows build
1038 // TODO: fix this when windows can correctly capture variables in nested lambda
1039 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, ScalarType::ComplexHalf,
1040 result.scalar_type(), "index_add_", [&result, &source, &dim, &index_contig, &numel, &alpha] {
1041 auto alpha_value = alpha.to<scalar_t>();
1042 auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
1043 auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
1044 // TODO: Maybe TensorAccessor can be used here?
1045 auto* result_ptr = result.data_ptr<scalar_t>();
1046 auto* source_ptr = source.const_data_ptr<scalar_t>();
1047 AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_add_cpu_",
1048 [&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &alpha_value] {
1049 auto index_data = index_contig.const_data_ptr<index_t>();
1050 for (const auto i : c10::irange(numel)) {
1051 auto self_i = index_data[i];
1052 TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self");
1053 scalar_t *self_ip = result_ptr + self_i * result_stride;
1054 *self_ip += *(source_ptr + i * source_stride) * alpha_value;
1055 }
1056 });
1057 });
1058 }
1059 }
1060
index_reduce_func_impl(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source,bool include_self,const Tensor & result,const ReductionType & op)1061 static void index_reduce_func_impl(
1062 const Tensor& self,
1063 int64_t dim,
1064 const Tensor& index,
1065 const Tensor& source,
1066 bool include_self,
1067 const Tensor& result,
1068 const ReductionType& op) {
1069 if (!result.is_same(self)) result.copy_(self);
1070 if (!include_self) {
1071 AT_DISPATCH_ALL_TYPES_AND2(
1072 at::ScalarType::Half, at::ScalarType::BFloat16,
1073 self.scalar_type(), "index_reduce_func_exclude_input_init", [&] {
1074 scalar_t init_val;
1075 switch (op) {
1076 case ReductionType::PROD:
1077 init_val = (scalar_t)1;
1078 break;
1079 case ReductionType::MAX:
1080 init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
1081 : std::numeric_limits<scalar_t>::lowest();
1082 break;
1083 case ReductionType::MIN:
1084 init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
1085 : std::numeric_limits<scalar_t>::max();
1086 break;
1087 default:
1088 init_val = (scalar_t)0;
1089 break;
1090 }
1091 // index_fill_ requires index to be a LongTensor
1092 result.index_fill_(dim, index.to(at::ScalarType::Long), init_val);
1093 });
1094 }
1095
1096 auto numel = index.numel();
1097
1098 auto index_contig = index.contiguous();
1099
1100 if (result.dim() > 1) {
1101 // Equivalent to:
1102 // for (const auto i : c10::irange(numel)) {
1103 // auto selfSlice = self.select(dim, index_data[i]);
1104 // auto sourceSlice = source.select(dim, i);
1105 // selfSlice.op_(sourceSlice);
1106 // }
1107 // But much faster as this reuses the iterator from the binary op
1108 if (numel == 0) {
1109 return;
1110 }
1111 auto selfSlice = result.select(dim, 0);
1112 auto sourceSlice = source.select(dim, 0);
1113 auto self_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());
1114 auto source_stride_bytes = source.stride(dim) * elementSize(source.scalar_type());
1115 auto self_dim_size = result.size(dim);
1116 auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice);
1117
1118 AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_func_cpu_", [&] () {
1119 auto index_data = index_contig.const_data_ptr<index_t>();
1120 for (const auto i : c10::irange(numel)) {
1121 auto self_i = index_data[i];
1122 TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
1123 auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
1124 auto source_data = static_cast<const char*>(sourceSlice.const_data_ptr()) + i * source_stride_bytes;
1125 iter.unsafe_replace_operand(0, self_data);
1126 iter.unsafe_replace_operand(1, self_data);
1127 iter.unsafe_replace_operand(2, const_cast<char*>(source_data));
1128
1129 switch (op) {
1130 case ReductionType::PROD :
1131 mul_stub(iter.device_type(), iter);
1132 break;
1133 case ReductionType::MIN :
1134 minimum_stub(iter.device_type(), iter);
1135 break;
1136 case ReductionType::MAX :
1137 maximum_stub(iter.device_type(), iter);
1138 break;
1139 default :
1140 add_stub(iter.device_type(), iter, 1);
1141 break;
1142 }
1143 }
1144 });
1145
1146 if (op == ReductionType::MEAN) {
1147 auto counts = include_self ? at::ones_like(result) : at::zeros_like(result);
1148 counts.index_add_(dim, index, at::ones_like(source));
1149 counts.masked_fill_(counts == 0, 1);
1150 if (result.is_floating_point() || result.is_complex()) {
1151 result.div_(counts);
1152 } else {
1153 result.div_(counts, "floor");
1154 }
1155 }
1156 }
1157 else {
1158 TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
1159 auto counts = include_self ? at::ones_like(result) : at::zeros_like(result);
1160 // explicitly capture all required variables to work around windows build
1161 // TODO: fix this when windows can correctly capture variables in nested lambda
1162 AT_DISPATCH_ALL_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
1163 result.scalar_type(), "index_func_", [&result, &source, &dim, &index_contig, &numel, &op, &counts] {
1164 auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
1165 auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
1166 auto counts_stride = counts.dim() == 0 ? 1 : counts.stride(dim);
1167 // TODO: Maybe TensorAccessor can be used here?
1168 auto* result_ptr = result.data_ptr<scalar_t>();
1169 auto* source_ptr = source.const_data_ptr<scalar_t>();
1170 auto counts_ptr = counts.data_ptr<scalar_t>();
1171 AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_func_cpu_",
1172 [&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &op, &counts_ptr, &counts_stride] {
1173 auto index_data = index_contig.const_data_ptr<index_t>();
1174 for (const auto i : c10::irange(numel)) {
1175 auto self_i = index_data[i];
1176 TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self");
1177 scalar_t *self_ip = result_ptr + self_i * result_stride;
1178 scalar_t *count_ip;
1179 scalar_t val;
1180 switch (op) {
1181 case ReductionType::MEAN :
1182 *self_ip += *(source_ptr + i * source_stride);
1183 count_ip = counts_ptr + self_i * counts_stride;
1184 *count_ip += 1;
1185 break;
1186 case ReductionType::PROD :
1187 *self_ip *= *(source_ptr + i * source_stride);
1188 break;
1189 case ReductionType::MIN :
1190 val = *(source_ptr + i * source_stride);
1191 *self_ip = at::_isnan<scalar_t>(val) ? val : std::min(*self_ip, val);
1192 break;
1193 case ReductionType::MAX :
1194 val = *(source_ptr + i * source_stride);
1195 *self_ip = at::_isnan<scalar_t>(val) ? val : std::max(*self_ip, val);
1196 break;
1197 default:
1198 break;
1199 }
1200 }
1201 });
1202 });
1203 if (op == ReductionType::MEAN) {
1204 counts.masked_fill_(counts == 0, 1);
1205 if (result.is_floating_point() || result.is_complex()) {
1206 result.div_(counts);
1207 } else {
1208 result.div_(counts, "floor");
1209 }
1210 }
1211 }
1212 }
1213
TORCH_IMPL_FUNC(index_reduce_cpu_out)1214 TORCH_IMPL_FUNC(index_reduce_cpu_out)
1215 (const Tensor& self,
1216 int64_t dim,
1217 const Tensor& index,
1218 const Tensor& source,
1219 const c10::string_view reduce,
1220 bool include_input,
1221 const Tensor& result) {
1222 TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time.");
1223 auto op = get_operator_enum(reduce, true);
1224 index_reduce_func_impl(self, dim, index, source, include_input, result, op);
1225 }
1226
1227 // Check that indices fall within dimension array size
1228 // Avoid redispatch call to min/max
1229 template <typename IndexType>
check_indexarray_range(const IndexType * indices,int64_t n,IndexType indexing_axis_dim)1230 static void check_indexarray_range(
1231 const IndexType* indices,
1232 int64_t n,
1233 IndexType indexing_axis_dim) {
1234 for (const auto i : c10::irange(n)) {
1235 auto idx = indices[i];
1236 TORCH_CHECK(
1237 0 <= idx && idx < indexing_axis_dim,
1238 "INDICES element is out of DATA bounds, id=",
1239 idx,
1240 " axis_dim=",
1241 indexing_axis_dim);
1242 }
1243 }
1244
index_select_out_cpu_dim1_(Tensor & result_contig,const Tensor & self,const Tensor & index_contig)1245 static Tensor & index_select_out_cpu_dim1_(
1246 Tensor & result_contig, const Tensor & self, const Tensor & index_contig) {
1247
1248 auto self_contig = self.contiguous();
1249 const caffe2::TypeMeta dataType = self_contig.dtype();
1250 size_t item_bytesize = dataType.itemsize();
1251
1252 auto out = static_cast<char*>(result_contig.data_ptr());
1253
1254 auto src_base = static_cast<const char*>(self_contig.const_data_ptr());
1255
1256 auto self_sizes = self_contig.sizes();
1257 auto outer_dims_product = c10::size_to_dim_(1, self_sizes);
1258 auto block_size = c10::size_from_dim_(2, self_sizes);
1259 auto block_bytesize = block_size * item_bytesize;
1260
1261 auto src_indexing_axis_dim = self_sizes[1];
1262 auto src_batch_bytesize = self_sizes[1] * block_bytesize;
1263 auto N = index_contig.numel();
1264
1265 auto gathered_batch_bytesize = N * block_bytesize;
1266
1267 AT_DISPATCH_INDEX_TYPES(
1268 index_contig.scalar_type(), "batch_index_select_compute", [&]() {
1269
1270 const auto* idxs = index_contig.const_data_ptr<index_t>();
1271 check_indexarray_range<index_t>(idxs, N, src_indexing_axis_dim);
1272
1273 // Special-case single-float copy for efficiency
1274 if (self.scalar_type() == ScalarType::Float && block_size == 1) {
1275 for (const auto batch : c10::irange(outer_dims_product)) {
1276 const float* src_floats =
1277 (const float*)(src_base + batch * src_batch_bytesize);
1278 float* dst_floats = (float*)(out + batch * gathered_batch_bytesize);
1279
1280 for (const auto i : c10::irange(N)) {
1281 auto idx = idxs[i];
1282 dst_floats[i] = src_floats[idx];
1283 }
1284 }
1285 } else {
1286 // outer_dims_product specifies how many times we repeat inner dimensions,
1287 // so we just iterate over it to cover all outer dimensions.
1288 for (const auto batch : c10::irange(outer_dims_product)) {
1289 for (const auto i : c10::irange(N)) {
1290 auto idx = idxs[i];
1291 auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize;
1292 auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize;
1293 memcpy(dst, src, block_bytesize);
1294 }
1295 }
1296 }
1297 });
1298 return result_contig;
1299 }
1300
index_select_out_cpu_(const Tensor & self,int64_t dim,const Tensor & index,Tensor & result)1301 Tensor & index_select_out_cpu_(const Tensor & self, int64_t dim, const Tensor & index, Tensor & result) {
1302 if (self.is_quantized()) {
1303 TORCH_CHECK(
1304 self.qscheme() == kPerTensorAffine,
1305 "Only per_tensor quantized quantized tensors are supported by index_select.")
1306 }
1307 dim = maybe_wrap_dim(dim, self.dim());
1308 auto numel = index.numel();
1309 TORCH_CHECK_INDEX(index.dim() <= 1, "index_select(): Index is supposed to be a vector");
1310 TORCH_CHECK(!(self.dim() == 0 && numel != 1), "index_select(): Index to scalar can have only 1 value, got ", numel, " value(s)");
1311 TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int, "index_select(): Expected dtype int32 or int64 for index");
1312 TORCH_CHECK(self.scalar_type() == result.scalar_type(),
1313 "index_select(): self and result must have the same scalar type");
1314 at::assert_no_internal_overlap(result);
1315 at::assert_no_overlap(result, self);
1316 at::assert_no_overlap(result, index);
1317 auto result_size = self.sizes().vec();
1318 if (self.dim() > 0) {
1319 result_size[dim] = numel;
1320 }
1321 at::native::resize_output(result, result_size);
1322
1323 auto index_contig = index.contiguous();
1324
1325 if (self.dim() > 1) {
1326 if (numel == 0) {
1327 return result;
1328 }
1329 if (self.numel() == 0) {
1330 auto src_indexing_axis_dim = self.size(dim);
1331 TORCH_CHECK(src_indexing_axis_dim > 0,
1332 "index_select(): self indexing axis dim should be positive");
1333 AT_DISPATCH_INDEX_TYPES(
1334 index_contig.scalar_type(), "index_select_empty_self_bound_check", [&]() {
1335 const auto* idxs = index_contig.const_data_ptr<index_t>();
1336 check_indexarray_range<index_t>(idxs, numel, src_indexing_axis_dim);
1337 });
1338 return result;
1339 }
1340
1341 if (dim == 1 && result.is_contiguous()) {
1342 // fast pass
1343 return index_select_out_cpu_dim1_(result, self, index_contig);
1344 }
1345
1346 auto selfSlice = self.select(dim, 0);
1347 auto resultSlice = result.select(dim, 0);
1348 auto selfSlice_data = selfSlice.const_data_ptr();
1349 auto resultSlice_data = resultSlice.data_ptr();
1350 auto self_stride_bytes = self.stride(dim) * elementSize(self.scalar_type());
1351 auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());
1352 auto self_dim_size = self.size(dim);
1353 auto slice_size = selfSlice.numel();
1354
1355 auto iter = TensorIteratorConfig()
1356 .check_all_same_dtype(false)
1357 .resize_outputs(false)
1358 .add_output(resultSlice)
1359 .add_const_input(selfSlice)
1360 .build();
1361
1362 auto grain_size = at::internal::GRAIN_SIZE;
1363 auto outer_loop =
1364 // explicitly capture all required variables to work around windows build
1365 // TODO: fix this when windows can correctly capture variables in nested lambda
1366 [&index_contig, &iter, &self_dim_size, &selfSlice_data, &self_stride_bytes, &resultSlice_data,
1367 &result_stride_bytes](int64_t start, int64_t end) {
1368 auto sub_iter = TensorIterator(iter);
1369 AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
1370 [&index_contig, &start, &end, &sub_iter, &self_dim_size, &selfSlice_data, &self_stride_bytes,
1371 &resultSlice_data, &result_stride_bytes] () {
1372 auto index_data = index_contig.const_data_ptr<index_t>();
1373 for (const auto i : c10::irange(start, end)) {
1374 auto self_i = index_data[i];
1375 TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
1376 auto self_data = static_cast<const char*>(selfSlice_data) + self_i * self_stride_bytes;
1377 auto result_data = static_cast<char*>(resultSlice_data) + i * result_stride_bytes;
1378 sub_iter.unsafe_replace_operand(0, result_data);
1379 sub_iter.unsafe_replace_operand(1, const_cast<char*>(self_data));
1380 copy_stub(sub_iter.device_type(), sub_iter, false);
1381 };
1382 });
1383 };
1384
1385 // parallel on inner loop in case the slice is large enough;
1386 // otherwise parallel on outer loop
1387 if (slice_size >= grain_size) {
1388 outer_loop(0, numel);
1389 } else {
1390 // use a fast loop when self and result are contiguous and of the same data type
1391 if (iter.is_contiguous() && self.scalar_type() == result.scalar_type()) {
1392 auto slice_size_bytes = slice_size * elementSize(self.scalar_type());
1393 // explicitly capture all required variables to work around windows build
1394 // TODO: fix this when windows can correctly capture variables in nested lambda
1395 at::parallel_for(0, numel, grain_size / slice_size,
1396 [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data,
1397 &self_stride_bytes, &resultSlice_data, &result_stride_bytes](int64_t start, int64_t end) {
1398 AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
1399 [&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data,
1400 &self_stride_bytes, &resultSlice_data, &result_stride_bytes, &start, &end] () {
1401 auto index_data = index_contig.const_data_ptr<index_t>();
1402 for (const auto i : c10::irange(start, end)) {
1403 auto self_i = index_data[i];
1404 TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
1405 auto self_data = static_cast<const char*>(selfSlice_data) + self_i * self_stride_bytes;
1406 auto result_data = static_cast<char*>(resultSlice_data) + i * result_stride_bytes;
1407 memcpy(result_data, self_data, slice_size_bytes);
1408 }
1409 });
1410 });
1411 } else {
1412 at::parallel_for(0, numel, grain_size / slice_size, outer_loop);
1413 }
1414 }
1415 } else {
1416 TORCH_CHECK(result.dim() <= 1, "result.dim() (", result.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
1417 // explicitly capture all required variables to work around windows build
1418 // TODO: fix this when windows can correctly capture variables in nested lambda
1419 if(self.is_quantized()){
1420 AT_DISPATCH_QINT_TYPES(self.scalar_type(), "index_select_quant", [&index_contig, &self, &result, &dim, &numel] {
1421 auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
1422 auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
1423 auto self_data_ptr = self.const_data_ptr<scalar_t>();
1424 auto result_data_ptr = result.data_ptr<scalar_t>();
1425 auto self_numel = self.numel();
1426 AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_quant_",
1427 [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] {
1428 auto index_data = index_contig.const_data_ptr<index_t>();
1429 for (const auto i : c10::irange(numel)) {
1430 auto self_i = index_data[i];
1431 TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
1432 const scalar_t *self_ip = self_data_ptr + self_i * self_stride;
1433 *(result_data_ptr + i * result_stride) = *self_ip;
1434 }
1435 });
1436 });
1437 } else {
1438 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
1439 self.scalar_type(), "index_select", [&index_contig, &self, &result, &dim, &numel] {
1440 auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
1441 auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
1442
1443 auto self_data_ptr = self.const_data_ptr<scalar_t>();
1444 auto result_data_ptr = result.data_ptr<scalar_t>();
1445 auto self_numel = self.numel();
1446 AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
1447 [&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] {
1448 auto index_data = index_contig.const_data_ptr<index_t>();
1449 for (const auto i : c10::irange(numel)) {
1450 auto self_i = index_data[i];
1451 TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
1452 const scalar_t *self_ip = self_data_ptr + self_i * self_stride;
1453 *(result_data_ptr + i * result_stride) = *self_ip;
1454 }
1455 });
1456 });
1457 }
1458 }
1459
1460 return result;
1461 }
1462
index_select_cpu_(const Tensor & self,int64_t dim,const Tensor & index)1463 Tensor index_select_cpu_(const Tensor & self, int64_t dim, const Tensor & index) {
1464 Tensor result = at::empty({0}, self.options());
1465 return at::native::index_select_out_cpu_(self, dim, index, result);
1466 }
1467
index_select_quantized_cpu_(const Tensor & self,int64_t dim,const Tensor & index)1468 Tensor index_select_quantized_cpu_(const Tensor & self, int64_t dim, const Tensor & index) {
1469 TORCH_CHECK(self.qscheme() == kPerTensorAffine,
1470 "Only per_tensor quantized quantized tensors are supported by index_select.")
1471 Tensor result = at::empty_quantized({0}, self);
1472 return at::native::index_select_out_cpu_(self, dim, index, result);
1473 }
1474
index_select_backward_symint(const Tensor & grad,c10::SymIntArrayRef self_sizes,int64_t dim,const Tensor & index)1475 Tensor index_select_backward_symint(const Tensor& grad, c10::SymIntArrayRef self_sizes, int64_t dim, const Tensor& index) {
1476 // for composite compliance, use out-of-place variant of
1477 // `index_add` if index tensor is a Tensor Subclass.
1478 if (isTensorSubclassLike(index)) {
1479 return grad.new_zeros_symint(self_sizes, grad.options()).index_add(dim, index, grad);
1480 }
1481 return grad.new_zeros_symint(self_sizes, grad.options()).index_add_(dim, index, grad);
1482 }
1483
index_fill_(Tensor & self,int64_t dim,const Tensor & index,const Scalar & source)1484 Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) {
1485 at::NoNamesGuard guard;
1486
1487 TORCH_CHECK_INDEX(
1488 index.scalar_type() == ScalarType::Long,
1489 "index_fill_(): Expected dtype int64 for index.");
1490
1491 at::assert_no_overlap(self, index);
1492 if (at::has_internal_overlap(self) == at::MemOverlap::Yes) {
1493 TORCH_WARN(
1494 "Use of index_fill_ on expanded tensors is deprecated. "
1495 "Please clone() the tensor before performing this operation. "
1496 "This also applies to advanced indexing e.g. tensor[mask] = scalar");
1497 }
1498
1499 if (!self.is_complex() && source.isComplex()) {
1500 TORCH_CHECK(false, "index_fill_(): Converting complex Scalar to non-complex type is not supported");
1501 }
1502
1503 // Handle the case when `self` is 0-dim
1504 Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self;
1505
1506 dim = at::maybe_wrap_dim(dim, self_nonzero_dim);
1507 TORCH_CHECK(index.dim() <= 1, "Index has to be a vector/scalar");
1508
1509 // Prepare `index` for TensorIterator.
1510 // It is restrided to be broadcastable over `self` in TensorIterator.
1511 auto index_sizes = std::vector<int64_t>(self_nonzero_dim.dim(), 1);
1512 auto index_strides = std::vector<int64_t>(self_nonzero_dim.dim(), 0);
1513 index_sizes[dim] = index.numel();
1514 index_strides[dim] = (index.dim() > 0) ? index.stride(0) : 1; // `index` is 1d or scalar
1515 auto index_restrided = index.as_strided(
1516 index_sizes, index_strides);
1517
1518 // Prepare `self` for TensorIterator.
1519 // Restride `self` to not advance in dimension `dim`.
1520 // We do not use squash_dim here because `index` will
1521 // need to advance in this dimension.
1522 // Note that self_sizes[dim] is set to index.numel().
1523 // This is done so that self_sizes[dim] and index_sizes[dim]
1524 // match as required by TensorIterator (input shape should
1525 // strictly broadcast over output shape, i.e.
1526 // output.shape[i] >= input.shape[i] for i in range(dims)).
1527 auto self_sizes = self_nonzero_dim.sizes().vec();
1528 auto self_strides = self_nonzero_dim.strides().vec();
1529 self_sizes[dim] = index.numel();
1530 self_strides[dim] = 0;
1531 auto self_restrided = self_nonzero_dim.as_strided(self_sizes, self_strides);
1532
1533 auto iter = TensorIteratorConfig()
1534 // We do not check for overlap because `self` is restrided
1535 // with zero stride. Zero strides trigger memory overlap assert
1536 // within TensorIterator.
1537 .set_check_mem_overlap(false)
1538 .check_all_same_dtype(false)
1539 .resize_outputs(false)
1540 .add_output(self_restrided)
1541 .add_const_input(index_restrided)
1542 .build();
1543
1544 auto self_dim_size = (self_nonzero_dim.sizes())[dim];
1545 auto self_dim_stride = (self_nonzero_dim.strides())[dim];
1546 index_fill_stub(
1547 iter.device_type(),
1548 iter,
1549 dim,
1550 self_dim_size,
1551 self_dim_stride,
1552 source);
1553
1554 return self;
1555 }
1556
index_fill_(Tensor & self,int64_t dim,const Tensor & index,const Tensor & source)1557 Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
1558 TORCH_CHECK(source.dim() == 0, "index_fill_ only supports a 0-dimensional value tensor, but got tensor "
1559 "with ", source.dim(), " dimension(s).");
1560 return self.index_fill_(dim, index, source.item());
1561 }
1562
index_fill(const Tensor & self,int64_t dim,const Tensor & index,const Scalar & source)1563 Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Scalar& source) {
1564 return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source);
1565 }
1566
index_fill(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source)1567 Tensor index_fill(const Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
1568 return self.clone(at::MemoryFormat::Preserve).index_fill_(dim, index, source);
1569 }
1570
1571 // fast paths for GNN usage
can_use_expanded_index_path(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,bool is_scatter_like)1572 static bool can_use_expanded_index_path(
1573 const Tensor& self,
1574 int64_t dim,
1575 const Tensor& index,
1576 const Tensor& src,
1577 bool is_scatter_like) {
1578 #ifdef USE_FBGEMM
1579 if (!fbgemm::is_radix_sort_accelerated_with_openmp()) {
1580 return false;
1581 }
1582 #else
1583 return false;
1584 #endif
1585
1586 if (!self.device().is_cpu()) {
1587 return false;
1588 }
1589
1590 const auto st = self.scalar_type();
1591 if (!(c10::isFloatingType(st))) {
1592 return false;
1593 }
1594
1595 // skip when having empty tensor
1596 if (self.numel() == 0 || index.numel() == 0 || src.numel() == 0) {
1597 return false;
1598 }
1599
1600 // skip when having scalar tensor
1601 if (self.ndimension() == 0 || index.ndimension() == 0 || src.ndimension() == 0) {
1602 return false;
1603 }
1604
1605 // allow only different size on dim 0 for src and index
1606 // https://github.com/pytorch/pytorch/issues/99595
1607 for (const auto dim : c10::irange(1, index.dim())) {
1608 if (src.size(dim) != index.size(dim)) {
1609 return false;
1610 }
1611 }
1612
1613 if (is_scatter_like) {
1614 // using `spmm` for scatter would require sorting on index,
1615 // this is only perf beneficial when the inner dimension, aka, `channels`
1616 // is big enough.
1617 constexpr int64_t threshold = 16;
1618 if (index.numel() / index.size(0) < threshold) {
1619 return false;
1620 }
1621 }
1622
1623 // usually the expanded index has stride on the first dimension to be 1,
1624 // and strides on other dims to be 0 or 1, e.g.
1625 // shape [108365, 16]; strides [1, 0]
1626 // shape [13264, 1, 7]; strides [1, 1, 0]
1627 auto index_strides = index.strides().vec();
1628 bool is_index_expanded = index_strides[0] == 1;
1629 for (const auto dim : c10::irange(1, index_strides.size())) {
1630 if (index_strides[dim] > 1) { is_index_expanded = false; }
1631 }
1632
1633 // index is expanded
1634 return dim == 0 && is_index_expanded && src.is_contiguous() && self.is_contiguous();
1635 }
1636
1637 // gather_out_cpu_cuda
TORCH_IMPL_FUNC(gather_out)1638 TORCH_IMPL_FUNC(gather_out)
1639 (const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& result) {
1640 if (index.numel() == 0) return;
1641 dim = at::maybe_wrap_dim(dim, self.dim());
1642 if (can_use_expanded_index_path(result, dim, index, self, /*is_scatter_like=*/false)) {
1643 gather_expanded_index_stub(result.device().type(), result, self, index);
1644 } else {
1645 gather_stub(result.device().type(), result, self, dim, index);
1646 }
1647 }
1648
gather_backward(const Tensor & grad,const Tensor & self,int64_t dim,const Tensor & index,bool sparse_grad)1649 Tensor gather_backward(const Tensor& grad, const Tensor& self, int64_t dim, const Tensor& index, bool sparse_grad) {
1650 if (sparse_grad) {
1651 return at::_gather_sparse_backward(self, dim, index, grad);
1652 }
1653 auto result = grad.new_zeros_symint(self.sym_sizes());
1654 // for composite, vmap and inductor compliance, use out-of-place variant of
1655 // `scatter_add` if index or grad tensors is a Tensor Subclass.
1656 if (areAnyTensorSubclassLike({index, grad})) {
1657 return result.scatter_add(dim, index, grad);
1658 }
1659 result.scatter_add_(dim, index, grad);
1660 return result;
1661 }
1662
scatter_reduce_exclude_self_helper(const Tensor & self,int64_t dim,const Tensor & index,const ReductionType & op)1663 static void scatter_reduce_exclude_self_helper(
1664 const Tensor& self,
1665 int64_t dim,
1666 const Tensor& index,
1667 const ReductionType& op) {
1668 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
1669 at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool,
1670 self.scalar_type(), "scatter_reduce_exclude_input_init", [&] {
1671 scalar_t init_val;
1672 switch (op) {
1673 case ReductionType::SUM:
1674 init_val = (scalar_t)0;
1675 break;
1676 case ReductionType::PROD:
1677 init_val = (scalar_t)1;
1678 break;
1679 case ReductionType::MAX:
1680 init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
1681 : std::numeric_limits<scalar_t>::lowest();
1682 break;
1683 case ReductionType::MIN:
1684 init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
1685 : std::numeric_limits<scalar_t>::max();
1686 break;
1687 case ReductionType::MEAN:
1688 init_val = (scalar_t)0;
1689 break;
1690 }
1691 self.scatter_(dim, index, init_val);
1692 });
1693 }
1694
_scatter_via_index_put(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,const Tensor & mut_out,bool accumulate)1695 static void _scatter_via_index_put(
1696 const Tensor& self,
1697 int64_t dim,
1698 const Tensor& index,
1699 const Tensor& src,
1700 const Tensor& mut_out,
1701 bool accumulate) {
1702 if (self.dim() == 1) {
1703 torch::List<std::optional<Tensor>> indices;
1704 indices.reserve(1);
1705 indices.push_back(index);
1706 mut_out.index_put_(indices, src, accumulate);
1707 } else {
1708 Tensor mut_out_contig = mut_out.contiguous();
1709
1710 auto index_coords_sizes = index.sizes().vec();
1711 index_coords_sizes.push_back(self.dim());
1712 auto index_coords = at::empty(
1713 index_coords_sizes,
1714 at::TensorOptions().dtype(at::ScalarType::Long).device(self.device()));
1715
1716 for (int64_t dim_other = 0; dim_other < self.dim(); dim_other++) {
1717 if (dim_other == dim) {
1718 continue;
1719 }
1720 auto dim_coord_vals = at::arange(
1721 index.size(dim_other),
1722 at::TensorOptions().device(self.device()));
1723
1724 for (int64_t dim_unsqueeze = 0; dim_unsqueeze < self.dim() - 1; dim_unsqueeze++) {
1725 dim_coord_vals = dim_coord_vals.unsqueeze((dim_unsqueeze >= dim_other) ? -1 : 0);
1726 }
1727
1728 auto view_sizes = index.sizes().vec();
1729 view_sizes.push_back(1);
1730 auto view_strides = index_coords.strides().vec();
1731 view_strides[self.dim()] = self.dim();
1732
1733 at::as_strided(
1734 index_coords,
1735 view_sizes,
1736 view_strides,
1737 dim_other
1738 ).copy_(dim_coord_vals.unsqueeze(-1));
1739 }
1740
1741 auto view_sizes = index.sizes().vec();
1742 view_sizes.push_back(1);
1743 auto view_strides = index_coords.strides().vec();
1744 view_strides[self.dim()] = self.dim();
1745
1746 at::as_strided(
1747 index_coords,
1748 view_sizes,
1749 view_strides,
1750 dim
1751 ).copy_(index.unsqueeze(-1));
1752
1753 Tensor index_coords_flat = index_coords.flatten(0, -2);
1754
1755 // Copy mut_out_contig's strides into a tensor
1756 // TODO: Is there a utility function that already does this?
1757 IntArrayRef mut_out_contig_strides = mut_out_contig.strides();
1758 Tensor coord_strides = at::empty(
1759 {mut_out_contig.dim()},
1760 TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU));
1761 std::memcpy(
1762 coord_strides.mutable_data_ptr(),
1763 mut_out_contig_strides.data(),
1764 coord_strides.nbytes());
1765 coord_strides = coord_strides.to(mut_out_contig.device());
1766
1767 // `index_flat` contains the 1-D indices corresponding with the
1768 // flattened `mut_out`
1769 Tensor index_flat = (index_coords_flat * coord_strides).sum({-1});
1770 Tensor mut_out_flat = mut_out_contig.flatten();
1771 Tensor src_flat = at::as_strided(
1772 src,
1773 index.sizes(),
1774 src.strides()
1775 ).flatten();
1776
1777 torch::List<std::optional<Tensor>> indices;
1778 indices.reserve(1);
1779 indices.push_back(index_flat);
1780
1781 mut_out_flat.index_put_(indices, src_flat, accumulate);
1782
1783 if (!mut_out.is_contiguous()) {
1784 mut_out.copy_(mut_out_flat.reshape(mut_out.sizes()));
1785 }
1786 }
1787 }
1788
1789 template <bool use_new_options = false, typename T, typename ReduceStub, typename FillStub>
scatter_impl(const Tensor & self,int64_t dim,const Tensor & index,const T & src,const Tensor & out,ReduceStub & reduce_stub,FillStub & fill_stub,const std::optional<c10::string_view> reduce=std::nullopt,bool reduce_includes_self=true)1790 void scatter_impl(
1791 const Tensor& self,
1792 int64_t dim,
1793 const Tensor& index,
1794 const T& src,
1795 const Tensor& out,
1796 ReduceStub& reduce_stub,
1797 FillStub& fill_stub,
1798 const std::optional<c10::string_view> reduce = std::nullopt,
1799 bool reduce_includes_self = true) {
1800
1801 dim = at::maybe_wrap_dim(dim, self.dim());
1802 auto mut_out = const_cast<Tensor&>(out);
1803
1804 if (!self.is_same(mut_out)) {
1805 mut_out.copy_(self);
1806 }
1807
1808 if (index.numel() == 0) return;
1809
1810 auto op = ReductionType::SUM;
1811 bool deterministic = globalContext().deterministicAlgorithms() && (self.device().type() == DeviceType::CUDA || self.device().type() == DeviceType::XPU);
1812
1813 if (reduce.has_value()) {
1814 op = get_operator_enum(reduce.value(), use_new_options);
1815 if (!reduce_includes_self) {
1816 // scatter inits for reduction to appropriate indices (used by scatter_reduce.two)
1817 scatter_reduce_exclude_self_helper(mut_out, dim, index, op);
1818 }
1819 // _scatter_via_index_put can only handle sum and mean reduction type
1820 deterministic = deterministic && (op == ReductionType::SUM || op == ReductionType::MEAN);
1821 }
1822
1823 // Scalar src should already be deterministic
1824 if (deterministic && std::is_same_v<T, Tensor>) {
1825 // both runtime and compile check are required
1826 if constexpr (std::is_same_v<T, Tensor>) {
1827 bool accumulate = reduce.has_value();
1828 _scatter_via_index_put(self, dim, index, src, mut_out, accumulate);
1829 return;
1830 }
1831 }
1832
1833 if (reduce.has_value()) {
1834 reduce_stub(self.device().type(), mut_out, dim, index, src, op);
1835 } else {
1836 fill_stub(self.device().type(), mut_out, dim, index, src);
1837 }
1838 }
1839
TORCH_IMPL_FUNC(scatter_src_out)1840 TORCH_IMPL_FUNC(scatter_src_out)
1841 (const Tensor& self,
1842 int64_t dim,
1843 const Tensor& index,
1844 const Tensor& src,
1845 const Tensor& out) {
1846 scatter_impl(self, dim, index, src, out,
1847 scatter_reduce_stub,
1848 scatter_stub);
1849 }
1850
TORCH_IMPL_FUNC(scatter_value_out)1851 TORCH_IMPL_FUNC(scatter_value_out)
1852 (const Tensor& self,
1853 int64_t dim,
1854 const Tensor& index,
1855 const Scalar& value,
1856 const Tensor& out) {
1857 scatter_impl(self, dim, index, value, out,
1858 scatter_scalar_reduce_stub,
1859 scatter_fill_stub);
1860 }
1861
TORCH_IMPL_FUNC(scatter_reduce_out)1862 TORCH_IMPL_FUNC(scatter_reduce_out)
1863 (const Tensor& self,
1864 int64_t dim,
1865 const Tensor& index,
1866 const Tensor& src,
1867 const c10::string_view reduce,
1868 const Tensor& out) {
1869 scatter_impl(self, dim, index, src, out,
1870 scatter_reduce_stub,
1871 scatter_stub,
1872 reduce);
1873 }
1874
TORCH_IMPL_FUNC(scatter_value_reduce_out)1875 TORCH_IMPL_FUNC(scatter_value_reduce_out)
1876 (const Tensor& self,
1877 int64_t dim,
1878 const Tensor& index,
1879 const Scalar& value,
1880 const c10::string_view reduce,
1881 const Tensor& out) {
1882 scatter_impl(self, dim, index, value, out,
1883 scatter_scalar_reduce_stub,
1884 scatter_fill_stub,
1885 reduce);
1886 }
1887
TORCH_IMPL_FUNC(scatter_add)1888 TORCH_IMPL_FUNC(scatter_add)
1889 (const Tensor& self,
1890 int64_t dim,
1891 const Tensor& index,
1892 const Tensor& src,
1893 const Tensor& out) {
1894 auto mut_out = const_cast<Tensor&>(out);
1895 dim = maybe_wrap_dim(dim, self.dim());
1896
1897 if (!self.is_same(mut_out)) {
1898 mut_out.copy_(self);
1899 }
1900
1901 if (index.numel() == 0) return;
1902
1903 // See Note [Enabling Deterministic Operations]
1904 // Avoid gpuAtomicAdd for CUDA if deterministic mode is turned on
1905 if (globalContext().deterministicAlgorithms() && self.device().type() == DeviceType::CUDA) {
1906 _scatter_via_index_put(self, dim, index, src, mut_out, /*accumulate*/true);
1907 } else {
1908 if (can_use_expanded_index_path(mut_out, dim, index, src, /*is_scatter_like*/true)) {
1909 scatter_add_expanded_index_stub(self.device().type(), mut_out, index, src);
1910 } else {
1911 scatter_add_stub(self.device().type(), mut_out, dim, index, src);
1912 }
1913 }
1914 }
1915
TORCH_IMPL_FUNC(scatter_reduce_two)1916 TORCH_IMPL_FUNC(scatter_reduce_two)
1917 (const Tensor& self,
1918 int64_t dim,
1919 const Tensor& index,
1920 const Tensor& src,
1921 const c10::string_view reduce,
1922 bool include_self,
1923 const Tensor& out) {
1924
1925 dim = at::maybe_wrap_dim(dim, self.dim());
1926
1927 if (!self.is_same(out)) {
1928 out.copy_(self);
1929 }
1930
1931 const auto op = get_operator_enum(reduce, true);
1932
1933 if (can_use_expanded_index_path(out, dim, index, src, /*is_scatter_like*/true)) {
1934 scatter_reduce_expanded_index_stub(self.device().type(), out, index, src, op, include_self);
1935 return;
1936 }
1937
1938 scatter_impl</*use_new_options=*/true>(self, dim, index, src, out,
1939 scatter_reduce_two_stub,
1940 scatter_stub,
1941 reduce,
1942 include_self);
1943
1944 if (op == ReductionType::MEAN) {
1945 auto ones = at::ones_like(src);
1946 auto count = include_self ? at::ones_like(out) : at::zeros_like(out);
1947 count.scatter_add_(dim, index, ones);
1948 count.masked_fill_(count == 0, 1);
1949
1950 if (out.is_floating_point() || out.is_complex()) {
1951 out.div_(count);
1952 } else {
1953 out.div_(count, "floor");
1954 }
1955 }
1956 }
1957
masked_scatter(const Tensor & self,const Tensor & mask,const Tensor & source)1958 Tensor masked_scatter(const Tensor & self, const Tensor & mask, const Tensor & source) {
1959 auto [_mask, _self] = expand_outplace(mask, self);
1960 return _self->clone(at::MemoryFormat::Contiguous).masked_scatter_(*_mask, source);
1961 }
1962
masked_scatter_backward_symint(const Tensor & grad,const Tensor & mask,c10::SymIntArrayRef sizes)1963 Tensor masked_scatter_backward_symint(
1964 const Tensor& grad,
1965 const Tensor& mask,
1966 c10::SymIntArrayRef sizes) {
1967 c10::SymInt numel = 1;
1968 for (const auto& size : sizes) {
1969 numel *= size;
1970 }
1971 auto mask_selected = grad.masked_select(mask);
1972 auto diff_nelem = numel - mask_selected.sym_numel();
1973 if (diff_nelem > 0) {
1974 // because mask_selected returns a 1-d tensor with size of masked elements
1975 // that are 1, we need to fill out the rest with zeros then reshape back to
1976 // tensor2's size.
1977 auto zeros_fillin =
1978 at::zeros_symint({std::move(diff_nelem)}, grad.options());
1979 mask_selected = at::cat({mask_selected, std::move(zeros_fillin)}, 0);
1980 }
1981 return mask_selected.view_symint(sizes);
1982 }
1983
masked_fill_impl_cpu(Tensor & self,const Tensor & mask,const Scalar & value)1984 static Tensor & masked_fill_impl_cpu(Tensor & self, const Tensor & mask, const Scalar& value) {
1985 NoNamesGuard guard;
1986 TORCH_CHECK(mask.dtype() == ScalarType::Bool, "masked_fill_ only supports boolean masks, but got mask "
1987 "with dtype ", mask.dtype());
1988
1989 if (at::has_internal_overlap(self) == MemOverlap::Yes) {
1990 TORCH_WARN(
1991 "Use of masked_fill_ on expanded tensors is deprecated. "
1992 "Please clone() the tensor before performing this operation. "
1993 "This also applies to advanced indexing e.g. tensor[mask] = scalar");
1994 }
1995 at::assert_no_partial_overlap(self, mask);
1996
1997 auto iter = TensorIteratorConfig()
1998 .set_check_mem_overlap(false) // deprecated, but not a hard error
1999 .check_all_same_dtype(false)
2000 .resize_outputs(false)
2001 .add_output(self)
2002 .add_const_input(mask)
2003 .build();
2004
2005 masked_fill_stub(iter.device_type(), iter, value);
2006 return self;
2007 }
2008
masked_fill__cpu(Tensor & self,const Tensor & mask,const Scalar & value)2009 Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, const Scalar& value) {
2010 auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_");
2011
2012 masked_fill_impl_cpu(self, mask, value);
2013 namedinference::propagate_names_if_nonempty(self, maybe_outnames);
2014 return self;
2015 }
2016
masked_fill__cpu(Tensor & self,const Tensor & mask,const Tensor & value)2017 Tensor & masked_fill__cpu(Tensor& self, const Tensor & mask, const Tensor & value) {
2018 auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_");
2019 TORCH_CHECK(value.dim() == 0, "masked_fill_ only supports a 0-dimensional value tensor, but got tensor "
2020 "with ", value.dim(), " dimension(s).");
2021
2022 masked_fill_impl_cpu(self, mask, value.item());
2023 namedinference::propagate_names_if_nonempty(self, maybe_outnames);
2024 return self;
2025 }
2026
masked_fill(const Tensor & self,const Tensor & mask,const Scalar & source)2027 Tensor masked_fill(const Tensor & self, const Tensor & mask, const Scalar& source) {
2028 Tensor result;
2029 auto maybe_outnames = namedinference::broadcast_to_outnames(mask, self, "masked_fill");
2030 {
2031 NoNamesGuard guard;
2032 auto [_mask, _self] = expand_outplace(mask, self);
2033 result = _self->clone(at::MemoryFormat::Contiguous);
2034 result.masked_fill_(mask, source);
2035 }
2036 namedinference::propagate_names_if_nonempty(result, maybe_outnames);
2037 return result;
2038 }
2039
masked_fill(const Tensor & self,const Tensor & mask,const Tensor & source)2040 Tensor masked_fill(const Tensor & self, const Tensor & mask, const Tensor & source) {
2041 Tensor result;
2042 auto maybe_outnames = namedinference::broadcast_to_outnames(mask, self, "masked_fill");
2043 {
2044 NoNamesGuard guard;
2045 auto [_mask, _self] = expand_outplace(mask, self);
2046 result = _self->clone(at::MemoryFormat::Contiguous);
2047 result.masked_fill_(mask, source);
2048 }
2049 namedinference::propagate_names_if_nonempty(result, maybe_outnames);
2050 return result;
2051 }
2052
masked_select_out_impl_cpu(Tensor & result,const Tensor & self,const Tensor & mask)2053 static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self, const Tensor & mask) {
2054 NoNamesGuard guard;
2055
2056 TORCH_CHECK(mask.scalar_type() == ScalarType::Bool,
2057 "masked_select: expected BoolTensor for mask");
2058 TORCH_CHECK(self.scalar_type() == result.scalar_type(),
2059 "masked_select(): self and result must have the same scalar type");
2060
2061 at::assert_no_internal_overlap(result);
2062 at::assert_no_overlap(result, self);
2063 at::assert_no_overlap(result, mask);
2064
2065 auto [_mask, _self] = expand_outplace(mask, self);
2066
2067 auto shape = _self->sizes();
2068 int64_t numel = _mask->sum().item().toLong();
2069 at::native::resize_output(result, {numel});
2070 if (numel == 0) {
2071 return result;
2072 }
2073
2074 // Create strided view of result before feeding into TensorIterator
2075 auto strides = DimVector(shape.size(), 0);
2076 auto orig_stride = result.strides()[0];
2077 auto result_strided = result.as_strided(shape, strides);
2078
2079 // serial kernel
2080 // serial kernel requires that src is traversed in its logical order. However, TensorIterator might
2081 // have reordered dimensions so that src would be traversed in its physical order, producing wrong
2082 // answers. A sufficient condition that no reorder happened is that both _self and _mask is contiguous.
2083 // If it is not satisfied, use parallel kernel that handles permutations correctly
2084 bool use_serial_kernel = (self.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ) &&
2085 _self->is_contiguous() && _mask->is_contiguous();
2086 if (use_serial_kernel) {
2087 auto iter = TensorIteratorConfig()
2088 .set_check_mem_overlap(false) // result is intentionally zero-strided above
2089 .check_all_same_dtype(false)
2090 .resize_outputs(false)
2091 .add_output(result_strided)
2092 .add_const_input(*_self)
2093 .add_const_input(*_mask)
2094 .build();
2095
2096 masked_select_serial_stub(iter.device_type(), iter, orig_stride);
2097 return result;
2098 }
2099
2100 // Use a prefix sum to record the output locations of the masked elements,
2101 // so as to parallel with TensorIterator.
2102 auto mask_long = at::empty(shape, self.options().dtype(at::kLong)).copy_(*_mask);
2103 auto mask_prefix_sum = at::empty(shape, self.options().dtype(at::kLong));
2104 auto mask_long_data = mask_long.data_ptr<int64_t>();
2105 auto mask_prefix_sum_data = mask_prefix_sum.data_ptr<int64_t>();
2106 // TODO: Here can only use std::partial_sum for C++14,
2107 // use std::exclusive_scan when PyTorch upgrades to C++17, which have better performance.
2108 // std::exclusive_scan(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data, 0);
2109 std::partial_sum(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data);
2110
2111 auto iter = TensorIteratorConfig()
2112 .set_check_mem_overlap(false) // result is intentionally zero-strided above
2113 .check_all_same_dtype(false)
2114 .resize_outputs(false)
2115 .add_output(result_strided)
2116 .add_const_input(*_self)
2117 .add_const_input(*_mask)
2118 .add_const_input(mask_prefix_sum)
2119 .build();
2120
2121 masked_select_stub(iter.device_type(), iter, orig_stride);
2122 return result;
2123 }
2124
masked_select_out_cpu(const Tensor & self,const Tensor & mask,Tensor & result)2125 Tensor & masked_select_out_cpu(const Tensor & self, const Tensor & mask, Tensor & result) {
2126 namedinference::compute_broadcast_outnames(self, mask);
2127 return masked_select_out_impl_cpu(result, self, mask);
2128 }
2129
masked_select_cpu(const Tensor & self,const Tensor & mask)2130 Tensor masked_select_cpu(const Tensor & self, const Tensor & mask) {
2131 Tensor result = at::empty({0}, self.options());
2132 return at::native::masked_select_out_cpu(self, mask, result);
2133 }
2134
masked_select_backward(const Tensor & grad,const Tensor & input,const Tensor & mask)2135 Tensor masked_select_backward(const Tensor& grad, const Tensor& input, const Tensor& mask) {
2136 // The following could just be written as `zeros_like(input).masked_scatter(mask, grad)`.
2137 // However, as an optimization, we call the in-place variant of masked_scatter.
2138 // Unfortunately, that doesn't allow for the broadcasting of the LHS, so we need
2139 // to explicitly broadcast here (the out-of-place variant of masked_scatter
2140 // implicitly handles broadcasting).
2141 auto result = at::zeros_like(
2142 input.expand(at::infer_size(input.sizes(), mask.sizes())), at::MemoryFormat::Preserve);
2143
2144 // for composite compliance, use out-of-place variant
2145 // of `masked_scatter`.
2146 if (areAnyTensorSubclassLike({grad, mask})) {
2147 return result.masked_scatter(mask, grad);
2148 }
2149 result.masked_scatter_(mask, grad);
2150 return result;
2151 }
2152
2153 namespace {
2154
_take_along_dim_helper(const Tensor & self,const Tensor & indices,int64_t dim)2155 inline std::tuple<Tensor, Tensor, int64_t> _take_along_dim_helper(
2156 const Tensor& self,
2157 const Tensor& indices,
2158 int64_t dim) {
2159 TORCH_CHECK(
2160 self.dim() == indices.dim(),
2161 "torch.take_along_dim(): input and indices should have the same number of dimensions, ",
2162 "but got ", self.dim(), " dimensions for input, and ", indices.dim(), " dimensions for indices")
2163 TORCH_CHECK(
2164 indices.scalar_type() == ScalarType::Long,
2165 "torch.take_along_dim(): dtype of indices should be Long but got ", indices.scalar_type())
2166
2167 dim = at::maybe_wrap_dim(dim, self.dim());
2168
2169 SymDimVector self_sizes{self.sym_sizes()};
2170 // update number of elements at dim as per indices
2171 self_sizes[dim] = indices.sym_size(dim);
2172 auto broadcast_shape = infer_size_symint(self_sizes, indices.sym_sizes());
2173 auto indices_broadcasted = at::broadcast_to_symint(indices, broadcast_shape);
2174
2175 SymDimVector indices_sizes{indices.sym_sizes()};
2176 // update number of elements at dim as per self
2177 indices_sizes[dim] = self.sym_size(dim);
2178 broadcast_shape = infer_size_symint(indices_sizes, self.sym_sizes());
2179 auto self_broadcasted = at::broadcast_to_symint(self, broadcast_shape);
2180
2181 return std::make_tuple(std::move(self_broadcasted),
2182 std::move(indices_broadcasted),
2183 std::move(dim));
2184 }
2185
checkDevice(CheckedFrom c,const Tensor & t,Device device)2186 static inline void checkDevice(CheckedFrom c, const Tensor& t, Device device) {
2187 TORCH_CHECK(
2188 !t.defined() || t.device() == device,
2189 "Expected tensor to have ", device,
2190 " Device, but got tensor with ", t.device(), " Device ",
2191 "(while checking arguments for ", c, ")");
2192 }
2193
checkDevice(CheckedFrom c,at::ArrayRef<Tensor> tensors,Device device)2194 static inline void checkDevice(CheckedFrom c, at::ArrayRef<Tensor> tensors, Device device) {
2195 for (auto &t : tensors) {
2196 checkDevice(c, t, device);
2197 }
2198 }
2199
2200 } // anonymous namespace
2201
take_along_dim(const Tensor & self,const Tensor & indices,std::optional<int64_t> opt_dim)2202 Tensor take_along_dim(const Tensor& self, const Tensor& indices, std::optional<int64_t> opt_dim) {
2203 checkDevice("torch.take_along_dim():", {self, indices}, self.device());
2204 if (opt_dim.has_value()) {
2205 auto [self_broadcasted, indices_broadcasted, dim] =
2206 _take_along_dim_helper(self, indices, opt_dim.value());
2207 return self_broadcasted.gather(dim, indices_broadcasted);
2208 }
2209
2210 // similar to `take`, but `take` doesn't support the same dtypes as `gather`.
2211 return self.view(-1).gather(0, indices.view(-1));
2212 }
2213
take_along_dim_out(const Tensor & self,const Tensor & indices,std::optional<int64_t> opt_dim,Tensor & result)2214 Tensor& take_along_dim_out(const Tensor& self, const Tensor& indices, std::optional<int64_t> opt_dim, Tensor& result) {
2215 checkDevice("torch.take_along_dim():", {self, indices, result}, self.device());
2216 if (opt_dim.has_value()) {
2217 auto [self_broadcasted, indices_broadcasted, dim] =
2218 _take_along_dim_helper(self, indices, opt_dim.value());
2219 return at::gather_out(result, self_broadcasted, dim, indices_broadcasted);
2220 }
2221
2222 // similar to `take`, but `take` doesn't support the same dtypes as `gather`.
2223 return at::gather_out(result, self.view(-1), 0, indices.view(-1));
2224 }
2225
_gather_sparse_backward(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & grad)2226 Tensor _gather_sparse_backward(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& grad){
2227 // special case scalar input and/or index
2228 if (self.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(at::empty_symint({0,grad.sym_numel()}, index.options()), grad, self.sym_sizes());
2229 if (grad.ndimension() == 0) return at::_sparse_coo_tensor_unsafe_symint(index.view({1,1}), grad, self.sym_sizes());
2230 Tensor sparse_ind = at::empty_symint({self.ndimension(), grad.sym_numel()}, self.options().dtype(at::kLong));
2231 SymInt grad_numel = grad.sym_numel();
2232 if (grad_numel > 0) {
2233 SymInt n_above = grad_numel;
2234 SymInt n_below = 1;
2235 if (dim < 0) dim += self.ndimension();
2236 for (const auto i : c10::irange(self.ndimension())) {
2237 n_above /= grad.sym_size(i);
2238 if (i == dim) {
2239 sparse_ind[i] = index.reshape(-1);
2240 } else {
2241 sparse_ind[i] = at::arange(grad.sym_size(i),self.options().dtype(at::kLong)).unsqueeze(1).expand_symint({grad.sym_size(i), n_above}).reshape(-1).repeat_symint(n_below);
2242 }
2243 n_below *= grad.sym_size(i);
2244 }
2245 }
2246 return at::_sparse_coo_tensor_unsafe_symint(sparse_ind, grad.reshape(-1), self.sym_sizes());
2247 }
2248
2249 template <typename scalar_t>
count_nonzero_impl(TensorIteratorBase & iter,Range range)2250 int64_t count_nonzero_impl(TensorIteratorBase& iter, Range range) {
2251 int64_t num_nonzero = 0;
2252
2253 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
2254 constexpr int ilp_factor = 4;
2255 const char* ptr = data[0];
2256 const auto stride = strides[0];
2257 int64_t nonzero[ilp_factor] = {0};
2258
2259 int64_t i = 0;
2260 for (; i + (ilp_factor - 1) < n; i += ilp_factor) {
2261 c10::ForcedUnroll<ilp_factor>{}([&](int k) {
2262 const auto& val = c10::load<scalar_t>(ptr + k * stride);
2263 if (val != scalar_t(0)) {
2264 ++nonzero[k];
2265 }
2266 });
2267 ptr += ilp_factor * stride;
2268 }
2269 for (; i < n; ++i) {
2270 const auto& val = c10::load<scalar_t>(ptr);
2271 if (val != scalar_t(0)) {
2272 ++nonzero[0];
2273 }
2274 ptr += stride;
2275 }
2276 for (const auto k : c10::irange(1, ilp_factor)) {
2277 nonzero[0] += nonzero[k];
2278 }
2279 num_nonzero += nonzero[0];
2280 };
2281 iter.serial_for_each(loop, range);
2282
2283 return num_nonzero;
2284 }
2285
count_nonzero_cuda(const Tensor & self,IntArrayRef dims)2286 Tensor count_nonzero_cuda(const Tensor& self, IntArrayRef dims){
2287 auto reduce = self;
2288 if (reduce.scalar_type() != kBool) {
2289 reduce = reduce != 0;
2290 }
2291 return reduce.sum(dims);
2292 }
2293
count_nonzero_cpu(const Tensor & self,IntArrayRef dims)2294 Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){
2295 if (!dims.empty()) {
2296 auto reduce = self;
2297 if (reduce.scalar_type() != kBool) {
2298 reduce = reduce != 0;
2299 }
2300 return reduce.sum(dims);
2301 }
2302
2303 // Optimized all-reduce
2304 auto iter = TensorIteratorConfig()
2305 .add_const_input(self)
2306 .build();
2307
2308 const auto num_threads = at::get_num_threads();
2309 DimVector thread_count_nonzero(num_threads);
2310
2311 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
2312 kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] {
2313 at::parallel_for(0, iter.numel(), internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
2314 const auto tid = at::get_thread_num();
2315 thread_count_nonzero[tid] = count_nonzero_impl<scalar_t>(iter, {begin, end});
2316 });
2317 });
2318
2319 for (const auto i : c10::irange(1, num_threads)) {
2320 thread_count_nonzero[0] += thread_count_nonzero[i];
2321 }
2322 auto out = at::empty({}, self.options().dtype(kLong));
2323 *out.mutable_data_ptr<int64_t>() = thread_count_nonzero[0];
2324 return out;
2325 }
2326
2327
count_nonzero(const Tensor & self,std::optional<int64_t> dim)2328 Tensor count_nonzero(const Tensor& self, std::optional<int64_t> dim) {
2329 if (dim) {
2330 return at::count_nonzero(self, IntArrayRef{*dim});
2331 }
2332 return at::count_nonzero(self, IntArrayRef{});
2333 }
2334
2335
nonzero_out_cpu(const Tensor & self,Tensor & result)2336 Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) {
2337 TORCH_CHECK(result.scalar_type() == kLong,
2338 "nonzero: Expected out tensor to have scalar type Long "
2339 "but got scalar type", result.scalar_type());
2340 at::assert_no_internal_overlap(result);
2341 at::assert_no_overlap(result, self);
2342
2343 auto iter = TensorIteratorConfig()
2344 .add_const_input(self)
2345 .enforce_linear_iteration()
2346 .build();
2347
2348 const auto numel = iter.numel();
2349 const auto num_threads = at::get_num_threads();
2350 DimVector thread_begin(num_threads, -1);
2351 DimVector thread_count_nonzero(num_threads + 1);
2352
2353 // Pass 1: Count nonzero element per-thread
2354 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
2355 kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_count_cpu", [&] {
2356 at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
2357 const auto tid = at::get_thread_num();
2358 thread_begin[tid] = begin;
2359 thread_count_nonzero[tid + 1] = count_nonzero_impl<scalar_t>(iter, {begin, end});
2360 });
2361 });
2362
2363 // Convert thread-local counts to cumulative sum
2364 for (const auto i : c10::irange(1, thread_count_nonzero.size())) {
2365 thread_count_nonzero[i] += thread_count_nonzero[i - 1];
2366 }
2367
2368 const auto self_sizes = self.sizes();
2369 const auto total_nonzero = thread_count_nonzero.back();
2370 const int64_t ndim = self_sizes.size();
2371 if (resize_output(result, {total_nonzero, ndim})) {
2372 // Default to fortran-contiguous output (see gh-46224)
2373 result.as_strided_({total_nonzero, ndim}, {1, total_nonzero});
2374 }
2375
2376 if (result.numel() == 0) {
2377 return result;
2378 }
2379
2380 auto out_accessor = result.accessor<int64_t, 2>();
2381
2382 // Pass 2: Write indexes
2383 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
2384 kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_cpu", [&] {
2385 at::parallel_for(0, numel, internal::GRAIN_SIZE, [&] (int64_t begin, int64_t end) {
2386 auto tid = at::get_thread_num();
2387 // Work needs to be distributed the same on both passes
2388 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(begin == thread_begin[tid]);
2389
2390 // +1 faster than additional condition check inside loop
2391 c10::SmallVector<int64_t, 33> sizes(ndim + 1, -1);
2392 std::copy(self_sizes.begin(), self_sizes.end(), sizes.begin() + 1);
2393 c10::SmallVector<int64_t, 33> current_idx(ndim + 1);
2394 if (begin > 0) {
2395 auto idx = begin;
2396 for (int64_t k = ndim; idx > 0 && k > 0; --k) {
2397 current_idx[k] = idx % sizes[k];
2398 idx /= sizes[k];
2399 }
2400 }
2401
2402 auto out_ptr = out_accessor[thread_count_nonzero[tid]].data();
2403
2404 auto loop = [&](char** data, const int64_t* strides, int64_t n1, int64_t n2) {
2405 // Copy into local variables to improve compiler alias analysis
2406 int64_t* C10_RESTRICT local_idx = current_idx.data() + 1;
2407 const int64_t* C10_RESTRICT local_sizes = sizes.data() + 1;
2408 const auto in_stride = strides[0];
2409 const auto out_stride1 = out_accessor.stride(1);
2410 const auto out_stride0 = out_accessor.stride(0) - ndim * out_stride1;
2411 const auto ndim = out_accessor.size(1);
2412 int64_t* out = out_ptr;
2413
2414 for (const auto i : c10::irange(n2)) {
2415 const char* ptr = data[0] + i * strides[1];
2416 for (C10_UNUSED const auto j : c10::irange(n1)) {
2417 const auto& val = c10::load<scalar_t>(ptr);
2418 // If nonzero, write index
2419 if (val != scalar_t(0)) {
2420 for (const auto k : c10::irange(ndim)) {
2421 *out = local_idx[k];
2422 out += out_stride1;
2423 }
2424 out += out_stride0;
2425 }
2426 ptr += in_stride;
2427
2428 // Advance current index
2429 int64_t k = ndim - 1;
2430 ++local_idx[k];
2431 while (C10_UNLIKELY(local_idx[k] == local_sizes[k])) {
2432 local_idx[k] = 0;
2433 --k;
2434 ++local_idx[k];
2435 }
2436 }
2437 }
2438 out_ptr = out;
2439 };
2440 iter.serial_for_each(loop, {begin, end});
2441 TORCH_INTERNAL_ASSERT(out_ptr == out_accessor[thread_count_nonzero[tid + 1]].data());
2442 });
2443 });
2444 return result;
2445 }
2446
nonzero_cpu(const Tensor & self)2447 Tensor nonzero_cpu(const Tensor& self) {
2448 auto result = at::empty({0}, self.options().dtype(kLong));
2449 nonzero_out_cpu(self, result);
2450 return result;
2451 }
2452
nonzero_static_out_cpu(const Tensor & self,int64_t size,int64_t fill_value,Tensor & result)2453 Tensor& nonzero_static_out_cpu(
2454 const Tensor& self,
2455 int64_t size,
2456 int64_t fill_value,
2457 Tensor& result) {
2458 // Check if `size` is not negative
2459 TORCH_CHECK(
2460 size >= 0, "nonzero_static: 'size' must be an non-negative integer");
2461 TORCH_CHECK(
2462 result.scalar_type() == kLong,
2463 "nonzero_static: Expected out tensor to have scalar type Long "
2464 "but got scalar type",
2465 result.scalar_type());
2466
2467 int64_t ndim = self.dim();
2468 if (result.dim() != 2 || result.size(0) != size || result.size(1) != ndim) {
2469 at::native::resize_output(result, {size, ndim});
2470 }
2471 // Verify that the output tensor is resized to expected size=(size, ndim)
2472 TORCH_CHECK(
2473 result.dim() == 2,
2474 "nonzero_static: Expected out tensor to be a 2D tensor but got a ",
2475 result.dim(),
2476 "D tensor");
2477 TORCH_CHECK(
2478 result.size(0) == size && result.size(1) == ndim,
2479 "nonzero_static: Expected out tensor to have Size([",
2480 size,
2481 ", ",
2482 ndim,
2483 "]) but got Size([",
2484 result.size(0),
2485 ", ",
2486 result.size(1),
2487 "]) ");
2488 at::assert_no_internal_overlap(result);
2489 at::assert_no_overlap(result, self);
2490
2491 // Return earlier if either dim is 0
2492 if (result.size(0) == 0 || result.size(1) == 0) {
2493 return result;
2494 }
2495
2496 // Delegate call to regular nonzero to get a data-dependent output
2497 auto dyn_result = nonzero_cpu(self);
2498 int64_t num_nonzeros = dyn_result.size(0);
2499 int64_t copy_len = std::min(size, num_nonzeros);
2500 // Copy the dynamic result to the fixed-size tensor
2501 result.narrow(0, 0, copy_len).copy_(dyn_result.narrow(0, 0, copy_len));
2502 if (size > copy_len) {
2503 // Pad result with `fill_value`
2504 result.narrow(0, copy_len, size - copy_len).fill_(fill_value);
2505 }
2506 return result;
2507 }
2508
nonzero_static_cpu(const Tensor & self,int64_t size,int64_t fill_value)2509 Tensor nonzero_static_cpu(
2510 const Tensor& self,
2511 int64_t size,
2512 int64_t fill_value) {
2513 // Check if `size` is not negative
2514 TORCH_CHECK(
2515 size >= 0, "nonzero_static: 'size' must be an non-negative integer");
2516 // Allocate fixed-size out tensor
2517 int64_t ndim = self.dim();
2518 auto result = at::empty(
2519 {size, ndim},
2520 at::TensorOptions().dtype(at::ScalarType::Long).device(at::kCPU));
2521 nonzero_static_out_cpu(self, size, fill_value, result);
2522 return result;
2523 }
2524
nonzero_numpy(const Tensor & self)2525 std::vector<Tensor> nonzero_numpy(const Tensor& self) {
2526 // special case scalar for compatibility with numpy:
2527 //
2528 // >>> np.array(5).nonzero()
2529 // (array([0]),)
2530 // >>> np.array(0).nonzero()
2531 // (array([], dtype=int64),)
2532
2533 if (self.dim() == 0) {
2534 return self.unsqueeze(0).nonzero().unbind(1);
2535 }
2536
2537 return self.nonzero().unbind(1);
2538 }
2539
argwhere(const Tensor & self)2540 Tensor argwhere(const Tensor& self) {
2541 return self.nonzero();
2542 }
2543
masked_scatter__cpu(Tensor & self,const Tensor & mask,const Tensor & source)2544 Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & source) {
2545 at::assert_no_internal_overlap(self);
2546 TORCH_CHECK(
2547 self.scalar_type() == source.scalar_type(),
2548 "masked_scatter: expected self and source to have same dtypes but got",
2549 self.scalar_type(),
2550 " and ",
2551 source.scalar_type());
2552
2553 TORCH_CHECK(self.device().type() == at::kCPU, "device type of self (", self.device().type(), ") is not CPU");
2554 TORCH_CHECK(mask.device().type() == at::kCPU, "device type of mask (", mask.device().type(), ") is not CPU");
2555 TORCH_CHECK(source.device().type() == at::kCPU, "device type of source (", source.device().type(), ") is not CPU");
2556
2557 c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_scatter_");
2558
2559 if (b_mask->dtype() == ScalarType::Byte) {
2560 TORCH_WARN("masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated," \
2561 "please use a mask with dtype torch.bool instead.");
2562 }
2563
2564 auto src_cont = source.contiguous();
2565
2566 auto iter = TensorIteratorConfig()
2567 .set_check_mem_overlap(false)
2568 .check_all_same_dtype(false)
2569 .resize_outputs(false)
2570 // order of indexing matters
2571 .enforce_linear_iteration()
2572 .add_output(self)
2573 .add_const_input(*b_mask)
2574 .build();
2575
2576 masked_scatter_stub(iter.device_type(), iter, src_cont);
2577 return self;
2578 }
2579
2580 } // namespace at::native
2581