1# mypy: allow-untyped-defs 2import functools 3import inspect 4import itertools 5import warnings 6from collections import OrderedDict 7from typing import Any, List, Optional, Tuple 8from typing_extensions import deprecated 9 10import torch 11import torch._C as _C 12import torch._functorch as _functorch 13import torch.utils.hooks as hooks 14from torch._C import _functions 15from torch._functorch.autograd_function import custom_function_call 16 17 18__all__ = [ 19 "FunctionCtx", 20 "BackwardCFunction", 21 "FunctionMeta", 22 "Function", 23 "once_differentiable", 24 "InplaceFunction", 25 "NestedIOFunction", 26] 27 28# Unique id provider for each class inheriting from Function 29# This is incremented in FunctionMeta during class definition 30AUTOGRAD_FUNCTION_COUNTER = itertools.count() 31 32 33# Formerly known as: _ContextMethodMixin 34class FunctionCtx: 35 def save_for_backward(self, *tensors: torch.Tensor): 36 r"""Save given tensors for a future call to :func:`~Function.backward`. 37 38 ``save_for_backward`` should be called at most once, in either the 39 :func:`setup_context` or :func:`forward` methods, and only with tensors. 40 41 All tensors intended to be used in the backward pass should be saved 42 with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent 43 incorrect gradients and memory leaks, and enable the application of saved 44 tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`. 45 46 Note that if intermediary tensors, tensors that are neither inputs 47 nor outputs of :func:`forward`, are saved for backward, your custom Function 48 may not support double backward. 49 Custom Functions that do not support double backward should decorate their 50 :func:`backward` method with ``@once_differentiable`` so that performing 51 double backward raises an error. If you'd like to support double backward, 52 you can either recompute intermediaries based on the inputs during backward 53 or return the intermediaries as the outputs of the custom Function. See the 54 `double backward tutorial <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_ 55 for more details. 56 57 In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors` 58 attribute. Before returning them to the user, a check is made to ensure 59 they weren't used in any in-place operation that modified their content. 60 61 Arguments can also be ``None``. This is a no-op. 62 63 See :ref:`extending-autograd` for more details on how to use this method. 64 65 Example:: 66 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 67 >>> class Func(Function): 68 >>> @staticmethod 69 >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): 70 >>> w = x * z 71 >>> out = x * y + y * z + w * y 72 >>> ctx.save_for_backward(x, y, w, out) 73 >>> ctx.z = z # z is not a tensor 74 >>> return out 75 >>> 76 >>> @staticmethod 77 >>> @once_differentiable 78 >>> def backward(ctx, grad_out): 79 >>> x, y, w, out = ctx.saved_tensors 80 >>> z = ctx.z 81 >>> gx = grad_out * (y + y * z) 82 >>> gy = grad_out * (x + z + w) 83 >>> gz = None 84 >>> return gx, gy, gz 85 >>> 86 >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) 87 >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) 88 >>> c = 4 89 >>> d = Func.apply(a, b, c) 90 91 """ 92 self.to_save = tensors 93 94 def save_for_forward(self, *tensors: torch.Tensor): 95 r"""Save given tensors for a future call to :func:`~Function.jvp`. 96 97 ``save_for_forward`` should be called at most once, in either the 98 :func:`setup_context` or :func:`forward` methods, and all arguments 99 should be tensors. 100 101 In :func:`jvp`, saved objects can be accessed through the :attr:`saved_tensors` 102 attribute. 103 104 Arguments can also be ``None``. This is a no-op. 105 106 See :ref:`extending-autograd` for more details on how to use this method. 107 108 Example:: 109 >>> # xdoctest: +SKIP 110 >>> class Func(torch.autograd.Function): 111 >>> @staticmethod 112 >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): 113 >>> ctx.save_for_backward(x, y) 114 >>> ctx.save_for_forward(x, y) 115 >>> ctx.z = z 116 >>> return x * y * z 117 >>> 118 >>> @staticmethod 119 >>> def jvp(ctx, x_t, y_t, _): 120 >>> x, y = ctx.saved_tensors 121 >>> z = ctx.z 122 >>> return z * (y * x_t + x * y_t) 123 >>> 124 >>> @staticmethod 125 >>> def vjp(ctx, grad_out): 126 >>> x, y = ctx.saved_tensors 127 >>> z = ctx.z 128 >>> return z * grad_out * y, z * grad_out * x, None 129 >>> 130 >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) 131 >>> t = torch.tensor(1., dtype=torch.double) 132 >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) 133 >>> c = 4 134 >>> 135 >>> with fwAD.dual_level(): 136 >>> a_dual = fwAD.make_dual(a, t) 137 >>> d = Func.apply(a_dual, b, c) 138 139 """ 140 for tensor in tensors: 141 assert isinstance(tensor, torch.Tensor) or tensor is None, ( 142 "save_for_forward expects all arguments to be tensors; you should " 143 "save non-tensors as attributes on ctx." 144 ) 145 146 self.saved_for_forward = tensors 147 148 def mark_dirty(self, *args: torch.Tensor): 149 r"""Mark given tensors as modified in an in-place operation. 150 151 This should be called at most once, in either the :func:`setup_context` 152 or :func:`forward` methods, and all arguments should be inputs. 153 154 Every tensor that's been modified in-place in a call to :func:`forward` 155 should be given to this function, to ensure correctness of our checks. 156 It doesn't matter whether the function is called before or after 157 modification. 158 159 Examples:: 160 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 161 >>> class Inplace(Function): 162 >>> @staticmethod 163 >>> def forward(ctx, x): 164 >>> x_npy = x.numpy() # x_npy shares storage with x 165 >>> x_npy += 1 166 >>> ctx.mark_dirty(x) 167 >>> return x 168 >>> 169 >>> @staticmethod 170 >>> @once_differentiable 171 >>> def backward(ctx, grad_output): 172 >>> return grad_output 173 >>> 174 >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone() 175 >>> b = a * a 176 >>> Inplace.apply(a) # This would lead to wrong gradients! 177 >>> # but the engine would not know unless we mark_dirty 178 >>> # xdoctest: +SKIP 179 >>> b.backward() # RuntimeError: one of the variables needed for gradient 180 >>> # computation has been modified by an inplace operation 181 182 """ 183 self.dirty_tensors = args 184 185 @deprecated( 186 "`mark_shared_storage` is deprecated. " 187 "Tensors with shared storages are automatically tracked. " 188 "Note that calls to `set_()` are not tracked", 189 category=FutureWarning, 190 ) 191 def mark_shared_storage(self, *pairs): 192 pass 193 194 def mark_non_differentiable(self, *args: torch.Tensor): 195 r"""Mark outputs as non-differentiable. 196 197 This should be called at most once, in either the :func:`setup_context` 198 or :func:`forward` methods, and all arguments should be tensor outputs. 199 200 This will mark outputs as not requiring gradients, increasing the 201 efficiency of backward computation. You still need to accept a gradient 202 for each output in :meth:`~Function.backward`, but it's always going to 203 be a zero tensor with the same shape as the shape of a corresponding 204 output. 205 206 This is used e.g. for indices returned from a sort. See example:: 207 >>> class Func(Function): 208 >>> @staticmethod 209 >>> def forward(ctx, x): 210 >>> sorted, idx = x.sort() 211 >>> ctx.mark_non_differentiable(idx) 212 >>> ctx.save_for_backward(x, idx) 213 >>> return sorted, idx 214 >>> 215 >>> @staticmethod 216 >>> @once_differentiable 217 >>> def backward(ctx, g1, g2): # still need to accept g2 218 >>> x, idx = ctx.saved_tensors 219 >>> grad_input = torch.zeros_like(x) 220 >>> grad_input.index_add_(0, idx, g1) 221 >>> return grad_input 222 223 """ 224 self.non_differentiable = args 225 226 def set_materialize_grads(self, value: bool): 227 r"""Set whether to materialize grad tensors. Default is ``True``. 228 229 This should be called only from either the :func:`setup_context` or 230 :func:`forward` methods. 231 232 If ``True``, undefined grad tensors will be expanded to tensors full of zeros 233 prior to calling the :func:`backward` and :func:`jvp` methods. 234 235 Example:: 236 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 237 >>> class SimpleFunc(Function): 238 >>> @staticmethod 239 >>> def forward(ctx, x): 240 >>> return x.clone(), x.clone() 241 >>> 242 >>> @staticmethod 243 >>> @once_differentiable 244 >>> def backward(ctx, g1, g2): 245 >>> return g1 + g2 # No check for None necessary 246 >>> 247 >>> # We modify SimpleFunc to handle non-materialized grad outputs 248 >>> class Func(Function): 249 >>> @staticmethod 250 >>> def forward(ctx, x): 251 >>> ctx.set_materialize_grads(False) 252 >>> ctx.save_for_backward(x) 253 >>> return x.clone(), x.clone() 254 >>> 255 >>> @staticmethod 256 >>> @once_differentiable 257 >>> def backward(ctx, g1, g2): 258 >>> x, = ctx.saved_tensors 259 >>> grad_input = torch.zeros_like(x) 260 >>> if g1 is not None: # We must check for None now 261 >>> grad_input += g1 262 >>> if g2 is not None: 263 >>> grad_input += g2 264 >>> return grad_input 265 >>> 266 >>> a = torch.tensor(1., requires_grad=True) 267 >>> b, _ = Func.apply(a) # induces g2 to be undefined 268 269 """ 270 self.materialize_grads = value 271 272 273# DO NOT USE: This is only defined to be able to load old serialized models 274_ContextMethodMixin = FunctionCtx 275 276 277class _HookMixin: 278 @staticmethod 279 def _register_hook(backward_hooks, hook): 280 if backward_hooks is None: 281 backward_hooks = OrderedDict() 282 handle = hooks.RemovableHandle(backward_hooks) 283 backward_hooks[handle.id] = hook 284 return backward_hooks, handle 285 286 287class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin): 288 r""" 289 This class is used for internal autograd work. Do not use. 290 """ 291 292 def apply(self, *args): 293 r""" 294 Apply method used when executing this Node during the backward 295 """ 296 # _forward_cls is defined by derived class 297 # The user should define either backward or vjp but never both. 298 backward_fn = self._forward_cls.backward # type: ignore[attr-defined] 299 vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined] 300 if backward_fn is not Function.backward and vjp_fn is not Function.vjp: 301 raise RuntimeError( 302 "Implementing both 'backward' and 'vjp' for a custom " 303 "Function is not allowed. You should only implement one " 304 "of them." 305 ) 306 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn 307 return user_fn(self, *args) 308 309 def apply_jvp(self, *args): 310 r""" 311 Apply method used when executing forward mode AD during the forward 312 """ 313 # _forward_cls is defined by derived class 314 return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined] 315 316 def _compiled_autograd_key(self): 317 return self._forward_cls._compiled_autograd_key(self) # type: ignore[attr-defined] 318 319 320class FunctionMeta(type): 321 """Function metaclass. 322 323 This metaclass sets up the following properties: 324 _backward_cls: The Function class corresponding to the differentiated 325 version of this function (which is generated on the fly by this 326 metaclass). 327 """ 328 329 def __init__(cls, name, bases, attrs): 330 backward_fn = type( 331 name + "Backward", (BackwardCFunction,), {"_forward_cls": cls} 332 ) 333 backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER) # type: ignore[attr-defined] 334 backward_fn._compiled_autograd_should_lift = attrs.get( # type: ignore[attr-defined] 335 "_compiled_autograd_should_lift", True 336 ) 337 cls._backward_cls = backward_fn 338 339 super().__init__(name, bases, attrs) 340 341 342class _SingleLevelFunction( 343 _C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta 344): 345 @staticmethod 346 def forward(*args: Any, **kwargs: Any) -> Any: 347 r"""Define the forward of the custom autograd Function. 348 349 This function is to be overridden by all subclasses. 350 There are two ways to define forward: 351 352 Usage 1 (Combined forward and ctx):: 353 354 @staticmethod 355 def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: 356 pass 357 358 - It must accept a context ctx as the first argument, followed by any 359 number of arguments (tensors or other types). 360 - See :ref:`combining-forward-context` for more details 361 362 Usage 2 (Separate forward and ctx):: 363 364 @staticmethod 365 def forward(*args: Any, **kwargs: Any) -> Any: 366 pass 367 368 @staticmethod 369 def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: 370 pass 371 372 - The forward no longer accepts a ctx argument. 373 - Instead, you must also override the :meth:`torch.autograd.Function.setup_context` 374 staticmethod to handle setting up the ``ctx`` object. 375 ``output`` is the output of the forward, ``inputs`` are a Tuple of inputs 376 to the forward. 377 - See :ref:`extending-autograd` for more details 378 379 The context can be used to store arbitrary data that can be then 380 retrieved during the backward pass. Tensors should not be stored 381 directly on `ctx` (though this is not currently enforced for 382 backward compatibility). Instead, tensors should be saved either with 383 :func:`ctx.save_for_backward` if they are intended to be used in 384 ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward` 385 if they are intended to be used for in ``jvp``. 386 """ 387 raise NotImplementedError( 388 "You must implement the forward function for custom autograd.Function." 389 ) 390 391 @staticmethod 392 def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any: 393 r"""There are two ways to define the forward pass of an autograd.Function. 394 395 Either: 396 397 1. Override forward with the signature ``forward(ctx, *args, **kwargs)``. 398 ``setup_context`` is not overridden. Setting up the ctx for backward 399 happens inside the ``forward``. 400 2. Override forward with the signature ``forward(*args, **kwargs)`` and 401 override ``setup_context``. Setting up the ctx for backward happens 402 inside ``setup_context`` (as opposed to inside the ``forward``) 403 404 See :meth:`torch.autograd.Function.forward` and :ref:`extending-autograd` for more details. 405 """ 406 raise NotImplementedError("setup_context is not implemented.") 407 408 @staticmethod 409 def backward(ctx: Any, *grad_outputs: Any) -> Any: 410 r"""Define a formula for differentiating the operation with backward mode automatic differentiation. 411 412 This function is to be overridden by all subclasses. 413 (Defining this function is equivalent to defining the ``vjp`` function.) 414 415 It must accept a context :attr:`ctx` as the first argument, followed by 416 as many outputs as the :func:`forward` returned (None will be passed in 417 for non tensor outputs of the forward function), 418 and it should return as many tensors, as there were inputs to 419 :func:`forward`. Each argument is the gradient w.r.t the given output, 420 and each returned value should be the gradient w.r.t. the 421 corresponding input. If an input is not a Tensor or is a Tensor not 422 requiring grads, you can just pass None as a gradient for that input. 423 424 The context can be used to retrieve tensors saved during the forward 425 pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple 426 of booleans representing whether each input needs gradient. E.g., 427 :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the 428 first input to :func:`forward` needs gradient computed w.r.t. the 429 output. 430 """ 431 raise NotImplementedError( 432 "You must implement either the backward or vjp method for " 433 "your custom autograd.Function to use it with backward " 434 "mode AD." 435 ) 436 437 # vjp and backward are alias of each other 438 vjp = backward 439 440 @staticmethod 441 def jvp(ctx: Any, *grad_inputs: Any) -> Any: 442 r"""Define a formula for differentiating the operation with forward mode automatic differentiation. 443 444 This function is to be overridden by all subclasses. 445 It must accept a context :attr:`ctx` as the first argument, followed by 446 as many inputs as the :func:`forward` got (None will be passed in 447 for non tensor inputs of the forward function), 448 and it should return as many tensors as there were outputs to 449 :func:`forward`. Each argument is the gradient w.r.t the given input, 450 and each returned value should be the gradient w.r.t. the 451 corresponding output. If an output is not a Tensor or the function is not 452 differentiable with respect to that output, you can just pass None as a 453 gradient for that input. 454 455 You can use the :attr:`ctx` object to pass any value from the forward to this 456 functions. 457 """ 458 raise NotImplementedError( 459 "You must implement the jvp function for custom " 460 "autograd.Function to use it with forward mode AD." 461 ) 462 463 464class Function(_SingleLevelFunction): 465 r"""Base class to create custom `autograd.Function`. 466 467 To create a custom `autograd.Function`, subclass this class and implement 468 the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom 469 op in the forward pass, call the class method ``apply``. Do not call 470 :meth:`forward` directly. 471 472 To ensure correctness and best performance, make sure you are calling the 473 correct methods on ``ctx`` and validating your backward function using 474 :func:`torch.autograd.gradcheck`. 475 476 See :ref:`extending-autograd` for more details on how to use this class. 477 478 Examples:: 479 480 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) 481 >>> class Exp(Function): 482 >>> @staticmethod 483 >>> def forward(ctx, i): 484 >>> result = i.exp() 485 >>> ctx.save_for_backward(result) 486 >>> return result 487 >>> 488 >>> @staticmethod 489 >>> def backward(ctx, grad_output): 490 >>> result, = ctx.saved_tensors 491 >>> return grad_output * result 492 >>> 493 >>> # Use it by calling the apply method: 494 >>> # xdoctest: +SKIP 495 >>> output = Exp.apply(input) 496 """ 497 498 def __init__(self, *args, **kwargs): 499 warnings.warn( 500 f"{self.__class__} should not be instantiated. Methods on autograd functions" 501 "are all static, so you should invoke them on the class itself. " 502 "Instantiating an autograd function will raise an " 503 "error in a future version of PyTorch.", 504 DeprecationWarning, 505 stacklevel=2, 506 ) 507 508 def __call__(self, *args, **kwargs): 509 raise RuntimeError( 510 "Legacy autograd function with non-static forward method is deprecated. " 511 "Please use new-style autograd function with static forward method. " 512 "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)" 513 ) 514 515 """ 516 Bool that specifies if PyTorch should attempt to autogenerate 517 :func:`torch.vmap` support for this autograd.Function. You may set this to 518 True only if this autograd.Function's forward, backward, and jvp (if they 519 exist) are written using PyTorch operations; otherwise, please override 520 :meth:`torch.autograd.Function.vmap` to add support for :func:`torch.vmap`. 521 522 Please see :ref:`func-autograd-function` for more details. 523 """ 524 generate_vmap_rule = False 525 526 @staticmethod 527 def vmap(info, in_dims, *args): 528 r"""Define the behavior for this autograd.Function underneath :func:`torch.vmap`. 529 530 For a :func:`torch.autograd.Function` to support 531 :func:`torch.vmap`, you must either override this static method, or set 532 ``generate_vmap_rule`` to ``True`` (you may not do both). 533 534 If you choose to override this staticmethod: it must accept 535 536 - an ``info`` object as the first argument. ``info.batch_size`` 537 specifies the size of the dimension being vmapped over, 538 while ``info.randomness`` is the randomness option passed to 539 :func:`torch.vmap`. 540 - an ``in_dims`` tuple as the second argument. 541 For each arg in ``args``, ``in_dims`` has a corresponding 542 ``Optional[int]``. It is ``None`` if the arg is not a Tensor or if 543 the arg is not being vmapped over, otherwise, it is an integer 544 specifying what dimension of the Tensor is being vmapped over. 545 - ``*args``, which is the same as the args to :meth:`~Function.forward`. 546 547 The return of the vmap staticmethod is a tuple of ``(output, out_dims)``. 548 Similar to ``in_dims``, ``out_dims`` should be of the same structure as 549 ``output`` and contain one ``out_dim`` per output that specifies if the 550 output has the vmapped dimension and what index it is in. 551 552 Please see :ref:`func-autograd-function` for more details. 553 """ 554 raise NotImplementedError( 555 "To use autograd.Function with vmap, you must either override the " 556 "vmap staticmethod or set generate_vmap_rule=True." 557 ) 558 559 @classmethod 560 def apply(cls, *args, **kwargs): 561 def bind_default_args(func, *args, **kwargs): 562 signature = inspect.signature(func) 563 bound_args = signature.bind(*args, **kwargs) 564 bound_args.apply_defaults() 565 566 return bound_args.args 567 568 is_setup_ctx_defined = _is_setup_context_defined(cls.setup_context) 569 if is_setup_ctx_defined: 570 args = bind_default_args(cls.forward, *args, **kwargs) 571 572 if not torch._C._are_functorch_transforms_active(): 573 # See NOTE: [functorch vjp and autograd interaction] 574 args = _functorch.utils.unwrap_dead_wrappers(args) 575 return super().apply(*args, **kwargs) # type: ignore[misc] 576 577 if not is_setup_ctx_defined: 578 raise RuntimeError( 579 "In order to use an autograd.Function with functorch transforms " 580 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 581 "staticmethod. For more details, please see " 582 "https://pytorch.org/docs/main/notes/extending.func.html" 583 ) 584 585 return custom_function_call(cls, *args, **kwargs) 586 587 @staticmethod 588 def _compiled_autograd_key(ctx): 589 return (ctx._autograd_function_id,) 590 591 592def _is_setup_context_defined(fn): 593 return fn != _SingleLevelFunction.setup_context 594 595 596def once_differentiable(fn): 597 @functools.wraps(fn) 598 def wrapper(ctx, *args): 599 with torch.no_grad(): 600 outputs = fn(ctx, *args) 601 602 if not torch.is_grad_enabled(): 603 return outputs 604 605 # If any of the inputs have requires_grad=True, we force the outputs 606 # to have requires_grad=True but point to a grad_fn which throws an 607 # error message during (double) back-propagation. 608 # XXX: this is only an approximation of requires_grad - there's no way 609 # to figure out if fn didn't use ctx.saved_tensors and as a result 610 # some Tensors might require grad, even if no args do. 611 # Unfortunately, this leads to unexpected error messages ("no nodes 612 # require computing gradients"), but I don't have a better idea. 613 # These functions would raise an error in backward anyway. 614 requires_grad = any( 615 isinstance(arg, torch.Tensor) and arg.requires_grad for arg in args 616 ) 617 if not requires_grad: 618 return outputs 619 620 if not isinstance(outputs, tuple): 621 outputs = (outputs,) 622 623 err_fn = _functions.DelayedError( 624 b"trying to differentiate twice a function that was marked " 625 b"with @once_differentiable", 626 len(outputs), 627 ) 628 629 # Create aliases of each output that has requires_grad=True. We need 630 # at least one of the inputs to err_fn to require grad so that the 631 # output will have a grad_fn. 632 def fake_requires_grad(var): 633 if var is not None: 634 var = var.detach() 635 var.requires_grad = True 636 return var 637 638 return err_fn(*[fake_requires_grad(v) for v in outputs]) 639 640 return wrapper 641 642 643class InplaceFunction(Function): 644 r""" 645 This class is here only for backward compatibility reasons. 646 Use :class:`Function` instead of this for any new use case. 647 """ 648 649 def __init__(self, inplace=False): 650 super().__init__() 651 self.inplace = inplace 652 653 654def _nested_map(condition, fn, condition_msg=None): 655 def _map(obj): 656 if condition(obj): 657 return fn(obj) 658 elif obj is None: 659 return None 660 elif isinstance(obj, (list, tuple)): 661 mapped = (_map(x) for x in obj) 662 if hasattr(obj, "_fields"): 663 # obj is namedtuple 664 return type(obj)(*mapped) 665 return type(obj)(mapped) 666 elif isinstance(obj, dict): 667 return {x: _map(obj[x]) for x in obj} 668 else: 669 raise ValueError( 670 "Auto nesting doesn't know how to process " 671 "an input object of type " 672 + torch.typename(obj) 673 + ( 674 ". Accepted types: " + condition_msg + ", or lists/tuples of them" 675 if condition_msg 676 else "" 677 ) 678 ) 679 680 return _map 681 682 683def _jit_unwrap_structured(obj): 684 if hasattr(obj, "_jit_unwrap"): 685 return obj._jit_unwrap() 686 return obj 687 688 689def _iter_filter(condition, allow_unknown=False, condition_msg=None, conversion=None): 690 def _iter(obj): 691 if conversion is not None: 692 obj = conversion(obj) 693 if condition(obj): 694 yield obj 695 elif obj is None: 696 return 697 elif isinstance(obj, (list, tuple)): 698 for o in obj: 699 yield from _iter(o) 700 elif isinstance(obj, dict): 701 # We only accept primitive key types, so we needn't inspect them 702 for o in obj.values(): 703 yield from _iter(o) 704 elif allow_unknown: 705 yield obj 706 else: 707 raise ValueError( 708 "Auto nesting doesn't know how to process " 709 "an input object of type " 710 + torch.typename(obj) 711 + ( 712 ". Accepted types: " + condition_msg + ", or lists/tuples of them" 713 if condition_msg 714 else "" 715 ) 716 ) 717 718 return _iter 719 720 721def _unflatten(input, proto): 722 # unflatten a list or tuple input into a nested list/tuple structure 723 # specified by proto 724 def unflatten_helper(input, proto): 725 res: List[Optional[torch.Tensor]] = [] 726 if hasattr(proto, "_jit_wrap"): 727 return proto._jit_wrap(input) 728 if not isinstance(proto, (list, tuple)): 729 return input[0], input[1:] 730 for e in proto: 731 if e is None: 732 res.append(e) 733 else: 734 res_e, input = unflatten_helper(input, e) 735 res.append(res_e) 736 return type(proto)(res), input 737 738 return unflatten_helper(input, proto)[0] 739 740 741_iter_jit_values = _iter_filter( 742 lambda o: o is None or isinstance(o, torch._C.Value), 743 condition_msg="jit's Values or None", 744) 745_iter_tensors = _iter_filter( 746 lambda x: isinstance(x, torch.Tensor), 747 condition_msg="Tensors", 748 conversion=_jit_unwrap_structured, 749) 750_iter_tensors_permissive = _iter_filter( 751 lambda x: isinstance(x, torch.Tensor), 752 allow_unknown=True, 753 condition_msg="Tensors (permissive)", 754) 755_iter_None_tensors = _iter_filter( 756 lambda o: o is None or isinstance(o, torch.Tensor), condition_msg="Tensors or None" 757) 758_map_tensor_data = _nested_map( 759 lambda x: isinstance(x, torch.Tensor), lambda o: o.data, condition_msg="Tensors" 760) 761 762 763class NestedIOFunction(Function): 764 r""" 765 This class is here only for backward compatibility reasons. 766 Use :class:`Function` instead of this for any new use case. 767 """ 768 # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the 769 # superclass (Function) but are instance methods here, which mypy reports as incompatible. 770 771 def _do_forward(self, *input): 772 self._nested_input = input 773 flat_input = tuple(_iter_tensors(input)) 774 flat_output = super()._do_forward(*flat_input) # type: ignore[misc] 775 nested_output = self._nested_output 776 nested_tensors = _unflatten(flat_output, self._nested_output) 777 return nested_tensors 778 779 def _do_backward(self, gradients, retain_variables): 780 self.retain_variables = retain_variables 781 result = super()._do_backward(gradients, retain_variables) # type: ignore[misc] 782 if not retain_variables: 783 del self._nested_output 784 del self._to_save_nested 785 return result 786 787 def backward(self, *gradients: Any) -> Any: # type: ignore[override] 788 r""" 789 Shared backward utility. 790 """ 791 nested_gradients = _unflatten(gradients, self._nested_output) 792 result = self.backward_extended(*nested_gradients) # type: ignore[func-returns-value] 793 return tuple(_iter_None_tensors(result)) 794 795 __call__ = _do_forward 796 797 def forward(self, *args: Any) -> Any: # type: ignore[override] 798 r""" 799 Shared forward utility. 800 """ 801 nested_tensors = _map_tensor_data(self._nested_input) 802 result = self.forward_extended(*nested_tensors) # type: ignore[func-returns-value] 803 del self._nested_input 804 self._nested_output = result 805 return tuple(_iter_tensors(result)) 806 807 def save_for_backward(self, *args: Any) -> None: 808 r""" 809 See :meth:`Function.save_for_backward`. 810 """ 811 self.to_save = tuple(_iter_tensors(args)) 812 self._to_save_nested = args 813 814 @property 815 def saved_tensors(self): 816 r""" 817 See :meth:`Function.saved_tensors`. 818 """ 819 flat_tensors = super().saved_tensors # type: ignore[misc] 820 return _unflatten(flat_tensors, self._to_save_nested) 821 822 def mark_dirty(self, *args: Any, **kwargs: Any) -> None: 823 r""" 824 See :meth:`Function.mark_dirty`. 825 """ 826 self.dirty_tensors = tuple(_iter_tensors((args, kwargs))) 827 828 def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None: 829 r""" 830 See :meth:`Function.mark_non_differentiable`. 831 """ 832 self.non_differentiable = tuple(_iter_tensors((args, kwargs))) 833 834 def forward_extended(self, *input: Any) -> None: 835 r""" 836 User defined forward. 837 """ 838 raise NotImplementedError 839 840 def backward_extended(self, *grad_output: Any) -> None: 841 r""" 842 User defined backward. 843 """ 844 raise NotImplementedError 845