1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport functools 3*da0073e9SAndroid Build Coastguard Workerimport traceback 4*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional, Tuple 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.jit.log_extract import ( 7*da0073e9SAndroid Build Coastguard Worker extract_ir, 8*da0073e9SAndroid Build Coastguard Worker load_graph_and_inputs, 9*da0073e9SAndroid Build Coastguard Worker run_baseline_no_fusion, 10*da0073e9SAndroid Build Coastguard Worker run_nnc, 11*da0073e9SAndroid Build Coastguard Worker run_nvfuser, 12*da0073e9SAndroid Build Coastguard Worker) 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker""" 16*da0073e9SAndroid Build Coastguard WorkerUsage: 17*da0073e9SAndroid Build Coastguard Worker1. Run your script and pipe into a log file 18*da0073e9SAndroid Build Coastguard Worker PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt 19*da0073e9SAndroid Build Coastguard Worker2. Run log_extract: 20*da0073e9SAndroid Build Coastguard Worker log_extract.py log.txt --nvfuser --nnc-dynamic --nnc-static 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard WorkerYou can also extract the list of extracted IR: 23*da0073e9SAndroid Build Coastguard Worker log_extract.py log.txt --output 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard WorkerPassing in --graphs 0 2 will only run graphs 0 and 2 26*da0073e9SAndroid Build Coastguard Worker""" 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerdef test_runners( 30*da0073e9SAndroid Build Coastguard Worker graphs: List[str], 31*da0073e9SAndroid Build Coastguard Worker runners: List[Tuple[str, Callable]], 32*da0073e9SAndroid Build Coastguard Worker graph_set: Optional[List[int]], 33*da0073e9SAndroid Build Coastguard Worker): 34*da0073e9SAndroid Build Coastguard Worker for i, ir in enumerate(graphs): 35*da0073e9SAndroid Build Coastguard Worker _, inputs = load_graph_and_inputs(ir) 36*da0073e9SAndroid Build Coastguard Worker if graph_set and i not in graph_set: 37*da0073e9SAndroid Build Coastguard Worker continue 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker print(f"Running Graph {i}") 40*da0073e9SAndroid Build Coastguard Worker prev_result = None 41*da0073e9SAndroid Build Coastguard Worker prev_runner_name = None 42*da0073e9SAndroid Build Coastguard Worker for runner in runners: 43*da0073e9SAndroid Build Coastguard Worker runner_name, runner_fn = runner 44*da0073e9SAndroid Build Coastguard Worker try: 45*da0073e9SAndroid Build Coastguard Worker result = runner_fn(ir, inputs) 46*da0073e9SAndroid Build Coastguard Worker if prev_result: 47*da0073e9SAndroid Build Coastguard Worker improvement = (prev_result / result - 1) * 100 48*da0073e9SAndroid Build Coastguard Worker print( 49*da0073e9SAndroid Build Coastguard Worker f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%" 50*da0073e9SAndroid Build Coastguard Worker ) 51*da0073e9SAndroid Build Coastguard Worker else: 52*da0073e9SAndroid Build Coastguard Worker print(f"{runner_name} : {result:.6f} ms") 53*da0073e9SAndroid Build Coastguard Worker prev_result = result 54*da0073e9SAndroid Build Coastguard Worker prev_runner_name = runner_name 55*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 56*da0073e9SAndroid Build Coastguard Worker print(f" Graph {i} failed for {runner_name} :", traceback.format_exc()) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Workerdef run(): 60*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser( 61*da0073e9SAndroid Build Coastguard Worker description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR" 62*da0073e9SAndroid Build Coastguard Worker ) 63*da0073e9SAndroid Build Coastguard Worker parser.add_argument("filename", help="Filename of log file") 64*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 65*da0073e9SAndroid Build Coastguard Worker "--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser" 66*da0073e9SAndroid Build Coastguard Worker ) 67*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 68*da0073e9SAndroid Build Coastguard Worker "--no-nvfuser", 69*da0073e9SAndroid Build Coastguard Worker dest="nvfuser", 70*da0073e9SAndroid Build Coastguard Worker action="store_false", 71*da0073e9SAndroid Build Coastguard Worker help="DON'T benchmark nvfuser", 72*da0073e9SAndroid Build Coastguard Worker ) 73*da0073e9SAndroid Build Coastguard Worker parser.set_defaults(nvfuser=False) 74*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 75*da0073e9SAndroid Build Coastguard Worker "--nnc-static", 76*da0073e9SAndroid Build Coastguard Worker dest="nnc_static", 77*da0073e9SAndroid Build Coastguard Worker action="store_true", 78*da0073e9SAndroid Build Coastguard Worker help="benchmark nnc static", 79*da0073e9SAndroid Build Coastguard Worker ) 80*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 81*da0073e9SAndroid Build Coastguard Worker "--no-nnc-static", 82*da0073e9SAndroid Build Coastguard Worker dest="nnc_static", 83*da0073e9SAndroid Build Coastguard Worker action="store_false", 84*da0073e9SAndroid Build Coastguard Worker help="DON'T benchmark nnc static", 85*da0073e9SAndroid Build Coastguard Worker ) 86*da0073e9SAndroid Build Coastguard Worker parser.set_defaults(nnc_static=False) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 89*da0073e9SAndroid Build Coastguard Worker "--nnc-dynamic", 90*da0073e9SAndroid Build Coastguard Worker dest="nnc_dynamic", 91*da0073e9SAndroid Build Coastguard Worker action="store_true", 92*da0073e9SAndroid Build Coastguard Worker help="nnc with dynamic shapes", 93*da0073e9SAndroid Build Coastguard Worker ) 94*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 95*da0073e9SAndroid Build Coastguard Worker "--no-nnc-dynamic", 96*da0073e9SAndroid Build Coastguard Worker dest="nnc_dynamic", 97*da0073e9SAndroid Build Coastguard Worker action="store_false", 98*da0073e9SAndroid Build Coastguard Worker help="DONT't benchmark nnc with dynamic shapes", 99*da0073e9SAndroid Build Coastguard Worker ) 100*da0073e9SAndroid Build Coastguard Worker parser.set_defaults(nnc_dynamic=False) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 103*da0073e9SAndroid Build Coastguard Worker "--baseline", dest="baseline", action="store_true", help="benchmark baseline" 104*da0073e9SAndroid Build Coastguard Worker ) 105*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 106*da0073e9SAndroid Build Coastguard Worker "--no-baseline", 107*da0073e9SAndroid Build Coastguard Worker dest="baseline", 108*da0073e9SAndroid Build Coastguard Worker action="store_false", 109*da0073e9SAndroid Build Coastguard Worker help="DON'T benchmark baseline", 110*da0073e9SAndroid Build Coastguard Worker ) 111*da0073e9SAndroid Build Coastguard Worker parser.set_defaults(baseline=False) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 114*da0073e9SAndroid Build Coastguard Worker "--output", dest="output", action="store_true", help="Output graph IR" 115*da0073e9SAndroid Build Coastguard Worker ) 116*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 117*da0073e9SAndroid Build Coastguard Worker "--no-output", dest="output", action="store_false", help="DON'T output graph IR" 118*da0073e9SAndroid Build Coastguard Worker ) 119*da0073e9SAndroid Build Coastguard Worker parser.set_defaults(output=False) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 122*da0073e9SAndroid Build Coastguard Worker "--graphs", nargs="+", type=int, help="Run only specified graph indices" 123*da0073e9SAndroid Build Coastguard Worker ) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 126*da0073e9SAndroid Build Coastguard Worker graphs = extract_ir(args.filename) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker graph_set = args.graphs 129*da0073e9SAndroid Build Coastguard Worker graph_set = graph_set if graph_set else None 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker options = [] 132*da0073e9SAndroid Build Coastguard Worker if args.baseline: 133*da0073e9SAndroid Build Coastguard Worker options.append(("Baseline no fusion", run_baseline_no_fusion)) 134*da0073e9SAndroid Build Coastguard Worker if args.nnc_dynamic: 135*da0073e9SAndroid Build Coastguard Worker options.append(("NNC Dynamic", functools.partial(run_nnc, dynamic=True))) 136*da0073e9SAndroid Build Coastguard Worker if args.nnc_static: 137*da0073e9SAndroid Build Coastguard Worker options.append(("NNC Static", functools.partial(run_nnc, dynamic=False))) 138*da0073e9SAndroid Build Coastguard Worker if args.nvfuser: 139*da0073e9SAndroid Build Coastguard Worker options.append(("NVFuser", run_nvfuser)) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker test_runners(graphs, options, graph_set) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker if args.output: 144*da0073e9SAndroid Build Coastguard Worker quoted = [] 145*da0073e9SAndroid Build Coastguard Worker for i, ir in enumerate(graphs): 146*da0073e9SAndroid Build Coastguard Worker if graph_set and i not in graph_set: 147*da0073e9SAndroid Build Coastguard Worker continue 148*da0073e9SAndroid Build Coastguard Worker quoted.append('"""' + ir + '"""') 149*da0073e9SAndroid Build Coastguard Worker print("[" + ", ".join(quoted) + "]") 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 153*da0073e9SAndroid Build Coastguard Worker run() 154