1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3# NOTE: We allow Dynamo to see this file (via torch/_dynamo/trace_rules.py) so that it can 4# trace through functorch transforms. 5# Currently, we can't allow Dynamo to see `eager_transforms.py`/`vmap.py` as that break a lot of thing 6# and there isn't a mechanism to selectively expose only some functions (eg. grad) from a file 7# to Dynamo. 8import functools 9 10from torch._functorch.utils import argnums_t, exposed_in 11from torch._functorch.vmap import ( 12 _check_out_dims_is_int_or_int_pytree, 13 _check_randomness_arg, 14 _chunked_vmap, 15 _process_batched_inputs, 16 Callable, 17 in_dims_t, 18 out_dims_t, 19 vmap_impl, 20) 21 22 23# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors, 24# sends those into func, and then unwraps the output BatchedTensors. Operations 25# on BatchedTensors perform the batched operations that the user is asking for. 26# 27# vmap's randomness behavior differs from JAX's, which would require a PRNG key 28# to be passed everywhere. 29 30 31@exposed_in("torch.func") 32def vmap( 33 func: Callable, 34 in_dims: in_dims_t = 0, 35 out_dims: out_dims_t = 0, 36 randomness: str = "error", 37 *, 38 chunk_size=None, 39) -> Callable: 40 """ 41 vmap is the vectorizing map; ``vmap(func)`` returns a new function that 42 maps ``func`` over some dimension of the inputs. Semantically, vmap 43 pushes the map into PyTorch operations called by ``func``, effectively 44 vectorizing those operations. 45 46 vmap is useful for handling batch dimensions: one can write a function 47 ``func`` that runs on examples and then lift it to a function that can 48 take batches of examples with ``vmap(func)``. vmap can also be used to 49 compute batched gradients when composed with autograd. 50 51 .. note:: 52 :func:`torch.vmap` is aliased to :func:`torch.func.vmap` for 53 convenience. Use whichever one you'd like. 54 55 Args: 56 func (function): A Python function that takes one or more arguments. 57 Must return one or more Tensors. 58 in_dims (int or nested structure): Specifies which dimension of the 59 inputs should be mapped over. ``in_dims`` should have a 60 structure like the inputs. If the ``in_dim`` for a particular 61 input is None, then that indicates there is no map dimension. 62 Default: 0. 63 out_dims (int or Tuple[int]): Specifies where the mapped dimension 64 should appear in the outputs. If ``out_dims`` is a Tuple, then 65 it should have one element per output. Default: 0. 66 randomness (str): Specifies whether the randomness in this 67 vmap should be the same or different across batches. If 'different', 68 the randomness for each batch will be different. If 'same', the 69 randomness will be the same across batches. If 'error', any calls to 70 random functions will error. Default: 'error'. WARNING: this flag 71 only applies to random PyTorch operations and does not apply to 72 Python's random module or numpy randomness. 73 chunk_size (None or int): If None (default), apply a single vmap over inputs. 74 If not None, then compute the vmap :attr:`chunk_size` samples at a time. 75 Note that :attr:`chunk_size=1` is equivalent to computing the vmap with a for-loop. 76 If you run into memory issues computing the vmap, please try a non-None chunk_size. 77 78 Returns: 79 Returns a new "batched" function. It takes the same inputs as 80 ``func``, except each input has an extra dimension at the index 81 specified by ``in_dims``. It takes returns the same outputs as 82 ``func``, except each output has an extra dimension at the index 83 specified by ``out_dims``. 84 85 .. warning: 86 :func:`vmap` works best with functional-style code. Please do not 87 perform any side-effects in ``func``, with the exception of 88 in-place PyTorch operations. Examples of side-effects include mutating 89 Python data structures and assigning values to variables not captured 90 in ``func``. 91 92 One example of using :func:`vmap` is to compute batched dot products. PyTorch 93 doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully 94 rummaging through docs, use :func:`vmap` to construct a new function. 95 96 >>> torch.dot # [D], [D] -> [] 97 >>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N] 98 >>> x, y = torch.randn(2, 5), torch.randn(2, 5) 99 >>> batched_dot(x, y) 100 101 :func:`vmap` can be helpful in hiding batch dimensions, leading to a simpler 102 model authoring experience. 103 104 >>> batch_size, feature_size = 3, 5 105 >>> weights = torch.randn(feature_size, requires_grad=True) 106 >>> 107 >>> def model(feature_vec): 108 >>> # Very simple linear model with activation 109 >>> return feature_vec.dot(weights).relu() 110 >>> 111 >>> examples = torch.randn(batch_size, feature_size) 112 >>> result = torch.vmap(model)(examples) 113 114 :func:`vmap` can also help vectorize computations that were previously difficult 115 or impossible to batch. One example is higher-order gradient computation. 116 The PyTorch autograd engine computes vjps (vector-Jacobian products). 117 Computing a full Jacobian matrix for some function f: R^N -> R^N usually 118 requires N calls to ``autograd.grad``, one per Jacobian row. Using :func:`vmap`, 119 we can vectorize the whole computation, computing the Jacobian in a single 120 call to ``autograd.grad``. 121 122 >>> # Setup 123 >>> N = 5 124 >>> f = lambda x: x ** 2 125 >>> x = torch.randn(N, requires_grad=True) 126 >>> y = f(x) 127 >>> I_N = torch.eye(N) 128 >>> 129 >>> # Sequential approach 130 >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0] 131 >>> for v in I_N.unbind()] 132 >>> jacobian = torch.stack(jacobian_rows) 133 >>> 134 >>> # vectorized gradient computation 135 >>> def get_vjp(v): 136 >>> return torch.autograd.grad(y, x, v) 137 >>> jacobian = torch.vmap(get_vjp)(I_N) 138 139 :func:`vmap` can also be nested, producing an output with multiple batched dimensions 140 141 >>> torch.dot # [D], [D] -> [] 142 >>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0] 143 >>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5) 144 >>> batched_dot(x, y) # tensor of size [2, 3] 145 146 If the inputs are not batched along the first dimension, ``in_dims`` specifies 147 the dimension that each inputs are batched along as 148 149 >>> torch.dot # [N], [N] -> [] 150 >>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D] 151 >>> x, y = torch.randn(2, 5), torch.randn(2, 5) 152 >>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension 153 154 If there are multiple inputs each of which is batched along different dimensions, 155 ``in_dims`` must be a tuple with the batch dimension for each input as 156 157 >>> torch.dot # [D], [D] -> [] 158 >>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N] 159 >>> x, y = torch.randn(2, 5), torch.randn(5) 160 >>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None 161 162 If the input is a Python struct, ``in_dims`` must be a tuple containing a struct 163 matching the shape of the input: 164 165 >>> f = lambda dict: torch.dot(dict['x'], dict['y']) 166 >>> x, y = torch.randn(2, 5), torch.randn(5) 167 >>> input = {'x': x, 'y': y} 168 >>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},)) 169 >>> batched_dot(input) 170 171 By default, the output is batched along the first dimension. However, it can be batched 172 along any dimension by using ``out_dims`` 173 174 >>> f = lambda x: x ** 2 175 >>> x = torch.randn(2, 5) 176 >>> batched_pow = torch.vmap(f, out_dims=1) 177 >>> batched_pow(x) # [5, 2] 178 179 For any function that uses kwargs, the returned function will not batch the kwargs but will 180 accept kwargs 181 182 >>> x = torch.randn([2, 5]) 183 >>> def fn(x, scale=4.): 184 >>> return x * scale 185 >>> 186 >>> batched_pow = torch.vmap(fn) 187 >>> assert torch.allclose(batched_pow(x), x * 4) 188 >>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5] 189 190 .. note:: 191 vmap does not provide general autobatching or handle variable-length 192 sequences out of the box. 193 """ 194 from torch._dynamo import is_compiling 195 196 _check_randomness_arg(randomness) 197 if not (chunk_size is None or chunk_size > 0): 198 raise ValueError( 199 f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})" 200 ) 201 202 def wrapped(*args, **kwargs): 203 return vmap_impl( 204 func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs 205 ) 206 207 if not is_compiling(): 208 wrapped = functools.wraps(func)(wrapped) 209 210 return wrapped 211 212 213def chunk_vmap( 214 func: Callable, 215 in_dims: in_dims_t = 0, 216 out_dims: out_dims_t = 0, 217 randomness: str = "error", 218 chunks=2, 219) -> Callable: 220 """ 221 chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes 222 everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of 223 chunks at a time. For more details about vectorizing map, see :func:`vmap`. 224 225 .. note:: 226 Please use :func:`vmap` with ``chunk_size`` argument instead of this API. 227 228 Args: 229 func (function): A Python function that takes one or more arguments. 230 Must return one or more Tensors. 231 in_dims (int or nested structure): Specifies which dimension of the 232 inputs should be mapped over. ``in_dims`` should have a 233 structure like the inputs. If the ``in_dim`` for a particular 234 input is None, then that indicates there is no map dimension. 235 Default: 0. 236 out_dims (int or Tuple[int]): Specifies where the mapped dimension 237 should appear in the outputs. If ``out_dims`` is a Tuple, then 238 it should have one element per output. Default: 0. 239 randomness (str): Specifies whether the randomness in this 240 vmap should be the same or different across batches. If 'different', 241 the randomness for each batch will be different. If 'same', the 242 randomness will be the same across batches. If 'error', any calls to 243 random functions will error. Default: 'error'. WARNING: this flag 244 only applies to random PyTorch operations and does not apply to 245 Python's random module or numpy randomness. 246 chunks (int): Number of chunks to use to split the input data. Default is 2. 247 If equals to 1 then :func:`vmap` is called. 248 249 Returns: 250 Returns a new "batched" function. It takes the same inputs as 251 ``func``, except each input has an extra dimension at the index 252 specified by ``in_dims``. It takes returns the same outputs as 253 ``func``, except each output has an extra dimension at the index 254 specified by ``out_dims``. 255 """ 256 _check_randomness_arg(randomness) 257 258 if chunks == 1: 259 return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness) 260 261 def _get_chunk_flat_args(flat_args_, flat_in_dims_, chunks_): 262 flat_args_chunks = tuple( 263 t.chunk(chunks_, dim=in_dim) 264 if in_dim is not None 265 else [ 266 t, 267 ] 268 * chunks_ 269 for t, in_dim in zip(flat_args_, flat_in_dims_) 270 ) 271 # transpose chunk dim and flatten structure 272 # chunks_flat_args is a list of flatten args 273 chunks_flat_args = zip(*flat_args_chunks) 274 return chunks_flat_args 275 276 @functools.wraps(func) 277 def wrapped_with_chunks(*args, **kwargs): 278 _check_out_dims_is_int_or_int_pytree(out_dims, func) 279 _, flat_in_dims, flat_args, args_spec = _process_batched_inputs( 280 in_dims, args, func 281 ) 282 # Chunk flat arguments 283 chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks) 284 285 # Apply vmap on chunks 286 return _chunked_vmap( 287 func, 288 flat_in_dims, 289 chunks_flat_args, 290 args_spec, 291 out_dims, 292 randomness, 293 **kwargs, 294 ) 295 296 return wrapped_with_chunks 297 298 299@exposed_in("torch.func") 300def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: 301 """``grad`` operator helps computing gradients of ``func`` with respect to the 302 input(s) specified by ``argnums``. This operator can be nested to 303 compute higher-order gradients. 304 305 Args: 306 func (Callable): A Python function that takes one or more arguments. 307 Must return a single-element Tensor. If specified ``has_aux`` equals ``True``, 308 function can return a tuple of single-element Tensor and other auxiliary objects: 309 ``(output, aux)``. 310 argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to. 311 ``argnums`` can be single integer or tuple of integers. Default: 0. 312 has_aux (bool): Flag indicating that ``func`` returns a tensor and other 313 auxiliary objects: ``(output, aux)``. Default: False. 314 315 Returns: 316 Function to compute gradients with respect to its inputs. By default, the output of 317 the function is the gradient tensor(s) with respect to the first argument. 318 If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects 319 is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with 320 respect to each ``argnums`` value is returned. 321 322 Example of using ``grad``: 323 324 >>> # xdoctest: +SKIP 325 >>> from torch.func import grad 326 >>> x = torch.randn([]) 327 >>> cos_x = grad(lambda x: torch.sin(x))(x) 328 >>> assert torch.allclose(cos_x, x.cos()) 329 >>> 330 >>> # Second-order gradients 331 >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) 332 >>> assert torch.allclose(neg_sin_x, -x.sin()) 333 334 When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients: 335 336 >>> # xdoctest: +SKIP 337 >>> from torch.func import grad, vmap 338 >>> batch_size, feature_size = 3, 5 339 >>> 340 >>> def model(weights, feature_vec): 341 >>> # Very simple linear model with activation 342 >>> assert feature_vec.dim() == 1 343 >>> return feature_vec.dot(weights).relu() 344 >>> 345 >>> def compute_loss(weights, example, target): 346 >>> y = model(weights, example) 347 >>> return ((y - target) ** 2).mean() # MSELoss 348 >>> 349 >>> weights = torch.randn(feature_size, requires_grad=True) 350 >>> examples = torch.randn(batch_size, feature_size) 351 >>> targets = torch.randn(batch_size) 352 >>> inputs = (weights, examples, targets) 353 >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) 354 355 Example of using ``grad`` with ``has_aux`` and ``argnums``: 356 357 >>> # xdoctest: +SKIP 358 >>> from torch.func import grad 359 >>> def my_loss_func(y, y_pred): 360 >>> loss_per_sample = (0.5 * y_pred - y) ** 2 361 >>> loss = loss_per_sample.mean() 362 >>> return loss, (y_pred, loss_per_sample) 363 >>> 364 >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True) 365 >>> y_true = torch.rand(4) 366 >>> y_preds = torch.rand(4, requires_grad=True) 367 >>> out = fn(y_true, y_preds) 368 >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample)) 369 370 .. note:: 371 Using PyTorch ``torch.no_grad`` together with ``grad``. 372 373 Case 1: Using ``torch.no_grad`` inside a function: 374 375 >>> # xdoctest: +SKIP 376 >>> def f(x): 377 >>> with torch.no_grad(): 378 >>> c = x ** 2 379 >>> return x - c 380 381 In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``. 382 383 Case 2: Using ``grad`` inside ``torch.no_grad`` context manager: 384 385 >>> # xdoctest: +SKIP 386 >>> with torch.no_grad(): 387 >>> grad(f)(x) 388 389 In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the 390 outer one. This is because ``grad`` is a "function transform": its result 391 should not depend on the result of a context manager outside of ``f``. 392 393 """ 394 # To avoid cyclical dependency. 395 import torch._functorch.eager_transforms as eager_transforms 396 from torch._dynamo import is_compiling 397 398 def wrapper(*args, **kwargs): 399 return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs) 400 401 if not is_compiling(): 402 wrapper = functools.wraps(func)(wrapper) 403 404 return wrapper 405 406 407@exposed_in("torch.func") 408def grad_and_value( 409 func: Callable, argnums: argnums_t = 0, has_aux: bool = False 410) -> Callable: 411 """ 412 Returns a function to compute a tuple of the gradient and primal, or 413 forward, computation. 414 415 Args: 416 func (Callable): A Python function that takes one or more arguments. 417 Must return a single-element Tensor. If specified ``has_aux`` 418 equals ``True``, function can return a tuple of single-element 419 Tensor and other auxiliary objects: ``(output, aux)``. 420 argnums (int or Tuple[int]): Specifies arguments to compute gradients 421 with respect to. ``argnums`` can be single integer or tuple of 422 integers. Default: 0. 423 has_aux (bool): Flag indicating that ``func`` returns a tensor and 424 other auxiliary objects: ``(output, aux)``. Default: False. 425 426 Returns: 427 Function to compute a tuple of gradients with respect to its inputs 428 and the forward computation. By default, the output of the function is 429 a tuple of the gradient tensor(s) with respect to the first argument 430 and the primal computation. If specified ``has_aux`` equals 431 ``True``, tuple of gradients and tuple of the forward computation with 432 output auxiliary objects is returned. If ``argnums`` is a tuple of 433 integers, a tuple of a tuple of the output gradients with respect to 434 each ``argnums`` value and the forward computation is returned. 435 436 See :func:`grad` for examples 437 """ 438 from torch._dynamo import is_compiling 439 from torch._functorch import eager_transforms 440 441 def wrapper(*args, **kwargs): 442 return eager_transforms.grad_and_value_impl( 443 func, argnums, has_aux, args, kwargs 444 ) 445 446 if not is_compiling(): 447 wrapper = functools.wraps(func)(wrapper) 448 449 return wrapper 450