xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ReplicationPadding.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/native/Padding.h>
7 #include <c10/util/irange.h>
8 #include <algorithm>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/replication_pad1d_backward_native.h>
15 #include <ATen/ops/replication_pad1d_native.h>
16 #include <ATen/ops/replication_pad2d_backward_native.h>
17 #include <ATen/ops/replication_pad2d_native.h>
18 #include <ATen/ops/replication_pad3d_backward_native.h>
19 #include <ATen/ops/replication_pad3d_native.h>
20 #include <ATen/ops/zeros_like.h>
21 #endif
22 
23 namespace at::meta {
24 
TORCH_META_FUNC(replication_pad1d)25 TORCH_META_FUNC(replication_pad1d) (
26   const Tensor& input, IntArrayRef paddingSize  // no out argument!
27 ) {
28   TORCH_CHECK(paddingSize.size() == 2, "padding size is expected to be 2");
29 
30   int64_t dimw = 1;
31   int64_t dimslices = 0;
32   int64_t nbatch = 1;
33 
34   int64_t pad_l = paddingSize[0];
35   int64_t pad_r = paddingSize[1];
36 
37   at::native::padding::check_valid_input<1>(input, paddingSize);
38 
39   if (input.ndimension() == 3) {
40     nbatch = input.size(0);
41     dimw++;
42     dimslices++;
43   }
44 
45   /* sizes */
46   int64_t nslices = input.size(dimslices);
47   int64_t iwidth = input.size(dimw);
48   int64_t owidth = iwidth + pad_l + pad_r;
49 
50   TORCH_CHECK(owidth >= 1,
51       "input (W: ", iwidth, ") is too small."
52       " Calculated output W: ", owidth);
53 
54   if (input.ndimension() == 2) {
55     set_output_raw_strided(0, {nslices, owidth}, {}, input.options());
56   } else {
57     set_output_raw_strided(0, {nbatch, nslices, owidth}, {}, input.options());
58   }
59 }
60 
TORCH_META_FUNC(replication_pad1d_backward)61 TORCH_META_FUNC(replication_pad1d_backward) (
62   const Tensor& gradOutput,
63   const Tensor& input,
64   IntArrayRef paddingSize
65 ) {
66   int64_t dimw = 1;
67   TORCH_CHECK(paddingSize.size() == 2, "padding size is expected to be 2");
68   int64_t pad_l = paddingSize[0];
69   int64_t pad_r = paddingSize[1];
70 
71   if (input.ndimension() == 3) {
72     dimw++;
73   }
74 
75   /* sizes */
76   int64_t iwidth = input.size(dimw);
77   int64_t owidth  = iwidth + pad_l + pad_r;
78 
79   TORCH_CHECK(owidth == gradOutput.size(dimw),
80       "gradOutput width unexpected. Expected: ", owidth,
81       " Got: ", gradOutput.size(dimw));
82 
83   set_output_raw_strided(0, input.sizes(), {}, input.options());
84 }
85 
TORCH_META_FUNC(replication_pad2d)86 TORCH_META_FUNC(replication_pad2d) (
87   const Tensor& input, IntArrayRef paddingSize
88 ) {
89   TORCH_CHECK(paddingSize.size() == 4, "padding size is expected to be 4");
90   int64_t pad_l = paddingSize[0];
91   int64_t pad_r = paddingSize[1];
92   int64_t pad_t = paddingSize[2];
93   int64_t pad_b = paddingSize[3];
94   int64_t dimw = 2;
95   int64_t dimh = 1;
96   int64_t dimslices = 0;
97   int64_t nbatch = 1;
98 
99   at::native::padding::check_valid_input<2>(input, paddingSize);
100 
101   if (input.dim() == 4) {
102     nbatch = input.size(0);
103     dimw++;
104     dimh++;
105     dimslices++;
106   }
107 
108   /* sizes */
109   int64_t nslices = input.size(dimslices);
110   int64_t iheight = input.size(dimh);
111   int64_t iwidth = input.size(dimw);
112   int64_t oheight = iheight + pad_t + pad_b;
113   int64_t owidth  = iwidth + pad_l + pad_r;
114 
115   TORCH_CHECK(owidth >= 1 || oheight >= 1,
116       "input (H: ", iheight, ", W: ", iwidth, " ) is too small."
117       " Calculated output H: ", oheight, " W: ", owidth);
118 
119   if (input.dim() == 3) {
120     set_output_raw_strided(0, {nslices, oheight, owidth}, {}, input.options());
121   } else {
122     set_output_raw_strided(0, {nbatch, nslices, oheight, owidth}, {}, input.options());
123   }
124 }
125 
TORCH_META_FUNC(replication_pad3d)126 TORCH_META_FUNC(replication_pad3d) (
127   const Tensor& input, IntArrayRef paddingSize
128 ) {
129   TORCH_CHECK(paddingSize.size() == 6, "padding size is expected to be 6");
130   int64_t pleft = paddingSize[0];
131   int64_t pright = paddingSize[1];
132   int64_t ptop = paddingSize[2];
133   int64_t pbottom = paddingSize[3];
134   int64_t pfront = paddingSize[4];
135   int64_t pback = paddingSize[5];
136   int64_t dimw = 3;
137   int64_t dimh = 2;
138   int64_t dimd = 1;
139   int64_t dimslices = 0;
140   int64_t nbatch = 1;
141 
142   at::native::padding::check_valid_input<3>(input, paddingSize);
143 
144   if (input.dim() == 5) {
145     nbatch = input.size(0);
146     dimw++;
147     dimh++;
148     dimd++;
149     dimslices++;
150   }
151 
152   /* sizes */
153   int64_t nslices = input.size(dimslices);
154   int64_t idepth = input.size(dimd);
155   int64_t iheight = input.size(dimh);
156   int64_t iwidth = input.size(dimw);
157   int64_t odepth = idepth + pfront + pback;
158   int64_t oheight = iheight + ptop + pbottom;
159   int64_t owidth  = iwidth + pleft + pright;
160 
161   TORCH_CHECK(owidth >= 1 || oheight >= 1 || odepth >= 1,
162       "input (D: ", idepth, " H: ", iheight, ", W: ", iwidth,
163       ") is too small."
164       " Calculated output D: ", odepth, " H: ", oheight, " W: ", owidth);
165 
166   /* resize output */
167   if (input.dim() == 4) {
168     set_output_raw_strided(0, {nslices, odepth, oheight, owidth}, {}, input.options());
169   } else {
170     set_output_raw_strided(0, {nbatch, nslices, odepth, oheight, owidth}, {}, input.options());
171   }
172 }
173 
174 } // namespace at::meta
175 
176 namespace at::native {
177 
178 namespace {
179 
replication_pad2d_backward_out_cpu_template(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)180 void replication_pad2d_backward_out_cpu_template(
181     Tensor& gradInput,
182     const Tensor& gradOutput,
183     const Tensor& input,
184     IntArrayRef paddingSize)
185 {
186   TORCH_CHECK(paddingSize.size() == 4, "padding size is expected to be 4");
187   int pad_l = paddingSize[0];
188   int pad_r = paddingSize[1];
189   int pad_t = paddingSize[2];
190   int pad_b = paddingSize[3];
191   int dimw = 2;
192   int dimh = 1;
193 
194   if (input.dim() == 4) {
195     dimw++;
196     dimh++;
197   }
198 
199   /* sizes */
200   int64_t iheight = input.size(dimh);
201   int64_t iwidth = input.size(dimw);
202   int64_t oheight = iheight + pad_t + pad_b;
203   int64_t owidth  = iwidth + pad_l + pad_r;
204 
205   TORCH_CHECK(owidth == gradOutput.size(dimw),
206       "gradOutput width unexpected. Expected: ", owidth, ", Got: ",
207       gradOutput.size(dimw));
208   TORCH_CHECK(oheight == gradOutput.size(dimh),
209       "gradOutput height unexpected. Expected: ", oheight, ", Got: ",
210       gradOutput.size(dimh));
211 
212   if (gradInput.numel() == 0) {
213     return;
214   }
215 
216   replication_pad2d_backward_kernel(kCPU, gradInput, gradOutput, paddingSize);
217 }
218 
replication_pad3d_backward_out_cpu_template(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)219 void replication_pad3d_backward_out_cpu_template(
220     Tensor& gradInput,
221     const Tensor& gradOutput,
222     const Tensor& input,
223     IntArrayRef paddingSize)
224 {
225   TORCH_CHECK(paddingSize.size() == 6, "padding size is expected to be 6");
226   int pleft = paddingSize[0];
227   int pright = paddingSize[1];
228   int ptop = paddingSize[2];
229   int pbottom = paddingSize[3];
230   int pfront = paddingSize[4];
231   int pback = paddingSize[5];
232   int dimw = 3;
233   int dimh = 2;
234   int dimd = 1;
235 
236   if (input.dim() == 5) {
237     dimw++;
238     dimh++;
239     dimd++;
240   }
241 
242   /* sizes */
243   int64_t idepth = input.size(dimd);
244   int64_t iheight = input.size(dimh);
245   int64_t iwidth = input.size(dimw);
246   int64_t odepth = idepth + pfront + pback;
247   int64_t oheight = iheight + ptop + pbottom;
248   int64_t owidth  = iwidth + pleft + pright;
249 
250   at::native::padding::check_valid_input<3>(input, paddingSize);
251 
252   TORCH_CHECK(owidth == gradOutput.size(dimw),
253       "gradOutput width unexpected. Expected: ", owidth, ", Got: ",
254       gradOutput.size(dimw));
255   TORCH_CHECK(oheight == gradOutput.size(dimh),
256       "gradOutput height unexpected. Expected: ", oheight, ", Got: ",
257       gradOutput.size(dimh));
258   TORCH_CHECK(odepth == gradOutput.size(dimd),
259       "gradOutput depth unexpected. Expected: ", odepth, ", Got: ",
260       gradOutput.size(dimd));
261 
262   if (gradInput.numel() == 0) {
263     return;
264   }
265 
266   replication_pad3d_backward_kernel(kCPU, gradInput, gradOutput, paddingSize);
267 }
268 
269 } // anonymous namespace
270 
TORCH_IMPL_FUNC(replication_pad1d_out_cpu)271 TORCH_IMPL_FUNC(replication_pad1d_out_cpu) (
272   const Tensor& input, IntArrayRef paddingSize, const Tensor& output
273 ) {
274   replication_pad1d_kernel(kCPU, output, input, paddingSize);
275 }
276 
TORCH_IMPL_FUNC(replication_pad1d_backward_out_cpu)277 TORCH_IMPL_FUNC(replication_pad1d_backward_out_cpu) (
278   const Tensor& gradOutput, const Tensor& input, IntArrayRef paddingSize, const Tensor& gradInput
279 ) {
280   if (gradInput.numel() == 0) {
281     return;
282   }
283   gradInput.zero_();
284 
285   replication_pad1d_backward_kernel(kCPU, gradInput, gradOutput, paddingSize);
286 }
287 
TORCH_IMPL_FUNC(replication_pad2d_out_cpu)288 TORCH_IMPL_FUNC(replication_pad2d_out_cpu) (
289   const Tensor& input, IntArrayRef paddingSize, const Tensor& output
290 ) {
291   // TODO: move this to TORCH_META_FUNC when CUDA has channels last support
292   output.resize_(output.sizes(), input.suggest_memory_format());
293 
294   replication_pad2d_kernel(kCPU, output, input, paddingSize);
295 }
296 
replication_pad2d_backward_out_cpu(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize,Tensor & gradInput)297 Tensor& replication_pad2d_backward_out_cpu(const Tensor& gradOutput,
298     const Tensor& input,
299     IntArrayRef paddingSize,
300     Tensor& gradInput)
301 {
302   gradInput.resize_as_(input, input.suggest_memory_format());
303   gradInput.zero_();
304   replication_pad2d_backward_out_cpu_template(
305       gradInput, gradOutput, input, paddingSize);
306   return gradInput;
307 }
308 
replication_pad2d_backward_cpu(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)309 Tensor replication_pad2d_backward_cpu(
310     const Tensor& gradOutput,
311     const Tensor& input,
312     IntArrayRef paddingSize)
313 {
314   auto gradInput = at::zeros_like(input, input.suggest_memory_format());
315   replication_pad2d_backward_out_cpu_template(
316       gradInput, gradOutput, input, paddingSize);
317   return gradInput;
318 }
319 
TORCH_IMPL_FUNC(replication_pad3d_out_cpu)320 TORCH_IMPL_FUNC(replication_pad3d_out_cpu) (
321   const Tensor& input, IntArrayRef paddingSize, const Tensor& output
322 ) {
323   // TODO: move this to TORCH_META_FUNC when CUDA has channels last support
324   output.resize_(output.sizes(), input.suggest_memory_format());
325 
326   replication_pad3d_kernel(kCPU, output, input, paddingSize);
327 }
328 
replication_pad3d_backward_out_cpu(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize,Tensor & gradInput)329 Tensor& replication_pad3d_backward_out_cpu(const Tensor& gradOutput,
330     const Tensor& input,
331     IntArrayRef paddingSize,
332     Tensor& gradInput)
333 {
334   gradInput.resize_as_(input, input.suggest_memory_format());
335   gradInput.zero_();
336   replication_pad3d_backward_out_cpu_template(
337       gradInput, gradOutput, input, paddingSize);
338   return gradInput;
339 }
340 
replication_pad3d_backward_cpu(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)341 Tensor replication_pad3d_backward_cpu(
342     const Tensor& gradOutput,
343     const Tensor& input,
344     IntArrayRef paddingSize)
345 {
346   auto gradInput = at::zeros_like(input, input.suggest_memory_format());
347   replication_pad3d_backward_out_cpu_template(
348       gradInput, gradOutput, input, paddingSize);
349   return gradInput;
350 }
351 
352 DEFINE_DISPATCH(replication_pad1d_kernel);
353 DEFINE_DISPATCH(replication_pad1d_backward_kernel);
354 DEFINE_DISPATCH(replication_pad2d_kernel);
355 DEFINE_DISPATCH(replication_pad2d_backward_kernel);
356 DEFINE_DISPATCH(replication_pad3d_kernel);
357 DEFINE_DISPATCH(replication_pad3d_backward_kernel);
358 
359 } // namespace at::native
360