xref: /aosp_15_r20/external/pytorch/test/package/test_misc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3import inspect
4import os
5import platform
6import sys
7from io import BytesIO
8from pathlib import Path
9from textwrap import dedent
10from unittest import skipIf
11
12from torch.package import is_from_package, PackageExporter, PackageImporter
13from torch.package.package_exporter import PackagingError
14from torch.testing._internal.common_utils import (
15    IS_FBCODE,
16    IS_SANDCASTLE,
17    run_tests,
18    skipIfTorchDynamo,
19)
20
21
22try:
23    from .common import PackageTestCase
24except ImportError:
25    # Support the case where we run this file directly.
26    from common import PackageTestCase
27
28
29class TestMisc(PackageTestCase):
30    """Tests for one-off or random functionality. Try not to add to this!"""
31
32    def test_file_structure(self):
33        """
34        Tests package's Directory structure representation of a zip file. Ensures
35        that the returned Directory prints what is expected and filters
36        inputs/outputs correctly.
37        """
38        buffer = BytesIO()
39
40        export_plain = dedent(
41            """\
42                \u251c\u2500\u2500 .data
43                \u2502   \u251c\u2500\u2500 extern_modules
44                \u2502   \u251c\u2500\u2500 python_version
45                \u2502   \u251c\u2500\u2500 serialization_id
46                \u2502   \u2514\u2500\u2500 version
47                \u251c\u2500\u2500 main
48                \u2502   \u2514\u2500\u2500 main
49                \u251c\u2500\u2500 obj
50                \u2502   \u2514\u2500\u2500 obj.pkl
51                \u251c\u2500\u2500 package_a
52                \u2502   \u251c\u2500\u2500 __init__.py
53                \u2502   \u2514\u2500\u2500 subpackage.py
54                \u251c\u2500\u2500 byteorder
55                \u2514\u2500\u2500 module_a.py
56            """
57        )
58        export_include = dedent(
59            """\
60                \u251c\u2500\u2500 obj
61                \u2502   \u2514\u2500\u2500 obj.pkl
62                \u2514\u2500\u2500 package_a
63                    \u2514\u2500\u2500 subpackage.py
64            """
65        )
66        import_exclude = dedent(
67            """\
68                \u251c\u2500\u2500 .data
69                \u2502   \u251c\u2500\u2500 extern_modules
70                \u2502   \u251c\u2500\u2500 python_version
71                \u2502   \u251c\u2500\u2500 serialization_id
72                \u2502   \u2514\u2500\u2500 version
73                \u251c\u2500\u2500 main
74                \u2502   \u2514\u2500\u2500 main
75                \u251c\u2500\u2500 obj
76                \u2502   \u2514\u2500\u2500 obj.pkl
77                \u251c\u2500\u2500 package_a
78                \u2502   \u251c\u2500\u2500 __init__.py
79                \u2502   \u2514\u2500\u2500 subpackage.py
80                \u251c\u2500\u2500 byteorder
81                \u2514\u2500\u2500 module_a.py
82            """
83        )
84
85        with PackageExporter(buffer) as he:
86            import module_a
87            import package_a
88            import package_a.subpackage
89
90            obj = package_a.subpackage.PackageASubpackageObject()
91            he.intern("**")
92            he.save_module(module_a.__name__)
93            he.save_module(package_a.__name__)
94            he.save_pickle("obj", "obj.pkl", obj)
95            he.save_text("main", "main", "my string")
96
97        buffer.seek(0)
98        hi = PackageImporter(buffer)
99
100        file_structure = hi.file_structure()
101        # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently
102        self.assertEqual(
103            dedent("\n".join(str(file_structure).split("\n")[1:])),
104            export_plain,
105        )
106        file_structure = hi.file_structure(include=["**/subpackage.py", "**/*.pkl"])
107        self.assertEqual(
108            dedent("\n".join(str(file_structure).split("\n")[1:])),
109            export_include,
110        )
111
112        file_structure = hi.file_structure(exclude="**/*.storage")
113        self.assertEqual(
114            dedent("\n".join(str(file_structure).split("\n")[1:])),
115            import_exclude,
116        )
117
118    def test_loaders_that_remap_files_work_ok(self):
119        from importlib.abc import MetaPathFinder
120        from importlib.machinery import SourceFileLoader
121        from importlib.util import spec_from_loader
122
123        class LoaderThatRemapsModuleA(SourceFileLoader):
124            def get_filename(self, name):
125                result = super().get_filename(name)
126                if name == "module_a":
127                    return os.path.join(
128                        os.path.dirname(result), "module_a_remapped_path.py"
129                    )
130                else:
131                    return result
132
133        class FinderThatRemapsModuleA(MetaPathFinder):
134            def find_spec(self, fullname, path, target):
135                """Try to find the original spec for module_a using all the
136                remaining meta_path finders."""
137                if fullname != "module_a":
138                    return None
139                spec = None
140                for finder in sys.meta_path:
141                    if finder is self:
142                        continue
143                    if hasattr(finder, "find_spec"):
144                        spec = finder.find_spec(fullname, path, target=target)
145                    elif hasattr(finder, "load_module"):
146                        spec = spec_from_loader(fullname, finder)
147                    if spec is not None:
148                        break
149                assert spec is not None and isinstance(spec.loader, SourceFileLoader)
150                spec.loader = LoaderThatRemapsModuleA(
151                    spec.loader.name, spec.loader.path
152                )
153                return spec
154
155        sys.meta_path.insert(0, FinderThatRemapsModuleA())
156        # clear it from sys.modules so that we use the custom finder next time
157        # it gets imported
158        sys.modules.pop("module_a", None)
159        try:
160            buffer = BytesIO()
161            with PackageExporter(buffer) as he:
162                import module_a
163
164                he.intern("**")
165                he.save_module(module_a.__name__)
166
167            buffer.seek(0)
168            hi = PackageImporter(buffer)
169            self.assertTrue("remapped_path" in hi.get_source("module_a"))
170        finally:
171            # pop it again to ensure it does not mess up other tests
172            sys.modules.pop("module_a", None)
173            sys.meta_path.pop(0)
174
175    def test_python_version(self):
176        """
177        Tests that the current python version is stored in the package and is available
178        via PackageImporter's python_version() method.
179        """
180        buffer = BytesIO()
181
182        with PackageExporter(buffer) as he:
183            from package_a.test_module import SimpleTest
184
185            he.intern("**")
186            obj = SimpleTest()
187            he.save_pickle("obj", "obj.pkl", obj)
188
189        buffer.seek(0)
190        hi = PackageImporter(buffer)
191
192        self.assertEqual(hi.python_version(), platform.python_version())
193
194    @skipIf(
195        IS_FBCODE or IS_SANDCASTLE,
196        "Tests that use temporary files are disabled in fbcode",
197    )
198    def test_load_python_version_from_package(self):
199        """Tests loading a package with a python version embdded"""
200        importer1 = PackageImporter(
201            f"{Path(__file__).parent}/package_e/test_nn_module.pt"
202        )
203        self.assertEqual(importer1.python_version(), "3.9.7")
204
205    def test_file_structure_has_file(self):
206        """
207        Test Directory's has_file() method.
208        """
209        buffer = BytesIO()
210        with PackageExporter(buffer) as he:
211            import package_a.subpackage
212
213            he.intern("**")
214            obj = package_a.subpackage.PackageASubpackageObject()
215            he.save_pickle("obj", "obj.pkl", obj)
216
217        buffer.seek(0)
218
219        importer = PackageImporter(buffer)
220        file_structure = importer.file_structure()
221        self.assertTrue(file_structure.has_file("package_a/subpackage.py"))
222        self.assertFalse(file_structure.has_file("package_a/subpackage"))
223
224    def test_exporter_content_lists(self):
225        """
226        Test content list API for PackageExporter's contained modules.
227        """
228
229        with PackageExporter(BytesIO()) as he:
230            import package_b
231
232            he.extern("package_b.subpackage_1")
233            he.mock("package_b.subpackage_2")
234            he.intern("**")
235            he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"]))
236            self.assertEqual(he.externed_modules(), ["package_b.subpackage_1"])
237            self.assertEqual(he.mocked_modules(), ["package_b.subpackage_2"])
238            self.assertEqual(
239                he.interned_modules(),
240                ["package_b", "package_b.subpackage_0.subsubpackage_0"],
241            )
242            self.assertEqual(he.get_rdeps("package_b.subpackage_2"), ["package_b"])
243
244        with self.assertRaises(PackagingError) as e:
245            with PackageExporter(BytesIO()) as he:
246                import package_b
247
248                he.deny("package_b")
249                he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"]))
250                self.assertEqual(he.denied_modules(), ["package_b"])
251
252    def test_is_from_package(self):
253        """is_from_package should work for objects and modules"""
254        import package_a.subpackage
255
256        buffer = BytesIO()
257        obj = package_a.subpackage.PackageASubpackageObject()
258
259        with PackageExporter(buffer) as pe:
260            pe.intern("**")
261            pe.save_pickle("obj", "obj.pkl", obj)
262
263        buffer.seek(0)
264        pi = PackageImporter(buffer)
265        mod = pi.import_module("package_a.subpackage")
266        loaded_obj = pi.load_pickle("obj", "obj.pkl")
267
268        self.assertFalse(is_from_package(package_a.subpackage))
269        self.assertTrue(is_from_package(mod))
270
271        self.assertFalse(is_from_package(obj))
272        self.assertTrue(is_from_package(loaded_obj))
273
274    def test_inspect_class(self):
275        """Should be able to retrieve source for a packaged class."""
276        import package_a.subpackage
277
278        buffer = BytesIO()
279        obj = package_a.subpackage.PackageASubpackageObject()
280
281        with PackageExporter(buffer) as pe:
282            pe.intern("**")
283            pe.save_pickle("obj", "obj.pkl", obj)
284
285        buffer.seek(0)
286        pi = PackageImporter(buffer)
287        packaged_class = pi.import_module(
288            "package_a.subpackage"
289        ).PackageASubpackageObject
290        regular_class = package_a.subpackage.PackageASubpackageObject
291
292        packaged_src = inspect.getsourcelines(packaged_class)
293        regular_src = inspect.getsourcelines(regular_class)
294        self.assertEqual(packaged_src, regular_src)
295
296    def test_dunder_package_present(self):
297        """
298        The attribute '__torch_package__' should be populated on imported modules.
299        """
300        import package_a.subpackage
301
302        buffer = BytesIO()
303        obj = package_a.subpackage.PackageASubpackageObject()
304
305        with PackageExporter(buffer) as pe:
306            pe.intern("**")
307            pe.save_pickle("obj", "obj.pkl", obj)
308
309        buffer.seek(0)
310        pi = PackageImporter(buffer)
311        mod = pi.import_module("package_a.subpackage")
312        self.assertTrue(hasattr(mod, "__torch_package__"))
313
314    def test_dunder_package_works_from_package(self):
315        """
316        The attribute '__torch_package__' should be accessible from within
317        the module itself, so that packaged code can detect whether it's
318        being used in a packaged context or not.
319        """
320        import package_a.use_dunder_package as mod
321
322        buffer = BytesIO()
323
324        with PackageExporter(buffer) as pe:
325            pe.intern("**")
326            pe.save_module(mod.__name__)
327
328        buffer.seek(0)
329        pi = PackageImporter(buffer)
330        imported_mod = pi.import_module(mod.__name__)
331        self.assertTrue(imported_mod.is_from_package())
332        self.assertFalse(mod.is_from_package())
333
334    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
335    def test_std_lib_sys_hackery_checks(self):
336        """
337        The standard library performs sys.module assignment hackery which
338        causes modules who do this hackery to fail on import. See
339        https://github.com/pytorch/pytorch/issues/57490 for more information.
340        """
341        import package_a.std_sys_module_hacks
342
343        buffer = BytesIO()
344        mod = package_a.std_sys_module_hacks.Module()
345
346        with PackageExporter(buffer) as pe:
347            pe.intern("**")
348            pe.save_pickle("obj", "obj.pkl", mod)
349
350        buffer.seek(0)
351        pi = PackageImporter(buffer)
352        mod = pi.load_pickle("obj", "obj.pkl")
353        mod()
354
355
356if __name__ == "__main__":
357    run_tests()
358