1# mypy: allow-untyped-defs 2import copyreg 3import enum 4import functools 5import warnings 6from collections import OrderedDict 7from copy import deepcopy 8from numbers import Number 9from typing import Any, Dict, Optional, Tuple, Union 10 11import torch 12import torch._C as _C 13from torch._namedtensor_internals import ( 14 check_serializing_named_tensor, 15 is_ellipsis, 16 resolve_ellipsis, 17 single_ellipsis_index, 18 unzip_namedshape, 19 update_names, 20) 21from torch.overrides import ( 22 get_default_nowrap_functions, 23 handle_torch_function, 24 has_torch_function, 25 has_torch_function_unary, 26 has_torch_function_variadic, 27) 28 29 30def _handle_torch_function_and_wrap_type_error_to_not_implemented(f): 31 assigned = functools.WRAPPER_ASSIGNMENTS 32 33 @functools.wraps(f, assigned=assigned) 34 def wrapped(*args, **kwargs): 35 try: 36 # See https://github.com/pytorch/pytorch/issues/75462 37 if has_torch_function(args): 38 return handle_torch_function(wrapped, args, *args, **kwargs) 39 return f(*args, **kwargs) 40 except TypeError: 41 return NotImplemented 42 43 return wrapped 44 45 46# Should not be used, this is kept only for BC of loading old serialized Tensor subclasses 47def _rebuild_from_type(func, type, args, dict): 48 if type is Tensor: 49 return func(*args) 50 51 ret = func(*args).as_subclass(type) 52 ret.__dict__ = dict 53 return ret 54 55 56def _rebuild_from_type_v2(func, new_type, args, state): 57 ret = func(*args) 58 if type(ret) is not new_type: 59 ret = ret.as_subclass(new_type) 60 # Tensor does define __setstate__ even though it doesn't define 61 # __getstate__. So only use __setstate__ if it is NOT the one defined 62 # on Tensor 63 if ( 64 getattr(ret.__class__, "__setstate__", Tensor.__setstate__) 65 is not Tensor.__setstate__ 66 ): 67 ret.__setstate__(state) 68 else: 69 ret = torch._utils._set_obj_state(ret, state) 70 return ret 71 72 73# NB: If you subclass Tensor, and want to share the subclassed class 74# across processes, you must also update torch/multiprocessing/reductions.py 75# to define a ForkingPickler serialization mode for the class. 76# 77# NB: If you add a new method to Tensor, you must update 78# torch/_C/__init__.pyi.in to add a type annotation for your method; 79# otherwise, it will not show up in autocomplete. 80class Tensor(torch._C.TensorBase): 81 def __deepcopy__(self, memo): 82 if has_torch_function_unary(self): 83 return handle_torch_function(Tensor.__deepcopy__, (self,), self, memo) 84 if not self.is_leaf: 85 raise RuntimeError( 86 "Only Tensors created explicitly by the user " 87 "(graph leaves) support the deepcopy protocol at the moment. " 88 "If you were attempting to deepcopy a module, this may be because " 89 "of a torch.nn.utils.weight_norm usage, " 90 "see https://github.com/pytorch/pytorch/pull/103001" 91 ) 92 if id(self) in memo: 93 return memo[id(self)] 94 with torch.no_grad(): 95 # TODO: skipping storage copy is wrong for meta, as meta 96 # does accurate alias tracking; however, the code below 97 # doesn't work because of 98 # https://github.com/pytorch/pytorch/issues/47442 99 # Update the test in test_serialization if you remove 'meta' from here 100 if ( 101 self.is_sparse 102 or self.device.type 103 in ["lazy", "xla", "mtia", "mps", "maia", "meta", "ipu"] 104 or ( 105 not torch._C._has_storage(self) 106 and self.device.type == torch._C._get_privateuse1_backend_name() 107 ) 108 or (type(self) is not Tensor and self.data_ptr() == 0) 109 ): 110 new_tensor = self.clone() 111 if type(new_tensor) is not type(self): 112 raise RuntimeError( 113 "The default implementation of __deepcopy__() for wrapper subclasses " 114 "only works for subclass types that implement clone() and for which " 115 "cloning returns another instance of the same subclass. You should either " 116 "properly implement clone() for your subclass or override __deepcopy__() " 117 "if it is intended behavior for clone() to return an instance of a " 118 "different type." 119 ) 120 else: 121 new_storage = self._typed_storage()._deepcopy(memo) 122 if self.is_quantized: 123 # quantizer_params can be different type based on torch attribute 124 quantizer_params: Union[ 125 Tuple[torch.qscheme, float, int], 126 Tuple[torch.qscheme, Tensor, Tensor, int], 127 ] 128 if self.qscheme() == torch.per_tensor_affine: 129 quantizer_params = ( 130 self.qscheme(), 131 self.q_scale(), 132 self.q_zero_point(), 133 ) 134 elif self.qscheme() in ( 135 torch.per_channel_affine, 136 torch.per_channel_affine_float_qparams, 137 ): 138 quantizer_params = ( 139 self.qscheme(), 140 self.q_per_channel_scales(), 141 self.q_per_channel_zero_points(), 142 self.q_per_channel_axis(), 143 ) 144 else: 145 raise RuntimeError( 146 f"Unsupported qscheme {self.qscheme()} in deepcopy" 147 ) 148 # TODO: Once we decide to break serialization FC, no longer 149 # need to wrap with TypedStorage 150 new_tensor = torch._utils._rebuild_qtensor( 151 torch.storage.TypedStorage( 152 wrap_storage=new_storage._untyped_storage, 153 dtype=self.dtype, 154 _internal=True, 155 ), 156 self.storage_offset(), 157 self.size(), 158 self.stride(), 159 quantizer_params, 160 self.requires_grad, 161 self._backward_hooks, 162 ) 163 if type(new_tensor) is not type(self): 164 raise RuntimeError( 165 "The default implementation of __deepcopy__() for quantized tensors " 166 "expects the tensor returned by torch._utils._rebuild_qtensor() to " 167 "match the type of the instance being copied. If you encounter this, " 168 "please open an issue on PyTorch's GitHub." 169 ) 170 else: 171 new_tensor = self.new_empty([]) 172 if type(new_tensor) is not type(self): 173 raise RuntimeError( 174 "The default implementation of __deepcopy__() for non-wrapper subclasses " 175 "only works for subclass types that implement new_empty() and for which " 176 "that function returns another instance of the same subclass. You should " 177 "either properly implement new_empty() for your subclass or override " 178 "__deepcopy__() if it is intended behavior for new_empty() to return " 179 "an instance of a different type." 180 ) 181 new_tensor.set_( 182 new_storage, self.storage_offset(), self.size(), self.stride() 183 ) 184 if self.is_conj(): 185 new_tensor = new_tensor.conj_physical() 186 if self.is_neg(): 187 new_tensor = new_tensor.neg() 188 if self.requires_grad: 189 new_tensor.requires_grad_() 190 if self.grad is not None: 191 new_tensor.grad = self.grad.__deepcopy__(memo) 192 193 if type(self) is not Tensor: 194 if type(new_tensor) is not type(self): 195 raise RuntimeError( 196 "Type of deepcopy result does not match the type of the source tensor. " 197 "If you encounter this, please open an issue on PyTorch's GitHub." 198 ) 199 200 # Plain Tensors don't have slots 201 slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] 202 for slot in slots_to_save: 203 if hasattr(self, slot): 204 setattr(new_tensor, slot, deepcopy(getattr(self, slot), memo)) 205 206 new_tensor.__dict__ = deepcopy(self.__dict__, memo) 207 208 memo[id(self)] = new_tensor 209 return new_tensor 210 211 def __reduce_ex__(self, proto): 212 materialize_fake_tensors = ( 213 torch.serialization._serialization_tls.materialize_fake_tensors 214 ) 215 state = torch._utils._get_obj_state(self) 216 # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has 217 # some state that cannot be pickled 218 if ( 219 # TODO: remove hasattr, it's a hack to support versions of torch that 220 # don't have _subclasses 221 hasattr(torch, "_subclasses") 222 and type(self) is torch._subclasses.fake_tensor.FakeTensor 223 and materialize_fake_tensors 224 ) or (type(self) is Tensor and not state): 225 # Fast path for regular tensor without Python state. 226 return self._reduce_ex_internal(proto) 227 if has_torch_function_unary(self): 228 return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto) 229 func, args = self._reduce_ex_internal(proto) 230 return (_rebuild_from_type_v2, (func, type(self), args, state)) 231 232 def storage(self): 233 r""" 234 storage() -> torch.TypedStorage 235 236 Returns the underlying :class:`TypedStorage`. 237 238 .. warning:: 239 240 :class:`TypedStorage` is deprecated. It will be removed in the future, and 241 :class:`UntypedStorage` will be the only storage class. To access the 242 :class:`UntypedStorage` directly, use :attr:`Tensor.untyped_storage()`. 243 """ 244 if has_torch_function_unary(self): 245 return handle_torch_function(Tensor.storage, (self,), self) 246 247 torch.storage._warn_typed_storage_removal(stacklevel=2) 248 return self._typed_storage() 249 250 # For internal use only, to avoid raising deprecation warning 251 def _typed_storage(self): 252 untyped_storage = self.untyped_storage() 253 return torch.TypedStorage( 254 wrap_storage=untyped_storage, dtype=self.dtype, _internal=True 255 ) 256 257 def _reduce_ex_internal(self, proto): 258 check_serializing_named_tensor(self) 259 260 from torch.utils.hooks import warn_if_has_hooks 261 262 # See Note [Don't serialize hooks] 263 warn_if_has_hooks(self) 264 backward_hooks: Dict[Any, Any] = OrderedDict() 265 266 skip_data = torch.serialization._serialization_tls.skip_data 267 materialize_fake_tensors = ( 268 torch.serialization._serialization_tls.materialize_fake_tensors 269 ) 270 271 # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. 272 # We considered a few options: 273 # 1. CPU tensor can't be used here. 274 # Otherwise in torch.load CPU storage is reconstructed with randomly 275 # initialized data, moved onto backend device, and then storage is updated 276 # to the serialized content. This works perfectly for CPU/CUDA but not these backends; 277 # their tensors are disconnected with storage so they don't get the update. 278 # 2. Python list is not a good fit due to performance reason. 279 # `tolist()` converts every single element in the tensor into python objects 280 # and serialize them one by one. 281 if self.device.type in ["xla", "mtia", "maia"] or ( 282 not torch._C._has_storage(self) 283 and self.device.type == torch._C._get_privateuse1_backend_name() 284 ): 285 # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't 286 # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, 287 # this would reconstruct the BFloat16 tensor from numpy. 288 if skip_data: 289 raise RuntimeError( 290 "Cannot serialize tensors on backends with no storage under skip_data context manager" 291 ) 292 numpy_tensor = ( 293 self.cpu().numpy() 294 if self.dtype != torch.bfloat16 295 else self.cpu().to(torch.float32).numpy() 296 ) 297 return ( 298 torch._utils._rebuild_device_tensor_from_numpy, 299 (numpy_tensor, self.dtype, str(self.device), self.requires_grad), 300 ) 301 if self.device.type == "meta": 302 # NB: This implementation BREAKS storage sharing. Current 303 # hypothesis is that no one cares for meta tensors. 304 if skip_data: 305 warnings.warn( 306 "Serializing tensors on the meta device under skip_data context manager is a no-op" 307 ) 308 arg_meta = ( 309 self.dtype, 310 tuple(self.size()), 311 self.stride(), 312 self.requires_grad, 313 ) 314 return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) 315 if self.is_quantized: 316 if skip_data: 317 raise RuntimeError( 318 "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" 319 ) 320 # quantizer_params can be different type based on torch attribute 321 quantizer_params: Union[ 322 Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] 323 ] 324 if self.qscheme() == torch.per_tensor_affine: 325 quantizer_params = ( 326 torch.per_tensor_affine, 327 self.q_scale(), 328 self.q_zero_point(), 329 ) 330 elif self.qscheme() in ( 331 torch.per_channel_affine, 332 torch.per_channel_affine_float_qparams, 333 ): 334 # convert scales and zero points to tuple to avoid recursive calls 335 # when/if we get multi-axis quantized tensors in the future, the shape 336 # is recoverable from the main tensor shape 337 quantizer_params = ( 338 torch.per_channel_affine, 339 self.q_per_channel_scales(), 340 self.q_per_channel_zero_points(), 341 self.q_per_channel_axis(), 342 ) 343 else: 344 raise RuntimeError( 345 f"Serialization is not supported for tensors of type {self.qscheme()}" 346 ) 347 # TODO: Once we decide to break serialization FC, no longer 348 # need to wrap with TypedStorage 349 args_qtensor = ( 350 torch.storage.TypedStorage( 351 wrap_storage=self._typed_storage()._untyped_storage, 352 dtype=self.dtype, 353 _internal=True, 354 ), 355 self.storage_offset(), 356 tuple(self.size()), 357 self.stride(), 358 quantizer_params, 359 self.requires_grad, 360 backward_hooks, 361 ) 362 return (torch._utils._rebuild_qtensor, args_qtensor) 363 elif self.is_sparse: 364 if self.layout == torch.sparse_coo: 365 args_sparse = ( 366 self.layout, 367 (self._indices(), self._values(), self.size(), self.is_coalesced()), 368 ) 369 else: 370 raise NotImplementedError( 371 f"sparse tensor __reduce_ex__ for layout `{self.layout}`" 372 ) 373 return (torch._utils._rebuild_sparse_tensor, args_sparse) 374 elif self.layout in { 375 torch.sparse_csr, 376 torch.sparse_csc, 377 torch.sparse_bsr, 378 torch.sparse_bsc, 379 }: 380 if self.layout in {torch.sparse_csr, torch.sparse_bsr}: 381 compressed_indices, plain_indices = ( 382 self.crow_indices(), 383 self.col_indices(), 384 ) 385 else: 386 compressed_indices, plain_indices = ( 387 self.ccol_indices(), 388 self.row_indices(), 389 ) 390 args_sparse_compressed = ( 391 self.layout, 392 ( 393 compressed_indices, 394 plain_indices, 395 self.values(), 396 self.size(), 397 ), 398 ) 399 return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) 400 elif self.is_nested: 401 if skip_data: 402 raise RuntimeError( 403 "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature" 404 ) 405 args_nested = ( 406 # NB: values() currently returns the storage as a buffer in an unsafe way. 407 # Ideally, we'd use a private API for this instead. TODO: Switch to this if 408 # we ever get around to adding it. 409 self.values(), 410 self._nested_tensor_size(), 411 self._nested_tensor_strides(), 412 self._nested_tensor_storage_offsets(), 413 ) 414 return (torch._utils._rebuild_nested_tensor, args_nested) 415 elif ( 416 type(self) is not torch.Tensor 417 and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ 418 and ( 419 isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor) 420 or ( 421 not isinstance(self, torch._subclasses.fake_tensor.FakeTensor) 422 and self.data_ptr() == 0 423 ) 424 ) 425 ): 426 arg_wrapper_subclass = ( 427 type(self), 428 self.dtype, 429 tuple(self.size()), 430 self.stride(), 431 self.storage_offset(), 432 self.layout, 433 self.device, 434 self.requires_grad, 435 ) 436 return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) 437 elif ( 438 type(self) is not torch.Tensor 439 and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ 440 and ( 441 isinstance(self, torch._subclasses.fake_tensor.FakeTensor) 442 and not (skip_data and materialize_fake_tensors) 443 ) 444 ): 445 arg_wrapper_subclass = ( 446 type(self), 447 self.dtype, 448 tuple(self.size()), 449 self.stride(), 450 self.storage_offset(), 451 self.layout, 452 self.device, 453 self.requires_grad, 454 ) 455 return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) 456 else: 457 v3_dtypes = torch.storage._new_dtypes() 458 if self.dtype in v3_dtypes: 459 rebuild_func = torch._utils._rebuild_tensor_v3 460 storage = self.untyped_storage() 461 else: 462 # TODO: Once we decide to break serialization FC, no longer 463 # need to wrap with TypedStorage 464 rebuild_func = torch._utils._rebuild_tensor_v2 # type: ignore[assignment] 465 storage = torch.storage.TypedStorage( 466 wrap_storage=self._typed_storage()._untyped_storage, 467 dtype=self.dtype, 468 _internal=True, 469 ) # type: ignore[assignment] 470 471 # TODO: remove hasattr, it's a hack to support versions of torch that 472 # don't have _subclasses 473 if ( 474 hasattr(torch, "_subclasses") 475 and isinstance(self, torch._subclasses.fake_tensor.FakeTensor) 476 and skip_data 477 ): 478 storage._fake_device = self.device 479 480 args = ( 481 storage, 482 self.storage_offset(), 483 tuple(self.size()), 484 self.stride(), 485 self.requires_grad, 486 backward_hooks, 487 ) # previously was self._backward_hooks 488 489 if isinstance(storage, torch.storage.UntypedStorage): 490 args = args + (self.dtype,) # type: ignore[assignment] 491 492 metadata = torch._utils.get_tensor_metadata(self) 493 if metadata: 494 args = args + (metadata,) # type: ignore[assignment] 495 496 return (rebuild_func, args) 497 498 def __setstate__(self, state): 499 if has_torch_function_unary(self): 500 return handle_torch_function(Tensor.__setstate__, (self,), self, state) 501 # Warning: this method is NOT called when you torch.load() a tensor; 502 # that is managed by _rebuild_tensor_v2 503 if not self.is_leaf: 504 raise RuntimeError("__setstate__ can be only called on leaf Tensors") 505 if len(state) == 4: 506 # legacy serialization of Tensor 507 self.set_(*state) 508 return 509 elif len(state) == 5: 510 # legacy serialization of Variable 511 self.data = state[0] 512 state = (state[3], state[4], state[2]) 513 # The setting of _backward_hooks is expected to be a no-op. 514 # See Note [Don't serialize hooks] 515 self.requires_grad, _, self._backward_hooks = state 516 517 def __repr__(self, *, tensor_contents=None): 518 if has_torch_function_unary(self): 519 return handle_torch_function( 520 Tensor.__repr__, (self,), self, tensor_contents=tensor_contents 521 ) 522 # All strings are unicode in Python 3. 523 return torch._tensor_str._str(self, tensor_contents=tensor_contents) 524 525 def backward( 526 self, gradient=None, retain_graph=None, create_graph=False, inputs=None 527 ): 528 r"""Computes the gradient of current tensor wrt graph leaves. 529 530 The graph is differentiated using the chain rule. If the tensor is 531 non-scalar (i.e. its data has more than one element) and requires 532 gradient, the function additionally requires specifying a ``gradient``. 533 It should be a tensor of matching type and shape, that represents 534 the gradient of the differentiated function w.r.t. ``self``. 535 536 This function accumulates gradients in the leaves - you might need to zero 537 ``.grad`` attributes or set them to ``None`` before calling it. 538 See :ref:`Default gradient layouts<default-grad-layouts>` 539 for details on the memory layout of accumulated gradients. 540 541 .. note:: 542 543 If you run any forward ops, create ``gradient``, and/or call ``backward`` 544 in a user-specified CUDA stream context, see 545 :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`. 546 547 .. note:: 548 549 When ``inputs`` are provided and a given input is not a leaf, 550 the current implementation will call its grad_fn (though it is not strictly needed to get this gradients). 551 It is an implementation detail on which the user should not rely. 552 See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. 553 554 Args: 555 gradient (Tensor, optional): The gradient of the function 556 being differentiated w.r.t. ``self``. 557 This argument can be omitted if ``self`` is a scalar. 558 retain_graph (bool, optional): If ``False``, the graph used to compute 559 the grads will be freed. Note that in nearly all cases setting 560 this option to True is not needed and often can be worked around 561 in a much more efficient way. Defaults to the value of 562 ``create_graph``. 563 create_graph (bool, optional): If ``True``, graph of the derivative will 564 be constructed, allowing to compute higher order derivative 565 products. Defaults to ``False``. 566 inputs (sequence of Tensor, optional): Inputs w.r.t. which the gradient will be 567 accumulated into ``.grad``. All other tensors will be ignored. If not 568 provided, the gradient is accumulated into all the leaf Tensors that were 569 used to compute the :attr:`tensors`. 570 """ 571 if has_torch_function_unary(self): 572 return handle_torch_function( 573 Tensor.backward, 574 (self,), 575 self, 576 gradient=gradient, 577 retain_graph=retain_graph, 578 create_graph=create_graph, 579 inputs=inputs, 580 ) 581 torch.autograd.backward( 582 self, gradient, retain_graph, create_graph, inputs=inputs 583 ) 584 585 def register_hook(self, hook): 586 r"""Registers a backward hook. 587 588 The hook will be called every time a gradient with respect to the 589 Tensor is computed. The hook should have the following signature:: 590 591 hook(grad) -> Tensor or None 592 593 594 The hook should not modify its argument, but it can optionally return 595 a new gradient which will be used in place of :attr:`grad`. 596 597 This function returns a handle with a method ``handle.remove()`` 598 that removes the hook from the module. 599 600 .. note:: 601 See :ref:`backward-hooks-execution` for more information on how when this hook 602 is executed, and how its execution is ordered relative to other hooks. 603 604 Example:: 605 606 >>> v = torch.tensor([0., 0., 0.], requires_grad=True) 607 >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient 608 >>> v.backward(torch.tensor([1., 2., 3.])) 609 >>> v.grad 610 611 2 612 4 613 6 614 [torch.FloatTensor of size (3,)] 615 616 >>> h.remove() # removes the hook 617 """ 618 if has_torch_function_unary(self): 619 return handle_torch_function(Tensor.register_hook, (self,), self, hook) 620 if not self.requires_grad: 621 raise RuntimeError( 622 "cannot register a hook on a tensor that doesn't require gradient" 623 ) 624 if self._backward_hooks is None: 625 self._backward_hooks = OrderedDict() 626 if self.grad_fn is not None: 627 self.grad_fn._register_hook_dict(self) 628 629 from torch.utils.hooks import RemovableHandle 630 631 handle = RemovableHandle(self._backward_hooks) 632 self._backward_hooks[handle.id] = hook 633 return handle 634 635 def register_post_accumulate_grad_hook(self, hook): 636 r"""Registers a backward hook that runs after grad accumulation. 637 638 The hook will be called after all gradients for a tensor have been accumulated, 639 meaning that the .grad field has been updated on that tensor. The post 640 accumulate grad hook is ONLY applicable for leaf tensors (tensors without a 641 .grad_fn field). Registering this hook on a non-leaf tensor will error! 642 643 The hook should have the following signature:: 644 645 hook(param: Tensor) -> None 646 647 Note that, unlike other autograd hooks, this hook operates on the tensor 648 that requires grad and not the grad itself. The hook can in-place modify 649 and access its Tensor argument, including its .grad field. 650 651 This function returns a handle with a method ``handle.remove()`` 652 that removes the hook from the module. 653 654 .. note:: 655 See :ref:`backward-hooks-execution` for more information on how when this hook 656 is executed, and how its execution is ordered relative to other hooks. Since 657 this hook runs during the backward pass, it will run in no_grad mode (unless 658 create_graph is True). You can use torch.enable_grad() to re-enable autograd 659 within the hook if you need it. 660 661 Example:: 662 663 >>> v = torch.tensor([0., 0., 0.], requires_grad=True) 664 >>> lr = 0.01 665 >>> # simulate a simple SGD update 666 >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr)) 667 >>> v.backward(torch.tensor([1., 2., 3.])) 668 >>> v 669 tensor([-0.0100, -0.0200, -0.0300], requires_grad=True) 670 671 >>> h.remove() # removes the hook 672 """ 673 if has_torch_function_unary(self): 674 return handle_torch_function( 675 Tensor.register_post_accumulate_grad_hook, (self,), self, hook 676 ) 677 if not self.requires_grad: 678 raise RuntimeError( 679 "cannot register a hook on a tensor that doesn't require gradient" 680 ) 681 if self.grad_fn is not None: 682 raise RuntimeError( 683 "post accumulate grad hooks cannot be registered on non-leaf tensors" 684 ) 685 if self._post_accumulate_grad_hooks is None: 686 self._post_accumulate_grad_hooks: Dict[Any, Any] = OrderedDict() 687 688 from torch.utils.hooks import RemovableHandle 689 690 handle = RemovableHandle(self._post_accumulate_grad_hooks) 691 self._post_accumulate_grad_hooks[handle.id] = hook 692 return handle 693 694 def reinforce(self, reward): 695 def trim(str): 696 return "\n".join([line.strip() for line in str.split("\n")]) 697 698 raise RuntimeError( 699 trim( 700 r"""reinforce() was removed. 701 Use torch.distributions instead. 702 See https://pytorch.org/docs/main/distributions.html 703 704 Instead of: 705 706 probs = policy_network(state) 707 action = probs.multinomial() 708 next_state, reward = env.step(action) 709 action.reinforce(reward) 710 action.backward() 711 712 Use: 713 714 probs = policy_network(state) 715 # NOTE: categorical is equivalent to what used to be called multinomial 716 m = torch.distributions.Categorical(probs) 717 action = m.sample() 718 next_state, reward = env.step(action) 719 loss = -m.log_prob(action) * reward 720 loss.backward() 721 """ 722 ) 723 ) 724 725 detach = _C._add_docstr( 726 _C.TensorBase.detach, 727 r""" 728 Returns a new Tensor, detached from the current graph. 729 730 The result will never require gradient. 731 732 This method also affects forward mode AD gradients and the result will never 733 have forward mode AD gradients. 734 735 .. note:: 736 737 Returned Tensor shares the same storage with the original one. 738 In-place modifications on either of them will be seen, and may trigger 739 errors in correctness checks. 740 """, 741 ) 742 743 detach_ = _C._add_docstr( 744 _C.TensorBase.detach_, 745 r""" 746 Detaches the Tensor from the graph that created it, making it a leaf. 747 Views cannot be detached in-place. 748 749 This method also affects forward mode AD gradients and the result will never 750 have forward mode AD gradients. 751 """, 752 ) 753 754 def is_shared(self): 755 r"""Checks if tensor is in shared memory. 756 757 This is always ``True`` for CUDA tensors. 758 """ 759 if has_torch_function_unary(self): 760 return handle_torch_function(Tensor.is_shared, (self,), self) 761 return self._typed_storage()._is_shared() 762 763 def share_memory_(self): 764 r"""Moves the underlying storage to shared memory. 765 766 This is a no-op if the underlying storage is already in shared memory 767 and for CUDA tensors. Tensors in shared memory cannot be resized. 768 769 See :meth:`torch.UntypedStorage.share_memory_` for more details. 770 """ 771 if has_torch_function_unary(self): 772 return handle_torch_function(Tensor.share_memory_, (self,), self) 773 self._typed_storage()._share_memory_() 774 return self 775 776 def module_load(self, other, assign=False): 777 r"""Defines how to transform ``other`` when loading it into ``self`` in :meth:`~nn.Module.load_state_dict`. 778 779 Used when :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. 780 781 It is expected that ``self`` is a parameter or buffer in an ``nn.Module`` and ``other`` is the 782 value in the state dictionary with the corresponding key, this method defines 783 how ``other`` is remapped before being swapped with ``self`` via 784 :func:`~torch.utils.swap_tensors` in :meth:`~nn.Module.load_state_dict`. 785 786 .. note:: 787 This method should always return a new object that is not ``self`` or ``other``. 788 For example, the default implementation returns ``self.copy_(other).detach()`` 789 if ``assign`` is ``False`` or ``other.detach()`` if ``assign`` is ``True``. 790 791 Args: 792 other (Tensor): value in state dict with key corresponding to ``self`` 793 assign (bool): the assign argument passed to :meth:`nn.Module.load_state_dict` 794 795 """ 796 if has_torch_function_variadic(self, other): 797 return handle_torch_function( 798 Tensor.module_load, (self, other), self, other, assign=assign 799 ) 800 801 if assign: 802 return other.detach() 803 else: 804 return self.copy_(other).detach() 805 806 def __reversed__(self): 807 r"""Reverses the tensor along dimension 0.""" 808 if has_torch_function_unary(self): 809 return handle_torch_function(Tensor.__reversed__, (self,), self) 810 if self.dim() == 0: 811 return self 812 else: 813 return self.flip(0) 814 815 def norm( 816 self, 817 p: Optional[Union[float, str]] = "fro", 818 dim=None, 819 keepdim=False, 820 dtype=None, 821 ): 822 r"""See :func:`torch.norm`""" 823 if has_torch_function_unary(self): 824 return handle_torch_function( 825 Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype 826 ) 827 return torch.norm(self, p, dim, keepdim, dtype=dtype) 828 829 def solve(self, other): 830 from torch._linalg_utils import solve 831 832 return solve(self, other) 833 834 def lstsq(self, other): 835 from torch._linalg_utils import lstsq 836 837 return lstsq(self, other) 838 839 def eig(self, eigenvectors=False): 840 from torch._linalg_utils import eig 841 842 return eig(self, eigenvectors=eigenvectors) 843 844 def symeig(self, eigenvectors=False): 845 from torch._linalg_utils import _symeig 846 847 return _symeig(self, eigenvectors=eigenvectors) 848 849 def lu(self, pivot=True, get_infos=False): 850 r"""See :func:`torch.lu`""" 851 # If get_infos is True, then we don't need to check for errors and vice versa 852 if has_torch_function_unary(self): 853 return handle_torch_function( 854 Tensor.lu, (self,), self, pivot=pivot, get_infos=get_infos 855 ) 856 857 LU, pivots, infos = torch._lu_with_info( 858 self, pivot=pivot, check_errors=(not get_infos) 859 ) 860 if get_infos: 861 return LU, pivots, infos 862 else: 863 return LU, pivots 864 865 def stft( 866 self, 867 n_fft: int, 868 hop_length: Optional[int] = None, 869 win_length: Optional[int] = None, 870 window: "Optional[Tensor]" = None, 871 center: bool = True, 872 pad_mode: str = "reflect", 873 normalized: bool = False, 874 onesided: Optional[bool] = None, 875 return_complex: Optional[bool] = None, 876 ): 877 r"""See :func:`torch.stft` 878 879 .. warning:: 880 This function changed signature at version 0.4.1. Calling with 881 the previous signature may cause error or return incorrect result. 882 """ 883 if has_torch_function_unary(self): 884 return handle_torch_function( 885 Tensor.stft, 886 (self,), 887 self, 888 n_fft, 889 hop_length=hop_length, 890 win_length=win_length, 891 window=window, 892 center=center, 893 pad_mode=pad_mode, 894 normalized=normalized, 895 onesided=onesided, 896 return_complex=return_complex, 897 ) 898 return torch.stft( 899 self, 900 n_fft, 901 hop_length, 902 win_length, 903 window, 904 center, 905 pad_mode, 906 normalized, 907 onesided, 908 return_complex=return_complex, 909 ) 910 911 def istft( 912 self, 913 n_fft: int, 914 hop_length: Optional[int] = None, 915 win_length: Optional[int] = None, 916 window: "Optional[Tensor]" = None, 917 center: bool = True, 918 normalized: bool = False, 919 onesided: Optional[bool] = None, 920 length: Optional[int] = None, 921 return_complex: bool = False, 922 ): 923 r"""See :func:`torch.istft`""" 924 if has_torch_function_unary(self): 925 return handle_torch_function( 926 Tensor.istft, 927 (self,), 928 self, 929 n_fft, 930 hop_length=hop_length, 931 win_length=win_length, 932 window=window, 933 center=center, 934 normalized=normalized, 935 onesided=onesided, 936 length=length, 937 return_complex=return_complex, 938 ) 939 return torch.istft( 940 self, 941 n_fft, 942 hop_length, 943 win_length, 944 window, 945 center, 946 normalized, 947 onesided, 948 length, 949 return_complex=return_complex, 950 ) 951 952 def resize(self, *sizes): 953 if has_torch_function_unary(self): 954 return handle_torch_function(Tensor.resize, (self,), self, *sizes) 955 warnings.warn("non-inplace resize is deprecated") 956 from torch.autograd._functions import Resize 957 958 return Resize.apply(self, sizes) 959 960 def resize_as(self, tensor): 961 if has_torch_function_variadic(self, tensor): 962 return handle_torch_function(Tensor.resize_as, (self, tensor), self, tensor) 963 warnings.warn("non-inplace resize_as is deprecated") 964 from torch.autograd._functions import Resize 965 966 return Resize.apply(self, tensor.size()) 967 968 def split(self, split_size, dim=0): 969 r"""See :func:`torch.split`""" 970 if has_torch_function_unary(self): 971 return handle_torch_function( 972 Tensor.split, (self,), self, split_size, dim=dim 973 ) 974 if isinstance(split_size, Tensor): 975 try: 976 split_size = int(split_size) 977 except ValueError: 978 pass 979 980 if isinstance(split_size, (int, torch.SymInt)): 981 return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined] 982 else: 983 return torch._VF.split_with_sizes(self, split_size, dim) 984 985 def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None): 986 r"""Returns the unique elements of the input tensor. 987 988 See :func:`torch.unique` 989 """ 990 if has_torch_function_unary(self): 991 return handle_torch_function( 992 Tensor.unique, 993 (self,), 994 self, 995 sorted=sorted, 996 return_inverse=return_inverse, 997 return_counts=return_counts, 998 dim=dim, 999 ) 1000 return torch.unique( 1001 self, 1002 sorted=sorted, 1003 return_inverse=return_inverse, 1004 return_counts=return_counts, 1005 dim=dim, 1006 ) 1007 1008 def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): 1009 r"""Eliminates all but the first element from every consecutive group of equivalent elements. 1010 1011 See :func:`torch.unique_consecutive` 1012 """ 1013 if has_torch_function_unary(self): 1014 return handle_torch_function( 1015 Tensor.unique_consecutive, 1016 (self,), 1017 self, 1018 return_inverse=return_inverse, 1019 return_counts=return_counts, 1020 dim=dim, 1021 ) 1022 return torch.unique_consecutive( 1023 self, return_inverse=return_inverse, return_counts=return_counts, dim=dim 1024 ) 1025 1026 @_handle_torch_function_and_wrap_type_error_to_not_implemented 1027 def __rsub__(self, other): 1028 return _C._VariableFunctions.rsub(self, other) 1029 1030 @_handle_torch_function_and_wrap_type_error_to_not_implemented 1031 def __rdiv__(self, other): 1032 return self.reciprocal() * other 1033 1034 __rtruediv__ = __rdiv__ 1035 __itruediv__ = _C.TensorBase.__idiv__ 1036 1037 __pow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( 1038 _C.TensorBase.pow 1039 ) 1040 __ipow__ = _handle_torch_function_and_wrap_type_error_to_not_implemented( 1041 _C.TensorBase.pow_ 1042 ) 1043 1044 @_handle_torch_function_and_wrap_type_error_to_not_implemented 1045 def __rmod__(self, other): 1046 return torch.remainder(other, self) 1047 1048 def __format__(self, format_spec): 1049 if has_torch_function_unary(self): 1050 return handle_torch_function(Tensor.__format__, (self,), self, format_spec) 1051 if self.dim() == 0 and not self.is_meta and type(self) is Tensor: 1052 return self.item().__format__(format_spec) 1053 return object.__format__(self, format_spec) 1054 1055 @_handle_torch_function_and_wrap_type_error_to_not_implemented 1056 def __rpow__(self, other): 1057 return torch.pow(other, self) 1058 1059 @_handle_torch_function_and_wrap_type_error_to_not_implemented 1060 def __floordiv__(self, other): 1061 return torch.floor_divide(self, other) 1062 1063 @_handle_torch_function_and_wrap_type_error_to_not_implemented 1064 def __rfloordiv__(self, other): 1065 return torch.floor_divide(other, self) 1066 1067 @_handle_torch_function_and_wrap_type_error_to_not_implemented 1068 def __rlshift__(self, other): 1069 return torch.bitwise_left_shift(other, self) 1070 1071 @_handle_torch_function_and_wrap_type_error_to_not_implemented 1072 def __rrshift__(self, other): 1073 return torch.bitwise_right_shift(other, self) 1074 1075 @_handle_torch_function_and_wrap_type_error_to_not_implemented 1076 def __rmatmul__(self, other): 1077 return torch.matmul(other, self) 1078 1079 __pos__ = _C.TensorBase.positive 1080 __neg__ = _C.TensorBase.neg 1081 __abs__ = _C.TensorBase.abs 1082 1083 def __len__(self): 1084 if has_torch_function_unary(self): 1085 return handle_torch_function(Tensor.__len__, (self,), self) 1086 if self.dim() == 0: 1087 raise TypeError("len() of a 0-d tensor") 1088 if torch._C._get_tracing_state(): 1089 warnings.warn( 1090 "Using len to get tensor shape might cause the trace to be incorrect. " 1091 "Recommended usage would be tensor.shape[0]. " 1092 "Passing a tensor of different shape might lead to errors or silently give " 1093 "incorrect results.", 1094 category=torch.jit.TracerWarning, 1095 stacklevel=2, 1096 ) 1097 return self.shape[0] 1098 1099 def __iter__(self): 1100 # NB: we use 'imap' and not 'map' here, so that in Python 2 we get a 1101 # generator and don't eagerly perform all the indexes. This could 1102 # save us work, and also helps keep trace ordering deterministic 1103 # (e.g., if you zip(*hiddens), the eager map will force all the 1104 # indexes of hiddens[0] before hiddens[1], while the generator 1105 # map will interleave them.) 1106 # NB: We have intentionally skipped __torch_function__ dispatch here. 1107 # See gh-54457 1108 if self.dim() == 0: 1109 raise TypeError("iteration over a 0-d tensor") 1110 if torch._C._get_tracing_state(): 1111 warnings.warn( 1112 "Iterating over a tensor might cause the trace to be incorrect. " 1113 "Passing a tensor of different shape won't change the number of " 1114 "iterations executed (and might lead to errors or silently give " 1115 "incorrect results).", 1116 category=torch.jit.TracerWarning, 1117 stacklevel=2, 1118 ) 1119 return iter(self.unbind(0)) 1120 1121 def __hash__(self): 1122 # Do NOT handle __torch_function__ here as user's default 1123 # implementation that handle most functions will most likely do it wrong. 1124 # It can be easily overridden by defining this method on the user 1125 # subclass if needed. 1126 return id(self) 1127 1128 def __dir__(self): 1129 if has_torch_function_unary(self): 1130 return handle_torch_function(Tensor.__dir__, (self,), self) 1131 tensor_methods = dir(self.__class__) 1132 tensor_methods.remove("volatile") # deprecated 1133 attrs = list(self.__dict__.keys()) 1134 keys = tensor_methods + attrs 1135 1136 # property only available dense, cuda tensors 1137 if (not self.is_cuda) or self.is_sparse: 1138 keys.remove("__cuda_array_interface__") 1139 1140 return sorted(keys) 1141 1142 # Numpy array interface, to support `numpy.asarray(tensor) -> ndarray` 1143 __array_priority__ = 1000 # prefer Tensor ops over numpy ones 1144 1145 def __array__(self, dtype=None): 1146 if has_torch_function_unary(self): 1147 return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype) 1148 if dtype is None: 1149 return self.numpy() 1150 else: 1151 return self.numpy().astype(dtype, copy=False) 1152 1153 # Wrap Numpy array again in a suitable tensor when done, to support e.g. 1154 # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor` 1155 def __array_wrap__(self, array): 1156 if has_torch_function_unary(self): 1157 return handle_torch_function( 1158 Tensor.__array_wrap__, (self,), self, array=array 1159 ) 1160 if array.dtype == bool: 1161 # Workaround, torch has no built-in bool tensor 1162 array = array.astype("uint8") 1163 return torch.from_numpy(array) 1164 1165 def __contains__(self, element: Any, /) -> bool: 1166 r"""Check if `element` is present in tensor 1167 1168 Args: 1169 element (Tensor or scalar): element to be checked 1170 for presence in current tensor" 1171 """ 1172 if has_torch_function_unary(self): 1173 return handle_torch_function(Tensor.__contains__, (self,), self, element) 1174 if isinstance( 1175 element, (torch.Tensor, Number, torch.SymInt, torch.SymFloat, torch.SymBool) 1176 ): 1177 # type hint doesn't understand the __contains__ result array 1178 return bool((element == self).any().item()) # type: ignore[union-attr] 1179 1180 raise RuntimeError( 1181 f"Tensor.__contains__ only supports Tensor or scalar, but you passed in a {type(element)}." 1182 ) 1183 1184 @property 1185 def __cuda_array_interface__(self): 1186 """Array view description for cuda tensors. 1187 1188 See: 1189 https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html 1190 """ 1191 if has_torch_function_unary(self): 1192 # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185 1193 return handle_torch_function( 1194 Tensor.__cuda_array_interface__.__get__, # type: ignore[attr-defined] 1195 (self,), 1196 self, 1197 ) 1198 1199 # raise AttributeError for unsupported tensors, so that 1200 # hasattr(cpu_tensor, "__cuda_array_interface__") is False. 1201 if not self.is_cuda: 1202 raise AttributeError( 1203 f"Can't get __cuda_array_interface__ on non-CUDA tensor type: {self.type()} " 1204 "If CUDA data is required use tensor.cuda() to copy tensor to device memory." 1205 ) 1206 1207 if self.is_sparse: 1208 raise AttributeError( 1209 f"Can't get __cuda_array_interface__ on sparse type: {self.type()} " 1210 "Use Tensor.to_dense() to convert to a dense tensor first." 1211 ) 1212 1213 # RuntimeError, matching tensor.__array__() behavior. 1214 if self.requires_grad: 1215 raise RuntimeError( 1216 "Can't get __cuda_array_interface__ on Variable that requires grad. " 1217 "If gradients aren't required, use var.detach() to get Variable that doesn't require grad." 1218 ) 1219 1220 # CUDA devices are little-endian and tensors are stored in native byte 1221 # order. 1-byte entries are endian-agnostic. 1222 typestr = { 1223 torch.complex64: "<c8", 1224 torch.complex128: "<c16", 1225 torch.bfloat16: "<f2", 1226 torch.float16: "<f2", 1227 torch.float32: "<f4", 1228 torch.float64: "<f8", 1229 torch.uint8: "|u1", 1230 torch.int8: "|i1", 1231 torch.uint16: "<u2", 1232 torch.int16: "<i2", 1233 torch.uint32: "<u4", 1234 torch.int32: "<i4", 1235 torch.uint64: "<u8", 1236 torch.int64: "<i8", 1237 torch.bool: "|b1", 1238 }[self.dtype] 1239 1240 itemsize = self.element_size() 1241 1242 shape = tuple(self.shape) 1243 if self.is_contiguous(): 1244 # __cuda_array_interface__ v2 requires the strides to be omitted 1245 # (either not set or set to None) for C-contiguous arrays. 1246 strides = None 1247 else: 1248 strides = tuple(s * itemsize for s in self.stride()) 1249 data_ptr = self.data_ptr() if self.numel() > 0 else 0 1250 data = (data_ptr, False) # read-only is false 1251 1252 return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=2) 1253 1254 def storage_type(self): 1255 r"""storage_type() -> type 1256 1257 Returns the type of the underlying storage. 1258 1259 """ 1260 if has_torch_function_unary(self): 1261 return handle_torch_function(Tensor.storage_type, (self,), self) 1262 1263 torch.storage._warn_typed_storage_removal() 1264 1265 return self._typed_storage()._get_legacy_storage_class() 1266 1267 def refine_names(self, *names): 1268 r"""Refines the dimension names of :attr:`self` according to :attr:`names`. 1269 1270 Refining is a special case of renaming that "lifts" unnamed dimensions. 1271 A ``None`` dim can be refined to have any name; a named dim can only be 1272 refined to have the same name. 1273 1274 Because named tensors can coexist with unnamed tensors, refining names 1275 gives a nice way to write named-tensor-aware code that works with both 1276 named and unnamed tensors. 1277 1278 :attr:`names` may contain up to one Ellipsis (``...``). 1279 The Ellipsis is expanded greedily; it is expanded in-place to fill 1280 :attr:`names` to the same length as ``self.dim()`` using names from the 1281 corresponding indices of ``self.names``. 1282 1283 Python 2 does not support Ellipsis but one may use a string literal 1284 instead (``'...'``). 1285 1286 Args: 1287 names (iterable of str): The desired names of the output tensor. May 1288 contain up to one Ellipsis. 1289 1290 Examples:: 1291 1292 >>> imgs = torch.randn(32, 3, 128, 128) 1293 >>> named_imgs = imgs.refine_names('N', 'C', 'H', 'W') 1294 >>> named_imgs.names 1295 ('N', 'C', 'H', 'W') 1296 1297 >>> tensor = torch.randn(2, 3, 5, 7, 11) 1298 >>> tensor = tensor.refine_names('A', ..., 'B', 'C') 1299 >>> tensor.names 1300 ('A', None, None, 'B', 'C') 1301 1302 .. warning:: 1303 The named tensor API is experimental and subject to change. 1304 1305 """ 1306 if has_torch_function_unary(self): 1307 return handle_torch_function(Tensor.refine_names, (self,), self, *names) 1308 names = resolve_ellipsis(names, self.names, "refine_names") 1309 return super().refine_names(names) 1310 1311 def align_to(self, *names): 1312 r"""Permutes the dimensions of the :attr:`self` tensor to match the order 1313 specified in :attr:`names`, adding size-one dims for any new names. 1314 1315 All of the dims of :attr:`self` must be named in order to use this method. 1316 The resulting tensor is a view on the original tensor. 1317 1318 All dimension names of :attr:`self` must be present in :attr:`names`. 1319 :attr:`names` may contain additional names that are not in ``self.names``; 1320 the output tensor has a size-one dimension for each of those new names. 1321 1322 :attr:`names` may contain up to one Ellipsis (``...``). 1323 The Ellipsis is expanded to be equal to all dimension names of :attr:`self` 1324 that are not mentioned in :attr:`names`, in the order that they appear 1325 in :attr:`self`. 1326 1327 Python 2 does not support Ellipsis but one may use a string literal 1328 instead (``'...'``). 1329 1330 Args: 1331 names (iterable of str): The desired dimension ordering of the 1332 output tensor. May contain up to one Ellipsis that is expanded 1333 to all unmentioned dim names of :attr:`self`. 1334 1335 Examples:: 1336 1337 >>> tensor = torch.randn(2, 2, 2, 2, 2, 2) 1338 >>> named_tensor = tensor.refine_names('A', 'B', 'C', 'D', 'E', 'F') 1339 1340 # Move the F and E dims to the front while keeping the rest in order 1341 >>> named_tensor.align_to('F', 'E', ...) 1342 1343 .. warning:: 1344 The named tensor API is experimental and subject to change. 1345 1346 """ 1347 if has_torch_function_unary(self): 1348 return handle_torch_function(Tensor.align_to, (self,), self, *names) 1349 ellipsis_idx = single_ellipsis_index(names, "align_to") 1350 if ellipsis_idx is None: 1351 return super().align_to(names) 1352 return super().align_to( 1353 [name for name in names if not is_ellipsis(name)], ellipsis_idx 1354 ) 1355 1356 def unflatten(self, dim, sizes): 1357 r""" 1358 unflatten(dim, sizes) -> Tensor 1359 1360 See :func:`torch.unflatten`. 1361 1362 """ 1363 if has_torch_function_unary(self): 1364 return handle_torch_function(Tensor.unflatten, (self,), self, dim, sizes) 1365 1366 if not sizes: 1367 raise RuntimeError("unflatten: sizes must be non-empty") 1368 1369 names = None 1370 if isinstance(sizes, OrderedDict) or ( 1371 isinstance(sizes, (tuple, list)) and isinstance(sizes[0], (tuple, list)) 1372 ): 1373 names, sizes = unzip_namedshape(sizes) 1374 return super().unflatten(dim, sizes, names) 1375 else: 1376 return super().unflatten(dim, sizes) 1377 1378 def rename_(self, *names, **rename_map): 1379 """In-place version of :meth:`~Tensor.rename`.""" 1380 1381 if has_torch_function_unary(self): 1382 return handle_torch_function( 1383 Tensor.rename_, (self,), self, *names, **rename_map 1384 ) 1385 1386 # Note [rename_ / rename API] 1387 # The Python API for these is different from the C++ API. In Python: 1388 # 1) tensor.rename(*names) takes a vararglist of names 1389 # 2) tensor.rename(**rename_map) takes a map of names to rename. 1390 # C++ is static, making it difficult to implement similar behavior. 1391 return update_names(self, names, rename_map, inplace=True) 1392 1393 def rename(self, *names, **rename_map): 1394 """Renames dimension names of :attr:`self`. 1395 1396 There are two main usages: 1397 1398 ``self.rename(**rename_map)`` returns a view on tensor that has dims 1399 renamed as specified in the mapping :attr:`rename_map`. 1400 1401 ``self.rename(*names)`` returns a view on tensor, renaming all 1402 dimensions positionally using :attr:`names`. 1403 Use ``self.rename(None)`` to drop names on a tensor. 1404 1405 One cannot specify both positional args :attr:`names` and keyword args 1406 :attr:`rename_map`. 1407 1408 Examples:: 1409 1410 >>> imgs = torch.rand(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) 1411 >>> renamed_imgs = imgs.rename(N='batch', C='channels') 1412 >>> renamed_imgs.names 1413 ('batch', 'channels', 'H', 'W') 1414 1415 >>> renamed_imgs = imgs.rename(None) 1416 >>> renamed_imgs.names 1417 (None, None, None, None) 1418 1419 >>> renamed_imgs = imgs.rename('batch', 'channel', 'height', 'width') 1420 >>> renamed_imgs.names 1421 ('batch', 'channel', 'height', 'width') 1422 1423 .. warning:: 1424 The named tensor API is experimental and subject to change. 1425 1426 """ 1427 if has_torch_function_unary(self): 1428 return handle_torch_function( 1429 Tensor.rename, (self,), self, *names, **rename_map 1430 ) 1431 1432 # See Note [rename_ / rename API] 1433 return update_names(self, names, rename_map, inplace=False) 1434 1435 def to_sparse_coo(self): 1436 """Convert a tensor to :ref:`coordinate format <sparse-coo-docs>`. 1437 1438 Examples:: 1439 1440 >>> dense = torch.randn(5, 5) 1441 >>> sparse = dense.to_sparse_coo() 1442 >>> sparse._nnz() 1443 25 1444 1445 """ 1446 return self.to_sparse() 1447 1448 def dim_order(self): 1449 """ 1450 1451 dim_order() -> tuple 1452 1453 Returns a tuple of int describing the dim order or physical layout of :attr:`self`. 1454 1455 Args: 1456 None 1457 1458 Dim order represents how dimensions are laid out in memory, 1459 starting from the outermost to the innermost dimension. 1460 1461 Example:: 1462 >>> torch.empty((2, 3, 5, 7)).dim_order() 1463 (0, 1, 2, 3) 1464 >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).dim_order() 1465 (0, 2, 3, 1) 1466 1467 .. warning:: 1468 The dim_order tensor API is experimental and subject to change. 1469 1470 """ 1471 if has_torch_function_unary(self): 1472 return handle_torch_function(Tensor.dim_order, (self,), self) 1473 1474 import torch._prims_common as utils 1475 1476 return tuple(utils.compute_elementwise_output_logical_to_physical_perm(self)) 1477 1478 def _update_names(self, names, inplace): 1479 if has_torch_function_unary(self): 1480 return handle_torch_function( 1481 Tensor._update_names, (self,), self, names, inplace 1482 ) 1483 1484 # See Note [rename_ / rename API] 1485 if inplace: 1486 return super().rename_(names) 1487 else: 1488 return super().rename(names) 1489 1490 @classmethod 1491 def __torch_function__(cls, func, types, args=(), kwargs=None): 1492 """ 1493 This __torch_function__ implementation wraps subclasses such that 1494 methods called on subclasses return a subclass instance instead of 1495 a ``torch.Tensor`` instance. 1496 1497 One corollary to this is that you need coverage for torch.Tensor 1498 methods if implementing __torch_function__ for subclasses. 1499 1500 We recommend always calling ``super().__torch_function__`` as the base 1501 case when doing the above. 1502 1503 While not mandatory, we recommend making `__torch_function__` a classmethod. 1504 """ 1505 if kwargs is None: 1506 kwargs = {} 1507 1508 if not all(issubclass(cls, t) for t in types): 1509 return NotImplemented 1510 1511 with _C.DisableTorchFunctionSubclass(): 1512 ret = func(*args, **kwargs) 1513 if func in get_default_nowrap_functions(): 1514 return ret 1515 else: 1516 return _convert(ret, cls) 1517 1518 __torch_dispatch__ = _C._disabled_torch_dispatch_impl 1519 1520 def __dlpack__(self, stream=None): 1521 """ 1522 Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ 1523 of the current tensor to be exported to other libraries. 1524 1525 This function will be called from the `from_dlpack` method 1526 of the library that will consume the capsule. `from_dlpack` passes the current 1527 stream to this method as part of the specification. 1528 1529 Args: 1530 stream (integer or None): An optional Python integer representing a 1531 pointer to a CUDA stream. The current stream is synchronized with 1532 this stream before the capsule is created, and since the capsule 1533 shares its storage with the tensor this make it safe to access from 1534 both streams. If None or -1 is passed then no synchronization is performed. 1535 If 1 (on CUDA) or 0 (on ROCM) then the default stream is used for 1536 synchronization. 1537 """ 1538 if has_torch_function_unary(self): 1539 return handle_torch_function(Tensor.__dlpack__, (self,), self, stream) 1540 1541 # DLPack capsules can't capture all of PyTorch's semantics, 1542 # so we prohibit exporting tensors that would lose their properties like 1543 # requires_grad and having the conjugate bit set. 1544 if self.requires_grad: 1545 raise RuntimeError( 1546 "Can't export tensors that require gradient, use tensor.detach()" 1547 ) 1548 if self.is_conj(): 1549 raise RuntimeError("Can't export tensors with the conjugate bit set") 1550 if self.layout != torch.strided: 1551 raise RuntimeError( 1552 "Can't export tensors with layout other than torch.strided" 1553 ) 1554 1555 if stream is not None and type(stream) is not int: 1556 # Stream pointers in CUDA/ROCm are uniquely numbered and can 1557 # be retrieved from their integer value. 1558 raise TypeError("stream must be ``int`` or ``none``") 1559 elif stream is not None and stream != -1: 1560 if self.device.type == "cuda": 1561 # NB: This logic handles the special case values for default 1562 # streams and must be kept in sync with from_dlpack in 1563 # torch/utils/dlpack.py 1564 if stream == 1 and torch.version.hip is None: 1565 stream = torch.cuda.default_stream() 1566 elif stream == 0 and torch.version.hip is not None: 1567 stream = torch.cuda.default_stream() 1568 else: 1569 stream = torch.cuda.ExternalStream(stream) 1570 # Only synchronize on different streams 1571 sync_stream = torch.cuda.current_stream() 1572 if stream != sync_stream: 1573 event = torch.cuda.Event() 1574 event.record(sync_stream) 1575 stream.wait_event(event) 1576 return torch.to_dlpack(self) 1577 1578 def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]: 1579 if has_torch_function_unary(self): 1580 return handle_torch_function(Tensor.__dlpack_device__, (self,), self) 1581 1582 from torch.utils.dlpack import DLDeviceType 1583 1584 device = self.device 1585 idx = device.index if device.index is not None else 0 1586 torch_device_type = device.type 1587 if torch_device_type == "cuda" and torch.version.hip is not None: 1588 device_type = DLDeviceType.kDLROCM 1589 elif torch_device_type == "cpu" and self.is_pinned(): 1590 device_type = DLDeviceType.kDLCPUPinned 1591 elif torch_device_type == "cuda": 1592 device_type = DLDeviceType.kDLGPU 1593 elif torch_device_type == "cpu": 1594 device_type = DLDeviceType.kDLCPU 1595 elif self.device.type == "xpu": 1596 device_type = DLDeviceType.kDLOneAPI 1597 else: 1598 raise ValueError(f"Unknown device type {torch_device_type} for Dlpack") 1599 return (device_type, idx) 1600 1601 __module__ = "torch" 1602 1603 1604def _convert(ret, cls): 1605 if cls is Tensor: 1606 return ret 1607 1608 if isinstance(ret, Tensor) and not isinstance(ret, cls): 1609 ret = ret.as_subclass(cls) 1610 1611 if isinstance(ret, (tuple, list)): 1612 # Also handles things like namedtuples 1613 ret = type(ret)(_convert(r, cls) for r in ret) 1614 1615 return ret 1616