xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/normalization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/modules/normalization.h>
2 
3 #include <torch/cuda.h>
4 #include <torch/nn/init.h>
5 #include <torch/utils.h>
6 
7 #include <ostream>
8 #include <utility>
9 
10 namespace F = torch::nn::functional;
11 
12 namespace torch {
13 namespace nn {
14 
LayerNormImpl(LayerNormOptions options_)15 LayerNormImpl::LayerNormImpl(LayerNormOptions options_)
16     : options(std::move(options_)) {
17   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
18   reset();
19 }
20 
reset()21 void LayerNormImpl::reset() {
22   if (options.elementwise_affine()) {
23     weight =
24         register_parameter("weight", torch::empty(options.normalized_shape()));
25     bias = register_parameter("bias", torch::empty(options.normalized_shape()));
26   } else {
27     weight =
28         register_parameter("weight", torch::Tensor(), /*requires_grad=*/false);
29     bias = register_parameter("bias", torch::Tensor(), /*requires_grad=*/false);
30   }
31   reset_parameters();
32 }
33 
reset_parameters()34 void LayerNormImpl::reset_parameters() {
35   if (options.elementwise_affine()) {
36     torch::nn::init::ones_(weight);
37     torch::nn::init::zeros_(bias);
38   }
39 }
40 
pretty_print(std::ostream & stream) const41 void LayerNormImpl::pretty_print(std::ostream& stream) const {
42   stream << std::boolalpha << "torch::nn::LayerNorm("
43          << torch::IntArrayRef(options.normalized_shape())
44          << ", eps=" << options.eps()
45          << ", elementwise_affine=" << options.elementwise_affine() << ")";
46 }
47 
forward(const Tensor & input)48 torch::Tensor LayerNormImpl::forward(const Tensor& input) {
49   return F::detail::layer_norm(
50       input, options.normalized_shape(), weight, bias, options.eps());
51 }
52 
53 // ============================================================================
54 
LocalResponseNormImpl(const LocalResponseNormOptions & options_)55 LocalResponseNormImpl::LocalResponseNormImpl(
56     const LocalResponseNormOptions& options_)
57     : options(options_) {}
58 
forward(const Tensor & input)59 Tensor LocalResponseNormImpl::forward(const Tensor& input) {
60   return F::detail::local_response_norm(
61       input, options.size(), options.alpha(), options.beta(), options.k());
62 }
63 
reset()64 void LocalResponseNormImpl::reset() {}
65 
pretty_print(std::ostream & stream) const66 void LocalResponseNormImpl::pretty_print(std::ostream& stream) const {
67   stream << std::boolalpha << "torch::nn::LocalResponseNorm(" << options.size()
68          << ", alpha=" << options.alpha() << ", beta=" << options.beta()
69          << ", k=" << options.k() << ")";
70 }
71 
72 // ============================================================================
73 
reset()74 void CrossMapLRN2dImpl::reset() {}
75 
pretty_print(std::ostream & stream) const76 void CrossMapLRN2dImpl::pretty_print(std::ostream& stream) const {
77   stream << std::boolalpha << "torch::nn::CrossMapLRN2d(" << options.size()
78          << ", alpha=" << options.alpha() << ", beta=" << options.beta()
79          << ", k=" << options.k() << ")";
80 }
81 
forward(const torch::Tensor & input)82 torch::Tensor CrossMapLRN2dImpl::forward(const torch::Tensor& input) {
83   return functions::CrossMapLRN2d::apply(input, options);
84 }
85 
86 // ============================================================================
87 
GroupNormImpl(const GroupNormOptions & options_)88 GroupNormImpl::GroupNormImpl(const GroupNormOptions& options_)
89     : options(options_) { // NOLINT(modernize-pass-by-value)
90   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
91   reset();
92 }
93 
reset()94 void GroupNormImpl::reset() {
95   if (options.affine()) {
96     weight = register_parameter("weight", torch::empty(options.num_channels()));
97     bias = register_parameter("bias", torch::empty(options.num_channels()));
98   } else {
99     weight =
100         register_parameter("weight", torch::Tensor(), /*requires_grad=*/false);
101     bias = register_parameter("bias", torch::Tensor(), /*requires_grad=*/false);
102   }
103   reset_parameters();
104 }
105 
reset_parameters()106 void GroupNormImpl::reset_parameters() {
107   if (options.affine()) {
108     torch::nn::init::ones_(weight);
109     torch::nn::init::zeros_(bias);
110   }
111 }
112 
forward(const Tensor & input)113 torch::Tensor GroupNormImpl::forward(const Tensor& input) {
114   return F::detail::group_norm(
115       input, options.num_groups(), weight, bias, options.eps());
116 }
117 
pretty_print(std::ostream & stream) const118 void GroupNormImpl::pretty_print(std::ostream& stream) const {
119   stream << std::boolalpha << "torch::nn::GroupNorm(" << options.num_groups()
120          << ", " << options.num_channels() << ", eps=" << options.eps()
121          << ", affine=" << options.affine() << ")";
122 }
123 
124 } // namespace nn
125 } // namespace torch
126