1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/expanding_array.h> 6 #include <torch/types.h> 7 8 namespace torch { 9 namespace nn { 10 11 /// Options for the `Fold` module. 12 /// 13 /// Example: 14 /// ``` 15 /// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 16 /// 1}).stride(2)); 17 /// ``` 18 struct TORCH_API FoldOptions { FoldOptionsFoldOptions19 FoldOptions(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size) 20 : output_size_(std::move(output_size)), 21 kernel_size_(std::move(kernel_size)) {} 22 23 /// describes the spatial shape of the large containing tensor of the sliding 24 /// local blocks. It is useful to resolve the ambiguity when multiple input 25 /// shapes map to same number of sliding blocks, e.g., with stride > 0. 26 TORCH_ARG(ExpandingArray<2>, output_size); 27 28 /// the size of the sliding blocks 29 TORCH_ARG(ExpandingArray<2>, kernel_size); 30 31 /// controls the spacing between the kernel points; also known as the à trous 32 /// algorithm. 33 TORCH_ARG(ExpandingArray<2>, dilation) = 1; 34 35 /// controls the amount of implicit zero-paddings on both sides for padding 36 /// number of points for each dimension before reshaping. 37 TORCH_ARG(ExpandingArray<2>, padding) = 0; 38 39 /// controls the stride for the sliding blocks. 40 TORCH_ARG(ExpandingArray<2>, stride) = 1; 41 }; 42 43 namespace functional { 44 /// Options for `torch::nn::functional::fold`. 45 /// 46 /// See the documentation for `torch::nn::FoldOptions` class to learn what 47 /// arguments are supported. 48 /// 49 /// Example: 50 /// ``` 51 /// namespace F = torch::nn::functional; 52 /// F::fold(input, F::FoldFuncOptions({3, 2}, {2, 2})); 53 /// ``` 54 using FoldFuncOptions = FoldOptions; 55 } // namespace functional 56 57 // ============================================================================ 58 59 /// Options for the `Unfold` module. 60 /// 61 /// Example: 62 /// ``` 63 /// Unfold model(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2)); 64 /// ``` 65 struct TORCH_API UnfoldOptions { UnfoldOptionsUnfoldOptions66 UnfoldOptions(ExpandingArray<2> kernel_size) 67 : kernel_size_(std::move(kernel_size)) {} 68 69 /// the size of the sliding blocks 70 TORCH_ARG(ExpandingArray<2>, kernel_size); 71 72 /// controls the spacing between the kernel points; also known as the à trous 73 /// algorithm. 74 TORCH_ARG(ExpandingArray<2>, dilation) = 1; 75 76 /// controls the amount of implicit zero-paddings on both sides for padding 77 /// number of points for each dimension before reshaping. 78 TORCH_ARG(ExpandingArray<2>, padding) = 0; 79 80 /// controls the stride for the sliding blocks. 81 TORCH_ARG(ExpandingArray<2>, stride) = 1; 82 }; 83 84 namespace functional { 85 /// Options for `torch::nn::functional::unfold`. 86 /// 87 /// See the documentation for `torch::nn::UnfoldOptions` class to learn what 88 /// arguments are supported. 89 /// 90 /// Example: 91 /// ``` 92 /// namespace F = torch::nn::functional; 93 /// F::unfold(input, F::UnfoldFuncOptions({2, 2}).padding(1).stride(2)); 94 /// ``` 95 using UnfoldFuncOptions = UnfoldOptions; 96 } // namespace functional 97 98 } // namespace nn 99 } // namespace torch 100