xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorTransformations.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Tensor.h>
2 
3 #ifndef AT_PER_OPERATOR_HEADERS
4 #include <ATen/Functions.h>
5 #else
6 #include <ATen/ops/roll.h>
7 #endif
8 
9 #include <c10/util/Exception.h>
10 
11 namespace at::native {
12 
roll_common(const Tensor & self,IntArrayRef shifts,IntArrayRef dims)13 static inline Tensor roll_common(const Tensor& self, IntArrayRef shifts, IntArrayRef dims) {
14   TORCH_CHECK(!shifts.empty(), "`shifts` required");
15   if (dims.empty() && shifts.size() == 1) {
16     auto flattened = self.contiguous().view(self.numel());
17     return roll(flattened, shifts[0], 0).view(self.sizes());
18   }
19   TORCH_CHECK(
20     shifts.size() == dims.size(),
21     "shifts and dimensions must align. shifts: ", shifts.size(), ", dims:", dims.size()
22   );
23   AT_ASSERT(dims.size() > 1);
24   auto tail_shifts = shifts.slice(1);
25   auto tail_dims = dims.slice(1);
26   auto first_dim_rolled = roll(self, shifts[0], dims[0]);
27   return at::roll(first_dim_rolled, tail_shifts, tail_dims);
28 }
29 
30 }  // namespace at::native
31