xref: /aosp_15_r20/external/pytorch/torch/utils/data/datapipes/iter/filelister.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Iterator, List, Sequence, Union
3
4from torch.utils.data.datapipes._decorator import functional_datapipe
5from torch.utils.data.datapipes.datapipe import IterDataPipe
6from torch.utils.data.datapipes.iter.utils import IterableWrapperIterDataPipe
7from torch.utils.data.datapipes.utils.common import get_file_pathnames_from_root
8
9
10__all__ = ["FileListerIterDataPipe"]
11
12
13@functional_datapipe("list_files")
14class FileListerIterDataPipe(IterDataPipe[str]):
15    r"""
16    Given path(s) to the root directory, yields file pathname(s) (path + filename) of files within the root directory.
17
18    Multiple root directories can be provided (functional name: ``list_files``).
19
20    Args:
21        root: Root directory or a sequence of root directories
22        masks: Unix style filter string or string list for filtering file name(s)
23        recursive: Whether to return pathname from nested directories or not
24        abspath: Whether to return relative pathname or absolute pathname
25        non_deterministic: Whether to return pathname in sorted order or not.
26            If ``False``, the results yielded from each root directory will be sorted
27        length: Nominal length of the datapipe
28
29    Example:
30        >>> # xdoctest: +SKIP
31        >>> from torchdata.datapipes.iter import FileLister
32        >>> dp = FileLister(root=".", recursive=True)
33        >>> list(dp)
34        ['example.py', './data/data.tar']
35    """
36
37    def __init__(
38        self,
39        root: Union[str, Sequence[str], IterDataPipe] = ".",
40        masks: Union[str, List[str]] = "",
41        *,
42        recursive: bool = False,
43        abspath: bool = False,
44        non_deterministic: bool = False,
45        length: int = -1,
46    ) -> None:
47        super().__init__()
48        if isinstance(root, str):
49            root = [root]
50        if not isinstance(root, IterDataPipe):
51            root = IterableWrapperIterDataPipe(root)
52        self.datapipe: IterDataPipe = root
53        self.masks: Union[str, List[str]] = masks
54        self.recursive: bool = recursive
55        self.abspath: bool = abspath
56        self.non_deterministic: bool = non_deterministic
57        self.length: int = length
58
59    def __iter__(self) -> Iterator[str]:
60        for path in self.datapipe:
61            yield from get_file_pathnames_from_root(
62                path, self.masks, self.recursive, self.abspath, self.non_deterministic
63            )
64
65    def __len__(self):
66        if self.length == -1:
67            raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
68        return self.length
69