xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/conv.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/enum.h>
6 #include <torch/expanding_array.h>
7 #include <torch/types.h>
8 
9 namespace torch {
10 namespace nn {
11 
12 namespace detail {
13 
14 typedef std::variant<
15     enumtype::kZeros,
16     enumtype::kReflect,
17     enumtype::kReplicate,
18     enumtype::kCircular>
19     conv_padding_mode_t;
20 
21 template <size_t D>
22 using conv_padding_t =
23     std::variant<ExpandingArray<D>, enumtype::kValid, enumtype::kSame>;
24 
25 /// Options for a `D`-dimensional convolution or convolution transpose module.
26 template <size_t D>
27 struct ConvNdOptions {
28   using padding_t = conv_padding_t<D>;
ConvNdOptionsConvNdOptions29   ConvNdOptions(
30       int64_t in_channels,
31       int64_t out_channels,
32       ExpandingArray<D> kernel_size)
33       : in_channels_(in_channels),
34         out_channels_(out_channels),
35         kernel_size_(std::move(kernel_size)) {}
36 
37   /// The number of channels the input volumes will have.
38   /// Changing this parameter after construction __has no effect__.
39   TORCH_ARG(int64_t, in_channels);
40 
41   /// The number of output channels the convolution should produce.
42   /// Changing this parameter after construction __has no effect__.
43   TORCH_ARG(int64_t, out_channels);
44 
45   /// The kernel size to use.
46   /// For a `D`-dim convolution, must be a single number or a list of `D`
47   /// numbers.
48   /// This parameter __can__ be changed after construction.
49   TORCH_ARG(ExpandingArray<D>, kernel_size);
50 
51   /// The stride of the convolution.
52   /// For a `D`-dim convolution, must be a single number or a list of `D`
53   /// numbers.
54   /// This parameter __can__ be changed after construction.
55   TORCH_ARG(ExpandingArray<D>, stride) = 1;
56 
57   /// The padding to add to the input volumes.
58   /// For a `D`-dim convolution, must be a single number or a list of `D`
59   /// numbers.
60   /// This parameter __can__ be changed after construction.
61   TORCH_ARG(padding_t, padding) = 0;
62 
63  public:
decltypeConvNdOptions64   decltype(auto) padding(std::initializer_list<int64_t> il) {
65     return padding(IntArrayRef{il});
66   }
67 
68   /// The kernel dilation.
69   /// For a `D`-dim convolution, must be a single number or a list of `D`
70   /// numbers.
71   /// This parameter __can__ be changed after construction.
72   TORCH_ARG(ExpandingArray<D>, dilation) = 1;
73 
74   /// If true, convolutions will be transpose convolutions (a.k.a.
75   /// deconvolutions).
76   /// Changing this parameter after construction __has no effect__.
77   TORCH_ARG(bool, transposed) = false;
78 
79   /// For transpose convolutions, the padding to add to output volumes.
80   /// For a `D`-dim convolution, must be a single number or a list of `D`
81   /// numbers.
82   /// This parameter __can__ be changed after construction.
83   TORCH_ARG(ExpandingArray<D>, output_padding) = 0;
84 
85   /// The number of convolution groups.
86   /// This parameter __can__ be changed after construction.
87   TORCH_ARG(int64_t, groups) = 1;
88 
89   /// Whether to add a bias after individual applications of the kernel.
90   /// Changing this parameter after construction __has no effect__.
91   TORCH_ARG(bool, bias) = true;
92 
93   /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or
94   /// `torch::kCircular`. Default: `torch::kZeros`
95   TORCH_ARG(conv_padding_mode_t, padding_mode) = torch::kZeros;
96 };
97 
98 } // namespace detail
99 
100 // ============================================================================
101 
102 /// Options for a `D`-dimensional convolution module.
103 template <size_t D>
104 struct ConvOptions {
105   using padding_mode_t = detail::conv_padding_mode_t;
106   using padding_t = detail::conv_padding_t<D>;
107 
ConvOptionsConvOptions108   ConvOptions(
109       int64_t in_channels,
110       int64_t out_channels,
111       ExpandingArray<D> kernel_size)
112       : in_channels_(in_channels),
113         out_channels_(out_channels),
114         kernel_size_(std::move(kernel_size)) {}
115 
116   /// The number of channels the input volumes will have.
117   /// Changing this parameter after construction __has no effect__.
118   TORCH_ARG(int64_t, in_channels);
119 
120   /// The number of output channels the convolution should produce.
121   /// Changing this parameter after construction __has no effect__.
122   TORCH_ARG(int64_t, out_channels);
123 
124   /// The kernel size to use.
125   /// For a `D`-dim convolution, must be a single number or a list of `D`
126   /// numbers.
127   /// This parameter __can__ be changed after construction.
128   TORCH_ARG(ExpandingArray<D>, kernel_size);
129 
130   /// The stride of the convolution.
131   /// For a `D`-dim convolution, must be a single number or a list of `D`
132   /// numbers.
133   /// This parameter __can__ be changed after construction.
134   TORCH_ARG(ExpandingArray<D>, stride) = 1;
135 
136   /// The padding to add to the input volumes.
137   /// For a `D`-dim convolution, must be a single number or a list of `D`
138   /// numbers.
139   /// This parameter __can__ be changed after construction.
140   TORCH_ARG(padding_t, padding) = 0;
141 
142  public:
decltypeConvOptions143   decltype(auto) padding(std::initializer_list<int64_t> il) {
144     return padding(IntArrayRef{il});
145   }
146 
147   /// The kernel dilation.
148   /// For a `D`-dim convolution, must be a single number or a list of `D`
149   /// numbers.
150   /// This parameter __can__ be changed after construction.
151   TORCH_ARG(ExpandingArray<D>, dilation) = 1;
152 
153   /// The number of convolution groups.
154   /// This parameter __can__ be changed after construction.
155   TORCH_ARG(int64_t, groups) = 1;
156 
157   /// Whether to add a bias after individual applications of the kernel.
158   /// Changing this parameter after construction __has no effect__.
159   TORCH_ARG(bool, bias) = true;
160 
161   /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or
162   /// `torch::kCircular`. Default: `torch::kZeros`
163   TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
164 };
165 
166 /// `ConvOptions` specialized for the `Conv1d` module.
167 ///
168 /// Example:
169 /// ```
170 /// Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false));
171 /// ```
172 using Conv1dOptions = ConvOptions<1>;
173 
174 /// `ConvOptions` specialized for the `Conv2d` module.
175 ///
176 /// Example:
177 /// ```
178 /// Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false));
179 /// ```
180 using Conv2dOptions = ConvOptions<2>;
181 
182 /// `ConvOptions` specialized for the `Conv3d` module.
183 ///
184 /// Example:
185 /// ```
186 /// Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false));
187 /// ```
188 using Conv3dOptions = ConvOptions<3>;
189 
190 // ============================================================================
191 
192 namespace functional {
193 
194 /// Options for a `D`-dimensional convolution functional.
195 template <size_t D>
196 struct ConvFuncOptions {
197   using padding_t = torch::nn::detail::conv_padding_t<D>;
198 
199   /// optional bias of shape `(out_channels)`. Default: ``None``
200   TORCH_ARG(torch::Tensor, bias) = Tensor();
201 
202   /// The stride of the convolving kernel.
203   /// For a `D`-dim convolution, must be a single number or a list of `D`
204   /// numbers.
205   TORCH_ARG(ExpandingArray<D>, stride) = 1;
206 
207   /// Implicit paddings on both sides of the input.
208   /// For a `D`-dim convolution, must be a single number or a list of `D`
209   /// numbers.
210   TORCH_ARG(padding_t, padding) = 0;
211 
212  public:
decltypeConvFuncOptions213   decltype(auto) padding(std::initializer_list<int64_t> il) {
214     return padding(IntArrayRef{il});
215   }
216 
217   /// The spacing between kernel elements.
218   /// For a `D`-dim convolution, must be a single number or a list of `D`
219   /// numbers.
220   TORCH_ARG(ExpandingArray<D>, dilation) = 1;
221 
222   /// Split input into groups, `in_channels` should be divisible by
223   /// the number of groups.
224   TORCH_ARG(int64_t, groups) = 1;
225 };
226 
227 /// `ConvFuncOptions` specialized for `torch::nn::functional::conv1d`.
228 ///
229 /// Example:
230 /// ```
231 /// namespace F = torch::nn::functional;
232 /// F::conv1d(x, weight, F::Conv1dFuncOptions().stride(1));
233 /// ```
234 using Conv1dFuncOptions = ConvFuncOptions<1>;
235 
236 /// `ConvFuncOptions` specialized for `torch::nn::functional::conv2d`.
237 ///
238 /// Example:
239 /// ```
240 /// namespace F = torch::nn::functional;
241 /// F::conv2d(x, weight, F::Conv2dFuncOptions().stride(1));
242 /// ```
243 using Conv2dFuncOptions = ConvFuncOptions<2>;
244 
245 /// `ConvFuncOptions` specialized for `torch::nn::functional::conv3d`.
246 ///
247 /// Example:
248 /// ```
249 /// namespace F = torch::nn::functional;
250 /// F::conv3d(x, weight, F::Conv3dFuncOptions().stride(1));
251 /// ```
252 using Conv3dFuncOptions = ConvFuncOptions<3>;
253 
254 } // namespace functional
255 
256 // ============================================================================
257 
258 template <size_t D>
259 struct ConvTransposeOptions {
260   using padding_mode_t = detail::conv_padding_mode_t;
261 
ConvTransposeOptionsConvTransposeOptions262   ConvTransposeOptions(
263       int64_t in_channels,
264       int64_t out_channels,
265       ExpandingArray<D> kernel_size)
266       : in_channels_(in_channels),
267         out_channels_(out_channels),
268         kernel_size_(std::move(kernel_size)) {}
269 
270   /// The number of channels the input volumes will have.
271   /// Changing this parameter after construction __has no effect__.
272   TORCH_ARG(int64_t, in_channels);
273 
274   /// The number of output channels the convolution should produce.
275   /// Changing this parameter after construction __has no effect__.
276   TORCH_ARG(int64_t, out_channels);
277 
278   /// The kernel size to use.
279   /// For a `D`-dim convolution, must be a single number or a list of `D`
280   /// numbers.
281   /// This parameter __can__ be changed after construction.
282   TORCH_ARG(ExpandingArray<D>, kernel_size);
283 
284   /// The stride of the convolution.
285   /// For a `D`-dim convolution, must be a single number or a list of `D`
286   /// numbers.
287   /// This parameter __can__ be changed after construction.
288   TORCH_ARG(ExpandingArray<D>, stride) = 1;
289 
290   /// The padding to add to the input volumes.
291   /// For a `D`-dim convolution, must be a single number or a list of `D`
292   /// numbers.
293   /// This parameter __can__ be changed after construction.
294   TORCH_ARG(ExpandingArray<D>, padding) = 0;
295 
296   /// For transpose convolutions, the padding to add to output volumes.
297   /// For a `D`-dim convolution, must be a single number or a list of `D`
298   /// numbers.
299   /// This parameter __can__ be changed after construction.
300   TORCH_ARG(ExpandingArray<D>, output_padding) = 0;
301 
302   /// The number of convolution groups.
303   /// This parameter __can__ be changed after construction.
304   TORCH_ARG(int64_t, groups) = 1;
305 
306   /// Whether to add a bias after individual applications of the kernel.
307   /// Changing this parameter after construction __has no effect__.
308   TORCH_ARG(bool, bias) = true;
309 
310   /// The kernel dilation.
311   /// For a `D`-dim convolution, must be a single number or a list of `D`
312   /// numbers.
313   /// This parameter __can__ be changed after construction.
314   TORCH_ARG(ExpandingArray<D>, dilation) = 1;
315 
316   /// Accepted values `torch::kZeros`, `torch::kReflect`, `torch::kReplicate` or
317   /// `torch::kCircular`. Default: `torch::kZeros`
318   TORCH_ARG(padding_mode_t, padding_mode) = torch::kZeros;
319 };
320 
321 /// `ConvTransposeOptions` specialized for the `ConvTranspose1d` module.
322 ///
323 /// Example:
324 /// ```
325 /// ConvTranspose1d model(ConvTranspose1dOptions(3, 2,
326 /// 3).stride(1).bias(false));
327 /// ```
328 using ConvTranspose1dOptions = ConvTransposeOptions<1>;
329 
330 /// `ConvTransposeOptions` specialized for the `ConvTranspose2d` module.
331 ///
332 /// Example:
333 /// ```
334 /// ConvTranspose2d model(ConvTranspose2dOptions(3, 2,
335 /// 3).stride(1).bias(false));
336 /// ```
337 using ConvTranspose2dOptions = ConvTransposeOptions<2>;
338 
339 /// `ConvTransposeOptions` specialized for the `ConvTranspose3d` module.
340 ///
341 /// Example:
342 /// ```
343 /// ConvTranspose3d model(ConvTranspose3dOptions(2, 2,
344 /// 2).stride(1).bias(false));
345 /// ```
346 using ConvTranspose3dOptions = ConvTransposeOptions<3>;
347 
348 // ============================================================================
349 
350 namespace functional {
351 
352 /// Options for a `D`-dimensional convolution functional.
353 template <size_t D>
354 struct ConvTransposeFuncOptions {
355   /// optional bias of shape `(out_channels)`. Default: ``None``
356   TORCH_ARG(torch::Tensor, bias) = Tensor();
357 
358   /// The stride of the convolving kernel.
359   /// For a `D`-dim convolution, must be a single number or a list of `D`
360   /// numbers.
361   TORCH_ARG(ExpandingArray<D>, stride) = 1;
362 
363   /// Implicit paddings on both sides of the input.
364   /// For a `D`-dim convolution, must be a single number or a list of `D`
365   /// numbers.
366   TORCH_ARG(ExpandingArray<D>, padding) = 0;
367 
368   /// Additional size added to one side of each dimension in the output shape.
369   /// Default: 0
370   TORCH_ARG(ExpandingArray<D>, output_padding) = 0;
371 
372   /// Split input into groups, `in_channels` should be divisible by
373   /// the number of groups.
374   TORCH_ARG(int64_t, groups) = 1;
375 
376   /// The spacing between kernel elements.
377   /// For a `D`-dim convolution, must be a single number or a list of `D`
378   /// numbers.
379   TORCH_ARG(ExpandingArray<D>, dilation) = 1;
380 };
381 
382 /// `ConvTransposeFuncOptions` specialized for
383 /// `torch::nn::functional::conv_transpose1d`.
384 ///
385 /// Example:
386 /// ```
387 /// namespace F = torch::nn::functional;
388 /// F::conv_transpose1d(x, weight, F::ConvTranspose1dFuncOptions().stride(1));
389 /// ```
390 using ConvTranspose1dFuncOptions = ConvTransposeFuncOptions<1>;
391 
392 /// `ConvTransposeFuncOptions` specialized for
393 /// `torch::nn::functional::conv_transpose2d`.
394 ///
395 /// Example:
396 /// ```
397 /// namespace F = torch::nn::functional;
398 /// F::conv_transpose2d(x, weight, F::ConvTranspose2dFuncOptions().stride(1));
399 /// ```
400 using ConvTranspose2dFuncOptions = ConvTransposeFuncOptions<2>;
401 
402 /// `ConvTransposeFuncOptions` specialized for
403 /// `torch::nn::functional::conv_transpose3d`.
404 ///
405 /// Example:
406 /// ```
407 /// namespace F = torch::nn::functional;
408 /// F::conv_transpose3d(x, weight, F::ConvTranspose3dFuncOptions().stride(1));
409 /// ```
410 using ConvTranspose3dFuncOptions = ConvTransposeFuncOptions<3>;
411 
412 } // namespace functional
413 
414 } // namespace nn
415 } // namespace torch
416