1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/enum.h> 6 #include <torch/types.h> 7 8 namespace torch { 9 namespace nn { 10 11 using activation_t = std::variant< 12 enumtype::kReLU, 13 enumtype::kGELU, 14 std::function<Tensor(const Tensor&)>>; 15 16 /// Options for the `TransformerEncoderLayer` 17 /// 18 /// Example: 19 /// ``` 20 /// auto options = TransformerEncoderLayer(512, 8).dropout(0.2); 21 /// ``` 22 struct TORCH_API TransformerEncoderLayerOptions { 23 /* implicit */ TransformerEncoderLayerOptions(int64_t d_model, int64_t nhead); 24 25 /// the number of expected features in the input 26 TORCH_ARG(int64_t, d_model); 27 28 /// the number of heads in the multiheadattention models 29 TORCH_ARG(int64_t, nhead); 30 31 /// the dimension of the feedforward network model, default is 2048 32 TORCH_ARG(int64_t, dim_feedforward) = 2048; 33 34 /// the dropout value, default is 0.1 35 TORCH_ARG(double, dropout) = 0.1; 36 37 /// the activation function of intermediate layer, can be ``torch::kReLU``, 38 /// ``torch::GELU``, or a unary callable. Default: ``torch::kReLU`` 39 TORCH_ARG(activation_t, activation) = torch::kReLU; 40 }; 41 42 // ============================================================================ 43 44 /// Options for the `TransformerDecoderLayer` module. 45 /// 46 /// Example: 47 /// ``` 48 /// TransformerDecoderLayer model(TransformerDecoderLayerOptions(512, 49 /// 8).dropout(0.2)); 50 /// ``` 51 struct TORCH_API TransformerDecoderLayerOptions { 52 TransformerDecoderLayerOptions(int64_t d_model, int64_t nhead); 53 54 /// number of expected features in the input 55 TORCH_ARG(int64_t, d_model); 56 57 /// number of heads in the multiheadattention models 58 TORCH_ARG(int64_t, nhead); 59 60 /// dimension of the feedforward network model. Default: 2048 61 TORCH_ARG(int64_t, dim_feedforward) = 2048; 62 63 /// dropout value. Default: 1 64 TORCH_ARG(double, dropout) = 0.1; 65 66 /// activation function of intermediate layer, can be ``torch::kGELU``, 67 /// ``torch::kReLU``, or a unary callable. Default: ``torch::kReLU`` 68 TORCH_ARG(activation_t, activation) = torch::kReLU; 69 }; 70 71 } // namespace nn 72 } // namespace torch 73