xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/functional/embedding.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/options/embedding.h>
4 
5 namespace torch {
6 namespace nn {
7 namespace functional {
8 
9 inline Tensor one_hot(const Tensor& tensor, int64_t num_classes = -1) {
10   return torch::one_hot(tensor, num_classes);
11 }
12 
13 #ifndef DOXYGEN_SHOULD_SKIP_THIS
14 namespace detail {
_no_grad_embedding_renorm_(Tensor weight,const Tensor & input,float max_norm,float norm_type)15 inline void _no_grad_embedding_renorm_(
16     Tensor weight,
17     const Tensor& input,
18     float max_norm,
19     float norm_type) {
20   torch::NoGradGuard no_grad;
21   torch::embedding_renorm_(weight, input, max_norm, norm_type);
22 }
23 
embedding(const Tensor & input,const Tensor & weight,std::optional<int64_t> padding_idx,std::optional<double> max_norm,double norm_type,bool scale_grad_by_freq,bool sparse)24 inline Tensor embedding(
25     const Tensor& input,
26     const Tensor& weight,
27     std::optional<int64_t> padding_idx,
28     std::optional<double> max_norm,
29     double norm_type,
30     bool scale_grad_by_freq,
31     bool sparse) {
32   auto input_ = input;
33 
34   if (padding_idx != std::nullopt) {
35     if (*padding_idx > 0) {
36       TORCH_CHECK(
37           *padding_idx < weight.size(0),
38           "Padding_idx must be within num_embeddings");
39     } else if (*padding_idx < 0) {
40       TORCH_CHECK(
41           *padding_idx >= -weight.size(0),
42           "Padding_idx must be within num_embedding");
43       padding_idx = weight.size(0) + *padding_idx;
44     }
45   } else {
46     padding_idx = -1;
47   }
48 
49   if (max_norm != std::nullopt) {
50     input_ = input_.contiguous();
51     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
52     _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type);
53   }
54   return torch::embedding(
55       weight, input_, *padding_idx, scale_grad_by_freq, sparse);
56 }
57 } // namespace detail
58 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
59 
60 /// See
61 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.embedding
62 /// about the exact behavior of this functional.
63 ///
64 /// See the documentation for `torch::nn::functional::EmbeddingFuncOptions`
65 /// class to learn what optional arguments are supported for this functional.
66 ///
67 /// Example:
68 /// ```
69 /// namespace F = torch::nn::functional;
70 /// F::embedding(input, weight,
71 /// F::EmbeddingFuncOptions().norm_type(2.5).scale_grad_by_freq(true).sparse(true));
72 /// ```
73 inline Tensor embedding(
74     const Tensor& input,
75     const Tensor& weight,
76     const EmbeddingFuncOptions& options = {}) {
77   return detail::embedding(
78       input,
79       weight,
80       options.padding_idx(),
81       options.max_norm(),
82       options.norm_type(),
83       options.scale_grad_by_freq(),
84       options.sparse());
85 }
86 
87 #ifndef DOXYGEN_SHOULD_SKIP_THIS
88 namespace detail {
embedding_bag(const Tensor & input,const Tensor & weight,const Tensor & offsets,std::optional<double> max_norm,double norm_type,bool scale_grad_by_freq,EmbeddingBagMode mode,bool sparse,const Tensor & per_sample_weights,bool include_last_offset,std::optional<int64_t> padding_idx)89 inline Tensor embedding_bag(
90     const Tensor& input,
91     const Tensor& weight,
92     const Tensor& offsets,
93     std::optional<double> max_norm,
94     double norm_type,
95     bool scale_grad_by_freq,
96     EmbeddingBagMode mode,
97     bool sparse,
98     const Tensor& per_sample_weights,
99     bool include_last_offset,
100     std::optional<int64_t> padding_idx) {
101   auto input_ = input;
102   auto offsets_ = offsets;
103   auto per_sample_weights_ = per_sample_weights;
104   TORCH_CHECK(
105       !per_sample_weights_.defined() ||
106           input_.sizes() == per_sample_weights_.sizes(),
107       "embedding_bag: If per_sample_weights (",
108       per_sample_weights_.sizes(),
109       ") is not null, then it must have the same shape as the input (",
110       input_.sizes(),
111       ")");
112   if (input_.dim() == 2) {
113     TORCH_CHECK(
114         !offsets_.defined(),
115         "If input is 2D, then offsets has to be null, as input is treated is a mini-batch of fixed length sequences. However, found offsets of type Tensor");
116     offsets_ = torch::arange(
117         0,
118         input_.numel(),
119         input_.size(1),
120         torch::TensorOptions().dtype(torch::kLong).device(input_.device()));
121     input_ = input_.reshape(-1);
122     if (per_sample_weights_.defined()) {
123       per_sample_weights_ = per_sample_weights_.reshape(-1);
124     }
125   } else if (input_.dim() == 1) {
126     TORCH_CHECK(
127         offsets_.defined(), "offsets has to be a 1D Tensor but got null");
128     TORCH_CHECK(offsets_.dim() == 1, "offsets has to be a 1D Tensor");
129   } else {
130     TORCH_CHECK(
131         false,
132         "input has to be 1D or 2D Tensor, but got Tensor of dimension ",
133         input_.dim());
134   }
135 
136   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
137   int mode_enum;
138   if (std::holds_alternative<enumtype::kSum>(mode)) {
139     mode_enum = 0;
140   } else if (std::holds_alternative<enumtype::kMean>(mode)) {
141     mode_enum = 1;
142   } else if (std::holds_alternative<enumtype::kMax>(mode)) {
143     mode_enum = 2;
144     TORCH_CHECK(
145         !scale_grad_by_freq,
146         "max mode does not support scaling the gradient by the frequency");
147     TORCH_CHECK(!sparse, "max mode does not support sparse weights");
148   } else {
149     TORCH_CHECK(false, "mode has to be one of sum, mean or max");
150   }
151 
152   if (max_norm != std::nullopt) {
153     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
154     _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type);
155   }
156 
157   TORCH_CHECK(
158       !per_sample_weights_.defined() || std::get_if<enumtype::kSum>(&mode),
159       "embedding_bag: per_sample_weights was not null. ",
160       "per_sample_weights is only supported for mode='kSum' (got mode='",
161       torch::enumtype::get_enum_name(mode),
162       "').Please open a feature request on GitHub.");
163 
164   return std::get<0>(torch::embedding_bag(
165       weight,
166       input_,
167       offsets_,
168       scale_grad_by_freq,
169       mode_enum,
170       sparse,
171       per_sample_weights_,
172       include_last_offset,
173       padding_idx));
174 }
175 } // namespace detail
176 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
177 
178 /// See
179 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.embedding_bag
180 /// about the exact behavior of this functional.
181 ///
182 /// See the documentation for `torch::nn::functional::EmbeddingBagFuncOptions`
183 /// class to learn what optional arguments are supported for this functional.
184 ///
185 /// Example:
186 /// ```
187 /// namespace F = torch::nn::functional;
188 /// F::embedding_bag(input, weight,
189 /// F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets));
190 /// ```
191 inline Tensor embedding_bag(
192     const Tensor& input,
193     const Tensor& weight,
194     const EmbeddingBagFuncOptions& options = {}) {
195   return detail::embedding_bag(
196       input,
197       weight,
198       options.offsets(),
199       options.max_norm(),
200       options.norm_type(),
201       options.scale_grad_by_freq(),
202       options.mode(),
203       options.sparse(),
204       options.per_sample_weights(),
205       options.include_last_offset(),
206       options.padding_idx());
207 }
208 
209 } // namespace functional
210 } // namespace nn
211 } // namespace torch
212