1# Owner(s): ["oncall: package/deploy"] 2 3from io import BytesIO 4 5from torch.package import PackageExporter, PackageImporter 6from torch.package._mangling import ( 7 demangle, 8 get_mangle_prefix, 9 is_mangled, 10 PackageMangler, 11) 12from torch.testing._internal.common_utils import run_tests 13 14 15try: 16 from .common import PackageTestCase 17except ImportError: 18 # Support the case where we run this file directly. 19 from common import PackageTestCase 20 21 22class TestMangling(PackageTestCase): 23 def test_unique_manglers(self): 24 """ 25 Each mangler instance should generate a unique mangled name for a given input. 26 """ 27 a = PackageMangler() 28 b = PackageMangler() 29 self.assertNotEqual(a.mangle("foo.bar"), b.mangle("foo.bar")) 30 31 def test_mangler_is_consistent(self): 32 """ 33 Mangling the same name twice should produce the same result. 34 """ 35 a = PackageMangler() 36 self.assertEqual(a.mangle("abc.def"), a.mangle("abc.def")) 37 38 def test_roundtrip_mangling(self): 39 a = PackageMangler() 40 self.assertEqual("foo", demangle(a.mangle("foo"))) 41 42 def test_is_mangled(self): 43 a = PackageMangler() 44 b = PackageMangler() 45 self.assertTrue(is_mangled(a.mangle("foo.bar"))) 46 self.assertTrue(is_mangled(b.mangle("foo.bar"))) 47 48 self.assertFalse(is_mangled("foo.bar")) 49 self.assertFalse(is_mangled(demangle(a.mangle("foo.bar")))) 50 51 def test_demangler_multiple_manglers(self): 52 """ 53 PackageDemangler should be able to demangle name generated by any PackageMangler. 54 """ 55 a = PackageMangler() 56 b = PackageMangler() 57 58 self.assertEqual("foo.bar", demangle(a.mangle("foo.bar"))) 59 self.assertEqual("bar.foo", demangle(b.mangle("bar.foo"))) 60 61 def test_mangle_empty_errors(self): 62 a = PackageMangler() 63 with self.assertRaises(AssertionError): 64 a.mangle("") 65 66 def test_demangle_base(self): 67 """ 68 Demangling a mangle parent directly should currently return an empty string. 69 """ 70 a = PackageMangler() 71 mangled = a.mangle("foo") 72 mangle_parent = mangled.partition(".")[0] 73 self.assertEqual("", demangle(mangle_parent)) 74 75 def test_mangle_prefix(self): 76 a = PackageMangler() 77 mangled = a.mangle("foo.bar") 78 mangle_prefix = get_mangle_prefix(mangled) 79 self.assertEqual(mangle_prefix + "." + "foo.bar", mangled) 80 81 def test_unique_module_names(self): 82 import package_a.subpackage 83 84 obj = package_a.subpackage.PackageASubpackageObject() 85 obj2 = package_a.PackageAObject(obj) 86 f1 = BytesIO() 87 with PackageExporter(f1) as pe: 88 pe.intern("**") 89 pe.save_pickle("obj", "obj.pkl", obj2) 90 f1.seek(0) 91 importer1 = PackageImporter(f1) 92 loaded1 = importer1.load_pickle("obj", "obj.pkl") 93 f1.seek(0) 94 importer2 = PackageImporter(f1) 95 loaded2 = importer2.load_pickle("obj", "obj.pkl") 96 97 # Modules from loaded packages should not shadow the names of modules. 98 # See mangling.md for more info. 99 self.assertNotEqual(type(obj2).__module__, type(loaded1).__module__) 100 self.assertNotEqual(type(loaded1).__module__, type(loaded2).__module__) 101 102 def test_package_mangler(self): 103 a = PackageMangler() 104 b = PackageMangler() 105 a_mangled = a.mangle("foo.bar") 106 # Since `a` mangled this string, it should demangle properly. 107 self.assertEqual(a.demangle(a_mangled), "foo.bar") 108 # Since `b` did not mangle this string, demangling should leave it alone. 109 self.assertEqual(b.demangle(a_mangled), a_mangled) 110 111 112if __name__ == "__main__": 113 run_tests() 114