xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/rnn.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/modules/common.h>
5 #include <torch/nn/modules/dropout.h>
6 #include <torch/nn/options/rnn.h>
7 #include <torch/nn/pimpl.h>
8 #include <torch/nn/utils/rnn.h>
9 #include <torch/types.h>
10 
11 #include <ATen/ATen.h>
12 #include <c10/util/Exception.h>
13 
14 #include <cstddef>
15 #include <functional>
16 #include <memory>
17 #include <vector>
18 
19 namespace torch {
20 namespace nn {
21 
22 namespace detail {
23 /// Base class for all RNN implementations (intended for code sharing).
24 template <typename Derived>
25 class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> {
26  public:
27   explicit RNNImplBase(const RNNOptionsBase& options_);
28 
29   /// Initializes the parameters of the RNN module.
30   void reset() override;
31 
32   void reset_parameters();
33 
34   /// Overrides `nn::Module::to()` to call `flatten_parameters()` after the
35   /// original operation.
36   void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)
37       override;
38   void to(torch::Dtype dtype, bool non_blocking = false) override;
39   void to(torch::Device device, bool non_blocking = false) override;
40 
41   /// Pretty prints the RNN module into the given `stream`.
42   void pretty_print(std::ostream& stream) const override;
43 
44   /// Modifies the internal storage of weights for optimization purposes.
45   ///
46   /// On CPU, this method should be called if any of the weight or bias vectors
47   /// are changed (i.e. weights are added or removed). On GPU, it should be
48   /// called __any time the storage of any parameter is modified__, e.g. any
49   /// time a parameter is assigned a new value. This allows using the fast path
50   /// in cuDNN implementations of respective RNN `forward()` methods. It is
51   /// called once upon construction, inside `reset()`.
52   void flatten_parameters();
53 
54   std::vector<Tensor> all_weights() const;
55 
56   /// The RNN's options.
57   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
58   RNNOptionsBase options_base;
59 
60  protected:
61   // Resets flat_weights_
62   // Note: be v. careful before removing this, as 3rd party device types
63   // likely rely on this behavior to properly .to() modules like LSTM.
64   void reset_flat_weights();
65 
66   void check_input(const Tensor& input, const Tensor& batch_sizes) const;
67 
68   std::tuple<int64_t, int64_t, int64_t> get_expected_hidden_size(
69       const Tensor& input,
70       const Tensor& batch_sizes) const;
71 
72   void check_hidden_size(
73       const Tensor& hx,
74       std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
75       std::string msg = "Expected hidden size {1}, got {2}") const;
76 
77   void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes)
78       const;
79 
80   Tensor permute_hidden(Tensor hx, const Tensor& permutation) const;
81 
82   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
83   std::vector<std::string> flat_weights_names_;
84   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
85   std::vector<std::vector<std::string>> all_weights_;
86   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
87   std::vector<Tensor> flat_weights_;
88 };
89 } // namespace detail
90 
91 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
92 
93 /// A multi-layer Elman RNN module with Tanh or ReLU activation.
94 /// See https://pytorch.org/docs/main/generated/torch.nn.RNN.html to learn
95 /// about the exact behavior of this module.
96 ///
97 /// See the documentation for `torch::nn::RNNOptions` class to learn what
98 /// constructor arguments are supported for this module.
99 ///
100 /// Example:
101 /// ```
102 /// RNN model(RNNOptions(128,
103 /// 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
104 /// ```
105 class TORCH_API RNNImpl : public detail::RNNImplBase<RNNImpl> {
106  public:
RNNImpl(int64_t input_size,int64_t hidden_size)107   RNNImpl(int64_t input_size, int64_t hidden_size)
108       : RNNImpl(RNNOptions(input_size, hidden_size)) {}
109   explicit RNNImpl(const RNNOptions& options_);
110 
111   std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
112 
113  protected:
114   FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
115 
116  public:
117   std::tuple<torch::nn::utils::rnn::PackedSequence, Tensor>
118   forward_with_packed_input(
119       const torch::nn::utils::rnn::PackedSequence& packed_input,
120       Tensor hx = {});
121 
122   RNNOptions options;
123 
124  protected:
125   std::tuple<Tensor, Tensor> forward_helper(
126       const Tensor& input,
127       const Tensor& batch_sizes,
128       const Tensor& sorted_indices,
129       int64_t max_batch_size,
130       Tensor hx);
131 };
132 
133 /// A `ModuleHolder` subclass for `RNNImpl`.
134 /// See the documentation for `RNNImpl` class to learn what methods it
135 /// provides, and examples of how to use `RNN` with `torch::nn::RNNOptions`.
136 /// See the documentation for `ModuleHolder` to learn about PyTorch's
137 /// module storage semantics.
138 TORCH_MODULE(RNN);
139 
140 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
141 
142 /// A multi-layer long-short-term-memory (LSTM) module.
143 /// See https://pytorch.org/docs/main/generated/torch.nn.LSTM.html to learn
144 /// about the exact behavior of this module.
145 ///
146 /// See the documentation for `torch::nn::LSTMOptions` class to learn what
147 /// constructor arguments are supported for this module.
148 ///
149 /// Example:
150 /// ```
151 /// LSTM model(LSTMOptions(2,
152 /// 4).num_layers(3).batch_first(false).bidirectional(true));
153 /// ```
154 class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
155  public:
LSTMImpl(int64_t input_size,int64_t hidden_size)156   LSTMImpl(int64_t input_size, int64_t hidden_size)
157       : LSTMImpl(LSTMOptions(input_size, hidden_size)) {}
158   explicit LSTMImpl(const LSTMOptions& options_);
159 
160   std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward(
161       const Tensor& input,
162       torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
163 
164  protected:
165   FORWARD_HAS_DEFAULT_ARGS(
166       {1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})
167 
168  public:
169   std::tuple<torch::nn::utils::rnn::PackedSequence, std::tuple<Tensor, Tensor>>
170   forward_with_packed_input(
171       const torch::nn::utils::rnn::PackedSequence& packed_input,
172       torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
173 
174   LSTMOptions options;
175 
176  protected:
177   void check_forward_args(
178       const Tensor& input,
179       std::tuple<Tensor, Tensor> hidden,
180       const Tensor& batch_sizes) const;
181 
182   std::tuple<int64_t, int64_t, int64_t> get_expected_cell_size(
183       const Tensor& input,
184       const Tensor& batch_sizes) const;
185 
186   std::tuple<Tensor, Tensor> permute_hidden(
187       std::tuple<Tensor, Tensor> hx,
188       const Tensor& permutation) const;
189 
190   std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward_helper(
191       const Tensor& input,
192       const Tensor& batch_sizes,
193       const Tensor& sorted_indices,
194       int64_t max_batch_size,
195       torch::optional<std::tuple<Tensor, Tensor>> hx_opt);
196 };
197 
198 /// A `ModuleHolder` subclass for `LSTMImpl`.
199 /// See the documentation for `LSTMImpl` class to learn what methods it
200 /// provides, and examples of how to use `LSTM` with `torch::nn::LSTMOptions`.
201 /// See the documentation for `ModuleHolder` to learn about PyTorch's
202 /// module storage semantics.
203 TORCH_MODULE(LSTM);
204 
205 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
206 
207 /// A multi-layer gated recurrent unit (GRU) module.
208 /// See https://pytorch.org/docs/main/generated/torch.nn.GRU.html to learn
209 /// about the exact behavior of this module.
210 ///
211 /// See the documentation for `torch::nn::GRUOptions` class to learn what
212 /// constructor arguments are supported for this module.
213 ///
214 /// Example:
215 /// ```
216 /// GRU model(GRUOptions(2,
217 /// 4).num_layers(3).batch_first(false).bidirectional(true));
218 /// ```
219 class TORCH_API GRUImpl : public detail::RNNImplBase<GRUImpl> {
220  public:
GRUImpl(int64_t input_size,int64_t hidden_size)221   GRUImpl(int64_t input_size, int64_t hidden_size)
222       : GRUImpl(GRUOptions(input_size, hidden_size)) {}
223   explicit GRUImpl(const GRUOptions& options_);
224 
225   std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
226 
227  protected:
228   FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::Tensor())})
229 
230  public:
231   std::tuple<torch::nn::utils::rnn::PackedSequence, Tensor>
232   forward_with_packed_input(
233       const torch::nn::utils::rnn::PackedSequence& packed_input,
234       Tensor hx = {});
235 
236   GRUOptions options;
237 
238  protected:
239   std::tuple<Tensor, Tensor> forward_helper(
240       const Tensor& input,
241       const Tensor& batch_sizes,
242       const Tensor& sorted_indices,
243       int64_t max_batch_size,
244       Tensor hx);
245 };
246 
247 /// A `ModuleHolder` subclass for `GRUImpl`.
248 /// See the documentation for `GRUImpl` class to learn what methods it
249 /// provides, and examples of how to use `GRU` with `torch::nn::GRUOptions`.
250 /// See the documentation for `ModuleHolder` to learn about PyTorch's
251 /// module storage semantics.
252 TORCH_MODULE(GRU);
253 
254 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase
255 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
256 
257 namespace detail {
258 /// Base class for all RNNCell implementations (intended for code sharing).
259 template <typename Derived>
260 class TORCH_API RNNCellImplBase : public torch::nn::Cloneable<Derived> {
261  public:
262   explicit RNNCellImplBase(const RNNCellOptionsBase& options_);
263 
264   /// Initializes the parameters of the RNNCell module.
265   void reset() override;
266 
267   void reset_parameters();
268 
269   /// Pretty prints the RNN module into the given `stream`.
270   void pretty_print(std::ostream& stream) const override;
271 
272   RNNCellOptionsBase options_base;
273 
274   Tensor weight_ih;
275   Tensor weight_hh;
276   Tensor bias_ih;
277   Tensor bias_hh;
278 
279  protected:
280   void check_forward_input(const Tensor& input, const std::string& name) const;
281   virtual std::string get_nonlinearity_str() const;
282 };
283 } // namespace detail
284 
285 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCell
286 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
287 
288 /// An Elman RNN cell with tanh or ReLU non-linearity.
289 /// See https://pytorch.org/docs/main/nn.html#torch.nn.RNNCell to learn
290 /// about the exact behavior of this module.
291 ///
292 /// See the documentation for `torch::nn::RNNCellOptions` class to learn what
293 /// constructor arguments are supported for this module.
294 ///
295 /// Example:
296 /// ```
297 /// RNNCell model(RNNCellOptions(20,
298 /// 10).bias(false).nonlinearity(torch::kReLU));
299 /// ```
300 class TORCH_API RNNCellImpl : public detail::RNNCellImplBase<RNNCellImpl> {
301  public:
RNNCellImpl(int64_t input_size,int64_t hidden_size)302   RNNCellImpl(int64_t input_size, int64_t hidden_size)
303       : RNNCellImpl(RNNCellOptions(input_size, hidden_size)) {}
304   explicit RNNCellImpl(const RNNCellOptions& options_);
305 
306   Tensor forward(const Tensor& input, Tensor hx = {});
307 
308  protected:
309   FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
310 
311  public:
312   RNNCellOptions options;
313 
314  protected:
315   std::string get_nonlinearity_str() const override;
316 };
317 
318 /// A `ModuleHolder` subclass for `RNNCellImpl`.
319 /// See the documentation for `RNNCellImpl` class to learn what methods it
320 /// provides, and examples of how to use `RNNCell` with
321 /// `torch::nn::RNNCellOptions`. See the documentation for `ModuleHolder` to
322 /// learn about PyTorch's module storage semantics.
323 TORCH_MODULE(RNNCell);
324 
325 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCell
326 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
327 
328 /// A long short-term memory (LSTM) cell.
329 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LSTMCell to learn
330 /// about the exact behavior of this module.
331 ///
332 /// See the documentation for `torch::nn::LSTMCellOptions` class to learn what
333 /// constructor arguments are supported for this module.
334 ///
335 /// Example:
336 /// ```
337 /// LSTMCell model(LSTMCellOptions(20, 10).bias(false));
338 /// ```
339 class TORCH_API LSTMCellImpl : public detail::RNNCellImplBase<LSTMCellImpl> {
340  public:
LSTMCellImpl(int64_t input_size,int64_t hidden_size)341   LSTMCellImpl(int64_t input_size, int64_t hidden_size)
342       : LSTMCellImpl(LSTMCellOptions(input_size, hidden_size)) {}
343   explicit LSTMCellImpl(const LSTMCellOptions& options_);
344 
345   std::tuple<Tensor, Tensor> forward(
346       const Tensor& input,
347       torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
348 
349  protected:
350   FORWARD_HAS_DEFAULT_ARGS(
351       {1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})
352 
353  public:
354   LSTMCellOptions options;
355 };
356 
357 /// A `ModuleHolder` subclass for `LSTMCellImpl`.
358 /// See the documentation for `LSTMCellImpl` class to learn what methods it
359 /// provides, and examples of how to use `LSTMCell` with
360 /// `torch::nn::LSTMCellOptions`. See the documentation for `ModuleHolder` to
361 /// learn about PyTorch's module storage semantics.
362 TORCH_MODULE(LSTMCell);
363 
364 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCell
365 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
366 
367 /// A gated recurrent unit (GRU) cell.
368 /// See https://pytorch.org/docs/main/nn.html#torch.nn.GRUCell to learn
369 /// about the exact behavior of this module.
370 ///
371 /// See the documentation for `torch::nn::GRUCellOptions` class to learn what
372 /// constructor arguments are supported for this module.
373 ///
374 /// Example:
375 /// ```
376 /// GRUCell model(GRUCellOptions(20, 10).bias(false));
377 /// ```
378 class TORCH_API GRUCellImpl : public detail::RNNCellImplBase<GRUCellImpl> {
379  public:
GRUCellImpl(int64_t input_size,int64_t hidden_size)380   GRUCellImpl(int64_t input_size, int64_t hidden_size)
381       : GRUCellImpl(GRUCellOptions(input_size, hidden_size)) {}
382   explicit GRUCellImpl(const GRUCellOptions& options_);
383 
384   Tensor forward(const Tensor& input, Tensor hx = {});
385 
386  protected:
387   FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
388 
389  public:
390   GRUCellOptions options;
391 };
392 
393 /// A `ModuleHolder` subclass for `GRUCellImpl`.
394 /// See the documentation for `GRUCellImpl` class to learn what methods it
395 /// provides, and examples of how to use `GRUCell` with
396 /// `torch::nn::GRUCellOptions`. See the documentation for `ModuleHolder` to
397 /// learn about PyTorch's module storage semantics.
398 TORCH_MODULE(GRUCell);
399 
400 } // namespace nn
401 } // namespace torch
402