xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/check_graph_breaks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import os
3import sys
4import textwrap
5
6import pandas as pd
7
8
9def get_field(csv, model_name: str, field: str):
10    try:
11        return csv.loc[csv["name"] == model_name][field].item()
12    except Exception as e:
13        return None
14
15
16def check_graph_breaks(actual_csv, expected_csv, expected_filename):
17    failed = []
18    improved = []
19
20    for model in actual_csv["name"]:
21        graph_breaks = get_field(actual_csv, model, "graph_breaks")
22        expected_graph_breaks = get_field(expected_csv, model, "graph_breaks")
23
24        if graph_breaks == expected_graph_breaks:
25            status = "PASS"
26            print(f"{model:34}  {status}")
27            continue
28
29        elif graph_breaks > expected_graph_breaks:
30            status = "FAIL:"
31            failed.append(model)
32        elif graph_breaks < expected_graph_breaks:
33            status = "IMPROVED:"
34            improved.append(model)
35        print(
36            f"{model:34}  {status:9} graph_breaks={graph_breaks}, expected={expected_graph_breaks}"
37        )
38
39    msg = ""
40    if failed or improved:
41        if failed:
42            msg += textwrap.dedent(
43                f"""
44            Error: {len(failed)} models have new dynamo graph breaks:
45                {' '.join(failed)}
46
47            """
48            )
49        if improved:
50            msg += textwrap.dedent(
51                f"""
52            Improvement: {len(improved)} models have fixed dynamo graph breaks:
53                {' '.join(improved)}
54
55            """
56            )
57        sha = os.getenv("SHA1", "{your CI commit sha}")
58        msg += textwrap.dedent(
59            f"""
60        If this change is expected, you can update `{expected_filename}` to reflect the new baseline.
61        from pytorch/pytorch root, run
62        `python benchmarks/dynamo/ci_expected_accuracy/update_expected.py {sha}`
63        and then `git add` the resulting local changes to expected CSVs to your commit.
64        """
65        )
66    return failed or improved, msg
67
68
69def main():
70    parser = argparse.ArgumentParser()
71    parser.add_argument("--actual", type=str, required=True)
72    parser.add_argument("--expected", type=str, required=True)
73    args = parser.parse_args()
74
75    actual = pd.read_csv(args.actual)
76    expected = pd.read_csv(args.expected)
77
78    failed, msg = check_graph_breaks(actual, expected, args.expected)
79    if failed:
80        print(msg)
81        sys.exit(1)
82
83
84if __name__ == "__main__":
85    main()
86