1import time 2from argparse import ArgumentParser 3from collections import defaultdict 4from typing import Any, Callable, List, NamedTuple 5 6import torch 7from torch.autograd import functional 8 9 10try: 11 import functorch as ft 12 13 has_functorch = True 14 print(f"Found functorch: {ft.__version__}") 15except ImportError: 16 has_functorch = False 17 18import audio_text_models 19import ppl_models 20import vision_models 21 22from utils import GetterType, InputsType, TimingResultType, to_markdown_table, VType 23 24 25def get_task_func(task: str) -> Callable: 26 def hessian_fwdrev(model, inp, strict=None): 27 return functional.hessian( 28 model, 29 inp, 30 strict=False, 31 vectorize=True, 32 outer_jacobian_strategy="forward-mode", 33 ) 34 35 def hessian_revrev(model, inp, strict=None): 36 return functional.hessian(model, inp, strict=False, vectorize=True) 37 38 def jacfwd(model, inp, strict=None): 39 return functional.jacobian( 40 model, inp, strict=False, vectorize=True, strategy="forward-mode" 41 ) 42 43 def jacrev(model, inp, strict=None): 44 return functional.jacobian(model, inp, strict=False, vectorize=True) 45 46 if task == "hessian_fwdrev": 47 return hessian_fwdrev 48 elif task == "hessian_revrev": 49 return hessian_revrev 50 elif task == "jacfwd": 51 return jacfwd 52 elif task == "jacrev": 53 return jacrev 54 else: 55 return getattr(functional, task) 56 57 58def get_task_functorch(task: str) -> Callable: 59 @torch.no_grad() 60 def vjp(model, inp, v=None, strict=None): 61 assert v is not None 62 out, vjpfunc = ft.vjp(model, *inp) 63 return out, vjpfunc(v) 64 65 @torch.no_grad() 66 def jvp(model, inp, v=None, strict=None): 67 assert v is not None 68 return ft.jvp(model, inp, v) 69 70 @torch.no_grad() 71 def vhp(model, inp, v=None, strict=None): 72 assert v is not None 73 argnums = tuple(range(len(inp))) 74 _, vjpfunc, aux = ft.vjp(ft.grad_and_value(model, argnums), *inp, has_aux=True) 75 return aux, vjpfunc(v) 76 77 @torch.no_grad() 78 def hvp(model, inp, v=None, strict=None): 79 assert v is not None 80 argnums = tuple(range(len(inp))) 81 _, hvp_out, aux = ft.jvp( 82 ft.grad_and_value(model, argnums), inp, v, has_aux=True 83 ) 84 return aux, hvp_out 85 86 @torch.no_grad() 87 def jacfwd(model, inp, v=None, strict=None): 88 argnums = tuple(range(len(inp))) 89 return ft.jacfwd(model, argnums)(*inp) 90 91 @torch.no_grad() 92 def jacrev(model, inp, v=None, strict=None): 93 argnums = tuple(range(len(inp))) 94 return ft.jacrev(model, argnums)(*inp) 95 96 @torch.no_grad() 97 def hessian(model, inp, v=None, strict=None): 98 argnums = tuple(range(len(inp))) 99 return ft.hessian(model, argnums=argnums)(*inp) 100 101 @torch.no_grad() 102 def hessian_fwdrev(model, inp, v=None, strict=None): 103 argnums = tuple(range(len(inp))) 104 return ft.jacfwd(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp) 105 106 @torch.no_grad() 107 def hessian_revrev(model, inp, v=None, strict=None): 108 argnums = tuple(range(len(inp))) 109 return ft.jacrev(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp) 110 111 if task in locals(): 112 return locals()[task] 113 elif task == "jacobian": 114 raise RuntimeError( 115 "functorch has no equivalent of autograd.functional.jacobian with vectorize=False yet" 116 ) 117 else: 118 raise RuntimeError(f"Unsupported task: {task}") 119 120 121# Listing of the different tasks 122FAST_TASKS_NO_DOUBLE_BACK = [ 123 "vjp", 124] 125 126FAST_TASKS = FAST_TASKS_NO_DOUBLE_BACK + [ 127 "vhp", 128 "jvp", 129] 130 131ALL_TASKS_NON_VECTORIZED = FAST_TASKS + ["hvp", "jacobian", "hessian"] 132 133DOUBLE_BACKWARD_TASKS = ["jvp", "hvp", "vhp", "hessian"] 134 135VECTORIZED_TASKS = ["hessian_fwdrev", "hessian_revrev", "jacfwd", "jacrev"] 136 137ALL_TASKS = ALL_TASKS_NON_VECTORIZED + VECTORIZED_TASKS 138 139 140# Model definition which contains: 141# - name: a string with the model name. 142# - getter: a function to get the model. It takes as input the device on which the model 143# will run. It should return the forward function and the parameters (Tensors) used as 144# input for the forward function. Note that the forward must *not* have any side effect. 145# - tasks: the list of recommended tasks that can run in a reasonable amount of time with this model. 146# - unsupported: the list of tasks that this model cannot run. 147class ModelDef(NamedTuple): 148 name: str 149 getter: GetterType 150 tasks: List[str] 151 unsupported: List[str] 152 153 154MODELS = [ 155 ModelDef("resnet18", vision_models.get_resnet18, FAST_TASKS, []), 156 ModelDef("fcn_resnet", vision_models.get_fcn_resnet, FAST_TASKS, []), 157 ModelDef("detr", vision_models.get_detr, FAST_TASKS, []), 158 ModelDef("ppl_simple_reg", ppl_models.get_simple_regression, ALL_TASKS, []), 159 ModelDef("ppl_robust_reg", ppl_models.get_robust_regression, ALL_TASKS, []), 160 ModelDef("wav2letter", audio_text_models.get_wav2letter, FAST_TASKS, []), 161 ModelDef( 162 "deepspeech", 163 audio_text_models.get_deepspeech, 164 FAST_TASKS_NO_DOUBLE_BACK, 165 DOUBLE_BACKWARD_TASKS, 166 ), 167 ModelDef("transformer", audio_text_models.get_transformer, FAST_TASKS, []), 168 ModelDef("multiheadattn", audio_text_models.get_multiheadattn, FAST_TASKS, []), 169] 170 171 172def get_v_for(model: Callable, inp: InputsType, task: str) -> VType: 173 v: VType 174 175 if task in ["vjp"]: 176 out = model(*inp) 177 v = torch.rand_like(out) 178 elif task in ["jvp", "hvp", "vhp"]: 179 if isinstance(inp, tuple): 180 v = tuple(torch.rand_like(i) for i in inp) 181 else: 182 v = torch.rand_like(inp) 183 else: 184 v = None 185 186 return v 187 188 189def run_once(model: Callable, inp: InputsType, task: str, v: VType, **kwargs) -> None: 190 func = get_task_func(task) 191 192 if v is not None: 193 res = func(model, inp, v=v, strict=True) 194 else: 195 res = func(model, inp, strict=True) 196 197 198def run_once_functorch( 199 model: Callable, inp: InputsType, task: str, v: VType, maybe_check_consistency=False 200) -> None: 201 func = get_task_functorch(task) 202 203 if v is not None: 204 res = func(model, inp, v=v, strict=True) 205 else: 206 res = func(model, inp, strict=True) 207 208 if maybe_check_consistency: 209 af_func = get_task_func(task) 210 if v is not None: 211 expected = af_func(model, inp, v=v, strict=True) 212 else: 213 expected = af_func(model, inp, strict=True) 214 atol = 1e-2 if task == "vhp" else 5e-3 215 torch.testing.assert_close( 216 res, 217 expected, 218 rtol=1e-5, 219 atol=atol, 220 msg=f"Consistency fail for task '{task}'", 221 ) 222 223 224def run_model( 225 model_getter: GetterType, args: Any, task: str, run_once_fn: Callable = run_once 226) -> List[float]: 227 if args.gpu == -1: 228 device = torch.device("cpu") 229 230 def noop(): 231 pass 232 233 do_sync = noop 234 else: 235 device = torch.device(f"cuda:{args.gpu}") 236 do_sync = torch.cuda.synchronize 237 238 model, inp = model_getter(device) 239 240 v = get_v_for(model, inp, task) 241 242 # Warmup 243 # maybe_check_consistency=True checks for consistency between 244 # functorch vs autograd.functional and is done in run_once_functorch only 245 run_once_fn(model, inp, task, v, maybe_check_consistency=True) 246 247 elapsed = [] 248 for it in range(args.num_iters): 249 do_sync() 250 start = time.time() 251 run_once_fn(model, inp, task, v) 252 do_sync() 253 elapsed.append(time.time() - start) 254 255 return elapsed 256 257 258def main(): 259 parser = ArgumentParser("Main script to benchmark functional API of the autograd.") 260 parser.add_argument( 261 "--output", type=str, default="", help="Text file where to write the output" 262 ) 263 parser.add_argument("--num-iters", type=int, default=10) 264 parser.add_argument( 265 "--gpu", 266 type=int, 267 default=-2, 268 help="GPU to use, -1 for CPU and -2 for auto-detect", 269 ) 270 parser.add_argument( 271 "--run-slow-tasks", action="store_true", help="Run even the slow tasks" 272 ) 273 parser.add_argument( 274 "--model-filter", 275 type=str, 276 default="", 277 help="Only run the models in this filter", 278 ) 279 parser.add_argument( 280 "--task-filter", type=str, default="", help="Only run the tasks in this filter" 281 ) 282 parser.add_argument( 283 "--num-threads", 284 type=int, 285 default=10, 286 help="Number of concurrent threads to use when running on cpu", 287 ) 288 parser.add_argument("--seed", type=int, default=0, help="The random seed to use.") 289 args = parser.parse_args() 290 291 results: TimingResultType = defaultdict(defaultdict) 292 torch.set_num_threads(args.num_threads) 293 torch.set_num_interop_threads(args.num_threads) 294 295 # This automatically seed cuda if it is available 296 torch.manual_seed(args.seed) 297 298 if args.gpu == -2: 299 args.gpu = 0 if torch.cuda.is_available() else -1 300 301 for name, model_getter, recommended_tasks, unsupported_tasks in MODELS: 302 if args.model_filter and name not in args.model_filter: 303 continue 304 tasks = ALL_TASKS if args.run_slow_tasks else recommended_tasks 305 for task in tasks: 306 if task in unsupported_tasks: 307 continue 308 if args.task_filter and task not in args.task_filter: 309 continue 310 runtimes = run_model(model_getter, args, task) 311 312 runtimes = torch.tensor(runtimes) 313 mean, var = runtimes.mean(), runtimes.var() 314 results[name][task] = (mean.item(), var.item()) 315 print(f"Results for model {name} on task {task}: {mean}s (var: {var})") 316 317 if has_functorch: 318 try: 319 runtimes = run_model( 320 model_getter, args, task, run_once_fn=run_once_functorch 321 ) 322 except RuntimeError as e: 323 print( 324 f"Failed model using Functorch: {name}, task: {task}, Error message: \n\t", 325 e, 326 ) 327 continue 328 329 runtimes = torch.tensor(runtimes) 330 mean, var = runtimes.mean(), runtimes.var() 331 results[name][f"functorch {task}"] = (mean.item(), var.item()) 332 print( 333 f"Results for model {name} on task {task} using Functorch: {mean}s (var: {var})" 334 ) 335 336 if args.output: 337 with open(args.output, "w") as f: 338 f.write(to_markdown_table(results)) 339 340 341if __name__ == "__main__": 342 main() 343