xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/summarize_perf.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport logging
2*da0073e9SAndroid Build Coastguard Workerimport os
3*da0073e9SAndroid Build Coastguard Workerimport re
4*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport click
7*da0073e9SAndroid Build Coastguard Workerimport pandas as pd
8*da0073e9SAndroid Build Coastguard Workerfrom tabulate import tabulate
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerdef gmean(s):
12*da0073e9SAndroid Build Coastguard Worker    return s.product() ** (1 / len(s))
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef find_csv_files(path, perf_compare):
16*da0073e9SAndroid Build Coastguard Worker    """
17*da0073e9SAndroid Build Coastguard Worker    Recursively search for all CSV files in directory and subdirectories whose
18*da0073e9SAndroid Build Coastguard Worker    name contains a target string.
19*da0073e9SAndroid Build Coastguard Worker    """
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker    def is_csv(f):
22*da0073e9SAndroid Build Coastguard Worker        if perf_compare:
23*da0073e9SAndroid Build Coastguard Worker            regex = r"training_(torchbench|huggingface|timm_models)\.csv"
24*da0073e9SAndroid Build Coastguard Worker            return re.match(regex, f) is not None
25*da0073e9SAndroid Build Coastguard Worker        else:
26*da0073e9SAndroid Build Coastguard Worker            return f.endswith("_performance.csv")
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    csv_files = []
29*da0073e9SAndroid Build Coastguard Worker    for root, dirs, files in os.walk(path):
30*da0073e9SAndroid Build Coastguard Worker        for file in files:
31*da0073e9SAndroid Build Coastguard Worker            if is_csv(file):
32*da0073e9SAndroid Build Coastguard Worker                csv_files.append(os.path.join(root, file))
33*da0073e9SAndroid Build Coastguard Worker    return csv_files
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker@click.command()
37*da0073e9SAndroid Build Coastguard Worker@click.argument("directory", default="artifacts")
38*da0073e9SAndroid Build Coastguard Worker@click.option("--amp", is_flag=True)
39*da0073e9SAndroid Build Coastguard Worker@click.option("--float32", is_flag=True)
40*da0073e9SAndroid Build Coastguard Worker@click.option(
41*da0073e9SAndroid Build Coastguard Worker    "--perf-compare",
42*da0073e9SAndroid Build Coastguard Worker    is_flag=True,
43*da0073e9SAndroid Build Coastguard Worker    help="Set if the CSVs were generated by running manually the action rather than picking them from the nightly job",
44*da0073e9SAndroid Build Coastguard Worker)
45*da0073e9SAndroid Build Coastguard Workerdef main(directory, amp, float32, perf_compare):
46*da0073e9SAndroid Build Coastguard Worker    """
47*da0073e9SAndroid Build Coastguard Worker    Given a directory containing multiple CSVs from --performance benchmark
48*da0073e9SAndroid Build Coastguard Worker    runs, aggregates and generates summary statistics similar to the web UI at
49*da0073e9SAndroid Build Coastguard Worker    https://torchci-git-fork-huydhn-add-compilers-bench-74abf8-fbopensource.vercel.app/benchmark/compilers
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker    This is most useful if you've downloaded CSVs from CI and need to quickly
52*da0073e9SAndroid Build Coastguard Worker    look at aggregate stats.  The CSVs are expected to follow exactly the same
53*da0073e9SAndroid Build Coastguard Worker    naming convention that is used in CI.
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    You may also be interested in
56*da0073e9SAndroid Build Coastguard Worker    https://docs.google.com/document/d/1DQQxIgmKa3eF0HByDTLlcJdvefC4GwtsklJUgLs09fQ/edit#
57*da0073e9SAndroid Build Coastguard Worker    which explains how to interpret the raw csv data.
58*da0073e9SAndroid Build Coastguard Worker    """
59*da0073e9SAndroid Build Coastguard Worker    dtypes = ["amp", "float32"]
60*da0073e9SAndroid Build Coastguard Worker    if amp and not float32:
61*da0073e9SAndroid Build Coastguard Worker        dtypes = ["amp"]
62*da0073e9SAndroid Build Coastguard Worker    if float32 and not amp:
63*da0073e9SAndroid Build Coastguard Worker        dtypes = ["float32"]
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker    dfs = defaultdict(list)
66*da0073e9SAndroid Build Coastguard Worker    for f in find_csv_files(directory, perf_compare):
67*da0073e9SAndroid Build Coastguard Worker        try:
68*da0073e9SAndroid Build Coastguard Worker            dfs[os.path.basename(f)].append(pd.read_csv(f))
69*da0073e9SAndroid Build Coastguard Worker        except Exception:
70*da0073e9SAndroid Build Coastguard Worker            logging.warning("failed parsing %s", f)
71*da0073e9SAndroid Build Coastguard Worker            raise
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    # dtype -> statistic -> benchmark -> compiler -> value
74*da0073e9SAndroid Build Coastguard Worker    results = defaultdict(  # dtype
75*da0073e9SAndroid Build Coastguard Worker        lambda: defaultdict(  # statistic
76*da0073e9SAndroid Build Coastguard Worker            lambda: defaultdict(dict)  # benchmark  # compiler -> value
77*da0073e9SAndroid Build Coastguard Worker        )
78*da0073e9SAndroid Build Coastguard Worker    )
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    for k, v in sorted(dfs.items()):
81*da0073e9SAndroid Build Coastguard Worker        if perf_compare:
82*da0073e9SAndroid Build Coastguard Worker            regex = r"training_(torchbench|huggingface|timm_models)\.csv"
83*da0073e9SAndroid Build Coastguard Worker            m = re.match(regex, k)
84*da0073e9SAndroid Build Coastguard Worker            assert m is not None, k
85*da0073e9SAndroid Build Coastguard Worker            compiler = "inductor"
86*da0073e9SAndroid Build Coastguard Worker            benchmark = m.group(1)
87*da0073e9SAndroid Build Coastguard Worker            dtype = "float32"
88*da0073e9SAndroid Build Coastguard Worker            mode = "training"
89*da0073e9SAndroid Build Coastguard Worker            device = "cuda"
90*da0073e9SAndroid Build Coastguard Worker        else:
91*da0073e9SAndroid Build Coastguard Worker            regex = (
92*da0073e9SAndroid Build Coastguard Worker                "(.+)_"
93*da0073e9SAndroid Build Coastguard Worker                "(torchbench|huggingface|timm_models)_"
94*da0073e9SAndroid Build Coastguard Worker                "(float32|amp)_"
95*da0073e9SAndroid Build Coastguard Worker                "(inference|training)_"
96*da0073e9SAndroid Build Coastguard Worker                "(cpu|cuda)_"
97*da0073e9SAndroid Build Coastguard Worker                r"performance\.csv"
98*da0073e9SAndroid Build Coastguard Worker            )
99*da0073e9SAndroid Build Coastguard Worker            m = re.match(regex, k)
100*da0073e9SAndroid Build Coastguard Worker            compiler = m.group(1)
101*da0073e9SAndroid Build Coastguard Worker            benchmark = m.group(2)
102*da0073e9SAndroid Build Coastguard Worker            dtype = m.group(3)
103*da0073e9SAndroid Build Coastguard Worker            mode = m.group(4)
104*da0073e9SAndroid Build Coastguard Worker            device = m.group(5)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker        df = pd.concat(v)
107*da0073e9SAndroid Build Coastguard Worker        df = df.dropna().query("speedup != 0")
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        statistics = {
110*da0073e9SAndroid Build Coastguard Worker            "speedup": gmean(df["speedup"]),
111*da0073e9SAndroid Build Coastguard Worker            "comptime": df["compilation_latency"].mean(),
112*da0073e9SAndroid Build Coastguard Worker            "memory": gmean(df["compression_ratio"]),
113*da0073e9SAndroid Build Coastguard Worker        }
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker        if dtype not in dtypes:
116*da0073e9SAndroid Build Coastguard Worker            continue
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker        for statistic, v in statistics.items():
119*da0073e9SAndroid Build Coastguard Worker            results[f"{device} {dtype} {mode}"][statistic][benchmark][compiler] = v
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    descriptions = {
122*da0073e9SAndroid Build Coastguard Worker        "speedup": "Geometric mean speedup",
123*da0073e9SAndroid Build Coastguard Worker        "comptime": "Mean compilation time",
124*da0073e9SAndroid Build Coastguard Worker        "memory": "Peak memory compression ratio",
125*da0073e9SAndroid Build Coastguard Worker    }
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    for dtype_mode, r in results.items():
128*da0073e9SAndroid Build Coastguard Worker        print(f"# {dtype_mode} performance results")
129*da0073e9SAndroid Build Coastguard Worker        for statistic, data in r.items():
130*da0073e9SAndroid Build Coastguard Worker            print(f"## {descriptions[statistic]}")
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker            table = []
133*da0073e9SAndroid Build Coastguard Worker            for row_name in data[next(iter(data.keys()))]:
134*da0073e9SAndroid Build Coastguard Worker                row = [row_name]
135*da0073e9SAndroid Build Coastguard Worker                for col_name in data:
136*da0073e9SAndroid Build Coastguard Worker                    row.append(round(data[col_name][row_name], 2))
137*da0073e9SAndroid Build Coastguard Worker                table.append(row)
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker            headers = list(data.keys())
140*da0073e9SAndroid Build Coastguard Worker            print(tabulate(table, headers=headers))
141*da0073e9SAndroid Build Coastguard Worker            print()
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Workermain()
145