1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/types.h> 6 #include <vector> 7 8 namespace torch { 9 namespace nn { 10 11 /// Options for the `LayerNorm` module. 12 /// 13 /// Example: 14 /// ``` 15 /// LayerNorm model(LayerNormOptions({2, 16 /// 2}).elementwise_affine(false).eps(2e-5)); 17 /// ``` 18 struct TORCH_API LayerNormOptions { 19 /* implicit */ LayerNormOptions(std::vector<int64_t> normalized_shape); 20 /// input shape from an expected input. 21 TORCH_ARG(std::vector<int64_t>, normalized_shape); 22 /// a value added to the denominator for numerical stability. ``Default: 23 /// 1e-5``. 24 TORCH_ARG(double, eps) = 1e-5; 25 /// a boolean value that when set to ``true``, this module 26 /// has learnable per-element affine parameters initialized to ones (for 27 /// weights) and zeros (for biases). ``Default: true``. 28 TORCH_ARG(bool, elementwise_affine) = true; 29 }; 30 31 // ============================================================================ 32 33 namespace functional { 34 35 /// Options for `torch::nn::functional::layer_norm`. 36 /// 37 /// Example: 38 /// ``` 39 /// namespace F = torch::nn::functional; 40 /// F::layer_norm(input, F::LayerNormFuncOptions({2, 2}).eps(2e-5)); 41 /// ``` 42 struct TORCH_API LayerNormFuncOptions { 43 /* implicit */ LayerNormFuncOptions(std::vector<int64_t> normalized_shape); 44 /// input shape from an expected input. 45 TORCH_ARG(std::vector<int64_t>, normalized_shape); 46 47 TORCH_ARG(Tensor, weight) = {}; 48 49 TORCH_ARG(Tensor, bias) = {}; 50 51 /// a value added to the denominator for numerical stability. ``Default: 52 /// 1e-5``. 53 TORCH_ARG(double, eps) = 1e-5; 54 }; 55 56 } // namespace functional 57 58 // ============================================================================ 59 60 /// Options for the `LocalResponseNorm` module. 61 /// 62 /// Example: 63 /// ``` 64 /// LocalResponseNorm 65 /// model(LocalResponseNormOptions(2).alpha(0.0002).beta(0.85).k(2.)); 66 /// ``` 67 struct TORCH_API LocalResponseNormOptions { LocalResponseNormOptionsLocalResponseNormOptions68 /* implicit */ LocalResponseNormOptions(int64_t size) : size_(size) {} 69 /// amount of neighbouring channels used for normalization 70 TORCH_ARG(int64_t, size); 71 72 /// multiplicative factor. Default: 1e-4 73 TORCH_ARG(double, alpha) = 1e-4; 74 75 /// exponent. Default: 0.75 76 TORCH_ARG(double, beta) = 0.75; 77 78 /// additive factor. Default: 1 79 TORCH_ARG(double, k) = 1.; 80 }; 81 82 namespace functional { 83 /// Options for `torch::nn::functional::local_response_norm`. 84 /// 85 /// See the documentation for `torch::nn::LocalResponseNormOptions` class to 86 /// learn what arguments are supported. 87 /// 88 /// Example: 89 /// ``` 90 /// namespace F = torch::nn::functional; 91 /// F::local_response_norm(x, F::LocalResponseNormFuncOptions(2)); 92 /// ``` 93 using LocalResponseNormFuncOptions = LocalResponseNormOptions; 94 } // namespace functional 95 96 // ============================================================================ 97 98 /// Options for the `CrossMapLRN2d` module. 99 /// 100 /// Example: 101 /// ``` 102 /// CrossMapLRN2d model(CrossMapLRN2dOptions(3).alpha(1e-5).beta(0.1).k(10)); 103 /// ``` 104 struct TORCH_API CrossMapLRN2dOptions { 105 CrossMapLRN2dOptions(int64_t size); 106 107 TORCH_ARG(int64_t, size); 108 109 TORCH_ARG(double, alpha) = 1e-4; 110 111 TORCH_ARG(double, beta) = 0.75; 112 113 TORCH_ARG(int64_t, k) = 1; 114 }; 115 116 // ============================================================================ 117 118 namespace functional { 119 120 /// Options for `torch::nn::functional::normalize`. 121 /// 122 /// Example: 123 /// ``` 124 /// namespace F = torch::nn::functional; 125 /// F::normalize(input, F::NormalizeFuncOptions().p(1).dim(-1)); 126 /// ``` 127 struct TORCH_API NormalizeFuncOptions { 128 /// The exponent value in the norm formulation. Default: 2.0 129 TORCH_ARG(double, p) = 2.0; 130 /// The dimension to reduce. Default: 1 131 TORCH_ARG(int64_t, dim) = 1; 132 /// Small value to avoid division by zero. Default: 1e-12 133 TORCH_ARG(double, eps) = 1e-12; 134 /// the output tensor. If `out` is used, this 135 /// operation won't be differentiable. 136 TORCH_ARG(std::optional<Tensor>, out) = std::nullopt; 137 }; 138 139 } // namespace functional 140 141 // ============================================================================ 142 143 /// Options for the `GroupNorm` module. 144 /// 145 /// Example: 146 /// ``` 147 /// GroupNorm model(GroupNormOptions(2, 2).eps(2e-5).affine(false)); 148 /// ``` 149 struct TORCH_API GroupNormOptions { 150 /* implicit */ GroupNormOptions(int64_t num_groups, int64_t num_channels); 151 152 /// number of groups to separate the channels into 153 TORCH_ARG(int64_t, num_groups); 154 /// number of channels expected in input 155 TORCH_ARG(int64_t, num_channels); 156 /// a value added to the denominator for numerical stability. Default: 1e-5 157 TORCH_ARG(double, eps) = 1e-5; 158 /// a boolean value that when set to ``true``, this module 159 /// has learnable per-channel affine parameters initialized to ones (for 160 /// weights) and zeros (for biases). Default: ``true``. 161 TORCH_ARG(bool, affine) = true; 162 }; 163 164 // ============================================================================ 165 166 namespace functional { 167 168 /// Options for `torch::nn::functional::group_norm`. 169 /// 170 /// Example: 171 /// ``` 172 /// namespace F = torch::nn::functional; 173 /// F::group_norm(input, F::GroupNormFuncOptions(2).eps(2e-5)); 174 /// ``` 175 struct TORCH_API GroupNormFuncOptions { 176 /* implicit */ GroupNormFuncOptions(int64_t num_groups); 177 178 /// number of groups to separate the channels into 179 TORCH_ARG(int64_t, num_groups); 180 181 TORCH_ARG(Tensor, weight) = {}; 182 183 TORCH_ARG(Tensor, bias) = {}; 184 185 /// a value added to the denominator for numerical stability. Default: 1e-5 186 TORCH_ARG(double, eps) = 1e-5; 187 }; 188 189 } // namespace functional 190 191 } // namespace nn 192 } // namespace torch 193