xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/transformercoder.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/enum.h>
6 #include <torch/types.h>
7 
8 #include <torch/nn/modules/container/any.h>
9 #include <torch/nn/modules/transformerlayer.h>
10 
11 namespace torch {
12 namespace nn {
13 
14 /// Options for the `TransformerEncoder`
15 ///
16 /// Example:
17 /// ```
18 /// TransformerEncoderLayer encoderLayer(TransformerEncoderLayerOptions(512,
19 /// 8).dropout(0.1)); auto options = TransformerEncoderOptions(encoderLayer,
20 /// 6).norm(LayerNorm(LayerNormOptions({2})));
21 /// ```
22 struct TORCH_API TransformerEncoderOptions {
23   // This constructor will keep a shallow copy of encoder_layer, so it keeps all
24   // the data in encoder_layer.
25   TransformerEncoderOptions(
26       TransformerEncoderLayer encoder_layer,
27       int64_t num_layers);
28   // This constructor will create a new TransformerEncoderLayer obj based on
29   // passed in encoder_layer_options.
30   TransformerEncoderOptions(
31       const TransformerEncoderLayerOptions& encoder_layer_options,
32       int64_t num_layers);
33 
34   /// transformer Encoder Layer
35   TORCH_ARG(TransformerEncoderLayer, encoder_layer) = nullptr;
36 
37   /// number of encoder layers
38   TORCH_ARG(int64_t, num_layers);
39 
40   /// normalization module
41   TORCH_ARG(AnyModule, norm);
42 };
43 
44 /// Options for the `TransformerDecoder` module.
45 ///
46 /// Example:
47 /// ```
48 /// TransformerDecoderLayer decoder_layer(TransformerDecoderLayerOptions(512,
49 /// 8).dropout(0.1)); auto options = TransformerDecoderOptions(decoder_layer,
50 /// 6)norm(LayerNorm(LayerNormOptions({2}))); TransformerDecoder
51 /// transformer_decoder(options);
52 /// ```
53 struct TORCH_API TransformerDecoderOptions {
54   // This constructor will keep the a ref of passed in decoder_layer,
55   // so it keeps all the data in decoder_layer.
56   TransformerDecoderOptions(
57       TransformerDecoderLayer decoder_layer,
58       int64_t num_layers);
59   // This constructor will create a new TransformerDecoderLayer obj,
60   // based on passed in decoder_layer_options.
61   TransformerDecoderOptions(
62       const TransformerDecoderLayerOptions& decoder_layer_options,
63       int64_t num_layers);
64 
65   /// decoder layer to be cloned
66   TORCH_ARG(TransformerDecoderLayer, decoder_layer) = nullptr;
67 
68   /// number of decoder layers
69   TORCH_ARG(int64_t, num_layers);
70 
71   /// normalization module
72   TORCH_ARG(AnyModule, norm);
73 };
74 
75 } // namespace nn
76 } // namespace torch
77