xref: /aosp_15_r20/external/pytorch/torch/_vmap_internals.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from typing import Any, Callable, List, Optional, Tuple, Union
4from typing_extensions import deprecated
5
6import torch
7from torch import Tensor
8from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten
9
10
11in_dims_t = Union[int, Tuple]
12out_dims_t = Union[int, Tuple[int, ...]]
13
14
15# Checks that all args-to-be-batched have the same batch dim size
16def _validate_and_get_batch_size(
17    flat_in_dims: List[Optional[int]],
18    flat_args: List,
19) -> int:
20    batch_sizes = [
21        arg.size(in_dim)
22        for in_dim, arg in zip(flat_in_dims, flat_args)
23        if in_dim is not None
24    ]
25    if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
26        raise ValueError(
27            f"vmap: Expected all tensors to have the same size in the mapped "
28            f"dimension, got sizes {batch_sizes} for the mapped dimension"
29        )
30    return batch_sizes[0]
31
32
33def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
34    if isinstance(batched_outputs, tuple):
35        return len(batched_outputs)
36    return 1
37
38
39# If value is a tuple, check it has length `num_elements`.
40# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
41def _as_tuple(
42    value: Any,
43    num_elements: int,
44    error_message_lambda: Callable[[], str],
45) -> Tuple:
46    if not isinstance(value, tuple):
47        return (value,) * num_elements
48    if len(value) != num_elements:
49        raise ValueError(error_message_lambda())
50    return value
51
52
53# Creates BatchedTensors for every Tensor in arg that should be batched.
54# Returns the (potentially) batched arguments and the batch_size.
55def _create_batched_inputs(
56    in_dims: in_dims_t,
57    args: Tuple,
58    vmap_level: int,
59    func: Callable,
60) -> Tuple[Tuple, int]:
61    if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
62        raise ValueError(
63            f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
64            f"expected `in_dims` to be int or a (potentially nested) tuple "
65            f"matching the structure of inputs, got: {type(in_dims)}."
66        )
67    if len(args) == 0:
68        raise ValueError(
69            f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add "
70            f"inputs, or you are trying to vmap over a function with no inputs. "
71            f"The latter is unsupported."
72        )
73
74    flat_args, args_spec = tree_flatten(args)
75    flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
76    if flat_in_dims is None:
77        raise ValueError(
78            f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
79            f"in_dims is not compatible with the structure of `inputs`. "
80            f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs "
81            f"has structure {args_spec}."
82        )
83
84    for arg, in_dim in zip(flat_args, flat_in_dims):
85        if not isinstance(in_dim, int) and in_dim is not None:
86            raise ValueError(
87                f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
88                f"Got in_dim={in_dim} for an input but in_dim must be either "
89                f"an integer dimension or None."
90            )
91        if isinstance(in_dim, int) and not isinstance(arg, Tensor):
92            raise ValueError(
93                f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
94                f"Got in_dim={in_dim} for an input but the input is of type "
95                f"{type(arg)}. We cannot vmap over non-Tensor arguments, "
96                f"please use None as the respective in_dim"
97            )
98        if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
99            raise ValueError(
100                f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
101                f"Got in_dim={in_dim} for some input, but that input is a Tensor "
102                f"of dimensionality {arg.dim()} so expected in_dim to satisfy "
103                f"0 <= in_dim < {arg.dim()}."
104            )
105
106    batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
107    # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
108    batched_inputs = [
109        arg if in_dim is None else torch._add_batch_dim(arg, in_dim, vmap_level)
110        for in_dim, arg in zip(flat_in_dims, flat_args)
111    ]
112    return tree_unflatten(batched_inputs, args_spec), batch_size
113
114
115# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
116def _unwrap_batched(
117    batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
118    out_dims: out_dims_t,
119    vmap_level: int,
120    batch_size: int,
121    func: Callable,
122    allow_none_pass_through: bool = False,
123) -> Tuple:
124    num_outputs = _num_outputs(batched_outputs)
125    out_dims_as_tuple = _as_tuple(
126        out_dims,
127        num_outputs,
128        lambda: f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must "
129        f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.",
130    )
131
132    # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
133    # There is something wrong with our type bindings for functions that begin
134    # with '_', see #40397.
135    if isinstance(batched_outputs, Tensor):
136        out_dim = out_dims_as_tuple[0]
137        return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim)  # type: ignore[return-value]
138    if allow_none_pass_through:
139        return tuple(
140            (
141                torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
142                if out is not None
143                else None
144            )
145            for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
146        )
147    else:
148        return tuple(
149            torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
150            for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
151        )
152
153
154# Checks that `fn` returned one or more Tensors and nothing else.
155# NB: A python function that return multiple arguments returns a single tuple,
156# so we are effectively checking that `outputs` is a single Tensor or a tuple of
157# Tensors.
158def _validate_outputs(outputs: Any, func: Callable) -> None:
159    if isinstance(outputs, Tensor):
160        return
161    if not isinstance(outputs, tuple):
162        raise ValueError(
163            f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
164            f"Tensors, got type {type(outputs)} as the return."
165        )
166    for idx, output in enumerate(outputs):
167        if isinstance(output, Tensor):
168            continue
169        raise ValueError(
170            f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
171            f"Tensors, got type {type(output)} for return {idx}."
172        )
173
174
175def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
176    if isinstance(out_dims, int):
177        return
178    if not isinstance(out_dims, tuple) or not all(
179        isinstance(out_dim, int) for out_dim in out_dims
180    ):
181        raise ValueError(
182            f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "
183            f"an int or a tuple of int representing where in the outputs the "
184            f"vmapped dimension should appear."
185        )
186
187
188def _get_name(func: Callable):
189    if hasattr(func, "__name__"):
190        return func.__name__
191
192    # Not all callables have __name__, in fact, only static functions/methods do.
193    # A callable created via functools.partial or an nn.Module, to name some
194    # examples, don't have a __name__.
195    return repr(func)
196
197
198# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
199# sends those into func, and then unwraps the output BatchedTensors. Operations
200# on BatchedTensors perform the batched operations that the user is asking for.
201@deprecated(
202    "Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.",
203    category=FutureWarning,
204)
205def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
206    """
207    Please use torch.vmap instead of this API.
208    """
209    return _vmap(func, in_dims, out_dims)
210
211
212# A version of vmap but without the initial "experimental prototype" warning
213def _vmap(
214    func: Callable,
215    in_dims: in_dims_t = 0,
216    out_dims: out_dims_t = 0,
217    allow_none_pass_through: bool = False,
218) -> Callable:
219    # The `allow_none_pass_through` argument is a temporary workaround may be removed.
220    # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine,
221    # which may return None if any of the inputs are unused. See the issue discussing this:
222    # https://github.com/facebookresearch/functorch/issues/159.
223    @functools.wraps(func)
224    def wrapped(*args):
225        _check_out_dims_is_int_or_int_tuple(out_dims, func)
226        vmap_level = torch._C._vmapmode_increment_nesting()
227        try:
228            batched_inputs, batch_size = _create_batched_inputs(
229                in_dims, args, vmap_level, func
230            )
231            batched_outputs = func(*batched_inputs)
232            if not allow_none_pass_through:
233                _validate_outputs(batched_outputs, func)
234            return _unwrap_batched(
235                batched_outputs,
236                out_dims,
237                vmap_level,
238                batch_size,
239                func,
240                allow_none_pass_through=allow_none_pass_through,
241            )
242        finally:
243            torch._C._vmapmode_decrement_nesting()
244
245    return wrapped
246