1# Owner(s): ["module: cuda graphs"] 2 3import functools 4import unittest 5 6import torch 7import torch._dynamo 8import torch._dynamo.config 9import torch._dynamo.test_case 10import torch._dynamo.testing 11from torch._dynamo.testing import same 12from torch.testing._internal.common_utils import TEST_CUDA_GRAPH 13 14 15def composed(*decs): 16 def deco(f): 17 for dec in reversed(decs): 18 f = dec(f) 19 return f 20 21 return deco 22 23 24def assert_aot_autograd_counter(ok=True): 25 def deco(f): 26 @functools.wraps(f) 27 def wrap(self, *args, **kwargs): 28 torch._dynamo.utils.counters.clear() 29 r = f(self, *args, **kwargs) 30 c_ok = torch._dynamo.utils.counters["aot_autograd"]["ok"] 31 c_not_ok = torch._dynamo.utils.counters["aot_autograd"]["not_ok"] 32 if ok: 33 self.assertGreater(c_ok, 0) 34 self.assertEqual(c_not_ok, 0) 35 else: 36 self.assertEqual(c_ok, 0) 37 self.assertGreater(c_not_ok, 0) 38 return r 39 40 return wrap 41 42 return deco 43 44 45def patch_all(ok=True): 46 return composed( 47 torch._dynamo.config.patch( 48 verify_correctness=True, automatic_dynamic_shapes=True 49 ), 50 assert_aot_autograd_counter(ok), 51 ) 52 53 54N_ITERS = 5 55 56 57@unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda") 58class TestAotCudagraphs(torch._dynamo.test_case.TestCase): 59 @patch_all() 60 def test_basic(self): 61 def model(x, y): 62 return (x + y) * y 63 64 @torch._dynamo.optimize("cudagraphs") 65 def fn(x, y): 66 for i in range(N_ITERS): 67 loss = model(x, y).sum() 68 loss.backward() 69 70 x = torch.randn(3, device="cuda", requires_grad=True) 71 y = torch.randn(3, device="cuda") 72 fn(x, y) 73 74 @patch_all() 75 def test_dtoh(self): 76 def model(x, y): 77 a = x + y 78 b = a.cpu() * 3 79 return b 80 81 @torch._dynamo.optimize("cudagraphs") 82 def fn(x, y): 83 for i in range(N_ITERS): 84 loss = model(x, y).sum() 85 loss.backward() 86 87 x = torch.randn(3, device="cuda", requires_grad=True) 88 y = torch.randn(3, device="cuda") 89 fn(x, y) 90 91 @patch_all() 92 def test_htod(self): 93 def model(x, y): 94 a = x + y 95 return a * 3 96 97 @torch._dynamo.optimize("cudagraphs") 98 def fn(x, y): 99 for i in range(N_ITERS): 100 loss = model(x, y).sum() 101 loss.backward() 102 103 x = torch.randn(3, device="cuda", requires_grad=True) 104 y = torch.randn((), device="cpu") 105 fn(x, y) 106 107 def test_mutate_input(self): 108 def model(x, y): 109 y.add_(3) 110 return x * y 111 112 @torch._dynamo.optimize("cudagraphs") 113 def fn(x, y): 114 for i in range(N_ITERS): 115 with self.subTest(i): 116 y_orig = y.clone() 117 loss = model(x, y).sum() 118 self.assertTrue(same(y, y_orig + 3)) 119 loss.backward() 120 121 x = torch.randn(3, device="cuda", requires_grad=True) 122 y = torch.randn(3, device="cuda") 123 fn(x, y) 124 125 @patch_all() 126 def test_mutate_constant(self): 127 def model(x, y): 128 c = torch.tensor(1) 129 c.add_(2) 130 return x * y * 0 + c 131 132 @torch._dynamo.optimize("cudagraphs") 133 def fn(x, y): 134 for i in range(N_ITERS): 135 with self.subTest(i): 136 loss = model(x, y).sum() 137 self.assertTrue(same(loss, torch.tensor(3.0, device="cuda"))) 138 loss.backward() 139 140 x = torch.randn(1, device="cuda", requires_grad=True) 141 y = torch.randn(1, device="cuda") 142 fn(x, y) 143 144 @patch_all() 145 def test_factory(self): 146 def model(y): 147 x = torch.zeros(3, device="cuda:0") 148 x.add_(3) 149 return x * y 150 151 @torch._dynamo.optimize("cudagraphs") 152 def fn(y): 153 for i in range(N_ITERS): 154 with self.subTest(i): 155 loss = model(y).sum() 156 loss.backward() 157 158 y = torch.randn(3, device="cuda:0", requires_grad=True) 159 fn(y) 160 161 @patch_all() 162 def test_mutated_metadata(self): 163 # more tortured example at 164 # https://github.com/pytorch/pytorch/issues/81385 165 def model(x): 166 x = x.clone() 167 x.resize_(20) 168 x.fill_(2) 169 return x 170 171 @torch._dynamo.optimize("cudagraphs") 172 def fn(x): 173 for i in range(N_ITERS): 174 with self.subTest(i): 175 rx = model(x) 176 self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0"))) 177 178 x = torch.empty(0, device="cuda:0") 179 fn(x) 180 181 @patch_all() 182 def test_dead_fill(self): 183 def model(x): 184 x = x.clone() 185 y = x[0:0] 186 x.fill_(2) 187 y.fill_(3) 188 return x, y 189 190 @torch._dynamo.optimize("cudagraphs") 191 def fn(x): 192 for i in range(N_ITERS): 193 with self.subTest(i): 194 rx, ry = model(x) 195 self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0"))) 196 self.assertTrue(same(ry, torch.empty(0, device="cuda:0"))) 197 198 x = torch.empty(20, device="cuda:0") 199 fn(x) 200 201 202if __name__ == "__main__": 203 from torch._dynamo.test_case import run_tests 204 205 if not TEST_CUDA_GRAPH: 206 if __name__ == "__main__": 207 import sys 208 209 sys.exit(0) 210 raise unittest.SkipTest("cuda graph test is skipped") 211 212 run_tests() 213