xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/functional/instancenorm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/options/instancenorm.h>
4 
5 namespace torch {
6 namespace nn {
7 namespace functional {
8 
9 #ifndef DOXYGEN_SHOULD_SKIP_THIS
10 namespace detail {
instance_norm(const Tensor & input,const Tensor & running_mean,const Tensor & running_var,const Tensor & weight,const Tensor & bias,bool use_input_stats,double momentum,double eps)11 inline Tensor instance_norm(
12     const Tensor& input,
13     const Tensor& running_mean,
14     const Tensor& running_var,
15     const Tensor& weight,
16     const Tensor& bias,
17     bool use_input_stats,
18     double momentum,
19     double eps) {
20   return torch::instance_norm(
21       input,
22       weight,
23       bias,
24       running_mean,
25       running_var,
26       use_input_stats,
27       momentum,
28       eps,
29       at::globalContext().userEnabledCuDNN());
30 }
31 } // namespace detail
32 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
33 
34 /// See
35 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.instance_norm
36 /// about the exact behavior of this functional.
37 ///
38 /// See the documentation for `torch::nn::functional::InstanceNormFuncOptions`
39 /// class to learn what optional arguments are supported for this functional.
40 ///
41 /// Example:
42 /// ```
43 /// namespace F = torch::nn::functional;
44 /// F::instance_norm(input,
45 /// F::InstanceNormFuncOptions().running_mean(mean).running_var(variance).weight(weight).bias(bias).momentum(0.1).eps(1e-5));
46 /// ```
47 inline Tensor instance_norm(
48     const Tensor& input,
49     const InstanceNormFuncOptions& options = {}) {
50   return detail::instance_norm(
51       input,
52       options.running_mean(),
53       options.running_var(),
54       options.weight(),
55       options.bias(),
56       options.use_input_stats(),
57       options.momentum(),
58       options.eps());
59 }
60 
61 } // namespace functional
62 } // namespace nn
63 } // namespace torch
64