1# mypy: allow-untyped-defs 2import random 3from typing import Dict, Iterator, List, Optional, Sized, Tuple, Type, TypeVar 4 5import torch 6from torch.utils.data.datapipes._decorator import functional_datapipe 7from torch.utils.data.datapipes.datapipe import IterDataPipe 8from torch.utils.data.sampler import Sampler, SequentialSampler 9 10 11__all__ = [ 12 "SamplerIterDataPipe", 13 "ShufflerIterDataPipe", 14] 15 16 17_T_co = TypeVar("_T_co", covariant=True) 18 19 20class SamplerIterDataPipe(IterDataPipe[_T_co]): 21 r""" 22 Generate sample elements using the provided ``Sampler`` (defaults to :class:`SequentialSampler`). 23 24 Args: 25 datapipe: IterDataPipe to sample from 26 sampler: Sampler class to generate sample elements from input DataPipe. 27 Default is :class:`SequentialSampler` for IterDataPipe 28 """ 29 30 datapipe: IterDataPipe 31 sampler: Sampler 32 33 def __init__( 34 self, 35 datapipe: IterDataPipe, 36 sampler: Type[Sampler] = SequentialSampler, 37 sampler_args: Optional[Tuple] = None, 38 sampler_kwargs: Optional[Dict] = None, 39 ) -> None: 40 assert isinstance( 41 datapipe, Sized 42 ), "Sampler class requires input datapipe implemented `__len__`" 43 super().__init__() 44 self.datapipe = datapipe 45 self.sampler_args = () if sampler_args is None else sampler_args 46 self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs 47 # https://github.com/python/mypy/pull/9629 will solve 48 self.sampler = sampler(*self.sampler_args, data_source=self.datapipe, **self.sampler_kwargs) # type: ignore[misc] 49 50 def __iter__(self) -> Iterator[_T_co]: 51 return iter(self.sampler) 52 53 def __len__(self) -> int: 54 # Dataset has been tested as `Sized` 55 if isinstance(self.sampler, Sized): 56 return len(self.sampler) 57 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 58 59 60@functional_datapipe("shuffle") 61class ShufflerIterDataPipe(IterDataPipe[_T_co]): 62 r""" 63 Shuffle the input DataPipe with a buffer (functional name: ``shuffle``). 64 65 The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then, 66 each item will be yielded from the buffer by reservoir sampling via iterator. 67 68 ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the 69 datapipe is not shuffled. In order to fully shuffle all elements from datapipe, 70 ``buffer_size`` is required to be greater than or equal to the size of datapipe. 71 72 When it is used with :class:`torch.utils.data.DataLoader`, the methods to 73 set up random seed are different based on :attr:`num_workers`. 74 75 For single-process mode (:attr:`num_workers == 0`), the random seed is set before 76 the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process 77 mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed 78 for each worker process. 79 80 Args: 81 datapipe: The IterDataPipe being shuffled 82 buffer_size: The buffer size for shuffling (default to ``10000``) 83 unbatch_level: Specifies if it is necessary to unbatch source data before 84 applying the shuffle 85 86 Example: 87 >>> # xdoctest: +SKIP 88 >>> from torchdata.datapipes.iter import IterableWrapper 89 >>> dp = IterableWrapper(range(10)) 90 >>> shuffle_dp = dp.shuffle() 91 >>> list(shuffle_dp) 92 [0, 4, 1, 6, 3, 2, 9, 5, 7, 8] 93 """ 94 95 datapipe: IterDataPipe[_T_co] 96 buffer_size: int 97 _buffer: List[_T_co] 98 _enabled: bool 99 _seed: Optional[int] 100 _rng: random.Random 101 102 def __init__( 103 self, 104 datapipe: IterDataPipe[_T_co], 105 *, 106 buffer_size: int = 10000, 107 unbatch_level: int = 0, 108 ) -> None: 109 super().__init__() 110 # TODO: Performance optimization 111 # buffer can be a fixed size and remove expensive `append()` and `len()` operations 112 self._buffer: List[_T_co] = [] 113 assert buffer_size > 0, "buffer_size should be larger than 0" 114 if unbatch_level == 0: 115 self.datapipe = datapipe 116 else: 117 self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level) 118 self.buffer_size = buffer_size 119 self._enabled = True 120 self._seed = None 121 self._rng = random.Random() 122 123 def set_shuffle(self, shuffle=True): 124 self._enabled = shuffle 125 return self 126 127 def set_seed(self, seed: int): 128 self._seed = seed 129 return self 130 131 def __iter__(self) -> Iterator[_T_co]: 132 if not self._enabled: 133 yield from self.datapipe 134 else: 135 for x in self.datapipe: 136 if len(self._buffer) == self.buffer_size: 137 idx = self._rng.randint(0, len(self._buffer) - 1) 138 val, self._buffer[idx] = self._buffer[idx], x 139 yield val 140 else: 141 self._buffer.append(x) 142 while self._buffer: 143 idx = self._rng.randint(0, len(self._buffer) - 1) 144 yield self._buffer.pop(idx) 145 146 def __len__(self) -> int: 147 if isinstance(self.datapipe, Sized): 148 return len(self.datapipe) 149 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 150 151 def reset(self) -> None: 152 self._buffer = [] 153 if self._enabled: 154 if self._seed is None: 155 self._seed = int(torch.empty((), dtype=torch.int64).random_().item()) 156 self._rng.seed(self._seed) 157 self._seed = None 158 159 def __getstate__(self): 160 state = ( 161 self.datapipe, 162 self.buffer_size, 163 self._enabled, 164 self._seed, 165 self._buffer, 166 self._rng.getstate(), 167 self._valid_iterator_id, 168 self._number_of_samples_yielded, 169 ) 170 if IterDataPipe.getstate_hook is not None: 171 return IterDataPipe.getstate_hook(state) 172 return state 173 174 def __setstate__(self, state): 175 ( 176 self.datapipe, 177 self.buffer_size, 178 self._enabled, 179 self._seed, 180 self._buffer, 181 rng_state, 182 self._valid_iterator_id, 183 self._number_of_samples_yielded, 184 ) = state 185 self._rng = random.Random() 186 self._rng.setstate(rng_state) 187 188 def __del__(self): 189 self._buffer.clear() 190