1 #include <torch/nn/functional/dropout.h>
2 #include <torch/nn/modules/dropout.h>
3
4 #include <torch/types.h>
5
6 #include <c10/util/Exception.h>
7
8 #include <cstddef>
9 #include <ostream>
10 #include <vector>
11
12 namespace F = torch::nn::functional;
13
14 namespace torch {
15 namespace nn {
16
forward(Tensor input)17 Tensor DropoutImpl::forward(Tensor input) {
18 return F::detail::dropout(
19 input, options.p(), is_training(), options.inplace());
20 }
21
pretty_print(std::ostream & stream) const22 void DropoutImpl::pretty_print(std::ostream& stream) const {
23 stream << std::boolalpha << "torch::nn::Dropout(p=" << options.p()
24 << ", inplace=" << options.inplace() << ")";
25 }
26
27 // ============================================================================
28
forward(Tensor input)29 Tensor Dropout2dImpl::forward(Tensor input) {
30 return F::detail::dropout2d(
31 input, options.p(), is_training(), options.inplace());
32 }
33
pretty_print(std::ostream & stream) const34 void Dropout2dImpl::pretty_print(std::ostream& stream) const {
35 stream << std::boolalpha << "torch::nn::Dropout2d(p=" << options.p()
36 << ", inplace=" << options.inplace() << ")";
37 }
38
39 // ============================================================================
40
forward(Tensor input)41 Tensor Dropout3dImpl::forward(Tensor input) {
42 return F::detail::dropout3d(
43 input, options.p(), is_training(), options.inplace());
44 }
45
pretty_print(std::ostream & stream) const46 void Dropout3dImpl::pretty_print(std::ostream& stream) const {
47 stream << std::boolalpha << "torch::nn::Dropout3d(p=" << options.p()
48 << ", inplace=" << options.inplace() << ")";
49 }
50
51 // ============================================================================
52
forward(const Tensor & input)53 Tensor AlphaDropoutImpl::forward(const Tensor& input) {
54 return F::detail::alpha_dropout(
55 input, options.p(), is_training(), /*inplace=*/false);
56 }
57
pretty_print(std::ostream & stream) const58 void AlphaDropoutImpl::pretty_print(std::ostream& stream) const {
59 stream << std::boolalpha << "torch::nn::AlphaDropout(p=" << options.p()
60 << ", inplace=" << options.inplace() << ")";
61 }
62
63 // ============================================================================
64
forward(const Tensor & input)65 Tensor FeatureAlphaDropoutImpl::forward(const Tensor& input) {
66 return F::detail::feature_alpha_dropout(
67 input, options.p(), is_training(), /*inplace=*/false);
68 }
69
pretty_print(std::ostream & stream) const70 void FeatureAlphaDropoutImpl::pretty_print(std::ostream& stream) const {
71 stream << std::boolalpha << "torch::nn::FeatureAlphaDropout(p=" << options.p()
72 << ", inplace=" << options.inplace() << ")";
73 }
74
75 } // namespace nn
76 } // namespace torch
77