1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/enum.h> 6 #include <torch/types.h> 7 8 #include <torch/nn/modules/container/any.h> 9 #include <torch/nn/modules/transformerlayer.h> 10 11 namespace torch { 12 namespace nn { 13 14 /// Options for the `TransformerEncoder` 15 /// 16 /// Example: 17 /// ``` 18 /// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, 19 /// 8).dropout(0.1)); auto options = TransformerEncoderOptions(encoderLayer, 20 /// 6).norm(LayerNorm(LayerNormOptions({2}))); 21 /// ``` 22 struct TORCH_API TransformerEncoderOptions { 23 // This constructor will keep a shallow copy of encoder_layer, so it keeps all 24 // the data in encoder_layer. 25 TransformerEncoderOptions( 26 TransformerEncoderLayer encoder_layer, 27 int64_t num_layers); 28 // This constructor will create a new TransformerEncoderLayer obj based on 29 // passed in encoder_layer_options. 30 TransformerEncoderOptions( 31 const TransformerEncoderLayerOptions& encoder_layer_options, 32 int64_t num_layers); 33 34 /// transformer Encoder Layer 35 TORCH_ARG(TransformerEncoderLayer, encoder_layer) = nullptr; 36 37 /// number of encoder layers 38 TORCH_ARG(int64_t, num_layers); 39 40 /// normalization module 41 TORCH_ARG(AnyModule, norm); 42 }; 43 44 /// Options for the `TransformerDecoder` module. 45 /// 46 /// Example: 47 /// ``` 48 /// TransformerDecoderLayer decoder_layer(TransformerDecoderLayerOptions(512, 49 /// 8).dropout(0.1)); auto options = TransformerDecoderOptions(decoder_layer, 50 /// 6)norm(LayerNorm(LayerNormOptions({2}))); TransformerDecoder 51 /// transformer_decoder(options); 52 /// ``` 53 struct TORCH_API TransformerDecoderOptions { 54 // This constructor will keep the a ref of passed in decoder_layer, 55 // so it keeps all the data in decoder_layer. 56 TransformerDecoderOptions( 57 TransformerDecoderLayer decoder_layer, 58 int64_t num_layers); 59 // This constructor will create a new TransformerDecoderLayer obj, 60 // based on passed in decoder_layer_options. 61 TransformerDecoderOptions( 62 const TransformerDecoderLayerOptions& decoder_layer_options, 63 int64_t num_layers); 64 65 /// decoder layer to be cloned 66 TORCH_ARG(TransformerDecoderLayer, decoder_layer) = nullptr; 67 68 /// number of decoder layers 69 TORCH_ARG(int64_t, num_layers); 70 71 /// normalization module 72 TORCH_ARG(AnyModule, norm); 73 }; 74 75 } // namespace nn 76 } // namespace torch 77