xref: /aosp_15_r20/external/pytorch/benchmarks/functional_autograd_benchmark/compare.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2from collections import defaultdict
3
4from utils import from_markdown_table, to_markdown_table
5
6
7def main():
8    parser = argparse.ArgumentParser(
9        "Main script to compare results from the benchmarks"
10    )
11    parser.add_argument(
12        "--before",
13        type=str,
14        default="before.txt",
15        help="Text file containing the times to use as base",
16    )
17    parser.add_argument(
18        "--after",
19        type=str,
20        default="after.txt",
21        help="Text file containing the times to use as new version",
22    )
23    parser.add_argument(
24        "--output", type=str, default="", help="Text file where to write the output"
25    )
26    args = parser.parse_args()
27
28    with open(args.before) as f:
29        content = f.read()
30    res_before = from_markdown_table(content)
31
32    with open(args.after) as f:
33        content = f.read()
34    res_after = from_markdown_table(content)
35
36    diff = defaultdict(defaultdict)
37    for model in res_before:
38        for task in res_before[model]:
39            mean_before, var_before = res_before[model][task]
40            if task not in res_after[model]:
41                diff[model][task] = (None, mean_before, var_before, None, None)
42            else:
43                mean_after, var_after = res_after[model][task]
44                diff[model][task] = (
45                    mean_before / mean_after,
46                    mean_before,
47                    var_before,
48                    mean_after,
49                    var_after,
50                )
51    for model in res_after:
52        for task in res_after[model]:
53            if task not in res_before[model]:
54                mean_after, var_after = res_after[model][task]
55                diff[model][task] = (None, None, None, mean_after, var_after)
56
57    header = (
58        "model",
59        "task",
60        "speedup",
61        "mean (before)",
62        "var (before)",
63        "mean (after)",
64        "var (after)",
65    )
66    out = to_markdown_table(diff, header=header)
67
68    print(out)
69    if args.output:
70        with open(args.output, "w") as f:
71            f.write(out)
72
73
74if __name__ == "__main__":
75    main()
76