xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/options/transformer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/options/transformer.h>
2 #include <torch/nn/options/transformercoder.h>
3 #include <torch/nn/options/transformerlayer.h>
4 
5 namespace torch {
6 namespace nn {
7 
TransformerEncoderLayerOptions(int64_t d_model,int64_t nhead)8 TransformerEncoderLayerOptions::TransformerEncoderLayerOptions(
9     int64_t d_model,
10     int64_t nhead)
11     : d_model_(d_model), nhead_(nhead) {}
12 
TransformerDecoderLayerOptions(int64_t d_model,int64_t nhead)13 TransformerDecoderLayerOptions::TransformerDecoderLayerOptions(
14     int64_t d_model,
15     int64_t nhead)
16     : d_model_(d_model), nhead_(nhead) {}
17 
TransformerEncoderOptions(TransformerEncoderLayer encoder_layer,int64_t num_layers)18 TransformerEncoderOptions::TransformerEncoderOptions(
19     TransformerEncoderLayer encoder_layer,
20     int64_t num_layers)
21     : encoder_layer_(std::move(encoder_layer)), num_layers_(num_layers) {}
22 
TransformerEncoderOptions(const TransformerEncoderLayerOptions & encoder_layer_options,int64_t num_layers)23 TransformerEncoderOptions::TransformerEncoderOptions(
24     const TransformerEncoderLayerOptions& encoder_layer_options,
25     int64_t num_layers)
26     : encoder_layer_(encoder_layer_options), num_layers_(num_layers) {}
27 
TransformerDecoderOptions(TransformerDecoderLayer decoder_layer,int64_t num_layers)28 TransformerDecoderOptions::TransformerDecoderOptions(
29     TransformerDecoderLayer decoder_layer,
30     int64_t num_layers)
31     : decoder_layer_(std::move(decoder_layer)), num_layers_(num_layers) {}
32 
TransformerDecoderOptions(const TransformerDecoderLayerOptions & decoder_layer_options,int64_t num_layers)33 TransformerDecoderOptions::TransformerDecoderOptions(
34     const TransformerDecoderLayerOptions& decoder_layer_options,
35     int64_t num_layers)
36     : decoder_layer_(decoder_layer_options), num_layers_(num_layers) {}
37 
TransformerOptions(int64_t d_model,int64_t nhead)38 TransformerOptions::TransformerOptions(int64_t d_model, int64_t nhead)
39     : d_model_(d_model), nhead_(nhead) {}
40 
TransformerOptions(int64_t d_model,int64_t nhead,int64_t num_encoder_layers,int64_t num_decoder_layers)41 TransformerOptions::TransformerOptions(
42     int64_t d_model,
43     int64_t nhead,
44     int64_t num_encoder_layers,
45     int64_t num_decoder_layers)
46     : d_model_(d_model),
47       nhead_(nhead),
48       num_encoder_layers_(num_encoder_layers),
49       num_decoder_layers_(num_decoder_layers) {}
50 
51 } // namespace nn
52 } // namespace torch
53