1# Owner(s): ["module: inductor"] 2import contextlib 3import functools 4import gc 5import importlib 6import sys 7import unittest 8import warnings 9from unittest import mock 10 11import torch 12import torch._dynamo.config as dynamo_config 13import torch.nn as nn 14from torch._dynamo.utils import counters 15from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache 16from torch._inductor import config 17from torch._inductor.codecache import FxGraphCache 18from torch._inductor.compile_fx import compile_fx_inner 19from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl 20from torch._inductor.cudagraph_utils import FunctionID 21from torch._inductor.test_case import TestCase as InductorTestCase 22from torch.fx.experimental.proxy_tensor import make_fx 23from torch.testing import FileCheck 24from torch.testing._internal.common_cuda import TEST_MULTIGPU 25from torch.testing._internal.common_utils import ( 26 instantiate_parametrized_tests, 27 IS_CI, 28 IS_LINUX, 29 IS_WINDOWS, 30 parametrize, 31 skipIfRocm, 32 TEST_CUDA_GRAPH, 33 TEST_WITH_ASAN, 34) 35from torch.utils._python_dispatch import TorchDispatchMode 36 37 38if IS_WINDOWS and IS_CI: 39 sys.stderr.write( 40 "Windows CI does not have necessary dependencies for test_torchinductor yet\n" 41 ) 42 if __name__ == "__main__": 43 sys.exit(0) 44 raise unittest.SkipTest("requires sympy/functorch/filelock") 45 46importlib.import_module("functorch") 47importlib.import_module("filelock") 48 49from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA 50 51 52aten = torch.ops.aten 53requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 54requires_multigpu = functools.partial( 55 unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices" 56) 57from io import StringIO 58 59 60def get_compile_fn(backend): 61 if backend == "cudagraphs": 62 return functools.partial(torch.compile, backend="cudagraphs") 63 else: 64 return functools.partial(torch.compile, mode="reduce-overhead") 65 66 67class capture_stderr(list): 68 """ 69 Replace sys.stderr with a temporary StringIO 70 """ 71 72 def __enter__(self): 73 self.sys_stderr = sys.stderr 74 self.stringio = StringIO() 75 sys.stderr = self.stringio 76 return self 77 78 def __exit__(self, *args): 79 self.append(str(self.stringio.getvalue())) 80 del self.stringio 81 sys.stderr = self.sys_stderr 82 83 84def cdata(t): 85 return t.untyped_storage()._cdata 86 87 88class TestCase(InductorTestCase): 89 @classmethod 90 def setUpClass(cls): 91 super().setUpClass() 92 cls._stack = contextlib.ExitStack() 93 cls._stack.enter_context( 94 config.patch( 95 { 96 "debug": True, 97 "cpp.min_chunk_size": 1, 98 "triton.autotune_pointwise": False, # too slow 99 "implicit_fallbacks": False, 100 } 101 ) 102 ) 103 104 @classmethod 105 def tearDownClass(cls): 106 cls._stack.close() 107 super().tearDownClass() 108 109 def setUp(self): 110 torch._dynamo.reset() 111 super().setUp() 112 113 def tearDown(self): 114 super().tearDown() 115 torch._dynamo.reset() 116 117 118if HAS_CUDA and not TEST_WITH_ASAN: 119 120 def get_all_cudagraph_segments(): 121 segments = torch.cuda.memory_snapshot() 122 return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)] 123 124 def all_live_blocks(): 125 blocks_addrs = [] 126 for segment in get_all_cudagraph_segments(): 127 addr = segment["address"] 128 for block in segment["blocks"]: 129 if block["state"] == "active_allocated": 130 blocks_addrs.append(addr) 131 addr += block["size"] 132 133 return blocks_addrs 134 135 def all_live_block_count(): 136 return len(all_live_blocks()) 137 138 class CudaGraphTreeTests(TestCase): 139 def setUp(self): 140 super().setUp() 141 self.graph_stack = contextlib.ExitStack() 142 self.graph_stack.enter_context( 143 config.patch( 144 { 145 "triton.cudagraphs": True, 146 "triton.cudagraph_trees": True, 147 "triton.fast_path_cudagraph_asserts": True, # too slow 148 "triton.slow_path_cudagraph_asserts": True, 149 } 150 ) 151 ) 152 self.graph_stack.enter_context( 153 dynamo_config.patch(automatic_dynamic_shapes=True) 154 ) 155 self.device_idx = torch.rand([0], device="cuda").device.index 156 warnings.filterwarnings("ignore") 157 158 def tearDown(self): 159 super().tearDown() 160 torch._dynamo.reset() 161 gc.collect() 162 torch.cuda.empty_cache() 163 self.graph_stack.close() 164 165 self.assertIsNone(self.get_manager()) 166 self.assertEqual(all_live_block_count(), 0) 167 self.assertEqual(len(get_all_cudagraph_segments()), 0) 168 warnings.resetwarnings() 169 170 def get_manager(self, device_index=None): 171 return torch._inductor.cudagraph_trees.get_container( 172 self.device_idx if not device_index else device_index 173 ).tree_manager 174 175 def get_roots(self): 176 return self.get_manager().get_roots() 177 178 def curr_node(self): 179 return self.get_manager().current_node 180 181 def get_root_children(self): 182 return [root.num_descendants() for root in self.get_roots()] 183 184 def cudagraphify_impl( 185 self, *args, is_inference=True, is_backward=False, **kwargs 186 ): 187 return tree_cudagraphify_impl( 188 *args, 189 **kwargs, 190 device_index=self.device_idx, 191 is_inference=is_inference, 192 is_backward=is_backward, 193 ) 194 195 @staticmethod 196 def run_twc(fn, *args, **kwargs): 197 fn(*args, **kwargs) 198 return fn(*args, **kwargs) 199 200 def num_checkpoints(self): 201 return self.get_manager().debug_checkpointing_counter 202 203 def test_run_simple(self): 204 def foo(x): 205 return x * x * x 206 207 foo_opt = torch.compile(foo) 208 ones = torch.ones([4, 4], device="cuda") 209 zeros = torch.zeros([5, 5], device="cuda") 210 self.run_twc(foo_opt, ones) 211 self.run_twc(foo_opt, zeros) 212 self.assertEqual(self.get_root_children(), [0, 0]) 213 214 def check_rng(self): 215 @torch.compile(mode="reduce-overhead") 216 def foo(): 217 return torch.rand([20]) 218 219 torch.manual_seed(0) 220 221 out = foo() 222 out2 = foo() 223 out3 = foo() 224 225 torch.manual_seed(0) 226 227 self.assertEqual(out, foo()) 228 self.assertEqual(out2, foo()) 229 self.assertEqual(out3, foo()) 230 231 @torch._inductor.config.patch("fallback_random", True) 232 def test_rng_trees(self): 233 self.check_rng() 234 235 @torch._inductor.config.patch("triton.cudagraph_trees", False) 236 @torch._inductor.config.patch("fallback_random", True) 237 def test_rng_non_trees(self): 238 self.check_rng() 239 240 def test_mutation_reinplaced(self): 241 import torch.nn as nn 242 243 class Model(nn.Module): 244 def __init__(self) -> None: 245 super().__init__() 246 247 def forward(self, input, other, out): 248 input = torch.logical_xor(input=input, other=other, out=out) 249 return input 250 251 x = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float32).cuda() 252 y = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float32).cuda() 253 z = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float16).cuda() 254 255 model = Model().cuda() 256 eag = model(x, y, z) 257 with capture_stderr() as captured_output: 258 opt = torch.compile(model.forward, mode="reduce-overhead")(x, y, z) 259 260 FileCheck().check( 261 "skipping cudagraphs due to mutated inputs (1 instances). Found from" 262 ).check("torch.logical_xor").run(captured_output[0]) 263 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 264 265 @requires_multigpu() 266 @parametrize("backend", ("inductor", "cudagraphs")) 267 def test_multiple_devices_msg(self, backend): 268 def foo(x, y): 269 return (x + 1, y + 2) 270 271 foo = get_compile_fn(backend)(foo) 272 with capture_stderr() as captured_output: 273 foo(torch.ones([10], device="cuda"), torch.ones([20])) 274 275 FileCheck().check( 276 "skipping cudagraphs due to cpu device (arg1_1). Found from" 277 ).check("y + 2").run(captured_output[0]) 278 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 279 280 with capture_stderr() as captured_output: 281 foo( 282 torch.ones([10], device="cuda:0"), torch.ones([10], device="cuda:1") 283 ) 284 285 FileCheck().check("skipping cudagraphs due to multiple devices").run( 286 captured_output[0] 287 ) 288 self.assertEqual(counters["inductor"]["cudagraph_skips"], 2) 289 290 @torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True) 291 def test_skip_symbolic(self): 292 @torch.compile(dynamic=True) 293 def foo(x, y): 294 return x + y 295 296 with capture_stderr() as captured_output: 297 foo(torch.rand([10], device="cuda"), torch.rand([10], device="cuda")) 298 299 FileCheck().check( 300 "skipping cudagraphs due to graph with symbolic shapes inputs" 301 ).check("x + y").run(captured_output[0]) 302 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 303 304 @parametrize("backend", ("inductor", "cudagraphs")) 305 @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True) 306 @torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True) 307 @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) 308 def test_mutation_on_inp(self, backend): 309 def foo(x): 310 x.add_(2) 311 return x 312 313 foo = get_compile_fn(backend)(foo) 314 315 def inp(): 316 return torch.ones([10], device="cuda") 317 318 with capture_stderr() as captured_output: 319 foo(inp()) 320 321 FileCheck().check( 322 "skipping cudagraphs due to mutated inputs (1 instances). Found from" 323 ).check(".add_(2)").run(captured_output[0]) 324 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 325 326 # mutation on inp doesnt hit cudagraphs 327 self.assertEqual(len(self.get_manager().roots), 0) 328 329 # mutation on parameters/buffers hits cudagraphs 330 class Mod(torch.nn.Module): 331 def __init__(self) -> None: 332 super().__init__() 333 self.buf = torch.ones([10], device="cuda") 334 335 def forward(self, x): 336 self.buf.add_(x) 337 return self.buf + x 338 339 def foo(mod, x): 340 return mod(x) 341 342 foo = get_compile_fn(backend)(foo) 343 mod = Mod() 344 mod2 = Mod() 345 346 for _ in range(3): 347 self.assertEqual(foo(mod, inp()), mod2(inp())) 348 self.assertEqual(mod.buf, mod2.buf) 349 350 self.assertIsNotNone(self.get_manager()) 351 352 @parametrize("backend", ("inductor", "cudagraphs")) 353 @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True) 354 @torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", False) 355 @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False) 356 def test_mutation_cudagraph_managed_tensors_config(self, backend): 357 def foo(x): 358 return x + 1 359 360 def mut(x): 361 x.add_(2) 362 return x 363 364 def non_mut(x): 365 return x.add(2) 366 367 mut = get_compile_fn(backend)(mut) 368 foo = get_compile_fn(backend)(foo) 369 370 with capture_stderr() as captured_output: 371 for i in range(3): 372 torch.compiler.cudagraph_mark_step_begin() 373 inp = torch.rand([4], device="cuda") 374 375 tmp = foo(inp) 376 mut_out = mut(tmp) 377 self.assertEqual(mut_out, non_mut(foo(inp))) 378 FileCheck().check_count( 379 "skipping cudagraphs due to mutated inputs (1 instances). Found from", 380 1, 381 exactly=True, 382 ).run(captured_output[0]) 383 384 @parametrize("backend", ("inductor", "cudagraphs")) 385 @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True) 386 @torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True) 387 @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) 388 def test_mutation_cudagraph_managed_tensors(self, backend): 389 def foo(x): 390 return x + 1 391 392 def mut(x): 393 x.add_(2) 394 return x 395 396 def non_mut(x): 397 return x.add(2) 398 399 mut = get_compile_fn(backend)(mut) 400 foo = get_compile_fn(backend)(foo) 401 402 with capture_stderr() as captured_output: 403 for i in range(3): 404 torch.compiler.cudagraph_mark_step_begin() 405 inp = torch.rand([4], device="cuda") 406 407 tmp = foo(inp) 408 mut_out = mut(tmp) 409 self.assertEqual(mut_out, non_mut(foo(inp))) 410 FileCheck().check_count( 411 "skipping cudagraphs due to mutated inputs (1 instances). Found from", 412 0, 413 exactly=True, 414 ).run(captured_output[0]) 415 self.assertTrue("cudagraph_skips" not in counters["inductor"]) 416 417 torch.compiler.cudagraph_mark_step_begin() 418 inp = torch.rand([4], device="cuda") 419 tmp = foo(inp) 420 mut_inp = tmp.clone() 421 # in this case, what previously a mutated cudagraph managed tensor is no longer, 422 # now its an input from eager we should fallback to inductor without cudagraphs 423 with capture_stderr() as captured_output: 424 mut(mut_inp) 425 FileCheck().check( 426 "skipping cudagraphs due to mutated inputs (1 instances). Found from" 427 ).check("x.add_(2)").run(captured_output[0]) 428 self.assertEqual(mut_inp, non_mut(foo(inp))) 429 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 430 431 @parametrize("backend", ("inductor", "cudagraphs")) 432 @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True) 433 @torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True) 434 @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) 435 def test_mutation_cudagraph_managed_tensor_warn(self, backend): 436 def foo(x): 437 return x.add_(1) 438 439 def fee(y, z): 440 return z.add(3) 441 442 def inp(): 443 return torch.rand([4], device="cuda") 444 445 foo = get_compile_fn(backend)(foo) 446 fee = get_compile_fn(backend)(fee) 447 448 with capture_stderr() as captured_output: 449 for _ in range(3): 450 torch.compiler.cudagraph_mark_step_begin() 451 fee(inp(), foo(inp())) 452 FileCheck().check_count( 453 "skipping cudagraphs due to mutated inputs (1 instances). Found from", 454 1, 455 exactly=True, 456 ).run(captured_output[0]) 457 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 458 459 @parametrize("backend", ("inductor", "cudagraphs")) 460 @torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True) 461 @torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True) 462 @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) 463 def test_mutation_cudagraph_managed_tensor_warn_only_once(self, backend): 464 def foo(x): 465 return x + 1 466 467 def mut(x): 468 x.add_(2) 469 return x 470 471 def inp(): 472 return torch.rand([4], device="cuda") 473 474 mut = get_compile_fn(backend)(mut) 475 foo = get_compile_fn(backend)(foo) 476 477 with capture_stderr() as captured_output: 478 # Should warn for current_node=None 479 mut(inp()) 480 481 for i in range(3): 482 torch.compiler.cudagraph_mark_step_begin() 483 tmp = foo(inp()) 484 mut(tmp) # should not warn 485 486 mut_inp = tmp.clone() 487 mut(mut_inp) # should not warn since mut has warned 488 489 FileCheck().check_count( 490 "skipping cudagraphs due to mutated inputs (1 instances). Found from", 491 1, 492 exactly=True, 493 ).run(captured_output[0]) 494 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 495 496 def test_function_compiled_multiple_times(self): 497 def foo(x): 498 y = foo2(x) 499 y2 = foo2(y) 500 return y + y2 501 502 def foo2(x): 503 torch._dynamo.graph_break() 504 return x * x * x 505 506 foo_opt = torch.compile(foo) 507 ones = torch.ones([4, 4], device="cuda") 508 foo(ones) 509 foo_opt(ones) 510 foo_opt(ones) 511 self.assertEqual(foo_opt(ones), foo(ones)) 512 # paths 513 children = self.get_root_children() 514 # one root with two children 515 self.assertEqual(children, [2]) 516 517 def test_end_recording_early(self): 518 def foo(x): 519 y = x * x * x 520 torch._dynamo.graph_break() 521 z = x + y 522 return z 523 524 @torch.compile 525 def foo2(x): 526 return x + 4 527 528 foo_opt = torch.compile(foo) 529 530 for _ in range(3): 531 out = foo_opt(torch.ones([4, 4], device="cuda")) 532 del out 533 534 # when I tried inducing separate recordings via graph break, 535 # the frame kept interferring by keeping outputs alive 536 # this isnt great by simulates the logic. 537 from torch._dynamo.mutation_guard import GenerationTracker 538 539 GenerationTracker.generation -= 1 540 541 out = foo2(torch.ones([4, 4], device="cuda")) 542 del out 543 544 foo_opt(torch.ones([4, 4], device="cuda")) 545 546 # Two separate traces - one has a child, one doesnt 547 self.assertEqual(self.get_root_children(), [1, 0]) 548 549 def test_execution_into_recording(self): 550 def foo(x): 551 y = x + x 552 553 if y.sum() > 0: 554 return y + 10 555 else: 556 return y - 10 557 558 foo_opt = torch.compile(foo) 559 inp = torch.zeros([4, 4], dtype=torch.float, device="cuda") 560 self.assertEqual(foo_opt(inp), foo(inp)) 561 self.assertEqual(foo_opt(inp), foo(inp)) 562 563 inp.add_(1) 564 out_eager = foo(inp) 565 out_warmup = foo_opt(inp) 566 self.assertEqual(out_warmup, out_eager) 567 # warmup should be have storage deallocator hooked on 568 self.assertEqual(all_live_block_count(), 1) 569 570 out_live = foo_opt(inp) 571 self.assertEqual(out_live, out_eager) 572 573 # should be in recording mode, with storage deallocator hooked on 574 self.assertEqual(all_live_block_count(), 1) 575 # warmup should have been freed 576 del out_warmup 577 # should be in recording mode, with storage deallocator hooked on 578 self.assertEqual(all_live_block_count(), 1) 579 580 del out_live 581 self.assertEqual(all_live_block_count(), 0) 582 583 out = foo_opt(inp) 584 self.assertEqual(foo(inp), out) 585 586 # should be in execution mode 587 self.assertEqual(all_live_block_count(), 0) 588 589 def test_forward_with_skipped_cudagraphed_backward(self): 590 @torch.compile(mode="reduce-overhead") 591 def foo(x): 592 return x * x * x 593 594 for _ in range(3): 595 inp = torch.rand([20, 20], device="cuda", requires_grad=True) 596 out = foo(inp) 597 598 def complex_memory_overlap_new(t): 599 return True 600 601 try: 602 prev = torch._inductor.compile_fx.complex_memory_overlap 603 torch._inductor.compile_fx.complex_memory_overlap = ( 604 complex_memory_overlap_new 605 ) 606 back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda") 607 out.backward(back_inp) 608 finally: 609 torch._inductor.compile_fx.complex_memory_overlap = prev 610 611 # we should not have cudagraph'd the backwards 612 new_id = self.get_manager().new_graph_id().id 613 self.assertEqual(new_id, 1) 614 615 self.assertFalse(self.get_manager().running_forwards_with_pending_backwards) 616 617 @torch._functorch.config.patch("enable_autograd_cache", True) 618 @torch._inductor.config.patch("fx_graph_cache", True) 619 @torch._inductor.config.patch("fx_graph_remote_cache", False) 620 def test_cache_hit_forward_miss_backward(self): 621 # Test that we don't cache cudagraphs, skipping cudagraphs on backward on a cache miss 622 623 @torch.compile(mode="reduce-overhead") 624 def foo(x): 625 return x * x * x 626 627 def complex_memory_overlap_new(t): 628 return True 629 630 # Run forwards, fx graph should cache miss 631 for _ in range(3): 632 torch._dynamo.reset() 633 counters.clear() 634 FxGraphCache.clear() 635 AOTAutogradCache.clear() 636 637 with mock.patch( 638 "torch._inductor.compile_fx.complex_memory_overlap", 639 new=complex_memory_overlap_new, 640 ): 641 inp = torch.rand([20, 20], device="cuda", requires_grad=True) 642 out = foo(inp) 643 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 644 645 # Reset dynamo and related caches except for FXGraphCache 646 torch._dynamo.reset() 647 # Forwards should be a cache hit now, we still skip cudagraphs 648 inp = torch.rand([20, 20], device="cuda", requires_grad=True) 649 out = foo(inp) 650 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 651 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) 652 653 # Run backward without complex memory overlap being set 654 655 # Run the backward without complex memory overlap reason 656 # cache should miss, but cudagraphs should not run 657 # because forward skipped it 658 back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda") 659 out.backward(back_inp) 660 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) 661 662 # Run it one more time, this time AOTAutogradCache will hit 663 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2) 664 self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) 665 666 torch._dynamo.reset() 667 inp = torch.rand([20, 20], device="cuda", requires_grad=True) 668 out = foo(inp) 669 back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda") 670 out.backward(back_inp) 671 672 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) 673 674 # we should not have cudagraph'd anything 675 assert self.get_manager() is None 676 677 @torch._functorch.config.patch("enable_autograd_cache", True) 678 @torch._inductor.config.patch("fx_graph_cache", True) 679 @torch._inductor.config.patch("fx_graph_remote_cache", False) 680 def test_backward_gets_cached_cudagraphs(self): 681 # We pass cpu tensors to foo and save that into the cache 682 # On a subsequent run in a new process, cudagraphs should be 683 # disabled properly on both forward and backwards runs. 684 685 @torch.compile(mode="reduce-overhead") 686 def foo(x): 687 return x * x * x 688 689 torch._dynamo.reset() 690 counters.clear() 691 FxGraphCache.clear() 692 AOTAutogradCache.clear() 693 694 # Use cpu device to disable cudagraphs during compilation 695 inp = torch.rand([20, 20], device="cpu", requires_grad=True) 696 out = foo(inp) 697 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) 698 699 back_inp = torch.empty_strided([20, 20], [0, 1], device="cpu") 700 out.backward(back_inp) 701 self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) 702 703 # Run again on new process 704 torch._dynamo.reset() 705 706 # Forward and backward should also disable cudagraphs without compilation 707 inp = torch.rand([20, 20], device="cpu", requires_grad=True) 708 out = foo(inp) 709 # AOTAutogradCache will load the forward and the backward from cache immediately, so fx_graph_cache_hit will equal 2 710 self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) 711 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) 712 torch._dynamo.reset() 713 714 back_inp = torch.empty_strided([20, 20], [0, 1], device="cpu") 715 out.backward(back_inp) 716 717 # we should not have cudagraph'd anything 718 assert self.get_manager() is None 719 720 @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True) 721 @torch._functorch.config.patch("enable_autograd_cache", True) 722 @torch._inductor.config.patch("fx_graph_cache", True) 723 @torch._inductor.config.patch("fx_graph_remote_cache", False) 724 def test_cached_forward_backward(self): 725 counters.clear() 726 AOTAutogradCache.clear() 727 FxGraphCache.clear() 728 729 @torch.compile 730 def foo(x): 731 torch.manual_seed(0) 732 y = x * 2 733 return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4) 734 735 inp = torch.rand([4, 4], requires_grad=True, device="cuda") 736 inp2 = inp.detach().clone().requires_grad_(True) 737 out = foo(inp) 738 739 out.sum().backward() 740 741 self.assertEqual(self.get_root_children(), [1]) 742 743 # the three saved tensors should die in the backward 744 # we kept alive the output 745 self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) 746 self.assertEqual( 747 self.curr_node().expected_dead_indices_after_graph, 748 [(0, 1), (0, 2)], 749 ) 750 self.assertFalse(self.get_manager().new_graph_id().id == 0) 751 self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) 752 753 # Reset dynamo and rerun. We should see a cache hit now 754 torch._dynamo.reset() 755 756 out2 = foo(inp2) 757 out2.sum().backward() 758 self.assertEqual(out, out2) 759 self.assertEqual(inp.grad, inp2.grad) 760 761 self.assertEqual(self.get_root_children(), [1]) 762 self.assertFalse(self.get_manager().new_graph_id().id == 0) 763 self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) 764 765 @parametrize("backend", ("inductor", "cudagraphs")) 766 def test_forward_backward_not_called(self, backend): 767 def foo(x, y): 768 x_out = x * x * x 769 torch._dynamo.graph_break() 770 y_out = y * y * y 771 return x_out, y_out 772 773 foo = get_compile_fn(backend)(foo) 774 775 for _ in range(3): 776 inps = [ 777 torch.rand([20, 20], requires_grad=True, device="cuda") 778 for _ in range(2) 779 ] 780 x_out, y_out = foo(inps[0], inps[1]) 781 x_out.sum().backward() 782 783 self.assertFalse(self.get_manager().running_forwards_with_pending_backwards) 784 785 # we should not have cudagraph'd the y backward 786 new_id = self.get_manager().new_graph_id().id 787 self.assertEqual(new_id, 3) 788 789 def _test_unaligned_static_input_impl(self, expected_clones): 790 def fn(x, y): 791 return (x + y,) 792 793 def get_aligned_inputs(): 794 return [torch.rand([5, 5], device="cuda") for _ in range(2)] 795 796 mod = make_fx(fn)(*get_aligned_inputs()) 797 798 mode = torch._subclasses.FakeTensorMode() 799 800 with mode: 801 inps = [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)] 802 803 compiled_f = compile_fx_inner( 804 mod, inps, static_input_idxs=[0], cudagraphs=True 805 ) 806 807 def get_unaligned_inputs(): 808 return [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)] 809 810 class CloneCounterMode(TorchDispatchMode): 811 def __init__(self) -> None: 812 self.count = 0 813 814 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 815 kwargs = {} if kwargs is None else kwargs 816 self.count += func is torch.ops.aten.clone.default 817 return func(*args, **kwargs) 818 819 for _ in range(3): 820 with CloneCounterMode() as m: 821 compiled_f(get_unaligned_inputs()) 822 self.assertEqual(m.count, expected_clones) 823 824 compiled_f(get_aligned_inputs()) 825 self.assertEqual(m.count, expected_clones) 826 827 def test_unaligned_static_input_trees(self): 828 self._test_unaligned_static_input_impl(expected_clones=0) 829 830 @torch._inductor.config.patch("triton.cudagraph_trees", False) 831 def test_unaligned_static_input_non_trees(self): 832 self._test_unaligned_static_input_impl(expected_clones=0) 833 834 @torch._inductor.config.patch("triton.cudagraphs", False) 835 def test_unaligned_static_input_no_cudagraphs(self): 836 self._test_unaligned_static_input_impl(expected_clones=0) 837 838 def test_sparsity(self): 839 def foo(view_6, buf31): 840 return aten._sparse_coo_tensor_with_dims_and_tensors( 841 1, 842 1, 843 [1000000, 64], 844 view_6, 845 buf31, 846 dtype=torch.float32, 847 layout=torch.sparse_coo, 848 device="cuda", 849 pin_memory=None, 850 ) 851 852 foo_opt = torch.compile(foo) 853 854 view_6 = torch.zeros([1, 102397], dtype=torch.int64, device="cuda") 855 buf31 = torch.rand([102397, 64], device="cuda") 856 857 for _ in range(3): 858 self.assertEqual(foo_opt(view_6, buf31), foo(view_6, buf31)) 859 860 def test_accumulate_multiple_recordings(self): 861 def foo(x): 862 y = x + x + x 863 torch._dynamo.graph_break() 864 if y.sum() <= 0: 865 return y 866 else: 867 return y * 10 868 869 foo_opt = torch.compile(foo) 870 871 # two separate compilations & recordings 872 out1 = self.run_twc(foo_opt, torch.zeros([5], device="cuda")) 873 874 # out1 gets manually freed 875 out2 = self.run_twc(foo_opt, torch.zeros([6], device="cuda")) 876 877 self.assertEqual(all_live_block_count(), 1) 878 879 out3 = self.run_twc(foo_opt, torch.ones([5], device="cuda")) 880 881 self.assertEqual(out3, foo(torch.ones([5], device="cuda"))) 882 883 self.assertEqual(all_live_block_count(), 1) 884 del out1, out2 885 self.assertEqual(all_live_block_count(), 1) 886 887 del out3 888 gc.collect() 889 self.assertEqual(all_live_block_count(), 0) 890 891 @torch._inductor.config.patch("freezing", True) 892 def test_constant_output(self): 893 class Mod(torch.nn.Module): 894 def __init__(self) -> None: 895 super().__init__() 896 self.param = torch.nn.Parameter( 897 torch.tensor([float(i) for i in range(10)], device="cuda") 898 ) 899 900 def forward(self, inp): 901 return self.param, self.param[0:2], inp + 2 902 903 inp = torch.tensor([2], device="cuda") 904 m = Mod() 905 with torch.no_grad(): 906 out_eager = m(inp) 907 908 m_comp = torch.compile(m) 909 for _ in range(3): 910 self.assertEqual(out_eager, m_comp(inp)) 911 912 def test_live_outputs_multiple_graphs(self): 913 def foo(x): 914 x = x + x + x 915 y = x + 1 916 torch._dynamo.graph_break() 917 z = x * x 918 if z.sum() > 0: 919 return y + 1 920 else: 921 return y 922 923 foo_opt = torch.compile(foo) 924 925 self.run_twc(foo_opt, torch.zeros([5], device="cuda")) 926 self.assertEqual(self.num_checkpoints(), 0) 927 out = self.run_twc(foo_opt, torch.ones([5], device="cuda")) 928 929 self.assertEqual(all_live_block_count(), 1) 930 931 del out 932 self.assertEqual(all_live_block_count(), 0) 933 934 # we need to checkpoint from function to warmup y + 1, 935 # and then again to record it 936 self.assertEqual(self.num_checkpoints(), 2) 937 938 def test_expanded_inputs(self): 939 x = torch.rand(1, 512, device="cuda").expand(4, 512) 940 941 def foo(x): 942 return x + 4 + torch.ones([4, 512], device="cuda") 943 944 foo_opt = torch.compile()(foo) 945 946 for _ in range(3): 947 self.assertEqual(foo_opt(x), foo(x)) 948 949 self.assertFalse(self.get_manager().new_graph_id().id == 0) 950 951 @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True) 952 def test_tensor_dies_between_checkpoint(self): 953 def foo(args): 954 x = args[0] 955 args.clear() 956 return x + 1, x + 2 957 958 inp = torch.rand([4], device="cuda") 959 inp_list = [inp] 960 foo_cg = self.cudagraphify_impl(foo, inp_list, ()) 961 foo_cg(inp_list) 962 foo_cg([inp]) 963 964 out1, out2 = foo_cg([inp]) 965 inp = [out1] 966 967 del out1, out2 968 969 def foo2(args): 970 x = args[0] 971 args.clear() 972 return [x * x * x] 973 974 self.assertEqual(self.num_checkpoints(), 0) 975 foo2_cg = self.cudagraphify_impl(foo2, inp, ()) 976 977 x = foo2_cg(inp)[0] 978 979 self.assertEqual(self.num_checkpoints(), 1) 980 # out2 dies between the previous recording and the new one, 981 # need to be manually deallocated after the checkpoint 982 983 self.assertEqual(all_live_block_count(), 1) 984 del x 985 self.assertEqual(all_live_block_count(), 0) 986 987 def test_aliased_storage_single_weakref(self): 988 @torch.compile(mode="reduce-overhead") 989 def foo(x): 990 x = x * 20 991 x_alias = x[0] 992 y = x * 10 993 y_alias = y[0] 994 torch._dynamo.graph_break() 995 ind = torch.tensor(4, device="cuda") 996 x_alias2 = x[ind:] 997 y_alias2 = y[ind:] 998 return x, x_alias, x_alias2, y_alias, y_alias2 999 1000 for _ in range(4): 1001 outs = foo(torch.rand([20, 20], device="cuda")) 1002 1003 ptr_to_ref = { 1004 out.untyped_storage().data_ptr(): out.untyped_storage()._cdata 1005 for out in outs 1006 } 1007 1008 self.assertEqual(len(ptr_to_ref), 2) 1009 for out in outs: 1010 self.assertEqual( 1011 ptr_to_ref[out.untyped_storage().data_ptr()], 1012 out.untyped_storage()._cdata, 1013 ) 1014 del outs 1015 del out 1016 1017 node = self.get_manager().current_node 1018 self.assertEqual(len(list(node.path_live_weakrefs())), 0) 1019 self.assertFalse(self.get_manager().new_graph_id().id == 0) 1020 1021 def test_aliasing_static_ref(self): 1022 class Mod(torch.nn.Linear): 1023 def forward(self, x): 1024 return self.weight.T @ x, self.weight.T, self.weight[0:4] 1025 1026 m = Mod(10, 10).cuda() 1027 1028 @torch.compile(mode="reduce-overhead") 1029 def foo(mod, x): 1030 return mod(x) 1031 1032 @torch.compile(mode="reduce-overhead") 1033 def foo2(x): 1034 return x[2:] 1035 1036 x = torch.rand([10, 10], device="cuda", requires_grad=True) 1037 param_c = cdata(m.weight) 1038 for _ in range(3): 1039 out1, alias_1, alias_2 = foo(m, x) 1040 self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1) 1041 1042 out2 = foo2(out1) 1043 out2.sum().backward() 1044 self.assertEqual(cdata(out1), cdata(out2)) 1045 1046 node = self.curr_node() 1047 first_node = next(node._path_from_root) 1048 self.assertFalse(first_node.unaliased_in_all_paths[0]) 1049 self.assertTrue(first_node.cached_tensor_outputs[0] is None) 1050 1051 @skipIfRocm 1052 def test_checkpointing_resets_persistent_refs(self): 1053 @torch.compile(mode="reduce-overhead") 1054 def foo(x): 1055 return x @ x 1056 1057 def inp(): 1058 return torch.rand([20, 20], device="cuda", requires_grad=False) 1059 1060 for _ in range(3): 1061 foo(inp()) 1062 1063 self.assertEqual(self.num_checkpoints(), 0) 1064 1065 out = foo(inp()) 1066 out_id = id(out) 1067 del out 1068 self.assertEqual(id(foo(inp())), out_id) 1069 1070 @torch.compile(mode="reduce-overhead") 1071 def foo2(x): 1072 return x[0], x @ x 1073 1074 for i in range(2): 1075 out = foo(inp()) 1076 1077 from torch._dynamo.mutation_guard import GenerationTracker 1078 1079 GenerationTracker.generation -= 1 1080 1081 out_alias, out2 = foo2(out) 1082 del out_alias 1083 1084 self.assertEqual(all_live_block_count(), 2) 1085 del out 1086 self.assertEqual(all_live_block_count(), 1) 1087 del out2 1088 self.assertEqual(all_live_block_count(), 0) 1089 1090 self.assertEqual(self.num_checkpoints(), i + 1) 1091 1092 new_out = foo(inp()) 1093 curr_node = self.curr_node() 1094 self.assertFalse(curr_node.unaliased_in_all_paths[0]) 1095 self.assertFalse(out_id == id(new_out)) 1096 1097 def test_aliased_static_parameter(self): 1098 inp = torch.rand([20, 20], device="cuda") 1099 1100 def foo(args): 1101 x = args[0] 1102 args.clear() 1103 return (x[0],) 1104 1105 foo_cg = self.cudagraphify_impl(foo, [inp], (0,)) 1106 1107 for _ in range(3): 1108 out = foo_cg([inp])[0] 1109 self.assertEqual(cdata(inp), cdata(out)) 1110 1111 node = self.curr_node() 1112 self.assertEqual(node.cached_tensor_outputs, [None]) 1113 self.assertEqual(node.unaliased_in_all_paths, [False]) 1114 1115 def test_warmup_stream_sync(self): 1116 def foo(args): 1117 x = args[0] 1118 args.clear() 1119 x_orig = x 1120 for _ in range(100): 1121 x = x @ x 1122 return (x,) 1123 1124 inp = torch.rand([4096, 4096], device="cuda") 1125 ref = foo([inp])[0] 1126 torch.cuda.synchronize() 1127 1128 user_stream = torch.cuda.Stream() 1129 with torch.cuda.stream(user_stream): 1130 foo_cg = self.cudagraphify_impl(foo, [inp], (0,)) 1131 out = foo_cg([inp])[0] 1132 y = out + 1 1133 self.assertEqual(y, ref + 1) 1134 1135 def test_unaligned_static_parameter(self): 1136 def gen_inp(): 1137 inp = torch.ones([20], device="cuda") 1138 return [inp[1:]] 1139 1140 def foo(args): 1141 x = args[0] 1142 args.clear() 1143 return (x + x,) 1144 1145 foo_cg = self.cudagraphify_impl(foo, gen_inp(), (0,)) 1146 1147 for _ in range(3): 1148 out = foo_cg(gen_inp()) 1149 self.assertEqual(out, foo(gen_inp())) 1150 del out 1151 1152 node = self.curr_node() 1153 self.assertEqual(node.static_input_data_ptrs, [None]) 1154 1155 def test_amp_cache_disabled(self): 1156 @torch.compile() 1157 def foo(x): 1158 return x + x 1159 1160 for _ in range(3): 1161 out = foo(torch.rand([4, 4], device="cuda", requires_grad=True)) 1162 1163 # amp cache for cudagraph outputs should be disabled 1164 t2 = torch.rand([4, 4], device="cuda") 1165 1166 with torch.cuda.amp.autocast(): 1167 run_once = out @ t2 1168 1169 out.detach().zero_() 1170 1171 run_twice = out @ t2 1172 1173 self.assertNotEqual(run_once, run_twice) 1174 1175 def test_remove_hooks_on_cached_tensors(self): 1176 @torch.compile() 1177 def foo(x): 1178 return x * x 1179 1180 inp = torch.rand([4], device="cuda", requires_grad=True) 1181 1182 for _ in range(5): 1183 out = foo(inp) 1184 self.assertIsNone(out._backward_hooks) 1185 out.register_hook(lambda: None) 1186 1187 # today, torch.compile never outputs a leaf tensor which is the only 1188 # tensor that can register _post_accumulate_grad_hooks 1189 # add this as a preventative test 1190 1191 @torch.compile() 1192 def foo(x): 1193 return torch.rand([4], device="cuda", requires_grad=True) 1194 1195 for _ in range(5): 1196 out = foo(inp) 1197 self.assertIsNone(out._post_accumulate_grad_hooks) 1198 out.register_post_accumulate_grad_hook(lambda: None) 1199 1200 def test_multiple_insert_removal_caching(self): 1201 torch._C._set_cached_tensors_enabled(True) 1202 try: 1203 x = torch.rand([4], device="cuda") 1204 1205 torch._C._add_cached_tensor(x) 1206 self.assertTrue(torch._C._is_cached_tensor(x)) 1207 1208 torch._C._add_cached_tensor(x) 1209 torch._C._remove_cached_tensor(x) 1210 1211 self.assertFalse(torch._C._is_cached_tensor(x)) 1212 finally: 1213 torch._C._set_cached_tensors_enabled(False) 1214 1215 def test_accumulate_grad(self): 1216 # cudagraph trees shouldnt interfere with accumulation logic 1217 1218 def compute_grad(grad_output, create_graph): 1219 x = torch.randn(5, 5, requires_grad=True, device="cuda") 1220 1221 @torch.compile() 1222 def foo(x): 1223 return x + 2 1224 1225 y = foo(x) 1226 y.backward(grad_output, retain_graph=True) 1227 x_grad = x.grad 1228 x_grad_clone = x.grad.clone() 1229 y.backward(grad_output, create_graph=create_graph) 1230 return x_grad, x_grad_clone 1231 1232 for _ in range(3): 1233 grad_output = torch.ones(5, 5, device="cuda") 1234 1235 # Accumulate in-place when create_graph is False 1236 x_grad, x_grad_clone = compute_grad(grad_output, create_graph=False) 1237 self.assertEqual(x_grad, x_grad_clone * 2) 1238 1239 # Accumulate out-of-place when create_graph is False 1240 x_grad, x_grad_clone = compute_grad(grad_output, create_graph=True) 1241 self.assertEqual(x_grad, x_grad_clone) 1242 1243 def test_frozen_fn(self): 1244 @torch.compile() 1245 def foo(x): 1246 return x @ x 1247 1248 for _ in range(3): 1249 out = foo(torch.rand([10, 10], device="cuda")) 1250 1251 self.assertTrue(self.get_manager().new_graph_id().id == 1) 1252 frozen = torch._dynamo.run(foo) 1253 1254 for _ in range(3): 1255 out = frozen(torch.rand([10, 10], device="cuda")) 1256 1257 # didnt do additional recordings 1258 self.assertTrue(self.get_manager().new_graph_id().id == 2) 1259 1260 def test_empty_cpu_tensor(self): 1261 def foo(x): 1262 return x @ x, torch.tensor([]) 1263 1264 foo_opt = torch.compile(foo) 1265 x = torch.rand([4], device="cuda") 1266 1267 for _ in range(3): 1268 out_opt = foo_opt(x) 1269 self.assertEqual(foo(x), out_opt) 1270 1271 self.assertTrue(self.get_manager().new_graph_id().id == 1) 1272 1273 def test_output_alias(self): 1274 inp = torch.rand([20, 20], device="cuda") 1275 1276 def foo(args): 1277 x = args[0] 1278 args.clear() 1279 out = x + x 1280 return (x, x[0]) 1281 1282 foo_cg = self.cudagraphify_impl(foo, [inp], ()) 1283 1284 for _ in range(3): 1285 out_1, out_2 = foo_cg([inp]) 1286 self.assertEqual(cdata(out_1), cdata(out_2)) 1287 del out_1, out_2 1288 self.assertEqual(len(list(self.curr_node().path_live_weakrefs())), 0) 1289 1290 self.assertEqual(self.curr_node().cached_tensor_outputs, [None, None]) 1291 1292 def test_empty_storage(self): 1293 @torch.compile(mode="reduce-overhead") 1294 def foo(x): 1295 return ( 1296 (x + x + x), 1297 torch.zeros([0], device="cuda"), 1298 torch.zeros([100], device="cuda")[0:0], 1299 ) 1300 1301 inp = torch.rand([4], device="cuda") 1302 for _ in range(3): 1303 out = foo(inp) 1304 node = self.curr_node() 1305 self.assertEqual(len(list(node.path_live_weakrefs())), 1) 1306 1307 @torch.compile(mode="reduce-overhead") 1308 def foo(x): 1309 return (x + x + x), torch.rand([4], device="cuda") + 10 1310 1311 inp = torch.rand([0], device="cuda") 1312 for _ in range(3): 1313 out = foo(inp) 1314 node = self.curr_node() 1315 self.assertEqual(len(list(node.path_live_weakrefs())), 1) 1316 1317 @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True) 1318 def test_aliased_output_checkpoint(self): 1319 def foo(args): 1320 x = args[0] 1321 args.clear() 1322 y = x + 2 1323 return x + 1, y, y[0] 1324 1325 inp = torch.rand([4, 4], device="cuda") 1326 foo_cg = self.cudagraphify_impl(foo, [inp], ()) 1327 foo_cg([inp]) 1328 foo_cg([inp]) 1329 1330 out1, out2, out3 = foo_cg([inp]) 1331 inp = [out1] 1332 1333 del out1, out2, out3 1334 1335 def foo2(args): 1336 x = args[0] 1337 args.clear() 1338 return [x * x * x] 1339 1340 self.assertEqual(self.num_checkpoints(), 0) 1341 foo2_cg = self.cudagraphify_impl(foo2, inp, ()) 1342 1343 x = foo2_cg(inp)[0] 1344 1345 self.assertEqual(self.num_checkpoints(), 1) 1346 # out2 and out3 dies between the previous recording and the new one, 1347 # need to be manually deallocated after the checkpoint 1348 1349 self.assertEqual(all_live_block_count(), 1) 1350 del x 1351 self.assertEqual(all_live_block_count(), 0) 1352 1353 @skipIfRocm 1354 @unittest.skipIf(not IS_LINUX, "cpp contexts are linux only") 1355 @torch._inductor.config.patch("triton.cudagraph_trees_history_recording", True) 1356 def test_workspace_allocation_error(self): 1357 torch._C._cuda_clearCublasWorkspaces() 1358 1359 prev = torch._inductor.cudagraph_trees.clear_cublas_manager 1360 1361 try: 1362 torch._inductor.cudagraph_trees.clear_cublas_manager = ( 1363 contextlib.nullcontext 1364 ) 1365 1366 @torch.compile() 1367 def foo(x, y): 1368 return x @ x 1369 1370 inps = [torch.rand([400, 400], device="cuda") for _ in range(2)] 1371 1372 thrown = False 1373 try: 1374 foo(*inps) 1375 except Exception as e: 1376 thrown = True 1377 self.assertTrue( 1378 "at::cuda::blas::gemm<float>" in str(e) 1379 or "at::cuda::blas::gemm_internal_cublas<float>" in str(e) 1380 ) 1381 self.assertTrue( 1382 "getCurrentCUDABlasHandle" in str(e) 1383 or "getNewWorkspace" in str(e) 1384 ) 1385 1386 self.assertTrue(thrown) 1387 1388 finally: 1389 torch._C._cuda_clearCublasWorkspaces() 1390 torch._inductor.cudagraph_trees.clear_cublas_manager = prev 1391 torch._inductor.cudagraph_trees.get_container( 1392 self.device_idx 1393 ).tree_manager = None 1394 1395 def test_peristed_output_livenes(self): 1396 @torch.compile 1397 def foo(x): 1398 return x + x 1399 1400 for _ in range(3): 1401 foo(torch.rand([2, 2], device="cuda")) 1402 1403 node = self.get_manager().current_node 1404 self.assertEqual(len(list(node.path_live_weakrefs())), 0) 1405 1406 out = foo(torch.rand([2, 2], device="cuda")) 1407 self.assertTrue(out is node.cached_tensor_outputs[0]) 1408 self.assertEqual(len(list(node.path_live_weakrefs())), 1) 1409 1410 out_ref = out[0:] 1411 del out 1412 self.assertEqual(len(list(node.path_live_weakrefs())), 1) 1413 1414 del out_ref 1415 self.assertEqual(len(list(node.path_live_weakrefs())), 0) 1416 1417 @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True) 1418 def test_tensor_no_longer_in_pool(self): 1419 def foo(args): 1420 x = args[0] 1421 args.clear() 1422 return x + 1, x + 2 1423 1424 inp = torch.rand([4], device="cuda") 1425 inp_list = [inp] 1426 foo_cg = self.cudagraphify_impl(foo, inp_list, ()) 1427 x1, x2 = foo_cg(inp_list) 1428 1429 def foo2(args): 1430 x = args[0] 1431 args.clear() 1432 return [x * x * x] 1433 1434 inp_list = [x1] 1435 foo2_cg = self.cudagraphify_impl(foo2, inp_list, ()) 1436 foo2_cg(inp_list) 1437 1438 del x1, x2 1439 # TODO make configurable 1440 1441 x1, x2 = foo_cg([inp]) 1442 self.assertEqual(self.num_checkpoints(), 0) 1443 1444 # input location has changed, should force recompile and checkpointing 1445 foo2_cg([torch.zeros_like(x1)]) 1446 1447 self.assertEqual(self.num_checkpoints(), 1) 1448 self.assertEqual(self.get_root_children(), [2]) 1449 1450 @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True) 1451 def test_checkpoint_shared_output_storage_deallocation(self): 1452 def foo(args): 1453 x = args[0] 1454 args.clear() 1455 x_tmp = x + 1 1456 return x[0], x[1] 1457 1458 inp = torch.rand([2, 2], device="cuda") 1459 inp_list = [inp] 1460 foo_cg = self.cudagraphify_impl(foo, inp_list, ()) 1461 foo_cg(inp_list) 1462 foo_cg([inp]) 1463 1464 x1, x2 = foo_cg([inp]) 1465 inp = [x1] 1466 1467 def foo2(args): 1468 x = args[0] 1469 args.clear() 1470 y = x * x 1471 return y[0], y[1] 1472 1473 foo2_cg = self.cudagraphify_impl(foo2, inp, ()) 1474 foo2_cg(inp) 1475 1476 self.assertEqual(self.num_checkpoints(), 1) 1477 self.assertEqual( 1478 x1.untyped_storage().data_ptr(), x2.untyped_storage().data_ptr() 1479 ) 1480 self.assertEqual(all_live_block_count(), 1) 1481 del x1 1482 self.assertEqual(all_live_block_count(), 1) 1483 del x2 1484 self.assertEqual(all_live_block_count(), 0) 1485 1486 @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True) 1487 def test_cleanup(self): 1488 def test_closure(): 1489 @torch.compile 1490 def foo(x): 1491 return x + 1 + 2, x * 10 1492 1493 foo(torch.rand([4], device="cuda")) 1494 return foo(torch.rand([4], device="cuda")) 1495 1496 out1, out2 = test_closure() 1497 torch._dynamo.reset() 1498 1499 # TODO - deallocate on tensor deallocation 1500 # self.assertTrue(self.get_manager() is not None) 1501 # del out1 1502 # self.assertTrue(self.get_manager() is not None) 1503 # del out2 1504 self.assertTrue(self.get_manager() is None) 1505 1506 @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True) 1507 def test_forward_backward(self): 1508 @torch.compile 1509 def foo(x): 1510 y = x * 2 1511 return torch.sin(y) * torch.nn.functional.dropout(x, p=0.4) 1512 1513 inp = torch.rand([4, 4], requires_grad=True, device="cuda") 1514 out = foo(inp) 1515 out.sum().backward() 1516 1517 self.assertEqual(self.get_root_children(), [1]) 1518 1519 # the three saved tensors should die in the backward 1520 # we kept alive the output 1521 self.assertEqual(self.curr_node().expected_dead_indices_before_graph, []) 1522 self.assertEqual( 1523 self.curr_node().expected_dead_indices_after_graph, 1524 [(0, 1), (0, 2)], 1525 ) 1526 self.assertFalse(self.get_manager().new_graph_id().id == 0) 1527 1528 def test_separate_recordings(self): 1529 def foo_unopt(x, y): 1530 return (x + 1) @ y 1531 1532 foo = torch.compile(foo_unopt) 1533 1534 foo_unopt( 1535 torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda") 1536 ) 1537 1538 inps = [ 1539 torch.ones([20, 20], device="cuda", requires_grad=False) 1540 for _ in range(2) 1541 ] 1542 1543 out = foo(*inps) 1544 torch.cuda.synchronize() 1545 foo(*inps) 1546 torch.cuda.synchronize() 1547 foo(*inps) 1548 torch.cuda.synchronize() 1549 1550 foo_unopt( 1551 torch.ones([20, 20], device="cuda"), torch.ones([20, 20], device="cuda") 1552 ) 1553 1554 inps2 = [ 1555 torch.rand([40, 40], device="cuda", requires_grad=False) 1556 for _ in range(2) 1557 ] 1558 1559 foo(*inps2) 1560 foo(*inps2) 1561 foo(*inps2) 1562 1563 # two separate roots 1564 self.assertEqual(self.get_root_children(), [0, 0]) 1565 1566 def test_alias_of_parameter(self): 1567 class AliasMod(nn.Module): 1568 def __init__(self) -> None: 1569 super().__init__() 1570 self.param = torch.nn.Parameter(torch.rand([20, 20], device="cuda")) 1571 1572 def forward(self, x): 1573 return self.param[0], self.param, self.param + x 1574 1575 @torch.compile(mode="reduce-overhead") 1576 def foo(mod, inp): 1577 return mod(inp) 1578 1579 inp = torch.rand([20, 20], device="cuda") 1580 mod = AliasMod() 1581 1582 storage_ref = torch.multiprocessing.reductions.StorageWeakRef( 1583 mod.param.untyped_storage() 1584 ) 1585 1586 for _ in range(3): 1587 outs = foo(mod, inp) 1588 1589 self.assertEqual(mod(inp), outs) 1590 1591 self.assertFalse(storage_ref.expired()) 1592 1593 node = self.get_manager().current_node 1594 self.assertEqual(len(list(node.path_live_weakrefs())), 1) 1595 1596 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) 1597 @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False) 1598 def test_unstable_ptr(self): 1599 import torch 1600 1601 @torch.compile(mode="reduce-overhead") 1602 def foo(m, inp): 1603 return m(inp) 1604 1605 def f(): 1606 l = [] 1607 m = torch.nn.Linear(20, 20).cuda() 1608 for _ in range(4): 1609 inp = torch.rand([20, 20], device="cuda") 1610 foo(m, inp) 1611 m.weight.data = torch.rand([20, 20], device="cuda") 1612 1613 self.assertRaises(RuntimeError, f) 1614 1615 @requires_multigpu() 1616 def test_manager_per_device(self): 1617 def test(): 1618 def foo(args): 1619 x = args[0] 1620 args.clear() 1621 return (x + 3,) 1622 1623 inp = torch.rand([20, 20], device="cuda:1") 1624 1625 inp_list = [inp] 1626 foo_cg = tree_cudagraphify_impl( 1627 foo, 1628 inp_list, 1629 (), 1630 device_index=1, 1631 is_backward=False, 1632 is_inference=True, 1633 ) 1634 for _ in range(3): 1635 self.assertEqual(foo_cg([inp]), foo([inp])) 1636 1637 self.assertTrue(self.get_manager(device_index=0) is None) 1638 self.assertFalse(self.get_manager(device_index=1) is None) 1639 1640 test() 1641 self.assertTrue(self.get_manager(device_index=1) is None) 1642 1643 def test_error_on_dealloc_use(self): 1644 @torch.compile() 1645 def foo(x): 1646 return x * x * x 1647 1648 inp = torch.rand([4], device="cuda") 1649 out = foo(inp) 1650 out2 = foo(inp) 1651 1652 with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."): 1653 out + out 1654 1655 foo(inp) 1656 1657 with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."): 1658 out2 + out2 1659 1660 @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") 1661 def test_conv_benchmark(self): 1662 with torch.backends.cudnn.flags( 1663 enabled=True, benchmark=True, deterministic=False 1664 ): 1665 m = torch.nn.Conv2d(5, 6, [3, 3]).cuda() 1666 inp = torch.randn([2, 5, 16, 16]).cuda() 1667 1668 @torch.compile() 1669 def foo(m, inp): 1670 return m(inp) 1671 1672 foo(m, inp) 1673 1674 def test_single_stream_use(self): 1675 @torch.compile() 1676 def foo(x): 1677 return (x * x * x).relu() 1678 1679 inp = torch.rand([4], device="cuda", requires_grad=True) 1680 streams = set() 1681 streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()} 1682 for _ in range(4): 1683 foo(inp).sum().backward() 1684 1685 streams = { 1686 seg["stream"] for seg in get_all_cudagraph_segments() 1687 } - streams_init 1688 self.assertEqual(len(streams), 1) 1689 self.assertFalse(self.get_manager().new_graph_id().id == 0) 1690 1691 @torch._dynamo.config.patch("assume_static_by_default", False) 1692 def test_dynamic_backward(self): 1693 def foo(x): 1694 x = torch.cat([x, x]) 1695 return torch.addmm(x, x, x).relu(), x.size(0) 1696 1697 opt_foo = torch.compile(mode="reduce-overhead")(foo) 1698 1699 def run_test(foo, inp): 1700 r, s = foo(inp) 1701 r.sum().backward() 1702 g = inp.grad.clone() 1703 inp.grad = None 1704 r = r.clone() 1705 return r, s, g 1706 1707 def run_big_test(inp): 1708 r0, s0, g0 = run_test(foo, inp) 1709 r1, s1, g1 = run_test(opt_foo, inp) 1710 r2, s2, g2 = run_test(opt_foo, inp) 1711 self.assertEqual(r0, r1) 1712 self.assertEqual(r0, r2) 1713 self.assertEqual(s0, s1) 1714 self.assertEqual(s0, s2) 1715 self.assertEqual(g0, g1) 1716 self.assertEqual(g0, g2) 1717 1718 inp = torch.randn(2, 4, device="cuda", requires_grad=True) 1719 run_big_test(inp) 1720 1721 inp = torch.randn(3, 6, device="cuda", requires_grad=True) 1722 run_big_test(inp) 1723 1724 def test_dynamic_warmup(self): 1725 COUNTER = 0 1726 1727 def f(inps): 1728 i, x = inps 1729 inps.clear() 1730 nonlocal COUNTER 1731 COUNTER += 1 1732 return x * 2 1733 1734 x = torch.randn(2, device="cuda") 1735 inp_list = [2, x] 1736 foo_cg = self.cudagraphify_impl(f, inp_list, ()) 1737 foo_cg(inp_list) # warmup 1738 foo_cg([2, x]) # record 1739 foo_cg([2, x]) # replay 1740 self.assertEqual(COUNTER, 2) 1741 1742 # Switching the size will require a warmup again 1743 x = torch.randn(3, device="cuda") 1744 inp_list = [3, x] 1745 foo_cg(inp_list) # warmup 1746 foo_cg([3, x]) # record 1747 foo_cg([3, x]) # replay 1748 self.assertEqual(COUNTER, 4) 1749 1750 def test_forward_generation(self): 1751 def foo(x): 1752 return x * x * x 1753 1754 def foo2(x): 1755 return x * 12 1756 1757 foo_opt = torch.compile(foo) 1758 foo2_opt = torch.compile(foo2) 1759 ones = torch.ones([4, 4], device="cuda", requires_grad=True) 1760 1761 out = foo_opt(ones) 1762 out2 = foo2_opt(out) 1763 1764 self.assertEqual(all_live_block_count(), 2) 1765 1766 self.assertTrue(self.get_manager().running_forwards_with_pending_backwards) 1767 1768 out2.sum().backward() 1769 self.assertFalse(self.get_manager().running_forwards_with_pending_backwards) 1770 1771 del out 1772 del out2 1773 1774 foo2_opt(foo_opt(ones)).sum().backward() 1775 1776 out = foo_opt(ones.detach()) 1777 self.assertFalse(self.get_manager().running_forwards_with_pending_backwards) 1778 self.assertFalse(self.get_manager().new_graph_id().id == 0) 1779 1780 def test_warn_on_pending_backward(self): 1781 @torch.compile 1782 def foo(x): 1783 return x * x * x 1784 1785 out = foo(torch.rand([4, 4], device="cuda", requires_grad=True)) 1786 out = foo(torch.rand([4, 4], device="cuda", requires_grad=True)) 1787 1788 warnings.resetwarnings() 1789 with warnings.catch_warnings(record=True) as w: 1790 out = foo(torch.rand([4, 4], device="cuda", requires_grad=True)) 1791 1792 FileCheck().check( 1793 "Unable to hit fast path of CUDAGraphs because of pending" 1794 ).run(str(w[0])) 1795 self.assertTrue(self.get_manager().new_graph_id().id == 0) 1796 1797 def test_mark_step(self): 1798 @torch.compile 1799 def foo(x): 1800 return x * x * x 1801 1802 torch.compiler.cudagraph_mark_step_begin() 1803 out = foo(torch.rand([4, 4], device="cuda", requires_grad=True)) 1804 1805 torch.compiler.cudagraph_mark_step_begin() 1806 out = foo(torch.rand([4, 4], device="cuda", requires_grad=True)) 1807 self.assertFalse(self.get_manager().new_graph_id().id == 0) 1808 1809 @torch._dynamo.config.patch("capture_scalar_outputs", True) 1810 def test_incompatible_cudagraph_ops_item(self): 1811 @torch.compile(mode="reduce-overhead") 1812 def foo(x): 1813 return x.item() 1814 1815 # NB: This doesn't work with float, because float unbacked codegen 1816 # is currently broken. But testing the float case here is also 1817 # awkward, because we plan to Tensor-ify the float compute, and as 1818 # a result we'd actually expect this to work with cuda graphs! 1819 with capture_stderr() as captured_output: 1820 self.assertEqual(foo(torch.tensor(3, device="cuda")), 3) 1821 self.assertEqual(foo(torch.tensor(6, device="cuda")), 6) 1822 1823 # NOTE: this test is named after incompatible ops, but is not skipping due to incompatible ops. 1824 # This should get fixed. 1825 FileCheck().check( 1826 "skipping cudagraphs due to cpu device (_local_scalar_dense)" 1827 ).run(captured_output[0]) 1828 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 1829 1830 @torch._dynamo.config.patch("compiled_autograd", True) 1831 def test_compiled_autograd_static_input_params(self): 1832 @torch.compile(mode="reduce-overhead") 1833 def bwd(loss): 1834 loss.backward() 1835 1836 model = torch.nn.Linear(10, 10, bias=False, device="cuda") 1837 x = torch.randn(10, 10, device="cuda") 1838 for i in range(5): 1839 out = model(x) 1840 bwd(out.sum()) 1841 model.weight.grad = None 1842 1843 # i=0, 0 copies (warmup) 1844 # i=1, 2 copies (record, 1/3 inputs marked as static) 1845 # i>1, 0 copies (run) 1846 self.assertEqual( 1847 counters["inductor"]["cudagraph_recorded_non_static_inputs"], 2 1848 ) 1849 1850 @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True) 1851 def test_incompatible_cudagraph_ops_nonzero(self): 1852 @torch.compile(mode="reduce-overhead") 1853 def foo(x): 1854 return x.nonzero() 1855 1856 with capture_stderr() as captured_output: 1857 self.assertEqual( 1858 foo(torch.tensor([1, 0, 2], device="cuda")), 1859 torch.tensor([[0], [2]]), 1860 ) 1861 self.assertEqual( 1862 foo(torch.tensor([1, 0, 0], device="cuda")), torch.tensor([[0]]) 1863 ) 1864 1865 FileCheck().check("skipping cudagraphs due to ['incompatible ops']").run( 1866 captured_output[0] 1867 ) 1868 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 1869 1870 @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True) 1871 def test_incompatible_cudagraph_ops_nonzero_graph_breaks(self): 1872 @torch.compile(mode="reduce-overhead") 1873 def foo(x): 1874 y = x.nonzero() # skip 1875 torch._dynamo.graph_break() 1876 return y.nonzero() # skip 2 times (due to recompile) 1877 1878 foo(torch.tensor([1, 0, 2], device="cuda")) 1879 foo(torch.tensor([1, 0, 0], device="cuda")) 1880 1881 self.assertEqual(counters["inductor"]["cudagraph_skips"], 3) 1882 1883 @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True) 1884 def test_incompatible_cudagraph_ops_nonzero_backend(self): 1885 @torch.compile(backend="cudagraphs") 1886 def foo(x): 1887 return x.nonzero() 1888 1889 with capture_stderr() as captured_output: 1890 self.assertEqual( 1891 foo(torch.tensor([1, 0, 2], device="cuda")), 1892 torch.tensor([[0], [2]]), 1893 ) 1894 self.assertEqual( 1895 foo(torch.tensor([1, 0, 0], device="cuda")), torch.tensor([[0]]) 1896 ) 1897 1898 FileCheck().check( 1899 "skipping cudagraphs due to incompatible op (nonzero)" 1900 ).run(captured_output[0]) 1901 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 1902 1903 def test_storage_access_error(self): 1904 x = torch.rand([4], device="cuda") 1905 torch._C._set_storage_access_error_msg(x, "custom error msg") 1906 1907 with self.assertRaisesRegex(Exception, "custom error msg"): 1908 device = x.untyped_storage() 1909 1910 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) 1911 @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False) 1912 def test_static_inputs_address_mutation_log(self): 1913 class Goo(torch.nn.Module): 1914 def __init__(self) -> None: 1915 super().__init__() 1916 self.linear = torch.nn.Linear(2, 2, device="cuda") 1917 1918 def forward(self, x) -> torch.Tensor: 1919 return self.linear(x) 1920 1921 class Foo(torch.nn.Module): 1922 def __init__(self) -> None: 1923 super().__init__() 1924 self.static_tensor = torch.zeros((2, 2), device="cuda") 1925 self.goo = Goo() 1926 1927 def forward(self, x) -> torch.Tensor: 1928 self.static_tensor.add_(torch.ones((2, 2), device="cuda")) 1929 return self.static_tensor + x + self.goo(x) 1930 1931 foo = Foo() 1932 foo = torch.compile(foo, mode="reduce-overhead") 1933 inp = torch.rand((2, 2), device="cuda") 1934 1935 for _ in range(3): 1936 foo(inp) 1937 1938 # mutates static input tensors' addresses 1939 foo.static_tensor = torch.ones((2, 2), device="cuda") 1940 foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda")) 1941 1942 with self.assertRaisesRegex( 1943 Exception, 1944 r"static input data pointer changed.\n" 1945 r"input name: primals_2. data pointer changed from .* to .*. input stack trace:(?s).*" 1946 r"input name: primals_3. data pointer changed from .* to .*. input stack trace:.*," 1947 r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n", 1948 ): 1949 self.curr_node().run( 1950 [foo.goo.linear.weight, foo.goo.linear.bias, foo.static_tensor, inp] 1951 ) 1952 1953 def _run_iter(self, param, fn): 1954 fwd_output = fn(torch.ones(2, 2), param) 1955 fwd_output.sum().backward() 1956 grad_output = param.grad.clone().detach() 1957 param.grad = None 1958 return fwd_output, grad_output 1959 1960 def _assert_equal_multi_loop(self, param, fn_eager, fn_compiled): 1961 exp_output, exp_grad = self._run_iter(param, fn_eager) 1962 for _ in range(5): 1963 compiled_output, compiled_grad = self._run_iter(param, fn_compiled) 1964 self.assertEqual(exp_output, compiled_output) 1965 self.assertEqual(exp_grad, compiled_grad) 1966 1967 def run_static_input_param_test(self, fn_eager, num_graphs): 1968 with torch.device("cuda"): 1969 fn_compiled = torch.compile(fn_eager, mode="reduce-overhead") 1970 1971 p1 = torch.nn.Parameter(torch.rand([2, 2])) 1972 self._assert_equal_multi_loop(p1, fn_eager, fn_compiled) 1973 1974 p2 = torch.nn.Parameter(torch.rand([2, 2])) 1975 self._assert_equal_multi_loop(p2, fn_eager, fn_compiled) 1976 1977 # Run p1 again to ensure we reuse the previous recording 1978 self._assert_equal_multi_loop(p1, fn_eager, fn_compiled) 1979 1980 self.assertEqual(self.get_manager().new_graph_id().id, num_graphs) 1981 1982 def _module_test(self, mod, name="weight", param_wrapping=True): 1983 with torch.device("cuda"): 1984 1985 def fn(x, mod): 1986 return mod(x) 1987 1988 fn_compiled = torch.compile(fn, mode="reduce-overhead", fullgraph=True) 1989 1990 def run_test_iter(mod, fn): 1991 fwd_output = fn(torch.ones(2, 2), mod) 1992 fwd_output.sum().backward() 1993 grad_output = mod.weight.grad.clone().detach() 1994 mod.zero_grad() 1995 return fwd_output, grad_output 1996 1997 def run_test(): 1998 exp_output, exp_grad = run_test_iter(mod, fn) 1999 for _ in range(5): 2000 compiled_output, compiled_grad = run_test_iter(mod, fn_compiled) 2001 self.assertEqual(exp_output, compiled_output) 2002 self.assertEqual(exp_grad, compiled_grad) 2003 2004 run_test() 2005 old_attr = getattr(mod, name) 2006 modified_attr = torch.rand_like(old_attr) 2007 if param_wrapping: 2008 modified_attr = torch.nn.Parameter(modified_attr) 2009 setattr(mod, name, modified_attr) 2010 run_test() 2011 # Run original version to verify we reuse the other recording 2012 setattr(mod, name, old_attr) 2013 run_test() 2014 2015 # Fwd + bwd graphs for each version of the function => 4 graphs 2016 self.assertEqual(self.get_manager().new_graph_id().id, 4) 2017 2018 @torch._dynamo.config.patch("error_on_recompile", True) 2019 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2020 def test_multi_dispatch_single_compile_param_inputs(self): 2021 # Verify that we can record multiple cudagraphs for a single 2022 # compiled function with param inputs 2023 def fn(x, y): 2024 return x * y 2025 2026 # Fwd + bwd graphs for each version of the function => 4 graphs 2027 self.run_static_input_param_test(fn, 4) 2028 2029 @torch._dynamo.config.patch("error_on_recompile", True) 2030 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2031 def test_multi_dispatch_single_compile_builtin_module(self): 2032 # Verify that we don't recompile when changing the param of a builtin module 2033 # and that we record another cudagraph 2034 # Note: Linear is a builtin module so we enable that config setting above 2035 self._module_test(torch.nn.Linear(2, 3, device="cuda")) 2036 2037 @torch._dynamo.config.patch("error_on_recompile", True) 2038 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2039 def test_multi_dispatch_single_compile_builtin_module_buffers(self): 2040 # Verify that we don't recompile when changing the buffer of a builtin module 2041 # and that we record another cudagraph 2042 self._module_test( 2043 torch.nn.BatchNorm1d(2, device="cuda"), 2044 name="running_mean", 2045 param_wrapping=False, 2046 ) 2047 2048 @torch._inductor.config.patch("triton.cudagraphs", True) 2049 @torch._dynamo.config.patch("error_on_recompile", True) 2050 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2051 def test_multi_dispatch_custom_module(self): 2052 # Test that we can correctly dispatch multiple graphs 2053 # if params of a custom module change 2054 class TestModule(torch.nn.Module): 2055 def __init__(self, param) -> None: 2056 super().__init__() 2057 self.weight = param 2058 2059 def forward(self, x): 2060 return self.weight * x 2061 2062 self._module_test( 2063 TestModule(torch.nn.Parameter(torch.rand([2, 2], device="cuda"))) 2064 ) 2065 2066 @torch._dynamo.config.patch("error_on_recompile", True) 2067 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2068 def test_multi_dispatch_custom_module_buffer(self): 2069 # Test that we can correctly dispatch multiple graphs 2070 # if buffers of a custom module change 2071 class TestModule(torch.nn.Module): 2072 def __init__(self, param, buf) -> None: 2073 super().__init__() 2074 self.weight = param 2075 self.buf = torch.nn.Buffer(buf) 2076 2077 def forward(self, x): 2078 return x * self.weight + self.buf 2079 2080 self._module_test( 2081 TestModule( 2082 torch.nn.Parameter(torch.rand([2, 2], device="cuda")), 2083 torch.rand([2, 2], device="cuda"), 2084 ), 2085 name="buf", 2086 param_wrapping=False, 2087 ) 2088 2089 @torch._inductor.config.patch("triton.cudagraphs", True) 2090 @torch._dynamo.config.patch("error_on_recompile", True) 2091 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2092 def test_multi_dispatch_child_node(self): 2093 # Test that we can correctly dispatch multiple graphs if a child node 2094 # in the tree has stable input pointers change 2095 def fn(x, p): 2096 # Graph 1 2097 y = x * x 2098 torch._dynamo.graph_break() 2099 # Graph 2 2100 return y * p 2101 2102 # We have 5 graphs here 2103 # Graph 1 2104 # / \ 2105 # Graph 2 w/ p1 Graph 2 w/ p2 2106 # and then two backward graphs 2107 self.run_static_input_param_test(fn, 5) 2108 2109 @torch._dynamo.config.patch("error_on_recompile", True) 2110 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2111 def test_multi_dispatch_parent_node(self): 2112 def fn(x, p): 2113 # Graph 1 2114 y = x * p 2115 torch._dynamo.graph_break() 2116 # Graph 2 2117 return y + x 2118 2119 # We have 6 graphs here 2120 # Graph 1 w/ p1 Graph 1 w/ p2 2121 # | | 2122 # Graph 2 (v1) Graph 2 (v2) 2123 # There are two versions of graph 2 because 2124 # we re-record due to different memory state after running the 2125 # two versions of Graph 1 2126 # and then two backward graphs 2127 self.run_static_input_param_test(fn, 6) 2128 2129 @torch._dynamo.config.patch("error_on_recompile", True) 2130 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) 2131 @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) 2132 @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0) 2133 def test_fallback_to_eager_if_recompiling_too_many_times(self): 2134 class Foo(torch.nn.Module): 2135 def __init__(self) -> None: 2136 super().__init__() 2137 self.param = torch.nn.Parameter(torch.rand([2, 2], device="cuda")) 2138 2139 def forward(self, x): 2140 return x * self.param 2141 2142 with capture_stderr() as captured_output: 2143 # We have 3 graphs here 2144 # None 2145 # / \ 2146 # (fwd w/ p1, Graph 0) (bwd w/p2, Graph2) 2147 # (bwd w/ p1, Graph 1) 2148 # All other graphs are skipped because we hit the max recording limit 2149 # (=0 for each node and function pair) 2150 fn_compiled = torch.compile(Foo(), mode="reduce-overhead") 2151 for _ in range(3): 2152 fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() 2153 2154 # Change static tensor address 2155 fn_compiled.param.data = torch.rand([2, 2], device="cuda") 2156 fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() 2157 self.assertEqual(self.get_manager().new_graph_id().id, 3) 2158 2159 FileCheck().check( 2160 "skipping cudagraph due to function 0 exceeding max re-recording limit (=0) " 2161 "on cudagraph node None due to static input data pointer changed." 2162 ).run(captured_output[0]) 2163 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 2164 2165 @torch._dynamo.config.patch("error_on_recompile", True) 2166 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) 2167 @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) 2168 @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0) 2169 def test_fallback_to_eager_if_recompiling_too_many_times_warn_only_once(self): 2170 class Foo(torch.nn.Module): 2171 def __init__(self) -> None: 2172 super().__init__() 2173 self.param = torch.nn.Parameter(torch.rand([2, 2], device="cuda")) 2174 2175 def forward(self, x): 2176 return x * self.param 2177 2178 with capture_stderr() as captured_output: 2179 with torch.device("cuda"): 2180 # We have 3 graphs here 2181 # None 2182 # / \ 2183 # (fwd w/ p1, Graph 0) (bwd w/p2, Graph2) 2184 # (bwd w/ p1, Graph 1) 2185 # All other graphs are skipped because we hit the max recording limit 2186 # (=0 for each node and function pair) 2187 fn_compiled = torch.compile(Foo(), mode="reduce-overhead") 2188 for _ in range(3): 2189 fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() 2190 2191 for _ in range(5): 2192 # Change static tensor address 2193 fn_compiled.param.data = torch.rand([2, 2], device="cuda") 2194 fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() 2195 2196 FileCheck().check_count( 2197 "skipping cudagraph due to function 0 exceeding max re-recording limit (=0) " 2198 "on cudagraph node None due to static input data pointer changed.", 2199 1, 2200 exactly=True, 2201 ).check_count( 2202 "skipping cudagraph due to function 1 exceeding max re-recording limit (=0) " 2203 "on cudagraph node None due to static input data pointer changed.", 2204 1, 2205 exactly=True, 2206 ).run( 2207 captured_output[0] 2208 ) 2209 self.assertEqual(counters["inductor"]["cudagraph_skips"], 2) 2210 2211 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) 2212 @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) 2213 @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0) 2214 def test_fallback_to_eager_if_recompiling_too_many_times_due_to_cudagraph_managed_tensor( 2215 self, 2216 ): 2217 # By setting triton.cudagraph_support_input_mutation=True, we force re-record 2218 # if cudagraph managed tensor addresses changed. 2219 @torch.compile(mode="reduce-overhead") 2220 def foo(x): 2221 return x + 1 2222 2223 @torch.compile(mode="reduce-overhead") 2224 def goo(x): 2225 return x * 2 2226 2227 for _ in range(3): 2228 torch.compiler.cudagraph_mark_step_begin() 2229 inp = torch.rand((2, 3), device="cuda") 2230 y = foo(inp) 2231 z = goo(y) 2232 2233 with capture_stderr() as captured_output: 2234 torch.compiler.cudagraph_mark_step_begin() 2235 x = torch.rand(2, 3, device="cuda") 2236 y = foo(x) 2237 y_clone = y.clone() 2238 z = goo(y_clone) 2239 2240 # eager function should run successfully 2241 for _ in range(5): 2242 torch.compiler.cudagraph_mark_step_begin() 2243 x = torch.rand(2, 3, device="cuda") 2244 y = foo(x) 2245 y_clone = y.clone() 2246 z = goo(y_clone) 2247 2248 FileCheck().check_count( 2249 "skipping cudagraph due to function 1 exceeding max re-recording limit (=0) " 2250 "on cudagraph node 0 due to cudagraph managed tensor data pointer changed", 2251 1, 2252 exactly=True, 2253 ).run(captured_output[0]) 2254 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 2255 2256 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) 2257 @torch._dynamo.config.patch("error_on_recompile", True) 2258 @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 2259 @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 1) 2260 def test_not_fallback_to_eager_if_have_not_recompiling_too_many_times(self): 2261 def fn(x, y): 2262 return x * y 2263 2264 # We have 4 graphs here 2265 # None 2266 # / \ 2267 # (fwd w/ p1, Graph 0) (fwd w/p2, Graph2) 2268 # (bwd w/ p1, Graph 1) (bwd w/p2, Graph3) 2269 self.run_static_input_param_test(fn, 4) 2270 self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) 2271 2272 def test_tensor_constant_mutation(self): 2273 class Foo(torch.nn.Module): 2274 def __init__(self) -> None: 2275 super().__init__() 2276 self.tensor_constant = torch.ones((2, 3), device="cuda") 2277 2278 def forward(self, x: torch.Tensor) -> torch.Tensor: 2279 self.tensor_constant += 1 2280 return x + self.tensor_constant 2281 2282 foo = Foo() 2283 foo = torch.compile(foo, mode="reduce-overhead") 2284 inp = torch.rand((2, 3), device="cuda") 2285 for _ in range(3): 2286 foo(inp) 2287 2288 @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) 2289 def test_rerecord_if_static_input_address_changed(self): 2290 # By setting triton.cudagraph_support_input_mutation=True, we force re-record 2291 # if static tensor addresses changed. 2292 class Goo(torch.nn.Module): 2293 def __init__(self) -> None: 2294 super().__init__() 2295 self.linear = torch.nn.Linear(2, 2, device="cuda") 2296 2297 def forward(self, x) -> torch.Tensor: 2298 return self.linear(x) 2299 2300 class Foo(torch.nn.Module): 2301 def __init__(self) -> None: 2302 super().__init__() 2303 self.register_buffer( 2304 "static_tensor", torch.zeros((2, 2), device="cuda") 2305 ) 2306 self.goo = Goo() 2307 2308 def forward(self, x) -> torch.Tensor: 2309 self.static_tensor.add_(torch.ones((2, 2), device="cuda")) 2310 return self.static_tensor + x + self.goo(x) 2311 2312 foo = Foo() 2313 foo = torch.compile(foo, mode="reduce-overhead") 2314 inp = torch.rand((2, 2), device="cuda") 2315 2316 for _ in range(3): 2317 foo(inp) 2318 2319 # mutates static input tensors' addresses 2320 foo.static_tensor = torch.ones((2, 2), device="cuda") 2321 foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda")) 2322 2323 if torch._dynamo.config.inline_inbuilt_nn_modules: 2324 for _ in range(3): 2325 foo(inp) 2326 else: 2327 # Run with specific function id to avoid dynamo recompiling 2328 self.get_manager().run( 2329 [ 2330 foo.goo.linear.weight, 2331 foo.goo.linear.bias, 2332 foo.static_tensor, 2333 inp, 2334 ], 2335 FunctionID(0), 2336 ) 2337 2338 self.assertEqual(self.get_manager().new_graph_id().id, 2) 2339 2340 @torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1) 2341 def test_skip_if_dynamic_shape_limit_reached1(self): 2342 class Mod(torch.nn.Module): 2343 def __init__(self) -> None: 2344 super().__init__() 2345 self.linear = torch.nn.Linear(3, 3, device="cuda") 2346 2347 def forward(self, x: torch.Tensor) -> torch.Tensor: 2348 return self.linear(x) 2349 2350 def iter(batch_size: int, mod: torch.nn.Module): 2351 x = torch.rand((batch_size, 3), device="cuda") 2352 for _ in range(3): 2353 mod(x) 2354 2355 mod = torch.compile(Mod(), mode="reduce-overhead") 2356 2357 with capture_stderr() as captured_output: 2358 for batch_size in range(10, 40, 10): 2359 iter(batch_size, mod) 2360 2361 FileCheck().check( 2362 "CUDAGraph supports dynamic shapes by recording a new graph for each " 2363 "distinct input size. Recording too many CUDAGraphs may lead to " 2364 "extra overhead. We have observed 2 distinct sizes. " 2365 "Please consider the following options for better performance: " 2366 "a) padding inputs to a few fixed number of shapes; or b) set " 2367 "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. " 2368 "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None " 2369 "to silence this warning." 2370 ).run("\n".join(captured_output)) 2371 2372 @torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1) 2373 def test_skip_if_dynamic_shape_limit_reached2(self): 2374 class Mod(torch.nn.Module): 2375 def __init__(self) -> None: 2376 super().__init__() 2377 self.attn = torch.nn.MultiheadAttention( 2378 embed_dim=3, num_heads=3, device="cuda" 2379 ) 2380 2381 def forward( 2382 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor 2383 ) -> torch.Tensor: 2384 return self.attn(q, k, v) 2385 2386 mod = torch.compile(Mod(), mode="reduce-overhead") 2387 2388 def iter(batch_size: int, length: int): 2389 q = torch.rand((batch_size, length, 3), device="cuda") 2390 k = torch.rand((batch_size, length, 3), device="cuda") 2391 v = torch.rand((batch_size, length, 3), device="cuda") 2392 for _ in range(3): 2393 mod(q, k, v) 2394 2395 with capture_stderr() as captured_output: 2396 for batch_size in range(10, 40, 10): 2397 for length in range(10, 30, 10): 2398 iter(batch_size, length) 2399 2400 print(captured_output) 2401 FileCheck().check( 2402 "CUDAGraph supports dynamic shapes by recording a new graph for each " 2403 "distinct input size. Recording too many CUDAGraphs may lead to " 2404 "extra overhead. We have observed 2 distinct sizes. " 2405 "Please consider the following options for better performance: " 2406 "a) padding inputs to a few fixed number of shapes; or b) set " 2407 "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. " 2408 "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None " 2409 "to silence this warning." 2410 ).run(captured_output[0]) 2411 2412 @torch._inductor.config.patch("triton.cudagraph_dynamic_shape_warn_limit", 1) 2413 def test_warn_once_if_dynamic_shape_limit_reached(self): 2414 class Mod(torch.nn.Module): 2415 def __init__(self) -> None: 2416 super().__init__() 2417 self.linear = torch.nn.Linear(3, 3, device="cuda") 2418 2419 def forward(self, x: torch.Tensor) -> torch.Tensor: 2420 return self.linear(x) 2421 2422 def iter(batch_size: int, mod: torch.nn.Module): 2423 x = torch.rand((batch_size, 3), device="cuda") 2424 for _ in range(3): 2425 mod(x) 2426 2427 mod = torch.compile(Mod(), mode="reduce-overhead") 2428 2429 with capture_stderr() as captured_output: 2430 for batch_size in range(10, 200, 10): 2431 iter(batch_size, mod) 2432 2433 print(captured_output) 2434 2435 FileCheck().check_count( 2436 "CUDAGraph supports dynamic shapes by recording a new graph for each " 2437 "distinct input size. Recording too many CUDAGraphs may lead to " 2438 "extra overhead. We have observed 2 distinct sizes. " 2439 "Please consider the following options for better performance: " 2440 "a) padding inputs to a few fixed number of shapes; or b) set " 2441 "torch._inductor.config.triton.cudagraph_skip_dynamic_graphs=True. " 2442 "Set torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit=None " 2443 "to silence this warning.", 2444 1, 2445 exactly=True, 2446 ).run("\n".join(captured_output)) 2447 2448 @torch._inductor.config.patch("cpp_wrapper", 1) 2449 def test_cpp_wrapper(self): 2450 def f(x): 2451 return torch.sin(x) 2452 2453 compiled = torch.compile(f, mode="reduce-overhead") 2454 example_input = torch.randn(10, device="cuda") 2455 compiled_result = self.run_twc(compiled, example_input) 2456 eager_result = f(example_input) 2457 self.assertEqual(compiled_result, eager_result) 2458 2459 instantiate_parametrized_tests(CudaGraphTreeTests) 2460 2461if __name__ == "__main__": 2462 from torch._inductor.test_case import run_tests 2463 2464 if not TEST_CUDA_GRAPH: 2465 if __name__ == "__main__": 2466 sys.exit(0) 2467 raise unittest.SkipTest("cuda graph test is skipped") 2468 2469 if HAS_CPU or HAS_CUDA: 2470 run_tests(needs="filelock") 2471