1import timeit 2from functools import partial 3 4import numpy as np 5import pandas as pd 6 7import torch 8from functorch.compile import pointwise_operator 9 10 11WRITE_CSV = False 12CUDA = False 13SIZES = [1, 512, 8192] 14NUMBER = [100, 10, 1, 1] 15REPEAT = 20 16 17 18@pointwise_operator 19def nnc_add(a, b): 20 return a + b 21 22 23@pointwise_operator 24def nnc_addnorm(a, b, mean, std): 25 return (a + b - mean) / std 26 27 28def eager_addnorm(a, b, mean, std): 29 return (a + b - mean) / std 30 31 32def inplace_addnorm(a, b, mean, std, out): 33 out = torch.add(a, b, out=out) 34 torch.sub(out, mean, out=out) 35 torch.div(out, std, out=out) 36 return out 37 38 39ts_addnorm = torch.jit.script(eager_addnorm) 40ts_ip_addnorm = torch.jit.script(inplace_addnorm) 41 42 43def maybe_synced(fn): 44 if CUDA: 45 synchronize = torch.cuda.synchronize 46 synchronize() # warmup 47 48 def _fn(): 49 result = fn() 50 synchronize() 51 return result 52 53 return _fn 54 return fn 55 56 57def benchmark_loop(setup): 58 result = np.zeros((REPEAT, len(SIZES), 2), dtype=np.float64) 59 for s, n in enumerate(SIZES): 60 nnc, aten = setup(n) 61 nnc = maybe_synced(nnc) 62 aten = maybe_synced(aten) 63 64 for r in range(result.shape[0]): 65 result[r, s, 0] = timeit.timeit(nnc, number=NUMBER[s]) 66 result[r, s, 1] = timeit.timeit(aten, number=NUMBER[s]) 67 68 result = np.median(result, axis=0) 69 assert result.shape == (len(SIZES), 2) 70 result = result[:, 1] / result[:, 0] 71 print(result) 72 return result 73 74 75def test(make_args, nnc=nnc_add, aten=torch.add): 76 def setup(n): 77 args = make_args(n) 78 result_aten = aten(*args) 79 result_nnc = nnc(*args) 80 assert result_nnc.dtype == result_aten.dtype 81 assert result_nnc.size() == result_aten.size() 82 assert result_nnc.stride() == result_aten.stride() 83 torch.testing.assert_close(result_aten, result_nnc) 84 return (lambda: nnc(*args), lambda: aten(*args)) 85 86 return benchmark_loop(setup) 87 88 89def test_inplace(make_args, nnc=nnc_add, aten=torch.add): 90 def inplace_setup(n): 91 a, b = make_args(n) 92 result_aten = torch.clone(a) 93 result_nnc = torch.clone(a) 94 nnc(result_nnc, b, out=result_nnc) 95 aten(result_aten, b, out=result_aten) 96 torch.testing.assert_close(result_aten, result_nnc) 97 return (lambda: nnc(a, b, out=a), lambda: aten(a, b, out=a)) 98 99 return benchmark_loop(inplace_setup) 100 101 102def test_out(make_args, out, nnc=nnc_add, aten=torch.add): 103 def out_setup(n): 104 args = make_args(n) 105 result_aten = out(n) 106 result_nnc = out(n) 107 aten(*args, out=result_aten) 108 nnc(*args, out=result_nnc) 109 torch.testing.assert_close(result_aten, result_nnc) 110 result = out(n) 111 return (lambda: nnc(*args, out=result), lambda: aten(*args, out=result)) 112 113 return benchmark_loop(out_setup) 114 115 116def test_backwards(make_args, nnc=nnc_add, aten=torch.add): 117 def backwards_setup(n): 118 args = make_args(n) 119 (grad_var,) = (a for a in args if a.requires_grad) 120 aten(*args).sum().backward() 121 correct = grad_var.grad.clone() 122 grad_var.grad.zero_() 123 nnc(*args).sum().backward() 124 torch.testing.assert_close(correct, grad_var.grad) 125 return ( 126 lambda: nnc(*args).sum().backward(), 127 lambda: aten(*args).sum().backward(), 128 ) 129 130 return benchmark_loop(backwards_setup) 131 132 133def main(): 134 torch.set_num_threads(1) # TODO(jansel): add parallel support 135 torch._C._jit_override_can_fuse_on_cpu(True) 136 137 device = "cuda" if CUDA else "cpu" 138 I = partial(torch.randint, 0, 100, device=device) 139 R = partial(torch.randn, device=device) 140 141 results = [ 142 ("add", test(lambda n: (R(n, n), R(n, n)))), 143 ("broadcast1", test(lambda n: (R(n, n), R(1)))), 144 ("broadcast2", test(lambda n: (R(n, n), R(n, 1)))), 145 ("broadcast3", test(lambda n: (R(n, 1), R(1, n)))), 146 ("inplace", test_inplace(lambda n: (R(n, n), R(n, 1)))), 147 ("out=", test_out(lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n))), 148 ("transposed1", test(lambda n: (R(n, n), R(n, n).transpose(0, 1)))), 149 ( 150 "transposed2", 151 test(lambda n: (R(n, n).transpose(0, 1), R(n, n).transpose(0, 1))), 152 ), 153 ("slice1", test(lambda n: (R(n + 1, n + 1, 2)[:n, :n, 0], R(n, n)))), 154 ("slice2", test(lambda n: (R(n, n, 2)[:, :, 0], R(n, n, 2)[:, :, 0]))), 155 ( 156 "strided out", 157 test_out( 158 lambda n: (R(n, n), R(n, n)), 159 out=lambda n: R(n + 1, n + 1, 2)[:n, :n, 0], 160 ), 161 ), 162 ( 163 "out convert", 164 test_out( 165 lambda n: (R(n, n), R(n, n)), out=lambda n: R(n, n, dtype=torch.float64) 166 ), 167 ), 168 ("issue #57611 (n,32,32,2)", test(lambda n: (R(1, 32, 32, 2), R(n, 1, 1, 2)))), 169 ("float+double", test(lambda n: (R(n, n), R(n, n, dtype=torch.float64)))), 170 ( 171 "int+long", 172 test( 173 lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int64)) 174 ), 175 ), 176 ( 177 "int+short", 178 test( 179 lambda n: (I([n, n], dtype=torch.int32), I([n, n], dtype=torch.int16)) 180 ), 181 ), 182 ( 183 "float+int", 184 test( 185 lambda n: (R([n, n], dtype=torch.float32), I([n, n], dtype=torch.int32)) 186 ), 187 ), 188 ( 189 "double+long", 190 test( 191 lambda n: (R([n, n], dtype=torch.float64), I([n, n], dtype=torch.int64)) 192 ), 193 ), 194 ( 195 "fused addnorm", 196 test( 197 lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), 198 nnc=nnc_addnorm, 199 aten=eager_addnorm, 200 ), 201 ), 202 ( 203 "fused addnorm (vs TS)", 204 test( 205 lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), 206 nnc=nnc_addnorm, 207 aten=ts_addnorm, 208 ), 209 ), 210 ( 211 "fused addnorm out=", 212 test_out( 213 lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), 214 nnc=nnc_addnorm, 215 aten=inplace_addnorm, 216 out=lambda n: R(n, n), 217 ), 218 ), 219 ( 220 "fused addnorm out= (vs TS)", 221 test_out( 222 lambda n: (R(n, n), R(n, n), R(n, n), R(n, n)), 223 nnc=nnc_addnorm, 224 aten=ts_ip_addnorm, 225 out=lambda n: R(n, n), 226 ), 227 ), 228 ( 229 "fused addnorm backward", 230 test_backwards( 231 lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)), 232 nnc=nnc_addnorm, 233 aten=eager_addnorm, 234 ), 235 ), 236 ( 237 "fused addnorm backward (vs TS)", 238 test_backwards( 239 lambda n: (R(n, n), R(n, n, requires_grad=True), R(n, n), R(n, n)), 240 nnc=nnc_addnorm, 241 aten=ts_addnorm, 242 ), 243 ), 244 ] 245 246 df = pd.DataFrame( 247 np.stack([r for n, r in results]), 248 columns=[f"{n}x{n}".rjust(9) for n in SIZES], 249 index=[n for n, r in results], 250 ) 251 252 if WRITE_CSV: 253 df.to_csv("../operator_authoring_results.csv") 254 print("wrote ../operator_authoring_results.csv") 255 256 print() 257 print("Speedups over aten") 258 pd.options.display.float_format = "{:.2f}x".format 259 print(df) 260 261 262if __name__ == "__main__": 263 main() 264