xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/transformercoder.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/modules/common.h>
6 #include <torch/nn/modules/container/any.h>
7 #include <torch/nn/modules/container/modulelist.h>
8 #include <torch/nn/options/transformercoder.h>
9 #include <torch/nn/pimpl.h>
10 
11 #include <torch/types.h>
12 
13 #include <ostream>
14 
15 namespace torch {
16 namespace nn {
17 
18 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoder
19 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
20 
21 /// TransformerEncoder module.
22 /// See
23 /// https://pytorch.org/docs/main/generated/torch.nn.TransformerEncoder.html
24 /// to learn abouut the exact behavior of this encoder layer module.
25 ///
26 /// See the documentation for `torch::nn::TransformerEncoder` class to learn
27 /// what constructor arguments are supported for this encoder module.
28 ///
29 /// Example:
30 /// ```
31 /// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512,
32 /// 8).dropout(0.1)); TransformerEncoder
33 /// encoder(TransformerEncoderOptions(encoderLayer,
34 /// 6).norm(LayerNorm(LayerNormOptions({2}))));
35 /// ```
36 class TORCH_API TransformerEncoderImpl
37     : public Cloneable<TransformerEncoderImpl> {
38  public:
TransformerEncoderImpl(TransformerEncoderLayer encoder_layer,int64_t num_layers)39   TransformerEncoderImpl(
40       TransformerEncoderLayer encoder_layer,
41       int64_t num_layers)
42       : TransformerEncoderImpl(
43             TransformerEncoderOptions(encoder_layer, num_layers)) {}
44   explicit TransformerEncoderImpl(TransformerEncoderOptions options_);
45 
46   Tensor forward(
47       const Tensor& src,
48       const Tensor& src_mask = {},
49       const Tensor& src_key_padding_mask = {});
50 
51   void reset() override;
52 
53   void reset_parameters();
54 
55  protected:
56   FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())})
57 
58  public:
59   /// options with which this `TransformerEncoder` was constructed
60   TransformerEncoderOptions options;
61 
62   /// module list that contains all the encoder layers
63   ModuleList layers = nullptr;
64 
65   /// optional normalization module
66   AnyModule norm;
67 };
68 
69 /// A `ModuleHolder` subclass for `TransformerEncoderImpl`.
70 /// See the documentation for `TransformerEncoderImpl` class to learn what
71 /// methods it provides, and examples of how to use `TransformerEncoder` with
72 /// `torch::nn::TransformerEncoderOptions`.
73 /// See the documentation for `ModuleHolder` to learn about PyTorch's
74 /// module storage semantics.
75 TORCH_MODULE(TransformerEncoder);
76 
77 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoder
78 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79 
80 /// TransformerDecoder is a stack of N decoder layers.
81 /// See
82 /// https://pytorch.org/docs/main/generated/torch.nn.TransformerDecoder.html
83 /// to learn abouut the exact behavior of this decoder module
84 ///
85 /// See the documentation for `torch::nn::TransformerDecoderOptions` class to
86 /// learn what constructor arguments are supported for this decoder module
87 ///
88 /// Example:
89 /// ```
90 /// TransformerDecoderLayer decoder_layer(TransformerDecoderLayerOptions(512,
91 /// 8).dropout(0.1)); TransformerDecoder
92 /// transformer_decoder(TransformerDecoderOptions(decoder_layer,
93 /// 6).norm(LayerNorm(LayerNormOptions({2})))); const auto memory =
94 /// torch::rand({10, 32, 512}); const auto tgt = torch::rand({20, 32, 512});
95 /// auto out = transformer_decoder(tgt, memory);
96 /// ```
97 class TORCH_API TransformerDecoderImpl
98     : public Cloneable<TransformerDecoderImpl> {
99  public:
TransformerDecoderImpl(TransformerDecoderLayer decoder_layer,int64_t num_layers)100   TransformerDecoderImpl(
101       TransformerDecoderLayer decoder_layer,
102       int64_t num_layers)
103       : TransformerDecoderImpl(
104             TransformerDecoderOptions(decoder_layer, num_layers)) {}
105   explicit TransformerDecoderImpl(TransformerDecoderOptions options_);
106 
107   void reset() override;
108 
109   void reset_parameters();
110 
111   /// Pass the inputs (and mask) through the decoder layer in turn.
112   /// Args:
113   ///       tgt: the sequence to the decoder layer (required).
114   ///       memory: the sequence from the last layer of the encoder (required).
115   ///       tgt_mask: the mask for the tgt sequence (optional).
116   ///       memory_mask: the mask for the memory sequence (optional).
117   ///       tgt_key_padding_mask: the mask for the tgt keys per batch
118   ///       (optional). memory_key_padding_mask: the mask for the memory keys
119   ///       per batch (optional).
120   Tensor forward(
121       const Tensor& tgt,
122       const Tensor& memory,
123       const Tensor& tgt_mask = {},
124       const Tensor& memory_mask = {},
125       const Tensor& tgt_key_padding_mask = {},
126       const Tensor& memory_key_padding_mask = {});
127 
128   /// The options used to configure this module.
129   TransformerDecoderOptions options;
130 
131   /// Cloned layers of decoder layers
132   ModuleList layers{nullptr};
133 
134   /// optional layer normalization module
135   AnyModule norm;
136 
137  protected:
138   FORWARD_HAS_DEFAULT_ARGS(
139       {2, AnyValue(Tensor())},
140       {3, AnyValue(Tensor())},
141       {4, AnyValue(Tensor())},
142       {5, AnyValue(Tensor())})
143 };
144 
145 /// A `ModuleHolder` subclass for `TransformerDecoderImpl`.
146 /// See the documentation for `TransformerDecoderImpl` class to learn what
147 /// methods it provides, and examples of how to use `TransformerDecoder` with
148 /// `torch::nn::TransformerDecoderOptions`.
149 /// See the documentation for `ModuleHolder` to learn about PyTorch's
150 /// module storage semantics.
151 TORCH_MODULE(TransformerDecoder);
152 
153 } // namespace nn
154 } // namespace torch
155