xref: /aosp_15_r20/external/pytorch/torch/_C/_nn.pyi.in (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# ${generated_comment}
2# mypy: disable-error-code="type-arg"
3
4from typing import List, Literal, Optional, overload, Sequence, Tuple, Union
5
6from torch import memory_format, Tensor
7from torch.types import _bool, _device, _dtype, _int, _size
8
9# Defined in tools/autograd/templates/python_nn_functions.cpp
10
11${c_nn_function_hints}
12
13# Defined in aten/src/ATen/native/mkldnn/Linear.cpp
14def mkldnn_linear(input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: ...
15
16# Defined at aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
17def mkldnn_reorder_conv2d_weight(
18    self: Tensor,
19    padding: List,
20    stride: List,
21    dilatation: List,
22    groups: int,
23) -> Tensor: ...
24def mkldnn_reorder_conv3d_weight(
25    self: Tensor,
26    padding: List,
27    stride: List,
28    dilatation: List,
29    groups: int,
30) -> Tensor: ...
31
32# Defined in aten/src/ATen/native/mkldnn/Prelu.cpp
33def mkldnn_prelu(input: Tensor, weight: Tensor) -> Tensor: ...
34
35# Defined at tools/autograd/templates/python_nn_functions.cpp
36@overload
37def _parse_to(
38    device: _device,
39    dtype: _dtype,
40    non_blocking: _bool,
41    copy: _bool,
42    *,
43    memory_format: memory_format,
44) -> Tuple[_device, _dtype, _bool, memory_format]: ...
45@overload
46def _parse_to(
47    dtype: _dtype,
48    non_blocking: _bool,
49    copy: _bool,
50    *,
51    memory_format: memory_format,
52) -> Tuple[_device, _dtype, _bool, memory_format]: ...
53@overload
54def _parse_to(
55    tensor: Tensor,
56    non_blocking: _bool,
57    copy: _bool,
58    *,
59    memory_format: memory_format,
60) -> Tuple[_device, _dtype, _bool, memory_format]: ...
61
62# Defined in aten/src/ATen/native/PackedSequence.cpp
63def pad_sequence(
64    sequences: Union[List[Tensor], Tuple[Tensor, ...]],
65    batch_first: bool = False,
66    padding_value: float = 0.0,
67    padding_side: Union[Literal["left", "right"], str] = "right",
68) -> Tensor: ...
69def flatten_dense_tensors(tensors: List[Tensor]) -> Tensor: ...
70def unflatten_dense_tensors(flat: Tensor, tensors: List[Tensor]) -> List[Tensor]: ...
71