xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/activation.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/enum.h>
6 #include <torch/types.h>
7 
8 namespace torch {
9 namespace nn {
10 
11 /// Options for the `ELU` module.
12 ///
13 /// Example:
14 /// ```
15 /// ELU model(ELUOptions().alpha(42.42).inplace(true));
16 /// ```
17 struct TORCH_API ELUOptions {
18   /// The `alpha` value for the ELU formulation. Default: 1.0
19   TORCH_ARG(double, alpha) = 1.0;
20 
21   /// can optionally do the operation in-place. Default: False
22   TORCH_ARG(bool, inplace) = false;
23 };
24 
25 namespace functional {
26 /// Options for `torch::nn::functional::elu`.
27 ///
28 /// See the documentation for `torch::nn::ELUOptions` class to learn what
29 /// arguments are supported.
30 ///
31 /// Example:
32 /// ```
33 /// namespace F = torch::nn::functional;
34 /// F::elu(x, F::ELUFuncOptions().alpha(0.42).inplace(true));
35 /// ```
36 using ELUFuncOptions = ELUOptions;
37 } // namespace functional
38 
39 // ============================================================================
40 
41 /// Options for the `SELU` module.
42 ///
43 /// Example:
44 /// ```
45 /// SELU model(SELUOptions().inplace(true));
46 /// ```
47 struct TORCH_API SELUOptions {
48   /* implicit */ SELUOptions(bool inplace = false);
49 
50   /// can optionally do the operation in-place. Default: False
51   TORCH_ARG(bool, inplace);
52 };
53 
54 namespace functional {
55 /// Options for `torch::nn::functional::selu`.
56 ///
57 /// See the documentation for `torch::nn::SELUOptions` class to learn what
58 /// arguments are supported.
59 ///
60 /// Example:
61 /// ```
62 /// namespace F = torch::nn::functional;
63 /// F::selu(input, F::SELUFuncOptions(false));
64 /// ```
65 using SELUFuncOptions = SELUOptions;
66 } // namespace functional
67 
68 // ============================================================================
69 
70 /// Options for the `GLU` module.
71 ///
72 /// Example:
73 /// ```
74 /// GLU model(GLUOptions(1));
75 /// ```
76 struct TORCH_API GLUOptions {
77   /* implicit */ GLUOptions(int64_t dim = -1);
78 
79   /// the dimension on which to split the input. Default: -1
80   TORCH_ARG(int64_t, dim);
81 };
82 
83 namespace functional {
84 /// Options for `torch::nn::functional::glu`.
85 ///
86 /// See the documentation for `torch::nn::GLUOptions` class to learn what
87 /// arguments are supported.
88 ///
89 /// Example:
90 /// ```
91 /// namespace F = torch::nn::functional;
92 /// F::glu(input, GLUFuncOptions(1));
93 /// ```
94 using GLUFuncOptions = GLUOptions;
95 } // namespace functional
96 
97 // ============================================================================
98 
99 /// Options for the `GELU` module.
100 ///
101 /// Example:
102 /// ```
103 /// GELU model(GELUOptions().approximate("none"));
104 /// ```
105 struct TORCH_API GELUOptions {
106   /// Specifies the approximation to apply to the output.
107   TORCH_ARG(std::string, approximate) = "none";
108 };
109 
110 namespace functional {
111 /// Options for `torch::nn::functional::gelu`.
112 ///
113 /// See the documentation for `torch::nn::GELUOptions` class to learn what
114 /// arguments are supported.
115 ///
116 /// Example:
117 /// ```
118 /// namespace F = torch::nn::functional;
119 /// F::gelu(input, F::GELUFuncOptions().approximate("none"));
120 /// ```
121 using GELUFuncOptions = GELUOptions;
122 } // namespace functional
123 
124 // ============================================================================
125 
126 /// Options for the `Hardshrink` module.
127 ///
128 /// Example:
129 /// ```
130 /// Hardshrink model(HardshrinkOptions().lambda(42.42));
131 /// ```
132 struct TORCH_API HardshrinkOptions {
133   /* implicit */ HardshrinkOptions(double lambda = 0.5);
134 
135   /// the `lambda` value for the Hardshrink formulation. Default: 0.5
136   TORCH_ARG(double, lambda);
137 };
138 
139 namespace functional {
140 /// Options for `torch::nn::functional::hardshrink`.
141 ///
142 /// See the documentation for `torch::nn::HardshrinkOptions` class to learn what
143 /// arguments are supported.
144 ///
145 /// Example:
146 /// ```
147 /// namespace F = torch::nn::functional;
148 /// F::hardshrink(x, F::HardshrinkFuncOptions().lambda(0.42));
149 /// ```
150 using HardshrinkFuncOptions = HardshrinkOptions;
151 } // namespace functional
152 
153 // ============================================================================
154 
155 /// Options for the `Hardtanh` module.
156 ///
157 /// Example:
158 /// ```
159 /// Hardtanh
160 /// model(HardtanhOptions().min_val(-42.42).max_val(0.42).inplace(true));
161 /// ```
162 struct TORCH_API HardtanhOptions {
163   /// minimum value of the linear region range. Default: -1
164   TORCH_ARG(double, min_val) = -1.0;
165 
166   /// maximum value of the linear region range. Default: 1
167   TORCH_ARG(double, max_val) = 1.0;
168 
169   /// can optionally do the operation in-place. Default: False
170   TORCH_ARG(bool, inplace) = false;
171 };
172 
173 namespace functional {
174 /// Options for `torch::nn::functional::hardtanh`.
175 ///
176 /// See the documentation for `torch::nn::HardtanhOptions` class to learn what
177 /// arguments are supported.
178 ///
179 /// Example:
180 /// ```
181 /// namespace F = torch::nn::functional;
182 /// F::hardtanh(x,
183 /// F::HardtanhFuncOptions().min_val(-1.0).max_val(1.0).inplace(true));
184 /// ```
185 using HardtanhFuncOptions = HardtanhOptions;
186 } // namespace functional
187 
188 // ============================================================================
189 
190 /// Options for the `LeakyReLU` module.
191 ///
192 /// Example:
193 /// ```
194 /// LeakyReLU model(LeakyReLUOptions().negative_slope(0.42).inplace(true));
195 /// ```
196 struct TORCH_API LeakyReLUOptions {
197   /// Controls the angle of the negative slope. Default: 1e-2
198   TORCH_ARG(double, negative_slope) = 1e-2;
199 
200   /// can optionally do the operation in-place. Default: False
201   TORCH_ARG(bool, inplace) = false;
202 };
203 
204 namespace functional {
205 /// Options for `torch::nn::functional::leaky_relu`.
206 ///
207 /// See the documentation for `torch::nn::LeakyReLUOptions` class to learn what
208 /// arguments are supported.
209 ///
210 /// Example:
211 /// ```
212 /// namespace F = torch::nn::functional;
213 /// F::leaky_relu(x,
214 /// F::LeakyReLUFuncOptions().negative_slope(0.42).inplace(true));
215 /// ```
216 using LeakyReLUFuncOptions = LeakyReLUOptions;
217 } // namespace functional
218 
219 // ============================================================================
220 
221 /// Options for the `Softmax` module.
222 ///
223 /// Example:
224 /// ```
225 /// Softmax model(SoftmaxOptions(1));
226 /// ```
227 struct TORCH_API SoftmaxOptions {
228   SoftmaxOptions(int64_t dim);
229 
230   /// Dimension along which Softmax will be computed.
231   TORCH_ARG(int64_t, dim);
232 };
233 
234 // ============================================================================
235 
236 namespace functional {
237 
238 /// Options for `torch::nn::functional::softmax`.
239 ///
240 /// Example:
241 /// ```
242 /// namespace F = torch::nn::functional;
243 /// F::softmax(input, F::SoftmaxFuncOptions(1));
244 /// ```
245 struct TORCH_API SoftmaxFuncOptions {
246   SoftmaxFuncOptions(int64_t dim);
247 
248   /// Dimension along which Softmax will be computed.
249   TORCH_ARG(int64_t, dim);
250 
251   /// the desired data type of returned tensor.
252   /// If specified, the input tensor is casted to `dtype` before the operation
253   /// is performed. This is useful for preventing data type overflows. Default:
254   /// None.
255   TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt;
256 };
257 
258 } // namespace functional
259 
260 // ============================================================================
261 
262 /// Options for the `Softmin` module.
263 ///
264 /// Example:
265 /// ```
266 /// Softmin model(SoftminOptions(1));
267 /// ```
268 struct TORCH_API SoftminOptions {
269   SoftminOptions(int64_t dim);
270 
271   /// Dimension along which Softmin will be computed.
272   TORCH_ARG(int64_t, dim);
273 };
274 
275 // ============================================================================
276 
277 namespace functional {
278 
279 /// Options for `torch::nn::functional::softmin`.
280 ///
281 /// Example:
282 /// ```
283 /// namespace F = torch::nn::functional;
284 /// F::softmin(input, F::SoftminFuncOptions(1));
285 /// ```
286 struct TORCH_API SoftminFuncOptions {
287   SoftminFuncOptions(int64_t dim);
288 
289   /// Dimension along which Softmin will be computed.
290   TORCH_ARG(int64_t, dim);
291 
292   /// the desired data type of returned tensor.
293   /// If specified, the input tensor is casted to `dtype` before the operation
294   /// is performed. This is useful for preventing data type overflows. Default:
295   /// None.
296   TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt;
297 };
298 
299 } // namespace functional
300 
301 // ============================================================================
302 
303 /// Options for the `LogSoftmax` module.
304 ///
305 /// Example:
306 /// ```
307 /// LogSoftmax model(LogSoftmaxOptions(1));
308 /// ```
309 struct TORCH_API LogSoftmaxOptions {
310   LogSoftmaxOptions(int64_t dim);
311 
312   /// Dimension along which LogSoftmax will be computed.
313   TORCH_ARG(int64_t, dim);
314 };
315 
316 // ============================================================================
317 
318 namespace functional {
319 
320 /// Options for `torch::nn::functional::log_softmax`.
321 ///
322 /// Example:
323 /// ```
324 /// namespace F = torch::nn::functional;
325 /// F::log_softmax(input, LogSoftmaxFuncOptions(1));
326 /// ```
327 struct TORCH_API LogSoftmaxFuncOptions {
328   LogSoftmaxFuncOptions(int64_t dim);
329 
330   /// Dimension along which LogSoftmax will be computed.
331   TORCH_ARG(int64_t, dim);
332 
333   /// the desired data type of returned tensor.
334   /// If specified, the input tensor is casted to `dtype` before the operation
335   /// is performed. This is useful for preventing data type overflows. Default:
336   /// None.
337   TORCH_ARG(std::optional<torch::Dtype>, dtype) = std::nullopt;
338 };
339 
340 } // namespace functional
341 
342 // ============================================================================
343 
344 /// Options for the `PReLU` module.
345 ///
346 /// Example:
347 /// ```
348 /// PReLU model(PReLUOptions().num_parameters(42));
349 /// ```
350 struct TORCH_API PReLUOptions {
351   /// number of `a` to learn. Although it takes an int as input, there is only
352   /// two values are legitimate: 1, or the number of channels at input. Default:
353   /// 1
354   TORCH_ARG(int64_t, num_parameters) = 1;
355 
356   /// the initial value of `a`. Default: 0.25
357   TORCH_ARG(double, init) = 0.25;
358 };
359 
360 // ============================================================================
361 
362 /// Options for the `ReLU` module.
363 ///
364 /// Example:
365 /// ```
366 /// ReLU model(ReLUOptions().inplace(true));
367 /// ```
368 struct TORCH_API ReLUOptions {
369   /* implicit */ ReLUOptions(bool inplace = false);
370 
371   /// can optionally do the operation in-place. Default: False
372   TORCH_ARG(bool, inplace);
373 };
374 
375 namespace functional {
376 /// Options for `torch::nn::functional::relu`.
377 ///
378 /// See the documentation for `torch::nn::ReLUOptions` class to learn what
379 /// arguments are supported.
380 ///
381 /// Example:
382 /// ```
383 /// namespace F = torch::nn::functional;
384 /// F::relu(x, F::ReLUFuncOptions().inplace(true));
385 /// ```
386 using ReLUFuncOptions = ReLUOptions;
387 } // namespace functional
388 
389 // ============================================================================
390 
391 /// Options for the `ReLU6` module.
392 ///
393 /// Example:
394 /// ```
395 /// ReLU6 model(ReLU6Options().inplace(true));
396 /// ```
397 struct TORCH_API ReLU6Options {
398   /* implicit */ ReLU6Options(bool inplace = false);
399 
400   /// can optionally do the operation in-place. Default: False
401   TORCH_ARG(bool, inplace);
402 };
403 
404 namespace functional {
405 /// Options for `torch::nn::functional::relu6`.
406 ///
407 /// See the documentation for `torch::nn::ReLU6Options` class to learn what
408 /// arguments are supported.
409 ///
410 /// Example:
411 /// ```
412 /// namespace F = torch::nn::functional;
413 /// F::relu6(x, F::ReLU6FuncOptions().inplace(true));
414 /// ```
415 using ReLU6FuncOptions = ReLU6Options;
416 } // namespace functional
417 
418 // ============================================================================
419 
420 /// Options for the `RReLU` module.
421 ///
422 /// Example:
423 /// ```
424 /// RReLU model(RReLUOptions().lower(0.24).upper(0.42).inplace(true));
425 /// ```
426 struct TORCH_API RReLUOptions {
427   /// lower bound of the uniform distribution. Default: 1/8
428   TORCH_ARG(double, lower) = 1.0 / 8.0;
429 
430   /// upper bound of the uniform distribution. Default: 1/3
431   TORCH_ARG(double, upper) = 1.0 / 3.0;
432 
433   /// can optionally do the operation in-place. Default: False
434   TORCH_ARG(bool, inplace) = false;
435 };
436 
437 // ============================================================================
438 
439 namespace functional {
440 
441 /// Options for `torch::nn::functional::rrelu`.
442 ///
443 /// Example:
444 /// ```
445 /// namespace F = torch::nn::functional;
446 /// F::rrelu(x, F::RReLUFuncOptions().lower(0.1).upper(0.4).inplace(true));
447 /// ```
448 struct TORCH_API RReLUFuncOptions {
449   /// lower bound of the uniform distribution. Default: 1/8
450   TORCH_ARG(double, lower) = 1.0 / 8.0;
451 
452   /// upper bound of the uniform distribution. Default: 1/3
453   TORCH_ARG(double, upper) = 1.0 / 3.0;
454 
455   TORCH_ARG(bool, training) = false;
456 
457   /// can optionally do the operation in-place. Default: False
458   TORCH_ARG(bool, inplace) = false;
459 };
460 
461 } // namespace functional
462 
463 // ============================================================================
464 
465 /// Options for the `CELU` module.
466 ///
467 /// Example:
468 /// ```
469 /// CELU model(CELUOptions().alpha(42.42).inplace(true));
470 /// ```
471 struct TORCH_API CELUOptions {
472   /// The `alpha` value for the CELU formulation. Default: 1.0
473   TORCH_ARG(double, alpha) = 1.0;
474 
475   /// can optionally do the operation in-place. Default: False
476   TORCH_ARG(bool, inplace) = false;
477 };
478 
479 namespace functional {
480 /// Options for `torch::nn::functional::celu`.
481 ///
482 /// See the documentation for `torch::nn::CELUOptions` class to learn what
483 /// arguments are supported.
484 ///
485 /// Example:
486 /// ```
487 /// namespace F = torch::nn::functional;
488 /// F::celu(x, F::CELUFuncOptions().alpha(0.42).inplace(true));
489 /// ```
490 using CELUFuncOptions = CELUOptions;
491 } // namespace functional
492 
493 // ============================================================================
494 
495 /// Options for the `Softplus` module.
496 ///
497 /// Example:
498 /// ```
499 /// Softplus model(SoftplusOptions().beta(0.24).threshold(42.42));
500 /// ```
501 struct TORCH_API SoftplusOptions {
502   /// the `beta` value for the Softplus formulation. Default: 1
503   TORCH_ARG(double, beta) = 1.0;
504 
505   /// values above this revert to a linear function. Default: 20
506   TORCH_ARG(double, threshold) = 20.0;
507 };
508 
509 namespace functional {
510 /// Options for `torch::nn::functional::softplus`.
511 ///
512 /// See the documentation for `torch::nn::SoftplusOptions` class to learn what
513 /// arguments are supported.
514 ///
515 /// Example:
516 /// ```
517 /// namespace F = torch::nn::functional;
518 /// F::softplus(x, F::SoftplusFuncOptions().beta(0.5).threshold(3.0));
519 /// ```
520 using SoftplusFuncOptions = SoftplusOptions;
521 } // namespace functional
522 
523 // ============================================================================
524 
525 /// Options for the `Softshrink` module.
526 ///
527 /// Example:
528 /// ```
529 /// Softshrink model(SoftshrinkOptions(42.42));
530 /// ```
531 struct TORCH_API SoftshrinkOptions {
532   /* implicit */ SoftshrinkOptions(double lambda = 0.5);
533 
534   /// the `lambda` value for the Softshrink formulation. Default: 0.5
535   TORCH_ARG(double, lambda);
536 };
537 
538 namespace functional {
539 /// Options for `torch::nn::functional::softshrink`.
540 ///
541 /// See the documentation for `torch::nn::SoftshrinkOptions` class to learn what
542 /// arguments are supported.
543 ///
544 /// Example:
545 /// ```
546 /// namespace F = torch::nn::functional;
547 /// F::softshrink(x, F::SoftshrinkFuncOptions(0.42));
548 /// ```
549 using SoftshrinkFuncOptions = SoftshrinkOptions;
550 } // namespace functional
551 
552 // ============================================================================
553 
554 /// Options for the `Threshold` module.
555 ///
556 /// Example:
557 /// ```
558 /// Threshold model(ThresholdOptions(42.42, 24.24).inplace(true));
559 /// ```
560 struct TORCH_API ThresholdOptions {
ThresholdOptionsThresholdOptions561   ThresholdOptions(double threshold, double value)
562       : threshold_(threshold), value_(value) {}
563 
564   /// The value to threshold at
565   TORCH_ARG(double, threshold);
566 
567   /// The value to replace with
568   TORCH_ARG(double, value);
569 
570   /// can optionally do the operation in-place. Default: False
571   TORCH_ARG(bool, inplace) = false;
572 };
573 
574 namespace functional {
575 /// Options for `torch::nn::functional::threshold`.
576 ///
577 /// See the documentation for `torch::nn::ThresholdOptions` class to learn what
578 /// arguments are supported.
579 ///
580 /// Example:
581 /// ```
582 /// namespace F = torch::nn::functional;
583 /// F::threshold(x, F::ThresholdFuncOptions(0.5, 0.5).inplace(true));
584 /// ```
585 using ThresholdFuncOptions = ThresholdOptions;
586 } // namespace functional
587 
588 // ============================================================================
589 
590 namespace functional {
591 
592 /// Options for `torch::nn::functional::gumbel_softmax`.
593 ///
594 /// Example:
595 /// ```
596 /// namespace F = torch::nn::functional;
597 /// F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true).dim(-1));
598 /// ```
599 struct TORCH_API GumbelSoftmaxFuncOptions {
600   /// non-negative scalar temperature
601   TORCH_ARG(double, tau) = 1.0;
602 
603   /// returned samples will be discretized as one-hot vectors,
604   /// but will be differentiated as if it is the soft sample in autograd.
605   /// Default: False
606   TORCH_ARG(bool, hard) = false;
607 
608   /// dimension along which softmax will be computed. Default: -1
609   TORCH_ARG(int, dim) = -1;
610 };
611 
612 } // namespace functional
613 
614 // ============================================================================
615 
616 /// Options for the `MultiheadAttention` module.
617 ///
618 /// Example:
619 /// ```
620 /// MultiheadAttention model(MultiheadAttentionOptions(20, 10).bias(false));
621 /// ```
622 struct TORCH_API MultiheadAttentionOptions {
623   MultiheadAttentionOptions(int64_t embed_dim, int64_t num_heads);
624 
625   /// total dimension of the model.
626   TORCH_ARG(int64_t, embed_dim);
627 
628   /// parallel attention heads.
629   TORCH_ARG(int64_t, num_heads);
630 
631   /// a Dropout layer on attn_output_weights. Default: 0.0.
632   TORCH_ARG(double, dropout) = 0.0;
633 
634   /// add bias as module parameter. Default: true.
635   TORCH_ARG(bool, bias) = true;
636 
637   /// add bias to the key and value sequences at dim=0.
638   TORCH_ARG(bool, add_bias_kv) = false;
639 
640   /// add a new batch of zeros to the key and value sequences at dim=1.
641   TORCH_ARG(bool, add_zero_attn) = false;
642 
643   /// total number of features in key. Default: std::nullopt.
644   TORCH_ARG(int64_t, kdim);
645 
646   /// total number of features in key. Default: std::nullopt.
647   TORCH_ARG(int64_t, vdim);
648 };
649 
650 // ============================================================================
651 
652 namespace functional {
653 
654 /// Options for `torch::nn::functional::multi_head_attention_forward`
655 struct TORCH_API MultiheadAttentionForwardFuncOptions {
656   MultiheadAttentionForwardFuncOptions(
657       int64_t embed_dim_to_check,
658       int64_t num_heads,
659       Tensor in_proj_weight,
660       Tensor in_proj_bias,
661       Tensor bias_k,
662       Tensor bias_v,
663       bool add_zero_attn,
664       double dropout_p,
665       Tensor out_proj_weight,
666       Tensor out_proj_bias);
667 
668   TORCH_ARG(int64_t, embed_dim_to_check);
669 
670   TORCH_ARG(int64_t, num_heads);
671 
672   TORCH_ARG(Tensor, in_proj_weight);
673 
674   TORCH_ARG(Tensor, in_proj_bias);
675 
676   TORCH_ARG(Tensor, bias_k);
677 
678   TORCH_ARG(Tensor, bias_v);
679 
680   TORCH_ARG(bool, add_zero_attn);
681 
682   TORCH_ARG(double, dropout_p);
683 
684   TORCH_ARG(Tensor, out_proj_weight);
685 
686   TORCH_ARG(Tensor, out_proj_bias);
687 
688   TORCH_ARG(bool, training) = true;
689 
690   TORCH_ARG(Tensor, key_padding_mask) = {};
691 
692   TORCH_ARG(bool, need_weights) = true;
693 
694   TORCH_ARG(Tensor, attn_mask) = {};
695 
696   TORCH_ARG(bool, use_separate_proj_weight) = false;
697 
698   TORCH_ARG(Tensor, q_proj_weight) = {};
699 
700   TORCH_ARG(Tensor, k_proj_weight) = {};
701 
702   TORCH_ARG(Tensor, v_proj_weight) = {};
703 
704   TORCH_ARG(Tensor, static_k) = {};
705 
706   TORCH_ARG(Tensor, static_v) = {};
707 
708   TORCH_ARG(bool, average_attn_weights) = true;
709 };
710 
711 } // namespace functional
712 
713 } // namespace nn
714 } // namespace torch
715