1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/module.h> 5 #include <torch/nn/modules/common.h> 6 #include <torch/nn/modules/container/any.h> 7 #include <torch/nn/modules/container/modulelist.h> 8 #include <torch/nn/options/transformercoder.h> 9 #include <torch/nn/pimpl.h> 10 11 #include <torch/types.h> 12 13 #include <ostream> 14 15 namespace torch { 16 namespace nn { 17 18 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoder 19 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 20 21 /// TransformerEncoder module. 22 /// See 23 /// https://pytorch.org/docs/main/generated/torch.nn.TransformerEncoder.html 24 /// to learn abouut the exact behavior of this encoder layer module. 25 /// 26 /// See the documentation for `torch::nn::TransformerEncoder` class to learn 27 /// what constructor arguments are supported for this encoder module. 28 /// 29 /// Example: 30 /// ``` 31 /// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, 32 /// 8).dropout(0.1)); TransformerEncoder 33 /// encoder(TransformerEncoderOptions(encoderLayer, 34 /// 6).norm(LayerNorm(LayerNormOptions({2})))); 35 /// ``` 36 class TORCH_API TransformerEncoderImpl 37 : public Cloneable<TransformerEncoderImpl> { 38 public: TransformerEncoderImpl(TransformerEncoderLayer encoder_layer,int64_t num_layers)39 TransformerEncoderImpl( 40 TransformerEncoderLayer encoder_layer, 41 int64_t num_layers) 42 : TransformerEncoderImpl( 43 TransformerEncoderOptions(encoder_layer, num_layers)) {} 44 explicit TransformerEncoderImpl(TransformerEncoderOptions options_); 45 46 Tensor forward( 47 const Tensor& src, 48 const Tensor& src_mask = {}, 49 const Tensor& src_key_padding_mask = {}); 50 51 void reset() override; 52 53 void reset_parameters(); 54 55 protected: 56 FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())}) 57 58 public: 59 /// options with which this `TransformerEncoder` was constructed 60 TransformerEncoderOptions options; 61 62 /// module list that contains all the encoder layers 63 ModuleList layers = nullptr; 64 65 /// optional normalization module 66 AnyModule norm; 67 }; 68 69 /// A `ModuleHolder` subclass for `TransformerEncoderImpl`. 70 /// See the documentation for `TransformerEncoderImpl` class to learn what 71 /// methods it provides, and examples of how to use `TransformerEncoder` with 72 /// `torch::nn::TransformerEncoderOptions`. 73 /// See the documentation for `ModuleHolder` to learn about PyTorch's 74 /// module storage semantics. 75 TORCH_MODULE(TransformerEncoder); 76 77 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoder 78 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 79 80 /// TransformerDecoder is a stack of N decoder layers. 81 /// See 82 /// https://pytorch.org/docs/main/generated/torch.nn.TransformerDecoder.html 83 /// to learn abouut the exact behavior of this decoder module 84 /// 85 /// See the documentation for `torch::nn::TransformerDecoderOptions` class to 86 /// learn what constructor arguments are supported for this decoder module 87 /// 88 /// Example: 89 /// ``` 90 /// TransformerDecoderLayer decoder_layer(TransformerDecoderLayerOptions(512, 91 /// 8).dropout(0.1)); TransformerDecoder 92 /// transformer_decoder(TransformerDecoderOptions(decoder_layer, 93 /// 6).norm(LayerNorm(LayerNormOptions({2})))); const auto memory = 94 /// torch::rand({10, 32, 512}); const auto tgt = torch::rand({20, 32, 512}); 95 /// auto out = transformer_decoder(tgt, memory); 96 /// ``` 97 class TORCH_API TransformerDecoderImpl 98 : public Cloneable<TransformerDecoderImpl> { 99 public: TransformerDecoderImpl(TransformerDecoderLayer decoder_layer,int64_t num_layers)100 TransformerDecoderImpl( 101 TransformerDecoderLayer decoder_layer, 102 int64_t num_layers) 103 : TransformerDecoderImpl( 104 TransformerDecoderOptions(decoder_layer, num_layers)) {} 105 explicit TransformerDecoderImpl(TransformerDecoderOptions options_); 106 107 void reset() override; 108 109 void reset_parameters(); 110 111 /// Pass the inputs (and mask) through the decoder layer in turn. 112 /// Args: 113 /// tgt: the sequence to the decoder layer (required). 114 /// memory: the sequence from the last layer of the encoder (required). 115 /// tgt_mask: the mask for the tgt sequence (optional). 116 /// memory_mask: the mask for the memory sequence (optional). 117 /// tgt_key_padding_mask: the mask for the tgt keys per batch 118 /// (optional). memory_key_padding_mask: the mask for the memory keys 119 /// per batch (optional). 120 Tensor forward( 121 const Tensor& tgt, 122 const Tensor& memory, 123 const Tensor& tgt_mask = {}, 124 const Tensor& memory_mask = {}, 125 const Tensor& tgt_key_padding_mask = {}, 126 const Tensor& memory_key_padding_mask = {}); 127 128 /// The options used to configure this module. 129 TransformerDecoderOptions options; 130 131 /// Cloned layers of decoder layers 132 ModuleList layers{nullptr}; 133 134 /// optional layer normalization module 135 AnyModule norm; 136 137 protected: 138 FORWARD_HAS_DEFAULT_ARGS( 139 {2, AnyValue(Tensor())}, 140 {3, AnyValue(Tensor())}, 141 {4, AnyValue(Tensor())}, 142 {5, AnyValue(Tensor())}) 143 }; 144 145 /// A `ModuleHolder` subclass for `TransformerDecoderImpl`. 146 /// See the documentation for `TransformerDecoderImpl` class to learn what 147 /// methods it provides, and examples of how to use `TransformerDecoder` with 148 /// `torch::nn::TransformerDecoderOptions`. 149 /// See the documentation for `ModuleHolder` to learn about PyTorch's 150 /// module storage semantics. 151 TORCH_MODULE(TransformerDecoder); 152 153 } // namespace nn 154 } // namespace torch 155