xref: /aosp_15_r20/external/pytorch/benchmarks/compare-fastrnn-results.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport json
3*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard WorkerResult = namedtuple("Result", ["name", "base_time", "diff_time"])
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerdef construct_name(fwd_bwd, test_name):
10*da0073e9SAndroid Build Coastguard Worker    bwd = "backward" in fwd_bwd
11*da0073e9SAndroid Build Coastguard Worker    suite_name = fwd_bwd.replace("-backward", "")
12*da0073e9SAndroid Build Coastguard Worker    return f"{suite_name}[{test_name}]:{'bwd' if bwd else 'fwd'}"
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef get_times(json_data):
16*da0073e9SAndroid Build Coastguard Worker    r = {}
17*da0073e9SAndroid Build Coastguard Worker    for fwd_bwd in json_data:
18*da0073e9SAndroid Build Coastguard Worker        for test_name in json_data[fwd_bwd]:
19*da0073e9SAndroid Build Coastguard Worker            name = construct_name(fwd_bwd, test_name)
20*da0073e9SAndroid Build Coastguard Worker            r[name] = json_data[fwd_bwd][test_name]
21*da0073e9SAndroid Build Coastguard Worker    return r
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerparser = argparse.ArgumentParser("compare two pytest jsons")
25*da0073e9SAndroid Build Coastguard Workerparser.add_argument("base", help="base json file")
26*da0073e9SAndroid Build Coastguard Workerparser.add_argument("diff", help="diff json file")
27*da0073e9SAndroid Build Coastguard Workerparser.add_argument(
28*da0073e9SAndroid Build Coastguard Worker    "--format", default="md", type=str, help="output format (csv, md, json, table)"
29*da0073e9SAndroid Build Coastguard Worker)
30*da0073e9SAndroid Build Coastguard Workerargs = parser.parse_args()
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerwith open(args.base) as base:
33*da0073e9SAndroid Build Coastguard Worker    base_times = get_times(json.load(base))
34*da0073e9SAndroid Build Coastguard Workerwith open(args.diff) as diff:
35*da0073e9SAndroid Build Coastguard Worker    diff_times = get_times(json.load(diff))
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerall_keys = set(base_times.keys()).union(diff_times.keys())
38*da0073e9SAndroid Build Coastguard Workerresults = [
39*da0073e9SAndroid Build Coastguard Worker    Result(name, base_times.get(name, float("nan")), diff_times.get(name, float("nan")))
40*da0073e9SAndroid Build Coastguard Worker    for name in sorted(all_keys)
41*da0073e9SAndroid Build Coastguard Worker]
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Workerheader_fmt = {
44*da0073e9SAndroid Build Coastguard Worker    "table": "{:48s} {:>13s} {:>15s} {:>10s}",
45*da0073e9SAndroid Build Coastguard Worker    "md": "| {:48s} | {:>13s} | {:>15s} | {:>10s} |",
46*da0073e9SAndroid Build Coastguard Worker    "csv": "{:s}, {:s}, {:s}, {:s}",
47*da0073e9SAndroid Build Coastguard Worker}
48*da0073e9SAndroid Build Coastguard Workerdata_fmt = {
49*da0073e9SAndroid Build Coastguard Worker    "table": "{:48s} {:13.6f} {:15.6f} {:9.1f}%",
50*da0073e9SAndroid Build Coastguard Worker    "md": "| {:48s} | {:13.6f} | {:15.6f} | {:9.1f}% |",
51*da0073e9SAndroid Build Coastguard Worker    "csv": "{:s}, {:.6f}, {:.6f}, {:.2f}%",
52*da0073e9SAndroid Build Coastguard Worker}
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Workerif args.format in ["table", "md", "csv"]:
55*da0073e9SAndroid Build Coastguard Worker    header_fmt_str = header_fmt[args.format]
56*da0073e9SAndroid Build Coastguard Worker    data_fmt_str = data_fmt[args.format]
57*da0073e9SAndroid Build Coastguard Worker    print(header_fmt_str.format("name", "base time (s)", "diff time (s)", "% change"))
58*da0073e9SAndroid Build Coastguard Worker    if args.format == "md":
59*da0073e9SAndroid Build Coastguard Worker        print(header_fmt_str.format(":---", "---:", "---:", "---:"))
60*da0073e9SAndroid Build Coastguard Worker    for r in results:
61*da0073e9SAndroid Build Coastguard Worker        print(
62*da0073e9SAndroid Build Coastguard Worker            data_fmt_str.format(
63*da0073e9SAndroid Build Coastguard Worker                r.name,
64*da0073e9SAndroid Build Coastguard Worker                r.base_time,
65*da0073e9SAndroid Build Coastguard Worker                r.diff_time,
66*da0073e9SAndroid Build Coastguard Worker                (r.diff_time / r.base_time - 1.0) * 100.0,
67*da0073e9SAndroid Build Coastguard Worker            )
68*da0073e9SAndroid Build Coastguard Worker        )
69*da0073e9SAndroid Build Coastguard Workerelif args.format == "json":
70*da0073e9SAndroid Build Coastguard Worker    print(json.dumps(results))
71*da0073e9SAndroid Build Coastguard Workerelse:
72*da0073e9SAndroid Build Coastguard Worker    raise ValueError("Unknown output format: " + args.format)
73