1#!/usr/bin/env python3 2 3import argparse 4import inspect 5import sys 6 7import numpy as np 8import tabulate 9 10import torch 11import torch._inductor 12from torch._dynamo.backends.cudagraphs import cudagraphs_inner 13from torch._dynamo.testing import same 14from torch._inductor.compile_fx import compile_fx 15from torch._inductor.utils import timed 16 17 18aten = torch.ops.aten 19 20try: 21 import test.test_torchinductor as tti 22except ImportError: 23 tti = None 24 25 26def compute_speedups(args, models, example_inputs): 27 expected = models[0](*example_inputs) 28 for model in models[1:]: 29 actual = model(*example_inputs) 30 assert same(actual, expected), expected[0] - actual[0] 31 32 timings = np.zeros((args.repeat, len(models)), np.float64) 33 for rep in range(args.repeat): 34 # interleave the runs to handle frequency scaling and load changes 35 for m, model in enumerate(models): 36 timings[rep, m] = timed(model, example_inputs) 37 median = np.median(timings, axis=0) 38 return (median[0] / median[1:]).tolist() 39 40 41def microbenchmark(args, model, example_inputs): 42 compiled_fn = compile_fx(torch.fx.symbolic_trace(model), example_inputs) 43 cudagraphs_eager = cudagraphs_inner(model, example_inputs, copy_outputs=False) 44 cudagraphs_jit = cudagraphs_inner( 45 torch.jit.trace(model, example_inputs), example_inputs, copy_outputs=False 46 ) 47 return compute_speedups( 48 args, 49 [cudagraphs_eager, cudagraphs_jit, compiled_fn], 50 example_inputs, 51 ) 52 53 54class MyModel1(torch.nn.Module): 55 def __init__(self): 56 super().__init__() 57 self.model = torch.nn.Sequential( 58 torch.nn.Linear(1024, 1024), 59 torch.nn.ReLU(), 60 ) 61 62 def forward(self, input): 63 # return (self.model(input) + 1,) 64 return (self.model(input),) 65 66 67class MyModel2(torch.nn.Module): 68 def forward(self, x, y): 69 # return x / (torch.abs(x) + 1.0), 70 return (x + y,) 71 72 73class MicroBenchmarks: 74 @staticmethod 75 def add(a, b): 76 return (a + b,) 77 78 @staticmethod 79 def scale(x, m, d): 80 return ((x - m) / torch.clip(d, 1e-4),) 81 82 @staticmethod 83 def abs_norm(x): 84 return (x / (torch.abs(x) + 1),) 85 86 @staticmethod 87 def add_relu_softmax(x, a): 88 return (torch.softmax(torch.relu(x + a), -1),) 89 90 @staticmethod 91 def sum(a, b): 92 return ((a + b).sum(),) 93 94 @staticmethod 95 def view(x): 96 return (aten.alias(x),) 97 98 99def main(): 100 parser = argparse.ArgumentParser() 101 parser.add_argument( 102 "--filter", "-k", action="append", help="filter benchmarks with regexp" 103 ) 104 parser.add_argument( 105 "--exclude", "-x", action="append", help="filter benchmarks with regexp" 106 ) 107 parser.add_argument("--devices", "-d", action="append", help="cpu or cuda") 108 parser.add_argument("--size", "-s", action="append", help="cpu or cuda") 109 parser.add_argument( 110 "--repeat", "-n", type=int, default=30, help="number of timing runs" 111 ) 112 parser.add_argument( 113 "--threads", "-t", type=int, help="number of threads to use for eager" 114 ) 115 parser.add_argument( 116 "--verbose", "-v", action="store_true", help="enable verbose debug printouts" 117 ) 118 parser.add_argument( 119 "--nvfuser", action="store_true", help="enable nvfuser globally" 120 ) 121 parser.add_argument("--transpose", action="store_true", help="transpose one input") 122 parser.add_argument("--broadcast", action="store_true", help="broadcast one input") 123 args = parser.parse_args() 124 125 # defaults 126 args.devices = args.devices or ["cpu", "cuda"] 127 args.filter = args.filter or [r"."] 128 args.exclude = args.exclude or [r"^$"] 129 args.size = args.size or [64, 256, 1024, 4096, 8192] 130 131 if args.nvfuser: 132 torch._C._jit_override_can_fuse_on_cpu(False) 133 torch._C._jit_override_can_fuse_on_gpu(False) 134 torch._C._jit_set_texpr_fuser_enabled(False) 135 torch._C._jit_set_nvfuser_enabled(True) 136 else: 137 torch._C._jit_override_can_fuse_on_cpu(torch._C._llvm_enabled()) 138 torch._C._jit_override_can_fuse_on_gpu(True) 139 torch._C._jit_set_texpr_fuser_enabled(True) 140 if torch.cuda.is_available(): 141 torch._C._jit_set_nvfuser_enabled(False) 142 143 if args.threads: 144 torch.set_num_threads(args.threads) 145 torch._inductor.config.cpp.threads = args.threads 146 147 if args.verbose: 148 torch._inductor.config.debug = True 149 150 torch._inductor.config.triton.autotune_pointwise = True 151 152 rows = [] 153 for model in (MicroBenchmarks.sum, MicroBenchmarks.view): 154 nargs = len(inspect.signature(model).parameters) 155 for device in args.devices: 156 for n in args.size: 157 n = int(n) 158 sys.stdout.write(f"{model.__name__:10} {device:4} {n:5} ") 159 sys.stdout.flush() 160 inputs = [torch.rand((n, n), device=device) for _ in range(nargs)] 161 if args.broadcast: 162 inputs[-1] = torch.rand((1, n), device=device) 163 if args.transpose: 164 inputs[-1] = inputs[-1].transpose(0, 1) 165 result = microbenchmark(args, model, inputs) 166 rows.append([model.__name__, device, str(n)] + result) 167 print(" ".join(f"{v:.2f}x" for v in result)) 168 169 print( 170 tabulate.tabulate( 171 rows, 172 headers=[ 173 "model", 174 "dev", 175 "n", 176 "ts", 177 "inductor", 178 ], 179 ) 180 ) 181 182 183if __name__ == "__main__": 184 main() 185