xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/embedding.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/functional/embedding.h>
5 #include <torch/nn/modules/common.h>
6 #include <torch/nn/options/embedding.h>
7 #include <torch/nn/pimpl.h>
8 #include <torch/types.h>
9 
10 #include <cstddef>
11 
12 namespace torch {
13 namespace nn {
14 
15 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Embedding
16 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
17 
18 /// Performs a lookup in a fixed size embedding table.
19 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Embedding to learn
20 /// about the exact behavior of this module.
21 ///
22 /// See the documentation for `torch::nn::EmbeddingOptions` class to learn what
23 /// constructor arguments are supported for this module.
24 ///
25 /// Example:
26 /// ```
27 /// Embedding model(EmbeddingOptions(10,
28 /// 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true));
29 /// ```
30 class TORCH_API EmbeddingImpl : public torch::nn::Cloneable<EmbeddingImpl> {
31  public:
EmbeddingImpl(int64_t num_embeddings,int64_t embedding_dim)32   EmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim)
33       : EmbeddingImpl(EmbeddingOptions(num_embeddings, embedding_dim)) {}
34   explicit EmbeddingImpl(EmbeddingOptions options_);
35 
36   void reset() override;
37 
38   void reset_parameters();
39 
40   /// Pretty prints the `Embedding` module into the given `stream`.
41   void pretty_print(std::ostream& stream) const override;
42 
43   /// Performs a lookup on the embedding table stored in `weight` using the
44   /// `indices` supplied and returns the result.
45   Tensor forward(const Tensor& indices);
46 
47   /// The `Options` used to configure this `Embedding` module.
48   /// Changes to `EmbeddingOptions` *after construction* have no effect.
49   EmbeddingOptions options;
50 
51   /// The embedding table.
52   Tensor weight;
53 };
54 
55 /// A `ModuleHolder` subclass for `EmbeddingImpl`.
56 /// See the documentation for `EmbeddingImpl` class to learn what methods it
57 /// provides, and examples of how to use `Embedding` with
58 /// `torch::nn::EmbeddingOptions`. See the documentation for `ModuleHolder` to
59 /// learn about PyTorch's module storage semantics.
60 class Embedding : public torch::nn::ModuleHolder<EmbeddingImpl> {
61  public:
62   using torch::nn::ModuleHolder<EmbeddingImpl>::ModuleHolder;
63 
64   /// See the documentation for `torch::nn::EmbeddingFromPretrainedOptions`
65   /// class to learn what optional arguments are supported for this function.
66   static Embedding from_pretrained(
67       const torch::Tensor& embeddings,
68       const EmbeddingFromPretrainedOptions& options = {}) {
69     TORCH_CHECK(
70         embeddings.dim() == 2,
71         "Embeddings parameter is expected to be 2-dimensional");
72 
73     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
74     int64_t rows, cols;
75     rows = embeddings.size(0);
76     cols = embeddings.size(1);
77 
78     Embedding embedding(EmbeddingOptions(rows, cols)
79                             ._weight(embeddings)
80                             .padding_idx(options.padding_idx())
81                             .max_norm(options.max_norm())
82                             .norm_type(options.norm_type())
83                             .scale_grad_by_freq(options.scale_grad_by_freq())
84                             .sparse(options.sparse()));
85     embedding->weight.set_requires_grad(!options.freeze());
86     return embedding;
87   }
88 };
89 
90 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ EmbeddingBag
91 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
92 
93 /// Computes sums or means of 'bags' of embeddings, without instantiating the
94 /// intermediate embeddings.
95 /// See https://pytorch.org/docs/main/nn.html#torch.nn.EmbeddingBag to learn
96 /// about the exact behavior of this module.
97 ///
98 /// See the documentation for `torch::nn::EmbeddingBagOptions` class to learn
99 /// what constructor arguments are supported for this module.
100 ///
101 /// Example:
102 /// ```
103 /// EmbeddingBag model(EmbeddingBagOptions(10,
104 /// 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum).padding_idx(1));
105 /// ```
106 class TORCH_API EmbeddingBagImpl
107     : public torch::nn::Cloneable<EmbeddingBagImpl> {
108  public:
EmbeddingBagImpl(int64_t num_embeddings,int64_t embedding_dim)109   EmbeddingBagImpl(int64_t num_embeddings, int64_t embedding_dim)
110       : EmbeddingBagImpl(EmbeddingBagOptions(num_embeddings, embedding_dim)) {}
111   explicit EmbeddingBagImpl(EmbeddingBagOptions options_);
112 
113   void reset() override;
114 
115   void reset_parameters();
116 
117   /// Pretty prints the `EmbeddingBag` module into the given `stream`.
118   void pretty_print(std::ostream& stream) const override;
119 
120   /// The `Options` used to configure this `EmbeddingBag` module.
121   EmbeddingBagOptions options;
122   /// The embedding table.
123   Tensor weight;
124 
125   Tensor forward(
126       const Tensor& input,
127       const Tensor& offsets = {},
128       const Tensor& per_sample_weights = {});
129 
130  protected:
131   FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())})
132 };
133 
134 /// A `ModuleHolder` subclass for `EmbeddingBagImpl`.
135 /// See the documentation for `EmbeddingBagImpl` class to learn what methods it
136 /// provides, and examples of how to use `EmbeddingBag` with
137 /// `torch::nn::EmbeddingBagOptions`. See the documentation for `ModuleHolder`
138 /// to learn about PyTorch's module storage semantics.
139 class EmbeddingBag : public torch::nn::ModuleHolder<EmbeddingBagImpl> {
140  public:
141   using torch::nn::ModuleHolder<EmbeddingBagImpl>::ModuleHolder;
142 
143   /// See the documentation for `torch::nn::EmbeddingBagFromPretrainedOptions`
144   /// class to learn what optional arguments are supported for this function.
145   static EmbeddingBag from_pretrained(
146       const torch::Tensor& embeddings,
147       const EmbeddingBagFromPretrainedOptions& options = {}) {
148     TORCH_CHECK(
149         embeddings.dim() == 2,
150         "Embeddings parameter is expected to be 2-dimensional");
151 
152     // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
153     int64_t rows, cols;
154     rows = embeddings.size(0);
155     cols = embeddings.size(1);
156 
157     EmbeddingBag embeddingbag(
158         EmbeddingBagOptions(rows, cols)
159             ._weight(embeddings)
160             .max_norm(options.max_norm())
161             .norm_type(options.norm_type())
162             .scale_grad_by_freq(options.scale_grad_by_freq())
163             .mode(options.mode())
164             .sparse(options.sparse())
165             .padding_idx(options.padding_idx()));
166     embeddingbag->weight.set_requires_grad(!options.freeze());
167     return embeddingbag;
168   }
169 };
170 } // namespace nn
171 } // namespace torch
172