1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/nn/options/batchnorm.h> 6 #include <torch/types.h> 7 8 namespace torch { 9 namespace nn { 10 11 /// Options for the `InstanceNorm` module. 12 struct TORCH_API InstanceNormOptions { 13 /* implicit */ InstanceNormOptions(int64_t num_features); 14 15 /// The number of features of the input tensor. 16 TORCH_ARG(int64_t, num_features); 17 18 /// The epsilon value added for numerical stability. 19 TORCH_ARG(double, eps) = 1e-5; 20 21 /// A momentum multiplier for the mean and variance. 22 TORCH_ARG(double, momentum) = 0.1; 23 24 /// Whether to learn a scale and bias that are applied in an affine 25 /// transformation on the input. 26 TORCH_ARG(bool, affine) = false; 27 28 /// Whether to store and update batch statistics (mean and variance) in the 29 /// module. 30 TORCH_ARG(bool, track_running_stats) = false; 31 }; 32 33 /// Options for the `InstanceNorm1d` module. 34 /// 35 /// Example: 36 /// ``` 37 /// InstanceNorm1d 38 /// model(InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 39 /// ``` 40 using InstanceNorm1dOptions = InstanceNormOptions; 41 42 /// Options for the `InstanceNorm2d` module. 43 /// 44 /// Example: 45 /// ``` 46 /// InstanceNorm2d 47 /// model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 48 /// ``` 49 using InstanceNorm2dOptions = InstanceNormOptions; 50 51 /// Options for the `InstanceNorm3d` module. 52 /// 53 /// Example: 54 /// ``` 55 /// InstanceNorm3d 56 /// model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 57 /// ``` 58 using InstanceNorm3dOptions = InstanceNormOptions; 59 60 namespace functional { 61 62 /// Options for `torch::nn::functional::instance_norm`. 63 /// 64 /// Example: 65 /// ``` 66 /// namespace F = torch::nn::functional; 67 /// F::instance_norm(input, 68 /// F::InstanceNormFuncOptions().running_mean(mean).running_var(variance).weight(weight).bias(bias).momentum(0.1).eps(1e-5)); 69 /// ``` 70 struct TORCH_API InstanceNormFuncOptions { 71 TORCH_ARG(Tensor, running_mean) = Tensor(); 72 73 TORCH_ARG(Tensor, running_var) = Tensor(); 74 75 TORCH_ARG(Tensor, weight) = Tensor(); 76 77 TORCH_ARG(Tensor, bias) = Tensor(); 78 79 TORCH_ARG(bool, use_input_stats) = true; 80 81 TORCH_ARG(double, momentum) = 0.1; 82 83 TORCH_ARG(double, eps) = 1e-5; 84 }; 85 86 } // namespace functional 87 88 } // namespace nn 89 } // namespace torch 90