xref: /aosp_15_r20/external/pytorch/test/package/test_resources.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3from io import BytesIO
4from sys import version_info
5from textwrap import dedent
6from unittest import skipIf
7
8from torch.package import PackageExporter, PackageImporter
9from torch.testing._internal.common_utils import run_tests
10
11
12try:
13    from .common import PackageTestCase
14except ImportError:
15    # Support the case where we run this file directly.
16    from common import PackageTestCase
17
18
19@skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
20class TestResources(PackageTestCase):
21    """Tests for access APIs for packaged resources."""
22
23    def test_resource_reader(self):
24        """Test compliance with the get_resource_reader importlib API."""
25        buffer = BytesIO()
26        with PackageExporter(buffer) as pe:
27            # Layout looks like:
28            #    package
29            #    |-- one/
30            #    |   |-- a.txt
31            #    |   |-- b.txt
32            #    |   |-- c.txt
33            #    |   +-- three/
34            #    |       |-- d.txt
35            #    |       +-- e.txt
36            #    +-- two/
37            #       |-- f.txt
38            #       +-- g.txt
39            pe.save_text("one", "a.txt", "hello, a!")
40            pe.save_text("one", "b.txt", "hello, b!")
41            pe.save_text("one", "c.txt", "hello, c!")
42
43            pe.save_text("one.three", "d.txt", "hello, d!")
44            pe.save_text("one.three", "e.txt", "hello, e!")
45
46            pe.save_text("two", "f.txt", "hello, f!")
47            pe.save_text("two", "g.txt", "hello, g!")
48
49        buffer.seek(0)
50        importer = PackageImporter(buffer)
51
52        reader_one = importer.get_resource_reader("one")
53        with self.assertRaises(FileNotFoundError):
54            reader_one.resource_path("a.txt")
55
56        self.assertTrue(reader_one.is_resource("a.txt"))
57        self.assertEqual(reader_one.open_resource("a.txt").getbuffer(), b"hello, a!")
58        self.assertFalse(reader_one.is_resource("three"))
59        reader_one_contents = list(reader_one.contents())
60        self.assertSequenceEqual(
61            reader_one_contents, ["a.txt", "b.txt", "c.txt", "three"]
62        )
63
64        reader_two = importer.get_resource_reader("two")
65        self.assertTrue(reader_two.is_resource("f.txt"))
66        self.assertEqual(reader_two.open_resource("f.txt").getbuffer(), b"hello, f!")
67        reader_two_contents = list(reader_two.contents())
68        self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"])
69
70        reader_one_three = importer.get_resource_reader("one.three")
71        self.assertTrue(reader_one_three.is_resource("d.txt"))
72        self.assertEqual(
73            reader_one_three.open_resource("d.txt").getbuffer(), b"hello, d!"
74        )
75        reader_one_three_contenst = list(reader_one_three.contents())
76        self.assertSequenceEqual(reader_one_three_contenst, ["d.txt", "e.txt"])
77
78        self.assertIsNone(importer.get_resource_reader("nonexistent_package"))
79
80    def test_package_resource_access(self):
81        """Packaged modules should be able to use the importlib.resources API to access
82        resources saved in the package.
83        """
84        mod_src = dedent(
85            """\
86            import importlib.resources
87            import my_cool_resources
88
89            def secret_message():
90                return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
91            """
92        )
93        buffer = BytesIO()
94        with PackageExporter(buffer) as pe:
95            pe.save_source_string("foo.bar", mod_src)
96            pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")
97
98        buffer.seek(0)
99        importer = PackageImporter(buffer)
100        self.assertEqual(
101            importer.import_module("foo.bar").secret_message(), "my sekrit plays"
102        )
103
104    def test_importer_access(self):
105        buffer = BytesIO()
106        with PackageExporter(buffer) as he:
107            he.save_text("main", "main", "my string")
108            he.save_binary("main", "main_binary", b"my string")
109            src = dedent(
110                """\
111                import importlib
112                import torch_package_importer as resources
113
114                t = resources.load_text('main', 'main')
115                b = resources.load_binary('main', 'main_binary')
116                """
117            )
118            he.save_source_string("main", src, is_package=True)
119        buffer.seek(0)
120        hi = PackageImporter(buffer)
121        m = hi.import_module("main")
122        self.assertEqual(m.t, "my string")
123        self.assertEqual(m.b, b"my string")
124
125    def test_resource_access_by_path(self):
126        """
127        Tests that packaged code can used importlib.resources.path.
128        """
129        buffer = BytesIO()
130        with PackageExporter(buffer) as he:
131            he.save_binary("string_module", "my_string", b"my string")
132            src = dedent(
133                """\
134                import importlib.resources
135                import string_module
136
137                with importlib.resources.path(string_module, 'my_string') as path:
138                    with open(path, mode='r', encoding='utf-8') as f:
139                        s = f.read()
140                """
141            )
142            he.save_source_string("main", src, is_package=True)
143        buffer.seek(0)
144        hi = PackageImporter(buffer)
145        m = hi.import_module("main")
146        self.assertEqual(m.s, "my string")
147
148
149if __name__ == "__main__":
150    run_tests()
151