1import argparse 2import os 3import sys 4import textwrap 5 6import pandas as pd 7 8 9# Hack to have something similar to DISABLED_TEST. These models are flaky. 10 11flaky_models = { 12 "yolov3", 13 "gluon_inception_v3", 14 "detectron2_maskrcnn_r_101_c4", 15 "XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148 16} 17 18 19def get_field(csv, model_name: str, field: str): 20 try: 21 return csv.loc[csv["name"] == model_name][field].item() 22 except Exception as e: 23 return None 24 25 26def check_accuracy(actual_csv, expected_csv, expected_filename): 27 failed = [] 28 improved = [] 29 30 for model in actual_csv["name"]: 31 accuracy = get_field(actual_csv, model, "accuracy") 32 expected_accuracy = get_field(expected_csv, model, "accuracy") 33 34 if accuracy == expected_accuracy: 35 status = "PASS" if expected_accuracy == "pass" else "XFAIL" 36 print(f"{model:34} {status}") 37 continue 38 elif model in flaky_models: 39 if accuracy == "pass": 40 # model passed but marked xfailed 41 status = "PASS_BUT_FLAKY:" 42 else: 43 # model failed but marked passe 44 status = "FAIL_BUT_FLAKY:" 45 elif accuracy != "pass": 46 status = "FAIL:" 47 failed.append(model) 48 else: 49 status = "IMPROVED:" 50 improved.append(model) 51 print( 52 f"{model:34} {status:9} accuracy={accuracy}, expected={expected_accuracy}" 53 ) 54 55 msg = "" 56 if failed or improved: 57 if failed: 58 msg += textwrap.dedent( 59 f""" 60 Error: {len(failed)} models have accuracy status regressed: 61 {' '.join(failed)} 62 63 """ 64 ) 65 if improved: 66 msg += textwrap.dedent( 67 f""" 68 Improvement: {len(improved)} models have accuracy status improved: 69 {' '.join(improved)} 70 71 """ 72 ) 73 sha = os.getenv("SHA1", "{your CI commit sha}") 74 msg += textwrap.dedent( 75 f""" 76 If this change is expected, you can update `{expected_filename}` to reflect the new baseline. 77 from pytorch/pytorch root, run 78 `python benchmarks/dynamo/ci_expected_accuracy/update_expected.py {sha}` 79 and then `git add` the resulting local changes to expected CSVs to your commit. 80 """ 81 ) 82 return failed or improved, msg 83 84 85def main(): 86 parser = argparse.ArgumentParser() 87 parser.add_argument("--actual", type=str, required=True) 88 parser.add_argument("--expected", type=str, required=True) 89 args = parser.parse_args() 90 91 actual = pd.read_csv(args.actual) 92 expected = pd.read_csv(args.expected) 93 94 failed, msg = check_accuracy(actual, expected, args.expected) 95 if failed: 96 print(msg) 97 sys.exit(1) 98 99 100if __name__ == "__main__": 101 main() 102