1import argparse 2import sys 3 4from benchmarks.fastrnns.factory import lstm_creator 5 6from torchvision.models import resnet50 7 8import torch 9import torch.utils.benchmark as benchmark_utils 10 11 12def prepare_lstm_jit(bench_args): 13 model_def = lstm_creator( 14 script=True, 15 seqLength=bench_args.lstmSeqLength, 16 numLayers=bench_args.lstmNumLayers, 17 inputSize=bench_args.lstmInputSize, 18 hiddenSize=bench_args.lstmHiddenSize, 19 miniBatch=bench_args.lstmMiniBatch, 20 device="cpu", 21 ) 22 return model_def.inputs, model_def.forward 23 24 25def prepare_resnet50_jit(bench_args): 26 model = resnet50() 27 inputs = (torch.randn(32, 3, 224, 224),) 28 model = torch.jit.trace(model, inputs) 29 return inputs, model 30 31 32MODELS = { 33 "resnet50_jit": prepare_resnet50_jit, 34 "lstm_jit": prepare_lstm_jit, 35} 36 37NUM_THREADS = [1, 2, 4, 8, 16, 32] 38 39 40def run_bench(model_names, bench_args): 41 results = [] 42 for model_name in model_names: 43 model_creator = MODELS[model_name] 44 inputs, model = model_creator(bench_args) 45 46 print("Benchmarking RecordFunction overhead for", model_name) 47 print("Running warmup...", end=" ") 48 sys.stdout.flush() 49 for _ in range(bench_args.warmup): 50 model(*inputs) 51 print("finished") 52 53 for num_threads in NUM_THREADS: 54 for with_rec_fn in [True, False]: 55 torch.autograd._enable_record_function(with_rec_fn) 56 torch.autograd._clear_callbacks() 57 if with_rec_fn: 58 torch.autograd._set_empty_test_observer(True, 0.0001) 59 60 print( 61 "Running {} RecordFunction, num threads {} ...".format( 62 "with" if with_rec_fn else "without", num_threads 63 ), 64 end=" ", 65 ) 66 sys.stdout.flush() 67 timer = benchmark_utils.Timer( 68 stmt="model(*inputs)", 69 globals={"model": model, "inputs": inputs}, 70 description=model_name, 71 label="Record function overhead", 72 sub_label=f"with{'' if with_rec_fn else 'out'}_rec_fn, num_threads {num_threads}", 73 num_threads=num_threads, 74 ) 75 result = timer.blocked_autorange( 76 min_run_time=bench_args.timer_min_run_time 77 ) 78 print("finished") 79 print(result) 80 sys.stdout.flush() 81 results.append(result) 82 83 comparison = benchmark_utils.Compare(results) 84 comparison.trim_significant_figures() 85 comparison.highlight_warnings() 86 comparison.print() 87 88 89if __name__ == "__main__": 90 parser = argparse.ArgumentParser( 91 description="Benchmark RecordFunction overhead for ResNet and LSTM models" 92 ) 93 94 parser.add_argument( 95 "--models", 96 nargs="*", 97 default=["lstm_jit"], 98 help="What model to run: " + str(MODELS.keys()), 99 ) 100 101 parser.add_argument("--lstmSeqLength", default="100", type=int) 102 parser.add_argument("--lstmNumLayers", default="1", type=int) 103 parser.add_argument("--lstmInputSize", default="512", type=int) 104 parser.add_argument("--lstmHiddenSize", default="512", type=int) 105 parser.add_argument("--lstmMiniBatch", default="64", type=int) 106 parser.add_argument("--warmup", default="2", type=int) 107 parser.add_argument("--nloops", default="50", type=int) 108 parser.add_argument( 109 "--timer-min-run-time", "--timer_min_run_time", default=120, type=int 110 ) 111 112 args = parser.parse_args() 113 114 models = args.models or MODELS.keys() 115 116 for model in models: 117 assert model in MODELS 118 run_bench(models, args) 119