1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport copy 3*da0073e9SAndroid Build Coastguard Workerimport gc 4*da0073e9SAndroid Build Coastguard Workerimport json 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerimport time 7*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler import record_function 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerfrom .fuser import set_fuser 13*da0073e9SAndroid Build Coastguard Workerfrom .runner import get_nn_runners 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard WorkerBenchResult = namedtuple( 17*da0073e9SAndroid Build Coastguard Worker "BenchResult", 18*da0073e9SAndroid Build Coastguard Worker [ 19*da0073e9SAndroid Build Coastguard Worker "name", 20*da0073e9SAndroid Build Coastguard Worker "avg_fwd", 21*da0073e9SAndroid Build Coastguard Worker "std_fwd", 22*da0073e9SAndroid Build Coastguard Worker "info_fwd", 23*da0073e9SAndroid Build Coastguard Worker "avg_bwd", 24*da0073e9SAndroid Build Coastguard Worker "std_bwd", 25*da0073e9SAndroid Build Coastguard Worker "info_bwd", 26*da0073e9SAndroid Build Coastguard Worker ], 27*da0073e9SAndroid Build Coastguard Worker) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workerdef fit_str(string, colwidth=16): 31*da0073e9SAndroid Build Coastguard Worker if len(string) < colwidth: 32*da0073e9SAndroid Build Coastguard Worker return (colwidth - len(string)) * " " + string 33*da0073e9SAndroid Build Coastguard Worker else: 34*da0073e9SAndroid Build Coastguard Worker return string[:colwidth] 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Workerdef to_str(item): 38*da0073e9SAndroid Build Coastguard Worker if isinstance(item, float): 39*da0073e9SAndroid Build Coastguard Worker return f"{item:.4g}" 40*da0073e9SAndroid Build Coastguard Worker return str(item) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Workerdef print_header(colwidth=16, sep=" "): 44*da0073e9SAndroid Build Coastguard Worker items = [] 45*da0073e9SAndroid Build Coastguard Worker for item in BenchResult._fields: 46*da0073e9SAndroid Build Coastguard Worker items.append(fit_str(item)) 47*da0073e9SAndroid Build Coastguard Worker return sep.join(items) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Workerdef pretty_print(benchresult, colwidth=16, sep=" "): 51*da0073e9SAndroid Build Coastguard Worker items = [] 52*da0073e9SAndroid Build Coastguard Worker for thing in benchresult: 53*da0073e9SAndroid Build Coastguard Worker items.append(fit_str(to_str(thing))) 54*da0073e9SAndroid Build Coastguard Worker return sep.join(items) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker# shim for torch.cuda.Event when running on cpu 58*da0073e9SAndroid Build Coastguard Workerclass Event: 59*da0073e9SAndroid Build Coastguard Worker def __init__(self, enable_timing): 60*da0073e9SAndroid Build Coastguard Worker pass 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def record(self): 63*da0073e9SAndroid Build Coastguard Worker self.time = time.perf_counter() 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker def elapsed_time(self, end_event): 66*da0073e9SAndroid Build Coastguard Worker assert isinstance(end_event, Event) 67*da0073e9SAndroid Build Coastguard Worker return end_event.time - self.time 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Workerdef trainbench( 71*da0073e9SAndroid Build Coastguard Worker name, 72*da0073e9SAndroid Build Coastguard Worker rnn_creator, 73*da0073e9SAndroid Build Coastguard Worker nloops=100, 74*da0073e9SAndroid Build Coastguard Worker warmup=10, 75*da0073e9SAndroid Build Coastguard Worker seqLength=100, 76*da0073e9SAndroid Build Coastguard Worker numLayers=1, 77*da0073e9SAndroid Build Coastguard Worker inputSize=512, 78*da0073e9SAndroid Build Coastguard Worker hiddenSize=512, 79*da0073e9SAndroid Build Coastguard Worker miniBatch=64, 80*da0073e9SAndroid Build Coastguard Worker device="cuda", 81*da0073e9SAndroid Build Coastguard Worker seed=None, 82*da0073e9SAndroid Build Coastguard Worker): 83*da0073e9SAndroid Build Coastguard Worker def train_batch(modeldef): 84*da0073e9SAndroid Build Coastguard Worker # CUDA events for timing 85*da0073e9SAndroid Build Coastguard Worker if device == "cuda": 86*da0073e9SAndroid Build Coastguard Worker timer_class = torch.cuda.Event 87*da0073e9SAndroid Build Coastguard Worker else: 88*da0073e9SAndroid Build Coastguard Worker timer_class = Event 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker fwd_start_event = timer_class(enable_timing=True) 91*da0073e9SAndroid Build Coastguard Worker fwd_end_event = timer_class(enable_timing=True) 92*da0073e9SAndroid Build Coastguard Worker bwd_start_event = timer_class(enable_timing=True) 93*da0073e9SAndroid Build Coastguard Worker bwd_end_event = timer_class(enable_timing=True) 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker gc.collect() 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker fwd_start_event.record() 98*da0073e9SAndroid Build Coastguard Worker with record_function("## forward ##"): 99*da0073e9SAndroid Build Coastguard Worker forward_output = modeldef.forward(*modeldef.inputs) 100*da0073e9SAndroid Build Coastguard Worker fwd_end_event.record() 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker # XXX: Use if need to print something 103*da0073e9SAndroid Build Coastguard Worker # print(modeldef.forward.graph_for(*modeldef.inputs)) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker if modeldef.backward_setup is not None: 106*da0073e9SAndroid Build Coastguard Worker backward_input = modeldef.backward_setup(forward_output) 107*da0073e9SAndroid Build Coastguard Worker else: 108*da0073e9SAndroid Build Coastguard Worker backward_input = forward_output 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker gc.collect() 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker bwd_start_event.record() 113*da0073e9SAndroid Build Coastguard Worker if modeldef.backward is not None: 114*da0073e9SAndroid Build Coastguard Worker modeldef.backward(*backward_input) 115*da0073e9SAndroid Build Coastguard Worker bwd_end_event.record() 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker if modeldef.backward is not None: 118*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 119*da0073e9SAndroid Build Coastguard Worker for param in modeldef.params: 120*da0073e9SAndroid Build Coastguard Worker assert param.grad is not None 121*da0073e9SAndroid Build Coastguard Worker param.grad.zero_() 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker if device == "cuda": 124*da0073e9SAndroid Build Coastguard Worker torch.cuda.synchronize() 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker fwd_time = fwd_start_event.elapsed_time(fwd_end_event) 127*da0073e9SAndroid Build Coastguard Worker bwd_time = bwd_start_event.elapsed_time(bwd_end_event) 128*da0073e9SAndroid Build Coastguard Worker return fwd_time, bwd_time 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker creator_args = creator_args = { 131*da0073e9SAndroid Build Coastguard Worker "seqLength": seqLength, 132*da0073e9SAndroid Build Coastguard Worker "numLayers": numLayers, 133*da0073e9SAndroid Build Coastguard Worker "inputSize": inputSize, 134*da0073e9SAndroid Build Coastguard Worker "hiddenSize": hiddenSize, 135*da0073e9SAndroid Build Coastguard Worker "miniBatch": miniBatch, 136*da0073e9SAndroid Build Coastguard Worker "device": device, 137*da0073e9SAndroid Build Coastguard Worker "seed": seed, 138*da0073e9SAndroid Build Coastguard Worker } 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker modeldef = rnn_creator(**creator_args) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker [train_batch(modeldef) for _ in range(warmup)] 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker results = [train_batch(modeldef) for _ in range(nloops)] 145*da0073e9SAndroid Build Coastguard Worker fwd_times, bwd_times = zip(*results) 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker fwd_times = torch.tensor(fwd_times) 148*da0073e9SAndroid Build Coastguard Worker bwd_times = torch.tensor(bwd_times) 149*da0073e9SAndroid Build Coastguard Worker return BenchResult( 150*da0073e9SAndroid Build Coastguard Worker name=name, 151*da0073e9SAndroid Build Coastguard Worker avg_fwd=fwd_times.mean().item(), 152*da0073e9SAndroid Build Coastguard Worker std_fwd=fwd_times.std().item(), 153*da0073e9SAndroid Build Coastguard Worker info_fwd=fwd_times, 154*da0073e9SAndroid Build Coastguard Worker avg_bwd=bwd_times.mean().item(), 155*da0073e9SAndroid Build Coastguard Worker std_bwd=bwd_times.std().item(), 156*da0073e9SAndroid Build Coastguard Worker info_bwd=bwd_times, 157*da0073e9SAndroid Build Coastguard Worker ) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Workerdef print_stderr(*args, **kwargs): 161*da0073e9SAndroid Build Coastguard Worker kwargs["file"] = sys.stderr 162*da0073e9SAndroid Build Coastguard Worker return print(*args, **kwargs) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Workerdef print_json_oss_format(results): 166*da0073e9SAndroid Build Coastguard Worker oss_results = {} 167*da0073e9SAndroid Build Coastguard Worker for group_name, group_val in results.items(): 168*da0073e9SAndroid Build Coastguard Worker oss_results[group_name] = {} 169*da0073e9SAndroid Build Coastguard Worker for model_name, run_time in group_val.items(): 170*da0073e9SAndroid Build Coastguard Worker # Output for OSS 171*da0073e9SAndroid Build Coastguard Worker oss_results[group_name][model_name] = run_time["avg"] 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker print(json.dumps(oss_results)) 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Workerdef print_json_pep_format(results): 177*da0073e9SAndroid Build Coastguard Worker # print the AI-PEP format json string for each model 178*da0073e9SAndroid Build Coastguard Worker for group_name, group_val in results.items(): 179*da0073e9SAndroid Build Coastguard Worker for model_name, run_time in group_val.items(): 180*da0073e9SAndroid Build Coastguard Worker # Output for AI-PEP 181*da0073e9SAndroid Build Coastguard Worker num_iters = len(run_time["info"]) 182*da0073e9SAndroid Build Coastguard Worker info = run_time["info"].tolist() 183*da0073e9SAndroid Build Coastguard Worker for i in range(num_iters): 184*da0073e9SAndroid Build Coastguard Worker print( 185*da0073e9SAndroid Build Coastguard Worker "Caffe2Observer " 186*da0073e9SAndroid Build Coastguard Worker + json.dumps( 187*da0073e9SAndroid Build Coastguard Worker { 188*da0073e9SAndroid Build Coastguard Worker "type": "NET", 189*da0073e9SAndroid Build Coastguard Worker "metric": group_name + "-" + model_name, 190*da0073e9SAndroid Build Coastguard Worker "unit": "ms", 191*da0073e9SAndroid Build Coastguard Worker "value": str(info[i]), 192*da0073e9SAndroid Build Coastguard Worker } 193*da0073e9SAndroid Build Coastguard Worker ) 194*da0073e9SAndroid Build Coastguard Worker ) 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Workerdef bench(rnn_runners, group_name, print_json=False, sep=" ", **params): 198*da0073e9SAndroid Build Coastguard Worker print_stderr(print_header(sep=sep)) 199*da0073e9SAndroid Build Coastguard Worker results = {} 200*da0073e9SAndroid Build Coastguard Worker for name, creator, context in rnn_runners: 201*da0073e9SAndroid Build Coastguard Worker with context(): 202*da0073e9SAndroid Build Coastguard Worker try: 203*da0073e9SAndroid Build Coastguard Worker result = trainbench(name, creator, **params) 204*da0073e9SAndroid Build Coastguard Worker # Replace the value of info_fwd and info_bwd to None 205*da0073e9SAndroid Build Coastguard Worker result_with_no_info = result._replace(info_fwd="None", info_bwd="None") 206*da0073e9SAndroid Build Coastguard Worker print_stderr(pretty_print(result_with_no_info, sep=sep)) 207*da0073e9SAndroid Build Coastguard Worker results[name] = result 208*da0073e9SAndroid Build Coastguard Worker except Exception as e: 209*da0073e9SAndroid Build Coastguard Worker if not print_json: 210*da0073e9SAndroid Build Coastguard Worker raise 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker return { 213*da0073e9SAndroid Build Coastguard Worker group_name: { 214*da0073e9SAndroid Build Coastguard Worker k: {"avg": v.avg_fwd, "std": v.std_fwd, "info": v.info_fwd} 215*da0073e9SAndroid Build Coastguard Worker for k, v in results.items() 216*da0073e9SAndroid Build Coastguard Worker }, 217*da0073e9SAndroid Build Coastguard Worker group_name 218*da0073e9SAndroid Build Coastguard Worker + "-backward": { 219*da0073e9SAndroid Build Coastguard Worker k: {"avg": v.avg_bwd, "std": v.std_bwd, "info": v.info_bwd} 220*da0073e9SAndroid Build Coastguard Worker for k, v in results.items() 221*da0073e9SAndroid Build Coastguard Worker }, 222*da0073e9SAndroid Build Coastguard Worker } 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Workerdef bench_group(model_list, bench_name, bench_group, bench_args): 226*da0073e9SAndroid Build Coastguard Worker print_stderr(f"Benchmarking {bench_name}s...") 227*da0073e9SAndroid Build Coastguard Worker nn_results = bench(get_nn_runners(*model_list), bench_group, **bench_args) 228*da0073e9SAndroid Build Coastguard Worker print_stderr("") 229*da0073e9SAndroid Build Coastguard Worker return nn_results 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 233*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="Profile RNNs") 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker # groups help control which test group you want to run 236*da0073e9SAndroid Build Coastguard Worker # if you only want to run one/two benchmark, run it with 237*da0073e9SAndroid Build Coastguard Worker # e.g: python -m fastrnns.bench --rnns jit and --group rnns 238*da0073e9SAndroid Build Coastguard Worker default_groups = ["cnns", "rnns"] 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--seqLength", default="100", type=int) 241*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--numLayers", default="1", type=int) 242*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--inputSize", default="512", type=int) 243*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--hiddenSize", default="512", type=int) 244*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--miniBatch", default="64", type=int) 245*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--warmup", default="10", type=int) 246*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--nloops", default="100", type=int) 247*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--device", default="cuda", type=str) 248*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 249*da0073e9SAndroid Build Coastguard Worker "--variable-lstms", 250*da0073e9SAndroid Build Coastguard Worker "--variable_lstms", 251*da0073e9SAndroid Build Coastguard Worker action="store_true", 252*da0073e9SAndroid Build Coastguard Worker help="Also benchmark variable sequence length lstms " 253*da0073e9SAndroid Build Coastguard Worker "Note that some of these run really slowly " 254*da0073e9SAndroid Build Coastguard Worker "and that the `seqLength` flag will be ignored.", 255*da0073e9SAndroid Build Coastguard Worker ) 256*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--sep", default=" ", type=str) 257*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--print-json", nargs="?", default=None, const="oss") 258*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--rnns", nargs="*", help="What to run. cudnn, aten, jit, etc") 259*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 260*da0073e9SAndroid Build Coastguard Worker "--cnns", nargs="*", help="What to run. resnet18, resnet18_jit, resnet50, etc" 261*da0073e9SAndroid Build Coastguard Worker ) 262*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 263*da0073e9SAndroid Build Coastguard Worker "--group", 264*da0073e9SAndroid Build Coastguard Worker nargs="*", 265*da0073e9SAndroid Build Coastguard Worker default=default_groups, 266*da0073e9SAndroid Build Coastguard Worker help="Which group to run. cnns, rnns, etc.", 267*da0073e9SAndroid Build Coastguard Worker ) 268*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 269*da0073e9SAndroid Build Coastguard Worker "--fuser", 270*da0073e9SAndroid Build Coastguard Worker default="te", 271*da0073e9SAndroid Build Coastguard Worker type=str, 272*da0073e9SAndroid Build Coastguard Worker help="The fuser backend to use. One of: te, old, or none", 273*da0073e9SAndroid Build Coastguard Worker ) 274*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 275*da0073e9SAndroid Build Coastguard Worker "--executor", 276*da0073e9SAndroid Build Coastguard Worker default=None, 277*da0073e9SAndroid Build Coastguard Worker type=str, 278*da0073e9SAndroid Build Coastguard Worker help="The executor to use. One of: legacy, simple, profiling", 279*da0073e9SAndroid Build Coastguard Worker ) 280*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 281*da0073e9SAndroid Build Coastguard Worker "--cuda-pointwise-loop-level", 282*da0073e9SAndroid Build Coastguard Worker "--cuda_pointwise_loop_level", 283*da0073e9SAndroid Build Coastguard Worker default=None, 284*da0073e9SAndroid Build Coastguard Worker type=int, 285*da0073e9SAndroid Build Coastguard Worker ) 286*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 287*da0073e9SAndroid Build Coastguard Worker "--cuda-pointwise-block-count", 288*da0073e9SAndroid Build Coastguard Worker "--cuda_pointwise_block_count", 289*da0073e9SAndroid Build Coastguard Worker default=None, 290*da0073e9SAndroid Build Coastguard Worker type=int, 291*da0073e9SAndroid Build Coastguard Worker ) 292*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 293*da0073e9SAndroid Build Coastguard Worker "--cuda-pointwise-block-size", 294*da0073e9SAndroid Build Coastguard Worker "--cuda_pointwise_block_size", 295*da0073e9SAndroid Build Coastguard Worker default=None, 296*da0073e9SAndroid Build Coastguard Worker type=int, 297*da0073e9SAndroid Build Coastguard Worker ) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 300*da0073e9SAndroid Build Coastguard Worker set_fuser(args.fuser, args.executor) 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker if args.cuda_pointwise_loop_level: 303*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_cuda_pointwise_loop_levels(args.cuda_pointwise_loop_level) 304*da0073e9SAndroid Build Coastguard Worker if args.cuda_pointwise_block_count: 305*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_cuda_pointwise_block_count(args.cuda_pointwise_block_count) 306*da0073e9SAndroid Build Coastguard Worker if args.cuda_pointwise_block_size: 307*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_cuda_pointwise_block_size(args.cuda_pointwise_block_size) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker rnns = args.rnns or [ 310*da0073e9SAndroid Build Coastguard Worker "cudnn", 311*da0073e9SAndroid Build Coastguard Worker "aten", 312*da0073e9SAndroid Build Coastguard Worker "jit", 313*da0073e9SAndroid Build Coastguard Worker "jit_premul", 314*da0073e9SAndroid Build Coastguard Worker "jit_premul_bias", 315*da0073e9SAndroid Build Coastguard Worker "jit_simple", 316*da0073e9SAndroid Build Coastguard Worker "jit_multilayer", 317*da0073e9SAndroid Build Coastguard Worker "py", 318*da0073e9SAndroid Build Coastguard Worker ] 319*da0073e9SAndroid Build Coastguard Worker cnns = args.cnns or ["resnet18", "resnet18_jit", "resnet50", "resnet50_jit"] 320*da0073e9SAndroid Build Coastguard Worker # TODO: Maybe add a separate section for the layernorm/dropout lstms 321*da0073e9SAndroid Build Coastguard Worker # 'cudnn_layernorm', jit_layernorm', 'jit_layernom_decom', 322*da0073e9SAndroid Build Coastguard Worker # 'jit', 'jit_dropout', 'cudnn_dropout' 323*da0073e9SAndroid Build Coastguard Worker vlrnns = ["vl_cudnn", "vl_jit", "vl_py"] 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker if args.print_json: 326*da0073e9SAndroid Build Coastguard Worker print_stderr = lambda *args, **kwargs: None # noqa: E731,F811 327*da0073e9SAndroid Build Coastguard Worker print_stderr(args) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker bench_args = copy.deepcopy(vars(args)) 330*da0073e9SAndroid Build Coastguard Worker should_bench_varlen_lstms = args.variable_lstms 331*da0073e9SAndroid Build Coastguard Worker del bench_args["group"] 332*da0073e9SAndroid Build Coastguard Worker del bench_args["rnns"] 333*da0073e9SAndroid Build Coastguard Worker del bench_args["cnns"] 334*da0073e9SAndroid Build Coastguard Worker del bench_args["variable_lstms"] 335*da0073e9SAndroid Build Coastguard Worker del bench_args["fuser"] 336*da0073e9SAndroid Build Coastguard Worker del bench_args["executor"] 337*da0073e9SAndroid Build Coastguard Worker del bench_args["cuda_pointwise_loop_level"] 338*da0073e9SAndroid Build Coastguard Worker del bench_args["cuda_pointwise_block_count"] 339*da0073e9SAndroid Build Coastguard Worker del bench_args["cuda_pointwise_block_size"] 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker results = {} 342*da0073e9SAndroid Build Coastguard Worker if should_bench_varlen_lstms: 343*da0073e9SAndroid Build Coastguard Worker if args.nloops + args.warmup > 30: 344*da0073e9SAndroid Build Coastguard Worker print_stderr( 345*da0073e9SAndroid Build Coastguard Worker "WARNING: some of the variable sequence length lstms are " 346*da0073e9SAndroid Build Coastguard Worker "very unoptimized and therefore take forever to run." 347*da0073e9SAndroid Build Coastguard Worker ) 348*da0073e9SAndroid Build Coastguard Worker results.update( 349*da0073e9SAndroid Build Coastguard Worker bench_group(vlrnns, "variable-length sequence LSTM", "vl_lstm", bench_args) 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker if "rnns" in args.group: 353*da0073e9SAndroid Build Coastguard Worker results.update(bench_group(rnns, "LSTM", "lstm", bench_args)) 354*da0073e9SAndroid Build Coastguard Worker if "cnns" in args.group: 355*da0073e9SAndroid Build Coastguard Worker results.update(bench_group(cnns, "ResNet", "resnet", bench_args)) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker if args.print_json == "oss": 358*da0073e9SAndroid Build Coastguard Worker print_json_oss_format(results) 359*da0073e9SAndroid Build Coastguard Worker elif args.print_json == "pep": 360*da0073e9SAndroid Build Coastguard Worker print_json_pep_format(results) 361