1# Owner(s): ["module: functorch"] 2# flake8: noqa: B950 3import unittest 4from collections import deque 5from functools import partial 6from typing import List, TYPE_CHECKING 7 8import torch 9import torch._dynamo 10import torch._functorch 11import torch._inductor 12import torch._inductor.decomposition 13from functorch.compile import ( 14 aot_function, 15 default_decompositions, 16 min_cut_rematerialization_partition, 17 nop, 18) 19from torch._functorch.aot_autograd import aot_export_module 20from torch._higher_order_ops.effects import with_effects 21from torch._higher_order_ops.torchbind import enable_torchbind_tracing 22from torch.fx.experimental.proxy_tensor import make_fx 23from torch.testing import FileCheck 24from torch.testing._internal.common_cuda import ( 25 _get_torch_cuda_version, 26 SM70OrLater, 27 SM80OrLater, 28) 29from torch.testing._internal.common_quantization import skipIfNoDynamoSupport 30from torch.testing._internal.common_utils import ( 31 IS_WINDOWS, 32 run_tests, 33 skipIfTorchDynamo, 34 TEST_CUDA, 35 TEST_WITH_ROCM, 36 TestCase, 37) 38from torch.testing._internal.torchbind_impls import init_torchbind_implementations 39 40 41if TYPE_CHECKING: 42 from torch.utils.hooks import RemovableHandle 43 44from torch.testing._internal.two_tensor import TwoTensor 45 46 47def extract_graph(fx_g, _, graph_cell): 48 graph_cell[0] = fx_g 49 return fx_g 50 51 52def get_fw_bw_graph( 53 f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False 54): 55 fw_graph_cell = [None] 56 bw_graph_cell = [None] 57 requires_grad = False 58 59 def fn_req_grad(t): 60 nonlocal requires_grad 61 requires_grad = requires_grad or t.requires_grad 62 return t 63 64 torch.utils._pytree.tree_map_only(torch.Tensor, fn_req_grad, inps) 65 66 out = aot_function( 67 f, 68 fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), 69 bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell) 70 if requires_grad 71 else nop, 72 partition_fn=partitioner, 73 decompositions=default_decompositions, 74 dynamic=dynamic, 75 )(*inps) 76 77 if requires_grad: 78 out.sum().backward() 79 80 return (fw_graph_cell[0], bw_graph_cell[0]) 81 82 83def make_inputs_non_leaves(inps): 84 return torch.utils._pytree.tree_map_only(torch.Tensor, lambda t: t.add(1), inps) 85 86 87@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support") 88class TestWithEffects(TestCase): 89 def setUp(self): 90 init_torchbind_implementations() 91 92 def test_print(self): 93 class M(torch.nn.Module): 94 def forward(self, x): 95 torch.ops.aten._print("moo") 96 res = x + x 97 torch.ops.aten._print("moo") 98 return (res,) 99 100 inputs = (torch.randn(3),) 101 102 # Without functionalization, print should just appear in the graph directly 103 gm = make_fx(M())(*inputs) 104 FileCheck().check_count("torch.ops.aten._print.default", 2, exactly=True).run( 105 gm.code 106 ) 107 108 # With functionalization, it should appear wrapped with with_effects() 109 gm, gs = aot_export_module(M(), inputs, trace_joint=False) 110 self.assertExpectedInline( 111 str(gm.code).strip(), 112 """\ 113def forward(self, arg0_1, arg1_1): 114 with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None 115 getitem = with_effects[0]; with_effects = None 116 add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None 117 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None 118 getitem_2 = with_effects_1[0]; with_effects_1 = None 119 return (getitem_2, add)""", 120 ) 121 self.assertEqual(len(gs.input_tokens), 1) 122 self.assertEqual(len(gs.output_tokens), 1) 123 124 with torch._functorch.config.patch(unlift_effect_tokens=True): 125 gm, gs = aot_export_module(M(), inputs, trace_joint=False) 126 self.assertExpectedInline( 127 str(gm.code).strip(), 128 """\ 129def forward(self, arg1_1): 130 _make_token_default = torch.ops.prims._make_token.default() 131 with_effects = torch.ops.higher_order.with_effects(_make_token_default, torch.ops.aten._print.default, 'moo'); _make_token_default = None 132 getitem = with_effects[0]; with_effects = None 133 add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None 134 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None 135 getitem_2 = with_effects_1[0]; with_effects_1 = None 136 _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_2]); getitem_2 = _sink_tokens_default = None 137 return [add]""", # noqa: B950 138 ) 139 140 def test_torchbind_custom_op(self): 141 class M(torch.nn.Module): 142 def __init__(self) -> None: 143 super().__init__() 144 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 145 146 def forward(self, x): 147 return (x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x),) 148 149 with enable_torchbind_tracing(): 150 gm, gs = aot_export_module(M(), (torch.ones(2, 3),), trace_joint=False) 151 152 self.assertExpectedInline( 153 str(gm.code).strip(), 154 """\ 155def forward(self, arg0_1, arg1_1): 156 _torchbind_obj0 = self._torchbind_obj0 157 with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TorchScriptTesting.takes_foo.default, _torchbind_obj0, arg1_1); arg0_1 = _torchbind_obj0 = None 158 getitem = with_effects[0] 159 getitem_1 = with_effects[1]; with_effects = None 160 add = torch.ops.aten.add.Tensor(arg1_1, getitem_1); arg1_1 = getitem_1 = None 161 return (getitem, add)""", # noqa: B950 162 ) 163 self.assertEqual(len(gs.input_tokens), 1) 164 self.assertEqual(len(gs.output_tokens), 1) 165 166 def test_print_with_buffer_mutations(self): 167 class M(torch.nn.Module): 168 def __init__(self) -> None: 169 super().__init__() 170 self.buf = torch.nn.Buffer(torch.ones(3)) 171 172 def forward(self, x): 173 torch.ops.aten._print("moo") 174 res = x + x 175 self.buf.add_(res) 176 res = self.buf + x 177 torch.ops.aten._print("moo") 178 return (res,) 179 180 inputs = (torch.randn(3),) 181 182 # With functionalization, it should appear wrapped with with_effects() 183 gm, gs = aot_export_module(M(), inputs, trace_joint=False) 184 self.assertExpectedInline( 185 str(gm.code).strip(), 186 """\ 187def forward(self, arg0_1, arg1_1, arg2_1): 188 with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None 189 getitem = with_effects[0]; with_effects = None 190 add = torch.ops.aten.add.Tensor(arg2_1, arg2_1) 191 add_1 = torch.ops.aten.add.Tensor(arg1_1, add); arg1_1 = add = None 192 add_2 = torch.ops.aten.add.Tensor(add_1, arg2_1); arg2_1 = None 193 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None 194 getitem_2 = with_effects_1[0]; with_effects_1 = None 195 return (getitem_2, add_1, add_2)""", 196 ) 197 self.assertEqual(len(gs.input_tokens), 1) 198 self.assertEqual(len(gs.output_tokens), 1) 199 self.assertEqual(len(gs.buffers_to_mutate), 1) 200 201 def test_print_with_input_mutations(self): 202 class M(torch.nn.Module): 203 def __init__(self) -> None: 204 super().__init__() 205 206 def forward(self, x): 207 torch.ops.aten._print("moo") 208 res = x + x 209 x.add_(res) 210 res = x + x 211 torch.ops.aten._print("moo") 212 return (res,) 213 214 inputs = (torch.randn(3),) 215 216 # With functionalization, it should appear wrapped with with_effects() 217 gm, gs = aot_export_module(M(), inputs, trace_joint=False) 218 self.assertEqual(len(gs.input_tokens), 1) 219 self.assertEqual(len(gs.output_tokens), 1) 220 self.assertEqual(len(gs.user_inputs_to_mutate), 1) 221 222 def test_alias_op(self): 223 def f(token, x): 224 token, out = with_effects(token, torch.ops.aten.absolute_.default, x) 225 return token, out 226 227 with self.assertRaisesRegex( 228 AssertionError, r"Ops with aliasing is not supported" 229 ): 230 make_fx(f)(torch.tensor([]), torch.tensor(4)) 231 232 def test_compile_aot_eager(self): 233 def f(x): 234 torch.ops.aten._print("moo") 235 res = x + x 236 torch.ops.aten._print("moo") 237 return res 238 239 inputs = (torch.randn(2, 3),) 240 241 res = torch.compile(f, backend="aot_eager")(*inputs) 242 self.assertTrue(torch.allclose(res, f(*inputs))) 243 244 @unittest.skipIf(IS_WINDOWS, "triton") 245 @unittest.skipIf(not SM70OrLater, "triton") 246 def test_compile_inductor(self): 247 def f(x): 248 torch.ops.aten._print("moo") 249 res = x + x 250 torch.ops.aten._print("moo") 251 return res 252 253 inputs = (torch.randn(2, 3),) 254 255 res = torch.compile(f, backend="inductor")(*inputs) 256 self.assertTrue(torch.allclose(res, f(*inputs))) 257 258 @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") 259 @skipIfNoDynamoSupport 260 def test_compile_inductor_external_op_return_none(self): 261 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 262 torch.library.define( 263 "mylib::inplace_add", 264 "(Tensor input, Tensor(a!) output) -> ()", 265 lib=lib, 266 ) 267 268 def inplace_add(input: torch.Tensor, output: torch.Tensor) -> None: 269 assert input.device == output.device 270 output.add_(input) 271 272 lib.impl("inplace_add", inplace_add, "CompositeExplicitAutograd") 273 274 def f(x): 275 out = torch.empty(3) 276 out = torch.zeros_like(out) 277 torch.ops.mylib.inplace_add(x, out) 278 return out 279 280 inputs = (torch.randn(3),) 281 282 res = torch.compile(f, backend="inductor")(*inputs) 283 self.assertTrue(torch.allclose(res, f(*inputs))) 284 285 def test_compile_aot_eager_requires_grad(self): 286 def f(x): 287 torch.ops.aten._print("moo") 288 res = x + x 289 torch.ops.aten._print("moo") 290 return res 291 292 inputs = (torch.randn(2, 3, requires_grad=True),) 293 294 res = torch.compile(f, backend="aot_eager")(*inputs) 295 self.assertTrue(torch.allclose(res, f(*inputs))) 296 297 res.sum().backward() 298 299 @unittest.skipIf(IS_WINDOWS, "triton") 300 @unittest.skipIf(TEST_WITH_ROCM, "triton") 301 @unittest.skipIf(not SM80OrLater, "triton") 302 @unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton") 303 @unittest.skipIf(not TEST_CUDA, "triton") 304 @skipIfNoDynamoSupport 305 def test_register_effectful_custom_op(self): 306 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 307 torch._dynamo.config.capture_scalar_outputs = True 308 torch._dynamo.config.capture_dynamic_output_shape_ops = True 309 310 torch.library.define( 311 "mylib::record_scalar_tensor", 312 "(Tensor x, str prefix) -> ()", 313 lib=lib, 314 ) 315 316 # global variable to store the recorded tensor and prefix. 317 recorded_dict = {} 318 319 # Pytorch custorm op implementation 320 @torch.library.impl( 321 "mylib::record_scalar_tensor", 322 "CompositeExplicitAutograd", 323 lib=lib, 324 ) 325 def record_scalar_tensor(x, prefix): 326 recorded_dict[prefix] = x.clone() 327 return 328 329 # Meta function of the custom op 330 @torch.library.impl_abstract( 331 "mylib::record_scalar_tensor", 332 lib=lib, 333 ) 334 def record_scalar_tensor_meta(x, prefix): 335 return 336 337 from torch._higher_order_ops.effects import ( 338 _EffectType, 339 _register_effectful_op, 340 ) 341 342 _register_effectful_op( 343 torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED 344 ) 345 346 my_config = {} 347 my_config["MockModule"] = "mean" 348 my_config["MockModule.linear"] = "mean" 349 my_config["MockModule.relu"] = "mean" 350 351 class MyLinear(torch.nn.Module): 352 def __init__(self, in_features, out_features): 353 super().__init__() 354 self.weight = torch.nn.Parameter( 355 torch.randn(out_features, in_features), requires_grad=True 356 ) 357 self.bias = torch.nn.Parameter( 358 torch.randn(out_features), requires_grad=True 359 ) 360 361 def forward(self, x): 362 return torch.nn.functional.linear(x, self.weight, self.bias) 363 364 class MockModule(torch.nn.Module): 365 def __init__(self) -> None: 366 super().__init__() 367 self.linear = MyLinear(10, 10) 368 self.register_buffer( 369 "buf0", torch.randn(10, 10, requires_grad=True) 370 ) 371 372 def forward(self, x): 373 return torch.nn.functional.relu(self.linear(x) + self.buf0) 374 375 def forward_hook( 376 module: torch.nn.Module, 377 inputs: torch.Tensor, 378 output: torch.Tensor, 379 prefix: str, 380 aggregate_method: str, 381 ) -> torch.Tensor: 382 if aggregate_method == "mean": 383 torch.ops.mylib.record_scalar_tensor(output.mean(), prefix) 384 elif aggregate_method == "max": 385 torch.ops.mylib.record_scalar_tensor(output.max(), prefix) 386 else: 387 # demo purpose, using "min" 388 torch.ops.mylib.record_scalar_tensor(output.sum(), prefix) 389 return output 390 391 def add_hooks(module, config): 392 handles: List[RemovableHandle] = [] 393 q = deque([(module.__class__.__name__, module)]) 394 while q: 395 name, m = q.pop() 396 children = [(name + "." + n, y) for (n, y) in m.named_children()] 397 q.extend(children) 398 aggregate_method = config.get(name, "mean") 399 prefix = name + ":" + aggregate_method 400 handle = m.register_forward_hook( 401 partial( 402 forward_hook, 403 prefix=prefix, 404 aggregate_method=aggregate_method, 405 ) 406 ) 407 if handle: 408 handles.append(handle) 409 return handles 410 411 x = torch.randn(10, 10, device="cuda") 412 mod = MockModule().to("cuda") 413 414 add_hooks(mod, my_config) 415 416 opt_mod = torch.compile(backend="inductor")(mod) 417 y = opt_mod(x) 418 419 self.assertTrue(torch.allclose(y, mod(x))) 420 # Ensure it works well with backward 421 y.sum().backward() 422 # Ensure the grad is existing 423 self.assertTrue(isinstance(opt_mod.linear.weight.grad, torch.Tensor)) 424 425 self.assertEqual(len(recorded_dict), 2) 426 self.assertTrue("MockModule.linear:mean" in recorded_dict) 427 self.assertTrue("MockModule:mean" in recorded_dict) 428 429 @skipIfNoDynamoSupport 430 def test_effectful_custom_op_with_subclasses(self): 431 with torch.library._scoped_library("_mylib", "FRAGMENT") as lib: 432 lib.define("zoo(Tensor x) -> Tensor") 433 lib.define("zoo2(Tensor x) -> Tensor") 434 435 d = {"fw": 0, "bw": 0} 436 437 def reset_counter(): 438 d["fw"] = 0 439 d["bw"] = 0 440 441 def assert_counter(fw, bw): 442 self.assertEqual(d["fw"], fw) 443 self.assertEqual(d["bw"], bw) 444 445 def foo_impl(a): 446 d["fw"] = d["fw"] + 1 447 return 2 * a.clone() 448 449 def foo_meta(a): 450 return a.clone() 451 452 def foo2_impl(x): 453 d["bw"] = d["bw"] + 1 454 return x.clone() 455 456 def foo2_meta(a): 457 return a.clone() 458 459 for backend in ["CPU", "CUDA"]: 460 lib.impl("zoo", foo_impl, backend) 461 lib.impl("zoo2", foo2_impl, backend) 462 lib.impl("zoo", foo_meta, "Meta") 463 lib.impl("zoo2", foo2_meta, "Meta") 464 465 def foo_bwd(ctx, grad): 466 torch.ops._mylib.zoo2(grad) 467 return grad.clone() 468 469 torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib) 470 471 from torch._higher_order_ops.effects import ( 472 _EffectType, 473 _register_effectful_op, 474 ) 475 476 _register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED) 477 _register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED) 478 479 def fn(x, y): 480 return torch.ops._mylib.zoo(x) + y 481 482 def ins_sc(): 483 return ( 484 TwoTensor( 485 torch.tensor([1.0, 2.0, 3.0]), torch.tensor([1.0, 2.0, 3.0]) 486 ), 487 torch.tensor([4.0, 5.0, 6.0]), 488 ) 489 490 def ins_dense(): 491 return torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]) 492 493 for i, (ins_fn, expected_fw_count) in enumerate( 494 zip([ins_sc, ins_dense], [2, 1]) 495 ): 496 reset_counter() 497 ref_out = fn(*ins_fn()) 498 assert_counter(expected_fw_count, 0) 499 500 compiled_fn = torch.compile(fn, backend="aot_eager") 501 out = compiled_fn(*ins_fn()) 502 reset_counter() 503 out = compiled_fn(*ins_fn()) 504 assert_counter(expected_fw_count, 0) 505 506 self.assertEqual(ref_out, out) 507 508 def ins_dense_req_grad(): 509 return ( 510 torch.tensor([1.0, 2.0, 3.0], requires_grad=True), 511 torch.tensor([4.0, 5.0, 6.0], requires_grad=True), 512 ) 513 514 def ins_sc_req_grad(): 515 return ( 516 TwoTensor( 517 torch.tensor([1.0, 2.0, 3.0], requires_grad=True), 518 torch.tensor([4.0, 5.0, 6.0], requires_grad=True), 519 ), 520 TwoTensor( 521 torch.tensor([7.0, 8.0, 9.0], requires_grad=True), 522 torch.tensor([10.0, 11.0, 12.0], requires_grad=True), 523 ), 524 ) 525 526 for i, ( 527 ins_fn_req_grad, 528 ( 529 expected_fw_count, 530 expected_fw_count_after_bw, 531 expected_bw_count_after_bw, 532 ), 533 ) in enumerate( 534 zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1, 1), (2, 2, 2)]) 535 ): 536 ref_ins = ins_fn_req_grad() 537 reset_counter() 538 ref_out = fn(*ref_ins) 539 assert_counter(expected_fw_count, 0) 540 ref_out.sum().backward() 541 assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw) 542 543 compiled_fn = torch.compile(fn, fullgraph=True) 544 545 ins = ins_fn_req_grad() 546 out = compiled_fn(*ins) 547 reset_counter() 548 out = compiled_fn(*ins) 549 assert_counter(expected_fw_count, 0) 550 self.assertEqual(ref_out, out) 551 out.sum().backward() 552 assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw) 553 self.assertEqual(ref_ins[1].grad, ins[1].grad) 554 self.assertEqual(ref_ins[0].grad, ins[0].grad) 555 556 fw_graph, bw_graph = get_fw_bw_graph(fn, ins_sc_req_grad()) 557 self.assertExpectedInline( 558 fw_graph.code.strip(), 559 """\ 560def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5): 561 with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.zoo.default, primals_2); primals_1 = primals_2 = None 562 getitem = with_effects[0] 563 getitem_1 = with_effects[1]; with_effects = None 564 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.zoo.default, primals_3); getitem = primals_3 = None 565 getitem_2 = with_effects_1[0] 566 getitem_3 = with_effects_1[1]; with_effects_1 = None 567 add = torch.ops.aten.add.Tensor(getitem_1, primals_4); getitem_1 = primals_4 = None 568 add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_5); getitem_3 = primals_5 = None 569 return (getitem_2, add, add_1)""", 570 ) 571 self.assertExpectedInline( 572 bw_graph.code.strip(), 573 """\ 574def forward(self, tangents_1, tangents_2, tangents_token): 575 with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.zoo2.default, tangents_1); tangents_token = None 576 getitem_4 = with_effects_2[0]; with_effects_2 = None 577 with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.zoo2.default, tangents_2); getitem_4 = None 578 getitem_6 = with_effects_3[0]; with_effects_3 = None 579 clone = torch.ops.aten.clone.default(tangents_1) 580 clone_1 = torch.ops.aten.clone.default(tangents_2) 581 return (clone, clone_1, tangents_1, tangents_2, getitem_6)""", 582 ) 583 584 def test_effects_and_input_mutation_return(self): 585 def fn(a, b): 586 torch.ops.aten._print("effect") 587 return torch.sin(a, out=b) 588 589 inp = [torch.randn(3, 3), torch.ones(3, 3)] 590 ref_out = fn(*inp) 591 out = torch.compile(fn, fullgraph=True)(*inp) 592 self.assertEqual(ref_out, out) 593 594 fw_graph, bw_graph = get_fw_bw_graph(fn, inp) 595 self.assertExpectedInline( 596 fw_graph.code.strip(), 597 """\ 598def forward(self, arg0_1, arg1_1, arg2_1): 599 with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect'); arg0_1 = None 600 getitem = with_effects[0]; with_effects = None 601 sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None 602 return (getitem, sin, sin)""", 603 ) 604 605 def test_effects_and_input_output_view_simple(self): 606 def fn(a): 607 return a.view(-1) 608 609 inp = [torch.ones(2, 2, requires_grad=False).add(1)] 610 ref_out = fn(*inp) 611 out = torch.compile(fn, fullgraph=True)(*inp) 612 self.assertEqual(ref_out, out) 613 614 inp = [torch.ones(2, 2, requires_grad=True).add(1)] 615 ref_out = fn(*inp) 616 out = torch.compile(fn, fullgraph=True)(*inp) 617 self.assertEqual(ref_out, out) 618 619 fw_graph, bw_graph = get_fw_bw_graph(fn, inp) 620 621 self.assertExpectedInline( 622 fw_graph.code.strip(), 623 """\ 624def forward(self, arg0_1): 625 view = torch.ops.aten.view.default(arg0_1, [-1]); arg0_1 = None 626 return (view,)""", 627 ) 628 629 def test_effects_and_aliased_outputs(self): 630 def fn(a): 631 b = a.mul(2) 632 torch.ops.aten._print("effect") 633 c = b.view(-1) 634 return b, c 635 636 f_compiled = aot_function(fn, nop) 637 for req_grad in [True, False]: 638 inp = torch.ones(3, requires_grad=req_grad) 639 out_ref = fn(inp) 640 out_test = f_compiled(inp) 641 self.assertEqual(out_ref[0], out_test[0]) 642 self.assertEqual(out_ref[1], out_test[1]) 643 # Try mutating one of the outputs, which is aliased. 644 out_ref[0].mul_(3) 645 out_test[0].mul_(3) 646 # Assert that the aliasing relationship was preserved 647 self.assertEqual(out_ref[0], out_test[0]) 648 self.assertEqual(out_ref[1], out_test[1]) 649 650 def test_effects_and_input_mutation_is_output(self): 651 def fn(a): 652 a.mul_(2) 653 torch.ops.aten._print("effect") 654 return a 655 656 inp = make_inputs_non_leaves([torch.ones(3, 3, requires_grad=True)]) 657 ref_out = fn(*inp) 658 out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp) 659 self.assertEqual(ref_out, out) 660 661 inp = [torch.ones(3, 3, requires_grad=False)] 662 ref_out = fn(*inp) 663 out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp) 664 self.assertEqual(ref_out, out) 665 666 fw_graph, bw_graph = get_fw_bw_graph(fn, inp) 667 self.assertExpectedInline( 668 fw_graph.code.strip(), 669 """\ 670def forward(self, arg0_1, arg1_1): 671 mul = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None 672 with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect'); arg0_1 = None 673 getitem = with_effects[0]; with_effects = None 674 return (getitem, mul, mul)""", 675 ) 676 677 @skipIfTorchDynamo() 678 def test_effectful_op_in_backward(self): 679 with torch.library._scoped_library("_mylib", "FRAGMENT") as lib: 680 lib.define("foo(Tensor x) -> Tensor") 681 682 def foo_impl(a): 683 return a.clone() 684 685 def foo_bwd(ctx, grad): 686 return torch.ops._mylib.foo(grad) 687 688 for backend in ["CPU", "CUDA", "Meta"]: 689 lib.impl("foo", foo_impl, backend) 690 691 torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib) 692 693 from torch._higher_order_ops.effects import ( 694 _deregister_effectful_op, 695 _EffectType, 696 _register_effectful_op, 697 ) 698 699 _register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED) 700 try: 701 702 def fn(x, y): 703 return torch.ops._mylib.foo(x) + y 704 705 def ins_dense_req_grad(): 706 return ( 707 torch.tensor([1.0, 2.0, 3.0], requires_grad=True), 708 torch.tensor([4.0, 5.0, 6.0], requires_grad=True), 709 ) 710 711 def ins_sc_req_grad(): 712 return ( 713 TwoTensor( 714 torch.tensor([1.0, 2.0, 3.0], requires_grad=True), 715 torch.tensor([4.0, 5.0, 6.0], requires_grad=True), 716 ), 717 torch.tensor([4.0, 5.0, 6.0], requires_grad=True), 718 ) 719 720 for i, ins_fn in enumerate([ins_dense_req_grad, ins_sc_req_grad]): 721 ref_ins = ins_fn() 722 723 ref_out = fn(*ref_ins) 724 ref_out.sum().backward() 725 726 compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True) 727 ins = ins_fn() 728 out = compiled_fn(*ins) 729 self.assertEqual(ref_out, out) 730 out.sum().backward() 731 self.assertEqual(ref_ins[1].grad, ins[1].grad) 732 self.assertEqual(ref_ins[0].grad, ins[0].grad) 733 734 fw_graph, bw_graph = get_fw_bw_graph(fn, ins) 735 if i == 0: 736 self.assertExpectedInline( 737 fw_graph.code.strip(), 738 """\ 739def forward(self, primals_1, primals_2, primals_3): 740 with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2); primals_1 = primals_2 = None 741 getitem = with_effects[0] 742 getitem_1 = with_effects[1]; with_effects = None 743 add = torch.ops.aten.add.Tensor(getitem_1, primals_3); getitem_1 = primals_3 = None 744 return (getitem, add)""", 745 ) 746 self.assertExpectedInline( 747 bw_graph.code.strip(), 748 """\ 749def forward(self, tangents_1, tangents_token): 750 with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1); tangents_token = None 751 getitem_2 = with_effects_1[0] 752 getitem_3 = with_effects_1[1]; with_effects_1 = None 753 return (getitem_3, tangents_1, getitem_2)""", 754 ) 755 elif i == 1: 756 self.assertExpectedInline( 757 fw_graph.code.strip(), 758 """\ 759def forward(self, primals_1, primals_2, primals_3, primals_4): 760 with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2); primals_1 = primals_2 = None 761 getitem = with_effects[0] 762 getitem_1 = with_effects[1]; with_effects = None 763 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.foo.default, primals_3); getitem = primals_3 = None 764 getitem_2 = with_effects_1[0] 765 getitem_3 = with_effects_1[1]; with_effects_1 = None 766 add = torch.ops.aten.add.Tensor(getitem_1, primals_4); getitem_1 = None 767 add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_4); getitem_3 = primals_4 = None 768 return (getitem_2, add, add_1)""", 769 ) 770 self.assertExpectedInline( 771 bw_graph.code.strip(), 772 """\ 773def forward(self, tangents_1, tangents_2, tangents_token): 774 with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1); tangents_token = None 775 getitem_4 = with_effects_2[0] 776 getitem_5 = with_effects_2[1]; with_effects_2 = None 777 with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.foo.default, tangents_2); getitem_4 = None 778 getitem_6 = with_effects_3[0] 779 getitem_7 = with_effects_3[1]; with_effects_3 = None 780 return (getitem_5, getitem_7, tangents_1, tangents_2, getitem_6)""", 781 ) 782 else: 783 raise NotImplementedError 784 finally: 785 _deregister_effectful_op(torch.ops._mylib.foo.default) 786 787 @skipIfNoDynamoSupport 788 def test_regular_effectful_op_only_in_backward(self): 789 from torch._higher_order_ops.effects import ( 790 _deregister_effectful_op, 791 _EffectType, 792 _register_effectful_op, 793 ) 794 795 _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED) 796 try: 797 798 def fn(x): 799 return x.sin() 800 801 def inps_fn(): 802 return (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),) 803 804 torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn()) 805 806 fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn()) 807 self.assertExpectedInline( 808 fw_graph.code.strip(), 809 """\ 810def forward(self, primals_1): 811 sin = torch.ops.aten.sin.default(primals_1) 812 return (sin, primals_1)""", 813 ) 814 self.assertExpectedInline( 815 bw_graph.code.strip(), 816 """\ 817def forward(self, primals_1, tangents_1, tangents_token): 818 with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1); tangents_token = primals_1 = None 819 getitem = with_effects[0] 820 getitem_1 = with_effects[1]; with_effects = None 821 mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1); tangents_1 = getitem_1 = None 822 return (mul, getitem)""", 823 ) 824 825 def inps_fn_sc(): 826 return ( 827 TwoTensor( 828 torch.tensor([1.0, 2.0, 3.0], requires_grad=True), 829 torch.tensor([4.0, 5.0, 6.0], requires_grad=True), 830 ), 831 ) 832 833 torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn_sc()) 834 fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn_sc()) 835 self.assertExpectedInline( 836 fw_graph.code.strip(), 837 """\ 838def forward(self, primals_1, primals_2): 839 sin = torch.ops.aten.sin.default(primals_1) 840 sin_1 = torch.ops.aten.sin.default(primals_2) 841 return (sin, sin_1, primals_1, primals_2)""", 842 ) 843 self.assertExpectedInline( 844 bw_graph.code.strip(), 845 """\ 846def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token): 847 with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1); tangents_token = primals_1 = None 848 getitem = with_effects[0] 849 getitem_1 = with_effects[1]; with_effects = None 850 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten.cos.default, primals_2); getitem = primals_2 = None 851 getitem_2 = with_effects_1[0] 852 getitem_3 = with_effects_1[1]; with_effects_1 = None 853 mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1); tangents_1 = getitem_1 = None 854 mul_1 = torch.ops.aten.mul.Tensor(tangents_2, getitem_3); tangents_2 = getitem_3 = None 855 return (mul, mul_1, getitem_2)""", 856 ) 857 finally: 858 _deregister_effectful_op(torch.ops.aten.cos.default) 859 860 @skipIfNoDynamoSupport 861 def test_regular_effectful_op_in_forward_and_backward(self): 862 from torch._higher_order_ops.effects import ( 863 _deregister_effectful_op, 864 _EffectType, 865 _register_effectful_op, 866 ) 867 868 _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED) 869 try: 870 871 def fn(x): 872 x = x.cos() 873 return x.sin() 874 875 inps = (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),) 876 torch.compile(fn, backend="inductor", fullgraph=True)(*inps) 877 878 fw_graph, bw_graph = get_fw_bw_graph(fn, inps) 879 self.assertExpectedInline( 880 fw_graph.code.strip(), 881 """\ 882def forward(self, primals_1, primals_2): 883 with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.aten.cos.default, primals_2); primals_1 = None 884 getitem = with_effects[0] 885 getitem_1 = with_effects[1]; with_effects = None 886 sin = torch.ops.aten.sin.default(getitem_1) 887 return (getitem, sin, primals_2, getitem_1)""", 888 ) 889 self.assertExpectedInline( 890 bw_graph.code.strip(), 891 """\ 892def forward(self, primals_2, getitem_1, tangents_1, tangents_token): 893 with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, getitem_1); tangents_token = getitem_1 = None 894 getitem_2 = with_effects_1[0] 895 getitem_3 = with_effects_1[1]; with_effects_1 = None 896 mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_3); tangents_1 = getitem_3 = None 897 sin_1 = torch.ops.aten.sin.default(primals_2); primals_2 = None 898 neg = torch.ops.aten.neg.default(sin_1); sin_1 = None 899 mul_1 = torch.ops.aten.mul.Tensor(mul, neg); mul = neg = None 900 return (mul_1, getitem_2)""", 901 ) 902 finally: 903 _deregister_effectful_op(torch.ops.aten.cos.default) 904 905 906if __name__ == "__main__": 907 run_tests() 908