1# mypy: allow-untyped-defs 2import difflib 3import functools 4import os 5import io 6import re 7import shutil 8import struct 9import sys 10import torch 11import tarfile 12import tempfile 13import warnings 14from contextlib import closing, contextmanager 15from enum import Enum 16from ._utils import _import_dotted_name 17from torch._sources import get_source_lines_and_file 18from torch.types import Storage 19from torch.storage import _get_dtype_from_pickle_storage_type 20from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List 21from typing_extensions import TypeAlias, TypeGuard # Python 3.10+ 22import copyreg 23import pickle 24import torch._weights_only_unpickler as _weights_only_unpickler 25 26DEFAULT_PROTOCOL = 2 27 28LONG_SIZE = struct.Struct('=l').size 29INT_SIZE = struct.Struct('=i').size 30SHORT_SIZE = struct.Struct('=h').size 31 32MAGIC_NUMBER = 0x1950a86a20f9469cfc6c 33PROTOCOL_VERSION = 1001 34STORAGE_KEY_SEPARATOR = ',' 35 36FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]] 37MAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]] 38STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage] 39 40IS_WINDOWS = sys.platform == "win32" 41 42if not IS_WINDOWS: 43 from mmap import MAP_SHARED, MAP_PRIVATE 44else: 45 MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] 46 47__all__ = [ 48 'SourceChangeWarning', 49 'mkdtemp', 50 'register_package', 51 'check_module_version_greater_or_equal', 52 'validate_cuda_device', 53 'validate_hpu_device', 54 'location_tag', 55 'default_restore_location', 56 'normalize_storage_type', 57 'storage_to_tensor_type', 58 'save', 59 'load', 60 'StorageType', 61 'LoadEndianness', 62 'get_default_load_endianness', 63 'set_default_load_endianness', 64 'clear_safe_globals', 65 'get_safe_globals', 66 'add_safe_globals', 67] 68 69 70class SourceChangeWarning(Warning): 71 pass 72 73 74@contextmanager 75def mkdtemp(): 76 path = tempfile.mkdtemp() 77 try: 78 yield path 79 finally: 80 shutil.rmtree(path) 81 82 83_package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = [] 84 85class LoadEndianness(Enum): 86 NATIVE = 1 87 LITTLE = 2 88 BIG = 3 89 90_default_load_endian: Optional[LoadEndianness] = None 91 92def get_default_load_endianness() -> Optional[LoadEndianness]: 93 ''' 94 Get fallback byte order for loading files 95 96 If byteorder mark is not present in saved checkpoint, 97 this byte order is used as fallback. 98 By default, it's "native" byte order. 99 100 Returns: 101 default_load_endian: Optional[LoadEndianness] 102 ''' 103 return _default_load_endian 104 105def set_default_load_endianness(endianness): 106 ''' 107 Set fallback byte order for loading files 108 109 If byteorder mark is not present in saved checkpoint, 110 this byte order is used as fallback. 111 By default, it's "native" byte order. 112 113 Args: 114 endianness: the new fallback byte order 115 ''' 116 global _default_load_endian 117 if not isinstance(endianness, LoadEndianness) and endianness is not None: 118 raise TypeError("Invalid argument type in function set_default_load_endianness") 119 _default_load_endian = endianness 120 121_default_mmap_options: int = MAP_PRIVATE 122 123def get_default_mmap_options() -> int: 124 ''' 125 Get default mmap options for :func:`torch.load` with ``mmap=True``. 126 127 Defaults to ``mmap.MAP_PRIVATE``. 128 129 130 Returns: 131 default_mmap_options: int 132 ''' 133 return _default_mmap_options 134 135def set_default_mmap_options(flags: int): 136 ''' 137 Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags. 138 139 For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported. 140 Please open an issue if you need any other option to be added here. 141 142 .. note:: 143 This feature is currently not supported for Windows. 144 145 Args: 146 flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` 147 ''' 148 global _default_mmap_options 149 if IS_WINDOWS: 150 raise RuntimeError("Changing the default mmap options is currently not supported for Windows") 151 if (flags != MAP_PRIVATE and flags != MAP_SHARED): 152 raise ValueError("Invalid argument in function set_default_mmap_options, " 153 f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}") 154 _default_mmap_options = flags 155 156def clear_safe_globals() -> None: 157 ''' 158 Clears the list of globals that are safe for ``weights_only`` load. 159 ''' 160 _weights_only_unpickler._clear_safe_globals() 161 162def get_safe_globals() -> List[Any]: 163 ''' 164 Returns the list of user-added globals that are safe for ``weights_only`` load. 165 ''' 166 return _weights_only_unpickler._get_safe_globals() 167 168def add_safe_globals(safe_globals: List[Any]) -> None: 169 ''' 170 Marks the given globals as safe for ``weights_only`` load. For example, functions 171 added to this list can be called during unpickling, classes could be instantiated 172 and have state set. 173 174 Args: 175 safe_globals (List[Any]): list of globals to mark as safe 176 177 Example: 178 >>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization") 179 >>> import tempfile 180 >>> class MyTensor(torch.Tensor): 181 ... pass 182 >>> t = MyTensor(torch.randn(2, 3)) 183 >>> with tempfile.NamedTemporaryFile() as f: 184 ... torch.save(t, f.name) 185 # Running `torch.load(f.name, weights_only=True)` will fail with 186 # Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default. 187 # Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint. 188 ... torch.serialization.add_safe_globals([MyTensor]) 189 ... torch.load(f.name, weights_only=True) 190 # MyTensor([[-0.5024, -1.8152, -0.5455], 191 # [-0.8234, 2.0500, -0.3657]]) 192 ''' 193 _weights_only_unpickler._add_safe_globals(safe_globals) 194 195def _is_zipfile(f) -> bool: 196 # This is a stricter implementation than zipfile.is_zipfile(). 197 # zipfile.is_zipfile() is True if the magic number appears anywhere in the 198 # binary. Since we expect the files here to be generated by torch.save or 199 # torch.jit.save, it's safe to only check the start bytes and avoid 200 # collisions and assume the zip has only 1 file. 201 # See bugs.python.org/issue28494. 202 203 start = f.tell() 204 # Read the first few bytes and match against the ZIP file signature 205 local_header_magic_number = b'PK\x03\x04' 206 read_bytes = f.read(len(local_header_magic_number)) 207 f.seek(start) 208 return read_bytes == local_header_magic_number 209 210 211def register_package( 212 priority: int, 213 tagger: Callable[[STORAGE], Optional[str]], 214 deserializer: Callable[[STORAGE, str], Optional[STORAGE]] 215): 216 ''' 217 Registers callables for tagging and deserializing storage objects with an associated priority. 218 Tagging associates a device with a storage object at save time while deserializing moves a 219 storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer` 220 are run in the order given by their :attr:`priority` until a tagger/deserializer returns a 221 value that is not `None`. 222 223 To override the deserialization behavior for a device in the global registry, one can register a 224 tagger with a higher priority than the existing tagger. 225 226 This function can also be used to register a tagger and deserializer for new devices. 227 228 Args: 229 priority: Indicates the priority associated with the tagger and deserializer, where a lower 230 value indicates higher priority. 231 tagger: Callable that takes in a storage object and returns its tagged device as a string 232 or None. 233 deserializer: Callable that takes in storage object and a device string and returns a storage 234 object on the appropriate device or None. 235 236 Returns: 237 `None` 238 239 Example: 240 >>> def ipu_tag(obj): 241 >>> if obj.device.type == 'ipu': 242 >>> return 'ipu' 243 >>> def ipu_deserialize(obj, location): 244 >>> if location.startswith('ipu'): 245 >>> ipu = getattr(torch, "ipu", None) 246 >>> assert ipu is not None, "IPU device module is not loaded" 247 >>> assert torch.ipu.is_available(), "ipu is not available" 248 >>> return obj.ipu(location) 249 >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize) 250 ''' 251 queue_elem = (priority, tagger, deserializer) 252 _package_registry.append(queue_elem) 253 _package_registry.sort() 254 255 256def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True): 257 ''' 258 Check if a module's version satisfies requirements 259 260 Usually, a module's version string will be like 'x.y.z', which would be represented 261 as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version 262 string does not match the given tuple's format up to the length of the tuple, then 263 error and exit or emit a warning. 264 265 Args: 266 module: the module to check the version of 267 req_version_tuple: tuple (usually of ints) representing the required version 268 error_if_malformed: whether we should exit if module version string is malformed 269 270 Returns: 271 requirement_is_met: bool 272 ''' 273 try: 274 version_strs = module.__version__.split('.') 275 # Cast module version fields to match the types of the required version 276 module_version = tuple( 277 type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple) 278 ) 279 requirement_is_met = module_version >= req_version_tuple 280 281 except Exception as e: 282 message = ( 283 f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared" 284 f" with tuple {str(req_version_tuple)}" 285 ) 286 if error_if_malformed: 287 raise RuntimeError(message) from e 288 else: 289 warnings.warn(message + ', but continuing assuming that requirement is met') 290 requirement_is_met = True 291 292 return requirement_is_met 293 294 295def _cpu_tag(obj): 296 if obj.device.type == 'cpu': 297 return 'cpu' 298 299 300def _mps_tag(obj): 301 if obj.device.type == 'mps': 302 return 'mps' 303 304 305def _meta_tag(obj): 306 if obj.device.type == 'meta': 307 return 'meta' 308 309 310def _backend_tag(backend_name, obj): 311 if backend_name == 'privateuse1': 312 backend_name = torch._C._get_privateuse1_backend_name() 313 if obj.device.type == backend_name: 314 if obj.device.index is None: 315 return backend_name 316 else: 317 return backend_name + ':' + str(obj.device.index) 318 319 320def _cpu_deserialize(obj, location): 321 if location == 'cpu': 322 return obj 323 324 325def _mps_deserialize(obj, location): 326 if location.startswith('mps'): 327 return obj.mps() 328 329 330def _meta_deserialize(obj, location): 331 if location == 'meta': 332 return torch.UntypedStorage(obj.nbytes(), device='meta') 333 334 335def _validate_device(location, backend_name): 336 ''' 337 Check whether the device index of specified backend is valid 338 339 In case of privateuse1 backend, your must first register a device_module for 340 privateuse1 using torch._register_device_module. Implement the following 341 methods in device_module like cuda: device_module._utils._get_device_index(location, True), 342 device_module.device_count(). 343 344 Args: 345 location: string of device 346 backend_name: the backend name or the name of privateuse1, which can be renamed 347 348 Returns: 349 device_index: int 350 ''' 351 if not hasattr(torch, backend_name): 352 raise RuntimeError(f'The {backend_name.upper()} device module is not registered. ' 353 'If you are running on a CPU-only machine, ' 354 'please use torch.load with map_location=torch.device(\'cpu\') ' 355 'to map your storages to the CPU.') 356 device_module = getattr(torch, backend_name) 357 if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'): 358 device_index = device_module._utils._get_device_index(location, True) 359 device = torch.device(backend_name, device_index) 360 else: 361 device = torch.device(location) 362 device_index = device.index if device.index else 0 363 if hasattr(device_module, 'is_available') and not device_module.is_available(): 364 raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} ' 365 f'device but torch.{backend_name}.is_available() is False. ' 366 'If you are running on a CPU-only machine, ' 367 'please use torch.load with map_location=torch.device(\'cpu\') ' 368 'to map your storages to the CPU.') 369 if hasattr(device_module, 'device_count'): 370 device_count = device_module.device_count() 371 if device_index >= device_count: 372 raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device ' 373 f'{device_index} but torch.{backend_name}.device_count() is {device_count}. ' 374 'Please use torch.load with map_location to map your storages ' 375 'to an existing device.') 376 return device 377 378 379def validate_cuda_device(location): 380 return _validate_device(location, 'cuda').index 381 382 383def validate_hpu_device(location): 384 return _validate_device(location, 'hpu').index 385 386 387def _deserialize(backend_name, obj, location): 388 if backend_name == 'privateuse1': 389 backend_name = torch._C._get_privateuse1_backend_name() 390 if location.startswith(backend_name): 391 device = _validate_device(location, backend_name) 392 return obj.to(device=device) 393 394 395register_package(10, _cpu_tag, _cpu_deserialize) 396register_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda')) 397register_package(21, _mps_tag, _mps_deserialize) 398register_package(22, _meta_tag, _meta_deserialize) 399register_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1')) 400register_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu')) 401register_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu')) 402 403def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]): 404 for _, tagger, _ in _package_registry: 405 location = tagger(storage) 406 if location: 407 return location 408 raise RuntimeError("don't know how to determine data location of " 409 + torch.typename(storage)) 410 411 412def default_restore_location(storage, location): 413 """ 414 Restores `storage` using a deserializer function registered for the `location`. 415 416 This function looks in the registry for deserializer functions that match the `location`. 417 If found, it attempts to use them, in priority order, to restore `storage` until one 418 returns a not `None` result. If no deserializer can be found in the registry, or all found fail 419 to bear a result, it raises a `RuntimeError`. 420 421 Args: 422 storage (STORAGE): the storage object to restore 423 location (str): the location tag associated with the storage object 424 425 Returns: 426 storage: Optional[STORAGE] 427 428 Raises: 429 RuntimeError: If no deserializer matching `location` is found in the registry or if 430 all matching ones return `None`. 431 """ 432 for _, _, fn in _package_registry: 433 result = fn(storage, location) 434 if result is not None: 435 return result 436 raise RuntimeError("don't know how to restore data location of " 437 + torch.typename(storage) + " (tagged with " 438 + location + ")") 439 440 441def normalize_storage_type(storage_type): 442 return getattr(torch, storage_type.__name__) 443 444 445def storage_to_tensor_type(storage): 446 storage_type = type(storage) 447 module = _import_dotted_name(storage_type.__module__) 448 return getattr(module, storage_type.__name__.replace('Storage', 'Tensor')) 449 450 451def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]: 452 return isinstance(name_or_buffer, (str, os.PathLike)) 453 454 455class _opener: 456 def __init__(self, file_like): 457 self.file_like = file_like 458 459 def __enter__(self): 460 return self.file_like 461 462 def __exit__(self, *args): 463 pass 464 465 466class _open_file(_opener): 467 def __init__(self, name, mode): 468 super().__init__(open(name, mode)) 469 470 def __exit__(self, *args): 471 self.file_like.close() 472 473 474class _open_buffer_reader(_opener): 475 def __init__(self, buffer): 476 super().__init__(buffer) 477 _check_seekable(buffer) 478 479 480class _open_buffer_writer(_opener): 481 def __exit__(self, *args): 482 self.file_like.flush() 483 484 485def _open_file_like(name_or_buffer, mode): 486 if _is_path(name_or_buffer): 487 return _open_file(name_or_buffer, mode) 488 else: 489 if 'w' in mode: 490 return _open_buffer_writer(name_or_buffer) 491 elif 'r' in mode: 492 return _open_buffer_reader(name_or_buffer) 493 else: 494 raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}") 495 496 497class _open_zipfile_reader(_opener): 498 def __init__(self, name_or_buffer) -> None: 499 super().__init__(torch._C.PyTorchFileReader(name_or_buffer)) 500 501 502class _open_zipfile_writer_file(_opener): 503 def __init__(self, name) -> None: 504 self.file_stream = None 505 self.name = str(name) 506 try: 507 self.name.encode('ascii') 508 except UnicodeEncodeError: 509 # PyTorchFileWriter only supports ascii filename. 510 # For filenames with non-ascii characters, we rely on Python 511 # for writing out the file. 512 self.file_stream = io.FileIO(self.name, mode='w') 513 super().__init__(torch._C.PyTorchFileWriter(self.file_stream)) 514 else: 515 super().__init__(torch._C.PyTorchFileWriter(self.name)) 516 517 def __exit__(self, *args) -> None: 518 self.file_like.write_end_of_file() 519 if self.file_stream is not None: 520 self.file_stream.close() 521 522 523class _open_zipfile_writer_buffer(_opener): 524 def __init__(self, buffer) -> None: 525 if not callable(getattr(buffer, "write", None)): 526 msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'" 527 if not hasattr(buffer, "write"): 528 raise AttributeError(msg) 529 raise TypeError(msg) 530 self.buffer = buffer 531 super().__init__(torch._C.PyTorchFileWriter(buffer)) 532 533 def __exit__(self, *args) -> None: 534 self.file_like.write_end_of_file() 535 self.buffer.flush() 536 537 538def _open_zipfile_writer(name_or_buffer): 539 container: Type[_opener] 540 if _is_path(name_or_buffer): 541 container = _open_zipfile_writer_file 542 else: 543 container = _open_zipfile_writer_buffer 544 return container(name_or_buffer) 545 546 547def _is_compressed_file(f) -> bool: 548 compress_modules = ['gzip'] 549 try: 550 return f.__module__ in compress_modules 551 except AttributeError: 552 return False 553 554 555def _should_read_directly(f): 556 """ 557 Checks if f is a file that should be read directly. It should be read 558 directly if it is backed by a real file (has a fileno) and is not a 559 a compressed file (e.g. gzip) 560 """ 561 if _is_compressed_file(f): 562 return False 563 try: 564 return f.fileno() >= 0 565 except io.UnsupportedOperation: 566 return False 567 except AttributeError: 568 return False 569 570 571def _check_seekable(f) -> bool: 572 573 def raise_err_msg(patterns, e): 574 for p in patterns: 575 if p in str(e): 576 msg = (str(e) + ". You can only torch.load from a file that is seekable." 577 + " Please pre-load the data into a buffer like io.BytesIO and" 578 + " try to load from it instead.") 579 raise type(e)(msg) 580 raise e 581 582 try: 583 f.seek(f.tell()) 584 return True 585 except (io.UnsupportedOperation, AttributeError) as e: 586 raise_err_msg(["seek", "tell"], e) 587 return False 588 589 590def _check_dill_version(pickle_module) -> None: 591 '''Checks if using dill as the pickle module, and if so, checks if it is the correct version. 592 If dill version is lower than 0.3.1, a ValueError is raised. 593 594 Args: 595 pickle_module: module used for pickling metadata and objects 596 597 ''' 598 if pickle_module is not None and pickle_module.__name__ == 'dill': 599 required_dill_version = (0, 3, 1) 600 if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False): 601 raise ValueError(( 602 "'torch' supports dill >= {}, but you have dill {}." 603 " Please upgrade dill or switch to 'pickle'" 604 ).format( 605 '.'.join([str(num) for num in required_dill_version]), 606 pickle_module.__version__ 607 )) 608 609 610def _check_save_filelike(f): 611 if not _is_path(f) and not hasattr(f, 'write'): 612 raise AttributeError( 613 "expected 'f' to be string, path, or a file-like object with " 614 "a 'write' attribute") 615 616 617def save( 618 obj: object, 619 f: FILE_LIKE, 620 pickle_module: Any = pickle, 621 pickle_protocol: int = DEFAULT_PROTOCOL, 622 _use_new_zipfile_serialization: bool = True, 623 _disable_byteorder_record: bool = False 624) -> None: 625 # Reference: https://github.com/pytorch/pytorch/issues/54354 626 # The first line of this docstring overrides the one Sphinx generates for the 627 # documentation. We need it so that Sphinx doesn't leak `pickle`s path from 628 # the build environment (e.g. `<module 'pickle' from '/leaked/path'). 629 630 """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) 631 632 Saves an object to a disk file. 633 634 See also: :ref:`saving-loading-tensors` 635 636 Args: 637 obj: saved object 638 f: a file-like object (has to implement write and flush) or a string or 639 os.PathLike object containing a file name 640 pickle_module: module used for pickling metadata and objects 641 pickle_protocol: can be specified to override the default protocol 642 643 .. note:: 644 A common PyTorch convention is to save tensors using .pt file extension. 645 646 .. note:: 647 PyTorch preserves storage sharing across serialization. See 648 :ref:`preserve-storage-sharing` for more details. 649 650 .. note:: 651 The 1.6 release of PyTorch switched ``torch.save`` to use a new 652 zipfile-based file format. ``torch.load`` still retains the ability to 653 load files in the old format. If for any reason you want ``torch.save`` 654 to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``. 655 656 Example: 657 >>> # xdoctest: +SKIP("makes cwd dirty") 658 >>> # Save to file 659 >>> x = torch.tensor([0, 1, 2, 3, 4]) 660 >>> torch.save(x, 'tensor.pt') 661 >>> # Save to io.BytesIO buffer 662 >>> buffer = io.BytesIO() 663 >>> torch.save(x, buffer) 664 """ 665 torch._C._log_api_usage_once("torch.save") 666 _check_dill_version(pickle_module) 667 _check_save_filelike(f) 668 669 if _use_new_zipfile_serialization: 670 with _open_zipfile_writer(f) as opened_zipfile: 671 _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record) 672 return 673 else: 674 with _open_file_like(f, 'wb') as opened_file: 675 _legacy_save(obj, opened_file, pickle_module, pickle_protocol) 676 677 678def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None: 679 import torch.nn as nn 680 serialized_container_types = {} 681 serialized_storages = {} 682 683 # Since loading storages that view the same data with different dtypes is 684 # not supported, we need to keep track of the dtype associated with each 685 # storage data_ptr and throw an error if the dtype is ever different. 686 # TODO: This feature could be added in the future 687 storage_dtypes: Dict[int, torch.dtype] = {} 688 689 def persistent_id(obj: Any) -> Optional[Tuple]: 690 # FIXME: the docs say that persistent_id should only return a string 691 # but torch store returns tuples. This works only in the binary protocol 692 # see 693 # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects 694 # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 695 if isinstance(obj, type) and issubclass(obj, nn.Module): 696 if obj in serialized_container_types: 697 return None 698 serialized_container_types[obj] = True 699 source_file = source = None 700 try: 701 source_lines, _, source_file = get_source_lines_and_file(obj) 702 source = ''.join(source_lines) 703 except Exception: # saving the source is optional, so we can ignore any errors 704 warnings.warn("Couldn't retrieve source code for container of " 705 "type " + obj.__name__ + ". It won't be checked " 706 "for correctness upon loading.") 707 return ('module', obj, source_file, source) 708 709 if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): 710 storage: torch.UntypedStorage 711 712 if isinstance(obj, torch.storage.TypedStorage): 713 # TODO: Once we decide to break serialization FC, this case 714 # can be deleted 715 storage = obj._untyped_storage 716 storage_dtype = obj.dtype 717 storage_type_str = obj._pickle_storage_type() 718 storage_type = getattr(torch, storage_type_str) 719 dtype = obj.dtype 720 storage_numel = obj._size() 721 722 elif isinstance(obj, torch.UntypedStorage): 723 storage = obj 724 storage_dtype = torch.uint8 725 storage_type = normalize_storage_type(type(obj)) 726 dtype = torch.uint8 727 storage_numel = storage.nbytes() 728 else: 729 raise TypeError(f'type not recognized: {type(obj)}') 730 731 # If storage is allocated, ensure that any other saved storages 732 # pointing to the same data all have the same dtype. If storage is 733 # not allocated, don't perform this check 734 if storage.data_ptr() != 0: 735 if storage.data_ptr() in storage_dtypes: 736 if storage_dtype != storage_dtypes[storage.data_ptr()]: 737 raise RuntimeError( 738 'Cannot save multiple tensors or storages that ' 739 'view the same data as different types') 740 else: 741 storage_dtypes[storage.data_ptr()] = storage_dtype 742 743 view_metadata: Optional[Tuple[str, int, int]] 744 745 # Offset is always 0, but we keep it for backwards compatibility 746 # with the old serialization format (which supported storage views) 747 offset = 0 748 storage_key = str(storage._cdata) 749 location = location_tag(storage) 750 751 # TODO: There's an issue here with FC. It might be impossible to 752 # solve, but it's worth noting. Imagine we save a list `[storage, 753 # tensor]`, where `tensor.storage()` is the same as `storage`, and 754 # `tensor.element_size() > 1`. Let's say that `tensor.dtype == 755 # torch.float`. The storage will be serialized with element size 756 # of 1, since we're choosing to serialize the first occurance of 757 # a duplicate storage. Since this legacy serialization format saves 758 # the numel of the storage, rather than nbytes directly, we'll be 759 # effectively saving nbytes in this case. We'll be able to load it 760 # and the tensor back up with no problems in _this_ and future 761 # versions of pytorch, but in older versions, here's the problem: 762 # the storage will be loaded up as a UntypedStorage, and then the 763 # FloatTensor will loaded and the UntypedStorage will be assigned to 764 # it. Since the storage dtype does not match the tensor dtype, this 765 # will cause an error. If we reverse the list, like `[tensor, 766 # storage]`, then we will save the `tensor.storage()` as a faked 767 # `FloatStorage`, and the saved size will be the correct 768 # dtype-specific numel count that old versions expect. `tensor` 769 # will be able to load up properly in old versions, pointing to 770 # a FloatStorage. However, `storage` is still being translated to 771 # a UntypedStorage, and it will try to resolve to the same 772 # FloatStorage that `tensor` contains. This will also cause an 773 # error. It doesn't seem like there's any way around this. 774 # Probably, we just cannot maintain FC for the legacy format if the 775 # saved list contains both a tensor and a storage that point to the 776 # same data. We should still be able to maintain FC for lists of 777 # just tensors, as long as all views share the same dtype as the 778 # tensor they are viewing. 779 780 if storage_key not in serialized_storages: 781 serialized_storages[storage_key] = (storage, dtype) 782 is_view = storage._cdata != storage._cdata 783 if is_view: 784 view_metadata = (str(storage._cdata), offset, storage.nbytes()) 785 else: 786 view_metadata = None 787 788 res = ('storage', 789 storage_type, 790 storage_key, 791 location, 792 storage_numel, 793 view_metadata) 794 return res 795 return None 796 797 sys_info = dict( 798 protocol_version=PROTOCOL_VERSION, 799 little_endian=sys.byteorder == 'little', 800 type_sizes=dict( 801 short=SHORT_SIZE, 802 int=INT_SIZE, 803 long=LONG_SIZE, 804 ), 805 ) 806 807 pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) 808 pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) 809 pickle_module.dump(sys_info, f, protocol=pickle_protocol) 810 pickler = pickle_module.Pickler(f, protocol=pickle_protocol) 811 pickler.persistent_id = persistent_id 812 pickler.dump(obj) 813 814 serialized_storage_keys = sorted(serialized_storages.keys()) 815 pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) 816 f.flush() 817 for key in serialized_storage_keys: 818 storage, dtype = serialized_storages[key] 819 storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype)) 820 821 822def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): 823 serialized_storages = {} 824 id_map: Dict[int, str] = {} 825 826 # Since loading storages that view the same data with different dtypes is 827 # not supported, we need to keep track of the dtype associated with each 828 # storage data_ptr and throw an error if the dtype is ever different. 829 # TODO: This feature could be added in the future 830 storage_dtypes: Dict[int, torch.dtype] = {} 831 832 def persistent_id(obj): 833 # FIXME: the docs say that persistent_id should only return a string 834 # but torch store returns tuples. This works only in the binary protocol 835 # see 836 # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects 837 # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 838 if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): 839 840 if isinstance(obj, torch.storage.TypedStorage): 841 # TODO: Once we decide to break serialization FC, this case 842 # can be deleted 843 storage = obj._untyped_storage 844 storage_dtype = obj.dtype 845 storage_type_str = obj._pickle_storage_type() 846 storage_type = getattr(torch, storage_type_str) 847 storage_numel = obj._size() 848 849 else: 850 storage = obj 851 storage_dtype = torch.uint8 852 storage_type = normalize_storage_type(type(obj)) 853 storage_numel = storage.nbytes() 854 855 # If storage is allocated, ensure that any other saved storages 856 # pointing to the same data all have the same dtype. If storage is 857 # not allocated, don't perform this check 858 if storage.data_ptr() != 0: 859 if storage.data_ptr() in storage_dtypes: 860 if storage_dtype != storage_dtypes[storage.data_ptr()]: 861 raise RuntimeError( 862 'Cannot save multiple tensors or storages that ' 863 'view the same data as different types') 864 else: 865 storage_dtypes[storage.data_ptr()] = storage_dtype 866 867 storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) 868 location = location_tag(storage) 869 serialized_storages[storage_key] = storage 870 871 return ('storage', 872 storage_type, 873 storage_key, 874 location, 875 storage_numel) 876 877 return None 878 879 # Write the pickle data for `obj` 880 data_buf = io.BytesIO() 881 pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) 882 pickler.persistent_id = persistent_id 883 pickler.dump(obj) 884 data_value = data_buf.getvalue() 885 zip_file.write_record('data.pkl', data_value, len(data_value)) 886 887 # Write byte order marker 888 if not _disable_byteorder_record: 889 if sys.byteorder not in ['little', 'big']: 890 raise ValueError('Unknown endianness type: ' + sys.byteorder) 891 892 zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder)) 893 894 # Write each tensor to a file named tensor/the_tensor_key in the zip archive 895 for key in sorted(serialized_storages.keys()): 896 name = f'data/{key}' 897 storage = serialized_storages[key] 898 # given that we copy things around anyway, we might use storage.cpu() 899 # this means to that to get tensors serialized, you need to implement 900 # .cpu() on the underlying Storage 901 if storage.device.type != 'cpu': 902 storage = storage.cpu() 903 # Now that it is on the CPU we can directly copy it into the zip file 904 num_bytes = storage.nbytes() 905 zip_file.write_record(name, storage, num_bytes) 906 907 908def load( 909 f: FILE_LIKE, 910 map_location: MAP_LOCATION = None, 911 pickle_module: Any = None, 912 *, 913 weights_only: Optional[bool] = None, 914 mmap: Optional[bool] = None, 915 **pickle_load_args: Any 916) -> Any: 917 # Reference: https://github.com/pytorch/pytorch/issues/54354 918 # The first line of this docstring overrides the one Sphinx generates for the 919 # documentation. We need it so that Sphinx doesn't leak `pickle`s path from 920 # the build environment (e.g. `<module 'pickle' from '/leaked/path'). 921 922 """load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args) 923 924 Loads an object saved with :func:`torch.save` from a file. 925 926 :func:`torch.load` uses Python's unpickling facilities but treats storages, 927 which underlie tensors, specially. They are first deserialized on the 928 CPU and are then moved to the device they were saved from. If this fails 929 (e.g. because the run time system doesn't have certain devices), an exception 930 is raised. However, storages can be dynamically remapped to an alternative 931 set of devices using the :attr:`map_location` argument. 932 933 If :attr:`map_location` is a callable, it will be called once for each serialized 934 storage with two arguments: storage and location. The storage argument 935 will be the initial deserialization of the storage, residing on the CPU. 936 Each serialized storage has a location tag associated with it which 937 identifies the device it was saved from, and this tag is the second 938 argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'`` 939 for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors. 940 :attr:`map_location` should return either ``None`` or a storage. If 941 :attr:`map_location` returns a storage, it will be used as the final deserialized 942 object, already moved to the right device. Otherwise, :func:`torch.load` will 943 fall back to the default behavior, as if :attr:`map_location` wasn't specified. 944 945 If :attr:`map_location` is a :class:`torch.device` object or a string containing 946 a device tag, it indicates the location where all tensors should be loaded. 947 948 Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags 949 appearing in the file (keys), to ones that specify where to put the 950 storages (values). 951 952 User extensions can register their own location tags and tagging and 953 deserialization methods using :func:`torch.serialization.register_package`. 954 955 Args: 956 f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), 957 or a string or os.PathLike object containing a file name 958 map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage 959 locations 960 pickle_module: module used for unpickling metadata and objects (has to 961 match the :attr:`pickle_module` used to serialize file) 962 weights_only: Indicates whether unpickler should be restricted to 963 loading only tensors, primitive types, dictionaries 964 and any types added via :func:`torch.serialization.add_safe_globals`. 965 mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory. 966 Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they 967 are moved to the location that they were tagged with when saving, or specified by ``map_location``. This 968 second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the 969 tensor storages from disk to CPU memory in the first step, ``f`` is mmaped. 970 pickle_load_args: (Python 3 only) optional keyword arguments passed over to 971 :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g., 972 :attr:`errors=...`. 973 974 .. warning:: 975 :func:`torch.load()` unless `weights_only` parameter is set to `True`, 976 uses ``pickle`` module implicitly, which is known to be insecure. 977 It is possible to construct malicious pickle data which will execute arbitrary code 978 during unpickling. Never load data that could have come from an untrusted 979 source in an unsafe mode, or that could have been tampered with. **Only load data you trust**. 980 981 .. note:: 982 When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors 983 will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')`` 984 and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint. 985 986 .. note:: 987 By default, we decode byte strings as ``utf-8``. This is to avoid a common error 988 case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...`` 989 when loading files saved by Python 2 in Python 3. If this default 990 is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how 991 these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them 992 to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them 993 as byte arrays which can be decoded later with ``byte_array.decode(...)``. 994 995 Example: 996 >>> # xdoctest: +SKIP("undefined filepaths") 997 >>> torch.load('tensors.pt', weights_only=True) 998 # Load all tensors onto the CPU 999 >>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True) 1000 # Load all tensors onto the CPU, using a function 1001 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True) 1002 # Load all tensors onto GPU 1 1003 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True) 1004 # Map tensors from GPU 1 to GPU 0 1005 >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True) 1006 # Load tensor from io.BytesIO object 1007 # Loading from a buffer setting weights_only=False, warning this can be unsafe 1008 >>> with open('tensor.pt', 'rb') as f: 1009 ... buffer = io.BytesIO(f.read()) 1010 >>> torch.load(buffer, weights_only=False) 1011 # Load a module with 'ascii' encoding for unpickling 1012 # Loading from a module setting weights_only=False, warning this can be unsafe 1013 >>> torch.load('module.pt', encoding='ascii', weights_only=False) 1014 """ 1015 torch._C._log_api_usage_once("torch.load") 1016 UNSAFE_MESSAGE = ( 1017 "Re-running `torch.load` with `weights_only` set to `False` will likely succeed, " 1018 "but it can result in arbitrary code execution. Do it only if you got the file from a " 1019 "trusted source." 1020 ) 1021 DOCS_MESSAGE = ( 1022 "\n\nCheck the documentation of torch.load to learn more about types accepted by default with " 1023 "weights_only https://pytorch.org/docs/stable/generated/torch.load.html." 1024 ) 1025 1026 def _get_wo_message(message: str) -> str: 1027 pattern = r"GLOBAL (\S+) was not an allowed global by default." 1028 has_unsafe_global = re.search(pattern, message) is not None 1029 if has_unsafe_global: 1030 updated_message = ( 1031 "Weights only load failed. This file can still be loaded, to do so you have two options " 1032 f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check " 1033 "the recommended steps in the following error message.\n\tWeightsUnpickler error: " 1034 + message 1035 ) 1036 else: 1037 updated_message = ( 1038 f"Weights only load failed. {UNSAFE_MESSAGE}\n Please file an issue with the following " 1039 "so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler " 1040 "error: " + message 1041 ) 1042 return updated_message + DOCS_MESSAGE 1043 1044 if weights_only is None: 1045 weights_only, warn_weights_only = False, True 1046 else: 1047 warn_weights_only = False 1048 1049 # Add ability to force safe only weight loads via environment variable 1050 if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']: 1051 weights_only = True 1052 1053 if weights_only: 1054 if pickle_module is not None: 1055 raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") 1056 else: 1057 if pickle_module is None: 1058 if warn_weights_only: 1059 warnings.warn( 1060 "You are using `torch.load` with `weights_only=False` (the current default value), which uses " 1061 "the default pickle module implicitly. It is possible to construct malicious pickle data " 1062 "which will execute arbitrary code during unpickling (See " 1063 "https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). " 1064 "In a future release, the default value for `weights_only` will be flipped to `True`. This " 1065 "limits the functions that could be executed during unpickling. Arbitrary objects will no " 1066 "longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the " 1067 "user via `torch.serialization.add_safe_globals`. We recommend you start setting " 1068 "`weights_only=True` for any use case where you don't have full control of the loaded file. " 1069 "Please open an issue on GitHub for any issues related to this experimental feature.", 1070 FutureWarning, 1071 stacklevel=2, 1072 ) 1073 pickle_module = pickle 1074 1075 # make flipping default BC-compatible 1076 if mmap is None: 1077 mmap = False 1078 1079 _check_dill_version(pickle_module) 1080 1081 if 'encoding' not in pickle_load_args.keys(): 1082 pickle_load_args['encoding'] = 'utf-8' 1083 1084 with _open_file_like(f, 'rb') as opened_file: 1085 if _is_zipfile(opened_file): 1086 # The zipfile reader is going to advance the current file position. 1087 # If we want to actually tail call to torch.jit.load, we need to 1088 # reset back to the original position. 1089 orig_position = opened_file.tell() 1090 overall_storage = None 1091 with _open_zipfile_reader(opened_file) as opened_zipfile: 1092 if _is_torchscript_zip(opened_zipfile): 1093 warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive" 1094 " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" 1095 " silence this warning)", UserWarning) 1096 opened_file.seek(orig_position) 1097 return torch.jit.load(opened_file, map_location=map_location) 1098 if mmap: 1099 if not _is_path(f): 1100 raise ValueError("f must be a file path in order to use the mmap argument") 1101 size = os.path.getsize(f) 1102 if not IS_WINDOWS: 1103 shared = get_default_mmap_options() == MAP_SHARED 1104 else: 1105 shared = False 1106 overall_storage = torch.UntypedStorage.from_file(os.fspath(f), shared, size) 1107 if weights_only: 1108 try: 1109 return _load(opened_zipfile, 1110 map_location, 1111 _weights_only_unpickler, 1112 overall_storage=overall_storage, 1113 **pickle_load_args) 1114 except RuntimeError as e: 1115 raise pickle.UnpicklingError(_get_wo_message(str(e))) from None 1116 return _load( 1117 opened_zipfile, 1118 map_location, 1119 pickle_module, 1120 overall_storage=overall_storage, 1121 **pickle_load_args, 1122 ) 1123 if mmap: 1124 f_name = "" if not isinstance(f, str) else f"{f}, " 1125 raise RuntimeError("mmap can only be used with files saved with " 1126 f"`torch.save({f_name}_use_new_zipfile_serialization=True), " 1127 "please torch.save your checkpoint with this option in order to use mmap.") 1128 if weights_only: 1129 try: 1130 return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args) 1131 except RuntimeError as e: 1132 raise pickle.UnpicklingError(_get_wo_message(str(e))) from None 1133 return _legacy_load( 1134 opened_file, map_location, pickle_module, **pickle_load_args 1135 ) 1136 1137 1138# Register pickling support for layout instances such as 1139# torch.sparse_coo, etc 1140def _get_layout(name): 1141 """Get layout extension object from its string representation. 1142 """ 1143 cache = _get_layout.cache # type: ignore[attr-defined] 1144 if not cache: 1145 for v in torch.__dict__.values(): 1146 if isinstance(v, torch.layout): 1147 cache[str(v)] = v 1148 return cache[name] 1149 1150# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087 1151_get_layout.cache = {} # type: ignore[attr-defined] 1152copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),))) 1153 1154 1155def _legacy_load(f, map_location, pickle_module, **pickle_load_args): 1156 deserialized_objects: Dict[int, Any] = {} 1157 1158 restore_location = _get_restore_location(map_location) 1159 1160 class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] 1161 1162 def find_class(self, mod_name, name): 1163 if type(name) is str and 'Storage' in name: 1164 try: 1165 return StorageType(name) 1166 except KeyError: 1167 pass 1168 return super().find_class(mod_name, name) 1169 1170 def _check_container_source(container_type, source_file, original_source): 1171 try: 1172 current_source = ''.join(get_source_lines_and_file(container_type)[0]) 1173 except Exception: # saving the source is optional, so we can ignore any errors 1174 warnings.warn("Couldn't retrieve source code for container of " 1175 "type " + container_type.__name__ + ". It won't be checked " 1176 "for correctness upon loading.") 1177 return 1178 if original_source != current_source: 1179 if container_type.dump_patches: 1180 file_name = container_type.__name__ + '.patch' 1181 diff = difflib.unified_diff(current_source.split('\n'), 1182 original_source.split('\n'), 1183 source_file, 1184 source_file, lineterm="") 1185 lines = '\n'.join(diff) 1186 try: 1187 with open(file_name, 'a+') as f: 1188 file_size = f.seek(0, 2) 1189 f.seek(0) 1190 if file_size == 0: 1191 f.write(lines) 1192 elif file_size != len(lines) or f.read() != lines: 1193 raise OSError 1194 msg = ("Saved a reverse patch to " + file_name + ". " 1195 "Run `patch -p0 < " + file_name + "` to revert your " 1196 "changes.") 1197 except OSError: 1198 msg = ("Tried to save a patch, but couldn't create a " 1199 "writable file " + file_name + ". Make sure it " 1200 "doesn't exist and your working directory is " 1201 "writable.") 1202 else: 1203 msg = ("you can retrieve the original source code by " 1204 "accessing the object's source attribute or set " 1205 "`torch.nn.Module.dump_patches = True` and use the " 1206 "patch tool to revert the changes.") 1207 msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}" 1208 warnings.warn(msg, SourceChangeWarning) 1209 1210 def legacy_load(f): 1211 deserialized_objects: Dict[int, Any] = {} 1212 1213 def persistent_load(saved_id): 1214 if isinstance(saved_id, tuple): 1215 # Ignore containers that don't have any sources saved 1216 if all(saved_id[1:]): 1217 _check_container_source(*saved_id) 1218 return saved_id[0] 1219 return deserialized_objects[int(saved_id)] 1220 1221 with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ 1222 mkdtemp() as tmpdir: 1223 1224 tar.extract('storages', path=tmpdir) 1225 with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f: 1226 num_storages = pickle_module.load(f, **pickle_load_args) 1227 for i in range(num_storages): 1228 args = pickle_module.load(f, **pickle_load_args) 1229 key, location, storage_type = args 1230 dtype = storage_type._dtype 1231 obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype)) 1232 obj = restore_location(obj, location) 1233 # TODO: Once we decide to break serialization FC, we can 1234 # stop wrapping with TypedStorage 1235 deserialized_objects[key] = torch.storage.TypedStorage( 1236 wrap_storage=obj, 1237 dtype=dtype, 1238 _internal=True) 1239 1240 storage_views = pickle_module.load(f, **pickle_load_args) 1241 for target_cdata, root_cdata, offset, numel in storage_views: 1242 root = deserialized_objects[root_cdata] 1243 element_size = torch._utils._element_size(root.dtype) 1244 offset_bytes = offset * element_size 1245 # TODO: Once we decide to break serialization FC, we can 1246 # stop wrapping with TypedStorage 1247 deserialized_objects[target_cdata] = torch.storage.TypedStorage( 1248 wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size], 1249 dtype=root.dtype, 1250 _internal=True) 1251 1252 tar.extract('tensors', path=tmpdir) 1253 with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: 1254 num_tensors = pickle_module.load(f, **pickle_load_args) 1255 for _ in range(num_tensors): 1256 args = pickle_module.load(f, **pickle_load_args) 1257 key, storage_id, original_tensor_type = args 1258 storage = deserialized_objects[storage_id] 1259 ndim, = struct.unpack('<i', f.read(4)) 1260 # skip next 4 bytes; legacy encoding treated ndim as 8 bytes 1261 f.read(4) 1262 numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) 1263 stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim)) 1264 storage_offset, = struct.unpack('<q', f.read(8)) 1265 tensor = torch.empty((0,), dtype=storage.dtype).set_( 1266 storage._untyped_storage, storage_offset, numel, stride) 1267 deserialized_objects[key] = tensor 1268 1269 pickle_file = tar.extractfile('pickle') 1270 unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args) 1271 unpickler.persistent_load = persistent_load 1272 result = unpickler.load() 1273 return result 1274 1275 deserialized_objects = {} 1276 1277 def persistent_load(saved_id): 1278 assert isinstance(saved_id, tuple) 1279 typename = _maybe_decode_ascii(saved_id[0]) 1280 data = saved_id[1:] 1281 1282 if typename == 'module': 1283 # Ignore containers that don't have any sources saved 1284 if all(data[1:]): 1285 _check_container_source(*data) 1286 return data[0] 1287 elif typename == 'storage': 1288 storage_type, root_key, location, numel, view_metadata = data 1289 location = _maybe_decode_ascii(location) 1290 dtype = storage_type.dtype 1291 1292 nbytes = numel * torch._utils._element_size(dtype) 1293 1294 if root_key not in deserialized_objects: 1295 if torch._guards.active_fake_mode() is not None: 1296 obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta')) 1297 else: 1298 obj = cast(Storage, torch.UntypedStorage(nbytes)) 1299 obj._torch_load_uninitialized = True 1300 obj = restore_location(obj, location) 1301 # TODO: Once we decide to break serialization FC, we can 1302 # stop wrapping with TypedStorage 1303 typed_storage = torch.storage.TypedStorage( 1304 wrap_storage=obj, 1305 dtype=dtype, 1306 _internal=True) 1307 deserialized_objects[root_key] = typed_storage 1308 else: 1309 typed_storage = deserialized_objects[root_key] 1310 if typed_storage._data_ptr() == 0: 1311 typed_storage = torch.storage.TypedStorage( 1312 device=typed_storage._untyped_storage.device, 1313 dtype=dtype, 1314 _internal=True) 1315 1316 if view_metadata is not None: 1317 view_key, offset, view_size = view_metadata 1318 offset_bytes = offset * torch._utils._element_size(dtype) 1319 view_size_bytes = view_size * torch._utils._element_size(dtype) 1320 if view_key not in deserialized_objects: 1321 # TODO: Once we decide to break serialization FC, we can 1322 # stop wrapping with TypedStorage 1323 deserialized_objects[view_key] = torch.storage.TypedStorage( 1324 wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes], 1325 dtype=dtype, 1326 _internal=True) 1327 res = deserialized_objects[view_key] 1328 1329 else: 1330 res = typed_storage 1331 return res 1332 else: 1333 raise RuntimeError(f"Unknown saved id type: {saved_id[0]}") 1334 1335 _check_seekable(f) 1336 f_should_read_directly = _should_read_directly(f) 1337 1338 if f_should_read_directly and f.tell() == 0: 1339 # legacy_load requires that f has fileno() 1340 # only if offset is zero we can attempt the legacy tar file loader 1341 try: 1342 return legacy_load(f) 1343 except tarfile.TarError: 1344 if _is_zipfile(f): 1345 # .zip is used for torch.jit.save and will throw an un-pickling error here 1346 raise RuntimeError( 1347 f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None 1348 # if not a tarfile, reset file offset and proceed 1349 f.seek(0) 1350 1351 if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2): 1352 raise RuntimeError( 1353 "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. " 1354 f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this ' 1355 "functionality.") 1356 1357 magic_number = pickle_module.load(f, **pickle_load_args) 1358 if magic_number != MAGIC_NUMBER: 1359 raise RuntimeError("Invalid magic number; corrupt file?") 1360 protocol_version = pickle_module.load(f, **pickle_load_args) 1361 if protocol_version != PROTOCOL_VERSION: 1362 raise RuntimeError(f"Invalid protocol version: {protocol_version}") 1363 1364 _sys_info = pickle_module.load(f, **pickle_load_args) 1365 unpickler = UnpicklerWrapper(f, **pickle_load_args) 1366 unpickler.persistent_load = persistent_load 1367 result = unpickler.load() 1368 1369 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) 1370 1371 if torch._guards.active_fake_mode() is None: 1372 offset = f.tell() if f_should_read_directly else None 1373 for key in deserialized_storage_keys: 1374 assert key in deserialized_objects 1375 typed_storage = deserialized_objects[key] 1376 typed_storage._untyped_storage._set_from_file( 1377 f, offset, f_should_read_directly, 1378 torch._utils._element_size(typed_storage.dtype)) 1379 if offset is not None: 1380 offset = f.tell() 1381 1382 torch._utils._validate_loaded_sparse_tensors() 1383 1384 return result 1385 1386 1387def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: 1388 # When using encoding='bytes' in Py3, some **internal** keys stored as 1389 # strings in Py2 are loaded as bytes. This function decodes them with 1390 # ascii encoding, one that Py3 uses by default. 1391 # 1392 # NOTE: This should only be used on internal keys (e.g., `typename` and 1393 # `location` in `persistent_load` below! 1394 if isinstance(bytes_str, bytes): 1395 return bytes_str.decode('ascii') 1396 return bytes_str 1397 1398 1399def _get_restore_location(map_location): 1400 if map_location is None: 1401 restore_location = default_restore_location 1402 elif isinstance(map_location, dict): 1403 def restore_location(storage, location): 1404 location = map_location.get(location, location) 1405 return default_restore_location(storage, location) 1406 elif isinstance(map_location, (str, bytes)): 1407 def restore_location(storage, location): 1408 return default_restore_location(storage, map_location) 1409 elif isinstance(map_location, torch.device): 1410 def restore_location(storage, location): 1411 return default_restore_location(storage, str(map_location)) 1412 else: 1413 def restore_location(storage, location): 1414 result = map_location(storage, location) 1415 if result is None: 1416 result = default_restore_location(storage, location) 1417 return result 1418 return restore_location 1419 1420 1421class StorageType: 1422 def __init__(self, name): 1423 self._dtype = _get_dtype_from_pickle_storage_type(name) 1424 1425 @property 1426 def dtype(self): 1427 return self._dtype 1428 1429 def __str__(self): 1430 return f'StorageType(dtype={self.dtype})' 1431 1432 1433def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args): 1434 restore_location = _get_restore_location(map_location) 1435 1436 loaded_storages = {} 1437 1438 # check if byteswapping is needed 1439 byteordername = 'byteorder' 1440 byteorderdata = None 1441 if zip_file.has_record(byteordername): 1442 byteorderdata = zip_file.get_record(byteordername) 1443 if byteorderdata not in [b'little', b'big']: 1444 raise ValueError('Unknown endianness type: ' + byteorderdata.decode()) 1445 elif get_default_load_endianness() == LoadEndianness.LITTLE or \ 1446 get_default_load_endianness() is None: 1447 byteorderdata = b'little' 1448 elif get_default_load_endianness() == LoadEndianness.BIG: 1449 byteorderdata = b'big' 1450 elif get_default_load_endianness() == LoadEndianness.NATIVE: 1451 pass 1452 else: 1453 raise ValueError('Invalid load endianness type') 1454 1455 if not zip_file.has_record(byteordername) and \ 1456 get_default_load_endianness() is None and \ 1457 sys.byteorder == 'big': 1458 # Default behaviour was changed 1459 # See https://github.com/pytorch/pytorch/issues/101688 1460 warnings.warn("The default load endianness for checkpoints without a byteorder mark " 1461 "on big endian machines was changed from 'native' to 'little' endian, " 1462 "to avoid this behavior please use " 1463 "torch.serialization.set_default_load_endianness to set " 1464 "the desired default load endianness", 1465 UserWarning) 1466 1467 def load_tensor(dtype, numel, key, location): 1468 name = f'data/{key}' 1469 if torch._guards.detect_fake_mode(None) is not None: 1470 nbytes = numel * torch._utils._element_size(dtype) 1471 storage = torch.UntypedStorage(nbytes, device='meta') 1472 elif overall_storage is not None: 1473 storage_offset = zip_file.get_record_offset(name) 1474 storage = overall_storage[storage_offset:storage_offset + numel] 1475 else: 1476 storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage 1477 # swap here if byteswapping is needed 1478 if byteorderdata is not None: 1479 if byteorderdata.decode() != sys.byteorder: 1480 storage.byteswap(dtype) 1481 1482 # TODO: Once we decide to break serialization FC, we can 1483 # stop wrapping with TypedStorage 1484 typed_storage = torch.storage.TypedStorage( 1485 wrap_storage=restore_location(storage, location), 1486 dtype=dtype, 1487 _internal=True) 1488 1489 if typed_storage._data_ptr() != 0: 1490 loaded_storages[key] = typed_storage 1491 1492 return typed_storage 1493 1494 def persistent_load(saved_id): 1495 assert isinstance(saved_id, tuple) 1496 typename = _maybe_decode_ascii(saved_id[0]) 1497 data = saved_id[1:] 1498 1499 assert typename == 'storage', \ 1500 f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" 1501 storage_type, key, location, numel = data 1502 if storage_type is torch.UntypedStorage: 1503 dtype = torch.uint8 1504 else: 1505 dtype = storage_type.dtype 1506 1507 if key in loaded_storages: 1508 typed_storage = loaded_storages[key] 1509 else: 1510 nbytes = numel * torch._utils._element_size(dtype) 1511 typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) 1512 1513 return typed_storage 1514 1515 load_module_mapping: Dict[str, str] = { 1516 # See https://github.com/pytorch/pytorch/pull/51633 1517 'torch.tensor': 'torch._tensor' 1518 } 1519 1520 # Need to subclass Unpickler instead of directly monkey-patching the find_class method 1521 # because it's marked readonly in pickle. 1522 # The type: ignore is because mypy can't statically determine the type of this class. 1523 class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] 1524 # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732 1525 # Lets us override the imports that pickle uses when unpickling an object. 1526 # This is useful for maintaining BC if we change a module path that tensor instantiation relies on. 1527 def find_class(self, mod_name, name): 1528 if type(name) is str and 'Storage' in name: 1529 try: 1530 return StorageType(name) 1531 except KeyError: 1532 pass 1533 mod_name = load_module_mapping.get(mod_name, mod_name) 1534 return super().find_class(mod_name, name) 1535 1536 # Load the data (which may in turn use `persistent_load` to load tensors) 1537 data_file = io.BytesIO(zip_file.get_record(pickle_file)) 1538 1539 unpickler = UnpicklerWrapper(data_file, **pickle_load_args) 1540 unpickler.persistent_load = persistent_load 1541 # Needed for tensors where storage device and rebuild tensor device are 1542 # not connected (wrapper subclasses and tensors rebuilt using numpy) 1543 torch._utils._thread_local_state.map_location = map_location 1544 result = unpickler.load() 1545 del torch._utils._thread_local_state.map_location 1546 1547 torch._utils._validate_loaded_sparse_tensors() 1548 torch._C._log_api_usage_metadata( 1549 "torch.load.metadata", {"serialization_id": zip_file.serialization_id()} 1550 ) 1551 return result 1552 1553 1554def _is_torchscript_zip(zip_file): 1555 return 'constants.pkl' in zip_file.get_all_records() 1556