1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["NNC"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport numpy as np 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 6*da0073e9SAndroid Build Coastguard Workerfrom torch import nn 7*da0073e9SAndroid Build Coastguard Workerimport unittest 8*da0073e9SAndroid Build Coastguard Workerimport itertools 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard WorkerLLVM_ENABLED = torch._C._llvm_enabled() 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerclass BaseTestClass(JitTestCase): 17*da0073e9SAndroid Build Coastguard Worker def setUp(self): 18*da0073e9SAndroid Build Coastguard Worker super().setUp() 19*da0073e9SAndroid Build Coastguard Worker self.tensorexpr_options = TensorExprTestOptions() 20*da0073e9SAndroid Build Coastguard Worker self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] 21*da0073e9SAndroid Build Coastguard Worker self.dtypes = [torch.float32, torch.bfloat16] if LLVM_ENABLED else [torch.float32] 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 24*da0073e9SAndroid Build Coastguard Worker self.tensorexpr_options.restore() 25*da0073e9SAndroid Build Coastguard Worker super().tearDown() 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker def assertLastGraphAllFused(self): 28*da0073e9SAndroid Build Coastguard Worker self.assertAllFused(torch.jit.last_executed_optimized_graph()) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Workerdef warmup_and_run_forward(f, *args): 32*da0073e9SAndroid Build Coastguard Worker for _ in range(torch._C._jit_get_num_profiled_runs() + 1): 33*da0073e9SAndroid Build Coastguard Worker results = f(*args) 34*da0073e9SAndroid Build Coastguard Worker return results 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker@skipIfTorchDynamo() 38*da0073e9SAndroid Build Coastguard Workerclass TestTensorExprFuser(BaseTestClass): 39*da0073e9SAndroid Build Coastguard Worker def test_easy(self): 40*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 41*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, y) 42*da0073e9SAndroid Build Coastguard Worker return aaa 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024))) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1024) 47*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1024) 48*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 49*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 50*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def test_three_arg(self): 53*da0073e9SAndroid Build Coastguard Worker def easy(x, y, z): 54*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, y) 55*da0073e9SAndroid Build Coastguard Worker bbb = torch.add(aaa, z) 56*da0073e9SAndroid Build Coastguard Worker return bbb 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 59*da0073e9SAndroid Build Coastguard Worker easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) 60*da0073e9SAndroid Build Coastguard Worker ) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1024) 63*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1024) 64*da0073e9SAndroid Build Coastguard Worker c = torch.rand(1024) 65*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b, c) 66*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 67*da0073e9SAndroid Build Coastguard Worker npr = a.numpy() + b.numpy() + c.numpy() 68*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(npr, x.numpy()) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker def test_four_arg(self): 71*da0073e9SAndroid Build Coastguard Worker def run_addcmul(x, y, z, w): 72*da0073e9SAndroid Build Coastguard Worker c = torch.addcmul(torch.add(x, y), z, w) 73*da0073e9SAndroid Build Coastguard Worker return c 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker for dev in self.devices: 76*da0073e9SAndroid Build Coastguard Worker rand_a = torch.rand(1024, dtype=torch.float, device=dev) 77*da0073e9SAndroid Build Coastguard Worker rand_b = torch.rand(1024, dtype=torch.float, device=dev) 78*da0073e9SAndroid Build Coastguard Worker rand_c = torch.rand(1024, dtype=torch.float, device=dev) 79*da0073e9SAndroid Build Coastguard Worker rand_d = torch.rand(1024, dtype=torch.float, device=dev) 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 82*da0073e9SAndroid Build Coastguard Worker run_addcmul, 83*da0073e9SAndroid Build Coastguard Worker ( 84*da0073e9SAndroid Build Coastguard Worker torch.zeros(1024, dtype=torch.float, device=dev), 85*da0073e9SAndroid Build Coastguard Worker torch.zeros(1024, dtype=torch.float, device=dev), 86*da0073e9SAndroid Build Coastguard Worker torch.zeros(1024, dtype=torch.float, device=dev), 87*da0073e9SAndroid Build Coastguard Worker torch.zeros(1024, dtype=torch.float, device=dev), 88*da0073e9SAndroid Build Coastguard Worker ), 89*da0073e9SAndroid Build Coastguard Worker ) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, rand_a, rand_b, rand_c, rand_d) 92*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 93*da0073e9SAndroid Build Coastguard Worker y = run_addcmul(rand_a, rand_b, rand_c, rand_d) 94*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker def test_three_arg2(self): 97*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 98*da0073e9SAndroid Build Coastguard Worker def test(x, y, z): 99*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, y) 100*da0073e9SAndroid Build Coastguard Worker bbb = torch.add(aaa, z) 101*da0073e9SAndroid Build Coastguard Worker return bbb 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker M = 32 104*da0073e9SAndroid Build Coastguard Worker N = 32 105*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 106*da0073e9SAndroid Build Coastguard Worker test, 107*da0073e9SAndroid Build Coastguard Worker ( 108*da0073e9SAndroid Build Coastguard Worker torch.rand(M, N, device=device), 109*da0073e9SAndroid Build Coastguard Worker torch.rand(M, N, device=device), 110*da0073e9SAndroid Build Coastguard Worker torch.rand(M, N, device=device), 111*da0073e9SAndroid Build Coastguard Worker ), 112*da0073e9SAndroid Build Coastguard Worker ) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker a = torch.rand(M, N, device=device) 115*da0073e9SAndroid Build Coastguard Worker b = torch.rand(M, N, device=device) 116*da0073e9SAndroid Build Coastguard Worker c = torch.rand(M, N, device=device) 117*da0073e9SAndroid Build Coastguard Worker x = traced(a, b, c) 118*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b, c) 119*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 120*da0073e9SAndroid Build Coastguard Worker npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() 121*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(npr, x.cpu().numpy()) 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker def test_broadcast3(self): 124*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 125*da0073e9SAndroid Build Coastguard Worker def test_body(M, N, L, K): 126*da0073e9SAndroid Build Coastguard Worker def test(x, y, z): 127*da0073e9SAndroid Build Coastguard Worker v1 = torch.add(x, y) 128*da0073e9SAndroid Build Coastguard Worker v2 = torch.add(v1, z) 129*da0073e9SAndroid Build Coastguard Worker return v2 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker a_shape = [M, N] 132*da0073e9SAndroid Build Coastguard Worker b_shape = [L, M, 1] 133*da0073e9SAndroid Build Coastguard Worker c_shape = [K, L, 1, 1] 134*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 135*da0073e9SAndroid Build Coastguard Worker test, 136*da0073e9SAndroid Build Coastguard Worker ( 137*da0073e9SAndroid Build Coastguard Worker torch.rand(*a_shape, device=device), 138*da0073e9SAndroid Build Coastguard Worker torch.rand(*b_shape, device=device), 139*da0073e9SAndroid Build Coastguard Worker torch.rand(*c_shape, device=device), 140*da0073e9SAndroid Build Coastguard Worker ), 141*da0073e9SAndroid Build Coastguard Worker ) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker a = torch.rand(*a_shape, device=device) 144*da0073e9SAndroid Build Coastguard Worker b = torch.rand(*b_shape, device=device) 145*da0073e9SAndroid Build Coastguard Worker c = torch.rand(*c_shape, device=device) 146*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b, c) 147*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 148*da0073e9SAndroid Build Coastguard Worker npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() 149*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(npr, x.cpu().numpy()) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker test_configs = [[5, 2, 7, 3], [8, 8, 8, 8]] 152*da0073e9SAndroid Build Coastguard Worker for test_config in test_configs: 153*da0073e9SAndroid Build Coastguard Worker test_body(*test_config) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker def test_all_combos(self): 156*da0073e9SAndroid Build Coastguard Worker def easy(x, y, z): 157*da0073e9SAndroid Build Coastguard Worker a = torch.add(x, y) 158*da0073e9SAndroid Build Coastguard Worker b = torch.add(a, z) 159*da0073e9SAndroid Build Coastguard Worker c = torch.add(x, b) 160*da0073e9SAndroid Build Coastguard Worker d = torch.add(c, a) 161*da0073e9SAndroid Build Coastguard Worker return d 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker def np_easy(x, y, z): 164*da0073e9SAndroid Build Coastguard Worker a = x + y 165*da0073e9SAndroid Build Coastguard Worker b = a + z 166*da0073e9SAndroid Build Coastguard Worker c = x + b 167*da0073e9SAndroid Build Coastguard Worker d = c + a 168*da0073e9SAndroid Build Coastguard Worker return d 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 171*da0073e9SAndroid Build Coastguard Worker easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) 172*da0073e9SAndroid Build Coastguard Worker ) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1024) 175*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1024) 176*da0073e9SAndroid Build Coastguard Worker c = torch.rand(1024) 177*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b, c) 178*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 179*da0073e9SAndroid Build Coastguard Worker npr = np_easy(a.numpy(), b.numpy(), c.numpy()) 180*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(npr, x.numpy()) 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker def test_rank_two(self): 183*da0073e9SAndroid Build Coastguard Worker def easy(x, y, z): 184*da0073e9SAndroid Build Coastguard Worker a = torch.add(x, y) 185*da0073e9SAndroid Build Coastguard Worker b = torch.add(a, z) 186*da0073e9SAndroid Build Coastguard Worker c = torch.add(x, b) 187*da0073e9SAndroid Build Coastguard Worker d = torch.add(c, a) 188*da0073e9SAndroid Build Coastguard Worker return d 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker def np_easy(x, y, z): 191*da0073e9SAndroid Build Coastguard Worker a = x + y 192*da0073e9SAndroid Build Coastguard Worker b = a + z 193*da0073e9SAndroid Build Coastguard Worker c = x + b 194*da0073e9SAndroid Build Coastguard Worker d = c + a 195*da0073e9SAndroid Build Coastguard Worker return d 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker shape = 32, 32 198*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 199*da0073e9SAndroid Build Coastguard Worker easy, (torch.rand(shape), torch.rand(shape), torch.rand(shape)) 200*da0073e9SAndroid Build Coastguard Worker ) 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker a = torch.rand(shape) 203*da0073e9SAndroid Build Coastguard Worker b = torch.rand(shape) 204*da0073e9SAndroid Build Coastguard Worker c = torch.rand(shape) 205*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b, c) 206*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 207*da0073e9SAndroid Build Coastguard Worker npr = np_easy(a.numpy(), b.numpy(), c.numpy()) 208*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(npr, x.numpy()) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker def test_broadcast(self): 211*da0073e9SAndroid Build Coastguard Worker def easy(x, y, z): 212*da0073e9SAndroid Build Coastguard Worker a = torch.add(x, y) 213*da0073e9SAndroid Build Coastguard Worker b = torch.add(a, z) 214*da0073e9SAndroid Build Coastguard Worker return b 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker def np_easy(x, y, z): 217*da0073e9SAndroid Build Coastguard Worker a = x + y 218*da0073e9SAndroid Build Coastguard Worker b = a + z 219*da0073e9SAndroid Build Coastguard Worker return b 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker N = 32 222*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.rand(N, N), torch.rand(N), torch.rand(N, N))) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker a = torch.rand(N, N) 225*da0073e9SAndroid Build Coastguard Worker b = torch.rand(N) 226*da0073e9SAndroid Build Coastguard Worker c = torch.rand(N, N) 227*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b, c) 228*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 229*da0073e9SAndroid Build Coastguard Worker npr = np_easy(a.numpy(), b.numpy(), c.numpy()) 230*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(npr, x.numpy()) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def test_broadcast_2(self): 233*da0073e9SAndroid Build Coastguard Worker zero = torch.tensor([0.0], dtype=torch.float) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker def foo(x, y, z): 236*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, y) 237*da0073e9SAndroid Build Coastguard Worker bbb = torch.add(zero, aaa) 238*da0073e9SAndroid Build Coastguard Worker return torch.add(bbb, z) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker def foo_np(x, y, z): 241*da0073e9SAndroid Build Coastguard Worker a = x + y 242*da0073e9SAndroid Build Coastguard Worker b = zero.numpy() + a 243*da0073e9SAndroid Build Coastguard Worker return b + z 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 246*da0073e9SAndroid Build Coastguard Worker y = torch.ones(3, 1) 247*da0073e9SAndroid Build Coastguard Worker z = torch.rand(4) 248*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, (x, y, z)) 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Worker r = warmup_and_run_forward(traced, x, y, z) 251*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) 254*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(r, rnp) 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker def test_broadcast_big2(self): 257*da0073e9SAndroid Build Coastguard Worker zero = torch.tensor([0.0], dtype=torch.float) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker def foo(x, y, z): 260*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, y) 261*da0073e9SAndroid Build Coastguard Worker bbb = torch.add(zero, aaa) 262*da0073e9SAndroid Build Coastguard Worker return torch.add(bbb, z) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker def foo_np(x, y, z): 265*da0073e9SAndroid Build Coastguard Worker a = x + y 266*da0073e9SAndroid Build Coastguard Worker b = zero.numpy() + a 267*da0073e9SAndroid Build Coastguard Worker return b + z 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker x = torch.rand(32, 1024) 270*da0073e9SAndroid Build Coastguard Worker y = torch.ones(32, 1) 271*da0073e9SAndroid Build Coastguard Worker z = torch.rand(1024) 272*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, (x, y, z)) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker r = warmup_and_run_forward(traced, x, y, z) 275*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 276*da0073e9SAndroid Build Coastguard Worker rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) 277*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(r, rnp) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker def test_alpha(self): 280*da0073e9SAndroid Build Coastguard Worker def alpha(x): 281*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, x, alpha=2.0) 282*da0073e9SAndroid Build Coastguard Worker return aaa 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(alpha, (torch.tensor([1.0]))) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0]) 287*da0073e9SAndroid Build Coastguard Worker x = traced(a) 288*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(a.numpy() + 2.0 * a.numpy(), x.numpy()) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 291*da0073e9SAndroid Build Coastguard Worker def test_constant(self): 292*da0073e9SAndroid Build Coastguard Worker def constant(x): 293*da0073e9SAndroid Build Coastguard Worker bbb = torch.tensor([1.0]) 294*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, bbb) 295*da0073e9SAndroid Build Coastguard Worker return aaa 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(constant, (torch.tensor([1.0]))) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0]) 300*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a) 301*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 302*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(a.numpy() + 1.0, x.numpy()) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker def test_add_sub(self): 305*da0073e9SAndroid Build Coastguard Worker def easy(x, y, z): 306*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, y) 307*da0073e9SAndroid Build Coastguard Worker bbb = torch.sub(aaa, z) 308*da0073e9SAndroid Build Coastguard Worker return bbb 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 311*da0073e9SAndroid Build Coastguard Worker easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) 312*da0073e9SAndroid Build Coastguard Worker ) 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1024) 315*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1024) 316*da0073e9SAndroid Build Coastguard Worker c = torch.rand(1024) 317*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b, c) 318*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 319*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy()) 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker def test_promotion(self): 322*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 323*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, y) 324*da0073e9SAndroid Build Coastguard Worker return aaa 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 327*da0073e9SAndroid Build Coastguard Worker easy, 328*da0073e9SAndroid Build Coastguard Worker (torch.zeros(1024, dtype=torch.int32), torch.rand(1024, dtype=torch.float32)), 329*da0073e9SAndroid Build Coastguard Worker ) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(1024, dtype=torch.int32) 332*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1024, dtype=torch.float32) 333*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 334*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 335*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker def test_double(self): 338*da0073e9SAndroid Build Coastguard Worker TENSOR_LEN = 8 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 341*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, y) 342*da0073e9SAndroid Build Coastguard Worker bbb = torch.mul(aaa, y) 343*da0073e9SAndroid Build Coastguard Worker return bbb 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 346*da0073e9SAndroid Build Coastguard Worker easy, 347*da0073e9SAndroid Build Coastguard Worker (torch.rand(TENSOR_LEN, dtype=torch.float64), torch.full((TENSOR_LEN,), 0.5, dtype=torch.float64)), 348*da0073e9SAndroid Build Coastguard Worker ) 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker a = torch.rand(TENSOR_LEN, dtype=torch.double) 351*da0073e9SAndroid Build Coastguard Worker b = torch.full((TENSOR_LEN,), 0.5, dtype=torch.double) 352*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 353*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 354*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker def test_short(self): 357*da0073e9SAndroid Build Coastguard Worker TENSOR_LEN = 8 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 360*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, y) 361*da0073e9SAndroid Build Coastguard Worker bbb = torch.mul(aaa, y) 362*da0073e9SAndroid Build Coastguard Worker return bbb 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 365*da0073e9SAndroid Build Coastguard Worker easy, 366*da0073e9SAndroid Build Coastguard Worker (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16), 367*da0073e9SAndroid Build Coastguard Worker torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)), 368*da0073e9SAndroid Build Coastguard Worker ) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) 371*da0073e9SAndroid Build Coastguard Worker b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16) 372*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 373*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 374*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker def test_char(self): 377*da0073e9SAndroid Build Coastguard Worker TENSOR_LEN = 8 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 380*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, y) 381*da0073e9SAndroid Build Coastguard Worker bbb = torch.mul(aaa, y) 382*da0073e9SAndroid Build Coastguard Worker return bbb 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 385*da0073e9SAndroid Build Coastguard Worker easy, 386*da0073e9SAndroid Build Coastguard Worker (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8), 387*da0073e9SAndroid Build Coastguard Worker torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)), 388*da0073e9SAndroid Build Coastguard Worker ) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) 391*da0073e9SAndroid Build Coastguard Worker b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) 392*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 393*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 394*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker def test_int64_promotion(self): 397*da0073e9SAndroid Build Coastguard Worker TENSOR_LEN = 8 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 400*da0073e9SAndroid Build Coastguard Worker aaa = torch.add(x, y) 401*da0073e9SAndroid Build Coastguard Worker bbb = torch.mul(aaa, y) 402*da0073e9SAndroid Build Coastguard Worker return bbb 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace( 405*da0073e9SAndroid Build Coastguard Worker easy, 406*da0073e9SAndroid Build Coastguard Worker (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8), 407*da0073e9SAndroid Build Coastguard Worker torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)), 408*da0073e9SAndroid Build Coastguard Worker ) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8) 411*da0073e9SAndroid Build Coastguard Worker b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64) 412*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 413*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 414*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy()) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker def test_eq(self): 417*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 418*da0073e9SAndroid Build Coastguard Worker c = torch.eq(x, y) 419*da0073e9SAndroid Build Coastguard Worker return c 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) 422*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(1024, dtype=torch.int32) 423*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(1024, dtype=torch.int32) 424*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 425*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 426*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(np.ones(1024), x.numpy()) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker def test_ne(self): 429*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 430*da0073e9SAndroid Build Coastguard Worker c = torch.ne(x, y) 431*da0073e9SAndroid Build Coastguard Worker return c 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) 434*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(1024, dtype=torch.int32) 435*da0073e9SAndroid Build Coastguard Worker b = torch.ones(1024, dtype=torch.int32) 436*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 437*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 438*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(np.ones(1024), x.numpy()) 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker def test_ge(self): 441*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 442*da0073e9SAndroid Build Coastguard Worker c = torch.ge(x, y) 443*da0073e9SAndroid Build Coastguard Worker return c 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) 446*da0073e9SAndroid Build Coastguard Worker aa = np.empty([1024], dtype=np.int32) 447*da0073e9SAndroid Build Coastguard Worker aa.fill(5) 448*da0073e9SAndroid Build Coastguard Worker a = torch.from_numpy(aa) 449*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(1024, dtype=torch.int32) 450*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 451*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 452*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(np.ones(1024), x.numpy()) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker def test_gt(self): 455*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 456*da0073e9SAndroid Build Coastguard Worker c = torch.gt(x, y) 457*da0073e9SAndroid Build Coastguard Worker return c 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) 460*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1024, dtype=torch.int32) 461*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(1024, dtype=torch.int32) 462*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 463*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 464*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(np.ones(1024), x.numpy()) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker def test_le(self): 467*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 468*da0073e9SAndroid Build Coastguard Worker c = torch.le(x, y) 469*da0073e9SAndroid Build Coastguard Worker return c 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) 472*da0073e9SAndroid Build Coastguard Worker aa = np.empty([1024], dtype=np.int32) 473*da0073e9SAndroid Build Coastguard Worker aa.fill(5) 474*da0073e9SAndroid Build Coastguard Worker a = torch.from_numpy(aa) 475*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(1024, dtype=torch.int32) 476*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 477*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 478*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(np.zeros(1024), x.numpy()) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker def test_lt(self): 481*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 482*da0073e9SAndroid Build Coastguard Worker c = torch.lt(x, y) 483*da0073e9SAndroid Build Coastguard Worker return c 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker for dev in self.devices: 486*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev))) 487*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1024, dtype=torch.int32, device=dev) 488*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(1024, dtype=torch.int32, device=dev) 489*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 490*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 491*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy()) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 494*da0073e9SAndroid Build Coastguard Worker def test_min_max(self): 495*da0073e9SAndroid Build Coastguard Worker def test(x, y): 496*da0073e9SAndroid Build Coastguard Worker return torch.max(torch.min(x, y), torch.tensor([4.0])) 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (torch.zeros(1024), torch.zeros(1024))) 499*da0073e9SAndroid Build Coastguard Worker a = 8.0 * torch.rand(1024) 500*da0073e9SAndroid Build Coastguard Worker b = 8.0 * torch.rand(1024) 501*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose( 502*da0073e9SAndroid Build Coastguard Worker warmup_and_run_forward(traced, a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0]) 503*da0073e9SAndroid Build Coastguard Worker ) 504*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker def test_min_max_reduction(self): 507*da0073e9SAndroid Build Coastguard Worker def test(x): 508*da0073e9SAndroid Build Coastguard Worker return torch.min(x) + torch.max(x) 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (torch.zeros(1024))) 511*da0073e9SAndroid Build Coastguard Worker a = 8.0 * torch.rand(1024) 512*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) 513*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker def test_min_max_reduction2(self): 516*da0073e9SAndroid Build Coastguard Worker def test(x): 517*da0073e9SAndroid Build Coastguard Worker return x.min() + x.max() 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (torch.zeros(1024))) 520*da0073e9SAndroid Build Coastguard Worker a = 8.0 * torch.rand(1024) 521*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy())) 522*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker def test_min_max_reduction_dim1(self): 525*da0073e9SAndroid Build Coastguard Worker def test(x): 526*da0073e9SAndroid Build Coastguard Worker return torch.min(x, 1)[0] + torch.max(x, 1)[0] 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (torch.zeros(16, 16))) 529*da0073e9SAndroid Build Coastguard Worker a = 8.0 * torch.rand(16, 16) 530*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin( 531*da0073e9SAndroid Build Coastguard Worker a.numpy(), axis=1) + np.amax(a.numpy(), axis=1)) 532*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker def test_min_max_reduction_dim1_2(self): 535*da0073e9SAndroid Build Coastguard Worker def test(x): 536*da0073e9SAndroid Build Coastguard Worker return torch.min(x * x, 1) 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (torch.zeros(16, 16))) 539*da0073e9SAndroid Build Coastguard Worker a = 8.0 * torch.rand(16, 16) 540*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin((a * a).numpy(), axis=1)) 541*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker def test_clamp(self): 544*da0073e9SAndroid Build Coastguard Worker def test(x): 545*da0073e9SAndroid Build Coastguard Worker return torch.clamp(x + 3.0, 0.0, 6.0) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker for dev in self.devices: 548*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) 549*da0073e9SAndroid Build Coastguard Worker a = 20.0 * torch.rand(1024, device=dev) - 10.0 550*da0073e9SAndroid Build Coastguard Worker an = a.cpu().numpy() 551*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip(an + 3.0, 0.0, 6.0)) 552*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker def test_relu(self): 555*da0073e9SAndroid Build Coastguard Worker def test(x): 556*da0073e9SAndroid Build Coastguard Worker return torch.clamp(F.relu(x), 0, 0.5) 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker for dev in self.devices: 559*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) 560*da0073e9SAndroid Build Coastguard Worker a = 20.0 * torch.rand(1024, device=dev) - 10.0 561*da0073e9SAndroid Build Coastguard Worker an = a.cpu().numpy() 562*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5)) 563*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker def test_reps(self): 566*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 567*da0073e9SAndroid Build Coastguard Worker c = torch.add(x, y) 568*da0073e9SAndroid Build Coastguard Worker return c 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024))) 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker for _ in range(32): 573*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1024) 574*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(1024) 575*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 576*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(np.ones(1024), x.numpy()) 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker def test_add_const_rhs(self): 579*da0073e9SAndroid Build Coastguard Worker def test(x): 580*da0073e9SAndroid Build Coastguard Worker return x + 3.0 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, torch.rand(4)) 583*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 584*da0073e9SAndroid Build Coastguard Worker y = warmup_and_run_forward(traced, x) 585*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 586*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(x.numpy() + 3.0, y.numpy()) 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker def test_int_output(self): 589*da0073e9SAndroid Build Coastguard Worker def test(x, y, z): 590*da0073e9SAndroid Build Coastguard Worker return x * y * z 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker xs = [(torch.rand(4) * 3 + 1).to(torch.int32) for i in range(3)] 593*da0073e9SAndroid Build Coastguard Worker x, y, z = xs 594*da0073e9SAndroid Build Coastguard Worker xn, yn, zn = (t.numpy() for t in xs) 595*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (x, y, z)) 596*da0073e9SAndroid Build Coastguard Worker res = warmup_and_run_forward(traced, x, y, z) 597*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 598*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(xn * yn * zn, res.numpy()) 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker def test_binary_ops(self): 601*da0073e9SAndroid Build Coastguard Worker def test_atan2(x, y): 602*da0073e9SAndroid Build Coastguard Worker c = torch.atan2(torch.add(x, y), y) 603*da0073e9SAndroid Build Coastguard Worker return c 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker def test_gt(x, y): 606*da0073e9SAndroid Build Coastguard Worker c = torch.gt(torch.add(x, y), y) 607*da0073e9SAndroid Build Coastguard Worker return c 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker def test_ge(x, y): 610*da0073e9SAndroid Build Coastguard Worker c = torch.ge(torch.add(x, y), y) 611*da0073e9SAndroid Build Coastguard Worker return c 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker def test_lt(x, y): 614*da0073e9SAndroid Build Coastguard Worker c = torch.lt(torch.add(x, y), y) 615*da0073e9SAndroid Build Coastguard Worker return c 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker def test_le(x, y): 618*da0073e9SAndroid Build Coastguard Worker c = torch.le(torch.add(x, y), y) 619*da0073e9SAndroid Build Coastguard Worker return c 620*da0073e9SAndroid Build Coastguard Worker 621*da0073e9SAndroid Build Coastguard Worker def test_lerp(x, y): 622*da0073e9SAndroid Build Coastguard Worker c = torch.lerp(torch.add(x, 1), x, 2.0) 623*da0073e9SAndroid Build Coastguard Worker return c 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Worker def test_mul(x, y): 626*da0073e9SAndroid Build Coastguard Worker c = torch.mul(torch.add(x, y), y) 627*da0073e9SAndroid Build Coastguard Worker return c 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker def test_ne(x, y): 630*da0073e9SAndroid Build Coastguard Worker c = torch.ne(torch.add(x, y), y) 631*da0073e9SAndroid Build Coastguard Worker return c 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker def test_div(x, y): 634*da0073e9SAndroid Build Coastguard Worker c = torch.div(torch.add(x, y), 2) 635*da0073e9SAndroid Build Coastguard Worker return c 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker def test_eq(x, y): 638*da0073e9SAndroid Build Coastguard Worker c = torch.eq(torch.add(x, y), y) 639*da0073e9SAndroid Build Coastguard Worker return c 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker def test_fmod(x, y): 642*da0073e9SAndroid Build Coastguard Worker c = torch.fmod(torch.add(x, y), 2) 643*da0073e9SAndroid Build Coastguard Worker return c 644*da0073e9SAndroid Build Coastguard Worker 645*da0073e9SAndroid Build Coastguard Worker def test_sub(x, y): 646*da0073e9SAndroid Build Coastguard Worker c = torch.sub(torch.add(x, y), x) 647*da0073e9SAndroid Build Coastguard Worker return c 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker def test_remainder(x, y): 650*da0073e9SAndroid Build Coastguard Worker c = torch.remainder(torch.add(x, y), 3.0) 651*da0073e9SAndroid Build Coastguard Worker return c 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker def test_pow(x, y): 654*da0073e9SAndroid Build Coastguard Worker c = torch.pow(torch.add(x, y), 2.0) 655*da0073e9SAndroid Build Coastguard Worker return c 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker def test_type_as(x, y): 658*da0073e9SAndroid Build Coastguard Worker return x.type_as(torch.add(x, y)) 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker cmp_fns = { 661*da0073e9SAndroid Build Coastguard Worker test_gt, 662*da0073e9SAndroid Build Coastguard Worker test_ge, 663*da0073e9SAndroid Build Coastguard Worker test_lt, 664*da0073e9SAndroid Build Coastguard Worker test_le, 665*da0073e9SAndroid Build Coastguard Worker test_ne, 666*da0073e9SAndroid Build Coastguard Worker test_eq 667*da0073e9SAndroid Build Coastguard Worker } 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker non_cmp_fns = { 670*da0073e9SAndroid Build Coastguard Worker test_atan2, 671*da0073e9SAndroid Build Coastguard Worker test_lerp, 672*da0073e9SAndroid Build Coastguard Worker test_mul, 673*da0073e9SAndroid Build Coastguard Worker test_div, 674*da0073e9SAndroid Build Coastguard Worker test_fmod, 675*da0073e9SAndroid Build Coastguard Worker test_sub, 676*da0073e9SAndroid Build Coastguard Worker test_remainder, 677*da0073e9SAndroid Build Coastguard Worker test_pow, 678*da0073e9SAndroid Build Coastguard Worker test_type_as, 679*da0073e9SAndroid Build Coastguard Worker } 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker all_test_fns = cmp_fns.union(non_cmp_fns) 682*da0073e9SAndroid Build Coastguard Worker fn_dev_dtype = itertools.product(all_test_fns, self.devices, self.dtypes) 683*da0073e9SAndroid Build Coastguard Worker for torch_fn, dev, data_type in fn_dev_dtype: 684*da0073e9SAndroid Build Coastguard Worker if torch_fn is test_lerp and data_type is torch.bfloat16: 685*da0073e9SAndroid Build Coastguard Worker continue 686*da0073e9SAndroid Build Coastguard Worker rand_a = torch.rand(1024, dtype=data_type, device=dev) 687*da0073e9SAndroid Build Coastguard Worker rand_b = torch.rand(1024, dtype=data_type, device=dev) 688*da0073e9SAndroid Build Coastguard Worker in1 = 20 * torch.rand(1024, dtype=data_type, device=dev) 689*da0073e9SAndroid Build Coastguard Worker in2 = 20 * torch.rand(1024, dtype=data_type, device=dev) 690*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(torch_fn, (in1, in2)) 691*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, rand_a, rand_b) 692*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker _atol = 2e-3 695*da0073e9SAndroid Build Coastguard Worker _rtol = 1e-5 696*da0073e9SAndroid Build Coastguard Worker if data_type is torch.bfloat16: 697*da0073e9SAndroid Build Coastguard Worker # Compared to aten logic, NNC coudl save addtional BF16/Fp32 conversion. 698*da0073e9SAndroid Build Coastguard Worker # Take d = a + b - c as an example, the aten logic is as follows at 699*da0073e9SAndroid Build Coastguard Worker # operator level: 700*da0073e9SAndroid Build Coastguard Worker # tmp = to_bf16(to_fp32(a) + to_fp32(b)) 701*da0073e9SAndroid Build Coastguard Worker # d = to_bf16(to_fp32(tmp) + to_fp32(c)) 702*da0073e9SAndroid Build Coastguard Worker # But NNC could fuse the compression and remove the redudant conversions. 703*da0073e9SAndroid Build Coastguard Worker # The final statement is as follows 704*da0073e9SAndroid Build Coastguard Worker # d = to_bf16(to_fp32(a) + to_fp32(b) + to_fp32(c)) 705*da0073e9SAndroid Build Coastguard Worker # Hence, we simulate NNC computation by feeding fp32 tensors and converting 706*da0073e9SAndroid Build Coastguard Worker # the result tensor back to bf16. The simulation could avoid the numeric 707*da0073e9SAndroid Build Coastguard Worker # deviation to simplify the result comprasion 708*da0073e9SAndroid Build Coastguard Worker y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float()) 709*da0073e9SAndroid Build Coastguard Worker if torch_fn not in cmp_fns: 710*da0073e9SAndroid Build Coastguard Worker y = y.bfloat16() 711*da0073e9SAndroid Build Coastguard Worker _atol = 2e-2 712*da0073e9SAndroid Build Coastguard Worker else: 713*da0073e9SAndroid Build Coastguard Worker y = torch_fn(rand_a, rand_b) 714*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker def test_unary_ops(self): 717*da0073e9SAndroid Build Coastguard Worker def test_cast_float(x, y): 718*da0073e9SAndroid Build Coastguard Worker c = torch.ops.aten._cast_Float(torch.add(x, y)) 719*da0073e9SAndroid Build Coastguard Worker return c 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker def test_round(x, y): 722*da0073e9SAndroid Build Coastguard Worker c = torch.round(torch.add(x, y)) 723*da0073e9SAndroid Build Coastguard Worker return c 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker def test_sin(x, y): 726*da0073e9SAndroid Build Coastguard Worker c = torch.sin(torch.add(x, y)) 727*da0073e9SAndroid Build Coastguard Worker return c 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker def test_asin(x, y): 730*da0073e9SAndroid Build Coastguard Worker c = torch.asin(torch.add(x, y)) 731*da0073e9SAndroid Build Coastguard Worker return c 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker def test_sinh(x, y): 734*da0073e9SAndroid Build Coastguard Worker c = torch.sinh(torch.add(x, y)) 735*da0073e9SAndroid Build Coastguard Worker return c 736*da0073e9SAndroid Build Coastguard Worker 737*da0073e9SAndroid Build Coastguard Worker def test_cos(x, y): 738*da0073e9SAndroid Build Coastguard Worker c = torch.cos(torch.add(x, y)) 739*da0073e9SAndroid Build Coastguard Worker return c 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker def test_acos(x, y): 742*da0073e9SAndroid Build Coastguard Worker c = torch.acos(torch.add(x, y)) 743*da0073e9SAndroid Build Coastguard Worker return c 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker def test_cosh(x, y): 746*da0073e9SAndroid Build Coastguard Worker c = torch.cosh(torch.add(x, y)) 747*da0073e9SAndroid Build Coastguard Worker return c 748*da0073e9SAndroid Build Coastguard Worker 749*da0073e9SAndroid Build Coastguard Worker def test_tan(x, y): 750*da0073e9SAndroid Build Coastguard Worker c = torch.tan(torch.add(x, y)) 751*da0073e9SAndroid Build Coastguard Worker return c 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker def test_atan(x, y): 754*da0073e9SAndroid Build Coastguard Worker c = torch.atan(torch.add(x, y)) 755*da0073e9SAndroid Build Coastguard Worker return c 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker def test_tanh(x, y): 758*da0073e9SAndroid Build Coastguard Worker c = torch.tanh(torch.add(x, y)) 759*da0073e9SAndroid Build Coastguard Worker return c 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker def test_sqrt(x, y): 762*da0073e9SAndroid Build Coastguard Worker c = torch.sqrt(torch.add(x, y)) 763*da0073e9SAndroid Build Coastguard Worker return c 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker def test_rsqrt(x, y): 766*da0073e9SAndroid Build Coastguard Worker c = torch.rsqrt(torch.add(x, y)) 767*da0073e9SAndroid Build Coastguard Worker return c 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker def test_floor(x, y): 770*da0073e9SAndroid Build Coastguard Worker c = torch.floor(torch.add(x, y)) 771*da0073e9SAndroid Build Coastguard Worker return c 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker def test_ceil(x, y): 774*da0073e9SAndroid Build Coastguard Worker c = torch.ceil(torch.add(x, y)) 775*da0073e9SAndroid Build Coastguard Worker return c 776*da0073e9SAndroid Build Coastguard Worker 777*da0073e9SAndroid Build Coastguard Worker def test_trunc(x, y): 778*da0073e9SAndroid Build Coastguard Worker c = torch.trunc(torch.add(x, y)) 779*da0073e9SAndroid Build Coastguard Worker return c 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker def test_abs(x, y): 782*da0073e9SAndroid Build Coastguard Worker c = torch.abs(torch.add(x, y)) 783*da0073e9SAndroid Build Coastguard Worker return c 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker def test_log(x, y): 786*da0073e9SAndroid Build Coastguard Worker c = torch.log(torch.add(x, y)) 787*da0073e9SAndroid Build Coastguard Worker return c 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker def test_log2(x, y): 790*da0073e9SAndroid Build Coastguard Worker c = torch.log2(torch.add(x, y)) 791*da0073e9SAndroid Build Coastguard Worker return c 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker def test_log10(x, y): 794*da0073e9SAndroid Build Coastguard Worker c = torch.log10(torch.add(x, y)) 795*da0073e9SAndroid Build Coastguard Worker return c 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Worker def test_log1p(x, y): 798*da0073e9SAndroid Build Coastguard Worker c = torch.log1p(torch.add(x, y)) 799*da0073e9SAndroid Build Coastguard Worker return c 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker def test_rqrt(x, y): 802*da0073e9SAndroid Build Coastguard Worker c = torch.rsqrt(torch.add(x, y)) 803*da0073e9SAndroid Build Coastguard Worker return c 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker def test_erf(x, y): 806*da0073e9SAndroid Build Coastguard Worker c = torch.erf(torch.add(x, y)) 807*da0073e9SAndroid Build Coastguard Worker return c 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker def test_exp(x, y): 810*da0073e9SAndroid Build Coastguard Worker c = torch.exp(torch.add(x, y)) 811*da0073e9SAndroid Build Coastguard Worker return c 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker def test_expm1(x, y): 814*da0073e9SAndroid Build Coastguard Worker c = torch.expm1(torch.add(x, y)) 815*da0073e9SAndroid Build Coastguard Worker return c 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker def test_erfc(x, y): 818*da0073e9SAndroid Build Coastguard Worker c = torch.erfc(torch.add(x, y)) 819*da0073e9SAndroid Build Coastguard Worker return c 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker def test_frac(x, y): 822*da0073e9SAndroid Build Coastguard Worker c = torch.frac(torch.add(x, y)) 823*da0073e9SAndroid Build Coastguard Worker return c 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker def test_lgamma(x, y): 826*da0073e9SAndroid Build Coastguard Worker c = torch.lgamma(torch.add(x, y)) 827*da0073e9SAndroid Build Coastguard Worker return c 828*da0073e9SAndroid Build Coastguard Worker 829*da0073e9SAndroid Build Coastguard Worker def test_sigmoid(x, y): 830*da0073e9SAndroid Build Coastguard Worker c = torch.sigmoid(torch.add(x, y)) 831*da0073e9SAndroid Build Coastguard Worker return c 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker def test_reciprocal(x, y): 834*da0073e9SAndroid Build Coastguard Worker c = torch.reciprocal(torch.add(x, y)) 835*da0073e9SAndroid Build Coastguard Worker return c 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker def test_neg(x, y): 838*da0073e9SAndroid Build Coastguard Worker c = torch.neg(torch.add(x, y)) 839*da0073e9SAndroid Build Coastguard Worker return c 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker def test_relu(x, y): 842*da0073e9SAndroid Build Coastguard Worker c = torch.relu(torch.add(x, y)) 843*da0073e9SAndroid Build Coastguard Worker return c 844*da0073e9SAndroid Build Coastguard Worker 845*da0073e9SAndroid Build Coastguard Worker def test_hardtanh(x, y): 846*da0073e9SAndroid Build Coastguard Worker c = F.hardtanh(torch.add(x, y), -1.0, 1.0) 847*da0073e9SAndroid Build Coastguard Worker return c 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Worker def test_threshold(x, y): 850*da0073e9SAndroid Build Coastguard Worker c = F.threshold(torch.add(x, y), 0.5, 10) 851*da0073e9SAndroid Build Coastguard Worker return c 852*da0073e9SAndroid Build Coastguard Worker 853*da0073e9SAndroid Build Coastguard Worker gpu_only_fns = { 854*da0073e9SAndroid Build Coastguard Worker test_erf, 855*da0073e9SAndroid Build Coastguard Worker test_erfc 856*da0073e9SAndroid Build Coastguard Worker } 857*da0073e9SAndroid Build Coastguard Worker fns = { 858*da0073e9SAndroid Build Coastguard Worker test_round, 859*da0073e9SAndroid Build Coastguard Worker test_sin, 860*da0073e9SAndroid Build Coastguard Worker test_asin, 861*da0073e9SAndroid Build Coastguard Worker test_sinh, 862*da0073e9SAndroid Build Coastguard Worker test_cos, 863*da0073e9SAndroid Build Coastguard Worker test_acos, 864*da0073e9SAndroid Build Coastguard Worker test_cosh, 865*da0073e9SAndroid Build Coastguard Worker test_tan, 866*da0073e9SAndroid Build Coastguard Worker test_atan, 867*da0073e9SAndroid Build Coastguard Worker test_sqrt, 868*da0073e9SAndroid Build Coastguard Worker test_floor, 869*da0073e9SAndroid Build Coastguard Worker test_ceil, 870*da0073e9SAndroid Build Coastguard Worker test_trunc, 871*da0073e9SAndroid Build Coastguard Worker test_abs, 872*da0073e9SAndroid Build Coastguard Worker test_log, 873*da0073e9SAndroid Build Coastguard Worker test_log2, 874*da0073e9SAndroid Build Coastguard Worker test_log10, 875*da0073e9SAndroid Build Coastguard Worker test_log1p, 876*da0073e9SAndroid Build Coastguard Worker test_rsqrt, 877*da0073e9SAndroid Build Coastguard Worker test_exp, 878*da0073e9SAndroid Build Coastguard Worker test_expm1, 879*da0073e9SAndroid Build Coastguard Worker test_frac, 880*da0073e9SAndroid Build Coastguard Worker test_lgamma, 881*da0073e9SAndroid Build Coastguard Worker test_reciprocal, 882*da0073e9SAndroid Build Coastguard Worker test_neg, 883*da0073e9SAndroid Build Coastguard Worker test_threshold, 884*da0073e9SAndroid Build Coastguard Worker test_relu, 885*da0073e9SAndroid Build Coastguard Worker test_tanh, 886*da0073e9SAndroid Build Coastguard Worker test_hardtanh, 887*da0073e9SAndroid Build Coastguard Worker test_sigmoid, 888*da0073e9SAndroid Build Coastguard Worker } 889*da0073e9SAndroid Build Coastguard Worker fn_dev_dtype = itertools.product(gpu_only_fns.union(fns), self.devices, self.dtypes) 890*da0073e9SAndroid Build Coastguard Worker 891*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(0) 892*da0073e9SAndroid Build Coastguard Worker for torch_fn, dev, data_type in fn_dev_dtype: 893*da0073e9SAndroid Build Coastguard Worker if torch_fn == test_lgamma and dev == "cuda": 894*da0073e9SAndroid Build Coastguard Worker # lgamma_cuda does not support BF16 895*da0073e9SAndroid Build Coastguard Worker continue 896*da0073e9SAndroid Build Coastguard Worker rand_a = torch.rand(1024, dtype=data_type, device=dev) 897*da0073e9SAndroid Build Coastguard Worker rand_b = torch.rand(1024, dtype=data_type, device=dev) 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker ins = 20 * torch.rand(1024, dtype=data_type, device=dev) 900*da0073e9SAndroid Build Coastguard Worker cc = np.empty([1024], dtype=np.float32) 901*da0073e9SAndroid Build Coastguard Worker cc.fill(np.nan) 902*da0073e9SAndroid Build Coastguard Worker nans = torch.from_numpy(cc).to(dev) 903*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(torch_fn, (ins, ins)) 904*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, rand_a, rand_b) 905*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker _atol = 5e-3 if data_type is torch.bfloat16 else 2e-3 908*da0073e9SAndroid Build Coastguard Worker _rtol = 1e-5 909*da0073e9SAndroid Build Coastguard Worker if data_type is torch.bfloat16 and torch_fn not in gpu_only_fns: 910*da0073e9SAndroid Build Coastguard Worker y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float()) 911*da0073e9SAndroid Build Coastguard Worker y = y.bfloat16() 912*da0073e9SAndroid Build Coastguard Worker else: 913*da0073e9SAndroid Build Coastguard Worker y = torch_fn(rand_a, rand_b) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol) 916*da0073e9SAndroid Build Coastguard Worker # nans 917*da0073e9SAndroid Build Coastguard Worker # TODO: reenable. Currently all of the tests fail 918*da0073e9SAndroid Build Coastguard Worker # traced = torch.jit.trace(torch_fn, (ins, ins)) 919*da0073e9SAndroid Build Coastguard Worker # x = warmup_and_run_forward(traced, rand_a, rand_b) 920*da0073e9SAndroid Build Coastguard Worker # y = torch_fn(nans, rand_b) 921*da0073e9SAndroid Build Coastguard Worker # try: 922*da0073e9SAndroid Build Coastguard Worker # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) 923*da0073e9SAndroid Build Coastguard Worker # print("Succeeded on dev=", dev, "function=", torch_fn) 924*da0073e9SAndroid Build Coastguard Worker # except AssertionError: 925*da0073e9SAndroid Build Coastguard Worker # # Print extra info before exiting: 926*da0073e9SAndroid Build Coastguard Worker # print("Failed on dev=", dev, "function=", torch_fn) 927*da0073e9SAndroid Build Coastguard Worker # # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) 928*da0073e9SAndroid Build Coastguard Worker 929*da0073e9SAndroid Build Coastguard Worker 930*da0073e9SAndroid Build Coastguard Worker def test_round_2(self): 931*da0073e9SAndroid Build Coastguard Worker def round(x): 932*da0073e9SAndroid Build Coastguard Worker return torch.round(x) 933*da0073e9SAndroid Build Coastguard Worker 934*da0073e9SAndroid Build Coastguard Worker for data_type in [torch.float32, torch.double]: 935*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([0.2, 1.6, 2.5, 3.5]).to(data_type) 936*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(round, (a)) 937*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a) 938*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 939*da0073e9SAndroid Build Coastguard Worker y = round(x) 940*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Worker def test_rand_like(self): 943*da0073e9SAndroid Build Coastguard Worker N = 1 << 16 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker def run_rand_like(x, y): 946*da0073e9SAndroid Build Coastguard Worker return torch.rand_like(torch.add(x, y)) 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 949*da0073e9SAndroid Build Coastguard Worker x = torch.rand(N, device=device) 950*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False) 951*da0073e9SAndroid Build Coastguard Worker 952*da0073e9SAndroid Build Coastguard Worker for data_type in self.dtypes: 953*da0073e9SAndroid Build Coastguard Worker _x = x.to(dtype=data_type) 954*da0073e9SAndroid Build Coastguard Worker x_v = warmup_and_run_forward(traced, _x, _x) 955*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() 958*da0073e9SAndroid Build Coastguard Worker x1_mean = np.mean(x_np) 959*da0073e9SAndroid Build Coastguard Worker x2_mean = np.mean(x_np ** 2) 960*da0073e9SAndroid Build Coastguard Worker x3_mean = np.mean(x_np ** 3) 961*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2) 962*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2) 963*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2) 964*da0073e9SAndroid Build Coastguard Worker 965*da0073e9SAndroid Build Coastguard Worker def test_nans(self): 966*da0073e9SAndroid Build Coastguard Worker def test_max(x, y): 967*da0073e9SAndroid Build Coastguard Worker return torch.max(2 * x, 2 * y) 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Worker def test_min(x, y): 970*da0073e9SAndroid Build Coastguard Worker return torch.min(2 * x, 2 * y) 971*da0073e9SAndroid Build Coastguard Worker 972*da0073e9SAndroid Build Coastguard Worker tmax = torch.jit.trace(test_max, (torch.rand(1), torch.rand(1))) 973*da0073e9SAndroid Build Coastguard Worker tmin = torch.jit.trace(test_min, (torch.rand(1), torch.rand(1))) 974*da0073e9SAndroid Build Coastguard Worker 975*da0073e9SAndroid Build Coastguard Worker for data_type in self.dtypes: 976*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([np.nan]).to(dtype=data_type) 977*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([1.0]).to(dtype=data_type) 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker assert np.isnan(warmup_and_run_forward(tmin, x, y).float().item()) 980*da0073e9SAndroid Build Coastguard Worker assert np.isnan(warmup_and_run_forward(tmin, y, x).float().item()) 981*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 982*da0073e9SAndroid Build Coastguard Worker assert np.isnan(warmup_and_run_forward(tmax, x, y).float().item()) 983*da0073e9SAndroid Build Coastguard Worker assert np.isnan(warmup_and_run_forward(tmax, y, x).float().item()) 984*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker def test_double_intrinsics(self): 987*da0073e9SAndroid Build Coastguard Worker def do_pow(x): 988*da0073e9SAndroid Build Coastguard Worker return torch.pow(x, 7) 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 991*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, dtype=torch.double, device=device) 992*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(do_pow, (x)) 993*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, x) 994*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker def test_remainder(self): 997*da0073e9SAndroid Build Coastguard Worker def run_remainder(x, y): 998*da0073e9SAndroid Build Coastguard Worker c = torch.remainder(torch.add(x, y), x) 999*da0073e9SAndroid Build Coastguard Worker return c 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker for data_type in self.dtypes: 1002*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1024, dtype=data_type) 1003*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1024, dtype=data_type) 1004*da0073e9SAndroid Build Coastguard Worker zeros = torch.zeros(1024, dtype=data_type) 1005*da0073e9SAndroid Build Coastguard Worker cc = np.array(1024, dtype=float) 1006*da0073e9SAndroid Build Coastguard Worker cc.fill(np.nan) 1007*da0073e9SAndroid Build Coastguard Worker nans = torch.from_numpy(cc).to(dtype=data_type) 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker # random floats 1010*da0073e9SAndroid Build Coastguard Worker zeros1 = torch.zeros(1024, dtype=data_type) 1011*da0073e9SAndroid Build Coastguard Worker zeros2 = torch.zeros(1024, dtype=data_type) 1012*da0073e9SAndroid Build Coastguard Worker 1013*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(run_remainder, (zeros1, zeros2)) 1014*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 1015*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1016*da0073e9SAndroid Build Coastguard Worker y = run_remainder(a, b) 1017*da0073e9SAndroid Build Coastguard Worker if data_type is torch.bfloat16: 1018*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y, atol=4e-3, rtol=2e-3) 1019*da0073e9SAndroid Build Coastguard Worker else: 1020*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker # div by 0 1023*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(run_remainder, (zeros1, zeros2)) 1024*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, zeros, a) 1025*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1026*da0073e9SAndroid Build Coastguard Worker y = run_remainder(zeros, a) 1027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker # numerators and denominatos are nan 1030*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(run_remainder, (zeros1, zeros2)) 1031*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, nans, a) 1032*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1033*da0073e9SAndroid Build Coastguard Worker y = run_remainder(nans, a) 1034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, y) 1035*da0073e9SAndroid Build Coastguard Worker 1036*da0073e9SAndroid Build Coastguard Worker def test_multioutput(self): 1037*da0073e9SAndroid Build Coastguard Worker def easy(x): 1038*da0073e9SAndroid Build Coastguard Worker b = x + 1 1039*da0073e9SAndroid Build Coastguard Worker c = b + b 1040*da0073e9SAndroid Build Coastguard Worker return (b, c) 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.zeros(1024))) 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(1024) 1045*da0073e9SAndroid Build Coastguard Worker b, c = warmup_and_run_forward(traced, a) 1046*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1047*da0073e9SAndroid Build Coastguard Worker bp = a.numpy() + 1 1048*da0073e9SAndroid Build Coastguard Worker cp = bp + bp 1049*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(b.numpy(), bp) 1050*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(c.numpy(), cp) 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Worker def test_chunk(self): 1053*da0073e9SAndroid Build Coastguard Worker def easy(x): 1054*da0073e9SAndroid Build Coastguard Worker y = x + 1 1055*da0073e9SAndroid Build Coastguard Worker aaa, bbb = torch.chunk(y, 2) 1056*da0073e9SAndroid Build Coastguard Worker return aaa + bbb 1057*da0073e9SAndroid Build Coastguard Worker 1058*da0073e9SAndroid Build Coastguard Worker for data_type in self.dtypes: 1059*da0073e9SAndroid Build Coastguard Worker trace_input = torch.zeros(1024, 1024, dtype=data_type) 1060*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (trace_input)) 1061*da0073e9SAndroid Build Coastguard Worker 1062*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(32, 32, dtype=data_type) 1063*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a) 1064*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1065*da0073e9SAndroid Build Coastguard Worker npr = a.float().numpy() 1066*da0073e9SAndroid Build Coastguard Worker npr2 = npr + 1 1067*da0073e9SAndroid Build Coastguard Worker npr_a, npr_b = np.array_split(npr2, 2) 1068*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(npr_a + npr_b, x.float().numpy()) 1069*da0073e9SAndroid Build Coastguard Worker 1070*da0073e9SAndroid Build Coastguard Worker def test_cat(self): 1071*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 1072*da0073e9SAndroid Build Coastguard Worker _dim = 1 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker def foo(*args): 1075*da0073e9SAndroid Build Coastguard Worker args_2 = [v + i for i, v in enumerate(args)] 1076*da0073e9SAndroid Build Coastguard Worker v = torch.cat(args_2, dim=_dim) 1077*da0073e9SAndroid Build Coastguard Worker return v * v 1078*da0073e9SAndroid Build Coastguard Worker 1079*da0073e9SAndroid Build Coastguard Worker for data_type in self.dtypes: 1080*da0073e9SAndroid Build Coastguard Worker M = 16 1081*da0073e9SAndroid Build Coastguard Worker Ns = [128, 16, 1] 1082*da0073e9SAndroid Build Coastguard Worker values = [torch.zeros(M, N, dtype=data_type, device=device) for N in Ns] 1083*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, values) 1084*da0073e9SAndroid Build Coastguard Worker 1085*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, *values) 1086*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1087*da0073e9SAndroid Build Coastguard Worker ref = foo(*values) 1088*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(ref.cpu().float().numpy(), x.cpu().float().numpy()) 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker # Test channels-last 1091*da0073e9SAndroid Build Coastguard Worker for _cur_dim in range(4): 1092*da0073e9SAndroid Build Coastguard Worker _dim = _cur_dim 1093*da0073e9SAndroid Build Coastguard Worker values = [torch.randn((2, 3, 4, 5), device=device).to(memory_format=torch.channels_last) for _ in range(10)] 1094*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, values) 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, *values) 1097*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1098*da0073e9SAndroid Build Coastguard Worker ref = foo(*values) 1099*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, x) 1100*da0073e9SAndroid Build Coastguard Worker 1101*da0073e9SAndroid Build Coastguard Worker # This test checks that we correctly handle fusion group with just aten::cat in it. 1102*da0073e9SAndroid Build Coastguard Worker # Note that the test only makes sense with min_fusion_group=1, otherwise no 1103*da0073e9SAndroid Build Coastguard Worker # fusion groups would be formed at all. 1104*da0073e9SAndroid Build Coastguard Worker # TODO: Fix and re-enable the test. 1105*da0073e9SAndroid Build Coastguard Worker @unittest.skip("cat is broken with fusion group inlining disabled") 1106*da0073e9SAndroid Build Coastguard Worker def test_cat_only(self): 1107*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 1108*da0073e9SAndroid Build Coastguard Worker def foo(*args): 1109*da0073e9SAndroid Build Coastguard Worker args_2 = [v + i for i, v in enumerate(args)] 1110*da0073e9SAndroid Build Coastguard Worker v = torch.cat(args_2, dim=1) 1111*da0073e9SAndroid Build Coastguard Worker return v 1112*da0073e9SAndroid Build Coastguard Worker 1113*da0073e9SAndroid Build Coastguard Worker M = 16 1114*da0073e9SAndroid Build Coastguard Worker Ns = [128, 16, 1] 1115*da0073e9SAndroid Build Coastguard Worker values = [torch.zeros(M, N, device=device) for N in Ns] 1116*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, values) 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, *values) 1119*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1120*da0073e9SAndroid Build Coastguard Worker ref = foo(*values) 1121*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker def test_cat_negative_dim(self): 1124*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 1125*da0073e9SAndroid Build Coastguard Worker def foo(*args): 1126*da0073e9SAndroid Build Coastguard Worker v = torch.cat(args, dim=-1) 1127*da0073e9SAndroid Build Coastguard Worker return v * v 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker M = 16 1130*da0073e9SAndroid Build Coastguard Worker Ns = [128, 16, 1] 1131*da0073e9SAndroid Build Coastguard Worker values = [torch.randn(M, N, device=device) for N in Ns] 1132*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, values) 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, *values) 1135*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1136*da0073e9SAndroid Build Coastguard Worker ref = foo(*values) 1137*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1138*da0073e9SAndroid Build Coastguard Worker 1139*da0073e9SAndroid Build Coastguard Worker def test_cat_promote_inputs(self): 1140*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 1141*da0073e9SAndroid Build Coastguard Worker def foo(*args): 1142*da0073e9SAndroid Build Coastguard Worker v = torch.cat(args, dim=1) 1143*da0073e9SAndroid Build Coastguard Worker return v * v 1144*da0073e9SAndroid Build Coastguard Worker 1145*da0073e9SAndroid Build Coastguard Worker M = 16 1146*da0073e9SAndroid Build Coastguard Worker Ns = [128, 16, 1] 1147*da0073e9SAndroid Build Coastguard Worker dtypes = [torch.half, torch.float32, torch.double] 1148*da0073e9SAndroid Build Coastguard Worker values = [torch.randn(M, N, device=device, dtype=dt) for N, dt in zip(Ns, dtypes)] 1149*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, values) 1150*da0073e9SAndroid Build Coastguard Worker 1151*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, *values) 1152*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1153*da0073e9SAndroid Build Coastguard Worker ref = foo(*values) 1154*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1155*da0073e9SAndroid Build Coastguard Worker 1156*da0073e9SAndroid Build Coastguard Worker def test_cat_empty_tensors(self): 1157*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 1158*da0073e9SAndroid Build Coastguard Worker def foo(*args): 1159*da0073e9SAndroid Build Coastguard Worker v = torch.cat(args, dim=1) 1160*da0073e9SAndroid Build Coastguard Worker return v * v 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker M = 16 1163*da0073e9SAndroid Build Coastguard Worker Ns = [128, 16, 1] 1164*da0073e9SAndroid Build Coastguard Worker empty = torch.tensor([], device=device, dtype=torch.double) 1165*da0073e9SAndroid Build Coastguard Worker values = [empty] + [torch.randn(M, N, device=device) for N in Ns] 1166*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, values) 1167*da0073e9SAndroid Build Coastguard Worker 1168*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, *values) 1169*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1170*da0073e9SAndroid Build Coastguard Worker ref = foo(*values) 1171*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1172*da0073e9SAndroid Build Coastguard Worker 1173*da0073e9SAndroid Build Coastguard Worker # now test with only empty tensors 1174*da0073e9SAndroid Build Coastguard Worker values = [empty for i in range(3)] 1175*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, values) 1176*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, *values) 1177*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1178*da0073e9SAndroid Build Coastguard Worker ref = foo(*values) 1179*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1180*da0073e9SAndroid Build Coastguard Worker 1181*da0073e9SAndroid Build Coastguard Worker def test_cat_with_constant_dim(self): 1182*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 1183*da0073e9SAndroid Build Coastguard Worker def foo(*args): 1184*da0073e9SAndroid Build Coastguard Worker v1 = torch.cat(args, dim=1) 1185*da0073e9SAndroid Build Coastguard Worker v2 = torch.cat([v1], dim=1) 1186*da0073e9SAndroid Build Coastguard Worker return v2 * v2 1187*da0073e9SAndroid Build Coastguard Worker 1188*da0073e9SAndroid Build Coastguard Worker empty = torch.tensor([], device=device, dtype=torch.float32) 1189*da0073e9SAndroid Build Coastguard Worker inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)] 1190*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, inputs) 1191*da0073e9SAndroid Build Coastguard Worker 1192*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, *inputs) 1193*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1194*da0073e9SAndroid Build Coastguard Worker ref = foo(*inputs) 1195*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy()) 1196*da0073e9SAndroid Build Coastguard Worker 1197*da0073e9SAndroid Build Coastguard Worker def test_scalar(self): 1198*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1199*da0073e9SAndroid Build Coastguard Worker def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor: 1200*da0073e9SAndroid Build Coastguard Worker return torch.add(torch.add(x, y, alpha=a), z, alpha=b) 1201*da0073e9SAndroid Build Coastguard Worker 1202*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1203*da0073e9SAndroid Build Coastguard Worker def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor: 1204*da0073e9SAndroid Build Coastguard Worker return torch.add(torch.add(x, y, alpha=a), z, alpha=b) 1205*da0073e9SAndroid Build Coastguard Worker 1206*da0073e9SAndroid Build Coastguard Worker for test in (test_float, test_int): 1207*da0073e9SAndroid Build Coastguard Worker for data_type in self.dtypes: 1208*da0073e9SAndroid Build Coastguard Worker x, y, z = (torch.rand(4, dtype=data_type) for i in range(3)) 1209*da0073e9SAndroid Build Coastguard Worker a, b = 1, 2 1210*da0073e9SAndroid Build Coastguard Worker test(x, y, z, a, b) 1211*da0073e9SAndroid Build Coastguard Worker r = test(x, y, z, a, b) 1212*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r, x + y * a + z * b) 1213*da0073e9SAndroid Build Coastguard Worker 1214*da0073e9SAndroid Build Coastguard Worker def test_loop(self): 1215*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1216*da0073e9SAndroid Build Coastguard Worker def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor: 1217*da0073e9SAndroid Build Coastguard Worker b = y 1218*da0073e9SAndroid Build Coastguard Worker for i in range(0, z): 1219*da0073e9SAndroid Build Coastguard Worker a = x + y 1220*da0073e9SAndroid Build Coastguard Worker b = b + y 1221*da0073e9SAndroid Build Coastguard Worker return b 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4) 1224*da0073e9SAndroid Build Coastguard Worker test(x, y, z) 1225*da0073e9SAndroid Build Coastguard Worker r = test(x, y, z) 1226*da0073e9SAndroid Build Coastguard Worker 1227*da0073e9SAndroid Build Coastguard Worker def test_slice(self): 1228*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 1229*da0073e9SAndroid Build Coastguard Worker a = x[0:512:2] 1230*da0073e9SAndroid Build Coastguard Worker b = y[0:512:2] 1231*da0073e9SAndroid Build Coastguard Worker return a + b 1232*da0073e9SAndroid Build Coastguard Worker 1233*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024))) 1234*da0073e9SAndroid Build Coastguard Worker 1235*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1024, 1024) 1236*da0073e9SAndroid Build Coastguard Worker x = traced(a, a) 1237*da0073e9SAndroid Build Coastguard Worker npr = a[0:512:2] 1238*da0073e9SAndroid Build Coastguard Worker npr = npr + npr 1239*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(npr.numpy(), x.numpy()) 1240*da0073e9SAndroid Build Coastguard Worker 1241*da0073e9SAndroid Build Coastguard Worker def test_unsqueeze(self, N=256): 1242*da0073e9SAndroid Build Coastguard Worker def easy(x, y): 1243*da0073e9SAndroid Build Coastguard Worker a = torch.unsqueeze(x, 0) 1244*da0073e9SAndroid Build Coastguard Worker b = torch.unsqueeze(y, 0) 1245*da0073e9SAndroid Build Coastguard Worker return a + b 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N))) 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker a = torch.rand(N, N) 1250*da0073e9SAndroid Build Coastguard Worker x = traced(a, a) 1251*da0073e9SAndroid Build Coastguard Worker npr = np.expand_dims(a, 0) 1252*da0073e9SAndroid Build Coastguard Worker npr = npr + npr 1253*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(npr, x.numpy()) 1254*da0073e9SAndroid Build Coastguard Worker 1255*da0073e9SAndroid Build Coastguard Worker def _test_softmax(self, device): 1256*da0073e9SAndroid Build Coastguard Worker def test_softmax(x, y): 1257*da0073e9SAndroid Build Coastguard Worker a = F.softmax(x, dim=0, dtype=torch.float32) 1258*da0073e9SAndroid Build Coastguard Worker b = F.softmax(y, dim=0, dtype=torch.float32) 1259*da0073e9SAndroid Build Coastguard Worker c = F.softmax(x, dim=1, dtype=torch.float32) 1260*da0073e9SAndroid Build Coastguard Worker d = F.softmax(y, dim=1, dtype=torch.float32) 1261*da0073e9SAndroid Build Coastguard Worker return a + b + c + d 1262*da0073e9SAndroid Build Coastguard Worker 1263*da0073e9SAndroid Build Coastguard Worker def test_softmax_neg_index(x, y): 1264*da0073e9SAndroid Build Coastguard Worker a = F.softmax(x, dim=-2, dtype=torch.float32) 1265*da0073e9SAndroid Build Coastguard Worker b = F.softmax(y, dim=-2, dtype=torch.float32) 1266*da0073e9SAndroid Build Coastguard Worker c = F.softmax(x, dim=-1, dtype=torch.float32) 1267*da0073e9SAndroid Build Coastguard Worker d = F.softmax(y, dim=-1, dtype=torch.float32) 1268*da0073e9SAndroid Build Coastguard Worker return a + b + c + d 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker def test_log_softmax(x, y): 1271*da0073e9SAndroid Build Coastguard Worker a = F.log_softmax(x, dim=0, dtype=torch.float32) 1272*da0073e9SAndroid Build Coastguard Worker b = F.log_softmax(y, dim=0, dtype=torch.float32) 1273*da0073e9SAndroid Build Coastguard Worker c = F.log_softmax(x, dim=1, dtype=torch.float32) 1274*da0073e9SAndroid Build Coastguard Worker d = F.log_softmax(y, dim=1, dtype=torch.float32) 1275*da0073e9SAndroid Build Coastguard Worker return a + b + c + d 1276*da0073e9SAndroid Build Coastguard Worker 1277*da0073e9SAndroid Build Coastguard Worker for test in (test_softmax, test_log_softmax, test_softmax_neg_index): 1278*da0073e9SAndroid Build Coastguard Worker for data_type in self.dtypes: 1279*da0073e9SAndroid Build Coastguard Worker old = torch._C._jit_set_texpr_reductions_enabled(True) 1280*da0073e9SAndroid Build Coastguard Worker traced_input = torch.randn(2, 3, dtype=data_type, device=device) 1281*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (traced_input, traced_input)) 1282*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 3, dtype=data_type, device=device) 1283*da0073e9SAndroid Build Coastguard Worker res = traced(inp, inp) 1284*da0073e9SAndroid Build Coastguard Worker # Use eager mode as reference. 1285*da0073e9SAndroid Build Coastguard Worker ref = test(inp, inp) 1286*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06) 1287*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_texpr_reductions_enabled(old) 1288*da0073e9SAndroid Build Coastguard Worker 1289*da0073e9SAndroid Build Coastguard Worker def test_softmax_cpu(self): 1290*da0073e9SAndroid Build Coastguard Worker self._test_softmax('cpu') 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") 1293*da0073e9SAndroid Build Coastguard Worker @unittest.skip("global allocs are not supported yet.") 1294*da0073e9SAndroid Build Coastguard Worker def test_softmax_cuda(self): 1295*da0073e9SAndroid Build Coastguard Worker self._test_softmax('cuda') 1296*da0073e9SAndroid Build Coastguard Worker 1297*da0073e9SAndroid Build Coastguard Worker def test_half_gelu(self): 1298*da0073e9SAndroid Build Coastguard Worker devices = ["cuda"] if torch.cuda.is_available() else [] 1299*da0073e9SAndroid Build Coastguard Worker 1300*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1301*da0073e9SAndroid Build Coastguard Worker def bias_gelu(bias, y): 1302*da0073e9SAndroid Build Coastguard Worker x = bias + y 1303*da0073e9SAndroid Build Coastguard Worker return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) 1304*da0073e9SAndroid Build Coastguard Worker 1305*da0073e9SAndroid Build Coastguard Worker for device in devices: 1306*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1024, dtype=torch.half, device=device) 1307*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1024, dtype=torch.half, device=device) 1308*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(bias_gelu, (a, b)) 1309*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 1310*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker def test_half_bn_relu(self): 1313*da0073e9SAndroid Build Coastguard Worker devices = ["cuda"] if torch.cuda.is_available() else [] 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Worker def foo(a, b, c): 1316*da0073e9SAndroid Build Coastguard Worker y = torch.nn.functional.batch_norm(a, b, c) 1317*da0073e9SAndroid Build Coastguard Worker z = y.relu() 1318*da0073e9SAndroid Build Coastguard Worker return z 1319*da0073e9SAndroid Build Coastguard Worker 1320*da0073e9SAndroid Build Coastguard Worker for device in devices: 1321*da0073e9SAndroid Build Coastguard Worker a = torch.rand(16, 16, dtype=torch.half, device=device) 1322*da0073e9SAndroid Build Coastguard Worker b = torch.rand(16, dtype=torch.half, device=device) 1323*da0073e9SAndroid Build Coastguard Worker c = torch.rand(16, dtype=torch.half, device=device) 1324*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, (a, b, c)) 1325*da0073e9SAndroid Build Coastguard Worker print(traced.graph) 1326*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b, c) 1327*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker def test_exp_pow(self): 1330*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1331*da0073e9SAndroid Build Coastguard Worker def do_exp(x, y, z): 1332*da0073e9SAndroid Build Coastguard Worker return ((x * y) * 2) * torch.pow(z, 2) 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 1335*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, dtype=torch.double, device=device) 1336*da0073e9SAndroid Build Coastguard Worker y = torch.rand(10, dtype=torch.double, device=device) 1337*da0073e9SAndroid Build Coastguard Worker z = torch.rand(10, dtype=torch.double, device=device) 1338*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(do_exp, (x, y, z)) 1339*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, x, y, z) 1340*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1341*da0073e9SAndroid Build Coastguard Worker 1342*da0073e9SAndroid Build Coastguard Worker def test_sin_pow(self): 1343*da0073e9SAndroid Build Coastguard Worker def test(x): 1344*da0073e9SAndroid Build Coastguard Worker return torch.sin(torch.pow(x, 0)) 1345*da0073e9SAndroid Build Coastguard Worker 1346*da0073e9SAndroid Build Coastguard Worker for data_type, shape in itertools.product(self.dtypes, [[3], [5], [10]]): 1347*da0073e9SAndroid Build Coastguard Worker x = torch.rand(shape, dtype=data_type) 1348*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(test) 1349*da0073e9SAndroid Build Coastguard Worker out = warmup_and_run_forward(scripted, x) 1350*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1351*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, test(x)) 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker def test_transpose(self): 1354*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1355*da0073e9SAndroid Build Coastguard Worker def test(x, y, z): 1356*da0073e9SAndroid Build Coastguard Worker return x.transpose(0, 1) + y + z 1357*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4, 5, 2, 3) 1358*da0073e9SAndroid Build Coastguard Worker y = torch.rand(5, 4, 2, 3) 1359*da0073e9SAndroid Build Coastguard Worker z = torch.rand(5, 4, 2, 3) 1360*da0073e9SAndroid Build Coastguard Worker ref = test(x, y, z) 1361*da0073e9SAndroid Build Coastguard Worker res = test(x, y, z) 1362*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(ref.numpy(), res.numpy()) 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard Worker def test_sliced_stride(self): 1365*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1366*da0073e9SAndroid Build Coastguard Worker def test(x, y, z): 1367*da0073e9SAndroid Build Coastguard Worker return x + y + z 1368*da0073e9SAndroid Build Coastguard Worker x = torch.rand(16, 4, 2, 3)[::2] 1369*da0073e9SAndroid Build Coastguard Worker y = torch.rand(8, 4, 2, 3) 1370*da0073e9SAndroid Build Coastguard Worker z = torch.rand(8, 4, 2, 3) 1371*da0073e9SAndroid Build Coastguard Worker ref = test(x, y, z) 1372*da0073e9SAndroid Build Coastguard Worker res = test(x, y, z) 1373*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(ref.numpy(), res.numpy()) 1374*da0073e9SAndroid Build Coastguard Worker 1375*da0073e9SAndroid Build Coastguard Worker @unittest.skip("dynamic shapes are not quite there yet") 1376*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") 1377*da0073e9SAndroid Build Coastguard Worker def test_dynamic_shape(self): 1378*da0073e9SAndroid Build Coastguard Worker with num_profiled_runs(2): 1379*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1380*da0073e9SAndroid Build Coastguard Worker def test(x, y, z): 1381*da0073e9SAndroid Build Coastguard Worker return x * y * z 1382*da0073e9SAndroid Build Coastguard Worker x, y, z = (torch.rand(4, 8).cuda() for _ in range(3)) 1383*da0073e9SAndroid Build Coastguard Worker ref = test(x, y, z) 1384*da0073e9SAndroid Build Coastguard Worker _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)]) 1385*da0073e9SAndroid Build Coastguard Worker res = test(x, y, z) 1386*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy()) 1387*da0073e9SAndroid Build Coastguard Worker 1388*da0073e9SAndroid Build Coastguard Worker # A wild broadcast appears. 1389*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4, 8).cuda() 1390*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, 8).cuda() 1391*da0073e9SAndroid Build Coastguard Worker z = torch.rand(4, 1).cuda() 1392*da0073e9SAndroid Build Coastguard Worker res = test(x, y, z) 1393*da0073e9SAndroid Build Coastguard Worker xn, yn, zn = (t.cpu().numpy() for t in (x, y, z)) 1394*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) 1395*da0073e9SAndroid Build Coastguard Worker 1396*da0073e9SAndroid Build Coastguard Worker # Mismatched shapes shouldn't reach codegen. 1397*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4, 8).cuda() 1398*da0073e9SAndroid Build Coastguard Worker y = torch.rand(4, 8).cuda() 1399*da0073e9SAndroid Build Coastguard Worker z = torch.rand(5, 8).cuda() 1400*da0073e9SAndroid Build Coastguard Worker try: 1401*da0073e9SAndroid Build Coastguard Worker res = test(x, y, z) 1402*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 1403*da0073e9SAndroid Build Coastguard Worker assert "The size of tensor a (4) must match" in e.args[0] 1404*da0073e9SAndroid Build Coastguard Worker 1405*da0073e9SAndroid Build Coastguard Worker # Changing a static dimension fails guards. 1406*da0073e9SAndroid Build Coastguard Worker # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)] 1407*da0073e9SAndroid Build Coastguard Worker # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)] 1408*da0073e9SAndroid Build Coastguard Worker # res = test(x, y, z) 1409*da0073e9SAndroid Build Coastguard Worker # print(test.graph_for(x, y, z)) 1410*da0073e9SAndroid Build Coastguard Worker # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) 1411*da0073e9SAndroid Build Coastguard Worker 1412*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") 1413*da0073e9SAndroid Build Coastguard Worker def test_guard_fails(self): 1414*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1415*da0073e9SAndroid Build Coastguard Worker def test(x, y, z): 1416*da0073e9SAndroid Build Coastguard Worker return x * y * z 1417*da0073e9SAndroid Build Coastguard Worker r1 = test(*[torch.rand(4).cuda() for _ in range(3)]) 1418*da0073e9SAndroid Build Coastguard Worker r2 = test(*[torch.rand(4).cuda() for _ in range(3)]) 1419*da0073e9SAndroid Build Coastguard Worker r3 = test(*[torch.rand(4).cuda() for _ in range(3)]) 1420*da0073e9SAndroid Build Coastguard Worker r4 = test(*[torch.rand(7).cuda() for _ in range(3)]) 1421*da0073e9SAndroid Build Coastguard Worker 1422*da0073e9SAndroid Build Coastguard Worker def test_bitwise_ops(self): 1423*da0073e9SAndroid Build Coastguard Worker def run_and(x, y): 1424*da0073e9SAndroid Build Coastguard Worker return x & (x & y) 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker def run_or(x, y): 1427*da0073e9SAndroid Build Coastguard Worker return x & (x | y) 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker def run_xor(x, y): 1430*da0073e9SAndroid Build Coastguard Worker return x ^ (x ^ y) 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker def run_lshift(x, y): 1433*da0073e9SAndroid Build Coastguard Worker return x & (x << y) 1434*da0073e9SAndroid Build Coastguard Worker 1435*da0073e9SAndroid Build Coastguard Worker def run_rshift(x, y): 1436*da0073e9SAndroid Build Coastguard Worker return x & (x >> y) 1437*da0073e9SAndroid Build Coastguard Worker 1438*da0073e9SAndroid Build Coastguard Worker fns = {run_and, run_or, run_xor, run_lshift, run_rshift} 1439*da0073e9SAndroid Build Coastguard Worker 1440*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 1441*da0073e9SAndroid Build Coastguard Worker for fn in fns: 1442*da0073e9SAndroid Build Coastguard Worker a = torch.ones(128, dtype=torch.int32, device=device) 1443*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(128, dtype=torch.int32, device=device) 1444*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(128, dtype=torch.int32, device=device) 1445*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(fn, (inp, inp)) 1446*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 1447*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1448*da0073e9SAndroid Build Coastguard Worker y = fn(a, b) 1449*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) 1450*da0073e9SAndroid Build Coastguard Worker 1451*da0073e9SAndroid Build Coastguard Worker def test_where(self): 1452*da0073e9SAndroid Build Coastguard Worker def run_where(x, y): 1453*da0073e9SAndroid Build Coastguard Worker return torch.where(torch.gt(x, y), x, y) 1454*da0073e9SAndroid Build Coastguard Worker 1455*da0073e9SAndroid Build Coastguard Worker for data_type in self.dtypes: 1456*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1024, dtype=data_type) 1457*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1024, dtype=data_type) 1458*da0073e9SAndroid Build Coastguard Worker zeros = torch.zeros(1024, dtype=data_type) 1459*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(run_where, (zeros, zeros)) 1460*da0073e9SAndroid Build Coastguard Worker x = warmup_and_run_forward(traced, a, b) 1461*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1462*da0073e9SAndroid Build Coastguard Worker y = run_where(a, b) 1463*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(x.float().numpy(), y.float().numpy()) 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker def test_multi_rand(self): 1466*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 1467*da0073e9SAndroid Build Coastguard Worker def test(x): 1468*da0073e9SAndroid Build Coastguard Worker y = torch.rand_like(x) 1469*da0073e9SAndroid Build Coastguard Worker return (x + y) - (y - x) 1470*da0073e9SAndroid Build Coastguard Worker 1471*da0073e9SAndroid Build Coastguard Worker _atol = 2e-3 1472*da0073e9SAndroid Build Coastguard Worker _rtol = 1e-5 1473*da0073e9SAndroid Build Coastguard Worker for data_type in self.dtypes: 1474*da0073e9SAndroid Build Coastguard Worker if data_type is torch.bfloat16: 1475*da0073e9SAndroid Build Coastguard Worker _atol = 2e-2 1476*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4, dtype=data_type, device=device) 1477*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(test) 1478*da0073e9SAndroid Build Coastguard Worker out = warmup_and_run_forward(scripted, a) 1479*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1480*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(out, 2 * a, atol=_atol, rtol=_rtol) 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker def test_mask(self): 1483*da0073e9SAndroid Build Coastguard Worker def test(x): 1484*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze(1) == 0 1485*da0073e9SAndroid Build Coastguard Worker 1486*da0073e9SAndroid Build Coastguard Worker for d in self.devices: 1487*da0073e9SAndroid Build Coastguard Worker for data_type in self.dtypes: 1488*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4, dtype=data_type, device=d) > 0.5 1489*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(test) 1490*da0073e9SAndroid Build Coastguard Worker out = warmup_and_run_forward(scripted, x) 1491*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1492*da0073e9SAndroid Build Coastguard Worker assert torch.equal(out, test(x)) 1493*da0073e9SAndroid Build Coastguard Worker 1494*da0073e9SAndroid Build Coastguard Worker def test_simple_add(self): 1495*da0073e9SAndroid Build Coastguard Worker val = torch._C._jit_get_te_generate_block_code() 1496*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_generate_block_code(True) 1497*da0073e9SAndroid Build Coastguard Worker fall_bk = torch._C._jit_texpr_fallback_allowed() 1498*da0073e9SAndroid Build Coastguard Worker torch._C._jit_texpr_set_fallback_allowed(True) 1499*da0073e9SAndroid Build Coastguard Worker 1500*da0073e9SAndroid Build Coastguard Worker def simple(a, b): 1501*da0073e9SAndroid Build Coastguard Worker return torch.add(a, b) 1502*da0073e9SAndroid Build Coastguard Worker 1503*da0073e9SAndroid Build Coastguard Worker a = torch.ones(256, 256) 1504*da0073e9SAndroid Build Coastguard Worker b = torch.ones(256, 256) 1505*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(simple, 1506*da0073e9SAndroid Build Coastguard Worker (torch.ones(256, 256), torch.ones(256, 256))) 1507*da0073e9SAndroid Build Coastguard Worker f = traced(a, b) 1508*da0073e9SAndroid Build Coastguard Worker f_test = np.full((256, 256), 2, dtype=float) 1509*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose(f.numpy(), f_test) 1510*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_te_generate_block_code(val) 1511*da0073e9SAndroid Build Coastguard Worker torch._C._jit_texpr_set_fallback_allowed(fall_bk) 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Worker def test_strided_output_preserved(self): 1514*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 1515*da0073e9SAndroid Build Coastguard Worker return a + b - a 1516*da0073e9SAndroid Build Coastguard Worker 1517*da0073e9SAndroid Build Coastguard Worker # smaller, easier to debug example 1518*da0073e9SAndroid Build Coastguard Worker x = torch.arange(6) 1519*da0073e9SAndroid Build Coastguard Worker x = torch.as_strided(x, (2, 3), (1, 2)) 1520*da0073e9SAndroid Build Coastguard Worker total = 0 1521*da0073e9SAndroid Build Coastguard Worker for i in range(2): 1522*da0073e9SAndroid Build Coastguard Worker for j in range(3): 1523*da0073e9SAndroid Build Coastguard Worker x[i, j] = total 1524*da0073e9SAndroid Build Coastguard Worker total += 1 1525*da0073e9SAndroid Build Coastguard Worker foo_script = torch.jit.script(foo) 1526*da0073e9SAndroid Build Coastguard Worker foo_script(x, x) 1527*da0073e9SAndroid Build Coastguard Worker foo_script(x, x) 1528*da0073e9SAndroid Build Coastguard Worker out_s = foo_script(x, x) 1529*da0073e9SAndroid Build Coastguard Worker out_eager = foo(x, x) 1530*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_s, out_eager) 1531*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_s.stride(), out_eager.stride()) 1532*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1533*da0073e9SAndroid Build Coastguard Worker 1534*da0073e9SAndroid Build Coastguard Worker # more dims 1535*da0073e9SAndroid Build Coastguard Worker N, C, H, W, = 2, 3, 4, 5 1536*da0073e9SAndroid Build Coastguard Worker x = torch.rand(N, C, H, W).to(memory_format=torch.channels_last) 1537*da0073e9SAndroid Build Coastguard Worker foo_script = torch.jit.script(foo) 1538*da0073e9SAndroid Build Coastguard Worker foo_script(x, x) 1539*da0073e9SAndroid Build Coastguard Worker foo_script(x, x) 1540*da0073e9SAndroid Build Coastguard Worker out_s = foo_script(x, x) 1541*da0073e9SAndroid Build Coastguard Worker out_eager = foo(x, x) 1542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_s, out_eager) 1543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_s.stride(), out_eager.stride()) 1544*da0073e9SAndroid Build Coastguard Worker self.assertLastGraphAllFused() 1545*da0073e9SAndroid Build Coastguard Worker 1546*da0073e9SAndroid Build Coastguard Worker def test_alias_analysis_module(self): 1547*da0073e9SAndroid Build Coastguard Worker class AliasModule(nn.Module): 1548*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1549*da0073e9SAndroid Build Coastguard Worker super().__init__() 1550*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 1551*da0073e9SAndroid Build Coastguard Worker self.a = torch.randn(128, 128) 1552*da0073e9SAndroid Build Coastguard Worker self.b = torch.randn(128, 128) 1553*da0073e9SAndroid Build Coastguard Worker self.c = torch.randn(128, 128) 1554*da0073e9SAndroid Build Coastguard Worker 1555*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y, z): 1556*da0073e9SAndroid Build Coastguard Worker z = z + self.a 1557*da0073e9SAndroid Build Coastguard Worker self.b.add_(y) 1558*da0073e9SAndroid Build Coastguard Worker w = z + self.a 1559*da0073e9SAndroid Build Coastguard Worker z = w + x 1560*da0073e9SAndroid Build Coastguard Worker return z 1561*da0073e9SAndroid Build Coastguard Worker x = torch.randn(128, 128) 1562*da0073e9SAndroid Build Coastguard Worker 1563*da0073e9SAndroid Build Coastguard Worker def getModule(script): 1564*da0073e9SAndroid Build Coastguard Worker am = AliasModule() 1565*da0073e9SAndroid Build Coastguard Worker if script: 1566*da0073e9SAndroid Build Coastguard Worker return torch.jit.script(am) 1567*da0073e9SAndroid Build Coastguard Worker return am 1568*da0073e9SAndroid Build Coastguard Worker 1569*da0073e9SAndroid Build Coastguard Worker am = getModule(False) 1570*da0073e9SAndroid Build Coastguard Worker am_s = getModule(True) 1571*da0073e9SAndroid Build Coastguard Worker ref = am(x, x, x) 1572*da0073e9SAndroid Build Coastguard Worker test = am_s(x, x, x) 1573*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref, test) 1574*da0073e9SAndroid Build Coastguard Worker 1575*da0073e9SAndroid Build Coastguard Worker # Now do the aliasing 1576*da0073e9SAndroid Build Coastguard Worker am.a = am.b 1577*da0073e9SAndroid Build Coastguard Worker ref = am(x, x, x) 1578*da0073e9SAndroid Build Coastguard Worker 1579*da0073e9SAndroid Build Coastguard Worker am_s.a = am_s.b 1580*da0073e9SAndroid Build Coastguard Worker test = am_s(x, x, x) 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref, test) 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker def test_alias_analysis_inputs(self): 1585*da0073e9SAndroid Build Coastguard Worker class AliasModule(nn.Module): 1586*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1587*da0073e9SAndroid Build Coastguard Worker super().__init__() 1588*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 1589*da0073e9SAndroid Build Coastguard Worker self.a = torch.randn(128, 128) 1590*da0073e9SAndroid Build Coastguard Worker self.b = torch.randn(128, 128) 1591*da0073e9SAndroid Build Coastguard Worker self.c = torch.randn(128, 128) 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y, z): 1594*da0073e9SAndroid Build Coastguard Worker x.add_(y) 1595*da0073e9SAndroid Build Coastguard Worker w = z + self.a 1596*da0073e9SAndroid Build Coastguard Worker z = w + x 1597*da0073e9SAndroid Build Coastguard Worker return z 1598*da0073e9SAndroid Build Coastguard Worker 1599*da0073e9SAndroid Build Coastguard Worker def getModule(script): 1600*da0073e9SAndroid Build Coastguard Worker am = AliasModule() 1601*da0073e9SAndroid Build Coastguard Worker if script: 1602*da0073e9SAndroid Build Coastguard Worker return torch.jit.script(am) 1603*da0073e9SAndroid Build Coastguard Worker return am 1604*da0073e9SAndroid Build Coastguard Worker am = getModule(False) 1605*da0073e9SAndroid Build Coastguard Worker am_s = getModule(True) 1606*da0073e9SAndroid Build Coastguard Worker 1607*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 1608*da0073e9SAndroid Build Coastguard Worker x = torch.randn(128, 128) 1609*da0073e9SAndroid Build Coastguard Worker ref = am(x, x, x) 1610*da0073e9SAndroid Build Coastguard Worker 1611*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 1612*da0073e9SAndroid Build Coastguard Worker x = torch.randn(128, 128) 1613*da0073e9SAndroid Build Coastguard Worker test = am_s(x, x, x) 1614*da0073e9SAndroid Build Coastguard Worker 1615*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref, test) 1616*da0073e9SAndroid Build Coastguard Worker 1617*da0073e9SAndroid Build Coastguard Worker def test_alias_analysis_input_and_module(self): 1618*da0073e9SAndroid Build Coastguard Worker class AliasModule(nn.Module): 1619*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1620*da0073e9SAndroid Build Coastguard Worker super().__init__() 1621*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 1622*da0073e9SAndroid Build Coastguard Worker self.a = torch.randn(128, 128) 1623*da0073e9SAndroid Build Coastguard Worker self.b = torch.randn(128, 128) 1624*da0073e9SAndroid Build Coastguard Worker self.c = torch.randn(128, 128) 1625*da0073e9SAndroid Build Coastguard Worker 1626*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y, z): 1627*da0073e9SAndroid Build Coastguard Worker x.add_(y) 1628*da0073e9SAndroid Build Coastguard Worker w = z + self.b 1629*da0073e9SAndroid Build Coastguard Worker z = w + x 1630*da0073e9SAndroid Build Coastguard Worker return z 1631*da0073e9SAndroid Build Coastguard Worker 1632*da0073e9SAndroid Build Coastguard Worker def getModule(script): 1633*da0073e9SAndroid Build Coastguard Worker am = AliasModule() 1634*da0073e9SAndroid Build Coastguard Worker if script: 1635*da0073e9SAndroid Build Coastguard Worker return torch.jit.script(am) 1636*da0073e9SAndroid Build Coastguard Worker return am 1637*da0073e9SAndroid Build Coastguard Worker am = getModule(False) 1638*da0073e9SAndroid Build Coastguard Worker am_s = getModule(True) 1639*da0073e9SAndroid Build Coastguard Worker 1640*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 1641*da0073e9SAndroid Build Coastguard Worker x = torch.randn(128, 128) 1642*da0073e9SAndroid Build Coastguard Worker am.b = x 1643*da0073e9SAndroid Build Coastguard Worker ref = am(x, x, x) 1644*da0073e9SAndroid Build Coastguard Worker 1645*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 1646*da0073e9SAndroid Build Coastguard Worker x = torch.randn(128, 128) 1647*da0073e9SAndroid Build Coastguard Worker am_s.b = x 1648*da0073e9SAndroid Build Coastguard Worker test = am_s(x, x, x) 1649*da0073e9SAndroid Build Coastguard Worker 1650*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref, test) 1651*da0073e9SAndroid Build Coastguard Worker 1652*da0073e9SAndroid Build Coastguard Worker def test_multiple_outputs(self): 1653*da0073e9SAndroid Build Coastguard Worker for device in self.devices: 1654*da0073e9SAndroid Build Coastguard Worker # A bug reported internally similar to the one reported in #48533 1655*da0073e9SAndroid Build Coastguard Worker def foo(a, b, c): 1656*da0073e9SAndroid Build Coastguard Worker t_next = c + 1 1657*da0073e9SAndroid Build Coastguard Worker t5 = t_next * b 1658*da0073e9SAndroid Build Coastguard Worker t6 = torch.unsqueeze(t_next, 1) 1659*da0073e9SAndroid Build Coastguard Worker t7 = a * t6 1660*da0073e9SAndroid Build Coastguard Worker return (t7, t5, t_next) 1661*da0073e9SAndroid Build Coastguard Worker 1662*da0073e9SAndroid Build Coastguard Worker for data_type in self.dtypes: 1663*da0073e9SAndroid Build Coastguard Worker a = torch.rand(20, 20, dtype=data_type, device=device) 1664*da0073e9SAndroid Build Coastguard Worker b = torch.rand(20 * 29, dtype=data_type, device=device).as_strided([20], [29]) 1665*da0073e9SAndroid Build Coastguard Worker c = torch.ones(20, dtype=torch.int64, device=device) 1666*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(foo, (a, b, c)) 1667*da0073e9SAndroid Build Coastguard Worker ref = foo(a, b, c) 1668*da0073e9SAndroid Build Coastguard Worker exp = traced(a, b, c) 1669*da0073e9SAndroid Build Coastguard Worker exp = traced(a, b, c) 1670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, exp) 1671*da0073e9SAndroid Build Coastguard Worker 1672*da0073e9SAndroid Build Coastguard Worker def test_propagated_mem_layout(self): 1673*da0073e9SAndroid Build Coastguard Worker def foo(a, b, c): 1674*da0073e9SAndroid Build Coastguard Worker t_next = c + 1 1675*da0073e9SAndroid Build Coastguard Worker t5 = t_next * b 1676*da0073e9SAndroid Build Coastguard Worker t7 = a * t5 1677*da0073e9SAndroid Build Coastguard Worker return t7 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker def foo_multi_outputs(a, b, c): 1680*da0073e9SAndroid Build Coastguard Worker t_next = c + 1 1681*da0073e9SAndroid Build Coastguard Worker t5 = b * t_next 1682*da0073e9SAndroid Build Coastguard Worker t7 = a * t5 1683*da0073e9SAndroid Build Coastguard Worker return (t7, t5, t_next) 1684*da0073e9SAndroid Build Coastguard Worker 1685*da0073e9SAndroid Build Coastguard Worker def foo_multi_outputs_i_nhwc_o_nchw(a, b, c): 1686*da0073e9SAndroid Build Coastguard Worker t_next = c + 1 1687*da0073e9SAndroid Build Coastguard Worker t5 = b * t_next 1688*da0073e9SAndroid Build Coastguard Worker t7 = a * t5 1689*da0073e9SAndroid Build Coastguard Worker t8 = t7.to(memory_format=torch.contiguous_format) 1690*da0073e9SAndroid Build Coastguard Worker return (t8, t7, t5, t_next) 1691*da0073e9SAndroid Build Coastguard Worker 1692*da0073e9SAndroid Build Coastguard Worker def run_foo_case(foo, a, b, c): 1693*da0073e9SAndroid Build Coastguard Worker traced_contiguous = torch.jit.trace(foo, (a, b, c)) 1694*da0073e9SAndroid Build Coastguard Worker ref = foo(a, b, c) 1695*da0073e9SAndroid Build Coastguard Worker exp = traced_contiguous(a, b, c) 1696*da0073e9SAndroid Build Coastguard Worker exp = traced_contiguous(a, b, c) 1697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, exp) 1698*da0073e9SAndroid Build Coastguard Worker 1699*da0073e9SAndroid Build Coastguard Worker mem_layouts = list(itertools.product([torch.contiguous_format, torch.channels_last], repeat=3)) 1700*da0073e9SAndroid Build Coastguard Worker shapes = [(2, 3, 4, 5), (2, 1, 1, 5), (1, 1, 1, 1)] 1701*da0073e9SAndroid Build Coastguard Worker permutes = [(0, 3, 2, 1), (0, 3, 1, 2)] 1702*da0073e9SAndroid Build Coastguard Worker funcs = [foo, foo_multi_outputs, foo_multi_outputs_i_nhwc_o_nchw] 1703*da0073e9SAndroid Build Coastguard Worker configs = itertools.product(funcs, shapes, mem_layouts, permutes) 1704*da0073e9SAndroid Build Coastguard Worker for strategy in ["STATIC", "DYNAMIC"]: 1705*da0073e9SAndroid Build Coastguard Worker old_strategy = torch.jit.set_fusion_strategy([(strategy, 10)]) 1706*da0073e9SAndroid Build Coastguard Worker for _func, _shape, _mem_layouts, _permute in configs: 1707*da0073e9SAndroid Build Coastguard Worker a = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[0]) 1708*da0073e9SAndroid Build Coastguard Worker b = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[1]) 1709*da0073e9SAndroid Build Coastguard Worker c = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[2]) 1710*da0073e9SAndroid Build Coastguard Worker run_foo_case(_func, a, b, c) 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker a = a.permute(dims=_permute) 1713*da0073e9SAndroid Build Coastguard Worker b = b.permute(dims=_permute) 1714*da0073e9SAndroid Build Coastguard Worker c = c.permute(dims=_permute) 1715*da0073e9SAndroid Build Coastguard Worker run_foo_case(_func, a, b, c) 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker torch.jit.set_fusion_strategy(old_strategy) 1718*da0073e9SAndroid Build Coastguard Worker 1719*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 1720*da0073e9SAndroid Build Coastguard Worker run_tests() 1721