xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/padding.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/enum.h>
6 #include <torch/expanding_array.h>
7 #include <torch/types.h>
8 
9 namespace torch {
10 namespace nn {
11 
12 /// Options for a `D`-dimensional ReflectionPad module.
13 template <size_t D>
14 struct TORCH_API ReflectionPadOptions {
ReflectionPadOptionsReflectionPadOptions15   ReflectionPadOptions(ExpandingArray<D * 2> padding) : padding_(padding) {}
16 
17   /// The size of the padding.
18   /// If it is `int`, uses the same padding in all boundaries.
19   /// If it is a 2-`tuple` (for ReflectionPad1d), uses (padding_left,
20   /// padding_right). If it is a 4-`tuple` (for ReflectionPad2d), uses
21   /// (padding_left, padding_right, padding_top, padding_bottom). If it is a
22   /// 6-`tuple` (for ReflectionPad3d), uses (padding_left, padding_right,
23   /// padding_top, padding_bottom, padding_front, padding_back).
24 
25   TORCH_ARG(ExpandingArray<D * 2>, padding);
26 };
27 
28 /// `ReflectionPadOptions` specialized for the `ReflectionPad1d` module.
29 ///
30 /// Example:
31 /// ```
32 /// ReflectionPad1d model(ReflectionPad1dOptions({3, 1}));
33 /// ```
34 using ReflectionPad1dOptions = ReflectionPadOptions<1>;
35 
36 /// `ReflectionPadOptions` specialized for the `ReflectionPad2d` module.
37 ///
38 /// Example:
39 /// ```
40 /// ReflectionPad2d model(ReflectionPad2dOptions({1, 1, 2, 0}));
41 /// ```
42 using ReflectionPad2dOptions = ReflectionPadOptions<2>;
43 
44 /// `ReflectionPadOptions` specialized for the `ReflectionPad3d` module.
45 ///
46 /// Example:
47 /// ```
48 /// ReflectionPad3d model(ReflectionPad3dOptions({1, 1, 2, 0, 1, 1}));
49 /// ```
50 using ReflectionPad3dOptions = ReflectionPadOptions<3>;
51 
52 // ============================================================================
53 
54 /// Options for a `D`-dimensional ReplicationPad module.
55 template <size_t D>
56 struct TORCH_API ReplicationPadOptions {
ReplicationPadOptionsReplicationPadOptions57   ReplicationPadOptions(ExpandingArray<D * 2> padding) : padding_(padding) {}
58 
59   /// The size of the padding.
60   /// - If it is `int`, uses the same padding in all boundaries.
61   /// - If it is a 2-`tuple` (for ReplicationPad1d), uses (padding_left,
62   /// padding_right).
63   /// - If it is a 4-`tuple` (for ReplicationPad2d), uses (padding_left,
64   /// padding_right, padding_top, padding_bottom).
65   /// - If it is a 6-`tuple` (for ReplicationPad3d), uses
66   ///   (padding_left, padding_right, padding_top, padding_bottom,
67   ///   padding_front, padding_back).
68   TORCH_ARG(ExpandingArray<D * 2>, padding);
69 };
70 
71 /// `ReplicationPadOptions` specialized for the `ReplicationPad1d` module.
72 ///
73 /// Example:
74 /// ```
75 /// ReplicationPad1d model(ReplicationPad1dOptions({3, 1}));
76 /// ```
77 using ReplicationPad1dOptions = ReplicationPadOptions<1>;
78 
79 /// `ReplicationPadOptions` specialized for the `ReplicationPad2d` module.
80 ///
81 /// Example:
82 /// ```
83 /// ReplicationPad2d model(ReplicationPad2dOptions({1, 1, 2, 0}));
84 /// ```
85 using ReplicationPad2dOptions = ReplicationPadOptions<2>;
86 
87 /// `ReplicationPadOptions` specialized for the `ReplicationPad3d` module.
88 ///
89 /// Example:
90 /// ```
91 /// ReplicationPad3d model(ReplicationPad3dOptions({1, 2, 1, 2, 1, 2}));
92 /// ```
93 using ReplicationPad3dOptions = ReplicationPadOptions<3>;
94 
95 // ============================================================================
96 
97 template <size_t D>
98 struct TORCH_API ZeroPadOptions {
ZeroPadOptionsZeroPadOptions99   ZeroPadOptions(ExpandingArray<D * 2> padding) : padding_(padding) {}
100 
101   /// The size of the padding.
102   /// - If it is `int`, uses the same padding in all boundaries.
103   /// - If it is a 2-`tuple` (for ZeroPad1d), uses (padding_left,
104   /// padding_right).
105   /// - If it is a 4-`tuple` (for ZeroPad2d), uses (padding_left, padding_right,
106   /// padding_top, padding_bottom).
107   /// - If it is a 6-`tuple` (for ZeroPad3d), uses
108   ///   (padding_left, padding_right, padding_top, padding_bottom,
109   ///   padding_front, padding_back).
110   TORCH_ARG(ExpandingArray<D * 2>, padding);
111 };
112 
113 /// `ZeroPadOptions` specialized for the `ZeroPad1d` module.
114 ///
115 /// Example:
116 /// ```
117 /// ConstantPad1d model(ConstantPad1dOptions({3, 1});
118 /// ```
119 using ZeroPad1dOptions = ZeroPadOptions<1>;
120 
121 /// `ZeroPadOptions` specialized for the `ZeroPad2d` module.
122 ///
123 /// Example:
124 /// ```
125 /// ConstantPad2d model(ConstantPad2dOptions({1, 1, 2, 0});
126 /// ```
127 using ZeroPad2dOptions = ZeroPadOptions<2>;
128 
129 /// `ZeroPadOptions` specialized for the `ZeroPad3d` module.
130 ///
131 /// Example:
132 /// ```
133 /// ConstantPad3d model(ConstantPad3dOptions({1, 2, 1, 2, 1, 2});
134 /// ```
135 using ZeroPad3dOptions = ZeroPadOptions<3>;
136 
137 // ============================================================================
138 
139 /// Options for a `D`-dimensional ConstantPad module.
140 template <size_t D>
141 struct TORCH_API ConstantPadOptions {
ConstantPadOptionsConstantPadOptions142   ConstantPadOptions(ExpandingArray<D * 2> padding, double value)
143       : padding_(padding), value_(value) {}
144 
145   /// The size of the padding.
146   /// - If it is `int`, uses the same padding in all boundaries.
147   /// - If it is a 2-`tuple` (for ConstantPad1d), uses (padding_left,
148   /// padding_right).
149   /// - If it is a 4-`tuple` (for ConstantPad2d), uses (padding_left,
150   /// padding_right, padding_top, padding_bottom).
151   /// - If it is a 6-`tuple` (for ConstantPad3d), uses
152   ///   (padding_left, padding_right, padding_top, padding_bottom,
153   ///   padding_front, padding_back).
154   TORCH_ARG(ExpandingArray<D * 2>, padding);
155 
156   /// Fill value for constant padding.
157   TORCH_ARG(double, value);
158 };
159 
160 /// `ConstantPadOptions` specialized for the `ConstantPad1d` module.
161 ///
162 /// Example:
163 /// ```
164 /// ConstantPad1d model(ConstantPad1dOptions({3, 1}, 3.5));
165 /// ```
166 using ConstantPad1dOptions = ConstantPadOptions<1>;
167 
168 /// `ConstantPadOptions` specialized for the `ConstantPad2d` module.
169 ///
170 /// Example:
171 /// ```
172 /// ConstantPad2d model(ConstantPad2dOptions({3, 0, 2, 1}, 3.5));
173 /// ```
174 using ConstantPad2dOptions = ConstantPadOptions<2>;
175 
176 /// `ConstantPadOptions` specialized for the `ConstantPad3d` module.
177 ///
178 /// Example:
179 /// ```
180 /// ConstantPad3d model(ConstantPad3dOptions({1, 2, 1, 2, 1, 2}, 3.5));
181 /// ```
182 using ConstantPad3dOptions = ConstantPadOptions<3>;
183 
184 // ============================================================================
185 
186 namespace functional {
187 
188 /// Options for `torch::nn::functional::pad`.
189 ///
190 /// Example:
191 /// ```
192 /// namespace F = torch::nn::functional;
193 /// F::pad(input, F::PadFuncOptions({1, 2, 2, 1, 1,
194 /// 2}).mode(torch::kReplicate));
195 /// ```
196 struct TORCH_API PadFuncOptions {
197   typedef std::variant<
198       enumtype::kConstant,
199       enumtype::kReflect,
200       enumtype::kReplicate,
201       enumtype::kCircular>
202       mode_t;
203 
204   PadFuncOptions(std::vector<int64_t> pad);
205 
206   /// m-elements tuple, where m/2 <= input dimensions and m is even.
207   TORCH_ARG(std::vector<int64_t>, pad);
208 
209   /// "constant", "reflect", "replicate" or "circular". Default: "constant"
210   TORCH_ARG(mode_t, mode) = torch::kConstant;
211 
212   /// fill value for "constant" padding. Default: 0
213   TORCH_ARG(double, value) = 0;
214 };
215 
216 } // namespace functional
217 
218 } // namespace nn
219 } // namespace torch
220