xref: /aosp_15_r20/external/pytorch/torch/serialization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import difflib
3import functools
4import os
5import io
6import re
7import shutil
8import struct
9import sys
10import torch
11import tarfile
12import tempfile
13import warnings
14from contextlib import closing, contextmanager
15from enum import Enum
16from ._utils import _import_dotted_name
17from torch._sources import get_source_lines_and_file
18from torch.types import Storage
19from torch.storage import _get_dtype_from_pickle_storage_type
20from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List
21from typing_extensions import TypeAlias, TypeGuard  # Python 3.10+
22import copyreg
23import pickle
24import torch._weights_only_unpickler as _weights_only_unpickler
25
26DEFAULT_PROTOCOL = 2
27
28LONG_SIZE = struct.Struct('=l').size
29INT_SIZE = struct.Struct('=i').size
30SHORT_SIZE = struct.Struct('=h').size
31
32MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
33PROTOCOL_VERSION = 1001
34STORAGE_KEY_SEPARATOR = ','
35
36FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
37MAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]]
38STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
39
40IS_WINDOWS = sys.platform == "win32"
41
42if not IS_WINDOWS:
43    from mmap import MAP_SHARED, MAP_PRIVATE
44else:
45    MAP_SHARED, MAP_PRIVATE = None, None  # type: ignore[assignment]
46
47__all__ = [
48    'SourceChangeWarning',
49    'mkdtemp',
50    'register_package',
51    'check_module_version_greater_or_equal',
52    'validate_cuda_device',
53    'validate_hpu_device',
54    'location_tag',
55    'default_restore_location',
56    'normalize_storage_type',
57    'storage_to_tensor_type',
58    'save',
59    'load',
60    'StorageType',
61    'LoadEndianness',
62    'get_default_load_endianness',
63    'set_default_load_endianness',
64    'clear_safe_globals',
65    'get_safe_globals',
66    'add_safe_globals',
67]
68
69
70class SourceChangeWarning(Warning):
71    pass
72
73
74@contextmanager
75def mkdtemp():
76    path = tempfile.mkdtemp()
77    try:
78        yield path
79    finally:
80        shutil.rmtree(path)
81
82
83_package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = []
84
85class LoadEndianness(Enum):
86    NATIVE = 1
87    LITTLE = 2
88    BIG = 3
89
90_default_load_endian: Optional[LoadEndianness] = None
91
92def get_default_load_endianness() -> Optional[LoadEndianness]:
93    '''
94    Get fallback byte order for loading files
95
96    If byteorder mark is not present in saved checkpoint,
97    this byte order is used as fallback.
98    By default, it's "native" byte order.
99
100    Returns:
101        default_load_endian: Optional[LoadEndianness]
102    '''
103    return _default_load_endian
104
105def set_default_load_endianness(endianness):
106    '''
107    Set fallback byte order for loading files
108
109    If byteorder mark is not present in saved checkpoint,
110    this byte order is used as fallback.
111    By default, it's "native" byte order.
112
113    Args:
114        endianness: the new fallback byte order
115    '''
116    global _default_load_endian
117    if not isinstance(endianness, LoadEndianness) and endianness is not None:
118        raise TypeError("Invalid argument type in function set_default_load_endianness")
119    _default_load_endian = endianness
120
121_default_mmap_options: int = MAP_PRIVATE
122
123def get_default_mmap_options() -> int:
124    '''
125    Get default mmap options for :func:`torch.load` with ``mmap=True``.
126
127    Defaults to ``mmap.MAP_PRIVATE``.
128
129
130    Returns:
131        default_mmap_options: int
132    '''
133    return _default_mmap_options
134
135def set_default_mmap_options(flags: int):
136    '''
137    Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
138
139    For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
140    Please open an issue if you need any other option to be added here.
141
142    .. note::
143        This feature is currently not supported for Windows.
144
145    Args:
146        flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
147    '''
148    global _default_mmap_options
149    if IS_WINDOWS:
150        raise RuntimeError("Changing the default mmap options is currently not supported for Windows")
151    if (flags != MAP_PRIVATE and flags != MAP_SHARED):
152        raise ValueError("Invalid argument in function set_default_mmap_options, "
153                         f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}")
154    _default_mmap_options = flags
155
156def clear_safe_globals() -> None:
157    '''
158    Clears the list of globals that are safe for ``weights_only`` load.
159    '''
160    _weights_only_unpickler._clear_safe_globals()
161
162def get_safe_globals() -> List[Any]:
163    '''
164    Returns the list of user-added globals that are safe for ``weights_only`` load.
165    '''
166    return _weights_only_unpickler._get_safe_globals()
167
168def add_safe_globals(safe_globals: List[Any]) -> None:
169    '''
170    Marks the given globals as safe for ``weights_only`` load. For example, functions
171    added to this list can be called during unpickling, classes could be instantiated
172    and have state set.
173
174    Args:
175        safe_globals (List[Any]): list of globals to mark as safe
176
177    Example:
178        >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
179        >>> import tempfile
180        >>> class MyTensor(torch.Tensor):
181        ...     pass
182        >>> t = MyTensor(torch.randn(2, 3))
183        >>> with tempfile.NamedTemporaryFile() as f:
184        ...     torch.save(t, f.name)
185        # Running `torch.load(f.name, weights_only=True)` will fail with
186        # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
187        # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
188        ...     torch.serialization.add_safe_globals([MyTensor])
189        ...     torch.load(f.name, weights_only=True)
190        # MyTensor([[-0.5024, -1.8152, -0.5455],
191        #          [-0.8234,  2.0500, -0.3657]])
192    '''
193    _weights_only_unpickler._add_safe_globals(safe_globals)
194
195def _is_zipfile(f) -> bool:
196    # This is a stricter implementation than zipfile.is_zipfile().
197    # zipfile.is_zipfile() is True if the magic number appears anywhere in the
198    # binary. Since we expect the files here to be generated by torch.save or
199    # torch.jit.save, it's safe to only check the start bytes and avoid
200    # collisions and assume the zip has only 1 file.
201    # See bugs.python.org/issue28494.
202
203    start = f.tell()
204    # Read the first few bytes and match against the ZIP file signature
205    local_header_magic_number = b'PK\x03\x04'
206    read_bytes = f.read(len(local_header_magic_number))
207    f.seek(start)
208    return read_bytes == local_header_magic_number
209
210
211def register_package(
212    priority: int,
213    tagger: Callable[[STORAGE], Optional[str]],
214    deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
215):
216    '''
217    Registers callables for tagging and deserializing storage objects with an associated priority.
218    Tagging associates a device with a storage object at save time while deserializing moves a
219    storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
220    are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
221    value that is not `None`.
222
223    To override the deserialization behavior for a device in the global registry, one can register a
224    tagger with a higher priority than the existing tagger.
225
226    This function can also be used to register a tagger and deserializer for new devices.
227
228    Args:
229        priority: Indicates the priority associated with the tagger and deserializer, where a lower
230            value indicates higher priority.
231        tagger: Callable that takes in a storage object and returns its tagged device as a string
232            or None.
233        deserializer: Callable that takes in storage object and a device string and returns a storage
234            object on the appropriate device or None.
235
236    Returns:
237        `None`
238
239    Example:
240        >>> def ipu_tag(obj):
241        >>>     if obj.device.type == 'ipu':
242        >>>         return 'ipu'
243        >>> def ipu_deserialize(obj, location):
244        >>>     if location.startswith('ipu'):
245        >>>         ipu = getattr(torch, "ipu", None)
246        >>>         assert ipu is not None, "IPU device module is not loaded"
247        >>>         assert torch.ipu.is_available(), "ipu is not available"
248        >>>         return obj.ipu(location)
249        >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
250    '''
251    queue_elem = (priority, tagger, deserializer)
252    _package_registry.append(queue_elem)
253    _package_registry.sort()
254
255
256def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
257    '''
258    Check if a module's version satisfies requirements
259
260    Usually, a module's version string will be like 'x.y.z', which would be represented
261    as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
262    string does not match the given tuple's format up to the length of the tuple, then
263    error and exit or emit a warning.
264
265    Args:
266        module: the module to check the version of
267        req_version_tuple: tuple (usually of ints) representing the required version
268        error_if_malformed: whether we should exit if module version string is malformed
269
270    Returns:
271        requirement_is_met: bool
272    '''
273    try:
274        version_strs = module.__version__.split('.')
275        # Cast module version fields to match the types of the required version
276        module_version = tuple(
277            type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
278        )
279        requirement_is_met = module_version >= req_version_tuple
280
281    except Exception as e:
282        message = (
283            f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared"
284            f" with tuple {str(req_version_tuple)}"
285        )
286        if error_if_malformed:
287            raise RuntimeError(message) from e
288        else:
289            warnings.warn(message + ', but continuing assuming that requirement is met')
290            requirement_is_met = True
291
292    return requirement_is_met
293
294
295def _cpu_tag(obj):
296    if obj.device.type == 'cpu':
297        return 'cpu'
298
299
300def _mps_tag(obj):
301    if obj.device.type == 'mps':
302        return 'mps'
303
304
305def _meta_tag(obj):
306    if obj.device.type == 'meta':
307        return 'meta'
308
309
310def _backend_tag(backend_name, obj):
311    if backend_name == 'privateuse1':
312        backend_name = torch._C._get_privateuse1_backend_name()
313    if obj.device.type == backend_name:
314        if obj.device.index is None:
315            return backend_name
316        else:
317            return backend_name + ':' + str(obj.device.index)
318
319
320def _cpu_deserialize(obj, location):
321    if location == 'cpu':
322        return obj
323
324
325def _mps_deserialize(obj, location):
326    if location.startswith('mps'):
327        return obj.mps()
328
329
330def _meta_deserialize(obj, location):
331    if location == 'meta':
332        return torch.UntypedStorage(obj.nbytes(), device='meta')
333
334
335def _validate_device(location, backend_name):
336    '''
337    Check whether the device index of specified backend is valid
338
339    In case of privateuse1 backend, your must first register a device_module for
340    privateuse1 using torch._register_device_module. Implement the following
341    methods in device_module like cuda: device_module._utils._get_device_index(location, True),
342    device_module.device_count().
343
344    Args:
345        location: string of device
346        backend_name: the backend name or the name of privateuse1, which can be renamed
347
348    Returns:
349        device_index: int
350    '''
351    if not hasattr(torch, backend_name):
352        raise RuntimeError(f'The {backend_name.upper()} device module is not registered. '
353                           'If you are running on a CPU-only machine, '
354                           'please use torch.load with map_location=torch.device(\'cpu\') '
355                           'to map your storages to the CPU.')
356    device_module = getattr(torch, backend_name)
357    if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'):
358        device_index = device_module._utils._get_device_index(location, True)
359        device = torch.device(backend_name, device_index)
360    else:
361        device = torch.device(location)
362        device_index = device.index if device.index else 0
363    if hasattr(device_module, 'is_available') and not device_module.is_available():
364        raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} '
365                           f'device but torch.{backend_name}.is_available() is False. '
366                           'If you are running on a CPU-only machine, '
367                           'please use torch.load with map_location=torch.device(\'cpu\') '
368                           'to map your storages to the CPU.')
369    if hasattr(device_module, 'device_count'):
370        device_count = device_module.device_count()
371        if device_index >= device_count:
372            raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device '
373                               f'{device_index} but torch.{backend_name}.device_count() is {device_count}. '
374                               'Please use torch.load with map_location to map your storages '
375                               'to an existing device.')
376    return device
377
378
379def validate_cuda_device(location):
380    return _validate_device(location, 'cuda').index
381
382
383def validate_hpu_device(location):
384    return _validate_device(location, 'hpu').index
385
386
387def _deserialize(backend_name, obj, location):
388    if backend_name == 'privateuse1':
389        backend_name = torch._C._get_privateuse1_backend_name()
390    if location.startswith(backend_name):
391        device = _validate_device(location, backend_name)
392        return obj.to(device=device)
393
394
395register_package(10, _cpu_tag, _cpu_deserialize)
396register_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda'))
397register_package(21, _mps_tag, _mps_deserialize)
398register_package(22, _meta_tag, _meta_deserialize)
399register_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1'))
400register_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu'))
401register_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu'))
402
403def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
404    for _, tagger, _ in _package_registry:
405        location = tagger(storage)
406        if location:
407            return location
408    raise RuntimeError("don't know how to determine data location of "
409                       + torch.typename(storage))
410
411
412def default_restore_location(storage, location):
413    """
414    Restores `storage` using a deserializer function registered for the `location`.
415
416    This function looks in the registry for deserializer functions that match the `location`.
417    If found, it attempts to use them, in priority order, to restore `storage` until one
418    returns a not `None` result. If no deserializer can be found in the registry, or all found fail
419    to bear a result, it raises a `RuntimeError`.
420
421    Args:
422        storage (STORAGE): the storage object to restore
423        location (str): the location tag associated with the storage object
424
425    Returns:
426        storage: Optional[STORAGE]
427
428    Raises:
429        RuntimeError: If no deserializer matching `location` is found in the registry or if
430           all matching ones return `None`.
431    """
432    for _, _, fn in _package_registry:
433        result = fn(storage, location)
434        if result is not None:
435            return result
436    raise RuntimeError("don't know how to restore data location of "
437                       + torch.typename(storage) + " (tagged with "
438                       + location + ")")
439
440
441def normalize_storage_type(storage_type):
442    return getattr(torch, storage_type.__name__)
443
444
445def storage_to_tensor_type(storage):
446    storage_type = type(storage)
447    module = _import_dotted_name(storage_type.__module__)
448    return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
449
450
451def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
452    return isinstance(name_or_buffer, (str, os.PathLike))
453
454
455class _opener:
456    def __init__(self, file_like):
457        self.file_like = file_like
458
459    def __enter__(self):
460        return self.file_like
461
462    def __exit__(self, *args):
463        pass
464
465
466class _open_file(_opener):
467    def __init__(self, name, mode):
468        super().__init__(open(name, mode))
469
470    def __exit__(self, *args):
471        self.file_like.close()
472
473
474class _open_buffer_reader(_opener):
475    def __init__(self, buffer):
476        super().__init__(buffer)
477        _check_seekable(buffer)
478
479
480class _open_buffer_writer(_opener):
481    def __exit__(self, *args):
482        self.file_like.flush()
483
484
485def _open_file_like(name_or_buffer, mode):
486    if _is_path(name_or_buffer):
487        return _open_file(name_or_buffer, mode)
488    else:
489        if 'w' in mode:
490            return _open_buffer_writer(name_or_buffer)
491        elif 'r' in mode:
492            return _open_buffer_reader(name_or_buffer)
493        else:
494            raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
495
496
497class _open_zipfile_reader(_opener):
498    def __init__(self, name_or_buffer) -> None:
499        super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
500
501
502class _open_zipfile_writer_file(_opener):
503    def __init__(self, name) -> None:
504        self.file_stream = None
505        self.name = str(name)
506        try:
507            self.name.encode('ascii')
508        except UnicodeEncodeError:
509            # PyTorchFileWriter only supports ascii filename.
510            # For filenames with non-ascii characters, we rely on Python
511            # for writing out the file.
512            self.file_stream = io.FileIO(self.name, mode='w')
513            super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
514        else:
515            super().__init__(torch._C.PyTorchFileWriter(self.name))
516
517    def __exit__(self, *args) -> None:
518        self.file_like.write_end_of_file()
519        if self.file_stream is not None:
520            self.file_stream.close()
521
522
523class _open_zipfile_writer_buffer(_opener):
524    def __init__(self, buffer) -> None:
525        if not callable(getattr(buffer, "write", None)):
526            msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
527            if not hasattr(buffer, "write"):
528                raise AttributeError(msg)
529            raise TypeError(msg)
530        self.buffer = buffer
531        super().__init__(torch._C.PyTorchFileWriter(buffer))
532
533    def __exit__(self, *args) -> None:
534        self.file_like.write_end_of_file()
535        self.buffer.flush()
536
537
538def _open_zipfile_writer(name_or_buffer):
539    container: Type[_opener]
540    if _is_path(name_or_buffer):
541        container = _open_zipfile_writer_file
542    else:
543        container = _open_zipfile_writer_buffer
544    return container(name_or_buffer)
545
546
547def _is_compressed_file(f) -> bool:
548    compress_modules = ['gzip']
549    try:
550        return f.__module__ in compress_modules
551    except AttributeError:
552        return False
553
554
555def _should_read_directly(f):
556    """
557    Checks if f is a file that should be read directly. It should be read
558    directly if it is backed by a real file (has a fileno) and is not a
559    a compressed file (e.g. gzip)
560    """
561    if _is_compressed_file(f):
562        return False
563    try:
564        return f.fileno() >= 0
565    except io.UnsupportedOperation:
566        return False
567    except AttributeError:
568        return False
569
570
571def _check_seekable(f) -> bool:
572
573    def raise_err_msg(patterns, e):
574        for p in patterns:
575            if p in str(e):
576                msg = (str(e) + ". You can only torch.load from a file that is seekable."
577                                + " Please pre-load the data into a buffer like io.BytesIO and"
578                                + " try to load from it instead.")
579                raise type(e)(msg)
580        raise e
581
582    try:
583        f.seek(f.tell())
584        return True
585    except (io.UnsupportedOperation, AttributeError) as e:
586        raise_err_msg(["seek", "tell"], e)
587    return False
588
589
590def _check_dill_version(pickle_module) -> None:
591    '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
592    If dill version is lower than 0.3.1, a ValueError is raised.
593
594    Args:
595        pickle_module: module used for pickling metadata and objects
596
597    '''
598    if pickle_module is not None and pickle_module.__name__ == 'dill':
599        required_dill_version = (0, 3, 1)
600        if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
601            raise ValueError((
602                "'torch' supports dill >= {}, but you have dill {}."
603                " Please upgrade dill or switch to 'pickle'"
604            ).format(
605                '.'.join([str(num) for num in required_dill_version]),
606                pickle_module.__version__
607            ))
608
609
610def _check_save_filelike(f):
611    if not _is_path(f) and not hasattr(f, 'write'):
612        raise AttributeError(
613            "expected 'f' to be string, path, or a file-like object with "
614            "a 'write' attribute")
615
616
617def save(
618    obj: object,
619    f: FILE_LIKE,
620    pickle_module: Any = pickle,
621    pickle_protocol: int = DEFAULT_PROTOCOL,
622    _use_new_zipfile_serialization: bool = True,
623    _disable_byteorder_record: bool = False
624) -> None:
625    # Reference: https://github.com/pytorch/pytorch/issues/54354
626    # The first line of this docstring overrides the one Sphinx generates for the
627    # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
628    # the build environment (e.g. `<module 'pickle' from '/leaked/path').
629
630    """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
631
632    Saves an object to a disk file.
633
634    See also: :ref:`saving-loading-tensors`
635
636    Args:
637        obj: saved object
638        f: a file-like object (has to implement write and flush) or a string or
639           os.PathLike object containing a file name
640        pickle_module: module used for pickling metadata and objects
641        pickle_protocol: can be specified to override the default protocol
642
643    .. note::
644        A common PyTorch convention is to save tensors using .pt file extension.
645
646    .. note::
647        PyTorch preserves storage sharing across serialization. See
648        :ref:`preserve-storage-sharing` for more details.
649
650    .. note::
651        The 1.6 release of PyTorch switched ``torch.save`` to use a new
652        zipfile-based file format. ``torch.load`` still retains the ability to
653        load files in the old format. If for any reason you want ``torch.save``
654        to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
655
656    Example:
657        >>> # xdoctest: +SKIP("makes cwd dirty")
658        >>> # Save to file
659        >>> x = torch.tensor([0, 1, 2, 3, 4])
660        >>> torch.save(x, 'tensor.pt')
661        >>> # Save to io.BytesIO buffer
662        >>> buffer = io.BytesIO()
663        >>> torch.save(x, buffer)
664    """
665    torch._C._log_api_usage_once("torch.save")
666    _check_dill_version(pickle_module)
667    _check_save_filelike(f)
668
669    if _use_new_zipfile_serialization:
670        with _open_zipfile_writer(f) as opened_zipfile:
671            _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
672            return
673    else:
674        with _open_file_like(f, 'wb') as opened_file:
675            _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
676
677
678def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
679    import torch.nn as nn
680    serialized_container_types = {}
681    serialized_storages = {}
682
683    # Since loading storages that view the same data with different dtypes is
684    # not supported, we need to keep track of the dtype associated with each
685    # storage data_ptr and throw an error if the dtype is ever different.
686    # TODO: This feature could be added in the future
687    storage_dtypes: Dict[int, torch.dtype] = {}
688
689    def persistent_id(obj: Any) -> Optional[Tuple]:
690        # FIXME: the docs say that persistent_id should only return a string
691        # but torch store returns tuples. This works only in the binary protocol
692        # see
693        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
694        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
695        if isinstance(obj, type) and issubclass(obj, nn.Module):
696            if obj in serialized_container_types:
697                return None
698            serialized_container_types[obj] = True
699            source_file = source = None
700            try:
701                source_lines, _, source_file = get_source_lines_and_file(obj)
702                source = ''.join(source_lines)
703            except Exception:  # saving the source is optional, so we can ignore any errors
704                warnings.warn("Couldn't retrieve source code for container of "
705                              "type " + obj.__name__ + ". It won't be checked "
706                              "for correctness upon loading.")
707            return ('module', obj, source_file, source)
708
709        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
710            storage: torch.UntypedStorage
711
712            if isinstance(obj, torch.storage.TypedStorage):
713                # TODO: Once we decide to break serialization FC, this case
714                # can be deleted
715                storage = obj._untyped_storage
716                storage_dtype = obj.dtype
717                storage_type_str = obj._pickle_storage_type()
718                storage_type = getattr(torch, storage_type_str)
719                dtype = obj.dtype
720                storage_numel = obj._size()
721
722            elif isinstance(obj, torch.UntypedStorage):
723                storage = obj
724                storage_dtype = torch.uint8
725                storage_type = normalize_storage_type(type(obj))
726                dtype = torch.uint8
727                storage_numel = storage.nbytes()
728            else:
729                raise TypeError(f'type not recognized: {type(obj)}')
730
731            # If storage is allocated, ensure that any other saved storages
732            # pointing to the same data all have the same dtype. If storage is
733            # not allocated, don't perform this check
734            if storage.data_ptr() != 0:
735                if storage.data_ptr() in storage_dtypes:
736                    if storage_dtype != storage_dtypes[storage.data_ptr()]:
737                        raise RuntimeError(
738                            'Cannot save multiple tensors or storages that '
739                            'view the same data as different types')
740                else:
741                    storage_dtypes[storage.data_ptr()] = storage_dtype
742
743            view_metadata: Optional[Tuple[str, int, int]]
744
745            # Offset is always 0, but we keep it for backwards compatibility
746            # with the old serialization format (which supported storage views)
747            offset = 0
748            storage_key = str(storage._cdata)
749            location = location_tag(storage)
750
751            # TODO: There's an issue here with FC. It might be impossible to
752            # solve, but it's worth noting. Imagine we save a list `[storage,
753            # tensor]`, where `tensor.storage()` is the same as `storage`, and
754            # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
755            # torch.float`.  The storage will be serialized with element size
756            # of 1, since we're choosing to serialize the first occurance of
757            # a duplicate storage. Since this legacy serialization format saves
758            # the numel of the storage, rather than nbytes directly, we'll be
759            # effectively saving nbytes in this case.  We'll be able to load it
760            # and the tensor back up with no problems in _this_ and future
761            # versions of pytorch, but in older versions, here's the problem:
762            # the storage will be loaded up as a UntypedStorage, and then the
763            # FloatTensor will loaded and the UntypedStorage will be assigned to
764            # it. Since the storage dtype does not match the tensor dtype, this
765            # will cause an error.  If we reverse the list, like `[tensor,
766            # storage]`, then we will save the `tensor.storage()` as a faked
767            # `FloatStorage`, and the saved size will be the correct
768            # dtype-specific numel count that old versions expect. `tensor`
769            # will be able to load up properly in old versions, pointing to
770            # a FloatStorage. However, `storage` is still being translated to
771            # a UntypedStorage, and it will try to resolve to the same
772            # FloatStorage that `tensor` contains. This will also cause an
773            # error. It doesn't seem like there's any way around this.
774            # Probably, we just cannot maintain FC for the legacy format if the
775            # saved list contains both a tensor and a storage that point to the
776            # same data.  We should still be able to maintain FC for lists of
777            # just tensors, as long as all views share the same dtype as the
778            # tensor they are viewing.
779
780            if storage_key not in serialized_storages:
781                serialized_storages[storage_key] = (storage, dtype)
782            is_view = storage._cdata != storage._cdata
783            if is_view:
784                view_metadata = (str(storage._cdata), offset, storage.nbytes())
785            else:
786                view_metadata = None
787
788            res = ('storage',
789                   storage_type,
790                   storage_key,
791                   location,
792                   storage_numel,
793                   view_metadata)
794            return res
795        return None
796
797    sys_info = dict(
798        protocol_version=PROTOCOL_VERSION,
799        little_endian=sys.byteorder == 'little',
800        type_sizes=dict(
801            short=SHORT_SIZE,
802            int=INT_SIZE,
803            long=LONG_SIZE,
804        ),
805    )
806
807    pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
808    pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
809    pickle_module.dump(sys_info, f, protocol=pickle_protocol)
810    pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
811    pickler.persistent_id = persistent_id
812    pickler.dump(obj)
813
814    serialized_storage_keys = sorted(serialized_storages.keys())
815    pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
816    f.flush()
817    for key in serialized_storage_keys:
818        storage, dtype = serialized_storages[key]
819        storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
820
821
822def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
823    serialized_storages = {}
824    id_map: Dict[int, str] = {}
825
826    # Since loading storages that view the same data with different dtypes is
827    # not supported, we need to keep track of the dtype associated with each
828    # storage data_ptr and throw an error if the dtype is ever different.
829    # TODO: This feature could be added in the future
830    storage_dtypes: Dict[int, torch.dtype] = {}
831
832    def persistent_id(obj):
833        # FIXME: the docs say that persistent_id should only return a string
834        # but torch store returns tuples. This works only in the binary protocol
835        # see
836        # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
837        # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
838        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
839
840            if isinstance(obj, torch.storage.TypedStorage):
841                # TODO: Once we decide to break serialization FC, this case
842                # can be deleted
843                storage = obj._untyped_storage
844                storage_dtype = obj.dtype
845                storage_type_str = obj._pickle_storage_type()
846                storage_type = getattr(torch, storage_type_str)
847                storage_numel = obj._size()
848
849            else:
850                storage = obj
851                storage_dtype = torch.uint8
852                storage_type = normalize_storage_type(type(obj))
853                storage_numel = storage.nbytes()
854
855            # If storage is allocated, ensure that any other saved storages
856            # pointing to the same data all have the same dtype. If storage is
857            # not allocated, don't perform this check
858            if storage.data_ptr() != 0:
859                if storage.data_ptr() in storage_dtypes:
860                    if storage_dtype != storage_dtypes[storage.data_ptr()]:
861                        raise RuntimeError(
862                            'Cannot save multiple tensors or storages that '
863                            'view the same data as different types')
864                else:
865                    storage_dtypes[storage.data_ptr()] = storage_dtype
866
867            storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
868            location = location_tag(storage)
869            serialized_storages[storage_key] = storage
870
871            return ('storage',
872                    storage_type,
873                    storage_key,
874                    location,
875                    storage_numel)
876
877        return None
878
879    # Write the pickle data for `obj`
880    data_buf = io.BytesIO()
881    pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
882    pickler.persistent_id = persistent_id
883    pickler.dump(obj)
884    data_value = data_buf.getvalue()
885    zip_file.write_record('data.pkl', data_value, len(data_value))
886
887    # Write byte order marker
888    if not _disable_byteorder_record:
889        if sys.byteorder not in ['little', 'big']:
890            raise ValueError('Unknown endianness type: ' + sys.byteorder)
891
892        zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
893
894    # Write each tensor to a file named tensor/the_tensor_key in the zip archive
895    for key in sorted(serialized_storages.keys()):
896        name = f'data/{key}'
897        storage = serialized_storages[key]
898        # given that we copy things around anyway, we might use storage.cpu()
899        # this means to that to get tensors serialized, you need to implement
900        # .cpu() on the underlying Storage
901        if storage.device.type != 'cpu':
902            storage = storage.cpu()
903        # Now that it is on the CPU we can directly copy it into the zip file
904        num_bytes = storage.nbytes()
905        zip_file.write_record(name, storage, num_bytes)
906
907
908def load(
909    f: FILE_LIKE,
910    map_location: MAP_LOCATION = None,
911    pickle_module: Any = None,
912    *,
913    weights_only: Optional[bool] = None,
914    mmap: Optional[bool] = None,
915    **pickle_load_args: Any
916) -> Any:
917    # Reference: https://github.com/pytorch/pytorch/issues/54354
918    # The first line of this docstring overrides the one Sphinx generates for the
919    # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
920    # the build environment (e.g. `<module 'pickle' from '/leaked/path').
921
922    """load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
923
924    Loads an object saved with :func:`torch.save` from a file.
925
926    :func:`torch.load` uses Python's unpickling facilities but treats storages,
927    which underlie tensors, specially. They are first deserialized on the
928    CPU and are then moved to the device they were saved from. If this fails
929    (e.g. because the run time system doesn't have certain devices), an exception
930    is raised. However, storages can be dynamically remapped to an alternative
931    set of devices using the :attr:`map_location` argument.
932
933    If :attr:`map_location` is a callable, it will be called once for each serialized
934    storage with two arguments: storage and location. The storage argument
935    will be the initial deserialization of the storage, residing on the CPU.
936    Each serialized storage has a location tag associated with it which
937    identifies the device it was saved from, and this tag is the second
938    argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
939    for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
940    :attr:`map_location` should return either ``None`` or a storage. If
941    :attr:`map_location` returns a storage, it will be used as the final deserialized
942    object, already moved to the right device. Otherwise, :func:`torch.load` will
943    fall back to the default behavior, as if :attr:`map_location` wasn't specified.
944
945    If :attr:`map_location` is a :class:`torch.device` object or a string containing
946    a device tag, it indicates the location where all tensors should be loaded.
947
948    Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
949    appearing in the file (keys), to ones that specify where to put the
950    storages (values).
951
952    User extensions can register their own location tags and tagging and
953    deserialization methods using :func:`torch.serialization.register_package`.
954
955    Args:
956        f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
957            or a string or os.PathLike object containing a file name
958        map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
959            locations
960        pickle_module: module used for unpickling metadata and objects (has to
961            match the :attr:`pickle_module` used to serialize file)
962        weights_only: Indicates whether unpickler should be restricted to
963            loading only tensors, primitive types, dictionaries
964            and any types added via :func:`torch.serialization.add_safe_globals`.
965        mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
966            Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
967            are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
968            second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
969            tensor storages from disk to CPU memory in the first step, ``f`` is mmaped.
970        pickle_load_args: (Python 3 only) optional keyword arguments passed over to
971            :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
972            :attr:`errors=...`.
973
974    .. warning::
975        :func:`torch.load()` unless `weights_only` parameter is set to `True`,
976        uses ``pickle`` module implicitly, which is known to be insecure.
977        It is possible to construct malicious pickle data which will execute arbitrary code
978        during unpickling. Never load data that could have come from an untrusted
979        source in an unsafe mode, or that could have been tampered with. **Only load data you trust**.
980
981    .. note::
982        When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
983        will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
984        and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
985
986    .. note::
987        By default, we decode byte strings as ``utf-8``.  This is to avoid a common error
988        case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
989        when loading files saved by Python 2 in Python 3.  If this default
990        is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
991        these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
992        to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
993        as byte arrays which can be decoded later with ``byte_array.decode(...)``.
994
995    Example:
996        >>> # xdoctest: +SKIP("undefined filepaths")
997        >>> torch.load('tensors.pt', weights_only=True)
998        # Load all tensors onto the CPU
999        >>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True)
1000        # Load all tensors onto the CPU, using a function
1001        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True)
1002        # Load all tensors onto GPU 1
1003        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True)
1004        # Map tensors from GPU 1 to GPU 0
1005        >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True)
1006        # Load tensor from io.BytesIO object
1007        # Loading from a buffer setting weights_only=False, warning this can be unsafe
1008        >>> with open('tensor.pt', 'rb') as f:
1009        ...     buffer = io.BytesIO(f.read())
1010        >>> torch.load(buffer, weights_only=False)
1011        # Load a module with 'ascii' encoding for unpickling
1012        # Loading from a module setting weights_only=False, warning this can be unsafe
1013        >>> torch.load('module.pt', encoding='ascii', weights_only=False)
1014    """
1015    torch._C._log_api_usage_once("torch.load")
1016    UNSAFE_MESSAGE = (
1017        "Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
1018        "but it can result in arbitrary code execution. Do it only if you got the file from a "
1019        "trusted source."
1020    )
1021    DOCS_MESSAGE = (
1022        "\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
1023        "weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
1024    )
1025
1026    def _get_wo_message(message: str) -> str:
1027        pattern = r"GLOBAL (\S+) was not an allowed global by default."
1028        has_unsafe_global = re.search(pattern, message) is not None
1029        if has_unsafe_global:
1030            updated_message = (
1031                "Weights only load failed. This file can still be loaded, to do so you have two options "
1032                f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
1033                "the recommended steps in the following error message.\n\tWeightsUnpickler error: "
1034                + message
1035            )
1036        else:
1037            updated_message = (
1038                f"Weights only load failed. {UNSAFE_MESSAGE}\n Please file an issue with the following "
1039                "so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler "
1040                "error: " + message
1041            )
1042        return updated_message + DOCS_MESSAGE
1043
1044    if weights_only is None:
1045        weights_only, warn_weights_only = False, True
1046    else:
1047        warn_weights_only = False
1048
1049    # Add ability to force safe only weight loads via environment variable
1050    if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
1051        weights_only = True
1052
1053    if weights_only:
1054        if pickle_module is not None:
1055            raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
1056    else:
1057        if pickle_module is None:
1058            if warn_weights_only:
1059                warnings.warn(
1060                    "You are using `torch.load` with `weights_only=False` (the current default value), which uses "
1061                    "the default pickle module implicitly. It is possible to construct malicious pickle data "
1062                    "which will execute arbitrary code during unpickling (See "
1063                    "https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
1064                    "In a future release, the default value for `weights_only` will be flipped to `True`. This "
1065                    "limits the functions that could be executed during unpickling. Arbitrary objects will no "
1066                    "longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
1067                    "user via `torch.serialization.add_safe_globals`. We recommend you start setting "
1068                    "`weights_only=True` for any use case where you don't have full control of the loaded file. "
1069                    "Please open an issue on GitHub for any issues related to this experimental feature.",
1070                    FutureWarning,
1071                    stacklevel=2,
1072                )
1073            pickle_module = pickle
1074
1075    # make flipping default BC-compatible
1076    if mmap is None:
1077        mmap = False
1078
1079    _check_dill_version(pickle_module)
1080
1081    if 'encoding' not in pickle_load_args.keys():
1082        pickle_load_args['encoding'] = 'utf-8'
1083
1084    with _open_file_like(f, 'rb') as opened_file:
1085        if _is_zipfile(opened_file):
1086            # The zipfile reader is going to advance the current file position.
1087            # If we want to actually tail call to torch.jit.load, we need to
1088            # reset back to the original position.
1089            orig_position = opened_file.tell()
1090            overall_storage = None
1091            with _open_zipfile_reader(opened_file) as opened_zipfile:
1092                if _is_torchscript_zip(opened_zipfile):
1093                    warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
1094                                  " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
1095                                  " silence this warning)", UserWarning)
1096                    opened_file.seek(orig_position)
1097                    return torch.jit.load(opened_file, map_location=map_location)
1098                if mmap:
1099                    if not _is_path(f):
1100                        raise ValueError("f must be a file path in order to use the mmap argument")
1101                    size = os.path.getsize(f)
1102                    if not IS_WINDOWS:
1103                        shared = get_default_mmap_options() == MAP_SHARED
1104                    else:
1105                        shared = False
1106                    overall_storage = torch.UntypedStorage.from_file(os.fspath(f), shared, size)
1107                if weights_only:
1108                    try:
1109                        return _load(opened_zipfile,
1110                                     map_location,
1111                                     _weights_only_unpickler,
1112                                     overall_storage=overall_storage,
1113                                     **pickle_load_args)
1114                    except RuntimeError as e:
1115                        raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
1116                return _load(
1117                    opened_zipfile,
1118                    map_location,
1119                    pickle_module,
1120                    overall_storage=overall_storage,
1121                    **pickle_load_args,
1122                )
1123        if mmap:
1124            f_name = "" if not isinstance(f, str) else f"{f}, "
1125            raise RuntimeError("mmap can only be used with files saved with "
1126                               f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
1127                               "please torch.save your checkpoint with this option in order to use mmap.")
1128        if weights_only:
1129            try:
1130                return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
1131            except RuntimeError as e:
1132                raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
1133        return _legacy_load(
1134            opened_file, map_location, pickle_module, **pickle_load_args
1135        )
1136
1137
1138# Register pickling support for layout instances such as
1139# torch.sparse_coo, etc
1140def _get_layout(name):
1141    """Get layout extension object from its string representation.
1142    """
1143    cache = _get_layout.cache   # type: ignore[attr-defined]
1144    if not cache:
1145        for v in torch.__dict__.values():
1146            if isinstance(v, torch.layout):
1147                cache[str(v)] = v
1148    return cache[name]
1149
1150# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
1151_get_layout.cache = {}   # type: ignore[attr-defined]
1152copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
1153
1154
1155def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
1156    deserialized_objects: Dict[int, Any] = {}
1157
1158    restore_location = _get_restore_location(map_location)
1159
1160    class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined]
1161
1162        def find_class(self, mod_name, name):
1163            if type(name) is str and 'Storage' in name:
1164                try:
1165                    return StorageType(name)
1166                except KeyError:
1167                    pass
1168            return super().find_class(mod_name, name)
1169
1170    def _check_container_source(container_type, source_file, original_source):
1171        try:
1172            current_source = ''.join(get_source_lines_and_file(container_type)[0])
1173        except Exception:  # saving the source is optional, so we can ignore any errors
1174            warnings.warn("Couldn't retrieve source code for container of "
1175                          "type " + container_type.__name__ + ". It won't be checked "
1176                          "for correctness upon loading.")
1177            return
1178        if original_source != current_source:
1179            if container_type.dump_patches:
1180                file_name = container_type.__name__ + '.patch'
1181                diff = difflib.unified_diff(current_source.split('\n'),
1182                                            original_source.split('\n'),
1183                                            source_file,
1184                                            source_file, lineterm="")
1185                lines = '\n'.join(diff)
1186                try:
1187                    with open(file_name, 'a+') as f:
1188                        file_size = f.seek(0, 2)
1189                        f.seek(0)
1190                        if file_size == 0:
1191                            f.write(lines)
1192                        elif file_size != len(lines) or f.read() != lines:
1193                            raise OSError
1194                    msg = ("Saved a reverse patch to " + file_name + ". "
1195                           "Run `patch -p0 < " + file_name + "` to revert your "
1196                           "changes.")
1197                except OSError:
1198                    msg = ("Tried to save a patch, but couldn't create a "
1199                           "writable file " + file_name + ". Make sure it "
1200                           "doesn't exist and your working directory is "
1201                           "writable.")
1202            else:
1203                msg = ("you can retrieve the original source code by "
1204                       "accessing the object's source attribute or set "
1205                       "`torch.nn.Module.dump_patches = True` and use the "
1206                       "patch tool to revert the changes.")
1207            msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
1208            warnings.warn(msg, SourceChangeWarning)
1209
1210    def legacy_load(f):
1211        deserialized_objects: Dict[int, Any] = {}
1212
1213        def persistent_load(saved_id):
1214            if isinstance(saved_id, tuple):
1215                # Ignore containers that don't have any sources saved
1216                if all(saved_id[1:]):
1217                    _check_container_source(*saved_id)
1218                return saved_id[0]
1219            return deserialized_objects[int(saved_id)]
1220
1221        with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
1222                mkdtemp() as tmpdir:
1223
1224            tar.extract('storages', path=tmpdir)
1225            with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
1226                num_storages = pickle_module.load(f, **pickle_load_args)
1227                for i in range(num_storages):
1228                    args = pickle_module.load(f, **pickle_load_args)
1229                    key, location, storage_type = args
1230                    dtype = storage_type._dtype
1231                    obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
1232                    obj = restore_location(obj, location)
1233                    # TODO: Once we decide to break serialization FC, we can
1234                    # stop wrapping with TypedStorage
1235                    deserialized_objects[key] = torch.storage.TypedStorage(
1236                        wrap_storage=obj,
1237                        dtype=dtype,
1238                        _internal=True)
1239
1240                storage_views = pickle_module.load(f, **pickle_load_args)
1241                for target_cdata, root_cdata, offset, numel in storage_views:
1242                    root = deserialized_objects[root_cdata]
1243                    element_size = torch._utils._element_size(root.dtype)
1244                    offset_bytes = offset * element_size
1245                    # TODO: Once we decide to break serialization FC, we can
1246                    # stop wrapping with TypedStorage
1247                    deserialized_objects[target_cdata] = torch.storage.TypedStorage(
1248                        wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
1249                        dtype=root.dtype,
1250                        _internal=True)
1251
1252            tar.extract('tensors', path=tmpdir)
1253            with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
1254                num_tensors = pickle_module.load(f, **pickle_load_args)
1255                for _ in range(num_tensors):
1256                    args = pickle_module.load(f, **pickle_load_args)
1257                    key, storage_id, original_tensor_type = args
1258                    storage = deserialized_objects[storage_id]
1259                    ndim, = struct.unpack('<i', f.read(4))
1260                    # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
1261                    f.read(4)
1262                    numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
1263                    stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
1264                    storage_offset, = struct.unpack('<q', f.read(8))
1265                    tensor = torch.empty((0,), dtype=storage.dtype).set_(
1266                        storage._untyped_storage, storage_offset, numel, stride)
1267                    deserialized_objects[key] = tensor
1268
1269            pickle_file = tar.extractfile('pickle')
1270            unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
1271            unpickler.persistent_load = persistent_load
1272            result = unpickler.load()
1273            return result
1274
1275    deserialized_objects = {}
1276
1277    def persistent_load(saved_id):
1278        assert isinstance(saved_id, tuple)
1279        typename = _maybe_decode_ascii(saved_id[0])
1280        data = saved_id[1:]
1281
1282        if typename == 'module':
1283            # Ignore containers that don't have any sources saved
1284            if all(data[1:]):
1285                _check_container_source(*data)
1286            return data[0]
1287        elif typename == 'storage':
1288            storage_type, root_key, location, numel, view_metadata = data
1289            location = _maybe_decode_ascii(location)
1290            dtype = storage_type.dtype
1291
1292            nbytes = numel * torch._utils._element_size(dtype)
1293
1294            if root_key not in deserialized_objects:
1295                if torch._guards.active_fake_mode() is not None:
1296                    obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta'))
1297                else:
1298                    obj = cast(Storage, torch.UntypedStorage(nbytes))
1299                    obj._torch_load_uninitialized = True
1300                    obj = restore_location(obj, location)
1301                # TODO: Once we decide to break serialization FC, we can
1302                # stop wrapping with TypedStorage
1303                typed_storage = torch.storage.TypedStorage(
1304                    wrap_storage=obj,
1305                    dtype=dtype,
1306                    _internal=True)
1307                deserialized_objects[root_key] = typed_storage
1308            else:
1309                typed_storage = deserialized_objects[root_key]
1310                if typed_storage._data_ptr() == 0:
1311                    typed_storage = torch.storage.TypedStorage(
1312                        device=typed_storage._untyped_storage.device,
1313                        dtype=dtype,
1314                        _internal=True)
1315
1316            if view_metadata is not None:
1317                view_key, offset, view_size = view_metadata
1318                offset_bytes = offset * torch._utils._element_size(dtype)
1319                view_size_bytes = view_size * torch._utils._element_size(dtype)
1320                if view_key not in deserialized_objects:
1321                    # TODO: Once we decide to break serialization FC, we can
1322                    # stop wrapping with TypedStorage
1323                    deserialized_objects[view_key] = torch.storage.TypedStorage(
1324                        wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
1325                        dtype=dtype,
1326                        _internal=True)
1327                res = deserialized_objects[view_key]
1328
1329            else:
1330                res = typed_storage
1331            return res
1332        else:
1333            raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
1334
1335    _check_seekable(f)
1336    f_should_read_directly = _should_read_directly(f)
1337
1338    if f_should_read_directly and f.tell() == 0:
1339        # legacy_load requires that f has fileno()
1340        # only if offset is zero we can attempt the legacy tar file loader
1341        try:
1342            return legacy_load(f)
1343        except tarfile.TarError:
1344            if _is_zipfile(f):
1345                # .zip is used for torch.jit.save and will throw an un-pickling error here
1346                raise RuntimeError(
1347                    f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
1348            # if not a tarfile, reset file offset and proceed
1349            f.seek(0)
1350
1351    if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
1352        raise RuntimeError(
1353            "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
1354            f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this '
1355            "functionality.")
1356
1357    magic_number = pickle_module.load(f, **pickle_load_args)
1358    if magic_number != MAGIC_NUMBER:
1359        raise RuntimeError("Invalid magic number; corrupt file?")
1360    protocol_version = pickle_module.load(f, **pickle_load_args)
1361    if protocol_version != PROTOCOL_VERSION:
1362        raise RuntimeError(f"Invalid protocol version: {protocol_version}")
1363
1364    _sys_info = pickle_module.load(f, **pickle_load_args)
1365    unpickler = UnpicklerWrapper(f, **pickle_load_args)
1366    unpickler.persistent_load = persistent_load
1367    result = unpickler.load()
1368
1369    deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
1370
1371    if torch._guards.active_fake_mode() is None:
1372        offset = f.tell() if f_should_read_directly else None
1373        for key in deserialized_storage_keys:
1374            assert key in deserialized_objects
1375            typed_storage = deserialized_objects[key]
1376            typed_storage._untyped_storage._set_from_file(
1377                f, offset, f_should_read_directly,
1378                torch._utils._element_size(typed_storage.dtype))
1379            if offset is not None:
1380                offset = f.tell()
1381
1382    torch._utils._validate_loaded_sparse_tensors()
1383
1384    return result
1385
1386
1387def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
1388    # When using encoding='bytes' in Py3, some **internal** keys stored as
1389    # strings in Py2 are loaded as bytes. This function decodes them with
1390    # ascii encoding, one that Py3 uses by default.
1391    #
1392    # NOTE: This should only be used on internal keys (e.g., `typename` and
1393    #       `location` in `persistent_load` below!
1394    if isinstance(bytes_str, bytes):
1395        return bytes_str.decode('ascii')
1396    return bytes_str
1397
1398
1399def _get_restore_location(map_location):
1400    if map_location is None:
1401        restore_location = default_restore_location
1402    elif isinstance(map_location, dict):
1403        def restore_location(storage, location):
1404            location = map_location.get(location, location)
1405            return default_restore_location(storage, location)
1406    elif isinstance(map_location, (str, bytes)):
1407        def restore_location(storage, location):
1408            return default_restore_location(storage, map_location)
1409    elif isinstance(map_location, torch.device):
1410        def restore_location(storage, location):
1411            return default_restore_location(storage, str(map_location))
1412    else:
1413        def restore_location(storage, location):
1414            result = map_location(storage, location)
1415            if result is None:
1416                result = default_restore_location(storage, location)
1417            return result
1418    return restore_location
1419
1420
1421class StorageType:
1422    def __init__(self, name):
1423        self._dtype = _get_dtype_from_pickle_storage_type(name)
1424
1425    @property
1426    def dtype(self):
1427        return self._dtype
1428
1429    def __str__(self):
1430        return f'StorageType(dtype={self.dtype})'
1431
1432
1433def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
1434    restore_location = _get_restore_location(map_location)
1435
1436    loaded_storages = {}
1437
1438    # check if byteswapping is needed
1439    byteordername = 'byteorder'
1440    byteorderdata = None
1441    if zip_file.has_record(byteordername):
1442        byteorderdata = zip_file.get_record(byteordername)
1443        if byteorderdata not in [b'little', b'big']:
1444            raise ValueError('Unknown endianness type: ' + byteorderdata.decode())
1445    elif get_default_load_endianness() == LoadEndianness.LITTLE or \
1446            get_default_load_endianness() is None:
1447        byteorderdata = b'little'
1448    elif get_default_load_endianness() == LoadEndianness.BIG:
1449        byteorderdata = b'big'
1450    elif get_default_load_endianness() == LoadEndianness.NATIVE:
1451        pass
1452    else:
1453        raise ValueError('Invalid load endianness type')
1454
1455    if not zip_file.has_record(byteordername) and \
1456            get_default_load_endianness() is None and \
1457            sys.byteorder == 'big':
1458        # Default behaviour was changed
1459        # See https://github.com/pytorch/pytorch/issues/101688
1460        warnings.warn("The default load endianness for checkpoints without a byteorder mark "
1461                      "on big endian machines was changed from 'native' to 'little' endian, "
1462                      "to avoid this behavior please use "
1463                      "torch.serialization.set_default_load_endianness to set "
1464                      "the desired default load endianness",
1465                      UserWarning)
1466
1467    def load_tensor(dtype, numel, key, location):
1468        name = f'data/{key}'
1469        if torch._guards.detect_fake_mode(None) is not None:
1470            nbytes = numel * torch._utils._element_size(dtype)
1471            storage = torch.UntypedStorage(nbytes, device='meta')
1472        elif overall_storage is not None:
1473            storage_offset = zip_file.get_record_offset(name)
1474            storage = overall_storage[storage_offset:storage_offset + numel]
1475        else:
1476            storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
1477        # swap here if byteswapping is needed
1478        if byteorderdata is not None:
1479            if byteorderdata.decode() != sys.byteorder:
1480                storage.byteswap(dtype)
1481
1482        # TODO: Once we decide to break serialization FC, we can
1483        # stop wrapping with TypedStorage
1484        typed_storage = torch.storage.TypedStorage(
1485            wrap_storage=restore_location(storage, location),
1486            dtype=dtype,
1487            _internal=True)
1488
1489        if typed_storage._data_ptr() != 0:
1490            loaded_storages[key] = typed_storage
1491
1492        return typed_storage
1493
1494    def persistent_load(saved_id):
1495        assert isinstance(saved_id, tuple)
1496        typename = _maybe_decode_ascii(saved_id[0])
1497        data = saved_id[1:]
1498
1499        assert typename == 'storage', \
1500            f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
1501        storage_type, key, location, numel = data
1502        if storage_type is torch.UntypedStorage:
1503            dtype = torch.uint8
1504        else:
1505            dtype = storage_type.dtype
1506
1507        if key in loaded_storages:
1508            typed_storage = loaded_storages[key]
1509        else:
1510            nbytes = numel * torch._utils._element_size(dtype)
1511            typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
1512
1513        return typed_storage
1514
1515    load_module_mapping: Dict[str, str] = {
1516        # See https://github.com/pytorch/pytorch/pull/51633
1517        'torch.tensor': 'torch._tensor'
1518    }
1519
1520    # Need to subclass Unpickler instead of directly monkey-patching the find_class method
1521    # because it's marked readonly in pickle.
1522    # The type: ignore is because mypy can't statically determine the type of this class.
1523    class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined]
1524        # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
1525        # Lets us override the imports that pickle uses when unpickling an object.
1526        # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
1527        def find_class(self, mod_name, name):
1528            if type(name) is str and 'Storage' in name:
1529                try:
1530                    return StorageType(name)
1531                except KeyError:
1532                    pass
1533            mod_name = load_module_mapping.get(mod_name, mod_name)
1534            return super().find_class(mod_name, name)
1535
1536    # Load the data (which may in turn use `persistent_load` to load tensors)
1537    data_file = io.BytesIO(zip_file.get_record(pickle_file))
1538
1539    unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
1540    unpickler.persistent_load = persistent_load
1541    # Needed for tensors where storage device and rebuild tensor device are
1542    # not connected (wrapper subclasses and tensors rebuilt using numpy)
1543    torch._utils._thread_local_state.map_location = map_location
1544    result = unpickler.load()
1545    del torch._utils._thread_local_state.map_location
1546
1547    torch._utils._validate_loaded_sparse_tensors()
1548    torch._C._log_api_usage_metadata(
1549        "torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
1550    )
1551    return result
1552
1553
1554def _is_torchscript_zip(zip_file):
1555    return 'constants.pkl' in zip_file.get_all_records()
1556