xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/functional/batchnorm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 #include <torch/nn/options/batchnorm.h>
5 #include <torch/types.h>
6 
7 namespace torch {
8 namespace nn {
9 namespace functional {
10 
11 #ifndef DOXYGEN_SHOULD_SKIP_THIS
12 namespace detail {
batch_norm(const Tensor & input,const Tensor & running_mean,const Tensor & running_var,Tensor weight,Tensor bias,bool training,std::optional<double> momentum,double eps)13 inline Tensor batch_norm(
14     const Tensor& input,
15     const Tensor& running_mean,
16     const Tensor& running_var,
17     Tensor weight,
18     Tensor bias,
19     bool training,
20     std::optional<double> momentum,
21     double eps) {
22   TORCH_CHECK(
23       input.dim() >= 2,
24       "Expected at least 2 input dimensions, but got ",
25       input.dim());
26   if (training) {
27     auto size = input.sizes();
28     int64_t size_prods = size[0];
29     for (const auto i : c10::irange(size.size() - 2)) {
30       size_prods *= size[i + 2];
31     }
32     TORCH_CHECK(
33         size_prods != 1,
34         "Expected more than 1 value per channel when training, got input size ",
35         size);
36   }
37 
38   return torch::batch_norm(
39       input,
40       weight,
41       bias,
42       running_mean,
43       running_var,
44       training,
45       momentum.value(),
46       eps,
47       at::globalContext().userEnabledCuDNN());
48 }
49 } // namespace detail
50 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
51 
52 /// See
53 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.batch_norm
54 /// about the exact behavior of this functional.
55 ///
56 /// See the documentation for `torch::nn::functional::BatchNormFuncOptions`
57 /// class to learn what optional arguments are supported for this functional.
58 ///
59 /// Example:
60 /// ```
61 /// namespace F = torch::nn::functional;
62 /// F::batch_norm(input, mean, variance,
63 /// F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(0.1).eps(1e-05).training(false));
64 /// ```
65 inline Tensor batch_norm(
66     const Tensor& input,
67     const Tensor& running_mean,
68     const Tensor& running_var,
69     const BatchNormFuncOptions& options = {}) {
70   return detail::batch_norm(
71       input,
72       running_mean,
73       running_var,
74       options.weight(),
75       options.bias(),
76       options.training(),
77       options.momentum(),
78       options.eps());
79 }
80 
81 } // namespace functional
82 } // namespace nn
83 } // namespace torch
84