1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/enum.h> 6 #include <torch/types.h> 7 8 #include <torch/nn/modules/container/any.h> 9 #include <torch/nn/options/transformerlayer.h> 10 11 namespace torch { 12 namespace nn { 13 14 /// Options for the `Transformer` module 15 /// 16 /// Example: 17 /// ``` 18 /// TransformerOptions options; 19 /// TransformerOptions options(16, 4); 20 /// auto options = TransformerOptions().d_model(4).nhead(2).dropout(0.0); 21 /// ``` 22 struct TORCH_API TransformerOptions { 23 // The following constructors are commonly used 24 // Please don't add more unless it is proved as a common usage 25 TransformerOptions() = default; 26 TransformerOptions(int64_t d_model, int64_t nhead); 27 TransformerOptions( 28 int64_t d_model, 29 int64_t nhead, 30 int64_t num_encoder_layers, 31 int64_t num_decoder_layers); 32 33 /// the number of expected features in the encoder/decoder inputs 34 /// (default=512) 35 TORCH_ARG(int64_t, d_model) = 512; 36 37 /// the number of heads in the multiheadattention models (default=8) 38 TORCH_ARG(int64_t, nhead) = 8; 39 40 /// the number of sub-encoder-layers in the encoder (default=6) 41 TORCH_ARG(int64_t, num_encoder_layers) = 6; 42 43 /// the number of sub-decoder-layers in the decoder (default=6) 44 TORCH_ARG(int64_t, num_decoder_layers) = 6; 45 46 /// the dimension of the feedforward network model (default=2048) 47 TORCH_ARG(int64_t, dim_feedforward) = 2048; 48 49 /// the dropout value (default=0.1) 50 TORCH_ARG(double, dropout) = 0.1; 51 52 /// the activation function of encoder/decoder intermediate layer 53 /// (default=``torch::kReLU``) 54 TORCH_ARG(activation_t, activation) = torch::kReLU; 55 56 /// custom encoder (default=None) 57 TORCH_ARG(AnyModule, custom_encoder); 58 59 /// custom decoder (default=None) 60 TORCH_ARG(AnyModule, custom_decoder); 61 }; 62 63 } // namespace nn 64 } // namespace torch 65