xref: /aosp_15_r20/external/pytorch/test/dynamo/test_cudagraphs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cuda graphs"]
2
3import functools
4import unittest
5
6import torch
7import torch._dynamo
8import torch._dynamo.config
9import torch._dynamo.test_case
10import torch._dynamo.testing
11from torch._dynamo.testing import same
12from torch.testing._internal.common_utils import TEST_CUDA_GRAPH
13
14
15def composed(*decs):
16    def deco(f):
17        for dec in reversed(decs):
18            f = dec(f)
19        return f
20
21    return deco
22
23
24def assert_aot_autograd_counter(ok=True):
25    def deco(f):
26        @functools.wraps(f)
27        def wrap(self, *args, **kwargs):
28            torch._dynamo.utils.counters.clear()
29            r = f(self, *args, **kwargs)
30            c_ok = torch._dynamo.utils.counters["aot_autograd"]["ok"]
31            c_not_ok = torch._dynamo.utils.counters["aot_autograd"]["not_ok"]
32            if ok:
33                self.assertGreater(c_ok, 0)
34                self.assertEqual(c_not_ok, 0)
35            else:
36                self.assertEqual(c_ok, 0)
37                self.assertGreater(c_not_ok, 0)
38            return r
39
40        return wrap
41
42    return deco
43
44
45def patch_all(ok=True):
46    return composed(
47        torch._dynamo.config.patch(
48            verify_correctness=True, automatic_dynamic_shapes=True
49        ),
50        assert_aot_autograd_counter(ok),
51    )
52
53
54N_ITERS = 5
55
56
57@unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda")
58class TestAotCudagraphs(torch._dynamo.test_case.TestCase):
59    @patch_all()
60    def test_basic(self):
61        def model(x, y):
62            return (x + y) * y
63
64        @torch._dynamo.optimize("cudagraphs")
65        def fn(x, y):
66            for i in range(N_ITERS):
67                loss = model(x, y).sum()
68                loss.backward()
69
70        x = torch.randn(3, device="cuda", requires_grad=True)
71        y = torch.randn(3, device="cuda")
72        fn(x, y)
73
74    @patch_all()
75    def test_dtoh(self):
76        def model(x, y):
77            a = x + y
78            b = a.cpu() * 3
79            return b
80
81        @torch._dynamo.optimize("cudagraphs")
82        def fn(x, y):
83            for i in range(N_ITERS):
84                loss = model(x, y).sum()
85                loss.backward()
86
87        x = torch.randn(3, device="cuda", requires_grad=True)
88        y = torch.randn(3, device="cuda")
89        fn(x, y)
90
91    @patch_all()
92    def test_htod(self):
93        def model(x, y):
94            a = x + y
95            return a * 3
96
97        @torch._dynamo.optimize("cudagraphs")
98        def fn(x, y):
99            for i in range(N_ITERS):
100                loss = model(x, y).sum()
101                loss.backward()
102
103        x = torch.randn(3, device="cuda", requires_grad=True)
104        y = torch.randn((), device="cpu")
105        fn(x, y)
106
107    def test_mutate_input(self):
108        def model(x, y):
109            y.add_(3)
110            return x * y
111
112        @torch._dynamo.optimize("cudagraphs")
113        def fn(x, y):
114            for i in range(N_ITERS):
115                with self.subTest(i):
116                    y_orig = y.clone()
117                    loss = model(x, y).sum()
118                    self.assertTrue(same(y, y_orig + 3))
119                    loss.backward()
120
121        x = torch.randn(3, device="cuda", requires_grad=True)
122        y = torch.randn(3, device="cuda")
123        fn(x, y)
124
125    @patch_all()
126    def test_mutate_constant(self):
127        def model(x, y):
128            c = torch.tensor(1)
129            c.add_(2)
130            return x * y * 0 + c
131
132        @torch._dynamo.optimize("cudagraphs")
133        def fn(x, y):
134            for i in range(N_ITERS):
135                with self.subTest(i):
136                    loss = model(x, y).sum()
137                    self.assertTrue(same(loss, torch.tensor(3.0, device="cuda")))
138                    loss.backward()
139
140        x = torch.randn(1, device="cuda", requires_grad=True)
141        y = torch.randn(1, device="cuda")
142        fn(x, y)
143
144    @patch_all()
145    def test_factory(self):
146        def model(y):
147            x = torch.zeros(3, device="cuda:0")
148            x.add_(3)
149            return x * y
150
151        @torch._dynamo.optimize("cudagraphs")
152        def fn(y):
153            for i in range(N_ITERS):
154                with self.subTest(i):
155                    loss = model(y).sum()
156                    loss.backward()
157
158        y = torch.randn(3, device="cuda:0", requires_grad=True)
159        fn(y)
160
161    @patch_all()
162    def test_mutated_metadata(self):
163        # more tortured example at
164        # https://github.com/pytorch/pytorch/issues/81385
165        def model(x):
166            x = x.clone()
167            x.resize_(20)
168            x.fill_(2)
169            return x
170
171        @torch._dynamo.optimize("cudagraphs")
172        def fn(x):
173            for i in range(N_ITERS):
174                with self.subTest(i):
175                    rx = model(x)
176                    self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
177
178        x = torch.empty(0, device="cuda:0")
179        fn(x)
180
181    @patch_all()
182    def test_dead_fill(self):
183        def model(x):
184            x = x.clone()
185            y = x[0:0]
186            x.fill_(2)
187            y.fill_(3)
188            return x, y
189
190        @torch._dynamo.optimize("cudagraphs")
191        def fn(x):
192            for i in range(N_ITERS):
193                with self.subTest(i):
194                    rx, ry = model(x)
195                    self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
196                    self.assertTrue(same(ry, torch.empty(0, device="cuda:0")))
197
198        x = torch.empty(20, device="cuda:0")
199        fn(x)
200
201
202if __name__ == "__main__":
203    from torch._dynamo.test_case import run_tests
204
205    if not TEST_CUDA_GRAPH:
206        if __name__ == "__main__":
207            import sys
208
209            sys.exit(0)
210        raise unittest.SkipTest("cuda graph test is skipped")
211
212    run_tests()
213