1# mypy: allow-untyped-defs 2import functools 3from typing import Any, Callable, List, Optional, Tuple, Union 4from typing_extensions import deprecated 5 6import torch 7from torch import Tensor 8from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten 9 10 11in_dims_t = Union[int, Tuple] 12out_dims_t = Union[int, Tuple[int, ...]] 13 14 15# Checks that all args-to-be-batched have the same batch dim size 16def _validate_and_get_batch_size( 17 flat_in_dims: List[Optional[int]], 18 flat_args: List, 19) -> int: 20 batch_sizes = [ 21 arg.size(in_dim) 22 for in_dim, arg in zip(flat_in_dims, flat_args) 23 if in_dim is not None 24 ] 25 if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes): 26 raise ValueError( 27 f"vmap: Expected all tensors to have the same size in the mapped " 28 f"dimension, got sizes {batch_sizes} for the mapped dimension" 29 ) 30 return batch_sizes[0] 31 32 33def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int: 34 if isinstance(batched_outputs, tuple): 35 return len(batched_outputs) 36 return 1 37 38 39# If value is a tuple, check it has length `num_elements`. 40# If value is not a tuple, make a tuple with `value` repeated `num_elements` times 41def _as_tuple( 42 value: Any, 43 num_elements: int, 44 error_message_lambda: Callable[[], str], 45) -> Tuple: 46 if not isinstance(value, tuple): 47 return (value,) * num_elements 48 if len(value) != num_elements: 49 raise ValueError(error_message_lambda()) 50 return value 51 52 53# Creates BatchedTensors for every Tensor in arg that should be batched. 54# Returns the (potentially) batched arguments and the batch_size. 55def _create_batched_inputs( 56 in_dims: in_dims_t, 57 args: Tuple, 58 vmap_level: int, 59 func: Callable, 60) -> Tuple[Tuple, int]: 61 if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): 62 raise ValueError( 63 f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): " 64 f"expected `in_dims` to be int or a (potentially nested) tuple " 65 f"matching the structure of inputs, got: {type(in_dims)}." 66 ) 67 if len(args) == 0: 68 raise ValueError( 69 f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add " 70 f"inputs, or you are trying to vmap over a function with no inputs. " 71 f"The latter is unsupported." 72 ) 73 74 flat_args, args_spec = tree_flatten(args) 75 flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec) 76 if flat_in_dims is None: 77 raise ValueError( 78 f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): " 79 f"in_dims is not compatible with the structure of `inputs`. " 80 f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs " 81 f"has structure {args_spec}." 82 ) 83 84 for arg, in_dim in zip(flat_args, flat_in_dims): 85 if not isinstance(in_dim, int) and in_dim is not None: 86 raise ValueError( 87 f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): " 88 f"Got in_dim={in_dim} for an input but in_dim must be either " 89 f"an integer dimension or None." 90 ) 91 if isinstance(in_dim, int) and not isinstance(arg, Tensor): 92 raise ValueError( 93 f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): " 94 f"Got in_dim={in_dim} for an input but the input is of type " 95 f"{type(arg)}. We cannot vmap over non-Tensor arguments, " 96 f"please use None as the respective in_dim" 97 ) 98 if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()): 99 raise ValueError( 100 f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): " 101 f"Got in_dim={in_dim} for some input, but that input is a Tensor " 102 f"of dimensionality {arg.dim()} so expected in_dim to satisfy " 103 f"0 <= in_dim < {arg.dim()}." 104 ) 105 106 batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args) 107 # See NOTE [Ignored _remove_batch_dim, _add_batch_dim] 108 batched_inputs = [ 109 arg if in_dim is None else torch._add_batch_dim(arg, in_dim, vmap_level) 110 for in_dim, arg in zip(flat_in_dims, flat_args) 111 ] 112 return tree_unflatten(batched_inputs, args_spec), batch_size 113 114 115# Undos the batching (and any batch dimensions) associated with the `vmap_level`. 116def _unwrap_batched( 117 batched_outputs: Union[Tensor, Tuple[Tensor, ...]], 118 out_dims: out_dims_t, 119 vmap_level: int, 120 batch_size: int, 121 func: Callable, 122 allow_none_pass_through: bool = False, 123) -> Tuple: 124 num_outputs = _num_outputs(batched_outputs) 125 out_dims_as_tuple = _as_tuple( 126 out_dims, 127 num_outputs, 128 lambda: f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must " 129 f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.", 130 ) 131 132 # NOTE [Ignored _remove_batch_dim, _add_batch_dim] 133 # There is something wrong with our type bindings for functions that begin 134 # with '_', see #40397. 135 if isinstance(batched_outputs, Tensor): 136 out_dim = out_dims_as_tuple[0] 137 return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value] 138 if allow_none_pass_through: 139 return tuple( 140 ( 141 torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) 142 if out is not None 143 else None 144 ) 145 for out, out_dim in zip(batched_outputs, out_dims_as_tuple) 146 ) 147 else: 148 return tuple( 149 torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) 150 for out, out_dim in zip(batched_outputs, out_dims_as_tuple) 151 ) 152 153 154# Checks that `fn` returned one or more Tensors and nothing else. 155# NB: A python function that return multiple arguments returns a single tuple, 156# so we are effectively checking that `outputs` is a single Tensor or a tuple of 157# Tensors. 158def _validate_outputs(outputs: Any, func: Callable) -> None: 159 if isinstance(outputs, Tensor): 160 return 161 if not isinstance(outputs, tuple): 162 raise ValueError( 163 f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return " 164 f"Tensors, got type {type(outputs)} as the return." 165 ) 166 for idx, output in enumerate(outputs): 167 if isinstance(output, Tensor): 168 continue 169 raise ValueError( 170 f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return " 171 f"Tensors, got type {type(output)} for return {idx}." 172 ) 173 174 175def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None: 176 if isinstance(out_dims, int): 177 return 178 if not isinstance(out_dims, tuple) or not all( 179 isinstance(out_dim, int) for out_dim in out_dims 180 ): 181 raise ValueError( 182 f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be " 183 f"an int or a tuple of int representing where in the outputs the " 184 f"vmapped dimension should appear." 185 ) 186 187 188def _get_name(func: Callable): 189 if hasattr(func, "__name__"): 190 return func.__name__ 191 192 # Not all callables have __name__, in fact, only static functions/methods do. 193 # A callable created via functools.partial or an nn.Module, to name some 194 # examples, don't have a __name__. 195 return repr(func) 196 197 198# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors, 199# sends those into func, and then unwraps the output BatchedTensors. Operations 200# on BatchedTensors perform the batched operations that the user is asking for. 201@deprecated( 202 "Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.", 203 category=FutureWarning, 204) 205def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable: 206 """ 207 Please use torch.vmap instead of this API. 208 """ 209 return _vmap(func, in_dims, out_dims) 210 211 212# A version of vmap but without the initial "experimental prototype" warning 213def _vmap( 214 func: Callable, 215 in_dims: in_dims_t = 0, 216 out_dims: out_dims_t = 0, 217 allow_none_pass_through: bool = False, 218) -> Callable: 219 # The `allow_none_pass_through` argument is a temporary workaround may be removed. 220 # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine, 221 # which may return None if any of the inputs are unused. See the issue discussing this: 222 # https://github.com/facebookresearch/functorch/issues/159. 223 @functools.wraps(func) 224 def wrapped(*args): 225 _check_out_dims_is_int_or_int_tuple(out_dims, func) 226 vmap_level = torch._C._vmapmode_increment_nesting() 227 try: 228 batched_inputs, batch_size = _create_batched_inputs( 229 in_dims, args, vmap_level, func 230 ) 231 batched_outputs = func(*batched_inputs) 232 if not allow_none_pass_through: 233 _validate_outputs(batched_outputs, func) 234 return _unwrap_batched( 235 batched_outputs, 236 out_dims, 237 vmap_level, 238 batch_size, 239 func, 240 allow_none_pass_through=allow_none_pass_through, 241 ) 242 finally: 243 torch._C._vmapmode_decrement_nesting() 244 245 return wrapped 246