xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Repeat.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/native/Repeat.h>
6 #include <c10/util/irange.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/repeat_interleave.h>
14 #include <ATen/ops/repeat_interleave_native.h>
15 #endif
16 
17 template <typename index_t>
compute_cpu(const index_t * repeat_ptr,const int64_t * cumsum_ptr,index_t * result_ptr,int64_t size,int64_t result_size)18 static void compute_cpu(
19     const index_t* repeat_ptr,
20     const int64_t* cumsum_ptr,
21     index_t* result_ptr,
22     int64_t size,
23     int64_t result_size) {
24   TORCH_CHECK(
25       (result_size == cumsum_ptr[size - 1]),
26       "allocated size does not match required size");
27   at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) {
28     for (const auto i : c10::irange(i_begin, i_end)) {
29       int64_t end = cumsum_ptr[i];
30       index_t size = repeat_ptr[i];
31       TORCH_CHECK((size >= 0), "repeats can not be negative");
32       int64_t start = end - size;
33       for (const auto j : c10::irange(start, end)) {
34         result_ptr[j] = i;
35       }
36     }
37   });
38 }
39 
40 namespace at::native {
41 
repeat_interleave_cpu(const Tensor & repeat,std::optional<int64_t> output_size)42 Tensor repeat_interleave_cpu(
43     const Tensor& repeat,
44     std::optional<int64_t> output_size) {
45   Tensor output;
46   AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_cpu", [&]() {
47     output = repeat_interleave_common<index_t, compute_cpu<index_t>>(
48         repeat, output_size);
49   });
50 
51   return output;
52 }
53 
repeat_interleave_symint(const Tensor & self,const Tensor & repeats,std::optional<int64_t> dim,std::optional<SymInt> output_size)54 Tensor repeat_interleave_symint(
55     const Tensor& self,
56     const Tensor& repeats,
57     std::optional<int64_t> dim,
58     std::optional<SymInt> output_size) {
59   Tensor input = self;
60 
61   // Store conj and neg bits
62   const auto conj = input.is_conj();
63   if (conj) {
64     input = input.conj();
65   }
66   const auto neg = input.is_neg();
67   if (neg) {
68     input = input._neg_view();
69   }
70 
71   if (!dim) {
72     input = input.flatten();
73     dim = 0;
74   }
75 
76   Tensor repeats_ = repeats;
77   if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.sym_size(0) == 1)) {
78     repeats_ = repeats.reshape({1}).expand_symint({input.sym_size(dim.value())});
79   } else if (repeats.dim() == 1) {
80     TORCH_CHECK(
81         repeats.sym_size(0) == input.sym_size(dim.value()),
82         "repeats must have the same size as input along dim, but got repeats.size(0) = ",
83         repeats.sym_size(0), " and input.size(", dim.value(), ") = ", input.sym_size(dim.value())
84     );
85   } else {
86     AT_ERROR("repeats must be 0-dim or 1-dim tensor");
87   }
88 
89   auto ret = input.index_select(
90       dim.value(), at::repeat_interleave_symint(repeats_, std::move(output_size)));
91   // Restore conj and neg bits
92   if (conj) {
93     ret = ret.conj();
94   }
95   if (neg) {
96     ret = ret._neg_view();
97   }
98   return ret;
99 }
100 
repeat_interleave_symint(const Tensor & self,c10::SymInt repeats,std::optional<int64_t> dim_opt,std::optional<SymInt> output_size)101 Tensor repeat_interleave_symint(
102     const Tensor& self,
103     c10::SymInt repeats,
104     std::optional<int64_t> dim_opt,
105     std::optional<SymInt> output_size) {
106   Tensor input = dim_opt ? self : self.flatten();
107   int64_t dim = c10::maybe_wrap_dim(dim_opt.value_or(0), self.dim());
108   TORCH_CHECK(repeats >= 0, "Repeats must be non-negative");
109 
110   input = input.unsqueeze(dim + 1);
111   auto expand_shape = input.sym_sizes().vec();
112   expand_shape[dim + 1] = repeats;
113   input = input.expand_symint(expand_shape);
114 
115   // This argument doesn't really make sense for the scalar overload, but exists
116   // for consistency with the tensor overload
117   if (output_size) {
118     auto calculated_size = (repeats * expand_shape[dim]).guard_int(__FILE__, __LINE__);
119     TORCH_CHECK(*output_size == calculated_size, "repeat_interleave: Invalid output_size, expected ",
120                 calculated_size, " but got ", *output_size);
121   }
122 
123   return input.clone(at::MemoryFormat::Contiguous).flatten(dim, dim + 1);
124 }
125 
126 } // namespace at::native
127