xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/normalization.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/functional/normalization.h>
5 #include <torch/nn/modules/_functions.h>
6 #include <torch/nn/options/normalization.h>
7 #include <torch/nn/pimpl.h>
8 #include <torch/types.h>
9 
10 #include <cstddef>
11 #include <vector>
12 
13 namespace torch {
14 namespace nn {
15 
16 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LayerNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
17 
18 /// Applies Layer Normalization over a mini-batch of inputs as described in
19 /// the paper `Layer Normalization`_ .
20 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LayerNorm to learn
21 /// about the exact behavior of this module.
22 ///
23 /// See the documentation for `torch::nn::LayerNormOptions` class to learn what
24 /// constructor arguments are supported for this module.
25 ///
26 /// Example:
27 /// ```
28 /// LayerNorm model(LayerNormOptions({2,
29 /// 2}).elementwise_affine(false).eps(2e-5));
30 /// ```
31 class TORCH_API LayerNormImpl : public torch::nn::Cloneable<LayerNormImpl> {
32  public:
LayerNormImpl(std::vector<int64_t> normalized_shape)33   LayerNormImpl(std::vector<int64_t> normalized_shape)
34       : LayerNormImpl(LayerNormOptions(normalized_shape)) {}
35   explicit LayerNormImpl(LayerNormOptions options_);
36 
37   void reset() override;
38 
39   void reset_parameters();
40 
41   /// Pretty prints the `LayerNorm` module into the given `stream`.
42   void pretty_print(std::ostream& stream) const override;
43 
44   /// Applies layer normalization over a mini-batch of inputs as described in
45   /// the paper `Layer Normalization`_ .
46   ///
47   /// The mean and standard-deviation are calculated separately over the last
48   /// certain number dimensions which have to be of the shape specified by
49   /// input `normalized_shape`.
50   ///
51   /// `Layer Normalization`: https://arxiv.org/abs/1607.06450
52   Tensor forward(const Tensor& input);
53 
54   /// The options with which this module was constructed.
55   LayerNormOptions options;
56 
57   /// The learned weight.
58   /// Initialized to ones if the `elementwise_affine` option is set to `true`
59   /// upon construction.
60   Tensor weight;
61 
62   /// The learned bias.
63   /// Initialized to zeros `elementwise_affine` option is set to `true` upon
64   /// construction.
65   Tensor bias;
66 };
67 
68 /// A `ModuleHolder` subclass for `LayerNormImpl`.
69 /// See the documentation for `LayerNormImpl` class to learn what methods it
70 /// provides, and examples of how to use `LayerNorm` with
71 /// `torch::nn::LayerNormOptions`. See the documentation for `ModuleHolder` to
72 /// learn about PyTorch's module storage semantics.
73 TORCH_MODULE(LayerNorm);
74 
75 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LocalResponseNorm
76 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
77 
78 /// Applies local response normalization over an input signal composed
79 /// of several input planes, where channels occupy the second dimension.
80 /// Applies normalization across channels.
81 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LocalResponseNorm to
82 /// learn about the exact behavior of this module.
83 ///
84 /// See the documentation for `torch::nn::LocalResponseNormOptions` class to
85 /// learn what constructor arguments are supported for this module.
86 ///
87 /// Example:
88 /// ```
89 /// LocalResponseNorm
90 /// model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.));
91 /// ```
92 class TORCH_API LocalResponseNormImpl
93     : public Cloneable<LocalResponseNormImpl> {
94  public:
LocalResponseNormImpl(int64_t size)95   LocalResponseNormImpl(int64_t size)
96       : LocalResponseNormImpl(LocalResponseNormOptions(size)) {}
97   explicit LocalResponseNormImpl(const LocalResponseNormOptions& options_);
98 
99   Tensor forward(const Tensor& input);
100 
101   void reset() override;
102 
103   /// Pretty prints the `LocalResponseNormImpl` module into the given `stream`.
104   void pretty_print(std::ostream& stream) const override;
105 
106   /// The options with which this `Module` was constructed.
107   LocalResponseNormOptions options;
108 };
109 
110 /// A `ModuleHolder` subclass for `LocalResponseNormImpl`.
111 /// See the documentation for `LocalResponseNormImpl` class to learn what
112 /// methods it provides, and examples of how to use `LocalResponseNorm` with
113 /// `torch::nn::LocalResponseNormOptions`. See the documentation for
114 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
115 TORCH_MODULE(LocalResponseNorm);
116 
117 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossMapLRN2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
118 
119 /// See the documentation for `torch::nn::CrossMapLRN2dOptions` class to learn
120 /// what constructor arguments are supported for this module.
121 ///
122 /// Example:
123 /// ```
124 /// CrossMapLRN2d model(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10));
125 /// ```
126 class TORCH_API CrossMapLRN2dImpl
127     : public torch::nn::Cloneable<CrossMapLRN2dImpl> {
128  public:
CrossMapLRN2dImpl(int64_t size)129   CrossMapLRN2dImpl(int64_t size)
130       : CrossMapLRN2dImpl(CrossMapLRN2dOptions(size)) {}
CrossMapLRN2dImpl(const CrossMapLRN2dOptions & options_)131   explicit CrossMapLRN2dImpl(const CrossMapLRN2dOptions& options_)
132       : options(options_) {}
133 
134   void reset() override;
135 
136   /// Pretty prints the `CrossMapLRN2d` module into the given `stream`.
137   void pretty_print(std::ostream& stream) const override;
138 
139   torch::Tensor forward(const torch::Tensor& input);
140 
141   CrossMapLRN2dOptions options;
142 };
143 
144 /// A `ModuleHolder` subclass for `CrossMapLRN2dImpl`.
145 /// See the documentation for `CrossMapLRN2dImpl` class to learn what methods it
146 /// provides, and examples of how to use `CrossMapLRN2d` with
147 /// `torch::nn::CrossMapLRN2dOptions`. See the documentation for `ModuleHolder`
148 /// to learn about PyTorch's module storage semantics.
149 TORCH_MODULE(CrossMapLRN2d);
150 
151 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GroupNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
152 
153 /// Applies Group Normalization over a mini-batch of inputs as described in
154 /// the paper `Group Normalization`_ .
155 /// See https://pytorch.org/docs/main/nn.html#torch.nn.GroupNorm to learn
156 /// about the exact behavior of this module.
157 ///
158 /// See the documentation for `torch::nn::GroupNormOptions` class to learn what
159 /// constructor arguments are supported for this module.
160 ///
161 /// Example:
162 /// ```
163 /// GroupNorm model(GroupNormOptions(2, 2).eps(2e-5).affine(false));
164 /// ```
165 class TORCH_API GroupNormImpl : public torch::nn::Cloneable<GroupNormImpl> {
166  public:
GroupNormImpl(int64_t num_groups,int64_t num_channels)167   GroupNormImpl(int64_t num_groups, int64_t num_channels)
168       : GroupNormImpl(GroupNormOptions(num_groups, num_channels)) {}
169   explicit GroupNormImpl(const GroupNormOptions& options_);
170 
171   void reset() override;
172 
173   void reset_parameters();
174 
175   /// Pretty prints the `GroupNorm` module into the given `stream`.
176   void pretty_print(std::ostream& stream) const override;
177 
178   Tensor forward(const Tensor& input);
179 
180   /// The options with which this module was constructed.
181   GroupNormOptions options;
182 
183   /// The learned weight.
184   Tensor weight;
185 
186   /// The learned bias.
187   Tensor bias;
188 };
189 
190 /// A `ModuleHolder` subclass for `GroupNormImpl`.
191 /// See the documentation for `GroupNormImpl` class to learn what methods it
192 /// provides, and examples of how to use `GroupNorm` with
193 /// `torch::nn::GroupNormOptions`. See the documentation for `ModuleHolder` to
194 /// learn about PyTorch's module storage semantics.
195 TORCH_MODULE(GroupNorm);
196 
197 } // namespace nn
198 } // namespace torch
199