xref: /aosp_15_r20/external/pytorch/torch/_functorch/apis.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3# NOTE: We allow Dynamo to see this file (via torch/_dynamo/trace_rules.py) so that it can
4#       trace through functorch transforms.
5#       Currently, we can't allow Dynamo to see `eager_transforms.py`/`vmap.py` as that break a lot of thing
6#       and there isn't a mechanism to selectively expose only some functions (eg. grad) from a file
7#       to Dynamo.
8import functools
9
10from torch._functorch.utils import argnums_t, exposed_in
11from torch._functorch.vmap import (
12    _check_out_dims_is_int_or_int_pytree,
13    _check_randomness_arg,
14    _chunked_vmap,
15    _process_batched_inputs,
16    Callable,
17    in_dims_t,
18    out_dims_t,
19    vmap_impl,
20)
21
22
23# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
24# sends those into func, and then unwraps the output BatchedTensors. Operations
25# on BatchedTensors perform the batched operations that the user is asking for.
26#
27# vmap's randomness behavior differs from JAX's, which would require a PRNG key
28# to be passed everywhere.
29
30
31@exposed_in("torch.func")
32def vmap(
33    func: Callable,
34    in_dims: in_dims_t = 0,
35    out_dims: out_dims_t = 0,
36    randomness: str = "error",
37    *,
38    chunk_size=None,
39) -> Callable:
40    """
41    vmap is the vectorizing map; ``vmap(func)`` returns a new function that
42    maps ``func`` over some dimension of the inputs. Semantically, vmap
43    pushes the map into PyTorch operations called by ``func``, effectively
44    vectorizing those operations.
45
46    vmap is useful for handling batch dimensions: one can write a function
47    ``func`` that runs on examples and then lift it to a function that can
48    take batches of examples with ``vmap(func)``. vmap can also be used to
49    compute batched gradients when composed with autograd.
50
51    .. note::
52        :func:`torch.vmap` is aliased to :func:`torch.func.vmap` for
53        convenience. Use whichever one you'd like.
54
55    Args:
56        func (function): A Python function that takes one or more arguments.
57            Must return one or more Tensors.
58        in_dims (int or nested structure): Specifies which dimension of the
59            inputs should be mapped over. ``in_dims`` should have a
60            structure like the inputs. If the ``in_dim`` for a particular
61            input is None, then that indicates there is no map dimension.
62            Default: 0.
63        out_dims (int or Tuple[int]): Specifies where the mapped dimension
64            should appear in the outputs. If ``out_dims`` is a Tuple, then
65            it should have one element per output. Default: 0.
66        randomness (str): Specifies whether the randomness in this
67            vmap should be the same or different across batches. If 'different',
68            the randomness for each batch will be different. If 'same', the
69            randomness will be the same across batches. If 'error', any calls to
70            random functions will error. Default: 'error'. WARNING: this flag
71            only applies to random PyTorch operations and does not apply to
72            Python's random module or numpy randomness.
73        chunk_size (None or int): If None (default), apply a single vmap over inputs.
74            If not None, then compute the vmap :attr:`chunk_size` samples at a time.
75            Note that :attr:`chunk_size=1` is equivalent to computing the vmap with a for-loop.
76            If you run into memory issues computing the vmap, please try a non-None chunk_size.
77
78    Returns:
79        Returns a new "batched" function. It takes the same inputs as
80        ``func``, except each input has an extra dimension at the index
81        specified by ``in_dims``. It takes returns the same outputs as
82        ``func``, except each output has an extra dimension at the index
83        specified by ``out_dims``.
84
85    .. warning:
86        :func:`vmap` works best with functional-style code. Please do not
87        perform any side-effects in ``func``, with the exception of
88        in-place PyTorch operations. Examples of side-effects include mutating
89        Python data structures and assigning values to variables not captured
90        in ``func``.
91
92    One example of using :func:`vmap` is to compute batched dot products. PyTorch
93    doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully
94    rummaging through docs, use :func:`vmap` to construct a new function.
95
96        >>> torch.dot                            # [D], [D] -> []
97        >>> batched_dot = torch.func.vmap(torch.dot)  # [N, D], [N, D] -> [N]
98        >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
99        >>> batched_dot(x, y)
100
101    :func:`vmap` can be helpful in hiding batch dimensions, leading to a simpler
102    model authoring experience.
103
104        >>> batch_size, feature_size = 3, 5
105        >>> weights = torch.randn(feature_size, requires_grad=True)
106        >>>
107        >>> def model(feature_vec):
108        >>>     # Very simple linear model with activation
109        >>>     return feature_vec.dot(weights).relu()
110        >>>
111        >>> examples = torch.randn(batch_size, feature_size)
112        >>> result = torch.vmap(model)(examples)
113
114    :func:`vmap` can also help vectorize computations that were previously difficult
115    or impossible to batch. One example is higher-order gradient computation.
116    The PyTorch autograd engine computes vjps (vector-Jacobian products).
117    Computing a full Jacobian matrix for some function f: R^N -> R^N usually
118    requires N calls to ``autograd.grad``, one per Jacobian row. Using :func:`vmap`,
119    we can vectorize the whole computation, computing the Jacobian in a single
120    call to ``autograd.grad``.
121
122        >>> # Setup
123        >>> N = 5
124        >>> f = lambda x: x ** 2
125        >>> x = torch.randn(N, requires_grad=True)
126        >>> y = f(x)
127        >>> I_N = torch.eye(N)
128        >>>
129        >>> # Sequential approach
130        >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
131        >>>                  for v in I_N.unbind()]
132        >>> jacobian = torch.stack(jacobian_rows)
133        >>>
134        >>> # vectorized gradient computation
135        >>> def get_vjp(v):
136        >>>     return torch.autograd.grad(y, x, v)
137        >>> jacobian = torch.vmap(get_vjp)(I_N)
138
139    :func:`vmap` can also be nested, producing an output with multiple batched dimensions
140
141        >>> torch.dot                            # [D], [D] -> []
142        >>> batched_dot = torch.vmap(torch.vmap(torch.dot))  # [N1, N0, D], [N1, N0, D] -> [N1, N0]
143        >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
144        >>> batched_dot(x, y) # tensor of size [2, 3]
145
146    If the inputs are not batched along the first dimension, ``in_dims`` specifies
147    the dimension that each inputs are batched along as
148
149        >>> torch.dot                            # [N], [N] -> []
150        >>> batched_dot = torch.vmap(torch.dot, in_dims=1)  # [N, D], [N, D] -> [D]
151        >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
152        >>> batched_dot(x, y)   # output is [5] instead of [2] if batched along the 0th dimension
153
154    If there are multiple inputs each of which is batched along different dimensions,
155    ``in_dims`` must be a tuple with the batch dimension for each input as
156
157        >>> torch.dot                            # [D], [D] -> []
158        >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None))  # [N, D], [D] -> [N]
159        >>> x, y = torch.randn(2, 5), torch.randn(5)
160        >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
161
162    If the input is a Python struct, ``in_dims`` must be a tuple containing a struct
163    matching the shape of the input:
164
165        >>> f = lambda dict: torch.dot(dict['x'], dict['y'])
166        >>> x, y = torch.randn(2, 5), torch.randn(5)
167        >>> input = {'x': x, 'y': y}
168        >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
169        >>> batched_dot(input)
170
171    By default, the output is batched along the first dimension. However, it can be batched
172    along any dimension by using ``out_dims``
173
174        >>> f = lambda x: x ** 2
175        >>> x = torch.randn(2, 5)
176        >>> batched_pow = torch.vmap(f, out_dims=1)
177        >>> batched_pow(x) # [5, 2]
178
179    For any function that uses kwargs, the returned function will not batch the kwargs but will
180    accept kwargs
181
182        >>> x = torch.randn([2, 5])
183        >>> def fn(x, scale=4.):
184        >>>   return x * scale
185        >>>
186        >>> batched_pow = torch.vmap(fn)
187        >>> assert torch.allclose(batched_pow(x), x * 4)
188        >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
189
190    .. note::
191        vmap does not provide general autobatching or handle variable-length
192        sequences out of the box.
193    """
194    from torch._dynamo import is_compiling
195
196    _check_randomness_arg(randomness)
197    if not (chunk_size is None or chunk_size > 0):
198        raise ValueError(
199            f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})"
200        )
201
202    def wrapped(*args, **kwargs):
203        return vmap_impl(
204            func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
205        )
206
207    if not is_compiling():
208        wrapped = functools.wraps(func)(wrapped)
209
210    return wrapped
211
212
213def chunk_vmap(
214    func: Callable,
215    in_dims: in_dims_t = 0,
216    out_dims: out_dims_t = 0,
217    randomness: str = "error",
218    chunks=2,
219) -> Callable:
220    """
221    chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes
222    everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of
223    chunks at a time. For more details about vectorizing map, see :func:`vmap`.
224
225    .. note::
226        Please use :func:`vmap` with ``chunk_size`` argument instead of this API.
227
228    Args:
229        func (function): A Python function that takes one or more arguments.
230            Must return one or more Tensors.
231        in_dims (int or nested structure): Specifies which dimension of the
232            inputs should be mapped over. ``in_dims`` should have a
233            structure like the inputs. If the ``in_dim`` for a particular
234            input is None, then that indicates there is no map dimension.
235            Default: 0.
236        out_dims (int or Tuple[int]): Specifies where the mapped dimension
237            should appear in the outputs. If ``out_dims`` is a Tuple, then
238            it should have one element per output. Default: 0.
239        randomness (str): Specifies whether the randomness in this
240            vmap should be the same or different across batches. If 'different',
241            the randomness for each batch will be different. If 'same', the
242            randomness will be the same across batches. If 'error', any calls to
243            random functions will error. Default: 'error'. WARNING: this flag
244            only applies to random PyTorch operations and does not apply to
245            Python's random module or numpy randomness.
246        chunks (int): Number of chunks to use to split the input data. Default is 2.
247            If equals to 1 then :func:`vmap` is called.
248
249    Returns:
250        Returns a new "batched" function. It takes the same inputs as
251        ``func``, except each input has an extra dimension at the index
252        specified by ``in_dims``. It takes returns the same outputs as
253        ``func``, except each output has an extra dimension at the index
254        specified by ``out_dims``.
255    """
256    _check_randomness_arg(randomness)
257
258    if chunks == 1:
259        return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness)
260
261    def _get_chunk_flat_args(flat_args_, flat_in_dims_, chunks_):
262        flat_args_chunks = tuple(
263            t.chunk(chunks_, dim=in_dim)
264            if in_dim is not None
265            else [
266                t,
267            ]
268            * chunks_
269            for t, in_dim in zip(flat_args_, flat_in_dims_)
270        )
271        # transpose chunk dim and flatten structure
272        # chunks_flat_args is a list of flatten args
273        chunks_flat_args = zip(*flat_args_chunks)
274        return chunks_flat_args
275
276    @functools.wraps(func)
277    def wrapped_with_chunks(*args, **kwargs):
278        _check_out_dims_is_int_or_int_pytree(out_dims, func)
279        _, flat_in_dims, flat_args, args_spec = _process_batched_inputs(
280            in_dims, args, func
281        )
282        # Chunk flat arguments
283        chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks)
284
285        # Apply vmap on chunks
286        return _chunked_vmap(
287            func,
288            flat_in_dims,
289            chunks_flat_args,
290            args_spec,
291            out_dims,
292            randomness,
293            **kwargs,
294        )
295
296    return wrapped_with_chunks
297
298
299@exposed_in("torch.func")
300def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
301    """``grad`` operator helps computing gradients of ``func`` with respect to the
302    input(s) specified by ``argnums``. This operator can be nested to
303    compute higher-order gradients.
304
305    Args:
306        func (Callable): A Python function that takes one or more arguments.
307            Must return a single-element Tensor. If specified ``has_aux`` equals ``True``,
308            function can return a tuple of single-element Tensor and other auxiliary objects:
309            ``(output, aux)``.
310        argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
311            ``argnums`` can be single integer or tuple of integers. Default: 0.
312        has_aux (bool): Flag indicating that ``func`` returns a tensor and other
313            auxiliary objects: ``(output, aux)``. Default: False.
314
315    Returns:
316        Function to compute gradients with respect to its inputs. By default, the output of
317        the function is the gradient tensor(s) with respect to the first argument.
318        If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects
319        is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with
320        respect to each ``argnums`` value is returned.
321
322    Example of using ``grad``:
323
324        >>> # xdoctest: +SKIP
325        >>> from torch.func import grad
326        >>> x = torch.randn([])
327        >>> cos_x = grad(lambda x: torch.sin(x))(x)
328        >>> assert torch.allclose(cos_x, x.cos())
329        >>>
330        >>> # Second-order gradients
331        >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
332        >>> assert torch.allclose(neg_sin_x, -x.sin())
333
334    When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
335
336        >>> # xdoctest: +SKIP
337        >>> from torch.func import grad, vmap
338        >>> batch_size, feature_size = 3, 5
339        >>>
340        >>> def model(weights, feature_vec):
341        >>>     # Very simple linear model with activation
342        >>>     assert feature_vec.dim() == 1
343        >>>     return feature_vec.dot(weights).relu()
344        >>>
345        >>> def compute_loss(weights, example, target):
346        >>>     y = model(weights, example)
347        >>>     return ((y - target) ** 2).mean()  # MSELoss
348        >>>
349        >>> weights = torch.randn(feature_size, requires_grad=True)
350        >>> examples = torch.randn(batch_size, feature_size)
351        >>> targets = torch.randn(batch_size)
352        >>> inputs = (weights, examples, targets)
353        >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
354
355    Example of using ``grad`` with ``has_aux`` and ``argnums``:
356
357        >>> # xdoctest: +SKIP
358        >>> from torch.func import grad
359        >>> def my_loss_func(y, y_pred):
360        >>>    loss_per_sample = (0.5 * y_pred - y) ** 2
361        >>>    loss = loss_per_sample.mean()
362        >>>    return loss, (y_pred, loss_per_sample)
363        >>>
364        >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
365        >>> y_true = torch.rand(4)
366        >>> y_preds = torch.rand(4, requires_grad=True)
367        >>> out = fn(y_true, y_preds)
368        >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
369
370    .. note::
371        Using PyTorch ``torch.no_grad`` together with ``grad``.
372
373        Case 1: Using ``torch.no_grad`` inside a function:
374
375            >>> # xdoctest: +SKIP
376            >>> def f(x):
377            >>>     with torch.no_grad():
378            >>>         c = x ** 2
379            >>>     return x - c
380
381        In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``.
382
383        Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
384
385            >>> # xdoctest: +SKIP
386            >>> with torch.no_grad():
387            >>>     grad(f)(x)
388
389        In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the
390        outer one. This is because ``grad`` is a "function transform": its result
391        should not depend on the result of a context manager outside of ``f``.
392
393    """
394    # To avoid cyclical dependency.
395    import torch._functorch.eager_transforms as eager_transforms
396    from torch._dynamo import is_compiling
397
398    def wrapper(*args, **kwargs):
399        return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
400
401    if not is_compiling():
402        wrapper = functools.wraps(func)(wrapper)
403
404    return wrapper
405
406
407@exposed_in("torch.func")
408def grad_and_value(
409    func: Callable, argnums: argnums_t = 0, has_aux: bool = False
410) -> Callable:
411    """
412    Returns a function to compute a tuple of the gradient and primal, or
413    forward, computation.
414
415    Args:
416        func (Callable): A Python function that takes one or more arguments.
417            Must return a single-element Tensor. If specified ``has_aux``
418            equals ``True``, function can return a tuple of single-element
419            Tensor and other auxiliary objects: ``(output, aux)``.
420        argnums (int or Tuple[int]): Specifies arguments to compute gradients
421            with respect to. ``argnums`` can be single integer or tuple of
422            integers. Default: 0.
423        has_aux (bool): Flag indicating that ``func`` returns a tensor and
424            other auxiliary objects: ``(output, aux)``. Default: False.
425
426    Returns:
427        Function to compute a tuple of gradients with respect to its inputs
428        and the forward computation. By default, the output of the function is
429        a tuple of the gradient tensor(s) with respect to the first argument
430        and the primal computation. If specified ``has_aux`` equals
431        ``True``, tuple of gradients and tuple of the forward computation with
432        output auxiliary objects is returned. If ``argnums`` is a tuple of
433        integers, a tuple of a tuple of the output gradients with respect to
434        each ``argnums`` value and the forward computation is returned.
435
436    See :func:`grad` for examples
437    """
438    from torch._dynamo import is_compiling
439    from torch._functorch import eager_transforms
440
441    def wrapper(*args, **kwargs):
442        return eager_transforms.grad_and_value_impl(
443            func, argnums, has_aux, args, kwargs
444        )
445
446    if not is_compiling():
447        wrapper = functools.wraps(func)(wrapper)
448
449    return wrapper
450