xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/normalization.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/types.h>
6 #include <vector>
7 
8 namespace torch {
9 namespace nn {
10 
11 /// Options for the `LayerNorm` module.
12 ///
13 /// Example:
14 /// ```
15 /// LayerNorm model(LayerNormOptions({2,
16 /// 2}).elementwise_affine(false).eps(2e-5));
17 /// ```
18 struct TORCH_API LayerNormOptions {
19   /* implicit */ LayerNormOptions(std::vector<int64_t> normalized_shape);
20   /// input shape from an expected input.
21   TORCH_ARG(std::vector<int64_t>, normalized_shape);
22   /// a value added to the denominator for numerical stability. ``Default:
23   /// 1e-5``.
24   TORCH_ARG(double, eps) = 1e-5;
25   /// a boolean value that when set to ``true``, this module
26   /// has learnable per-element affine parameters initialized to ones (for
27   /// weights) and zeros (for biases). ``Default: true``.
28   TORCH_ARG(bool, elementwise_affine) = true;
29 };
30 
31 // ============================================================================
32 
33 namespace functional {
34 
35 /// Options for `torch::nn::functional::layer_norm`.
36 ///
37 /// Example:
38 /// ```
39 /// namespace F = torch::nn::functional;
40 /// F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5));
41 /// ```
42 struct TORCH_API LayerNormFuncOptions {
43   /* implicit */ LayerNormFuncOptions(std::vector<int64_t> normalized_shape);
44   /// input shape from an expected input.
45   TORCH_ARG(std::vector<int64_t>, normalized_shape);
46 
47   TORCH_ARG(Tensor, weight) = {};
48 
49   TORCH_ARG(Tensor, bias) = {};
50 
51   /// a value added to the denominator for numerical stability. ``Default:
52   /// 1e-5``.
53   TORCH_ARG(double, eps) = 1e-5;
54 };
55 
56 } // namespace functional
57 
58 // ============================================================================
59 
60 /// Options for the `LocalResponseNorm` module.
61 ///
62 /// Example:
63 /// ```
64 /// LocalResponseNorm
65 /// model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.));
66 /// ```
67 struct TORCH_API LocalResponseNormOptions {
LocalResponseNormOptionsLocalResponseNormOptions68   /* implicit */ LocalResponseNormOptions(int64_t size) : size_(size) {}
69   /// amount of neighbouring channels used for normalization
70   TORCH_ARG(int64_t, size);
71 
72   /// multiplicative factor. Default: 1e-4
73   TORCH_ARG(double, alpha) = 1e-4;
74 
75   /// exponent. Default: 0.75
76   TORCH_ARG(double, beta) = 0.75;
77 
78   /// additive factor. Default: 1
79   TORCH_ARG(double, k) = 1.;
80 };
81 
82 namespace functional {
83 /// Options for `torch::nn::functional::local_response_norm`.
84 ///
85 /// See the documentation for `torch::nn::LocalResponseNormOptions` class to
86 /// learn what arguments are supported.
87 ///
88 /// Example:
89 /// ```
90 /// namespace F = torch::nn::functional;
91 /// F::local_response_norm(x, F::LocalResponseNormFuncOptions(2));
92 /// ```
93 using LocalResponseNormFuncOptions = LocalResponseNormOptions;
94 } // namespace functional
95 
96 // ============================================================================
97 
98 /// Options for the `CrossMapLRN2d` module.
99 ///
100 /// Example:
101 /// ```
102 /// CrossMapLRN2d model(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10));
103 /// ```
104 struct TORCH_API CrossMapLRN2dOptions {
105   CrossMapLRN2dOptions(int64_t size);
106 
107   TORCH_ARG(int64_t, size);
108 
109   TORCH_ARG(double, alpha) = 1e-4;
110 
111   TORCH_ARG(double, beta) = 0.75;
112 
113   TORCH_ARG(int64_t, k) = 1;
114 };
115 
116 // ============================================================================
117 
118 namespace functional {
119 
120 /// Options for `torch::nn::functional::normalize`.
121 ///
122 /// Example:
123 /// ```
124 /// namespace F = torch::nn::functional;
125 /// F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1));
126 /// ```
127 struct TORCH_API NormalizeFuncOptions {
128   /// The exponent value in the norm formulation. Default: 2.0
129   TORCH_ARG(double, p) = 2.0;
130   /// The dimension to reduce. Default: 1
131   TORCH_ARG(int64_t, dim) = 1;
132   /// Small value to avoid division by zero. Default: 1e-12
133   TORCH_ARG(double, eps) = 1e-12;
134   /// the output tensor. If `out` is used, this
135   /// operation won't be differentiable.
136   TORCH_ARG(std::optional<Tensor>, out) = std::nullopt;
137 };
138 
139 } // namespace functional
140 
141 // ============================================================================
142 
143 /// Options for the `GroupNorm` module.
144 ///
145 /// Example:
146 /// ```
147 /// GroupNorm model(GroupNormOptions(2, 2).eps(2e-5).affine(false));
148 /// ```
149 struct TORCH_API GroupNormOptions {
150   /* implicit */ GroupNormOptions(int64_t num_groups, int64_t num_channels);
151 
152   /// number of groups to separate the channels into
153   TORCH_ARG(int64_t, num_groups);
154   /// number of channels expected in input
155   TORCH_ARG(int64_t, num_channels);
156   /// a value added to the denominator for numerical stability. Default: 1e-5
157   TORCH_ARG(double, eps) = 1e-5;
158   /// a boolean value that when set to ``true``, this module
159   /// has learnable per-channel affine parameters initialized to ones (for
160   /// weights) and zeros (for biases). Default: ``true``.
161   TORCH_ARG(bool, affine) = true;
162 };
163 
164 // ============================================================================
165 
166 namespace functional {
167 
168 /// Options for `torch::nn::functional::group_norm`.
169 ///
170 /// Example:
171 /// ```
172 /// namespace F = torch::nn::functional;
173 /// F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5));
174 /// ```
175 struct TORCH_API GroupNormFuncOptions {
176   /* implicit */ GroupNormFuncOptions(int64_t num_groups);
177 
178   /// number of groups to separate the channels into
179   TORCH_ARG(int64_t, num_groups);
180 
181   TORCH_ARG(Tensor, weight) = {};
182 
183   TORCH_ARG(Tensor, bias) = {};
184 
185   /// a value added to the denominator for numerical stability. Default: 1e-5
186   TORCH_ARG(double, eps) = 1e-5;
187 };
188 
189 } // namespace functional
190 
191 } // namespace nn
192 } // namespace torch
193