1import time
2from argparse import ArgumentParser
3from collections import defaultdict
4from typing import Any, Callable, List, NamedTuple
5
6import torch
7from torch.autograd import functional
8
9
10try:
11    import functorch as ft
12
13    has_functorch = True
14    print(f"Found functorch: {ft.__version__}")
15except ImportError:
16    has_functorch = False
17
18import audio_text_models
19import ppl_models
20import vision_models
21
22from utils import GetterType, InputsType, TimingResultType, to_markdown_table, VType
23
24
25def get_task_func(task: str) -> Callable:
26    def hessian_fwdrev(model, inp, strict=None):
27        return functional.hessian(
28            model,
29            inp,
30            strict=False,
31            vectorize=True,
32            outer_jacobian_strategy="forward-mode",
33        )
34
35    def hessian_revrev(model, inp, strict=None):
36        return functional.hessian(model, inp, strict=False, vectorize=True)
37
38    def jacfwd(model, inp, strict=None):
39        return functional.jacobian(
40            model, inp, strict=False, vectorize=True, strategy="forward-mode"
41        )
42
43    def jacrev(model, inp, strict=None):
44        return functional.jacobian(model, inp, strict=False, vectorize=True)
45
46    if task == "hessian_fwdrev":
47        return hessian_fwdrev
48    elif task == "hessian_revrev":
49        return hessian_revrev
50    elif task == "jacfwd":
51        return jacfwd
52    elif task == "jacrev":
53        return jacrev
54    else:
55        return getattr(functional, task)
56
57
58def get_task_functorch(task: str) -> Callable:
59    @torch.no_grad()
60    def vjp(model, inp, v=None, strict=None):
61        assert v is not None
62        out, vjpfunc = ft.vjp(model, *inp)
63        return out, vjpfunc(v)
64
65    @torch.no_grad()
66    def jvp(model, inp, v=None, strict=None):
67        assert v is not None
68        return ft.jvp(model, inp, v)
69
70    @torch.no_grad()
71    def vhp(model, inp, v=None, strict=None):
72        assert v is not None
73        argnums = tuple(range(len(inp)))
74        _, vjpfunc, aux = ft.vjp(ft.grad_and_value(model, argnums), *inp, has_aux=True)
75        return aux, vjpfunc(v)
76
77    @torch.no_grad()
78    def hvp(model, inp, v=None, strict=None):
79        assert v is not None
80        argnums = tuple(range(len(inp)))
81        _, hvp_out, aux = ft.jvp(
82            ft.grad_and_value(model, argnums), inp, v, has_aux=True
83        )
84        return aux, hvp_out
85
86    @torch.no_grad()
87    def jacfwd(model, inp, v=None, strict=None):
88        argnums = tuple(range(len(inp)))
89        return ft.jacfwd(model, argnums)(*inp)
90
91    @torch.no_grad()
92    def jacrev(model, inp, v=None, strict=None):
93        argnums = tuple(range(len(inp)))
94        return ft.jacrev(model, argnums)(*inp)
95
96    @torch.no_grad()
97    def hessian(model, inp, v=None, strict=None):
98        argnums = tuple(range(len(inp)))
99        return ft.hessian(model, argnums=argnums)(*inp)
100
101    @torch.no_grad()
102    def hessian_fwdrev(model, inp, v=None, strict=None):
103        argnums = tuple(range(len(inp)))
104        return ft.jacfwd(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp)
105
106    @torch.no_grad()
107    def hessian_revrev(model, inp, v=None, strict=None):
108        argnums = tuple(range(len(inp)))
109        return ft.jacrev(ft.jacrev(model, argnums=argnums), argnums=argnums)(*inp)
110
111    if task in locals():
112        return locals()[task]
113    elif task == "jacobian":
114        raise RuntimeError(
115            "functorch has no equivalent of autograd.functional.jacobian with vectorize=False yet"
116        )
117    else:
118        raise RuntimeError(f"Unsupported task: {task}")
119
120
121# Listing of the different tasks
122FAST_TASKS_NO_DOUBLE_BACK = [
123    "vjp",
124]
125
126FAST_TASKS = FAST_TASKS_NO_DOUBLE_BACK + [
127    "vhp",
128    "jvp",
129]
130
131ALL_TASKS_NON_VECTORIZED = FAST_TASKS + ["hvp", "jacobian", "hessian"]
132
133DOUBLE_BACKWARD_TASKS = ["jvp", "hvp", "vhp", "hessian"]
134
135VECTORIZED_TASKS = ["hessian_fwdrev", "hessian_revrev", "jacfwd", "jacrev"]
136
137ALL_TASKS = ALL_TASKS_NON_VECTORIZED + VECTORIZED_TASKS
138
139
140# Model definition which contains:
141# - name: a string with the model name.
142# - getter: a function to get the model. It takes as input the device on which the model
143#     will run. It should return the forward function and the parameters (Tensors) used as
144#     input for the forward function. Note that the forward must *not* have any side effect.
145# - tasks: the list of recommended tasks that can run in a reasonable amount of time with this model.
146# - unsupported: the list of tasks that this model cannot run.
147class ModelDef(NamedTuple):
148    name: str
149    getter: GetterType
150    tasks: List[str]
151    unsupported: List[str]
152
153
154MODELS = [
155    ModelDef("resnet18", vision_models.get_resnet18, FAST_TASKS, []),
156    ModelDef("fcn_resnet", vision_models.get_fcn_resnet, FAST_TASKS, []),
157    ModelDef("detr", vision_models.get_detr, FAST_TASKS, []),
158    ModelDef("ppl_simple_reg", ppl_models.get_simple_regression, ALL_TASKS, []),
159    ModelDef("ppl_robust_reg", ppl_models.get_robust_regression, ALL_TASKS, []),
160    ModelDef("wav2letter", audio_text_models.get_wav2letter, FAST_TASKS, []),
161    ModelDef(
162        "deepspeech",
163        audio_text_models.get_deepspeech,
164        FAST_TASKS_NO_DOUBLE_BACK,
165        DOUBLE_BACKWARD_TASKS,
166    ),
167    ModelDef("transformer", audio_text_models.get_transformer, FAST_TASKS, []),
168    ModelDef("multiheadattn", audio_text_models.get_multiheadattn, FAST_TASKS, []),
169]
170
171
172def get_v_for(model: Callable, inp: InputsType, task: str) -> VType:
173    v: VType
174
175    if task in ["vjp"]:
176        out = model(*inp)
177        v = torch.rand_like(out)
178    elif task in ["jvp", "hvp", "vhp"]:
179        if isinstance(inp, tuple):
180            v = tuple(torch.rand_like(i) for i in inp)
181        else:
182            v = torch.rand_like(inp)
183    else:
184        v = None
185
186    return v
187
188
189def run_once(model: Callable, inp: InputsType, task: str, v: VType, **kwargs) -> None:
190    func = get_task_func(task)
191
192    if v is not None:
193        res = func(model, inp, v=v, strict=True)
194    else:
195        res = func(model, inp, strict=True)
196
197
198def run_once_functorch(
199    model: Callable, inp: InputsType, task: str, v: VType, maybe_check_consistency=False
200) -> None:
201    func = get_task_functorch(task)
202
203    if v is not None:
204        res = func(model, inp, v=v, strict=True)
205    else:
206        res = func(model, inp, strict=True)
207
208    if maybe_check_consistency:
209        af_func = get_task_func(task)
210        if v is not None:
211            expected = af_func(model, inp, v=v, strict=True)
212        else:
213            expected = af_func(model, inp, strict=True)
214        atol = 1e-2 if task == "vhp" else 5e-3
215        torch.testing.assert_close(
216            res,
217            expected,
218            rtol=1e-5,
219            atol=atol,
220            msg=f"Consistency fail for task '{task}'",
221        )
222
223
224def run_model(
225    model_getter: GetterType, args: Any, task: str, run_once_fn: Callable = run_once
226) -> List[float]:
227    if args.gpu == -1:
228        device = torch.device("cpu")
229
230        def noop():
231            pass
232
233        do_sync = noop
234    else:
235        device = torch.device(f"cuda:{args.gpu}")
236        do_sync = torch.cuda.synchronize
237
238    model, inp = model_getter(device)
239
240    v = get_v_for(model, inp, task)
241
242    # Warmup
243    # maybe_check_consistency=True checks for consistency between
244    # functorch vs autograd.functional and is done in run_once_functorch only
245    run_once_fn(model, inp, task, v, maybe_check_consistency=True)
246
247    elapsed = []
248    for it in range(args.num_iters):
249        do_sync()
250        start = time.time()
251        run_once_fn(model, inp, task, v)
252        do_sync()
253        elapsed.append(time.time() - start)
254
255    return elapsed
256
257
258def main():
259    parser = ArgumentParser("Main script to benchmark functional API of the autograd.")
260    parser.add_argument(
261        "--output", type=str, default="", help="Text file where to write the output"
262    )
263    parser.add_argument("--num-iters", type=int, default=10)
264    parser.add_argument(
265        "--gpu",
266        type=int,
267        default=-2,
268        help="GPU to use, -1 for CPU and -2 for auto-detect",
269    )
270    parser.add_argument(
271        "--run-slow-tasks", action="store_true", help="Run even the slow tasks"
272    )
273    parser.add_argument(
274        "--model-filter",
275        type=str,
276        default="",
277        help="Only run the models in this filter",
278    )
279    parser.add_argument(
280        "--task-filter", type=str, default="", help="Only run the tasks in this filter"
281    )
282    parser.add_argument(
283        "--num-threads",
284        type=int,
285        default=10,
286        help="Number of concurrent threads to use when running on cpu",
287    )
288    parser.add_argument("--seed", type=int, default=0, help="The random seed to use.")
289    args = parser.parse_args()
290
291    results: TimingResultType = defaultdict(defaultdict)
292    torch.set_num_threads(args.num_threads)
293    torch.set_num_interop_threads(args.num_threads)
294
295    # This automatically seed cuda if it is available
296    torch.manual_seed(args.seed)
297
298    if args.gpu == -2:
299        args.gpu = 0 if torch.cuda.is_available() else -1
300
301    for name, model_getter, recommended_tasks, unsupported_tasks in MODELS:
302        if args.model_filter and name not in args.model_filter:
303            continue
304        tasks = ALL_TASKS if args.run_slow_tasks else recommended_tasks
305        for task in tasks:
306            if task in unsupported_tasks:
307                continue
308            if args.task_filter and task not in args.task_filter:
309                continue
310            runtimes = run_model(model_getter, args, task)
311
312            runtimes = torch.tensor(runtimes)
313            mean, var = runtimes.mean(), runtimes.var()
314            results[name][task] = (mean.item(), var.item())
315            print(f"Results for model {name} on task {task}: {mean}s (var: {var})")
316
317            if has_functorch:
318                try:
319                    runtimes = run_model(
320                        model_getter, args, task, run_once_fn=run_once_functorch
321                    )
322                except RuntimeError as e:
323                    print(
324                        f"Failed model using Functorch: {name}, task: {task}, Error message: \n\t",
325                        e,
326                    )
327                    continue
328
329                runtimes = torch.tensor(runtimes)
330                mean, var = runtimes.mean(), runtimes.var()
331                results[name][f"functorch {task}"] = (mean.item(), var.item())
332                print(
333                    f"Results for model {name} on task {task} using Functorch: {mean}s (var: {var})"
334                )
335
336    if args.output:
337        with open(args.output, "w") as f:
338            f.write(to_markdown_table(results))
339
340
341if __name__ == "__main__":
342    main()
343