1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/functional/activation.h> 5 #include <torch/nn/modules/common.h> 6 #include <torch/nn/modules/linear.h> 7 #include <torch/nn/options/activation.h> 8 9 #include <torch/csrc/Export.h> 10 11 namespace torch { 12 namespace nn { 13 14 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 15 16 /// Applies elu over a given input. 17 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ELU to learn 18 /// about the exact behavior of this module. 19 /// 20 /// See the documentation for `torch::nn::ELUOptions` class to learn what 21 /// constructor arguments are supported for this module. 22 /// 23 /// Example: 24 /// ``` 25 /// ELU model(ELUOptions().alpha(42.42).inplace(true)); 26 /// ``` 27 class TORCH_API ELUImpl : public torch::nn::Cloneable<ELUImpl> { 28 public: 29 explicit ELUImpl(const ELUOptions& options_ = {}); 30 31 Tensor forward(Tensor input); 32 33 void reset() override; 34 35 /// Pretty prints the `ELU` module into the given `stream`. 36 void pretty_print(std::ostream& stream) const override; 37 38 /// The options with which this `Module` was constructed. 39 ELUOptions options; 40 }; 41 42 /// A `ModuleHolder` subclass for `ELUImpl`. 43 /// See the documentation for `ELUImpl` class to learn what methods it 44 /// provides, and examples of how to use `ELU` with `torch::nn::ELUOptions`. 45 /// See the documentation for `ModuleHolder` to learn about PyTorch's 46 /// module storage semantics. 47 TORCH_MODULE(ELU); 48 49 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 50 51 /// Applies the selu function element-wise. 52 /// See https://pytorch.org/docs/main/nn.html#torch.nn.SELU to learn 53 /// about the exact behavior of this module. 54 /// 55 /// See the documentation for `torch::nn::SELUOptions` class to learn what 56 /// constructor arguments are supported for this module. 57 /// 58 /// Example: 59 /// ``` 60 /// SELU model(SELUOptions().inplace(true)); 61 /// ``` 62 class TORCH_API SELUImpl : public torch::nn::Cloneable<SELUImpl> { 63 public: 64 explicit SELUImpl(const SELUOptions& options_ = {}); 65 66 Tensor forward(Tensor input); 67 68 void reset() override; 69 70 /// Pretty prints the `SELU` module into the given `stream`. 71 void pretty_print(std::ostream& stream) const override; 72 73 /// The options with which this `Module` was constructed. 74 SELUOptions options; 75 }; 76 77 /// A `ModuleHolder` subclass for `SELUImpl`. 78 /// See the documentation for `SELUImpl` class to learn what methods it 79 /// provides, and examples of how to use `SELU` with `torch::nn::SELUOptions`. 80 /// See the documentation for `ModuleHolder` to learn about PyTorch's 81 /// module storage semantics. 82 TORCH_MODULE(SELU); 83 84 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 85 86 /// Applies the hard shrinkage function element-wise. 87 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Hardshrink to learn 88 /// about the exact behavior of this module. 89 /// 90 /// See the documentation for `torch::nn::HardshrinkOptions` class to learn what 91 /// constructor arguments are supported for this module. 92 /// 93 /// Example: 94 /// ``` 95 /// Hardshrink model(HardshrinkOptions().lambda(42.42)); 96 /// ``` 97 class TORCH_API HardshrinkImpl : public torch::nn::Cloneable<HardshrinkImpl> { 98 public: 99 explicit HardshrinkImpl(const HardshrinkOptions& options_ = {}); 100 101 Tensor forward(const Tensor& input); 102 103 void reset() override; 104 105 /// Pretty prints the `Hardshrink` module into the given `stream`. 106 void pretty_print(std::ostream& stream) const override; 107 108 /// The options with which this `Module` was constructed. 109 HardshrinkOptions options; 110 }; 111 112 /// A `ModuleHolder` subclass for `HardshrinkImpl`. 113 /// See the documentation for `HardshrinkImpl` class to learn what methods it 114 /// provides, and examples of how to use `Hardshrink` with 115 /// `torch::nn::HardshrinkOptions`. See the documentation for `ModuleHolder` to 116 /// learn about PyTorch's module storage semantics. 117 TORCH_MODULE(Hardshrink); 118 119 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardtanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 120 121 /// Applies the HardTanh function element-wise. 122 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Hardtanh to learn 123 /// about the exact behavior of this module. 124 /// 125 /// See the documentation for `torch::nn::HardtanhOptions` class to learn what 126 /// constructor arguments are supported for this module. 127 /// 128 /// Example: 129 /// ``` 130 /// Hardtanh 131 /// model(HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true)); 132 /// ``` 133 class TORCH_API HardtanhImpl : public torch::nn::Cloneable<HardtanhImpl> { 134 public: 135 explicit HardtanhImpl(const HardtanhOptions& options_ = {}); 136 137 Tensor forward(Tensor input); 138 139 void reset() override; 140 141 /// Pretty prints the `Hardtanh` module into the given `stream`. 142 void pretty_print(std::ostream& stream) const override; 143 144 /// The options with which this `Module` was constructed. 145 HardtanhOptions options; 146 }; 147 148 /// A `ModuleHolder` subclass for `HardtanhImpl`. 149 /// See the documentation for `HardtanhImpl` class to learn what methods it 150 /// provides, and examples of how to use `Hardtanh` with 151 /// `torch::nn::HardtanhOptions`. See the documentation for `ModuleHolder` to 152 /// learn about PyTorch's module storage semantics. 153 TORCH_MODULE(Hardtanh); 154 155 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LeakyReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 156 157 /// Applies the LeakyReLU function element-wise. 158 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LeakyReLU to learn 159 /// about the exact behavior of this module. 160 /// 161 /// See the documentation for `torch::nn::LeakyReLUOptions` class to learn what 162 /// constructor arguments are supported for this module. 163 /// 164 /// Example: 165 /// ``` 166 /// LeakyReLU model(LeakyReLUOptions().negative_slope(0.42).inplace(true)); 167 /// ``` 168 class TORCH_API LeakyReLUImpl : public torch::nn::Cloneable<LeakyReLUImpl> { 169 public: 170 explicit LeakyReLUImpl(const LeakyReLUOptions& options_ = {}); 171 172 Tensor forward(Tensor input); 173 174 void reset() override; 175 176 /// Pretty prints the `LeakyReLU` module into the given `stream`. 177 void pretty_print(std::ostream& stream) const override; 178 179 /// The options with which this `Module` was constructed. 180 LeakyReLUOptions options; 181 }; 182 183 /// A `ModuleHolder` subclass for `LeakyReLUImpl`. 184 /// See the documentation for `LeakyReLUImpl` class to learn what methods it 185 /// provides, and examples of how to use `LeakyReLU` with 186 /// `torch::nn::LeakyReLUOptions`. See the documentation for `ModuleHolder` to 187 /// learn about PyTorch's module storage semantics. 188 TORCH_MODULE(LeakyReLU); 189 190 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 191 192 /// Applies the LogSigmoid function element-wise. 193 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LogSigmoid to learn 194 /// about the exact behavior of this module. 195 class TORCH_API LogSigmoidImpl : public torch::nn::Cloneable<LogSigmoidImpl> { 196 public: 197 Tensor forward(const Tensor& input); 198 199 void reset() override; 200 201 /// Pretty prints the `LogSigmoid` module into the given `stream`. 202 void pretty_print(std::ostream& stream) const override; 203 }; 204 205 /// A `ModuleHolder` subclass for `LogSigmoidImpl`. 206 /// See the documentation for `LogSigmoidImpl` class to learn what methods it 207 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 208 /// module storage semantics. 209 TORCH_MODULE(LogSigmoid); 210 211 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 212 213 /// Applies the Softmax function. 214 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softmax to learn 215 /// about the exact behavior of this module. 216 /// 217 /// See the documentation for `torch::nn::SoftmaxOptions` class to learn what 218 /// constructor arguments are supported for this module. 219 /// 220 /// Example: 221 /// ``` 222 /// Softmax model(SoftmaxOptions(1)); 223 /// ``` 224 class TORCH_API SoftmaxImpl : public torch::nn::Cloneable<SoftmaxImpl> { 225 public: SoftmaxImpl(int64_t dim)226 explicit SoftmaxImpl(int64_t dim) : SoftmaxImpl(SoftmaxOptions(dim)) {} 227 explicit SoftmaxImpl(const SoftmaxOptions& options_); 228 229 Tensor forward(const Tensor& input); 230 231 void reset() override; 232 233 /// Pretty prints the `Softmax` module into the given `stream`. 234 void pretty_print(std::ostream& stream) const override; 235 236 SoftmaxOptions options; 237 }; 238 239 /// A `ModuleHolder` subclass for `SoftmaxImpl`. 240 /// See the documentation for `SoftmaxImpl` class to learn what methods it 241 /// provides, and examples of how to use `Softmax` with 242 /// `torch::nn::SoftmaxOptions`. See the documentation for `ModuleHolder` to 243 /// learn about PyTorch's module storage semantics. 244 TORCH_MODULE(Softmax); 245 246 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmin ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 247 248 /// Applies the Softmin function element-wise. 249 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softmin to learn 250 /// about the exact behavior of this module. 251 /// 252 /// See the documentation for `torch::nn::SoftminOptions` class to learn what 253 /// constructor arguments are supported for this module. 254 /// 255 /// Example: 256 /// ``` 257 /// Softmin model(SoftminOptions(1)); 258 /// ``` 259 class TORCH_API SoftminImpl : public torch::nn::Cloneable<SoftminImpl> { 260 public: SoftminImpl(int64_t dim)261 explicit SoftminImpl(int64_t dim) : SoftminImpl(SoftminOptions(dim)) {} 262 explicit SoftminImpl(const SoftminOptions& options_); 263 264 Tensor forward(const Tensor& input); 265 266 void reset() override; 267 268 /// Pretty prints the `Softmin` module into the given `stream`. 269 void pretty_print(std::ostream& stream) const override; 270 271 SoftminOptions options; 272 }; 273 274 /// A `ModuleHolder` subclass for `SoftminImpl`. 275 /// See the documentation for `SoftminImpl` class to learn what methods it 276 /// provides, and examples of how to use `Softmin` with 277 /// `torch::nn::SoftminOptions`. See the documentation for `ModuleHolder` to 278 /// learn about PyTorch's module storage semantics. 279 TORCH_MODULE(Softmin); 280 281 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSoftmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 282 283 /// Applies the LogSoftmax function element-wise. 284 /// See https://pytorch.org/docs/main/nn.html#torch.nn.LogSoftmax to learn 285 /// about the exact behavior of this module. 286 /// 287 /// See the documentation for `torch::nn::LogSoftmaxOptions` class to learn what 288 /// constructor arguments are supported for this module. 289 /// 290 /// Example: 291 /// ``` 292 /// LogSoftmax model(LogSoftmaxOptions(1)); 293 /// ``` 294 class TORCH_API LogSoftmaxImpl : public torch::nn::Cloneable<LogSoftmaxImpl> { 295 public: LogSoftmaxImpl(int64_t dim)296 explicit LogSoftmaxImpl(int64_t dim) 297 : LogSoftmaxImpl(LogSoftmaxOptions(dim)) {} 298 explicit LogSoftmaxImpl(const LogSoftmaxOptions& options_); 299 300 Tensor forward(const Tensor& input); 301 302 void reset() override; 303 304 /// Pretty prints the `LogSoftmax` module into the given `stream`. 305 void pretty_print(std::ostream& stream) const override; 306 307 LogSoftmaxOptions options; 308 }; 309 310 /// A `ModuleHolder` subclass for `LogSoftmaxImpl`. 311 /// See the documentation for `LogSoftmaxImpl` class to learn what methods it 312 /// provides, and examples of how to use `LogSoftmax` with 313 /// `torch::nn::LogSoftmaxOptions`. See the documentation for `ModuleHolder` to 314 /// learn about PyTorch's module storage semantics. 315 TORCH_MODULE(LogSoftmax); 316 317 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 318 319 /// Applies the Softmax2d function element-wise. 320 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softmax2d to learn 321 /// about the exact behavior of this module. 322 class TORCH_API Softmax2dImpl : public torch::nn::Cloneable<Softmax2dImpl> { 323 public: 324 Tensor forward(const Tensor& input); 325 326 void reset() override; 327 328 /// Pretty prints the `Softmax2d` module into the given `stream`. 329 void pretty_print(std::ostream& stream) const override; 330 }; 331 332 /// A `ModuleHolder` subclass for `Softmax2dImpl`. 333 /// See the documentation for `Softmax2dImpl` class to learn what methods it 334 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 335 /// module storage semantics. 336 TORCH_MODULE(Softmax2d); 337 338 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 339 340 /// Applies the PReLU function element-wise. 341 /// See https://pytorch.org/docs/main/nn.html#torch.nn.PReLU to learn 342 /// about the exact behavior of this module. 343 /// 344 /// See the documentation for `torch::nn::PReLUOptions` class to learn what 345 /// constructor arguments are supported for this module. 346 /// 347 /// Example: 348 /// ``` 349 /// PReLU model(PReLUOptions().num_parameters(42)); 350 /// ``` 351 class TORCH_API PReLUImpl : public torch::nn::Cloneable<PReLUImpl> { 352 public: 353 explicit PReLUImpl(const PReLUOptions& options_ = {}); 354 355 Tensor forward(const Tensor& input); 356 357 void reset() override; 358 359 /// Pretty prints the `PReLU` module into the given `stream`. 360 void pretty_print(std::ostream& stream) const override; 361 362 /// The options with which this `Module` was constructed. 363 PReLUOptions options; 364 365 /// The learned weight. 366 Tensor weight; 367 }; 368 369 /// A `ModuleHolder` subclass for `PReLUImpl`. 370 /// See the documentation for `PReLUImpl` class to learn what methods it 371 /// provides, and examples of how to use `PReLU` with `torch::nn::PReLUOptions`. 372 /// See the documentation for `ModuleHolder` to learn about PyTorch's 373 /// module storage semantics. 374 TORCH_MODULE(PReLU); 375 376 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 377 378 /// Applies the ReLU function element-wise. 379 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ReLU to learn 380 /// about the exact behavior of this module. 381 /// 382 /// See the documentation for `torch::nn::ReLUOptions` class to learn what 383 /// constructor arguments are supported for this module. 384 /// 385 /// Example: 386 /// ``` 387 /// ReLU model(ReLUOptions().inplace(true)); 388 /// ``` 389 class TORCH_API ReLUImpl : public torch::nn::Cloneable<ReLUImpl> { 390 public: 391 explicit ReLUImpl(const ReLUOptions& options_ = {}); 392 393 Tensor forward(Tensor input); 394 395 void reset() override; 396 397 /// Pretty prints the `ReLU` module into the given `stream`. 398 void pretty_print(std::ostream& stream) const override; 399 400 /// The options with which this `Module` was constructed. 401 ReLUOptions options; 402 }; 403 404 /// A `ModuleHolder` subclass for `ReLUImpl`. 405 /// See the documentation for `ReLUImpl` class to learn what methods it 406 /// provides, and examples of how to use `ReLU` with `torch::nn::ReLUOptions`. 407 /// See the documentation for `ModuleHolder` to learn about PyTorch's 408 /// module storage semantics. 409 TORCH_MODULE(ReLU); 410 411 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU6 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 412 413 /// Applies the ReLU6 function element-wise. 414 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ReLU6 to learn 415 /// about the exact behavior of this module. 416 /// 417 /// See the documentation for `torch::nn::ReLU6Options` class to learn what 418 /// constructor arguments are supported for this module. 419 /// 420 /// Example: 421 /// ``` 422 /// ReLU6 model(ReLU6Options().inplace(true)); 423 /// ``` 424 class TORCH_API ReLU6Impl : public torch::nn::Cloneable<ReLU6Impl> { 425 public: 426 explicit ReLU6Impl(const ReLU6Options& options_ = {}); 427 428 Tensor forward(Tensor input); 429 430 void reset() override; 431 432 /// Pretty prints the `ReLU6` module into the given `stream`. 433 void pretty_print(std::ostream& stream) const override; 434 435 /// The options with which this `Module` was constructed. 436 ReLU6Options options; 437 }; 438 439 /// A `ModuleHolder` subclass for `ReLU6Impl`. 440 /// See the documentation for `ReLU6Impl` class to learn what methods it 441 /// provides, and examples of how to use `ReLU6` with `torch::nn::ReLU6Options`. 442 /// See the documentation for `ModuleHolder` to learn about PyTorch's 443 /// module storage semantics. 444 TORCH_MODULE(ReLU6); 445 446 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 447 448 /// Applies the RReLU function element-wise. 449 /// See https://pytorch.org/docs/main/nn.html#torch.nn.RReLU to learn 450 /// about the exact behavior of this module. 451 /// 452 /// See the documentation for `torch::nn::RReLUOptions` class to learn what 453 /// constructor arguments are supported for this module. 454 /// 455 /// Example: 456 /// ``` 457 /// RReLU model(RReLUOptions().lower(0.24).upper(0.42).inplace(true)); 458 /// ``` 459 class TORCH_API RReLUImpl : public torch::nn::Cloneable<RReLUImpl> { 460 public: 461 explicit RReLUImpl(const RReLUOptions& options_ = {}); 462 463 Tensor forward(Tensor input); 464 465 void reset() override; 466 467 /// Pretty prints the `RReLU` module into the given `stream`. 468 void pretty_print(std::ostream& stream) const override; 469 470 /// The options with which this `Module` was constructed. 471 RReLUOptions options; 472 }; 473 474 /// A `ModuleHolder` subclass for `RReLUImpl`. 475 /// See the documentation for `RReLUImpl` class to learn what methods it 476 /// provides, and examples of how to use `RReLU` with `torch::nn::RReLUOptions`. 477 /// See the documentation for `ModuleHolder` to learn about PyTorch's 478 /// module storage semantics. 479 TORCH_MODULE(RReLU); 480 481 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 482 483 /// Applies celu over a given input. 484 /// See https://pytorch.org/docs/main/nn.html#torch.nn.CELU to learn 485 /// about the exact behavior of this module. 486 /// 487 /// See the documentation for `torch::nn::CELUOptions` class to learn what 488 /// constructor arguments are supported for this module. 489 /// 490 /// Example: 491 /// ``` 492 /// CELU model(CELUOptions().alpha(42.42).inplace(true)); 493 /// ``` 494 class TORCH_API CELUImpl : public torch::nn::Cloneable<CELUImpl> { 495 public: 496 explicit CELUImpl(const CELUOptions& options_ = {}); 497 498 Tensor forward(Tensor input); 499 500 void reset() override; 501 502 /// Pretty prints the `CELU` module into the given `stream`. 503 void pretty_print(std::ostream& stream) const override; 504 505 /// The options with which this `Module` was constructed. 506 CELUOptions options; 507 }; 508 509 /// A `ModuleHolder` subclass for `CELUImpl`. 510 /// See the documentation for `CELUImpl` class to learn what methods it 511 /// provides, and examples of how to use `CELU` with `torch::nn::CELUOptions`. 512 /// See the documentation for `ModuleHolder` to learn about PyTorch's 513 /// module storage semantics. 514 TORCH_MODULE(CELU); 515 516 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 517 518 /// Applies glu over a given input. 519 /// See https://pytorch.org/docs/main/nn.html#torch.nn.GLU to learn 520 /// about the exact behavior of this module. 521 /// 522 /// See the documentation for `torch::nn::GLUOptions` class to learn what 523 /// constructor arguments are supported for this module. 524 /// 525 /// Example: 526 /// ``` 527 /// GLU model(GLUOptions(1)); 528 /// ``` 529 class TORCH_API GLUImpl : public torch::nn::Cloneable<GLUImpl> { 530 public: 531 explicit GLUImpl(const GLUOptions& options_ = {}); 532 533 Tensor forward(const Tensor& input); 534 535 void reset() override; 536 537 /// Pretty prints the `GLU` module into the given `stream`. 538 void pretty_print(std::ostream& stream) const override; 539 540 /// The options with which this `Module` was constructed. 541 GLUOptions options; 542 }; 543 544 /// A `ModuleHolder` subclass for `GLUImpl`. 545 /// See the documentation for `GLUImpl` class to learn what methods it 546 /// provides, and examples of how to use `GLU` with `torch::nn::GLUOptions`. 547 /// See the documentation for `ModuleHolder` to learn about PyTorch's 548 /// module storage semantics. 549 TORCH_MODULE(GLU); 550 551 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 552 553 /// Applies gelu over a given input. 554 /// See https://pytorch.org/docs/main/nn.html#torch.nn.GELU to learn 555 /// about the exact behavior of this module. 556 class TORCH_API GELUImpl : public torch::nn::Cloneable<GELUImpl> { 557 public: 558 explicit GELUImpl(GELUOptions options_ = {}); 559 560 Tensor forward(const Tensor& input); 561 562 void reset() override; 563 564 /// Pretty prints the `GELU` module into the given `stream`. 565 void pretty_print(std::ostream& stream) const override; 566 567 /// The options with which this `Module` was constructed. 568 GELUOptions options; 569 }; 570 571 /// A `ModuleHolder` subclass for `GELUImpl`. 572 /// See the documentation for `GELUImpl` class to learn what methods it 573 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 574 /// module storage semantics. 575 TORCH_MODULE(GELU); 576 577 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SiLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 578 579 /// Applies silu over a given input. 580 /// See https://pytorch.org/docs/main/nn.html#torch.nn.SiLU to learn 581 /// about the exact behavior of this module. 582 class TORCH_API SiLUImpl : public torch::nn::Cloneable<SiLUImpl> { 583 public: 584 Tensor forward(const Tensor& input); 585 586 void reset() override; 587 588 /// Pretty prints the `SiLU` module into the given `stream`. 589 void pretty_print(std::ostream& stream) const override; 590 }; 591 592 /// A `ModuleHolder` subclass for `SiLUImpl`. 593 /// See the documentation for `SiLUImpl` class to learn what methods it 594 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 595 /// module storage semantics. 596 TORCH_MODULE(SiLU); 597 598 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Mish ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 599 600 /// Applies mish over a given input. 601 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Mish to learn 602 /// about the exact behavior of this module. 603 class TORCH_API MishImpl : public torch::nn::Cloneable<MishImpl> { 604 public: 605 Tensor forward(const Tensor& input); 606 607 void reset() override; 608 609 /// Pretty prints the `Mish` module into the given `stream`. 610 void pretty_print(std::ostream& stream) const override; 611 }; 612 613 /// A `ModuleHolder` subclass for `MishImpl`. 614 /// See the documentation for `MishImpl` class to learn what methods it 615 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 616 /// module storage semantics. 617 TORCH_MODULE(Mish); 618 619 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 620 621 /// Applies sigmoid over a given input. 622 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Sigmoid to learn 623 /// about the exact behavior of this module. 624 class TORCH_API SigmoidImpl : public torch::nn::Cloneable<SigmoidImpl> { 625 public: 626 Tensor forward(const Tensor& input); 627 628 void reset() override; 629 630 /// Pretty prints the `Sigmoid` module into the given `stream`. 631 void pretty_print(std::ostream& stream) const override; 632 }; 633 634 /// A `ModuleHolder` subclass for `SigmoidImpl`. 635 /// See the documentation for `SigmoidImpl` class to learn what methods it 636 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 637 /// module storage semantics. 638 TORCH_MODULE(Sigmoid); 639 640 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softplus ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 641 642 /// Applies softplus over a given input. 643 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softplus to learn 644 /// about the exact behavior of this module. 645 /// 646 /// See the documentation for `torch::nn::SoftplusOptions` class to learn what 647 /// constructor arguments are supported for this module. 648 /// 649 /// Example: 650 /// ``` 651 /// Softplus model(SoftplusOptions().beta(0.24).threshold(42.42)); 652 /// ``` 653 class TORCH_API SoftplusImpl : public torch::nn::Cloneable<SoftplusImpl> { 654 public: 655 explicit SoftplusImpl(const SoftplusOptions& options_ = {}); 656 657 Tensor forward(const Tensor& input); 658 659 void reset() override; 660 661 /// Pretty prints the `Softplus` module into the given `stream`. 662 void pretty_print(std::ostream& stream) const override; 663 664 /// The options with which this `Module` was constructed. 665 SoftplusOptions options; 666 }; 667 668 /// A `ModuleHolder` subclass for `SoftplusImpl`. 669 /// See the documentation for `SoftplusImpl` class to learn what methods it 670 /// provides, and examples of how to use `Softplus` with 671 /// `torch::nn::SoftplusOptions`. See the documentation for `ModuleHolder` to 672 /// learn about PyTorch's module storage semantics. 673 TORCH_MODULE(Softplus); 674 675 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 676 677 /// Applies the soft shrinkage function element-wise. 678 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softshrink to learn 679 /// about the exact behavior of this module. 680 /// 681 /// See the documentation for `torch::nn::SoftshrinkOptions` class to learn what 682 /// constructor arguments are supported for this module. 683 /// 684 /// Example: 685 /// ``` 686 /// Softshrink model(SoftshrinkOptions(42.42)); 687 /// ``` 688 class TORCH_API SoftshrinkImpl : public torch::nn::Cloneable<SoftshrinkImpl> { 689 public: 690 explicit SoftshrinkImpl(const SoftshrinkOptions& options_ = {}); 691 692 Tensor forward(const Tensor& input); 693 694 void reset() override; 695 696 /// Pretty prints the `Softshrink` module into the given `stream`. 697 void pretty_print(std::ostream& stream) const override; 698 699 /// The options with which this `Module` was constructed. 700 SoftshrinkOptions options; 701 }; 702 703 /// A `ModuleHolder` subclass for `SoftshrinkImpl`. 704 /// See the documentation for `SoftshrinkImpl` class to learn what methods it 705 /// provides, and examples of how to use `Softshrink` with 706 /// `torch::nn::SoftshrinkOptions`. See the documentation for `ModuleHolder` to 707 /// learn about PyTorch's module storage semantics. 708 TORCH_MODULE(Softshrink); 709 710 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softsign ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 711 712 /// Applies Softsign over a given input. 713 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Softsign to learn 714 /// about the exact behavior of this module. 715 class TORCH_API SoftsignImpl : public torch::nn::Cloneable<SoftsignImpl> { 716 public: 717 Tensor forward(const Tensor& input); 718 719 void reset() override; 720 721 /// Pretty prints the `Softsign` module into the given `stream`. 722 void pretty_print(std::ostream& stream) const override; 723 }; 724 725 /// A `ModuleHolder` subclass for `SoftsignImpl`. 726 /// See the documentation for `SoftsignImpl` class to learn what methods it 727 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 728 /// module storage semantics. 729 TORCH_MODULE(Softsign); 730 731 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 732 733 /// Applies Tanh over a given input. 734 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Tanh to learn 735 /// about the exact behavior of this module. 736 class TORCH_API TanhImpl : public torch::nn::Cloneable<TanhImpl> { 737 public: 738 Tensor forward(const Tensor& input); 739 740 void reset() override; 741 742 /// Pretty prints the `Tanh` module into the given `stream`. 743 void pretty_print(std::ostream& stream) const override; 744 }; 745 746 /// A `ModuleHolder` subclass for `TanhImpl`. 747 /// See the documentation for `TanhImpl` class to learn what methods it 748 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 749 /// module storage semantics. 750 TORCH_MODULE(Tanh); 751 752 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanhshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 753 754 /// Applies Tanhshrink over a given input. 755 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Tanhshrink to learn 756 /// about the exact behavior of this module. 757 class TORCH_API TanhshrinkImpl : public torch::nn::Cloneable<TanhshrinkImpl> { 758 public: 759 Tensor forward(const Tensor& input); 760 761 void reset() override; 762 763 /// Pretty prints the `Tanhshrink` module into the given `stream`. 764 void pretty_print(std::ostream& stream) const override; 765 }; 766 767 /// A `ModuleHolder` subclass for `TanhshrinkImpl`. 768 /// See the documentation for `TanhshrinkImpl` class to learn what methods it 769 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 770 /// module storage semantics. 771 TORCH_MODULE(Tanhshrink); 772 773 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Threshold ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 774 775 /// Applies the Threshold function element-wise. 776 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Threshold to learn 777 /// about the exact behavior of this module. 778 /// 779 /// See the documentation for `torch::nn::ThresholdOptions` class to learn what 780 /// constructor arguments are supported for this module. 781 /// 782 /// Example: 783 /// ``` 784 /// Threshold model(ThresholdOptions(42.42, 24.24).inplace(true)); 785 /// ``` 786 class TORCH_API ThresholdImpl : public torch::nn::Cloneable<ThresholdImpl> { 787 public: ThresholdImpl(double threshold,double value)788 ThresholdImpl(double threshold, double value) 789 : ThresholdImpl(ThresholdOptions(threshold, value)) {} 790 explicit ThresholdImpl(const ThresholdOptions& options_); 791 792 Tensor forward(Tensor input); 793 794 void reset() override; 795 796 /// Pretty prints the `Threshold` module into the given `stream`. 797 void pretty_print(std::ostream& stream) const override; 798 799 /// The options with which this `Module` was constructed. 800 ThresholdOptions options; 801 }; 802 803 /// A `ModuleHolder` subclass for `ThresholdImpl`. 804 /// See the documentation for `ThresholdImpl` class to learn what methods it 805 /// provides, and examples of how to use `Threshold` with 806 /// `torch::nn::ThresholdOptions`. See the documentation for `ModuleHolder` to 807 /// learn about PyTorch's module storage semantics. 808 TORCH_MODULE(Threshold); 809 810 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiheadAttention ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 811 812 /// Applies the MultiheadAttention function element-wise. 813 /// See https://pytorch.org/docs/main/nn.html#torch.nn.MultiheadAttention 814 /// to learn about the exact behavior of this module. 815 /// 816 /// See the documentation for `torch::nn::MultiheadAttentionOptions` class to 817 /// learn what constructor arguments are supported for this module. 818 /// 819 /// Example: 820 /// ``` 821 /// MultiheadAttention model(MultiheadAttentionOptions(20, 10).bias(false)); 822 /// ``` 823 class TORCH_API MultiheadAttentionImpl 824 : public torch::nn::Cloneable<MultiheadAttentionImpl> { 825 public: MultiheadAttentionImpl(int64_t embed_dim,int64_t num_heads)826 MultiheadAttentionImpl(int64_t embed_dim, int64_t num_heads) 827 : MultiheadAttentionImpl( 828 MultiheadAttentionOptions(embed_dim, num_heads)) {} 829 explicit MultiheadAttentionImpl(const MultiheadAttentionOptions& options_); 830 831 std::tuple<Tensor, Tensor> forward( 832 const Tensor& query, 833 const Tensor& key, 834 const Tensor& value, 835 const Tensor& key_padding_mask = {}, 836 bool need_weights = true, 837 const Tensor& attn_mask = {}, 838 bool average_attn_weights = true); 839 840 protected: 841 FORWARD_HAS_DEFAULT_ARGS( 842 {3, AnyValue(Tensor())}, 843 {4, AnyValue(true)}, 844 {5, AnyValue(Tensor())}, 845 {6, AnyValue(true)}) 846 847 public: 848 void reset() override; 849 850 void _reset_parameters(); 851 852 /// The options with which this `Module` was constructed. 853 MultiheadAttentionOptions options; 854 855 bool _qkv_same_embed_dim; 856 Tensor in_proj_weight; 857 Tensor in_proj_bias; 858 Tensor bias_k; 859 Tensor bias_v; 860 Linear out_proj = nullptr; 861 Tensor q_proj_weight; 862 Tensor k_proj_weight; 863 Tensor v_proj_weight; 864 int64_t head_dim; 865 }; 866 867 /// A `ModuleHolder` subclass for `MultiheadAttentionImpl`. 868 /// See the documentation for `MultiheadAttentionImpl` class to learn what 869 /// methods it provides, and examples of how to use `MultiheadAttention` with 870 /// `torch::nn::MultiheadAttentionOptions`. See the documentation for 871 /// `ModuleHolder` to learn about PyTorch's module storage semantics. 872 TORCH_MODULE(MultiheadAttention); 873 874 } // namespace nn 875 } // namespace torch 876