1""" 2PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes 3with test_functionalization_with_native_python_assertion) 4""" 5 6# Owner(s): ["oncall: export"] 7import math 8import operator 9import unittest 10from re import escape 11from typing import List, Set 12 13import torch 14from functorch.experimental.control_flow import cond 15from torch._dynamo.eval_frame import is_dynamo_supported 16from torch._export.non_strict_utils import ( 17 _fakify_script_objects, 18 _gather_constant_attrs, 19) 20from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse 21from torch._export.passes.replace_set_grad_with_hop_pass import ( 22 _is_set_grad_enabled_node, 23 _is_set_grad_enabled_sub_mod, 24) 25from torch._export.passes.replace_view_ops_with_view_copy_ops_pass import ( 26 get_view_copy_of_view_op, 27 is_view_op, 28 ReplaceViewOpsWithViewCopyOpsPass, 29) 30from torch._export.utils import ( 31 node_inline_, 32 nodes_count, 33 nodes_filter, 34 nodes_map, 35 sequential_split, 36) 37from torch._higher_order_ops.auto_functionalize import auto_functionalized 38from torch._subclasses.fake_tensor import FakeTensorMode 39from torch.export import export 40from torch.export._remove_auto_functionalized_pass import ( 41 unsafe_remove_auto_functionalized_pass, 42) 43from torch.export._remove_effect_tokens_pass import _remove_effect_tokens 44from torch.export.passes import move_to_device_pass 45from torch.fx.experimental.symbolic_shapes import ShapeEnv 46from torch.fx.passes.infra.partitioner import Partition 47from torch.fx.passes.operator_support import OperatorSupport 48from torch.library import _scoped_library, impl 49from torch.testing._internal.common_cuda import TEST_CUDA 50from torch.testing._internal.common_utils import ( 51 IS_WINDOWS, 52 run_tests, 53 skipIfTorchDynamo, 54 TestCase, 55) 56from torch.testing._internal.torchbind_impls import init_torchbind_implementations 57from torch.utils import _pytree as pytree 58 59 60def count_call_function(graph: torch.fx.Graph, target: torch.ops.OpOverload) -> int: 61 count = 0 62 for node in graph.nodes: 63 if node.op == "call_function" and node.target == target: 64 count += 1 65 return count 66 67 68class _AddOperatorSupport(OperatorSupport): 69 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 70 return node.op == "call_function" and node.target in {operator.add} 71 72 73class _AtenAddOperatorSupport(OperatorSupport): 74 def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: 75 return node.op == "call_function" and node.target in {torch.ops.aten.add.Tensor} 76 77 78def _to_partition_names(partitions: List[Partition]) -> List[Set[str]]: 79 return [{n.name for n in p.nodes} for p in partitions] 80 81 82def _get_output_names(gm: torch.fx.GraphModule) -> List[str]: 83 output_node = next(n for n in gm.graph.nodes if n.op == "output") 84 args = pytree.tree_leaves(output_node.args) 85 # if isinstance(args, tuple) and len(args) == 1: 86 # args = args[0] 87 return [str(arg) for arg in args] 88 89 90class ModelsWithScriptObjectAttr: 91 class Simple(torch.nn.Module): 92 def __init__(self) -> None: 93 super().__init__() 94 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 95 96 class SimpleWithAttrInContainer(torch.nn.Module): 97 def __init__(self) -> None: 98 super().__init__() 99 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 100 self.pytree_attr2 = [ 101 torch.classes._TorchScriptTesting._Foo(1, 2), 102 { 103 torch.classes._TorchScriptTesting._Foo(3, 4), 104 }, 105 {"foo": torch.classes._TorchScriptTesting._Foo(5, 6)}, 106 ] 107 108 class NestedWithAttrInContainer(torch.nn.Module): 109 def __init__(self) -> None: 110 super().__init__() 111 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 112 self.pytree_attr2 = [ 113 torch.classes._TorchScriptTesting._Foo(1, 2), 114 { 115 torch.classes._TorchScriptTesting._Foo(3, 4), 116 }, 117 {"foo": torch.classes._TorchScriptTesting._Foo(5, 6)}, 118 ] 119 self.sub_mod = ModelsWithScriptObjectAttr.Simple() 120 self.sub_mod2 = ModelsWithScriptObjectAttr.SimpleWithAttrInContainer() 121 122 class MoreNestedWithAttrInContainer(torch.nn.Module): 123 def __init__(self) -> None: 124 super().__init__() 125 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 126 self.pytree_attr2 = [ 127 torch.classes._TorchScriptTesting._Foo(1, 2), 128 { 129 torch.classes._TorchScriptTesting._Foo(3, 4), 130 }, 131 {"foo": torch.classes._TorchScriptTesting._Foo(5, 6)}, 132 ] 133 self.sub_mod = ModelsWithScriptObjectAttr.Simple() 134 self.sub_mod2 = ModelsWithScriptObjectAttr.NestedWithAttrInContainer() 135 136 137def _set_grad_enabled_tests(): 138 from torch.export._trace import _export 139 140 class SetGradOp(torch.nn.Module): 141 def forward(self, x): 142 x = x + 1 143 torch._C._set_grad_enabled(True) 144 c = x.sin().sum() 145 torch._C._set_grad_enabled(False) 146 d = c + 1 147 torch._C._set_grad_enabled(True) 148 e = d - 1 149 return d, e 150 151 class SetGradCtxManager(torch.nn.Module): 152 def forward(self, x): 153 x = x + 1 154 with torch.enable_grad(): 155 c = x.sin().sum() 156 with torch.no_grad(): 157 d = c + 1 158 with torch.enable_grad(): 159 e = d - 1 160 return d, e 161 162 class SetGradCtxManagerMultiDep(torch.nn.Module): 163 def forward(self, x): 164 x = x + 1 165 with torch.enable_grad(): 166 c1 = x.sin().sum() 167 c2 = x.cos().sum() 168 with torch.no_grad(): 169 d1 = c1 + 1 170 d2 = c2 + 1 171 with torch.enable_grad(): 172 e1 = d1 - 1 173 e2 = d2 - 1 174 return d1, d2, e1, e2 175 176 x = torch.randn(2, 2) 177 178 def _get_predispatch_module(mod, args, ambient_grad_enabled=True): 179 with torch.set_grad_enabled(ambient_grad_enabled): 180 return _export(mod, args, pre_dispatch=True).module() 181 182 return { 183 "ctx_manager": ( 184 SetGradCtxManager(), 185 _get_predispatch_module(SetGradCtxManager(), (x,)), 186 (x,), 187 ), 188 "ctx_manager_under_no_grad": ( 189 SetGradCtxManager(), 190 _get_predispatch_module(SetGradCtxManager(), (x,), False), 191 (x,), 192 ), 193 "ctx_manager_multi_dep": ( 194 SetGradCtxManagerMultiDep(), 195 _get_predispatch_module(SetGradCtxManagerMultiDep(), (x,)), 196 (x,), 197 ), 198 "ctx_manager_multi_dep_no_grad": ( 199 SetGradCtxManagerMultiDep(), 200 _get_predispatch_module(SetGradCtxManagerMultiDep(), (x,), False), 201 (x,), 202 ), 203 "op": (SetGradOp(), _get_predispatch_module(SetGradOp(), (x,)), (x,)), 204 "op_under_no_grad": ( 205 SetGradOp(), 206 _get_predispatch_module(SetGradOp(), (x,), False), 207 (x,), 208 ), 209 } 210 211 212def _with_autocast_tests(): 213 from torch.export._trace import _export 214 215 class WithAutocastOp(torch.nn.Module): 216 def forward(self, x): 217 x = x + 1 218 with torch.autocast(device_type="cpu", enabled=True): 219 c = x.sin().sum() 220 with torch.autocast(device_type="cpu", enabled=False): 221 d = c + 1 222 with torch.autocast(device_type="cpu", enabled=True): 223 e = d - 1 224 return d, e 225 226 class WithAutocastOpMultiDep(torch.nn.Module): 227 def forward(self, x): 228 x = x + 1 229 with torch.autocast(device_type="cpu", enabled=True): 230 c1 = x.sin().sum() 231 c2 = x.cos().sum() 232 with torch.autocast(device_type="cpu", enabled=False): 233 d1 = c1 + 1 234 d2 = c2 + 1 235 with torch.autocast(device_type="cpu", enabled=True): 236 e1 = d1 - 1 237 e2 = d2 - 1 238 return d1, d2, e1, e2 239 240 class SplitAutocastOp(torch.nn.Module): 241 def forward(self, x): 242 x = x + 1 243 with torch.autocast(device_type="cpu", enabled=True): 244 c = x.sin().sum() 245 d = c + 1 246 with torch.autocast(device_type="cpu", enabled=True): 247 e = d - 1 248 return d, e 249 250 x = torch.randn(2, 2) 251 252 def _get_predispatch_module(mod, args): 253 return _export(mod, args, pre_dispatch=True).module() 254 255 return { 256 "ctx_manager": ( 257 WithAutocastOp(), 258 _get_predispatch_module(WithAutocastOp(), (x,)), 259 (x,), 260 ), 261 "ctx_manager_multi_dep": ( 262 WithAutocastOpMultiDep(), 263 _get_predispatch_module(WithAutocastOpMultiDep(), (x,)), 264 (x,), 265 ), 266 "ctx_manager_split": ( 267 SplitAutocastOp(), 268 _get_predispatch_module(SplitAutocastOp(), (x,)), 269 (x,), 270 ), 271 } 272 273 274def _sequential_split_inline_tests(): 275 from torch.export._trace import _export 276 277 class Simple(torch.nn.Module): 278 def forward(self, x): 279 x = x + 1 280 c = x.sin().sum() 281 d = c + 1 282 e = d - 1 283 return d, e 284 285 class MultiDep(torch.nn.Module): 286 def forward(self, x1, x2): 287 x1 = x1 + 1 288 x2 = x2 + 1 289 c1 = x1.sin() 290 c2 = x2.cos() 291 d1 = c1 + 1 292 d2 = c2 + 1 293 e1 = d1 - 1 294 e2 = d2 - 1 295 return d1, d2, e1, e2 296 297 def _get_predispatch_module(mod, args): 298 return _export(mod, args, pre_dispatch=True).module() 299 300 def _insert_dilimiter_nodes(gm: torch.fx.GraphModule, step: int = 1): 301 insert_locs = [] 302 for i, node in enumerate( 303 nodes_filter(gm.graph.nodes, lambda n: n.op == "call_function") 304 ): 305 if i % step == 0: 306 insert_locs.append(node) 307 308 for i, node in enumerate(insert_locs): 309 with gm.graph.inserting_before(node): 310 gm.graph.call_function( 311 torch._C._set_grad_enabled, (True if i % 2 == 0 else False,), {} 312 ) 313 return gm 314 315 x = torch.randn(2, 2) 316 simple = _get_predispatch_module(Simple(), (x,)) 317 simple1 = _get_predispatch_module(Simple(), (x,)) 318 multi_dep = _get_predispatch_module(MultiDep(), (x, x.sin())) 319 multi_dep1 = _get_predispatch_module(MultiDep(), (x, x.sin())) 320 return { 321 "simple_step1": (_insert_dilimiter_nodes(simple1, 1), (x,)), 322 "simple_step2": (_insert_dilimiter_nodes(simple, 2), (x,)), 323 "multi_dep_step2": (_insert_dilimiter_nodes(multi_dep, 2), (x, x.sin())), 324 "multi_dep_step3": (_insert_dilimiter_nodes(multi_dep1, 3), (x, x.sin())), 325 } 326 327 328@skipIfTorchDynamo("recursively running dynamo on export is unlikely") 329@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") 330class TestPasses(TestCase): 331 def setUp(self): 332 super().setUp() 333 self.SEQUENTIAL_SPLIT_INLINE_TESTS = _sequential_split_inline_tests() 334 self.SET_GRAD_ENABLED_TESTS = _set_grad_enabled_tests() 335 self.WITH_AUTOCAST_TESTS = _with_autocast_tests() 336 337 init_torchbind_implementations() 338 339 def tearDown(self): 340 self.SEQUENTIAL_SPLIT_INLINE_TESTS.clear() 341 self.SET_GRAD_ENABLED_TESTS.clear() 342 self.WITH_AUTOCAST_TESTS.clear() 343 super().tearDown() 344 345 def test_runtime_assert_one_dim(self) -> None: 346 class M(torch.nn.Module): 347 def __init__(self) -> None: 348 super().__init__() 349 350 def forward(self, x): 351 return x.cos() 352 353 x = torch.zeros(2, 2, 3) 354 355 dim1_x = torch.export.Dim("dim1_x", min=2, max=6) 356 ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}}) 357 358 with self.assertRaisesRegex( 359 RuntimeError, 360 escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), 361 ): 362 ep.module()(torch.zeros(2, 7, 3)) 363 364 self.assertEqual( 365 ep.module()(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3)) 366 ) 367 368 def test_runtime_assert_multiple_dims(self) -> None: 369 class M(torch.nn.Module): 370 def __init__(self) -> None: 371 super().__init__() 372 373 def forward(self, x, y): 374 return x.cos().sum() + y.sin().sum() 375 376 x = torch.zeros(4, 2, 3) 377 y = torch.zeros(5, 5, 5) 378 379 dim1_x = torch.export.Dim("dim1_x", min=2, max=6) 380 dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y", min=3) 381 382 ep = torch.export.export( 383 M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}} 384 ) 385 386 with self.assertRaisesRegex( 387 RuntimeError, 388 escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), 389 ): 390 ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) 391 392 with self.assertRaisesRegex( 393 RuntimeError, 394 escape("Expected input at *args[1].shape[0] to be >= 3, but got 2"), 395 ): 396 ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) 397 398 def test_runtime_assert_some_dims_not_specified(self) -> None: 399 class M(torch.nn.Module): 400 def __init__(self) -> None: 401 super().__init__() 402 403 def forward(self, x, y): 404 return x.cos().sum() + y.sin().sum() 405 406 x = torch.zeros(4, 2, 3) 407 y = torch.zeros(5, 5, 5) 408 409 dim1_x = torch.export.Dim("dim1_x", min=2, max=6) 410 dim0_x = torch.export.Dim("dim0_x", min=3) 411 412 ep = torch.export.export( 413 M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None} 414 ) 415 416 with self.assertRaisesRegex( 417 RuntimeError, 418 escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"), 419 ): 420 ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) 421 422 # y is specialized to 5 423 with self.assertRaisesRegex( 424 RuntimeError, 425 escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"), 426 ): 427 ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) 428 429 # Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1 430 gm_result_for_1_size = ep.module()(torch.ones(3, 1, 3), torch.ones(5, 5, 5)) 431 eager_result_for_1_size = M().forward(torch.ones(3, 1, 3), torch.ones(5, 5, 5)) 432 433 self.assertEqual(gm_result_for_1_size, eager_result_for_1_size) 434 435 def test_runtime_assert_some_inps_not_used(self) -> None: 436 class M(torch.nn.Module): 437 def __init__(self) -> None: 438 super().__init__() 439 440 def forward(self, x, y): 441 return y.cos().sum() 442 443 x = torch.zeros(4, 2, 3) 444 y = torch.zeros(5, 5, 5) 445 446 dim1_y = torch.export.Dim("dim1_y", min=3, max=6) 447 ep = torch.export.export( 448 M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}} 449 ) 450 451 with self.assertRaisesRegex(RuntimeError, escape("shape[1] to be equal to 2")): 452 ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5)) 453 454 # y is specialized to 5 455 with self.assertRaisesRegex( 456 RuntimeError, 457 escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"), 458 ): 459 ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5)) 460 461 # Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1 462 gm_result_for_1_size = ep.module()(torch.zeros(4, 2, 3), torch.ones(5, 5, 5)) 463 eager_result_for_1_size = M().forward(torch.zeros(4, 2, 3), torch.ones(5, 5, 5)) 464 465 self.assertEqual(gm_result_for_1_size, eager_result_for_1_size) 466 467 def test_view_to_view_copy(self) -> None: 468 class M(torch.nn.Module): 469 def __init__(self) -> None: 470 super().__init__() 471 472 def forward(self, x): 473 z = x.view(x.shape) 474 return z.cos().sum() 475 476 x = torch.zeros(4, 2, 3) 477 478 ep = export(M(), (x,)) 479 self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 1) 480 481 ep = ep._transform_do_not_use(ReplaceViewOpsWithViewCopyOpsPass()) 482 self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 0) 483 484 def test_functionalization_with_view_copy(self) -> None: 485 class Module(torch.nn.Module): 486 def forward(self, x): 487 y = x + 4 488 y.add_(4) 489 z = y.view(y.shape) 490 return x.cos() + z.cos() 491 492 x = torch.zeros(4, 2, 3) 493 foo = Module() 494 ep = export(foo, (x,))._transform_do_not_use( 495 ReplaceViewOpsWithViewCopyOpsPass() 496 ) 497 # After this pass, there shouldn't be any view nodes in the graph 498 self.assertTrue(count_call_function(ep.graph, torch.ops.aten.view.default) == 0) 499 self.assertTrue( 500 count_call_function(ep.graph, torch.ops.aten.view_copy.default) > 0 501 ) 502 503 def test_views_op_having_view_copy(self) -> None: 504 schemas = torch._C._dispatch_get_registrations_for_dispatch_key("") 505 aten_schemas = [s[6:] for s in schemas if s.startswith("aten::")] 506 507 for aten_schema in aten_schemas: 508 val = aten_schema.split(".") 509 assert len(val) <= 2 510 name = "" 511 overload = "" 512 if len(val) == 1: 513 name = val[0] 514 overload = "default" 515 else: 516 name, overload = val[0], val[1] 517 518 op_overload = getattr(getattr(torch.ops.aten, name), overload) 519 if torch.Tag.core in op_overload.tags and is_view_op(op_overload._schema): 520 self.assertIsNotNone(get_view_copy_of_view_op(op_overload._schema)) 521 522 def test_custom_obj_tuple_out(self): 523 class MyModule(torch.nn.Module): 524 def __init__(self) -> None: 525 super().__init__() 526 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 527 528 def forward(self, x): 529 a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x) 530 y = a[0] + a[1] 531 b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y) 532 return b 533 534 m = MyModule() 535 inputs = (torch.ones(2, 3),) 536 ep = torch.export.export(m, inputs, strict=False) 537 538 inp = torch.randn(2, 3) 539 orig_res = m(inp) 540 ep_res = ep.module()(inp) 541 542 without_token_ep = _remove_effect_tokens(ep) 543 without_token_ep.verifier().check(without_token_ep) 544 without_token_res = without_token_ep.module()(inp) 545 546 self.assertTrue(torch.allclose(orig_res, ep_res)) 547 self.assertTrue(torch.allclose(orig_res, without_token_res)) 548 549 def test_remove_effect_token_kwargs(self): 550 class MyModule(torch.nn.Module): 551 def __init__(self) -> None: 552 super().__init__() 553 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 554 555 def forward(self, x): 556 a = torch.ops._TorchScriptTesting.takes_foo_tuple_return( 557 foo=self.attr, x=x 558 ) 559 y = a[0] + a[1] 560 b = torch.ops._TorchScriptTesting.takes_foo(foo=self.attr, x=y) 561 return b 562 563 m = MyModule() 564 inputs = (torch.ones(2, 3),) 565 ep = torch.export.export(m, inputs, strict=False) 566 without_token_ep = _remove_effect_tokens(ep) 567 self.assertExpectedInline( 568 without_token_ep.graph_module.code.strip(), 569 """\ 570def forward(self, token, obj_attr, x): 571 with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, foo = obj_attr, x = x); token = x = None 572 getitem = with_effects[0] 573 getitem_1 = with_effects[1] 574 getitem_2 = with_effects[2]; with_effects = None 575 add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None 576 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, foo = obj_attr, x = add); getitem = obj_attr = add = None 577 getitem_3 = with_effects_1[0] 578 getitem_4 = with_effects_1[1]; with_effects_1 = None 579 return (getitem_3, getitem_4)""", # noqa: B950 580 ) 581 582 def test_fakify_script_objects(self): 583 for m in [ 584 ModelsWithScriptObjectAttr.Simple(), 585 ModelsWithScriptObjectAttr.SimpleWithAttrInContainer(), 586 ModelsWithScriptObjectAttr.NestedWithAttrInContainer(), 587 ModelsWithScriptObjectAttr.MoreNestedWithAttrInContainer(), 588 ]: 589 constant_attrs = _gather_constant_attrs(m) 590 fake_mode = FakeTensorMode( 591 shape_env=ShapeEnv(tracked_fakes=[]), 592 allow_non_fake_inputs=True, 593 ) 594 with _fakify_script_objects(m, (), {}, fake_mode) as ( 595 patched_mod, 596 _, 597 _, 598 fake_constant_attrs, 599 fake_to_real, 600 ): 601 self.assertEqual(len(fake_constant_attrs), len(constant_attrs)) 602 for fake_obj, fqn in fake_constant_attrs.items(): 603 self.assertEqual(constant_attrs[fake_to_real[fake_obj]], fqn) 604 605 # TODO: _gather_constants doesn't recursively look into the pytree containers. 606 @unittest.expectedFailure 607 def test_fakify_script_objects_properly_handle_containers(self): 608 m = ModelsWithScriptObjectAttr.SimpleWithAttrInContainer() 609 constant_attrs = _gather_constant_attrs(m) 610 fake_mode = FakeTensorMode( 611 shape_env=ShapeEnv(tracked_fakes=[]), 612 allow_non_fake_inputs=True, 613 ) 614 with _fakify_script_objects(m, (), {}, fake_mode) as ( 615 patched_mod, 616 _, 617 _, 618 fake_constant_attrs, 619 fake_to_real, 620 ): 621 self.assertTrue("attr" in fake_constant_attrs.values()) 622 self.assertTrue("pytree_attr2" in fake_constant_attrs.values()) 623 624 def test_runtime_assert_inline_constraints_for_item(self) -> None: 625 class M(torch.nn.Module): 626 def __init__(self) -> None: 627 super().__init__() 628 629 def forward(self, x): 630 b = x.item() 631 torch._check(b >= 2) 632 torch._check(b <= 5) 633 return b 634 635 x = torch.tensor([2]) 636 mod = M() 637 ep = export(mod, (x,)) 638 639 with self.assertRaisesRegex( 640 RuntimeError, r"Runtime assertion failed for expression u[\d+] \<\= 5" 641 ): 642 ep.module()(torch.tensor([6])) 643 644 new_inp = torch.tensor([5]) 645 self.assertEqual(mod(new_inp), ep.module()(new_inp)) 646 647 def test_runtime_assert_inline_constraints_for_nonzero(self) -> None: 648 class M(torch.nn.Module): 649 def __init__(self) -> None: 650 super().__init__() 651 652 def forward(self, x): 653 b = x.nonzero() 654 torch._check(b.shape[0] >= 3) 655 torch._check(b.shape[0] <= 5) 656 return b 657 658 x = torch.tensor([2, 1, 2, 3, 5, 0]) 659 660 mod = M() 661 dim0_x = torch.export.Dim("dim0_x") 662 ep = torch.export.export(mod, (x,), dynamic_shapes={"x": {0: dim0_x}}) 663 664 num_assert = count_call_function( 665 ep.graph, torch.ops.aten._assert_scalar.default 666 ) 667 self.assertEqual(num_assert, 2) 668 num_constrain_range = count_call_function( 669 ep.graph, torch.ops.aten.sym_constrain_range.default 670 ) 671 self.assertEqual(num_constrain_range, 0) 672 673 with self.assertRaisesRegex( 674 RuntimeError, 675 r"Runtime assertion failed for expression u[\d+] \>\= 3", 676 ): 677 ep.module()(torch.tensor([1, 1, 0, 0, 0])) 678 679 with self.assertRaisesRegex( 680 RuntimeError, 681 r"Runtime assertion failed for expression u[\d+] \<\= 5", 682 ): 683 ep.module()(torch.ones(6)) 684 685 new_inp = torch.tensor([1, 1, 1, 1]) 686 self.assertEqual(mod(new_inp), ep.module()(new_inp)) 687 688 @unittest.skipIf(IS_WINDOWS, "Windows not supported") 689 @unittest.expectedFailure 690 # TODO(pianpwk): add back runtime asserts to subgraphs 691 def test_runtime_assert_inline_constraints_for_cond(self) -> None: 692 class M(torch.nn.Module): 693 def __init__(self) -> None: 694 super().__init__() 695 696 def forward(self, pred, x, y): 697 def true_fn(x, y): 698 b = x.item() 699 torch._check(b >= 2) 700 torch._check(b <= 5) 701 return x - b 702 703 def false_fn(x, y): 704 c = y.item() 705 torch._check(c >= 2) 706 torch._check(c <= 5) 707 return y - c 708 709 ret = cond(pred, true_fn, false_fn, [x, y]) 710 return ret 711 712 x = torch.tensor([2]) 713 y = torch.tensor([5]) 714 mod = M() 715 ep = export(mod, (torch.tensor(True), x, y)) 716 717 with self.assertRaisesRegex( 718 RuntimeError, "is outside of inline constraint \\[2, 5\\]." 719 ): 720 ep.module()(torch.tensor(False), torch.tensor([6]), torch.tensor([6])) 721 722 def test_math_ops(self): 723 class Module(torch.nn.Module): 724 def forward(self, x): 725 return ( 726 torch.tensor([math.ceil(x.item())]), 727 torch.tensor([math.floor(x.item())]), 728 ) 729 730 func = Module() 731 x = torch.randn(1, dtype=torch.float32) 732 ep = torch.export.export(func, args=(x,)) 733 _ExportPassBaseDeprecatedDoNotUse()(ep.graph_module) 734 735 def test_predispatch_set_grad(self): 736 def _check_node_users_in_the_same_graph(gm): 737 for node in gm.graph.nodes: 738 for user in node.users: 739 self.assertTrue(user.graph is gm.graph) 740 741 mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["op"] 742 _check_node_users_in_the_same_graph(mod) 743 self.assertEqual(mod_orig(*args), mod(*args)) 744 self.assertExpectedInline( 745 mod.code.strip("\n"), 746 """\ 747def forward(self, x): 748 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 749 add = torch.ops.aten.add.Tensor(x, 1); x = None 750 sin = torch.ops.aten.sin.default(add); add = None 751 sum_1 = torch.ops.aten.sum.default(sin); sin = None 752 submod_4 = self.submod_2 753 add_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_4, sum_1); submod_4 = sum_1 = None 754 getitem = add_1[0]; add_1 = None 755 sub = torch.ops.aten.sub.Tensor(getitem, 1) 756 return pytree.tree_unflatten((getitem, sub), self._out_spec) 757 """, 758 ) 759 760 mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["op_under_no_grad"] 761 _check_node_users_in_the_same_graph(mod) 762 self.assertEqual(mod_orig(*args), mod(*args)) 763 self.assertExpectedInline( 764 mod.code.strip("\n"), 765 """\ 766def forward(self, x): 767 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 768 add = torch.ops.aten.add.Tensor(x, 1); x = None 769 sin = torch.ops.aten.sin.default(add); add = None 770 sum_1 = torch.ops.aten.sum.default(sin); sin = None 771 submod_4 = self.submod_2 772 add_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_4, sum_1); submod_4 = sum_1 = None 773 getitem = add_1[0]; add_1 = None 774 sub = torch.ops.aten.sub.Tensor(getitem, 1) 775 return pytree.tree_unflatten((getitem, sub), self._out_spec) 776 """, 777 ) 778 779 mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager"] 780 _check_node_users_in_the_same_graph(mod) 781 self.assertEqual(mod_orig(*args), mod(*args)) 782 self.assertExpectedInline( 783 mod.code.strip("\n"), 784 """\ 785def forward(self, x): 786 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 787 add = torch.ops.aten.add.Tensor(x, 1); x = None 788 sin = torch.ops.aten.sin.default(add); add = None 789 sum_1 = torch.ops.aten.sum.default(sin); sin = None 790 submod_3 = self.submod_1 791 add_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, sum_1); submod_3 = sum_1 = None 792 getitem = add_1[0]; add_1 = None 793 sub = torch.ops.aten.sub.Tensor(getitem, 1) 794 return pytree.tree_unflatten((getitem, sub), self._out_spec) 795 """, 796 ) 797 798 mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_under_no_grad"] 799 _check_node_users_in_the_same_graph(mod) 800 self.assertEqual(mod_orig(*args), mod(*args)) 801 self.assertExpectedInline( 802 mod.code.strip("\n"), 803 """\ 804def forward(self, x): 805 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 806 add = torch.ops.aten.add.Tensor(x, 1); x = None 807 submod_5 = self.submod_1 808 sum_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None 809 getitem = sum_1[0]; sum_1 = None 810 add_1 = torch.ops.aten.add.Tensor(getitem, 1); getitem = None 811 submod_6 = self.submod_3 812 sub = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_6, add_1); submod_6 = None 813 getitem_1 = sub[0]; sub = None 814 return pytree.tree_unflatten((add_1, getitem_1), self._out_spec) 815 """, 816 ) 817 818 mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_multi_dep"] 819 _check_node_users_in_the_same_graph(mod) 820 self.assertEqual(mod_orig(*args), mod(*args)) 821 self.assertExpectedInline( 822 mod.code.strip("\n"), 823 """\ 824def forward(self, x): 825 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 826 add = torch.ops.aten.add.Tensor(x, 1); x = None 827 sin = torch.ops.aten.sin.default(add) 828 sum_1 = torch.ops.aten.sum.default(sin); sin = None 829 cos = torch.ops.aten.cos.default(add); add = None 830 sum_2 = torch.ops.aten.sum.default(cos); cos = None 831 submod_3 = self.submod_1 832 wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(False, submod_3, sum_1, sum_2); submod_3 = sum_1 = sum_2 = None 833 add_1 = wrap_with_set_grad_enabled[0] 834 add_2 = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None 835 sub = torch.ops.aten.sub.Tensor(add_1, 1) 836 sub_1 = torch.ops.aten.sub.Tensor(add_2, 1) 837 return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec) 838 """, # noqa: B950 839 ) 840 841 mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS[ 842 "ctx_manager_multi_dep_no_grad" 843 ] 844 _check_node_users_in_the_same_graph(mod) 845 self.assertEqual(mod_orig(*args), mod(*args)) 846 self.assertExpectedInline( 847 mod.code.strip("\n"), 848 """\ 849def forward(self, x): 850 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 851 add = torch.ops.aten.add.Tensor(x, 1); x = None 852 submod_5 = self.submod_1 853 wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None 854 sum_1 = wrap_with_set_grad_enabled[0] 855 sum_2 = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None 856 add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None 857 add_2 = torch.ops.aten.add.Tensor(sum_2, 1); sum_2 = None 858 submod_6 = self.submod_3 859 wrap_with_set_grad_enabled_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_6, add_1, add_2); submod_6 = None 860 sub = wrap_with_set_grad_enabled_1[0] 861 sub_1 = wrap_with_set_grad_enabled_1[1]; wrap_with_set_grad_enabled_1 = None 862 return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec) 863 """, # noqa: B950 864 ) 865 866 def test_sequential_split(self): 867 for gm, args in self.SEQUENTIAL_SPLIT_INLINE_TESTS.values(): 868 set_grad_counts = nodes_count(gm.graph.nodes, _is_set_grad_enabled_node) 869 new_gm = sequential_split(gm, _is_set_grad_enabled_node) 870 new_set_grad_counts = nodes_count( 871 new_gm.graph.nodes, _is_set_grad_enabled_sub_mod 872 ) 873 self.assertEqual(set_grad_counts, new_set_grad_counts) 874 self.assertEqual(gm(*args), new_gm(*args)) 875 876 def test_sequential_split_graph(self): 877 gm, args = self.SEQUENTIAL_SPLIT_INLINE_TESTS["multi_dep_step2"] 878 879 new_gm = sequential_split(gm, _is_set_grad_enabled_node) 880 self.assertEqual(gm(*args), new_gm(*args)) 881 self.assertExpectedInline( 882 new_gm.code.strip("\n"), 883 """\ 884def forward(self, x1, x2): 885 x1, x2, = fx_pytree.tree_flatten_spec(([x1, x2], {}), self._in_spec) 886 submod_1 = self.submod_1(x1, x2); x1 = x2 = None 887 getitem = submod_1[0] 888 getitem_1 = submod_1[1]; submod_1 = None 889 submod_2 = self.submod_2(getitem, getitem_1); getitem = getitem_1 = None 890 getitem_2 = submod_2[0] 891 getitem_3 = submod_2[1]; submod_2 = None 892 submod_3 = self.submod_3(getitem_2, getitem_3); getitem_2 = getitem_3 = None 893 getitem_4 = submod_3[0] 894 getitem_5 = submod_3[1]; submod_3 = None 895 submod_4 = self.submod_4(getitem_4, getitem_5) 896 getitem_6 = submod_4[0] 897 getitem_7 = submod_4[1]; submod_4 = None 898 return pytree.tree_unflatten((getitem_4, getitem_5, getitem_6, getitem_7), self._out_spec) 899 """, 900 ) 901 self.assertExpectedInline( 902 new_gm.submod_1.code.strip("\n"), 903 """\ 904def forward(self, x1, x2): 905 _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None 906 add = torch.ops.aten.add.Tensor(x1, 1); x1 = None 907 add_1 = torch.ops.aten.add.Tensor(x2, 1); x2 = None 908 return (add, add_1) 909 """, 910 ) 911 self.assertExpectedInline( 912 new_gm.submod_2.code.strip("\n"), 913 """\ 914def forward(self, add, add_1): 915 _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None 916 sin = torch.ops.aten.sin.default(add); add = None 917 cos = torch.ops.aten.cos.default(add_1); add_1 = None 918 return (sin, cos) 919 """, 920 ) 921 self.assertExpectedInline( 922 new_gm.submod_3.code.strip("\n"), 923 """\ 924def forward(self, sin, cos): 925 _set_grad_enabled_2 = torch._C._set_grad_enabled(True); _set_grad_enabled_2 = None 926 add_2 = torch.ops.aten.add.Tensor(sin, 1); sin = None 927 add_3 = torch.ops.aten.add.Tensor(cos, 1); cos = None 928 return (add_2, add_3) 929 """, 930 ) 931 932 def test_predispatch_autocast(self): 933 def _check_node_users_in_the_same_graph(gm): 934 for node in gm.graph.nodes: 935 for user in node.users: 936 self.assertTrue(user.graph is gm.graph) 937 938 mod_orig, mod, args = self.WITH_AUTOCAST_TESTS["ctx_manager"] 939 _check_node_users_in_the_same_graph(mod) 940 self.assertEqual(mod_orig(*args), mod(*args)) 941 self.assertExpectedInline( 942 mod.code.strip("\n"), 943 """\ 944def forward(self, x): 945 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 946 add = torch.ops.aten.add.Tensor(x, 1); x = None 947 submod_4 = self.submod_1 948 sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None 949 getitem = sum_1[0]; sum_1 = None 950 submod_5 = self.submod_2 951 add_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, False, None, submod_5, getitem); submod_5 = getitem = None 952 getitem_1 = add_1[0]; add_1 = None 953 submod_6 = self.submod_3 954 sub = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_6, getitem_1); submod_6 = None 955 getitem_2 = sub[0]; sub = None 956 return pytree.tree_unflatten((getitem_1, getitem_2), self._out_spec) 957 """, 958 ) 959 960 self.assertExpectedInline( 961 mod.submod_1.code.strip("\n"), 962 """\ 963def forward(self, add): 964 sin = torch.ops.aten.sin.default(add); add = None 965 sum_1 = torch.ops.aten.sum.default(sin); sin = None 966 return (sum_1,) 967 """, 968 ) 969 970 self.assertExpectedInline( 971 mod.submod_2.code.strip("\n"), 972 """\ 973def forward(self, sum_1): 974 add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None 975 return (add_1,) 976 """, 977 ) 978 979 self.assertExpectedInline( 980 mod.submod_3.code.strip("\n"), 981 """\ 982def forward(self, add_1): 983 sub = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None 984 return (sub,) 985 """, 986 ) 987 988 mod_orig, mod, args = self.WITH_AUTOCAST_TESTS["ctx_manager_multi_dep"] 989 _check_node_users_in_the_same_graph(mod) 990 self.assertEqual(mod_orig(*args), mod(*args)) 991 self.assertExpectedInline( 992 mod.code.strip("\n"), 993 """\ 994def forward(self, x): 995 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 996 add = torch.ops.aten.add.Tensor(x, 1); x = None 997 submod_4 = self.submod_1 998 wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None 999 sum_1 = wrap_with_autocast[0] 1000 sum_2 = wrap_with_autocast[1]; wrap_with_autocast = None 1001 submod_5 = self.submod_2 1002 wrap_with_autocast_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, False, None, submod_5, sum_1, sum_2); submod_5 = sum_1 = sum_2 = None 1003 add_1 = wrap_with_autocast_1[0] 1004 add_2 = wrap_with_autocast_1[1]; wrap_with_autocast_1 = None 1005 submod_6 = self.submod_3 1006 wrap_with_autocast_2 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_6, add_1, add_2); submod_6 = None 1007 sub = wrap_with_autocast_2[0] 1008 sub_1 = wrap_with_autocast_2[1]; wrap_with_autocast_2 = None 1009 return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec) 1010 """, # noqa: B950 1011 ) 1012 1013 self.assertExpectedInline( 1014 mod.submod_1.code.strip("\n"), 1015 """\ 1016def forward(self, add): 1017 sin = torch.ops.aten.sin.default(add) 1018 sum_1 = torch.ops.aten.sum.default(sin); sin = None 1019 cos = torch.ops.aten.cos.default(add); add = None 1020 sum_2 = torch.ops.aten.sum.default(cos); cos = None 1021 return (sum_1, sum_2) 1022 """, 1023 ) 1024 1025 self.assertExpectedInline( 1026 mod.submod_2.code.strip("\n"), 1027 """\ 1028def forward(self, sum_1, sum_2): 1029 add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None 1030 add_2 = torch.ops.aten.add.Tensor(sum_2, 1); sum_2 = None 1031 return (add_1, add_2) 1032 """, 1033 ) 1034 1035 self.assertExpectedInline( 1036 mod.submod_3.code.strip("\n"), 1037 """\ 1038def forward(self, add_1, add_2): 1039 sub = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None 1040 sub_1 = torch.ops.aten.sub.Tensor(add_2, 1); add_2 = None 1041 return (sub, sub_1) 1042 """, 1043 ) 1044 1045 mod_orig, mod, args = self.WITH_AUTOCAST_TESTS["ctx_manager_split"] 1046 _check_node_users_in_the_same_graph(mod) 1047 self.assertEqual(mod_orig(*args), mod(*args)) 1048 self.assertExpectedInline( 1049 mod.code.strip("\n"), 1050 """\ 1051def forward(self, x): 1052 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 1053 add = torch.ops.aten.add.Tensor(x, 1); x = None 1054 submod_4 = self.submod_1 1055 sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None 1056 getitem = sum_1[0]; sum_1 = None 1057 add_1 = torch.ops.aten.add.Tensor(getitem, 1); getitem = None 1058 submod_5 = self.submod_3 1059 sub = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_5, add_1); submod_5 = None 1060 getitem_1 = sub[0]; sub = None 1061 return pytree.tree_unflatten((add_1, getitem_1), self._out_spec) 1062 """, 1063 ) 1064 1065 self.assertExpectedInline( 1066 mod.submod_1.code.strip("\n"), 1067 """\ 1068def forward(self, add): 1069 sin = torch.ops.aten.sin.default(add); add = None 1070 sum_1 = torch.ops.aten.sum.default(sin); sin = None 1071 return (sum_1,) 1072 """, 1073 ) 1074 1075 self.assertExpectedInline( 1076 mod.submod_3.code.strip("\n"), 1077 """\ 1078def forward(self, add_1): 1079 sub = torch.ops.aten.sub.Tensor(add_1, 1); add_1 = None 1080 return (sub,) 1081 """, 1082 ) 1083 1084 def test_inline_(self): 1085 for gm, args in self.SEQUENTIAL_SPLIT_INLINE_TESTS.values(): 1086 before_str = gm.print_readable(print_output=False) 1087 new_gm = sequential_split(gm, _is_set_grad_enabled_node) 1088 nodes_map( 1089 new_gm.graph.nodes, 1090 lambda node: node_inline_(node) if node.op == "call_module" else node, 1091 ) 1092 after_inline_str = new_gm.print_readable(print_output=False) 1093 self.assertEqual(before_str, after_inline_str) 1094 self.assertEqual(gm(*args), new_gm(*args)) 1095 1096 def test_remove_auto_functionalized_pass(self) -> None: 1097 with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib: 1098 lib.define("custom_mutator(Tensor x, Tensor(a!) y) -> Tensor") 1099 1100 @impl(lib, "custom_mutator", "Meta") 1101 def custom_mutator_meta( 1102 x: torch.Tensor, 1103 y: torch.Tensor, 1104 ) -> torch.Tensor: 1105 return torch.empty_like(x) 1106 1107 @impl(lib, "custom_mutator", "CompositeExplicitAutograd") 1108 def custom_mutator( 1109 x: torch.Tensor, 1110 y: torch.Tensor, 1111 ) -> torch.Tensor: 1112 return x + y.add_(1) 1113 1114 class M(torch.nn.Module): 1115 def __init__(self) -> None: 1116 super().__init__() 1117 self.state = torch.nn.Buffer(torch.zeros(1)) 1118 1119 def forward(self, x): 1120 return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator(x, self.state) 1121 1122 mod = M() 1123 x = torch.randn([3, 3]) 1124 ep = export(mod, (x,)) 1125 inplace_ep = unsafe_remove_auto_functionalized_pass(ep) 1126 nodes = inplace_ep.graph.nodes 1127 for node in nodes: 1128 if node.op == "call_function": 1129 self.assertFalse(node.target is auto_functionalized) 1130 self.assertFalse(node.target is operator.getitem) 1131 1132 for spec in inplace_ep.graph_signature.output_specs: 1133 self.assertFalse("getitem" in spec.arg.name) 1134 1135 def test_remove_auto_functionalized_pass_tuple(self) -> None: 1136 with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib: 1137 lib.define( 1138 "custom_mutator_tuple(Tensor x, Tensor(a!) y) -> (Tensor, Tensor)" 1139 ) 1140 1141 @impl(lib, "custom_mutator_tuple", "Meta") 1142 def custom_mutator_tuple_meta( 1143 x: torch.Tensor, 1144 y: torch.Tensor, 1145 ): 1146 return (torch.empty_like(x), torch.empty_like(x)) 1147 1148 @impl(lib, "custom_mutator_tuple", "CompositeExplicitAutograd") 1149 def custom_mutator_tuple( 1150 x: torch.Tensor, 1151 y: torch.Tensor, 1152 ): 1153 return (x, x + y.add_(1)) 1154 1155 class M(torch.nn.Module): 1156 def __init__(self) -> None: 1157 super().__init__() 1158 self.state = torch.nn.Buffer(torch.zeros(1)) 1159 1160 def forward(self, x): 1161 return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple( 1162 x, self.state 1163 ) 1164 1165 mod = M() 1166 x = torch.randn([3, 3]) 1167 ep = export(mod, (x,)) 1168 inplace_ep = unsafe_remove_auto_functionalized_pass(ep) 1169 graph_text = str(inplace_ep.graph) 1170 self.assertExpectedInline( 1171 graph_text, 1172 """\ 1173graph(): 1174 %b_state : [num_users=2] = placeholder[target=b_state] 1175 %x : [num_users=1] = placeholder[target=x] 1176 %custom_mutator_tuple_default : [num_users=2] = call_function[target=torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple.\ 1177default](args = (%x, %b_state), kwargs = {}) 1178 %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 0), kwargs = {}) 1179 %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%custom_mutator_tuple_default, 1), kwargs = {}) 1180 return (b_state, getitem_3, getitem_4)""", 1181 ) 1182 1183 @unittest.skipIf(not TEST_CUDA, "requires cuda") 1184 def test_move_to_device_pass(self): 1185 class Model(torch.nn.Module): 1186 def __init__(self, size=4, h_dim=10): 1187 super().__init__() 1188 self.rnn = torch.nn.GRU(size, h_dim, batch_first=True) 1189 1190 def forward(self, x): 1191 _, states = self.rnn(x) 1192 return states 1193 1194 # move the exported program from cpu to cuda:0 1195 mod = Model() 1196 example_inputs = (torch.rand(1, 10, 4),) 1197 ep = export(mod, example_inputs) 1198 location = torch.device("cuda:0") 1199 ep = move_to_device_pass(ep, location=location) 1200 gm = ep.module() 1201 test_inputs = (torch.rand(1, 10, 4).to("cuda:0"),) 1202 outputs = gm(*test_inputs) 1203 self.assertEqual(outputs.device, torch.device("cuda:0")) 1204 # move it back to cpu 1205 location = "cpu" 1206 ep = move_to_device_pass(ep, location=location) 1207 gm = ep.module() 1208 test_inputs = (torch.rand(1, 10, 4).to("cpu"),) 1209 outputs = gm(*test_inputs) 1210 self.assertEqual(outputs.device, torch.device("cpu")) 1211 # move it to cuda:0 again 1212 location = {"cpu": "cuda:0"} 1213 ep = move_to_device_pass(ep, location=location) 1214 gm = ep.module() 1215 test_inputs = (torch.rand(1, 10, 4).to("cuda:0"),) 1216 outputs = gm(*test_inputs) 1217 self.assertEqual(outputs.device, torch.device("cuda:0")) 1218 1219 1220if __name__ == "__main__": 1221 run_tests() 1222