xref: /aosp_15_r20/external/pytorch/test/dynamo/test_nops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport torch
3*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
5*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo import eval_frame
6*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.hooks import Hooks
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerc = 10
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerdef fn1(a, b):
13*da0073e9SAndroid Build Coastguard Worker    return a + b - c
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerdef fn2(a, b):
17*da0073e9SAndroid Build Coastguard Worker    x = 0
18*da0073e9SAndroid Build Coastguard Worker    y = 1
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker    def modify():
21*da0073e9SAndroid Build Coastguard Worker        nonlocal x
22*da0073e9SAndroid Build Coastguard Worker        x += a + b + c
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    for _ in range(2):
25*da0073e9SAndroid Build Coastguard Worker        modify()
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    return x + y
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerdef fn3():
31*da0073e9SAndroid Build Coastguard Worker    yield 1
32*da0073e9SAndroid Build Coastguard Worker    yield 2
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerwith_debug_nops = eval_frame._optimize_catch_errors(
36*da0073e9SAndroid Build Coastguard Worker    torch._dynamo.testing.debug_insert_nops, Hooks(None, None)
37*da0073e9SAndroid Build Coastguard Worker)
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerclass NopTests(torch._dynamo.test_case.TestCase):
41*da0073e9SAndroid Build Coastguard Worker    @with_debug_nops
42*da0073e9SAndroid Build Coastguard Worker    def test1(self):
43*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn1(1, 2), -7)
44*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn1(1, 2), -7)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    @with_debug_nops
47*da0073e9SAndroid Build Coastguard Worker    def test2(self):
48*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn2(1, 2), 27)
49*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn2(1, 2), 27)
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker    @with_debug_nops
52*da0073e9SAndroid Build Coastguard Worker    def test3(self):
53*da0073e9SAndroid Build Coastguard Worker        t = fn3()
54*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(next(t), 1)
55*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(next(t), 2)
56*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(StopIteration, lambda: next(t))
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    def test_extended_args(self):
59*da0073e9SAndroid Build Coastguard Worker        too_many_adds = "+".join(["a", "b"] * 256)
60*da0073e9SAndroid Build Coastguard Worker        source = (
61*da0073e9SAndroid Build Coastguard Worker            f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)"
62*da0073e9SAndroid Build Coastguard Worker        )
63*da0073e9SAndroid Build Coastguard Worker        fn = eval(source)
64*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1)
65*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(1)
66*da0073e9SAndroid Build Coastguard Worker        fn = with_debug_nops(fn)
67*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(a, b).sum(), 513)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
71*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    run_tests()
74