1# mypy: allow-untyped-defs 2import abc 3import collections 4import contextlib 5import functools 6import logging 7import threading 8import weakref 9from collections import defaultdict, namedtuple 10from typing import ( 11 Any, 12 Callable, 13 cast, 14 Deque, 15 Dict, 16 List, 17 Optional, 18 Sequence, 19 Set, 20 Tuple, 21 Union, 22) 23 24import torch 25from torch.autograd.variable import Variable 26from torch.utils._python_dispatch import TorchDispatchMode 27from torch.utils.hooks import RemovableHandle 28 29log = logging.getLogger(__name__) 30 31 32__all__ = [ 33 "saved_tensors_hooks", 34 "save_on_cpu", 35 "disable_saved_tensors_hooks", 36 "register_multi_grad_hook", 37 "allow_mutation_on_saved_tensors", 38 "Node", 39 "GradientEdge", 40 "get_gradient_edge", 41 "increment_version", 42] 43 44 45class Node(abc.ABC): 46 @abc.abstractmethod 47 def name(self) -> str: 48 r"""Return the name. 49 50 Example:: 51 52 >>> import torch 53 >>> a = torch.tensor([0., 0., 0.], requires_grad=True) 54 >>> b = a.clone() 55 >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) 56 >>> print(b.grad_fn.name()) 57 CloneBackward0 58 """ 59 ... 60 61 @property 62 @abc.abstractmethod 63 def next_functions(self) -> Tuple[Tuple[Optional["Node"], int], ...]: 64 ... 65 66 @abc.abstractmethod 67 def metadata(self) -> dict: 68 r"""Return the metadata.""" 69 ... 70 71 @abc.abstractmethod 72 def _register_hook_dict(self, tensor: torch.Tensor) -> None: 73 ... 74 75 @abc.abstractmethod 76 def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle: 77 r"""Register a backward hook. 78 79 The hook will be called every time a gradient with respect to the 80 Node is computed. The hook should have the following signature:: 81 82 hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None 83 84 85 The hook should not modify its argument, but it can optionally return 86 a new gradient which will be used in place of :attr:`grad_inputs`. 87 88 This function returns a handle with a method ``handle.remove()`` 89 that removes the hook from the module. 90 91 .. note:: 92 See :ref:`backward-hooks-execution` for more information on how when this hook 93 is executed, and how its execution is ordered relative to other hooks. 94 95 Example:: 96 97 >>> import torch 98 >>> a = torch.tensor([0., 0., 0.], requires_grad=True) 99 >>> b = a.clone() 100 >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) 101 >>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,)) 102 >>> b.sum().backward(retain_graph=True) 103 >>> print(a.grad) 104 tensor([2., 2., 2.]) 105 >>> handle.remove() # Removes the hook 106 >>> a.grad = None 107 >>> b.sum().backward(retain_graph=True) 108 >>> print(a.grad) 109 tensor([1., 1., 1.]) 110 """ 111 ... 112 113 @abc.abstractmethod 114 def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle: 115 r"""Register a backward pre-hook. 116 117 The hook will be called every time a gradient with respect to the 118 Node is computed. The hook should have the following signature:: 119 120 hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None 121 122 The hook should not modify its argument, but it can optionally return 123 a new gradient which will be used in place of :attr:`grad_outputs`. 124 125 This function returns a handle with a method ``handle.remove()`` 126 that removes the hook from the module. 127 128 .. note:: 129 See :ref:`backward-hooks-execution` for more information on how when this hook 130 is executed, and how its execution is ordered relative to other hooks. 131 132 Example:: 133 134 >>> a = torch.tensor([0., 0., 0.], requires_grad=True) 135 >>> b = a.clone() 136 >>> assert isinstance(b.grad_fn, torch.autograd.graph.Node) 137 >>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,)) 138 >>> b.sum().backward(retain_graph=True) 139 >>> print(a.grad) 140 tensor([2., 2., 2.]) 141 >>> handle.remove() 142 >>> a.grad = None 143 >>> b.sum().backward(retain_graph=True) 144 >>> print(a.grad) 145 tensor([1., 1., 1.]) 146 """ 147 ... 148 149 @classmethod 150 def __subclasshook__(cls, C): 151 if cls is Node: 152 if ( 153 C is not None and C is getattr(torch._C._functions, C.__name__, None) 154 ) or issubclass(C, torch.autograd.function.BackwardCFunction): 155 return True 156 return NotImplemented 157 158 159def _get_grad_fn_or_grad_acc(t): 160 if t.requires_grad and t.grad_fn is None: 161 with torch.enable_grad(): 162 return t.view_as(t).grad_fn.next_functions[0][0] 163 else: 164 return t.grad_fn 165 166 167GradientEdge = namedtuple("GradientEdge", ("node output_nr")) 168GradientEdge.__doc__ = """\ 169Object representing a given gradient edge within the autograd graph. 170To get the gradient edge where a given Tensor gradient will be computed, 171you can do ``edge = autograd.graph.get_gradient_edge(tensor)``. 172""" 173 174 175def get_gradient_edge(tensor): 176 """Get the gradient edge for computing the gradient of the given Tensor. 177 178 In particular, it is equivalent to call 179 ``g = autograd.grad(loss, input)`` and ``g = autograd.grad(loss, get_gradient_edge(input))``. 180 """ 181 if not tensor.requires_grad: 182 raise RuntimeError( 183 "It is not possible to get the gradient edge for a Tensor that does not require gradients" 184 ) 185 grad_fn = _get_grad_fn_or_grad_acc(tensor) 186 187 # Note that output_nr default to 0 which is the right value 188 # for the AccumulateGrad node. 189 return GradientEdge(grad_fn, tensor.output_nr) 190 191 192def increment_version(tensor): 193 """Update autograd metadata tracking whether the given Tensor was modified in place. 194 195 This is to enable more accurate error checking within the autograd engine. 196 It is already done automatically by PyTorch functions and within custom Function 197 when mark_dirty() is called appropriately so you only need to call this explicitly 198 if you are doing inplace operation on the Tensor data in a way that Pytorch doesn't 199 know about. For example a custom kernel that reads the Tensor data_ptr and modifies 200 the memory inplace based on this pointer. 201 202 Note that incrementing the version counter multiple times for a single inplace operation 203 is not problematic. 204 """ 205 torch._C._increment_version(tensor) 206 207 208class saved_tensors_hooks: 209 """Context-manager that sets a pair of pack / unpack hooks for saved tensors. 210 211 Use this context-manager to define how intermediary results of an operation 212 should be packed before saving, and unpacked on retrieval. 213 214 In that context, the ``pack_hook`` function will be called everytime an 215 operation saves a tensor for backward (this includes intermediary results 216 saved using 217 :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but 218 also those recorded by a PyTorch-defined operation). The output of 219 ``pack_hook`` is then stored in the computation graph instead of the 220 original tensor. 221 222 The ``unpack_hook`` is called when the saved tensor needs to be accessed, 223 namely when executing :func:`torch.Tensor.backward()` or 224 :func:`torch.autograd.grad()`. It takes as argument the *packed* object 225 returned by ``pack_hook`` and should return a tensor which has the same 226 content as the original tensor (passed as input to the corresponding 227 ``pack_hook``). 228 229 The hooks should have the following signatures: 230 231 pack_hook(tensor: Tensor) -> Any 232 233 unpack_hook(Any) -> Tensor 234 235 where the return value of ``pack_hook`` is a valid input to ``unpack_hook``. 236 237 In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms 238 of value, size, dtype and device. 239 240 Example:: 241 242 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 243 >>> def pack_hook(x): 244 ... print("Packing", x) 245 ... return x 246 >>> 247 >>> def unpack_hook(x): 248 ... print("Unpacking", x) 249 ... return x 250 >>> 251 >>> a = torch.ones(5, requires_grad=True) 252 >>> b = torch.ones(5, requires_grad=True) * 2 253 >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): 254 ... y = a * b 255 Packing tensor([1., 1., 1., 1., 1.], requires_grad=True) 256 Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>) 257 >>> y.sum().backward() 258 Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True) 259 Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>) 260 261 .. warning :: 262 Performing an inplace operation on the input to either hooks may lead 263 to undefined behavior. 264 265 .. warning :: 266 Only one pair of hooks is allowed at a time. When recursively nesting this 267 context-manager, only the inner-most pair of hooks will be applied. 268 """ 269 270 def __init__( 271 self, 272 pack_hook: Callable[[torch.Tensor], Any], 273 unpack_hook: Callable[[Any], torch.Tensor], 274 ): 275 self.pack_hook = pack_hook 276 self.unpack_hook = unpack_hook 277 278 def __enter__(self): 279 torch._C._autograd._push_saved_tensors_default_hooks( 280 self.pack_hook, self.unpack_hook 281 ) 282 283 def __exit__(self, *args: object): 284 torch._C._autograd._pop_saved_tensors_default_hooks() 285 286 287class save_on_cpu(saved_tensors_hooks): 288 """Context manager under which tensors saved by the forward pass will be stored on cpu, then retrieved for backward. 289 290 When performing operations within this context manager, intermediary 291 results saved in the graph during the forward pass will be moved to CPU, 292 then copied back to the original device when needed for the backward pass. 293 If the graph was already on CPU, no tensor copy is performed. 294 295 Use this context-manager to trade compute for GPU memory usage (e.g. 296 when your model doesn't fit in GPU memory during training). 297 298 Args: 299 pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory 300 during packing and copied to GPU asynchronously during unpacking. 301 Defaults to ``False``. 302 Also see :ref:`cuda-memory-pinning`. 303 304 305 Example:: 306 307 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 308 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 309 >>> a = torch.randn(5, requires_grad=True, device="cuda") 310 >>> b = torch.randn(5, requires_grad=True, device="cuda") 311 >>> c = torch.randn(5, requires_grad=True, device="cuda") 312 >>> 313 >>> def f(a, b, c): 314 ... prod_1 = a * b # a and b are saved on GPU 315 ... with torch.autograd.graph.save_on_cpu(): 316 ... prod_2 = prod_1 * c # prod_1 and c are saved on CPU 317 ... y = prod_2 * a # prod_2 and a are saved on GPU 318 ... return y 319 >>> 320 >>> y = f(a, b, c) 321 >>> del a, b, c # for illustration only 322 >>> # the content of a, b, and prod_2 are still alive on GPU 323 >>> # the content of prod_1 and c only live on CPU 324 >>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward 325 >>> # all intermediary tensors are released (deleted) after the call to backward 326 327 """ 328 329 def __init__(self, pin_memory=False, device_type="cuda"): 330 device_module = getattr(torch, device_type, torch.cuda) 331 332 def pack_to_cpu(tensor): 333 if not pin_memory: 334 return (tensor.device, tensor.cpu()) 335 packed = torch.empty( 336 tensor.size(), 337 dtype=tensor.dtype, 338 layout=tensor.layout, 339 pin_memory=(device_module.is_available() and not tensor.is_sparse), 340 ) 341 packed.copy_(tensor) 342 return (tensor.device, packed) 343 344 def unpack_from_cpu(packed): 345 device, tensor = packed 346 return tensor.to(device, non_blocking=pin_memory) 347 348 super().__init__(pack_to_cpu, unpack_from_cpu) 349 350 351@contextlib.contextmanager 352def disable_saved_tensors_hooks(error_message): 353 """Context-manager that disables the saved tensors default hooks feature. 354 355 Useful for if you are creating a feature that does not work with saved 356 tensors default hooks. 357 358 Args: 359 error_message (str): When saved tensors default hooks are used when they 360 have been are disabled, a RuntimeError with this 361 error message gets raised. 362 363 Example:: 364 365 >>> # xdoctest: +SKIP(failing) 366 >>> message = "saved tensors default hooks are disabled" 367 >>> with torch.autograd.graph.disable_saved_tensors_hooks(message): 368 ... # Raises RuntimeError: saved tensors default hooks are disabled 369 ... with torch.autograd.graph.save_on_cpu(): 370 ... pass 371 372 """ 373 try: 374 maybe_prev_message = ( 375 torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() 376 ) 377 torch._C._autograd._saved_tensors_hooks_disable(error_message) 378 yield 379 finally: 380 # See NOTE: [disabled_error_message invariant] 381 if maybe_prev_message is None: 382 torch._C._autograd._saved_tensors_hooks_enable() 383 else: 384 torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message) 385 386 387class _MultiHandle(RemovableHandle): 388 handles: Tuple[RemovableHandle, ...] 389 390 def __init__(self, handles: Tuple[RemovableHandle, ...]): 391 self.handles = handles 392 393 def remove(self): 394 for handle in self.handles: 395 handle.remove() 396 397 def __getstate__(self): 398 return self.handles 399 400 def __setstate__(self, state): 401 self.handles = state 402 403 404def register_multi_grad_hook( 405 tensors: Sequence[torch.Tensor], 406 fn: Union[ 407 Callable[[Sequence[Optional[torch.Tensor]]], None], 408 Callable[[torch.Tensor], None], 409 ], 410 *, 411 mode: str = "all", 412): 413 r"""Register a multi-grad backward hook. 414 415 There are two supported modes: ``"all"`` and ``"any"``. 416 417 Under the ``"all"`` mode, the hook will be called after gradients with respect to every tensor in 418 :attr:`tensors` have been computed. If a tensor is in :attr:`tensors` but 419 is not part of the graph, or if a tensor is not needed to compute the gradients 420 for any ``inputs`` specified for the current ``.backward()`` or ``.grad()`` call, 421 this tensor will be ignored and the hook will not wait for its gradient to be 422 computed. 423 424 After every non-ignored tensor's gradient has been computed, :attr:`fn` will be 425 called with those gradients. ``None`` will be passed for tensors that did not 426 have their gradients computed. 427 428 Under the ``"any"`` mode, the hook will be called after the first gradient 429 with respect to a tensor in :attr:`tensors` has been computed. The hook 430 will be called with that gradient as its argument. 431 432 The hook should not modify its arguments. 433 434 This function returns a handle with a method ``handle.remove()`` that removes the hook. 435 436 .. note:: 437 See :ref:`backward-hooks-execution` for more information on how when this hook 438 is executed, and how its execution is ordered relative to other hooks. 439 440 Example:: 441 442 >>> import torch 443 >>> 444 >>> a = torch.rand(2, 3, requires_grad=True) 445 >>> b = torch.rand(2, 3, requires_grad=True) 446 >>> c = a * b 447 >>> d = a * b 448 >>> 449 >>> def fn(grads): 450 ... print([g is not None for g in grads]) 451 ... 452 >>> torch.autograd.graph.register_multi_grad_hook((a, b, c, d), fn) 453 >>> 454 >>> c.sum().backward(retain_graph=True) 455 [True, True, True, False] 456 >>> c.sum().backward(inputs=(a,), retain_graph=True) 457 [True, False, True, False] 458 >>> 459 """ 460 supported_modes = ("all", "any") 461 if mode not in supported_modes: 462 raise ValueError(f"Expects mode to be one of {supported_modes} but got {mode}") 463 464 if mode == "all": 465 count: Dict[int, int] = dict() 466 nb_calls = None 467 buffer: Dict[int, List[Optional[torch.Tensor]]] = dict() 468 469 grad_fns = list(map(_get_grad_fn_or_grad_acc, tensors)) 470 len_tensors = len(tensors) 471 472 def get_inner_hook(idx): 473 def inner_hook(grad: torch.Tensor): 474 nonlocal count, nb_calls, buffer, fn 475 id = torch._C._current_graph_task_id() 476 assert ( 477 id != -1 478 ), "expected this hook to be called inside a backward call" 479 count[id] = count.get(id, 0) 480 buffer[id] = buffer.get(id, [None] * len_tensors) 481 482 if count[id] == 0: 483 # On the first call, compute the actual nb_calls and buffer 484 nb_calls = sum(torch._C._will_engine_execute_node(g) for g in grad_fns) # type: ignore[attr-defined] 485 486 buffer[id][idx] = grad 487 count[id] += 1 488 489 if count[id] == nb_calls: 490 fn = cast(Callable[[Sequence[Optional[torch.Tensor]]], None], fn) 491 fn(buffer[id]) 492 del count[id] 493 del buffer[id] 494 495 return inner_hook 496 497 handles: Tuple[RemovableHandle] = tuple( 498 t.register_hook(get_inner_hook(i)) for i, t in enumerate(tensors) 499 ) 500 elif mode == "any": 501 fn = cast(Callable[[torch.Tensor], None], fn) 502 lock = threading.Lock() 503 ran_hook: Dict[int, bool] = defaultdict(bool) 504 505 @functools.wraps(fn) 506 def wrapped_fn(grad: torch.Tensor): 507 nonlocal ran_hook 508 id = torch._C._current_graph_task_id() 509 assert id != -1, "expected this hook to be called inside a backward call" 510 with lock: 511 prev, ran_hook[id] = ran_hook[id], True 512 if prev: 513 return 514 fn(grad) 515 516 handles = tuple( 517 tensor.register_hook(wrapped_fn) 518 for tensor in tensors 519 if tensor.requires_grad 520 ) 521 522 return _MultiHandle(handles) # type: ignore[possibly-undefined] 523 524 525# NOTE [Allow mutation on tensors saved for backward] 526# 527# 1. Tensor gets saved for backward 528# - remember the python object id and the version of the tensor 529# - remember aliasing information (data_ptr of base + version) 530# - save the original so we control its lifetime 531# 2. Any time a tensor gets in-placed 532# - for each tensor aliased to it: 533# - check using its object id and version to see if it has been saved 534# - if it has been saved, clone it 535# - delete the reference to the original 536# 3. during backward 537# - if the clone exists, the tensor must've been modified in-place 538_allow_mutation_on_saved_tensors_enabled = False 539 540 541def _get_tid(t) -> Tuple[int, int, int]: 542 # FIXME: This is almost definitely a bug. 543 if isinstance( 544 t, 545 ( 546 torch._subclasses.fake_tensor.FakeTensor, 547 torch._subclasses.functional_tensor.FunctionalTensor, 548 ), 549 ): 550 data_ptr = 0 551 else: 552 data_ptr = t.data_ptr() 553 return (id(t), data_ptr, t._version) 554 555 556def _get_sid(t) -> Tuple[int, int]: 557 # FIXME: This is almost definitely a bug. 558 if isinstance( 559 t, 560 ( 561 torch._subclasses.fake_tensor.FakeTensor, 562 torch._subclasses.functional_tensor.FunctionalTensor, 563 ), 564 ): 565 data_ptr = 0 566 else: 567 data_ptr = t.data_ptr() 568 return (data_ptr, t._version) 569 570 571class _Handle: 572 pass 573 574 575class _swap_with_cloned(saved_tensors_hooks): 576 def __init__(self, ctx): 577 def pack_hook(t): 578 tid = _get_tid(t) 579 sid = _get_sid(t) 580 # Tensors saved for backward have an entry in _tid_to_weakhandle 581 handle: Optional[_Handle] = None 582 583 # Save aliasing information 584 ctx.sid_to_tid[sid].add(tid) 585 586 # NB: The same tensor (of the same version) can be saved multiple times 587 if tid not in ctx.tid_to_weakhandle: 588 handle = _Handle() 589 ctx.tid_to_weakhandle[tid] = handle 590 ctx.original[handle] = t 591 else: 592 # Store an additional strong reference to the handle 593 handle = ctx.tid_to_weakhandle[tid] 594 return handle 595 596 def unpack_hook(tup): 597 handle = tup 598 error_msg = ( 599 "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" 600 "in which the graph was originally recorded." 601 ) 602 assert _allow_mutation_on_saved_tensors_enabled, error_msg 603 if handle in ctx.cloned: 604 res = ctx.cloned[handle] 605 else: 606 assert handle in ctx.original, error_msg 607 res = ctx.original[handle] 608 return res 609 610 super().__init__(pack_hook, unpack_hook) 611 612 613class _CloneArgBeforeMutateMode(TorchDispatchMode): 614 def __init__(self, ctx): 615 self.ctx = ctx 616 617 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 618 kwargs = kwargs or {} 619 620 for idx, arg in enumerate(func._schema.arguments): 621 if arg.alias_info is not None and arg.alias_info.is_write: 622 t = kwargs["out"] if arg.is_out else args[idx] 623 tid = _get_tid(t) 624 sid = _get_sid(t) 625 ctx = self.ctx 626 if sid in ctx.sid_to_tid: 627 for tid in ctx.sid_to_tid[sid]: 628 if tid not in ctx.tid_to_weakhandle: 629 # We know that if tid is in sid_to_tid, then it must also be in 630 # tid_to_weakhandle. However, it is possible for the tensor to be 631 # saved at one point, but cleared by backward before it is modified 632 # in-place. Consider the following example: 633 # 634 # >>> a = torch.randn(2, 3, requires_grad=True).clone() 635 # >>> out = (a**2).sum() 636 # >>> out.backward() 637 # >>> a.sin_() 638 continue 639 handle = ctx.tid_to_weakhandle[tid] 640 if handle in ctx.cloned: 641 # The same exact tensor has been cloned already 642 continue 643 ctx.cloned[handle] = ctx.original[handle].clone() 644 del ctx.original[handle] 645 646 rs = func(*args, **kwargs) 647 return rs 648 649 650class _AllowMutationOnSavedContext: 651 def __init__(self): 652 self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() 653 self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() 654 self.tid_to_weakhandle: weakref.WeakValueDictionary = ( 655 weakref.WeakValueDictionary() 656 ) 657 self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict( 658 set 659 ) 660 661 def clear(self): 662 self.cloned.clear() 663 self.original.clear() 664 self.tid_to_weakhandle.clear() 665 self.sid_to_tid.clear() 666 667 668@contextlib.contextmanager 669def allow_mutation_on_saved_tensors(): 670 """Context manager under which mutating tensors saved for backward is allowed. 671 672 Under this context manager, tensors saved for backward are cloned on mutation, 673 so the original version can still be used during backward. Normally, mutating a tensor 674 saved for backward will result in an error raised when it's used during backward. 675 676 To ensure the correct behavior, both the forward and backward should be run under 677 the same context manager. 678 679 returns: 680 An _AllowMutationOnSavedContext object storing the state managed by this 681 context manager. This object can be useful for debugging purposes. The state 682 managed by the context manager is automatically cleared upon exiting. 683 684 Example:: 685 686 >>> import torch 687 >>> with torch.autograd.graph.allow_mutation_on_saved_tensors(): 688 ... # forward 689 ... a = torch.ones(2, 3, requires_grad=True) 690 ... b = a.clone() 691 ... out = (b**2).sum() 692 ... b.sin_() 693 ... # backward 694 ... out.sum().backward() 695 ... 696 tensor([[0.8415, 0.8415, 0.8415], 697 [0.8415, 0.8415, 0.8415]], grad_fn=<SinBackward0>) 698 """ 699 global _allow_mutation_on_saved_tensors_enabled 700 701 ctx = _AllowMutationOnSavedContext() 702 703 with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx): 704 try: 705 if _allow_mutation_on_saved_tensors_enabled: 706 raise RuntimeError( 707 "allow_mutation_on_saved_tensors contexts cannot be nested" 708 ) 709 _allow_mutation_on_saved_tensors_enabled = True 710 yield ctx 711 finally: 712 ctx.clear() 713 _allow_mutation_on_saved_tensors_enabled = False 714 715 716def _register_logging_hooks_on_whole_graph(t_outputs: List[torch.Tensor]): 717 grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs)) 718 719 def iter_graph(roots): 720 if not roots: 721 return 722 seen = set() 723 q: Deque = collections.deque() 724 for node in roots: 725 if node is not None: 726 seen.add(node) 727 q.append(node) 728 729 while q: 730 node = q.popleft() 731 for fn, _idx in node.next_functions: 732 if fn in seen or fn is None: 733 continue 734 seen.add(fn) 735 q.append(fn) 736 737 yield node 738 739 def fmt(t): 740 # Avoid circular import 741 from torch.testing._internal.common_utils import dtype_abbrs 742 743 if t is None: 744 return "None" 745 return f"{dtype_abbrs[t.dtype]}[{', '.join(map(str, t.shape))}]" 746 747 def prehook(grad_outputs): 748 node = torch._C._current_autograd_node() 749 grad_outputs_str = f"[{','.join(fmt(t) for t in grad_outputs)}]" 750 log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}" 751 log.debug(log_str) 752 753 handles = [] 754 for node in iter_graph(grad_fns): 755 handles.append(node.register_prehook(prehook)) 756 757 def unregister_hooks(): 758 for handle in handles: 759 handle.remove() 760 761 return unregister_hooks 762 763 764def _engine_run_backward(t_outputs, *args, **kwargs): 765 attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG 766 if attach_logging_hooks: 767 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) 768 try: 769 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 770 t_outputs, *args, **kwargs 771 ) # Calls into the C++ engine to run the backward pass 772 finally: 773 if attach_logging_hooks: 774 unregister_hooks() # type: ignore[possibly-undefined] 775