xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/iter/fileopener.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from io import IOBase
3from typing import Iterable, Optional, Tuple
4
5from torch.utils.data.datapipes._decorator import functional_datapipe
6from torch.utils.data.datapipes.datapipe import IterDataPipe
7from torch.utils.data.datapipes.utils.common import get_file_binaries_from_pathnames
8
9
10__all__ = [
11    "FileOpenerIterDataPipe",
12]
13
14
15@functional_datapipe("open_files")
16class FileOpenerIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
17    r"""
18    Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``).
19
20    Args:
21        datapipe: Iterable datapipe that provides pathnames
22        mode: An optional string that specifies the mode in which
23            the file is opened by ``open()``. It defaults to ``r``, other options are
24            ``b`` for reading in binary mode and ``t`` for text mode.
25        encoding: An optional string that specifies the encoding of the
26            underlying file. It defaults to ``None`` to match the default encoding of ``open``.
27        length: Nominal length of the datapipe
28
29    Note:
30        The opened file handles will be closed by Python's GC periodically. Users can choose
31        to close them explicitly.
32
33    Example:
34        >>> # xdoctest: +SKIP
35        >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader
36        >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt'))
37        >>> dp = FileOpener(dp)
38        >>> dp = StreamReader(dp)
39        >>> list(dp)
40        [('./abc.txt', 'abc')]
41    """
42
43    def __init__(
44        self,
45        datapipe: Iterable[str],
46        mode: str = "r",
47        encoding: Optional[str] = None,
48        length: int = -1,
49    ):
50        super().__init__()
51        self.datapipe: Iterable = datapipe
52        self.mode: str = mode
53        self.encoding: Optional[str] = encoding
54
55        if self.mode not in ("b", "t", "rb", "rt", "r"):
56            raise ValueError(f"Invalid mode {mode}")
57        # TODO: enforce typing for each instance based on mode, otherwise
58        #       `argument_validation` with this DataPipe may be potentially broken
59
60        if "b" in mode and encoding is not None:
61            raise ValueError("binary mode doesn't take an encoding argument")
62
63        self.length: int = length
64
65    # Remove annotation due to 'IOBase' is a general type and true type
66    # is determined at runtime based on mode. Some `DataPipe` requiring
67    # a subtype would cause mypy error.
68    def __iter__(self):
69        yield from get_file_binaries_from_pathnames(
70            self.datapipe, self.mode, self.encoding
71        )
72
73    def __len__(self):
74        if self.length == -1:
75            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
76        return self.length
77