1 #include <torch/nn/options/activation.h>
2
3 namespace torch {
4 namespace nn {
5
SELUOptions(bool inplace)6 SELUOptions::SELUOptions(bool inplace) : inplace_(inplace) {}
7
GLUOptions(int64_t dim)8 GLUOptions::GLUOptions(int64_t dim) : dim_(dim) {}
9
HardshrinkOptions(double lambda)10 HardshrinkOptions::HardshrinkOptions(double lambda) : lambda_(lambda) {}
11
SoftmaxOptions(int64_t dim)12 SoftmaxOptions::SoftmaxOptions(int64_t dim) : dim_(dim) {}
13
SoftminOptions(int64_t dim)14 SoftminOptions::SoftminOptions(int64_t dim) : dim_(dim) {}
15
LogSoftmaxOptions(int64_t dim)16 LogSoftmaxOptions::LogSoftmaxOptions(int64_t dim) : dim_(dim) {}
17
ReLUOptions(bool inplace)18 ReLUOptions::ReLUOptions(bool inplace) : inplace_(inplace) {}
19
ReLU6Options(bool inplace)20 ReLU6Options::ReLU6Options(bool inplace) : inplace_(inplace) {}
21
SoftshrinkOptions(double lambda)22 SoftshrinkOptions::SoftshrinkOptions(double lambda) : lambda_(lambda) {}
23
MultiheadAttentionOptions(int64_t embed_dim,int64_t num_heads)24 MultiheadAttentionOptions::MultiheadAttentionOptions(
25 int64_t embed_dim,
26 int64_t num_heads)
27 : embed_dim_(embed_dim),
28 num_heads_(num_heads),
29 kdim_(embed_dim),
30 vdim_(embed_dim) {}
31
32 namespace functional {
33
SoftmaxFuncOptions(int64_t dim)34 SoftmaxFuncOptions::SoftmaxFuncOptions(int64_t dim) : dim_(dim) {}
35
SoftminFuncOptions(int64_t dim)36 SoftminFuncOptions::SoftminFuncOptions(int64_t dim) : dim_(dim) {}
37
LogSoftmaxFuncOptions(int64_t dim)38 LogSoftmaxFuncOptions::LogSoftmaxFuncOptions(int64_t dim) : dim_(dim) {}
39
MultiheadAttentionForwardFuncOptions(int64_t embed_dim_to_check,int64_t num_heads,Tensor in_proj_weight,Tensor in_proj_bias,Tensor bias_k,Tensor bias_v,bool add_zero_attn,double dropout_p,Tensor out_proj_weight,Tensor out_proj_bias)40 MultiheadAttentionForwardFuncOptions::MultiheadAttentionForwardFuncOptions(
41 int64_t embed_dim_to_check,
42 int64_t num_heads,
43 Tensor in_proj_weight,
44 Tensor in_proj_bias,
45 Tensor bias_k,
46 Tensor bias_v,
47 bool add_zero_attn,
48 double dropout_p,
49 Tensor out_proj_weight,
50 Tensor out_proj_bias)
51 : embed_dim_to_check_(embed_dim_to_check),
52 num_heads_(num_heads),
53 in_proj_weight_(std::move(in_proj_weight)),
54 in_proj_bias_(std::move(in_proj_bias)),
55 bias_k_(std::move(bias_k)),
56 bias_v_(std::move(bias_v)),
57 add_zero_attn_(add_zero_attn),
58 dropout_p_(dropout_p),
59 out_proj_weight_(std::move(out_proj_weight)),
60 out_proj_bias_(std::move(out_proj_bias)) {}
61
62 } // namespace functional
63 } // namespace nn
64 } // namespace torch
65