1 #include <torch/nn/modules/normalization.h>
2
3 #include <torch/cuda.h>
4 #include <torch/nn/init.h>
5 #include <torch/utils.h>
6
7 #include <ostream>
8 #include <utility>
9
10 namespace F = torch::nn::functional;
11
12 namespace torch {
13 namespace nn {
14
LayerNormImpl(LayerNormOptions options_)15 LayerNormImpl::LayerNormImpl(LayerNormOptions options_)
16 : options(std::move(options_)) {
17 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
18 reset();
19 }
20
reset()21 void LayerNormImpl::reset() {
22 if (options.elementwise_affine()) {
23 weight =
24 register_parameter("weight", torch::empty(options.normalized_shape()));
25 bias = register_parameter("bias", torch::empty(options.normalized_shape()));
26 } else {
27 weight =
28 register_parameter("weight", torch::Tensor(), /*requires_grad=*/false);
29 bias = register_parameter("bias", torch::Tensor(), /*requires_grad=*/false);
30 }
31 reset_parameters();
32 }
33
reset_parameters()34 void LayerNormImpl::reset_parameters() {
35 if (options.elementwise_affine()) {
36 torch::nn::init::ones_(weight);
37 torch::nn::init::zeros_(bias);
38 }
39 }
40
pretty_print(std::ostream & stream) const41 void LayerNormImpl::pretty_print(std::ostream& stream) const {
42 stream << std::boolalpha << "torch::nn::LayerNorm("
43 << torch::IntArrayRef(options.normalized_shape())
44 << ", eps=" << options.eps()
45 << ", elementwise_affine=" << options.elementwise_affine() << ")";
46 }
47
forward(const Tensor & input)48 torch::Tensor LayerNormImpl::forward(const Tensor& input) {
49 return F::detail::layer_norm(
50 input, options.normalized_shape(), weight, bias, options.eps());
51 }
52
53 // ============================================================================
54
LocalResponseNormImpl(const LocalResponseNormOptions & options_)55 LocalResponseNormImpl::LocalResponseNormImpl(
56 const LocalResponseNormOptions& options_)
57 : options(options_) {}
58
forward(const Tensor & input)59 Tensor LocalResponseNormImpl::forward(const Tensor& input) {
60 return F::detail::local_response_norm(
61 input, options.size(), options.alpha(), options.beta(), options.k());
62 }
63
reset()64 void LocalResponseNormImpl::reset() {}
65
pretty_print(std::ostream & stream) const66 void LocalResponseNormImpl::pretty_print(std::ostream& stream) const {
67 stream << std::boolalpha << "torch::nn::LocalResponseNorm(" << options.size()
68 << ", alpha=" << options.alpha() << ", beta=" << options.beta()
69 << ", k=" << options.k() << ")";
70 }
71
72 // ============================================================================
73
reset()74 void CrossMapLRN2dImpl::reset() {}
75
pretty_print(std::ostream & stream) const76 void CrossMapLRN2dImpl::pretty_print(std::ostream& stream) const {
77 stream << std::boolalpha << "torch::nn::CrossMapLRN2d(" << options.size()
78 << ", alpha=" << options.alpha() << ", beta=" << options.beta()
79 << ", k=" << options.k() << ")";
80 }
81
forward(const torch::Tensor & input)82 torch::Tensor CrossMapLRN2dImpl::forward(const torch::Tensor& input) {
83 return functions::CrossMapLRN2d::apply(input, options);
84 }
85
86 // ============================================================================
87
GroupNormImpl(const GroupNormOptions & options_)88 GroupNormImpl::GroupNormImpl(const GroupNormOptions& options_)
89 : options(options_) { // NOLINT(modernize-pass-by-value)
90 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
91 reset();
92 }
93
reset()94 void GroupNormImpl::reset() {
95 if (options.affine()) {
96 weight = register_parameter("weight", torch::empty(options.num_channels()));
97 bias = register_parameter("bias", torch::empty(options.num_channels()));
98 } else {
99 weight =
100 register_parameter("weight", torch::Tensor(), /*requires_grad=*/false);
101 bias = register_parameter("bias", torch::Tensor(), /*requires_grad=*/false);
102 }
103 reset_parameters();
104 }
105
reset_parameters()106 void GroupNormImpl::reset_parameters() {
107 if (options.affine()) {
108 torch::nn::init::ones_(weight);
109 torch::nn::init::zeros_(bias);
110 }
111 }
112
forward(const Tensor & input)113 torch::Tensor GroupNormImpl::forward(const Tensor& input) {
114 return F::detail::group_norm(
115 input, options.num_groups(), weight, bias, options.eps());
116 }
117
pretty_print(std::ostream & stream) const118 void GroupNormImpl::pretty_print(std::ostream& stream) const {
119 stream << std::boolalpha << "torch::nn::GroupNorm(" << options.num_groups()
120 << ", " << options.num_channels() << ", eps=" << options.eps()
121 << ", affine=" << options.affine() << ")";
122 }
123
124 } // namespace nn
125 } // namespace torch
126