xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/fold.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/expanding_array.h>
6 #include <torch/types.h>
7 
8 namespace torch {
9 namespace nn {
10 
11 /// Options for the `Fold` module.
12 ///
13 /// Example:
14 /// ```
15 /// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2,
16 /// 1}).stride(2));
17 /// ```
18 struct TORCH_API FoldOptions {
FoldOptionsFoldOptions19   FoldOptions(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size)
20       : output_size_(std::move(output_size)),
21         kernel_size_(std::move(kernel_size)) {}
22 
23   /// describes the spatial shape of the large containing tensor of the sliding
24   /// local blocks. It is useful to resolve the ambiguity when multiple input
25   /// shapes map to same number of sliding blocks, e.g., with stride > 0.
26   TORCH_ARG(ExpandingArray<2>, output_size);
27 
28   /// the size of the sliding blocks
29   TORCH_ARG(ExpandingArray<2>, kernel_size);
30 
31   /// controls the spacing between the kernel points; also known as the à trous
32   /// algorithm.
33   TORCH_ARG(ExpandingArray<2>, dilation) = 1;
34 
35   /// controls the amount of implicit zero-paddings on both sides for padding
36   /// number of points for each dimension before reshaping.
37   TORCH_ARG(ExpandingArray<2>, padding) = 0;
38 
39   /// controls the stride for the sliding blocks.
40   TORCH_ARG(ExpandingArray<2>, stride) = 1;
41 };
42 
43 namespace functional {
44 /// Options for `torch::nn::functional::fold`.
45 ///
46 /// See the documentation for `torch::nn::FoldOptions` class to learn what
47 /// arguments are supported.
48 ///
49 /// Example:
50 /// ```
51 /// namespace F = torch::nn::functional;
52 /// F::fold(input, F::FoldFuncOptions({3, 2}, {2, 2}));
53 /// ```
54 using FoldFuncOptions = FoldOptions;
55 } // namespace functional
56 
57 // ============================================================================
58 
59 /// Options for the `Unfold` module.
60 ///
61 /// Example:
62 /// ```
63 /// Unfold model(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2));
64 /// ```
65 struct TORCH_API UnfoldOptions {
UnfoldOptionsUnfoldOptions66   UnfoldOptions(ExpandingArray<2> kernel_size)
67       : kernel_size_(std::move(kernel_size)) {}
68 
69   /// the size of the sliding blocks
70   TORCH_ARG(ExpandingArray<2>, kernel_size);
71 
72   /// controls the spacing between the kernel points; also known as the à trous
73   /// algorithm.
74   TORCH_ARG(ExpandingArray<2>, dilation) = 1;
75 
76   /// controls the amount of implicit zero-paddings on both sides for padding
77   /// number of points for each dimension before reshaping.
78   TORCH_ARG(ExpandingArray<2>, padding) = 0;
79 
80   /// controls the stride for the sliding blocks.
81   TORCH_ARG(ExpandingArray<2>, stride) = 1;
82 };
83 
84 namespace functional {
85 /// Options for `torch::nn::functional::unfold`.
86 ///
87 /// See the documentation for `torch::nn::UnfoldOptions` class to learn what
88 /// arguments are supported.
89 ///
90 /// Example:
91 /// ```
92 /// namespace F = torch::nn::functional;
93 /// F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2));
94 /// ```
95 using UnfoldFuncOptions = UnfoldOptions;
96 } // namespace functional
97 
98 } // namespace nn
99 } // namespace torch
100