xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/join_results.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker"""
2*da0073e9SAndroid Build Coastguard WorkerA tool to merge multiple csv files (generated by torchbench.py/etc) into a single csv file.
3*da0073e9SAndroid Build Coastguard WorkerPerforms an outer join based on the benchmark name, filling in any missing data with zeros.
4*da0073e9SAndroid Build Coastguard Worker"""
5*da0073e9SAndroid Build Coastguard Workerimport argparse
6*da0073e9SAndroid Build Coastguard Workerimport functools
7*da0073e9SAndroid Build Coastguard Workerimport operator
8*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport pandas as pd
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerdef longest_common_prefix(strs):
14*da0073e9SAndroid Build Coastguard Worker    shortest_str = min(strs, key=len)
15*da0073e9SAndroid Build Coastguard Worker    for i, char in enumerate(shortest_str):
16*da0073e9SAndroid Build Coastguard Worker        for other in strs:
17*da0073e9SAndroid Build Coastguard Worker            if other[i] != char:
18*da0073e9SAndroid Build Coastguard Worker                return shortest_str[:i]
19*da0073e9SAndroid Build Coastguard Worker    return ""
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerdef main():
23*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser()
24*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--field", "-f", default="speedup", type=str)
25*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--output", "-o", type=str)
26*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("inputs", nargs="*")
27*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    prefix = longest_common_prefix([Path(inp).stem for inp in args.inputs])
30*da0073e9SAndroid Build Coastguard Worker    frames = []
31*da0073e9SAndroid Build Coastguard Worker    fields = []
32*da0073e9SAndroid Build Coastguard Worker    for inp in args.inputs:
33*da0073e9SAndroid Build Coastguard Worker        field = Path(inp).stem[len(prefix) :]
34*da0073e9SAndroid Build Coastguard Worker        fields.append(field)
35*da0073e9SAndroid Build Coastguard Worker        frames.append(
36*da0073e9SAndroid Build Coastguard Worker            pd.read_csv(inp)
37*da0073e9SAndroid Build Coastguard Worker            .filter(["name", args.field])
38*da0073e9SAndroid Build Coastguard Worker            .rename(columns={args.field: field})
39*da0073e9SAndroid Build Coastguard Worker        )
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    df = frames[0]
42*da0073e9SAndroid Build Coastguard Worker    for other in frames[1:]:
43*da0073e9SAndroid Build Coastguard Worker        df = df.merge(other, how="outer", on="name")
44*da0073e9SAndroid Build Coastguard Worker    df = df.fillna(0)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    # drop rows where all backends failed
47*da0073e9SAndroid Build Coastguard Worker    df = df[functools.reduce(operator.or_, [df[f] != 0 for f in fields])]
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    prefix = prefix.strip("_") or "output"
50*da0073e9SAndroid Build Coastguard Worker    output = args.output or f"{prefix}.csv"
51*da0073e9SAndroid Build Coastguard Worker    print(f"Writing {output}")
52*da0073e9SAndroid Build Coastguard Worker    df.to_csv(output, index=False)
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
56*da0073e9SAndroid Build Coastguard Worker    main()
57