1import argparse 2import os 3import sys 4import textwrap 5 6import pandas as pd 7 8 9def get_field(csv, model_name: str, field: str): 10 try: 11 return csv.loc[csv["name"] == model_name][field].item() 12 except Exception as e: 13 return None 14 15 16def check_graph_breaks(actual_csv, expected_csv, expected_filename): 17 failed = [] 18 improved = [] 19 20 for model in actual_csv["name"]: 21 graph_breaks = get_field(actual_csv, model, "graph_breaks") 22 expected_graph_breaks = get_field(expected_csv, model, "graph_breaks") 23 24 if graph_breaks == expected_graph_breaks: 25 status = "PASS" 26 print(f"{model:34} {status}") 27 continue 28 29 elif graph_breaks > expected_graph_breaks: 30 status = "FAIL:" 31 failed.append(model) 32 elif graph_breaks < expected_graph_breaks: 33 status = "IMPROVED:" 34 improved.append(model) 35 print( 36 f"{model:34} {status:9} graph_breaks={graph_breaks}, expected={expected_graph_breaks}" 37 ) 38 39 msg = "" 40 if failed or improved: 41 if failed: 42 msg += textwrap.dedent( 43 f""" 44 Error: {len(failed)} models have new dynamo graph breaks: 45 {' '.join(failed)} 46 47 """ 48 ) 49 if improved: 50 msg += textwrap.dedent( 51 f""" 52 Improvement: {len(improved)} models have fixed dynamo graph breaks: 53 {' '.join(improved)} 54 55 """ 56 ) 57 sha = os.getenv("SHA1", "{your CI commit sha}") 58 msg += textwrap.dedent( 59 f""" 60 If this change is expected, you can update `{expected_filename}` to reflect the new baseline. 61 from pytorch/pytorch root, run 62 `python benchmarks/dynamo/ci_expected_accuracy/update_expected.py {sha}` 63 and then `git add` the resulting local changes to expected CSVs to your commit. 64 """ 65 ) 66 return failed or improved, msg 67 68 69def main(): 70 parser = argparse.ArgumentParser() 71 parser.add_argument("--actual", type=str, required=True) 72 parser.add_argument("--expected", type=str, required=True) 73 args = parser.parse_args() 74 75 actual = pd.read_csv(args.actual) 76 expected = pd.read_csv(args.expected) 77 78 failed, msg = check_graph_breaks(actual, expected, args.expected) 79 if failed: 80 print(msg) 81 sys.exit(1) 82 83 84if __name__ == "__main__": 85 main() 86