xref: /aosp_15_r20/external/pytorch/test/package/test_misc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: package/deploy"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport inspect
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport platform
6*da0073e9SAndroid Build Coastguard Workerimport sys
7*da0073e9SAndroid Build Coastguard Workerfrom io import BytesIO
8*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
9*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent
10*da0073e9SAndroid Build Coastguard Workerfrom unittest import skipIf
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerfrom torch.package import is_from_package, PackageExporter, PackageImporter
13*da0073e9SAndroid Build Coastguard Workerfrom torch.package.package_exporter import PackagingError
14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
15*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
16*da0073e9SAndroid Build Coastguard Worker    IS_SANDCASTLE,
17*da0073e9SAndroid Build Coastguard Worker    run_tests,
18*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workertry:
23*da0073e9SAndroid Build Coastguard Worker    from .common import PackageTestCase
24*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
25*da0073e9SAndroid Build Coastguard Worker    # Support the case where we run this file directly.
26*da0073e9SAndroid Build Coastguard Worker    from common import PackageTestCase
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerclass TestMisc(PackageTestCase):
30*da0073e9SAndroid Build Coastguard Worker    """Tests for one-off or random functionality. Try not to add to this!"""
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    def test_file_structure(self):
33*da0073e9SAndroid Build Coastguard Worker        """
34*da0073e9SAndroid Build Coastguard Worker        Tests package's Directory structure representation of a zip file. Ensures
35*da0073e9SAndroid Build Coastguard Worker        that the returned Directory prints what is expected and filters
36*da0073e9SAndroid Build Coastguard Worker        inputs/outputs correctly.
37*da0073e9SAndroid Build Coastguard Worker        """
38*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker        export_plain = dedent(
41*da0073e9SAndroid Build Coastguard Worker            """\
42*da0073e9SAndroid Build Coastguard Worker                \u251c\u2500\u2500 .data
43*da0073e9SAndroid Build Coastguard Worker                \u2502   \u251c\u2500\u2500 extern_modules
44*da0073e9SAndroid Build Coastguard Worker                \u2502   \u251c\u2500\u2500 python_version
45*da0073e9SAndroid Build Coastguard Worker                \u2502   \u251c\u2500\u2500 serialization_id
46*da0073e9SAndroid Build Coastguard Worker                \u2502   \u2514\u2500\u2500 version
47*da0073e9SAndroid Build Coastguard Worker                \u251c\u2500\u2500 main
48*da0073e9SAndroid Build Coastguard Worker                \u2502   \u2514\u2500\u2500 main
49*da0073e9SAndroid Build Coastguard Worker                \u251c\u2500\u2500 obj
50*da0073e9SAndroid Build Coastguard Worker                \u2502   \u2514\u2500\u2500 obj.pkl
51*da0073e9SAndroid Build Coastguard Worker                \u251c\u2500\u2500 package_a
52*da0073e9SAndroid Build Coastguard Worker                \u2502   \u251c\u2500\u2500 __init__.py
53*da0073e9SAndroid Build Coastguard Worker                \u2502   \u2514\u2500\u2500 subpackage.py
54*da0073e9SAndroid Build Coastguard Worker                \u251c\u2500\u2500 byteorder
55*da0073e9SAndroid Build Coastguard Worker                \u2514\u2500\u2500 module_a.py
56*da0073e9SAndroid Build Coastguard Worker            """
57*da0073e9SAndroid Build Coastguard Worker        )
58*da0073e9SAndroid Build Coastguard Worker        export_include = dedent(
59*da0073e9SAndroid Build Coastguard Worker            """\
60*da0073e9SAndroid Build Coastguard Worker                \u251c\u2500\u2500 obj
61*da0073e9SAndroid Build Coastguard Worker                \u2502   \u2514\u2500\u2500 obj.pkl
62*da0073e9SAndroid Build Coastguard Worker                \u2514\u2500\u2500 package_a
63*da0073e9SAndroid Build Coastguard Worker                    \u2514\u2500\u2500 subpackage.py
64*da0073e9SAndroid Build Coastguard Worker            """
65*da0073e9SAndroid Build Coastguard Worker        )
66*da0073e9SAndroid Build Coastguard Worker        import_exclude = dedent(
67*da0073e9SAndroid Build Coastguard Worker            """\
68*da0073e9SAndroid Build Coastguard Worker                \u251c\u2500\u2500 .data
69*da0073e9SAndroid Build Coastguard Worker                \u2502   \u251c\u2500\u2500 extern_modules
70*da0073e9SAndroid Build Coastguard Worker                \u2502   \u251c\u2500\u2500 python_version
71*da0073e9SAndroid Build Coastguard Worker                \u2502   \u251c\u2500\u2500 serialization_id
72*da0073e9SAndroid Build Coastguard Worker                \u2502   \u2514\u2500\u2500 version
73*da0073e9SAndroid Build Coastguard Worker                \u251c\u2500\u2500 main
74*da0073e9SAndroid Build Coastguard Worker                \u2502   \u2514\u2500\u2500 main
75*da0073e9SAndroid Build Coastguard Worker                \u251c\u2500\u2500 obj
76*da0073e9SAndroid Build Coastguard Worker                \u2502   \u2514\u2500\u2500 obj.pkl
77*da0073e9SAndroid Build Coastguard Worker                \u251c\u2500\u2500 package_a
78*da0073e9SAndroid Build Coastguard Worker                \u2502   \u251c\u2500\u2500 __init__.py
79*da0073e9SAndroid Build Coastguard Worker                \u2502   \u2514\u2500\u2500 subpackage.py
80*da0073e9SAndroid Build Coastguard Worker                \u251c\u2500\u2500 byteorder
81*da0073e9SAndroid Build Coastguard Worker                \u2514\u2500\u2500 module_a.py
82*da0073e9SAndroid Build Coastguard Worker            """
83*da0073e9SAndroid Build Coastguard Worker        )
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as he:
86*da0073e9SAndroid Build Coastguard Worker            import module_a
87*da0073e9SAndroid Build Coastguard Worker            import package_a
88*da0073e9SAndroid Build Coastguard Worker            import package_a.subpackage
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker            obj = package_a.subpackage.PackageASubpackageObject()
91*da0073e9SAndroid Build Coastguard Worker            he.intern("**")
92*da0073e9SAndroid Build Coastguard Worker            he.save_module(module_a.__name__)
93*da0073e9SAndroid Build Coastguard Worker            he.save_module(package_a.__name__)
94*da0073e9SAndroid Build Coastguard Worker            he.save_pickle("obj", "obj.pkl", obj)
95*da0073e9SAndroid Build Coastguard Worker            he.save_text("main", "main", "my string")
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
98*da0073e9SAndroid Build Coastguard Worker        hi = PackageImporter(buffer)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker        file_structure = hi.file_structure()
101*da0073e9SAndroid Build Coastguard Worker        # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently
102*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
103*da0073e9SAndroid Build Coastguard Worker            dedent("\n".join(str(file_structure).split("\n")[1:])),
104*da0073e9SAndroid Build Coastguard Worker            export_plain,
105*da0073e9SAndroid Build Coastguard Worker        )
106*da0073e9SAndroid Build Coastguard Worker        file_structure = hi.file_structure(include=["**/subpackage.py", "**/*.pkl"])
107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
108*da0073e9SAndroid Build Coastguard Worker            dedent("\n".join(str(file_structure).split("\n")[1:])),
109*da0073e9SAndroid Build Coastguard Worker            export_include,
110*da0073e9SAndroid Build Coastguard Worker        )
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        file_structure = hi.file_structure(exclude="**/*.storage")
113*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
114*da0073e9SAndroid Build Coastguard Worker            dedent("\n".join(str(file_structure).split("\n")[1:])),
115*da0073e9SAndroid Build Coastguard Worker            import_exclude,
116*da0073e9SAndroid Build Coastguard Worker        )
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    def test_loaders_that_remap_files_work_ok(self):
119*da0073e9SAndroid Build Coastguard Worker        from importlib.abc import MetaPathFinder
120*da0073e9SAndroid Build Coastguard Worker        from importlib.machinery import SourceFileLoader
121*da0073e9SAndroid Build Coastguard Worker        from importlib.util import spec_from_loader
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker        class LoaderThatRemapsModuleA(SourceFileLoader):
124*da0073e9SAndroid Build Coastguard Worker            def get_filename(self, name):
125*da0073e9SAndroid Build Coastguard Worker                result = super().get_filename(name)
126*da0073e9SAndroid Build Coastguard Worker                if name == "module_a":
127*da0073e9SAndroid Build Coastguard Worker                    return os.path.join(
128*da0073e9SAndroid Build Coastguard Worker                        os.path.dirname(result), "module_a_remapped_path.py"
129*da0073e9SAndroid Build Coastguard Worker                    )
130*da0073e9SAndroid Build Coastguard Worker                else:
131*da0073e9SAndroid Build Coastguard Worker                    return result
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        class FinderThatRemapsModuleA(MetaPathFinder):
134*da0073e9SAndroid Build Coastguard Worker            def find_spec(self, fullname, path, target):
135*da0073e9SAndroid Build Coastguard Worker                """Try to find the original spec for module_a using all the
136*da0073e9SAndroid Build Coastguard Worker                remaining meta_path finders."""
137*da0073e9SAndroid Build Coastguard Worker                if fullname != "module_a":
138*da0073e9SAndroid Build Coastguard Worker                    return None
139*da0073e9SAndroid Build Coastguard Worker                spec = None
140*da0073e9SAndroid Build Coastguard Worker                for finder in sys.meta_path:
141*da0073e9SAndroid Build Coastguard Worker                    if finder is self:
142*da0073e9SAndroid Build Coastguard Worker                        continue
143*da0073e9SAndroid Build Coastguard Worker                    if hasattr(finder, "find_spec"):
144*da0073e9SAndroid Build Coastguard Worker                        spec = finder.find_spec(fullname, path, target=target)
145*da0073e9SAndroid Build Coastguard Worker                    elif hasattr(finder, "load_module"):
146*da0073e9SAndroid Build Coastguard Worker                        spec = spec_from_loader(fullname, finder)
147*da0073e9SAndroid Build Coastguard Worker                    if spec is not None:
148*da0073e9SAndroid Build Coastguard Worker                        break
149*da0073e9SAndroid Build Coastguard Worker                assert spec is not None and isinstance(spec.loader, SourceFileLoader)
150*da0073e9SAndroid Build Coastguard Worker                spec.loader = LoaderThatRemapsModuleA(
151*da0073e9SAndroid Build Coastguard Worker                    spec.loader.name, spec.loader.path
152*da0073e9SAndroid Build Coastguard Worker                )
153*da0073e9SAndroid Build Coastguard Worker                return spec
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        sys.meta_path.insert(0, FinderThatRemapsModuleA())
156*da0073e9SAndroid Build Coastguard Worker        # clear it from sys.modules so that we use the custom finder next time
157*da0073e9SAndroid Build Coastguard Worker        # it gets imported
158*da0073e9SAndroid Build Coastguard Worker        sys.modules.pop("module_a", None)
159*da0073e9SAndroid Build Coastguard Worker        try:
160*da0073e9SAndroid Build Coastguard Worker            buffer = BytesIO()
161*da0073e9SAndroid Build Coastguard Worker            with PackageExporter(buffer) as he:
162*da0073e9SAndroid Build Coastguard Worker                import module_a
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker                he.intern("**")
165*da0073e9SAndroid Build Coastguard Worker                he.save_module(module_a.__name__)
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker            buffer.seek(0)
168*da0073e9SAndroid Build Coastguard Worker            hi = PackageImporter(buffer)
169*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("remapped_path" in hi.get_source("module_a"))
170*da0073e9SAndroid Build Coastguard Worker        finally:
171*da0073e9SAndroid Build Coastguard Worker            # pop it again to ensure it does not mess up other tests
172*da0073e9SAndroid Build Coastguard Worker            sys.modules.pop("module_a", None)
173*da0073e9SAndroid Build Coastguard Worker            sys.meta_path.pop(0)
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker    def test_python_version(self):
176*da0073e9SAndroid Build Coastguard Worker        """
177*da0073e9SAndroid Build Coastguard Worker        Tests that the current python version is stored in the package and is available
178*da0073e9SAndroid Build Coastguard Worker        via PackageImporter's python_version() method.
179*da0073e9SAndroid Build Coastguard Worker        """
180*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as he:
183*da0073e9SAndroid Build Coastguard Worker            from package_a.test_module import SimpleTest
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker            he.intern("**")
186*da0073e9SAndroid Build Coastguard Worker            obj = SimpleTest()
187*da0073e9SAndroid Build Coastguard Worker            he.save_pickle("obj", "obj.pkl", obj)
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
190*da0073e9SAndroid Build Coastguard Worker        hi = PackageImporter(buffer)
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hi.python_version(), platform.python_version())
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker    @skipIf(
195*da0073e9SAndroid Build Coastguard Worker        IS_FBCODE or IS_SANDCASTLE,
196*da0073e9SAndroid Build Coastguard Worker        "Tests that use temporary files are disabled in fbcode",
197*da0073e9SAndroid Build Coastguard Worker    )
198*da0073e9SAndroid Build Coastguard Worker    def test_load_python_version_from_package(self):
199*da0073e9SAndroid Build Coastguard Worker        """Tests loading a package with a python version embdded"""
200*da0073e9SAndroid Build Coastguard Worker        importer1 = PackageImporter(
201*da0073e9SAndroid Build Coastguard Worker            f"{Path(__file__).parent}/package_e/test_nn_module.pt"
202*da0073e9SAndroid Build Coastguard Worker        )
203*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(importer1.python_version(), "3.9.7")
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker    def test_file_structure_has_file(self):
206*da0073e9SAndroid Build Coastguard Worker        """
207*da0073e9SAndroid Build Coastguard Worker        Test Directory's has_file() method.
208*da0073e9SAndroid Build Coastguard Worker        """
209*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
210*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as he:
211*da0073e9SAndroid Build Coastguard Worker            import package_a.subpackage
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker            he.intern("**")
214*da0073e9SAndroid Build Coastguard Worker            obj = package_a.subpackage.PackageASubpackageObject()
215*da0073e9SAndroid Build Coastguard Worker            he.save_pickle("obj", "obj.pkl", obj)
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        importer = PackageImporter(buffer)
220*da0073e9SAndroid Build Coastguard Worker        file_structure = importer.file_structure()
221*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(file_structure.has_file("package_a/subpackage.py"))
222*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(file_structure.has_file("package_a/subpackage"))
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    def test_exporter_content_lists(self):
225*da0073e9SAndroid Build Coastguard Worker        """
226*da0073e9SAndroid Build Coastguard Worker        Test content list API for PackageExporter's contained modules.
227*da0073e9SAndroid Build Coastguard Worker        """
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(BytesIO()) as he:
230*da0073e9SAndroid Build Coastguard Worker            import package_b
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker            he.extern("package_b.subpackage_1")
233*da0073e9SAndroid Build Coastguard Worker            he.mock("package_b.subpackage_2")
234*da0073e9SAndroid Build Coastguard Worker            he.intern("**")
235*da0073e9SAndroid Build Coastguard Worker            he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"]))
236*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(he.externed_modules(), ["package_b.subpackage_1"])
237*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(he.mocked_modules(), ["package_b.subpackage_2"])
238*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
239*da0073e9SAndroid Build Coastguard Worker                he.interned_modules(),
240*da0073e9SAndroid Build Coastguard Worker                ["package_b", "package_b.subpackage_0.subsubpackage_0"],
241*da0073e9SAndroid Build Coastguard Worker            )
242*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(he.get_rdeps("package_b.subpackage_2"), ["package_b"])
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(PackagingError) as e:
245*da0073e9SAndroid Build Coastguard Worker            with PackageExporter(BytesIO()) as he:
246*da0073e9SAndroid Build Coastguard Worker                import package_b
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker                he.deny("package_b")
249*da0073e9SAndroid Build Coastguard Worker                he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"]))
250*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(he.denied_modules(), ["package_b"])
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker    def test_is_from_package(self):
253*da0073e9SAndroid Build Coastguard Worker        """is_from_package should work for objects and modules"""
254*da0073e9SAndroid Build Coastguard Worker        import package_a.subpackage
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
257*da0073e9SAndroid Build Coastguard Worker        obj = package_a.subpackage.PackageASubpackageObject()
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as pe:
260*da0073e9SAndroid Build Coastguard Worker            pe.intern("**")
261*da0073e9SAndroid Build Coastguard Worker            pe.save_pickle("obj", "obj.pkl", obj)
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
264*da0073e9SAndroid Build Coastguard Worker        pi = PackageImporter(buffer)
265*da0073e9SAndroid Build Coastguard Worker        mod = pi.import_module("package_a.subpackage")
266*da0073e9SAndroid Build Coastguard Worker        loaded_obj = pi.load_pickle("obj", "obj.pkl")
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(is_from_package(package_a.subpackage))
269*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_from_package(mod))
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(is_from_package(obj))
272*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(is_from_package(loaded_obj))
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker    def test_inspect_class(self):
275*da0073e9SAndroid Build Coastguard Worker        """Should be able to retrieve source for a packaged class."""
276*da0073e9SAndroid Build Coastguard Worker        import package_a.subpackage
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
279*da0073e9SAndroid Build Coastguard Worker        obj = package_a.subpackage.PackageASubpackageObject()
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as pe:
282*da0073e9SAndroid Build Coastguard Worker            pe.intern("**")
283*da0073e9SAndroid Build Coastguard Worker            pe.save_pickle("obj", "obj.pkl", obj)
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
286*da0073e9SAndroid Build Coastguard Worker        pi = PackageImporter(buffer)
287*da0073e9SAndroid Build Coastguard Worker        packaged_class = pi.import_module(
288*da0073e9SAndroid Build Coastguard Worker            "package_a.subpackage"
289*da0073e9SAndroid Build Coastguard Worker        ).PackageASubpackageObject
290*da0073e9SAndroid Build Coastguard Worker        regular_class = package_a.subpackage.PackageASubpackageObject
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        packaged_src = inspect.getsourcelines(packaged_class)
293*da0073e9SAndroid Build Coastguard Worker        regular_src = inspect.getsourcelines(regular_class)
294*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(packaged_src, regular_src)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker    def test_dunder_package_present(self):
297*da0073e9SAndroid Build Coastguard Worker        """
298*da0073e9SAndroid Build Coastguard Worker        The attribute '__torch_package__' should be populated on imported modules.
299*da0073e9SAndroid Build Coastguard Worker        """
300*da0073e9SAndroid Build Coastguard Worker        import package_a.subpackage
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
303*da0073e9SAndroid Build Coastguard Worker        obj = package_a.subpackage.PackageASubpackageObject()
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as pe:
306*da0073e9SAndroid Build Coastguard Worker            pe.intern("**")
307*da0073e9SAndroid Build Coastguard Worker            pe.save_pickle("obj", "obj.pkl", obj)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
310*da0073e9SAndroid Build Coastguard Worker        pi = PackageImporter(buffer)
311*da0073e9SAndroid Build Coastguard Worker        mod = pi.import_module("package_a.subpackage")
312*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(mod, "__torch_package__"))
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker    def test_dunder_package_works_from_package(self):
315*da0073e9SAndroid Build Coastguard Worker        """
316*da0073e9SAndroid Build Coastguard Worker        The attribute '__torch_package__' should be accessible from within
317*da0073e9SAndroid Build Coastguard Worker        the module itself, so that packaged code can detect whether it's
318*da0073e9SAndroid Build Coastguard Worker        being used in a packaged context or not.
319*da0073e9SAndroid Build Coastguard Worker        """
320*da0073e9SAndroid Build Coastguard Worker        import package_a.use_dunder_package as mod
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as pe:
325*da0073e9SAndroid Build Coastguard Worker            pe.intern("**")
326*da0073e9SAndroid Build Coastguard Worker            pe.save_module(mod.__name__)
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
329*da0073e9SAndroid Build Coastguard Worker        pi = PackageImporter(buffer)
330*da0073e9SAndroid Build Coastguard Worker        imported_mod = pi.import_module(mod.__name__)
331*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(imported_mod.is_from_package())
332*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mod.is_from_package())
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
335*da0073e9SAndroid Build Coastguard Worker    def test_std_lib_sys_hackery_checks(self):
336*da0073e9SAndroid Build Coastguard Worker        """
337*da0073e9SAndroid Build Coastguard Worker        The standard library performs sys.module assignment hackery which
338*da0073e9SAndroid Build Coastguard Worker        causes modules who do this hackery to fail on import. See
339*da0073e9SAndroid Build Coastguard Worker        https://github.com/pytorch/pytorch/issues/57490 for more information.
340*da0073e9SAndroid Build Coastguard Worker        """
341*da0073e9SAndroid Build Coastguard Worker        import package_a.std_sys_module_hacks
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker        buffer = BytesIO()
344*da0073e9SAndroid Build Coastguard Worker        mod = package_a.std_sys_module_hacks.Module()
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker        with PackageExporter(buffer) as pe:
347*da0073e9SAndroid Build Coastguard Worker            pe.intern("**")
348*da0073e9SAndroid Build Coastguard Worker            pe.save_pickle("obj", "obj.pkl", mod)
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
351*da0073e9SAndroid Build Coastguard Worker        pi = PackageImporter(buffer)
352*da0073e9SAndroid Build Coastguard Worker        mod = pi.load_pickle("obj", "obj.pkl")
353*da0073e9SAndroid Build Coastguard Worker        mod()
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
357*da0073e9SAndroid Build Coastguard Worker    run_tests()
358