xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/dropout.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/functional/dropout.h>
2 #include <torch/nn/modules/dropout.h>
3 
4 #include <torch/types.h>
5 
6 #include <c10/util/Exception.h>
7 
8 #include <cstddef>
9 #include <ostream>
10 #include <vector>
11 
12 namespace F = torch::nn::functional;
13 
14 namespace torch {
15 namespace nn {
16 
forward(Tensor input)17 Tensor DropoutImpl::forward(Tensor input) {
18   return F::detail::dropout(
19       input, options.p(), is_training(), options.inplace());
20 }
21 
pretty_print(std::ostream & stream) const22 void DropoutImpl::pretty_print(std::ostream& stream) const {
23   stream << std::boolalpha << "torch::nn::Dropout(p=" << options.p()
24          << ", inplace=" << options.inplace() << ")";
25 }
26 
27 // ============================================================================
28 
forward(Tensor input)29 Tensor Dropout2dImpl::forward(Tensor input) {
30   return F::detail::dropout2d(
31       input, options.p(), is_training(), options.inplace());
32 }
33 
pretty_print(std::ostream & stream) const34 void Dropout2dImpl::pretty_print(std::ostream& stream) const {
35   stream << std::boolalpha << "torch::nn::Dropout2d(p=" << options.p()
36          << ", inplace=" << options.inplace() << ")";
37 }
38 
39 // ============================================================================
40 
forward(Tensor input)41 Tensor Dropout3dImpl::forward(Tensor input) {
42   return F::detail::dropout3d(
43       input, options.p(), is_training(), options.inplace());
44 }
45 
pretty_print(std::ostream & stream) const46 void Dropout3dImpl::pretty_print(std::ostream& stream) const {
47   stream << std::boolalpha << "torch::nn::Dropout3d(p=" << options.p()
48          << ", inplace=" << options.inplace() << ")";
49 }
50 
51 // ============================================================================
52 
forward(const Tensor & input)53 Tensor AlphaDropoutImpl::forward(const Tensor& input) {
54   return F::detail::alpha_dropout(
55       input, options.p(), is_training(), /*inplace=*/false);
56 }
57 
pretty_print(std::ostream & stream) const58 void AlphaDropoutImpl::pretty_print(std::ostream& stream) const {
59   stream << std::boolalpha << "torch::nn::AlphaDropout(p=" << options.p()
60          << ", inplace=" << options.inplace() << ")";
61 }
62 
63 // ============================================================================
64 
forward(const Tensor & input)65 Tensor FeatureAlphaDropoutImpl::forward(const Tensor& input) {
66   return F::detail::feature_alpha_dropout(
67       input, options.p(), is_training(), /*inplace=*/false);
68 }
69 
pretty_print(std::ostream & stream) const70 void FeatureAlphaDropoutImpl::pretty_print(std::ostream& stream) const {
71   stream << std::boolalpha << "torch::nn::FeatureAlphaDropout(p=" << options.p()
72          << ", inplace=" << options.inplace() << ")";
73 }
74 
75 } // namespace nn
76 } // namespace torch
77