1# Owner(s): ["module: inductor"] 2import contextlib 3import dataclasses 4import functools 5import io 6import itertools 7import logging 8import os 9import re 10import subprocess 11import sys 12import unittest 13from importlib.machinery import SourceFileLoader 14from pathlib import Path 15from unittest import mock 16 17import torch 18import torch.nn as nn 19import torch.nn.functional as F 20from torch import _inductor as inductor 21from torch._dynamo import compiled_autograd, config 22from torch._dynamo.backends.debugging import aot_eager 23from torch._dynamo.utils import counters 24from torch._inductor import config as inductor_config 25from torch._inductor.test_case import run_tests, TestCase 26from torch.testing._internal.common_utils import skipIfWindows 27from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA 28from torch.testing._internal.logging_utils import logs_to_string 29 30 31# note: these tests are not run on windows due to inductor_utils.HAS_CPU 32 33 34def make_compiler_fn(fullgraph=True, dynamic=True, backend="inductor"): 35 assert backend in ["inductor", "aot_eager"] 36 37 def _compiler_fn(gm): 38 """Same as torch.compile() but counts number of compiles""" 39 40 def _inner_compiler(gm_, example_inputs_): 41 counters["compiled_autograd"]["compiles"] += 1 42 if backend == "inductor": 43 return inductor.compile(gm_, example_inputs_) 44 elif backend == "aot_eager": 45 return aot_eager(gm_, example_inputs_) 46 47 return torch.compile( 48 gm, backend=_inner_compiler, fullgraph=fullgraph, dynamic=dynamic 49 ) 50 51 return _compiler_fn 52 53 54compiler_fn = make_compiler_fn() 55 56 57# TODO(jansel): hooks as lambdas creates recompiles in dynamo, we should fix that 58def hook1(grad): 59 return grad * 2 60 61 62def hook2(grads): 63 return (grads[0] + 1,) 64 65 66def hook3(gI, gO): 67 return (torch.sin(gI[0]) + gO[0],) 68 69 70class TestCompiledAutograd(TestCase): 71 def setUp(self) -> None: 72 super().setUp() 73 torch._logging.set_logs(compiled_autograd_verbose=False) 74 config.compiled_autograd = False 75 compiled_autograd.reset() 76 77 def tearDown(self) -> None: 78 super().tearDown() 79 torch._logging.set_logs(compiled_autograd_verbose=False) 80 config.compiled_autograd = False 81 compiled_autograd.reset() 82 83 def check_output_and_recompiles( 84 self, fn, count=1, compiler_fn=compiler_fn, compile_fn=False 85 ): 86 if isinstance(count, list): 87 captures, compiles = count 88 else: 89 captures, compiles = count, count 90 with torch.autograd.set_multithreading_enabled(False): 91 torch._dynamo.reset() 92 counters["compiled_autograd"].clear() 93 torch.manual_seed(123) 94 expected = list(fn()) 95 torch.manual_seed(123) 96 with compiled_autograd.enable(compiler_fn): 97 opt_fn = torch.compile(fn) if compile_fn else fn 98 actual = list(opt_fn()) 99 self.assertEqual(expected, actual) 100 self.assertEqual(counters["compiled_autograd"]["captures"], captures) 101 self.assertEqual(counters["compiled_autograd"]["compiles"], compiles) 102 103 def run_as_subprocess(self, script) -> bytes: 104 try: 105 return subprocess.check_output( 106 [sys.executable, "-c", script], 107 stderr=subprocess.STDOUT, 108 # On Windows, opening the subprocess with the default CWD makes `import torch` 109 # fail, so just set CWD to this script's directory 110 cwd=os.path.dirname(os.path.realpath(__file__)), 111 ) 112 except subprocess.CalledProcessError as e: 113 self.fail(f"Subprocess exited with return code: {e.returncode}") 114 115 def test_dynamo_flaky_segfault(self): 116 script = """ 117import torch 118 119def main(): 120 def compiler_fn(gm): 121 return torch.compile(gm, backend="eager") 122 123 def inner(): 124 x = torch.randn(1000, 3000) 125 w = torch.randn(1000, 3000, requires_grad=True) 126 def model(i): 127 return torch.nn.functional.linear(i, w) 128 out = model(x) 129 loss = out.sum() 130 with torch._dynamo.compiled_autograd.enable(compiler_fn): 131 loss.backward() 132 assert(w.grad is not None) 133 134 inner() 135 torch._dynamo.reset() 136 inner() 137 138main() 139 """ 140 # Run it three times to catch bad dynamo state resets 141 for _ in range(3): 142 self.run_as_subprocess(script) 143 144 def test_basic(self): 145 def fn(): 146 model = torch.nn.Sequential( 147 torch.nn.Linear(4, 4), 148 torch.nn.ReLU(), 149 torch.nn.Linear(4, 4), 150 torch.nn.ReLU(), 151 ) 152 x = torch.randn([2, 4]) 153 result = model(x).sum() 154 result.backward() 155 yield model[0].weight.grad 156 yield model[0].bias.grad 157 yield model[2].weight.grad 158 yield model[2].bias.grad 159 160 self.check_output_and_recompiles(fn) 161 162 def test_cache_hit(self): 163 def fn(): 164 for _ in range(3): 165 model = torch.nn.Sequential( 166 torch.nn.Linear(4, 4), 167 torch.nn.ReLU(), 168 torch.nn.Linear(4, 4), 169 torch.nn.ReLU(), 170 ) 171 x = torch.randn([2, 4]) 172 result = model(x).sum() 173 result.backward() 174 yield model[0].weight.grad 175 yield model[0].bias.grad 176 yield model[2].weight.grad 177 yield model[2].bias.grad 178 179 self.check_output_and_recompiles(fn) 180 181 def test_graph_break_custom_op(self): 182 @torch.library.custom_op("mylib::sin", mutates_args={}) 183 def sin(x: torch.Tensor) -> torch.Tensor: 184 return x.sin() 185 186 def setup_context(ctx, inputs, output): 187 (x,) = inputs 188 ctx.save_for_backward(x) 189 190 def backward(ctx, grad): 191 (x,) = ctx.saved_tensors 192 return grad * x.cos() 193 194 sin.register_autograd(backward, setup_context=setup_context) 195 196 x = torch.randn(3, requires_grad=True) 197 y = sin(x.clone()).sum() 198 with compiled_autograd.enable(compiler_fn): 199 y.backward() 200 201 def test_tensor_grad_hook1(self): 202 def fn(): 203 for _ in range(3): 204 model = torch.nn.Sequential( 205 torch.nn.Linear(4, 4), 206 torch.nn.ReLU(), 207 ) 208 x = torch.randn([2, 4]) 209 210 model[0].weight.register_hook(hook1) 211 212 result = model(x).sum() 213 result.backward() 214 yield model[0].weight.grad 215 yield model[0].bias.grad 216 217 self.check_output_and_recompiles(fn) 218 219 def test_tensor_grad_hook2(self): 220 def fn(): 221 for _ in range(3): 222 model = torch.nn.Sequential( 223 torch.nn.Linear(4, 4), 224 torch.nn.ReLU(), 225 ) 226 x = torch.randn([1, 4]) 227 228 result = model(x).sum() 229 result.grad_fn.register_prehook(hook2) 230 result.backward() 231 yield model[0].weight.grad 232 yield model[0].bias.grad 233 234 self.check_output_and_recompiles(fn) 235 236 def test_tensor_grad_hook3(self): 237 def fn(): 238 for _ in range(3): 239 model = torch.nn.Sequential( 240 torch.nn.Linear(4, 4), 241 torch.nn.ReLU(), 242 ) 243 x = torch.randn([1, 4]) 244 245 result = model(x).sum() 246 result.grad_fn.register_hook(hook3) 247 result.backward() 248 yield model[0].weight.grad 249 yield model[0].bias.grad 250 251 self.check_output_and_recompiles(fn) 252 253 def test_torch_compile(self): 254 def fn(): 255 model = torch.nn.Sequential( 256 torch.nn.Linear(4, 4), 257 torch.nn.Sigmoid(), 258 ) 259 opt_model = torch.compile(model, fullgraph=True) 260 261 for _ in range(3): 262 x = torch.randn([1, 4]) 263 264 result = opt_model(x).sum() 265 result.backward() 266 yield model[0].weight.grad 267 yield model[0].bias.grad 268 model.zero_grad() 269 270 self.check_output_and_recompiles(fn) 271 272 def test_torch_compile_api_inductor(self): 273 def fn(): 274 torch.manual_seed(123) 275 model = torch.nn.Sequential( 276 torch.nn.Linear(4, 4), 277 torch.nn.Sigmoid(), 278 ) 279 280 res = [] 281 for _ in range(3): 282 x = torch.randn([1, 4]) 283 284 result = model(x).sum() 285 result.backward() 286 res.append(model[0].weight.grad) 287 res.append(model[0].bias.grad) 288 model.zero_grad() 289 return res 290 291 expected = fn() 292 with config.patch(compiled_autograd=True): 293 compiled_fn = torch.compile(fn) 294 actual = compiled_fn() 295 self.assertEqual(expected, actual) 296 self.assertEqual(counters["compiled_autograd"]["captures"], 1) 297 298 def test_torch_compile_api_aot_eager(self): 299 def fn(): 300 torch.manual_seed(123) 301 model = torch.nn.Sequential( 302 torch.nn.Linear(4, 4), 303 torch.nn.Sigmoid(), 304 ) 305 306 res = [] 307 for _ in range(3): 308 x = torch.randn([1, 4]) 309 310 result = model(x).sum() 311 result.backward() 312 res.append(model[0].weight.grad) 313 res.append(model[0].bias.grad) 314 model.zero_grad() 315 return res 316 317 expected = fn() 318 with config.patch(compiled_autograd=True): 319 compiled_fn = torch.compile(fn, backend="aot_eager") 320 actual = compiled_fn() 321 self.assertEqual(expected, actual) 322 self.assertEqual(counters["compiled_autograd"]["captures"], 1) 323 324 def test_torch_compile_api_eager(self): 325 def fn(): 326 torch.manual_seed(123) 327 model = torch.nn.Sequential( 328 torch.nn.Linear(4, 4), 329 torch.nn.Sigmoid(), 330 ) 331 332 res = [] 333 for _ in range(3): 334 x = torch.randn([1, 4]) 335 336 result = model(x).sum() 337 result.backward() 338 res.append(model[0].weight.grad) 339 res.append(model[0].bias.grad) 340 model.zero_grad() 341 return res 342 343 expected = fn() 344 with config.patch(compiled_autograd=True): 345 compiled_fn = torch.compile(fn, backend="eager") 346 actual = compiled_fn() 347 self.assertEqual(expected, actual) 348 self.assertEqual(counters["compiled_autograd"]["captures"], 1) 349 350 def test_multiple_torch_compile(self): 351 model = torch.nn.Sequential( 352 torch.nn.Linear(4, 4), 353 torch.nn.Sigmoid(), 354 ) 355 x = torch.randn([1, 4]) 356 357 def fn(): 358 result = model(x).sum() 359 result.backward() 360 361 model2 = torch.nn.Linear(4, 4) 362 x2 = torch.randn([1, 4]) 363 364 def fn2(): 365 result = model2(x2).sum() 366 result.backward() 367 368 no_ca1 = torch.compile(fn) 369 no_ca1() 370 self.assertEqual(counters["compiled_autograd"]["captures"], 0) 371 counters.clear() 372 373 with config.patch(compiled_autograd=True): 374 with_ca = torch.compile(fn2) 375 with_ca() 376 self.assertEqual(counters["compiled_autograd"]["captures"], 1) 377 counters.clear() 378 379 no_ca2 = torch.compile(fn) 380 no_ca2() 381 self.assertEqual(counters["compiled_autograd"]["captures"], 0) 382 383 def test_torch_compile_graph_break(self): 384 model = torch.nn.Sequential( 385 torch.nn.Linear(4, 4), 386 torch.nn.Sigmoid(), 387 ) 388 x = torch.randn([1, 4]) 389 390 @torch._dynamo.disable() 391 def fn(): 392 result = model(x).sum() 393 result.backward() 394 395 with config.patch(compiled_autograd=True): 396 opt_fn = torch.compile(fn) 397 opt_fn() 398 399 self.assertEqual(counters["compiled_autograd"]["captures"], 1) 400 401 def test_torch_compile_graph_break2(self): 402 model = torch.nn.Sequential( 403 torch.nn.Linear(4, 4), 404 torch.nn.Sigmoid(), 405 ) 406 x = torch.randn([1, 4]) 407 408 @torch._dynamo.disable() 409 def inner_fn(loss): 410 loss.backward() 411 412 def fn(): 413 result = model(x).sum() 414 inner_fn(result) 415 416 with config.patch(compiled_autograd=True): 417 opt_fn = torch.compile(fn) 418 opt_fn() 419 420 self.assertEqual(counters["compiled_autograd"]["captures"], 1) 421 422 def test_torch_compile_only_backward_call(self): 423 model = torch.nn.Sequential( 424 torch.nn.Linear(4, 4), 425 torch.nn.Sigmoid(), 426 ) 427 x = torch.randn([1, 4]) 428 429 result = model(x).sum() 430 with config.patch(compiled_autograd=True): 431 opt_bwd = torch.compile(lambda: result.backward()) 432 opt_bwd() 433 434 self.assertEqual(counters["compiled_autograd"]["captures"], 1) 435 436 def test_dynamo_boxed(self): 437 def get_placeholders(gm_): 438 placeholders = [] 439 for node in gm_.graph.nodes: 440 if node.op == "placeholder": 441 placeholders.append(node) 442 return placeholders 443 444 def eager_with_check(gm, is_bwd): 445 def inner_compiler(gm_, example_inputs_): 446 placeholders = get_placeholders(gm_) 447 if is_bwd: 448 # should be boxed inputs 449 assert len(placeholders) == 1 450 else: 451 assert len(placeholders) > 1 452 453 return gm_ 454 455 return torch.compile(gm, backend=inner_compiler) 456 457 fwd_compiler_fn = functools.partial(eager_with_check, is_bwd=False) 458 bwd_compiler_fn = functools.partial(eager_with_check, is_bwd=True) 459 460 def fn(inputs): 461 args_0, args_1, args_2 = inputs 462 out = torch.mm(args_0, args_1) 463 out = torch.mm(out, args_2) 464 loss = out.sum() 465 with compiled_autograd.enable(bwd_compiler_fn): 466 loss.backward() 467 yield args_0.grad 468 yield args_1.grad 469 yield args_2.grad 470 471 inputs = [ 472 torch.randn([1, 2], requires_grad=True), 473 torch.randn([2, 3], requires_grad=True), 474 torch.randn([3, 4], requires_grad=True), 475 ] 476 477 compiled_fn = eager_with_check(fn, is_bwd=False) 478 grads = list(compiled_fn(inputs)) 479 self.assertEqual(len(grads), 3) 480 self.assertNotEqual(grads[0], None) 481 self.assertNotEqual(grads[1], None) 482 self.assertNotEqual(grads[2], None) 483 484 def test_inputs_aliasing_bytecode_attr_mutations(self): 485 # Freeze compiled autograd graph 486 compiler = torch._dynamo.compiled_autograd.AutogradCompilerInstance(compiler_fn) 487 param = torch.ones(100) 488 activ = torch.ones(100) * 2 489 inputs = [param, activ] 490 proxies, _, _ = compiler.begin_capture(inputs=inputs, sizes=[], scalars=[]) 491 param_proxy, activ_proxy = proxies 492 buf = activ_proxy * 2 493 torch.ops.inductor.accumulate_grad_.default(param_proxy, buf) 494 runtime_wrapper, compiled_fn = compiler.end_capture(buf) 495 496 def bytecode_hook(code, out_code): 497 import dis 498 import sys 499 500 if sys.version_info < (3, 11): 501 call_op = "CALL_FUNCTION" 502 else: 503 call_op = "CALL" 504 505 insts = list(dis.get_instructions(out_code)) 506 call_graph_idx = next( 507 i for i, inst in enumerate(insts) if inst.opname == call_op 508 ) 509 # pre-graph should alias: inputs_ref_0 = inputs[0] 510 matches = [ 511 inst 512 for inst in insts[:call_graph_idx] 513 if inst.opname == "STORE_FAST" and inst.argval == "inputs_ref_0" 514 ] 515 self.assertTrue(len(matches) == 1) 516 # post-graph should access inputs_ref_0 instead of inputs 517 matches = [ 518 inst for inst in insts[call_graph_idx:] if inst.argval == "inputs" 519 ] 520 self.assertTrue(len(matches) == 0) 521 matches = [ 522 inst 523 for inst in insts[call_graph_idx:] 524 if inst.opname == "LOAD_FAST" and inst.argval == "inputs_ref_0" 525 ] 526 self.assertTrue(len(matches) == 1) 527 528 torch._dynamo.reset() 529 handle = torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook) 530 try: 531 runtime_wrapper( 532 compiled_fn=compiled_fn, 533 inputs=[param, activ], 534 sizes=(), 535 scalars=(), 536 hooks=(), 537 ) 538 finally: 539 handle.remove() 540 541 def test_inputs_aliasing_bytecode_stack_restore(self): 542 logging.getLogger().setLevel(logging.WARNING) 543 from torch.testing._internal.logging_tensor import LoggingTensor 544 545 # Create a graph that allows inputs stealing 546 def forward(inputs): 547 add = inputs[0] + 1 548 add_1 = add + inputs[1] # handled in suffix for tensor subclass 549 out = add_1.cpu() 550 return (out,) 551 552 gm = torch.fx.symbolic_trace(forward) 553 torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"]) 554 compiled_fn = torch.compile(gm) 555 556 inputs = [ 557 torch.ones(1000000, dtype=torch.float32), 558 LoggingTensor(torch.ones(1)), 559 ] 560 561 def bytecode_hook(code, out_code): 562 import dis 563 import sys 564 565 if sys.version_info < (3, 11): 566 call_op = "CALL_FUNCTION" 567 else: 568 call_op = "CALL" 569 570 insts = list(dis.get_instructions(out_code)) 571 call_graph_idx = next( 572 i for i, inst in enumerate(insts) if inst.opname == call_op 573 ) 574 # pre-graph should alias: inputs_ref_0 = inputs[0] 575 matches = [ 576 inst 577 for inst in insts[:call_graph_idx] 578 if inst.opname == "STORE_FAST" and inst.argval == "inputs_ref_0" 579 ] 580 self.assertTrue(len(matches) == 1) 581 # post-graph should access inputs_ref_0 instead of inputs 582 matches = [ 583 inst for inst in insts[call_graph_idx:] if inst.argval == "inputs" 584 ] 585 self.assertTrue(len(matches) == 0) 586 matches = [ 587 inst 588 for inst in insts[call_graph_idx:] 589 if inst.opname == "LOAD_FAST" and inst.argval == "inputs_ref_0" 590 ] 591 self.assertTrue(len(matches) == 1) 592 593 torch._dynamo.reset() 594 handle = torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook) 595 try: 596 out = compiled_fn(inputs) 597 self.assertTrue(len(inputs) == 0) 598 finally: 599 handle.remove() 600 601 def test_implicit_add(self): 602 def fn(): 603 y = torch.randn(1, 4, requires_grad=True) 604 605 def model(x): 606 # y is used multiple times, gradients get added 607 return torch.sigmoid(x * y + torch.sin(y) + torch.cos(y)) 608 609 for _ in range(3): 610 x = torch.randn([1, 4]) 611 612 result = model(x).sum() 613 result.backward() 614 yield result 615 yield y.grad 616 y.grad = None 617 618 self.check_output_and_recompiles(fn) 619 620 def test_output_nodes_all_leaves(self): 621 def fn(): 622 y = torch.randn(1, 4, requires_grad=True) 623 z = torch.randn(1, 4, requires_grad=True) 624 625 def model(x): 626 return torch.sigmoid(x * z + torch.sin(y) + torch.cos(y)) 627 628 for _ in range(3): 629 x = torch.randn([1, 4]) 630 631 result = model(x).sum() 632 gy, gz = torch.autograd.grad(result, inputs=[y, z]) 633 assert y.grad is None 634 assert z.grad is None 635 yield gy 636 yield gz 637 638 self.check_output_and_recompiles(fn) 639 640 def test_output_nodes_some_leaves(self): 641 def fn(): 642 class UnreachableBwd(torch.autograd.Function): 643 @staticmethod 644 def forward(ctx, x): 645 return x 646 647 @staticmethod 648 def backward(ctx, gO): 649 raise RuntimeError 650 651 y = torch.randn(1, 4, requires_grad=True) 652 z = torch.randn(1, 4, requires_grad=True) 653 654 def model(x): 655 return torch.sigmoid(UnreachableBwd.apply(y) * z) 656 657 for _ in range(3): 658 x = torch.randn([1, 4]) 659 660 result = model(x).sum() 661 gz = torch.autograd.grad(result, inputs=[z]) 662 assert y.grad is None 663 assert z.grad is None 664 yield gz 665 666 self.check_output_and_recompiles(fn) 667 668 def test_no_output_nodes_all_leaves(self): 669 def fn(): 670 y = torch.randn(1, 4, requires_grad=True) 671 z = torch.randn(1, 4, requires_grad=True) 672 673 def model(x): 674 return torch.sigmoid(x * z + torch.sin(y) + torch.cos(y)) 675 676 for _ in range(3): 677 x = torch.randn([1, 4]) 678 result = model(x).sum() 679 out = result.backward() 680 assert out is None 681 assert y.grad is not None 682 assert z.grad is not None 683 yield y.grad 684 yield z.grad 685 y.grad = None 686 z.grad = None 687 688 self.check_output_and_recompiles(fn) 689 690 def test_no_output_nodes_some_leaves(self): 691 def fn(): 692 class UnreachableBwd(torch.autograd.Function): 693 @staticmethod 694 def forward(ctx, x): 695 return x 696 697 @staticmethod 698 def backward(ctx, gO): 699 raise RuntimeError 700 701 y = torch.randn(1, 4, requires_grad=True) 702 z = torch.randn(1, 4, requires_grad=True) 703 a = torch.randn(1, 4, requires_grad=True) 704 705 def model(x): 706 return torch.sigmoid(x * y * z * UnreachableBwd.apply(a)) 707 708 for _ in range(3): 709 x = torch.randn([1, 4]) 710 result = model(x).sum() 711 out = result.backward(inputs=[y, z]) 712 assert out is None 713 assert y.grad is not None 714 assert z.grad is not None 715 assert a.grad is None 716 yield y.grad 717 yield z.grad 718 y.grad = None 719 z.grad = None 720 721 self.check_output_and_recompiles(fn) 722 723 def test_no_output_nodes_different_leaves_will_recompile(self): 724 def fn(): 725 def fwd(x, y, z): 726 out = x * y # MulBackward0 727 out2 = out * z # MulBackward0 728 return out2.sum() # SumBackward0 729 730 x = torch.randn(5, requires_grad=True) 731 y = torch.randn(5, requires_grad=True) 732 z = torch.randn(5, requires_grad=True) 733 loss = fwd(x, y, z) 734 torch.compile(lambda: torch.autograd.backward(loss, inputs=[x]))() 735 yield x.grad 736 x.grad = None 737 738 loss = fwd(x, y, z) 739 torch.compile(lambda: torch.autograd.backward(loss, inputs=[y]))() 740 yield y.grad 741 742 # Guarded by TensorArg id, mismatch on last MulBackward0 743 self.check_output_and_recompiles(fn, 2) 744 745 def test_dynamic_shapes(self): 746 def fn(): 747 model = torch.nn.Sequential( 748 torch.nn.Linear(4, 4), 749 torch.nn.ReLU(), 750 torch.nn.Linear(4, 4), 751 torch.nn.ReLU(), 752 ) 753 opt_model = torch.compile(model, dynamic=True) 754 755 for b in range(10, 100, 10): 756 x = torch.randn([b, 4]) 757 result = opt_model(x).sum() 758 result.backward() 759 yield model[0].weight.grad 760 yield model[0].bias.grad 761 yield model[2].weight.grad 762 yield model[2].bias.grad 763 model.zero_grad() 764 765 # TODO(jansel): we should be able to get this count to 1 766 self.check_output_and_recompiles(fn, count=2) 767 768 def test_accumulate_without_zero(self): 769 def fn(): 770 model = torch.nn.Sequential( 771 torch.nn.Linear(4, 4), 772 torch.nn.ReLU(), 773 torch.nn.Linear(4, 4), 774 torch.nn.ReLU(), 775 ) 776 opt_model = torch.compile(model, dynamic=True) 777 778 for _ in range(10): 779 x = torch.randn([10, 4]) 780 result = opt_model(x).sum() 781 result.backward() 782 yield model[0].weight.grad.clone() 783 yield model[0].bias.grad.clone() 784 yield model[2].weight.grad.clone() 785 yield model[2].bias.grad.clone() 786 787 self.check_output_and_recompiles(fn, count=2) 788 789 def test_inplace_grad_update(self): 790 def fn(): 791 model = torch.nn.Sequential( 792 torch.nn.Linear(4, 4), 793 torch.nn.ReLU(), 794 ) 795 opt_model = torch.compile(model, dynamic=True) 796 797 for _ in range(10): 798 w_grad = torch.rand_like(model[0].weight) 799 b_grad = torch.rand_like(model[0].bias) 800 model[0].weight.grad = w_grad 801 model[0].bias.grad = b_grad 802 803 x = torch.randn([10, 4]) 804 result = opt_model(x).sum() 805 result.backward() 806 assert model[0].weight.grad is w_grad 807 assert model[0].bias.grad is b_grad 808 yield w_grad.clone() 809 yield b_grad.clone() 810 811 self.check_output_and_recompiles(fn, count=1) 812 813 @unittest.skipIf(not HAS_CUDA, "requires cuda") 814 def test_issue106555(self): 815 DEVICE = torch.device("cuda:0") 816 NUM_FEATURES = 256 817 818 def bias_sigmoid_mul(x1, x2, bias): 819 x2 = torch.sigmoid(x2 + bias) 820 y = x1 * x2 821 return y 822 823 bias_sigmoid_mul_jit = torch.compile(bias_sigmoid_mul) 824 825 class ModuleWithJit(nn.Module): 826 def __init__(self) -> None: 827 super().__init__() 828 self.linear_1 = nn.Linear(NUM_FEATURES, NUM_FEATURES, bias=True) 829 self.linear_2 = nn.Linear(NUM_FEATURES, NUM_FEATURES, bias=False) 830 self.linear_2_bias = nn.Parameter(torch.zeros(NUM_FEATURES)) 831 832 def forward(self, input_tensor): 833 x1 = self.linear_1(input_tensor) 834 x2 = self.linear_2(input_tensor) 835 output = bias_sigmoid_mul_jit(x1, x2, self.linear_2_bias) 836 return output 837 838 class Model(nn.Module): 839 def __init__(self) -> None: 840 super().__init__() 841 self.module_with_jit_1 = ModuleWithJit() 842 self.module_with_jit_2 = ModuleWithJit() 843 844 def forward(self, x, gradient_checkpointing: bool): 845 if gradient_checkpointing: 846 y = torch.utils.checkpoint.checkpoint( 847 self._forward, x, use_reentrant=True 848 ) 849 else: 850 y = self._forward(x) 851 return y 852 853 def _forward(self, x): 854 x = x + self.module_with_jit_1(x) 855 x = x + self.module_with_jit_2(x.transpose(-2, -3)).transpose(-2, -3) 856 return x 857 858 torch.cuda.set_device(device=DEVICE) 859 torch.manual_seed(1234567890) 860 model = Model() 861 model.train() 862 model.to(device=DEVICE) 863 model_parameters = list(model.parameters()) 864 865 torch.manual_seed(1234567890) 866 input_tensor = torch.randn(1, 128, 256, NUM_FEATURES).to(device=DEVICE) 867 input_tensor.requires_grad = True 868 target_tensor = torch.randn(1, 128, 256, NUM_FEATURES).to( 869 dtype=input_tensor.dtype, device=DEVICE 870 ) 871 872 for iteration in range(10): 873 for param in model_parameters: 874 param.grad = None 875 output_tensor = model( 876 x=input_tensor.clone(), 877 gradient_checkpointing=True, 878 ) 879 loss = torch.mean(torch.abs(target_tensor - output_tensor)) 880 loss.backward() 881 882 def test_keep_graph_simple(self): 883 x = torch.tensor([2.0], requires_grad=True) 884 y = x**2 885 886 # First backward pass; keep the computation graph 887 y.backward(retain_graph=True) 888 self.assertEqual(x.grad, torch.Tensor([4])) # dy/dx at x=2 is 4 889 890 # Note - this will run under both the eager and compiled regime. 891 def fn(): 892 # Reset the gradients 893 x.grad = torch.tensor([0.0]) 894 # Second and Third backward pass; keep the computation graph 895 y.backward(retain_graph=True) 896 self.assertEqual(x.grad, torch.Tensor([4])) # dy/dx at x=2 is 4 897 return x.grad 898 899 self.check_output_and_recompiles(fn, count=1) 900 901 def test_keep_graph_usage_after_compiled(self): 902 x = torch.tensor([2.0], requires_grad=True) 903 y = x**2 904 905 # First backward pass; keep the computation graph 906 def eager_check(): 907 y.backward(retain_graph=True) 908 self.assertEqual(x.grad, torch.Tensor([4])) # dy/dx at x=2 is 4 909 x.grad = torch.tensor([0.0]) 910 911 eager_check() 912 913 for i in range(0, 5): 914 with compiled_autograd.enable(compiler_fn): 915 eager_check() 916 917 eager_check() 918 919 def test_custom_fn_saved_tensors(self): 920 def fn(): 921 class MySin(torch.autograd.Function): 922 @staticmethod 923 def forward(ctx, x): 924 ctx.save_for_backward(x) 925 return torch.sin(x) 926 927 @staticmethod 928 def backward(ctx, gO): 929 (x,) = ctx.saved_tensors 930 return gO * torch.cos(x) 931 932 for i in [10, 100, 10, 15, 20, 25]: 933 x = torch.arange(0.0, i, requires_grad=True) 934 out = MySin.apply(x) 935 loss = out.sum() 936 loss.backward() 937 yield x.grad 938 939 self.check_output_and_recompiles(fn, count=2) 940 941 def test_custom_fn_saved_multiple_tensors(self): 942 def fn(): 943 class MyFn(torch.autograd.Function): 944 @staticmethod 945 def forward(ctx, x, y): 946 ctx.save_for_backward(x, y) 947 return torch.sin(x), torch.sin(y) 948 949 @staticmethod 950 def backward(ctx, gO_x, gO_y): 951 (x, y) = ctx.saved_tensors 952 return gO_x * torch.cos(x), gO_y * torch.cos(y) 953 954 for i in [10, 100, 10, 15, 20, 25]: 955 x = torch.arange(0.0, i, requires_grad=True) 956 y = torch.arange(0.0, i, requires_grad=True) 957 out1, out2 = MyFn.apply(x, y) 958 loss = (out1 * out2).sum() 959 loss.backward() 960 yield x.grad 961 962 self.check_output_and_recompiles(fn, count=2) 963 964 def test_custom_fn_saved_multiple_tensors_dedup(self): 965 def fn(): 966 class MyFn(torch.autograd.Function): 967 @staticmethod 968 def forward(ctx, x): 969 ctx.save_for_backward(x, x) 970 return torch.sin(x) 971 972 @staticmethod 973 def backward(ctx, gO): 974 (x1, x2) = ctx.saved_tensors 975 return gO * torch.cos(x1) * torch.cos(x2) 976 977 for i in [10, 100, 10, 15, 20, 25]: 978 x = torch.arange(0.0, i, requires_grad=True) 979 out = MyFn.apply(x) 980 loss = out.sum() 981 loss.backward() 982 yield x.grad 983 984 self.check_output_and_recompiles(fn, count=2) 985 986 def test_custom_fn_saved_shape_tensor(self): 987 def fn(): 988 class MyFn(torch.autograd.Function): 989 @staticmethod 990 def forward(ctx, x): 991 ctx.save_for_backward(x) 992 return x 993 994 @staticmethod 995 def backward(ctx, gO): 996 (x,) = ctx.saved_tensors 997 return gO * x.shape[0] 998 999 for i in [10, 100, 10, 15, 20, 25]: 1000 x = torch.arange(0.0, i, requires_grad=True) 1001 out = MyFn.apply(x) 1002 loss = out.sum() 1003 loss.backward() 1004 yield x.grad 1005 1006 self.check_output_and_recompiles(fn, count=2) 1007 1008 def test_custom_fn_saved_attr(self): 1009 def fn(): 1010 class MyFn(torch.autograd.Function): 1011 @staticmethod 1012 def forward(ctx, x): 1013 ctx.shape = x.shape 1014 return x 1015 1016 @staticmethod 1017 def backward(ctx, gO): 1018 x_shape = ctx.shape[0] 1019 return gO * x_shape 1020 1021 for i in [10, 100, 10, 15, 20, 25]: 1022 x = torch.arange(0.0, i, requires_grad=True) 1023 out = MyFn.apply(x) 1024 loss = out.sum() 1025 loss.backward() 1026 yield x.grad 1027 1028 self.check_output_and_recompiles( 1029 fn, count=2, compiler_fn=make_compiler_fn(fullgraph=False) 1030 ) 1031 1032 def test_custom_fn_multiple_grads(self): 1033 def fn(): 1034 class MyFn(torch.autograd.Function): 1035 @staticmethod 1036 def forward(ctx, x, y): 1037 return x + y, y 1038 1039 @staticmethod 1040 def backward(ctx, gO_1, gO_2): 1041 return gO_1, gO_2 1042 1043 for i in [10, 100, 10, 15, 20, 25]: 1044 x = torch.arange(0.0, i, requires_grad=True) 1045 y = torch.arange(0.0, i, requires_grad=True) 1046 out1, out2 = MyFn.apply(x, y) 1047 loss = (out1 + out2).sum() 1048 loss.backward() 1049 yield x.grad 1050 yield y.grad 1051 1052 self.check_output_and_recompiles(fn, count=2) 1053 1054 def test_custom_fn_non_variable_input(self): 1055 def fn(): 1056 class MyFn(torch.autograd.Function): 1057 @staticmethod 1058 def forward(ctx, x, y, z): 1059 return x * 2, y * 3, z * 4 1060 1061 @staticmethod 1062 def backward(ctx, gO_1, gO_2, gO_3): 1063 return gO_1, gO_2, gO_3 1064 1065 for i in [10, 100, 10, 15, 20, 25]: 1066 x = torch.arange(0.0, i, requires_grad=True) 1067 y = 1 1068 z = torch.arange(0.0, i, requires_grad=True) 1069 out1, out2, out3 = MyFn.apply(x, y, z) 1070 loss = (out1 + out2 + out3).sum() 1071 loss.backward() 1072 yield x 1073 yield y 1074 yield z 1075 1076 self.check_output_and_recompiles(fn, count=2) 1077 1078 @unittest.skipIf(not HAS_CUDA, "requires cuda") 1079 def test_logging_tensor_flaky(self) -> None: 1080 # when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore 1081 # resulting in: 1082 # - pytest: `TypeError: unsupported operand type(s) for +: 'Tensor' and 'LoggingTensor'` 1083 # - python: `TypeError: not all arguments converted during string formatting` 1084 1085 # 1. some triton involving test 1086 def fn(): 1087 def _fn(x): 1088 return x 1089 1090 x = torch.arange( 1091 1, 10, requires_grad=True, dtype=torch.float16, device="cuda" 1092 ) 1093 out = _fn(x) 1094 loss = out.sum() 1095 loss.backward() 1096 1097 with compiled_autograd.enable(compiler_fn): 1098 fn() 1099 1100 logging.getLogger().setLevel( 1101 logging.WARNING 1102 ) # triton setup overwrote it to INFO 1103 # 2. test_inputs_aliasing_bytecode_stack_restore 1104 from torch.testing._internal.logging_tensor import LoggingTensor 1105 1106 def forward(inputs): 1107 add = inputs[0] + 1 1108 add_1 = add + inputs[1] 1109 out = add_1.cpu() 1110 return (out,) 1111 1112 gm = torch.fx.symbolic_trace(forward) 1113 print(gm.print_readable()) 1114 torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"]) 1115 compiled_fn = torch.compile(gm) 1116 1117 inputs = [ 1118 torch.ones(1000000, dtype=torch.float32), 1119 LoggingTensor(torch.ones(1)), 1120 ] 1121 1122 compiled_fn(inputs) 1123 1124 @unittest.skipIf(not HAS_CUDA, "requires cuda") 1125 def test_custom_fn_output_metadata(self): 1126 def my_compiler_fn(gm): 1127 for node in gm.graph.nodes: 1128 if isinstance(node.target, torch._ops.OpOverload): 1129 assert ( 1130 node.target._name != "aten::_to_copy" 1131 ), "there should be no implicit copies (e.g. dtype casting)" 1132 1133 def inner_compiler(gm_, example_inputs_): 1134 counters["compiled_autograd"]["compiles"] += 1 1135 return inductor.compile(gm_, example_inputs_) 1136 1137 return torch.compile( 1138 gm, backend=inner_compiler, fullgraph=True, dynamic=True 1139 ) 1140 1141 def fn(): 1142 class MyFn(torch.autograd.Function): 1143 @staticmethod 1144 def forward(ctx, x): 1145 return x 1146 1147 @staticmethod 1148 def backward(ctx, gO): 1149 return gO 1150 1151 x = torch.arange( 1152 1, 10, requires_grad=True, dtype=torch.float16, device="cuda" 1153 ) 1154 x_view = x.view(3, 3) 1155 out = MyFn.apply(x_view) 1156 loss = out.sum() 1157 loss.backward() 1158 yield x.dtype 1159 yield x.device 1160 yield x.grad 1161 1162 self.check_output_and_recompiles(fn, count=1) 1163 1164 def test_custom_fn_with_same_graph(self): 1165 def fn(): 1166 class MyFn1(torch.autograd.Function): 1167 @staticmethod 1168 def forward(ctx, x): 1169 return x 1170 1171 @staticmethod 1172 def backward(ctx, gO): 1173 return gO 1174 1175 # same as MyFn1, but different autograd function id 1176 # should not be using same graph as MyFn1 1177 class MyFn2(torch.autograd.Function): 1178 @staticmethod 1179 def forward(ctx, x): 1180 return x 1181 1182 @staticmethod 1183 def backward(ctx, gO): 1184 return gO 1185 1186 for myfn in [MyFn1, MyFn2, MyFn1, MyFn2]: 1187 x = torch.arange(0.0, 10, requires_grad=True) 1188 out = myfn.apply(x) 1189 loss = out.sum() 1190 loss.backward() 1191 yield x.grad 1192 1193 self.check_output_and_recompiles( 1194 fn, count=2 1195 ) # should compile once for MyFn1 and once for MyFn2 1196 1197 def test_custom_fn_dynamically_defined_class(self): 1198 def fn(): 1199 def create_class(multiplier: int): 1200 class DynamicFn(torch.autograd.Function): 1201 @staticmethod 1202 def forward(ctx, x): 1203 return x * multiplier 1204 1205 @staticmethod 1206 def backward(ctx, gO): 1207 return gO * multiplier 1208 1209 return DynamicFn 1210 1211 for multiplier in [10, 20, 30]: 1212 x = torch.arange(0.0, 10, requires_grad=True) 1213 out = create_class(multiplier).apply(x) 1214 loss = out.sum() 1215 loss.backward() 1216 yield x.grad 1217 1218 self.check_output_and_recompiles(fn, count=3) 1219 1220 def test_custom_fn_bw_graph_break(self): 1221 def fn(): 1222 class MySin(torch.autograd.Function): 1223 @staticmethod 1224 def forward(ctx, x): 1225 ctx.save_for_backward(x) 1226 return torch.sin(x) 1227 1228 @staticmethod 1229 def backward(ctx, gO): 1230 print("graph break") 1231 (x,) = ctx.saved_tensors 1232 print("graph break") 1233 return gO * torch.cos(x) 1234 1235 for i in [10, 100, 10, 15, 20, 25]: 1236 x = torch.arange(0.0, i, requires_grad=True) 1237 out = MySin.apply(x) 1238 loss = out.sum() 1239 loss.backward() 1240 yield x.grad 1241 1242 self.check_output_and_recompiles( 1243 fn, count=[2, 6], compiler_fn=make_compiler_fn(fullgraph=False) 1244 ) 1245 1246 def test_custom_fn_compiled_fw_graph_break(self): 1247 def fn(): 1248 class MySin(torch.autograd.Function): 1249 @staticmethod 1250 def forward(ctx, x): 1251 print("graph break") 1252 ctx.save_for_backward(x) 1253 return torch.sin(x) 1254 1255 @staticmethod 1256 def backward(ctx, gO): 1257 (x,) = ctx.saved_tensors 1258 return gO * torch.cos(x) 1259 1260 opt_model = torch.compile(MySin.apply) 1261 for i in [10, 100, 10, 15, 20, 25]: 1262 x = torch.arange(0.0, i, requires_grad=True) 1263 out = opt_model(x) 1264 loss = out.sum() 1265 loss.backward() 1266 yield x.grad 1267 1268 self.check_output_and_recompiles( 1269 fn, count=2, compiler_fn=make_compiler_fn(fullgraph=False) 1270 ) 1271 self.assertEqual(counters["stats"]["unique_graphs"], 5) # 3 fw, 2 bw 1272 1273 def test_custom_fn_compiled_fw_bw_graph_break(self): 1274 def fn(): 1275 class MySin(torch.autograd.Function): 1276 @staticmethod 1277 def forward(ctx, x): 1278 print("graph break") 1279 ctx.save_for_backward(x) 1280 return torch.sin(x) 1281 1282 @staticmethod 1283 def backward(ctx, gO): 1284 print("graph break") 1285 (x,) = ctx.saved_tensors 1286 return gO * torch.cos(x) 1287 1288 opt_model = torch.compile(MySin.apply) 1289 for i in [10, 100, 10, 15, 20, 25]: 1290 x = torch.arange(0.0, i, requires_grad=True) 1291 out = opt_model(x) 1292 loss = out.sum() 1293 loss.backward() 1294 yield x.grad 1295 1296 self.check_output_and_recompiles( 1297 fn, count=[2, 6], compiler_fn=make_compiler_fn(fullgraph=False) 1298 ) 1299 self.assertEqual(counters["stats"]["unique_graphs"], 9) # 3 fw, 6 bw 1300 1301 def test_mismatch_fake_tensor_mode(self, dynamic_shape=False): 1302 """ 1303 Repro the failure of training nanogpt with both compiled-autograd 1304 and _LazyGraphModule. Check https://github.com/pytorch/pytorch/pull/118981 1305 for more context. 1306 """ 1307 B = 8 1308 x = torch.rand(B, 16) 1309 y = torch.rand(B, 16, requires_grad=True) 1310 1311 if dynamic_shape: 1312 torch._dynamo.mark_dynamic(x, 0) 1313 torch._dynamo.mark_dynamic(y, 0) 1314 1315 def f(): 1316 y.grad = None 1317 out = x + y 1318 1319 # make sure the backward call does not trigger any error when 1320 # compiling the backward graph 1321 out.sum().backward() 1322 return out, y.grad 1323 1324 self.check_output_and_recompiles(f, compile_fn=True) 1325 1326 def test_mismatch_fake_tensor_mode_dynamic_shape(self): 1327 self.test_mismatch_fake_tensor_mode(dynamic_shape=True) 1328 1329 def test_accumulate_grad_accuracy(self): 1330 def fn(): 1331 model = torch.nn.Sequential( 1332 torch.nn.Linear(2, 1, bias=False), 1333 torch.nn.Linear(1, 2, bias=False), 1334 ) 1335 x = torch.randn(2, 2) 1336 1337 out = model(x) 1338 loss = out.sum() 1339 torch.manual_seed(0) 1340 loss.backward() 1341 1342 yield model[0].weight.grad 1343 yield model[1].weight.grad 1344 1345 self.check_output_and_recompiles(fn, 1) 1346 1347 def test_trace_run_with_rng_state(self): 1348 def sdpa(xq, xk): 1349 return F.scaled_dot_product_attention(xq, xk, xk, is_causal=True) 1350 1351 def g(xq_1, xk_1, xq_2, xk_2): 1352 # xq: (bs, n_local_heads, seqlen, head_dim) 1353 # xk: (bs, n_local_heads, cache_len + seqlen, head_dim) 1354 y1 = sdpa(xq_1, xk_1) 1355 y2 = torch.utils.checkpoint.checkpoint( 1356 sdpa, xq_2, xk_2, use_reentrant=False 1357 ) 1358 y = torch.mul(y1, y2) 1359 z = torch.matmul(y, y) 1360 return z 1361 1362 def f(): 1363 bs = 1 1364 n_local_heads = 1 1365 seqlen = 2 1366 head_dim = 2 1367 cache_len = 2 1368 xq_list = [ 1369 torch.ones( 1370 (bs, n_local_heads, seqlen, head_dim), 1371 requires_grad=True, 1372 device="cpu", 1373 ) 1374 for _ in range(2) 1375 ] 1376 xk_list = [ 1377 torch.ones( 1378 (bs, n_local_heads, cache_len + seqlen, head_dim), 1379 requires_grad=True, 1380 device="cpu", 1381 ) 1382 for _ in range(2) 1383 ] 1384 out = torch.compile(g, fullgraph=True)( 1385 xq_list[0], xk_list[0], xq_list[1], xk_list[1] 1386 ) 1387 out.sum().backward() 1388 return out, *[x.grad for x in xq_list + xk_list] 1389 1390 """ 1391 Walkthrough of what happens with `run_with_rng_state`: 1392 1. `run_with_rng_state` only shows up in the backward graph (this op is inserted by the partitioner). 1393 2. The Dynamo graph captured by Compiled Autograd looks like: 1394 ``` 1395 ===== __compiled_fn_3 ===== 1396 torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): 1397 def forward(self, L_inputs_ : list): 1398 ... 1399 run_with_rng_state = torch.ops.higher_order.run_with_rng_state( 1400 getitem_8, 1401 torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default, 1402 getitem_3, getitem_4, getitem_4, 0.0, True, 1403 ) 1404 ... 1405 ``` 1406 3. We want to preserve this `run_with_rng_state` op when going through AOTAutograd. We do it by having special handling 1407 in `run_with_rng_state` op's py_functionalize_impl. 1408 """ 1409 1410 def _run_with_rng_state_op_check(inductor_post_grad_graph): 1411 # Checks that `run_with_rng_state` op exists in Compiled Autograd's Inductor post-grad graph. 1412 op_set = {node.target for node in inductor_post_grad_graph.nodes} 1413 if torch.ops.higher_order.run_and_save_rng_state not in op_set: 1414 # This is backward graph, so check existence of `run_with_rng_state` op 1415 self.assertTrue(torch.ops.higher_order.run_with_rng_state in op_set) 1416 1417 with torch._inductor.config.patch( 1418 post_grad_custom_post_pass=_run_with_rng_state_op_check 1419 ): 1420 compiler_fn = make_compiler_fn(fullgraph=True) 1421 1422 def make_compiler_fn_with_op_check(): 1423 def _compiler_fn(gm): 1424 # Checks that `run_with_rng_state` op exists in Compiled Autograd's Dynamo graph. 1425 self.assertTrue( 1426 any( 1427 node.target is torch.ops.higher_order.run_with_rng_state 1428 for node in gm.graph.nodes 1429 ) 1430 ) 1431 return compiler_fn(gm) 1432 1433 return _compiler_fn 1434 1435 compiler_fn_with_op_check = make_compiler_fn_with_op_check() 1436 self.check_output_and_recompiles( 1437 f, compiler_fn=compiler_fn_with_op_check, compile_fn=False 1438 ) 1439 1440 def test_trace_auto_functionalized(self): 1441 torch.library.define( 1442 "testlib::foo", 1443 "(Tensor(a!) x) -> (Tensor)", 1444 tags=torch.Tag.pt2_compliant_tag, 1445 ) 1446 torch.library.define( 1447 "testlib::foo_mutated", 1448 "(Tensor(a!) x) -> (Tensor)", 1449 tags=torch.Tag.pt2_compliant_tag, 1450 ) 1451 1452 @torch.library.impl("testlib::foo", "cpu") 1453 def foo(x): 1454 x.add_(5) 1455 return x 1456 1457 @torch.library.impl("testlib::foo", "Meta") 1458 def foo_meta(x): 1459 return x 1460 1461 @torch.library.impl("testlib::foo_mutated", "CompositeImplicitAutograd") 1462 def foo_mutated(x): 1463 return torch.ops.testlib.foo(x) 1464 1465 def _get_custom_policy(must_recompute_list=None): 1466 def _custom_policy(ctx, func, *args, **kwargs): 1467 if must_recompute_list is not None and func in must_recompute_list: 1468 return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE 1469 else: 1470 return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE 1471 1472 return _custom_policy 1473 1474 def context_fn(): 1475 must_recompute_list = [ 1476 torch.ops.higher_order.auto_functionalized, 1477 ] 1478 return torch.utils.checkpoint.create_selective_checkpoint_contexts( 1479 _get_custom_policy( 1480 must_recompute_list=must_recompute_list, 1481 ), 1482 ) 1483 1484 def g(x): 1485 x = torch.matmul(x, x) 1486 torch.ops.testlib.foo_mutated(x) 1487 return torch.matmul(x, x) 1488 1489 def g_cp(x): 1490 return torch.utils.checkpoint.checkpoint( 1491 g, x, use_reentrant=False, context_fn=context_fn 1492 ) 1493 1494 def f(): 1495 inps = (torch.randn(4, 4, requires_grad=True),) 1496 output = torch.compile(g_cp, backend="aot_eager", fullgraph=True)(*inps) 1497 output.sum().backward() 1498 return output, inps[0].grad 1499 1500 """ 1501 Walkthrough of what happens with `auto_functionalized`: 1502 1. `auto_functionalized` op is inserted into the graph during AOTAutograd functionalization. 1503 We force the op to be recomputed (by using SAC), so it appears in the backward graph. 1504 2. The AOT backward graph looks like: 1505 ``` 1506 ===== Backward graph 0 ===== 1507 def forward(self, primals_1: "f32[4, 4][4, 1]cpu", tangents_1: "f32[4, 4][4, 1]cpu"): 1508 ... 1509 X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = mm) 1510 ... 1511 return (add_1,) 1512 ``` 1513 3. The Compiled Autograd graph looks like: 1514 ``` 1515 ===== Compiled autograd graph ===== 1516 def forward(self, inputs, sizes, scalars, hooks): 1517 ... 1518 X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = aot0_mm) 1519 ... 1520 return [] 1521 ``` 1522 4. The Dynamo graph captured by Compiled Autograd looks like: 1523 ``` 1524 ===== __compiled_fn_3 ===== 1525 def forward(self, L_inputs_ : list): 1526 ... 1527 X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = aot0_mm) 1528 ... 1529 return (new_grad,) 1530 ``` 1531 5. The Compiled Autograd's AOT "forward-only" graph looks like: 1532 ``` 1533 ===== Forward graph 1 ===== 1534 def forward(self, arg0_1: "f32[][]cpu", arg1_1: "f32[4, 4][4, 1]cpu"): 1535 ... 1536 X = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = mm) 1537 ... 1538 return (clone_1,) 1539 ``` 1540 6. The `auto_functionalized` op should then be lowered using the normal lowering path in Inductor. 1541 """ 1542 1543 compiler_fn = make_compiler_fn(fullgraph=True, backend="aot_eager") 1544 1545 def make_compiler_fn_with_op_check(): 1546 def _compiler_fn(gm): 1547 # Checks that `auto_functionalized` op exists in Compiled Autograd's Dynamo graph. 1548 self.assertTrue( 1549 any( 1550 node.target is torch.ops.higher_order.auto_functionalized 1551 for node in gm.graph.nodes 1552 ), 1553 f"`torch.ops.higher_order.auto_functionalized` op not found in {gm.graph}", 1554 ) 1555 return compiler_fn(gm) 1556 1557 return _compiler_fn 1558 1559 compiler_fn_with_op_check = make_compiler_fn_with_op_check() 1560 self.check_output_and_recompiles( 1561 f, compiler_fn=compiler_fn_with_op_check, compile_fn=False 1562 ) 1563 1564 def test_non_traceable_autograd_cpp_node(self): 1565 cpp_source = """ 1566struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 1567 static constexpr bool is_traceable = false; 1568 1569 static torch::Tensor forward( 1570 torch::autograd::AutogradContext* ctx, 1571 const torch::Tensor& x) { 1572 return x; 1573 } 1574 1575 static torch::autograd::variable_list backward( 1576 torch::autograd::AutogradContext *ctx, 1577 torch::autograd::variable_list grad_output) { 1578 return grad_output; 1579 } 1580}; 1581 1582torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) { 1583 return CustomOpAutogradFunction::apply(x); 1584} 1585 1586TORCH_LIBRARY(test_non_traceable_autograd_cpp_node, m) { 1587 m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 1588} 1589 """ 1590 1591 module = torch.utils.cpp_extension.load_inline( 1592 name="test_non_traceable_autograd_cpp_node", 1593 cpp_sources=cpp_source, 1594 functions="custom_op_backed_by_autograd_fn", 1595 verbose=True, 1596 ) 1597 1598 def fn(): 1599 x = torch.ones(10, 10, requires_grad=True) 1600 out = torch.ops.test_non_traceable_autograd_cpp_node.custom_op_backed_by_autograd_fn( 1601 x 1602 ) 1603 loss = out.sum() 1604 loss.backward() 1605 1606 with self.assertRaisesRegex( 1607 RuntimeError, 1608 "https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/", 1609 ), compiled_autograd.enable(compiler_fn): 1610 fn() 1611 1612 @unittest.skip("Flaky, cache from test ordering affects test. #135369") 1613 def test_autograd_cpp_node(self): 1614 cpp_source = """ 1615struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 1616 static constexpr bool is_traceable = true; 1617 1618 static torch::Tensor forward( 1619 torch::autograd::AutogradContext* ctx, 1620 const torch::Tensor& x) { 1621 return x; 1622 } 1623 1624 static torch::autograd::variable_list backward( 1625 torch::autograd::AutogradContext *ctx, 1626 torch::autograd::variable_list grad_output) { 1627 return grad_output; 1628 } 1629}; 1630 1631torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) { 1632 return CustomOpAutogradFunction::apply(x); 1633} 1634 1635TORCH_LIBRARY(test_autograd_cpp_node, m) { 1636 m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 1637} 1638 """ 1639 1640 module = torch.utils.cpp_extension.load_inline( 1641 name="test_autograd_cpp_node", 1642 cpp_sources=cpp_source, 1643 functions="custom_op_backed_by_autograd_fn", 1644 verbose=True, 1645 ) 1646 1647 def fn(): 1648 for i in [10, 100, 10, 20, 10]: 1649 x = torch.ones(i, i, requires_grad=True) 1650 out = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn( 1651 x 1652 ) 1653 loss = out.sum() 1654 loss.backward() 1655 yield x.grad 1656 1657 # compiles for 10 (static) and 100 (dynamic) 1658 self.check_output_and_recompiles(fn, 2) 1659 1660 def test_autograd_cpp_node_id(self): 1661 cpp_source = """ 1662struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 1663 static constexpr bool is_traceable = true; 1664 1665 static torch::Tensor forward( 1666 torch::autograd::AutogradContext* ctx, 1667 const torch::Tensor& x) { 1668 return x; 1669 } 1670 1671 static torch::autograd::variable_list backward( 1672 torch::autograd::AutogradContext *ctx, 1673 torch::autograd::variable_list grad_output) { 1674 return grad_output; 1675 } 1676}; 1677 1678struct CustomOpAutogradFunction2 : public torch::autograd::Function<CustomOpAutogradFunction2> { 1679 static constexpr bool is_traceable = true; 1680 1681 static torch::Tensor forward( 1682 torch::autograd::AutogradContext* ctx, 1683 const torch::Tensor& x) { 1684 return x; 1685 } 1686 1687 static torch::autograd::variable_list backward( 1688 torch::autograd::AutogradContext *ctx, 1689 torch::autograd::variable_list grad_output) { 1690 return grad_output; 1691 } 1692}; 1693 1694torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) { 1695 return CustomOpAutogradFunction::apply(x); 1696} 1697 1698torch::Tensor custom_op_backed_by_autograd_fn2(torch::Tensor x) { 1699 return CustomOpAutogradFunction2::apply(x); 1700} 1701 1702TORCH_LIBRARY(test_autograd_cpp_node_id, m) { 1703 m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 1704 m.def("custom_op_backed_by_autograd_fn2", custom_op_backed_by_autograd_fn2); 1705} 1706 """ 1707 1708 module = torch.utils.cpp_extension.load_inline( 1709 name="test_autograd_cpp_node_id", 1710 cpp_sources=cpp_source, 1711 functions="custom_op_backed_by_autograd_fn", 1712 verbose=True, 1713 ) 1714 1715 def same_autograd_fn(): 1716 def fn(): 1717 x = torch.ones(10, 10, requires_grad=True) 1718 out = ( 1719 torch.ops.test_autograd_cpp_node_id.custom_op_backed_by_autograd_fn( 1720 x 1721 ) 1722 ) 1723 loss = out.sum() 1724 loss.backward() 1725 yield x.grad 1726 1727 yield from fn() # compile 1728 yield from fn() # reuse 1729 yield from fn() # reuse 1730 yield from fn() # reuse 1731 1732 self.check_output_and_recompiles(same_autograd_fn, 1) 1733 1734 def different_autograd_fn(): 1735 def fn(op): 1736 x = torch.ones(10, 10, requires_grad=True) 1737 out = op(x) 1738 loss = out.sum() 1739 loss.backward() 1740 yield x.grad 1741 1742 op1 = torch.ops.test_autograd_cpp_node_id.custom_op_backed_by_autograd_fn 1743 op2 = torch.ops.test_autograd_cpp_node_id.custom_op_backed_by_autograd_fn2 1744 yield from fn(op1) # compile 1745 yield from fn(op2) # compile 1746 yield from fn(op1) # reuse 1747 yield from fn(op2) # reuse 1748 1749 self.check_output_and_recompiles(different_autograd_fn, 2) 1750 1751 def test_autograd_cpp_node_saved(self): 1752 cpp_source = """ 1753struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 1754 static constexpr bool is_traceable = true; 1755 1756 static torch::Tensor forward( 1757 torch::autograd::AutogradContext* ctx, 1758 const torch::Tensor& x, 1759 const torch::Tensor& y, 1760 const torch::Tensor& fixed) { 1761 ctx->save_for_backward({x, y}); 1762 ctx->saved_data["fixed_tensor"] = fixed; 1763 ctx->saved_data["bool"] = true; 1764 ctx->saved_data["int"] = 1; 1765 c10::List<std::string> list({"string"}); 1766 ctx->saved_data["list"] = std::move(list); 1767 c10::Dict<std::string, double> dict; 1768 dict.insert("string", 1.0); 1769 ctx->saved_data["dict"] = std::move(dict); 1770 return x; 1771 } 1772 1773 static torch::autograd::variable_list backward( 1774 torch::autograd::AutogradContext *ctx, 1775 torch::autograd::variable_list grad_output) { 1776 const auto& saved_variables = ctx->get_saved_variables(); 1777 assert(saved_variables.size() == 2); 1778 torch::Tensor x = saved_variables[0]; 1779 torch::Tensor y = saved_variables[1]; 1780 torch::Tensor fixed = ctx->saved_data["fixed_tensor"].toTensor(); 1781 assert(ctx->saved_data["bool"].isBool()); 1782 c10::SymInt i = ctx->saved_data["int"].toSymInt(); 1783 c10::List<c10::IValue> list = ctx->saved_data["list"].toList(); 1784 assert(list.size() == 1); 1785 assert(list.get(0).toStringRef() == "string"); 1786 c10::Dict<c10::IValue, c10::IValue> dict = ctx->saved_data["dict"].toGenericDict(); 1787 assert(dict.size() == 1); 1788 assert(dict.at("string") == 1.0); 1789 1790 torch::autograd::variable_list grad_inputs(3); 1791 grad_inputs[0] = x + y + torch::sum(fixed) + i; 1792 return grad_inputs; 1793 } 1794}; 1795 1796torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x, const torch::Tensor& y, const torch::Tensor& fixed) { 1797 return CustomOpAutogradFunction::apply(x, y, fixed); 1798} 1799 1800TORCH_LIBRARY(test_autograd_cpp_node_saved, m) { 1801 m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 1802} 1803 """ 1804 1805 module = torch.utils.cpp_extension.load_inline( 1806 name="test_autograd_cpp_node_saved", 1807 cpp_sources=cpp_source, 1808 functions="custom_op_backed_by_autograd_fn", 1809 verbose=True, 1810 ) 1811 1812 def fn(): 1813 fixed = torch.ones(2, 2) 1814 for i in [10, 100, 10, 20, 10]: 1815 x = torch.ones(i, i, requires_grad=True) 1816 y = torch.randn(i, i) 1817 out = torch.ops.test_autograd_cpp_node_saved.custom_op_backed_by_autograd_fn( 1818 x, y, fixed 1819 ) 1820 loss = out.sum() 1821 loss.backward() 1822 yield x.grad 1823 1824 self.check_output_and_recompiles(fn, 2) 1825 1826 def test_autograd_cpp_node_saved_dynamic(self): 1827 cpp_source = """ 1828struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 1829 static constexpr bool is_traceable = true; 1830 1831 static torch::Tensor forward( 1832 torch::autograd::AutogradContext* ctx, 1833 const torch::Tensor& x) { 1834 ctx->save_for_backward({x}); 1835 ctx->saved_data["dynamic"] = x.view(-1); 1836 return x; 1837 } 1838 1839 static torch::autograd::variable_list backward( 1840 torch::autograd::AutogradContext *ctx, 1841 torch::autograd::variable_list grad_output) { 1842 const auto& saved_variables = ctx->get_saved_variables(); 1843 assert(saved_variables.size() == 1); 1844 torch::Tensor x = saved_variables[0]; 1845 torch::Tensor z = ctx->saved_data["dynamic"].toTensor(); 1846 1847 torch::autograd::variable_list grad_inputs(1); 1848 grad_inputs[0] = x + torch::sum(z); 1849 return grad_inputs; 1850 } 1851}; 1852 1853torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) { 1854 return CustomOpAutogradFunction::apply(x); 1855} 1856 1857TORCH_LIBRARY(test_autograd_cpp_node_saved_dynamic, m) { 1858 m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 1859} 1860 """ 1861 1862 module = torch.utils.cpp_extension.load_inline( 1863 name="test_autograd_cpp_node_saved_dynamic", 1864 cpp_sources=cpp_source, 1865 functions="custom_op_backed_by_autograd_fn", 1866 verbose=True, 1867 ) 1868 1869 def fn(): 1870 for i in [10, 100, 10, 20, 10]: 1871 x = torch.ones(i, i, requires_grad=True) 1872 out = torch.ops.test_autograd_cpp_node_saved_dynamic.custom_op_backed_by_autograd_fn( 1873 x 1874 ) 1875 loss = out.sum() 1876 loss.backward() 1877 yield x.grad 1878 1879 # compiles for 10 (static) and 100 (dynamic) 1880 self.check_output_and_recompiles(fn, 2) 1881 1882 def test_autograd_cpp_node_saved_int(self): 1883 cpp_source = """ 1884struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 1885 static constexpr bool is_traceable = true; 1886 1887 static torch::Tensor forward( 1888 torch::autograd::AutogradContext* ctx, 1889 const torch::Tensor& x, 1890 int64_t y) { 1891 ctx->save_for_backward({x}); 1892 ctx->saved_data["int"] = y; 1893 ctx->saved_data["symint"] = c10::SymInt(y); 1894 return x; 1895 } 1896 1897 static torch::autograd::variable_list backward( 1898 torch::autograd::AutogradContext *ctx, 1899 torch::autograd::variable_list grad_output) { 1900 const auto& saved_variables = ctx->get_saved_variables(); 1901 assert(saved_variables.size() == 1); 1902 torch::Tensor x = saved_variables[0]; 1903 c10::SymInt y = ctx->saved_data["int"].toSymInt(); 1904 c10::SymInt ys = ctx->saved_data["symint"].toSymInt(); 1905 1906 torch::autograd::variable_list grad_inputs(2); 1907 grad_inputs[0] = x + y + ys; 1908 return grad_inputs; 1909 } 1910}; 1911 1912torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x, int64_t y) { 1913 return CustomOpAutogradFunction::apply(x, y); 1914} 1915 1916TORCH_LIBRARY(test_autograd_cpp_node_saved_int, m) { 1917 m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 1918} 1919 """ 1920 1921 module = torch.utils.cpp_extension.load_inline( 1922 name="test_autograd_cpp_node_saved_int", 1923 cpp_sources=cpp_source, 1924 functions="custom_op_backed_by_autograd_fn", 1925 verbose=True, 1926 ) 1927 1928 def fn(): 1929 for y in [1, 2, 3, 1]: 1930 x = torch.ones(10, 10, requires_grad=True) 1931 out = torch.ops.test_autograd_cpp_node_saved_int.custom_op_backed_by_autograd_fn( 1932 x, y 1933 ) 1934 loss = out.sum() 1935 loss.backward() 1936 yield x.grad 1937 1938 self.check_output_and_recompiles(fn, 1) 1939 1940 def test_autograd_cpp_node_saved_float(self): 1941 cpp_source = """ 1942struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 1943 static constexpr bool is_traceable = true; 1944 1945 static torch::Tensor forward( 1946 torch::autograd::AutogradContext* ctx, 1947 const torch::Tensor& x, 1948 double z) { 1949 ctx->save_for_backward({x}); 1950 ctx->saved_data["float"] = z; 1951 ctx->saved_data["symfloat"] = c10::SymFloat(z); 1952 return x; 1953 } 1954 1955 static torch::autograd::variable_list backward( 1956 torch::autograd::AutogradContext *ctx, 1957 torch::autograd::variable_list grad_output) { 1958 const auto& saved_variables = ctx->get_saved_variables(); 1959 assert(saved_variables.size() == 1); 1960 torch::Tensor x = saved_variables[0]; 1961 c10::SymFloat z = ctx->saved_data["float"].toSymFloat(); 1962 c10::SymFloat zs = ctx->saved_data["symfloat"].toSymFloat(); 1963 1964 torch::autograd::variable_list grad_inputs(2); 1965 grad_inputs[0] = x + z + zs; 1966 return grad_inputs; 1967 } 1968}; 1969 1970torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x, double z) { 1971 return CustomOpAutogradFunction::apply(x, z); 1972} 1973 1974TORCH_LIBRARY(test_autograd_cpp_node_saved_float, m) { 1975 m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 1976} 1977 """ 1978 1979 module = torch.utils.cpp_extension.load_inline( 1980 name="test_autograd_cpp_node_saved_float", 1981 cpp_sources=cpp_source, 1982 functions="custom_op_backed_by_autograd_fn", 1983 verbose=True, 1984 ) 1985 1986 def fn(): 1987 for z in [1.1, 2.2, 3.3, 1.1]: 1988 x = torch.ones(10, 10, requires_grad=True) 1989 out = torch.ops.test_autograd_cpp_node_saved_float.custom_op_backed_by_autograd_fn( 1990 x, z 1991 ) 1992 loss = out.sum() 1993 loss.backward() 1994 yield x.grad 1995 1996 # compiled autograd and dynamo both support symfloat, but not backend 1997 self.check_output_and_recompiles(fn, [1, 3]) 1998 1999 def test_autograd_cpp_node_data_dependent(self): 2000 cpp_source = """ 2001struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 2002 static constexpr bool is_traceable = true; 2003 static int iteration; 2004 2005 static torch::autograd::variable_list forward( 2006 torch::autograd::AutogradContext* ctx, 2007 const torch::Tensor& x, 2008 const torch::Tensor& y) { 2009 ctx->save_for_backward({x, y}); 2010 ctx->saved_data["bool"] = true; 2011 ctx->saved_data["int"] = 1; 2012 2013 switch (iteration) { 2014 case 0: { 2015 break; 2016 } 2017 case 1: { 2018 // recompile 2019 ctx->saved_data["forces_recompile"] = iteration; 2020 break; 2021 } 2022 case 2: { 2023 // recompile 2024 ctx->set_materialize_grads(false); 2025 break; 2026 } 2027 case 3: { 2028 // reuse 2029 break; 2030 } 2031 default: { 2032 throw std::runtime_error("unexpected iteration"); 2033 } 2034 } 2035 iteration++; 2036 return {x, y}; 2037 } 2038 2039 static torch::autograd::variable_list backward( 2040 torch::autograd::AutogradContext *ctx, 2041 torch::autograd::variable_list grad_output) { 2042 const auto& saved_variables = ctx->get_saved_variables(); 2043 assert(saved_variables.size() == 2); 2044 torch::Tensor x = saved_variables[0]; 2045 torch::Tensor y = saved_variables[1]; 2046 c10::SymInt i = ctx->saved_data["int"].toSymInt(); 2047 2048 torch::autograd::variable_list grad_inputs(2); 2049 grad_inputs[0] = x + y + i; 2050 return grad_inputs; 2051 } 2052}; 2053 2054int CustomOpAutogradFunction::iteration = 0; 2055 2056torch::autograd::variable_list custom_op_backed_by_autograd_fn(const torch::Tensor& x, const torch::Tensor& y) { 2057 return CustomOpAutogradFunction::apply(x, y); 2058} 2059 2060void reset() { 2061 CustomOpAutogradFunction::iteration = 0; 2062} 2063 2064TORCH_LIBRARY(test_autograd_cpp_node_data_dependent, m) { 2065 m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 2066 m.def("reset", reset); 2067} 2068 """ 2069 2070 module = torch.utils.cpp_extension.load_inline( 2071 name="test_autograd_cpp_node_data_dependent", 2072 cpp_sources=cpp_source, 2073 functions="custom_op_backed_by_autograd_fn", 2074 verbose=True, 2075 ) 2076 2077 def fn(): 2078 torch.ops.test_autograd_cpp_node_data_dependent.reset() 2079 for i in [10, 10, 10, 10]: 2080 x = torch.ones(i, i, requires_grad=True) 2081 y = torch.randn(i, i) 2082 ( 2083 out1, 2084 out2, 2085 ) = torch.ops.test_autograd_cpp_node_data_dependent.custom_op_backed_by_autograd_fn( 2086 x, y 2087 ) 2088 loss = (out1 + out2).sum() 2089 loss.backward() 2090 yield x.grad 2091 2092 self.check_output_and_recompiles(fn, 3) 2093 2094 @unittest.skipIf(not HAS_CUDA, "requires cuda") 2095 def test_free_activation_memory(self): 2096 script = """ 2097import torch 2098 2099def main(): 2100 assert(torch.cuda.memory_allocated() == 0) 2101 2102 # Use an op to check that the memory is freed by the time the op is executed 2103 def assertion_impl(to_clone): 2104 mem_allocated = torch.cuda.memory_allocated() 2105 assert mem_allocated < 4000000 # some activations should be freed 2106 return to_clone.clone() 2107 2108 with torch.library._scoped_library("test_compiled_autograd", "FRAGMENT") as lib: 2109 lib.define( 2110 "assertion_op(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,) 2111 ) 2112 lib.impl("assertion_op", assertion_impl, "CPU") 2113 lib.impl("assertion_op", lambda x: x.clone(), "Meta") 2114 2115 # Create a graph that allows inputs stealing 2116 def forward(activations): 2117 add = activations[0] + 1 2118 out = add.cpu() 2119 cloned_out = torch.ops.test_compiled_autograd.assertion_op(out) 2120 return (cloned_out,) 2121 2122 gm = torch.fx.symbolic_trace(forward) 2123 torch._dynamo.utils.set_locals_to_steal(gm, ["activations"]) 2124 compiled_fn = torch.compile(gm) 2125 2126 # allocate at least 4,000,000 bytes (1,000,000 * 4 bytes) 2127 activations = [torch.ones(1000000, dtype=torch.float32, device="cuda")] 2128 assert torch.cuda.memory_allocated() > 4000000 2129 2130 out = compiled_fn(activations) 2131 assert len(activations) == 0 2132 2133main() 2134 """ 2135 self.run_as_subprocess(script) 2136 2137 @unittest.skipIf(not HAS_CUDA, "requires cuda") 2138 def test_free_activation_memory_subclass(self): 2139 # cover the case when aot inputs have subclasses, resulting in a different runtime wrapper 2140 2141 script = """ 2142import torch 2143 2144def main(): 2145 assert torch.cuda.memory_allocated() == 0 2146 2147 # Use an op to check that the memory is freed by the time the op is executed 2148 def assertion_impl(to_clone): 2149 mem_allocated = torch.cuda.memory_allocated() 2150 assert mem_allocated < 1200000 # some activations should be freed 2151 assert mem_allocated > 800000 # currently subclasses don't seem to be freed in inductor 2152 return to_clone.clone() 2153 2154 with torch.library._scoped_library("test_compiled_autograd", "FRAGMENT") as lib: 2155 lib.define( 2156 "assertion_op(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,) 2157 ) 2158 lib.impl("assertion_op", assertion_impl, "CPU") 2159 lib.impl("assertion_op", lambda x: x.clone(), "Meta") 2160 lib.impl("assertion_op", lambda x: x.clone(), "NestedTensor") 2161 2162 def fn(inputs): 2163 _, y = inputs 2164 out = y.cpu() 2165 cloned_out = torch.ops.test_compiled_autograd.assertion_op(out) 2166 return cloned_out 2167 2168 gm = torch.fx.symbolic_trace(fn) 2169 torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"]) 2170 compiled_fn = torch.compile(gm) 2171 2172 from torch.nested._internal.nested_tensor import jagged_from_list 2173 2174 activations = [ 2175 jagged_from_list( 2176 [ 2177 torch.ones((1, 100000), device="cuda"), # 400,000 bytes 2178 torch.ones((1, 100000), device="cuda"), # 400,000 bytes 2179 ], 2180 None, 2181 )[ 2182 0 2183 ], # NestedTensor 2184 torch.ones((1, 100000), device="cuda"), # 400,000 bytes 2185 ] 2186 # 1,200,000 bytes (3 * 4 * 100,000 bytes) 2187 assert torch.cuda.memory_allocated() > 1200000 2188 2189 out = compiled_fn(activations) 2190 assert len(activations) == 0 2191 2192main() 2193 """ 2194 2195 def test_callback_graph_break_throws_error(self): 2196 called = [0] 2197 2198 def callback_final(): 2199 called[0] += 1 2200 2201 class MyFunc(torch.autograd.Function): 2202 @staticmethod 2203 def forward(ctx, input): 2204 return input 2205 2206 @staticmethod 2207 @torch.autograd.function.once_differentiable 2208 def backward(ctx, grad): 2209 torch.autograd.Variable._execution_engine.queue_callback(callback_final) 2210 torch._dynamo.graph_break() 2211 return grad 2212 2213 a = torch.rand((3, 3), requires_grad=True) 2214 with self.assertRaisesRegex( 2215 AssertionError, 2216 "only supported when Compiled Autograd is enabled with fullgraph=True", 2217 ): 2218 with compiled_autograd.enable(make_compiler_fn(fullgraph=False)): 2219 b = MyFunc.apply(a) 2220 b.sum().backward() 2221 2222 @unittest.skipIf(not HAS_CUDA, "requires cuda") 2223 def test_cudagraphs_cpu_division(self): 2224 from torch._dynamo.testing import reduce_to_scalar_loss 2225 2226 model = torch.nn.Linear(10, 10, dtype=torch.float16).cuda() 2227 inputs = torch.randn(10, 10, dtype=torch.float16).cuda() 2228 out = model(inputs) 2229 loss = reduce_to_scalar_loss(out) 2230 2231 stderr_msgs = io.StringIO() 2232 with mock.patch("sys.stderr", stderr_msgs), compiled_autograd.enable( 2233 compiler_fn 2234 ): 2235 torch._inductor.config.triton.cudagraphs = True 2236 loss.backward() 2237 torch._inductor.config.triton.cudagraphs = False 2238 2239 self.assertFalse("skipping cudagraphs" in stderr_msgs.getvalue()) 2240 2241 def test_cudagraphs_cpu_graph(self): 2242 from torch._dynamo.testing import reduce_to_scalar_loss 2243 2244 model = torch.nn.Linear(10, 10, dtype=torch.float16) 2245 inputs = torch.randn(10, 10, dtype=torch.float16) 2246 out = model(inputs) 2247 loss = reduce_to_scalar_loss(out) 2248 2249 with compiled_autograd.enable(compiler_fn): 2250 torch._inductor.config.triton.cudagraphs = True 2251 loss.backward() 2252 torch._inductor.config.triton.cudagraphs = False 2253 2254 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 2255 2256 @unittest.skipIf(not HAS_CUDA, "requires cuda") 2257 def test_cudagraphs_sdpa(self): 2258 query = torch.rand( 2259 32, 8, 128, 64, dtype=torch.float16, device="cuda", requires_grad=True 2260 ) 2261 key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 2262 value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") 2263 out = torch.nn.functional.scaled_dot_product_attention(query, key, value) 2264 2265 with config.patch(compiled_autograd=True), inductor_config.patch( 2266 "triton.cudagraphs", True 2267 ): 2268 opt_bwd = torch.compile(lambda: out.sum().backward()) 2269 opt_bwd() 2270 2271 self.assertEqual(counters["compiled_autograd"]["captures"], 1) 2272 self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) 2273 2274 @unittest.skipIf(not HAS_CUDA, "requires cuda") 2275 def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self): 2276 class MyFn(torch.autograd.Function): 2277 @staticmethod 2278 def forward(ctx, x): 2279 cpu_tensor = torch.tensor(5) 2280 ctx.save_for_backward(x, cpu_tensor) # visible to c++/autograd 2281 ctx.cpu_scalar = 5 # opaque to c++/autograd 2282 return x.sum() 2283 2284 @staticmethod 2285 def backward(ctx, gO): 2286 x, cpu_tensor = ctx.saved_tensors 2287 expand = gO * torch.ones_like(x) 2288 return expand * cpu_tensor * ctx.cpu_scalar 2289 2290 x = torch.randn(10, requires_grad=True, device="cuda") 2291 out = MyFn.apply(x) 2292 with config.patch(compiled_autograd=True), inductor_config.patch( 2293 "triton.cudagraphs", True 2294 ): 2295 opt_bwd = torch.compile(lambda: out.backward()) 2296 opt_bwd() 2297 2298 self.assertEqual(counters["compiled_autograd"]["captures"], 1) 2299 # Compiled autograd lifts custom autograd.Function bwd instead of tracing it. 2300 # Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. 2301 self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) 2302 2303 @unittest.skipIf(not HAS_CUDA, "requires cuda") 2304 def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self): 2305 cpp_source = """ 2306struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 2307 static constexpr bool is_traceable = true; 2308 2309 static torch::Tensor forward( 2310 torch::autograd::AutogradContext* ctx, 2311 const torch::Tensor& x) { 2312 const auto& cpu_tensor = torch::tensor(1); 2313 ctx->save_for_backward({x, cpu_tensor}); 2314 ctx->saved_data["cpu_scalar"] = 1; 2315 return x; 2316 } 2317 2318 static torch::autograd::variable_list backward( 2319 torch::autograd::AutogradContext *ctx, 2320 torch::autograd::variable_list grad_output) { 2321 const auto& saved_variables = ctx->get_saved_variables(); 2322 assert(saved_variables.size() == 2); 2323 torch::Tensor x = saved_variables[0]; 2324 torch::Tensor cpu_tensor = saved_variables[1]; 2325 int cpu_scalar = ctx->saved_data["cpu_scalar"].toInt(); 2326 auto expand = grad_output[0] * torch::ones_like(x); 2327 torch::autograd::variable_list grad_inputs(1); 2328 grad_inputs[0] = expand * cpu_tensor * cpu_scalar; // autograd engine asserts that tensors are on same device 2329 return grad_inputs; 2330 } 2331}; 2332 2333torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) { 2334 return CustomOpAutogradFunction::apply(x); 2335} 2336 2337TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { 2338 m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 2339} 2340 """ 2341 2342 module = torch.utils.cpp_extension.load_inline( 2343 name="test_cudagraphs_cpu_scalar_used_in_cpp_custom_op", 2344 cpp_sources=cpp_source, 2345 functions="custom_op_backed_by_autograd_fn", 2346 verbose=True, 2347 ) 2348 2349 x = torch.randn(2, 2, requires_grad=True, device="cuda") 2350 with config.patch(compiled_autograd=True), inductor_config.patch( 2351 "triton.cudagraphs", True 2352 ): 2353 out = torch.ops.test_cudagraphs_cpu_scalar_used_in_cpp_custom_op.custom_op_backed_by_autograd_fn( 2354 x 2355 ) 2356 opt_bwd = torch.compile(lambda: out.sum().backward()) 2357 opt_bwd() 2358 2359 self.assertEqual(counters["compiled_autograd"]["captures"], 1) 2360 # always safe to move, since we trace into the autograd::function bwd and can see if it's only used by aten ops 2361 self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) 2362 2363 def test_logs(self): 2364 logs, ctx = logs_to_string( 2365 torch._dynamo.compiled_autograd.__name__, "compiled_autograd" 2366 ) 2367 with compiled_autograd.enable(compiler_fn), ctx(): 2368 torch.randn(4, 4, requires_grad=True).sum().backward() 2369 2370 self.assertEqual(counters["compiled_autograd"]["captures"], 1) 2371 self.assertEqual(counters["compiled_autograd"]["compiles"], 1) 2372 assert "torch::autograd::AccumulateGrad (NodeCall" in logs.getvalue() 2373 assert ( 2374 "Cache miss due to new autograd node: torch::autograd::GraphRoot" 2375 not in logs.getvalue() 2376 ) 2377 2378 def test_verbose_logs_graph(self): 2379 def fn(): 2380 model = torch.nn.Sequential( 2381 torch.nn.Linear(4, 4), 2382 torch.nn.ReLU(), 2383 torch.nn.Linear(4, 4), 2384 torch.nn.ReLU(), 2385 ) 2386 x = torch.randn([2, 4]) 2387 result = model(x).sum() 2388 result.backward() 2389 yield model[0].weight.grad 2390 yield model[0].bias.grad 2391 yield model[2].weight.grad 2392 yield model[2].bias.grad 2393 2394 logs, ctx = logs_to_string( 2395 torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose" 2396 ) 2397 with ctx(): 2398 self.check_output_and_recompiles(fn) 2399 2400 expected_logs = [ 2401 "SumBackward0 (NodeCall 1)", 2402 "ReluBackward0 (NodeCall 2)", 2403 "AddmmBackward0 (NodeCall 3)", 2404 "TBackward0 (NodeCall 4)", 2405 "torch::autograd::AccumulateGrad (NodeCall 5)", 2406 "ReluBackward0 (NodeCall 6)", 2407 "AddmmBackward0 (NodeCall 7)", 2408 "TBackward0 (NodeCall 8)", 2409 "torch::autograd::AccumulateGrad (NodeCall 9)", 2410 "torch::autograd::AccumulateGrad (NodeCall 10)", 2411 "torch::autograd::AccumulateGrad (NodeCall 11)", 2412 ] 2413 2414 self.assertEqual( 2415 sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs) 2416 ) 2417 2418 @mock.patch( 2419 "torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count 2420 ) 2421 @mock.patch("torch._dynamo.config.inline_inbuilt_nn_modules", True) 2422 def test_verbose_logs_aot_id(self, _): 2423 def fn(): 2424 model = torch.nn.Sequential( 2425 torch.nn.Linear(4, 4), 2426 torch.nn.ReLU(), 2427 torch.nn.Linear(4, 4), 2428 torch.nn.ReLU(), 2429 ) 2430 x = torch.randn([2, 4]) 2431 2432 @torch.compile 2433 def forward(model, x): 2434 return model(x) 2435 2436 result = forward(model, x).sum() 2437 result.backward() 2438 yield model[0].weight.grad 2439 yield model[0].bias.grad 2440 yield model[2].weight.grad 2441 yield model[2].bias.grad 2442 2443 logs, ctx = logs_to_string( 2444 torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose" 2445 ) 2446 with ctx(): 2447 self.check_output_and_recompiles(fn) 2448 2449 self.assertTrue("CompiledFunctionBackward0" in logs.getvalue()) 2450 2451 @mock.patch( 2452 "torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count 2453 ) 2454 def test_verbose_logs_aot_dispatcher_nodes(self, _): 2455 def fn(): 2456 @torch.compile 2457 def f(x): 2458 tmp1 = x.sin() 2459 tmp2 = x.cos() 2460 torch._dynamo.graph_break() 2461 return tmp1.sin() + tmp2.cos() 2462 2463 x = torch.randn(4, requires_grad=True) 2464 out = f(x) 2465 out.sum().backward() 2466 yield x.grad 2467 2468 logs, ctx = logs_to_string( 2469 torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose" 2470 ) 2471 with ctx(): 2472 self.check_output_and_recompiles(fn) 2473 2474 expected_logs = [ 2475 "CompiledFunctionBackward1", 2476 "aot1_tangents_1", 2477 "aot1_sin_1", 2478 "aot1_primals_2", 2479 "aot1_neg", 2480 "aot0_tangents_2", 2481 "aot1_cos_1", 2482 "aot1_primals_1", 2483 "aot0_tangents_1", 2484 "CompiledFunctionBackward0", 2485 "aot0_neg", 2486 "aot0_sin", 2487 "aot0_mul", 2488 "aot0_mul_1", 2489 "aot0_cos", 2490 "aot0_add", 2491 ] 2492 2493 self.assertEqual( 2494 sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs) 2495 ) 2496 2497 @mock.patch( 2498 "torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count 2499 ) 2500 def test_verbose_logs_aot_dispatcher_nodes_hop(self, _): 2501 @dataclasses.dataclass 2502 class CustomObj: 2503 val: torch.Tensor 2504 2505 def fn(x, obj): 2506 y = x.sin() 2507 closure_var = y + 1 2508 y.register_hook(lambda grad: grad + obj.val + closure_var) 2509 z = y.sin() 2510 return z 2511 2512 opt_fn = torch.compile(fn) 2513 2514 x = torch.ones(4, requires_grad=True) 2515 y = torch.ones(4, requires_grad=True) 2516 obj = CustomObj(torch.tensor(88)) 2517 fn(x, obj).sum().backward() 2518 2519 logs, ctx = logs_to_string( 2520 torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose" 2521 ) 2522 with ctx(), compiled_autograd.enable(compiler_fn): 2523 opt_fn(y, obj).sum().backward() 2524 self.assertEqual(x.grad, y.grad) 2525 2526 expected_logs = [ 2527 "CompiledFunctionBackward0", 2528 "aot0_primals_2", 2529 "aot0_tangents_2", 2530 "aot0_tangents_1", 2531 "aot0_sin", 2532 "aot0_cos", 2533 "aot0_mul", 2534 "aot0_add_1", 2535 "aot0_trace_wrapped", 2536 "aot0_cos_1", 2537 "aot0_mul_1", 2538 ] 2539 2540 self.assertEqual( 2541 sum(1 for e in expected_logs if e in logs.getvalue()), len(expected_logs) 2542 ) 2543 2544 @skipIfWindows(msg="AssertionError: Scalars are not equal!") 2545 def test_verbose_logs_cpp(self): 2546 torch._logging.set_logs(compiled_autograd_verbose=True) 2547 2548 def fn(): 2549 model = torch.nn.Sequential( 2550 torch.nn.Linear(4, 4), 2551 torch.nn.ReLU(), 2552 torch.nn.Linear(4, 4), 2553 torch.nn.ReLU(), 2554 ) 2555 for i in [10, 11, 12]: 2556 model.zero_grad() 2557 x = torch.randn([i, 4]) 2558 result = model(x).sum() 2559 result.backward() 2560 yield model[0].weight.grad 2561 yield model[0].bias.grad 2562 yield model[2].weight.grad 2563 yield model[2].bias.grad 2564 2565 logs, ctx = logs_to_string( 2566 torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose" 2567 ) 2568 with ctx(): 2569 self.check_output_and_recompiles(fn, count=2) 2570 2571 patterns1 = [ 2572 r".*Cache miss due to new autograd node: torch::autograd::GraphRoot \(NodeCall 0\) with key size (\d+), " 2573 r"previous key sizes=\[\]\n", 2574 ] 2575 2576 # recompile 2577 patterns2 = [ 2578 r".*Cache miss due to changed shapes: marking size idx (\d+) of torch::autograd::GraphRoot \(NodeCall 0\) as dynamic\n", 2579 r".*Cache miss due to changed shapes: marking size idx (\d+) of SumBackward0 \(NodeCall 1\) as dynamic\n", 2580 r".*Cache miss due to changed shapes: marking size idx (\d+) of SumBackward0 \(NodeCall 1\) as dynamic\n", 2581 r".*Cache miss due to changed shapes: marking size idx (\d+) of ReluBackward0 \(NodeCall 2\) as dynamic\n", 2582 r".*Cache miss due to changed shapes: marking size idx (\d+) of AddmmBackward0 \(NodeCall 3\) as dynamic\n", 2583 r".*Cache miss due to changed shapes: marking size idx (\d+) of torch::autograd::AccumulateGrad " 2584 r"\(NodeCall 5\) as dynamic\n", 2585 r".*Cache miss due to changed shapes: marking size idx (\d+) of ReluBackward0 \(NodeCall 6\) as dynamic\n", 2586 ] 2587 2588 all_logs = logs.getvalue() 2589 2590 pattern1 = r"".join(patterns1) 2591 matches1 = re.findall(pattern1, all_logs) 2592 self.assertEqual(len(matches1), 1) 2593 assert isinstance( 2594 matches1[0], str 2595 ) # for a single match: matches1=['match'], for multiple matches: matches1=[('match1', 'match2')]... 2596 self.assertEqual(len(matches1), len(patterns1)) 2597 2598 pattern2 = r"".join(patterns2) 2599 matches2 = re.findall(pattern2, all_logs) 2600 self.assertEqual(len(matches2), 1) 2601 self.assertEqual(len(matches2[0]), len(patterns2)) 2602 2603 def test_verbose_logs_snapshot(self): 2604 def fn(): 2605 model = torch.nn.Sequential( 2606 torch.nn.Linear(4, 4), 2607 torch.nn.ReLU(), 2608 torch.nn.Linear(4, 4), 2609 torch.nn.ReLU(), 2610 ) 2611 x = torch.randn([2, 4]) 2612 result = model(x).sum() 2613 result.backward() 2614 yield model[0].weight.grad 2615 yield model[0].bias.grad 2616 yield model[2].weight.grad 2617 yield model[2].bias.grad 2618 2619 logs, ctx = logs_to_string( 2620 torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose" 2621 ) 2622 with ctx(): 2623 with compiled_autograd.enable(compiler_fn): 2624 # unused, verbose level already snapshot with contextmanager 2625 torch._logging.set_logs(compiled_autograd_verbose=True) 2626 fn() 2627 2628 unexpected_logs = [ 2629 "Cache miss due to new autograd node: torch::autograd::GraphRoot (NodeCall 0)" 2630 ] 2631 2632 self.assertEqual(sum(1 for e in unexpected_logs if e in logs.getvalue()), 0) 2633 2634 @unittest.expectedFailure 2635 def test_saved_tensor_unpack_hook_ordering(self): 2636 # not the correct behaviour, I'm just preventing this from changing silently 2637 def f(x, y): 2638 return x * y 2639 2640 pack_count = 0 2641 unpack_count = 0 2642 2643 def pack_hook(x): 2644 nonlocal pack_count 2645 pack_count += 1 2646 return x 2647 2648 def unpack_hook(x): 2649 nonlocal unpack_count 2650 unpack_count += 1 2651 return x 2652 2653 def tensor_hook(_): 2654 # in eager, tensor_hook is fired before unpack_hook 2655 # but in compiled autograd, tensor_hook is lifted whereas unpack_hook is not 2656 self.assertEqual(unpack_count, 0) 2657 2658 x = torch.ones(4, requires_grad=True) 2659 y = torch.ones(4, requires_grad=False) 2660 with torch.autograd.graph.saved_tensors_hooks( 2661 pack_hook, unpack_hook 2662 ), compiled_autograd.enable(make_compiler_fn(fullgraph=False)): 2663 out_test = f(x, y) 2664 self.assertEqual(pack_count, 1) 2665 self.assertEqual(unpack_count, 0) 2666 loss = out_test.sum() 2667 loss.register_hook(tensor_hook) 2668 loss.backward() 2669 self.assertEqual(pack_count, 1) 2670 self.assertEqual(unpack_count, 1) 2671 2672 def test_reentrant_checkpointing(self): 2673 def fn(x): 2674 y = x.sin() 2675 z = y.cos() 2676 return (y * z).sum() 2677 2678 inp = torch.rand(10, 10, requires_grad=True) 2679 out = torch.utils.checkpoint.checkpoint(fn, inp, use_reentrant=True) 2680 with self.assertRaisesRegex( 2681 RuntimeError, 2682 r"\(e.g. reentrant checkpointing\), this is not supported yet\.", 2683 ), torch._dynamo.compiled_autograd.enable(torch.compile): 2684 out.backward() 2685 2686 2687def load_test_module(name): 2688 testdir = Path(__file__).absolute().parent.parent 2689 with mock.patch("sys.path", [*sys.path, str(testdir)]): 2690 return SourceFileLoader( 2691 name, str(testdir / f"{name.replace('.', '/')}.py") 2692 ).load_module() 2693 2694 2695def make_wrapped(fn, ctxs): 2696 @functools.wraps(fn) 2697 def wrapped(self): 2698 torch._dynamo.reset() 2699 stack = contextlib.ExitStack() 2700 for ctx in ctxs: 2701 stack.enter_context(ctx) 2702 out = fn(self) 2703 stack.close() 2704 return out 2705 2706 return wrapped 2707 2708 2709def wrap_test_class(orig_cls): 2710 dct = orig_cls.__dict__.copy() 2711 for name in list(dct.keys()): 2712 fn = dct[name] 2713 if not callable(fn) or name in skipped_tests: 2714 continue 2715 elif known_failures_re.match(name) or name in known_failing_tests: 2716 dct[name] = unittest.expectedFailure 2717 elif name.startswith("test_"): 2718 fullgraph = name not in known_graph_breaks_tests 2719 ctxs = [ 2720 compiled_autograd.enable(make_compiler_fn(fullgraph=fullgraph)), 2721 test_contexts.get(name, contextlib.nullcontext()), 2722 ] 2723 dct[name] = make_wrapped(fn, ctxs) 2724 2725 cls = type( 2726 orig_cls.__name__ + "WithCompiledAutograd", 2727 orig_cls.__bases__, 2728 dct, 2729 ) 2730 cls.__file__ = __file__ 2731 return cls 2732 2733 2734known_graph_breaks_tests = { 2735 "test_hook_none", # uses assert in hook 2736 "test_post_accumulate_grad_hook_e2e", # optim.Adam manually graph breaks 2737 "test_tensor_hooks_inplace", # uses assert in hook 2738 "test_tensor_hooks_inplace_over_view", # uses assert in hook 2739 "test_grad_fn_prehooks", # uses assert in hook 2740 "test_grad_fn_prehooks_multiple_outputs", # uses assert in hook 2741 "test_grad_fn_prehooks_remove_hooks", # uses handle.remove() in hook 2742 "test_tensor_hooks_inplace_multiple_outputs", # uses assert in hook 2743 "test_hooks", # uses assert in hook 2744 "test_accumulate_grad_posthooks_can_observe_tensor_prehook", # allclose 2745 "test_saved_tensors_hook_version_counter_not_shared", # assertEqual 2746 "test_post_accumulate_grad_hook_returns_not_None", # throws 2747 "test_custom_function_cycle", # assertEqual 2748 "test_mark_non_differentiable_mixed", # assertTrue 2749 "test_materialize_grads", # assertEqual 2750 "test_return_leaf", # assertEqual 2751 "test_save_none_for_backward", # assertIsNone 2752 "test_saved_variables_deprecated", # warnings.warn 2753 "test_autograd_node_isinstance", # assertIsInstance 2754 "test_set_materialize_non_diff_grads", # assertIsNone 2755 "test_backward_dict_grad_for_nontensor", # torch/_custom_op/autograd.py in skip files 2756 "test_backward_dict_invalid_keys", # torch/_custom_op/autograd.py in skip files 2757 "test_backward_dict_requires_keys_for_input_optional_tensors", # torch/_custom_op/autograd.py in skip files 2758 "test_backward_dict_requires_keys_for_input_tensors", # torch/_custom_op/autograd.py in skip files 2759 "test_backward_grads_are_tensor_or_none", # torch/_custom_op/autograd.py in skip files 2760 "test_backward_impl_on_existing_op", # torch/_custom_op/autograd.py in skip files 2761 "test_backward_returns_dict", # torch/_custom_op/autograd.py in skip files 2762 "test_backward_tensorlist_input_requires_list_grads", # torch/_custom_op/autograd.py in skip files 2763 "test_backward_tensorlist_input_requires_list_grads_none_or_Tensor", # torch/_custom_op/autograd.py in skip files 2764 "test_backward_tensorlist_input_requires_list_grads_with_same_numel", # torch/_custom_op/autograd.py in skip files 2765 "test_save_for_backward_inputs_are_namedtuple", # torch/_custom_op/autograd.py in skip files 2766} 2767 2768test_contexts = { 2769 "test_setitem_mask": config.patch(capture_dynamic_output_shape_ops=True), 2770 "test_index_backward_does_not_save_tensor": config.patch( 2771 capture_dynamic_output_shape_ops=True 2772 ), 2773} 2774 2775# These groups of tests aren't supported yet 2776known_failures_re = re.compile( 2777 r"^test_(sparse|profiler|gradcheck|checkpoint|named_tensor)" 2778) 2779 2780# Bugs needing investigation: 2781skipped_tests = { 2782 "test_callback_propagates_errors_from_device_thread", # fullgraph for queue_callback, but graph break for RuntimeError 2783} 2784 2785known_failing_tests = { 2786 # Category: Compiled autograd 2787 "test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook 2788 "test_reentrant_with_leaf_variable_hook", # hangs when enabled with graph breaks 2789 "test_reentrant_with_non_leaf_variable_hook", # hangs when enabled with graph breaks 2790 "test_anomaly_grad_warnings", # does not support anomaly mode 2791 "test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd 2792 "test_current_node", # TorchDispatchMode not yet implemented for compiled autograd 2793 "test_post_accumulate_grad_hook_ordering", # accuracy error 2794 "test_retain_grad_cycle", # retains_grad_hooks 2795 "test_retain_grad_inplace", # retains_grad_hooks 2796 "test_retain_grad_inplace_over_view", # retains_grad_hooks 2797 "test_retains_grad_can_always_observe_tensor_prehook", # retains_grad_hooks 2798 "test_retains_grad_inplace_multiple_outputs", # retains_grad_hooks 2799 "test_reentrant_child_error", # hangs when enabled with graph breaks 2800 "test_accumulate_grad", # create_graph 2801 "test_anomaly_assign_parent_cleanup", # create_graph 2802 "test_anomaly_mode_no_check_nan", # anomaly mode 2803 "test_backward_create_graph_warns", # create_graph 2804 "test_backward_with_nonleaf_inputs", # create_graph 2805 "test_create_graph_and_full_backward_hook_cycle", # create_graph 2806 "test_current_graph_task_id", # autograd state already cleared once dynamo is called 2807 "test_custom_autograd_repeated_grad_grad", # create_graph 2808 "test_custom_function_forward_mode_forward_is_no_op", # forward AD 2809 "test_custom_function_forward_mode_inplace_checks", # forward AD 2810 "test_custom_function_forward_mode_view_checks", # forward AD 2811 "test_custom_function_forward_mode_wrong_formula", # forward AD 2812 "test_default_saved_tensors_hooks_double_backward", # create_graph 2813 "test_node_post_hook_registered_during_unpack_hook", # 'NoneType' object has no attribute 'register_hook' 2814 "test_full_backward_hook_double_backward", # create_graph 2815 "test_function", # create_graph 2816 "test_grad", # create_graph 2817 "test_grad_materialize_grads", # create_graph 2818 "test_grad_nonleaf", # create_graph 2819 "test_grad_nonleaf_many_outputs", # create_graph 2820 "test_hessian_vector", # create_graph 2821 "test_hook_edge_case_when_called_with_grad", # retains_grad_hooks 2822 "test_inplace_on_view_backward", # create_graph 2823 "test_multi_grad_any_hooks", # register_multi_grad_hook 2824 "test_multi_grad_all_hooks", # retains_grad_hooks 2825 "test_nested_anomaly_detect_nan", # create_graph 2826 "test_nested_anomaly_printstack_cleanup", # create_graph 2827 "test_once_differentiable", # create_graph 2828 "test_prehook_ordering", # retains_grad_hooks 2829 "test_retain_grad", # retains_grad_hooks 2830 "test_saved_variable_packing_unpacking_saved_original_with_hooks", # create_graph 2831 "test_select_sum", # create_graph, also needs graph breaks 2832 "test_will_engine_execute_node", # retains_grad_hooks 2833 "test_backward_to_node", # retains_grad_hooks NYI 2834 "test_anomaly_detect_nan", # anomaly mode 2835 "test_custom_autograd_no_early_free", # create_graph 2836 "test_custom_function_error", # vjp 2837 "test_custom_function_save_for_forward", # vjp 2838 "test_deep_reentrant", # hangs with graph breaks 2839 "test_dont_materialize_grads", # undefined grad 2840 "test_grad_mode_restored_reentrant", # hangs with graph breaks 2841 "test_no_grad_copy", # setting static member in lifted backward 2842 "test_no_grad_copy_sparse", # setting static member in lifted backward 2843 "test_reentrant_priority", # hangs with graph breaks 2844 "test_reentrant_with_callbacks_both_depths", # hangs with graph breaks 2845 "test_reentrant_with_callbacks_depth_0", # probably hangs with graph breaks 2846 "test_reentrant_with_callbacks_depth_1", # probably hangs with graph breaks 2847 "test_save_output_nr", # output_nr grad passed as None 2848 "test_setup_context_when_forward_has_default_args", # autograd.Function with class methods 2849 "test_simple_reentrant", # hangs with graph breaks 2850 "test_lobpcg", # create_graph 2851 "test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors) 2852 "test_backward_twice_without_saved_values", # https://github.com/pytorch/pytorch/issues/129938 2853 # Category: Dynamo 2854 "test_accumulate_grad_tensor_reference", # Out of bounds: frame_state_entry.stride[i] is None 2855 "test_custom_function_exception", # torch.no_grad(), torch._dynamo.exc.Unsupported: missing: WITH_EXCEPT_START 2856 "test_to_sparse_backward", # Out of bounds: frame_state_entry.stride[i] is None 2857 "test_autograd_simple_views_python", # gradient is None 2858 "test_function_returns_undefined_tensor", # gradient is None 2859 "test_naughty_autograd_function_stashing_ctx", # bytecode issue 2860 "test_unrelated_inputs", # gradient batching rule not implemented for aten::sym_size.int 2861 "test_custom_function_non_tensor_inputs_outputs", # gradient batching rule not implemented for aten::sym_size.int 2862 "test_return_duplicate", # gradient batching rule not implemented for aten::sym_size.int 2863 "test_return_duplicate_inplace", # gradient batching rule not implemented for aten::sym_size.int 2864 "test_setitem", # CopySlices accuracy error 2865 # Category: Inductor 2866 "test_input_buffer_accum", # does not support sparse_grad=True: https://github.com/pytorch/pytorch/issues/120267 2867 "test_graph_save_on_cpu", # does not support pin_memory: https://github.com/pytorch/pytorch/issues/134173 2868 # Category: FakeTensor 2869 "test_saving_variable_to_disk", # torch.save should no-op and be recorded in the graph 2870 "test_wrapped_number_saved_tensors_hooks", # Proxy tensor should carryover is_wrapped_number_ of its original 2871 "test_grad_batched_grad", # torch._subclasses.fake_tensor.UnsupportedFakeTensorException: meta converter nyi 2872 "test_scalar_grad_mixed_device", # Fake Tensors aren't propagating device properly for 0-dim grads 2873 # Category: Divergence from eager 2874 "test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward 2875 "test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance 2876 # Uncategorized 2877} 2878 2879if not HAS_CUDA: 2880 # Found Tesla M60 which is too old to be supported by the triton GPU compiler 2881 known_failing_tests.add("test_type_conversions") 2882 2883test_autograd = load_test_module("test_autograd") 2884test_custom_ops = load_test_module("test_custom_ops") 2885 2886TestAutogradWithCompiledAutograd = wrap_test_class(test_autograd.TestAutograd) 2887TestCustomOpWithCompiledAutograd = wrap_test_class(test_custom_ops.TestCustomOp) 2888 2889if __name__ == "__main__": 2890 if HAS_CPU: 2891 run_tests(needs="filelock") 2892