xref: /aosp_15_r20/external/pytorch/torch/package/_package_unpickler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import _compat_pickle
3import pickle
4
5from .importer import Importer
6
7
8class PackageUnpickler(pickle._Unpickler):  # type: ignore[name-defined]
9    """Package-aware unpickler.
10
11    This behaves the same as a normal unpickler, except it uses `importer` to
12    find any global names that it encounters while unpickling.
13    """
14
15    def __init__(self, importer: Importer, *args, **kwargs):
16        super().__init__(*args, **kwargs)
17        self._importer = importer
18
19    def find_class(self, module, name):
20        # Subclasses may override this.
21        if self.proto < 3 and self.fix_imports:  # type: ignore[attr-defined]
22            if (module, name) in _compat_pickle.NAME_MAPPING:
23                module, name = _compat_pickle.NAME_MAPPING[(module, name)]
24            elif module in _compat_pickle.IMPORT_MAPPING:
25                module = _compat_pickle.IMPORT_MAPPING[module]
26        mod = self._importer.import_module(module)
27        return getattr(mod, name)
28