xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/PixelShuffle.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/TensorTransformations.h>
3 #include <ATen/native/cpu/PixelShuffleKernel.h>
4 #include <ATen/native/PixelShuffle.h>
5 
6 #include <c10/util/Exception.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/Functions.h>
10 #include <ATen/NativeFunctions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/pixel_shuffle_native.h>
14 #include <ATen/ops/pixel_unshuffle_native.h>
15 #endif
16 
17 #include <algorithm>
18 #include <numeric>
19 #include <vector>
20 
21 namespace at::native {
22 
pixel_shuffle_cpu(const Tensor & self,int64_t upscale_factor)23 Tensor pixel_shuffle_cpu(const Tensor& self, int64_t upscale_factor) {
24   check_pixel_shuffle_shapes(self, upscale_factor);
25 
26   // Format: (B1, ..., Bn), C, H, W
27   std::vector<int64_t> output_sizes(self.sizes().begin(), self.sizes().end() - 3);
28   output_sizes.insert(output_sizes.end(),
29       {self.size(-3) / upscale_factor / upscale_factor,
30        self.size(-2) * upscale_factor,
31        self.size(-1) * upscale_factor});
32 
33   auto output = at::empty({0}, self.options());
34   auto memory_format = self.suggest_memory_format();
35   output.resize_(output_sizes, memory_format);
36 
37   if (output.numel() == 0) {
38     return output;
39   }
40 
41   auto input = self.contiguous(memory_format);
42 
43   pixel_shuffle_kernel(kCPU, output, input, upscale_factor);
44   return output;
45 }
46 
pixel_unshuffle_cpu(const Tensor & self,int64_t downscale_factor)47 Tensor pixel_unshuffle_cpu(const Tensor& self, int64_t downscale_factor) {
48   check_pixel_unshuffle_shapes(self, downscale_factor);
49 
50   if (self.numel() == 0) {
51     return self.clone();
52   }
53 
54   // Format: (B1, ..., Bn), C, H, W
55   std::vector<int64_t> output_sizes(self.sizes().begin(), self.sizes().end() - 3);
56   output_sizes.insert(output_sizes.end(),
57       {self.size(-3) * downscale_factor * downscale_factor,
58        self.size(-2) / downscale_factor,
59        self.size(-1) / downscale_factor});
60 
61   auto output = at::empty({0}, self.options());
62   auto memory_format = self.suggest_memory_format();
63   output.resize_(output_sizes, memory_format);
64 
65   if (output.numel() == 0) {
66     return output;
67   }
68 
69   auto input = self.contiguous(memory_format);
70 
71   pixel_unshuffle_kernel(kCPU, output, input, downscale_factor);
72   return output;
73 }
74 
math_pixel_shuffle(const Tensor & self,int64_t upscale_factor)75 Tensor math_pixel_shuffle(const Tensor& self, int64_t upscale_factor) {
76   check_pixel_shuffle_shapes(self, upscale_factor);
77 
78   // Format: (B1, ..., Bn), C, H, W
79   int64_t c = self.size(-3);
80   int64_t h = self.size(-2);
81   int64_t w = self.size(-1);
82   const auto NUM_NON_BATCH_DIMS = 3;
83   const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;
84 
85   int64_t upscale_factor_squared = upscale_factor * upscale_factor;
86   int64_t oc = c / upscale_factor_squared;
87   int64_t oh = h * upscale_factor;
88   int64_t ow = w * upscale_factor;
89 
90   // First, reshape to split the channels dim from c into 3 separate dims: (oc,
91   // upscale_factor, upscale_factor). This allows shuffling to be done next by
92   // permuting dims.
93   std::vector<int64_t> added_dims_shape(
94       self.sizes().begin(), self_sizes_batch_end);
95   added_dims_shape.insert(
96       added_dims_shape.end(), {oc, upscale_factor, upscale_factor, h, w});
97   const auto input_reshaped = self.reshape(added_dims_shape);
98 
99   // Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims.
100   std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
101   // std::iota is used to maintain the batch dims within the permutation.
102   std::iota(permutation.begin(), permutation.end(), 0);
103   permutation.insert(permutation.end(), {-5 /* oc */, -2 /* h */, -4 /* 1st upscale_factor */, -1 /* w */,
104                                          -3 /* 2nd upscale_factor */});
105   const auto input_permuted = input_reshaped.permute(permutation);
106 
107   // Finally, upscale by collapsing (h, upscale_factor) -> a single dim (oh)
108   // and (w, upscale_factor) -> a single dim (ow).
109   std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
110   final_shape.insert(final_shape.end(), {oc, oh, ow});
111 
112   // pixel_shuffle expects to *never* return an alias of the input.
113   return input_permuted.clone(at::MemoryFormat::Contiguous).view(final_shape);
114 }
115 
math_pixel_unshuffle(const Tensor & self,int64_t downscale_factor)116 Tensor math_pixel_unshuffle(const Tensor& self, int64_t downscale_factor) {
117   check_pixel_unshuffle_shapes(self, downscale_factor);
118 
119   // Format: (B1, ..., Bn), C, H, W
120   int64_t c = self.size(-3);
121   int64_t h = self.size(-2);
122   int64_t w = self.size(-1);
123   constexpr auto NUM_NON_BATCH_DIMS = 3;
124   const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;
125 
126   int64_t downscale_factor_squared = downscale_factor * downscale_factor;
127   int64_t oc = c * downscale_factor_squared;
128   int64_t oh = h / downscale_factor;
129   int64_t ow = w / downscale_factor;
130 
131   // First, reshape to split height dim into (oh, downscale_factor) dims and
132   // width dim into (ow, downscale_factor) dims. This allows unshuffling to be
133   // done next by permuting dims.
134   std::vector<int64_t> added_dims_shape(
135       self.sizes().begin(), self_sizes_batch_end);
136   added_dims_shape.insert(
137       added_dims_shape.end(), {c, oh, downscale_factor, ow, downscale_factor});
138   const auto input_reshaped = self.reshape(added_dims_shape);
139 
140   // Next, unshuffle by permuting the downscale_factor dims alongside the channel dim.
141   std::vector<int64_t> permutation(self.sizes().begin(), self_sizes_batch_end);
142   // std::iota is used to maintain the batch dims within the permutation.
143   std::iota(permutation.begin(), permutation.end(), 0);
144   permutation.insert(permutation.end(), {-5 /* c */, -3 /* 1st downscale_factor */, -1 /*2nd downscale_factor */,
145                                          -4 /* oh */, -2 /* ow */});
146   const auto input_permuted = input_reshaped.permute(permutation);
147 
148   // Finally, downscale by collapsing (c, downscale_factor, downscale_factor) -> a single dim (oc),
149   // resulting in height=oh and width=ow.
150   std::vector<int64_t> final_shape(self.sizes().begin(), self_sizes_batch_end);
151   final_shape.insert(final_shape.end(), {oc, oh, ow});
152 
153   // pixel_unshuffle expects to *never* return an alias of the input.
154   return input_permuted.clone(at::MemoryFormat::Contiguous).view(final_shape);
155 }
156 
157 DEFINE_DISPATCH(pixel_shuffle_kernel);
158 DEFINE_DISPATCH(pixel_unshuffle_kernel);
159 
160 } // namespace at::native
161