xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/iter/combining.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy as copymodule
3import warnings
4from abc import ABC, abstractmethod
5from collections import deque
6from typing import (
7    Any,
8    Callable,
9    Deque,
10    Iterator,
11    List,
12    Literal,
13    Optional,
14    Sized,
15    Tuple,
16    TypeVar,
17)
18
19from torch.utils.data.datapipes._decorator import functional_datapipe
20from torch.utils.data.datapipes._hook_iterator import _SnapshotState
21from torch.utils.data.datapipes.datapipe import IterDataPipe
22from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, StreamWrapper
23
24
25__all__ = [
26    "ConcaterIterDataPipe",
27    "DemultiplexerIterDataPipe",
28    "ForkerIterDataPipe",
29    "MultiplexerIterDataPipe",
30    "ZipperIterDataPipe",
31]
32
33
34_T_co = TypeVar("_T_co", covariant=True)
35
36
37@functional_datapipe("concat")
38class ConcaterIterDataPipe(IterDataPipe):
39    r"""
40    Concatenates multiple Iterable DataPipes (functional name: ``concat``).
41
42    The resulting DataPipe will yield all the elements from the first input DataPipe, before yielding from the subsequent ones.
43
44    Args:
45        datapipes: Iterable DataPipes being concatenated
46
47    Example:
48        >>> # xdoctest: +REQUIRES(module:torchdata)
49        >>> import random
50        >>> from torchdata.datapipes.iter import IterableWrapper
51        >>> dp1 = IterableWrapper(range(3))
52        >>> dp2 = IterableWrapper(range(5))
53        >>> list(dp1.concat(dp2))
54        [0, 1, 2, 0, 1, 2, 3, 4]
55    """
56
57    datapipes: Tuple[IterDataPipe]
58
59    def __init__(self, *datapipes: IterDataPipe):
60        if len(datapipes) == 0:
61            raise ValueError("Expected at least one DataPipe, but got nothing")
62        if not all(isinstance(dp, IterDataPipe) for dp in datapipes):
63            raise TypeError("Expected all inputs to be `IterDataPipe`")
64        self.datapipes = datapipes  # type: ignore[assignment]
65
66    def __iter__(self) -> Iterator:
67        for dp in self.datapipes:
68            yield from dp
69
70    def __len__(self) -> int:
71        if all(isinstance(dp, Sized) for dp in self.datapipes):
72            return sum(len(dp) for dp in self.datapipes)
73        else:
74            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
75
76
77@functional_datapipe("fork")
78class ForkerIterDataPipe(IterDataPipe):
79    r"""
80    Creates multiple instances of the same Iterable DataPipe (functional name: ``fork``).
81
82    Args:
83        datapipe: Iterable DataPipe being copied
84        num_instances: number of instances of the datapipe to create
85        buffer_size: this restricts how far ahead the leading child DataPipe
86           can read relative to the slowest child DataPipe.
87           Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
88        copy: copy strategy to use for items yielded by each branch. Supported
89            options are ``None`` for no copying, ``"shallow"`` for shallow object
90            copies, and ``"deep"`` for deep object copies. Defaults to ``None``.
91
92    Note:
93        All branches of the forked pipeline return the identical object unless
94        the copy parameter is supplied. If the object is mutable or contains
95        mutable objects, changing them in one branch will affect all others.
96
97    Example:
98        >>> # xdoctest: +REQUIRES(module:torchdata)
99        >>> from torchdata.datapipes.iter import IterableWrapper
100        >>> source_dp = IterableWrapper(range(5))
101        >>> dp1, dp2 = source_dp.fork(num_instances=2)
102        >>> list(dp1)
103        [0, 1, 2, 3, 4]
104        >>> list(dp2)
105        [0, 1, 2, 3, 4]
106    """
107
108    def __new__(
109        cls,
110        datapipe: IterDataPipe,
111        num_instances: int,
112        buffer_size: int = 1000,
113        copy: Optional[Literal["shallow", "deep"]] = None,
114    ):
115        if num_instances < 1:
116            raise ValueError(
117                f"Expected `num_instances` larger than 0, but {num_instances} is found"
118            )
119        if num_instances == 1:
120            return datapipe
121        container = _ForkerIterDataPipe(datapipe, num_instances, buffer_size, copy)  # type: ignore[abstract]
122        return [_ChildDataPipe(container, i) for i in range(num_instances)]
123
124
125class _ContainerTemplate(ABC):
126    r"""Abstract class for container ``DataPipes``. The followings are three required methods."""
127
128    @abstractmethod
129    def get_next_element_by_instance(self, instance_id: int):
130        ...
131
132    @abstractmethod
133    def is_every_instance_exhausted(self) -> bool:
134        ...
135
136    @abstractmethod
137    def reset(self) -> None:
138        ...
139
140    @abstractmethod
141    def get_length_by_instance(self, instance_id: int):
142        r"""Raise TypeError if it's not supposed to be implemented to support `list(datapipe)`."""
143
144
145def _no_op(x):
146    return x
147
148
149class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
150    r"""
151    Container to hold instance-specific information on behalf of ForkerIterDataPipe.
152
153    It tracks the state of its child DataPipes, maintains the buffer, and yields the next value
154    as requested by the child DataPipes.
155    """
156
157    def __init__(
158        self,
159        datapipe: IterDataPipe,
160        num_instances: int,
161        buffer_size: int = 1000,
162        copy: Optional[Literal["shallow", "deep"]] = None,
163    ):
164        self.main_datapipe = datapipe
165        self._datapipe_iterator: Optional[Iterator[Any]] = None
166        self.num_instances = num_instances
167        self.buffer: Deque = deque()
168        self.buffer_size = buffer_size
169        if self.buffer_size < 0:
170            warnings.warn(
171                "Unlimited buffer size is set for `fork`, "
172                "please be aware of OOM at random places",
173                UserWarning,
174            )
175        if copy is None:
176            self.copy_fn = _no_op
177        elif copy == "shallow":
178            self.copy_fn = copymodule.copy
179        elif copy == "deep":
180            self.copy_fn = copymodule.deepcopy
181        else:
182            raise ValueError(
183                f"Unknown copy method `{copy}` requested, choose one of None, `shallow` or `deep`."
184            )
185
186        self.child_pointers: List[int] = [
187            0
188        ] * num_instances  # Indicate the indices of the next element to get
189        self.slowest_ptr = 0  # The index to read by the slowest child
190        self.leading_ptr = 0  # The index to read by the fastest child
191        self.end_ptr: Optional[int] = None  # The index to stop child
192        self._child_stop: List[bool] = [True for _ in range(num_instances)]
193
194    def __len__(self):
195        return len(self.main_datapipe)
196
197    def get_next_element_by_instance(self, instance_id: int):
198        if self._datapipe_iterator is None and self._child_stop[instance_id]:
199            self._datapipe_iterator = iter(self.main_datapipe)
200            self._snapshot_state = _SnapshotState.Iterating
201            for i in range(self.num_instances):
202                self._child_stop[i] = False
203        try:
204            while not self._child_stop[instance_id]:
205                self.child_pointers[instance_id] += 1
206                if (
207                    self.end_ptr is not None
208                    and self.child_pointers[instance_id] == self.end_ptr
209                ):
210                    self._child_stop[instance_id] = True
211                    break
212                # Use buffer
213                if self.buffer and self.child_pointers[instance_id] <= self.leading_ptr:
214                    idx = self.child_pointers[instance_id] - self.slowest_ptr - 1
215                    return_val = self.buffer[idx]
216                else:  # Retrieve one element from main datapipe
217                    self.leading_ptr = self.child_pointers[instance_id]
218                    try:
219                        return_val = next(self._datapipe_iterator)  # type: ignore[arg-type]
220                        self.buffer.append(return_val)
221                    except StopIteration:
222                        self._child_stop[instance_id] = True
223                        self._datapipe_iterator = None
224                        self.end_ptr = self.leading_ptr
225                        continue
226                if self.child_pointers[instance_id] == self.slowest_ptr + 1:
227                    new_min = min(
228                        self.child_pointers
229                    )  # Can optimize by avoiding the call to min()
230                    if self.slowest_ptr < new_min:
231                        self.slowest_ptr = new_min
232                        self.buffer.popleft()
233                if (
234                    self.buffer_size >= 0
235                    and self.leading_ptr > self.buffer_size + self.slowest_ptr
236                ):
237                    raise BufferError(
238                        "ForkerIterDataPipe buffer overflow,"
239                        + f"buffer size {self.buffer_size} is insufficient."
240                    )
241
242                yield self.copy_fn(return_val)  # type: ignore[possibly-undefined]
243        finally:
244            self._child_stop[instance_id] = True
245            # Cleanup _datapipe_iterator for the case that fork exits earlier
246            if all(self._child_stop):
247                self._datapipe_iterator = None
248                self._cleanup()
249
250    def is_every_instance_exhausted(self) -> bool:
251        return self.end_ptr is not None and all(self._child_stop)
252
253    def get_length_by_instance(self, instance_id: int) -> int:
254        return len(self.main_datapipe)
255
256    def reset(self) -> None:
257        self._datapipe_iterator = None
258        self.buffer = deque()
259        self.child_pointers = [0] * self.num_instances
260        self.slowest_ptr = 0
261        self.leading_ptr = 0
262        self.end_ptr = None
263        self._child_stop = [True for _ in range(self.num_instances)]
264
265    def __getstate__(self):
266        state = (
267            self.main_datapipe,
268            self.num_instances,
269            self.buffer_size,
270            self.copy_fn,
271            self._valid_iterator_id,
272            self._number_of_samples_yielded,
273        )
274        if IterDataPipe.getstate_hook is not None:
275            return IterDataPipe.getstate_hook(state)
276        return state
277
278    def __setstate__(self, state):
279        (
280            self.main_datapipe,
281            self.num_instances,
282            self.buffer_size,
283            self.copy_fn,
284            self._valid_iterator_id,
285            self._number_of_samples_yielded,
286        ) = state
287        self._datapipe_iterator = None
288        self.buffer = deque()
289        self.child_pointers = [0] * self.num_instances
290        self.slowest_ptr = 0
291        self.leading_ptr = 0
292        self.end_ptr = None
293        self._child_stop = [True for _ in range(self.num_instances)]
294
295    def _cleanup(self):
296        while self.buffer:
297            d = self.buffer.popleft()
298            StreamWrapper.close_streams(d)
299
300    def __del__(self):
301        self._cleanup()
302
303
304class _ChildDataPipe(IterDataPipe):
305    r"""
306    Iterable Datapipe that is a child of a main DataPipe.
307
308    The instance of this class will pass its instance_id to get the next value from its main DataPipe.
309
310    Note:
311        ChildDataPipe, like all other IterDataPipe, follows the single iterator per IterDataPipe constraint.
312        Since ChildDataPipes share a common buffer, when an iterator is created for one of the ChildDataPipes,
313        the previous iterators  for all ChildDataPipes must be invalidated, with the exception when a ChildDataPipe
314        hasn't had an iterator created from it since the last invalidation. See the example below.
315
316    Example:
317        >>> # xdoctest: +REQUIRES(module:torchdata)
318        >>> # Singler Iterator per IteraDataPipe Invalidation
319        >>> from torchdata.datapipes.iter import IterableWrapper
320        >>> source_dp = IterableWrapper(range(10))
321        >>> cdp1, cdp2 = source_dp.fork(num_instances=2)
322        >>> it1, it2 = iter(cdp1), iter(cdp2)
323        >>> it3 = iter(cdp1)
324        >>> # The line above invalidates `it1` and `it2`, and resets `ForkerIterDataPipe`.
325        >>> it4 = iter(cdp2)
326        >>> # The line above doesn't invalidate `it3`, because an iterator for `cdp2` hasn't been created since
327        >>> # the last invalidation.
328
329    Args:
330        main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)'
331        instance_id: integer identifier of this instance
332    """
333
334    _is_child_datapipe: bool = True
335
336    def __init__(self, main_datapipe: IterDataPipe, instance_id: int):
337        assert isinstance(main_datapipe, _ContainerTemplate)
338
339        self.main_datapipe: IterDataPipe = main_datapipe
340        self.instance_id = instance_id
341
342    def __iter__(self):
343        # Note that the logic behind setting iterator ID and `reset` are handled within `hook_iterator`
344        # We want to separate the code for reset and yield, so that 'reset' executes before __next__ is called
345        return self.main_datapipe.get_next_element_by_instance(self.instance_id)
346
347    def __len__(self):
348        return self.main_datapipe.get_length_by_instance(self.instance_id)
349
350    # This method is called by `hook_iterator` in `_typing.py`.
351    def _set_main_datapipe_valid_iterator_id(self) -> int:
352        r"""
353        Update the valid iterator ID for both this DataPipe object and `main_datapipe`.
354
355        `main_datapipe.reset()` is called when the ID is incremented to a new generation.
356        """
357        # 1. First time any child iterator is created
358        if self.main_datapipe._valid_iterator_id is None:
359            self.main_datapipe._valid_iterator_id = 0  # type: ignore[attr-defined]
360        # 2. This instance was already in the same generation as `main_datapipe`,
361        #    we need to increment the ID further by 1
362        elif self.main_datapipe._valid_iterator_id == self._valid_iterator_id:  # type: ignore[has-type]
363            self.main_datapipe._valid_iterator_id += 1  # type: ignore[attr-defined]
364            # Whenever a new generation of iterator is created, the `main_datapipe` must reset
365            if not self.main_datapipe.is_every_instance_exhausted():
366                warnings.warn(
367                    "Some child DataPipes are not exhausted when __iter__ is called. We are resetting "
368                    "the buffer and each child DataPipe will read from the start again.",
369                    UserWarning,
370                )
371            self.main_datapipe.reset()
372        # 3. Otherwise, the iterator is behind the others, so it will just need to catch up by setting
373        #    the instance's iterator to match that of `main_datapipe`
374        self._valid_iterator_id = self.main_datapipe._valid_iterator_id
375        return self._valid_iterator_id
376
377    # This method is called by `hook_iterator` in `_typing.py`.
378    def _check_valid_iterator_id(self, iterator_id) -> bool:
379        r"""Check the valid iterator ID against that of DataPipe object and that of `main_datapipe`."""
380        return (
381            iterator_id == self._valid_iterator_id
382            and iterator_id == self.main_datapipe._valid_iterator_id
383        )
384
385
386@functional_datapipe("demux")
387class DemultiplexerIterDataPipe(IterDataPipe):
388    r"""
389    Splits the input DataPipe into multiple child DataPipes, using the given classification function (functional name: ``demux``).
390
391    A list of the child DataPipes is returned from this operation.
392
393    Args:
394        datapipe: Iterable DataPipe being filtered
395        num_instances: number of instances of the DataPipe to create
396        classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None``
397        drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None``
398        buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
399            DataPipes while waiting for their values to be yielded.
400            Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
401
402    Examples:
403        >>> # xdoctest: +REQUIRES(module:torchdata)
404        >>> from torchdata.datapipes.iter import IterableWrapper
405        >>> def odd_or_even(n):
406        ...     return n % 2
407        >>> source_dp = IterableWrapper(range(5))
408        >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even)
409        >>> list(dp1)
410        [0, 2, 4]
411        >>> list(dp2)
412        [1, 3]
413        >>> # It can also filter out any element that gets `None` from the `classifier_fn`
414        >>> def odd_or_even_no_zero(n):
415        ...     return n % 2 if n != 0 else None
416        >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True)
417        >>> list(dp1)
418        [2, 4]
419        >>> list(dp2)
420        [1, 3]
421    """
422
423    def __new__(
424        cls,
425        datapipe: IterDataPipe,
426        num_instances: int,
427        classifier_fn: Callable[[_T_co], Optional[int]],
428        drop_none: bool = False,
429        buffer_size: int = 1000,
430    ):
431        if num_instances < 1:
432            raise ValueError(
433                f"Expected `num_instances` larger than 0, but {num_instances} is found"
434            )
435
436        _check_unpickable_fn(classifier_fn)
437
438        # When num_instances == 1, demux can be replaced by filter,
439        # but keep it as Demultiplexer for the sake of consistency
440        # like throwing Error when classification result is out of o range
441        container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size)  # type: ignore[abstract]
442        return [_ChildDataPipe(container, i) for i in range(num_instances)]
443
444
445class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
446    r"""
447    Container to hold instance-specific information on behalf of DemultiplexerIterDataPipe.
448
449    It tracks the state of its child DataPipes, maintains the buffer, classifies and yields the next correct value
450    as requested by the child DataPipes.
451    """
452
453    def __init__(
454        self,
455        datapipe: IterDataPipe[_T_co],
456        num_instances: int,
457        classifier_fn: Callable[[_T_co], Optional[int]],
458        drop_none: bool,
459        buffer_size: int,
460    ):
461        self.main_datapipe = datapipe
462        self._datapipe_iterator: Optional[Iterator[Any]] = None
463        self.num_instances = num_instances
464        self.buffer_size = buffer_size
465        if self.buffer_size < 0:
466            warnings.warn(
467                "Unlimited buffer size is set for `demux`, "
468                "please be aware of OOM at random places",
469                UserWarning,
470            )
471        self.current_buffer_usage = 0
472        self.child_buffers: List[Deque[_T_co]] = [deque() for _ in range(num_instances)]
473        self.classifier_fn = classifier_fn
474        self.drop_none = drop_none
475        self.main_datapipe_exhausted = False
476        self._child_stop: List[bool] = [True for _ in range(num_instances)]
477
478    def _find_next(self, instance_id: int) -> _T_co:  # type: ignore[type-var]
479        while True:
480            if self.main_datapipe_exhausted or self._child_stop[instance_id]:
481                raise StopIteration
482            if self._datapipe_iterator is None:
483                raise ValueError(
484                    "_datapipe_iterator has not been set, likely because this private method is called directly "
485                    "without invoking get_next_element_by_instance() first."
486                )
487            value = next(self._datapipe_iterator)
488            classification = self.classifier_fn(value)
489            if classification is None and self.drop_none:
490                StreamWrapper.close_streams(value)
491                continue
492            if (
493                classification is None
494                or classification >= self.num_instances
495                or classification < 0
496            ):
497                raise ValueError(
498                    f"Output of the classification fn should be between 0 and {self.num_instances - 1}. "
499                    + f"{classification} is returned."
500                )
501            if classification == instance_id:
502                return value
503            self.child_buffers[classification].append(value)
504            self.current_buffer_usage += 1
505            if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size:
506                raise BufferError(
507                    f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient."
508                )
509
510    def get_next_element_by_instance(self, instance_id: int):
511        if self._datapipe_iterator is None and self._child_stop[instance_id]:
512            self._datapipe_iterator = iter(self.main_datapipe)
513            self._snapshot_state = (
514                _SnapshotState.Iterating
515            )  # This is necessary for the DataPipe to reset properly.
516            self.main_datapipe_exhausted = False
517            for i in range(self.num_instances):
518                self._child_stop[i] = False
519
520        try:
521            while not self._child_stop[instance_id]:
522                if self.child_buffers[instance_id]:
523                    self.current_buffer_usage -= 1
524                    yield self.child_buffers[instance_id].popleft()
525                else:
526                    try:
527                        yield self._find_next(instance_id)
528                    except StopIteration:
529                        self._child_stop[instance_id] = True
530                        self.main_datapipe_exhausted = True
531                        self._datapipe_iterator = None
532        finally:
533            self._child_stop[instance_id] = True
534            # Cleanup _datapipe_iterator for the case that demux exits earlier
535            if all(self._child_stop):
536                self._datapipe_iterator = None
537            if self.child_buffers[instance_id]:
538                self._cleanup(instance_id)
539
540    def is_every_instance_exhausted(self) -> bool:
541        return self.main_datapipe_exhausted and all(self._child_stop)
542
543    def get_length_by_instance(self, instance_id: int) -> int:
544        raise TypeError
545
546    def reset(self) -> None:
547        self._datapipe_iterator = None
548        self.current_buffer_usage = 0
549        self.child_buffers = [deque() for _ in range(self.num_instances)]
550        self._child_stop = [True for _ in range(self.num_instances)]
551        self.main_datapipe_exhausted = False
552
553    def __getstate__(self):
554        state = (
555            self.main_datapipe,
556            self.num_instances,
557            self.buffer_size,
558            self.classifier_fn,
559            self.drop_none,
560            self._valid_iterator_id,
561            self._number_of_samples_yielded,
562        )
563        if IterDataPipe.getstate_hook is not None:
564            return IterDataPipe.getstate_hook(state)
565        return state
566
567    def __setstate__(self, state):
568        (
569            self.main_datapipe,
570            self.num_instances,
571            self.buffer_size,
572            self.classifier_fn,
573            self.drop_none,
574            self._valid_iterator_id,
575            self._number_of_samples_yielded,
576        ) = state
577        self._datapipe_iterator = None
578        self.current_buffer_usage = 0
579        self.child_buffers = [deque() for _ in range(self.num_instances)]
580        self._child_stop = [True for _ in range(self.num_instances)]
581        self.main_datapipe_exhausted = False
582
583    def _cleanup(self, instance_id: Optional[int] = None):
584        ids = (
585            range(self.num_instances)
586            if instance_id is None
587            else [
588                instance_id,
589            ]
590        )
591        for i in ids:
592            q = self.child_buffers[i]
593            while q:
594                d = q.popleft()
595                StreamWrapper.close_streams(d)
596
597    def __del__(self):
598        self._cleanup()
599
600
601@functional_datapipe("mux")
602class MultiplexerIterDataPipe(IterDataPipe):
603    r"""
604    Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``).
605
606    As in, one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration,
607    and so on. It ends when the shortest input DataPipe is exhausted.
608
609    Args:
610        datapipes: Iterable DataPipes that will take turn to yield their elements, until the shortest DataPipe is exhausted
611
612    Example:
613        >>> # xdoctest: +REQUIRES(module:torchdata)
614        >>> from torchdata.datapipes.iter import IterableWrapper
615        >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
616        >>> list(dp1.mux(dp2, dp3))
617        [0, 10, 20, 1, 11, 21, 2, 12, 22]
618    """
619
620    def __init__(self, *datapipes):
621        self.datapipes = datapipes
622        self.buffer: List = (
623            []
624        )  # Store values to be yielded only when every iterator provides one
625
626    def __iter__(self):
627        iterators = [iter(x) for x in self.datapipes]
628        while len(iterators):
629            for it in iterators:
630                try:
631                    value = next(it)
632                    self.buffer.append(value)
633                except StopIteration:
634                    self.buffer.clear()
635                    return
636            yield from self.buffer
637            self.buffer.clear()
638
639    def __len__(self):
640        if all(isinstance(dp, Sized) for dp in self.datapipes):
641            return min(len(dp) for dp in self.datapipes) * len(self.datapipes)
642        else:
643            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
644
645    def reset(self) -> None:
646        self.buffer = []
647
648    def __getstate__(self):
649        state = (
650            self.datapipes,
651            self._valid_iterator_id,
652            self._number_of_samples_yielded,
653        )
654        if IterDataPipe.getstate_hook is not None:
655            return IterDataPipe.getstate_hook(state)
656        return state
657
658    def __setstate__(self, state):
659        (
660            self.datapipes,
661            self._valid_iterator_id,
662            self._number_of_samples_yielded,
663        ) = state
664        self.buffer = []
665
666    def __del__(self):
667        self.buffer.clear()
668
669
670@functional_datapipe("zip")
671class ZipperIterDataPipe(IterDataPipe[Tuple[_T_co]]):
672    r"""
673    Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip``).
674
675    The output is stopped as soon as the shortest input DataPipe is exhausted.
676
677    Args:
678        *datapipes: Iterable DataPipes being aggregated
679
680    Example:
681        >>> # xdoctest: +REQUIRES(module:torchdata)
682        >>> from torchdata.datapipes.iter import IterableWrapper
683        >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25))
684        >>> list(dp1.zip(dp2, dp3))
685        [(0, 10, 20), (1, 11, 21), (2, 12, 22), (3, 13, 23), (4, 14, 24)]
686    """
687
688    datapipes: Tuple[IterDataPipe]
689
690    def __init__(self, *datapipes: IterDataPipe):
691        if not all(isinstance(dp, IterDataPipe) for dp in datapipes):
692            raise TypeError(
693                "All inputs are required to be `IterDataPipe` " "for `ZipIterDataPipe`."
694            )
695        super().__init__()
696        self.datapipes = datapipes  # type: ignore[assignment]
697
698    def __iter__(self) -> Iterator[Tuple[_T_co]]:
699        iterators = [iter(datapipe) for datapipe in self.datapipes]
700        yield from zip(*iterators)
701
702    def __len__(self) -> int:
703        if all(isinstance(dp, Sized) for dp in self.datapipes):
704            return min(len(dp) for dp in self.datapipes)
705        else:
706            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
707