xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/options/rnn.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/options/rnn.h>
2 
3 namespace torch {
4 namespace nn {
5 
6 namespace detail {
7 
RNNOptionsBase(rnn_options_base_mode_t mode,int64_t input_size,int64_t hidden_size)8 RNNOptionsBase::RNNOptionsBase(
9     rnn_options_base_mode_t mode,
10     int64_t input_size,
11     int64_t hidden_size)
12     : mode_(mode), input_size_(input_size), hidden_size_(hidden_size) {}
13 
14 } // namespace detail
15 
RNNOptions(int64_t input_size,int64_t hidden_size)16 RNNOptions::RNNOptions(int64_t input_size, int64_t hidden_size)
17     : input_size_(input_size), hidden_size_(hidden_size) {}
18 
LSTMOptions(int64_t input_size,int64_t hidden_size)19 LSTMOptions::LSTMOptions(int64_t input_size, int64_t hidden_size)
20     : input_size_(input_size), hidden_size_(hidden_size) {}
21 
GRUOptions(int64_t input_size,int64_t hidden_size)22 GRUOptions::GRUOptions(int64_t input_size, int64_t hidden_size)
23     : input_size_(input_size), hidden_size_(hidden_size) {}
24 
25 namespace detail {
26 
RNNCellOptionsBase(int64_t input_size,int64_t hidden_size,bool bias,int64_t num_chunks)27 RNNCellOptionsBase::RNNCellOptionsBase(
28     int64_t input_size,
29     int64_t hidden_size,
30     bool bias,
31     int64_t num_chunks)
32     : input_size_(input_size),
33       hidden_size_(hidden_size),
34       bias_(bias),
35       num_chunks_(num_chunks) {}
36 
37 } // namespace detail
38 
RNNCellOptions(int64_t input_size,int64_t hidden_size)39 RNNCellOptions::RNNCellOptions(int64_t input_size, int64_t hidden_size)
40     : input_size_(input_size), hidden_size_(hidden_size) {}
41 
LSTMCellOptions(int64_t input_size,int64_t hidden_size)42 LSTMCellOptions::LSTMCellOptions(int64_t input_size, int64_t hidden_size)
43     : input_size_(input_size), hidden_size_(hidden_size) {}
44 
GRUCellOptions(int64_t input_size,int64_t hidden_size)45 GRUCellOptions::GRUCellOptions(int64_t input_size, int64_t hidden_size)
46     : input_size_(input_size), hidden_size_(hidden_size) {}
47 
48 } // namespace nn
49 } // namespace torch
50