1# mypy: allow-untyped-defs 2import functools 3import inspect 4from enum import Enum 5 6import torch 7 8 9class _SnapshotState(Enum): 10 r""" 11 These are the snapshotting-related states that IterDataPipes can be in. 12 13 `NotStarted` - allows you to restore a snapshot and create an iterator with reset 14 `Restored` - cannot restore again, allows you to create an iterator without resetting the DataPipe 15 `Iterating` - can restore, will reset if you create a new iterator 16 """ 17 18 NotStarted = 0 19 Restored = 1 20 Iterating = 2 21 22 23def _simplify_obj_name(obj) -> str: 24 """Simplify the display strings of objects for the purpose of rendering within DataPipe error messages.""" 25 if inspect.isfunction(obj): 26 return obj.__name__ 27 else: 28 return repr(obj) 29 30 31def _strip_datapipe_from_name(name: str) -> str: 32 return name.replace("IterDataPipe", "").replace("MapDataPipe", "") 33 34 35def _generate_input_args_string(obj): 36 """Generate a string for the input arguments of an object.""" 37 signature = inspect.signature(obj.__class__) 38 input_param_names = set(signature.parameters.keys()) 39 result = [] 40 for name, value in inspect.getmembers(obj): 41 if name in input_param_names: 42 result.append((name, _simplify_obj_name(value))) 43 return ", ".join([f"{name}={value}" for name, value in result]) 44 45 46def _generate_iterdatapipe_msg(datapipe, simplify_dp_name: bool = False): 47 output_string = ( 48 f"{datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})" 49 ) 50 if simplify_dp_name: 51 output_string = _strip_datapipe_from_name(output_string) 52 return output_string 53 54 55def _gen_invalid_iterdatapipe_msg(datapipe): 56 return ( 57 "This iterator has been invalidated because another iterator has been created " 58 f"from the same IterDataPipe: {_generate_iterdatapipe_msg(datapipe)}\n" 59 "This may be caused multiple references to the same IterDataPipe. We recommend " 60 "using `.fork()` if that is necessary." 61 ) 62 63 64_feedback_msg = ( 65 "\nFor feedback regarding this single iterator per IterDataPipe constraint, feel free " 66 "to comment on this issue: https://github.com/pytorch/data/issues/45." 67) 68 69 70def _check_iterator_valid(datapipe, iterator_id, next_method_exists=False) -> None: 71 r""" 72 Given an instance of a DataPipe and an iterator ID, check if the IDs match, and if not, raises an exception. 73 74 In the case of ChildDataPipe, the ID gets compared to the one stored in `main_datapipe` as well. 75 """ 76 if next_method_exists: 77 # This is the case where `IterDataPipe` has both `__iter__` and `__next__`. 78 # The `_valid_iterator_id` should either be never set (`None`), or set by at most one 79 # iterator (`0`). Otherwise, it means there are multiple iterators. 80 if datapipe._valid_iterator_id is not None and datapipe._valid_iterator_id != 0: 81 extra_msg = "\nNote that this exception is raised inside your IterDataPipe's a `__next__` method" 82 raise RuntimeError( 83 _gen_invalid_iterdatapipe_msg(datapipe) + extra_msg + _feedback_msg 84 ) 85 elif ( 86 hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True 87 ): 88 if hasattr(datapipe, "_check_valid_iterator_id"): 89 if not datapipe._check_valid_iterator_id(iterator_id): 90 raise RuntimeError( 91 "This iterator has been invalidated, because a new iterator has been created " 92 f"from one of the ChildDataPipes of " 93 f"{_generate_iterdatapipe_msg(datapipe.main_datapipe)}." 94 + _feedback_msg 95 ) 96 else: 97 raise RuntimeError( 98 "ChildDataPipe must have method `_check_valid_iterator_id`." 99 ) 100 elif datapipe._valid_iterator_id != iterator_id: 101 raise RuntimeError(_gen_invalid_iterdatapipe_msg(datapipe) + _feedback_msg) 102 103 104def _set_datapipe_valid_iterator_id(datapipe): 105 """Given a DataPipe, updates its valid iterator ID and reset the DataPipe.""" 106 if hasattr(datapipe, "_is_child_datapipe") and datapipe._is_child_datapipe is True: 107 if hasattr(datapipe, "_set_main_datapipe_valid_iterator_id"): 108 datapipe._set_main_datapipe_valid_iterator_id() # reset() is called within this method when appropriate 109 else: 110 raise RuntimeError( 111 "ChildDataPipe must have method `_set_main_datapipe_valid_iterator_id`." 112 ) 113 else: 114 if datapipe._valid_iterator_id is None: 115 datapipe._valid_iterator_id = 0 116 else: 117 datapipe._valid_iterator_id += 1 118 datapipe.reset() 119 return datapipe._valid_iterator_id 120 121 122def hook_iterator(namespace): 123 r""" 124 Define a hook that is applied to all `__iter__` of metaclass `_DataPipeMeta`. 125 126 This is done for the purpose of profiling and checking if an iterator is still valid. 127 """ 128 129 def profiler_record_fn_context(datapipe): 130 if not hasattr(datapipe, "_profile_name"): 131 datapipe._profile_name = _generate_iterdatapipe_msg( 132 datapipe, simplify_dp_name=True 133 ) 134 return torch.autograd.profiler.record_function(datapipe._profile_name) 135 136 class IteratorDecorator: 137 r""" 138 Wrap the iterator and modifying its `__next__` method. 139 140 This decorator is applied to DataPipes of which `__iter__` method is NOT a generator function. 141 Those `__iter__` method commonly returns `self` but not necessarily. 142 """ 143 144 def __init__(self, iterator, datapipe, iterator_id, has_next_method): 145 self.iterator = iterator 146 self.datapipe = datapipe 147 self.iterator_id = iterator_id 148 self._profiler_enabled = torch.autograd._profiler_enabled() 149 # Check if `__iter__` returns `self` and `DataPipe` has `__next__` 150 self.self_and_has_next_method = ( 151 self.iterator is self.datapipe and has_next_method 152 ) 153 154 def __iter__(self): 155 return self 156 157 def _get_next(self): 158 """Return next with logic related to iterator validity, profiler, and incrementation of samples yielded.""" 159 _check_iterator_valid(self.datapipe, self.iterator_id) 160 result = next(self.iterator) 161 if not self.self_and_has_next_method: 162 self.datapipe._number_of_samples_yielded += 1 163 return result 164 165 def __next__(self): 166 # TODO: Add try-except to in-place reduce traceback from the Exception 167 # See: https://github.com/pytorch/data/issues/284 168 if self._profiler_enabled: 169 with profiler_record_fn_context(self.datapipe): 170 return self._get_next() 171 else: # Decided against using `contextlib.nullcontext` for performance reasons 172 return self._get_next() 173 174 def __getattr__(self, name): 175 return getattr(self.iterator, name) 176 177 func = namespace["__iter__"] 178 179 # ``__iter__`` of IterDataPipe is a generator function 180 if inspect.isgeneratorfunction(func): 181 182 @functools.wraps(func) 183 def wrap_generator(*args, **kwargs): 184 gen = func(*args, **kwargs) 185 datapipe = args[0] 186 if datapipe._fast_forward_iterator: 187 it = datapipe._fast_forward_iterator 188 datapipe._fast_forward_iterator = None 189 datapipe._snapshot_state = _SnapshotState.Iterating 190 while True: 191 try: 192 yield next(it) 193 except StopIteration: 194 return 195 iterator_id = _set_datapipe_valid_iterator_id( 196 datapipe 197 ) # This ID is tied to each created iterator 198 _profiler_enabled = torch.autograd._profiler_enabled() 199 try: 200 if _profiler_enabled: 201 with profiler_record_fn_context(datapipe): 202 response = gen.send(None) 203 else: 204 response = gen.send(None) 205 206 while True: 207 datapipe._number_of_samples_yielded += 1 208 request = yield response 209 # Pass through here every time `__next__` is called 210 if _profiler_enabled: 211 with profiler_record_fn_context(datapipe): 212 _check_iterator_valid(datapipe, iterator_id) 213 response = gen.send(request) 214 else: # Decided against using `contextlib.nullcontext` for performance reasons 215 _check_iterator_valid(datapipe, iterator_id) 216 response = gen.send(request) 217 except StopIteration as e: 218 return 219 except Exception as e: 220 # TODO: Simplify the traceback message to skip over `response = gen.send(None)` 221 # Part of https://github.com/pytorch/data/issues/284 222 datapipe = args[0] 223 msg = "thrown by __iter__ of" 224 single_iterator_msg = "single iterator per IterDataPipe constraint" 225 if hasattr(e.args, "__len__"): 226 full_msg = f"{msg} {datapipe.__class__.__name__}({_generate_input_args_string(datapipe)})" 227 if len(e.args) == 0 or not isinstance( 228 e.args[0], str 229 ): # If an exception message doesn't exist 230 e.args = (f"\nThis exception is {full_msg}",) 231 elif msg not in e.args[0] and single_iterator_msg not in e.args[0]: 232 e.args = ( 233 e.args[0] + f"\nThis exception is {full_msg}", 234 ) + e.args[1:] 235 raise 236 237 namespace["__iter__"] = wrap_generator 238 else: # ``__iter__`` of IterDataPipe is NOT a generator function 239 # IterDataPipe is an iterator with both ``__iter__`` and ``__next__`` 240 # And ``__iter__`` may or may not return `self` 241 if "__next__" in namespace: # If `__next__` exists, put a wrapper around it 242 next_func = namespace["__next__"] 243 244 @functools.wraps(next_func) 245 def wrap_next(*args, **kwargs): 246 datapipe = args[0] 247 if torch.autograd._profiler_enabled(): 248 with profiler_record_fn_context(datapipe): 249 result = next_func(*args, **kwargs) 250 else: 251 result = next_func(*args, **kwargs) 252 datapipe._number_of_samples_yielded += 1 253 return result 254 255 namespace["__next__"] = wrap_next 256 257 # Note that if the `__next__` and `__iter__` do something completely unrelated. It may cause issue but 258 # the user will be violating the iterator protocol. Potential issue: 259 # 1. Valid iterator ID may not update or checked properly 260 # 2. The number of samples yielded will be miscounted 261 262 # Regardless if `__next__` exists or not, `__iter__` needs a wrapper to track the number of valid iterators 263 @functools.wraps(func) 264 def wrap_iter(*args, **kwargs): 265 iter_ret = func(*args, **kwargs) 266 datapipe = args[0] 267 datapipe._snapshot_state = _SnapshotState.Iterating 268 if datapipe._fast_forward_iterator: 269 iter_ret = datapipe._fast_forward_iterator 270 datapipe._fast_forward_iterator = None 271 return iter_ret 272 iterator_id = _set_datapipe_valid_iterator_id( 273 datapipe 274 ) # This ID is tied to each created iterator 275 return IteratorDecorator( 276 iter_ret, datapipe, iterator_id, "__next__" in namespace 277 ) 278 279 namespace["__iter__"] = wrap_iter 280