xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/transformerlayer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/modules/activation.h>
6 #include <torch/nn/modules/common.h>
7 #include <torch/nn/modules/dropout.h>
8 #include <torch/nn/modules/linear.h>
9 #include <torch/nn/modules/normalization.h>
10 #include <torch/nn/options/transformerlayer.h>
11 #include <torch/nn/pimpl.h>
12 
13 #include <torch/types.h>
14 
15 #include <ostream>
16 
17 namespace torch {
18 namespace nn {
19 
20 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoderLayer
21 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
22 
23 /// TransformerEncoderLayer module.
24 /// See
25 /// https://pytorch.org/docs/main/generated/torch.nn.TransformerEncoderLayer.html
26 /// to learn abouut the exact behavior of this encoder layer model
27 ///
28 /// See the documentation for `torch::nn::TransformerEncoderLayer` class to
29 /// learn what constructor arguments are supported for this encoder layer model
30 ///
31 /// Example:
32 /// ```
33 /// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512,
34 /// 8).dropout(0.1));
35 /// ```
36 class TORCH_API TransformerEncoderLayerImpl
37     : public Cloneable<TransformerEncoderLayerImpl> {
38  public:
TransformerEncoderLayerImpl(int64_t d_model,int64_t nhead)39   TransformerEncoderLayerImpl(int64_t d_model, int64_t nhead)
40       : TransformerEncoderLayerImpl(
41             TransformerEncoderLayerOptions(d_model, nhead)) {}
42   explicit TransformerEncoderLayerImpl(TransformerEncoderLayerOptions options_);
43 
44   Tensor forward(
45       const Tensor& src,
46       const Tensor& src_mask = {},
47       const Tensor& src_key_padding_mask = {});
48 
49   void reset() override;
50 
51   void reset_parameters();
52 
53  protected:
54   FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())})
55 
56  public:
57   /// options with which this `TransformerEncoderLayer` was constructed
58   TransformerEncoderLayerOptions options;
59 
60   /// self attention
61   MultiheadAttention self_attn = nullptr;
62 
63   /// feedforward first linear layer
64   Linear linear1 = nullptr;
65 
66   /// feedforward dropout layer
67   Dropout dropout = nullptr;
68 
69   /// feedforward second linear layer
70   Linear linear2 = nullptr;
71 
72   /// pre feedforward, normalization layer
73   LayerNorm norm1 = nullptr;
74   /// post feedfastward, normalization layer
75   LayerNorm norm2 = nullptr;
76 
77   /// pre feedfastward, dropout layer
78   Dropout dropout1 = nullptr;
79   /// post feedfastward, dropout layer
80   Dropout dropout2 = nullptr;
81 };
82 
83 /// A `ModuleHolder` subclass for `TransformerEncoderLayerImpl``.
84 /// See the documentation for `TransformerEncoderLayerImpl` class to learn what
85 /// methods it provides, and examples of how to use `TransformerEncoderLayer`
86 /// with `torch::nn::TransformerEncoderLayerOptions`. See the documentation for
87 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
88 TORCH_MODULE(TransformerEncoderLayer);
89 
90 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoderLayer
91 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
92 
93 /// TransformerDecoderLayer is made up of self-attn, multi-head-attn and
94 /// feedforward network. This standard decoder layer is based on the paper
95 /// "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, Niki Parmar,
96 /// Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia
97 /// Polosukhin. 2017. Attention is all you need. In Advances in Neural
98 /// Information Processing Systems, pages 6000-6010. Users may modify or
99 /// implement in a different way during application. See
100 /// https://pytorch.org/docs/main/nn.html#transformer-layers to learn about
101 /// the exact behavior of this module.
102 ///
103 /// See the documentation for `torch::nn::TransformerDecoderLayerOptions` class
104 /// to learn what constructor arguments are supported for this module.
105 ///
106 /// Example:
107 /// ```
108 /// TransformerDecoderLayer model(TransformerDecoderLayerOptions(512,
109 /// 8).dropout(0.2));
110 /// ```
111 class TORCH_API TransformerDecoderLayerImpl
112     : public Cloneable<TransformerDecoderLayerImpl> {
113  public:
TransformerDecoderLayerImpl(int64_t d_model,int64_t nhead)114   TransformerDecoderLayerImpl(int64_t d_model, int64_t nhead)
115       : TransformerDecoderLayerImpl(
116             TransformerDecoderLayerOptions(d_model, nhead)) {}
117   explicit TransformerDecoderLayerImpl(TransformerDecoderLayerOptions options_);
118 
119   void reset() override;
120 
121   void reset_parameters();
122 
123   /// Pass the inputs (and mask) through the decoder layer.
124   /// Args:
125   ///       tgt: the sequence to the decoder layer (required).
126   ///       memory: the sequence from the last layer of the encoder (required).
127   ///       tgt_mask: the mask for the tgt sequence (optional).
128   ///       memory_mask: the mask for the memory sequence (optional).
129   ///       tgt_key_padding_mask: the mask for the tgt keys per batch
130   ///       (optional). memory_key_padding_mask: the mask for the memory keys
131   ///       per batch (optional).
132   Tensor forward(
133       Tensor tgt,
134       const Tensor& memory,
135       const Tensor& tgt_mask = {},
136       const Tensor& memory_mask = {},
137       const Tensor& tgt_key_padding_mask = {},
138       const Tensor& memory_key_padding_mask = {});
139 
140   /// The options used to configure this module.
141   TransformerDecoderLayerOptions options;
142 
143   /// self attention
144   MultiheadAttention self_attn{nullptr};
145 
146   /// Dropout, post self attention
147   Dropout dropout1{nullptr};
148 
149   /// Normalization, post self attention
150   LayerNorm norm1{nullptr};
151 
152   /// Multi-headed attention
153   MultiheadAttention multihead_attn{nullptr};
154 
155   /// Dropout, post multi-headed attention
156   Dropout dropout2{nullptr};
157 
158   /// Normalization, post multi-headed attention
159   LayerNorm norm2{nullptr};
160 
161   /// Feed forward first linear layer
162   Linear linear1{nullptr};
163 
164   /// Feed forward dropout layer
165   Dropout dropout{nullptr};
166 
167   /// Feed forward second linear layer
168   Linear linear2{nullptr};
169 
170   /// Dropout, post feed forward
171   Dropout dropout3{nullptr};
172 
173   /// Normalization, post feed forward
174   LayerNorm norm3{nullptr};
175 
176  protected:
177   FORWARD_HAS_DEFAULT_ARGS(
178       {2, AnyValue(Tensor())},
179       {3, AnyValue(Tensor())},
180       {4, AnyValue(Tensor())},
181       {5, AnyValue(Tensor())})
182 
183   /// Apply activation based on configuration
184   Tensor activation(const Tensor& input);
185 };
186 
187 /// A `ModuleHolder` subclass for `TransformerDecoderLayerImpl`.
188 /// See the documentation for `TransformerDecoderLayerImpl` class to learn what
189 /// methods it provides, and examples of how to use `TransformerDecoderLayer`
190 /// with `torch::nn::TransformerDecoderLayerOptions`. See the documentation for
191 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
192 TORCH_MODULE(TransformerDecoderLayer);
193 
194 } // namespace nn
195 } // namespace torch
196