xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/transformerlayer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/enum.h>
6 #include <torch/types.h>
7 
8 namespace torch {
9 namespace nn {
10 
11 using activation_t = std::variant<
12     enumtype::kReLU,
13     enumtype::kGELU,
14     std::function<Tensor(const Tensor&)>>;
15 
16 /// Options for the `TransformerEncoderLayer`
17 ///
18 /// Example:
19 /// ```
20 /// auto options = TransformerEncoderLayer(512, 8).dropout(0.2);
21 /// ```
22 struct TORCH_API TransformerEncoderLayerOptions {
23   /* implicit */ TransformerEncoderLayerOptions(int64_t d_model, int64_t nhead);
24 
25   /// the number of expected features in the input
26   TORCH_ARG(int64_t, d_model);
27 
28   /// the number of heads in the multiheadattention models
29   TORCH_ARG(int64_t, nhead);
30 
31   /// the dimension of the feedforward network model, default is 2048
32   TORCH_ARG(int64_t, dim_feedforward) = 2048;
33 
34   /// the dropout value, default is 0.1
35   TORCH_ARG(double, dropout) = 0.1;
36 
37   /// the activation function of intermediate layer, can be ``torch::kReLU``,
38   /// ``torch::GELU``, or a unary callable. Default: ``torch::kReLU``
39   TORCH_ARG(activation_t, activation) = torch::kReLU;
40 };
41 
42 // ============================================================================
43 
44 /// Options for the `TransformerDecoderLayer` module.
45 ///
46 /// Example:
47 /// ```
48 /// TransformerDecoderLayer model(TransformerDecoderLayerOptions(512,
49 /// 8).dropout(0.2));
50 /// ```
51 struct TORCH_API TransformerDecoderLayerOptions {
52   TransformerDecoderLayerOptions(int64_t d_model, int64_t nhead);
53 
54   /// number of expected features in the input
55   TORCH_ARG(int64_t, d_model);
56 
57   /// number of heads in the multiheadattention models
58   TORCH_ARG(int64_t, nhead);
59 
60   /// dimension of the feedforward network model. Default: 2048
61   TORCH_ARG(int64_t, dim_feedforward) = 2048;
62 
63   /// dropout value. Default: 1
64   TORCH_ARG(double, dropout) = 0.1;
65 
66   /// activation function of intermediate layer, can be ``torch::kGELU``,
67   /// ``torch::kReLU``, or a unary callable. Default: ``torch::kReLU``
68   TORCH_ARG(activation_t, activation) = torch::kReLU;
69 };
70 
71 } // namespace nn
72 } // namespace torch
73