xref: /aosp_15_r20/external/pytorch/torch/_functorch/eager_transforms.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import contextlib
10from functools import partial, wraps
11from typing import Any, Callable, List, Optional, Tuple, Union
12
13import torch
14import torch.autograd.forward_ad as fwAD
15from torch._C._functorch import (
16    _assert_wrapped_functional,
17    _func_decrement_nesting,
18    _func_increment_nesting,
19    _grad_decrement_nesting,
20    _grad_increment_nesting,
21    _jvp_decrement_nesting,
22    _jvp_increment_nesting,
23    _propagate_functional_input_mutation,
24    _unwrap_for_grad,
25    _unwrap_functional_tensor,
26    _wrap_for_grad,
27    _wrap_functional_tensor,
28    get_inplace_requires_grad_allowed,
29    set_inplace_requires_grad_allowed,
30)
31from torch._functorch.utils import argnums_t, exposed_in
32from torch._subclasses.functional_tensor import FunctionalTensor
33from torch.fx.experimental import const_fold
34from torch.fx.experimental.proxy_tensor import make_fx
35from torch.utils import _pytree as pytree
36from torch.utils._pytree import (
37    tree_flatten,
38    tree_map,
39    tree_map_,
40    tree_map_only,
41    tree_unflatten,
42    treespec_pprint,
43)
44
45from .apis import vmap
46from .vmap import doesnt_support_saved_tensors_hooks, get_chunk_sizes
47
48
49def lazy_dynamo_disallow(func):
50    import torch._dynamo
51
52    return torch._dynamo.disallow_in_graph(func)
53
54
55@contextlib.contextmanager
56def enable_inplace_requires_grad(enabled):
57    prev_state = get_inplace_requires_grad_allowed()
58    set_inplace_requires_grad_allowed(enabled)
59    try:
60        yield
61    finally:
62        set_inplace_requires_grad_allowed(prev_state)
63
64
65def _vjp_treespec_compare(primals_out, cotangents):
66    # Revert this once #116264 gets fixed
67    _, primals_out_spec = tree_flatten(primals_out)
68    _, cotangents_spec = tree_flatten(cotangents)
69    # Dynamo fails to trace operator.ne below. To bypass this limitation, this
70    # function is not inlined.
71    if primals_out_spec != cotangents_spec:
72        raise RuntimeError(
73            f"Expected pytree structure of cotangents to be the same "
74            f"as pytree structure of outputs to the function. "
75            f"cotangents: {treespec_pprint(cotangents_spec)}, "
76            f"primal output: {treespec_pprint(primals_out_spec)}"
77        )
78
79
80def _jvp_treespec_compare(primals, tangents):
81    # Revert this once #116264 gets fixed
82    _, primals_spec = tree_flatten(primals)
83    _, tangents_spec = tree_flatten(tangents)
84    if primals_spec != tangents_spec:
85        raise RuntimeError(
86            f"{jvp_str}: Expected primals and tangents to have the same python "
87            f"structure. For example, if primals is a tuple of 3 tensors, "
88            f"tangents also must be. Got primals with structure {primals_spec} "
89            f"and tangents with structure {tangents_spec}"
90        )
91
92
93def _linearize_treespec_compare(primals, tangents):
94    # Revert this once #116264 gets fixed
95    _, primals_argspec = tree_flatten(primals)
96    _, tangent_argspec = tree_flatten(tangents)
97    if tangent_argspec != primals_argspec:
98        raise RuntimeError(
99            f"Expected the tangents {tangent_argspec} to have "
100            f"the same argspec as the primals {primals_argspec}"
101        )
102
103
104def _set_tensor_requires_grad(x):
105    # avoid graph-break on x.requires_grad_()
106    # https://github.com/pytorch/pytorch/pull/110053
107    return x.requires_grad_()
108
109
110def _create_differentiable(inps, level=None):
111    def create_differentiable(x):
112        if isinstance(x, torch.Tensor):
113            with enable_inplace_requires_grad(True):
114                return _set_tensor_requires_grad(x)
115        raise ValueError(
116            f"Thing passed to transform API must be Tensor, " f"got {type(x)}"
117        )
118
119    return tree_map(create_differentiable, inps)
120
121
122def _undo_create_differentiable(inps, level=None):
123    def unwrap_tensors(x):
124        if isinstance(x, torch.Tensor):
125            return _unwrap_for_grad(x, level)
126        # TODO: Remove the following hack for namedtuples
127        if isinstance(x, tuple):
128            return tree_map(unwrap_tensors, tuple(x))
129
130        raise RuntimeError(f"Expected tensors, got unsupported type {type(x)}")
131
132    return tree_map(unwrap_tensors, inps)
133
134
135def _is_differentiable(maybe_tensor):
136    if not isinstance(maybe_tensor, torch.Tensor):
137        return False
138    return maybe_tensor.requires_grad
139
140
141def _any_differentiable(tensor_or_tuple_of_tensors):
142    flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors)
143    return any(tuple(map(_is_differentiable, flat_args)))
144
145
146def _wrap_tensor_for_grad(maybe_tensor, level):
147    if not isinstance(maybe_tensor, torch.Tensor):
148        return maybe_tensor
149    return _wrap_for_grad(maybe_tensor, level)
150
151
152def _wrap_all_tensors(tensor_pytree, level):
153    return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_pytree)
154
155
156def _as_tuple(val):
157    if isinstance(val, tuple):
158        return val
159    return (val,)
160
161
162# Version of autograd.grad that handles outputs that don't depend on inputs
163
164
165def _autograd_grad(
166    outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
167):
168    if grad_outputs is None:
169        diff_outputs = tuple(out for out in outputs if out.requires_grad)
170    else:
171        result = tuple(
172            (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
173        )
174        if len(result) == 0:
175            diff_outputs, grad_outputs = (), ()
176        else:
177            diff_outputs, grad_outputs = zip(*result)
178    if len(diff_outputs) == 0:
179        return tuple(torch.zeros_like(inp) for inp in inputs)
180    grad_inputs = torch.autograd.grad(
181        diff_outputs,
182        inputs,
183        grad_outputs,
184        retain_graph=retain_graph,
185        create_graph=create_graph,
186        allow_unused=True,
187    )
188    grad_inputs = tuple(
189        torch.zeros_like(inp) if gi is None else gi
190        for gi, inp in zip(grad_inputs, inputs)
191    )
192    return grad_inputs
193
194
195# NOTE [grad and vjp interaction with no_grad]
196#
197# def f(x):
198#   with torch.no_grad():
199#     c = x ** 2
200#   return x - c
201#
202# The thing to consider is if enable_grad is on/off before grad gets called.
203#
204# Case 1: enable_grad is on.
205# grad(f)(x)
206# In this case, `grad` should respect the inner torch.no_grad.
207#
208# Case 2: enable_grad is off
209# with torch.no_grad():
210#   grad(f)(x)
211# In this case, `grad` should respect the inner torch.no_grad, but not the
212# outer one. This is because `grad` is a "function transform": its result
213# should not depend on the result of a context manager outside of `f`.
214#
215# This gives us the following desired behavior:
216# - (nested) grad transforms must obey torch.no_grad inside them
217# - (nested) grad transforms should not obey torch.no_grad outside them
218#
219# To achieve this behavior, upon entering grad/vjp:
220# - we save the current ("previous") is_grad_enabled (*)
221# - we unconditionally enable grad.
222#
223# Inside DynamicLayerBackFallback, when we're temporarily popping `grad` layer
224# off the stack:
225# - if grad_mode is disabled, then we do nothing. (there is a torch.no_grad
226#   active, all subsequent grad transforms must obey it).
227# - if grad_mode is enabled, and the previous is_grad_enabled (*) is False,
228#   then we temporarily restore the previous `is_grad_enabled`. This is
229#   because we're crossing the boundary from a `grad` outside the
230#   no_grad to a `grad` inside the no_grad.
231#
232# NB: vjp has some interesting behavior because the vjp's callable can be called
233# under a different grad_mode than the forward computation...
234#
235# NB: forward-mode AD: forward-mode AD doesn't respect torch.no_grad, but
236# it respects c10::AutoFwGradMode. We've implemented the same logic for
237# our jvp transform (it will have special handling if FwGradMode is disabled).
238
239
240# How do we increment and decrement the nesting? I don't think we can.
241@exposed_in("torch.func")
242def vjp(func: Callable, *primals, has_aux: bool = False):
243    """
244    Standing for the vector-Jacobian product, returns a tuple containing the
245    results of ``func`` applied to ``primals`` and a function that, when
246    given ``cotangents``, computes the reverse-mode Jacobian of ``func`` with
247    respect to ``primals`` times ``cotangents``.
248
249    Args:
250        func (Callable): A Python function that takes one or more arguments. Must
251            return one or more Tensors.
252        primals (Tensors): Positional arguments to ``func`` that must all be
253            Tensors. The returned function will also be computing the
254            derivative with respect to these arguments
255        has_aux (bool): Flag indicating that ``func`` returns a
256            ``(output, aux)`` tuple where the first element is the output of
257            the function to be differentiated and the second element is
258            other auxiliary objects that will not be differentiated.
259            Default: False.
260
261    Returns:
262        Returns a ``(output, vjp_fn)`` tuple containing the output of ``func``
263        applied to ``primals`` and a function that computes the vjp of
264        ``func`` with respect to all ``primals`` using the cotangents passed
265        to the returned function. If ``has_aux is True``, then instead returns a
266        ``(output, vjp_fn, aux)`` tuple.
267        The returned ``vjp_fn`` function will return a tuple of each VJP.
268
269    When used in simple cases, :func:`vjp` behaves the same as :func:`grad`
270
271        >>> x = torch.randn([5])
272        >>> f = lambda x: x.sin().sum()
273        >>> (_, vjpfunc) = torch.func.vjp(f, x)
274        >>> grad = vjpfunc(torch.tensor(1.))[0]
275        >>> assert torch.allclose(grad, torch.func.grad(f)(x))
276
277    However, :func:`vjp` can support functions with multiple outputs by
278    passing in the cotangents for each of the outputs
279
280        >>> x = torch.randn([5])
281        >>> f = lambda x: (x.sin(), x.cos())
282        >>> (_, vjpfunc) = torch.func.vjp(f, x)
283        >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
284        >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
285
286    :func:`vjp` can even support outputs being Python structs
287
288        >>> x = torch.randn([5])
289        >>> f = lambda x: {'first': x.sin(), 'second': x.cos()}
290        >>> (_, vjpfunc) = torch.func.vjp(f, x)
291        >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])}
292        >>> vjps = vjpfunc(cotangents)
293        >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
294
295    The function returned by :func:`vjp` will compute the partials with
296    respect to each of the ``primals``
297
298        >>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
299        >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
300        >>> cotangents = torch.randn([5, 5])
301        >>> vjps = vjpfunc(cotangents)
302        >>> assert len(vjps) == 2
303        >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
304        >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
305
306    ``primals`` are the positional arguments for ``f``. All kwargs use their
307    default value
308
309        >>> x = torch.randn([5])
310        >>> def f(x, scale=4.):
311        >>>   return x * scale
312        >>>
313        >>> (_, vjpfunc) = torch.func.vjp(f, x)
314        >>> vjps = vjpfunc(torch.ones_like(x))
315        >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
316
317    .. note::
318        Using PyTorch ``torch.no_grad`` together with ``vjp``.
319        Case 1: Using ``torch.no_grad`` inside a function:
320
321            >>> def f(x):
322            >>>     with torch.no_grad():
323            >>>         c = x ** 2
324            >>>     return x - c
325
326        In this case, ``vjp(f)(x)`` will respect the inner ``torch.no_grad``.
327
328        Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager:
329
330            >>> # xdoctest: +SKIP(failing)
331            >>> with torch.no_grad():
332            >>>     vjp(f)(x)
333
334        In this case, ``vjp`` will respect the inner ``torch.no_grad``, but not the
335        outer one. This is because ``vjp`` is a "function transform": its result
336        should not depend on the result of a context manager outside of ``f``.
337    """
338    return _vjp_with_argnums(func, *primals, has_aux=has_aux)
339
340
341@contextlib.contextmanager
342def grad_increment_nesting():
343    try:
344        grad_level = _grad_increment_nesting()
345        yield grad_level
346    finally:
347        _grad_decrement_nesting()
348
349
350def enter_jvp_nesting():
351    global JVP_NESTING
352    jvp_level = _jvp_increment_nesting()
353    JVP_NESTING += 1
354    return jvp_level
355
356
357def exit_jvp_nesting():
358    global JVP_NESTING
359    _jvp_decrement_nesting()
360    JVP_NESTING -= 1
361
362
363@contextlib.contextmanager
364def jvp_increment_nesting():
365    try:
366        yield enter_jvp_nesting()
367    finally:
368        exit_jvp_nesting()
369
370
371@doesnt_support_saved_tensors_hooks
372def _vjp_with_argnums(
373    func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False
374):
375    # This is the same function as vjp but also accepts an argnums argument
376    # All args are the same as vjp except for the added argument
377    # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.
378    #         If None, computes the gradients with respect to all inputs (used for vjp). Default: None
379    #
380    # WARN: Users should NOT call this function directly and should just be calling vjp.
381    # It is only separated so that inputs passed to jacrev but not differentiated get the correct wrappers.
382    #
383    # NOTE: All error messages are produced as if vjp was being called, even if this was called by jacrev
384    #
385    # Returns the same two elements as :func:`vjp` but the function returned, vjp_fn, returns a tuple of VJPs
386    # for only the primal elements given by argnums.
387    with grad_increment_nesting() as level:
388        # See NOTE [grad and vjp interaction with no_grad]
389        with torch.enable_grad():
390            primals = _wrap_all_tensors(primals, level)
391            # Note for the reviewer: This is extremely odd but it passes the
392            # assertion "len(self.block_stack) == 1" on symbolic_convert.py
393            # The equivalent "if argnums is None" fails for some reason
394            if not isinstance(argnums, int) and not argnums:
395                diff_primals = _create_differentiable(primals, level)
396            else:
397                diff_primals = _slice_argnums(primals, argnums, as_tuple=False)
398                tree_map_(partial(_create_differentiable, level=level), diff_primals)
399            primals_out = func(*primals)
400
401            if has_aux:
402                if not (isinstance(primals_out, tuple) and len(primals_out) == 2):
403                    raise RuntimeError(
404                        "vjp(f, *primals): output of function f should be a tuple: (output, aux) "
405                        "if has_aux is True"
406                    )
407                primals_out, aux = primals_out
408                aux = _undo_create_differentiable(aux, level)
409
410            flat_primals_out, primals_out_spec = tree_flatten(primals_out)
411            assert_non_empty_tensor_output(flat_primals_out, "vjp(f, *primals)")
412            flat_diff_primals, primals_spec = tree_flatten(diff_primals)
413            results = _undo_create_differentiable(primals_out, level)
414
415            for primal_out in flat_primals_out:
416                assert isinstance(primal_out, torch.Tensor)
417                if primal_out.is_floating_point() or primal_out.is_complex():
418                    continue
419                raise RuntimeError(
420                    "vjp(f, ...): All outputs of f must be "
421                    "floating-point or complex Tensors, got Tensor "
422                    f"with dtype {primal_out.dtype}"
423                )
424
425        def wrapper(cotangents, retain_graph=True, create_graph=None):
426            if create_graph is None:
427                create_graph = torch.is_grad_enabled()
428            flat_cotangents, cotangents_spec = tree_flatten(cotangents)
429            _vjp_treespec_compare(primals_out, cotangents)
430            result = _autograd_grad(
431                flat_primals_out,
432                flat_diff_primals,
433                flat_cotangents,
434                retain_graph=retain_graph,
435                create_graph=create_graph,
436            )
437            return tree_unflatten(result, primals_spec)
438
439    if has_aux:
440        return results, wrapper, aux
441    else:
442        return results, wrapper
443
444
445def _safe_zero_index(x):
446    assert len(x) == 1
447    return x[0]
448
449
450# jacrev and jacfwd don't support complex functions
451# Helper function to throw appropriate error.
452def error_if_complex(func_name, args, is_input):
453    flat_args = pytree.tree_leaves(args)
454    for idx, arg in enumerate(flat_args):
455        if isinstance(arg, torch.Tensor) and arg.dtype.is_complex:
456            input_or_output = "inputs" if is_input else "outputs"
457            err_msg = (
458                f"{func_name}: Expected all {input_or_output} "
459                f"to be real but received complex tensor at flattened input idx: {idx}"
460            )
461            raise RuntimeError(err_msg)
462
463
464@exposed_in("torch.func")
465def jacrev(
466    func: Callable,
467    argnums: Union[int, Tuple[int]] = 0,
468    *,
469    has_aux=False,
470    chunk_size: Optional[int] = None,
471    _preallocate_and_copy=False,
472):
473    """
474    Computes the Jacobian of ``func`` with respect to the arg(s) at index
475    ``argnum`` using reverse mode autodiff
476
477    .. note::
478        Using :attr:`chunk_size=1` is equivalent to computing the jacobian
479        row-by-row with a for-loop i.e. the constraints of :func:`vmap` are
480        not applicable.
481
482    Args:
483        func (function): A Python function that takes one or more arguments,
484            one of which must be a Tensor, and returns one or more Tensors
485        argnums (int or Tuple[int]): Optional, integer or tuple of integers,
486            saying which arguments to get the Jacobian with respect to.
487            Default: 0.
488        has_aux (bool): Flag indicating that ``func`` returns a
489            ``(output, aux)`` tuple where the first element is the output of
490            the function to be differentiated and the second element is
491            auxiliary objects that will not be differentiated.
492            Default: False.
493        chunk_size (None or int): If None (default), use the maximum chunk size
494            (equivalent to doing a single vmap over vjp to compute the jacobian).
495            If 1, then compute the jacobian row-by-row with a for-loop.
496            If not None, then compute the jacobian :attr:`chunk_size` rows at a time
497            (equivalent to doing multiple vmap over vjp). If you run into memory issues computing
498            the jacobian, please try to specify a non-None chunk_size.
499
500    Returns:
501        Returns a function that takes in the same inputs as ``func`` and
502        returns the Jacobian of ``func`` with respect to the arg(s) at
503        ``argnums``. If ``has_aux is True``, then the returned function
504        instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
505        is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
506
507    A basic usage with a pointwise, unary operation will give a diagonal array
508    as the Jacobian
509
510        >>> from torch.func import jacrev
511        >>> x = torch.randn(5)
512        >>> jacobian = jacrev(torch.sin)(x)
513        >>> expected = torch.diag(torch.cos(x))
514        >>> assert torch.allclose(jacobian, expected)
515
516    If you would like to compute the output of the function as well as the
517    jacobian of the function, use the ``has_aux`` flag to return the output
518    as an auxiliary object:
519
520        >>> from torch.func import jacrev
521        >>> x = torch.randn(5)
522        >>>
523        >>> def f(x):
524        >>>   return x.sin()
525        >>>
526        >>> def g(x):
527        >>>   result = f(x)
528        >>>   return result, result
529        >>>
530        >>> jacobian_f, f_x = jacrev(g, has_aux=True)(x)
531        >>> assert torch.allclose(f_x, f(x))
532
533    :func:`jacrev` can be composed with vmap to produce batched
534    Jacobians:
535
536        >>> from torch.func import jacrev, vmap
537        >>> x = torch.randn(64, 5)
538        >>> jacobian = vmap(jacrev(torch.sin))(x)
539        >>> assert jacobian.shape == (64, 5, 5)
540
541    Additionally, :func:`jacrev` can be composed with itself to produce
542    Hessians
543
544        >>> from torch.func import jacrev
545        >>> def f(x):
546        >>>   return x.sin().sum()
547        >>>
548        >>> x = torch.randn(5)
549        >>> hessian = jacrev(jacrev(f))(x)
550        >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
551
552    By default, :func:`jacrev` computes the Jacobian with respect to the first
553    input. However, it can compute the Jacboian with respect to a different
554    argument by using ``argnums``:
555
556        >>> from torch.func import jacrev
557        >>> def f(x, y):
558        >>>   return x + y ** 2
559        >>>
560        >>> x, y = torch.randn(5), torch.randn(5)
561        >>> jacobian = jacrev(f, argnums=1)(x, y)
562        >>> expected = torch.diag(2 * y)
563        >>> assert torch.allclose(jacobian, expected)
564
565    Additionally, passing a tuple to ``argnums`` will compute the Jacobian
566    with respect to multiple arguments
567
568        >>> from torch.func import jacrev
569        >>> def f(x, y):
570        >>>   return x + y ** 2
571        >>>
572        >>> x, y = torch.randn(5), torch.randn(5)
573        >>> jacobian = jacrev(f, argnums=(0, 1))(x, y)
574        >>> expectedX = torch.diag(torch.ones_like(x))
575        >>> expectedY = torch.diag(2 * y)
576        >>> assert torch.allclose(jacobian[0], expectedX)
577        >>> assert torch.allclose(jacobian[1], expectedY)
578
579    .. note::
580        Using PyTorch ``torch.no_grad`` together with ``jacrev``.
581        Case 1: Using ``torch.no_grad`` inside a function:
582
583            >>> def f(x):
584            >>>     with torch.no_grad():
585            >>>         c = x ** 2
586            >>>     return x - c
587
588        In this case, ``jacrev(f)(x)`` will respect the inner ``torch.no_grad``.
589
590        Case 2: Using ``jacrev`` inside ``torch.no_grad`` context manager:
591
592            >>> with torch.no_grad():
593            >>>     jacrev(f)(x)
594
595        In this case, ``jacrev`` will respect the inner ``torch.no_grad``, but not the
596        outer one. This is because ``jacrev`` is a "function transform": its result
597        should not depend on the result of a context manager outside of ``f``.
598    """
599    if not (chunk_size is None or chunk_size > 0):
600        raise ValueError("jacrev: `chunk_size` should be greater than 0.")
601
602    def wrapper_fn(*args):
603        error_if_complex("jacrev", args, is_input=True)
604        vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)
605        if has_aux:
606            output, vjp_fn, aux = vjp_out
607        else:
608            output, vjp_fn = vjp_out
609
610        # See NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
611        flat_output, output_spec = tree_flatten(output)
612
613        error_if_complex("jacrev", flat_output, is_input=False)
614
615        # NB: vjp already checks that all outputs are tensors
616        # Step 1: Construct grad_outputs by splitting the standard basis
617        flat_output_numels = tuple(out.numel() for out in flat_output)
618
619        primals = _slice_argnums(args, argnums)
620        flat_primals, primals_spec = tree_flatten(primals)
621
622        def compute_jacobian_stacked():
623            # Helper function to compute chunked Jacobian
624            # The intermediate chunked calculation are only
625            # scoped at this function level.
626            chunked_results = []
627            for flat_basis_chunk in _chunked_standard_basis_for_(
628                flat_output, flat_output_numels, chunk_size=chunk_size
629            ):
630                if chunk_size == 1:
631                    # sanity check.
632                    for t in flat_basis_chunk:
633                        assert t.size(0) == 1
634
635                    flat_basis_chunk = tree_map(
636                        lambda t: torch.squeeze(t, 0), flat_basis_chunk
637                    )
638
639                basis = tree_unflatten(flat_basis_chunk, output_spec)
640
641                if chunk_size == 1:
642                    # Behaviour with `chunk_size=1` is same as `for-loop`
643                    # i.e. user shouldn't deal with the limitations of vmap.
644                    chunked_result = vjp_fn(basis)
645                else:  # chunk_size is None or chunk_size != 1
646                    chunked_result = vmap(vjp_fn)(basis)
647
648                flat_results = pytree.tree_leaves(chunked_result)
649
650                if chunk_size == 1:
651                    flat_results = tree_map(
652                        lambda t: torch.unsqueeze(t, 0), flat_results
653                    )
654
655                chunked_results.append(flat_results)
656
657            if len(chunked_results) == 1:
658                # Short-circuit if we used a single chunk
659                return chunked_results[0]
660
661            # Concatenate chunks.
662            flat_results = []
663            # Iterate and concat the jacobians of different
664            # inputs.
665            for idx in range(len(flat_primals)):
666                r = tuple(r_[idx] for r_ in chunked_results)
667                flat_results.append(torch.cat(r, 0))
668
669            return flat_results
670
671        def compute_jacobian_preallocate_and_copy():
672            # Helper function to compute chunked Jacobian
673            # The intermediate chunked calculation are only
674            # scoped at this function level.
675            out_vec_size = sum(flat_output_numels)
676
677            # Don't pre-allocate if we have a single chunk.
678            if not (chunk_size is None or chunk_size >= out_vec_size):
679                stacked_results = [
680                    primal.new_zeros(out_vec_size, *primal.shape)
681                    for primal in flat_primals
682                ]
683
684            for idx, flat_basis_chunk in enumerate(
685                _chunked_standard_basis_for_(
686                    flat_output, flat_output_numels, chunk_size=chunk_size
687                )
688            ):
689                if chunk_size == 1:
690                    # sanity check.
691                    for t in flat_basis_chunk:
692                        assert t.size(0) == 1
693
694                    flat_basis_chunk = [torch.squeeze(t, 0) for t in flat_basis_chunk]
695
696                basis = tree_unflatten(flat_basis_chunk, output_spec)
697
698                if chunk_size == 1:
699                    # Behaviour with `chunk_size=1` is same as `for-loop`
700                    # i.e. user shouldn't deal with the limitations of vmap.
701                    chunked_result = vjp_fn(basis)
702                else:  # chunk_size is None or chunk_size != 1
703                    chunked_result = vmap(vjp_fn)(basis)
704
705                flat_results = pytree.tree_leaves(chunked_result)
706
707                # Short-circuit if we have a single chunk.
708                if chunk_size is None or chunk_size >= out_vec_size:
709                    if chunk_size == 1:  # and out_vec_size == 1
710                        # Since we squeezed the output dim
711                        flat_results = tree_map(
712                            lambda t: torch.unsqueeze(t, 0), flat_results
713                        )
714                    return flat_results
715
716                for r, sr in zip(flat_results, stacked_results):
717                    sr[idx * chunk_size : (idx + 1) * chunk_size].copy_(r)
718
719            return stacked_results
720
721        if _preallocate_and_copy:
722            flat_jacobians_per_input = compute_jacobian_preallocate_and_copy()
723        else:
724            flat_jacobians_per_input = compute_jacobian_stacked()
725
726        # Step 2: The returned jacobian is one big tensor per input. In this step,
727        # we split each Tensor by output.
728        flat_jacobians_per_input = [
729            result.split(flat_output_numels, dim=0)
730            for result in flat_jacobians_per_input
731        ]
732        flat_input_flat_output = [
733            tuple(
734                split.view(out.shape + primal.shape)
735                for split, out in zip(splits, flat_output)
736            )
737            for splits, primal in zip(flat_jacobians_per_input, flat_primals)
738        ]
739
740        # Step 3: Right now, `jacobian` is a List[List[Tensor]].
741        # The outer List corresponds to the number of primals,
742        # the inner List corresponds to the number of outputs.
743        # We need to:
744        # a. Exchange the order of the outer List and inner List
745        # b. tree_unflatten the inner Lists (which correspond to the primals)
746        # c. handle the argnums=int case
747        # d. tree_unflatten the outer List (which corresponds to the outputs)
748        flat_output_flat_input = tuple(zip(*flat_input_flat_output))
749
750        flat_output_input = tuple(
751            tree_unflatten(flat_input, primals_spec)
752            for flat_input in flat_output_flat_input
753        )
754
755        if isinstance(argnums, int):
756            flat_output_input = tuple(
757                _safe_zero_index(flat_input) for flat_input in flat_output_input
758            )
759        output_input = tree_unflatten(flat_output_input, output_spec)
760        if has_aux:
761            return output_input, aux
762        return output_input
763
764    # Dynamo does not support HOP composition if their inner function is
765    # annotated with @functools.wraps(...). We circumvent this issue by applying
766    # wraps only if we're not tracing with dynamo.
767    if not torch._dynamo.is_compiling():
768        wrapper_fn = wraps(func)(wrapper_fn)
769
770    return wrapper_fn
771
772
773# NOTE: [Computing jacobian with vmap and vjp for multiple outputs]
774#
775# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3).
776# It turns out we can compute the jacobian of this function with a single
777# call to autograd.grad by using vmap over the correct grad_outputs.
778#
779# Firstly, one way to compute the jacobian is to stack x**2 and x.sum()
780# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()])
781#
782# To get the first row of the jacobian, we call
783# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0]))
784# To get the 2nd row of the jacobian, we call
785# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0]))
786# and so on.
787#
788# Using vmap, we can vectorize all 4 of these computations into one by
789# passing the standard basis for R^4 as the grad_output.
790# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)).
791#
792# Now, how do we compute the jacobian *without stacking the output*?
793# We can just split the standard basis across the outputs. So to
794# compute the jacobian of f(x), we'd use
795# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...))
796# The grad_outputs looks like the following:
797# ( torch.tensor([[1, 0, 0],
798#                 [0, 1, 0],
799#                 [0, 0, 1],
800#                 [0, 0, 0]]),
801#   torch.tensor([[0],
802#                 [0],
803#                 [0],
804#                 [1]]) )
805#
806# But we're not done yet!
807# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...)))
808# returns a Tensor of shape [4, 3]. We have to remember to split the
809# jacobian of shape [4, 3] into two:
810# - one of shape [3, 3] for the first output
811# - one of shape [   3] for the second output
812
813
814def _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
815    # This function:
816    # - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix.
817    # - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`.
818    # - Each chunk corresponds to one tensor. The chunk has the same dtype and
819    #   device as the tensor
820    #
821    # For example, with tensor_numels = [1, 2, 1], this function returns:
822    # ( tensor([[1],     tensor([[0, 0],      tensor([[0],
823    #           [0],             [1, 0],              [0],
824    #           [0],             [0, 1],              [0],
825    #           [0]])  ,         [0, 0]])  ,          [1]])  )
826    #
827    # Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors)
828    # Precondition: tensors always has at least one element.
829    #
830    # See NOTE: [Computing jacobian with vmap and grad for multiple tensors]
831    # for context behind this function.
832    # NOTE: Argument `chunk_size` is used to generate chunked basis instead of
833    #       one huge basis matrix. `chunk_size` dictates the maximum size of the
834    #       basis matrix along dim=0.
835    assert len(tensors) == len(tensor_numels)
836    assert len(tensors) > 0
837    assert chunk_size is None or chunk_size > 0
838    total_numel = sum(tensor_numels)
839    if chunk_size and chunk_size < total_numel:
840        chunk_numels = get_chunk_sizes(total_numel, chunk_size)
841    else:  # chunk_size is None or chunk_size >= total_numel
842        chunk_size = total_numel
843        chunk_numels = [total_numel]
844
845    diag_start_indices = (
846        0,
847        *torch.tensor(tensor_numels).cumsum(dim=0)[:-1].neg().unbind(),
848    )
849
850    for chunk_idx, total_numel in enumerate(chunk_numels):
851        chunks = tuple(
852            tensor.new_zeros(total_numel, tensor_numel)
853            for tensor, tensor_numel in zip(tensors, tensor_numels)
854        )
855
856        for chunk, diag_start_idx in zip(chunks, diag_start_indices):
857            chunk.diagonal(diag_start_idx + chunk_idx * chunk_size).fill_(1)
858        chunks = tuple(
859            chunk.view(total_numel, *tensor.shape)
860            for chunk, tensor in zip(chunks, tensors)
861        )
862        yield chunks
863
864
865def _construct_standard_basis_for(tensors, tensor_numels):
866    for basis in _chunked_standard_basis_for_(tensors, tensor_numels, chunk_size=None):
867        return basis
868
869
870def _validate_and_wrap_argnum(argnum, num_args):
871    if not isinstance(argnum, int):
872        raise RuntimeError(f"argnum must be int, got: {type(argnum)}")
873    if argnum >= 0 and argnum < num_args:
874        return argnum
875    if argnum < 0 and argnum >= -num_args:
876        return argnum + num_args
877    raise RuntimeError(f"Got argnum={argnum}, but only {num_args} positional inputs")
878
879
880def _check_unique_non_empty(argnums):
881    if isinstance(argnums, tuple):
882        if len(argnums) == 0:
883            raise RuntimeError("argnums must be non-empty")
884        if len(set(argnums)) != len(argnums):
885            raise RuntimeError(f"argnums elements must be unique, got {argnums}")
886
887
888def _replace_args(old_args, new_args, argnums):
889    if isinstance(argnums, int):
890        if len(new_args) != 1:
891            raise RuntimeError(
892                f"new_args should be of size 1, was of size {len(new_args)}"
893            )
894        return tuple(
895            new_args[0] if i == argnums else old_args[i] for i in range(len(old_args))
896        )
897    if isinstance(argnums, tuple):
898        if len(new_args) != len(argnums):
899            raise RuntimeError(
900                "new_args should have the same size as argnums. "
901                f"Argnums size {len(argnums)}, new_args size {len(new_args)}"
902            )
903
904        def get_right_elem(i):
905            return new_args[argnums.index(i)] if i in argnums else old_args[i]
906
907        return tuple(get_right_elem(i) for i in range(len(old_args)))
908    raise RuntimeError(f"argnums must be int or Tuple[int, ...], got: {type(argnums)}")
909
910
911def _validate_and_wrap_argnums(argnums, num_args):
912    if isinstance(argnums, int):
913        return _validate_and_wrap_argnum(argnums, num_args)
914    if isinstance(argnums, tuple):
915        return tuple(_validate_and_wrap_argnum(argnum, num_args) for argnum in argnums)
916    raise AssertionError("Should never get here")
917
918
919def _slice_argnums(args, argnums, as_tuple=True):
920    if not isinstance(argnums, int) and not isinstance(argnums, tuple):
921        raise RuntimeError(
922            f"argnums must be int or Tuple[int, ...], got: {type(argnums)}"
923        )
924    argnums = _validate_and_wrap_argnums(argnums, len(args))
925    _check_unique_non_empty(argnums)
926    if isinstance(argnums, int):
927        if as_tuple:
928            return (args[argnums],)
929        else:
930            return args[argnums]
931    return tuple(args[i] for i in argnums)
932
933
934JVP_NESTING = 0
935
936
937def assert_flat_tuple_of_tensors(elts: Any, api: str, argname: str) -> None:
938    if not isinstance(elts, tuple):
939        raise RuntimeError(
940            f"{api}: Expected {argname} to be a tuple of Tensors, got {type(elts)}"
941        )
942    for elt in elts:
943        if isinstance(elt, torch.Tensor):
944            continue
945        raise RuntimeError(
946            f"{api}: Expected {argname} to be a tuple of Tensors, got "
947            f"a tuple with an element of type {type(elt)}"
948        )
949    if len(elts) == 0:
950        raise RuntimeError(
951            f"{api}: Expected {argname} to be a non-empty tuple of Tensors."
952        )
953
954
955def assert_non_empty_tensor_output(output: List[Any], api: str) -> None:
956    if (len(output) == 1 and output[0] is None) or len(output) < 1:
957        raise RuntimeError(
958            f"{api}: Expected f to be a function that has non-empty output (got output = {output})"
959        )
960    for o in output:
961        if not isinstance(o, torch.Tensor):
962            raise RuntimeError(
963                f"{api}: expected f(*primals) to return only tensors"
964                f", got unsupported type {type(o)}"
965            )
966
967
968def assert_output_is_tensor_or_tensors(output: Any, api: str) -> None:
969    if isinstance(output, torch.Tensor):
970        return
971    if not isinstance(output, tuple):
972        raise RuntimeError(
973            f"{api}: Expected output of f to be a Tensor or Tensors, got "
974            f"{type(output)}"
975        )
976    if len(output) == 0:
977        raise RuntimeError(
978            f"{api}: Expected output of f to be a non-empty tuple of Tensors."
979        )
980    for out in output:
981        if isinstance(out, torch.Tensor):
982            continue
983        raise RuntimeError(
984            f"{api}: Expected output of f to be a Tensor or Tensors, got "
985            f"{type(out)} as an output"
986        )
987
988
989def assert_non_empty_list_of_tensors(
990    output: List[torch.Tensor], api: str, argname: str
991) -> None:
992    if len(output) == 0:
993        raise RuntimeError(f"{api}: Expected {argname} to contain at least one Tensor.")
994    for out in output:
995        if isinstance(out, torch.Tensor):
996            continue
997        raise RuntimeError(
998            f"{api}: Expected {argname} to only contain Tensors, got " f"{type(out)}"
999        )
1000
1001
1002jvp_str = "jvp(f, primals, tangents)"
1003
1004
1005def safe_unpack_dual(dual, strict):
1006    if not isinstance(dual, torch.Tensor):
1007        raise RuntimeError(
1008            f"{jvp_str}: expected f(*args) to return only tensors"
1009            f", got unsupported type {type(dual)}"
1010        )
1011
1012    primal, tangent = fwAD.unpack_dual(dual)
1013    if tangent is None:
1014        if strict:
1015            raise RuntimeError(
1016                "jvp(f, primals, tangents, strict=True): "
1017                "The output of f is independent of "
1018                "the inputs. This is not allowed with strict=True."
1019            )
1020        tangent = torch.zeros_like(primal)
1021    return primal, tangent
1022
1023
1024@exposed_in("torch.func")
1025def jvp(
1026    func: Callable,
1027    primals: Any,
1028    tangents: Any,
1029    *,
1030    strict: bool = False,
1031    has_aux: bool = False,
1032):
1033    """
1034    Standing for the Jacobian-vector product, returns a tuple containing
1035    the output of `func(*primals)` and the "Jacobian of ``func`` evaluated at
1036    ``primals``" times ``tangents``. This is also known as forward-mode autodiff.
1037
1038    Args:
1039        func (function): A Python function that takes one or more arguments,
1040            one of which must be a Tensor, and returns one or more Tensors
1041        primals (Tensors): Positional arguments to ``func`` that must all be
1042            Tensors. The returned function will also be computing the
1043            derivative with respect to these arguments
1044        tangents (Tensors): The "vector" for which Jacobian-vector-product is
1045            computed. Must be the same structure and sizes as the inputs to
1046            ``func``.
1047        has_aux (bool): Flag indicating that ``func`` returns a
1048            ``(output, aux)`` tuple where the first element is the output of
1049            the function to be differentiated and the second element is
1050            other auxiliary objects that will not be differentiated.
1051            Default: False.
1052
1053    Returns:
1054        Returns a ``(output, jvp_out)`` tuple containing the output of ``func``
1055        evaluated at ``primals`` and the Jacobian-vector product.
1056        If ``has_aux is True``, then instead returns a ``(output, jvp_out, aux)`` tuple.
1057
1058    .. note::
1059        You may see this API error out with "forward-mode AD not implemented
1060        for operator X". If so, please file a bug report and we will prioritize it.
1061
1062    jvp is useful when you wish to compute gradients of a function R^1 -> R^N
1063
1064        >>> from torch.func import jvp
1065        >>> x = torch.randn([])
1066        >>> f = lambda x: x * torch.tensor([1., 2., 3])
1067        >>> value, grad = jvp(f, (x,), (torch.tensor(1.),))
1068        >>> assert torch.allclose(value, f(x))
1069        >>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
1070
1071    :func:`jvp` can support functions with multiple inputs by passing in the
1072    tangents for each of the inputs
1073
1074         >>> from torch.func import jvp
1075         >>> x = torch.randn(5)
1076         >>> y = torch.randn(5)
1077         >>> f = lambda x, y: (x * y)
1078         >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
1079         >>> assert torch.allclose(output, x + y)
1080
1081    """
1082
1083    return _jvp_with_argnums(
1084        func, primals, tangents, argnums=None, strict=strict, has_aux=has_aux
1085    )
1086
1087
1088def _jvp_with_argnums(
1089    func: Callable,
1090    primals: Any,
1091    tangents: Any,
1092    argnums: Optional[argnums_t],
1093    *,
1094    strict: bool = False,
1095    has_aux: bool,
1096):
1097    # This is the same function as jvp but also accepts an argnums argument
1098    # Most args are the same as jvp except for the added argument
1099    # argnums (Optional[int or tuple[int]]): Optional, specifies the argument(s) to compute gradients with respect to.
1100    #         If None, computes the gradients with respect to all inputs (used for jvp). Default: None
1101    # Because of this, tangents must be of length argnums and matches up to the corresponding primal whose index is
1102    # given by argnums
1103    #
1104    # WARN: Users should NOT call this function directly and should just be calling jvp.
1105    # It is only separated so that inputs passed to jacfwd but not differentiated get the correct wrappers.
1106    #
1107    # NOTE: All error messages are produced as if jvp was being called, even if this was called by jacfwd
1108    #
1109    # Returns the same two elements as :func:`jvp` but the returned tuple, ``jvp_out``, only has JVPs with respect to
1110    # the primals given by argnums
1111    if not isinstance(primals, tuple):
1112        raise RuntimeError(
1113            f"{jvp_str}: Expected primals to be a tuple. "
1114            f"E.g. it should be valid to call f(*primals)."
1115        )
1116    diff_args = primals if argnums is None else _slice_argnums(primals, argnums)
1117    flat_primals, primals_spec = tree_flatten(diff_args)
1118    flat_tangents, tangents_spec = tree_flatten(tangents)
1119    _jvp_treespec_compare(diff_args, tangents)
1120    assert_non_empty_list_of_tensors(flat_primals, jvp_str, "primals")
1121    assert_non_empty_list_of_tensors(flat_tangents, jvp_str, "tangents")
1122
1123    global JVP_NESTING
1124
1125    with jvp_increment_nesting() as level:
1126        with fwAD._set_fwd_grad_enabled(True):
1127            ctx = fwAD.dual_level if JVP_NESTING == 1 else contextlib.nullcontext
1128            with ctx():
1129                flat_duals = tuple(
1130                    fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents)
1131                )
1132                duals = tree_unflatten(flat_duals, primals_spec)
1133                # Note for the reviewer: This is extremely odd but it passes the
1134                # assertion "len(self.block_stack) == 1" on symbolic_convert.py
1135                # The equivalent "if argnums is not None" fails for some reason
1136                if isinstance(argnums, (int, tuple)):
1137                    primals = _wrap_all_tensors(primals, level)
1138                    duals = _replace_args(primals, duals, argnums)
1139                result_duals = func(*duals)
1140                if has_aux:
1141                    if not (isinstance(result_duals, tuple) and len(result_duals) == 2):
1142                        raise RuntimeError(
1143                            f"{jvp_str}: output of function f should be a tuple: (output, aux) "
1144                            "if has_aux is True"
1145                        )
1146                    result_duals, aux = result_duals
1147                    aux = _undo_create_differentiable(aux, level)
1148
1149                result_duals, spec = tree_flatten(result_duals)
1150                assert_non_empty_tensor_output(result_duals, jvp_str)
1151
1152                primals_out, tangents_out = zip(
1153                    *[safe_unpack_dual(dual, strict) for dual in result_duals]
1154                )
1155                primals_out = tree_map(
1156                    partial(_undo_create_differentiable, level=level), primals_out
1157                )
1158                tangents_out = tree_map(
1159                    partial(_undo_create_differentiable, level=level), tangents_out
1160                )
1161
1162                primals_out_unflatten = tree_unflatten(primals_out, spec)
1163                tangents_out_unflatten = tree_unflatten(tangents_out, spec)
1164                if has_aux:
1165                    return primals_out_unflatten, tangents_out_unflatten, aux
1166
1167                return primals_out_unflatten, tangents_out_unflatten
1168
1169
1170def safe_unflatten(tensor, dim, shape):
1171    if len(shape) == 0:
1172        assert tensor.shape[dim] == 1
1173        return tensor.squeeze(dim)
1174    return tensor.unflatten(dim, shape)
1175
1176
1177@exposed_in("torch.func")
1178def jacfwd(
1179    func: Callable,
1180    argnums: argnums_t = 0,
1181    has_aux: bool = False,
1182    *,
1183    randomness: str = "error",
1184):
1185    """
1186    Computes the Jacobian of ``func`` with respect to the arg(s) at index
1187    ``argnum`` using forward-mode autodiff
1188
1189    Args:
1190        func (function): A Python function that takes one or more arguments,
1191            one of which must be a Tensor, and returns one or more Tensors
1192        argnums (int or Tuple[int]): Optional, integer or tuple of integers,
1193            saying which arguments to get the Jacobian with respect to.
1194            Default: 0.
1195        has_aux (bool): Flag indicating that ``func`` returns a
1196            ``(output, aux)`` tuple where the first element is the output of
1197            the function to be differentiated and the second element is
1198            auxiliary objects that will not be differentiated.
1199            Default: False.
1200        randomness(str): Flag indicating what type of randomness to use.
1201            See :func:`vmap` for more detail. Allowed: "different", "same", "error".
1202            Default: "error"
1203
1204    Returns:
1205        Returns a function that takes in the same inputs as ``func`` and
1206        returns the Jacobian of ``func`` with respect to the arg(s) at
1207        ``argnums``. If ``has_aux is True``, then the returned function
1208        instead returns a ``(jacobian, aux)`` tuple where ``jacobian``
1209        is the Jacobian and ``aux`` is auxiliary objects returned by ``func``.
1210
1211    .. note::
1212        You may see this API error out with "forward-mode AD not implemented
1213        for operator X". If so, please file a bug report and we will prioritize it.
1214        An alternative is to use :func:`jacrev`, which has better operator coverage.
1215
1216    A basic usage with a pointwise, unary operation will give a diagonal array
1217    as the Jacobian
1218
1219        >>> from torch.func import jacfwd
1220        >>> x = torch.randn(5)
1221        >>> jacobian = jacfwd(torch.sin)(x)
1222        >>> expected = torch.diag(torch.cos(x))
1223        >>> assert torch.allclose(jacobian, expected)
1224
1225    :func:`jacfwd` can be composed with vmap to produce batched
1226    Jacobians:
1227
1228        >>> from torch.func import jacfwd, vmap
1229        >>> x = torch.randn(64, 5)
1230        >>> jacobian = vmap(jacfwd(torch.sin))(x)
1231        >>> assert jacobian.shape == (64, 5, 5)
1232
1233    If you would like to compute the output of the function as well as the
1234    jacobian of the function, use the ``has_aux`` flag to return the output
1235    as an auxiliary object:
1236
1237        >>> from torch.func import jacfwd
1238        >>> x = torch.randn(5)
1239        >>>
1240        >>> def f(x):
1241        >>>   return x.sin()
1242        >>>
1243        >>> def g(x):
1244        >>>   result = f(x)
1245        >>>   return result, result
1246        >>>
1247        >>> jacobian_f, f_x = jacfwd(g, has_aux=True)(x)
1248        >>> assert torch.allclose(f_x, f(x))
1249
1250    Additionally, :func:`jacrev` can be composed with itself or :func:`jacrev`
1251    to produce Hessians
1252
1253        >>> from torch.func import jacfwd, jacrev
1254        >>> def f(x):
1255        >>>   return x.sin().sum()
1256        >>>
1257        >>> x = torch.randn(5)
1258        >>> hessian = jacfwd(jacrev(f))(x)
1259        >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
1260
1261    By default, :func:`jacfwd` computes the Jacobian with respect to the first
1262    input. However, it can compute the Jacboian with respect to a different
1263    argument by using ``argnums``:
1264
1265        >>> from torch.func import jacfwd
1266        >>> def f(x, y):
1267        >>>   return x + y ** 2
1268        >>>
1269        >>> x, y = torch.randn(5), torch.randn(5)
1270        >>> jacobian = jacfwd(f, argnums=1)(x, y)
1271        >>> expected = torch.diag(2 * y)
1272        >>> assert torch.allclose(jacobian, expected)
1273
1274    Additionally, passing a tuple to ``argnums`` will compute the Jacobian
1275    with respect to multiple arguments
1276
1277        >>> from torch.func import jacfwd
1278        >>> def f(x, y):
1279        >>>   return x + y ** 2
1280        >>>
1281        >>> x, y = torch.randn(5), torch.randn(5)
1282        >>> jacobian = jacfwd(f, argnums=(0, 1))(x, y)
1283        >>> expectedX = torch.diag(torch.ones_like(x))
1284        >>> expectedY = torch.diag(2 * y)
1285        >>> assert torch.allclose(jacobian[0], expectedX)
1286        >>> assert torch.allclose(jacobian[1], expectedY)
1287
1288    """
1289
1290    def wrapper_fn(*args):
1291        error_if_complex("jacfwd", args, is_input=True)
1292        primals = args if argnums is None else _slice_argnums(args, argnums)
1293        flat_primals, primals_spec = tree_flatten(primals)
1294        flat_primals_numels = tuple(p.numel() for p in flat_primals)
1295        flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
1296        basis = tree_unflatten(flat_basis, primals_spec)
1297
1298        def push_jvp(basis):
1299            output = _jvp_with_argnums(
1300                func, args, basis, argnums=argnums, has_aux=has_aux
1301            )
1302            # output[0] is the output of `func(*args)`
1303            error_if_complex("jacfwd", output[0], is_input=False)
1304            if has_aux:
1305                _, jvp_out, aux = output
1306                return jvp_out, aux
1307            _, jvp_out = output
1308            return jvp_out
1309
1310        results = vmap(push_jvp, randomness=randomness)(basis)
1311        if has_aux:
1312            results, aux = results
1313            # aux is in the standard basis format, e.g. NxN matrix
1314            # We need to fetch the first element as original `func` output
1315            flat_aux, aux_spec = tree_flatten(aux)
1316            flat_aux = [value[0] for value in flat_aux]
1317            aux = tree_unflatten(flat_aux, aux_spec)
1318
1319        jac_outs, spec = tree_flatten(results)
1320        # Most probably below output check can never raise an error
1321        # as jvp should test the output before
1322        # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)')
1323
1324        jac_outs_ins = tuple(
1325            tuple(
1326                safe_unflatten(jac_out_in, -1, primal.shape)
1327                for primal, jac_out_in in zip(
1328                    flat_primals,
1329                    jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1),
1330                )
1331            )
1332            for jac_out in jac_outs
1333        )
1334        jac_outs_ins = tuple(
1335            tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins
1336        )
1337
1338        if isinstance(argnums, int):
1339            jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)
1340        if has_aux:
1341            return tree_unflatten(jac_outs_ins, spec), aux
1342        return tree_unflatten(jac_outs_ins, spec)
1343
1344    # Dynamo does not support HOP composition if their inner function is
1345    # annotated with @functools.wraps(...). We circumvent this issue by applying
1346    # wraps only if we're not tracing with dynamo.
1347    if not torch._dynamo.is_compiling():
1348        wrapper_fn = wraps(func)(wrapper_fn)
1349
1350    return wrapper_fn
1351
1352
1353@exposed_in("torch.func")
1354def hessian(func, argnums=0):
1355    """
1356    Computes the Hessian of ``func`` with respect to the arg(s) at index
1357    ``argnum`` via a forward-over-reverse strategy.
1358
1359    The forward-over-reverse strategy (composing ``jacfwd(jacrev(func))``) is
1360    a good default for good performance. It is possible to compute Hessians
1361    through other compositions of :func:`jacfwd` and :func:`jacrev` like
1362    ``jacfwd(jacfwd(func))`` or ``jacrev(jacrev(func))``.
1363
1364    Args:
1365        func (function): A Python function that takes one or more arguments,
1366            one of which must be a Tensor, and returns one or more Tensors
1367        argnums (int or Tuple[int]): Optional, integer or tuple of integers,
1368            saying which arguments to get the Hessian with respect to.
1369            Default: 0.
1370
1371    Returns:
1372        Returns a function that takes in the same inputs as ``func`` and
1373        returns the Hessian of ``func`` with respect to the arg(s) at
1374        ``argnums``.
1375
1376    .. note::
1377        You may see this API error out with "forward-mode AD not implemented
1378        for operator X". If so, please file a bug report and we will prioritize it.
1379        An alternative is to use ``jacrev(jacrev(func))``, which has better
1380        operator coverage.
1381
1382    A basic usage with a R^N -> R^1 function gives a N x N Hessian:
1383
1384        >>> from torch.func import hessian
1385        >>> def f(x):
1386        >>>   return x.sin().sum()
1387        >>>
1388        >>> x = torch.randn(5)
1389        >>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
1390        >>> assert torch.allclose(hess, torch.diag(-x.sin()))
1391
1392    """
1393    return jacfwd(jacrev(func, argnums), argnums)
1394
1395
1396@doesnt_support_saved_tensors_hooks
1397def grad_and_value_impl(func, argnums, has_aux, args, kwargs) -> Callable:
1398    with grad_increment_nesting() as level:
1399        output, aux, grad_input = None, None, None
1400        # See NOTE [grad and vjp interaction with no_grad]
1401        with torch.enable_grad():
1402            args = _wrap_all_tensors(args, level)
1403            kwargs = _wrap_all_tensors(kwargs, level)
1404            diff_args = _slice_argnums(args, argnums, as_tuple=False)
1405            tree_map_(partial(_create_differentiable, level=level), diff_args)
1406
1407            output = func(*args, **kwargs)
1408            if has_aux:
1409                if not (isinstance(output, tuple) and len(output) == 2):
1410                    raise RuntimeError(
1411                        "grad_and_value(f)(*args): output of function f should be a tuple: (output, aux) "
1412                        "if has_aux is True"
1413                    )
1414                output, aux = output
1415
1416            if not isinstance(output, torch.Tensor):
1417                raise RuntimeError(
1418                    "grad_and_value(f)(*args): Expected f(*args) "
1419                    f"to return a Tensor, got {type(output)}"
1420                )
1421            if output.dim() != 0:
1422                raise RuntimeError(
1423                    "grad_and_value(f)(*args): Expected f(*args) "
1424                    "to return a scalar Tensor, got tensor with "
1425                    f"{output.dim()} dims. Maybe you wanted to "
1426                    "use the vjp or jacrev APIs instead?"
1427                )
1428
1429            flat_diff_args, spec = tree_flatten(diff_args)
1430
1431            # NB: need create_graph so that backward pass isn't run in no_grad mode
1432            flat_outputs = _as_tuple(output)
1433            flat_grad_input = _autograd_grad(
1434                flat_outputs, flat_diff_args, create_graph=True
1435            )
1436            grad_input = tree_unflatten(flat_grad_input, spec)
1437
1438            grad_input = _undo_create_differentiable(grad_input, level)
1439            output = _undo_create_differentiable(output, level)
1440            if has_aux:
1441                aux = _undo_create_differentiable(aux, level)
1442
1443        if has_aux:
1444            return grad_input, (output, aux)
1445        return grad_input, output
1446
1447
1448def grad_impl(func: Callable, argnums: argnums_t, has_aux: bool, args, kwargs):
1449    results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
1450    if has_aux:
1451        grad, (_, aux) = results
1452        return grad, aux
1453    grad, _ = results
1454    return grad
1455
1456
1457def _maybe_wrap_functional_tensor(
1458    maybe_tensor, level, *, _python_functionalize: bool = False
1459):
1460    if not isinstance(maybe_tensor, torch.Tensor):
1461        return maybe_tensor
1462    wrapped = _wrap_functional_tensor(maybe_tensor, level)
1463    _assert_wrapped_functional(maybe_tensor, wrapped)
1464    if _python_functionalize:
1465        out = FunctionalTensor(wrapped)
1466        torch._mirror_autograd_meta_to(maybe_tensor, out)
1467        return out
1468    return wrapped
1469
1470
1471def _wrap_all_tensors_to_functional(
1472    tensor_pytree, level, *, _python_functionalize: bool = False
1473):
1474    return tree_map(
1475        partial(
1476            lambda x: _maybe_wrap_functional_tensor(
1477                x, level, _python_functionalize=_python_functionalize
1478            )
1479        ),
1480        tensor_pytree,
1481    )
1482
1483
1484def _maybe_unwrap_functional_tensor(maybe_tensor, *, reapply_views: bool):
1485    if not isinstance(maybe_tensor, torch.Tensor):
1486        return maybe_tensor
1487    if isinstance(maybe_tensor, FunctionalTensor):
1488        maybe_tensor = maybe_tensor.elem
1489
1490    if not torch._is_functional_tensor(maybe_tensor):
1491        # If it's not a functional tensor, just return it.
1492        # This can happen if we functionalize a fn that returns a global,
1493        # which was never wrapped properly.
1494        return maybe_tensor
1495    # Sync any pending updates on the output tensor
1496    torch._sync(maybe_tensor)
1497    return _unwrap_functional_tensor(maybe_tensor, reapply_views)
1498
1499
1500def _unwrap_all_tensors_from_functional(tensor_pytree, *, reapply_views: bool):
1501    return tree_map(
1502        lambda t: _maybe_unwrap_functional_tensor(t, reapply_views=reapply_views),
1503        tensor_pytree,
1504    )
1505
1506
1507@exposed_in("torch.func")
1508def functionalize(func: Callable, *, remove: str = "mutations") -> Callable:
1509    """
1510    functionalize is a transform that can be used to remove (intermediate)
1511    mutations and aliasing from a function, while preserving the function's
1512    semantics.
1513
1514    ``functionalize(func)`` returns a new function with the same semantics
1515    as ``func``, but with all intermediate mutations removed.
1516    Every inplace operation performed on an intermediate tensor:
1517    ``intermediate.foo_()``
1518    gets replaced by its out-of-place equivalent:
1519    ``intermediate_updated = intermediate.foo()``.
1520
1521    functionalize is useful for shipping a pytorch program off to
1522    backends or compilers that aren't able to easily represent
1523    mutations or aliasing operators.
1524
1525    Args:
1526        func (Callable): A Python function that takes one or more arguments.
1527        remove (str): An optional string argument, that takes on either
1528            the value 'mutations' or 'mutations_and_views'.
1529            If 'mutations' is passed in then all mutating operators
1530            will be replaced with their non-mutating equivalents.
1531            If 'mutations_and_views' is passed in, then additionally, all aliasing
1532            operators will be replaced with their non-aliasing equivalents.
1533            Default: 'mutations'.
1534
1535    Returns:
1536        Returns a new "functionalized" function. It takes the same inputs as
1537        ``func``, and has the same behavior, but any mutations
1538        (and optionally aliasing) performed on intermediate tensors
1539        in the function will be removed.
1540
1541    functionalize will also remove mutations (and views) that were performed on function inputs.
1542    However to preserve semantics, functionalize will "fix up" the mutations after
1543    the transform has finished running, by detecting if any tensor inputs "should have"
1544    been mutated, and copying the new data back to the inputs if necessary.
1545
1546
1547    Example::
1548
1549        >>> # xdoctest: +SKIP
1550        >>> import torch
1551        >>> from torch.fx.experimental.proxy_tensor import make_fx
1552        >>> from torch.func import functionalize
1553        >>>
1554        >>> # A function that uses mutations and views, but only on intermediate tensors.
1555        >>> def f(a):
1556        ...     b = a + 1
1557        ...     c = b.view(-1)
1558        ...     c.add_(1)
1559        ...     return b
1560        ...
1561        >>> inpt = torch.randn(2)
1562        >>>
1563        >>> out1 = f(inpt)
1564        >>> out2 = functionalize(f)(inpt)
1565        >>>
1566        >>> # semantics are the same (outputs are equivalent)
1567        >>> print(torch.allclose(out1, out2))
1568        True
1569        >>>
1570        >>> f_traced = make_fx(f)(inpt)
1571        >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt)
1572        >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
1573        >>>
1574        >>> print(f_traced.code)
1575
1576
1577
1578        def forward(self, a_1):
1579            add = torch.ops.aten.add(a_1, 1);  a_1 = None
1580            view = torch.ops.aten.view(add, [-1])
1581            add_ = torch.ops.aten.add_(view, 1);  view = None
1582            return add
1583
1584        >>> print(f_no_mutations_traced.code)
1585
1586
1587
1588        def forward(self, a_1):
1589            add = torch.ops.aten.add(a_1, 1);  a_1 = None
1590            view = torch.ops.aten.view(add, [-1]);  add = None
1591            add_1 = torch.ops.aten.add(view, 1);  view = None
1592            view_1 = torch.ops.aten.view(add_1, [2]);  add_1 = None
1593            return view_1
1594
1595        >>> print(f_no_mutations_and_views_traced.code)
1596
1597
1598
1599        def forward(self, a_1):
1600            add = torch.ops.aten.add(a_1, 1);  a_1 = None
1601            view_copy = torch.ops.aten.view_copy(add, [-1]);  add = None
1602            add_1 = torch.ops.aten.add(view_copy, 1);  view_copy = None
1603            view_copy_1 = torch.ops.aten.view_copy(add_1, [2]);  add_1 = None
1604            return view_copy_1
1605
1606
1607        >>> # A function that mutates its input tensor
1608        >>> def f(a):
1609        ...     b = a.view(-1)
1610        ...     b.add_(1)
1611        ...     return a
1612        ...
1613        >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
1614        >>> #
1615        >>> # All mutations and views have been removed,
1616        >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input
1617        >>> # after the function has completed.
1618        >>> print(f_no_mutations_and_views_traced.code)
1619
1620
1621
1622        def forward(self, a_1):
1623            view_copy = torch.ops.aten.view_copy(a_1, [-1])
1624            add = torch.ops.aten.add(view_copy, 1);  view_copy = None
1625            view_copy_1 = torch.ops.aten.view_copy(add, [2]);  add = None
1626            copy_ = torch.ops.aten.copy_(a_1, view_copy_1);  a_1 = None
1627            return view_copy_1
1628
1629
1630    There are a few "failure modes" for functionalize that are worth calling out:
1631      (1) Like other torch.func transforms, `functionalize()` doesn't work with functions
1632          that directly use `.backward()`. The same is true for torch.autograd.grad.
1633          If you want to use autograd, you can compute gradients directly
1634          with `functionalize(grad(f))`.
1635      (2) Like other torch.func transforms, `functionalize()` doesn't work with global state.
1636          If you call `functionalize(f)` on a function that takes views / mutations of
1637          non-local state, functionalization will simply no-op and pass the view/mutation
1638          calls directly to the backend.
1639          One way to work around this is is to ensure that any non-local state creation
1640          is wrapped into a larger function, which you then call functionalize on.
1641      (3) `resize_()` has some limitations: functionalize will only work on programs
1642          that use resize_()` as long as the tensor being resized is not a view.
1643      (4) `as_strided()` has some limitations: functionalize will not work on
1644          `as_strided()` calls that result in tensors with overlapping memory.
1645
1646
1647    Finally, a helpful mental model for understanding functionalization is that
1648    most user pytorch programs are writing with the public torch API.
1649    When executed, torch operators are generally decomposed into
1650    our internal C++ "ATen" API.
1651    The logic for functionalization happens entirely at the level of ATen.
1652    Functionalization knows how to take every aliasing operator in ATen,
1653    and map it to its non-aliasing equivalent
1654    (e.g. ``tensor.view({-1})`` -> ``at::view_copy(tensor, {-1})``),
1655    and how to take every mutating operator in ATen,
1656    and map it to its non-mutating equivalent
1657    (e.g. ``tensor.add_(1)`` -> ``at::add(tensor, -1)``),
1658    while tracking aliases and mutations out-of-line to know when to fix things up.
1659    Information about which ATen operators are aliasing or mutating all comes from
1660    https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml.
1661    """
1662    if remove == "mutations":
1663        reapply_views = True
1664    elif remove == "mutations_and_views":
1665        reapply_views = False
1666    else:
1667        raise RuntimeError(
1668            f"functionalize(f, remove='mutations'): received invalid argument for remove={remove}."
1669            " Valid options are:\n"
1670            "     remove='mutations': all inplace and out= operators will be removed from the program, and replaced"
1671            " with their out-of-place equivalents.\n"
1672            "     remove='mutations_and_views': In addition to the above, all aliasing operators {view} will be"
1673            " replaced with their non-aliasing counterparts, {view}_copy.\n"
1674        )
1675
1676    @wraps(func)
1677    def wrapped(*args, **kwargs):
1678        try:
1679            func_level = _func_increment_nesting(reapply_views)
1680            func_args = _wrap_all_tensors_to_functional(args, func_level)
1681            func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level)
1682
1683            flattened_unwrapped_args = pytree.arg_tree_leaves(*args)
1684            flattened_wrapped_args = pytree.arg_tree_leaves(*func_args)
1685            flattened_unwrapped_kwargs = pytree.arg_tree_leaves(**kwargs)
1686            flattened_wrapped_kwargs = pytree.arg_tree_leaves(**func_kwargs)
1687
1688            func_outputs = func(*func_args, **func_kwargs)
1689            outputs = _unwrap_all_tensors_from_functional(
1690                func_outputs, reapply_views=reapply_views
1691            )
1692            flat_outputs, func_out_spec = tree_flatten(outputs)
1693
1694            for a in flattened_wrapped_args + flattened_wrapped_kwargs:
1695                if isinstance(a, torch.Tensor):
1696                    # Call sync_() on the inputs, to ensure that any pending mutations have been applied.
1697                    torch._sync(a)
1698
1699            # And if any mutations were applied to the inputs, we need to propagate them back to the user.
1700            for unwrapped, wrapped in zip(
1701                flattened_unwrapped_args, flattened_wrapped_args
1702            ):
1703                if isinstance(unwrapped, torch.Tensor) and isinstance(
1704                    wrapped, torch.Tensor
1705                ):
1706                    _propagate_functional_input_mutation(unwrapped, wrapped)
1707            for unwrapped, wrapped in zip(
1708                flattened_unwrapped_kwargs, flattened_wrapped_kwargs
1709            ):
1710                if isinstance(unwrapped, torch.Tensor) and isinstance(
1711                    wrapped, torch.Tensor
1712                ):
1713                    _propagate_functional_input_mutation(unwrapped, wrapped)
1714
1715            return outputs
1716        finally:
1717            _func_decrement_nesting()
1718
1719    return wrapped
1720
1721
1722@exposed_in("torch.func")
1723def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
1724    """
1725    Returns the value of ``func`` at ``primals`` and linear approximation
1726    at ``primals``.
1727
1728    Args:
1729        func (Callable): A Python function that takes one or more arguments.
1730        primals (Tensors): Positional arguments to ``func`` that must all be
1731            Tensors. These are the values at which the function is linearly approximated.
1732
1733    Returns:
1734        Returns a ``(output, jvp_fn)`` tuple containing the output of ``func``
1735        applied to ``primals`` and a function that computes the jvp of
1736        ``func`` evaluated at ``primals``.
1737
1738    linearize is useful if jvp is to be computed multiple times at ``primals``. However,
1739    to achieve this, linearize saves intermediate computation and has higher memory requirements
1740    than directly applying `jvp`. So, if all the ``tangents`` are known, it maybe more efficient
1741    to compute vmap(jvp) instead of using linearize.
1742
1743    .. note::
1744        linearize evaluates ``func`` twice. Please file an issue for an implementation
1745        with a single evaluation.
1746
1747    Example::
1748        >>> import torch
1749        >>> from torch.func import linearize
1750        >>> def fn(x):
1751        ...     return x.sin()
1752        ...
1753        >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
1754        >>> jvp_fn(torch.ones(3, 3))
1755        tensor([[1., 1., 1.],
1756                [1., 1., 1.],
1757                [1., 1., 1.]])
1758        >>>
1759
1760    """
1761    # Note: We evaluate `fn` twice.
1762    # Once for returning the output and other while
1763    # tracing the graph.
1764    # If this becomes a bottle-neck, we should update
1765    # make_fx such that it also returns the output.
1766
1767    output = func(*primals)
1768    _, output_spec = tree_flatten(output)
1769
1770    flat_primals, primals_argspec = tree_flatten(primals)
1771
1772    # tangents for tracing
1773    flat_tangents = tuple(p.new_empty(()).expand_as(p) for p in flat_primals)
1774
1775    # function to trace
1776    def trace_fn(flat_tangents):
1777        with fwAD.dual_level():
1778            flat_duals = tuple(
1779                fwAD.make_dual(p, t) for p, t in zip(flat_primals, flat_tangents)
1780            )
1781            duals = tree_unflatten(flat_duals, primals_argspec)
1782            output = func(*duals)
1783            tangents = tree_map_only(
1784                torch.Tensor, lambda dual: safe_unpack_dual(dual, False)[1], output
1785            )
1786
1787        return tangents
1788
1789    jvp_graph = lazy_dynamo_disallow(make_fx)(trace_fn)(flat_tangents)
1790    const_folded_jvp_graph = lazy_dynamo_disallow(const_fold.split_const_subgraphs)(
1791        jvp_graph
1792    )
1793
1794    # Hold only the meta-data regarding the primals.
1795    flat_primals_shape = tuple(p.shape for p in flat_primals)
1796    flat_primals_device = tuple(p.device for p in flat_primals)
1797    flat_primals_dtype = tuple(p.dtype for p in flat_primals)
1798
1799    def forward_ad_checks(flat_tangents):
1800        for idx, t in enumerate(flat_tangents):
1801            if t.shape != flat_primals_shape[idx]:
1802                msg = (
1803                    f"tangent:{idx} with shape {t.shape} in flattened "
1804                    f"pytree doesn't match the shape {flat_primals_shape[idx]} "
1805                    "of the corresponding primal."
1806                )
1807                raise RuntimeError(msg)
1808
1809            if t.device != flat_primals_device[idx]:
1810                msg = (
1811                    f"tangent:{idx} with device {t.device} in flattened "
1812                    f"pytree doesn't match the device {flat_primals_device[idx]} "
1813                    "of the corresponding primal."
1814                )
1815                raise RuntimeError(msg)
1816
1817            if t.dtype != flat_primals_dtype[idx]:
1818                msg = (
1819                    f"tangent:{idx} with dtype {t.dtype} in flattened "
1820                    f"pytree doesn't match the dtype {flat_primals_dtype[idx]} "
1821                    "of the corresponding primal."
1822                )
1823                raise RuntimeError(msg)
1824
1825    # jvp_fn : callable to return
1826    #   It takes care of checking the argspec of tangents,
1827    #   calling the folded fx graph and unflattening fx graph output
1828    def jvp_fn(*tangents):
1829        flat_tangents, tangent_argspec = tree_flatten(tangents)
1830        _linearize_treespec_compare(primals, tangents)
1831
1832        forward_ad_checks(flat_tangents)
1833
1834        flat_output = const_folded_jvp_graph(*flat_tangents)
1835        # const folded graph can return flat output,
1836        # so transform output.
1837        return tree_unflatten(flat_output, output_spec)
1838
1839    return output, jvp_fn
1840