xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/dropout.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/options/dropout.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/types.h>
7 
8 #include <torch/csrc/Export.h>
9 
10 #include <cstddef>
11 #include <vector>
12 
13 namespace torch {
14 namespace nn {
15 
16 namespace detail {
17 
18 template <typename Derived>
19 class _DropoutNd : public torch::nn::Cloneable<Derived> {
20  public:
_DropoutNd(double p)21   _DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)){};
22 
options(options_)23   explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) {
24     // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
25     reset();
26   }
27 
reset()28   void reset() override {
29     TORCH_CHECK(
30         options.p() >= 0. && options.p() <= 1.,
31         "dropout probability has to be between 0 and 1, but got ",
32         options.p());
33   }
34 
35   /// The options with which this `Module` was constructed.
36   DropoutOptions options;
37 };
38 
39 } // namespace detail
40 
41 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
42 
43 /// Applies dropout over a 1-D input.
44 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout to learn
45 /// about the exact behavior of this module.
46 ///
47 /// See the documentation for `torch::nn::DropoutOptions` class to learn what
48 /// constructor arguments are supported for this module.
49 ///
50 /// Example:
51 /// ```
52 /// Dropout model(DropoutOptions().p(0.42).inplace(true));
53 /// ```
54 class TORCH_API DropoutImpl : public detail::_DropoutNd<DropoutImpl> {
55  public:
56   using detail::_DropoutNd<DropoutImpl>::_DropoutNd;
57 
58   Tensor forward(Tensor input);
59 
60   /// Pretty prints the `Dropout` module into the given `stream`.
61   void pretty_print(std::ostream& stream) const override;
62 };
63 
64 /// A `ModuleHolder` subclass for `DropoutImpl`.
65 /// See the documentation for `DropoutImpl` class to learn what methods it
66 /// provides, and examples of how to use `Dropout` with
67 /// `torch::nn::DropoutOptions`. See the documentation for `ModuleHolder` to
68 /// learn about PyTorch's module storage semantics.
69 TORCH_MODULE(Dropout);
70 
71 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72 
73 /// Applies dropout over a 2-D input.
74 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout2d to learn
75 /// about the exact behavior of this module.
76 ///
77 /// See the documentation for `torch::nn::Dropout2dOptions` class to learn what
78 /// constructor arguments are supported for this module.
79 ///
80 /// Example:
81 /// ```
82 /// Dropout2d model(Dropout2dOptions().p(0.42).inplace(true));
83 /// ```
84 class TORCH_API Dropout2dImpl : public detail::_DropoutNd<Dropout2dImpl> {
85  public:
86   using detail::_DropoutNd<Dropout2dImpl>::_DropoutNd;
87 
88   Tensor forward(Tensor input);
89 
90   /// Pretty prints the `Dropout2d` module into the given `stream`.
91   void pretty_print(std::ostream& stream) const override;
92 };
93 
94 /// A `ModuleHolder` subclass for `Dropout2dImpl`.
95 /// See the documentation for `Dropout2dImpl` class to learn what methods it
96 /// provides, and examples of how to use `Dropout2d` with
97 /// `torch::nn::Dropout2dOptions`. See the documentation for `ModuleHolder` to
98 /// learn about PyTorch's module storage semantics.
99 TORCH_MODULE(Dropout2d);
100 
101 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102 
103 /// Applies dropout over a 3-D input.
104 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout3d to learn
105 /// about the exact behavior of this module.
106 ///
107 /// See the documentation for `torch::nn::Dropout3dOptions` class to learn what
108 /// constructor arguments are supported for this module.
109 ///
110 /// Example:
111 /// ```
112 /// Dropout3d model(Dropout3dOptions().p(0.42).inplace(true));
113 /// ```
114 class TORCH_API Dropout3dImpl : public detail::_DropoutNd<Dropout3dImpl> {
115  public:
116   using detail::_DropoutNd<Dropout3dImpl>::_DropoutNd;
117 
118   Tensor forward(Tensor input);
119 
120   /// Pretty prints the `Dropout3d` module into the given `stream`.
121   void pretty_print(std::ostream& stream) const override;
122 };
123 
124 /// A `ModuleHolder` subclass for `Dropout3dImpl`.
125 /// See the documentation for `Dropout3dImpl` class to learn what methods it
126 /// provides, and examples of how to use `Dropout3d` with
127 /// `torch::nn::Dropout3dOptions`. See the documentation for `ModuleHolder` to
128 /// learn about PyTorch's module storage semantics.
129 TORCH_MODULE(Dropout3d);
130 
131 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
132 
133 /// Applies Alpha Dropout over the input.
134 /// See https://pytorch.org/docs/main/nn.html#torch.nn.AlphaDropout to learn
135 /// about the exact behavior of this module.
136 ///
137 /// See the documentation for `torch::nn::AlphaDropoutOptions` class to learn
138 /// what constructor arguments are supported for this module.
139 ///
140 /// Example:
141 /// ```
142 /// AlphaDropout model(AlphaDropoutOptions(0.2).inplace(true));
143 /// ```
144 class TORCH_API AlphaDropoutImpl : public detail::_DropoutNd<AlphaDropoutImpl> {
145  public:
146   using detail::_DropoutNd<AlphaDropoutImpl>::_DropoutNd;
147 
148   Tensor forward(const Tensor& input);
149 
150   /// Pretty prints the `AlphaDropout` module into the given `stream`.
151   void pretty_print(std::ostream& stream) const override;
152 };
153 
154 /// A `ModuleHolder` subclass for `AlphaDropoutImpl`.
155 /// See the documentation for `AlphaDropoutImpl` class to learn what methods it
156 /// provides, and examples of how to use `AlphaDropout` with
157 /// `torch::nn::AlphaDropoutOptions`. See the documentation for `ModuleHolder`
158 /// to learn about PyTorch's module storage semantics.
159 TORCH_MODULE(AlphaDropout);
160 
161 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureAlphaDropout
162 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
163 
164 /// See the documentation for `torch::nn::FeatureAlphaDropoutOptions` class to
165 /// learn what constructor arguments are supported for this module.
166 ///
167 /// Example:
168 /// ```
169 /// FeatureAlphaDropout model(FeatureAlphaDropoutOptions(0.2).inplace(true));
170 /// ```
171 class TORCH_API FeatureAlphaDropoutImpl
172     : public detail::_DropoutNd<FeatureAlphaDropoutImpl> {
173  public:
174   using detail::_DropoutNd<FeatureAlphaDropoutImpl>::_DropoutNd;
175 
176   Tensor forward(const Tensor& input);
177 
178   /// Pretty prints the `FeatureAlphaDropout` module into the given `stream`.
179   void pretty_print(std::ostream& stream) const override;
180 };
181 
182 /// A `ModuleHolder` subclass for `FeatureAlphaDropoutImpl`.
183 /// See the documentation for `FeatureAlphaDropoutImpl` class to learn what
184 /// methods it provides, and examples of how to use `FeatureAlphaDropout` with
185 /// `torch::nn::FeatureAlphaDropoutOptions`. See the documentation for
186 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
187 TORCH_MODULE(FeatureAlphaDropout);
188 
189 } // namespace nn
190 } // namespace torch
191