1 #include <torch/nn/modules/embedding.h>
2
3 #include <torch/nn/init.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6
7 #include <cstddef>
8 #include <ostream>
9 #include <utility>
10 #include <vector>
11
12 namespace F = torch::nn::functional;
13
14 namespace torch {
15 namespace nn {
EmbeddingImpl(EmbeddingOptions options_)16 EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options_)
17 : options(std::move(options_)) {
18 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
19 reset();
20 }
21
reset()22 void EmbeddingImpl::reset() {
23 if (options.padding_idx() != std::nullopt) {
24 if (*options.padding_idx() > 0) {
25 TORCH_CHECK(
26 *options.padding_idx() < options.num_embeddings(),
27 "Padding_idx must be within num_embeddings");
28 } else if (*options.padding_idx() < 0) {
29 TORCH_CHECK(
30 *options.padding_idx() >= -options.num_embeddings(),
31 "Padding_idx must be within num_embedding");
32 options.padding_idx(options.num_embeddings() + *options.padding_idx());
33 }
34 }
35
36 if (!options._weight().defined()) {
37 weight = register_parameter(
38 "weight",
39 torch::empty({options.num_embeddings(), options.embedding_dim()}));
40 reset_parameters();
41 } else {
42 TORCH_CHECK(
43 options._weight().sizes() ==
44 torch::IntArrayRef(
45 {options.num_embeddings(), options.embedding_dim()}),
46 "Shape of _weight does not match num_embeddings and embedding_dim");
47 weight = register_parameter("weight", options._weight());
48 }
49 }
50
reset_parameters()51 void EmbeddingImpl::reset_parameters() {
52 torch::nn::init::normal_(weight);
53 if (options.padding_idx() != std::nullopt) {
54 torch::NoGradGuard no_grad;
55 weight[*options.padding_idx()].fill_(0);
56 }
57 }
58
pretty_print(std::ostream & stream) const59 void EmbeddingImpl::pretty_print(std::ostream& stream) const {
60 stream << "torch::nn::Embedding(num_embeddings=" << options.num_embeddings()
61 << ", embedding_dim=" << options.embedding_dim();
62 if (options.padding_idx() != std::nullopt) {
63 stream << ", padding_idx=" << *options.padding_idx();
64 }
65 if (options.max_norm() != std::nullopt) {
66 stream << ", max_norm=" << *options.max_norm();
67 }
68 if (options.norm_type() != 2) {
69 stream << ", norm_type=" << options.norm_type();
70 }
71 if (options.scale_grad_by_freq()) {
72 stream << ", scale_grad_by_freq=" << std::boolalpha
73 << options.scale_grad_by_freq();
74 }
75 if (options.sparse()) {
76 stream << ", sparse=" << std::boolalpha << options.sparse();
77 }
78 stream << ")";
79 }
80
forward(const Tensor & input)81 torch::Tensor EmbeddingImpl::forward(const Tensor& input) {
82 return F::detail::embedding(
83 input,
84 weight,
85 options.padding_idx(),
86 options.max_norm(),
87 options.norm_type(),
88 options.scale_grad_by_freq(),
89 options.sparse());
90 }
91
EmbeddingBagImpl(EmbeddingBagOptions options_)92 EmbeddingBagImpl::EmbeddingBagImpl(EmbeddingBagOptions options_)
93 : options(std::move(options_)) {
94 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
95 reset();
96 }
97
reset()98 void EmbeddingBagImpl::reset() {
99 if (options.padding_idx().has_value()) {
100 auto padding_idx = options.padding_idx().value();
101 if (padding_idx > 0) {
102 TORCH_CHECK(
103 padding_idx < options.num_embeddings(),
104 "Padding_idx must be within num_embeddings");
105 } else if (padding_idx < 0) {
106 TORCH_CHECK(
107 padding_idx >= -options.num_embeddings(),
108 "Padding_idx must be within num_embedding");
109 options.padding_idx(options.num_embeddings() + padding_idx);
110 }
111 }
112 if (!options._weight().defined()) {
113 weight = register_parameter(
114 "weight",
115 torch::empty({options.num_embeddings(), options.embedding_dim()}));
116 reset_parameters();
117 } else {
118 TORCH_CHECK(
119 options._weight().sizes() ==
120 torch::IntArrayRef(
121 {options.num_embeddings(), options.embedding_dim()}),
122 "Shape of weight does not match num_embeddings and embedding_dim");
123 weight = register_parameter("weight", options._weight());
124 }
125 }
126
reset_parameters()127 void EmbeddingBagImpl::reset_parameters() {
128 if (options.padding_idx().has_value()) {
129 torch::NoGradGuard no_grad;
130 weight[options.padding_idx().value()].fill_(0);
131 }
132 torch::nn::init::normal_(weight);
133 }
134
forward(const Tensor & input,const Tensor & offsets,const Tensor & per_sample_weights)135 torch::Tensor EmbeddingBagImpl::forward(
136 const Tensor& input,
137 const Tensor& offsets,
138 const Tensor& per_sample_weights) {
139 return F::detail::embedding_bag(
140 input,
141 weight,
142 offsets,
143 options.max_norm(),
144 options.norm_type(),
145 options.scale_grad_by_freq(),
146 options.mode(),
147 options.sparse(),
148 per_sample_weights,
149 options.include_last_offset(),
150 options.padding_idx());
151 }
152
pretty_print(std::ostream & stream) const153 void EmbeddingBagImpl::pretty_print(std::ostream& stream) const {
154 stream << "torch::nn::EmbeddingBag(num_embeddings="
155 << options.num_embeddings()
156 << ", embedding_dim=" << options.embedding_dim();
157 if (options.max_norm() != std::nullopt) {
158 stream << ", max_norm=" << *options.max_norm();
159 }
160 if (options.norm_type() != 2) {
161 stream << ", norm_type=" << options.norm_type();
162 }
163 if (options.scale_grad_by_freq()) {
164 stream << ", scale_grad_by_freq=" << std::boolalpha
165 << options.scale_grad_by_freq();
166 }
167 if (options.sparse()) {
168 stream << ", sparse=" << std::boolalpha << options.sparse();
169 }
170 if (!std::get_if<enumtype::kMean>(&options.mode())) {
171 stream << ", mode=" << torch::enumtype::get_enum_name(options.mode());
172 }
173 if (options.include_last_offset()) {
174 stream << ", include_last_offset=" << std::boolalpha
175 << options.include_last_offset();
176 }
177 if (options.padding_idx().has_value()) {
178 stream << ", padding_idx=" << options.padding_idx().value();
179 }
180 stream << ")";
181 }
182 } // namespace nn
183 } // namespace torch
184