xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/transformer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/enum.h>
6 #include <torch/types.h>
7 
8 #include <torch/nn/modules/container/any.h>
9 #include <torch/nn/options/transformerlayer.h>
10 
11 namespace torch {
12 namespace nn {
13 
14 /// Options for the `Transformer` module
15 ///
16 /// Example:
17 /// ```
18 /// TransformerOptions options;
19 /// TransformerOptions options(16, 4);
20 /// auto options = TransformerOptions().d_model(4).nhead(2).dropout(0.0);
21 /// ```
22 struct TORCH_API TransformerOptions {
23   // The following constructors are commonly used
24   // Please don't add more unless it is proved as a common usage
25   TransformerOptions() = default;
26   TransformerOptions(int64_t d_model, int64_t nhead);
27   TransformerOptions(
28       int64_t d_model,
29       int64_t nhead,
30       int64_t num_encoder_layers,
31       int64_t num_decoder_layers);
32 
33   /// the number of expected features in the encoder/decoder inputs
34   /// (default=512)
35   TORCH_ARG(int64_t, d_model) = 512;
36 
37   /// the number of heads in the multiheadattention models (default=8)
38   TORCH_ARG(int64_t, nhead) = 8;
39 
40   /// the number of sub-encoder-layers in the encoder (default=6)
41   TORCH_ARG(int64_t, num_encoder_layers) = 6;
42 
43   /// the number of sub-decoder-layers in the decoder (default=6)
44   TORCH_ARG(int64_t, num_decoder_layers) = 6;
45 
46   /// the dimension of the feedforward network model (default=2048)
47   TORCH_ARG(int64_t, dim_feedforward) = 2048;
48 
49   /// the dropout value (default=0.1)
50   TORCH_ARG(double, dropout) = 0.1;
51 
52   /// the activation function of encoder/decoder intermediate layer
53   /// (default=``torch::kReLU``)
54   TORCH_ARG(activation_t, activation) = torch::kReLU;
55 
56   /// custom encoder (default=None)
57   TORCH_ARG(AnyModule, custom_encoder);
58 
59   /// custom decoder (default=None)
60   TORCH_ARG(AnyModule, custom_decoder);
61 };
62 
63 } // namespace nn
64 } // namespace torch
65