xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/iter/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import warnings
4
5from torch.utils.data.datapipes.datapipe import IterDataPipe
6
7
8__all__ = ["IterableWrapperIterDataPipe"]
9
10
11class IterableWrapperIterDataPipe(IterDataPipe):
12    r"""
13    Wraps an iterable object to create an IterDataPipe.
14
15    Args:
16        iterable: Iterable object to be wrapped into an IterDataPipe
17        deepcopy: Option to deepcopy input iterable object for each
18            iterator. The copy is made when the first element is read in ``iter()``.
19
20    .. note::
21        If ``deepcopy`` is explicitly set to ``False``, users should ensure
22        that the data pipeline doesn't contain any in-place operations over
23        the iterable instance to prevent data inconsistency across iterations.
24
25    Example:
26        >>> # xdoctest: +SKIP
27        >>> from torchdata.datapipes.iter import IterableWrapper
28        >>> dp = IterableWrapper(range(10))
29        >>> list(dp)
30        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
31    """
32
33    def __init__(self, iterable, deepcopy=True):
34        self.iterable = iterable
35        self.deepcopy = deepcopy
36
37    def __iter__(self):
38        source_data = self.iterable
39        if self.deepcopy:
40            try:
41                source_data = copy.deepcopy(self.iterable)
42            # For the case that data cannot be deep-copied,
43            # all in-place operations will affect iterable variable.
44            # When this DataPipe is iterated second time, it will
45            # yield modified items.
46            except TypeError:
47                warnings.warn(
48                    "The input iterable can not be deepcopied, "
49                    "please be aware of in-place modification would affect source data."
50                )
51        yield from source_data
52
53    def __len__(self):
54        return len(self.iterable)
55