xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/embedding.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/modules/embedding.h>
2 
3 #include <torch/nn/init.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6 
7 #include <cstddef>
8 #include <ostream>
9 #include <utility>
10 #include <vector>
11 
12 namespace F = torch::nn::functional;
13 
14 namespace torch {
15 namespace nn {
EmbeddingImpl(EmbeddingOptions options_)16 EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options_)
17     : options(std::move(options_)) {
18   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
19   reset();
20 }
21 
reset()22 void EmbeddingImpl::reset() {
23   if (options.padding_idx() != std::nullopt) {
24     if (*options.padding_idx() > 0) {
25       TORCH_CHECK(
26           *options.padding_idx() < options.num_embeddings(),
27           "Padding_idx must be within num_embeddings");
28     } else if (*options.padding_idx() < 0) {
29       TORCH_CHECK(
30           *options.padding_idx() >= -options.num_embeddings(),
31           "Padding_idx must be within num_embedding");
32       options.padding_idx(options.num_embeddings() + *options.padding_idx());
33     }
34   }
35 
36   if (!options._weight().defined()) {
37     weight = register_parameter(
38         "weight",
39         torch::empty({options.num_embeddings(), options.embedding_dim()}));
40     reset_parameters();
41   } else {
42     TORCH_CHECK(
43         options._weight().sizes() ==
44             torch::IntArrayRef(
45                 {options.num_embeddings(), options.embedding_dim()}),
46         "Shape of _weight does not match num_embeddings and embedding_dim");
47     weight = register_parameter("weight", options._weight());
48   }
49 }
50 
reset_parameters()51 void EmbeddingImpl::reset_parameters() {
52   torch::nn::init::normal_(weight);
53   if (options.padding_idx() != std::nullopt) {
54     torch::NoGradGuard no_grad;
55     weight[*options.padding_idx()].fill_(0);
56   }
57 }
58 
pretty_print(std::ostream & stream) const59 void EmbeddingImpl::pretty_print(std::ostream& stream) const {
60   stream << "torch::nn::Embedding(num_embeddings=" << options.num_embeddings()
61          << ", embedding_dim=" << options.embedding_dim();
62   if (options.padding_idx() != std::nullopt) {
63     stream << ", padding_idx=" << *options.padding_idx();
64   }
65   if (options.max_norm() != std::nullopt) {
66     stream << ", max_norm=" << *options.max_norm();
67   }
68   if (options.norm_type() != 2) {
69     stream << ", norm_type=" << options.norm_type();
70   }
71   if (options.scale_grad_by_freq()) {
72     stream << ", scale_grad_by_freq=" << std::boolalpha
73            << options.scale_grad_by_freq();
74   }
75   if (options.sparse()) {
76     stream << ", sparse=" << std::boolalpha << options.sparse();
77   }
78   stream << ")";
79 }
80 
forward(const Tensor & input)81 torch::Tensor EmbeddingImpl::forward(const Tensor& input) {
82   return F::detail::embedding(
83       input,
84       weight,
85       options.padding_idx(),
86       options.max_norm(),
87       options.norm_type(),
88       options.scale_grad_by_freq(),
89       options.sparse());
90 }
91 
EmbeddingBagImpl(EmbeddingBagOptions options_)92 EmbeddingBagImpl::EmbeddingBagImpl(EmbeddingBagOptions options_)
93     : options(std::move(options_)) {
94   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
95   reset();
96 }
97 
reset()98 void EmbeddingBagImpl::reset() {
99   if (options.padding_idx().has_value()) {
100     auto padding_idx = options.padding_idx().value();
101     if (padding_idx > 0) {
102       TORCH_CHECK(
103           padding_idx < options.num_embeddings(),
104           "Padding_idx must be within num_embeddings");
105     } else if (padding_idx < 0) {
106       TORCH_CHECK(
107           padding_idx >= -options.num_embeddings(),
108           "Padding_idx must be within num_embedding");
109       options.padding_idx(options.num_embeddings() + padding_idx);
110     }
111   }
112   if (!options._weight().defined()) {
113     weight = register_parameter(
114         "weight",
115         torch::empty({options.num_embeddings(), options.embedding_dim()}));
116     reset_parameters();
117   } else {
118     TORCH_CHECK(
119         options._weight().sizes() ==
120             torch::IntArrayRef(
121                 {options.num_embeddings(), options.embedding_dim()}),
122         "Shape of weight does not match num_embeddings and embedding_dim");
123     weight = register_parameter("weight", options._weight());
124   }
125 }
126 
reset_parameters()127 void EmbeddingBagImpl::reset_parameters() {
128   if (options.padding_idx().has_value()) {
129     torch::NoGradGuard no_grad;
130     weight[options.padding_idx().value()].fill_(0);
131   }
132   torch::nn::init::normal_(weight);
133 }
134 
forward(const Tensor & input,const Tensor & offsets,const Tensor & per_sample_weights)135 torch::Tensor EmbeddingBagImpl::forward(
136     const Tensor& input,
137     const Tensor& offsets,
138     const Tensor& per_sample_weights) {
139   return F::detail::embedding_bag(
140       input,
141       weight,
142       offsets,
143       options.max_norm(),
144       options.norm_type(),
145       options.scale_grad_by_freq(),
146       options.mode(),
147       options.sparse(),
148       per_sample_weights,
149       options.include_last_offset(),
150       options.padding_idx());
151 }
152 
pretty_print(std::ostream & stream) const153 void EmbeddingBagImpl::pretty_print(std::ostream& stream) const {
154   stream << "torch::nn::EmbeddingBag(num_embeddings="
155          << options.num_embeddings()
156          << ", embedding_dim=" << options.embedding_dim();
157   if (options.max_norm() != std::nullopt) {
158     stream << ", max_norm=" << *options.max_norm();
159   }
160   if (options.norm_type() != 2) {
161     stream << ", norm_type=" << options.norm_type();
162   }
163   if (options.scale_grad_by_freq()) {
164     stream << ", scale_grad_by_freq=" << std::boolalpha
165            << options.scale_grad_by_freq();
166   }
167   if (options.sparse()) {
168     stream << ", sparse=" << std::boolalpha << options.sparse();
169   }
170   if (!std::get_if<enumtype::kMean>(&options.mode())) {
171     stream << ", mode=" << torch::enumtype::get_enum_name(options.mode());
172   }
173   if (options.include_last_offset()) {
174     stream << ", include_last_offset=" << std::boolalpha
175            << options.include_last_offset();
176   }
177   if (options.padding_idx().has_value()) {
178     stream << ", padding_idx=" << options.padding_idx().value();
179   }
180   stream << ")";
181 }
182 } // namespace nn
183 } // namespace torch
184