1# mypy: allow-untyped-defs 2import io 3import pickle 4import warnings 5from collections.abc import Collection 6from typing import Dict, List, Optional, Set, Tuple, Type, Union 7 8from torch.utils._import_utils import dill_available 9from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe 10 11 12__all__ = ["traverse", "traverse_dps"] 13 14DataPipe = Union[IterDataPipe, MapDataPipe] 15DataPipeGraph = Dict[int, Tuple[DataPipe, "DataPipeGraph"]] # type: ignore[misc] 16 17 18def _stub_unpickler(): 19 return "STUB" 20 21 22# TODO(VitalyFedyunin): Make sure it works without dill module installed 23def _list_connected_datapipes( 24 scan_obj: DataPipe, only_datapipe: bool, cache: Set[int] 25) -> List[DataPipe]: 26 f = io.BytesIO() 27 p = pickle.Pickler( 28 f 29 ) # Not going to work for lambdas, but dill infinite loops on typing and can't be used as is 30 if dill_available(): 31 from dill import Pickler as dill_Pickler 32 33 d = dill_Pickler(f) 34 else: 35 d = None 36 37 captured_connections = [] 38 39 def getstate_hook(ori_state): 40 state = None 41 if isinstance(ori_state, dict): 42 state = {} # type: ignore[assignment] 43 for k, v in ori_state.items(): 44 if isinstance(v, (IterDataPipe, MapDataPipe, Collection)): 45 state[k] = v # type: ignore[attr-defined] 46 elif isinstance(ori_state, (tuple, list)): 47 state = [] # type: ignore[assignment] 48 for v in ori_state: 49 if isinstance(v, (IterDataPipe, MapDataPipe, Collection)): 50 state.append(v) # type: ignore[attr-defined] 51 elif isinstance(ori_state, (IterDataPipe, MapDataPipe, Collection)): 52 state = ori_state # type: ignore[assignment] 53 return state 54 55 def reduce_hook(obj): 56 if obj == scan_obj or id(obj) in cache: 57 raise NotImplementedError 58 else: 59 captured_connections.append(obj) 60 # Adding id to remove duplicate DataPipe serialized at the same level 61 cache.add(id(obj)) 62 return _stub_unpickler, () 63 64 datapipe_classes: Tuple[Type[DataPipe]] = (IterDataPipe, MapDataPipe) # type: ignore[assignment] 65 66 try: 67 for cls in datapipe_classes: 68 cls.set_reduce_ex_hook(reduce_hook) 69 if only_datapipe: 70 cls.set_getstate_hook(getstate_hook) 71 try: 72 p.dump(scan_obj) 73 except (pickle.PickleError, AttributeError, TypeError): 74 if dill_available(): 75 d.dump(scan_obj) 76 else: 77 raise 78 finally: 79 for cls in datapipe_classes: 80 cls.set_reduce_ex_hook(None) 81 if only_datapipe: 82 cls.set_getstate_hook(None) 83 if dill_available(): 84 from dill import extend as dill_extend 85 86 dill_extend(False) # Undo change to dispatch table 87 return captured_connections 88 89 90def traverse_dps(datapipe: DataPipe) -> DataPipeGraph: 91 r""" 92 Traverse the DataPipes and their attributes to extract the DataPipe graph. 93 94 This only looks into the attribute from each DataPipe that is either a 95 DataPipe and a Python collection object such as ``list``, ``tuple``, 96 ``set`` and ``dict``. 97 98 Args: 99 datapipe: the end DataPipe of the graph 100 Returns: 101 A graph represented as a nested dictionary, where keys are ids of DataPipe instances 102 and values are tuples of DataPipe instance and the sub-graph 103 """ 104 cache: Set[int] = set() 105 return _traverse_helper(datapipe, only_datapipe=True, cache=cache) 106 107 108def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPipeGraph: 109 r""" 110 Traverse the DataPipes and their attributes to extract the DataPipe graph. 111 112 [Deprecated] 113 When ``only_dataPipe`` is specified as ``True``, it would only look into the 114 attribute from each DataPipe that is either a DataPipe and a Python collection object 115 such as ``list``, ``tuple``, ``set`` and ``dict``. 116 117 Note: 118 This function is deprecated. Please use `traverse_dps` instead. 119 120 Args: 121 datapipe: the end DataPipe of the graph 122 only_datapipe: If ``False`` (default), all attributes of each DataPipe are traversed. 123 This argument is deprecating and will be removed after the next release. 124 Returns: 125 A graph represented as a nested dictionary, where keys are ids of DataPipe instances 126 and values are tuples of DataPipe instance and the sub-graph 127 """ 128 msg = ( 129 "`traverse` function and will be removed after 1.13. " 130 "Please use `traverse_dps` instead." 131 ) 132 if not only_datapipe: 133 msg += " And, the behavior will be changed to the equivalent of `only_datapipe=True`." 134 warnings.warn(msg, FutureWarning) 135 if only_datapipe is None: 136 only_datapipe = False 137 cache: Set[int] = set() 138 return _traverse_helper(datapipe, only_datapipe, cache) 139 140 141# Add cache here to prevent infinite recursion on DataPipe 142def _traverse_helper( 143 datapipe: DataPipe, only_datapipe: bool, cache: Set[int] 144) -> DataPipeGraph: 145 if not isinstance(datapipe, (IterDataPipe, MapDataPipe)): 146 raise RuntimeError( 147 f"Expected `IterDataPipe` or `MapDataPipe`, but {type(datapipe)} is found" 148 ) 149 150 dp_id = id(datapipe) 151 if dp_id in cache: 152 return {} 153 cache.add(dp_id) 154 # Using cache.copy() here is to prevent the same DataPipe pollutes the cache on different paths 155 items = _list_connected_datapipes(datapipe, only_datapipe, cache.copy()) 156 d: DataPipeGraph = {dp_id: (datapipe, {})} 157 for item in items: 158 # Using cache.copy() here is to prevent recursion on a single path rather than global graph 159 # Single DataPipe can present multiple times in different paths in graph 160 d[dp_id][1].update(_traverse_helper(item, only_datapipe, cache.copy())) 161 return d 162