1 #include <torch/nn/functional/activation.h>
2 #include <torch/nn/init.h>
3 #include <torch/nn/modules/activation.h>
4
5 namespace F = torch::nn::functional;
6
7 namespace torch {
8 namespace nn {
9
ELUImpl(const ELUOptions & options_)10 ELUImpl::ELUImpl(const ELUOptions& options_) : options(options_) {}
11
forward(Tensor input)12 Tensor ELUImpl::forward(Tensor input) {
13 return F::detail::elu(input, options.alpha(), options.inplace());
14 }
15
reset()16 void ELUImpl::reset() {}
17
pretty_print(std::ostream & stream) const18 void ELUImpl::pretty_print(std::ostream& stream) const {
19 stream << "torch::nn::ELU(alpha=" << options.alpha();
20 if (options.inplace()) {
21 stream << std::boolalpha << ", inplace=" << options.inplace();
22 }
23 stream << ")";
24 }
25
26 // ============================================================================
27
SELUImpl(const SELUOptions & options_)28 SELUImpl::SELUImpl(const SELUOptions& options_) : options(options_) {}
29
forward(Tensor input)30 Tensor SELUImpl::forward(Tensor input) {
31 return F::detail::selu(input, options.inplace());
32 }
33
reset()34 void SELUImpl::reset() {}
35
pretty_print(std::ostream & stream) const36 void SELUImpl::pretty_print(std::ostream& stream) const {
37 stream << "torch::nn::SELU(";
38 if (options.inplace()) {
39 stream << std::boolalpha << "inplace=" << options.inplace();
40 }
41 stream << ")";
42 }
43
44 // ============================================================================
45
HardshrinkImpl(const HardshrinkOptions & options_)46 HardshrinkImpl::HardshrinkImpl(const HardshrinkOptions& options_)
47 : options(options_) {}
48
forward(const Tensor & input)49 Tensor HardshrinkImpl::forward(const Tensor& input) {
50 return F::detail::hardshrink(input, options.lambda());
51 }
52
reset()53 void HardshrinkImpl::reset() {}
54
pretty_print(std::ostream & stream) const55 void HardshrinkImpl::pretty_print(std::ostream& stream) const {
56 stream << std::boolalpha << "torch::nn::Hardshrink(" << options.lambda()
57 << ")";
58 }
59
60 // ============================================================================
61
HardtanhImpl(const HardtanhOptions & options_)62 HardtanhImpl::HardtanhImpl(const HardtanhOptions& options_)
63 : options(options_) {
64 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
65 reset();
66 }
67
forward(Tensor input)68 Tensor HardtanhImpl::forward(Tensor input) {
69 return F::detail::hardtanh(
70 input, options.min_val(), options.max_val(), options.inplace());
71 }
72
reset()73 void HardtanhImpl::reset() {
74 TORCH_CHECK(
75 options.max_val() > options.min_val(),
76 "max_val must be greater than min_val");
77 }
78
pretty_print(std::ostream & stream) const79 void HardtanhImpl::pretty_print(std::ostream& stream) const {
80 stream << std::boolalpha
81 << "torch::nn::Hardtanh(min_val=" << options.min_val()
82 << ", max_val=" << options.max_val();
83 if (options.inplace()) {
84 stream << std::boolalpha << ", inplace=" << options.inplace();
85 }
86 stream << ")";
87 }
88
89 // ============================================================================
90
LeakyReLUImpl(const LeakyReLUOptions & options_)91 LeakyReLUImpl::LeakyReLUImpl(const LeakyReLUOptions& options_)
92 : options(options_) {}
93
forward(Tensor input)94 Tensor LeakyReLUImpl::forward(Tensor input) {
95 return F::detail::leaky_relu(
96 input, options.negative_slope(), options.inplace());
97 }
98
reset()99 void LeakyReLUImpl::reset() {}
100
pretty_print(std::ostream & stream) const101 void LeakyReLUImpl::pretty_print(std::ostream& stream) const {
102 stream << std::boolalpha
103 << "torch::nn::LeakyReLU(negative_slope=" << options.negative_slope();
104 if (options.inplace()) {
105 stream << std::boolalpha << ", inplace=" << options.inplace();
106 }
107 stream << ")";
108 }
109
110 // ============================================================================
111
forward(const Tensor & input)112 Tensor LogSigmoidImpl::forward(const Tensor& input) {
113 return F::logsigmoid(input);
114 }
115
reset()116 void LogSigmoidImpl::reset() {}
117
pretty_print(std::ostream & stream) const118 void LogSigmoidImpl::pretty_print(std::ostream& stream) const {
119 stream << "torch::nn::LogSigmoid()";
120 }
121
122 // ============================================================================
123
SoftmaxImpl(const SoftmaxOptions & options_)124 SoftmaxImpl::SoftmaxImpl(const SoftmaxOptions& options_) : options(options_) {}
125
reset()126 void SoftmaxImpl::reset() {}
127
pretty_print(std::ostream & stream) const128 void SoftmaxImpl::pretty_print(std::ostream& stream) const {
129 stream << "torch::nn::Softmax(dim=" << options.dim() << ")";
130 }
131
forward(const Tensor & input)132 Tensor SoftmaxImpl::forward(const Tensor& input) {
133 return F::detail::softmax(input, options.dim(), std::nullopt);
134 }
135
136 // ============================================================================
137
SoftminImpl(const SoftminOptions & options_)138 SoftminImpl::SoftminImpl(const SoftminOptions& options_) : options(options_) {}
139
reset()140 void SoftminImpl::reset() {}
141
pretty_print(std::ostream & stream) const142 void SoftminImpl::pretty_print(std::ostream& stream) const {
143 stream << "torch::nn::Softmin(dim=" << options.dim() << ")";
144 }
145
forward(const Tensor & input)146 Tensor SoftminImpl::forward(const Tensor& input) {
147 return F::detail::softmin(input, options.dim(), std::nullopt);
148 }
149
150 // ============================================================================
151
LogSoftmaxImpl(const LogSoftmaxOptions & options_)152 LogSoftmaxImpl::LogSoftmaxImpl(const LogSoftmaxOptions& options_)
153 : options(options_) {}
154
reset()155 void LogSoftmaxImpl::reset() {}
156
pretty_print(std::ostream & stream) const157 void LogSoftmaxImpl::pretty_print(std::ostream& stream) const {
158 stream << "torch::nn::LogSoftmax(dim=" << options.dim() << ")";
159 }
160
forward(const Tensor & input)161 Tensor LogSoftmaxImpl::forward(const Tensor& input) {
162 return F::detail::log_softmax(input, options.dim(), std::nullopt);
163 }
164
165 // ============================================================================
166
reset()167 void Softmax2dImpl::reset() {}
168
pretty_print(std::ostream & stream) const169 void Softmax2dImpl::pretty_print(std::ostream& stream) const {
170 stream << "torch::nn::Softmax2d()";
171 }
172
forward(const Tensor & input)173 Tensor Softmax2dImpl::forward(const Tensor& input) {
174 TORCH_CHECK(
175 input.dim() == 4 || input.dim() == 3,
176 "Softmax2d requires a 3D or 4D tensor as input");
177 return F::detail::softmax(input, /*dim=*/-3, std::nullopt);
178 }
179
180 // ============================================================================
181
PReLUImpl(const PReLUOptions & options_)182 PReLUImpl::PReLUImpl(const PReLUOptions& options_) : options(options_) {
183 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
184 reset();
185 }
186
forward(const Tensor & input)187 Tensor PReLUImpl::forward(const Tensor& input) {
188 return F::prelu(input, weight);
189 }
190
reset()191 void PReLUImpl::reset() {
192 weight = register_parameter(
193 "weight", torch::full(options.num_parameters(), options.init()));
194 }
195
pretty_print(std::ostream & stream) const196 void PReLUImpl::pretty_print(std::ostream& stream) const {
197 stream << "torch::nn::PReLU(num_parameters=" << options.num_parameters()
198 << ")";
199 }
200
201 // ============================================================================
202
ReLUImpl(const ReLUOptions & options_)203 ReLUImpl::ReLUImpl(const ReLUOptions& options_) : options(options_) {}
204
forward(Tensor input)205 Tensor ReLUImpl::forward(Tensor input) {
206 return F::detail::relu(input, options.inplace());
207 }
208
reset()209 void ReLUImpl::reset() {}
210
pretty_print(std::ostream & stream) const211 void ReLUImpl::pretty_print(std::ostream& stream) const {
212 stream << "torch::nn::ReLU(";
213 if (options.inplace()) {
214 stream << std::boolalpha << "inplace=" << options.inplace();
215 }
216 stream << ")";
217 }
218
219 // ============================================================================
220
ReLU6Impl(const ReLU6Options & options_)221 ReLU6Impl::ReLU6Impl(const ReLU6Options& options_) : options(options_) {}
222
forward(Tensor input)223 Tensor ReLU6Impl::forward(Tensor input) {
224 return F::detail::relu6(input, options.inplace());
225 }
226
reset()227 void ReLU6Impl::reset() {}
228
pretty_print(std::ostream & stream) const229 void ReLU6Impl::pretty_print(std::ostream& stream) const {
230 stream << "torch::nn::ReLU6(";
231 if (options.inplace()) {
232 stream << std::boolalpha << "inplace=" << options.inplace();
233 }
234 stream << ")";
235 }
236
237 // ============================================================================
238
RReLUImpl(const RReLUOptions & options_)239 RReLUImpl::RReLUImpl(const RReLUOptions& options_) : options(options_) {}
240
forward(Tensor input)241 Tensor RReLUImpl::forward(Tensor input) {
242 return F::detail::rrelu(
243 input,
244 options.lower(),
245 options.upper(),
246 is_training(),
247 options.inplace());
248 }
249
reset()250 void RReLUImpl::reset() {}
251
pretty_print(std::ostream & stream) const252 void RReLUImpl::pretty_print(std::ostream& stream) const {
253 stream << "torch::nn::RReLU(lower=" << options.lower()
254 << ", upper=" << options.upper();
255 if (options.inplace()) {
256 stream << std::boolalpha << ", inplace=" << options.inplace();
257 }
258 stream << ")";
259 }
260
261 // ============================================================================
262
CELUImpl(const CELUOptions & options_)263 CELUImpl::CELUImpl(const CELUOptions& options_) : options(options_) {}
264
forward(Tensor input)265 Tensor CELUImpl::forward(Tensor input) {
266 return F::detail::celu(input, options.alpha(), options.inplace());
267 }
268
reset()269 void CELUImpl::reset() {}
270
pretty_print(std::ostream & stream) const271 void CELUImpl::pretty_print(std::ostream& stream) const {
272 stream << "torch::nn::CELU(alpha=" << options.alpha();
273 if (options.inplace()) {
274 stream << std::boolalpha << ", inplace=" << options.inplace();
275 }
276 stream << ")";
277 }
278
279 // ============================================================================
280
GLUImpl(const GLUOptions & options_)281 GLUImpl::GLUImpl(const GLUOptions& options_) : options(options_) {}
282
forward(const Tensor & input)283 Tensor GLUImpl::forward(const Tensor& input) {
284 return F::detail::glu(input, options.dim());
285 }
286
reset()287 void GLUImpl::reset() {}
288
pretty_print(std::ostream & stream) const289 void GLUImpl::pretty_print(std::ostream& stream) const {
290 stream << "torch::nn::GLU(dim=" << options.dim() << ")";
291 }
292
293 // ============================================================================
294
GELUImpl(GELUOptions options_)295 GELUImpl::GELUImpl(GELUOptions options_) : options(std::move(options_)) {}
296
forward(const Tensor & input)297 Tensor GELUImpl::forward(const Tensor& input) {
298 return F::detail::gelu(input, options.approximate());
299 }
300
reset()301 void GELUImpl::reset() {}
302
pretty_print(std::ostream & stream) const303 void GELUImpl::pretty_print(std::ostream& stream) const {
304 stream << "torch::nn::GELU()";
305 }
306
307 // ============================================================================
308
forward(const Tensor & input)309 Tensor SiLUImpl::forward(const Tensor& input) {
310 return F::silu(input);
311 }
312
reset()313 void SiLUImpl::reset() {}
314
pretty_print(std::ostream & stream) const315 void SiLUImpl::pretty_print(std::ostream& stream) const {
316 stream << "torch::nn::SiLU()";
317 }
318
319 // ============================================================================
320
forward(const Tensor & input)321 Tensor MishImpl::forward(const Tensor& input) {
322 return F::mish(input);
323 }
324
reset()325 void MishImpl::reset() {}
326
pretty_print(std::ostream & stream) const327 void MishImpl::pretty_print(std::ostream& stream) const {
328 stream << "torch::nn::Mish()";
329 }
330
331 // ============================================================================
332
forward(const Tensor & input)333 Tensor SigmoidImpl::forward(const Tensor& input) {
334 return torch::sigmoid(input);
335 }
336
reset()337 void SigmoidImpl::reset() {}
338
pretty_print(std::ostream & stream) const339 void SigmoidImpl::pretty_print(std::ostream& stream) const {
340 stream << "torch::nn::Sigmoid()";
341 }
342
343 // ============================================================================
344
SoftplusImpl(const SoftplusOptions & options_)345 SoftplusImpl::SoftplusImpl(const SoftplusOptions& options_)
346 : options(options_) {}
347
forward(const Tensor & input)348 Tensor SoftplusImpl::forward(const Tensor& input) {
349 return F::detail::softplus(input, options.beta(), options.threshold());
350 }
351
reset()352 void SoftplusImpl::reset() {}
353
pretty_print(std::ostream & stream) const354 void SoftplusImpl::pretty_print(std::ostream& stream) const {
355 stream << "torch::nn::Softplus(beta=" << options.beta()
356 << ", threshold=" << options.threshold() << ")";
357 }
358
359 // ============================================================================
360
SoftshrinkImpl(const SoftshrinkOptions & options_)361 SoftshrinkImpl::SoftshrinkImpl(const SoftshrinkOptions& options_)
362 : options(options_) {}
363
forward(const Tensor & input)364 Tensor SoftshrinkImpl::forward(const Tensor& input) {
365 return F::detail::softshrink(input, options.lambda());
366 }
367
reset()368 void SoftshrinkImpl::reset() {}
369
pretty_print(std::ostream & stream) const370 void SoftshrinkImpl::pretty_print(std::ostream& stream) const {
371 stream << "torch::nn::Softshrink(" << options.lambda() << ")";
372 }
373
374 // ============================================================================
375
forward(const Tensor & input)376 Tensor SoftsignImpl::forward(const Tensor& input) {
377 return F::softsign(input);
378 }
379
reset()380 void SoftsignImpl::reset() {}
381
pretty_print(std::ostream & stream) const382 void SoftsignImpl::pretty_print(std::ostream& stream) const {
383 stream << "torch::nn::Softsign()";
384 }
385
386 // ============================================================================
387
forward(const Tensor & input)388 Tensor TanhImpl::forward(const Tensor& input) {
389 return torch::tanh(input);
390 }
391
reset()392 void TanhImpl::reset() {}
393
pretty_print(std::ostream & stream) const394 void TanhImpl::pretty_print(std::ostream& stream) const {
395 stream << "torch::nn::Tanh()";
396 }
397
398 // ============================================================================
399
forward(const Tensor & input)400 Tensor TanhshrinkImpl::forward(const Tensor& input) {
401 return F::tanhshrink(input);
402 }
403
reset()404 void TanhshrinkImpl::reset() {}
405
pretty_print(std::ostream & stream) const406 void TanhshrinkImpl::pretty_print(std::ostream& stream) const {
407 stream << "torch::nn::Tanhshrink()";
408 }
409
410 // ============================================================================
411
ThresholdImpl(const ThresholdOptions & options_)412 ThresholdImpl::ThresholdImpl(const ThresholdOptions& options_)
413 : options(options_) {}
414
forward(Tensor input)415 Tensor ThresholdImpl::forward(Tensor input) {
416 return F::detail::threshold(
417 input, options.threshold(), options.value(), options.inplace());
418 }
419
reset()420 void ThresholdImpl::reset() {}
421
pretty_print(std::ostream & stream) const422 void ThresholdImpl::pretty_print(std::ostream& stream) const {
423 stream << "torch::nn::Threshold(threshold=" << options.threshold()
424 << ", value=" << options.value();
425 if (options.inplace()) {
426 stream << std::boolalpha << ", inplace=" << options.inplace();
427 }
428 stream << ")";
429 }
430
431 // ============================================================================
432
MultiheadAttentionImpl(const MultiheadAttentionOptions & options_)433 MultiheadAttentionImpl::MultiheadAttentionImpl(
434 const MultiheadAttentionOptions& options_)
435 : Cloneable("torch::nn::MultiheadAttention"), options(options_) {
436 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
437 reset();
438 }
439
forward(const Tensor & query,const Tensor & key,const Tensor & value,const Tensor & key_padding_mask,bool need_weights,const Tensor & attn_mask,bool average_attn_weights)440 std::tuple<Tensor, Tensor> MultiheadAttentionImpl::forward(
441 const Tensor& query,
442 const Tensor& key,
443 const Tensor& value,
444 const Tensor& key_padding_mask,
445 bool need_weights,
446 const Tensor& attn_mask,
447 bool average_attn_weights) {
448 if (!_qkv_same_embed_dim) {
449 return F::multi_head_attention_forward(
450 query,
451 key,
452 value,
453 F::MultiheadAttentionForwardFuncOptions(
454 /*embed_dim_to_check=*/options.embed_dim(),
455 /*num_heads=*/options.num_heads(),
456 /*in_proj_weight=*/in_proj_weight,
457 /*in_proj_bias=*/in_proj_bias,
458 /*bias_k=*/bias_k,
459 /*bias_v=*/bias_v,
460 /*add_zero_attn=*/options.add_zero_attn(),
461 /*dropout_p=*/options.dropout(),
462 /*out_proj_weight=*/out_proj->weight,
463 /*out_proj_bias=*/out_proj->bias)
464 .training(is_training())
465 .key_padding_mask(key_padding_mask)
466 .need_weights(need_weights)
467 .attn_mask(attn_mask)
468 .use_separate_proj_weight(true)
469 .q_proj_weight(q_proj_weight)
470 .k_proj_weight(k_proj_weight)
471 .v_proj_weight(v_proj_weight)
472 .average_attn_weights(average_attn_weights));
473 } else {
474 return F::multi_head_attention_forward(
475 query,
476 key,
477 value,
478 F::MultiheadAttentionForwardFuncOptions(
479 /*embed_dim_to_check=*/options.embed_dim(),
480 /*num_heads=*/options.num_heads(),
481 /*in_proj_weight=*/in_proj_weight,
482 /*in_proj_bias=*/in_proj_bias,
483 /*bias_k=*/bias_k,
484 /*bias_v=*/bias_v,
485 /*add_zero_attn=*/options.add_zero_attn(),
486 /*dropout_p=*/options.dropout(),
487 /*out_proj_weight=*/out_proj->weight,
488 /*out_proj_bias=*/out_proj->bias)
489 .training(is_training())
490 .key_padding_mask(key_padding_mask)
491 .need_weights(need_weights)
492 .attn_mask(attn_mask)
493 .average_attn_weights(average_attn_weights));
494 }
495 }
496
reset()497 void MultiheadAttentionImpl::reset() {
498 _qkv_same_embed_dim = options.kdim() == options.embed_dim() &&
499 options.vdim() == options.embed_dim();
500 head_dim = options.embed_dim() / options.num_heads();
501 TORCH_CHECK(
502 head_dim * options.num_heads() == options.embed_dim(),
503 "embed_dim must be divisible by num_heads");
504 if (!_qkv_same_embed_dim) {
505 q_proj_weight = register_parameter(
506 "q_proj_weight",
507 torch::empty({options.embed_dim(), options.embed_dim()}));
508 k_proj_weight = register_parameter(
509 "k_proj_weight", torch::empty({options.embed_dim(), options.kdim()}));
510 v_proj_weight = register_parameter(
511 "v_proj_weight", torch::empty({options.embed_dim(), options.vdim()}));
512 register_parameter("in_proj_weight", {}, /*requires_grad=*/false);
513 } else {
514 in_proj_weight = register_parameter(
515 "in_proj_weight",
516 torch::empty({3 * options.embed_dim(), options.embed_dim()}));
517 register_parameter("q_proj_weight", {}, /*requires_grad=*/false);
518 register_parameter("k_proj_weight", {}, /*requires_grad=*/false);
519 register_parameter("v_proj_weight", {}, /*requires_grad=*/false);
520 }
521 if (options.bias()) {
522 in_proj_bias = register_parameter(
523 "in_proj_bias", torch::empty(3 * options.embed_dim()));
524 } else {
525 register_parameter("in_proj_bias", {}, /*requires_grad=*/false);
526 }
527 out_proj = register_module(
528 "out_proj",
529 Linear(LinearOptions(options.embed_dim(), options.embed_dim())
530 .bias(options.bias())));
531 if (options.add_bias_kv()) {
532 bias_k =
533 register_parameter("bias_k", torch::empty({1, 1, options.embed_dim()}));
534 bias_v =
535 register_parameter("bias_v", torch::empty({1, 1, options.embed_dim()}));
536 } else {
537 bias_k.reset();
538 bias_v.reset();
539 }
540 _reset_parameters();
541 }
542
_reset_parameters()543 void MultiheadAttentionImpl::_reset_parameters() {
544 using namespace torch::nn::init;
545 if (_qkv_same_embed_dim) {
546 xavier_uniform_(in_proj_weight);
547 } else {
548 xavier_uniform_(q_proj_weight);
549 xavier_uniform_(k_proj_weight);
550 xavier_uniform_(v_proj_weight);
551 }
552 if (in_proj_bias.defined()) {
553 constant_(in_proj_bias, 0.);
554 constant_(out_proj->bias, 0.);
555 }
556 if (bias_k.defined()) {
557 xavier_normal_(bias_k);
558 }
559 if (bias_v.defined()) {
560 xavier_normal_(bias_v);
561 }
562 }
563
564 } // namespace nn
565 } // namespace torch
566