1# mypy: allow-untyped-defs 2import collections 3import importlib.machinery 4import io 5import linecache 6import pickletools 7import platform 8import types 9from collections import defaultdict, OrderedDict 10from dataclasses import dataclass 11from enum import Enum 12from importlib.machinery import SourceFileLoader 13from pathlib import Path 14from typing import ( 15 Any, 16 BinaryIO, 17 Callable, 18 cast, 19 DefaultDict, 20 Dict, 21 List, 22 Optional, 23 Sequence, 24 Set, 25 Union, 26) 27 28import torch 29from torch.serialization import location_tag, normalize_storage_type 30from torch.types import Storage 31from torch.utils.hooks import RemovableHandle 32 33from ._digraph import DiGraph 34from ._importlib import _normalize_path 35from ._mangling import demangle, is_mangled 36from ._package_pickler import create_pickler 37from ._stdlib import is_stdlib_module 38from .find_file_dependencies import find_files_source_depends_on 39from .glob_group import GlobGroup, GlobPattern 40from .importer import Importer, OrderedImporter, sys_importer 41 42 43__all__ = [ 44 "PackagingErrorReason", 45 "EmptyMatchError", 46 "PackagingError", 47 "PackageExporter", 48] 49 50_gate_torchscript_serialization = True 51 52ActionHook = Callable[["PackageExporter", str], None] 53 54 55class _ModuleProviderAction(Enum): 56 """Represents one of the actions that :class:`PackageExporter` can take on a module. 57 58 See :meth:`PackageExporter.extern` and friends for a description of what the actions do. 59 """ 60 61 INTERN = 1 62 EXTERN = 2 63 MOCK = 3 64 DENY = 4 65 # Special case: when a module is mocked, PackageExporter writes out a 66 # `_mock` module that implements our mocking stubs. If we re-package code, 67 # we may encounter a `_mock` module from the original package. If we do, 68 # just ignore it and write a `_mock` module once. 69 REPACKAGED_MOCK_MODULE = 5 70 # Special case: PackageImporter adds a fake module 71 # (`torch_package_importer`) that allows packaged code to access it. Don't 72 # re-export this. 73 SKIP = 6 74 75 76class PackagingErrorReason(Enum): 77 """Listing of different reasons a dependency may fail to package. 78 79 This enum is used to provide good error messages when 80 :class:`PackagingError` is raised. 81 """ 82 83 def __repr__(self): 84 return f"<{self.__class__.__name__}.{self.name}>" 85 86 IS_EXTENSION_MODULE = ( 87 "Module is a C extension module. torch.package supports Python modules only." 88 ) 89 NO_DUNDER_FILE = "Module had no __file__ defined." 90 SOURCE_FILE_NOT_FOUND = ( 91 "Module had a __file__, but we could not find it in your filesystem." 92 ) 93 DEPENDENCY_RESOLUTION_FAILED = "Dependency resolution failed." 94 NO_ACTION = ( 95 "Module did not match against any action pattern. Extern, mock, or intern it." 96 ) 97 DENIED = "Module was denied by a pattern." 98 MOCKED_BUT_STILL_USED = ( 99 "Module was mocked out, but is still being used in the package. " 100 "Please intern or extern the mocked modules if objects are supposed to be in " 101 "the package." 102 ) 103 104 105@dataclass 106class _PatternInfo: 107 """Holds :class:`PackageExporter`-specific info about how to execute matches against""" 108 109 # What action to take on a module that matches this pattern. 110 action: _ModuleProviderAction 111 # The value of `allow_empty` the user gave when specifying the pattern. 112 allow_empty: bool 113 # Whether this pattern has been matched during packaging. 114 was_matched: bool 115 116 def __init__(self, action, allow_empty): 117 self.action = action 118 self.allow_empty = allow_empty 119 self.was_matched = False 120 121 122class EmptyMatchError(Exception): 123 """This is an exception that is thrown when a mock or extern is marked as 124 ``allow_empty=False``, and is not matched with any module during packaging. 125 """ 126 127 128class PackagingError(Exception): 129 """This exception is raised when there is an issue with exporting a package. 130 ``PackageExporter`` will attempt to gather up all the errors and present 131 them to you at once. 132 """ 133 134 def __init__(self, dependency_graph: DiGraph, debug=False): 135 # Group errors by reason. 136 broken: Dict[PackagingErrorReason, List[str]] = defaultdict(list) 137 for module_name, attrs in dependency_graph.nodes.items(): 138 error = attrs.get("error") 139 if error is None: 140 continue 141 if error == PackagingErrorReason.NO_ACTION: 142 assert "action" not in attrs 143 broken[error].append(module_name) 144 145 message = io.StringIO() 146 message.write("\n") 147 148 for reason, module_names in broken.items(): 149 message.write(f"* {reason.value}\n") 150 for module_name in module_names: 151 message.write(f" {module_name}\n") 152 153 # Print additional context if it's provided. 154 error_context = dependency_graph.nodes[module_name].get("error_context") 155 if error_context is not None: 156 message.write(f" Context: {error_context}\n") 157 if module_name in _DISALLOWED_MODULES: 158 message.write( 159 " Note: While we usually use modules in the python standard library " 160 f"from the local environment, `{module_name}` has a lot of system " 161 "level access and therefore can pose a security risk. We heavily " 162 f"recommend removing `{module_name}` from your packaged code. However, if that " 163 "is not possible, add it to the extern list by calling " 164 f'PackageExporter.extern("`{module_name}`")\n' 165 ) 166 if debug: 167 module_path = dependency_graph.first_path(module_name) 168 message.write( 169 f" A path to {module_name}: {' -> '.join(module_path)}\n" 170 ) 171 if not debug: 172 message.write("\n") 173 message.write( 174 "Set debug=True when invoking PackageExporter for a visualization of where " 175 "broken modules are coming from!\n" 176 ) 177 # Save the dependency graph so that tooling can get at it. 178 self.dependency_graph = dependency_graph 179 super().__init__(message.getvalue()) 180 181 182class PackageExporter: 183 """Exporters allow you to write packages of code, pickled Python data, and 184 arbitrary binary and text resources into a self-contained package. 185 186 Imports can load this code in a hermetic way, such that code is loaded 187 from the package rather than the normal Python import system. This allows 188 for the packaging of PyTorch model code and data so that it can be run 189 on a server or used in the future for transfer learning. 190 191 The code contained in packages is copied file-by-file from the original 192 source when it is created, and the file format is a specially organized 193 zip file. Future users of the package can unzip the package, and edit the code 194 in order to perform custom modifications to it. 195 196 The importer for packages ensures that code in the module can only be loaded from 197 within the package, except for modules explicitly listed as external using :meth:`extern`. 198 The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on. 199 This prevents "implicit" dependencies where the package runs locally because it is importing 200 a locally-installed package, but then fails when the package is copied to another machine. 201 202 When source code is added to the package, the exporter can optionally scan it 203 for further code dependencies (``dependencies=True``). It looks for import statements, 204 resolves relative references to qualified module names, and performs an action specified by the user 205 (See: :meth:`extern`, :meth:`mock`, and :meth:`intern`). 206 """ 207 208 """A importer that will be searched in order to find the modules referenced by other modules or by 209 pickled objects. The default module environment just uses sys_importer, which searches the Python environment. 210 """ 211 importer: Importer 212 213 def __init__( 214 self, 215 f: Union[str, Path, BinaryIO], 216 importer: Union[Importer, Sequence[Importer]] = sys_importer, 217 debug: bool = False, 218 ): 219 """ 220 Create an exporter. 221 222 Args: 223 f: The location to export to. Can be a ``string``/``Path`` object containing a filename 224 or a binary I/O object. 225 importer: If a single Importer is passed, use that to search for modules. 226 If a sequence of importers are passed, an ``OrderedImporter`` will be constructed out of them. 227 debug: If set to True, add path of broken modules to PackagingErrors. 228 """ 229 torch._C._log_api_usage_once("torch.package.PackageExporter") 230 self.debug = debug 231 if isinstance(f, (Path, str)): 232 f = str(f) 233 self.buffer: Optional[BinaryIO] = None 234 else: # is a byte buffer 235 self.buffer = f 236 237 self.zip_file = torch._C.PyTorchFileWriter(f) 238 self.zip_file.set_min_version(6) 239 self._written_files: Set[str] = set() 240 241 self.serialized_reduces: Dict[int, Any] = {} 242 243 # A graph tracking all the modules and pickle objects added to this 244 # package and the dependencies between them. 245 # - Each node is a module name (or a pickle name that looks like '<foo.obj.pkl>') 246 # - Each directed edge (u, v) means u depends on v. 247 # - Nodes may contain metadata that describe how to write the thing to the zipfile. 248 self.dependency_graph = DiGraph() 249 self.script_module_serializer = torch._C.ScriptModuleSerializer(self.zip_file) 250 self.storage_context = self.script_module_serializer.storage_context() 251 252 # These are OrderedDicts for compatibility with RemovableHandle. 253 # Generic OrderedDict type annotations are not present until 3.7. 254 # The real type signature is OrderedDict[int, Callable[[PackageExporter, str], None]] 255 self._extern_hooks: OrderedDict = OrderedDict() 256 self._mock_hooks: OrderedDict = OrderedDict() 257 self._intern_hooks: OrderedDict = OrderedDict() 258 259 if isinstance(importer, Importer): 260 self.importer = importer 261 else: 262 if not isinstance(importer, collections.abc.Sequence): 263 raise TypeError( 264 "importer arg should be an Importer or a sequence of Importers, " 265 f"got {type(importer)} instead." 266 ) 267 self.importer = OrderedImporter(*importer) 268 269 self.patterns: Dict[GlobGroup, _PatternInfo] = {} 270 self._unique_id = 0 271 272 def save_source_file( 273 self, module_name: str, file_or_directory: str, dependencies=True 274 ): 275 """Adds the local file system ``file_or_directory`` to the source package to provide the code 276 for ``module_name``. 277 278 Args: 279 module_name (str): e.g. ``"my_package.my_subpackage"``, code will be saved to provide code for this package. 280 file_or_directory (str): the path to a file or directory of code. When a directory, all python files in the directory 281 are recursively copied using :meth:`save_source_file`. If a file is named ``"/__init__.py"`` the code is treated 282 as a package. 283 dependencies (bool, optional): If ``True``, we scan the source for dependencies. 284 """ 285 path = Path(file_or_directory) 286 if path.is_dir(): 287 to_save = [] # list of tuples with arguments to save_source_string 288 module_path = module_name.replace(".", "/") 289 for filename in path.glob("**/*.py"): 290 relative_path = filename.relative_to(path).as_posix() 291 archivename = module_path + "/" + relative_path 292 submodule_name = None 293 if filename.name == "__init__.py": 294 submodule_name = archivename[: -len("/__init__.py")].replace( 295 "/", "." 296 ) 297 is_package = True 298 else: 299 submodule_name = archivename[: -len(".py")].replace("/", ".") 300 is_package = False 301 302 # we delay the call to save_source_string so that we record all the source files 303 # being provided by this directory structure _before_ attempting to resolve the dependencies 304 # on the source. This makes sure we don't try to copy over modules that will just get 305 # overwritten by this directory blob 306 to_save.append( 307 ( 308 submodule_name, 309 _read_file(str(filename)), 310 is_package, 311 dependencies, 312 ) 313 ) 314 315 for item in to_save: 316 self.save_source_string(*item) 317 else: 318 is_package = path.name == "__init__.py" 319 self.save_source_string( 320 module_name, 321 _read_file(file_or_directory), 322 is_package, 323 dependencies, 324 ) 325 326 def get_unique_id(self) -> str: 327 """Get an id. This id is guaranteed to only be handed out once for this package.""" 328 ret = str(self._unique_id) 329 self._unique_id += 1 330 return ret 331 332 def _get_dependencies( 333 self, src: str, module_name: str, is_package: bool 334 ) -> List[str]: 335 """Return all modules that this source code depends on. 336 337 Dependencies are found by scanning the source code for import-like statements. 338 339 Arguments: 340 src: The Python source code to analyze for dependencies. 341 module_name: The name of the module that ``src`` corresponds to. 342 is_package: Whether this module should be treated as a package. 343 See :py:meth:`save_source_string` for more info. 344 345 Returns: 346 A list containing modules detected as direct dependencies in 347 ``src``. The items in the list are guaranteed to be unique. 348 """ 349 package_name = ( 350 module_name if is_package else module_name.rsplit(".", maxsplit=1)[0] 351 ) 352 try: 353 dep_pairs = find_files_source_depends_on(src, package_name) 354 except Exception as e: 355 self.dependency_graph.add_node( 356 module_name, 357 error=PackagingErrorReason.DEPENDENCY_RESOLUTION_FAILED, 358 error_context=str(e), 359 ) 360 return [] 361 362 # Use a dict to get uniquing but also deterministic order 363 dependencies = {} 364 for dep_module_name, dep_module_obj in dep_pairs: 365 # handle the case where someone did something like `from pack import sub` 366 # where `sub` is a submodule. In this case we don't have to save pack, just sub. 367 # this ensures we don't pick up additional dependencies on pack. 368 # However, in the case where `sub` is not a submodule but an object, then we do have 369 # to save pack. 370 if dep_module_obj is not None: 371 possible_submodule = f"{dep_module_name}.{dep_module_obj}" 372 if self._module_exists(possible_submodule): 373 dependencies[possible_submodule] = True 374 # we don't need to save `pack` 375 continue 376 if self._module_exists(dep_module_name): 377 dependencies[dep_module_name] = True 378 379 return list(dependencies.keys()) 380 381 def save_source_string( 382 self, 383 module_name: str, 384 src: str, 385 is_package: bool = False, 386 dependencies: bool = True, 387 ): 388 """Adds ``src`` as the source code for ``module_name`` in the exported package. 389 390 Args: 391 module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code for this package. 392 src (str): The Python source code to save for this package. 393 is_package (bool, optional): If ``True``, this module is treated as a package. Packages are allowed to have submodules 394 (e.g. ``my_package.my_subpackage.my_subsubpackage``), and resources can be saved inside them. Defaults to ``False``. 395 dependencies (bool, optional): If ``True``, we scan the source for dependencies. 396 """ 397 self.dependency_graph.add_node( 398 module_name, 399 source=src, 400 is_package=is_package, 401 provided=True, 402 action=_ModuleProviderAction.INTERN, 403 ) 404 405 if dependencies: 406 deps = self._get_dependencies(src, module_name, is_package) 407 408 for dep in deps: 409 self.dependency_graph.add_edge(module_name, dep) 410 self.add_dependency(dep) 411 412 def _write_source_string( 413 self, 414 module_name: str, 415 src: str, 416 is_package: bool = False, 417 ): 418 """Write ``src`` as the source code for ``module_name`` in the zip archive. 419 420 Arguments are otherwise the same as for :meth:`save_source_string`. 421 """ 422 extension = "/__init__.py" if is_package else ".py" 423 filename = module_name.replace(".", "/") + extension 424 425 self._write(filename, src) 426 427 def _import_module(self, module_name: str): 428 try: 429 return self.importer.import_module(module_name) 430 except ModuleNotFoundError as e: 431 if not is_mangled(module_name): 432 raise 433 msg = ( 434 f"Module not found: '{module_name}'. Make sure the PackageImporter that " 435 "created this module is present in `self.importer`" 436 ) 437 raise ModuleNotFoundError(msg) from None 438 439 def _module_exists(self, module_name: str) -> bool: 440 try: 441 self._import_module(module_name) 442 return True 443 except Exception: 444 return False 445 446 def _get_source_of_module(self, module: types.ModuleType) -> Optional[str]: 447 filename = None 448 spec = getattr(module, "__spec__", None) 449 if spec is not None: 450 loader = getattr(spec, "loader", None) 451 if loader is not None and isinstance(loader, SourceFileLoader): 452 try: 453 filename = loader.get_filename(module.__name__) 454 except ImportError: 455 pass 456 if filename is None: 457 filename = getattr(module, "__file__", None) 458 if isinstance(filename, str) and filename.endswith(".py"): 459 return "".join(linecache.getlines(filename, module.__dict__)) 460 return None 461 462 def add_dependency(self, module_name: str, dependencies=True): 463 """Given a module, add it to the dependency graph according to patterns 464 specified by the user. 465 """ 466 if ( 467 module_name in self.dependency_graph 468 and self.dependency_graph.nodes[module_name].get("provided") is True 469 ): 470 return 471 472 # Special case: PackageImporter provides a special module called 473 # `torch_package_importer` that allows packaged modules to reference 474 # their PackageImporter. We don't want to re-export this. 475 if module_name == "torch_package_importer": 476 self.dependency_graph.add_node( 477 module_name, 478 action=_ModuleProviderAction.SKIP, 479 provided=True, 480 ) 481 return 482 483 if module_name == "_mock": 484 self.dependency_graph.add_node( 485 module_name, 486 action=_ModuleProviderAction.REPACKAGED_MOCK_MODULE, 487 provided=True, 488 ) 489 return 490 491 if self._can_implicitly_extern(module_name): 492 self.dependency_graph.add_node( 493 module_name, action=_ModuleProviderAction.EXTERN, provided=True 494 ) 495 return 496 497 for pattern, pattern_info in self.patterns.items(): 498 if pattern.matches(module_name): 499 pattern_info.was_matched = True 500 self.dependency_graph.add_node( 501 module_name, action=pattern_info.action, provided=True 502 ) 503 504 if pattern_info.action == _ModuleProviderAction.DENY: 505 # Requiring a denied module just adds an error to the graph. 506 self.dependency_graph.add_node( 507 module_name, error=PackagingErrorReason.DENIED 508 ) 509 510 # If we are interning this module, we need to retrieve its 511 # dependencies and package those as well. 512 if pattern_info.action == _ModuleProviderAction.INTERN: 513 self._intern_module(module_name, dependencies) 514 return 515 516 # No patterns have matched. Explicitly add this as an error. 517 self.dependency_graph.add_node( 518 module_name, error=PackagingErrorReason.NO_ACTION 519 ) 520 521 def save_module(self, module_name: str, dependencies=True): 522 """Save the code for ``module`` into the package. Code for the module is resolved using the ``importers`` path to find the 523 module object, and then using its ``__file__`` attribute to find the source code. 524 525 Args: 526 module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code 527 for this package. 528 dependencies (bool, optional): If ``True``, we scan the source for dependencies. 529 """ 530 if not isinstance(module_name, str): 531 raise TypeError( 532 "save_module() expects a string input, did you perhaps mean to pass `__name__`?" 533 ) 534 535 self._intern_module(module_name, dependencies) 536 537 def _intern_module( 538 self, 539 module_name: str, 540 dependencies: bool, 541 ): 542 """Adds the module to the dependency graph as an interned module, 543 along with any metadata needed to write it out to the zipfile at serialization time. 544 """ 545 module_obj = self._import_module(module_name) 546 # Subtle: if the import above succeeded, either: 547 # 1. The module name is not mangled, and this was just a regular import, or 548 # 2. The module name is mangled, but one of the importers was able to 549 # recognize the mangling and import it. 550 # Either way, it is now safe to demangle this name so that we don't 551 # serialize the mangled version to the package. 552 module_name = demangle(module_name) 553 554 # Find dependencies of this module and require them as well. 555 is_package = hasattr(module_obj, "__path__") 556 source = self._get_source_of_module(module_obj) 557 if source is None: 558 # Couldn't find a source! Add it to our dependency graph as broken 559 # and continue. 560 filename = getattr(module_obj, "__file__", None) 561 error_context = None 562 if filename is None: 563 packaging_error = PackagingErrorReason.NO_DUNDER_FILE 564 elif filename.endswith(tuple(importlib.machinery.EXTENSION_SUFFIXES)): 565 packaging_error = PackagingErrorReason.IS_EXTENSION_MODULE 566 else: 567 packaging_error = PackagingErrorReason.SOURCE_FILE_NOT_FOUND 568 error_context = f"filename: {filename}" 569 self.dependency_graph.add_node( 570 module_name, 571 action=_ModuleProviderAction.INTERN, 572 is_package=is_package, 573 error=packaging_error, 574 error_context=error_context, 575 provided=True, 576 ) 577 return 578 579 self.dependency_graph.add_node( 580 module_name, 581 action=_ModuleProviderAction.INTERN, 582 is_package=is_package, 583 source=source, 584 provided=True, 585 ) 586 587 if dependencies: 588 deps = self._get_dependencies(source, module_name, is_package) 589 for dep in deps: 590 self.dependency_graph.add_edge(module_name, dep) 591 self.add_dependency(dep) 592 593 def save_pickle( 594 self, 595 package: str, 596 resource: str, 597 obj: Any, 598 dependencies: bool = True, 599 pickle_protocol: int = 3, 600 ): 601 """Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into 602 the archive rather than a stand-alone file. Standard pickle does not save the code, only the objects. 603 If ``dependencies`` is true, this method will also scan the pickled objects for which modules are required 604 to reconstruct them and save the relevant code. 605 606 To be able to save an object where ``type(obj).__name__`` is ``my_module.MyObject``, 607 ``my_module.MyObject`` must resolve to the class of the object according to the ``importer`` order. When saving objects that 608 have previously been packaged, the importer's ``import_module`` method will need to be present in the ``importer`` list 609 for this to work. 610 611 Args: 612 package (str): The name of module package this resource should go in (e.g. ``"my_package.my_subpackage"``). 613 resource (str): A unique name for the resource, used to identify it to load. 614 obj (Any): The object to save, must be picklable. 615 dependencies (bool, optional): If ``True``, we scan the source for dependencies. 616 """ 617 618 assert (pickle_protocol == 4) or ( 619 pickle_protocol == 3 620 ), "torch.package only supports pickle protocols 3 and 4" 621 622 filename = self._filename(package, resource) 623 # Write the pickle data for `obj` 624 data_buf = io.BytesIO() 625 pickler = create_pickler(data_buf, self.importer, protocol=pickle_protocol) 626 pickler.persistent_id = self._persistent_id 627 pickler.dump(obj) 628 data_value = data_buf.getvalue() 629 mocked_modules = defaultdict(list) 630 name_in_dependency_graph = f"<{package}.{resource}>" 631 self.dependency_graph.add_node( 632 name_in_dependency_graph, 633 action=_ModuleProviderAction.INTERN, 634 provided=True, 635 is_pickle=True, 636 ) 637 638 def _check_mocked_error(module: Optional[str], field: Optional[str]): 639 """ 640 checks if an object (field) comes from a mocked module and then adds 641 the pair to mocked_modules which contains mocked modules paired with their 642 list of mocked objects present in the pickle. 643 644 We also hold the invariant that the first user defined rule that applies 645 to the module is the one we use. 646 """ 647 648 assert isinstance(module, str) 649 assert isinstance(field, str) 650 if self._can_implicitly_extern(module): 651 return 652 for pattern, pattern_info in self.patterns.items(): 653 if pattern.matches(module): 654 if pattern_info.action == _ModuleProviderAction.MOCK: 655 mocked_modules[module].append(field) 656 return 657 658 if dependencies: 659 all_dependencies = [] 660 module = None 661 field = None 662 memo: DefaultDict[int, str] = defaultdict(None) 663 memo_count = 0 664 # pickletools.dis(data_value) 665 for opcode, arg, pos in pickletools.genops(data_value): 666 if pickle_protocol == 4: 667 if ( 668 opcode.name == "SHORT_BINUNICODE" 669 or opcode.name == "BINUNICODE" 670 or opcode.name == "BINUNICODE8" 671 ): 672 assert isinstance(arg, str) 673 module = field 674 field = arg 675 memo[memo_count] = arg 676 elif ( 677 opcode.name == "LONG_BINGET" 678 or opcode.name == "BINGET" 679 or opcode.name == "GET" 680 ): 681 assert isinstance(arg, int) 682 module = field 683 field = memo.get(arg, None) 684 elif opcode.name == "MEMOIZE": 685 memo_count += 1 686 elif opcode.name == "STACK_GLOBAL": 687 if module is None: 688 # If not module was passed on in the entries preceeding this one, continue. 689 continue 690 assert isinstance(module, str) 691 if module not in all_dependencies: 692 all_dependencies.append(module) 693 _check_mocked_error(module, field) 694 elif ( 695 pickle_protocol == 3 and opcode.name == "GLOBAL" 696 ): # a global reference 697 assert isinstance(arg, str) 698 module, field = arg.split(" ") 699 if module not in all_dependencies: 700 all_dependencies.append(module) 701 _check_mocked_error(module, field) 702 for module_name in all_dependencies: 703 self.dependency_graph.add_edge(name_in_dependency_graph, module_name) 704 705 """ If an object happens to come from a mocked module, then we collect these errors and spit them 706 out with the other errors found by package exporter. 707 """ 708 if module_name in mocked_modules: 709 assert isinstance(module_name, str) 710 fields = mocked_modules[module_name] 711 self.dependency_graph.add_node( 712 module_name, 713 action=_ModuleProviderAction.MOCK, 714 error=PackagingErrorReason.MOCKED_BUT_STILL_USED, 715 error_context=f"Object(s) '{fields}' from module `{module_name}` was mocked out during packaging " 716 f"but is being used in resource - `{resource}` in package `{package}`. ", 717 provided=True, 718 ) 719 else: 720 self.add_dependency(module_name) 721 722 self._write(filename, data_value) 723 724 def save_text(self, package: str, resource: str, text: str): 725 """Save text data to the package. 726 727 Args: 728 package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``). 729 resource (str): A unique name for the resource, used to identify it to load. 730 text (str): The contents to save. 731 """ 732 return self.save_binary(package, resource, text.encode("utf-8")) 733 734 def save_binary(self, package, resource, binary: bytes): 735 """Save raw bytes to the package. 736 737 Args: 738 package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``). 739 resource (str): A unique name for the resource, used to identify it to load. 740 binary (str): The data to save. 741 """ 742 filename = self._filename(package, resource) 743 self._write(filename, binary) 744 745 def register_extern_hook(self, hook: ActionHook) -> RemovableHandle: 746 """Registers an extern hook on the exporter. 747 748 The hook will be called each time a module matches against an :meth:`extern` pattern. 749 It should have the following signature:: 750 751 hook(exporter: PackageExporter, module_name: str) -> None 752 753 Hooks will be called in order of registration. 754 755 Returns: 756 :class:`torch.utils.hooks.RemovableHandle`: 757 A handle that can be used to remove the added hook by calling 758 ``handle.remove()``. 759 """ 760 handle = RemovableHandle(self._extern_hooks) 761 self._extern_hooks[handle.id] = hook 762 return handle 763 764 def register_mock_hook(self, hook: ActionHook) -> RemovableHandle: 765 """Registers a mock hook on the exporter. 766 767 The hook will be called each time a module matches against a :meth:`mock` pattern. 768 It should have the following signature:: 769 770 hook(exporter: PackageExporter, module_name: str) -> None 771 772 Hooks will be called in order of registration. 773 774 Returns: 775 :class:`torch.utils.hooks.RemovableHandle`: 776 A handle that can be used to remove the added hook by calling 777 ``handle.remove()``. 778 """ 779 handle = RemovableHandle(self._mock_hooks) 780 self._mock_hooks[handle.id] = hook 781 return handle 782 783 def register_intern_hook(self, hook: ActionHook) -> RemovableHandle: 784 """Registers an intern hook on the exporter. 785 786 The hook will be called each time a module matches against an :meth:`intern` pattern. 787 It should have the following signature:: 788 789 hook(exporter: PackageExporter, module_name: str) -> None 790 791 Hooks will be called in order of registration. 792 793 Returns: 794 :class:`torch.utils.hooks.RemovableHandle`: 795 A handle that can be used to remove the added hook by calling 796 ``handle.remove()``. 797 """ 798 handle = RemovableHandle(self._intern_hooks) 799 self._intern_hooks[handle.id] = hook 800 return handle 801 802 def intern( 803 self, 804 include: "GlobPattern", 805 *, 806 exclude: "GlobPattern" = (), 807 allow_empty: bool = True, 808 ): 809 """Specify modules that should be packaged. A module must match some ``intern`` pattern in order to be 810 included in the package and have its dependencies processed recursively. 811 812 Args: 813 include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings 814 for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`. 815 816 exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. 817 818 allow_empty (bool): An optional flag that specifies whether the intern modules specified by this call 819 to the ``intern`` method must be matched to some module during packaging. If an ``intern`` module glob 820 pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) 821 before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, no such exception is thrown. 822 823 """ 824 self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( 825 _ModuleProviderAction.INTERN, allow_empty 826 ) 827 828 def mock( 829 self, 830 include: "GlobPattern", 831 *, 832 exclude: "GlobPattern" = (), 833 allow_empty: bool = True, 834 ): 835 """Replace some required modules with a mock implementation. Mocked modules will return a fake 836 object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes 837 find files that are imported by model files but whose functionality is never used 838 (e.g. custom serialization code or training helpers). 839 Use this function to mock this functionality out without having to modify the original code. 840 841 Args: 842 include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings 843 for the names of the modules to be mocked out. Strings can also be a glob-style pattern 844 string that may match multiple modules. Any required dependencies that match this pattern 845 string will be mocked out automatically. 846 847 Examples : 848 ``'torch.**'`` -- matches ``torch`` and all submodules of torch, e.g. ``'torch.nn'`` 849 and ``'torch.nn.functional'`` 850 851 ``'torch.*'`` -- matches ``'torch.nn'`` or ``'torch.functional'``, but not 852 ``'torch.nn.functional'`` 853 854 exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. 855 e.g. ``include='torch.**', exclude='torch.foo'`` will mock all torch packages except ``'torch.foo'``, 856 Default: is ``[]``. 857 858 allow_empty (bool): An optional flag that specifies whether the mock implementation(s) specified by this call 859 to the :meth:`mock` method must be matched to some module during packaging. If a mock is added with 860 ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) and the mock has 861 not been matched to a module used by the package being exported, an exception is thrown. 862 If ``allow_empty=True``, no such exception is thrown. 863 864 """ 865 self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( 866 _ModuleProviderAction.MOCK, allow_empty 867 ) 868 869 def extern( 870 self, 871 include: "GlobPattern", 872 *, 873 exclude: "GlobPattern" = (), 874 allow_empty: bool = True, 875 ): 876 """Include ``module`` in the list of external modules the package can import. 877 This will prevent dependency discovery from saving 878 it in the package. The importer will load an external module directly from the standard import system. 879 Code for extern modules must also exist in the process loading the package. 880 881 Args: 882 include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings 883 for the names of the modules to be externed. This can also be a glob-style pattern, as 884 described in :meth:`mock`. 885 886 exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the 887 include string. 888 889 allow_empty (bool): An optional flag that specifies whether the extern modules specified by this call 890 to the ``extern`` method must be matched to some module during packaging. If an extern module glob 891 pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via 892 ``__exit__``) before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, 893 no such exception is thrown. 894 895 """ 896 self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( 897 _ModuleProviderAction.EXTERN, allow_empty 898 ) 899 900 def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()): 901 """Blocklist modules who names match the given glob patterns from the list of modules the package can import. 902 If a dependency on any matching packages is found, a :class:`PackagingError` is raised. 903 904 Args: 905 include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings 906 for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`. 907 908 exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. 909 """ 910 self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( 911 _ModuleProviderAction.DENY, allow_empty=True 912 ) 913 914 def _persistent_id(self, obj): 915 if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): 916 storage: Storage 917 if isinstance(obj, torch.storage.TypedStorage): 918 # TODO: Once we decide to break serialization FC, we can 919 # remove this case 920 untyped_storage = obj._untyped_storage 921 storage_type_str = obj.pickle_storage_type() 922 storage_type = getattr(torch, storage_type_str) 923 storage = cast(Storage, untyped_storage) 924 storage_numel = obj.size() 925 926 elif isinstance(obj, torch.UntypedStorage): 927 untyped_storage = obj 928 storage = cast(Storage, untyped_storage) 929 storage_type = normalize_storage_type(type(storage)) 930 storage_numel = storage.nbytes() 931 else: 932 raise RuntimeError(f"storage type not recognized: {type(obj)}") 933 934 location = location_tag(storage) 935 936 # serialize storage if not already written 937 storage_present = self.storage_context.has_storage(storage) 938 storage_id = self.storage_context.get_or_add_storage(storage) 939 if not storage_present: 940 if storage.device.type != "cpu": 941 storage = storage.cpu() 942 num_bytes = storage.nbytes() 943 self.zip_file.write_record( 944 f".data/{storage_id}.storage", storage, num_bytes 945 ) 946 return ("storage", storage_type, storage_id, location, storage_numel) 947 948 if hasattr(obj, "__reduce_package__"): 949 if _gate_torchscript_serialization and isinstance( 950 obj, torch.jit.RecursiveScriptModule 951 ): 952 raise Exception( # noqa: TRY002 953 "Serializing ScriptModules directly into a package is a beta feature. " 954 "To use, set global " 955 "`torch.package.package_exporter._gate_torchscript_serialization` to `False`." 956 ) 957 if self.serialized_reduces.get(id(obj)) is None: 958 self.serialized_reduces[id(obj)] = ( 959 "reduce_package", 960 id(obj), 961 *obj.__reduce_package__(self), 962 ) 963 964 return self.serialized_reduces[id(obj)] 965 966 return None 967 968 def __enter__(self): 969 return self 970 971 def __exit__(self, exc_type, exc_value, traceback): 972 # If __exit__ was called because an exception was raised, we do not 973 # attempt to finalize the package. Instead, control is returned to the 974 # caller to continue raising the exception. 975 if exc_type is not None: 976 # Do the bare minimum to leave the open buffer in a valid state. 977 self._finalize_zip() 978 return 979 980 self.close() 981 982 def _write(self, filename, str_or_bytes): 983 if filename in self._written_files: 984 raise AssertionError( 985 f"Tried to write file '{filename}', but it already exists in this archive. " 986 "Please file a bug." 987 ) 988 self._written_files.add(filename) 989 990 if is_mangled(filename): 991 raise AssertionError( 992 f"Tried to save a torch.package'd module as '{filename}'. " 993 "Directly saving torch.package'd modules is not allowed." 994 ) 995 if isinstance(str_or_bytes, str): 996 str_or_bytes = str_or_bytes.encode("utf-8") 997 self.zip_file.write_record(filename, str_or_bytes, len(str_or_bytes)) 998 999 def _validate_dependency_graph(self): 1000 # 1. Check the graph for any errors inserted during dependency analysis. 1001 for attrs in self.dependency_graph.nodes.values(): 1002 if "error" in attrs: 1003 raise PackagingError(self.dependency_graph, debug=self.debug) 1004 1005 # 2. Check that all patterns for which allow_empty=False have been matched at least once. 1006 for pattern, pattern_info in self.patterns.items(): 1007 if not pattern_info.allow_empty and not pattern_info.was_matched: 1008 raise EmptyMatchError( 1009 f"Exporter did not match any modules to {pattern}, which was marked as allow_empty=False" 1010 ) 1011 1012 def _write_mock_file(self): 1013 if "_mock.py" not in self._written_files: 1014 mock_file = str(Path(__file__).parent / "_mock.py") 1015 self._write_source_string("_mock", _read_file(mock_file), is_package=False) 1016 1017 def _execute_dependency_graph(self): 1018 """Takes a finalized dependency graph describing how to package all 1019 modules and executes it, writing to the ZIP archive. 1020 """ 1021 self._validate_dependency_graph() 1022 1023 extern_modules = [] 1024 for module_name, attrs in self.dependency_graph.nodes.items(): 1025 action = attrs["action"] 1026 1027 if action == _ModuleProviderAction.EXTERN: 1028 for hook in self._extern_hooks.values(): 1029 hook(self, module_name) 1030 1031 extern_modules.append(module_name) 1032 1033 elif action == _ModuleProviderAction.MOCK: 1034 for hook in self._mock_hooks.values(): 1035 hook(self, module_name) 1036 1037 self._write_mock_file() 1038 1039 is_package = hasattr(self._import_module(module_name), "__path__") 1040 self._write_source_string(module_name, _MOCK_IMPL, is_package) 1041 1042 elif action == _ModuleProviderAction.INTERN: 1043 for hook in self._intern_hooks.values(): 1044 hook(self, module_name) 1045 1046 # The node in the dependency graph contains metadata that tells us 1047 # how to intern the module. 1048 if "provided" not in attrs: 1049 raise AssertionError( 1050 f"Module was marked `intern` but not provided: {module_name}" 1051 ) 1052 1053 if attrs.get("is_pickle") is True: 1054 # This node came from save_pickle, we don't need to write any source for it. 1055 continue 1056 1057 is_package = attrs["is_package"] 1058 source = attrs["source"] 1059 self._write_source_string(module_name, source, is_package) 1060 1061 elif action == _ModuleProviderAction.REPACKAGED_MOCK_MODULE: 1062 self._write_mock_file() 1063 elif action == _ModuleProviderAction.SKIP: 1064 continue 1065 else: 1066 raise AssertionError( 1067 f"Invalid action: {module_name}, {action}. Please report a bug to PyTorch." 1068 ) 1069 1070 extern_file_contents = "\n".join(extern_modules) + "\n" 1071 self._write(".data/extern_modules", extern_file_contents) 1072 1073 def _write_python_version(self): 1074 """Writes the python version that the package was created with to .data/python_version""" 1075 self._write(".data/python_version", platform.python_version()) 1076 1077 def close(self): 1078 """Write the package to the filesystem. Any calls after :meth:`close` are now invalid. 1079 It is preferable to use resource guard syntax instead:: 1080 1081 with PackageExporter("file.zip") as e: 1082 ... 1083 """ 1084 self._execute_dependency_graph() 1085 self._write_python_version() 1086 1087 self.script_module_serializer.write_files() 1088 self._finalize_zip() 1089 1090 def _finalize_zip(self): 1091 """Called at the very end of packaging to leave the zipfile in a closed but valid state.""" 1092 del self.zip_file 1093 if self.buffer: 1094 self.buffer.flush() 1095 1096 def _filename(self, package, resource): 1097 package_path = package.replace(".", "/") 1098 resource = _normalize_path(resource) 1099 return f"{package_path}/{resource}" 1100 1101 def _can_implicitly_extern(self, module_name: str): 1102 top_level_package_name = module_name.partition(".")[0] 1103 return top_level_package_name == "torch" or ( 1104 top_level_package_name not in _DISALLOWED_MODULES 1105 and is_stdlib_module(top_level_package_name) 1106 ) 1107 1108 def dependency_graph_string(self) -> str: 1109 """Returns digraph string representation of dependencies in package. 1110 1111 Returns: 1112 A string representation of dependencies in package. 1113 """ 1114 return self.dependency_graph.to_dot() 1115 1116 def _nodes_with_action_type( 1117 self, action: Optional[_ModuleProviderAction] 1118 ) -> List[str]: 1119 result = [] 1120 for name, node_dict in self.dependency_graph.nodes.items(): 1121 node_action = node_dict.get("action", None) 1122 if node_action == action and "is_pickle" not in node_dict: 1123 result.append(name) 1124 result.sort() 1125 return result 1126 1127 def externed_modules(self) -> List[str]: 1128 """Return all modules that are currently externed. 1129 1130 Returns: 1131 A list containing the names of modules which will be 1132 externed in this package. 1133 """ 1134 return self._nodes_with_action_type(_ModuleProviderAction.EXTERN) 1135 1136 def interned_modules(self) -> List[str]: 1137 """Return all modules that are currently interned. 1138 1139 Returns: 1140 A list containing the names of modules which will be 1141 interned in this package. 1142 """ 1143 return self._nodes_with_action_type(_ModuleProviderAction.INTERN) 1144 1145 def mocked_modules(self) -> List[str]: 1146 """Return all modules that are currently mocked. 1147 1148 Returns: 1149 A list containing the names of modules which will be 1150 mocked in this package. 1151 """ 1152 return self._nodes_with_action_type(_ModuleProviderAction.MOCK) 1153 1154 def denied_modules(self) -> List[str]: 1155 """Return all modules that are currently denied. 1156 1157 Returns: 1158 A list containing the names of modules which will be 1159 denied in this package. 1160 """ 1161 return self._nodes_with_action_type(_ModuleProviderAction.DENY) 1162 1163 def get_rdeps(self, module_name: str) -> List[str]: 1164 """Return a list of all modules which depend on the module ``module_name``. 1165 1166 Returns: 1167 A list containing the names of modules which depend on ``module_name``. 1168 """ 1169 if module_name in self.dependency_graph._pred.keys(): 1170 return list(self.dependency_graph._pred[module_name].keys()) 1171 else: 1172 return [] 1173 1174 def all_paths(self, src: str, dst: str) -> str: 1175 """Return a dot representation of the subgraph 1176 that has all paths from src to dst. 1177 1178 Returns: 1179 A dot representation containing all paths from src to dst. 1180 (https://graphviz.org/doc/info/lang.html) 1181 """ 1182 return self.dependency_graph.all_paths(src, dst) 1183 1184 1185# even though these are in the standard library, we do not allow them to be 1186# automatically externed since they offer a lot of system level access 1187_DISALLOWED_MODULES = ["sys", "io"] 1188 1189_MOCK_IMPL = """\ 1190from _mock import MockedObject 1191def __getattr__(attr: str): 1192 return MockedObject(__name__ + '.' + attr, _suppress_err=True) 1193""" 1194 1195 1196def _read_file(filename: str) -> str: 1197 with open(filename, "rb") as f: 1198 b = f.read() 1199 return b.decode("utf-8") 1200