xref: /aosp_15_r20/external/pytorch/test/package/test_mangling.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: package/deploy"]
2
3from io import BytesIO
4
5from torch.package import PackageExporter, PackageImporter
6from torch.package._mangling import (
7    demangle,
8    get_mangle_prefix,
9    is_mangled,
10    PackageMangler,
11)
12from torch.testing._internal.common_utils import run_tests
13
14
15try:
16    from .common import PackageTestCase
17except ImportError:
18    # Support the case where we run this file directly.
19    from common import PackageTestCase
20
21
22class TestMangling(PackageTestCase):
23    def test_unique_manglers(self):
24        """
25        Each mangler instance should generate a unique mangled name for a given input.
26        """
27        a = PackageMangler()
28        b = PackageMangler()
29        self.assertNotEqual(a.mangle("foo.bar"), b.mangle("foo.bar"))
30
31    def test_mangler_is_consistent(self):
32        """
33        Mangling the same name twice should produce the same result.
34        """
35        a = PackageMangler()
36        self.assertEqual(a.mangle("abc.def"), a.mangle("abc.def"))
37
38    def test_roundtrip_mangling(self):
39        a = PackageMangler()
40        self.assertEqual("foo", demangle(a.mangle("foo")))
41
42    def test_is_mangled(self):
43        a = PackageMangler()
44        b = PackageMangler()
45        self.assertTrue(is_mangled(a.mangle("foo.bar")))
46        self.assertTrue(is_mangled(b.mangle("foo.bar")))
47
48        self.assertFalse(is_mangled("foo.bar"))
49        self.assertFalse(is_mangled(demangle(a.mangle("foo.bar"))))
50
51    def test_demangler_multiple_manglers(self):
52        """
53        PackageDemangler should be able to demangle name generated by any PackageMangler.
54        """
55        a = PackageMangler()
56        b = PackageMangler()
57
58        self.assertEqual("foo.bar", demangle(a.mangle("foo.bar")))
59        self.assertEqual("bar.foo", demangle(b.mangle("bar.foo")))
60
61    def test_mangle_empty_errors(self):
62        a = PackageMangler()
63        with self.assertRaises(AssertionError):
64            a.mangle("")
65
66    def test_demangle_base(self):
67        """
68        Demangling a mangle parent directly should currently return an empty string.
69        """
70        a = PackageMangler()
71        mangled = a.mangle("foo")
72        mangle_parent = mangled.partition(".")[0]
73        self.assertEqual("", demangle(mangle_parent))
74
75    def test_mangle_prefix(self):
76        a = PackageMangler()
77        mangled = a.mangle("foo.bar")
78        mangle_prefix = get_mangle_prefix(mangled)
79        self.assertEqual(mangle_prefix + "." + "foo.bar", mangled)
80
81    def test_unique_module_names(self):
82        import package_a.subpackage
83
84        obj = package_a.subpackage.PackageASubpackageObject()
85        obj2 = package_a.PackageAObject(obj)
86        f1 = BytesIO()
87        with PackageExporter(f1) as pe:
88            pe.intern("**")
89            pe.save_pickle("obj", "obj.pkl", obj2)
90        f1.seek(0)
91        importer1 = PackageImporter(f1)
92        loaded1 = importer1.load_pickle("obj", "obj.pkl")
93        f1.seek(0)
94        importer2 = PackageImporter(f1)
95        loaded2 = importer2.load_pickle("obj", "obj.pkl")
96
97        # Modules from loaded packages should not shadow the names of modules.
98        # See mangling.md for more info.
99        self.assertNotEqual(type(obj2).__module__, type(loaded1).__module__)
100        self.assertNotEqual(type(loaded1).__module__, type(loaded2).__module__)
101
102    def test_package_mangler(self):
103        a = PackageMangler()
104        b = PackageMangler()
105        a_mangled = a.mangle("foo.bar")
106        # Since `a` mangled this string, it should demangle properly.
107        self.assertEqual(a.demangle(a_mangled), "foo.bar")
108        # Since `b` did not mangle this string, demangling should leave it alone.
109        self.assertEqual(b.demangle(a_mangled), a_mangled)
110
111
112if __name__ == "__main__":
113    run_tests()
114