1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/module.h> 5 #include <torch/nn/modules/activation.h> 6 #include <torch/nn/modules/common.h> 7 #include <torch/nn/modules/dropout.h> 8 #include <torch/nn/modules/linear.h> 9 #include <torch/nn/modules/normalization.h> 10 #include <torch/nn/options/transformerlayer.h> 11 #include <torch/nn/pimpl.h> 12 13 #include <torch/types.h> 14 15 #include <ostream> 16 17 namespace torch { 18 namespace nn { 19 20 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoderLayer 21 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 22 23 /// TransformerEncoderLayer module. 24 /// See 25 /// https://pytorch.org/docs/main/generated/torch.nn.TransformerEncoderLayer.html 26 /// to learn abouut the exact behavior of this encoder layer model 27 /// 28 /// See the documentation for `torch::nn::TransformerEncoderLayer` class to 29 /// learn what constructor arguments are supported for this encoder layer model 30 /// 31 /// Example: 32 /// ``` 33 /// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512, 34 /// 8).dropout(0.1)); 35 /// ``` 36 class TORCH_API TransformerEncoderLayerImpl 37 : public Cloneable<TransformerEncoderLayerImpl> { 38 public: TransformerEncoderLayerImpl(int64_t d_model,int64_t nhead)39 TransformerEncoderLayerImpl(int64_t d_model, int64_t nhead) 40 : TransformerEncoderLayerImpl( 41 TransformerEncoderLayerOptions(d_model, nhead)) {} 42 explicit TransformerEncoderLayerImpl(TransformerEncoderLayerOptions options_); 43 44 Tensor forward( 45 const Tensor& src, 46 const Tensor& src_mask = {}, 47 const Tensor& src_key_padding_mask = {}); 48 49 void reset() override; 50 51 void reset_parameters(); 52 53 protected: 54 FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())}) 55 56 public: 57 /// options with which this `TransformerEncoderLayer` was constructed 58 TransformerEncoderLayerOptions options; 59 60 /// self attention 61 MultiheadAttention self_attn = nullptr; 62 63 /// feedforward first linear layer 64 Linear linear1 = nullptr; 65 66 /// feedforward dropout layer 67 Dropout dropout = nullptr; 68 69 /// feedforward second linear layer 70 Linear linear2 = nullptr; 71 72 /// pre feedforward, normalization layer 73 LayerNorm norm1 = nullptr; 74 /// post feedfastward, normalization layer 75 LayerNorm norm2 = nullptr; 76 77 /// pre feedfastward, dropout layer 78 Dropout dropout1 = nullptr; 79 /// post feedfastward, dropout layer 80 Dropout dropout2 = nullptr; 81 }; 82 83 /// A `ModuleHolder` subclass for `TransformerEncoderLayerImpl``. 84 /// See the documentation for `TransformerEncoderLayerImpl` class to learn what 85 /// methods it provides, and examples of how to use `TransformerEncoderLayer` 86 /// with `torch::nn::TransformerEncoderLayerOptions`. See the documentation for 87 /// `ModuleHolder` to learn about PyTorch's module storage semantics. 88 TORCH_MODULE(TransformerEncoderLayer); 89 90 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoderLayer 91 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 92 93 /// TransformerDecoderLayer is made up of self-attn, multi-head-attn and 94 /// feedforward network. This standard decoder layer is based on the paper 95 /// "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar, 96 /// Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia 97 /// Polosukhin. 2017. Attention is all you need. In Advances in Neural 98 /// Information Processing Systems, pages 6000-6010. Users may modify or 99 /// implement in a different way during application. See 100 /// https://pytorch.org/docs/main/nn.html#transformer-layers to learn about 101 /// the exact behavior of this module. 102 /// 103 /// See the documentation for `torch::nn::TransformerDecoderLayerOptions` class 104 /// to learn what constructor arguments are supported for this module. 105 /// 106 /// Example: 107 /// ``` 108 /// TransformerDecoderLayer model(TransformerDecoderLayerOptions(512, 109 /// 8).dropout(0.2)); 110 /// ``` 111 class TORCH_API TransformerDecoderLayerImpl 112 : public Cloneable<TransformerDecoderLayerImpl> { 113 public: TransformerDecoderLayerImpl(int64_t d_model,int64_t nhead)114 TransformerDecoderLayerImpl(int64_t d_model, int64_t nhead) 115 : TransformerDecoderLayerImpl( 116 TransformerDecoderLayerOptions(d_model, nhead)) {} 117 explicit TransformerDecoderLayerImpl(TransformerDecoderLayerOptions options_); 118 119 void reset() override; 120 121 void reset_parameters(); 122 123 /// Pass the inputs (and mask) through the decoder layer. 124 /// Args: 125 /// tgt: the sequence to the decoder layer (required). 126 /// memory: the sequence from the last layer of the encoder (required). 127 /// tgt_mask: the mask for the tgt sequence (optional). 128 /// memory_mask: the mask for the memory sequence (optional). 129 /// tgt_key_padding_mask: the mask for the tgt keys per batch 130 /// (optional). memory_key_padding_mask: the mask for the memory keys 131 /// per batch (optional). 132 Tensor forward( 133 Tensor tgt, 134 const Tensor& memory, 135 const Tensor& tgt_mask = {}, 136 const Tensor& memory_mask = {}, 137 const Tensor& tgt_key_padding_mask = {}, 138 const Tensor& memory_key_padding_mask = {}); 139 140 /// The options used to configure this module. 141 TransformerDecoderLayerOptions options; 142 143 /// self attention 144 MultiheadAttention self_attn{nullptr}; 145 146 /// Dropout, post self attention 147 Dropout dropout1{nullptr}; 148 149 /// Normalization, post self attention 150 LayerNorm norm1{nullptr}; 151 152 /// Multi-headed attention 153 MultiheadAttention multihead_attn{nullptr}; 154 155 /// Dropout, post multi-headed attention 156 Dropout dropout2{nullptr}; 157 158 /// Normalization, post multi-headed attention 159 LayerNorm norm2{nullptr}; 160 161 /// Feed forward first linear layer 162 Linear linear1{nullptr}; 163 164 /// Feed forward dropout layer 165 Dropout dropout{nullptr}; 166 167 /// Feed forward second linear layer 168 Linear linear2{nullptr}; 169 170 /// Dropout, post feed forward 171 Dropout dropout3{nullptr}; 172 173 /// Normalization, post feed forward 174 LayerNorm norm3{nullptr}; 175 176 protected: 177 FORWARD_HAS_DEFAULT_ARGS( 178 {2, AnyValue(Tensor())}, 179 {3, AnyValue(Tensor())}, 180 {4, AnyValue(Tensor())}, 181 {5, AnyValue(Tensor())}) 182 183 /// Apply activation based on configuration 184 Tensor activation(const Tensor& input); 185 }; 186 187 /// A `ModuleHolder` subclass for `TransformerDecoderLayerImpl`. 188 /// See the documentation for `TransformerDecoderLayerImpl` class to learn what 189 /// methods it provides, and examples of how to use `TransformerDecoderLayer` 190 /// with `torch::nn::TransformerDecoderLayerOptions`. See the documentation for 191 /// `ModuleHolder` to learn about PyTorch's module storage semantics. 192 TORCH_MODULE(TransformerDecoderLayer); 193 194 } // namespace nn 195 } // namespace torch 196