xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/transformer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/modules/common.h>
6 #include <torch/nn/options/transformer.h>
7 #include <torch/nn/pimpl.h>
8 
9 #include <torch/types.h>
10 
11 #include <ostream>
12 
13 namespace torch {
14 namespace nn {
15 
16 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Transformer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
17 
18 /// A transformer model. User is able to modify the attributes as needed. The
19 /// architecture is based on the paper "Attention Is All You Need". Ashish
20 /// Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N
21 /// Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need.
22 /// In Advances in Neural Information Processing Systems, pages 6000-6010.
23 ///
24 /// See https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html to
25 /// learn about the exact behavior of this transformer model
26 ///
27 /// See the documentation for `torch::nn::Transformer` class to learn what
28 /// constructor arguments are supported for this encoder layer model
29 ///
30 /// Example:
31 /// ```
32 /// Transformer trans(TransformerOptions(512, 8));
33 /// ```
34 class TORCH_API TransformerImpl : public Cloneable<TransformerImpl> {
35  public:
36   explicit TransformerImpl(TransformerOptions options_);
37 
38   /// forward function for Transformer Module
39   /// Args:
40   ///   src: the sequence to the encoder (required).
41   ///   tgt: the sequence to the decoder (required).
42   ///   src_mask: the additive mask for the src sequence (optional).
43   ///   tgt_mask: the additive mask for the tgt sequence (optional).
44   ///   memory_mask: the additive mask for the encoder output (optional).
45   ///   src_key_padding_mask: the ByteTensor mask for src keys per batch
46   ///   (optional). tgt_key_padding_mask: the ByteTensor mask for tgt keys per
47   ///   batch (optional). memory_key_padding_mask: the ByteTensor mask for
48   ///   memory keys per batch (optional).
49   ///
50   /// Shape:
51   ///   src: `(S, N, E)`
52   ///   tgt: `(T, N, E)`
53   ///   src_mask: `(S, S)`
54   ///   tgt_mask: `(T, T)`
55   ///   memory_mask: `(T, S)`
56   ///   src_key_padding_mask: `(N, S)`
57   ///   tgt_key_padding_mask: `(N, T)`
58   ///   memory_key_padding_mask: `(N, S)`
59   ///
60   ///   Note:
61   ///     [src/tgt/memory]_mask ensures that position i is allowed to attend the
62   ///     unmasked positions. If a ByteTensor is provided, the non-zero
63   ///     positions are not allowed to attend while the zero positions will be
64   ///     unchanged. If a BoolTensor is provided, positions with `True` are not
65   ///     allowed to attend while `False` values will be unchanged. If a
66   ///     FloatTensor is provided, it will be added to the attention weight.
67   ///
68   ///     [src/tgt/memory]_key_padding_mask provides specified elements in the
69   ///     key to be ignored by the attention. If a ByteTensor is provided, the
70   ///     non-zero positions will be ignored while the zero positions will be
71   ///     unchanged. If a BoolTensor is provided, the positions with the value
72   ///     of `True` will be ignored while the position with the value of `False`
73   ///     will be unchanged.
74   ///
75   ///   output: `(T, N, E)`
76   ///
77   ///   Note:
78   ///     Due to the multi-head attention architecture in the transformer model,
79   ///     the output sequence length of a transformer is same as the input
80   ///     sequence (i.e. target) length of the decode.
81   ///
82   ///   where
83   ///   S is the source sequence length,
84   ///   T is the target sequence length,
85   ///   N is the batch size,
86   ///   E is the feature number.
87   Tensor forward(
88       const Tensor& src,
89       const Tensor& tgt,
90       const Tensor& src_mask = {},
91       const Tensor& tgt_mask = {},
92       const Tensor& memory_mask = {},
93       const Tensor& src_key_padding_mask = {},
94       const Tensor& tgt_key_padding_mask = {},
95       const Tensor& memory_key_padding_mask = {});
96 
97   void reset() override;
98 
99   void reset_parameters();
100 
101   /// Generate a square mask for the sequence.
102   /// The masked positions are filled with `-inf` in float type.
103   /// Unmasked positions are filled with `0.0` in float type.
104   /// Note:
105   ///   1. This function will always return a CPU tensor.
106   ///   2. This function requires the platform support IEEE754, since `-inf` is
107   ///   guaranteed to
108   ///      be valid only when IEEE754 is supported. If the platform doesn't
109   ///      support IEEE754, this function will fill the mask with the smallest
110   ///      float number instead of `-inf`, a one time warning will pop up as
111   ///      well.
112   static Tensor generate_square_subsequent_mask(int64_t sz);
113 
114  protected:
115   FORWARD_HAS_DEFAULT_ARGS(
116       {2, AnyValue(Tensor())},
117       {3, AnyValue(Tensor())},
118       {4, AnyValue(Tensor())},
119       {5, AnyValue(Tensor())},
120       {6, AnyValue(Tensor())},
121       {7, AnyValue(Tensor())})
122 
123  public:
124   /// options with which this `Transformer` was constructed
125   TransformerOptions options;
126 
127   /// encoder module
128   AnyModule encoder;
129 
130   /// decoder module
131   AnyModule decoder;
132 };
133 
134 /// A `ModuleHolder` subclass for `TransformerImpl`.
135 /// See the documentation for `TransformerImpl` class to learn what
136 /// methods it provides, and examples of how to use `Transformer` with
137 /// `torch::nn::TransformerOptions`.
138 /// See the documentation for `ModuleHolder` to learn about PyTorch's
139 /// module storage semantics.
140 TORCH_MODULE(Transformer);
141 
142 } // namespace nn
143 } // namespace torch
144