xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/functional/padding.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/PadNd.h>
4 #include <torch/nn/options/padding.h>
5 
6 namespace torch {
7 namespace nn {
8 namespace functional {
9 
10 #ifndef DOXYGEN_SHOULD_SKIP_THIS
11 namespace detail {
pad(const Tensor & input,IntArrayRef pad,PadFuncOptions::mode_t mode,double value)12 inline Tensor pad(
13     const Tensor& input,
14     IntArrayRef pad,
15     PadFuncOptions::mode_t mode,
16     double value) {
17   const auto mode_enum = [&] {
18     if (std::holds_alternative<enumtype::kConstant>(mode)) {
19       return at::padding_mode::constant;
20     } else if (std::holds_alternative<enumtype::kReflect>(mode)) {
21       return at::padding_mode::reflect;
22     } else if (std::holds_alternative<enumtype::kReplicate>(mode)) {
23       return at::padding_mode::replicate;
24     } else if (std::holds_alternative<enumtype::kCircular>(mode)) {
25       return at::padding_mode::circular;
26     }
27     TORCH_CHECK(false, "Unrecognised padding mode");
28   }();
29 
30   std::optional<double> fill_value;
31   if (value != 0.0) {
32     fill_value = value;
33   }
34   return at::_pad_enum(input, pad, static_cast<int64_t>(mode_enum), fill_value);
35 }
36 } // namespace detail
37 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
38 
39 /// See
40 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.pad
41 /// about the exact behavior of this functional.
42 ///
43 /// See the documentation for `torch::nn::functional::PadFuncOptions` class to
44 /// learn what optional arguments are supported for this functional.
45 ///
46 /// Example:
47 /// ```
48 /// namespace F = torch::nn::functional;
49 /// F::pad(input, F::PadFuncOptions({1, 2, 2, 1, 1,
50 /// 2}).mode(torch::kReplicate));
51 /// ```
pad(const Tensor & input,const PadFuncOptions & options)52 inline Tensor pad(const Tensor& input, const PadFuncOptions& options) {
53   return detail::pad(input, options.pad(), options.mode(), options.value());
54 }
55 
56 } // namespace functional
57 } // namespace nn
58 } // namespace torch
59