xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/options/normalization.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/options/normalization.h>
2 
3 namespace torch {
4 namespace nn {
5 
LayerNormOptions(std::vector<int64_t> normalized_shape)6 LayerNormOptions::LayerNormOptions(std::vector<int64_t> normalized_shape)
7     : normalized_shape_(std::move(normalized_shape)) {}
8 
CrossMapLRN2dOptions(int64_t size)9 CrossMapLRN2dOptions::CrossMapLRN2dOptions(int64_t size) : size_(size) {}
10 
GroupNormOptions(int64_t num_groups,int64_t num_channels)11 GroupNormOptions::GroupNormOptions(int64_t num_groups, int64_t num_channels)
12     : num_groups_(num_groups), num_channels_(num_channels) {}
13 
14 namespace functional {
15 
LayerNormFuncOptions(std::vector<int64_t> normalized_shape)16 LayerNormFuncOptions::LayerNormFuncOptions(
17     std::vector<int64_t> normalized_shape)
18     : normalized_shape_(std::move(normalized_shape)) {}
19 
GroupNormFuncOptions(int64_t num_groups)20 GroupNormFuncOptions::GroupNormFuncOptions(int64_t num_groups)
21     : num_groups_(num_groups) {}
22 
23 } // namespace functional
24 
25 } // namespace nn
26 } // namespace torch
27