1# mypy: allow-untyped-defs 2from io import IOBase 3from typing import Iterable, Optional, Tuple 4 5from torch.utils.data.datapipes._decorator import functional_datapipe 6from torch.utils.data.datapipes.datapipe import IterDataPipe 7from torch.utils.data.datapipes.utils.common import get_file_binaries_from_pathnames 8 9 10__all__ = [ 11 "FileOpenerIterDataPipe", 12] 13 14 15@functional_datapipe("open_files") 16class FileOpenerIterDataPipe(IterDataPipe[Tuple[str, IOBase]]): 17 r""" 18 Given pathnames, opens files and yield pathname and file stream in a tuple (functional name: ``open_files``). 19 20 Args: 21 datapipe: Iterable datapipe that provides pathnames 22 mode: An optional string that specifies the mode in which 23 the file is opened by ``open()``. It defaults to ``r``, other options are 24 ``b`` for reading in binary mode and ``t`` for text mode. 25 encoding: An optional string that specifies the encoding of the 26 underlying file. It defaults to ``None`` to match the default encoding of ``open``. 27 length: Nominal length of the datapipe 28 29 Note: 30 The opened file handles will be closed by Python's GC periodically. Users can choose 31 to close them explicitly. 32 33 Example: 34 >>> # xdoctest: +SKIP 35 >>> from torchdata.datapipes.iter import FileLister, FileOpener, StreamReader 36 >>> dp = FileLister(root=".").filter(lambda fname: fname.endswith('.txt')) 37 >>> dp = FileOpener(dp) 38 >>> dp = StreamReader(dp) 39 >>> list(dp) 40 [('./abc.txt', 'abc')] 41 """ 42 43 def __init__( 44 self, 45 datapipe: Iterable[str], 46 mode: str = "r", 47 encoding: Optional[str] = None, 48 length: int = -1, 49 ): 50 super().__init__() 51 self.datapipe: Iterable = datapipe 52 self.mode: str = mode 53 self.encoding: Optional[str] = encoding 54 55 if self.mode not in ("b", "t", "rb", "rt", "r"): 56 raise ValueError(f"Invalid mode {mode}") 57 # TODO: enforce typing for each instance based on mode, otherwise 58 # `argument_validation` with this DataPipe may be potentially broken 59 60 if "b" in mode and encoding is not None: 61 raise ValueError("binary mode doesn't take an encoding argument") 62 63 self.length: int = length 64 65 # Remove annotation due to 'IOBase' is a general type and true type 66 # is determined at runtime based on mode. Some `DataPipe` requiring 67 # a subtype would cause mypy error. 68 def __iter__(self): 69 yield from get_file_binaries_from_pathnames( 70 self.datapipe, self.mode, self.encoding 71 ) 72 73 def __len__(self): 74 if self.length == -1: 75 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 76 return self.length 77