1 #pragma once 2 3 #include <torch/expanding_array.h> 4 #include <torch/nn/cloneable.h> 5 #include <torch/nn/functional/fold.h> 6 #include <torch/nn/options/fold.h> 7 #include <torch/nn/pimpl.h> 8 #include <torch/types.h> 9 10 namespace torch { 11 namespace nn { 12 13 /// Applies fold over a 3-D input. 14 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Fold to learn about 15 /// the exact behavior of this module. 16 /// 17 /// See the documentation for `torch::nn::FoldOptions` class to learn what 18 /// constructor arguments are supported for this module. 19 /// 20 /// Example: 21 /// ``` 22 /// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2, 23 /// 1}).stride(2)); 24 /// ``` 25 class TORCH_API FoldImpl : public torch::nn::Cloneable<FoldImpl> { 26 public: FoldImpl(ExpandingArray<2> output_size,ExpandingArray<2> kernel_size)27 FoldImpl(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size) 28 : FoldImpl(FoldOptions(output_size, kernel_size)) {} 29 explicit FoldImpl(const FoldOptions& options_); 30 31 void reset() override; 32 33 /// Pretty prints the `Fold` module into the given `stream`. 34 void pretty_print(std::ostream& stream) const override; 35 36 Tensor forward(const Tensor& input); 37 38 /// The options with which this `Module` was constructed. 39 FoldOptions options; 40 }; 41 42 /// A `ModuleHolder` subclass for `FoldImpl`. 43 /// See the documentation for `FoldImpl` class to learn what methods it 44 /// provides, and examples of how to use `Fold` with `torch::nn::FoldOptions`. 45 /// See the documentation for `ModuleHolder` to learn about PyTorch's 46 /// module storage semantics. 47 TORCH_MODULE(Fold); 48 49 // ============================================================================ 50 51 /// Applies unfold over a 4-D input. 52 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Unfold to learn about 53 /// the exact behavior of this module. 54 /// 55 /// See the documentation for `torch::nn::UnfoldOptions` class to learn what 56 /// constructor arguments are supported for this module. 57 /// 58 /// Example: 59 /// ``` 60 /// Unfold model(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2)); 61 /// ``` 62 class TORCH_API UnfoldImpl : public Cloneable<UnfoldImpl> { 63 public: UnfoldImpl(ExpandingArray<2> kernel_size)64 UnfoldImpl(ExpandingArray<2> kernel_size) 65 : UnfoldImpl(UnfoldOptions(kernel_size)) {} 66 explicit UnfoldImpl(const UnfoldOptions& options_); 67 68 void reset() override; 69 70 /// Pretty prints the `Unfold` module into the given `stream`. 71 void pretty_print(std::ostream& stream) const override; 72 73 Tensor forward(const Tensor& input); 74 75 /// The options with which this `Module` was constructed. 76 UnfoldOptions options; 77 }; 78 79 /// A `ModuleHolder` subclass for `UnfoldImpl`. 80 /// See the documentation for `UnfoldImpl` class to learn what methods it 81 /// provides, and examples of how to use `Unfold` with 82 /// `torch::nn::UnfoldOptions`. See the documentation for `ModuleHolder` to 83 /// learn about PyTorch's module storage semantics. 84 TORCH_MODULE(Unfold); 85 86 } // namespace nn 87 } // namespace torch 88