1*da0073e9SAndroid Build Coastguard Workerimport logging 2*da0073e9SAndroid Build Coastguard Workerimport os 3*da0073e9SAndroid Build Coastguard Workerimport re 4*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport click 7*da0073e9SAndroid Build Coastguard Workerimport pandas as pd 8*da0073e9SAndroid Build Coastguard Workerfrom tabulate import tabulate 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerdef gmean(s): 12*da0073e9SAndroid Build Coastguard Worker return s.product() ** (1 / len(s)) 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerdef find_csv_files(path, perf_compare): 16*da0073e9SAndroid Build Coastguard Worker """ 17*da0073e9SAndroid Build Coastguard Worker Recursively search for all CSV files in directory and subdirectories whose 18*da0073e9SAndroid Build Coastguard Worker name contains a target string. 19*da0073e9SAndroid Build Coastguard Worker """ 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker def is_csv(f): 22*da0073e9SAndroid Build Coastguard Worker if perf_compare: 23*da0073e9SAndroid Build Coastguard Worker regex = r"training_(torchbench|huggingface|timm_models)\.csv" 24*da0073e9SAndroid Build Coastguard Worker return re.match(regex, f) is not None 25*da0073e9SAndroid Build Coastguard Worker else: 26*da0073e9SAndroid Build Coastguard Worker return f.endswith("_performance.csv") 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker csv_files = [] 29*da0073e9SAndroid Build Coastguard Worker for root, dirs, files in os.walk(path): 30*da0073e9SAndroid Build Coastguard Worker for file in files: 31*da0073e9SAndroid Build Coastguard Worker if is_csv(file): 32*da0073e9SAndroid Build Coastguard Worker csv_files.append(os.path.join(root, file)) 33*da0073e9SAndroid Build Coastguard Worker return csv_files 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker@click.command() 37*da0073e9SAndroid Build Coastguard Worker@click.argument("directory", default="artifacts") 38*da0073e9SAndroid Build Coastguard Worker@click.option("--amp", is_flag=True) 39*da0073e9SAndroid Build Coastguard Worker@click.option("--float32", is_flag=True) 40*da0073e9SAndroid Build Coastguard Worker@click.option( 41*da0073e9SAndroid Build Coastguard Worker "--perf-compare", 42*da0073e9SAndroid Build Coastguard Worker is_flag=True, 43*da0073e9SAndroid Build Coastguard Worker help="Set if the CSVs were generated by running manually the action rather than picking them from the nightly job", 44*da0073e9SAndroid Build Coastguard Worker) 45*da0073e9SAndroid Build Coastguard Workerdef main(directory, amp, float32, perf_compare): 46*da0073e9SAndroid Build Coastguard Worker """ 47*da0073e9SAndroid Build Coastguard Worker Given a directory containing multiple CSVs from --performance benchmark 48*da0073e9SAndroid Build Coastguard Worker runs, aggregates and generates summary statistics similar to the web UI at 49*da0073e9SAndroid Build Coastguard Worker https://torchci-git-fork-huydhn-add-compilers-bench-74abf8-fbopensource.vercel.app/benchmark/compilers 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker This is most useful if you've downloaded CSVs from CI and need to quickly 52*da0073e9SAndroid Build Coastguard Worker look at aggregate stats. The CSVs are expected to follow exactly the same 53*da0073e9SAndroid Build Coastguard Worker naming convention that is used in CI. 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker You may also be interested in 56*da0073e9SAndroid Build Coastguard Worker https://docs.google.com/document/d/1DQQxIgmKa3eF0HByDTLlcJdvefC4GwtsklJUgLs09fQ/edit# 57*da0073e9SAndroid Build Coastguard Worker which explains how to interpret the raw csv data. 58*da0073e9SAndroid Build Coastguard Worker """ 59*da0073e9SAndroid Build Coastguard Worker dtypes = ["amp", "float32"] 60*da0073e9SAndroid Build Coastguard Worker if amp and not float32: 61*da0073e9SAndroid Build Coastguard Worker dtypes = ["amp"] 62*da0073e9SAndroid Build Coastguard Worker if float32 and not amp: 63*da0073e9SAndroid Build Coastguard Worker dtypes = ["float32"] 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker dfs = defaultdict(list) 66*da0073e9SAndroid Build Coastguard Worker for f in find_csv_files(directory, perf_compare): 67*da0073e9SAndroid Build Coastguard Worker try: 68*da0073e9SAndroid Build Coastguard Worker dfs[os.path.basename(f)].append(pd.read_csv(f)) 69*da0073e9SAndroid Build Coastguard Worker except Exception: 70*da0073e9SAndroid Build Coastguard Worker logging.warning("failed parsing %s", f) 71*da0073e9SAndroid Build Coastguard Worker raise 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker # dtype -> statistic -> benchmark -> compiler -> value 74*da0073e9SAndroid Build Coastguard Worker results = defaultdict( # dtype 75*da0073e9SAndroid Build Coastguard Worker lambda: defaultdict( # statistic 76*da0073e9SAndroid Build Coastguard Worker lambda: defaultdict(dict) # benchmark # compiler -> value 77*da0073e9SAndroid Build Coastguard Worker ) 78*da0073e9SAndroid Build Coastguard Worker ) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker for k, v in sorted(dfs.items()): 81*da0073e9SAndroid Build Coastguard Worker if perf_compare: 82*da0073e9SAndroid Build Coastguard Worker regex = r"training_(torchbench|huggingface|timm_models)\.csv" 83*da0073e9SAndroid Build Coastguard Worker m = re.match(regex, k) 84*da0073e9SAndroid Build Coastguard Worker assert m is not None, k 85*da0073e9SAndroid Build Coastguard Worker compiler = "inductor" 86*da0073e9SAndroid Build Coastguard Worker benchmark = m.group(1) 87*da0073e9SAndroid Build Coastguard Worker dtype = "float32" 88*da0073e9SAndroid Build Coastguard Worker mode = "training" 89*da0073e9SAndroid Build Coastguard Worker device = "cuda" 90*da0073e9SAndroid Build Coastguard Worker else: 91*da0073e9SAndroid Build Coastguard Worker regex = ( 92*da0073e9SAndroid Build Coastguard Worker "(.+)_" 93*da0073e9SAndroid Build Coastguard Worker "(torchbench|huggingface|timm_models)_" 94*da0073e9SAndroid Build Coastguard Worker "(float32|amp)_" 95*da0073e9SAndroid Build Coastguard Worker "(inference|training)_" 96*da0073e9SAndroid Build Coastguard Worker "(cpu|cuda)_" 97*da0073e9SAndroid Build Coastguard Worker r"performance\.csv" 98*da0073e9SAndroid Build Coastguard Worker ) 99*da0073e9SAndroid Build Coastguard Worker m = re.match(regex, k) 100*da0073e9SAndroid Build Coastguard Worker compiler = m.group(1) 101*da0073e9SAndroid Build Coastguard Worker benchmark = m.group(2) 102*da0073e9SAndroid Build Coastguard Worker dtype = m.group(3) 103*da0073e9SAndroid Build Coastguard Worker mode = m.group(4) 104*da0073e9SAndroid Build Coastguard Worker device = m.group(5) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker df = pd.concat(v) 107*da0073e9SAndroid Build Coastguard Worker df = df.dropna().query("speedup != 0") 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker statistics = { 110*da0073e9SAndroid Build Coastguard Worker "speedup": gmean(df["speedup"]), 111*da0073e9SAndroid Build Coastguard Worker "comptime": df["compilation_latency"].mean(), 112*da0073e9SAndroid Build Coastguard Worker "memory": gmean(df["compression_ratio"]), 113*da0073e9SAndroid Build Coastguard Worker } 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker if dtype not in dtypes: 116*da0073e9SAndroid Build Coastguard Worker continue 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker for statistic, v in statistics.items(): 119*da0073e9SAndroid Build Coastguard Worker results[f"{device} {dtype} {mode}"][statistic][benchmark][compiler] = v 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker descriptions = { 122*da0073e9SAndroid Build Coastguard Worker "speedup": "Geometric mean speedup", 123*da0073e9SAndroid Build Coastguard Worker "comptime": "Mean compilation time", 124*da0073e9SAndroid Build Coastguard Worker "memory": "Peak memory compression ratio", 125*da0073e9SAndroid Build Coastguard Worker } 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker for dtype_mode, r in results.items(): 128*da0073e9SAndroid Build Coastguard Worker print(f"# {dtype_mode} performance results") 129*da0073e9SAndroid Build Coastguard Worker for statistic, data in r.items(): 130*da0073e9SAndroid Build Coastguard Worker print(f"## {descriptions[statistic]}") 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker table = [] 133*da0073e9SAndroid Build Coastguard Worker for row_name in data[next(iter(data.keys()))]: 134*da0073e9SAndroid Build Coastguard Worker row = [row_name] 135*da0073e9SAndroid Build Coastguard Worker for col_name in data: 136*da0073e9SAndroid Build Coastguard Worker row.append(round(data[col_name][row_name], 2)) 137*da0073e9SAndroid Build Coastguard Worker table.append(row) 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker headers = list(data.keys()) 140*da0073e9SAndroid Build Coastguard Worker print(tabulate(table, headers=headers)) 141*da0073e9SAndroid Build Coastguard Worker print() 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Workermain() 145