1 #include <torch/nn/options/transformer.h>
2 #include <torch/nn/options/transformercoder.h>
3 #include <torch/nn/options/transformerlayer.h>
4
5 namespace torch {
6 namespace nn {
7
TransformerEncoderLayerOptions(int64_t d_model,int64_t nhead)8 TransformerEncoderLayerOptions::TransformerEncoderLayerOptions(
9 int64_t d_model,
10 int64_t nhead)
11 : d_model_(d_model), nhead_(nhead) {}
12
TransformerDecoderLayerOptions(int64_t d_model,int64_t nhead)13 TransformerDecoderLayerOptions::TransformerDecoderLayerOptions(
14 int64_t d_model,
15 int64_t nhead)
16 : d_model_(d_model), nhead_(nhead) {}
17
TransformerEncoderOptions(TransformerEncoderLayer encoder_layer,int64_t num_layers)18 TransformerEncoderOptions::TransformerEncoderOptions(
19 TransformerEncoderLayer encoder_layer,
20 int64_t num_layers)
21 : encoder_layer_(std::move(encoder_layer)), num_layers_(num_layers) {}
22
TransformerEncoderOptions(const TransformerEncoderLayerOptions & encoder_layer_options,int64_t num_layers)23 TransformerEncoderOptions::TransformerEncoderOptions(
24 const TransformerEncoderLayerOptions& encoder_layer_options,
25 int64_t num_layers)
26 : encoder_layer_(encoder_layer_options), num_layers_(num_layers) {}
27
TransformerDecoderOptions(TransformerDecoderLayer decoder_layer,int64_t num_layers)28 TransformerDecoderOptions::TransformerDecoderOptions(
29 TransformerDecoderLayer decoder_layer,
30 int64_t num_layers)
31 : decoder_layer_(std::move(decoder_layer)), num_layers_(num_layers) {}
32
TransformerDecoderOptions(const TransformerDecoderLayerOptions & decoder_layer_options,int64_t num_layers)33 TransformerDecoderOptions::TransformerDecoderOptions(
34 const TransformerDecoderLayerOptions& decoder_layer_options,
35 int64_t num_layers)
36 : decoder_layer_(decoder_layer_options), num_layers_(num_layers) {}
37
TransformerOptions(int64_t d_model,int64_t nhead)38 TransformerOptions::TransformerOptions(int64_t d_model, int64_t nhead)
39 : d_model_(d_model), nhead_(nhead) {}
40
TransformerOptions(int64_t d_model,int64_t nhead,int64_t num_encoder_layers,int64_t num_decoder_layers)41 TransformerOptions::TransformerOptions(
42 int64_t d_model,
43 int64_t nhead,
44 int64_t num_encoder_layers,
45 int64_t num_decoder_layers)
46 : d_model_(d_model),
47 nhead_(nhead),
48 num_encoder_layers_(num_encoder_layers),
49 num_decoder_layers_(num_decoder_layers) {}
50
51 } // namespace nn
52 } // namespace torch
53