1# Owner(s): ["oncall: export"] 2# flake8: noqa 3import copy 4import dataclasses 5import io 6import logging 7import operator 8import re 9import unittest 10import warnings 11from contextlib import contextmanager 12from dataclasses import dataclass 13from re import escape 14from typing import Dict, List 15 16import torch 17import torch._dynamo as torchdynamo 18import torch.nn.functional as F 19from functorch.experimental.control_flow import cond, map 20from torch import Tensor 21from torch._decomp import get_decompositions 22from torch._dynamo.test_case import TestCase 23from torch._dynamo.testing import normalize_gm 24from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse 25from torch._export.utils import ( 26 get_buffer, 27 get_param, 28 is_buffer, 29 is_param, 30 register_dataclass_as_pytree_node, 31) 32from torch._higher_order_ops.hints_wrap import hints_wrapper 33from torch._inductor.compile_fx import split_const_gm 34from torch._subclasses import FakeTensorMode 35from torch.export import Dim, export, unflatten 36from torch.export._trace import ( 37 _export, 38 _export_to_torch_ir, 39 DEFAULT_EXPORT_DYNAMO_CONFIG, 40) 41from torch.export.graph_signature import ( 42 ExportGraphSignature, 43 InputKind, 44 OutputKind, 45 OutputSpec, 46 TensorArgument, 47) 48from torch.fx.experimental.proxy_tensor import make_fx 49from torch.fx.experimental.symbolic_shapes import ShapeEnv 50from torch.testing import FileCheck 51from torch.testing._internal.common_cuda import ( 52 PLATFORM_SUPPORTS_FLASH_ATTENTION, 53 SM90OrLater, 54) 55from torch.testing._internal.common_device_type import onlyCPU, onlyCUDA 56from torch.testing._internal.common_utils import ( 57 find_library_location, 58 IS_FBCODE, 59 IS_MACOS, 60 IS_SANDCASTLE, 61 IS_WINDOWS, 62 run_tests, 63 TEST_TRANSFORMERS, 64 TestCase as TorchTestCase, 65) 66from torch.utils._pytree import ( 67 LeafSpec, 68 tree_flatten, 69 tree_map, 70 tree_unflatten, 71 TreeSpec, 72 treespec_dumps, 73 treespec_loads, 74) 75 76 77try: 78 from torchrec.sparse.jagged_tensor import KeyedJaggedTensor 79 80 HAS_TORCHREC = True 81except ImportError: 82 HAS_TORCHREC = False 83 84try: 85 from . import testing 86except ImportError: 87 import testing 88# The following import pattern matters as `test_export.export` is patched 89# in other files (like test_export_nonstrict.py). `torch.export.export` 90# will invalidate the patch. 91from torch.export import export 92 93 94torch.library.define("testlib::returns_tensor_symint", "(Tensor x) -> (Tensor, SymInt)") 95torch.library.define( 96 "testlib::foo", 97 "(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)", 98 tags=torch.Tag.pt2_compliant_tag, 99) 100torch.library.define( 101 "testlib::foo_mutated", 102 "(Tensor(a!) x) -> (Tensor, Tensor)", 103 tags=torch.Tag.pt2_compliant_tag, 104) 105torch.library.define( 106 "testlib::foo_functional", 107 "(Tensor x) -> (Tensor)", 108 tags=torch.Tag.pt2_compliant_tag, 109) 110torch.library.define( 111 "testlib::foo_unbacked", 112 "(Scalar x) -> (Tensor)", 113 tags=torch.Tag.pt2_compliant_tag, 114) 115 116 117@torch.library.impl("testlib::returns_tensor_symint", "cpu") 118@torch.library.impl_abstract("testlib::returns_tensor_symint") 119def returns_tensor_symint_impl(x): 120 return x, x.shape[0] 121 122 123@torch.library.impl("testlib::foo", "cpu") 124@torch._dynamo.disable 125def foo_impl(x, z): 126 x.add_(5) 127 z.add_(5) 128 return x, z, x + z 129 130 131@torch.library.impl_abstract("testlib::foo") 132def foo_abstract(x, z): 133 return x, z, x + z 134 135 136@torch.library.impl("testlib::foo_mutated", "CompositeImplicitAutograd") 137def foo_mutated(x): 138 a, b, c = torch.ops.testlib.foo(x, x.cos()) 139 return a, a.cos() 140 141 142@torch.library.impl("testlib::foo_functional", "CompositeImplicitAutograd") 143def foo_functional(x): 144 a, b, c = torch.ops.testlib.foo(x.cos(), x.cos()) 145 return a.cos() 146 147 148@torch.library.impl("testlib::foo_unbacked", "CompositeImplicitAutograd") 149def foo_unbacked(x): 150 if x > 2: 151 return torch.ones(4, 4) 152 if x < 6: 153 return torch.ones(4, 4) 154 return torch.ones(4, 4) 155 156 157@dataclass 158class Inp: 159 x: Tensor 160 y: List[Tensor] 161 z: Dict[str, Tensor] 162 163 164NON_STRICT_SUFFIX = "_non_strict" 165RETRACEABILITY_SUFFIX = "_retraceability" 166SERDES_SUFFIX = "_serdes" 167PREDISPATCH_SUFFIX = "_pre_dispatch" 168TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp" 169TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_non_strict" 170 171 172def is_non_strict_test(test_name): 173 return test_name.endswith(NON_STRICT_SUFFIX) 174 175 176def is_retracebility_test(test_name): 177 return test_name.endswith(RETRACEABILITY_SUFFIX) 178 179 180def is_serdes_test(test_name): 181 return test_name.endswith(SERDES_SUFFIX) 182 183 184def is_training_ir_test(test_name): 185 return test_name.endswith(TRAINING_IR_DECOMP_STRICT_SUFFIX) or test_name.endswith( 186 TRAINING_IR_DECOMP_NON_STRICT_SUFFIX 187 ) 188 189 190def get_hop_schema(ep: torch.export.ExportedProgram): 191 hop_node = next( 192 node 193 for node in ep.graph.nodes 194 if isinstance(node.target, torch._ops.HigherOrderOperator) 195 ) 196 return torch._library.utils.hop_schema_from_fx_node(hop_node) 197 198 199@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") 200class TestDynamismExpression(TestCase): 201 def test_export_inline_constraints(self): 202 class Module(torch.nn.Module): 203 def forward(self, x): 204 b = x.item() 205 torch._check_is_size(b) 206 return torch.full((b, 1), 1) 207 208 f = Module() 209 inp = (torch.tensor([3]),) 210 ref = f(*inp) 211 212 gm = export(f, inp) 213 res = gm.module()(*inp) 214 215 self.assertTrue(torchdynamo.utils.same(ref, res)) 216 217 gm = make_fx(f, tracing_mode="symbolic")(*inp) 218 res = gm(*inp) 219 self.assertTrue(torchdynamo.utils.same(ref, res)) 220 221 def test_export_constraints_error_not_in_range(self): 222 class InvalidInputConflictWithInputConstraints(torch.nn.Module): 223 def forward(self, x): 224 return x + 1 225 226 inp = torch.zeros([3]) 227 dim_x = torch.export.Dim("dim_x", min=6) 228 with self.assertRaisesRegex(torch._dynamo.exc.UserError, "not in range"): 229 torch.export.export( 230 InvalidInputConflictWithInputConstraints(), 231 (inp,), 232 dynamic_shapes={"x": {0: dim_x}}, 233 ) 234 235 def test_export_slice_maxsize(self): 236 class Slice(torch.nn.Module): 237 def forward(self, *args): 238 return torch.ops.aten.slice.Tensor(*args) 239 240 inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807) 241 dynamic_shapes = (({0: Dim("dim")}, None, None, None),) 242 torch.export.export( 243 Slice(), 244 inp, 245 dynamic_shapes=dynamic_shapes, 246 ) 247 248 def test_export_constraints_error(self): 249 class ConflictingConstraints(torch.nn.Module): 250 def forward(self, x): 251 b = x.item() 252 torch._check_is_size(b) 253 torch._check(b >= 4) 254 torch._check(b <= 5) 255 torch._check(b <= 5) 256 torch._check(True) 257 return torch.full((b, 1), 1) 258 259 inp = (torch.tensor([3]),) 260 ep = export(ConflictingConstraints(), inp) 261 262 with self.assertRaisesRegex( 263 RuntimeError, r"Runtime assertion failed for expression u[\d+] \>\= 4" 264 ): 265 ep.module()(torch.tensor([3])) 266 267 def test_export_assume_static_by_default(self): 268 class Module(torch.nn.Module): 269 def forward(self, x: torch.Tensor): 270 if x.shape[0] == 4: 271 return x + 1 272 else: 273 return x 274 275 branch_on_shape = Module() 276 inp = (torch.rand(4, 5),) 277 278 # Being able to export means shape is preserved as static 279 export(branch_on_shape, inp) 280 281 282@unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") 283@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") 284class TestExport(TestCase): 285 def _test_export_same_as_eager(self, f, args, kwargs=None): 286 kwargs = kwargs or {} 287 exported_program = export(f, args, kwargs) 288 self.assertEqual(exported_program.module()(*args, **kwargs), f(*args, **kwargs)) 289 # this is not supported by .module() 290 # reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)} 291 # self.assertEqual( 292 # exported_program.module()(*args, **reversed_kwargs), f(*args, **reversed_kwargs) 293 # ) 294 295 def _check_dynamic_shapes_specs_and_shapes( 296 self, model, inputs, specs, passing_shapes, failing_shapes, test_serdes=False 297 ): 298 from torch._export.serde.dynamic_shapes import ( 299 _dump_dynamic_shapes, 300 _load_dynamic_shapes, 301 ) 302 from torch.utils._pytree import tree_map 303 304 def _construct_inputs(shapes): 305 def _is_tensor_leaf(x): 306 return isinstance(x, tuple) and all(isinstance(y, int) for y in x) 307 308 return tree_map( 309 lambda x: torch.randn(*x) if _is_tensor_leaf(x) else x, 310 shapes, 311 is_leaf=_is_tensor_leaf, 312 ) 313 314 # exports with a list of equivalent dynamic shapes specs, 315 # then tests for pass/fail on list of shapes 316 for _specs in specs: 317 ep = export(model, inputs, dynamic_shapes=_specs) 318 eps = [ep] 319 if test_serdes: 320 # test dynamic shapes serialization 321 # test that behavior remains the same when exporting with ser/des specs: 322 # serialize + deserialize original specs, and export. 323 ep_serdes = export( 324 model, 325 inputs, 326 dynamic_shapes=_load_dynamic_shapes( 327 _dump_dynamic_shapes(_specs, inputs) 328 ), 329 ) 330 eps.append(ep_serdes) 331 332 for ep in eps: 333 for shapes in passing_shapes: 334 test_inputs = _construct_inputs(shapes) 335 ep.module()(*test_inputs) 336 for shapes in failing_shapes: 337 test_inputs = _construct_inputs(shapes) 338 with self.assertRaises(RuntimeError): 339 ep.module()(*test_inputs) 340 341 def test_basic(self): 342 class Module(torch.nn.Module): 343 def forward(self, x, y): 344 return x[0] + y 345 346 f = Module() 347 inp = ([torch.ones(1, 3)], torch.ones(1, 3)) 348 self._test_export_same_as_eager(f, inp) 349 350 def test_no_tensor_computation(self): 351 class Module(torch.nn.Module): 352 def forward(self, x, y): 353 return y 354 355 f = Module() 356 inp = ([torch.ones(1, 3)], 1) 357 ep = export(f, inp) 358 self.assertEqual(ep.module()(*inp), f(*inp)) 359 self.assertExpectedInline( 360 str(ep.graph).strip(), 361 """\ 362graph(): 363 %x_0 : [num_users=0] = placeholder[target=x_0] 364 %y : [num_users=0] = placeholder[target=y] 365 return (1,)""", 366 ) 367 368 def test_no_tensor_computation_2(self): 369 class Module(torch.nn.Module): 370 def forward(self, x, y): 371 return x 372 373 f = Module() 374 inp = (torch.randn(3), 1) 375 ep = export(f, inp) 376 self.assertEqual(ep.module()(*inp), f(*inp)) 377 self.assertExpectedInline( 378 str(ep.graph).strip(), 379 """\ 380graph(): 381 %x : [num_users=1] = placeholder[target=x] 382 %y : [num_users=0] = placeholder[target=y] 383 return (x,)""", 384 ) 385 386 def test_no_tensor_computation_3(self): 387 class Module(torch.nn.Module): 388 def forward(self, x, y): 389 return 5 390 391 f = Module() 392 inp = (2, 1) 393 ep = export(f, inp) 394 self.assertEqual(ep.module()(*inp), f(*inp)) 395 self.assertExpectedInline( 396 str(ep.graph).strip(), 397 """\ 398graph(): 399 %x : [num_users=0] = placeholder[target=x] 400 %y : [num_users=0] = placeholder[target=y] 401 return (5,)""", 402 ) 403 404 def test_no_tensor_computation_4(self): 405 class Module(torch.nn.Module): 406 def forward(self, x, y): 407 return x 408 409 f = Module() 410 inp = ([torch.randn(3)], 1) 411 ep = export(f, inp) 412 self.assertEqual(ep.module()(*inp), f(*inp)) 413 self.assertExpectedInline( 414 str(ep.graph).strip(), 415 """\ 416graph(): 417 %x_0 : [num_users=1] = placeholder[target=x_0] 418 %y : [num_users=0] = placeholder[target=y] 419 return (x_0,)""", 420 ) 421 422 def test_not_registered_parameter(self): 423 class Basic(torch.nn.Module): 424 def __init__(self): 425 super().__init__() 426 self.params = {"foo": torch.nn.Parameter(torch.ones(3, 3))} 427 428 def forward(self, x): 429 return x + self.params["foo"] 430 431 f = Basic() 432 args = (torch.randn(1, 3),) 433 # strict-mode will error out because foo is registered as parameter 434 # in dynamo (a behavior that's different from eager). We decided to 435 # follow eager behavior. 436 ep = export(f, args, strict=False) 437 gm = ep.module() 438 self.assertEqual(len(ep.graph_signature.lifted_tensor_constants), 1) 439 self.assertEqual(len(ep.graph_signature.parameters), 0) 440 # check foo is not a parameter in the final graph 441 self.assertEqual(len(list(gm.named_parameters())), 0) 442 self.assertEqual(gm(*args), f(*args)) 443 self.assertExpectedInline( 444 str(gm.graph).strip(), 445 """\ 446graph(): 447 %lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0] 448 %x : [num_users=1] = placeholder[target=x] 449 %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_0), kwargs = {}) 450 return (add,)""", 451 ) 452 453 def test_external_call_non_strict_real_tensor(self): 454 class ExternalMethod: 455 def add(self, x): 456 return x + x 457 458 class Basic(torch.nn.Module): 459 def __init__(self) -> None: 460 super().__init__() 461 self.external_add = ExternalMethod().add 462 463 def forward(self, x): 464 return self.external_add(x) 465 466 f = Basic() 467 args = (torch.randn(1, 3),) 468 ep = export(f, args, strict=False) 469 self.assertEqual(ep.module()(*args), f(*args)) 470 471 def test_colon_parameter(self): 472 class M(torch.nn.Module): 473 def __init__(self) -> None: 474 super().__init__() 475 self.register_parameter("foo:bar", torch.nn.Parameter(torch.ones(3, 3))) 476 477 def forward(self, x): 478 return x + getattr(self, "foo:bar") 479 480 ep = export(M(), (torch.randn(3, 3),)) 481 x = torch.randn(3, 3) 482 self.assertEqual(ep.module()(x), M()(x)) 483 484 def test_conv_dynamic(self): 485 # Simple module for demonstration 486 class M(torch.nn.Module): 487 def __init__(self) -> None: 488 super().__init__() 489 self.conv = torch.nn.Conv2d( 490 in_channels=3, out_channels=32, kernel_size=3, padding=1 491 ) 492 self.relu = torch.nn.ReLU() 493 self.maxpool = torch.nn.MaxPool2d(kernel_size=3) 494 495 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 496 a = self.conv(x) 497 a.add_(y) 498 return self.maxpool(self.relu(a)) 499 500 example_args = (torch.randn(2, 3, 256, 256), torch.ones(2, 32, 256, 256)) 501 dynamic_shapes = {"x": {0: Dim("batch")}, "y": {0: Dim("batch")}} 502 m = M() 503 exported_program: torch.export.ExportedProgram = export( 504 m, args=example_args, dynamic_shapes=dynamic_shapes 505 ) 506 507 args = (torch.randn(17, 3, 256, 256), torch.ones(17, 32, 256, 256)) 508 self.assertEqual(exported_program.module()(*args), m(*args)) 509 args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256)) 510 self.assertEqual(exported_program.module()(*args), m(*args)) 511 512 from torch._export import capture_pre_autograd_graph 513 514 gm: torch.fx.GraphModule = capture_pre_autograd_graph( 515 m, args=example_args, dynamic_shapes=dynamic_shapes 516 ) 517 518 args = (torch.randn(17, 3, 256, 256), torch.ones(17, 32, 256, 256)) 519 self.assertEqual(gm(*args), m(*args)) 520 args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256)) 521 self.assertEqual(gm(*args), m(*args)) 522 523 def test_masked_select_dynamic(self): 524 class M(torch.nn.Module): 525 def __init__(self) -> None: 526 super().__init__() 527 528 def forward(self, x: torch.Tensor) -> torch.Tensor: 529 mask = x.ge(0.5) 530 return torch.masked_select(x, mask) 531 532 example_args = (torch.randn(3, 4, 5),) 533 dim0_x_max, dim1_x_max = 100, 7 534 dynamic_shapes = { 535 "x": { 536 0: Dim("dim0_x", max=dim0_x_max), 537 1: Dim("dim1_x_max", max=dim1_x_max), 538 } 539 } 540 m = M() 541 exported_program: torch.export.ExportedProgram = export( 542 m, args=example_args, dynamic_shapes=dynamic_shapes 543 ) 544 545 # Test that the expected upper bound is among the range constraints. 546 expected_upper_bound = dim0_x_max * dim1_x_max * 5 547 vr_upper_bounds = [ 548 vr.upper for vr in exported_program.range_constraints.values() 549 ] 550 self.assertTrue(expected_upper_bound in set(vr_upper_bounds)) 551 # Test that none of the upper bounds are larger. 552 for vr_upper in vr_upper_bounds: 553 self.assertTrue(vr_upper <= expected_upper_bound) 554 555 def test_setgrad_lifted_tensor(self): 556 class M(torch.nn.Module): 557 def forward(self, x, y): 558 with torch.enable_grad(): 559 c = torch.tensor(4) 560 z = c + x + y 561 562 return z * z 563 564 m = M() 565 x = torch.randn(4) 566 y = torch.randn(4) 567 # Need to surround export with no_grad to bypass AutogradStateOpsFailSafeguard. 568 with torch.no_grad(): 569 ep = export(m, (x, y)) 570 self.assertEqual(ep.module()(x, y), m(x, y)) 571 572 def test_basic_non_strict_real_tensor(self): 573 class Basic(torch.nn.Module): 574 def __init__(self) -> None: 575 super().__init__() 576 self.param = torch.nn.Parameter(torch.randn(1, 3)) 577 578 def forward(self, x, y): 579 return x[0] + y - self.param 580 581 f = Basic() 582 args = ([torch.randn(1, 3)], torch.randn(1, 3)) 583 ep = export(f, args, strict=False) 584 self.assertEqual(ep.module()(*args), f(*args)) 585 586 def test_basic_non_strict_fake_tensor(self): 587 class Basic(torch.nn.Module): 588 def __init__(self) -> None: 589 super().__init__() 590 self.param = torch.nn.Parameter(torch.randn(3, 2)) 591 592 def forward(self, x, y): 593 return x[0] + y - self.param 594 595 fake_mode = FakeTensorMode(shape_env=ShapeEnv(tracked_fakes=[])) 596 f = Basic() 597 with fake_mode: 598 args = ([torch.empty(3, 2)], torch.empty(3, 2)) 599 ep = export(f, args, strict=False) 600 inputs = ([torch.randn(3, 2)], torch.randn(3, 2)) 601 self.assertEqual(ep.module()(*inputs), f(*inputs)) 602 603 def test_non_strict_dynamic_shapes(self): 604 class Foo(torch.nn.Module): 605 def __init__(self) -> None: 606 super().__init__() 607 self.u = torch.nn.Buffer(torch.ones(1)) 608 self.v = torch.nn.Buffer(torch.ones(1)) 609 610 def forward(self, x, ys, zs, c): 611 y = ys[0] + ys[1] + zs["a"] + zs["b"] 612 self.v.add_(3) 613 w = self.u - self.v 614 if x.shape[0] < 3 and c.shape[0] != 4: 615 return x + w, x + y 616 else: 617 return x - w, x - y 618 619 foo = Foo() 620 621 inp = ( 622 torch.ones(5), 623 [torch.zeros(5), torch.ones(5)], 624 {"a": torch.zeros(5), "b": torch.ones(5)}, 625 torch.ones(4), 626 ) 627 dim = torch.export.Dim("dim", min=3) 628 dynamic_shapes = ( 629 {0: dim}, 630 [{0: dim}, {0: dim}], 631 {"a": {0: dim}, "b": {0: dim}}, 632 None, 633 ) 634 635 ep_ns = torch.export.export( 636 foo, inp, dynamic_shapes=dynamic_shapes, strict=False 637 ) 638 639 bad_runtime_inp1 = ( 640 torch.ones(6), 641 [torch.zeros(5), torch.ones(5)], 642 {"a": torch.zeros(5), "b": torch.ones(5)}, 643 torch.ones(4), 644 ) 645 with self.assertRaisesRegex( 646 RuntimeError, 647 escape( 648 "Expected input at *args[1][0].shape[0] to be equal to 6, but got 5" 649 ), 650 ): 651 ep_ns.module()(*bad_runtime_inp1) 652 653 bad_runtime_inp2 = ( 654 torch.ones(5), 655 [torch.zeros(5), torch.ones(5)], 656 {"a": torch.zeros(5), "b": torch.ones(5)}, 657 torch.ones(6), 658 ) 659 with self.assertRaisesRegex( 660 RuntimeError, 661 escape("Expected input at *args[3].shape[0] to be equal to 4, but got 6"), 662 ): 663 ep_ns.module()(*bad_runtime_inp2) 664 665 good_runtime_inp = ( 666 torch.ones(7), 667 [torch.zeros(7), torch.ones(7)], 668 {"a": torch.zeros(7), "b": torch.ones(7)}, 669 torch.ones(4), 670 ) 671 ep_ns.module()(*good_runtime_inp) 672 673 bad_example_inp = ( 674 torch.ones(2), 675 [torch.zeros(2), torch.ones(2)], 676 {"a": torch.zeros(2), "b": torch.ones(2)}, 677 torch.ones(4), 678 ) 679 with self.assertRaisesRegex( 680 torch.fx.experimental.symbolic_shapes.ConstraintViolationError, 681 "2 not in range.*3,", 682 ): 683 ep_ns = torch.export.export( 684 foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False 685 ) 686 687 def test_non_strict_dynamic_shapes_suggested_fixes(self): 688 class Foo(torch.nn.Module): 689 def forward(self, x, c): 690 if x.shape[0] <= 6: 691 return x + 1, c + 2 692 else: 693 return x - 1, c - 2 694 695 foo = Foo() 696 697 bad_example_inp = ( 698 torch.ones(5), 699 torch.ones(4), 700 ) 701 dim = torch.export.Dim("dim", min=3) 702 dynamic_shapes = ( 703 {0: dim}, 704 None, 705 ) 706 707 with self.assertRaisesRegex( 708 torch._dynamo.exc.UserError, 709 "Constraints violated \\(dim\\)!(.*\n)*.*" 710 "Not all values of dim.*satisfy the generated guard(.*\n)*.*" 711 "Suggested fixes:(.*\n)*.*" 712 "dim = Dim\\('dim', min=3, max=6\\)", 713 ): 714 torch.export.export( 715 foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False 716 ) 717 718 def test_unbacked_to_cond(self): 719 class M(torch.nn.Module): 720 def forward(self, a): 721 az = a.nonzero() 722 723 def true_fn(x): 724 return (x + 1).sum() 725 726 def false_fn(x): 727 return (x + 3).sum() 728 729 r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,)) 730 return r * 2 731 732 M()(torch.randn(7)) 733 torch.export.export(M(), (torch.randn(7),)) 734 735 def test_unbacked_to_cond_passthrough(self): 736 class M(torch.nn.Module): 737 def forward(self, a): 738 az = a.nonzero() 739 740 def true_fn(x): 741 return x + 1 742 743 def false_fn(x): 744 return x + 3 745 746 r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,)) 747 return r * 2 748 749 M()(torch.randn(7)) 750 torch.export.export(M(), (torch.randn(7),)) 751 752 @torch._dynamo.config.patch(capture_scalar_outputs=True) 753 def test_cond_contains_unbacked_no_escape(self): 754 class M(torch.nn.Module): 755 def forward(self, a, b1, b2, c): 756 def true_fn(x): 757 return x * b1.item() 758 759 def false_fn(x): 760 return x * b2.item() 761 762 r = torch.cond(a, true_fn, false_fn, (c,)) 763 return r * 2 764 765 args = ( 766 torch.tensor(True), 767 torch.tensor([4]), 768 torch.tensor([4]), 769 torch.randn(10, requires_grad=True), 770 ) 771 torch.export.export(M(), args) 772 773 def test_state_tensors(self): 774 class M(torch.nn.Module): # simple with register buffer 775 def __init__(self) -> None: 776 super().__init__() 777 self.buf = torch.nn.Buffer(torch.ones(2, 3), persistent=False) 778 779 def forward(self, x): 780 # x = 2 781 y = self.buf 782 # y = 1 783 w1 = self.buf + 3 784 w2 = self.buf + 4 785 w3 = self.buf + 5 786 self.buf = w1 787 z = self.buf 788 self.buf = w3 789 # z = 4 790 return x + y + z + w2 791 792 ep = torch.export.export(M(), (torch.randn(2, 3),), strict=False) 793 self.assertEqual(ep.graph_signature.buffers_to_mutate, {"add_2": "buf"}) 794 self.assertTrue( 795 torch.allclose(ep.module()(torch.ones(2, 3) + 1), torch.ones(2, 3) * 12) 796 ) 797 798 class M(torch.nn.Module): # simple without register buffer 799 def __init__(self) -> None: 800 super().__init__() 801 self.buf = torch.ones(2, 3) 802 803 def forward(self, x): 804 # x = 2 805 y = self.buf 806 # y = 1 807 self.buf = self.buf + 3 808 z = self.buf 809 # z = 3 810 return x + y + z 811 812 with self.assertRaisesRegex( 813 ValueError, 814 "The tensor attribute self.buf was assigned during export", 815 ): 816 torch.export.export(M(), (torch.randn(2, 3),), strict=False) 817 818 class M(torch.nn.Module): # complex with register buffer 819 def __init__(self) -> None: 820 super().__init__() 821 tensors = [torch.ones(2, 3), torch.ones(2, 3)] 822 for i, tensor in enumerate(tensors): 823 self.register_buffer(f"buf_{i}", tensor, persistent=False) 824 825 def get_tensor(self, i): 826 return getattr(self, f"buf_{i}") 827 828 def set_tensor(self, i, val): 829 setattr(self, f"buf_{i}", val) 830 831 def forward(self, x): 832 # x = 2 833 y = self.get_tensor(0) + self.get_tensor(1) 834 # y = 1 + 1 835 self.set_tensor(0, torch.ones(2, 3) + 2) 836 self.set_tensor(1, torch.ones(2, 3) + 2) 837 z = self.get_tensor(0) + self.get_tensor(1) 838 # z = 3 + 3 839 return x + y + z 840 841 ep = torch.export.export(M(), (torch.randn(2, 3),), strict=False) 842 self.assertEqual( 843 ep.graph_signature.buffers_to_mutate, {"add_1": "buf_0", "add_2": "buf_1"} 844 ) 845 self.assertTrue( 846 torch.allclose(ep.module()(torch.ones(2, 3) + 1), torch.ones(2, 3) * 10) 847 ) 848 849 class M(torch.nn.Module): # complex without register buffer 850 def __init__(self) -> None: 851 super().__init__() 852 self.tensors = [torch.ones(2, 3), torch.ones(2, 3)] 853 854 def get_tensor(self, i): 855 return self.tensors[i] 856 857 def set_tensor(self, i, val): 858 self.tensors[i] = val 859 860 def forward(self, x): 861 # x = 2 862 y = self.get_tensor(0) + self.get_tensor(1) 863 # y = 1 + 1 864 self.set_tensor(0, torch.ones(2, 3) + 2) 865 self.set_tensor(1, torch.ones(2, 3) + 2) 866 z = self.get_tensor(0) + self.get_tensor(1) 867 # z = 3 + 3 868 return x + y + z 869 870 with self.assertRaisesRegex( 871 ValueError, 872 "The tensor attributes self.tensors\\[0\\], self.tensors\\[1\\] were assigned during export", 873 ): 874 torch.export.export(M(), (torch.randn(2, 3),), strict=False) 875 876 def test_state_primitives(self): 877 class M(torch.nn.Module): 878 def __init__(self) -> None: 879 super().__init__() 880 self.x = 1 881 self.y = {"k": 2} 882 self.z = (3,) 883 884 def forward(self, x): 885 self.x = self.x + 4 886 self.y["k"] = self.y["k"] + 5 887 self.z = (self.z[0] + 6,) 888 return x + self.x + self.y["k"] + self.z[0] 889 890 ep = export(M(), (torch.randn(2, 3),)) 891 self.assertTrue( 892 torch.allclose(ep.module()(torch.zeros(2, 3)), torch.ones(2, 3) * 21) 893 ) 894 895 def test_export_script_module(self): 896 class Foo(torch.nn.Module): 897 def forward(self, rv: torch.Tensor, t: torch.Tensor): 898 i = t.item() 899 return rv + i 900 901 foo = Foo() 902 foo_script = torch.jit.script(foo) 903 inp = (torch.zeros(3, 4), torch.tensor(7)) 904 905 with self.assertRaisesRegex( 906 ValueError, "Exporting a ScriptModule is not supported" 907 ): 908 export(foo_script, inp) 909 910 from torch._export.converter import TS2EPConverter 911 912 TS2EPConverter(foo_script, inp).convert() 913 914 def test_torch_fn(self): 915 class M1(torch.nn.Module): 916 def __init__(self) -> None: 917 super().__init__() 918 self.linear = torch.nn.Linear(3, 3) 919 self.relu = torch.nn.ReLU() 920 921 def forward(self, x): 922 x = self.linear(x) 923 x = self.linear(x) 924 x = self.relu(x) 925 x = x + x 926 return x 927 928 ep1 = export(M1(), (torch.randn(3, 3),)).run_decompositions() 929 expected_result = [ 930 ("linear_1", "builtin_function_or_method.linear"), 931 ("linear_1", "builtin_function_or_method.linear"), 932 ("linear_2", "builtin_function_or_method.linear"), 933 ("linear_2", "builtin_function_or_method.linear"), 934 ("relu_1", "function.relu"), 935 ("add_1", "method_descriptor.add"), 936 ] 937 actual_result = [] 938 for i, node in enumerate(ep1.graph.nodes): 939 if node.op == "call_function": 940 actual_result.append(node.meta.get("torch_fn")) 941 self.assertEqual(actual_result, expected_result) 942 943 class M2(torch.nn.Module): 944 def __init__(self) -> None: 945 super().__init__() 946 947 def forward(self, x, weight, bias): 948 x = torch.nn.functional.linear(x, weight, bias) 949 x = torch.nn.functional.relu(x) 950 x = torch.add(x, x) 951 return x 952 953 ep2 = export( 954 M2(), (torch.randn(3, 3), torch.randn(3, 3), torch.randn(3)) 955 ).run_decompositions() 956 expected_result = [ 957 ("linear_1", "builtin_function_or_method.linear"), 958 ("linear_1", "builtin_function_or_method.linear"), 959 ("relu_1", "function.relu"), 960 ("add_1", "builtin_function_or_method.add"), 961 ] 962 actual_result = [] 963 for i, node in enumerate(ep2.graph.nodes): 964 if node.op == "call_function": 965 actual_result.append(node.meta.get("torch_fn")) 966 self.assertEqual(actual_result, expected_result) 967 968 @testing.expectedFailureSerDer # failed serializing SymInt nodes in subgraph (known issue) 969 def test_hoo_inline_users_issue(self): 970 # This came from an issue where replace_with_hop passes would inline subgraphs, 971 # and mess up node.users for nodes present in multiple subgraphs (e.g. _x in SetGradCase 972 # below, since it's used in both set_grad_enabled HOO modules). 973 # This checks that node.users and node.args are in correspondence. 974 def check_users_for_graph(graph): 975 def _tuple_contains(_tuple, val): 976 # check nested, since output node args have format ((x, y, ...),) 977 return any( 978 _tuple_contains(x, val) if isinstance(x, tuple) else x == val 979 for x in _tuple 980 ) 981 982 for node in graph.nodes: 983 # check node.users 984 for user in node.users.keys(): 985 assert _tuple_contains(user.args, node) 986 # check node.args 987 for arg in node.args: 988 if isinstance(arg, torch.fx.Node): 989 assert _tuple_contains(arg.users, node) 990 991 # check set grad enabled 992 class SetGradCase(torch.nn.Module): 993 def forward(self, x): 994 _x = x.shape[0] + 2 995 _xx = _x + 2 996 with torch.no_grad(): 997 y = _x * 4 998 return _xx, y 999 1000 ep = export( 1001 SetGradCase(), 1002 (torch.randn(6),), 1003 dynamic_shapes={"x": (Dim("dx"),)}, 1004 strict=False, 1005 ) 1006 check_users_for_graph(ep.graph) 1007 1008 def test_export_predispatch_custom_ops_warnings(self): 1009 @torch.library.custom_op("mylib::foo", mutates_args={}) 1010 def foo(x: torch.Tensor) -> torch.Tensor: 1011 return x.sin() 1012 1013 @foo.register_fake 1014 def _(x): 1015 return torch.empty_like(x) 1016 1017 class Foo(torch.nn.Module): 1018 def forward(self, x): 1019 return foo(x) 1020 1021 x = torch.randn(3) 1022 1023 # Assert no warnings 1024 with warnings.catch_warnings(): 1025 warnings.simplefilter("error") 1026 torch.export.export(Foo(), (x,)) 1027 1028 # Assert warning for CompositeImplictAutograd op 1029 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 1030 lib.define("foo123(Tensor x) -> Tensor") 1031 lib.impl("foo123", lambda x: x.sin(), "CompositeImplicitAutograd") 1032 1033 class Bar(torch.nn.Module): 1034 def forward(self, x): 1035 return torch.ops.mylib.foo123(x) 1036 1037 with self.assertWarnsRegex( 1038 UserWarning, "CompositeImplicitAutograd and have functional schema" 1039 ): 1040 with warnings.catch_warnings(): 1041 warnings.simplefilter("always") 1042 torch.export.export(Bar(), (x,)) 1043 1044 def test_export_preserve_linear_at_aot_level(self): 1045 class Foo(torch.nn.Module): 1046 def __init__(self) -> None: 1047 super().__init__() 1048 self.linear = torch.nn.Linear(3, 3) 1049 1050 def forward(self, x): 1051 x = self.linear(x) 1052 return torch.ops.aten.chunk.default(x, 3, 0) 1053 1054 gm = ( 1055 torch.export.export( 1056 Foo(), 1057 (torch.randn(3, 3),), 1058 ) 1059 .run_decompositions({}, _preserve_ops=(torch.ops.aten.linear.default,)) 1060 .graph_module 1061 ) 1062 # linear is CompositeImplicitAutograd functional op so we should preserve it 1063 # chunk is CompositeImplicitAutograd non-functional op we decompose. 1064 self.assertExpectedInline( 1065 str(gm.code).strip(), 1066 """\ 1067def forward(self, p_linear_weight, p_linear_bias, x): 1068 linear = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); x = p_linear_weight = p_linear_bias = None 1069 split = torch.ops.aten.split.Tensor(linear, 1); linear = None 1070 getitem = split[0] 1071 getitem_1 = split[1] 1072 getitem_2 = split[2]; split = None 1073 return (getitem, getitem_1, getitem_2)""", 1074 ) 1075 1076 def test_export_cond_preserve_torch_fn_for_subgraphs(self): 1077 class MySubModule(torch.nn.Module): 1078 def foo(self, x): 1079 return x.cos() 1080 1081 def forward(self, x): 1082 return self.foo(x) 1083 1084 class CondBranchClassMethod(torch.nn.Module): 1085 def __init__(self) -> None: 1086 super().__init__() 1087 self.subm = MySubModule() 1088 1089 def bar(self, x): 1090 return x.sin() 1091 1092 def forward(self, x): 1093 return cond(x.sum() <= 2, self.subm.forward, self.bar, [x]) 1094 1095 example_inputs = (torch.randn(1, 3, 3, 3),) 1096 m = CondBranchClassMethod() 1097 m.eval() 1098 gm = export(m, example_inputs).module() 1099 1100 actual_torch_fns = [] 1101 for mod in gm.modules(): 1102 for node in mod.graph.nodes: 1103 if node.name in {"sin", "cos"}: 1104 torch_fn = node.meta.get("torch_fn") 1105 print(torch_fn) 1106 actual_torch_fns.append(torch_fn) 1107 exp_torch_fns = [ 1108 ("cos_1", "method_descriptor.cos"), 1109 ("sin_1", "method_descriptor.sin"), 1110 ] 1111 self.assertEqual(actual_torch_fns, exp_torch_fns) 1112 1113 def test_derived_dim_basic(self): 1114 class Foo(torch.nn.Module): 1115 def forward(self, x, y): 1116 return x + y[1:] 1117 1118 foo = Foo() 1119 1120 x, y = torch.randn(5), torch.randn(6) 1121 dimx = torch.export.Dim("dimx", min=3, max=6) 1122 1123 dimy = torch.export.Dim("dimy", min=4, max=7) # doesn't work 1124 with self.assertRaisesRegex( 1125 torch._dynamo.exc.UserError, 1126 ( 1127 "Constraints violated \\(dimy\\)!(.*\n)*.*" 1128 "The values of dimy.*must always be related to the values of dimx.*by.*(.*\n)*.*" 1129 "Suggested fixes:(.*\n)*.*" 1130 "dimy = dimx \\+ 1" 1131 ), 1132 ): 1133 export( 1134 foo, 1135 (x, y), 1136 dynamic_shapes=({0: dimx}, {0: dimy}), 1137 ) 1138 1139 dimy = dimx * 2 # doesn't work 1140 with self.assertRaisesRegex( 1141 torch._dynamo.exc.UserError, 1142 "Expected input.*size.* to be equal to 2\\*dimx, where dimx = 5, but got 6", 1143 ): 1144 export( 1145 foo, 1146 (x, y), 1147 dynamic_shapes=({0: dimx}, {0: dimy}), 1148 ) 1149 1150 dimy = dimx + 1 # works 1151 ep = export( 1152 foo, 1153 (x, y), 1154 dynamic_shapes=({0: dimx}, {0: dimy}), 1155 ) 1156 with self.assertRaisesRegex( 1157 RuntimeError, 1158 "Expected input.*shape.*to be equal to 5, but got 6", 1159 ): 1160 ep.module()(torch.randn(4), torch.randn(6)) 1161 1162 self.assertEqual(ep.module()(torch.randn(4), torch.randn(5)).size()[0], 4) 1163 1164 def test_derived_dim_nested(self): 1165 class Foo(torch.nn.Module): 1166 def forward(self, x, y): 1167 return x + y[1::2] 1168 1169 foo = Foo() 1170 1171 x, y = torch.randn(5), torch.randn(11) 1172 dimx = torch.export.Dim("dimx", min=3, max=6) 1173 dimy = dimx * 2 + 1 # works 1174 ep = export( 1175 foo, 1176 (x, y), 1177 dynamic_shapes=({0: dimx}, {0: dimy}), 1178 ) 1179 self.assertEqual(ep.module()(torch.randn(4), torch.randn(9)).size()[0], 4) 1180 1181 class Foo(torch.nn.Module): 1182 def forward(self, z, y): 1183 return z[1:] + y[1::2] 1184 1185 foo = Foo() 1186 1187 z, y = torch.randn(6), torch.randn(11) 1188 1189 dimz = dimx 1190 dimy = dimx * 2 - 1 # works 1191 ep = export( 1192 foo, 1193 (z, y), 1194 dynamic_shapes=({0: dimz}, {0: dimy}), 1195 ) 1196 self.assertEqual(ep.module()(torch.randn(5), torch.randn(9)).size()[0], 4) 1197 1198 dimz = dimx + 1 1199 dimy = dimx * 2 - 1 # doesn't work 1200 1201 with self.assertRaisesRegex( 1202 torch._dynamo.exc.UserError, 1203 "Expected input.*size.*to be equal to 2\\*dimx - 1, where dimx = 5, but got 11", 1204 ): 1205 export( 1206 foo, 1207 (z, y), 1208 dynamic_shapes=({0: dimz}, {0: dimy}), 1209 ) 1210 1211 dimy = dimx * 2 + 1 # works 1212 ep = export( 1213 foo, 1214 (z, y), 1215 dynamic_shapes=({0: dimz}, {0: dimy}), 1216 ) 1217 with self.assertRaisesRegex( 1218 RuntimeError, "Expected input.*shape.*to be <= 7, but got 8" 1219 ): 1220 ep.module()(torch.randn(8), torch.randn(15)) 1221 with self.assertRaisesRegex( 1222 RuntimeError, 1223 "Expected input.*shape.*to be equal to 9, but got 8", 1224 ): 1225 ep.module()(torch.randn(5), torch.randn(8)) 1226 1227 self.assertEqual(ep.module()(torch.randn(5), torch.randn(9)).size()[0], 4) 1228 1229 def test_derived_dim_integer(self): 1230 class Foo(torch.nn.Module): 1231 def forward(self, w): 1232 if w.shape[0] % 2 == 0: 1233 return w[::2] 1234 else: 1235 return w[1:-1:2] 1236 1237 foo = Foo() 1238 1239 w = torch.randn(10) 1240 dimx = torch.export.Dim("dimx", min=3, max=6) 1241 dimw = dimx * 2 + 1 # doesn't work 1242 with self.assertRaisesRegex( 1243 torch._dynamo.exc.UserError, 1244 "Expected shape.*= 10 of input Tensor to be " 1245 "of the form 2\\*dimx \\+ 1, where dimx is an integer", 1246 ): 1247 export( 1248 foo, 1249 (w,), 1250 dynamic_shapes=({0: dimw},), 1251 ) 1252 1253 dimw = dimx * 2 # works 1254 ep = export( 1255 foo, 1256 (w,), 1257 dynamic_shapes=({0: dimw},), 1258 ) 1259 with self.assertRaisesRegex( 1260 RuntimeError, 1261 "Expected input.*shape.*= 9 to be " 1262 "of the form 2\\*s1, where s1 is an integer", 1263 ): 1264 ep.module()(torch.randn(9)) 1265 1266 self.assertEqual(ep.module()(torch.randn(8)).size()[0], 4) 1267 with self.assertRaisesRegex( 1268 RuntimeError, 1269 "Expected input.*shape.*to be <= 12, but got 14", 1270 ): 1271 ep.module()(torch.randn(14)) 1272 1273 def test_derived_dim_repeat_derived(self): 1274 class Foo(torch.nn.Module): 1275 def forward(self, u, v): 1276 return u[::2] + v[::2] 1277 1278 foo = Foo() 1279 1280 u, v = torch.randn(10), torch.randn(10) 1281 dimx = torch.export.Dim("dimx", min=3, max=6) 1282 dimw = dimx * 2 # works 1283 ep = export( 1284 foo, 1285 (u, v), 1286 dynamic_shapes=({0: dimw}, {0: dimw}), 1287 ) 1288 self.assertEqual(ep.module()(torch.randn(8), torch.randn(8)).size()[0], 4) 1289 1290 def test_derived_dim_out_of_order(self): 1291 dimy = torch.export.Dim("dimy", min=5, max=7) 1292 dimx = dimy - 1 # out of order, effectively dimy = dimx + 1 1293 dimz = dimy + 1 # out of order, effectively dimz = dimx + 2 1294 1295 class Foo(torch.nn.Module): 1296 def forward(self, x, y, z): 1297 return x + y[1:] + z[2:] 1298 1299 foo = Foo() 1300 1301 u, v, w = torch.randn(5), torch.randn(6), torch.randn(7) 1302 ep = export( 1303 foo, 1304 (u, v, w), 1305 dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}), 1306 ) 1307 with self.assertRaisesRegex( 1308 RuntimeError, 1309 "Expected input.*shape.*to be equal to 8, but got 5", 1310 ): 1311 ep.module()(torch.randn(6), torch.randn(7), torch.randn(5)) 1312 1313 self.assertEqual( 1314 ep.module()(torch.randn(6), torch.randn(7), torch.randn(8)).size()[0], 6 1315 ) 1316 1317 def test_derived_dim_out_of_order_repeat_derived(self): 1318 dimy = torch.export.Dim("dimy", min=5, max=7) 1319 dimx = dimy - 1 # out of order, effectively dimy = dimx + 1 1320 dimz = dimy + 1 # out of order, effectively dimz = dimx + 2 1321 dimx1 = dimx 1322 dimx2 = dimz - 2 # works, effectively = dimx 1323 1324 class Foo(torch.nn.Module): 1325 def forward(self, x, y, z, x1, x2): 1326 return x + y[1:] + z[2:] + x1 + x2 1327 1328 foo = Foo() 1329 1330 u, v, w, u1, u2 = ( 1331 torch.randn(5), 1332 torch.randn(6), 1333 torch.randn(7), 1334 torch.randn(5), 1335 torch.randn(5), 1336 ) 1337 ep = export( 1338 foo, 1339 (u, v, w, u1, u2), 1340 dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}, {0: dimx1}, {0: dimx2}), 1341 ) 1342 with self.assertRaisesRegex( 1343 RuntimeError, 1344 "Expected input.*shape.*to be equal to 6, but got 5", 1345 ): 1346 ep.module()( 1347 torch.randn(6), 1348 torch.randn(7), 1349 torch.randn(8), 1350 torch.randn(6), 1351 torch.randn(5), 1352 ) 1353 1354 self.assertEqual( 1355 ep.module()( 1356 torch.randn(6), 1357 torch.randn(7), 1358 torch.randn(8), 1359 torch.randn(6), 1360 torch.randn(6), 1361 ).size()[0], 1362 6, 1363 ) 1364 1365 ep = export( 1366 foo, 1367 (u, v, w, u, u), # reused inputs 1368 dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}, {0: dimx1}, {0: dimx2}), 1369 ) 1370 with self.assertRaisesRegex( 1371 RuntimeError, 1372 "Expected input.*shape.*to be equal to 6, but got 5", 1373 ): 1374 ep.module()( 1375 torch.randn(6), 1376 torch.randn(7), 1377 torch.randn(8), 1378 torch.randn(6), 1379 torch.randn(5), 1380 ) 1381 1382 self.assertEqual( 1383 ep.module()( 1384 torch.randn(6), 1385 torch.randn(7), 1386 torch.randn(8), 1387 torch.randn(6), 1388 torch.randn(6), 1389 ).size()[0], 1390 6, 1391 ) 1392 1393 def test_specialize_derived_dim_roots(self): 1394 # dim & derived dim both specialize 1395 class Foo(torch.nn.Module): 1396 def forward(self, x, y): 1397 return x.reshape([-1]) + y 1398 1399 dy = Dim("dy", min=6) 1400 x, y = torch.randn(6, 2), torch.randn(12) 1401 dynamic_shapes = { 1402 "x": (dy - 6, 2), 1403 "y": (dy,), 1404 } 1405 try: 1406 export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) 1407 raise Exception( 1408 "export() call should have failed with dynamic shapes error." 1409 ) 1410 except torch._dynamo.exc.UserError as exc: 1411 expected_error_msg = ( 1412 "Specializations unexpectedly required \(dy\)!(.*\n)*.*" 1413 ".*solving the guards generated for dy - 6.*resulted in a specialized value of 6(.*\n)*.*" 1414 "Suggested fixes(.*\n)*.*" 1415 ".*dy = 12(.*\n)*.*" 1416 ) 1417 self.assertTrue(re.search(expected_error_msg, exc.args[0]) is not None) 1418 self.assertTrue( 1419 "dy - 6 = 6" not in exc.args[0] 1420 ) # don't suggest fix for non-root dim 1421 1422 def test_keep_composite_ops_invalid(self): 1423 class Foo(torch.nn.Module): 1424 def __init__(self) -> None: 1425 super().__init__() 1426 self.linear = torch.nn.Linear(3, 3) 1427 1428 def forward(self, x): 1429 x = self.linear(x) 1430 return torch.ops.aten.chunk.default(x, 3, 0) 1431 1432 with self.assertRaisesRegex( 1433 RuntimeError, "aten.chunk.default is a mutating/aliasing op" 1434 ): 1435 _ = torch.export.export( 1436 Foo(), 1437 (torch.randn(3, 3),), 1438 ).run_decompositions({}, _preserve_ops=(torch.ops.aten.chunk.default,)) 1439 1440 with self.assertRaisesRegex( 1441 RuntimeError, "aten.sym_size.default is a metadata query function" 1442 ): 1443 _ = torch.export.export( 1444 Foo(), 1445 (torch.randn(3, 3),), 1446 ).run_decompositions({}, _preserve_ops=(torch.ops.aten.sym_size.default,)) 1447 1448 with self.assertRaisesRegex( 1449 RuntimeError, 1450 "We can't detect aten.native_batch_norm.default as a functional op statically", 1451 ): 1452 _ = torch.export.export( 1453 Foo(), 1454 (torch.randn(3, 3),), 1455 ).run_decompositions( 1456 {}, _preserve_ops=(torch.ops.aten.native_batch_norm.default,) 1457 ) 1458 1459 def test_keep_composite_ops_linear_convd(self): 1460 class MyLinear(torch.nn.Module): 1461 def __init__(self) -> None: 1462 super().__init__() 1463 self.weight = torch.randn(20, 98) 1464 self.bias = torch.randn(20) 1465 1466 def forward(self, x): 1467 return torch.nn.functional.linear(x, self.weight, self.bias) 1468 1469 class Foo(torch.nn.Module): 1470 def __init__(self) -> None: 1471 super().__init__() 1472 self.conv = torch.nn.Conv2d(16, 33, 3) 1473 self.conv1d = torch.nn.Conv1d(16, 33, 3) 1474 self.linear = MyLinear() 1475 1476 def forward(self, x, y): 1477 x_conv = self.conv(x) 1478 y_conv_1d = self.conv1d(y) 1479 x_linear = self.linear(x_conv) 1480 return x_linear.cos() + y_conv_1d.sum() 1481 1482 ep = torch.export.export( 1483 Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)) 1484 ) 1485 ep_has_linear_convd = ep.run_decompositions( 1486 decomp_table={}, 1487 _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, 1488 ) 1489 self.assertExpectedInline( 1490 str(ep_has_linear_convd.graph_module.code).strip(), 1491 """\ 1492def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y): 1493 conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None 1494 conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None 1495 linear = torch.ops.aten.linear.default(conv2d, c_linear_weight, c_linear_bias); conv2d = c_linear_weight = c_linear_bias = None 1496 cos = torch.ops.aten.cos.default(linear); linear = None 1497 sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None 1498 add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None 1499 return (add,)""", 1500 ) 1501 1502 ep_has_convd = ep.run_decompositions( 1503 decomp_table=None, 1504 _preserve_ops=[ 1505 torch.ops.aten.conv2d.default, 1506 torch.ops.aten.conv1d.default, 1507 ], 1508 ) 1509 self.assertExpectedInline( 1510 str(ep_has_convd.graph_module.code).strip(), 1511 """\ 1512def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y): 1513 conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None 1514 conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None 1515 view = torch.ops.aten.view.default(conv2d, [31680, 98]); conv2d = None 1516 permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None 1517 addmm = torch.ops.aten.addmm.default(c_linear_bias, view, permute); c_linear_bias = view = permute = None 1518 view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None 1519 cos = torch.ops.aten.cos.default(view_1); view_1 = None 1520 sum_1 = torch.ops.aten.sum.dim_IntList(conv1d, []); conv1d = None 1521 add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None 1522 return (add,)""", 1523 ) 1524 1525 ep_has_convd = ep_has_convd.run_decompositions( 1526 decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default] 1527 ) 1528 self.assertExpectedInline( 1529 str(ep_has_convd.graph_module.code).strip(), 1530 """\ 1531def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y): 1532 conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None 1533 convolution = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1); y = p_conv1d_weight = p_conv1d_bias = None 1534 view = torch.ops.aten.view.default(conv2d, [31680, 98]); conv2d = None 1535 permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None 1536 addmm = torch.ops.aten.addmm.default(c_linear_bias, view, permute); c_linear_bias = view = permute = None 1537 view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None 1538 cos = torch.ops.aten.cos.default(view_1); view_1 = None 1539 sum_1 = torch.ops.aten.sum.dim_IntList(convolution, []); convolution = None 1540 add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None 1541 return (add,)""", 1542 ) 1543 1544 def test_keep_composite_ops_linear_convd_for_training_ir(self): 1545 class MyLinear(torch.nn.Module): 1546 def __init__(self) -> None: 1547 super().__init__() 1548 self.weight = torch.nn.Buffer(torch.randn(20, 98)) 1549 self.bias = torch.nn.Buffer(torch.randn(20)) 1550 1551 def forward(self, x): 1552 return torch.nn.functional.linear(x, self.weight, self.bias) 1553 1554 class Foo(torch.nn.Module): 1555 def __init__(self) -> None: 1556 super().__init__() 1557 self.conv = torch.nn.Conv2d(16, 33, 3) 1558 self.conv1d = torch.nn.Conv1d(16, 33, 3) 1559 self.linear = MyLinear() 1560 1561 def forward(self, x, y): 1562 x_conv = self.conv(x) 1563 y_conv_1d = self.conv1d(y) 1564 x_linear = self.linear(x_conv) 1565 return x_linear.cos() + y_conv_1d.sum() 1566 1567 ep = torch.export.export_for_training( 1568 Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)) 1569 ) 1570 ep_has_linear_convd = ep.run_decompositions( 1571 decomp_table={}, 1572 _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, 1573 ) 1574 1575 self.assertExpectedInline( 1576 str(ep_has_linear_convd.graph_module.code).strip(), 1577 """\ 1578def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y): 1579 conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None 1580 conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None 1581 linear = torch.ops.aten.linear.default(conv2d, b_linear_weight, b_linear_bias); conv2d = b_linear_weight = b_linear_bias = None 1582 cos = torch.ops.aten.cos.default(linear); linear = None 1583 sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None 1584 add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None 1585 return (add,)""", 1586 ) 1587 1588 ep_has_convd = ep.run_decompositions( 1589 decomp_table=None, 1590 _preserve_ops=[ 1591 torch.ops.aten.conv2d.default, 1592 torch.ops.aten.conv1d.default, 1593 ], 1594 ) 1595 1596 self.assertExpectedInline( 1597 str(ep_has_convd.graph_module.code).strip(), 1598 """\ 1599def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y): 1600 conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None 1601 conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None 1602 view = torch.ops.aten.view.default(conv2d, [31680, 98]); conv2d = None 1603 permute = torch.ops.aten.permute.default(b_linear_weight, [1, 0]); b_linear_weight = None 1604 addmm = torch.ops.aten.addmm.default(b_linear_bias, view, permute); b_linear_bias = view = permute = None 1605 view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None 1606 cos = torch.ops.aten.cos.default(view_1); view_1 = None 1607 sum_1 = torch.ops.aten.sum.dim_IntList(conv1d, []); conv1d = None 1608 add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None 1609 return (add,)""", 1610 ) 1611 1612 ep_has_convd = ep_has_convd.run_decompositions( 1613 decomp_table=None, _preserve_ops=[torch.ops.aten.conv2d.default] 1614 ) 1615 1616 self.assertExpectedInline( 1617 str(ep_has_convd.graph_module.code).strip(), 1618 """\ 1619def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_linear_weight, b_linear_bias, x, y): 1620 conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None 1621 convolution = torch.ops.aten.convolution.default(y, p_conv1d_weight, p_conv1d_bias, [1], [0], [1], False, [0], 1); y = p_conv1d_weight = p_conv1d_bias = None 1622 view = torch.ops.aten.view.default(conv2d, [31680, 98]); conv2d = None 1623 permute = torch.ops.aten.permute.default(b_linear_weight, [1, 0]); b_linear_weight = None 1624 addmm = torch.ops.aten.addmm.default(b_linear_bias, view, permute); b_linear_bias = view = permute = None 1625 view_1 = torch.ops.aten.view.default(addmm, [20, 33, 48, 20]); addmm = None 1626 cos = torch.ops.aten.cos.default(view_1); view_1 = None 1627 sum_1 = torch.ops.aten.sum.dim_IntList(convolution, []); convolution = None 1628 add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None 1629 return (add,)""", 1630 ) 1631 1632 def test_set_grad_empty(self): 1633 class M(torch.nn.Module): 1634 def forward(self, x): 1635 with torch.no_grad(): 1636 x = x + 1 1637 return x, None 1638 1639 ep = export(M(), (torch.ones(3, 3),)) 1640 inp = torch.randn(3, 3) 1641 self.assertTrue(torch.allclose(ep.module()(inp)[0], inp + 1)) 1642 1643 def test_derived_dim_out_of_order_simplified(self): 1644 _dimz = torch.export.Dim("_dimz", min=6, max=8) 1645 dimy = _dimz - 1 1646 dimx = dimy - 1 1647 dimz = torch.export.Dim("dimz", min=6, max=8) # doesn't work, should be = _dimz 1648 1649 class Foo(torch.nn.Module): 1650 def forward(self, x, y, z): 1651 return x + y[1:] + z[2:] 1652 1653 foo = Foo() 1654 u, v, w = torch.randn(5), torch.randn(6), torch.randn(7) 1655 try: 1656 export( 1657 foo, 1658 (u, v, w), 1659 dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}), 1660 ) 1661 except torch._dynamo.exc.UserError as exc: 1662 expected_error_msg = ( 1663 "Constraints violated \(dimz\)!(.*\n)*.*" 1664 "The values of dimz.*must always be related to the values of _dimz - 2.*by.*(.*\n)*.*" 1665 "Suggested fixes:(.*\n)*.*" 1666 "dimz = _dimz" 1667 ) 1668 self.assertTrue(re.search(expected_error_msg, exc.args[0]) is not None) 1669 # don't suggest fix for non-root dims, and no need to update root here 1670 self.assertTrue("_dimz - 2 = Dim(" not in exc.args[0]) 1671 self.assertTrue("_dimz - 1 = _dimz - 1" not in exc.args[0]) 1672 self.assertTrue("_dimz = Dim(" not in exc.args[0]) 1673 1674 dimz = dimx + 2 # works, effectively = _dimz 1675 ep = export( 1676 foo, 1677 (u, v, w), 1678 dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}), 1679 ) 1680 with self.assertRaisesRegex( 1681 RuntimeError, 1682 "Expected input.*shape.*to be equal to 8, but got 5", 1683 ): 1684 ep.module()(torch.randn(6), torch.randn(7), torch.randn(5)) 1685 1686 self.assertEqual( 1687 ep.module()(torch.randn(6), torch.randn(7), torch.randn(8)).size()[0], 6 1688 ) 1689 1690 def test_simple_export_for_training(self): 1691 class Foo(torch.nn.Module): 1692 def __init__(self) -> None: 1693 super().__init__() 1694 self.linear = torch.nn.Linear(2, 2) 1695 1696 def forward(self, x): 1697 return self.linear(x) 1698 1699 eager_model = Foo() 1700 ep_for_training = torch.export.export_for_training( 1701 eager_model, (torch.ones(2, 2),) 1702 ) 1703 self.assertExpectedInline( 1704 str(ep_for_training.graph_module.code).strip(), 1705 """\ 1706def forward(self, p_linear_weight, p_linear_bias, x): 1707 linear = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); x = p_linear_weight = p_linear_bias = None 1708 return (linear,)""", 1709 ) 1710 gm = ep_for_training.module() 1711 self.assertExpectedInline( 1712 str(gm.code).strip(), 1713 """\ 1714def forward(self, x): 1715 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 1716 linear_weight = self.linear.weight 1717 linear_bias = self.linear.bias 1718 linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None 1719 return pytree.tree_unflatten((linear,), self._out_spec)""", 1720 ) 1721 1722 self.assertTrue( 1723 torch.allclose(gm(torch.ones(2, 2)), eager_model(torch.ones(2, 2))) 1724 ) 1725 1726 def test_export_for_training_with_mutation(self): 1727 class Foo(torch.nn.Module): 1728 def __init__(self) -> None: 1729 super().__init__() 1730 self.buffer = torch.nn.Buffer(torch.ones(4, 4)) 1731 1732 def forward(self, x): 1733 x.add_(5) 1734 self.buffer.add_(5) 1735 return x + self.buffer 1736 1737 eager_model_for_export = Foo() 1738 eager_model_for_testing = Foo() 1739 ep_for_training = torch.export.export_for_training( 1740 eager_model_for_export, (torch.ones(4, 4),) 1741 ) 1742 self.assertExpectedInline( 1743 str(ep_for_training.graph_module.code).strip(), 1744 """\ 1745def forward(self, b_buffer, x): 1746 add_ = torch.ops.aten.add_.Tensor(x, 5); x = None 1747 add__1 = torch.ops.aten.add_.Tensor(b_buffer, 5); b_buffer = None 1748 add = torch.ops.aten.add.Tensor(add_, add__1); add_ = add__1 = None 1749 return (add,)""", 1750 ) 1751 gm = ep_for_training.module() 1752 self.assertExpectedInline( 1753 str(gm.code).strip(), 1754 """\ 1755def forward(self, x): 1756 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 1757 buffer = self.buffer 1758 add_ = torch.ops.aten.add_.Tensor(x, 5); x = None 1759 add__1 = torch.ops.aten.add_.Tensor(buffer, 5); buffer = None 1760 add = torch.ops.aten.add.Tensor(add_, add__1); add_ = add__1 = None 1761 return pytree.tree_unflatten((add,), self._out_spec)""", 1762 ) 1763 1764 self.assertTrue( 1765 torch.allclose( 1766 gm(torch.ones(4, 4)), eager_model_for_testing(torch.ones(4, 4)) 1767 ) 1768 ) 1769 1770 def test_export_for_training_with_dynamic_shapes(self): 1771 class Foo(torch.nn.Module): 1772 def __init__(self) -> None: 1773 super().__init__() 1774 self.buffer = torch.nn.Buffer(torch.ones(4, 4)) 1775 1776 def forward(self, x): 1777 x.add_(5) 1778 self.buffer.add_(5) 1779 return x + self.buffer.sum() 1780 1781 eager_model_for_export_training = Foo() 1782 eager_model_for_export_inference = Foo() 1783 eager_model_for_testing = Foo() 1784 ep_for_training = torch.export.export_for_training( 1785 eager_model_for_export_training, 1786 (torch.ones(4, 4),), 1787 dynamic_shapes=({0: Dim("x")},), 1788 ) 1789 1790 self.assertTrue( 1791 torch.allclose( 1792 ep_for_training.module()(torch.ones(2, 4)), 1793 eager_model_for_testing(torch.ones(2, 4)), 1794 ) 1795 ) 1796 1797 ep_for_real = export( 1798 eager_model_for_export_inference, 1799 (torch.ones(4, 4),), 1800 dynamic_shapes=({0: Dim("x")},), 1801 ) 1802 1803 self.assertEqual( 1804 str(ep_for_training.range_constraints), str(ep_for_real.range_constraints) 1805 ) 1806 1807 def test_export_for_training_with_container_type(self): 1808 class Foo(torch.nn.Module): 1809 def __init__(self) -> None: 1810 super().__init__() 1811 self.buffer = torch.nn.Buffer(torch.ones(4, 4)) 1812 1813 def forward(self, container): 1814 x = container[0][0] 1815 y = container[0][1] 1816 x.add_(5) 1817 y.add_(5) 1818 return x + y + self.buffer.sum() 1819 1820 eager_model = Foo() 1821 ep_for_training = torch.export.export_for_training( 1822 eager_model, 1823 ([torch.ones(4, 4), torch.ones(4, 4)],), 1824 ) 1825 1826 self.assertTrue( 1827 torch.allclose( 1828 ep_for_training.module()( 1829 ([torch.ones(4, 4), torch.ones(4, 4)]), 1830 ), 1831 eager_model(([torch.ones(4, 4), torch.ones(4, 4)])), 1832 ) 1833 ) 1834 1835 def test_export_for_training_run_decomp(self): 1836 class Foo(torch.nn.Module): 1837 def __init__(self) -> None: 1838 super().__init__() 1839 self.buffer = torch.nn.Buffer(torch.ones(2, 2)) 1840 self.linear = torch.nn.Linear(2, 2) 1841 1842 def forward(self, x): 1843 self.buffer.add_(5) 1844 return self.linear(x) + self.buffer.sum() 1845 1846 eager_model = Foo() 1847 ep_for_training = torch.export.export_for_training( 1848 eager_model, 1849 (torch.ones(2, 2),), 1850 ) 1851 ep_for_inference = ep_for_training.run_decompositions() 1852 self.assertExpectedInline( 1853 str(ep_for_inference.graph_module.code).strip(), 1854 """\ 1855def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): 1856 add = torch.ops.aten.add.Tensor(b_buffer, 5); b_buffer = None 1857 permute = torch.ops.aten.permute.default(p_linear_weight, [1, 0]); p_linear_weight = None 1858 addmm = torch.ops.aten.addmm.default(p_linear_bias, x, permute); p_linear_bias = x = permute = None 1859 sum_1 = torch.ops.aten.sum.dim_IntList(add, []) 1860 add_1 = torch.ops.aten.add.Tensor(addmm, sum_1); addmm = sum_1 = None 1861 return (add, add_1)""", 1862 ) 1863 1864 def test_derived_dim_out_of_order_simplified_repeat_non_derived(self): 1865 class Foo(torch.nn.Module): 1866 def forward(self, x, y, y1, z): 1867 return x + y[1:] + y1[1:] + z[2:] 1868 1869 foo = Foo() 1870 1871 u, v, v1, w = torch.randn(5), torch.randn(6), torch.randn(6), torch.randn(7) 1872 _dimz = torch.export.Dim("_dimz", min=6, max=8) 1873 dimy = _dimz - 1 1874 dimx = dimy - 1 1875 dimz = dimx + 2 # works, effectively = _dimz 1876 ep = export( 1877 foo, 1878 (u, v, v1, w), 1879 dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimy}, {0: dimz}), 1880 ) 1881 with self.assertRaisesRegex( 1882 RuntimeError, 1883 "Expected input.*shape.*to be equal to 7, but got 5", 1884 ): 1885 ep.module()( 1886 torch.randn(6), 1887 torch.randn(7), 1888 torch.randn(5), 1889 torch.randn(8), 1890 ) 1891 1892 self.assertEqual( 1893 ep.module()( 1894 torch.randn(6), 1895 torch.randn(7), 1896 torch.randn(7), 1897 torch.randn(8), 1898 ).size()[0], 1899 6, 1900 ) 1901 1902 def test_static_dim_constraints(self): 1903 class Foo(torch.nn.Module): 1904 def __init__(self) -> None: 1905 super().__init__() 1906 self.l = torch.nn.Linear(6, 4) 1907 1908 def forward(self, x, y, z): 1909 x0 = self.l(x) + y[1:] 1910 return x0, z * 2.0 1911 1912 foo = Foo() 1913 inputs = (torch.randn(4, 6), torch.randn(5, 4), torch.randn(3, 3)) 1914 dx = Dim("dx", min=3, max=6) 1915 dy = dx + 1 1916 dz = Dim("dz", min=3, max=6) 1917 1918 # test that tweaking shapes fails 1919 wrong_shape_inputs = [ 1920 (torch.randn(4, 7), torch.randn(5, 4), torch.randn(3, 3)), 1921 (torch.randn(4, 6), torch.randn(5, 5), torch.randn(3, 3)), 1922 (torch.randn(4, 6), torch.randn(5, 4), torch.randn(3, 4)), 1923 ] 1924 1925 # all of these should be fine 1926 for dynamic_shapes in [ 1927 ({0: dx, 1: 6}, {0: dy, 1: 4}, {0: dz, 1: 3}), 1928 ((dx, None), (dy, 4), (dz, 3)), 1929 ((None, 6), (5, None), (None, None)), 1930 ((4, 6), {0: None, 1: 4}, {0: None, 1: 3}), 1931 (None, None, (Dim.STATIC, Dim.STATIC)), 1932 ]: 1933 ep = export(foo, inputs, dynamic_shapes=dynamic_shapes) 1934 self.assertEqual(foo(*inputs), ep.module()(*inputs)) 1935 for wrong_inputs in wrong_shape_inputs: 1936 with self.assertRaises(RuntimeError): 1937 ep.module()(*wrong_inputs) 1938 1939 # check range_constraints - static dims shouldn't be present 1940 ep = export(foo, inputs, dynamic_shapes=((dx, None), (dy, 4), (dz, 3))) 1941 self.assertEqual(len(ep.range_constraints), 3) 1942 for vr in ep.range_constraints.values(): 1943 self.assertTrue(vr.lower < vr.upper) 1944 1945 # check raised errors 1946 with self.assertRaisesRegex( 1947 ( 1948 torch.fx.experimental.symbolic_shapes.ConstraintViolationError, 1949 torch._dynamo.exc.UserError, 1950 ), 1951 "Static shape constraint of 5 does not match input size of 4, for .*", 1952 ): 1953 _ = export(foo, inputs, dynamic_shapes=((5, None), None, None)) 1954 with self.assertRaisesRegex( 1955 ( 1956 torch.fx.experimental.symbolic_shapes.ConstraintViolationError, 1957 torch._dynamo.exc.UserError, 1958 ), 1959 "Static shape constraint of 9 does not match input size of 6, for .*", 1960 ): 1961 _ = export(foo, inputs, dynamic_shapes=((dx, 9), (dy, 4), (3, 3))) 1962 1963 def test_dim_1_2(self): 1964 class Foo(torch.nn.Module): 1965 def forward(self, x): 1966 return x * 2 1967 1968 dx = Dim("dx", min=1, max=2) 1969 ep = export(Foo(), (torch.randn(2, 2),), dynamic_shapes=({0: dx, 1: None},)) 1970 ep.module()(torch.randn(1, 2)) 1971 ep.module()(torch.randn(2, 2)) 1972 with self.assertRaisesRegex( 1973 RuntimeError, "Expected input at .* to be <= 2, but got 3" 1974 ): 1975 ep.module()(torch.randn(3, 2)) 1976 vr = list(ep.range_constraints.values())[0] 1977 self.assertEqual(vr.lower, 1) 1978 self.assertEqual(vr.upper, 2) 1979 1980 def test_derived_dim_1_2(self): 1981 class Bar(torch.nn.Module): 1982 def forward(self, x, y): 1983 return x + y[1:] 1984 1985 dx = Dim("dx", min=1, max=2) 1986 ep = export( 1987 Bar(), 1988 (torch.randn(2, 2), torch.randn(3, 2)), 1989 dynamic_shapes=({0: dx, 1: None}, {0: dx + 1, 1: None}), 1990 ) 1991 ep.module()(torch.randn(1, 2), torch.randn(2, 2)) 1992 range_lower_bounds = sorted(vr.lower for vr in ep.range_constraints.values()) 1993 range_upper_bounds = sorted(vr.upper for vr in ep.range_constraints.values()) 1994 self.assertEqual(range_lower_bounds, [1, 2]) 1995 self.assertEqual(range_upper_bounds, [2, 3]) 1996 1997 def test_dynamic_shapes_builder_basic(self): 1998 class M(torch.nn.Module): 1999 def forward(self, x, y, z): 2000 return x + y[0] + z["k"] 2001 2002 m = M() 2003 2004 x = torch.randn(4) 2005 y = [torch.randn(4)] 2006 z = {"k": torch.randn(4)} 2007 args = (x, y, z) 2008 2009 shapes_collection = torch.export.ShapesCollection() 2010 dim = torch.export.Dim("dim", max=10) 2011 shapes_collection[x] = (dim,) 2012 shapes_collection[y[0]] = (dim,) 2013 shapes_collection[z["k"]] = (dim,) 2014 2015 ep = export(m, args, dynamic_shapes=shapes_collection) 2016 sym = next(iter(ep.range_constraints.keys())) 2017 for node in ep.graph.nodes: 2018 if node.op == "placeholder": 2019 self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") 2020 2021 def test_dynamic_shapes_builder_kwargs(self): 2022 class M(torch.nn.Module): 2023 def forward(self, x, y, z): 2024 return x + y[0] + z["k"] 2025 2026 m = M() 2027 2028 x = torch.randn(4) 2029 y = [torch.randn(4)] 2030 z = {"k": torch.randn(4)} 2031 args = (x,) 2032 kwargs = {"z": z, "y": y} 2033 2034 shapes_collection = torch.export.ShapesCollection() 2035 dim = torch.export.Dim("dim", max=10) 2036 shapes_collection[x] = (dim,) 2037 shapes_collection[y[0]] = (dim,) 2038 shapes_collection[z["k"]] = (dim,) 2039 2040 ep = export(m, args, kwargs=kwargs, dynamic_shapes=shapes_collection) 2041 sym = next(iter(ep.range_constraints.keys())) 2042 for node in ep.graph.nodes: 2043 if node.op == "placeholder": 2044 self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") 2045 2046 # retracing doesn't seem to like dataclass registration, 2047 # raising a dynamo error in fx_pytree.tree_flatten_spec 2048 @testing.expectedFailureRetraceability 2049 def test_dynamic_shapes_builder_pytree(self): 2050 torch.export.register_dataclass( 2051 Inp, 2052 serialized_type_name="test_dynamic_shapes_builder_pytree.Inp", 2053 ) 2054 2055 class M(torch.nn.Module): 2056 def forward(self, inp: Inp): 2057 return inp.x + inp.y[0] + inp.z["k"] 2058 2059 m = M() 2060 x = torch.randn(4) 2061 y = [torch.randn(4)] 2062 z = {"k": torch.randn(4)} 2063 args = (Inp(x, y, z),) 2064 2065 shapes_collection = torch.export.ShapesCollection() 2066 dim = torch.export.Dim("dim", max=10) 2067 shapes_collection[x] = (dim,) 2068 shapes_collection[y[0]] = (dim,) 2069 shapes_collection[z["k"]] = (dim,) 2070 2071 ep = export(m, args, dynamic_shapes=shapes_collection.dynamic_shapes(m, args)) 2072 sym = next(iter(ep.range_constraints.keys())) 2073 for node in ep.graph.nodes: 2074 if node.op == "placeholder": 2075 self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") 2076 2077 def test_mismatched_dynamic_shapes(self): 2078 AUTO, STATIC = Dim.AUTO, Dim.STATIC 2079 2080 class M(torch.nn.Module): 2081 def forward(self, x): 2082 return x["k"]["k"][0] + x["k"]["k"][1] 2083 2084 inputs = ({"k": {"k": [torch.rand(4), torch.rand(4)]}},) 2085 dim = torch.export.Dim("dim") 2086 2087 dynamic_shapes = { 2088 "k": {"k": [dim, dim]} 2089 } # ValueError: Node keys mismatch; missing key(s): {'x'}; extra key(s): {'k'}. 2090 with self.assertRaisesRegex( 2091 torch._dynamo.exc.UserError, 2092 re.escape( 2093 "When `dynamic_shapes` is specified as a dict, its top-level keys " 2094 "must be the arg names ['x'] of `inputs`, but here they are ['k']. " 2095 "Since here `inputs` is a list/tuple enclosing a single dict, " 2096 "maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?" 2097 ), 2098 ): 2099 export(M(), inputs, dynamic_shapes=dynamic_shapes) 2100 2101 dynamic_shapes = ( 2102 {"k": {"k": [dim, dim]}}, 2103 ) # torch._dynamo.exc.UserError: Unexpected dynamic_shape .*dim.* of Tensor, try None instead 2104 with self.assertRaisesRegex( 2105 torch._dynamo.exc.UserError, 2106 "Unexpected input tensor shape .*dim.* " 2107 + re.escape( 2108 "specified at `dynamic_shapes[0]['k']['k'][0]` " 2109 "(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," 2110 " where each dimension is an int, a Dim, Dim.AUTO, or Dim.STATIC)" 2111 ), 2112 ): 2113 export(M(), inputs, dynamic_shapes=dynamic_shapes) 2114 2115 dynamic_shapes = ( 2116 {"k": {"k": (dim, dim)}}, 2117 ) # ValueError: Node type mismatch; expected <class 'list'>, but got <class 'tuple'>. 2118 with self.assertRaisesRegex( 2119 torch._dynamo.exc.UserError, 2120 re.escape( 2121 "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: " 2122 "`inputs[0]['k']['k']` is a <class 'list'>, but `dynamic_shapes[0]['k']['k']` is a <class 'tuple'>" 2123 ), 2124 ): 2125 export(M(), inputs, dynamic_shapes=dynamic_shapes) 2126 2127 dynamic_shapes = ({"k": {"k": [(dim,), (dim,)]}},) # ok 2128 export(M(), inputs, dynamic_shapes=dynamic_shapes) 2129 2130 dynamic_shapes = ( 2131 {"k": {"k": dim}}, 2132 ) # ValueError: Node type mismatch; expected <class 'list'>, but got .*_Dim.*. 2133 with self.assertRaisesRegex( 2134 torch._dynamo.exc.UserError, 2135 re.escape( 2136 "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: " 2137 "`inputs[0]['k']['k']` is a <class 'list'>, but `dynamic_shapes[0]['k']['k']` is not" 2138 ), 2139 ): 2140 export(M(), inputs, dynamic_shapes=dynamic_shapes) 2141 2142 dynamic_shapes = { 2143 "x": {"k": [(dim,), (dim,)]}, 2144 "k": {"k": [(dim,), (dim,)]}, 2145 } # ValueError: Node arity mismatch; expected 1, but got 2. 2146 with self.assertRaisesRegex( 2147 torch._dynamo.exc.UserError, 2148 re.escape( 2149 "When `dynamic_shapes` is specified as a dict, its top-level keys " 2150 "must be the arg names ['x'] of `inputs`, but here they are ['x', 'k']. " 2151 "Alternatively, you could also ignore arg names entirely " 2152 "and specify `dynamic_shapes` as a list/tuple matching `inputs`." 2153 ), 2154 ): 2155 export(M(), inputs, dynamic_shapes=dynamic_shapes) 2156 2157 dynamic_shapes = ( 2158 {"k": {"k": [(dim,), (dim,), (dim,)]}}, 2159 ) # ValueError: Node arity mismatch; expected 2, but got 3. 2160 with self.assertRaisesRegex( 2161 torch._dynamo.exc.UserError, 2162 re.escape( 2163 "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: " 2164 "`inputs[0]['k']['k']` has 2 elements, but `dynamic_shapes[0]['k']['k']` has 3 elements" 2165 ), 2166 ): 2167 export(M(), inputs, dynamic_shapes=dynamic_shapes) 2168 2169 dynamic_shapes = ( 2170 {"k": {"K": [(dim,), (dim,), (dim,)]}}, 2171 ) # ValueError: Node keys mismatch; missing key(s): {'k'}; extra key(s): {'K'}. 2172 with self.assertRaisesRegex( 2173 torch._dynamo.exc.UserError, 2174 re.escape( 2175 "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: " 2176 "`inputs[0]['k']` has keys ['k'], but `dynamic_shapes[0]['k']` has keys ['K']" 2177 ), 2178 ): 2179 export(M(), inputs, dynamic_shapes=dynamic_shapes) 2180 2181 dynamic_shapes = { 2182 "x": {"k": {"k": [(dim,), (AUTO,)]}} 2183 } # mixing AUTO and Dims is not well supported. 2184 with self.assertRaisesRegex( 2185 torch._dynamo.exc.UserError, 2186 re.escape( 2187 "Specifying both `Dim.AUTO` and `Dim` or `DerivedDim` in `dynamic_shapes` is not well supported at the moment, " 2188 "and can easily lead to constraint violation errors or obscure errors in torch.export." 2189 ), 2190 ): 2191 export(M(), inputs, dynamic_shapes=dynamic_shapes) 2192 2193 class N(torch.nn.Module): 2194 def forward(self, x): 2195 return x["k"]["k1"][0] + x["k"]["k2"][0] 2196 2197 inputs = ({"k": {"k1": [torch.rand(4)], "k2": [torch.rand(4)]}},) 2198 dim = torch.export.Dim("dim") 2199 2200 dynamic_shapes = ({"k": {"k2": [(dim,)], "k1": [(dim,)]}},) # ok 2201 export(N(), inputs, dynamic_shapes=dynamic_shapes) 2202 2203 def test_torch_check_eq_commutativity(self): 2204 class M1(torch.nn.Module): 2205 def forward(self, x1, x2, x3, y): 2206 z1 = x1.item() 2207 z2 = x2.item() 2208 z3 = x3.item() 2209 # instead of: torch._check((z2 + z3) == z1) 2210 torch._check(z1 == (z2 + z3)) 2211 if z2 + z3 == z1: 2212 return y * 2 2213 else: 2214 return y + 3 2215 2216 export( 2217 M1(), 2218 (torch.tensor(6), torch.tensor(3), torch.tensor(3), torch.randn(1)), 2219 ) 2220 2221 class M2(torch.nn.Module): 2222 def forward(self, x1, x2, x3, y): 2223 z1 = x1.item() 2224 z2 = x2.item() 2225 z3 = x3.item() 2226 # instead of: torch._check((z2 + z3) != z1) 2227 torch._check(z1 != (z2 + z3)) 2228 if z2 + z3 == z1: 2229 return y * 2 2230 else: 2231 return y + 3 2232 2233 export( 2234 M2(), 2235 (torch.tensor(6), torch.tensor(6), torch.tensor(6), torch.randn(1)), 2236 ) 2237 2238 def test_raise_user_error_when_guard_on_data_dependent_operation(self): 2239 class M(torch.nn.Module): 2240 def forward(self, x): 2241 y = x.nonzero() 2242 z = y.shape[0] 2243 if z > 2: 2244 return x.cos() 2245 else: 2246 return x.sin() 2247 2248 with self.assertRaisesRegex( 2249 ( 2250 torchdynamo.exc.UserError, 2251 torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode, 2252 ), 2253 "Could not guard on data-dependent expression", 2254 ): 2255 _ = export(M(), (torch.tensor([2, 3, 5]),)) 2256 2257 def test_suggested_fixes_for_data_dependent_errors_basic(self): 2258 # suggested fixes for data-dependent errors only work in non-strict mode 2259 strict = False 2260 error_type = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode 2261 2262 # Just to introduce some indirection: N is a top-level module N that calls 2263 # module M, defined next. 2264 class N(torch.nn.Module): 2265 def __init__(self) -> None: 2266 super().__init__() 2267 self.m = M() 2268 2269 def forward(self, t): 2270 return self.m(t) + 1 2271 2272 # example input 2273 t = torch.tensor([1, 4, 4], dtype=torch.int32) 2274 2275 # We define a series of versions of M() below. Each version has 2276 # raises a data-dependent error that the next version fixes, by 2277 # copy-pasting a suggested fix in the error message. The fix is 2278 # always a torch.check() on an unresolved condition (or its negation) 2279 # on unbacked symints mentioned in the error message. 2280 # Note that the suggested fixes are in terms of local variables 2281 # near the location of error that "contain" the unbacked symints 2282 # in the unresolved condition (either directly or indirectly, e.g., 2283 # inside a list or inside the shape of a tensor). 2284 2285 class M_v0(torch.nn.Module): 2286 def forward(self, t): 2287 items = [t[i].item() for i in range(t.numel())] 2288 r = torch.randn([items[0], items[1]]) 2289 # Could not guard on data-dependent expression Eq(u2, -1) 2290 return r.view(items[0], items[2]) 2291 2292 M = M_v0 2293 with self.assertRaisesRegex( 2294 error_type, 2295 "The following call raised this error(.*\n)+" 2296 f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+" 2297 "To fix the error, insert one of the following checks before this call.*:\n" 2298 f".*{re.escape('torch._check(items[2] == (-1))')}.*\n" 2299 f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+" 2300 f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}", 2301 ): 2302 export(N(), (t,), strict=strict) 2303 2304 class M_v1(torch.nn.Module): 2305 def forward(self, t): 2306 items = [t[i].item() for i in range(t.numel())] 2307 r = torch.randn([items[0], items[1]]) 2308 # Could not guard on data-dependent expression Eq(u2, -1) 2309 torch._check(items[2] != -1) 2310 # Could not guard on data-dependent expression u2 >= 0 2311 return r.view(items[0], items[2]) 2312 2313 M = M_v1 2314 with self.assertRaisesRegex( 2315 error_type, 2316 "The following call raised this error(.*\n)+" 2317 f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+" 2318 "To fix the error, insert one of the following checks before this call.*:\n" 2319 f".*{re.escape('torch._check(items[2] >= 0)')}.*\n" 2320 f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+" 2321 f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}", 2322 ): 2323 export(N(), (t,), strict=strict) 2324 2325 class M_v2(torch.nn.Module): 2326 def forward(self, t): 2327 items = [t[i].item() for i in range(t.numel())] 2328 r = torch.randn([items[0], items[1]]) 2329 # Could not guard on data-dependent expression Eq(u2, -1) 2330 torch._check(items[2] != -1) 2331 # Could not guard on data-dependent expression u2 >= 0 2332 torch._check(items[2] >= 0) 2333 # Could not guard on data-dependent expression Eq(u1, u2) 2334 return r.view(items[0], items[2]) 2335 2336 M = M_v2 2337 with self.assertRaisesRegex( 2338 error_type, 2339 "The following call raised this error(.*\n)+" 2340 f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+" 2341 "To fix the error, insert one of the following checks before this call.*:\n" 2342 f".*{re.escape('torch._check(items[2] == items[1])')}.*\n" 2343 f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+" 2344 f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}", 2345 ): 2346 export(N(), (t,), strict=strict) 2347 2348 class M_v3(torch.nn.Module): 2349 def forward(self, t): 2350 items = [t[i].item() for i in range(t.numel())] 2351 r = torch.randn([items[0], items[1]]) 2352 # Could not guard on data-dependent expression Eq(u2, -1) 2353 torch._check(items[2] != -1) 2354 # Could not guard on data-dependent expression u2 >= 0 2355 torch._check(items[2] >= 0) 2356 # Could not guard on data-dependent expression Eq(u1, u2) 2357 torch._check(items[2] == r.shape[1]) 2358 return r.view(items[0], items[2]) 2359 2360 M = M_v3 2361 export(N(), (t,), strict=strict) 2362 2363 @testing.expectedFailureSerDer # T195866111 2364 def test_suggested_fixes_for_data_dependent_errors_puzzlers(self): 2365 # suggested fixes for data-dependent errors only work in non-strict mode 2366 strict = False 2367 error_type = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode 2368 2369 def retry_export(m, inp, fixes): 2370 # API that applies a series of fixes, retrying export after applying each fix, 2371 # and asserting the applied fix was suggested in the previous try. 2372 # Using this API avoids the need to define multiple versions of the same test 2373 # module, as in `test_suggested_fixes_for_data_dependent_errors_basic` above. 2374 def code(snippets): 2375 return f"[{', '.join(snippets)}]" 2376 2377 for i in range(len(fixes)): 2378 with self.assertRaisesRegex(error_type, re.escape(fixes[i])): 2379 export(m, (*inp, code(fixes[:i])), strict=strict) 2380 export(m, (*inp, code(fixes)), strict=strict) 2381 2382 # The following examples are lifted from @ezyang's "Data-dependent shape puzzlers" 2383 # notebook at https://www.internalfb.com/intern/anp/view/?id=5330476 2384 2385 # These test modules are written in a way that works well with retry_export above. 2386 # Specifically, they take an extra `fixes` argument and `eval` it at the location 2387 # that is expected to raise errors. 2388 2389 class cf_implicitsize(torch.nn.Module): 2390 def forward(self, x, y, fixes): 2391 i = x.item() 2392 eval(fixes) 2393 # instead of y[i] 2394 return y.narrow(0, i, 1).squeeze() 2395 2396 retry_export( 2397 cf_implicitsize(), 2398 (torch.tensor(2), torch.randn(10)), 2399 fixes=[ 2400 # Could not guard on data-dependent expression u0 < 0 2401 "torch._check(i >= 0)", 2402 ], 2403 ) 2404 2405 class cf_nomemo(torch.nn.Module): 2406 def forward(self, x, y, fixes): 2407 i = y[0].item() 2408 eval(fixes) 2409 return x.unsqueeze(1).expand(-1, i) 2410 2411 retry_export( 2412 cf_nomemo(), 2413 (torch.randn(8), torch.tensor([2])), 2414 fixes=[ 2415 # Could not guard on data-dependent expression Eq(u0, 1) 2416 "torch._check(i != 1)", 2417 # Could not guard on data-dependent expression Ne(u0, -1) 2418 "torch._check(i != (-1))", 2419 ], 2420 ) 2421 2422 class cf_changevar(torch.nn.Module): 2423 def forward(self, x, fixes): 2424 i = x.item() 2425 eval(fixes) 2426 r = torch.arange(i // 2) 2427 return r + r 2428 2429 retry_export( 2430 cf_changevar(), 2431 (torch.tensor(20),), 2432 fixes=[ 2433 # Could not guard on data-dependent expression Eq((u0//2), 0) 2434 "torch._check(((i//2)) != 0)", 2435 # Could not guard on data-dependent expression Eq((u0//2), 1) 2436 "torch._check(((i//2)) != 1)", 2437 ], 2438 ) 2439 2440 class cf_stacklist(torch.nn.Module): 2441 def forward(self, xs, y, fixes): 2442 i = y.item() 2443 eval(fixes) 2444 # instead of xs[i] 2445 return torch.stack(xs, 0).narrow(0, i, 1).squeeze() 2446 2447 retry_export( 2448 cf_stacklist(), 2449 ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), 2450 fixes=[ 2451 # Could not guard on data-dependent expression u0 < 0 2452 "torch._check(i >= 0)", 2453 ], 2454 ) 2455 2456 class cf_tensorsplit(torch.nn.Module): 2457 def forward(self, x, offsets_t, fixes): 2458 lengths = torch.diff(offsets_t).tolist() 2459 rs = [] 2460 start = 0 2461 for length in lengths: 2462 eval(fixes) 2463 rs.append(x.narrow(0, start, length)) 2464 start += length 2465 return rs 2466 2467 retry_export( 2468 cf_tensorsplit(), 2469 (torch.arange(10), torch.tensor([0, 2, 5, 7, 10])), 2470 fixes=[], # nothing to fix! 2471 ) 2472 2473 def test_no_suggested_fixes_for_data_dependent_errors(self): 2474 # suggested fixes for data-dependent errors only work in non-strict mode 2475 strict = False 2476 error_type = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode 2477 2478 class cf_stacklist(torch.nn.Module): 2479 def forward(self, xs, y): 2480 # y.item() is not a local, so we can't suggest a fix 2481 return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() 2482 2483 with self.assertRaisesRegex( 2484 error_type, 2485 "Could not guard on data-dependent expression u0 < 0", 2486 ): 2487 export( 2488 cf_stacklist(), 2489 ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), 2490 strict=strict, 2491 ) 2492 2493 def test_tolist(self): 2494 class M(torch.nn.Module): 2495 def forward(self, x): 2496 return x.tolist() 2497 2498 ep = export(M(), (torch.ones(3, dtype=torch.int),)) 2499 self.assertEqual(ep.module()(torch.tensor([1, 2, 3])), [1, 2, 3]) 2500 2501 def test_if_functional(self): 2502 class Module(torch.nn.Module): 2503 def forward(self, x): 2504 z = x + 4 2505 z.add_(4) 2506 y = z.view(x.shape) 2507 return x.cos() + y.cos() 2508 2509 foo = Module() 2510 gm = export(foo, (torch.tensor([2, 3, 5]),)) 2511 2512 view_count = 0 2513 for node in gm.graph.nodes: 2514 if node.op == "call_function" and node.target == torch.ops.aten.add_.Tensor: 2515 # No more inplace mutation 2516 self.assertNotEqual( 2517 node.target, 2518 torch.ops.aten.add_.Tensor, 2519 "There shouldn't be any inplace mutation node in the graph.", 2520 ) 2521 if ( 2522 node.op == "call_function" 2523 and node.target == torch.ops.aten.view.default 2524 ): 2525 view_count += 1 2526 2527 # There should be nonzero view nodes in the graph 2528 self.assertTrue(view_count > 0) 2529 2530 def test_solver_unsupported_sympy_function(self): 2531 # repro of https://github.com/pytorch/pytorch/issues/131897 2532 2533 class MyModule(torch.nn.Module): 2534 def __init__(self): 2535 super().__init__() 2536 2537 def forward(self, x, y): 2538 x = torch.nn.functional.interpolate( 2539 x, scale_factor=0.5, mode="bilinear" 2540 ) 2541 x = torch.nn.functional.interpolate( 2542 x, scale_factor=2.0, mode="bilinear" 2543 ) 2544 x = x + y 2545 return x 2546 2547 model = MyModule().eval() 2548 2549 inputs = ( 2550 torch.rand((1, 1, 32, 32)), 2551 torch.rand((1, 1, 32, 32)), 2552 ) 2553 2554 dim = torch.export.Dim("Dim", min=16, max=64) 2555 dynamic_shapes = {"x": {2: dim, 3: dim}, "y": {2: dim, 3: dim}} 2556 2557 exported_program = export(model, inputs, dynamic_shapes=dynamic_shapes) 2558 self.assertEqual(exported_program.module()(*inputs), model(*inputs)) 2559 2560 def test_export_mod_constraints(self): 2561 class BasicDynamiShapeModel(torch.nn.Module): 2562 def forward(self, x: torch.Tensor) -> torch.Tensor: 2563 return x.view(x.shape[0] - 1, -1) 2564 2565 m = BasicDynamiShapeModel() 2566 a = torch.randn(3, 4) 2567 dim0_x = torch.export.Dim("dim0_x", min=3) 2568 dim1_x = torch.export.Dim("dim1_x", max=8000) 2569 dynamic_shapes = {"x": (dim0_x, dim1_x)} 2570 em = torch.export._trace._export( 2571 m, 2572 (a,), 2573 dynamic_shapes=dynamic_shapes, 2574 allow_complex_guards_as_runtime_asserts=True, 2575 ) 2576 em.module()(torch.randn(4, 3)) 2577 with self.assertRaisesRegex( 2578 RuntimeError, 2579 r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, s0 \- 1\), 0\)", 2580 ): 2581 em.module()(torch.randn(4, 5)) 2582 2583 dim0_x = None 2584 dim1_x = 2 * torch.export.Dim("_dim1_x", max=4000) 2585 dynamic_shapes = {"x": (dim0_x, dim1_x)} 2586 em = torch.export.export(m, (a,), dynamic_shapes=dynamic_shapes) 2587 x = torch.randn(3, 5) 2588 with self.assertRaisesRegex( 2589 RuntimeError, 2590 "Expected.*shape\\[1\\] = 5 to be of the form 2\\*s1, where s1 is an integer", 2591 ): 2592 em.module()(x) 2593 2594 def test_mark_and_auto_dynamic(self): 2595 # for this use case, mark_dynamic() and AUTO should have same effect. 2596 # check that same symbol gets allocated to both dims without raising constraint violation. 2597 AUTO, STATIC = Dim.AUTO, Dim.STATIC 2598 2599 class Foo(torch.nn.Module): 2600 def forward(self, x, y): 2601 torch._check(x.shape[0] == y.shape[0]) 2602 torch._check(x.shape[0] <= 64) 2603 return x + 2, y + 2 2604 2605 inputs = (torch.randn(4, 4), torch.randn(4, 4)) 2606 ep_auto = torch.export.export( 2607 Foo(), inputs, dynamic_shapes={"x": (AUTO, None), "y": (AUTO, None)} 2608 ) 2609 torch._dynamo.mark_dynamic(inputs[0], 0) 2610 torch._dynamo.mark_dynamic(inputs[1], 0) 2611 ep_dynamic = torch.export.export(Foo(), inputs) 2612 2613 # test both programs have same effect 2614 for ep in [ep_auto, ep_dynamic]: 2615 gm = ep.module() 2616 gm(torch.randn(32, 4), torch.randn(32, 4)) 2617 gm(torch.randn(1, 4), torch.randn(1, 4)) 2618 with self.assertRaises(RuntimeError): 2619 gm(torch.randn(33, 4), torch.randn(32, 4)) 2620 gm(torch.randn(128, 4), torch.randn(128, 4)) 2621 2622 def test_dont_duck_size_for_auto_dynamic(self): 2623 # for this use case, mark_dynamic() and AUTO should have same effect. 2624 # check that same symbol gets allocated to both dims without raising constraint violation. 2625 AUTO, STATIC = Dim.AUTO, Dim.STATIC 2626 2627 class Foo(torch.nn.Module): 2628 def forward(self, x, y): 2629 # x: [s0, s1], y: [s0 + 1, 4] 2630 assert y.shape[1] == 4 2631 assert x.shape[0] == y.shape[0] - 1 2632 return x * 2, y * 2 2633 2634 # duck sizing would make all static based on these sample inputs 2635 inputs = (torch.randn(4, 4), torch.randn(5, 4)) 2636 shapes = { 2637 "x": (AUTO, AUTO), 2638 "y": (AUTO, AUTO), 2639 } 2640 ep = export(Foo(), inputs, dynamic_shapes=shapes) 2641 ep.module()(torch.randn(6, 3), torch.randn(7, 4)) 2642 2643 @testing.expectedFailureRetraceability # T183144629 2644 def test_map(self): 2645 class Module(torch.nn.Module): 2646 def forward(self, xs, y, z): 2647 def body(x, y, z): 2648 return x + y + z 2649 2650 return map(body, xs, y, z) 2651 2652 list_tensor_map = Module() 2653 inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4)) 2654 self._test_export_same_as_eager(list_tensor_map, inps) 2655 2656 @unittest.expectedFailure 2657 def test_crop_like(self): 2658 # https://fb.workplace.com/groups/1405155842844877/posts/8195050017188725/ 2659 2660 # Minimal crop code copied from https://github.com/pytorch/vision/blob/main/torchvision/transforms/v2/functional 2661 class CropLike(torch.nn.Module): 2662 def forward(self, image, crop_height, crop_width): 2663 c, image_height, image_width = image.shape 2664 crop_top = int(round((image_height - crop_height) / 2.0)) 2665 crop_left = int(round((image_width - crop_width) / 2.0)) 2666 return image[ 2667 ..., 2668 crop_top : crop_top + crop_height, 2669 crop_left : crop_left + crop_width, 2670 ] 2671 2672 crop = CropLike() 2673 imagew = Dim("width") 2674 imageh = Dim("height") 2675 dynamic_dims = { 2676 "image": {0: None, 1: imageh, 2: imagew}, 2677 "crop_height": None, 2678 "crop_width": None, 2679 } 2680 args = (torch.rand(3, 512, 512), 150, 150) 2681 ecrop = export(crop, args=args, dynamic_shapes=dynamic_dims) 2682 2683 args = (torch.rand(3, 700, 700), 150, 150) 2684 self.assertEqual(ecrop.module()(*args), ecrop(*args)) 2685 2686 def test_export_func_with_kwargs(self): 2687 class Module(torch.nn.Module): 2688 def forward(self, arg1, arg2, kw1, kw2): 2689 return arg1 + arg2, kw1 + kw2 2690 2691 kw_func = Module() 2692 args = (torch.ones(6, 4), torch.ones(1, 1)) 2693 kwargs = {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)} 2694 self._test_export_same_as_eager(kw_func, args, kwargs) 2695 2696 def test_export_func_with_pytree_kwargs(self): 2697 class Module(torch.nn.Module): 2698 def forward(self, arg1, arg2, a, b): 2699 return arg1 + a["kw1"] + b[0], arg2 + a["kw2"] + b[1] 2700 2701 kw_func = Module() 2702 args = (torch.ones(2, 3), torch.ones(3, 4)) 2703 kwargs = { 2704 "a": {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)}, 2705 "b": [torch.ones(2, 3), torch.ones(3, 4)], 2706 } 2707 self._test_export_same_as_eager(kw_func, args, kwargs) 2708 2709 def test_export_func_with_default_kwargs(self): 2710 class Module(torch.nn.Module): 2711 def forward(self, arg1, arg2, a, b=1): 2712 return arg1 + arg2, a["kw1"] + a["kw2"] + b 2713 2714 kw_func = Module() 2715 2716 class Module2(torch.nn.Module): 2717 def forward(self, arg1, arg2, a=1, b=2): 2718 return arg1 + a, arg2 + b 2719 2720 kw_func2 = Module2() 2721 2722 args = (torch.ones(6, 4), torch.ones(1, 1)) 2723 kwargs1 = {"a": {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}} 2724 kwargs2 = {"a": {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)}, "b": 2} 2725 self._test_export_same_as_eager(kw_func, args, kwargs1) 2726 self._test_export_same_as_eager(kw_func, args, kwargs2) 2727 kwargs3 = {"b": 1} 2728 self._test_export_same_as_eager(kw_func2, args, kwargs3) 2729 2730 def test_export_func_with_var_postional_args(self): 2731 class Module(torch.nn.Module): 2732 def forward(self, arg1, arg2, *args): 2733 return arg1 + args[0], arg2 + args[1] 2734 2735 kw_func = Module() 2736 args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4)) 2737 self._test_export_same_as_eager(kw_func, args) 2738 2739 def test_export_func_with_keyword_only_args(self): 2740 class Module(torch.nn.Module): 2741 def forward(self, arg1, arg2, *args, kw1, kw2): 2742 return arg1 + args[0] + kw1, arg2 + args[1] + kw2 2743 2744 kw_func = Module() 2745 args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4)) 2746 kwargs = {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)} 2747 self._test_export_same_as_eager(kw_func, args, kwargs) 2748 2749 def test_export_func_with_var_keyword_args(self): 2750 class Module(torch.nn.Module): 2751 def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs): 2752 return ( 2753 arg1 + args[0] + kw1 + kwargs["kw3"], 2754 arg2 + args[1] + kw2 + kwargs["kw4"], 2755 ) 2756 2757 kw_func = Module() 2758 args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4)) 2759 kwargs = { 2760 "kw1": torch.ones(2, 3), 2761 "kw2": torch.ones(3, 4), 2762 "kw3": torch.ones(2, 3), 2763 "kw4": torch.ones(3, 4), 2764 } 2765 self._test_export_same_as_eager(kw_func, args, kwargs) 2766 2767 def test_unbacked_slice(self): 2768 class M(torch.nn.Module): 2769 def forward(self, scores, score_thr, topk: torch.Tensor, results=None): 2770 valid_mask = scores > score_thr 2771 scores = scores[valid_mask] 2772 valid_idxs = torch.nonzero(valid_mask).to(scores.device) 2773 2774 num_topk = torch.minimum(topk, torch.tensor(valid_idxs.shape[0])).item() 2775 torch._check_is_size(num_topk) 2776 torch._check(scores.shape[0] >= num_topk) 2777 scores, idxs = scores.sort(descending=True) 2778 scores = scores[:num_topk] 2779 topk_idxs = valid_idxs[idxs[:num_topk]] 2780 keep_idxs, labels = topk_idxs.unbind(dim=1) 2781 2782 return scores, labels, keep_idxs 2783 2784 score = torch.tensor( 2785 [[0.1, 0.3, 0.2], [0.12, 0.7, 0.9], [0.02, 0.8, 0.08], [0.4, 0.1, 0.08]] 2786 ) 2787 bbox_pred = torch.tensor([[0.2, 0.3], [0.4, 0.7], [0.1, 0.1], [0.5, 0.1]]) 2788 score_thr = 0.15 2789 nms_pre = torch.tensor(4) 2790 inputs = (score, score_thr, nms_pre, dict(bbox_pred=bbox_pred)) 2791 2792 ep = torch.export.export(M(), inputs) 2793 orig_res = M()(*inputs) 2794 ep_res = ep.module()(*inputs) 2795 self.assertTrue(torch.allclose(orig_res[0], ep_res[0])) 2796 self.assertTrue(torch.allclose(orig_res[1], ep_res[1])) 2797 self.assertTrue(torch.allclose(orig_res[2], ep_res[2])) 2798 2799 def test_unflatten_asserts(self): 2800 # TODO: strict-export fails 2801 class M1(torch.nn.Module): 2802 def forward(self, x, y): 2803 b = x.item() 2804 2805 torch._check_is_size(b) 2806 torch._check(b < y.size(0)) 2807 return y[:b] 2808 2809 class M3(torch.nn.Module): 2810 def forward(self, x, y): 2811 b = x.item() 2812 2813 torch._check_is_size(b) 2814 torch._check(b < y.size(0) * 2) 2815 return y[:b] 2816 2817 class M2(torch.nn.Module): 2818 def __init__(self) -> None: 2819 super().__init__() 2820 self.m1 = M1() 2821 self.m3 = M3() 2822 2823 def forward(self, x, y): 2824 return self.m1(x, y) + self.m3(x, y) 2825 2826 inputs = (torch.tensor(3), torch.randn(10)) 2827 2828 ep = torch.export.export( 2829 M2(), inputs, dynamic_shapes={"x": None, "y": (Dim("moo"),)}, strict=False 2830 ) 2831 orig_res = M2()(*inputs) 2832 ep_res = ep.module()(*inputs) 2833 self.assertTrue(torch.allclose(orig_res[0], ep_res[0])) 2834 self.assertTrue(torch.allclose(orig_res[1], ep_res[1])) 2835 self.assertTrue(torch.allclose(orig_res[2], ep_res[2])) 2836 2837 unflattened = torch.export.unflatten(ep) 2838 ep_res = unflattened(*inputs) 2839 self.assertTrue(torch.allclose(orig_res[0], ep_res[0])) 2840 self.assertTrue(torch.allclose(orig_res[1], ep_res[1])) 2841 self.assertTrue(torch.allclose(orig_res[2], ep_res[2])) 2842 2843 def test_export_func_with_var_keyword_pytree_args(self): 2844 class Module(torch.nn.Module): 2845 def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs): 2846 return ( 2847 arg1 + arg2[0][0] + args[0] + kw1[0] + kwargs["kw3"][0], 2848 arg2[1] + args[1] + kw2 + kwargs["kw4"], 2849 ) 2850 2851 kw_func = Module() 2852 args = ( 2853 torch.ones(2, 3), 2854 [(torch.ones(2, 3),), torch.ones(3, 4)], 2855 torch.ones(2, 3), 2856 torch.ones(3, 4), 2857 ) 2858 kwargs = { 2859 "kw1": (torch.ones(2, 3),), 2860 "kw2": torch.ones(3, 4), 2861 "kw3": (torch.ones(2, 3), torch.ones(3, 4)), 2862 "kw4": torch.ones(3, 4), 2863 } 2864 self._test_export_same_as_eager(kw_func, args, kwargs) 2865 2866 @testing.expectedFailureSerDer # we don't save placeholder metadata 2867 @testing.expectedFailureNonStrict 2868 @testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure 2869 def test_linear_conv(self): 2870 class MyLinear(torch.nn.Module): 2871 def __init__(self) -> None: 2872 super().__init__() 2873 self.weight = torch.randn(20, 98) 2874 self.bias = torch.randn(20) 2875 2876 def forward(self, x): 2877 return torch.nn.functional.linear(x, self.weight, self.bias) 2878 2879 class Foo(torch.nn.Module): 2880 def __init__(self) -> None: 2881 super().__init__() 2882 self.conv = torch.nn.Conv2d(16, 33, 3) 2883 self.linear = MyLinear() 2884 2885 def forward(self, x): 2886 x_conv = self.conv(x) 2887 x_linear = self.linear(x_conv) 2888 return x_linear.cos() 2889 2890 ep = export(Foo(), (torch.randn(20, 16, 50, 100),)) 2891 for node in ep.graph.nodes: 2892 if ( 2893 node.op == "placeholder" 2894 and node.name in ep.graph_signature.inputs_to_buffers 2895 or node.name in ep.graph_signature.inputs_to_parameters 2896 ): 2897 self.assertTrue("source_fn_stack" in node.meta) 2898 2899 def test_export_api_with_dynamic_shapes(self): 2900 from torch.export import Dim, dims, export 2901 2902 # pass dynamic shapes of inputs [args] 2903 class Foo(torch.nn.Module): 2904 def forward(self, x, y): 2905 return torch.matmul(x, y) 2906 2907 foo = Foo() 2908 inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4)) 2909 batch = Dim("batch") 2910 efoo = export( 2911 foo, 2912 inputs, 2913 dynamic_shapes={k: {0: batch} for k in ["x", "y"]}, 2914 ) 2915 self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) 2916 2917 foo = Foo() 2918 inputs = (torch.randn(10, 2, 3),) 2919 kwinputs = {"y": torch.randn(10, 3, 4)} 2920 batch = Dim("batch") 2921 efoo = export( 2922 foo, inputs, kwinputs, dynamic_shapes={k: {0: batch} for k in ["x", "y"]} 2923 ) 2924 self.assertEqual( 2925 efoo.module()(*inputs, **kwinputs).shape, foo(*inputs, **kwinputs).shape 2926 ) 2927 2928 # pass dynamic shapes of inputs [partial, error] 2929 foo = Foo() 2930 inputs = (torch.randn(10, 2, 3),) 2931 kwinputs = {"y": torch.randn(10, 3, 4)} 2932 batch = Dim("batch") 2933 with self.assertRaisesRegex( 2934 torch._dynamo.exc.UserError, 2935 ( 2936 "Constraints violated \\(batch\\)!(.*\n)*.*" 2937 "batch was inferred to be a constant(.*\n)*.*" 2938 "Suggested fixes:(.*\n)*.*" 2939 "batch = 10" 2940 ), 2941 ): 2942 export( 2943 foo, 2944 inputs, 2945 kwinputs, 2946 dynamic_shapes={"x": {0: batch}, "y": None}, 2947 ) 2948 2949 # pass dynamic shapes of inputs [module] 2950 foo = Foo() 2951 inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4)) 2952 batch = Dim("batch") 2953 efoo = export( 2954 foo, 2955 inputs, 2956 dynamic_shapes={"x": {0: batch}, "y": {0: batch}}, 2957 ) 2958 self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) 2959 2960 # pass dynamic shapes of inputs [bounds, mostly shared] 2961 foo = Foo() 2962 inputs = (torch.randn(10, 3, 3), torch.randn(10, 3, 3)) 2963 batch = Dim("batch", min=8, max=64) 2964 size = Dim("size") 2965 efoo = export( 2966 foo, 2967 inputs, 2968 dynamic_shapes={ 2969 "x": (batch, size, size), 2970 "y": (batch, size, size), 2971 }, 2972 ) 2973 self.assertEqual( 2974 [ 2975 str(node.meta["val"].shape) 2976 for node in efoo.graph_module.graph.nodes 2977 if node.op == "placeholder" 2978 ], 2979 ["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"], 2980 ) 2981 self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) 2982 2983 # pass dynamic shapes of inputs [multiple, mostly distinct] 2984 inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4)) 2985 batch, M, K, N = dims("batch", "M", "K", "N") 2986 efoo = export( 2987 Foo(), 2988 inputs, 2989 dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)}, 2990 ) 2991 self.assertEqual( 2992 [ 2993 str(node.meta["val"].shape) 2994 for node in efoo.graph_module.graph.nodes 2995 if node.op == "placeholder" 2996 ], 2997 ["torch.Size([s0, s1, s2])", "torch.Size([s0, s2, s5])"], 2998 ) 2999 self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) 3000 3001 # pass dynamic shapes of inputs [dict] 3002 class Foo(torch.nn.Module): 3003 def forward(self, inputs): 3004 return torch.matmul(inputs["x"], inputs["y"]) 3005 3006 foo = Foo() 3007 inputs = ({"x": torch.randn(10, 2, 3), "y": torch.randn(10, 3, 4)},) 3008 batch = Dim("batch") 3009 efoo = export( 3010 foo, inputs, dynamic_shapes={"inputs": {k: {0: batch} for k in ["x", "y"]}} 3011 ) 3012 self.assertEqual( 3013 [ 3014 str(node.meta["val"].shape) 3015 for node in efoo.graph_module.graph.nodes 3016 if node.op == "placeholder" 3017 ], 3018 ["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"], 3019 ) 3020 self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) 3021 3022 # pass dynamic shapes of inputs [list] 3023 class Foo(torch.nn.Module): 3024 def forward(self, inputs): 3025 return torch.matmul(inputs[0], inputs[1]) 3026 3027 foo = Foo() 3028 inputs = ([torch.randn(10, 2, 3), torch.randn(10, 3, 4)],) 3029 batch = Dim("batch") 3030 efoo = export( 3031 foo, inputs, dynamic_shapes={"inputs": [{0: batch} for _ in range(2)]} 3032 ) 3033 self.assertEqual( 3034 [ 3035 str(node.meta["val"].shape) 3036 for node in efoo.graph_module.graph.nodes 3037 if node.op == "placeholder" 3038 ], 3039 ["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"], 3040 ) 3041 self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape) 3042 3043 # pass dynamic shapes of inputs [dataclass] 3044 3045 # TODO(avik): This part of the test should have failed both serde and retracing 3046 # but these failures are hidden because of the local import of `export` in this test. 3047 # The serde failure is benign, and easily avoided by moving the dataclass definition 3048 # to the top-level. OTOH the retracing failure needs further investigation. 3049 @dataclass 3050 class DataClass: 3051 a: Tensor 3052 b: Tensor 3053 3054 register_dataclass_as_pytree_node( 3055 DataClass, 3056 serialized_type_name="test_export_api_with_dynamic_shapes.DataClass", 3057 ) 3058 3059 class Foo(torch.nn.Module): 3060 def forward(self, inputs): 3061 return torch.matmul(inputs.a, inputs.b) 3062 3063 foo = Foo() 3064 inputs = (DataClass(a=torch.randn(10, 2, 3), b=torch.randn(10, 3, 4)),) 3065 batch = Dim("batch") 3066 efoo = export( 3067 foo, 3068 inputs, 3069 dynamic_shapes={"inputs": [{0: batch}, {0: batch}]}, 3070 ) 3071 self.assertEqual( 3072 [ 3073 str(node.meta["val"].shape) 3074 for node in efoo.graph_module.graph.nodes 3075 if node.op == "placeholder" 3076 ], 3077 ["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"], 3078 ) 3079 3080 # pass dynamic shapes of inputs [pytree-registered classes] 3081 if HAS_TORCHREC: 3082 # skipping tests if torchrec not available 3083 class Foo(torch.nn.Module): 3084 def forward(self, kjt) -> torch.Tensor: 3085 return kjt.values() + 0, kjt.offsets() + 0 3086 3087 foo = Foo() 3088 kjt = KeyedJaggedTensor( 3089 values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), 3090 keys=["index_0", "index_1"], 3091 lengths=torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3]), 3092 offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]), 3093 ) 3094 inputs = (kjt,) 3095 dim = Dim("dim") 3096 dim_plus_one = Dim("dim_plus_one") 3097 efoo = torch.export.export( 3098 foo, 3099 inputs, 3100 dynamic_shapes={"kjt": [{0: dim}, None, {0: dim}, {0: dim_plus_one}]}, 3101 ) 3102 self.assertEqual( 3103 [out.shape for out in efoo.module()(*inputs)], 3104 [out.shape for out in foo(*inputs)], 3105 ) 3106 3107 # pass dynamic shapes of inputs [distinct, error] 3108 class Foo(torch.nn.Module): 3109 def forward(self, x, y): 3110 return torch.matmul(x, y) 3111 3112 foo = Foo() 3113 inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4)) 3114 batch, M, K1, K2, N = dims("batch", "M", "K1", "K2", "N") 3115 with self.assertRaisesRegex( 3116 torch._dynamo.exc.UserError, 3117 ( 3118 "Constraints violated \\(K2\\)!(.*\n)*.*" 3119 "K2.*and.*K1.*must always be equal(.*\n)*.*" 3120 "Suggested fixes:(.*\n)*.*" 3121 "K2 = K1" 3122 ), 3123 ): 3124 export( 3125 foo, 3126 inputs, 3127 dynamic_shapes={"x": (batch, M, K1), "y": (batch, K2, N)}, 3128 ) 3129 3130 # pass dynamic shapes of inputs [specialized, error] 3131 foo = Foo() 3132 inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4)) 3133 batch, M, K1, N = dims("batch", "M", "K1", "N") 3134 with self.assertRaisesRegex( 3135 torch._dynamo.exc.UserError, 3136 ( 3137 "Constraints violated \\(K1\\)!(.*\n)*.*" 3138 "K1 was inferred to be a constant(.*\n)*.*" 3139 "Suggested fixes:(.*\n)*.*" 3140 "K1 = 3" 3141 ), 3142 ): 3143 export( 3144 foo, 3145 inputs, 3146 dynamic_shapes={"x": (batch, M, K1), "y": (batch, None, N)}, 3147 ) 3148 3149 # pass dynamic shapes of inputs [guards, error] 3150 class Foo(torch.nn.Module): 3151 def forward(self, x, y): 3152 if x.shape[0] < 16 and y.shape[1] % 3 == 0: 3153 return torch.matmul(x, y) 3154 else: 3155 return x + y 3156 3157 foo = Foo() 3158 inputs = (torch.randn(10, 2, 3), torch.randn(10, 3, 4)) 3159 batch, M, K, N = dims("batch", "M", "K", "N") 3160 with self.assertRaisesRegex( 3161 torch._dynamo.exc.UserError, 3162 ( 3163 "Constraints violated.*!(.*\n)*.*" 3164 "Not all values of K.*satisfy the generated guard(.*\n)*.*" 3165 "Not all values of batch.*satisfy the generated guard(.*\n)*.*" 3166 "Suggested fixes:(.*\n)*.*" 3167 "batch = Dim\\('batch', max=15\\)(.*\n)*.*" 3168 "K = 3\\*_K" 3169 ), 3170 ): 3171 export( 3172 foo, 3173 inputs, 3174 dynamic_shapes={"x": (batch, M, K), "y": (batch, K, N)}, 3175 ) 3176 3177 def test_suggested_fixes_new_roots(self): 3178 from torch.export import dims 3179 3180 # suggested fixes should introduce new root dim for modulo guard 3181 class Foo(torch.nn.Module): 3182 def forward(self, x, y, z): 3183 # dy = 3 * _dx 3184 # dx = 3 * _dx - 1 3185 # dz = 3 * _dx + 2 3186 # suggested fixes results will look something like 3187 # {"dx": {"eq": 3*_dx-1, "min": 5, "max": 36}, "dy": {"eq": dx+1}, ...} 3188 if x.shape[0] >= 5 and x.shape[0] <= 36 and y.shape[0] % 3 == 0: 3189 return x + y[1:] + z[3:] 3190 3191 foo = Foo() 3192 inputs = ( 3193 torch.randn( 3194 11, 3195 ), 3196 torch.randn( 3197 12, 3198 ), 3199 torch.randn( 3200 14, 3201 ), 3202 ) 3203 dx, dy, dz = dims("dx", "dy", "dz") 3204 dynamic_shapes = { 3205 "x": (dx,), 3206 "y": (dy,), 3207 "z": (dz,), 3208 } 3209 with self.assertRaisesRegex( # figure out regex later 3210 torch._dynamo.exc.UserError, 3211 ( 3212 "Constraints violated.*!(.*\n)*.*" 3213 "Suggested fixes(.*\n)*.*" 3214 "_dx = Dim\(\\'_dx\\', max=12\)(.*\n)*.*" 3215 "dx = 3\*_dx - 1(.*\n)*.*" 3216 "dy = 3\*_dx(.*\n)*.*" 3217 "dz = 3\*_dx \+ 2" 3218 ), 3219 ): 3220 export(Foo(), inputs, dynamic_shapes=dynamic_shapes) 3221 # retry export 3222 _dx = Dim("_dx", min=2, max=12) 3223 dynamic_shapes = {"x": (3 * _dx - 1,), "y": (3 * _dx,), "z": (3 * _dx + 2,)} 3224 export(Foo(), inputs, dynamic_shapes=dynamic_shapes) 3225 3226 def test_refine_dynamic_shapes_from_suggested_fixes(self): 3227 from torch.export.dynamic_shapes import ( 3228 refine_dynamic_shapes_from_suggested_fixes, 3229 ) 3230 3231 def helper(model, inputs, dynamic_shapes): 3232 # export, fail, parse & refine suggested fixes, re-export 3233 try: 3234 export(Foo(), inps, dynamic_shapes=dynamic_shapes) 3235 raise Exception("should have raised constraint violation error") 3236 except torch._dynamo.exc.UserError as exc: 3237 new_shapes = refine_dynamic_shapes_from_suggested_fixes( 3238 exc.msg, dynamic_shapes 3239 ) 3240 export(Foo(), inps, dynamic_shapes=new_shapes) 3241 return new_shapes 3242 3243 # specialize dims + derived dims 3244 class Foo(torch.nn.Module): 3245 def forward(self, x, y, z): 3246 x0 = x + y[1:] + z[2:] 3247 x1 = x @ torch.randn(4, 4) 3248 return x0, x1 3249 3250 inps = ( 3251 torch.randn( 3252 4, 3253 ), 3254 torch.randn( 3255 5, 3256 ), 3257 torch.randn( 3258 6, 3259 ), 3260 ) 3261 dx = Dim("dx", max=16) 3262 dynamic_shapes = {"x": (dx,), "y": (dx + 1,), "z": (dx + 2,)} 3263 new_shapes = helper(Foo(), inps, dynamic_shapes) 3264 self.assertEqual(new_shapes["x"][0], 4) 3265 self.assertEqual(new_shapes["z"][0], 6) 3266 3267 # refine lower, upper bound 3268 class Foo(torch.nn.Module): 3269 def forward(self, x, y): 3270 if x.shape[0] >= 6 and y.shape[0] <= 16: 3271 return x * 2.0, y + 1 3272 3273 inps = (torch.randn(16), torch.randn(12)) 3274 dynamic_shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)} 3275 new_shapes = helper(Foo(), inps, dynamic_shapes) 3276 self.assertEqual(new_shapes["x"][0].min, 6) 3277 self.assertEqual(new_shapes["y"][0].max, 16) 3278 3279 # divisiblity, will introduce new root 3280 class Foo(torch.nn.Module): 3281 def forward(self, x): 3282 if x.shape[0] >= 9: 3283 return x.reshape([-1, 3]) 3284 3285 inps = ( 3286 torch.randn( 3287 15, 3288 ), 3289 ) 3290 dynamic_shapes = ((Dim("dx"),),) 3291 new_shapes = helper(Foo(), inps, dynamic_shapes) 3292 dim = new_shapes[0][0] 3293 root = dim.root 3294 self.assertEqual(dim.fn(2), 6) 3295 self.assertEqual(root.min, 3) 3296 3297 # turn dim into derived dim/relation 3298 class Foo(torch.nn.Module): 3299 def forward(self, x, y): 3300 return x + y[4:] 3301 3302 inps = (torch.randn(6, 4), torch.randn(10, 4)) 3303 dynamic_shapes = { 3304 "x": (Dim("dx0"), Dim("dx1")), 3305 "y": (Dim("dy0"), Dim("dy1")), 3306 } 3307 new_shapes = helper(Foo(), inps, dynamic_shapes) 3308 self.assertEqual(new_shapes["x"][0], new_shapes["y"][0].root) # dy0 = dx0 + 4 3309 self.assertEqual(new_shapes["y"][0].fn(5), 9) 3310 self.assertEqual(new_shapes["x"][1], new_shapes["y"][1]) # dx1 = dy1 3311 3312 # nested dynamic shapes spec 3313 class Foo(torch.nn.Module): 3314 def forward(self, x, y): 3315 x0 = x[0]["data"] + x[1] + x[2][2:] 3316 x1 = y["a"] @ torch.randn(4, 4) 3317 x2 = y["b"] @ torch.randn(6, 6) 3318 return x0, x1, x2 3319 3320 inps = ( 3321 [ 3322 {"data": torch.randn(4, 4)}, 3323 torch.randn(4, 4), 3324 torch.randn(6, 4), 3325 ], 3326 { 3327 "a": torch.randn(8, 4), 3328 "b": torch.randn(9, 6), 3329 }, 3330 ) 3331 dynamic_shapes = { 3332 "x": [ 3333 {"data": (Dim("dx00"), Dim("dx01"))}, 3334 (Dim("dx10"), Dim("dx11")), 3335 (Dim("dx20"), Dim("dx21")), 3336 ], 3337 "y": { 3338 "a": (Dim("dya0"), Dim("dya1")), 3339 "b": (Dim("dyb0"), Dim("dyb1")), 3340 }, 3341 } 3342 new_shapes = helper(Foo(), inps, dynamic_shapes) 3343 self.assertEqual( 3344 new_shapes["x"][0]["data"][0], new_shapes["x"][1][0] 3345 ) # dx10 = dx00 3346 self.assertEqual( 3347 new_shapes["x"][2][0].root, new_shapes["x"][0]["data"][0] 3348 ) # dx20 = dx00 + 2 3349 self.assertEqual(new_shapes["x"][2][0].fn(10), 12) 3350 self.assertEqual( 3351 new_shapes["x"][0]["data"][1], new_shapes["x"][1][1] 3352 ) # dx11 = dx01 3353 self.assertEqual(new_shapes["y"]["a"][1], 4) 3354 self.assertEqual(new_shapes["y"]["b"][1], 6) 3355 self.assertEqual(new_shapes["y"]["b"][0].__name__, "dyb0") # unchanged 3356 3357 def test_dynamic_shapes_spec_with_pytree(self): 3358 from torch.export import Dim, export 3359 from torch.utils._pytree import tree_map 3360 3361 inputs = { 3362 "tensor": torch.randn(3), 3363 "dict_of_tensors": {k: torch.randn(3) for k in ["A", "B", "C", "D"]}, 3364 "list_of_tensors": [torch.randn(3) for _ in range(4)], 3365 } 3366 3367 batch = Dim("batch") 3368 # uniformly specify dynamic shapes for all inputs 3369 spec = tree_map(lambda x: {0: batch}, inputs) 3370 3371 class Foo(torch.nn.Module): 3372 def forward(self, inputs): 3373 return ( 3374 inputs["tensor"] 3375 + inputs["dict_of_tensors"]["A"] 3376 + inputs["list_of_tensors"][0] 3377 ) 3378 3379 ep = export(Foo(), (inputs,), dynamic_shapes={"inputs": spec}) 3380 input_shapes = [ 3381 str(node.meta["val"].shape) 3382 for node in ep.graph_module.graph.nodes 3383 if node.op == "placeholder" 3384 ] 3385 self.assertEqual(len(input_shapes), 9) 3386 self.assertTrue(all(shape == "torch.Size([s0])" for shape in input_shapes)) 3387 3388 def test_error_does_not_reference_eager_fallback(self): 3389 class Module(torch.nn.Module): 3390 def forward(self, x): 3391 y = x.nonzero() 3392 z = y.shape[0] 3393 if z > 2: 3394 return x.cos() 3395 else: 3396 return x.sin() 3397 3398 fn_ddo = Module() 3399 if is_non_strict_test(self._testMethodName): 3400 error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode 3401 error_msg = r"Could not guard on data-dependent expression" 3402 else: 3403 error = torchdynamo.exc.UserError 3404 error_msg = r"^(?!.*fall back to eager).*" 3405 with self.assertRaisesRegex(error, error_msg): 3406 _ = export(fn_ddo, (torch.tensor([2, 3, 5]),)) 3407 3408 def test_pytree_register_data_class(self): 3409 @dataclass 3410 class MyDataClass: 3411 x: int 3412 y: int 3413 z: int = None 3414 3415 dt = MyDataClass(x=3, y=4) 3416 flat, spec = tree_flatten(dt) 3417 self.assertTrue(spec, LeafSpec()) 3418 self.assertTrue(len(flat) == 1) 3419 3420 register_dataclass_as_pytree_node( 3421 MyDataClass, 3422 serialized_type_name="test_pytree_register_data_class.MyDataClass", 3423 ) 3424 3425 flat, spec = tree_flatten(dt) 3426 self.assertEqual( 3427 spec, 3428 TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]), 3429 ) 3430 self.assertEqual(flat, [3, 4]) 3431 3432 orig_dt = tree_unflatten(flat, spec) 3433 self.assertTrue(isinstance(orig_dt, MyDataClass)) 3434 self.assertEqual(orig_dt.x, 3) 3435 self.assertEqual(orig_dt.y, 4) 3436 self.assertEqual(orig_dt.z, None) 3437 3438 roundtrip_spec = treespec_loads(treespec_dumps(spec)) 3439 self.assertEqual(roundtrip_spec, spec) 3440 3441 @dataclass 3442 class MyOtherDataClass: # the pytree registration don't allow registering the same class twice 3443 x: int 3444 y: int 3445 z: int = None 3446 3447 # Override the registration with keep none fields 3448 register_dataclass_as_pytree_node( 3449 MyOtherDataClass, 3450 return_none_fields=True, 3451 serialized_type_name="test_pytree_regster_data_class.MyOtherDataClass", 3452 ) 3453 3454 dt = MyOtherDataClass(x=3, y=4) 3455 flat, spec = tree_flatten(dt) 3456 self.assertEqual( 3457 spec, 3458 TreeSpec( 3459 MyOtherDataClass, 3460 [["x", "y", "z"], []], 3461 [LeafSpec(), LeafSpec(), LeafSpec()], 3462 ), 3463 ) 3464 self.assertEqual(flat, [3, 4, None]) 3465 3466 orig_dt = tree_unflatten(flat, spec) 3467 self.assertTrue(isinstance(orig_dt, MyOtherDataClass)) 3468 self.assertEqual(orig_dt.x, 3) 3469 self.assertEqual(orig_dt.y, 4) 3470 self.assertEqual(orig_dt.z, None) 3471 3472 roundtrip_spec = treespec_loads(treespec_dumps(spec)) 3473 self.assertEqual(roundtrip_spec, spec) 3474 3475 def test_pytree_register_nested_data_class(self): 3476 @dataclass 3477 class Inner: 3478 x: int 3479 y: int 3480 3481 @dataclass 3482 class Outer: 3483 xy: Inner 3484 ab: Inner 3485 3486 xy = Inner(1, 2) 3487 ab = Inner(3, 4) 3488 dt = Outer(xy, ab) 3489 inp = {"dt1": (dt, ({},)), "dt2": ((torch.ones(1),), dt)} 3490 3491 register_dataclass_as_pytree_node( 3492 Inner, serialized_type_name="test_pytree_register_nested_data_class.Inner" 3493 ) 3494 register_dataclass_as_pytree_node( 3495 Outer, serialized_type_name="test_pytree_register_nested_data_class.Outer" 3496 ) 3497 3498 flat, spec = tree_flatten(inp) 3499 self.assertEqual(flat, [1, 2, 3, 4, torch.ones(1), 1, 2, 3, 4]) 3500 3501 unflat = tree_unflatten(flat, spec) 3502 self.assertEqual(unflat, inp) 3503 3504 roundtrip_spec = treespec_loads(treespec_dumps(spec)) 3505 self.assertEqual(roundtrip_spec, spec) 3506 3507 def test_param_util(self): 3508 class Basic(torch.nn.Module): 3509 def __init__(self) -> None: 3510 super().__init__() 3511 self.lin = torch.nn.Linear(10, 1) 3512 3513 def forward(self, x): 3514 return self.lin(x) 3515 3516 ep = export(Basic(), (torch.randn(5, 10),)) 3517 num_params = 0 3518 params = [] 3519 for node in ep.graph.nodes: 3520 if is_param(ep, node): 3521 num_params += 1 3522 params.append(get_param(ep, node)) 3523 self.assertEqual(num_params, 2) 3524 self.assertEqual(params[0].shape, [1, 10]) # weight 3525 self.assertEqual(params[1].shape, [1]) # bias 3526 3527 def test_buffer_util(self): 3528 ep = export( 3529 torch.nn.BatchNorm2d(100, affine=False), (torch.ones(20, 100, 35, 45),) 3530 ) 3531 num_buffer = 0 3532 buffer = [] 3533 3534 for node in ep.graph.nodes: 3535 if is_buffer(ep, node): 3536 num_buffer += 1 3537 buffer.append(get_buffer(ep, node)) 3538 self.assertEqual(num_buffer, 3) 3539 3540 self.assertEqual(buffer[0].shape, torch.Size([100])) # running_mean 3541 self.assertEqual(buffer[1].shape, torch.Size([100])) # running_var 3542 self.assertEqual(buffer[2].shape, torch.Size([])) # num_batches_tracked 3543 3544 def test_export_dynamo_config(self): 3545 class MyModule(torch.nn.Module): 3546 def __init__(self) -> None: 3547 super().__init__() 3548 self.lstm = torch.nn.LSTM(input_size=4, hidden_size=5, num_layers=1) 3549 3550 def forward(self, inputs: torch.Tensor) -> torch.Tensor: 3551 return self.lstm(inputs) 3552 3553 config = DEFAULT_EXPORT_DYNAMO_CONFIG 3554 mod = MyModule() 3555 3556 @contextmanager 3557 def _patch_config(kwargs): 3558 orig_config_dict = dataclasses.asdict(config) 3559 3560 try: 3561 for k, v in kwargs.items(): 3562 setattr(config, k, v) 3563 yield 3564 finally: 3565 for k, v in orig_config_dict.items(): 3566 setattr(config, k, v) 3567 3568 inp = (torch.rand(5, 4),) 3569 exported_program = export(mod, inp, strict=True) 3570 3571 with _patch_config({"allow_rnn": False}): 3572 with self.assertRaisesRegex( 3573 torch._dynamo.exc.Unsupported, 3574 "TorchDynamo purposely graph breaks on RNN, GRU, LSTMs", 3575 ): 3576 _ = export(mod, inp, strict=True) 3577 3578 def test_device_to_static(self): 3579 class Module(torch.nn.Module): 3580 def forward(self, x): 3581 return x.to("cpu") 3582 3583 ep = export(Module(), (torch.tensor(1, device="cpu"),)) 3584 ops = [] 3585 for node in ep.graph.nodes: 3586 if node.op == "call_function": 3587 ops.append(node.target) 3588 self.assertGreater(len(ops), 0) 3589 for op in ops: 3590 self.assertIn(op, (torch.ops.aten._to_copy.default,)) 3591 3592 def test_device_to_dynamic(self): 3593 class Module(torch.nn.Module): 3594 def forward(self, x): 3595 return x.to("cpu") 3596 3597 ep = export( 3598 Module(), 3599 (torch.tensor([1, 2], device="cpu"),), 3600 dynamic_shapes={"x": {0: Dim("i")}}, 3601 ) 3602 ops = [] 3603 for node in ep.graph.nodes: 3604 if node.op == "call_function": 3605 ops.append(node.target) 3606 self.assertGreater(len(ops), 0) 3607 for op in ops: 3608 self.assertIn(op, (torch.ops.aten._to_copy.default,)) 3609 3610 def test_device_to_mutation(self): 3611 class Module(torch.nn.Module): 3612 def forward(self, x): 3613 y = x.to("cpu") 3614 y.add_(1) 3615 return y, x 3616 3617 with self.assertRaisesRegex( 3618 RuntimeError, "cannot mutate tensors with frozen storage" 3619 ): 3620 export(Module(), (torch.tensor(1, device="cpu"),)) 3621 3622 def test_float_conversion(self): 3623 class Module(torch.nn.Module): 3624 def forward(self, x): 3625 return x.float() 3626 3627 ep = export(Module(), (torch.tensor(1, dtype=torch.float),)) 3628 ops = [] 3629 for node in ep.graph.nodes: 3630 if node.op == "call_function": 3631 ops.append(node.target) 3632 self.assertGreater(len(ops), 0) 3633 for op in ops: 3634 self.assertIn(op, (torch.ops.aten._to_copy.default,)) 3635 3636 def test_device_to_mutation_float(self): 3637 class Module(torch.nn.Module): 3638 def forward(self, x): 3639 y = x.float() 3640 y.add_(1) 3641 return y, x 3642 3643 with self.assertRaisesRegex( 3644 RuntimeError, "cannot mutate tensors with frozen storage" 3645 ): 3646 export(Module(), (torch.tensor(1, dtype=torch.float),)) 3647 3648 def test_module(self): 3649 class MyLinear(torch.nn.Module): 3650 def __init__(self) -> None: 3651 super().__init__() 3652 self.weight = torch.randn(20, 98) 3653 self.bias = torch.randn(20) 3654 3655 def forward(self, x): 3656 return torch.nn.functional.linear(x, self.weight, self.bias) 3657 3658 class Foo(torch.nn.Module): 3659 def __init__(self) -> None: 3660 super().__init__() 3661 self.conv = torch.nn.Conv2d(16, 33, 3) 3662 self.linear = MyLinear() 3663 3664 def forward(self, x): 3665 a, b = x 3666 a_conv = self.conv(a) 3667 a_linear = self.linear(a_conv) 3668 b_conv = self.conv(b) 3669 b_linear = self.linear(b_conv) 3670 return ( 3671 a_linear.cos() + b_linear.sin(), 3672 a_linear.sin() + b_linear.cos(), 3673 ) 3674 3675 inp_container = ((torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),) 3676 3677 ep = export(Foo(), inp_container) 3678 ep_rexported = export(ep.module(), inp_container) 3679 3680 inp_test = ((torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),) 3681 3682 self.assertTrue( 3683 torch.allclose( 3684 ep.module()(*inp_test)[0], ep_rexported.module()(*inp_test)[0] 3685 ) 3686 ) 3687 self.assertTrue( 3688 torch.allclose( 3689 ep.module()(*inp_test)[1], ep_rexported.module()(*inp_test)[1] 3690 ) 3691 ) 3692 3693 def test_use_embedding_twice(self): 3694 class Foo(torch.nn.Module): 3695 def __init__(self): 3696 super().__init__() 3697 self.embed = torch.nn.Embedding(4, 4) 3698 3699 def forward(self, x): 3700 return self.embed(x) + self.embed.weight[x] 3701 3702 inputs = (torch.tensor([0, 1, 2, 3]),) 3703 ep = export(Foo(), inputs) 3704 3705 def test_module_with_dict_container_inp_out(self): 3706 class MyLinear(torch.nn.Module): 3707 def __init__(self) -> None: 3708 super().__init__() 3709 self.weight = torch.randn(20, 98) 3710 self.bias = torch.randn(20) 3711 3712 def forward(self, x): 3713 return torch.nn.functional.linear(x, self.weight, self.bias) 3714 3715 class Foo(torch.nn.Module): 3716 def __init__(self) -> None: 3717 super().__init__() 3718 self.conv = torch.nn.Conv2d(16, 33, 3) 3719 self.linear = MyLinear() 3720 3721 def forward(self, x): 3722 a1, a2 = x["a"] 3723 b = x["b"] 3724 a1_conv = self.conv(a1) 3725 a1_linear = self.linear(a1_conv) 3726 a2_conv = self.conv(a2) 3727 a2_linear = self.linear(a2_conv) 3728 b_conv = self.conv(b) 3729 b_linear = self.linear(b_conv) 3730 return { 3731 "a": a1_linear.cos() + b_linear.sin(), 3732 "b": a2_linear.sin() + b_linear.cos(), 3733 } 3734 3735 inp_container = ( 3736 { 3737 "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), 3738 "b": torch.randn(20, 16, 50, 100), 3739 }, 3740 ) 3741 3742 ep = export(Foo(), inp_container) 3743 ep_rexported = export(ep.module(), inp_container) 3744 3745 inp_test = ( 3746 { 3747 "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), 3748 "b": torch.randn(20, 16, 50, 100), 3749 }, 3750 ) 3751 3752 self.assertTrue( 3753 torch.allclose( 3754 ep.module()(*inp_test)["a"], ep_rexported.module()(*inp_test)["a"] 3755 ) 3756 ) 3757 self.assertTrue( 3758 torch.allclose( 3759 ep.module()(*inp_test)["b"], ep_rexported.module()(*inp_test)["b"] 3760 ) 3761 ) 3762 3763 def test_args_type_checked(self): 3764 class M(torch.nn.Module): 3765 def forward(self, x): 3766 return x + 1 3767 3768 inp = torch.rand(2, 2) 3769 with self.assertRaisesRegex(torch._dynamo.exc.UserError, "to be a tuple"): 3770 # Intentionally not wrapping `inp` in a tuple to trigger the error 3771 _ = export(M(), inp) 3772 3773 def test_decomp_batch_norm_functional_predispatch(self): 3774 class ConvBatchnorm(torch.nn.Module): 3775 def __init__(self) -> None: 3776 super().__init__() 3777 self.conv = torch.nn.Conv2d(1, 3, 1, 1) 3778 self.bn = torch.nn.BatchNorm2d(3) 3779 3780 def forward(self, x): 3781 x = self.conv(x) 3782 x = self.bn(x) 3783 return (x,) 3784 3785 mod = ConvBatchnorm() 3786 mod.eval() 3787 inp = torch.randn(1, 1, 3, 3) 3788 3789 gm = torch.export._trace._export(mod, (inp,), pre_dispatch=True).module() 3790 self.assertExpectedInline( 3791 str(gm.code).strip(), 3792 """\ 3793def forward(self, x): 3794 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 3795 conv_weight = self.conv.weight 3796 conv_bias = self.conv.bias 3797 bn_weight = self.bn.weight 3798 bn_bias = self.bn.bias 3799 bn_running_mean = self.bn.running_mean 3800 bn_running_var = self.bn.running_var 3801 bn_num_batches_tracked = self.bn.num_batches_tracked; bn_num_batches_tracked = None 3802 conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None 3803 _native_batch_norm_legit_no_training = torch.ops.aten._native_batch_norm_legit_no_training.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, 0.1, 1e-05); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None 3804 getitem = _native_batch_norm_legit_no_training[0]; _native_batch_norm_legit_no_training = None 3805 return pytree.tree_unflatten((getitem,), self._out_spec)""", 3806 ) 3807 3808 mod.train() 3809 gm_train = _export(mod, (inp,), pre_dispatch=True).module() 3810 self.assertExpectedInline( 3811 str(gm_train.code).strip(), 3812 """\ 3813def forward(self, x): 3814 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 3815 conv_weight = self.conv.weight 3816 conv_bias = self.conv.bias 3817 bn_weight = self.bn.weight 3818 bn_bias = self.bn.bias 3819 bn_running_mean = self.bn.running_mean 3820 bn_running_var = self.bn.running_var 3821 bn_num_batches_tracked = self.bn.num_batches_tracked 3822 conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None 3823 add = torch.ops.aten.add.Tensor(bn_num_batches_tracked, 1) 3824 _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05); conv2d = bn_weight = bn_bias = None 3825 getitem = _native_batch_norm_legit_functional[0] 3826 getitem_3 = _native_batch_norm_legit_functional[3] 3827 getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None 3828 copy__default = torch.ops.aten.copy_.default(bn_running_mean, getitem_3); bn_running_mean = getitem_3 = copy__default = None 3829 copy__default_1 = torch.ops.aten.copy_.default(bn_running_var, getitem_4); bn_running_var = getitem_4 = copy__default_1 = None 3830 copy__default_2 = torch.ops.aten.copy_.default(bn_num_batches_tracked, add); bn_num_batches_tracked = add = copy__default_2 = None 3831 return pytree.tree_unflatten((getitem,), self._out_spec)""", 3832 ) 3833 3834 def test_constrain_size_in_eager(self): 3835 class Module(torch.nn.Module): 3836 def forward(self, x, y): 3837 n = x.max().item() 3838 torch._check_is_size(n) 3839 return y + n 3840 3841 fn = Module() 3842 ep = export( 3843 fn, 3844 (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))), 3845 ) 3846 test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))) 3847 self.assertTrue(torch.allclose(ep.module()(*test_inp), fn(*test_inp))) 3848 3849 def test_constrain_size_with_constrain_value(self): 3850 class Module(torch.nn.Module): 3851 def forward(self, x, y): 3852 n = x.max().item() 3853 torch._check(n >= 2) 3854 torch._check(n <= 10) 3855 torch._check_is_size(n) 3856 return y + n 3857 3858 fn = Module() 3859 with self.assertRaisesRegex( 3860 RuntimeError, r"Expected cond to be True, but got False" 3861 ): 3862 _ = fn(torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))) 3863 3864 ep = export( 3865 fn, 3866 (torch.randint(3, 4, (2, 2)), torch.randint(3, 5, (2, 3))), 3867 ) 3868 with self.assertRaisesRegex( 3869 RuntimeError, r"Runtime assertion failed for expression u[\d+] \>\= 2" 3870 ): 3871 test_inp = (torch.randint(1, 2, (2, 2)), torch.randint(3, 5, (2, 3))) 3872 _ = ep.module()(*test_inp) 3873 3874 def test_constrain_size_with_various_cases(self): 3875 class Module1(torch.nn.Module): 3876 def forward(self, x, y): 3877 n = x.item() 3878 torch._check_is_size(n) 3879 torch._check(n >= 0) 3880 return y.sum() + torch.ones(n, 5).sum() 3881 3882 case1 = Module1() 3883 3884 class Module2(torch.nn.Module): 3885 def forward(self, x, y): 3886 n = x.item() 3887 torch._check_is_size(n) 3888 torch._check(n >= 0) 3889 torch._check(n <= 6) 3890 return y.sum() + torch.ones(n, 5).sum() 3891 3892 case2 = Module2() 3893 3894 class Module3(torch.nn.Module): 3895 def forward(self, x, y): 3896 n = x.item() 3897 torch._check_is_size(n) 3898 torch._check(n >= 0) 3899 torch._check(n <= 1) 3900 return y.sum() + torch.ones(n, 5).sum() 3901 3902 case3 = Module3() 3903 3904 class Module4(torch.nn.Module): 3905 def forward(self, x, y): 3906 n = x.item() 3907 torch._check_is_size(n) 3908 torch._check(n >= 2) 3909 return y.sum() + torch.ones(n, 5).sum() 3910 3911 case4 = Module4() 3912 3913 class Module5(torch.nn.Module): 3914 def forward(self, x, y): 3915 n = x.item() 3916 torch._check_is_size(n) 3917 torch._check(n >= 1) 3918 return y.sum() + torch.ones(n, 5).sum() 3919 3920 case5 = Module5() 3921 3922 ep = export(case1, (torch.tensor(1), torch.ones(4, 5))) 3923 3924 with self.assertRaisesRegex( 3925 RuntimeError, r"Expected cond to be True, but got False" 3926 ): 3927 _ = case1(torch.tensor(-1), torch.randn(4, 5)) 3928 3929 self.assertTrue( 3930 torch.allclose( 3931 ep.module()(torch.tensor(1), torch.ones(4, 5)), 3932 case1(torch.tensor(1), torch.ones(4, 5)), 3933 ) 3934 ) 3935 3936 ep = export(case2, (torch.tensor(5), torch.randn(4, 5))) 3937 3938 with self.assertRaisesRegex( 3939 RuntimeError, 3940 r"Expected cond to be True, but got False", 3941 ): 3942 _ = case2(torch.tensor(7), torch.randn(4, 5)) 3943 3944 with self.assertRaisesRegex( 3945 RuntimeError, 3946 r"Expected cond to be True, but got False", 3947 ): 3948 _ = case2(torch.tensor(9), torch.randn(4, 5)) 3949 3950 self.assertTrue( 3951 torch.allclose( 3952 ep.module()(torch.tensor(5), torch.ones(4, 5)), 3953 case2(torch.tensor(5), torch.ones(4, 5)), 3954 ) 3955 ) 3956 3957 _ = case3(torch.tensor(1), torch.randn(4, 5)) 3958 3959 with self.assertRaisesRegex( 3960 RuntimeError, 3961 r"Expected cond to be True, but got False", 3962 ): 3963 _ = case4(torch.tensor(1), torch.randn(4, 5)) 3964 3965 ep = export(case4, (torch.tensor(5), torch.randn(4, 5))) 3966 3967 with self.assertRaisesRegex( 3968 RuntimeError, 3969 r"Expected cond to be True, but got False", 3970 ): 3971 _ = case4(torch.tensor(1), torch.randn(4, 5)) 3972 3973 self.assertTrue( 3974 torch.allclose( 3975 ep.module()(torch.tensor(5), torch.ones(4, 5)), 3976 case4(torch.tensor(5), torch.ones(4, 5)), 3977 ) 3978 ) 3979 3980 ep = export(case5, (torch.tensor(5), torch.randn(4, 5))) 3981 3982 with self.assertRaisesRegex( 3983 RuntimeError, 3984 r"Expected cond to be True, but got False", 3985 ): 3986 _ = case5(torch.tensor(0), torch.randn(4, 5)) 3987 3988 self.assertTrue( 3989 torch.allclose( 3990 ep.module()(torch.tensor(5), torch.ones(4, 5)), 3991 case5(torch.tensor(5), torch.ones(4, 5)), 3992 ) 3993 ) 3994 3995 def test_automatic_constrain_size(self): 3996 class M(torch.nn.Module): 3997 def forward(self, x, y): 3998 n = x.item() 3999 return y.sum() + torch.ones(n, 5).sum() 4000 4001 ep = export(M(), (torch.tensor(1), torch.ones(4, 5))) 4002 4003 # This is because we insert sym_constrain_range in the graph now 4004 error_msg = r"Invalid value range for -1 between" 4005 with self.assertRaisesRegex(RuntimeError, error_msg): 4006 _ = ep.module()(torch.tensor(-1), torch.randn(4, 5)) 4007 4008 self.assertTrue( 4009 torch.allclose( 4010 ep.module()(torch.tensor(1), torch.ones(4, 5)), 4011 M()(torch.tensor(1), torch.ones(4, 5)), 4012 ) 4013 ) 4014 4015 def test_constrain_decomp(self) -> None: 4016 class M(torch.nn.Module): 4017 def __init__(self) -> None: 4018 super().__init__() 4019 self.freq = torch.ones(5, 5) 4020 4021 def forward(self, start_pos: torch.Tensor): 4022 pos = start_pos.item() 4023 torch._check_is_size(pos) 4024 torch._check(pos >= 0) 4025 torch._check(pos <= 4) 4026 return self.freq[pos] * self.freq[pos] 4027 4028 ep = torch.export.export(M(), (torch.tensor(1),)) 4029 FileCheck().check_count( 4030 "torch.ops.aten._assert_scalar.default", 2, exactly=True 4031 ).run(ep.graph_module.code) 4032 FileCheck().check_count( 4033 "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True 4034 ).run(ep.graph_module.code) 4035 4036 decompose_ep = ep.run_decompositions() 4037 FileCheck().check_count( 4038 "torch.ops.aten._assert_scalar.default", 2, exactly=True 4039 ).run(ep.graph_module.code) 4040 FileCheck().check_count( 4041 "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True 4042 ).run(ep.graph_module.code) 4043 4044 def test_mixed_input(self): 4045 class Module(torch.nn.Module): 4046 def forward(self, a, b, alpha: int): 4047 return torch.add(a, b, alpha=alpha) 4048 4049 func = Module() 4050 4051 a = torch.rand(1, 2) 4052 b = torch.rand(1, 2) 4053 alpha = 10 4054 4055 exported = export(func, (a, b, alpha)) 4056 for node in exported.graph_module.graph.nodes: 4057 if node.op == "placeholder": 4058 self.assertTrue(isinstance(node.meta["val"], (Tensor, int))) 4059 4060 def test_export_with_inline_constraints(self): 4061 class Module(torch.nn.Module): 4062 def forward(self, x): 4063 a = x.item() 4064 torch._check(a >= 4) 4065 torch._check(a <= 7) 4066 return torch.empty((a, 4)) 4067 4068 f = Module() 4069 ep = export(f, (torch.tensor([5]),)) 4070 self.assertEqual(ep.module()(torch.tensor([6])).shape, (6, 4)) 4071 4072 FileCheck().check_count( 4073 "torch.ops.aten._assert_scalar.default", 2, exactly=True 4074 ).run(ep.graph_module.code) 4075 FileCheck().check_count( 4076 "torch.ops.aten.sym_constrain_range.default", 0, exactly=True 4077 ).run(ep.graph_module.code) 4078 FileCheck().check_count( 4079 "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True 4080 ).run(ep.graph_module.code) 4081 4082 with self.assertRaisesRegex( 4083 RuntimeError, 4084 r"Runtime assertion failed for expression u[\d+] \<\= 7", 4085 ) as cm: 4086 ep.module()(torch.tensor([30])) 4087 4088 def test_export_with_inline_constraints_complex(self): 4089 class Module(torch.nn.Module): 4090 def forward(self, x): 4091 a = x.item() 4092 torch._check(a >= 4) 4093 torch._check(a <= 7) 4094 empty = torch.empty((a, 4)) 4095 4096 return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0) 4097 4098 f = Module() 4099 ep = export(f, (torch.tensor([6]),)) 4100 self.assertEqual(ep.module()(torch.tensor([5])).shape, (10, 5)) 4101 FileCheck().check_count( 4102 "torch.ops.aten._assert_scalar.default", 2, exactly=True 4103 ).run(ep.graph_module.code) 4104 FileCheck().check_count( 4105 "torch.ops.aten.sym_constrain_range.default", 0, exactly=True 4106 ).run(ep.graph_module.code) 4107 FileCheck().check_count( 4108 "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True 4109 ).run(ep.graph_module.code) 4110 4111 def test_to_module_with_mutated_buffer(self): 4112 class Foo(torch.nn.Module): 4113 def __init__(self) -> None: 4114 super().__init__() 4115 self.buf = torch.nn.Buffer(torch.zeros(1)) 4116 4117 def forward(self, x): 4118 self.buf.add_(1) 4119 return x.sum() + self.buf.sum() 4120 4121 exported = export(Foo(), (torch.ones(5, 5),)) 4122 stateful_gm = exported.module() 4123 export_return_val = stateful_gm(torch.ones(5, 5)) 4124 eager = Foo() 4125 eager_return_val = eager(torch.ones(5, 5)) 4126 self.assertTrue(torch.allclose(eager_return_val, export_return_val)) 4127 4128 for name, buffer in stateful_gm.named_buffers(): 4129 self.assertTrue(torch.allclose(torch.ones(1), buffer)) 4130 4131 changed = stateful_gm.graph.eliminate_dead_code() 4132 self.assertFalse(changed) 4133 self.assertTrue( 4134 torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5))) 4135 ) 4136 4137 for name, buffer in stateful_gm.named_buffers(): 4138 self.assertTrue(torch.allclose(torch.tensor(2, dtype=torch.float), buffer)) 4139 4140 def test_to_module_with_mutated_buffer_multiple(self): 4141 class Bar(torch.nn.Module): 4142 def __init__(self) -> None: 4143 super().__init__() 4144 self.buf = torch.nn.Buffer(torch.ones(1)) 4145 4146 def forward(self, x): 4147 self.buf.add_(1) 4148 return x.sum() + self.buf.sum() 4149 4150 class Foo(torch.nn.Module): 4151 def __init__(self) -> None: 4152 super().__init__() 4153 self.buf = torch.nn.Buffer(torch.zeros(1)) 4154 self.bar = Bar() 4155 4156 def forward(self, x): 4157 self.buf.add_(1) 4158 self.bar.buf.add_(2) 4159 bar = self.bar(x) 4160 return bar.sum() + self.buf.sum() 4161 4162 exported = export(Foo(), (torch.ones(5, 5),)) 4163 stateful_gm = exported.module() 4164 export_return_val = stateful_gm(torch.ones(5, 5)) 4165 eager = Foo() 4166 eager_return_val = eager(torch.ones(5, 5)) 4167 self.assertTrue(torch.allclose(eager_return_val, export_return_val)) 4168 4169 for name, buffer in stateful_gm.named_buffers(): 4170 if name == "L__self___buf": 4171 self.assertTrue(torch.allclose(torch.ones(1), buffer)) 4172 if name == "L__self___bar_buf": 4173 self.assertTrue( 4174 torch.allclose(torch.tensor(4, dtype=torch.float), buffer) 4175 ) 4176 4177 changed = stateful_gm.graph.eliminate_dead_code() 4178 self.assertFalse(changed) 4179 self.assertTrue( 4180 torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5))) 4181 ) 4182 4183 for name, buffer in stateful_gm.named_buffers(): 4184 if name == "L__self___buf": 4185 self.assertTrue( 4186 torch.allclose(torch.tensor(2, dtype=torch.float), buffer) 4187 ) 4188 if name == "L__self___bar_buf": 4189 self.assertTrue( 4190 torch.allclose(torch.tensor(7, dtype=torch.float), buffer) 4191 ) 4192 4193 def test_runtime_assert_for_prim(self): 4194 class Foo(torch.nn.Module): 4195 def forward(self, x, y): 4196 return x + y 4197 4198 foo = Foo() 4199 tensor_inp = torch.ones(7, 5) 4200 dim0_x = torch.export.Dim("dim0_x", min=6) 4201 dynamic_shapes = {"x": {0: dim0_x}, "y": None} 4202 exported = torch.export.export( 4203 foo, (tensor_inp, 5), dynamic_shapes=dynamic_shapes 4204 ) 4205 self.assertTrue( 4206 torch.allclose( 4207 exported.module()(torch.ones(8, 5), 5), foo(torch.ones(8, 5), 5) 4208 ) 4209 ) 4210 with self.assertRaisesRegex( 4211 RuntimeError, 4212 escape("Expected input at *args[1] to be equal to 5, but got 6"), 4213 ): 4214 _ = exported.module()(torch.ones(8, 5), 6) 4215 4216 exported = torch.export.export( 4217 foo, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes 4218 ) 4219 with self.assertRaisesRegex( 4220 RuntimeError, 4221 escape("Expected input at *args[1] to be equal to 5.0, but got 6.0"), 4222 ): 4223 _ = exported.module()(torch.ones(7, 5), 6.0) 4224 4225 def test_runtime_assert_for_prm_str(self): 4226 class Foo(torch.nn.Module): 4227 def forward(self, a, b, mode): 4228 return torch.div(a, b, rounding_mode=mode) 4229 4230 foo = Foo() 4231 inps = (torch.randn(4, 4), torch.randn(4), "trunc") 4232 exported = export(foo, inps) 4233 with self.assertRaisesRegex( 4234 RuntimeError, "to be equal to trunc, but got floor" 4235 ): 4236 _ = exported.module()(torch.randn(4, 4), torch.randn(4), "floor") 4237 self.assertTrue(torch.allclose(exported.module()(*inps), foo(*inps))) 4238 4239 def test_redundant_assert_max_upper_bound(self): 4240 class M(torch.nn.Module): 4241 def forward(self, x): 4242 b = x.nonzero() 4243 torch._check(b.shape[0] >= 3) 4244 return b 4245 4246 m = M() 4247 inp = (torch.tensor([1, 1, 1, 0, 1]),) 4248 dim = torch.export.Dim("dim") 4249 ep = export(m, inp, dynamic_shapes=((dim,),)) 4250 FileCheck().check_count( 4251 "torch.ops.aten._assert_scalar.default", 1, exactly=True 4252 ).run(ep.graph_module.code) 4253 4254 def test_to_module_with_mutated_buffer_multiple_update_sub_later(self): 4255 class Bar(torch.nn.Module): 4256 def __init__(self) -> None: 4257 super().__init__() 4258 self.buf = torch.nn.Buffer(torch.ones(1)) 4259 4260 def forward(self, x): 4261 self.buf.add_(1) 4262 return x.sum() + self.buf.sum() 4263 4264 class Foo(torch.nn.Module): 4265 def __init__(self) -> None: 4266 super().__init__() 4267 self.buf = torch.nn.Buffer(torch.zeros(1)) 4268 self.bar = Bar() 4269 4270 def forward(self, x): 4271 self.buf.add_(1) 4272 bar = self.bar(x) 4273 self.bar.buf.add_(2) 4274 return bar.sum() + self.buf.sum() 4275 4276 exported = export(Foo(), (torch.ones(5, 5),)) 4277 stateful_gm = exported.module() 4278 export_return_val = stateful_gm(torch.ones(5, 5)) 4279 eager = Foo() 4280 eager_return_val = eager(torch.ones(5, 5)) 4281 self.assertTrue(torch.allclose(eager_return_val, export_return_val)) 4282 4283 for name, buffer in stateful_gm.named_buffers(): 4284 if name == "L__self___buf": 4285 self.assertTrue(torch.allclose(torch.ones(1), buffer)) 4286 if name == "L__self___bar_buf": 4287 self.assertTrue( 4288 torch.allclose(torch.tensor(4, dtype=torch.float), buffer) 4289 ) 4290 4291 changed = stateful_gm.graph.eliminate_dead_code() 4292 self.assertFalse(changed) 4293 self.assertTrue( 4294 torch.allclose(stateful_gm(torch.ones(5, 5)), eager(torch.ones(5, 5))) 4295 ) 4296 4297 for name, buffer in stateful_gm.named_buffers(): 4298 if name == "L__self___buf": 4299 self.assertTrue( 4300 torch.allclose(torch.tensor(2, dtype=torch.float), buffer) 4301 ) 4302 if name == "L__self___bar_buf": 4303 self.assertTrue( 4304 torch.allclose(torch.tensor(7, dtype=torch.float), buffer) 4305 ) 4306 4307 def test_retracable_ep(self): 4308 class Bar(torch.nn.Module): 4309 def __init__(self) -> None: 4310 super().__init__() 4311 self.buf = torch.nn.Buffer(torch.ones(1)) 4312 4313 def forward(self, x): 4314 self.buf.add_(1) 4315 return x.sum() + self.buf.sum() 4316 4317 class Foo(torch.nn.Module): 4318 def __init__(self) -> None: 4319 super().__init__() 4320 self.buf = torch.nn.Buffer(torch.zeros(1)) 4321 self.bar = Bar() 4322 4323 def forward(self, x): 4324 self.buf.add_(1) 4325 bar = self.bar(x) 4326 self.bar.buf.add_(2) 4327 return bar.sum() + self.buf.sum() 4328 4329 inp = torch.ones(5, 5) 4330 exported = torch.export.export(Foo(), (inp,)) 4331 reexported = torch.export.export(exported.module(), (inp,)) 4332 4333 self.assertTrue(torch.allclose(Foo()(inp), reexported.module()(inp))) 4334 4335 dim0_x = torch.export.Dim("dim0_x") 4336 exported = torch.export.export(Foo(), (inp,), dynamic_shapes=({0: dim0_x},)) 4337 reexported = torch.export.export(exported.module(), (inp,)) 4338 with self.assertRaisesRegex( 4339 RuntimeError, "shape\[0\] to be equal to 5, but got 7" 4340 ): 4341 reexported.module()(torch.ones(7, 5)) 4342 4343 reexported = torch.export.export( 4344 exported.module(), (inp,), dynamic_shapes=({0: dim0_x},) 4345 ) 4346 self.assertTrue( 4347 torch.allclose( 4348 Foo()(torch.ones(7, 5)), reexported.module()(torch.ones(7, 5)) 4349 ) 4350 ) 4351 4352 # can't retrace with invalid inputs with respect to the original ExportedProgram 4353 dim0_x_v2 = torch.export.Dim("dim0_x_v2", min=3) 4354 exported_v2 = torch.export.export( 4355 Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}} 4356 ) 4357 with self.assertRaisesRegex( 4358 RuntimeError, 4359 escape("Expected input at *args[0].shape[0] to be >= 3, but got 2"), 4360 ): 4361 torch.export.export(exported_v2.module(), (torch.randn(2, 2),)) 4362 4363 def test_export_cond_symbool_pred(self): 4364 class A(torch.nn.Module): 4365 def __init__(self) -> None: 4366 super().__init__() 4367 self.buffer = torch.nn.Buffer(torch.ones(6, 4)) 4368 4369 def forward(self): 4370 return self.buffer.cos() 4371 4372 class Foo(torch.nn.Module): 4373 def __init__(self) -> None: 4374 super().__init__() 4375 self.a = A() 4376 4377 def forward(self, x): 4378 def true_fn(x): 4379 return x.cos() + self.a().sum() 4380 4381 def false_fn(x): 4382 return x.sin() 4383 4384 return cond(x.shape[0] > 4, true_fn, false_fn, [x]) 4385 4386 dim0 = torch.export.Dim("dim0", min=3) 4387 inp = torch.ones(6, 4) 4388 ep = export(Foo(), (inp,), dynamic_shapes={"x": {0: dim0}}) 4389 schema = get_hop_schema(ep) 4390 self.assertExpectedInline( 4391 str(schema), 4392 """cond(SymBool pred, GraphModule true_fn, GraphModule false_fn, Tensor[2] operands) -> Tensor[1]""", 4393 ) 4394 self.assertExpectedInline( 4395 ep.graph_module.code.strip(), 4396 """\ 4397def forward(self, b_a_buffer, x): 4398 sym_size_int_1 = torch.ops.aten.sym_size.int(x, 0) 4399 gt = sym_size_int_1 > 4; sym_size_int_1 = None 4400 true_graph_0 = self.true_graph_0 4401 false_graph_0 = self.false_graph_0 4402 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x, b_a_buffer]); gt = true_graph_0 = false_graph_0 = x = b_a_buffer = None 4403 getitem = cond[0]; cond = None 4404 return (getitem,)""", 4405 ) 4406 self.assertTrue( 4407 torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4))) 4408 ) 4409 4410 def test_aten_lift_fresh_copy(self): 4411 class M(torch.nn.Module): 4412 def forward(self, x): 4413 return torch.ops.aten.lift_fresh_copy(x) 4414 4415 ep = export(M(), (torch.ones(6, 4),)) 4416 found = False 4417 4418 op = "torch.ops.aten.clone.default" 4419 FileCheck().check_count(op, 1, exactly=True).run(ep.graph_module.code) 4420 4421 def test_cond_buffers(self): 4422 class M(torch.nn.Module): 4423 def __init__(self) -> None: 4424 super().__init__() 4425 self.register_parameter( 4426 "param", torch.nn.Parameter(torch.ones(2, 3), requires_grad=False) 4427 ) 4428 self.buffer = torch.nn.Buffer(torch.ones(2, 3) + 1) 4429 4430 def true_fn(self, x): 4431 return x + self.param 4432 4433 def false_fn(self, x): 4434 return x + self.buffer 4435 4436 def forward(self, x): 4437 return cond(x.shape[0] == 4, self.true_fn, self.false_fn, [x]) 4438 4439 inp = torch.ones(2, 3) 4440 ep = torch.export.export(M(), (inp,)) 4441 inp = torch.randn(2, 3) 4442 epm = ep.module() 4443 self.assertTrue(torch.allclose(epm(inp), M()(inp))) 4444 4445 for gm in epm.named_modules(): 4446 if not isinstance(gm, torch.fx.GraphModule): 4447 continue 4448 self.assertEqual( 4449 len([node for node in gm.graph.nodes if node.op == "placeholder"]), 1 4450 ) 4451 4452 # map_fn references module outside the module hierarchy 4453 @unittest.expectedFailure 4454 def test_map_buffers(self): 4455 class M1(torch.nn.Module): 4456 def __init__(self) -> None: 4457 super().__init__() 4458 self.register_parameter( 4459 "param", torch.nn.Parameter(torch.tensor(5), requires_grad=False) 4460 ) 4461 self.buffer = torch.nn.Buffer(torch.tensor(6) + 1) 4462 4463 m1 = M1() 4464 4465 def map_fn(x, y): 4466 z = x + y + m1.param + m1.buffer 4467 z.add_(4) 4468 return z 4469 4470 class M(torch.nn.Module): 4471 def forward(self, xs, y): 4472 return map(map_fn, xs, y) 4473 4474 example_inputs = (torch.ones(3, 2), torch.tensor(3)) 4475 ep = torch.export.export(M(), example_inputs) 4476 example_inputs = (torch.randn(3, 2), torch.tensor(3)) 4477 epm = ep.module() 4478 self.assertTrue(torch.allclose(epm(*example_inputs), M()(*example_inputs))) 4479 4480 for gm in epm.named_modules(): 4481 if not isinstance(gm, torch.fx.GraphModule): 4482 continue 4483 self.assertEqual( 4484 len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2 4485 ) 4486 4487 def test_check_is_size_error(self): 4488 class Module(torch.nn.Module): 4489 def forward(self, x): 4490 a = x.item() 4491 # We cannot automatically infer a is a size here because view 4492 # accepts -1 4493 return torch.randn(24).view(a, 4) 4494 4495 f = Module() 4496 if is_non_strict_test(self._testMethodName): 4497 error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode 4498 else: 4499 error = torch._dynamo.exc.UserError 4500 error_msg = r"Could not guard on data-dependent expression" 4501 with self.assertRaisesRegex(error, error_msg): 4502 _ = export(f, (torch.tensor(6),)) 4503 4504 def test_train_eval_on_exported_preautograd_module(self): 4505 class Foo(torch.nn.Module): 4506 def __init__(self) -> None: 4507 super().__init__() 4508 4509 def forward(self, x): 4510 if x.shape[0] > 4: 4511 return x.cos() 4512 return x.sin() 4513 4514 graph_module = _export(Foo(), (torch.ones(7, 5),), pre_dispatch=True).module() 4515 with self.assertRaisesRegex( 4516 NotImplementedError, r"Calling train\(\) is not supported yet." 4517 ): 4518 graph_module.train() 4519 4520 with self.assertRaisesRegex( 4521 NotImplementedError, r"Calling eval\(\) is not supported yet." 4522 ): 4523 graph_module.eval() 4524 4525 def test_lifted_constants(self) -> None: 4526 class Module(torch.nn.Module): 4527 def forward(self, x): 4528 return x + torch.tensor(3) 4529 4530 f = Module() 4531 ep = export(f, (torch.tensor(1),)) 4532 4533 self.assertEqual(len(ep.graph_signature.input_specs), 2) 4534 self.assertEqual(len(ep.constants), 1) 4535 4536 class Foo(torch.nn.Module): 4537 def __init__(self) -> None: 4538 super().__init__() 4539 self.a = torch.tensor(3) 4540 4541 def forward(self, x): 4542 list_tensor = [torch.tensor(3), torch.tensor(4)] 4543 return x + self.a + list_tensor[0] + list_tensor[1] 4544 4545 ep = export(Foo(), (torch.tensor(1),)) 4546 4547 self.assertEqual(len(ep.graph_signature.input_specs), 4) 4548 self.assertEqual(len(ep.state_dict), 0) 4549 self.assertEqual(len(ep.constants), 3) 4550 4551 inp = (torch.tensor(5),) 4552 self.assertTrue(torch.allclose(ep.module()(*inp), Foo()(*inp))) 4553 4554 transform = ep.run_decompositions() 4555 self.assertEqual(len(ep.graph_signature.input_specs), 4) 4556 self.assertTrue(torch.allclose(ep.module()(*inp), transform.module()(*inp))) 4557 4558 def test_tensor_attribute_zero_args(self): 4559 class Foo(torch.nn.Module): 4560 def __init__(self, value): 4561 super().__init__() 4562 self.x = torch.tensor(value) 4563 4564 def forward(self): 4565 return self.x.clone() 4566 4567 m = Foo([1, 2]) 4568 ep = export(m, ()) 4569 self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"]) 4570 4571 def test_preserve_shape_dynamism_for_unused_inputs(self): 4572 @dataclass 4573 class Input: 4574 f: torch.Tensor 4575 p: torch.Tensor 4576 4577 torch._export.utils.register_dataclass_as_pytree_node( 4578 Input, 4579 serialized_type_name="test_preserve_shape_dynamism_for_unused_inputs.Input", 4580 ) 4581 4582 class Module(torch.nn.Module): 4583 def forward(self, x: Input): 4584 return x.f + 1 4585 4586 mod = Module() 4587 example_inputs = (Input(f=torch.ones(10, 4), p=torch.zeros(10, 4)),) 4588 ep_static = torch.export.export(mod, example_inputs) 4589 for node in ep_static.graph.nodes: 4590 if node.op == "placeholder": 4591 for s in node.meta["val"].shape: 4592 self.assertIsInstance(s, int) 4593 4594 dim0_x_f, dim0_x_p = torch.export.dims("dim0_x_f", "dim0_x_p") 4595 dynamic_shapes = {"x": [{0: dim0_x_f}, {0: dim0_x_p}]} 4596 ep_dynamic = torch.export.export( 4597 mod, example_inputs, dynamic_shapes=dynamic_shapes 4598 ) 4599 for node in ep_dynamic.graph.nodes: 4600 if node.op == "placeholder": 4601 for i, s in enumerate(node.meta["val"].shape): 4602 if i == 0: 4603 self.assertIsInstance(s, torch.SymInt) 4604 else: 4605 self.assertIsInstance(s, int) 4606 4607 def test_multiple_definitions_same_name_dim(self): 4608 class Foo(torch.nn.Module): 4609 def forward(self, x, y): 4610 return torch.matmul(x, y) 4611 4612 A = torch.export.Dim("C", min=3) 4613 B = torch.export.Dim("C", max=12) 4614 with self.assertRaisesRegex( 4615 torch._dynamo.exc.UserError, 4616 "Found different definitions Dim\\(.*min=3\\) and Dim\\(.*max=12\\) " 4617 "for the same symbolic dimension", 4618 ): 4619 torch.export.export( 4620 Foo(), 4621 (torch.randn(10, 10), torch.randn(10, 10)), 4622 dynamic_shapes={"x": (A, B), "y": (B, A)}, 4623 ) 4624 4625 def test_export_with_wrong_inputs(self): 4626 class MyModule(torch.nn.Module): 4627 def forward(self, x): 4628 return x + x 4629 4630 exported_program = export(MyModule(), (torch.rand(2, 3),), {}) 4631 with self.assertRaisesRegex(ValueError, "Trying to flatten user inputs"): 4632 exported_program.module()(torch.rand(2, 3), torch.rand(2, 3)) 4633 4634 def test_export_decomps_simple(self): 4635 class M(torch.nn.Module): 4636 def __init__(self) -> None: 4637 super().__init__() 4638 self.lin = torch.nn.Linear(10, 1) 4639 4640 def forward(self, x): 4641 return self.lin(x) 4642 4643 inp = (torch.randn(5, 10),) 4644 m = M() 4645 ep = export(m, inp) 4646 state_dict = ep.state_dict 4647 4648 self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp))) 4649 4650 core_aten_ep = ep.run_decompositions() 4651 FileCheck().check_count("torch.ops.aten.permute.default", 1, exactly=True).run( 4652 core_aten_ep.graph_module.code 4653 ) 4654 FileCheck().check_count("torch.ops.aten.t.default", 0, exactly=True).run( 4655 core_aten_ep.graph_module.code 4656 ) 4657 self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp))) 4658 self.assertEqual(id(state_dict), id(ep.state_dict)) 4659 4660 def test_export_decomps_dynamic(self): 4661 class M(torch.nn.Module): 4662 def __init__(self) -> None: 4663 super().__init__() 4664 self.lin = torch.nn.Linear(10, 1) 4665 4666 def forward(self, x): 4667 return self.lin(x) 4668 4669 inp = (torch.randn(5, 10),) 4670 m = M() 4671 ep = export(m, inp, dynamic_shapes={"x": {0: Dim("batch")}}) 4672 4673 core_aten_ep = ep.run_decompositions() 4674 4675 input_node = [ 4676 node for node in core_aten_ep.graph.nodes if node.op == "placeholder" 4677 ][-1] 4678 self.assertTrue(isinstance(input_node.meta["val"].shape[0], torch.SymInt)) 4679 4680 FileCheck().check_count("torch.ops.aten.permute.default", 1, exactly=True).run( 4681 core_aten_ep.graph_module.code 4682 ) 4683 FileCheck().check_count("torch.ops.aten.t.default", 0, exactly=True).run( 4684 core_aten_ep.graph_module.code 4685 ) 4686 self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp))) 4687 4688 def test_nonzero_2(self): 4689 class Module(torch.nn.Module): 4690 def forward(self, x): 4691 return torch.nonzero(x) 4692 4693 f = Module() 4694 ep = export(f, (torch.ones(2),)) 4695 inp = torch.randn(2) 4696 self.assertTrue(torch.allclose(ep.module()(inp), torch.nonzero(inp))) 4697 4698 def test_redundant_asserts(self): 4699 class Foo(torch.nn.Module): 4700 def forward(self, x): 4701 y = x.item() 4702 torch._check_is_size(y) 4703 return torch.zeros(y) 4704 4705 f = Foo() 4706 4707 ep = export(f, (torch.tensor([3]),)) 4708 4709 FileCheck().check_count( 4710 "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True 4711 ).run(ep.graph_module.code) 4712 FileCheck().check_count( 4713 "torch.ops.aten._assert_scalar.default", 1, exactly=True 4714 ).run(ep.graph_module.code) 4715 4716 ep = ep.run_decompositions() 4717 4718 FileCheck().check_count( 4719 "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True 4720 ).run(ep.graph_module.code) 4721 FileCheck().check_count( 4722 "torch.ops.aten._assert_scalar.default", 1, exactly=True 4723 ).run(ep.graph_module.code) 4724 4725 def test_non_arg_name_dynamic_shapes_api(self): 4726 class Foo(torch.nn.Module): 4727 def forward(self, a, b): 4728 return a.sum() + b.sum() 4729 4730 foo = Foo() 4731 dim = torch.export.Dim("dim") 4732 ep = torch.export.export( 4733 foo, 4734 (torch.randn(4, 4), torch.randn(4, 4)), 4735 dynamic_shapes=(None, {0: dim}), 4736 ) 4737 4738 test_inp = (torch.randn(4, 4), torch.randn(7, 4)) 4739 self.assertEqual(ep.module()(*test_inp), foo(*test_inp)) 4740 4741 ep_v2 = torch.export.export( 4742 foo, 4743 (torch.randn(4, 4), torch.randn(4, 4)), 4744 dynamic_shapes=(None, None), 4745 ) 4746 with self.assertRaisesRegex( 4747 RuntimeError, "shape\[0\] to be equal to 4, but got 7" 4748 ): 4749 ep_v2.module()(*test_inp) 4750 4751 def test_constant_output(self): 4752 class ModuleConstant(torch.nn.Module): 4753 def __init__(self) -> None: 4754 super().__init__() 4755 self.b = torch.randn(3, 2) 4756 4757 def forward(self): 4758 return self.b 4759 4760 class ModuleNestedConstant(torch.nn.Module): 4761 def __init__(self) -> None: 4762 super().__init__() 4763 self.bff = torch.randn(3, 2) 4764 4765 def forward(self, x, y): 4766 return {"prediction": (x + y, self.bff)} 4767 4768 mod = ModuleConstant() 4769 ep = torch.export.export(mod, ()) 4770 self.assertEqual(ep.module()(), mod()) 4771 4772 args = (torch.randn(3, 2), torch.randn(3, 2)) 4773 mod = ModuleNestedConstant() 4774 ep = torch.export.export(mod, args) 4775 self.assertEqual(ep.module()(*args), mod(*args)) 4776 4777 def test_non_arg_name_dynamic_shapes_api_with_kwarg(self): 4778 class Foo(torch.nn.Module): 4779 def forward(self, a, b, kw1, kw2): 4780 return a.sum() + b.sum() + kw1.sum() - kw2.sum() 4781 4782 foo = Foo() 4783 dim = torch.export.Dim("dim") 4784 dim_for_kw1 = torch.export.Dim("dim_for_kw1") 4785 ep = torch.export.export( 4786 foo, 4787 (torch.randn(4, 4), torch.randn(4, 4)), 4788 {"kw2": torch.ones(4, 4), "kw1": torch.zeros(4, 4)}, 4789 # We are specifying dynamism on the first kwarg even though user passed in 4790 # different order 4791 dynamic_shapes=(None, {0: dim}, {0: dim_for_kw1}, None), 4792 ) 4793 4794 test_inp = (torch.randn(4, 4), torch.randn(7, 4)) 4795 test_kwargs = {"kw2": torch.ones(4, 4), "kw1": torch.zeros(9, 4)} 4796 # This should work even if the kwarg order are flipped. 4797 self.assertEqual( 4798 ep.module()(*test_inp, **test_kwargs), foo(*test_inp, **test_kwargs) 4799 ) 4800 4801 def test_non_arg_name_dynamic_shapes_api_with_container_type(self): 4802 class Foo(torch.nn.Module): 4803 def forward(self, a, b): 4804 return a[0].sum() + a[1].sum() + b.sum() 4805 4806 inp_a = (torch.randn(4, 4), torch.randn(4, 4)) 4807 inp_b = torch.randn(4, 4) 4808 inp = (inp_a, inp_b) 4809 4810 count = 0 4811 4812 def dynamify_inp(x): 4813 # Mark the second input a[1] dynamic 4814 nonlocal count 4815 if count == 1: 4816 dim = torch.export.Dim("dim", min=3) 4817 count += 1 4818 return {0: dim} 4819 count += 1 4820 return None 4821 4822 dynamic_shapes = tree_map(dynamify_inp, inp) 4823 4824 foo = Foo() 4825 ep = torch.export.export(foo, inp, dynamic_shapes=dynamic_shapes) 4826 4827 test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4)) 4828 with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be >= 3, but got 2"): 4829 ep.module()(*test_inp) 4830 4831 def test_nested_module(self): 4832 class M1(torch.nn.Module): 4833 def forward(self, x): 4834 return x + x 4835 4836 class M2(torch.nn.Module): 4837 def forward(self, x): 4838 m = M1() 4839 return m(x) * x 4840 4841 inps = (torch.randn(3, 3),) 4842 ep = export(M2(), inps) 4843 self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) 4844 4845 add_nodes = [ 4846 node 4847 for node in ep.graph.nodes 4848 if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor 4849 ] 4850 self.assertEqual(len(add_nodes), 1) 4851 add_node = add_nodes[0] 4852 self.assertEqual(len(add_node.meta["nn_module_stack"]), 1) 4853 self.assertTrue("M2" in list(add_node.meta["nn_module_stack"].values())[0][1]) 4854 4855 self.assertExpectedInline( 4856 str(ep.graph).strip(), 4857 """\ 4858graph(): 4859 %x : [num_users=2] = placeholder[target=x] 4860 %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %x), kwargs = {}) 4861 %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) 4862 return (mul,)""", 4863 ) 4864 4865 unflattened = unflatten(ep) 4866 self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) 4867 4868 def test_nested_module_with_init_buffer(self): 4869 class M1(torch.nn.Module): 4870 def __init__(self) -> None: 4871 super().__init__() 4872 self.b = torch.ones(3, 3) 4873 4874 def forward(self, x): 4875 return x + self.b 4876 4877 class M2(torch.nn.Module): 4878 def forward(self, x): 4879 m = M1() 4880 return m(x) * x 4881 4882 inps = (torch.randn(3, 3),) 4883 ep = export(M2(), inps) 4884 self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) 4885 4886 self.assertEqual(len(ep.state_dict), 0) 4887 self.assertEqual(len(ep.constants), 0) 4888 4889 self.assertExpectedInline( 4890 str(ep.graph).strip(), 4891 """\ 4892graph(): 4893 %x : [num_users=2] = placeholder[target=x] 4894 %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False}) 4895 %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %ones), kwargs = {}) 4896 %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) 4897 return (mul,)""", 4898 ) 4899 4900 unflattened = unflatten(ep) 4901 self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) 4902 4903 @testing.expectedFailureRetraceability # Retracing tensor constants results in buffers 4904 def test_nested_module_with_constant_buffer(self): 4905 class M1(torch.nn.Module): 4906 def __init__(self) -> None: 4907 super().__init__() 4908 self.b = torch.tensor(5) 4909 4910 def forward(self, x): 4911 return x + self.b 4912 4913 class M2(torch.nn.Module): 4914 def forward(self, x): 4915 m = M1() 4916 return m(x) * x 4917 4918 inps = (torch.randn(3, 3),) 4919 ep = export(M2(), inps) 4920 self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) 4921 4922 self.assertEqual(len(ep.state_dict), 0) 4923 self.assertEqual(len(ep.constants), 1) 4924 4925 if is_training_ir_test(self._testMethodName): 4926 self.assertExpectedInline( 4927 str(ep.graph).strip(), 4928 """\ 4929graph(): 4930 %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] 4931 %x : [num_users=2] = placeholder[target=x] 4932 %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {}) 4933 %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lift_fresh_copy), kwargs = {}) 4934 %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) 4935 return (mul,)""", 4936 ) 4937 else: 4938 self.assertExpectedInline( 4939 str(ep.graph).strip(), 4940 """\ 4941graph(): 4942 %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] 4943 %x : [num_users=2] = placeholder[target=x] 4944 %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {}) 4945 %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {}) 4946 %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %detach), kwargs = {}) 4947 %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) 4948 return (mul,)""", 4949 ) 4950 4951 unflattened = unflatten(ep) 4952 self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) 4953 4954 def test_nested_module_with_parameter(self): 4955 class M1(torch.nn.Module): 4956 def __init__(self) -> None: 4957 super().__init__() 4958 self.a = torch.nn.Parameter(torch.ones(3, 3)) 4959 self.b = torch.nn.Parameter(torch.tensor(5.0)) 4960 4961 def forward(self, x): 4962 return x + self.a * self.b 4963 4964 class M2(torch.nn.Module): 4965 def forward(self, x): 4966 m = M1() 4967 return m(x) * x 4968 4969 inps = (torch.randn(3, 3),) 4970 # Strict export segfaults (Issue #128109) 4971 ep = torch.export.export(M2(), inps, strict=False) 4972 self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps))) 4973 4974 self.assertEqual(len(ep.state_dict), 0) 4975 self.assertEqual(len(ep.constants), 1) 4976 4977 self.assertExpectedInline( 4978 str(ep.graph).strip(), 4979 """\ 4980graph(): 4981 %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] 4982 %x : [num_users=2] = placeholder[target=x] 4983 %ones : [num_users=1] = call_function[target=torch.ops.aten.ones.default](args = ([3, 3],), kwargs = {device: cpu, pin_memory: False}) 4984 %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%ones,), kwargs = {}) 4985 %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {}) 4986 %detach_1 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {}) 4987 %detach_2 : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%detach_1,), kwargs = {}) 4988 %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%detach, %detach_2), kwargs = {}) 4989 %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %mul), kwargs = {}) 4990 %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %x), kwargs = {}) 4991 return (mul_1,)""", 4992 ) 4993 4994 unflattened = unflatten(ep) 4995 self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) 4996 4997 def test_lazy_module_kwargs(self): 4998 class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module): 4999 def initialize_parameters(self, *args, **kwargs): 5000 pass 5001 5002 def forward(self, x, y): 5003 return x + y 5004 5005 m = LazyModule() 5006 ep = torch.export.export( 5007 m, (), {"x": torch.randn(3, 3), "y": torch.randn(3, 3)} 5008 ) 5009 inputs = {"x": torch.randn(3, 3), "y": torch.randn(3, 3)} 5010 self.assertEqual(ep.module()(**inputs), m(**inputs)) 5011 5012 def test_retrace_pre_autograd(self): 5013 class Foo(torch.nn.Module): 5014 def __init__(self) -> None: 5015 super().__init__() 5016 self.buffer = torch.nn.Buffer(torch.ones(4, 4)) 5017 5018 def forward(self, x): 5019 self.buffer.add_(4) 5020 return x.sum() + self.buffer.sum() 5021 5022 inp = torch.randn(4, 4) 5023 gm = _export( 5024 Foo(), 5025 (inp,), 5026 dynamic_shapes=({0: torch.export.Dim("dim", min=3)},), 5027 pre_dispatch=True, 5028 ).module() 5029 5030 with self.assertRaisesRegex( 5031 RuntimeError, escape("Expected input at *args[0].shape[0]") 5032 ): 5033 gm(torch.randn(2, 2)) 5034 5035 with self.assertRaisesRegex( 5036 RuntimeError, escape("Expected input at *args[0].shape[0]") 5037 ): 5038 torch.export.export(gm, (torch.randn(2, 2),)) 5039 5040 ep = torch.export.export( 5041 gm, 5042 (torch.randn(5, 4),), 5043 dynamic_shapes=({0: torch.export.Dim("dim", min=3)},), 5044 ) 5045 5046 test_inp = torch.ones(8, 4) 5047 self.assertTrue(torch.allclose(ep.module()(test_inp), Foo().forward(test_inp))) 5048 5049 def test_runtime_assert_with_size(self): 5050 class M(torch.nn.Module): 5051 def forward(self, x, y): 5052 a = x.item() 5053 torch._check_is_size(a) 5054 torch._check(a <= y.size(0)) 5055 return y[:a] 5056 5057 ep = export( 5058 M(), 5059 (torch.tensor(5), torch.ones(10)), 5060 dynamic_shapes={"x": None, "y": {0: torch.export.Dim("t")}}, 5061 ) 5062 inp = (torch.tensor(6), torch.randn(13)) 5063 self.assertTrue(torch.allclose(ep.module()(*inp), M()(*inp))) 5064 5065 @unittest.skip("Test is only supposed to work with non-strict mode") 5066 def test_issue_113041(self): 5067 class TestModule(torch.nn.Module): 5068 def __init__(self) -> None: 5069 super().__init__() 5070 self.a = torch.tensor(1.0) 5071 5072 def forward(self, x: torch.Tensor) -> torch.Tensor: 5073 return x + self.a 5074 5075 def forward_hook(module: torch.nn.Module, inputs, output) -> torch.Tensor: 5076 return 2 * output 5077 5078 seq = torch.nn.Sequential(TestModule()).eval() 5079 seq.b = torch.tensor(2) 5080 handle = seq.register_forward_hook(forward_hook) 5081 5082 class M(torch.nn.Module): 5083 def __init__(self) -> None: 5084 super().__init__() 5085 self.seq = seq 5086 5087 def forward(self, x): 5088 return self.seq(x) + self.seq.b 5089 5090 inp = (torch.randn(2, 8),) 5091 ep = export(M(), inp) # This errors because dynamo adds an extra input 5092 5093 def test_export_with_fake_tensor_inputs(self): 5094 fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() 5095 5096 class Model(torch.nn.Module): 5097 def __init__(self) -> None: 5098 super().__init__() 5099 self.linear = torch.nn.Linear(2, 2) 5100 5101 def forward(self, x): 5102 out = self.linear(x) 5103 return out 5104 5105 # Put the inputs on a device 5106 with fake_mode, torch.device("meta"): 5107 x = torch.rand(5, 2, 2) 5108 model = Model() 5109 5110 exported_program = torch.export.export(model, (x,)) 5111 export_res = exported_program.module()(x) 5112 exp_res = model(x) 5113 all_meta_val = [ 5114 node.meta["val"] 5115 for node in exported_program.graph_module.graph.nodes 5116 if "val" in node.meta 5117 ] 5118 self.assertTrue(export_res.size() == exp_res.size()) 5119 self.assertTrue(all(val.device == x.device for val in all_meta_val)) 5120 self.assertTrue( 5121 all(val.fake_mode is all_meta_val[0].fake_mode for val in all_meta_val) 5122 ) 5123 decomposed_ep = exported_program.run_decompositions() 5124 export_res = decomposed_ep.module()(x) 5125 self.assertTrue(export_res.size() == exp_res.size()) 5126 5127 def test_export_with_fake_tensor_inputs_on_cuda_devices(self): 5128 fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() 5129 5130 class Model(torch.nn.Module): 5131 def __init__(self) -> None: 5132 super().__init__() 5133 self.linear = torch.nn.Linear(2, 2) 5134 5135 def forward(self, x): 5136 out = self.linear(x) 5137 return out 5138 5139 # Put the inputs on a device 5140 with fake_mode, torch.device("meta"): 5141 x = torch.rand(5, 2, 2) 5142 model = Model() 5143 5144 # Manualy set the fake_device of fake tensors. 5145 x.fake_device = torch.device("cuda:0") 5146 for n, p in model.named_parameters(): 5147 p.fake_device = torch.device("cuda:0") 5148 5149 # Need to set all the requires_grad of tensors to False, because fake_tensor with CUDA device 5150 # doesn't quite work well with aot_autograd right now due to some logic fails 5151 # the check in call getDeviceGuardImpl in InputMetadata. 5152 x.requires_grad = False 5153 for n, p in model.named_parameters(): 5154 p.requires_grad = False 5155 5156 def check_device_and_fake_mode(): 5157 exported_program = torch.export.export(model, (x,)) 5158 export_res = exported_program.module()(x) 5159 exp_res = model(x) 5160 all_meta_val = [ 5161 node.meta["val"] 5162 for node in exported_program.graph_module.graph.nodes 5163 if "val" in node.meta 5164 ] 5165 self.assertTrue(export_res.size() == exp_res.size()) 5166 self.assertTrue(all(val.device == x.device for val in all_meta_val)) 5167 self.assertTrue( 5168 all(val.fake_mode is all_meta_val[0].fake_mode for val in all_meta_val) 5169 ) 5170 5171 check_device_and_fake_mode() 5172 5173 def test_run_decomposition_supports_user_input_mutation(self): 5174 class SingleOp(torch.nn.Module): 5175 def __init__(self) -> None: 5176 super().__init__() 5177 self.op = torch.ops.aten.native_batch_norm 5178 5179 def forward( 5180 self, 5181 input, 5182 weight, 5183 bias, 5184 running_mean, 5185 running_var, 5186 training, 5187 momentum, 5188 eps, 5189 **kwargs, 5190 ): 5191 return self.op( 5192 input, 5193 weight, 5194 bias, 5195 running_mean, 5196 running_var, 5197 training, 5198 momentum, 5199 eps, 5200 **kwargs, 5201 ) 5202 5203 input = torch.randn(5, 5, 5) 5204 weight = torch.randn(5) 5205 bias = torch.randn(5) 5206 running_mean = torch.randn(5) 5207 running_var = torch.randn(5) 5208 training = True 5209 momentum = 0.5 5210 eps = 0.6 5211 5212 model = SingleOp() 5213 output = model( 5214 input, weight, bias, running_mean, running_var, training, momentum, eps 5215 ) 5216 5217 ep = torch.export.export( 5218 model, 5219 args=( 5220 input, 5221 weight, 5222 bias, 5223 running_mean, 5224 running_var, 5225 training, 5226 momentum, 5227 eps, 5228 ), 5229 ) 5230 ep.run_decompositions(decomp_table=torch._decomp.decomposition_table) 5231 self.assertEqual( 5232 ep.module()( 5233 input, weight, bias, running_mean, running_var, training, momentum, eps 5234 ), 5235 output, 5236 ) 5237 5238 def test_export_graph_with_no_inputs(self): 5239 # We saw this pattern when users want to export 5240 # a graph that initlizes the states of a model. 5241 class Module(torch.nn.Module): 5242 def forward(self): 5243 return torch.randn(3, 4), torch.randn(3, 4) 5244 5245 f = Module() 5246 ep = torch.export.export(f, ()) 5247 a, b = ep.module()() 5248 self.assertEqual(a.size(), torch.Size([3, 4])) 5249 self.assertEqual(b.size(), torch.Size([3, 4])) 5250 5251 # Contains unbacked symint 5252 class M(torch.nn.Module): 5253 def forward(self): 5254 full = torch.full((), 11) 5255 i0 = full.item() 5256 return (torch.full((i0,), 0.0),) 5257 5258 f = M() 5259 ep = torch.export.export(f, ()) 5260 a = ep.module()()[0] 5261 self.assertEqual(a.size(), torch.Size([11])) 5262 self.assertEqual(a, torch.zeros(11)) 5263 5264 def test_pad_sequence(self): 5265 class Module(torch.nn.Module): 5266 def forward(self, x): 5267 return torch._C._nn.pad_sequence([x]) 5268 5269 m0 = Module() 5270 inputs = (torch.randn(3, 2),) 5271 ep = torch.export.export( 5272 m0, inputs, dynamic_shapes={"x": {0: Dim("batch_size")}} 5273 ) 5274 self.assertEqual(ep.module()(*inputs), m0(*inputs)) 5275 5276 class ModuleBatchFirst(torch.nn.Module): 5277 def forward(self, x): 5278 return torch._C._nn.pad_sequence([x], batch_first=True) 5279 5280 m1 = ModuleBatchFirst() 5281 inputs = (torch.randn(3, 2),) 5282 ep = torch.export.export( 5283 m1, inputs, dynamic_shapes={"x": {0: Dim("batch_size")}} 5284 ) 5285 self.assertEqual(ep.module()(*inputs), m1(*inputs)) 5286 5287 class ModuleMulti(torch.nn.Module): 5288 def forward(self, x, y, z): 5289 return torch._C._nn.pad_sequence([x, y, z]) 5290 5291 m2 = ModuleMulti() 5292 inputs = (torch.randn(5, 2), torch.randn(4, 2), torch.randn(3, 2)) 5293 ep = torch.export.export( 5294 m2, 5295 inputs, 5296 dynamic_shapes={ 5297 "x": {0: Dim("batch_size")}, 5298 "y": {0: Dim("y")}, 5299 "z": {0: Dim("z")}, 5300 }, 5301 ) 5302 self.assertEqual(ep.module()(*inputs), m2(*inputs)) 5303 5304 class ModuleMultiBatchFirst(torch.nn.Module): 5305 def forward(self, x, y, z): 5306 return torch._C._nn.pad_sequence([x, y, z], batch_first=True) 5307 5308 m3 = ModuleMulti() 5309 inputs = (torch.randn(5, 2), torch.randn(4, 2), torch.randn(3, 2)) 5310 ep = torch.export.export( 5311 m2, 5312 inputs, 5313 dynamic_shapes={ 5314 "x": {0: Dim("batch_size")}, 5315 "y": {0: Dim("y")}, 5316 "z": {0: Dim("z")}, 5317 }, 5318 ) 5319 self.assertEqual(ep.module()(*inputs), m3(*inputs)) 5320 5321 def test_export_then_compile_tensor_ctor(self): 5322 class M(torch.nn.Module): 5323 def forward(self, scores, mask): 5324 scores = scores.masked_fill( 5325 mask, torch.tensor(torch.finfo(scores.dtype).min) 5326 ) # (bs, n_heads, q_length, k_length) 5327 return scores 5328 5329 tensor_cpu = torch.randn(2, 4) 5330 mask_cpu = torch.BoolTensor( 5331 [[False, True, False, False], [False, False, False, False]] 5332 ) 5333 5334 m = M().eval() 5335 # res_ref = m(tensor_cpu, mask_cpu) 5336 # print("res_ref is: {}".format(res_ref), flush=True) 5337 5338 exported_model = _export(m, (tensor_cpu, mask_cpu), pre_dispatch=True).module() 5339 optimized_model = torch.compile(exported_model) 5340 optimized_model(tensor_cpu, mask_cpu) 5341 5342 def test_export_input_mutation_static_shape(self): 5343 class MutationModel(torch.nn.Module): 5344 def forward(self, x, y): 5345 x.view(3, 2, -1).add_(y) 5346 return x 5347 5348 inputs = (torch.randn(12), torch.tensor(2)) 5349 model = MutationModel() 5350 ep = export(model, inputs) 5351 inputs_export = copy.deepcopy(inputs) 5352 inputs_model = copy.deepcopy(inputs) 5353 self.assertEqual(ep.module()(*inputs_export), model(*inputs_model)) 5354 self.assertEqual(inputs[0] + torch.tensor(2), inputs_model[0]) 5355 self.assertEqual(inputs[0] + torch.tensor(2), inputs_export[0]) 5356 5357 def test_export_input_mutation_dynamic_shape(self): 5358 class MutationModel(torch.nn.Module): 5359 def forward(self, x, y): 5360 x[0].mul_(y) 5361 return x 5362 5363 inputs = ((torch.randn(12), torch.randn(3, 2)), 2.0) 5364 model = MutationModel() 5365 ep = torch.export.export( 5366 model, 5367 inputs, 5368 dynamic_shapes={"x": ({0: torch.export.Dim("dim")}, None), "y": None}, 5369 ) 5370 nodes = list(ep.graph.nodes) 5371 self.assertEqual(nodes[0].op, "placeholder") 5372 self.assertIsInstance(nodes[0].meta["val"], torch.Tensor) 5373 self.assertIsInstance(nodes[0].meta["val"].shape[0], torch.SymInt) 5374 5375 inputs_export = copy.deepcopy(inputs) 5376 inputs_model = copy.deepcopy(inputs) 5377 self.assertEqual(ep.module()(*inputs_export), model(*inputs_model)) 5378 self.assertEqual(inputs[0][0] * 2.0, inputs_model[0][0]) 5379 self.assertEqual(inputs[0][0] * 2.0, inputs_export[0][0]) 5380 5381 def test_export_input_mutation_bug(self): 5382 class M(torch.nn.Module): 5383 def forward(self, x): 5384 x[:, :2, :] = x[:, :2, :] + 1 5385 return x 5386 5387 inputs = (torch.ones(4, 4, 4),) 5388 ep = torch.export.export(M(), inputs) 5389 m = ep.module() 5390 5391 # Make the name conflict with a placeholder name that we get from 5392 # aot_export 5393 for i, node in enumerate(m.graph.nodes): 5394 if node.op == "placeholder": 5395 node.name = f"arg0_{i + 1}" 5396 m.recompile() 5397 5398 ep = torch.export.export(m, inputs) 5399 5400 inputs = (torch.randn(4, 4, 4),) 5401 self.assertEqual( 5402 ep.module()(*copy.deepcopy(inputs)), M()(*copy.deepcopy(inputs)) 5403 ) 5404 5405 def test__scaled_dot_product_flash_attention(self): 5406 class Module(torch.nn.Module): 5407 def forward(self, q, k, v): 5408 res = torch.nn.functional.scaled_dot_product_attention(q, k, v) 5409 return res[0] 5410 5411 m = Module() 5412 inputs = ( 5413 torch.randn(5, 4, 3, 2), 5414 torch.randn(5, 4, 3, 2), 5415 torch.randn(5, 4, 3, 2), 5416 ) 5417 ep = export(m, inputs) 5418 self.assertEqual(ep.module()(*inputs), m(*inputs)) 5419 5420 @testing.expectedFailureSerDer # symfloat nyi 5421 def test_sym_sqrt(self): 5422 import math 5423 5424 class M(torch.nn.Module): 5425 def forward(self, x): 5426 return x / torch.sym_sqrt(x.shape[0]) 5427 5428 ep = export(M(), (torch.ones(16, 4),), dynamic_shapes={"x": {0: Dim("dim")}}) 5429 _ExportPassBaseDeprecatedDoNotUse()(ep.graph_module) 5430 FileCheck().check_count("torch._sym_sqrt", 1, exactly=True).run( 5431 ep.graph_module.code 5432 ) 5433 5434 def test_check_specialized_int(self): 5435 class SingleOp(torch.nn.Module): 5436 def __init__(self) -> None: 5437 super().__init__() 5438 self.op = torch.ops.aten.scatter_add 5439 5440 def forward(self, t, dim, index, src, **kwargs): 5441 return self.op(t, dim, index, src, **kwargs) 5442 5443 t = torch.randn(10, 5) 5444 dim = -1 5445 index = torch.tensor( 5446 [ 5447 [2, 4, 3, 1, 0], 5448 [0, 2, 1, 4, 3], 5449 [3, 1, 4, 2, 0], 5450 [4, 0, 3, 1, 2], 5451 [3, 0, 4, 1, 2], 5452 ] 5453 ) 5454 src = torch.randn(5, 5) 5455 5456 model = SingleOp() 5457 output = model(t, dim, index, src) 5458 5459 ep = torch.export.export(model, args=(t, dim, index, src)) 5460 ep.run_decompositions(decomp_table=torch._decomp.decomposition_table) 5461 self.assertEqual(ep.module()(t, dim, index, src), output) 5462 5463 def test_fqn(self): 5464 class NestedChild(torch.nn.Module): 5465 def forward(self, x): 5466 return x / x 5467 5468 class Child1(torch.nn.Module): 5469 def __init__(self) -> None: 5470 super().__init__() 5471 self.nested = NestedChild() 5472 self.register_parameter( 5473 "child1param", torch.nn.Parameter(torch.ones(2, 3)) 5474 ) 5475 5476 def forward(self, x): 5477 x = self.nested(x) 5478 return x + self.child1param 5479 5480 class Child2(torch.nn.Module): 5481 def __init__(self) -> None: 5482 super().__init__() 5483 self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) 5484 5485 def forward(self, x): 5486 return x - self.child2buffer 5487 5488 class MyModule(torch.nn.Module): 5489 def __init__(self) -> None: 5490 super().__init__() 5491 self.foo = Child1() 5492 self.bar = Child2() 5493 self.register_parameter( 5494 "rootparam", torch.nn.Parameter(torch.ones(2, 3)) 5495 ) 5496 5497 def forward(self, x): 5498 x = x * self.rootparam 5499 x = self.foo(x) 5500 x = self.bar(x) 5501 return x 5502 5503 orig_eager = MyModule() 5504 test_inp = torch.randn(2, 3) 5505 5506 torch_gm = _export_to_torch_ir(orig_eager, (torch.rand(2, 3),), {}) 5507 for k, v in orig_eager.state_dict().items(): 5508 normalized_k = k.replace(".", "_") 5509 self.assertIn(normalized_k, torch_gm.state_dict()) 5510 self.assertEqual(v, torch_gm.state_dict()[normalized_k]) 5511 self.assertTrue(torch.allclose(torch_gm(test_inp), orig_eager(test_inp))) 5512 5513 pre_autograd_gm = torch.export._trace._export( 5514 orig_eager, (torch.rand(2, 3),), {}, pre_dispatch=True 5515 ).module() 5516 for k, v in orig_eager.state_dict().items(): 5517 normalized_k = k.replace(".", "_") 5518 self.assertIn(k, pre_autograd_gm.state_dict()) 5519 self.assertEqual(v, pre_autograd_gm.state_dict()[k]) 5520 self.assertTrue(torch.allclose(pre_autograd_gm(test_inp), orig_eager(test_inp))) 5521 5522 ep = export(orig_eager, (torch.rand(2, 3),), {}) 5523 for k, v in orig_eager.state_dict().items(): 5524 # We do not need to normalize the key here because exported 5525 # program's state dict is able to contain the module information. 5526 self.assertIn(k, ep.state_dict) 5527 self.assertEqual(v, ep.state_dict[k]) 5528 self.assertTrue(torch.allclose(ep.module()(test_inp), orig_eager(test_inp))) 5529 5530 def test_nn_module_stack(self): 5531 class Leaf(torch.nn.Module): 5532 def __init__(self) -> None: 5533 super().__init__() 5534 self.linear = torch.nn.Linear(4, 4) 5535 5536 def forward(self, x): 5537 return self.linear(x) 5538 5539 class Bar(torch.nn.Module): 5540 def __init__(self) -> None: 5541 super().__init__() 5542 self.leaf = Leaf() 5543 self.buffer = torch.nn.Buffer(torch.randn(4, 4)) 5544 5545 def forward(self, x): 5546 return self.buffer.sum() + self.leaf(x).sum() 5547 5548 class Foo(torch.nn.Module): 5549 def __init__(self) -> None: 5550 super().__init__() 5551 self.bar = Bar() 5552 5553 def forward(self, x): 5554 y = self.bar.buffer + x 5555 return (self.bar(x) + y.sum(),) 5556 5557 inp = (torch.randn(4, 4),) 5558 mod = Foo() 5559 ep_strict = torch.export.export(mod, inp).run_decompositions() 5560 ep_non_strict = torch.export.export(mod, inp, strict=False).run_decompositions() 5561 5562 gm_unflat_non_strict = unflatten(ep_non_strict) 5563 self.assertTrue(hasattr(gm_unflat_non_strict, "bar")) 5564 self.assertTrue(hasattr(gm_unflat_non_strict.bar, "buffer")) 5565 self.assertTrue(hasattr(gm_unflat_non_strict.bar, "leaf")) 5566 5567 gm_unflat_strict = unflatten(ep_strict) 5568 5569 self.assertEqual(gm_unflat_non_strict(*inp), gm_unflat_strict(*inp)) 5570 self.assertExpectedInline( 5571 str(gm_unflat_non_strict.bar.leaf.linear.graph).strip(), 5572 """\ 5573graph(): 5574 %x : [num_users=1] = placeholder[target=x] 5575 %weight : [num_users=1] = get_attr[target=weight] 5576 %bias : [num_users=1] = get_attr[target=bias] 5577 %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%weight, [1, 0]), kwargs = {}) 5578 %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%bias, %x, %permute), kwargs = {}) 5579 return addmm""", 5580 ) 5581 5582 gm_flat_non_strict = ep_non_strict.module() 5583 gm_flat_strict = ep_strict.module() 5584 5585 self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp)) 5586 5587 def test_nn_module_stack_shared_submodule(self): 5588 class Leaf(torch.nn.Module): 5589 def __init__(self) -> None: 5590 super().__init__() 5591 self.linear = torch.nn.Linear(4, 4) 5592 5593 def forward(self, x): 5594 return self.linear(x) 5595 5596 class Bar(torch.nn.Module): 5597 def __init__(self) -> None: 5598 super().__init__() 5599 self.leaf = Leaf() 5600 self.buffer = torch.nn.Buffer(torch.randn(4, 4)) 5601 5602 def forward(self, x): 5603 return self.buffer.sum() + self.leaf(x).sum() 5604 5605 class BarDifferent(torch.nn.Module): 5606 def __init__(self) -> None: 5607 super().__init__() 5608 self.leaf = Leaf() 5609 5610 def forward(self, x): 5611 a = self.leaf(x).sum() 5612 b = self.leaf(x).sum() 5613 return a + b 5614 5615 class Foo(torch.nn.Module): 5616 def __init__(self) -> None: 5617 super().__init__() 5618 self.bar = Bar() 5619 self.bar_different = BarDifferent() 5620 5621 def forward(self, x): 5622 y = self.bar.buffer + x 5623 return ( 5624 self.bar(x) + self.bar_different(x + 2), 5625 y.sum(), 5626 ) 5627 5628 inp = (torch.randn(4, 4),) 5629 mod = Foo() 5630 ep_strict = torch.export.export(mod, inp) 5631 ep_non_strict = torch.export.export(mod, inp, strict=False) 5632 5633 gm_unflat_non_strict = unflatten(ep_non_strict) 5634 self.assertTrue(hasattr(gm_unflat_non_strict, "bar")) 5635 self.assertTrue(hasattr(gm_unflat_non_strict.bar, "buffer")) 5636 self.assertTrue(hasattr(gm_unflat_non_strict.bar, "leaf")) 5637 self.assertTrue(hasattr(gm_unflat_non_strict.bar_different, "leaf")) 5638 5639 gm_unflat_strict = unflatten(ep_strict) 5640 5641 self.assertEqual(gm_unflat_non_strict(*inp), gm_unflat_strict(*inp)) 5642 self.assertExpectedInline( 5643 str(gm_unflat_non_strict.bar.leaf.linear.graph).strip(), 5644 """\ 5645graph(): 5646 %x : [num_users=1] = placeholder[target=x] 5647 %weight : [num_users=1] = get_attr[target=weight] 5648 %bias : [num_users=1] = get_attr[target=bias] 5649 %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %weight, %bias), kwargs = {}) 5650 return linear""", 5651 ) 5652 self.assertExpectedInline( 5653 str(gm_unflat_non_strict.bar_different.leaf.linear.graph).strip(), 5654 """\ 5655graph(): 5656 %add_2 : [num_users=1] = placeholder[target=add_2] 5657 %weight : [num_users=1] = get_attr[target=weight] 5658 %bias : [num_users=1] = get_attr[target=bias] 5659 %linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%add_2, %weight, %bias), kwargs = {}) 5660 return linear_1""", 5661 ) 5662 5663 gm_flat_non_strict = ep_non_strict.module() 5664 gm_flat_strict = ep_strict.module() 5665 5666 self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp)) 5667 5668 def test_stack_trace(self): 5669 class Foo(torch.nn.Module): 5670 def __init__(self) -> None: 5671 super().__init__() 5672 self.linear = torch.nn.Linear(4, 4) 5673 5674 def forward(self, x): 5675 x = self.linear(x) 5676 x *= 2.0 5677 return x 5678 5679 ep = export( 5680 Foo(), 5681 (torch.randn(4, 4),), 5682 ) 5683 # check correct lines are in stack trace 5684 trace_mul = [node for node in ep.graph.nodes if node.name == "mul"][0].meta.get( 5685 "stack_trace", "" 5686 ) 5687 self.assertTrue( 5688 re.search(r"test_export.py.*in forward\n.*x \*= 2.0", trace_mul) 5689 ) 5690 trace_addmm = [ 5691 node for node in ep.graph.nodes if node.name in ["addmm", "linear"] 5692 ][0].meta.get("stack_trace", "") 5693 self.assertTrue( 5694 re.search( 5695 r"test_export.py.*in forward\n.*x = self.linear\(x\)", trace_addmm 5696 ) 5697 ) 5698 5699 def test_cond_with_module_stack_export_with(self): 5700 class Bar(torch.nn.Module): 5701 def __init__(self) -> None: 5702 super().__init__() 5703 self.linear = torch.nn.Linear(4, 4) 5704 5705 def forward(self, x): 5706 def true_fn(x): 5707 return self.linear(x).cos() 5708 5709 def false_fn(x): 5710 return self.linear(x).sin() 5711 5712 return torch.cond(x.sum() > 4, true_fn, false_fn, [x]) 5713 5714 class CondExport(torch.nn.Module): 5715 def __init__(self) -> None: 5716 super().__init__() 5717 self.bar = Bar() 5718 5719 def forward(self, x): 5720 return x.cos() + self.bar(x) 5721 5722 inp = (torch.randn(4, 4),) 5723 ep = torch.export.export(CondExport(), inp, strict=False) 5724 self.assertExpectedInline( 5725 ep.graph_module.code.strip(), 5726 """\ 5727def forward(self, p_bar_linear_weight, p_bar_linear_bias, x): 5728 cos = torch.ops.aten.cos.default(x) 5729 sum_1 = torch.ops.aten.sum.default(x) 5730 gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None 5731 true_graph_0 = self.true_graph_0 5732 false_graph_0 = self.false_graph_0 5733 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [p_bar_linear_bias, p_bar_linear_weight, x]); gt = true_graph_0 = false_graph_0 = p_bar_linear_bias = p_bar_linear_weight = x = None 5734 getitem = cond[0]; cond = None 5735 add = torch.ops.aten.add.Tensor(cos, getitem); cos = getitem = None 5736 return (add,)""", 5737 ) 5738 schema = get_hop_schema(ep) 5739 self.assertExpectedInline( 5740 str(schema), 5741 """cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""", 5742 ) 5743 5744 cond_top_level_nn_module_stack = [ 5745 node.meta["nn_module_stack"] 5746 for node in ep.graph.nodes 5747 if node.name == "true_graph_0" 5748 ][0] 5749 5750 self.assertTrue( 5751 "test_cond_with_module_stack_export_with.<locals>.Bar" 5752 in str(cond_top_level_nn_module_stack) 5753 ) 5754 5755 # TODO: See https://github.com/pytorch/pytorch/issues/115790 5756 @unittest.expectedFailure 5757 def test_cond_with_module_stack_export_with_unflatten(self): 5758 class Bar(torch.nn.Module): 5759 def __init__(self) -> None: 5760 super().__init__() 5761 self.linear = torch.nn.Linear(4, 4) 5762 5763 def forward(self, x): 5764 def true_fn(x): 5765 return self.linear(x).cos() 5766 5767 def false_fn(x): 5768 return self.linear(x).sin() 5769 5770 return torch.cond(x.shape[0] > 4, true_fn, false_fn, [x]) 5771 5772 class CondExport(torch.nn.Module): 5773 def __init__(self) -> None: 5774 super().__init__() 5775 self.bar = Bar() 5776 5777 def forward(self, x): 5778 return x.cos() + self.bar(x) 5779 5780 inp = (torch.randn(4, 4),) 5781 ep = torch.export.export(CondExport(), inp, strict=False) 5782 5783 cond_top_level_nn_module_stack = [ 5784 node.meta["nn_module_stack"] 5785 for node in ep.graph.nodes 5786 if node.name == "true_graph_0" 5787 ][0] 5788 5789 # we can't preserve nn_module_stack for the subgraphs for now. 5790 for node in ep.graph_module.true_graph_0.graph.nodes: 5791 self.assertEqual( 5792 node.meta["nn_module_stack"], cond_top_level_nn_module_stack 5793 ) 5794 5795 # this doesn't work today 5796 gm_unflat_strict = unflatten(ep) 5797 5798 def test_predispatch_cond(self): 5799 class Model(torch.nn.Module): 5800 def __init__(self) -> None: 5801 super().__init__() 5802 self.pred = torch.nn.Buffer(torch.tensor(False)) 5803 self.t = torch.nn.Buffer(torch.tensor(10)) 5804 5805 def forward(self, x, y): 5806 def true_fn(x, y): 5807 with torch.enable_grad(): 5808 return x - 1 + self.t + y 5809 5810 return torch.cond( 5811 self.pred, 5812 true_fn, 5813 lambda x, y: x + 1 - self.t + y, 5814 [x, y], 5815 ) 5816 5817 model = Model() 5818 with torch.no_grad(): 5819 exported_program = torch.export._trace._export( 5820 model, 5821 (torch.tensor(10), torch.tensor(12)), 5822 {}, 5823 dynamic_shapes=None, 5824 pre_dispatch=True, 5825 strict=False, 5826 ) 5827 5828 schema = get_hop_schema(exported_program) 5829 self.assertExpectedInline( 5830 str(schema), 5831 """cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""", # noqa: B950 5832 ) 5833 5834 self.assertExpectedInline( 5835 str(exported_program.graph_module.code.strip()), 5836 """\ 5837def forward(self, b_pred, b_t, x, y): 5838 true_graph_0 = self.true_graph_0 5839 false_graph_0 = self.false_graph_0 5840 cond = torch.ops.higher_order.cond(b_pred, true_graph_0, false_graph_0, [b_t, x, y]); b_pred = true_graph_0 = false_graph_0 = b_t = x = y = None 5841 getitem = cond[0]; cond = None 5842 return (getitem,)""", 5843 ) # noqa: B950 5844 5845 self.assertExpectedInline( 5846 str(exported_program.graph_module.true_graph_0.code.strip()), 5847 """\ 5848def forward(self, b_t, x, y): 5849 submod_3 = self.submod_1 5850 add_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_3, x, b_t, y); submod_3 = x = b_t = y = None 5851 getitem = add_1[0]; add_1 = None 5852 return (getitem,)""", 5853 ) 5854 5855 self.assertExpectedInline( 5856 str(exported_program.graph_module.true_graph_0.submod_1.code.strip()), 5857 """\ 5858def forward(self, x, b_t, y): 5859 sub = torch.ops.aten.sub.Tensor(x, 1); x = None 5860 add = torch.ops.aten.add.Tensor(sub, b_t); sub = b_t = None 5861 add_1 = torch.ops.aten.add.Tensor(add, y); add = y = None 5862 return (add_1,)""", 5863 ) 5864 5865 def test_predispatch_grad_wrappers(self): 5866 class Model(torch.nn.Module): 5867 def forward(self, x, y): 5868 with torch.enable_grad(): 5869 x = x - y 5870 with torch.no_grad(): 5871 x = x + y 5872 return x 5873 5874 # no grad 5875 model = Model() 5876 with torch.no_grad(): 5877 ep_nograd = torch.export._trace._export( 5878 model, 5879 (torch.tensor(10), torch.tensor(12)), 5880 {}, 5881 dynamic_shapes=None, 5882 pre_dispatch=True, 5883 strict=False, 5884 ) 5885 # check that only sub op is wrapped with grad_enabled 5886 getattr_nodes = [ 5887 node for node in ep_nograd.graph.nodes if node.op == "get_attr" 5888 ] 5889 self.assertEqual(len(getattr_nodes), 1) 5890 grad_subgraph = getattr(ep_nograd.graph_module, getattr_nodes[0].target) 5891 op_node = [ 5892 node for node in grad_subgraph.graph.nodes if node.op == "call_function" 5893 ][0] 5894 self.assertEqual(op_node.target._name, "aten::sub.Tensor") 5895 5896 # enable grad 5897 model = Model() 5898 ep_grad = torch.export._trace._export( 5899 model, 5900 (torch.tensor(10), torch.tensor(12)), 5901 {}, 5902 dynamic_shapes=None, 5903 pre_dispatch=True, 5904 strict=False, 5905 ) 5906 # check that only add op is wrapped with grad_enabled 5907 getattr_nodes = [node for node in ep_grad.graph.nodes if node.op == "get_attr"] 5908 self.assertEqual(len(getattr_nodes), 1) 5909 grad_subgraph = getattr(ep_grad.graph_module, getattr_nodes[0].target) 5910 op_node = [ 5911 node for node in grad_subgraph.graph.nodes if node.op == "call_function" 5912 ][0] 5913 self.assertEqual(op_node.target._name, "aten::add.Tensor") 5914 5915 @testing.expectedFailureRetraceability 5916 def test_layer_sharing(self): 5917 N, C, H, W = 1, 2, 2, 3 5918 5919 class Module(torch.nn.Module): 5920 def __init__(self) -> None: 5921 super().__init__() 5922 layer = torch.nn.LayerNorm([C, H, W]) 5923 self.norms = torch.nn.ModuleList( 5924 [ 5925 layer, 5926 layer, 5927 ] 5928 ) 5929 5930 def forward(self, x): 5931 for norm in self.norms: 5932 x = norm(x) 5933 return x 5934 5935 m = Module() 5936 copied_m = copy.deepcopy(m) 5937 ep = export(copied_m, (torch.randn(N, C, H, W),)) 5938 self.assertEqual(copied_m.state_dict(), m.state_dict()) 5939 self.assertEqual(ep.state_dict, m.state_dict()) 5940 5941 def test_non_persistent_buffer(self): 5942 class MyModule(torch.nn.Module): 5943 def __init__(self) -> None: 5944 super().__init__() 5945 self.foo = torch.nn.Buffer(torch.rand(2, 3), persistent=False) 5946 5947 def forward(self, x): 5948 return self.foo + x 5949 5950 class MyOuterModule(torch.nn.Module): 5951 def __init__(self) -> None: 5952 super().__init__() 5953 self.inner = MyModule() 5954 5955 def forward(self, x): 5956 return self.inner(x) 5957 5958 inp = torch.rand(2, 3) 5959 5960 def _test(m, non_persistent_buffer): 5961 ep = export(m, (inp,), {}) 5962 5963 self.assertEqual(ep.module()(inp), m(inp)) 5964 # Non-persistent buffers should not show up in the state dict 5965 self.assertNotIn(non_persistent_buffer, ep.state_dict) 5966 named_buffers = {name: buffer for (name, buffer) in ep.named_buffers()} 5967 # But they should show up in named_buffers() 5968 self.assertIn(non_persistent_buffer, named_buffers) 5969 self.assertIn(non_persistent_buffer, ep.constants) 5970 self.assertEqual(len(ep.constants), 1) 5971 5972 # Check the same properties of the unlifted module 5973 mod = ep.module() 5974 self.assertNotIn(non_persistent_buffer, mod.state_dict()) 5975 mod_named_buffers = {name: buffer for (name, buffer) in mod.named_buffers()} 5976 self.assertIn(non_persistent_buffer, mod_named_buffers) 5977 self.assertIn(non_persistent_buffer, ep.constants) 5978 self.assertEqual(len(ep.constants), 1) 5979 self.assertEqual(mod(inp), m(inp)) 5980 5981 _test(MyModule(), "foo") 5982 _test(MyOuterModule(), "inner.foo") 5983 5984 def test_export_with_set_grad_enabled(self): 5985 class Model(torch.nn.Module): 5986 def __init__(self) -> None: 5987 super().__init__() 5988 self.linear = torch.nn.Linear(4, 4) 5989 5990 def forward(self, x): 5991 with torch.no_grad(): 5992 return self.linear(x) 5993 5994 model = Model() 5995 ep = export(model, (torch.randn(4, 4),), {}) 5996 # _export_for_traininig is using pre_dispatch=False 5997 # Therefore the set_grad calls are not replaced with a hop. 5998 if not is_training_ir_test(self._testMethodName): 5999 self.assertIn( 6000 "torch.ops.higher_order.wrap_with_set_grad_enabled", 6001 ep.graph_module.code, 6002 ) 6003 6004 def test_export_as_backend(self): 6005 def f(x, y): 6006 return x + y 6007 6008 def my_custom_backend(gm, example_inputs): 6009 gm = ( 6010 torch.export.export(gm, tuple(example_inputs), strict=False) 6011 .run_decompositions() 6012 .module() 6013 ) 6014 return gm 6015 6016 inp = (torch.randn(3, 3), torch.randn(3, 3)) 6017 new_res = torch.compile(f, backend=my_custom_backend)(*inp) 6018 self.assertTrue(torch.allclose(f(*inp), new_res)) 6019 6020 def test_nonstrict_retrace_preserves_metadata(self): 6021 class MyModule(torch.nn.Module): 6022 def __init__(self) -> None: 6023 super().__init__() 6024 self.linear = torch.nn.Linear(4, 4) 6025 6026 def forward(self, x): 6027 return self.linear(x) 6028 6029 inp = torch.randn(4, 4) 6030 m = MyModule() 6031 ep = torch.export.export(m, (inp,), {}, strict=False) 6032 # retrace 6033 ep2 = torch.export.export(ep.module(), (inp,), {}, strict=False) 6034 6035 for n1, n2 in zip(list(ep.graph.nodes), list(ep2.graph.nodes)): 6036 self.assertEqual(n1.meta.get("stack_trace"), n2.meta.get("stack_trace")) 6037 6038 def test_fake_weights(self): 6039 class MyModule(torch.nn.Module): 6040 def __init__(self) -> None: 6041 super().__init__() 6042 self.foo = torch.nn.Parameter(torch.randn(4, 4)) 6043 self.bar = torch.nn.Buffer(torch.randn(4, 4), persistent=False) 6044 self.baz = torch.nn.Buffer(torch.randn(4, 4), persistent=True) 6045 6046 def forward(self, x): 6047 return self.foo + x + self.bar + self.baz 6048 6049 fake_mode = torch._subclasses.FakeTensorMode( 6050 shape_env=ShapeEnv(tracked_fakes=[]) 6051 ) 6052 with fake_mode: 6053 m = MyModule() 6054 inp = torch.randn(4, 4) 6055 ep = export(m, (inp,)) 6056 # Can't compare outputs because the module has fake weights. 6057 6058 def test_fake_inputs(self): 6059 class MyModule(torch.nn.Module): 6060 def __init__(self) -> None: 6061 super().__init__() 6062 self.foo = torch.nn.Parameter(torch.randn(4, 4)) 6063 6064 def forward(self, x): 6065 return self.foo + x 6066 6067 fake_mode = torch._subclasses.FakeTensorMode( 6068 shape_env=ShapeEnv(tracked_fakes=[]) 6069 ) 6070 m = MyModule() 6071 with fake_mode: 6072 inp = torch.randn(4, 4) 6073 6074 ep = export(m, (inp,)) 6075 self.assertEqual(ep.module()(torch.ones(4, 4)), m(torch.ones(4, 4))) 6076 6077 def test_trace_under_fake(self): 6078 class MyModule(torch.nn.Module): 6079 def __init__(self) -> None: 6080 super().__init__() 6081 self.foo = torch.nn.Parameter(torch.randn(4, 4)) 6082 6083 def forward(self, x): 6084 return self.foo + x 6085 6086 fake_mode = torch._subclasses.FakeTensorMode( 6087 shape_env=ShapeEnv(tracked_fakes=[]) 6088 ) 6089 with fake_mode: 6090 m = MyModule() 6091 inp = torch.randn(4, 4) 6092 # Can't use unqualified export() as it will attempt to deserialize 6093 # under a new FakeTensorMode. 6094 ep = torch.export.export(m, (inp,)) 6095 6096 def test_compiling_state(self): 6097 class TestModule1(torch.nn.Module): 6098 def forward(self, x): 6099 if torch._dynamo.is_compiling(): 6100 return x * 2 6101 else: 6102 return x * 3 6103 6104 class TestModule2(torch.nn.Module): 6105 def forward(self, x): 6106 if torch._utils.is_compiling(): 6107 return x * 2 6108 else: 6109 return x * 3 6110 6111 class TestModule3(torch.nn.Module): 6112 def forward(self, x): 6113 if torch.compiler.is_compiling(): 6114 return x * 2 6115 else: 6116 return x * 3 6117 6118 for m in [TestModule1(), TestModule2(), TestModule3()]: 6119 input = torch.randn(5) 6120 ep_strict = export(m, (input,), strict=True) 6121 ep_non_strict = export(m, (input,), strict=False) 6122 6123 self.assertTrue(torch.allclose(input * 3, m(input))) 6124 self.assertTrue(torch.allclose(input * 2, ep_strict.module()(input))) 6125 self.assertTrue(torch.allclose(input * 2, ep_non_strict.module()(input))) 6126 6127 def test_user_input_and_buffer_mutation(self): 6128 class MyModule(torch.nn.Module): 6129 def __init__(self) -> None: 6130 super().__init__() 6131 self.foo = torch.nn.Buffer(torch.randn(4, 4)) 6132 6133 def forward(self, x): 6134 self.foo.add_(1) 6135 x.add_(1) 6136 return self.foo + x 6137 6138 mod = MyModule() 6139 mod_copy = copy.deepcopy(mod) 6140 ep = export(mod_copy, (torch.rand(4, 4),)) 6141 6142 self.assertEqual(mod.foo, ep.module().foo) 6143 self.assertEqual(mod(torch.ones(4, 4)), ep.module()(torch.ones(4, 4))) 6144 6145 def test_symint_tensor_return(self): 6146 class Module(torch.nn.Module): 6147 def forward(self, x): 6148 return torch.ops.testlib.returns_tensor_symint(x)[0] 6149 6150 self._test_export_same_as_eager(Module(), (torch.randn(4, 4),)) 6151 6152 def test_custom_op_auto_functionalize(self): 6153 class M(torch.nn.Module): 6154 def __init__(self) -> None: 6155 super().__init__() 6156 6157 def forward(self, x, z): 6158 return torch.ops.testlib.foo(x, z) 6159 6160 inps = (torch.ones(5), torch.ones(5)) 6161 inps_for_export = (torch.ones(5), torch.ones(5)) 6162 inps_for_export_with_decomp = (torch.ones(5), torch.ones(5)) 6163 6164 ep = torch.export.export(M(), inps_for_export) 6165 x_new_eager, z_new_eager, legit_eager = M()(*inps) 6166 x_new_export, z_new_export, legit_export = ep.module()(*inps_for_export) 6167 self.assertTrue(torch.allclose(x_new_eager, x_new_export)) 6168 self.assertTrue(torch.allclose(z_new_eager, z_new_export)) 6169 self.assertTrue(torch.allclose(legit_eager, legit_export)) 6170 6171 ep = ep.run_decompositions() 6172 x_new_export, z_new_export, legit_export = ep.module()( 6173 *inps_for_export_with_decomp 6174 ) 6175 self.assertTrue(torch.allclose(x_new_eager, x_new_export)) 6176 self.assertTrue(torch.allclose(z_new_eager, z_new_export)) 6177 self.assertTrue(torch.allclose(legit_eager, legit_export)) 6178 6179 def test_custom_op_auto_functionalize_pre_dispatch(self): 6180 class M(torch.nn.Module): 6181 def __init__(self) -> None: 6182 super().__init__() 6183 6184 def forward(self, x): 6185 return torch.ops.testlib.foo_mutated(x) 6186 6187 inps = (torch.ones(5),) 6188 6189 ep = torch.export.export(M(), inps) 6190 self.assertExpectedInline( 6191 str(ep.graph_module.code.strip()), 6192 """\ 6193def forward(self, x): 6194 cos = torch.ops.aten.cos.default(x) 6195 auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos); x = cos = None 6196 getitem_3 = auto_functionalized[3]; auto_functionalized = None 6197 cos_1 = torch.ops.aten.cos.default(getitem_3) 6198 return (getitem_3, getitem_3, cos_1)""", 6199 ) 6200 6201 ep = torch.export._trace._export(M(), inps, pre_dispatch=True) 6202 self.assertExpectedInline( 6203 str(ep.graph_module.code.strip()), 6204 """\ 6205def forward(self, x): 6206 cos = torch.ops.aten.cos.default(x) 6207 auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos); x = cos = None 6208 getitem_3 = auto_functionalized[3]; auto_functionalized = None 6209 cos_1 = torch.ops.aten.cos.default(getitem_3) 6210 return (getitem_3, getitem_3, cos_1)""", 6211 ) 6212 6213 def test_custom_op_auto_warn_pre_dispatch(self): 6214 class M(torch.nn.Module): 6215 def __init__(self) -> None: 6216 super().__init__() 6217 6218 def forward(self, x): 6219 return torch.ops.testlib.foo_functional(x) 6220 6221 inps = (torch.ones(5),) 6222 6223 ep = torch.export.export(M(), inps).run_decompositions() 6224 self.assertExpectedInline( 6225 str(ep.graph_module.code.strip()), 6226 """\ 6227def forward(self, x): 6228 cos = torch.ops.aten.cos.default(x) 6229 cos_1 = torch.ops.aten.cos.default(x); x = None 6230 auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = cos, z = cos_1); cos = cos_1 = None 6231 getitem_3 = auto_functionalized[3]; auto_functionalized = None 6232 cos_2 = torch.ops.aten.cos.default(getitem_3); getitem_3 = None 6233 return (cos_2,)""", 6234 ) 6235 6236 ep = torch.export._trace._export(M(), inps, pre_dispatch=True) 6237 self.assertExpectedInline( 6238 str(ep.graph_module.code.strip()), 6239 """\ 6240def forward(self, x): 6241 foo_functional = torch.ops.testlib.foo_functional.default(x); x = None 6242 return (foo_functional,)""", 6243 ) 6244 6245 def test_placeholder_naming_collisions(self): 6246 # test collisions between nested user inputs 6247 class Foo(torch.nn.Module): 6248 def forward(self, x, x_foo, x_foo_0): 6249 return x["foo"][0] + x_foo[0] + x_foo_0 6250 6251 inputs = ( 6252 {"foo": [torch.randn(4, 4)]}, 6253 (torch.randn(4, 4),), 6254 torch.randn(4, 4), 6255 ) 6256 ep = export(Foo(), inputs) 6257 expected_names = ["x_foo_0", "x_foo_0_1", "x_foo_0_2"] 6258 real_names = [spec.arg.name for spec in ep.graph_signature.input_specs] 6259 self.assertEqual(expected_names, real_names) 6260 6261 # test collisions between user inputs and params, buffers, constants 6262 class Foo(torch.nn.Module): 6263 def __init__(self) -> None: 6264 super().__init__() 6265 self.param = torch.nn.Parameter(torch.randn(4)) 6266 self.alpha = torch.nn.Buffer(torch.randn(4), persistent=True) 6267 self.beta = torch.nn.Buffer(torch.randn(4), persistent=False) 6268 self.gamma = torch.randn(4) 6269 6270 def forward(self, p, b_alpha, b, c_gamma): 6271 p = p["param"] + self.param 6272 b = self.alpha + self.beta + b_alpha + b["beta"] 6273 c = self.gamma + c_gamma 6274 return p, b, c 6275 6276 inputs = ( 6277 {"param": torch.randn(4)}, 6278 torch.randn(4), 6279 {"beta": torch.randn(4)}, 6280 torch.randn(4), 6281 ) 6282 ep = export(Foo(), inputs) 6283 expected_names = [ # user inputs should be prioritized, unprefixed 6284 ("p_param_1", InputKind.PARAMETER), 6285 ("b_alpha_1", InputKind.BUFFER), 6286 ("b_beta_1", InputKind.BUFFER), 6287 ("c_gamma_1", InputKind.CONSTANT_TENSOR), 6288 ("p_param", InputKind.USER_INPUT), 6289 ("b_alpha", InputKind.USER_INPUT), 6290 ("b_beta", InputKind.USER_INPUT), 6291 ("c_gamma", InputKind.USER_INPUT), 6292 ] 6293 real_names = [ 6294 (spec.arg.name, spec.kind) for spec in ep.graph_signature.input_specs 6295 ] 6296 self.assertEqual(expected_names, real_names) 6297 6298 # test collisions between user inputs & call_function nodes 6299 class Foo(torch.nn.Module): 6300 def forward(self, mul, add, add_1): 6301 return mul * mul + add * add_1 6302 6303 ep = export(Foo(), (torch.randn(4, 4), torch.randn(4, 4), torch.randn(4, 4))) 6304 expected_names_and_ops = [ 6305 ("mul", "placeholder"), 6306 ("add", "placeholder"), 6307 ("add_1", "placeholder"), 6308 ("mul_1", "call_function"), 6309 ("mul_2", "call_function"), 6310 ("add_2", "call_function"), 6311 ("output", "output"), 6312 ] 6313 real_names_and_ops = [(node.name, node.op) for node in ep.graph.nodes] 6314 self.assertEqual(expected_names_and_ops, real_names_and_ops) 6315 6316 def test_placeholder_naming_collisions_hoo_subgraphs(self): 6317 # test collisions between user inputs, top-level nodes, and HOO subgraph nodes 6318 class Foo(torch.nn.Module): 6319 def forward(self, x, mul, mul_1): 6320 _mul = x * x 6321 y = cond( 6322 _mul.sum() > 0, 6323 lambda x, y, z: x * y * z, 6324 lambda x, y, z: x + y + z, 6325 [_mul, mul, mul_1], 6326 ) 6327 with torch.enable_grad(): 6328 y = y * y 6329 return y 6330 6331 with torch.no_grad(): 6332 ep = torch.export._trace._export( 6333 Foo(), 6334 (torch.randn(4), torch.randn(4), torch.randn(4)), 6335 pre_dispatch=True, 6336 ) 6337 6338 schema = get_hop_schema(ep) 6339 self.assertExpectedInline( 6340 str(schema), 6341 """cond(Tensor pred, GraphModule true_fn, GraphModule false_fn, Tensor[3] operands) -> Tensor[1]""", 6342 ) 6343 # test cond subgraph 6344 expected_names_and_ops = [ 6345 ("mul_2", "placeholder"), 6346 ("mul", "placeholder"), 6347 ("mul_1", "placeholder"), 6348 ("mul_3", "call_function"), 6349 ("mul_4", "call_function"), 6350 ("output", "output"), 6351 ] 6352 real_names_and_ops = [ 6353 (node.name, node.op) for node in ep.graph_module.true_graph_0.graph.nodes 6354 ] 6355 self.assertEqual(expected_names_and_ops, real_names_and_ops) 6356 # test set_grad_enabled subgraph 6357 expected_names_and_ops = [ 6358 ("getitem", "placeholder"), 6359 ("mul_1", "call_function"), 6360 ("output", "output"), 6361 ] 6362 real_names_and_ops = [ 6363 (node.name, node.op) for node in ep.graph_module.submod_1.graph.nodes 6364 ] 6365 self.assertEqual(expected_names_and_ops, real_names_and_ops) 6366 6367 # test collisions between user inputs & higher order op subgraphs 6368 # (please never do this) 6369 class Foo(torch.nn.Module): 6370 def forward(self, input, true_graph, body_graph): 6371 x = input + true_graph[0] + true_graph[1] 6372 x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x]) 6373 x = cond(x.sum() > 0, lambda x: x * 2.0, lambda x: x + 2.0, [x]) 6374 return x 6375 6376 inputs = ( 6377 torch.randn(10, 4), 6378 (torch.randn(4), torch.randn(4)), 6379 (torch.randn(4),), 6380 ) 6381 ep = export(Foo(), inputs) 6382 expected_getattr_names = [ 6383 "true_graph_2", 6384 "false_graph_0", 6385 "true_graph_3", 6386 "false_graph_1", 6387 ] 6388 real_getattr_names = [ 6389 node.name for node in ep.graph.nodes if node.op == "get_attr" 6390 ] 6391 self.assertEqual(expected_getattr_names, real_getattr_names) 6392 6393 def test_constant_input_naming(self): 6394 class Foo(torch.nn.Module): 6395 def forward(self, x, y, div="floor"): 6396 return torch.div(x, y, rounding_mode=div) 6397 6398 f = Foo() 6399 inputs = (torch.randn(4), torch.randn(4), "floor") 6400 ep = export(f, inputs) 6401 div_spec = ep.graph_signature.input_specs[2] 6402 self.assertEqual(div_spec.arg.name, "div") 6403 self.assertEqual(div_spec.arg.value, "floor") 6404 6405 def test_unbacked_deferred_runtime_retrace(self): 6406 class Foo(torch.nn.Module): 6407 def forward(self, x, y): 6408 y_sum = y.sin().sum() 6409 with torch.no_grad(): 6410 a = x.item() 6411 torch._check_is_size(a) 6412 torch._check(a > 2) 6413 torch._check(a < 6) 6414 unbacked_shape = torch.ops.testlib.foo_unbacked(a) 6415 return y + y_sum + unbacked_shape.sum() 6416 6417 inps = (torch.tensor(4), torch.randn(5, 5)) 6418 from torch.export import _trace 6419 6420 ep_pre = _trace._export(Foo(), inps, pre_dispatch=True, strict=False) 6421 self.assertExpectedInline( 6422 str(ep_pre.graph_module.submod_1.code).strip(), 6423 """\ 6424def forward(self, x): 6425 item = torch.ops.aten.item.default(x); x = None 6426 sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item); sym_constrain_range_for_size_default = None 6427 ge_1 = item >= 3 6428 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 3 on node 'ge_1'"); ge_1 = _assert_scalar_default = None 6429 le = item <= 5 6430 _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u1 <= 5 on node 'le'"); le = _assert_scalar_default_1 = None 6431 gt_1 = item > 2 6432 _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 2 < u1 on node 'gt_1'"); gt_1 = _assert_scalar_default_2 = None 6433 lt_1 = item < 6 6434 _assert_scalar_default_3 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'"); lt_1 = _assert_scalar_default_3 = None 6435 foo_unbacked = torch.ops.testlib.foo_unbacked.default(item); item = None 6436 return (foo_unbacked,)""", 6437 ) 6438 ep_aot = ep_pre.run_decompositions() 6439 self.assertExpectedInline( 6440 str(ep_aot.graph_module.code).strip(), 6441 """\ 6442def forward(self, x, y): 6443 sin = torch.ops.aten.sin.default(y) 6444 sum_1 = torch.ops.aten.sum.dim_IntList(sin, []); sin = None 6445 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x); x = None 6446 sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense); sym_constrain_range_for_size_default = None 6447 ge_1 = _local_scalar_dense >= 3 6448 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u3 >= 3 on node 'ge_1'"); ge_1 = _assert_scalar_default = None 6449 le_1 = _local_scalar_dense <= 5; _local_scalar_dense = None 6450 _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u3 <= 5 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None 6451 full = torch.ops.aten.full.default([4, 4], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) 6452 add = torch.ops.aten.add.Tensor(y, sum_1); y = sum_1 = None 6453 sum_2 = torch.ops.aten.sum.dim_IntList(full, []); full = None 6454 add_1 = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None 6455 return (add_1,)""", 6456 ) 6457 6458 def test_nested_dynamic_shapes_spec(self): 6459 class Foo(torch.nn.Module): 6460 def forward(self, x): 6461 (a0, a1), (b0, b1), (c0, c1, c2) = x 6462 return a0 + a1 + b0 + b1 + c0 + c1 + c2 6463 6464 f = Foo() 6465 inputs = ( 6466 (1, 2), 6467 ( 6468 torch.randn(4, 4), 6469 torch.randn(4, 4), 6470 ), 6471 ( 6472 torch.randn(4, 4), 6473 torch.randn(4, 4), 6474 torch.randn(4, 4), 6475 ), 6476 ) 6477 # make sure this gets parsed correctly as 7 individual inputs, not 3 tensors 6478 dynamic_shapes = { 6479 "x": ( 6480 (None, None), 6481 (None, None), 6482 (None, None, None), 6483 ) 6484 } 6485 export(f, (inputs,), dynamic_shapes=dynamic_shapes) 6486 6487 def test_disable_forced_specializations_ok(self): 6488 # check that we don't force specialization, and defer to runtime asserts 6489 # with allow_complex_guards_as_runtime_asserts=True to successfully export 6490 # case 1: modulo guards 6491 from torch.export import dims 6492 6493 class Mod4Reshape(torch.nn.Module): 6494 def forward(self, x): 6495 return x.reshape(x.shape[0] - 1, 4, -1) # Mod(s0*s1, 4*(s0-1)) = 0 6496 6497 inputs = (torch.randn(10, 72),) 6498 dx, dy = dims("dx", "dy") 6499 ep = torch.export._trace._export( 6500 Mod4Reshape(), 6501 inputs, 6502 dynamic_shapes={"x": (dx, dy)}, 6503 allow_complex_guards_as_runtime_asserts=True, 6504 ) 6505 out1 = ep.module()(torch.randn(8, 7)) 6506 self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape) 6507 out2 = ep.module()(torch.randn(12, 11)) 6508 self.assertEqual(out2.shape, torch.ones(11, 4, 3).shape) 6509 with self.assertRaisesRegex( 6510 RuntimeError, 6511 r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, 4\*s0 \- 4\), 0\) on node 'eq.*'", 6512 ): 6513 ep.module()(torch.randn(8, 8)) # fail 6514 6515 # case 2: 2d reshape 6516 class FreeReshape(torch.nn.Module): 6517 def forward(self, x, y, z): 6518 return x.reshape([-1]) + y.reshape([-1]) + z # s0*s1 = s2*s3 = s4 6519 6520 inputs = ( 6521 torch.randn(6, 8), 6522 torch.randn(3, 16), 6523 torch.randn(48), 6524 ) 6525 dynamic_shapes = { 6526 "x": [Dim(f"dx{i}", min=2) for i in range(2)], 6527 "y": [Dim(f"dy{i}", min=2) for i in range(2)], 6528 "z": [Dim(f"dz{i}", min=4) for i in range(1)], 6529 } 6530 ep = torch.export._trace._export( 6531 FreeReshape(), 6532 inputs, 6533 dynamic_shapes=dynamic_shapes, 6534 allow_complex_guards_as_runtime_asserts=True, 6535 ) 6536 ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes) 6537 out1 = ep.module()(torch.randn(48, 1), torch.randn(4, 12), torch.randn(48)) 6538 self.assertEqual(out1.shape, torch.ones(48).shape) 6539 out2 = ep.module()(torch.randn(5, 8), torch.randn(4, 10), torch.randn(40)) 6540 self.assertEqual(out2.shape, torch.ones(40).shape) 6541 with self.assertRaisesRegex( 6542 RuntimeError, 6543 r"Runtime assertion failed for expression Eq\(s0\*s1, s2\*s3\) on node 'eq.*'", 6544 ): # fail only at runtime 6545 ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail 6546 6547 # case 3: 3d reshape (previously failing with different issue) 6548 class Reshape3d(torch.nn.Module): 6549 def forward(self, x, y): 6550 return x.reshape([-1]) + y # s0*s1*s2 = s3 6551 6552 inputs = ( 6553 torch.randn(4, 3, 2), 6554 torch.randn(24), 6555 ) 6556 dynamic_shapes = { 6557 "x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)), 6558 "y": (Dim("dy", min=8),), 6559 } 6560 ep = torch.export._trace._export( 6561 Reshape3d(), 6562 inputs, 6563 dynamic_shapes=dynamic_shapes, 6564 allow_complex_guards_as_runtime_asserts=True, 6565 ) 6566 out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126)) 6567 self.assertEqual(out1.shape, torch.ones(126).shape) 6568 with self.assertRaisesRegex( 6569 RuntimeError, 6570 r"Runtime assertion failed for expression Eq\(s0\*s1\*s2, s3\) on node 'eq.*'", 6571 ): # fail only at runtime 6572 ep.module()(torch.randn(4, 3, 2), torch.randn(10)) # fail 6573 6574 def test_disable_forced_specializations_errors(self): 6575 # check error messages with hybrid symints 6576 class Foo(torch.nn.Module): 6577 def forward(self, w, x, y, z): 6578 return w.reshape([-1]) + x, y + z # simple: s0*s1 = s2, s3 = s4 6579 6580 inputs = ( 6581 torch.randn(3, 4), 6582 torch.randn(12), 6583 torch.randn(4), 6584 torch.randn(4), 6585 ) 6586 dynamic_shapes = { 6587 "w": [Dim(f"dw{i}") for i in range(2)], 6588 "x": [Dim(f"dx{i}") for i in range(1)], 6589 "y": [Dim("dy")], # y & z incorrect, export is supposed to fail. 6590 "z": [Dim("dz")], # suggested fix should be to match these up. 6591 } 6592 with self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize. 6593 torch._dynamo.exc.UserError, 6594 r".*Constraints violated(.*\n)*" 6595 r"Suggested fixes:(.*\n)*" 6596 r".*dz = dy(.*\n)*", 6597 ) as msg: 6598 export( 6599 Foo(), 6600 inputs, 6601 dynamic_shapes=dynamic_shapes, 6602 strict=False, 6603 ) 6604 6605 # TODO requires_grad doesn't seem to work with serialization. 6606 @testing.expectedFailureSerDer 6607 def test_preserve_requires_grad_placeholders(self): 6608 class Module(torch.nn.Module): 6609 def __init__(self) -> None: 6610 super().__init__() 6611 self.p = torch.nn.Parameter(torch.randn(3, 3)) 6612 6613 def forward(self, x, y): 6614 return self.p + x + y 6615 6616 m = Module() 6617 ep = export(m, (torch.randn(3, 3), torch.randn(3, 3, requires_grad=True))) 6618 placeholders = [ 6619 node for node in ep.graph_module.graph.nodes if node.op == "placeholder" 6620 ] 6621 self.assertTrue(placeholders[0].meta["val"].requires_grad) 6622 self.assertFalse(placeholders[1].meta["val"].requires_grad) 6623 self.assertTrue(placeholders[2].meta["val"].requires_grad) 6624 6625 def test_reshape_view_helper(self): 6626 # see: https://github.com/pytorch/pytorch/issues/126607 6627 class Model(torch.nn.Module): 6628 def __init__(self) -> None: 6629 super().__init__() 6630 6631 def forward(self, x): 6632 x = x.view(x.size(1), -1) 6633 # torch/_refs/__init__/_reshape_view_helper() will generate guards on reshape kernel(?) 6634 # Ne(s0, 20), so that reshape isn't no-op 6635 # Ne(Mod(s0, 20), 0), so that reshape needs to first flatten [s0, 20, 16] -> [s0*20, 16] 6636 # then split_dim -> [20, s0, 16] 6637 # check that these show up in graph 6638 return torch.nn.functional.softmax( 6639 x, dim=0 6640 ) # don't think softmax actually creates any issues, just part of original test 6641 6642 model = Model() 6643 x = torch.rand(1024, 20, 16) 6644 dynamic_shapes = {"x": {0: Dim("batch")}} 6645 ep = torch.export._trace._export( 6646 model, 6647 (x,), 6648 dynamic_shapes=dynamic_shapes, 6649 allow_complex_guards_as_runtime_asserts=True, 6650 ) 6651 with self.assertRaisesRegex( 6652 RuntimeError, 6653 r"Runtime assertion failed for expression Ne\(s0, 20\)", 6654 ): 6655 ep.module()(torch.randn(20, 20, 16)) 6656 with self.assertRaisesRegex( 6657 RuntimeError, 6658 r"Runtime assertion failed for expression Ne\(Mod\(s0, 20\), 0\)", 6659 ): 6660 ep.module()(torch.randn(400, 20, 16)) 6661 ep.module()(torch.randn(42, 20, 16)) 6662 6663 def test_allow_explicit_guards_as_runtime_asserts(self): 6664 # check that explicit guards are treated as runtime assertions 6665 class Foo(torch.nn.Module): 6666 def forward(self, x, y): 6667 # check that negation of first guard also shows up as runtime assertion 6668 if x.shape[0] == y.shape[0]: # False 6669 return x + y 6670 elif x.shape[0] == y.shape[0] ** 3: # False 6671 return x + 2, y + 3 6672 elif x.shape[0] ** 2 == y.shape[0] * 3: # True 6673 return x * 2.0, y * 3.0 6674 6675 inputs = (torch.randn(6), torch.randn(12)) 6676 dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]} 6677 ep = torch.export._trace._export( 6678 Foo(), 6679 inputs, 6680 dynamic_shapes=dynamic_shapes, 6681 allow_complex_guards_as_runtime_asserts=True, 6682 ) 6683 # check forward pass 6684 out0, out1 = ep.module()(torch.randn(9), torch.randn(27)) 6685 self.assertEqual(out0.shape, torch.ones(9).shape) 6686 self.assertEqual(out1.shape, torch.ones(27).shape) 6687 with self.assertRaisesRegex( 6688 RuntimeError, 6689 r"Runtime assertion failed for expression Ne\(s0, s1\)", 6690 ): # fail only at runtime 6691 ep.module()(torch.randn(4), torch.randn(4)) # fail 6692 with self.assertRaisesRegex( 6693 RuntimeError, 6694 r"Runtime assertion failed for expression Ne\(s0, s1\**3\)", 6695 ): 6696 ep.module()(torch.randn(64), torch.randn(4)) # fail 6697 with self.assertRaisesRegex( 6698 RuntimeError, 6699 r"Runtime assertion failed for expression Eq\(s0\**2, 3\*s1\)", 6700 ): 6701 ep.module()(torch.randn(10), torch.randn(9)) # fail 6702 6703 # this should be set with command line flag TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1, 6704 # but dynamo checks that at torch import time, so setting os.environ makes no difference 6705 # instead, manually patch dynamo config and test. 6706 # test that setting this flag removes runtime asserts 6707 from torch._dynamo import config as _dynamo_config 6708 6709 with _dynamo_config.patch( 6710 do_not_emit_runtime_asserts=True, 6711 ): 6712 ep = torch.export._trace._export( 6713 Foo(), 6714 inputs, 6715 dynamic_shapes=dynamic_shapes, 6716 allow_complex_guards_as_runtime_asserts=True, 6717 ).run_decompositions() 6718 6719 self.assertEqual( 6720 [ 6721 node.target == torch.ops.aten._assert_scalar.default 6722 for node in ep.graph.nodes 6723 ].count(True), 6724 0, 6725 ) 6726 6727 def test_constant_aliasing(self): 6728 class M1(torch.nn.Module): 6729 def __init__(self, m2, foo): 6730 super().__init__() 6731 self.m2 = m2 6732 self.foo = foo 6733 6734 def forward(self, x): 6735 return x + self.foo + self.m2(x) 6736 6737 class M2(torch.nn.Module): 6738 def __init__(self) -> None: 6739 super().__init__() 6740 self.foo = torch.ones(3, 3) 6741 6742 def forward(self, x): 6743 return x + self.foo 6744 6745 m2 = M2() 6746 m1 = M1(m2, m2.foo) 6747 inps = (torch.ones(3, 3),) 6748 ep = torch.export.export(m1, inps, strict=False) 6749 # check both constants appear in list 6750 self.assertEqual(sorted(list(ep.constants)), ["foo", "m2.foo"]) 6751 # check only one input spec exists 6752 num_constant_inputs = [ 6753 spec.kind == InputKind.CONSTANT_TENSOR 6754 for spec in ep.graph_signature.input_specs 6755 ].count(True) 6756 self.assertEqual(num_constant_inputs, 1) 6757 # unflatten 6758 unflattened = unflatten(ep) 6759 self.assertTrue(torch.allclose(m1(*inps), unflattened(*inps))) 6760 6761 @testing.expectedFailureRetraceability 6762 def test_unused_aliases(self): 6763 class Foo(torch.nn.Module): 6764 def __init__(self) -> None: 6765 super().__init__() 6766 # param 6767 self.alpha = torch.nn.Parameter(torch.randn(4)) 6768 self.beta = self.alpha 6769 self.gamma = self.alpha 6770 6771 def forward(self, x): 6772 return x + self.gamma 6773 6774 inps = (torch.randn(4),) 6775 ep = export(Foo(), inps) 6776 # placeholder nodes will be deduplicated in strict-mode, 6777 # but check that all params still appear in state dict 6778 for param in ["alpha", "beta", "gamma"]: 6779 self.assertTrue(param in ep.state_dict) 6780 6781 # check that they also appear in unflattened state dict 6782 unep = unflatten(ep) 6783 for param in ["alpha", "beta", "gamma"]: 6784 self.assertTrue(param in unep.state_dict()) 6785 6786 def test_intermediate_shape_comp(self): 6787 class Foo(torch.nn.Module): 6788 def forward(self, x, y): 6789 z = torch.cat([x, x], dim=0) 6790 w = z.repeat(y.shape[0]) 6791 return w.shape[0] + x.shape[0] 6792 6793 inputs = (torch.randn(6), torch.randn(4)) 6794 shapes = { 6795 "x": (Dim("dx0"),), 6796 "y": (Dim("dy"),), 6797 } 6798 ep = export( 6799 Foo(), 6800 inputs, 6801 dynamic_shapes=shapes, 6802 ) 6803 # test that shape is from size compute, not sym_size call 6804 add_node = [node for node in ep.graph.nodes if node.target == operator.add][0] 6805 self.assertTrue(add_node.args[0].target == operator.mul) 6806 # test sym_size calls only happen on placeholders 6807 sym_size_nodes = [ 6808 node 6809 for node in ep.graph.nodes 6810 if node.target == torch.ops.aten.sym_size.int 6811 ] 6812 self.assertEqual(len(sym_size_nodes), 2) 6813 self.assertTrue( 6814 all(node.args[0].op == "placeholder" for node in sym_size_nodes) 6815 ) 6816 # dynamo will DCE the repeat node, AOTAutograd will leave it 6817 # training IR will also DCE due to retracing 6818 repeat_nodes = [ 6819 node 6820 for node in ep.graph.nodes 6821 if node.target == torch.ops.aten.repeat.default 6822 ] 6823 self.assertEqual( 6824 len(repeat_nodes), 6825 1 6826 if is_non_strict_test(self._testMethodName) 6827 and not is_training_ir_test(self._testMethodName) 6828 else 0, 6829 ) 6830 6831 def test_checks_to_constrain_range(self): 6832 class Foo(torch.nn.Module): 6833 def forward(self, x, y): 6834 n = y.item() 6835 m = y.item() 6836 torch._check_is_size(n) 6837 torch._check(m >= 0) 6838 torch._check(n >= 3) 6839 torch._check(-m >= -9) # m <= 9 6840 torch._check(n <= 6) 6841 # n has range [3, 9] 6842 return x[:n] 6843 6844 inputs = (torch.randn(10), torch.tensor(6)) 6845 ep = export(Foo(), inputs) 6846 FileCheck().check_count( 6847 "torch.ops.aten._assert_scalar.default", 2, exactly=True 6848 ).run(ep.graph_module.code) 6849 FileCheck().check_count( 6850 "torch.ops.aten.sym_constrain_range.default", 0, exactly=True 6851 ).run(ep.graph_module.code) 6852 FileCheck().check_count( 6853 "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True 6854 ).run(ep.graph_module.code) 6855 6856 ep = ep.run_decompositions() 6857 FileCheck().check_count( 6858 "torch.ops.aten._assert_scalar.default", 2, exactly=True 6859 ).run(ep.graph_module.code) 6860 FileCheck().check_count( 6861 "torch.ops.aten.sym_constrain_range.default", 0, exactly=True 6862 ).run(ep.graph_module.code) 6863 FileCheck().check_count( 6864 "torch.ops.aten.sym_constrain_range_for_size.default", 1, exactly=True 6865 ).run(ep.graph_module.code) 6866 6867 # check runtime 6868 ep.module()(torch.randn(10), torch.tensor(5)) 6869 with self.assertRaisesRegex( 6870 RuntimeError, 6871 r"Runtime assertion failed for expression u[\d+] \>\= 3", 6872 ): 6873 ep.module()(torch.randn(10), torch.tensor(2)) 6874 6875 def test_cse_for_symint(self): 6876 class Foo(torch.nn.Module): 6877 # check sym ops only get computed once 6878 def forward(self, x, y): 6879 if ( 6880 x.shape[0] ** 2 - y.shape[0] ** 2 >= 4 # 16 6881 and x.shape[0] ** 2 - y.shape[0] ** 2 <= 20 6882 and x.shape[0] ** 2 - y.shape[0] ** 2 != 15 6883 ): 6884 return x * 2, y * 2 6885 6886 inputs = (torch.randn(5), torch.randn(3)) 6887 shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)} 6888 ep = torch.export._trace._export( 6889 Foo(), 6890 inputs, 6891 dynamic_shapes=shapes, 6892 allow_complex_guards_as_runtime_asserts=True, 6893 ) 6894 # count 2 pow nodes, 2 sym_size.int nodes 6895 self.assertEqual( 6896 [node.target for node in ep.graph.nodes].count( 6897 operator.pow, 6898 ), 6899 2, 6900 ) 6901 FileCheck().check_count("torch.ops.aten.sym_size.int", 2, exactly=True).run( 6902 ep.graph_module.code 6903 ) 6904 6905 ep = ep.run_decompositions() 6906 self.assertEqual( 6907 [node.target for node in ep.graph.nodes].count( 6908 operator.pow, 6909 ), 6910 2, 6911 ) 6912 FileCheck().check_count("torch.ops.aten.sym_size.int", 2, exactly=True).run( 6913 ep.graph_module.code 6914 ) 6915 6916 def test_slice_with_floordiv(self): 6917 # slice operation emits runtime assert s0//2 <= s1 6918 class M1(torch.nn.Module): 6919 def forward(self, x, y): 6920 d = x.size(0) // 2 6921 return y[d:] 6922 6923 class M(torch.nn.Module): 6924 def __init__(self) -> None: 6925 super().__init__() 6926 self.m1 = M1() 6927 6928 def forward(self, x, y): 6929 d = x.size(0) // 2 6930 m1_res = self.m1(x, y) 6931 return y[d:] + m1_res 6932 6933 inputs = (torch.ones(10), torch.ones(10)) 6934 d0 = torch.export.Dim("d0", max=2048) 6935 d1 = torch.export.Dim("d1", max=2048) 6936 ep = export( 6937 M(), 6938 inputs, 6939 dynamic_shapes=((d0,), (d1,)), 6940 ) 6941 ep.module()(torch.ones(8), torch.ones(4)) 6942 ep.module()(torch.ones(8), torch.ones(5)) 6943 with self.assertRaisesRegex( 6944 RuntimeError, 6945 r"Runtime assertion failed for expression \(s0//2\) \<\= s1", 6946 ): 6947 ep.module()(torch.ones(10), torch.ones(4)) 6948 6949 def test_split_const_gm_with_lifted_constants(self): 6950 class Model(torch.nn.Module): 6951 def __init__(self) -> None: 6952 super().__init__() 6953 self.w_pre = torch.randn(4, 4) 6954 self.b = torch.randn(4) 6955 6956 def forward(self, x): 6957 w_transpose = torch.transpose(self.w_pre, 0, 1) 6958 w_relu = torch.nn.functional.relu(w_transpose) 6959 w = w_relu + self.b 6960 return torch.matmul(x, w) 6961 6962 example_inputs = (torch.randn(4, 4),) 6963 mod = Model() 6964 ep = torch.export.export(mod, example_inputs) 6965 new_gm = copy.deepcopy(ep.graph_module) 6966 new_sig = copy.deepcopy(ep.graph_signature) 6967 placeholder_nodes = [ 6968 node for node in new_gm.graph.nodes if node.op == "placeholder" 6969 ] 6970 constants = {**ep.state_dict, **ep.constants} 6971 lifted_constants = { 6972 n.name: constants[spec.target] 6973 for n, spec in zip(placeholder_nodes, new_sig.input_specs) 6974 if spec.target is not None 6975 } 6976 const_gm, _ = split_const_gm(new_gm, lifted_constants) 6977 counter = 0 6978 for node in const_gm.graph.nodes: 6979 if node.op == "call_function": 6980 counter += 1 6981 self.assertTrue(counter > 0) 6982 test_input = torch.randn(4, 4) 6983 expected = new_gm(None, None, test_input)[0] 6984 actual = mod(test_input) 6985 self.assertEqual(actual, expected) 6986 const_gm, _ = split_const_gm(ep.graph_module, lifted_constants, lambda x: True) 6987 counter = 0 6988 for node in const_gm.graph.nodes: 6989 if node.op == "call_function": 6990 self.assertTrue(False) 6991 6992 @testing.expectedFailureTrainingIRToRunDecomp # T200904004 6993 @testing.expectedFailureTrainingIRToRunDecompNonStrict 6994 def test_istft_op(self): 6995 class istft_class(torch.nn.Module): 6996 def forward(self, spec): 6997 window = torch.hann_window(1024).type(torch.FloatTensor) 6998 return torch.istft( 6999 spec, 7000 n_fft=1024, 7001 hop_length=512, 7002 window=window, 7003 length=144000, 7004 ) 7005 7006 model = istft_class() 7007 real_part = torch.randn(1, 513, 282, dtype=torch.float32) 7008 imaginary_part = torch.randn(1, 513, 282, dtype=torch.float32) 7009 spec = torch.complex(real_part, imaginary_part) 7010 export(model, (spec,)) 7011 7012 def test_automatic_dynamic_shapes_simple_equality(self): 7013 # The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism 7014 # leads to replacement symbols being set for equalities, and inferred relationships being checked 7015 # with runtime asserts. Check that we specialize to static values when the program says so. 7016 AUTO, STATIC = Dim.AUTO, Dim.STATIC 7017 7018 # case 1: direct equality between symbols 7019 class SimpleEquality(torch.nn.Module): 7020 def forward(self, x, y, z): 7021 # all inputs should have shape [s0, s1] 7022 return x + y + z 7023 7024 inputs = tuple(torch.randn(6, 3) for _ in range(3)) 7025 # fully dynamic 7026 self._check_dynamic_shapes_specs_and_shapes( 7027 SimpleEquality(), 7028 inputs, 7029 specs=[ 7030 ((AUTO, AUTO), (AUTO, AUTO), (AUTO, AUTO)), 7031 [[AUTO, AUTO], [AUTO, AUTO], [AUTO, AUTO]], 7032 {"x": (AUTO, AUTO), "y": (AUTO, AUTO), "z": (AUTO, AUTO)}, 7033 ], 7034 passing_shapes=[ 7035 ((4, 4), (4, 4), (4, 4)), 7036 ((1, 1), (1, 1), (1, 1)), 7037 ((0, 9), (0, 9), (0, 9)), 7038 ], 7039 failing_shapes=[ 7040 ((4, 4), (4, 4), (4, 3)), 7041 ((4, 4), (5, 4), (4, 5)), 7042 ], 7043 test_serdes=True, 7044 ) 7045 # static s1 7046 self._check_dynamic_shapes_specs_and_shapes( 7047 # specifying just one dimension as static should be enough to specialize all s1 7048 SimpleEquality(), 7049 inputs, 7050 specs=[ 7051 [{0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}, (AUTO, None)], 7052 {"x": (AUTO, AUTO), "y": (AUTO, AUTO), "z": (AUTO, None)}, 7053 ], 7054 passing_shapes=[ 7055 ((4, 3), (4, 3), (4, 3)), 7056 ((1, 3), (1, 3), (1, 3)), 7057 ((0, 3), (0, 3), (0, 3)), 7058 ], 7059 failing_shapes=[ 7060 ((4, 4), (4, 4), (4, 4)), 7061 ((1, 1), (1, 1), (1, 1)), 7062 ((0, 9), (0, 9), (0, 9)), 7063 ], 7064 test_serdes=True, 7065 ) 7066 # fully static 7067 self._check_dynamic_shapes_specs_and_shapes( 7068 # this should specialize all 7069 SimpleEquality(), 7070 inputs, 7071 specs=[{"x": (None, AUTO), "y": (AUTO, AUTO), "z": (AUTO, None)}], 7072 passing_shapes=[ 7073 ((6, 3), (6, 3), (6, 3)), 7074 ], 7075 failing_shapes=[ 7076 ((6, 4), (6, 4), (6, 4)), 7077 ((1, 3), (1, 3), (1, 3)), 7078 ((0, 9), (0, 9), (0, 9)), 7079 ], 7080 test_serdes=True, 7081 ) 7082 7083 def test_automatic_dynamic_shapes_constant_relation(self): 7084 AUTO, STATIC = Dim.AUTO, Dim.STATIC 7085 7086 # case 2: related by constant: s0 + 4 = s1 7087 class OffBy4(torch.nn.Module): 7088 def forward(self, x, y): 7089 return x + y[4:] 7090 7091 inputs = (torch.randn(6), torch.randn(10)) 7092 # fully dynamic 7093 self._check_dynamic_shapes_specs_and_shapes( 7094 OffBy4(), 7095 inputs, 7096 specs=[ 7097 ((AUTO,), (AUTO,)), 7098 {"x": (AUTO,), "y": (AUTO,)}, 7099 ], 7100 passing_shapes=[ 7101 ((10,), (14,)), 7102 ((3,), (7,)), 7103 ((2,), (6,)), 7104 ], 7105 failing_shapes=[ 7106 ((10,), (13,)), 7107 ], 7108 test_serdes=True, 7109 ) 7110 # static s1 should specialize s0 7111 self._check_dynamic_shapes_specs_and_shapes( 7112 OffBy4(), 7113 inputs, 7114 specs=[ 7115 {"x": (AUTO,), "y": (None,)}, 7116 ], 7117 passing_shapes=[ 7118 ((6,), (10,)), 7119 ], 7120 failing_shapes=[ 7121 ((10,), (14,)), 7122 ((3,), (7,)), 7123 ((2,), (6,)), 7124 ], 7125 test_serdes=True, 7126 ) 7127 7128 def test_automatic_dynamic_shapes_linear_relation(self): 7129 AUTO, STATIC = Dim.AUTO, Dim.STATIC 7130 7131 # case 3: linear relation 7132 class LinearRel(torch.nn.Module): 7133 def forward(self, x, y): 7134 # x: [s0], y: [s1] 7135 # relation seems to be (s0 + 2) // 4 == s1 7136 return x[1::4] + y 7137 7138 inputs = (torch.randn(21), torch.randn(5)) 7139 7140 # fully dynamic 7141 self._check_dynamic_shapes_specs_and_shapes( 7142 LinearRel(), 7143 inputs, 7144 specs=[ 7145 ((AUTO,), (AUTO,)), 7146 {"x": (AUTO,), "y": (AUTO,)}, 7147 ], 7148 passing_shapes=[ 7149 ((33,), (8,)), 7150 ((32,), (8,)), 7151 ((31,), (8,)), 7152 ((30,), (8,)), 7153 ], 7154 failing_shapes=[ 7155 ((34,), (8,)), 7156 ((22,), (5,)), 7157 ], 7158 test_serdes=False, 7159 ) 7160 # static s1 shouldn't actually specialize s0 (guard: (s0 + 2) // 4 == 5) 7161 self._check_dynamic_shapes_specs_and_shapes( 7162 LinearRel(), 7163 inputs, 7164 specs=[ 7165 ((AUTO,), None), 7166 {"x": (AUTO,), "y": None}, 7167 ], 7168 passing_shapes=[ 7169 ((21,), (5,)), 7170 ((20,), (5,)), 7171 ((19,), (5,)), 7172 ((18,), (5,)), 7173 ], 7174 failing_shapes=[ 7175 ((33,), (8,)), 7176 ], 7177 test_serdes=False, 7178 ) 7179 # but static s0 will definitely specialize s1 (guard: (21 + 2) // 4 == s1 -> 5 == s1) 7180 self._check_dynamic_shapes_specs_and_shapes( 7181 LinearRel(), 7182 inputs, 7183 specs=[ 7184 (None, (AUTO,)), 7185 ], 7186 passing_shapes=[ 7187 ((21,), (5,)), 7188 ], 7189 failing_shapes=[ 7190 ((22,), (5,)), 7191 ], 7192 test_serdes=True, 7193 ) 7194 7195 def test_dynamic_shapes_serdes_generic(self): 7196 from torch._export.serde.dynamic_shapes import ( 7197 _dump_dynamic_shapes, 7198 _load_dynamic_shapes, 7199 ) 7200 7201 class Foo(torch.nn.Module): 7202 def forward(self, a, b, c, d): 7203 if d == "hello": 7204 x = a[0] + a[1][1:] 7205 b = torch.cat([b, b], dim=0).reshape([-1, 1]) 7206 return x + b, c * 2 7207 7208 # test de/serialization on some generic specs 7209 dz = Dim("dz", min=4, max=16) 7210 dx = 2 * dz 7211 dy = dx + 1 7212 inputs = ( 7213 [ 7214 torch.randn(8, 4), 7215 torch.randn(9, 4), 7216 ], 7217 torch.randn(4), 7218 torch.randn(4, 4), 7219 "hello", 7220 ) 7221 dynamic_shapes = { 7222 "a": [ 7223 (dx, 4), 7224 (dy, 4), 7225 ], 7226 "b": (dz,), 7227 "c": None, 7228 "d": None, 7229 } 7230 ep = export(Foo(), inputs, dynamic_shapes=dynamic_shapes) 7231 self._check_dynamic_shapes_specs_and_shapes( 7232 Foo(), 7233 inputs, 7234 [dynamic_shapes], 7235 [ 7236 ([(16, 4), (17, 4)], (8,), (4, 4), "hello"), 7237 ([(24, 4), (25, 4)], (12,), (4, 4), "hello"), 7238 ], 7239 [ 7240 ([(16, 4), (17, 4)], (8,), (5, 5), "hello"), 7241 ], 7242 test_serdes=True, 7243 ) 7244 self.assertExpectedInline( 7245 _dump_dynamic_shapes(dynamic_shapes, inputs), 7246 """DynamicShapesSpec(dynamic_shapes=([['2*dz', 4], ['2*dz + 1', 4]], ['dz'], ['_DimHint.STATIC', '_DimHint.STATIC'], None), dims={'dz': RootDim(min=4, max=16, derived=['2*dz', '2*dz + 1'])})""", 7247 ) 7248 self.assertExpectedInline( 7249 _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True), 7250 """{'dynamic_shapes': ([['2*dz', 4], ['2*dz + 1', 4]], ['dz'], ['_DimHint.STATIC', '_DimHint.STATIC'], None), 'dims': {'dz': {'min': 4, 'max': 16, 'derived': ['2*dz', '2*dz + 1']}}}""", 7251 ) 7252 ((dx, _), (dy, _)), (dz,), (_, _), _ = _load_dynamic_shapes( 7253 _dump_dynamic_shapes(dynamic_shapes, inputs) 7254 ) 7255 self.assertEqual(dx.root, dz) 7256 self.assertEqual(dy.root, dz) 7257 7258 def test_dynamic_shapes_serdes_various(self): 7259 # serialization for dataclass inputs, Dim.AUTO/STATIC, and kwargs 7260 from torch._export.serde.dynamic_shapes import ( 7261 _dump_dynamic_shapes, 7262 _load_dynamic_shapes, 7263 ) 7264 7265 auto, static = Dim.AUTO, Dim.STATIC 7266 7267 @dataclass 7268 class Input: 7269 a: Tensor 7270 b: Tensor 7271 7272 register_dataclass_as_pytree_node( 7273 Input, 7274 serialized_type_name="test_dynamic_shapes_serdes_various.Input", 7275 ) 7276 7277 class Foo(torch.nn.Module): 7278 def forward(self, x, y, z): 7279 return x - torch.randn(4), y.a + y.b + z[1:] 7280 7281 args = (torch.randn(4, 4),) 7282 kwargs = { 7283 "y": Input(a=torch.randn(8, 8), b=torch.randn(8, 8)), 7284 "z": torch.randn(9, 8), 7285 } 7286 dynamic_shapes = { 7287 "x": (auto, static), 7288 "y": [(auto, auto), (auto, auto)], 7289 "z": (auto, 8), 7290 } 7291 7292 # dump dynamic_shapes 7293 self.assertExpectedInline( 7294 _dump_dynamic_shapes(dynamic_shapes, args, kwargs), 7295 """DynamicShapesSpec(dynamic_shapes=(['_DimHint.AUTO', '_DimHint.STATIC'], [['_DimHint.AUTO', '_DimHint.AUTO'], ['_DimHint.AUTO', '_DimHint.AUTO']], ['_DimHint.AUTO', 8]), dims={})""", 7296 ) 7297 self.assertExpectedInline( 7298 _dump_dynamic_shapes(dynamic_shapes, args, kwargs, to_dict=True), 7299 """{'dynamic_shapes': (['_DimHint.AUTO', '_DimHint.STATIC'], [['_DimHint.AUTO', '_DimHint.AUTO'], ['_DimHint.AUTO', '_DimHint.AUTO']], ['_DimHint.AUTO', 8]), 'dims': {}}""", 7300 ) 7301 7302 def test_dynamic_shapes_serdes_user_errors(self): 7303 # check error messages for dynamic shapes de/serialization 7304 from torch._export.serde.dynamic_shapes import ( 7305 _dump_dynamic_shapes, 7306 _load_dynamic_shapes, 7307 DynamicShapesSpec, 7308 RootDim, 7309 ) 7310 from torch._export.serde.serialize import _dataclass_to_dict 7311 7312 # this stuff should be well tested in `test_mismatched_dynamic_shapes` 7313 with self.assertRaisesRegex( 7314 torch._dynamo.exc.UserError, 7315 re.escape( 7316 "Detected mismatch between the structure of `inputs` and `dynamic_shapes`: `inputs[0]['k']` " 7317 "is a <class 'list'>, but `dynamic_shapes[0]['k']` is a <class 'tuple'>" 7318 ), 7319 ): 7320 dynamic_shapes = {"x": {"k": (Dim("dx"), Dim("dy"))}} 7321 _dump_dynamic_shapes(dynamic_shapes, ({"k": [torch.randn(4, 4)]},)) 7322 7323 # loading with from_dict=True/False 7324 spec = DynamicShapesSpec( 7325 dynamic_shapes=[["dx"]], 7326 dims={"dx": RootDim(min=4, max=16, derived=[])}, 7327 ) 7328 spec_dict = _dataclass_to_dict(spec) 7329 with self.assertRaisesRegex( 7330 torch._dynamo.exc.UserError, 7331 re.escape( 7332 "With from_dict=True, expected `spec` to be a dict, " 7333 "got <class 'torch._export.serde.dynamic_shapes.DynamicShapesSpec'>" 7334 ), 7335 ): 7336 _load_dynamic_shapes(spec, from_dict=True) 7337 7338 with self.assertRaisesRegex( 7339 torch._dynamo.exc.UserError, 7340 re.escape("Expected `spec` to be a DynamicShapesSpec, got <class 'dict'>"), 7341 ): 7342 _load_dynamic_shapes(spec_dict, from_dict=False) 7343 7344 self.assertExpectedInline( 7345 _load_dynamic_shapes(spec, from_dict=False), 7346 """[[<class 'torch._export.serde.dynamic_shapes.dx'>]]""", 7347 ) 7348 7349 # check incorrect info in dims 7350 with self.assertRaisesRegex( 7351 torch._dynamo.exc.UserError, 7352 re.escape( 7353 "Expected dims in `spec['dims']` to map `min` to an int, got dx: None" 7354 ), 7355 ): 7356 spec = { 7357 "dynamic_shapes": [["dx"]], 7358 "dims": { 7359 "dx": { 7360 "min": None, 7361 "max": 4, 7362 "derived": [], 7363 }, 7364 }, 7365 } 7366 _load_dynamic_shapes(spec, from_dict=True) 7367 7368 with self.assertRaisesRegex( 7369 torch._dynamo.exc.UserError, 7370 re.escape( 7371 "Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, " 7372 "got dx which is not in dict_keys(['dy'])" 7373 ), 7374 ): 7375 spec = { 7376 "dynamic_shapes": [["dx"]], 7377 "dims": { 7378 "dy": { 7379 "min": 2, 7380 "max": 4, 7381 "derived": [], 7382 }, 7383 }, 7384 } 7385 _load_dynamic_shapes(spec, from_dict=True) 7386 7387 with self.assertRaisesRegex( 7388 torch._dynamo.exc.UserError, 7389 re.escape( 7390 "Expected derived expressions to be linear expressions, got dx**2 + 4" 7391 ), 7392 ): 7393 spec = { 7394 "dynamic_shapes": [["dx"]], 7395 "dims": { 7396 "dx": { 7397 "min": 2, 7398 "max": 4, 7399 "derived": ["dx**2 + 4"], 7400 }, 7401 }, 7402 } 7403 _load_dynamic_shapes(spec, from_dict=True) 7404 7405 @testing.expectedFailureNonStrict 7406 @testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked? 7407 @testing.expectedFailureSerDer # T195866111 7408 def test_hints_wrapper(self): 7409 class M(torch.nn.Module): 7410 def __init__(self) -> None: 7411 super().__init__() 7412 7413 def forward(self, x, y): 7414 x = x + y 7415 7416 def inner_body_fn(x, y): 7417 x = torch.relu(x) 7418 x = x + y 7419 return x 7420 7421 def outer_body_fn(x, y): 7422 x = hints_wrapper( 7423 inner_body_fn, (x, y), {}, hints={"inner_body": True} 7424 ) 7425 x = torch.abs(x) 7426 return x 7427 7428 res = hints_wrapper( 7429 outer_body_fn, (x, y), {}, hints={"outer_body": True} 7430 ) 7431 return res 7432 7433 x = torch.randn(2, 4) 7434 y = torch.ones(4) 7435 7436 ep = export(M(), (x, y)) 7437 export_res = ep.module()(x, y) 7438 ref_res = M()(x, y) 7439 self.assertEqual(export_res, ref_res) 7440 self.assertExpectedInline( 7441 normalize_gm(ep.graph_module.print_readable(print_output=False)), 7442 """\ 7443class GraphModule(torch.nn.Module): 7444 def forward(self, x: "f32[2, 4]", y: "f32[4]"): 7445 add: "f32[2, 4]" = torch.ops.aten.add.Tensor(x, y); x = None 7446 7447 hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0 7448 hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (add, y), {}, hints = {'outer_body': True}); hints_wrapper_body_graph_0 = add = y = None 7449 getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None 7450 return (getitem,) 7451 7452 class hints_wrapper_body_graph_0(torch.nn.Module): 7453 def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"): 7454 hints_wrapper_body_graph_0 = self.hints_wrapper_body_graph_0 7455 hints_wrapper = torch.ops.higher_order.hints_wrapper(hints_wrapper_body_graph_0, (arg0_1, arg1_1), {}, hints = {'inner_body': True}); hints_wrapper_body_graph_0 = arg0_1 = arg1_1 = None 7456 getitem: "f32[2, 4]" = hints_wrapper[0]; hints_wrapper = None 7457 abs_1: "f32[2, 4]" = torch.ops.aten.abs.default(getitem); getitem = None 7458 return (abs_1,) 7459 7460 class hints_wrapper_body_graph_0(torch.nn.Module): 7461 def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"): 7462 relu: "f32[2, 4]" = torch.ops.aten.relu.default(arg0_1); arg0_1 = None 7463 add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None 7464 return (add,) 7465""", 7466 ) 7467 7468 7469@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") 7470class TestOneOffModelExportResult(TestCase): 7471 def test_scaled_dot_product_attention_cpu(self): 7472 """ 7473 This test makes sure we are always getting the same decomposition result for SDPA. 7474 As of now _scaled_dot_product_flash_attention_for_cpu is expected to show up in 7475 export() result. Some downstream backend then further decompose it into core ATen 7476 ops in torch/_decomp/decompositions.py (search for 7477 _scaled_dot_product_flash_attention_for_cpu). 7478 7479 Export is decomposing based on the CompositeImplicitAutograd kernel implementation 7480 of SDPA. If this test fails, it means the kernel is being modified. In this case 7481 we strongly encourage you to change the decomposition rule under 7482 torch/_decomp/decompositions.py along with the kernel changes, so all of the 7483 downstream backends are not being affected. 7484 """ 7485 7486 class ScaledDotProductAttention(torch.nn.Module): 7487 def __init__(self) -> None: 7488 super().__init__() 7489 7490 def forward(self, q, k, v): 7491 attn_output = F.scaled_dot_product_attention( 7492 q, k, v, None, dropout_p=0.0, is_causal=True 7493 ) 7494 return attn_output 7495 7496 q = torch.randn(1, 1, 8, 8, device="cpu") 7497 k = torch.randn(1, 1, 8, 8, device="cpu") 7498 v = torch.randn(1, 1, 8, 8, device="cpu") 7499 7500 from torch.nn.attention import SDPBackend 7501 7502 with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]): 7503 ep = torch.export.export(ScaledDotProductAttention(), (q, k, v)) 7504 print(ep.graph) 7505 ep.run_decompositions() 7506 print(ep.graph) 7507 7508 # self.assertExpectedInline(ep.graph_module.code.strip(), """\ 7509 # def forward(self, arg0_1, arg1_1, arg2_1): 7510 # _scaled_dot_product_flash_attention_for_cpu = torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default(arg0_1, arg1_1, arg2_1, 0.0, True); arg0_1 = arg1_1 = arg2_1 = None 7511 # getitem = _scaled_dot_product_flash_attention_for_cpu[0]; _scaled_dot_product_flash_attention_for_cpu = None 7512 # return (getitem,)""") 7513 7514 @unittest.skipIf( 7515 not PLATFORM_SUPPORTS_FLASH_ATTENTION, 7516 "Can't run fused SDPA on this platform", 7517 ) 7518 def test_scaled_dot_product_attention_cuda(self): 7519 """ 7520 This test makes sure we are always getting the same decomposition result for SDPA. 7521 As of now _scaled_dot_product_flash_attention is expected to show up in 7522 export() result (GPU tensors are given). Currently there's no downstream 7523 backend relies on this export result so if this test fails, feel free to 7524 change it to the latest export() result. 7525 """ 7526 7527 class ScaledDotProductAttention(torch.nn.Module): 7528 def __init__(self) -> None: 7529 super().__init__() 7530 7531 def forward(self, q, k, v): 7532 attn_output = F.scaled_dot_product_attention( 7533 q, k, v, None, dropout_p=0.0, is_causal=True 7534 ) 7535 return attn_output 7536 7537 q = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda") 7538 k = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda") 7539 v = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda") 7540 7541 ep = torch.export.export( 7542 ScaledDotProductAttention(), (q, k, v) 7543 ).run_decompositions() 7544 code_str = """\ 7545def forward(self, q, k, v): 7546 _scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention.default(q, k, v, 0.0, True, scale = 0.125); q = k = v = None 7547 getitem = _scaled_dot_product_flash_attention[0]; _scaled_dot_product_flash_attention = None 7548 return (getitem,)""" 7549 if SM90OrLater and not torch.version.hip: 7550 code_str = """\ 7551def forward(self, q, k, v): 7552 _scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(q, k, v, None, False, 0.0, True); q = k = v = None 7553 getitem = _scaled_dot_product_cudnn_attention[0]; _scaled_dot_product_cudnn_attention = None 7554 return (getitem,)""" 7555 self.assertExpectedInline( 7556 ep.graph_module.code.strip(), 7557 code_str, 7558 ) 7559 7560 def test_int_list_output(self): 7561 class M(torch.nn.Module): 7562 def forward(self, x): 7563 return [((1, 3), [x + x, x * x])] 7564 7565 ep = torch.export.export(M(), (torch.ones(2, 3),)) 7566 res = ep.module()(torch.ones(2, 3)) 7567 self.assertEqual(res[0][0], (1, 3)) 7568 7569 def test_primitive_constant_output(self): 7570 class Z(torch.nn.Module): 7571 def forward(self, x, y): 7572 with torch.no_grad(): 7573 return y * x, "moo" 7574 7575 ep = torch.export.export(Z(), (torch.tensor(3), 5)) 7576 res = ep.module()(torch.tensor(4), 5) 7577 self.assertEqual(res[0], torch.tensor(20)) 7578 self.assertEqual(res[1], "moo") 7579 7580 class B(torch.nn.Module): 7581 def forward(self, x, y): 7582 return y * x, y 7583 7584 ep = torch.export.export(B(), (torch.tensor(3), 5)) 7585 res = ep.module()(torch.tensor(4), 5) 7586 self.assertEqual(res[0], torch.tensor(20)) 7587 self.assertEqual(res[1], 5) 7588 7589 with self.assertRaisesRegex( 7590 RuntimeError, 7591 escape("Expected input at *args[1] to be equal to 5, but got 20"), 7592 ): 7593 res = ep.module()(torch.tensor(4), 20) 7594 7595 class F(torch.nn.Module): 7596 def forward(self, x): 7597 # return a constant of primitive type 7598 y = 5 7599 return y * x, y 7600 7601 ep = torch.export.export(F(), (torch.tensor(3),)) 7602 res = ep.module()(torch.tensor(4)) 7603 self.assertEqual(res[0], torch.tensor(20)) 7604 self.assertEqual(res[1], 5) 7605 7606 class Q(torch.nn.Module): 7607 def forward(self, x, y): 7608 return y * x, y - 1 7609 7610 ep = torch.export.export(Q(), (torch.tensor(3), 5)) 7611 res = ep.module()(torch.tensor(4), 5) 7612 self.assertEqual(res[0], torch.tensor(20)) 7613 self.assertEqual(res[1], 4) 7614 7615 def test_unbacked_sdpa(self): 7616 import torch 7617 from torch.nn.attention import sdpa_kernel, SDPBackend 7618 from torch.nn.functional import scaled_dot_product_attention 7619 7620 class Module(torch.nn.Module): 7621 def forward( 7622 self, query: torch.Tensor, cache: torch.Tensor, start_pos: torch.Tensor 7623 ) -> torch.Tensor: 7624 # x.sizes(): 1, 128, 16, 128 7625 sp = start_pos.item() 7626 torch._check_is_size(sp) 7627 torch._check(sp >= 0) 7628 torch._check(sp <= 126) 7629 key = cache[:, : sp + 1, :, :] # 1, sp+1, 16, 128 7630 value = cache[:, : sp + 1, :, :] # 1, sp+1, 16, 128 7631 query = query.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 7632 key = key.transpose(1, 2) 7633 value = value.transpose(1, 2) 7634 # https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/attention.cpp#L732 7635 return scaled_dot_product_attention(query, key, value) 7636 7637 cache = torch.randn(1, 128, 16, 128, dtype=torch.float16) 7638 query = torch.randn(1, 1, 16, 128, dtype=torch.float16) 7639 start_pos = torch.tensor([0]) 7640 with sdpa_kernel(SDPBackend.MATH), torch.no_grad(): 7641 ep = torch.export.export(Module(), (query, cache, start_pos)) 7642 args = (query, cache, start_pos) 7643 self.assertEqual(ep.module()(*args), Module()(*args)) 7644 args = (query, cache, torch.tensor([3])) 7645 self.assertEqual(ep.module()(*args), Module()(*args)) 7646 args = (query, cache, torch.tensor([126])) 7647 self.assertEqual(ep.module()(*args), Module()(*args)) 7648 7649 def test_none_input_output(self): 7650 class Z(torch.nn.Module): 7651 def forward(self, x, y): 7652 return x * x 7653 7654 ep = torch.export.export(Z(), (torch.tensor(3), None)) 7655 res = ep.module()(torch.tensor(4), None) 7656 self.assertEqual(res, torch.tensor(16)) 7657 7658 class B(torch.nn.Module): 7659 def forward(self, x, y): 7660 return x * x, y 7661 7662 ep = torch.export.export(B(), (torch.tensor(3), None)) 7663 res = ep.module()(torch.tensor(4), None) 7664 self.assertEqual(res[0], torch.tensor(16)) 7665 self.assertEqual(res[1], None) 7666 7667 decomp = ep.run_decompositions() 7668 gm = decomp.module() 7669 res = gm(torch.tensor(4), None) 7670 self.assertEqual(res[0], torch.tensor(16)) 7671 self.assertEqual(res[1], None) 7672 7673 def test_print(self): 7674 class M(torch.nn.Module): 7675 def forward(self, x): 7676 print("start") 7677 x1 = x + x 7678 print(x1) 7679 x2 = x1 * x1 7680 print(1, 2, 3) 7681 x3 = x2 + x2 7682 return (x1, x3) 7683 7684 gm = export(M(), (torch.randn(3, 3),)).graph_module 7685 self.assertExpectedInline( 7686 gm.code.strip(), 7687 """\ 7688def forward(self, x): 7689 add = torch.ops.aten.add.Tensor(x, x); x = None 7690 mul = torch.ops.aten.mul.Tensor(add, add) 7691 add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None 7692 return (add, add_1)""", 7693 ) 7694 7695 def test_logging_logger(self): 7696 logger = logging.getLogger(__name__) 7697 7698 class M(torch.nn.Module): 7699 def forward(self, x): 7700 logger.log("start") 7701 x1 = x + x 7702 logger.debug(x1) 7703 x2 = x1 * x1 7704 logger.info(1, 2, 3) 7705 x3 = x2 + x2 7706 return (x1, x3) 7707 7708 gm = export(M(), (torch.randn(3, 3),)).graph_module 7709 self.assertExpectedInline( 7710 gm.code.strip(), 7711 """\ 7712def forward(self, x): 7713 add = torch.ops.aten.add.Tensor(x, x); x = None 7714 mul = torch.ops.aten.mul.Tensor(add, add) 7715 add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None 7716 return (add, add_1)""", 7717 ) 7718 7719 @unittest.skipIf(not TEST_TRANSFORMERS, "No transformers") 7720 def test_hf_logging_logger(self): 7721 import transformers 7722 7723 logger = transformers.utils.logging.get_logger(__name__) 7724 7725 class M(torch.nn.Module): 7726 def forward(self, x): 7727 logger.warning_once("start") 7728 x1 = x + x 7729 x2 = x1 * x1 7730 x3 = x2 + x2 7731 return (x1, x3) 7732 7733 gm = export(M(), (torch.randn(3, 3),)).graph_module 7734 self.assertExpectedInline( 7735 gm.code.strip(), 7736 """\ 7737def forward(self, x): 7738 add = torch.ops.aten.add.Tensor(x, x); x = None 7739 mul = torch.ops.aten.mul.Tensor(add, add) 7740 add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None 7741 return (add, add_1)""", 7742 ) 7743 7744 def test_warning(self): 7745 class M(torch.nn.Module): 7746 def forward(self, x): 7747 warnings.warn("moo") 7748 res = x + x 7749 warnings.warn(f"{res}") 7750 return res 7751 7752 gm = export(M(), (torch.randn(3, 3),)).graph_module 7753 self.assertExpectedInline( 7754 gm.code.strip(), 7755 """\ 7756def forward(self, x): 7757 add = torch.ops.aten.add.Tensor(x, x); x = None 7758 return (add,)""", 7759 ) 7760 7761 def test_constant_fqn(self): 7762 class Nested(torch.nn.Module): 7763 def __init__(self) -> None: 7764 super().__init__() 7765 self.constant = torch.rand(2, 3) 7766 self.parameter = torch.nn.Parameter(torch.rand(2, 3)) 7767 7768 def forward(self, x): 7769 return x + self.constant 7770 7771 class Mod(torch.nn.Module): 7772 def __init__(self) -> None: 7773 super().__init__() 7774 self.nested = Nested() 7775 7776 def forward(self, x): 7777 return self.nested(x) + self.nested.constant + self.nested.parameter 7778 7779 m = Mod() 7780 ep = export(m, (torch.rand(2, 3),), strict=True) 7781 self.assertEqual(ep.constants["nested.constant"], m.nested.constant) 7782 self.assertEqual(ep.module()(torch.ones(2, 3)), m(torch.ones(2, 3))) 7783 7784 def test_constant_name(self): 7785 class Nested(torch.nn.Module): 7786 def __init__(self) -> None: 7787 super().__init__() 7788 self.constant = torch.rand(2, 3) 7789 self.parameter = torch.nn.Parameter(torch.rand(2, 3)) 7790 7791 def forward(self, x): 7792 return x + self.constant 7793 7794 class Mod(torch.nn.Module): 7795 def __init__(self) -> None: 7796 super().__init__() 7797 self.nested_1 = Nested() 7798 self.nested_2 = Nested() 7799 7800 def forward(self, x): 7801 return ( 7802 self.nested_1(x) 7803 + self.nested_2(x) 7804 + self.nested_1.constant 7805 + self.nested_2.constant 7806 + self.nested_1.parameter 7807 + self.nested_2.parameter 7808 ) 7809 7810 m = Mod() 7811 ep = export(m, (torch.rand(2, 3),), strict=False) 7812 self.assertEqual(ep.module()(torch.ones(2, 3)), m(torch.ones(2, 3))) 7813 7814 # check constant fqn when there are multiple instances of the same class 7815 self.assertEqual(ep.constants["nested_1.constant"], m.nested_1.constant) 7816 self.assertEqual(ep.constants["nested_2.constant"], m.nested_2.constant) 7817 7818 # check constant_name in the graph 7819 placeholders = [ 7820 node for node in ep.graph_module.graph.nodes if node.op == "placeholder" 7821 ] 7822 self.assertEqual(len(placeholders), 5) 7823 self.assertTrue(all(ph.name == ph.target for ph in placeholders)) 7824 # suffix should be added to duplicated constant_name 7825 self.assertEqual(placeholders[2].name, "c_nested_1_constant") 7826 self.assertEqual(placeholders[3].name, "c_nested_2_constant") 7827 7828 def test_nested_retrace(self): 7829 class Nested(torch.nn.Module): 7830 def __init__(self) -> None: 7831 super().__init__() 7832 self.param = torch.nn.Parameter(torch.randn(3)) 7833 7834 def forward(self, x): 7835 return x + self.param 7836 7837 class Foo(torch.nn.Module): 7838 def __init__(self) -> None: 7839 super().__init__() 7840 self.nested = Nested() 7841 7842 def forward(self, x): 7843 return x + self.nested(x) 7844 7845 # first export 7846 foo = Foo().to("meta") 7847 inputs = (torch.ones(3, device="meta"),) 7848 foo(*inputs) 7849 ep = torch.export.export(foo, inputs, strict=False) 7850 7851 # second export 7852 foo_1 = ep.module() 7853 ep_1 = torch.export.export(foo_1, inputs, strict=False) 7854 7855 for node1, node2 in zip(ep.graph.nodes, ep_1.graph.nodes): 7856 nn_module_stack_1 = node1.meta.get("nn_module_stack", None) 7857 nn_module_stack_2 = node2.meta.get("nn_module_stack", None) 7858 7859 if nn_module_stack_1 is None: 7860 self.assertTrue(nn_module_stack_2 is None) 7861 else: 7862 for v1, v2 in zip( 7863 nn_module_stack_1.values(), nn_module_stack_2.values() 7864 ): 7865 self.assertEqual(v1, v2) 7866 7867 def test_duplicated_getitem(self): 7868 class Foo(torch.nn.Module): 7869 def forward(self, x): 7870 return torch.topk(x, 2) 7871 7872 foo = Foo() 7873 inputs = (torch.randn(3),) 7874 ep = torch.export.export(foo, inputs, strict=False) 7875 7876 graph_module = copy.deepcopy(ep.graph_module) 7877 7878 call_function_node = None 7879 num_getitems = 0 7880 for node in graph_module.graph.nodes: 7881 if ( 7882 node.op == "call_function" 7883 and node.target == torch.ops.aten.topk.default 7884 ): 7885 call_function_node = node 7886 elif node.op == "call_function" and node.target == operator.getitem: 7887 self.assertIs(node.args[0], call_function_node) 7888 num_getitems += 1 7889 7890 self.assertIsNotNone(call_function_node) 7891 self.assertEqual(num_getitems, 2) 7892 7893 output_node = list(graph_module.graph.nodes)[-1] 7894 7895 nodes = [] 7896 with graph_module.graph.inserting_before(output_node): 7897 nodes.append( 7898 graph_module.graph.call_function( 7899 operator.getitem, (call_function_node, 1) 7900 ) 7901 ) 7902 nodes.append( 7903 graph_module.graph.call_function( 7904 operator.getitem, (call_function_node, 0) 7905 ) 7906 ) 7907 nodes.append( 7908 graph_module.graph.call_function( 7909 operator.getitem, (call_function_node, 0) 7910 ) 7911 ) 7912 nodes.append( 7913 graph_module.graph.call_function( 7914 operator.getitem, (call_function_node, 1) 7915 ) 7916 ) 7917 signature = ExportGraphSignature( 7918 input_specs=ep.graph_signature.input_specs, 7919 output_specs=ep.graph_signature.output_specs 7920 + [ 7921 OutputSpec( 7922 kind=OutputKind.USER_OUTPUT, 7923 arg=TensorArgument(name=node.name), 7924 target=None, 7925 ) 7926 for node in nodes 7927 ], 7928 ) 7929 output_node.args = (output_node.args[0] + tuple(nodes),) 7930 graph_module.recompile() 7931 new_ep = ep._update(graph_module, signature) 7932 7933 new_num_getitems = 0 7934 for node in new_ep.graph.nodes: 7935 if ( 7936 node.op == "call_function" 7937 and node.target == torch.ops.aten.topk.default 7938 ): 7939 call_function_node = node 7940 elif node.op == "call_function" and node.target == operator.getitem: 7941 self.assertIs(node.args[0], call_function_node) 7942 new_num_getitems += 1 7943 self.assertEqual(num_getitems, new_num_getitems) 7944 self.assertEqual( 7945 len(list(new_ep.graph.nodes)[-1].args[0]), len(signature.output_specs) 7946 ) 7947 7948 7949@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") 7950class TestExportCustomClass(TorchTestCase): 7951 def setUp(self): 7952 if IS_FBCODE: 7953 lib_file_path = "//caffe2/test/cpp/jit:test_custom_class_registrations" 7954 elif IS_SANDCASTLE or IS_MACOS: 7955 raise unittest.SkipTest("non-portable load_library call used in test") 7956 elif IS_WINDOWS: 7957 lib_file_path = find_library_location("torchbind_test.dll") 7958 else: 7959 lib_file_path = find_library_location("libtorchbind_test.so") 7960 torch.ops.load_library(str(lib_file_path)) 7961 7962 def test_lift_custom_obj(self): 7963 # TODO: fix this test once custom class tracing is implemented 7964 7965 custom_obj = torch.classes._TorchScriptTesting._PickleTester([3, 4]) 7966 7967 class Foo(torch.nn.Module): 7968 def forward(self, x): 7969 return x + x 7970 7971 f = Foo() 7972 7973 inputs = (torch.zeros(4, 4),) 7974 ep = export(f, inputs) 7975 7976 # Replace one of the values with an instance of our custom class 7977 for node in ep.graph.nodes: 7978 if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 7979 with ep.graph.inserting_before(node): 7980 setattr(ep.graph_module, "custom_obj", custom_obj) 7981 getattr_node = ep.graph.get_attr("custom_obj") 7982 # Copy over an nn_module_stack as they are required. 7983 getattr_node.meta["nn_module_stack"] = node.meta["nn_module_stack"] 7984 custom_node = ep.graph.call_function( 7985 torch.ops._TorchScriptTesting.take_an_instance.default, 7986 (getattr_node,), 7987 ) 7988 custom_node.meta["val"] = torch.ones(4, 4) 7989 # Copy over an nn_module_stack as they are required. 7990 custom_node.meta["nn_module_stack"] = node.meta["nn_module_stack"] 7991 custom_node.meta["torch_fn"] = ( 7992 "custom_op", 7993 "torch.ops._TorchScriptTesting.take_an_instance.default", 7994 ) 7995 arg0, _ = node.args 7996 node.args = (arg0, custom_node) 7997 7998 from torch._export.passes.lift_constants_pass import lift_constants_pass 7999 from torch._export.serde.serialize import deserialize, serialize 8000 8001 constants = lift_constants_pass(ep.graph_module, ep.graph_signature, {}) 8002 for k, v in constants.items(): 8003 assert k not in ep.constants 8004 ep._constants[k] = v 8005 serialized_vals = serialize(ep) 8006 deserialized_ep = deserialize(serialized_vals) 8007 8008 for node in deserialized_ep.graph.nodes: 8009 if ( 8010 node.op == "call_function" 8011 and node.target 8012 == torch.ops._TorchScriptTesting.take_an_instance.default 8013 ): 8014 arg = node.args[0] 8015 self.assertTrue(arg.op == "placeholder") 8016 8017 def test_preserve_non_cia_op(self): 8018 class M(torch.nn.Module): 8019 def forward(self, x): 8020 return torch.nn.functional.elu(x) 8021 8022 ep = export(M(), (torch.randn(2, 3, 4, 5),)) 8023 FileCheck().check_count("torch.ops.aten.elu.default", 1, exactly=True).run( 8024 ep.graph_module.code 8025 ) 8026 8027 ep = ep.run_decompositions( 8028 decomp_table=get_decompositions([torch.ops.aten.elu.default]), 8029 _preserve_ops=[torch.ops.aten.elu.default], 8030 ) 8031 FileCheck().check_count("torch.ops.aten.elu.default", 1, exactly=True).run( 8032 ep.graph_module.code 8033 ) 8034 8035 def test_preserve_cia_op(self): 8036 class StaticResizeBilinear2dModule(torch.nn.Module): 8037 def forward(self, x): 8038 a = torch.nn.functional.interpolate( 8039 x, 8040 size=(x.shape[2] * 2, x.shape[3] * 3), 8041 mode="bilinear", 8042 align_corners=False, 8043 antialias=False, 8044 ) 8045 return a 8046 8047 ep = export(StaticResizeBilinear2dModule(), (torch.randn(2, 3, 4, 5),)) 8048 FileCheck().check_count( 8049 "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True 8050 ).run(ep.graph_module.code) 8051 8052 decomp_table = get_decompositions([torch.ops.aten.upsample_bilinear2d.vec]) 8053 ep = ep.run_decompositions( 8054 decomp_table=decomp_table, 8055 _preserve_ops=[torch.ops.aten.upsample_bilinear2d.vec], 8056 ) 8057 assert torch.ops.aten.upsample_bilinear2d.vec in decomp_table 8058 FileCheck().check_count( 8059 "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True 8060 ).run(ep.graph_module.code) 8061 8062 8063if __name__ == "__main__": 8064 run_tests() 8065