1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/modules/common.h> 5 #include <torch/nn/modules/dropout.h> 6 #include <torch/nn/options/rnn.h> 7 #include <torch/nn/pimpl.h> 8 #include <torch/nn/utils/rnn.h> 9 #include <torch/types.h> 10 11 #include <ATen/ATen.h> 12 #include <c10/util/Exception.h> 13 14 #include <cstddef> 15 #include <functional> 16 #include <memory> 17 #include <vector> 18 19 namespace torch { 20 namespace nn { 21 22 namespace detail { 23 /// Base class for all RNN implementations (intended for code sharing). 24 template <typename Derived> 25 class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> { 26 public: 27 explicit RNNImplBase(const RNNOptionsBase& options_); 28 29 /// Initializes the parameters of the RNN module. 30 void reset() override; 31 32 void reset_parameters(); 33 34 /// Overrides `nn::Module::to()` to call `flatten_parameters()` after the 35 /// original operation. 36 void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false) 37 override; 38 void to(torch::Dtype dtype, bool non_blocking = false) override; 39 void to(torch::Device device, bool non_blocking = false) override; 40 41 /// Pretty prints the RNN module into the given `stream`. 42 void pretty_print(std::ostream& stream) const override; 43 44 /// Modifies the internal storage of weights for optimization purposes. 45 /// 46 /// On CPU, this method should be called if any of the weight or bias vectors 47 /// are changed (i.e. weights are added or removed). On GPU, it should be 48 /// called __any time the storage of any parameter is modified__, e.g. any 49 /// time a parameter is assigned a new value. This allows using the fast path 50 /// in cuDNN implementations of respective RNN `forward()` methods. It is 51 /// called once upon construction, inside `reset()`. 52 void flatten_parameters(); 53 54 std::vector<Tensor> all_weights() const; 55 56 /// The RNN's options. 57 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 58 RNNOptionsBase options_base; 59 60 protected: 61 // Resets flat_weights_ 62 // Note: be v. careful before removing this, as 3rd party device types 63 // likely rely on this behavior to properly .to() modules like LSTM. 64 void reset_flat_weights(); 65 66 void check_input(const Tensor& input, const Tensor& batch_sizes) const; 67 68 std::tuple<int64_t, int64_t, int64_t> get_expected_hidden_size( 69 const Tensor& input, 70 const Tensor& batch_sizes) const; 71 72 void check_hidden_size( 73 const Tensor& hx, 74 std::tuple<int64_t, int64_t, int64_t> expected_hidden_size, 75 std::string msg = "Expected hidden size {1}, got {2}") const; 76 77 void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) 78 const; 79 80 Tensor permute_hidden(Tensor hx, const Tensor& permutation) const; 81 82 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 83 std::vector<std::string> flat_weights_names_; 84 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 85 std::vector<std::vector<std::string>> all_weights_; 86 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 87 std::vector<Tensor> flat_weights_; 88 }; 89 } // namespace detail 90 91 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 92 93 /// A multi-layer Elman RNN module with Tanh or ReLU activation. 94 /// See https://pytorch.org/docs/main/generated/torch.nn.RNN.html to learn 95 /// about the exact behavior of this module. 96 /// 97 /// See the documentation for `torch::nn::RNNOptions` class to learn what 98 /// constructor arguments are supported for this module. 99 /// 100 /// Example: 101 /// ``` 102 /// RNN model(RNNOptions(128, 103 /// 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh)); 104 /// ``` 105 class TORCH_API RNNImpl : public detail::RNNImplBase<RNNImpl> { 106 public: RNNImpl(int64_t input_size,int64_t hidden_size)107 RNNImpl(int64_t input_size, int64_t hidden_size) 108 : RNNImpl(RNNOptions(input_size, hidden_size)) {} 109 explicit RNNImpl(const RNNOptions& options_); 110 111 std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {}); 112 113 protected: 114 FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}) 115 116 public: 117 std::tuple<torch::nn::utils::rnn::PackedSequence, Tensor> 118 forward_with_packed_input( 119 const torch::nn::utils::rnn::PackedSequence& packed_input, 120 Tensor hx = {}); 121 122 RNNOptions options; 123 124 protected: 125 std::tuple<Tensor, Tensor> forward_helper( 126 const Tensor& input, 127 const Tensor& batch_sizes, 128 const Tensor& sorted_indices, 129 int64_t max_batch_size, 130 Tensor hx); 131 }; 132 133 /// A `ModuleHolder` subclass for `RNNImpl`. 134 /// See the documentation for `RNNImpl` class to learn what methods it 135 /// provides, and examples of how to use `RNN` with `torch::nn::RNNOptions`. 136 /// See the documentation for `ModuleHolder` to learn about PyTorch's 137 /// module storage semantics. 138 TORCH_MODULE(RNN); 139 140 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 141 142 /// A multi-layer long-short-term-memory (LSTM) module. 143 /// See https://pytorch.org/docs/main/generated/torch.nn.LSTM.html to learn 144 /// about the exact behavior of this module. 145 /// 146 /// See the documentation for `torch::nn::LSTMOptions` class to learn what 147 /// constructor arguments are supported for this module. 148 /// 149 /// Example: 150 /// ``` 151 /// LSTM model(LSTMOptions(2, 152 /// 4).num_layers(3).batch_first(false).bidirectional(true)); 153 /// ``` 154 class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> { 155 public: LSTMImpl(int64_t input_size,int64_t hidden_size)156 LSTMImpl(int64_t input_size, int64_t hidden_size) 157 : LSTMImpl(LSTMOptions(input_size, hidden_size)) {} 158 explicit LSTMImpl(const LSTMOptions& options_); 159 160 std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward( 161 const Tensor& input, 162 torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {}); 163 164 protected: 165 FORWARD_HAS_DEFAULT_ARGS( 166 {1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())}) 167 168 public: 169 std::tuple<torch::nn::utils::rnn::PackedSequence, std::tuple<Tensor, Tensor>> 170 forward_with_packed_input( 171 const torch::nn::utils::rnn::PackedSequence& packed_input, 172 torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {}); 173 174 LSTMOptions options; 175 176 protected: 177 void check_forward_args( 178 const Tensor& input, 179 std::tuple<Tensor, Tensor> hidden, 180 const Tensor& batch_sizes) const; 181 182 std::tuple<int64_t, int64_t, int64_t> get_expected_cell_size( 183 const Tensor& input, 184 const Tensor& batch_sizes) const; 185 186 std::tuple<Tensor, Tensor> permute_hidden( 187 std::tuple<Tensor, Tensor> hx, 188 const Tensor& permutation) const; 189 190 std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward_helper( 191 const Tensor& input, 192 const Tensor& batch_sizes, 193 const Tensor& sorted_indices, 194 int64_t max_batch_size, 195 torch::optional<std::tuple<Tensor, Tensor>> hx_opt); 196 }; 197 198 /// A `ModuleHolder` subclass for `LSTMImpl`. 199 /// See the documentation for `LSTMImpl` class to learn what methods it 200 /// provides, and examples of how to use `LSTM` with `torch::nn::LSTMOptions`. 201 /// See the documentation for `ModuleHolder` to learn about PyTorch's 202 /// module storage semantics. 203 TORCH_MODULE(LSTM); 204 205 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 206 207 /// A multi-layer gated recurrent unit (GRU) module. 208 /// See https://pytorch.org/docs/main/generated/torch.nn.GRU.html to learn 209 /// about the exact behavior of this module. 210 /// 211 /// See the documentation for `torch::nn::GRUOptions` class to learn what 212 /// constructor arguments are supported for this module. 213 /// 214 /// Example: 215 /// ``` 216 /// GRU model(GRUOptions(2, 217 /// 4).num_layers(3).batch_first(false).bidirectional(true)); 218 /// ``` 219 class TORCH_API GRUImpl : public detail::RNNImplBase<GRUImpl> { 220 public: GRUImpl(int64_t input_size,int64_t hidden_size)221 GRUImpl(int64_t input_size, int64_t hidden_size) 222 : GRUImpl(GRUOptions(input_size, hidden_size)) {} 223 explicit GRUImpl(const GRUOptions& options_); 224 225 std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {}); 226 227 protected: 228 FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::Tensor())}) 229 230 public: 231 std::tuple<torch::nn::utils::rnn::PackedSequence, Tensor> 232 forward_with_packed_input( 233 const torch::nn::utils::rnn::PackedSequence& packed_input, 234 Tensor hx = {}); 235 236 GRUOptions options; 237 238 protected: 239 std::tuple<Tensor, Tensor> forward_helper( 240 const Tensor& input, 241 const Tensor& batch_sizes, 242 const Tensor& sorted_indices, 243 int64_t max_batch_size, 244 Tensor hx); 245 }; 246 247 /// A `ModuleHolder` subclass for `GRUImpl`. 248 /// See the documentation for `GRUImpl` class to learn what methods it 249 /// provides, and examples of how to use `GRU` with `torch::nn::GRUOptions`. 250 /// See the documentation for `ModuleHolder` to learn about PyTorch's 251 /// module storage semantics. 252 TORCH_MODULE(GRU); 253 254 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase 255 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 256 257 namespace detail { 258 /// Base class for all RNNCell implementations (intended for code sharing). 259 template <typename Derived> 260 class TORCH_API RNNCellImplBase : public torch::nn::Cloneable<Derived> { 261 public: 262 explicit RNNCellImplBase(const RNNCellOptionsBase& options_); 263 264 /// Initializes the parameters of the RNNCell module. 265 void reset() override; 266 267 void reset_parameters(); 268 269 /// Pretty prints the RNN module into the given `stream`. 270 void pretty_print(std::ostream& stream) const override; 271 272 RNNCellOptionsBase options_base; 273 274 Tensor weight_ih; 275 Tensor weight_hh; 276 Tensor bias_ih; 277 Tensor bias_hh; 278 279 protected: 280 void check_forward_input(const Tensor& input, const std::string& name) const; 281 virtual std::string get_nonlinearity_str() const; 282 }; 283 } // namespace detail 284 285 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCell 286 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 287 288 /// An Elman RNN cell with tanh or ReLU non-linearity. 289 /// See https://pytorch.org/docs/main/nn.html#torch.nn.RNNCell to learn 290 /// about the exact behavior of this module. 291 /// 292 /// See the documentation for `torch::nn::RNNCellOptions` class to learn what 293 /// constructor arguments are supported for this module. 294 /// 295 /// Example: 296 /// ``` 297 /// RNNCell model(RNNCellOptions(20, 298 /// 10).bias(false).nonlinearity(torch::kReLU)); 299 /// ``` 300 class TORCH_API RNNCellImpl : public detail::RNNCellImplBase<RNNCellImpl> { 301 public: RNNCellImpl(int64_t input_size,int64_t hidden_size)302 RNNCellImpl(int64_t input_size, int64_t hidden_size) 303 : RNNCellImpl(RNNCellOptions(input_size, hidden_size)) {} 304 explicit RNNCellImpl(const RNNCellOptions& options_); 305 306 Tensor forward(const Tensor& input, Tensor hx = {}); 307 308 protected: 309 FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}) 310 311 public: 312 RNNCellOptions options; 313 314 protected: 315 std::string get_nonlinearity_str() const override; 316 }; 317 318 /// A `ModuleHolder` subclass for `RNNCellImpl`. 319 /// See the documentation for `RNNCellImpl` class to learn what methods it 320 /// provides, and examples of how to use `RNNCell` with 321 /// `torch::nn::RNNCellOptions`. See the documentation for `ModuleHolder` to 322 /// learn about PyTorch's module storage semantics. 323 TORCH_MODULE(RNNCell); 324 325 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCell 326 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 327 328 /// A long short-term memory (LSTM) cell. 329 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LSTMCell to learn 330 /// about the exact behavior of this module. 331 /// 332 /// See the documentation for `torch::nn::LSTMCellOptions` class to learn what 333 /// constructor arguments are supported for this module. 334 /// 335 /// Example: 336 /// ``` 337 /// LSTMCell model(LSTMCellOptions(20, 10).bias(false)); 338 /// ``` 339 class TORCH_API LSTMCellImpl : public detail::RNNCellImplBase<LSTMCellImpl> { 340 public: LSTMCellImpl(int64_t input_size,int64_t hidden_size)341 LSTMCellImpl(int64_t input_size, int64_t hidden_size) 342 : LSTMCellImpl(LSTMCellOptions(input_size, hidden_size)) {} 343 explicit LSTMCellImpl(const LSTMCellOptions& options_); 344 345 std::tuple<Tensor, Tensor> forward( 346 const Tensor& input, 347 torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {}); 348 349 protected: 350 FORWARD_HAS_DEFAULT_ARGS( 351 {1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())}) 352 353 public: 354 LSTMCellOptions options; 355 }; 356 357 /// A `ModuleHolder` subclass for `LSTMCellImpl`. 358 /// See the documentation for `LSTMCellImpl` class to learn what methods it 359 /// provides, and examples of how to use `LSTMCell` with 360 /// `torch::nn::LSTMCellOptions`. See the documentation for `ModuleHolder` to 361 /// learn about PyTorch's module storage semantics. 362 TORCH_MODULE(LSTMCell); 363 364 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCell 365 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 366 367 /// A gated recurrent unit (GRU) cell. 368 /// See https://pytorch.org/docs/main/nn.html#torch.nn.GRUCell to learn 369 /// about the exact behavior of this module. 370 /// 371 /// See the documentation for `torch::nn::GRUCellOptions` class to learn what 372 /// constructor arguments are supported for this module. 373 /// 374 /// Example: 375 /// ``` 376 /// GRUCell model(GRUCellOptions(20, 10).bias(false)); 377 /// ``` 378 class TORCH_API GRUCellImpl : public detail::RNNCellImplBase<GRUCellImpl> { 379 public: GRUCellImpl(int64_t input_size,int64_t hidden_size)380 GRUCellImpl(int64_t input_size, int64_t hidden_size) 381 : GRUCellImpl(GRUCellOptions(input_size, hidden_size)) {} 382 explicit GRUCellImpl(const GRUCellOptions& options_); 383 384 Tensor forward(const Tensor& input, Tensor hx = {}); 385 386 protected: 387 FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}) 388 389 public: 390 GRUCellOptions options; 391 }; 392 393 /// A `ModuleHolder` subclass for `GRUCellImpl`. 394 /// See the documentation for `GRUCellImpl` class to learn what methods it 395 /// provides, and examples of how to use `GRUCell` with 396 /// `torch::nn::GRUCellOptions`. See the documentation for `ModuleHolder` to 397 /// learn about PyTorch's module storage semantics. 398 TORCH_MODULE(GRUCell); 399 400 } // namespace nn 401 } // namespace torch 402