1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/Repeat.h>
6 #include <c10/util/irange.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/repeat_interleave.h>
14 #include <ATen/ops/repeat_interleave_native.h>
15 #endif
16
17 template <typename index_t>
compute_cpu(const index_t * repeat_ptr,const int64_t * cumsum_ptr,index_t * result_ptr,int64_t size,int64_t result_size)18 static void compute_cpu(
19 const index_t* repeat_ptr,
20 const int64_t* cumsum_ptr,
21 index_t* result_ptr,
22 int64_t size,
23 int64_t result_size) {
24 TORCH_CHECK(
25 (result_size == cumsum_ptr[size - 1]),
26 "allocated size does not match required size");
27 at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) {
28 for (const auto i : c10::irange(i_begin, i_end)) {
29 int64_t end = cumsum_ptr[i];
30 index_t size = repeat_ptr[i];
31 TORCH_CHECK((size >= 0), "repeats can not be negative");
32 int64_t start = end - size;
33 for (const auto j : c10::irange(start, end)) {
34 result_ptr[j] = i;
35 }
36 }
37 });
38 }
39
40 namespace at::native {
41
repeat_interleave_cpu(const Tensor & repeat,std::optional<int64_t> output_size)42 Tensor repeat_interleave_cpu(
43 const Tensor& repeat,
44 std::optional<int64_t> output_size) {
45 Tensor output;
46 AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_cpu", [&]() {
47 output = repeat_interleave_common<index_t, compute_cpu<index_t>>(
48 repeat, output_size);
49 });
50
51 return output;
52 }
53
repeat_interleave_symint(const Tensor & self,const Tensor & repeats,std::optional<int64_t> dim,std::optional<SymInt> output_size)54 Tensor repeat_interleave_symint(
55 const Tensor& self,
56 const Tensor& repeats,
57 std::optional<int64_t> dim,
58 std::optional<SymInt> output_size) {
59 Tensor input = self;
60
61 // Store conj and neg bits
62 const auto conj = input.is_conj();
63 if (conj) {
64 input = input.conj();
65 }
66 const auto neg = input.is_neg();
67 if (neg) {
68 input = input._neg_view();
69 }
70
71 if (!dim) {
72 input = input.flatten();
73 dim = 0;
74 }
75
76 Tensor repeats_ = repeats;
77 if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.sym_size(0) == 1)) {
78 repeats_ = repeats.reshape({1}).expand_symint({input.sym_size(dim.value())});
79 } else if (repeats.dim() == 1) {
80 TORCH_CHECK(
81 repeats.sym_size(0) == input.sym_size(dim.value()),
82 "repeats must have the same size as input along dim, but got repeats.size(0) = ",
83 repeats.sym_size(0), " and input.size(", dim.value(), ") = ", input.sym_size(dim.value())
84 );
85 } else {
86 AT_ERROR("repeats must be 0-dim or 1-dim tensor");
87 }
88
89 auto ret = input.index_select(
90 dim.value(), at::repeat_interleave_symint(repeats_, std::move(output_size)));
91 // Restore conj and neg bits
92 if (conj) {
93 ret = ret.conj();
94 }
95 if (neg) {
96 ret = ret._neg_view();
97 }
98 return ret;
99 }
100
repeat_interleave_symint(const Tensor & self,c10::SymInt repeats,std::optional<int64_t> dim_opt,std::optional<SymInt> output_size)101 Tensor repeat_interleave_symint(
102 const Tensor& self,
103 c10::SymInt repeats,
104 std::optional<int64_t> dim_opt,
105 std::optional<SymInt> output_size) {
106 Tensor input = dim_opt ? self : self.flatten();
107 int64_t dim = c10::maybe_wrap_dim(dim_opt.value_or(0), self.dim());
108 TORCH_CHECK(repeats >= 0, "Repeats must be non-negative");
109
110 input = input.unsqueeze(dim + 1);
111 auto expand_shape = input.sym_sizes().vec();
112 expand_shape[dim + 1] = repeats;
113 input = input.expand_symint(expand_shape);
114
115 // This argument doesn't really make sense for the scalar overload, but exists
116 // for consistency with the tensor overload
117 if (output_size) {
118 auto calculated_size = (repeats * expand_shape[dim]).guard_int(__FILE__, __LINE__);
119 TORCH_CHECK(*output_size == calculated_size, "repeat_interleave: Invalid output_size, expected ",
120 calculated_size, " but got ", *output_size);
121 }
122
123 return input.clone(at::MemoryFormat::Contiguous).flatten(dim, dim + 1);
124 }
125
126 } // namespace at::native
127