1# mypy: allow-untyped-defs 2import warnings 3from collections import defaultdict 4from typing import ( 5 Any, 6 Callable, 7 DefaultDict, 8 Iterator, 9 List, 10 Optional, 11 Sized, 12 Type, 13 TypeVar, 14) 15 16import torch.utils.data.datapipes.iter.sharding 17from torch.utils.data.datapipes._decorator import functional_datapipe 18from torch.utils.data.datapipes.datapipe import DataChunk, IterDataPipe 19from torch.utils.data.datapipes.utils.common import _check_unpickable_fn 20 21 22__all__ = [ 23 "BatcherIterDataPipe", 24 "GrouperIterDataPipe", 25 "UnBatcherIterDataPipe", 26] 27 28 29_T_co = TypeVar("_T_co", covariant=True) 30 31 32def __getattr__(name: str): 33 if name in ["SHARDING_PRIORITIES", "ShardingFilterIterDataPipe"]: 34 warnings.warn( 35 f"`{name}` from `torch.utils.data.datapipes.iter.grouping` is going to be removed in PyTorch 2.1" 36 f"Please use `{name}` from the `torch.utils.data.datapipes.iter.sharding`", 37 category=FutureWarning, 38 stacklevel=2, 39 ) 40 41 return getattr(torch.utils.data.datapipes.iter.sharding, name) 42 43 raise AttributeError(f"module {__name__} has no attribute {name}") 44 45 46@functional_datapipe("batch") 47class BatcherIterDataPipe(IterDataPipe[DataChunk]): 48 r""" 49 Creates mini-batches of data (functional name: ``batch``). 50 51 An outer dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, or ``length % batch_size`` for the 52 last batch if ``drop_last`` is set to ``False``. 53 54 Args: 55 datapipe: Iterable DataPipe being batched 56 batch_size: The size of each batch 57 drop_last: Option to drop the last batch if it's not full 58 wrapper_class: wrapper to apply onto each batch (type ``List``) before yielding, 59 defaults to ``DataChunk`` 60 61 Example: 62 >>> # xdoctest: +SKIP 63 >>> from torchdata.datapipes.iter import IterableWrapper 64 >>> dp = IterableWrapper(range(10)) 65 >>> dp = dp.batch(batch_size=3, drop_last=True) 66 >>> list(dp) 67 [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 68 """ 69 70 datapipe: IterDataPipe 71 batch_size: int 72 drop_last: bool 73 74 def __init__( 75 self, 76 datapipe: IterDataPipe, 77 batch_size: int, 78 drop_last: bool = False, 79 wrapper_class: Type[DataChunk] = DataChunk, 80 ) -> None: 81 assert batch_size > 0, "Batch size is required to be larger than 0!" 82 super().__init__() 83 self.datapipe = datapipe 84 self.batch_size = batch_size 85 self.drop_last = drop_last 86 self.wrapper_class = wrapper_class 87 88 def __iter__(self) -> Iterator[DataChunk]: 89 batch: List = [] 90 for x in self.datapipe: 91 batch.append(x) 92 if len(batch) == self.batch_size: 93 yield self.wrapper_class(batch) 94 batch = [] 95 if len(batch) > 0: 96 if not self.drop_last: 97 yield self.wrapper_class(batch) 98 99 def __len__(self) -> int: 100 if isinstance(self.datapipe, Sized): 101 if self.drop_last: 102 return len(self.datapipe) // self.batch_size 103 else: 104 return (len(self.datapipe) + self.batch_size - 1) // self.batch_size 105 else: 106 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 107 108 109@functional_datapipe("unbatch") 110class UnBatcherIterDataPipe(IterDataPipe): 111 r""" 112 Undos batching of data (functional name: ``unbatch``). 113 114 In other words, it flattens the data up to the specified level within a batched DataPipe. 115 116 Args: 117 datapipe: Iterable DataPipe being un-batched 118 unbatch_level: Defaults to ``1`` (only flattening the top level). If set to ``2``, 119 it will flatten the top two levels, and ``-1`` will flatten the entire DataPipe. 120 121 Example: 122 >>> # xdoctest: +SKIP 123 >>> from torchdata.datapipes.iter import IterableWrapper 124 >>> source_dp = IterableWrapper([[[0, 1], [2]], [[3, 4], [5]], [[6]]]) 125 >>> dp1 = source_dp.unbatch() 126 >>> list(dp1) 127 [[0, 1], [2], [3, 4], [5], [6]] 128 >>> dp2 = source_dp.unbatch(unbatch_level=2) 129 >>> list(dp2) 130 [0, 1, 2, 3, 4, 5, 6] 131 """ 132 133 def __init__(self, datapipe: IterDataPipe, unbatch_level: int = 1): 134 self.datapipe = datapipe 135 self.unbatch_level = unbatch_level 136 137 def __iter__(self): 138 for element in self.datapipe: 139 yield from self._dive(element, unbatch_level=self.unbatch_level) 140 141 def _dive(self, element, unbatch_level): 142 if unbatch_level < -1: 143 raise ValueError("unbatch_level must be -1 or >= 0") 144 if unbatch_level == -1: 145 if isinstance(element, (list, DataChunk)): 146 for item in element: 147 yield from self._dive(item, unbatch_level=-1) 148 else: 149 yield element 150 elif unbatch_level == 0: 151 yield element 152 else: 153 if isinstance(element, (list, DataChunk)): 154 for item in element: 155 yield from self._dive(item, unbatch_level=unbatch_level - 1) 156 else: 157 raise IndexError( 158 f"unbatch_level {self.unbatch_level} exceeds the depth of the DataPipe" 159 ) 160 161 162@functional_datapipe("groupby") 163class GrouperIterDataPipe(IterDataPipe[DataChunk]): 164 r""" 165 Groups data from IterDataPipe by keys from ``group_key_fn``, yielding a ``DataChunk`` with batch size up to ``group_size``. 166 167 (functional name: ``groupby``). 168 169 The samples are read sequentially from the source ``datapipe``, and a batch of samples belonging to the same group 170 will be yielded as soon as the size of the batch reaches ``group_size``. When the buffer is full, 171 the DataPipe will yield the largest batch with the same key, provided that its size is larger 172 than ``guaranteed_group_size``. If its size is smaller, it will be dropped if ``drop_remaining=True``. 173 174 After iterating through the entirety of source ``datapipe``, everything not dropped due to the buffer capacity 175 will be yielded from the buffer, even if the group sizes are smaller than ``guaranteed_group_size``. 176 177 Args: 178 datapipe: Iterable datapipe to be grouped 179 group_key_fn: Function used to generate group key from the data of the source datapipe 180 keep_key: Option to yield the matching key along with the items in a tuple, 181 resulting in `(key, [items])` otherwise returning [items] 182 buffer_size: The size of buffer for ungrouped data 183 group_size: The max size of each group, a batch is yielded as soon as it reaches this size 184 guaranteed_group_size: The guaranteed minimum group size to be yielded in case the buffer is full 185 drop_remaining: Specifies if the group smaller than ``guaranteed_group_size`` will be dropped from buffer 186 when the buffer is full 187 188 Example: 189 >>> import os 190 >>> # xdoctest: +SKIP 191 >>> from torchdata.datapipes.iter import IterableWrapper 192 >>> def group_fn(file): 193 ... return os.path.basename(file).split(".")[0] 194 >>> source_dp = IterableWrapper(["a.png", "b.png", "a.json", "b.json", "a.jpg", "c.json"]) 195 >>> dp0 = source_dp.groupby(group_key_fn=group_fn) 196 >>> list(dp0) 197 [['a.png', 'a.json', 'a.jpg'], ['b.png', 'b.json'], ['c.json']] 198 >>> # A group is yielded as soon as its size equals to `group_size` 199 >>> dp1 = source_dp.groupby(group_key_fn=group_fn, group_size=2) 200 >>> list(dp1) 201 [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] 202 >>> # Scenario where `buffer` is full, and group 'a' needs to be yielded since its size > `guaranteed_group_size` 203 >>> dp2 = source_dp.groupby(group_key_fn=group_fn, buffer_size=3, group_size=3, guaranteed_group_size=2) 204 >>> list(dp2) 205 [['a.png', 'a.json'], ['b.png', 'b.json'], ['a.jpg'], ['c.json']] 206 """ 207 208 def __init__( 209 self, 210 datapipe: IterDataPipe[_T_co], 211 group_key_fn: Callable[[_T_co], Any], 212 *, 213 keep_key: bool = False, 214 buffer_size: int = 10000, 215 group_size: Optional[int] = None, 216 guaranteed_group_size: Optional[int] = None, 217 drop_remaining: bool = False, 218 ): 219 _check_unpickable_fn(group_key_fn) 220 self.datapipe = datapipe 221 self.group_key_fn = group_key_fn 222 223 self.keep_key = keep_key 224 self.max_buffer_size = buffer_size 225 self.buffer_elements: DefaultDict[Any, List] = defaultdict(list) 226 self.curr_buffer_size = 0 227 self.group_size = group_size 228 self.guaranteed_group_size = None 229 if group_size is not None and buffer_size is not None: 230 assert 0 < group_size <= buffer_size 231 self.guaranteed_group_size = group_size 232 if guaranteed_group_size is not None: 233 assert group_size is not None and 0 < guaranteed_group_size <= group_size 234 self.guaranteed_group_size = guaranteed_group_size 235 self.drop_remaining = drop_remaining 236 self.wrapper_class = DataChunk 237 238 def _remove_biggest_key(self): 239 biggest_key = None 240 biggest_size = 0 241 result_to_yield = None 242 for findkey in self.buffer_elements.keys(): 243 if len(self.buffer_elements[findkey]) > biggest_size: 244 biggest_size = len(self.buffer_elements[findkey]) 245 biggest_key = findkey 246 247 if ( 248 self.guaranteed_group_size is not None 249 and biggest_size < self.guaranteed_group_size 250 and not self.drop_remaining 251 ): 252 raise RuntimeError( 253 "Failed to group items", str(self.buffer_elements[biggest_key]) 254 ) 255 256 if ( 257 self.guaranteed_group_size is None 258 or biggest_size >= self.guaranteed_group_size 259 ): 260 result_to_yield = self.buffer_elements[biggest_key] 261 262 self.curr_buffer_size -= biggest_size 263 del self.buffer_elements[biggest_key] 264 265 return result_to_yield 266 267 def __iter__(self): 268 for x in self.datapipe: 269 key = self.group_key_fn(x) 270 271 self.buffer_elements[key].append(x) 272 self.curr_buffer_size += 1 273 274 if self.group_size is not None and self.group_size == len( 275 self.buffer_elements[key] 276 ): 277 result: DataChunk[Any] = self.wrapper_class(self.buffer_elements[key]) 278 yield (key, result) if self.keep_key else result 279 self.curr_buffer_size -= len(self.buffer_elements[key]) 280 del self.buffer_elements[key] 281 282 if self.curr_buffer_size == self.max_buffer_size: 283 result_to_yield = self._remove_biggest_key() 284 if result_to_yield is not None: 285 result = self.wrapper_class(result_to_yield) 286 yield (key, result) if self.keep_key else result 287 288 for key in tuple(self.buffer_elements.keys()): 289 result = self.wrapper_class(self.buffer_elements.pop(key)) 290 self.curr_buffer_size -= len(result) 291 yield (key, result) if self.keep_key else result 292 293 def reset(self) -> None: 294 self.curr_buffer_size = 0 295 self.buffer_elements = defaultdict(list) 296 297 def __getstate__(self): 298 state = ( 299 self.datapipe, 300 self.group_key_fn, 301 self.keep_key, 302 self.max_buffer_size, 303 self.group_size, 304 self.guaranteed_group_size, 305 self.drop_remaining, 306 self.wrapper_class, 307 self._valid_iterator_id, 308 self._number_of_samples_yielded, 309 ) 310 if IterDataPipe.getstate_hook is not None: 311 return IterDataPipe.getstate_hook(state) 312 return state 313 314 def __setstate__(self, state): 315 ( 316 self.datapipe, 317 self.group_key_fn, 318 self.keep_key, 319 self.max_buffer_size, 320 self.group_size, 321 self.guaranteed_group_size, 322 self.drop_remaining, 323 self.wrapper_class, 324 self._valid_iterator_id, 325 self._number_of_samples_yielded, 326 ) = state 327 self.curr_buffer_size = 0 328 self.buffer_elements = defaultdict(list) 329 330 def __del__(self): 331 self.buffer_elements.clear() 332