xref: /aosp_15_r20/external/pytorch/torch/nn/functional.pyi.in (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# ${generated_comment}
2# mypy: allow-untyped-defs
3
4from typing import (
5    Any,
6    Callable,
7    Dict,
8    List,
9    Literal,
10    Optional,
11    overload,
12    Sequence,
13    Tuple,
14    Union,
15)
16
17from torch import Tensor
18from torch.types import _dtype, _int, _size
19
20from .common_types import (
21    _ratio_any_t,
22    _size_1_t,
23    _size_2_opt_t,
24    _size_2_t,
25    _size_3_opt_t,
26    _size_3_t,
27    _size_any_t,
28)
29
30# 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys.
31# It is standards-track but not in `typing` yet. We leave this hear to be uncommented once the feature
32# is wide-spread.
33
34# from mypy_extensions import TypedDict
35
36# GRID_SAMPLE_INTERPOLATION_MODES = TypedDict('GRID_SAMPLE_INTERPOLATION_MODES', {'bilinear': int, 'nearest': int})
37# GRID_SAMPLE_PADDING_MODES = TypedDict('GRID_SAMPLE_PADDING_MODES', {'zeros': int, 'border': int, 'reflection': int})
38
39GRID_SAMPLE_INTERPOLATION_MODES = Dict[str, int]
40GRID_SAMPLE_PADDING_MODES = Dict[str, int]
41
42# These stubs were generated by running stubgen (`stubgen --parse-only functional.py`), followed by manual cleaning.
43#
44# The 'BroadcastingList{1,2,3}' types were replaced by `_size` or _output_ratio, as appropriate.
45# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence`
46# type. There is no way to express the expected lengths of these lists in the current Python typing system.
47#
48# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
49# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
50# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
51# to encode the type semantics of `_add_docstr`, should that system ever become widespread.
52def fractional_max_pool2d_with_indices(
53    input: Tensor,
54    kernel_size: _size,
55    output_size: Optional[_size] = ...,
56    output_ratio: Optional[_ratio_any_t] = ...,
57    return_indices: bool = ...,
58    _random_samples: Optional[Tensor] = ...,
59) -> Tuple[Tensor, Tensor]: ...
60def fractional_max_pool3d_with_indices(
61    input: Tensor,
62    kernel_size: _size,
63    output_size: Optional[_size] = ...,
64    output_ratio: Optional[_ratio_any_t] = ...,
65    return_indices: bool = ...,
66    _random_samples: Optional[Tensor] = ...,
67) -> Tuple[Tensor, Tensor]: ...
68def max_pool1d_with_indices(
69    input: Tensor,
70    kernel_size: _size,
71    stride: Optional[_size] = ...,
72    padding: _size = ...,
73    dilation: _size = ...,
74    ceil_mode: bool = ...,
75    return_indices: bool = ...,
76) -> Tuple[Tensor, Tensor]: ...
77def max_pool2d_with_indices(
78    input: Tensor,
79    kernel_size: _size,
80    stride: Optional[_size] = ...,
81    padding: _size = ...,
82    dilation: _size = ...,
83    ceil_mode: bool = ...,
84    return_indices: bool = ...,
85) -> Tuple[Tensor, Tensor]: ...
86def max_pool3d_with_indices(
87    input: Tensor,
88    kernel_size: _size,
89    stride: Optional[_size] = ...,
90    padding: _size = ...,
91    dilation: _size = ...,
92    ceil_mode: bool = ...,
93    return_indices: bool = ...,
94) -> Tuple[Tensor, Tensor]: ...
95def max_unpool1d(
96    input: Tensor,
97    indices: Tensor,
98    kernel_size: _size,
99    stride: Optional[_size] = ...,
100    padding: _size = ...,
101    output_size: Optional[_size] = ...,
102) -> Tensor: ...
103def max_unpool2d(
104    input: Tensor,
105    indices: Tensor,
106    kernel_size: _size,
107    stride: Optional[_size] = ...,
108    padding: _size = ...,
109    output_size: Optional[_size] = ...,
110) -> Tensor: ...
111def max_unpool3d(
112    input: Tensor,
113    indices: Tensor,
114    kernel_size: _size,
115    stride: Optional[_size] = ...,
116    padding: _size = ...,
117    output_size: Optional[_size] = ...,
118) -> Tensor: ...
119def lp_pool1d(
120    input: Tensor,
121    norm_type: float,
122    kernel_size: _size_1_t,
123    stride: Union[Optional[_size], Optional[int]] = ...,
124    ceil_mode: bool = ...,
125) -> Tensor: ...
126def lp_pool2d(
127    input: Tensor,
128    norm_type: float,
129    kernel_size: _size_2_t,
130    stride: Union[Optional[_size], Optional[int]] = ...,
131    ceil_mode: bool = ...,
132) -> Tensor: ...
133def lp_pool3d(
134    input: Tensor,
135    norm_type: float,
136    kernel_size: _size_3_t,
137    stride: Union[Optional[_size], Optional[int]] = ...,
138    ceil_mode: bool = ...,
139) -> Tensor: ...
140def adaptive_max_pool1d_with_indices(
141    input: Tensor,
142    output_size: _size,
143    return_indices: bool = ...,
144) -> Tuple[Tensor, Tensor]: ...
145def adaptive_max_pool2d_with_indices(
146    input: Tensor,
147    output_size: _size_2_opt_t,
148    return_indices: bool = ...,
149) -> Tuple[Tensor, Tensor]: ...
150def adaptive_max_pool3d_with_indices(
151    input: Tensor,
152    output_size: _size_3_opt_t,
153    return_indices: bool = ...,
154) -> Tuple[Tensor, Tensor]: ...
155def adaptive_avg_pool2d(input: Tensor, output_size: _size_2_opt_t) -> Tensor: ...
156def adaptive_avg_pool3d(input: Tensor, output_size: _size_3_opt_t) -> Tensor: ...
157def dropout(
158    input: Tensor,
159    p: float = ...,
160    training: bool = ...,
161    inplace: bool = ...,
162) -> Tensor: ...
163def alpha_dropout(
164    input: Tensor,
165    p: float = ...,
166    training: bool = ...,
167    inplace: bool = ...,
168) -> Tensor: ...
169def dropout1d(
170    input: Tensor,
171    p: float = ...,
172    training: bool = ...,
173    inplace: bool = ...,
174) -> Tensor: ...
175def dropout2d(
176    input: Tensor,
177    p: float = ...,
178    training: bool = ...,
179    inplace: bool = ...,
180) -> Tensor: ...
181def dropout3d(
182    input: Tensor,
183    p: float = ...,
184    training: bool = ...,
185    inplace: bool = ...,
186) -> Tensor: ...
187def feature_alpha_dropout(
188    input: Tensor,
189    p: float = ...,
190    training: bool = ...,
191    inplace: bool = ...,
192) -> Tensor: ...
193def threshold(
194    input: Tensor,
195    threshold: float,
196    value: float,
197    inplace: bool = ...,
198) -> Tensor: ...
199def relu(input: Tensor, inplace: bool = ...) -> Tensor: ...
200def glu(input: Tensor, dim: int = ...) -> Tensor: ...
201def hardtanh(
202    input: Tensor,
203    min_val: float = ...,
204    max_val: float = ...,
205    inplace: bool = ...,
206) -> Tensor: ...
207def relu6(input: Tensor, inplace: bool = ...) -> Tensor: ...
208def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
209def selu(input: Tensor, inplace: bool = ...) -> Tensor: ...
210def celu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...
211def leaky_relu(
212    input: Tensor,
213    negative_slope: float = ...,
214    inplace: bool = ...,
215) -> Tensor: ...
216def rrelu(
217    input: Tensor,
218    lower: float = ...,
219    upper: float = ...,
220    training: bool = ...,
221    inplace: bool = ...,
222) -> Tensor: ...
223def tanhshrink(input: Any): ...
224def softsign(input: Any): ...
225def softmin(
226    input: Tensor,
227    dim: Optional[int] = ...,
228    _stacklevel: int = ...,
229    dtype: Optional[_dtype] = ...,
230) -> Tensor: ...
231def softmax(
232    input: Tensor,
233    dim: Optional[int] = ...,
234    _stacklevel: int = ...,
235    dtype: Optional[_dtype] = ...,
236) -> Tensor: ...
237def gumbel_softmax(
238    logits: Tensor,
239    tau: float = ...,
240    hard: bool = ...,
241    eps: float = ...,
242    dim: int = ...,
243) -> Tensor: ...
244def log_softmax(
245    input: Tensor,
246    dim: Optional[int] = ...,
247    _stacklevel: int = ...,
248    dtype: Optional[_dtype] = ...,
249) -> Tensor: ...
250def tanh(input: Any): ...
251def sigmoid(input: Any) -> Tensor: ...
252def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: ...
253def silu(input: Tensor, inplace: bool = False) -> Tensor: ...
254def mish(input: Tensor, inplace: bool = False) -> Tensor: ...
255def hardswish(input: Tensor, inplace: bool = False) -> Tensor: ...
256def embedding(
257    input: Tensor,
258    weight: Tensor,
259    padding_idx: Optional[int] = ...,
260    max_norm: Optional[float] = ...,
261    norm_type: float = ...,
262    scale_grad_by_freq: bool = ...,
263    sparse: bool = ...,
264) -> Tensor: ...
265def embedding_bag(
266    input: Tensor,
267    weight: Tensor,
268    offsets: Optional[Tensor] = ...,
269    max_norm: Optional[float] = ...,
270    norm_type: float = ...,
271    scale_grad_by_freq: bool = ...,
272    mode: str = ...,
273    sparse: bool = ...,
274    per_sample_weights: Optional[Tensor] = ...,
275    include_last_offset: bool = ...,
276    padding_idx: Optional[int] = ...,
277) -> Tensor: ...
278def batch_norm(
279    input: Tensor,
280    running_mean: Optional[Tensor],
281    running_var: Optional[Tensor],
282    weight: Optional[Tensor] = ...,
283    bias: Optional[Tensor] = ...,
284    training: bool = ...,
285    momentum: float = ...,
286    eps: float = ...,
287) -> Tensor: ...
288def instance_norm(
289    input: Tensor,
290    running_mean: Optional[Tensor] = ...,
291    running_var: Optional[Tensor] = ...,
292    weight: Optional[Tensor] = ...,
293    bias: Optional[Tensor] = ...,
294    use_input_stats: bool = ...,
295    momentum: float = ...,
296    eps: float = ...,
297) -> Tensor: ...
298def layer_norm(
299    input: Tensor,
300    normalized_shape: Sequence[int],
301    weight: Optional[Tensor] = ...,
302    bias: Optional[Tensor] = ...,
303    eps: float = ...,
304) -> Tensor: ...
305def rms_norm(
306    input: Tensor,
307    normalized_shape: Sequence[int],
308    weight: Optional[Tensor] = ...,
309    eps: Optional[float] = ...,
310) -> Tensor: ...
311def group_norm(
312    input: Tensor,
313    num_groups: int,
314    weight: Optional[Tensor] = ...,
315    bias: Optional[Tensor] = ...,
316    eps: float = ...,
317) -> Tensor: ...
318def local_response_norm(
319    input: Tensor,
320    size: int,
321    alpha: float = ...,
322    beta: float = ...,
323    k: float = ...,
324) -> Tensor: ...
325def ctc_loss(
326    log_probs: Tensor,
327    targets: Tensor,
328    input_lengths: Tensor,
329    target_lengths: Tensor,
330    blank: int = ...,
331    reduction: str = ...,
332    zero_infinity: bool = ...,
333) -> Tensor: ...
334def nll_loss(
335    input: Tensor,
336    target: Tensor,
337    weight: Optional[Tensor] = ...,
338    size_average: Optional[bool] = ...,
339    ignore_index: int = ...,
340    reduce: Optional[bool] = ...,
341    reduction: str = ...,
342) -> Tensor: ...
343def poisson_nll_loss(
344    input: Tensor,
345    target: Tensor,
346    log_input: bool = ...,
347    full: bool = ...,
348    size_average: Optional[bool] = ...,
349    eps: float = ...,
350    reduce: Optional[bool] = ...,
351    reduction: str = ...,
352) -> Tensor: ...
353def gaussian_nll_loss(
354    input: Tensor,
355    target: Tensor,
356    var: Tensor,
357    full: Optional[bool] = ...,
358    eps: Optional[float] = ...,
359    reduction: Optional[str] = ...,
360) -> Tensor: ...
361def kl_div(
362    input: Tensor,
363    target: Tensor,
364    size_average: Optional[bool] = ...,
365    reduce: Optional[bool] = ...,
366    reduction: str = ...,
367    log_target: bool = ...,
368) -> Tensor: ...
369def cross_entropy(
370    input: Tensor,
371    target: Tensor,
372    weight: Optional[Tensor] = ...,
373    size_average: Optional[bool] = ...,
374    ignore_index: int = ...,
375    reduce: Optional[bool] = ...,
376    reduction: str = ...,
377    label_smoothing: float = ...,
378) -> Tensor: ...
379def binary_cross_entropy(
380    input: Tensor,
381    target: Tensor,
382    weight: Optional[Tensor] = ...,
383    size_average: Optional[bool] = ...,
384    reduce: Optional[bool] = ...,
385    reduction: str = ...,
386) -> Tensor: ...
387def binary_cross_entropy_with_logits(
388    input: Tensor,
389    target: Tensor,
390    weight: Optional[Tensor] = ...,
391    size_average: Optional[bool] = ...,
392    reduce: Optional[bool] = ...,
393    reduction: str = ...,
394    pos_weight: Optional[Tensor] = ...,
395) -> Tensor: ...
396def smooth_l1_loss(
397    input: Tensor,
398    target: Tensor,
399    size_average: Optional[bool] = ...,
400    reduce: Optional[bool] = ...,
401    reduction: str = ...,
402    beta: float = ...,
403) -> Tensor: ...
404def huber_loss(
405    input: Tensor,
406    target: Tensor,
407    reduction: str = ...,
408    delta: float = ...,
409) -> Tensor: ...
410def l1_loss(
411    input: Tensor,
412    target: Tensor,
413    size_average: Optional[bool] = ...,
414    reduce: Optional[bool] = ...,
415    reduction: str = ...,
416) -> Tensor: ...
417def mse_loss(
418    input: Tensor,
419    target: Tensor,
420    size_average: Optional[bool] = ...,
421    reduce: Optional[bool] = ...,
422    reduction: str = ...,
423) -> Tensor: ...
424def margin_ranking_loss(
425    input1: Tensor,
426    input2: Tensor,
427    target: Tensor,
428    margin: float = ...,
429    size_average: Optional[bool] = ...,
430    reduce: Optional[bool] = ...,
431    reduction: str = ...,
432) -> Tensor: ...
433def hinge_embedding_loss(
434    input: Tensor,
435    target: Tensor,
436    margin: float = ...,
437    size_average: Optional[bool] = ...,
438    reduce: Optional[bool] = ...,
439    reduction: str = ...,
440) -> Tensor: ...
441def multilabel_margin_loss(
442    input: Tensor,
443    target: Tensor,
444    size_average: Optional[bool] = ...,
445    reduce: Optional[bool] = ...,
446    reduction: str = ...,
447) -> Tensor: ...
448def soft_margin_loss(
449    input: Tensor,
450    target: Tensor,
451    size_average: Optional[bool] = ...,
452    reduce: Optional[bool] = ...,
453    reduction: str = ...,
454) -> Tensor: ...
455def multilabel_soft_margin_loss(
456    input: Tensor,
457    target: Tensor,
458    weight: Optional[Tensor] = ...,
459    size_average: Optional[bool] = ...,
460    reduce: Optional[bool] = ...,
461    reduction: str = ...,
462) -> Tensor: ...
463def cosine_embedding_loss(
464    input1: Tensor,
465    input2: Tensor,
466    target: Tensor,
467    margin: float = ...,
468    size_average: Optional[bool] = ...,
469    reduce: Optional[bool] = ...,
470    reduction: str = ...,
471) -> Tensor: ...
472def multi_margin_loss(
473    input: Tensor,
474    target: Tensor,
475    p: int = ...,
476    margin: float = ...,
477    weight: Optional[Tensor] = ...,
478    size_average: Optional[bool] = ...,
479    reduce: Optional[bool] = ...,
480    reduction: str = ...,
481) -> Tensor: ...
482def upsample(
483    input: Any,
484    size: Optional[Any] = ...,
485    scale_factor: Optional[Any] = ...,
486    mode: str = ...,
487    align_corners: Optional[Any] = ...,
488): ...
489def interpolate(
490    input: Any,
491    size: Optional[Any] = ...,
492    scale_factor: Optional[Any] = ...,
493    mode: str = ...,
494    align_corners: Optional[Any] = ...,
495    recompute_scale_factor: Optional[Any] = ...,
496    antialias: bool = ...,
497): ...
498def upsample_nearest(
499    input: Any,
500    size: Optional[Any] = ...,
501    scale_factor: Optional[Any] = ...,
502): ...
503def upsample_bilinear(
504    input: Any,
505    size: Optional[Any] = ...,
506    scale_factor: Optional[Any] = ...,
507): ...
508def grid_sample(
509    input: Tensor,
510    grid: Tensor,
511    mode: str = ...,
512    padding_mode: str = ...,
513    align_corners: Optional[Any] = ...,
514) -> Tensor: ...
515def affine_grid(
516    theta: Tensor,
517    size: List[int],
518    align_corners: Optional[Any] = ...,
519) -> Tensor: ...
520def triplet_margin_loss(
521    anchor: Tensor,
522    positive: Tensor,
523    negative: Tensor,
524    margin: float = ...,
525    p: float = ...,
526    eps: float = ...,
527    swap: bool = ...,
528    size_average: Optional[bool] = ...,
529    reduce: Optional[bool] = ...,
530    reduction: str = ...,
531) -> Tensor: ...
532def triplet_margin_with_distance_loss(
533    anchor: Tensor,
534    positive: Tensor,
535    negative: Tensor,
536    *,
537    distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = ...,
538    margin: float = ...,
539    swap: bool = ...,
540    reduction: str = ...,
541) -> Tensor: ...
542def normalize(
543    input: Tensor,
544    p: float = ...,
545    dim: int = ...,
546    eps: float = ...,
547    out: Optional[Tensor] = ...,
548) -> Tensor: ...
549def assert_int_or_pair(
550    arg: Any,
551    arg_name: Any,
552    message: Any,
553) -> None: ...
554def unfold(
555    input: Tensor,
556    kernel_size: _size_any_t,
557    dilation: _size_any_t = ...,
558    padding: _size_any_t = ...,
559    stride: _size_any_t = ...,
560) -> Tensor: ...
561def fold(
562    input: Tensor,
563    output_size: _size_any_t,
564    kernel_size: _size_any_t,
565    dilation: _size_any_t = ...,
566    padding: _size_any_t = ...,
567    stride: _size_any_t = ...,
568) -> Tensor: ...
569def _canonical_mask(
570    mask: Optional[Tensor],
571    mask_name: str,
572    other_type: Optional[_dtype],
573    other_name: str,
574    target_type: _dtype,
575    check_other: bool = True,
576) -> Optional[Tensor]: ...
577def _none_or_dtype(input: Optional[Tensor]) -> Optional[_dtype]: ...
578def multi_head_attention_forward(
579    query: Tensor,
580    key: Tensor,
581    value: Tensor,
582    embed_dim_to_check: int,
583    num_heads: int,
584    in_proj_weight: Optional[Tensor],
585    in_proj_bias: Optional[Tensor],
586    bias_k: Optional[Tensor],
587    bias_v: Optional[Tensor],
588    add_zero_attn: bool,
589    dropout_p: float,
590    out_proj_weight: Tensor,
591    out_proj_bias: Optional[Tensor],
592    training: bool = True,
593    key_padding_mask: Optional[Tensor] = None,
594    need_weights: bool = True,
595    attn_mask: Optional[Tensor] = None,
596    use_separate_proj_weight: bool = False,
597    q_proj_weight: Optional[Tensor] = None,
598    k_proj_weight: Optional[Tensor] = None,
599    v_proj_weight: Optional[Tensor] = None,
600    static_k: Optional[Tensor] = None,
601    static_v: Optional[Tensor] = None,
602    average_attn_weights: bool = True,
603    is_causal: bool = False,
604) -> Tuple[Tensor, Optional[Tensor]]: ...
605
606${imported_hints}
607
608${dispatched_hints}
609