1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport collections 6*da0073e9SAndroid Build Coastguard Workerimport copy 7*da0073e9SAndroid Build Coastguard Workerimport functools 8*da0073e9SAndroid Build Coastguard Workerimport io 9*da0073e9SAndroid Build Coastguard Workerimport threading 10*da0073e9SAndroid Build Coastguard Workerimport warnings 11*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, cast, Dict as _Dict, Optional as _Optional, Type, TypeVar, Union 12*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import Self 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerimport torch 15*da0073e9SAndroid Build Coastguard Workerfrom torch._utils import _to, _type 16*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _bool, _int, Storage 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker__all__ = ["TypedStorage", "UntypedStorage"] 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workertry: 23*da0073e9SAndroid Build Coastguard Worker import numpy as np 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker HAS_NUMPY = True 26*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError: 27*da0073e9SAndroid Build Coastguard Worker HAS_NUMPY = False 28*da0073e9SAndroid Build Coastguard Worker np = None # type: ignore[assignment] 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker_share_memory_lock = threading.Lock() 32*da0073e9SAndroid Build Coastguard Worker_share_memory_map: _Dict[int, threading.RLock] = {} 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T", bound="Union[_StorageBase, TypedStorage]") 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Workerclass _StorageBase: 38*da0073e9SAndroid Build Coastguard Worker _cdata: Any 39*da0073e9SAndroid Build Coastguard Worker is_sparse: _bool = False 40*da0073e9SAndroid Build Coastguard Worker is_sparse_csr: _bool = False 41*da0073e9SAndroid Build Coastguard Worker device: torch.device 42*da0073e9SAndroid Build Coastguard Worker # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) 43*da0073e9SAndroid Build Coastguard Worker _fake_device: _Optional[torch.device] = None 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker def __init__(self, *args, **kwargs): 46*da0073e9SAndroid Build Coastguard Worker pass 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker def __len__(self) -> _int: 49*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 52*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker def __setitem__(self, *args, **kwargs): 55*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T: 58*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker def new(self) -> Union[_StorageBase, TypedStorage]: 61*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker def nbytes(self) -> _int: 64*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker def size(self) -> _int: 67*da0073e9SAndroid Build Coastguard Worker return self.nbytes() 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker def type( 70*da0073e9SAndroid Build Coastguard Worker self, dtype: _Optional[str] = None, non_blocking: _bool = False 71*da0073e9SAndroid Build Coastguard Worker ) -> Union[_StorageBase, TypedStorage]: 72*da0073e9SAndroid Build Coastguard Worker return _type(self, dtype, non_blocking) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def cuda( 75*da0073e9SAndroid Build Coastguard Worker self, device=None, non_blocking=False 76*da0073e9SAndroid Build Coastguard Worker ) -> Union[_StorageBase, TypedStorage]: 77*da0073e9SAndroid Build Coastguard Worker """Returns a copy of this object in CUDA memory. 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker If this object is already in CUDA memory and on the correct device, then 80*da0073e9SAndroid Build Coastguard Worker no copy is performed and the original object is returned. 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker Args: 83*da0073e9SAndroid Build Coastguard Worker device (int): The destination GPU id. Defaults to the current device. 84*da0073e9SAndroid Build Coastguard Worker non_blocking (bool): If ``True`` and the source is in pinned memory, 85*da0073e9SAndroid Build Coastguard Worker the copy will be asynchronous with respect to the host. Otherwise, 86*da0073e9SAndroid Build Coastguard Worker the argument has no effect. 87*da0073e9SAndroid Build Coastguard Worker """ 88*da0073e9SAndroid Build Coastguard Worker device2 = torch.device("cuda", device) if device else torch.device("cuda") 89*da0073e9SAndroid Build Coastguard Worker return self.to(device=device2, non_blocking=non_blocking) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]: 92*da0073e9SAndroid Build Coastguard Worker """Returns a copy of this object in HPU memory. 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker If this object is already in HPU memory and on the correct device, then 95*da0073e9SAndroid Build Coastguard Worker no copy is performed and the original object is returned. 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker Args: 98*da0073e9SAndroid Build Coastguard Worker device (int): The destination HPU id. Defaults to the current device. 99*da0073e9SAndroid Build Coastguard Worker non_blocking (bool): If ``True`` and the source is in pinned memory, 100*da0073e9SAndroid Build Coastguard Worker the copy will be asynchronous with respect to the host. Otherwise, 101*da0073e9SAndroid Build Coastguard Worker the argument has no effect. 102*da0073e9SAndroid Build Coastguard Worker """ 103*da0073e9SAndroid Build Coastguard Worker device2 = torch.device("hpu", device) if device else torch.device("hpu") 104*da0073e9SAndroid Build Coastguard Worker return self.to(device=device2, non_blocking=non_blocking) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def element_size(self) -> _int: 107*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def get_device(self) -> _int: 110*da0073e9SAndroid Build Coastguard Worker return self.device.index 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker def data_ptr(self) -> _int: 113*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker def resizable(self) -> _bool: 116*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker # Defined in torch/csrc/generic/StorageSharing.cpp 119*da0073e9SAndroid Build Coastguard Worker def _share_filename_cpu_(self, *args, **kwargs): 120*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker def _share_fd_cpu_(self, *args, **kwargs): 123*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker @classmethod 126*da0073e9SAndroid Build Coastguard Worker def _new_using_filename_cpu(cls: Type[T], size: _int) -> T: 127*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker @classmethod 130*da0073e9SAndroid Build Coastguard Worker def _new_using_fd_cpu(cls: Type[T], size: _int) -> T: 131*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker @classmethod 134*da0073e9SAndroid Build Coastguard Worker def from_buffer(cls: Type[T], *args, **kwargs) -> T: 135*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker @classmethod 138*da0073e9SAndroid Build Coastguard Worker def _new_shared_filename_cpu( 139*da0073e9SAndroid Build Coastguard Worker cls: Type[T], 140*da0073e9SAndroid Build Coastguard Worker manager, 141*da0073e9SAndroid Build Coastguard Worker obj, 142*da0073e9SAndroid Build Coastguard Worker size, 143*da0073e9SAndroid Build Coastguard Worker *, 144*da0073e9SAndroid Build Coastguard Worker device=None, 145*da0073e9SAndroid Build Coastguard Worker dtype=None, 146*da0073e9SAndroid Build Coastguard Worker ) -> T: 147*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker @classmethod 150*da0073e9SAndroid Build Coastguard Worker def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T: 151*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker @classmethod 154*da0073e9SAndroid Build Coastguard Worker def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T: 155*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker def _shared_decref(self) -> Union[_StorageBase, TypedStorage]: 158*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker def _write_file(self, *args, **kwargs): 161*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker def resize_(self, size: _int): 164*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: 167*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker def _set_from_file(self, *args, **kwargs): 170*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker def _set_cdata(self, *args, **kwargs): 173*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker def _share_cuda_(self, *args, **kwargs): 176*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker def is_shared(self) -> _bool: 179*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker @classmethod 182*da0073e9SAndroid Build Coastguard Worker def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T: 183*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker def _shared_incref(self, *args, **kwargs): 186*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker @classmethod 189*da0073e9SAndroid Build Coastguard Worker def _free_weak_ref(cls, *args, **kwargs): 190*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker @property 193*da0073e9SAndroid Build Coastguard Worker def is_cuda(self): 194*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker @property 197*da0073e9SAndroid Build Coastguard Worker def is_hpu(self): 198*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker @classmethod 201*da0073e9SAndroid Build Coastguard Worker def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]: 202*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker @classmethod 205*da0073e9SAndroid Build Coastguard Worker def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]: 206*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker def _byteswap(self, *args, **kwargs): 209*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker def _get_filename(self, *args, **kwargs) -> _Optional[str]: 212*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 215*da0073e9SAndroid Build Coastguard Worker info_str = f"[{torch.typename(self)}(device={self.device}) of size {len(self)}]" 216*da0073e9SAndroid Build Coastguard Worker if self.device.type == "meta": 217*da0073e9SAndroid Build Coastguard Worker return "...\n" + info_str 218*da0073e9SAndroid Build Coastguard Worker data_str = " " + "\n ".join(str(self[i]) for i in range(self.size())) 219*da0073e9SAndroid Build Coastguard Worker return data_str + "\n" + info_str 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 222*da0073e9SAndroid Build Coastguard Worker return iter(self[i] for i in range(self.size())) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker def __copy__(self): 225*da0073e9SAndroid Build Coastguard Worker return self.clone() 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker def __deepcopy__(self, memo): 228*da0073e9SAndroid Build Coastguard Worker memo = memo.setdefault("torch", {}) 229*da0073e9SAndroid Build Coastguard Worker if self._cdata in memo: 230*da0073e9SAndroid Build Coastguard Worker return memo[self._cdata] 231*da0073e9SAndroid Build Coastguard Worker new_storage = self.clone() 232*da0073e9SAndroid Build Coastguard Worker memo[self._cdata] = new_storage 233*da0073e9SAndroid Build Coastguard Worker return new_storage 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker def __reduce__(self): 236*da0073e9SAndroid Build Coastguard Worker b = io.BytesIO() 237*da0073e9SAndroid Build Coastguard Worker torch.save(self, b, _use_new_zipfile_serialization=False) 238*da0073e9SAndroid Build Coastguard Worker return (_load_from_bytes, (b.getvalue(),)) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker def __sizeof__(self): 241*da0073e9SAndroid Build Coastguard Worker return super().__sizeof__() + self.size() 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker def clone(self): 244*da0073e9SAndroid Build Coastguard Worker """Return a copy of this storage.""" 245*da0073e9SAndroid Build Coastguard Worker return type(self)(self.nbytes(), device=self.device).copy_(self) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker def tolist(self): 248*da0073e9SAndroid Build Coastguard Worker """Return a list containing the elements of this storage.""" 249*da0073e9SAndroid Build Coastguard Worker return list(self) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker def cpu(self): 252*da0073e9SAndroid Build Coastguard Worker """Return a CPU copy of this storage if it's not already on the CPU.""" 253*da0073e9SAndroid Build Coastguard Worker if self.device.type != "cpu": 254*da0073e9SAndroid Build Coastguard Worker return torch.UntypedStorage(self.size()).copy_(self, False) 255*da0073e9SAndroid Build Coastguard Worker return self 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker def mps(self): 258*da0073e9SAndroid Build Coastguard Worker """Return a MPS copy of this storage if it's not already on the MPS.""" 259*da0073e9SAndroid Build Coastguard Worker if self.device.type != "mps": 260*da0073e9SAndroid Build Coastguard Worker return torch.UntypedStorage(self.size(), device="mps").copy_(self, False) 261*da0073e9SAndroid Build Coastguard Worker return self 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker def _to(self, dtype): 264*da0073e9SAndroid Build Coastguard Worker if not isinstance(dtype, torch.dtype): 265*da0073e9SAndroid Build Coastguard Worker raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}") 266*da0073e9SAndroid Build Coastguard Worker storage = ( 267*da0073e9SAndroid Build Coastguard Worker torch.tensor([], dtype=torch.uint8, device=self.device) 268*da0073e9SAndroid Build Coastguard Worker .set_(cast(Storage, self)) 269*da0073e9SAndroid Build Coastguard Worker .to(dtype) 270*da0073e9SAndroid Build Coastguard Worker ._typed_storage() 271*da0073e9SAndroid Build Coastguard Worker ) 272*da0073e9SAndroid Build Coastguard Worker if storage.data_ptr() == self.data_ptr(): 273*da0073e9SAndroid Build Coastguard Worker storage = storage.clone() 274*da0073e9SAndroid Build Coastguard Worker return storage 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker def to( 277*da0073e9SAndroid Build Coastguard Worker self, *, device: torch.device, non_blocking: _bool = False 278*da0073e9SAndroid Build Coastguard Worker ) -> Union[_StorageBase, TypedStorage]: 279*da0073e9SAndroid Build Coastguard Worker return _to(self, device, non_blocking) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker def double(self): 282*da0073e9SAndroid Build Coastguard Worker """Casts this storage to double type.""" 283*da0073e9SAndroid Build Coastguard Worker return self._to(torch.double) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker def float(self): 286*da0073e9SAndroid Build Coastguard Worker """Casts this storage to float type.""" 287*da0073e9SAndroid Build Coastguard Worker return self._to(torch.float) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker def half(self): 290*da0073e9SAndroid Build Coastguard Worker """Casts this storage to half type.""" 291*da0073e9SAndroid Build Coastguard Worker return self._to(torch.half) 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker def long(self): 294*da0073e9SAndroid Build Coastguard Worker """Casts this storage to long type.""" 295*da0073e9SAndroid Build Coastguard Worker return self._to(torch.long) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker def int(self): 298*da0073e9SAndroid Build Coastguard Worker """Casts this storage to int type.""" 299*da0073e9SAndroid Build Coastguard Worker return self._to(torch.int) 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker def short(self): 302*da0073e9SAndroid Build Coastguard Worker """Casts this storage to short type.""" 303*da0073e9SAndroid Build Coastguard Worker return self._to(torch.short) 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker def char(self): 306*da0073e9SAndroid Build Coastguard Worker """Casts this storage to char type.""" 307*da0073e9SAndroid Build Coastguard Worker return self._to(torch.int8) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker def byte(self): 310*da0073e9SAndroid Build Coastguard Worker """Casts this storage to byte type.""" 311*da0073e9SAndroid Build Coastguard Worker return self._to(torch.uint8) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker def bool(self): 314*da0073e9SAndroid Build Coastguard Worker """Casts this storage to bool type.""" 315*da0073e9SAndroid Build Coastguard Worker return self._to(torch.bool) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker def bfloat16(self): 318*da0073e9SAndroid Build Coastguard Worker """Casts this storage to bfloat16 type.""" 319*da0073e9SAndroid Build Coastguard Worker return self._to(torch.bfloat16) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker def complex_double(self): 322*da0073e9SAndroid Build Coastguard Worker """Casts this storage to complex double type.""" 323*da0073e9SAndroid Build Coastguard Worker return self._to(torch.cdouble) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker def complex_float(self): 326*da0073e9SAndroid Build Coastguard Worker """Casts this storage to complex float type.""" 327*da0073e9SAndroid Build Coastguard Worker return self._to(torch.cfloat) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker def float8_e5m2(self): 330*da0073e9SAndroid Build Coastguard Worker """Casts this storage to float8_e5m2 type""" 331*da0073e9SAndroid Build Coastguard Worker return self._to(torch.float8_e5m2) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker def float8_e4m3fn(self): 334*da0073e9SAndroid Build Coastguard Worker """Casts this storage to float8_e4m3fn type""" 335*da0073e9SAndroid Build Coastguard Worker return self._to(torch.float8_e4m3fn) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker def float8_e5m2fnuz(self): 338*da0073e9SAndroid Build Coastguard Worker """Casts this storage to float8_e5m2fnuz type""" 339*da0073e9SAndroid Build Coastguard Worker return self._to(torch.float8_e5m2fnuz) 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker def float8_e4m3fnuz(self): 342*da0073e9SAndroid Build Coastguard Worker """Casts this storage to float8_e4m3fnuz type""" 343*da0073e9SAndroid Build Coastguard Worker return self._to(torch.float8_e4m3fnuz) 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker def is_pinned(self, device: Union[str, torch.device] = "cuda"): 346*da0073e9SAndroid Build Coastguard Worker r"""Determine whether the CPU storage is already pinned on device. 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker Args: 349*da0073e9SAndroid Build Coastguard Worker device (str or torch.device): The device to pin memory on. Default: ``'cuda'``. 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker Returns: 352*da0073e9SAndroid Build Coastguard Worker A boolean variable. 353*da0073e9SAndroid Build Coastguard Worker """ 354*da0073e9SAndroid Build Coastguard Worker return ( 355*da0073e9SAndroid Build Coastguard Worker torch.tensor([], dtype=torch.uint8, device=self.device) 356*da0073e9SAndroid Build Coastguard Worker .set_(cast(Storage, self)) 357*da0073e9SAndroid Build Coastguard Worker .is_pinned(device) 358*da0073e9SAndroid Build Coastguard Worker ) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker def pin_memory(self, device: Union[str, torch.device] = "cuda"): 361*da0073e9SAndroid Build Coastguard Worker r"""Copy the CPU storage to pinned memory, if it's not already pinned. 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker Args: 364*da0073e9SAndroid Build Coastguard Worker device (str or torch.device): The device to pin memory on. Default: ``'cuda'``. 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker Returns: 367*da0073e9SAndroid Build Coastguard Worker A pinned CPU storage. 368*da0073e9SAndroid Build Coastguard Worker """ 369*da0073e9SAndroid Build Coastguard Worker if self.device.type != "cpu": 370*da0073e9SAndroid Build Coastguard Worker raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned") 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker pinned_tensor = ( 373*da0073e9SAndroid Build Coastguard Worker torch.tensor([], dtype=torch.uint8, device=self.device) 374*da0073e9SAndroid Build Coastguard Worker .set_(cast(Storage, self)) 375*da0073e9SAndroid Build Coastguard Worker .pin_memory(device) 376*da0073e9SAndroid Build Coastguard Worker ) 377*da0073e9SAndroid Build Coastguard Worker return pinned_tensor.untyped_storage() 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker def share_memory_(self): 380*da0073e9SAndroid Build Coastguard Worker """See :meth:`torch.UntypedStorage.share_memory_`""" 381*da0073e9SAndroid Build Coastguard Worker from torch.multiprocessing import get_sharing_strategy 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker if self.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: 384*da0073e9SAndroid Build Coastguard Worker pass # CUDA or PrivateUse1 doesn't use POSIX shared memory 385*da0073e9SAndroid Build Coastguard Worker elif get_sharing_strategy() == "file_system": 386*da0073e9SAndroid Build Coastguard Worker self._share_filename_cpu_() 387*da0073e9SAndroid Build Coastguard Worker else: 388*da0073e9SAndroid Build Coastguard Worker self._share_fd_cpu_() 389*da0073e9SAndroid Build Coastguard Worker return self 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker @classmethod 392*da0073e9SAndroid Build Coastguard Worker def _new_shared(cls, size, *, device="cpu"): 393*da0073e9SAndroid Build Coastguard Worker """Create a new storage in shared memory with the same data type.""" 394*da0073e9SAndroid Build Coastguard Worker from torch.multiprocessing import get_sharing_strategy 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 397*da0073e9SAndroid Build Coastguard Worker if device.type in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]: 398*da0073e9SAndroid Build Coastguard Worker return cls(size, device=device) 399*da0073e9SAndroid Build Coastguard Worker elif get_sharing_strategy() == "file_system": 400*da0073e9SAndroid Build Coastguard Worker return cls._new_using_filename_cpu(size) 401*da0073e9SAndroid Build Coastguard Worker else: 402*da0073e9SAndroid Build Coastguard Worker return cls._new_using_fd_cpu(size) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker def untyped(self): 405*da0073e9SAndroid Build Coastguard Worker return self 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker def byteswap(self, dtype): 408*da0073e9SAndroid Build Coastguard Worker """Swap bytes in underlying data.""" 409*da0073e9SAndroid Build Coastguard Worker elem_size = torch._utils._element_size(dtype) 410*da0073e9SAndroid Build Coastguard Worker # for complex types, don't swap first and second numbers 411*da0073e9SAndroid Build Coastguard Worker if dtype.is_complex: 412*da0073e9SAndroid Build Coastguard Worker elem_size = max(int(elem_size / 2), 1) 413*da0073e9SAndroid Build Coastguard Worker self._byteswap(elem_size) 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Workerdef _share_memory_lock_protected(fn): 417*da0073e9SAndroid Build Coastguard Worker @functools.wraps(fn) 418*da0073e9SAndroid Build Coastguard Worker def wrapper(self, *args, **kwargs): 419*da0073e9SAndroid Build Coastguard Worker to_free = None 420*da0073e9SAndroid Build Coastguard Worker to_wait = None 421*da0073e9SAndroid Build Coastguard Worker with _share_memory_lock: 422*da0073e9SAndroid Build Coastguard Worker key = self._cdata 423*da0073e9SAndroid Build Coastguard Worker if key in _share_memory_map: 424*da0073e9SAndroid Build Coastguard Worker to_wait = _share_memory_map[key] 425*da0073e9SAndroid Build Coastguard Worker else: 426*da0073e9SAndroid Build Coastguard Worker _share_memory_map[key] = threading.RLock() 427*da0073e9SAndroid Build Coastguard Worker _share_memory_map[key].acquire() 428*da0073e9SAndroid Build Coastguard Worker to_free = key 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker # If we're already in the process of sharing the storage, wait 431*da0073e9SAndroid Build Coastguard Worker # for it to be done. 432*da0073e9SAndroid Build Coastguard Worker if to_wait is not None: 433*da0073e9SAndroid Build Coastguard Worker with to_wait: 434*da0073e9SAndroid Build Coastguard Worker pass 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker try: 437*da0073e9SAndroid Build Coastguard Worker return fn(self, *args, **kwargs) 438*da0073e9SAndroid Build Coastguard Worker finally: 439*da0073e9SAndroid Build Coastguard Worker # If we acquired the storage lock here and we're done working on it 440*da0073e9SAndroid Build Coastguard Worker # we can now release it and free the entry. 441*da0073e9SAndroid Build Coastguard Worker if to_free is not None: 442*da0073e9SAndroid Build Coastguard Worker # Ensure that the cdata from the storage didn't change and only 443*da0073e9SAndroid Build Coastguard Worker # the data_ptr did. 444*da0073e9SAndroid Build Coastguard Worker assert self._cdata == to_free 445*da0073e9SAndroid Build Coastguard Worker with _share_memory_lock: 446*da0073e9SAndroid Build Coastguard Worker _share_memory_map[to_free].release() 447*da0073e9SAndroid Build Coastguard Worker del _share_memory_map[to_free] 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker return wrapper 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Workerclass UntypedStorage(torch._C.StorageBase, _StorageBase): 453*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, *args, **kwargs): 454*da0073e9SAndroid Build Coastguard Worker if self.device.type == "meta": 455*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Not available for 'meta' device type") 456*da0073e9SAndroid Build Coastguard Worker return super().__getitem__(*args, **kwargs) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker @property 459*da0073e9SAndroid Build Coastguard Worker def is_cuda(self): 460*da0073e9SAndroid Build Coastguard Worker return self.device.type == "cuda" 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker @property 463*da0073e9SAndroid Build Coastguard Worker def is_hpu(self): 464*da0073e9SAndroid Build Coastguard Worker return self.device.type == "hpu" 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker @property 467*da0073e9SAndroid Build Coastguard Worker def filename(self) -> _Optional[str]: 468*da0073e9SAndroid Build Coastguard Worker """Returns the file name associated with this storage. 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker The file name will be a string if the storage is on CPU and was created via 471*da0073e9SAndroid Build Coastguard Worker :meth:`~torch.from_file()` with ``shared`` as ``True``. This attribute is ``None`` otherwise. 472*da0073e9SAndroid Build Coastguard Worker """ 473*da0073e9SAndroid Build Coastguard Worker return self._get_filename() 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker @_share_memory_lock_protected 476*da0073e9SAndroid Build Coastguard Worker def share_memory_(self, *args, **kwargs): 477*da0073e9SAndroid Build Coastguard Worker """ 478*da0073e9SAndroid Build Coastguard Worker Moves the storage to shared memory. 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker This is a no-op for storages already in shared memory and for CUDA 481*da0073e9SAndroid Build Coastguard Worker storages, which do not need to be moved for sharing across processes. 482*da0073e9SAndroid Build Coastguard Worker Storages in shared memory cannot be resized. 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker Note that to mitigate issues like `this <https://github.com/pytorch/pytorch/issues/95606>`_ 485*da0073e9SAndroid Build Coastguard Worker it is thread safe to call this function from multiple threads on the same object. 486*da0073e9SAndroid Build Coastguard Worker It is NOT thread safe though to call any other function on self without proper 487*da0073e9SAndroid Build Coastguard Worker synchronization. Please see :doc:`/notes/multiprocessing` for more details. 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker .. note:: 490*da0073e9SAndroid Build Coastguard Worker When all references to a storage in shared memory are deleted, the associated shared memory 491*da0073e9SAndroid Build Coastguard Worker object will also be deleted. PyTorch has a special cleanup process to ensure that this happens 492*da0073e9SAndroid Build Coastguard Worker even if the current process exits unexpectedly. 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker It is worth noting the difference between :meth:`share_memory_` and :meth:`from_file` with ``shared = True`` 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker #. ``share_memory_`` uses `shm_open(3) <https://man7.org/linux/man-pages/man3/shm_open.3.html>`_ to create a 497*da0073e9SAndroid Build Coastguard Worker POSIX shared memory object while :meth:`from_file` uses 498*da0073e9SAndroid Build Coastguard Worker `open(2) <https://man7.org/linux/man-pages/man2/open.2.html>`_ to open the filename passed by the user. 499*da0073e9SAndroid Build Coastguard Worker #. Both use an `mmap(2) call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_ with ``MAP_SHARED`` 500*da0073e9SAndroid Build Coastguard Worker to map the file/object into the current virtual address space 501*da0073e9SAndroid Build Coastguard Worker #. ``share_memory_`` will call ``shm_unlink(3)`` on the object after mapping it to make sure the shared memory 502*da0073e9SAndroid Build Coastguard Worker object is freed when no process has the object open. ``torch.from_file(shared=True)`` does not unlink the 503*da0073e9SAndroid Build Coastguard Worker file. This file is persistent and will remain until it is deleted by the user. 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker Returns: 506*da0073e9SAndroid Build Coastguard Worker ``self`` 507*da0073e9SAndroid Build Coastguard Worker """ 508*da0073e9SAndroid Build Coastguard Worker return super().share_memory_(*args, **kwargs) 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker @_share_memory_lock_protected 511*da0073e9SAndroid Build Coastguard Worker def _share_fd_cpu_(self, *args, **kwargs): 512*da0073e9SAndroid Build Coastguard Worker return super()._share_fd_cpu_(*args, **kwargs) 513*da0073e9SAndroid Build Coastguard Worker 514*da0073e9SAndroid Build Coastguard Worker @_share_memory_lock_protected 515*da0073e9SAndroid Build Coastguard Worker def _share_filename_cpu_(self, *args, **kwargs): 516*da0073e9SAndroid Build Coastguard Worker return super()._share_filename_cpu_(*args, **kwargs) 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Workerdef _load_from_bytes(b): 520*da0073e9SAndroid Build Coastguard Worker return torch.load(io.BytesIO(b), weights_only=False) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None) 524*da0073e9SAndroid Build Coastguard Workerdef _new_dtypes(): 525*da0073e9SAndroid Build Coastguard Worker # These are dtypes serialized as UntypedStorage unlike those in 526*da0073e9SAndroid Build Coastguard Worker # _dtype_to_storage_type_map 527*da0073e9SAndroid Build Coastguard Worker return { 528*da0073e9SAndroid Build Coastguard Worker torch.float8_e5m2, 529*da0073e9SAndroid Build Coastguard Worker torch.float8_e4m3fn, 530*da0073e9SAndroid Build Coastguard Worker torch.float8_e5m2fnuz, 531*da0073e9SAndroid Build Coastguard Worker torch.float8_e4m3fnuz, 532*da0073e9SAndroid Build Coastguard Worker torch.bits8, 533*da0073e9SAndroid Build Coastguard Worker torch.bits16, 534*da0073e9SAndroid Build Coastguard Worker torch.bits1x8, 535*da0073e9SAndroid Build Coastguard Worker torch.bits2x4, 536*da0073e9SAndroid Build Coastguard Worker torch.bits4x2, 537*da0073e9SAndroid Build Coastguard Worker torch.complex32, 538*da0073e9SAndroid Build Coastguard Worker } 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None) 542*da0073e9SAndroid Build Coastguard Workerdef _dtype_to_storage_type_map(): 543*da0073e9SAndroid Build Coastguard Worker # NOTE: We should no longer add dtypes to this map. This map 544*da0073e9SAndroid Build Coastguard Worker # is only used for BC/FC with older PyTorch versions. Going forward, 545*da0073e9SAndroid Build Coastguard Worker # new dtypes of TypedStorage should not translate to a legacy 546*da0073e9SAndroid Build Coastguard Worker # <type>Storage class. Instead, new dtypes of TypedStorage should 547*da0073e9SAndroid Build Coastguard Worker # be serialized as an UntypedStorage paired with a torch.dtype 548*da0073e9SAndroid Build Coastguard Worker return { 549*da0073e9SAndroid Build Coastguard Worker torch.double: "DoubleStorage", 550*da0073e9SAndroid Build Coastguard Worker torch.float: "FloatStorage", 551*da0073e9SAndroid Build Coastguard Worker torch.half: "HalfStorage", 552*da0073e9SAndroid Build Coastguard Worker torch.long: "LongStorage", 553*da0073e9SAndroid Build Coastguard Worker torch.int: "IntStorage", 554*da0073e9SAndroid Build Coastguard Worker torch.int16: "ShortStorage", 555*da0073e9SAndroid Build Coastguard Worker torch.int8: "CharStorage", 556*da0073e9SAndroid Build Coastguard Worker torch.uint8: "ByteStorage", 557*da0073e9SAndroid Build Coastguard Worker torch.bool: "BoolStorage", 558*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: "BFloat16Storage", 559*da0073e9SAndroid Build Coastguard Worker torch.cdouble: "ComplexDoubleStorage", 560*da0073e9SAndroid Build Coastguard Worker torch.cfloat: "ComplexFloatStorage", 561*da0073e9SAndroid Build Coastguard Worker torch.qint8: "QInt8Storage", 562*da0073e9SAndroid Build Coastguard Worker torch.qint32: "QInt32Storage", 563*da0073e9SAndroid Build Coastguard Worker torch.quint8: "QUInt8Storage", 564*da0073e9SAndroid Build Coastguard Worker torch.quint4x2: "QUInt4x2Storage", 565*da0073e9SAndroid Build Coastguard Worker torch.quint2x4: "QUInt2x4Storage", 566*da0073e9SAndroid Build Coastguard Worker } 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None) 570*da0073e9SAndroid Build Coastguard Workerdef _storage_type_to_dtype_map(): 571*da0073e9SAndroid Build Coastguard Worker dtype_map = {val: key for key, val in _dtype_to_storage_type_map().items()} 572*da0073e9SAndroid Build Coastguard Worker return dtype_map 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker 575*da0073e9SAndroid Build Coastguard Workerdef _get_storage_from_sequence(sequence, dtype, device): 576*da0073e9SAndroid Build Coastguard Worker if dtype in [ 577*da0073e9SAndroid Build Coastguard Worker torch.quint8, 578*da0073e9SAndroid Build Coastguard Worker torch.quint4x2, 579*da0073e9SAndroid Build Coastguard Worker torch.quint2x4, 580*da0073e9SAndroid Build Coastguard Worker torch.qint32, 581*da0073e9SAndroid Build Coastguard Worker torch.qint8, 582*da0073e9SAndroid Build Coastguard Worker ]: 583*da0073e9SAndroid Build Coastguard Worker interpret_dtypes = { 584*da0073e9SAndroid Build Coastguard Worker torch.quint8: torch.uint8, 585*da0073e9SAndroid Build Coastguard Worker torch.quint4x2: torch.uint8, 586*da0073e9SAndroid Build Coastguard Worker torch.quint2x4: torch.uint8, 587*da0073e9SAndroid Build Coastguard Worker torch.qint32: torch.int32, 588*da0073e9SAndroid Build Coastguard Worker torch.qint8: torch.int8, 589*da0073e9SAndroid Build Coastguard Worker } 590*da0073e9SAndroid Build Coastguard Worker tmp_tensor = torch.tensor( 591*da0073e9SAndroid Build Coastguard Worker sequence, dtype=interpret_dtypes[dtype], device=device 592*da0073e9SAndroid Build Coastguard Worker ) 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker else: 595*da0073e9SAndroid Build Coastguard Worker tmp_tensor = torch.tensor(sequence, dtype=dtype, device=device) 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker return tmp_tensor._typed_storage()._untyped_storage 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Workerdef _isint(x): 601*da0073e9SAndroid Build Coastguard Worker if HAS_NUMPY: 602*da0073e9SAndroid Build Coastguard Worker return isinstance(x, (int, np.integer)) 603*da0073e9SAndroid Build Coastguard Worker else: 604*da0073e9SAndroid Build Coastguard Worker return isinstance(x, int) 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker_always_warn_typed_storage_removal = False 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Workerdef _get_always_warn_typed_storage_removal(): 611*da0073e9SAndroid Build Coastguard Worker return _always_warn_typed_storage_removal 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Workerdef _set_always_warn_typed_storage_removal(always_warn): 615*da0073e9SAndroid Build Coastguard Worker global _always_warn_typed_storage_removal 616*da0073e9SAndroid Build Coastguard Worker assert isinstance(always_warn, bool) 617*da0073e9SAndroid Build Coastguard Worker _always_warn_typed_storage_removal = always_warn 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Workerdef _warn_typed_storage_removal(stacklevel=2): 621*da0073e9SAndroid Build Coastguard Worker global _always_warn_typed_storage_removal 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker def is_first_time(): 624*da0073e9SAndroid Build Coastguard Worker if not hasattr(_warn_typed_storage_removal, "has_warned"): 625*da0073e9SAndroid Build Coastguard Worker return True 626*da0073e9SAndroid Build Coastguard Worker else: 627*da0073e9SAndroid Build Coastguard Worker return not _warn_typed_storage_removal.__dict__["has_warned"] 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker if _get_always_warn_typed_storage_removal() or is_first_time(): 630*da0073e9SAndroid Build Coastguard Worker message = ( 631*da0073e9SAndroid Build Coastguard Worker "TypedStorage is deprecated. It will be removed in the future and " 632*da0073e9SAndroid Build Coastguard Worker "UntypedStorage will be the only storage class. This should only matter " 633*da0073e9SAndroid Build Coastguard Worker "to you if you are using storages directly. To access UntypedStorage " 634*da0073e9SAndroid Build Coastguard Worker "directly, use tensor.untyped_storage() instead of tensor.storage()" 635*da0073e9SAndroid Build Coastguard Worker ) 636*da0073e9SAndroid Build Coastguard Worker warnings.warn(message, UserWarning, stacklevel=stacklevel + 1) 637*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal.__dict__["has_warned"] = True 638*da0073e9SAndroid Build Coastguard Worker 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Workerdef _reset_warn_typed_storage_removal(): 641*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal.__dict__["has_warned"] = False 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Workerdef _get_device_from_module(module: str): 645*da0073e9SAndroid Build Coastguard Worker last_part = module.rsplit(".", 1)[-1] 646*da0073e9SAndroid Build Coastguard Worker if last_part in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]: 647*da0073e9SAndroid Build Coastguard Worker return last_part 648*da0073e9SAndroid Build Coastguard Worker else: 649*da0073e9SAndroid Build Coastguard Worker return "cpu" 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker 652*da0073e9SAndroid Build Coastguard Workerclass TypedStorage: 653*da0073e9SAndroid Build Coastguard Worker is_sparse: _bool = False 654*da0073e9SAndroid Build Coastguard Worker # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) 655*da0073e9SAndroid Build Coastguard Worker _fake_device: _Optional[torch.device] = None 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker dtype: torch.dtype 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker @property 660*da0073e9SAndroid Build Coastguard Worker def _dtype(self): 661*da0073e9SAndroid Build Coastguard Worker return self.dtype 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker @property 664*da0073e9SAndroid Build Coastguard Worker def filename(self) -> _Optional[str]: 665*da0073e9SAndroid Build Coastguard Worker """Returns the file name associated with this storage if the storage was memory mapped from a file. 666*da0073e9SAndroid Build Coastguard Worker or ``None`` if the storage was not created by memory mapping a file.""" 667*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.filename 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker def fill_(self, value): 670*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 671*da0073e9SAndroid Build Coastguard Worker self._setitem(slice(0, self._size()), value) 672*da0073e9SAndroid Build Coastguard Worker return self 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker def __new__( 675*da0073e9SAndroid Build Coastguard Worker cls, 676*da0073e9SAndroid Build Coastguard Worker *args, 677*da0073e9SAndroid Build Coastguard Worker wrap_storage=None, 678*da0073e9SAndroid Build Coastguard Worker dtype=None, 679*da0073e9SAndroid Build Coastguard Worker device=None, 680*da0073e9SAndroid Build Coastguard Worker _internal=False, 681*da0073e9SAndroid Build Coastguard Worker ): 682*da0073e9SAndroid Build Coastguard Worker if not _internal: 683*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker if cls == torch.storage._LegacyStorage: 686*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 687*da0073e9SAndroid Build Coastguard Worker "Only child classes of _LegacyStorage can be instantiated" 688*da0073e9SAndroid Build Coastguard Worker ) 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker if cls == TypedStorage: 691*da0073e9SAndroid Build Coastguard Worker return super().__new__(cls) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker else: 694*da0073e9SAndroid Build Coastguard Worker arg_error_msg = ( 695*da0073e9SAndroid Build Coastguard Worker f"{cls}.__new__ received an invalid combination " 696*da0073e9SAndroid Build Coastguard Worker f"of arguments. Expected one of:\n" 697*da0073e9SAndroid Build Coastguard Worker " * no arguments\n" 698*da0073e9SAndroid Build Coastguard Worker " * (int size)\n" 699*da0073e9SAndroid Build Coastguard Worker " * (Sequence data)\n" 700*da0073e9SAndroid Build Coastguard Worker " * (*, UntypedStorage wrap_storage)" 701*da0073e9SAndroid Build Coastguard Worker ) 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker if device is not None: 704*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 705*da0073e9SAndroid Build Coastguard Worker arg_error_msg + "\nKeyword argument 'device' cannot be specified" 706*da0073e9SAndroid Build Coastguard Worker ) 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker if dtype is not None: 709*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 710*da0073e9SAndroid Build Coastguard Worker arg_error_msg + "\nKeyword argument 'dtype' cannot be specified" 711*da0073e9SAndroid Build Coastguard Worker ) 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker if wrap_storage is None: 714*da0073e9SAndroid Build Coastguard Worker if len(args) > 1: 715*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 716*da0073e9SAndroid Build Coastguard Worker arg_error_msg + "\nToo many positional arguments" 717*da0073e9SAndroid Build Coastguard Worker ) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker if ( 720*da0073e9SAndroid Build Coastguard Worker len(args) == 1 721*da0073e9SAndroid Build Coastguard Worker and not _isint(args[0]) 722*da0073e9SAndroid Build Coastguard Worker and not isinstance(args[0], collections.abc.Sequence) 723*da0073e9SAndroid Build Coastguard Worker ): 724*da0073e9SAndroid Build Coastguard Worker raise TypeError( 725*da0073e9SAndroid Build Coastguard Worker arg_error_msg 726*da0073e9SAndroid Build Coastguard Worker + f"\nArgument type not recognized: {type(args[0])}" 727*da0073e9SAndroid Build Coastguard Worker ) 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker return TypedStorage( 730*da0073e9SAndroid Build Coastguard Worker *args, 731*da0073e9SAndroid Build Coastguard Worker dtype=cls._dtype, 732*da0073e9SAndroid Build Coastguard Worker device=_get_device_from_module(cls.__module__), 733*da0073e9SAndroid Build Coastguard Worker _internal=True, 734*da0073e9SAndroid Build Coastguard Worker ) 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker else: 737*da0073e9SAndroid Build Coastguard Worker if len(args) != 0: 738*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 739*da0073e9SAndroid Build Coastguard Worker arg_error_msg 740*da0073e9SAndroid Build Coastguard Worker + "\nNo positional arguments should be given when using " 741*da0073e9SAndroid Build Coastguard Worker "'wrap_storage'" 742*da0073e9SAndroid Build Coastguard Worker ) 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Worker if not isinstance(wrap_storage, torch.UntypedStorage): 745*da0073e9SAndroid Build Coastguard Worker raise TypeError( 746*da0073e9SAndroid Build Coastguard Worker arg_error_msg 747*da0073e9SAndroid Build Coastguard Worker + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}" 748*da0073e9SAndroid Build Coastguard Worker ) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker cls_device = _get_device_from_module(cls.__module__) 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker if wrap_storage.device.type != cls_device: 753*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 754*da0073e9SAndroid Build Coastguard Worker arg_error_msg 755*da0073e9SAndroid Build Coastguard Worker + f"\nDevice of 'wrap_storage' must be {cls_device}" 756*da0073e9SAndroid Build Coastguard Worker f", but got {wrap_storage.device.type}" 757*da0073e9SAndroid Build Coastguard Worker ) 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker return TypedStorage( 760*da0073e9SAndroid Build Coastguard Worker *args, 761*da0073e9SAndroid Build Coastguard Worker wrap_storage=wrap_storage, 762*da0073e9SAndroid Build Coastguard Worker dtype=cls.dtype, 763*da0073e9SAndroid Build Coastguard Worker _internal=True, 764*da0073e9SAndroid Build Coastguard Worker ) 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker def __init__( 767*da0073e9SAndroid Build Coastguard Worker self, 768*da0073e9SAndroid Build Coastguard Worker *args, 769*da0073e9SAndroid Build Coastguard Worker device=None, 770*da0073e9SAndroid Build Coastguard Worker dtype=None, 771*da0073e9SAndroid Build Coastguard Worker wrap_storage=None, 772*da0073e9SAndroid Build Coastguard Worker _internal=False, 773*da0073e9SAndroid Build Coastguard Worker ): 774*da0073e9SAndroid Build Coastguard Worker if not _internal: 775*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 776*da0073e9SAndroid Build Coastguard Worker arg_error_msg = ( 777*da0073e9SAndroid Build Coastguard Worker "TypedStorage.__init__ received an invalid combination " 778*da0073e9SAndroid Build Coastguard Worker "of arguments. Expected one of:\n" 779*da0073e9SAndroid Build Coastguard Worker " * (*, torch.device device, torch.dtype dtype)\n" 780*da0073e9SAndroid Build Coastguard Worker " * (int size, *, torch.device device, torch.dtype dtype)\n" 781*da0073e9SAndroid Build Coastguard Worker " * (Sequence data, *, torch.device device, torch.dtype dtype)\n" 782*da0073e9SAndroid Build Coastguard Worker " * (*, UntypedStorage wrap_storage, torch.dtype dtype)" 783*da0073e9SAndroid Build Coastguard Worker ) 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker if wrap_storage is not None: 786*da0073e9SAndroid Build Coastguard Worker if len(args) != 0: 787*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 788*da0073e9SAndroid Build Coastguard Worker arg_error_msg 789*da0073e9SAndroid Build Coastguard Worker + "\nNo positional arguments should be given when using " 790*da0073e9SAndroid Build Coastguard Worker "'wrap_storage'" 791*da0073e9SAndroid Build Coastguard Worker ) 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker if dtype is None: 794*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 795*da0073e9SAndroid Build Coastguard Worker arg_error_msg + "\nArgument 'dtype' must be specified" 796*da0073e9SAndroid Build Coastguard Worker ) 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker if not isinstance(dtype, torch.dtype): 799*da0073e9SAndroid Build Coastguard Worker raise TypeError( 800*da0073e9SAndroid Build Coastguard Worker arg_error_msg 801*da0073e9SAndroid Build Coastguard Worker + f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}" 802*da0073e9SAndroid Build Coastguard Worker ) 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker if device is not None: 805*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 806*da0073e9SAndroid Build Coastguard Worker arg_error_msg 807*da0073e9SAndroid Build Coastguard Worker + "\nArgument 'device' should not be specified when 'wrap_storage' is given" 808*da0073e9SAndroid Build Coastguard Worker ) 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker self.dtype = dtype 811*da0073e9SAndroid Build Coastguard Worker 812*da0073e9SAndroid Build Coastguard Worker if not isinstance(wrap_storage, torch.UntypedStorage): 813*da0073e9SAndroid Build Coastguard Worker raise TypeError( 814*da0073e9SAndroid Build Coastguard Worker arg_error_msg 815*da0073e9SAndroid Build Coastguard Worker + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}" 816*da0073e9SAndroid Build Coastguard Worker ) 817*da0073e9SAndroid Build Coastguard Worker 818*da0073e9SAndroid Build Coastguard Worker self._untyped_storage = wrap_storage 819*da0073e9SAndroid Build Coastguard Worker 820*da0073e9SAndroid Build Coastguard Worker else: 821*da0073e9SAndroid Build Coastguard Worker self.dtype = torch.get_default_dtype() if dtype is None else dtype 822*da0073e9SAndroid Build Coastguard Worker device = torch.device("cpu" if device is None else device) 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Worker if self.dtype in [ 825*da0073e9SAndroid Build Coastguard Worker torch.quint8, 826*da0073e9SAndroid Build Coastguard Worker torch.quint4x2, 827*da0073e9SAndroid Build Coastguard Worker torch.quint2x4, 828*da0073e9SAndroid Build Coastguard Worker torch.qint32, 829*da0073e9SAndroid Build Coastguard Worker torch.qint8, 830*da0073e9SAndroid Build Coastguard Worker ]: 831*da0073e9SAndroid Build Coastguard Worker if device.type == "cuda": 832*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 833*da0073e9SAndroid Build Coastguard Worker "Cannot create CUDA storage with quantized dtype" 834*da0073e9SAndroid Build Coastguard Worker ) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker if len(args) == 0: 837*da0073e9SAndroid Build Coastguard Worker self._untyped_storage = torch.UntypedStorage(device=device) 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker elif len(args) == 1: 840*da0073e9SAndroid Build Coastguard Worker if _isint(args[0]): 841*da0073e9SAndroid Build Coastguard Worker self._untyped_storage = torch.UntypedStorage( 842*da0073e9SAndroid Build Coastguard Worker int(args[0]) * self._element_size(), device=device 843*da0073e9SAndroid Build Coastguard Worker ) 844*da0073e9SAndroid Build Coastguard Worker elif isinstance(args[0], collections.abc.Sequence): 845*da0073e9SAndroid Build Coastguard Worker self._untyped_storage = _get_storage_from_sequence( 846*da0073e9SAndroid Build Coastguard Worker args[0], self.dtype, device 847*da0073e9SAndroid Build Coastguard Worker ) 848*da0073e9SAndroid Build Coastguard Worker else: 849*da0073e9SAndroid Build Coastguard Worker raise TypeError( 850*da0073e9SAndroid Build Coastguard Worker arg_error_msg 851*da0073e9SAndroid Build Coastguard Worker + f"\nArgument type not recognized: {type(args[0])}" 852*da0073e9SAndroid Build Coastguard Worker ) 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker else: 855*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(arg_error_msg + "\nToo many positional arguments") 856*da0073e9SAndroid Build Coastguard Worker 857*da0073e9SAndroid Build Coastguard Worker @property 858*da0073e9SAndroid Build Coastguard Worker def is_cuda(self): 859*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 860*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.device.type == "cuda" 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker @property 863*da0073e9SAndroid Build Coastguard Worker def is_hpu(self): 864*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 865*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.device.type == "hpu" 866*da0073e9SAndroid Build Coastguard Worker 867*da0073e9SAndroid Build Coastguard Worker def untyped(self): 868*da0073e9SAndroid Build Coastguard Worker """Return the internal :class:`torch.UntypedStorage`.""" 869*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 870*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker def _new_wrapped_storage(self, untyped_storage) -> Self: 873*da0073e9SAndroid Build Coastguard Worker assert type(untyped_storage) == torch.UntypedStorage 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker if type(self) == TypedStorage: 876*da0073e9SAndroid Build Coastguard Worker return cast( 877*da0073e9SAndroid Build Coastguard Worker Self, 878*da0073e9SAndroid Build Coastguard Worker TypedStorage( 879*da0073e9SAndroid Build Coastguard Worker wrap_storage=untyped_storage, dtype=self.dtype, _internal=True 880*da0073e9SAndroid Build Coastguard Worker ), 881*da0073e9SAndroid Build Coastguard Worker ) 882*da0073e9SAndroid Build Coastguard Worker else: 883*da0073e9SAndroid Build Coastguard Worker return type(self)(wrap_storage=untyped_storage) 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Worker def __len__(self): 886*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 887*da0073e9SAndroid Build Coastguard Worker return self._size() 888*da0073e9SAndroid Build Coastguard Worker 889*da0073e9SAndroid Build Coastguard Worker def _maybe_wrap_index(self, idx, is_stop=False): 890*da0073e9SAndroid Build Coastguard Worker if idx is None: 891*da0073e9SAndroid Build Coastguard Worker if is_stop: 892*da0073e9SAndroid Build Coastguard Worker return self._size() 893*da0073e9SAndroid Build Coastguard Worker else: 894*da0073e9SAndroid Build Coastguard Worker return 0 895*da0073e9SAndroid Build Coastguard Worker 896*da0073e9SAndroid Build Coastguard Worker else: 897*da0073e9SAndroid Build Coastguard Worker if type(idx) != int: 898*da0073e9SAndroid Build Coastguard Worker raise TypeError(f"can't index a {type(self)} with {type(idx)}") 899*da0073e9SAndroid Build Coastguard Worker if is_stop: 900*da0073e9SAndroid Build Coastguard Worker if (idx > self._size()) or (idx < -self._size()): 901*da0073e9SAndroid Build Coastguard Worker raise IndexError( 902*da0073e9SAndroid Build Coastguard Worker f"index {idx} out of range for storage of size {self.size()}" 903*da0073e9SAndroid Build Coastguard Worker ) 904*da0073e9SAndroid Build Coastguard Worker if idx > 0: 905*da0073e9SAndroid Build Coastguard Worker return idx 906*da0073e9SAndroid Build Coastguard Worker else: 907*da0073e9SAndroid Build Coastguard Worker return idx % self._size() 908*da0073e9SAndroid Build Coastguard Worker else: 909*da0073e9SAndroid Build Coastguard Worker if (idx >= self._size()) or (idx < -self._size()): 910*da0073e9SAndroid Build Coastguard Worker raise IndexError( 911*da0073e9SAndroid Build Coastguard Worker f"index {idx} out of range for storage of size {self.size()}" 912*da0073e9SAndroid Build Coastguard Worker ) 913*da0073e9SAndroid Build Coastguard Worker return idx % self._size() 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker def __setitem__(self, idx, value): 916*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 917*da0073e9SAndroid Build Coastguard Worker return self._setitem(idx, value) 918*da0073e9SAndroid Build Coastguard Worker 919*da0073e9SAndroid Build Coastguard Worker def _setitem(self, idx, value): 920*da0073e9SAndroid Build Coastguard Worker if not isinstance(idx, (int, slice)): 921*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") 922*da0073e9SAndroid Build Coastguard Worker if torch.is_storage(value): 923*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"cannot set item with value type {type(value)}") 924*da0073e9SAndroid Build Coastguard Worker if self.dtype in [ 925*da0073e9SAndroid Build Coastguard Worker torch.quint8, 926*da0073e9SAndroid Build Coastguard Worker torch.quint4x2, 927*da0073e9SAndroid Build Coastguard Worker torch.quint2x4, 928*da0073e9SAndroid Build Coastguard Worker torch.qint32, 929*da0073e9SAndroid Build Coastguard Worker torch.qint8, 930*da0073e9SAndroid Build Coastguard Worker ]: 931*da0073e9SAndroid Build Coastguard Worker interpret_dtypes = { 932*da0073e9SAndroid Build Coastguard Worker torch.quint8: torch.uint8, 933*da0073e9SAndroid Build Coastguard Worker torch.quint4x2: torch.uint8, 934*da0073e9SAndroid Build Coastguard Worker torch.quint2x4: torch.uint8, 935*da0073e9SAndroid Build Coastguard Worker torch.qint32: torch.int32, 936*da0073e9SAndroid Build Coastguard Worker torch.qint8: torch.int8, 937*da0073e9SAndroid Build Coastguard Worker } 938*da0073e9SAndroid Build Coastguard Worker tmp_dtype = interpret_dtypes[self.dtype] 939*da0073e9SAndroid Build Coastguard Worker tmp_tensor = torch.tensor( 940*da0073e9SAndroid Build Coastguard Worker [], dtype=tmp_dtype, device=self._untyped_storage.device 941*da0073e9SAndroid Build Coastguard Worker ) 942*da0073e9SAndroid Build Coastguard Worker tmp_tensor.set_( 943*da0073e9SAndroid Build Coastguard Worker TypedStorage( 944*da0073e9SAndroid Build Coastguard Worker wrap_storage=self._untyped_storage, dtype=tmp_dtype, _internal=True 945*da0073e9SAndroid Build Coastguard Worker ) 946*da0073e9SAndroid Build Coastguard Worker ) 947*da0073e9SAndroid Build Coastguard Worker else: 948*da0073e9SAndroid Build Coastguard Worker tmp_tensor = torch.tensor( 949*da0073e9SAndroid Build Coastguard Worker [], dtype=self.dtype, device=self._untyped_storage.device 950*da0073e9SAndroid Build Coastguard Worker ).set_(self) 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker tmp_tensor[idx] = value 953*da0073e9SAndroid Build Coastguard Worker 954*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, idx): 955*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 956*da0073e9SAndroid Build Coastguard Worker return self._getitem(idx) 957*da0073e9SAndroid Build Coastguard Worker 958*da0073e9SAndroid Build Coastguard Worker def _getitem(self, idx): 959*da0073e9SAndroid Build Coastguard Worker if self._untyped_storage.device.type == "meta": 960*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Not available for 'meta' device type") 961*da0073e9SAndroid Build Coastguard Worker 962*da0073e9SAndroid Build Coastguard Worker # NOTE: Before TypedStorage existed, indexing with a slice used to be 963*da0073e9SAndroid Build Coastguard Worker # possible for <type>Storage objects. However, it would return 964*da0073e9SAndroid Build Coastguard Worker # a storage view, which would be a hassle to implement in TypedStorage, 965*da0073e9SAndroid Build Coastguard Worker # so it was disabled 966*da0073e9SAndroid Build Coastguard Worker if isinstance(idx, slice): 967*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 968*da0073e9SAndroid Build Coastguard Worker "slices are only supported in UntypedStorage.__getitem__" 969*da0073e9SAndroid Build Coastguard Worker ) 970*da0073e9SAndroid Build Coastguard Worker elif not isinstance(idx, int): 971*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"can't index a {type(self)} with {type(idx)}") 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker if self.dtype in [ 974*da0073e9SAndroid Build Coastguard Worker torch.quint8, 975*da0073e9SAndroid Build Coastguard Worker torch.quint4x2, 976*da0073e9SAndroid Build Coastguard Worker torch.quint2x4, 977*da0073e9SAndroid Build Coastguard Worker torch.qint32, 978*da0073e9SAndroid Build Coastguard Worker torch.qint8, 979*da0073e9SAndroid Build Coastguard Worker ]: 980*da0073e9SAndroid Build Coastguard Worker interpret_dtypes = { 981*da0073e9SAndroid Build Coastguard Worker torch.quint8: torch.uint8, 982*da0073e9SAndroid Build Coastguard Worker torch.quint4x2: torch.uint8, 983*da0073e9SAndroid Build Coastguard Worker torch.quint2x4: torch.uint8, 984*da0073e9SAndroid Build Coastguard Worker torch.qint32: torch.int32, 985*da0073e9SAndroid Build Coastguard Worker torch.qint8: torch.int8, 986*da0073e9SAndroid Build Coastguard Worker } 987*da0073e9SAndroid Build Coastguard Worker return TypedStorage( 988*da0073e9SAndroid Build Coastguard Worker wrap_storage=self._untyped_storage, 989*da0073e9SAndroid Build Coastguard Worker dtype=interpret_dtypes[self.dtype], 990*da0073e9SAndroid Build Coastguard Worker _internal=True, 991*da0073e9SAndroid Build Coastguard Worker )._getitem(idx) 992*da0073e9SAndroid Build Coastguard Worker 993*da0073e9SAndroid Build Coastguard Worker idx_wrapped = self._maybe_wrap_index(idx) 994*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.fake_tensor import unset_fake_temporarily 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker with unset_fake_temporarily(): 997*da0073e9SAndroid Build Coastguard Worker tmp_tensor = torch.tensor( 998*da0073e9SAndroid Build Coastguard Worker [], dtype=self.dtype, device=self._untyped_storage.device 999*da0073e9SAndroid Build Coastguard Worker ).set_(self) 1000*da0073e9SAndroid Build Coastguard Worker return tmp_tensor[idx_wrapped].item() 1001*da0073e9SAndroid Build Coastguard Worker 1002*da0073e9SAndroid Build Coastguard Worker def copy_(self, source: T, non_blocking: _Optional[bool] = None): 1003*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1004*da0073e9SAndroid Build Coastguard Worker if isinstance(source, TypedStorage): 1005*da0073e9SAndroid Build Coastguard Worker self._untyped_storage.copy_(source._untyped_storage, non_blocking) 1006*da0073e9SAndroid Build Coastguard Worker else: 1007*da0073e9SAndroid Build Coastguard Worker self._untyped_storage.copy_(source, non_blocking) 1008*da0073e9SAndroid Build Coastguard Worker return self 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Worker def nbytes(self): 1011*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1012*da0073e9SAndroid Build Coastguard Worker return self._nbytes() 1013*da0073e9SAndroid Build Coastguard Worker 1014*da0073e9SAndroid Build Coastguard Worker # For internal use only, to avoid deprecation warning 1015*da0073e9SAndroid Build Coastguard Worker def _nbytes(self): 1016*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.nbytes() 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker def type( 1019*da0073e9SAndroid Build Coastguard Worker self, 1020*da0073e9SAndroid Build Coastguard Worker dtype: _Optional[str] = None, 1021*da0073e9SAndroid Build Coastguard Worker non_blocking: bool = False, 1022*da0073e9SAndroid Build Coastguard Worker ) -> Union[_StorageBase, TypedStorage, str]: 1023*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1024*da0073e9SAndroid Build Coastguard Worker if dtype is None: 1025*da0073e9SAndroid Build Coastguard Worker legacy_class = self._get_legacy_storage_class() 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker if legacy_class is not None: 1028*da0073e9SAndroid Build Coastguard Worker return legacy_class.__module__ + "." + legacy_class.__name__ 1029*da0073e9SAndroid Build Coastguard Worker 1030*da0073e9SAndroid Build Coastguard Worker return ".".join([self.__module__, type(self).__name__]) 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker else: 1033*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.type(dtype, non_blocking) 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard Worker def cuda(self, device=None, non_blocking=False) -> Self: 1036*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1037*da0073e9SAndroid Build Coastguard Worker if self.dtype in [ 1038*da0073e9SAndroid Build Coastguard Worker torch.quint8, 1039*da0073e9SAndroid Build Coastguard Worker torch.quint4x2, 1040*da0073e9SAndroid Build Coastguard Worker torch.quint2x4, 1041*da0073e9SAndroid Build Coastguard Worker torch.qint32, 1042*da0073e9SAndroid Build Coastguard Worker torch.qint8, 1043*da0073e9SAndroid Build Coastguard Worker ]: 1044*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Cannot create CUDA storage with quantized dtype") 1045*da0073e9SAndroid Build Coastguard Worker cuda_storage = self._untyped_storage.cuda(device, non_blocking) 1046*da0073e9SAndroid Build Coastguard Worker return self._new_wrapped_storage(cuda_storage) 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker def hpu(self, device=None, non_blocking=False) -> Self: 1049*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1050*da0073e9SAndroid Build Coastguard Worker if self.dtype in [ 1051*da0073e9SAndroid Build Coastguard Worker torch.quint8, 1052*da0073e9SAndroid Build Coastguard Worker torch.quint4x2, 1053*da0073e9SAndroid Build Coastguard Worker torch.quint2x4, 1054*da0073e9SAndroid Build Coastguard Worker torch.qint32, 1055*da0073e9SAndroid Build Coastguard Worker torch.qint8, 1056*da0073e9SAndroid Build Coastguard Worker ]: 1057*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Cannot create HPU storage with quantized dtype") 1058*da0073e9SAndroid Build Coastguard Worker hpu_storage = self._untyped_storage.hpu(device, non_blocking) 1059*da0073e9SAndroid Build Coastguard Worker return self._new_wrapped_storage(hpu_storage) 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker def to(self, *, device: torch.device, non_blocking: bool = False) -> Self: 1062*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1063*da0073e9SAndroid Build Coastguard Worker if self.dtype in [ 1064*da0073e9SAndroid Build Coastguard Worker torch.quint8, 1065*da0073e9SAndroid Build Coastguard Worker torch.quint4x2, 1066*da0073e9SAndroid Build Coastguard Worker torch.quint2x4, 1067*da0073e9SAndroid Build Coastguard Worker torch.qint32, 1068*da0073e9SAndroid Build Coastguard Worker torch.qint8, 1069*da0073e9SAndroid Build Coastguard Worker ]: 1070*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1071*da0073e9SAndroid Build Coastguard Worker f"Cannot create {device.type.upper()} storage with quantized dtype" 1072*da0073e9SAndroid Build Coastguard Worker ) 1073*da0073e9SAndroid Build Coastguard Worker to_storage = self._untyped_storage.to(device=device, non_blocking=non_blocking) 1074*da0073e9SAndroid Build Coastguard Worker return self._new_wrapped_storage(to_storage) 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Worker def element_size(self): 1077*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1078*da0073e9SAndroid Build Coastguard Worker return self._element_size() 1079*da0073e9SAndroid Build Coastguard Worker 1080*da0073e9SAndroid Build Coastguard Worker # For internal use only, to avoid deprecation warning 1081*da0073e9SAndroid Build Coastguard Worker def _element_size(self): 1082*da0073e9SAndroid Build Coastguard Worker return torch._utils._element_size(self.dtype) 1083*da0073e9SAndroid Build Coastguard Worker 1084*da0073e9SAndroid Build Coastguard Worker def get_device(self) -> _int: 1085*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1086*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.get_device() 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker def __str__(self): 1089*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1090*da0073e9SAndroid Build Coastguard Worker info_str = ( 1091*da0073e9SAndroid Build Coastguard Worker f"[{torch.typename(self)}(dtype={self.dtype}, " 1092*da0073e9SAndroid Build Coastguard Worker f"device={self.device}) of size {len(self)}]" 1093*da0073e9SAndroid Build Coastguard Worker ) 1094*da0073e9SAndroid Build Coastguard Worker if self.device.type == "meta": 1095*da0073e9SAndroid Build Coastguard Worker return "...\n" + info_str 1096*da0073e9SAndroid Build Coastguard Worker else: 1097*da0073e9SAndroid Build Coastguard Worker data_str = " " + "\n ".join(str(self[i]) for i in range(self.size())) 1098*da0073e9SAndroid Build Coastguard Worker return data_str + "\n" + info_str 1099*da0073e9SAndroid Build Coastguard Worker 1100*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 1101*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1102*da0073e9SAndroid Build Coastguard Worker return str(self) 1103*da0073e9SAndroid Build Coastguard Worker 1104*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 1105*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1106*da0073e9SAndroid Build Coastguard Worker return iter(self[i] for i in range(self.size())) 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker def __copy__(self): 1109*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1110*da0073e9SAndroid Build Coastguard Worker return self._new_wrapped_storage(copy.copy(self._untyped_storage)) 1111*da0073e9SAndroid Build Coastguard Worker 1112*da0073e9SAndroid Build Coastguard Worker def __deepcopy__(self, memo): 1113*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1114*da0073e9SAndroid Build Coastguard Worker return self._deepcopy(memo) 1115*da0073e9SAndroid Build Coastguard Worker 1116*da0073e9SAndroid Build Coastguard Worker # For internal use only, to avoid deprecation warning 1117*da0073e9SAndroid Build Coastguard Worker def _deepcopy(self, memo): 1118*da0073e9SAndroid Build Coastguard Worker return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo)) 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker def __sizeof__(self): 1121*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1122*da0073e9SAndroid Build Coastguard Worker return super().__sizeof__() + self.nbytes() 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Worker def clone(self): 1125*da0073e9SAndroid Build Coastguard Worker """Return a copy of this storage.""" 1126*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1127*da0073e9SAndroid Build Coastguard Worker return self._new_wrapped_storage(self._untyped_storage.clone()) 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker def tolist(self): 1130*da0073e9SAndroid Build Coastguard Worker """Return a list containing the elements of this storage.""" 1131*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1132*da0073e9SAndroid Build Coastguard Worker return list(self) 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker def cpu(self): 1135*da0073e9SAndroid Build Coastguard Worker """Return a CPU copy of this storage if it's not already on the CPU.""" 1136*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1137*da0073e9SAndroid Build Coastguard Worker return self._new_wrapped_storage(self._untyped_storage.cpu()) 1138*da0073e9SAndroid Build Coastguard Worker 1139*da0073e9SAndroid Build Coastguard Worker def is_pinned(self, device: Union[str, torch.device] = "cuda"): 1140*da0073e9SAndroid Build Coastguard Worker r"""Determine whether the CPU TypedStorage is already pinned on device. 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker Args: 1143*da0073e9SAndroid Build Coastguard Worker device (str or torch.device): The device to pin memory on. Default: ``'cuda'`` 1144*da0073e9SAndroid Build Coastguard Worker 1145*da0073e9SAndroid Build Coastguard Worker Returns: 1146*da0073e9SAndroid Build Coastguard Worker A boolean variable. 1147*da0073e9SAndroid Build Coastguard Worker """ 1148*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1149*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.is_pinned(device) 1150*da0073e9SAndroid Build Coastguard Worker 1151*da0073e9SAndroid Build Coastguard Worker def pin_memory(self, device: Union[str, torch.device] = "cuda"): 1152*da0073e9SAndroid Build Coastguard Worker r"""Copy the CPU TypedStorage to pinned memory, if it's not already pinned. 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker Args: 1155*da0073e9SAndroid Build Coastguard Worker device (str or torch.device): The device to pin memory on. Default: ``'cuda'``. 1156*da0073e9SAndroid Build Coastguard Worker 1157*da0073e9SAndroid Build Coastguard Worker Returns: 1158*da0073e9SAndroid Build Coastguard Worker A pinned CPU storage. 1159*da0073e9SAndroid Build Coastguard Worker """ 1160*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1161*da0073e9SAndroid Build Coastguard Worker return self._new_wrapped_storage( 1162*da0073e9SAndroid Build Coastguard Worker self._untyped_storage.pin_memory(device=device) 1163*da0073e9SAndroid Build Coastguard Worker ) 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Worker def share_memory_(self): 1166*da0073e9SAndroid Build Coastguard Worker """See :meth:`torch.UntypedStorage.share_memory_`""" 1167*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1168*da0073e9SAndroid Build Coastguard Worker return self._share_memory_() 1169*da0073e9SAndroid Build Coastguard Worker 1170*da0073e9SAndroid Build Coastguard Worker # For internal use only, to avoid deprecation warning 1171*da0073e9SAndroid Build Coastguard Worker def _share_memory_(self): 1172*da0073e9SAndroid Build Coastguard Worker self._untyped_storage.share_memory_() 1173*da0073e9SAndroid Build Coastguard Worker return self 1174*da0073e9SAndroid Build Coastguard Worker 1175*da0073e9SAndroid Build Coastguard Worker def _new_shared(self, size, *, device=None): 1176*da0073e9SAndroid Build Coastguard Worker """Create a new storage in shared memory with the same data type.""" 1177*da0073e9SAndroid Build Coastguard Worker if device is None: 1178*da0073e9SAndroid Build Coastguard Worker device = "cpu" 1179*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 1180*da0073e9SAndroid Build Coastguard Worker untyped_storage = torch.UntypedStorage._new_shared( 1181*da0073e9SAndroid Build Coastguard Worker size * self._element_size(), device=device 1182*da0073e9SAndroid Build Coastguard Worker ) 1183*da0073e9SAndroid Build Coastguard Worker return TypedStorage( 1184*da0073e9SAndroid Build Coastguard Worker wrap_storage=untyped_storage, dtype=self.dtype, _internal=True 1185*da0073e9SAndroid Build Coastguard Worker ) 1186*da0073e9SAndroid Build Coastguard Worker 1187*da0073e9SAndroid Build Coastguard Worker @property 1188*da0073e9SAndroid Build Coastguard Worker def _cdata(self): 1189*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage._cdata 1190*da0073e9SAndroid Build Coastguard Worker 1191*da0073e9SAndroid Build Coastguard Worker @property 1192*da0073e9SAndroid Build Coastguard Worker def device(self): 1193*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1194*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.device 1195*da0073e9SAndroid Build Coastguard Worker 1196*da0073e9SAndroid Build Coastguard Worker def size(self): 1197*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1198*da0073e9SAndroid Build Coastguard Worker return self._size() 1199*da0073e9SAndroid Build Coastguard Worker 1200*da0073e9SAndroid Build Coastguard Worker # For internal use only, to avoid deprecation warning 1201*da0073e9SAndroid Build Coastguard Worker def _size(self): 1202*da0073e9SAndroid Build Coastguard Worker # NB: don't indirect through __len__, as that requires 1203*da0073e9SAndroid Build Coastguard Worker # an int to be returned 1204*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.nbytes() // self._element_size() 1205*da0073e9SAndroid Build Coastguard Worker 1206*da0073e9SAndroid Build Coastguard Worker def pickle_storage_type(self): 1207*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1208*da0073e9SAndroid Build Coastguard Worker return self._pickle_storage_type() 1209*da0073e9SAndroid Build Coastguard Worker 1210*da0073e9SAndroid Build Coastguard Worker # For internal use only, to avoid deprecation warning 1211*da0073e9SAndroid Build Coastguard Worker def _pickle_storage_type(self): 1212*da0073e9SAndroid Build Coastguard Worker try: 1213*da0073e9SAndroid Build Coastguard Worker return _dtype_to_storage_type_map()[self.dtype] 1214*da0073e9SAndroid Build Coastguard Worker except KeyError as e: 1215*da0073e9SAndroid Build Coastguard Worker raise KeyError(f"dtype {self.dtype} is not recognized") from e 1216*da0073e9SAndroid Build Coastguard Worker 1217*da0073e9SAndroid Build Coastguard Worker def __reduce__(self): 1218*da0073e9SAndroid Build Coastguard Worker b = io.BytesIO() 1219*da0073e9SAndroid Build Coastguard Worker torch.save(self, b, _use_new_zipfile_serialization=False) 1220*da0073e9SAndroid Build Coastguard Worker return (_load_from_bytes, (b.getvalue(),)) 1221*da0073e9SAndroid Build Coastguard Worker 1222*da0073e9SAndroid Build Coastguard Worker def data_ptr(self): 1223*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1224*da0073e9SAndroid Build Coastguard Worker return self._data_ptr() 1225*da0073e9SAndroid Build Coastguard Worker 1226*da0073e9SAndroid Build Coastguard Worker # For internal use only, to avoid deprecation warning 1227*da0073e9SAndroid Build Coastguard Worker def _data_ptr(self): 1228*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.data_ptr() 1229*da0073e9SAndroid Build Coastguard Worker 1230*da0073e9SAndroid Build Coastguard Worker def resizable(self): 1231*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1232*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.resizable() 1233*da0073e9SAndroid Build Coastguard Worker 1234*da0073e9SAndroid Build Coastguard Worker def resize_(self, size): 1235*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1236*da0073e9SAndroid Build Coastguard Worker self._resize_(size) 1237*da0073e9SAndroid Build Coastguard Worker 1238*da0073e9SAndroid Build Coastguard Worker # For internal use only, to avoid deprecation warning 1239*da0073e9SAndroid Build Coastguard Worker def _resize_(self, size): 1240*da0073e9SAndroid Build Coastguard Worker self._untyped_storage.resize_(size * self._element_size()) 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker @classmethod 1243*da0073e9SAndroid Build Coastguard Worker def _free_weak_ref(cls, *args, **kwargs): 1244*da0073e9SAndroid Build Coastguard Worker return UntypedStorage._free_weak_ref(*args, **kwargs) 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker def _weak_ref(self, *args, **kwargs): 1247*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage._weak_ref(*args, **kwargs) 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker @classmethod 1250*da0073e9SAndroid Build Coastguard Worker def from_buffer(cls, *args, **kwargs): 1251*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1252*da0073e9SAndroid Build Coastguard Worker return cls._from_buffer(*args, **kwargs) 1253*da0073e9SAndroid Build Coastguard Worker 1254*da0073e9SAndroid Build Coastguard Worker @classmethod 1255*da0073e9SAndroid Build Coastguard Worker def _from_buffer(cls, *args, dtype=None, device=None, **kwargs): 1256*da0073e9SAndroid Build Coastguard Worker if cls == TypedStorage: 1257*da0073e9SAndroid Build Coastguard Worker dtype = torch.get_default_dtype() if dtype is None else dtype 1258*da0073e9SAndroid Build Coastguard Worker device = torch.device("cpu" if device is None else device) 1259*da0073e9SAndroid Build Coastguard Worker if device.type != "cpu": 1260*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1261*da0073e9SAndroid Build Coastguard Worker f"TypedStorage.from_buffer: Not available for device {device.type}" 1262*da0073e9SAndroid Build Coastguard Worker ) 1263*da0073e9SAndroid Build Coastguard Worker untyped_storage: torch.UntypedStorage = torch.UntypedStorage.from_buffer( 1264*da0073e9SAndroid Build Coastguard Worker *args, dtype=dtype, **kwargs 1265*da0073e9SAndroid Build Coastguard Worker ) 1266*da0073e9SAndroid Build Coastguard Worker 1267*da0073e9SAndroid Build Coastguard Worker else: 1268*da0073e9SAndroid Build Coastguard Worker if dtype is not None or len(args) == 5: 1269*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1270*da0073e9SAndroid Build Coastguard Worker "from_buffer: 'dtype' can only be specified in " 1271*da0073e9SAndroid Build Coastguard Worker "UntypedStorage.from_buffer and TypedStorage.from_buffer" 1272*da0073e9SAndroid Build Coastguard Worker ) 1273*da0073e9SAndroid Build Coastguard Worker if device is not None: 1274*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1275*da0073e9SAndroid Build Coastguard Worker "from_buffer: 'device' can only be specified in " 1276*da0073e9SAndroid Build Coastguard Worker "UntypedStorage.from_buffer and TypedStorage.from_buffer" 1277*da0073e9SAndroid Build Coastguard Worker ) 1278*da0073e9SAndroid Build Coastguard Worker 1279*da0073e9SAndroid Build Coastguard Worker dtype = cls._dtype 1280*da0073e9SAndroid Build Coastguard Worker untyped_storage = torch.UntypedStorage.from_buffer( 1281*da0073e9SAndroid Build Coastguard Worker *args, dtype=dtype, **kwargs 1282*da0073e9SAndroid Build Coastguard Worker ) 1283*da0073e9SAndroid Build Coastguard Worker 1284*da0073e9SAndroid Build Coastguard Worker return TypedStorage(wrap_storage=untyped_storage, dtype=dtype, _internal=True) 1285*da0073e9SAndroid Build Coastguard Worker 1286*da0073e9SAndroid Build Coastguard Worker def _to(self, dtype): 1287*da0073e9SAndroid Build Coastguard Worker if not isinstance(dtype, torch.dtype): 1288*da0073e9SAndroid Build Coastguard Worker raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}") 1289*da0073e9SAndroid Build Coastguard Worker storage = ( 1290*da0073e9SAndroid Build Coastguard Worker torch.tensor([], dtype=self.dtype, device=self.device) 1291*da0073e9SAndroid Build Coastguard Worker .set_(self) 1292*da0073e9SAndroid Build Coastguard Worker .to(dtype) 1293*da0073e9SAndroid Build Coastguard Worker ._typed_storage() 1294*da0073e9SAndroid Build Coastguard Worker ) 1295*da0073e9SAndroid Build Coastguard Worker if storage.data_ptr() == self.data_ptr(): 1296*da0073e9SAndroid Build Coastguard Worker storage = storage.clone() 1297*da0073e9SAndroid Build Coastguard Worker return storage 1298*da0073e9SAndroid Build Coastguard Worker 1299*da0073e9SAndroid Build Coastguard Worker def double(self): 1300*da0073e9SAndroid Build Coastguard Worker """Casts this storage to double type.""" 1301*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1302*da0073e9SAndroid Build Coastguard Worker return self._to(torch.double) 1303*da0073e9SAndroid Build Coastguard Worker 1304*da0073e9SAndroid Build Coastguard Worker def float(self): 1305*da0073e9SAndroid Build Coastguard Worker """Casts this storage to float type.""" 1306*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1307*da0073e9SAndroid Build Coastguard Worker return self._to(torch.float) 1308*da0073e9SAndroid Build Coastguard Worker 1309*da0073e9SAndroid Build Coastguard Worker def half(self): 1310*da0073e9SAndroid Build Coastguard Worker """Casts this storage to half type.""" 1311*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1312*da0073e9SAndroid Build Coastguard Worker return self._to(torch.half) 1313*da0073e9SAndroid Build Coastguard Worker 1314*da0073e9SAndroid Build Coastguard Worker def long(self): 1315*da0073e9SAndroid Build Coastguard Worker """Casts this storage to long type.""" 1316*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1317*da0073e9SAndroid Build Coastguard Worker return self._to(torch.long) 1318*da0073e9SAndroid Build Coastguard Worker 1319*da0073e9SAndroid Build Coastguard Worker def int(self): 1320*da0073e9SAndroid Build Coastguard Worker """Casts this storage to int type.""" 1321*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1322*da0073e9SAndroid Build Coastguard Worker return self._to(torch.int) 1323*da0073e9SAndroid Build Coastguard Worker 1324*da0073e9SAndroid Build Coastguard Worker def short(self): 1325*da0073e9SAndroid Build Coastguard Worker """Casts this storage to short type.""" 1326*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1327*da0073e9SAndroid Build Coastguard Worker return self._to(torch.short) 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker def char(self): 1330*da0073e9SAndroid Build Coastguard Worker """Casts this storage to char type.""" 1331*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1332*da0073e9SAndroid Build Coastguard Worker return self._to(torch.int8) 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker def byte(self): 1335*da0073e9SAndroid Build Coastguard Worker """Casts this storage to byte type.""" 1336*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1337*da0073e9SAndroid Build Coastguard Worker return self._to(torch.uint8) 1338*da0073e9SAndroid Build Coastguard Worker 1339*da0073e9SAndroid Build Coastguard Worker def bool(self): 1340*da0073e9SAndroid Build Coastguard Worker """Casts this storage to bool type.""" 1341*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1342*da0073e9SAndroid Build Coastguard Worker return self._to(torch.bool) 1343*da0073e9SAndroid Build Coastguard Worker 1344*da0073e9SAndroid Build Coastguard Worker def bfloat16(self): 1345*da0073e9SAndroid Build Coastguard Worker """Casts this storage to bfloat16 type.""" 1346*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1347*da0073e9SAndroid Build Coastguard Worker return self._to(torch.bfloat16) 1348*da0073e9SAndroid Build Coastguard Worker 1349*da0073e9SAndroid Build Coastguard Worker def complex_double(self): 1350*da0073e9SAndroid Build Coastguard Worker """Casts this storage to complex double type.""" 1351*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1352*da0073e9SAndroid Build Coastguard Worker return self._to(torch.cdouble) 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Worker def complex_float(self): 1355*da0073e9SAndroid Build Coastguard Worker """Casts this storage to complex float type.""" 1356*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1357*da0073e9SAndroid Build Coastguard Worker return self._to(torch.cfloat) 1358*da0073e9SAndroid Build Coastguard Worker 1359*da0073e9SAndroid Build Coastguard Worker def float8_e5m2(self): 1360*da0073e9SAndroid Build Coastguard Worker """Casts this storage to float8_e5m2 type""" 1361*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1362*da0073e9SAndroid Build Coastguard Worker return self._to(torch.float8_e5m2) 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard Worker def float8_e4m3fn(self): 1365*da0073e9SAndroid Build Coastguard Worker """Casts this storage to float8_e4m3fn type""" 1366*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1367*da0073e9SAndroid Build Coastguard Worker return self._to(torch.float8_e4m3fn) 1368*da0073e9SAndroid Build Coastguard Worker 1369*da0073e9SAndroid Build Coastguard Worker def float8_e5m2fnuz(self): 1370*da0073e9SAndroid Build Coastguard Worker """Casts this storage to float8_e5m2fnuz type""" 1371*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1372*da0073e9SAndroid Build Coastguard Worker return self._to(torch.float8_e5m2fnuz) 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker def float8_e4m3fnuz(self): 1375*da0073e9SAndroid Build Coastguard Worker """Casts this storage to float8_e4m3fnuz type""" 1376*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1377*da0073e9SAndroid Build Coastguard Worker return self._to(torch.float8_e4m3fnuz) 1378*da0073e9SAndroid Build Coastguard Worker 1379*da0073e9SAndroid Build Coastguard Worker @classmethod 1380*da0073e9SAndroid Build Coastguard Worker def from_file(cls, filename, shared, size): 1381*da0073e9SAndroid Build Coastguard Worker """from_file(filename, shared=False, size=0) -> Storage 1382*da0073e9SAndroid Build Coastguard Worker 1383*da0073e9SAndroid Build Coastguard Worker Creates a CPU storage backed by a memory-mapped file. 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker If ``shared`` is ``True``, then memory is shared between all processes. 1386*da0073e9SAndroid Build Coastguard Worker All changes are written to the file. If ``shared`` is ``False``, then the changes on 1387*da0073e9SAndroid Build Coastguard Worker the storage do not affect the file. 1388*da0073e9SAndroid Build Coastguard Worker 1389*da0073e9SAndroid Build Coastguard Worker ``size`` is the number of elements in the storage. If ``shared`` is ``False``, 1390*da0073e9SAndroid Build Coastguard Worker then the file must contain at least ``size * sizeof(Type)`` bytes 1391*da0073e9SAndroid Build Coastguard Worker (``Type`` is the type of storage). If ``shared`` is ``True`` the file will be created if needed. 1392*da0073e9SAndroid Build Coastguard Worker 1393*da0073e9SAndroid Build Coastguard Worker Args: 1394*da0073e9SAndroid Build Coastguard Worker filename (str): file name to map 1395*da0073e9SAndroid Build Coastguard Worker shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the 1396*da0073e9SAndroid Build Coastguard Worker underlying `mmap(2) call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_) 1397*da0073e9SAndroid Build Coastguard Worker size (int): number of elements in the storage 1398*da0073e9SAndroid Build Coastguard Worker """ 1399*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1400*da0073e9SAndroid Build Coastguard Worker if cls == TypedStorage: 1401*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("from_file can only be called on derived classes") 1402*da0073e9SAndroid Build Coastguard Worker untyped_storage = UntypedStorage.from_file( 1403*da0073e9SAndroid Build Coastguard Worker filename, shared, size * torch._utils._element_size(cls.dtype) 1404*da0073e9SAndroid Build Coastguard Worker ) 1405*da0073e9SAndroid Build Coastguard Worker storage = cls(wrap_storage=untyped_storage) 1406*da0073e9SAndroid Build Coastguard Worker return storage 1407*da0073e9SAndroid Build Coastguard Worker 1408*da0073e9SAndroid Build Coastguard Worker @classmethod 1409*da0073e9SAndroid Build Coastguard Worker def _expired(cls, *args, **kwargs): 1410*da0073e9SAndroid Build Coastguard Worker return UntypedStorage._expired(*args, **kwargs) 1411*da0073e9SAndroid Build Coastguard Worker 1412*da0073e9SAndroid Build Coastguard Worker def _write_file(self, *args, **kwargs): 1413*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage._write_file(*args, **kwargs) 1414*da0073e9SAndroid Build Coastguard Worker 1415*da0073e9SAndroid Build Coastguard Worker def _set_from_file(self, *args, **kwargs): 1416*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage._set_from_file(*args, **kwargs) 1417*da0073e9SAndroid Build Coastguard Worker 1418*da0073e9SAndroid Build Coastguard Worker def _set_cdata(self, *args, **kwargs): 1419*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage._set_cdata(*args, **kwargs) 1420*da0073e9SAndroid Build Coastguard Worker 1421*da0073e9SAndroid Build Coastguard Worker def _share_cuda_(self, *args, **kwargs): 1422*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage._share_cuda_(*args, **kwargs) 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker def is_shared(self): 1425*da0073e9SAndroid Build Coastguard Worker _warn_typed_storage_removal() 1426*da0073e9SAndroid Build Coastguard Worker return self._is_shared() 1427*da0073e9SAndroid Build Coastguard Worker 1428*da0073e9SAndroid Build Coastguard Worker # For internal use only, to avoid deprecation warning 1429*da0073e9SAndroid Build Coastguard Worker def _is_shared(self): 1430*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage.is_shared() 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker @classmethod 1433*da0073e9SAndroid Build Coastguard Worker def _new_shared_cuda(cls, *args, **kwargs): 1434*da0073e9SAndroid Build Coastguard Worker return torch.UntypedStorage._new_shared_cuda(*args, **kwargs) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker def _share_filename_cpu_(self, *args, **kwargs): 1437*da0073e9SAndroid Build Coastguard Worker ( 1438*da0073e9SAndroid Build Coastguard Worker manager_handle, 1439*da0073e9SAndroid Build Coastguard Worker storage_handle, 1440*da0073e9SAndroid Build Coastguard Worker size, 1441*da0073e9SAndroid Build Coastguard Worker ) = self._untyped_storage._share_filename_cpu_(*args, **kwargs) 1442*da0073e9SAndroid Build Coastguard Worker return manager_handle, storage_handle, size // self._element_size() 1443*da0073e9SAndroid Build Coastguard Worker 1444*da0073e9SAndroid Build Coastguard Worker def _shared_decref(self): 1445*da0073e9SAndroid Build Coastguard Worker self._untyped_storage._shared_decref() 1446*da0073e9SAndroid Build Coastguard Worker return self 1447*da0073e9SAndroid Build Coastguard Worker 1448*da0073e9SAndroid Build Coastguard Worker @classmethod 1449*da0073e9SAndroid Build Coastguard Worker def _release_ipc_counter(cls, *args, device=None, **kwargs): 1450*da0073e9SAndroid Build Coastguard Worker return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) 1451*da0073e9SAndroid Build Coastguard Worker 1452*da0073e9SAndroid Build Coastguard Worker def _shared_incref(self, *args, **kwargs): 1453*da0073e9SAndroid Build Coastguard Worker return self._untyped_storage._shared_incref(*args, **kwargs) 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker def _share_fd_cpu_(self, *args, **kwargs): 1456*da0073e9SAndroid Build Coastguard Worker fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs) 1457*da0073e9SAndroid Build Coastguard Worker return fd, size // self._element_size() 1458*da0073e9SAndroid Build Coastguard Worker 1459*da0073e9SAndroid Build Coastguard Worker def _get_legacy_storage_class(self): 1460*da0073e9SAndroid Build Coastguard Worker if self.dtype not in _dtype_to_storage_type_map(): 1461*da0073e9SAndroid Build Coastguard Worker return None 1462*da0073e9SAndroid Build Coastguard Worker 1463*da0073e9SAndroid Build Coastguard Worker storage_name = _dtype_to_storage_type_map()[self.dtype] 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker if self.device.type not in [ 1466*da0073e9SAndroid Build Coastguard Worker "cpu", 1467*da0073e9SAndroid Build Coastguard Worker "cuda", 1468*da0073e9SAndroid Build Coastguard Worker "hpu", 1469*da0073e9SAndroid Build Coastguard Worker torch._C._get_privateuse1_backend_name(), 1470*da0073e9SAndroid Build Coastguard Worker ]: 1471*da0073e9SAndroid Build Coastguard Worker return None 1472*da0073e9SAndroid Build Coastguard Worker 1473*da0073e9SAndroid Build Coastguard Worker module = ( 1474*da0073e9SAndroid Build Coastguard Worker torch if self.device.type == "cpu" else getattr(torch, self.device.type) 1475*da0073e9SAndroid Build Coastguard Worker ) 1476*da0073e9SAndroid Build Coastguard Worker 1477*da0073e9SAndroid Build Coastguard Worker try: 1478*da0073e9SAndroid Build Coastguard Worker return getattr(module, storage_name) 1479*da0073e9SAndroid Build Coastguard Worker except AttributeError: 1480*da0073e9SAndroid Build Coastguard Worker return None 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker 1483*da0073e9SAndroid Build Coastguard WorkerTypedStorage.type.__doc__ = _type.__doc__ 1484*da0073e9SAndroid Build Coastguard WorkerTypedStorage.cuda.__doc__ = _StorageBase.cuda.__doc__ 1485*da0073e9SAndroid Build Coastguard WorkerTypedStorage.hpu.__doc__ = _StorageBase.hpu.__doc__ 1486*da0073e9SAndroid Build Coastguard WorkerTypedStorage.to.__doc__ = _to.__doc__ 1487*da0073e9SAndroid Build Coastguard Worker 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Workerclass _LegacyStorageMeta(type): 1490*da0073e9SAndroid Build Coastguard Worker dtype: torch.dtype 1491*da0073e9SAndroid Build Coastguard Worker 1492*da0073e9SAndroid Build Coastguard Worker def __instancecheck__(cls, instance): 1493*da0073e9SAndroid Build Coastguard Worker if type(instance) == TypedStorage: 1494*da0073e9SAndroid Build Coastguard Worker cls_device = _get_device_from_module(cls.__module__) 1495*da0073e9SAndroid Build Coastguard Worker return (cls_device == instance.device.type) and ( 1496*da0073e9SAndroid Build Coastguard Worker cls.dtype == instance.dtype 1497*da0073e9SAndroid Build Coastguard Worker ) 1498*da0073e9SAndroid Build Coastguard Worker return False 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker 1501*da0073e9SAndroid Build Coastguard Workerclass _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta): 1502*da0073e9SAndroid Build Coastguard Worker @classmethod 1503*da0073e9SAndroid Build Coastguard Worker def _new_shared(cls, size): 1504*da0073e9SAndroid Build Coastguard Worker """Create a new storage in shared memory with the same data type.""" 1505*da0073e9SAndroid Build Coastguard Worker untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size()) 1506*da0073e9SAndroid Build Coastguard Worker return cls(wrap_storage=untyped_storage) 1507*da0073e9SAndroid Build Coastguard Worker 1508*da0073e9SAndroid Build Coastguard Worker @classmethod 1509*da0073e9SAndroid Build Coastguard Worker def _release_ipc_counter(cls, *args, **kwargs): 1510*da0073e9SAndroid Build Coastguard Worker return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs) 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Worker @classmethod 1513*da0073e9SAndroid Build Coastguard Worker def _new_shared_filename(cls, manager, obj, size): 1514*da0073e9SAndroid Build Coastguard Worker bytes_size = size * torch._utils._element_size(cls.dtype) 1515*da0073e9SAndroid Build Coastguard Worker return cls( 1516*da0073e9SAndroid Build Coastguard Worker wrap_storage=torch.UntypedStorage._new_shared_filename_cpu( 1517*da0073e9SAndroid Build Coastguard Worker manager, obj, bytes_size 1518*da0073e9SAndroid Build Coastguard Worker ) 1519*da0073e9SAndroid Build Coastguard Worker ) 1520*da0073e9SAndroid Build Coastguard Worker 1521*da0073e9SAndroid Build Coastguard Worker 1522*da0073e9SAndroid Build Coastguard Workerdef _get_dtype_from_pickle_storage_type(pickle_storage_type: str): 1523*da0073e9SAndroid Build Coastguard Worker try: 1524*da0073e9SAndroid Build Coastguard Worker return _storage_type_to_dtype_map()[pickle_storage_type] 1525*da0073e9SAndroid Build Coastguard Worker except KeyError as e: 1526*da0073e9SAndroid Build Coastguard Worker raise KeyError( 1527*da0073e9SAndroid Build Coastguard Worker f'pickle storage type "{pickle_storage_type}" is not recognized' 1528*da0073e9SAndroid Build Coastguard Worker ) from e 1529