xref: /aosp_15_r20/external/pytorch/torch/storage.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport collections
6*da0073e9SAndroid Build Coastguard Workerimport copy
7*da0073e9SAndroid Build Coastguard Workerimport functools
8*da0073e9SAndroid Build Coastguard Workerimport io
9*da0073e9SAndroid Build Coastguard Workerimport threading
10*da0073e9SAndroid Build Coastguard Workerimport warnings
11*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, cast, Dict as _Dict, Optional as _Optional, Type, TypeVar, Union
12*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import Self
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerimport torch
15*da0073e9SAndroid Build Coastguard Workerfrom torch._utils import _to, _type
16*da0073e9SAndroid Build Coastguard Workerfrom torch.types import _bool, _int, Storage
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker__all__ = ["TypedStorage", "UntypedStorage"]
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workertry:
23*da0073e9SAndroid Build Coastguard Worker    import numpy as np
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker    HAS_NUMPY = True
26*da0073e9SAndroid Build Coastguard Workerexcept ModuleNotFoundError:
27*da0073e9SAndroid Build Coastguard Worker    HAS_NUMPY = False
28*da0073e9SAndroid Build Coastguard Worker    np = None  # type: ignore[assignment]
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker_share_memory_lock = threading.Lock()
32*da0073e9SAndroid Build Coastguard Worker_share_memory_map: _Dict[int, threading.RLock] = {}
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T", bound="Union[_StorageBase, TypedStorage]")
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerclass _StorageBase:
38*da0073e9SAndroid Build Coastguard Worker    _cdata: Any
39*da0073e9SAndroid Build Coastguard Worker    is_sparse: _bool = False
40*da0073e9SAndroid Build Coastguard Worker    is_sparse_csr: _bool = False
41*da0073e9SAndroid Build Coastguard Worker    device: torch.device
42*da0073e9SAndroid Build Coastguard Worker    # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True)
43*da0073e9SAndroid Build Coastguard Worker    _fake_device: _Optional[torch.device] = None
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *args, **kwargs):
46*da0073e9SAndroid Build Coastguard Worker        pass
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker    def __len__(self) -> _int:
49*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, idx):
52*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker    def __setitem__(self, *args, **kwargs):
55*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T:
58*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    def new(self) -> Union[_StorageBase, TypedStorage]:
61*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    def nbytes(self) -> _int:
64*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    def size(self) -> _int:
67*da0073e9SAndroid Build Coastguard Worker        return self.nbytes()
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    def type(
70*da0073e9SAndroid Build Coastguard Worker        self, dtype: _Optional[str] = None, non_blocking: _bool = False
71*da0073e9SAndroid Build Coastguard Worker    ) -> Union[_StorageBase, TypedStorage]:
72*da0073e9SAndroid Build Coastguard Worker        return _type(self, dtype, non_blocking)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def cuda(
75*da0073e9SAndroid Build Coastguard Worker        self, device=None, non_blocking=False
76*da0073e9SAndroid Build Coastguard Worker    ) -> Union[_StorageBase, TypedStorage]:
77*da0073e9SAndroid Build Coastguard Worker        """Returns a copy of this object in CUDA memory.
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        If this object is already in CUDA memory and on the correct device, then
80*da0073e9SAndroid Build Coastguard Worker        no copy is performed and the original object is returned.
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker        Args:
83*da0073e9SAndroid Build Coastguard Worker            device (int): The destination GPU id. Defaults to the current device.
84*da0073e9SAndroid Build Coastguard Worker            non_blocking (bool): If ``True`` and the source is in pinned memory,
85*da0073e9SAndroid Build Coastguard Worker                the copy will be asynchronous with respect to the host. Otherwise,
86*da0073e9SAndroid Build Coastguard Worker                the argument has no effect.
87*da0073e9SAndroid Build Coastguard Worker        """
88*da0073e9SAndroid Build Coastguard Worker        device2 = torch.device("cuda", device) if device else torch.device("cuda")
89*da0073e9SAndroid Build Coastguard Worker        return self.to(device=device2, non_blocking=non_blocking)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]:
92*da0073e9SAndroid Build Coastguard Worker        """Returns a copy of this object in HPU memory.
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker        If this object is already in HPU memory and on the correct device, then
95*da0073e9SAndroid Build Coastguard Worker        no copy is performed and the original object is returned.
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        Args:
98*da0073e9SAndroid Build Coastguard Worker            device (int): The destination HPU id. Defaults to the current device.
99*da0073e9SAndroid Build Coastguard Worker            non_blocking (bool): If ``True`` and the source is in pinned memory,
100*da0073e9SAndroid Build Coastguard Worker                the copy will be asynchronous with respect to the host. Otherwise,
101*da0073e9SAndroid Build Coastguard Worker                the argument has no effect.
102*da0073e9SAndroid Build Coastguard Worker        """
103*da0073e9SAndroid Build Coastguard Worker        device2 = torch.device("hpu", device) if device else torch.device("hpu")
104*da0073e9SAndroid Build Coastguard Worker        return self.to(device=device2, non_blocking=non_blocking)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    def element_size(self) -> _int:
107*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    def get_device(self) -> _int:
110*da0073e9SAndroid Build Coastguard Worker        return self.device.index
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker    def data_ptr(self) -> _int:
113*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker    def resizable(self) -> _bool:
116*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    # Defined in torch/csrc/generic/StorageSharing.cpp
119*da0073e9SAndroid Build Coastguard Worker    def _share_filename_cpu_(self, *args, **kwargs):
120*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    def _share_fd_cpu_(self, *args, **kwargs):
123*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    @classmethod
126*da0073e9SAndroid Build Coastguard Worker    def _new_using_filename_cpu(cls: Type[T], size: _int) -> T:
127*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    @classmethod
130*da0073e9SAndroid Build Coastguard Worker    def _new_using_fd_cpu(cls: Type[T], size: _int) -> T:
131*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    @classmethod
134*da0073e9SAndroid Build Coastguard Worker    def from_buffer(cls: Type[T], *args, **kwargs) -> T:
135*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    @classmethod
138*da0073e9SAndroid Build Coastguard Worker    def _new_shared_filename_cpu(
139*da0073e9SAndroid Build Coastguard Worker        cls: Type[T],
140*da0073e9SAndroid Build Coastguard Worker        manager,
141*da0073e9SAndroid Build Coastguard Worker        obj,
142*da0073e9SAndroid Build Coastguard Worker        size,
143*da0073e9SAndroid Build Coastguard Worker        *,
144*da0073e9SAndroid Build Coastguard Worker        device=None,
145*da0073e9SAndroid Build Coastguard Worker        dtype=None,
146*da0073e9SAndroid Build Coastguard Worker    ) -> T:
147*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    @classmethod
150*da0073e9SAndroid Build Coastguard Worker    def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T:
151*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    @classmethod
154*da0073e9SAndroid Build Coastguard Worker    def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T:
155*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker    def _shared_decref(self) -> Union[_StorageBase, TypedStorage]:
158*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    def _write_file(self, *args, **kwargs):
161*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker    def resize_(self, size: _int):
164*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker    def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]:
167*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker    def _set_from_file(self, *args, **kwargs):
170*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    def _set_cdata(self, *args, **kwargs):
173*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker    def _share_cuda_(self, *args, **kwargs):
176*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    def is_shared(self) -> _bool:
179*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker    @classmethod
182*da0073e9SAndroid Build Coastguard Worker    def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T:
183*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker    def _shared_incref(self, *args, **kwargs):
186*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    @classmethod
189*da0073e9SAndroid Build Coastguard Worker    def _free_weak_ref(cls, *args, **kwargs):
190*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    @property
193*da0073e9SAndroid Build Coastguard Worker    def is_cuda(self):
194*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker    @property
197*da0073e9SAndroid Build Coastguard Worker    def is_hpu(self):
198*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker    @classmethod
201*da0073e9SAndroid Build Coastguard Worker    def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]:
202*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker    @classmethod
205*da0073e9SAndroid Build Coastguard Worker    def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]:
206*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    def _byteswap(self, *args, **kwargs):
209*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    def _get_filename(self, *args, **kwargs) -> _Optional[str]:
212*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
215*da0073e9SAndroid Build Coastguard Worker        info_str = f"[{torch.typename(self)}(device={self.device}) of size {len(self)}]"
216*da0073e9SAndroid Build Coastguard Worker        if self.device.type == "meta":
217*da0073e9SAndroid Build Coastguard Worker            return "...\n" + info_str
218*da0073e9SAndroid Build Coastguard Worker        data_str = " " + "\n ".join(str(self[i]) for i in range(self.size()))
219*da0073e9SAndroid Build Coastguard Worker        return data_str + "\n" + info_str
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
222*da0073e9SAndroid Build Coastguard Worker        return iter(self[i] for i in range(self.size()))
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    def __copy__(self):
225*da0073e9SAndroid Build Coastguard Worker        return self.clone()
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker    def __deepcopy__(self, memo):
228*da0073e9SAndroid Build Coastguard Worker        memo = memo.setdefault("torch", {})
229*da0073e9SAndroid Build Coastguard Worker        if self._cdata in memo:
230*da0073e9SAndroid Build Coastguard Worker            return memo[self._cdata]
231*da0073e9SAndroid Build Coastguard Worker        new_storage = self.clone()
232*da0073e9SAndroid Build Coastguard Worker        memo[self._cdata] = new_storage
233*da0073e9SAndroid Build Coastguard Worker        return new_storage
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker    def __reduce__(self):
236*da0073e9SAndroid Build Coastguard Worker        b = io.BytesIO()
237*da0073e9SAndroid Build Coastguard Worker        torch.save(self, b, _use_new_zipfile_serialization=False)
238*da0073e9SAndroid Build Coastguard Worker        return (_load_from_bytes, (b.getvalue(),))
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    def __sizeof__(self):
241*da0073e9SAndroid Build Coastguard Worker        return super().__sizeof__() + self.size()
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker    def clone(self):
244*da0073e9SAndroid Build Coastguard Worker        """Return a copy of this storage."""
245*da0073e9SAndroid Build Coastguard Worker        return type(self)(self.nbytes(), device=self.device).copy_(self)
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker    def tolist(self):
248*da0073e9SAndroid Build Coastguard Worker        """Return a list containing the elements of this storage."""
249*da0073e9SAndroid Build Coastguard Worker        return list(self)
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker    def cpu(self):
252*da0073e9SAndroid Build Coastguard Worker        """Return a CPU copy of this storage if it's not already on the CPU."""
253*da0073e9SAndroid Build Coastguard Worker        if self.device.type != "cpu":
254*da0073e9SAndroid Build Coastguard Worker            return torch.UntypedStorage(self.size()).copy_(self, False)
255*da0073e9SAndroid Build Coastguard Worker        return self
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker    def mps(self):
258*da0073e9SAndroid Build Coastguard Worker        """Return a MPS copy of this storage if it's not already on the MPS."""
259*da0073e9SAndroid Build Coastguard Worker        if self.device.type != "mps":
260*da0073e9SAndroid Build Coastguard Worker            return torch.UntypedStorage(self.size(), device="mps").copy_(self, False)
261*da0073e9SAndroid Build Coastguard Worker        return self
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker    def _to(self, dtype):
264*da0073e9SAndroid Build Coastguard Worker        if not isinstance(dtype, torch.dtype):
265*da0073e9SAndroid Build Coastguard Worker            raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
266*da0073e9SAndroid Build Coastguard Worker        storage = (
267*da0073e9SAndroid Build Coastguard Worker            torch.tensor([], dtype=torch.uint8, device=self.device)
268*da0073e9SAndroid Build Coastguard Worker            .set_(cast(Storage, self))
269*da0073e9SAndroid Build Coastguard Worker            .to(dtype)
270*da0073e9SAndroid Build Coastguard Worker            ._typed_storage()
271*da0073e9SAndroid Build Coastguard Worker        )
272*da0073e9SAndroid Build Coastguard Worker        if storage.data_ptr() == self.data_ptr():
273*da0073e9SAndroid Build Coastguard Worker            storage = storage.clone()
274*da0073e9SAndroid Build Coastguard Worker        return storage
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker    def to(
277*da0073e9SAndroid Build Coastguard Worker        self, *, device: torch.device, non_blocking: _bool = False
278*da0073e9SAndroid Build Coastguard Worker    ) -> Union[_StorageBase, TypedStorage]:
279*da0073e9SAndroid Build Coastguard Worker        return _to(self, device, non_blocking)
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker    def double(self):
282*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to double type."""
283*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.double)
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker    def float(self):
286*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to float type."""
287*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.float)
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker    def half(self):
290*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to half type."""
291*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.half)
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker    def long(self):
294*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to long type."""
295*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.long)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    def int(self):
298*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to int type."""
299*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.int)
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker    def short(self):
302*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to short type."""
303*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.short)
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker    def char(self):
306*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to char type."""
307*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.int8)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker    def byte(self):
310*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to byte type."""
311*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.uint8)
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker    def bool(self):
314*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to bool type."""
315*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.bool)
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker    def bfloat16(self):
318*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to bfloat16 type."""
319*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.bfloat16)
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker    def complex_double(self):
322*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to complex double type."""
323*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.cdouble)
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker    def complex_float(self):
326*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to complex float type."""
327*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.cfloat)
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    def float8_e5m2(self):
330*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to float8_e5m2 type"""
331*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.float8_e5m2)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    def float8_e4m3fn(self):
334*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to float8_e4m3fn type"""
335*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.float8_e4m3fn)
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker    def float8_e5m2fnuz(self):
338*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to float8_e5m2fnuz type"""
339*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.float8_e5m2fnuz)
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker    def float8_e4m3fnuz(self):
342*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to float8_e4m3fnuz type"""
343*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.float8_e4m3fnuz)
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker    def is_pinned(self, device: Union[str, torch.device] = "cuda"):
346*da0073e9SAndroid Build Coastguard Worker        r"""Determine whether the CPU storage is already pinned on device.
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker        Args:
349*da0073e9SAndroid Build Coastguard Worker            device (str or torch.device): The device to pin memory on. Default: ``'cuda'``.
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker        Returns:
352*da0073e9SAndroid Build Coastguard Worker            A boolean variable.
353*da0073e9SAndroid Build Coastguard Worker        """
354*da0073e9SAndroid Build Coastguard Worker        return (
355*da0073e9SAndroid Build Coastguard Worker            torch.tensor([], dtype=torch.uint8, device=self.device)
356*da0073e9SAndroid Build Coastguard Worker            .set_(cast(Storage, self))
357*da0073e9SAndroid Build Coastguard Worker            .is_pinned(device)
358*da0073e9SAndroid Build Coastguard Worker        )
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker    def pin_memory(self, device: Union[str, torch.device] = "cuda"):
361*da0073e9SAndroid Build Coastguard Worker        r"""Copy the CPU storage to pinned memory, if it's not already pinned.
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker        Args:
364*da0073e9SAndroid Build Coastguard Worker            device (str or torch.device): The device to pin memory on. Default: ``'cuda'``.
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        Returns:
367*da0073e9SAndroid Build Coastguard Worker            A pinned CPU storage.
368*da0073e9SAndroid Build Coastguard Worker        """
369*da0073e9SAndroid Build Coastguard Worker        if self.device.type != "cpu":
370*da0073e9SAndroid Build Coastguard Worker            raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker        pinned_tensor = (
373*da0073e9SAndroid Build Coastguard Worker            torch.tensor([], dtype=torch.uint8, device=self.device)
374*da0073e9SAndroid Build Coastguard Worker            .set_(cast(Storage, self))
375*da0073e9SAndroid Build Coastguard Worker            .pin_memory(device)
376*da0073e9SAndroid Build Coastguard Worker        )
377*da0073e9SAndroid Build Coastguard Worker        return pinned_tensor.untyped_storage()
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker    def share_memory_(self):
380*da0073e9SAndroid Build Coastguard Worker        """See :meth:`torch.UntypedStorage.share_memory_`"""
381*da0073e9SAndroid Build Coastguard Worker        from torch.multiprocessing import get_sharing_strategy
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker        if self.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
384*da0073e9SAndroid Build Coastguard Worker            pass  # CUDA or PrivateUse1 doesn't use POSIX shared memory
385*da0073e9SAndroid Build Coastguard Worker        elif get_sharing_strategy() == "file_system":
386*da0073e9SAndroid Build Coastguard Worker            self._share_filename_cpu_()
387*da0073e9SAndroid Build Coastguard Worker        else:
388*da0073e9SAndroid Build Coastguard Worker            self._share_fd_cpu_()
389*da0073e9SAndroid Build Coastguard Worker        return self
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker    @classmethod
392*da0073e9SAndroid Build Coastguard Worker    def _new_shared(cls, size, *, device="cpu"):
393*da0073e9SAndroid Build Coastguard Worker        """Create a new storage in shared memory with the same data type."""
394*da0073e9SAndroid Build Coastguard Worker        from torch.multiprocessing import get_sharing_strategy
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
397*da0073e9SAndroid Build Coastguard Worker        if device.type in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]:
398*da0073e9SAndroid Build Coastguard Worker            return cls(size, device=device)
399*da0073e9SAndroid Build Coastguard Worker        elif get_sharing_strategy() == "file_system":
400*da0073e9SAndroid Build Coastguard Worker            return cls._new_using_filename_cpu(size)
401*da0073e9SAndroid Build Coastguard Worker        else:
402*da0073e9SAndroid Build Coastguard Worker            return cls._new_using_fd_cpu(size)
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    def untyped(self):
405*da0073e9SAndroid Build Coastguard Worker        return self
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker    def byteswap(self, dtype):
408*da0073e9SAndroid Build Coastguard Worker        """Swap bytes in underlying data."""
409*da0073e9SAndroid Build Coastguard Worker        elem_size = torch._utils._element_size(dtype)
410*da0073e9SAndroid Build Coastguard Worker        # for complex types, don't swap first and second numbers
411*da0073e9SAndroid Build Coastguard Worker        if dtype.is_complex:
412*da0073e9SAndroid Build Coastguard Worker            elem_size = max(int(elem_size / 2), 1)
413*da0073e9SAndroid Build Coastguard Worker        self._byteswap(elem_size)
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Workerdef _share_memory_lock_protected(fn):
417*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(fn)
418*da0073e9SAndroid Build Coastguard Worker    def wrapper(self, *args, **kwargs):
419*da0073e9SAndroid Build Coastguard Worker        to_free = None
420*da0073e9SAndroid Build Coastguard Worker        to_wait = None
421*da0073e9SAndroid Build Coastguard Worker        with _share_memory_lock:
422*da0073e9SAndroid Build Coastguard Worker            key = self._cdata
423*da0073e9SAndroid Build Coastguard Worker            if key in _share_memory_map:
424*da0073e9SAndroid Build Coastguard Worker                to_wait = _share_memory_map[key]
425*da0073e9SAndroid Build Coastguard Worker            else:
426*da0073e9SAndroid Build Coastguard Worker                _share_memory_map[key] = threading.RLock()
427*da0073e9SAndroid Build Coastguard Worker                _share_memory_map[key].acquire()
428*da0073e9SAndroid Build Coastguard Worker                to_free = key
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker        # If we're already in the process of sharing the storage, wait
431*da0073e9SAndroid Build Coastguard Worker        # for it to be done.
432*da0073e9SAndroid Build Coastguard Worker        if to_wait is not None:
433*da0073e9SAndroid Build Coastguard Worker            with to_wait:
434*da0073e9SAndroid Build Coastguard Worker                pass
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker        try:
437*da0073e9SAndroid Build Coastguard Worker            return fn(self, *args, **kwargs)
438*da0073e9SAndroid Build Coastguard Worker        finally:
439*da0073e9SAndroid Build Coastguard Worker            # If we acquired the storage lock here and we're done working on it
440*da0073e9SAndroid Build Coastguard Worker            # we can now release it and free the entry.
441*da0073e9SAndroid Build Coastguard Worker            if to_free is not None:
442*da0073e9SAndroid Build Coastguard Worker                # Ensure that the cdata from the storage didn't change and only
443*da0073e9SAndroid Build Coastguard Worker                # the data_ptr did.
444*da0073e9SAndroid Build Coastguard Worker                assert self._cdata == to_free
445*da0073e9SAndroid Build Coastguard Worker                with _share_memory_lock:
446*da0073e9SAndroid Build Coastguard Worker                    _share_memory_map[to_free].release()
447*da0073e9SAndroid Build Coastguard Worker                    del _share_memory_map[to_free]
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker    return wrapper
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Workerclass UntypedStorage(torch._C.StorageBase, _StorageBase):
453*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, *args, **kwargs):
454*da0073e9SAndroid Build Coastguard Worker        if self.device.type == "meta":
455*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError("Not available for 'meta' device type")
456*da0073e9SAndroid Build Coastguard Worker        return super().__getitem__(*args, **kwargs)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker    @property
459*da0073e9SAndroid Build Coastguard Worker    def is_cuda(self):
460*da0073e9SAndroid Build Coastguard Worker        return self.device.type == "cuda"
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker    @property
463*da0073e9SAndroid Build Coastguard Worker    def is_hpu(self):
464*da0073e9SAndroid Build Coastguard Worker        return self.device.type == "hpu"
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker    @property
467*da0073e9SAndroid Build Coastguard Worker    def filename(self) -> _Optional[str]:
468*da0073e9SAndroid Build Coastguard Worker        """Returns the file name associated with this storage.
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker        The file name will be a string if the storage is on CPU and was created via
471*da0073e9SAndroid Build Coastguard Worker        :meth:`~torch.from_file()` with ``shared`` as ``True``. This attribute is ``None`` otherwise.
472*da0073e9SAndroid Build Coastguard Worker        """
473*da0073e9SAndroid Build Coastguard Worker        return self._get_filename()
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker    @_share_memory_lock_protected
476*da0073e9SAndroid Build Coastguard Worker    def share_memory_(self, *args, **kwargs):
477*da0073e9SAndroid Build Coastguard Worker        """
478*da0073e9SAndroid Build Coastguard Worker        Moves the storage to shared memory.
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker        This is a no-op for storages already in shared memory and for CUDA
481*da0073e9SAndroid Build Coastguard Worker        storages, which do not need to be moved for sharing across processes.
482*da0073e9SAndroid Build Coastguard Worker        Storages in shared memory cannot be resized.
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker        Note that to mitigate issues like `this <https://github.com/pytorch/pytorch/issues/95606>`_
485*da0073e9SAndroid Build Coastguard Worker        it is thread safe to call this function from multiple threads on the same object.
486*da0073e9SAndroid Build Coastguard Worker        It is NOT thread safe though to call any other function on self without proper
487*da0073e9SAndroid Build Coastguard Worker        synchronization. Please see :doc:`/notes/multiprocessing` for more details.
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Worker        .. note::
490*da0073e9SAndroid Build Coastguard Worker            When all references to a storage in shared memory are deleted, the associated shared memory
491*da0073e9SAndroid Build Coastguard Worker            object will also be deleted. PyTorch has a special cleanup process to ensure that this happens
492*da0073e9SAndroid Build Coastguard Worker            even if the current process exits unexpectedly.
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker            It is worth noting the difference between :meth:`share_memory_` and :meth:`from_file` with ``shared = True``
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker            #. ``share_memory_`` uses `shm_open(3) <https://man7.org/linux/man-pages/man3/shm_open.3.html>`_ to create a
497*da0073e9SAndroid Build Coastguard Worker               POSIX shared memory object while :meth:`from_file` uses
498*da0073e9SAndroid Build Coastguard Worker               `open(2) <https://man7.org/linux/man-pages/man2/open.2.html>`_ to open the filename passed by the user.
499*da0073e9SAndroid Build Coastguard Worker            #. Both use an `mmap(2) call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_ with ``MAP_SHARED``
500*da0073e9SAndroid Build Coastguard Worker               to map the file/object into the current virtual address space
501*da0073e9SAndroid Build Coastguard Worker            #. ``share_memory_`` will call ``shm_unlink(3)`` on the object after mapping it to make sure the shared memory
502*da0073e9SAndroid Build Coastguard Worker               object is freed when no process has the object open. ``torch.from_file(shared=True)`` does not unlink the
503*da0073e9SAndroid Build Coastguard Worker               file. This file is persistent and will remain until it is deleted by the user.
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker        Returns:
506*da0073e9SAndroid Build Coastguard Worker            ``self``
507*da0073e9SAndroid Build Coastguard Worker        """
508*da0073e9SAndroid Build Coastguard Worker        return super().share_memory_(*args, **kwargs)
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker    @_share_memory_lock_protected
511*da0073e9SAndroid Build Coastguard Worker    def _share_fd_cpu_(self, *args, **kwargs):
512*da0073e9SAndroid Build Coastguard Worker        return super()._share_fd_cpu_(*args, **kwargs)
513*da0073e9SAndroid Build Coastguard Worker
514*da0073e9SAndroid Build Coastguard Worker    @_share_memory_lock_protected
515*da0073e9SAndroid Build Coastguard Worker    def _share_filename_cpu_(self, *args, **kwargs):
516*da0073e9SAndroid Build Coastguard Worker        return super()._share_filename_cpu_(*args, **kwargs)
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Workerdef _load_from_bytes(b):
520*da0073e9SAndroid Build Coastguard Worker    return torch.load(io.BytesIO(b), weights_only=False)
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None)
524*da0073e9SAndroid Build Coastguard Workerdef _new_dtypes():
525*da0073e9SAndroid Build Coastguard Worker    # These are dtypes serialized as UntypedStorage unlike those in
526*da0073e9SAndroid Build Coastguard Worker    # _dtype_to_storage_type_map
527*da0073e9SAndroid Build Coastguard Worker    return {
528*da0073e9SAndroid Build Coastguard Worker        torch.float8_e5m2,
529*da0073e9SAndroid Build Coastguard Worker        torch.float8_e4m3fn,
530*da0073e9SAndroid Build Coastguard Worker        torch.float8_e5m2fnuz,
531*da0073e9SAndroid Build Coastguard Worker        torch.float8_e4m3fnuz,
532*da0073e9SAndroid Build Coastguard Worker        torch.bits8,
533*da0073e9SAndroid Build Coastguard Worker        torch.bits16,
534*da0073e9SAndroid Build Coastguard Worker        torch.bits1x8,
535*da0073e9SAndroid Build Coastguard Worker        torch.bits2x4,
536*da0073e9SAndroid Build Coastguard Worker        torch.bits4x2,
537*da0073e9SAndroid Build Coastguard Worker        torch.complex32,
538*da0073e9SAndroid Build Coastguard Worker    }
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None)
542*da0073e9SAndroid Build Coastguard Workerdef _dtype_to_storage_type_map():
543*da0073e9SAndroid Build Coastguard Worker    # NOTE: We should no longer add dtypes to this map. This map
544*da0073e9SAndroid Build Coastguard Worker    # is only used for BC/FC with older PyTorch versions. Going forward,
545*da0073e9SAndroid Build Coastguard Worker    # new dtypes of TypedStorage should not translate to a legacy
546*da0073e9SAndroid Build Coastguard Worker    # <type>Storage class. Instead, new dtypes of TypedStorage should
547*da0073e9SAndroid Build Coastguard Worker    # be serialized as an UntypedStorage paired with a torch.dtype
548*da0073e9SAndroid Build Coastguard Worker    return {
549*da0073e9SAndroid Build Coastguard Worker        torch.double: "DoubleStorage",
550*da0073e9SAndroid Build Coastguard Worker        torch.float: "FloatStorage",
551*da0073e9SAndroid Build Coastguard Worker        torch.half: "HalfStorage",
552*da0073e9SAndroid Build Coastguard Worker        torch.long: "LongStorage",
553*da0073e9SAndroid Build Coastguard Worker        torch.int: "IntStorage",
554*da0073e9SAndroid Build Coastguard Worker        torch.int16: "ShortStorage",
555*da0073e9SAndroid Build Coastguard Worker        torch.int8: "CharStorage",
556*da0073e9SAndroid Build Coastguard Worker        torch.uint8: "ByteStorage",
557*da0073e9SAndroid Build Coastguard Worker        torch.bool: "BoolStorage",
558*da0073e9SAndroid Build Coastguard Worker        torch.bfloat16: "BFloat16Storage",
559*da0073e9SAndroid Build Coastguard Worker        torch.cdouble: "ComplexDoubleStorage",
560*da0073e9SAndroid Build Coastguard Worker        torch.cfloat: "ComplexFloatStorage",
561*da0073e9SAndroid Build Coastguard Worker        torch.qint8: "QInt8Storage",
562*da0073e9SAndroid Build Coastguard Worker        torch.qint32: "QInt32Storage",
563*da0073e9SAndroid Build Coastguard Worker        torch.quint8: "QUInt8Storage",
564*da0073e9SAndroid Build Coastguard Worker        torch.quint4x2: "QUInt4x2Storage",
565*da0073e9SAndroid Build Coastguard Worker        torch.quint2x4: "QUInt2x4Storage",
566*da0073e9SAndroid Build Coastguard Worker    }
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None)
570*da0073e9SAndroid Build Coastguard Workerdef _storage_type_to_dtype_map():
571*da0073e9SAndroid Build Coastguard Worker    dtype_map = {val: key for key, val in _dtype_to_storage_type_map().items()}
572*da0073e9SAndroid Build Coastguard Worker    return dtype_map
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker
575*da0073e9SAndroid Build Coastguard Workerdef _get_storage_from_sequence(sequence, dtype, device):
576*da0073e9SAndroid Build Coastguard Worker    if dtype in [
577*da0073e9SAndroid Build Coastguard Worker        torch.quint8,
578*da0073e9SAndroid Build Coastguard Worker        torch.quint4x2,
579*da0073e9SAndroid Build Coastguard Worker        torch.quint2x4,
580*da0073e9SAndroid Build Coastguard Worker        torch.qint32,
581*da0073e9SAndroid Build Coastguard Worker        torch.qint8,
582*da0073e9SAndroid Build Coastguard Worker    ]:
583*da0073e9SAndroid Build Coastguard Worker        interpret_dtypes = {
584*da0073e9SAndroid Build Coastguard Worker            torch.quint8: torch.uint8,
585*da0073e9SAndroid Build Coastguard Worker            torch.quint4x2: torch.uint8,
586*da0073e9SAndroid Build Coastguard Worker            torch.quint2x4: torch.uint8,
587*da0073e9SAndroid Build Coastguard Worker            torch.qint32: torch.int32,
588*da0073e9SAndroid Build Coastguard Worker            torch.qint8: torch.int8,
589*da0073e9SAndroid Build Coastguard Worker        }
590*da0073e9SAndroid Build Coastguard Worker        tmp_tensor = torch.tensor(
591*da0073e9SAndroid Build Coastguard Worker            sequence, dtype=interpret_dtypes[dtype], device=device
592*da0073e9SAndroid Build Coastguard Worker        )
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker    else:
595*da0073e9SAndroid Build Coastguard Worker        tmp_tensor = torch.tensor(sequence, dtype=dtype, device=device)
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker    return tmp_tensor._typed_storage()._untyped_storage
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Workerdef _isint(x):
601*da0073e9SAndroid Build Coastguard Worker    if HAS_NUMPY:
602*da0073e9SAndroid Build Coastguard Worker        return isinstance(x, (int, np.integer))
603*da0073e9SAndroid Build Coastguard Worker    else:
604*da0073e9SAndroid Build Coastguard Worker        return isinstance(x, int)
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker_always_warn_typed_storage_removal = False
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Workerdef _get_always_warn_typed_storage_removal():
611*da0073e9SAndroid Build Coastguard Worker    return _always_warn_typed_storage_removal
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Workerdef _set_always_warn_typed_storage_removal(always_warn):
615*da0073e9SAndroid Build Coastguard Worker    global _always_warn_typed_storage_removal
616*da0073e9SAndroid Build Coastguard Worker    assert isinstance(always_warn, bool)
617*da0073e9SAndroid Build Coastguard Worker    _always_warn_typed_storage_removal = always_warn
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Workerdef _warn_typed_storage_removal(stacklevel=2):
621*da0073e9SAndroid Build Coastguard Worker    global _always_warn_typed_storage_removal
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker    def is_first_time():
624*da0073e9SAndroid Build Coastguard Worker        if not hasattr(_warn_typed_storage_removal, "has_warned"):
625*da0073e9SAndroid Build Coastguard Worker            return True
626*da0073e9SAndroid Build Coastguard Worker        else:
627*da0073e9SAndroid Build Coastguard Worker            return not _warn_typed_storage_removal.__dict__["has_warned"]
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker    if _get_always_warn_typed_storage_removal() or is_first_time():
630*da0073e9SAndroid Build Coastguard Worker        message = (
631*da0073e9SAndroid Build Coastguard Worker            "TypedStorage is deprecated. It will be removed in the future and "
632*da0073e9SAndroid Build Coastguard Worker            "UntypedStorage will be the only storage class. This should only matter "
633*da0073e9SAndroid Build Coastguard Worker            "to you if you are using storages directly.  To access UntypedStorage "
634*da0073e9SAndroid Build Coastguard Worker            "directly, use tensor.untyped_storage() instead of tensor.storage()"
635*da0073e9SAndroid Build Coastguard Worker        )
636*da0073e9SAndroid Build Coastguard Worker        warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
637*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal.__dict__["has_warned"] = True
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Workerdef _reset_warn_typed_storage_removal():
641*da0073e9SAndroid Build Coastguard Worker    _warn_typed_storage_removal.__dict__["has_warned"] = False
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Workerdef _get_device_from_module(module: str):
645*da0073e9SAndroid Build Coastguard Worker    last_part = module.rsplit(".", 1)[-1]
646*da0073e9SAndroid Build Coastguard Worker    if last_part in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]:
647*da0073e9SAndroid Build Coastguard Worker        return last_part
648*da0073e9SAndroid Build Coastguard Worker    else:
649*da0073e9SAndroid Build Coastguard Worker        return "cpu"
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker
652*da0073e9SAndroid Build Coastguard Workerclass TypedStorage:
653*da0073e9SAndroid Build Coastguard Worker    is_sparse: _bool = False
654*da0073e9SAndroid Build Coastguard Worker    # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True)
655*da0073e9SAndroid Build Coastguard Worker    _fake_device: _Optional[torch.device] = None
656*da0073e9SAndroid Build Coastguard Worker
657*da0073e9SAndroid Build Coastguard Worker    dtype: torch.dtype
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker    @property
660*da0073e9SAndroid Build Coastguard Worker    def _dtype(self):
661*da0073e9SAndroid Build Coastguard Worker        return self.dtype
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker    @property
664*da0073e9SAndroid Build Coastguard Worker    def filename(self) -> _Optional[str]:
665*da0073e9SAndroid Build Coastguard Worker        """Returns the file name associated with this storage if the storage was memory mapped from a file.
666*da0073e9SAndroid Build Coastguard Worker        or ``None`` if the storage was not created by memory mapping a file."""
667*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage.filename
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker    def fill_(self, value):
670*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
671*da0073e9SAndroid Build Coastguard Worker        self._setitem(slice(0, self._size()), value)
672*da0073e9SAndroid Build Coastguard Worker        return self
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker    def __new__(
675*da0073e9SAndroid Build Coastguard Worker        cls,
676*da0073e9SAndroid Build Coastguard Worker        *args,
677*da0073e9SAndroid Build Coastguard Worker        wrap_storage=None,
678*da0073e9SAndroid Build Coastguard Worker        dtype=None,
679*da0073e9SAndroid Build Coastguard Worker        device=None,
680*da0073e9SAndroid Build Coastguard Worker        _internal=False,
681*da0073e9SAndroid Build Coastguard Worker    ):
682*da0073e9SAndroid Build Coastguard Worker        if not _internal:
683*da0073e9SAndroid Build Coastguard Worker            _warn_typed_storage_removal()
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker        if cls == torch.storage._LegacyStorage:
686*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
687*da0073e9SAndroid Build Coastguard Worker                "Only child classes of _LegacyStorage can be instantiated"
688*da0073e9SAndroid Build Coastguard Worker            )
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker        if cls == TypedStorage:
691*da0073e9SAndroid Build Coastguard Worker            return super().__new__(cls)
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        else:
694*da0073e9SAndroid Build Coastguard Worker            arg_error_msg = (
695*da0073e9SAndroid Build Coastguard Worker                f"{cls}.__new__ received an invalid combination "
696*da0073e9SAndroid Build Coastguard Worker                f"of arguments. Expected one of:\n"
697*da0073e9SAndroid Build Coastguard Worker                " * no arguments\n"
698*da0073e9SAndroid Build Coastguard Worker                " * (int size)\n"
699*da0073e9SAndroid Build Coastguard Worker                " * (Sequence data)\n"
700*da0073e9SAndroid Build Coastguard Worker                " * (*, UntypedStorage wrap_storage)"
701*da0073e9SAndroid Build Coastguard Worker            )
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker            if device is not None:
704*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
705*da0073e9SAndroid Build Coastguard Worker                    arg_error_msg + "\nKeyword argument 'device' cannot be specified"
706*da0073e9SAndroid Build Coastguard Worker                )
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker            if dtype is not None:
709*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
710*da0073e9SAndroid Build Coastguard Worker                    arg_error_msg + "\nKeyword argument 'dtype' cannot be specified"
711*da0073e9SAndroid Build Coastguard Worker                )
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker            if wrap_storage is None:
714*da0073e9SAndroid Build Coastguard Worker                if len(args) > 1:
715*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
716*da0073e9SAndroid Build Coastguard Worker                        arg_error_msg + "\nToo many positional arguments"
717*da0073e9SAndroid Build Coastguard Worker                    )
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker                if (
720*da0073e9SAndroid Build Coastguard Worker                    len(args) == 1
721*da0073e9SAndroid Build Coastguard Worker                    and not _isint(args[0])
722*da0073e9SAndroid Build Coastguard Worker                    and not isinstance(args[0], collections.abc.Sequence)
723*da0073e9SAndroid Build Coastguard Worker                ):
724*da0073e9SAndroid Build Coastguard Worker                    raise TypeError(
725*da0073e9SAndroid Build Coastguard Worker                        arg_error_msg
726*da0073e9SAndroid Build Coastguard Worker                        + f"\nArgument type not recognized: {type(args[0])}"
727*da0073e9SAndroid Build Coastguard Worker                    )
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker                return TypedStorage(
730*da0073e9SAndroid Build Coastguard Worker                    *args,
731*da0073e9SAndroid Build Coastguard Worker                    dtype=cls._dtype,
732*da0073e9SAndroid Build Coastguard Worker                    device=_get_device_from_module(cls.__module__),
733*da0073e9SAndroid Build Coastguard Worker                    _internal=True,
734*da0073e9SAndroid Build Coastguard Worker                )
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker            else:
737*da0073e9SAndroid Build Coastguard Worker                if len(args) != 0:
738*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
739*da0073e9SAndroid Build Coastguard Worker                        arg_error_msg
740*da0073e9SAndroid Build Coastguard Worker                        + "\nNo positional arguments should be given when using "
741*da0073e9SAndroid Build Coastguard Worker                        "'wrap_storage'"
742*da0073e9SAndroid Build Coastguard Worker                    )
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker                if not isinstance(wrap_storage, torch.UntypedStorage):
745*da0073e9SAndroid Build Coastguard Worker                    raise TypeError(
746*da0073e9SAndroid Build Coastguard Worker                        arg_error_msg
747*da0073e9SAndroid Build Coastguard Worker                        + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}"
748*da0073e9SAndroid Build Coastguard Worker                    )
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker                cls_device = _get_device_from_module(cls.__module__)
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker                if wrap_storage.device.type != cls_device:
753*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
754*da0073e9SAndroid Build Coastguard Worker                        arg_error_msg
755*da0073e9SAndroid Build Coastguard Worker                        + f"\nDevice of 'wrap_storage' must be {cls_device}"
756*da0073e9SAndroid Build Coastguard Worker                        f", but got {wrap_storage.device.type}"
757*da0073e9SAndroid Build Coastguard Worker                    )
758*da0073e9SAndroid Build Coastguard Worker
759*da0073e9SAndroid Build Coastguard Worker                return TypedStorage(
760*da0073e9SAndroid Build Coastguard Worker                    *args,
761*da0073e9SAndroid Build Coastguard Worker                    wrap_storage=wrap_storage,
762*da0073e9SAndroid Build Coastguard Worker                    dtype=cls.dtype,
763*da0073e9SAndroid Build Coastguard Worker                    _internal=True,
764*da0073e9SAndroid Build Coastguard Worker                )
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker    def __init__(
767*da0073e9SAndroid Build Coastguard Worker        self,
768*da0073e9SAndroid Build Coastguard Worker        *args,
769*da0073e9SAndroid Build Coastguard Worker        device=None,
770*da0073e9SAndroid Build Coastguard Worker        dtype=None,
771*da0073e9SAndroid Build Coastguard Worker        wrap_storage=None,
772*da0073e9SAndroid Build Coastguard Worker        _internal=False,
773*da0073e9SAndroid Build Coastguard Worker    ):
774*da0073e9SAndroid Build Coastguard Worker        if not _internal:
775*da0073e9SAndroid Build Coastguard Worker            _warn_typed_storage_removal()
776*da0073e9SAndroid Build Coastguard Worker        arg_error_msg = (
777*da0073e9SAndroid Build Coastguard Worker            "TypedStorage.__init__ received an invalid combination "
778*da0073e9SAndroid Build Coastguard Worker            "of arguments. Expected one of:\n"
779*da0073e9SAndroid Build Coastguard Worker            " * (*, torch.device device, torch.dtype dtype)\n"
780*da0073e9SAndroid Build Coastguard Worker            " * (int size, *, torch.device device, torch.dtype dtype)\n"
781*da0073e9SAndroid Build Coastguard Worker            " * (Sequence data, *, torch.device device, torch.dtype dtype)\n"
782*da0073e9SAndroid Build Coastguard Worker            " * (*, UntypedStorage wrap_storage, torch.dtype dtype)"
783*da0073e9SAndroid Build Coastguard Worker        )
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker        if wrap_storage is not None:
786*da0073e9SAndroid Build Coastguard Worker            if len(args) != 0:
787*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
788*da0073e9SAndroid Build Coastguard Worker                    arg_error_msg
789*da0073e9SAndroid Build Coastguard Worker                    + "\nNo positional arguments should be given when using "
790*da0073e9SAndroid Build Coastguard Worker                    "'wrap_storage'"
791*da0073e9SAndroid Build Coastguard Worker                )
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker            if dtype is None:
794*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
795*da0073e9SAndroid Build Coastguard Worker                    arg_error_msg + "\nArgument 'dtype' must be specified"
796*da0073e9SAndroid Build Coastguard Worker                )
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker            if not isinstance(dtype, torch.dtype):
799*da0073e9SAndroid Build Coastguard Worker                raise TypeError(
800*da0073e9SAndroid Build Coastguard Worker                    arg_error_msg
801*da0073e9SAndroid Build Coastguard Worker                    + f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}"
802*da0073e9SAndroid Build Coastguard Worker                )
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker            if device is not None:
805*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
806*da0073e9SAndroid Build Coastguard Worker                    arg_error_msg
807*da0073e9SAndroid Build Coastguard Worker                    + "\nArgument 'device' should not be specified when 'wrap_storage' is given"
808*da0073e9SAndroid Build Coastguard Worker                )
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker            self.dtype = dtype
811*da0073e9SAndroid Build Coastguard Worker
812*da0073e9SAndroid Build Coastguard Worker            if not isinstance(wrap_storage, torch.UntypedStorage):
813*da0073e9SAndroid Build Coastguard Worker                raise TypeError(
814*da0073e9SAndroid Build Coastguard Worker                    arg_error_msg
815*da0073e9SAndroid Build Coastguard Worker                    + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}"
816*da0073e9SAndroid Build Coastguard Worker                )
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker            self._untyped_storage = wrap_storage
819*da0073e9SAndroid Build Coastguard Worker
820*da0073e9SAndroid Build Coastguard Worker        else:
821*da0073e9SAndroid Build Coastguard Worker            self.dtype = torch.get_default_dtype() if dtype is None else dtype
822*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cpu" if device is None else device)
823*da0073e9SAndroid Build Coastguard Worker
824*da0073e9SAndroid Build Coastguard Worker            if self.dtype in [
825*da0073e9SAndroid Build Coastguard Worker                torch.quint8,
826*da0073e9SAndroid Build Coastguard Worker                torch.quint4x2,
827*da0073e9SAndroid Build Coastguard Worker                torch.quint2x4,
828*da0073e9SAndroid Build Coastguard Worker                torch.qint32,
829*da0073e9SAndroid Build Coastguard Worker                torch.qint8,
830*da0073e9SAndroid Build Coastguard Worker            ]:
831*da0073e9SAndroid Build Coastguard Worker                if device.type == "cuda":
832*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(
833*da0073e9SAndroid Build Coastguard Worker                        "Cannot create CUDA storage with quantized dtype"
834*da0073e9SAndroid Build Coastguard Worker                    )
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker            if len(args) == 0:
837*da0073e9SAndroid Build Coastguard Worker                self._untyped_storage = torch.UntypedStorage(device=device)
838*da0073e9SAndroid Build Coastguard Worker
839*da0073e9SAndroid Build Coastguard Worker            elif len(args) == 1:
840*da0073e9SAndroid Build Coastguard Worker                if _isint(args[0]):
841*da0073e9SAndroid Build Coastguard Worker                    self._untyped_storage = torch.UntypedStorage(
842*da0073e9SAndroid Build Coastguard Worker                        int(args[0]) * self._element_size(), device=device
843*da0073e9SAndroid Build Coastguard Worker                    )
844*da0073e9SAndroid Build Coastguard Worker                elif isinstance(args[0], collections.abc.Sequence):
845*da0073e9SAndroid Build Coastguard Worker                    self._untyped_storage = _get_storage_from_sequence(
846*da0073e9SAndroid Build Coastguard Worker                        args[0], self.dtype, device
847*da0073e9SAndroid Build Coastguard Worker                    )
848*da0073e9SAndroid Build Coastguard Worker                else:
849*da0073e9SAndroid Build Coastguard Worker                    raise TypeError(
850*da0073e9SAndroid Build Coastguard Worker                        arg_error_msg
851*da0073e9SAndroid Build Coastguard Worker                        + f"\nArgument type not recognized: {type(args[0])}"
852*da0073e9SAndroid Build Coastguard Worker                    )
853*da0073e9SAndroid Build Coastguard Worker
854*da0073e9SAndroid Build Coastguard Worker            else:
855*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(arg_error_msg + "\nToo many positional arguments")
856*da0073e9SAndroid Build Coastguard Worker
857*da0073e9SAndroid Build Coastguard Worker    @property
858*da0073e9SAndroid Build Coastguard Worker    def is_cuda(self):
859*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
860*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage.device.type == "cuda"
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker    @property
863*da0073e9SAndroid Build Coastguard Worker    def is_hpu(self):
864*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
865*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage.device.type == "hpu"
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker    def untyped(self):
868*da0073e9SAndroid Build Coastguard Worker        """Return the internal :class:`torch.UntypedStorage`."""
869*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
870*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker    def _new_wrapped_storage(self, untyped_storage) -> Self:
873*da0073e9SAndroid Build Coastguard Worker        assert type(untyped_storage) == torch.UntypedStorage
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker        if type(self) == TypedStorage:
876*da0073e9SAndroid Build Coastguard Worker            return cast(
877*da0073e9SAndroid Build Coastguard Worker                Self,
878*da0073e9SAndroid Build Coastguard Worker                TypedStorage(
879*da0073e9SAndroid Build Coastguard Worker                    wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
880*da0073e9SAndroid Build Coastguard Worker                ),
881*da0073e9SAndroid Build Coastguard Worker            )
882*da0073e9SAndroid Build Coastguard Worker        else:
883*da0073e9SAndroid Build Coastguard Worker            return type(self)(wrap_storage=untyped_storage)
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
886*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
887*da0073e9SAndroid Build Coastguard Worker        return self._size()
888*da0073e9SAndroid Build Coastguard Worker
889*da0073e9SAndroid Build Coastguard Worker    def _maybe_wrap_index(self, idx, is_stop=False):
890*da0073e9SAndroid Build Coastguard Worker        if idx is None:
891*da0073e9SAndroid Build Coastguard Worker            if is_stop:
892*da0073e9SAndroid Build Coastguard Worker                return self._size()
893*da0073e9SAndroid Build Coastguard Worker            else:
894*da0073e9SAndroid Build Coastguard Worker                return 0
895*da0073e9SAndroid Build Coastguard Worker
896*da0073e9SAndroid Build Coastguard Worker        else:
897*da0073e9SAndroid Build Coastguard Worker            if type(idx) != int:
898*da0073e9SAndroid Build Coastguard Worker                raise TypeError(f"can't index a {type(self)} with {type(idx)}")
899*da0073e9SAndroid Build Coastguard Worker            if is_stop:
900*da0073e9SAndroid Build Coastguard Worker                if (idx > self._size()) or (idx < -self._size()):
901*da0073e9SAndroid Build Coastguard Worker                    raise IndexError(
902*da0073e9SAndroid Build Coastguard Worker                        f"index {idx} out of range for storage of size {self.size()}"
903*da0073e9SAndroid Build Coastguard Worker                    )
904*da0073e9SAndroid Build Coastguard Worker                if idx > 0:
905*da0073e9SAndroid Build Coastguard Worker                    return idx
906*da0073e9SAndroid Build Coastguard Worker                else:
907*da0073e9SAndroid Build Coastguard Worker                    return idx % self._size()
908*da0073e9SAndroid Build Coastguard Worker            else:
909*da0073e9SAndroid Build Coastguard Worker                if (idx >= self._size()) or (idx < -self._size()):
910*da0073e9SAndroid Build Coastguard Worker                    raise IndexError(
911*da0073e9SAndroid Build Coastguard Worker                        f"index {idx} out of range for storage of size {self.size()}"
912*da0073e9SAndroid Build Coastguard Worker                    )
913*da0073e9SAndroid Build Coastguard Worker                return idx % self._size()
914*da0073e9SAndroid Build Coastguard Worker
915*da0073e9SAndroid Build Coastguard Worker    def __setitem__(self, idx, value):
916*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
917*da0073e9SAndroid Build Coastguard Worker        return self._setitem(idx, value)
918*da0073e9SAndroid Build Coastguard Worker
919*da0073e9SAndroid Build Coastguard Worker    def _setitem(self, idx, value):
920*da0073e9SAndroid Build Coastguard Worker        if not isinstance(idx, (int, slice)):
921*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
922*da0073e9SAndroid Build Coastguard Worker        if torch.is_storage(value):
923*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"cannot set item with value type {type(value)}")
924*da0073e9SAndroid Build Coastguard Worker        if self.dtype in [
925*da0073e9SAndroid Build Coastguard Worker            torch.quint8,
926*da0073e9SAndroid Build Coastguard Worker            torch.quint4x2,
927*da0073e9SAndroid Build Coastguard Worker            torch.quint2x4,
928*da0073e9SAndroid Build Coastguard Worker            torch.qint32,
929*da0073e9SAndroid Build Coastguard Worker            torch.qint8,
930*da0073e9SAndroid Build Coastguard Worker        ]:
931*da0073e9SAndroid Build Coastguard Worker            interpret_dtypes = {
932*da0073e9SAndroid Build Coastguard Worker                torch.quint8: torch.uint8,
933*da0073e9SAndroid Build Coastguard Worker                torch.quint4x2: torch.uint8,
934*da0073e9SAndroid Build Coastguard Worker                torch.quint2x4: torch.uint8,
935*da0073e9SAndroid Build Coastguard Worker                torch.qint32: torch.int32,
936*da0073e9SAndroid Build Coastguard Worker                torch.qint8: torch.int8,
937*da0073e9SAndroid Build Coastguard Worker            }
938*da0073e9SAndroid Build Coastguard Worker            tmp_dtype = interpret_dtypes[self.dtype]
939*da0073e9SAndroid Build Coastguard Worker            tmp_tensor = torch.tensor(
940*da0073e9SAndroid Build Coastguard Worker                [], dtype=tmp_dtype, device=self._untyped_storage.device
941*da0073e9SAndroid Build Coastguard Worker            )
942*da0073e9SAndroid Build Coastguard Worker            tmp_tensor.set_(
943*da0073e9SAndroid Build Coastguard Worker                TypedStorage(
944*da0073e9SAndroid Build Coastguard Worker                    wrap_storage=self._untyped_storage, dtype=tmp_dtype, _internal=True
945*da0073e9SAndroid Build Coastguard Worker                )
946*da0073e9SAndroid Build Coastguard Worker            )
947*da0073e9SAndroid Build Coastguard Worker        else:
948*da0073e9SAndroid Build Coastguard Worker            tmp_tensor = torch.tensor(
949*da0073e9SAndroid Build Coastguard Worker                [], dtype=self.dtype, device=self._untyped_storage.device
950*da0073e9SAndroid Build Coastguard Worker            ).set_(self)
951*da0073e9SAndroid Build Coastguard Worker
952*da0073e9SAndroid Build Coastguard Worker        tmp_tensor[idx] = value
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, idx):
955*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
956*da0073e9SAndroid Build Coastguard Worker        return self._getitem(idx)
957*da0073e9SAndroid Build Coastguard Worker
958*da0073e9SAndroid Build Coastguard Worker    def _getitem(self, idx):
959*da0073e9SAndroid Build Coastguard Worker        if self._untyped_storage.device.type == "meta":
960*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError("Not available for 'meta' device type")
961*da0073e9SAndroid Build Coastguard Worker
962*da0073e9SAndroid Build Coastguard Worker        # NOTE: Before TypedStorage existed, indexing with a slice used to be
963*da0073e9SAndroid Build Coastguard Worker        # possible for <type>Storage objects. However, it would return
964*da0073e9SAndroid Build Coastguard Worker        # a storage view, which would be a hassle to implement in TypedStorage,
965*da0073e9SAndroid Build Coastguard Worker        # so it was disabled
966*da0073e9SAndroid Build Coastguard Worker        if isinstance(idx, slice):
967*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
968*da0073e9SAndroid Build Coastguard Worker                "slices are only supported in UntypedStorage.__getitem__"
969*da0073e9SAndroid Build Coastguard Worker            )
970*da0073e9SAndroid Build Coastguard Worker        elif not isinstance(idx, int):
971*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
972*da0073e9SAndroid Build Coastguard Worker
973*da0073e9SAndroid Build Coastguard Worker        if self.dtype in [
974*da0073e9SAndroid Build Coastguard Worker            torch.quint8,
975*da0073e9SAndroid Build Coastguard Worker            torch.quint4x2,
976*da0073e9SAndroid Build Coastguard Worker            torch.quint2x4,
977*da0073e9SAndroid Build Coastguard Worker            torch.qint32,
978*da0073e9SAndroid Build Coastguard Worker            torch.qint8,
979*da0073e9SAndroid Build Coastguard Worker        ]:
980*da0073e9SAndroid Build Coastguard Worker            interpret_dtypes = {
981*da0073e9SAndroid Build Coastguard Worker                torch.quint8: torch.uint8,
982*da0073e9SAndroid Build Coastguard Worker                torch.quint4x2: torch.uint8,
983*da0073e9SAndroid Build Coastguard Worker                torch.quint2x4: torch.uint8,
984*da0073e9SAndroid Build Coastguard Worker                torch.qint32: torch.int32,
985*da0073e9SAndroid Build Coastguard Worker                torch.qint8: torch.int8,
986*da0073e9SAndroid Build Coastguard Worker            }
987*da0073e9SAndroid Build Coastguard Worker            return TypedStorage(
988*da0073e9SAndroid Build Coastguard Worker                wrap_storage=self._untyped_storage,
989*da0073e9SAndroid Build Coastguard Worker                dtype=interpret_dtypes[self.dtype],
990*da0073e9SAndroid Build Coastguard Worker                _internal=True,
991*da0073e9SAndroid Build Coastguard Worker            )._getitem(idx)
992*da0073e9SAndroid Build Coastguard Worker
993*da0073e9SAndroid Build Coastguard Worker        idx_wrapped = self._maybe_wrap_index(idx)
994*da0073e9SAndroid Build Coastguard Worker        from torch._subclasses.fake_tensor import unset_fake_temporarily
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker        with unset_fake_temporarily():
997*da0073e9SAndroid Build Coastguard Worker            tmp_tensor = torch.tensor(
998*da0073e9SAndroid Build Coastguard Worker                [], dtype=self.dtype, device=self._untyped_storage.device
999*da0073e9SAndroid Build Coastguard Worker            ).set_(self)
1000*da0073e9SAndroid Build Coastguard Worker            return tmp_tensor[idx_wrapped].item()
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker    def copy_(self, source: T, non_blocking: _Optional[bool] = None):
1003*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1004*da0073e9SAndroid Build Coastguard Worker        if isinstance(source, TypedStorage):
1005*da0073e9SAndroid Build Coastguard Worker            self._untyped_storage.copy_(source._untyped_storage, non_blocking)
1006*da0073e9SAndroid Build Coastguard Worker        else:
1007*da0073e9SAndroid Build Coastguard Worker            self._untyped_storage.copy_(source, non_blocking)
1008*da0073e9SAndroid Build Coastguard Worker        return self
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Worker    def nbytes(self):
1011*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1012*da0073e9SAndroid Build Coastguard Worker        return self._nbytes()
1013*da0073e9SAndroid Build Coastguard Worker
1014*da0073e9SAndroid Build Coastguard Worker    # For internal use only, to avoid deprecation warning
1015*da0073e9SAndroid Build Coastguard Worker    def _nbytes(self):
1016*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage.nbytes()
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker    def type(
1019*da0073e9SAndroid Build Coastguard Worker        self,
1020*da0073e9SAndroid Build Coastguard Worker        dtype: _Optional[str] = None,
1021*da0073e9SAndroid Build Coastguard Worker        non_blocking: bool = False,
1022*da0073e9SAndroid Build Coastguard Worker    ) -> Union[_StorageBase, TypedStorage, str]:
1023*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1024*da0073e9SAndroid Build Coastguard Worker        if dtype is None:
1025*da0073e9SAndroid Build Coastguard Worker            legacy_class = self._get_legacy_storage_class()
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker            if legacy_class is not None:
1028*da0073e9SAndroid Build Coastguard Worker                return legacy_class.__module__ + "." + legacy_class.__name__
1029*da0073e9SAndroid Build Coastguard Worker
1030*da0073e9SAndroid Build Coastguard Worker            return ".".join([self.__module__, type(self).__name__])
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker        else:
1033*da0073e9SAndroid Build Coastguard Worker            return self._untyped_storage.type(dtype, non_blocking)
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard Worker    def cuda(self, device=None, non_blocking=False) -> Self:
1036*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1037*da0073e9SAndroid Build Coastguard Worker        if self.dtype in [
1038*da0073e9SAndroid Build Coastguard Worker            torch.quint8,
1039*da0073e9SAndroid Build Coastguard Worker            torch.quint4x2,
1040*da0073e9SAndroid Build Coastguard Worker            torch.quint2x4,
1041*da0073e9SAndroid Build Coastguard Worker            torch.qint32,
1042*da0073e9SAndroid Build Coastguard Worker            torch.qint8,
1043*da0073e9SAndroid Build Coastguard Worker        ]:
1044*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Cannot create CUDA storage with quantized dtype")
1045*da0073e9SAndroid Build Coastguard Worker        cuda_storage = self._untyped_storage.cuda(device, non_blocking)
1046*da0073e9SAndroid Build Coastguard Worker        return self._new_wrapped_storage(cuda_storage)
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker    def hpu(self, device=None, non_blocking=False) -> Self:
1049*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1050*da0073e9SAndroid Build Coastguard Worker        if self.dtype in [
1051*da0073e9SAndroid Build Coastguard Worker            torch.quint8,
1052*da0073e9SAndroid Build Coastguard Worker            torch.quint4x2,
1053*da0073e9SAndroid Build Coastguard Worker            torch.quint2x4,
1054*da0073e9SAndroid Build Coastguard Worker            torch.qint32,
1055*da0073e9SAndroid Build Coastguard Worker            torch.qint8,
1056*da0073e9SAndroid Build Coastguard Worker        ]:
1057*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Cannot create HPU storage with quantized dtype")
1058*da0073e9SAndroid Build Coastguard Worker        hpu_storage = self._untyped_storage.hpu(device, non_blocking)
1059*da0073e9SAndroid Build Coastguard Worker        return self._new_wrapped_storage(hpu_storage)
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker    def to(self, *, device: torch.device, non_blocking: bool = False) -> Self:
1062*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1063*da0073e9SAndroid Build Coastguard Worker        if self.dtype in [
1064*da0073e9SAndroid Build Coastguard Worker            torch.quint8,
1065*da0073e9SAndroid Build Coastguard Worker            torch.quint4x2,
1066*da0073e9SAndroid Build Coastguard Worker            torch.quint2x4,
1067*da0073e9SAndroid Build Coastguard Worker            torch.qint32,
1068*da0073e9SAndroid Build Coastguard Worker            torch.qint8,
1069*da0073e9SAndroid Build Coastguard Worker        ]:
1070*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
1071*da0073e9SAndroid Build Coastguard Worker                f"Cannot create {device.type.upper()} storage with quantized dtype"
1072*da0073e9SAndroid Build Coastguard Worker            )
1073*da0073e9SAndroid Build Coastguard Worker        to_storage = self._untyped_storage.to(device=device, non_blocking=non_blocking)
1074*da0073e9SAndroid Build Coastguard Worker        return self._new_wrapped_storage(to_storage)
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker    def element_size(self):
1077*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1078*da0073e9SAndroid Build Coastguard Worker        return self._element_size()
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker    # For internal use only, to avoid deprecation warning
1081*da0073e9SAndroid Build Coastguard Worker    def _element_size(self):
1082*da0073e9SAndroid Build Coastguard Worker        return torch._utils._element_size(self.dtype)
1083*da0073e9SAndroid Build Coastguard Worker
1084*da0073e9SAndroid Build Coastguard Worker    def get_device(self) -> _int:
1085*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1086*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage.get_device()
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker    def __str__(self):
1089*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1090*da0073e9SAndroid Build Coastguard Worker        info_str = (
1091*da0073e9SAndroid Build Coastguard Worker            f"[{torch.typename(self)}(dtype={self.dtype}, "
1092*da0073e9SAndroid Build Coastguard Worker            f"device={self.device}) of size {len(self)}]"
1093*da0073e9SAndroid Build Coastguard Worker        )
1094*da0073e9SAndroid Build Coastguard Worker        if self.device.type == "meta":
1095*da0073e9SAndroid Build Coastguard Worker            return "...\n" + info_str
1096*da0073e9SAndroid Build Coastguard Worker        else:
1097*da0073e9SAndroid Build Coastguard Worker            data_str = " " + "\n ".join(str(self[i]) for i in range(self.size()))
1098*da0073e9SAndroid Build Coastguard Worker            return data_str + "\n" + info_str
1099*da0073e9SAndroid Build Coastguard Worker
1100*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
1101*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1102*da0073e9SAndroid Build Coastguard Worker        return str(self)
1103*da0073e9SAndroid Build Coastguard Worker
1104*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
1105*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1106*da0073e9SAndroid Build Coastguard Worker        return iter(self[i] for i in range(self.size()))
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker    def __copy__(self):
1109*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1110*da0073e9SAndroid Build Coastguard Worker        return self._new_wrapped_storage(copy.copy(self._untyped_storage))
1111*da0073e9SAndroid Build Coastguard Worker
1112*da0073e9SAndroid Build Coastguard Worker    def __deepcopy__(self, memo):
1113*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1114*da0073e9SAndroid Build Coastguard Worker        return self._deepcopy(memo)
1115*da0073e9SAndroid Build Coastguard Worker
1116*da0073e9SAndroid Build Coastguard Worker    # For internal use only, to avoid deprecation warning
1117*da0073e9SAndroid Build Coastguard Worker    def _deepcopy(self, memo):
1118*da0073e9SAndroid Build Coastguard Worker        return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo))
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker    def __sizeof__(self):
1121*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1122*da0073e9SAndroid Build Coastguard Worker        return super().__sizeof__() + self.nbytes()
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Worker    def clone(self):
1125*da0073e9SAndroid Build Coastguard Worker        """Return a copy of this storage."""
1126*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1127*da0073e9SAndroid Build Coastguard Worker        return self._new_wrapped_storage(self._untyped_storage.clone())
1128*da0073e9SAndroid Build Coastguard Worker
1129*da0073e9SAndroid Build Coastguard Worker    def tolist(self):
1130*da0073e9SAndroid Build Coastguard Worker        """Return a list containing the elements of this storage."""
1131*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1132*da0073e9SAndroid Build Coastguard Worker        return list(self)
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker    def cpu(self):
1135*da0073e9SAndroid Build Coastguard Worker        """Return a CPU copy of this storage if it's not already on the CPU."""
1136*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1137*da0073e9SAndroid Build Coastguard Worker        return self._new_wrapped_storage(self._untyped_storage.cpu())
1138*da0073e9SAndroid Build Coastguard Worker
1139*da0073e9SAndroid Build Coastguard Worker    def is_pinned(self, device: Union[str, torch.device] = "cuda"):
1140*da0073e9SAndroid Build Coastguard Worker        r"""Determine whether the CPU TypedStorage is already pinned on device.
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker        Args:
1143*da0073e9SAndroid Build Coastguard Worker            device (str or torch.device): The device to pin memory on. Default: ``'cuda'``
1144*da0073e9SAndroid Build Coastguard Worker
1145*da0073e9SAndroid Build Coastguard Worker        Returns:
1146*da0073e9SAndroid Build Coastguard Worker            A boolean variable.
1147*da0073e9SAndroid Build Coastguard Worker        """
1148*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1149*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage.is_pinned(device)
1150*da0073e9SAndroid Build Coastguard Worker
1151*da0073e9SAndroid Build Coastguard Worker    def pin_memory(self, device: Union[str, torch.device] = "cuda"):
1152*da0073e9SAndroid Build Coastguard Worker        r"""Copy the CPU TypedStorage to pinned memory, if it's not already pinned.
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker        Args:
1155*da0073e9SAndroid Build Coastguard Worker            device (str or torch.device): The device to pin memory on. Default: ``'cuda'``.
1156*da0073e9SAndroid Build Coastguard Worker
1157*da0073e9SAndroid Build Coastguard Worker        Returns:
1158*da0073e9SAndroid Build Coastguard Worker            A pinned CPU storage.
1159*da0073e9SAndroid Build Coastguard Worker        """
1160*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1161*da0073e9SAndroid Build Coastguard Worker        return self._new_wrapped_storage(
1162*da0073e9SAndroid Build Coastguard Worker            self._untyped_storage.pin_memory(device=device)
1163*da0073e9SAndroid Build Coastguard Worker        )
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Worker    def share_memory_(self):
1166*da0073e9SAndroid Build Coastguard Worker        """See :meth:`torch.UntypedStorage.share_memory_`"""
1167*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1168*da0073e9SAndroid Build Coastguard Worker        return self._share_memory_()
1169*da0073e9SAndroid Build Coastguard Worker
1170*da0073e9SAndroid Build Coastguard Worker    # For internal use only, to avoid deprecation warning
1171*da0073e9SAndroid Build Coastguard Worker    def _share_memory_(self):
1172*da0073e9SAndroid Build Coastguard Worker        self._untyped_storage.share_memory_()
1173*da0073e9SAndroid Build Coastguard Worker        return self
1174*da0073e9SAndroid Build Coastguard Worker
1175*da0073e9SAndroid Build Coastguard Worker    def _new_shared(self, size, *, device=None):
1176*da0073e9SAndroid Build Coastguard Worker        """Create a new storage in shared memory with the same data type."""
1177*da0073e9SAndroid Build Coastguard Worker        if device is None:
1178*da0073e9SAndroid Build Coastguard Worker            device = "cpu"
1179*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
1180*da0073e9SAndroid Build Coastguard Worker        untyped_storage = torch.UntypedStorage._new_shared(
1181*da0073e9SAndroid Build Coastguard Worker            size * self._element_size(), device=device
1182*da0073e9SAndroid Build Coastguard Worker        )
1183*da0073e9SAndroid Build Coastguard Worker        return TypedStorage(
1184*da0073e9SAndroid Build Coastguard Worker            wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
1185*da0073e9SAndroid Build Coastguard Worker        )
1186*da0073e9SAndroid Build Coastguard Worker
1187*da0073e9SAndroid Build Coastguard Worker    @property
1188*da0073e9SAndroid Build Coastguard Worker    def _cdata(self):
1189*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage._cdata
1190*da0073e9SAndroid Build Coastguard Worker
1191*da0073e9SAndroid Build Coastguard Worker    @property
1192*da0073e9SAndroid Build Coastguard Worker    def device(self):
1193*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1194*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage.device
1195*da0073e9SAndroid Build Coastguard Worker
1196*da0073e9SAndroid Build Coastguard Worker    def size(self):
1197*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1198*da0073e9SAndroid Build Coastguard Worker        return self._size()
1199*da0073e9SAndroid Build Coastguard Worker
1200*da0073e9SAndroid Build Coastguard Worker    # For internal use only, to avoid deprecation warning
1201*da0073e9SAndroid Build Coastguard Worker    def _size(self):
1202*da0073e9SAndroid Build Coastguard Worker        # NB: don't indirect through __len__, as that requires
1203*da0073e9SAndroid Build Coastguard Worker        # an int to be returned
1204*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage.nbytes() // self._element_size()
1205*da0073e9SAndroid Build Coastguard Worker
1206*da0073e9SAndroid Build Coastguard Worker    def pickle_storage_type(self):
1207*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1208*da0073e9SAndroid Build Coastguard Worker        return self._pickle_storage_type()
1209*da0073e9SAndroid Build Coastguard Worker
1210*da0073e9SAndroid Build Coastguard Worker    # For internal use only, to avoid deprecation warning
1211*da0073e9SAndroid Build Coastguard Worker    def _pickle_storage_type(self):
1212*da0073e9SAndroid Build Coastguard Worker        try:
1213*da0073e9SAndroid Build Coastguard Worker            return _dtype_to_storage_type_map()[self.dtype]
1214*da0073e9SAndroid Build Coastguard Worker        except KeyError as e:
1215*da0073e9SAndroid Build Coastguard Worker            raise KeyError(f"dtype {self.dtype} is not recognized") from e
1216*da0073e9SAndroid Build Coastguard Worker
1217*da0073e9SAndroid Build Coastguard Worker    def __reduce__(self):
1218*da0073e9SAndroid Build Coastguard Worker        b = io.BytesIO()
1219*da0073e9SAndroid Build Coastguard Worker        torch.save(self, b, _use_new_zipfile_serialization=False)
1220*da0073e9SAndroid Build Coastguard Worker        return (_load_from_bytes, (b.getvalue(),))
1221*da0073e9SAndroid Build Coastguard Worker
1222*da0073e9SAndroid Build Coastguard Worker    def data_ptr(self):
1223*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1224*da0073e9SAndroid Build Coastguard Worker        return self._data_ptr()
1225*da0073e9SAndroid Build Coastguard Worker
1226*da0073e9SAndroid Build Coastguard Worker    # For internal use only, to avoid deprecation warning
1227*da0073e9SAndroid Build Coastguard Worker    def _data_ptr(self):
1228*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage.data_ptr()
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Worker    def resizable(self):
1231*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1232*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage.resizable()
1233*da0073e9SAndroid Build Coastguard Worker
1234*da0073e9SAndroid Build Coastguard Worker    def resize_(self, size):
1235*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1236*da0073e9SAndroid Build Coastguard Worker        self._resize_(size)
1237*da0073e9SAndroid Build Coastguard Worker
1238*da0073e9SAndroid Build Coastguard Worker    # For internal use only, to avoid deprecation warning
1239*da0073e9SAndroid Build Coastguard Worker    def _resize_(self, size):
1240*da0073e9SAndroid Build Coastguard Worker        self._untyped_storage.resize_(size * self._element_size())
1241*da0073e9SAndroid Build Coastguard Worker
1242*da0073e9SAndroid Build Coastguard Worker    @classmethod
1243*da0073e9SAndroid Build Coastguard Worker    def _free_weak_ref(cls, *args, **kwargs):
1244*da0073e9SAndroid Build Coastguard Worker        return UntypedStorage._free_weak_ref(*args, **kwargs)
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker    def _weak_ref(self, *args, **kwargs):
1247*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage._weak_ref(*args, **kwargs)
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker    @classmethod
1250*da0073e9SAndroid Build Coastguard Worker    def from_buffer(cls, *args, **kwargs):
1251*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1252*da0073e9SAndroid Build Coastguard Worker        return cls._from_buffer(*args, **kwargs)
1253*da0073e9SAndroid Build Coastguard Worker
1254*da0073e9SAndroid Build Coastguard Worker    @classmethod
1255*da0073e9SAndroid Build Coastguard Worker    def _from_buffer(cls, *args, dtype=None, device=None, **kwargs):
1256*da0073e9SAndroid Build Coastguard Worker        if cls == TypedStorage:
1257*da0073e9SAndroid Build Coastguard Worker            dtype = torch.get_default_dtype() if dtype is None else dtype
1258*da0073e9SAndroid Build Coastguard Worker            device = torch.device("cpu" if device is None else device)
1259*da0073e9SAndroid Build Coastguard Worker            if device.type != "cpu":
1260*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1261*da0073e9SAndroid Build Coastguard Worker                    f"TypedStorage.from_buffer: Not available for device {device.type}"
1262*da0073e9SAndroid Build Coastguard Worker                )
1263*da0073e9SAndroid Build Coastguard Worker            untyped_storage: torch.UntypedStorage = torch.UntypedStorage.from_buffer(
1264*da0073e9SAndroid Build Coastguard Worker                *args, dtype=dtype, **kwargs
1265*da0073e9SAndroid Build Coastguard Worker            )
1266*da0073e9SAndroid Build Coastguard Worker
1267*da0073e9SAndroid Build Coastguard Worker        else:
1268*da0073e9SAndroid Build Coastguard Worker            if dtype is not None or len(args) == 5:
1269*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1270*da0073e9SAndroid Build Coastguard Worker                    "from_buffer: 'dtype' can only be specified in "
1271*da0073e9SAndroid Build Coastguard Worker                    "UntypedStorage.from_buffer and TypedStorage.from_buffer"
1272*da0073e9SAndroid Build Coastguard Worker                )
1273*da0073e9SAndroid Build Coastguard Worker            if device is not None:
1274*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1275*da0073e9SAndroid Build Coastguard Worker                    "from_buffer: 'device' can only be specified in "
1276*da0073e9SAndroid Build Coastguard Worker                    "UntypedStorage.from_buffer and TypedStorage.from_buffer"
1277*da0073e9SAndroid Build Coastguard Worker                )
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker            dtype = cls._dtype
1280*da0073e9SAndroid Build Coastguard Worker            untyped_storage = torch.UntypedStorage.from_buffer(
1281*da0073e9SAndroid Build Coastguard Worker                *args, dtype=dtype, **kwargs
1282*da0073e9SAndroid Build Coastguard Worker            )
1283*da0073e9SAndroid Build Coastguard Worker
1284*da0073e9SAndroid Build Coastguard Worker        return TypedStorage(wrap_storage=untyped_storage, dtype=dtype, _internal=True)
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Worker    def _to(self, dtype):
1287*da0073e9SAndroid Build Coastguard Worker        if not isinstance(dtype, torch.dtype):
1288*da0073e9SAndroid Build Coastguard Worker            raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
1289*da0073e9SAndroid Build Coastguard Worker        storage = (
1290*da0073e9SAndroid Build Coastguard Worker            torch.tensor([], dtype=self.dtype, device=self.device)
1291*da0073e9SAndroid Build Coastguard Worker            .set_(self)
1292*da0073e9SAndroid Build Coastguard Worker            .to(dtype)
1293*da0073e9SAndroid Build Coastguard Worker            ._typed_storage()
1294*da0073e9SAndroid Build Coastguard Worker        )
1295*da0073e9SAndroid Build Coastguard Worker        if storage.data_ptr() == self.data_ptr():
1296*da0073e9SAndroid Build Coastguard Worker            storage = storage.clone()
1297*da0073e9SAndroid Build Coastguard Worker        return storage
1298*da0073e9SAndroid Build Coastguard Worker
1299*da0073e9SAndroid Build Coastguard Worker    def double(self):
1300*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to double type."""
1301*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1302*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.double)
1303*da0073e9SAndroid Build Coastguard Worker
1304*da0073e9SAndroid Build Coastguard Worker    def float(self):
1305*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to float type."""
1306*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1307*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.float)
1308*da0073e9SAndroid Build Coastguard Worker
1309*da0073e9SAndroid Build Coastguard Worker    def half(self):
1310*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to half type."""
1311*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1312*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.half)
1313*da0073e9SAndroid Build Coastguard Worker
1314*da0073e9SAndroid Build Coastguard Worker    def long(self):
1315*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to long type."""
1316*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1317*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.long)
1318*da0073e9SAndroid Build Coastguard Worker
1319*da0073e9SAndroid Build Coastguard Worker    def int(self):
1320*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to int type."""
1321*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1322*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.int)
1323*da0073e9SAndroid Build Coastguard Worker
1324*da0073e9SAndroid Build Coastguard Worker    def short(self):
1325*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to short type."""
1326*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1327*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.short)
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker    def char(self):
1330*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to char type."""
1331*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1332*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.int8)
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker    def byte(self):
1335*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to byte type."""
1336*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1337*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.uint8)
1338*da0073e9SAndroid Build Coastguard Worker
1339*da0073e9SAndroid Build Coastguard Worker    def bool(self):
1340*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to bool type."""
1341*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1342*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.bool)
1343*da0073e9SAndroid Build Coastguard Worker
1344*da0073e9SAndroid Build Coastguard Worker    def bfloat16(self):
1345*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to bfloat16 type."""
1346*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1347*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.bfloat16)
1348*da0073e9SAndroid Build Coastguard Worker
1349*da0073e9SAndroid Build Coastguard Worker    def complex_double(self):
1350*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to complex double type."""
1351*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1352*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.cdouble)
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Worker    def complex_float(self):
1355*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to complex float type."""
1356*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1357*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.cfloat)
1358*da0073e9SAndroid Build Coastguard Worker
1359*da0073e9SAndroid Build Coastguard Worker    def float8_e5m2(self):
1360*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to float8_e5m2 type"""
1361*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1362*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.float8_e5m2)
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard Worker    def float8_e4m3fn(self):
1365*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to float8_e4m3fn type"""
1366*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1367*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.float8_e4m3fn)
1368*da0073e9SAndroid Build Coastguard Worker
1369*da0073e9SAndroid Build Coastguard Worker    def float8_e5m2fnuz(self):
1370*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to float8_e5m2fnuz type"""
1371*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1372*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.float8_e5m2fnuz)
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker    def float8_e4m3fnuz(self):
1375*da0073e9SAndroid Build Coastguard Worker        """Casts this storage to float8_e4m3fnuz type"""
1376*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1377*da0073e9SAndroid Build Coastguard Worker        return self._to(torch.float8_e4m3fnuz)
1378*da0073e9SAndroid Build Coastguard Worker
1379*da0073e9SAndroid Build Coastguard Worker    @classmethod
1380*da0073e9SAndroid Build Coastguard Worker    def from_file(cls, filename, shared, size):
1381*da0073e9SAndroid Build Coastguard Worker        """from_file(filename, shared=False, size=0) -> Storage
1382*da0073e9SAndroid Build Coastguard Worker
1383*da0073e9SAndroid Build Coastguard Worker        Creates a CPU storage backed by a memory-mapped file.
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker        If ``shared`` is ``True``, then memory is shared between all processes.
1386*da0073e9SAndroid Build Coastguard Worker        All changes are written to the file. If ``shared`` is ``False``, then the changes on
1387*da0073e9SAndroid Build Coastguard Worker        the storage do not affect the file.
1388*da0073e9SAndroid Build Coastguard Worker
1389*da0073e9SAndroid Build Coastguard Worker        ``size`` is the number of elements in the storage. If ``shared`` is ``False``,
1390*da0073e9SAndroid Build Coastguard Worker        then the file must contain at least ``size * sizeof(Type)`` bytes
1391*da0073e9SAndroid Build Coastguard Worker        (``Type`` is the type of storage). If ``shared`` is ``True`` the file will be created if needed.
1392*da0073e9SAndroid Build Coastguard Worker
1393*da0073e9SAndroid Build Coastguard Worker        Args:
1394*da0073e9SAndroid Build Coastguard Worker            filename (str): file name to map
1395*da0073e9SAndroid Build Coastguard Worker            shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the
1396*da0073e9SAndroid Build Coastguard Worker                            underlying `mmap(2) call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_)
1397*da0073e9SAndroid Build Coastguard Worker            size (int): number of elements in the storage
1398*da0073e9SAndroid Build Coastguard Worker        """
1399*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1400*da0073e9SAndroid Build Coastguard Worker        if cls == TypedStorage:
1401*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("from_file can only be called on derived classes")
1402*da0073e9SAndroid Build Coastguard Worker        untyped_storage = UntypedStorage.from_file(
1403*da0073e9SAndroid Build Coastguard Worker            filename, shared, size * torch._utils._element_size(cls.dtype)
1404*da0073e9SAndroid Build Coastguard Worker        )
1405*da0073e9SAndroid Build Coastguard Worker        storage = cls(wrap_storage=untyped_storage)
1406*da0073e9SAndroid Build Coastguard Worker        return storage
1407*da0073e9SAndroid Build Coastguard Worker
1408*da0073e9SAndroid Build Coastguard Worker    @classmethod
1409*da0073e9SAndroid Build Coastguard Worker    def _expired(cls, *args, **kwargs):
1410*da0073e9SAndroid Build Coastguard Worker        return UntypedStorage._expired(*args, **kwargs)
1411*da0073e9SAndroid Build Coastguard Worker
1412*da0073e9SAndroid Build Coastguard Worker    def _write_file(self, *args, **kwargs):
1413*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage._write_file(*args, **kwargs)
1414*da0073e9SAndroid Build Coastguard Worker
1415*da0073e9SAndroid Build Coastguard Worker    def _set_from_file(self, *args, **kwargs):
1416*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage._set_from_file(*args, **kwargs)
1417*da0073e9SAndroid Build Coastguard Worker
1418*da0073e9SAndroid Build Coastguard Worker    def _set_cdata(self, *args, **kwargs):
1419*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage._set_cdata(*args, **kwargs)
1420*da0073e9SAndroid Build Coastguard Worker
1421*da0073e9SAndroid Build Coastguard Worker    def _share_cuda_(self, *args, **kwargs):
1422*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage._share_cuda_(*args, **kwargs)
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker    def is_shared(self):
1425*da0073e9SAndroid Build Coastguard Worker        _warn_typed_storage_removal()
1426*da0073e9SAndroid Build Coastguard Worker        return self._is_shared()
1427*da0073e9SAndroid Build Coastguard Worker
1428*da0073e9SAndroid Build Coastguard Worker    # For internal use only, to avoid deprecation warning
1429*da0073e9SAndroid Build Coastguard Worker    def _is_shared(self):
1430*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage.is_shared()
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker    @classmethod
1433*da0073e9SAndroid Build Coastguard Worker    def _new_shared_cuda(cls, *args, **kwargs):
1434*da0073e9SAndroid Build Coastguard Worker        return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker    def _share_filename_cpu_(self, *args, **kwargs):
1437*da0073e9SAndroid Build Coastguard Worker        (
1438*da0073e9SAndroid Build Coastguard Worker            manager_handle,
1439*da0073e9SAndroid Build Coastguard Worker            storage_handle,
1440*da0073e9SAndroid Build Coastguard Worker            size,
1441*da0073e9SAndroid Build Coastguard Worker        ) = self._untyped_storage._share_filename_cpu_(*args, **kwargs)
1442*da0073e9SAndroid Build Coastguard Worker        return manager_handle, storage_handle, size // self._element_size()
1443*da0073e9SAndroid Build Coastguard Worker
1444*da0073e9SAndroid Build Coastguard Worker    def _shared_decref(self):
1445*da0073e9SAndroid Build Coastguard Worker        self._untyped_storage._shared_decref()
1446*da0073e9SAndroid Build Coastguard Worker        return self
1447*da0073e9SAndroid Build Coastguard Worker
1448*da0073e9SAndroid Build Coastguard Worker    @classmethod
1449*da0073e9SAndroid Build Coastguard Worker    def _release_ipc_counter(cls, *args, device=None, **kwargs):
1450*da0073e9SAndroid Build Coastguard Worker        return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
1451*da0073e9SAndroid Build Coastguard Worker
1452*da0073e9SAndroid Build Coastguard Worker    def _shared_incref(self, *args, **kwargs):
1453*da0073e9SAndroid Build Coastguard Worker        return self._untyped_storage._shared_incref(*args, **kwargs)
1454*da0073e9SAndroid Build Coastguard Worker
1455*da0073e9SAndroid Build Coastguard Worker    def _share_fd_cpu_(self, *args, **kwargs):
1456*da0073e9SAndroid Build Coastguard Worker        fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs)
1457*da0073e9SAndroid Build Coastguard Worker        return fd, size // self._element_size()
1458*da0073e9SAndroid Build Coastguard Worker
1459*da0073e9SAndroid Build Coastguard Worker    def _get_legacy_storage_class(self):
1460*da0073e9SAndroid Build Coastguard Worker        if self.dtype not in _dtype_to_storage_type_map():
1461*da0073e9SAndroid Build Coastguard Worker            return None
1462*da0073e9SAndroid Build Coastguard Worker
1463*da0073e9SAndroid Build Coastguard Worker        storage_name = _dtype_to_storage_type_map()[self.dtype]
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker        if self.device.type not in [
1466*da0073e9SAndroid Build Coastguard Worker            "cpu",
1467*da0073e9SAndroid Build Coastguard Worker            "cuda",
1468*da0073e9SAndroid Build Coastguard Worker            "hpu",
1469*da0073e9SAndroid Build Coastguard Worker            torch._C._get_privateuse1_backend_name(),
1470*da0073e9SAndroid Build Coastguard Worker        ]:
1471*da0073e9SAndroid Build Coastguard Worker            return None
1472*da0073e9SAndroid Build Coastguard Worker
1473*da0073e9SAndroid Build Coastguard Worker        module = (
1474*da0073e9SAndroid Build Coastguard Worker            torch if self.device.type == "cpu" else getattr(torch, self.device.type)
1475*da0073e9SAndroid Build Coastguard Worker        )
1476*da0073e9SAndroid Build Coastguard Worker
1477*da0073e9SAndroid Build Coastguard Worker        try:
1478*da0073e9SAndroid Build Coastguard Worker            return getattr(module, storage_name)
1479*da0073e9SAndroid Build Coastguard Worker        except AttributeError:
1480*da0073e9SAndroid Build Coastguard Worker            return None
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker
1483*da0073e9SAndroid Build Coastguard WorkerTypedStorage.type.__doc__ = _type.__doc__
1484*da0073e9SAndroid Build Coastguard WorkerTypedStorage.cuda.__doc__ = _StorageBase.cuda.__doc__
1485*da0073e9SAndroid Build Coastguard WorkerTypedStorage.hpu.__doc__ = _StorageBase.hpu.__doc__
1486*da0073e9SAndroid Build Coastguard WorkerTypedStorage.to.__doc__ = _to.__doc__
1487*da0073e9SAndroid Build Coastguard Worker
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Workerclass _LegacyStorageMeta(type):
1490*da0073e9SAndroid Build Coastguard Worker    dtype: torch.dtype
1491*da0073e9SAndroid Build Coastguard Worker
1492*da0073e9SAndroid Build Coastguard Worker    def __instancecheck__(cls, instance):
1493*da0073e9SAndroid Build Coastguard Worker        if type(instance) == TypedStorage:
1494*da0073e9SAndroid Build Coastguard Worker            cls_device = _get_device_from_module(cls.__module__)
1495*da0073e9SAndroid Build Coastguard Worker            return (cls_device == instance.device.type) and (
1496*da0073e9SAndroid Build Coastguard Worker                cls.dtype == instance.dtype
1497*da0073e9SAndroid Build Coastguard Worker            )
1498*da0073e9SAndroid Build Coastguard Worker        return False
1499*da0073e9SAndroid Build Coastguard Worker
1500*da0073e9SAndroid Build Coastguard Worker
1501*da0073e9SAndroid Build Coastguard Workerclass _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta):
1502*da0073e9SAndroid Build Coastguard Worker    @classmethod
1503*da0073e9SAndroid Build Coastguard Worker    def _new_shared(cls, size):
1504*da0073e9SAndroid Build Coastguard Worker        """Create a new storage in shared memory with the same data type."""
1505*da0073e9SAndroid Build Coastguard Worker        untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size())
1506*da0073e9SAndroid Build Coastguard Worker        return cls(wrap_storage=untyped_storage)
1507*da0073e9SAndroid Build Coastguard Worker
1508*da0073e9SAndroid Build Coastguard Worker    @classmethod
1509*da0073e9SAndroid Build Coastguard Worker    def _release_ipc_counter(cls, *args, **kwargs):
1510*da0073e9SAndroid Build Coastguard Worker        return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Worker    @classmethod
1513*da0073e9SAndroid Build Coastguard Worker    def _new_shared_filename(cls, manager, obj, size):
1514*da0073e9SAndroid Build Coastguard Worker        bytes_size = size * torch._utils._element_size(cls.dtype)
1515*da0073e9SAndroid Build Coastguard Worker        return cls(
1516*da0073e9SAndroid Build Coastguard Worker            wrap_storage=torch.UntypedStorage._new_shared_filename_cpu(
1517*da0073e9SAndroid Build Coastguard Worker                manager, obj, bytes_size
1518*da0073e9SAndroid Build Coastguard Worker            )
1519*da0073e9SAndroid Build Coastguard Worker        )
1520*da0073e9SAndroid Build Coastguard Worker
1521*da0073e9SAndroid Build Coastguard Worker
1522*da0073e9SAndroid Build Coastguard Workerdef _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
1523*da0073e9SAndroid Build Coastguard Worker    try:
1524*da0073e9SAndroid Build Coastguard Worker        return _storage_type_to_dtype_map()[pickle_storage_type]
1525*da0073e9SAndroid Build Coastguard Worker    except KeyError as e:
1526*da0073e9SAndroid Build Coastguard Worker        raise KeyError(
1527*da0073e9SAndroid Build Coastguard Worker            f'pickle storage type "{pickle_storage_type}" is not recognized'
1528*da0073e9SAndroid Build Coastguard Worker        ) from e
1529