1import functools 2import pickle 3from typing import Callable, Dict, Iterable, Iterator, List, Optional, TypeVar 4 5from torch.utils._import_utils import import_dill 6from torch.utils.data.datapipes._hook_iterator import _SnapshotState 7from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta 8from torch.utils.data.datapipes.utils.common import ( 9 _deprecation_warning, 10 _iter_deprecated_functional_names, 11 _map_deprecated_functional_names, 12) 13from torch.utils.data.dataset import Dataset, IterableDataset 14 15 16dill = import_dill() 17HAS_DILL = dill is not None 18 19__all__ = [ 20 "DataChunk", 21 "DFIterDataPipe", 22 "IterDataPipe", 23 "MapDataPipe", 24] 25 26 27_T = TypeVar("_T") 28_T_co = TypeVar("_T_co", covariant=True) 29 30UNTRACABLE_DATAFRAME_PIPES = [ 31 "batch", # As it returns DataChunks 32 "groupby", # As it returns DataChunks 33 "_dataframes_as_tuples", # As it unpacks DF 34 "trace_as_dataframe", # As it used to mark DF for tracing 35] 36 37 38class DataChunk(List[_T]): 39 def __init__(self, items: Iterable[_T]) -> None: 40 items = list(items) 41 super().__init__(items) 42 self.items = items 43 44 def as_str(self, indent: str = "") -> str: 45 return indent + "[" + ", ".join(str(i) for i in iter(self)) + "]" 46 47 def __iter__(self) -> Iterator[_T]: 48 yield from super().__iter__() 49 50 def raw_iterator(self) -> Iterator[_T]: 51 yield from self.items 52 53 54class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): 55 r""" 56 Iterable-style DataPipe. 57 58 All DataPipes that represent an iterable of data samples should subclass this. 59 This style of DataPipes is particularly useful when data come from a stream, or 60 when the number of samples is too large to fit them all in memory. ``IterDataPipe`` is lazily initialized and its 61 elements are computed only when ``next()`` is called on the iterator of an ``IterDataPipe``. 62 63 All subclasses should overwrite :meth:`__iter__`, which would return an 64 iterator of samples in this DataPipe. Calling ``__iter__`` of an ``IterDataPipe`` automatically invokes its 65 method ``reset()``, which by default performs no operation. When writing a custom ``IterDataPipe``, users should 66 override ``reset()`` if necessary. The common usages include resetting buffers, pointers, 67 and various state variables within the custom ``IterDataPipe``. 68 69 Note: 70 Only `one` iterator can be valid for each ``IterDataPipe`` at a time, 71 and the creation a second iterator will invalidate the first one. This constraint is necessary because 72 some ``IterDataPipe`` have internal buffers, whose states can become invalid if there are multiple iterators. 73 The code example below presents details on how this constraint looks in practice. 74 If you have any feedback related to this constraint, please see `GitHub IterDataPipe Single Iterator Issue`_. 75 76 These DataPipes can be invoked in two ways, using the class constructor or applying their 77 functional form onto an existing ``IterDataPipe`` (recommended, available to most but not all DataPipes). 78 You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple 79 operations in succession. 80 81 .. _GitHub IterDataPipe Single Iterator Issue: 82 https://github.com/pytorch/data/issues/45 83 84 Note: 85 When a subclass is used with :class:`~torch.utils.data.DataLoader`, each 86 item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader` 87 iterator. When :attr:`num_workers > 0`, each worker process will have a 88 different copy of the DataPipe object, so it is often desired to configure 89 each copy independently to avoid having duplicate data returned from the 90 workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker 91 process, returns information about the worker. It can be used in either the 92 dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's 93 :attr:`worker_init_fn` option to modify each copy's behavior. 94 95 Examples: 96 General Usage: 97 >>> # xdoctest: +SKIP 98 >>> from torchdata.datapipes.iter import IterableWrapper, Mapper 99 >>> dp = IterableWrapper(range(10)) 100 >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor 101 >>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended) 102 >>> list(map_dp_1) 103 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 104 >>> list(map_dp_2) 105 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 106 >>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0) 107 >>> list(filter_dp) 108 [2, 4, 6, 8, 10] 109 Single Iterator Constraint Example: 110 >>> from torchdata.datapipes.iter import IterableWrapper, Mapper 111 >>> source_dp = IterableWrapper(range(10)) 112 >>> it1 = iter(source_dp) 113 >>> list(it1) 114 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 115 >>> it1 = iter(source_dp) 116 >>> it2 = iter(source_dp) # The creation of a new iterator invalidates `it1` 117 >>> next(it2) 118 0 119 >>> next(it1) # Further usage of `it1` will raise a `RunTimeError` 120 """ 121 122 functions: Dict[str, Callable] = {} 123 reduce_ex_hook: Optional[Callable] = None 124 getstate_hook: Optional[Callable] = None 125 str_hook: Optional[Callable] = None 126 repr_hook: Optional[Callable] = None 127 _valid_iterator_id: Optional[int] = None 128 _number_of_samples_yielded: int = 0 129 _snapshot_state: _SnapshotState = _SnapshotState.NotStarted 130 _fast_forward_iterator: Optional[Iterator] = None 131 132 def __iter__(self) -> Iterator[_T_co]: 133 return self 134 135 def __getattr__(self, attribute_name): 136 if attribute_name in IterDataPipe.functions: 137 if attribute_name in _iter_deprecated_functional_names: 138 kwargs = _iter_deprecated_functional_names[attribute_name] 139 _deprecation_warning(**kwargs) 140 f = IterDataPipe.functions[attribute_name] 141 function = functools.partial(f, self) 142 functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",)) 143 return function 144 else: 145 raise AttributeError( 146 f"'{self.__class__.__name__}' object has no attribute '{attribute_name}" 147 ) 148 149 @classmethod 150 def register_function(cls, function_name, function): 151 cls.functions[function_name] = function 152 153 @classmethod 154 def register_datapipe_as_function( 155 cls, function_name, cls_to_register, enable_df_api_tracing=False 156 ): 157 if function_name in cls.functions: 158 raise Exception( # noqa: TRY002 159 f"Unable to add DataPipe function name {function_name} as it is already taken" 160 ) 161 162 def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs): 163 result_pipe = cls(source_dp, *args, **kwargs) 164 if isinstance(result_pipe, IterDataPipe): 165 if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe): 166 if function_name not in UNTRACABLE_DATAFRAME_PIPES: 167 result_pipe = result_pipe.trace_as_dataframe() 168 169 return result_pipe 170 171 function = functools.partial( 172 class_function, cls_to_register, enable_df_api_tracing 173 ) 174 functools.update_wrapper( 175 wrapper=function, wrapped=cls_to_register, assigned=("__doc__",) 176 ) 177 cls.functions[function_name] = function 178 179 def __getstate__(self): 180 """ 181 Serialize `lambda` functions when `dill` is available. 182 183 If this doesn't cover your custom DataPipe's use case, consider writing custom methods for 184 `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization. 185 """ 186 state = self.__dict__ 187 if IterDataPipe.getstate_hook is not None: 188 return IterDataPipe.getstate_hook(state) 189 return state 190 191 def __reduce_ex__(self, *args, **kwargs): 192 if IterDataPipe.reduce_ex_hook is not None: 193 try: 194 return IterDataPipe.reduce_ex_hook(self) 195 except NotImplementedError: 196 pass 197 return super().__reduce_ex__(*args, **kwargs) 198 199 @classmethod 200 def set_getstate_hook(cls, hook_fn): 201 if IterDataPipe.getstate_hook is not None and hook_fn is not None: 202 raise RuntimeError("Attempt to override existing getstate_hook") 203 IterDataPipe.getstate_hook = hook_fn 204 205 @classmethod 206 def set_reduce_ex_hook(cls, hook_fn): 207 if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None: 208 raise RuntimeError("Attempt to override existing reduce_ex_hook") 209 IterDataPipe.reduce_ex_hook = hook_fn 210 211 def __repr__(self): 212 if self.repr_hook is not None: 213 return self.repr_hook(self) 214 # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name 215 return str(self.__class__.__qualname__) 216 217 def __str__(self): 218 if self.str_hook is not None: 219 return self.str_hook(self) 220 # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name 221 return str(self.__class__.__qualname__) 222 223 def __dir__(self): 224 # for auto-completion in a REPL (e.g. Jupyter notebook) 225 return list(super().__dir__()) + list(self.functions.keys()) 226 227 def reset(self) -> None: 228 r""" 229 Reset the `IterDataPipe` to the initial state. 230 231 By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities, 232 they may want to override this method with implementations that 233 may clear the buffers and reset pointers of the DataPipe. 234 The `reset` method is always called when `__iter__` is called as part of `hook_iterator`. 235 """ 236 237 238class DFIterDataPipe(IterDataPipe): 239 def _is_dfpipe(self): 240 return True 241 242 243class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta): 244 r""" 245 Map-style DataPipe. 246 247 All datasets that represent a map from keys to data samples should subclass this. 248 Subclasses should overwrite :meth:`__getitem__`, supporting fetching a 249 data sample for a given, unique key. Subclasses can also optionally overwrite 250 :meth:`__len__`, which is expected to return the size of the dataset by many 251 :class:`~torch.utils.data.Sampler` implementations and the default options 252 of :class:`~torch.utils.data.DataLoader`. 253 254 These DataPipes can be invoked in two ways, using the class constructor or applying their 255 functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes). 256 257 Note: 258 :class:`~torch.utils.data.DataLoader` by default constructs an index 259 sampler that yields integral indices. To make it work with a map-style 260 DataPipe with non-integral indices/keys, a custom sampler must be provided. 261 262 Example: 263 >>> # xdoctest: +SKIP 264 >>> from torchdata.datapipes.map import SequenceWrapper, Mapper 265 >>> dp = SequenceWrapper(range(10)) 266 >>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form (recommended) 267 >>> list(map_dp_1) 268 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 269 >>> map_dp_2 = Mapper(dp, lambda x: x + 1) # Using class constructor 270 >>> list(map_dp_2) 271 [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 272 >>> batch_dp = map_dp_1.batch(batch_size=2) 273 >>> list(batch_dp) 274 [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] 275 """ 276 277 functions: Dict[str, Callable] = {} 278 reduce_ex_hook: Optional[Callable] = None 279 getstate_hook: Optional[Callable] = None 280 str_hook: Optional[Callable] = None 281 repr_hook: Optional[Callable] = None 282 283 def __getattr__(self, attribute_name): 284 if attribute_name in MapDataPipe.functions: 285 if attribute_name in _map_deprecated_functional_names: 286 kwargs = _map_deprecated_functional_names[attribute_name] 287 _deprecation_warning(**kwargs) 288 f = MapDataPipe.functions[attribute_name] 289 function = functools.partial(f, self) 290 functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",)) 291 return function 292 else: 293 raise AttributeError( 294 f"'{self.__class__.__name__}' object has no attribute '{attribute_name}" 295 ) 296 297 @classmethod 298 def register_function(cls, function_name, function): 299 cls.functions[function_name] = function 300 301 @classmethod 302 def register_datapipe_as_function(cls, function_name, cls_to_register): 303 if function_name in cls.functions: 304 raise Exception( # noqa: TRY002 305 f"Unable to add DataPipe function name {function_name} as it is already taken" 306 ) 307 308 def class_function(cls, source_dp, *args, **kwargs): 309 result_pipe = cls(source_dp, *args, **kwargs) 310 return result_pipe 311 312 function = functools.partial(class_function, cls_to_register) 313 functools.update_wrapper( 314 wrapper=function, wrapped=cls_to_register, assigned=("__doc__",) 315 ) 316 cls.functions[function_name] = function 317 318 def __getstate__(self): 319 """ 320 Serialize `lambda` functions when `dill` is available. 321 322 If this doesn't cover your custom DataPipe's use case, consider writing custom methods for 323 `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization. 324 """ 325 state = self.__dict__ 326 if MapDataPipe.getstate_hook is not None: 327 return MapDataPipe.getstate_hook(state) 328 return state 329 330 def __reduce_ex__(self, *args, **kwargs): 331 if MapDataPipe.reduce_ex_hook is not None: 332 try: 333 return MapDataPipe.reduce_ex_hook(self) 334 except NotImplementedError: 335 pass 336 return super().__reduce_ex__(*args, **kwargs) 337 338 @classmethod 339 def set_getstate_hook(cls, hook_fn): 340 if MapDataPipe.getstate_hook is not None and hook_fn is not None: 341 raise RuntimeError("Attempt to override existing getstate_hook") 342 MapDataPipe.getstate_hook = hook_fn 343 344 @classmethod 345 def set_reduce_ex_hook(cls, hook_fn): 346 if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None: 347 raise RuntimeError("Attempt to override existing reduce_ex_hook") 348 MapDataPipe.reduce_ex_hook = hook_fn 349 350 def __repr__(self): 351 if self.repr_hook is not None: 352 return self.repr_hook(self) 353 # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name 354 return str(self.__class__.__qualname__) 355 356 def __str__(self): 357 if self.str_hook is not None: 358 return self.str_hook(self) 359 # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name 360 return str(self.__class__.__qualname__) 361 362 def __dir__(self): 363 # for auto-completion in a REPL (e.g. Jupyter notebook) 364 return list(super().__dir__()) + list(self.functions.keys()) 365 366 367class _DataPipeSerializationWrapper: 368 def __init__(self, datapipe): 369 self._datapipe = datapipe 370 371 def __getstate__(self): 372 use_dill = False 373 try: 374 value = pickle.dumps(self._datapipe) 375 except Exception: 376 if HAS_DILL: 377 value = dill.dumps(self._datapipe) 378 use_dill = True 379 else: 380 raise 381 return (value, use_dill) 382 383 def __setstate__(self, state): 384 value, use_dill = state 385 if use_dill: 386 self._datapipe = dill.loads(value) 387 else: 388 self._datapipe = pickle.loads(value) 389 390 def __len__(self): 391 try: 392 return len(self._datapipe) 393 except Exception as e: 394 raise TypeError( 395 f"{type(self).__name__} instance doesn't have valid length" 396 ) from e 397 398 399class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe): 400 def __init__(self, datapipe: IterDataPipe[_T_co]): 401 super().__init__(datapipe) 402 self._datapipe_iter: Optional[Iterator[_T_co]] = None 403 404 def __iter__(self) -> "_IterDataPipeSerializationWrapper": 405 self._datapipe_iter = iter(self._datapipe) 406 return self 407 408 def __next__(self) -> _T_co: # type: ignore[type-var] 409 assert self._datapipe_iter is not None 410 return next(self._datapipe_iter) 411 412 413class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe): 414 def __getitem__(self, idx): 415 return self._datapipe[idx] 416