xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/datapipe.pyi.in (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection
3# The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt
4# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other
5# classes/objects here, even though we are not injecting extra code into them at the moment.
6
7from typing import (
8    Any,
9    Callable,
10    Dict,
11    Iterable,
12    Iterator,
13    List,
14    Literal,
15    Optional,
16    Type,
17    TypeVar,
18    Union,
19)
20
21from torch.utils.data import Dataset, default_collate, IterableDataset
22from torch.utils.data.datapipes._hook_iterator import _SnapshotState
23from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
24
25_T = TypeVar("_T")
26_T_co = TypeVar("_T_co", covariant=True)
27UNTRACABLE_DATAFRAME_PIPES: Any
28
29class DataChunk(List[_T]):
30    items: List[_T]
31    def __init__(self, items: Iterable[_T]) -> None: ...
32    def as_str(self, indent: str = "") -> str: ...
33    def __iter__(self) -> Iterator[_T]: ...
34    def raw_iterator(self) -> Iterator[_T]: ...
35
36class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
37    functions: Dict[str, Callable] = ...
38    reduce_ex_hook: Optional[Callable] = ...
39    getstate_hook: Optional[Callable] = ...
40    str_hook: Optional[Callable] = ...
41    repr_hook: Optional[Callable] = ...
42    def __getattr__(self, attribute_name: Any): ...
43    @classmethod
44    def register_function(cls, function_name: Any, function: Any) -> None: ...
45    @classmethod
46    def register_datapipe_as_function(
47        cls,
48        function_name: Any,
49        cls_to_register: Any,
50    ): ...
51    def __getstate__(self): ...
52    def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
53    @classmethod
54    def set_getstate_hook(cls, hook_fn: Any) -> None: ...
55    @classmethod
56    def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
57    ${MapDataPipeMethods}
58
59class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
60    functions: Dict[str, Callable] = ...
61    reduce_ex_hook: Optional[Callable] = ...
62    getstate_hook: Optional[Callable] = ...
63    str_hook: Optional[Callable] = ...
64    repr_hook: Optional[Callable] = ...
65    _number_of_samples_yielded: int = ...
66    _snapshot_state: _SnapshotState = _SnapshotState.Iterating  # noqa: PYI015
67    _fast_forward_iterator: Optional[Iterator] = ...
68    def __getattr__(self, attribute_name: Any): ...
69    @classmethod
70    def register_function(cls, function_name: Any, function: Any) -> None: ...
71    @classmethod
72    def register_datapipe_as_function(
73        cls,
74        function_name: Any,
75        cls_to_register: Any,
76        enable_df_api_tracing: bool = ...,
77    ): ...
78    def __getstate__(self): ...
79    def __reduce_ex__(self, *args: Any, **kwargs: Any): ...
80    @classmethod
81    def set_getstate_hook(cls, hook_fn: Any) -> None: ...
82    @classmethod
83    def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ...
84    ${IterDataPipeMethods}
85
86class DFIterDataPipe(IterDataPipe):
87    def _is_dfpipe(self): ...
88    def __iter__(self): ...
89
90class _DataPipeSerializationWrapper:
91    def __init__(self, datapipe): ...
92    def __getstate__(self): ...
93    def __setstate__(self, state): ...
94    def __len__(self): ...
95
96class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
97    def __iter__(self): ...
98
99class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
100    def __getitem__(self, idx): ...
101