1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/enum.h> 6 #include <torch/types.h> 7 8 namespace torch { 9 namespace nn { 10 11 /// Options for the `ELU` module. 12 /// 13 /// Example: 14 /// ``` 15 /// ELU model(ELUOptions().alpha(42.42).inplace(true)); 16 /// ``` 17 struct TORCH_API ELUOptions { 18 /// The `alpha` value for the ELU formulation. Default: 1.0 19 TORCH_ARG(double, alpha) = 1.0; 20 21 /// can optionally do the operation in-place. Default: False 22 TORCH_ARG(bool, inplace) = false; 23 }; 24 25 namespace functional { 26 /// Options for `torch::nn::functional::elu`. 27 /// 28 /// See the documentation for `torch::nn::ELUOptions` class to learn what 29 /// arguments are supported. 30 /// 31 /// Example: 32 /// ``` 33 /// namespace F = torch::nn::functional; 34 /// F::elu(x, F::ELUFuncOptions().alpha(0.42).inplace(true)); 35 /// ``` 36 using ELUFuncOptions = ELUOptions; 37 } // namespace functional 38 39 // ============================================================================ 40 41 /// Options for the `SELU` module. 42 /// 43 /// Example: 44 /// ``` 45 /// SELU model(SELUOptions().inplace(true)); 46 /// ``` 47 struct TORCH_API SELUOptions { 48 /* implicit */ SELUOptions(bool inplace = false); 49 50 /// can optionally do the operation in-place. Default: False 51 TORCH_ARG(bool, inplace); 52 }; 53 54 namespace functional { 55 /// Options for `torch::nn::functional::selu`. 56 /// 57 /// See the documentation for `torch::nn::SELUOptions` class to learn what 58 /// arguments are supported. 59 /// 60 /// Example: 61 /// ``` 62 /// namespace F = torch::nn::functional; 63 /// F::selu(input, F::SELUFuncOptions(false)); 64 /// ``` 65 using SELUFuncOptions = SELUOptions; 66 } // namespace functional 67 68 // ============================================================================ 69 70 /// Options for the `GLU` module. 71 /// 72 /// Example: 73 /// ``` 74 /// GLU model(GLUOptions(1)); 75 /// ``` 76 struct TORCH_API GLUOptions { 77 /* implicit */ GLUOptions(int64_t dim = -1); 78 79 /// the dimension on which to split the input. Default: -1 80 TORCH_ARG(int64_t, dim); 81 }; 82 83 namespace functional { 84 /// Options for `torch::nn::functional::glu`. 85 /// 86 /// See the documentation for `torch::nn::GLUOptions` class to learn what 87 /// arguments are supported. 88 /// 89 /// Example: 90 /// ``` 91 /// namespace F = torch::nn::functional; 92 /// F::glu(input, GLUFuncOptions(1)); 93 /// ``` 94 using GLUFuncOptions = GLUOptions; 95 } // namespace functional 96 97 // ============================================================================ 98 99 /// Options for the `GELU` module. 100 /// 101 /// Example: 102 /// ``` 103 /// GELU model(GELUOptions().approximate("none")); 104 /// ``` 105 struct TORCH_API GELUOptions { 106 /// Specifies the approximation to apply to the output. 107 TORCH_ARG(std::string, approximate) = "none"; 108 }; 109 110 namespace functional { 111 /// Options for `torch::nn::functional::gelu`. 112 /// 113 /// See the documentation for `torch::nn::GELUOptions` class to learn what 114 /// arguments are supported. 115 /// 116 /// Example: 117 /// ``` 118 /// namespace F = torch::nn::functional; 119 /// F::gelu(input, F::GELUFuncOptions().approximate("none")); 120 /// ``` 121 using GELUFuncOptions = GELUOptions; 122 } // namespace functional 123 124 // ============================================================================ 125 126 /// Options for the `Hardshrink` module. 127 /// 128 /// Example: 129 /// ``` 130 /// Hardshrink model(HardshrinkOptions().lambda(42.42)); 131 /// ``` 132 struct TORCH_API HardshrinkOptions { 133 /* implicit */ HardshrinkOptions(double lambda = 0.5); 134 135 /// the `lambda` value for the Hardshrink formulation. Default: 0.5 136 TORCH_ARG(double, lambda); 137 }; 138 139 namespace functional { 140 /// Options for `torch::nn::functional::hardshrink`. 141 /// 142 /// See the documentation for `torch::nn::HardshrinkOptions` class to learn what 143 /// arguments are supported. 144 /// 145 /// Example: 146 /// ``` 147 /// namespace F = torch::nn::functional; 148 /// F::hardshrink(x, F::HardshrinkFuncOptions().lambda(0.42)); 149 /// ``` 150 using HardshrinkFuncOptions = HardshrinkOptions; 151 } // namespace functional 152 153 // ============================================================================ 154 155 /// Options for the `Hardtanh` module. 156 /// 157 /// Example: 158 /// ``` 159 /// Hardtanh 160 /// model(HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true)); 161 /// ``` 162 struct TORCH_API HardtanhOptions { 163 /// minimum value of the linear region range. Default: -1 164 TORCH_ARG(double, min_val) = -1.0; 165 166 /// maximum value of the linear region range. Default: 1 167 TORCH_ARG(double, max_val) = 1.0; 168 169 /// can optionally do the operation in-place. Default: False 170 TORCH_ARG(bool, inplace) = false; 171 }; 172 173 namespace functional { 174 /// Options for `torch::nn::functional::hardtanh`. 175 /// 176 /// See the documentation for `torch::nn::HardtanhOptions` class to learn what 177 /// arguments are supported. 178 /// 179 /// Example: 180 /// ``` 181 /// namespace F = torch::nn::functional; 182 /// F::hardtanh(x, 183 /// F::HardtanhFuncOptions().min_val(-1.0).max_val(1.0).inplace(true)); 184 /// ``` 185 using HardtanhFuncOptions = HardtanhOptions; 186 } // namespace functional 187 188 // ============================================================================ 189 190 /// Options for the `LeakyReLU` module. 191 /// 192 /// Example: 193 /// ``` 194 /// LeakyReLU model(LeakyReLUOptions().negative_slope(0.42).inplace(true)); 195 /// ``` 196 struct TORCH_API LeakyReLUOptions { 197 /// Controls the angle of the negative slope. Default: 1e-2 198 TORCH_ARG(double, negative_slope) = 1e-2; 199 200 /// can optionally do the operation in-place. Default: False 201 TORCH_ARG(bool, inplace) = false; 202 }; 203 204 namespace functional { 205 /// Options for `torch::nn::functional::leaky_relu`. 206 /// 207 /// See the documentation for `torch::nn::LeakyReLUOptions` class to learn what 208 /// arguments are supported. 209 /// 210 /// Example: 211 /// ``` 212 /// namespace F = torch::nn::functional; 213 /// F::leaky_relu(x, 214 /// F::LeakyReLUFuncOptions().negative_slope(0.42).inplace(true)); 215 /// ``` 216 using LeakyReLUFuncOptions = LeakyReLUOptions; 217 } // namespace functional 218 219 // ============================================================================ 220 221 /// Options for the `Softmax` module. 222 /// 223 /// Example: 224 /// ``` 225 /// Softmax model(SoftmaxOptions(1)); 226 /// ``` 227 struct TORCH_API SoftmaxOptions { 228 SoftmaxOptions(int64_t dim); 229 230 /// Dimension along which Softmax will be computed. 231 TORCH_ARG(int64_t, dim); 232 }; 233 234 // ============================================================================ 235 236 namespace functional { 237 238 /// Options for `torch::nn::functional::softmax`. 239 /// 240 /// Example: 241 /// ``` 242 /// namespace F = torch::nn::functional; 243 /// F::softmax(input, F::SoftmaxFuncOptions(1)); 244 /// ``` 245 struct TORCH_API SoftmaxFuncOptions { 246 SoftmaxFuncOptions(int64_t dim); 247 248 /// Dimension along which Softmax will be computed. 249 TORCH_ARG(int64_t, dim); 250 251 /// the desired data type of returned tensor. 252 /// If specified, the input tensor is casted to `dtype` before the operation 253 /// is performed. This is useful for preventing data type overflows. Default: 254 /// None. 255 TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt; 256 }; 257 258 } // namespace functional 259 260 // ============================================================================ 261 262 /// Options for the `Softmin` module. 263 /// 264 /// Example: 265 /// ``` 266 /// Softmin model(SoftminOptions(1)); 267 /// ``` 268 struct TORCH_API SoftminOptions { 269 SoftminOptions(int64_t dim); 270 271 /// Dimension along which Softmin will be computed. 272 TORCH_ARG(int64_t, dim); 273 }; 274 275 // ============================================================================ 276 277 namespace functional { 278 279 /// Options for `torch::nn::functional::softmin`. 280 /// 281 /// Example: 282 /// ``` 283 /// namespace F = torch::nn::functional; 284 /// F::softmin(input, F::SoftminFuncOptions(1)); 285 /// ``` 286 struct TORCH_API SoftminFuncOptions { 287 SoftminFuncOptions(int64_t dim); 288 289 /// Dimension along which Softmin will be computed. 290 TORCH_ARG(int64_t, dim); 291 292 /// the desired data type of returned tensor. 293 /// If specified, the input tensor is casted to `dtype` before the operation 294 /// is performed. This is useful for preventing data type overflows. Default: 295 /// None. 296 TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt; 297 }; 298 299 } // namespace functional 300 301 // ============================================================================ 302 303 /// Options for the `LogSoftmax` module. 304 /// 305 /// Example: 306 /// ``` 307 /// LogSoftmax model(LogSoftmaxOptions(1)); 308 /// ``` 309 struct TORCH_API LogSoftmaxOptions { 310 LogSoftmaxOptions(int64_t dim); 311 312 /// Dimension along which LogSoftmax will be computed. 313 TORCH_ARG(int64_t, dim); 314 }; 315 316 // ============================================================================ 317 318 namespace functional { 319 320 /// Options for `torch::nn::functional::log_softmax`. 321 /// 322 /// Example: 323 /// ``` 324 /// namespace F = torch::nn::functional; 325 /// F::log_softmax(input, LogSoftmaxFuncOptions(1)); 326 /// ``` 327 struct TORCH_API LogSoftmaxFuncOptions { 328 LogSoftmaxFuncOptions(int64_t dim); 329 330 /// Dimension along which LogSoftmax will be computed. 331 TORCH_ARG(int64_t, dim); 332 333 /// the desired data type of returned tensor. 334 /// If specified, the input tensor is casted to `dtype` before the operation 335 /// is performed. This is useful for preventing data type overflows. Default: 336 /// None. 337 TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt; 338 }; 339 340 } // namespace functional 341 342 // ============================================================================ 343 344 /// Options for the `PReLU` module. 345 /// 346 /// Example: 347 /// ``` 348 /// PReLU model(PReLUOptions().num_parameters(42)); 349 /// ``` 350 struct TORCH_API PReLUOptions { 351 /// number of `a` to learn. Although it takes an int as input, there is only 352 /// two values are legitimate: 1, or the number of channels at input. Default: 353 /// 1 354 TORCH_ARG(int64_t, num_parameters) = 1; 355 356 /// the initial value of `a`. Default: 0.25 357 TORCH_ARG(double, init) = 0.25; 358 }; 359 360 // ============================================================================ 361 362 /// Options for the `ReLU` module. 363 /// 364 /// Example: 365 /// ``` 366 /// ReLU model(ReLUOptions().inplace(true)); 367 /// ``` 368 struct TORCH_API ReLUOptions { 369 /* implicit */ ReLUOptions(bool inplace = false); 370 371 /// can optionally do the operation in-place. Default: False 372 TORCH_ARG(bool, inplace); 373 }; 374 375 namespace functional { 376 /// Options for `torch::nn::functional::relu`. 377 /// 378 /// See the documentation for `torch::nn::ReLUOptions` class to learn what 379 /// arguments are supported. 380 /// 381 /// Example: 382 /// ``` 383 /// namespace F = torch::nn::functional; 384 /// F::relu(x, F::ReLUFuncOptions().inplace(true)); 385 /// ``` 386 using ReLUFuncOptions = ReLUOptions; 387 } // namespace functional 388 389 // ============================================================================ 390 391 /// Options for the `ReLU6` module. 392 /// 393 /// Example: 394 /// ``` 395 /// ReLU6 model(ReLU6Options().inplace(true)); 396 /// ``` 397 struct TORCH_API ReLU6Options { 398 /* implicit */ ReLU6Options(bool inplace = false); 399 400 /// can optionally do the operation in-place. Default: False 401 TORCH_ARG(bool, inplace); 402 }; 403 404 namespace functional { 405 /// Options for `torch::nn::functional::relu6`. 406 /// 407 /// See the documentation for `torch::nn::ReLU6Options` class to learn what 408 /// arguments are supported. 409 /// 410 /// Example: 411 /// ``` 412 /// namespace F = torch::nn::functional; 413 /// F::relu6(x, F::ReLU6FuncOptions().inplace(true)); 414 /// ``` 415 using ReLU6FuncOptions = ReLU6Options; 416 } // namespace functional 417 418 // ============================================================================ 419 420 /// Options for the `RReLU` module. 421 /// 422 /// Example: 423 /// ``` 424 /// RReLU model(RReLUOptions().lower(0.24).upper(0.42).inplace(true)); 425 /// ``` 426 struct TORCH_API RReLUOptions { 427 /// lower bound of the uniform distribution. Default: 1/8 428 TORCH_ARG(double, lower) = 1.0 / 8.0; 429 430 /// upper bound of the uniform distribution. Default: 1/3 431 TORCH_ARG(double, upper) = 1.0 / 3.0; 432 433 /// can optionally do the operation in-place. Default: False 434 TORCH_ARG(bool, inplace) = false; 435 }; 436 437 // ============================================================================ 438 439 namespace functional { 440 441 /// Options for `torch::nn::functional::rrelu`. 442 /// 443 /// Example: 444 /// ``` 445 /// namespace F = torch::nn::functional; 446 /// F::rrelu(x, F::RReLUFuncOptions().lower(0.1).upper(0.4).inplace(true)); 447 /// ``` 448 struct TORCH_API RReLUFuncOptions { 449 /// lower bound of the uniform distribution. Default: 1/8 450 TORCH_ARG(double, lower) = 1.0 / 8.0; 451 452 /// upper bound of the uniform distribution. Default: 1/3 453 TORCH_ARG(double, upper) = 1.0 / 3.0; 454 455 TORCH_ARG(bool, training) = false; 456 457 /// can optionally do the operation in-place. Default: False 458 TORCH_ARG(bool, inplace) = false; 459 }; 460 461 } // namespace functional 462 463 // ============================================================================ 464 465 /// Options for the `CELU` module. 466 /// 467 /// Example: 468 /// ``` 469 /// CELU model(CELUOptions().alpha(42.42).inplace(true)); 470 /// ``` 471 struct TORCH_API CELUOptions { 472 /// The `alpha` value for the CELU formulation. Default: 1.0 473 TORCH_ARG(double, alpha) = 1.0; 474 475 /// can optionally do the operation in-place. Default: False 476 TORCH_ARG(bool, inplace) = false; 477 }; 478 479 namespace functional { 480 /// Options for `torch::nn::functional::celu`. 481 /// 482 /// See the documentation for `torch::nn::CELUOptions` class to learn what 483 /// arguments are supported. 484 /// 485 /// Example: 486 /// ``` 487 /// namespace F = torch::nn::functional; 488 /// F::celu(x, F::CELUFuncOptions().alpha(0.42).inplace(true)); 489 /// ``` 490 using CELUFuncOptions = CELUOptions; 491 } // namespace functional 492 493 // ============================================================================ 494 495 /// Options for the `Softplus` module. 496 /// 497 /// Example: 498 /// ``` 499 /// Softplus model(SoftplusOptions().beta(0.24).threshold(42.42)); 500 /// ``` 501 struct TORCH_API SoftplusOptions { 502 /// the `beta` value for the Softplus formulation. Default: 1 503 TORCH_ARG(double, beta) = 1.0; 504 505 /// values above this revert to a linear function. Default: 20 506 TORCH_ARG(double, threshold) = 20.0; 507 }; 508 509 namespace functional { 510 /// Options for `torch::nn::functional::softplus`. 511 /// 512 /// See the documentation for `torch::nn::SoftplusOptions` class to learn what 513 /// arguments are supported. 514 /// 515 /// Example: 516 /// ``` 517 /// namespace F = torch::nn::functional; 518 /// F::softplus(x, F::SoftplusFuncOptions().beta(0.5).threshold(3.0)); 519 /// ``` 520 using SoftplusFuncOptions = SoftplusOptions; 521 } // namespace functional 522 523 // ============================================================================ 524 525 /// Options for the `Softshrink` module. 526 /// 527 /// Example: 528 /// ``` 529 /// Softshrink model(SoftshrinkOptions(42.42)); 530 /// ``` 531 struct TORCH_API SoftshrinkOptions { 532 /* implicit */ SoftshrinkOptions(double lambda = 0.5); 533 534 /// the `lambda` value for the Softshrink formulation. Default: 0.5 535 TORCH_ARG(double, lambda); 536 }; 537 538 namespace functional { 539 /// Options for `torch::nn::functional::softshrink`. 540 /// 541 /// See the documentation for `torch::nn::SoftshrinkOptions` class to learn what 542 /// arguments are supported. 543 /// 544 /// Example: 545 /// ``` 546 /// namespace F = torch::nn::functional; 547 /// F::softshrink(x, F::SoftshrinkFuncOptions(0.42)); 548 /// ``` 549 using SoftshrinkFuncOptions = SoftshrinkOptions; 550 } // namespace functional 551 552 // ============================================================================ 553 554 /// Options for the `Threshold` module. 555 /// 556 /// Example: 557 /// ``` 558 /// Threshold model(ThresholdOptions(42.42, 24.24).inplace(true)); 559 /// ``` 560 struct TORCH_API ThresholdOptions { ThresholdOptionsThresholdOptions561 ThresholdOptions(double threshold, double value) 562 : threshold_(threshold), value_(value) {} 563 564 /// The value to threshold at 565 TORCH_ARG(double, threshold); 566 567 /// The value to replace with 568 TORCH_ARG(double, value); 569 570 /// can optionally do the operation in-place. Default: False 571 TORCH_ARG(bool, inplace) = false; 572 }; 573 574 namespace functional { 575 /// Options for `torch::nn::functional::threshold`. 576 /// 577 /// See the documentation for `torch::nn::ThresholdOptions` class to learn what 578 /// arguments are supported. 579 /// 580 /// Example: 581 /// ``` 582 /// namespace F = torch::nn::functional; 583 /// F::threshold(x, F::ThresholdFuncOptions(0.5, 0.5).inplace(true)); 584 /// ``` 585 using ThresholdFuncOptions = ThresholdOptions; 586 } // namespace functional 587 588 // ============================================================================ 589 590 namespace functional { 591 592 /// Options for `torch::nn::functional::gumbel_softmax`. 593 /// 594 /// Example: 595 /// ``` 596 /// namespace F = torch::nn::functional; 597 /// F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(-1)); 598 /// ``` 599 struct TORCH_API GumbelSoftmaxFuncOptions { 600 /// non-negative scalar temperature 601 TORCH_ARG(double, tau) = 1.0; 602 603 /// returned samples will be discretized as one-hot vectors, 604 /// but will be differentiated as if it is the soft sample in autograd. 605 /// Default: False 606 TORCH_ARG(bool, hard) = false; 607 608 /// dimension along which softmax will be computed. Default: -1 609 TORCH_ARG(int, dim) = -1; 610 }; 611 612 } // namespace functional 613 614 // ============================================================================ 615 616 /// Options for the `MultiheadAttention` module. 617 /// 618 /// Example: 619 /// ``` 620 /// MultiheadAttention model(MultiheadAttentionOptions(20, 10).bias(false)); 621 /// ``` 622 struct TORCH_API MultiheadAttentionOptions { 623 MultiheadAttentionOptions(int64_t embed_dim, int64_t num_heads); 624 625 /// total dimension of the model. 626 TORCH_ARG(int64_t, embed_dim); 627 628 /// parallel attention heads. 629 TORCH_ARG(int64_t, num_heads); 630 631 /// a Dropout layer on attn_output_weights. Default: 0.0. 632 TORCH_ARG(double, dropout) = 0.0; 633 634 /// add bias as module parameter. Default: true. 635 TORCH_ARG(bool, bias) = true; 636 637 /// add bias to the key and value sequences at dim=0. 638 TORCH_ARG(bool, add_bias_kv) = false; 639 640 /// add a new batch of zeros to the key and value sequences at dim=1. 641 TORCH_ARG(bool, add_zero_attn) = false; 642 643 /// total number of features in key. Default: std::nullopt. 644 TORCH_ARG(int64_t, kdim); 645 646 /// total number of features in key. Default: std::nullopt. 647 TORCH_ARG(int64_t, vdim); 648 }; 649 650 // ============================================================================ 651 652 namespace functional { 653 654 /// Options for `torch::nn::functional::multi_head_attention_forward` 655 struct TORCH_API MultiheadAttentionForwardFuncOptions { 656 MultiheadAttentionForwardFuncOptions( 657 int64_t embed_dim_to_check, 658 int64_t num_heads, 659 Tensor in_proj_weight, 660 Tensor in_proj_bias, 661 Tensor bias_k, 662 Tensor bias_v, 663 bool add_zero_attn, 664 double dropout_p, 665 Tensor out_proj_weight, 666 Tensor out_proj_bias); 667 668 TORCH_ARG(int64_t, embed_dim_to_check); 669 670 TORCH_ARG(int64_t, num_heads); 671 672 TORCH_ARG(Tensor, in_proj_weight); 673 674 TORCH_ARG(Tensor, in_proj_bias); 675 676 TORCH_ARG(Tensor, bias_k); 677 678 TORCH_ARG(Tensor, bias_v); 679 680 TORCH_ARG(bool, add_zero_attn); 681 682 TORCH_ARG(double, dropout_p); 683 684 TORCH_ARG(Tensor, out_proj_weight); 685 686 TORCH_ARG(Tensor, out_proj_bias); 687 688 TORCH_ARG(bool, training) = true; 689 690 TORCH_ARG(Tensor, key_padding_mask) = {}; 691 692 TORCH_ARG(bool, need_weights) = true; 693 694 TORCH_ARG(Tensor, attn_mask) = {}; 695 696 TORCH_ARG(bool, use_separate_proj_weight) = false; 697 698 TORCH_ARG(Tensor, q_proj_weight) = {}; 699 700 TORCH_ARG(Tensor, k_proj_weight) = {}; 701 702 TORCH_ARG(Tensor, v_proj_weight) = {}; 703 704 TORCH_ARG(Tensor, static_k) = {}; 705 706 TORCH_ARG(Tensor, static_v) = {}; 707 708 TORCH_ARG(bool, average_attn_weights) = true; 709 }; 710 711 } // namespace functional 712 713 } // namespace nn 714 } // namespace torch 715