xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/options/activation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/options/activation.h>
2 
3 namespace torch {
4 namespace nn {
5 
SELUOptions(bool inplace)6 SELUOptions::SELUOptions(bool inplace) : inplace_(inplace) {}
7 
GLUOptions(int64_t dim)8 GLUOptions::GLUOptions(int64_t dim) : dim_(dim) {}
9 
HardshrinkOptions(double lambda)10 HardshrinkOptions::HardshrinkOptions(double lambda) : lambda_(lambda) {}
11 
SoftmaxOptions(int64_t dim)12 SoftmaxOptions::SoftmaxOptions(int64_t dim) : dim_(dim) {}
13 
SoftminOptions(int64_t dim)14 SoftminOptions::SoftminOptions(int64_t dim) : dim_(dim) {}
15 
LogSoftmaxOptions(int64_t dim)16 LogSoftmaxOptions::LogSoftmaxOptions(int64_t dim) : dim_(dim) {}
17 
ReLUOptions(bool inplace)18 ReLUOptions::ReLUOptions(bool inplace) : inplace_(inplace) {}
19 
ReLU6Options(bool inplace)20 ReLU6Options::ReLU6Options(bool inplace) : inplace_(inplace) {}
21 
SoftshrinkOptions(double lambda)22 SoftshrinkOptions::SoftshrinkOptions(double lambda) : lambda_(lambda) {}
23 
MultiheadAttentionOptions(int64_t embed_dim,int64_t num_heads)24 MultiheadAttentionOptions::MultiheadAttentionOptions(
25     int64_t embed_dim,
26     int64_t num_heads)
27     : embed_dim_(embed_dim),
28       num_heads_(num_heads),
29       kdim_(embed_dim),
30       vdim_(embed_dim) {}
31 
32 namespace functional {
33 
SoftmaxFuncOptions(int64_t dim)34 SoftmaxFuncOptions::SoftmaxFuncOptions(int64_t dim) : dim_(dim) {}
35 
SoftminFuncOptions(int64_t dim)36 SoftminFuncOptions::SoftminFuncOptions(int64_t dim) : dim_(dim) {}
37 
LogSoftmaxFuncOptions(int64_t dim)38 LogSoftmaxFuncOptions::LogSoftmaxFuncOptions(int64_t dim) : dim_(dim) {}
39 
MultiheadAttentionForwardFuncOptions(int64_t embed_dim_to_check,int64_t num_heads,Tensor in_proj_weight,Tensor in_proj_bias,Tensor bias_k,Tensor bias_v,bool add_zero_attn,double dropout_p,Tensor out_proj_weight,Tensor out_proj_bias)40 MultiheadAttentionForwardFuncOptions::MultiheadAttentionForwardFuncOptions(
41     int64_t embed_dim_to_check,
42     int64_t num_heads,
43     Tensor in_proj_weight,
44     Tensor in_proj_bias,
45     Tensor bias_k,
46     Tensor bias_v,
47     bool add_zero_attn,
48     double dropout_p,
49     Tensor out_proj_weight,
50     Tensor out_proj_bias)
51     : embed_dim_to_check_(embed_dim_to_check),
52       num_heads_(num_heads),
53       in_proj_weight_(std::move(in_proj_weight)),
54       in_proj_bias_(std::move(in_proj_bias)),
55       bias_k_(std::move(bias_k)),
56       bias_v_(std::move(bias_v)),
57       add_zero_attn_(add_zero_attn),
58       dropout_p_(dropout_p),
59       out_proj_weight_(std::move(out_proj_weight)),
60       out_proj_bias_(std::move(out_proj_bias)) {}
61 
62 } // namespace functional
63 } // namespace nn
64 } // namespace torch
65