1 #pragma once
2
3 #include <c10/util/irange.h>
4 #include <torch/nn/options/batchnorm.h>
5 #include <torch/types.h>
6
7 namespace torch {
8 namespace nn {
9 namespace functional {
10
11 #ifndef DOXYGEN_SHOULD_SKIP_THIS
12 namespace detail {
batch_norm(const Tensor & input,const Tensor & running_mean,const Tensor & running_var,Tensor weight,Tensor bias,bool training,std::optional<double> momentum,double eps)13 inline Tensor batch_norm(
14 const Tensor& input,
15 const Tensor& running_mean,
16 const Tensor& running_var,
17 Tensor weight,
18 Tensor bias,
19 bool training,
20 std::optional<double> momentum,
21 double eps) {
22 TORCH_CHECK(
23 input.dim() >= 2,
24 "Expected at least 2 input dimensions, but got ",
25 input.dim());
26 if (training) {
27 auto size = input.sizes();
28 int64_t size_prods = size[0];
29 for (const auto i : c10::irange(size.size() - 2)) {
30 size_prods *= size[i + 2];
31 }
32 TORCH_CHECK(
33 size_prods != 1,
34 "Expected more than 1 value per channel when training, got input size ",
35 size);
36 }
37
38 return torch::batch_norm(
39 input,
40 weight,
41 bias,
42 running_mean,
43 running_var,
44 training,
45 momentum.value(),
46 eps,
47 at::globalContext().userEnabledCuDNN());
48 }
49 } // namespace detail
50 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
51
52 /// See
53 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.batch_norm
54 /// about the exact behavior of this functional.
55 ///
56 /// See the documentation for `torch::nn::functional::BatchNormFuncOptions`
57 /// class to learn what optional arguments are supported for this functional.
58 ///
59 /// Example:
60 /// ```
61 /// namespace F = torch::nn::functional;
62 /// F::batch_norm(input, mean, variance,
63 /// F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(0.1).eps(1e-05).training(false));
64 /// ```
65 inline Tensor batch_norm(
66 const Tensor& input,
67 const Tensor& running_mean,
68 const Tensor& running_var,
69 const BatchNormFuncOptions& options = {}) {
70 return detail::batch_norm(
71 input,
72 running_mean,
73 running_var,
74 options.weight(),
75 options.bias(),
76 options.training(),
77 options.momentum(),
78 options.eps());
79 }
80
81 } // namespace functional
82 } // namespace nn
83 } // namespace torch
84