1# mypy: ignore-errors 2 3# Copyright (c) Facebook, Inc. and its affiliates. 4# All rights reserved. 5# 6# This source code is licensed under the BSD-style license found in the 7# LICENSE file in the root directory of this source tree. 8 9import contextlib 10from functools import partial, wraps 11from typing import Any, Callable, List, Optional, Tuple, Union 12 13import torch 14import torch.autograd.forward_ad as fwAD 15from torch._C._functorch import ( 16 _assert_wrapped_functional, 17 _func_decrement_nesting, 18 _func_increment_nesting, 19 _grad_decrement_nesting, 20 _grad_increment_nesting, 21 _jvp_decrement_nesting, 22 _jvp_increment_nesting, 23 _propagate_functional_input_mutation, 24 _unwrap_for_grad, 25 _unwrap_functional_tensor, 26 _wrap_for_grad, 27 _wrap_functional_tensor, 28 get_inplace_requires_grad_allowed, 29 set_inplace_requires_grad_allowed, 30) 31from torch._functorch.utils import argnums_t, exposed_in 32from torch._subclasses.functional_tensor import FunctionalTensor 33from torch.fx.experimental import const_fold 34from torch.fx.experimental.proxy_tensor import make_fx 35from torch.utils import _pytree as pytree 36from torch.utils._pytree import ( 37 tree_flatten, 38 tree_map, 39 tree_map_, 40 tree_map_only, 41 tree_unflatten, 42 treespec_pprint, 43) 44 45from .apis import vmap 46from .vmap import doesnt_support_saved_tensors_hooks, get_chunk_sizes 47 48 49def lazy_dynamo_disallow(func): 50 import torch._dynamo 51 52 return torch._dynamo.disallow_in_graph(func) 53 54 55@contextlib.contextmanager 56def enable_inplace_requires_grad(enabled): 57 prev_state = get_inplace_requires_grad_allowed() 58 set_inplace_requires_grad_allowed(enabled) 59 try: 60 yield 61 finally: 62 set_inplace_requires_grad_allowed(prev_state) 63 64 65def _vjp_treespec_compare(primals_out, cotangents): 66 # Revert this once #116264 gets fixed 67 _, primals_out_spec = tree_flatten(primals_out) 68 _, cotangents_spec = tree_flatten(cotangents) 69 # Dynamo fails to trace operator.ne below. To bypass this limitation, this 70 # function is not inlined. 71 if primals_out_spec != cotangents_spec: 72 raise RuntimeError( 73 f"Expected pytree structure of cotangents to be the same " 74 f"as pytree structure of outputs to the function. " 75 f"cotangents: {treespec_pprint(cotangents_spec)}, " 76 f"primal output: {treespec_pprint(primals_out_spec)}" 77 ) 78 79 80def _jvp_treespec_compare(primals, tangents): 81 # Revert this once #116264 gets fixed 82 _, primals_spec = tree_flatten(primals) 83 _, tangents_spec = tree_flatten(tangents) 84 if primals_spec != tangents_spec: 85 raise RuntimeError( 86 f"{jvp_str}: Expected primals and tangents to have the same python " 87 f"structure. For example, if primals is a tuple of 3 tensors, " 88 f"tangents also must be. Got primals with structure {primals_spec} " 89 f"and tangents with structure {tangents_spec}" 90 ) 91 92 93def _linearize_treespec_compare(primals, tangents): 94 # Revert this once #116264 gets fixed 95 _, primals_argspec = tree_flatten(primals) 96 _, tangent_argspec = tree_flatten(tangents) 97 if tangent_argspec != primals_argspec: 98 raise RuntimeError( 99 f"Expected the tangents {tangent_argspec} to have " 100 f"the same argspec as the primals {primals_argspec}" 101 ) 102 103 104def _set_tensor_requires_grad(x): 105 # avoid graph-break on x.requires_grad_() 106 # https://github.com/pytorch/pytorch/pull/110053 107 return x.requires_grad_() 108 109 110def _create_differentiable(inps, level=None): 111 def create_differentiable(x): 112 if isinstance(x, torch.Tensor): 113 with enable_inplace_requires_grad(True): 114 return _set_tensor_requires_grad(x) 115 raise ValueError( 116 f"Thing passed to transform API must be Tensor, " f"got {type(x)}" 117 ) 118 119 return tree_map(create_differentiable, inps) 120 121 122def _undo_create_differentiable(inps, level=None): 123 def unwrap_tensors(x): 124 if isinstance(x, torch.Tensor): 125 return _unwrap_for_grad(x, level) 126 # TODO: Remove the following hack for namedtuples 127 if isinstance(x, tuple): 128 return tree_map(unwrap_tensors, tuple(x)) 129 130 raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}") 131 132 return tree_map(unwrap_tensors, inps) 133 134 135def _is_differentiable(maybe_tensor): 136 if not isinstance(maybe_tensor, torch.Tensor): 137 return False 138 return maybe_tensor.requires_grad 139 140 141def _any_differentiable(tensor_or_tuple_of_tensors): 142 flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors) 143 return any(tuple(map(_is_differentiable, flat_args))) 144 145 146def _wrap_tensor_for_grad(maybe_tensor, level): 147 if not isinstance(maybe_tensor, torch.Tensor): 148 return maybe_tensor 149 return _wrap_for_grad(maybe_tensor, level) 150 151 152def _wrap_all_tensors(tensor_pytree, level): 153 return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree) 154 155 156def _as_tuple(val): 157 if isinstance(val, tuple): 158 return val 159 return (val,) 160 161 162# Version of autograd.grad that handles outputs that don't depend on inputs 163 164 165def _autograd_grad( 166 outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True 167): 168 if grad_outputs is None: 169 diff_outputs = tuple(out for out in outputs if out.requires_grad) 170 else: 171 result = tuple( 172 (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad 173 ) 174 if len(result) == 0: 175 diff_outputs, grad_outputs = (), () 176 else: 177 diff_outputs, grad_outputs = zip(*result) 178 if len(diff_outputs) == 0: 179 return tuple(torch.zeros_like(inp) for inp in inputs) 180 grad_inputs = torch.autograd.grad( 181 diff_outputs, 182 inputs, 183 grad_outputs, 184 retain_graph=retain_graph, 185 create_graph=create_graph, 186 allow_unused=True, 187 ) 188 grad_inputs = tuple( 189 torch.zeros_like(inp) if gi is None else gi 190 for gi, inp in zip(grad_inputs, inputs) 191 ) 192 return grad_inputs 193 194 195# NOTE [grad and vjp interaction with no_grad] 196# 197# def f(x): 198# with torch.no_grad(): 199# c = x ** 2 200# return x - c 201# 202# The thing to consider is if enable_grad is on/off before grad gets called. 203# 204# Case 1: enable_grad is on. 205# grad(f)(x) 206# In this case, `grad` should respect the inner torch.no_grad. 207# 208# Case 2: enable_grad is off 209# with torch.no_grad(): 210# grad(f)(x) 211# In this case, `grad` should respect the inner torch.no_grad, but not the 212# outer one. This is because `grad` is a "function transform": its result 213# should not depend on the result of a context manager outside of `f`. 214# 215# This gives us the following desired behavior: 216# - (nested) grad transforms must obey torch.no_grad inside them 217# - (nested) grad transforms should not obey torch.no_grad outside them 218# 219# To achieve this behavior, upon entering grad/vjp: 220# - we save the current ("previous") is_grad_enabled (*) 221# - we unconditionally enable grad. 222# 223# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer 224# off the stack: 225# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad 226# active, all subsequent grad transforms must obey it). 227# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False, 228# then we temporarily restore the previous `is_grad_enabled`. This is 229# because we're crossing the boundary from a `grad` outside the 230# no_grad to a `grad` inside the no_grad. 231# 232# NB: vjp has some interesting behavior because the vjp's callable can be called 233# under a different grad_mode than the forward computation... 234# 235# NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but 236# it respects c10::AutoFwGradMode. We've implemented the same logic for 237# our jvp transform (it will have special handling if FwGradMode is disabled). 238 239 240# How do we increment and decrement the nesting? I don't think we can. 241@exposed_in("torch.func") 242def vjp(func: Callable, *primals, has_aux: bool = False): 243 """ 244 Standing for the vector-Jacobian product, returns a tuple containing the 245 results of ``func`` applied to ``primals`` and a function that, when 246 given ``cotangents``, computes the reverse-mode Jacobian of ``func`` with 247 respect to ``primals`` times ``cotangents``. 248 249 Args: 250 func (Callable): A Python function that takes one or more arguments. Must 251 return one or more Tensors. 252 primals (Tensors): Positional arguments to ``func`` that must all be 253 Tensors. The returned function will also be computing the 254 derivative with respect to these arguments 255 has_aux (bool): Flag indicating that ``func`` returns a 256 ``(output, aux)`` tuple where the first element is the output of 257 the function to be differentiated and the second element is 258 other auxiliary objects that will not be differentiated. 259 Default: False. 260 261 Returns: 262 Returns a ``(output, vjp_fn)`` tuple containing the output of ``func`` 263 applied to ``primals`` and a function that computes the vjp of 264 ``func`` with respect to all ``primals`` using the cotangents passed 265 to the returned function. If ``has_aux is True``, then instead returns a 266 ``(output, vjp_fn, aux)`` tuple. 267 The returned ``vjp_fn`` function will return a tuple of each VJP. 268 269 When used in simple cases, :func:`vjp` behaves the same as :func:`grad` 270 271 >>> x = torch.randn([5]) 272 >>> f = lambda x: x.sin().sum() 273 >>> (_, vjpfunc) = torch.func.vjp(f, x) 274 >>> grad = vjpfunc(torch.tensor(1.))[0] 275 >>> assert torch.allclose(grad, torch.func.grad(f)(x)) 276 277 However, :func:`vjp` can support functions with multiple outputs by 278 passing in the cotangents for each of the outputs 279 280 >>> x = torch.randn([5]) 281 >>> f = lambda x: (x.sin(), x.cos()) 282 >>> (_, vjpfunc) = torch.func.vjp(f, x) 283 >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) 284 >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) 285 286 :func:`vjp` can even support outputs being Python structs 287 288 >>> x = torch.randn([5]) 289 >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} 290 >>> (_, vjpfunc) = torch.func.vjp(f, x) 291 >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} 292 >>> vjps = vjpfunc(cotangents) 293 >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) 294 295 The function returned by :func:`vjp` will compute the partials with 296 respect to each of the ``primals`` 297 298 >>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) 299 >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) 300 >>> cotangents = torch.randn([5, 5]) 301 >>> vjps = vjpfunc(cotangents) 302 >>> assert len(vjps) == 2 303 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) 304 >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents)) 305 306 ``primals`` are the positional arguments for ``f``. All kwargs use their 307 default value 308 309 >>> x = torch.randn([5]) 310 >>> def f(x, scale=4.): 311 >>> return x * scale 312 >>> 313 >>> (_, vjpfunc) = torch.func.vjp(f, x) 314 >>> vjps = vjpfunc(torch.ones_like(x)) 315 >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.)) 316 317 .. note:: 318 Using PyTorch ``torch.no_grad`` together with ``vjp``. 319 Case 1: Using ``torch.no_grad`` inside a function: 320 321 >>> def f(x): 322 >>> with torch.no_grad(): 323 >>> c = x ** 2 324 >>> return x - c 325 326 In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``. 327 328 Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager: 329 330 >>> # xdoctest: +SKIP(failing) 331 >>> with torch.no_grad(): 332 >>> vjp(f)(x) 333 334 In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the 335 outer one. This is because ``vjp`` is a "function transform": its result 336 should not depend on the result of a context manager outside of ``f``. 337 """ 338 return _vjp_with_argnums(func, *primals, has_aux=has_aux) 339 340 341@contextlib.contextmanager 342def grad_increment_nesting(): 343 try: 344 grad_level = _grad_increment_nesting() 345 yield grad_level 346 finally: 347 _grad_decrement_nesting() 348 349 350def enter_jvp_nesting(): 351 global JVP_NESTING 352 jvp_level = _jvp_increment_nesting() 353 JVP_NESTING += 1 354 return jvp_level 355 356 357def exit_jvp_nesting(): 358 global JVP_NESTING 359 _jvp_decrement_nesting() 360 JVP_NESTING -= 1 361 362 363@contextlib.contextmanager 364def jvp_increment_nesting(): 365 try: 366 yield enter_jvp_nesting() 367 finally: 368 exit_jvp_nesting() 369 370 371@doesnt_support_saved_tensors_hooks 372def _vjp_with_argnums( 373 func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False 374): 375 # This is the same function as vjp but also accepts an argnums argument 376 # All args are the same as vjp except for the added argument 377 # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to. 378 # If None, computes the gradients with respect to all inputs (used for vjp). Default: None 379 # 380 # WARN: Users should NOT call this function directly and should just be calling vjp. 381 # It is only separated so that inputs passed to jacrev but not differentiated get the correct wrappers. 382 # 383 # NOTE: All error messages are produced as if vjp was being called, even if this was called by jacrev 384 # 385 # Returns the same two elements as :func:`vjp` but the function returned, vjp_fn, returns a tuple of VJPs 386 # for only the primal elements given by argnums. 387 with grad_increment_nesting() as level: 388 # See NOTE [grad and vjp interaction with no_grad] 389 with torch.enable_grad(): 390 primals = _wrap_all_tensors(primals, level) 391 # Note for the reviewer: This is extremely odd but it passes the 392 # assertion "len(self.block_stack) == 1" on symbolic_convert.py 393 # The equivalent "if argnums is None" fails for some reason 394 if not isinstance(argnums, int) and not argnums: 395 diff_primals = _create_differentiable(primals, level) 396 else: 397 diff_primals = _slice_argnums(primals, argnums, as_tuple=False) 398 tree_map_(partial(_create_differentiable, level=level), diff_primals) 399 primals_out = func(*primals) 400 401 if has_aux: 402 if not (isinstance(primals_out, tuple) and len(primals_out) == 2): 403 raise RuntimeError( 404 "vjp(f, *primals): output of function f should be a tuple: (output, aux) " 405 "if has_aux is True" 406 ) 407 primals_out, aux = primals_out 408 aux = _undo_create_differentiable(aux, level) 409 410 flat_primals_out, primals_out_spec = tree_flatten(primals_out) 411 assert_non_empty_tensor_output(flat_primals_out, "vjp(f, *primals)") 412 flat_diff_primals, primals_spec = tree_flatten(diff_primals) 413 results = _undo_create_differentiable(primals_out, level) 414 415 for primal_out in flat_primals_out: 416 assert isinstance(primal_out, torch.Tensor) 417 if primal_out.is_floating_point() or primal_out.is_complex(): 418 continue 419 raise RuntimeError( 420 "vjp(f, ...): All outputs of f must be " 421 "floating-point or complex Tensors, got Tensor " 422 f"with dtype {primal_out.dtype}" 423 ) 424 425 def wrapper(cotangents, retain_graph=True, create_graph=None): 426 if create_graph is None: 427 create_graph = torch.is_grad_enabled() 428 flat_cotangents, cotangents_spec = tree_flatten(cotangents) 429 _vjp_treespec_compare(primals_out, cotangents) 430 result = _autograd_grad( 431 flat_primals_out, 432 flat_diff_primals, 433 flat_cotangents, 434 retain_graph=retain_graph, 435 create_graph=create_graph, 436 ) 437 return tree_unflatten(result, primals_spec) 438 439 if has_aux: 440 return results, wrapper, aux 441 else: 442 return results, wrapper 443 444 445def _safe_zero_index(x): 446 assert len(x) == 1 447 return x[0] 448 449 450# jacrev and jacfwd don't support complex functions 451# Helper function to throw appropriate error. 452def error_if_complex(func_name, args, is_input): 453 flat_args = pytree.tree_leaves(args) 454 for idx, arg in enumerate(flat_args): 455 if isinstance(arg, torch.Tensor) and arg.dtype.is_complex: 456 input_or_output = "inputs" if is_input else "outputs" 457 err_msg = ( 458 f"{func_name}: Expected all {input_or_output} " 459 f"to be real but received complex tensor at flattened input idx: {idx}" 460 ) 461 raise RuntimeError(err_msg) 462 463 464@exposed_in("torch.func") 465def jacrev( 466 func: Callable, 467 argnums: Union[int, Tuple[int]] = 0, 468 *, 469 has_aux=False, 470 chunk_size: Optional[int] = None, 471 _preallocate_and_copy=False, 472): 473 """ 474 Computes the Jacobian of ``func`` with respect to the arg(s) at index 475 ``argnum`` using reverse mode autodiff 476 477 .. note:: 478 Using :attr:`chunk_size=1` is equivalent to computing the jacobian 479 row-by-row with a for-loop i.e. the constraints of :func:`vmap` are 480 not applicable. 481 482 Args: 483 func (function): A Python function that takes one or more arguments, 484 one of which must be a Tensor, and returns one or more Tensors 485 argnums (int or Tuple[int]): Optional, integer or tuple of integers, 486 saying which arguments to get the Jacobian with respect to. 487 Default: 0. 488 has_aux (bool): Flag indicating that ``func`` returns a 489 ``(output, aux)`` tuple where the first element is the output of 490 the function to be differentiated and the second element is 491 auxiliary objects that will not be differentiated. 492 Default: False. 493 chunk_size (None or int): If None (default), use the maximum chunk size 494 (equivalent to doing a single vmap over vjp to compute the jacobian). 495 If 1, then compute the jacobian row-by-row with a for-loop. 496 If not None, then compute the jacobian :attr:`chunk_size` rows at a time 497 (equivalent to doing multiple vmap over vjp). If you run into memory issues computing 498 the jacobian, please try to specify a non-None chunk_size. 499 500 Returns: 501 Returns a function that takes in the same inputs as ``func`` and 502 returns the Jacobian of ``func`` with respect to the arg(s) at 503 ``argnums``. If ``has_aux is True``, then the returned function 504 instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` 505 is the Jacobian and ``aux`` is auxiliary objects returned by ``func``. 506 507 A basic usage with a pointwise, unary operation will give a diagonal array 508 as the Jacobian 509 510 >>> from torch.func import jacrev 511 >>> x = torch.randn(5) 512 >>> jacobian = jacrev(torch.sin)(x) 513 >>> expected = torch.diag(torch.cos(x)) 514 >>> assert torch.allclose(jacobian, expected) 515 516 If you would like to compute the output of the function as well as the 517 jacobian of the function, use the ``has_aux`` flag to return the output 518 as an auxiliary object: 519 520 >>> from torch.func import jacrev 521 >>> x = torch.randn(5) 522 >>> 523 >>> def f(x): 524 >>> return x.sin() 525 >>> 526 >>> def g(x): 527 >>> result = f(x) 528 >>> return result, result 529 >>> 530 >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x) 531 >>> assert torch.allclose(f_x, f(x)) 532 533 :func:`jacrev` can be composed with vmap to produce batched 534 Jacobians: 535 536 >>> from torch.func import jacrev, vmap 537 >>> x = torch.randn(64, 5) 538 >>> jacobian = vmap(jacrev(torch.sin))(x) 539 >>> assert jacobian.shape == (64, 5, 5) 540 541 Additionally, :func:`jacrev` can be composed with itself to produce 542 Hessians 543 544 >>> from torch.func import jacrev 545 >>> def f(x): 546 >>> return x.sin().sum() 547 >>> 548 >>> x = torch.randn(5) 549 >>> hessian = jacrev(jacrev(f))(x) 550 >>> assert torch.allclose(hessian, torch.diag(-x.sin())) 551 552 By default, :func:`jacrev` computes the Jacobian with respect to the first 553 input. However, it can compute the Jacboian with respect to a different 554 argument by using ``argnums``: 555 556 >>> from torch.func import jacrev 557 >>> def f(x, y): 558 >>> return x + y ** 2 559 >>> 560 >>> x, y = torch.randn(5), torch.randn(5) 561 >>> jacobian = jacrev(f, argnums=1)(x, y) 562 >>> expected = torch.diag(2 * y) 563 >>> assert torch.allclose(jacobian, expected) 564 565 Additionally, passing a tuple to ``argnums`` will compute the Jacobian 566 with respect to multiple arguments 567 568 >>> from torch.func import jacrev 569 >>> def f(x, y): 570 >>> return x + y ** 2 571 >>> 572 >>> x, y = torch.randn(5), torch.randn(5) 573 >>> jacobian = jacrev(f, argnums=(0, 1))(x, y) 574 >>> expectedX = torch.diag(torch.ones_like(x)) 575 >>> expectedY = torch.diag(2 * y) 576 >>> assert torch.allclose(jacobian[0], expectedX) 577 >>> assert torch.allclose(jacobian[1], expectedY) 578 579 .. note:: 580 Using PyTorch ``torch.no_grad`` together with ``jacrev``. 581 Case 1: Using ``torch.no_grad`` inside a function: 582 583 >>> def f(x): 584 >>> with torch.no_grad(): 585 >>> c = x ** 2 586 >>> return x - c 587 588 In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``. 589 590 Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager: 591 592 >>> with torch.no_grad(): 593 >>> jacrev(f)(x) 594 595 In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the 596 outer one. This is because ``jacrev`` is a "function transform": its result 597 should not depend on the result of a context manager outside of ``f``. 598 """ 599 if not (chunk_size is None or chunk_size > 0): 600 raise ValueError("jacrev: `chunk_size` should be greater than 0.") 601 602 def wrapper_fn(*args): 603 error_if_complex("jacrev", args, is_input=True) 604 vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux) 605 if has_aux: 606 output, vjp_fn, aux = vjp_out 607 else: 608 output, vjp_fn = vjp_out 609 610 # See NOTE: [Computing jacobian with vmap and vjp for multiple outputs] 611 flat_output, output_spec = tree_flatten(output) 612 613 error_if_complex("jacrev", flat_output, is_input=False) 614 615 # NB: vjp already checks that all outputs are tensors 616 # Step 1: Construct grad_outputs by splitting the standard basis 617 flat_output_numels = tuple(out.numel() for out in flat_output) 618 619 primals = _slice_argnums(args, argnums) 620 flat_primals, primals_spec = tree_flatten(primals) 621 622 def compute_jacobian_stacked(): 623 # Helper function to compute chunked Jacobian 624 # The intermediate chunked calculation are only 625 # scoped at this function level. 626 chunked_results = [] 627 for flat_basis_chunk in _chunked_standard_basis_for_( 628 flat_output, flat_output_numels, chunk_size=chunk_size 629 ): 630 if chunk_size == 1: 631 # sanity check. 632 for t in flat_basis_chunk: 633 assert t.size(0) == 1 634 635 flat_basis_chunk = tree_map( 636 lambda t: torch.squeeze(t, 0), flat_basis_chunk 637 ) 638 639 basis = tree_unflatten(flat_basis_chunk, output_spec) 640 641 if chunk_size == 1: 642 # Behaviour with `chunk_size=1` is same as `for-loop` 643 # i.e. user shouldn't deal with the limitations of vmap. 644 chunked_result = vjp_fn(basis) 645 else: # chunk_size is None or chunk_size != 1 646 chunked_result = vmap(vjp_fn)(basis) 647 648 flat_results = pytree.tree_leaves(chunked_result) 649 650 if chunk_size == 1: 651 flat_results = tree_map( 652 lambda t: torch.unsqueeze(t, 0), flat_results 653 ) 654 655 chunked_results.append(flat_results) 656 657 if len(chunked_results) == 1: 658 # Short-circuit if we used a single chunk 659 return chunked_results[0] 660 661 # Concatenate chunks. 662 flat_results = [] 663 # Iterate and concat the jacobians of different 664 # inputs. 665 for idx in range(len(flat_primals)): 666 r = tuple(r_[idx] for r_ in chunked_results) 667 flat_results.append(torch.cat(r, 0)) 668 669 return flat_results 670 671 def compute_jacobian_preallocate_and_copy(): 672 # Helper function to compute chunked Jacobian 673 # The intermediate chunked calculation are only 674 # scoped at this function level. 675 out_vec_size = sum(flat_output_numels) 676 677 # Don't pre-allocate if we have a single chunk. 678 if not (chunk_size is None or chunk_size >= out_vec_size): 679 stacked_results = [ 680 primal.new_zeros(out_vec_size, *primal.shape) 681 for primal in flat_primals 682 ] 683 684 for idx, flat_basis_chunk in enumerate( 685 _chunked_standard_basis_for_( 686 flat_output, flat_output_numels, chunk_size=chunk_size 687 ) 688 ): 689 if chunk_size == 1: 690 # sanity check. 691 for t in flat_basis_chunk: 692 assert t.size(0) == 1 693 694 flat_basis_chunk = [torch.squeeze(t, 0) for t in flat_basis_chunk] 695 696 basis = tree_unflatten(flat_basis_chunk, output_spec) 697 698 if chunk_size == 1: 699 # Behaviour with `chunk_size=1` is same as `for-loop` 700 # i.e. user shouldn't deal with the limitations of vmap. 701 chunked_result = vjp_fn(basis) 702 else: # chunk_size is None or chunk_size != 1 703 chunked_result = vmap(vjp_fn)(basis) 704 705 flat_results = pytree.tree_leaves(chunked_result) 706 707 # Short-circuit if we have a single chunk. 708 if chunk_size is None or chunk_size >= out_vec_size: 709 if chunk_size == 1: # and out_vec_size == 1 710 # Since we squeezed the output dim 711 flat_results = tree_map( 712 lambda t: torch.unsqueeze(t, 0), flat_results 713 ) 714 return flat_results 715 716 for r, sr in zip(flat_results, stacked_results): 717 sr[idx * chunk_size : (idx + 1) * chunk_size].copy_(r) 718 719 return stacked_results 720 721 if _preallocate_and_copy: 722 flat_jacobians_per_input = compute_jacobian_preallocate_and_copy() 723 else: 724 flat_jacobians_per_input = compute_jacobian_stacked() 725 726 # Step 2: The returned jacobian is one big tensor per input. In this step, 727 # we split each Tensor by output. 728 flat_jacobians_per_input = [ 729 result.split(flat_output_numels, dim=0) 730 for result in flat_jacobians_per_input 731 ] 732 flat_input_flat_output = [ 733 tuple( 734 split.view(out.shape + primal.shape) 735 for split, out in zip(splits, flat_output) 736 ) 737 for splits, primal in zip(flat_jacobians_per_input, flat_primals) 738 ] 739 740 # Step 3: Right now, `jacobian` is a List[List[Tensor]]. 741 # The outer List corresponds to the number of primals, 742 # the inner List corresponds to the number of outputs. 743 # We need to: 744 # a. Exchange the order of the outer List and inner List 745 # b. tree_unflatten the inner Lists (which correspond to the primals) 746 # c. handle the argnums=int case 747 # d. tree_unflatten the outer List (which corresponds to the outputs) 748 flat_output_flat_input = tuple(zip(*flat_input_flat_output)) 749 750 flat_output_input = tuple( 751 tree_unflatten(flat_input, primals_spec) 752 for flat_input in flat_output_flat_input 753 ) 754 755 if isinstance(argnums, int): 756 flat_output_input = tuple( 757 _safe_zero_index(flat_input) for flat_input in flat_output_input 758 ) 759 output_input = tree_unflatten(flat_output_input, output_spec) 760 if has_aux: 761 return output_input, aux 762 return output_input 763 764 # Dynamo does not support HOP composition if their inner function is 765 # annotated with @functools.wraps(...). We circumvent this issue by applying 766 # wraps only if we're not tracing with dynamo. 767 if not torch._dynamo.is_compiling(): 768 wrapper_fn = wraps(func)(wrapper_fn) 769 770 return wrapper_fn 771 772 773# NOTE: [Computing jacobian with vmap and vjp for multiple outputs] 774# 775# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3). 776# It turns out we can compute the jacobian of this function with a single 777# call to autograd.grad by using vmap over the correct grad_outputs. 778# 779# Firstly, one way to compute the jacobian is to stack x**2 and x.sum() 780# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()]) 781# 782# To get the first row of the jacobian, we call 783# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0])) 784# To get the 2nd row of the jacobian, we call 785# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0])) 786# and so on. 787# 788# Using vmap, we can vectorize all 4 of these computations into one by 789# passing the standard basis for R^4 as the grad_output. 790# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)). 791# 792# Now, how do we compute the jacobian *without stacking the output*? 793# We can just split the standard basis across the outputs. So to 794# compute the jacobian of f(x), we'd use 795# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...)) 796# The grad_outputs looks like the following: 797# ( torch.tensor([[1, 0, 0], 798# [0, 1, 0], 799# [0, 0, 1], 800# [0, 0, 0]]), 801# torch.tensor([[0], 802# [0], 803# [0], 804# [1]]) ) 805# 806# But we're not done yet! 807# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...))) 808# returns a Tensor of shape [4, 3]. We have to remember to split the 809# jacobian of shape [4, 3] into two: 810# - one of shape [3, 3] for the first output 811# - one of shape [ 3] for the second output 812 813 814def _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None): 815 # This function: 816 # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix. 817 # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`. 818 # - Each chunk corresponds to one tensor. The chunk has the same dtype and 819 # device as the tensor 820 # 821 # For example, with tensor_numels = [1, 2, 1], this function returns: 822 # ( tensor([[1], tensor([[0, 0], tensor([[0], 823 # [0], [1, 0], [0], 824 # [0], [0, 1], [0], 825 # [0]]) , [0, 0]]) , [1]]) ) 826 # 827 # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors) 828 # Precondition: tensors always has at least one element. 829 # 830 # See NOTE: [Computing jacobian with vmap and grad for multiple tensors] 831 # for context behind this function. 832 # NOTE: Argument `chunk_size` is used to generate chunked basis instead of 833 # one huge basis matrix. `chunk_size` dictates the maximum size of the 834 # basis matrix along dim=0. 835 assert len(tensors) == len(tensor_numels) 836 assert len(tensors) > 0 837 assert chunk_size is None or chunk_size > 0 838 total_numel = sum(tensor_numels) 839 if chunk_size and chunk_size < total_numel: 840 chunk_numels = get_chunk_sizes(total_numel, chunk_size) 841 else: # chunk_size is None or chunk_size >= total_numel 842 chunk_size = total_numel 843 chunk_numels = [total_numel] 844 845 diag_start_indices = ( 846 0, 847 *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind(), 848 ) 849 850 for chunk_idx, total_numel in enumerate(chunk_numels): 851 chunks = tuple( 852 tensor.new_zeros(total_numel, tensor_numel) 853 for tensor, tensor_numel in zip(tensors, tensor_numels) 854 ) 855 856 for chunk, diag_start_idx in zip(chunks, diag_start_indices): 857 chunk.diagonal(diag_start_idx + chunk_idx * chunk_size).fill_(1) 858 chunks = tuple( 859 chunk.view(total_numel, *tensor.shape) 860 for chunk, tensor in zip(chunks, tensors) 861 ) 862 yield chunks 863 864 865def _construct_standard_basis_for(tensors, tensor_numels): 866 for basis in _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None): 867 return basis 868 869 870def _validate_and_wrap_argnum(argnum, num_args): 871 if not isinstance(argnum, int): 872 raise RuntimeError(f"argnum must be int, got: {type(argnum)}") 873 if argnum >= 0 and argnum < num_args: 874 return argnum 875 if argnum < 0 and argnum >= -num_args: 876 return argnum + num_args 877 raise RuntimeError(f"Got argnum={argnum}, but only {num_args} positional inputs") 878 879 880def _check_unique_non_empty(argnums): 881 if isinstance(argnums, tuple): 882 if len(argnums) == 0: 883 raise RuntimeError("argnums must be non-empty") 884 if len(set(argnums)) != len(argnums): 885 raise RuntimeError(f"argnums elements must be unique, got {argnums}") 886 887 888def _replace_args(old_args, new_args, argnums): 889 if isinstance(argnums, int): 890 if len(new_args) != 1: 891 raise RuntimeError( 892 f"new_args should be of size 1, was of size {len(new_args)}" 893 ) 894 return tuple( 895 new_args[0] if i == argnums else old_args[i] for i in range(len(old_args)) 896 ) 897 if isinstance(argnums, tuple): 898 if len(new_args) != len(argnums): 899 raise RuntimeError( 900 "new_args should have the same size as argnums. " 901 f"Argnums size {len(argnums)}, new_args size {len(new_args)}" 902 ) 903 904 def get_right_elem(i): 905 return new_args[argnums.index(i)] if i in argnums else old_args[i] 906 907 return tuple(get_right_elem(i) for i in range(len(old_args))) 908 raise RuntimeError(f"argnums must be int or Tuple[int, ...], got: {type(argnums)}") 909 910 911def _validate_and_wrap_argnums(argnums, num_args): 912 if isinstance(argnums, int): 913 return _validate_and_wrap_argnum(argnums, num_args) 914 if isinstance(argnums, tuple): 915 return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums) 916 raise AssertionError("Should never get here") 917 918 919def _slice_argnums(args, argnums, as_tuple=True): 920 if not isinstance(argnums, int) and not isinstance(argnums, tuple): 921 raise RuntimeError( 922 f"argnums must be int or Tuple[int, ...], got: {type(argnums)}" 923 ) 924 argnums = _validate_and_wrap_argnums(argnums, len(args)) 925 _check_unique_non_empty(argnums) 926 if isinstance(argnums, int): 927 if as_tuple: 928 return (args[argnums],) 929 else: 930 return args[argnums] 931 return tuple(args[i] for i in argnums) 932 933 934JVP_NESTING = 0 935 936 937def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None: 938 if not isinstance(elts, tuple): 939 raise RuntimeError( 940 f"{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}" 941 ) 942 for elt in elts: 943 if isinstance(elt, torch.Tensor): 944 continue 945 raise RuntimeError( 946 f"{api}: Expected {argname} to be a tuple of Tensors, got " 947 f"a tuple with an element of type {type(elt)}" 948 ) 949 if len(elts) == 0: 950 raise RuntimeError( 951 f"{api}: Expected {argname} to be a non-empty tuple of Tensors." 952 ) 953 954 955def assert_non_empty_tensor_output(output: List[Any], api: str) -> None: 956 if (len(output) == 1 and output[0] is None) or len(output) < 1: 957 raise RuntimeError( 958 f"{api}: Expected f to be a function that has non-empty output (got output = {output})" 959 ) 960 for o in output: 961 if not isinstance(o, torch.Tensor): 962 raise RuntimeError( 963 f"{api}: expected f(*primals) to return only tensors" 964 f", got unsupported type {type(o)}" 965 ) 966 967 968def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None: 969 if isinstance(output, torch.Tensor): 970 return 971 if not isinstance(output, tuple): 972 raise RuntimeError( 973 f"{api}: Expected output of f to be a Tensor or Tensors, got " 974 f"{type(output)}" 975 ) 976 if len(output) == 0: 977 raise RuntimeError( 978 f"{api}: Expected output of f to be a non-empty tuple of Tensors." 979 ) 980 for out in output: 981 if isinstance(out, torch.Tensor): 982 continue 983 raise RuntimeError( 984 f"{api}: Expected output of f to be a Tensor or Tensors, got " 985 f"{type(out)} as an output" 986 ) 987 988 989def assert_non_empty_list_of_tensors( 990 output: List[torch.Tensor], api: str, argname: str 991) -> None: 992 if len(output) == 0: 993 raise RuntimeError(f"{api}: Expected {argname} to contain at least one Tensor.") 994 for out in output: 995 if isinstance(out, torch.Tensor): 996 continue 997 raise RuntimeError( 998 f"{api}: Expected {argname} to only contain Tensors, got " f"{type(out)}" 999 ) 1000 1001 1002jvp_str = "jvp(f, primals, tangents)" 1003 1004 1005def safe_unpack_dual(dual, strict): 1006 if not isinstance(dual, torch.Tensor): 1007 raise RuntimeError( 1008 f"{jvp_str}: expected f(*args) to return only tensors" 1009 f", got unsupported type {type(dual)}" 1010 ) 1011 1012 primal, tangent = fwAD.unpack_dual(dual) 1013 if tangent is None: 1014 if strict: 1015 raise RuntimeError( 1016 "jvp(f, primals, tangents, strict=True): " 1017 "The output of f is independent of " 1018 "the inputs. This is not allowed with strict=True." 1019 ) 1020 tangent = torch.zeros_like(primal) 1021 return primal, tangent 1022 1023 1024@exposed_in("torch.func") 1025def jvp( 1026 func: Callable, 1027 primals: Any, 1028 tangents: Any, 1029 *, 1030 strict: bool = False, 1031 has_aux: bool = False, 1032): 1033 """ 1034 Standing for the Jacobian-vector product, returns a tuple containing 1035 the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at 1036 ``primals``" times ``tangents``. This is also known as forward-mode autodiff. 1037 1038 Args: 1039 func (function): A Python function that takes one or more arguments, 1040 one of which must be a Tensor, and returns one or more Tensors 1041 primals (Tensors): Positional arguments to ``func`` that must all be 1042 Tensors. The returned function will also be computing the 1043 derivative with respect to these arguments 1044 tangents (Tensors): The "vector" for which Jacobian-vector-product is 1045 computed. Must be the same structure and sizes as the inputs to 1046 ``func``. 1047 has_aux (bool): Flag indicating that ``func`` returns a 1048 ``(output, aux)`` tuple where the first element is the output of 1049 the function to be differentiated and the second element is 1050 other auxiliary objects that will not be differentiated. 1051 Default: False. 1052 1053 Returns: 1054 Returns a ``(output, jvp_out)`` tuple containing the output of ``func`` 1055 evaluated at ``primals`` and the Jacobian-vector product. 1056 If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple. 1057 1058 .. note:: 1059 You may see this API error out with "forward-mode AD not implemented 1060 for operator X". If so, please file a bug report and we will prioritize it. 1061 1062 jvp is useful when you wish to compute gradients of a function R^1 -> R^N 1063 1064 >>> from torch.func import jvp 1065 >>> x = torch.randn([]) 1066 >>> f = lambda x: x * torch.tensor([1., 2., 3]) 1067 >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) 1068 >>> assert torch.allclose(value, f(x)) 1069 >>> assert torch.allclose(grad, torch.tensor([1., 2, 3])) 1070 1071 :func:`jvp` can support functions with multiple inputs by passing in the 1072 tangents for each of the inputs 1073 1074 >>> from torch.func import jvp 1075 >>> x = torch.randn(5) 1076 >>> y = torch.randn(5) 1077 >>> f = lambda x, y: (x * y) 1078 >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) 1079 >>> assert torch.allclose(output, x + y) 1080 1081 """ 1082 1083 return _jvp_with_argnums( 1084 func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux 1085 ) 1086 1087 1088def _jvp_with_argnums( 1089 func: Callable, 1090 primals: Any, 1091 tangents: Any, 1092 argnums: Optional[argnums_t], 1093 *, 1094 strict: bool = False, 1095 has_aux: bool, 1096): 1097 # This is the same function as jvp but also accepts an argnums argument 1098 # Most args are the same as jvp except for the added argument 1099 # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to. 1100 # If None, computes the gradients with respect to all inputs (used for jvp). Default: None 1101 # Because of this, tangents must be of length argnums and matches up to the corresponding primal whose index is 1102 # given by argnums 1103 # 1104 # WARN: Users should NOT call this function directly and should just be calling jvp. 1105 # It is only separated so that inputs passed to jacfwd but not differentiated get the correct wrappers. 1106 # 1107 # NOTE: All error messages are produced as if jvp was being called, even if this was called by jacfwd 1108 # 1109 # Returns the same two elements as :func:`jvp` but the returned tuple, ``jvp_out``, only has JVPs with respect to 1110 # the primals given by argnums 1111 if not isinstance(primals, tuple): 1112 raise RuntimeError( 1113 f"{jvp_str}: Expected primals to be a tuple. " 1114 f"E.g. it should be valid to call f(*primals)." 1115 ) 1116 diff_args = primals if argnums is None else _slice_argnums(primals, argnums) 1117 flat_primals, primals_spec = tree_flatten(diff_args) 1118 flat_tangents, tangents_spec = tree_flatten(tangents) 1119 _jvp_treespec_compare(diff_args, tangents) 1120 assert_non_empty_list_of_tensors(flat_primals, jvp_str, "primals") 1121 assert_non_empty_list_of_tensors(flat_tangents, jvp_str, "tangents") 1122 1123 global JVP_NESTING 1124 1125 with jvp_increment_nesting() as level: 1126 with fwAD._set_fwd_grad_enabled(True): 1127 ctx = fwAD.dual_level if JVP_NESTING == 1 else contextlib.nullcontext 1128 with ctx(): 1129 flat_duals = tuple( 1130 fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents) 1131 ) 1132 duals = tree_unflatten(flat_duals, primals_spec) 1133 # Note for the reviewer: This is extremely odd but it passes the 1134 # assertion "len(self.block_stack) == 1" on symbolic_convert.py 1135 # The equivalent "if argnums is not None" fails for some reason 1136 if isinstance(argnums, (int, tuple)): 1137 primals = _wrap_all_tensors(primals, level) 1138 duals = _replace_args(primals, duals, argnums) 1139 result_duals = func(*duals) 1140 if has_aux: 1141 if not (isinstance(result_duals, tuple) and len(result_duals) == 2): 1142 raise RuntimeError( 1143 f"{jvp_str}: output of function f should be a tuple: (output, aux) " 1144 "if has_aux is True" 1145 ) 1146 result_duals, aux = result_duals 1147 aux = _undo_create_differentiable(aux, level) 1148 1149 result_duals, spec = tree_flatten(result_duals) 1150 assert_non_empty_tensor_output(result_duals, jvp_str) 1151 1152 primals_out, tangents_out = zip( 1153 *[safe_unpack_dual(dual, strict) for dual in result_duals] 1154 ) 1155 primals_out = tree_map( 1156 partial(_undo_create_differentiable, level=level), primals_out 1157 ) 1158 tangents_out = tree_map( 1159 partial(_undo_create_differentiable, level=level), tangents_out 1160 ) 1161 1162 primals_out_unflatten = tree_unflatten(primals_out, spec) 1163 tangents_out_unflatten = tree_unflatten(tangents_out, spec) 1164 if has_aux: 1165 return primals_out_unflatten, tangents_out_unflatten, aux 1166 1167 return primals_out_unflatten, tangents_out_unflatten 1168 1169 1170def safe_unflatten(tensor, dim, shape): 1171 if len(shape) == 0: 1172 assert tensor.shape[dim] == 1 1173 return tensor.squeeze(dim) 1174 return tensor.unflatten(dim, shape) 1175 1176 1177@exposed_in("torch.func") 1178def jacfwd( 1179 func: Callable, 1180 argnums: argnums_t = 0, 1181 has_aux: bool = False, 1182 *, 1183 randomness: str = "error", 1184): 1185 """ 1186 Computes the Jacobian of ``func`` with respect to the arg(s) at index 1187 ``argnum`` using forward-mode autodiff 1188 1189 Args: 1190 func (function): A Python function that takes one or more arguments, 1191 one of which must be a Tensor, and returns one or more Tensors 1192 argnums (int or Tuple[int]): Optional, integer or tuple of integers, 1193 saying which arguments to get the Jacobian with respect to. 1194 Default: 0. 1195 has_aux (bool): Flag indicating that ``func`` returns a 1196 ``(output, aux)`` tuple where the first element is the output of 1197 the function to be differentiated and the second element is 1198 auxiliary objects that will not be differentiated. 1199 Default: False. 1200 randomness(str): Flag indicating what type of randomness to use. 1201 See :func:`vmap` for more detail. Allowed: "different", "same", "error". 1202 Default: "error" 1203 1204 Returns: 1205 Returns a function that takes in the same inputs as ``func`` and 1206 returns the Jacobian of ``func`` with respect to the arg(s) at 1207 ``argnums``. If ``has_aux is True``, then the returned function 1208 instead returns a ``(jacobian, aux)`` tuple where ``jacobian`` 1209 is the Jacobian and ``aux`` is auxiliary objects returned by ``func``. 1210 1211 .. note:: 1212 You may see this API error out with "forward-mode AD not implemented 1213 for operator X". If so, please file a bug report and we will prioritize it. 1214 An alternative is to use :func:`jacrev`, which has better operator coverage. 1215 1216 A basic usage with a pointwise, unary operation will give a diagonal array 1217 as the Jacobian 1218 1219 >>> from torch.func import jacfwd 1220 >>> x = torch.randn(5) 1221 >>> jacobian = jacfwd(torch.sin)(x) 1222 >>> expected = torch.diag(torch.cos(x)) 1223 >>> assert torch.allclose(jacobian, expected) 1224 1225 :func:`jacfwd` can be composed with vmap to produce batched 1226 Jacobians: 1227 1228 >>> from torch.func import jacfwd, vmap 1229 >>> x = torch.randn(64, 5) 1230 >>> jacobian = vmap(jacfwd(torch.sin))(x) 1231 >>> assert jacobian.shape == (64, 5, 5) 1232 1233 If you would like to compute the output of the function as well as the 1234 jacobian of the function, use the ``has_aux`` flag to return the output 1235 as an auxiliary object: 1236 1237 >>> from torch.func import jacfwd 1238 >>> x = torch.randn(5) 1239 >>> 1240 >>> def f(x): 1241 >>> return x.sin() 1242 >>> 1243 >>> def g(x): 1244 >>> result = f(x) 1245 >>> return result, result 1246 >>> 1247 >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x) 1248 >>> assert torch.allclose(f_x, f(x)) 1249 1250 Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev` 1251 to produce Hessians 1252 1253 >>> from torch.func import jacfwd, jacrev 1254 >>> def f(x): 1255 >>> return x.sin().sum() 1256 >>> 1257 >>> x = torch.randn(5) 1258 >>> hessian = jacfwd(jacrev(f))(x) 1259 >>> assert torch.allclose(hessian, torch.diag(-x.sin())) 1260 1261 By default, :func:`jacfwd` computes the Jacobian with respect to the first 1262 input. However, it can compute the Jacboian with respect to a different 1263 argument by using ``argnums``: 1264 1265 >>> from torch.func import jacfwd 1266 >>> def f(x, y): 1267 >>> return x + y ** 2 1268 >>> 1269 >>> x, y = torch.randn(5), torch.randn(5) 1270 >>> jacobian = jacfwd(f, argnums=1)(x, y) 1271 >>> expected = torch.diag(2 * y) 1272 >>> assert torch.allclose(jacobian, expected) 1273 1274 Additionally, passing a tuple to ``argnums`` will compute the Jacobian 1275 with respect to multiple arguments 1276 1277 >>> from torch.func import jacfwd 1278 >>> def f(x, y): 1279 >>> return x + y ** 2 1280 >>> 1281 >>> x, y = torch.randn(5), torch.randn(5) 1282 >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y) 1283 >>> expectedX = torch.diag(torch.ones_like(x)) 1284 >>> expectedY = torch.diag(2 * y) 1285 >>> assert torch.allclose(jacobian[0], expectedX) 1286 >>> assert torch.allclose(jacobian[1], expectedY) 1287 1288 """ 1289 1290 def wrapper_fn(*args): 1291 error_if_complex("jacfwd", args, is_input=True) 1292 primals = args if argnums is None else _slice_argnums(args, argnums) 1293 flat_primals, primals_spec = tree_flatten(primals) 1294 flat_primals_numels = tuple(p.numel() for p in flat_primals) 1295 flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels) 1296 basis = tree_unflatten(flat_basis, primals_spec) 1297 1298 def push_jvp(basis): 1299 output = _jvp_with_argnums( 1300 func, args, basis, argnums=argnums, has_aux=has_aux 1301 ) 1302 # output[0] is the output of `func(*args)` 1303 error_if_complex("jacfwd", output[0], is_input=False) 1304 if has_aux: 1305 _, jvp_out, aux = output 1306 return jvp_out, aux 1307 _, jvp_out = output 1308 return jvp_out 1309 1310 results = vmap(push_jvp, randomness=randomness)(basis) 1311 if has_aux: 1312 results, aux = results 1313 # aux is in the standard basis format, e.g. NxN matrix 1314 # We need to fetch the first element as original `func` output 1315 flat_aux, aux_spec = tree_flatten(aux) 1316 flat_aux = [value[0] for value in flat_aux] 1317 aux = tree_unflatten(flat_aux, aux_spec) 1318 1319 jac_outs, spec = tree_flatten(results) 1320 # Most probably below output check can never raise an error 1321 # as jvp should test the output before 1322 # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)') 1323 1324 jac_outs_ins = tuple( 1325 tuple( 1326 safe_unflatten(jac_out_in, -1, primal.shape) 1327 for primal, jac_out_in in zip( 1328 flat_primals, 1329 jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1), 1330 ) 1331 ) 1332 for jac_out in jac_outs 1333 ) 1334 jac_outs_ins = tuple( 1335 tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins 1336 ) 1337 1338 if isinstance(argnums, int): 1339 jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins) 1340 if has_aux: 1341 return tree_unflatten(jac_outs_ins, spec), aux 1342 return tree_unflatten(jac_outs_ins, spec) 1343 1344 # Dynamo does not support HOP composition if their inner function is 1345 # annotated with @functools.wraps(...). We circumvent this issue by applying 1346 # wraps only if we're not tracing with dynamo. 1347 if not torch._dynamo.is_compiling(): 1348 wrapper_fn = wraps(func)(wrapper_fn) 1349 1350 return wrapper_fn 1351 1352 1353@exposed_in("torch.func") 1354def hessian(func, argnums=0): 1355 """ 1356 Computes the Hessian of ``func`` with respect to the arg(s) at index 1357 ``argnum`` via a forward-over-reverse strategy. 1358 1359 The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is 1360 a good default for good performance. It is possible to compute Hessians 1361 through other compositions of :func:`jacfwd` and :func:`jacrev` like 1362 ``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``. 1363 1364 Args: 1365 func (function): A Python function that takes one or more arguments, 1366 one of which must be a Tensor, and returns one or more Tensors 1367 argnums (int or Tuple[int]): Optional, integer or tuple of integers, 1368 saying which arguments to get the Hessian with respect to. 1369 Default: 0. 1370 1371 Returns: 1372 Returns a function that takes in the same inputs as ``func`` and 1373 returns the Hessian of ``func`` with respect to the arg(s) at 1374 ``argnums``. 1375 1376 .. note:: 1377 You may see this API error out with "forward-mode AD not implemented 1378 for operator X". If so, please file a bug report and we will prioritize it. 1379 An alternative is to use ``jacrev(jacrev(func))``, which has better 1380 operator coverage. 1381 1382 A basic usage with a R^N -> R^1 function gives a N x N Hessian: 1383 1384 >>> from torch.func import hessian 1385 >>> def f(x): 1386 >>> return x.sin().sum() 1387 >>> 1388 >>> x = torch.randn(5) 1389 >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) 1390 >>> assert torch.allclose(hess, torch.diag(-x.sin())) 1391 1392 """ 1393 return jacfwd(jacrev(func, argnums), argnums) 1394 1395 1396@doesnt_support_saved_tensors_hooks 1397def grad_and_value_impl(func, argnums, has_aux, args, kwargs) -> Callable: 1398 with grad_increment_nesting() as level: 1399 output, aux, grad_input = None, None, None 1400 # See NOTE [grad and vjp interaction with no_grad] 1401 with torch.enable_grad(): 1402 args = _wrap_all_tensors(args, level) 1403 kwargs = _wrap_all_tensors(kwargs, level) 1404 diff_args = _slice_argnums(args, argnums, as_tuple=False) 1405 tree_map_(partial(_create_differentiable, level=level), diff_args) 1406 1407 output = func(*args, **kwargs) 1408 if has_aux: 1409 if not (isinstance(output, tuple) and len(output) == 2): 1410 raise RuntimeError( 1411 "grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) " 1412 "if has_aux is True" 1413 ) 1414 output, aux = output 1415 1416 if not isinstance(output, torch.Tensor): 1417 raise RuntimeError( 1418 "grad_and_value(f)(*args): Expected f(*args) " 1419 f"to return a Tensor, got {type(output)}" 1420 ) 1421 if output.dim() != 0: 1422 raise RuntimeError( 1423 "grad_and_value(f)(*args): Expected f(*args) " 1424 "to return a scalar Tensor, got tensor with " 1425 f"{output.dim()} dims. Maybe you wanted to " 1426 "use the vjp or jacrev APIs instead?" 1427 ) 1428 1429 flat_diff_args, spec = tree_flatten(diff_args) 1430 1431 # NB: need create_graph so that backward pass isn't run in no_grad mode 1432 flat_outputs = _as_tuple(output) 1433 flat_grad_input = _autograd_grad( 1434 flat_outputs, flat_diff_args, create_graph=True 1435 ) 1436 grad_input = tree_unflatten(flat_grad_input, spec) 1437 1438 grad_input = _undo_create_differentiable(grad_input, level) 1439 output = _undo_create_differentiable(output, level) 1440 if has_aux: 1441 aux = _undo_create_differentiable(aux, level) 1442 1443 if has_aux: 1444 return grad_input, (output, aux) 1445 return grad_input, output 1446 1447 1448def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs): 1449 results = grad_and_value_impl(func, argnums, has_aux, args, kwargs) 1450 if has_aux: 1451 grad, (_, aux) = results 1452 return grad, aux 1453 grad, _ = results 1454 return grad 1455 1456 1457def _maybe_wrap_functional_tensor( 1458 maybe_tensor, level, *, _python_functionalize: bool = False 1459): 1460 if not isinstance(maybe_tensor, torch.Tensor): 1461 return maybe_tensor 1462 wrapped = _wrap_functional_tensor(maybe_tensor, level) 1463 _assert_wrapped_functional(maybe_tensor, wrapped) 1464 if _python_functionalize: 1465 out = FunctionalTensor(wrapped) 1466 torch._mirror_autograd_meta_to(maybe_tensor, out) 1467 return out 1468 return wrapped 1469 1470 1471def _wrap_all_tensors_to_functional( 1472 tensor_pytree, level, *, _python_functionalize: bool = False 1473): 1474 return tree_map( 1475 partial( 1476 lambda x: _maybe_wrap_functional_tensor( 1477 x, level, _python_functionalize=_python_functionalize 1478 ) 1479 ), 1480 tensor_pytree, 1481 ) 1482 1483 1484def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool): 1485 if not isinstance(maybe_tensor, torch.Tensor): 1486 return maybe_tensor 1487 if isinstance(maybe_tensor, FunctionalTensor): 1488 maybe_tensor = maybe_tensor.elem 1489 1490 if not torch._is_functional_tensor(maybe_tensor): 1491 # If it's not a functional tensor, just return it. 1492 # This can happen if we functionalize a fn that returns a global, 1493 # which was never wrapped properly. 1494 return maybe_tensor 1495 # Sync any pending updates on the output tensor 1496 torch._sync(maybe_tensor) 1497 return _unwrap_functional_tensor(maybe_tensor, reapply_views) 1498 1499 1500def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool): 1501 return tree_map( 1502 lambda t: _maybe_unwrap_functional_tensor(t, reapply_views=reapply_views), 1503 tensor_pytree, 1504 ) 1505 1506 1507@exposed_in("torch.func") 1508def functionalize(func: Callable, *, remove: str = "mutations") -> Callable: 1509 """ 1510 functionalize is a transform that can be used to remove (intermediate) 1511 mutations and aliasing from a function, while preserving the function's 1512 semantics. 1513 1514 ``functionalize(func)`` returns a new function with the same semantics 1515 as ``func``, but with all intermediate mutations removed. 1516 Every inplace operation performed on an intermediate tensor: 1517 ``intermediate.foo_()`` 1518 gets replaced by its out-of-place equivalent: 1519 ``intermediate_updated = intermediate.foo()``. 1520 1521 functionalize is useful for shipping a pytorch program off to 1522 backends or compilers that aren't able to easily represent 1523 mutations or aliasing operators. 1524 1525 Args: 1526 func (Callable): A Python function that takes one or more arguments. 1527 remove (str): An optional string argument, that takes on either 1528 the value 'mutations' or 'mutations_and_views'. 1529 If 'mutations' is passed in then all mutating operators 1530 will be replaced with their non-mutating equivalents. 1531 If 'mutations_and_views' is passed in, then additionally, all aliasing 1532 operators will be replaced with their non-aliasing equivalents. 1533 Default: 'mutations'. 1534 1535 Returns: 1536 Returns a new "functionalized" function. It takes the same inputs as 1537 ``func``, and has the same behavior, but any mutations 1538 (and optionally aliasing) performed on intermediate tensors 1539 in the function will be removed. 1540 1541 functionalize will also remove mutations (and views) that were performed on function inputs. 1542 However to preserve semantics, functionalize will "fix up" the mutations after 1543 the transform has finished running, by detecting if any tensor inputs "should have" 1544 been mutated, and copying the new data back to the inputs if necessary. 1545 1546 1547 Example:: 1548 1549 >>> # xdoctest: +SKIP 1550 >>> import torch 1551 >>> from torch.fx.experimental.proxy_tensor import make_fx 1552 >>> from torch.func import functionalize 1553 >>> 1554 >>> # A function that uses mutations and views, but only on intermediate tensors. 1555 >>> def f(a): 1556 ... b = a + 1 1557 ... c = b.view(-1) 1558 ... c.add_(1) 1559 ... return b 1560 ... 1561 >>> inpt = torch.randn(2) 1562 >>> 1563 >>> out1 = f(inpt) 1564 >>> out2 = functionalize(f)(inpt) 1565 >>> 1566 >>> # semantics are the same (outputs are equivalent) 1567 >>> print(torch.allclose(out1, out2)) 1568 True 1569 >>> 1570 >>> f_traced = make_fx(f)(inpt) 1571 >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt) 1572 >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) 1573 >>> 1574 >>> print(f_traced.code) 1575 1576 1577 1578 def forward(self, a_1): 1579 add = torch.ops.aten.add(a_1, 1); a_1 = None 1580 view = torch.ops.aten.view(add, [-1]) 1581 add_ = torch.ops.aten.add_(view, 1); view = None 1582 return add 1583 1584 >>> print(f_no_mutations_traced.code) 1585 1586 1587 1588 def forward(self, a_1): 1589 add = torch.ops.aten.add(a_1, 1); a_1 = None 1590 view = torch.ops.aten.view(add, [-1]); add = None 1591 add_1 = torch.ops.aten.add(view, 1); view = None 1592 view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None 1593 return view_1 1594 1595 >>> print(f_no_mutations_and_views_traced.code) 1596 1597 1598 1599 def forward(self, a_1): 1600 add = torch.ops.aten.add(a_1, 1); a_1 = None 1601 view_copy = torch.ops.aten.view_copy(add, [-1]); add = None 1602 add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None 1603 view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None 1604 return view_copy_1 1605 1606 1607 >>> # A function that mutates its input tensor 1608 >>> def f(a): 1609 ... b = a.view(-1) 1610 ... b.add_(1) 1611 ... return a 1612 ... 1613 >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) 1614 >>> # 1615 >>> # All mutations and views have been removed, 1616 >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input 1617 >>> # after the function has completed. 1618 >>> print(f_no_mutations_and_views_traced.code) 1619 1620 1621 1622 def forward(self, a_1): 1623 view_copy = torch.ops.aten.view_copy(a_1, [-1]) 1624 add = torch.ops.aten.add(view_copy, 1); view_copy = None 1625 view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None 1626 copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None 1627 return view_copy_1 1628 1629 1630 There are a few "failure modes" for functionalize that are worth calling out: 1631 (1) Like other torch.func transforms, `functionalize()` doesn't work with functions 1632 that directly use `.backward()`. The same is true for torch.autograd.grad. 1633 If you want to use autograd, you can compute gradients directly 1634 with `functionalize(grad(f))`. 1635 (2) Like other torch.func transforms, `functionalize()` doesn't work with global state. 1636 If you call `functionalize(f)` on a function that takes views / mutations of 1637 non-local state, functionalization will simply no-op and pass the view/mutation 1638 calls directly to the backend. 1639 One way to work around this is is to ensure that any non-local state creation 1640 is wrapped into a larger function, which you then call functionalize on. 1641 (3) `resize_()` has some limitations: functionalize will only work on programs 1642 that use resize_()` as long as the tensor being resized is not a view. 1643 (4) `as_strided()` has some limitations: functionalize will not work on 1644 `as_strided()` calls that result in tensors with overlapping memory. 1645 1646 1647 Finally, a helpful mental model for understanding functionalization is that 1648 most user pytorch programs are writing with the public torch API. 1649 When executed, torch operators are generally decomposed into 1650 our internal C++ "ATen" API. 1651 The logic for functionalization happens entirely at the level of ATen. 1652 Functionalization knows how to take every aliasing operator in ATen, 1653 and map it to its non-aliasing equivalent 1654 (e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``), 1655 and how to take every mutating operator in ATen, 1656 and map it to its non-mutating equivalent 1657 (e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``), 1658 while tracking aliases and mutations out-of-line to know when to fix things up. 1659 Information about which ATen operators are aliasing or mutating all comes from 1660 https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml. 1661 """ 1662 if remove == "mutations": 1663 reapply_views = True 1664 elif remove == "mutations_and_views": 1665 reapply_views = False 1666 else: 1667 raise RuntimeError( 1668 f"functionalize(f, remove='mutations'): received invalid argument for remove={remove}." 1669 " Valid options are:\n" 1670 " remove='mutations': all inplace and out= operators will be removed from the program, and replaced" 1671 " with their out-of-place equivalents.\n" 1672 " remove='mutations_and_views': In addition to the above, all aliasing operators {view} will be" 1673 " replaced with their non-aliasing counterparts, {view}_copy.\n" 1674 ) 1675 1676 @wraps(func) 1677 def wrapped(*args, **kwargs): 1678 try: 1679 func_level = _func_increment_nesting(reapply_views) 1680 func_args = _wrap_all_tensors_to_functional(args, func_level) 1681 func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level) 1682 1683 flattened_unwrapped_args = pytree.arg_tree_leaves(*args) 1684 flattened_wrapped_args = pytree.arg_tree_leaves(*func_args) 1685 flattened_unwrapped_kwargs = pytree.arg_tree_leaves(**kwargs) 1686 flattened_wrapped_kwargs = pytree.arg_tree_leaves(**func_kwargs) 1687 1688 func_outputs = func(*func_args, **func_kwargs) 1689 outputs = _unwrap_all_tensors_from_functional( 1690 func_outputs, reapply_views=reapply_views 1691 ) 1692 flat_outputs, func_out_spec = tree_flatten(outputs) 1693 1694 for a in flattened_wrapped_args + flattened_wrapped_kwargs: 1695 if isinstance(a, torch.Tensor): 1696 # Call sync_() on the inputs, to ensure that any pending mutations have been applied. 1697 torch._sync(a) 1698 1699 # And if any mutations were applied to the inputs, we need to propagate them back to the user. 1700 for unwrapped, wrapped in zip( 1701 flattened_unwrapped_args, flattened_wrapped_args 1702 ): 1703 if isinstance(unwrapped, torch.Tensor) and isinstance( 1704 wrapped, torch.Tensor 1705 ): 1706 _propagate_functional_input_mutation(unwrapped, wrapped) 1707 for unwrapped, wrapped in zip( 1708 flattened_unwrapped_kwargs, flattened_wrapped_kwargs 1709 ): 1710 if isinstance(unwrapped, torch.Tensor) and isinstance( 1711 wrapped, torch.Tensor 1712 ): 1713 _propagate_functional_input_mutation(unwrapped, wrapped) 1714 1715 return outputs 1716 finally: 1717 _func_decrement_nesting() 1718 1719 return wrapped 1720 1721 1722@exposed_in("torch.func") 1723def linearize(func: Callable, *primals) -> Tuple[Any, Callable]: 1724 """ 1725 Returns the value of ``func`` at ``primals`` and linear approximation 1726 at ``primals``. 1727 1728 Args: 1729 func (Callable): A Python function that takes one or more arguments. 1730 primals (Tensors): Positional arguments to ``func`` that must all be 1731 Tensors. These are the values at which the function is linearly approximated. 1732 1733 Returns: 1734 Returns a ``(output, jvp_fn)`` tuple containing the output of ``func`` 1735 applied to ``primals`` and a function that computes the jvp of 1736 ``func`` evaluated at ``primals``. 1737 1738 linearize is useful if jvp is to be computed multiple times at ``primals``. However, 1739 to achieve this, linearize saves intermediate computation and has higher memory requirements 1740 than directly applying `jvp`. So, if all the ``tangents`` are known, it maybe more efficient 1741 to compute vmap(jvp) instead of using linearize. 1742 1743 .. note:: 1744 linearize evaluates ``func`` twice. Please file an issue for an implementation 1745 with a single evaluation. 1746 1747 Example:: 1748 >>> import torch 1749 >>> from torch.func import linearize 1750 >>> def fn(x): 1751 ... return x.sin() 1752 ... 1753 >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) 1754 >>> jvp_fn(torch.ones(3, 3)) 1755 tensor([[1., 1., 1.], 1756 [1., 1., 1.], 1757 [1., 1., 1.]]) 1758 >>> 1759 1760 """ 1761 # Note: We evaluate `fn` twice. 1762 # Once for returning the output and other while 1763 # tracing the graph. 1764 # If this becomes a bottle-neck, we should update 1765 # make_fx such that it also returns the output. 1766 1767 output = func(*primals) 1768 _, output_spec = tree_flatten(output) 1769 1770 flat_primals, primals_argspec = tree_flatten(primals) 1771 1772 # tangents for tracing 1773 flat_tangents = tuple(p.new_empty(()).expand_as(p) for p in flat_primals) 1774 1775 # function to trace 1776 def trace_fn(flat_tangents): 1777 with fwAD.dual_level(): 1778 flat_duals = tuple( 1779 fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents) 1780 ) 1781 duals = tree_unflatten(flat_duals, primals_argspec) 1782 output = func(*duals) 1783 tangents = tree_map_only( 1784 torch.Tensor, lambda dual: safe_unpack_dual(dual, False)[1], output 1785 ) 1786 1787 return tangents 1788 1789 jvp_graph = lazy_dynamo_disallow(make_fx)(trace_fn)(flat_tangents) 1790 const_folded_jvp_graph = lazy_dynamo_disallow(const_fold.split_const_subgraphs)( 1791 jvp_graph 1792 ) 1793 1794 # Hold only the meta-data regarding the primals. 1795 flat_primals_shape = tuple(p.shape for p in flat_primals) 1796 flat_primals_device = tuple(p.device for p in flat_primals) 1797 flat_primals_dtype = tuple(p.dtype for p in flat_primals) 1798 1799 def forward_ad_checks(flat_tangents): 1800 for idx, t in enumerate(flat_tangents): 1801 if t.shape != flat_primals_shape[idx]: 1802 msg = ( 1803 f"tangent:{idx} with shape {t.shape} in flattened " 1804 f"pytree doesn't match the shape {flat_primals_shape[idx]} " 1805 "of the corresponding primal." 1806 ) 1807 raise RuntimeError(msg) 1808 1809 if t.device != flat_primals_device[idx]: 1810 msg = ( 1811 f"tangent:{idx} with device {t.device} in flattened " 1812 f"pytree doesn't match the device {flat_primals_device[idx]} " 1813 "of the corresponding primal." 1814 ) 1815 raise RuntimeError(msg) 1816 1817 if t.dtype != flat_primals_dtype[idx]: 1818 msg = ( 1819 f"tangent:{idx} with dtype {t.dtype} in flattened " 1820 f"pytree doesn't match the dtype {flat_primals_dtype[idx]} " 1821 "of the corresponding primal." 1822 ) 1823 raise RuntimeError(msg) 1824 1825 # jvp_fn : callable to return 1826 # It takes care of checking the argspec of tangents, 1827 # calling the folded fx graph and unflattening fx graph output 1828 def jvp_fn(*tangents): 1829 flat_tangents, tangent_argspec = tree_flatten(tangents) 1830 _linearize_treespec_compare(primals, tangents) 1831 1832 forward_ad_checks(flat_tangents) 1833 1834 flat_output = const_folded_jvp_graph(*flat_tangents) 1835 # const folded graph can return flat output, 1836 # so transform output. 1837 return tree_unflatten(flat_output, output_spec) 1838 1839 return output, jvp_fn 1840