xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/embedding.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/enum.h>
6 #include <torch/types.h>
7 
8 namespace torch {
9 namespace nn {
10 
11 /// Options for the `Embedding` module.
12 ///
13 /// Example:
14 /// ```
15 /// Embedding model(EmbeddingOptions(10,
16 /// 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true));
17 /// ```
18 struct TORCH_API EmbeddingOptions {
19   EmbeddingOptions(int64_t num_embeddings, int64_t embedding_dim);
20 
21   /// The size of the dictionary of embeddings.
22   TORCH_ARG(int64_t, num_embeddings);
23   /// The size of each embedding vector.
24   TORCH_ARG(int64_t, embedding_dim);
25   /// If specified, the entries at `padding_idx` do not contribute to the
26   /// gradient; therefore, the embedding vector at `padding_idx` is not updated
27   /// during training, i.e. it remains as a fixed "pad". For a newly constructed
28   /// Embedding, the embedding vector at `padding_idx` will default to all
29   /// zeros, but can be updated to another value to be used as the padding
30   /// vector.
31   TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
32   /// If given, each embedding vector with norm larger than `max_norm` is
33   /// renormalized to have norm `max_norm`.
34   TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
35   /// The p of the p-norm to compute for the `max_norm` option. Default ``2``.
36   TORCH_ARG(double, norm_type) = 2.;
37   /// If given, this will scale gradients by the inverse of frequency of the
38   /// words in the mini-batch. Default ``false``.
39   TORCH_ARG(bool, scale_grad_by_freq) = false;
40   /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor.
41   TORCH_ARG(bool, sparse) = false;
42   /// The learnable weights of the module of shape (num_embeddings,
43   /// embedding_dim)
44   TORCH_ARG(torch::Tensor, _weight) = Tensor();
45 };
46 
47 // ============================================================================
48 
49 /// Options for the `Embedding::from_pretrained` function.
50 struct TORCH_API EmbeddingFromPretrainedOptions {
51   /// If ``true``, the tensor does not get updated in the learning process.
52   /// Equivalent to ``embedding.weight.requires_grad_(false)``. Default:
53   /// ``true``
54   TORCH_ARG(bool, freeze) = true;
55   /// If specified, the entries at `padding_idx` do not contribute to the
56   /// gradient; therefore, the embedding vector at `padding_idx` is not updated
57   /// during training, i.e. it remains as a fixed "pad".
58   TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
59   /// If given, each embedding vector with norm larger than `max_norm` is
60   /// renormalized to have norm `max_norm`.
61   TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
62   /// The p of the p-norm to compute for the `max_norm` option. Default ``2``.
63   TORCH_ARG(double, norm_type) = 2.;
64   /// If given, this will scale gradients by the inverse of frequency of the
65   /// words in the mini-batch. Default ``false``.
66   TORCH_ARG(bool, scale_grad_by_freq) = false;
67   /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor.
68   TORCH_ARG(bool, sparse) = false;
69 };
70 
71 // ============================================================================
72 
73 namespace functional {
74 
75 /// Options for `torch::nn::functional::embedding`.
76 ///
77 /// Example:
78 /// ```
79 /// namespace F = torch::nn::functional;
80 /// F::embedding(input, weight,
81 /// F::EmbeddingFuncOptions().norm_type(2.5).scale_grad_by_freq(true).sparse(true));
82 /// ```
83 struct TORCH_API EmbeddingFuncOptions {
84   /// If specified, the entries at `padding_idx` do not contribute to the
85   /// gradient; therefore, the embedding vector at `padding_idx` is not updated
86   /// during training, i.e. it remains as a fixed "pad".
87   TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
88   /// If given, each embedding vector with norm larger than `max_norm` is
89   /// renormalized to have norm `max_norm`.
90   TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
91   /// The p of the p-norm to compute for the `max_norm` option. Default ``2``.
92   TORCH_ARG(double, norm_type) = 2.;
93   /// If given, this will scale gradients by the inverse of frequency of the
94   /// words in the mini-batch. Default ``false``.
95   TORCH_ARG(bool, scale_grad_by_freq) = false;
96   /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor.
97   TORCH_ARG(bool, sparse) = false;
98 };
99 
100 } // namespace functional
101 
102 // ============================================================================
103 
104 typedef std::variant<enumtype::kSum, enumtype::kMean, enumtype::kMax>
105     EmbeddingBagMode;
106 
107 /// Options for the `EmbeddingBag` module.
108 ///
109 /// Example:
110 /// ```
111 /// EmbeddingBag model(EmbeddingBagOptions(10,
112 /// 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum));
113 /// ```
114 struct TORCH_API EmbeddingBagOptions {
115   EmbeddingBagOptions(int64_t num_embeddings, int64_t embedding_dim);
116 
117   /// The size of the dictionary of embeddings.
118   TORCH_ARG(int64_t, num_embeddings);
119   /// The size of each embedding vector.
120   TORCH_ARG(int64_t, embedding_dim);
121   /// If given, each embedding vector with norm larger than `max_norm` is
122   /// renormalized to have norm `max_norm`.
123   TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
124   /// The p of the p-norm to compute for the `max_norm` option. Default ``2``.
125   TORCH_ARG(double, norm_type) = 2.;
126   /// If given, this will scale gradients by the inverse of frequency of the
127   /// words in the mini-batch. Default ``false``. Note: this option is not
128   /// supported when ``mode="kMax"``.
129   TORCH_ARG(bool, scale_grad_by_freq) = false;
130   /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the
131   /// bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights`
132   /// into consideration. ``"kMean"`` computes the average of the values in the
133   /// bag, ``"kMax"`` computes the max value over each bag.
134   TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean;
135   /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor.
136   /// Note: this option is not supported when ``mode="kMax"``.
137   TORCH_ARG(bool, sparse) = false;
138   /// The learnable weights of the module of shape (num_embeddings,
139   /// embedding_dim)
140   TORCH_ARG(torch::Tensor, _weight) = Tensor();
141   /// If ``true``, `offsets` has one additional element, where the last element
142   /// is equivalent to the size of `indices`. This matches the CSR format.
143   TORCH_ARG(bool, include_last_offset) = false;
144   /// If specified, the entries at `padding_idx` do not contribute to the
145   /// gradient; therefore, the embedding vector at padding_idx is not updated
146   /// during training, i.e. it remains as a fixed "pad". For a newly constructed
147   /// EmbeddingBag, the embedding vector at `padding_idx` will default to all
148   /// zeros, but can be updated to another value to be used as the padding
149   /// vector. Note that the embedding vector at `padding_idx` is excluded from
150   /// the reduction.
151   TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
152 };
153 
154 // ============================================================================
155 
156 /// Options for the `EmbeddingBag::from_pretrained` function.
157 struct TORCH_API EmbeddingBagFromPretrainedOptions {
158   /// If ``true``, the tensor does not get updated in the learning process.
159   /// Equivalent to ``embeddingbag.weight.requires_grad_(false)``. Default:
160   /// ``true``
161   TORCH_ARG(bool, freeze) = true;
162   /// If given, each embedding vector with norm larger than `max_norm` is
163   /// renormalized to have norm `max_norm`.
164   TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
165   /// The p of the p-norm to compute for the `max_norm` option. Default ``2``.
166   TORCH_ARG(double, norm_type) = 2.;
167   /// If given, this will scale gradients by the inverse of frequency of the
168   /// words in the mini-batch. Default ``false``. Note: this option is not
169   /// supported when ``mode="kMax"``.
170   TORCH_ARG(bool, scale_grad_by_freq) = false;
171   /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the
172   /// bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights`
173   /// into consideration. ``"kMean"`` computes the average of the values in the
174   /// bag, ``"kMax"`` computes the max value over each bag.
175   TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean;
176   /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor.
177   /// Note: this option is not supported when ``mode="kMax"``.
178   TORCH_ARG(bool, sparse) = false;
179   /// If ``true``, `offsets` has one additional element, where the last element
180   /// is equivalent to the size of `indices`. This matches the CSR format. Note:
181   /// this option is currently only supported when ``mode="sum"``.
182   TORCH_ARG(bool, include_last_offset) = false;
183   /// If specified, the entries at `padding_idx` do not contribute to the
184   /// gradient; therefore, the embedding vector at padding_idx is not updated
185   /// during training, i.e. it remains as a fixed "pad". Note that the embedding
186   /// vector at `padding_idx` is excluded from the reduction.
187   TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
188 };
189 
190 // ============================================================================
191 
192 namespace functional {
193 
194 /// Options for `torch::nn::functional::embedding_bag`.
195 ///
196 /// Example:
197 /// ```
198 /// namespace F = torch::nn::functional;
199 /// F::embedding_bag(input, weight,
200 /// F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets));
201 /// ```
202 struct TORCH_API EmbeddingBagFuncOptions {
203   /// Only used when `input` is 1D. `offsets` determines
204   /// the starting index position of each bag (sequence) in `input`.
205   TORCH_ARG(torch::Tensor, offsets) = Tensor();
206   /// If given, each embedding vector with norm larger than `max_norm` is
207   /// renormalized to have norm `max_norm`.
208   TORCH_ARG(std::optional<double>, max_norm) = std::nullopt;
209   /// The p of the p-norm to compute for the `max_norm` option. Default ``2``.
210   TORCH_ARG(double, norm_type) = 2.;
211   /// If given, this will scale gradients by the inverse of frequency of the
212   /// words in the mini-batch. Default ``false``. Note: this option is not
213   /// supported when ``mode="kMax"``.
214   TORCH_ARG(bool, scale_grad_by_freq) = false;
215   /// ``"kSum"``, ``"kMean"`` or ``"kMax"``. Specifies the way to reduce the
216   /// bag. ``"kSum"`` computes the weighted sum, taking `per_sample_weights`
217   /// into consideration. ``"kMean"`` computes the average of the values in the
218   /// bag, ``"kMax"`` computes the max value over each bag.
219   TORCH_ARG(EmbeddingBagMode, mode) = torch::kMean;
220   /// If ``true``, gradient w.r.t. `weight` matrix will be a sparse tensor.
221   /// Note: this option is not supported when ``mode="kMax"``.
222   TORCH_ARG(bool, sparse) = false;
223   /// a tensor of float / double weights, or None to indicate all weights should
224   /// be taken to be 1. If specified, `per_sample_weights` must have exactly the
225   /// same shape as input and is treated as having the same `offsets`, if those
226   /// are not None.
227   TORCH_ARG(torch::Tensor, per_sample_weights) = Tensor();
228   /// If ``true``, `offsets` has one additional element, where the last element
229   /// is equivalent to the size of `indices`. This matches the CSR format. Note:
230   /// this option is currently only supported when ``mode="sum"``.
231   TORCH_ARG(bool, include_last_offset) = false;
232   /// If specified, the entries at `padding_idx` do not contribute to the
233   /// gradient; therefore, the embedding vector at padding_idx is not updated
234   /// during training, i.e. it remains as a fixed "pad". Note that the embedding
235   /// vector at `padding_idx` is excluded from the reduction.
236   TORCH_ARG(std::optional<int64_t>, padding_idx) = std::nullopt;
237 };
238 
239 } // namespace functional
240 
241 } // namespace nn
242 } // namespace torch
243