xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/iter/combinatorics.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import random
3from typing import Dict, Iterator, List, Optional, Sized, Tuple, Type, TypeVar
4
5import torch
6from torch.utils.data.datapipes._decorator import functional_datapipe
7from torch.utils.data.datapipes.datapipe import IterDataPipe
8from torch.utils.data.sampler import Sampler, SequentialSampler
9
10
11__all__ = [
12    "SamplerIterDataPipe",
13    "ShufflerIterDataPipe",
14]
15
16
17_T_co = TypeVar("_T_co", covariant=True)
18
19
20class SamplerIterDataPipe(IterDataPipe[_T_co]):
21    r"""
22    Generate sample elements using the provided ``Sampler`` (defaults to :class:`SequentialSampler`).
23
24    Args:
25        datapipe: IterDataPipe to sample from
26        sampler: Sampler class to generate sample elements from input DataPipe.
27            Default is :class:`SequentialSampler` for IterDataPipe
28    """
29
30    datapipe: IterDataPipe
31    sampler: Sampler
32
33    def __init__(
34        self,
35        datapipe: IterDataPipe,
36        sampler: Type[Sampler] = SequentialSampler,
37        sampler_args: Optional[Tuple] = None,
38        sampler_kwargs: Optional[Dict] = None,
39    ) -> None:
40        assert isinstance(
41            datapipe, Sized
42        ), "Sampler class requires input datapipe implemented `__len__`"
43        super().__init__()
44        self.datapipe = datapipe
45        self.sampler_args = () if sampler_args is None else sampler_args
46        self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
47        # https://github.com/python/mypy/pull/9629 will solve
48        self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs)  # type: ignore[misc]
49
50    def __iter__(self) -> Iterator[_T_co]:
51        return iter(self.sampler)
52
53    def __len__(self) -> int:
54        # Dataset has been tested as `Sized`
55        if isinstance(self.sampler, Sized):
56            return len(self.sampler)
57        raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
58
59
60@functional_datapipe("shuffle")
61class ShufflerIterDataPipe(IterDataPipe[_T_co]):
62    r"""
63    Shuffle the input DataPipe with a buffer (functional name: ``shuffle``).
64
65    The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then,
66    each item will be yielded from the buffer by reservoir sampling via iterator.
67
68    ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the
69    datapipe is not shuffled. In order to fully shuffle all elements from datapipe,
70    ``buffer_size`` is required to be greater than or equal to the size of datapipe.
71
72    When it is used with :class:`torch.utils.data.DataLoader`, the methods to
73    set up random seed are different based on :attr:`num_workers`.
74
75    For single-process mode (:attr:`num_workers == 0`), the random seed is set before
76    the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
77    mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed
78    for each worker process.
79
80    Args:
81        datapipe: The IterDataPipe being shuffled
82        buffer_size: The buffer size for shuffling (default to ``10000``)
83        unbatch_level: Specifies if it is necessary to unbatch source data before
84            applying the shuffle
85
86    Example:
87        >>> # xdoctest: +SKIP
88        >>> from torchdata.datapipes.iter import IterableWrapper
89        >>> dp = IterableWrapper(range(10))
90        >>> shuffle_dp = dp.shuffle()
91        >>> list(shuffle_dp)
92        [0, 4, 1, 6, 3, 2, 9, 5, 7, 8]
93    """
94
95    datapipe: IterDataPipe[_T_co]
96    buffer_size: int
97    _buffer: List[_T_co]
98    _enabled: bool
99    _seed: Optional[int]
100    _rng: random.Random
101
102    def __init__(
103        self,
104        datapipe: IterDataPipe[_T_co],
105        *,
106        buffer_size: int = 10000,
107        unbatch_level: int = 0,
108    ) -> None:
109        super().__init__()
110        # TODO: Performance optimization
111        #       buffer can be a fixed size and remove expensive `append()` and `len()` operations
112        self._buffer: List[_T_co] = []
113        assert buffer_size > 0, "buffer_size should be larger than 0"
114        if unbatch_level == 0:
115            self.datapipe = datapipe
116        else:
117            self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level)
118        self.buffer_size = buffer_size
119        self._enabled = True
120        self._seed = None
121        self._rng = random.Random()
122
123    def set_shuffle(self, shuffle=True):
124        self._enabled = shuffle
125        return self
126
127    def set_seed(self, seed: int):
128        self._seed = seed
129        return self
130
131    def __iter__(self) -> Iterator[_T_co]:
132        if not self._enabled:
133            yield from self.datapipe
134        else:
135            for x in self.datapipe:
136                if len(self._buffer) == self.buffer_size:
137                    idx = self._rng.randint(0, len(self._buffer) - 1)
138                    val, self._buffer[idx] = self._buffer[idx], x
139                    yield val
140                else:
141                    self._buffer.append(x)
142            while self._buffer:
143                idx = self._rng.randint(0, len(self._buffer) - 1)
144                yield self._buffer.pop(idx)
145
146    def __len__(self) -> int:
147        if isinstance(self.datapipe, Sized):
148            return len(self.datapipe)
149        raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
150
151    def reset(self) -> None:
152        self._buffer = []
153        if self._enabled:
154            if self._seed is None:
155                self._seed = int(torch.empty((), dtype=torch.int64).random_().item())
156            self._rng.seed(self._seed)
157            self._seed = None
158
159    def __getstate__(self):
160        state = (
161            self.datapipe,
162            self.buffer_size,
163            self._enabled,
164            self._seed,
165            self._buffer,
166            self._rng.getstate(),
167            self._valid_iterator_id,
168            self._number_of_samples_yielded,
169        )
170        if IterDataPipe.getstate_hook is not None:
171            return IterDataPipe.getstate_hook(state)
172        return state
173
174    def __setstate__(self, state):
175        (
176            self.datapipe,
177            self.buffer_size,
178            self._enabled,
179            self._seed,
180            self._buffer,
181            rng_state,
182            self._valid_iterator_id,
183            self._number_of_samples_yielded,
184        ) = state
185        self._rng = random.Random()
186        self._rng.setstate(rng_state)
187
188    def __del__(self):
189        self._buffer.clear()
190