xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/activation.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/functional/activation.h>
5 #include <torch/nn/modules/common.h>
6 #include <torch/nn/modules/linear.h>
7 #include <torch/nn/options/activation.h>
8 
9 #include <torch/csrc/Export.h>
10 
11 namespace torch {
12 namespace nn {
13 
14 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
15 
16 /// Applies elu over a given input.
17 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ELU to learn
18 /// about the exact behavior of this module.
19 ///
20 /// See the documentation for `torch::nn::ELUOptions` class to learn what
21 /// constructor arguments are supported for this module.
22 ///
23 /// Example:
24 /// ```
25 /// ELU model(ELUOptions().alpha(42.42).inplace(true));
26 /// ```
27 class TORCH_API ELUImpl : public torch::nn::Cloneable<ELUImpl> {
28  public:
29   explicit ELUImpl(const ELUOptions& options_ = {});
30 
31   Tensor forward(Tensor input);
32 
33   void reset() override;
34 
35   /// Pretty prints the `ELU` module into the given `stream`.
36   void pretty_print(std::ostream& stream) const override;
37 
38   /// The options with which this `Module` was constructed.
39   ELUOptions options;
40 };
41 
42 /// A `ModuleHolder` subclass for `ELUImpl`.
43 /// See the documentation for `ELUImpl` class to learn what methods it
44 /// provides, and examples of how to use `ELU` with `torch::nn::ELUOptions`.
45 /// See the documentation for `ModuleHolder` to learn about PyTorch's
46 /// module storage semantics.
47 TORCH_MODULE(ELU);
48 
49 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
50 
51 /// Applies the selu function element-wise.
52 /// See https://pytorch.org/docs/main/nn.html#torch.nn.SELU to learn
53 /// about the exact behavior of this module.
54 ///
55 /// See the documentation for `torch::nn::SELUOptions` class to learn what
56 /// constructor arguments are supported for this module.
57 ///
58 /// Example:
59 /// ```
60 /// SELU model(SELUOptions().inplace(true));
61 /// ```
62 class TORCH_API SELUImpl : public torch::nn::Cloneable<SELUImpl> {
63  public:
64   explicit SELUImpl(const SELUOptions& options_ = {});
65 
66   Tensor forward(Tensor input);
67 
68   void reset() override;
69 
70   /// Pretty prints the `SELU` module into the given `stream`.
71   void pretty_print(std::ostream& stream) const override;
72 
73   /// The options with which this `Module` was constructed.
74   SELUOptions options;
75 };
76 
77 /// A `ModuleHolder` subclass for `SELUImpl`.
78 /// See the documentation for `SELUImpl` class to learn what methods it
79 /// provides, and examples of how to use `SELU` with `torch::nn::SELUOptions`.
80 /// See the documentation for `ModuleHolder` to learn about PyTorch's
81 /// module storage semantics.
82 TORCH_MODULE(SELU);
83 
84 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
85 
86 /// Applies the hard shrinkage function element-wise.
87 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Hardshrink to learn
88 /// about the exact behavior of this module.
89 ///
90 /// See the documentation for `torch::nn::HardshrinkOptions` class to learn what
91 /// constructor arguments are supported for this module.
92 ///
93 /// Example:
94 /// ```
95 /// Hardshrink model(HardshrinkOptions().lambda(42.42));
96 /// ```
97 class TORCH_API HardshrinkImpl : public torch::nn::Cloneable<HardshrinkImpl> {
98  public:
99   explicit HardshrinkImpl(const HardshrinkOptions& options_ = {});
100 
101   Tensor forward(const Tensor& input);
102 
103   void reset() override;
104 
105   /// Pretty prints the `Hardshrink` module into the given `stream`.
106   void pretty_print(std::ostream& stream) const override;
107 
108   /// The options with which this `Module` was constructed.
109   HardshrinkOptions options;
110 };
111 
112 /// A `ModuleHolder` subclass for `HardshrinkImpl`.
113 /// See the documentation for `HardshrinkImpl` class to learn what methods it
114 /// provides, and examples of how to use `Hardshrink` with
115 /// `torch::nn::HardshrinkOptions`. See the documentation for `ModuleHolder` to
116 /// learn about PyTorch's module storage semantics.
117 TORCH_MODULE(Hardshrink);
118 
119 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardtanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
120 
121 /// Applies the HardTanh function element-wise.
122 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Hardtanh to learn
123 /// about the exact behavior of this module.
124 ///
125 /// See the documentation for `torch::nn::HardtanhOptions` class to learn what
126 /// constructor arguments are supported for this module.
127 ///
128 /// Example:
129 /// ```
130 /// Hardtanh
131 /// model(HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true));
132 /// ```
133 class TORCH_API HardtanhImpl : public torch::nn::Cloneable<HardtanhImpl> {
134  public:
135   explicit HardtanhImpl(const HardtanhOptions& options_ = {});
136 
137   Tensor forward(Tensor input);
138 
139   void reset() override;
140 
141   /// Pretty prints the `Hardtanh` module into the given `stream`.
142   void pretty_print(std::ostream& stream) const override;
143 
144   /// The options with which this `Module` was constructed.
145   HardtanhOptions options;
146 };
147 
148 /// A `ModuleHolder` subclass for `HardtanhImpl`.
149 /// See the documentation for `HardtanhImpl` class to learn what methods it
150 /// provides, and examples of how to use `Hardtanh` with
151 /// `torch::nn::HardtanhOptions`. See the documentation for `ModuleHolder` to
152 /// learn about PyTorch's module storage semantics.
153 TORCH_MODULE(Hardtanh);
154 
155 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LeakyReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
156 
157 /// Applies the LeakyReLU function element-wise.
158 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LeakyReLU to learn
159 /// about the exact behavior of this module.
160 ///
161 /// See the documentation for `torch::nn::LeakyReLUOptions` class to learn what
162 /// constructor arguments are supported for this module.
163 ///
164 /// Example:
165 /// ```
166 /// LeakyReLU model(LeakyReLUOptions().negative_slope(0.42).inplace(true));
167 /// ```
168 class TORCH_API LeakyReLUImpl : public torch::nn::Cloneable<LeakyReLUImpl> {
169  public:
170   explicit LeakyReLUImpl(const LeakyReLUOptions& options_ = {});
171 
172   Tensor forward(Tensor input);
173 
174   void reset() override;
175 
176   /// Pretty prints the `LeakyReLU` module into the given `stream`.
177   void pretty_print(std::ostream& stream) const override;
178 
179   /// The options with which this `Module` was constructed.
180   LeakyReLUOptions options;
181 };
182 
183 /// A `ModuleHolder` subclass for `LeakyReLUImpl`.
184 /// See the documentation for `LeakyReLUImpl` class to learn what methods it
185 /// provides, and examples of how to use `LeakyReLU` with
186 /// `torch::nn::LeakyReLUOptions`. See the documentation for `ModuleHolder` to
187 /// learn about PyTorch's module storage semantics.
188 TORCH_MODULE(LeakyReLU);
189 
190 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
191 
192 /// Applies the LogSigmoid function element-wise.
193 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LogSigmoid to learn
194 /// about the exact behavior of this module.
195 class TORCH_API LogSigmoidImpl : public torch::nn::Cloneable<LogSigmoidImpl> {
196  public:
197   Tensor forward(const Tensor& input);
198 
199   void reset() override;
200 
201   /// Pretty prints the `LogSigmoid` module into the given `stream`.
202   void pretty_print(std::ostream& stream) const override;
203 };
204 
205 /// A `ModuleHolder` subclass for `LogSigmoidImpl`.
206 /// See the documentation for `LogSigmoidImpl` class to learn what methods it
207 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
208 /// module storage semantics.
209 TORCH_MODULE(LogSigmoid);
210 
211 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
212 
213 /// Applies the Softmax function.
214 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softmax to learn
215 /// about the exact behavior of this module.
216 ///
217 /// See the documentation for `torch::nn::SoftmaxOptions` class to learn what
218 /// constructor arguments are supported for this module.
219 ///
220 /// Example:
221 /// ```
222 /// Softmax model(SoftmaxOptions(1));
223 /// ```
224 class TORCH_API SoftmaxImpl : public torch::nn::Cloneable<SoftmaxImpl> {
225  public:
SoftmaxImpl(int64_t dim)226   explicit SoftmaxImpl(int64_t dim) : SoftmaxImpl(SoftmaxOptions(dim)) {}
227   explicit SoftmaxImpl(const SoftmaxOptions& options_);
228 
229   Tensor forward(const Tensor& input);
230 
231   void reset() override;
232 
233   /// Pretty prints the `Softmax` module into the given `stream`.
234   void pretty_print(std::ostream& stream) const override;
235 
236   SoftmaxOptions options;
237 };
238 
239 /// A `ModuleHolder` subclass for `SoftmaxImpl`.
240 /// See the documentation for `SoftmaxImpl` class to learn what methods it
241 /// provides, and examples of how to use `Softmax` with
242 /// `torch::nn::SoftmaxOptions`. See the documentation for `ModuleHolder` to
243 /// learn about PyTorch's module storage semantics.
244 TORCH_MODULE(Softmax);
245 
246 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmin ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
247 
248 /// Applies the Softmin function element-wise.
249 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softmin to learn
250 /// about the exact behavior of this module.
251 ///
252 /// See the documentation for `torch::nn::SoftminOptions` class to learn what
253 /// constructor arguments are supported for this module.
254 ///
255 /// Example:
256 /// ```
257 /// Softmin model(SoftminOptions(1));
258 /// ```
259 class TORCH_API SoftminImpl : public torch::nn::Cloneable<SoftminImpl> {
260  public:
SoftminImpl(int64_t dim)261   explicit SoftminImpl(int64_t dim) : SoftminImpl(SoftminOptions(dim)) {}
262   explicit SoftminImpl(const SoftminOptions& options_);
263 
264   Tensor forward(const Tensor& input);
265 
266   void reset() override;
267 
268   /// Pretty prints the `Softmin` module into the given `stream`.
269   void pretty_print(std::ostream& stream) const override;
270 
271   SoftminOptions options;
272 };
273 
274 /// A `ModuleHolder` subclass for `SoftminImpl`.
275 /// See the documentation for `SoftminImpl` class to learn what methods it
276 /// provides, and examples of how to use `Softmin` with
277 /// `torch::nn::SoftminOptions`. See the documentation for `ModuleHolder` to
278 /// learn about PyTorch's module storage semantics.
279 TORCH_MODULE(Softmin);
280 
281 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSoftmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
282 
283 /// Applies the LogSoftmax function element-wise.
284 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LogSoftmax to learn
285 /// about the exact behavior of this module.
286 ///
287 /// See the documentation for `torch::nn::LogSoftmaxOptions` class to learn what
288 /// constructor arguments are supported for this module.
289 ///
290 /// Example:
291 /// ```
292 /// LogSoftmax model(LogSoftmaxOptions(1));
293 /// ```
294 class TORCH_API LogSoftmaxImpl : public torch::nn::Cloneable<LogSoftmaxImpl> {
295  public:
LogSoftmaxImpl(int64_t dim)296   explicit LogSoftmaxImpl(int64_t dim)
297       : LogSoftmaxImpl(LogSoftmaxOptions(dim)) {}
298   explicit LogSoftmaxImpl(const LogSoftmaxOptions& options_);
299 
300   Tensor forward(const Tensor& input);
301 
302   void reset() override;
303 
304   /// Pretty prints the `LogSoftmax` module into the given `stream`.
305   void pretty_print(std::ostream& stream) const override;
306 
307   LogSoftmaxOptions options;
308 };
309 
310 /// A `ModuleHolder` subclass for `LogSoftmaxImpl`.
311 /// See the documentation for `LogSoftmaxImpl` class to learn what methods it
312 /// provides, and examples of how to use `LogSoftmax` with
313 /// `torch::nn::LogSoftmaxOptions`. See the documentation for `ModuleHolder` to
314 /// learn about PyTorch's module storage semantics.
315 TORCH_MODULE(LogSoftmax);
316 
317 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
318 
319 /// Applies the Softmax2d function element-wise.
320 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softmax2d to learn
321 /// about the exact behavior of this module.
322 class TORCH_API Softmax2dImpl : public torch::nn::Cloneable<Softmax2dImpl> {
323  public:
324   Tensor forward(const Tensor& input);
325 
326   void reset() override;
327 
328   /// Pretty prints the `Softmax2d` module into the given `stream`.
329   void pretty_print(std::ostream& stream) const override;
330 };
331 
332 /// A `ModuleHolder` subclass for `Softmax2dImpl`.
333 /// See the documentation for `Softmax2dImpl` class to learn what methods it
334 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
335 /// module storage semantics.
336 TORCH_MODULE(Softmax2d);
337 
338 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
339 
340 /// Applies the PReLU function element-wise.
341 /// See https://pytorch.org/docs/main/nn.html#torch.nn.PReLU to learn
342 /// about the exact behavior of this module.
343 ///
344 /// See the documentation for `torch::nn::PReLUOptions` class to learn what
345 /// constructor arguments are supported for this module.
346 ///
347 /// Example:
348 /// ```
349 /// PReLU model(PReLUOptions().num_parameters(42));
350 /// ```
351 class TORCH_API PReLUImpl : public torch::nn::Cloneable<PReLUImpl> {
352  public:
353   explicit PReLUImpl(const PReLUOptions& options_ = {});
354 
355   Tensor forward(const Tensor& input);
356 
357   void reset() override;
358 
359   /// Pretty prints the `PReLU` module into the given `stream`.
360   void pretty_print(std::ostream& stream) const override;
361 
362   /// The options with which this `Module` was constructed.
363   PReLUOptions options;
364 
365   /// The learned weight.
366   Tensor weight;
367 };
368 
369 /// A `ModuleHolder` subclass for `PReLUImpl`.
370 /// See the documentation for `PReLUImpl` class to learn what methods it
371 /// provides, and examples of how to use `PReLU` with `torch::nn::PReLUOptions`.
372 /// See the documentation for `ModuleHolder` to learn about PyTorch's
373 /// module storage semantics.
374 TORCH_MODULE(PReLU);
375 
376 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
377 
378 /// Applies the ReLU function element-wise.
379 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ReLU to learn
380 /// about the exact behavior of this module.
381 ///
382 /// See the documentation for `torch::nn::ReLUOptions` class to learn what
383 /// constructor arguments are supported for this module.
384 ///
385 /// Example:
386 /// ```
387 /// ReLU model(ReLUOptions().inplace(true));
388 /// ```
389 class TORCH_API ReLUImpl : public torch::nn::Cloneable<ReLUImpl> {
390  public:
391   explicit ReLUImpl(const ReLUOptions& options_ = {});
392 
393   Tensor forward(Tensor input);
394 
395   void reset() override;
396 
397   /// Pretty prints the `ReLU` module into the given `stream`.
398   void pretty_print(std::ostream& stream) const override;
399 
400   /// The options with which this `Module` was constructed.
401   ReLUOptions options;
402 };
403 
404 /// A `ModuleHolder` subclass for `ReLUImpl`.
405 /// See the documentation for `ReLUImpl` class to learn what methods it
406 /// provides, and examples of how to use `ReLU` with `torch::nn::ReLUOptions`.
407 /// See the documentation for `ModuleHolder` to learn about PyTorch's
408 /// module storage semantics.
409 TORCH_MODULE(ReLU);
410 
411 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU6 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
412 
413 /// Applies the ReLU6 function element-wise.
414 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ReLU6 to learn
415 /// about the exact behavior of this module.
416 ///
417 /// See the documentation for `torch::nn::ReLU6Options` class to learn what
418 /// constructor arguments are supported for this module.
419 ///
420 /// Example:
421 /// ```
422 /// ReLU6 model(ReLU6Options().inplace(true));
423 /// ```
424 class TORCH_API ReLU6Impl : public torch::nn::Cloneable<ReLU6Impl> {
425  public:
426   explicit ReLU6Impl(const ReLU6Options& options_ = {});
427 
428   Tensor forward(Tensor input);
429 
430   void reset() override;
431 
432   /// Pretty prints the `ReLU6` module into the given `stream`.
433   void pretty_print(std::ostream& stream) const override;
434 
435   /// The options with which this `Module` was constructed.
436   ReLU6Options options;
437 };
438 
439 /// A `ModuleHolder` subclass for `ReLU6Impl`.
440 /// See the documentation for `ReLU6Impl` class to learn what methods it
441 /// provides, and examples of how to use `ReLU6` with `torch::nn::ReLU6Options`.
442 /// See the documentation for `ModuleHolder` to learn about PyTorch's
443 /// module storage semantics.
444 TORCH_MODULE(ReLU6);
445 
446 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
447 
448 /// Applies the RReLU function element-wise.
449 /// See https://pytorch.org/docs/main/nn.html#torch.nn.RReLU to learn
450 /// about the exact behavior of this module.
451 ///
452 /// See the documentation for `torch::nn::RReLUOptions` class to learn what
453 /// constructor arguments are supported for this module.
454 ///
455 /// Example:
456 /// ```
457 /// RReLU model(RReLUOptions().lower(0.24).upper(0.42).inplace(true));
458 /// ```
459 class TORCH_API RReLUImpl : public torch::nn::Cloneable<RReLUImpl> {
460  public:
461   explicit RReLUImpl(const RReLUOptions& options_ = {});
462 
463   Tensor forward(Tensor input);
464 
465   void reset() override;
466 
467   /// Pretty prints the `RReLU` module into the given `stream`.
468   void pretty_print(std::ostream& stream) const override;
469 
470   /// The options with which this `Module` was constructed.
471   RReLUOptions options;
472 };
473 
474 /// A `ModuleHolder` subclass for `RReLUImpl`.
475 /// See the documentation for `RReLUImpl` class to learn what methods it
476 /// provides, and examples of how to use `RReLU` with `torch::nn::RReLUOptions`.
477 /// See the documentation for `ModuleHolder` to learn about PyTorch's
478 /// module storage semantics.
479 TORCH_MODULE(RReLU);
480 
481 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
482 
483 /// Applies celu over a given input.
484 /// See https://pytorch.org/docs/main/nn.html#torch.nn.CELU to learn
485 /// about the exact behavior of this module.
486 ///
487 /// See the documentation for `torch::nn::CELUOptions` class to learn what
488 /// constructor arguments are supported for this module.
489 ///
490 /// Example:
491 /// ```
492 /// CELU model(CELUOptions().alpha(42.42).inplace(true));
493 /// ```
494 class TORCH_API CELUImpl : public torch::nn::Cloneable<CELUImpl> {
495  public:
496   explicit CELUImpl(const CELUOptions& options_ = {});
497 
498   Tensor forward(Tensor input);
499 
500   void reset() override;
501 
502   /// Pretty prints the `CELU` module into the given `stream`.
503   void pretty_print(std::ostream& stream) const override;
504 
505   /// The options with which this `Module` was constructed.
506   CELUOptions options;
507 };
508 
509 /// A `ModuleHolder` subclass for `CELUImpl`.
510 /// See the documentation for `CELUImpl` class to learn what methods it
511 /// provides, and examples of how to use `CELU` with `torch::nn::CELUOptions`.
512 /// See the documentation for `ModuleHolder` to learn about PyTorch's
513 /// module storage semantics.
514 TORCH_MODULE(CELU);
515 
516 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
517 
518 /// Applies glu over a given input.
519 /// See https://pytorch.org/docs/main/nn.html#torch.nn.GLU to learn
520 /// about the exact behavior of this module.
521 ///
522 /// See the documentation for `torch::nn::GLUOptions` class to learn what
523 /// constructor arguments are supported for this module.
524 ///
525 /// Example:
526 /// ```
527 /// GLU model(GLUOptions(1));
528 /// ```
529 class TORCH_API GLUImpl : public torch::nn::Cloneable<GLUImpl> {
530  public:
531   explicit GLUImpl(const GLUOptions& options_ = {});
532 
533   Tensor forward(const Tensor& input);
534 
535   void reset() override;
536 
537   /// Pretty prints the `GLU` module into the given `stream`.
538   void pretty_print(std::ostream& stream) const override;
539 
540   /// The options with which this `Module` was constructed.
541   GLUOptions options;
542 };
543 
544 /// A `ModuleHolder` subclass for `GLUImpl`.
545 /// See the documentation for `GLUImpl` class to learn what methods it
546 /// provides, and examples of how to use `GLU` with `torch::nn::GLUOptions`.
547 /// See the documentation for `ModuleHolder` to learn about PyTorch's
548 /// module storage semantics.
549 TORCH_MODULE(GLU);
550 
551 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
552 
553 /// Applies gelu over a given input.
554 /// See https://pytorch.org/docs/main/nn.html#torch.nn.GELU to learn
555 /// about the exact behavior of this module.
556 class TORCH_API GELUImpl : public torch::nn::Cloneable<GELUImpl> {
557  public:
558   explicit GELUImpl(GELUOptions options_ = {});
559 
560   Tensor forward(const Tensor& input);
561 
562   void reset() override;
563 
564   /// Pretty prints the `GELU` module into the given `stream`.
565   void pretty_print(std::ostream& stream) const override;
566 
567   /// The options with which this `Module` was constructed.
568   GELUOptions options;
569 };
570 
571 /// A `ModuleHolder` subclass for `GELUImpl`.
572 /// See the documentation for `GELUImpl` class to learn what methods it
573 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
574 /// module storage semantics.
575 TORCH_MODULE(GELU);
576 
577 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SiLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
578 
579 /// Applies silu over a given input.
580 /// See https://pytorch.org/docs/main/nn.html#torch.nn.SiLU to learn
581 /// about the exact behavior of this module.
582 class TORCH_API SiLUImpl : public torch::nn::Cloneable<SiLUImpl> {
583  public:
584   Tensor forward(const Tensor& input);
585 
586   void reset() override;
587 
588   /// Pretty prints the `SiLU` module into the given `stream`.
589   void pretty_print(std::ostream& stream) const override;
590 };
591 
592 /// A `ModuleHolder` subclass for `SiLUImpl`.
593 /// See the documentation for `SiLUImpl` class to learn what methods it
594 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
595 /// module storage semantics.
596 TORCH_MODULE(SiLU);
597 
598 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Mish ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
599 
600 /// Applies mish over a given input.
601 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Mish to learn
602 /// about the exact behavior of this module.
603 class TORCH_API MishImpl : public torch::nn::Cloneable<MishImpl> {
604  public:
605   Tensor forward(const Tensor& input);
606 
607   void reset() override;
608 
609   /// Pretty prints the `Mish` module into the given `stream`.
610   void pretty_print(std::ostream& stream) const override;
611 };
612 
613 /// A `ModuleHolder` subclass for `MishImpl`.
614 /// See the documentation for `MishImpl` class to learn what methods it
615 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
616 /// module storage semantics.
617 TORCH_MODULE(Mish);
618 
619 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
620 
621 /// Applies sigmoid over a given input.
622 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Sigmoid to learn
623 /// about the exact behavior of this module.
624 class TORCH_API SigmoidImpl : public torch::nn::Cloneable<SigmoidImpl> {
625  public:
626   Tensor forward(const Tensor& input);
627 
628   void reset() override;
629 
630   /// Pretty prints the `Sigmoid` module into the given `stream`.
631   void pretty_print(std::ostream& stream) const override;
632 };
633 
634 /// A `ModuleHolder` subclass for `SigmoidImpl`.
635 /// See the documentation for `SigmoidImpl` class to learn what methods it
636 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
637 /// module storage semantics.
638 TORCH_MODULE(Sigmoid);
639 
640 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softplus ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
641 
642 /// Applies softplus over a given input.
643 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softplus to learn
644 /// about the exact behavior of this module.
645 ///
646 /// See the documentation for `torch::nn::SoftplusOptions` class to learn what
647 /// constructor arguments are supported for this module.
648 ///
649 /// Example:
650 /// ```
651 /// Softplus model(SoftplusOptions().beta(0.24).threshold(42.42));
652 /// ```
653 class TORCH_API SoftplusImpl : public torch::nn::Cloneable<SoftplusImpl> {
654  public:
655   explicit SoftplusImpl(const SoftplusOptions& options_ = {});
656 
657   Tensor forward(const Tensor& input);
658 
659   void reset() override;
660 
661   /// Pretty prints the `Softplus` module into the given `stream`.
662   void pretty_print(std::ostream& stream) const override;
663 
664   /// The options with which this `Module` was constructed.
665   SoftplusOptions options;
666 };
667 
668 /// A `ModuleHolder` subclass for `SoftplusImpl`.
669 /// See the documentation for `SoftplusImpl` class to learn what methods it
670 /// provides, and examples of how to use `Softplus` with
671 /// `torch::nn::SoftplusOptions`. See the documentation for `ModuleHolder` to
672 /// learn about PyTorch's module storage semantics.
673 TORCH_MODULE(Softplus);
674 
675 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
676 
677 /// Applies the soft shrinkage function element-wise.
678 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softshrink to learn
679 /// about the exact behavior of this module.
680 ///
681 /// See the documentation for `torch::nn::SoftshrinkOptions` class to learn what
682 /// constructor arguments are supported for this module.
683 ///
684 /// Example:
685 /// ```
686 /// Softshrink model(SoftshrinkOptions(42.42));
687 /// ```
688 class TORCH_API SoftshrinkImpl : public torch::nn::Cloneable<SoftshrinkImpl> {
689  public:
690   explicit SoftshrinkImpl(const SoftshrinkOptions& options_ = {});
691 
692   Tensor forward(const Tensor& input);
693 
694   void reset() override;
695 
696   /// Pretty prints the `Softshrink` module into the given `stream`.
697   void pretty_print(std::ostream& stream) const override;
698 
699   /// The options with which this `Module` was constructed.
700   SoftshrinkOptions options;
701 };
702 
703 /// A `ModuleHolder` subclass for `SoftshrinkImpl`.
704 /// See the documentation for `SoftshrinkImpl` class to learn what methods it
705 /// provides, and examples of how to use `Softshrink` with
706 /// `torch::nn::SoftshrinkOptions`. See the documentation for `ModuleHolder` to
707 /// learn about PyTorch's module storage semantics.
708 TORCH_MODULE(Softshrink);
709 
710 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softsign ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
711 
712 /// Applies Softsign over a given input.
713 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softsign to learn
714 /// about the exact behavior of this module.
715 class TORCH_API SoftsignImpl : public torch::nn::Cloneable<SoftsignImpl> {
716  public:
717   Tensor forward(const Tensor& input);
718 
719   void reset() override;
720 
721   /// Pretty prints the `Softsign` module into the given `stream`.
722   void pretty_print(std::ostream& stream) const override;
723 };
724 
725 /// A `ModuleHolder` subclass for `SoftsignImpl`.
726 /// See the documentation for `SoftsignImpl` class to learn what methods it
727 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
728 /// module storage semantics.
729 TORCH_MODULE(Softsign);
730 
731 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
732 
733 /// Applies Tanh over a given input.
734 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Tanh to learn
735 /// about the exact behavior of this module.
736 class TORCH_API TanhImpl : public torch::nn::Cloneable<TanhImpl> {
737  public:
738   Tensor forward(const Tensor& input);
739 
740   void reset() override;
741 
742   /// Pretty prints the `Tanh` module into the given `stream`.
743   void pretty_print(std::ostream& stream) const override;
744 };
745 
746 /// A `ModuleHolder` subclass for `TanhImpl`.
747 /// See the documentation for `TanhImpl` class to learn what methods it
748 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
749 /// module storage semantics.
750 TORCH_MODULE(Tanh);
751 
752 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanhshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
753 
754 /// Applies Tanhshrink over a given input.
755 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Tanhshrink to learn
756 /// about the exact behavior of this module.
757 class TORCH_API TanhshrinkImpl : public torch::nn::Cloneable<TanhshrinkImpl> {
758  public:
759   Tensor forward(const Tensor& input);
760 
761   void reset() override;
762 
763   /// Pretty prints the `Tanhshrink` module into the given `stream`.
764   void pretty_print(std::ostream& stream) const override;
765 };
766 
767 /// A `ModuleHolder` subclass for `TanhshrinkImpl`.
768 /// See the documentation for `TanhshrinkImpl` class to learn what methods it
769 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
770 /// module storage semantics.
771 TORCH_MODULE(Tanhshrink);
772 
773 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Threshold ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
774 
775 /// Applies the Threshold function element-wise.
776 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Threshold to learn
777 /// about the exact behavior of this module.
778 ///
779 /// See the documentation for `torch::nn::ThresholdOptions` class to learn what
780 /// constructor arguments are supported for this module.
781 ///
782 /// Example:
783 /// ```
784 /// Threshold model(ThresholdOptions(42.42, 24.24).inplace(true));
785 /// ```
786 class TORCH_API ThresholdImpl : public torch::nn::Cloneable<ThresholdImpl> {
787  public:
ThresholdImpl(double threshold,double value)788   ThresholdImpl(double threshold, double value)
789       : ThresholdImpl(ThresholdOptions(threshold, value)) {}
790   explicit ThresholdImpl(const ThresholdOptions& options_);
791 
792   Tensor forward(Tensor input);
793 
794   void reset() override;
795 
796   /// Pretty prints the `Threshold` module into the given `stream`.
797   void pretty_print(std::ostream& stream) const override;
798 
799   /// The options with which this `Module` was constructed.
800   ThresholdOptions options;
801 };
802 
803 /// A `ModuleHolder` subclass for `ThresholdImpl`.
804 /// See the documentation for `ThresholdImpl` class to learn what methods it
805 /// provides, and examples of how to use `Threshold` with
806 /// `torch::nn::ThresholdOptions`. See the documentation for `ModuleHolder` to
807 /// learn about PyTorch's module storage semantics.
808 TORCH_MODULE(Threshold);
809 
810 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiheadAttention ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
811 
812 /// Applies the MultiheadAttention function element-wise.
813 /// See https://pytorch.org/docs/main/nn.html#torch.nn.MultiheadAttention
814 /// to learn about the exact behavior of this module.
815 ///
816 /// See the documentation for `torch::nn::MultiheadAttentionOptions` class to
817 /// learn what constructor arguments are supported for this module.
818 ///
819 /// Example:
820 /// ```
821 /// MultiheadAttention model(MultiheadAttentionOptions(20, 10).bias(false));
822 /// ```
823 class TORCH_API MultiheadAttentionImpl
824     : public torch::nn::Cloneable<MultiheadAttentionImpl> {
825  public:
MultiheadAttentionImpl(int64_t embed_dim,int64_t num_heads)826   MultiheadAttentionImpl(int64_t embed_dim, int64_t num_heads)
827       : MultiheadAttentionImpl(
828             MultiheadAttentionOptions(embed_dim, num_heads)) {}
829   explicit MultiheadAttentionImpl(const MultiheadAttentionOptions& options_);
830 
831   std::tuple<Tensor, Tensor> forward(
832       const Tensor& query,
833       const Tensor& key,
834       const Tensor& value,
835       const Tensor& key_padding_mask = {},
836       bool need_weights = true,
837       const Tensor& attn_mask = {},
838       bool average_attn_weights = true);
839 
840  protected:
841   FORWARD_HAS_DEFAULT_ARGS(
842       {3, AnyValue(Tensor())},
843       {4, AnyValue(true)},
844       {5, AnyValue(Tensor())},
845       {6, AnyValue(true)})
846 
847  public:
848   void reset() override;
849 
850   void _reset_parameters();
851 
852   /// The options with which this `Module` was constructed.
853   MultiheadAttentionOptions options;
854 
855   bool _qkv_same_embed_dim;
856   Tensor in_proj_weight;
857   Tensor in_proj_bias;
858   Tensor bias_k;
859   Tensor bias_v;
860   Linear out_proj = nullptr;
861   Tensor q_proj_weight;
862   Tensor k_proj_weight;
863   Tensor v_proj_weight;
864   int64_t head_dim;
865 };
866 
867 /// A `ModuleHolder` subclass for `MultiheadAttentionImpl`.
868 /// See the documentation for `MultiheadAttentionImpl` class to learn what
869 /// methods it provides, and examples of how to use `MultiheadAttention` with
870 /// `torch::nn::MultiheadAttentionOptions`. See the documentation for
871 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
872 TORCH_MODULE(MultiheadAttention);
873 
874 } // namespace nn
875 } // namespace torch
876