xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/dropout.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/types.h>
6 
7 namespace torch {
8 namespace nn {
9 
10 /// Options for the `Dropout` module.
11 ///
12 /// Example:
13 /// ```
14 /// Dropout model(DropoutOptions().p(0.42).inplace(true));
15 /// ```
16 struct TORCH_API DropoutOptions {
17   /* implicit */ DropoutOptions(double p = 0.5);
18 
19   /// The probability of an element to be zeroed. Default: 0.5
20   TORCH_ARG(double, p) = 0.5;
21 
22   /// can optionally do the operation in-place. Default: False
23   TORCH_ARG(bool, inplace) = false;
24 };
25 
26 /// Options for the `Dropout2d` module.
27 ///
28 /// Example:
29 /// ```
30 /// Dropout2d model(Dropout2dOptions().p(0.42).inplace(true));
31 /// ```
32 using Dropout2dOptions = DropoutOptions;
33 
34 /// Options for the `Dropout3d` module.
35 ///
36 /// Example:
37 /// ```
38 /// Dropout3d model(Dropout3dOptions().p(0.42).inplace(true));
39 /// ```
40 using Dropout3dOptions = DropoutOptions;
41 
42 /// Options for the `AlphaDropout` module.
43 ///
44 /// Example:
45 /// ```
46 /// AlphaDropout model(AlphaDropoutOptions(0.2).inplace(true));
47 /// ```
48 using AlphaDropoutOptions = DropoutOptions;
49 
50 /// Options for the `FeatureAlphaDropout` module.
51 ///
52 /// Example:
53 /// ```
54 /// FeatureAlphaDropout model(FeatureAlphaDropoutOptions(0.2).inplace(true));
55 /// ```
56 using FeatureAlphaDropoutOptions = DropoutOptions;
57 
58 namespace functional {
59 
60 /// Options for `torch::nn::functional::dropout`.
61 ///
62 /// Example:
63 /// ```
64 /// namespace F = torch::nn::functional;
65 /// F::dropout(input, F::DropoutFuncOptions().p(0.5));
66 /// ```
67 struct TORCH_API DropoutFuncOptions {
68   /// The probability of an element to be zeroed. Default: 0.5
69   TORCH_ARG(double, p) = 0.5;
70 
71   TORCH_ARG(bool, training) = true;
72 
73   /// can optionally do the operation in-place. Default: False
74   TORCH_ARG(bool, inplace) = false;
75 };
76 
77 /// Options for `torch::nn::functional::dropout2d`.
78 ///
79 /// Example:
80 /// ```
81 /// namespace F = torch::nn::functional;
82 /// F::dropout2d(input, F::Dropout2dFuncOptions().p(0.5));
83 /// ```
84 using Dropout2dFuncOptions = DropoutFuncOptions;
85 
86 /// Options for `torch::nn::functional::dropout3d`.
87 ///
88 /// Example:
89 /// ```
90 /// namespace F = torch::nn::functional;
91 /// F::dropout3d(input, F::Dropout3dFuncOptions().p(0.5));
92 /// ```
93 using Dropout3dFuncOptions = DropoutFuncOptions;
94 
95 /// Options for `torch::nn::functional::alpha_dropout`.
96 ///
97 /// Example:
98 /// ```
99 /// namespace F = torch::nn::functional;
100 /// F::alpha_dropout(input,
101 /// F::AlphaDropoutFuncOptions().p(0.5).training(false));
102 /// ```
103 struct TORCH_API AlphaDropoutFuncOptions {
104   TORCH_ARG(double, p) = 0.5;
105 
106   TORCH_ARG(bool, training) = false;
107 
108   TORCH_ARG(bool, inplace) = false;
109 };
110 
111 /// Options for `torch::nn::functional::feature_alpha_dropout`.
112 ///
113 /// Example:
114 /// ```
115 /// namespace F = torch::nn::functional;
116 /// F::feature_alpha_dropout(input,
117 /// F::FeatureAlphaDropoutFuncOptions().p(0.5).training(false));
118 /// ```
119 struct TORCH_API FeatureAlphaDropoutFuncOptions {
120   TORCH_ARG(double, p) = 0.5;
121 
122   TORCH_ARG(bool, training) = false;
123 
124   TORCH_ARG(bool, inplace) = false;
125 };
126 
127 } // namespace functional
128 
129 } // namespace nn
130 } // namespace torch
131