xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/datapipe.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import functools
2import pickle
3from typing import Callable, Dict, Iterable, Iterator, List, Optional, TypeVar
4
5from torch.utils._import_utils import import_dill
6from torch.utils.data.datapipes._hook_iterator import _SnapshotState
7from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta
8from torch.utils.data.datapipes.utils.common import (
9    _deprecation_warning,
10    _iter_deprecated_functional_names,
11    _map_deprecated_functional_names,
12)
13from torch.utils.data.dataset import Dataset, IterableDataset
14
15
16dill = import_dill()
17HAS_DILL = dill is not None
18
19__all__ = [
20    "DataChunk",
21    "DFIterDataPipe",
22    "IterDataPipe",
23    "MapDataPipe",
24]
25
26
27_T = TypeVar("_T")
28_T_co = TypeVar("_T_co", covariant=True)
29
30UNTRACABLE_DATAFRAME_PIPES = [
31    "batch",  # As it returns DataChunks
32    "groupby",  # As it returns DataChunks
33    "_dataframes_as_tuples",  # As it unpacks DF
34    "trace_as_dataframe",  # As it used to mark DF for tracing
35]
36
37
38class DataChunk(List[_T]):
39    def __init__(self, items: Iterable[_T]) -> None:
40        items = list(items)
41        super().__init__(items)
42        self.items = items
43
44    def as_str(self, indent: str = "") -> str:
45        return indent + "[" + ", ".join(str(i) for i in iter(self)) + "]"
46
47    def __iter__(self) -> Iterator[_T]:
48        yield from super().__iter__()
49
50    def raw_iterator(self) -> Iterator[_T]:
51        yield from self.items
52
53
54class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
55    r"""
56    Iterable-style DataPipe.
57
58    All DataPipes that represent an iterable of data samples should subclass this.
59    This style of DataPipes is particularly useful when data come from a stream, or
60    when the number of samples is too large to fit them all in memory. ``IterDataPipe`` is lazily initialized and its
61    elements are computed only when ``next()`` is called on the iterator of an ``IterDataPipe``.
62
63    All subclasses should overwrite :meth:`__iter__`, which would return an
64    iterator of samples in this DataPipe. Calling ``__iter__`` of an ``IterDataPipe`` automatically invokes its
65    method ``reset()``, which by default performs no operation. When writing a custom ``IterDataPipe``, users should
66    override ``reset()`` if necessary. The common usages include resetting buffers, pointers,
67    and various state variables within the custom ``IterDataPipe``.
68
69    Note:
70        Only `one` iterator can be valid for each ``IterDataPipe`` at a time,
71        and the creation a second iterator will invalidate the first one. This constraint is necessary because
72        some ``IterDataPipe`` have internal buffers, whose states can become invalid if there are multiple iterators.
73        The code example below presents details on how this constraint looks in practice.
74        If you have any feedback related to this constraint, please see `GitHub IterDataPipe Single Iterator Issue`_.
75
76    These DataPipes can be invoked in two ways, using the class constructor or applying their
77    functional form onto an existing ``IterDataPipe`` (recommended, available to most but not all DataPipes).
78    You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple
79    operations in succession.
80
81    .. _GitHub IterDataPipe Single Iterator Issue:
82        https://github.com/pytorch/data/issues/45
83
84    Note:
85        When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
86        item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader`
87        iterator. When :attr:`num_workers > 0`, each worker process will have a
88        different copy of the DataPipe object, so it is often desired to configure
89        each copy independently to avoid having duplicate data returned from the
90        workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
91        process, returns information about the worker. It can be used in either the
92        dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
93        :attr:`worker_init_fn` option to modify each copy's behavior.
94
95    Examples:
96        General Usage:
97            >>> # xdoctest: +SKIP
98            >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
99            >>> dp = IterableWrapper(range(10))
100            >>> map_dp_1 = Mapper(dp, lambda x: x + 1)  # Using class constructor
101            >>> map_dp_2 = dp.map(lambda x: x + 1)  # Using functional form (recommended)
102            >>> list(map_dp_1)
103            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
104            >>> list(map_dp_2)
105            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
106            >>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0)
107            >>> list(filter_dp)
108            [2, 4, 6, 8, 10]
109        Single Iterator Constraint Example:
110            >>> from torchdata.datapipes.iter import IterableWrapper, Mapper
111            >>> source_dp = IterableWrapper(range(10))
112            >>> it1 = iter(source_dp)
113            >>> list(it1)
114            [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
115            >>> it1 = iter(source_dp)
116            >>> it2 = iter(source_dp)  # The creation of a new iterator invalidates `it1`
117            >>> next(it2)
118            0
119            >>> next(it1)  # Further usage of `it1` will raise a `RunTimeError`
120    """
121
122    functions: Dict[str, Callable] = {}
123    reduce_ex_hook: Optional[Callable] = None
124    getstate_hook: Optional[Callable] = None
125    str_hook: Optional[Callable] = None
126    repr_hook: Optional[Callable] = None
127    _valid_iterator_id: Optional[int] = None
128    _number_of_samples_yielded: int = 0
129    _snapshot_state: _SnapshotState = _SnapshotState.NotStarted
130    _fast_forward_iterator: Optional[Iterator] = None
131
132    def __iter__(self) -> Iterator[_T_co]:
133        return self
134
135    def __getattr__(self, attribute_name):
136        if attribute_name in IterDataPipe.functions:
137            if attribute_name in _iter_deprecated_functional_names:
138                kwargs = _iter_deprecated_functional_names[attribute_name]
139                _deprecation_warning(**kwargs)
140            f = IterDataPipe.functions[attribute_name]
141            function = functools.partial(f, self)
142            functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
143            return function
144        else:
145            raise AttributeError(
146                f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
147            )
148
149    @classmethod
150    def register_function(cls, function_name, function):
151        cls.functions[function_name] = function
152
153    @classmethod
154    def register_datapipe_as_function(
155        cls, function_name, cls_to_register, enable_df_api_tracing=False
156    ):
157        if function_name in cls.functions:
158            raise Exception(  # noqa: TRY002
159                f"Unable to add DataPipe function name {function_name} as it is already taken"
160            )
161
162        def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
163            result_pipe = cls(source_dp, *args, **kwargs)
164            if isinstance(result_pipe, IterDataPipe):
165                if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
166                    if function_name not in UNTRACABLE_DATAFRAME_PIPES:
167                        result_pipe = result_pipe.trace_as_dataframe()
168
169            return result_pipe
170
171        function = functools.partial(
172            class_function, cls_to_register, enable_df_api_tracing
173        )
174        functools.update_wrapper(
175            wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
176        )
177        cls.functions[function_name] = function
178
179    def __getstate__(self):
180        """
181        Serialize `lambda` functions when `dill` is available.
182
183        If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
184        `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
185        """
186        state = self.__dict__
187        if IterDataPipe.getstate_hook is not None:
188            return IterDataPipe.getstate_hook(state)
189        return state
190
191    def __reduce_ex__(self, *args, **kwargs):
192        if IterDataPipe.reduce_ex_hook is not None:
193            try:
194                return IterDataPipe.reduce_ex_hook(self)
195            except NotImplementedError:
196                pass
197        return super().__reduce_ex__(*args, **kwargs)
198
199    @classmethod
200    def set_getstate_hook(cls, hook_fn):
201        if IterDataPipe.getstate_hook is not None and hook_fn is not None:
202            raise RuntimeError("Attempt to override existing getstate_hook")
203        IterDataPipe.getstate_hook = hook_fn
204
205    @classmethod
206    def set_reduce_ex_hook(cls, hook_fn):
207        if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None:
208            raise RuntimeError("Attempt to override existing reduce_ex_hook")
209        IterDataPipe.reduce_ex_hook = hook_fn
210
211    def __repr__(self):
212        if self.repr_hook is not None:
213            return self.repr_hook(self)
214        # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
215        return str(self.__class__.__qualname__)
216
217    def __str__(self):
218        if self.str_hook is not None:
219            return self.str_hook(self)
220        # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name
221        return str(self.__class__.__qualname__)
222
223    def __dir__(self):
224        # for auto-completion in a REPL (e.g. Jupyter notebook)
225        return list(super().__dir__()) + list(self.functions.keys())
226
227    def reset(self) -> None:
228        r"""
229        Reset the `IterDataPipe` to the initial state.
230
231        By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities,
232        they may want to override this method with implementations that
233        may clear the buffers and reset pointers of the DataPipe.
234        The `reset` method is always called when `__iter__` is called as part of `hook_iterator`.
235        """
236
237
238class DFIterDataPipe(IterDataPipe):
239    def _is_dfpipe(self):
240        return True
241
242
243class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta):
244    r"""
245    Map-style DataPipe.
246
247    All datasets that represent a map from keys to data samples should subclass this.
248    Subclasses should overwrite :meth:`__getitem__`, supporting fetching a
249    data sample for a given, unique key. Subclasses can also optionally overwrite
250    :meth:`__len__`, which is expected to return the size of the dataset by many
251    :class:`~torch.utils.data.Sampler` implementations and the default options
252    of :class:`~torch.utils.data.DataLoader`.
253
254    These DataPipes can be invoked in two ways, using the class constructor or applying their
255    functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes).
256
257    Note:
258        :class:`~torch.utils.data.DataLoader` by default constructs an index
259        sampler that yields integral indices. To make it work with a map-style
260        DataPipe with non-integral indices/keys, a custom sampler must be provided.
261
262    Example:
263        >>> # xdoctest: +SKIP
264        >>> from torchdata.datapipes.map import SequenceWrapper, Mapper
265        >>> dp = SequenceWrapper(range(10))
266        >>> map_dp_1 = dp.map(lambda x: x + 1)  # Using functional form (recommended)
267        >>> list(map_dp_1)
268        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
269        >>> map_dp_2 = Mapper(dp, lambda x: x + 1)  # Using class constructor
270        >>> list(map_dp_2)
271        [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
272        >>> batch_dp = map_dp_1.batch(batch_size=2)
273        >>> list(batch_dp)
274        [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
275    """
276
277    functions: Dict[str, Callable] = {}
278    reduce_ex_hook: Optional[Callable] = None
279    getstate_hook: Optional[Callable] = None
280    str_hook: Optional[Callable] = None
281    repr_hook: Optional[Callable] = None
282
283    def __getattr__(self, attribute_name):
284        if attribute_name in MapDataPipe.functions:
285            if attribute_name in _map_deprecated_functional_names:
286                kwargs = _map_deprecated_functional_names[attribute_name]
287                _deprecation_warning(**kwargs)
288            f = MapDataPipe.functions[attribute_name]
289            function = functools.partial(f, self)
290            functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
291            return function
292        else:
293            raise AttributeError(
294                f"'{self.__class__.__name__}' object has no attribute '{attribute_name}"
295            )
296
297    @classmethod
298    def register_function(cls, function_name, function):
299        cls.functions[function_name] = function
300
301    @classmethod
302    def register_datapipe_as_function(cls, function_name, cls_to_register):
303        if function_name in cls.functions:
304            raise Exception(  # noqa: TRY002
305                f"Unable to add DataPipe function name {function_name} as it is already taken"
306            )
307
308        def class_function(cls, source_dp, *args, **kwargs):
309            result_pipe = cls(source_dp, *args, **kwargs)
310            return result_pipe
311
312        function = functools.partial(class_function, cls_to_register)
313        functools.update_wrapper(
314            wrapper=function, wrapped=cls_to_register, assigned=("__doc__",)
315        )
316        cls.functions[function_name] = function
317
318    def __getstate__(self):
319        """
320        Serialize `lambda` functions when `dill` is available.
321
322        If this doesn't cover your custom DataPipe's use case, consider writing custom methods for
323        `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization.
324        """
325        state = self.__dict__
326        if MapDataPipe.getstate_hook is not None:
327            return MapDataPipe.getstate_hook(state)
328        return state
329
330    def __reduce_ex__(self, *args, **kwargs):
331        if MapDataPipe.reduce_ex_hook is not None:
332            try:
333                return MapDataPipe.reduce_ex_hook(self)
334            except NotImplementedError:
335                pass
336        return super().__reduce_ex__(*args, **kwargs)
337
338    @classmethod
339    def set_getstate_hook(cls, hook_fn):
340        if MapDataPipe.getstate_hook is not None and hook_fn is not None:
341            raise RuntimeError("Attempt to override existing getstate_hook")
342        MapDataPipe.getstate_hook = hook_fn
343
344    @classmethod
345    def set_reduce_ex_hook(cls, hook_fn):
346        if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None:
347            raise RuntimeError("Attempt to override existing reduce_ex_hook")
348        MapDataPipe.reduce_ex_hook = hook_fn
349
350    def __repr__(self):
351        if self.repr_hook is not None:
352            return self.repr_hook(self)
353        # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
354        return str(self.__class__.__qualname__)
355
356    def __str__(self):
357        if self.str_hook is not None:
358            return self.str_hook(self)
359        # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
360        return str(self.__class__.__qualname__)
361
362    def __dir__(self):
363        # for auto-completion in a REPL (e.g. Jupyter notebook)
364        return list(super().__dir__()) + list(self.functions.keys())
365
366
367class _DataPipeSerializationWrapper:
368    def __init__(self, datapipe):
369        self._datapipe = datapipe
370
371    def __getstate__(self):
372        use_dill = False
373        try:
374            value = pickle.dumps(self._datapipe)
375        except Exception:
376            if HAS_DILL:
377                value = dill.dumps(self._datapipe)
378                use_dill = True
379            else:
380                raise
381        return (value, use_dill)
382
383    def __setstate__(self, state):
384        value, use_dill = state
385        if use_dill:
386            self._datapipe = dill.loads(value)
387        else:
388            self._datapipe = pickle.loads(value)
389
390    def __len__(self):
391        try:
392            return len(self._datapipe)
393        except Exception as e:
394            raise TypeError(
395                f"{type(self).__name__} instance doesn't have valid length"
396            ) from e
397
398
399class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
400    def __init__(self, datapipe: IterDataPipe[_T_co]):
401        super().__init__(datapipe)
402        self._datapipe_iter: Optional[Iterator[_T_co]] = None
403
404    def __iter__(self) -> "_IterDataPipeSerializationWrapper":
405        self._datapipe_iter = iter(self._datapipe)
406        return self
407
408    def __next__(self) -> _T_co:  # type: ignore[type-var]
409        assert self._datapipe_iter is not None
410        return next(self._datapipe_iter)
411
412
413class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe):
414    def __getitem__(self, idx):
415        return self._datapipe[idx]
416