1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/functional/batchnorm.h> 5 #include <torch/nn/init.h> 6 #include <torch/nn/options/batchnorm.h> 7 #include <torch/nn/pimpl.h> 8 #include <torch/types.h> 9 10 #include <cstdint> 11 12 namespace torch { 13 namespace nn { 14 15 /// Base class for all (dimension-specialized) batchnorm and instancenorm 16 /// modules. 17 template <size_t D, typename Derived, typename DerivedOptions> 18 class NormImplBase : public torch::nn::Cloneable<Derived> { 19 protected: 20 virtual void _check_input_dim(const Tensor& input) = 0; 21 22 public: NormImplBase(const DerivedOptions & options_)23 NormImplBase(const DerivedOptions& options_) : options(options_) { 24 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 25 reset(); 26 } 27 reset()28 void reset() override { 29 if (options.affine()) { 30 weight = this->register_parameter( 31 "weight", torch::empty({options.num_features()})); 32 bias = this->register_parameter( 33 "bias", torch::empty({options.num_features()})); 34 } else { 35 weight = 36 this->register_parameter("weight", Tensor(), /*requires_grad=*/false); 37 bias = 38 this->register_parameter("bias", Tensor(), /*requires_grad=*/false); 39 } 40 if (options.track_running_stats()) { 41 running_mean = this->register_buffer( 42 "running_mean", torch::zeros({options.num_features()})); 43 running_var = this->register_buffer( 44 "running_var", torch::ones({options.num_features()})); 45 num_batches_tracked = this->register_buffer( 46 "num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong))); 47 } else { 48 running_mean = this->register_buffer("running_mean", Tensor()); 49 running_var = this->register_buffer("running_var", Tensor()); 50 num_batches_tracked = 51 this->register_buffer("num_batches_tracked", Tensor()); 52 } 53 reset_parameters(); 54 } 55 reset_running_stats()56 void reset_running_stats() { 57 if (options.track_running_stats()) { 58 running_mean.zero_(); 59 running_var.fill_(1); 60 num_batches_tracked.zero_(); 61 } 62 } 63 reset_parameters()64 void reset_parameters() { 65 reset_running_stats(); 66 if (options.affine()) { 67 torch::nn::init::ones_(weight); 68 torch::nn::init::zeros_(bias); 69 } 70 } 71 72 /// The options with which this module was constructed. 73 DerivedOptions options; 74 75 /// The learned weight. 76 /// Only defined if the `affine` option was `true` upon construction. 77 Tensor weight; 78 79 /// The learned bias. 80 /// Only defined if the `affine` option was `true` upon construction. 81 Tensor bias; 82 83 /// The running mean. 84 /// Only defined if the `track_running_stats` option was `true` upon 85 /// construction. 86 Tensor running_mean; 87 88 /// The running variance. 89 /// Only defined if the `track_running_stats` option was `true` upon 90 /// construction. 91 Tensor running_var; 92 93 /// The number of the forward call. 94 /// Only defined if the `track_running_stats` option was `true` upon 95 /// construction. 96 Tensor num_batches_tracked; 97 }; 98 99 /// Base class for all (dimension-specialized) batchnorm modules. 100 template <size_t D, typename Derived> 101 class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> { 102 public: 103 using NormImplBase<D, Derived, BatchNormOptions>::NormImplBase; 104 forward(const Tensor & input)105 Tensor forward(const Tensor& input) { 106 this->_check_input_dim(input); 107 // NOLINTNEXTLINE(cppcoreguidelines-init-variables) 108 double exponential_average_factor; 109 if (this->options.momentum() == std::nullopt) { 110 exponential_average_factor = 0.0; 111 } else { 112 exponential_average_factor = this->options.momentum().value(); 113 } 114 115 if (this->is_training() && this->options.track_running_stats()) { 116 if (this->num_batches_tracked.defined()) { 117 this->num_batches_tracked += 1; 118 if (this->options.momentum() == 119 std::nullopt) { // use cumulative moving average 120 exponential_average_factor = 121 1.0 / this->num_batches_tracked.template item<double>(); 122 } else { // use exponential moving average 123 exponential_average_factor = this->options.momentum().value(); 124 } 125 } 126 } 127 128 return torch::nn::functional::detail::batch_norm( 129 input, 130 this->running_mean, 131 this->running_var, 132 this->weight, 133 this->bias, 134 this->is_training() || !this->options.track_running_stats(), 135 /*momentum=*/exponential_average_factor, 136 this->options.eps()); 137 } 138 139 /// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`. pretty_print(std::ostream & stream)140 void pretty_print(std::ostream& stream) const override { 141 stream << std::boolalpha << "torch::nn::BatchNorm" << D << "d(" 142 << this->options.num_features() << ", " 143 << "eps=" << this->options.eps() << ", " 144 << "momentum="; 145 146 if (this->options.momentum().has_value()) { 147 stream << this->options.momentum().value(); 148 } else { 149 stream << "None"; 150 } 151 152 stream << ", " 153 << "affine=" << this->options.affine() << ", " 154 << "track_running_stats=" << this->options.track_running_stats() 155 << ")"; 156 } 157 }; 158 159 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d 160 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 161 162 /// Applies the BatchNorm1d function. 163 /// See https://pytorch.org/docs/main/nn.html#torch.nn.BatchNorm1d to learn 164 /// about the exact behavior of this module. 165 /// 166 /// See the documentation for `torch::nn::BatchNorm1dOptions` class to learn 167 /// what constructor arguments are supported for this module. 168 /// 169 /// Example: 170 /// ``` 171 /// BatchNorm1d 172 /// model(BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 173 /// ``` 174 class TORCH_API BatchNorm1dImpl : public BatchNormImplBase<1, BatchNorm1dImpl> { 175 protected: 176 void _check_input_dim(const Tensor& input) override; 177 178 public: 179 using BatchNormImplBase<1, BatchNorm1dImpl>::BatchNormImplBase; 180 }; 181 182 /// A `ModuleHolder` subclass for `BatchNorm1dImpl`. 183 /// See the documentation for `BatchNorm1dImpl` class to learn what methods it 184 /// provides, and examples of how to use `BatchNorm1d` with 185 /// `torch::nn::BatchNorm1dOptions`. See the documentation for `ModuleHolder` to 186 /// learn about PyTorch's module storage semantics. 187 TORCH_MODULE(BatchNorm1d); 188 189 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm2d 190 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 191 192 /// Applies the BatchNorm2d function. 193 /// See https://pytorch.org/docs/main/nn.html#torch.nn.BatchNorm2d to learn 194 /// about the exact behavior of this module. 195 /// 196 /// See the documentation for `torch::nn::BatchNorm2dOptions` class to learn 197 /// what constructor arguments are supported for this module. 198 /// 199 /// Example: 200 /// ``` 201 /// BatchNorm2d 202 /// model(BatchNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 203 /// ``` 204 class TORCH_API BatchNorm2dImpl : public BatchNormImplBase<2, BatchNorm2dImpl> { 205 protected: 206 void _check_input_dim(const Tensor& input) override; 207 208 public: 209 using BatchNormImplBase<2, BatchNorm2dImpl>::BatchNormImplBase; 210 }; 211 212 /// A `ModuleHolder` subclass for `BatchNorm2dImpl`. 213 /// See the documentation for `BatchNorm2dImpl` class to learn what methods it 214 /// provides, and examples of how to use `BatchNorm2d` with 215 /// `torch::nn::BatchNorm2dOptions`. See the documentation for `ModuleHolder` to 216 /// learn about PyTorch's module storage semantics. 217 TORCH_MODULE(BatchNorm2d); 218 219 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm3d 220 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 221 222 /// Applies the BatchNorm3d function. 223 /// See https://pytorch.org/docs/main/nn.html#torch.nn.BatchNorm3d to learn 224 /// about the exact behavior of this module. 225 /// 226 /// See the documentation for `torch::nn::BatchNorm3dOptions` class to learn 227 /// what constructor arguments are supported for this module. 228 /// 229 /// Example: 230 /// ``` 231 /// BatchNorm3d 232 /// model(BatchNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 233 /// ``` 234 class TORCH_API BatchNorm3dImpl : public BatchNormImplBase<3, BatchNorm3dImpl> { 235 protected: 236 void _check_input_dim(const Tensor& input) override; 237 238 public: 239 using BatchNormImplBase<3, BatchNorm3dImpl>::BatchNormImplBase; 240 }; 241 242 /// A `ModuleHolder` subclass for `BatchNorm3dImpl`. 243 /// See the documentation for `BatchNorm3dImpl` class to learn what methods it 244 /// provides, and examples of how to use `BatchNorm3d` with 245 /// `torch::nn::BatchNorm3dOptions`. See the documentation for `ModuleHolder` to 246 /// learn about PyTorch's module storage semantics. 247 TORCH_MODULE(BatchNorm3d); 248 249 } // namespace nn 250 } // namespace torch 251