xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorShape.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/DimVector.h>
4 #include <ATen/core/functional.h>
5 #include <ATen/core/IListRef.h>
6 #include <ATen/TensorSubclassLikeUtils.h>
7 #include <ATen/AccumulateType.h>
8 #include <ATen/Dispatch.h>
9 #include <ATen/ExpandUtils.h>
10 #include <ATen/InferSize.h>
11 #include <ATen/MemoryOverlap.h>
12 #include <ATen/NamedTensorUtils.h>
13 #include <ATen/SparseCsrTensorUtils.h>
14 #include <ATen/TensorOperators.h>
15 #include <ATen/WrapDimUtils.h>
16 #include <ATen/core/DimVector.h>
17 #include <ATen/core/IListRef.h>
18 #include <ATen/native/Copy.h>
19 #include <ATen/native/NonSymbolicBC.h>
20 #include <ATen/native/Resize.h>
21 #include <ATen/native/SparseTensorUtils.h>
22 #include <ATen/native/TensorIterator.h>
23 #include <ATen/native/TensorShape.h>
24 #include <ATen/native/TypeProperties.h>
25 #include <ATen/native/cpu/CatKernel.h>
26 #include <ATen/native/cpu/SerialStackImpl.h>
27 #include <ATen/native/cpu/StackKernel.h>
28 #include <ATen/quantized/QTensorImpl.h>
29 #include <c10/util/Exception.h>
30 #include <optional>
31 #include <c10/util/SmallVector.h>
32 #include <c10/util/accumulate.h>
33 #include <c10/util/irange.h>
34 
35 #ifndef AT_PER_OPERATOR_HEADERS
36 #include <ATen/Functions.h>
37 #include <ATen/NativeFunctions.h>
38 #else
39 #include <ATen/ops/_chunk_cat_native.h>
40 #include <ATen/ops/_conj_copy_native.h>
41 #include <ATen/ops/_convert_indices_from_coo_to_csr.h>
42 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
43 #include <ATen/ops/_foreach_copy.h>
44 #include <ATen/ops/_fw_primal_copy_native.h>
45 #include <ATen/ops/_indices_copy_native.h>
46 #include <ATen/ops/_make_dual.h>
47 #include <ATen/ops/_make_dual_copy_native.h>
48 #include <ATen/ops/_mkldnn_reshape.h>
49 #include <ATen/ops/_mkldnn_transpose.h>
50 #include <ATen/ops/_neg_view_copy_native.h>
51 #include <ATen/ops/_reshape_alias_copy_native.h>
52 #include <ATen/ops/_reshape_alias_native.h>
53 #include <ATen/ops/_reshape_copy_native.h>
54 #include <ATen/ops/_reshape_from_tensor_native.h>
55 #include <ATen/ops/_shape_as_tensor_native.h>
56 #include <ATen/ops/_sparse_broadcast_to.h>
57 #include <ATen/ops/_sparse_broadcast_to_copy_native.h>
58 #include <ATen/ops/_sparse_broadcast_to_native.h>
59 #include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
60 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
61 #include <ATen/ops/_sparse_csc_tensor_unsafe_native.h>
62 #include <ATen/ops/_sparse_csr_tensor_unsafe.h>
63 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
64 #include <ATen/ops/_stack_native.h>
65 #include <ATen/ops/_unsafe_view.h>
66 #include <ATen/ops/_unsafe_view_native.h>
67 #include <ATen/ops/_values_copy_native.h>
68 #include <ATen/ops/adjoint_native.h>
69 #include <ATen/ops/alias.h>
70 #include <ATen/ops/alias_copy_native.h>
71 #include <ATen/ops/alias_native.h>
72 #include <ATen/ops/arange.h>
73 #include <ATen/ops/arange_native.h>
74 #include <ATen/ops/as_strided_copy_native.h>
75 #include <ATen/ops/as_strided_native.h>
76 #include <ATen/ops/as_strided_scatter_native.h>
77 #include <ATen/ops/atleast_1d.h>
78 #include <ATen/ops/atleast_2d.h>
79 #include <ATen/ops/atleast_3d.h>
80 #include <ATen/ops/block_diag_native.h>
81 #include <ATen/ops/broadcast_tensors_native.h>
82 #include <ATen/ops/broadcast_to_native.h>
83 #include <ATen/ops/cat.h>
84 #include <ATen/ops/cat_meta.h>
85 #include <ATen/ops/cat_native.h>
86 #include <ATen/ops/chunk_native.h>
87 #include <ATen/ops/col_indices_copy_native.h>
88 #include <ATen/ops/column_stack_native.h>
89 #include <ATen/ops/concat_native.h>
90 #include <ATen/ops/concatenate_native.h>
91 #include <ATen/ops/crow_indices_copy_native.h>
92 #include <ATen/ops/dense_dim_native.h>
93 #include <ATen/ops/detach_copy_native.h>
94 #include <ATen/ops/detach_native.h>
95 #include <ATen/ops/diag.h>
96 #include <ATen/ops/diag_embed.h>
97 #include <ATen/ops/diag_embed_native.h>
98 #include <ATen/ops/diag_native.h>
99 #include <ATen/ops/diagflat_native.h>
100 #include <ATen/ops/diagonal.h>
101 #include <ATen/ops/diagonal_backward.h>
102 #include <ATen/ops/diagonal_backward_native.h>
103 #include <ATen/ops/diagonal_copy.h>
104 #include <ATen/ops/diagonal_copy_native.h>
105 #include <ATen/ops/diagonal_native.h>
106 #include <ATen/ops/diagonal_scatter_native.h>
107 #include <ATen/ops/dsplit_native.h>
108 #include <ATen/ops/dstack_native.h>
109 #include <ATen/ops/empty.h>
110 #include <ATen/ops/empty_like.h>
111 #include <ATen/ops/empty_quantized.h>
112 #include <ATen/ops/expand_as_native.h>
113 #include <ATen/ops/expand_copy_native.h>
114 #include <ATen/ops/expand_native.h>
115 #include <ATen/ops/flatten_dense_tensors_native.h>
116 #include <ATen/ops/flatten_native.h>
117 #include <ATen/ops/from_blob.h>
118 #include <ATen/ops/hsplit_native.h>
119 #include <ATen/ops/hstack.h>
120 #include <ATen/ops/hstack_native.h>
121 #include <ATen/ops/index_select_native.h>
122 #include <ATen/ops/indices_copy_native.h>
123 #include <ATen/ops/lift_fresh_native.h>
124 #include <ATen/ops/lift_native.h>
125 #include <ATen/ops/mH_native.h>
126 #include <ATen/ops/mT_native.h>
127 #include <ATen/ops/matrix_H_native.h>
128 #include <ATen/ops/meshgrid_native.h>
129 #include <ATen/ops/moveaxis_native.h>
130 #include <ATen/ops/movedim.h>
131 #include <ATen/ops/movedim_native.h>
132 #include <ATen/ops/narrow.h>
133 #include <ATen/ops/narrow_copy.h>
134 #include <ATen/ops/narrow_copy_native.h>
135 #include <ATen/ops/narrow_native.h>
136 #include <ATen/ops/new_empty_native.h>
137 #include <ATen/ops/new_ones_native.h>
138 #include <ATen/ops/numpy_T_native.h>
139 #include <ATen/ops/permute_copy_native.h>
140 #include <ATen/ops/permute_native.h>
141 #include <ATen/ops/ravel_native.h>
142 #include <ATen/ops/repeat_native.h>
143 #include <ATen/ops/reshape_as_native.h>
144 #include <ATen/ops/reshape_native.h>
145 #include <ATen/ops/resize_native.h>
146 #include <ATen/ops/row_stack_native.h>
147 #include <ATen/ops/select.h>
148 #include <ATen/ops/select_backward_native.h>
149 #include <ATen/ops/select_copy_native.h>
150 #include <ATen/ops/select_native.h>
151 #include <ATen/ops/select_scatter_native.h>
152 #include <ATen/ops/set_native.h>
153 #include <ATen/ops/slice.h>
154 #include <ATen/ops/slice_backward_native.h>
155 #include <ATen/ops/slice_copy_native.h>
156 #include <ATen/ops/slice_inverse_native.h>
157 #include <ATen/ops/slice_native.h>
158 #include <ATen/ops/slice_scatter_native.h>
159 #include <ATen/ops/sparse_coo_tensor.h>
160 #include <ATen/ops/sparse_coo_tensor_native.h>
161 #include <ATen/ops/sparse_dim_native.h>
162 #include <ATen/ops/split_copy_native.h>
163 #include <ATen/ops/split_native.h>
164 #include <ATen/ops/split_with_sizes.h>
165 #include <ATen/ops/split_with_sizes_copy_native.h>
166 #include <ATen/ops/split_with_sizes_native.h>
167 #include <ATen/ops/squeeze_copy_native.h>
168 #include <ATen/ops/squeeze_native.h>
169 #include <ATen/ops/squeeze.h>
170 #include <ATen/ops/stack_native.h>
171 #include <ATen/ops/sub.h>
172 #include <ATen/ops/sum.h>
173 #include <ATen/ops/sum_to_size_native.h>
174 #include <ATen/ops/swapaxes_native.h>
175 #include <ATen/ops/swapdims_native.h>
176 #include <ATen/ops/t_copy_native.h>
177 #include <ATen/ops/t_native.h>
178 #include <ATen/ops/tensor.h>
179 #include <ATen/ops/tensor_split.h>
180 #include <ATen/ops/tensor_split_native.h>
181 #include <ATen/ops/tile_native.h>
182 #include <ATen/ops/transpose.h>
183 #include <ATen/ops/transpose_copy_native.h>
184 #include <ATen/ops/transpose_native.h>
185 #include <ATen/ops/unbind.h>
186 #include <ATen/ops/unbind_copy_native.h>
187 #include <ATen/ops/unbind_native.h>
188 #include <ATen/ops/unflatten_dense_tensors_native.h>
189 #include <ATen/ops/unflatten_native.h>
190 #include <ATen/ops/unfold_copy_native.h>
191 #include <ATen/ops/unfold_native.h>
192 #include <ATen/ops/unsafe_chunk_native.h>
193 #include <ATen/ops/unsafe_split_native.h>
194 #include <ATen/ops/unsafe_split_with_sizes_native.h>
195 #include <ATen/ops/unsqueeze_copy_native.h>
196 #include <ATen/ops/unsqueeze_native.h>
197 #include <ATen/ops/values_copy_native.h>
198 #include <ATen/ops/view_as_complex.h>
199 #include <ATen/ops/view_as_complex_copy_native.h>
200 #include <ATen/ops/view_as_native.h>
201 #include <ATen/ops/view_as_real.h>
202 #include <ATen/ops/view_as_real_copy_native.h>
203 #include <ATen/ops/view_copy_native.h>
204 #include <ATen/ops/view_native.h>
205 #include <ATen/ops/vsplit_native.h>
206 #include <ATen/ops/vstack.h>
207 #include <ATen/ops/vstack_native.h>
208 #include <ATen/ops/zeros.h>
209 #include <ATen/ops/zeros_like.h>
210 #include <ATen/ops/zeros_native.h>
211 #endif
212 
213 #include <algorithm>
214 #include <cstdint>
215 #include <utility>
216 #include <vector>
217 
218 namespace at::meta {
cat_check_no_zero_dim(const MaterializedITensorListRef & tensors)219 inline void cat_check_no_zero_dim(const MaterializedITensorListRef& tensors) {
220   size_t i = 0;
221   for (const Tensor& t : tensors) {
222     TORCH_CHECK(
223         t.dim() > 0,
224         "zero-dimensional tensor (at position ", i, ") cannot be concatenated");
225     i++;
226   }
227 }
228 
cat_compute_output_memory_format(const MaterializedITensorListRef & inputs)229 inline c10::MemoryFormat cat_compute_output_memory_format(const MaterializedITensorListRef& inputs) {
230   std::optional<c10::MemoryFormat> format = std::nullopt;
231   for (const Tensor& t : inputs) {
232     auto f = t.suggest_memory_format();
233     if (f == c10::MemoryFormat::Contiguous) {
234         return f;
235     }
236     if (format.has_value() && format.value() != f) {
237         return c10::MemoryFormat::Contiguous;
238     }
239     format = f;
240   }
241   return format.value();
242 }
243 
TORCH_PRECOMPUTE_META_FUNC(cat)244 TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) {
245   // previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
246   // to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
247   // to be "skipped".  We maintain this behavior for backwards compatibility, but only for this specific
248   // size (i.e. other empty sizes are not skipped).
249   auto materialized = tensors.materialize();
250 
251   cat_check_no_zero_dim(materialized);
252   dim = at::legacy_cat_wrap_dim(dim, materialized);
253 
254   // Checking names before the actual dimensions.
255   auto maybe_outnames = namedinference::compute_cat_outnames(materialized);
256 
257   TORCH_CHECK(
258       !materialized.empty(), "torch.cat(): expected a non-empty list of Tensors");
259 
260   // Look for the first valid tensor.
261   size_t valid = materialized.size();
262   for (const auto i : c10::irange(materialized.size())) {
263     if (!at::native::cat_should_skip_tensor(materialized[i].get())) {
264       valid = i;
265       break;
266     }
267   }
268 
269   bool all_contiguous = true;
270   bool all_same_dtype = true;
271   bool all_same_sizes_and_stride = true;
272   auto memory_format = cat_compute_output_memory_format(materialized);
273 
274   // Compute what the output dtype should be:
275   const auto& result = maybe_get_output();
276   auto is_out_defined = result.defined();
277   auto out_dtype = at::native::result_type(tensors);
278 
279   // If the output tensor is defined, we need to take it into account
280   // when computing the actual output dtype and the flags.
281   if (is_out_defined) {
282     // Check for type promotion, if the output tensor is defined.
283     TORCH_CHECK(
284         canCast(out_dtype, result.scalar_type()),
285         "torch.cat(): input types can't be cast to the desired output type ",
286         result.scalar_type());
287     out_dtype = result.scalar_type();
288     all_contiguous = result.is_contiguous(memory_format);
289   }
290 
291   // Fallback 'set_output' parameters.
292   // (in case we don't find a valid tensor)
293   DimVector sizes {0};
294   TensorOptions options = materialized[0].get().options()
295       .dtype(out_dtype)
296       .memory_format(memory_format);
297 
298   // If we found a valid tensor, check whether the input tensors
299   // are compatible, i.e. we can execute `cat` on them.
300   bool found_valid_tensor = valid < materialized.size();
301   if (found_valid_tensor) {
302     TORCH_CHECK(
303         dim <= materialized[valid].get().dim(), "torch.cat(): dimension ", dim, "out of range");
304 
305     // Compute the output tensor size.
306     // It should have the same shape as any other valid tensor,
307     // except in the dimension 'dim'.
308     size_t size_at_dim = 0;
309     for (const auto i : c10::irange(materialized.size())) {
310       const Tensor& t = materialized[i];
311       all_same_dtype = all_same_dtype && out_dtype == t.scalar_type();
312       if (!at::native::cat_should_skip_tensor(t)) {
313         at::native::check_cat_shape_except_dim(materialized[valid], t, dim, i);
314         size_at_dim += t.size(dim);
315         all_contiguous = all_contiguous && t.is_contiguous(memory_format);
316         all_same_sizes_and_stride = all_same_sizes_and_stride &&
317             t.sizes() == materialized[valid].get().sizes() &&
318             t.strides() == materialized[valid].get().strides();
319       } else {
320         all_contiguous = false;
321       }
322     }
323 
324     // Actually set the output.
325     sizes = materialized[valid].get().sizes().vec();
326     sizes[dim] = size_at_dim;
327     options = materialized[valid].get().options()
328         .dtype(out_dtype)
329         .memory_format(memory_format);
330   }
331 
332   set_output_raw_strided(0, sizes, {}, options, maybe_outnames);
333   // Checks for overlaps between the inputs and the output tensor.
334   if (is_out_defined && found_valid_tensor) {
335     at::assert_no_internal_overlap(result);
336     for (const Tensor& t : materialized) {
337       at::assert_no_overlap(result, t);
338     }
339   }
340 
341   return TORCH_PRECOMPUTE_STRUCT(cat)()
342       .set_dim(dim)
343       .set_valid(valid)
344       .set_all_contiguous(all_contiguous)
345       .set_all_same_dtype(all_same_dtype)
346       .set_all_same_sizes_and_stride(all_same_sizes_and_stride)
347       .set_memory_format(memory_format);
348 }
349 } // namespace at::meta
350 
351 namespace at::native {
352 
353 DEFINE_DISPATCH(cat_serial_stub);
354 DEFINE_DISPATCH(stack_serial_stub);
355 
_reshape_from_tensor(const Tensor & self,const Tensor & shape_tensor)356 Tensor _reshape_from_tensor(const Tensor& self, const Tensor& shape_tensor) {
357   TORCH_CHECK(shape_tensor.dim() == 1);
358   std::vector<int64_t> shape;
359   auto accessor = shape_tensor.accessor<int64_t, 1>();
360   for (const auto i : c10::irange(shape_tensor.numel())) {
361     shape.push_back(accessor[i]);
362   }
363   return self.reshape(IntArrayRef(shape));
364 }
365 
_shape_as_tensor(const Tensor & self)366 Tensor _shape_as_tensor(const Tensor& self) {
367   auto options = TensorOptions(at::kLong);
368   return at::tensor(self.sizes(), options);
369 }
370 
set_(Tensor & result,Storage source)371 Tensor& set_(Tensor& result, Storage source) {
372   int64_t new_size =
373       static_cast<int64_t>(source.nbytes() / result.dtype().itemsize());
374   return result.set_(std::move(source), 0, new_size, {});
375 }
376 
377 
378 // unify with cuda implementation?  This is not done to avoid a dispatch in resize_impl_cpu_
set_storage_cpu_(Tensor & result,Storage storage,int64_t storage_offset,IntArrayRef size,IntArrayRef stride)379 Tensor& set_storage_cpu_(Tensor& result, Storage storage, int64_t storage_offset, IntArrayRef size, IntArrayRef stride) {
380   checkSetStorage(result, std::move(storage), storage_offset, size, stride);
381 
382   result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
383   at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ?
384                                           at::OptionalIntArrayRef(stride) : std::nullopt;
385   // We can re-use this kernel for the meta device.
386   // We just need to make sure we don't actually try to resize the (null) storage.
387   at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(), size, stride_opt, /*resize_storage=*/!result.is_meta());
388   return result;
389 }
390 
set_storage_meta__symint(Tensor & result,Storage storage,c10::SymInt storage_offset,c10::SymIntArrayRef size,c10::SymIntArrayRef stride)391 Tensor& set_storage_meta__symint(Tensor& result, Storage storage, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
392   checkSetStorage(result, storage, storage_offset, size, stride);
393 
394   c10::SymDimVector contiguous_strides;
395   if (stride.data() == nullptr) {
396     // TODO: dedupe this with empty() symbolic logic
397     int64_t dim = size.size();
398     contiguous_strides.resize(dim);
399     if (dim > 0) {
400       const auto last_idx = dim - 1;
401       contiguous_strides.at(last_idx) = 1;
402       for (auto i = last_idx - 1; i >= 0; --i) {
403         // TODO: max with 1
404         contiguous_strides.at(i) = contiguous_strides.at(i+1) * size.at(i+1);
405       }
406     }
407     stride = contiguous_strides;
408   }
409 
410   // Run this before storage setting so we can access numel
411   result.unsafeGetTensorImpl()->set_sizes_and_strides(size, stride, storage_offset);
412 
413   // Matches maybe_resize_storage_cpu no-numel behavior
414   if (TORCH_GUARD_SIZE_OBLIVIOUS(result.sym_numel().sym_ne(0))) {
415     // maybe_resize_storage_cpu can handle no storage exists at all but
416     // that should never be the case here
417     TORCH_INTERNAL_ASSERT(storage);
418     TORCH_CHECK(storage.resizable(), "Trying to resize storage that is not resizable");
419     // All meta data pointers are the same, so we don't have to "re" allocate
420     // it.  TODO: Actually this might not quite be correct if we use special
421     // pointers to track whether or not fake cuda tensors are pinned or not
422     const auto itemsize = result.dtype().itemsize();
423     c10::SymInt new_size_bytes = result.is_contiguous()
424       ? at::detail::computeStorageNbytesContiguous(size, itemsize, std::move(storage_offset))
425       : at::detail::computeStorageNbytes(size, stride, itemsize, std::move(storage_offset));
426     // TODO: When there are unbacked SymInts, we unconditionally skip the
427     // setter.  This is technically wrong, but we cannot conveniently test
428     // the real condition in many cases, because a lot of people are using
429     // set_ just to swizzle metadata on a tensor, they didn't actually want
430     // to see if they need to resize the storage.
431     //
432     // The old behavior was to unconditionally set_nbytes, but I think not
433     // setting it is more safe.
434     if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && TORCH_GUARD_SIZE_OBLIVIOUS(new_size_bytes.sym_gt(storage.sym_nbytes()))) {
435       storage.set_nbytes(std::move(new_size_bytes));
436     }
437   }
438   return result;
439 }
440 
set__symint(Tensor & result,const Tensor & storage,c10::SymInt storage_offset,c10::SymIntArrayRef size,c10::SymIntArrayRef stride)441 Tensor& set__symint(Tensor& result, const Tensor& storage, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) {
442   TORCH_CHECK(storage.is_contiguous(), "passed in tensor to be used as storage must be contiguous");
443   return result.set__symint(storage.storage(), storage_offset + storage.sym_storage_offset(), size, stride);
444 }
445 
set_tensor_(Tensor & result,const Tensor & source)446 Tensor& set_tensor_(Tensor& result, const Tensor& source) {
447   if (result.unsafeGetTensorImpl() != source.unsafeGetTensorImpl()) {
448     return result.set__symint(source.storage(), source.sym_storage_offset(), source.sym_sizes(), source.sym_strides());
449   }
450   return result;
451 }
452 
453 // this needs to be split along CPU/CUDA lines because we don't have a consistent
454 // way of getting the allocator to use for a device (c10::GetAllocator is not
455 // the same as at::cuda::getCUDADeviceAllocator().
set_cpu_(Tensor & result)456 Tensor& set_cpu_(Tensor& result) {
457   caffe2::TypeMeta dtype = result.dtype();
458   Storage storage(
459       Storage::use_byte_size_t(),
460       0,
461       c10::GetAllocator(kCPU),
462       true);
463   result.set_(std::move(storage), 0, {0}, {});
464   TORCH_INTERNAL_ASSERT(dtype == result.dtype());
465   return result;
466 }
467 
468 // We can't re-use the cpu kernel here because we don't want to use the cpu allocator.
set_meta_(Tensor & result)469 Tensor& set_meta_(Tensor& result) {
470   caffe2::TypeMeta dtype = result.dtype();
471   Storage storage(
472       Storage::use_byte_size_t(),
473       0,
474       c10::GetAllocator(kMeta),
475       true);
476   result.set_(std::move(storage), 0, {0}, {});
477   TORCH_INTERNAL_ASSERT(dtype == result.dtype());
478   return result;
479 }
480 
sparse_broadcast_to(const Tensor & self,IntArrayRef size)481 Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) {
482   TORCH_CHECK(self.is_sparse(), "input must be sparse tensor");
483   int64_t sparse_extra_ndim = size.size() - self.dim();
484   int64_t sparse_ndim = size.size() - self.dense_dim();
485   TORCH_CHECK(sparse_extra_ndim >= 0, "input not broadcastable to size with smaller dimensionality");
486   Tensor indices = self._indices();
487   Tensor values = self._values();
488   auto nnz = values.size(0);
489 
490   std::vector<int64_t> broadcast_sizes;
491   std::vector<int64_t> broadcast_dense_sizes;
492   std::vector<int64_t> broadcast_dims;
493   std::vector<int64_t> unchanged_dims;
494   broadcast_sizes.reserve(sparse_ndim);
495   broadcast_dense_sizes.reserve(self.dense_dim() + 1);
496   broadcast_dims.reserve(self.sparse_dim());
497   unchanged_dims.reserve(self.sparse_dim());
498   int64_t nnz_factor = 1;
499   int64_t min_broadcast_dim = (sparse_extra_ndim > 0 ? 0: -1);
500   int64_t max_unchanged_dim = -1;
501   for (int64_t i=0; i<sparse_extra_ndim; i++) {
502     auto d = size[i];
503     nnz_factor *= d;
504     broadcast_sizes.emplace_back(d);
505   }
506   for (int64_t i=0; i<self.sparse_dim(); i++) {
507     auto d = size[sparse_extra_ndim + i];
508     if (self.size(i) != d) {
509       TORCH_CHECK(self.size(i) == 1,
510                   "The expanded size of the tensor (",size[sparse_extra_ndim + i],") ",
511                   "must match the existing size (",self.size(i),")");
512       nnz_factor *= d;
513       broadcast_sizes.emplace_back(d);
514       if (min_broadcast_dim == -1) {
515         min_broadcast_dim = sparse_extra_ndim + i;
516       }
517       broadcast_dims.emplace_back(i);
518     } else {
519       unchanged_dims.emplace_back(i);
520       max_unchanged_dim = sparse_extra_ndim + i;
521     }
522   }
523   // to_broadcast conserves is_coalesced property iff only the last
524   // sparse dimensions are expanded. Possible expansion of dense
525   // dimensions can be discarded as it does not affect the is_coalesce
526   // property.
527   bool is_coalesced = self.dim()==0 || (self.is_coalesced() && (max_unchanged_dim < min_broadcast_dim || min_broadcast_dim == -1));
528 
529   broadcast_dense_sizes.emplace_back(nnz);
530   for (int64_t i=0; i<self.dense_dim(); i++) {
531     broadcast_dense_sizes.emplace_back(size[sparse_extra_ndim + self.sparse_dim() + i]);
532   }
533 
534   std::vector<int64_t> new_indices_size{sparse_ndim, nnz * nnz_factor};
535   std::vector<int64_t> new_values_size(values.sizes().vec());
536   new_values_size[0] = new_indices_size[1];
537 
538   Tensor new_values = values.expand(broadcast_dense_sizes).repeat_interleave(nnz_factor, 0);
539   Tensor new_indices = indices.new_empty(new_indices_size);
540   if (!broadcast_sizes.empty()) {
541     Tensor broadcast_indices = at::sparse::full_coo_indices(broadcast_sizes, indices.options()).tile(nnz);
542     new_indices.narrow(0, 0, sparse_extra_ndim).copy_(broadcast_indices.narrow(0, 0, sparse_extra_ndim));
543     for (size_t i=0; i<broadcast_dims.size(); i++) {
544       int64_t j=broadcast_dims[i];
545       new_indices.select(0, sparse_extra_ndim + j).copy_(broadcast_indices.select(0, sparse_extra_ndim + i));
546     }
547   }
548   for (int64_t j:unchanged_dims) {
549     new_indices.select(0, sparse_extra_ndim + j).copy_(indices.select(0, j).repeat_interleave(nnz_factor));
550   }
551   return at::sparse_coo_tensor(new_indices, new_values, size, self.options(), is_coalesced);
552 }
553 
broadcast_to_symint(const Tensor & self,SymIntArrayRef size)554 Tensor broadcast_to_symint(const Tensor& self, SymIntArrayRef size) {
555   return self.expand_symint(size);
556 }
557 
broadcast_tensors(TensorList tensors)558 std::vector<Tensor> broadcast_tensors(TensorList tensors) {
559   return expand_outplace(tensors);
560 }
561 
fastCatOutDim0(const Tensor & out,const MaterializedITensorListRef & inputs)562 static void fastCatOutDim0(const Tensor& out, const MaterializedITensorListRef& inputs) {
563   auto outBytes = out.nbytes();
564   char* dataPtr = reinterpret_cast<char*>(out.data_ptr());
565   size_t totalBytes = 0;
566   for (const Tensor& input : inputs) {
567     TORCH_CHECK(outBytes >= totalBytes);
568     if (input.nbytes() > 0) {
569       std::memcpy(dataPtr + totalBytes, input.const_data_ptr(), input.nbytes());
570     }
571     totalBytes += input.nbytes();
572   }
573   TORCH_CHECK(outBytes == totalBytes);
574 }
575 
576 
TORCH_IMPL_FUNC(cat_out_cpu)577 TORCH_IMPL_FUNC(cat_out_cpu)
578 (const ITensorListRef& tensors,
579  int64_t dim,
580  int64_t valid,
581  bool all_contiguous,
582  bool all_same_dtype,
583  bool all_same_sizes_and_stride,
584  MemoryFormat memory_format,
585  const Tensor& result) {
586   if (result.numel() == 0) {
587     return;
588   }
589 
590   auto materialized = tensors.materialize();
591 
592   bool use_serial_kernel = result.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
593   ScalarType dtype = materialized[valid].get().scalar_type();
594   bool serial_dtype = at::isFloatingType(dtype);
595   // fast path for single thread when both inputs and result are contiguous and
596   // not empty, and concat dim is 0
597   if (use_serial_kernel && all_contiguous && all_same_dtype && (MemoryFormat::Contiguous == memory_format)) {
598     if (dim == 0) {
599       fastCatOutDim0(result, materialized);
600       return;
601     }
602     // TODO: Add fast cat for higher dimensions and support multi-threaded fast cat
603   }
604 
605   // fast path for single thread when both inputs and result are contiguous and not empty
606   if (use_serial_kernel && all_contiguous && all_same_dtype && serial_dtype) {
607     cat_serial_stub(kCPU, result, materialized, dim);
608     return;
609   }
610 
611   int64_t offset = 0;
612   if (all_same_sizes_and_stride && result.is_contiguous(memory_format) &&
613       all_same_dtype) {
614     const Tensor& source_slice = materialized[valid];
615     auto slice_dim_size = source_slice.sizes()[dim];
616     auto result_slice = result.narrow(dim, 0, slice_dim_size);
617     auto result_slice_data = result_slice.data_ptr();
618     auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());
619 
620     auto iter = TensorIteratorConfig()
621       .set_check_mem_overlap(false)
622       .resize_outputs(false)
623       .add_output(result_slice)
624       .add_const_input(source_slice)
625       .enforce_safe_casting_to_output(true)
626       .build();
627 
628     for (const Tensor& tensor : materialized) {
629       if (cat_should_skip_tensor(tensor)) {
630         continue;
631       }
632       auto source_data = static_cast<const char*>(tensor.const_data_ptr());
633       auto result_data = static_cast<char*>(result_slice_data) + offset * result_stride_bytes;
634       iter.unsafe_replace_operand(0, result_data);
635       iter.unsafe_replace_operand(1, const_cast<char*>(source_data));
636       copy_stub(iter.device_type(), iter, false);
637       offset += slice_dim_size;
638     }
639   } else {
640     for (const Tensor& tensor: materialized) {
641       if (cat_should_skip_tensor(tensor)) {
642         continue;
643       }
644       auto slice_dim_size = tensor.sizes()[dim];
645       auto result_slice = result.narrow(dim, offset, slice_dim_size);
646 
647       auto iter = TensorIteratorConfig()
648         .set_check_mem_overlap(false)  // Already checked above
649         .resize_outputs(false)
650         .add_output(result_slice)
651         .add_const_input(tensor)
652         .promote_inputs_to_common_dtype(true)
653         .cast_common_dtype_to_outputs(true)
654         .enforce_safe_casting_to_output(true)
655         .build();
656       copy_stub(iter.device_type(), iter, false);
657       offset += slice_dim_size;
658     }
659   }
660 }
661 
cat_out(TensorList tensors,Dimname dim,Tensor & result)662 Tensor& cat_out(TensorList tensors, Dimname dim, Tensor& result) {
663   TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
664   return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim));
665 }
666 
cat(TensorList tensors,Dimname dim)667 Tensor cat(TensorList tensors, Dimname dim) {
668   TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
669   return at::cat(tensors, dimname_to_position(tensors[0], dim));
670 }
671 
672 // torch.concat, alias for torch.cat
concat_out(TensorList tensors,Dimname dim,Tensor & result)673 Tensor& concat_out(TensorList tensors, Dimname dim, Tensor& result) {
674   return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim));
675 }
676 
concat(TensorList tensors,Dimname dim)677 Tensor concat(TensorList tensors, Dimname dim) {
678   return at::cat(tensors, dimname_to_position(tensors[0], dim));
679 }
680 
concat_out(TensorList tensors,int64_t dim,Tensor & result)681 Tensor & concat_out(TensorList tensors, int64_t dim, Tensor & result) {
682   return at::cat_out(result, tensors, dim);
683 }
684 
concat(TensorList tensors,int64_t dim)685 Tensor concat(TensorList tensors, int64_t dim) {
686   return at::cat(tensors, dim);
687 }
688 
689 // torch.concatenate, alias for torch.cat
concatenate_out(TensorList tensors,Dimname dim,Tensor & result)690 Tensor& concatenate_out(TensorList tensors, Dimname dim, Tensor& result) {
691   return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim));
692 }
693 
concatenate(TensorList tensors,Dimname dim)694 Tensor concatenate(TensorList tensors, Dimname dim) {
695   return at::cat(tensors, dimname_to_position(tensors[0], dim));
696 }
697 
concatenate_out(TensorList tensors,int64_t dim,Tensor & result)698 Tensor& concatenate_out(TensorList tensors, int64_t dim, Tensor & result) {
699   return at::cat_out(result, tensors, dim);
700 }
701 
concatenate(TensorList tensors,int64_t dim)702 Tensor concatenate(TensorList tensors, int64_t dim) {
703   return at::cat(tensors, dim);
704 }
705 
sizes_match_except(IntArrayRef s1,IntArrayRef s2,int64_t dim_except)706 static bool sizes_match_except(IntArrayRef s1, IntArrayRef s2, int64_t dim_except /* should already be wrapped */) {
707   if (s1.size() != s2.size()) {
708     return false;
709   }
710   for (const auto i : c10::irange(static_cast<int64_t>(s1.size()))) {
711     if (i != dim_except && s1[i] != s2[i]) {
712       return false;
713     }
714   }
715   return true;
716 }
717 
718 // Check to see if the shape of tensors is compatible
719 // for being concatenated along a given dimension.
check_cat_sparse_dims(Tensor const & t,int64_t pos,IntArrayRef sizes,int64_t wrapped,int64_t sparse_dim,int64_t dense_dim)720 static void check_cat_sparse_dims(Tensor const &t,
721   int64_t pos /* used only for debug messages */,
722   IntArrayRef sizes,
723   int64_t wrapped,
724   int64_t sparse_dim,
725   int64_t dense_dim) {
726     TORCH_CHECK(t.is_sparse(),
727             "Concatenating sparse tensors, but a dense tensor was found at position ", pos, ".");
728     TORCH_CHECK(sizes_match_except(sizes, t.sizes(), wrapped),
729             "All tensors must have the same shape: ", sizes, " (except in the concatenating dimension),"
730             " but found shape: ", t.sizes(), " at position ", pos, ".");
731     TORCH_CHECK(t.sparse_dim() == sparse_dim && t.dense_dim() == dense_dim,
732             "All tensors must have the same sparse_dim and dense_dim: ", sparse_dim, ", ", dense_dim,
733             ", but tensor at position ", pos, " has ", t.sparse_dim(), ", ", t.dense_dim(), ".");
734 }
735 
cat_sparse_impl(const MaterializedITensorListRef & tensors,int64_t dim)736 static Tensor cat_sparse_impl(const MaterializedITensorListRef& tensors, int64_t dim) {
737   std::vector<Tensor> indices;
738   std::vector<Tensor> values;
739   int64_t wrapped = maybe_wrap_dim(dim, tensors[0].get().dim());
740   int64_t sparse_dim = tensors[0].get().sparse_dim();
741   int64_t dense_dim = tensors[0].get().dense_dim();
742   IntArrayRef sizes = tensors[0].get().sizes();
743   if (wrapped < sparse_dim) {
744     for (const auto i : c10::irange(tensors.size())) {
745       const Tensor& t = tensors[i];
746       check_cat_sparse_dims(t, i, sizes, wrapped, sparse_dim, dense_dim);
747       indices.push_back(t._indices());
748       values.push_back(t._values());
749     }
750     Tensor idxs = at::cat(indices, 1);
751     Tensor vals = at::cat(values, 0);
752 
753     // We now need to move the indices of each
754     // input tensor up along `dim` by an appropriate amount.
755     // E.g., if t1 has indices [[2,3,4],[5,6,7]],
756     // and sizes [10, 7]
757     // then torch.cat((t1,t1,t1),1) should have indices
758     // [[2,3,4,2,3,4,2,3,4],[5,6,7,12,13,14,19,20,21]],
759     // so we need to increase idxs[1][3:6] by 7
760     // and idxs[1][6:9] by 14.
761     int64_t col = 0;
762     int64_t cumulative_offset = 0;
763     for (const auto i : c10::irange(tensors.size())) {
764       const Tensor& t = tensors[i];
765       int64_t this_piece_size = t._nnz();
766       // cumulative_offset is zero for the first piece, so
767       // don't waste time doing this operation unless i > 0.
768       if (i > 0) {
769         idxs[wrapped].narrow(0, col, this_piece_size) += cumulative_offset;
770       }
771       cumulative_offset += t.size(wrapped);
772       col += this_piece_size;
773     }
774     auto sizes_copy = sizes.vec();
775     sizes_copy[wrapped] = cumulative_offset;
776     return native::sparse_coo_tensor(
777         idxs,
778         vals,
779         sizes_copy,
780         optTypeMetaToScalarType(tensors[0].get().options().dtype_opt()),
781         tensors[0].get().options().layout_opt(),
782         tensors[0].get().options().device_opt(),
783         tensors[0].get().options().pinned_memory_opt());
784   }
785   else {
786     // Catting along a dense dimension requires us to create new values.
787     // For illustration, consider the sparse 3d tensors t1 and t2,
788     // given by t1 = [[[1,2],[3,4]], ... (zeros) ..., [[5,6],[7,8]]]
789     // and t2 = [... (zeros) ..., [[9, 10], [11,12]], ... (zeros) ...],
790     // Their concatenation along dimension 2 is:
791     // [[[1,2,0,0],[3,4,0,0]], ... (zeros) ..., [[0,0,9,10],[0,0,11,12]], ... (zeros) ..., [[5,6,0,0],[7,8,0,0]]]
792     //
793     // Their values tensors are, respectively,
794     // [[[1,2],[3,4]],[[5,6],[7,8]]] and [[[9,10],[11,12]]].
795     //
796     // and so the values tensor of their concatenation along dim 2 will be:
797     // [[[1,2,0,0],[3,4,0,0]],[[5,6,0,0],[7,8,0,0]],[[0,0,9,10],[0,0,11,12]]]
798     //
799     // which we can get by taking the values tensor of each tensor, catting it with zeros of the appropriate size on the left and right,
800     // and then catting all those results together.
801 
802     // The dimension in each tensor's values object that corresponds to the overall dimension along which we're catting.
803     int64_t values_dim = wrapped - sparse_dim + 1;
804     // The final size along the catted dimension.
805     const int64_t total_size = std::accumulate(
806         tensors.begin(),
807         tensors.end(),
808         static_cast<int64_t>(0),
809         [values_dim](int64_t l, const Tensor& r) {
810           return l + r._values().size(values_dim);
811         });
812     auto zeros_sizes = tensors[0].get()._values().sizes().vec();
813     int64_t cumulative_size = 0;
814     std::vector<Tensor> vals_pieces;
815     std::vector<Tensor> idxs_pieces;
816     for (const auto i : c10::irange(tensors.size())) {
817       const Tensor& t = tensors[i];
818       check_cat_sparse_dims(t, i, sizes, wrapped, sparse_dim, dense_dim);
819       // dimension 0 of values corresponds to the number of values,
820       // rather than to any logical dimension of the sparse tensor.
821       zeros_sizes[0] = t._values().size(0);
822       zeros_sizes[values_dim] = cumulative_size;
823       cumulative_size += t._values().size(values_dim);
824       auto z1 = at::zeros(
825           zeros_sizes,
826           optTypeMetaToScalarType(t._values().options().dtype_opt()),
827           t._values().options().layout_opt(),
828           t._values().options().device_opt(),
829           t._values().options().pinned_memory_opt());
830       zeros_sizes[values_dim] = total_size - cumulative_size;
831       auto z2 = at::zeros(
832           zeros_sizes,
833           optTypeMetaToScalarType(t._values().options().dtype_opt()),
834           t._values().options().layout_opt(),
835           t._values().options().device_opt(),
836           t._values().options().pinned_memory_opt());
837       vals_pieces.push_back(at::cat({z1, t._values(), z2}, values_dim));
838       idxs_pieces.push_back(t._indices());
839     }
840     auto sizes_copy = sizes.vec();
841     sizes_copy[wrapped] = total_size;
842     // This can create an uncoalesced tensor
843     return native::sparse_coo_tensor(
844         at::cat(idxs_pieces, 1),
845         at::cat(vals_pieces),
846         sizes_copy,
847         optTypeMetaToScalarType(tensors[0].get().options().dtype_opt()),
848         tensors[0].get().options().layout_opt(),
849         tensors[0].get().options().device_opt(),
850         tensors[0].get().options().pinned_memory_opt());
851   }
852 }
853 
cat_sparse(const ITensorListRef & tensors,int64_t dim)854 Tensor cat_sparse(const ITensorListRef& tensors, int64_t dim) {
855   auto materialized = tensors.materialize();
856   auto maybe_outnames = namedinference::compute_cat_outnames(materialized);
857   auto result = cat_sparse_impl(materialized, at::legacy_cat_wrap_dim(dim, materialized));
858   namedinference::propagate_names_if_nonempty(result, maybe_outnames);
859   return result;
860 }
861 
block_diag(TensorList tensors)862 Tensor block_diag(TensorList tensors) {
863   Tensor result;
864   if (tensors.empty()) {
865     result = at::empty({1, 0});
866     return result;
867   }
868 
869   const Device& device = tensors[0].device();
870   for (const auto tensor_idx : c10::irange(tensors.size())) {
871     const Tensor& tensor = tensors[tensor_idx];
872 
873     TORCH_CHECK(
874       tensor.device() == device,
875       "torch.block_diag: input tensors must all be on the same device.",
876       " Input 0 is on device ", device,
877       " and input ", tensor_idx, " is on device ", tensor.device()
878     );
879   }
880 
881   ScalarType output_scalar_type = native::result_type(tensors);
882   int64_t result_dim0 = 0;
883   int64_t result_dim1 = 0;
884   std::vector<Tensor> tensors_2D(tensors.size());
885 
886   // Sum the dimensions of the tensors, check tensor sizes,
887   // and expand all 0-D and 1-D tensors so that everything
888   // is 2-D
889   for (const auto tensor_idx : c10::irange(tensors.size())) {
890     const Tensor& tensor = tensors[tensor_idx];
891     int64_t ndims = tensor.dim();
892     TORCH_CHECK(
893       ndims <= 2,
894       "torch.block_diag: Input tensors must have 2 or fewer dimensions. Input ",
895       tensor_idx, " has ", ndims, " dimensions"
896     );
897 
898     int64_t dim0 = 1;
899     int64_t dim1 = 1;
900 
901     if (ndims == 2) {
902       dim0 = tensor.size(0);
903       dim1 = tensor.size(1);
904       tensors_2D[tensor_idx] = tensor;
905     } else if (ndims == 1) {
906       // Switching dim 0 to dim 1 is intentional
907       dim1 = tensor.size(0);
908       tensors_2D[tensor_idx] = tensor.expand({dim0, dim1});
909     } else {
910       tensors_2D[tensor_idx] = tensor.expand({dim0, dim1});
911     }
912     result_dim0 += dim0;
913     result_dim1 += dim1;
914   }
915 
916   result = at::zeros(
917     {result_dim0, result_dim1},
918     tensors[0].options().dtype(output_scalar_type)
919   );
920 
921   int64_t cur_dim0 = 0;
922   int64_t cur_dim1 = 0;
923 
924   // Copy each tensor into the appropriate location in the result matrix
925   for (const auto& tensor : tensors_2D) {
926     int64_t dim0 = tensor.size(0);
927     int64_t dim1 = tensor.size(1);
928     result.slice(0, cur_dim0, cur_dim0+dim0).slice(1, cur_dim1, cur_dim1+dim1).copy_(tensor);
929 
930     cur_dim0 += dim0;
931     cur_dim1 += dim1;
932   }
933 
934   return result;
935 }
936 
chunk(const Tensor & self,int64_t chunks,int64_t dim)937 std::vector<Tensor> chunk(const Tensor& self, int64_t chunks, int64_t dim) {
938   TORCH_CHECK(self.dim() > 0,
939            "chunk expects at least a 1-dimensional tensor");
940   TORCH_CHECK(chunks > 0,
941            "chunk expects `chunks` to be greater than 0, got: ", chunks);
942 
943   const auto dim_size = self.sym_size(dim);
944   auto split_size = (dim_size + chunks - 1) / chunks;
945 
946   // We need to call split_with_sizes in the case where split_size and dimension size are 0, because
947   // a call to split would discard the number of chunks (because we can have an arbitrary number of
948   // 0-sized chunks adding up to 0).  So, call split_with_sizes with the correct number of chunks,
949   // eventually we will do this for all cases.
950   if (split_size == 0 && dim_size == 0) {
951     std::vector<c10::SymInt> split_sizes(chunks, split_size);
952     split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size);
953     return self.split_with_sizes_symint(split_sizes, dim);
954   } else {
955     return self.split_symint(std::move(split_size), dim);
956   }
957 }
958 
tensor_split_sections_symint(const Tensor & self,c10::SymInt sym_sections,int64_t dim)959 std::vector<Tensor> tensor_split_sections_symint(const Tensor& self, c10::SymInt sym_sections, int64_t dim) {
960   TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
961   int64_t dim_ = maybe_wrap_dim(dim, self.dim());
962   // NB: intentional, sections specifies number of output tensors, which
963   // cannot be polymorphic
964   int64_t sections = sym_sections.guard_int(__FILE__, __LINE__);
965   TORCH_CHECK(sections > 0, "number of sections must be larger than 0, got ", sections);
966   const auto dim_size = self.sym_size(dim_);
967   std::vector<Tensor> splits(sections);
968   auto min_split_size = dim_size / sections;
969   auto num_splits_one_extra = dim_size % sections;
970   c10::SymInt start_idx = 0;
971   for (const auto split_idx : c10::irange(sections)) {
972     auto split_size = (num_splits_one_extra > split_idx) ? (min_split_size + 1) : min_split_size;
973     splits[split_idx] = at::slice_symint(self, dim_, start_idx, start_idx + split_size);
974     start_idx += split_size;
975   }
976   return splits;
977 }
978 
979 template <typename T>
_tensor_split_indices(const Tensor & self,ArrayRef<T> indices,int64_t dim)980 std::vector<Tensor> _tensor_split_indices(const Tensor& self, ArrayRef<T> indices, int64_t dim) {
981   TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
982   int64_t dim_ = maybe_wrap_dim(dim, self.dim());
983   int64_t num_indices = indices.size();
984   std::vector<Tensor> splits(num_indices + 1);
985   T start_idx(0);
986   for (const auto split_idx : c10::irange(num_indices)) {
987     auto end_idx = indices[split_idx];
988     splits[split_idx] = at::symint::slice<T>(self, dim_, start_idx, end_idx);
989     start_idx = end_idx;
990   }
991   splits[num_indices] = at::symint::slice<T>(self, dim_, start_idx, at::symint::size<T>(self, dim_));
992   return splits;
993 }
994 
tensor_split(const Tensor & self,IntArrayRef indices,int64_t dim)995 std::vector<Tensor> tensor_split(const Tensor& self, IntArrayRef indices, int64_t dim) {
996   return _tensor_split_indices(self, indices, dim);
997 }
998 
tensor_split_indices_symint(const Tensor & self,SymIntArrayRef indices,int64_t dim)999 std::vector<Tensor> tensor_split_indices_symint(const Tensor& self, SymIntArrayRef indices, int64_t dim) {
1000   return _tensor_split_indices(self, indices, dim);
1001 }
1002 
tensor_split(const Tensor & self,const Tensor & tensor_indices_or_sections,int64_t dim)1003 std::vector<Tensor> tensor_split(const Tensor& self, const Tensor& tensor_indices_or_sections, int64_t dim) {
1004   TORCH_CHECK(self.dim() > 0, "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ", self.dim()," dims");
1005   auto split_device = tensor_indices_or_sections.device();
1006   TORCH_CHECK(split_device == kCPU,
1007     "tensor_split expected tensor_indices_or_sections to be on cpu, but it's on ", split_device);
1008   auto split_dtype = tensor_indices_or_sections.scalar_type();
1009   TORCH_CHECK(split_dtype == at::kLong,
1010     "tensor_split expected tensor_indices_or_sections to have dtype of long, but got ", split_dtype);
1011   auto split_dim = tensor_indices_or_sections.dim();
1012   TORCH_CHECK(split_dim == 1 || split_dim == 0,
1013     "tensor_split expected tensor_indices_or_sections to be a zero-dimensional or one-dimensional tensor, but got a tensor with ", split_dim, " dims");
1014 
1015   if (split_dim == 0) {
1016     int64_t sections = tensor_indices_or_sections.item<int64_t>();
1017     return self.tensor_split(sections, dim);
1018   } else {
1019     auto indices_data = tensor_indices_or_sections.const_data_ptr<int64_t>();
1020     auto stride = tensor_indices_or_sections.stride(0);
1021     auto numel = tensor_indices_or_sections.numel();
1022     std::vector<int64_t> indices(numel);
1023     for (const auto offset : c10::irange(numel)) {
1024       // indices tensor could be non-contiguous
1025       indices[offset] = *(indices_data + offset * stride);
1026     }
1027     return self.tensor_split(indices, dim);
1028   }
1029 }
1030 
unsafe_chunk(const Tensor & self,int64_t chunks,int64_t dim)1031 std::vector<Tensor> unsafe_chunk(const Tensor& self, int64_t chunks, int64_t dim) {
1032   TORCH_CHECK(self.dim() > 0,
1033            "chunk expects at least a 1-dimensional tensor");
1034   TORCH_CHECK(chunks > 0,
1035            "chunk expects `chunks` to be greater than 0, got: ", chunks);
1036 
1037   const auto dim_size = self.size(dim);
1038   int64_t split_size = (dim_size + chunks - 1) / chunks;
1039 
1040   // See the comment above in chunk(...)
1041   if (split_size == 0 && dim_size == 0) {
1042     std::vector<int64_t> split_sizes(chunks, split_size);
1043     split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size);
1044     return self.unsafe_split_with_sizes(split_sizes, dim);
1045   } else {
1046     return self.unsafe_split(split_size, dim);
1047   }
1048 }
1049 
diagflat(const Tensor & self,int64_t offset)1050 Tensor diagflat(const Tensor& self, int64_t offset) {
1051   return self.contiguous().view(-1).diag(offset);
1052 }
1053 
diagonal(const Tensor & self,int64_t offset,int64_t dim1_,int64_t dim2_)1054 Tensor diagonal(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim2_) {
1055   int64_t nDims = self.dim();
1056   int64_t dim1 = maybe_wrap_dim(dim1_, nDims);
1057   int64_t dim2 = maybe_wrap_dim(dim2_, nDims);
1058   TORCH_CHECK(dim1 != dim2, "diagonal dimensions cannot be identical ", dim1_, ", ", dim2_);
1059   auto outnames = namedinference::compute_diagonal_outnames(self, dim1, dim2);
1060   NoNamesGuard no_names_guard;
1061 
1062   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1063   int64_t diag_size;
1064   int64_t storage_offset = self.storage_offset();
1065   // compute storage offset and size for the diagonal
1066   // for positive values of offset (above the main diagonal)
1067   // "leftmost columns" (along dim2) are dropped
1068   // for negative values of offset (below the main diagonal)
1069   // "topmost rows" (along dim1) are dropped.
1070   // Note that we invert +/- in the second to absorb the negative
1071   // sign in the offset.
1072   if (offset >= 0) {
1073     diag_size = std::max<int64_t>(std::min(self.size(dim1), self.size(dim2)-offset), 0);
1074   } else {
1075     diag_size = std::max<int64_t>(std::min(self.size(dim1)+offset, self.size(dim2)), 0);
1076   }
1077 
1078   // NumPy allows you to specify offsets "off the end"; let's just be careful not to
1079   // set a ridiculous storage_offset in that case (technically it shouldn't matter
1080   // because there are no elements in the tensor, but let's be kosher).
1081   if (diag_size == 0) {
1082     // skip
1083   } else if (offset >= 0) {
1084     storage_offset += offset * self.stride(dim2);
1085   } else {
1086     storage_offset -= offset * self.stride(dim1);
1087   }
1088 
1089   // construct new size and stride: we drop dim1 and dim2 (maximum first for not changing the index of the minimum)
1090   // the new ("joint") dimension is appended to the end of the shape / stride to match numpy semantics
1091   DimVector sizes(self.sizes().begin(), self.sizes().end());
1092   DimVector strides(self.strides().begin(), self.strides().end());
1093   sizes.erase(sizes.begin() + std::max(dim1, dim2));
1094   strides.erase(strides.begin() + std::max(dim1, dim2));
1095   sizes.erase(sizes.begin() + std::min(dim1, dim2));
1096   strides.erase(strides.begin() + std::min(dim1, dim2));
1097   sizes.push_back(diag_size);
1098   strides.push_back(self.stride(dim1)+self.stride(dim2));
1099 
1100   // return view with new parameters
1101   auto result = self.as_strided(sizes, strides, storage_offset);
1102 
1103   no_names_guard.reset();
1104   namedinference::propagate_names_if_nonempty(result, outnames);
1105   return result;
1106 }
1107 
diagonal(const Tensor & self,Dimname outdim,Dimname dim1,Dimname dim2,int64_t offset)1108 Tensor diagonal(const Tensor& self, Dimname outdim, Dimname dim1, Dimname dim2, int64_t offset) {
1109   auto result = at::diagonal(
1110       self,
1111       offset,
1112       dimname_to_position(self, dim1),
1113       dimname_to_position(self, dim2));
1114   // This is slower than it needs to be because there is no way to modify
1115   // the names of a tensor in-place right now. In the future we should consider
1116   // offering that functionality.
1117   std::vector<Dimname> new_names = result.names().vec();
1118   new_names[new_names.size() - 1] = outdim;
1119   return result.refine_names(new_names);
1120 }
1121 
diag_embed(const Tensor & self,int64_t offset,int64_t dim1_,int64_t dim2_)1122 Tensor diag_embed(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim2_) {
1123   int64_t nDims = self.dim() + 1;
1124   int64_t dim1 = maybe_wrap_dim(dim1_, nDims);
1125   int64_t dim2 = maybe_wrap_dim(dim2_, nDims);
1126   TORCH_CHECK(dim1 != dim2, "diagonal dimensions cannot be identical ", dim1_, ", ", dim2_);
1127   int64_t new_dim_len = std::abs(offset) + self.size(-1);
1128   auto sizes = self.sizes().vec();
1129   sizes.pop_back();
1130   sizes.insert(sizes.begin() + std::min(dim1, dim2), new_dim_len);
1131   sizes.insert(sizes.begin() + std::max(dim1, dim2), new_dim_len);
1132   auto result = at::zeros(sizes, self.options());
1133   auto diag = result.diagonal(offset, dim1, dim2);
1134   diag.copy_(self);
1135   return result;
1136 }
1137 
expand(const Tensor & self,c10::IntArrayRef size,bool)1138 Tensor expand(const Tensor& self, c10::IntArrayRef size, bool /*unused*/) {
1139   TORCH_CHECK(size.size() >= (size_t)self.dim(),
1140            "expand(", self.toString(), "{", self.sizes(), "}, size=", size,
1141            "): the number of sizes provided (", size.size(), ") ",
1142            "must be greater or equal to the number of dimensions in the tensor (",
1143            self.dim(), ")");
1144   TORCH_CHECK(!self.is_sparse() && !at::sparse_csr::is_sparse_compressed(self),
1145             "expand is unsupported for ", self.layout(), " tensors");
1146 
1147   auto expandedSizesAndStrides = inferExpandGeometry_dimvector(self.sizes(), self.strides(), size);
1148 
1149   auto result = self.as_strided(
1150       expandedSizesAndStrides.sizes, expandedSizesAndStrides.strides);
1151   namedinference::propagate_names_for_expand(result, self);
1152   return result;
1153 }
1154 
expand_as(const Tensor & self,const Tensor & other)1155 Tensor expand_as(const Tensor& self, const Tensor& other) {
1156   return self.expand_symint(other.sym_sizes());
1157 }
1158 
sum_to_size_symint(const Tensor & self,SymIntArrayRef size)1159 Tensor sum_to_size_symint(const Tensor& self, SymIntArrayRef size) {
1160   TORCH_CHECK(is_expandable_to(size, self.sym_sizes()),
1161            "size {", size, "} is not expandable to size {", self.sizes(), "}.");
1162 
1163   return sum_to(self, size);
1164 }
1165 
1166 // We currently do not support per-channel quant for unfold, diagonal, expand, permute.
1167 // TODO: Make this an aten function and replace as_strided_qtensorimpl once that is done.
make_qtensor(const Tensor & self,IntArrayRef size,IntArrayRef stride,QuantizerPtr quantizer)1168 static Tensor make_qtensor(const Tensor& self, IntArrayRef size, IntArrayRef stride, QuantizerPtr quantizer) {
1169   auto result = at::detail::make_tensor<QTensorImpl>(
1170       c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype(), quantizer);
1171   setStrided(result, size, stride, self.storage_offset());
1172   return result;
1173 }
1174 
as_strided_tensorimpl(const Tensor & self,IntArrayRef size,IntArrayRef stride,std::optional<int64_t> storage_offset_)1175 Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, std::optional<int64_t> storage_offset_) {
1176   TORCH_INTERNAL_ASSERT(!self.is_mps(), "as_strided_tensorimpl does not work with MPS; call self.as_strided(...) instead");
1177   auto storage_offset = storage_offset_.value_or(self.storage_offset());
1178   auto result = at::detail::make_tensor<TensorImpl>(
1179       c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
1180   setStrided(result, size, stride, storage_offset);
1181   return result;
1182 }
1183 
1184 template <typename T>
setStridedUnchecked(const Tensor & self,ArrayRef<T> size,ArrayRef<T> stride,T && storage_offset)1185 inline void setStridedUnchecked(
1186     const Tensor& self,
1187     ArrayRef<T> size,
1188     ArrayRef<T> stride,
1189     T&& storage_offset) {
1190   auto* self_ = self.unsafeGetTensorImpl();
1191   self_->set_sizes_and_strides(size, stride, std::make_optional(std::forward<T>(storage_offset)));
1192 }
1193 
as_strided_tensorimpl_meta_symint(const Tensor & self,SymIntArrayRef sym_size,SymIntArrayRef sym_stride,std::optional<c10::SymInt> sym_storage_offset_)1194 Tensor as_strided_tensorimpl_meta_symint(const Tensor& self, SymIntArrayRef sym_size, SymIntArrayRef sym_stride, std::optional<c10::SymInt> sym_storage_offset_) {
1195   auto sym_storage_offset = sym_storage_offset_.value_or(self.sym_storage_offset());
1196   auto result = at::detail::make_tensor<TensorImpl>(
1197       c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
1198   // NB: The reason this is unchecked is to ensure we don't generate
1199   // guards on the base storage itself when performing as_strided calls.
1200   // Although technically these guards are necessary, in practice they
1201   // cause a lot of guards that falsely refer to base symbols.  We will instead
1202   // rely on AOTAutograd to sort out if we actually have dependence on view
1203   // bases / storage size.
1204   setStridedUnchecked(result, sym_size, sym_stride, std::move(sym_storage_offset));
1205   return result;
1206 }
1207 
as_strided_qtensorimpl(const Tensor & self,IntArrayRef size,IntArrayRef stride,std::optional<int64_t> storage_offset_)1208 Tensor as_strided_qtensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, std::optional<int64_t> storage_offset_) {
1209   auto storage_offset = storage_offset_.value_or(self.storage_offset());
1210   auto quantizer = get_qtensorimpl(self)->quantizer();
1211   TORCH_CHECK(
1212       quantizer->qscheme() == QScheme::PER_TENSOR_AFFINE,
1213       "Setting strides is possible only on uniformly quantized tensor");
1214   auto result = at::detail::make_tensor<QTensorImpl>(
1215       c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype(), quantizer);
1216   setStrided(result, size, stride, storage_offset);
1217   return result;
1218 }
1219 
1220 // This is an overloaded function similar to
1221 // Tensor as_strided_qtensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, std::optional<int64_t> storage_offset_)
1222 // and is currently not available through the dispatcher. The additional
1223 // input, quantizer, is called by the select & slice methods.
1224 // TODO: Make this function compatible with the dispatcher
as_strided_qtensorimpl(const Tensor & self,IntArrayRef size,IntArrayRef stride,std::optional<int64_t> storage_offset_,QuantizerPtr quantizer)1225 static Tensor as_strided_qtensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, std::optional<int64_t> storage_offset_,
1226   QuantizerPtr quantizer) {
1227   auto storage_offset = storage_offset_.value_or(self.storage_offset());
1228   TORCH_CHECK(
1229       (quantizer->qscheme() == QScheme::PER_TENSOR_AFFINE) ||
1230       (quantizer->qscheme() == QScheme::PER_CHANNEL_AFFINE),
1231       "Setting strides is possible only on uniformly or per channel quantized tensors");
1232   auto result = at::detail::make_tensor<QTensorImpl>(
1233       c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype(), quantizer);
1234   setStrided(result, size, stride, storage_offset);
1235   return result;
1236 }
1237 
as_strided__symint(const Tensor & self,SymIntArrayRef size,SymIntArrayRef stride,std::optional<c10::SymInt> storage_offset_)1238 const Tensor &as_strided__symint(const Tensor& self, SymIntArrayRef size, SymIntArrayRef stride, std::optional<c10::SymInt> storage_offset_) {
1239   auto storage_offset = storage_offset_.value_or(self.sym_storage_offset());
1240   setStrided(self, size, stride, std::move(storage_offset));
1241   return self;
1242 }
1243 
1244 // Should just use narrow_copy_out, but this API is used internally at Meta:
1245 // https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
narrow_copy_dense_cpu(const Tensor & self,int64_t dim,int64_t start,int64_t length)1246 Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
1247   // narrow_copy_dense_cpu_out always resize output's size, so there only create
1248   // a zero size tensor.
1249   auto output = at::empty({0}, self.options());
1250   return narrow_copy_dense_cpu_out(self, dim, start, length, output);
1251 }
1252 
narrow_copy_sparse(const Tensor & self,int64_t dim,int64_t start,int64_t length)1253 Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
1254   int64_t allDim = self.dim();
1255   int64_t end = start+length;
1256   TORCH_CHECK(allDim > 0, "narrow() cannot be applied to a 0-dim tensor.");
1257   TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
1258   TORCH_CHECK(dim >= 0 && dim < allDim,
1259     "Dimension ", dim, " out of range. Expecting 0 <= dim < ", allDim, ".");
1260   TORCH_CHECK(start >= 0 && end <= self.size(dim),
1261     "Invalid range to narrow. range(start, start+length) must be a subset of range(0, ", self.size(dim), ").")
1262   Tensor indices = self._indices();
1263   int64_t sparse_dim = self.sparse_dim();
1264 
1265   std::vector<int64_t> new_sizes = self.sizes().vec();
1266   new_sizes[dim] = length;
1267 
1268   Tensor new_values;
1269   Tensor new_indices;
1270   if (dim < sparse_dim) {
1271     Tensor mask = (indices[dim] >= start).__and__((indices[dim] < end));
1272     new_indices = indices.masked_select(mask).view({sparse_dim, -1});
1273     new_indices[dim].sub_(start);
1274     Tensor nzIndices = mask.nonzero().view(-1);
1275     new_values = self._values().index_select(0, nzIndices);
1276   } else {
1277     /* This means we are narrowing on a dense dim, which is in effect just a
1278         regular narrow on _values() */
1279     new_indices = indices;
1280     int64_t dense_dim = dim - sparse_dim + 1;
1281     new_values = self._values().narrow_copy(dense_dim, start, length);
1282   }
1283 
1284   return at::sparse_coo_tensor(new_indices, new_values, new_sizes, self.options(), self.is_coalesced());
1285 }
1286 
1287 // Should just use narrow_copy_out, but this API is used internally at Meta:
1288 // https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
narrow_copy_dense_cpu_out(const Tensor & self,int64_t dim,int64_t start,int64_t length,Tensor & output)1289 Tensor& narrow_copy_dense_cpu_out(
1290   const Tensor& self, int64_t dim, int64_t start, int64_t length, Tensor& output
1291 ) {
1292 
1293   TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
1294   TORCH_CHECK(self.dtype() == output.dtype());
1295 
1296   auto self_contig = self.expect_contiguous();
1297   const auto self_sizes = self_contig->sizes();
1298 
1299   // wrap dim if negative and do bound check
1300   if (dim < 0) {
1301     dim = at::maybe_wrap_dim(dim, self_sizes.size());
1302   } else {
1303     TORCH_CHECK(dim < static_cast<int64_t>(self_sizes.size()));
1304   }
1305 
1306   // wrap start and do bound check
1307   const auto cur_size = self_sizes[dim];
1308   TORCH_CHECK_INDEX(
1309     -cur_size <= start && start <= cur_size,
1310     "start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
1311   )
1312   if (start < 0) {
1313     start = start + cur_size;
1314   }
1315   TORCH_CHECK(
1316       length >= 0 && start <= cur_size - length,
1317       "start (",
1318       start,
1319       ") + length (",
1320       length,
1321       ") exceeds dimension size (",
1322       cur_size,
1323       ").");
1324 
1325   // resize output
1326   auto output_sizes = self_sizes.vec();
1327   output_sizes[dim] = length;
1328   at::native::resize_(output, output_sizes);
1329 
1330   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1331   const int64_t unit = c10::size_from_dim_(dim + 1, self_sizes);
1332   const int64_t num_blocks = c10::size_to_dim_(dim, self_sizes);
1333 
1334   const auto itemsize = self_contig->dtype().itemsize();
1335   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
1336   size_t src_nbytes = itemsize * self_contig->numel();
1337   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
1338   size_t dst_nbytes = itemsize * output.numel();
1339 
1340   size_t src_block_size = unit * self_sizes[dim];
1341   size_t dst_block_size = unit * length;
1342 
1343   if (num_blocks == 0 || dst_block_size == 0) {
1344     return output;
1345   }
1346 
1347   const char* src_bytes = static_cast<const char*>(self_contig->const_data_ptr());
1348   char* dst_bytes = static_cast<char*>(output.data_ptr());
1349 
1350   size_t src_block_size_bytes = itemsize * src_block_size;
1351   size_t dst_block_size_bytes = itemsize * dst_block_size;
1352   size_t src_offset = unit * start;
1353 
1354   const char* src_offset_bytes = src_bytes + itemsize * src_offset;
1355   char* dst_offset_bytes = dst_bytes;
1356 
1357   for (const auto i : c10::irange(num_blocks)) {
1358     const char* local_src_offset_bytes = src_offset_bytes + i * src_block_size_bytes;
1359     char* local_dst_offset_bytes = dst_offset_bytes + i * dst_block_size_bytes;
1360     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1361         static_cast<const void*>(local_src_offset_bytes + dst_block_size_bytes) <=
1362         static_cast<const void*>(src_bytes + src_nbytes));
1363     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1364         static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes) <=
1365         static_cast<void*>(dst_bytes + dst_nbytes));
1366 
1367     memcpy(
1368         local_dst_offset_bytes, local_src_offset_bytes, dst_block_size_bytes);
1369   }
1370   return output;
1371 }
1372 
narrow(const Tensor & self,int64_t dim,int64_t start,int64_t length)1373 Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
1374   TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
1375   TORCH_CHECK(length >= 0, "narrow(): length must be non-negative.");
1376   auto cur_size = self.size(dim);
1377   TORCH_CHECK_INDEX(
1378     -cur_size <= start && start <= cur_size,
1379     "start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
1380   )
1381   if (start < 0) {
1382     start = start + cur_size;
1383   }
1384   TORCH_CHECK(start <= cur_size - length,
1385            "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
1386   return at::slice(self, dim, start, start + length, 1);
1387 }
1388 
narrow_symint(const Tensor & self,int64_t dim,SymInt start,SymInt length)1389 Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) {
1390   TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
1391   TORCH_SYM_CHECK(length.sym_ge(0), "narrow(): length must be non-negative.");
1392   auto cur_size = self.sym_size(dim);
1393   TORCH_CHECK_INDEX(
1394     ((-cur_size).sym_le(start).sym_and(start.sym_le(cur_size))).expect_true(__FILE__, __LINE__),
1395     "start out of range (expected to be in range of [", -cur_size, ", ", cur_size, "], but got ", start, ")"
1396   )
1397   if (start < 0) {
1398     start = start + cur_size;
1399   }
1400   TORCH_SYM_CHECK(start.sym_le(cur_size - length),
1401            "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ").");
1402   return at::slice_symint(self, dim, start, start + length, 1);
1403 }
1404 
1405 // This overload exists purely for XLA, because they wanted to pass in "symbolic"
1406 // start via Tensor.
narrow_tensor_symint(const Tensor & self,int64_t dim,const Tensor & start,SymInt length)1407 Tensor narrow_tensor_symint(const Tensor& self, int64_t dim, const Tensor& start, SymInt length) {
1408   TORCH_CHECK(start.dim() == 0 && isIntegralType(start.scalar_type(), /*includeBool=*/false),
1409               "start must be an 0-dim integral Tensor.");
1410   int64_t st = start.item<int64_t>();
1411   return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length));
1412 }
1413 
1414 std::tuple<DimVector, DimVector, std::vector<int64_t>>
_permute_size_stride_estimation(const Tensor & self,IntArrayRef dims)1415 static _permute_size_stride_estimation(const Tensor& self, IntArrayRef dims) {
1416   const auto ndim = self.dim();
1417   TORCH_CHECK(ndim == static_cast<int64_t>(dims.size()),
1418       "permute(sparse_coo): number of dimensions in the tensor input ",
1419       "does not match the length of the desired ordering of dimensions ",
1420       "i.e. input.dim() = ", ndim, " is not equal to len(dims) = ", dims.size());
1421 
1422   const auto is_strided_layout = self.options().layout() == at::kStrided;
1423   const auto old_sizes = self.sizes();
1424   const auto old_strides = is_strided_layout ? self.strides() : IntArrayRef{};
1425 
1426   auto new_sizes = DimVector(ndim);
1427   auto new_strides = DimVector(is_strided_layout ? ndim : 0);
1428   auto wrapped_dims = std::vector<int64_t>(ndim);
1429   std::vector<bool> seen_dims(ndim);
1430 
1431   for (const auto i : c10::irange(ndim)) {
1432     const auto d = maybe_wrap_dim(dims[i], ndim);
1433     TORCH_CHECK(!seen_dims[d],
1434         "permute(): duplicate dims are not allowed.");
1435     seen_dims[d] = true;
1436     wrapped_dims[i] = d;
1437     new_sizes[i] = old_sizes[d];
1438     if (is_strided_layout) {
1439       new_strides[i] = old_strides[d];
1440     }
1441   }
1442 
1443   return std::make_tuple(new_sizes, new_strides, wrapped_dims);
1444 }
1445 
permute(const Tensor & self,IntArrayRef dims)1446 Tensor permute(const Tensor& self, IntArrayRef dims) {
1447   auto [new_sizes, new_strides, _] = _permute_size_stride_estimation(self, dims);
1448   return self.as_strided(new_sizes, new_strides);
1449 }
1450 
permute_sparse_coo(const Tensor & self,IntArrayRef dims)1451 Tensor permute_sparse_coo(const Tensor& self, IntArrayRef dims) {
1452   auto [new_sizes, _, wrapped_dims] = _permute_size_stride_estimation(self, dims);
1453 
1454   const auto ndim = self.dim();
1455   const auto sparse_ndim = self.sparse_dim();
1456   const auto dense_ndim = self.dense_dim();
1457 
1458   auto dims_id_perm = std::vector<int64_t>(ndim);
1459   auto dims_sparse_dense_id_perm = std::vector<int64_t>(ndim);
1460   for (const auto i : c10::irange(ndim)) {
1461     dims_id_perm[i] = i;
1462     dims_sparse_dense_id_perm[i] = wrapped_dims[i];
1463   }
1464   std::sort(dims_sparse_dense_id_perm.begin(), dims_sparse_dense_id_perm.begin() + sparse_ndim);
1465   std::sort(dims_sparse_dense_id_perm.begin() + sparse_ndim, dims_sparse_dense_id_perm.end());
1466   TORCH_CHECK(dims_sparse_dense_id_perm == dims_id_perm,
1467       "permute(sparse_coo): transpositions between sparse and dense dimensions are not allowed.",
1468       "Only transpositions within sparse and dense dimensions are supported.");
1469 
1470   const auto slice = [](std::vector<int64_t> v, size_t begin, size_t len) -> decltype(v) {
1471     return std::vector<int64_t>{v.begin() + begin, v.begin() + begin + len};
1472   };
1473 
1474   auto old_sparse_dims = slice(dims_id_perm, 0, sparse_ndim);
1475   auto old_dense_dims = slice(std::move(dims_id_perm), sparse_ndim, ndim - sparse_ndim);
1476   auto new_sparse_dims = slice(wrapped_dims, 0, sparse_ndim);
1477   auto new_dense_dims = slice(std::move(wrapped_dims), sparse_ndim, ndim - sparse_ndim);
1478 
1479   auto old_indices = self._indices();
1480   auto old_values = self._values();
1481 
1482   const auto new_indices = (new_sparse_dims == old_sparse_dims)
1483     ? std::move(old_indices)
1484     : [&]() -> Tensor {
1485       auto sparse_perm_tensor = at::from_blob(reinterpret_cast<void*>(new_sparse_dims.data()),
1486           {sparse_ndim}, old_indices.options().device(at::kCPU));
1487       // creates new indices. It is possible to avoid that if COO
1488       // is allowed to store a permutation vector.
1489       return old_indices.index_select(0, sparse_perm_tensor.to(self.device().type()));
1490     }();
1491   const auto new_values = (new_dense_dims == old_dense_dims)
1492     ? std::move(old_values)
1493     : [&]() -> Tensor {
1494       auto values_perm = std::vector<int64_t>(dense_ndim + 1);
1495       for (const auto i : c10::irange(dense_ndim)) {
1496         values_perm[i + 1] = new_dense_dims[i] - sparse_ndim + 1;
1497       }
1498       return old_values.permute(values_perm);
1499     }();
1500   const auto is_coalesced = self.is_coalesced() && (dims.empty() || dims[0] == 0);
1501   // TODO: apply `is_coalesced ||= new_values.size(0) < 2`.
1502   return _sparse_coo_tensor_with_dims_and_tensors(
1503        sparse_ndim, dense_ndim, new_sizes, new_indices, new_values, self.options(), is_coalesced);
1504 }
1505 
repeat(const Tensor & self,IntArrayRef repeats)1506 Tensor repeat(const Tensor& self, IntArrayRef repeats) {
1507   TORCH_CHECK(repeats.size() >= (size_t)self.dim(),
1508            "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
1509 
1510   // Add new leading dimensions to the tensor if the
1511   // number of target dimensions is larger than the
1512   // number of source dimensions.
1513   int64_t num_new_dimensions = repeats.size() - self.dim();
1514   DimVector padded_size(num_new_dimensions, 1);
1515   padded_size.insert(padded_size.end(), self.sizes().begin(), self.sizes().end());
1516   DimVector target_size(repeats.size());
1517   bool zero_tensor = false;
1518   for(const auto idx : c10::irange(repeats.size())) {
1519     if (repeats[idx] == 0) {
1520       zero_tensor = true;
1521     }
1522     target_size[idx] = padded_size[idx] * repeats[idx];
1523   }
1524 
1525   Tensor xtensor = self.expand(padded_size);
1526 
1527   Tensor result;
1528   if (self.is_quantized()) {
1529     result = at::empty_quantized(target_size, self);
1530   } else {
1531     result = at::empty(target_size, self.options());
1532   }
1533 
1534   // return an empty tensor if one of the repeat dimensions is zero
1535   if (zero_tensor) {
1536     return result;
1537   }
1538 
1539   Tensor urtensor = at::alias(result);
1540   for (const auto i : c10::irange(xtensor.dim())) {
1541     // can't unfold with step 0, so make sure step is at least 1
1542     // (it doesn't matter what it is in that case, because the size is 0).
1543     auto size_i = xtensor.sizes()[i];
1544     urtensor = urtensor.unfold(i, size_i, std::max<int64_t>(size_i, 1));
1545   }
1546 
1547   urtensor.copy_(xtensor.expand_as(urtensor));
1548 
1549   return result;
1550 }
1551 
tile_symint(const Tensor & self,SymIntArrayRef reps)1552 Tensor tile_symint(const Tensor& self, SymIntArrayRef reps){
1553   // If self.size() > len(reps), reps is promoted to self.size() by pre-pending
1554   // 1’s to it to keep the same behaviour as `numpy.tile`.
1555   // Thus for a tensor of shape (2, 3, 4, 5), a dims of (2, 2) is treated
1556   // as (1, 1, 2, 2).
1557   const int64_t size_diff = self.dim() - static_cast<int64_t>(reps.size());
1558   if (size_diff > 0){
1559     std::vector<c10::SymInt> new_reps(size_diff, 1);
1560     for (const auto i : c10::irange(reps.size())) {
1561       new_reps.emplace_back(reps[i]);
1562     }
1563     return self.repeat_symint(SymIntArrayRef(new_reps));
1564   }
1565   // `torch.tile` is equivalent to the already implemented `torch.Tensor.repeat`
1566   return self.repeat_symint(reps);
1567 }
1568 
1569 //
1570 // templated for ArrayRef<int64_t> and SmallVector<int64_t> use cases
1571 //
1572 template <typename Vec>
alias_with_sizes_and_strides(const Tensor & self,const Vec & sizes,const Vec & strides)1573 Tensor alias_with_sizes_and_strides(
1574     const Tensor& self,
1575     const Vec& sizes,
1576     const Vec& strides) {
1577   //caller should make sure that sizes and strides are valid for self
1578   //(storage is sufficient, strides are non-negative, strides and sizes array size is the same)
1579   Tensor self_;
1580   if (self.is_quantized()) {
1581     self_ = at::detail::make_tensor<QTensorImpl>(
1582       c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype(), get_qtensorimpl(self)->quantizer());
1583     auto* self_tmp_ = self_.unsafeGetTensorImpl();
1584     self_tmp_->set_storage_offset(self.storage_offset());
1585     self_tmp_->set_sizes_and_strides(sizes, strides);
1586   } else {
1587     self_ = at::detail::make_tensor<TensorImpl>(
1588       c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
1589     auto* self_tmp_ = self_.unsafeGetTensorImpl();
1590     self_tmp_->set_storage_offset(self.storage_offset());
1591     self_tmp_->set_sizes_and_strides(sizes, strides);
1592   }
1593   namedinference::propagate_names(self_, self);
1594   return self_;
1595 }
1596 
1597 // specialization for symbolic shapes and strides.
1598 // SymIntArrayRef/ArrayRef<c10::SymInt> and SmallVector<c10::SymInt>/SymDimVector
1599 template <template <typename...> typename Container>
alias_with_sizes_and_strides(const Tensor & self,const Container<c10::SymInt> & sizes,const Container<c10::SymInt> & strides)1600 Tensor alias_with_sizes_and_strides(
1601     const Tensor& self,
1602     const Container<c10::SymInt>& sizes,
1603     const Container<c10::SymInt>& strides) {
1604   //caller should make sure that sizes and strides are valid for self
1605   //(storage is sufficient, strides are non-negative, strides and sizes array size is the same)
1606   Tensor self_;
1607   if (self.is_quantized()) {
1608     self_ = at::detail::make_tensor<QTensorImpl>(
1609       c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype(), get_qtensorimpl(self)->quantizer());
1610     self_.unsafeGetTensorImpl()->set_sizes_and_strides(sizes, strides, self.sym_storage_offset());
1611   } else {
1612     self_ = at::detail::make_tensor<TensorImpl>(
1613     c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
1614     self_.unsafeGetTensorImpl()->set_sizes_and_strides(sizes, strides, self.sym_storage_offset());
1615   }
1616   namedinference::propagate_names(self_, self);
1617   return self_;
1618 }
1619 
reshape_symint(const Tensor & self,c10::SymIntArrayRef proposed_shape)1620 Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
1621   if (self.is_sparse()) {
1622     TORCH_CHECK(false, "reshape is not implemented for sparse tensors");
1623   }
1624 
1625   if (self.is_contiguous() && !self.is_mkldnn()) {
1626     return self.view_symint(proposed_shape);
1627   }
1628 
1629   c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel());
1630 
1631   if (self.is_mkldnn()) {
1632     return at::_mkldnn_reshape(self, C10_AS_INTARRAYREF_SLOW(shape));
1633   }
1634 
1635   // `computeStride` returns the proper strides to use if this
1636   // `reshape` can be just a view.
1637   auto stride = at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape);
1638 
1639   // NB: Even though we have viewable geometry and the target strides here,
1640   //     we do not just call `as_strided` on `self` because the backward
1641   //     for `as_strided` is not as efficient as that of `view` (since the
1642   //     former is meant to handle general cases).
1643   //
1644   //     Similarly we don't call `view` because it duplicates some of the work
1645   //     we've already done, and instead call our internal/private operator
1646   //     `_reshape_alias` that essentially does the same thing as `view` and
1647   //     `as_strided` without any of the extra overhead.
1648   if (stride.has_value()) {
1649     // Temporary check to revert to the old behavior/view in cases where the
1650     // device is not supported (e.g. for XLA the operation is not supported
1651     // so we use `view` instead).
1652     //
1653     // We need to do the checks here instead of in `native_functions.yaml`
1654     // to preserve backwards compatibility.
1655     if (!self.is_xla() && !self.is_lazy() && !self.is_ipu() && !at::isTensorSubclassLike(self)) {
1656       return self._reshape_alias_symint(shape, stride.value());
1657     } else {
1658       return self.view_symint(shape);
1659     }
1660   }
1661   return at::_unsafe_view_symint(self.clone(at::MemoryFormat::Contiguous), shape);
1662 }
1663 
_reshape_copy_symint(const Tensor & self,c10::SymIntArrayRef proposed_shape)1664 Tensor _reshape_copy_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
1665   if (self.is_sparse()) {
1666     TORCH_CHECK(0, "_reshape_copy is not implemented for sparse tensors");
1667   }
1668   c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel());
1669 
1670   if (self.is_mkldnn()) {
1671     TORCH_CHECK(0, "_reshape_copy not implemented for mkldnn tensors");
1672   }
1673 
1674   if (self.is_contiguous()) {
1675     return self.view_symint(shape).clone(at::MemoryFormat::Contiguous);
1676   } else {
1677     return at::_unsafe_view_symint(self.clone(at::MemoryFormat::Contiguous), shape);
1678   }
1679 }
1680 
1681 // Duplicate of above code for non-symbolic ints. Kept for BC purposes and to
1682 // minimize breakages.
reshape(const Tensor & self,IntArrayRef proposed_shape)1683 Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
1684   if (self.is_sparse()) {
1685     TORCH_CHECK(false, "reshape is not implemented for sparse tensors");
1686   }
1687   DimVector shape = infer_size_dv(proposed_shape, self.numel());
1688 
1689   if (self.is_mkldnn()) {
1690     return at::_mkldnn_reshape(self, shape);
1691   }
1692 
1693   // `computeStride` returns the proper strides to use if this
1694   // `reshape` can be just a view.
1695   auto stride = at::detail::computeStride(self.sizes(), self.strides(), shape);
1696 
1697   // NB: Even though we have viewable geometry and the target strides here,
1698   //     we do not just call `as_strided` on `self` because the backward
1699   //     for `as_strided` is not as efficient as that of `view` (since the
1700   //     former is meant to handle general cases).
1701   //
1702   //     Similarly we don't call `view` because it duplicates some of the work
1703   //     we've already done, and instead call our internal/private operator
1704   //     `_reshape_alias` that essentially does the same thing as `view` and
1705   //     `as_strided` without any of the extra overhead.
1706   if (stride.has_value()) {
1707     // Temporary check to revert to the old behavior/view in cases where the
1708     // device is not supported (e.g. for XLA the operation is not supported
1709     // so we use `view` instead).
1710     //
1711     // We need to do the checks here instead of in `native_functions.yaml`
1712     // to preserve backwards compatibility.
1713     if (!self.is_xla() && !self.is_lazy() && !self.is_ipu()) {
1714       return self._reshape_alias(shape, stride.value());
1715     } else {
1716       return self.view(shape);
1717     }
1718   }
1719   return at::_unsafe_view(self.clone(at::MemoryFormat::Contiguous), shape);
1720 }
1721 
_reshape_alias(const Tensor & self,IntArrayRef sizes,IntArrayRef strides)1722 Tensor _reshape_alias(const Tensor& self, IntArrayRef sizes, IntArrayRef strides) {
1723   // This is only used by `reshape` in cases where it would otherwise have dispatched
1724   // to `view`. This removes the overhead of calling `view` which duplicates some of
1725   // the work that's already been done (`infer_size_dv` and `computeStride`).
1726 
1727   return alias_with_sizes_and_strides(self, sizes, strides);
1728 }
1729 
reshape_as(const Tensor & self,const Tensor & other)1730 Tensor reshape_as(const Tensor& self, const Tensor& other) {
1731   return self.reshape_symint(other.sym_sizes());
1732 }
1733 
select_sparse(const Tensor & self,int64_t dim,int64_t index)1734 static Tensor select_sparse(const Tensor& self, int64_t dim, int64_t index) {
1735   int64_t sparse_dim = self.sparse_dim();
1736   int64_t dense_dim = self.dense_dim();
1737   TORCH_INTERNAL_ASSERT(dim >= 0 && dim < sparse_dim + dense_dim);
1738 
1739   auto indices = self._indices();
1740   auto values = self._values();
1741   auto new_sizes = self.sizes().vec();
1742   new_sizes.erase(new_sizes.begin() + dim);
1743 
1744   if (dim < sparse_dim) {
1745     auto nzIndices = (indices[dim] == index).nonzero().view(-1);
1746     auto new_values = values.index_select(0, nzIndices);
1747     if (sparse_dim == 1) {
1748       // return dense part:
1749       if (new_values.size(0) == 1) {
1750         return new_values[0];
1751       } else {
1752         // sum promotes integral type to int64 when dtype is not specified.
1753         return at::sum(new_values, 0, false, new_values.scalar_type());
1754       }
1755     } else {
1756       auto dimIndices = (arange(
1757                              0,
1758                              sparse_dim,
1759                              std::nullopt /* dtype */,
1760                              std::nullopt /* layout */,
1761                              self.device(),
1762                              std::nullopt /* pin_memory */) != dim)
1763                             .nonzero()
1764                             .view(-1);
1765       auto new_indices = indices.index_select(1, nzIndices).index_select(0, dimIndices);
1766       return _sparse_coo_tensor_with_dims_and_tensors(
1767             sparse_dim - 1, dense_dim, new_sizes, new_indices, new_values, self.options());
1768     }
1769   } else {
1770     auto new_values = values.select(dim - sparse_dim + 1, index);
1771     return _sparse_coo_tensor_with_dims_and_tensors(
1772          sparse_dim, dense_dim - 1, new_sizes, indices, new_values, self.options());
1773   }
1774 }
1775 
1776 // this is an auxiliary function, called by the select&slice methods, that
1777 // creates a new quantizer from the given input
1778 // is_select is true if calling function is select()
create_subtensor_quantizer(const Tensor & self,bool is_select,int64_t start,int64_t end,int64_t dim,int64_t step)1779 static QuantizerPtr create_subtensor_quantizer(const Tensor& self, bool is_select, int64_t start,
1780   int64_t end, int64_t dim, int64_t step) {
1781   auto quantizer_prev = get_qtensorimpl(self)->quantizer();
1782   if (quantizer_prev->qscheme() == QScheme::PER_TENSOR_AFFINE) {
1783     return quantizer_prev;
1784   }
1785   QuantizerPtr quantizer;
1786   auto temp = static_cast<PerChannelAffineQuantizer*>(quantizer_prev.get());
1787   auto axis = temp->axis();
1788   auto scales = temp->scales();
1789   auto zero_points = temp->zero_points();
1790   if (dim == axis) {
1791     // Compute scales&zps for sub-tensor
1792     // *.select(0, start) could alternatively be replaced with *.slice(0, start, end, step), but
1793     // select has less overhead
1794     scales = is_select ? scales.select(0, start) : scales.slice(0, start, end, step);
1795     zero_points = is_select ? zero_points.select(0, start) : zero_points.slice(0, start, end, step);
1796   }
1797   if (scales.numel() > 1) {
1798     // Axis only needs to be adjusted if the calling function is select(), since select() reduces
1799     // the number of dimensions of the tensor by 1, and remains unchanged if calling function is slice()
1800     quantizer = make_per_channel_affine_quantizer(scales, zero_points, (is_select ? axis - 1 : axis),
1801                                                   quantizer_prev->scalar_type());
1802   } else {
1803     quantizer = make_per_tensor_affine_quantizer(scales.item().to<double>(), zero_points.item().to<int64_t>(),
1804                                                  quantizer_prev->scalar_type());
1805   }
1806   return quantizer;
1807 }
1808 
select(const Tensor & self,int64_t dim,int64_t index)1809 Tensor select(const Tensor& self, int64_t dim, int64_t index) {
1810   return at::select_symint(self, dim, c10::SymInt{index});
1811 }
1812 
select(const Tensor & self,Dimname dim,int64_t index)1813 Tensor select(const Tensor& self, Dimname dim, int64_t index) {
1814   return at::select_symint(self, dimname_to_position(self, dim), c10::SymInt{index});
1815 }
1816 
select_symint(const Tensor & self,int64_t dim,c10::SymInt index)1817 Tensor select_symint(const Tensor& self, int64_t dim, c10::SymInt index) {
1818   int64_t ndim = self.dim();
1819   if (ndim == 0) {
1820     TORCH_CHECK_INDEX(false, "select() cannot be applied to a 0-dim tensor.");
1821   }
1822   dim = maybe_wrap_dim(dim, ndim);
1823   auto size = self.sym_sizes()[dim];
1824   // Note: `size < -index` is not equivalent to `size <= -1 - index` if index is INT64_MIN
1825   // For std::numeric_limits<int64_t>::min() result of unary minus is undefined by the standard
1826   // but in practice is equal to self. On the other hand, indexing wrapping is valid for all
1827   // negative int64_t values, as x[INT64_MIN] is the same as x[INT64_MAX]
1828   if (size <= -1 - index || size <= index) {
1829     if (self.has_names() && self.names()[dim] != Dimname::wildcard()) {
1830       TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ",
1831                      self.sizes(), " at dimension ", self.names()[dim]);
1832     }
1833     TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ",
1834                    self.sizes(), " at dimension ", dim);
1835   }
1836   if (index < 0) {
1837     index += size;
1838   }
1839   if (self.is_sparse()) {
1840     return select_sparse(self, dim, index.guard_int(__FILE__, __LINE__));
1841   }
1842 
1843   Tensor result;
1844   if (self.is_quantized()) {
1845     auto local_index = index.guard_int(__FILE__, __LINE__);
1846 
1847     DimVector sizes(self.sizes().begin(), self.sizes().end());
1848     DimVector strides(self.strides().begin(), self.strides().end());
1849     auto storage_offset = self.storage_offset() + local_index * strides[dim];
1850     sizes.erase(sizes.begin() + dim);
1851     strides.erase(strides.begin() + dim);
1852 
1853     auto quantizer = create_subtensor_quantizer(self, true, local_index, local_index + 1, dim, 1);
1854     result = as_strided_qtensorimpl(self, sizes, strides, storage_offset, std::move(quantizer));
1855   } else {
1856     std::vector<c10::SymInt> sizes(self.sym_sizes().begin(), self.sym_sizes().end());
1857     std::vector<c10::SymInt> strides(self.sym_strides().begin(), self.sym_strides().end());
1858     auto storage_offset = self.sym_storage_offset() + index * strides[dim];
1859     sizes.erase(sizes.begin() + dim);
1860     strides.erase(strides.begin() + dim);
1861 
1862     result = self.as_strided_symint(sizes, strides, storage_offset);
1863   }
1864   namedinference::propagate_names_except(result, self, {dim});
1865   return result;
1866 }
1867 
select_backward_symint(const Tensor & grad,c10::SymIntArrayRef input_sizes,int64_t dim,c10::SymInt index)1868 Tensor select_backward_symint(const Tensor& grad, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) {
1869   auto grad_input = at::zeros_symint(input_sizes, grad.options());
1870   grad_input.select_symint(dim, std::move(index)).copy_(grad);
1871   return grad_input;
1872 }
1873 
index_select_sparse_cpu(const Tensor & self,int64_t dim,const Tensor & index)1874 Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& index) {
1875   /*
1876     Algorithm:
1877     index - a 1-D tensor of indices with shape (n,)
1878     self - sparse tensor, its shape is sizes = sparse_shape + dense_shape
1879       indices - 2-D tensor of indices, shape is (sparse_dims, nnz)
1880       values - (1+len(dense_shape))-D tensor of values, shape is (nnz,) + dense_shape
1881     index_select(dim, index) returns a sparse tensor with the following data
1882       new_sizes = sizes[:dim] + (n,) + sizes[dim+1:]
1883       new_indices - shape is (sparse_dims, new_nnz)
1884       new_values - shape is (new_nnz,) + dense_shape
1885 
1886       if dim < len(sparse_shape):
1887           # Find new_indices[dim] of the output sparse tensor and
1888           # indices at which to select values/indices.
1889           # The CPP code uses (binary/in a count table) search to find matches and may
1890           # swap the loop order for better algorithmic complexity.
1891           new_dim_indices = []
1892           selected_dim_indices = []
1893           # This is a brute-force algorithms to convey the main idea.
1894           # The CPP code below is more efficient but more complicated.
1895           for i, i_idx in enumerate(indices[dim]):
1896               for j, j_idx in enumerate(index):
1897                   if i_idx == j_idx:
1898                       new_dim_indices.append(j)
1899                       selected_dim_indices.append(i)
1900           new_indices = indices.index_select(1, selected_dim_indices)
1901           new_values = values.index_select(0, selected_dim_indices)
1902           new_indices[dim] = new_dim_indices
1903       else:
1904           new_indices = indices
1905           new_values = values.index_select(dim - sparse_dim + 1, index);
1906     */
1907   const auto ndim = self.dim();
1908   TORCH_CHECK_INDEX(ndim, "index_select() cannot be applied to a 0-dim tensor.");
1909   TORCH_CHECK_INDEX(
1910       index.dim() == 1 && index.dtype() == at::kLong && index.options().layout() == at::kStrided,
1911       "index_select() argument index must be 1-D strided (non-sparse) long-tensor.");
1912   dim = maybe_wrap_dim(dim, ndim);
1913   const auto size = self.size(dim);
1914   const auto sparse_dim = self.sparse_dim();
1915   const auto dense_dim = self.dense_dim();
1916   const auto indices = self._indices();
1917   const auto values = self._values();
1918   const auto nnz = values.size(0);
1919   const auto index_len = index.size(0);
1920   auto res_sizes = self.sizes().vec();
1921   res_sizes[dim] = index_len;
1922 
1923   // Equivalent to t.index_select(dim, idx), but vanilla index_select is not parallel,
1924   // so we use gather instead.
1925   // We use this method to select relevant indices/values
1926   // from the intersection between indices[dim] and the index.
1927   const auto index_select = [](const Tensor& t, int64_t dim, const Tensor& idx) -> Tensor {
1928     const auto idx_len = idx.numel();
1929     auto out_shape = t.sizes().vec();
1930     out_shape[dim] = idx_len;
1931     auto idx_shape = std::vector<int64_t>(t.dim(), 1);
1932     idx_shape[dim] = idx_len;
1933     return t.gather(dim, idx.view(idx_shape).expand(out_shape));
1934   };
1935 
1936   // If indexing into sparse dimensions
1937   if (dim < sparse_dim) {
1938     // short-circuit if index is empty
1939     if (!index_len) {
1940       auto res_indices = index_select(indices, 1, index);
1941       res_indices[dim] = index;
1942       const auto res_values = index_select(values, 0, index);
1943 
1944       return _sparse_coo_tensor_with_dims_and_tensors(
1945           sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
1946     }
1947 
1948     const auto nneg_index = [&index, index_len, &self, size, dim]() -> Tensor {
1949       const auto index_contiguous = index.contiguous();
1950       auto nneg_index = at::empty_like(index_contiguous);
1951       // nneg_index = (index < 0) * (index + size) + (index >= 0) * index
1952       auto* ptr_index = index_contiguous.data_ptr<int64_t>();
1953       auto* ptr_nneg_index = nneg_index.data_ptr<int64_t>();
1954       at::parallel_for(0, index_len, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
1955           const auto* src = ptr_index + start;
1956           auto* dst = ptr_nneg_index + start;
1957           for (C10_UNUSED const auto _ : c10::irange(start, end)) {
1958             auto idx = *src++;
1959             if (idx < -size || idx >= size) {
1960                // Mark self and dim as used if code is compiled with STRIP_ERROR_MESSAGES
1961               (void)dim;
1962               (void)self;
1963               TORCH_CHECK_INDEX(false,
1964                   "index_select(): index contains ", idx, " that is out of range for tensor of size ",
1965                   self.sizes(), " at dimension ", dim
1966               );
1967             }
1968             if (idx < 0) {
1969               idx += size;
1970             }
1971             *dst++ = idx;
1972           }
1973       });
1974 
1975       return nneg_index;
1976     }();
1977 
1978     const auto dim_indices = indices[dim].contiguous();
1979 
1980     // If nnz is smaller than size, then either indices[dim] or index gets sorted,
1981     // then this is followed by a binary search to find interesections.
1982     const auto get_selected_indices_small_nnz_large_size = [&]() -> std::tuple<Tensor, Tensor> {
1983       const auto grain_size = at::internal::GRAIN_SIZE;
1984       const auto n_threads_nnz = std::max<int64_t>(
1985           1, std::min<int64_t>((nnz + grain_size - 1) / grain_size, at::get_num_threads())
1986       );
1987       const auto n_threads_index = std::max<int64_t>(
1988           1, std::min<int64_t>((index_len + grain_size - 1) / grain_size, at::get_num_threads())
1989       );
1990       const auto search_in_dim_indices
1991         // if either dim_indices or index requires sorting, we compare
1992         // the cost of sort + binary search, which is comparing
1993         // (len(dim_indices) + len(index)) * log(len(index)) to
1994         // (len(dim_indices) + len(index)) * log(len(dim_indices)).
1995         // That simplifies to comparing len(dim_indices) to len(index).
1996         // Additionally, we take into consideration potential parallel
1997         // speedup.
1998         = (nnz / n_threads_nnz <= index_len / n_threads_index)
1999         // if self is coalesced and dim is 0, then we compare
2000         // index_len * log(len(dim_indices)), which is binary search into dim_indices,
2001         // to (len(index_len) + len(dim_indices)) * log(index_len).
2002         // Additionally, we take into consideration potential parallel
2003         // speedup.
2004           || (self.is_coalesced() && dim == 0
2005           && (index_len * std::log2(nnz) / n_threads_index
2006             <= (nnz / n_threads_nnz + index_len) * std::log2(index_len)))
2007         ? true : false;
2008 
2009       // src is a source of indices to binary search in sorted
2010       Tensor sorted, sorted_idx, src;
2011       std::tie(sorted, sorted_idx, src) = [
2012         &dim_indices, &nneg_index, &self,
2013         search_in_dim_indices, dim, nnz
2014       ](void) -> std::tuple<Tensor, Tensor, Tensor> {
2015         // sort dim_indices to binary search into it
2016         if (search_in_dim_indices) {
2017           // dim_indices is already sorted if self is coalesced and dim == 0
2018           if (self.is_coalesced() && dim == 0) {
2019             return std::make_tuple(dim_indices, at::arange(nnz, dim_indices.options()), nneg_index);
2020           }
2021           else {
2022             auto [sorted_dim_indices, sorted_dim_indices_idx] = dim_indices.sort();
2023             return std::make_tuple(sorted_dim_indices, sorted_dim_indices_idx, nneg_index);
2024           }
2025         }
2026         // sort nneg_index to binary search into it
2027         else {
2028           auto [sorted_nneg_index, sorted_nneg_index_idx] = nneg_index.sort();
2029           return std::make_tuple(sorted_nneg_index, sorted_nneg_index_idx, dim_indices);
2030         }
2031       }();
2032 
2033       const auto src_grain_size = at::internal::GRAIN_SIZE;
2034       const auto src_len = src.numel();
2035       const auto n_threads_src = std::max<int64_t>(
2036           // 1 <= n_threads_src <= std::min(ceil(src.numel() / src_grain_size), max_threads)
2037           1, std::min<int64_t>((src_len + src_grain_size - 1) / src_grain_size, at::get_num_threads())
2038       );
2039       const auto chunk_size_src = (src_len + n_threads_src - 1) / n_threads_src;
2040 
2041       const std::vector<int64_t> src_n_threads_shape = {
2042         n_threads_src, (src_len + n_threads_src - 1) / n_threads_src
2043       };
2044 
2045       // src_int_idx and sorted_int_idx store "i" and "j" indices indicating
2046       // intersections such that src_int_idx[i] == sorted_int_idx[j].
2047       // These intersections are found with binary search and in parallel.
2048       auto src_int_idx = at::empty(src_n_threads_shape, src.options());
2049       auto sorted_int_idx = at::empty_like(src_int_idx);
2050       // For each element "i" from src, int_counts define how many
2051       // elements there are in sorted, i.e. "j" indices, corresponding
2052       // to "i", i.e.:
2053       // |{j : src_int_idx[i] == sorted_int_idx[j]}| for each i in src_int_idx.
2054       auto int_counts = at::zeros_like(src_int_idx);
2055 
2056       // fill in src_int_idx, sorted_int_idx, int_counts
2057       {
2058         const auto sorted_len = sorted.numel();
2059         const auto* ptr_sorted = sorted.const_data_ptr<int64_t>();
2060         const auto* ptr_sorted_start = ptr_sorted;
2061         const auto* ptr_sorted_end = ptr_sorted + sorted_len;
2062 
2063         at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
2064             const auto start = tid * chunk_size_src;
2065             const auto end = std::min(start + chunk_size_src, src_len);
2066             auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).data_ptr<int64_t>();
2067             auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).data_ptr<int64_t>();
2068             auto* ptr_tid_int_counts = int_counts.select(0, tid).data_ptr<int64_t>();
2069             const auto* ptr_src = src.const_data_ptr<int64_t>() + start;
2070 
2071             for (const auto i : c10::irange(start, end)) {
2072               const auto src_val = *ptr_src++;
2073               const auto src_val_lb = std::lower_bound(ptr_sorted_start, ptr_sorted_end, src_val);
2074               // We cannot just use *src_val_lb != src_val because when
2075               // src_val_lb == ptr_sorted_end, dereferencing past-the-end value
2076               // is not well-defined.
2077               if (src_val_lb == ptr_sorted_end || *src_val_lb != src_val) {
2078                 ++ptr_tid_src_int_idx;
2079                 ++ptr_tid_sorted_int_idx;
2080                 ++ptr_tid_int_counts;
2081                 continue;
2082               }
2083               const auto src_val_ub = std::upper_bound(ptr_sorted_start, ptr_sorted_end, src_val);
2084 
2085               const int64_t count = src_val_ub - src_val_lb;
2086               const int64_t j = src_val_lb - ptr_sorted_start;
2087 
2088               *ptr_tid_src_int_idx++ = i;
2089               *ptr_tid_sorted_int_idx++ = j;
2090               *ptr_tid_int_counts++ = count;
2091             }
2092         });
2093       }
2094 
2095       const auto compressed_int_counts = int_counts.sum(-1);
2096       const auto res_len = compressed_int_counts.sum().item<int64_t>();
2097 
2098       // Short-circuit if empty intersection
2099       if (!res_len) {
2100         auto empty_idx = at::empty({0}, src.options());
2101         return std::make_tuple(empty_idx, empty_idx);
2102       }
2103 
2104       // Now that we know "i", "j" and the counts, we "unflatten"
2105       // them into two arrays of intersection indices such that
2106       // selected_src = repeat_interleave(src_int_idx, int_counts),
2107       // and selected_sorted is obtained as follows:
2108       // offsets = int_counts.cumsum(0).sub_(int_counts)
2109       // for ii, (j, c) in enumerate(zip(sorted_int_idx, int_counts)):
2110       //     out_slice = slice(offsets[ii], offsets[ii] + c)
2111       //     src_slice = slice(j, j + c)
2112       //     selected_sorted[out_slice] = sorted_int_idx[src_slice]
2113       auto selected_sorted = at::empty({res_len}, sorted.options());
2114       auto selected_src = at::empty({res_len}, src.options());
2115 
2116       // fill in selected_sorted, selected_src
2117       {
2118         auto* ptr_selected_sorted = selected_sorted.data_ptr<int64_t>();
2119         auto* ptr_selected_src = selected_src.data_ptr<int64_t>();
2120 
2121         const auto thread_offsets = compressed_int_counts.cumsum(0).sub_(compressed_int_counts);
2122         const auto* ptr_sorted_idx = sorted_idx.const_data_ptr<int64_t>();
2123         at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
2124             const auto start = tid * chunk_size_src;
2125             const auto end = std::min(start + chunk_size_src, src_len);
2126             const auto tid_offset = thread_offsets.const_data_ptr<int64_t>()[tid];
2127             const auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).const_data_ptr<int64_t>();
2128             const auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).const_data_ptr<int64_t>();
2129             const auto* ptr_tid_int_counts = int_counts.select(0, tid).const_data_ptr<int64_t>();
2130             auto* ptr_tid_selected_sorted = ptr_selected_sorted + tid_offset;
2131             auto* ptr_tid_selected_src = ptr_selected_src + tid_offset;
2132 
2133             for (C10_UNUSED const auto _ : c10::irange(start, end)) {
2134               const auto count = *ptr_tid_int_counts++;
2135               const auto i = *ptr_tid_src_int_idx++;
2136               const auto j = *ptr_tid_sorted_int_idx++;
2137               if (!count) continue;
2138 
2139               std::fill_n(ptr_tid_selected_src, count, i);
2140               std::copy_n(ptr_sorted_idx + j, count, ptr_tid_selected_sorted);
2141 
2142               ptr_tid_selected_sorted += count;
2143               ptr_tid_selected_src += count;
2144             }
2145         });
2146       }
2147 
2148       return search_in_dim_indices
2149         ? std::make_tuple(selected_sorted, selected_src)
2150         : std::make_tuple(selected_src, selected_sorted);
2151     };
2152 
2153     // Converts a 1d sorted idx to a compressed 1d compressed idx,
2154     // aka crow in the CSR format. Useful to get a count table in
2155     // a parallelized and no-sync manner.
2156     // TODO: this function is equivalent to _convert_indices_from_coo_to_csr.
2157     // The mentioned function is not public yet.
2158     const auto sorted_idx_to_cidx = [](
2159         const Tensor& idx,
2160         int64_t len,
2161         bool run_in_parallel = true) -> Tensor {
2162       auto cidx = at::empty({len + 1}, idx.options());
2163 
2164       const auto* ptr_idx = idx.const_data_ptr<int64_t>();
2165       auto* ptr_cidx = cidx.data_ptr<int64_t>();
2166 
2167       const auto idx_len = idx.numel();
2168 
2169       std::fill_n(ptr_cidx, ptr_idx[0] + 1, 0);
2170       std::fill_n(ptr_cidx + ptr_idx[idx_len - 1] + 1, len - ptr_idx[idx_len - 1], idx_len);
2171 
2172       const auto grain_size = run_in_parallel ? at::internal::GRAIN_SIZE : idx_len;
2173       at::parallel_for(0, idx_len, grain_size, [&](int64_t start, int64_t end) {
2174           auto* ptr_curr_cidx = ptr_cidx + ptr_idx[start] + 1;
2175           for (int64_t i = start; i < std::min(end, idx_len - 1); ++i) {
2176             const auto diff = ptr_idx[i + 1] - ptr_idx[i];
2177             std::fill_n(ptr_curr_cidx, diff, i + 1);
2178             ptr_curr_cidx += diff;
2179           }
2180       });
2181 
2182       return cidx;
2183     };
2184 
2185     // If nnz is (much) larger than size, then both indices[dim] and index get sorted
2186     // with a count sort (faster, and no huge nnz-sized chunk memory allocations).
2187     // The element-wise product between the count tables gives us all the intersections.
2188     const auto get_selected_indices_large_nnz_small_size = [&]() -> std::tuple<Tensor, Tensor> {
2189       const auto get_counts = [&sorted_idx_to_cidx](
2190           // Writes into counts (must be preallocated and zero)
2191           // and allows to use external buffers.
2192           Tensor& counts,
2193           const Tensor& t,
2194           int64_t bins,
2195           bool is_sorted = false,
2196           bool run_in_parallel = true) -> void {
2197         if (is_sorted) {
2198           const auto cidx = sorted_idx_to_cidx(t, bins, run_in_parallel);
2199           at::sub_out(counts, cidx.slice(0, 1, bins + 1), cidx.slice(0, 0, bins));
2200         }
2201         else {
2202           auto* ptr_counts = counts.data_ptr<int64_t>();
2203           const auto* ptr_vals = t.const_data_ptr<int64_t>();
2204           for (C10_UNUSED const auto _ : c10::irange(t.numel())) {
2205             ++ptr_counts[*ptr_vals++];
2206           }
2207         }
2208       };
2209 
2210       const auto counts_per_thread = [&get_counts, size](
2211           const Tensor& idx,
2212           bool is_sorted = false,
2213           int64_t grain_size = at::internal::GRAIN_SIZE
2214       ) -> Tensor {
2215         const auto idx_len = idx.numel();
2216         // 1 <= n_threads <= min(ceil(len / grain_size), max_threads)
2217         const auto n_threads = std::max<int64_t>(
2218             1, std::min<int64_t>((idx_len + grain_size - 1) / grain_size, at::get_num_threads())
2219         );
2220         const auto chunk_size = (idx_len + n_threads - 1) / n_threads;
2221         const auto run_in_parallel = (n_threads == 1);
2222 
2223         auto counts_per_thread = at::zeros({n_threads, size}, idx.options());
2224         at::parallel_for(0, n_threads, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
2225           const auto start = tid * chunk_size;
2226           const auto end = std::min(start + chunk_size, idx_len);
2227           const auto tid_idx = idx.slice(0, start, end);
2228           auto tid_counts = counts_per_thread.select(0, tid);
2229           get_counts(tid_counts, tid_idx, /*bins=*/size,
2230               /*is_sorted=*/is_sorted, /*run_in_parallel=*/run_in_parallel);
2231         });
2232 
2233         return counts_per_thread;
2234       };
2235 
2236       auto dim_indices_counts_per_thread = counts_per_thread(
2237           dim_indices,
2238           /*is_sorted=*/self.is_coalesced() && dim == 0
2239           /*grain_size = at::internal::GRAIN_SIZE*/
2240       );
2241       auto dim_indices_offset_counts_per_thread = dim_indices_counts_per_thread.cumsum(0);
2242 
2243       auto index_counts_per_thread = counts_per_thread(
2244           nneg_index,
2245           /*is_sorted=*/false
2246           /*grain_size = at::internal::GRAIN_SIZE*/
2247       );
2248       auto index_offset_counts_per_thread = index_counts_per_thread.cumsum(0);
2249 
2250       const auto index_counts = index_offset_counts_per_thread.select(0, -1);
2251       const auto dim_indices_counts = dim_indices_offset_counts_per_thread.select(0, -1);
2252       const auto intersection_counts = index_counts.mul(dim_indices_counts);
2253       const auto res_len = intersection_counts.sum().item<int64_t>();
2254       // Short-circuit if empty intersection
2255       if (!res_len) {
2256         auto empty_idx = at::empty({0}, index.options());
2257         return std::make_tuple(empty_idx, empty_idx);
2258       }
2259       const auto intersection_offsets = intersection_counts.cumsum(0);
2260 
2261       const auto search_in_dim_indices = [&]() -> bool {
2262         const auto grain_size = at::internal::GRAIN_SIZE;
2263         const auto n_threads_index = std::max<int64_t>(
2264             1, std::min<int64_t>((index_len + grain_size - 1) / grain_size, at::get_num_threads())
2265         );
2266         const auto n_threads_dim_indices = std::max<int64_t>(
2267             1, std::min<int64_t>((nnz + grain_size - 1) / grain_size, at::get_num_threads())
2268         );
2269 
2270         const auto index_max_copy_work_per_thread =
2271           index_counts_per_thread.mul(dim_indices_counts).sum(-1).max().item<int64_t>();
2272         const auto dim_indices_max_copy_work_per_thread
2273           = dim_indices_counts_per_thread.mul(index_counts).sum(-1).max().item<int64_t>();
2274 
2275         const auto index_max_work_per_thread = index_max_copy_work_per_thread * index_len / n_threads_index;
2276         const auto dim_indices_max_work_per_thread = dim_indices_max_copy_work_per_thread * nnz / n_threads_dim_indices;
2277         return index_max_work_per_thread <= dim_indices_max_work_per_thread
2278           ? true
2279           : false;
2280       }();
2281 
2282       Tensor idx, idx_counts_per_thread, idx_offset_counts_per_thread;
2283       Tensor src, src_counts_per_thread, src_offset_counts_per_thread;
2284       std::tie(
2285           idx, idx_counts_per_thread, idx_offset_counts_per_thread,
2286           src, src_counts_per_thread, src_offset_counts_per_thread
2287       ) = [&]() {
2288         return search_in_dim_indices
2289           ? std::make_tuple(
2290               nneg_index, index_counts_per_thread, index_offset_counts_per_thread,
2291               dim_indices, dim_indices_counts_per_thread, dim_indices_offset_counts_per_thread
2292             )
2293           : std::make_tuple(
2294               dim_indices, dim_indices_counts_per_thread, dim_indices_counts_per_thread.cumsum(0),
2295               nneg_index, index_counts_per_thread, index_counts_per_thread.cumsum(0)
2296             );
2297       }();
2298 
2299       const auto idx_counts = idx_offset_counts_per_thread.select(0, -1);
2300       const auto src_counts = src_offset_counts_per_thread.select(0, -1);
2301 
2302       Tensor src_idx, src_idx_offsets;
2303       std::tie(src_idx, src_idx_offsets) = [&](
2304           int64_t grain_size = at::internal::GRAIN_SIZE
2305       ) -> std::tuple<Tensor, Tensor> {
2306         const auto src_intersection_counts = src_counts.mul(idx_counts > 0);
2307         const auto src_intersection_offsets = src_intersection_counts.cumsum(0);
2308         const auto src_idx_len = src_intersection_offsets.const_data_ptr<int64_t>()[size - 1];
2309         auto src_idx = at::empty({src_idx_len}, src.options());
2310 
2311         const auto* ptr_src = src.const_data_ptr<int64_t>();
2312         const auto* ptr_intersection_counts = intersection_counts.const_data_ptr<int64_t>();
2313         const auto* ptr_src_intersection_counts = src_intersection_counts.const_data_ptr<int64_t>();
2314         const auto* ptr_src_intersection_offsets = src_intersection_offsets.const_data_ptr<int64_t>();
2315         auto* ptr_src_idx = src_idx.data_ptr<int64_t>();
2316 
2317         const auto src_len = src.numel();
2318         const auto n_threads_src = std::max<int64_t>(
2319             1, std::min<int64_t>((src_len + grain_size - 1) / grain_size, at::get_num_threads())
2320         );
2321         const auto chunk_size = (src_len + n_threads_src - 1) / n_threads_src;
2322         at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
2323             const auto start = tid * chunk_size;
2324             const auto end = std::min(start + chunk_size, src_len);
2325             auto* ptr_src_tid = ptr_src + start;
2326             const auto* ptr_src_counts_per_thread
2327               = src_counts_per_thread.select(0, tid).const_data_ptr<int64_t>();
2328             const auto* ptr_src_offset_counts_per_thread
2329               = src_offset_counts_per_thread.select(0, tid).const_data_ptr<int64_t>();
2330             auto tid_counts = at::zeros({size}, src.options());
2331             auto* ptr_tid_counts = tid_counts.data_ptr<int64_t>();
2332 
2333             for (const auto i : c10::irange(start, end)) {
2334               const auto idx_val = *ptr_src_tid++;
2335               // skip idx value if not in the intersection
2336               if (!ptr_intersection_counts[idx_val]) continue;
2337               const auto idx_val_offset
2338                 = ptr_src_intersection_offsets[idx_val]
2339                 - ptr_src_intersection_counts[idx_val];
2340               const auto idx_val_tid_offset
2341                 = ptr_src_offset_counts_per_thread[idx_val]
2342                 - ptr_src_counts_per_thread[idx_val];
2343               auto& idx_val_local_tid_count = ptr_tid_counts[idx_val];
2344               ptr_src_idx[idx_val_offset + idx_val_tid_offset + idx_val_local_tid_count] = i;
2345               ++idx_val_local_tid_count;
2346             }
2347         });
2348 
2349         const auto src_idx_offsets = src_intersection_offsets.sub_(src_intersection_counts);
2350 
2351         return std::make_tuple(src_idx, src_idx_offsets);
2352       }();
2353 
2354       auto [idx_selected, src_selected] = [&](
2355           int64_t grain_size = at::internal::GRAIN_SIZE
2356       ) -> std::tuple<Tensor, Tensor> {
2357         const auto thread_offset = [&]() {
2358           // we do not need idx_counts_per_thread anymore,
2359           // so it is safe to do in-place intersection.
2360           auto counts_per_thread = idx_counts_per_thread.mul_(src_counts).sum(-1);
2361           return counts_per_thread.cumsum(0).sub_(counts_per_thread);
2362         }();
2363         const auto* ptr_thread_offset = thread_offset.const_data_ptr<int64_t>();
2364 
2365         auto idx_selected = at::empty({res_len}, idx.options());
2366         auto src_selected = at::empty({res_len}, src.options());
2367 
2368         const auto* ptr_idx = idx.const_data_ptr<int64_t>();
2369         const auto* ptr_src_counts = src_counts.const_data_ptr<int64_t>();
2370         const auto* ptr_intersection_counts = intersection_counts.const_data_ptr<int64_t>();
2371         const auto* ptr_src_idx = src_idx.const_data_ptr<int64_t>();
2372         const auto* ptr_src_idx_offsets = src_idx_offsets.const_data_ptr<int64_t>();
2373         auto* ptr_idx_selected = idx_selected.data_ptr<int64_t>();
2374         auto* ptr_src_selected = src_selected.data_ptr<int64_t>();
2375 
2376         const auto idx_len = idx.numel();
2377         const auto n_threads_idx = std::max<int64_t>(
2378             1, std::min<int64_t>((idx_len + grain_size - 1) / grain_size, at::get_num_threads())
2379         );
2380         const auto chunk_size = (idx_len + n_threads_idx - 1) / n_threads_idx;
2381         at::parallel_for(0, n_threads_idx, 1, [&](int64_t tid, C10_UNUSED int64_t _) {
2382             const auto start = tid * chunk_size;
2383             const auto end = std::min(start + chunk_size, idx_len);
2384             const auto tid_offset = ptr_thread_offset[tid];
2385             const auto* ptr_idx_tid = ptr_idx + start;
2386             auto* ptr_idx_selected_tid = ptr_idx_selected + tid_offset;
2387             auto* ptr_src_selected_tid = ptr_src_selected + tid_offset;
2388 
2389             for (const auto i : c10::irange(start, end)) {
2390               const auto idx_val = *ptr_idx_tid++;
2391               // skip if idx_val is not in the intersection
2392               if (!ptr_intersection_counts[idx_val]) continue;
2393               const auto count = ptr_src_counts[idx_val];
2394               const auto j = ptr_src_idx_offsets[idx_val];
2395               std::fill_n(ptr_idx_selected_tid, count, i);
2396               std::copy_n(ptr_src_idx + j, count, ptr_src_selected_tid);
2397               ptr_idx_selected_tid += count;
2398               ptr_src_selected_tid += count;
2399             }
2400         });
2401 
2402         return std::make_tuple(idx_selected, src_selected);
2403       }();
2404 
2405       return search_in_dim_indices
2406         ? std::make_tuple(src_selected, idx_selected)
2407         : std::make_tuple(idx_selected, src_selected);
2408     };
2409 
2410     const auto make_output = [&](
2411         const Tensor& selected_dim_indices,
2412         const Tensor& res_dim_indices) -> Tensor {
2413       auto res_indices = index_select(indices, 1, selected_dim_indices);
2414       res_indices[dim] = res_dim_indices;
2415       const auto res_values = index_select(values, 0, selected_dim_indices);
2416 
2417       return _sparse_coo_tensor_with_dims_and_tensors(
2418           sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
2419     };
2420 
2421     // Brute-force solution for small values of nnz and index_len
2422     const auto get_result_small_nnz_small_index = [&]()
2423       -> Tensor {
2424       const auto dim_indices_in_inner_loop = nnz >= index_len;
2425       auto [outer, inner] = [&]() -> std::tuple<Tensor, Tensor> {
2426         if (dim_indices_in_inner_loop) {
2427           return std::make_tuple(nneg_index, dim_indices);
2428         }
2429         else {
2430           return std::make_tuple(dim_indices, nneg_index);
2431         }
2432       }();
2433 
2434       const auto* ptr_outer = outer.const_data_ptr<int64_t>();
2435       const auto* ptr_inner = inner.const_data_ptr<int64_t>();
2436       // NOTE: if very critical, replace std::vector with
2437       // a data structure that operates on stack up to some limit.
2438       auto outer_selected_idx = std::vector<int64_t>();
2439       auto inner_selected_idx = std::vector<int64_t>();
2440       int64_t res_len = 0;
2441       for (const auto i : c10::irange(outer.numel())) {
2442         for (const auto j : c10::irange(inner.numel())) {
2443           if (ptr_outer[i] == ptr_inner[j]) {
2444             ++res_len;
2445             outer_selected_idx.push_back(i);
2446             inner_selected_idx.push_back(j);
2447           }
2448         }
2449       }
2450 
2451       const auto outer_selected_idx_tensor = at::from_blob(
2452           outer_selected_idx.data(), {res_len}, at::kLong
2453       );
2454       const auto inner_selected_idx_tensor = at::from_blob(
2455           inner_selected_idx.data(), {res_len}, at::kLong
2456       );
2457 
2458       return dim_indices_in_inner_loop
2459         ? make_output(inner_selected_idx_tensor, outer_selected_idx_tensor)
2460         : make_output(outer_selected_idx_tensor, inner_selected_idx_tensor);
2461     };
2462 
2463     constexpr int64_t BRUTE_FORCE_SIZE_LIMIT = 2 << 14; // 16384
2464     // NOTE: such a condition to avoid overflows in (nnz * index_len)
2465     if (nnz <= BRUTE_FORCE_SIZE_LIMIT && index_len <= BRUTE_FORCE_SIZE_LIMIT
2466         && (nnz * index_len) <= BRUTE_FORCE_SIZE_LIMIT) {
2467       return get_result_small_nnz_small_index();
2468     }
2469     else {
2470       Tensor selected_dim_indices;
2471       Tensor res_dim_indices;
2472 
2473       // A more precise decision could be of the form:
2474       // `nnz < C(nnz, size) * size`, but it requires heavy benchmarking.
2475       // We choose `nnz < size`, which measures theoretical complexity
2476       // and does not rely on runtime performance.
2477       // TODO: perform this analysis and find better C(nnz, size).
2478       if (nnz <= size) {
2479         std::tie(selected_dim_indices, res_dim_indices) = get_selected_indices_small_nnz_large_size();
2480       }
2481       else {
2482         std::tie(selected_dim_indices, res_dim_indices) = get_selected_indices_large_nnz_small_size();
2483       }
2484 
2485       return make_output(selected_dim_indices, res_dim_indices);
2486     }
2487   }
2488   // If indexing into dense dimensions
2489   else {
2490     // It is sufficient to just perform `index_select` on values
2491     // if `dim` refers to dense dimensions.
2492     const auto res_values = index_select(values, dim - sparse_dim + 1, index);
2493 
2494     return _sparse_coo_tensor_with_dims_and_tensors(
2495         sparse_dim, dense_dim, res_sizes, indices, res_values, self.options());
2496   }
2497 }
2498 
slice(const Tensor & self,int64_t dim,std::optional<int64_t> start,std::optional<int64_t> end,int64_t step)2499 Tensor slice(
2500     const Tensor& self,
2501     int64_t dim,
2502     std::optional<int64_t> start,
2503     std::optional<int64_t> end,
2504     int64_t step) {
2505   int64_t ndim = self.dim();
2506   if (ndim == 0) {
2507     TORCH_CHECK_INDEX(false, "slice() cannot be applied to a 0-dim tensor.");
2508   }
2509   dim = maybe_wrap_dim(dim, ndim);
2510   DimVector sizes(self.sizes().begin(), self.sizes().end());
2511   DimVector strides(self.strides().begin(), self.strides().end());
2512   // handle optional parameters
2513   int64_t start_val = start.has_value() ? start.value() : 0;
2514   int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
2515 
2516   // TODO: support negative strides
2517   TORCH_CHECK(step > 0, "slice step must be positive");
2518 
2519   if (start_val < 0) {
2520     start_val += sizes[dim];
2521   }
2522   if (end_val < 0) {
2523     end_val += sizes[dim];
2524   }
2525   if (start_val < 0) {
2526     start_val = 0;
2527   } else if (start_val >= sizes[dim]) {
2528     start_val = sizes[dim];
2529   }
2530   if (end_val < start_val) {
2531     end_val = start_val;
2532   } else if (end_val >= sizes[dim]) {
2533     end_val = sizes[dim];
2534   }
2535   auto storage_offset = self.storage_offset() + start_val * strides[dim];
2536   auto len = end_val - start_val;
2537   sizes[dim] = (len + step - 1) / step; // round-up
2538   strides[dim] *= step;
2539 
2540   Tensor result;
2541   if (self.is_quantized()) {
2542     auto quantizer = create_subtensor_quantizer(self, false, start_val, end_val, dim, step);
2543     result = as_strided_qtensorimpl(self, sizes, strides, storage_offset, std::move(quantizer));
2544   } else {
2545     // NB: it is extremely important to perform a redispatch here for
2546     // the MPS backend; if you call directly to as_strided_tensorimpl,
2547     // the necessary metadata for MPS will not get setup and you will
2548     // get silently wrong results
2549     result = self.as_strided(sizes, strides, storage_offset);
2550   }
2551   namedinference::propagate_names(result, self);
2552   return result;
2553 }
2554 
slice_inverse_symint(const Tensor & self,const Tensor & base,int64_t,std::optional<SymInt>,std::optional<SymInt>,SymInt)2555 Tensor slice_inverse_symint(
2556     const Tensor& self,
2557     const Tensor& base,
2558     int64_t /* dim */,
2559     std::optional<SymInt> /* start */,
2560     std::optional<SymInt> /* end */,
2561     SymInt /* step */) {
2562   // assume self has enough to storage to be viewed with base's metadata
2563   return self.as_strided_symint(base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
2564 }
2565 
slice_backward(const Tensor & grad,IntArrayRef input_sizes,int64_t dim,int64_t start,int64_t end,int64_t step)2566 Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
2567   auto grad_input = at::zeros(input_sizes, grad.options());
2568   grad_input.slice(dim, start, end, step).copy_(grad);
2569   return grad_input;
2570 }
2571 
split(const Tensor & self,int64_t split_size,int64_t dim)2572 std::vector<Tensor> split(const Tensor& self, int64_t split_size, int64_t dim) {
2573   const auto num_splits = get_num_splits(self, split_size, dim);
2574   std::vector<Tensor> splits(num_splits);
2575   int64_t last_split_size = split_size - (split_size * num_splits - self.size(dim));
2576 
2577   for (const auto i : c10::irange(num_splits)) {
2578     auto length = i < num_splits - 1 ? split_size : last_split_size;
2579     splits[i] = self.narrow(dim, i * split_size, length);
2580   }
2581   return splits;
2582 }
2583 
split_symint(const Tensor & self,c10::SymIntArrayRef sizes,int64_t dim)2584 std::vector<Tensor> split_symint(const Tensor& self, c10::SymIntArrayRef sizes, int64_t dim) {
2585   return at::split_with_sizes_symint(self, sizes, dim);
2586 }
2587 
unsafe_split(const Tensor & self,int64_t split_size,int64_t dim)2588 std::vector<Tensor> unsafe_split(const Tensor& self, int64_t split_size, int64_t dim) {
2589   auto result = at::native::split(self, split_size, dim);
2590   for (auto& t : result) {
2591     // TODO(Ailing): do we need to set version_counter here?
2592     if (!t.is_inference()) {
2593       t.unsafeGetTensorImpl()->set_version_counter(c10::VariableVersion(/*version=*/0));
2594     }
2595   }
2596   return result;
2597 }
2598 
hsplit(const Tensor & self,int64_t split_size)2599 std::vector<Tensor> hsplit(const Tensor& self, int64_t split_size) {
2600   TORCH_CHECK(self.dim() >= 1, "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with ", self.dim(), " dimensions!")
2601   int64_t dim = (self.dim() == 1) ? 0 : 1;
2602   TORCH_CHECK(split_size != 0 && self.sym_sizes()[dim] % split_size == 0,
2603     "torch.hsplit attempted to split along dimension ", dim,", but the size of the dimension ", self.sizes()[dim], " is not divisible by the split_size ", split_size, "!");
2604   return at::tensor_split(self, split_size, dim);
2605 }
2606 
vsplit(const Tensor & self,int64_t split_size)2607 std::vector<Tensor> vsplit(const Tensor& self, int64_t split_size) {
2608   TORCH_CHECK(self.dim() >= 2, "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with ", self.dim(), " dimensions!")
2609   TORCH_CHECK(split_size != 0 && self.sym_sizes()[0] % split_size == 0,
2610     "torch.vsplit attempted to split along dimension ", 0,", but the size of the dimension ", self.sizes()[0], " is not divisible by the split_size ", split_size, "!");
2611   return at::tensor_split(self, split_size, 0);
2612 }
2613 
dsplit(const Tensor & self,int64_t split_size)2614 std::vector<Tensor> dsplit(const Tensor& self, int64_t split_size) {
2615   TORCH_CHECK(self.dim() >= 3, "torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with ", self.dim(), " dimensions!")
2616   TORCH_CHECK(split_size != 0 && self.sym_sizes()[2] % split_size == 0,
2617     "torch.dsplit attempted to split along dimension ", 2,", but the size of the dimension ", self.sizes()[2], " is not divisible by the split_size ", split_size, "!");
2618   return at::tensor_split(self, split_size, 2);
2619 }
2620 
split_with_sizes(const Tensor & self,IntArrayRef split_sizes,int64_t dim)2621 std::vector<Tensor> split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
2622   TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
2623   const int64_t dim_size = self.size(dim);
2624   const int64_t num_splits = split_sizes.size();
2625   int64_t start_idx = 0;
2626 
2627   std::vector<Tensor> splits;
2628   splits.reserve(num_splits);
2629   for (const auto i : c10::irange(num_splits)) {
2630     auto length = split_sizes[i];
2631     TORCH_CHECK(length >= 0,
2632              "split_with_sizes expects split_sizes have only non-negative ",
2633              "entries, but got split_sizes=", split_sizes);
2634     splits.push_back(at::native::slice(self, dim, start_idx, start_idx + length, 1));
2635     start_idx += length;
2636   }
2637   TORCH_CHECK(start_idx == dim_size,
2638            "split_with_sizes expects split_sizes to sum exactly to ", dim_size,
2639            " (input tensor's size at dimension ", dim, "), ", "but got split_sizes=", split_sizes);
2640   return splits;
2641 }
2642 
unsafe_split_with_sizes(const Tensor & self,IntArrayRef split_sizes,int64_t dim)2643 std::vector<Tensor> unsafe_split_with_sizes(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
2644   auto result = at::native::split_with_sizes(self, split_sizes, dim);
2645   for (auto& t : result) {
2646     // TODO(Ailing): do we need to set version_counter here?
2647     if (!t.is_inference()) {
2648       t.unsafeGetTensorImpl()->set_version_counter(c10::VariableVersion(/*version=*/0));
2649     }
2650   }
2651   return result;
2652 }
2653 
hsplit(const Tensor & self,IntArrayRef split_sizes)2654 std::vector<Tensor> hsplit(const Tensor& self, IntArrayRef split_sizes) {
2655   TORCH_CHECK(self.dim() >= 1, "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with ", self.dim(), " dimensions!")
2656   return at::tensor_split(self, split_sizes, (self.dim() == 1) ? 0 : 1);
2657 }
2658 
vsplit(const Tensor & self,IntArrayRef split_sizes)2659 std::vector<Tensor> vsplit(const Tensor& self, IntArrayRef split_sizes) {
2660   TORCH_CHECK(self.dim() >= 2, "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with ", self.dim(), " dimensions!")
2661   return at::tensor_split(self, split_sizes, 0);
2662 }
2663 
dsplit(const Tensor & self,IntArrayRef split_sizes)2664 std::vector<Tensor> dsplit(const Tensor& self, IntArrayRef split_sizes) {
2665   TORCH_CHECK(self.dim() >= 3, "torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with ", self.dim(), " dimensions!")
2666   return at::tensor_split(self, split_sizes, 2);
2667 }
2668 
2669 // Precondition: tensors is non-empty
get_stack_inputs(TensorList tensors,int64_t dim)2670 static inline std::vector<Tensor> get_stack_inputs(TensorList tensors, int64_t dim) {
2671   std::vector<Tensor> inputs(tensors.size());
2672   at::IntArrayRef entry_shape = tensors[0].sizes();
2673   inputs[0] = tensors[0].unsqueeze(dim);
2674   for (const auto i : c10::irange(1, tensors.size())) {
2675     TORCH_CHECK(tensors[i].sizes() == entry_shape,
2676       "stack expects each tensor to be equal size, but got ", entry_shape,
2677       " at entry 0 and ", tensors[i].sizes(), " at entry ", i);
2678     inputs[i] = tensors[i].unsqueeze(dim);
2679   }
2680   return inputs;
2681 }
2682 
maybe_native_stack(Tensor & result,TensorList tensors,int64_t dim)2683 bool inline maybe_native_stack(Tensor& result, TensorList tensors, int64_t dim) {
2684   dim = maybe_wrap_dim(dim, tensors[0].dim() + 1);
2685   if (detail::CanUseNativeSerialStack<TensorList, /*skip_overlap_check*/ false>::call(result, tensors, dim)) {
2686     // compute the size of the result
2687     auto result_sizes = tensors[0].sizes().vec();
2688     result_sizes.insert(result_sizes.begin() + dim, tensors.size());
2689 
2690     // skip resizing if size of result is same as expected
2691     // raise a warning while resizing if output has one or more elements
2692     // at::native::resize_output(result, result_sizes);
2693     // TODO: restore the above, see https://github.com/pytorch/pytorch/issues/64709
2694 
2695     if (result.sizes() != result_sizes) {
2696       result.resize_(result_sizes);
2697     }
2698 
2699     stack_serial_stub(kCPU, result, tensors, dim);
2700     return true;
2701   }
2702   return false;
2703 }
2704 
_stack(TensorList tensors,int64_t dim)2705 Tensor _stack(TensorList tensors, int64_t dim) {
2706   ScalarType high_type = result_type(tensors);
2707   Tensor result = at::empty({0}, tensors[0].options().dtype(high_type));
2708   return at::native::_stack_out(get_stack_inputs(tensors, dim), dim, result);
2709 }
2710 
_stack_cpu(TensorList tensors,int64_t dim)2711 Tensor _stack_cpu(TensorList tensors, int64_t dim) {
2712   ScalarType high_type = result_type(tensors);
2713   Tensor result = at::empty({0}, tensors[0].options().dtype(high_type));
2714   return at::native::_stack_out_cpu(tensors, dim, result);
2715 }
2716 
check_stack_inputs(TensorList tensors,int64_t dim)2717 static void check_stack_inputs(TensorList tensors, int64_t dim) {
2718   at::IntArrayRef entry_shape = tensors[0].sizes();
2719   for (const auto i : c10::irange(1, tensors.size())) {
2720     TORCH_CHECK(tensors[i].sizes() == entry_shape,
2721       "stack expects each tensor to be equal size, but got ", entry_shape,
2722       " at entry 0 and ", tensors[i].sizes(), " at entry ", i);
2723   }
2724 }
2725 
2726 // Pads each tensor on `dim`-th dimension such that padded_dim % num_chunks == 0.
_pad_chunk(TensorList tensors,int64_t dim,int64_t num_chunks)2727 static std::vector<Tensor> _pad_chunk(TensorList tensors, int64_t dim, int64_t num_chunks) {
2728   auto num_tensors = tensors.size();
2729   std::vector<Tensor> padded_tensors;
2730   padded_tensors.reserve(num_tensors);
2731   for (const auto & tensor : tensors) {
2732     auto tensor_size = tensor.sizes();
2733     std::vector<int64_t> padded_size(tensor_size.vec());
2734     padded_size[dim] = (tensor_size[dim] + num_chunks - 1) / num_chunks * num_chunks;
2735     Tensor padded_tensor = tensor;
2736     if (padded_size != tensor_size) {
2737       padded_tensor = tensor.new_zeros(padded_size);
2738       padded_tensor.narrow(dim, 0, tensor_size[dim]).copy_(tensor);
2739     }
2740     std::vector<int64_t> view_sizes(tensor_size.begin(), tensor_size.begin()+dim);
2741     view_sizes.insert(view_sizes.end(), {num_chunks, -1});
2742     padded_tensors.push_back(padded_tensor.view(view_sizes));
2743   }
2744   return padded_tensors;
2745 }
2746 
_chunk_cat(TensorList tensors,int64_t dim,int64_t num_chunks)2747 Tensor _chunk_cat(TensorList tensors, int64_t dim, int64_t num_chunks) {
2748   auto wrapped_dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks);
2749   return at::cat(_pad_chunk(tensors, wrapped_dim, num_chunks), wrapped_dim+1);
2750 }
2751 
_chunk_cat_out(TensorList tensors,int64_t dim,int64_t num_chunks,Tensor & out)2752 Tensor& _chunk_cat_out(TensorList tensors, int64_t dim, int64_t num_chunks, Tensor& out) {
2753   auto wrapped_dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks);
2754   at::cat_out(out, _pad_chunk(tensors, wrapped_dim, num_chunks), wrapped_dim+1);
2755   return out;
2756 }
2757 
2758 // TODO(msubkhankulov): refactor to use _stack
stack(TensorList tensors,int64_t dim)2759 Tensor stack(TensorList tensors, int64_t dim) {
2760   TORCH_CHECK(!tensors.empty(),
2761            "stack expects a non-empty TensorList");
2762   auto wrapped_dim = maybe_wrap_dim(dim, tensors[0].ndimension()+1);
2763   if (wrapped_dim < tensors[0].ndimension() && !tensors[0].is_sparse()) {
2764     check_stack_inputs(tensors, wrapped_dim);
2765     auto result_sizes = tensors[0].sizes().vec();
2766     result_sizes.insert(result_sizes.begin() + wrapped_dim, tensors.size());
2767     auto out = at::cat(tensors, wrapped_dim);
2768     return out.view(result_sizes); // one can always split a dimension with view
2769   } else { //dim = tensors[0].ndimension() cannot be efficiently handled by view
2770     return at::cat(get_stack_inputs(tensors, dim), dim);
2771   }
2772 }
2773 
2774 // CPU specific implementation
_stack_out_cpu(TensorList tensors,int64_t dim,Tensor & result)2775 Tensor& _stack_out_cpu(TensorList tensors, int64_t dim, Tensor& result) {
2776   if (maybe_native_stack(result, tensors, dim)) {
2777     return result;
2778   } else {
2779     return at::cat_out(result, get_stack_inputs(tensors, dim), dim);
2780   }
2781 }
2782 
2783 // default backend
_stack_out(TensorList tensors,int64_t dim,Tensor & result)2784 Tensor& _stack_out(TensorList tensors, int64_t dim, Tensor& result) {
2785   return at::cat_out(result, tensors, dim);
2786 }
2787 
2788 // TODO(msubkhankulov): refactor to use _stack_out
stack_out(TensorList tensors,int64_t dim,Tensor & result)2789 Tensor& stack_out(TensorList tensors, int64_t dim, Tensor& result) {
2790   TORCH_CHECK(!tensors.empty(),
2791            "stack expects a non-empty TensorList");
2792   auto wrapped_dim = maybe_wrap_dim(dim, tensors[0].ndimension()+1);
2793   if (wrapped_dim < tensors[0].ndimension() && !tensors[0].is_sparse()) {
2794     check_stack_inputs(tensors, wrapped_dim);
2795     auto result_sizes = tensors[0].sizes().vec();
2796     result_sizes.insert(result_sizes.begin() + wrapped_dim, tensors.size());
2797     at::native::resize_output(result, result_sizes);
2798     auto cat_sizes = tensors[0].sizes().vec();
2799     cat_sizes[wrapped_dim] *= tensors.size();
2800     auto strides = at::detail::computeStride(result.sizes(), result.strides(), cat_sizes);
2801     if (strides.has_value()) {
2802       //can take fast cat path
2803       auto result_view = result.view(cat_sizes);
2804       at::cat_out(result_view, tensors, wrapped_dim);
2805       return result;
2806     }
2807   }
2808   return at::cat_out(result, get_stack_inputs(tensors, dim), dim);
2809 
2810 }
2811 
hstack(TensorList tensors)2812 Tensor hstack(TensorList tensors) {
2813   TORCH_CHECK(!tensors.empty(),
2814            "hstack expects a non-empty TensorList");
2815   auto rep = at::atleast_1d(tensors);
2816   if (rep[0].dim() == 1) {
2817     return at::cat(rep, 0);
2818   }
2819   return at::cat(rep, 1);
2820 }
2821 
hstack_out(TensorList tensors,Tensor & result)2822 Tensor& hstack_out(TensorList tensors, Tensor& result) {
2823   TORCH_CHECK(!tensors.empty(),
2824            "hstack expects a non-empty TensorList");
2825   auto rep = at::atleast_1d(tensors);
2826   if (rep[0].dim() == 1) {
2827     return at::cat_out(result, rep, 0);
2828   }
2829   return at::cat_out(result, rep, 1);
2830 }
2831 
vstack(TensorList tensors)2832 Tensor vstack(TensorList tensors) {
2833   TORCH_CHECK(!tensors.empty(),
2834            "vstack expects a non-empty TensorList");
2835   auto rep = at::atleast_2d(tensors);
2836   return at::cat(rep, 0);
2837 }
2838 
vstack_out(TensorList tensors,Tensor & result)2839 Tensor& vstack_out(TensorList tensors, Tensor& result) {
2840   TORCH_CHECK(!tensors.empty(),
2841            "vstack expects a non-empty TensorList");
2842   auto rep = at::atleast_2d(tensors);
2843   return at::cat_out(result, rep, 0);
2844 }
2845 
dstack(TensorList tensors)2846 Tensor dstack(TensorList tensors) {
2847   TORCH_CHECK(!tensors.empty(),
2848            "dstack expects a non-empty TensorList");
2849   auto rep = at::atleast_3d(tensors);
2850   return at::cat(rep, 2);
2851 }
dstack_out(TensorList tensors,Tensor & result)2852 Tensor& dstack_out(TensorList tensors, Tensor& result) {
2853   TORCH_CHECK(!tensors.empty(),
2854            "dstack expects a non-empty TensorList");
2855   auto rep = at::atleast_3d(tensors);
2856   return at::cat_out(result, rep, 2);
2857 }
2858 
sparse_transpose_(Tensor & self,int64_t dim0,int64_t dim1)2859 static inline Tensor & sparse_transpose_(Tensor & self, int64_t dim0, int64_t dim1) {
2860   int64_t nsparse_dim = self.sparse_dim();
2861   TORCH_CHECK(dim0 < nsparse_dim && dim1 < nsparse_dim,
2862            "sparse transpose: transposed dimensions must be sparse ",
2863            "Got sparse_dim: ", nsparse_dim, ", d0: ", dim0, ", d1: ", dim1);
2864 
2865   if (self._indices().numel() == 0 && self._values().numel() == 0) {
2866     auto sizes = self.sizes().vec();
2867     std::swap(sizes[dim0], sizes[dim1]);
2868 
2869     at::sparse::get_sparse_impl(self)->raw_resize_(self.sparse_dim(), self.dense_dim(), sizes);
2870   } else {
2871     auto indices = self._indices();
2872     auto row0 = indices.select(0, dim0);
2873     auto row1 = indices.select(0, dim1);
2874 
2875     // swap row0 and row1
2876     auto tmp = at::zeros_like(row0, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
2877     tmp.copy_(row0);
2878     row0.copy_(row1);
2879     row1.copy_(tmp);
2880 
2881     self._coalesced_(false);
2882 
2883     auto sizes = self.sizes().vec();
2884     std::swap(sizes[dim0], sizes[dim1]);
2885 
2886     at::sparse::get_sparse_impl(self)->raw_resize_(self._indices().size(0), self._values().dim() - 1, sizes);
2887   }
2888   return self;
2889 }
2890 
2891 // torch.row_stack, alias for torch.vstack
row_stack_out(TensorList tensors,Tensor & result)2892 Tensor& row_stack_out(TensorList tensors, Tensor& result) {
2893   return at::vstack_out(result, tensors);
2894 }
2895 
row_stack(TensorList tensors)2896 Tensor row_stack(TensorList tensors) {
2897   return at::vstack(tensors);
2898 }
2899 
reshape_input_for_column_stack(TensorList tensors)2900 static std::vector<Tensor> reshape_input_for_column_stack(TensorList tensors) {
2901   std::vector<Tensor> result(tensors.size());
2902   auto transform_lambda = [](const Tensor& input) -> Tensor {
2903     // reshape 0D or 1D tensor t into (t.numel(), 1)
2904     if (input.dim() <= 1) {
2905       return input.reshape_symint({input.sym_numel(), 1});
2906     }
2907     return input;
2908   };
2909   std::transform(tensors.cbegin(),
2910                  tensors.cend(),
2911                  result.begin(),
2912                  transform_lambda);
2913   return result;
2914 }
2915 
column_stack_out(TensorList tensors,Tensor & result)2916 Tensor& column_stack_out(TensorList tensors, Tensor& result) {
2917   TORCH_CHECK(!tensors.empty(),
2918               "column_stack expects a non-empty TensorList");
2919 
2920   auto reshaped_tensors = reshape_input_for_column_stack(tensors);
2921   return at::hstack_out(result, reshaped_tensors);
2922 }
2923 
column_stack(TensorList tensors)2924 Tensor column_stack(TensorList tensors) {
2925   TORCH_CHECK(!tensors.empty(),
2926               "column_stack expects a non-empty TensorList");
2927 
2928   auto reshaped_tensors = reshape_input_for_column_stack(tensors);
2929   return at::hstack(reshaped_tensors);
2930 }
2931 
propagate_transposed_names(Tensor & result,const Tensor & other,int64_t dim0,int64_t dim1)2932 static Tensor& propagate_transposed_names(
2933     Tensor& result,
2934     const Tensor& other,
2935     int64_t dim0,
2936     int64_t dim1) {
2937   if (other.has_names()) {
2938     auto names = other.names().vec();
2939     std::swap(names[dim0], names[dim1]);
2940     namedinference::propagate_names_if_nonempty(result, names);
2941   }
2942   return result;
2943 }
2944 
transpose(const Tensor & self,Dimname dim0,Dimname dim1)2945 Tensor transpose(const Tensor& self, Dimname dim0, Dimname dim1) {
2946   return at::transpose(
2947       self, dimname_to_position(self, dim0), dimname_to_position(self, dim1));
2948 }
2949 
2950 
transpose_(Tensor & self,int64_t dim0,int64_t dim1)2951 Tensor & transpose_(Tensor & self, int64_t dim0, int64_t dim1) {
2952   TORCH_CHECK(
2953       !(self.layout() == kSparseCsr || self.layout() == kSparseCsc ||
2954         self.layout() == kSparseBsr || self.layout() == kSparseBsc),
2955       "torch.transpose_: in-place transposition is not supported for ",
2956       self.layout(),
2957       " layout");
2958 
2959   auto ndims = self.dim();
2960   dim0 = maybe_wrap_dim(dim0, ndims);
2961   dim1 = maybe_wrap_dim(dim1, ndims);
2962   if (dim0 == dim1) {
2963     return self;
2964   }
2965 
2966   // Sparse COO is an exceptional sparse format as it allows transpose
2967   // to be a view operation which is a convenient property for
2968   // in-place operations. For other sparse formats, the in-place
2969   // transpose would not be possible without shuffling the specified
2970   // values. So we don't support this as it would defeat the purpose
2971   // of in-place opreations of being memory-efficient.
2972   if (self.is_sparse()) {
2973     return sparse_transpose_(self, dim0, dim1);
2974   }
2975 
2976   if (self.is_mkldnn()) {
2977     return at::_mkldnn_transpose_(self, dim0, dim1);
2978   }
2979 
2980   DimVector sizes(self.sizes().begin(), self.sizes().end());
2981   DimVector strides(self.strides().begin(), self.strides().end());
2982   std::swap(strides[dim0], strides[dim1]);
2983   std::swap(sizes[dim0], sizes[dim1]);
2984   self.as_strided_(sizes, strides);
2985   return self;
2986 }
2987 
2988 namespace {
2989 // Transpose implementation for sparse compressed layouts
2990 // NB: We assume that dim1,dim0 have already been wrapped
sparse_compressed_transpose(const Tensor & self,int64_t dim0,int64_t dim1)2991 static inline Tensor sparse_compressed_transpose(
2992     const Tensor& self,
2993     int64_t dim0,
2994     int64_t dim1) {
2995   auto compressed_inds = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
2996       self.layout(),
2997       "compressed_inds",
2998       [&self]() { return self.crow_indices(); },
2999       [&self]() { return self.ccol_indices(); });
3000 
3001   auto plain_inds = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
3002       self.layout(),
3003       "plain_inds",
3004       [&self]() { return self.col_indices(); },
3005       [&self]() { return self.row_indices(); });
3006 
3007   const auto n_batch_dim = compressed_inds.dim() - 1;
3008   const auto dense_dim = self.dim() - n_batch_dim - 2;
3009 
3010   // In theory it works, but missing to_dense coverage to test
3011   TORCH_CHECK(
3012       dense_dim == 0,
3013       "transpose(): hybrid sparse compressed tensors with dense dimensions are not supported");
3014 
3015   // Classify transpose "type"
3016   enum class TransposeDim : uint8_t { Batch, Sparse, Dense };
3017   auto classify_dim = [&n_batch_dim](const int64_t dim) {
3018     if (dim < n_batch_dim) {
3019       return TransposeDim::Batch;
3020     } else if (dim > n_batch_dim + 1) {
3021       return TransposeDim::Dense;
3022     } else {
3023       return TransposeDim::Sparse;
3024     }
3025   };
3026 
3027   const auto transpose_type = classify_dim(dim0);
3028   {
3029 #ifndef STRIP_ERROR_MESSAGES
3030     auto dim_type_name = [](const TransposeDim dim) {
3031       switch (dim) {
3032         case TransposeDim::Batch:
3033           return "Batch";
3034         case TransposeDim::Dense:
3035           return "Dense";
3036         case TransposeDim::Sparse:
3037           return "Sparse";
3038         default:
3039           TORCH_INTERNAL_ASSERT(
3040               false,
3041               "Impossible TransposeDim value: ",
3042               static_cast<std::underlying_type_t<TransposeDim>>(dim));
3043       }
3044     };
3045 #endif
3046     const auto dim1_type = classify_dim(dim1);
3047     TORCH_CHECK(
3048         dim1_type == transpose_type,
3049         "transpose(): can only transpose dimensions of the same type (Batch, Sparse, Dense), got ",
3050         dim0,
3051         "(",
3052         dim_type_name(transpose_type),
3053         ")",
3054         " and ",
3055         dim1,
3056         "(",
3057         dim_type_name(dim1_type),
3058         ")");
3059   }
3060 
3061   // We have validated everything, early exit for equal dims (no effect)
3062   if (dim0 == dim1) {
3063     return self.clone();
3064   }
3065 
3066   auto result_sizes = DimVector(self.sizes());
3067   std::swap(result_sizes[dim0], result_sizes[dim1]);
3068   Tensor result_vals;
3069   auto result_layout = self.layout();
3070 
3071   if (transpose_type == TransposeDim::Batch) {
3072     compressed_inds = compressed_inds.transpose(dim0, dim1).contiguous();
3073     plain_inds = plain_inds.transpose(dim0, dim1).contiguous();
3074     result_vals = self.values().transpose(dim0, dim1).contiguous();
3075 
3076   } else if (transpose_type == TransposeDim::Dense) {
3077     // NB: This code should work, but is untestable due to lack of support for
3078     // dense dimensions in to_dense. The Debug assert is present to emphasize
3079     // the fact that the block should not be possible to hit this code block
3080     TORCH_INTERNAL_ASSERT(
3081         false, "transpose(): Shouldn't have reached this point");
3082     result_vals = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
3083         self.layout(),
3084         "sparse_transpose",
3085         // un-blocked: 2 sparse dims map to single nnz dim, so dense dim0/1 are
3086         // one position left
3087         [&]() { return self.values().transpose(dim0 - 1, dim1 - 1); },
3088         // blocked: 2 sparse dims map to 3 (nnz, ) + blocksize dims, so dense
3089         // dim0/1 are one position right
3090         [&]() { return self.values().transpose(dim0 + 1, dim1 + 1); });
3091   } else /*if (transpose_type == TransposeDim::Sparse) */ {
3092     // Flip the layout
3093     result_layout = sparse_csr::flip_compressed_layout(self.layout());
3094     result_vals = AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
3095         self.layout(),
3096         "sparse_transpose",
3097         // un-blocked: no change to values, layout is flipped.
3098         [&]() { return self.values(); },
3099         // blocked: the blocks are nested under the sparse dims so they must be
3100         // transposed as well.
3101         [&]() {
3102           return self.values().transpose(-2 - dense_dim, -1 - dense_dim);
3103         });
3104   }
3105   return at::_sparse_compressed_tensor_unsafe(
3106       compressed_inds,
3107       plain_inds,
3108       result_vals,
3109       result_sizes,
3110       self.options().layout(result_layout));
3111 }
3112 } // namespace
3113 
transpose(const Tensor & self,int64_t dim0,int64_t dim1)3114 Tensor transpose(const Tensor & self, int64_t dim0, int64_t dim1) {
3115   auto ndims = self.dim();
3116   dim0 = maybe_wrap_dim(dim0, ndims);
3117   dim1 = maybe_wrap_dim(dim1, ndims);
3118 
3119   if (self.is_sparse()) {
3120     if (dim0 == dim1) {
3121       return self.clone();
3122     }
3123     Tensor self_clone = self.clone();
3124     return sparse_transpose_(self_clone, dim0, dim1);
3125   }
3126   if (self.layout() == kSparseBsr || self.layout() == kSparseCsr ||
3127       self.layout() == kSparseBsc || self.layout() == kSparseCsc) {
3128     return sparse_compressed_transpose(self, dim0, dim1);
3129   }
3130 
3131   if (self.is_mkldnn()) {
3132     return at::_mkldnn_transpose(self, dim0, dim1);
3133   }
3134 
3135   // Transpose of a tensor is a view operation.
3136   if (dim0 == dim1) {
3137     return self.alias();
3138   }
3139 
3140   SymDimVector sizes(self.sym_sizes().begin(), self.sym_sizes().end());
3141   std::swap(sizes[dim0], sizes[dim1]);
3142   SymDimVector strides(self.sym_strides().begin(), self.sym_strides().end());
3143   std::swap(strides[dim0], strides[dim1]);
3144   auto result = self.as_strided_symint(sizes, strides);
3145   propagate_transposed_names(result, self, dim0, dim1);
3146   return result;
3147 }
3148 
check_t(const Tensor & self,const char * fn)3149 static void check_t(const Tensor& self, const char *fn) {
3150   if (self.is_sparse()) {
3151     int64_t sparse_dim = self.sparse_dim();
3152     int64_t dense_dim = self.dense_dim();
3153     TORCH_CHECK(sparse_dim <= 2 && dense_dim == 0,
3154              fn, " expects a tensor with <= 2 sparse and 0 dense dimensions, but got ",
3155              sparse_dim, " sparse and ", dense_dim, " dense dimensions");
3156   } else {
3157     TORCH_CHECK(self.dim() <= 2,
3158              fn, " expects a tensor with <= 2 dimensions, but self is ", self.dim(), "D");
3159   }
3160 }
3161 
t(const Tensor & self)3162 Tensor t(const Tensor & self) {
3163   check_t(self, "t()");
3164   return self.transpose(0, self.dim() < 2 ? 0 : 1);
3165 }
3166 
t_(Tensor & self)3167 Tensor & t_(Tensor & self) {
3168   check_t(self, "t_()");
3169   return self.transpose_(0, self.dim() < 2 ? 0 : 1);
3170 }
3171 
3172 std::tuple<SymDimVector, SymDimVector>
inferSqueezeGeometry(const Tensor & tensor)3173 static inferSqueezeGeometry(const Tensor &tensor) {
3174   SymDimVector sizes;
3175   SymDimVector strides;
3176 
3177   for(const auto d : c10::irange(tensor.dim())) {
3178     if(tensor.sym_sizes()[d] != 1) {
3179       sizes.push_back(tensor.sym_sizes()[d]);
3180       strides.push_back(tensor.sym_strides()[d]);
3181     }
3182   }
3183 
3184   return std::make_tuple(std::move(sizes), std::move(strides));
3185 }
3186 
3187 std::tuple<SymDimVector, SymDimVector>
inferSqueezeGeometry(const Tensor & tensor,int64_t dim)3188 static inferSqueezeGeometry(const Tensor& tensor, int64_t dim) {
3189   SymDimVector sizes;
3190   SymDimVector strides;
3191 
3192   for(const auto d : c10::irange(tensor.dim())) {
3193     if(d != dim || tensor.sym_sizes()[dim] != 1) {
3194       sizes.push_back(tensor.sym_sizes()[d]);
3195       strides.push_back(tensor.sym_strides()[d]);
3196     }
3197   }
3198   return std::make_tuple(std::move(sizes), std::move(strides));
3199 }
3200 
3201 std::tuple<SymDimVector, SymDimVector>
inferSqueezeGeometry(const Tensor & tensor,std::bitset<dim_bitset_size> dim_mask)3202 static inferSqueezeGeometry(const Tensor &tensor, std::bitset<dim_bitset_size> dim_mask) {
3203   const auto ndim = tensor.dim();
3204   const auto sym_sizes = tensor.sym_sizes();
3205   const auto sym_strides = tensor.sym_strides();
3206 
3207   SymDimVector out_sizes, out_strides;
3208   for (const auto d: c10::irange(ndim)) {
3209     if (!dim_mask.test(d) || sym_sizes[d] != 1) {
3210       out_sizes.push_back(sym_sizes[d]);
3211       out_strides.push_back(sym_strides[d]);
3212     }
3213   }
3214   return std::make_tuple(std::move(out_sizes), std::move(out_strides));
3215 }
3216 
3217 namespace {
3218 // Named type instead of a pair/tuple so that we can be sure to
3219 // construct the vectors in place and get NRVO.
3220 template <typename T>
3221 struct InferUnsqueezeGeometryResult {
3222   SmallVector<T, kDimVectorStaticSize>sizes;
3223   SmallVector<T, kDimVectorStaticSize> strides;
InferUnsqueezeGeometryResultat::native::__anon87b55d972711::InferUnsqueezeGeometryResult3224   InferUnsqueezeGeometryResult(ArrayRef<T> tensor_sizes, ArrayRef<T> tensor_strides)
3225       : sizes(tensor_sizes.begin(), tensor_sizes.end())
3226       , strides(tensor_strides.begin(), tensor_strides.end()) {}
3227 };
3228 
3229 InferUnsqueezeGeometryResult<c10::SymInt>
inferUnsqueezeGeometry_symint(const Tensor & tensor,int64_t dim)3230 inferUnsqueezeGeometry_symint(const Tensor& tensor, int64_t dim) {
3231   InferUnsqueezeGeometryResult<c10::SymInt> result(tensor.sym_sizes(), tensor.sym_strides());
3232   c10::SymInt new_stride = dim >= tensor.dim() ? 1 : result.sizes[dim] * result.strides[dim];
3233   result.sizes.insert(result.sizes.begin() + dim, 1);
3234   result.strides.insert(result.strides.begin() + dim, new_stride);
3235 
3236   return result;
3237 }
3238 
3239 InferUnsqueezeGeometryResult<int64_t>
inferUnsqueezeGeometry(const Tensor & tensor,int64_t dim)3240 inferUnsqueezeGeometry(const Tensor& tensor, int64_t dim) {
3241   InferUnsqueezeGeometryResult<int64_t> result(tensor.sizes(), tensor.strides());
3242   int64_t new_stride = dim >= tensor.dim() ? 1 : result.sizes[dim] * result.strides[dim];
3243   result.sizes.insert(result.sizes.begin() + dim, 1);
3244   result.strides.insert(result.strides.begin() + dim, new_stride);
3245 
3246   return result;
3247 }
3248 
3249 // dim is present if squeezing a single dimension and absent if squeezing all dimensions
squeeze_qtensor(const Tensor & self,c10::OptionalIntArrayRef dims)3250 Tensor squeeze_qtensor(const Tensor& self, c10::OptionalIntArrayRef dims) {
3251   auto quantizer = get_qtensorimpl(self)->quantizer();
3252   const auto ndim = self.dim();
3253   auto mask = dims.has_value()
3254       ? dim_list_to_bitset(dims, self.dim())
3255       : std::bitset<dim_bitset_size>((1ull << self.dim()) - 1);
3256   auto [sizes, strides] = inferSqueezeGeometry(self, mask);
3257   if (quantizer->qscheme() == QScheme::PER_CHANNEL_AFFINE) {
3258     const auto* per_channel_quantizer = static_cast<at::PerChannelAffineQuantizer*>(quantizer.get());
3259     auto axis = per_channel_quantizer->axis();
3260     int64_t shift = 0;
3261     for (const auto d : c10::irange(ndim)) {
3262       if (mask.test(d) && self.sizes()[d] == 1) {
3263         TORCH_CHECK(axis != d, "Squeeze is only possible on non-axis dimension for Per-Channel Quantized Tensors.");
3264         if (d < axis) {
3265           ++shift;
3266         }
3267       }
3268     }
3269     axis -= shift;
3270     quantizer = make_per_channel_affine_quantizer(per_channel_quantizer->scales(),
3271                                                   per_channel_quantizer->zero_points(),
3272                                                   axis,
3273                                                   quantizer->scalar_type());
3274   }
3275   // TODO: quantized Tensor support for SymInt needs to be added but basic building blocs
3276   // are missing for now.
3277   auto result = make_qtensor(self, C10_AS_INTARRAYREF_SLOW(sizes), C10_AS_INTARRAYREF_SLOW(strides), std::move(quantizer));
3278   auto maybe_outnames = namedinference::compute_squeeze_outnames(self, mask);
3279   namedinference::propagate_names_if_nonempty(result, maybe_outnames);
3280   return result;
3281 }
3282 }
3283 
squeeze(const Tensor & self)3284 Tensor squeeze(const Tensor& self) {
3285   auto g = inferSqueezeGeometry(self);
3286   at::Tensor result = self.as_strided_symint(std::get<0>(g), std::get<1>(g));
3287   auto maybe_outnames = namedinference::compute_squeeze_outnames(self);
3288   namedinference::propagate_names_if_nonempty(result, maybe_outnames);
3289   return result;
3290 }
3291 
squeeze_quantized(const Tensor & self)3292 Tensor squeeze_quantized(const Tensor& self) {
3293   return squeeze_qtensor(self, std::nullopt);
3294 }
3295 
squeeze(const Tensor & self,int64_t dim)3296 Tensor squeeze(const Tensor& self, int64_t dim) {
3297   int64_t dims = self.dim();
3298   dim = maybe_wrap_dim(dim, dims);
3299   if (dims == 0 || self.sym_sizes()[dim] != 1) {
3300     return self.as_strided_symint(self.sym_sizes(), self.sym_strides());
3301   }
3302   auto g = inferSqueezeGeometry(self, dim);
3303   auto result = self.as_strided_symint(std::get<0>(g), std::get<1>(g));
3304   namedinference::propagate_names_except(result, self, {dim});
3305   return result;
3306 }
3307 
squeeze_quantized(const Tensor & self,int64_t dim)3308 Tensor squeeze_quantized(const Tensor& self, int64_t dim) {
3309   return squeeze_qtensor(self, dim);
3310 }
3311 
squeeze(const Tensor & self,IntArrayRef dims)3312 Tensor squeeze(const Tensor& self, IntArrayRef dims) {
3313   auto mask = dim_list_to_bitset(dims, self.dim());
3314   auto g = inferSqueezeGeometry(self, mask);
3315   at::Tensor result = self.as_strided_symint(std::get<0>(g), std::get<1>(g));
3316   auto maybe_outnames = namedinference::compute_squeeze_outnames(self, mask);
3317   namedinference::propagate_names_if_nonempty(result, maybe_outnames);
3318   return result;
3319 }
3320 
squeeze_quantized(const Tensor & self,IntArrayRef dim)3321 Tensor squeeze_quantized(const Tensor& self, IntArrayRef dim) {
3322   return squeeze_qtensor(self, dim);
3323 }
3324 
squeeze_(Tensor & self)3325 Tensor & squeeze_(Tensor& self) {
3326   auto g = inferSqueezeGeometry(self);
3327   self.as_strided__symint(std::get<0>(g), std::get<1>(g));
3328   return self;
3329 }
3330 
squeeze_(Tensor & self,int64_t dim)3331 Tensor & squeeze_(Tensor& self, int64_t dim) {
3332   int64_t dims = self.dim();
3333   dim = maybe_wrap_dim(dim, self.dim());
3334 
3335   if (dims == 0 || self.sym_sizes()[dim] != 1) {
3336     self.as_strided__symint(self.sym_sizes(), self.sym_strides());
3337     return self;
3338   }
3339   auto g = inferSqueezeGeometry(self, dim);
3340   self.as_strided__symint(std::get<0>(g), std::get<1>(g));
3341   return self;
3342 }
3343 
squeeze_(Tensor & self,IntArrayRef dims)3344 Tensor & squeeze_(Tensor &self, IntArrayRef dims) {
3345   auto mask = dim_list_to_bitset(dims, self.dim());
3346   auto g = inferSqueezeGeometry(self, mask);
3347   self.as_strided__symint(std::get<0>(g), std::get<1>(g));
3348   return self;
3349 }
3350 
3351 // NOTE [ Unsafe View ]
3352 // _unsafe_view() differs from view() in that the returned tensor isn't treated
3353 // as a view for the purposes of automatic differentiation. (It's not listed in
3354 // VIEW_FUNCTIONS in gen_inplace_or_view_type.py).  It's only safe to use if the `self` tensor
3355 // is temporary. For example, the viewed tensor here (a + b) is discarded immediately
3356 // after viewing:
3357 //
3358 //  res = at::_unsafe_view(a + b, size);
3359 //
3360 // This is a hack because in-place operations on tensors treated like views
3361 // can be much more expensive than the same operations on non-view tensors.
3362 
view_impl(const Tensor & self,IntArrayRef size)3363 inline Tensor view_impl(const Tensor& self, IntArrayRef size) {
3364 
3365   at::DimVector inferred_size = at::infer_size_dv(size, self.numel());
3366   auto stride = at::detail::computeStride(self.sizes(),
3367                                           self.strides(),
3368                                           inferred_size);
3369   TORCH_CHECK(stride.has_value(), "view size is "
3370     "not compatible with input tensor's size and stride (at least one dimension"
3371     " spans across two contiguous subspaces). Use .reshape(...) instead.");
3372   return alias_with_sizes_and_strides(self, inferred_size, *stride);
3373 
3374 }
3375 
_unsafe_view(const Tensor & self,IntArrayRef size)3376 Tensor _unsafe_view(const Tensor& self, IntArrayRef size) {
3377   return view_impl(self, size);
3378 }
3379 
unsqueeze(const Tensor & self,int64_t dim)3380 Tensor unsqueeze(const Tensor& self, int64_t dim) {
3381   dim = maybe_wrap_dim(dim, self.dim() + 1);
3382   auto g = inferUnsqueezeGeometry_symint(self, dim);
3383   return self.as_strided_symint(g.sizes, g.strides);
3384 }
3385 
unsqueeze_sparse(Tensor const & self,int64_t dim)3386 Tensor unsqueeze_sparse(Tensor const &self, int64_t dim) {
3387   dim = maybe_wrap_dim(dim, self.dim() + 1);
3388   int64_t sparse_dim = self.sparse_dim();
3389   int64_t dense_dim = self.dense_dim();
3390   auto indices = self._indices();
3391   auto sizes = self.sizes().vec();
3392   sizes.insert(sizes.begin() + dim, 1);
3393   if (dim <= sparse_dim) {
3394     auto new_indices = at::cat(
3395         {indices.narrow(0, 0, dim),
3396          at::zeros(
3397              {1, indices.size(1)},
3398              kLong,
3399              indices.options().layout_opt(),
3400              indices.options().device_opt(),
3401              indices.options().pinned_memory_opt()),
3402          indices.narrow(0, dim, indices.size(0) - dim)});
3403     return _sparse_coo_tensor_with_dims_and_tensors(
3404         sparse_dim + 1, dense_dim, sizes, new_indices, self._values(), self.options());
3405   } else {
3406     return _sparse_coo_tensor_with_dims_and_tensors(
3407         sparse_dim, dense_dim + 1, sizes, indices, self._values().unsqueeze(dim - sparse_dim + 1), self.options());
3408   }
3409 }
3410 
unsqueeze_quantized(const Tensor & self,int64_t dim)3411 Tensor unsqueeze_quantized(const Tensor& self, int64_t dim) {
3412   dim = maybe_wrap_dim(dim, self.dim() + 1);
3413   auto g = inferUnsqueezeGeometry(self, dim);
3414   auto quantizer = get_qtensorimpl(self)->quantizer();
3415   if (quantizer->qscheme() == QScheme::PER_CHANNEL_AFFINE) {
3416     const auto* per_channel_quantizer = static_cast<at::PerChannelAffineQuantizer*>(quantizer.get());
3417     auto axis = per_channel_quantizer->axis();
3418     if (axis >= dim) {
3419       axis += 1;
3420     }
3421     quantizer = make_per_channel_affine_quantizer(per_channel_quantizer->scales(),
3422                                                   per_channel_quantizer->zero_points(),
3423                                                   axis,
3424                                                   quantizer->scalar_type());
3425   }
3426   return make_qtensor(self, g.sizes, g.strides, std::move(quantizer));
3427 }
3428 
unsqueeze_(Tensor & self,int64_t dim)3429 Tensor & unsqueeze_(Tensor& self, int64_t dim) {
3430   dim = maybe_wrap_dim(dim, self.dim() + 1);
3431 
3432   auto g = inferUnsqueezeGeometry_symint(self, dim);
3433   self.as_strided__symint(g.sizes, g.strides);
3434   return self;
3435 }
3436 
flatten(const Tensor & self,int64_t start_dim,int64_t end_dim)3437 Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) {
3438   start_dim = maybe_wrap_dim(start_dim, self.dim());
3439   end_dim = maybe_wrap_dim(end_dim, self.dim());
3440   TORCH_CHECK(start_dim <= end_dim, "flatten() has invalid args: start_dim cannot come after end_dim");
3441 
3442   if (self.dim() == 0) {
3443     return self.reshape({1});
3444   }
3445   if (start_dim == end_dim) {
3446     return self;
3447   }
3448 
3449   // We don't want to infer_size on the entire shape, because that can give us an extra degree
3450   // of freedom we don't want; for example, consider shape [0, 1, 3, 0], with start_dim=1, end_dim=2.
3451   // It's clear we want result shape [0, 3, 0] but passing [0, -1, 0] to infer_size means the -1
3452   // can take on any value and satisfy the constraints.
3453   auto slice_numel = c10::multiply_integers(self.sym_sizes().slice(start_dim, end_dim - start_dim + 1));
3454   std::vector<c10::SymInt> shape;
3455   shape.reserve(self.dim() - end_dim + start_dim);
3456   for (const auto i : c10::irange(start_dim)) {
3457     shape.push_back(self.sym_sizes()[i]);
3458   }
3459   shape.push_back(slice_numel);
3460   for (const auto i : c10::irange(end_dim + 1, self.dim())) {
3461     shape.push_back(self.sym_sizes()[i]);
3462   }
3463 
3464   return native::reshape_symint(self, shape);
3465 }
3466 
flatten(const Tensor & self,int64_t start_dim,int64_t end_dim,Dimname out_dim)3467 Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim, Dimname out_dim) {
3468   start_dim = maybe_wrap_dim(start_dim, self.dim());
3469   end_dim = maybe_wrap_dim(end_dim, self.dim());
3470   TORCH_CHECK(start_dim <= end_dim, "flatten() has invalid args: start_dim cannot come after end_dim");
3471 
3472   auto outnames = self.names().vec();
3473   outnames.erase(outnames.begin() + start_dim, outnames.begin() + end_dim + 1);
3474   outnames.insert(outnames.begin() + start_dim, out_dim);
3475 
3476   Tensor result;
3477   {
3478     NoNamesGuard guard;
3479     result = native::flatten(self, start_dim, end_dim);
3480   }
3481   internal_set_names_inplace(result, outnames);
3482   return result;
3483 }
3484 
flatten(const Tensor & self,Dimname start_dim,Dimname end_dim,Dimname out_dim)3485 Tensor flatten(const Tensor& self, Dimname start_dim, Dimname end_dim, Dimname out_dim) {
3486   auto start_pos = dimname_to_position(self, start_dim);
3487   auto end_pos  = dimname_to_position(self, end_dim);
3488   return native::flatten(self, start_pos, end_pos, out_dim);
3489 }
3490 
flatten(const Tensor & self,DimnameList dims,Dimname out_dim)3491 Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) {
3492   auto positions = dimnames_to_positions(self, dims);
3493   TORCH_CHECK(!positions.empty(),
3494       "flatten(tensor, dims, out_dim): dims cannot be empty");
3495   for (const auto i : c10::irange(positions.size() - 1)) {
3496     if (positions[i] + 1 == positions[i + 1]) continue;
3497     TORCH_CHECK(positions[i] + 1 == positions[i + 1],
3498         "flatten(tensor, dims, out_dim): dims ", dims, " must be consecutive ",
3499         "in Tensor", self.names());
3500   }
3501   return native::flatten(self, *dims.begin(), *(dims.end() - 1), out_dim);
3502 }
3503 
ravel(const Tensor & self)3504 Tensor ravel(const Tensor& self) {
3505   return self.contiguous().view(-1);
3506 }
3507 
handle_unflatten_exception(const std::runtime_error & e,const Tensor & self,int64_t dim,SymIntArrayRef sizes,std::optional<DimnameList> names)3508 static inline void handle_unflatten_exception(const std::runtime_error &e,
3509                                               const Tensor &self,
3510                                               int64_t dim,
3511                                               SymIntArrayRef sizes,
3512                                               std::optional <DimnameList> names) {
3513   if (!strstr(e.what(), "is invalid for input of size")) {
3514     TORCH_CHECK(false, "unflatten got an unexpected error:\n", e.what());
3515   }
3516 
3517   if (self.has_names()) {
3518     TORCH_CHECK(false,
3519                 "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ",
3520                 dim, " (", self.names()[dim], ": ", self.sym_size(dim), ") in Tensor", self.names());
3521 
3522   } else {
3523     TORCH_CHECK(false,
3524                 "unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ",
3525                 dim, " (", self.sym_size(dim), ") in the input tensor");
3526   }
3527 }
3528 
unflatten_impl(const Tensor & self,int64_t dim,SymIntArrayRef sizes,std::optional<DimnameList> names)3529 static Tensor unflatten_impl(const Tensor& self, int64_t dim, SymIntArrayRef sizes, std::optional<DimnameList> names) {
3530   dim = maybe_wrap_dim(dim, self.dim());
3531 
3532   TORCH_CHECK(!sizes.empty(), "unflatten: sizes must be non-empty");
3533   TORCH_INTERNAL_ASSERT(!names || names->size() == sizes.size());
3534   if (self.has_names()) {
3535     TORCH_CHECK(names, "unflatten: input is a named tensor but no names were given for unflattened sizes");
3536   }
3537 
3538   SymDimVector inferred_size;
3539   try {
3540     inferred_size = at::infer_size_dv(sizes, self.sym_size(dim));
3541   } catch (const std::runtime_error& e) {
3542     // at::infer_size would throw std::runtime_error for invalid size,
3543     // catch the runtime_error and display the error message in a more user-friendly way
3544     // for both tensors and named tensors
3545     handle_unflatten_exception(e, self, dim, sizes, names);
3546   }
3547 
3548   SymDimVector shape(self.sym_sizes().begin(), self.sym_sizes().end());
3549   shape.erase(shape.begin() + dim);
3550   shape.insert(shape.begin() + dim, inferred_size.begin(), inferred_size.end());
3551 
3552   Tensor result;
3553   {
3554     NoNamesGuard guard;
3555     result = self.view_symint(shape);
3556   }
3557 
3558   if (names) {
3559     auto outnames = self.names().vec();
3560     outnames.erase(outnames.begin() + dim);
3561     outnames.insert(outnames.begin() + dim, names->begin(), names->end());
3562     at::internal_set_names_inplace(result, outnames);
3563   }
3564 
3565   return result;
3566 }
3567 
unflatten_symint(const Tensor & self,int64_t dim,SymIntArrayRef sizes)3568 Tensor unflatten_symint(const Tensor& self, int64_t dim, SymIntArrayRef sizes) {
3569   return native::unflatten_impl(self, dim, sizes, std::nullopt);
3570 }
3571 
unflatten_dimname_symint(const Tensor & self,Dimname dim,SymIntArrayRef sizes,DimnameList names)3572 Tensor unflatten_dimname_symint(const Tensor& self, Dimname dim, SymIntArrayRef sizes, DimnameList names) {
3573   return native::unflatten_impl(self, dimname_to_position(self, dim), sizes, names);
3574 }
3575 
view_as(const Tensor & self,const Tensor & other)3576 Tensor view_as(const Tensor& self, const Tensor& other) {
3577   return self.view_symint(other.sym_sizes());
3578 }
3579 
unbind(const Tensor & self,int64_t dim)3580 std::vector<Tensor> unbind(const Tensor &self, int64_t dim) {
3581   dim = maybe_wrap_dim(dim, self.dim());
3582   int64_t size = self.size(dim);
3583   std::vector<Tensor> tensors(size);
3584   for (const auto i : c10::irange(size)) {
3585     tensors[i] = self.select(dim, i);
3586   }
3587   return tensors;
3588 }
3589 
unbind(const Tensor & self,Dimname dim)3590 std::vector<Tensor> unbind(const Tensor& self, Dimname dim) {
3591   return at::unbind(self, dimname_to_position(self, dim));
3592 }
3593 
meshgrid(TensorList tensors)3594 std::vector<Tensor> meshgrid(TensorList tensors) {
3595   TORCH_WARN_ONCE("torch.meshgrid: in an upcoming release, it will be required to pass the "
3596                   "indexing argument.");
3597   return native::meshgrid(tensors, /*indexing=*/"ij");
3598 }
3599 
meshgrid(TensorList tensors,c10::string_view indexing)3600 std::vector<Tensor> meshgrid(TensorList tensors,
3601                              c10::string_view indexing) {
3602   int64_t size = tensors.size();
3603   TORCH_CHECK(size > 0, "meshgrid expects a non-empty TensorList");
3604 
3605   for(const auto i: c10::irange(size - 1)){
3606     TORCH_CHECK(tensors[i].dtype() == tensors[i+1].dtype(), "meshgrid expects all tensors to have the same dtype");
3607     TORCH_CHECK(tensors[i].device() == tensors[i+1].device(), "meshgrid expects all tensors to have the same device");
3608   }
3609 
3610   // Input tensors is of type TensorList, which is an alias to a
3611   // constant array slice, which doesn't allow for mutations. We may
3612   // need to swap our first two elements if indexing is "ij", so we
3613   // unconditionally create a vector that we can reorder to keep the
3614   // implementation simple.
3615   //
3616   // We are not concerned with the performance of this relative to
3617   // constructor a grid for each input.
3618   std::vector<std::reference_wrapper<const Tensor>> tensor_refs(tensors.begin(),
3619                                                                 tensors.end());
3620 
3621   // Whether or not to swap the first two tensors.
3622   //
3623   // We only swap if there are at least two* input tensors (obviously)
3624   // and if indexing is "xy".
3625   //
3626   // A reminder about "xy" semantics: "xy" semantics implies that the
3627   // output grids are in the cartesian coordinate system. Thus the
3628   // first dimension is the "x" axis (corresponding to column) and the
3629   // second dimension is the "y" axis (corresponding to row). Tensors,
3630   // however, generally consider the first axis to be the row and the
3631   // second axis to be the columns. Thus we flip the two dimensions in
3632   // contrast to "ij" indexing.
3633   //
3634   // It turns out that it's easiest to implement this by just swapping
3635   // the first two inputs. However, the order of the outputs still
3636   // must correspond to the order of the inputs. Thus we also must
3637   // swap the outputs if we swapped the inputs.
3638   //
3639   // * Why do we even support this function for exactly one input?
3640   bool swap_first_and_second_tensors = false;
3641 
3642   if (indexing == "xy") {
3643     // We can only swap if there are multiple tensors.
3644     swap_first_and_second_tensors = size >= 2;
3645     if (swap_first_and_second_tensors) {
3646       std::swap(tensor_refs[0], tensor_refs[1]);
3647     }
3648   } else {
3649     // Only "xy" and "ij" are supported, and we already checked for
3650     // "xy" above. Only "ij" remains as a valid mode.
3651     TORCH_CHECK(indexing == "ij",
3652                 "torch.meshgrid: indexing must be one of \"xy\" or \"ij\", "
3653                 "but received: ", indexing);
3654   }
3655 
3656   std::vector<c10::SymInt> shape(size);
3657   for(const auto i: c10::irange(size)){
3658     TORCH_CHECK(tensor_refs[i].get().dim() <= 1,
3659                 "torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: ", tensor_refs[i]);
3660     shape[i] = tensor_refs[i].get().sym_numel();  // treat 0D tensors as if they were a 1D tensor
3661   }
3662   std::vector<Tensor> grids;
3663   grids.reserve(size);
3664   std::vector<c10::SymInt> view_shape(size, 1);
3665   for(const auto i: c10::irange(size)){
3666     view_shape[i] = -1;  // select this dimension to infer
3667     grids.push_back(tensor_refs[i].get().view_symint(view_shape).expand_symint(shape));
3668     view_shape[i] = 1;  // restore to previous value
3669   }
3670 
3671   // Remember we need to also swap the outputs if we swapped the inputs.
3672   if (swap_first_and_second_tensors) {
3673     std::swap(grids[0], grids[1]);
3674   }
3675   return grids;
3676 }
3677 
3678 // Numpy-style `a.T`: returns the tensor
3679 // with dims reversed
numpy_T(const Tensor & self)3680 Tensor numpy_T(const Tensor &self) {
3681   const auto n = self.dim();
3682   if (n != 2 && n != 0) {
3683     TORCH_WARN_ONCE(
3684         "The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated ",
3685         "and it will throw an error in a future release. Consider `x.mT` to transpose batches of matrices ",
3686         "or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor."
3687     );
3688   }
3689   if (n == 0) {
3690    // Added in PyTorch 2.0
3691    TORCH_WARN_ONCE("Tensor.T is deprecated on 0-D tensors. This function is the identity in these cases.");
3692   }
3693   DimVector transpose_dims;
3694   for (int64_t i = n - 1; i >= 0; --i) {
3695     transpose_dims.push_back(i);
3696   }
3697   return self.permute(transpose_dims);
3698 }
3699 
matrix_H(const Tensor & self)3700 Tensor matrix_H(const Tensor &self) {
3701   const auto ndim = self.dim();
3702   if (ndim == 0) {
3703    // Added in PyTorch 2.0
3704    TORCH_WARN_ONCE("Tensor.H is deprecated on 0-D tensors. Consider using x.conj().");
3705   }
3706   TORCH_CHECK(ndim == 2 || ndim == 0,
3707       "tensor.H is only supported on matrices (2-D tensors). Got ", ndim, "-D tensor.",
3708       ndim > 2 ? " For batches of matrices, consider using tensor.mH" : "");
3709   if (self.is_complex()) {
3710     return ndim == 0 ? self.conj() : self.transpose(-2, -1).conj();
3711   } else {
3712     return ndim == 0 ? self : self.transpose(-2, -1);
3713   }
3714 }
3715 
3716 namespace {
_adjoint(const Tensor & self,const bool transpose,const char * const name)3717 Tensor _adjoint(const Tensor &self, const bool transpose, const char* const name) {
3718   const auto ndim = self.dim();
3719   TORCH_CHECK(ndim != 1,
3720       "tensor.", name, " is only supported on matrices or batches of matrices. Got 1-D tensor.");
3721   if (transpose || !self.is_complex()) {
3722     return ndim == 0 ? self : self.transpose(-2, -1);
3723   } else {
3724     return ndim == 0 ? self.conj() : self.transpose(-2, -1).conj();
3725   }
3726 }
3727 } // anonymous namespace
3728 
mT(const Tensor & self)3729 Tensor mT(const Tensor &self) {
3730   if (self.dim() == 0) {
3731    // Added in PyTorch 2.0
3732    TORCH_WARN_ONCE("Tensor.mT is deprecated on 0-D tensors. This function is the identity in these cases.");
3733   }
3734   return _adjoint(self, /*transpose=*/true, "mT");
3735 }
3736 
mH(const Tensor & self)3737 Tensor mH(const Tensor &self) {
3738   if (self.dim() == 0) {
3739     // Added in PyTorch 2.0
3740    TORCH_WARN_ONCE("Tensor.mH is deprecated on 0-D tensors. Consider using x.conj().");
3741   }
3742   return _adjoint(self, /*transpose=*/false, "mH");
3743 }
3744 
adjoint(const Tensor & self)3745 Tensor adjoint(const Tensor &self) {
3746   if (self.dim() == 0) {
3747    TORCH_WARN_ONCE("adjoint() is deprecated on 0-D tensors. Consider using x.conj().");
3748   }
3749   return _adjoint(self, /*transpose=*/false, "adjoint()");
3750 }
3751 
view(const Tensor & self,at::IntArrayRef size)3752 Tensor view(const Tensor& self,
3753             at::IntArrayRef size) {
3754   return view_impl(self, size);
3755 }
3756 
alias(const Tensor & self)3757 Tensor alias(const Tensor& self) {
3758   return alias_with_sizes_and_strides(self, self.sym_sizes(), self.sym_strides());
3759 }
3760 
detach(const Tensor & self)3761 Tensor detach(const Tensor& self) {
3762   // NB: detach() is not the same thing as alias()! The main difference is that
3763   // detach does not allow metadata change while alias does.
3764   return Tensor(self.getIntrusivePtr()->shallow_copy_and_detach(
3765     // NB: The ADInplaceOrView logic will overwrite these with the
3766     // appropriate values if it runs; otherwise these are the values.
3767     /*version_counter=*/0,
3768     /*allow_tensor_metadata_change=*/false));
3769 }
3770 
unfold(const Tensor & self,int64_t d,int64_t size,int64_t step)3771 Tensor unfold(const Tensor& self, int64_t d, int64_t size, int64_t step) {
3772   // some special handling to deal with allow d == 0 when self.dim() == 0
3773   auto ndim = self.dim();
3774   d = at::maybe_wrap_dim(d, ndim, /*wrap_scalar=*/true);
3775 
3776   auto sizes = self.sizes().vec();
3777   auto strides = self.strides().vec();
3778   int64_t max_size = self.dim() == 0 ? 1 : sizes[d];
3779   TORCH_CHECK(size <= max_size, "maximum size for tensor at dimension ", d,
3780                                 " is ", max_size, " but size is ", size);
3781   TORCH_CHECK(step > 0, "step is ", step, " but must be > 0");
3782   sizes.push_back(size);
3783   strides.push_back(self.dim() == 0 ? 1 : strides[d]);
3784   // The if handles the self.dim() == 0 case
3785   if (d < ndim) {
3786     sizes[d] = (sizes[d] - size) / step + 1;
3787     strides[d] *= step;
3788   }
3789   return self.as_strided(sizes, strides);
3790 }
3791 
diag(const Tensor & self,int64_t offset)3792 Tensor diag(const Tensor& self, int64_t offset) {
3793   auto ndim = self.dim();
3794   TORCH_CHECK(ndim == 1 || ndim == 2, "diag(): Supports 1D or 2D tensors. Got ", self.dim(), "D");
3795   if (ndim == 1) {
3796     return at::diag_embed(self, offset);
3797   } else {
3798     // We return a copy of the diagonal
3799     return at::diagonal_copy(self, offset);
3800   }
3801 }
3802 
diag_out(const Tensor & self,int64_t offset,Tensor & out)3803 Tensor& diag_out(const Tensor& self, int64_t offset, Tensor& out) {
3804   auto ndim = self.dim();
3805   TORCH_CHECK(ndim == 1 || ndim == 2, "Supports 1D or 2D tensors. Got ", self.dim(), "D");
3806   if (ndim == 1) {
3807     TORCH_CHECK(
3808         canCast(self.scalar_type(), out.scalar_type()),
3809         "diag: result type ", self.scalar_type(), " can't be cast to the desired out= type ",
3810         out.scalar_type());
3811     return at::diag_embed_out(out, self, offset);
3812   } else {
3813     return at::diagonal_copy_out(out, self, offset);
3814   }
3815 }
3816 
diagonal_backward_symint(const Tensor & grad,SymIntArrayRef input_sizes,int64_t offset,int64_t dim1,int64_t dim2)3817 Tensor diagonal_backward_symint(const Tensor & grad, SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
3818   auto grad_input = at::zeros_symint(input_sizes, grad.options());
3819   auto diag = grad_input.diagonal(offset, dim1, dim2);
3820   diag.copy_(grad);
3821   return grad_input;
3822 }
3823 
movedim(const Tensor & self,IntArrayRef src,IntArrayRef dst)3824 Tensor movedim(const Tensor& self, IntArrayRef src, IntArrayRef dst) {
3825   TORCH_CHECK(src.size() == dst.size(), "movedim: Invalid source or destination dims: source (",
3826               src, " dims) should contain the same number of dims as destination (", dst, " dims)");
3827 
3828   size_t self_dim = self.dim();
3829   DimVector normalized_src(src.size());
3830   DimVector normalized_dst(dst.size());
3831 
3832   auto wrap_dims = [&self_dim](const IntArrayRef& vec, DimVector& normalized_vec) {
3833     for (const auto i : c10::irange(vec.size())) {
3834       normalized_vec[i] = maybe_wrap_dim(vec[i], self_dim);
3835     }
3836   };
3837 
3838   wrap_dims(src, normalized_src);
3839   wrap_dims(dst, normalized_dst);
3840 
3841   auto all_unique = [](const DimVector& dims) {
3842     DimVector copy = dims;
3843     std::sort(copy.begin(), copy.end());
3844     auto duplicate = std::adjacent_find(copy.begin(), copy.end());
3845     return duplicate == copy.end();
3846   };
3847   TORCH_CHECK(all_unique(normalized_src), "movedim: repeated dim in `source` (", src, ")");
3848   TORCH_CHECK(all_unique(normalized_dst), "movedim: repeated dim in `destination` (", dst, ")");
3849 
3850   // handle the case of scalar tensor as a no-op
3851   if (self_dim == 0)
3852     return self.alias();
3853 
3854   // TODO: The algorithm below can probably be optimized.
3855   // Reference: https://github.com/pytorch/pytorch/pull/41480#discussion_r456100505
3856 
3857   // Algorithm Walkthrough
3858   // Example Input
3859   // Variable State:
3860   //     normalized_src = 0, 1
3861   //     normalized_dst = 2, 4
3862   //     self_dim = 5
3863   DimVector order(self_dim);
3864   DimVector source_dims(self_dim);
3865   DimVector destination_dims(self_dim);
3866 
3867   // We initialize two vectors to track update to the dims
3868   // `order` contains the final order of the dim positions.
3869   // Variable State:
3870   //     order = NA, NA, NA, NA, NA
3871   //     source_dims = 0, 1, 2, 3, 4
3872   //     destination_dims = 0, 1, 2, 3, 4
3873   std::iota(source_dims.begin(), source_dims.end(), 0);
3874   std::iota(destination_dims.begin(), destination_dims.end(), 0);
3875 
3876   // We mark and update position for the dim provided by user
3877   // i.e. `normalized_src` and `normalized_dims`
3878   // Variable State:
3879   //     order = NA, NA, 0, NA, 1
3880   //     source_dims = -1, -1, 2, 3, 4
3881   //     destination_dims = 0, 1, -1, 3, -1
3882   for (const auto i : c10::irange(src.size())) {
3883       order[normalized_dst[i]] = normalized_src[i];
3884       source_dims[normalized_src[i]] = -1;
3885       destination_dims[normalized_dst[i]] = -1;
3886   }
3887 
3888   // Remove the dims whose position we already know,
3889   // the ones marked with -1 in previous step
3890   // Variable State:
3891   //     source_dims = 2, 3, 4
3892   //     destination_dims = 0, 1, 3
3893   auto source_iter = std::remove(source_dims.begin(), source_dims.end(), -1);
3894   auto destination_iter = std::remove(destination_dims.begin(), destination_dims.end(), -1);
3895 
3896   int64_t rest_dim = self.dim() - src.size();
3897   TORCH_INTERNAL_ASSERT(std::distance(source_dims.begin(), source_iter)  == rest_dim);
3898   TORCH_INTERNAL_ASSERT(std::distance(destination_dims.begin(), destination_iter)  == rest_dim);
3899 
3900   // Update the position of the remaining dimensions.
3901   // `source_dims` now contains the original position
3902   // `destination_dims` contains the new position it will shifted to
3903   // after considering the user inputs.
3904   // Variable State:
3905   //     order = 2, 3, 0, 4, 1
3906   for (const auto i : c10::irange(rest_dim)) {
3907       order[destination_dims[i]] = source_dims[i];
3908   }
3909 
3910   return self.permute(order);
3911 }
3912 
movedim(const Tensor & self,int64_t src,int64_t dst)3913 Tensor movedim(const Tensor& self, int64_t src, int64_t dst) {
3914   return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst});
3915 }
3916 
moveaxis(const Tensor & self,IntArrayRef src,IntArrayRef dst)3917 Tensor moveaxis(const Tensor& self, IntArrayRef src, IntArrayRef dst) {
3918   return at::movedim(self, src, dst);
3919 }
3920 
moveaxis(const Tensor & self,int64_t src,int64_t dst)3921 Tensor moveaxis(const Tensor& self, int64_t src, int64_t dst) {
3922   return at::movedim(self, IntArrayRef{src}, IntArrayRef{dst});
3923 }
3924 
swapaxes(const Tensor & self,int64_t axis0,int64_t axis1)3925 Tensor swapaxes(const Tensor& self, int64_t axis0, int64_t axis1) {
3926   return self.transpose(axis0, axis1);
3927 }
3928 
swapaxes_(Tensor & self,int64_t axis0,int64_t axis1)3929 Tensor& swapaxes_(Tensor& self, int64_t axis0, int64_t axis1) {
3930   return self.transpose_(axis0, axis1);
3931 }
3932 
swapdims(const Tensor & self,int64_t dim0,int64_t dim1)3933 Tensor swapdims(const Tensor& self, int64_t dim0, int64_t dim1) {
3934   return self.transpose(dim0, dim1);
3935 }
3936 
swapdims_(Tensor & self,int64_t dim0,int64_t dim1)3937 Tensor& swapdims_(Tensor& self, int64_t dim0, int64_t dim1) {
3938   return self.transpose_(dim0, dim1);
3939 }
3940 
flatten_dense_tensors(TensorList tensors)3941 Tensor flatten_dense_tensors(TensorList tensors) {
3942   static auto flatten = [](const Tensor &t) { return t.contiguous().view({-1}); };
3943   if (tensors.size() == 1)
3944     return flatten(tensors[0]);
3945   return at::cat(fmap(tensors, flatten));
3946 }
3947 
unflatten_dense_tensors(const Tensor & flat,TensorList tensors)3948 std::vector<Tensor> unflatten_dense_tensors(const Tensor& flat, TensorList tensors) {
3949   std::vector<Tensor> outputs;
3950   outputs.reserve(tensors.size());
3951   size_t offset = 0;
3952   for (const auto & tensor : tensors) {
3953     auto numel = tensor.numel();
3954     // If unflatten an empty tensor, create a new empty tensor using
3955     // flat tensor Options.
3956     // This can avoid the unflattened empty tensor to share the same storage
3957     // with other unflatten tensors.
3958     if (numel == 0) {
3959       outputs.push_back(at::empty({0}, flat.options()));
3960     } else {
3961       outputs.push_back(flat.narrow(0, offset, numel).view(tensor.sizes()));
3962       offset += numel;
3963     }
3964   }
3965   return outputs;
3966 }
3967 
3968 
3969 // Clones a tensor by cloning the underlying storage that it came from,
3970 // which allows us to replicate the exact strides/storage_offset in the cloned tensor.
3971 // Note [*_scatter ops preserve strides]
3972 // In order for functionalization to preserve stride correctness, the *_scatter
3973 // operators that it calls must preserve the striding behavior of their inputs.
3974 // Specifically, the output of *_scatter(base, mutated_view, ...)
3975 // should have identical size/stride/storage_offset to "base".
clone_preserve_strides(const at::Tensor & self)3976 at::Tensor clone_preserve_strides(const at::Tensor& self) {
3977   TORCH_INTERNAL_ASSERT(self.has_storage());
3978   // In cases where the input tensor has internal memory overlap, we cannot actually
3979   // preserve the strides/storage_offset of the input tensor, because
3980   // *_scatter ops will try to copy_() into the cloned tensor.
3981   // However, this should **never** show up in functionalized user code;
3982   // most aten ops that try to mutate a tensor with internal memory overlap would error anyway.
3983   //
3984   // The one place that this does come up is in autograd - if there's a select_scatter
3985   // in the forward, then autograd will generate one for the backward.
3986   // If the input to the select_scatter is grad_output, then this could be an expanded tensor
3987   // with internal overlap.
3988   if (at::has_internal_overlap(self) == at::MemOverlap::Yes) {
3989     return self.clone();
3990   }
3991   auto dtype_size = self.dtype().itemsize();
3992   auto nbytes = self.storage().sym_nbytes();
3993   TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0);
3994   auto numel = nbytes / dtype_size;
3995   auto self_full_size = self.as_strided_symint({std::move(numel)}, {1}, 0);
3996   auto clone = self_full_size.clone();
3997   auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset());
3998   return out;
3999 }
4000 
4001 
slice_scatter(const at::Tensor & self,const at::Tensor & src,int64_t dim,std::optional<int64_t> start,std::optional<int64_t> end,int64_t step)4002 at::Tensor slice_scatter(const at::Tensor& self, const at::Tensor& src, int64_t dim, std::optional<int64_t> start, std::optional<int64_t> end, int64_t step) {
4003     // See Note [*_scatter ops preserve strides]
4004     auto output = clone_preserve_strides(self);
4005     auto slice = output.slice(dim, start, end, step);
4006     TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
4007     slice.copy_(src);
4008     return output;
4009 }
select_scatter_symint(const at::Tensor & self,const at::Tensor & src,int64_t dim,c10::SymInt index)4010 at::Tensor select_scatter_symint(const at::Tensor& self, const at::Tensor& src, int64_t dim, c10::SymInt index) {
4011     auto output = clone_preserve_strides(self);
4012     auto slice = output.select_symint(dim, std::move(index));
4013     TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
4014     slice.copy_(src);
4015     return output;
4016 }
diagonal_scatter(const at::Tensor & self,const at::Tensor & src,int64_t offset,int64_t dim1,int64_t dim2)4017 at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64_t offset, int64_t dim1, int64_t dim2) {
4018     // See Note [*_scatter ops preserve strides]
4019     auto output = clone_preserve_strides(self);
4020     auto slice = output.diagonal(offset, dim1, dim2);
4021     TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
4022     slice.copy_(src);
4023     return output;
4024 }
as_strided_scatter_symint(const at::Tensor & self,const at::Tensor & src,at::SymIntArrayRef size,at::SymIntArrayRef stride,std::optional<c10::SymInt> storage_offset)4025 at::Tensor as_strided_scatter_symint(const at::Tensor& self, const at::Tensor& src, at::SymIntArrayRef size, at::SymIntArrayRef stride, std::optional<c10::SymInt> storage_offset) {
4026     // See Note [as_strided_scatter backward support]
4027     TORCH_INTERNAL_ASSERT(!self.requires_grad() || self.is_contiguous(), "as_strided_scatter is currently only supported for contiguous inputs");
4028     // See Note [*_scatter ops preserve strides]
4029     auto output = clone_preserve_strides(self);
4030     auto slice = output.as_strided_symint(size, stride, std::move(storage_offset));
4031     TORCH_CHECK(slice.sym_sizes() == src.sym_sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sym_sizes(), ", slice size = ", slice.sym_sizes());
4032     slice.copy_(src);
4033     return output;
4034 }
4035 
4036 // The default implementation of lift is a no-op.
4037 // If TLS is set appropriately (for wrapper-tensor keys like Functionalize or functorch transforms),
4038 // then we'll dispatch to one of their implementations, which will properly lift the tensor into a wrapper.
lift(const at::Tensor & self)4039 at::Tensor lift(const at::Tensor& self) {
4040     return self;
4041 }
4042 
4043 // See notes in native_functions.yaml
lift_fresh(const at::Tensor & self)4044 at::Tensor lift_fresh(const at::Tensor& self) {
4045     return self;
4046 }
4047 
4048 // Autogen kernels for tensor list ops dont work on XLA. TODO(jakeszwe)
split_copy_Tensor_out(const at::Tensor & self,int64_t split_size,int64_t dim,at::TensorList out)4049 void split_copy_Tensor_out(const at::Tensor & self, int64_t split_size, int64_t dim, at::TensorList  out) {
4050   auto tmp = self.split(split_size, dim);
4051 
4052   TORCH_CHECK(out.size() == tmp.size(), "split_copy_Tensor_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
4053   for (const auto i : c10::irange(out.size())) {
4054     out[i].copy_(tmp[i]);
4055   }
4056 }
4057 
split_with_sizes_copy_out(const at::Tensor & self,at::IntArrayRef split_sizes,int64_t dim,at::TensorList out)4058 void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList  out) {
4059   auto tmp = self.split_with_sizes(split_sizes, dim);
4060 
4061   TORCH_CHECK(out.size() == tmp.size(), "split_with_sizes_copy_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
4062   for (const auto i : c10::irange(out.size())) {
4063     if (resize_output_check(out[i], tmp[i].sizes())) {
4064       out[i].resize_(tmp[i].sizes());
4065     }
4066     TORCH_CHECK(out[i].dtype() == tmp[i].dtype(),
4067         "Expected out tensor to have dtype ", tmp[i].dtype(), ", but got ", out[i].dtype(), " instead");
4068     TORCH_CHECK(out[i].device() == tmp[i].device(),
4069         "Expected out tensor to have device ", tmp[i].device(), ", but got ", out[i].device(), " instead");
4070     out[i].copy_(tmp[i]);
4071   }
4072 }
4073 
unbind_copy_int_out(const at::Tensor & self,int64_t dim,at::TensorList out)4074 void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList  out) {
4075   auto tmp = self.unbind(dim);
4076 
4077   TORCH_CHECK(out.size() == tmp.size(), "unbind_copy_int_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size());
4078   for (const auto i : c10::irange(out.size())) {
4079     out[i].copy_(tmp[i]);
4080   }
4081 }
4082 
sparse_dim_default(const Tensor & self)4083 int64_t sparse_dim_default(const Tensor& self) {
4084   TORCH_CHECK(self.layout() == kStrided, "sparse_dim expected sparse or strided tensor layout but got ", self.layout());
4085   return 0;
4086 }
4087 
dense_dim_default(const Tensor & self)4088 int64_t dense_dim_default(const Tensor& self) {
4089   TORCH_CHECK(self.layout() == kStrided, "dense_dim expected sparse or strided tensor layout but got ", self.layout());
4090   return self.dim();
4091 }
4092 
4093 } // namespace at::native
4094