xref: /aosp_15_r20/external/pytorch/torch/package/package_exporter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import importlib.machinery
4import io
5import linecache
6import pickletools
7import platform
8import types
9from collections import defaultdict, OrderedDict
10from dataclasses import dataclass
11from enum import Enum
12from importlib.machinery import SourceFileLoader
13from pathlib import Path
14from typing import (
15    Any,
16    BinaryIO,
17    Callable,
18    cast,
19    DefaultDict,
20    Dict,
21    List,
22    Optional,
23    Sequence,
24    Set,
25    Union,
26)
27
28import torch
29from torch.serialization import location_tag, normalize_storage_type
30from torch.types import Storage
31from torch.utils.hooks import RemovableHandle
32
33from ._digraph import DiGraph
34from ._importlib import _normalize_path
35from ._mangling import demangle, is_mangled
36from ._package_pickler import create_pickler
37from ._stdlib import is_stdlib_module
38from .find_file_dependencies import find_files_source_depends_on
39from .glob_group import GlobGroup, GlobPattern
40from .importer import Importer, OrderedImporter, sys_importer
41
42
43__all__ = [
44    "PackagingErrorReason",
45    "EmptyMatchError",
46    "PackagingError",
47    "PackageExporter",
48]
49
50_gate_torchscript_serialization = True
51
52ActionHook = Callable[["PackageExporter", str], None]
53
54
55class _ModuleProviderAction(Enum):
56    """Represents one of the actions that :class:`PackageExporter` can take on a module.
57
58    See :meth:`PackageExporter.extern` and friends for a description of what the actions do.
59    """
60
61    INTERN = 1
62    EXTERN = 2
63    MOCK = 3
64    DENY = 4
65    # Special case: when a module is mocked, PackageExporter writes out a
66    # `_mock` module that implements our mocking stubs. If we re-package code,
67    # we may encounter a `_mock` module from the original package. If we do,
68    # just ignore it and write a `_mock` module once.
69    REPACKAGED_MOCK_MODULE = 5
70    # Special case: PackageImporter adds a fake module
71    # (`torch_package_importer`) that allows packaged code to access it. Don't
72    # re-export this.
73    SKIP = 6
74
75
76class PackagingErrorReason(Enum):
77    """Listing of different reasons a dependency may fail to package.
78
79    This enum is used to provide good error messages when
80    :class:`PackagingError` is raised.
81    """
82
83    def __repr__(self):
84        return f"<{self.__class__.__name__}.{self.name}>"
85
86    IS_EXTENSION_MODULE = (
87        "Module is a C extension module. torch.package supports Python modules only."
88    )
89    NO_DUNDER_FILE = "Module had no __file__ defined."
90    SOURCE_FILE_NOT_FOUND = (
91        "Module had a __file__, but we could not find it in your filesystem."
92    )
93    DEPENDENCY_RESOLUTION_FAILED = "Dependency resolution failed."
94    NO_ACTION = (
95        "Module did not match against any action pattern. Extern, mock, or intern it."
96    )
97    DENIED = "Module was denied by a pattern."
98    MOCKED_BUT_STILL_USED = (
99        "Module was mocked out, but is still being used in the package. "
100        "Please intern or extern the mocked modules if objects are supposed to be in "
101        "the package."
102    )
103
104
105@dataclass
106class _PatternInfo:
107    """Holds :class:`PackageExporter`-specific info about how to execute matches against"""
108
109    # What action to take on a module that matches this pattern.
110    action: _ModuleProviderAction
111    # The value of `allow_empty` the user gave when specifying the pattern.
112    allow_empty: bool
113    # Whether this pattern has been matched during packaging.
114    was_matched: bool
115
116    def __init__(self, action, allow_empty):
117        self.action = action
118        self.allow_empty = allow_empty
119        self.was_matched = False
120
121
122class EmptyMatchError(Exception):
123    """This is an exception that is thrown when a mock or extern is marked as
124    ``allow_empty=False``, and is not matched with any module during packaging.
125    """
126
127
128class PackagingError(Exception):
129    """This exception is raised when there is an issue with exporting a package.
130    ``PackageExporter`` will attempt to gather up all the errors and present
131    them to you at once.
132    """
133
134    def __init__(self, dependency_graph: DiGraph, debug=False):
135        # Group errors by reason.
136        broken: Dict[PackagingErrorReason, List[str]] = defaultdict(list)
137        for module_name, attrs in dependency_graph.nodes.items():
138            error = attrs.get("error")
139            if error is None:
140                continue
141            if error == PackagingErrorReason.NO_ACTION:
142                assert "action" not in attrs
143            broken[error].append(module_name)
144
145        message = io.StringIO()
146        message.write("\n")
147
148        for reason, module_names in broken.items():
149            message.write(f"* {reason.value}\n")
150            for module_name in module_names:
151                message.write(f"    {module_name}\n")
152
153                # Print additional context if it's provided.
154                error_context = dependency_graph.nodes[module_name].get("error_context")
155                if error_context is not None:
156                    message.write(f"      Context: {error_context}\n")
157                if module_name in _DISALLOWED_MODULES:
158                    message.write(
159                        "      Note: While we usually use modules in the python standard library "
160                        f"from the local environment, `{module_name}` has a lot of system "
161                        "level access and therefore can pose a security risk. We heavily "
162                        f"recommend removing `{module_name}` from your packaged code. However, if that "
163                        "is not possible, add it to the extern list by calling "
164                        f'PackageExporter.extern("`{module_name}`")\n'
165                    )
166                if debug:
167                    module_path = dependency_graph.first_path(module_name)
168                    message.write(
169                        f"      A path to {module_name}: {' -> '.join(module_path)}\n"
170                    )
171        if not debug:
172            message.write("\n")
173            message.write(
174                "Set debug=True when invoking PackageExporter for a visualization of where "
175                "broken modules are coming from!\n"
176            )
177        # Save the dependency graph so that tooling can get at it.
178        self.dependency_graph = dependency_graph
179        super().__init__(message.getvalue())
180
181
182class PackageExporter:
183    """Exporters allow you to write packages of code, pickled Python data, and
184    arbitrary binary and text resources into a self-contained package.
185
186    Imports can load this code in a hermetic way, such that code is loaded
187    from the package rather than the normal Python import system. This allows
188    for the packaging of PyTorch model code and data so that it can be run
189    on a server or used in the future for transfer learning.
190
191    The code contained in packages is copied file-by-file from the original
192    source when it is created, and the file format is a specially organized
193    zip file. Future users of the package can unzip the package, and edit the code
194    in order to perform custom modifications to it.
195
196    The importer for packages ensures that code in the module can only be loaded from
197    within the package, except for modules explicitly listed as external using :meth:`extern`.
198    The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
199    This prevents "implicit" dependencies where the package runs locally because it is importing
200    a locally-installed package, but then fails when the package is copied to another machine.
201
202    When source code is added to the package, the exporter can optionally scan it
203    for further code dependencies (``dependencies=True``). It looks for import statements,
204    resolves relative references to qualified module names, and performs an action specified by the user
205    (See: :meth:`extern`, :meth:`mock`, and :meth:`intern`).
206    """
207
208    """A importer that will be searched in order to find the modules referenced by other modules or by
209    pickled objects. The default module environment just uses sys_importer, which searches the Python environment.
210    """
211    importer: Importer
212
213    def __init__(
214        self,
215        f: Union[str, Path, BinaryIO],
216        importer: Union[Importer, Sequence[Importer]] = sys_importer,
217        debug: bool = False,
218    ):
219        """
220        Create an exporter.
221
222        Args:
223            f: The location to export to. Can be a  ``string``/``Path`` object containing a filename
224                or a binary I/O object.
225            importer: If a single Importer is passed, use that to search for modules.
226                If a sequence of importers are passed, an ``OrderedImporter`` will be constructed out of them.
227            debug: If set to True, add path of broken modules to PackagingErrors.
228        """
229        torch._C._log_api_usage_once("torch.package.PackageExporter")
230        self.debug = debug
231        if isinstance(f, (Path, str)):
232            f = str(f)
233            self.buffer: Optional[BinaryIO] = None
234        else:  # is a byte buffer
235            self.buffer = f
236
237        self.zip_file = torch._C.PyTorchFileWriter(f)
238        self.zip_file.set_min_version(6)
239        self._written_files: Set[str] = set()
240
241        self.serialized_reduces: Dict[int, Any] = {}
242
243        # A graph tracking all the modules and pickle objects added to this
244        # package and the dependencies between them.
245        # - Each node is a module name (or a pickle name that looks like '<foo.obj.pkl>')
246        # - Each directed edge (u, v) means u depends on v.
247        # - Nodes may contain metadata that describe how to write the thing to the zipfile.
248        self.dependency_graph = DiGraph()
249        self.script_module_serializer = torch._C.ScriptModuleSerializer(self.zip_file)
250        self.storage_context = self.script_module_serializer.storage_context()
251
252        # These are OrderedDicts for compatibility with RemovableHandle.
253        # Generic OrderedDict type annotations are not present until 3.7.
254        # The real type signature is OrderedDict[int, Callable[[PackageExporter, str], None]]
255        self._extern_hooks: OrderedDict = OrderedDict()
256        self._mock_hooks: OrderedDict = OrderedDict()
257        self._intern_hooks: OrderedDict = OrderedDict()
258
259        if isinstance(importer, Importer):
260            self.importer = importer
261        else:
262            if not isinstance(importer, collections.abc.Sequence):
263                raise TypeError(
264                    "importer arg should be an Importer or a sequence of Importers, "
265                    f"got {type(importer)} instead."
266                )
267            self.importer = OrderedImporter(*importer)
268
269        self.patterns: Dict[GlobGroup, _PatternInfo] = {}
270        self._unique_id = 0
271
272    def save_source_file(
273        self, module_name: str, file_or_directory: str, dependencies=True
274    ):
275        """Adds the local file system ``file_or_directory`` to the source package to provide the code
276        for ``module_name``.
277
278        Args:
279            module_name (str): e.g. ``"my_package.my_subpackage"``, code will be saved to provide code for this package.
280            file_or_directory (str): the path to a file or directory of code. When a directory, all python files in the directory
281                are recursively copied using :meth:`save_source_file`. If a file is named ``"/__init__.py"`` the code is treated
282                as a package.
283            dependencies (bool, optional): If ``True``, we scan the source for dependencies.
284        """
285        path = Path(file_or_directory)
286        if path.is_dir():
287            to_save = []  # list of tuples with arguments to save_source_string
288            module_path = module_name.replace(".", "/")
289            for filename in path.glob("**/*.py"):
290                relative_path = filename.relative_to(path).as_posix()
291                archivename = module_path + "/" + relative_path
292                submodule_name = None
293                if filename.name == "__init__.py":
294                    submodule_name = archivename[: -len("/__init__.py")].replace(
295                        "/", "."
296                    )
297                    is_package = True
298                else:
299                    submodule_name = archivename[: -len(".py")].replace("/", ".")
300                    is_package = False
301
302                # we delay the call to save_source_string so that we record all the source files
303                # being provided by this directory structure _before_ attempting to resolve the dependencies
304                # on the source. This makes sure we don't try to copy over modules that will just get
305                # overwritten by this directory blob
306                to_save.append(
307                    (
308                        submodule_name,
309                        _read_file(str(filename)),
310                        is_package,
311                        dependencies,
312                    )
313                )
314
315            for item in to_save:
316                self.save_source_string(*item)
317        else:
318            is_package = path.name == "__init__.py"
319            self.save_source_string(
320                module_name,
321                _read_file(file_or_directory),
322                is_package,
323                dependencies,
324            )
325
326    def get_unique_id(self) -> str:
327        """Get an id. This id is guaranteed to only be handed out once for this package."""
328        ret = str(self._unique_id)
329        self._unique_id += 1
330        return ret
331
332    def _get_dependencies(
333        self, src: str, module_name: str, is_package: bool
334    ) -> List[str]:
335        """Return all modules that this source code depends on.
336
337        Dependencies are found by scanning the source code for import-like statements.
338
339        Arguments:
340            src: The Python source code to analyze for dependencies.
341            module_name: The name of the module that ``src`` corresponds to.
342            is_package: Whether this module should be treated as a package.
343                See :py:meth:`save_source_string` for more info.
344
345        Returns:
346            A list containing modules detected as direct dependencies in
347            ``src``.  The items in the list are guaranteed to be unique.
348        """
349        package_name = (
350            module_name if is_package else module_name.rsplit(".", maxsplit=1)[0]
351        )
352        try:
353            dep_pairs = find_files_source_depends_on(src, package_name)
354        except Exception as e:
355            self.dependency_graph.add_node(
356                module_name,
357                error=PackagingErrorReason.DEPENDENCY_RESOLUTION_FAILED,
358                error_context=str(e),
359            )
360            return []
361
362        # Use a dict to get uniquing but also deterministic order
363        dependencies = {}
364        for dep_module_name, dep_module_obj in dep_pairs:
365            # handle the case where someone did something like `from pack import sub`
366            # where `sub` is a submodule. In this case we don't have to save pack, just sub.
367            # this ensures we don't pick up additional dependencies on pack.
368            # However, in the case where `sub` is not a submodule but an object, then we do have
369            # to save pack.
370            if dep_module_obj is not None:
371                possible_submodule = f"{dep_module_name}.{dep_module_obj}"
372                if self._module_exists(possible_submodule):
373                    dependencies[possible_submodule] = True
374                    # we don't need to save `pack`
375                    continue
376            if self._module_exists(dep_module_name):
377                dependencies[dep_module_name] = True
378
379        return list(dependencies.keys())
380
381    def save_source_string(
382        self,
383        module_name: str,
384        src: str,
385        is_package: bool = False,
386        dependencies: bool = True,
387    ):
388        """Adds ``src`` as the source code for ``module_name`` in the exported package.
389
390        Args:
391            module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code for this package.
392            src (str): The Python source code to save for this package.
393            is_package (bool, optional): If ``True``, this module is treated as a package. Packages are allowed to have submodules
394                (e.g. ``my_package.my_subpackage.my_subsubpackage``), and resources can be saved inside them. Defaults to ``False``.
395            dependencies (bool, optional): If ``True``, we scan the source for dependencies.
396        """
397        self.dependency_graph.add_node(
398            module_name,
399            source=src,
400            is_package=is_package,
401            provided=True,
402            action=_ModuleProviderAction.INTERN,
403        )
404
405        if dependencies:
406            deps = self._get_dependencies(src, module_name, is_package)
407
408            for dep in deps:
409                self.dependency_graph.add_edge(module_name, dep)
410                self.add_dependency(dep)
411
412    def _write_source_string(
413        self,
414        module_name: str,
415        src: str,
416        is_package: bool = False,
417    ):
418        """Write ``src`` as the source code for ``module_name`` in the zip archive.
419
420        Arguments are otherwise the same as for :meth:`save_source_string`.
421        """
422        extension = "/__init__.py" if is_package else ".py"
423        filename = module_name.replace(".", "/") + extension
424
425        self._write(filename, src)
426
427    def _import_module(self, module_name: str):
428        try:
429            return self.importer.import_module(module_name)
430        except ModuleNotFoundError as e:
431            if not is_mangled(module_name):
432                raise
433            msg = (
434                f"Module not found: '{module_name}'. Make sure the PackageImporter that "
435                "created this module is present in `self.importer`"
436            )
437            raise ModuleNotFoundError(msg) from None
438
439    def _module_exists(self, module_name: str) -> bool:
440        try:
441            self._import_module(module_name)
442            return True
443        except Exception:
444            return False
445
446    def _get_source_of_module(self, module: types.ModuleType) -> Optional[str]:
447        filename = None
448        spec = getattr(module, "__spec__", None)
449        if spec is not None:
450            loader = getattr(spec, "loader", None)
451            if loader is not None and isinstance(loader, SourceFileLoader):
452                try:
453                    filename = loader.get_filename(module.__name__)
454                except ImportError:
455                    pass
456        if filename is None:
457            filename = getattr(module, "__file__", None)
458        if isinstance(filename, str) and filename.endswith(".py"):
459            return "".join(linecache.getlines(filename, module.__dict__))
460        return None
461
462    def add_dependency(self, module_name: str, dependencies=True):
463        """Given a module, add it to the dependency graph according to patterns
464        specified by the user.
465        """
466        if (
467            module_name in self.dependency_graph
468            and self.dependency_graph.nodes[module_name].get("provided") is True
469        ):
470            return
471
472        # Special case: PackageImporter provides a special module called
473        # `torch_package_importer` that allows packaged modules to reference
474        # their PackageImporter. We don't want to re-export this.
475        if module_name == "torch_package_importer":
476            self.dependency_graph.add_node(
477                module_name,
478                action=_ModuleProviderAction.SKIP,
479                provided=True,
480            )
481            return
482
483        if module_name == "_mock":
484            self.dependency_graph.add_node(
485                module_name,
486                action=_ModuleProviderAction.REPACKAGED_MOCK_MODULE,
487                provided=True,
488            )
489            return
490
491        if self._can_implicitly_extern(module_name):
492            self.dependency_graph.add_node(
493                module_name, action=_ModuleProviderAction.EXTERN, provided=True
494            )
495            return
496
497        for pattern, pattern_info in self.patterns.items():
498            if pattern.matches(module_name):
499                pattern_info.was_matched = True
500                self.dependency_graph.add_node(
501                    module_name, action=pattern_info.action, provided=True
502                )
503
504                if pattern_info.action == _ModuleProviderAction.DENY:
505                    # Requiring a denied module just adds an error to the graph.
506                    self.dependency_graph.add_node(
507                        module_name, error=PackagingErrorReason.DENIED
508                    )
509
510                # If we are interning this module, we need to retrieve its
511                # dependencies and package those as well.
512                if pattern_info.action == _ModuleProviderAction.INTERN:
513                    self._intern_module(module_name, dependencies)
514                return
515
516        # No patterns have matched. Explicitly add this as an error.
517        self.dependency_graph.add_node(
518            module_name, error=PackagingErrorReason.NO_ACTION
519        )
520
521    def save_module(self, module_name: str, dependencies=True):
522        """Save the code for ``module`` into the package. Code for the module is resolved using the ``importers`` path to find the
523        module object, and then using its ``__file__`` attribute to find the source code.
524
525        Args:
526            module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code
527                for this package.
528            dependencies (bool, optional): If ``True``, we scan the source for dependencies.
529        """
530        if not isinstance(module_name, str):
531            raise TypeError(
532                "save_module() expects a string input, did you perhaps mean to pass `__name__`?"
533            )
534
535        self._intern_module(module_name, dependencies)
536
537    def _intern_module(
538        self,
539        module_name: str,
540        dependencies: bool,
541    ):
542        """Adds the module to the dependency graph as an interned module,
543        along with any metadata needed to write it out to the zipfile at serialization time.
544        """
545        module_obj = self._import_module(module_name)
546        # Subtle: if the import above succeeded, either:
547        #   1. The module name is not mangled, and this was just a regular import, or
548        #   2. The module name is mangled, but one of the importers was able to
549        #      recognize the mangling and import it.
550        # Either way, it is now safe to demangle this name so that we don't
551        # serialize the mangled version to the package.
552        module_name = demangle(module_name)
553
554        # Find dependencies of this module and require them as well.
555        is_package = hasattr(module_obj, "__path__")
556        source = self._get_source_of_module(module_obj)
557        if source is None:
558            # Couldn't find a source!  Add it to our dependency graph as broken
559            # and continue.
560            filename = getattr(module_obj, "__file__", None)
561            error_context = None
562            if filename is None:
563                packaging_error = PackagingErrorReason.NO_DUNDER_FILE
564            elif filename.endswith(tuple(importlib.machinery.EXTENSION_SUFFIXES)):
565                packaging_error = PackagingErrorReason.IS_EXTENSION_MODULE
566            else:
567                packaging_error = PackagingErrorReason.SOURCE_FILE_NOT_FOUND
568                error_context = f"filename: {filename}"
569            self.dependency_graph.add_node(
570                module_name,
571                action=_ModuleProviderAction.INTERN,
572                is_package=is_package,
573                error=packaging_error,
574                error_context=error_context,
575                provided=True,
576            )
577            return
578
579        self.dependency_graph.add_node(
580            module_name,
581            action=_ModuleProviderAction.INTERN,
582            is_package=is_package,
583            source=source,
584            provided=True,
585        )
586
587        if dependencies:
588            deps = self._get_dependencies(source, module_name, is_package)
589            for dep in deps:
590                self.dependency_graph.add_edge(module_name, dep)
591                self.add_dependency(dep)
592
593    def save_pickle(
594        self,
595        package: str,
596        resource: str,
597        obj: Any,
598        dependencies: bool = True,
599        pickle_protocol: int = 3,
600    ):
601        """Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into
602        the archive rather than a stand-alone file. Standard pickle does not save the code, only the objects.
603        If ``dependencies`` is true, this method will also scan the pickled objects for which modules are required
604        to reconstruct them and save the relevant code.
605
606        To be able to save an object where ``type(obj).__name__`` is ``my_module.MyObject``,
607        ``my_module.MyObject`` must resolve to the class of the object according to the ``importer`` order. When saving objects that
608        have previously been packaged, the importer's ``import_module`` method will need to be present in the ``importer`` list
609        for this to work.
610
611        Args:
612            package (str): The name of module package this resource should go in (e.g. ``"my_package.my_subpackage"``).
613            resource (str): A unique name for the resource, used to identify it to load.
614            obj (Any): The object to save, must be picklable.
615            dependencies (bool, optional): If ``True``, we scan the source for dependencies.
616        """
617
618        assert (pickle_protocol == 4) or (
619            pickle_protocol == 3
620        ), "torch.package only supports pickle protocols 3 and 4"
621
622        filename = self._filename(package, resource)
623        # Write the pickle data for `obj`
624        data_buf = io.BytesIO()
625        pickler = create_pickler(data_buf, self.importer, protocol=pickle_protocol)
626        pickler.persistent_id = self._persistent_id
627        pickler.dump(obj)
628        data_value = data_buf.getvalue()
629        mocked_modules = defaultdict(list)
630        name_in_dependency_graph = f"<{package}.{resource}>"
631        self.dependency_graph.add_node(
632            name_in_dependency_graph,
633            action=_ModuleProviderAction.INTERN,
634            provided=True,
635            is_pickle=True,
636        )
637
638        def _check_mocked_error(module: Optional[str], field: Optional[str]):
639            """
640            checks if an object (field) comes from a mocked module and then adds
641            the pair to mocked_modules which contains mocked modules paired with their
642            list of mocked objects present in the pickle.
643
644            We also hold the invariant that the first user defined rule that applies
645            to the module is the one we use.
646            """
647
648            assert isinstance(module, str)
649            assert isinstance(field, str)
650            if self._can_implicitly_extern(module):
651                return
652            for pattern, pattern_info in self.patterns.items():
653                if pattern.matches(module):
654                    if pattern_info.action == _ModuleProviderAction.MOCK:
655                        mocked_modules[module].append(field)
656                    return
657
658        if dependencies:
659            all_dependencies = []
660            module = None
661            field = None
662            memo: DefaultDict[int, str] = defaultdict(None)
663            memo_count = 0
664            # pickletools.dis(data_value)
665            for opcode, arg, pos in pickletools.genops(data_value):
666                if pickle_protocol == 4:
667                    if (
668                        opcode.name == "SHORT_BINUNICODE"
669                        or opcode.name == "BINUNICODE"
670                        or opcode.name == "BINUNICODE8"
671                    ):
672                        assert isinstance(arg, str)
673                        module = field
674                        field = arg
675                        memo[memo_count] = arg
676                    elif (
677                        opcode.name == "LONG_BINGET"
678                        or opcode.name == "BINGET"
679                        or opcode.name == "GET"
680                    ):
681                        assert isinstance(arg, int)
682                        module = field
683                        field = memo.get(arg, None)
684                    elif opcode.name == "MEMOIZE":
685                        memo_count += 1
686                    elif opcode.name == "STACK_GLOBAL":
687                        if module is None:
688                            # If not module was passed on in the entries preceeding this one, continue.
689                            continue
690                        assert isinstance(module, str)
691                        if module not in all_dependencies:
692                            all_dependencies.append(module)
693                        _check_mocked_error(module, field)
694                elif (
695                    pickle_protocol == 3 and opcode.name == "GLOBAL"
696                ):  # a global reference
697                    assert isinstance(arg, str)
698                    module, field = arg.split(" ")
699                    if module not in all_dependencies:
700                        all_dependencies.append(module)
701                    _check_mocked_error(module, field)
702            for module_name in all_dependencies:
703                self.dependency_graph.add_edge(name_in_dependency_graph, module_name)
704
705                """ If an object happens to come from a mocked module, then we collect these errors and spit them
706                    out with the other errors found by package exporter.
707                """
708                if module_name in mocked_modules:
709                    assert isinstance(module_name, str)
710                    fields = mocked_modules[module_name]
711                    self.dependency_graph.add_node(
712                        module_name,
713                        action=_ModuleProviderAction.MOCK,
714                        error=PackagingErrorReason.MOCKED_BUT_STILL_USED,
715                        error_context=f"Object(s) '{fields}' from module `{module_name}` was mocked out during packaging "
716                        f"but is being used in resource - `{resource}` in package `{package}`. ",
717                        provided=True,
718                    )
719                else:
720                    self.add_dependency(module_name)
721
722        self._write(filename, data_value)
723
724    def save_text(self, package: str, resource: str, text: str):
725        """Save text data to the package.
726
727        Args:
728            package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``).
729            resource (str): A unique name for the resource, used to identify it to load.
730            text (str): The contents to save.
731        """
732        return self.save_binary(package, resource, text.encode("utf-8"))
733
734    def save_binary(self, package, resource, binary: bytes):
735        """Save raw bytes to the package.
736
737        Args:
738            package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``).
739            resource (str): A unique name for the resource, used to identify it to load.
740            binary (str): The data to save.
741        """
742        filename = self._filename(package, resource)
743        self._write(filename, binary)
744
745    def register_extern_hook(self, hook: ActionHook) -> RemovableHandle:
746        """Registers an extern hook on the exporter.
747
748        The hook will be called each time a module matches against an :meth:`extern` pattern.
749        It should have the following signature::
750
751            hook(exporter: PackageExporter, module_name: str) -> None
752
753        Hooks will be called in order of registration.
754
755        Returns:
756            :class:`torch.utils.hooks.RemovableHandle`:
757                A handle that can be used to remove the added hook by calling
758                ``handle.remove()``.
759        """
760        handle = RemovableHandle(self._extern_hooks)
761        self._extern_hooks[handle.id] = hook
762        return handle
763
764    def register_mock_hook(self, hook: ActionHook) -> RemovableHandle:
765        """Registers a mock hook on the exporter.
766
767        The hook will be called each time a module matches against a :meth:`mock` pattern.
768        It should have the following signature::
769
770            hook(exporter: PackageExporter, module_name: str) -> None
771
772        Hooks will be called in order of registration.
773
774        Returns:
775            :class:`torch.utils.hooks.RemovableHandle`:
776                A handle that can be used to remove the added hook by calling
777                ``handle.remove()``.
778        """
779        handle = RemovableHandle(self._mock_hooks)
780        self._mock_hooks[handle.id] = hook
781        return handle
782
783    def register_intern_hook(self, hook: ActionHook) -> RemovableHandle:
784        """Registers an intern hook on the exporter.
785
786        The hook will be called each time a module matches against an :meth:`intern` pattern.
787        It should have the following signature::
788
789            hook(exporter: PackageExporter, module_name: str) -> None
790
791        Hooks will be called in order of registration.
792
793        Returns:
794            :class:`torch.utils.hooks.RemovableHandle`:
795                A handle that can be used to remove the added hook by calling
796                ``handle.remove()``.
797        """
798        handle = RemovableHandle(self._intern_hooks)
799        self._intern_hooks[handle.id] = hook
800        return handle
801
802    def intern(
803        self,
804        include: "GlobPattern",
805        *,
806        exclude: "GlobPattern" = (),
807        allow_empty: bool = True,
808    ):
809        """Specify modules that should be packaged. A module must match some ``intern`` pattern in order to be
810        included in the package and have its dependencies processed recursively.
811
812        Args:
813            include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings
814                for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`.
815
816            exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
817
818            allow_empty (bool): An optional flag that specifies whether the intern modules specified by this call
819                to the ``intern`` method must be matched to some module during packaging. If an ``intern`` module glob
820                pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``)
821                before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, no such exception is thrown.
822
823        """
824        self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
825            _ModuleProviderAction.INTERN, allow_empty
826        )
827
828    def mock(
829        self,
830        include: "GlobPattern",
831        *,
832        exclude: "GlobPattern" = (),
833        allow_empty: bool = True,
834    ):
835        """Replace some required modules with a mock implementation.  Mocked modules will return a fake
836        object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes
837        find files that are imported by model files but whose functionality is never used
838        (e.g. custom serialization code or training helpers).
839        Use this function to mock this functionality out without having to modify the original code.
840
841        Args:
842            include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
843                for the names of the modules to be mocked out. Strings can also be a glob-style pattern
844                string that may match multiple modules. Any required dependencies that match this pattern
845                string will be mocked out automatically.
846
847                Examples :
848                    ``'torch.**'`` -- matches ``torch`` and all submodules of torch, e.g. ``'torch.nn'``
849                    and ``'torch.nn.functional'``
850
851                    ``'torch.*'`` -- matches ``'torch.nn'`` or ``'torch.functional'``, but not
852                    ``'torch.nn.functional'``
853
854            exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
855                e.g. ``include='torch.**', exclude='torch.foo'`` will mock all torch packages except ``'torch.foo'``,
856                Default: is ``[]``.
857
858            allow_empty (bool): An optional flag that specifies whether the mock implementation(s) specified by this call
859                to the :meth:`mock` method must be matched to some module during packaging. If a mock is added with
860                ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) and the mock has
861                not been matched to a module used by the package being exported, an exception is thrown.
862                If ``allow_empty=True``, no such exception is thrown.
863
864        """
865        self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
866            _ModuleProviderAction.MOCK, allow_empty
867        )
868
869    def extern(
870        self,
871        include: "GlobPattern",
872        *,
873        exclude: "GlobPattern" = (),
874        allow_empty: bool = True,
875    ):
876        """Include ``module`` in the list of external modules the package can import.
877        This will prevent dependency discovery from saving
878        it in the package. The importer will load an external module directly from the standard import system.
879        Code for extern modules must also exist in the process loading the package.
880
881        Args:
882            include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
883                for the names of the modules to be externed. This can also be a glob-style pattern, as
884                described in :meth:`mock`.
885
886            exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the
887                include string.
888
889            allow_empty (bool): An optional flag that specifies whether the extern modules specified by this call
890                to the ``extern`` method must be matched to some module during packaging. If an extern module glob
891                pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via
892                ``__exit__``) before any modules match that pattern, an exception is thrown. If ``allow_empty=True``,
893                no such exception is thrown.
894
895        """
896        self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
897            _ModuleProviderAction.EXTERN, allow_empty
898        )
899
900    def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()):
901        """Blocklist modules who names match the given glob patterns from the list of modules the package can import.
902        If a dependency on any matching packages is found, a :class:`PackagingError` is raised.
903
904        Args:
905            include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
906                for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`.
907
908            exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
909        """
910        self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
911            _ModuleProviderAction.DENY, allow_empty=True
912        )
913
914    def _persistent_id(self, obj):
915        if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
916            storage: Storage
917            if isinstance(obj, torch.storage.TypedStorage):
918                # TODO: Once we decide to break serialization FC, we can
919                # remove this case
920                untyped_storage = obj._untyped_storage
921                storage_type_str = obj.pickle_storage_type()
922                storage_type = getattr(torch, storage_type_str)
923                storage = cast(Storage, untyped_storage)
924                storage_numel = obj.size()
925
926            elif isinstance(obj, torch.UntypedStorage):
927                untyped_storage = obj
928                storage = cast(Storage, untyped_storage)
929                storage_type = normalize_storage_type(type(storage))
930                storage_numel = storage.nbytes()
931            else:
932                raise RuntimeError(f"storage type not recognized: {type(obj)}")
933
934            location = location_tag(storage)
935
936            # serialize storage if not already written
937            storage_present = self.storage_context.has_storage(storage)
938            storage_id = self.storage_context.get_or_add_storage(storage)
939            if not storage_present:
940                if storage.device.type != "cpu":
941                    storage = storage.cpu()
942                num_bytes = storage.nbytes()
943                self.zip_file.write_record(
944                    f".data/{storage_id}.storage", storage, num_bytes
945                )
946            return ("storage", storage_type, storage_id, location, storage_numel)
947
948        if hasattr(obj, "__reduce_package__"):
949            if _gate_torchscript_serialization and isinstance(
950                obj, torch.jit.RecursiveScriptModule
951            ):
952                raise Exception(  # noqa: TRY002
953                    "Serializing ScriptModules directly into a package is a beta feature. "
954                    "To use, set global "
955                    "`torch.package.package_exporter._gate_torchscript_serialization` to `False`."
956                )
957            if self.serialized_reduces.get(id(obj)) is None:
958                self.serialized_reduces[id(obj)] = (
959                    "reduce_package",
960                    id(obj),
961                    *obj.__reduce_package__(self),
962                )
963
964            return self.serialized_reduces[id(obj)]
965
966        return None
967
968    def __enter__(self):
969        return self
970
971    def __exit__(self, exc_type, exc_value, traceback):
972        # If __exit__ was called because an exception was raised, we do not
973        # attempt to finalize the package. Instead, control is returned to the
974        # caller to continue raising the exception.
975        if exc_type is not None:
976            # Do the bare minimum to leave the open buffer in a valid state.
977            self._finalize_zip()
978            return
979
980        self.close()
981
982    def _write(self, filename, str_or_bytes):
983        if filename in self._written_files:
984            raise AssertionError(
985                f"Tried to write file '{filename}', but it already exists in this archive. "
986                "Please file a bug."
987            )
988        self._written_files.add(filename)
989
990        if is_mangled(filename):
991            raise AssertionError(
992                f"Tried to save a torch.package'd module as '{filename}'. "
993                "Directly saving torch.package'd modules is not allowed."
994            )
995        if isinstance(str_or_bytes, str):
996            str_or_bytes = str_or_bytes.encode("utf-8")
997        self.zip_file.write_record(filename, str_or_bytes, len(str_or_bytes))
998
999    def _validate_dependency_graph(self):
1000        # 1. Check the graph for any errors inserted during dependency analysis.
1001        for attrs in self.dependency_graph.nodes.values():
1002            if "error" in attrs:
1003                raise PackagingError(self.dependency_graph, debug=self.debug)
1004
1005        # 2. Check that all patterns for which allow_empty=False have been matched at least once.
1006        for pattern, pattern_info in self.patterns.items():
1007            if not pattern_info.allow_empty and not pattern_info.was_matched:
1008                raise EmptyMatchError(
1009                    f"Exporter did not match any modules to {pattern}, which was marked as allow_empty=False"
1010                )
1011
1012    def _write_mock_file(self):
1013        if "_mock.py" not in self._written_files:
1014            mock_file = str(Path(__file__).parent / "_mock.py")
1015            self._write_source_string("_mock", _read_file(mock_file), is_package=False)
1016
1017    def _execute_dependency_graph(self):
1018        """Takes a finalized dependency graph describing how to package all
1019        modules and executes it, writing to the ZIP archive.
1020        """
1021        self._validate_dependency_graph()
1022
1023        extern_modules = []
1024        for module_name, attrs in self.dependency_graph.nodes.items():
1025            action = attrs["action"]
1026
1027            if action == _ModuleProviderAction.EXTERN:
1028                for hook in self._extern_hooks.values():
1029                    hook(self, module_name)
1030
1031                extern_modules.append(module_name)
1032
1033            elif action == _ModuleProviderAction.MOCK:
1034                for hook in self._mock_hooks.values():
1035                    hook(self, module_name)
1036
1037                self._write_mock_file()
1038
1039                is_package = hasattr(self._import_module(module_name), "__path__")
1040                self._write_source_string(module_name, _MOCK_IMPL, is_package)
1041
1042            elif action == _ModuleProviderAction.INTERN:
1043                for hook in self._intern_hooks.values():
1044                    hook(self, module_name)
1045
1046                # The node in the dependency graph contains metadata that tells us
1047                # how to intern the module.
1048                if "provided" not in attrs:
1049                    raise AssertionError(
1050                        f"Module was marked `intern` but not provided: {module_name}"
1051                    )
1052
1053                if attrs.get("is_pickle") is True:
1054                    # This node came from save_pickle, we don't need to write any source for it.
1055                    continue
1056
1057                is_package = attrs["is_package"]
1058                source = attrs["source"]
1059                self._write_source_string(module_name, source, is_package)
1060
1061            elif action == _ModuleProviderAction.REPACKAGED_MOCK_MODULE:
1062                self._write_mock_file()
1063            elif action == _ModuleProviderAction.SKIP:
1064                continue
1065            else:
1066                raise AssertionError(
1067                    f"Invalid action: {module_name}, {action}. Please report a bug to PyTorch."
1068                )
1069
1070        extern_file_contents = "\n".join(extern_modules) + "\n"
1071        self._write(".data/extern_modules", extern_file_contents)
1072
1073    def _write_python_version(self):
1074        """Writes the python version that the package was created with to .data/python_version"""
1075        self._write(".data/python_version", platform.python_version())
1076
1077    def close(self):
1078        """Write the package to the filesystem. Any calls after :meth:`close` are now invalid.
1079        It is preferable to use resource guard syntax instead::
1080
1081            with PackageExporter("file.zip") as e:
1082                ...
1083        """
1084        self._execute_dependency_graph()
1085        self._write_python_version()
1086
1087        self.script_module_serializer.write_files()
1088        self._finalize_zip()
1089
1090    def _finalize_zip(self):
1091        """Called at the very end of packaging to leave the zipfile in a closed but valid state."""
1092        del self.zip_file
1093        if self.buffer:
1094            self.buffer.flush()
1095
1096    def _filename(self, package, resource):
1097        package_path = package.replace(".", "/")
1098        resource = _normalize_path(resource)
1099        return f"{package_path}/{resource}"
1100
1101    def _can_implicitly_extern(self, module_name: str):
1102        top_level_package_name = module_name.partition(".")[0]
1103        return top_level_package_name == "torch" or (
1104            top_level_package_name not in _DISALLOWED_MODULES
1105            and is_stdlib_module(top_level_package_name)
1106        )
1107
1108    def dependency_graph_string(self) -> str:
1109        """Returns digraph string representation of dependencies in package.
1110
1111        Returns:
1112            A string representation of dependencies in package.
1113        """
1114        return self.dependency_graph.to_dot()
1115
1116    def _nodes_with_action_type(
1117        self, action: Optional[_ModuleProviderAction]
1118    ) -> List[str]:
1119        result = []
1120        for name, node_dict in self.dependency_graph.nodes.items():
1121            node_action = node_dict.get("action", None)
1122            if node_action == action and "is_pickle" not in node_dict:
1123                result.append(name)
1124        result.sort()
1125        return result
1126
1127    def externed_modules(self) -> List[str]:
1128        """Return all modules that are currently externed.
1129
1130        Returns:
1131            A list containing the names of modules which will be
1132            externed in this package.
1133        """
1134        return self._nodes_with_action_type(_ModuleProviderAction.EXTERN)
1135
1136    def interned_modules(self) -> List[str]:
1137        """Return all modules that are currently interned.
1138
1139        Returns:
1140            A list containing the names of modules which will be
1141            interned in this package.
1142        """
1143        return self._nodes_with_action_type(_ModuleProviderAction.INTERN)
1144
1145    def mocked_modules(self) -> List[str]:
1146        """Return all modules that are currently mocked.
1147
1148        Returns:
1149            A list containing the names of modules which will be
1150            mocked in this package.
1151        """
1152        return self._nodes_with_action_type(_ModuleProviderAction.MOCK)
1153
1154    def denied_modules(self) -> List[str]:
1155        """Return all modules that are currently denied.
1156
1157        Returns:
1158            A list containing the names of modules which will be
1159            denied in this package.
1160        """
1161        return self._nodes_with_action_type(_ModuleProviderAction.DENY)
1162
1163    def get_rdeps(self, module_name: str) -> List[str]:
1164        """Return a list of all modules which depend on the module ``module_name``.
1165
1166        Returns:
1167            A list containing the names of modules which depend on ``module_name``.
1168        """
1169        if module_name in self.dependency_graph._pred.keys():
1170            return list(self.dependency_graph._pred[module_name].keys())
1171        else:
1172            return []
1173
1174    def all_paths(self, src: str, dst: str) -> str:
1175        """Return a dot representation of the subgraph
1176           that has all paths from src to dst.
1177
1178        Returns:
1179            A dot representation containing all paths from src to dst.
1180            (https://graphviz.org/doc/info/lang.html)
1181        """
1182        return self.dependency_graph.all_paths(src, dst)
1183
1184
1185# even though these are in the standard library, we do not allow them to be
1186# automatically externed since they offer a lot of system level access
1187_DISALLOWED_MODULES = ["sys", "io"]
1188
1189_MOCK_IMPL = """\
1190from _mock import MockedObject
1191def __getattr__(attr: str):
1192    return MockedObject(__name__ + '.' + attr, _suppress_err=True)
1193"""
1194
1195
1196def _read_file(filename: str) -> str:
1197    with open(filename, "rb") as f:
1198        b = f.read()
1199        return b.decode("utf-8")
1200