1 #pragma once
2
3 #include <ATen/PadNd.h>
4 #include <torch/nn/options/padding.h>
5
6 namespace torch {
7 namespace nn {
8 namespace functional {
9
10 #ifndef DOXYGEN_SHOULD_SKIP_THIS
11 namespace detail {
pad(const Tensor & input,IntArrayRef pad,PadFuncOptions::mode_t mode,double value)12 inline Tensor pad(
13 const Tensor& input,
14 IntArrayRef pad,
15 PadFuncOptions::mode_t mode,
16 double value) {
17 const auto mode_enum = [&] {
18 if (std::holds_alternative<enumtype::kConstant>(mode)) {
19 return at::padding_mode::constant;
20 } else if (std::holds_alternative<enumtype::kReflect>(mode)) {
21 return at::padding_mode::reflect;
22 } else if (std::holds_alternative<enumtype::kReplicate>(mode)) {
23 return at::padding_mode::replicate;
24 } else if (std::holds_alternative<enumtype::kCircular>(mode)) {
25 return at::padding_mode::circular;
26 }
27 TORCH_CHECK(false, "Unrecognised padding mode");
28 }();
29
30 std::optional<double> fill_value;
31 if (value != 0.0) {
32 fill_value = value;
33 }
34 return at::_pad_enum(input, pad, static_cast<int64_t>(mode_enum), fill_value);
35 }
36 } // namespace detail
37 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
38
39 /// See
40 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.pad
41 /// about the exact behavior of this functional.
42 ///
43 /// See the documentation for `torch::nn::functional::PadFuncOptions` class to
44 /// learn what optional arguments are supported for this functional.
45 ///
46 /// Example:
47 /// ```
48 /// namespace F = torch::nn::functional;
49 /// F::pad(input, F::PadFuncOptions({1, 2, 2, 1, 1,
50 /// 2}).mode(torch::kReplicate));
51 /// ```
pad(const Tensor & input,const PadFuncOptions & options)52 inline Tensor pad(const Tensor& input, const PadFuncOptions& options) {
53 return detail::pad(input, options.pad(), options.mode(), options.value());
54 }
55
56 } // namespace functional
57 } // namespace nn
58 } // namespace torch
59