1""" 2PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes 3with test_sym_bool) 4""" 5 6 7# Owner(s): ["oncall: export"] 8import copy 9import io 10import tempfile 11import unittest 12import zipfile 13from pathlib import Path 14 15import torch 16import torch._dynamo as torchdynamo 17import torch.export._trace 18import torch.utils._pytree as pytree 19from torch._export.db.case import ExportCase, SupportLevel 20from torch._export.db.examples import all_examples 21from torch._export.serde.serialize import ( 22 canonicalize, 23 deserialize, 24 ExportedProgramDeserializer, 25 ExportedProgramSerializer, 26 serialize, 27 SerializeError, 28) 29from torch._higher_order_ops.torchbind import enable_torchbind_tracing 30from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode 31from torch.export import Dim, export, load, save 32from torch.fx.experimental.symbolic_shapes import is_concrete_int, ValueRanges 33from torch.testing._internal.common_utils import ( 34 instantiate_parametrized_tests, 35 IS_WINDOWS, 36 parametrize, 37 run_tests, 38 TemporaryFileName, 39 TestCase, 40) 41from torch.testing._internal.torchbind_impls import init_torchbind_implementations 42 43 44def get_filtered_export_db_tests(): 45 return [ 46 (name, case) 47 for name, case in all_examples().items() 48 if case.support_level == SupportLevel.SUPPORTED 49 ] 50 51 52@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") 53class TestSerialize(TestCase): 54 def test_export_with_extension_op_serialization(self): 55 class TestModule(torch.nn.Module): 56 def forward(self, x): 57 return x + x 58 59 class FooExtensionOp: 60 def __hash__(self): 61 return 0 62 63 def __eq__(self, other): 64 return type(other) == type(self) 65 66 def __call__(self, *args, **kwargs): 67 return torch.ops.aten.add.Tensor(*args, **kwargs) 68 69 @property 70 def __name__(self): 71 return "foo.my_op" 72 73 class ExtensionVerifier(torch._export.verifier.Verifier): 74 dialect = "FOO" 75 76 def allowed_op_types(self): 77 return super().allowed_op_types() + (FooExtensionOp,) 78 79 class FooExtensionHandler(torch._export.serde.serialize.ExtensionHandler): 80 @classmethod 81 def namespace(cls): 82 return "foo" 83 84 @classmethod 85 def to_op_name(cls, op): 86 return "my_op" 87 88 @classmethod 89 def from_op_name(cls, name: str): 90 self.assertEqual(name, "my_op") 91 return FooExtensionOp() 92 93 @classmethod 94 def op_schema(cls, op): 95 return torch.ops.aten.add.Tensor._schema 96 97 inp = (torch.ones(10),) 98 ep = export(TestModule(), inp) 99 100 # Register the custom op handler. 101 foo_custom_op = FooExtensionOp() 102 torch._export.serde.serialize.register_extension( 103 FooExtensionOp, FooExtensionHandler 104 ) 105 106 new_gm = copy.deepcopy(ep.graph_module) 107 # Inject the custom operator. 108 for node in new_gm.graph.nodes: 109 if node.name == "add": 110 node.target = foo_custom_op 111 112 new_ep = ep._update(new_gm, ep.graph_signature, verifiers=[ExtensionVerifier]) 113 serialized = serialize(new_ep) 114 deserialized = deserialize(serialized) 115 self.assertEqual( 116 len( 117 deserialized.graph.find_nodes(op="call_function", target=foo_custom_op) 118 ), 119 1, 120 ) 121 122 def test_predispatch_export_with_autograd_op(self): 123 class Foo(torch.nn.Module): 124 def __init__(self) -> None: 125 super().__init__() 126 127 def forward(self, x): 128 with torch.enable_grad(): 129 return x + x 130 131 inp = (torch.ones(10),) 132 with torch.no_grad(): 133 from torch.export._trace import _export 134 135 ep = _export(Foo(), inp, pre_dispatch=True) 136 137 buffer = io.BytesIO() 138 torch.export.save(ep, buffer) 139 buffer.seek(0) 140 loaded_ep = torch.export.load(buffer) 141 142 exp_out = ep.module()(*inp) 143 actual_out = loaded_ep.module()(*inp) 144 self.assertEqual(exp_out, actual_out) 145 self.assertEqual(exp_out.requires_grad, actual_out.requires_grad) 146 147 def test_export_example_inputs_preserved(self): 148 class MyModule(torch.nn.Module): 149 """A test module with that has multiple args and uses kwargs""" 150 151 def __init__(self) -> None: 152 super().__init__() 153 self.p = torch.nn.Parameter(torch.ones(2, 3)) 154 155 def forward(self, x, y, use_p=False): 156 out = x + y 157 if use_p: 158 out += self.p 159 return out 160 161 model = MyModule().eval() 162 random_inputs = (torch.rand([2, 3]), torch.rand([2, 3])) 163 exp_program = torch.export.export(model, random_inputs, {"use_p": True}) 164 165 output_buffer = io.BytesIO() 166 # Tests that example inputs are preserved when saving and loading module. 167 torch.export.save(exp_program, output_buffer) 168 loaded_model = torch.export.load(output_buffer) 169 # Extract the example inputs from before and after saving. 170 orig_args, orig_kwargs = exp_program.example_inputs 171 loaded_args, loaded_kwargs = loaded_model.example_inputs 172 # Run both modules and confirm that outputs match. 173 orig_out = exp_program.module()(*orig_args, **orig_kwargs) 174 loaded_out = loaded_model.module()(*loaded_args, **loaded_kwargs) 175 self.assertEqual(orig_out, loaded_out) 176 177 def test_metadata_parsing_with_layer_split(self): 178 # Tests that modules with more complicated layer patterns can be serialized 179 # and deserialized correctly. 180 class MyModule(torch.nn.Module): 181 def __init__(self) -> None: 182 super().__init__() 183 self.layers = torch.nn.Sequential( 184 torch.nn.SiLU(), 185 torch.nn.SiLU(), 186 torch.nn.SiLU(), 187 ) 188 189 def forward(self, x): 190 # Splitting layers of a sequential stack introduces commas and parens 191 # into metadata trace. 192 out_start, out_rest = self.layers[0], self.layers[1:] 193 h = out_start(x) 194 h = out_rest(h) 195 return h 196 197 inp = (torch.ones(10),) 198 # Module will only be able to roundtrip if metadata 199 # can be correctly parsed. 200 ep = export(MyModule(), inp) 201 buffer = io.BytesIO() 202 save(ep, buffer) 203 loaded_ep = load(buffer) 204 205 # Check that both modules run to confirm load was successful. 206 exp_out = ep.module()(*inp) 207 actual_out = loaded_ep.module()(*inp) 208 self.assertEqual(exp_out, actual_out) 209 210 def test_serialize_constant_outputs(self): 211 class MyModule(torch.nn.Module): 212 def __init__(self) -> None: 213 super().__init__() 214 215 def forward(self, x): 216 # Along with tensor output, return Nonetype 217 # and constant. Although these outputs aren't 218 # very useful, they do show up in graphs. 219 return x + 1, None, 1024 220 221 # Check that module can be roundtripped, thereby confirming proper deserialization. 222 inp = (torch.ones(10),) 223 ep = export(MyModule(), inp) 224 buffer = io.BytesIO() 225 save(ep, buffer) 226 loaded_ep = load(buffer) 227 228 exp_out = ep.module()(*inp) 229 actual_out = loaded_ep.module()(*inp) 230 self.assertEqual(exp_out, actual_out) 231 232 def test_serialize_multiple_returns_from_node(self) -> None: 233 class MyModule(torch.nn.Module): 234 def __init__(self) -> None: 235 super().__init__() 236 237 def forward(self, x, w, b): 238 return torch.nn.functional.layer_norm( 239 x, 240 x.size()[1:], 241 weight=w, 242 bias=b, 243 eps=1e-5, 244 ) 245 246 exported_module = export( 247 MyModule(), 248 ( 249 torch.ones([512, 512], requires_grad=True), 250 torch.ones([512]), 251 torch.ones([512]), 252 ), 253 ).run_decompositions() 254 255 serialized = ExportedProgramSerializer().serialize(exported_module) 256 node = serialized.exported_program.graph_module.graph.nodes[-1] 257 self.assertEqual(node.target, "torch.ops.aten.native_layer_norm.default") 258 # aten::native_layer_norm returns 3 tensors 259 self.assertEqual(len(node.outputs), 3) 260 261 # check the names are unique 262 seen = set() 263 for output in node.outputs: 264 name = output.as_tensor.name 265 self.assertNotIn(name, seen) 266 seen.add(name) 267 268 def test_serialize_sym_int(self) -> None: 269 class DynamicShapeSimpleModel(torch.nn.Module): 270 def __init__(self): 271 super().__init__() 272 273 def forward(self, a, b, c) -> torch.Tensor: 274 d = (torch.matmul(a, b) + c) / 2 275 d_s0 = d.shape[0] 276 d_s1 = d.shape[1] 277 d_s3 = d_s0 * d_s1 278 e = d.view(d_s3) 279 return torch.cat([e, e]) 280 281 inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) 282 dim0_ac = torch.export.Dim("dim0_ac") 283 dim1_bc = torch.export.Dim("dim1_b") 284 dynamic_shapes = { 285 "a": {0: dim0_ac}, 286 "b": {1: dim1_bc}, 287 "c": {0: dim0_ac, 1: dim1_bc}, 288 } 289 exported_module = export( 290 DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes 291 ).run_decompositions() 292 serialized = ExportedProgramSerializer().serialize(exported_module) 293 sym_size_nodes = [ 294 node 295 for node in serialized.exported_program.graph_module.graph.nodes 296 if node.target == "torch.ops.aten.sym_size.int" 297 ] 298 for node in sym_size_nodes: 299 self.assertEqual(node.inputs[0].name, "self") 300 self.assertEqual(node.inputs[1].name, "dim") 301 302 def test_serialize_list_returns(self) -> None: 303 class MyModule(torch.nn.Module): 304 def __init__(self) -> None: 305 super().__init__() 306 307 def forward(self, x): 308 return torch.split(x, 2) 309 310 input = torch.arange(10.0).reshape(5, 2) 311 exported_module = export(MyModule(), (input,)).run_decompositions() 312 313 serialized = ExportedProgramSerializer().serialize(exported_module) 314 node = serialized.exported_program.graph_module.graph.nodes[-1] 315 # split.Tensor gets decomposed to split_with_sizes by the core ATen decomposition table 316 self.assertEqual(node.target, "torch.ops.aten.split_with_sizes.default") 317 self.assertEqual(len(node.outputs), 1) 318 # Input looks like: 319 # tensor([[0, 1], 320 # [2, 3], 321 # [4, 5], 322 # [6, 7], 323 # [8, 9]]) 324 # Output looks like: 325 # (tensor([[0, 1], 326 # [2, 3]]), 327 # tensor([[4, 5], 328 # [6, 7]]), 329 # tensor([[8, 9]])) 330 self.assertEqual(len(node.outputs[0].as_tensors), 3) 331 332 # check the names are unique 333 seen = set() 334 for output in node.outputs[0].as_tensors: 335 name = output.name 336 self.assertNotIn(name, seen) 337 seen.add(name) 338 339 def test_multi_return_some_unused(self) -> None: 340 """ 341 Make sure the serialized output matches the op schema, even if some of 342 the arguments are never used in the graph. 343 """ 344 345 class MyModule(torch.nn.Module): 346 def __init__(self) -> None: 347 super().__init__() 348 349 def forward(self, x): 350 return torch.ops.aten.var_mean.correction(x, [1])[0] 351 352 exported_module = export( 353 MyModule(), 354 (torch.ones([512, 512], requires_grad=True),), 355 ).run_decompositions() 356 357 serialized = ExportedProgramSerializer().serialize(exported_module) 358 node = serialized.exported_program.graph_module.graph.nodes[-1] 359 self.assertEqual(node.target, "torch.ops.aten.var_mean.correction") 360 self.assertEqual(len(node.outputs), 2) 361 362 # check the names are unique 363 seen = set() 364 for output in node.outputs: 365 name = output.as_tensor.name 366 self.assertNotIn(name, seen) 367 seen.add(name) 368 369 def test_rational_ranges(self) -> None: 370 class M(torch.nn.Module): 371 def forward(self, x): 372 return x + x 373 374 ep = torch.export.export( 375 M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},) 376 ) 377 378 range_constraints = list(ep.range_constraints.keys()) 379 assert len(range_constraints) == 1 380 symint = range_constraints[0] 381 382 import sympy 383 384 upper_range = sympy.Rational(10, 3) 385 lower_range = sympy.Rational(10, 6) 386 ep.range_constraints[symint] = ValueRanges(lower=lower_range, upper=upper_range) 387 388 serialized = ExportedProgramSerializer().serialize(ep) 389 self.assertEqual(serialized.exported_program.range_constraints["s0"].min_val, 2) 390 self.assertEqual(serialized.exported_program.range_constraints["s0"].max_val, 3) 391 392 def test_kwargs_default(self) -> None: 393 """ 394 Tests that the kwargs default values are serialized even if they are not 395 specified 396 """ 397 398 class Foo(torch.nn.Module): 399 def forward(self, x: torch.Tensor) -> torch.Tensor: 400 values = torch.randn(3, 2) 401 return torch.searchsorted(x, values, side="right", right=True) 402 403 f = Foo() 404 405 x, _ = torch.sort(torch.randn(3, 4)) 406 exported_module = export(f, (x,)).run_decompositions() 407 serialized = ExportedProgramSerializer().serialize(exported_module) 408 409 node = serialized.exported_program.graph_module.graph.nodes[-1] 410 self.assertEqual(node.target, "torch.ops.aten.searchsorted.Tensor") 411 self.assertEqual(len(node.inputs), 4) 412 self.assertEqual(node.inputs[2].name, "right") 413 self.assertEqual(node.inputs[2].arg.as_bool, True) 414 self.assertEqual(node.inputs[3].name, "side") 415 self.assertEqual(node.inputs[3].arg.as_string, "right") 416 417 def test_canonicalize(self) -> None: 418 class Module(torch.nn.Module): 419 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 420 a = y + x 421 b = x + y 422 return b + a 423 424 ep = torch.export.export(Module(), (torch.randn(3, 2), torch.randn(3, 2))) 425 s = ExportedProgramSerializer().serialize(ep) 426 c = canonicalize(s.exported_program) 427 g = c.graph_module.graph 428 self.assertLess( 429 g.nodes[0].inputs[0].arg.as_tensor.name, 430 g.nodes[1].inputs[0].arg.as_tensor.name, 431 ) 432 433 def test_int_list(self) -> None: 434 class M(torch.nn.Module): 435 def forward(self, x): 436 return torch.ops.aten.sum.dim_IntList(x, []) 437 438 ep = torch.export.export(M(), (torch.randn(3, 2),)) 439 serialized = ExportedProgramSerializer().serialize(ep) 440 for node in serialized.exported_program.graph_module.graph.nodes: 441 if "aten.sum.dim_IntList" in node.target: 442 self.assertEqual(node.inputs[1].arg.type, "as_ints") 443 444 445@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") 446@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") 447class TestDeserialize(TestCase): 448 def setUp(self): 449 super().setUp() 450 init_torchbind_implementations() 451 452 def _check_graph_nodes(self, gm1, gm2, _check_meta=True): 453 # TODO: The _check_meta flag bypasses checking for 454 # source_fn/nn_module_stack as there is an issue with 455 # roundtripping the source_fn value on torch.ops.map nodes 456 # original source_fn: <functorch.experimental._map.MapWrapper object at 0x7f80a0549930> 457 # deserialized source_fn: 'functorch.experimental._map.map' 458 459 self.assertEqual(len(gm1.graph.nodes), len(gm2.graph.nodes)) 460 461 for node1, node2 in zip(gm1.graph.nodes, gm2.graph.nodes): 462 self.assertEqual(node1.op, node2.op) 463 if node1.op == "call_function": 464 # Check "val" metadata 465 val1 = node1.meta.get("val", None) 466 val2 = node2.meta.get("val", None) 467 if val1 is None or val2 is None: 468 # Either both are None 469 self.assertEqual(val1, val2) 470 elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor): 471 # Or both are fake tensors with the same shape/dtype 472 self.assertEqual(len(val1.shape), len(val2.shape)) 473 for s1, s2 in zip(val1.shape, val2.shape): 474 if is_concrete_int(s1) and is_concrete_int(s2): 475 self.assertEqual(s1, s2) 476 else: 477 self.assertEqual(str(s1), str(s2)) 478 self.assertEqual(val1.dtype, val2.dtype) 479 elif isinstance(val1, (list, tuple)) and isinstance( 480 val2, (list, tuple) 481 ): 482 # Or both are fake tensors lists with one element and with the 483 # same shape/dtype 484 for v1, v2 in zip( 485 pytree.tree_leaves(val1), pytree.tree_leaves(val2) 486 ): 487 if isinstance(v1, FakeTensor): 488 self.assertEqual(v1.shape, v2.shape) 489 self.assertEqual(v1.dtype, v2.dtype) 490 else: 491 # For expressions like 's0 < 10' can only compare through string 492 self.assertEqual(str(val1), str(val2)) 493 494 # Check "stack_trace" metadata 495 self.assertEqual( 496 node1.meta.get("stack_trace", None), 497 node2.meta.get("stack_trace", None), 498 ) 499 500 if node1.target == torch.ops.higher_order.cond: 501 true_graph1 = getattr(gm1, node1.args[1].target) 502 true_graph2 = getattr(gm2, node2.args[1].target) 503 self._check_graph_nodes(true_graph1, true_graph2) 504 505 false_graph1 = getattr(gm1, node1.args[2].target) 506 false_graph2 = getattr(gm2, node2.args[2].target) 507 self._check_graph_nodes(false_graph1, false_graph2) 508 elif node1.target == torch.ops.higher_order.map_impl: 509 map_graph1 = getattr(gm1, node1.args[0].target) 510 map_graph2 = getattr(gm2, node2.args[0].target) 511 self._check_graph_nodes(map_graph1, map_graph2, False) 512 513 if _check_meta and node1.op not in ("get_attr", "placeholder", "output"): 514 # Check "nn_module_stack" metadata 515 self.assertEqual( 516 node1.meta.get("nn_module_stack", None), 517 node2.meta.get("nn_module_stack", None), 518 ) 519 # Check "source_fn_stack" metadata 520 self.assertEqual( 521 node1.meta.get("source_fn_stack", None), 522 node2.meta.get("source_fn_stack", None), 523 ) 524 525 def check_graph( 526 self, 527 fn, 528 inputs, 529 dynamic_shapes=None, 530 _check_meta=True, 531 use_pre_dispatch=True, 532 strict=True, 533 ) -> None: 534 """Export a graph, serialize it, deserialize it, and compare the results.""" 535 536 def _deepcopy_inputs(inputs): 537 # copy.deepcopy(deepcopy) can fail if tensor inputs have attribute (i.e. __dict__). 538 # we remove __dict__ when deepcopying. 539 dict_mapping = dict() 540 inputs_clone = () 541 for idx, i in enumerate(inputs): 542 if isinstance(i, torch.Tensor) and hasattr(inputs[0], "__dict__"): 543 dict_mapping[idx] = i.__dict__ 544 i.__dict__ = {} 545 inputs_clone += (copy.deepcopy(i),) 546 547 # Add __dict__ back. 548 for k, v in dict_mapping.items(): 549 inputs[k].__dict__ = v 550 inputs_clone[k].__dict__ = v 551 return inputs_clone 552 553 def _check_graph(pre_dispatch): 554 if pre_dispatch: 555 ep = torch.export._trace._export( 556 fn, 557 _deepcopy_inputs(inputs), 558 {}, 559 dynamic_shapes=dynamic_shapes, 560 pre_dispatch=True, 561 strict=strict, 562 ) 563 else: 564 ep = torch.export.export( 565 fn, 566 _deepcopy_inputs(inputs), 567 {}, 568 dynamic_shapes=dynamic_shapes, 569 strict=strict, 570 ) 571 ep.graph.eliminate_dead_code() 572 573 serialized_artifact = serialize(ep, opset_version={"aten": 0}) 574 deserialized_ep = deserialize( 575 serialized_artifact, expected_opset_version={"aten": 0} 576 ) 577 deserialized_ep.graph.eliminate_dead_code() 578 579 orig_outputs = ep.module()(*_deepcopy_inputs(inputs)) 580 loaded_outputs = deserialized_ep.module()(*_deepcopy_inputs(inputs)) 581 582 flat_orig_outputs = pytree.tree_leaves(orig_outputs) 583 flat_loaded_outputs = pytree.tree_leaves(loaded_outputs) 584 585 for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs): 586 self.assertEqual(type(orig), type(loaded)) 587 if isinstance(orig, torch.Tensor): 588 if orig.is_meta: 589 self.assertEqual(orig, loaded) 590 else: 591 self.assertTrue(torch.allclose(orig, loaded)) 592 else: 593 self.assertEqual(orig, loaded) 594 self._check_graph_nodes( 595 ep.graph_module, deserialized_ep.graph_module, _check_meta 596 ) 597 598 if use_pre_dispatch: 599 _check_graph(pre_dispatch=True) 600 _check_graph(pre_dispatch=False) 601 else: 602 _check_graph(pre_dispatch=False) 603 604 def test_optional_tuple(self): 605 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 606 torch.library.define( 607 "mylib::foo", 608 "(Tensor a, Tensor b, Tensor? c) -> (Tensor, Tensor?)", 609 tags=torch.Tag.pt2_compliant_tag, 610 lib=lib, 611 ) 612 613 @torch.library.impl("mylib::foo", "cpu", lib=lib) 614 @torch.library.impl_abstract("mylib::foo") 615 def foo_impl(a, b, c): 616 res2 = None 617 if c is not None: 618 res2 = c + a + b 619 return a + b, res2 620 621 class M(torch.nn.Module): 622 def forward(self, a, b, c): 623 return torch.ops.mylib.foo(a, b, c) 624 625 self.check_graph(M(), (torch.randn(3), torch.randn(3), torch.randn(3))) 626 627 def test_auto_functionalize(self): 628 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 629 torch.library.define( 630 "mylib::foo1", 631 "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> Tensor", 632 tags=torch.Tag.pt2_compliant_tag, 633 lib=lib, 634 ) 635 torch.library.define( 636 "mylib::foo2", 637 "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", 638 tags=torch.Tag.pt2_compliant_tag, 639 lib=lib, 640 ) 641 torch.library.define( 642 "mylib::foo3", 643 "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", 644 tags=torch.Tag.pt2_compliant_tag, 645 lib=lib, 646 ) 647 648 @torch.library.impl("mylib::foo1", "cpu", lib=lib) 649 @torch.library.impl_abstract("mylib::foo1") 650 def foo1_impl(x, y, z, w, n): 651 x.add_(y[0] + w) 652 z.add_(y[1] + n) 653 return n + n 654 655 @torch.library.impl("mylib::foo2", "cpu", lib=lib) 656 @torch.library.impl_abstract("mylib::foo2") 657 def foo2_impl(x, y, z, w, n): 658 x.add_(y[0] + w) 659 z.add_(y[1] + n) 660 return (n + n, n * n) 661 662 @torch.library.impl("mylib::foo3", "cpu", lib=lib) 663 @torch.library.impl_abstract("mylib::foo3") 664 def foo3_impl(x, y, z, w, n): 665 x.add_(y[0] + w) 666 z.add_(y[1] + n) 667 return 668 669 class M(torch.nn.Module): 670 def forward(self, x, y, z, n): 671 n = torch.ops.mylib.foo1(x, y, z, 2, n) 672 torch.ops.mylib.foo3(x, y, z, 2, n) 673 return torch.ops.mylib.foo2(x, y, z, 2, n) 674 675 x = torch.randn(3) 676 y = (torch.randn(3), torch.randn(3)) 677 z = torch.randn(3) 678 n = torch.randn(3) 679 orig_args = (x, y, z, n) 680 681 # TODO Auto_functionalize is not supported on pre_dispatch IR 682 self.check_graph(M(), orig_args, use_pre_dispatch=False) 683 684 def test_multi_return(self) -> None: 685 """ 686 Test multiple return from a single node (ex. layer_norm has 2 outputs) 687 """ 688 689 class MyModule(torch.nn.Module): 690 def __init__(self) -> None: 691 super().__init__() 692 693 def forward(self, x, w, b): 694 return torch.nn.functional.layer_norm( 695 x, 696 x.size()[1:], 697 weight=w, 698 bias=b, 699 eps=1e-5, 700 ) 701 702 inputs = ( 703 torch.ones([512, 512], requires_grad=True), 704 torch.ones([512]), 705 torch.ones([512]), 706 ) 707 self.check_graph(MyModule(), inputs) 708 709 def test_basic(self) -> None: 710 class MyModule(torch.nn.Module): 711 def __init__(self) -> None: 712 super().__init__() 713 714 def forward(self, x): 715 x = x + x 716 x = x * x 717 x = x / x 718 return x, x.clone() 719 720 inputs = (torch.ones([512], requires_grad=True),) 721 self.check_graph(MyModule(), inputs) 722 723 def test_dynamic(self) -> None: 724 class DynamicShapeSimpleModel(torch.nn.Module): 725 def __init__(self) -> None: 726 super().__init__() 727 728 def forward(self, a, b, c) -> torch.Tensor: 729 d = (torch.matmul(a, b) + c) / 2 730 d_s0 = d.shape[0] 731 d_s1 = d.shape[1] 732 d_s3 = d_s0 * d_s1 733 e = d.view(d_s3) 734 return torch.cat([e, e]) 735 736 inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) 737 dim0_ac = torch.export.Dim("dim0_ac") 738 dynamic_shapes = {"a": {0: dim0_ac}, "b": None, "c": {0: dim0_ac}} 739 self.check_graph(DynamicShapeSimpleModel(), inputs, dynamic_shapes) 740 741 def test_sym_bool(self): 742 class Module(torch.nn.Module): 743 def forward(self, x, y): 744 assert x.size(0) in y 745 return x + y 746 747 f = Module() 748 self.check_graph(f, (torch.ones(1), torch.ones(3))) 749 750 def test_shape(self): 751 class Foo(torch.nn.Module): 752 def forward(self, x): 753 z, y = x.size() 754 return z + y + x[0], z 755 756 inputs = (torch.ones(2, 3),) 757 dim0_x, dim1_x = torch.export.dims("dim0_x", "dim1_x") 758 dynamic_shapes = {"x": (dim0_x, dim1_x)} 759 self.check_graph(Foo(), inputs, dynamic_shapes) 760 761 def test_module(self): 762 class M(torch.nn.Module): 763 def __init__(self) -> None: 764 super().__init__() 765 self.linear1 = torch.nn.Linear(3, 3) 766 self.relu = torch.nn.ReLU() 767 self.linear2 = torch.nn.Linear(3, 5) 768 769 def forward(self, x): 770 x = self.linear1(x) 771 x = self.linear1(x) 772 x = torch.nn.functional.relu(x) 773 x = self.linear2(x) 774 return x 775 776 inputs = (torch.randn(3, 3),) 777 self.check_graph(M(), inputs) 778 779 def test_module_meta(self): 780 class M(torch.nn.Module): 781 def __init__(self) -> None: 782 super().__init__() 783 self.p = torch.nn.Parameter(torch.ones(3, 3)) 784 785 def forward(self, x): 786 return self.p + x 787 788 with torch.device("meta"): 789 mod = M() 790 791 inputs = (torch.randn(3, 3, device="meta"),) 792 self.check_graph(mod, inputs) 793 794 def test_cond(self): 795 from functorch.experimental.control_flow import cond 796 797 inputs = torch.ones(4, 3), torch.zeros(4, 3) 798 799 class M(torch.nn.Module): 800 def forward(self, x, y): 801 def t(x, y): 802 return x + y 803 804 def f(x, y): 805 return x - y 806 807 return cond(x[0][0] > 4, t, f, [x, y]) 808 809 self.check_graph(M(), inputs) 810 811 def test_map(self): 812 from functorch.experimental import control_flow 813 814 def f(x, y): 815 return x + y 816 817 class Module(torch.nn.Module): 818 def forward(self, xs, y): 819 return control_flow.map(f, xs, y) 820 821 g = Module() 822 inputs = (torch.ones(3, 2, 2), torch.ones(2)) 823 self.check_graph(g, inputs, _check_meta=False) 824 825 def test_tensor_tensor_list(self): 826 with torch.library._scoped_library("_export", "FRAGMENT") as lib: 827 lib.define( 828 "_test_tensor_tensor_list_output(Tensor x, Tensor y) -> (Tensor, Tensor[])", 829 tags=torch.Tag.pt2_compliant_tag, 830 ) 831 832 def _test_tensor_tensor_list_output(x, y): 833 return y, [x] 834 835 lib.impl( 836 "_test_tensor_tensor_list_output", 837 _test_tensor_tensor_list_output, 838 "CPU", 839 ) 840 lib.impl( 841 "_test_tensor_tensor_list_output", 842 _test_tensor_tensor_list_output, 843 "Meta", 844 ) 845 846 class M(torch.nn.Module): 847 def forward(self, x, y): 848 a, b = torch.ops._export._test_tensor_tensor_list_output.default( 849 x, y 850 ) 851 return a + b[0] 852 853 self.check_graph(M(), (torch.rand(3, 2), torch.rand(3, 2))) 854 855 def test_list_of_optional_tensors(self) -> None: 856 class MyModule(torch.nn.Module): 857 def __init__(self) -> None: 858 super().__init__() 859 860 def forward(self, x, y, z): 861 indices = [None, None, torch.tensor([1, 3, 5, 7])] 862 indexed = torch.ops.aten.index.Tensor(x + y, indices) 863 return indexed + z 864 865 inputs = (torch.rand(8, 8, 8), torch.rand(8, 8, 8), torch.rand(8, 8, 4)) 866 self.check_graph(MyModule(), inputs) 867 868 def test_sym_ite(self): 869 class Foo(torch.nn.Module): 870 def forward(self, x): 871 b = x.shape[0] == 5 872 ret = torch.sym_ite(b, x.shape[0], x.shape[1]) 873 return ret 874 875 dynamic_shapes = {"x": {0: Dim("dim0"), 1: Dim("dim1")}} 876 self.check_graph(Foo(), (torch.ones(4, 5),), dynamic_shapes=dynamic_shapes) 877 878 def test_multiple_getitem(self): 879 class M(torch.nn.Module): 880 def forward(self, x): 881 a, b = torch.topk(x, 2) 882 a = a * 2 883 return a, b 884 885 ep = torch.export.export(M(), (torch.ones(3),)) 886 887 # insert another getitem node 888 for node in ep.graph.nodes: 889 if node.op == "call_function" and node.target == torch.ops.aten.mul.Tensor: 890 getitem_0 = node.args[0] 891 with ep.graph.inserting_before(getitem_0): 892 getitem_copy = ep.graph.node_copy(getitem_0) 893 mul_node = ep.graph.call_function( 894 torch.ops.aten.mul.Tensor, (getitem_copy, 2) 895 ) 896 mul_node.meta = copy.copy(getitem_copy.meta) 897 node.args = (getitem_0, mul_node) 898 899 deserialized_ep = deserialize(serialize(ep)) 900 901 inp = (torch.randn(3),) 902 orig_res = ep.module()(*inp) 903 res = deserialized_ep.module()(*inp) 904 self.assertTrue(torch.allclose(orig_res[0], res[0])) 905 self.assertTrue(torch.allclose(orig_res[1], res[1])) 906 907 # The deserialized graph should have deduped getitem calls 908 self.assertExpectedInline( 909 deserialized_ep.graph_module.code.strip("\n"), 910 """\ 911def forward(self, x): 912 topk_default = torch.ops.aten.topk.default(x, 2); x = None 913 getitem = topk_default[0] 914 getitem_1 = topk_default[1]; topk_default = None 915 mul_tensor = torch.ops.aten.mul.Tensor(getitem, 2) 916 mul = torch.ops.aten.mul.Tensor(getitem, mul_tensor); getitem = mul_tensor = None 917 return (mul, getitem_1) 918 """, 919 ) 920 921 @parametrize( 922 "name,case", 923 get_filtered_export_db_tests(), 924 name_fn=lambda name, case: f"case_{name}", 925 ) 926 def test_exportdb_supported(self, name: str, case: ExportCase) -> None: 927 model = case.model 928 _check_meta = "map" not in name 929 self.check_graph(model, case.example_args, _check_meta=_check_meta) 930 931 def test_constraints(self): 932 class Module(torch.nn.Module): 933 def forward(self, x, y): 934 n = x.item() 935 torch._check_is_size(n) 936 return y.sum() + torch.ones(n, 5).sum() 937 938 f = Module() 939 self.check_graph(f, (torch.tensor(3), torch.randn(4, 5))) 940 941 def test_get_attr(self) -> None: 942 class Module(torch.nn.Module): 943 def forward(self, x): 944 return x + torch.tensor(3) 945 946 f = Module() 947 self.check_graph(f, (torch.tensor(3),)) 948 949 def test_get_attr_list(self) -> None: 950 class Module(torch.nn.Module): 951 def forward(self, x): 952 return torch.cat([x, torch.tensor([1, 1])]) 953 954 f = Module() 955 self.check_graph(f, (torch.tensor([1, 1]),)) 956 957 @unittest.skipIf(not torch.cuda.is_available(), "Requires cuda") 958 def test_device(self) -> None: 959 class MyModule(torch.nn.Module): 960 def __init__(self) -> None: 961 super().__init__() 962 self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) 963 self.relu = torch.nn.ReLU() 964 965 def forward(self, x): 966 conv = self.conv(x) 967 relu = self.relu(conv) 968 mul = relu * 0.5 969 return mul 970 971 inp = torch.randn((1, 3, 224, 224), dtype=torch.float).to("cuda") 972 model = MyModule().eval().cuda() 973 self.check_graph(model, (inp,)) 974 975 def test_custom_obj_tuple_out(self): 976 class MyModule(torch.nn.Module): 977 def __init__(self) -> None: 978 super().__init__() 979 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 980 981 def forward(self, x): 982 a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x) 983 y = a[0] + a[1] 984 b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y) 985 return x + b 986 987 m = MyModule() 988 inputs = (torch.ones(2, 3),) 989 self.check_graph(m, inputs, strict=False) 990 991 def test_custom_obj(self): 992 class MyModule(torch.nn.Module): 993 def __init__(self) -> None: 994 super().__init__() 995 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 996 997 def forward(self, x): 998 a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x) 999 b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a) 1000 return x + b 1001 1002 m = MyModule() 1003 inputs = (torch.ones(2, 3),) 1004 self.check_graph(m, inputs, strict=False) 1005 1006 def test_custom_obj_list_out(self): 1007 class MyModule(torch.nn.Module): 1008 def __init__(self) -> None: 1009 super().__init__() 1010 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 1011 1012 def forward(self, x): 1013 a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x) 1014 y = a[0] + a[1] + a[2] 1015 b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y) 1016 return x + b 1017 1018 m = MyModule() 1019 inputs = (torch.ones(2, 3),) 1020 self.check_graph(m, inputs, strict=False) 1021 1022 def test_export_no_inputs(self): 1023 class M(torch.nn.Module): 1024 def __init__(self) -> None: 1025 super().__init__() 1026 self.p = torch.ones(3, 3) 1027 1028 def forward(self): 1029 return self.p * self.p 1030 1031 ep = torch.export.export(M(), ()) 1032 ep._example_inputs = None 1033 roundtrip_ep = deserialize(serialize(ep)) 1034 self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()())) 1035 1036 1037instantiate_parametrized_tests(TestDeserialize) 1038 1039 1040@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") 1041class TestSchemaVersioning(TestCase): 1042 def test_error(self): 1043 class Module(torch.nn.Module): 1044 def forward(self, x): 1045 return x + x 1046 1047 f = Module() 1048 ep = export(f, (torch.randn(1, 3),)) 1049 1050 serialized_program = ExportedProgramSerializer().serialize(ep) 1051 serialized_program.exported_program.schema_version.major = -1 1052 with self.assertRaisesRegex( 1053 SerializeError, r"Serialized schema version .* does not match our current" 1054 ): 1055 ExportedProgramDeserializer().deserialize( 1056 serialized_program.exported_program, 1057 serialized_program.state_dict, 1058 serialized_program.constants, 1059 serialized_program.example_inputs, 1060 ) 1061 1062 1063# We didn't set up kwargs input yet 1064unittest.expectedFailure(TestDeserialize.test_exportdb_supported_case_fn_with_kwargs) 1065 1066 1067@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") 1068class TestSaveLoad(TestCase): 1069 def test_save_buffer(self): 1070 inp = (torch.tensor([0.1, 0.1]),) 1071 1072 class Module(torch.nn.Module): 1073 def __init__(self) -> None: 1074 super().__init__() 1075 self.linear = torch.nn.Linear(2, 2) 1076 1077 def forward(self, x): 1078 x = x + 1 1079 y = x.t() 1080 y = y.relu() 1081 y = self.linear(y) 1082 return y 1083 1084 ep = export(Module(), inp) 1085 1086 buffer = io.BytesIO() 1087 save(ep, buffer) 1088 buffer.seek(0) 1089 loaded_ep = load(buffer) 1090 1091 self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) 1092 1093 def test_save_file(self): 1094 class Foo(torch.nn.Module): 1095 def forward(self, x): 1096 return x * x 1097 1098 f = Foo() 1099 1100 inp = (torch.randn(2, 2),) 1101 ep = export(f, inp) 1102 1103 with tempfile.NamedTemporaryFile() as f: 1104 save(ep, f) 1105 f.seek(0) 1106 loaded_ep = load(f) 1107 1108 self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) 1109 1110 def test_save_path(self): 1111 class Foo(torch.nn.Module): 1112 def forward(self, x, y): 1113 return x + y 1114 1115 f = Foo() 1116 1117 inp = (torch.tensor([6]), torch.tensor([7])) 1118 ep = export(f, inp) 1119 1120 with TemporaryFileName() as fname: 1121 path = Path(fname) 1122 save(ep, path) 1123 loaded_ep = load(path) 1124 1125 self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) 1126 1127 def test_save_extra(self): 1128 inp = (torch.tensor([0.1, 0.1]),) 1129 1130 class Foo(torch.nn.Module): 1131 def forward(self, x): 1132 return x * x + x 1133 1134 f = Foo() 1135 1136 ep = export(f, inp) 1137 1138 buffer = io.BytesIO() 1139 save(ep, buffer, extra_files={"extra.txt": "moo"}) 1140 buffer.seek(0) 1141 extra_files = {"extra.txt": ""} 1142 loaded_ep = load(buffer, extra_files=extra_files) 1143 1144 self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) 1145 self.assertEqual(extra_files["extra.txt"], "moo") 1146 1147 def test_version_error(self): 1148 class Foo(torch.nn.Module): 1149 def forward(self, x): 1150 return x + x 1151 1152 f = Foo() 1153 1154 ep = export(f, (torch.randn(1, 3),)) 1155 1156 with tempfile.NamedTemporaryFile() as f: 1157 save(ep, f) 1158 f.seek(0) 1159 1160 # Modify the version 1161 with zipfile.ZipFile(f, "a") as zipf: 1162 zipf.writestr("version", "-1.1") 1163 1164 with self.assertRaisesRegex( 1165 RuntimeError, r"Serialized version .* does not match our current" 1166 ): 1167 f.seek(0) 1168 load(f) 1169 1170 def test_save_constants(self): 1171 class Foo(torch.nn.Module): 1172 def __init__(self) -> None: 1173 super().__init__() 1174 self.a = torch.tensor(3) 1175 1176 def forward(self, x): 1177 list_tensor = [torch.tensor(3), torch.tensor(4)] 1178 return x + self.a + list_tensor[0] + list_tensor[1] 1179 1180 ep = export(Foo(), (torch.tensor(1),)) 1181 buffer = io.BytesIO() 1182 save(ep, buffer) 1183 buffer.seek(0) 1184 loaded_ep = load(buffer) 1185 1186 inp = (torch.tensor(1),) 1187 self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) 1188 1189 1190@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") 1191class TestSerializeCustomClass(TestCase): 1192 def setUp(self): 1193 super().setUp() 1194 init_torchbind_implementations() 1195 1196 def test_custom_class(self): 1197 custom_obj = torch.classes._TorchScriptTesting._PickleTester([3, 4]) 1198 1199 class Foo(torch.nn.Module): 1200 def forward(self, x): 1201 return x + x 1202 1203 f = Foo() 1204 1205 inputs = (torch.zeros(4, 4),) 1206 ep = export(f, inputs) 1207 1208 # Replace one of the values with an instance of our custom class 1209 for node in ep.graph.nodes: 1210 if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 1211 with ep.graph.inserting_before(node): 1212 custom_node = ep.graph.call_function( 1213 torch.ops._TorchScriptTesting.take_an_instance.default, 1214 (custom_obj,), 1215 ) 1216 custom_node.meta["val"] = torch.ones(4, 4) 1217 custom_node.meta["torch_fn"] = ( 1218 "take_an_instance", 1219 "take_an_instance", 1220 ) 1221 arg0, _ = node.args 1222 node.args = (arg0, custom_node) 1223 1224 serialized_vals = serialize(ep) 1225 1226 ep_str = serialized_vals.exported_program.decode("utf-8") 1227 assert "class_fqn" in ep_str 1228 assert custom_obj._type().qualified_name() in ep_str 1229 1230 deserialized_ep = deserialize(serialized_vals) 1231 1232 for node in deserialized_ep.graph.nodes: 1233 if ( 1234 node.op == "call_function" 1235 and node.target 1236 == torch.ops._TorchScriptTesting.take_an_instance.default 1237 ): 1238 arg = node.args[0] 1239 self.assertTrue(isinstance(arg, torch._C.ScriptObject)) 1240 self.assertEqual(arg._type(), custom_obj._type()) 1241 self.assertEqual(arg.__getstate__(), custom_obj.__getstate__()) 1242 self.assertEqual(arg.top(), 7) 1243 1244 def test_custom_class_containing_fake_tensor(self): 1245 class Foo(torch.nn.Module): 1246 def __init__(self) -> None: 1247 super().__init__() 1248 self.custom_obj = torch.classes._TorchScriptTesting._ContainsTensor( 1249 torch.rand(2, 3) 1250 ) 1251 1252 def forward(self, x): 1253 return x + self.custom_obj.get() 1254 1255 with FakeTensorMode(): 1256 f = Foo() 1257 1258 inputs = (torch.zeros(2, 3),) 1259 with enable_torchbind_tracing(): 1260 ep = export(f, inputs, strict=False) 1261 1262 serialized_vals = serialize(ep) 1263 ep = deserialize(serialized_vals) 1264 self.assertTrue(isinstance(ep.constants["custom_obj"].get(), FakeTensor)) 1265 1266 def test_custom_tag_metadata_serialization(self): 1267 class Foo(torch.nn.Module): 1268 def forward(self, x): 1269 return x + x 1270 1271 f = Foo() 1272 1273 inputs = (torch.zeros(4, 4),) 1274 ep = export(f, inputs) 1275 1276 new_gm = copy.deepcopy(ep.graph_module) 1277 new_gm.meta["custom"] = {} 1278 new_gm.meta["custom"]["f"] = "bar" 1279 1280 for node in new_gm.graph.nodes: 1281 if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 1282 node.meta["custom"] = {} 1283 node.meta["custom"]["quantization_tag"] = "foo" 1284 1285 new_ep = ep._update(new_gm, ep.graph_signature) 1286 serialized_vals = serialize(new_ep) 1287 new_ep = deserialize(serialized_vals) 1288 1289 self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar") 1290 counter = 0 1291 for node in new_ep.graph.nodes: 1292 if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 1293 counter += 1 1294 self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") 1295 self.assertEqual(counter, 1) 1296 1297 def test_custom_tag_metadata_decomp(self): 1298 class Foo(torch.nn.Module): 1299 def __init__(self): 1300 super().__init__() 1301 self.linear = torch.nn.Linear(2, 2) 1302 1303 def forward(self, x): 1304 return self.linear(x) 1305 1306 f = Foo() 1307 1308 inputs = (torch.ones(2, 2),) 1309 ep = export(f, inputs) 1310 1311 new_gm = copy.deepcopy(ep.graph_module) 1312 new_gm.meta["custom"] = {} 1313 new_gm.meta["custom"]["f"] = "bar" 1314 1315 counter = 0 1316 for node in new_gm.graph.nodes: 1317 if ( 1318 node.op == "call_function" 1319 and node.target == torch.ops.aten.linear.default 1320 ): 1321 counter += 1 1322 node.meta["custom"] = {} 1323 node.meta["custom"]["quantization_tag"] = "foo" 1324 self.assertEqual(counter, 1) 1325 1326 new_ep = ep._update(new_gm, ep.graph_signature) 1327 new_ep = new_ep.run_decompositions() 1328 1329 self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar") 1330 counter = 0 1331 for node in new_ep.graph.nodes: 1332 if node.op == "call_function": 1333 counter += 1 1334 self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") 1335 self.assertTrue(counter > 1) 1336 1337 # TODO For some reason, this doesn't work on Windows ONLY. 1338 # def test_custom_tag_metadata_reexport(self): 1339 # class Foo(torch.nn.Module): 1340 # def forward(self, x): 1341 # return x + x 1342 # 1343 # f = Foo() 1344 # 1345 # inputs = (torch.zeros(4, 4),) 1346 # ep = export(f, inputs) 1347 # 1348 # new_gm = copy.deepcopy(ep.graph_module) 1349 # new_gm.meta["custom"] = {} 1350 # new_gm.meta["custom"]["f"] = "bar" 1351 # 1352 # for node in new_gm.graph.nodes: 1353 # if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 1354 # node.meta["custom"] = {} 1355 # node.meta["custom"]["quantization_tag"] = "foo" 1356 # 1357 # new_ep = ep._update(new_gm, ep.graph_signature) 1358 # new_ep = torch.export.export(new_ep.module(), inputs) 1359 # 1360 # self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar") 1361 # counter = 0 1362 # for node in new_ep.graph.nodes: 1363 # if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 1364 # counter += 1 1365 # self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") 1366 # self.assertEqual(counter, 1) 1367 1368 def test_custom_tag_metadata_copy(self): 1369 class Foo(torch.nn.Module): 1370 def forward(self, x): 1371 return x + x 1372 1373 f = Foo() 1374 1375 inputs = (torch.zeros(4, 4),) 1376 ep = export(f, inputs) 1377 1378 new_gm = copy.deepcopy(ep.graph_module) 1379 new_gm.meta["custom"] = {} 1380 new_gm.meta["custom"]["f"] = "bar" 1381 1382 for node in new_gm.graph.nodes: 1383 if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 1384 node.meta["custom"] = {} 1385 node.meta["custom"]["quantization_tag"] = "foo" 1386 1387 new_gm = copy.deepcopy(new_gm) 1388 1389 self.assertEqual(new_gm.meta["custom"]["f"], "bar") 1390 counter = 0 1391 for node in new_gm.graph.nodes: 1392 if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 1393 counter += 1 1394 self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") 1395 self.assertEqual(counter, 1) 1396 1397 1398if __name__ == "__main__": 1399 run_tests() 1400