1# mypy: allow-untyped-defs 2""" 3The APIs in this file are exposed as `functorch.*`. They are thin wrappers 4around the torch.func.* APIs that have deprecation warnings -- we're trying 5to move people to the torch.func.* equivalents. 6 7NB: We don't use *args, **kwargs in the signatures because that changes the 8documentation. 9""" 10 11import textwrap 12import warnings 13from typing import Any, Callable, Optional, Tuple, Union 14 15import torch._functorch.apis as apis 16import torch._functorch.eager_transforms as _impl 17import torch._functorch.make_functional as _nn_impl 18import torch.nn as nn 19from torch._functorch.eager_transforms import argnums_t 20from torch._functorch.vmap import in_dims_t, out_dims_t 21 22 23def get_warning(api, new_api=None, replace_newlines=False): 24 if new_api is None: 25 new_api = f"torch.func.{api}" 26 warning = ( 27 f"We've integrated functorch into PyTorch. As the final step of the \n" 28 f"integration, `functorch.{api}` is deprecated as of PyTorch \n" 29 f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n" 30 f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n" 31 f"and/or the `torch.func` migration guide for more details \n" 32 f"https://pytorch.org/docs/main/func.migrating.html" 33 ) 34 if replace_newlines: 35 warning = warning.replace("\n", "") 36 return warning 37 38 39def warn_deprecated(api, new_api=None): 40 warning = get_warning(api, new_api, replace_newlines=True) 41 warnings.warn(warning, FutureWarning, stacklevel=3) 42 43 44def setup_docs(functorch_api, torch_func_api=None, new_api_name=None): 45 api_name = functorch_api.__name__ 46 if torch_func_api is None: 47 torch_func_api = getattr(_impl, api_name) 48 # See https://docs.python.org/3/using/cmdline.html#cmdoption-OO 49 if torch_func_api.__doc__ is None: 50 return 51 52 warning = get_warning(api_name, new_api_name) 53 warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, " ") 54 warning_note = textwrap.indent(warning_note, " ") 55 functorch_api.__doc__ = torch_func_api.__doc__ + warning_note 56 57 58def vmap( 59 func: Callable, 60 in_dims: in_dims_t = 0, 61 out_dims: out_dims_t = 0, 62 randomness: str = "error", 63 *, 64 chunk_size=None, 65) -> Callable: 66 warn_deprecated("vmap", "torch.vmap") 67 return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size) 68 69 70def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: 71 warn_deprecated("grad") 72 return apis.grad(func, argnums, has_aux) 73 74 75def grad_and_value( 76 func: Callable, argnums: argnums_t = 0, has_aux: bool = False 77) -> Callable: 78 warn_deprecated("grad_and_value") 79 return apis.grad_and_value(func, argnums, has_aux) 80 81 82def vjp(func: Callable, *primals, has_aux: bool = False): 83 warn_deprecated("vjp") 84 return _impl.vjp(func, *primals, has_aux=has_aux) 85 86 87def jvp( 88 func: Callable, 89 primals: Any, 90 tangents: Any, 91 *, 92 strict: bool = False, 93 has_aux: bool = False, 94): 95 warn_deprecated("jvp") 96 return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux) 97 98 99def jacrev( 100 func: Callable, 101 argnums: Union[int, Tuple[int]] = 0, 102 *, 103 has_aux=False, 104 chunk_size: Optional[int] = None, 105 _preallocate_and_copy=False, 106): 107 warn_deprecated("jacrev") 108 return _impl.jacrev( 109 func, 110 argnums, 111 has_aux=has_aux, 112 chunk_size=chunk_size, 113 _preallocate_and_copy=_preallocate_and_copy, 114 ) 115 116 117def jacfwd( 118 func: Callable, 119 argnums: argnums_t = 0, 120 has_aux: bool = False, 121 *, 122 randomness: str = "error", 123): 124 warn_deprecated("jacfwd") 125 return _impl.jacfwd(func, argnums, has_aux, randomness=randomness) 126 127 128def hessian(func, argnums=0): 129 warn_deprecated("hessian") 130 return _impl.hessian(func, argnums=argnums) 131 132 133def functionalize(func: Callable, *, remove: str = "mutations") -> Callable: 134 warn_deprecated("functionalize") 135 return _impl.functionalize(func, remove=remove) 136 137 138def make_functional(model: nn.Module, disable_autograd_tracking: bool = False): 139 warn_deprecated("make_functional", "torch.func.functional_call") 140 return _nn_impl.make_functional(model, disable_autograd_tracking) 141 142 143def make_functional_with_buffers( 144 model: nn.Module, disable_autograd_tracking: bool = False 145): 146 warn_deprecated("make_functional_with_buffers", "torch.func.functional_call") 147 return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking) 148 149 150def combine_state_for_ensemble(models): 151 warn_deprecated("combine_state_for_ensemble", "torch.func.stack_module_state") 152 return _nn_impl.combine_state_for_ensemble(models) 153 154 155setup_docs(vmap, apis.vmap, "torch.vmap") 156setup_docs(grad, apis.grad) 157setup_docs(grad_and_value, apis.grad_and_value) 158setup_docs(vjp) 159setup_docs(jvp) 160setup_docs(jacrev) 161setup_docs(jacfwd) 162setup_docs(hessian) 163setup_docs(functionalize) 164setup_docs(make_functional, _nn_impl.make_functional, "torch.func.functional_call") 165setup_docs( 166 make_functional_with_buffers, _nn_impl.make_functional, "torch.func.functional_call" 167) 168setup_docs( 169 combine_state_for_ensemble, 170 _nn_impl.combine_state_for_ensemble, 171 "torch.func.stack_module_state", 172) 173