xref: /aosp_15_r20/external/pytorch/torch/storage.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3from __future__ import annotations
4
5import collections
6import copy
7import functools
8import io
9import threading
10import warnings
11from typing import Any, cast, Dict as _Dict, Optional as _Optional, Type, TypeVar, Union
12from typing_extensions import Self
13
14import torch
15from torch._utils import _to, _type
16from torch.types import _bool, _int, Storage
17
18
19__all__ = ["TypedStorage", "UntypedStorage"]
20
21
22try:
23    import numpy as np
24
25    HAS_NUMPY = True
26except ModuleNotFoundError:
27    HAS_NUMPY = False
28    np = None  # type: ignore[assignment]
29
30
31_share_memory_lock = threading.Lock()
32_share_memory_map: _Dict[int, threading.RLock] = {}
33
34T = TypeVar("T", bound="Union[_StorageBase, TypedStorage]")
35
36
37class _StorageBase:
38    _cdata: Any
39    is_sparse: _bool = False
40    is_sparse_csr: _bool = False
41    device: torch.device
42    # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True)
43    _fake_device: _Optional[torch.device] = None
44
45    def __init__(self, *args, **kwargs):
46        pass
47
48    def __len__(self) -> _int:
49        raise NotImplementedError
50
51    def __getitem__(self, idx):
52        raise NotImplementedError
53
54    def __setitem__(self, *args, **kwargs):
55        raise NotImplementedError
56
57    def copy_(self, source: T, non_blocking: _Optional[_bool] = None) -> T:
58        raise NotImplementedError
59
60    def new(self) -> Union[_StorageBase, TypedStorage]:
61        raise NotImplementedError
62
63    def nbytes(self) -> _int:
64        raise NotImplementedError
65
66    def size(self) -> _int:
67        return self.nbytes()
68
69    def type(
70        self, dtype: _Optional[str] = None, non_blocking: _bool = False
71    ) -> Union[_StorageBase, TypedStorage]:
72        return _type(self, dtype, non_blocking)
73
74    def cuda(
75        self, device=None, non_blocking=False
76    ) -> Union[_StorageBase, TypedStorage]:
77        """Returns a copy of this object in CUDA memory.
78
79        If this object is already in CUDA memory and on the correct device, then
80        no copy is performed and the original object is returned.
81
82        Args:
83            device (int): The destination GPU id. Defaults to the current device.
84            non_blocking (bool): If ``True`` and the source is in pinned memory,
85                the copy will be asynchronous with respect to the host. Otherwise,
86                the argument has no effect.
87        """
88        device2 = torch.device("cuda", device) if device else torch.device("cuda")
89        return self.to(device=device2, non_blocking=non_blocking)
90
91    def hpu(self, device=None, non_blocking=False) -> Union[_StorageBase, TypedStorage]:
92        """Returns a copy of this object in HPU memory.
93
94        If this object is already in HPU memory and on the correct device, then
95        no copy is performed and the original object is returned.
96
97        Args:
98            device (int): The destination HPU id. Defaults to the current device.
99            non_blocking (bool): If ``True`` and the source is in pinned memory,
100                the copy will be asynchronous with respect to the host. Otherwise,
101                the argument has no effect.
102        """
103        device2 = torch.device("hpu", device) if device else torch.device("hpu")
104        return self.to(device=device2, non_blocking=non_blocking)
105
106    def element_size(self) -> _int:
107        raise NotImplementedError
108
109    def get_device(self) -> _int:
110        return self.device.index
111
112    def data_ptr(self) -> _int:
113        raise NotImplementedError
114
115    def resizable(self) -> _bool:
116        raise NotImplementedError
117
118    # Defined in torch/csrc/generic/StorageSharing.cpp
119    def _share_filename_cpu_(self, *args, **kwargs):
120        raise NotImplementedError
121
122    def _share_fd_cpu_(self, *args, **kwargs):
123        raise NotImplementedError
124
125    @classmethod
126    def _new_using_filename_cpu(cls: Type[T], size: _int) -> T:
127        raise NotImplementedError
128
129    @classmethod
130    def _new_using_fd_cpu(cls: Type[T], size: _int) -> T:
131        raise NotImplementedError
132
133    @classmethod
134    def from_buffer(cls: Type[T], *args, **kwargs) -> T:
135        raise NotImplementedError
136
137    @classmethod
138    def _new_shared_filename_cpu(
139        cls: Type[T],
140        manager,
141        obj,
142        size,
143        *,
144        device=None,
145        dtype=None,
146    ) -> T:
147        raise NotImplementedError
148
149    @classmethod
150    def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T:
151        raise NotImplementedError
152
153    @classmethod
154    def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T:
155        raise NotImplementedError
156
157    def _shared_decref(self) -> Union[_StorageBase, TypedStorage]:
158        raise NotImplementedError
159
160    def _write_file(self, *args, **kwargs):
161        raise NotImplementedError
162
163    def resize_(self, size: _int):
164        raise NotImplementedError
165
166    def _weak_ref(self, *args, **kwargs) -> Union[_StorageBase, TypedStorage]:
167        raise NotImplementedError
168
169    def _set_from_file(self, *args, **kwargs):
170        raise NotImplementedError
171
172    def _set_cdata(self, *args, **kwargs):
173        raise NotImplementedError
174
175    def _share_cuda_(self, *args, **kwargs):
176        raise NotImplementedError
177
178    def is_shared(self) -> _bool:
179        raise NotImplementedError
180
181    @classmethod
182    def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T:
183        raise NotImplementedError
184
185    def _shared_incref(self, *args, **kwargs):
186        raise NotImplementedError
187
188    @classmethod
189    def _free_weak_ref(cls, *args, **kwargs):
190        raise NotImplementedError
191
192    @property
193    def is_cuda(self):
194        raise NotImplementedError
195
196    @property
197    def is_hpu(self):
198        raise NotImplementedError
199
200    @classmethod
201    def from_file(cls, filename, shared, nbytes) -> Union[_StorageBase, TypedStorage]:
202        raise NotImplementedError
203
204    @classmethod
205    def _expired(cls, *args, **kwargs) -> Union[_StorageBase, TypedStorage]:
206        raise NotImplementedError
207
208    def _byteswap(self, *args, **kwargs):
209        raise NotImplementedError
210
211    def _get_filename(self, *args, **kwargs) -> _Optional[str]:
212        raise NotImplementedError
213
214    def __repr__(self):
215        info_str = f"[{torch.typename(self)}(device={self.device}) of size {len(self)}]"
216        if self.device.type == "meta":
217            return "...\n" + info_str
218        data_str = " " + "\n ".join(str(self[i]) for i in range(self.size()))
219        return data_str + "\n" + info_str
220
221    def __iter__(self):
222        return iter(self[i] for i in range(self.size()))
223
224    def __copy__(self):
225        return self.clone()
226
227    def __deepcopy__(self, memo):
228        memo = memo.setdefault("torch", {})
229        if self._cdata in memo:
230            return memo[self._cdata]
231        new_storage = self.clone()
232        memo[self._cdata] = new_storage
233        return new_storage
234
235    def __reduce__(self):
236        b = io.BytesIO()
237        torch.save(self, b, _use_new_zipfile_serialization=False)
238        return (_load_from_bytes, (b.getvalue(),))
239
240    def __sizeof__(self):
241        return super().__sizeof__() + self.size()
242
243    def clone(self):
244        """Return a copy of this storage."""
245        return type(self)(self.nbytes(), device=self.device).copy_(self)
246
247    def tolist(self):
248        """Return a list containing the elements of this storage."""
249        return list(self)
250
251    def cpu(self):
252        """Return a CPU copy of this storage if it's not already on the CPU."""
253        if self.device.type != "cpu":
254            return torch.UntypedStorage(self.size()).copy_(self, False)
255        return self
256
257    def mps(self):
258        """Return a MPS copy of this storage if it's not already on the MPS."""
259        if self.device.type != "mps":
260            return torch.UntypedStorage(self.size(), device="mps").copy_(self, False)
261        return self
262
263    def _to(self, dtype):
264        if not isinstance(dtype, torch.dtype):
265            raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
266        storage = (
267            torch.tensor([], dtype=torch.uint8, device=self.device)
268            .set_(cast(Storage, self))
269            .to(dtype)
270            ._typed_storage()
271        )
272        if storage.data_ptr() == self.data_ptr():
273            storage = storage.clone()
274        return storage
275
276    def to(
277        self, *, device: torch.device, non_blocking: _bool = False
278    ) -> Union[_StorageBase, TypedStorage]:
279        return _to(self, device, non_blocking)
280
281    def double(self):
282        """Casts this storage to double type."""
283        return self._to(torch.double)
284
285    def float(self):
286        """Casts this storage to float type."""
287        return self._to(torch.float)
288
289    def half(self):
290        """Casts this storage to half type."""
291        return self._to(torch.half)
292
293    def long(self):
294        """Casts this storage to long type."""
295        return self._to(torch.long)
296
297    def int(self):
298        """Casts this storage to int type."""
299        return self._to(torch.int)
300
301    def short(self):
302        """Casts this storage to short type."""
303        return self._to(torch.short)
304
305    def char(self):
306        """Casts this storage to char type."""
307        return self._to(torch.int8)
308
309    def byte(self):
310        """Casts this storage to byte type."""
311        return self._to(torch.uint8)
312
313    def bool(self):
314        """Casts this storage to bool type."""
315        return self._to(torch.bool)
316
317    def bfloat16(self):
318        """Casts this storage to bfloat16 type."""
319        return self._to(torch.bfloat16)
320
321    def complex_double(self):
322        """Casts this storage to complex double type."""
323        return self._to(torch.cdouble)
324
325    def complex_float(self):
326        """Casts this storage to complex float type."""
327        return self._to(torch.cfloat)
328
329    def float8_e5m2(self):
330        """Casts this storage to float8_e5m2 type"""
331        return self._to(torch.float8_e5m2)
332
333    def float8_e4m3fn(self):
334        """Casts this storage to float8_e4m3fn type"""
335        return self._to(torch.float8_e4m3fn)
336
337    def float8_e5m2fnuz(self):
338        """Casts this storage to float8_e5m2fnuz type"""
339        return self._to(torch.float8_e5m2fnuz)
340
341    def float8_e4m3fnuz(self):
342        """Casts this storage to float8_e4m3fnuz type"""
343        return self._to(torch.float8_e4m3fnuz)
344
345    def is_pinned(self, device: Union[str, torch.device] = "cuda"):
346        r"""Determine whether the CPU storage is already pinned on device.
347
348        Args:
349            device (str or torch.device): The device to pin memory on. Default: ``'cuda'``.
350
351        Returns:
352            A boolean variable.
353        """
354        return (
355            torch.tensor([], dtype=torch.uint8, device=self.device)
356            .set_(cast(Storage, self))
357            .is_pinned(device)
358        )
359
360    def pin_memory(self, device: Union[str, torch.device] = "cuda"):
361        r"""Copy the CPU storage to pinned memory, if it's not already pinned.
362
363        Args:
364            device (str or torch.device): The device to pin memory on. Default: ``'cuda'``.
365
366        Returns:
367            A pinned CPU storage.
368        """
369        if self.device.type != "cpu":
370            raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
371
372        pinned_tensor = (
373            torch.tensor([], dtype=torch.uint8, device=self.device)
374            .set_(cast(Storage, self))
375            .pin_memory(device)
376        )
377        return pinned_tensor.untyped_storage()
378
379    def share_memory_(self):
380        """See :meth:`torch.UntypedStorage.share_memory_`"""
381        from torch.multiprocessing import get_sharing_strategy
382
383        if self.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
384            pass  # CUDA or PrivateUse1 doesn't use POSIX shared memory
385        elif get_sharing_strategy() == "file_system":
386            self._share_filename_cpu_()
387        else:
388            self._share_fd_cpu_()
389        return self
390
391    @classmethod
392    def _new_shared(cls, size, *, device="cpu"):
393        """Create a new storage in shared memory with the same data type."""
394        from torch.multiprocessing import get_sharing_strategy
395
396        device = torch.device(device)
397        if device.type in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]:
398            return cls(size, device=device)
399        elif get_sharing_strategy() == "file_system":
400            return cls._new_using_filename_cpu(size)
401        else:
402            return cls._new_using_fd_cpu(size)
403
404    def untyped(self):
405        return self
406
407    def byteswap(self, dtype):
408        """Swap bytes in underlying data."""
409        elem_size = torch._utils._element_size(dtype)
410        # for complex types, don't swap first and second numbers
411        if dtype.is_complex:
412            elem_size = max(int(elem_size / 2), 1)
413        self._byteswap(elem_size)
414
415
416def _share_memory_lock_protected(fn):
417    @functools.wraps(fn)
418    def wrapper(self, *args, **kwargs):
419        to_free = None
420        to_wait = None
421        with _share_memory_lock:
422            key = self._cdata
423            if key in _share_memory_map:
424                to_wait = _share_memory_map[key]
425            else:
426                _share_memory_map[key] = threading.RLock()
427                _share_memory_map[key].acquire()
428                to_free = key
429
430        # If we're already in the process of sharing the storage, wait
431        # for it to be done.
432        if to_wait is not None:
433            with to_wait:
434                pass
435
436        try:
437            return fn(self, *args, **kwargs)
438        finally:
439            # If we acquired the storage lock here and we're done working on it
440            # we can now release it and free the entry.
441            if to_free is not None:
442                # Ensure that the cdata from the storage didn't change and only
443                # the data_ptr did.
444                assert self._cdata == to_free
445                with _share_memory_lock:
446                    _share_memory_map[to_free].release()
447                    del _share_memory_map[to_free]
448
449    return wrapper
450
451
452class UntypedStorage(torch._C.StorageBase, _StorageBase):
453    def __getitem__(self, *args, **kwargs):
454        if self.device.type == "meta":
455            raise NotImplementedError("Not available for 'meta' device type")
456        return super().__getitem__(*args, **kwargs)
457
458    @property
459    def is_cuda(self):
460        return self.device.type == "cuda"
461
462    @property
463    def is_hpu(self):
464        return self.device.type == "hpu"
465
466    @property
467    def filename(self) -> _Optional[str]:
468        """Returns the file name associated with this storage.
469
470        The file name will be a string if the storage is on CPU and was created via
471        :meth:`~torch.from_file()` with ``shared`` as ``True``. This attribute is ``None`` otherwise.
472        """
473        return self._get_filename()
474
475    @_share_memory_lock_protected
476    def share_memory_(self, *args, **kwargs):
477        """
478        Moves the storage to shared memory.
479
480        This is a no-op for storages already in shared memory and for CUDA
481        storages, which do not need to be moved for sharing across processes.
482        Storages in shared memory cannot be resized.
483
484        Note that to mitigate issues like `this <https://github.com/pytorch/pytorch/issues/95606>`_
485        it is thread safe to call this function from multiple threads on the same object.
486        It is NOT thread safe though to call any other function on self without proper
487        synchronization. Please see :doc:`/notes/multiprocessing` for more details.
488
489        .. note::
490            When all references to a storage in shared memory are deleted, the associated shared memory
491            object will also be deleted. PyTorch has a special cleanup process to ensure that this happens
492            even if the current process exits unexpectedly.
493
494            It is worth noting the difference between :meth:`share_memory_` and :meth:`from_file` with ``shared = True``
495
496            #. ``share_memory_`` uses `shm_open(3) <https://man7.org/linux/man-pages/man3/shm_open.3.html>`_ to create a
497               POSIX shared memory object while :meth:`from_file` uses
498               `open(2) <https://man7.org/linux/man-pages/man2/open.2.html>`_ to open the filename passed by the user.
499            #. Both use an `mmap(2) call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_ with ``MAP_SHARED``
500               to map the file/object into the current virtual address space
501            #. ``share_memory_`` will call ``shm_unlink(3)`` on the object after mapping it to make sure the shared memory
502               object is freed when no process has the object open. ``torch.from_file(shared=True)`` does not unlink the
503               file. This file is persistent and will remain until it is deleted by the user.
504
505        Returns:
506            ``self``
507        """
508        return super().share_memory_(*args, **kwargs)
509
510    @_share_memory_lock_protected
511    def _share_fd_cpu_(self, *args, **kwargs):
512        return super()._share_fd_cpu_(*args, **kwargs)
513
514    @_share_memory_lock_protected
515    def _share_filename_cpu_(self, *args, **kwargs):
516        return super()._share_filename_cpu_(*args, **kwargs)
517
518
519def _load_from_bytes(b):
520    return torch.load(io.BytesIO(b), weights_only=False)
521
522
523@functools.lru_cache(maxsize=None)
524def _new_dtypes():
525    # These are dtypes serialized as UntypedStorage unlike those in
526    # _dtype_to_storage_type_map
527    return {
528        torch.float8_e5m2,
529        torch.float8_e4m3fn,
530        torch.float8_e5m2fnuz,
531        torch.float8_e4m3fnuz,
532        torch.bits8,
533        torch.bits16,
534        torch.bits1x8,
535        torch.bits2x4,
536        torch.bits4x2,
537        torch.complex32,
538    }
539
540
541@functools.lru_cache(maxsize=None)
542def _dtype_to_storage_type_map():
543    # NOTE: We should no longer add dtypes to this map. This map
544    # is only used for BC/FC with older PyTorch versions. Going forward,
545    # new dtypes of TypedStorage should not translate to a legacy
546    # <type>Storage class. Instead, new dtypes of TypedStorage should
547    # be serialized as an UntypedStorage paired with a torch.dtype
548    return {
549        torch.double: "DoubleStorage",
550        torch.float: "FloatStorage",
551        torch.half: "HalfStorage",
552        torch.long: "LongStorage",
553        torch.int: "IntStorage",
554        torch.int16: "ShortStorage",
555        torch.int8: "CharStorage",
556        torch.uint8: "ByteStorage",
557        torch.bool: "BoolStorage",
558        torch.bfloat16: "BFloat16Storage",
559        torch.cdouble: "ComplexDoubleStorage",
560        torch.cfloat: "ComplexFloatStorage",
561        torch.qint8: "QInt8Storage",
562        torch.qint32: "QInt32Storage",
563        torch.quint8: "QUInt8Storage",
564        torch.quint4x2: "QUInt4x2Storage",
565        torch.quint2x4: "QUInt2x4Storage",
566    }
567
568
569@functools.lru_cache(maxsize=None)
570def _storage_type_to_dtype_map():
571    dtype_map = {val: key for key, val in _dtype_to_storage_type_map().items()}
572    return dtype_map
573
574
575def _get_storage_from_sequence(sequence, dtype, device):
576    if dtype in [
577        torch.quint8,
578        torch.quint4x2,
579        torch.quint2x4,
580        torch.qint32,
581        torch.qint8,
582    ]:
583        interpret_dtypes = {
584            torch.quint8: torch.uint8,
585            torch.quint4x2: torch.uint8,
586            torch.quint2x4: torch.uint8,
587            torch.qint32: torch.int32,
588            torch.qint8: torch.int8,
589        }
590        tmp_tensor = torch.tensor(
591            sequence, dtype=interpret_dtypes[dtype], device=device
592        )
593
594    else:
595        tmp_tensor = torch.tensor(sequence, dtype=dtype, device=device)
596
597    return tmp_tensor._typed_storage()._untyped_storage
598
599
600def _isint(x):
601    if HAS_NUMPY:
602        return isinstance(x, (int, np.integer))
603    else:
604        return isinstance(x, int)
605
606
607_always_warn_typed_storage_removal = False
608
609
610def _get_always_warn_typed_storage_removal():
611    return _always_warn_typed_storage_removal
612
613
614def _set_always_warn_typed_storage_removal(always_warn):
615    global _always_warn_typed_storage_removal
616    assert isinstance(always_warn, bool)
617    _always_warn_typed_storage_removal = always_warn
618
619
620def _warn_typed_storage_removal(stacklevel=2):
621    global _always_warn_typed_storage_removal
622
623    def is_first_time():
624        if not hasattr(_warn_typed_storage_removal, "has_warned"):
625            return True
626        else:
627            return not _warn_typed_storage_removal.__dict__["has_warned"]
628
629    if _get_always_warn_typed_storage_removal() or is_first_time():
630        message = (
631            "TypedStorage is deprecated. It will be removed in the future and "
632            "UntypedStorage will be the only storage class. This should only matter "
633            "to you if you are using storages directly.  To access UntypedStorage "
634            "directly, use tensor.untyped_storage() instead of tensor.storage()"
635        )
636        warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
637        _warn_typed_storage_removal.__dict__["has_warned"] = True
638
639
640def _reset_warn_typed_storage_removal():
641    _warn_typed_storage_removal.__dict__["has_warned"] = False
642
643
644def _get_device_from_module(module: str):
645    last_part = module.rsplit(".", 1)[-1]
646    if last_part in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]:
647        return last_part
648    else:
649        return "cpu"
650
651
652class TypedStorage:
653    is_sparse: _bool = False
654    # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True)
655    _fake_device: _Optional[torch.device] = None
656
657    dtype: torch.dtype
658
659    @property
660    def _dtype(self):
661        return self.dtype
662
663    @property
664    def filename(self) -> _Optional[str]:
665        """Returns the file name associated with this storage if the storage was memory mapped from a file.
666        or ``None`` if the storage was not created by memory mapping a file."""
667        return self._untyped_storage.filename
668
669    def fill_(self, value):
670        _warn_typed_storage_removal()
671        self._setitem(slice(0, self._size()), value)
672        return self
673
674    def __new__(
675        cls,
676        *args,
677        wrap_storage=None,
678        dtype=None,
679        device=None,
680        _internal=False,
681    ):
682        if not _internal:
683            _warn_typed_storage_removal()
684
685        if cls == torch.storage._LegacyStorage:
686            raise RuntimeError(
687                "Only child classes of _LegacyStorage can be instantiated"
688            )
689
690        if cls == TypedStorage:
691            return super().__new__(cls)
692
693        else:
694            arg_error_msg = (
695                f"{cls}.__new__ received an invalid combination "
696                f"of arguments. Expected one of:\n"
697                " * no arguments\n"
698                " * (int size)\n"
699                " * (Sequence data)\n"
700                " * (*, UntypedStorage wrap_storage)"
701            )
702
703            if device is not None:
704                raise RuntimeError(
705                    arg_error_msg + "\nKeyword argument 'device' cannot be specified"
706                )
707
708            if dtype is not None:
709                raise RuntimeError(
710                    arg_error_msg + "\nKeyword argument 'dtype' cannot be specified"
711                )
712
713            if wrap_storage is None:
714                if len(args) > 1:
715                    raise RuntimeError(
716                        arg_error_msg + "\nToo many positional arguments"
717                    )
718
719                if (
720                    len(args) == 1
721                    and not _isint(args[0])
722                    and not isinstance(args[0], collections.abc.Sequence)
723                ):
724                    raise TypeError(
725                        arg_error_msg
726                        + f"\nArgument type not recognized: {type(args[0])}"
727                    )
728
729                return TypedStorage(
730                    *args,
731                    dtype=cls._dtype,
732                    device=_get_device_from_module(cls.__module__),
733                    _internal=True,
734                )
735
736            else:
737                if len(args) != 0:
738                    raise RuntimeError(
739                        arg_error_msg
740                        + "\nNo positional arguments should be given when using "
741                        "'wrap_storage'"
742                    )
743
744                if not isinstance(wrap_storage, torch.UntypedStorage):
745                    raise TypeError(
746                        arg_error_msg
747                        + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}"
748                    )
749
750                cls_device = _get_device_from_module(cls.__module__)
751
752                if wrap_storage.device.type != cls_device:
753                    raise RuntimeError(
754                        arg_error_msg
755                        + f"\nDevice of 'wrap_storage' must be {cls_device}"
756                        f", but got {wrap_storage.device.type}"
757                    )
758
759                return TypedStorage(
760                    *args,
761                    wrap_storage=wrap_storage,
762                    dtype=cls.dtype,
763                    _internal=True,
764                )
765
766    def __init__(
767        self,
768        *args,
769        device=None,
770        dtype=None,
771        wrap_storage=None,
772        _internal=False,
773    ):
774        if not _internal:
775            _warn_typed_storage_removal()
776        arg_error_msg = (
777            "TypedStorage.__init__ received an invalid combination "
778            "of arguments. Expected one of:\n"
779            " * (*, torch.device device, torch.dtype dtype)\n"
780            " * (int size, *, torch.device device, torch.dtype dtype)\n"
781            " * (Sequence data, *, torch.device device, torch.dtype dtype)\n"
782            " * (*, UntypedStorage wrap_storage, torch.dtype dtype)"
783        )
784
785        if wrap_storage is not None:
786            if len(args) != 0:
787                raise RuntimeError(
788                    arg_error_msg
789                    + "\nNo positional arguments should be given when using "
790                    "'wrap_storage'"
791                )
792
793            if dtype is None:
794                raise RuntimeError(
795                    arg_error_msg + "\nArgument 'dtype' must be specified"
796                )
797
798            if not isinstance(dtype, torch.dtype):
799                raise TypeError(
800                    arg_error_msg
801                    + f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}"
802                )
803
804            if device is not None:
805                raise RuntimeError(
806                    arg_error_msg
807                    + "\nArgument 'device' should not be specified when 'wrap_storage' is given"
808                )
809
810            self.dtype = dtype
811
812            if not isinstance(wrap_storage, torch.UntypedStorage):
813                raise TypeError(
814                    arg_error_msg
815                    + f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}"
816                )
817
818            self._untyped_storage = wrap_storage
819
820        else:
821            self.dtype = torch.get_default_dtype() if dtype is None else dtype
822            device = torch.device("cpu" if device is None else device)
823
824            if self.dtype in [
825                torch.quint8,
826                torch.quint4x2,
827                torch.quint2x4,
828                torch.qint32,
829                torch.qint8,
830            ]:
831                if device.type == "cuda":
832                    raise RuntimeError(
833                        "Cannot create CUDA storage with quantized dtype"
834                    )
835
836            if len(args) == 0:
837                self._untyped_storage = torch.UntypedStorage(device=device)
838
839            elif len(args) == 1:
840                if _isint(args[0]):
841                    self._untyped_storage = torch.UntypedStorage(
842                        int(args[0]) * self._element_size(), device=device
843                    )
844                elif isinstance(args[0], collections.abc.Sequence):
845                    self._untyped_storage = _get_storage_from_sequence(
846                        args[0], self.dtype, device
847                    )
848                else:
849                    raise TypeError(
850                        arg_error_msg
851                        + f"\nArgument type not recognized: {type(args[0])}"
852                    )
853
854            else:
855                raise RuntimeError(arg_error_msg + "\nToo many positional arguments")
856
857    @property
858    def is_cuda(self):
859        _warn_typed_storage_removal()
860        return self._untyped_storage.device.type == "cuda"
861
862    @property
863    def is_hpu(self):
864        _warn_typed_storage_removal()
865        return self._untyped_storage.device.type == "hpu"
866
867    def untyped(self):
868        """Return the internal :class:`torch.UntypedStorage`."""
869        _warn_typed_storage_removal()
870        return self._untyped_storage
871
872    def _new_wrapped_storage(self, untyped_storage) -> Self:
873        assert type(untyped_storage) == torch.UntypedStorage
874
875        if type(self) == TypedStorage:
876            return cast(
877                Self,
878                TypedStorage(
879                    wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
880                ),
881            )
882        else:
883            return type(self)(wrap_storage=untyped_storage)
884
885    def __len__(self):
886        _warn_typed_storage_removal()
887        return self._size()
888
889    def _maybe_wrap_index(self, idx, is_stop=False):
890        if idx is None:
891            if is_stop:
892                return self._size()
893            else:
894                return 0
895
896        else:
897            if type(idx) != int:
898                raise TypeError(f"can't index a {type(self)} with {type(idx)}")
899            if is_stop:
900                if (idx > self._size()) or (idx < -self._size()):
901                    raise IndexError(
902                        f"index {idx} out of range for storage of size {self.size()}"
903                    )
904                if idx > 0:
905                    return idx
906                else:
907                    return idx % self._size()
908            else:
909                if (idx >= self._size()) or (idx < -self._size()):
910                    raise IndexError(
911                        f"index {idx} out of range for storage of size {self.size()}"
912                    )
913                return idx % self._size()
914
915    def __setitem__(self, idx, value):
916        _warn_typed_storage_removal()
917        return self._setitem(idx, value)
918
919    def _setitem(self, idx, value):
920        if not isinstance(idx, (int, slice)):
921            raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
922        if torch.is_storage(value):
923            raise RuntimeError(f"cannot set item with value type {type(value)}")
924        if self.dtype in [
925            torch.quint8,
926            torch.quint4x2,
927            torch.quint2x4,
928            torch.qint32,
929            torch.qint8,
930        ]:
931            interpret_dtypes = {
932                torch.quint8: torch.uint8,
933                torch.quint4x2: torch.uint8,
934                torch.quint2x4: torch.uint8,
935                torch.qint32: torch.int32,
936                torch.qint8: torch.int8,
937            }
938            tmp_dtype = interpret_dtypes[self.dtype]
939            tmp_tensor = torch.tensor(
940                [], dtype=tmp_dtype, device=self._untyped_storage.device
941            )
942            tmp_tensor.set_(
943                TypedStorage(
944                    wrap_storage=self._untyped_storage, dtype=tmp_dtype, _internal=True
945                )
946            )
947        else:
948            tmp_tensor = torch.tensor(
949                [], dtype=self.dtype, device=self._untyped_storage.device
950            ).set_(self)
951
952        tmp_tensor[idx] = value
953
954    def __getitem__(self, idx):
955        _warn_typed_storage_removal()
956        return self._getitem(idx)
957
958    def _getitem(self, idx):
959        if self._untyped_storage.device.type == "meta":
960            raise NotImplementedError("Not available for 'meta' device type")
961
962        # NOTE: Before TypedStorage existed, indexing with a slice used to be
963        # possible for <type>Storage objects. However, it would return
964        # a storage view, which would be a hassle to implement in TypedStorage,
965        # so it was disabled
966        if isinstance(idx, slice):
967            raise RuntimeError(
968                "slices are only supported in UntypedStorage.__getitem__"
969            )
970        elif not isinstance(idx, int):
971            raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
972
973        if self.dtype in [
974            torch.quint8,
975            torch.quint4x2,
976            torch.quint2x4,
977            torch.qint32,
978            torch.qint8,
979        ]:
980            interpret_dtypes = {
981                torch.quint8: torch.uint8,
982                torch.quint4x2: torch.uint8,
983                torch.quint2x4: torch.uint8,
984                torch.qint32: torch.int32,
985                torch.qint8: torch.int8,
986            }
987            return TypedStorage(
988                wrap_storage=self._untyped_storage,
989                dtype=interpret_dtypes[self.dtype],
990                _internal=True,
991            )._getitem(idx)
992
993        idx_wrapped = self._maybe_wrap_index(idx)
994        from torch._subclasses.fake_tensor import unset_fake_temporarily
995
996        with unset_fake_temporarily():
997            tmp_tensor = torch.tensor(
998                [], dtype=self.dtype, device=self._untyped_storage.device
999            ).set_(self)
1000            return tmp_tensor[idx_wrapped].item()
1001
1002    def copy_(self, source: T, non_blocking: _Optional[bool] = None):
1003        _warn_typed_storage_removal()
1004        if isinstance(source, TypedStorage):
1005            self._untyped_storage.copy_(source._untyped_storage, non_blocking)
1006        else:
1007            self._untyped_storage.copy_(source, non_blocking)
1008        return self
1009
1010    def nbytes(self):
1011        _warn_typed_storage_removal()
1012        return self._nbytes()
1013
1014    # For internal use only, to avoid deprecation warning
1015    def _nbytes(self):
1016        return self._untyped_storage.nbytes()
1017
1018    def type(
1019        self,
1020        dtype: _Optional[str] = None,
1021        non_blocking: bool = False,
1022    ) -> Union[_StorageBase, TypedStorage, str]:
1023        _warn_typed_storage_removal()
1024        if dtype is None:
1025            legacy_class = self._get_legacy_storage_class()
1026
1027            if legacy_class is not None:
1028                return legacy_class.__module__ + "." + legacy_class.__name__
1029
1030            return ".".join([self.__module__, type(self).__name__])
1031
1032        else:
1033            return self._untyped_storage.type(dtype, non_blocking)
1034
1035    def cuda(self, device=None, non_blocking=False) -> Self:
1036        _warn_typed_storage_removal()
1037        if self.dtype in [
1038            torch.quint8,
1039            torch.quint4x2,
1040            torch.quint2x4,
1041            torch.qint32,
1042            torch.qint8,
1043        ]:
1044            raise RuntimeError("Cannot create CUDA storage with quantized dtype")
1045        cuda_storage = self._untyped_storage.cuda(device, non_blocking)
1046        return self._new_wrapped_storage(cuda_storage)
1047
1048    def hpu(self, device=None, non_blocking=False) -> Self:
1049        _warn_typed_storage_removal()
1050        if self.dtype in [
1051            torch.quint8,
1052            torch.quint4x2,
1053            torch.quint2x4,
1054            torch.qint32,
1055            torch.qint8,
1056        ]:
1057            raise RuntimeError("Cannot create HPU storage with quantized dtype")
1058        hpu_storage = self._untyped_storage.hpu(device, non_blocking)
1059        return self._new_wrapped_storage(hpu_storage)
1060
1061    def to(self, *, device: torch.device, non_blocking: bool = False) -> Self:
1062        _warn_typed_storage_removal()
1063        if self.dtype in [
1064            torch.quint8,
1065            torch.quint4x2,
1066            torch.quint2x4,
1067            torch.qint32,
1068            torch.qint8,
1069        ]:
1070            raise RuntimeError(
1071                f"Cannot create {device.type.upper()} storage with quantized dtype"
1072            )
1073        to_storage = self._untyped_storage.to(device=device, non_blocking=non_blocking)
1074        return self._new_wrapped_storage(to_storage)
1075
1076    def element_size(self):
1077        _warn_typed_storage_removal()
1078        return self._element_size()
1079
1080    # For internal use only, to avoid deprecation warning
1081    def _element_size(self):
1082        return torch._utils._element_size(self.dtype)
1083
1084    def get_device(self) -> _int:
1085        _warn_typed_storage_removal()
1086        return self._untyped_storage.get_device()
1087
1088    def __str__(self):
1089        _warn_typed_storage_removal()
1090        info_str = (
1091            f"[{torch.typename(self)}(dtype={self.dtype}, "
1092            f"device={self.device}) of size {len(self)}]"
1093        )
1094        if self.device.type == "meta":
1095            return "...\n" + info_str
1096        else:
1097            data_str = " " + "\n ".join(str(self[i]) for i in range(self.size()))
1098            return data_str + "\n" + info_str
1099
1100    def __repr__(self):
1101        _warn_typed_storage_removal()
1102        return str(self)
1103
1104    def __iter__(self):
1105        _warn_typed_storage_removal()
1106        return iter(self[i] for i in range(self.size()))
1107
1108    def __copy__(self):
1109        _warn_typed_storage_removal()
1110        return self._new_wrapped_storage(copy.copy(self._untyped_storage))
1111
1112    def __deepcopy__(self, memo):
1113        _warn_typed_storage_removal()
1114        return self._deepcopy(memo)
1115
1116    # For internal use only, to avoid deprecation warning
1117    def _deepcopy(self, memo):
1118        return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo))
1119
1120    def __sizeof__(self):
1121        _warn_typed_storage_removal()
1122        return super().__sizeof__() + self.nbytes()
1123
1124    def clone(self):
1125        """Return a copy of this storage."""
1126        _warn_typed_storage_removal()
1127        return self._new_wrapped_storage(self._untyped_storage.clone())
1128
1129    def tolist(self):
1130        """Return a list containing the elements of this storage."""
1131        _warn_typed_storage_removal()
1132        return list(self)
1133
1134    def cpu(self):
1135        """Return a CPU copy of this storage if it's not already on the CPU."""
1136        _warn_typed_storage_removal()
1137        return self._new_wrapped_storage(self._untyped_storage.cpu())
1138
1139    def is_pinned(self, device: Union[str, torch.device] = "cuda"):
1140        r"""Determine whether the CPU TypedStorage is already pinned on device.
1141
1142        Args:
1143            device (str or torch.device): The device to pin memory on. Default: ``'cuda'``
1144
1145        Returns:
1146            A boolean variable.
1147        """
1148        _warn_typed_storage_removal()
1149        return self._untyped_storage.is_pinned(device)
1150
1151    def pin_memory(self, device: Union[str, torch.device] = "cuda"):
1152        r"""Copy the CPU TypedStorage to pinned memory, if it's not already pinned.
1153
1154        Args:
1155            device (str or torch.device): The device to pin memory on. Default: ``'cuda'``.
1156
1157        Returns:
1158            A pinned CPU storage.
1159        """
1160        _warn_typed_storage_removal()
1161        return self._new_wrapped_storage(
1162            self._untyped_storage.pin_memory(device=device)
1163        )
1164
1165    def share_memory_(self):
1166        """See :meth:`torch.UntypedStorage.share_memory_`"""
1167        _warn_typed_storage_removal()
1168        return self._share_memory_()
1169
1170    # For internal use only, to avoid deprecation warning
1171    def _share_memory_(self):
1172        self._untyped_storage.share_memory_()
1173        return self
1174
1175    def _new_shared(self, size, *, device=None):
1176        """Create a new storage in shared memory with the same data type."""
1177        if device is None:
1178            device = "cpu"
1179        device = torch.device(device)
1180        untyped_storage = torch.UntypedStorage._new_shared(
1181            size * self._element_size(), device=device
1182        )
1183        return TypedStorage(
1184            wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
1185        )
1186
1187    @property
1188    def _cdata(self):
1189        return self._untyped_storage._cdata
1190
1191    @property
1192    def device(self):
1193        _warn_typed_storage_removal()
1194        return self._untyped_storage.device
1195
1196    def size(self):
1197        _warn_typed_storage_removal()
1198        return self._size()
1199
1200    # For internal use only, to avoid deprecation warning
1201    def _size(self):
1202        # NB: don't indirect through __len__, as that requires
1203        # an int to be returned
1204        return self._untyped_storage.nbytes() // self._element_size()
1205
1206    def pickle_storage_type(self):
1207        _warn_typed_storage_removal()
1208        return self._pickle_storage_type()
1209
1210    # For internal use only, to avoid deprecation warning
1211    def _pickle_storage_type(self):
1212        try:
1213            return _dtype_to_storage_type_map()[self.dtype]
1214        except KeyError as e:
1215            raise KeyError(f"dtype {self.dtype} is not recognized") from e
1216
1217    def __reduce__(self):
1218        b = io.BytesIO()
1219        torch.save(self, b, _use_new_zipfile_serialization=False)
1220        return (_load_from_bytes, (b.getvalue(),))
1221
1222    def data_ptr(self):
1223        _warn_typed_storage_removal()
1224        return self._data_ptr()
1225
1226    # For internal use only, to avoid deprecation warning
1227    def _data_ptr(self):
1228        return self._untyped_storage.data_ptr()
1229
1230    def resizable(self):
1231        _warn_typed_storage_removal()
1232        return self._untyped_storage.resizable()
1233
1234    def resize_(self, size):
1235        _warn_typed_storage_removal()
1236        self._resize_(size)
1237
1238    # For internal use only, to avoid deprecation warning
1239    def _resize_(self, size):
1240        self._untyped_storage.resize_(size * self._element_size())
1241
1242    @classmethod
1243    def _free_weak_ref(cls, *args, **kwargs):
1244        return UntypedStorage._free_weak_ref(*args, **kwargs)
1245
1246    def _weak_ref(self, *args, **kwargs):
1247        return self._untyped_storage._weak_ref(*args, **kwargs)
1248
1249    @classmethod
1250    def from_buffer(cls, *args, **kwargs):
1251        _warn_typed_storage_removal()
1252        return cls._from_buffer(*args, **kwargs)
1253
1254    @classmethod
1255    def _from_buffer(cls, *args, dtype=None, device=None, **kwargs):
1256        if cls == TypedStorage:
1257            dtype = torch.get_default_dtype() if dtype is None else dtype
1258            device = torch.device("cpu" if device is None else device)
1259            if device.type != "cpu":
1260                raise RuntimeError(
1261                    f"TypedStorage.from_buffer: Not available for device {device.type}"
1262                )
1263            untyped_storage: torch.UntypedStorage = torch.UntypedStorage.from_buffer(
1264                *args, dtype=dtype, **kwargs
1265            )
1266
1267        else:
1268            if dtype is not None or len(args) == 5:
1269                raise RuntimeError(
1270                    "from_buffer: 'dtype' can only be specified in "
1271                    "UntypedStorage.from_buffer and TypedStorage.from_buffer"
1272                )
1273            if device is not None:
1274                raise RuntimeError(
1275                    "from_buffer: 'device' can only be specified in "
1276                    "UntypedStorage.from_buffer and TypedStorage.from_buffer"
1277                )
1278
1279            dtype = cls._dtype
1280            untyped_storage = torch.UntypedStorage.from_buffer(
1281                *args, dtype=dtype, **kwargs
1282            )
1283
1284        return TypedStorage(wrap_storage=untyped_storage, dtype=dtype, _internal=True)
1285
1286    def _to(self, dtype):
1287        if not isinstance(dtype, torch.dtype):
1288            raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
1289        storage = (
1290            torch.tensor([], dtype=self.dtype, device=self.device)
1291            .set_(self)
1292            .to(dtype)
1293            ._typed_storage()
1294        )
1295        if storage.data_ptr() == self.data_ptr():
1296            storage = storage.clone()
1297        return storage
1298
1299    def double(self):
1300        """Casts this storage to double type."""
1301        _warn_typed_storage_removal()
1302        return self._to(torch.double)
1303
1304    def float(self):
1305        """Casts this storage to float type."""
1306        _warn_typed_storage_removal()
1307        return self._to(torch.float)
1308
1309    def half(self):
1310        """Casts this storage to half type."""
1311        _warn_typed_storage_removal()
1312        return self._to(torch.half)
1313
1314    def long(self):
1315        """Casts this storage to long type."""
1316        _warn_typed_storage_removal()
1317        return self._to(torch.long)
1318
1319    def int(self):
1320        """Casts this storage to int type."""
1321        _warn_typed_storage_removal()
1322        return self._to(torch.int)
1323
1324    def short(self):
1325        """Casts this storage to short type."""
1326        _warn_typed_storage_removal()
1327        return self._to(torch.short)
1328
1329    def char(self):
1330        """Casts this storage to char type."""
1331        _warn_typed_storage_removal()
1332        return self._to(torch.int8)
1333
1334    def byte(self):
1335        """Casts this storage to byte type."""
1336        _warn_typed_storage_removal()
1337        return self._to(torch.uint8)
1338
1339    def bool(self):
1340        """Casts this storage to bool type."""
1341        _warn_typed_storage_removal()
1342        return self._to(torch.bool)
1343
1344    def bfloat16(self):
1345        """Casts this storage to bfloat16 type."""
1346        _warn_typed_storage_removal()
1347        return self._to(torch.bfloat16)
1348
1349    def complex_double(self):
1350        """Casts this storage to complex double type."""
1351        _warn_typed_storage_removal()
1352        return self._to(torch.cdouble)
1353
1354    def complex_float(self):
1355        """Casts this storage to complex float type."""
1356        _warn_typed_storage_removal()
1357        return self._to(torch.cfloat)
1358
1359    def float8_e5m2(self):
1360        """Casts this storage to float8_e5m2 type"""
1361        _warn_typed_storage_removal()
1362        return self._to(torch.float8_e5m2)
1363
1364    def float8_e4m3fn(self):
1365        """Casts this storage to float8_e4m3fn type"""
1366        _warn_typed_storage_removal()
1367        return self._to(torch.float8_e4m3fn)
1368
1369    def float8_e5m2fnuz(self):
1370        """Casts this storage to float8_e5m2fnuz type"""
1371        _warn_typed_storage_removal()
1372        return self._to(torch.float8_e5m2fnuz)
1373
1374    def float8_e4m3fnuz(self):
1375        """Casts this storage to float8_e4m3fnuz type"""
1376        _warn_typed_storage_removal()
1377        return self._to(torch.float8_e4m3fnuz)
1378
1379    @classmethod
1380    def from_file(cls, filename, shared, size):
1381        """from_file(filename, shared=False, size=0) -> Storage
1382
1383        Creates a CPU storage backed by a memory-mapped file.
1384
1385        If ``shared`` is ``True``, then memory is shared between all processes.
1386        All changes are written to the file. If ``shared`` is ``False``, then the changes on
1387        the storage do not affect the file.
1388
1389        ``size`` is the number of elements in the storage. If ``shared`` is ``False``,
1390        then the file must contain at least ``size * sizeof(Type)`` bytes
1391        (``Type`` is the type of storage). If ``shared`` is ``True`` the file will be created if needed.
1392
1393        Args:
1394            filename (str): file name to map
1395            shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the
1396                            underlying `mmap(2) call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_)
1397            size (int): number of elements in the storage
1398        """
1399        _warn_typed_storage_removal()
1400        if cls == TypedStorage:
1401            raise RuntimeError("from_file can only be called on derived classes")
1402        untyped_storage = UntypedStorage.from_file(
1403            filename, shared, size * torch._utils._element_size(cls.dtype)
1404        )
1405        storage = cls(wrap_storage=untyped_storage)
1406        return storage
1407
1408    @classmethod
1409    def _expired(cls, *args, **kwargs):
1410        return UntypedStorage._expired(*args, **kwargs)
1411
1412    def _write_file(self, *args, **kwargs):
1413        return self._untyped_storage._write_file(*args, **kwargs)
1414
1415    def _set_from_file(self, *args, **kwargs):
1416        return self._untyped_storage._set_from_file(*args, **kwargs)
1417
1418    def _set_cdata(self, *args, **kwargs):
1419        return self._untyped_storage._set_cdata(*args, **kwargs)
1420
1421    def _share_cuda_(self, *args, **kwargs):
1422        return self._untyped_storage._share_cuda_(*args, **kwargs)
1423
1424    def is_shared(self):
1425        _warn_typed_storage_removal()
1426        return self._is_shared()
1427
1428    # For internal use only, to avoid deprecation warning
1429    def _is_shared(self):
1430        return self._untyped_storage.is_shared()
1431
1432    @classmethod
1433    def _new_shared_cuda(cls, *args, **kwargs):
1434        return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
1435
1436    def _share_filename_cpu_(self, *args, **kwargs):
1437        (
1438            manager_handle,
1439            storage_handle,
1440            size,
1441        ) = self._untyped_storage._share_filename_cpu_(*args, **kwargs)
1442        return manager_handle, storage_handle, size // self._element_size()
1443
1444    def _shared_decref(self):
1445        self._untyped_storage._shared_decref()
1446        return self
1447
1448    @classmethod
1449    def _release_ipc_counter(cls, *args, device=None, **kwargs):
1450        return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
1451
1452    def _shared_incref(self, *args, **kwargs):
1453        return self._untyped_storage._shared_incref(*args, **kwargs)
1454
1455    def _share_fd_cpu_(self, *args, **kwargs):
1456        fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs)
1457        return fd, size // self._element_size()
1458
1459    def _get_legacy_storage_class(self):
1460        if self.dtype not in _dtype_to_storage_type_map():
1461            return None
1462
1463        storage_name = _dtype_to_storage_type_map()[self.dtype]
1464
1465        if self.device.type not in [
1466            "cpu",
1467            "cuda",
1468            "hpu",
1469            torch._C._get_privateuse1_backend_name(),
1470        ]:
1471            return None
1472
1473        module = (
1474            torch if self.device.type == "cpu" else getattr(torch, self.device.type)
1475        )
1476
1477        try:
1478            return getattr(module, storage_name)
1479        except AttributeError:
1480            return None
1481
1482
1483TypedStorage.type.__doc__ = _type.__doc__
1484TypedStorage.cuda.__doc__ = _StorageBase.cuda.__doc__
1485TypedStorage.hpu.__doc__ = _StorageBase.hpu.__doc__
1486TypedStorage.to.__doc__ = _to.__doc__
1487
1488
1489class _LegacyStorageMeta(type):
1490    dtype: torch.dtype
1491
1492    def __instancecheck__(cls, instance):
1493        if type(instance) == TypedStorage:
1494            cls_device = _get_device_from_module(cls.__module__)
1495            return (cls_device == instance.device.type) and (
1496                cls.dtype == instance.dtype
1497            )
1498        return False
1499
1500
1501class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta):
1502    @classmethod
1503    def _new_shared(cls, size):
1504        """Create a new storage in shared memory with the same data type."""
1505        untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size())
1506        return cls(wrap_storage=untyped_storage)
1507
1508    @classmethod
1509    def _release_ipc_counter(cls, *args, **kwargs):
1510        return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
1511
1512    @classmethod
1513    def _new_shared_filename(cls, manager, obj, size):
1514        bytes_size = size * torch._utils._element_size(cls.dtype)
1515        return cls(
1516            wrap_storage=torch.UntypedStorage._new_shared_filename_cpu(
1517                manager, obj, bytes_size
1518            )
1519        )
1520
1521
1522def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
1523    try:
1524        return _storage_type_to_dtype_map()[pickle_storage_type]
1525    except KeyError as e:
1526        raise KeyError(
1527            f'pickle storage type "{pickle_storage_type}" is not recognized'
1528        ) from e
1529