1# mypy: allow-untyped-defs 2import copy as copymodule 3import warnings 4from abc import ABC, abstractmethod 5from collections import deque 6from typing import ( 7 Any, 8 Callable, 9 Deque, 10 Iterator, 11 List, 12 Literal, 13 Optional, 14 Sized, 15 Tuple, 16 TypeVar, 17) 18 19from torch.utils.data.datapipes._decorator import functional_datapipe 20from torch.utils.data.datapipes._hook_iterator import _SnapshotState 21from torch.utils.data.datapipes.datapipe import IterDataPipe 22from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, StreamWrapper 23 24 25__all__ = [ 26 "ConcaterIterDataPipe", 27 "DemultiplexerIterDataPipe", 28 "ForkerIterDataPipe", 29 "MultiplexerIterDataPipe", 30 "ZipperIterDataPipe", 31] 32 33 34_T_co = TypeVar("_T_co", covariant=True) 35 36 37@functional_datapipe("concat") 38class ConcaterIterDataPipe(IterDataPipe): 39 r""" 40 Concatenates multiple Iterable DataPipes (functional name: ``concat``). 41 42 The resulting DataPipe will yield all the elements from the first input DataPipe, before yielding from the subsequent ones. 43 44 Args: 45 datapipes: Iterable DataPipes being concatenated 46 47 Example: 48 >>> # xdoctest: +REQUIRES(module:torchdata) 49 >>> import random 50 >>> from torchdata.datapipes.iter import IterableWrapper 51 >>> dp1 = IterableWrapper(range(3)) 52 >>> dp2 = IterableWrapper(range(5)) 53 >>> list(dp1.concat(dp2)) 54 [0, 1, 2, 0, 1, 2, 3, 4] 55 """ 56 57 datapipes: Tuple[IterDataPipe] 58 59 def __init__(self, *datapipes: IterDataPipe): 60 if len(datapipes) == 0: 61 raise ValueError("Expected at least one DataPipe, but got nothing") 62 if not all(isinstance(dp, IterDataPipe) for dp in datapipes): 63 raise TypeError("Expected all inputs to be `IterDataPipe`") 64 self.datapipes = datapipes # type: ignore[assignment] 65 66 def __iter__(self) -> Iterator: 67 for dp in self.datapipes: 68 yield from dp 69 70 def __len__(self) -> int: 71 if all(isinstance(dp, Sized) for dp in self.datapipes): 72 return sum(len(dp) for dp in self.datapipes) 73 else: 74 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 75 76 77@functional_datapipe("fork") 78class ForkerIterDataPipe(IterDataPipe): 79 r""" 80 Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``). 81 82 Args: 83 datapipe: Iterable DataPipe being copied 84 num_instances: number of instances of the datapipe to create 85 buffer_size: this restricts how far ahead the leading child DataPipe 86 can read relative to the slowest child DataPipe. 87 Defaults to ``1000``. Use ``-1`` for the unlimited buffer. 88 copy: copy strategy to use for items yielded by each branch. Supported 89 options are ``None`` for no copying, ``"shallow"`` for shallow object 90 copies, and ``"deep"`` for deep object copies. Defaults to ``None``. 91 92 Note: 93 All branches of the forked pipeline return the identical object unless 94 the copy parameter is supplied. If the object is mutable or contains 95 mutable objects, changing them in one branch will affect all others. 96 97 Example: 98 >>> # xdoctest: +REQUIRES(module:torchdata) 99 >>> from torchdata.datapipes.iter import IterableWrapper 100 >>> source_dp = IterableWrapper(range(5)) 101 >>> dp1, dp2 = source_dp.fork(num_instances=2) 102 >>> list(dp1) 103 [0, 1, 2, 3, 4] 104 >>> list(dp2) 105 [0, 1, 2, 3, 4] 106 """ 107 108 def __new__( 109 cls, 110 datapipe: IterDataPipe, 111 num_instances: int, 112 buffer_size: int = 1000, 113 copy: Optional[Literal["shallow", "deep"]] = None, 114 ): 115 if num_instances < 1: 116 raise ValueError( 117 f"Expected `num_instances` larger than 0, but {num_instances} is found" 118 ) 119 if num_instances == 1: 120 return datapipe 121 container = _ForkerIterDataPipe(datapipe, num_instances, buffer_size, copy) # type: ignore[abstract] 122 return [_ChildDataPipe(container, i) for i in range(num_instances)] 123 124 125class _ContainerTemplate(ABC): 126 r"""Abstract class for container ``DataPipes``. The followings are three required methods.""" 127 128 @abstractmethod 129 def get_next_element_by_instance(self, instance_id: int): 130 ... 131 132 @abstractmethod 133 def is_every_instance_exhausted(self) -> bool: 134 ... 135 136 @abstractmethod 137 def reset(self) -> None: 138 ... 139 140 @abstractmethod 141 def get_length_by_instance(self, instance_id: int): 142 r"""Raise TypeError if it's not supposed to be implemented to support `list(datapipe)`.""" 143 144 145def _no_op(x): 146 return x 147 148 149class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate): 150 r""" 151 Container to hold instance-specific information on behalf of ForkerIterDataPipe. 152 153 It tracks the state of its child DataPipes, maintains the buffer, and yields the next value 154 as requested by the child DataPipes. 155 """ 156 157 def __init__( 158 self, 159 datapipe: IterDataPipe, 160 num_instances: int, 161 buffer_size: int = 1000, 162 copy: Optional[Literal["shallow", "deep"]] = None, 163 ): 164 self.main_datapipe = datapipe 165 self._datapipe_iterator: Optional[Iterator[Any]] = None 166 self.num_instances = num_instances 167 self.buffer: Deque = deque() 168 self.buffer_size = buffer_size 169 if self.buffer_size < 0: 170 warnings.warn( 171 "Unlimited buffer size is set for `fork`, " 172 "please be aware of OOM at random places", 173 UserWarning, 174 ) 175 if copy is None: 176 self.copy_fn = _no_op 177 elif copy == "shallow": 178 self.copy_fn = copymodule.copy 179 elif copy == "deep": 180 self.copy_fn = copymodule.deepcopy 181 else: 182 raise ValueError( 183 f"Unknown copy method `{copy}` requested, choose one of None, `shallow` or `deep`." 184 ) 185 186 self.child_pointers: List[int] = [ 187 0 188 ] * num_instances # Indicate the indices of the next element to get 189 self.slowest_ptr = 0 # The index to read by the slowest child 190 self.leading_ptr = 0 # The index to read by the fastest child 191 self.end_ptr: Optional[int] = None # The index to stop child 192 self._child_stop: List[bool] = [True for _ in range(num_instances)] 193 194 def __len__(self): 195 return len(self.main_datapipe) 196 197 def get_next_element_by_instance(self, instance_id: int): 198 if self._datapipe_iterator is None and self._child_stop[instance_id]: 199 self._datapipe_iterator = iter(self.main_datapipe) 200 self._snapshot_state = _SnapshotState.Iterating 201 for i in range(self.num_instances): 202 self._child_stop[i] = False 203 try: 204 while not self._child_stop[instance_id]: 205 self.child_pointers[instance_id] += 1 206 if ( 207 self.end_ptr is not None 208 and self.child_pointers[instance_id] == self.end_ptr 209 ): 210 self._child_stop[instance_id] = True 211 break 212 # Use buffer 213 if self.buffer and self.child_pointers[instance_id] <= self.leading_ptr: 214 idx = self.child_pointers[instance_id] - self.slowest_ptr - 1 215 return_val = self.buffer[idx] 216 else: # Retrieve one element from main datapipe 217 self.leading_ptr = self.child_pointers[instance_id] 218 try: 219 return_val = next(self._datapipe_iterator) # type: ignore[arg-type] 220 self.buffer.append(return_val) 221 except StopIteration: 222 self._child_stop[instance_id] = True 223 self._datapipe_iterator = None 224 self.end_ptr = self.leading_ptr 225 continue 226 if self.child_pointers[instance_id] == self.slowest_ptr + 1: 227 new_min = min( 228 self.child_pointers 229 ) # Can optimize by avoiding the call to min() 230 if self.slowest_ptr < new_min: 231 self.slowest_ptr = new_min 232 self.buffer.popleft() 233 if ( 234 self.buffer_size >= 0 235 and self.leading_ptr > self.buffer_size + self.slowest_ptr 236 ): 237 raise BufferError( 238 "ForkerIterDataPipe buffer overflow," 239 + f"buffer size {self.buffer_size} is insufficient." 240 ) 241 242 yield self.copy_fn(return_val) # type: ignore[possibly-undefined] 243 finally: 244 self._child_stop[instance_id] = True 245 # Cleanup _datapipe_iterator for the case that fork exits earlier 246 if all(self._child_stop): 247 self._datapipe_iterator = None 248 self._cleanup() 249 250 def is_every_instance_exhausted(self) -> bool: 251 return self.end_ptr is not None and all(self._child_stop) 252 253 def get_length_by_instance(self, instance_id: int) -> int: 254 return len(self.main_datapipe) 255 256 def reset(self) -> None: 257 self._datapipe_iterator = None 258 self.buffer = deque() 259 self.child_pointers = [0] * self.num_instances 260 self.slowest_ptr = 0 261 self.leading_ptr = 0 262 self.end_ptr = None 263 self._child_stop = [True for _ in range(self.num_instances)] 264 265 def __getstate__(self): 266 state = ( 267 self.main_datapipe, 268 self.num_instances, 269 self.buffer_size, 270 self.copy_fn, 271 self._valid_iterator_id, 272 self._number_of_samples_yielded, 273 ) 274 if IterDataPipe.getstate_hook is not None: 275 return IterDataPipe.getstate_hook(state) 276 return state 277 278 def __setstate__(self, state): 279 ( 280 self.main_datapipe, 281 self.num_instances, 282 self.buffer_size, 283 self.copy_fn, 284 self._valid_iterator_id, 285 self._number_of_samples_yielded, 286 ) = state 287 self._datapipe_iterator = None 288 self.buffer = deque() 289 self.child_pointers = [0] * self.num_instances 290 self.slowest_ptr = 0 291 self.leading_ptr = 0 292 self.end_ptr = None 293 self._child_stop = [True for _ in range(self.num_instances)] 294 295 def _cleanup(self): 296 while self.buffer: 297 d = self.buffer.popleft() 298 StreamWrapper.close_streams(d) 299 300 def __del__(self): 301 self._cleanup() 302 303 304class _ChildDataPipe(IterDataPipe): 305 r""" 306 Iterable Datapipe that is a child of a main DataPipe. 307 308 The instance of this class will pass its instance_id to get the next value from its main DataPipe. 309 310 Note: 311 ChildDataPipe, like all other IterDataPipe, follows the single iterator per IterDataPipe constraint. 312 Since ChildDataPipes share a common buffer, when an iterator is created for one of the ChildDataPipes, 313 the previous iterators for all ChildDataPipes must be invalidated, with the exception when a ChildDataPipe 314 hasn't had an iterator created from it since the last invalidation. See the example below. 315 316 Example: 317 >>> # xdoctest: +REQUIRES(module:torchdata) 318 >>> # Singler Iterator per IteraDataPipe Invalidation 319 >>> from torchdata.datapipes.iter import IterableWrapper 320 >>> source_dp = IterableWrapper(range(10)) 321 >>> cdp1, cdp2 = source_dp.fork(num_instances=2) 322 >>> it1, it2 = iter(cdp1), iter(cdp2) 323 >>> it3 = iter(cdp1) 324 >>> # The line above invalidates `it1` and `it2`, and resets `ForkerIterDataPipe`. 325 >>> it4 = iter(cdp2) 326 >>> # The line above doesn't invalidate `it3`, because an iterator for `cdp2` hasn't been created since 327 >>> # the last invalidation. 328 329 Args: 330 main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)' 331 instance_id: integer identifier of this instance 332 """ 333 334 _is_child_datapipe: bool = True 335 336 def __init__(self, main_datapipe: IterDataPipe, instance_id: int): 337 assert isinstance(main_datapipe, _ContainerTemplate) 338 339 self.main_datapipe: IterDataPipe = main_datapipe 340 self.instance_id = instance_id 341 342 def __iter__(self): 343 # Note that the logic behind setting iterator ID and `reset` are handled within `hook_iterator` 344 # We want to separate the code for reset and yield, so that 'reset' executes before __next__ is called 345 return self.main_datapipe.get_next_element_by_instance(self.instance_id) 346 347 def __len__(self): 348 return self.main_datapipe.get_length_by_instance(self.instance_id) 349 350 # This method is called by `hook_iterator` in `_typing.py`. 351 def _set_main_datapipe_valid_iterator_id(self) -> int: 352 r""" 353 Update the valid iterator ID for both this DataPipe object and `main_datapipe`. 354 355 `main_datapipe.reset()` is called when the ID is incremented to a new generation. 356 """ 357 # 1. First time any child iterator is created 358 if self.main_datapipe._valid_iterator_id is None: 359 self.main_datapipe._valid_iterator_id = 0 # type: ignore[attr-defined] 360 # 2. This instance was already in the same generation as `main_datapipe`, 361 # we need to increment the ID further by 1 362 elif self.main_datapipe._valid_iterator_id == self._valid_iterator_id: # type: ignore[has-type] 363 self.main_datapipe._valid_iterator_id += 1 # type: ignore[attr-defined] 364 # Whenever a new generation of iterator is created, the `main_datapipe` must reset 365 if not self.main_datapipe.is_every_instance_exhausted(): 366 warnings.warn( 367 "Some child DataPipes are not exhausted when __iter__ is called. We are resetting " 368 "the buffer and each child DataPipe will read from the start again.", 369 UserWarning, 370 ) 371 self.main_datapipe.reset() 372 # 3. Otherwise, the iterator is behind the others, so it will just need to catch up by setting 373 # the instance's iterator to match that of `main_datapipe` 374 self._valid_iterator_id = self.main_datapipe._valid_iterator_id 375 return self._valid_iterator_id 376 377 # This method is called by `hook_iterator` in `_typing.py`. 378 def _check_valid_iterator_id(self, iterator_id) -> bool: 379 r"""Check the valid iterator ID against that of DataPipe object and that of `main_datapipe`.""" 380 return ( 381 iterator_id == self._valid_iterator_id 382 and iterator_id == self.main_datapipe._valid_iterator_id 383 ) 384 385 386@functional_datapipe("demux") 387class DemultiplexerIterDataPipe(IterDataPipe): 388 r""" 389 Splits the input DataPipe into multiple child DataPipes, using the given classification function (functional name: ``demux``). 390 391 A list of the child DataPipes is returned from this operation. 392 393 Args: 394 datapipe: Iterable DataPipe being filtered 395 num_instances: number of instances of the DataPipe to create 396 classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None`` 397 drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None`` 398 buffer_size: this defines the maximum number of inputs that the buffer can hold across all child 399 DataPipes while waiting for their values to be yielded. 400 Defaults to ``1000``. Use ``-1`` for the unlimited buffer. 401 402 Examples: 403 >>> # xdoctest: +REQUIRES(module:torchdata) 404 >>> from torchdata.datapipes.iter import IterableWrapper 405 >>> def odd_or_even(n): 406 ... return n % 2 407 >>> source_dp = IterableWrapper(range(5)) 408 >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even) 409 >>> list(dp1) 410 [0, 2, 4] 411 >>> list(dp2) 412 [1, 3] 413 >>> # It can also filter out any element that gets `None` from the `classifier_fn` 414 >>> def odd_or_even_no_zero(n): 415 ... return n % 2 if n != 0 else None 416 >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True) 417 >>> list(dp1) 418 [2, 4] 419 >>> list(dp2) 420 [1, 3] 421 """ 422 423 def __new__( 424 cls, 425 datapipe: IterDataPipe, 426 num_instances: int, 427 classifier_fn: Callable[[_T_co], Optional[int]], 428 drop_none: bool = False, 429 buffer_size: int = 1000, 430 ): 431 if num_instances < 1: 432 raise ValueError( 433 f"Expected `num_instances` larger than 0, but {num_instances} is found" 434 ) 435 436 _check_unpickable_fn(classifier_fn) 437 438 # When num_instances == 1, demux can be replaced by filter, 439 # but keep it as Demultiplexer for the sake of consistency 440 # like throwing Error when classification result is out of o range 441 container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size) # type: ignore[abstract] 442 return [_ChildDataPipe(container, i) for i in range(num_instances)] 443 444 445class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate): 446 r""" 447 Container to hold instance-specific information on behalf of DemultiplexerIterDataPipe. 448 449 It tracks the state of its child DataPipes, maintains the buffer, classifies and yields the next correct value 450 as requested by the child DataPipes. 451 """ 452 453 def __init__( 454 self, 455 datapipe: IterDataPipe[_T_co], 456 num_instances: int, 457 classifier_fn: Callable[[_T_co], Optional[int]], 458 drop_none: bool, 459 buffer_size: int, 460 ): 461 self.main_datapipe = datapipe 462 self._datapipe_iterator: Optional[Iterator[Any]] = None 463 self.num_instances = num_instances 464 self.buffer_size = buffer_size 465 if self.buffer_size < 0: 466 warnings.warn( 467 "Unlimited buffer size is set for `demux`, " 468 "please be aware of OOM at random places", 469 UserWarning, 470 ) 471 self.current_buffer_usage = 0 472 self.child_buffers: List[Deque[_T_co]] = [deque() for _ in range(num_instances)] 473 self.classifier_fn = classifier_fn 474 self.drop_none = drop_none 475 self.main_datapipe_exhausted = False 476 self._child_stop: List[bool] = [True for _ in range(num_instances)] 477 478 def _find_next(self, instance_id: int) -> _T_co: # type: ignore[type-var] 479 while True: 480 if self.main_datapipe_exhausted or self._child_stop[instance_id]: 481 raise StopIteration 482 if self._datapipe_iterator is None: 483 raise ValueError( 484 "_datapipe_iterator has not been set, likely because this private method is called directly " 485 "without invoking get_next_element_by_instance() first." 486 ) 487 value = next(self._datapipe_iterator) 488 classification = self.classifier_fn(value) 489 if classification is None and self.drop_none: 490 StreamWrapper.close_streams(value) 491 continue 492 if ( 493 classification is None 494 or classification >= self.num_instances 495 or classification < 0 496 ): 497 raise ValueError( 498 f"Output of the classification fn should be between 0 and {self.num_instances - 1}. " 499 + f"{classification} is returned." 500 ) 501 if classification == instance_id: 502 return value 503 self.child_buffers[classification].append(value) 504 self.current_buffer_usage += 1 505 if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size: 506 raise BufferError( 507 f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient." 508 ) 509 510 def get_next_element_by_instance(self, instance_id: int): 511 if self._datapipe_iterator is None and self._child_stop[instance_id]: 512 self._datapipe_iterator = iter(self.main_datapipe) 513 self._snapshot_state = ( 514 _SnapshotState.Iterating 515 ) # This is necessary for the DataPipe to reset properly. 516 self.main_datapipe_exhausted = False 517 for i in range(self.num_instances): 518 self._child_stop[i] = False 519 520 try: 521 while not self._child_stop[instance_id]: 522 if self.child_buffers[instance_id]: 523 self.current_buffer_usage -= 1 524 yield self.child_buffers[instance_id].popleft() 525 else: 526 try: 527 yield self._find_next(instance_id) 528 except StopIteration: 529 self._child_stop[instance_id] = True 530 self.main_datapipe_exhausted = True 531 self._datapipe_iterator = None 532 finally: 533 self._child_stop[instance_id] = True 534 # Cleanup _datapipe_iterator for the case that demux exits earlier 535 if all(self._child_stop): 536 self._datapipe_iterator = None 537 if self.child_buffers[instance_id]: 538 self._cleanup(instance_id) 539 540 def is_every_instance_exhausted(self) -> bool: 541 return self.main_datapipe_exhausted and all(self._child_stop) 542 543 def get_length_by_instance(self, instance_id: int) -> int: 544 raise TypeError 545 546 def reset(self) -> None: 547 self._datapipe_iterator = None 548 self.current_buffer_usage = 0 549 self.child_buffers = [deque() for _ in range(self.num_instances)] 550 self._child_stop = [True for _ in range(self.num_instances)] 551 self.main_datapipe_exhausted = False 552 553 def __getstate__(self): 554 state = ( 555 self.main_datapipe, 556 self.num_instances, 557 self.buffer_size, 558 self.classifier_fn, 559 self.drop_none, 560 self._valid_iterator_id, 561 self._number_of_samples_yielded, 562 ) 563 if IterDataPipe.getstate_hook is not None: 564 return IterDataPipe.getstate_hook(state) 565 return state 566 567 def __setstate__(self, state): 568 ( 569 self.main_datapipe, 570 self.num_instances, 571 self.buffer_size, 572 self.classifier_fn, 573 self.drop_none, 574 self._valid_iterator_id, 575 self._number_of_samples_yielded, 576 ) = state 577 self._datapipe_iterator = None 578 self.current_buffer_usage = 0 579 self.child_buffers = [deque() for _ in range(self.num_instances)] 580 self._child_stop = [True for _ in range(self.num_instances)] 581 self.main_datapipe_exhausted = False 582 583 def _cleanup(self, instance_id: Optional[int] = None): 584 ids = ( 585 range(self.num_instances) 586 if instance_id is None 587 else [ 588 instance_id, 589 ] 590 ) 591 for i in ids: 592 q = self.child_buffers[i] 593 while q: 594 d = q.popleft() 595 StreamWrapper.close_streams(d) 596 597 def __del__(self): 598 self._cleanup() 599 600 601@functional_datapipe("mux") 602class MultiplexerIterDataPipe(IterDataPipe): 603 r""" 604 Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``). 605 606 As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, 607 and so on. It ends when the shortest input DataPipe is exhausted. 608 609 Args: 610 datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted 611 612 Example: 613 >>> # xdoctest: +REQUIRES(module:torchdata) 614 >>> from torchdata.datapipes.iter import IterableWrapper 615 >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) 616 >>> list(dp1.mux(dp2, dp3)) 617 [0, 10, 20, 1, 11, 21, 2, 12, 22] 618 """ 619 620 def __init__(self, *datapipes): 621 self.datapipes = datapipes 622 self.buffer: List = ( 623 [] 624 ) # Store values to be yielded only when every iterator provides one 625 626 def __iter__(self): 627 iterators = [iter(x) for x in self.datapipes] 628 while len(iterators): 629 for it in iterators: 630 try: 631 value = next(it) 632 self.buffer.append(value) 633 except StopIteration: 634 self.buffer.clear() 635 return 636 yield from self.buffer 637 self.buffer.clear() 638 639 def __len__(self): 640 if all(isinstance(dp, Sized) for dp in self.datapipes): 641 return min(len(dp) for dp in self.datapipes) * len(self.datapipes) 642 else: 643 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 644 645 def reset(self) -> None: 646 self.buffer = [] 647 648 def __getstate__(self): 649 state = ( 650 self.datapipes, 651 self._valid_iterator_id, 652 self._number_of_samples_yielded, 653 ) 654 if IterDataPipe.getstate_hook is not None: 655 return IterDataPipe.getstate_hook(state) 656 return state 657 658 def __setstate__(self, state): 659 ( 660 self.datapipes, 661 self._valid_iterator_id, 662 self._number_of_samples_yielded, 663 ) = state 664 self.buffer = [] 665 666 def __del__(self): 667 self.buffer.clear() 668 669 670@functional_datapipe("zip") 671class ZipperIterDataPipe(IterDataPipe[Tuple[_T_co]]): 672 r""" 673 Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``). 674 675 The output is stopped as soon as the shortest input DataPipe is exhausted. 676 677 Args: 678 *datapipes: Iterable DataPipes being aggregated 679 680 Example: 681 >>> # xdoctest: +REQUIRES(module:torchdata) 682 >>> from torchdata.datapipes.iter import IterableWrapper 683 >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) 684 >>> list(dp1.zip(dp2, dp3)) 685 [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)] 686 """ 687 688 datapipes: Tuple[IterDataPipe] 689 690 def __init__(self, *datapipes: IterDataPipe): 691 if not all(isinstance(dp, IterDataPipe) for dp in datapipes): 692 raise TypeError( 693 "All inputs are required to be `IterDataPipe` " "for `ZipIterDataPipe`." 694 ) 695 super().__init__() 696 self.datapipes = datapipes # type: ignore[assignment] 697 698 def __iter__(self) -> Iterator[Tuple[_T_co]]: 699 iterators = [iter(datapipe) for datapipe in self.datapipes] 700 yield from zip(*iterators) 701 702 def __len__(self) -> int: 703 if all(isinstance(dp, Sized) for dp in self.datapipes): 704 return min(len(dp) for dp in self.datapipes) 705 else: 706 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 707