xref: /aosp_15_r20/external/pytorch/torch/utils/data/graph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import io
3import pickle
4import warnings
5from collections.abc import Collection
6from typing import Dict, List, Optional, Set, Tuple, Type, Union
7
8from torch.utils._import_utils import dill_available
9from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
10
11
12__all__ = ["traverse", "traverse_dps"]
13
14DataPipe = Union[IterDataPipe, MapDataPipe]
15DataPipeGraph = Dict[int, Tuple[DataPipe, "DataPipeGraph"]]  # type: ignore[misc]
16
17
18def _stub_unpickler():
19    return "STUB"
20
21
22# TODO(VitalyFedyunin): Make sure it works without dill module installed
23def _list_connected_datapipes(
24    scan_obj: DataPipe, only_datapipe: bool, cache: Set[int]
25) -> List[DataPipe]:
26    f = io.BytesIO()
27    p = pickle.Pickler(
28        f
29    )  # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is
30    if dill_available():
31        from dill import Pickler as dill_Pickler
32
33        d = dill_Pickler(f)
34    else:
35        d = None
36
37    captured_connections = []
38
39    def getstate_hook(ori_state):
40        state = None
41        if isinstance(ori_state, dict):
42            state = {}  # type: ignore[assignment]
43            for k, v in ori_state.items():
44                if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
45                    state[k] = v  # type: ignore[attr-defined]
46        elif isinstance(ori_state, (tuple, list)):
47            state = []  # type: ignore[assignment]
48            for v in ori_state:
49                if isinstance(v, (IterDataPipe, MapDataPipe, Collection)):
50                    state.append(v)  # type: ignore[attr-defined]
51        elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)):
52            state = ori_state  # type: ignore[assignment]
53        return state
54
55    def reduce_hook(obj):
56        if obj == scan_obj or id(obj) in cache:
57            raise NotImplementedError
58        else:
59            captured_connections.append(obj)
60            # Adding id to remove duplicate DataPipe serialized at the same level
61            cache.add(id(obj))
62            return _stub_unpickler, ()
63
64    datapipe_classes: Tuple[Type[DataPipe]] = (IterDataPipe, MapDataPipe)  # type: ignore[assignment]
65
66    try:
67        for cls in datapipe_classes:
68            cls.set_reduce_ex_hook(reduce_hook)
69            if only_datapipe:
70                cls.set_getstate_hook(getstate_hook)
71        try:
72            p.dump(scan_obj)
73        except (pickle.PickleError, AttributeError, TypeError):
74            if dill_available():
75                d.dump(scan_obj)
76            else:
77                raise
78    finally:
79        for cls in datapipe_classes:
80            cls.set_reduce_ex_hook(None)
81            if only_datapipe:
82                cls.set_getstate_hook(None)
83        if dill_available():
84            from dill import extend as dill_extend
85
86            dill_extend(False)  # Undo change to dispatch table
87    return captured_connections
88
89
90def traverse_dps(datapipe: DataPipe) -> DataPipeGraph:
91    r"""
92    Traverse the DataPipes and their attributes to extract the DataPipe graph.
93
94    This only looks into the attribute from each DataPipe that is either a
95    DataPipe and a Python collection object such as ``list``, ``tuple``,
96    ``set`` and ``dict``.
97
98    Args:
99        datapipe: the end DataPipe of the graph
100    Returns:
101        A graph represented as a nested dictionary, where keys are ids of DataPipe instances
102        and values are tuples of DataPipe instance and the sub-graph
103    """
104    cache: Set[int] = set()
105    return _traverse_helper(datapipe, only_datapipe=True, cache=cache)
106
107
108def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPipeGraph:
109    r"""
110    Traverse the DataPipes and their attributes to extract the DataPipe graph.
111
112    [Deprecated]
113    When ``only_dataPipe`` is specified as ``True``, it would only look into the
114    attribute from each DataPipe that is either a DataPipe and a Python collection object
115    such as ``list``, ``tuple``, ``set`` and ``dict``.
116
117    Note:
118        This function is deprecated. Please use `traverse_dps` instead.
119
120    Args:
121        datapipe: the end DataPipe of the graph
122        only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed.
123          This argument is deprecating and will be removed after the next release.
124    Returns:
125        A graph represented as a nested dictionary, where keys are ids of DataPipe instances
126        and values are tuples of DataPipe instance and the sub-graph
127    """
128    msg = (
129        "`traverse` function and will be removed after 1.13. "
130        "Please use `traverse_dps` instead."
131    )
132    if not only_datapipe:
133        msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`."
134    warnings.warn(msg, FutureWarning)
135    if only_datapipe is None:
136        only_datapipe = False
137    cache: Set[int] = set()
138    return _traverse_helper(datapipe, only_datapipe, cache)
139
140
141# Add cache here to prevent infinite recursion on DataPipe
142def _traverse_helper(
143    datapipe: DataPipe, only_datapipe: bool, cache: Set[int]
144) -> DataPipeGraph:
145    if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
146        raise RuntimeError(
147            f"Expected `IterDataPipe` or `MapDataPipe`, but {type(datapipe)} is found"
148        )
149
150    dp_id = id(datapipe)
151    if dp_id in cache:
152        return {}
153    cache.add(dp_id)
154    # Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths
155    items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy())
156    d: DataPipeGraph = {dp_id: (datapipe, {})}
157    for item in items:
158        # Using cache.copy() here is to prevent recursion on a single path rather than global graph
159        # Single DataPipe can present multiple times in different paths in graph
160        d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy()))
161    return d
162