1# Owner(s): ["oncall: package/deploy"] 2 3import inspect 4import os 5import platform 6import sys 7from io import BytesIO 8from pathlib import Path 9from textwrap import dedent 10from unittest import skipIf 11 12from torch.package import is_from_package, PackageExporter, PackageImporter 13from torch.package.package_exporter import PackagingError 14from torch.testing._internal.common_utils import ( 15 IS_FBCODE, 16 IS_SANDCASTLE, 17 run_tests, 18 skipIfTorchDynamo, 19) 20 21 22try: 23 from .common import PackageTestCase 24except ImportError: 25 # Support the case where we run this file directly. 26 from common import PackageTestCase 27 28 29class TestMisc(PackageTestCase): 30 """Tests for one-off or random functionality. Try not to add to this!""" 31 32 def test_file_structure(self): 33 """ 34 Tests package's Directory structure representation of a zip file. Ensures 35 that the returned Directory prints what is expected and filters 36 inputs/outputs correctly. 37 """ 38 buffer = BytesIO() 39 40 export_plain = dedent( 41 """\ 42 \u251c\u2500\u2500 .data 43 \u2502 \u251c\u2500\u2500 extern_modules 44 \u2502 \u251c\u2500\u2500 python_version 45 \u2502 \u251c\u2500\u2500 serialization_id 46 \u2502 \u2514\u2500\u2500 version 47 \u251c\u2500\u2500 main 48 \u2502 \u2514\u2500\u2500 main 49 \u251c\u2500\u2500 obj 50 \u2502 \u2514\u2500\u2500 obj.pkl 51 \u251c\u2500\u2500 package_a 52 \u2502 \u251c\u2500\u2500 __init__.py 53 \u2502 \u2514\u2500\u2500 subpackage.py 54 \u251c\u2500\u2500 byteorder 55 \u2514\u2500\u2500 module_a.py 56 """ 57 ) 58 export_include = dedent( 59 """\ 60 \u251c\u2500\u2500 obj 61 \u2502 \u2514\u2500\u2500 obj.pkl 62 \u2514\u2500\u2500 package_a 63 \u2514\u2500\u2500 subpackage.py 64 """ 65 ) 66 import_exclude = dedent( 67 """\ 68 \u251c\u2500\u2500 .data 69 \u2502 \u251c\u2500\u2500 extern_modules 70 \u2502 \u251c\u2500\u2500 python_version 71 \u2502 \u251c\u2500\u2500 serialization_id 72 \u2502 \u2514\u2500\u2500 version 73 \u251c\u2500\u2500 main 74 \u2502 \u2514\u2500\u2500 main 75 \u251c\u2500\u2500 obj 76 \u2502 \u2514\u2500\u2500 obj.pkl 77 \u251c\u2500\u2500 package_a 78 \u2502 \u251c\u2500\u2500 __init__.py 79 \u2502 \u2514\u2500\u2500 subpackage.py 80 \u251c\u2500\u2500 byteorder 81 \u2514\u2500\u2500 module_a.py 82 """ 83 ) 84 85 with PackageExporter(buffer) as he: 86 import module_a 87 import package_a 88 import package_a.subpackage 89 90 obj = package_a.subpackage.PackageASubpackageObject() 91 he.intern("**") 92 he.save_module(module_a.__name__) 93 he.save_module(package_a.__name__) 94 he.save_pickle("obj", "obj.pkl", obj) 95 he.save_text("main", "main", "my string") 96 97 buffer.seek(0) 98 hi = PackageImporter(buffer) 99 100 file_structure = hi.file_structure() 101 # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently 102 self.assertEqual( 103 dedent("\n".join(str(file_structure).split("\n")[1:])), 104 export_plain, 105 ) 106 file_structure = hi.file_structure(include=["**/subpackage.py", "**/*.pkl"]) 107 self.assertEqual( 108 dedent("\n".join(str(file_structure).split("\n")[1:])), 109 export_include, 110 ) 111 112 file_structure = hi.file_structure(exclude="**/*.storage") 113 self.assertEqual( 114 dedent("\n".join(str(file_structure).split("\n")[1:])), 115 import_exclude, 116 ) 117 118 def test_loaders_that_remap_files_work_ok(self): 119 from importlib.abc import MetaPathFinder 120 from importlib.machinery import SourceFileLoader 121 from importlib.util import spec_from_loader 122 123 class LoaderThatRemapsModuleA(SourceFileLoader): 124 def get_filename(self, name): 125 result = super().get_filename(name) 126 if name == "module_a": 127 return os.path.join( 128 os.path.dirname(result), "module_a_remapped_path.py" 129 ) 130 else: 131 return result 132 133 class FinderThatRemapsModuleA(MetaPathFinder): 134 def find_spec(self, fullname, path, target): 135 """Try to find the original spec for module_a using all the 136 remaining meta_path finders.""" 137 if fullname != "module_a": 138 return None 139 spec = None 140 for finder in sys.meta_path: 141 if finder is self: 142 continue 143 if hasattr(finder, "find_spec"): 144 spec = finder.find_spec(fullname, path, target=target) 145 elif hasattr(finder, "load_module"): 146 spec = spec_from_loader(fullname, finder) 147 if spec is not None: 148 break 149 assert spec is not None and isinstance(spec.loader, SourceFileLoader) 150 spec.loader = LoaderThatRemapsModuleA( 151 spec.loader.name, spec.loader.path 152 ) 153 return spec 154 155 sys.meta_path.insert(0, FinderThatRemapsModuleA()) 156 # clear it from sys.modules so that we use the custom finder next time 157 # it gets imported 158 sys.modules.pop("module_a", None) 159 try: 160 buffer = BytesIO() 161 with PackageExporter(buffer) as he: 162 import module_a 163 164 he.intern("**") 165 he.save_module(module_a.__name__) 166 167 buffer.seek(0) 168 hi = PackageImporter(buffer) 169 self.assertTrue("remapped_path" in hi.get_source("module_a")) 170 finally: 171 # pop it again to ensure it does not mess up other tests 172 sys.modules.pop("module_a", None) 173 sys.meta_path.pop(0) 174 175 def test_python_version(self): 176 """ 177 Tests that the current python version is stored in the package and is available 178 via PackageImporter's python_version() method. 179 """ 180 buffer = BytesIO() 181 182 with PackageExporter(buffer) as he: 183 from package_a.test_module import SimpleTest 184 185 he.intern("**") 186 obj = SimpleTest() 187 he.save_pickle("obj", "obj.pkl", obj) 188 189 buffer.seek(0) 190 hi = PackageImporter(buffer) 191 192 self.assertEqual(hi.python_version(), platform.python_version()) 193 194 @skipIf( 195 IS_FBCODE or IS_SANDCASTLE, 196 "Tests that use temporary files are disabled in fbcode", 197 ) 198 def test_load_python_version_from_package(self): 199 """Tests loading a package with a python version embdded""" 200 importer1 = PackageImporter( 201 f"{Path(__file__).parent}/package_e/test_nn_module.pt" 202 ) 203 self.assertEqual(importer1.python_version(), "3.9.7") 204 205 def test_file_structure_has_file(self): 206 """ 207 Test Directory's has_file() method. 208 """ 209 buffer = BytesIO() 210 with PackageExporter(buffer) as he: 211 import package_a.subpackage 212 213 he.intern("**") 214 obj = package_a.subpackage.PackageASubpackageObject() 215 he.save_pickle("obj", "obj.pkl", obj) 216 217 buffer.seek(0) 218 219 importer = PackageImporter(buffer) 220 file_structure = importer.file_structure() 221 self.assertTrue(file_structure.has_file("package_a/subpackage.py")) 222 self.assertFalse(file_structure.has_file("package_a/subpackage")) 223 224 def test_exporter_content_lists(self): 225 """ 226 Test content list API for PackageExporter's contained modules. 227 """ 228 229 with PackageExporter(BytesIO()) as he: 230 import package_b 231 232 he.extern("package_b.subpackage_1") 233 he.mock("package_b.subpackage_2") 234 he.intern("**") 235 he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"])) 236 self.assertEqual(he.externed_modules(), ["package_b.subpackage_1"]) 237 self.assertEqual(he.mocked_modules(), ["package_b.subpackage_2"]) 238 self.assertEqual( 239 he.interned_modules(), 240 ["package_b", "package_b.subpackage_0.subsubpackage_0"], 241 ) 242 self.assertEqual(he.get_rdeps("package_b.subpackage_2"), ["package_b"]) 243 244 with self.assertRaises(PackagingError) as e: 245 with PackageExporter(BytesIO()) as he: 246 import package_b 247 248 he.deny("package_b") 249 he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"])) 250 self.assertEqual(he.denied_modules(), ["package_b"]) 251 252 def test_is_from_package(self): 253 """is_from_package should work for objects and modules""" 254 import package_a.subpackage 255 256 buffer = BytesIO() 257 obj = package_a.subpackage.PackageASubpackageObject() 258 259 with PackageExporter(buffer) as pe: 260 pe.intern("**") 261 pe.save_pickle("obj", "obj.pkl", obj) 262 263 buffer.seek(0) 264 pi = PackageImporter(buffer) 265 mod = pi.import_module("package_a.subpackage") 266 loaded_obj = pi.load_pickle("obj", "obj.pkl") 267 268 self.assertFalse(is_from_package(package_a.subpackage)) 269 self.assertTrue(is_from_package(mod)) 270 271 self.assertFalse(is_from_package(obj)) 272 self.assertTrue(is_from_package(loaded_obj)) 273 274 def test_inspect_class(self): 275 """Should be able to retrieve source for a packaged class.""" 276 import package_a.subpackage 277 278 buffer = BytesIO() 279 obj = package_a.subpackage.PackageASubpackageObject() 280 281 with PackageExporter(buffer) as pe: 282 pe.intern("**") 283 pe.save_pickle("obj", "obj.pkl", obj) 284 285 buffer.seek(0) 286 pi = PackageImporter(buffer) 287 packaged_class = pi.import_module( 288 "package_a.subpackage" 289 ).PackageASubpackageObject 290 regular_class = package_a.subpackage.PackageASubpackageObject 291 292 packaged_src = inspect.getsourcelines(packaged_class) 293 regular_src = inspect.getsourcelines(regular_class) 294 self.assertEqual(packaged_src, regular_src) 295 296 def test_dunder_package_present(self): 297 """ 298 The attribute '__torch_package__' should be populated on imported modules. 299 """ 300 import package_a.subpackage 301 302 buffer = BytesIO() 303 obj = package_a.subpackage.PackageASubpackageObject() 304 305 with PackageExporter(buffer) as pe: 306 pe.intern("**") 307 pe.save_pickle("obj", "obj.pkl", obj) 308 309 buffer.seek(0) 310 pi = PackageImporter(buffer) 311 mod = pi.import_module("package_a.subpackage") 312 self.assertTrue(hasattr(mod, "__torch_package__")) 313 314 def test_dunder_package_works_from_package(self): 315 """ 316 The attribute '__torch_package__' should be accessible from within 317 the module itself, so that packaged code can detect whether it's 318 being used in a packaged context or not. 319 """ 320 import package_a.use_dunder_package as mod 321 322 buffer = BytesIO() 323 324 with PackageExporter(buffer) as pe: 325 pe.intern("**") 326 pe.save_module(mod.__name__) 327 328 buffer.seek(0) 329 pi = PackageImporter(buffer) 330 imported_mod = pi.import_module(mod.__name__) 331 self.assertTrue(imported_mod.is_from_package()) 332 self.assertFalse(mod.is_from_package()) 333 334 @skipIfTorchDynamo("Not a suitable test for TorchDynamo") 335 def test_std_lib_sys_hackery_checks(self): 336 """ 337 The standard library performs sys.module assignment hackery which 338 causes modules who do this hackery to fail on import. See 339 https://github.com/pytorch/pytorch/issues/57490 for more information. 340 """ 341 import package_a.std_sys_module_hacks 342 343 buffer = BytesIO() 344 mod = package_a.std_sys_module_hacks.Module() 345 346 with PackageExporter(buffer) as pe: 347 pe.intern("**") 348 pe.save_pickle("obj", "obj.pkl", mod) 349 350 buffer.seek(0) 351 pi = PackageImporter(buffer) 352 mod = pi.load_pickle("obj", "obj.pkl") 353 mod() 354 355 356if __name__ == "__main__": 357 run_tests() 358