1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/native/Padding.h>
7 #include <c10/util/irange.h>
8 #include <algorithm>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/replication_pad1d_backward_native.h>
15 #include <ATen/ops/replication_pad1d_native.h>
16 #include <ATen/ops/replication_pad2d_backward_native.h>
17 #include <ATen/ops/replication_pad2d_native.h>
18 #include <ATen/ops/replication_pad3d_backward_native.h>
19 #include <ATen/ops/replication_pad3d_native.h>
20 #include <ATen/ops/zeros_like.h>
21 #endif
22
23 namespace at::meta {
24
TORCH_META_FUNC(replication_pad1d)25 TORCH_META_FUNC(replication_pad1d) (
26 const Tensor& input, IntArrayRef paddingSize // no out argument!
27 ) {
28 TORCH_CHECK(paddingSize.size() == 2, "padding size is expected to be 2");
29
30 int64_t dimw = 1;
31 int64_t dimslices = 0;
32 int64_t nbatch = 1;
33
34 int64_t pad_l = paddingSize[0];
35 int64_t pad_r = paddingSize[1];
36
37 at::native::padding::check_valid_input<1>(input, paddingSize);
38
39 if (input.ndimension() == 3) {
40 nbatch = input.size(0);
41 dimw++;
42 dimslices++;
43 }
44
45 /* sizes */
46 int64_t nslices = input.size(dimslices);
47 int64_t iwidth = input.size(dimw);
48 int64_t owidth = iwidth + pad_l + pad_r;
49
50 TORCH_CHECK(owidth >= 1,
51 "input (W: ", iwidth, ") is too small."
52 " Calculated output W: ", owidth);
53
54 if (input.ndimension() == 2) {
55 set_output_raw_strided(0, {nslices, owidth}, {}, input.options());
56 } else {
57 set_output_raw_strided(0, {nbatch, nslices, owidth}, {}, input.options());
58 }
59 }
60
TORCH_META_FUNC(replication_pad1d_backward)61 TORCH_META_FUNC(replication_pad1d_backward) (
62 const Tensor& gradOutput,
63 const Tensor& input,
64 IntArrayRef paddingSize
65 ) {
66 int64_t dimw = 1;
67 TORCH_CHECK(paddingSize.size() == 2, "padding size is expected to be 2");
68 int64_t pad_l = paddingSize[0];
69 int64_t pad_r = paddingSize[1];
70
71 if (input.ndimension() == 3) {
72 dimw++;
73 }
74
75 /* sizes */
76 int64_t iwidth = input.size(dimw);
77 int64_t owidth = iwidth + pad_l + pad_r;
78
79 TORCH_CHECK(owidth == gradOutput.size(dimw),
80 "gradOutput width unexpected. Expected: ", owidth,
81 " Got: ", gradOutput.size(dimw));
82
83 set_output_raw_strided(0, input.sizes(), {}, input.options());
84 }
85
TORCH_META_FUNC(replication_pad2d)86 TORCH_META_FUNC(replication_pad2d) (
87 const Tensor& input, IntArrayRef paddingSize
88 ) {
89 TORCH_CHECK(paddingSize.size() == 4, "padding size is expected to be 4");
90 int64_t pad_l = paddingSize[0];
91 int64_t pad_r = paddingSize[1];
92 int64_t pad_t = paddingSize[2];
93 int64_t pad_b = paddingSize[3];
94 int64_t dimw = 2;
95 int64_t dimh = 1;
96 int64_t dimslices = 0;
97 int64_t nbatch = 1;
98
99 at::native::padding::check_valid_input<2>(input, paddingSize);
100
101 if (input.dim() == 4) {
102 nbatch = input.size(0);
103 dimw++;
104 dimh++;
105 dimslices++;
106 }
107
108 /* sizes */
109 int64_t nslices = input.size(dimslices);
110 int64_t iheight = input.size(dimh);
111 int64_t iwidth = input.size(dimw);
112 int64_t oheight = iheight + pad_t + pad_b;
113 int64_t owidth = iwidth + pad_l + pad_r;
114
115 TORCH_CHECK(owidth >= 1 || oheight >= 1,
116 "input (H: ", iheight, ", W: ", iwidth, " ) is too small."
117 " Calculated output H: ", oheight, " W: ", owidth);
118
119 if (input.dim() == 3) {
120 set_output_raw_strided(0, {nslices, oheight, owidth}, {}, input.options());
121 } else {
122 set_output_raw_strided(0, {nbatch, nslices, oheight, owidth}, {}, input.options());
123 }
124 }
125
TORCH_META_FUNC(replication_pad3d)126 TORCH_META_FUNC(replication_pad3d) (
127 const Tensor& input, IntArrayRef paddingSize
128 ) {
129 TORCH_CHECK(paddingSize.size() == 6, "padding size is expected to be 6");
130 int64_t pleft = paddingSize[0];
131 int64_t pright = paddingSize[1];
132 int64_t ptop = paddingSize[2];
133 int64_t pbottom = paddingSize[3];
134 int64_t pfront = paddingSize[4];
135 int64_t pback = paddingSize[5];
136 int64_t dimw = 3;
137 int64_t dimh = 2;
138 int64_t dimd = 1;
139 int64_t dimslices = 0;
140 int64_t nbatch = 1;
141
142 at::native::padding::check_valid_input<3>(input, paddingSize);
143
144 if (input.dim() == 5) {
145 nbatch = input.size(0);
146 dimw++;
147 dimh++;
148 dimd++;
149 dimslices++;
150 }
151
152 /* sizes */
153 int64_t nslices = input.size(dimslices);
154 int64_t idepth = input.size(dimd);
155 int64_t iheight = input.size(dimh);
156 int64_t iwidth = input.size(dimw);
157 int64_t odepth = idepth + pfront + pback;
158 int64_t oheight = iheight + ptop + pbottom;
159 int64_t owidth = iwidth + pleft + pright;
160
161 TORCH_CHECK(owidth >= 1 || oheight >= 1 || odepth >= 1,
162 "input (D: ", idepth, " H: ", iheight, ", W: ", iwidth,
163 ") is too small."
164 " Calculated output D: ", odepth, " H: ", oheight, " W: ", owidth);
165
166 /* resize output */
167 if (input.dim() == 4) {
168 set_output_raw_strided(0, {nslices, odepth, oheight, owidth}, {}, input.options());
169 } else {
170 set_output_raw_strided(0, {nbatch, nslices, odepth, oheight, owidth}, {}, input.options());
171 }
172 }
173
174 } // namespace at::meta
175
176 namespace at::native {
177
178 namespace {
179
replication_pad2d_backward_out_cpu_template(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)180 void replication_pad2d_backward_out_cpu_template(
181 Tensor& gradInput,
182 const Tensor& gradOutput,
183 const Tensor& input,
184 IntArrayRef paddingSize)
185 {
186 TORCH_CHECK(paddingSize.size() == 4, "padding size is expected to be 4");
187 int pad_l = paddingSize[0];
188 int pad_r = paddingSize[1];
189 int pad_t = paddingSize[2];
190 int pad_b = paddingSize[3];
191 int dimw = 2;
192 int dimh = 1;
193
194 if (input.dim() == 4) {
195 dimw++;
196 dimh++;
197 }
198
199 /* sizes */
200 int64_t iheight = input.size(dimh);
201 int64_t iwidth = input.size(dimw);
202 int64_t oheight = iheight + pad_t + pad_b;
203 int64_t owidth = iwidth + pad_l + pad_r;
204
205 TORCH_CHECK(owidth == gradOutput.size(dimw),
206 "gradOutput width unexpected. Expected: ", owidth, ", Got: ",
207 gradOutput.size(dimw));
208 TORCH_CHECK(oheight == gradOutput.size(dimh),
209 "gradOutput height unexpected. Expected: ", oheight, ", Got: ",
210 gradOutput.size(dimh));
211
212 if (gradInput.numel() == 0) {
213 return;
214 }
215
216 replication_pad2d_backward_kernel(kCPU, gradInput, gradOutput, paddingSize);
217 }
218
replication_pad3d_backward_out_cpu_template(Tensor & gradInput,const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)219 void replication_pad3d_backward_out_cpu_template(
220 Tensor& gradInput,
221 const Tensor& gradOutput,
222 const Tensor& input,
223 IntArrayRef paddingSize)
224 {
225 TORCH_CHECK(paddingSize.size() == 6, "padding size is expected to be 6");
226 int pleft = paddingSize[0];
227 int pright = paddingSize[1];
228 int ptop = paddingSize[2];
229 int pbottom = paddingSize[3];
230 int pfront = paddingSize[4];
231 int pback = paddingSize[5];
232 int dimw = 3;
233 int dimh = 2;
234 int dimd = 1;
235
236 if (input.dim() == 5) {
237 dimw++;
238 dimh++;
239 dimd++;
240 }
241
242 /* sizes */
243 int64_t idepth = input.size(dimd);
244 int64_t iheight = input.size(dimh);
245 int64_t iwidth = input.size(dimw);
246 int64_t odepth = idepth + pfront + pback;
247 int64_t oheight = iheight + ptop + pbottom;
248 int64_t owidth = iwidth + pleft + pright;
249
250 at::native::padding::check_valid_input<3>(input, paddingSize);
251
252 TORCH_CHECK(owidth == gradOutput.size(dimw),
253 "gradOutput width unexpected. Expected: ", owidth, ", Got: ",
254 gradOutput.size(dimw));
255 TORCH_CHECK(oheight == gradOutput.size(dimh),
256 "gradOutput height unexpected. Expected: ", oheight, ", Got: ",
257 gradOutput.size(dimh));
258 TORCH_CHECK(odepth == gradOutput.size(dimd),
259 "gradOutput depth unexpected. Expected: ", odepth, ", Got: ",
260 gradOutput.size(dimd));
261
262 if (gradInput.numel() == 0) {
263 return;
264 }
265
266 replication_pad3d_backward_kernel(kCPU, gradInput, gradOutput, paddingSize);
267 }
268
269 } // anonymous namespace
270
TORCH_IMPL_FUNC(replication_pad1d_out_cpu)271 TORCH_IMPL_FUNC(replication_pad1d_out_cpu) (
272 const Tensor& input, IntArrayRef paddingSize, const Tensor& output
273 ) {
274 replication_pad1d_kernel(kCPU, output, input, paddingSize);
275 }
276
TORCH_IMPL_FUNC(replication_pad1d_backward_out_cpu)277 TORCH_IMPL_FUNC(replication_pad1d_backward_out_cpu) (
278 const Tensor& gradOutput, const Tensor& input, IntArrayRef paddingSize, const Tensor& gradInput
279 ) {
280 if (gradInput.numel() == 0) {
281 return;
282 }
283 gradInput.zero_();
284
285 replication_pad1d_backward_kernel(kCPU, gradInput, gradOutput, paddingSize);
286 }
287
TORCH_IMPL_FUNC(replication_pad2d_out_cpu)288 TORCH_IMPL_FUNC(replication_pad2d_out_cpu) (
289 const Tensor& input, IntArrayRef paddingSize, const Tensor& output
290 ) {
291 // TODO: move this to TORCH_META_FUNC when CUDA has channels last support
292 output.resize_(output.sizes(), input.suggest_memory_format());
293
294 replication_pad2d_kernel(kCPU, output, input, paddingSize);
295 }
296
replication_pad2d_backward_out_cpu(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize,Tensor & gradInput)297 Tensor& replication_pad2d_backward_out_cpu(const Tensor& gradOutput,
298 const Tensor& input,
299 IntArrayRef paddingSize,
300 Tensor& gradInput)
301 {
302 gradInput.resize_as_(input, input.suggest_memory_format());
303 gradInput.zero_();
304 replication_pad2d_backward_out_cpu_template(
305 gradInput, gradOutput, input, paddingSize);
306 return gradInput;
307 }
308
replication_pad2d_backward_cpu(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)309 Tensor replication_pad2d_backward_cpu(
310 const Tensor& gradOutput,
311 const Tensor& input,
312 IntArrayRef paddingSize)
313 {
314 auto gradInput = at::zeros_like(input, input.suggest_memory_format());
315 replication_pad2d_backward_out_cpu_template(
316 gradInput, gradOutput, input, paddingSize);
317 return gradInput;
318 }
319
TORCH_IMPL_FUNC(replication_pad3d_out_cpu)320 TORCH_IMPL_FUNC(replication_pad3d_out_cpu) (
321 const Tensor& input, IntArrayRef paddingSize, const Tensor& output
322 ) {
323 // TODO: move this to TORCH_META_FUNC when CUDA has channels last support
324 output.resize_(output.sizes(), input.suggest_memory_format());
325
326 replication_pad3d_kernel(kCPU, output, input, paddingSize);
327 }
328
replication_pad3d_backward_out_cpu(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize,Tensor & gradInput)329 Tensor& replication_pad3d_backward_out_cpu(const Tensor& gradOutput,
330 const Tensor& input,
331 IntArrayRef paddingSize,
332 Tensor& gradInput)
333 {
334 gradInput.resize_as_(input, input.suggest_memory_format());
335 gradInput.zero_();
336 replication_pad3d_backward_out_cpu_template(
337 gradInput, gradOutput, input, paddingSize);
338 return gradInput;
339 }
340
replication_pad3d_backward_cpu(const Tensor & gradOutput,const Tensor & input,IntArrayRef paddingSize)341 Tensor replication_pad3d_backward_cpu(
342 const Tensor& gradOutput,
343 const Tensor& input,
344 IntArrayRef paddingSize)
345 {
346 auto gradInput = at::zeros_like(input, input.suggest_memory_format());
347 replication_pad3d_backward_out_cpu_template(
348 gradInput, gradOutput, input, paddingSize);
349 return gradInput;
350 }
351
352 DEFINE_DISPATCH(replication_pad1d_kernel);
353 DEFINE_DISPATCH(replication_pad1d_backward_kernel);
354 DEFINE_DISPATCH(replication_pad2d_kernel);
355 DEFINE_DISPATCH(replication_pad2d_backward_kernel);
356 DEFINE_DISPATCH(replication_pad3d_kernel);
357 DEFINE_DISPATCH(replication_pad3d_backward_kernel);
358
359 } // namespace at::native
360