1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/module.h> 5 #include <torch/nn/modules/common.h> 6 #include <torch/nn/options/transformer.h> 7 #include <torch/nn/pimpl.h> 8 9 #include <torch/types.h> 10 11 #include <ostream> 12 13 namespace torch { 14 namespace nn { 15 16 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Transformer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 17 18 /// A transformer model. User is able to modify the attributes as needed. The 19 /// architecture is based on the paper "Attention Is All You Need". Ashish 20 /// Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N 21 /// Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. 22 /// In Advances in Neural Information Processing Systems, pages 6000-6010. 23 /// 24 /// See https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html to 25 /// learn about the exact behavior of this transformer model 26 /// 27 /// See the documentation for `torch::nn::Transformer` class to learn what 28 /// constructor arguments are supported for this encoder layer model 29 /// 30 /// Example: 31 /// ``` 32 /// Transformer trans(TransformerOptions(512, 8)); 33 /// ``` 34 class TORCH_API TransformerImpl : public Cloneable<TransformerImpl> { 35 public: 36 explicit TransformerImpl(TransformerOptions options_); 37 38 /// forward function for Transformer Module 39 /// Args: 40 /// src: the sequence to the encoder (required). 41 /// tgt: the sequence to the decoder (required). 42 /// src_mask: the additive mask for the src sequence (optional). 43 /// tgt_mask: the additive mask for the tgt sequence (optional). 44 /// memory_mask: the additive mask for the encoder output (optional). 45 /// src_key_padding_mask: the ByteTensor mask for src keys per batch 46 /// (optional). tgt_key_padding_mask: the ByteTensor mask for tgt keys per 47 /// batch (optional). memory_key_padding_mask: the ByteTensor mask for 48 /// memory keys per batch (optional). 49 /// 50 /// Shape: 51 /// src: `(S, N, E)` 52 /// tgt: `(T, N, E)` 53 /// src_mask: `(S, S)` 54 /// tgt_mask: `(T, T)` 55 /// memory_mask: `(T, S)` 56 /// src_key_padding_mask: `(N, S)` 57 /// tgt_key_padding_mask: `(N, T)` 58 /// memory_key_padding_mask: `(N, S)` 59 /// 60 /// Note: 61 /// [src/tgt/memory]_mask ensures that position i is allowed to attend the 62 /// unmasked positions. If a ByteTensor is provided, the non-zero 63 /// positions are not allowed to attend while the zero positions will be 64 /// unchanged. If a BoolTensor is provided, positions with `True` are not 65 /// allowed to attend while `False` values will be unchanged. If a 66 /// FloatTensor is provided, it will be added to the attention weight. 67 /// 68 /// [src/tgt/memory]_key_padding_mask provides specified elements in the 69 /// key to be ignored by the attention. If a ByteTensor is provided, the 70 /// non-zero positions will be ignored while the zero positions will be 71 /// unchanged. If a BoolTensor is provided, the positions with the value 72 /// of `True` will be ignored while the position with the value of `False` 73 /// will be unchanged. 74 /// 75 /// output: `(T, N, E)` 76 /// 77 /// Note: 78 /// Due to the multi-head attention architecture in the transformer model, 79 /// the output sequence length of a transformer is same as the input 80 /// sequence (i.e. target) length of the decode. 81 /// 82 /// where 83 /// S is the source sequence length, 84 /// T is the target sequence length, 85 /// N is the batch size, 86 /// E is the feature number. 87 Tensor forward( 88 const Tensor& src, 89 const Tensor& tgt, 90 const Tensor& src_mask = {}, 91 const Tensor& tgt_mask = {}, 92 const Tensor& memory_mask = {}, 93 const Tensor& src_key_padding_mask = {}, 94 const Tensor& tgt_key_padding_mask = {}, 95 const Tensor& memory_key_padding_mask = {}); 96 97 void reset() override; 98 99 void reset_parameters(); 100 101 /// Generate a square mask for the sequence. 102 /// The masked positions are filled with `-inf` in float type. 103 /// Unmasked positions are filled with `0.0` in float type. 104 /// Note: 105 /// 1. This function will always return a CPU tensor. 106 /// 2. This function requires the platform support IEEE754, since `-inf` is 107 /// guaranteed to 108 /// be valid only when IEEE754 is supported. If the platform doesn't 109 /// support IEEE754, this function will fill the mask with the smallest 110 /// float number instead of `-inf`, a one time warning will pop up as 111 /// well. 112 static Tensor generate_square_subsequent_mask(int64_t sz); 113 114 protected: 115 FORWARD_HAS_DEFAULT_ARGS( 116 {2, AnyValue(Tensor())}, 117 {3, AnyValue(Tensor())}, 118 {4, AnyValue(Tensor())}, 119 {5, AnyValue(Tensor())}, 120 {6, AnyValue(Tensor())}, 121 {7, AnyValue(Tensor())}) 122 123 public: 124 /// options with which this `Transformer` was constructed 125 TransformerOptions options; 126 127 /// encoder module 128 AnyModule encoder; 129 130 /// decoder module 131 AnyModule decoder; 132 }; 133 134 /// A `ModuleHolder` subclass for `TransformerImpl`. 135 /// See the documentation for `TransformerImpl` class to learn what 136 /// methods it provides, and examples of how to use `Transformer` with 137 /// `torch::nn::TransformerOptions`. 138 /// See the documentation for `ModuleHolder` to learn about PyTorch's 139 /// module storage semantics. 140 TORCH_MODULE(Transformer); 141 142 } // namespace nn 143 } // namespace torch 144