xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ReflectionPad.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/quantized/Quantizer.h>
7 #include <ATen/native/Padding.h>
8 #include <c10/util/irange.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_empty_affine_quantized.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/reflection_pad1d_backward_native.h>
17 #include <ATen/ops/reflection_pad1d_native.h>
18 #include <ATen/ops/reflection_pad2d_backward_native.h>
19 #include <ATen/ops/reflection_pad2d_native.h>
20 #include <ATen/ops/reflection_pad3d_backward_native.h>
21 #include <ATen/ops/reflection_pad3d_native.h>
22 #include <ATen/ops/zeros_like.h>
23 #endif
24 
25 namespace at::meta {
26 
TORCH_META_FUNC(reflection_pad1d)27 TORCH_META_FUNC(reflection_pad1d)(const Tensor& input, IntArrayRef padding) {
28   int64_t dim_plane = 0;
29   int64_t dim_w = 1;
30   int64_t nbatch = 1;
31 
32   if (input.ndimension() == 3) {
33     nbatch = input.size(0);
34     dim_w++;
35     dim_plane++;
36   }
37 
38   at::native::padding::check_valid_input<1>(input, padding);
39 
40   /* sizes */
41   auto pad_l = padding[0];
42   auto pad_r = padding[1];
43 
44   int64_t nplane = input.size(dim_plane);
45   int64_t input_w = input.size(dim_w);
46   int64_t output_w = input_w + pad_l + pad_r;
47 
48   TORCH_CHECK(
49       pad_l < input_w && pad_r < input_w,
50       "Argument #4: Padding size "
51       "should be less than the corresponding input dimension, but got: padding (",
52       pad_l,
53       ", ",
54       pad_r,
55       ") at dimension ",
56       dim_w,
57       " of input ",
58       input.sizes());
59 
60   TORCH_CHECK(
61       output_w >= 1,
62       "input (W: ",
63       input_w,
64       ") is too small. Calculated output W: ",
65       output_w);
66 
67   if (input.ndimension() == 2) {
68     set_output_raw_strided(0, {nplane, output_w}, {}, input.options());
69   } else {
70     set_output_raw_strided(0, {nbatch, nplane, output_w}, {}, input.options());
71   }
72 }
73 
TORCH_META_FUNC(reflection_pad1d_backward)74 TORCH_META_FUNC(reflection_pad1d_backward)(const Tensor& grad_output,
75     const Tensor& input,
76     IntArrayRef padding) {
77   int64_t dim_w = 1;
78   if (input.ndimension() == 3) {
79     dim_w++;
80   }
81 
82   /* sizes */
83   auto pad_l = padding[0];
84   auto pad_r = padding[1];
85   int64_t input_w = input.size(dim_w);
86   int64_t output_w  = input_w + pad_l + pad_r;
87 
88   TORCH_CHECK(
89       pad_l < input_w && pad_r < input_w,
90       "Argument #4: Padding size "
91       "should be less than the corresponding input dimension, but got: padding (",
92       pad_l,
93       ", ",
94       pad_r,
95       ") at dimension ",
96       dim_w,
97       " of input ",
98       input.sizes());
99 
100   TORCH_CHECK(output_w == grad_output.size(dim_w), "grad_output width unexpected."
101     " Expected: ", output_w, ", Got: ", grad_output.size(dim_w));
102 
103   set_output_raw_strided(0, input.sizes(), {}, input.options());
104 }
105 
TORCH_META_FUNC(reflection_pad3d)106 TORCH_META_FUNC(reflection_pad3d)(const Tensor& input, IntArrayRef padding) {
107   int64_t pad_left = padding[0];
108   int64_t pad_right = padding[1];
109   int64_t pad_top = padding[2];
110   int64_t pad_bottom = padding[3];
111   int64_t pad_front = padding[4];
112   int64_t pad_back = padding[5];
113   int64_t dim_w = 3;
114   int64_t dim_h = 2;
115   int64_t dim_d = 1;
116   int64_t dim_plane = 0;
117 
118   at::native::padding::check_valid_input<3>(input, padding);
119 
120   bool batch_mode = (input.dim() == 5);
121   if (batch_mode) {
122     dim_w++;
123     dim_h++;
124     dim_d++;
125     dim_plane++;
126   }
127 
128   int64_t nplane = input.size(dim_plane);
129   int64_t input_d = input.size(dim_d);
130   int64_t input_h = input.size(dim_h);
131   int64_t input_w = input.size(dim_w);
132   int64_t output_d = input_d + pad_front + pad_back;
133   int64_t output_h = input_h + pad_top + pad_bottom;
134   int64_t output_w = input_w + pad_left + pad_right;
135 
136   TORCH_CHECK(
137       pad_left < input_w && pad_right < input_w,
138       "Argument #4: Padding size "
139       "should be less than the corresponding input dimension, but got: padding (",
140       pad_left, ", ", pad_right, ") at dimension ", dim_w, " of input ", input.sizes());
141   TORCH_CHECK(
142       pad_top < input_h && pad_bottom < input_h,
143       "Argument #6: Padding size "
144       "should be less than the corresponding input dimension, but got: padding (",
145       pad_top, ", ", pad_bottom, ") at dimension ", dim_h, " of input ", input.sizes());
146   TORCH_CHECK(
147       pad_front < input_d && pad_back < input_d,
148       "Argument #8: Padding size "
149       "should be less than the corresponding input dimension, but got: padding (",
150       pad_front, ", ", pad_back, ") at dimension ", dim_d, " of input ", input.sizes());
151 
152   TORCH_CHECK(output_w >= 1 || output_h >=1 || output_d >= 1,
153       "input (D: ", input_d, " H: ", input_h, ", W: ", input_w,
154       ") is too small."
155       " Calculated output D: ", output_d, " H: ", output_h, " W: ", output_w);
156 
157   if (batch_mode) {
158     set_output_raw_strided(0, {input.size(0), nplane, output_d, output_h, output_w}, {}, input.options());
159   } else {
160     set_output_raw_strided(0, {nplane, output_d, output_h, output_w}, {}, input.options());
161   }
162 }
163 
TORCH_META_FUNC(reflection_pad3d_backward)164 TORCH_META_FUNC(reflection_pad3d_backward)(
165     const Tensor& grad_output,
166     const Tensor& input,
167     IntArrayRef padding
168 ) {
169   TORCH_CHECK(padding.size() == 6, "padding size is expected to be 6");
170   TORCH_CHECK(input.dim() > 3);
171   TORCH_CHECK(grad_output.dim() == input.dim());
172 
173   int64_t pad_left = padding[0];
174   int64_t pad_right = padding[1];
175   int64_t pad_top = padding[2];
176   int64_t pad_bottom = padding[3];
177   int64_t pad_front = padding[4];
178   int64_t pad_back = padding[5];
179   int64_t dim_w = 3;
180   int64_t dim_h = 2;
181   int64_t dim_d = 1;
182 
183   if (input.dim() == 5) {
184     // batch mode
185     dim_w++;
186     dim_h++;
187     dim_d++;
188   }
189 
190   int64_t input_d = input.size(dim_d);
191   int64_t input_h = input.size(dim_h);
192   int64_t input_w = input.size(dim_w);
193   int64_t output_d = input_d + pad_front + pad_back;
194   int64_t output_h = input_h + pad_top + pad_bottom;
195   int64_t output_w = input_w + pad_left + pad_right;
196 
197   TORCH_CHECK(output_w == grad_output.size(dim_w), "grad_output width unexpected."
198     " Expected: ", output_w, ", Got: ", grad_output.size(dim_w));
199   TORCH_CHECK(output_h == grad_output.size(dim_h), "grad_output height unexpected."
200     " Expected: ", output_h, ", Got: ", grad_output.size(dim_h));
201   TORCH_CHECK(output_d == grad_output.size(dim_d), "grad_output depth unexpected."
202     " Expected: ", output_d, ", Got: ", grad_output.size(dim_d));
203 
204   set_output_raw_strided(0, input.sizes(), {}, input.options());
205 }
206 } // namespace at::meta
207 
208 namespace at::native {
209 
210 namespace {
211 
reflection_pad2d_out_template(Tensor & output,const Tensor & input,IntArrayRef padding)212 void reflection_pad2d_out_template(
213     Tensor &output, const Tensor &input, IntArrayRef padding) {
214   int dim_w = 2;
215   int dim_h = 1;
216   int dim_slices = 0;
217   int64_t nbatch = 1;
218 
219   at::native::padding::check_valid_input<2>(input, padding);
220 
221   int ndim = input.dim();
222   if (ndim == 4) {
223     nbatch = input.size(0);
224     dim_w++;
225     dim_h++;
226     dim_slices++;
227   }
228 
229   /* sizes */
230   int64_t pad_l = padding[0];
231   int64_t pad_r = padding[1];
232   int64_t pad_t = padding[2];
233   int64_t pad_b = padding[3];
234 
235   int64_t nplane = input.size(dim_slices);
236   int64_t input_h = input.size(dim_h);
237   int64_t input_w = input.size(dim_w);
238   int64_t output_h = input_h + pad_t + pad_b;
239   int64_t output_w  = input_w + pad_l + pad_r;
240 
241   TORCH_CHECK(pad_l < input_w && pad_r < input_w,
242     "Argument #4: Padding size should be less than the corresponding "
243     "input dimension, but got: padding (", pad_l, ", ", pad_r,
244     ") at dimension ", dim_w, " of input ", input.sizes());
245 
246   TORCH_CHECK(pad_t < input_h && pad_b < input_h,
247     "Argument #6: Padding size should be less than the corresponding "
248     "input dimension, but got: padding (", pad_t, ", ", pad_b,
249     ") at dimension ", dim_h, " of input ", input.sizes());
250 
251   TORCH_CHECK(output_w >= 1 || output_h >= 1,
252     "input (H: ", input_h, ", W: ", input_w, ") is too small. Calculated "
253     "output H: ", output_h, " W: ", output_w);
254 
255   /* resize output */
256   if (ndim == 3) {
257     output.resize_({nplane, output_h, output_w});
258   } else {
259     if (input.is_quantized()) {
260       // quantized tensor can not be resized with argument `memory_format`
261       output.resize_({nbatch, nplane, output_h, output_w});
262     } else {
263       output.resize_({nbatch, nplane, output_h, output_w}, input.suggest_memory_format());
264     }
265   }
266   reflection_pad2d_kernel(kCPU, output, input, padding);
267 }
268 
reflection_pad2d_backward_out_template(Tensor & grad_input,const Tensor & grad_output,const Tensor & input,IntArrayRef padding)269 void reflection_pad2d_backward_out_template(
270     Tensor &grad_input, const Tensor &grad_output,
271     const Tensor &input, IntArrayRef padding) {
272   int dim_w = 2;
273   int dim_h = 1;
274 
275   if (input.ndimension() == 4) {
276     dim_w++;
277     dim_h++;
278   }
279 
280   /* sizes */
281   int64_t pad_l = padding[0];
282   int64_t pad_r = padding[1];
283   int64_t pad_t = padding[2];
284   int64_t pad_b = padding[3];
285 
286   int64_t input_h = input.size(dim_h);
287   int64_t input_w = input.size(dim_w);
288   int64_t output_h = input_h + pad_t + pad_b;
289   int64_t output_w  = input_w + pad_l + pad_r;
290 
291   TORCH_CHECK(output_w == grad_output.size(dim_w),
292     "gradOutput width unexpected. Expected: ", output_w, ", Got: ",
293     grad_output.size(dim_w));
294 
295   TORCH_CHECK(output_h == grad_output.size(dim_h),
296     "gradOutput height unexpected. Expected: ", output_h, ", Got: ",
297     grad_output.size(dim_h));
298 
299   reflection_pad2d_backward_kernel(kCPU, grad_input, grad_output, padding);
300 }
301 
302 } // namespace
303 
reflection_pad1d_out_quantized_cpu(const Tensor & input,IntArrayRef padding,Tensor & output)304 Tensor& reflection_pad1d_out_quantized_cpu(const Tensor& input, IntArrayRef padding,
305     Tensor& output) {
306   TORCH_CHECK(input.qscheme() == kPerTensorAffine, "Only per tensor quantization is supported");
307   set_quantizer_(output, make_per_tensor_affine_quantizer(input.q_scale(), input.q_zero_point(), input.scalar_type()));
308   reflection_pad1d_kernel(kCPU, output, input, padding);
309   return output;
310 }
311 
TORCH_IMPL_FUNC(reflection_pad1d_out_cpu)312 TORCH_IMPL_FUNC(reflection_pad1d_out_cpu)
313 (const Tensor& input, IntArrayRef padding, const Tensor& output) {
314   reflection_pad1d_kernel(kCPU, output, input, padding);
315 }
316 
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cpu)317 TORCH_IMPL_FUNC(reflection_pad1d_backward_out_cpu)(const Tensor& grad_output,
318     const Tensor& input,
319     IntArrayRef padding,
320     const Tensor& grad_input) {
321   if (grad_output.numel() == 0) {
322     return;
323   }
324 
325   grad_input.zero_();
326   reflection_pad1d_backward_kernel(kCPU, grad_input, grad_output, padding);
327 }
328 
reflection_pad2d_out_cpu(const Tensor & input,IntArrayRef padding,Tensor & output)329 Tensor& reflection_pad2d_out_cpu(const Tensor& input, IntArrayRef padding,
330     Tensor& output) {
331   reflection_pad2d_out_template(output, input, padding);
332   return output;
333 }
334 
reflection_pad2d_cpu(const Tensor & input,IntArrayRef padding)335 Tensor reflection_pad2d_cpu(const Tensor& input, IntArrayRef padding) {
336   Tensor output = at::empty({0}, input.options());
337   reflection_pad2d_out_template(output, input, padding);
338   return output;
339 }
340 
reflection_pad2d_quantized_cpu(const Tensor & input,IntArrayRef padding)341 Tensor reflection_pad2d_quantized_cpu(const Tensor& input, IntArrayRef padding) {
342   TORCH_CHECK(input.qscheme() == kPerTensorAffine, "Only per tensor quantization is supported");
343   Tensor output = at::_empty_affine_quantized({0}, input.options(),
344                                            input.q_scale(),
345                                            input.q_zero_point());
346   reflection_pad2d_out_template(output, input, padding);
347   return output;
348 }
349 
reflection_pad2d_backward_out_cpu(const Tensor & grad_output,const Tensor & input,IntArrayRef padding,Tensor & grad_input)350 Tensor& reflection_pad2d_backward_out_cpu(const Tensor& grad_output,
351     const Tensor& input,
352     IntArrayRef padding,
353     Tensor& grad_input) {
354   grad_input.resize_as_(input, input.suggest_memory_format());
355   grad_input.zero_();
356   reflection_pad2d_backward_out_template(
357     grad_input, grad_output, input, padding);
358   return grad_input;
359 }
360 
reflection_pad2d_backward_cpu(const Tensor & grad_output,const Tensor & input,IntArrayRef padding)361 Tensor reflection_pad2d_backward_cpu(
362     const Tensor& grad_output,
363     const Tensor& input,
364     IntArrayRef padding) {
365   auto grad_input = at::zeros_like(input, input.suggest_memory_format());
366   reflection_pad2d_backward_out_template(
367     grad_input, grad_output, input, padding);
368   return grad_input;
369 }
370 
TORCH_IMPL_FUNC(reflection_pad3d_out_cpu)371 TORCH_IMPL_FUNC(reflection_pad3d_out_cpu)
372 (const Tensor& input, IntArrayRef padding, const Tensor& output) {
373   // TODO: move this to TORCH_META_FUNC when CUDA has channels last support
374   output.resize_(output.sizes(), input.suggest_memory_format());
375 
376   reflection_pad3d_kernel(kCPU, output, input, padding);
377 }
378 
TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cpu)379 TORCH_IMPL_FUNC(reflection_pad3d_backward_out_cpu)(const Tensor& grad_output,
380     const Tensor& input,
381     IntArrayRef padding,
382     const Tensor& grad_input) {
383   if (grad_output.numel() == 0) {
384     return;
385   }
386 
387   // TODO: move this to TORCH_META_FUNC when CUDA has channels last support
388   grad_input.resize_(input.sizes(), input.suggest_memory_format());
389 
390   grad_input.zero_();
391   reflection_pad3d_backward_kernel(kCPU, grad_input, grad_output, padding);
392 }
393 
394 DEFINE_DISPATCH(reflection_pad1d_kernel);
395 DEFINE_DISPATCH(reflection_pad1d_backward_kernel);
396 DEFINE_DISPATCH(reflection_pad2d_kernel);
397 DEFINE_DISPATCH(reflection_pad2d_backward_kernel);
398 DEFINE_DISPATCH(reflection_pad3d_kernel);
399 DEFINE_DISPATCH(reflection_pad3d_backward_kernel);
400 
401 } // namespace at::native
402