xref: /aosp_15_r20/external/pytorch/torch/_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copyreg
3import enum
4import functools
5import warnings
6from collections import OrderedDict
7from copy import deepcopy
8from numbers import Number
9from typing import Any, Dict, Optional, Tuple, Union
10
11import torch
12import torch._C as _C
13from torch._namedtensor_internals import (
14    check_serializing_named_tensor,
15    is_ellipsis,
16    resolve_ellipsis,
17    single_ellipsis_index,
18    unzip_namedshape,
19    update_names,
20)
21from torch.overrides import (
22    get_default_nowrap_functions,
23    handle_torch_function,
24    has_torch_function,
25    has_torch_function_unary,
26    has_torch_function_variadic,
27)
28
29
30def _handle_torch_function_and_wrap_type_error_to_not_implemented(f):
31    assigned = functools.WRAPPER_ASSIGNMENTS
32
33    @functools.wraps(f, assigned=assigned)
34    def wrapped(*args, **kwargs):
35        try:
36            # See https://github.com/pytorch/pytorch/issues/75462
37            if has_torch_function(args):
38                return handle_torch_function(wrapped, args, *args, **kwargs)
39            return f(*args, **kwargs)
40        except TypeError:
41            return NotImplemented
42
43    return wrapped
44
45
46# Should not be used, this is kept only for BC of loading old serialized Tensor subclasses
47def _rebuild_from_type(func, type, args, dict):
48    if type is Tensor:
49        return func(*args)
50
51    ret = func(*args).as_subclass(type)
52    ret.__dict__ = dict
53    return ret
54
55
56def _rebuild_from_type_v2(func, new_type, args, state):
57    ret = func(*args)
58    if type(ret) is not new_type:
59        ret = ret.as_subclass(new_type)
60    # Tensor does define __setstate__ even though it doesn't define
61    # __getstate__. So only use __setstate__ if it is NOT the one defined
62    # on Tensor
63    if (
64        getattr(ret.__class__, "__setstate__", Tensor.__setstate__)
65        is not Tensor.__setstate__
66    ):
67        ret.__setstate__(state)
68    else:
69        ret = torch._utils._set_obj_state(ret, state)
70    return ret
71
72
73# NB: If you subclass Tensor, and want to share the subclassed class
74# across processes, you must also update torch/multiprocessing/reductions.py
75# to define a ForkingPickler serialization mode for the class.
76#
77# NB: If you add a new method to Tensor, you must update
78# torch/_C/__init__.pyi.in to add a type annotation for your method;
79# otherwise, it will not show up in autocomplete.
80class Tensor(torch._C.TensorBase):
81    def __deepcopy__(self, memo):
82        if has_torch_function_unary(self):
83            return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo)
84        if not self.is_leaf:
85            raise RuntimeError(
86                "Only Tensors created explicitly by the user "
87                "(graph leaves) support the deepcopy protocol at the moment.  "
88                "If you were attempting to deepcopy a module, this may be because "
89                "of a torch.nn.utils.weight_norm usage, "
90                "see https://github.com/pytorch/pytorch/pull/103001"
91            )
92        if id(self) in memo:
93            return memo[id(self)]
94        with torch.no_grad():
95            # TODO: skipping storage copy is wrong for meta, as meta
96            # does accurate alias tracking; however, the code below
97            # doesn't work because of
98            # https://github.com/pytorch/pytorch/issues/47442
99            # Update the test in test_serialization if you remove 'meta' from here
100            if (
101                self.is_sparse
102                or self.device.type
103                in ["lazy", "xla", "mtia", "mps", "maia", "meta", "ipu"]
104                or (
105                    not torch._C._has_storage(self)
106                    and self.device.type == torch._C._get_privateuse1_backend_name()
107                )
108                or (type(self) is not Tensor and self.data_ptr() == 0)
109            ):
110                new_tensor = self.clone()
111                if type(new_tensor) is not type(self):
112                    raise RuntimeError(
113                        "The default implementation of __deepcopy__() for wrapper subclasses "
114                        "only works for subclass types that implement clone() and for which "
115                        "cloning returns another instance of the same subclass. You should either "
116                        "properly implement clone() for your subclass or override __deepcopy__() "
117                        "if it is intended behavior for clone() to return an instance of a "
118                        "different type."
119                    )
120            else:
121                new_storage = self._typed_storage()._deepcopy(memo)
122                if self.is_quantized:
123                    # quantizer_params can be different type based on torch attribute
124                    quantizer_params: Union[
125                        Tuple[torch.qscheme, float, int],
126                        Tuple[torch.qscheme, Tensor, Tensor, int],
127                    ]
128                    if self.qscheme() == torch.per_tensor_affine:
129                        quantizer_params = (
130                            self.qscheme(),
131                            self.q_scale(),
132                            self.q_zero_point(),
133                        )
134                    elif self.qscheme() in (
135                        torch.per_channel_affine,
136                        torch.per_channel_affine_float_qparams,
137                    ):
138                        quantizer_params = (
139                            self.qscheme(),
140                            self.q_per_channel_scales(),
141                            self.q_per_channel_zero_points(),
142                            self.q_per_channel_axis(),
143                        )
144                    else:
145                        raise RuntimeError(
146                            f"Unsupported qscheme {self.qscheme()} in deepcopy"
147                        )
148                    # TODO: Once we decide to break serialization FC, no longer
149                    # need to wrap with TypedStorage
150                    new_tensor = torch._utils._rebuild_qtensor(
151                        torch.storage.TypedStorage(
152                            wrap_storage=new_storage._untyped_storage,
153                            dtype=self.dtype,
154                            _internal=True,
155                        ),
156                        self.storage_offset(),
157                        self.size(),
158                        self.stride(),
159                        quantizer_params,
160                        self.requires_grad,
161                        self._backward_hooks,
162                    )
163                    if type(new_tensor) is not type(self):
164                        raise RuntimeError(
165                            "The default implementation of __deepcopy__() for quantized tensors "
166                            "expects the tensor returned by torch._utils._rebuild_qtensor() to "
167                            "match the type of the instance being copied. If you encounter this, "
168                            "please open an issue on PyTorch's GitHub."
169                        )
170                else:
171                    new_tensor = self.new_empty([])
172                    if type(new_tensor) is not type(self):
173                        raise RuntimeError(
174                            "The default implementation of __deepcopy__() for non-wrapper subclasses "
175                            "only works for subclass types that implement new_empty() and for which "
176                            "that function returns another instance of the same subclass. You should "
177                            "either properly implement new_empty() for your subclass or override "
178                            "__deepcopy__() if it is intended behavior for new_empty() to return "
179                            "an instance of a different type."
180                        )
181                    new_tensor.set_(
182                        new_storage, self.storage_offset(), self.size(), self.stride()
183                    )
184                    if self.is_conj():
185                        new_tensor = new_tensor.conj_physical()
186                    if self.is_neg():
187                        new_tensor = new_tensor.neg()
188            if self.requires_grad:
189                new_tensor.requires_grad_()
190            if self.grad is not None:
191                new_tensor.grad = self.grad.__deepcopy__(memo)
192
193            if type(self) is not Tensor:
194                if type(new_tensor) is not type(self):
195                    raise RuntimeError(
196                        "Type of deepcopy result does not match the type of the source tensor. "
197                        "If you encounter this, please open an issue on PyTorch's GitHub."
198                    )
199
200                # Plain Tensors don't have slots
201                slots_to_save = copyreg._slotnames(self.__class__)  # type: ignore[attr-defined]
202                for slot in slots_to_save:
203                    if hasattr(self, slot):
204                        setattr(new_tensor, slot, deepcopy(getattr(self, slot), memo))
205
206            new_tensor.__dict__ = deepcopy(self.__dict__, memo)
207
208            memo[id(self)] = new_tensor
209            return new_tensor
210
211    def __reduce_ex__(self, proto):
212        materialize_fake_tensors = (
213            torch.serialization._serialization_tls.materialize_fake_tensors
214        )
215        state = torch._utils._get_obj_state(self)
216        # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has
217        # some state that cannot be pickled
218        if (
219            # TODO: remove hasattr, it's a hack to support versions of torch that
220            # don't have _subclasses
221            hasattr(torch, "_subclasses")
222            and type(self) is torch._subclasses.fake_tensor.FakeTensor
223            and materialize_fake_tensors
224        ) or (type(self) is Tensor and not state):
225            # Fast path for regular tensor without Python state.
226            return self._reduce_ex_internal(proto)
227        if has_torch_function_unary(self):
228            return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto)
229        func, args = self._reduce_ex_internal(proto)
230        return (_rebuild_from_type_v2, (func, type(self), args, state))
231
232    def storage(self):
233        r"""
234        storage() -> torch.TypedStorage
235
236        Returns the underlying :class:`TypedStorage`.
237
238        .. warning::
239
240            :class:`TypedStorage` is deprecated. It will be removed in the future, and
241            :class:`UntypedStorage` will be the only storage class. To access the
242            :class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`.
243        """
244        if has_torch_function_unary(self):
245            return handle_torch_function(Tensor.storage, (self,), self)
246
247        torch.storage._warn_typed_storage_removal(stacklevel=2)
248        return self._typed_storage()
249
250    # For internal use only, to avoid raising deprecation warning
251    def _typed_storage(self):
252        untyped_storage = self.untyped_storage()
253        return torch.TypedStorage(
254            wrap_storage=untyped_storage, dtype=self.dtype, _internal=True
255        )
256
257    def _reduce_ex_internal(self, proto):
258        check_serializing_named_tensor(self)
259
260        from torch.utils.hooks import warn_if_has_hooks
261
262        # See Note [Don't serialize hooks]
263        warn_if_has_hooks(self)
264        backward_hooks: Dict[Any, Any] = OrderedDict()
265
266        skip_data = torch.serialization._serialization_tls.skip_data
267        materialize_fake_tensors = (
268            torch.serialization._serialization_tls.materialize_fake_tensors
269        )
270
271        # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors.
272        # We considered a few options:
273        # 1. CPU tensor can't be used here.
274        #    Otherwise in torch.load CPU storage is reconstructed with randomly
275        #    initialized data, moved onto backend device, and then storage is updated
276        #    to the serialized content. This works perfectly for CPU/CUDA but not these backends;
277        #    their tensors are disconnected with storage so they don't get the update.
278        # 2. Python list is not a good fit due to performance reason.
279        #    `tolist()` converts every single element in the tensor into python objects
280        #    and serialize them one by one.
281        if self.device.type in ["xla", "mtia", "maia"] or (
282            not torch._C._has_storage(self)
283            and self.device.type == torch._C._get_privateuse1_backend_name()
284        ):
285            # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
286            # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
287            # this would reconstruct the BFloat16 tensor from numpy.
288            if skip_data:
289                raise RuntimeError(
290                    "Cannot serialize tensors on backends with no storage under skip_data context manager"
291                )
292            numpy_tensor = (
293                self.cpu().numpy()
294                if self.dtype != torch.bfloat16
295                else self.cpu().to(torch.float32).numpy()
296            )
297            return (
298                torch._utils._rebuild_device_tensor_from_numpy,
299                (numpy_tensor, self.dtype, str(self.device), self.requires_grad),
300            )
301        if self.device.type == "meta":
302            # NB: This implementation BREAKS storage sharing.  Current
303            # hypothesis is that no one cares for meta tensors.
304            if skip_data:
305                warnings.warn(
306                    "Serializing tensors on the meta device under skip_data context manager is a no-op"
307                )
308            arg_meta = (
309                self.dtype,
310                tuple(self.size()),
311                self.stride(),
312                self.requires_grad,
313            )
314            return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta)
315        if self.is_quantized:
316            if skip_data:
317                raise RuntimeError(
318                    "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature"
319                )
320            # quantizer_params can be different type based on torch attribute
321            quantizer_params: Union[
322                Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]
323            ]
324            if self.qscheme() == torch.per_tensor_affine:
325                quantizer_params = (
326                    torch.per_tensor_affine,
327                    self.q_scale(),
328                    self.q_zero_point(),
329                )
330            elif self.qscheme() in (
331                torch.per_channel_affine,
332                torch.per_channel_affine_float_qparams,
333            ):
334                # convert scales and zero points to tuple to avoid recursive calls
335                # when/if we get multi-axis quantized tensors in the future, the shape
336                # is recoverable from the main tensor shape
337                quantizer_params = (
338                    torch.per_channel_affine,
339                    self.q_per_channel_scales(),
340                    self.q_per_channel_zero_points(),
341                    self.q_per_channel_axis(),
342                )
343            else:
344                raise RuntimeError(
345                    f"Serialization is not supported for tensors of type {self.qscheme()}"
346                )
347            # TODO: Once we decide to break serialization FC, no longer
348            # need to wrap with TypedStorage
349            args_qtensor = (
350                torch.storage.TypedStorage(
351                    wrap_storage=self._typed_storage()._untyped_storage,
352                    dtype=self.dtype,
353                    _internal=True,
354                ),
355                self.storage_offset(),
356                tuple(self.size()),
357                self.stride(),
358                quantizer_params,
359                self.requires_grad,
360                backward_hooks,
361            )
362            return (torch._utils._rebuild_qtensor, args_qtensor)
363        elif self.is_sparse:
364            if self.layout == torch.sparse_coo:
365                args_sparse = (
366                    self.layout,
367                    (self._indices(), self._values(), self.size(), self.is_coalesced()),
368                )
369            else:
370                raise NotImplementedError(
371                    f"sparse tensor __reduce_ex__ for layout `{self.layout}`"
372                )
373            return (torch._utils._rebuild_sparse_tensor, args_sparse)
374        elif self.layout in {
375            torch.sparse_csr,
376            torch.sparse_csc,
377            torch.sparse_bsr,
378            torch.sparse_bsc,
379        }:
380            if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
381                compressed_indices, plain_indices = (
382                    self.crow_indices(),
383                    self.col_indices(),
384                )
385            else:
386                compressed_indices, plain_indices = (
387                    self.ccol_indices(),
388                    self.row_indices(),
389                )
390            args_sparse_compressed = (
391                self.layout,
392                (
393                    compressed_indices,
394                    plain_indices,
395                    self.values(),
396                    self.size(),
397                ),
398            )
399            return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed)
400        elif self.is_nested:
401            if skip_data:
402                raise RuntimeError(
403                    "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature"
404                )
405            args_nested = (
406                # NB: values() currently returns the storage as a buffer in an unsafe way.
407                # Ideally, we'd use a private API for this instead. TODO: Switch to this if
408                # we ever get around to adding it.
409                self.values(),
410                self._nested_tensor_size(),
411                self._nested_tensor_strides(),
412                self._nested_tensor_storage_offsets(),
413            )
414            return (torch._utils._rebuild_nested_tensor, args_nested)
415        elif (
416            type(self) is not torch.Tensor
417            and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
418            and (
419                isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor)
420                or (
421                    not isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
422                    and self.data_ptr() == 0
423                )
424            )
425        ):
426            arg_wrapper_subclass = (
427                type(self),
428                self.dtype,
429                tuple(self.size()),
430                self.stride(),
431                self.storage_offset(),
432                self.layout,
433                self.device,
434                self.requires_grad,
435            )
436            return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
437        elif (
438            type(self) is not torch.Tensor
439            and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
440            and (
441                isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
442                and not (skip_data and materialize_fake_tensors)
443            )
444        ):
445            arg_wrapper_subclass = (
446                type(self),
447                self.dtype,
448                tuple(self.size()),
449                self.stride(),
450                self.storage_offset(),
451                self.layout,
452                self.device,
453                self.requires_grad,
454            )
455            return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
456        else:
457            v3_dtypes = torch.storage._new_dtypes()
458            if self.dtype in v3_dtypes:
459                rebuild_func = torch._utils._rebuild_tensor_v3
460                storage = self.untyped_storage()
461            else:
462                # TODO: Once we decide to break serialization FC, no longer
463                # need to wrap with TypedStorage
464                rebuild_func = torch._utils._rebuild_tensor_v2  # type: ignore[assignment]
465                storage = torch.storage.TypedStorage(
466                    wrap_storage=self._typed_storage()._untyped_storage,
467                    dtype=self.dtype,
468                    _internal=True,
469                )  # type: ignore[assignment]
470
471            # TODO: remove hasattr, it's a hack to support versions of torch that
472            # don't have _subclasses
473            if (
474                hasattr(torch, "_subclasses")
475                and isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
476                and skip_data
477            ):
478                storage._fake_device = self.device
479
480            args = (
481                storage,
482                self.storage_offset(),
483                tuple(self.size()),
484                self.stride(),
485                self.requires_grad,
486                backward_hooks,
487            )  # previously was self._backward_hooks
488
489            if isinstance(storage, torch.storage.UntypedStorage):
490                args = args + (self.dtype,)  # type: ignore[assignment]
491
492            metadata = torch._utils.get_tensor_metadata(self)
493            if metadata:
494                args = args + (metadata,)  # type: ignore[assignment]
495
496            return (rebuild_func, args)
497
498    def __setstate__(self, state):
499        if has_torch_function_unary(self):
500            return handle_torch_function(Tensor.__setstate__, (self,), self, state)
501        # Warning: this method is NOT called when you torch.load() a tensor;
502        # that is managed by _rebuild_tensor_v2
503        if not self.is_leaf:
504            raise RuntimeError("__setstate__ can be only called on leaf Tensors")
505        if len(state) == 4:
506            # legacy serialization of Tensor
507            self.set_(*state)
508            return
509        elif len(state) == 5:
510            # legacy serialization of Variable
511            self.data = state[0]
512            state = (state[3], state[4], state[2])
513        # The setting of _backward_hooks is expected to be a no-op.
514        # See Note [Don't serialize hooks]
515        self.requires_grad, _, self._backward_hooks = state
516
517    def __repr__(self, *, tensor_contents=None):
518        if has_torch_function_unary(self):
519            return handle_torch_function(
520                Tensor.__repr__, (self,), self, tensor_contents=tensor_contents
521            )
522        # All strings are unicode in Python 3.
523        return torch._tensor_str._str(self, tensor_contents=tensor_contents)
524
525    def backward(
526        self, gradient=None, retain_graph=None, create_graph=False, inputs=None
527    ):
528        r"""Computes the gradient of current tensor wrt graph leaves.
529
530        The graph is differentiated using the chain rule. If the tensor is
531        non-scalar (i.e. its data has more than one element) and requires
532        gradient, the function additionally requires specifying a ``gradient``.
533        It should be a tensor of matching type and shape, that represents
534        the gradient of the differentiated function w.r.t. ``self``.
535
536        This function accumulates gradients in the leaves - you might need to zero
537        ``.grad`` attributes or set them to ``None`` before calling it.
538        See :ref:`Default gradient layouts<default-grad-layouts>`
539        for details on the memory layout of accumulated gradients.
540
541        .. note::
542
543            If you run any forward ops, create ``gradient``, and/or call ``backward``
544            in a user-specified CUDA stream context, see
545            :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.
546
547        .. note::
548
549            When ``inputs`` are provided and a given input is not a leaf,
550            the current implementation will call its grad_fn (though it is not strictly needed to get this gradients).
551            It is an implementation detail on which the user should not rely.
552            See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
553
554        Args:
555            gradient (Tensor, optional): The gradient of the function
556                being differentiated w.r.t. ``self``.
557                This argument can be omitted if ``self`` is a scalar.
558            retain_graph (bool, optional): If ``False``, the graph used to compute
559                the grads will be freed. Note that in nearly all cases setting
560                this option to True is not needed and often can be worked around
561                in a much more efficient way. Defaults to the value of
562                ``create_graph``.
563            create_graph (bool, optional): If ``True``, graph of the derivative will
564                be constructed, allowing to compute higher order derivative
565                products. Defaults to ``False``.
566            inputs (sequence of Tensor, optional): Inputs w.r.t. which the gradient will be
567                accumulated into ``.grad``. All other tensors will be ignored. If not
568                provided, the gradient is accumulated into all the leaf Tensors that were
569                used to compute the :attr:`tensors`.
570        """
571        if has_torch_function_unary(self):
572            return handle_torch_function(
573                Tensor.backward,
574                (self,),
575                self,
576                gradient=gradient,
577                retain_graph=retain_graph,
578                create_graph=create_graph,
579                inputs=inputs,
580            )
581        torch.autograd.backward(
582            self, gradient, retain_graph, create_graph, inputs=inputs
583        )
584
585    def register_hook(self, hook):
586        r"""Registers a backward hook.
587
588        The hook will be called every time a gradient with respect to the
589        Tensor is computed. The hook should have the following signature::
590
591            hook(grad) -> Tensor or None
592
593
594        The hook should not modify its argument, but it can optionally return
595        a new gradient which will be used in place of :attr:`grad`.
596
597        This function returns a handle with a method ``handle.remove()``
598        that removes the hook from the module.
599
600        .. note::
601            See :ref:`backward-hooks-execution` for more information on how when this hook
602            is executed, and how its execution is ordered relative to other hooks.
603
604        Example::
605
606            >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
607            >>> h = v.register_hook(lambda grad: grad * 2)  # double the gradient
608            >>> v.backward(torch.tensor([1., 2., 3.]))
609            >>> v.grad
610
611             2
612             4
613             6
614            [torch.FloatTensor of size (3,)]
615
616            >>> h.remove()  # removes the hook
617        """
618        if has_torch_function_unary(self):
619            return handle_torch_function(Tensor.register_hook, (self,), self, hook)
620        if not self.requires_grad:
621            raise RuntimeError(
622                "cannot register a hook on a tensor that doesn't require gradient"
623            )
624        if self._backward_hooks is None:
625            self._backward_hooks = OrderedDict()
626            if self.grad_fn is not None:
627                self.grad_fn._register_hook_dict(self)
628
629        from torch.utils.hooks import RemovableHandle
630
631        handle = RemovableHandle(self._backward_hooks)
632        self._backward_hooks[handle.id] = hook
633        return handle
634
635    def register_post_accumulate_grad_hook(self, hook):
636        r"""Registers a backward hook that runs after grad accumulation.
637
638        The hook will be called after all gradients for a tensor have been accumulated,
639        meaning that the .grad field has been updated on that tensor. The post
640        accumulate grad hook is ONLY applicable for leaf tensors (tensors without a
641        .grad_fn field). Registering this hook on a non-leaf tensor will error!
642
643        The hook should have the following signature::
644
645            hook(param: Tensor) -> None
646
647        Note that, unlike other autograd hooks, this hook operates on the tensor
648        that requires grad and not the grad itself. The hook can in-place modify
649        and access its Tensor argument, including its .grad field.
650
651        This function returns a handle with a method ``handle.remove()``
652        that removes the hook from the module.
653
654        .. note::
655            See :ref:`backward-hooks-execution` for more information on how when this hook
656            is executed, and how its execution is ordered relative to other hooks. Since
657            this hook runs during the backward pass, it will run in no_grad mode (unless
658            create_graph is True). You can use torch.enable_grad() to re-enable autograd
659            within the hook if you need it.
660
661        Example::
662
663            >>> v = torch.tensor([0., 0., 0.], requires_grad=True)
664            >>> lr = 0.01
665            >>> # simulate a simple SGD update
666            >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))
667            >>> v.backward(torch.tensor([1., 2., 3.]))
668            >>> v
669            tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)
670
671            >>> h.remove()  # removes the hook
672        """
673        if has_torch_function_unary(self):
674            return handle_torch_function(
675                Tensor.register_post_accumulate_grad_hook, (self,), self, hook
676            )
677        if not self.requires_grad:
678            raise RuntimeError(
679                "cannot register a hook on a tensor that doesn't require gradient"
680            )
681        if self.grad_fn is not None:
682            raise RuntimeError(
683                "post accumulate grad hooks cannot be registered on non-leaf tensors"
684            )
685        if self._post_accumulate_grad_hooks is None:
686            self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict()
687
688        from torch.utils.hooks import RemovableHandle
689
690        handle = RemovableHandle(self._post_accumulate_grad_hooks)
691        self._post_accumulate_grad_hooks[handle.id] = hook
692        return handle
693
694    def reinforce(self, reward):
695        def trim(str):
696            return "\n".join([line.strip() for line in str.split("\n")])
697
698        raise RuntimeError(
699            trim(
700                r"""reinforce() was removed.
701            Use torch.distributions instead.
702            See https://pytorch.org/docs/main/distributions.html
703
704            Instead of:
705
706            probs = policy_network(state)
707            action = probs.multinomial()
708            next_state, reward = env.step(action)
709            action.reinforce(reward)
710            action.backward()
711
712            Use:
713
714            probs = policy_network(state)
715            # NOTE: categorical is equivalent to what used to be called multinomial
716            m = torch.distributions.Categorical(probs)
717            action = m.sample()
718            next_state, reward = env.step(action)
719            loss = -m.log_prob(action) * reward
720            loss.backward()
721        """
722            )
723        )
724
725    detach = _C._add_docstr(
726        _C.TensorBase.detach,
727        r"""
728    Returns a new Tensor, detached from the current graph.
729
730    The result will never require gradient.
731
732    This method also affects forward mode AD gradients and the result will never
733    have forward mode AD gradients.
734
735    .. note::
736
737      Returned Tensor shares the same storage with the original one.
738      In-place modifications on either of them will be seen, and may trigger
739      errors in correctness checks.
740    """,
741    )
742
743    detach_ = _C._add_docstr(
744        _C.TensorBase.detach_,
745        r"""
746    Detaches the Tensor from the graph that created it, making it a leaf.
747    Views cannot be detached in-place.
748
749    This method also affects forward mode AD gradients and the result will never
750    have forward mode AD gradients.
751    """,
752    )
753
754    def is_shared(self):
755        r"""Checks if tensor is in shared memory.
756
757        This is always ``True`` for CUDA tensors.
758        """
759        if has_torch_function_unary(self):
760            return handle_torch_function(Tensor.is_shared, (self,), self)
761        return self._typed_storage()._is_shared()
762
763    def share_memory_(self):
764        r"""Moves the underlying storage to shared memory.
765
766        This is a no-op if the underlying storage is already in shared memory
767        and for CUDA tensors. Tensors in shared memory cannot be resized.
768
769        See :meth:`torch.UntypedStorage.share_memory_` for more details.
770        """
771        if has_torch_function_unary(self):
772            return handle_torch_function(Tensor.share_memory_, (self,), self)
773        self._typed_storage()._share_memory_()
774        return self
775
776    def module_load(self, other, assign=False):
777        r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`.
778
779        Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
780
781        It is expected that ``self`` is a parameter or buffer in an ``nn.Module`` and ``other`` is the
782        value in the state dictionary with the corresponding key, this method defines
783        how ``other`` is remapped before being swapped with ``self`` via
784        :func:`~torch.utils.swap_tensors` in :meth:`~nn.Module.load_state_dict`.
785
786        .. note::
787            This method should always return a new object that is not ``self`` or ``other``.
788            For example, the default implementation returns ``self.copy_(other).detach()``
789            if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``.
790
791        Args:
792            other (Tensor): value in state dict with key corresponding to ``self``
793            assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict`
794
795        """
796        if has_torch_function_variadic(self, other):
797            return handle_torch_function(
798                Tensor.module_load, (self, other), self, other, assign=assign
799            )
800
801        if assign:
802            return other.detach()
803        else:
804            return self.copy_(other).detach()
805
806    def __reversed__(self):
807        r"""Reverses the tensor along dimension 0."""
808        if has_torch_function_unary(self):
809            return handle_torch_function(Tensor.__reversed__, (self,), self)
810        if self.dim() == 0:
811            return self
812        else:
813            return self.flip(0)
814
815    def norm(
816        self,
817        p: Optional[Union[float, str]] = "fro",
818        dim=None,
819        keepdim=False,
820        dtype=None,
821    ):
822        r"""See :func:`torch.norm`"""
823        if has_torch_function_unary(self):
824            return handle_torch_function(
825                Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype
826            )
827        return torch.norm(self, p, dim, keepdim, dtype=dtype)
828
829    def solve(self, other):
830        from torch._linalg_utils import solve
831
832        return solve(self, other)
833
834    def lstsq(self, other):
835        from torch._linalg_utils import lstsq
836
837        return lstsq(self, other)
838
839    def eig(self, eigenvectors=False):
840        from torch._linalg_utils import eig
841
842        return eig(self, eigenvectors=eigenvectors)
843
844    def symeig(self, eigenvectors=False):
845        from torch._linalg_utils import _symeig
846
847        return _symeig(self, eigenvectors=eigenvectors)
848
849    def lu(self, pivot=True, get_infos=False):
850        r"""See :func:`torch.lu`"""
851        # If get_infos is True, then we don't need to check for errors and vice versa
852        if has_torch_function_unary(self):
853            return handle_torch_function(
854                Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos
855            )
856
857        LU, pivots, infos = torch._lu_with_info(
858            self, pivot=pivot, check_errors=(not get_infos)
859        )
860        if get_infos:
861            return LU, pivots, infos
862        else:
863            return LU, pivots
864
865    def stft(
866        self,
867        n_fft: int,
868        hop_length: Optional[int] = None,
869        win_length: Optional[int] = None,
870        window: "Optional[Tensor]" = None,
871        center: bool = True,
872        pad_mode: str = "reflect",
873        normalized: bool = False,
874        onesided: Optional[bool] = None,
875        return_complex: Optional[bool] = None,
876    ):
877        r"""See :func:`torch.stft`
878
879        .. warning::
880          This function changed signature at version 0.4.1. Calling with
881          the previous signature may cause error or return incorrect result.
882        """
883        if has_torch_function_unary(self):
884            return handle_torch_function(
885                Tensor.stft,
886                (self,),
887                self,
888                n_fft,
889                hop_length=hop_length,
890                win_length=win_length,
891                window=window,
892                center=center,
893                pad_mode=pad_mode,
894                normalized=normalized,
895                onesided=onesided,
896                return_complex=return_complex,
897            )
898        return torch.stft(
899            self,
900            n_fft,
901            hop_length,
902            win_length,
903            window,
904            center,
905            pad_mode,
906            normalized,
907            onesided,
908            return_complex=return_complex,
909        )
910
911    def istft(
912        self,
913        n_fft: int,
914        hop_length: Optional[int] = None,
915        win_length: Optional[int] = None,
916        window: "Optional[Tensor]" = None,
917        center: bool = True,
918        normalized: bool = False,
919        onesided: Optional[bool] = None,
920        length: Optional[int] = None,
921        return_complex: bool = False,
922    ):
923        r"""See :func:`torch.istft`"""
924        if has_torch_function_unary(self):
925            return handle_torch_function(
926                Tensor.istft,
927                (self,),
928                self,
929                n_fft,
930                hop_length=hop_length,
931                win_length=win_length,
932                window=window,
933                center=center,
934                normalized=normalized,
935                onesided=onesided,
936                length=length,
937                return_complex=return_complex,
938            )
939        return torch.istft(
940            self,
941            n_fft,
942            hop_length,
943            win_length,
944            window,
945            center,
946            normalized,
947            onesided,
948            length,
949            return_complex=return_complex,
950        )
951
952    def resize(self, *sizes):
953        if has_torch_function_unary(self):
954            return handle_torch_function(Tensor.resize, (self,), self, *sizes)
955        warnings.warn("non-inplace resize is deprecated")
956        from torch.autograd._functions import Resize
957
958        return Resize.apply(self, sizes)
959
960    def resize_as(self, tensor):
961        if has_torch_function_variadic(self, tensor):
962            return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor)
963        warnings.warn("non-inplace resize_as is deprecated")
964        from torch.autograd._functions import Resize
965
966        return Resize.apply(self, tensor.size())
967
968    def split(self, split_size, dim=0):
969        r"""See :func:`torch.split`"""
970        if has_torch_function_unary(self):
971            return handle_torch_function(
972                Tensor.split, (self,), self, split_size, dim=dim
973            )
974        if isinstance(split_size, Tensor):
975            try:
976                split_size = int(split_size)
977            except ValueError:
978                pass
979
980        if isinstance(split_size, (int, torch.SymInt)):
981            return torch._VF.split(self, split_size, dim)  # type: ignore[attr-defined]
982        else:
983            return torch._VF.split_with_sizes(self, split_size, dim)
984
985    def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
986        r"""Returns the unique elements of the input tensor.
987
988        See :func:`torch.unique`
989        """
990        if has_torch_function_unary(self):
991            return handle_torch_function(
992                Tensor.unique,
993                (self,),
994                self,
995                sorted=sorted,
996                return_inverse=return_inverse,
997                return_counts=return_counts,
998                dim=dim,
999            )
1000        return torch.unique(
1001            self,
1002            sorted=sorted,
1003            return_inverse=return_inverse,
1004            return_counts=return_counts,
1005            dim=dim,
1006        )
1007
1008    def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None):
1009        r"""Eliminates all but the first element from every consecutive group of equivalent elements.
1010
1011        See :func:`torch.unique_consecutive`
1012        """
1013        if has_torch_function_unary(self):
1014            return handle_torch_function(
1015                Tensor.unique_consecutive,
1016                (self,),
1017                self,
1018                return_inverse=return_inverse,
1019                return_counts=return_counts,
1020                dim=dim,
1021            )
1022        return torch.unique_consecutive(
1023            self, return_inverse=return_inverse, return_counts=return_counts, dim=dim
1024        )
1025
1026    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1027    def __rsub__(self, other):
1028        return _C._VariableFunctions.rsub(self, other)
1029
1030    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1031    def __rdiv__(self, other):
1032        return self.reciprocal() * other
1033
1034    __rtruediv__ = __rdiv__
1035    __itruediv__ = _C.TensorBase.__idiv__
1036
1037    __pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
1038        _C.TensorBase.pow
1039    )
1040    __ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented(
1041        _C.TensorBase.pow_
1042    )
1043
1044    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1045    def __rmod__(self, other):
1046        return torch.remainder(other, self)
1047
1048    def __format__(self, format_spec):
1049        if has_torch_function_unary(self):
1050            return handle_torch_function(Tensor.__format__, (self,), self, format_spec)
1051        if self.dim() == 0 and not self.is_meta and type(self) is Tensor:
1052            return self.item().__format__(format_spec)
1053        return object.__format__(self, format_spec)
1054
1055    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1056    def __rpow__(self, other):
1057        return torch.pow(other, self)
1058
1059    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1060    def __floordiv__(self, other):
1061        return torch.floor_divide(self, other)
1062
1063    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1064    def __rfloordiv__(self, other):
1065        return torch.floor_divide(other, self)
1066
1067    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1068    def __rlshift__(self, other):
1069        return torch.bitwise_left_shift(other, self)
1070
1071    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1072    def __rrshift__(self, other):
1073        return torch.bitwise_right_shift(other, self)
1074
1075    @_handle_torch_function_and_wrap_type_error_to_not_implemented
1076    def __rmatmul__(self, other):
1077        return torch.matmul(other, self)
1078
1079    __pos__ = _C.TensorBase.positive
1080    __neg__ = _C.TensorBase.neg
1081    __abs__ = _C.TensorBase.abs
1082
1083    def __len__(self):
1084        if has_torch_function_unary(self):
1085            return handle_torch_function(Tensor.__len__, (self,), self)
1086        if self.dim() == 0:
1087            raise TypeError("len() of a 0-d tensor")
1088        if torch._C._get_tracing_state():
1089            warnings.warn(
1090                "Using len to get tensor shape might cause the trace to be incorrect. "
1091                "Recommended usage would be tensor.shape[0]. "
1092                "Passing a tensor of different shape might lead to errors or silently give "
1093                "incorrect results.",
1094                category=torch.jit.TracerWarning,
1095                stacklevel=2,
1096            )
1097        return self.shape[0]
1098
1099    def __iter__(self):
1100        # NB: we use 'imap' and not 'map' here, so that in Python 2 we get a
1101        # generator and don't eagerly perform all the indexes.  This could
1102        # save us work, and also helps keep trace ordering deterministic
1103        # (e.g., if you zip(*hiddens), the eager map will force all the
1104        # indexes of hiddens[0] before hiddens[1], while the generator
1105        # map will interleave them.)
1106        # NB: We have intentionally skipped __torch_function__ dispatch here.
1107        # See gh-54457
1108        if self.dim() == 0:
1109            raise TypeError("iteration over a 0-d tensor")
1110        if torch._C._get_tracing_state():
1111            warnings.warn(
1112                "Iterating over a tensor might cause the trace to be incorrect. "
1113                "Passing a tensor of different shape won't change the number of "
1114                "iterations executed (and might lead to errors or silently give "
1115                "incorrect results).",
1116                category=torch.jit.TracerWarning,
1117                stacklevel=2,
1118            )
1119        return iter(self.unbind(0))
1120
1121    def __hash__(self):
1122        # Do NOT handle __torch_function__ here as user's default
1123        # implementation that handle most functions will most likely do it wrong.
1124        # It can be easily overridden by defining this method on the user
1125        # subclass if needed.
1126        return id(self)
1127
1128    def __dir__(self):
1129        if has_torch_function_unary(self):
1130            return handle_torch_function(Tensor.__dir__, (self,), self)
1131        tensor_methods = dir(self.__class__)
1132        tensor_methods.remove("volatile")  # deprecated
1133        attrs = list(self.__dict__.keys())
1134        keys = tensor_methods + attrs
1135
1136        # property only available dense, cuda tensors
1137        if (not self.is_cuda) or self.is_sparse:
1138            keys.remove("__cuda_array_interface__")
1139
1140        return sorted(keys)
1141
1142    # Numpy array interface, to support `numpy.asarray(tensor) -> ndarray`
1143    __array_priority__ = 1000  # prefer Tensor ops over numpy ones
1144
1145    def __array__(self, dtype=None):
1146        if has_torch_function_unary(self):
1147            return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
1148        if dtype is None:
1149            return self.numpy()
1150        else:
1151            return self.numpy().astype(dtype, copy=False)
1152
1153    # Wrap Numpy array again in a suitable tensor when done, to support e.g.
1154    # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor`
1155    def __array_wrap__(self, array):
1156        if has_torch_function_unary(self):
1157            return handle_torch_function(
1158                Tensor.__array_wrap__, (self,), self, array=array
1159            )
1160        if array.dtype == bool:
1161            # Workaround, torch has no built-in bool tensor
1162            array = array.astype("uint8")
1163        return torch.from_numpy(array)
1164
1165    def __contains__(self, element: Any, /) -> bool:
1166        r"""Check if `element` is present in tensor
1167
1168        Args:
1169            element (Tensor or scalar): element to be checked
1170                for presence in current tensor"
1171        """
1172        if has_torch_function_unary(self):
1173            return handle_torch_function(Tensor.__contains__, (self,), self, element)
1174        if isinstance(
1175            element, (torch.Tensor, Number, torch.SymInt, torch.SymFloat, torch.SymBool)
1176        ):
1177            # type hint doesn't understand the __contains__ result array
1178            return bool((element == self).any().item())  # type: ignore[union-attr]
1179
1180        raise RuntimeError(
1181            f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type(element)}."
1182        )
1183
1184    @property
1185    def __cuda_array_interface__(self):
1186        """Array view description for cuda tensors.
1187
1188        See:
1189        https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
1190        """
1191        if has_torch_function_unary(self):
1192            # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
1193            return handle_torch_function(
1194                Tensor.__cuda_array_interface__.__get__,  # type: ignore[attr-defined]
1195                (self,),
1196                self,
1197            )
1198
1199        # raise AttributeError for unsupported tensors, so that
1200        # hasattr(cpu_tensor, "__cuda_array_interface__") is False.
1201        if not self.is_cuda:
1202            raise AttributeError(
1203                f"Can't get __cuda_array_interface__ on non-CUDA tensor type: {self.type()} "
1204                "If CUDA data is required use tensor.cuda() to copy tensor to device memory."
1205            )
1206
1207        if self.is_sparse:
1208            raise AttributeError(
1209                f"Can't get __cuda_array_interface__ on sparse type: {self.type()} "
1210                "Use Tensor.to_dense() to convert to a dense tensor first."
1211            )
1212
1213        # RuntimeError, matching tensor.__array__() behavior.
1214        if self.requires_grad:
1215            raise RuntimeError(
1216                "Can't get __cuda_array_interface__ on Variable that requires grad. "
1217                "If gradients aren't required, use var.detach() to get Variable that doesn't require grad."
1218            )
1219
1220        # CUDA devices are little-endian and tensors are stored in native byte
1221        # order. 1-byte entries are endian-agnostic.
1222        typestr = {
1223            torch.complex64: "<c8",
1224            torch.complex128: "<c16",
1225            torch.bfloat16: "<f2",
1226            torch.float16: "<f2",
1227            torch.float32: "<f4",
1228            torch.float64: "<f8",
1229            torch.uint8: "|u1",
1230            torch.int8: "|i1",
1231            torch.uint16: "<u2",
1232            torch.int16: "<i2",
1233            torch.uint32: "<u4",
1234            torch.int32: "<i4",
1235            torch.uint64: "<u8",
1236            torch.int64: "<i8",
1237            torch.bool: "|b1",
1238        }[self.dtype]
1239
1240        itemsize = self.element_size()
1241
1242        shape = tuple(self.shape)
1243        if self.is_contiguous():
1244            # __cuda_array_interface__ v2 requires the strides to be omitted
1245            # (either not set or set to None) for C-contiguous arrays.
1246            strides = None
1247        else:
1248            strides = tuple(s * itemsize for s in self.stride())
1249        data_ptr = self.data_ptr() if self.numel() > 0 else 0
1250        data = (data_ptr, False)  # read-only is false
1251
1252        return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=2)
1253
1254    def storage_type(self):
1255        r"""storage_type() -> type
1256
1257        Returns the type of the underlying storage.
1258
1259        """
1260        if has_torch_function_unary(self):
1261            return handle_torch_function(Tensor.storage_type, (self,), self)
1262
1263        torch.storage._warn_typed_storage_removal()
1264
1265        return self._typed_storage()._get_legacy_storage_class()
1266
1267    def refine_names(self, *names):
1268        r"""Refines the dimension names of :attr:`self` according to :attr:`names`.
1269
1270        Refining is a special case of renaming that "lifts" unnamed dimensions.
1271        A ``None`` dim can be refined to have any name; a named dim can only be
1272        refined to have the same name.
1273
1274        Because named tensors can coexist with unnamed tensors, refining names
1275        gives a nice way to write named-tensor-aware code that works with both
1276        named and unnamed tensors.
1277
1278        :attr:`names` may contain up to one Ellipsis (``...``).
1279        The Ellipsis is expanded greedily; it is expanded in-place to fill
1280        :attr:`names` to the same length as ``self.dim()`` using names from the
1281        corresponding indices of ``self.names``.
1282
1283        Python 2 does not support Ellipsis but one may use a string literal
1284        instead (``'...'``).
1285
1286        Args:
1287            names (iterable of str): The desired names of the output tensor. May
1288                contain up to one Ellipsis.
1289
1290        Examples::
1291
1292            >>> imgs = torch.randn(32, 3, 128, 128)
1293            >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W')
1294            >>> named_imgs.names
1295            ('N', 'C', 'H', 'W')
1296
1297            >>> tensor = torch.randn(2, 3, 5, 7, 11)
1298            >>> tensor = tensor.refine_names('A', ..., 'B', 'C')
1299            >>> tensor.names
1300            ('A', None, None, 'B', 'C')
1301
1302        .. warning::
1303            The named tensor API is experimental and subject to change.
1304
1305        """
1306        if has_torch_function_unary(self):
1307            return handle_torch_function(Tensor.refine_names, (self,), self, *names)
1308        names = resolve_ellipsis(names, self.names, "refine_names")
1309        return super().refine_names(names)
1310
1311    def align_to(self, *names):
1312        r"""Permutes the dimensions of the :attr:`self` tensor to match the order
1313        specified in :attr:`names`, adding size-one dims for any new names.
1314
1315        All of the dims of :attr:`self` must be named in order to use this method.
1316        The resulting tensor is a view on the original tensor.
1317
1318        All dimension names of :attr:`self` must be present in :attr:`names`.
1319        :attr:`names` may contain additional names that are not in ``self.names``;
1320        the output tensor has a size-one dimension for each of those new names.
1321
1322        :attr:`names` may contain up to one Ellipsis (``...``).
1323        The Ellipsis is expanded to be equal to all dimension names of :attr:`self`
1324        that are not mentioned in :attr:`names`, in the order that they appear
1325        in :attr:`self`.
1326
1327        Python 2 does not support Ellipsis but one may use a string literal
1328        instead (``'...'``).
1329
1330        Args:
1331            names (iterable of str): The desired dimension ordering of the
1332                output tensor. May contain up to one Ellipsis that is expanded
1333                to all unmentioned dim names of :attr:`self`.
1334
1335        Examples::
1336
1337            >>> tensor = torch.randn(2, 2, 2, 2, 2, 2)
1338            >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F')
1339
1340            # Move the F and E dims to the front while keeping the rest in order
1341            >>> named_tensor.align_to('F', 'E', ...)
1342
1343        .. warning::
1344            The named tensor API is experimental and subject to change.
1345
1346        """
1347        if has_torch_function_unary(self):
1348            return handle_torch_function(Tensor.align_to, (self,), self, *names)
1349        ellipsis_idx = single_ellipsis_index(names, "align_to")
1350        if ellipsis_idx is None:
1351            return super().align_to(names)
1352        return super().align_to(
1353            [name for name in names if not is_ellipsis(name)], ellipsis_idx
1354        )
1355
1356    def unflatten(self, dim, sizes):
1357        r"""
1358        unflatten(dim, sizes) -> Tensor
1359
1360        See :func:`torch.unflatten`.
1361
1362        """
1363        if has_torch_function_unary(self):
1364            return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes)
1365
1366        if not sizes:
1367            raise RuntimeError("unflatten: sizes must be non-empty")
1368
1369        names = None
1370        if isinstance(sizes, OrderedDict) or (
1371            isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list))
1372        ):
1373            names, sizes = unzip_namedshape(sizes)
1374            return super().unflatten(dim, sizes, names)
1375        else:
1376            return super().unflatten(dim, sizes)
1377
1378    def rename_(self, *names, **rename_map):
1379        """In-place version of :meth:`~Tensor.rename`."""
1380
1381        if has_torch_function_unary(self):
1382            return handle_torch_function(
1383                Tensor.rename_, (self,), self, *names, **rename_map
1384            )
1385
1386        # Note [rename_ / rename API]
1387        # The Python API for these is different from the C++ API. In Python:
1388        # 1) tensor.rename(*names) takes a vararglist of names
1389        # 2) tensor.rename(**rename_map) takes a map of names to rename.
1390        # C++ is static, making it difficult to implement similar behavior.
1391        return update_names(self, names, rename_map, inplace=True)
1392
1393    def rename(self, *names, **rename_map):
1394        """Renames dimension names of :attr:`self`.
1395
1396        There are two main usages:
1397
1398        ``self.rename(**rename_map)`` returns a view on tensor that has dims
1399        renamed as specified in the mapping :attr:`rename_map`.
1400
1401        ``self.rename(*names)`` returns a view on tensor, renaming all
1402        dimensions positionally using :attr:`names`.
1403        Use ``self.rename(None)`` to drop names on a tensor.
1404
1405        One cannot specify both positional args :attr:`names` and keyword args
1406        :attr:`rename_map`.
1407
1408        Examples::
1409
1410            >>> imgs = torch.rand(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
1411            >>> renamed_imgs = imgs.rename(N='batch', C='channels')
1412            >>> renamed_imgs.names
1413            ('batch', 'channels', 'H', 'W')
1414
1415            >>> renamed_imgs = imgs.rename(None)
1416            >>> renamed_imgs.names
1417            (None, None, None, None)
1418
1419            >>> renamed_imgs = imgs.rename('batch', 'channel', 'height', 'width')
1420            >>> renamed_imgs.names
1421            ('batch', 'channel', 'height', 'width')
1422
1423        .. warning::
1424            The named tensor API is experimental and subject to change.
1425
1426        """
1427        if has_torch_function_unary(self):
1428            return handle_torch_function(
1429                Tensor.rename, (self,), self, *names, **rename_map
1430            )
1431
1432        # See Note [rename_ / rename API]
1433        return update_names(self, names, rename_map, inplace=False)
1434
1435    def to_sparse_coo(self):
1436        """Convert a tensor to :ref:`coordinate format <sparse-coo-docs>`.
1437
1438        Examples::
1439
1440             >>> dense = torch.randn(5, 5)
1441             >>> sparse = dense.to_sparse_coo()
1442             >>> sparse._nnz()
1443             25
1444
1445        """
1446        return self.to_sparse()
1447
1448    def dim_order(self):
1449        """
1450
1451        dim_order() -> tuple
1452
1453        Returns a tuple of int describing the dim order or physical layout of :attr:`self`.
1454
1455        Args:
1456            None
1457
1458        Dim order represents how dimensions are laid out in memory,
1459        starting from the outermost to the innermost dimension.
1460
1461        Example::
1462            >>> torch.empty((2, 3, 5, 7)).dim_order()
1463            (0, 1, 2, 3)
1464            >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order()
1465            (0, 2, 3, 1)
1466
1467        .. warning::
1468            The dim_order tensor API is experimental and subject to change.
1469
1470        """
1471        if has_torch_function_unary(self):
1472            return handle_torch_function(Tensor.dim_order, (self,), self)
1473
1474        import torch._prims_common as utils
1475
1476        return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self))
1477
1478    def _update_names(self, names, inplace):
1479        if has_torch_function_unary(self):
1480            return handle_torch_function(
1481                Tensor._update_names, (self,), self, names, inplace
1482            )
1483
1484        # See Note [rename_ / rename API]
1485        if inplace:
1486            return super().rename_(names)
1487        else:
1488            return super().rename(names)
1489
1490    @classmethod
1491    def __torch_function__(cls, func, types, args=(), kwargs=None):
1492        """
1493        This __torch_function__ implementation wraps subclasses such that
1494        methods called on subclasses return a subclass instance instead of
1495        a ``torch.Tensor`` instance.
1496
1497        One corollary to this is that you need coverage for torch.Tensor
1498        methods if implementing __torch_function__ for subclasses.
1499
1500        We recommend always calling ``super().__torch_function__`` as the base
1501        case when doing the above.
1502
1503        While not mandatory, we recommend making `__torch_function__` a classmethod.
1504        """
1505        if kwargs is None:
1506            kwargs = {}
1507
1508        if not all(issubclass(cls, t) for t in types):
1509            return NotImplemented
1510
1511        with _C.DisableTorchFunctionSubclass():
1512            ret = func(*args, **kwargs)
1513            if func in get_default_nowrap_functions():
1514                return ret
1515            else:
1516                return _convert(ret, cls)
1517
1518    __torch_dispatch__ = _C._disabled_torch_dispatch_impl
1519
1520    def __dlpack__(self, stream=None):
1521        """
1522        Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_
1523        of the current tensor to be exported to other libraries.
1524
1525        This function will be called from the `from_dlpack` method
1526        of the library that will consume the capsule. `from_dlpack` passes the current
1527        stream to this method as part of the specification.
1528
1529        Args:
1530            stream (integer or None): An optional Python integer representing a
1531            pointer to a CUDA stream. The current stream is synchronized with
1532            this stream before the capsule is created, and since the capsule
1533            shares its storage with the tensor this make it safe to access from
1534            both streams.  If None or -1 is passed then no synchronization is performed.
1535            If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for
1536            synchronization.
1537        """
1538        if has_torch_function_unary(self):
1539            return handle_torch_function(Tensor.__dlpack__, (self,), self, stream)
1540
1541        # DLPack capsules can't capture all of PyTorch's semantics,
1542        # so we prohibit exporting tensors that would lose their properties like
1543        # requires_grad and having the conjugate bit set.
1544        if self.requires_grad:
1545            raise RuntimeError(
1546                "Can't export tensors that require gradient, use tensor.detach()"
1547            )
1548        if self.is_conj():
1549            raise RuntimeError("Can't export tensors with the conjugate bit set")
1550        if self.layout != torch.strided:
1551            raise RuntimeError(
1552                "Can't export tensors with layout other than torch.strided"
1553            )
1554
1555        if stream is not None and type(stream) is not int:
1556            # Stream pointers in CUDA/ROCm are uniquely numbered and can
1557            # be retrieved from their integer value.
1558            raise TypeError("stream must be ``int`` or ``none``")
1559        elif stream is not None and stream != -1:
1560            if self.device.type == "cuda":
1561                # NB: This logic handles the special case values for default
1562                # streams and must be kept in sync with from_dlpack in
1563                # torch/utils/dlpack.py
1564                if stream == 1 and torch.version.hip is None:
1565                    stream = torch.cuda.default_stream()
1566                elif stream == 0 and torch.version.hip is not None:
1567                    stream = torch.cuda.default_stream()
1568                else:
1569                    stream = torch.cuda.ExternalStream(stream)
1570                # Only synchronize on different streams
1571                sync_stream = torch.cuda.current_stream()
1572                if stream != sync_stream:
1573                    event = torch.cuda.Event()
1574                    event.record(sync_stream)
1575                    stream.wait_event(event)
1576        return torch.to_dlpack(self)
1577
1578    def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
1579        if has_torch_function_unary(self):
1580            return handle_torch_function(Tensor.__dlpack_device__, (self,), self)
1581
1582        from torch.utils.dlpack import DLDeviceType
1583
1584        device = self.device
1585        idx = device.index if device.index is not None else 0
1586        torch_device_type = device.type
1587        if torch_device_type == "cuda" and torch.version.hip is not None:
1588            device_type = DLDeviceType.kDLROCM
1589        elif torch_device_type == "cpu" and self.is_pinned():
1590            device_type = DLDeviceType.kDLCPUPinned
1591        elif torch_device_type == "cuda":
1592            device_type = DLDeviceType.kDLGPU
1593        elif torch_device_type == "cpu":
1594            device_type = DLDeviceType.kDLCPU
1595        elif self.device.type == "xpu":
1596            device_type = DLDeviceType.kDLOneAPI
1597        else:
1598            raise ValueError(f"Unknown device type {torch_device_type} for Dlpack")
1599        return (device_type, idx)
1600
1601    __module__ = "torch"
1602
1603
1604def _convert(ret, cls):
1605    if cls is Tensor:
1606        return ret
1607
1608    if isinstance(ret, Tensor) and not isinstance(ret, cls):
1609        ret = ret.as_subclass(cls)
1610
1611    if isinstance(ret, (tuple, list)):
1612        # Also handles things like namedtuples
1613        ret = type(ret)(_convert(r, cls) for r in ret)
1614
1615    return ret
1616