1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/quantized/Quantizer.h>
7 #include <ATen/native/Padding.h>
8 #include <c10/util/irange.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_empty_affine_quantized.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/reflection_pad1d_backward_native.h>
17 #include <ATen/ops/reflection_pad1d_native.h>
18 #include <ATen/ops/reflection_pad2d_backward_native.h>
19 #include <ATen/ops/reflection_pad2d_native.h>
20 #include <ATen/ops/reflection_pad3d_backward_native.h>
21 #include <ATen/ops/reflection_pad3d_native.h>
22 #include <ATen/ops/zeros_like.h>
23 #endif
24
25 namespace at::meta {
26
TORCH_META_FUNC(reflection_pad1d)27 TORCH_META_FUNC(reflection_pad1d)(const Tensor& input, IntArrayRef padding) {
28 int64_t dim_plane = 0;
29 int64_t dim_w = 1;
30 int64_t nbatch = 1;
31
32 if (input.ndimension() == 3) {
33 nbatch = input.size(0);
34 dim_w++;
35 dim_plane++;
36 }
37
38 at::native::padding::check_valid_input<1>(input, padding);
39
40 /* sizes */
41 auto pad_l = padding[0];
42 auto pad_r = padding[1];
43
44 int64_t nplane = input.size(dim_plane);
45 int64_t input_w = input.size(dim_w);
46 int64_t output_w = input_w + pad_l + pad_r;
47
48 TORCH_CHECK(
49 pad_l < input_w && pad_r < input_w,
50 "Argument #4: Padding size "
51 "should be less than the corresponding input dimension, but got: padding (",
52 pad_l,
53 ", ",
54 pad_r,
55 ") at dimension ",
56 dim_w,
57 " of input ",
58 input.sizes());
59
60 TORCH_CHECK(
61 output_w >= 1,
62 "input (W: ",
63 input_w,
64 ") is too small. Calculated output W: ",
65 output_w);
66
67 if (input.ndimension() == 2) {
68 set_output_raw_strided(0, {nplane, output_w}, {}, input.options());
69 } else {
70 set_output_raw_strided(0, {nbatch, nplane, output_w}, {}, input.options());
71 }
72 }
73
TORCH_META_FUNC(reflection_pad1d_backward)74 TORCH_META_FUNC(reflection_pad1d_backward)(const Tensor& grad_output,
75 const Tensor& input,
76 IntArrayRef padding) {
77 int64_t dim_w = 1;
78 if (input.ndimension() == 3) {
79 dim_w++;
80 }
81
82 /* sizes */
83 auto pad_l = padding[0];
84 auto pad_r = padding[1];
85 int64_t input_w = input.size(dim_w);
86 int64_t output_w = input_w + pad_l + pad_r;
87
88 TORCH_CHECK(
89 pad_l < input_w && pad_r < input_w,
90 "Argument #4: Padding size "
91 "should be less than the corresponding input dimension, but got: padding (",
92 pad_l,
93 ", ",
94 pad_r,
95 ") at dimension ",
96 dim_w,
97 " of input ",
98 input.sizes());
99
100 TORCH_CHECK(output_w == grad_output.size(dim_w), "grad_output width unexpected."
101 " Expected: ", output_w, ", Got: ", grad_output.size(dim_w));
102
103 set_output_raw_strided(0, input.sizes(), {}, input.options());
104 }
105
TORCH_META_FUNC(reflection_pad3d)106 TORCH_META_FUNC(reflection_pad3d)(const Tensor& input, IntArrayRef padding) {
107 int64_t pad_left = padding[0];
108 int64_t pad_right = padding[1];
109 int64_t pad_top = padding[2];
110 int64_t pad_bottom = padding[3];
111 int64_t pad_front = padding[4];
112 int64_t pad_back = padding[5];
113 int64_t dim_w = 3;
114 int64_t dim_h = 2;
115 int64_t dim_d = 1;
116 int64_t dim_plane = 0;
117
118 at::native::padding::check_valid_input<3>(input, padding);
119
120 bool batch_mode = (input.dim() == 5);
121 if (batch_mode) {
122 dim_w++;
123 dim_h++;
124 dim_d++;
125 dim_plane++;
126 }
127
128 int64_t nplane = input.size(dim_plane);
129 int64_t input_d = input.size(dim_d);
130 int64_t input_h = input.size(dim_h);
131 int64_t input_w = input.size(dim_w);
132 int64_t output_d = input_d + pad_front + pad_back;
133 int64_t output_h = input_h + pad_top + pad_bottom;
134 int64_t output_w = input_w + pad_left + pad_right;
135
136 TORCH_CHECK(
137 pad_left < input_w && pad_right < input_w,
138 "Argument #4: Padding size "
139 "should be less than the corresponding input dimension, but got: padding (",
140 pad_left, ", ", pad_right, ") at dimension ", dim_w, " of input ", input.sizes());
141 TORCH_CHECK(
142 pad_top < input_h && pad_bottom < input_h,
143 "Argument #6: Padding size "
144 "should be less than the corresponding input dimension, but got: padding (",
145 pad_top, ", ", pad_bottom, ") at dimension ", dim_h, " of input ", input.sizes());
146 TORCH_CHECK(
147 pad_front < input_d && pad_back < input_d,
148 "Argument #8: Padding size "
149 "should be less than the corresponding input dimension, but got: padding (",
150 pad_front, ", ", pad_back, ") at dimension ", dim_d, " of input ", input.sizes());
151
152 TORCH_CHECK(output_w >= 1 || output_h >=1 || output_d >= 1,
153 "input (D: ", input_d, " H: ", input_h, ", W: ", input_w,
154 ") is too small."
155 " Calculated output D: ", output_d, " H: ", output_h, " W: ", output_w);
156
157 if (batch_mode) {
158 set_output_raw_strided(0, {input.size(0), nplane, output_d, output_h, output_w}, {}, input.options());
159 } else {
160 set_output_raw_strided(0, {nplane, output_d, output_h, output_w}, {}, input.options());
161 }
162 }
163
TORCH_META_FUNC(reflection_pad3d_backward)164 TORCH_META_FUNC(reflection_pad3d_backward)(
165 const Tensor& grad_output,
166 const Tensor& input,
167 IntArrayRef padding
168 ) {
169 TORCH_CHECK(padding.size() == 6, "padding size is expected to be 6");
170 TORCH_CHECK(input.dim() > 3);
171 TORCH_CHECK(grad_output.dim() == input.dim());
172
173 int64_t pad_left = padding[0];
174 int64_t pad_right = padding[1];
175 int64_t pad_top = padding[2];
176 int64_t pad_bottom = padding[3];
177 int64_t pad_front = padding[4];
178 int64_t pad_back = padding[5];
179 int64_t dim_w = 3;
180 int64_t dim_h = 2;
181 int64_t dim_d = 1;
182
183 if (input.dim() == 5) {
184 // batch mode
185 dim_w++;
186 dim_h++;
187 dim_d++;
188 }
189
190 int64_t input_d = input.size(dim_d);
191 int64_t input_h = input.size(dim_h);
192 int64_t input_w = input.size(dim_w);
193 int64_t output_d = input_d + pad_front + pad_back;
194 int64_t output_h = input_h + pad_top + pad_bottom;
195 int64_t output_w = input_w + pad_left + pad_right;
196
197 TORCH_CHECK(output_w == grad_output.size(dim_w), "grad_output width unexpected."
198 " Expected: ", output_w, ", Got: ", grad_output.size(dim_w));
199 TORCH_CHECK(output_h == grad_output.size(dim_h), "grad_output height unexpected."
200 " Expected: ", output_h, ", Got: ", grad_output.size(dim_h));
201 TORCH_CHECK(output_d == grad_output.size(dim_d), "grad_output depth unexpected."
202 " Expected: ", output_d, ", Got: ", grad_output.size(dim_d));
203
204 set_output_raw_strided(0, input.sizes(), {}, input.options());
205 }
206 } // namespace at::meta
207
208 namespace at::native {
209
210 namespace {
211
reflection_pad2d_out_template(Tensor & output,const Tensor & input,IntArrayRef padding)212 void reflection_pad2d_out_template(
213 Tensor &output, const Tensor &input, IntArrayRef padding) {
214 int dim_w = 2;
215 int dim_h = 1;
216 int dim_slices = 0;
217 int64_t nbatch = 1;
218
219 at::native::padding::check_valid_input<2>(input, padding);
220
221 int ndim = input.dim();
222 if (ndim == 4) {
223 nbatch = input.size(0);
224 dim_w++;
225 dim_h++;
226 dim_slices++;
227 }
228
229 /* sizes */
230 int64_t pad_l = padding[0];
231 int64_t pad_r = padding[1];
232 int64_t pad_t = padding[2];
233 int64_t pad_b = padding[3];
234
235 int64_t nplane = input.size(dim_slices);
236 int64_t input_h = input.size(dim_h);
237 int64_t input_w = input.size(dim_w);
238 int64_t output_h = input_h + pad_t + pad_b;
239 int64_t output_w = input_w + pad_l + pad_r;
240
241 TORCH_CHECK(pad_l < input_w && pad_r < input_w,
242 "Argument #4: Padding size should be less than the corresponding "
243 "input dimension, but got: padding (", pad_l, ", ", pad_r,
244 ") at dimension ", dim_w, " of input ", input.sizes());
245
246 TORCH_CHECK(pad_t < input_h && pad_b < input_h,
247 "Argument #6: Padding size should be less than the corresponding "
248 "input dimension, but got: padding (", pad_t, ", ", pad_b,
249 ") at dimension ", dim_h, " of input ", input.sizes());
250
251 TORCH_CHECK(output_w >= 1 || output_h >= 1,
252 "input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated "
253 "output H: ", output_h, " W: ", output_w);
254
255 /* resize output */
256 if (ndim == 3) {
257 output.resize_({nplane, output_h, output_w});
258 } else {
259 if (input.is_quantized()) {
260 // quantized tensor can not be resized with argument `memory_format`
261 output.resize_({nbatch, nplane, output_h, output_w});
262 } else {
263 output.resize_({nbatch, nplane, output_h, output_w}, input.suggest_memory_format());
264 }
265 }
266 reflection_pad2d_kernel(kCPU, output, input, padding);
267 }
268
reflection_pad2d_backward_out_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,IntArrayRef padding)269 void reflection_pad2d_backward_out_template(
270 Tensor &grad_input, const Tensor &grad_output,
271 const Tensor &input, IntArrayRef padding) {
272 int dim_w = 2;
273 int dim_h = 1;
274
275 if (input.ndimension() == 4) {
276 dim_w++;
277 dim_h++;
278 }
279
280 /* sizes */
281 int64_t pad_l = padding[0];
282 int64_t pad_r = padding[1];
283 int64_t pad_t = padding[2];
284 int64_t pad_b = padding[3];
285
286 int64_t input_h = input.size(dim_h);
287 int64_t input_w = input.size(dim_w);
288 int64_t output_h = input_h + pad_t + pad_b;
289 int64_t output_w = input_w + pad_l + pad_r;
290
291 TORCH_CHECK(output_w == grad_output.size(dim_w),
292 "gradOutput width unexpected. Expected: ", output_w, ", Got: ",
293 grad_output.size(dim_w));
294
295 TORCH_CHECK(output_h == grad_output.size(dim_h),
296 "gradOutput height unexpected. Expected: ", output_h, ", Got: ",
297 grad_output.size(dim_h));
298
299 reflection_pad2d_backward_kernel(kCPU, grad_input, grad_output, padding);
300 }
301
302 } // namespace
303
reflection_pad1d_out_quantized_cpu(const Tensor & input,IntArrayRef padding,Tensor & output)304 Tensor& reflection_pad1d_out_quantized_cpu(const Tensor& input, IntArrayRef padding,
305 Tensor& output) {
306 TORCH_CHECK(input.qscheme() == kPerTensorAffine, "Only per tensor quantization is supported");
307 set_quantizer_(output, make_per_tensor_affine_quantizer(input.q_scale(), input.q_zero_point(), input.scalar_type()));
308 reflection_pad1d_kernel(kCPU, output, input, padding);
309 return output;
310 }
311
TORCH_IMPL_FUNC(reflection_pad1d_out_cpu)312 TORCH_IMPL_FUNC(reflection_pad1d_out_cpu)
313 (const Tensor& input, IntArrayRef padding, const Tensor& output) {
314 reflection_pad1d_kernel(kCPU, output, input, padding);
315 }
316
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cpu)317 TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cpu)(const Tensor& grad_output,
318 const Tensor& input,
319 IntArrayRef padding,
320 const Tensor& grad_input) {
321 if (grad_output.numel() == 0) {
322 return;
323 }
324
325 grad_input.zero_();
326 reflection_pad1d_backward_kernel(kCPU, grad_input, grad_output, padding);
327 }
328
reflection_pad2d_out_cpu(const Tensor & input,IntArrayRef padding,Tensor & output)329 Tensor& reflection_pad2d_out_cpu(const Tensor& input, IntArrayRef padding,
330 Tensor& output) {
331 reflection_pad2d_out_template(output, input, padding);
332 return output;
333 }
334
reflection_pad2d_cpu(const Tensor & input,IntArrayRef padding)335 Tensor reflection_pad2d_cpu(const Tensor& input, IntArrayRef padding) {
336 Tensor output = at::empty({0}, input.options());
337 reflection_pad2d_out_template(output, input, padding);
338 return output;
339 }
340
reflection_pad2d_quantized_cpu(const Tensor & input,IntArrayRef padding)341 Tensor reflection_pad2d_quantized_cpu(const Tensor& input, IntArrayRef padding) {
342 TORCH_CHECK(input.qscheme() == kPerTensorAffine, "Only per tensor quantization is supported");
343 Tensor output = at::_empty_affine_quantized({0}, input.options(),
344 input.q_scale(),
345 input.q_zero_point());
346 reflection_pad2d_out_template(output, input, padding);
347 return output;
348 }
349
reflection_pad2d_backward_out_cpu(const Tensor & grad_output,const Tensor & input,IntArrayRef padding,Tensor & grad_input)350 Tensor& reflection_pad2d_backward_out_cpu(const Tensor& grad_output,
351 const Tensor& input,
352 IntArrayRef padding,
353 Tensor& grad_input) {
354 grad_input.resize_as_(input, input.suggest_memory_format());
355 grad_input.zero_();
356 reflection_pad2d_backward_out_template(
357 grad_input, grad_output, input, padding);
358 return grad_input;
359 }
360
reflection_pad2d_backward_cpu(const Tensor & grad_output,const Tensor & input,IntArrayRef padding)361 Tensor reflection_pad2d_backward_cpu(
362 const Tensor& grad_output,
363 const Tensor& input,
364 IntArrayRef padding) {
365 auto grad_input = at::zeros_like(input, input.suggest_memory_format());
366 reflection_pad2d_backward_out_template(
367 grad_input, grad_output, input, padding);
368 return grad_input;
369 }
370
TORCH_IMPL_FUNC(reflection_pad3d_out_cpu)371 TORCH_IMPL_FUNC(reflection_pad3d_out_cpu)
372 (const Tensor& input, IntArrayRef padding, const Tensor& output) {
373 // TODO: move this to TORCH_META_FUNC when CUDA has channels last support
374 output.resize_(output.sizes(), input.suggest_memory_format());
375
376 reflection_pad3d_kernel(kCPU, output, input, padding);
377 }
378
TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cpu)379 TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cpu)(const Tensor& grad_output,
380 const Tensor& input,
381 IntArrayRef padding,
382 const Tensor& grad_input) {
383 if (grad_output.numel() == 0) {
384 return;
385 }
386
387 // TODO: move this to TORCH_META_FUNC when CUDA has channels last support
388 grad_input.resize_(input.sizes(), input.suggest_memory_format());
389
390 grad_input.zero_();
391 reflection_pad3d_backward_kernel(kCPU, grad_input, grad_output, padding);
392 }
393
394 DEFINE_DISPATCH(reflection_pad1d_kernel);
395 DEFINE_DISPATCH(reflection_pad1d_backward_kernel);
396 DEFINE_DISPATCH(reflection_pad2d_kernel);
397 DEFINE_DISPATCH(reflection_pad2d_backward_kernel);
398 DEFINE_DISPATCH(reflection_pad3d_kernel);
399 DEFINE_DISPATCH(reflection_pad3d_backward_kernel);
400
401 } // namespace at::native
402