1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/functional/normalization.h> 5 #include <torch/nn/modules/_functions.h> 6 #include <torch/nn/options/normalization.h> 7 #include <torch/nn/pimpl.h> 8 #include <torch/types.h> 9 10 #include <cstddef> 11 #include <vector> 12 13 namespace torch { 14 namespace nn { 15 16 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LayerNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 17 18 /// Applies Layer Normalization over a mini-batch of inputs as described in 19 /// the paper `Layer Normalization`_ . 20 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LayerNorm to learn 21 /// about the exact behavior of this module. 22 /// 23 /// See the documentation for `torch::nn::LayerNormOptions` class to learn what 24 /// constructor arguments are supported for this module. 25 /// 26 /// Example: 27 /// ``` 28 /// LayerNorm model(LayerNormOptions({2, 29 /// 2}).elementwise_affine(false).eps(2e-5)); 30 /// ``` 31 class TORCH_API LayerNormImpl : public torch::nn::Cloneable<LayerNormImpl> { 32 public: LayerNormImpl(std::vector<int64_t> normalized_shape)33 LayerNormImpl(std::vector<int64_t> normalized_shape) 34 : LayerNormImpl(LayerNormOptions(normalized_shape)) {} 35 explicit LayerNormImpl(LayerNormOptions options_); 36 37 void reset() override; 38 39 void reset_parameters(); 40 41 /// Pretty prints the `LayerNorm` module into the given `stream`. 42 void pretty_print(std::ostream& stream) const override; 43 44 /// Applies layer normalization over a mini-batch of inputs as described in 45 /// the paper `Layer Normalization`_ . 46 /// 47 /// The mean and standard-deviation are calculated separately over the last 48 /// certain number dimensions which have to be of the shape specified by 49 /// input `normalized_shape`. 50 /// 51 /// `Layer Normalization`: https://arxiv.org/abs/1607.06450 52 Tensor forward(const Tensor& input); 53 54 /// The options with which this module was constructed. 55 LayerNormOptions options; 56 57 /// The learned weight. 58 /// Initialized to ones if the `elementwise_affine` option is set to `true` 59 /// upon construction. 60 Tensor weight; 61 62 /// The learned bias. 63 /// Initialized to zeros `elementwise_affine` option is set to `true` upon 64 /// construction. 65 Tensor bias; 66 }; 67 68 /// A `ModuleHolder` subclass for `LayerNormImpl`. 69 /// See the documentation for `LayerNormImpl` class to learn what methods it 70 /// provides, and examples of how to use `LayerNorm` with 71 /// `torch::nn::LayerNormOptions`. See the documentation for `ModuleHolder` to 72 /// learn about PyTorch's module storage semantics. 73 TORCH_MODULE(LayerNorm); 74 75 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LocalResponseNorm 76 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 77 78 /// Applies local response normalization over an input signal composed 79 /// of several input planes, where channels occupy the second dimension. 80 /// Applies normalization across channels. 81 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LocalResponseNorm to 82 /// learn about the exact behavior of this module. 83 /// 84 /// See the documentation for `torch::nn::LocalResponseNormOptions` class to 85 /// learn what constructor arguments are supported for this module. 86 /// 87 /// Example: 88 /// ``` 89 /// LocalResponseNorm 90 /// model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.)); 91 /// ``` 92 class TORCH_API LocalResponseNormImpl 93 : public Cloneable<LocalResponseNormImpl> { 94 public: LocalResponseNormImpl(int64_t size)95 LocalResponseNormImpl(int64_t size) 96 : LocalResponseNormImpl(LocalResponseNormOptions(size)) {} 97 explicit LocalResponseNormImpl(const LocalResponseNormOptions& options_); 98 99 Tensor forward(const Tensor& input); 100 101 void reset() override; 102 103 /// Pretty prints the `LocalResponseNormImpl` module into the given `stream`. 104 void pretty_print(std::ostream& stream) const override; 105 106 /// The options with which this `Module` was constructed. 107 LocalResponseNormOptions options; 108 }; 109 110 /// A `ModuleHolder` subclass for `LocalResponseNormImpl`. 111 /// See the documentation for `LocalResponseNormImpl` class to learn what 112 /// methods it provides, and examples of how to use `LocalResponseNorm` with 113 /// `torch::nn::LocalResponseNormOptions`. See the documentation for 114 /// `ModuleHolder` to learn about PyTorch's module storage semantics. 115 TORCH_MODULE(LocalResponseNorm); 116 117 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CrossMapLRN2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 118 119 /// See the documentation for `torch::nn::CrossMapLRN2dOptions` class to learn 120 /// what constructor arguments are supported for this module. 121 /// 122 /// Example: 123 /// ``` 124 /// CrossMapLRN2d model(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10)); 125 /// ``` 126 class TORCH_API CrossMapLRN2dImpl 127 : public torch::nn::Cloneable<CrossMapLRN2dImpl> { 128 public: CrossMapLRN2dImpl(int64_t size)129 CrossMapLRN2dImpl(int64_t size) 130 : CrossMapLRN2dImpl(CrossMapLRN2dOptions(size)) {} CrossMapLRN2dImpl(const CrossMapLRN2dOptions & options_)131 explicit CrossMapLRN2dImpl(const CrossMapLRN2dOptions& options_) 132 : options(options_) {} 133 134 void reset() override; 135 136 /// Pretty prints the `CrossMapLRN2d` module into the given `stream`. 137 void pretty_print(std::ostream& stream) const override; 138 139 torch::Tensor forward(const torch::Tensor& input); 140 141 CrossMapLRN2dOptions options; 142 }; 143 144 /// A `ModuleHolder` subclass for `CrossMapLRN2dImpl`. 145 /// See the documentation for `CrossMapLRN2dImpl` class to learn what methods it 146 /// provides, and examples of how to use `CrossMapLRN2d` with 147 /// `torch::nn::CrossMapLRN2dOptions`. See the documentation for `ModuleHolder` 148 /// to learn about PyTorch's module storage semantics. 149 TORCH_MODULE(CrossMapLRN2d); 150 151 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GroupNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 152 153 /// Applies Group Normalization over a mini-batch of inputs as described in 154 /// the paper `Group Normalization`_ . 155 /// See https://pytorch.org/docs/main/nn.html#torch.nn.GroupNorm to learn 156 /// about the exact behavior of this module. 157 /// 158 /// See the documentation for `torch::nn::GroupNormOptions` class to learn what 159 /// constructor arguments are supported for this module. 160 /// 161 /// Example: 162 /// ``` 163 /// GroupNorm model(GroupNormOptions(2, 2).eps(2e-5).affine(false)); 164 /// ``` 165 class TORCH_API GroupNormImpl : public torch::nn::Cloneable<GroupNormImpl> { 166 public: GroupNormImpl(int64_t num_groups,int64_t num_channels)167 GroupNormImpl(int64_t num_groups, int64_t num_channels) 168 : GroupNormImpl(GroupNormOptions(num_groups, num_channels)) {} 169 explicit GroupNormImpl(const GroupNormOptions& options_); 170 171 void reset() override; 172 173 void reset_parameters(); 174 175 /// Pretty prints the `GroupNorm` module into the given `stream`. 176 void pretty_print(std::ostream& stream) const override; 177 178 Tensor forward(const Tensor& input); 179 180 /// The options with which this module was constructed. 181 GroupNormOptions options; 182 183 /// The learned weight. 184 Tensor weight; 185 186 /// The learned bias. 187 Tensor bias; 188 }; 189 190 /// A `ModuleHolder` subclass for `GroupNormImpl`. 191 /// See the documentation for `GroupNormImpl` class to learn what methods it 192 /// provides, and examples of how to use `GroupNorm` with 193 /// `torch::nn::GroupNormOptions`. See the documentation for `ModuleHolder` to 194 /// learn about PyTorch's module storage semantics. 195 TORCH_MODULE(GroupNorm); 196 197 } // namespace nn 198 } // namespace torch 199