1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: package/deploy"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport inspect 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport platform 6*da0073e9SAndroid Build Coastguard Workerimport sys 7*da0073e9SAndroid Build Coastguard Workerfrom io import BytesIO 8*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path 9*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent 10*da0073e9SAndroid Build Coastguard Workerfrom unittest import skipIf 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerfrom torch.package import is_from_package, PackageExporter, PackageImporter 13*da0073e9SAndroid Build Coastguard Workerfrom torch.package.package_exporter import PackagingError 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 15*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 16*da0073e9SAndroid Build Coastguard Worker IS_SANDCASTLE, 17*da0073e9SAndroid Build Coastguard Worker run_tests, 18*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workertry: 23*da0073e9SAndroid Build Coastguard Worker from .common import PackageTestCase 24*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 25*da0073e9SAndroid Build Coastguard Worker # Support the case where we run this file directly. 26*da0073e9SAndroid Build Coastguard Worker from common import PackageTestCase 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerclass TestMisc(PackageTestCase): 30*da0073e9SAndroid Build Coastguard Worker """Tests for one-off or random functionality. Try not to add to this!""" 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker def test_file_structure(self): 33*da0073e9SAndroid Build Coastguard Worker """ 34*da0073e9SAndroid Build Coastguard Worker Tests package's Directory structure representation of a zip file. Ensures 35*da0073e9SAndroid Build Coastguard Worker that the returned Directory prints what is expected and filters 36*da0073e9SAndroid Build Coastguard Worker inputs/outputs correctly. 37*da0073e9SAndroid Build Coastguard Worker """ 38*da0073e9SAndroid Build Coastguard Worker buffer = BytesIO() 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker export_plain = dedent( 41*da0073e9SAndroid Build Coastguard Worker """\ 42*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 .data 43*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 extern_modules 44*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 python_version 45*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 serialization_id 46*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 version 47*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 main 48*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 main 49*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 obj 50*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 obj.pkl 51*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 package_a 52*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 __init__.py 53*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 subpackage.py 54*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 byteorder 55*da0073e9SAndroid Build Coastguard Worker \u2514\u2500\u2500 module_a.py 56*da0073e9SAndroid Build Coastguard Worker """ 57*da0073e9SAndroid Build Coastguard Worker ) 58*da0073e9SAndroid Build Coastguard Worker export_include = dedent( 59*da0073e9SAndroid Build Coastguard Worker """\ 60*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 obj 61*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 obj.pkl 62*da0073e9SAndroid Build Coastguard Worker \u2514\u2500\u2500 package_a 63*da0073e9SAndroid Build Coastguard Worker \u2514\u2500\u2500 subpackage.py 64*da0073e9SAndroid Build Coastguard Worker """ 65*da0073e9SAndroid Build Coastguard Worker ) 66*da0073e9SAndroid Build Coastguard Worker import_exclude = dedent( 67*da0073e9SAndroid Build Coastguard Worker """\ 68*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 .data 69*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 extern_modules 70*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 python_version 71*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 serialization_id 72*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 version 73*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 main 74*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 main 75*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 obj 76*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 obj.pkl 77*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 package_a 78*da0073e9SAndroid Build Coastguard Worker \u2502 \u251c\u2500\u2500 __init__.py 79*da0073e9SAndroid Build Coastguard Worker \u2502 \u2514\u2500\u2500 subpackage.py 80*da0073e9SAndroid Build Coastguard Worker \u251c\u2500\u2500 byteorder 81*da0073e9SAndroid Build Coastguard Worker \u2514\u2500\u2500 module_a.py 82*da0073e9SAndroid Build Coastguard Worker """ 83*da0073e9SAndroid Build Coastguard Worker ) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker with PackageExporter(buffer) as he: 86*da0073e9SAndroid Build Coastguard Worker import module_a 87*da0073e9SAndroid Build Coastguard Worker import package_a 88*da0073e9SAndroid Build Coastguard Worker import package_a.subpackage 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker obj = package_a.subpackage.PackageASubpackageObject() 91*da0073e9SAndroid Build Coastguard Worker he.intern("**") 92*da0073e9SAndroid Build Coastguard Worker he.save_module(module_a.__name__) 93*da0073e9SAndroid Build Coastguard Worker he.save_module(package_a.__name__) 94*da0073e9SAndroid Build Coastguard Worker he.save_pickle("obj", "obj.pkl", obj) 95*da0073e9SAndroid Build Coastguard Worker he.save_text("main", "main", "my string") 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 98*da0073e9SAndroid Build Coastguard Worker hi = PackageImporter(buffer) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker file_structure = hi.file_structure() 101*da0073e9SAndroid Build Coastguard Worker # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently 102*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 103*da0073e9SAndroid Build Coastguard Worker dedent("\n".join(str(file_structure).split("\n")[1:])), 104*da0073e9SAndroid Build Coastguard Worker export_plain, 105*da0073e9SAndroid Build Coastguard Worker ) 106*da0073e9SAndroid Build Coastguard Worker file_structure = hi.file_structure(include=["**/subpackage.py", "**/*.pkl"]) 107*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 108*da0073e9SAndroid Build Coastguard Worker dedent("\n".join(str(file_structure).split("\n")[1:])), 109*da0073e9SAndroid Build Coastguard Worker export_include, 110*da0073e9SAndroid Build Coastguard Worker ) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker file_structure = hi.file_structure(exclude="**/*.storage") 113*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 114*da0073e9SAndroid Build Coastguard Worker dedent("\n".join(str(file_structure).split("\n")[1:])), 115*da0073e9SAndroid Build Coastguard Worker import_exclude, 116*da0073e9SAndroid Build Coastguard Worker ) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker def test_loaders_that_remap_files_work_ok(self): 119*da0073e9SAndroid Build Coastguard Worker from importlib.abc import MetaPathFinder 120*da0073e9SAndroid Build Coastguard Worker from importlib.machinery import SourceFileLoader 121*da0073e9SAndroid Build Coastguard Worker from importlib.util import spec_from_loader 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker class LoaderThatRemapsModuleA(SourceFileLoader): 124*da0073e9SAndroid Build Coastguard Worker def get_filename(self, name): 125*da0073e9SAndroid Build Coastguard Worker result = super().get_filename(name) 126*da0073e9SAndroid Build Coastguard Worker if name == "module_a": 127*da0073e9SAndroid Build Coastguard Worker return os.path.join( 128*da0073e9SAndroid Build Coastguard Worker os.path.dirname(result), "module_a_remapped_path.py" 129*da0073e9SAndroid Build Coastguard Worker ) 130*da0073e9SAndroid Build Coastguard Worker else: 131*da0073e9SAndroid Build Coastguard Worker return result 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker class FinderThatRemapsModuleA(MetaPathFinder): 134*da0073e9SAndroid Build Coastguard Worker def find_spec(self, fullname, path, target): 135*da0073e9SAndroid Build Coastguard Worker """Try to find the original spec for module_a using all the 136*da0073e9SAndroid Build Coastguard Worker remaining meta_path finders.""" 137*da0073e9SAndroid Build Coastguard Worker if fullname != "module_a": 138*da0073e9SAndroid Build Coastguard Worker return None 139*da0073e9SAndroid Build Coastguard Worker spec = None 140*da0073e9SAndroid Build Coastguard Worker for finder in sys.meta_path: 141*da0073e9SAndroid Build Coastguard Worker if finder is self: 142*da0073e9SAndroid Build Coastguard Worker continue 143*da0073e9SAndroid Build Coastguard Worker if hasattr(finder, "find_spec"): 144*da0073e9SAndroid Build Coastguard Worker spec = finder.find_spec(fullname, path, target=target) 145*da0073e9SAndroid Build Coastguard Worker elif hasattr(finder, "load_module"): 146*da0073e9SAndroid Build Coastguard Worker spec = spec_from_loader(fullname, finder) 147*da0073e9SAndroid Build Coastguard Worker if spec is not None: 148*da0073e9SAndroid Build Coastguard Worker break 149*da0073e9SAndroid Build Coastguard Worker assert spec is not None and isinstance(spec.loader, SourceFileLoader) 150*da0073e9SAndroid Build Coastguard Worker spec.loader = LoaderThatRemapsModuleA( 151*da0073e9SAndroid Build Coastguard Worker spec.loader.name, spec.loader.path 152*da0073e9SAndroid Build Coastguard Worker ) 153*da0073e9SAndroid Build Coastguard Worker return spec 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker sys.meta_path.insert(0, FinderThatRemapsModuleA()) 156*da0073e9SAndroid Build Coastguard Worker # clear it from sys.modules so that we use the custom finder next time 157*da0073e9SAndroid Build Coastguard Worker # it gets imported 158*da0073e9SAndroid Build Coastguard Worker sys.modules.pop("module_a", None) 159*da0073e9SAndroid Build Coastguard Worker try: 160*da0073e9SAndroid Build Coastguard Worker buffer = BytesIO() 161*da0073e9SAndroid Build Coastguard Worker with PackageExporter(buffer) as he: 162*da0073e9SAndroid Build Coastguard Worker import module_a 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker he.intern("**") 165*da0073e9SAndroid Build Coastguard Worker he.save_module(module_a.__name__) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 168*da0073e9SAndroid Build Coastguard Worker hi = PackageImporter(buffer) 169*da0073e9SAndroid Build Coastguard Worker self.assertTrue("remapped_path" in hi.get_source("module_a")) 170*da0073e9SAndroid Build Coastguard Worker finally: 171*da0073e9SAndroid Build Coastguard Worker # pop it again to ensure it does not mess up other tests 172*da0073e9SAndroid Build Coastguard Worker sys.modules.pop("module_a", None) 173*da0073e9SAndroid Build Coastguard Worker sys.meta_path.pop(0) 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker def test_python_version(self): 176*da0073e9SAndroid Build Coastguard Worker """ 177*da0073e9SAndroid Build Coastguard Worker Tests that the current python version is stored in the package and is available 178*da0073e9SAndroid Build Coastguard Worker via PackageImporter's python_version() method. 179*da0073e9SAndroid Build Coastguard Worker """ 180*da0073e9SAndroid Build Coastguard Worker buffer = BytesIO() 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker with PackageExporter(buffer) as he: 183*da0073e9SAndroid Build Coastguard Worker from package_a.test_module import SimpleTest 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker he.intern("**") 186*da0073e9SAndroid Build Coastguard Worker obj = SimpleTest() 187*da0073e9SAndroid Build Coastguard Worker he.save_pickle("obj", "obj.pkl", obj) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 190*da0073e9SAndroid Build Coastguard Worker hi = PackageImporter(buffer) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hi.python_version(), platform.python_version()) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker @skipIf( 195*da0073e9SAndroid Build Coastguard Worker IS_FBCODE or IS_SANDCASTLE, 196*da0073e9SAndroid Build Coastguard Worker "Tests that use temporary files are disabled in fbcode", 197*da0073e9SAndroid Build Coastguard Worker ) 198*da0073e9SAndroid Build Coastguard Worker def test_load_python_version_from_package(self): 199*da0073e9SAndroid Build Coastguard Worker """Tests loading a package with a python version embdded""" 200*da0073e9SAndroid Build Coastguard Worker importer1 = PackageImporter( 201*da0073e9SAndroid Build Coastguard Worker f"{Path(__file__).parent}/package_e/test_nn_module.pt" 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(importer1.python_version(), "3.9.7") 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker def test_file_structure_has_file(self): 206*da0073e9SAndroid Build Coastguard Worker """ 207*da0073e9SAndroid Build Coastguard Worker Test Directory's has_file() method. 208*da0073e9SAndroid Build Coastguard Worker """ 209*da0073e9SAndroid Build Coastguard Worker buffer = BytesIO() 210*da0073e9SAndroid Build Coastguard Worker with PackageExporter(buffer) as he: 211*da0073e9SAndroid Build Coastguard Worker import package_a.subpackage 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker he.intern("**") 214*da0073e9SAndroid Build Coastguard Worker obj = package_a.subpackage.PackageASubpackageObject() 215*da0073e9SAndroid Build Coastguard Worker he.save_pickle("obj", "obj.pkl", obj) 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker importer = PackageImporter(buffer) 220*da0073e9SAndroid Build Coastguard Worker file_structure = importer.file_structure() 221*da0073e9SAndroid Build Coastguard Worker self.assertTrue(file_structure.has_file("package_a/subpackage.py")) 222*da0073e9SAndroid Build Coastguard Worker self.assertFalse(file_structure.has_file("package_a/subpackage")) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker def test_exporter_content_lists(self): 225*da0073e9SAndroid Build Coastguard Worker """ 226*da0073e9SAndroid Build Coastguard Worker Test content list API for PackageExporter's contained modules. 227*da0073e9SAndroid Build Coastguard Worker """ 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker with PackageExporter(BytesIO()) as he: 230*da0073e9SAndroid Build Coastguard Worker import package_b 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker he.extern("package_b.subpackage_1") 233*da0073e9SAndroid Build Coastguard Worker he.mock("package_b.subpackage_2") 234*da0073e9SAndroid Build Coastguard Worker he.intern("**") 235*da0073e9SAndroid Build Coastguard Worker he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"])) 236*da0073e9SAndroid Build Coastguard Worker self.assertEqual(he.externed_modules(), ["package_b.subpackage_1"]) 237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(he.mocked_modules(), ["package_b.subpackage_2"]) 238*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 239*da0073e9SAndroid Build Coastguard Worker he.interned_modules(), 240*da0073e9SAndroid Build Coastguard Worker ["package_b", "package_b.subpackage_0.subsubpackage_0"], 241*da0073e9SAndroid Build Coastguard Worker ) 242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(he.get_rdeps("package_b.subpackage_2"), ["package_b"]) 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(PackagingError) as e: 245*da0073e9SAndroid Build Coastguard Worker with PackageExporter(BytesIO()) as he: 246*da0073e9SAndroid Build Coastguard Worker import package_b 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker he.deny("package_b") 249*da0073e9SAndroid Build Coastguard Worker he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"])) 250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(he.denied_modules(), ["package_b"]) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker def test_is_from_package(self): 253*da0073e9SAndroid Build Coastguard Worker """is_from_package should work for objects and modules""" 254*da0073e9SAndroid Build Coastguard Worker import package_a.subpackage 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker buffer = BytesIO() 257*da0073e9SAndroid Build Coastguard Worker obj = package_a.subpackage.PackageASubpackageObject() 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker with PackageExporter(buffer) as pe: 260*da0073e9SAndroid Build Coastguard Worker pe.intern("**") 261*da0073e9SAndroid Build Coastguard Worker pe.save_pickle("obj", "obj.pkl", obj) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 264*da0073e9SAndroid Build Coastguard Worker pi = PackageImporter(buffer) 265*da0073e9SAndroid Build Coastguard Worker mod = pi.import_module("package_a.subpackage") 266*da0073e9SAndroid Build Coastguard Worker loaded_obj = pi.load_pickle("obj", "obj.pkl") 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker self.assertFalse(is_from_package(package_a.subpackage)) 269*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_from_package(mod)) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker self.assertFalse(is_from_package(obj)) 272*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_from_package(loaded_obj)) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker def test_inspect_class(self): 275*da0073e9SAndroid Build Coastguard Worker """Should be able to retrieve source for a packaged class.""" 276*da0073e9SAndroid Build Coastguard Worker import package_a.subpackage 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker buffer = BytesIO() 279*da0073e9SAndroid Build Coastguard Worker obj = package_a.subpackage.PackageASubpackageObject() 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker with PackageExporter(buffer) as pe: 282*da0073e9SAndroid Build Coastguard Worker pe.intern("**") 283*da0073e9SAndroid Build Coastguard Worker pe.save_pickle("obj", "obj.pkl", obj) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 286*da0073e9SAndroid Build Coastguard Worker pi = PackageImporter(buffer) 287*da0073e9SAndroid Build Coastguard Worker packaged_class = pi.import_module( 288*da0073e9SAndroid Build Coastguard Worker "package_a.subpackage" 289*da0073e9SAndroid Build Coastguard Worker ).PackageASubpackageObject 290*da0073e9SAndroid Build Coastguard Worker regular_class = package_a.subpackage.PackageASubpackageObject 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker packaged_src = inspect.getsourcelines(packaged_class) 293*da0073e9SAndroid Build Coastguard Worker regular_src = inspect.getsourcelines(regular_class) 294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(packaged_src, regular_src) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker def test_dunder_package_present(self): 297*da0073e9SAndroid Build Coastguard Worker """ 298*da0073e9SAndroid Build Coastguard Worker The attribute '__torch_package__' should be populated on imported modules. 299*da0073e9SAndroid Build Coastguard Worker """ 300*da0073e9SAndroid Build Coastguard Worker import package_a.subpackage 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker buffer = BytesIO() 303*da0073e9SAndroid Build Coastguard Worker obj = package_a.subpackage.PackageASubpackageObject() 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker with PackageExporter(buffer) as pe: 306*da0073e9SAndroid Build Coastguard Worker pe.intern("**") 307*da0073e9SAndroid Build Coastguard Worker pe.save_pickle("obj", "obj.pkl", obj) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 310*da0073e9SAndroid Build Coastguard Worker pi = PackageImporter(buffer) 311*da0073e9SAndroid Build Coastguard Worker mod = pi.import_module("package_a.subpackage") 312*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(mod, "__torch_package__")) 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker def test_dunder_package_works_from_package(self): 315*da0073e9SAndroid Build Coastguard Worker """ 316*da0073e9SAndroid Build Coastguard Worker The attribute '__torch_package__' should be accessible from within 317*da0073e9SAndroid Build Coastguard Worker the module itself, so that packaged code can detect whether it's 318*da0073e9SAndroid Build Coastguard Worker being used in a packaged context or not. 319*da0073e9SAndroid Build Coastguard Worker """ 320*da0073e9SAndroid Build Coastguard Worker import package_a.use_dunder_package as mod 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker buffer = BytesIO() 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker with PackageExporter(buffer) as pe: 325*da0073e9SAndroid Build Coastguard Worker pe.intern("**") 326*da0073e9SAndroid Build Coastguard Worker pe.save_module(mod.__name__) 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 329*da0073e9SAndroid Build Coastguard Worker pi = PackageImporter(buffer) 330*da0073e9SAndroid Build Coastguard Worker imported_mod = pi.import_module(mod.__name__) 331*da0073e9SAndroid Build Coastguard Worker self.assertTrue(imported_mod.is_from_package()) 332*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mod.is_from_package()) 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 335*da0073e9SAndroid Build Coastguard Worker def test_std_lib_sys_hackery_checks(self): 336*da0073e9SAndroid Build Coastguard Worker """ 337*da0073e9SAndroid Build Coastguard Worker The standard library performs sys.module assignment hackery which 338*da0073e9SAndroid Build Coastguard Worker causes modules who do this hackery to fail on import. See 339*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/issues/57490 for more information. 340*da0073e9SAndroid Build Coastguard Worker """ 341*da0073e9SAndroid Build Coastguard Worker import package_a.std_sys_module_hacks 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker buffer = BytesIO() 344*da0073e9SAndroid Build Coastguard Worker mod = package_a.std_sys_module_hacks.Module() 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker with PackageExporter(buffer) as pe: 347*da0073e9SAndroid Build Coastguard Worker pe.intern("**") 348*da0073e9SAndroid Build Coastguard Worker pe.save_pickle("obj", "obj.pkl", mod) 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 351*da0073e9SAndroid Build Coastguard Worker pi = PackageImporter(buffer) 352*da0073e9SAndroid Build Coastguard Worker mod = pi.load_pickle("obj", "obj.pkl") 353*da0073e9SAndroid Build Coastguard Worker mod() 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 357*da0073e9SAndroid Build Coastguard Worker run_tests() 358