1 #pragma once
2
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/DispatchStub.h>
5
6 namespace at::native {
7
8 using padding_fn = void (*)(const Tensor&, const Tensor&, IntArrayRef);
9
10 // reflection padding
11 DECLARE_DISPATCH(padding_fn, reflection_pad1d_kernel);
12 DECLARE_DISPATCH(padding_fn, reflection_pad1d_backward_kernel);
13 DECLARE_DISPATCH(padding_fn, reflection_pad2d_kernel);
14 DECLARE_DISPATCH(padding_fn, reflection_pad2d_backward_kernel);
15 DECLARE_DISPATCH(padding_fn, reflection_pad3d_kernel);
16 DECLARE_DISPATCH(padding_fn, reflection_pad3d_backward_kernel);
17
18 // replication padding
19 DECLARE_DISPATCH(padding_fn, replication_pad1d_kernel);
20 DECLARE_DISPATCH(padding_fn, replication_pad1d_backward_kernel);
21 DECLARE_DISPATCH(padding_fn, replication_pad2d_kernel);
22 DECLARE_DISPATCH(padding_fn, replication_pad2d_backward_kernel);
23 DECLARE_DISPATCH(padding_fn, replication_pad3d_kernel);
24 DECLARE_DISPATCH(padding_fn, replication_pad3d_backward_kernel);
25
26 namespace padding {
27
28 template <int dim>
check_valid_input(const Tensor & input,IntArrayRef padding)29 inline void check_valid_input(const Tensor& input, IntArrayRef padding) {
30
31 TORCH_CHECK(padding.size() == 2 * dim,
32 "padding size is expected to be ", 2 * dim,
33 ", but got: ", padding.size());
34
35 int input_dim = input.dim();
36
37 bool is_batch_mode = input_dim == (dim + 2);
38
39 bool valid_batch_mode = is_batch_mode;
40 bool valid_non_batch_mode = !is_batch_mode;
41
42 if (is_batch_mode) {
43 // allow batch size of 0-dim.
44 for (const auto d : c10::irange(1, input_dim)) {
45 valid_batch_mode = valid_batch_mode && input.size(d) != 0;
46 }
47 } else {
48 for (const auto d : c10::irange(0, input_dim)) {
49 valid_non_batch_mode = valid_non_batch_mode && input.size(d) != 0;
50 }
51 }
52
53 // allow empty batch size but not other dimensions.
54 TORCH_CHECK(valid_batch_mode || valid_non_batch_mode,
55 "Expected ", dim + 1, "D or ", dim + 2,
56 "D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
57 input.sizes());
58 }
59
60 } // namespace padding
61
62 } // at::native
63