xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/join_results.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2A tool to merge multiple csv files (generated by torchbench.py/etc) into a single csv file.
3Performs an outer join based on the benchmark name, filling in any missing data with zeros.
4"""
5import argparse
6import functools
7import operator
8from pathlib import Path
9
10import pandas as pd
11
12
13def longest_common_prefix(strs):
14    shortest_str = min(strs, key=len)
15    for i, char in enumerate(shortest_str):
16        for other in strs:
17            if other[i] != char:
18                return shortest_str[:i]
19    return ""
20
21
22def main():
23    parser = argparse.ArgumentParser()
24    parser.add_argument("--field", "-f", default="speedup", type=str)
25    parser.add_argument("--output", "-o", type=str)
26    parser.add_argument("inputs", nargs="*")
27    args = parser.parse_args()
28
29    prefix = longest_common_prefix([Path(inp).stem for inp in args.inputs])
30    frames = []
31    fields = []
32    for inp in args.inputs:
33        field = Path(inp).stem[len(prefix) :]
34        fields.append(field)
35        frames.append(
36            pd.read_csv(inp)
37            .filter(["name", args.field])
38            .rename(columns={args.field: field})
39        )
40
41    df = frames[0]
42    for other in frames[1:]:
43        df = df.merge(other, how="outer", on="name")
44    df = df.fillna(0)
45
46    # drop rows where all backends failed
47    df = df[functools.reduce(operator.or_, [df[f] != 0 for f in fields])]
48
49    prefix = prefix.strip("_") or "output"
50    output = args.output or f"{prefix}.csv"
51    print(f"Writing {output}")
52    df.to_csv(output, index=False)
53
54
55if __name__ == "__main__":
56    main()
57