xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/fold.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/expanding_array.h>
4 #include <torch/nn/cloneable.h>
5 #include <torch/nn/functional/fold.h>
6 #include <torch/nn/options/fold.h>
7 #include <torch/nn/pimpl.h>
8 #include <torch/types.h>
9 
10 namespace torch {
11 namespace nn {
12 
13 /// Applies fold over a 3-D input.
14 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Fold to learn about
15 /// the exact behavior of this module.
16 ///
17 /// See the documentation for `torch::nn::FoldOptions` class to learn what
18 /// constructor arguments are supported for this module.
19 ///
20 /// Example:
21 /// ```
22 /// Fold model(FoldOptions({8, 8}, {3, 3}).dilation(2).padding({2,
23 /// 1}).stride(2));
24 /// ```
25 class TORCH_API FoldImpl : public torch::nn::Cloneable<FoldImpl> {
26  public:
FoldImpl(ExpandingArray<2> output_size,ExpandingArray<2> kernel_size)27   FoldImpl(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size)
28       : FoldImpl(FoldOptions(output_size, kernel_size)) {}
29   explicit FoldImpl(const FoldOptions& options_);
30 
31   void reset() override;
32 
33   /// Pretty prints the `Fold` module into the given `stream`.
34   void pretty_print(std::ostream& stream) const override;
35 
36   Tensor forward(const Tensor& input);
37 
38   /// The options with which this `Module` was constructed.
39   FoldOptions options;
40 };
41 
42 /// A `ModuleHolder` subclass for `FoldImpl`.
43 /// See the documentation for `FoldImpl` class to learn what methods it
44 /// provides, and examples of how to use `Fold` with `torch::nn::FoldOptions`.
45 /// See the documentation for `ModuleHolder` to learn about PyTorch's
46 /// module storage semantics.
47 TORCH_MODULE(Fold);
48 
49 // ============================================================================
50 
51 /// Applies unfold over a 4-D input.
52 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Unfold to learn about
53 /// the exact behavior of this module.
54 ///
55 /// See the documentation for `torch::nn::UnfoldOptions` class to learn what
56 /// constructor arguments are supported for this module.
57 ///
58 /// Example:
59 /// ```
60 /// Unfold model(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2));
61 /// ```
62 class TORCH_API UnfoldImpl : public Cloneable<UnfoldImpl> {
63  public:
UnfoldImpl(ExpandingArray<2> kernel_size)64   UnfoldImpl(ExpandingArray<2> kernel_size)
65       : UnfoldImpl(UnfoldOptions(kernel_size)) {}
66   explicit UnfoldImpl(const UnfoldOptions& options_);
67 
68   void reset() override;
69 
70   /// Pretty prints the `Unfold` module into the given `stream`.
71   void pretty_print(std::ostream& stream) const override;
72 
73   Tensor forward(const Tensor& input);
74 
75   /// The options with which this `Module` was constructed.
76   UnfoldOptions options;
77 };
78 
79 /// A `ModuleHolder` subclass for `UnfoldImpl`.
80 /// See the documentation for `UnfoldImpl` class to learn what methods it
81 /// provides, and examples of how to use `Unfold` with
82 /// `torch::nn::UnfoldOptions`. See the documentation for `ModuleHolder` to
83 /// learn about PyTorch's module storage semantics.
84 TORCH_MODULE(Unfold);
85 
86 } // namespace nn
87 } // namespace torch
88