xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/batchnorm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/functional/batchnorm.h>
5 #include <torch/nn/init.h>
6 #include <torch/nn/options/batchnorm.h>
7 #include <torch/nn/pimpl.h>
8 #include <torch/types.h>
9 
10 #include <cstdint>
11 
12 namespace torch {
13 namespace nn {
14 
15 /// Base class for all (dimension-specialized) batchnorm and instancenorm
16 /// modules.
17 template <size_t D, typename Derived, typename DerivedOptions>
18 class NormImplBase : public torch::nn::Cloneable<Derived> {
19  protected:
20   virtual void _check_input_dim(const Tensor& input) = 0;
21 
22  public:
NormImplBase(const DerivedOptions & options_)23   NormImplBase(const DerivedOptions& options_) : options(options_) {
24     // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
25     reset();
26   }
27 
reset()28   void reset() override {
29     if (options.affine()) {
30       weight = this->register_parameter(
31           "weight", torch::empty({options.num_features()}));
32       bias = this->register_parameter(
33           "bias", torch::empty({options.num_features()}));
34     } else {
35       weight =
36           this->register_parameter("weight", Tensor(), /*requires_grad=*/false);
37       bias =
38           this->register_parameter("bias", Tensor(), /*requires_grad=*/false);
39     }
40     if (options.track_running_stats()) {
41       running_mean = this->register_buffer(
42           "running_mean", torch::zeros({options.num_features()}));
43       running_var = this->register_buffer(
44           "running_var", torch::ones({options.num_features()}));
45       num_batches_tracked = this->register_buffer(
46           "num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
47     } else {
48       running_mean = this->register_buffer("running_mean", Tensor());
49       running_var = this->register_buffer("running_var", Tensor());
50       num_batches_tracked =
51           this->register_buffer("num_batches_tracked", Tensor());
52     }
53     reset_parameters();
54   }
55 
reset_running_stats()56   void reset_running_stats() {
57     if (options.track_running_stats()) {
58       running_mean.zero_();
59       running_var.fill_(1);
60       num_batches_tracked.zero_();
61     }
62   }
63 
reset_parameters()64   void reset_parameters() {
65     reset_running_stats();
66     if (options.affine()) {
67       torch::nn::init::ones_(weight);
68       torch::nn::init::zeros_(bias);
69     }
70   }
71 
72   /// The options with which this module was constructed.
73   DerivedOptions options;
74 
75   /// The learned weight.
76   /// Only defined if the `affine` option was `true` upon construction.
77   Tensor weight;
78 
79   /// The learned bias.
80   /// Only defined if the `affine` option was `true` upon construction.
81   Tensor bias;
82 
83   /// The running mean.
84   /// Only defined if the `track_running_stats` option was `true` upon
85   /// construction.
86   Tensor running_mean;
87 
88   /// The running variance.
89   /// Only defined if the `track_running_stats` option was `true` upon
90   /// construction.
91   Tensor running_var;
92 
93   /// The number of the forward call.
94   /// Only defined if the `track_running_stats` option was `true` upon
95   /// construction.
96   Tensor num_batches_tracked;
97 };
98 
99 /// Base class for all (dimension-specialized) batchnorm modules.
100 template <size_t D, typename Derived>
101 class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
102  public:
103   using NormImplBase<D, Derived, BatchNormOptions>::NormImplBase;
104 
forward(const Tensor & input)105   Tensor forward(const Tensor& input) {
106     this->_check_input_dim(input);
107     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
108     double exponential_average_factor;
109     if (this->options.momentum() == std::nullopt) {
110       exponential_average_factor = 0.0;
111     } else {
112       exponential_average_factor = this->options.momentum().value();
113     }
114 
115     if (this->is_training() && this->options.track_running_stats()) {
116       if (this->num_batches_tracked.defined()) {
117         this->num_batches_tracked += 1;
118         if (this->options.momentum() ==
119             std::nullopt) { // use cumulative moving average
120           exponential_average_factor =
121               1.0 / this->num_batches_tracked.template item<double>();
122         } else { // use exponential moving average
123           exponential_average_factor = this->options.momentum().value();
124         }
125       }
126     }
127 
128     return torch::nn::functional::detail::batch_norm(
129         input,
130         this->running_mean,
131         this->running_var,
132         this->weight,
133         this->bias,
134         this->is_training() || !this->options.track_running_stats(),
135         /*momentum=*/exponential_average_factor,
136         this->options.eps());
137   }
138 
139   /// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
pretty_print(std::ostream & stream)140   void pretty_print(std::ostream& stream) const override {
141     stream << std::boolalpha << "torch::nn::BatchNorm" << D << "d("
142            << this->options.num_features() << ", "
143            << "eps=" << this->options.eps() << ", "
144            << "momentum=";
145 
146     if (this->options.momentum().has_value()) {
147       stream << this->options.momentum().value();
148     } else {
149       stream << "None";
150     }
151 
152     stream << ", "
153            << "affine=" << this->options.affine() << ", "
154            << "track_running_stats=" << this->options.track_running_stats()
155            << ")";
156   }
157 };
158 
159 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d
160 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
161 
162 /// Applies the BatchNorm1d function.
163 /// See https://pytorch.org/docs/main/nn.html#torch.nn.BatchNorm1d to learn
164 /// about the exact behavior of this module.
165 ///
166 /// See the documentation for `torch::nn::BatchNorm1dOptions` class to learn
167 /// what constructor arguments are supported for this module.
168 ///
169 /// Example:
170 /// ```
171 /// BatchNorm1d
172 /// model(BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
173 /// ```
174 class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> {
175  protected:
176   void _check_input_dim(const Tensor& input) override;
177 
178  public:
179   using BatchNormImplBase<1, BatchNorm1dImpl>::BatchNormImplBase;
180 };
181 
182 /// A `ModuleHolder` subclass for `BatchNorm1dImpl`.
183 /// See the documentation for `BatchNorm1dImpl` class to learn what methods it
184 /// provides, and examples of how to use `BatchNorm1d` with
185 /// `torch::nn::BatchNorm1dOptions`. See the documentation for `ModuleHolder` to
186 /// learn about PyTorch's module storage semantics.
187 TORCH_MODULE(BatchNorm1d);
188 
189 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm2d
190 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
191 
192 /// Applies the BatchNorm2d function.
193 /// See https://pytorch.org/docs/main/nn.html#torch.nn.BatchNorm2d to learn
194 /// about the exact behavior of this module.
195 ///
196 /// See the documentation for `torch::nn::BatchNorm2dOptions` class to learn
197 /// what constructor arguments are supported for this module.
198 ///
199 /// Example:
200 /// ```
201 /// BatchNorm2d
202 /// model(BatchNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
203 /// ```
204 class TORCH_API BatchNorm2dImpl : public BatchNormImplBase<2, BatchNorm2dImpl> {
205  protected:
206   void _check_input_dim(const Tensor& input) override;
207 
208  public:
209   using BatchNormImplBase<2, BatchNorm2dImpl>::BatchNormImplBase;
210 };
211 
212 /// A `ModuleHolder` subclass for `BatchNorm2dImpl`.
213 /// See the documentation for `BatchNorm2dImpl` class to learn what methods it
214 /// provides, and examples of how to use `BatchNorm2d` with
215 /// `torch::nn::BatchNorm2dOptions`. See the documentation for `ModuleHolder` to
216 /// learn about PyTorch's module storage semantics.
217 TORCH_MODULE(BatchNorm2d);
218 
219 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm3d
220 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
221 
222 /// Applies the BatchNorm3d function.
223 /// See https://pytorch.org/docs/main/nn.html#torch.nn.BatchNorm3d to learn
224 /// about the exact behavior of this module.
225 ///
226 /// See the documentation for `torch::nn::BatchNorm3dOptions` class to learn
227 /// what constructor arguments are supported for this module.
228 ///
229 /// Example:
230 /// ```
231 /// BatchNorm3d
232 /// model(BatchNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
233 /// ```
234 class TORCH_API BatchNorm3dImpl : public BatchNormImplBase<3, BatchNorm3dImpl> {
235  protected:
236   void _check_input_dim(const Tensor& input) override;
237 
238  public:
239   using BatchNormImplBase<3, BatchNorm3dImpl>::BatchNormImplBase;
240 };
241 
242 /// A `ModuleHolder` subclass for `BatchNorm3dImpl`.
243 /// See the documentation for `BatchNorm3dImpl` class to learn what methods it
244 /// provides, and examples of how to use `BatchNorm3d` with
245 /// `torch::nn::BatchNorm3dOptions`. See the documentation for `ModuleHolder` to
246 /// learn about PyTorch's module storage semantics.
247 TORCH_MODULE(BatchNorm3d);
248 
249 } // namespace nn
250 } // namespace torch
251