1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/functional/embedding.h> 5 #include <torch/nn/modules/common.h> 6 #include <torch/nn/options/embedding.h> 7 #include <torch/nn/pimpl.h> 8 #include <torch/types.h> 9 10 #include <cstddef> 11 12 namespace torch { 13 namespace nn { 14 15 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Embedding 16 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 17 18 /// Performs a lookup in a fixed size embedding table. 19 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Embedding to learn 20 /// about the exact behavior of this module. 21 /// 22 /// See the documentation for `torch::nn::EmbeddingOptions` class to learn what 23 /// constructor arguments are supported for this module. 24 /// 25 /// Example: 26 /// ``` 27 /// Embedding model(EmbeddingOptions(10, 28 /// 2).padding_idx(3).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true)); 29 /// ``` 30 class TORCH_API EmbeddingImpl : public torch::nn::Cloneable<EmbeddingImpl> { 31 public: EmbeddingImpl(int64_t num_embeddings,int64_t embedding_dim)32 EmbeddingImpl(int64_t num_embeddings, int64_t embedding_dim) 33 : EmbeddingImpl(EmbeddingOptions(num_embeddings, embedding_dim)) {} 34 explicit EmbeddingImpl(EmbeddingOptions options_); 35 36 void reset() override; 37 38 void reset_parameters(); 39 40 /// Pretty prints the `Embedding` module into the given `stream`. 41 void pretty_print(std::ostream& stream) const override; 42 43 /// Performs a lookup on the embedding table stored in `weight` using the 44 /// `indices` supplied and returns the result. 45 Tensor forward(const Tensor& indices); 46 47 /// The `Options` used to configure this `Embedding` module. 48 /// Changes to `EmbeddingOptions` *after construction* have no effect. 49 EmbeddingOptions options; 50 51 /// The embedding table. 52 Tensor weight; 53 }; 54 55 /// A `ModuleHolder` subclass for `EmbeddingImpl`. 56 /// See the documentation for `EmbeddingImpl` class to learn what methods it 57 /// provides, and examples of how to use `Embedding` with 58 /// `torch::nn::EmbeddingOptions`. See the documentation for `ModuleHolder` to 59 /// learn about PyTorch's module storage semantics. 60 class Embedding : public torch::nn::ModuleHolder<EmbeddingImpl> { 61 public: 62 using torch::nn::ModuleHolder<EmbeddingImpl>::ModuleHolder; 63 64 /// See the documentation for `torch::nn::EmbeddingFromPretrainedOptions` 65 /// class to learn what optional arguments are supported for this function. 66 static Embedding from_pretrained( 67 const torch::Tensor& embeddings, 68 const EmbeddingFromPretrainedOptions& options = {}) { 69 TORCH_CHECK( 70 embeddings.dim() == 2, 71 "Embeddings parameter is expected to be 2-dimensional"); 72 73 // NOLINTNEXTLINE(cppcoreguidelines-init-variables) 74 int64_t rows, cols; 75 rows = embeddings.size(0); 76 cols = embeddings.size(1); 77 78 Embedding embedding(EmbeddingOptions(rows, cols) 79 ._weight(embeddings) 80 .padding_idx(options.padding_idx()) 81 .max_norm(options.max_norm()) 82 .norm_type(options.norm_type()) 83 .scale_grad_by_freq(options.scale_grad_by_freq()) 84 .sparse(options.sparse())); 85 embedding->weight.set_requires_grad(!options.freeze()); 86 return embedding; 87 } 88 }; 89 90 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ EmbeddingBag 91 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 92 93 /// Computes sums or means of 'bags' of embeddings, without instantiating the 94 /// intermediate embeddings. 95 /// See https://pytorch.org/docs/main/nn.html#torch.nn.EmbeddingBag to learn 96 /// about the exact behavior of this module. 97 /// 98 /// See the documentation for `torch::nn::EmbeddingBagOptions` class to learn 99 /// what constructor arguments are supported for this module. 100 /// 101 /// Example: 102 /// ``` 103 /// EmbeddingBag model(EmbeddingBagOptions(10, 104 /// 2).max_norm(2).norm_type(2.5).scale_grad_by_freq(true).sparse(true).mode(torch::kSum).padding_idx(1)); 105 /// ``` 106 class TORCH_API EmbeddingBagImpl 107 : public torch::nn::Cloneable<EmbeddingBagImpl> { 108 public: EmbeddingBagImpl(int64_t num_embeddings,int64_t embedding_dim)109 EmbeddingBagImpl(int64_t num_embeddings, int64_t embedding_dim) 110 : EmbeddingBagImpl(EmbeddingBagOptions(num_embeddings, embedding_dim)) {} 111 explicit EmbeddingBagImpl(EmbeddingBagOptions options_); 112 113 void reset() override; 114 115 void reset_parameters(); 116 117 /// Pretty prints the `EmbeddingBag` module into the given `stream`. 118 void pretty_print(std::ostream& stream) const override; 119 120 /// The `Options` used to configure this `EmbeddingBag` module. 121 EmbeddingBagOptions options; 122 /// The embedding table. 123 Tensor weight; 124 125 Tensor forward( 126 const Tensor& input, 127 const Tensor& offsets = {}, 128 const Tensor& per_sample_weights = {}); 129 130 protected: 131 FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())}) 132 }; 133 134 /// A `ModuleHolder` subclass for `EmbeddingBagImpl`. 135 /// See the documentation for `EmbeddingBagImpl` class to learn what methods it 136 /// provides, and examples of how to use `EmbeddingBag` with 137 /// `torch::nn::EmbeddingBagOptions`. See the documentation for `ModuleHolder` 138 /// to learn about PyTorch's module storage semantics. 139 class EmbeddingBag : public torch::nn::ModuleHolder<EmbeddingBagImpl> { 140 public: 141 using torch::nn::ModuleHolder<EmbeddingBagImpl>::ModuleHolder; 142 143 /// See the documentation for `torch::nn::EmbeddingBagFromPretrainedOptions` 144 /// class to learn what optional arguments are supported for this function. 145 static EmbeddingBag from_pretrained( 146 const torch::Tensor& embeddings, 147 const EmbeddingBagFromPretrainedOptions& options = {}) { 148 TORCH_CHECK( 149 embeddings.dim() == 2, 150 "Embeddings parameter is expected to be 2-dimensional"); 151 152 // NOLINTNEXTLINE(cppcoreguidelines-init-variables) 153 int64_t rows, cols; 154 rows = embeddings.size(0); 155 cols = embeddings.size(1); 156 157 EmbeddingBag embeddingbag( 158 EmbeddingBagOptions(rows, cols) 159 ._weight(embeddings) 160 .max_norm(options.max_norm()) 161 .norm_type(options.norm_type()) 162 .scale_grad_by_freq(options.scale_grad_by_freq()) 163 .mode(options.mode()) 164 .sparse(options.sparse()) 165 .padding_idx(options.padding_idx())); 166 embeddingbag->weight.set_requires_grad(!options.freeze()); 167 return embeddingbag; 168 } 169 }; 170 } // namespace nn 171 } // namespace torch 172