1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 5*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo import eval_frame 6*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.hooks import Hooks 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerc = 10 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerdef fn1(a, b): 13*da0073e9SAndroid Build Coastguard Worker return a + b - c 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerdef fn2(a, b): 17*da0073e9SAndroid Build Coastguard Worker x = 0 18*da0073e9SAndroid Build Coastguard Worker y = 1 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker def modify(): 21*da0073e9SAndroid Build Coastguard Worker nonlocal x 22*da0073e9SAndroid Build Coastguard Worker x += a + b + c 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 25*da0073e9SAndroid Build Coastguard Worker modify() 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker return x + y 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workerdef fn3(): 31*da0073e9SAndroid Build Coastguard Worker yield 1 32*da0073e9SAndroid Build Coastguard Worker yield 2 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Workerwith_debug_nops = eval_frame._optimize_catch_errors( 36*da0073e9SAndroid Build Coastguard Worker torch._dynamo.testing.debug_insert_nops, Hooks(None, None) 37*da0073e9SAndroid Build Coastguard Worker) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerclass NopTests(torch._dynamo.test_case.TestCase): 41*da0073e9SAndroid Build Coastguard Worker @with_debug_nops 42*da0073e9SAndroid Build Coastguard Worker def test1(self): 43*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn1(1, 2), -7) 44*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn1(1, 2), -7) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker @with_debug_nops 47*da0073e9SAndroid Build Coastguard Worker def test2(self): 48*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn2(1, 2), 27) 49*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn2(1, 2), 27) 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker @with_debug_nops 52*da0073e9SAndroid Build Coastguard Worker def test3(self): 53*da0073e9SAndroid Build Coastguard Worker t = fn3() 54*da0073e9SAndroid Build Coastguard Worker self.assertEqual(next(t), 1) 55*da0073e9SAndroid Build Coastguard Worker self.assertEqual(next(t), 2) 56*da0073e9SAndroid Build Coastguard Worker self.assertRaises(StopIteration, lambda: next(t)) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker def test_extended_args(self): 59*da0073e9SAndroid Build Coastguard Worker too_many_adds = "+".join(["a", "b"] * 256) 60*da0073e9SAndroid Build Coastguard Worker source = ( 61*da0073e9SAndroid Build Coastguard Worker f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)" 62*da0073e9SAndroid Build Coastguard Worker ) 63*da0073e9SAndroid Build Coastguard Worker fn = eval(source) 64*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1) 65*da0073e9SAndroid Build Coastguard Worker b = torch.ones(1) 66*da0073e9SAndroid Build Coastguard Worker fn = with_debug_nops(fn) 67*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(a, b).sum(), 513) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 71*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker run_tests() 74