1# mypy: allow-untyped-defs 2from typing import Iterator, List, Sequence, Union 3 4from torch.utils.data.datapipes._decorator import functional_datapipe 5from torch.utils.data.datapipes.datapipe import IterDataPipe 6from torch.utils.data.datapipes.iter.utils import IterableWrapperIterDataPipe 7from torch.utils.data.datapipes.utils.common import get_file_pathnames_from_root 8 9 10__all__ = ["FileListerIterDataPipe"] 11 12 13@functional_datapipe("list_files") 14class FileListerIterDataPipe(IterDataPipe[str]): 15 r""" 16 Given path(s) to the root directory, yields file pathname(s) (path + filename) of files within the root directory. 17 18 Multiple root directories can be provided (functional name: ``list_files``). 19 20 Args: 21 root: Root directory or a sequence of root directories 22 masks: Unix style filter string or string list for filtering file name(s) 23 recursive: Whether to return pathname from nested directories or not 24 abspath: Whether to return relative pathname or absolute pathname 25 non_deterministic: Whether to return pathname in sorted order or not. 26 If ``False``, the results yielded from each root directory will be sorted 27 length: Nominal length of the datapipe 28 29 Example: 30 >>> # xdoctest: +SKIP 31 >>> from torchdata.datapipes.iter import FileLister 32 >>> dp = FileLister(root=".", recursive=True) 33 >>> list(dp) 34 ['example.py', './data/data.tar'] 35 """ 36 37 def __init__( 38 self, 39 root: Union[str, Sequence[str], IterDataPipe] = ".", 40 masks: Union[str, List[str]] = "", 41 *, 42 recursive: bool = False, 43 abspath: bool = False, 44 non_deterministic: bool = False, 45 length: int = -1, 46 ) -> None: 47 super().__init__() 48 if isinstance(root, str): 49 root = [root] 50 if not isinstance(root, IterDataPipe): 51 root = IterableWrapperIterDataPipe(root) 52 self.datapipe: IterDataPipe = root 53 self.masks: Union[str, List[str]] = masks 54 self.recursive: bool = recursive 55 self.abspath: bool = abspath 56 self.non_deterministic: bool = non_deterministic 57 self.length: int = length 58 59 def __iter__(self) -> Iterator[str]: 60 for path in self.datapipe: 61 yield from get_file_pathnames_from_root( 62 path, self.masks, self.recursive, self.abspath, self.non_deterministic 63 ) 64 65 def __len__(self): 66 if self.length == -1: 67 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 68 return self.length 69