xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Padding.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/DispatchStub.h>
5 
6 namespace at::native {
7 
8 using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef);
9 
10 // reflection padding
11 DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel);
12 DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel);
13 DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel);
14 DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel);
15 DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel);
16 DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel);
17 
18 // replication padding
19 DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel);
20 DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel);
21 DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel);
22 DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel);
23 DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel);
24 DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel);
25 
26 namespace padding {
27 
28 template <int dim>
check_valid_input(const Tensor & input,IntArrayRef padding)29 inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
30 
31   TORCH_CHECK(padding.size() == 2 * dim,
32       "padding size is expected to be ", 2 * dim,
33       ", but got: ", padding.size());
34 
35   int input_dim = input.dim();
36 
37   bool is_batch_mode = input_dim == (dim + 2);
38 
39   bool valid_batch_mode = is_batch_mode;
40   bool valid_non_batch_mode = !is_batch_mode;
41 
42   if (is_batch_mode) {
43     // allow batch size of 0-dim.
44     for (const auto d : c10::irange(1, input_dim)) {
45       valid_batch_mode = valid_batch_mode && input.size(d) != 0;
46     }
47   } else {
48     for (const auto d : c10::irange(0, input_dim)) {
49       valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0;
50     }
51   }
52 
53   // allow empty batch size but not other dimensions.
54   TORCH_CHECK(valid_batch_mode || valid_non_batch_mode,
55       "Expected ", dim + 1, "D or ", dim + 2,
56       "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
57       input.sizes());
58 }
59 
60 } // namespace padding
61 
62 } // at::native
63