xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/instancenorm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/nn/options/batchnorm.h>
6 #include <torch/types.h>
7 
8 namespace torch {
9 namespace nn {
10 
11 /// Options for the `InstanceNorm` module.
12 struct TORCH_API InstanceNormOptions {
13   /* implicit */ InstanceNormOptions(int64_t num_features);
14 
15   /// The number of features of the input tensor.
16   TORCH_ARG(int64_t, num_features);
17 
18   /// The epsilon value added for numerical stability.
19   TORCH_ARG(double, eps) = 1e-5;
20 
21   /// A momentum multiplier for the mean and variance.
22   TORCH_ARG(double, momentum) = 0.1;
23 
24   /// Whether to learn a scale and bias that are applied in an affine
25   /// transformation on the input.
26   TORCH_ARG(bool, affine) = false;
27 
28   /// Whether to store and update batch statistics (mean and variance) in the
29   /// module.
30   TORCH_ARG(bool, track_running_stats) = false;
31 };
32 
33 /// Options for the `InstanceNorm1d` module.
34 ///
35 /// Example:
36 /// ```
37 /// InstanceNorm1d
38 /// model(InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
39 /// ```
40 using InstanceNorm1dOptions = InstanceNormOptions;
41 
42 /// Options for the `InstanceNorm2d` module.
43 ///
44 /// Example:
45 /// ```
46 /// InstanceNorm2d
47 /// model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
48 /// ```
49 using InstanceNorm2dOptions = InstanceNormOptions;
50 
51 /// Options for the `InstanceNorm3d` module.
52 ///
53 /// Example:
54 /// ```
55 /// InstanceNorm3d
56 /// model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
57 /// ```
58 using InstanceNorm3dOptions = InstanceNormOptions;
59 
60 namespace functional {
61 
62 /// Options for `torch::nn::functional::instance_norm`.
63 ///
64 /// Example:
65 /// ```
66 /// namespace F = torch::nn::functional;
67 /// F::instance_norm(input,
68 /// F::InstanceNormFuncOptions().running_mean(mean).running_var(variance).weight(weight).bias(bias).momentum(0.1).eps(1e-5));
69 /// ```
70 struct TORCH_API InstanceNormFuncOptions {
71   TORCH_ARG(Tensor, running_mean) = Tensor();
72 
73   TORCH_ARG(Tensor, running_var) = Tensor();
74 
75   TORCH_ARG(Tensor, weight) = Tensor();
76 
77   TORCH_ARG(Tensor, bias) = Tensor();
78 
79   TORCH_ARG(bool, use_input_stats) = true;
80 
81   TORCH_ARG(double, momentum) = 0.1;
82 
83   TORCH_ARG(double, eps) = 1e-5;
84 };
85 
86 } // namespace functional
87 
88 } // namespace nn
89 } // namespace torch
90