xref: /aosp_15_r20/external/pytorch/test/package/test_glob_group.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3from typing import Iterable
4
5from torch.package import GlobGroup
6from torch.testing._internal.common_utils import run_tests
7
8
9try:
10    from .common import PackageTestCase
11except ImportError:
12    # Support the case where we run this file directly.
13    from common import PackageTestCase
14
15
16class TestGlobGroup(PackageTestCase):
17    def assertMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
18        for candidate in candidates:
19            self.assertTrue(glob.matches(candidate))
20
21    def assertNotMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
22        for candidate in candidates:
23            self.assertFalse(glob.matches(candidate))
24
25    def test_one_star(self):
26        glob_group = GlobGroup("torch.*")
27        self.assertMatchesGlob(glob_group, ["torch.foo", "torch.bar"])
28        self.assertNotMatchesGlob(glob_group, ["tor.foo", "torch.foo.bar", "torch"])
29
30    def test_one_star_middle(self):
31        glob_group = GlobGroup("foo.*.bar")
32        self.assertMatchesGlob(glob_group, ["foo.q.bar", "foo.foo.bar"])
33        self.assertNotMatchesGlob(
34            glob_group,
35            [
36                "foo.bar",
37                "foo.foo",
38                "outer.foo.inner.bar",
39                "foo.q.bar.more",
40                "foo.one.two.bar",
41            ],
42        )
43
44    def test_one_star_partial(self):
45        glob_group = GlobGroup("fo*.bar")
46        self.assertMatchesGlob(glob_group, ["fo.bar", "foo.bar", "foobar.bar"])
47        self.assertNotMatchesGlob(glob_group, ["oij.bar", "f.bar", "foo"])
48
49    def test_one_star_multiple_in_component(self):
50        glob_group = GlobGroup("foo/a*.htm*", separator="/")
51        self.assertMatchesGlob(glob_group, ["foo/a.html", "foo/a.htm", "foo/abc.html"])
52
53    def test_one_star_partial_extension(self):
54        glob_group = GlobGroup("foo/*.txt", separator="/")
55        self.assertMatchesGlob(
56            glob_group, ["foo/hello.txt", "foo/goodbye.txt", "foo/.txt"]
57        )
58        self.assertNotMatchesGlob(
59            glob_group, ["foo/bar/hello.txt", "bar/foo/hello.txt"]
60        )
61
62    def test_two_star(self):
63        glob_group = GlobGroup("torch.**")
64        self.assertMatchesGlob(
65            glob_group, ["torch.foo", "torch.bar", "torch.foo.bar", "torch"]
66        )
67        self.assertNotMatchesGlob(glob_group, ["what.torch", "torchvision"])
68
69    def test_two_star_end(self):
70        glob_group = GlobGroup("**.torch")
71        self.assertMatchesGlob(glob_group, ["torch", "bar.torch"])
72        self.assertNotMatchesGlob(glob_group, ["visiontorch"])
73
74    def test_two_star_middle(self):
75        glob_group = GlobGroup("foo.**.baz")
76        self.assertMatchesGlob(
77            glob_group, ["foo.baz", "foo.bar.baz", "foo.bar1.bar2.baz"]
78        )
79        self.assertNotMatchesGlob(glob_group, ["foobaz", "foo.bar.baz.z"])
80
81    def test_two_star_multiple(self):
82        glob_group = GlobGroup("**/bar/**/*.txt", separator="/")
83        self.assertMatchesGlob(
84            glob_group, ["bar/baz.txt", "a/bar/b.txt", "bar/foo/c.txt"]
85        )
86        self.assertNotMatchesGlob(glob_group, ["baz.txt", "a/b.txt"])
87
88    def test_raw_two_star(self):
89        glob_group = GlobGroup("**")
90        self.assertMatchesGlob(glob_group, ["bar", "foo.bar", "ab.c.d.e"])
91        self.assertNotMatchesGlob(glob_group, [""])
92
93    def test_invalid_raw(self):
94        with self.assertRaises(ValueError):
95            GlobGroup("a.**b")
96
97    def test_exclude(self):
98        glob_group = GlobGroup("torch.**", exclude=["torch.**.foo"])
99        self.assertMatchesGlob(
100            glob_group,
101            ["torch", "torch.bar", "torch.barfoo"],
102        )
103        self.assertNotMatchesGlob(
104            glob_group,
105            ["torch.foo", "torch.some.foo"],
106        )
107
108    def test_exclude_from_all(self):
109        glob_group = GlobGroup("**", exclude=["foo.**", "bar.**"])
110        self.assertMatchesGlob(glob_group, ["a", "hello", "anything.really"])
111        self.assertNotMatchesGlob(glob_group, ["foo.bar", "foo.bar.baz"])
112
113    def test_list_include_exclude(self):
114        glob_group = GlobGroup(["foo", "bar.**"], exclude=["bar.baz", "bar.qux"])
115        self.assertMatchesGlob(glob_group, ["foo", "bar.other", "bar.bazother"])
116        self.assertNotMatchesGlob(glob_group, ["bar.baz", "bar.qux"])
117
118
119if __name__ == "__main__":
120    run_tests()
121