xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/iter/streamreader.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Tuple
3
4from torch.utils.data.datapipes._decorator import functional_datapipe
5from torch.utils.data.datapipes.datapipe import IterDataPipe
6
7
8__all__ = ["StreamReaderIterDataPipe"]
9
10
11@functional_datapipe("read_from_stream")
12class StreamReaderIterDataPipe(IterDataPipe[Tuple[str, bytes]]):
13    r"""
14    Given IO streams and their label names, yield bytes with label name as tuple.
15
16    (functional name: ``read_from_stream``).
17
18    Args:
19        datapipe: Iterable DataPipe provides label/URL and byte stream
20        chunk: Number of bytes to be read from stream per iteration.
21            If ``None``, all bytes will be read until the EOF.
22
23    Example:
24        >>> # xdoctest: +SKIP
25        >>> from torchdata.datapipes.iter import IterableWrapper, StreamReader
26        >>> from io import StringIO
27        >>> dp = IterableWrapper([("alphabet", StringIO("abcde"))])
28        >>> list(StreamReader(dp, chunk=1))
29        [('alphabet', 'a'), ('alphabet', 'b'), ('alphabet', 'c'), ('alphabet', 'd'), ('alphabet', 'e')]
30    """
31
32    def __init__(self, datapipe, chunk=None):
33        self.datapipe = datapipe
34        self.chunk = chunk
35
36    def __iter__(self):
37        for furl, stream in self.datapipe:
38            while True:
39                d = stream.read(self.chunk)
40                if not d:
41                    stream.close()
42                    break
43                yield (furl, d)
44