1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport argparse 4*da0073e9SAndroid Build Coastguard Workerimport inspect 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport numpy as np 8*da0073e9SAndroid Build Coastguard Workerimport tabulate 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport torch 11*da0073e9SAndroid Build Coastguard Workerimport torch._inductor 12*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.backends.cudagraphs import cudagraphs_inner 13*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import same 14*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor.compile_fx import compile_fx 15*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor.utils import timed 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Workeraten = torch.ops.aten 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workertry: 21*da0073e9SAndroid Build Coastguard Worker import test.test_torchinductor as tti 22*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 23*da0073e9SAndroid Build Coastguard Worker tti = None 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workerdef compute_speedups(args, models, example_inputs): 27*da0073e9SAndroid Build Coastguard Worker expected = models[0](*example_inputs) 28*da0073e9SAndroid Build Coastguard Worker for model in models[1:]: 29*da0073e9SAndroid Build Coastguard Worker actual = model(*example_inputs) 30*da0073e9SAndroid Build Coastguard Worker assert same(actual, expected), expected[0] - actual[0] 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker timings = np.zeros((args.repeat, len(models)), np.float64) 33*da0073e9SAndroid Build Coastguard Worker for rep in range(args.repeat): 34*da0073e9SAndroid Build Coastguard Worker # interleave the runs to handle frequency scaling and load changes 35*da0073e9SAndroid Build Coastguard Worker for m, model in enumerate(models): 36*da0073e9SAndroid Build Coastguard Worker timings[rep, m] = timed(model, example_inputs) 37*da0073e9SAndroid Build Coastguard Worker median = np.median(timings, axis=0) 38*da0073e9SAndroid Build Coastguard Worker return (median[0] / median[1:]).tolist() 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerdef microbenchmark(args, model, example_inputs): 42*da0073e9SAndroid Build Coastguard Worker compiled_fn = compile_fx(torch.fx.symbolic_trace(model), example_inputs) 43*da0073e9SAndroid Build Coastguard Worker cudagraphs_eager = cudagraphs_inner(model, example_inputs, copy_outputs=False) 44*da0073e9SAndroid Build Coastguard Worker cudagraphs_jit = cudagraphs_inner( 45*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(model, example_inputs), example_inputs, copy_outputs=False 46*da0073e9SAndroid Build Coastguard Worker ) 47*da0073e9SAndroid Build Coastguard Worker return compute_speedups( 48*da0073e9SAndroid Build Coastguard Worker args, 49*da0073e9SAndroid Build Coastguard Worker [cudagraphs_eager, cudagraphs_jit, compiled_fn], 50*da0073e9SAndroid Build Coastguard Worker example_inputs, 51*da0073e9SAndroid Build Coastguard Worker ) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Workerclass MyModel1(torch.nn.Module): 55*da0073e9SAndroid Build Coastguard Worker def __init__(self): 56*da0073e9SAndroid Build Coastguard Worker super().__init__() 57*da0073e9SAndroid Build Coastguard Worker self.model = torch.nn.Sequential( 58*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(1024, 1024), 59*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 60*da0073e9SAndroid Build Coastguard Worker ) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 63*da0073e9SAndroid Build Coastguard Worker # return (self.model(input) + 1,) 64*da0073e9SAndroid Build Coastguard Worker return (self.model(input),) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Workerclass MyModel2(torch.nn.Module): 68*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 69*da0073e9SAndroid Build Coastguard Worker # return x / (torch.abs(x) + 1.0), 70*da0073e9SAndroid Build Coastguard Worker return (x + y,) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Workerclass MicroBenchmarks: 74*da0073e9SAndroid Build Coastguard Worker @staticmethod 75*da0073e9SAndroid Build Coastguard Worker def add(a, b): 76*da0073e9SAndroid Build Coastguard Worker return (a + b,) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker @staticmethod 79*da0073e9SAndroid Build Coastguard Worker def scale(x, m, d): 80*da0073e9SAndroid Build Coastguard Worker return ((x - m) / torch.clip(d, 1e-4),) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker @staticmethod 83*da0073e9SAndroid Build Coastguard Worker def abs_norm(x): 84*da0073e9SAndroid Build Coastguard Worker return (x / (torch.abs(x) + 1),) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker @staticmethod 87*da0073e9SAndroid Build Coastguard Worker def add_relu_softmax(x, a): 88*da0073e9SAndroid Build Coastguard Worker return (torch.softmax(torch.relu(x + a), -1),) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker @staticmethod 91*da0073e9SAndroid Build Coastguard Worker def sum(a, b): 92*da0073e9SAndroid Build Coastguard Worker return ((a + b).sum(),) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker @staticmethod 95*da0073e9SAndroid Build Coastguard Worker def view(x): 96*da0073e9SAndroid Build Coastguard Worker return (aten.alias(x),) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Workerdef main(): 100*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser() 101*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 102*da0073e9SAndroid Build Coastguard Worker "--filter", "-k", action="append", help="filter benchmarks with regexp" 103*da0073e9SAndroid Build Coastguard Worker ) 104*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 105*da0073e9SAndroid Build Coastguard Worker "--exclude", "-x", action="append", help="filter benchmarks with regexp" 106*da0073e9SAndroid Build Coastguard Worker ) 107*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--devices", "-d", action="append", help="cpu or cuda") 108*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--size", "-s", action="append", help="cpu or cuda") 109*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 110*da0073e9SAndroid Build Coastguard Worker "--repeat", "-n", type=int, default=30, help="number of timing runs" 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 113*da0073e9SAndroid Build Coastguard Worker "--threads", "-t", type=int, help="number of threads to use for eager" 114*da0073e9SAndroid Build Coastguard Worker ) 115*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 116*da0073e9SAndroid Build Coastguard Worker "--verbose", "-v", action="store_true", help="enable verbose debug printouts" 117*da0073e9SAndroid Build Coastguard Worker ) 118*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 119*da0073e9SAndroid Build Coastguard Worker "--nvfuser", action="store_true", help="enable nvfuser globally" 120*da0073e9SAndroid Build Coastguard Worker ) 121*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--transpose", action="store_true", help="transpose one input") 122*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--broadcast", action="store_true", help="broadcast one input") 123*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker # defaults 126*da0073e9SAndroid Build Coastguard Worker args.devices = args.devices or ["cpu", "cuda"] 127*da0073e9SAndroid Build Coastguard Worker args.filter = args.filter or [r"."] 128*da0073e9SAndroid Build Coastguard Worker args.exclude = args.exclude or [r"^$"] 129*da0073e9SAndroid Build Coastguard Worker args.size = args.size or [64, 256, 1024, 4096, 8192] 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker if args.nvfuser: 132*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_cpu(False) 133*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_gpu(False) 134*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_texpr_fuser_enabled(False) 135*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_nvfuser_enabled(True) 136*da0073e9SAndroid Build Coastguard Worker else: 137*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_cpu(torch._C._llvm_enabled()) 138*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_gpu(True) 139*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_texpr_fuser_enabled(True) 140*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 141*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_nvfuser_enabled(False) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker if args.threads: 144*da0073e9SAndroid Build Coastguard Worker torch.set_num_threads(args.threads) 145*da0073e9SAndroid Build Coastguard Worker torch._inductor.config.cpp.threads = args.threads 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker if args.verbose: 148*da0073e9SAndroid Build Coastguard Worker torch._inductor.config.debug = True 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker torch._inductor.config.triton.autotune_pointwise = True 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker rows = [] 153*da0073e9SAndroid Build Coastguard Worker for model in (MicroBenchmarks.sum, MicroBenchmarks.view): 154*da0073e9SAndroid Build Coastguard Worker nargs = len(inspect.signature(model).parameters) 155*da0073e9SAndroid Build Coastguard Worker for device in args.devices: 156*da0073e9SAndroid Build Coastguard Worker for n in args.size: 157*da0073e9SAndroid Build Coastguard Worker n = int(n) 158*da0073e9SAndroid Build Coastguard Worker sys.stdout.write(f"{model.__name__:10} {device:4} {n:5} ") 159*da0073e9SAndroid Build Coastguard Worker sys.stdout.flush() 160*da0073e9SAndroid Build Coastguard Worker inputs = [torch.rand((n, n), device=device) for _ in range(nargs)] 161*da0073e9SAndroid Build Coastguard Worker if args.broadcast: 162*da0073e9SAndroid Build Coastguard Worker inputs[-1] = torch.rand((1, n), device=device) 163*da0073e9SAndroid Build Coastguard Worker if args.transpose: 164*da0073e9SAndroid Build Coastguard Worker inputs[-1] = inputs[-1].transpose(0, 1) 165*da0073e9SAndroid Build Coastguard Worker result = microbenchmark(args, model, inputs) 166*da0073e9SAndroid Build Coastguard Worker rows.append([model.__name__, device, str(n)] + result) 167*da0073e9SAndroid Build Coastguard Worker print(" ".join(f"{v:.2f}x" for v in result)) 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker print( 170*da0073e9SAndroid Build Coastguard Worker tabulate.tabulate( 171*da0073e9SAndroid Build Coastguard Worker rows, 172*da0073e9SAndroid Build Coastguard Worker headers=[ 173*da0073e9SAndroid Build Coastguard Worker "model", 174*da0073e9SAndroid Build Coastguard Worker "dev", 175*da0073e9SAndroid Build Coastguard Worker "n", 176*da0073e9SAndroid Build Coastguard Worker "ts", 177*da0073e9SAndroid Build Coastguard Worker "inductor", 178*da0073e9SAndroid Build Coastguard Worker ], 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker ) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 184*da0073e9SAndroid Build Coastguard Worker main() 185