xref: /aosp_15_r20/external/pytorch/test/package/test_directory_reader.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3import os
4import zipfile
5from sys import version_info
6from tempfile import TemporaryDirectory
7from textwrap import dedent
8from unittest import skipIf
9
10import torch
11from torch.package import PackageExporter, PackageImporter
12from torch.testing._internal.common_utils import (
13    IS_FBCODE,
14    IS_SANDCASTLE,
15    IS_WINDOWS,
16    run_tests,
17)
18
19
20try:
21    from torchvision.models import resnet18
22
23    HAS_TORCHVISION = True
24except ImportError:
25    HAS_TORCHVISION = False
26skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
27
28
29try:
30    from .common import PackageTestCase
31except ImportError:
32    # Support the case where we run this file directly.
33    from common import PackageTestCase
34
35from pathlib import Path
36
37
38packaging_directory = Path(__file__).parent
39
40
41@skipIf(
42    IS_FBCODE or IS_SANDCASTLE or IS_WINDOWS,
43    "Tests that use temporary files are disabled in fbcode",
44)
45class DirectoryReaderTest(PackageTestCase):
46    """Tests use of DirectoryReader as accessor for opened packages."""
47
48    @skipIfNoTorchVision
49    @skipIf(
50        True,
51        "Does not work with latest TorchVision, see https://github.com/pytorch/pytorch/issues/81115",
52    )
53    def test_loading_pickle(self):
54        """
55        Test basic saving and loading of modules and pickles from a DirectoryReader.
56        """
57        resnet = resnet18()
58
59        filename = self.temp()
60        with PackageExporter(filename) as e:
61            e.intern("**")
62            e.save_pickle("model", "model.pkl", resnet)
63
64        zip_file = zipfile.ZipFile(filename, "r")
65
66        with TemporaryDirectory() as temp_dir:
67            zip_file.extractall(path=temp_dir)
68            importer = PackageImporter(Path(temp_dir) / Path(filename).name)
69            dir_mod = importer.load_pickle("model", "model.pkl")
70            input = torch.rand(1, 3, 224, 224)
71            self.assertEqual(dir_mod(input), resnet(input))
72
73    def test_loading_module(self):
74        """
75        Test basic saving and loading of a packages from a DirectoryReader.
76        """
77        import package_a
78
79        filename = self.temp()
80        with PackageExporter(filename) as e:
81            e.save_module("package_a")
82
83        zip_file = zipfile.ZipFile(filename, "r")
84
85        with TemporaryDirectory() as temp_dir:
86            zip_file.extractall(path=temp_dir)
87            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
88            dir_mod = dir_importer.import_module("package_a")
89            self.assertEqual(dir_mod.result, package_a.result)
90
91    def test_loading_has_record(self):
92        """
93        Test DirectoryReader's has_record().
94        """
95        import package_a  # noqa: F401
96
97        filename = self.temp()
98        with PackageExporter(filename) as e:
99            e.save_module("package_a")
100
101        zip_file = zipfile.ZipFile(filename, "r")
102
103        with TemporaryDirectory() as temp_dir:
104            zip_file.extractall(path=temp_dir)
105            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
106            self.assertTrue(dir_importer.zip_reader.has_record("package_a/__init__.py"))
107            self.assertFalse(dir_importer.zip_reader.has_record("package_a"))
108
109    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
110    def test_resource_reader(self):
111        """Tests DirectoryReader as the base for get_resource_reader."""
112        filename = self.temp()
113        with PackageExporter(filename) as pe:
114            # Layout looks like:
115            #    package
116            #    |-- one/
117            #    |   |-- a.txt
118            #    |   |-- b.txt
119            #    |   |-- c.txt
120            #    |   +-- three/
121            #    |       |-- d.txt
122            #    |       +-- e.txt
123            #    +-- two/
124            #       |-- f.txt
125            #       +-- g.txt
126            pe.save_text("one", "a.txt", "hello, a!")
127            pe.save_text("one", "b.txt", "hello, b!")
128            pe.save_text("one", "c.txt", "hello, c!")
129
130            pe.save_text("one.three", "d.txt", "hello, d!")
131            pe.save_text("one.three", "e.txt", "hello, e!")
132
133            pe.save_text("two", "f.txt", "hello, f!")
134            pe.save_text("two", "g.txt", "hello, g!")
135
136        zip_file = zipfile.ZipFile(filename, "r")
137
138        with TemporaryDirectory() as temp_dir:
139            zip_file.extractall(path=temp_dir)
140            importer = PackageImporter(Path(temp_dir) / Path(filename).name)
141            reader_one = importer.get_resource_reader("one")
142
143            # Different behavior from still zipped archives
144            resource_path = os.path.join(
145                Path(temp_dir), Path(filename).name, "one", "a.txt"
146            )
147            self.assertEqual(reader_one.resource_path("a.txt"), resource_path)
148
149            self.assertTrue(reader_one.is_resource("a.txt"))
150            self.assertEqual(
151                reader_one.open_resource("a.txt").getbuffer(), b"hello, a!"
152            )
153            self.assertFalse(reader_one.is_resource("three"))
154            reader_one_contents = list(reader_one.contents())
155            reader_one_contents.sort()
156            self.assertSequenceEqual(
157                reader_one_contents, ["a.txt", "b.txt", "c.txt", "three"]
158            )
159
160            reader_two = importer.get_resource_reader("two")
161            self.assertTrue(reader_two.is_resource("f.txt"))
162            self.assertEqual(
163                reader_two.open_resource("f.txt").getbuffer(), b"hello, f!"
164            )
165            reader_two_contents = list(reader_two.contents())
166            reader_two_contents.sort()
167            self.assertSequenceEqual(reader_two_contents, ["f.txt", "g.txt"])
168
169            reader_one_three = importer.get_resource_reader("one.three")
170            self.assertTrue(reader_one_three.is_resource("d.txt"))
171            self.assertEqual(
172                reader_one_three.open_resource("d.txt").getbuffer(), b"hello, d!"
173            )
174            reader_one_three_contents = list(reader_one_three.contents())
175            reader_one_three_contents.sort()
176            self.assertSequenceEqual(reader_one_three_contents, ["d.txt", "e.txt"])
177
178            self.assertIsNone(importer.get_resource_reader("nonexistent_package"))
179
180    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
181    def test_package_resource_access(self):
182        """Packaged modules should be able to use the importlib.resources API to access
183        resources saved in the package.
184        """
185        mod_src = dedent(
186            """\
187            import importlib.resources
188            import my_cool_resources
189
190            def secret_message():
191                return importlib.resources.read_text(my_cool_resources, 'sekrit.txt')
192            """
193        )
194        filename = self.temp()
195        with PackageExporter(filename) as pe:
196            pe.save_source_string("foo.bar", mod_src)
197            pe.save_text("my_cool_resources", "sekrit.txt", "my sekrit plays")
198
199        zip_file = zipfile.ZipFile(filename, "r")
200
201        with TemporaryDirectory() as temp_dir:
202            zip_file.extractall(path=temp_dir)
203            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
204            self.assertEqual(
205                dir_importer.import_module("foo.bar").secret_message(),
206                "my sekrit plays",
207            )
208
209    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
210    def test_importer_access(self):
211        filename = self.temp()
212        with PackageExporter(filename) as he:
213            he.save_text("main", "main", "my string")
214            he.save_binary("main", "main_binary", b"my string")
215            src = dedent(
216                """\
217                import importlib
218                import torch_package_importer as resources
219
220                t = resources.load_text('main', 'main')
221                b = resources.load_binary('main', 'main_binary')
222                """
223            )
224            he.save_source_string("main", src, is_package=True)
225
226        zip_file = zipfile.ZipFile(filename, "r")
227
228        with TemporaryDirectory() as temp_dir:
229            zip_file.extractall(path=temp_dir)
230            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
231            m = dir_importer.import_module("main")
232            self.assertEqual(m.t, "my string")
233            self.assertEqual(m.b, b"my string")
234
235    @skipIf(version_info < (3, 7), "ResourceReader API introduced in Python 3.7")
236    def test_resource_access_by_path(self):
237        """
238        Tests that packaged code can used importlib.resources.path.
239        """
240        filename = self.temp()
241        with PackageExporter(filename) as e:
242            e.save_binary("string_module", "my_string", b"my string")
243            src = dedent(
244                """\
245                import importlib.resources
246                import string_module
247
248                with importlib.resources.path(string_module, 'my_string') as path:
249                    with open(path, mode='r', encoding='utf-8') as f:
250                        s = f.read()
251                """
252            )
253            e.save_source_string("main", src, is_package=True)
254
255        zip_file = zipfile.ZipFile(filename, "r")
256
257        with TemporaryDirectory() as temp_dir:
258            zip_file.extractall(path=temp_dir)
259            dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
260            m = dir_importer.import_module("main")
261            self.assertEqual(m.s, "my string")
262
263    def test_scriptobject_failure_message(self):
264        """
265        Test basic saving and loading of a ScriptModule in a directory.
266        Currently not supported.
267        """
268        from package_a.test_module import ModWithTensor
269
270        scripted_mod = torch.jit.script(ModWithTensor(torch.rand(1, 2, 3)))
271
272        filename = self.temp()
273        with PackageExporter(filename) as e:
274            e.save_pickle("res", "mod.pkl", scripted_mod)
275
276        zip_file = zipfile.ZipFile(filename, "r")
277
278        with self.assertRaisesRegex(
279            RuntimeError,
280            "Loading ScriptObjects from a PackageImporter created from a "
281            "directory is not supported. Use a package archive file instead.",
282        ):
283            with TemporaryDirectory() as temp_dir:
284                zip_file.extractall(path=temp_dir)
285                dir_importer = PackageImporter(Path(temp_dir) / Path(filename).name)
286                dir_mod = dir_importer.load_pickle("res", "mod.pkl")
287
288
289if __name__ == "__main__":
290    run_tests()
291