1# Owner(s): ["oncall: package/deploy"] 2 3import importlib 4from io import BytesIO 5from sys import version_info 6from textwrap import dedent 7from unittest import skipIf 8 9import torch.nn 10from torch.package import EmptyMatchError, Importer, PackageExporter, PackageImporter 11from torch.package.package_exporter import PackagingError 12from torch.testing._internal.common_utils import IS_WINDOWS, run_tests 13 14 15try: 16 from .common import PackageTestCase 17except ImportError: 18 # Support the case where we run this file directly. 19 from common import PackageTestCase 20 21 22class TestDependencyAPI(PackageTestCase): 23 """Dependency management API tests. 24 - mock() 25 - extern() 26 - deny() 27 """ 28 29 def test_extern(self): 30 buffer = BytesIO() 31 with PackageExporter(buffer) as he: 32 he.extern(["package_a.subpackage", "module_a"]) 33 he.save_source_string("foo", "import package_a.subpackage; import module_a") 34 buffer.seek(0) 35 hi = PackageImporter(buffer) 36 import module_a 37 import package_a.subpackage 38 39 module_a_im = hi.import_module("module_a") 40 hi.import_module("package_a.subpackage") 41 package_a_im = hi.import_module("package_a") 42 43 self.assertIs(module_a, module_a_im) 44 self.assertIsNot(package_a, package_a_im) 45 self.assertIs(package_a.subpackage, package_a_im.subpackage) 46 47 def test_extern_glob(self): 48 buffer = BytesIO() 49 with PackageExporter(buffer) as he: 50 he.extern(["package_a.*", "module_*"]) 51 he.save_module("package_a") 52 he.save_source_string( 53 "test_module", 54 dedent( 55 """\ 56 import package_a.subpackage 57 import module_a 58 """ 59 ), 60 ) 61 buffer.seek(0) 62 hi = PackageImporter(buffer) 63 import module_a 64 import package_a.subpackage 65 66 module_a_im = hi.import_module("module_a") 67 hi.import_module("package_a.subpackage") 68 package_a_im = hi.import_module("package_a") 69 70 self.assertIs(module_a, module_a_im) 71 self.assertIsNot(package_a, package_a_im) 72 self.assertIs(package_a.subpackage, package_a_im.subpackage) 73 74 def test_extern_glob_allow_empty(self): 75 """ 76 Test that an error is thrown when a extern glob is specified with allow_empty=True 77 and no matching module is required during packaging. 78 """ 79 import package_a.subpackage # noqa: F401 80 81 buffer = BytesIO() 82 with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"): 83 with PackageExporter(buffer) as exporter: 84 exporter.extern(include=["package_b.*"], allow_empty=False) 85 exporter.save_module("package_a.subpackage") 86 87 def test_deny(self): 88 """ 89 Test marking packages as "deny" during export. 90 """ 91 buffer = BytesIO() 92 93 with self.assertRaisesRegex(PackagingError, "denied"): 94 with PackageExporter(buffer) as exporter: 95 exporter.deny(["package_a.subpackage", "module_a"]) 96 exporter.save_source_string("foo", "import package_a.subpackage") 97 98 def test_deny_glob(self): 99 """ 100 Test marking packages as "deny" using globs instead of package names. 101 """ 102 buffer = BytesIO() 103 with self.assertRaises(PackagingError): 104 with PackageExporter(buffer) as exporter: 105 exporter.deny(["package_a.*", "module_*"]) 106 exporter.save_source_string( 107 "test_module", 108 dedent( 109 """\ 110 import package_a.subpackage 111 import module_a 112 """ 113 ), 114 ) 115 116 @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") 117 def test_mock(self): 118 buffer = BytesIO() 119 with PackageExporter(buffer) as he: 120 he.mock(["package_a.subpackage", "module_a"]) 121 # Import something that dependso n package_a.subpackage 122 he.save_source_string("foo", "import package_a.subpackage") 123 buffer.seek(0) 124 hi = PackageImporter(buffer) 125 import package_a.subpackage 126 127 _ = package_a.subpackage 128 import module_a 129 130 _ = module_a 131 132 m = hi.import_module("package_a.subpackage") 133 r = m.result 134 with self.assertRaisesRegex(NotImplementedError, "was mocked out"): 135 r() 136 137 @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") 138 def test_mock_glob(self): 139 buffer = BytesIO() 140 with PackageExporter(buffer) as he: 141 he.mock(["package_a.*", "module*"]) 142 he.save_module("package_a") 143 he.save_source_string( 144 "test_module", 145 dedent( 146 """\ 147 import package_a.subpackage 148 import module_a 149 """ 150 ), 151 ) 152 buffer.seek(0) 153 hi = PackageImporter(buffer) 154 import package_a.subpackage 155 156 _ = package_a.subpackage 157 import module_a 158 159 _ = module_a 160 161 m = hi.import_module("package_a.subpackage") 162 r = m.result 163 with self.assertRaisesRegex(NotImplementedError, "was mocked out"): 164 r() 165 166 def test_mock_glob_allow_empty(self): 167 """ 168 Test that an error is thrown when a mock glob is specified with allow_empty=True 169 and no matching module is required during packaging. 170 """ 171 import package_a.subpackage # noqa: F401 172 173 buffer = BytesIO() 174 with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"): 175 with PackageExporter(buffer) as exporter: 176 exporter.mock(include=["package_b.*"], allow_empty=False) 177 exporter.save_module("package_a.subpackage") 178 179 @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") 180 def test_pickle_mocked(self): 181 import package_a.subpackage 182 183 obj = package_a.subpackage.PackageASubpackageObject() 184 obj2 = package_a.PackageAObject(obj) 185 186 buffer = BytesIO() 187 with self.assertRaises(PackagingError): 188 with PackageExporter(buffer) as he: 189 he.mock(include="package_a.subpackage") 190 he.intern("**") 191 he.save_pickle("obj", "obj.pkl", obj2) 192 193 @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") 194 def test_pickle_mocked_all(self): 195 import package_a.subpackage 196 197 obj = package_a.subpackage.PackageASubpackageObject() 198 obj2 = package_a.PackageAObject(obj) 199 200 buffer = BytesIO() 201 with PackageExporter(buffer) as he: 202 he.intern(include="package_a.**") 203 he.mock("**") 204 he.save_pickle("obj", "obj.pkl", obj2) 205 206 def test_allow_empty_with_error(self): 207 """If an error occurs during packaging, it should not be shadowed by the allow_empty error.""" 208 buffer = BytesIO() 209 with self.assertRaises(ModuleNotFoundError): 210 with PackageExporter(buffer) as pe: 211 # Even though we did not extern a module that matches this 212 # pattern, we want to show the save_module error, not the allow_empty error. 213 214 pe.extern("foo", allow_empty=False) 215 pe.save_module("aodoifjodisfj") # will error 216 217 # we never get here, so technically the allow_empty check 218 # should raise an error. However, the error above is more 219 # informative to what's actually going wrong with packaging. 220 pe.save_source_string("bar", "import foo\n") 221 222 def test_implicit_intern(self): 223 """The save_module APIs should implicitly intern the module being saved.""" 224 import package_a # noqa: F401 225 226 buffer = BytesIO() 227 with PackageExporter(buffer) as he: 228 he.save_module("package_a") 229 230 def test_intern_error(self): 231 """Failure to handle all dependencies should lead to an error.""" 232 import package_a.subpackage 233 234 obj = package_a.subpackage.PackageASubpackageObject() 235 obj2 = package_a.PackageAObject(obj) 236 237 buffer = BytesIO() 238 239 with self.assertRaises(PackagingError) as e: 240 with PackageExporter(buffer) as he: 241 he.save_pickle("obj", "obj.pkl", obj2) 242 243 self.assertEqual( 244 str(e.exception), 245 dedent( 246 """ 247 * Module did not match against any action pattern. Extern, mock, or intern it. 248 package_a 249 package_a.subpackage 250 251 Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from! 252 """ 253 ), 254 ) 255 256 # Interning all dependencies should work 257 with PackageExporter(buffer) as he: 258 he.intern(["package_a", "package_a.subpackage"]) 259 he.save_pickle("obj", "obj.pkl", obj2) 260 261 @skipIf(IS_WINDOWS, "extension modules have a different file extension on windows") 262 def test_broken_dependency(self): 263 """A unpackageable dependency should raise a PackagingError.""" 264 265 def create_module(name): 266 spec = importlib.machinery.ModuleSpec(name, self, is_package=False) # type: ignore[arg-type] 267 module = importlib.util.module_from_spec(spec) 268 ns = module.__dict__ 269 ns["__spec__"] = spec 270 ns["__loader__"] = self 271 ns["__file__"] = f"{name}.so" 272 ns["__cached__"] = None 273 return module 274 275 class BrokenImporter(Importer): 276 def __init__(self) -> None: 277 self.modules = { 278 "foo": create_module("foo"), 279 "bar": create_module("bar"), 280 } 281 282 def import_module(self, module_name): 283 return self.modules[module_name] 284 285 buffer = BytesIO() 286 287 with self.assertRaises(PackagingError) as e: 288 with PackageExporter(buffer, importer=BrokenImporter()) as exporter: 289 exporter.intern(["foo", "bar"]) 290 exporter.save_source_string("my_module", "import foo; import bar") 291 292 self.assertEqual( 293 str(e.exception), 294 dedent( 295 """ 296 * Module is a C extension module. torch.package supports Python modules only. 297 foo 298 bar 299 300 Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from! 301 """ 302 ), 303 ) 304 305 def test_invalid_import(self): 306 """An incorrectly-formed import should raise a PackagingError.""" 307 buffer = BytesIO() 308 with self.assertRaises(PackagingError) as e: 309 with PackageExporter(buffer) as exporter: 310 # This import will fail to load. 311 exporter.save_source_string("foo", "from ........ import lol") 312 313 self.assertEqual( 314 str(e.exception), 315 dedent( 316 """ 317 * Dependency resolution failed. 318 foo 319 Context: attempted relative import beyond top-level package 320 321 Set debug=True when invoking PackageExporter for a visualization of where broken modules are coming from! 322 """ 323 ), 324 ) 325 326 @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") 327 def test_repackage_mocked_module(self): 328 """Re-packaging a package that contains a mocked module should work correctly.""" 329 buffer = BytesIO() 330 with PackageExporter(buffer) as exporter: 331 exporter.mock("package_a") 332 exporter.save_source_string("foo", "import package_a") 333 334 buffer.seek(0) 335 importer = PackageImporter(buffer) 336 foo = importer.import_module("foo") 337 338 # "package_a" should be mocked out. 339 with self.assertRaises(NotImplementedError): 340 foo.package_a.get_something() 341 342 # Re-package the model, but intern the previously-mocked module and mock 343 # everything else. 344 buffer2 = BytesIO() 345 with PackageExporter(buffer2, importer=importer) as exporter: 346 exporter.intern("package_a") 347 exporter.mock("**") 348 exporter.save_source_string("foo", "import package_a") 349 350 buffer2.seek(0) 351 importer2 = PackageImporter(buffer2) 352 foo2 = importer2.import_module("foo") 353 354 # "package_a" should still be mocked out. 355 with self.assertRaises(NotImplementedError): 356 foo2.package_a.get_something() 357 358 def test_externing_c_extension(self): 359 """Externing c extensions modules should allow us to still access them especially those found in torch._C.""" 360 361 buffer = BytesIO() 362 # The C extension module in question is F.gelu which comes from torch._C._nn 363 model = torch.nn.TransformerEncoderLayer( 364 d_model=64, 365 nhead=2, 366 dim_feedforward=64, 367 dropout=1.0, 368 batch_first=True, 369 activation="gelu", 370 norm_first=True, 371 ) 372 with PackageExporter(buffer) as e: 373 e.extern("torch.**") 374 e.intern("**") 375 376 e.save_pickle("model", "model.pkl", model) 377 buffer.seek(0) 378 imp = PackageImporter(buffer) 379 imp.load_pickle("model", "model.pkl") 380 381 382if __name__ == "__main__": 383 run_tests() 384