# Owner(s): ["module: cuda graphs"] import functools import unittest import torch import torch._dynamo import torch._dynamo.config import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import same from torch.testing._internal.common_utils import TEST_CUDA_GRAPH def composed(*decs): def deco(f): for dec in reversed(decs): f = dec(f) return f return deco def assert_aot_autograd_counter(ok=True): def deco(f): @functools.wraps(f) def wrap(self, *args, **kwargs): torch._dynamo.utils.counters.clear() r = f(self, *args, **kwargs) c_ok = torch._dynamo.utils.counters["aot_autograd"]["ok"] c_not_ok = torch._dynamo.utils.counters["aot_autograd"]["not_ok"] if ok: self.assertGreater(c_ok, 0) self.assertEqual(c_not_ok, 0) else: self.assertEqual(c_ok, 0) self.assertGreater(c_not_ok, 0) return r return wrap return deco def patch_all(ok=True): return composed( torch._dynamo.config.patch( verify_correctness=True, automatic_dynamic_shapes=True ), assert_aot_autograd_counter(ok), ) N_ITERS = 5 @unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda") class TestAotCudagraphs(torch._dynamo.test_case.TestCase): @patch_all() def test_basic(self): def model(x, y): return (x + y) * y @torch._dynamo.optimize("cudagraphs") def fn(x, y): for i in range(N_ITERS): loss = model(x, y).sum() loss.backward() x = torch.randn(3, device="cuda", requires_grad=True) y = torch.randn(3, device="cuda") fn(x, y) @patch_all() def test_dtoh(self): def model(x, y): a = x + y b = a.cpu() * 3 return b @torch._dynamo.optimize("cudagraphs") def fn(x, y): for i in range(N_ITERS): loss = model(x, y).sum() loss.backward() x = torch.randn(3, device="cuda", requires_grad=True) y = torch.randn(3, device="cuda") fn(x, y) @patch_all() def test_htod(self): def model(x, y): a = x + y return a * 3 @torch._dynamo.optimize("cudagraphs") def fn(x, y): for i in range(N_ITERS): loss = model(x, y).sum() loss.backward() x = torch.randn(3, device="cuda", requires_grad=True) y = torch.randn((), device="cpu") fn(x, y) def test_mutate_input(self): def model(x, y): y.add_(3) return x * y @torch._dynamo.optimize("cudagraphs") def fn(x, y): for i in range(N_ITERS): with self.subTest(i): y_orig = y.clone() loss = model(x, y).sum() self.assertTrue(same(y, y_orig + 3)) loss.backward() x = torch.randn(3, device="cuda", requires_grad=True) y = torch.randn(3, device="cuda") fn(x, y) @patch_all() def test_mutate_constant(self): def model(x, y): c = torch.tensor(1) c.add_(2) return x * y * 0 + c @torch._dynamo.optimize("cudagraphs") def fn(x, y): for i in range(N_ITERS): with self.subTest(i): loss = model(x, y).sum() self.assertTrue(same(loss, torch.tensor(3.0, device="cuda"))) loss.backward() x = torch.randn(1, device="cuda", requires_grad=True) y = torch.randn(1, device="cuda") fn(x, y) @patch_all() def test_factory(self): def model(y): x = torch.zeros(3, device="cuda:0") x.add_(3) return x * y @torch._dynamo.optimize("cudagraphs") def fn(y): for i in range(N_ITERS): with self.subTest(i): loss = model(y).sum() loss.backward() y = torch.randn(3, device="cuda:0", requires_grad=True) fn(y) @patch_all() def test_mutated_metadata(self): # more tortured example at # https://github.com/pytorch/pytorch/issues/81385 def model(x): x = x.clone() x.resize_(20) x.fill_(2) return x @torch._dynamo.optimize("cudagraphs") def fn(x): for i in range(N_ITERS): with self.subTest(i): rx = model(x) self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0"))) x = torch.empty(0, device="cuda:0") fn(x) @patch_all() def test_dead_fill(self): def model(x): x = x.clone() y = x[0:0] x.fill_(2) y.fill_(3) return x, y @torch._dynamo.optimize("cudagraphs") def fn(x): for i in range(N_ITERS): with self.subTest(i): rx, ry = model(x) self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0"))) self.assertTrue(same(ry, torch.empty(0, device="cuda:0"))) x = torch.empty(20, device="cuda:0") fn(x) if __name__ == "__main__": from torch._dynamo.test_case import run_tests if not TEST_CUDA_GRAPH: if __name__ == "__main__": import sys sys.exit(0) raise unittest.SkipTest("cuda graph test is skipped") run_tests()