1# mypy: allow-untyped-defs 2 3from __future__ import annotations 4 5import collections 6import copy 7import functools 8import io 9import threading 10import warnings 11from typing import Any, cast, Dict as _Dict, Optional as _Optional, Type, TypeVar, Union 12from typing_extensions import Self 13 14import torch 15from torch._utils import _to, _type 16from torch.types import _bool, _int, Storage 17 18 19__all__ = ["TypedStorage", "UntypedStorage"] 20 21 22try: 23 import numpy as np 24 25 HAS_NUMPY = True 26except ModuleNotFoundError: 27 HAS_NUMPY = False 28 np = None # type: ignore[assignment] 29 30 31_share_memory_lock = threading.Lock() 32_share_memory_map: _Dict[int, threading.RLock] = {} 33 34T = TypeVar("T", bound="Union[_StorageBase, TypedStorage]") 35 36 37class _StorageBase: 38 _cdata: Any 39 is_sparse: _bool = False 40 is_sparse_csr: _bool = False 41 device: torch.device 42 # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) 43 _fake_device: _Optional[torch.device] = None 44 45 def __init__(self, *args, **kwargs): 46 pass 47 48 def __len__(self) -> _int: 49 raise NotImplementedError 50 51 def __getitem__(self, idx): 52 raise NotImplementedError 53 54 def __setitem__(self, *args, **kwargs): 55 raise NotImplementedError 56 57 def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T: 58 raise NotImplementedError 59 60 def new(self) -> Union[_StorageBase, TypedStorage]: 61 raise NotImplementedError 62 63 def nbytes(self) -> _int: 64 raise NotImplementedError 65 66 def size(self) -> _int: 67 return self.nbytes() 68 69 def type( 70 self, dtype: _Optional[str] = None, non_blocking: _bool = False 71 ) -> Union[_StorageBase, TypedStorage]: 72 return _type(self, dtype, non_blocking) 73 74 def cuda( 75 self, device=None, non_blocking=False 76 ) -> Union[_StorageBase, TypedStorage]: 77 """Returns a copy of this object in CUDA memory. 78 79 If this object is already in CUDA memory and on the correct device, then 80 no copy is performed and the original object is returned. 81 82 Args: 83 device (int): The destination GPU id. Defaults to the current device. 84 non_blocking (bool): If ``True`` and the source is in pinned memory, 85 the copy will be asynchronous with respect to the host. Otherwise, 86 the argument has no effect. 87 """ 88 device2 = torch.device("cuda", device) if device else torch.device("cuda") 89 return self.to(device=device2, non_blocking=non_blocking) 90 91 def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]: 92 """Returns a copy of this object in HPU memory. 93 94 If this object is already in HPU memory and on the correct device, then 95 no copy is performed and the original object is returned. 96 97 Args: 98 device (int): The destination HPU id. Defaults to the current device. 99 non_blocking (bool): If ``True`` and the source is in pinned memory, 100 the copy will be asynchronous with respect to the host. Otherwise, 101 the argument has no effect. 102 """ 103 device2 = torch.device("hpu", device) if device else torch.device("hpu") 104 return self.to(device=device2, non_blocking=non_blocking) 105 106 def element_size(self) -> _int: 107 raise NotImplementedError 108 109 def get_device(self) -> _int: 110 return self.device.index 111 112 def data_ptr(self) -> _int: 113 raise NotImplementedError 114 115 def resizable(self) -> _bool: 116 raise NotImplementedError 117 118 # Defined in torch/csrc/generic/StorageSharing.cpp 119 def _share_filename_cpu_(self, *args, **kwargs): 120 raise NotImplementedError 121 122 def _share_fd_cpu_(self, *args, **kwargs): 123 raise NotImplementedError 124 125 @classmethod 126 def _new_using_filename_cpu(cls: Type[T], size: _int) -> T: 127 raise NotImplementedError 128 129 @classmethod 130 def _new_using_fd_cpu(cls: Type[T], size: _int) -> T: 131 raise NotImplementedError 132 133 @classmethod 134 def from_buffer(cls: Type[T], *args, **kwargs) -> T: 135 raise NotImplementedError 136 137 @classmethod 138 def _new_shared_filename_cpu( 139 cls: Type[T], 140 manager, 141 obj, 142 size, 143 *, 144 device=None, 145 dtype=None, 146 ) -> T: 147 raise NotImplementedError 148 149 @classmethod 150 def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T: 151 raise NotImplementedError 152 153 @classmethod 154 def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T: 155 raise NotImplementedError 156 157 def _shared_decref(self) -> Union[_StorageBase, TypedStorage]: 158 raise NotImplementedError 159 160 def _write_file(self, *args, **kwargs): 161 raise NotImplementedError 162 163 def resize_(self, size: _int): 164 raise NotImplementedError 165 166 def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: 167 raise NotImplementedError 168 169 def _set_from_file(self, *args, **kwargs): 170 raise NotImplementedError 171 172 def _set_cdata(self, *args, **kwargs): 173 raise NotImplementedError 174 175 def _share_cuda_(self, *args, **kwargs): 176 raise NotImplementedError 177 178 def is_shared(self) -> _bool: 179 raise NotImplementedError 180 181 @classmethod 182 def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T: 183 raise NotImplementedError 184 185 def _shared_incref(self, *args, **kwargs): 186 raise NotImplementedError 187 188 @classmethod 189 def _free_weak_ref(cls, *args, **kwargs): 190 raise NotImplementedError 191 192 @property 193 def is_cuda(self): 194 raise NotImplementedError 195 196 @property 197 def is_hpu(self): 198 raise NotImplementedError 199 200 @classmethod 201 def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]: 202 raise NotImplementedError 203 204 @classmethod 205 def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: 206 raise NotImplementedError 207 208 def _byteswap(self, *args, **kwargs): 209 raise NotImplementedError 210 211 def _get_filename(self, *args, **kwargs) -> _Optional[str]: 212 raise NotImplementedError 213 214 def __repr__(self): 215 info_str = f"[{torch.typename(self)}(device={self.device}) of size {len(self)}]" 216 if self.device.type == "meta": 217 return "...\n" + info_str 218 data_str = " " + "\n ".join(str(self[i]) for i in range(self.size())) 219 return data_str + "\n" + info_str 220 221 def __iter__(self): 222 return iter(self[i] for i in range(self.size())) 223 224 def __copy__(self): 225 return self.clone() 226 227 def __deepcopy__(self, memo): 228 memo = memo.setdefault("torch", {}) 229 if self._cdata in memo: 230 return memo[self._cdata] 231 new_storage = self.clone() 232 memo[self._cdata] = new_storage 233 return new_storage 234 235 def __reduce__(self): 236 b = io.BytesIO() 237 torch.save(self, b, _use_new_zipfile_serialization=False) 238 return (_load_from_bytes, (b.getvalue(),)) 239 240 def __sizeof__(self): 241 return super().__sizeof__() + self.size() 242 243 def clone(self): 244 """Return a copy of this storage.""" 245 return type(self)(self.nbytes(), device=self.device).copy_(self) 246 247 def tolist(self): 248 """Return a list containing the elements of this storage.""" 249 return list(self) 250 251 def cpu(self): 252 """Return a CPU copy of this storage if it's not already on the CPU.""" 253 if self.device.type != "cpu": 254 return torch.UntypedStorage(self.size()).copy_(self, False) 255 return self 256 257 def mps(self): 258 """Return a MPS copy of this storage if it's not already on the MPS.""" 259 if self.device.type != "mps": 260 return torch.UntypedStorage(self.size(), device="mps").copy_(self, False) 261 return self 262 263 def _to(self, dtype): 264 if not isinstance(dtype, torch.dtype): 265 raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}") 266 storage = ( 267 torch.tensor([], dtype=torch.uint8, device=self.device) 268 .set_(cast(Storage, self)) 269 .to(dtype) 270 ._typed_storage() 271 ) 272 if storage.data_ptr() == self.data_ptr(): 273 storage = storage.clone() 274 return storage 275 276 def to( 277 self, *, device: torch.device, non_blocking: _bool = False 278 ) -> Union[_StorageBase, TypedStorage]: 279 return _to(self, device, non_blocking) 280 281 def double(self): 282 """Casts this storage to double type.""" 283 return self._to(torch.double) 284 285 def float(self): 286 """Casts this storage to float type.""" 287 return self._to(torch.float) 288 289 def half(self): 290 """Casts this storage to half type.""" 291 return self._to(torch.half) 292 293 def long(self): 294 """Casts this storage to long type.""" 295 return self._to(torch.long) 296 297 def int(self): 298 """Casts this storage to int type.""" 299 return self._to(torch.int) 300 301 def short(self): 302 """Casts this storage to short type.""" 303 return self._to(torch.short) 304 305 def char(self): 306 """Casts this storage to char type.""" 307 return self._to(torch.int8) 308 309 def byte(self): 310 """Casts this storage to byte type.""" 311 return self._to(torch.uint8) 312 313 def bool(self): 314 """Casts this storage to bool type.""" 315 return self._to(torch.bool) 316 317 def bfloat16(self): 318 """Casts this storage to bfloat16 type.""" 319 return self._to(torch.bfloat16) 320 321 def complex_double(self): 322 """Casts this storage to complex double type.""" 323 return self._to(torch.cdouble) 324 325 def complex_float(self): 326 """Casts this storage to complex float type.""" 327 return self._to(torch.cfloat) 328 329 def float8_e5m2(self): 330 """Casts this storage to float8_e5m2 type""" 331 return self._to(torch.float8_e5m2) 332 333 def float8_e4m3fn(self): 334 """Casts this storage to float8_e4m3fn type""" 335 return self._to(torch.float8_e4m3fn) 336 337 def float8_e5m2fnuz(self): 338 """Casts this storage to float8_e5m2fnuz type""" 339 return self._to(torch.float8_e5m2fnuz) 340 341 def float8_e4m3fnuz(self): 342 """Casts this storage to float8_e4m3fnuz type""" 343 return self._to(torch.float8_e4m3fnuz) 344 345 def is_pinned(self, device: Union[str, torch.device] = "cuda"): 346 r"""Determine whether the CPU storage is already pinned on device. 347 348 Args: 349 device (str or torch.device): The device to pin memory on. Default: ``'cuda'``. 350 351 Returns: 352 A boolean variable. 353 """ 354 return ( 355 torch.tensor([], dtype=torch.uint8, device=self.device) 356 .set_(cast(Storage, self)) 357 .is_pinned(device) 358 ) 359 360 def pin_memory(self, device: Union[str, torch.device] = "cuda"): 361 r"""Copy the CPU storage to pinned memory, if it's not already pinned. 362 363 Args: 364 device (str or torch.device): The device to pin memory on. Default: ``'cuda'``. 365 366 Returns: 367 A pinned CPU storage. 368 """ 369 if self.device.type != "cpu": 370 raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned") 371 372 pinned_tensor = ( 373 torch.tensor([], dtype=torch.uint8, device=self.device) 374 .set_(cast(Storage, self)) 375 .pin_memory(device) 376 ) 377 return pinned_tensor.untyped_storage() 378 379 def share_memory_(self): 380 """See :meth:`torch.UntypedStorage.share_memory_`""" 381 from torch.multiprocessing import get_sharing_strategy 382 383 if self.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: 384 pass # CUDA or PrivateUse1 doesn't use POSIX shared memory 385 elif get_sharing_strategy() == "file_system": 386 self._share_filename_cpu_() 387 else: 388 self._share_fd_cpu_() 389 return self 390 391 @classmethod 392 def _new_shared(cls, size, *, device="cpu"): 393 """Create a new storage in shared memory with the same data type.""" 394 from torch.multiprocessing import get_sharing_strategy 395 396 device = torch.device(device) 397 if device.type in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]: 398 return cls(size, device=device) 399 elif get_sharing_strategy() == "file_system": 400 return cls._new_using_filename_cpu(size) 401 else: 402 return cls._new_using_fd_cpu(size) 403 404 def untyped(self): 405 return self 406 407 def byteswap(self, dtype): 408 """Swap bytes in underlying data.""" 409 elem_size = torch._utils._element_size(dtype) 410 # for complex types, don't swap first and second numbers 411 if dtype.is_complex: 412 elem_size = max(int(elem_size / 2), 1) 413 self._byteswap(elem_size) 414 415 416def _share_memory_lock_protected(fn): 417 @functools.wraps(fn) 418 def wrapper(self, *args, **kwargs): 419 to_free = None 420 to_wait = None 421 with _share_memory_lock: 422 key = self._cdata 423 if key in _share_memory_map: 424 to_wait = _share_memory_map[key] 425 else: 426 _share_memory_map[key] = threading.RLock() 427 _share_memory_map[key].acquire() 428 to_free = key 429 430 # If we're already in the process of sharing the storage, wait 431 # for it to be done. 432 if to_wait is not None: 433 with to_wait: 434 pass 435 436 try: 437 return fn(self, *args, **kwargs) 438 finally: 439 # If we acquired the storage lock here and we're done working on it 440 # we can now release it and free the entry. 441 if to_free is not None: 442 # Ensure that the cdata from the storage didn't change and only 443 # the data_ptr did. 444 assert self._cdata == to_free 445 with _share_memory_lock: 446 _share_memory_map[to_free].release() 447 del _share_memory_map[to_free] 448 449 return wrapper 450 451 452class UntypedStorage(torch._C.StorageBase, _StorageBase): 453 def __getitem__(self, *args, **kwargs): 454 if self.device.type == "meta": 455 raise NotImplementedError("Not available for 'meta' device type") 456 return super().__getitem__(*args, **kwargs) 457 458 @property 459 def is_cuda(self): 460 return self.device.type == "cuda" 461 462 @property 463 def is_hpu(self): 464 return self.device.type == "hpu" 465 466 @property 467 def filename(self) -> _Optional[str]: 468 """Returns the file name associated with this storage. 469 470 The file name will be a string if the storage is on CPU and was created via 471 :meth:`~torch.from_file()` with ``shared`` as ``True``. This attribute is ``None`` otherwise. 472 """ 473 return self._get_filename() 474 475 @_share_memory_lock_protected 476 def share_memory_(self, *args, **kwargs): 477 """ 478 Moves the storage to shared memory. 479 480 This is a no-op for storages already in shared memory and for CUDA 481 storages, which do not need to be moved for sharing across processes. 482 Storages in shared memory cannot be resized. 483 484 Note that to mitigate issues like `this <https://github.com/pytorch/pytorch/issues/95606>`_ 485 it is thread safe to call this function from multiple threads on the same object. 486 It is NOT thread safe though to call any other function on self without proper 487 synchronization. Please see :doc:`/notes/multiprocessing` for more details. 488 489 .. note:: 490 When all references to a storage in shared memory are deleted, the associated shared memory 491 object will also be deleted. PyTorch has a special cleanup process to ensure that this happens 492 even if the current process exits unexpectedly. 493 494 It is worth noting the difference between :meth:`share_memory_` and :meth:`from_file` with ``shared = True`` 495 496 #. ``share_memory_`` uses `shm_open(3) <https://man7.org/linux/man-pages/man3/shm_open.3.html>`_ to create a 497 POSIX shared memory object while :meth:`from_file` uses 498 `open(2) <https://man7.org/linux/man-pages/man2/open.2.html>`_ to open the filename passed by the user. 499 #. Both use an `mmap(2) call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_ with ``MAP_SHARED`` 500 to map the file/object into the current virtual address space 501 #. ``share_memory_`` will call ``shm_unlink(3)`` on the object after mapping it to make sure the shared memory 502 object is freed when no process has the object open. ``torch.from_file(shared=True)`` does not unlink the 503 file. This file is persistent and will remain until it is deleted by the user. 504 505 Returns: 506 ``self`` 507 """ 508 return super().share_memory_(*args, **kwargs) 509 510 @_share_memory_lock_protected 511 def _share_fd_cpu_(self, *args, **kwargs): 512 return super()._share_fd_cpu_(*args, **kwargs) 513 514 @_share_memory_lock_protected 515 def _share_filename_cpu_(self, *args, **kwargs): 516 return super()._share_filename_cpu_(*args, **kwargs) 517 518 519def _load_from_bytes(b): 520 return torch.load(io.BytesIO(b), weights_only=False) 521 522 523@functools.lru_cache(maxsize=None) 524def _new_dtypes(): 525 # These are dtypes serialized as UntypedStorage unlike those in 526 # _dtype_to_storage_type_map 527 return { 528 torch.float8_e5m2, 529 torch.float8_e4m3fn, 530 torch.float8_e5m2fnuz, 531 torch.float8_e4m3fnuz, 532 torch.bits8, 533 torch.bits16, 534 torch.bits1x8, 535 torch.bits2x4, 536 torch.bits4x2, 537 torch.complex32, 538 } 539 540 541@functools.lru_cache(maxsize=None) 542def _dtype_to_storage_type_map(): 543 # NOTE: We should no longer add dtypes to this map. This map 544 # is only used for BC/FC with older PyTorch versions. Going forward, 545 # new dtypes of TypedStorage should not translate to a legacy 546 # <type>Storage class. Instead, new dtypes of TypedStorage should 547 # be serialized as an UntypedStorage paired with a torch.dtype 548 return { 549 torch.double: "DoubleStorage", 550 torch.float: "FloatStorage", 551 torch.half: "HalfStorage", 552 torch.long: "LongStorage", 553 torch.int: "IntStorage", 554 torch.int16: "ShortStorage", 555 torch.int8: "CharStorage", 556 torch.uint8: "ByteStorage", 557 torch.bool: "BoolStorage", 558 torch.bfloat16: "BFloat16Storage", 559 torch.cdouble: "ComplexDoubleStorage", 560 torch.cfloat: "ComplexFloatStorage", 561 torch.qint8: "QInt8Storage", 562 torch.qint32: "QInt32Storage", 563 torch.quint8: "QUInt8Storage", 564 torch.quint4x2: "QUInt4x2Storage", 565 torch.quint2x4: "QUInt2x4Storage", 566 } 567 568 569@functools.lru_cache(maxsize=None) 570def _storage_type_to_dtype_map(): 571 dtype_map = {val: key for key, val in _dtype_to_storage_type_map().items()} 572 return dtype_map 573 574 575def _get_storage_from_sequence(sequence, dtype, device): 576 if dtype in [ 577 torch.quint8, 578 torch.quint4x2, 579 torch.quint2x4, 580 torch.qint32, 581 torch.qint8, 582 ]: 583 interpret_dtypes = { 584 torch.quint8: torch.uint8, 585 torch.quint4x2: torch.uint8, 586 torch.quint2x4: torch.uint8, 587 torch.qint32: torch.int32, 588 torch.qint8: torch.int8, 589 } 590 tmp_tensor = torch.tensor( 591 sequence, dtype=interpret_dtypes[dtype], device=device 592 ) 593 594 else: 595 tmp_tensor = torch.tensor(sequence, dtype=dtype, device=device) 596 597 return tmp_tensor._typed_storage()._untyped_storage 598 599 600def _isint(x): 601 if HAS_NUMPY: 602 return isinstance(x, (int, np.integer)) 603 else: 604 return isinstance(x, int) 605 606 607_always_warn_typed_storage_removal = False 608 609 610def _get_always_warn_typed_storage_removal(): 611 return _always_warn_typed_storage_removal 612 613 614def _set_always_warn_typed_storage_removal(always_warn): 615 global _always_warn_typed_storage_removal 616 assert isinstance(always_warn, bool) 617 _always_warn_typed_storage_removal = always_warn 618 619 620def _warn_typed_storage_removal(stacklevel=2): 621 global _always_warn_typed_storage_removal 622 623 def is_first_time(): 624 if not hasattr(_warn_typed_storage_removal, "has_warned"): 625 return True 626 else: 627 return not _warn_typed_storage_removal.__dict__["has_warned"] 628 629 if _get_always_warn_typed_storage_removal() or is_first_time(): 630 message = ( 631 "TypedStorage is deprecated. It will be removed in the future and " 632 "UntypedStorage will be the only storage class. This should only matter " 633 "to you if you are using storages directly. To access UntypedStorage " 634 "directly, use tensor.untyped_storage() instead of tensor.storage()" 635 ) 636 warnings.warn(message, UserWarning, stacklevel=stacklevel + 1) 637 _warn_typed_storage_removal.__dict__["has_warned"] = True 638 639 640def _reset_warn_typed_storage_removal(): 641 _warn_typed_storage_removal.__dict__["has_warned"] = False 642 643 644def _get_device_from_module(module: str): 645 last_part = module.rsplit(".", 1)[-1] 646 if last_part in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]: 647 return last_part 648 else: 649 return "cpu" 650 651 652class TypedStorage: 653 is_sparse: _bool = False 654 # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) 655 _fake_device: _Optional[torch.device] = None 656 657 dtype: torch.dtype 658 659 @property 660 def _dtype(self): 661 return self.dtype 662 663 @property 664 def filename(self) -> _Optional[str]: 665 """Returns the file name associated with this storage if the storage was memory mapped from a file. 666 or ``None`` if the storage was not created by memory mapping a file.""" 667 return self._untyped_storage.filename 668 669 def fill_(self, value): 670 _warn_typed_storage_removal() 671 self._setitem(slice(0, self._size()), value) 672 return self 673 674 def __new__( 675 cls, 676 *args, 677 wrap_storage=None, 678 dtype=None, 679 device=None, 680 _internal=False, 681 ): 682 if not _internal: 683 _warn_typed_storage_removal() 684 685 if cls == torch.storage._LegacyStorage: 686 raise RuntimeError( 687 "Only child classes of _LegacyStorage can be instantiated" 688 ) 689 690 if cls == TypedStorage: 691 return super().__new__(cls) 692 693 else: 694 arg_error_msg = ( 695 f"{cls}.__new__ received an invalid combination " 696 f"of arguments. Expected one of:\n" 697 " * no arguments\n" 698 " * (int size)\n" 699 " * (Sequence data)\n" 700 " * (*, UntypedStorage wrap_storage)" 701 ) 702 703 if device is not None: 704 raise RuntimeError( 705 arg_error_msg + "\nKeyword argument 'device' cannot be specified" 706 ) 707 708 if dtype is not None: 709 raise RuntimeError( 710 arg_error_msg + "\nKeyword argument 'dtype' cannot be specified" 711 ) 712 713 if wrap_storage is None: 714 if len(args) > 1: 715 raise RuntimeError( 716 arg_error_msg + "\nToo many positional arguments" 717 ) 718 719 if ( 720 len(args) == 1 721 and not _isint(args[0]) 722 and not isinstance(args[0], collections.abc.Sequence) 723 ): 724 raise TypeError( 725 arg_error_msg 726 + f"\nArgument type not recognized: {type(args[0])}" 727 ) 728 729 return TypedStorage( 730 *args, 731 dtype=cls._dtype, 732 device=_get_device_from_module(cls.__module__), 733 _internal=True, 734 ) 735 736 else: 737 if len(args) != 0: 738 raise RuntimeError( 739 arg_error_msg 740 + "\nNo positional arguments should be given when using " 741 "'wrap_storage'" 742 ) 743 744 if not isinstance(wrap_storage, torch.UntypedStorage): 745 raise TypeError( 746 arg_error_msg 747 + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}" 748 ) 749 750 cls_device = _get_device_from_module(cls.__module__) 751 752 if wrap_storage.device.type != cls_device: 753 raise RuntimeError( 754 arg_error_msg 755 + f"\nDevice of 'wrap_storage' must be {cls_device}" 756 f", but got {wrap_storage.device.type}" 757 ) 758 759 return TypedStorage( 760 *args, 761 wrap_storage=wrap_storage, 762 dtype=cls.dtype, 763 _internal=True, 764 ) 765 766 def __init__( 767 self, 768 *args, 769 device=None, 770 dtype=None, 771 wrap_storage=None, 772 _internal=False, 773 ): 774 if not _internal: 775 _warn_typed_storage_removal() 776 arg_error_msg = ( 777 "TypedStorage.__init__ received an invalid combination " 778 "of arguments. Expected one of:\n" 779 " * (*, torch.device device, torch.dtype dtype)\n" 780 " * (int size, *, torch.device device, torch.dtype dtype)\n" 781 " * (Sequence data, *, torch.device device, torch.dtype dtype)\n" 782 " * (*, UntypedStorage wrap_storage, torch.dtype dtype)" 783 ) 784 785 if wrap_storage is not None: 786 if len(args) != 0: 787 raise RuntimeError( 788 arg_error_msg 789 + "\nNo positional arguments should be given when using " 790 "'wrap_storage'" 791 ) 792 793 if dtype is None: 794 raise RuntimeError( 795 arg_error_msg + "\nArgument 'dtype' must be specified" 796 ) 797 798 if not isinstance(dtype, torch.dtype): 799 raise TypeError( 800 arg_error_msg 801 + f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}" 802 ) 803 804 if device is not None: 805 raise RuntimeError( 806 arg_error_msg 807 + "\nArgument 'device' should not be specified when 'wrap_storage' is given" 808 ) 809 810 self.dtype = dtype 811 812 if not isinstance(wrap_storage, torch.UntypedStorage): 813 raise TypeError( 814 arg_error_msg 815 + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}" 816 ) 817 818 self._untyped_storage = wrap_storage 819 820 else: 821 self.dtype = torch.get_default_dtype() if dtype is None else dtype 822 device = torch.device("cpu" if device is None else device) 823 824 if self.dtype in [ 825 torch.quint8, 826 torch.quint4x2, 827 torch.quint2x4, 828 torch.qint32, 829 torch.qint8, 830 ]: 831 if device.type == "cuda": 832 raise RuntimeError( 833 "Cannot create CUDA storage with quantized dtype" 834 ) 835 836 if len(args) == 0: 837 self._untyped_storage = torch.UntypedStorage(device=device) 838 839 elif len(args) == 1: 840 if _isint(args[0]): 841 self._untyped_storage = torch.UntypedStorage( 842 int(args[0]) * self._element_size(), device=device 843 ) 844 elif isinstance(args[0], collections.abc.Sequence): 845 self._untyped_storage = _get_storage_from_sequence( 846 args[0], self.dtype, device 847 ) 848 else: 849 raise TypeError( 850 arg_error_msg 851 + f"\nArgument type not recognized: {type(args[0])}" 852 ) 853 854 else: 855 raise RuntimeError(arg_error_msg + "\nToo many positional arguments") 856 857 @property 858 def is_cuda(self): 859 _warn_typed_storage_removal() 860 return self._untyped_storage.device.type == "cuda" 861 862 @property 863 def is_hpu(self): 864 _warn_typed_storage_removal() 865 return self._untyped_storage.device.type == "hpu" 866 867 def untyped(self): 868 """Return the internal :class:`torch.UntypedStorage`.""" 869 _warn_typed_storage_removal() 870 return self._untyped_storage 871 872 def _new_wrapped_storage(self, untyped_storage) -> Self: 873 assert type(untyped_storage) == torch.UntypedStorage 874 875 if type(self) == TypedStorage: 876 return cast( 877 Self, 878 TypedStorage( 879 wrap_storage=untyped_storage, dtype=self.dtype, _internal=True 880 ), 881 ) 882 else: 883 return type(self)(wrap_storage=untyped_storage) 884 885 def __len__(self): 886 _warn_typed_storage_removal() 887 return self._size() 888 889 def _maybe_wrap_index(self, idx, is_stop=False): 890 if idx is None: 891 if is_stop: 892 return self._size() 893 else: 894 return 0 895 896 else: 897 if type(idx) != int: 898 raise TypeError(f"can't index a {type(self)} with {type(idx)}") 899 if is_stop: 900 if (idx > self._size()) or (idx < -self._size()): 901 raise IndexError( 902 f"index {idx} out of range for storage of size {self.size()}" 903 ) 904 if idx > 0: 905 return idx 906 else: 907 return idx % self._size() 908 else: 909 if (idx >= self._size()) or (idx < -self._size()): 910 raise IndexError( 911 f"index {idx} out of range for storage of size {self.size()}" 912 ) 913 return idx % self._size() 914 915 def __setitem__(self, idx, value): 916 _warn_typed_storage_removal() 917 return self._setitem(idx, value) 918 919 def _setitem(self, idx, value): 920 if not isinstance(idx, (int, slice)): 921 raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") 922 if torch.is_storage(value): 923 raise RuntimeError(f"cannot set item with value type {type(value)}") 924 if self.dtype in [ 925 torch.quint8, 926 torch.quint4x2, 927 torch.quint2x4, 928 torch.qint32, 929 torch.qint8, 930 ]: 931 interpret_dtypes = { 932 torch.quint8: torch.uint8, 933 torch.quint4x2: torch.uint8, 934 torch.quint2x4: torch.uint8, 935 torch.qint32: torch.int32, 936 torch.qint8: torch.int8, 937 } 938 tmp_dtype = interpret_dtypes[self.dtype] 939 tmp_tensor = torch.tensor( 940 [], dtype=tmp_dtype, device=self._untyped_storage.device 941 ) 942 tmp_tensor.set_( 943 TypedStorage( 944 wrap_storage=self._untyped_storage, dtype=tmp_dtype, _internal=True 945 ) 946 ) 947 else: 948 tmp_tensor = torch.tensor( 949 [], dtype=self.dtype, device=self._untyped_storage.device 950 ).set_(self) 951 952 tmp_tensor[idx] = value 953 954 def __getitem__(self, idx): 955 _warn_typed_storage_removal() 956 return self._getitem(idx) 957 958 def _getitem(self, idx): 959 if self._untyped_storage.device.type == "meta": 960 raise NotImplementedError("Not available for 'meta' device type") 961 962 # NOTE: Before TypedStorage existed, indexing with a slice used to be 963 # possible for <type>Storage objects. However, it would return 964 # a storage view, which would be a hassle to implement in TypedStorage, 965 # so it was disabled 966 if isinstance(idx, slice): 967 raise RuntimeError( 968 "slices are only supported in UntypedStorage.__getitem__" 969 ) 970 elif not isinstance(idx, int): 971 raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") 972 973 if self.dtype in [ 974 torch.quint8, 975 torch.quint4x2, 976 torch.quint2x4, 977 torch.qint32, 978 torch.qint8, 979 ]: 980 interpret_dtypes = { 981 torch.quint8: torch.uint8, 982 torch.quint4x2: torch.uint8, 983 torch.quint2x4: torch.uint8, 984 torch.qint32: torch.int32, 985 torch.qint8: torch.int8, 986 } 987 return TypedStorage( 988 wrap_storage=self._untyped_storage, 989 dtype=interpret_dtypes[self.dtype], 990 _internal=True, 991 )._getitem(idx) 992 993 idx_wrapped = self._maybe_wrap_index(idx) 994 from torch._subclasses.fake_tensor import unset_fake_temporarily 995 996 with unset_fake_temporarily(): 997 tmp_tensor = torch.tensor( 998 [], dtype=self.dtype, device=self._untyped_storage.device 999 ).set_(self) 1000 return tmp_tensor[idx_wrapped].item() 1001 1002 def copy_(self, source: T, non_blocking: _Optional[bool] = None): 1003 _warn_typed_storage_removal() 1004 if isinstance(source, TypedStorage): 1005 self._untyped_storage.copy_(source._untyped_storage, non_blocking) 1006 else: 1007 self._untyped_storage.copy_(source, non_blocking) 1008 return self 1009 1010 def nbytes(self): 1011 _warn_typed_storage_removal() 1012 return self._nbytes() 1013 1014 # For internal use only, to avoid deprecation warning 1015 def _nbytes(self): 1016 return self._untyped_storage.nbytes() 1017 1018 def type( 1019 self, 1020 dtype: _Optional[str] = None, 1021 non_blocking: bool = False, 1022 ) -> Union[_StorageBase, TypedStorage, str]: 1023 _warn_typed_storage_removal() 1024 if dtype is None: 1025 legacy_class = self._get_legacy_storage_class() 1026 1027 if legacy_class is not None: 1028 return legacy_class.__module__ + "." + legacy_class.__name__ 1029 1030 return ".".join([self.__module__, type(self).__name__]) 1031 1032 else: 1033 return self._untyped_storage.type(dtype, non_blocking) 1034 1035 def cuda(self, device=None, non_blocking=False) -> Self: 1036 _warn_typed_storage_removal() 1037 if self.dtype in [ 1038 torch.quint8, 1039 torch.quint4x2, 1040 torch.quint2x4, 1041 torch.qint32, 1042 torch.qint8, 1043 ]: 1044 raise RuntimeError("Cannot create CUDA storage with quantized dtype") 1045 cuda_storage = self._untyped_storage.cuda(device, non_blocking) 1046 return self._new_wrapped_storage(cuda_storage) 1047 1048 def hpu(self, device=None, non_blocking=False) -> Self: 1049 _warn_typed_storage_removal() 1050 if self.dtype in [ 1051 torch.quint8, 1052 torch.quint4x2, 1053 torch.quint2x4, 1054 torch.qint32, 1055 torch.qint8, 1056 ]: 1057 raise RuntimeError("Cannot create HPU storage with quantized dtype") 1058 hpu_storage = self._untyped_storage.hpu(device, non_blocking) 1059 return self._new_wrapped_storage(hpu_storage) 1060 1061 def to(self, *, device: torch.device, non_blocking: bool = False) -> Self: 1062 _warn_typed_storage_removal() 1063 if self.dtype in [ 1064 torch.quint8, 1065 torch.quint4x2, 1066 torch.quint2x4, 1067 torch.qint32, 1068 torch.qint8, 1069 ]: 1070 raise RuntimeError( 1071 f"Cannot create {device.type.upper()} storage with quantized dtype" 1072 ) 1073 to_storage = self._untyped_storage.to(device=device, non_blocking=non_blocking) 1074 return self._new_wrapped_storage(to_storage) 1075 1076 def element_size(self): 1077 _warn_typed_storage_removal() 1078 return self._element_size() 1079 1080 # For internal use only, to avoid deprecation warning 1081 def _element_size(self): 1082 return torch._utils._element_size(self.dtype) 1083 1084 def get_device(self) -> _int: 1085 _warn_typed_storage_removal() 1086 return self._untyped_storage.get_device() 1087 1088 def __str__(self): 1089 _warn_typed_storage_removal() 1090 info_str = ( 1091 f"[{torch.typename(self)}(dtype={self.dtype}, " 1092 f"device={self.device}) of size {len(self)}]" 1093 ) 1094 if self.device.type == "meta": 1095 return "...\n" + info_str 1096 else: 1097 data_str = " " + "\n ".join(str(self[i]) for i in range(self.size())) 1098 return data_str + "\n" + info_str 1099 1100 def __repr__(self): 1101 _warn_typed_storage_removal() 1102 return str(self) 1103 1104 def __iter__(self): 1105 _warn_typed_storage_removal() 1106 return iter(self[i] for i in range(self.size())) 1107 1108 def __copy__(self): 1109 _warn_typed_storage_removal() 1110 return self._new_wrapped_storage(copy.copy(self._untyped_storage)) 1111 1112 def __deepcopy__(self, memo): 1113 _warn_typed_storage_removal() 1114 return self._deepcopy(memo) 1115 1116 # For internal use only, to avoid deprecation warning 1117 def _deepcopy(self, memo): 1118 return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo)) 1119 1120 def __sizeof__(self): 1121 _warn_typed_storage_removal() 1122 return super().__sizeof__() + self.nbytes() 1123 1124 def clone(self): 1125 """Return a copy of this storage.""" 1126 _warn_typed_storage_removal() 1127 return self._new_wrapped_storage(self._untyped_storage.clone()) 1128 1129 def tolist(self): 1130 """Return a list containing the elements of this storage.""" 1131 _warn_typed_storage_removal() 1132 return list(self) 1133 1134 def cpu(self): 1135 """Return a CPU copy of this storage if it's not already on the CPU.""" 1136 _warn_typed_storage_removal() 1137 return self._new_wrapped_storage(self._untyped_storage.cpu()) 1138 1139 def is_pinned(self, device: Union[str, torch.device] = "cuda"): 1140 r"""Determine whether the CPU TypedStorage is already pinned on device. 1141 1142 Args: 1143 device (str or torch.device): The device to pin memory on. Default: ``'cuda'`` 1144 1145 Returns: 1146 A boolean variable. 1147 """ 1148 _warn_typed_storage_removal() 1149 return self._untyped_storage.is_pinned(device) 1150 1151 def pin_memory(self, device: Union[str, torch.device] = "cuda"): 1152 r"""Copy the CPU TypedStorage to pinned memory, if it's not already pinned. 1153 1154 Args: 1155 device (str or torch.device): The device to pin memory on. Default: ``'cuda'``. 1156 1157 Returns: 1158 A pinned CPU storage. 1159 """ 1160 _warn_typed_storage_removal() 1161 return self._new_wrapped_storage( 1162 self._untyped_storage.pin_memory(device=device) 1163 ) 1164 1165 def share_memory_(self): 1166 """See :meth:`torch.UntypedStorage.share_memory_`""" 1167 _warn_typed_storage_removal() 1168 return self._share_memory_() 1169 1170 # For internal use only, to avoid deprecation warning 1171 def _share_memory_(self): 1172 self._untyped_storage.share_memory_() 1173 return self 1174 1175 def _new_shared(self, size, *, device=None): 1176 """Create a new storage in shared memory with the same data type.""" 1177 if device is None: 1178 device = "cpu" 1179 device = torch.device(device) 1180 untyped_storage = torch.UntypedStorage._new_shared( 1181 size * self._element_size(), device=device 1182 ) 1183 return TypedStorage( 1184 wrap_storage=untyped_storage, dtype=self.dtype, _internal=True 1185 ) 1186 1187 @property 1188 def _cdata(self): 1189 return self._untyped_storage._cdata 1190 1191 @property 1192 def device(self): 1193 _warn_typed_storage_removal() 1194 return self._untyped_storage.device 1195 1196 def size(self): 1197 _warn_typed_storage_removal() 1198 return self._size() 1199 1200 # For internal use only, to avoid deprecation warning 1201 def _size(self): 1202 # NB: don't indirect through __len__, as that requires 1203 # an int to be returned 1204 return self._untyped_storage.nbytes() // self._element_size() 1205 1206 def pickle_storage_type(self): 1207 _warn_typed_storage_removal() 1208 return self._pickle_storage_type() 1209 1210 # For internal use only, to avoid deprecation warning 1211 def _pickle_storage_type(self): 1212 try: 1213 return _dtype_to_storage_type_map()[self.dtype] 1214 except KeyError as e: 1215 raise KeyError(f"dtype {self.dtype} is not recognized") from e 1216 1217 def __reduce__(self): 1218 b = io.BytesIO() 1219 torch.save(self, b, _use_new_zipfile_serialization=False) 1220 return (_load_from_bytes, (b.getvalue(),)) 1221 1222 def data_ptr(self): 1223 _warn_typed_storage_removal() 1224 return self._data_ptr() 1225 1226 # For internal use only, to avoid deprecation warning 1227 def _data_ptr(self): 1228 return self._untyped_storage.data_ptr() 1229 1230 def resizable(self): 1231 _warn_typed_storage_removal() 1232 return self._untyped_storage.resizable() 1233 1234 def resize_(self, size): 1235 _warn_typed_storage_removal() 1236 self._resize_(size) 1237 1238 # For internal use only, to avoid deprecation warning 1239 def _resize_(self, size): 1240 self._untyped_storage.resize_(size * self._element_size()) 1241 1242 @classmethod 1243 def _free_weak_ref(cls, *args, **kwargs): 1244 return UntypedStorage._free_weak_ref(*args, **kwargs) 1245 1246 def _weak_ref(self, *args, **kwargs): 1247 return self._untyped_storage._weak_ref(*args, **kwargs) 1248 1249 @classmethod 1250 def from_buffer(cls, *args, **kwargs): 1251 _warn_typed_storage_removal() 1252 return cls._from_buffer(*args, **kwargs) 1253 1254 @classmethod 1255 def _from_buffer(cls, *args, dtype=None, device=None, **kwargs): 1256 if cls == TypedStorage: 1257 dtype = torch.get_default_dtype() if dtype is None else dtype 1258 device = torch.device("cpu" if device is None else device) 1259 if device.type != "cpu": 1260 raise RuntimeError( 1261 f"TypedStorage.from_buffer: Not available for device {device.type}" 1262 ) 1263 untyped_storage: torch.UntypedStorage = torch.UntypedStorage.from_buffer( 1264 *args, dtype=dtype, **kwargs 1265 ) 1266 1267 else: 1268 if dtype is not None or len(args) == 5: 1269 raise RuntimeError( 1270 "from_buffer: 'dtype' can only be specified in " 1271 "UntypedStorage.from_buffer and TypedStorage.from_buffer" 1272 ) 1273 if device is not None: 1274 raise RuntimeError( 1275 "from_buffer: 'device' can only be specified in " 1276 "UntypedStorage.from_buffer and TypedStorage.from_buffer" 1277 ) 1278 1279 dtype = cls._dtype 1280 untyped_storage = torch.UntypedStorage.from_buffer( 1281 *args, dtype=dtype, **kwargs 1282 ) 1283 1284 return TypedStorage(wrap_storage=untyped_storage, dtype=dtype, _internal=True) 1285 1286 def _to(self, dtype): 1287 if not isinstance(dtype, torch.dtype): 1288 raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}") 1289 storage = ( 1290 torch.tensor([], dtype=self.dtype, device=self.device) 1291 .set_(self) 1292 .to(dtype) 1293 ._typed_storage() 1294 ) 1295 if storage.data_ptr() == self.data_ptr(): 1296 storage = storage.clone() 1297 return storage 1298 1299 def double(self): 1300 """Casts this storage to double type.""" 1301 _warn_typed_storage_removal() 1302 return self._to(torch.double) 1303 1304 def float(self): 1305 """Casts this storage to float type.""" 1306 _warn_typed_storage_removal() 1307 return self._to(torch.float) 1308 1309 def half(self): 1310 """Casts this storage to half type.""" 1311 _warn_typed_storage_removal() 1312 return self._to(torch.half) 1313 1314 def long(self): 1315 """Casts this storage to long type.""" 1316 _warn_typed_storage_removal() 1317 return self._to(torch.long) 1318 1319 def int(self): 1320 """Casts this storage to int type.""" 1321 _warn_typed_storage_removal() 1322 return self._to(torch.int) 1323 1324 def short(self): 1325 """Casts this storage to short type.""" 1326 _warn_typed_storage_removal() 1327 return self._to(torch.short) 1328 1329 def char(self): 1330 """Casts this storage to char type.""" 1331 _warn_typed_storage_removal() 1332 return self._to(torch.int8) 1333 1334 def byte(self): 1335 """Casts this storage to byte type.""" 1336 _warn_typed_storage_removal() 1337 return self._to(torch.uint8) 1338 1339 def bool(self): 1340 """Casts this storage to bool type.""" 1341 _warn_typed_storage_removal() 1342 return self._to(torch.bool) 1343 1344 def bfloat16(self): 1345 """Casts this storage to bfloat16 type.""" 1346 _warn_typed_storage_removal() 1347 return self._to(torch.bfloat16) 1348 1349 def complex_double(self): 1350 """Casts this storage to complex double type.""" 1351 _warn_typed_storage_removal() 1352 return self._to(torch.cdouble) 1353 1354 def complex_float(self): 1355 """Casts this storage to complex float type.""" 1356 _warn_typed_storage_removal() 1357 return self._to(torch.cfloat) 1358 1359 def float8_e5m2(self): 1360 """Casts this storage to float8_e5m2 type""" 1361 _warn_typed_storage_removal() 1362 return self._to(torch.float8_e5m2) 1363 1364 def float8_e4m3fn(self): 1365 """Casts this storage to float8_e4m3fn type""" 1366 _warn_typed_storage_removal() 1367 return self._to(torch.float8_e4m3fn) 1368 1369 def float8_e5m2fnuz(self): 1370 """Casts this storage to float8_e5m2fnuz type""" 1371 _warn_typed_storage_removal() 1372 return self._to(torch.float8_e5m2fnuz) 1373 1374 def float8_e4m3fnuz(self): 1375 """Casts this storage to float8_e4m3fnuz type""" 1376 _warn_typed_storage_removal() 1377 return self._to(torch.float8_e4m3fnuz) 1378 1379 @classmethod 1380 def from_file(cls, filename, shared, size): 1381 """from_file(filename, shared=False, size=0) -> Storage 1382 1383 Creates a CPU storage backed by a memory-mapped file. 1384 1385 If ``shared`` is ``True``, then memory is shared between all processes. 1386 All changes are written to the file. If ``shared`` is ``False``, then the changes on 1387 the storage do not affect the file. 1388 1389 ``size`` is the number of elements in the storage. If ``shared`` is ``False``, 1390 then the file must contain at least ``size * sizeof(Type)`` bytes 1391 (``Type`` is the type of storage). If ``shared`` is ``True`` the file will be created if needed. 1392 1393 Args: 1394 filename (str): file name to map 1395 shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the 1396 underlying `mmap(2) call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_) 1397 size (int): number of elements in the storage 1398 """ 1399 _warn_typed_storage_removal() 1400 if cls == TypedStorage: 1401 raise RuntimeError("from_file can only be called on derived classes") 1402 untyped_storage = UntypedStorage.from_file( 1403 filename, shared, size * torch._utils._element_size(cls.dtype) 1404 ) 1405 storage = cls(wrap_storage=untyped_storage) 1406 return storage 1407 1408 @classmethod 1409 def _expired(cls, *args, **kwargs): 1410 return UntypedStorage._expired(*args, **kwargs) 1411 1412 def _write_file(self, *args, **kwargs): 1413 return self._untyped_storage._write_file(*args, **kwargs) 1414 1415 def _set_from_file(self, *args, **kwargs): 1416 return self._untyped_storage._set_from_file(*args, **kwargs) 1417 1418 def _set_cdata(self, *args, **kwargs): 1419 return self._untyped_storage._set_cdata(*args, **kwargs) 1420 1421 def _share_cuda_(self, *args, **kwargs): 1422 return self._untyped_storage._share_cuda_(*args, **kwargs) 1423 1424 def is_shared(self): 1425 _warn_typed_storage_removal() 1426 return self._is_shared() 1427 1428 # For internal use only, to avoid deprecation warning 1429 def _is_shared(self): 1430 return self._untyped_storage.is_shared() 1431 1432 @classmethod 1433 def _new_shared_cuda(cls, *args, **kwargs): 1434 return torch.UntypedStorage._new_shared_cuda(*args, **kwargs) 1435 1436 def _share_filename_cpu_(self, *args, **kwargs): 1437 ( 1438 manager_handle, 1439 storage_handle, 1440 size, 1441 ) = self._untyped_storage._share_filename_cpu_(*args, **kwargs) 1442 return manager_handle, storage_handle, size // self._element_size() 1443 1444 def _shared_decref(self): 1445 self._untyped_storage._shared_decref() 1446 return self 1447 1448 @classmethod 1449 def _release_ipc_counter(cls, *args, device=None, **kwargs): 1450 return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) 1451 1452 def _shared_incref(self, *args, **kwargs): 1453 return self._untyped_storage._shared_incref(*args, **kwargs) 1454 1455 def _share_fd_cpu_(self, *args, **kwargs): 1456 fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs) 1457 return fd, size // self._element_size() 1458 1459 def _get_legacy_storage_class(self): 1460 if self.dtype not in _dtype_to_storage_type_map(): 1461 return None 1462 1463 storage_name = _dtype_to_storage_type_map()[self.dtype] 1464 1465 if self.device.type not in [ 1466 "cpu", 1467 "cuda", 1468 "hpu", 1469 torch._C._get_privateuse1_backend_name(), 1470 ]: 1471 return None 1472 1473 module = ( 1474 torch if self.device.type == "cpu" else getattr(torch, self.device.type) 1475 ) 1476 1477 try: 1478 return getattr(module, storage_name) 1479 except AttributeError: 1480 return None 1481 1482 1483TypedStorage.type.__doc__ = _type.__doc__ 1484TypedStorage.cuda.__doc__ = _StorageBase.cuda.__doc__ 1485TypedStorage.hpu.__doc__ = _StorageBase.hpu.__doc__ 1486TypedStorage.to.__doc__ = _to.__doc__ 1487 1488 1489class _LegacyStorageMeta(type): 1490 dtype: torch.dtype 1491 1492 def __instancecheck__(cls, instance): 1493 if type(instance) == TypedStorage: 1494 cls_device = _get_device_from_module(cls.__module__) 1495 return (cls_device == instance.device.type) and ( 1496 cls.dtype == instance.dtype 1497 ) 1498 return False 1499 1500 1501class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta): 1502 @classmethod 1503 def _new_shared(cls, size): 1504 """Create a new storage in shared memory with the same data type.""" 1505 untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size()) 1506 return cls(wrap_storage=untyped_storage) 1507 1508 @classmethod 1509 def _release_ipc_counter(cls, *args, **kwargs): 1510 return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) 1511 1512 @classmethod 1513 def _new_shared_filename(cls, manager, obj, size): 1514 bytes_size = size * torch._utils._element_size(cls.dtype) 1515 return cls( 1516 wrap_storage=torch.UntypedStorage._new_shared_filename_cpu( 1517 manager, obj, bytes_size 1518 ) 1519 ) 1520 1521 1522def _get_dtype_from_pickle_storage_type(pickle_storage_type: str): 1523 try: 1524 return _storage_type_to_dtype_map()[pickle_storage_type] 1525 except KeyError as e: 1526 raise KeyError( 1527 f'pickle storage type "{pickle_storage_type}" is not recognized' 1528 ) from e 1529