1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/options/dropout.h> 5 #include <torch/nn/pimpl.h> 6 #include <torch/types.h> 7 8 #include <torch/csrc/Export.h> 9 10 #include <cstddef> 11 #include <vector> 12 13 namespace torch { 14 namespace nn { 15 16 namespace detail { 17 18 template <typename Derived> 19 class _DropoutNd : public torch::nn::Cloneable<Derived> { 20 public: _DropoutNd(double p)21 _DropoutNd(double p) : _DropoutNd(DropoutOptions().p(p)){}; 22 options(options_)23 explicit _DropoutNd(const DropoutOptions& options_ = {}) : options(options_) { 24 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 25 reset(); 26 } 27 reset()28 void reset() override { 29 TORCH_CHECK( 30 options.p() >= 0. && options.p() <= 1., 31 "dropout probability has to be between 0 and 1, but got ", 32 options.p()); 33 } 34 35 /// The options with which this `Module` was constructed. 36 DropoutOptions options; 37 }; 38 39 } // namespace detail 40 41 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 42 43 /// Applies dropout over a 1-D input. 44 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout to learn 45 /// about the exact behavior of this module. 46 /// 47 /// See the documentation for `torch::nn::DropoutOptions` class to learn what 48 /// constructor arguments are supported for this module. 49 /// 50 /// Example: 51 /// ``` 52 /// Dropout model(DropoutOptions().p(0.42).inplace(true)); 53 /// ``` 54 class TORCH_API DropoutImpl : public detail::_DropoutNd<DropoutImpl> { 55 public: 56 using detail::_DropoutNd<DropoutImpl>::_DropoutNd; 57 58 Tensor forward(Tensor input); 59 60 /// Pretty prints the `Dropout` module into the given `stream`. 61 void pretty_print(std::ostream& stream) const override; 62 }; 63 64 /// A `ModuleHolder` subclass for `DropoutImpl`. 65 /// See the documentation for `DropoutImpl` class to learn what methods it 66 /// provides, and examples of how to use `Dropout` with 67 /// `torch::nn::DropoutOptions`. See the documentation for `ModuleHolder` to 68 /// learn about PyTorch's module storage semantics. 69 TORCH_MODULE(Dropout); 70 71 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 72 73 /// Applies dropout over a 2-D input. 74 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout2d to learn 75 /// about the exact behavior of this module. 76 /// 77 /// See the documentation for `torch::nn::Dropout2dOptions` class to learn what 78 /// constructor arguments are supported for this module. 79 /// 80 /// Example: 81 /// ``` 82 /// Dropout2d model(Dropout2dOptions().p(0.42).inplace(true)); 83 /// ``` 84 class TORCH_API Dropout2dImpl : public detail::_DropoutNd<Dropout2dImpl> { 85 public: 86 using detail::_DropoutNd<Dropout2dImpl>::_DropoutNd; 87 88 Tensor forward(Tensor input); 89 90 /// Pretty prints the `Dropout2d` module into the given `stream`. 91 void pretty_print(std::ostream& stream) const override; 92 }; 93 94 /// A `ModuleHolder` subclass for `Dropout2dImpl`. 95 /// See the documentation for `Dropout2dImpl` class to learn what methods it 96 /// provides, and examples of how to use `Dropout2d` with 97 /// `torch::nn::Dropout2dOptions`. See the documentation for `ModuleHolder` to 98 /// learn about PyTorch's module storage semantics. 99 TORCH_MODULE(Dropout2d); 100 101 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Dropout3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 102 103 /// Applies dropout over a 3-D input. 104 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Dropout3d to learn 105 /// about the exact behavior of this module. 106 /// 107 /// See the documentation for `torch::nn::Dropout3dOptions` class to learn what 108 /// constructor arguments are supported for this module. 109 /// 110 /// Example: 111 /// ``` 112 /// Dropout3d model(Dropout3dOptions().p(0.42).inplace(true)); 113 /// ``` 114 class TORCH_API Dropout3dImpl : public detail::_DropoutNd<Dropout3dImpl> { 115 public: 116 using detail::_DropoutNd<Dropout3dImpl>::_DropoutNd; 117 118 Tensor forward(Tensor input); 119 120 /// Pretty prints the `Dropout3d` module into the given `stream`. 121 void pretty_print(std::ostream& stream) const override; 122 }; 123 124 /// A `ModuleHolder` subclass for `Dropout3dImpl`. 125 /// See the documentation for `Dropout3dImpl` class to learn what methods it 126 /// provides, and examples of how to use `Dropout3d` with 127 /// `torch::nn::Dropout3dOptions`. See the documentation for `ModuleHolder` to 128 /// learn about PyTorch's module storage semantics. 129 TORCH_MODULE(Dropout3d); 130 131 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AlphaDropout ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 132 133 /// Applies Alpha Dropout over the input. 134 /// See https://pytorch.org/docs/main/nn.html#torch.nn.AlphaDropout to learn 135 /// about the exact behavior of this module. 136 /// 137 /// See the documentation for `torch::nn::AlphaDropoutOptions` class to learn 138 /// what constructor arguments are supported for this module. 139 /// 140 /// Example: 141 /// ``` 142 /// AlphaDropout model(AlphaDropoutOptions(0.2).inplace(true)); 143 /// ``` 144 class TORCH_API AlphaDropoutImpl : public detail::_DropoutNd<AlphaDropoutImpl> { 145 public: 146 using detail::_DropoutNd<AlphaDropoutImpl>::_DropoutNd; 147 148 Tensor forward(const Tensor& input); 149 150 /// Pretty prints the `AlphaDropout` module into the given `stream`. 151 void pretty_print(std::ostream& stream) const override; 152 }; 153 154 /// A `ModuleHolder` subclass for `AlphaDropoutImpl`. 155 /// See the documentation for `AlphaDropoutImpl` class to learn what methods it 156 /// provides, and examples of how to use `AlphaDropout` with 157 /// `torch::nn::AlphaDropoutOptions`. See the documentation for `ModuleHolder` 158 /// to learn about PyTorch's module storage semantics. 159 TORCH_MODULE(AlphaDropout); 160 161 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FeatureAlphaDropout 162 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 163 164 /// See the documentation for `torch::nn::FeatureAlphaDropoutOptions` class to 165 /// learn what constructor arguments are supported for this module. 166 /// 167 /// Example: 168 /// ``` 169 /// FeatureAlphaDropout model(FeatureAlphaDropoutOptions(0.2).inplace(true)); 170 /// ``` 171 class TORCH_API FeatureAlphaDropoutImpl 172 : public detail::_DropoutNd<FeatureAlphaDropoutImpl> { 173 public: 174 using detail::_DropoutNd<FeatureAlphaDropoutImpl>::_DropoutNd; 175 176 Tensor forward(const Tensor& input); 177 178 /// Pretty prints the `FeatureAlphaDropout` module into the given `stream`. 179 void pretty_print(std::ostream& stream) const override; 180 }; 181 182 /// A `ModuleHolder` subclass for `FeatureAlphaDropoutImpl`. 183 /// See the documentation for `FeatureAlphaDropoutImpl` class to learn what 184 /// methods it provides, and examples of how to use `FeatureAlphaDropout` with 185 /// `torch::nn::FeatureAlphaDropoutOptions`. See the documentation for 186 /// `ModuleHolder` to learn about PyTorch's module storage semantics. 187 TORCH_MODULE(FeatureAlphaDropout); 188 189 } // namespace nn 190 } // namespace torch 191