1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/DimVector.h>
4 #include <ATen/core/functional.h>
5 #include <ATen/core/IListRef.h>
6 #include <ATen/TensorSubclassLikeUtils.h>
7 #include <ATen/AccumulateType.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/ExpandUtils.h>
10 #include <ATen/InferSize.h>
11 #include <ATen/MemoryOverlap.h>
12 #include <ATen/NamedTensorUtils.h>
13 #include <ATen/SparseCsrTensorUtils.h>
14 #include <ATen/TensorOperators.h>
15 #include <ATen/WrapDimUtils.h>
16 #include <ATen/core/DimVector.h>
17 #include <ATen/core/IListRef.h>
18 #include <ATen/native/Copy.h>
19 #include <ATen/native/NonSymbolicBC.h>
20 #include <ATen/native/Resize.h>
21 #include <ATen/native/SparseTensorUtils.h>
22 #include <ATen/native/TensorIterator.h>
23 #include <ATen/native/TensorShape.h>
24 #include <ATen/native/TypeProperties.h>
25 #include <ATen/native/cpu/CatKernel.h>
26 #include <ATen/native/cpu/SerialStackImpl.h>
27 #include <ATen/native/cpu/StackKernel.h>
28 #include <ATen/quantized/QTensorImpl.h>
29 #include <c10/util/Exception.h>
30 #include <optional>
31 #include <c10/util/SmallVector.h>
32 #include <c10/util/accumulate.h>
33 #include <c10/util/irange.h>
34
35 #ifndef AT_PER_OPERATOR_HEADERS
36 #include <ATen/Functions.h>
37 #include <ATen/NativeFunctions.h>
38 #else
39 #include <ATen/ops/_chunk_cat_native.h>
40 #include <ATen/ops/_conj_copy_native.h>
41 #include <ATen/ops/_convert_indices_from_coo_to_csr.h>
42 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
43 #include <ATen/ops/_foreach_copy.h>
44 #include <ATen/ops/_fw_primal_copy_native.h>
45 #include <ATen/ops/_indices_copy_native.h>
46 #include <ATen/ops/_make_dual.h>
47 #include <ATen/ops/_make_dual_copy_native.h>
48 #include <ATen/ops/_mkldnn_reshape.h>
49 #include <ATen/ops/_mkldnn_transpose.h>
50 #include <ATen/ops/_neg_view_copy_native.h>
51 #include <ATen/ops/_reshape_alias_copy_native.h>
52 #include <ATen/ops/_reshape_alias_native.h>
53 #include <ATen/ops/_reshape_copy_native.h>
54 #include <ATen/ops/_reshape_from_tensor_native.h>
55 #include <ATen/ops/_shape_as_tensor_native.h>
56 #include <ATen/ops/_sparse_broadcast_to.h>
57 #include <ATen/ops/_sparse_broadcast_to_copy_native.h>
58 #include <ATen/ops/_sparse_broadcast_to_native.h>
59 #include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
60 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
61 #include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
62 #include <ATen/ops/_sparse_csr_tensor_unsafe.h>
63 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
64 #include <ATen/ops/_stack_native.h>
65 #include <ATen/ops/_unsafe_view.h>
66 #include <ATen/ops/_unsafe_view_native.h>
67 #include <ATen/ops/_values_copy_native.h>
68 #include <ATen/ops/adjoint_native.h>
69 #include <ATen/ops/alias.h>
70 #include <ATen/ops/alias_copy_native.h>
71 #include <ATen/ops/alias_native.h>
72 #include <ATen/ops/arange.h>
73 #include <ATen/ops/arange_native.h>
74 #include <ATen/ops/as_strided_copy_native.h>
75 #include <ATen/ops/as_strided_native.h>
76 #include <ATen/ops/as_strided_scatter_native.h>
77 #include <ATen/ops/atleast_1d.h>
78 #include <ATen/ops/atleast_2d.h>
79 #include <ATen/ops/atleast_3d.h>
80 #include <ATen/ops/block_diag_native.h>
81 #include <ATen/ops/broadcast_tensors_native.h>
82 #include <ATen/ops/broadcast_to_native.h>
83 #include <ATen/ops/cat.h>
84 #include <ATen/ops/cat_meta.h>
85 #include <ATen/ops/cat_native.h>
86 #include <ATen/ops/chunk_native.h>
87 #include <ATen/ops/col_indices_copy_native.h>
88 #include <ATen/ops/column_stack_native.h>
89 #include <ATen/ops/concat_native.h>
90 #include <ATen/ops/concatenate_native.h>
91 #include <ATen/ops/crow_indices_copy_native.h>
92 #include <ATen/ops/dense_dim_native.h>
93 #include <ATen/ops/detach_copy_native.h>
94 #include <ATen/ops/detach_native.h>
95 #include <ATen/ops/diag.h>
96 #include <ATen/ops/diag_embed.h>
97 #include <ATen/ops/diag_embed_native.h>
98 #include <ATen/ops/diag_native.h>
99 #include <ATen/ops/diagflat_native.h>
100 #include <ATen/ops/diagonal.h>
101 #include <ATen/ops/diagonal_backward.h>
102 #include <ATen/ops/diagonal_backward_native.h>
103 #include <ATen/ops/diagonal_copy.h>
104 #include <ATen/ops/diagonal_copy_native.h>
105 #include <ATen/ops/diagonal_native.h>
106 #include <ATen/ops/diagonal_scatter_native.h>
107 #include <ATen/ops/dsplit_native.h>
108 #include <ATen/ops/dstack_native.h>
109 #include <ATen/ops/empty.h>
110 #include <ATen/ops/empty_like.h>
111 #include <ATen/ops/empty_quantized.h>
112 #include <ATen/ops/expand_as_native.h>
113 #include <ATen/ops/expand_copy_native.h>
114 #include <ATen/ops/expand_native.h>
115 #include <ATen/ops/flatten_dense_tensors_native.h>
116 #include <ATen/ops/flatten_native.h>
117 #include <ATen/ops/from_blob.h>
118 #include <ATen/ops/hsplit_native.h>
119 #include <ATen/ops/hstack.h>
120 #include <ATen/ops/hstack_native.h>
121 #include <ATen/ops/index_select_native.h>
122 #include <ATen/ops/indices_copy_native.h>
123 #include <ATen/ops/lift_fresh_native.h>
124 #include <ATen/ops/lift_native.h>
125 #include <ATen/ops/mH_native.h>
126 #include <ATen/ops/mT_native.h>
127 #include <ATen/ops/matrix_H_native.h>
128 #include <ATen/ops/meshgrid_native.h>
129 #include <ATen/ops/moveaxis_native.h>
130 #include <ATen/ops/movedim.h>
131 #include <ATen/ops/movedim_native.h>
132 #include <ATen/ops/narrow.h>
133 #include <ATen/ops/narrow_copy.h>
134 #include <ATen/ops/narrow_copy_native.h>
135 #include <ATen/ops/narrow_native.h>
136 #include <ATen/ops/new_empty_native.h>
137 #include <ATen/ops/new_ones_native.h>
138 #include <ATen/ops/numpy_T_native.h>
139 #include <ATen/ops/permute_copy_native.h>
140 #include <ATen/ops/permute_native.h>
141 #include <ATen/ops/ravel_native.h>
142 #include <ATen/ops/repeat_native.h>
143 #include <ATen/ops/reshape_as_native.h>
144 #include <ATen/ops/reshape_native.h>
145 #include <ATen/ops/resize_native.h>
146 #include <ATen/ops/row_stack_native.h>
147 #include <ATen/ops/select.h>
148 #include <ATen/ops/select_backward_native.h>
149 #include <ATen/ops/select_copy_native.h>
150 #include <ATen/ops/select_native.h>
151 #include <ATen/ops/select_scatter_native.h>
152 #include <ATen/ops/set_native.h>
153 #include <ATen/ops/slice.h>
154 #include <ATen/ops/slice_backward_native.h>
155 #include <ATen/ops/slice_copy_native.h>
156 #include <ATen/ops/slice_inverse_native.h>
157 #include <ATen/ops/slice_native.h>
158 #include <ATen/ops/slice_scatter_native.h>
159 #include <ATen/ops/sparse_coo_tensor.h>
160 #include <ATen/ops/sparse_coo_tensor_native.h>
161 #include <ATen/ops/sparse_dim_native.h>
162 #include <ATen/ops/split_copy_native.h>
163 #include <ATen/ops/split_native.h>
164 #include <ATen/ops/split_with_sizes.h>
165 #include <ATen/ops/split_with_sizes_copy_native.h>
166 #include <ATen/ops/split_with_sizes_native.h>
167 #include <ATen/ops/squeeze_copy_native.h>
168 #include <ATen/ops/squeeze_native.h>
169 #include <ATen/ops/squeeze.h>
170 #include <ATen/ops/stack_native.h>
171 #include <ATen/ops/sub.h>
172 #include <ATen/ops/sum.h>
173 #include <ATen/ops/sum_to_size_native.h>
174 #include <ATen/ops/swapaxes_native.h>
175 #include <ATen/ops/swapdims_native.h>
176 #include <ATen/ops/t_copy_native.h>
177 #include <ATen/ops/t_native.h>
178 #include <ATen/ops/tensor.h>
179 #include <ATen/ops/tensor_split.h>
180 #include <ATen/ops/tensor_split_native.h>
181 #include <ATen/ops/tile_native.h>
182 #include <ATen/ops/transpose.h>
183 #include <ATen/ops/transpose_copy_native.h>
184 #include <ATen/ops/transpose_native.h>
185 #include <ATen/ops/unbind.h>
186 #include <ATen/ops/unbind_copy_native.h>
187 #include <ATen/ops/unbind_native.h>
188 #include <ATen/ops/unflatten_dense_tensors_native.h>
189 #include <ATen/ops/unflatten_native.h>
190 #include <ATen/ops/unfold_copy_native.h>
191 #include <ATen/ops/unfold_native.h>
192 #include <ATen/ops/unsafe_chunk_native.h>
193 #include <ATen/ops/unsafe_split_native.h>
194 #include <ATen/ops/unsafe_split_with_sizes_native.h>
195 #include <ATen/ops/unsqueeze_copy_native.h>
196 #include <ATen/ops/unsqueeze_native.h>
197 #include <ATen/ops/values_copy_native.h>
198 #include <ATen/ops/view_as_complex.h>
199 #include <ATen/ops/view_as_complex_copy_native.h>
200 #include <ATen/ops/view_as_native.h>
201 #include <ATen/ops/view_as_real.h>
202 #include <ATen/ops/view_as_real_copy_native.h>
203 #include <ATen/ops/view_copy_native.h>
204 #include <ATen/ops/view_native.h>
205 #include <ATen/ops/vsplit_native.h>
206 #include <ATen/ops/vstack.h>
207 #include <ATen/ops/vstack_native.h>
208 #include <ATen/ops/zeros.h>
209 #include <ATen/ops/zeros_like.h>
210 #include <ATen/ops/zeros_native.h>
211 #endif
212
213 #include <algorithm>
214 #include <cstdint>
215 #include <utility>
216 #include <vector>
217
218 namespace at::meta {
cat_check_no_zero_dim(const MaterializedITensorListRef & tensors)219 inline void cat_check_no_zero_dim(const MaterializedITensorListRef& tensors) {
220 size_t i = 0;
221 for (const Tensor& t : tensors) {
222 TORCH_CHECK(
223 t.dim() > 0,
224 "zero-dimensional tensor (at position ", i, ") cannot be concatenated");
225 i++;
226 }
227 }
228
cat_compute_output_memory_format(const MaterializedITensorListRef & inputs)229 inline c10::MemoryFormat cat_compute_output_memory_format(const MaterializedITensorListRef& inputs) {
230 std::optional<c10::MemoryFormat> format = std::nullopt;
231 for (const Tensor& t : inputs) {
232 auto f = t.suggest_memory_format();
233 if (f == c10::MemoryFormat::Contiguous) {
234 return f;
235 }
236 if (format.has_value() && format.value() != f) {
237 return c10::MemoryFormat::Contiguous;
238 }
239 format = f;
240 }
241 return format.value();
242 }
243
TORCH_PRECOMPUTE_META_FUNC(cat)244 TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) {
245 // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
246 // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
247 // to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific
248 // size (i.e. other empty sizes are not skipped).
249 auto materialized = tensors.materialize();
250
251 cat_check_no_zero_dim(materialized);
252 dim = at::legacy_cat_wrap_dim(dim, materialized);
253
254 // Checking names before the actual dimensions.
255 auto maybe_outnames = namedinference::compute_cat_outnames(materialized);
256
257 TORCH_CHECK(
258 !materialized.empty(), "torch.cat(): expected a non-empty list of Tensors");
259
260 // Look for the first valid tensor.
261 size_t valid = materialized.size();
262 for (const auto i : c10::irange(materialized.size())) {
263 if (!at::native::cat_should_skip_tensor(materialized[i].get())) {
264 valid = i;
265 break;
266 }
267 }
268
269 bool all_contiguous = true;
270 bool all_same_dtype = true;
271 bool all_same_sizes_and_stride = true;
272 auto memory_format = cat_compute_output_memory_format(materialized);
273
274 // Compute what the output dtype should be:
275 const auto& result = maybe_get_output();
276 auto is_out_defined = result.defined();
277 auto out_dtype = at::native::result_type(tensors);
278
279 // If the output tensor is defined, we need to take it into account
280 // when computing the actual output dtype and the flags.
281 if (is_out_defined) {
282 // Check for type promotion, if the output tensor is defined.
283 TORCH_CHECK(
284 canCast(out_dtype, result.scalar_type()),
285 "torch.cat(): input types can't be cast to the desired output type ",
286 result.scalar_type());
287 out_dtype = result.scalar_type();
288 all_contiguous = result.is_contiguous(memory_format);
289 }
290
291 // Fallback 'set_output' parameters.
292 // (in case we don't find a valid tensor)
293 DimVector sizes {0};
294 TensorOptions options = materialized[0].get().options()
295 .dtype(out_dtype)
296 .memory_format(memory_format);
297
298 // If we found a valid tensor, check whether the input tensors
299 // are compatible, i.e. we can execute `cat` on them.
300 bool found_valid_tensor = valid < materialized.size();
301 if (found_valid_tensor) {
302 TORCH_CHECK(
303 dim <= materialized[valid].get().dim(), "torch.cat(): dimension ", dim, "out of range");
304
305 // Compute the output tensor size.
306 // It should have the same shape as any other valid tensor,
307 // except in the dimension 'dim'.
308 size_t size_at_dim = 0;
309 for (const auto i : c10::irange(materialized.size())) {
310 const Tensor& t = materialized[i];
311 all_same_dtype = all_same_dtype && out_dtype == t.scalar_type();
312 if (!at::native::cat_should_skip_tensor(t)) {
313 at::native::check_cat_shape_except_dim(materialized[valid], t, dim, i);
314 size_at_dim += t.size(dim);
315 all_contiguous = all_contiguous && t.is_contiguous(memory_format);
316 all_same_sizes_and_stride = all_same_sizes_and_stride &&
317 t.sizes() == materialized[valid].get().sizes() &&
318 t.strides() == materialized[valid].get().strides();
319 } else {
320 all_contiguous = false;
321 }
322 }
323
324 // Actually set the output.
325 sizes = materialized[valid].get().sizes().vec();
326 sizes[dim] = size_at_dim;
327 options = materialized[valid].get().options()
328 .dtype(out_dtype)
329 .memory_format(memory_format);
330 }
331
332 set_output_raw_strided(0, sizes, {}, options, maybe_outnames);
333 // Checks for overlaps between the inputs and the output tensor.
334 if (is_out_defined && found_valid_tensor) {
335 at::assert_no_internal_overlap(result);
336 for (const Tensor& t : materialized) {
337 at::assert_no_overlap(result, t);
338 }
339 }
340
341 return TORCH_PRECOMPUTE_STRUCT(cat)()
342 .set_dim(dim)
343 .set_valid(valid)
344 .set_all_contiguous(all_contiguous)
345 .set_all_same_dtype(all_same_dtype)
346 .set_all_same_sizes_and_stride(all_same_sizes_and_stride)
347 .set_memory_format(memory_format);
348 }
349 } // namespace at::meta
350
351 namespace at::native {
352
353 DEFINE_DISPATCH(cat_serial_stub);
354 DEFINE_DISPATCH(stack_serial_stub);
355
_reshape_from_tensor(const Tensor & self,const Tensor & shape_tensor)356 Tensor _reshape_from_tensor(const Tensor& self, const Tensor& shape_tensor) {
357 TORCH_CHECK(shape_tensor.dim() == 1);
358 std::vector<int64_t> shape;
359 auto accessor = shape_tensor.accessor<int64_t, 1>();
360 for (const auto i : c10::irange(shape_tensor.numel())) {
361 shape.push_back(accessor[i]);
362 }
363 return self.reshape(IntArrayRef(shape));
364 }
365
_shape_as_tensor(const Tensor & self)366 Tensor _shape_as_tensor(const Tensor& self) {
367 auto options = TensorOptions(at::kLong);
368 return at::tensor(self.sizes(), options);
369 }
370
set_(Tensor & result,Storage source)371 Tensor& set_(Tensor& result, Storage source) {
372 int64_t new_size =
373 static_cast<int64_t>(source.nbytes() / result.dtype().itemsize());
374 return result.set_(std::move(source), 0, new_size, {});
375 }
376
377
378 // unify with cuda implementation? This is not done to avoid a dispatch in resize_impl_cpu_
set_storage_cpu_(Tensor & result,Storage storage,int64_t storage_offset,IntArrayRef size,IntArrayRef stride)379 Tensor& set_storage_cpu_(Tensor& result, Storage storage, int64_t storage_offset, IntArrayRef size, IntArrayRef stride) {
380 checkSetStorage(result, std::move(storage), storage_offset, size, stride);
381
382 result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
383 at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ?
384 at::OptionalIntArrayRef(stride) : std::nullopt;
385 // We can re-use this kernel for the meta device.
386 // We just need to make sure we don't actually try to resize the (null) storage.
387 at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(), size, stride_opt, /*resize_storage=*/!result.is_meta());
388 return result;
389 }
390
set_storage_meta__symint(Tensor & result,Storage storage,c10::SymInt storage_offset,c10::SymIntArrayRef size,c10::SymIntArrayRef stride)391 Tensor& set_storage_meta__symint(Tensor& result, Storage storage, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
392 checkSetStorage(result, storage, storage_offset, size, stride);
393
394 c10::SymDimVector contiguous_strides;
395 if (stride.data() == nullptr) {
396 // TODO: dedupe this with empty() symbolic logic
397 int64_t dim = size.size();
398 contiguous_strides.resize(dim);
399 if (dim > 0) {
400 const auto last_idx = dim - 1;
401 contiguous_strides.at(last_idx) = 1;
402 for (auto i = last_idx - 1; i >= 0; --i) {
403 // TODO: max with 1
404 contiguous_strides.at(i) = contiguous_strides.at(i+1) * size.at(i+1);
405 }
406 }
407 stride = contiguous_strides;
408 }
409
410 // Run this before storage setting so we can access numel
411 result.unsafeGetTensorImpl()->set_sizes_and_strides(size, stride, storage_offset);
412
413 // Matches maybe_resize_storage_cpu no-numel behavior
414 if (TORCH_GUARD_SIZE_OBLIVIOUS(result.sym_numel().sym_ne(0))) {
415 // maybe_resize_storage_cpu can handle no storage exists at all but
416 // that should never be the case here
417 TORCH_INTERNAL_ASSERT(storage);
418 TORCH_CHECK(storage.resizable(), "Trying to resize storage that is not resizable");
419 // All meta data pointers are the same, so we don't have to "re" allocate
420 // it. TODO: Actually this might not quite be correct if we use special
421 // pointers to track whether or not fake cuda tensors are pinned or not
422 const auto itemsize = result.dtype().itemsize();
423 c10::SymInt new_size_bytes = result.is_contiguous()
424 ? at::detail::computeStorageNbytesContiguous(size, itemsize, std::move(storage_offset))
425 : at::detail::computeStorageNbytes(size, stride, itemsize, std::move(storage_offset));
426 // TODO: When there are unbacked SymInts, we unconditionally skip the
427 // setter. This is technically wrong, but we cannot conveniently test
428 // the real condition in many cases, because a lot of people are using
429 // set_ just to swizzle metadata on a tensor, they didn't actually want
430 // to see if they need to resize the storage.
431 //
432 // The old behavior was to unconditionally set_nbytes, but I think not
433 // setting it is more safe.
434 if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && TORCH_GUARD_SIZE_OBLIVIOUS(new_size_bytes.sym_gt(storage.sym_nbytes()))) {
435 storage.set_nbytes(std::move(new_size_bytes));
436 }
437 }
438 return result;
439 }
440
set__symint(Tensor & result,const Tensor & storage,c10::SymInt storage_offset,c10::SymIntArrayRef size,c10::SymIntArrayRef stride)441 Tensor& set__symint(Tensor& result, const Tensor& storage, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
442 TORCH_CHECK(storage.is_contiguous(), "passed in tensor to be used as storage must be contiguous");
443 return result.set__symint(storage.storage(), storage_offset + storage.sym_storage_offset(), size, stride);
444 }
445
set_tensor_(Tensor & result,const Tensor & source)446 Tensor& set_tensor_(Tensor& result, const Tensor& source) {
447 if (result.unsafeGetTensorImpl() != source.unsafeGetTensorImpl()) {
448 return result.set__symint(source.storage(), source.sym_storage_offset(), source.sym_sizes(), source.sym_strides());
449 }
450 return result;
451 }
452
453 // this needs to be split along CPU/CUDA lines because we don't have a consistent
454 // way of getting the allocator to use for a device (c10::GetAllocator is not
455 // the same as at::cuda::getCUDADeviceAllocator().
set_cpu_(Tensor & result)456 Tensor& set_cpu_(Tensor& result) {
457 caffe2::TypeMeta dtype = result.dtype();
458 Storage storage(
459 Storage::use_byte_size_t(),
460 0,
461 c10::GetAllocator(kCPU),
462 true);
463 result.set_(std::move(storage), 0, {0}, {});
464 TORCH_INTERNAL_ASSERT(dtype == result.dtype());
465 return result;
466 }
467
468 // We can't re-use the cpu kernel here because we don't want to use the cpu allocator.
set_meta_(Tensor & result)469 Tensor& set_meta_(Tensor& result) {
470 caffe2::TypeMeta dtype = result.dtype();
471 Storage storage(
472 Storage::use_byte_size_t(),
473 0,
474 c10::GetAllocator(kMeta),
475 true);
476 result.set_(std::move(storage), 0, {0}, {});
477 TORCH_INTERNAL_ASSERT(dtype == result.dtype());
478 return result;
479 }
480
sparse_broadcast_to(const Tensor & self,IntArrayRef size)481 Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) {
482 TORCH_CHECK(self.is_sparse(), "input must be sparse tensor");
483 int64_t sparse_extra_ndim = size.size() - self.dim();
484 int64_t sparse_ndim = size.size() - self.dense_dim();
485 TORCH_CHECK(sparse_extra_ndim >= 0, "input not broadcastable to size with smaller dimensionality");
486 Tensor indices = self._indices();
487 Tensor values = self._values();
488 auto nnz = values.size(0);
489
490 std::vector<int64_t> broadcast_sizes;
491 std::vector<int64_t> broadcast_dense_sizes;
492 std::vector<int64_t> broadcast_dims;
493 std::vector<int64_t> unchanged_dims;
494 broadcast_sizes.reserve(sparse_ndim);
495 broadcast_dense_sizes.reserve(self.dense_dim() + 1);
496 broadcast_dims.reserve(self.sparse_dim());
497 unchanged_dims.reserve(self.sparse_dim());
498 int64_t nnz_factor = 1;
499 int64_t min_broadcast_dim = (sparse_extra_ndim > 0 ? 0: -1);
500 int64_t max_unchanged_dim = -1;
501 for (int64_t i=0; i<sparse_extra_ndim; i++) {
502 auto d = size[i];
503 nnz_factor *= d;
504 broadcast_sizes.emplace_back(d);
505 }
506 for (int64_t i=0; i<self.sparse_dim(); i++) {
507 auto d = size[sparse_extra_ndim + i];
508 if (self.size(i) != d) {
509 TORCH_CHECK(self.size(i) == 1,
510 "The expanded size of the tensor (",size[sparse_extra_ndim + i],") ",
511 "must match the existing size (",self.size(i),")");
512 nnz_factor *= d;
513 broadcast_sizes.emplace_back(d);
514 if (min_broadcast_dim == -1) {
515 min_broadcast_dim = sparse_extra_ndim + i;
516 }
517 broadcast_dims.emplace_back(i);
518 } else {
519 unchanged_dims.emplace_back(i);
520 max_unchanged_dim = sparse_extra_ndim + i;
521 }
522 }
523 // to_broadcast conserves is_coalesced property iff only the last
524 // sparse dimensions are expanded. Possible expansion of dense
525 // dimensions can be discarded as it does not affect the is_coalesce
526 // property.
527 bool is_coalesced = self.dim()==0 || (self.is_coalesced() && (max_unchanged_dim < min_broadcast_dim || min_broadcast_dim == -1));
528
529 broadcast_dense_sizes.emplace_back(nnz);
530 for (int64_t i=0; i<self.dense_dim(); i++) {
531 broadcast_dense_sizes.emplace_back(size[sparse_extra_ndim + self.sparse_dim() + i]);
532 }
533
534 std::vector<int64_t> new_indices_size{sparse_ndim, nnz * nnz_factor};
535 std::vector<int64_t> new_values_size(values.sizes().vec());
536 new_values_size[0] = new_indices_size[1];
537
538 Tensor new_values = values.expand(broadcast_dense_sizes).repeat_interleave(nnz_factor, 0);
539 Tensor new_indices = indices.new_empty(new_indices_size);
540 if (!broadcast_sizes.empty()) {
541 Tensor broadcast_indices = at::sparse::full_coo_indices(broadcast_sizes, indices.options()).tile(nnz);
542 new_indices.narrow(0, 0, sparse_extra_ndim).copy_(broadcast_indices.narrow(0, 0, sparse_extra_ndim));
543 for (size_t i=0; i<broadcast_dims.size(); i++) {
544 int64_t j=broadcast_dims[i];
545 new_indices.select(0, sparse_extra_ndim + j).copy_(broadcast_indices.select(0, sparse_extra_ndim + i));
546 }
547 }
548 for (int64_t j:unchanged_dims) {
549 new_indices.select(0, sparse_extra_ndim + j).copy_(indices.select(0, j).repeat_interleave(nnz_factor));
550 }
551 return at::sparse_coo_tensor(new_indices, new_values, size, self.options(), is_coalesced);
552 }
553
broadcast_to_symint(const Tensor & self,SymIntArrayRef size)554 Tensor broadcast_to_symint(const Tensor& self, SymIntArrayRef size) {
555 return self.expand_symint(size);
556 }
557
broadcast_tensors(TensorList tensors)558 std::vector<Tensor> broadcast_tensors(TensorList tensors) {
559 return expand_outplace(tensors);
560 }
561
fastCatOutDim0(const Tensor & out,const MaterializedITensorListRef & inputs)562 static void fastCatOutDim0(const Tensor& out, const MaterializedITensorListRef& inputs) {
563 auto outBytes = out.nbytes();
564 char* dataPtr = reinterpret_cast<char*>(out.data_ptr());
565 size_t totalBytes = 0;
566 for (const Tensor& input : inputs) {
567 TORCH_CHECK(outBytes >= totalBytes);
568 if (input.nbytes() > 0) {
569 std::memcpy(dataPtr + totalBytes, input.const_data_ptr(), input.nbytes());
570 }
571 totalBytes += input.nbytes();
572 }
573 TORCH_CHECK(outBytes == totalBytes);
574 }
575
576
TORCH_IMPL_FUNC(cat_out_cpu)577 TORCH_IMPL_FUNC(cat_out_cpu)
578 (const ITensorListRef& tensors,
579 int64_t dim,
580 int64_t valid,
581 bool all_contiguous,
582 bool all_same_dtype,
583 bool all_same_sizes_and_stride,
584 MemoryFormat memory_format,
585 const Tensor& result) {
586 if (result.numel() == 0) {
587 return;
588 }
589
590 auto materialized = tensors.materialize();
591
592 bool use_serial_kernel = result.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
593 ScalarType dtype = materialized[valid].get().scalar_type();
594 bool serial_dtype = at::isFloatingType(dtype);
595 // fast path for single thread when both inputs and result are contiguous and
596 // not empty, and concat dim is 0
597 if (use_serial_kernel && all_contiguous && all_same_dtype && (MemoryFormat::Contiguous == memory_format)) {
598 if (dim == 0) {
599 fastCatOutDim0(result, materialized);
600 return;
601 }
602 // TODO: Add fast cat for higher dimensions and support multi-threaded fast cat
603 }
604
605 // fast path for single thread when both inputs and result are contiguous and not empty
606 if (use_serial_kernel && all_contiguous && all_same_dtype && serial_dtype) {
607 cat_serial_stub(kCPU, result, materialized, dim);
608 return;
609 }
610
611 int64_t offset = 0;
612 if (all_same_sizes_and_stride && result.is_contiguous(memory_format) &&
613 all_same_dtype) {
614 const Tensor& source_slice = materialized[valid];
615 auto slice_dim_size = source_slice.sizes()[dim];
616 auto result_slice = result.narrow(dim, 0, slice_dim_size);
617 auto result_slice_data = result_slice.data_ptr();
618 auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());
619
620 auto iter = TensorIteratorConfig()
621 .set_check_mem_overlap(false)
622 .resize_outputs(false)
623 .add_output(result_slice)
624 .add_const_input(source_slice)
625 .enforce_safe_casting_to_output(true)
626 .build();
627
628 for (const Tensor& tensor : materialized) {
629 if (cat_should_skip_tensor(tensor)) {
630 continue;
631 }
632 auto source_data = static_cast<const char*>(tensor.const_data_ptr());
633 auto result_data = static_cast<char*>(result_slice_data) + offset * result_stride_bytes;
634 iter.unsafe_replace_operand(0, result_data);
635 iter.unsafe_replace_operand(1, const_cast<char*>(source_data));
636 copy_stub(iter.device_type(), iter, false);
637 offset += slice_dim_size;
638 }
639 } else {
640 for (const Tensor& tensor: materialized) {
641 if (cat_should_skip_tensor(tensor)) {
642 continue;
643 }
644 auto slice_dim_size = tensor.sizes()[dim];
645 auto result_slice = result.narrow(dim, offset, slice_dim_size);
646
647 auto iter = TensorIteratorConfig()
648 .set_check_mem_overlap(false) // Already checked above
649 .resize_outputs(false)
650 .add_output(result_slice)
651 .add_const_input(tensor)
652 .promote_inputs_to_common_dtype(true)
653 .cast_common_dtype_to_outputs(true)
654 .enforce_safe_casting_to_output(true)
655 .build();
656 copy_stub(iter.device_type(), iter, false);
657 offset += slice_dim_size;
658 }
659 }
660 }
661
cat_out(TensorList tensors,Dimname dim,Tensor & result)662 Tensor& cat_out(TensorList tensors, Dimname dim, Tensor& result) {
663 TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
664 return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim));
665 }
666
cat(TensorList tensors,Dimname dim)667 Tensor cat(TensorList tensors, Dimname dim) {
668 TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
669 return at::cat(tensors, dimname_to_position(tensors[0], dim));
670 }
671
672 // torch.concat, alias for torch.cat
concat_out(TensorList tensors,Dimname dim,Tensor & result)673 Tensor& concat_out(TensorList tensors, Dimname dim, Tensor& result) {
674 return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim));
675 }
676
concat(TensorList tensors,Dimname dim)677 Tensor concat(TensorList tensors, Dimname dim) {
678 return at::cat(tensors, dimname_to_position(tensors[0], dim));
679 }
680
concat_out(TensorList tensors,int64_t dim,Tensor & result)681 Tensor & concat_out(TensorList tensors, int64_t dim, Tensor & result) {
682 return at::cat_out(result, tensors, dim);
683 }
684
concat(TensorList tensors,int64_t dim)685 Tensor concat(TensorList tensors, int64_t dim) {
686 return at::cat(tensors, dim);
687 }
688
689 // torch.concatenate, alias for torch.cat
concatenate_out(TensorList tensors,Dimname dim,Tensor & result)690 Tensor& concatenate_out(TensorList tensors, Dimname dim, Tensor& result) {
691 return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim));
692 }
693
concatenate(TensorList tensors,Dimname dim)694 Tensor concatenate(TensorList tensors, Dimname dim) {
695 return at::cat(tensors, dimname_to_position(tensors[0], dim));
696 }
697
concatenate_out(TensorList tensors,int64_t dim,Tensor & result)698 Tensor& concatenate_out(TensorList tensors, int64_t dim, Tensor & result) {
699 return at::cat_out(result, tensors, dim);
700 }
701
concatenate(TensorList tensors,int64_t dim)702 Tensor concatenate(TensorList tensors, int64_t dim) {
703 return at::cat(tensors, dim);
704 }
705
sizes_match_except(IntArrayRef s1,IntArrayRef s2,int64_t dim_except)706 static bool sizes_match_except(IntArrayRef s1, IntArrayRef s2, int64_t dim_except /* should already be wrapped */) {
707 if (s1.size() != s2.size()) {
708 return false;
709 }
710 for (const auto i : c10::irange(static_cast<int64_t>(s1.size()))) {
711 if (i != dim_except && s1[i] != s2[i]) {
712 return false;
713 }
714 }
715 return true;
716 }
717
718 // Check to see if the shape of tensors is compatible
719 // for being concatenated along a given dimension.
check_cat_sparse_dims(Tensor const & t,int64_t pos,IntArrayRef sizes,int64_t wrapped,int64_t sparse_dim,int64_t dense_dim)720 static void check_cat_sparse_dims(Tensor const &t,
721 int64_t pos /* used only for debug messages */,
722 IntArrayRef sizes,
723 int64_t wrapped,
724 int64_t sparse_dim,
725 int64_t dense_dim) {
726 TORCH_CHECK(t.is_sparse(),
727 "Concatenating sparse tensors, but a dense tensor was found at position ", pos, ".");
728 TORCH_CHECK(sizes_match_except(sizes, t.sizes(), wrapped),
729 "All tensors must have the same shape: ", sizes, " (except in the concatenating dimension),"
730 " but found shape: ", t.sizes(), " at position ", pos, ".");
731 TORCH_CHECK(t.sparse_dim() == sparse_dim && t.dense_dim() == dense_dim,
732 "All tensors must have the same sparse_dim and dense_dim: ", sparse_dim, ", ", dense_dim,
733 ", but tensor at position ", pos, " has ", t.sparse_dim(), ", ", t.dense_dim(), ".");
734 }
735
cat_sparse_impl(const MaterializedITensorListRef & tensors,int64_t dim)736 static Tensor cat_sparse_impl(const MaterializedITensorListRef& tensors, int64_t dim) {
737 std::vector<Tensor> indices;
738 std::vector<Tensor> values;
739 int64_t wrapped = maybe_wrap_dim(dim, tensors[0].get().dim());
740 int64_t sparse_dim = tensors[0].get().sparse_dim();
741 int64_t dense_dim = tensors[0].get().dense_dim();
742 IntArrayRef sizes = tensors[0].get().sizes();
743 if (wrapped < sparse_dim) {
744 for (const auto i : c10::irange(tensors.size())) {
745 const Tensor& t = tensors[i];
746 check_cat_sparse_dims(t, i, sizes, wrapped, sparse_dim, dense_dim);
747 indices.push_back(t._indices());
748 values.push_back(t._values());
749 }
750 Tensor idxs = at::cat(indices, 1);
751 Tensor vals = at::cat(values, 0);
752
753 // We now need to move the indices of each
754 // input tensor up along `dim` by an appropriate amount.
755 // E.g., if t1 has indices [[2,3,4],[5,6,7]],
756 // and sizes [10, 7]
757 // then torch.cat((t1,t1,t1),1) should have indices
758 // [[2,3,4,2,3,4,2,3,4],[5,6,7,12,13,14,19,20,21]],
759 // so we need to increase idxs[1][3:6] by 7
760 // and idxs[1][6:9] by 14.
761 int64_t col = 0;
762 int64_t cumulative_offset = 0;
763 for (const auto i : c10::irange(tensors.size())) {
764 const Tensor& t = tensors[i];
765 int64_t this_piece_size = t._nnz();
766 // cumulative_offset is zero for the first piece, so
767 // don't waste time doing this operation unless i > 0.
768 if (i > 0) {
769 idxs[wrapped].narrow(0, col, this_piece_size) += cumulative_offset;
770 }
771 cumulative_offset += t.size(wrapped);
772 col += this_piece_size;
773 }
774 auto sizes_copy = sizes.vec();
775 sizes_copy[wrapped] = cumulative_offset;
776 return native::sparse_coo_tensor(
777 idxs,
778 vals,
779 sizes_copy,
780 optTypeMetaToScalarType(tensors[0].get().options().dtype_opt()),
781 tensors[0].get().options().layout_opt(),
782 tensors[0].get().options().device_opt(),
783 tensors[0].get().options().pinned_memory_opt());
784 }
785 else {
786 // Catting along a dense dimension requires us to create new values.
787 // For illustration, consider the sparse 3d tensors t1 and t2,
788 // given by t1 = [[[1,2],[3,4]], ... (zeros) ..., [[5,6],[7,8]]]
789 // and t2 = [... (zeros) ..., [[9, 10], [11,12]], ... (zeros) ...],
790 // Their concatenation along dimension 2 is:
791 // [[[1,2,0,0],[3,4,0,0]], ... (zeros) ..., [[0,0,9,10],[0,0,11,12]], ... (zeros) ..., [[5,6,0,0],[7,8,0,0]]]
792 //
793 // Their values tensors are, respectively,
794 // [[[1,2],[3,4]],[[5,6],[7,8]]] and [[[9,10],[11,12]]].
795 //
796 // and so the values tensor of their concatenation along dim 2 will be:
797 // [[[1,2,0,0],[3,4,0,0]],[[5,6,0,0],[7,8,0,0]],[[0,0,9,10],[0,0,11,12]]]
798 //
799 // which we can get by taking the values tensor of each tensor, catting it with zeros of the appropriate size on the left and right,
800 // and then catting all those results together.
801
802 // The dimension in each tensor's values object that corresponds to the overall dimension along which we're catting.
803 int64_t values_dim = wrapped - sparse_dim + 1;
804 // The final size along the catted dimension.
805 const int64_t total_size = std::accumulate(
806 tensors.begin(),
807 tensors.end(),
808 static_cast<int64_t>(0),
809 [values_dim](int64_t l, const Tensor& r) {
810 return l + r._values().size(values_dim);
811 });
812 auto zeros_sizes = tensors[0].get()._values().sizes().vec();
813 int64_t cumulative_size = 0;
814 std::vector<Tensor> vals_pieces;
815 std::vector<Tensor> idxs_pieces;
816 for (const auto i : c10::irange(tensors.size())) {
817 const Tensor& t = tensors[i];
818 check_cat_sparse_dims(t, i, sizes, wrapped, sparse_dim, dense_dim);
819 // dimension 0 of values corresponds to the number of values,
820 // rather than to any logical dimension of the sparse tensor.
821 zeros_sizes[0] = t._values().size(0);
822 zeros_sizes[values_dim] = cumulative_size;
823 cumulative_size += t._values().size(values_dim);
824 auto z1 = at::zeros(
825 zeros_sizes,
826 optTypeMetaToScalarType(t._values().options().dtype_opt()),
827 t._values().options().layout_opt(),
828 t._values().options().device_opt(),
829 t._values().options().pinned_memory_opt());
830 zeros_sizes[values_dim] = total_size - cumulative_size;
831 auto z2 = at::zeros(
832 zeros_sizes,
833 optTypeMetaToScalarType(t._values().options().dtype_opt()),
834 t._values().options().layout_opt(),
835 t._values().options().device_opt(),
836 t._values().options().pinned_memory_opt());
837 vals_pieces.push_back(at::cat({z1, t._values(), z2}, values_dim));
838 idxs_pieces.push_back(t._indices());
839 }
840 auto sizes_copy = sizes.vec();
841 sizes_copy[wrapped] = total_size;
842 // This can create an uncoalesced tensor
843 return native::sparse_coo_tensor(
844 at::cat(idxs_pieces, 1),
845 at::cat(vals_pieces),
846 sizes_copy,
847 optTypeMetaToScalarType(tensors[0].get().options().dtype_opt()),
848 tensors[0].get().options().layout_opt(),
849 tensors[0].get().options().device_opt(),
850 tensors[0].get().options().pinned_memory_opt());
851 }
852 }
853
cat_sparse(const ITensorListRef & tensors,int64_t dim)854 Tensor cat_sparse(const ITensorListRef& tensors, int64_t dim) {
855 auto materialized = tensors.materialize();
856 auto maybe_outnames = namedinference::compute_cat_outnames(materialized);
857 auto result = cat_sparse_impl(materialized, at::legacy_cat_wrap_dim(dim, materialized));
858 namedinference::propagate_names_if_nonempty(result, maybe_outnames);
859 return result;
860 }
861
block_diag(TensorList tensors)862 Tensor block_diag(TensorList tensors) {
863 Tensor result;
864 if (tensors.empty()) {
865 result = at::empty({1, 0});
866 return result;
867 }
868
869 const Device& device = tensors[0].device();
870 for (const auto tensor_idx : c10::irange(tensors.size())) {
871 const Tensor& tensor = tensors[tensor_idx];
872
873 TORCH_CHECK(
874 tensor.device() == device,
875 "torch.block_diag: input tensors must all be on the same device.",
876 " Input 0 is on device ", device,
877 " and input ", tensor_idx, " is on device ", tensor.device()
878 );
879 }
880
881 ScalarType output_scalar_type = native::result_type(tensors);
882 int64_t result_dim0 = 0;
883 int64_t result_dim1 = 0;
884 std::vector<Tensor> tensors_2D(tensors.size());
885
886 // Sum the dimensions of the tensors, check tensor sizes,
887 // and expand all 0-D and 1-D tensors so that everything
888 // is 2-D
889 for (const auto tensor_idx : c10::irange(tensors.size())) {
890 const Tensor& tensor = tensors[tensor_idx];
891 int64_t ndims = tensor.dim();
892 TORCH_CHECK(
893 ndims <= 2,
894 "torch.block_diag: Input tensors must have 2 or fewer dimensions. Input ",
895 tensor_idx, " has ", ndims, " dimensions"
896 );
897
898 int64_t dim0 = 1;
899 int64_t dim1 = 1;
900
901 if (ndims == 2) {
902 dim0 = tensor.size(0);
903 dim1 = tensor.size(1);
904 tensors_2D[tensor_idx] = tensor;
905 } else if (ndims == 1) {
906 // Switching dim 0 to dim 1 is intentional
907 dim1 = tensor.size(0);
908 tensors_2D[tensor_idx] = tensor.expand({dim0, dim1});
909 } else {
910 tensors_2D[tensor_idx] = tensor.expand({dim0, dim1});
911 }
912 result_dim0 += dim0;
913 result_dim1 += dim1;
914 }
915
916 result = at::zeros(
917 {result_dim0, result_dim1},
918 tensors[0].options().dtype(output_scalar_type)
919 );
920
921 int64_t cur_dim0 = 0;
922 int64_t cur_dim1 = 0;
923
924 // Copy each tensor into the appropriate location in the result matrix
925 for (const auto& tensor : tensors_2D) {
926 int64_t dim0 = tensor.size(0);
927 int64_t dim1 = tensor.size(1);
928 result.slice(0, cur_dim0, cur_dim0+dim0).slice(1, cur_dim1, cur_dim1+dim1).copy_(tensor);
929
930 cur_dim0 += dim0;
931 cur_dim1 += dim1;
932 }
933
934 return result;
935 }
936
chunk(const Tensor & self,int64_t chunks,int64_t dim)937 std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) {
938 TORCH_CHECK(self.dim() > 0,
939 "chunk expects at least a 1-dimensional tensor");
940 TORCH_CHECK(chunks > 0,
941 "chunk expects `chunks` to be greater than 0, got: ", chunks);
942
943 const auto dim_size = self.sym_size(dim);
944 auto split_size = (dim_size + chunks - 1) / chunks;
945
946 // We need to call split_with_sizes in the case where split_size and dimension size are 0, because
947 // a call to split would discard the number of chunks (because we can have an arbitrary number of
948 // 0-sized chunks adding up to 0). So, call split_with_sizes with the correct number of chunks,
949 // eventually we will do this for all cases.
950 if (split_size == 0 && dim_size == 0) {
951 std::vector<c10::SymInt> split_sizes(chunks, split_size);
952 split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size);
953 return self.split_with_sizes_symint(split_sizes, dim);
954 } else {
955 return self.split_symint(std::move(split_size), dim);
956 }
957 }
958
tensor_split_sections_symint(const Tensor & self,c10::SymInt sym_sections,int64_t dim)959 std::vector<Tensor> tensor_split_sections_symint(const Tensor& self, c10::SymInt sym_sections, int64_t dim) {
960 TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
961 int64_t dim_ = maybe_wrap_dim(dim, self.dim());
962 // NB: intentional, sections specifies number of output tensors, which
963 // cannot be polymorphic
964 int64_t sections = sym_sections.guard_int(__FILE__, __LINE__);
965 TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections);
966 const auto dim_size = self.sym_size(dim_);
967 std::vector<Tensor> splits(sections);
968 auto min_split_size = dim_size / sections;
969 auto num_splits_one_extra = dim_size % sections;
970 c10::SymInt start_idx = 0;
971 for (const auto split_idx : c10::irange(sections)) {
972 auto split_size = (num_splits_one_extra > split_idx) ? (min_split_size + 1) : min_split_size;
973 splits[split_idx] = at::slice_symint(self, dim_, start_idx, start_idx + split_size);
974 start_idx += split_size;
975 }
976 return splits;
977 }
978
979 template <typename T>
_tensor_split_indices(const Tensor & self,ArrayRef<T> indices,int64_t dim)980 std::vector<Tensor> _tensor_split_indices(const Tensor& self, ArrayRef<T> indices, int64_t dim) {
981 TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
982 int64_t dim_ = maybe_wrap_dim(dim, self.dim());
983 int64_t num_indices = indices.size();
984 std::vector<Tensor> splits(num_indices + 1);
985 T start_idx(0);
986 for (const auto split_idx : c10::irange(num_indices)) {
987 auto end_idx = indices[split_idx];
988 splits[split_idx] = at::symint::slice<T>(self, dim_, start_idx, end_idx);
989 start_idx = end_idx;
990 }
991 splits[num_indices] = at::symint::slice<T>(self, dim_, start_idx, at::symint::size<T>(self, dim_));
992 return splits;
993 }
994
tensor_split(const Tensor & self,IntArrayRef indices,int64_t dim)995 std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim) {
996 return _tensor_split_indices(self, indices, dim);
997 }
998
tensor_split_indices_symint(const Tensor & self,SymIntArrayRef indices,int64_t dim)999 std::vector<Tensor> tensor_split_indices_symint(const Tensor& self, SymIntArrayRef indices, int64_t dim) {
1000 return _tensor_split_indices(self, indices, dim);
1001 }
1002
tensor_split(const Tensor & self,const Tensor & tensor_indices_or_sections,int64_t dim)1003 std::vector<Tensor> tensor_split(const Tensor& self, const Tensor& tensor_indices_or_sections, int64_t dim) {
1004 TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
1005 auto split_device = tensor_indices_or_sections.device();
1006 TORCH_CHECK(split_device == kCPU,
1007 "tensor_split expected tensor_indices_or_sections to be on cpu, but it's on ", split_device);
1008 auto split_dtype = tensor_indices_or_sections.scalar_type();
1009 TORCH_CHECK(split_dtype == at::kLong,
1010 "tensor_split expected tensor_indices_or_sections to have dtype of long, but got ", split_dtype);
1011 auto split_dim = tensor_indices_or_sections.dim();
1012 TORCH_CHECK(split_dim == 1 || split_dim == 0,
1013 "tensor_split expected tensor_indices_or_sections to be a zero-dimensional or one-dimensional tensor, but got a tensor with ", split_dim, " dims");
1014
1015 if (split_dim == 0) {
1016 int64_t sections = tensor_indices_or_sections.item<int64_t>();
1017 return self.tensor_split(sections, dim);
1018 } else {
1019 auto indices_data = tensor_indices_or_sections.const_data_ptr<int64_t>();
1020 auto stride = tensor_indices_or_sections.stride(0);
1021 auto numel = tensor_indices_or_sections.numel();
1022 std::vector<int64_t> indices(numel);
1023 for (const auto offset : c10::irange(numel)) {
1024 // indices tensor could be non-contiguous
1025 indices[offset] = *(indices_data + offset * stride);
1026 }
1027 return self.tensor_split(indices, dim);
1028 }
1029 }
1030
unsafe_chunk(const Tensor & self,int64_t chunks,int64_t dim)1031 std::vector<Tensor> unsafe_chunk(const Tensor& self, int64_t chunks, int64_t dim) {
1032 TORCH_CHECK(self.dim() > 0,
1033 "chunk expects at least a 1-dimensional tensor");
1034 TORCH_CHECK(chunks > 0,
1035 "chunk expects `chunks` to be greater than 0, got: ", chunks);
1036
1037 const auto dim_size = self.size(dim);
1038 int64_t split_size = (dim_size + chunks - 1) / chunks;
1039
1040 // See the comment above in chunk(...)
1041 if (split_size == 0 && dim_size == 0) {
1042 std::vector<int64_t> split_sizes(chunks, split_size);
1043 split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size);
1044 return self.unsafe_split_with_sizes(split_sizes, dim);
1045 } else {
1046 return self.unsafe_split(split_size, dim);
1047 }
1048 }
1049
diagflat(const Tensor & self,int64_t offset)1050 Tensor diagflat(const Tensor& self, int64_t offset) {
1051 return self.contiguous().view(-1).diag(offset);
1052 }
1053
diagonal(const Tensor & self,int64_t offset,int64_t dim1_,int64_t dim2_)1054 Tensor diagonal(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim2_) {
1055 int64_t nDims = self.dim();
1056 int64_t dim1 = maybe_wrap_dim(dim1_, nDims);
1057 int64_t dim2 = maybe_wrap_dim(dim2_, nDims);
1058 TORCH_CHECK(dim1 != dim2, "diagonal dimensions cannot be identical ", dim1_, ", ", dim2_);
1059 auto outnames = namedinference::compute_diagonal_outnames(self, dim1, dim2);
1060 NoNamesGuard no_names_guard;
1061
1062 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1063 int64_t diag_size;
1064 int64_t storage_offset = self.storage_offset();
1065 // compute storage offset and size for the diagonal
1066 // for positive values of offset (above the main diagonal)
1067 // "leftmost columns" (along dim2) are dropped
1068 // for negative values of offset (below the main diagonal)
1069 // "topmost rows" (along dim1) are dropped.
1070 // Note that we invert +/- in the second to absorb the negative
1071 // sign in the offset.
1072 if (offset >= 0) {
1073 diag_size = std::max<int64_t>(std::min(self.size(dim1), self.size(dim2)-offset), 0);
1074 } else {
1075 diag_size = std::max<int64_t>(std::min(self.size(dim1)+offset, self.size(dim2)), 0);
1076 }
1077
1078 // NumPy allows you to specify offsets "off the end"; let's just be careful not to
1079 // set a ridiculous storage_offset in that case (technically it shouldn't matter
1080 // because there are no elements in the tensor, but let's be kosher).
1081 if (diag_size == 0) {
1082 // skip
1083 } else if (offset >= 0) {
1084 storage_offset += offset * self.stride(dim2);
1085 } else {
1086 storage_offset -= offset * self.stride(dim1);
1087 }
1088
1089 // construct new size and stride: we drop dim1 and dim2 (maximum first for not changing the index of the minimum)
1090 // the new ("joint") dimension is appended to the end of the shape / stride to match numpy semantics
1091 DimVector sizes(self.sizes().begin(), self.sizes().end());
1092 DimVector strides(self.strides().begin(), self.strides().end());
1093 sizes.erase(sizes.begin() + std::max(dim1, dim2));
1094 strides.erase(strides.begin() + std::max(dim1, dim2));
1095 sizes.erase(sizes.begin() + std::min(dim1, dim2));
1096 strides.erase(strides.begin() + std::min(dim1, dim2));
1097 sizes.push_back(diag_size);
1098 strides.push_back(self.stride(dim1)+self.stride(dim2));
1099
1100 // return view with new parameters
1101 auto result = self.as_strided(sizes, strides, storage_offset);
1102
1103 no_names_guard.reset();
1104 namedinference::propagate_names_if_nonempty(result, outnames);
1105 return result;
1106 }
1107
diagonal(const Tensor & self,Dimname outdim,Dimname dim1,Dimname dim2,int64_t offset)1108 Tensor diagonal(const Tensor& self, Dimname outdim, Dimname dim1, Dimname dim2, int64_t offset) {
1109 auto result = at::diagonal(
1110 self,
1111 offset,
1112 dimname_to_position(self, dim1),
1113 dimname_to_position(self, dim2));
1114 // This is slower than it needs to be because there is no way to modify
1115 // the names of a tensor in-place right now. In the future we should consider
1116 // offering that functionality.
1117 std::vector<Dimname> new_names = result.names().vec();
1118 new_names[new_names.size() - 1] = outdim;
1119 return result.refine_names(new_names);
1120 }
1121
diag_embed(const Tensor & self,int64_t offset,int64_t dim1_,int64_t dim2_)1122 Tensor diag_embed(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim2_) {
1123 int64_t nDims = self.dim() + 1;
1124 int64_t dim1 = maybe_wrap_dim(dim1_, nDims);
1125 int64_t dim2 = maybe_wrap_dim(dim2_, nDims);
1126 TORCH_CHECK(dim1 != dim2, "diagonal dimensions cannot be identical ", dim1_, ", ", dim2_);
1127 int64_t new_dim_len = std::abs(offset) + self.size(-1);
1128 auto sizes = self.sizes().vec();
1129 sizes.pop_back();
1130 sizes.insert(sizes.begin() + std::min(dim1, dim2), new_dim_len);
1131 sizes.insert(sizes.begin() + std::max(dim1, dim2), new_dim_len);
1132 auto result = at::zeros(sizes, self.options());
1133 auto diag = result.diagonal(offset, dim1, dim2);
1134 diag.copy_(self);
1135 return result;
1136 }
1137
expand(const Tensor & self,c10::IntArrayRef size,bool)1138 Tensor expand(const Tensor& self, c10::IntArrayRef size, bool /*unused*/) {
1139 TORCH_CHECK(size.size() >= (size_t)self.dim(),
1140 "expand(", self.toString(), "{", self.sizes(), "}, size=", size,
1141 "): the number of sizes provided (", size.size(), ") ",
1142 "must be greater or equal to the number of dimensions in the tensor (",
1143 self.dim(), ")");
1144 TORCH_CHECK(!self.is_sparse() && !at::sparse_csr::is_sparse_compressed(self),
1145 "expand is unsupported for ", self.layout(), " tensors");
1146
1147 auto expandedSizesAndStrides = inferExpandGeometry_dimvector(self.sizes(), self.strides(), size);
1148
1149 auto result = self.as_strided(
1150 expandedSizesAndStrides.sizes, expandedSizesAndStrides.strides);
1151 namedinference::propagate_names_for_expand(result, self);
1152 return result;
1153 }
1154
expand_as(const Tensor & self,const Tensor & other)1155 Tensor expand_as(const Tensor& self, const Tensor& other) {
1156 return self.expand_symint(other.sym_sizes());
1157 }
1158
sum_to_size_symint(const Tensor & self,SymIntArrayRef size)1159 Tensor sum_to_size_symint(const Tensor& self, SymIntArrayRef size) {
1160 TORCH_CHECK(is_expandable_to(size, self.sym_sizes()),
1161 "size {", size, "} is not expandable to size {", self.sizes(), "}.");
1162
1163 return sum_to(self, size);
1164 }
1165
1166 // We currently do not support per-channel quant for unfold, diagonal, expand, permute.
1167 // TODO: Make this an aten function and replace as_strided_qtensorimpl once that is done.
make_qtensor(const Tensor & self,IntArrayRef size,IntArrayRef stride,QuantizerPtr quantizer)1168 static Tensor make_qtensor(const Tensor& self, IntArrayRef size, IntArrayRef stride, QuantizerPtr quantizer) {
1169 auto result = at::detail::make_tensor<QTensorImpl>(
1170 c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype(), quantizer);
1171 setStrided(result, size, stride, self.storage_offset());
1172 return result;
1173 }
1174
as_strided_tensorimpl(const Tensor & self,IntArrayRef size,IntArrayRef stride,std::optional<int64_t> storage_offset_)1175 Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, std::optional<int64_t> storage_offset_) {
1176 TORCH_INTERNAL_ASSERT(!self.is_mps(), "as_strided_tensorimpl does not work with MPS; call self.as_strided(...) instead");
1177 auto storage_offset = storage_offset_.value_or(self.storage_offset());
1178 auto result = at::detail::make_tensor<TensorImpl>(
1179 c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
1180 setStrided(result, size, stride, storage_offset);
1181 return result;
1182 }
1183
1184 template <typename T>
setStridedUnchecked(const Tensor & self,ArrayRef<T> size,ArrayRef<T> stride,T && storage_offset)1185 inline void setStridedUnchecked(
1186 const Tensor& self,
1187 ArrayRef<T> size,
1188 ArrayRef<T> stride,
1189 T&& storage_offset) {
1190 auto* self_ = self.unsafeGetTensorImpl();
1191 self_->set_sizes_and_strides(size, stride, std::make_optional(std::forward<T>(storage_offset)));
1192 }
1193
as_strided_tensorimpl_meta_symint(const Tensor & self,SymIntArrayRef sym_size,SymIntArrayRef sym_stride,std::optional<c10::SymInt> sym_storage_offset_)1194 Tensor as_strided_tensorimpl_meta_symint(const Tensor& self, SymIntArrayRef sym_size, SymIntArrayRef sym_stride, std::optional<c10::SymInt> sym_storage_offset_) {
1195 auto sym_storage_offset = sym_storage_offset_.value_or(self.sym_storage_offset());
1196 auto result = at::detail::make_tensor<TensorImpl>(
1197 c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
1198 // NB: The reason this is unchecked is to ensure we don't generate
1199 // guards on the base storage itself when performing as_strided calls.
1200 // Although technically these guards are necessary, in practice they
1201 // cause a lot of guards that falsely refer to base symbols. We will instead
1202 // rely on AOTAutograd to sort out if we actually have dependence on view
1203 // bases / storage size.
1204 setStridedUnchecked(result, sym_size, sym_stride, std::move(sym_storage_offset));
1205 return result;
1206 }
1207
as_strided_qtensorimpl(const Tensor & self,IntArrayRef size,IntArrayRef stride,std::optional<int64_t> storage_offset_)1208 Tensor as_strided_qtensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, std::optional<int64_t> storage_offset_) {
1209 auto storage_offset = storage_offset_.value_or(self.storage_offset());
1210 auto quantizer = get_qtensorimpl(self)->quantizer();
1211 TORCH_CHECK(
1212 quantizer->qscheme() == QScheme::PER_TENSOR_AFFINE,
1213 "Setting strides is possible only on uniformly quantized tensor");
1214 auto result = at::detail::make_tensor<QTensorImpl>(
1215 c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype(), quantizer);
1216 setStrided(result, size, stride, storage_offset);
1217 return result;
1218 }
1219
1220 // This is an overloaded function similar to
1221 // Tensor as_strided_qtensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, std::optional<int64_t> storage_offset_)
1222 // and is currently not available through the dispatcher. The additional
1223 // input, quantizer, is called by the select & slice methods.
1224 // TODO: Make this function compatible with the dispatcher
as_strided_qtensorimpl(const Tensor & self,IntArrayRef size,IntArrayRef stride,std::optional<int64_t> storage_offset_,QuantizerPtr quantizer)1225 static Tensor as_strided_qtensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, std::optional<int64_t> storage_offset_,
1226 QuantizerPtr quantizer) {
1227 auto storage_offset = storage_offset_.value_or(self.storage_offset());
1228 TORCH_CHECK(
1229 (quantizer->qscheme() == QScheme::PER_TENSOR_AFFINE) ||
1230 (quantizer->qscheme() == QScheme::PER_CHANNEL_AFFINE),
1231 "Setting strides is possible only on uniformly or per channel quantized tensors");
1232 auto result = at::detail::make_tensor<QTensorImpl>(
1233 c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype(), quantizer);
1234 setStrided(result, size, stride, storage_offset);
1235 return result;
1236 }
1237
as_strided__symint(const Tensor & self,SymIntArrayRef size,SymIntArrayRef stride,std::optional<c10::SymInt> storage_offset_)1238 const Tensor &as_strided__symint(const Tensor& self, SymIntArrayRef size, SymIntArrayRef stride, std::optional<c10::SymInt> storage_offset_) {
1239 auto storage_offset = storage_offset_.value_or(self.sym_storage_offset());
1240 setStrided(self, size, stride, std::move(storage_offset));
1241 return self;
1242 }
1243
1244 // Should just use narrow_copy_out, but this API is used internally at Meta:
1245 // https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
narrow_copy_dense_cpu(const Tensor & self,int64_t dim,int64_t start,int64_t length)1246 Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
1247 // narrow_copy_dense_cpu_out always resize output's size, so there only create
1248 // a zero size tensor.
1249 auto output = at::empty({0}, self.options());
1250 return narrow_copy_dense_cpu_out(self, dim, start, length, output);
1251 }
1252
narrow_copy_sparse(const Tensor & self,int64_t dim,int64_t start,int64_t length)1253 Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
1254 int64_t allDim = self.dim();
1255 int64_t end = start+length;
1256 TORCH_CHECK(allDim > 0, "narrow() cannot be applied to a 0-dim tensor.");
1257 TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
1258 TORCH_CHECK(dim >= 0 && dim < allDim,
1259 "Dimension ", dim, " out of range. Expecting 0 <= dim < ", allDim, ".");
1260 TORCH_CHECK(start >= 0 && end <= self.size(dim),
1261 "Invalid range to narrow. range(start, start+length) must be a subset of range(0, ", self.size(dim), ").")
1262 Tensor indices = self._indices();
1263 int64_t sparse_dim = self.sparse_dim();
1264
1265 std::vector<int64_t> new_sizes = self.sizes().vec();
1266 new_sizes[dim] = length;
1267
1268 Tensor new_values;
1269 Tensor new_indices;
1270 if (dim < sparse_dim) {
1271 Tensor mask = (indices[dim] >= start).__and__((indices[dim] < end));
1272 new_indices = indices.masked_select(mask).view({sparse_dim, -1});
1273 new_indices[dim].sub_(start);
1274 Tensor nzIndices = mask.nonzero().view(-1);
1275 new_values = self._values().index_select(0, nzIndices);
1276 } else {
1277 /* This means we are narrowing on a dense dim, which is in effect just a
1278 regular narrow on _values() */
1279 new_indices = indices;
1280 int64_t dense_dim = dim - sparse_dim + 1;
1281 new_values = self._values().narrow_copy(dense_dim, start, length);
1282 }
1283
1284 return at::sparse_coo_tensor(new_indices, new_values, new_sizes, self.options(), self.is_coalesced());
1285 }
1286
1287 // Should just use narrow_copy_out, but this API is used internally at Meta:
1288 // https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
narrow_copy_dense_cpu_out(const Tensor & self,int64_t dim,int64_t start,int64_t length,Tensor & output)1289 Tensor& narrow_copy_dense_cpu_out(
1290 const Tensor& self, int64_t dim, int64_t start, int64_t length, Tensor& output
1291 ) {
1292
1293 TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
1294 TORCH_CHECK(self.dtype() == output.dtype());
1295
1296 auto self_contig = self.expect_contiguous();
1297 const auto self_sizes = self_contig->sizes();
1298
1299 // wrap dim if negative and do bound check
1300 if (dim < 0) {
1301 dim = at::maybe_wrap_dim(dim, self_sizes.size());
1302 } else {
1303 TORCH_CHECK(dim < static_cast<int64_t>(self_sizes.size()));
1304 }
1305
1306 // wrap start and do bound check
1307 const auto cur_size = self_sizes[dim];
1308 TORCH_CHECK_INDEX(
1309 -cur_size <= start && start <= cur_size,
1310 "start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
1311 )
1312 if (start < 0) {
1313 start = start + cur_size;
1314 }
1315 TORCH_CHECK(
1316 length >= 0 && start <= cur_size - length,
1317 "start (",
1318 start,
1319 ") + length (",
1320 length,
1321 ") exceeds dimension size (",
1322 cur_size,
1323 ").");
1324
1325 // resize output
1326 auto output_sizes = self_sizes.vec();
1327 output_sizes[dim] = length;
1328 at::native::resize_(output, output_sizes);
1329
1330 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1331 const int64_t unit = c10::size_from_dim_(dim + 1, self_sizes);
1332 const int64_t num_blocks = c10::size_to_dim_(dim, self_sizes);
1333
1334 const auto itemsize = self_contig->dtype().itemsize();
1335 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
1336 size_t src_nbytes = itemsize * self_contig->numel();
1337 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
1338 size_t dst_nbytes = itemsize * output.numel();
1339
1340 size_t src_block_size = unit * self_sizes[dim];
1341 size_t dst_block_size = unit * length;
1342
1343 if (num_blocks == 0 || dst_block_size == 0) {
1344 return output;
1345 }
1346
1347 const char* src_bytes = static_cast<const char*>(self_contig->const_data_ptr());
1348 char* dst_bytes = static_cast<char*>(output.data_ptr());
1349
1350 size_t src_block_size_bytes = itemsize * src_block_size;
1351 size_t dst_block_size_bytes = itemsize * dst_block_size;
1352 size_t src_offset = unit * start;
1353
1354 const char* src_offset_bytes = src_bytes + itemsize * src_offset;
1355 char* dst_offset_bytes = dst_bytes;
1356
1357 for (const auto i : c10::irange(num_blocks)) {
1358 const char* local_src_offset_bytes = src_offset_bytes + i * src_block_size_bytes;
1359 char* local_dst_offset_bytes = dst_offset_bytes + i * dst_block_size_bytes;
1360 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1361 static_cast<const void*>(local_src_offset_bytes + dst_block_size_bytes) <=
1362 static_cast<const void*>(src_bytes + src_nbytes));
1363 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1364 static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes) <=
1365 static_cast<void*>(dst_bytes + dst_nbytes));
1366
1367 memcpy(
1368 local_dst_offset_bytes, local_src_offset_bytes, dst_block_size_bytes);
1369 }
1370 return output;
1371 }
1372
narrow(const Tensor & self,int64_t dim,int64_t start,int64_t length)1373 Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
1374 TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
1375 TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
1376 auto cur_size = self.size(dim);
1377 TORCH_CHECK_INDEX(
1378 -cur_size <= start && start <= cur_size,
1379 "start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
1380 )
1381 if (start < 0) {
1382 start = start + cur_size;
1383 }
1384 TORCH_CHECK(start <= cur_size - length,
1385 "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
1386 return at::slice(self, dim, start, start + length, 1);
1387 }
1388
narrow_symint(const Tensor & self,int64_t dim,SymInt start,SymInt length)1389 Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
1390 TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
1391 TORCH_SYM_CHECK(length.sym_ge(0), "narrow(): length must be non-negative.");
1392 auto cur_size = self.sym_size(dim);
1393 TORCH_CHECK_INDEX(
1394 ((-cur_size).sym_le(start).sym_and(start.sym_le(cur_size))).expect_true(__FILE__, __LINE__),
1395 "start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
1396 )
1397 if (start < 0) {
1398 start = start + cur_size;
1399 }
1400 TORCH_SYM_CHECK(start.sym_le(cur_size - length),
1401 "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
1402 return at::slice_symint(self, dim, start, start + length, 1);
1403 }
1404
1405 // This overload exists purely for XLA, because they wanted to pass in "symbolic"
1406 // start via Tensor.
narrow_tensor_symint(const Tensor & self,int64_t dim,const Tensor & start,SymInt length)1407 Tensor narrow_tensor_symint(const Tensor& self, int64_t dim, const Tensor& start, SymInt length) {
1408 TORCH_CHECK(start.dim() == 0 && isIntegralType(start.scalar_type(), /*includeBool=*/false),
1409 "start must be an 0-dim integral Tensor.");
1410 int64_t st = start.item<int64_t>();
1411 return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length));
1412 }
1413
1414 std::tuple<DimVector, DimVector, std::vector<int64_t>>
_permute_size_stride_estimation(const Tensor & self,IntArrayRef dims)1415 static _permute_size_stride_estimation(const Tensor& self, IntArrayRef dims) {
1416 const auto ndim = self.dim();
1417 TORCH_CHECK(ndim == static_cast<int64_t>(dims.size()),
1418 "permute(sparse_coo): number of dimensions in the tensor input ",
1419 "does not match the length of the desired ordering of dimensions ",
1420 "i.e. input.dim() = ", ndim, " is not equal to len(dims) = ", dims.size());
1421
1422 const auto is_strided_layout = self.options().layout() == at::kStrided;
1423 const auto old_sizes = self.sizes();
1424 const auto old_strides = is_strided_layout ? self.strides() : IntArrayRef{};
1425
1426 auto new_sizes = DimVector(ndim);
1427 auto new_strides = DimVector(is_strided_layout ? ndim : 0);
1428 auto wrapped_dims = std::vector<int64_t>(ndim);
1429 std::vector<bool> seen_dims(ndim);
1430
1431 for (const auto i : c10::irange(ndim)) {
1432 const auto d = maybe_wrap_dim(dims[i], ndim);
1433 TORCH_CHECK(!seen_dims[d],
1434 "permute(): duplicate dims are not allowed.");
1435 seen_dims[d] = true;
1436 wrapped_dims[i] = d;
1437 new_sizes[i] = old_sizes[d];
1438 if (is_strided_layout) {
1439 new_strides[i] = old_strides[d];
1440 }
1441 }
1442
1443 return std::make_tuple(new_sizes, new_strides, wrapped_dims);
1444 }
1445
permute(const Tensor & self,IntArrayRef dims)1446 Tensor permute(const Tensor& self, IntArrayRef dims) {
1447 auto [new_sizes, new_strides, _] = _permute_size_stride_estimation(self, dims);
1448 return self.as_strided(new_sizes, new_strides);
1449 }
1450
permute_sparse_coo(const Tensor & self,IntArrayRef dims)1451 Tensor permute_sparse_coo(const Tensor& self, IntArrayRef dims) {
1452 auto [new_sizes, _, wrapped_dims] = _permute_size_stride_estimation(self, dims);
1453
1454 const auto ndim = self.dim();
1455 const auto sparse_ndim = self.sparse_dim();
1456 const auto dense_ndim = self.dense_dim();
1457
1458 auto dims_id_perm = std::vector<int64_t>(ndim);
1459 auto dims_sparse_dense_id_perm = std::vector<int64_t>(ndim);
1460 for (const auto i : c10::irange(ndim)) {
1461 dims_id_perm[i] = i;
1462 dims_sparse_dense_id_perm[i] = wrapped_dims[i];
1463 }
1464 std::sort(dims_sparse_dense_id_perm.begin(), dims_sparse_dense_id_perm.begin() + sparse_ndim);
1465 std::sort(dims_sparse_dense_id_perm.begin() + sparse_ndim, dims_sparse_dense_id_perm.end());
1466 TORCH_CHECK(dims_sparse_dense_id_perm == dims_id_perm,
1467 "permute(sparse_coo): transpositions between sparse and dense dimensions are not allowed.",
1468 "Only transpositions within sparse and dense dimensions are supported.");
1469
1470 const auto slice = [](std::vector<int64_t> v, size_t begin, size_t len) -> decltype(v) {
1471 return std::vector<int64_t>{v.begin() + begin, v.begin() + begin + len};
1472 };
1473
1474 auto old_sparse_dims = slice(dims_id_perm, 0, sparse_ndim);
1475 auto old_dense_dims = slice(std::move(dims_id_perm), sparse_ndim, ndim - sparse_ndim);
1476 auto new_sparse_dims = slice(wrapped_dims, 0, sparse_ndim);
1477 auto new_dense_dims = slice(std::move(wrapped_dims), sparse_ndim, ndim - sparse_ndim);
1478
1479 auto old_indices = self._indices();
1480 auto old_values = self._values();
1481
1482 const auto new_indices = (new_sparse_dims == old_sparse_dims)
1483 ? std::move(old_indices)
1484 : [&]() -> Tensor {
1485 auto sparse_perm_tensor = at::from_blob(reinterpret_cast<void*>(new_sparse_dims.data()),
1486 {sparse_ndim}, old_indices.options().device(at::kCPU));
1487 // creates new indices. It is possible to avoid that if COO
1488 // is allowed to store a permutation vector.
1489 return old_indices.index_select(0, sparse_perm_tensor.to(self.device().type()));
1490 }();
1491 const auto new_values = (new_dense_dims == old_dense_dims)
1492 ? std::move(old_values)
1493 : [&]() -> Tensor {
1494 auto values_perm = std::vector<int64_t>(dense_ndim + 1);
1495 for (const auto i : c10::irange(dense_ndim)) {
1496 values_perm[i + 1] = new_dense_dims[i] - sparse_ndim + 1;
1497 }
1498 return old_values.permute(values_perm);
1499 }();
1500 const auto is_coalesced = self.is_coalesced() && (dims.empty() || dims[0] == 0);
1501 // TODO: apply `is_coalesced ||= new_values.size(0) < 2`.
1502 return _sparse_coo_tensor_with_dims_and_tensors(
1503 sparse_ndim, dense_ndim, new_sizes, new_indices, new_values, self.options(), is_coalesced);
1504 }
1505
repeat(const Tensor & self,IntArrayRef repeats)1506 Tensor repeat(const Tensor& self, IntArrayRef repeats) {
1507 TORCH_CHECK(repeats.size() >= (size_t)self.dim(),
1508 "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
1509
1510 // Add new leading dimensions to the tensor if the
1511 // number of target dimensions is larger than the
1512 // number of source dimensions.
1513 int64_t num_new_dimensions = repeats.size() - self.dim();
1514 DimVector padded_size(num_new_dimensions, 1);
1515 padded_size.insert(padded_size.end(), self.sizes().begin(), self.sizes().end());
1516 DimVector target_size(repeats.size());
1517 bool zero_tensor = false;
1518 for(const auto idx : c10::irange(repeats.size())) {
1519 if (repeats[idx] == 0) {
1520 zero_tensor = true;
1521 }
1522 target_size[idx] = padded_size[idx] * repeats[idx];
1523 }
1524
1525 Tensor xtensor = self.expand(padded_size);
1526
1527 Tensor result;
1528 if (self.is_quantized()) {
1529 result = at::empty_quantized(target_size, self);
1530 } else {
1531 result = at::empty(target_size, self.options());
1532 }
1533
1534 // return an empty tensor if one of the repeat dimensions is zero
1535 if (zero_tensor) {
1536 return result;
1537 }
1538
1539 Tensor urtensor = at::alias(result);
1540 for (const auto i : c10::irange(xtensor.dim())) {
1541 // can't unfold with step 0, so make sure step is at least 1
1542 // (it doesn't matter what it is in that case, because the size is 0).
1543 auto size_i = xtensor.sizes()[i];
1544 urtensor = urtensor.unfold(i, size_i, std::max<int64_t>(size_i, 1));
1545 }
1546
1547 urtensor.copy_(xtensor.expand_as(urtensor));
1548
1549 return result;
1550 }
1551
tile_symint(const Tensor & self,SymIntArrayRef reps)1552 Tensor tile_symint(const Tensor& self, SymIntArrayRef reps){
1553 // If self.size() > len(reps), reps is promoted to self.size() by pre-pending
1554 // 1’s to it to keep the same behaviour as `numpy.tile`.
1555 // Thus for a tensor of shape (2, 3, 4, 5), a dims of (2, 2) is treated
1556 // as (1, 1, 2, 2).
1557 const int64_t size_diff = self.dim() - static_cast<int64_t>(reps.size());
1558 if (size_diff > 0){
1559 std::vector<c10::SymInt> new_reps(size_diff, 1);
1560 for (const auto i : c10::irange(reps.size())) {
1561 new_reps.emplace_back(reps[i]);
1562 }
1563 return self.repeat_symint(SymIntArrayRef(new_reps));
1564 }
1565 // `torch.tile` is equivalent to the already implemented `torch.Tensor.repeat`
1566 return self.repeat_symint(reps);
1567 }
1568
1569 //
1570 // templated for ArrayRef<int64_t> and SmallVector<int64_t> use cases
1571 //
1572 template <typename Vec>
alias_with_sizes_and_strides(const Tensor & self,const Vec & sizes,const Vec & strides)1573 Tensor alias_with_sizes_and_strides(
1574 const Tensor& self,
1575 const Vec& sizes,
1576 const Vec& strides) {
1577 //caller should make sure that sizes and strides are valid for self
1578 //(storage is sufficient, strides are non-negative, strides and sizes array size is the same)
1579 Tensor self_;
1580 if (self.is_quantized()) {
1581 self_ = at::detail::make_tensor<QTensorImpl>(
1582 c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype(), get_qtensorimpl(self)->quantizer());
1583 auto* self_tmp_ = self_.unsafeGetTensorImpl();
1584 self_tmp_->set_storage_offset(self.storage_offset());
1585 self_tmp_->set_sizes_and_strides(sizes, strides);
1586 } else {
1587 self_ = at::detail::make_tensor<TensorImpl>(
1588 c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
1589 auto* self_tmp_ = self_.unsafeGetTensorImpl();
1590 self_tmp_->set_storage_offset(self.storage_offset());
1591 self_tmp_->set_sizes_and_strides(sizes, strides);
1592 }
1593 namedinference::propagate_names(self_, self);
1594 return self_;
1595 }
1596
1597 // specialization for symbolic shapes and strides.
1598 // SymIntArrayRef/ArrayRef<c10::SymInt> and SmallVector<c10::SymInt>/SymDimVector
1599 template <template <typename...> typename Container>
alias_with_sizes_and_strides(const Tensor & self,const Container<c10::SymInt> & sizes,const Container<c10::SymInt> & strides)1600 Tensor alias_with_sizes_and_strides(
1601 const Tensor& self,
1602 const Container<c10::SymInt>& sizes,
1603 const Container<c10::SymInt>& strides) {
1604 //caller should make sure that sizes and strides are valid for self
1605 //(storage is sufficient, strides are non-negative, strides and sizes array size is the same)
1606 Tensor self_;
1607 if (self.is_quantized()) {
1608 self_ = at::detail::make_tensor<QTensorImpl>(
1609 c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype(), get_qtensorimpl(self)->quantizer());
1610 self_.unsafeGetTensorImpl()->set_sizes_and_strides(sizes, strides, self.sym_storage_offset());
1611 } else {
1612 self_ = at::detail::make_tensor<TensorImpl>(
1613 c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
1614 self_.unsafeGetTensorImpl()->set_sizes_and_strides(sizes, strides, self.sym_storage_offset());
1615 }
1616 namedinference::propagate_names(self_, self);
1617 return self_;
1618 }
1619
reshape_symint(const Tensor & self,c10::SymIntArrayRef proposed_shape)1620 Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
1621 if (self.is_sparse()) {
1622 TORCH_CHECK(false, "reshape is not implemented for sparse tensors");
1623 }
1624
1625 if (self.is_contiguous() && !self.is_mkldnn()) {
1626 return self.view_symint(proposed_shape);
1627 }
1628
1629 c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel());
1630
1631 if (self.is_mkldnn()) {
1632 return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape));
1633 }
1634
1635 // `computeStride` returns the proper strides to use if this
1636 // `reshape` can be just a view.
1637 auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape);
1638
1639 // NB: Even though we have viewable geometry and the target strides here,
1640 // we do not just call `as_strided` on `self` because the backward
1641 // for `as_strided` is not as efficient as that of `view` (since the
1642 // former is meant to handle general cases).
1643 //
1644 // Similarly we don't call `view` because it duplicates some of the work
1645 // we've already done, and instead call our internal/private operator
1646 // `_reshape_alias` that essentially does the same thing as `view` and
1647 // `as_strided` without any of the extra overhead.
1648 if (stride.has_value()) {
1649 // Temporary check to revert to the old behavior/view in cases where the
1650 // device is not supported (e.g. for XLA the operation is not supported
1651 // so we use `view` instead).
1652 //
1653 // We need to do the checks here instead of in `native_functions.yaml`
1654 // to preserve backwards compatibility.
1655 if (!self.is_xla() && !self.is_lazy() && !self.is_ipu() && !at::isTensorSubclassLike(self)) {
1656 return self._reshape_alias_symint(shape, stride.value());
1657 } else {
1658 return self.view_symint(shape);
1659 }
1660 }
1661 return at::_unsafe_view_symint(self.clone(at::MemoryFormat::Contiguous), shape);
1662 }
1663
_reshape_copy_symint(const Tensor & self,c10::SymIntArrayRef proposed_shape)1664 Tensor _reshape_copy_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
1665 if (self.is_sparse()) {
1666 TORCH_CHECK(0, "_reshape_copy is not implemented for sparse tensors");
1667 }
1668 c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel());
1669
1670 if (self.is_mkldnn()) {
1671 TORCH_CHECK(0, "_reshape_copy not implemented for mkldnn tensors");
1672 }
1673
1674 if (self.is_contiguous()) {
1675 return self.view_symint(shape).clone(at::MemoryFormat::Contiguous);
1676 } else {
1677 return at::_unsafe_view_symint(self.clone(at::MemoryFormat::Contiguous), shape);
1678 }
1679 }
1680
1681 // Duplicate of above code for non-symbolic ints. Kept for BC purposes and to
1682 // minimize breakages.
reshape(const Tensor & self,IntArrayRef proposed_shape)1683 Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
1684 if (self.is_sparse()) {
1685 TORCH_CHECK(false, "reshape is not implemented for sparse tensors");
1686 }
1687 DimVector shape = infer_size_dv(proposed_shape, self.numel());
1688
1689 if (self.is_mkldnn()) {
1690 return at::_mkldnn_reshape(self, shape);
1691 }
1692
1693 // `computeStride` returns the proper strides to use if this
1694 // `reshape` can be just a view.
1695 auto stride = at::detail::computeStride(self.sizes(), self.strides(), shape);
1696
1697 // NB: Even though we have viewable geometry and the target strides here,
1698 // we do not just call `as_strided` on `self` because the backward
1699 // for `as_strided` is not as efficient as that of `view` (since the
1700 // former is meant to handle general cases).
1701 //
1702 // Similarly we don't call `view` because it duplicates some of the work
1703 // we've already done, and instead call our internal/private operator
1704 // `_reshape_alias` that essentially does the same thing as `view` and
1705 // `as_strided` without any of the extra overhead.
1706 if (stride.has_value()) {
1707 // Temporary check to revert to the old behavior/view in cases where the
1708 // device is not supported (e.g. for XLA the operation is not supported
1709 // so we use `view` instead).
1710 //
1711 // We need to do the checks here instead of in `native_functions.yaml`
1712 // to preserve backwards compatibility.
1713 if (!self.is_xla() && !self.is_lazy() && !self.is_ipu()) {
1714 return self._reshape_alias(shape, stride.value());
1715 } else {
1716 return self.view(shape);
1717 }
1718 }
1719 return at::_unsafe_view(self.clone(at::MemoryFormat::Contiguous), shape);
1720 }
1721
_reshape_alias(const Tensor & self,IntArrayRef sizes,IntArrayRef strides)1722 Tensor _reshape_alias(const Tensor& self, IntArrayRef sizes, IntArrayRef strides) {
1723 // This is only used by `reshape` in cases where it would otherwise have dispatched
1724 // to `view`. This removes the overhead of calling `view` which duplicates some of
1725 // the work that's already been done (`infer_size_dv` and `computeStride`).
1726
1727 return alias_with_sizes_and_strides(self, sizes, strides);
1728 }
1729
reshape_as(const Tensor & self,const Tensor & other)1730 Tensor reshape_as(const Tensor& self, const Tensor& other) {
1731 return self.reshape_symint(other.sym_sizes());
1732 }
1733
select_sparse(const Tensor & self,int64_t dim,int64_t index)1734 static Tensor select_sparse(const Tensor& self, int64_t dim, int64_t index) {
1735 int64_t sparse_dim = self.sparse_dim();
1736 int64_t dense_dim = self.dense_dim();
1737 TORCH_INTERNAL_ASSERT(dim >= 0 && dim < sparse_dim + dense_dim);
1738
1739 auto indices = self._indices();
1740 auto values = self._values();
1741 auto new_sizes = self.sizes().vec();
1742 new_sizes.erase(new_sizes.begin() + dim);
1743
1744 if (dim < sparse_dim) {
1745 auto nzIndices = (indices[dim] == index).nonzero().view(-1);
1746 auto new_values = values.index_select(0, nzIndices);
1747 if (sparse_dim == 1) {
1748 // return dense part:
1749 if (new_values.size(0) == 1) {
1750 return new_values[0];
1751 } else {
1752 // sum promotes integral type to int64 when dtype is not specified.
1753 return at::sum(new_values, 0, false, new_values.scalar_type());
1754 }
1755 } else {
1756 auto dimIndices = (arange(
1757 0,
1758 sparse_dim,
1759 std::nullopt /* dtype */,
1760 std::nullopt /* layout */,
1761 self.device(),
1762 std::nullopt /* pin_memory */) != dim)
1763 .nonzero()
1764 .view(-1);
1765 auto new_indices = indices.index_select(1, nzIndices).index_select(0, dimIndices);
1766 return _sparse_coo_tensor_with_dims_and_tensors(
1767 sparse_dim - 1, dense_dim, new_sizes, new_indices, new_values, self.options());
1768 }
1769 } else {
1770 auto new_values = values.select(dim - sparse_dim + 1, index);
1771 return _sparse_coo_tensor_with_dims_and_tensors(
1772 sparse_dim, dense_dim - 1, new_sizes, indices, new_values, self.options());
1773 }
1774 }
1775
1776 // this is an auxiliary function, called by the select&slice methods, that
1777 // creates a new quantizer from the given input
1778 // is_select is true if calling function is select()
create_subtensor_quantizer(const Tensor & self,bool is_select,int64_t start,int64_t end,int64_t dim,int64_t step)1779 static QuantizerPtr create_subtensor_quantizer(const Tensor& self, bool is_select, int64_t start,
1780 int64_t end, int64_t dim, int64_t step) {
1781 auto quantizer_prev = get_qtensorimpl(self)->quantizer();
1782 if (quantizer_prev->qscheme() == QScheme::PER_TENSOR_AFFINE) {
1783 return quantizer_prev;
1784 }
1785 QuantizerPtr quantizer;
1786 auto temp = static_cast<PerChannelAffineQuantizer*>(quantizer_prev.get());
1787 auto axis = temp->axis();
1788 auto scales = temp->scales();
1789 auto zero_points = temp->zero_points();
1790 if (dim == axis) {
1791 // Compute scales&zps for sub-tensor
1792 // *.select(0, start) could alternatively be replaced with *.slice(0, start, end, step), but
1793 // select has less overhead
1794 scales = is_select ? scales.select(0, start) : scales.slice(0, start, end, step);
1795 zero_points = is_select ? zero_points.select(0, start) : zero_points.slice(0, start, end, step);
1796 }
1797 if (scales.numel() > 1) {
1798 // Axis only needs to be adjusted if the calling function is select(), since select() reduces
1799 // the number of dimensions of the tensor by 1, and remains unchanged if calling function is slice()
1800 quantizer = make_per_channel_affine_quantizer(scales, zero_points, (is_select ? axis - 1 : axis),
1801 quantizer_prev->scalar_type());
1802 } else {
1803 quantizer = make_per_tensor_affine_quantizer(scales.item().to<double>(), zero_points.item().to<int64_t>(),
1804 quantizer_prev->scalar_type());
1805 }
1806 return quantizer;
1807 }
1808
select(const Tensor & self,int64_t dim,int64_t index)1809 Tensor select(const Tensor& self, int64_t dim, int64_t index) {
1810 return at::select_symint(self, dim, c10::SymInt{index});
1811 }
1812
select(const Tensor & self,Dimname dim,int64_t index)1813 Tensor select(const Tensor& self, Dimname dim, int64_t index) {
1814 return at::select_symint(self, dimname_to_position(self, dim), c10::SymInt{index});
1815 }
1816
select_symint(const Tensor & self,int64_t dim,c10::SymInt index)1817 Tensor select_symint(const Tensor& self, int64_t dim, c10::SymInt index) {
1818 int64_t ndim = self.dim();
1819 if (ndim == 0) {
1820 TORCH_CHECK_INDEX(false, "select() cannot be applied to a 0-dim tensor.");
1821 }
1822 dim = maybe_wrap_dim(dim, ndim);
1823 auto size = self.sym_sizes()[dim];
1824 // Note: `size < -index` is not equivalent to `size <= -1 - index` if index is INT64_MIN
1825 // For std::numeric_limits<int64_t>::min() result of unary minus is undefined by the standard
1826 // but in practice is equal to self. On the other hand, indexing wrapping is valid for all
1827 // negative int64_t values, as x[INT64_MIN] is the same as x[INT64_MAX]
1828 if (size <= -1 - index || size <= index) {
1829 if (self.has_names() && self.names()[dim] != Dimname::wildcard()) {
1830 TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ",
1831 self.sizes(), " at dimension ", self.names()[dim]);
1832 }
1833 TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ",
1834 self.sizes(), " at dimension ", dim);
1835 }
1836 if (index < 0) {
1837 index += size;
1838 }
1839 if (self.is_sparse()) {
1840 return select_sparse(self, dim, index.guard_int(__FILE__, __LINE__));
1841 }
1842
1843 Tensor result;
1844 if (self.is_quantized()) {
1845 auto local_index = index.guard_int(__FILE__, __LINE__);
1846
1847 DimVector sizes(self.sizes().begin(), self.sizes().end());
1848 DimVector strides(self.strides().begin(), self.strides().end());
1849 auto storage_offset = self.storage_offset() + local_index * strides[dim];
1850 sizes.erase(sizes.begin() + dim);
1851 strides.erase(strides.begin() + dim);
1852
1853 auto quantizer = create_subtensor_quantizer(self, true, local_index, local_index + 1, dim, 1);
1854 result = as_strided_qtensorimpl(self, sizes, strides, storage_offset, std::move(quantizer));
1855 } else {
1856 std::vector<c10::SymInt> sizes(self.sym_sizes().begin(), self.sym_sizes().end());
1857 std::vector<c10::SymInt> strides(self.sym_strides().begin(), self.sym_strides().end());
1858 auto storage_offset = self.sym_storage_offset() + index * strides[dim];
1859 sizes.erase(sizes.begin() + dim);
1860 strides.erase(strides.begin() + dim);
1861
1862 result = self.as_strided_symint(sizes, strides, storage_offset);
1863 }
1864 namedinference::propagate_names_except(result, self, {dim});
1865 return result;
1866 }
1867
select_backward_symint(const Tensor & grad,c10::SymIntArrayRef input_sizes,int64_t dim,c10::SymInt index)1868 Tensor select_backward_symint(const Tensor& grad, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) {
1869 auto grad_input = at::zeros_symint(input_sizes, grad.options());
1870 grad_input.select_symint(dim, std::move(index)).copy_(grad);
1871 return grad_input;
1872 }
1873
index_select_sparse_cpu(const Tensor & self,int64_t dim,const Tensor & index)1874 Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& index) {
1875 /*
1876 Algorithm:
1877 index - a 1-D tensor of indices with shape (n,)
1878 self - sparse tensor, its shape is sizes = sparse_shape + dense_shape
1879 indices - 2-D tensor of indices, shape is (sparse_dims, nnz)
1880 values - (1+len(dense_shape))-D tensor of values, shape is (nnz,) + dense_shape
1881 index_select(dim, index) returns a sparse tensor with the following data
1882 new_sizes = sizes[:dim] + (n,) + sizes[dim+1:]
1883 new_indices - shape is (sparse_dims, new_nnz)
1884 new_values - shape is (new_nnz,) + dense_shape
1885
1886 if dim < len(sparse_shape):
1887 # Find new_indices[dim] of the output sparse tensor and
1888 # indices at which to select values/indices.
1889 # The CPP code uses (binary/in a count table) search to find matches and may
1890 # swap the loop order for better algorithmic complexity.
1891 new_dim_indices = []
1892 selected_dim_indices = []
1893 # This is a brute-force algorithms to convey the main idea.
1894 # The CPP code below is more efficient but more complicated.
1895 for i, i_idx in enumerate(indices[dim]):
1896 for j, j_idx in enumerate(index):
1897 if i_idx == j_idx:
1898 new_dim_indices.append(j)
1899 selected_dim_indices.append(i)
1900 new_indices = indices.index_select(1, selected_dim_indices)
1901 new_values = values.index_select(0, selected_dim_indices)
1902 new_indices[dim] = new_dim_indices
1903 else:
1904 new_indices = indices
1905 new_values = values.index_select(dim - sparse_dim + 1, index);
1906 */
1907 const auto ndim = self.dim();
1908 TORCH_CHECK_INDEX(ndim, "index_select() cannot be applied to a 0-dim tensor.");
1909 TORCH_CHECK_INDEX(
1910 index.dim() == 1 && index.dtype() == at::kLong && index.options().layout() == at::kStrided,
1911 "index_select() argument index must be 1-D strided (non-sparse) long-tensor.");
1912 dim = maybe_wrap_dim(dim, ndim);
1913 const auto size = self.size(dim);
1914 const auto sparse_dim = self.sparse_dim();
1915 const auto dense_dim = self.dense_dim();
1916 const auto indices = self._indices();
1917 const auto values = self._values();
1918 const auto nnz = values.size(0);
1919 const auto index_len = index.size(0);
1920 auto res_sizes = self.sizes().vec();
1921 res_sizes[dim] = index_len;
1922
1923 // Equivalent to t.index_select(dim, idx), but vanilla index_select is not parallel,
1924 // so we use gather instead.
1925 // We use this method to select relevant indices/values
1926 // from the intersection between indices[dim] and the index.
1927 const auto index_select = [](const Tensor& t, int64_t dim, const Tensor& idx) -> Tensor {
1928 const auto idx_len = idx.numel();
1929 auto out_shape = t.sizes().vec();
1930 out_shape[dim] = idx_len;
1931 auto idx_shape = std::vector<int64_t>(t.dim(), 1);
1932 idx_shape[dim] = idx_len;
1933 return t.gather(dim, idx.view(idx_shape).expand(out_shape));
1934 };
1935
1936 // If indexing into sparse dimensions
1937 if (dim < sparse_dim) {
1938 // short-circuit if index is empty
1939 if (!index_len) {
1940 auto res_indices = index_select(indices, 1, index);
1941 res_indices[dim] = index;
1942 const auto res_values = index_select(values, 0, index);
1943
1944 return _sparse_coo_tensor_with_dims_and_tensors(
1945 sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
1946 }
1947
1948 const auto nneg_index = [&index, index_len, &self, size, dim]() -> Tensor {
1949 const auto index_contiguous = index.contiguous();
1950 auto nneg_index = at::empty_like(index_contiguous);
1951 // nneg_index = (index < 0) * (index + size) + (index >= 0) * index
1952 auto* ptr_index = index_contiguous.data_ptr<int64_t>();
1953 auto* ptr_nneg_index = nneg_index.data_ptr<int64_t>();
1954 at::parallel_for(0, index_len, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
1955 const auto* src = ptr_index + start;
1956 auto* dst = ptr_nneg_index + start;
1957 for (C10_UNUSED const auto _ : c10::irange(start, end)) {
1958 auto idx = *src++;
1959 if (idx < -size || idx >= size) {
1960 // Mark self and dim as used if code is compiled with STRIP_ERROR_MESSAGES
1961 (void)dim;
1962 (void)self;
1963 TORCH_CHECK_INDEX(false,
1964 "index_select(): index contains ", idx, " that is out of range for tensor of size ",
1965 self.sizes(), " at dimension ", dim
1966 );
1967 }
1968 if (idx < 0) {
1969 idx += size;
1970 }
1971 *dst++ = idx;
1972 }
1973 });
1974
1975 return nneg_index;
1976 }();
1977
1978 const auto dim_indices = indices[dim].contiguous();
1979
1980 // If nnz is smaller than size, then either indices[dim] or index gets sorted,
1981 // then this is followed by a binary search to find interesections.
1982 const auto get_selected_indices_small_nnz_large_size = [&]() -> std::tuple<Tensor, Tensor> {
1983 const auto grain_size = at::internal::GRAIN_SIZE;
1984 const auto n_threads_nnz = std::max<int64_t>(
1985 1, std::min<int64_t>((nnz + grain_size - 1) / grain_size, at::get_num_threads())
1986 );
1987 const auto n_threads_index = std::max<int64_t>(
1988 1, std::min<int64_t>((index_len + grain_size - 1) / grain_size, at::get_num_threads())
1989 );
1990 const auto search_in_dim_indices
1991 // if either dim_indices or index requires sorting, we compare
1992 // the cost of sort + binary search, which is comparing
1993 // (len(dim_indices) + len(index)) * log(len(index)) to
1994 // (len(dim_indices) + len(index)) * log(len(dim_indices)).
1995 // That simplifies to comparing len(dim_indices) to len(index).
1996 // Additionally, we take into consideration potential parallel
1997 // speedup.
1998 = (nnz / n_threads_nnz <= index_len / n_threads_index)
1999 // if self is coalesced and dim is 0, then we compare
2000 // index_len * log(len(dim_indices)), which is binary search into dim_indices,
2001 // to (len(index_len) + len(dim_indices)) * log(index_len).
2002 // Additionally, we take into consideration potential parallel
2003 // speedup.
2004 || (self.is_coalesced() && dim == 0
2005 && (index_len * std::log2(nnz) / n_threads_index
2006 <= (nnz / n_threads_nnz + index_len) * std::log2(index_len)))
2007 ? true : false;
2008
2009 // src is a source of indices to binary search in sorted
2010 Tensor sorted, sorted_idx, src;
2011 std::tie(sorted, sorted_idx, src) = [
2012 &dim_indices, &nneg_index, &self,
2013 search_in_dim_indices, dim, nnz
2014 ](void) -> std::tuple<Tensor, Tensor, Tensor> {
2015 // sort dim_indices to binary search into it
2016 if (search_in_dim_indices) {
2017 // dim_indices is already sorted if self is coalesced and dim == 0
2018 if (self.is_coalesced() && dim == 0) {
2019 return std::make_tuple(dim_indices, at::arange(nnz, dim_indices.options()), nneg_index);
2020 }
2021 else {
2022 auto [sorted_dim_indices, sorted_dim_indices_idx] = dim_indices.sort();
2023 return std::make_tuple(sorted_dim_indices, sorted_dim_indices_idx, nneg_index);
2024 }
2025 }
2026 // sort nneg_index to binary search into it
2027 else {
2028 auto [sorted_nneg_index, sorted_nneg_index_idx] = nneg_index.sort();
2029 return std::make_tuple(sorted_nneg_index, sorted_nneg_index_idx, dim_indices);
2030 }
2031 }();
2032
2033 const auto src_grain_size = at::internal::GRAIN_SIZE;
2034 const auto src_len = src.numel();
2035 const auto n_threads_src = std::max<int64_t>(
2036 // 1 <= n_threads_src <= std::min(ceil(src.numel() / src_grain_size), max_threads)
2037 1, std::min<int64_t>((src_len + src_grain_size - 1) / src_grain_size, at::get_num_threads())
2038 );
2039 const auto chunk_size_src = (src_len + n_threads_src - 1) / n_threads_src;
2040
2041 const std::vector<int64_t> src_n_threads_shape = {
2042 n_threads_src, (src_len + n_threads_src - 1) / n_threads_src
2043 };
2044
2045 // src_int_idx and sorted_int_idx store "i" and "j" indices indicating
2046 // intersections such that src_int_idx[i] == sorted_int_idx[j].
2047 // These intersections are found with binary search and in parallel.
2048 auto src_int_idx = at::empty(src_n_threads_shape, src.options());
2049 auto sorted_int_idx = at::empty_like(src_int_idx);
2050 // For each element "i" from src, int_counts define how many
2051 // elements there are in sorted, i.e. "j" indices, corresponding
2052 // to "i", i.e.:
2053 // |{j : src_int_idx[i] == sorted_int_idx[j]}| for each i in src_int_idx.
2054 auto int_counts = at::zeros_like(src_int_idx);
2055
2056 // fill in src_int_idx, sorted_int_idx, int_counts
2057 {
2058 const auto sorted_len = sorted.numel();
2059 const auto* ptr_sorted = sorted.const_data_ptr<int64_t>();
2060 const auto* ptr_sorted_start = ptr_sorted;
2061 const auto* ptr_sorted_end = ptr_sorted + sorted_len;
2062
2063 at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
2064 const auto start = tid * chunk_size_src;
2065 const auto end = std::min(start + chunk_size_src, src_len);
2066 auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).data_ptr<int64_t>();
2067 auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).data_ptr<int64_t>();
2068 auto* ptr_tid_int_counts = int_counts.select(0, tid).data_ptr<int64_t>();
2069 const auto* ptr_src = src.const_data_ptr<int64_t>() + start;
2070
2071 for (const auto i : c10::irange(start, end)) {
2072 const auto src_val = *ptr_src++;
2073 const auto src_val_lb = std::lower_bound(ptr_sorted_start, ptr_sorted_end, src_val);
2074 // We cannot just use *src_val_lb != src_val because when
2075 // src_val_lb == ptr_sorted_end, dereferencing past-the-end value
2076 // is not well-defined.
2077 if (src_val_lb == ptr_sorted_end || *src_val_lb != src_val) {
2078 ++ptr_tid_src_int_idx;
2079 ++ptr_tid_sorted_int_idx;
2080 ++ptr_tid_int_counts;
2081 continue;
2082 }
2083 const auto src_val_ub = std::upper_bound(ptr_sorted_start, ptr_sorted_end, src_val);
2084
2085 const int64_t count = src_val_ub - src_val_lb;
2086 const int64_t j = src_val_lb - ptr_sorted_start;
2087
2088 *ptr_tid_src_int_idx++ = i;
2089 *ptr_tid_sorted_int_idx++ = j;
2090 *ptr_tid_int_counts++ = count;
2091 }
2092 });
2093 }
2094
2095 const auto compressed_int_counts = int_counts.sum(-1);
2096 const auto res_len = compressed_int_counts.sum().item<int64_t>();
2097
2098 // Short-circuit if empty intersection
2099 if (!res_len) {
2100 auto empty_idx = at::empty({0}, src.options());
2101 return std::make_tuple(empty_idx, empty_idx);
2102 }
2103
2104 // Now that we know "i", "j" and the counts, we "unflatten"
2105 // them into two arrays of intersection indices such that
2106 // selected_src = repeat_interleave(src_int_idx, int_counts),
2107 // and selected_sorted is obtained as follows:
2108 // offsets = int_counts.cumsum(0).sub_(int_counts)
2109 // for ii, (j, c) in enumerate(zip(sorted_int_idx, int_counts)):
2110 // out_slice = slice(offsets[ii], offsets[ii] + c)
2111 // src_slice = slice(j, j + c)
2112 // selected_sorted[out_slice] = sorted_int_idx[src_slice]
2113 auto selected_sorted = at::empty({res_len}, sorted.options());
2114 auto selected_src = at::empty({res_len}, src.options());
2115
2116 // fill in selected_sorted, selected_src
2117 {
2118 auto* ptr_selected_sorted = selected_sorted.data_ptr<int64_t>();
2119 auto* ptr_selected_src = selected_src.data_ptr<int64_t>();
2120
2121 const auto thread_offsets = compressed_int_counts.cumsum(0).sub_(compressed_int_counts);
2122 const auto* ptr_sorted_idx = sorted_idx.const_data_ptr<int64_t>();
2123 at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
2124 const auto start = tid * chunk_size_src;
2125 const auto end = std::min(start + chunk_size_src, src_len);
2126 const auto tid_offset = thread_offsets.const_data_ptr<int64_t>()[tid];
2127 const auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).const_data_ptr<int64_t>();
2128 const auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).const_data_ptr<int64_t>();
2129 const auto* ptr_tid_int_counts = int_counts.select(0, tid).const_data_ptr<int64_t>();
2130 auto* ptr_tid_selected_sorted = ptr_selected_sorted + tid_offset;
2131 auto* ptr_tid_selected_src = ptr_selected_src + tid_offset;
2132
2133 for (C10_UNUSED const auto _ : c10::irange(start, end)) {
2134 const auto count = *ptr_tid_int_counts++;
2135 const auto i = *ptr_tid_src_int_idx++;
2136 const auto j = *ptr_tid_sorted_int_idx++;
2137 if (!count) continue;
2138
2139 std::fill_n(ptr_tid_selected_src, count, i);
2140 std::copy_n(ptr_sorted_idx + j, count, ptr_tid_selected_sorted);
2141
2142 ptr_tid_selected_sorted += count;
2143 ptr_tid_selected_src += count;
2144 }
2145 });
2146 }
2147
2148 return search_in_dim_indices
2149 ? std::make_tuple(selected_sorted, selected_src)
2150 : std::make_tuple(selected_src, selected_sorted);
2151 };
2152
2153 // Converts a 1d sorted idx to a compressed 1d compressed idx,
2154 // aka crow in the CSR format. Useful to get a count table in
2155 // a parallelized and no-sync manner.
2156 // TODO: this function is equivalent to _convert_indices_from_coo_to_csr.
2157 // The mentioned function is not public yet.
2158 const auto sorted_idx_to_cidx = [](
2159 const Tensor& idx,
2160 int64_t len,
2161 bool run_in_parallel = true) -> Tensor {
2162 auto cidx = at::empty({len + 1}, idx.options());
2163
2164 const auto* ptr_idx = idx.const_data_ptr<int64_t>();
2165 auto* ptr_cidx = cidx.data_ptr<int64_t>();
2166
2167 const auto idx_len = idx.numel();
2168
2169 std::fill_n(ptr_cidx, ptr_idx[0] + 1, 0);
2170 std::fill_n(ptr_cidx + ptr_idx[idx_len - 1] + 1, len - ptr_idx[idx_len - 1], idx_len);
2171
2172 const auto grain_size = run_in_parallel ? at::internal::GRAIN_SIZE : idx_len;
2173 at::parallel_for(0, idx_len, grain_size, [&](int64_t start, int64_t end) {
2174 auto* ptr_curr_cidx = ptr_cidx + ptr_idx[start] + 1;
2175 for (int64_t i = start; i < std::min(end, idx_len - 1); ++i) {
2176 const auto diff = ptr_idx[i + 1] - ptr_idx[i];
2177 std::fill_n(ptr_curr_cidx, diff, i + 1);
2178 ptr_curr_cidx += diff;
2179 }
2180 });
2181
2182 return cidx;
2183 };
2184
2185 // If nnz is (much) larger than size, then both indices[dim] and index get sorted
2186 // with a count sort (faster, and no huge nnz-sized chunk memory allocations).
2187 // The element-wise product between the count tables gives us all the intersections.
2188 const auto get_selected_indices_large_nnz_small_size = [&]() -> std::tuple<Tensor, Tensor> {
2189 const auto get_counts = [&sorted_idx_to_cidx](
2190 // Writes into counts (must be preallocated and zero)
2191 // and allows to use external buffers.
2192 Tensor& counts,
2193 const Tensor& t,
2194 int64_t bins,
2195 bool is_sorted = false,
2196 bool run_in_parallel = true) -> void {
2197 if (is_sorted) {
2198 const auto cidx = sorted_idx_to_cidx(t, bins, run_in_parallel);
2199 at::sub_out(counts, cidx.slice(0, 1, bins + 1), cidx.slice(0, 0, bins));
2200 }
2201 else {
2202 auto* ptr_counts = counts.data_ptr<int64_t>();
2203 const auto* ptr_vals = t.const_data_ptr<int64_t>();
2204 for (C10_UNUSED const auto _ : c10::irange(t.numel())) {
2205 ++ptr_counts[*ptr_vals++];
2206 }
2207 }
2208 };
2209
2210 const auto counts_per_thread = [&get_counts, size](
2211 const Tensor& idx,
2212 bool is_sorted = false,
2213 int64_t grain_size = at::internal::GRAIN_SIZE
2214 ) -> Tensor {
2215 const auto idx_len = idx.numel();
2216 // 1 <= n_threads <= min(ceil(len / grain_size), max_threads)
2217 const auto n_threads = std::max<int64_t>(
2218 1, std::min<int64_t>((idx_len + grain_size - 1) / grain_size, at::get_num_threads())
2219 );
2220 const auto chunk_size = (idx_len + n_threads - 1) / n_threads;
2221 const auto run_in_parallel = (n_threads == 1);
2222
2223 auto counts_per_thread = at::zeros({n_threads, size}, idx.options());
2224 at::parallel_for(0, n_threads, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
2225 const auto start = tid * chunk_size;
2226 const auto end = std::min(start + chunk_size, idx_len);
2227 const auto tid_idx = idx.slice(0, start, end);
2228 auto tid_counts = counts_per_thread.select(0, tid);
2229 get_counts(tid_counts, tid_idx, /*bins=*/size,
2230 /*is_sorted=*/is_sorted, /*run_in_parallel=*/run_in_parallel);
2231 });
2232
2233 return counts_per_thread;
2234 };
2235
2236 auto dim_indices_counts_per_thread = counts_per_thread(
2237 dim_indices,
2238 /*is_sorted=*/self.is_coalesced() && dim == 0
2239 /*grain_size = at::internal::GRAIN_SIZE*/
2240 );
2241 auto dim_indices_offset_counts_per_thread = dim_indices_counts_per_thread.cumsum(0);
2242
2243 auto index_counts_per_thread = counts_per_thread(
2244 nneg_index,
2245 /*is_sorted=*/false
2246 /*grain_size = at::internal::GRAIN_SIZE*/
2247 );
2248 auto index_offset_counts_per_thread = index_counts_per_thread.cumsum(0);
2249
2250 const auto index_counts = index_offset_counts_per_thread.select(0, -1);
2251 const auto dim_indices_counts = dim_indices_offset_counts_per_thread.select(0, -1);
2252 const auto intersection_counts = index_counts.mul(dim_indices_counts);
2253 const auto res_len = intersection_counts.sum().item<int64_t>();
2254 // Short-circuit if empty intersection
2255 if (!res_len) {
2256 auto empty_idx = at::empty({0}, index.options());
2257 return std::make_tuple(empty_idx, empty_idx);
2258 }
2259 const auto intersection_offsets = intersection_counts.cumsum(0);
2260
2261 const auto search_in_dim_indices = [&]() -> bool {
2262 const auto grain_size = at::internal::GRAIN_SIZE;
2263 const auto n_threads_index = std::max<int64_t>(
2264 1, std::min<int64_t>((index_len + grain_size - 1) / grain_size, at::get_num_threads())
2265 );
2266 const auto n_threads_dim_indices = std::max<int64_t>(
2267 1, std::min<int64_t>((nnz + grain_size - 1) / grain_size, at::get_num_threads())
2268 );
2269
2270 const auto index_max_copy_work_per_thread =
2271 index_counts_per_thread.mul(dim_indices_counts).sum(-1).max().item<int64_t>();
2272 const auto dim_indices_max_copy_work_per_thread
2273 = dim_indices_counts_per_thread.mul(index_counts).sum(-1).max().item<int64_t>();
2274
2275 const auto index_max_work_per_thread = index_max_copy_work_per_thread * index_len / n_threads_index;
2276 const auto dim_indices_max_work_per_thread = dim_indices_max_copy_work_per_thread * nnz / n_threads_dim_indices;
2277 return index_max_work_per_thread <= dim_indices_max_work_per_thread
2278 ? true
2279 : false;
2280 }();
2281
2282 Tensor idx, idx_counts_per_thread, idx_offset_counts_per_thread;
2283 Tensor src, src_counts_per_thread, src_offset_counts_per_thread;
2284 std::tie(
2285 idx, idx_counts_per_thread, idx_offset_counts_per_thread,
2286 src, src_counts_per_thread, src_offset_counts_per_thread
2287 ) = [&]() {
2288 return search_in_dim_indices
2289 ? std::make_tuple(
2290 nneg_index, index_counts_per_thread, index_offset_counts_per_thread,
2291 dim_indices, dim_indices_counts_per_thread, dim_indices_offset_counts_per_thread
2292 )
2293 : std::make_tuple(
2294 dim_indices, dim_indices_counts_per_thread, dim_indices_counts_per_thread.cumsum(0),
2295 nneg_index, index_counts_per_thread, index_counts_per_thread.cumsum(0)
2296 );
2297 }();
2298
2299 const auto idx_counts = idx_offset_counts_per_thread.select(0, -1);
2300 const auto src_counts = src_offset_counts_per_thread.select(0, -1);
2301
2302 Tensor src_idx, src_idx_offsets;
2303 std::tie(src_idx, src_idx_offsets) = [&](
2304 int64_t grain_size = at::internal::GRAIN_SIZE
2305 ) -> std::tuple<Tensor, Tensor> {
2306 const auto src_intersection_counts = src_counts.mul(idx_counts > 0);
2307 const auto src_intersection_offsets = src_intersection_counts.cumsum(0);
2308 const auto src_idx_len = src_intersection_offsets.const_data_ptr<int64_t>()[size - 1];
2309 auto src_idx = at::empty({src_idx_len}, src.options());
2310
2311 const auto* ptr_src = src.const_data_ptr<int64_t>();
2312 const auto* ptr_intersection_counts = intersection_counts.const_data_ptr<int64_t>();
2313 const auto* ptr_src_intersection_counts = src_intersection_counts.const_data_ptr<int64_t>();
2314 const auto* ptr_src_intersection_offsets = src_intersection_offsets.const_data_ptr<int64_t>();
2315 auto* ptr_src_idx = src_idx.data_ptr<int64_t>();
2316
2317 const auto src_len = src.numel();
2318 const auto n_threads_src = std::max<int64_t>(
2319 1, std::min<int64_t>((src_len + grain_size - 1) / grain_size, at::get_num_threads())
2320 );
2321 const auto chunk_size = (src_len + n_threads_src - 1) / n_threads_src;
2322 at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
2323 const auto start = tid * chunk_size;
2324 const auto end = std::min(start + chunk_size, src_len);
2325 auto* ptr_src_tid = ptr_src + start;
2326 const auto* ptr_src_counts_per_thread
2327 = src_counts_per_thread.select(0, tid).const_data_ptr<int64_t>();
2328 const auto* ptr_src_offset_counts_per_thread
2329 = src_offset_counts_per_thread.select(0, tid).const_data_ptr<int64_t>();
2330 auto tid_counts = at::zeros({size}, src.options());
2331 auto* ptr_tid_counts = tid_counts.data_ptr<int64_t>();
2332
2333 for (const auto i : c10::irange(start, end)) {
2334 const auto idx_val = *ptr_src_tid++;
2335 // skip idx value if not in the intersection
2336 if (!ptr_intersection_counts[idx_val]) continue;
2337 const auto idx_val_offset
2338 = ptr_src_intersection_offsets[idx_val]
2339 - ptr_src_intersection_counts[idx_val];
2340 const auto idx_val_tid_offset
2341 = ptr_src_offset_counts_per_thread[idx_val]
2342 - ptr_src_counts_per_thread[idx_val];
2343 auto& idx_val_local_tid_count = ptr_tid_counts[idx_val];
2344 ptr_src_idx[idx_val_offset + idx_val_tid_offset + idx_val_local_tid_count] = i;
2345 ++idx_val_local_tid_count;
2346 }
2347 });
2348
2349 const auto src_idx_offsets = src_intersection_offsets.sub_(src_intersection_counts);
2350
2351 return std::make_tuple(src_idx, src_idx_offsets);
2352 }();
2353
2354 auto [idx_selected, src_selected] = [&](
2355 int64_t grain_size = at::internal::GRAIN_SIZE
2356 ) -> std::tuple<Tensor, Tensor> {
2357 const auto thread_offset = [&]() {
2358 // we do not need idx_counts_per_thread anymore,
2359 // so it is safe to do in-place intersection.
2360 auto counts_per_thread = idx_counts_per_thread.mul_(src_counts).sum(-1);
2361 return counts_per_thread.cumsum(0).sub_(counts_per_thread);
2362 }();
2363 const auto* ptr_thread_offset = thread_offset.const_data_ptr<int64_t>();
2364
2365 auto idx_selected = at::empty({res_len}, idx.options());
2366 auto src_selected = at::empty({res_len}, src.options());
2367
2368 const auto* ptr_idx = idx.const_data_ptr<int64_t>();
2369 const auto* ptr_src_counts = src_counts.const_data_ptr<int64_t>();
2370 const auto* ptr_intersection_counts = intersection_counts.const_data_ptr<int64_t>();
2371 const auto* ptr_src_idx = src_idx.const_data_ptr<int64_t>();
2372 const auto* ptr_src_idx_offsets = src_idx_offsets.const_data_ptr<int64_t>();
2373 auto* ptr_idx_selected = idx_selected.data_ptr<int64_t>();
2374 auto* ptr_src_selected = src_selected.data_ptr<int64_t>();
2375
2376 const auto idx_len = idx.numel();
2377 const auto n_threads_idx = std::max<int64_t>(
2378 1, std::min<int64_t>((idx_len + grain_size - 1) / grain_size, at::get_num_threads())
2379 );
2380 const auto chunk_size = (idx_len + n_threads_idx - 1) / n_threads_idx;
2381 at::parallel_for(0, n_threads_idx, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
2382 const auto start = tid * chunk_size;
2383 const auto end = std::min(start + chunk_size, idx_len);
2384 const auto tid_offset = ptr_thread_offset[tid];
2385 const auto* ptr_idx_tid = ptr_idx + start;
2386 auto* ptr_idx_selected_tid = ptr_idx_selected + tid_offset;
2387 auto* ptr_src_selected_tid = ptr_src_selected + tid_offset;
2388
2389 for (const auto i : c10::irange(start, end)) {
2390 const auto idx_val = *ptr_idx_tid++;
2391 // skip if idx_val is not in the intersection
2392 if (!ptr_intersection_counts[idx_val]) continue;
2393 const auto count = ptr_src_counts[idx_val];
2394 const auto j = ptr_src_idx_offsets[idx_val];
2395 std::fill_n(ptr_idx_selected_tid, count, i);
2396 std::copy_n(ptr_src_idx + j, count, ptr_src_selected_tid);
2397 ptr_idx_selected_tid += count;
2398 ptr_src_selected_tid += count;
2399 }
2400 });
2401
2402 return std::make_tuple(idx_selected, src_selected);
2403 }();
2404
2405 return search_in_dim_indices
2406 ? std::make_tuple(src_selected, idx_selected)
2407 : std::make_tuple(idx_selected, src_selected);
2408 };
2409
2410 const auto make_output = [&](
2411 const Tensor& selected_dim_indices,
2412 const Tensor& res_dim_indices) -> Tensor {
2413 auto res_indices = index_select(indices, 1, selected_dim_indices);
2414 res_indices[dim] = res_dim_indices;
2415 const auto res_values = index_select(values, 0, selected_dim_indices);
2416
2417 return _sparse_coo_tensor_with_dims_and_tensors(
2418 sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
2419 };
2420
2421 // Brute-force solution for small values of nnz and index_len
2422 const auto get_result_small_nnz_small_index = [&]()
2423 -> Tensor {
2424 const auto dim_indices_in_inner_loop = nnz >= index_len;
2425 auto [outer, inner] = [&]() -> std::tuple<Tensor, Tensor> {
2426 if (dim_indices_in_inner_loop) {
2427 return std::make_tuple(nneg_index, dim_indices);
2428 }
2429 else {
2430 return std::make_tuple(dim_indices, nneg_index);
2431 }
2432 }();
2433
2434 const auto* ptr_outer = outer.const_data_ptr<int64_t>();
2435 const auto* ptr_inner = inner.const_data_ptr<int64_t>();
2436 // NOTE: if very critical, replace std::vector with
2437 // a data structure that operates on stack up to some limit.
2438 auto outer_selected_idx = std::vector<int64_t>();
2439 auto inner_selected_idx = std::vector<int64_t>();
2440 int64_t res_len = 0;
2441 for (const auto i : c10::irange(outer.numel())) {
2442 for (const auto j : c10::irange(inner.numel())) {
2443 if (ptr_outer[i] == ptr_inner[j]) {
2444 ++res_len;
2445 outer_selected_idx.push_back(i);
2446 inner_selected_idx.push_back(j);
2447 }
2448 }
2449 }
2450
2451 const auto outer_selected_idx_tensor = at::from_blob(
2452 outer_selected_idx.data(), {res_len}, at::kLong
2453 );
2454 const auto inner_selected_idx_tensor = at::from_blob(
2455 inner_selected_idx.data(), {res_len}, at::kLong
2456 );
2457
2458 return dim_indices_in_inner_loop
2459 ? make_output(inner_selected_idx_tensor, outer_selected_idx_tensor)
2460 : make_output(outer_selected_idx_tensor, inner_selected_idx_tensor);
2461 };
2462
2463 constexpr int64_t BRUTE_FORCE_SIZE_LIMIT = 2 << 14; // 16384
2464 // NOTE: such a condition to avoid overflows in (nnz * index_len)
2465 if (nnz <= BRUTE_FORCE_SIZE_LIMIT && index_len <= BRUTE_FORCE_SIZE_LIMIT
2466 && (nnz * index_len) <= BRUTE_FORCE_SIZE_LIMIT) {
2467 return get_result_small_nnz_small_index();
2468 }
2469 else {
2470 Tensor selected_dim_indices;
2471 Tensor res_dim_indices;
2472
2473 // A more precise decision could be of the form:
2474 // `nnz < C(nnz, size) * size`, but it requires heavy benchmarking.
2475 // We choose `nnz < size`, which measures theoretical complexity
2476 // and does not rely on runtime performance.
2477 // TODO: perform this analysis and find better C(nnz, size).
2478 if (nnz <= size) {
2479 std::tie(selected_dim_indices, res_dim_indices) = get_selected_indices_small_nnz_large_size();
2480 }
2481 else {
2482 std::tie(selected_dim_indices, res_dim_indices) = get_selected_indices_large_nnz_small_size();
2483 }
2484
2485 return make_output(selected_dim_indices, res_dim_indices);
2486 }
2487 }
2488 // If indexing into dense dimensions
2489 else {
2490 // It is sufficient to just perform `index_select` on values
2491 // if `dim` refers to dense dimensions.
2492 const auto res_values = index_select(values, dim - sparse_dim + 1, index);
2493
2494 return _sparse_coo_tensor_with_dims_and_tensors(
2495 sparse_dim, dense_dim, res_sizes, indices, res_values, self.options());
2496 }
2497 }
2498
slice(const Tensor & self,int64_t dim,std::optional<int64_t> start,std::optional<int64_t> end,int64_t step)2499 Tensor slice(
2500 const Tensor& self,
2501 int64_t dim,
2502 std::optional<int64_t> start,
2503 std::optional<int64_t> end,
2504 int64_t step) {
2505 int64_t ndim = self.dim();
2506 if (ndim == 0) {
2507 TORCH_CHECK_INDEX(false, "slice() cannot be applied to a 0-dim tensor.");
2508 }
2509 dim = maybe_wrap_dim(dim, ndim);
2510 DimVector sizes(self.sizes().begin(), self.sizes().end());
2511 DimVector strides(self.strides().begin(), self.strides().end());
2512 // handle optional parameters
2513 int64_t start_val = start.has_value() ? start.value() : 0;
2514 int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
2515
2516 // TODO: support negative strides
2517 TORCH_CHECK(step > 0, "slice step must be positive");
2518
2519 if (start_val < 0) {
2520 start_val += sizes[dim];
2521 }
2522 if (end_val < 0) {
2523 end_val += sizes[dim];
2524 }
2525 if (start_val < 0) {
2526 start_val = 0;
2527 } else if (start_val >= sizes[dim]) {
2528 start_val = sizes[dim];
2529 }
2530 if (end_val < start_val) {
2531 end_val = start_val;
2532 } else if (end_val >= sizes[dim]) {
2533 end_val = sizes[dim];
2534 }
2535 auto storage_offset = self.storage_offset() + start_val * strides[dim];
2536 auto len = end_val - start_val;
2537 sizes[dim] = (len + step - 1) / step; // round-up
2538 strides[dim] *= step;
2539
2540 Tensor result;
2541 if (self.is_quantized()) {
2542 auto quantizer = create_subtensor_quantizer(self, false, start_val, end_val, dim, step);
2543 result = as_strided_qtensorimpl(self, sizes, strides, storage_offset, std::move(quantizer));
2544 } else {
2545 // NB: it is extremely important to perform a redispatch here for
2546 // the MPS backend; if you call directly to as_strided_tensorimpl,
2547 // the necessary metadata for MPS will not get setup and you will
2548 // get silently wrong results
2549 result = self.as_strided(sizes, strides, storage_offset);
2550 }
2551 namedinference::propagate_names(result, self);
2552 return result;
2553 }
2554
slice_inverse_symint(const Tensor & self,const Tensor & base,int64_t,std::optional<SymInt>,std::optional<SymInt>,SymInt)2555 Tensor slice_inverse_symint(
2556 const Tensor& self,
2557 const Tensor& base,
2558 int64_t /* dim */,
2559 std::optional<SymInt> /* start */,
2560 std::optional<SymInt> /* end */,
2561 SymInt /* step */) {
2562 // assume self has enough to storage to be viewed with base's metadata
2563 return self.as_strided_symint(base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
2564 }
2565
slice_backward(const Tensor & grad,IntArrayRef input_sizes,int64_t dim,int64_t start,int64_t end,int64_t step)2566 Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
2567 auto grad_input = at::zeros(input_sizes, grad.options());
2568 grad_input.slice(dim, start, end, step).copy_(grad);
2569 return grad_input;
2570 }
2571
split(const Tensor & self,int64_t split_size,int64_t dim)2572 std::vector<Tensor> split(const Tensor& self, int64_t split_size, int64_t dim) {
2573 const auto num_splits = get_num_splits(self, split_size, dim);
2574 std::vector<Tensor> splits(num_splits);
2575 int64_t last_split_size = split_size - (split_size * num_splits - self.size(dim));
2576
2577 for (const auto i : c10::irange(num_splits)) {
2578 auto length = i < num_splits - 1 ? split_size : last_split_size;
2579 splits[i] = self.narrow(dim, i * split_size, length);
2580 }
2581 return splits;
2582 }
2583
split_symint(const Tensor & self,c10::SymIntArrayRef sizes,int64_t dim)2584 std::vector<Tensor> split_symint(const Tensor& self, c10::SymIntArrayRef sizes, int64_t dim) {
2585 return at::split_with_sizes_symint(self, sizes, dim);
2586 }
2587
unsafe_split(const Tensor & self,int64_t split_size,int64_t dim)2588 std::vector<Tensor> unsafe_split(const Tensor& self, int64_t split_size, int64_t dim) {
2589 auto result = at::native::split(self, split_size, dim);
2590 for (auto& t : result) {
2591 // TODO(Ailing): do we need to set version_counter here?
2592 if (!t.is_inference()) {
2593 t.unsafeGetTensorImpl()->set_version_counter(c10::VariableVersion(/*version=*/0));
2594 }
2595 }
2596 return result;
2597 }
2598
hsplit(const Tensor & self,int64_t split_size)2599 std::vector<Tensor> hsplit(const Tensor& self, int64_t split_size) {
2600 TORCH_CHECK(self.dim() >= 1, "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with ", self.dim(), " dimensions!")
2601 int64_t dim = (self.dim() == 1) ? 0 : 1;
2602 TORCH_CHECK(split_size != 0 && self.sym_sizes()[dim] % split_size == 0,
2603 "torch.hsplit attempted to split along dimension ", dim,", but the size of the dimension ", self.sizes()[dim], " is not divisible by the split_size ", split_size, "!");
2604 return at::tensor_split(self, split_size, dim);
2605 }
2606
vsplit(const Tensor & self,int64_t split_size)2607 std::vector<Tensor> vsplit(const Tensor& self, int64_t split_size) {
2608 TORCH_CHECK(self.dim() >= 2, "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with ", self.dim(), " dimensions!")
2609 TORCH_CHECK(split_size != 0 && self.sym_sizes()[0] % split_size == 0,
2610 "torch.vsplit attempted to split along dimension ", 0,", but the size of the dimension ", self.sizes()[0], " is not divisible by the split_size ", split_size, "!");
2611 return at::tensor_split(self, split_size, 0);
2612 }
2613
dsplit(const Tensor & self,int64_t split_size)2614 std::vector<Tensor> dsplit(const Tensor& self, int64_t split_size) {
2615 TORCH_CHECK(self.dim() >= 3, "torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with ", self.dim(), " dimensions!")
2616 TORCH_CHECK(split_size != 0 && self.sym_sizes()[2] % split_size == 0,
2617 "torch.dsplit attempted to split along dimension ", 2,", but the size of the dimension ", self.sizes()[2], " is not divisible by the split_size ", split_size, "!");
2618 return at::tensor_split(self, split_size, 2);
2619 }
2620
split_with_sizes(const Tensor & self,IntArrayRef split_sizes,int64_t dim)2621 std::vector<Tensor> split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
2622 TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
2623 const int64_t dim_size = self.size(dim);
2624 const int64_t num_splits = split_sizes.size();
2625 int64_t start_idx = 0;
2626
2627 std::vector<Tensor> splits;
2628 splits.reserve(num_splits);
2629 for (const auto i : c10::irange(num_splits)) {
2630 auto length = split_sizes[i];
2631 TORCH_CHECK(length >= 0,
2632 "split_with_sizes expects split_sizes have only non-negative ",
2633 "entries, but got split_sizes=", split_sizes);
2634 splits.push_back(at::native::slice(self, dim, start_idx, start_idx + length, 1));
2635 start_idx += length;
2636 }
2637 TORCH_CHECK(start_idx == dim_size,
2638 "split_with_sizes expects split_sizes to sum exactly to ", dim_size,
2639 " (input tensor's size at dimension ", dim, "), ", "but got split_sizes=", split_sizes);
2640 return splits;
2641 }
2642
unsafe_split_with_sizes(const Tensor & self,IntArrayRef split_sizes,int64_t dim)2643 std::vector<Tensor> unsafe_split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
2644 auto result = at::native::split_with_sizes(self, split_sizes, dim);
2645 for (auto& t : result) {
2646 // TODO(Ailing): do we need to set version_counter here?
2647 if (!t.is_inference()) {
2648 t.unsafeGetTensorImpl()->set_version_counter(c10::VariableVersion(/*version=*/0));
2649 }
2650 }
2651 return result;
2652 }
2653
hsplit(const Tensor & self,IntArrayRef split_sizes)2654 std::vector<Tensor> hsplit(const Tensor& self, IntArrayRef split_sizes) {
2655 TORCH_CHECK(self.dim() >= 1, "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with ", self.dim(), " dimensions!")
2656 return at::tensor_split(self, split_sizes, (self.dim() == 1) ? 0 : 1);
2657 }
2658
vsplit(const Tensor & self,IntArrayRef split_sizes)2659 std::vector<Tensor> vsplit(const Tensor& self, IntArrayRef split_sizes) {
2660 TORCH_CHECK(self.dim() >= 2, "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with ", self.dim(), " dimensions!")
2661 return at::tensor_split(self, split_sizes, 0);
2662 }
2663
dsplit(const Tensor & self,IntArrayRef split_sizes)2664 std::vector<Tensor> dsplit(const Tensor& self, IntArrayRef split_sizes) {
2665 TORCH_CHECK(self.dim() >= 3, "torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with ", self.dim(), " dimensions!")
2666 return at::tensor_split(self, split_sizes, 2);
2667 }
2668
2669 // Precondition: tensors is non-empty
get_stack_inputs(TensorList tensors,int64_t dim)2670 static inline std::vector<Tensor> get_stack_inputs(TensorList tensors, int64_t dim) {
2671 std::vector<Tensor> inputs(tensors.size());
2672 at::IntArrayRef entry_shape = tensors[0].sizes();
2673 inputs[0] = tensors[0].unsqueeze(dim);
2674 for (const auto i : c10::irange(1, tensors.size())) {
2675 TORCH_CHECK(tensors[i].sizes() == entry_shape,
2676 "stack expects each tensor to be equal size, but got ", entry_shape,
2677 " at entry 0 and ", tensors[i].sizes(), " at entry ", i);
2678 inputs[i] = tensors[i].unsqueeze(dim);
2679 }
2680 return inputs;
2681 }
2682
maybe_native_stack(Tensor & result,TensorList tensors,int64_t dim)2683 bool inline maybe_native_stack(Tensor& result, TensorList tensors, int64_t dim) {
2684 dim = maybe_wrap_dim(dim, tensors[0].dim() + 1);
2685 if (detail::CanUseNativeSerialStack<TensorList, /*skip_overlap_check*/ false>::call(result, tensors, dim)) {
2686 // compute the size of the result
2687 auto result_sizes = tensors[0].sizes().vec();
2688 result_sizes.insert(result_sizes.begin() + dim, tensors.size());
2689
2690 // skip resizing if size of result is same as expected
2691 // raise a warning while resizing if output has one or more elements
2692 // at::native::resize_output(result, result_sizes);
2693 // TODO: restore the above, see https://github.com/pytorch/pytorch/issues/64709
2694
2695 if (result.sizes() != result_sizes) {
2696 result.resize_(result_sizes);
2697 }
2698
2699 stack_serial_stub(kCPU, result, tensors, dim);
2700 return true;
2701 }
2702 return false;
2703 }
2704
_stack(TensorList tensors,int64_t dim)2705 Tensor _stack(TensorList tensors, int64_t dim) {
2706 ScalarType high_type = result_type(tensors);
2707 Tensor result = at::empty({0}, tensors[0].options().dtype(high_type));
2708 return at::native::_stack_out(get_stack_inputs(tensors, dim), dim, result);
2709 }
2710
_stack_cpu(TensorList tensors,int64_t dim)2711 Tensor _stack_cpu(TensorList tensors, int64_t dim) {
2712 ScalarType high_type = result_type(tensors);
2713 Tensor result = at::empty({0}, tensors[0].options().dtype(high_type));
2714 return at::native::_stack_out_cpu(tensors, dim, result);
2715 }
2716
check_stack_inputs(TensorList tensors,int64_t dim)2717 static void check_stack_inputs(TensorList tensors, int64_t dim) {
2718 at::IntArrayRef entry_shape = tensors[0].sizes();
2719 for (const auto i : c10::irange(1, tensors.size())) {
2720 TORCH_CHECK(tensors[i].sizes() == entry_shape,
2721 "stack expects each tensor to be equal size, but got ", entry_shape,
2722 " at entry 0 and ", tensors[i].sizes(), " at entry ", i);
2723 }
2724 }
2725
2726 // Pads each tensor on `dim`-th dimension such that padded_dim % num_chunks == 0.
_pad_chunk(TensorList tensors,int64_t dim,int64_t num_chunks)2727 static std::vector<Tensor> _pad_chunk(TensorList tensors, int64_t dim, int64_t num_chunks) {
2728 auto num_tensors = tensors.size();
2729 std::vector<Tensor> padded_tensors;
2730 padded_tensors.reserve(num_tensors);
2731 for (const auto & tensor : tensors) {
2732 auto tensor_size = tensor.sizes();
2733 std::vector<int64_t> padded_size(tensor_size.vec());
2734 padded_size[dim] = (tensor_size[dim] + num_chunks - 1) / num_chunks * num_chunks;
2735 Tensor padded_tensor = tensor;
2736 if (padded_size != tensor_size) {
2737 padded_tensor = tensor.new_zeros(padded_size);
2738 padded_tensor.narrow(dim, 0, tensor_size[dim]).copy_(tensor);
2739 }
2740 std::vector<int64_t> view_sizes(tensor_size.begin(), tensor_size.begin()+dim);
2741 view_sizes.insert(view_sizes.end(), {num_chunks, -1});
2742 padded_tensors.push_back(padded_tensor.view(view_sizes));
2743 }
2744 return padded_tensors;
2745 }
2746
_chunk_cat(TensorList tensors,int64_t dim,int64_t num_chunks)2747 Tensor _chunk_cat(TensorList tensors, int64_t dim, int64_t num_chunks) {
2748 auto wrapped_dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks);
2749 return at::cat(_pad_chunk(tensors, wrapped_dim, num_chunks), wrapped_dim+1);
2750 }
2751
_chunk_cat_out(TensorList tensors,int64_t dim,int64_t num_chunks,Tensor & out)2752 Tensor& _chunk_cat_out(TensorList tensors, int64_t dim, int64_t num_chunks, Tensor& out) {
2753 auto wrapped_dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks);
2754 at::cat_out(out, _pad_chunk(tensors, wrapped_dim, num_chunks), wrapped_dim+1);
2755 return out;
2756 }
2757
2758 // TODO(msubkhankulov): refactor to use _stack
stack(TensorList tensors,int64_t dim)2759 Tensor stack(TensorList tensors, int64_t dim) {
2760 TORCH_CHECK(!tensors.empty(),
2761 "stack expects a non-empty TensorList");
2762 auto wrapped_dim = maybe_wrap_dim(dim, tensors[0].ndimension()+1);
2763 if (wrapped_dim < tensors[0].ndimension() && !tensors[0].is_sparse()) {
2764 check_stack_inputs(tensors, wrapped_dim);
2765 auto result_sizes = tensors[0].sizes().vec();
2766 result_sizes.insert(result_sizes.begin() + wrapped_dim, tensors.size());
2767 auto out = at::cat(tensors, wrapped_dim);
2768 return out.view(result_sizes); // one can always split a dimension with view
2769 } else { //dim = tensors[0].ndimension() cannot be efficiently handled by view
2770 return at::cat(get_stack_inputs(tensors, dim), dim);
2771 }
2772 }
2773
2774 // CPU specific implementation
_stack_out_cpu(TensorList tensors,int64_t dim,Tensor & result)2775 Tensor& _stack_out_cpu(TensorList tensors, int64_t dim, Tensor& result) {
2776 if (maybe_native_stack(result, tensors, dim)) {
2777 return result;
2778 } else {
2779 return at::cat_out(result, get_stack_inputs(tensors, dim), dim);
2780 }
2781 }
2782
2783 // default backend
_stack_out(TensorList tensors,int64_t dim,Tensor & result)2784 Tensor& _stack_out(TensorList tensors, int64_t dim, Tensor& result) {
2785 return at::cat_out(result, tensors, dim);
2786 }
2787
2788 // TODO(msubkhankulov): refactor to use _stack_out
stack_out(TensorList tensors,int64_t dim,Tensor & result)2789 Tensor& stack_out(TensorList tensors, int64_t dim, Tensor& result) {
2790 TORCH_CHECK(!tensors.empty(),
2791 "stack expects a non-empty TensorList");
2792 auto wrapped_dim = maybe_wrap_dim(dim, tensors[0].ndimension()+1);
2793 if (wrapped_dim < tensors[0].ndimension() && !tensors[0].is_sparse()) {
2794 check_stack_inputs(tensors, wrapped_dim);
2795 auto result_sizes = tensors[0].sizes().vec();
2796 result_sizes.insert(result_sizes.begin() + wrapped_dim, tensors.size());
2797 at::native::resize_output(result, result_sizes);
2798 auto cat_sizes = tensors[0].sizes().vec();
2799 cat_sizes[wrapped_dim] *= tensors.size();
2800 auto strides = at::detail::computeStride(result.sizes(), result.strides(), cat_sizes);
2801 if (strides.has_value()) {
2802 //can take fast cat path
2803 auto result_view = result.view(cat_sizes);
2804 at::cat_out(result_view, tensors, wrapped_dim);
2805 return result;
2806 }
2807 }
2808 return at::cat_out(result, get_stack_inputs(tensors, dim), dim);
2809
2810 }
2811
hstack(TensorList tensors)2812 Tensor hstack(TensorList tensors) {
2813 TORCH_CHECK(!tensors.empty(),
2814 "hstack expects a non-empty TensorList");
2815 auto rep = at::atleast_1d(tensors);
2816 if (rep[0].dim() == 1) {
2817 return at::cat(rep, 0);
2818 }
2819 return at::cat(rep, 1);
2820 }
2821
hstack_out(TensorList tensors,Tensor & result)2822 Tensor& hstack_out(TensorList tensors, Tensor& result) {
2823 TORCH_CHECK(!tensors.empty(),
2824 "hstack expects a non-empty TensorList");
2825 auto rep = at::atleast_1d(tensors);
2826 if (rep[0].dim() == 1) {
2827 return at::cat_out(result, rep, 0);
2828 }
2829 return at::cat_out(result, rep, 1);
2830 }
2831
vstack(TensorList tensors)2832 Tensor vstack(TensorList tensors) {
2833 TORCH_CHECK(!tensors.empty(),
2834 "vstack expects a non-empty TensorList");
2835 auto rep = at::atleast_2d(tensors);
2836 return at::cat(rep, 0);
2837 }
2838
vstack_out(TensorList tensors,Tensor & result)2839 Tensor& vstack_out(TensorList tensors, Tensor& result) {
2840 TORCH_CHECK(!tensors.empty(),
2841 "vstack expects a non-empty TensorList");
2842 auto rep = at::atleast_2d(tensors);
2843 return at::cat_out(result, rep, 0);
2844 }
2845
dstack(TensorList tensors)2846 Tensor dstack(TensorList tensors) {
2847 TORCH_CHECK(!tensors.empty(),
2848 "dstack expects a non-empty TensorList");
2849 auto rep = at::atleast_3d(tensors);
2850 return at::cat(rep, 2);
2851 }
dstack_out(TensorList tensors,Tensor & result)2852 Tensor& dstack_out(TensorList tensors, Tensor& result) {
2853 TORCH_CHECK(!tensors.empty(),
2854 "dstack expects a non-empty TensorList");
2855 auto rep = at::atleast_3d(tensors);
2856 return at::cat_out(result, rep, 2);
2857 }
2858
sparse_transpose_(Tensor & self,int64_t dim0,int64_t dim1)2859 static inline Tensor & sparse_transpose_(Tensor & self, int64_t dim0, int64_t dim1) {
2860 int64_t nsparse_dim = self.sparse_dim();
2861 TORCH_CHECK(dim0 < nsparse_dim && dim1 < nsparse_dim,
2862 "sparse transpose: transposed dimensions must be sparse ",
2863 "Got sparse_dim: ", nsparse_dim, ", d0: ", dim0, ", d1: ", dim1);
2864
2865 if (self._indices().numel() == 0 && self._values().numel() == 0) {
2866 auto sizes = self.sizes().vec();
2867 std::swap(sizes[dim0], sizes[dim1]);
2868
2869 at::sparse::get_sparse_impl(self)->raw_resize_(self.sparse_dim(), self.dense_dim(), sizes);
2870 } else {
2871 auto indices = self._indices();
2872 auto row0 = indices.select(0, dim0);
2873 auto row1 = indices.select(0, dim1);
2874
2875 // swap row0 and row1
2876 auto tmp = at::zeros_like(row0, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
2877 tmp.copy_(row0);
2878 row0.copy_(row1);
2879 row1.copy_(tmp);
2880
2881 self._coalesced_(false);
2882
2883 auto sizes = self.sizes().vec();
2884 std::swap(sizes[dim0], sizes[dim1]);
2885
2886 at::sparse::get_sparse_impl(self)->raw_resize_(self._indices().size(0), self._values().dim() - 1, sizes);
2887 }
2888 return self;
2889 }
2890
2891 // torch.row_stack, alias for torch.vstack
row_stack_out(TensorList tensors,Tensor & result)2892 Tensor& row_stack_out(TensorList tensors, Tensor& result) {
2893 return at::vstack_out(result, tensors);
2894 }
2895
row_stack(TensorList tensors)2896 Tensor row_stack(TensorList tensors) {
2897 return at::vstack(tensors);
2898 }
2899
reshape_input_for_column_stack(TensorList tensors)2900 static std::vector<Tensor> reshape_input_for_column_stack(TensorList tensors) {
2901 std::vector<Tensor> result(tensors.size());
2902 auto transform_lambda = [](const Tensor& input) -> Tensor {
2903 // reshape 0D or 1D tensor t into (t.numel(), 1)
2904 if (input.dim() <= 1) {
2905 return input.reshape_symint({input.sym_numel(), 1});
2906 }
2907 return input;
2908 };
2909 std::transform(tensors.cbegin(),
2910 tensors.cend(),
2911 result.begin(),
2912 transform_lambda);
2913 return result;
2914 }
2915
column_stack_out(TensorList tensors,Tensor & result)2916 Tensor& column_stack_out(TensorList tensors, Tensor& result) {
2917 TORCH_CHECK(!tensors.empty(),
2918 "column_stack expects a non-empty TensorList");
2919
2920 auto reshaped_tensors = reshape_input_for_column_stack(tensors);
2921 return at::hstack_out(result, reshaped_tensors);
2922 }
2923
column_stack(TensorList tensors)2924 Tensor column_stack(TensorList tensors) {
2925 TORCH_CHECK(!tensors.empty(),
2926 "column_stack expects a non-empty TensorList");
2927
2928 auto reshaped_tensors = reshape_input_for_column_stack(tensors);
2929 return at::hstack(reshaped_tensors);
2930 }
2931
propagate_transposed_names(Tensor & result,const Tensor & other,int64_t dim0,int64_t dim1)2932 static Tensor& propagate_transposed_names(
2933 Tensor& result,
2934 const Tensor& other,
2935 int64_t dim0,
2936 int64_t dim1) {
2937 if (other.has_names()) {
2938 auto names = other.names().vec();
2939 std::swap(names[dim0], names[dim1]);
2940 namedinference::propagate_names_if_nonempty(result, names);
2941 }
2942 return result;
2943 }
2944
transpose(const Tensor & self,Dimname dim0,Dimname dim1)2945 Tensor transpose(const Tensor& self, Dimname dim0, Dimname dim1) {
2946 return at::transpose(
2947 self, dimname_to_position(self, dim0), dimname_to_position(self, dim1));
2948 }
2949
2950
transpose_(Tensor & self,int64_t dim0,int64_t dim1)2951 Tensor & transpose_(Tensor & self, int64_t dim0, int64_t dim1) {
2952 TORCH_CHECK(
2953 !(self.layout() == kSparseCsr || self.layout() == kSparseCsc ||
2954 self.layout() == kSparseBsr || self.layout() == kSparseBsc),
2955 "torch.transpose_: in-place transposition is not supported for ",
2956 self.layout(),
2957 " layout");
2958
2959 auto ndims = self.dim();
2960 dim0 = maybe_wrap_dim(dim0, ndims);
2961 dim1 = maybe_wrap_dim(dim1, ndims);
2962 if (dim0 == dim1) {
2963 return self;
2964 }
2965
2966 // Sparse COO is an exceptional sparse format as it allows transpose
2967 // to be a view operation which is a convenient property for
2968 // in-place operations. For other sparse formats, the in-place
2969 // transpose would not be possible without shuffling the specified
2970 // values. So we don't support this as it would defeat the purpose
2971 // of in-place opreations of being memory-efficient.
2972 if (self.is_sparse()) {
2973 return sparse_transpose_(self, dim0, dim1);
2974 }
2975
2976 if (self.is_mkldnn()) {
2977 return at::_mkldnn_transpose_(self, dim0, dim1);
2978 }
2979
2980 DimVector sizes(self.sizes().begin(), self.sizes().end());
2981 DimVector strides(self.strides().begin(), self.strides().end());
2982 std::swap(strides[dim0], strides[dim1]);
2983 std::swap(sizes[dim0], sizes[dim1]);
2984 self.as_strided_(sizes, strides);
2985 return self;
2986 }
2987
2988 namespace {
2989 // Transpose implementation for sparse compressed layouts
2990 // NB: We assume that dim1,dim0 have already been wrapped
sparse_compressed_transpose(const Tensor & self,int64_t dim0,int64_t dim1)2991 static inline Tensor sparse_compressed_transpose(
2992 const Tensor& self,
2993 int64_t dim0,
2994 int64_t dim1) {
2995 auto compressed_inds = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
2996 self.layout(),
2997 "compressed_inds",
2998 [&self]() { return self.crow_indices(); },
2999 [&self]() { return self.ccol_indices(); });
3000
3001 auto plain_inds = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
3002 self.layout(),
3003 "plain_inds",
3004 [&self]() { return self.col_indices(); },
3005 [&self]() { return self.row_indices(); });
3006
3007 const auto n_batch_dim = compressed_inds.dim() - 1;
3008 const auto dense_dim = self.dim() - n_batch_dim - 2;
3009
3010 // In theory it works, but missing to_dense coverage to test
3011 TORCH_CHECK(
3012 dense_dim == 0,
3013 "transpose(): hybrid sparse compressed tensors with dense dimensions are not supported");
3014
3015 // Classify transpose "type"
3016 enum class TransposeDim : uint8_t { Batch, Sparse, Dense };
3017 auto classify_dim = [&n_batch_dim](const int64_t dim) {
3018 if (dim < n_batch_dim) {
3019 return TransposeDim::Batch;
3020 } else if (dim > n_batch_dim + 1) {
3021 return TransposeDim::Dense;
3022 } else {
3023 return TransposeDim::Sparse;
3024 }
3025 };
3026
3027 const auto transpose_type = classify_dim(dim0);
3028 {
3029 #ifndef STRIP_ERROR_MESSAGES
3030 auto dim_type_name = [](const TransposeDim dim) {
3031 switch (dim) {
3032 case TransposeDim::Batch:
3033 return "Batch";
3034 case TransposeDim::Dense:
3035 return "Dense";
3036 case TransposeDim::Sparse:
3037 return "Sparse";
3038 default:
3039 TORCH_INTERNAL_ASSERT(
3040 false,
3041 "Impossible TransposeDim value: ",
3042 static_cast<std::underlying_type_t<TransposeDim>>(dim));
3043 }
3044 };
3045 #endif
3046 const auto dim1_type = classify_dim(dim1);
3047 TORCH_CHECK(
3048 dim1_type == transpose_type,
3049 "transpose(): can only transpose dimensions of the same type (Batch, Sparse, Dense), got ",
3050 dim0,
3051 "(",
3052 dim_type_name(transpose_type),
3053 ")",
3054 " and ",
3055 dim1,
3056 "(",
3057 dim_type_name(dim1_type),
3058 ")");
3059 }
3060
3061 // We have validated everything, early exit for equal dims (no effect)
3062 if (dim0 == dim1) {
3063 return self.clone();
3064 }
3065
3066 auto result_sizes = DimVector(self.sizes());
3067 std::swap(result_sizes[dim0], result_sizes[dim1]);
3068 Tensor result_vals;
3069 auto result_layout = self.layout();
3070
3071 if (transpose_type == TransposeDim::Batch) {
3072 compressed_inds = compressed_inds.transpose(dim0, dim1).contiguous();
3073 plain_inds = plain_inds.transpose(dim0, dim1).contiguous();
3074 result_vals = self.values().transpose(dim0, dim1).contiguous();
3075
3076 } else if (transpose_type == TransposeDim::Dense) {
3077 // NB: This code should work, but is untestable due to lack of support for
3078 // dense dimensions in to_dense. The Debug assert is present to emphasize
3079 // the fact that the block should not be possible to hit this code block
3080 TORCH_INTERNAL_ASSERT(
3081 false, "transpose(): Shouldn't have reached this point");
3082 result_vals = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
3083 self.layout(),
3084 "sparse_transpose",
3085 // un-blocked: 2 sparse dims map to single nnz dim, so dense dim0/1 are
3086 // one position left
3087 [&]() { return self.values().transpose(dim0 - 1, dim1 - 1); },
3088 // blocked: 2 sparse dims map to 3 (nnz, ) + blocksize dims, so dense
3089 // dim0/1 are one position right
3090 [&]() { return self.values().transpose(dim0 + 1, dim1 + 1); });
3091 } else /*if (transpose_type == TransposeDim::Sparse) */ {
3092 // Flip the layout
3093 result_layout = sparse_csr::flip_compressed_layout(self.layout());
3094 result_vals = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
3095 self.layout(),
3096 "sparse_transpose",
3097 // un-blocked: no change to values, layout is flipped.
3098 [&]() { return self.values(); },
3099 // blocked: the blocks are nested under the sparse dims so they must be
3100 // transposed as well.
3101 [&]() {
3102 return self.values().transpose(-2 - dense_dim, -1 - dense_dim);
3103 });
3104 }
3105 return at::_sparse_compressed_tensor_unsafe(
3106 compressed_inds,
3107 plain_inds,
3108 result_vals,
3109 result_sizes,
3110 self.options().layout(result_layout));
3111 }
3112 } // namespace
3113
transpose(const Tensor & self,int64_t dim0,int64_t dim1)3114 Tensor transpose(const Tensor & self, int64_t dim0, int64_t dim1) {
3115 auto ndims = self.dim();
3116 dim0 = maybe_wrap_dim(dim0, ndims);
3117 dim1 = maybe_wrap_dim(dim1, ndims);
3118
3119 if (self.is_sparse()) {
3120 if (dim0 == dim1) {
3121 return self.clone();
3122 }
3123 Tensor self_clone = self.clone();
3124 return sparse_transpose_(self_clone, dim0, dim1);
3125 }
3126 if (self.layout() == kSparseBsr || self.layout() == kSparseCsr ||
3127 self.layout() == kSparseBsc || self.layout() == kSparseCsc) {
3128 return sparse_compressed_transpose(self, dim0, dim1);
3129 }
3130
3131 if (self.is_mkldnn()) {
3132 return at::_mkldnn_transpose(self, dim0, dim1);
3133 }
3134
3135 // Transpose of a tensor is a view operation.
3136 if (dim0 == dim1) {
3137 return self.alias();
3138 }
3139
3140 SymDimVector sizes(self.sym_sizes().begin(), self.sym_sizes().end());
3141 std::swap(sizes[dim0], sizes[dim1]);
3142 SymDimVector strides(self.sym_strides().begin(), self.sym_strides().end());
3143 std::swap(strides[dim0], strides[dim1]);
3144 auto result = self.as_strided_symint(sizes, strides);
3145 propagate_transposed_names(result, self, dim0, dim1);
3146 return result;
3147 }
3148
check_t(const Tensor & self,const char * fn)3149 static void check_t(const Tensor& self, const char *fn) {
3150 if (self.is_sparse()) {
3151 int64_t sparse_dim = self.sparse_dim();
3152 int64_t dense_dim = self.dense_dim();
3153 TORCH_CHECK(sparse_dim <= 2 && dense_dim == 0,
3154 fn, " expects a tensor with <= 2 sparse and 0 dense dimensions, but got ",
3155 sparse_dim, " sparse and ", dense_dim, " dense dimensions");
3156 } else {
3157 TORCH_CHECK(self.dim() <= 2,
3158 fn, " expects a tensor with <= 2 dimensions, but self is ", self.dim(), "D");
3159 }
3160 }
3161
t(const Tensor & self)3162 Tensor t(const Tensor & self) {
3163 check_t(self, "t()");
3164 return self.transpose(0, self.dim() < 2 ? 0 : 1);
3165 }
3166
t_(Tensor & self)3167 Tensor & t_(Tensor & self) {
3168 check_t(self, "t_()");
3169 return self.transpose_(0, self.dim() < 2 ? 0 : 1);
3170 }
3171
3172 std::tuple<SymDimVector, SymDimVector>
inferSqueezeGeometry(const Tensor & tensor)3173 static inferSqueezeGeometry(const Tensor &tensor) {
3174 SymDimVector sizes;
3175 SymDimVector strides;
3176
3177 for(const auto d : c10::irange(tensor.dim())) {
3178 if(tensor.sym_sizes()[d] != 1) {
3179 sizes.push_back(tensor.sym_sizes()[d]);
3180 strides.push_back(tensor.sym_strides()[d]);
3181 }
3182 }
3183
3184 return std::make_tuple(std::move(sizes), std::move(strides));
3185 }
3186
3187 std::tuple<SymDimVector, SymDimVector>
inferSqueezeGeometry(const Tensor & tensor,int64_t dim)3188 static inferSqueezeGeometry(const Tensor& tensor, int64_t dim) {
3189 SymDimVector sizes;
3190 SymDimVector strides;
3191
3192 for(const auto d : c10::irange(tensor.dim())) {
3193 if(d != dim || tensor.sym_sizes()[dim] != 1) {
3194 sizes.push_back(tensor.sym_sizes()[d]);
3195 strides.push_back(tensor.sym_strides()[d]);
3196 }
3197 }
3198 return std::make_tuple(std::move(sizes), std::move(strides));
3199 }
3200
3201 std::tuple<SymDimVector, SymDimVector>
inferSqueezeGeometry(const Tensor & tensor,std::bitset<dim_bitset_size> dim_mask)3202 static inferSqueezeGeometry(const Tensor &tensor, std::bitset<dim_bitset_size> dim_mask) {
3203 const auto ndim = tensor.dim();
3204 const auto sym_sizes = tensor.sym_sizes();
3205 const auto sym_strides = tensor.sym_strides();
3206
3207 SymDimVector out_sizes, out_strides;
3208 for (const auto d: c10::irange(ndim)) {
3209 if (!dim_mask.test(d) || sym_sizes[d] != 1) {
3210 out_sizes.push_back(sym_sizes[d]);
3211 out_strides.push_back(sym_strides[d]);
3212 }
3213 }
3214 return std::make_tuple(std::move(out_sizes), std::move(out_strides));
3215 }
3216
3217 namespace {
3218 // Named type instead of a pair/tuple so that we can be sure to
3219 // construct the vectors in place and get NRVO.
3220 template <typename T>
3221 struct InferUnsqueezeGeometryResult {
3222 SmallVector<T, kDimVectorStaticSize>sizes;
3223 SmallVector<T, kDimVectorStaticSize> strides;
InferUnsqueezeGeometryResultat::native::__anon87b55d972711::InferUnsqueezeGeometryResult3224 InferUnsqueezeGeometryResult(ArrayRef<T> tensor_sizes, ArrayRef<T> tensor_strides)
3225 : sizes(tensor_sizes.begin(), tensor_sizes.end())
3226 , strides(tensor_strides.begin(), tensor_strides.end()) {}
3227 };
3228
3229 InferUnsqueezeGeometryResult<c10::SymInt>
inferUnsqueezeGeometry_symint(const Tensor & tensor,int64_t dim)3230 inferUnsqueezeGeometry_symint(const Tensor& tensor, int64_t dim) {
3231 InferUnsqueezeGeometryResult<c10::SymInt> result(tensor.sym_sizes(), tensor.sym_strides());
3232 c10::SymInt new_stride = dim >= tensor.dim() ? 1 : result.sizes[dim] * result.strides[dim];
3233 result.sizes.insert(result.sizes.begin() + dim, 1);
3234 result.strides.insert(result.strides.begin() + dim, new_stride);
3235
3236 return result;
3237 }
3238
3239 InferUnsqueezeGeometryResult<int64_t>
inferUnsqueezeGeometry(const Tensor & tensor,int64_t dim)3240 inferUnsqueezeGeometry(const Tensor& tensor, int64_t dim) {
3241 InferUnsqueezeGeometryResult<int64_t> result(tensor.sizes(), tensor.strides());
3242 int64_t new_stride = dim >= tensor.dim() ? 1 : result.sizes[dim] * result.strides[dim];
3243 result.sizes.insert(result.sizes.begin() + dim, 1);
3244 result.strides.insert(result.strides.begin() + dim, new_stride);
3245
3246 return result;
3247 }
3248
3249 // dim is present if squeezing a single dimension and absent if squeezing all dimensions
squeeze_qtensor(const Tensor & self,c10::OptionalIntArrayRef dims)3250 Tensor squeeze_qtensor(const Tensor& self, c10::OptionalIntArrayRef dims) {
3251 auto quantizer = get_qtensorimpl(self)->quantizer();
3252 const auto ndim = self.dim();
3253 auto mask = dims.has_value()
3254 ? dim_list_to_bitset(dims, self.dim())
3255 : std::bitset<dim_bitset_size>((1ull << self.dim()) - 1);
3256 auto [sizes, strides] = inferSqueezeGeometry(self, mask);
3257 if (quantizer->qscheme() == QScheme::PER_CHANNEL_AFFINE) {
3258 const auto* per_channel_quantizer = static_cast<at::PerChannelAffineQuantizer*>(quantizer.get());
3259 auto axis = per_channel_quantizer->axis();
3260 int64_t shift = 0;
3261 for (const auto d : c10::irange(ndim)) {
3262 if (mask.test(d) && self.sizes()[d] == 1) {
3263 TORCH_CHECK(axis != d, "Squeeze is only possible on non-axis dimension for Per-Channel Quantized Tensors.");
3264 if (d < axis) {
3265 ++shift;
3266 }
3267 }
3268 }
3269 axis -= shift;
3270 quantizer = make_per_channel_affine_quantizer(per_channel_quantizer->scales(),
3271 per_channel_quantizer->zero_points(),
3272 axis,
3273 quantizer->scalar_type());
3274 }
3275 // TODO: quantized Tensor support for SymInt needs to be added but basic building blocs
3276 // are missing for now.
3277 auto result = make_qtensor(self, C10_AS_INTARRAYREF_SLOW(sizes), C10_AS_INTARRAYREF_SLOW(strides), std::move(quantizer));
3278 auto maybe_outnames = namedinference::compute_squeeze_outnames(self, mask);
3279 namedinference::propagate_names_if_nonempty(result, maybe_outnames);
3280 return result;
3281 }
3282 }
3283
squeeze(const Tensor & self)3284 Tensor squeeze(const Tensor& self) {
3285 auto g = inferSqueezeGeometry(self);
3286 at::Tensor result = self.as_strided_symint(std::get<0>(g), std::get<1>(g));
3287 auto maybe_outnames = namedinference::compute_squeeze_outnames(self);
3288 namedinference::propagate_names_if_nonempty(result, maybe_outnames);
3289 return result;
3290 }
3291
squeeze_quantized(const Tensor & self)3292 Tensor squeeze_quantized(const Tensor& self) {
3293 return squeeze_qtensor(self, std::nullopt);
3294 }
3295
squeeze(const Tensor & self,int64_t dim)3296 Tensor squeeze(const Tensor& self, int64_t dim) {
3297 int64_t dims = self.dim();
3298 dim = maybe_wrap_dim(dim, dims);
3299 if (dims == 0 || self.sym_sizes()[dim] != 1) {
3300 return self.as_strided_symint(self.sym_sizes(), self.sym_strides());
3301 }
3302 auto g = inferSqueezeGeometry(self, dim);
3303 auto result = self.as_strided_symint(std::get<0>(g), std::get<1>(g));
3304 namedinference::propagate_names_except(result, self, {dim});
3305 return result;
3306 }
3307
squeeze_quantized(const Tensor & self,int64_t dim)3308 Tensor squeeze_quantized(const Tensor& self, int64_t dim) {
3309 return squeeze_qtensor(self, dim);
3310 }
3311
squeeze(const Tensor & self,IntArrayRef dims)3312 Tensor squeeze(const Tensor& self, IntArrayRef dims) {
3313 auto mask = dim_list_to_bitset(dims, self.dim());
3314 auto g = inferSqueezeGeometry(self, mask);
3315 at::Tensor result = self.as_strided_symint(std::get<0>(g), std::get<1>(g));
3316 auto maybe_outnames = namedinference::compute_squeeze_outnames(self, mask);
3317 namedinference::propagate_names_if_nonempty(result, maybe_outnames);
3318 return result;
3319 }
3320
squeeze_quantized(const Tensor & self,IntArrayRef dim)3321 Tensor squeeze_quantized(const Tensor& self, IntArrayRef dim) {
3322 return squeeze_qtensor(self, dim);
3323 }
3324
squeeze_(Tensor & self)3325 Tensor & squeeze_(Tensor& self) {
3326 auto g = inferSqueezeGeometry(self);
3327 self.as_strided__symint(std::get<0>(g), std::get<1>(g));
3328 return self;
3329 }
3330
squeeze_(Tensor & self,int64_t dim)3331 Tensor & squeeze_(Tensor& self, int64_t dim) {
3332 int64_t dims = self.dim();
3333 dim = maybe_wrap_dim(dim, self.dim());
3334
3335 if (dims == 0 || self.sym_sizes()[dim] != 1) {
3336 self.as_strided__symint(self.sym_sizes(), self.sym_strides());
3337 return self;
3338 }
3339 auto g = inferSqueezeGeometry(self, dim);
3340 self.as_strided__symint(std::get<0>(g), std::get<1>(g));
3341 return self;
3342 }
3343
squeeze_(Tensor & self,IntArrayRef dims)3344 Tensor & squeeze_(Tensor &self, IntArrayRef dims) {
3345 auto mask = dim_list_to_bitset(dims, self.dim());
3346 auto g = inferSqueezeGeometry(self, mask);
3347 self.as_strided__symint(std::get<0>(g), std::get<1>(g));
3348 return self;
3349 }
3350
3351 // NOTE [ Unsafe View ]
3352 // _unsafe_view() differs from view() in that the returned tensor isn't treated
3353 // as a view for the purposes of automatic differentiation. (It's not listed in
3354 // VIEW_FUNCTIONS in gen_inplace_or_view_type.py). It's only safe to use if the `self` tensor
3355 // is temporary. For example, the viewed tensor here (a + b) is discarded immediately
3356 // after viewing:
3357 //
3358 // res = at::_unsafe_view(a + b, size);
3359 //
3360 // This is a hack because in-place operations on tensors treated like views
3361 // can be much more expensive than the same operations on non-view tensors.
3362
view_impl(const Tensor & self,IntArrayRef size)3363 inline Tensor view_impl(const Tensor& self, IntArrayRef size) {
3364
3365 at::DimVector inferred_size = at::infer_size_dv(size, self.numel());
3366 auto stride = at::detail::computeStride(self.sizes(),
3367 self.strides(),
3368 inferred_size);
3369 TORCH_CHECK(stride.has_value(), "view size is "
3370 "not compatible with input tensor's size and stride (at least one dimension"
3371 " spans across two contiguous subspaces). Use .reshape(...) instead.");
3372 return alias_with_sizes_and_strides(self, inferred_size, *stride);
3373
3374 }
3375
_unsafe_view(const Tensor & self,IntArrayRef size)3376 Tensor _unsafe_view(const Tensor& self, IntArrayRef size) {
3377 return view_impl(self, size);
3378 }
3379
unsqueeze(const Tensor & self,int64_t dim)3380 Tensor unsqueeze(const Tensor& self, int64_t dim) {
3381 dim = maybe_wrap_dim(dim, self.dim() + 1);
3382 auto g = inferUnsqueezeGeometry_symint(self, dim);
3383 return self.as_strided_symint(g.sizes, g.strides);
3384 }
3385
unsqueeze_sparse(Tensor const & self,int64_t dim)3386 Tensor unsqueeze_sparse(Tensor const &self, int64_t dim) {
3387 dim = maybe_wrap_dim(dim, self.dim() + 1);
3388 int64_t sparse_dim = self.sparse_dim();
3389 int64_t dense_dim = self.dense_dim();
3390 auto indices = self._indices();
3391 auto sizes = self.sizes().vec();
3392 sizes.insert(sizes.begin() + dim, 1);
3393 if (dim <= sparse_dim) {
3394 auto new_indices = at::cat(
3395 {indices.narrow(0, 0, dim),
3396 at::zeros(
3397 {1, indices.size(1)},
3398 kLong,
3399 indices.options().layout_opt(),
3400 indices.options().device_opt(),
3401 indices.options().pinned_memory_opt()),
3402 indices.narrow(0, dim, indices.size(0) - dim)});
3403 return _sparse_coo_tensor_with_dims_and_tensors(
3404 sparse_dim + 1, dense_dim, sizes, new_indices, self._values(), self.options());
3405 } else {
3406 return _sparse_coo_tensor_with_dims_and_tensors(
3407 sparse_dim, dense_dim + 1, sizes, indices, self._values().unsqueeze(dim - sparse_dim + 1), self.options());
3408 }
3409 }
3410
unsqueeze_quantized(const Tensor & self,int64_t dim)3411 Tensor unsqueeze_quantized(const Tensor& self, int64_t dim) {
3412 dim = maybe_wrap_dim(dim, self.dim() + 1);
3413 auto g = inferUnsqueezeGeometry(self, dim);
3414 auto quantizer = get_qtensorimpl(self)->quantizer();
3415 if (quantizer->qscheme() == QScheme::PER_CHANNEL_AFFINE) {
3416 const auto* per_channel_quantizer = static_cast<at::PerChannelAffineQuantizer*>(quantizer.get());
3417 auto axis = per_channel_quantizer->axis();
3418 if (axis >= dim) {
3419 axis += 1;
3420 }
3421 quantizer = make_per_channel_affine_quantizer(per_channel_quantizer->scales(),
3422 per_channel_quantizer->zero_points(),
3423 axis,
3424 quantizer->scalar_type());
3425 }
3426 return make_qtensor(self, g.sizes, g.strides, std::move(quantizer));
3427 }
3428
unsqueeze_(Tensor & self,int64_t dim)3429 Tensor & unsqueeze_(Tensor& self, int64_t dim) {
3430 dim = maybe_wrap_dim(dim, self.dim() + 1);
3431
3432 auto g = inferUnsqueezeGeometry_symint(self, dim);
3433 self.as_strided__symint(g.sizes, g.strides);
3434 return self;
3435 }
3436
flatten(const Tensor & self,int64_t start_dim,int64_t end_dim)3437 Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) {
3438 start_dim = maybe_wrap_dim(start_dim, self.dim());
3439 end_dim = maybe_wrap_dim(end_dim, self.dim());
3440 TORCH_CHECK(start_dim <= end_dim, "flatten() has invalid args: start_dim cannot come after end_dim");
3441
3442 if (self.dim() == 0) {
3443 return self.reshape({1});
3444 }
3445 if (start_dim == end_dim) {
3446 return self;
3447 }
3448
3449 // We don't want to infer_size on the entire shape, because that can give us an extra degree
3450 // of freedom we don't want; for example, consider shape [0, 1, 3, 0], with start_dim=1, end_dim=2.
3451 // It's clear we want result shape [0, 3, 0] but passing [0, -1, 0] to infer_size means the -1
3452 // can take on any value and satisfy the constraints.
3453 auto slice_numel = c10::multiply_integers(self.sym_sizes().slice(start_dim, end_dim - start_dim + 1));
3454 std::vector<c10::SymInt> shape;
3455 shape.reserve(self.dim() - end_dim + start_dim);
3456 for (const auto i : c10::irange(start_dim)) {
3457 shape.push_back(self.sym_sizes()[i]);
3458 }
3459 shape.push_back(slice_numel);
3460 for (const auto i : c10::irange(end_dim + 1, self.dim())) {
3461 shape.push_back(self.sym_sizes()[i]);
3462 }
3463
3464 return native::reshape_symint(self, shape);
3465 }
3466
flatten(const Tensor & self,int64_t start_dim,int64_t end_dim,Dimname out_dim)3467 Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim, Dimname out_dim) {
3468 start_dim = maybe_wrap_dim(start_dim, self.dim());
3469 end_dim = maybe_wrap_dim(end_dim, self.dim());
3470 TORCH_CHECK(start_dim <= end_dim, "flatten() has invalid args: start_dim cannot come after end_dim");
3471
3472 auto outnames = self.names().vec();
3473 outnames.erase(outnames.begin() + start_dim, outnames.begin() + end_dim + 1);
3474 outnames.insert(outnames.begin() + start_dim, out_dim);
3475
3476 Tensor result;
3477 {
3478 NoNamesGuard guard;
3479 result = native::flatten(self, start_dim, end_dim);
3480 }
3481 internal_set_names_inplace(result, outnames);
3482 return result;
3483 }
3484
flatten(const Tensor & self,Dimname start_dim,Dimname end_dim,Dimname out_dim)3485 Tensor flatten(const Tensor& self, Dimname start_dim, Dimname end_dim, Dimname out_dim) {
3486 auto start_pos = dimname_to_position(self, start_dim);
3487 auto end_pos = dimname_to_position(self, end_dim);
3488 return native::flatten(self, start_pos, end_pos, out_dim);
3489 }
3490
flatten(const Tensor & self,DimnameList dims,Dimname out_dim)3491 Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) {
3492 auto positions = dimnames_to_positions(self, dims);
3493 TORCH_CHECK(!positions.empty(),
3494 "flatten(tensor, dims, out_dim): dims cannot be empty");
3495 for (const auto i : c10::irange(positions.size() - 1)) {
3496 if (positions[i] + 1 == positions[i + 1]) continue;
3497 TORCH_CHECK(positions[i] + 1 == positions[i + 1],
3498 "flatten(tensor, dims, out_dim): dims ", dims, " must be consecutive ",
3499 "in Tensor", self.names());
3500 }
3501 return native::flatten(self, *dims.begin(), *(dims.end() - 1), out_dim);
3502 }
3503
ravel(const Tensor & self)3504 Tensor ravel(const Tensor& self) {
3505 return self.contiguous().view(-1);
3506 }
3507
handle_unflatten_exception(const std::runtime_error & e,const Tensor & self,int64_t dim,SymIntArrayRef sizes,std::optional<DimnameList> names)3508 static inline void handle_unflatten_exception(const std::runtime_error &e,
3509 const Tensor &self,
3510 int64_t dim,
3511 SymIntArrayRef sizes,
3512 std::optional <DimnameList> names) {
3513 if (!strstr(e.what(), "is invalid for input of size")) {
3514 TORCH_CHECK(false, "unflatten got an unexpected error:\n", e.what());
3515 }
3516
3517 if (self.has_names()) {
3518 TORCH_CHECK(false,
3519 "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ",
3520 dim, " (", self.names()[dim], ": ", self.sym_size(dim), ") in Tensor", self.names());
3521
3522 } else {
3523 TORCH_CHECK(false,
3524 "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ",
3525 dim, " (", self.sym_size(dim), ") in the input tensor");
3526 }
3527 }
3528
unflatten_impl(const Tensor & self,int64_t dim,SymIntArrayRef sizes,std::optional<DimnameList> names)3529 static Tensor unflatten_impl(const Tensor& self, int64_t dim, SymIntArrayRef sizes, std::optional<DimnameList> names) {
3530 dim = maybe_wrap_dim(dim, self.dim());
3531
3532 TORCH_CHECK(!sizes.empty(), "unflatten: sizes must be non-empty");
3533 TORCH_INTERNAL_ASSERT(!names || names->size() == sizes.size());
3534 if (self.has_names()) {
3535 TORCH_CHECK(names, "unflatten: input is a named tensor but no names were given for unflattened sizes");
3536 }
3537
3538 SymDimVector inferred_size;
3539 try {
3540 inferred_size = at::infer_size_dv(sizes, self.sym_size(dim));
3541 } catch (const std::runtime_error& e) {
3542 // at::infer_size would throw std::runtime_error for invalid size,
3543 // catch the runtime_error and display the error message in a more user-friendly way
3544 // for both tensors and named tensors
3545 handle_unflatten_exception(e, self, dim, sizes, names);
3546 }
3547
3548 SymDimVector shape(self.sym_sizes().begin(), self.sym_sizes().end());
3549 shape.erase(shape.begin() + dim);
3550 shape.insert(shape.begin() + dim, inferred_size.begin(), inferred_size.end());
3551
3552 Tensor result;
3553 {
3554 NoNamesGuard guard;
3555 result = self.view_symint(shape);
3556 }
3557
3558 if (names) {
3559 auto outnames = self.names().vec();
3560 outnames.erase(outnames.begin() + dim);
3561 outnames.insert(outnames.begin() + dim, names->begin(), names->end());
3562 at::internal_set_names_inplace(result, outnames);
3563 }
3564
3565 return result;
3566 }
3567
unflatten_symint(const Tensor & self,int64_t dim,SymIntArrayRef sizes)3568 Tensor unflatten_symint(const Tensor& self, int64_t dim, SymIntArrayRef sizes) {
3569 return native::unflatten_impl(self, dim, sizes, std::nullopt);
3570 }
3571
unflatten_dimname_symint(const Tensor & self,Dimname dim,SymIntArrayRef sizes,DimnameList names)3572 Tensor unflatten_dimname_symint(const Tensor& self, Dimname dim, SymIntArrayRef sizes, DimnameList names) {
3573 return native::unflatten_impl(self, dimname_to_position(self, dim), sizes, names);
3574 }
3575
view_as(const Tensor & self,const Tensor & other)3576 Tensor view_as(const Tensor& self, const Tensor& other) {
3577 return self.view_symint(other.sym_sizes());
3578 }
3579
unbind(const Tensor & self,int64_t dim)3580 std::vector<Tensor> unbind(const Tensor &self, int64_t dim) {
3581 dim = maybe_wrap_dim(dim, self.dim());
3582 int64_t size = self.size(dim);
3583 std::vector<Tensor> tensors(size);
3584 for (const auto i : c10::irange(size)) {
3585 tensors[i] = self.select(dim, i);
3586 }
3587 return tensors;
3588 }
3589
unbind(const Tensor & self,Dimname dim)3590 std::vector<Tensor> unbind(const Tensor& self, Dimname dim) {
3591 return at::unbind(self, dimname_to_position(self, dim));
3592 }
3593
meshgrid(TensorList tensors)3594 std::vector<Tensor> meshgrid(TensorList tensors) {
3595 TORCH_WARN_ONCE("torch.meshgrid: in an upcoming release, it will be required to pass the "
3596 "indexing argument.");
3597 return native::meshgrid(tensors, /*indexing=*/"ij");
3598 }
3599
meshgrid(TensorList tensors,c10::string_view indexing)3600 std::vector<Tensor> meshgrid(TensorList tensors,
3601 c10::string_view indexing) {
3602 int64_t size = tensors.size();
3603 TORCH_CHECK(size > 0, "meshgrid expects a non-empty TensorList");
3604
3605 for(const auto i: c10::irange(size - 1)){
3606 TORCH_CHECK(tensors[i].dtype() == tensors[i+1].dtype(), "meshgrid expects all tensors to have the same dtype");
3607 TORCH_CHECK(tensors[i].device() == tensors[i+1].device(), "meshgrid expects all tensors to have the same device");
3608 }
3609
3610 // Input tensors is of type TensorList, which is an alias to a
3611 // constant array slice, which doesn't allow for mutations. We may
3612 // need to swap our first two elements if indexing is "ij", so we
3613 // unconditionally create a vector that we can reorder to keep the
3614 // implementation simple.
3615 //
3616 // We are not concerned with the performance of this relative to
3617 // constructor a grid for each input.
3618 std::vector<std::reference_wrapper<const Tensor>> tensor_refs(tensors.begin(),
3619 tensors.end());
3620
3621 // Whether or not to swap the first two tensors.
3622 //
3623 // We only swap if there are at least two* input tensors (obviously)
3624 // and if indexing is "xy".
3625 //
3626 // A reminder about "xy" semantics: "xy" semantics implies that the
3627 // output grids are in the cartesian coordinate system. Thus the
3628 // first dimension is the "x" axis (corresponding to column) and the
3629 // second dimension is the "y" axis (corresponding to row). Tensors,
3630 // however, generally consider the first axis to be the row and the
3631 // second axis to be the columns. Thus we flip the two dimensions in
3632 // contrast to "ij" indexing.
3633 //
3634 // It turns out that it's easiest to implement this by just swapping
3635 // the first two inputs. However, the order of the outputs still
3636 // must correspond to the order of the inputs. Thus we also must
3637 // swap the outputs if we swapped the inputs.
3638 //
3639 // * Why do we even support this function for exactly one input?
3640 bool swap_first_and_second_tensors = false;
3641
3642 if (indexing == "xy") {
3643 // We can only swap if there are multiple tensors.
3644 swap_first_and_second_tensors = size >= 2;
3645 if (swap_first_and_second_tensors) {
3646 std::swap(tensor_refs[0], tensor_refs[1]);
3647 }
3648 } else {
3649 // Only "xy" and "ij" are supported, and we already checked for
3650 // "xy" above. Only "ij" remains as a valid mode.
3651 TORCH_CHECK(indexing == "ij",
3652 "torch.meshgrid: indexing must be one of \"xy\" or \"ij\", "
3653 "but received: ", indexing);
3654 }
3655
3656 std::vector<c10::SymInt> shape(size);
3657 for(const auto i: c10::irange(size)){
3658 TORCH_CHECK(tensor_refs[i].get().dim() <= 1,
3659 "torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: ", tensor_refs[i]);
3660 shape[i] = tensor_refs[i].get().sym_numel(); // treat 0D tensors as if they were a 1D tensor
3661 }
3662 std::vector<Tensor> grids;
3663 grids.reserve(size);
3664 std::vector<c10::SymInt> view_shape(size, 1);
3665 for(const auto i: c10::irange(size)){
3666 view_shape[i] = -1; // select this dimension to infer
3667 grids.push_back(tensor_refs[i].get().view_symint(view_shape).expand_symint(shape));
3668 view_shape[i] = 1; // restore to previous value
3669 }
3670
3671 // Remember we need to also swap the outputs if we swapped the inputs.
3672 if (swap_first_and_second_tensors) {
3673 std::swap(grids[0], grids[1]);
3674 }
3675 return grids;
3676 }
3677
3678 // Numpy-style `a.T`: returns the tensor
3679 // with dims reversed
numpy_T(const Tensor & self)3680 Tensor numpy_T(const Tensor &self) {
3681 const auto n = self.dim();
3682 if (n != 2 && n != 0) {
3683 TORCH_WARN_ONCE(
3684 "The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated ",
3685 "and it will throw an error in a future release. Consider `x.mT` to transpose batches of matrices ",
3686 "or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor."
3687 );
3688 }
3689 if (n == 0) {
3690 // Added in PyTorch 2.0
3691 TORCH_WARN_ONCE("Tensor.T is deprecated on 0-D tensors. This function is the identity in these cases.");
3692 }
3693 DimVector transpose_dims;
3694 for (int64_t i = n - 1; i >= 0; --i) {
3695 transpose_dims.push_back(i);
3696 }
3697 return self.permute(transpose_dims);
3698 }
3699
matrix_H(const Tensor & self)3700 Tensor matrix_H(const Tensor &self) {
3701 const auto ndim = self.dim();
3702 if (ndim == 0) {
3703 // Added in PyTorch 2.0
3704 TORCH_WARN_ONCE("Tensor.H is deprecated on 0-D tensors. Consider using x.conj().");
3705 }
3706 TORCH_CHECK(ndim == 2 || ndim == 0,
3707 "tensor.H is only supported on matrices (2-D tensors). Got ", ndim, "-D tensor.",
3708 ndim > 2 ? " For batches of matrices, consider using tensor.mH" : "");
3709 if (self.is_complex()) {
3710 return ndim == 0 ? self.conj() : self.transpose(-2, -1).conj();
3711 } else {
3712 return ndim == 0 ? self : self.transpose(-2, -1);
3713 }
3714 }
3715
3716 namespace {
_adjoint(const Tensor & self,const bool transpose,const char * const name)3717 Tensor _adjoint(const Tensor &self, const bool transpose, const char* const name) {
3718 const auto ndim = self.dim();
3719 TORCH_CHECK(ndim != 1,
3720 "tensor.", name, " is only supported on matrices or batches of matrices. Got 1-D tensor.");
3721 if (transpose || !self.is_complex()) {
3722 return ndim == 0 ? self : self.transpose(-2, -1);
3723 } else {
3724 return ndim == 0 ? self.conj() : self.transpose(-2, -1).conj();
3725 }
3726 }
3727 } // anonymous namespace
3728
mT(const Tensor & self)3729 Tensor mT(const Tensor &self) {
3730 if (self.dim() == 0) {
3731 // Added in PyTorch 2.0
3732 TORCH_WARN_ONCE("Tensor.mT is deprecated on 0-D tensors. This function is the identity in these cases.");
3733 }
3734 return _adjoint(self, /*transpose=*/true, "mT");
3735 }
3736
mH(const Tensor & self)3737 Tensor mH(const Tensor &self) {
3738 if (self.dim() == 0) {
3739 // Added in PyTorch 2.0
3740 TORCH_WARN_ONCE("Tensor.mH is deprecated on 0-D tensors. Consider using x.conj().");
3741 }
3742 return _adjoint(self, /*transpose=*/false, "mH");
3743 }
3744
adjoint(const Tensor & self)3745 Tensor adjoint(const Tensor &self) {
3746 if (self.dim() == 0) {
3747 TORCH_WARN_ONCE("adjoint() is deprecated on 0-D tensors. Consider using x.conj().");
3748 }
3749 return _adjoint(self, /*transpose=*/false, "adjoint()");
3750 }
3751
view(const Tensor & self,at::IntArrayRef size)3752 Tensor view(const Tensor& self,
3753 at::IntArrayRef size) {
3754 return view_impl(self, size);
3755 }
3756
alias(const Tensor & self)3757 Tensor alias(const Tensor& self) {
3758 return alias_with_sizes_and_strides(self, self.sym_sizes(), self.sym_strides());
3759 }
3760
detach(const Tensor & self)3761 Tensor detach(const Tensor& self) {
3762 // NB: detach() is not the same thing as alias()! The main difference is that
3763 // detach does not allow metadata change while alias does.
3764 return Tensor(self.getIntrusivePtr()->shallow_copy_and_detach(
3765 // NB: The ADInplaceOrView logic will overwrite these with the
3766 // appropriate values if it runs; otherwise these are the values.
3767 /*version_counter=*/0,
3768 /*allow_tensor_metadata_change=*/false));
3769 }
3770
unfold(const Tensor & self,int64_t d,int64_t size,int64_t step)3771 Tensor unfold(const Tensor& self, int64_t d, int64_t size, int64_t step) {
3772 // some special handling to deal with allow d == 0 when self.dim() == 0
3773 auto ndim = self.dim();
3774 d = at::maybe_wrap_dim(d, ndim, /*wrap_scalar=*/true);
3775
3776 auto sizes = self.sizes().vec();
3777 auto strides = self.strides().vec();
3778 int64_t max_size = self.dim() == 0 ? 1 : sizes[d];
3779 TORCH_CHECK(size <= max_size, "maximum size for tensor at dimension ", d,
3780 " is ", max_size, " but size is ", size);
3781 TORCH_CHECK(step > 0, "step is ", step, " but must be > 0");
3782 sizes.push_back(size);
3783 strides.push_back(self.dim() == 0 ? 1 : strides[d]);
3784 // The if handles the self.dim() == 0 case
3785 if (d < ndim) {
3786 sizes[d] = (sizes[d] - size) / step + 1;
3787 strides[d] *= step;
3788 }
3789 return self.as_strided(sizes, strides);
3790 }
3791
diag(const Tensor & self,int64_t offset)3792 Tensor diag(const Tensor& self, int64_t offset) {
3793 auto ndim = self.dim();
3794 TORCH_CHECK(ndim == 1 || ndim == 2, "diag(): Supports 1D or 2D tensors. Got ", self.dim(), "D");
3795 if (ndim == 1) {
3796 return at::diag_embed(self, offset);
3797 } else {
3798 // We return a copy of the diagonal
3799 return at::diagonal_copy(self, offset);
3800 }
3801 }
3802
diag_out(const Tensor & self,int64_t offset,Tensor & out)3803 Tensor& diag_out(const Tensor& self, int64_t offset, Tensor& out) {
3804 auto ndim = self.dim();
3805 TORCH_CHECK(ndim == 1 || ndim == 2, "Supports 1D or 2D tensors. Got ", self.dim(), "D");
3806 if (ndim == 1) {
3807 TORCH_CHECK(
3808 canCast(self.scalar_type(), out.scalar_type()),
3809 "diag: result type ", self.scalar_type(), " can't be cast to the desired out= type ",
3810 out.scalar_type());
3811 return at::diag_embed_out(out, self, offset);
3812 } else {
3813 return at::diagonal_copy_out(out, self, offset);
3814 }
3815 }
3816
diagonal_backward_symint(const Tensor & grad,SymIntArrayRef input_sizes,int64_t offset,int64_t dim1,int64_t dim2)3817 Tensor diagonal_backward_symint(const Tensor & grad, SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
3818 auto grad_input = at::zeros_symint(input_sizes, grad.options());
3819 auto diag = grad_input.diagonal(offset, dim1, dim2);
3820 diag.copy_(grad);
3821 return grad_input;
3822 }
3823
movedim(const Tensor & self,IntArrayRef src,IntArrayRef dst)3824 Tensor movedim(const Tensor& self, IntArrayRef src, IntArrayRef dst) {
3825 TORCH_CHECK(src.size() == dst.size(), "movedim: Invalid source or destination dims: source (",
3826 src, " dims) should contain the same number of dims as destination (", dst, " dims)");
3827
3828 size_t self_dim = self.dim();
3829 DimVector normalized_src(src.size());
3830 DimVector normalized_dst(dst.size());
3831
3832 auto wrap_dims = [&self_dim](const IntArrayRef& vec, DimVector& normalized_vec) {
3833 for (const auto i : c10::irange(vec.size())) {
3834 normalized_vec[i] = maybe_wrap_dim(vec[i], self_dim);
3835 }
3836 };
3837
3838 wrap_dims(src, normalized_src);
3839 wrap_dims(dst, normalized_dst);
3840
3841 auto all_unique = [](const DimVector& dims) {
3842 DimVector copy = dims;
3843 std::sort(copy.begin(), copy.end());
3844 auto duplicate = std::adjacent_find(copy.begin(), copy.end());
3845 return duplicate == copy.end();
3846 };
3847 TORCH_CHECK(all_unique(normalized_src), "movedim: repeated dim in `source` (", src, ")");
3848 TORCH_CHECK(all_unique(normalized_dst), "movedim: repeated dim in `destination` (", dst, ")");
3849
3850 // handle the case of scalar tensor as a no-op
3851 if (self_dim == 0)
3852 return self.alias();
3853
3854 // TODO: The algorithm below can probably be optimized.
3855 // Reference: https://github.com/pytorch/pytorch/pull/41480#discussion_r456100505
3856
3857 // Algorithm Walkthrough
3858 // Example Input
3859 // Variable State:
3860 // normalized_src = 0, 1
3861 // normalized_dst = 2, 4
3862 // self_dim = 5
3863 DimVector order(self_dim);
3864 DimVector source_dims(self_dim);
3865 DimVector destination_dims(self_dim);
3866
3867 // We initialize two vectors to track update to the dims
3868 // `order` contains the final order of the dim positions.
3869 // Variable State:
3870 // order = NA, NA, NA, NA, NA
3871 // source_dims = 0, 1, 2, 3, 4
3872 // destination_dims = 0, 1, 2, 3, 4
3873 std::iota(source_dims.begin(), source_dims.end(), 0);
3874 std::iota(destination_dims.begin(), destination_dims.end(), 0);
3875
3876 // We mark and update position for the dim provided by user
3877 // i.e. `normalized_src` and `normalized_dims`
3878 // Variable State:
3879 // order = NA, NA, 0, NA, 1
3880 // source_dims = -1, -1, 2, 3, 4
3881 // destination_dims = 0, 1, -1, 3, -1
3882 for (const auto i : c10::irange(src.size())) {
3883 order[normalized_dst[i]] = normalized_src[i];
3884 source_dims[normalized_src[i]] = -1;
3885 destination_dims[normalized_dst[i]] = -1;
3886 }
3887
3888 // Remove the dims whose position we already know,
3889 // the ones marked with -1 in previous step
3890 // Variable State:
3891 // source_dims = 2, 3, 4
3892 // destination_dims = 0, 1, 3
3893 auto source_iter = std::remove(source_dims.begin(), source_dims.end(), -1);
3894 auto destination_iter = std::remove(destination_dims.begin(), destination_dims.end(), -1);
3895
3896 int64_t rest_dim = self.dim() - src.size();
3897 TORCH_INTERNAL_ASSERT(std::distance(source_dims.begin(), source_iter) == rest_dim);
3898 TORCH_INTERNAL_ASSERT(std::distance(destination_dims.begin(), destination_iter) == rest_dim);
3899
3900 // Update the position of the remaining dimensions.
3901 // `source_dims` now contains the original position
3902 // `destination_dims` contains the new position it will shifted to
3903 // after considering the user inputs.
3904 // Variable State:
3905 // order = 2, 3, 0, 4, 1
3906 for (const auto i : c10::irange(rest_dim)) {
3907 order[destination_dims[i]] = source_dims[i];
3908 }
3909
3910 return self.permute(order);
3911 }
3912
movedim(const Tensor & self,int64_t src,int64_t dst)3913 Tensor movedim(const Tensor& self, int64_t src, int64_t dst) {
3914 return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst});
3915 }
3916
moveaxis(const Tensor & self,IntArrayRef src,IntArrayRef dst)3917 Tensor moveaxis(const Tensor& self, IntArrayRef src, IntArrayRef dst) {
3918 return at::movedim(self, src, dst);
3919 }
3920
moveaxis(const Tensor & self,int64_t src,int64_t dst)3921 Tensor moveaxis(const Tensor& self, int64_t src, int64_t dst) {
3922 return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst});
3923 }
3924
swapaxes(const Tensor & self,int64_t axis0,int64_t axis1)3925 Tensor swapaxes(const Tensor& self, int64_t axis0, int64_t axis1) {
3926 return self.transpose(axis0, axis1);
3927 }
3928
swapaxes_(Tensor & self,int64_t axis0,int64_t axis1)3929 Tensor& swapaxes_(Tensor& self, int64_t axis0, int64_t axis1) {
3930 return self.transpose_(axis0, axis1);
3931 }
3932
swapdims(const Tensor & self,int64_t dim0,int64_t dim1)3933 Tensor swapdims(const Tensor& self, int64_t dim0, int64_t dim1) {
3934 return self.transpose(dim0, dim1);
3935 }
3936
swapdims_(Tensor & self,int64_t dim0,int64_t dim1)3937 Tensor& swapdims_(Tensor& self, int64_t dim0, int64_t dim1) {
3938 return self.transpose_(dim0, dim1);
3939 }
3940
flatten_dense_tensors(TensorList tensors)3941 Tensor flatten_dense_tensors(TensorList tensors) {
3942 static auto flatten = [](const Tensor &t) { return t.contiguous().view({-1}); };
3943 if (tensors.size() == 1)
3944 return flatten(tensors[0]);
3945 return at::cat(fmap(tensors, flatten));
3946 }
3947
unflatten_dense_tensors(const Tensor & flat,TensorList tensors)3948 std::vector<Tensor> unflatten_dense_tensors(const Tensor& flat, TensorList tensors) {
3949 std::vector<Tensor> outputs;
3950 outputs.reserve(tensors.size());
3951 size_t offset = 0;
3952 for (const auto & tensor : tensors) {
3953 auto numel = tensor.numel();
3954 // If unflatten an empty tensor, create a new empty tensor using
3955 // flat tensor Options.
3956 // This can avoid the unflattened empty tensor to share the same storage
3957 // with other unflatten tensors.
3958 if (numel == 0) {
3959 outputs.push_back(at::empty({0}, flat.options()));
3960 } else {
3961 outputs.push_back(flat.narrow(0, offset, numel).view(tensor.sizes()));
3962 offset += numel;
3963 }
3964 }
3965 return outputs;
3966 }
3967
3968
3969 // Clones a tensor by cloning the underlying storage that it came from,
3970 // which allows us to replicate the exact strides/storage_offset in the cloned tensor.
3971 // Note [*_scatter ops preserve strides]
3972 // In order for functionalization to preserve stride correctness, the *_scatter
3973 // operators that it calls must preserve the striding behavior of their inputs.
3974 // Specifically, the output of *_scatter(base, mutated_view, ...)
3975 // should have identical size/stride/storage_offset to "base".
clone_preserve_strides(const at::Tensor & self)3976 at::Tensor clone_preserve_strides(const at::Tensor& self) {
3977 TORCH_INTERNAL_ASSERT(self.has_storage());
3978 // In cases where the input tensor has internal memory overlap, we cannot actually
3979 // preserve the strides/storage_offset of the input tensor, because
3980 // *_scatter ops will try to copy_() into the cloned tensor.
3981 // However, this should **never** show up in functionalized user code;
3982 // most aten ops that try to mutate a tensor with internal memory overlap would error anyway.
3983 //
3984 // The one place that this does come up is in autograd - if there's a select_scatter
3985 // in the forward, then autograd will generate one for the backward.
3986 // If the input to the select_scatter is grad_output, then this could be an expanded tensor
3987 // with internal overlap.
3988 if (at::has_internal_overlap(self) == at::MemOverlap::Yes) {
3989 return self.clone();
3990 }
3991 auto dtype_size = self.dtype().itemsize();
3992 auto nbytes = self.storage().sym_nbytes();
3993 TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
3994 auto numel = nbytes / dtype_size;
3995 auto self_full_size = self.as_strided_symint({std::move(numel)}, {1}, 0);
3996 auto clone = self_full_size.clone();
3997 auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
3998 return out;
3999 }
4000
4001
slice_scatter(const at::Tensor & self,const at::Tensor & src,int64_t dim,std::optional<int64_t> start,std::optional<int64_t> end,int64_t step)4002 at::Tensor slice_scatter(const at::Tensor& self, const at::Tensor& src, int64_t dim, std::optional<int64_t> start, std::optional<int64_t> end, int64_t step) {
4003 // See Note [*_scatter ops preserve strides]
4004 auto output = clone_preserve_strides(self);
4005 auto slice = output.slice(dim, start, end, step);
4006 TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
4007 slice.copy_(src);
4008 return output;
4009 }
select_scatter_symint(const at::Tensor & self,const at::Tensor & src,int64_t dim,c10::SymInt index)4010 at::Tensor select_scatter_symint(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::SymInt index) {
4011 auto output = clone_preserve_strides(self);
4012 auto slice = output.select_symint(dim, std::move(index));
4013 TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
4014 slice.copy_(src);
4015 return output;
4016 }
diagonal_scatter(const at::Tensor & self,const at::Tensor & src,int64_t offset,int64_t dim1,int64_t dim2)4017 at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64_t offset, int64_t dim1, int64_t dim2) {
4018 // See Note [*_scatter ops preserve strides]
4019 auto output = clone_preserve_strides(self);
4020 auto slice = output.diagonal(offset, dim1, dim2);
4021 TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
4022 slice.copy_(src);
4023 return output;
4024 }
as_strided_scatter_symint(const at::Tensor & self,const at::Tensor & src,at::SymIntArrayRef size,at::SymIntArrayRef stride,std::optional<c10::SymInt> storage_offset)4025 at::Tensor as_strided_scatter_symint(const at::Tensor& self, const at::Tensor& src, at::SymIntArrayRef size, at::SymIntArrayRef stride, std::optional<c10::SymInt> storage_offset) {
4026 // See Note [as_strided_scatter backward support]
4027 TORCH_INTERNAL_ASSERT(!self.requires_grad() || self.is_contiguous(), "as_strided_scatter is currently only supported for contiguous inputs");
4028 // See Note [*_scatter ops preserve strides]
4029 auto output = clone_preserve_strides(self);
4030 auto slice = output.as_strided_symint(size, stride, std::move(storage_offset));
4031 TORCH_CHECK(slice.sym_sizes() == src.sym_sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sym_sizes(), ", slice size = ", slice.sym_sizes());
4032 slice.copy_(src);
4033 return output;
4034 }
4035
4036 // The default implementation of lift is a no-op.
4037 // If TLS is set appropriately (for wrapper-tensor keys like Functionalize or functorch transforms),
4038 // then we'll dispatch to one of their implementations, which will properly lift the tensor into a wrapper.
lift(const at::Tensor & self)4039 at::Tensor lift(const at::Tensor& self) {
4040 return self;
4041 }
4042
4043 // See notes in native_functions.yaml
lift_fresh(const at::Tensor & self)4044 at::Tensor lift_fresh(const at::Tensor& self) {
4045 return self;
4046 }
4047
4048 // Autogen kernels for tensor list ops dont work on XLA. TODO(jakeszwe)
split_copy_Tensor_out(const at::Tensor & self,int64_t split_size,int64_t dim,at::TensorList out)4049 void split_copy_Tensor_out(const at::Tensor & self, int64_t split_size, int64_t dim, at::TensorList out) {
4050 auto tmp = self.split(split_size, dim);
4051
4052 TORCH_CHECK(out.size() == tmp.size(), "split_copy_Tensor_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
4053 for (const auto i : c10::irange(out.size())) {
4054 out[i].copy_(tmp[i]);
4055 }
4056 }
4057
split_with_sizes_copy_out(const at::Tensor & self,at::IntArrayRef split_sizes,int64_t dim,at::TensorList out)4058 void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) {
4059 auto tmp = self.split_with_sizes(split_sizes, dim);
4060
4061 TORCH_CHECK(out.size() == tmp.size(), "split_with_sizes_copy_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
4062 for (const auto i : c10::irange(out.size())) {
4063 if (resize_output_check(out[i], tmp[i].sizes())) {
4064 out[i].resize_(tmp[i].sizes());
4065 }
4066 TORCH_CHECK(out[i].dtype() == tmp[i].dtype(),
4067 "Expected out tensor to have dtype ", tmp[i].dtype(), ", but got ", out[i].dtype(), " instead");
4068 TORCH_CHECK(out[i].device() == tmp[i].device(),
4069 "Expected out tensor to have device ", tmp[i].device(), ", but got ", out[i].device(), " instead");
4070 out[i].copy_(tmp[i]);
4071 }
4072 }
4073
unbind_copy_int_out(const at::Tensor & self,int64_t dim,at::TensorList out)4074 void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList out) {
4075 auto tmp = self.unbind(dim);
4076
4077 TORCH_CHECK(out.size() == tmp.size(), "unbind_copy_int_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
4078 for (const auto i : c10::irange(out.size())) {
4079 out[i].copy_(tmp[i]);
4080 }
4081 }
4082
sparse_dim_default(const Tensor & self)4083 int64_t sparse_dim_default(const Tensor& self) {
4084 TORCH_CHECK(self.layout() == kStrided, "sparse_dim expected sparse or strided tensor layout but got ", self.layout());
4085 return 0;
4086 }
4087
dense_dim_default(const Tensor & self)4088 int64_t dense_dim_default(const Tensor& self) {
4089 TORCH_CHECK(self.layout() == kStrided, "dense_dim expected sparse or strided tensor layout but got ", self.layout());
4090 return self.dim();
4091 }
4092
4093 } // namespace at::native
4094