xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/check_accuracy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import os
3import sys
4import textwrap
5
6import pandas as pd
7
8
9# Hack to have something similar to DISABLED_TEST. These models are flaky.
10
11flaky_models = {
12    "yolov3",
13    "gluon_inception_v3",
14    "detectron2_maskrcnn_r_101_c4",
15    "XGLMForCausalLM",  # discovered in https://github.com/pytorch/pytorch/pull/128148
16}
17
18
19def get_field(csv, model_name: str, field: str):
20    try:
21        return csv.loc[csv["name"] == model_name][field].item()
22    except Exception as e:
23        return None
24
25
26def check_accuracy(actual_csv, expected_csv, expected_filename):
27    failed = []
28    improved = []
29
30    for model in actual_csv["name"]:
31        accuracy = get_field(actual_csv, model, "accuracy")
32        expected_accuracy = get_field(expected_csv, model, "accuracy")
33
34        if accuracy == expected_accuracy:
35            status = "PASS" if expected_accuracy == "pass" else "XFAIL"
36            print(f"{model:34}  {status}")
37            continue
38        elif model in flaky_models:
39            if accuracy == "pass":
40                # model passed but marked xfailed
41                status = "PASS_BUT_FLAKY:"
42            else:
43                # model failed but marked passe
44                status = "FAIL_BUT_FLAKY:"
45        elif accuracy != "pass":
46            status = "FAIL:"
47            failed.append(model)
48        else:
49            status = "IMPROVED:"
50            improved.append(model)
51        print(
52            f"{model:34}  {status:9} accuracy={accuracy}, expected={expected_accuracy}"
53        )
54
55    msg = ""
56    if failed or improved:
57        if failed:
58            msg += textwrap.dedent(
59                f"""
60            Error: {len(failed)} models have accuracy status regressed:
61                {' '.join(failed)}
62
63            """
64            )
65        if improved:
66            msg += textwrap.dedent(
67                f"""
68            Improvement: {len(improved)} models have accuracy status improved:
69                {' '.join(improved)}
70
71            """
72            )
73        sha = os.getenv("SHA1", "{your CI commit sha}")
74        msg += textwrap.dedent(
75            f"""
76        If this change is expected, you can update `{expected_filename}` to reflect the new baseline.
77        from pytorch/pytorch root, run
78        `python benchmarks/dynamo/ci_expected_accuracy/update_expected.py {sha}`
79        and then `git add` the resulting local changes to expected CSVs to your commit.
80        """
81        )
82    return failed or improved, msg
83
84
85def main():
86    parser = argparse.ArgumentParser()
87    parser.add_argument("--actual", type=str, required=True)
88    parser.add_argument("--expected", type=str, required=True)
89    args = parser.parse_args()
90
91    actual = pd.read_csv(args.actual)
92    expected = pd.read_csv(args.expected)
93
94    failed, msg = check_accuracy(actual, expected, args.expected)
95    if failed:
96        print(msg)
97        sys.exit(1)
98
99
100if __name__ == "__main__":
101    main()
102