1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8import copy 9import os 10import tempfile 11import unittest 12from typing import List, Optional, Tuple 13 14import executorch.exir as exir 15 16# Import passes 17import executorch.exir.memory_planning # noqa 18import torch 19from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge 20from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops 21from executorch.exir.dialects.edge._ops import EdgeOpOverload 22from executorch.exir.emit import emit_program 23from executorch.exir.graph_module import get_control_flow_submodules 24from executorch.exir.pass_base import ExportPass, PassResult 25from executorch.exir.passes import ( 26 dead_code_elimination_pass, 27 DebugPass, 28 HintBasedSymShapeEvalPass, 29 MemoryPlanningPass, 30 propagate_dynamic_shape, 31 RemoveNoopPass, 32 ReplaceSymSizeOpPass, 33 ToOutVarPass, 34) 35from executorch.exir.passes.constant_prop_pass import constant_prop_pass 36from executorch.exir.passes.debug_handle_generator_pass import ( 37 DebugHandleGeneratorPass, 38 generate_missing_debug_handles, 39) 40from executorch.exir.passes.insert_write_back_for_buffers_pass import ( 41 insert_write_back_for_buffers_pass, 42) 43 44from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass 45from executorch.exir.passes.normalize_view_copy_base_pass import ( 46 NormalizeViewCopyBasePass, 47) 48from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass 49from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators 50from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass 51from executorch.exir.passes.replace_view_copy_with_view_pass import ( 52 ReplaceViewCopyWithViewPass, 53) 54from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass 55from executorch.exir.passes.spec_prop_pass import SpecPropPass 56from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass 57from executorch.exir.program._program import lift_constant_tensor_pass 58from executorch.exir.schema import TensorShapeDynamism 59from executorch.exir.tensor import TensorSpec 60from executorch.exir.tests.common import register_additional_test_aten_ops 61from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic 62from executorch.exir.tests.models import MLP, Mul 63from functorch.experimental import control_flow 64 65from torch import nn 66 67from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 68from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 69 get_symmetric_quantization_config, 70 XNNPACKQuantizer, 71) 72from torch.export import export 73from torch.export.graph_signature import InputKind, InputSpec, TensorArgument 74from torch.fx import GraphModule, subgraph_rewriter 75from torch.fx.experimental.proxy_tensor import make_fx 76from torch.library import impl, Library 77from torch.testing import FileCheck 78from torch.utils import _pytree as pytree 79 80 81# pyre-ignore 82def collect_ops(gm: torch.fx.GraphModule): 83 """ 84 Collect all targets for call_function nodes from the graph module recursively. 85 """ 86 ops = set() 87 for subgm in gm.modules(): 88 if not isinstance(subgm, torch.fx.GraphModule): 89 continue 90 for node in subgm.graph.nodes: 91 if node.op == "call_function": 92 ops.add(node.target) 93 return ops 94 95 96lib = Library("DO_NOT_USE_TEST_ONLY", "DEF") 97 98lib.define("foo(Tensor self) -> (Tensor, Tensor)") 99lib.define("add_relu(Tensor self, Tensor other) -> Tensor") 100 101 102@impl(lib, "foo", "CompositeExplicitAutograd") 103def foo(a: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 104 return a + 1, None 105 106 107lib.define( 108 "foo.out(Tensor self, *, Tensor(a!) out1, Tensor(b!) out2) -> (Tensor(a!), Tensor(b!))" 109) 110 111 112@impl(lib, "foo.out", "CompositeExplicitAutograd") 113def foo_out( 114 a: torch.Tensor, out1: torch.Tensor, out2: torch.Tensor 115) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 116 return a + 1, None 117 118 119class TestPasses(unittest.TestCase): 120 @classmethod 121 def setUpClass(cls) -> None: 122 register_additional_test_aten_ops() 123 124 def test_remove_mixed_type_operators(self) -> None: 125 class Add(torch.nn.Module): 126 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 127 return (x + y) + x 128 129 add = Add() 130 131 int_tensor = torch.tensor([[1, 2, 3]]) 132 float_tensor = torch.tensor([[1.0, 2.0, 3.0]]) 133 edge_prog = to_edge( 134 export( 135 add, 136 (int_tensor, float_tensor), 137 ) 138 ) 139 140 new_prog = edge_prog.transform([RemoveMixedTypeOperators()]) 141 new_graph_module = new_prog.exported_program().graph_module 142 self.assertIsNotNone(new_graph_module) 143 144 add_count = 0 145 146 for node in new_graph_module.graph.nodes: 147 if ( 148 node.op == "call_function" 149 and node.target == exir_ops.edge.aten.add.Tensor 150 ): 151 add_count += 1 152 node_args = node.args 153 for arg in node_args: 154 self.assertEqual(arg.meta["val"].dtype, torch.float) 155 156 self.assertEqual(add_count, 2) 157 158 double_tensor = torch.tensor([[1.0, 2.0, 3.0]]) 159 double_tensor = double_tensor.to(torch.double) 160 161 double_prog = to_edge(export(add, (int_tensor, double_tensor))) 162 163 double_prog.transform([RemoveMixedTypeOperators()]) 164 new_graph_module_double = double_prog.exported_program().graph_module 165 self.assertIsNotNone(new_graph_module_double) 166 167 add_count_double = 0 168 169 for node in new_graph_module_double.graph.nodes: 170 if ( 171 node.op == "call_function" 172 and node.target == exir_ops.edge.aten.add.Tensor 173 ): 174 add_count_double += 1 175 node_args = node.args 176 for arg in node_args: 177 self.assertEqual(arg.meta["val"].dtype, torch.double) 178 179 self.assertEqual(add_count_double, 2) 180 181 class Mult(torch.nn.Module): 182 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 183 return x * y 184 185 mult = Mult() 186 187 float_tensor_vert = float_tensor.T 188 mult_prog = to_edge( 189 export( 190 mult, 191 (int_tensor, float_tensor_vert), 192 ) 193 ) 194 195 # graph_module_mult.graph.print_tabular() 196 197 mult_prog = mult_prog.transform([RemoveMixedTypeOperators()]) 198 new_graph_module_mult = mult_prog.exported_program().graph_module 199 self.assertIsNotNone(new_graph_module_mult) 200 201 mult_count = 0 202 203 for node in new_graph_module_mult.graph.nodes: 204 if ( 205 node.op == "call_function" 206 and node.target == exir_ops.edge.aten.mul.Tensor 207 ): 208 mult_count += 1 209 node_args = node.args 210 for arg in node_args: 211 self.assertEqual(arg.meta["val"].dtype, torch.float) 212 213 self.assertEqual(mult_count, 1) 214 215 def test_remove_noop_pass(self) -> None: 216 class Foo(torch.nn.Module): 217 def forward(self, x: torch.Tensor) -> torch.Tensor: 218 return x.to(dtype=torch.float32) 219 220 foo = Foo() 221 222 # Turn off functionalization so that we can get the actual to.dtype op 223 edge_prog = to_edge( 224 export( 225 foo, 226 (torch.ones(1, dtype=torch.float32),), 227 ) 228 ) 229 edge_prog = edge_prog.transform([RemoveNoopPass()]) 230 self.assertIsNotNone(edge_prog.exported_program().graph_module) 231 new_graph_module = edge_prog.exported_program().graph_module 232 for node in new_graph_module.graph.nodes: 233 if node.op == "call_function": 234 self.assertNotEqual(node.target, torch.ops.aten.to.dtype) 235 236 def test_redundant_slice_copy_removal(self) -> None: 237 class FooWithNoSlice(torch.nn.Module): 238 def forward(self, x: torch.Tensor) -> torch.Tensor: 239 return x[:, :, :] 240 241 foo_with_no_slice = FooWithNoSlice() 242 243 class FooWithOneSlice(torch.nn.Module): 244 def forward(self, x: torch.Tensor) -> torch.Tensor: 245 return x[:1, :, :] 246 247 foo_with_one_slice = FooWithOneSlice() 248 249 class FooWithAllSlices(torch.nn.Module): 250 def forward(self, x: torch.Tensor) -> torch.Tensor: 251 return x[:1, :2, 2:4] 252 253 foo_with_all_slices = FooWithAllSlices() 254 255 # Turn off functionalization so that we can get the actual to.dtype op 256 x = torch.ones((3, 8, 8)) 257 prog = to_edge( 258 export( 259 foo_with_no_slice, 260 (x,), 261 ) 262 ) 263 prog = prog.transform([RemoveNoopPass()]) 264 new_graph_module = prog.exported_program().graph_module 265 FileCheck().check_count( 266 "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 0, exactly=True 267 ).run(new_graph_module.code) 268 269 prog = to_edge( 270 export( 271 foo_with_one_slice, 272 (x,), 273 ) 274 ) 275 prog = prog.transform([RemoveNoopPass()]) 276 new_graph_module = prog.exported_program().graph_module 277 FileCheck().check_count( 278 "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 1, exactly=True 279 ).run(new_graph_module.code) 280 281 prog = to_edge( 282 export( 283 foo_with_all_slices, 284 (x,), 285 ) 286 ) 287 prog = prog.transform([RemoveNoopPass()]) 288 new_graph_module = prog.exported_program().graph_module 289 FileCheck().check_count( 290 "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 3, exactly=True 291 ).run(new_graph_module.code) 292 293 def test_compile_to_edge(self) -> None: 294 class Foo(torch.nn.Module): 295 def forward(self, x: torch.Tensor) -> torch.Tensor: 296 return x * 2 297 298 f = Foo() 299 300 x = (torch.randn(2, 3),) 301 302 to_edge( 303 export( 304 f, 305 x, 306 ) 307 ).exported_program().graph_module 308 # TODO(angelayi): Add a utility function that verifies a model is in 309 # the edge dialect 310 311 def test_to_out_variant_none_output(self) -> None: 312 class CompositeModel(torch.nn.Module): 313 def __init__(self, _weight): 314 super().__init__() 315 self.weight = _weight 316 self.lstm = torch.nn.LSTM( 317 input_size=32, 318 hidden_size=32, 319 num_layers=1, 320 ) 321 322 def forward(self, x_raw, h, c): 323 output, (hn, cn) = self.lstm(x_raw, (h, c)) 324 return output 325 326 # Prepare input and trace it 327 input_x = torch.ones([1, 32]) 328 input_h = torch.ones([1, 32]) 329 input_c = torch.ones([1, 32]) 330 inputs = (input_x, input_h, input_c) 331 332 composite_m = CompositeModel(3) 333 334 edge_prog = to_edge( 335 export( 336 composite_m, 337 inputs, 338 ) 339 # torch._ops.aten.t.default 340 , 341 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 342 ) 343 344 new_prog = edge_prog.transform([SpecPropPass()]) 345 346 new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module) 347 self.assertIsNotNone(new_gm_res) 348 new_gm = new_gm_res.graph_module 349 for node in new_gm.graph.nodes: 350 if node.op == "call_function" and node.target in [ 351 torch.ops.DO_NOT_USE_TEST_ONLY.foo.out, 352 torch.ops.my_awesome_3rdparty_ns.awesome_op.out, 353 ]: 354 self.assertEqual(len(node.kwargs), 2) 355 out1_node = node.kwargs["out1"] 356 self.assertEqual(out1_node.op, "call_function") 357 self.assertIs(out1_node.target, memory.alloc) 358 self.assertIs(node.kwargs["out2"], None) 359 360 new_gm_res = MemoryPlanningPass()(new_gm) 361 self.assertIsNotNone(new_gm_res) 362 new_gm = new_gm_res.graph_module 363 new_prog.exported_program().graph_module.graph = new_gm.graph 364 emit_program(new_prog.exported_program()) 365 366 def test_to_out_variant_singleon_tensor_list(self) -> None: 367 class MyModel(nn.Module): 368 def __init__(self): 369 super().__init__() 370 371 def forward(self, x): 372 return torch.split(x, 10) 373 374 def get_random_inputs(self): 375 return (torch.randn(10),) 376 377 model = MyModel() 378 inputs = model.get_random_inputs() 379 prog = to_edge( 380 export( 381 model, 382 inputs, 383 ), 384 compile_config=EdgeCompileConfig(_check_ir_validity=False), 385 ) # TODO(larryliu): fix split_copy 386 new_gm_res = ToOutVarPass()(prog.exported_program().graph_module) 387 self.assertIsNotNone(new_gm_res) 388 new_gm = new_gm_res.graph_module 389 390 for nd in new_gm.graph.nodes: 391 if nd.target is exir_ops.edge.aten.split_copy.Tensor_out: 392 break 393 394 val = nd.meta["val"] 395 396 # We must return a spec which is a list of a signle TensorSpec item. 397 # Returning the TensorSpec item directly cause future getitem op fails. 398 self.assertTrue(isinstance(val, (tuple, list))) 399 self.assertEqual(1, len(val)) 400 401 def test_to_out_variant_multiple_out(self) -> None: 402 class MyModel(nn.Module): 403 def __init__(self): 404 super().__init__() 405 406 def forward(self, x): 407 return torch.topk(x, 5) 408 409 def get_random_inputs(self): 410 return (torch.randn(10),) 411 412 model = MyModel() 413 inputs = model.get_random_inputs() 414 prog = to_edge( 415 export( 416 model, 417 inputs, 418 ), 419 compile_config=EdgeCompileConfig(_check_ir_validity=False), 420 ) # TODO(larryliu): fix topk 421 new_gm_res = ToOutVarPass()(prog.exported_program().graph_module) 422 self.assertIsNotNone(new_gm_res) 423 new_gm = new_gm_res.graph_module 424 425 for nd in new_gm.graph.nodes: 426 if nd.target is torch.ops.aten.topk.values: 427 break 428 429 val = nd.meta["val"] 430 431 # We must return a spec which is a list of a signle TensorSpec item. 432 # Returning the TensorSpec item directly cause future getitem op fails. 433 self.assertTrue(isinstance(val, (tuple, list))) 434 self.assertEqual(2, len(val)) 435 436 def test_to_out_variant_to_copy(self) -> None: 437 class Module(torch.nn.Module): 438 def __init__(self): 439 super().__init__() 440 441 def forward(self, x): 442 return x.to(torch.int32) 443 444 model = Module() 445 446 inputs = torch.tensor(1.0, dtype=torch.float) 447 model_res = model(inputs) 448 449 edge_dialect = to_edge( 450 export( 451 model, 452 (inputs,), 453 ) 454 ) 455 edge_res = edge_dialect.exported_program().module()(inputs) 456 self.assertTrue(torch.allclose(model_res, edge_res)) 457 458 def test_export_pass(self) -> None: 459 class Foo(torch.nn.Module): 460 def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 461 y = torch.cat([x, x]) 462 return torch.ops.aten.tensor_split.sections(y, 2) 463 464 f = Foo() 465 466 class NullPass(ExportPass): 467 pass 468 469 prog = to_edge( 470 export( 471 f, 472 (torch.ones(3, 2),), 473 ), 474 compile_config=EdgeCompileConfig(_check_ir_validity=False), 475 ) # TODO(larryliu): fix cat 476 new_prog = prog.transform([NullPass()]) 477 new_nodes = new_prog.exported_program().graph_module.graph.nodes 478 for node in new_nodes: 479 if node.op != "call_function": 480 continue 481 self.assertTrue(hasattr(node, "stack_trace")) 482 self.assertIsNotNone(node.stack_trace) 483 484 old_nodes = prog.exported_program().graph_module.graph.nodes 485 self.assertEqual(len(new_nodes), len(old_nodes)) 486 for new_node, old_node in zip(new_nodes, old_nodes): 487 self.assertEqual(new_node.op, old_node.op) 488 self.assertEqual(new_node.target, old_node.target) 489 490 def test_export_pass_pt2(self) -> None: 491 class Foo(torch.nn.Module): 492 def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 493 y = torch.cat([x, x]) 494 return torch.ops.aten.tensor_split.sections(y, 2) 495 496 f = Foo() 497 498 class NullPass(ExportPass): 499 pass 500 501 prog = to_edge( 502 export( 503 f, 504 (torch.ones(3, 2),), 505 ), 506 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 507 ) 508 new_prog = prog.transform([NullPass()]) 509 new_nodes = new_prog.exported_program().graph_module.graph.nodes 510 for node in new_nodes: 511 if node.op != "call_function": 512 continue 513 self.assertTrue(hasattr(node, "stack_trace")) 514 self.assertIsNotNone(node.stack_trace) 515 516 old_nodes = prog.exported_program().graph_module.graph.nodes 517 self.assertEqual(len(new_nodes), len(old_nodes)) 518 for new_node, old_node in zip(new_nodes, old_nodes): 519 self.assertEqual(new_node.op, old_node.op) 520 self.assertEqual(new_node.target, old_node.target) 521 522 def test_export_scalar_to_tensor_pass(self) -> None: 523 class Mul(torch.nn.Module): 524 def forward(self, x: torch.Tensor) -> torch.Tensor: 525 return x * 3.14 526 527 mul = Mul() 528 529 expo_prog = to_edge(export(mul, (torch.ones(1),))) 530 new_prog = expo_prog.transform([ScalarToTensorPass()]) 531 self.assertIsNotNone(new_prog.exported_program().graph_module) 532 new_graph_module = new_prog.exported_program().graph_module 533 534 inp = torch.zeros(1) 535 self.assertTrue( 536 torch.allclose( 537 expo_prog.exported_program().module()(inp), 538 new_prog.exported_program().module()(inp), 539 ) 540 ) 541 for node in new_graph_module.graph.nodes: 542 if node.op == "call_function": 543 for arg in node.args + tuple(node.kwargs.values()): 544 self.assertFalse(isinstance(arg, float)) 545 546 def test_remove_mixed_types_symfloats(self) -> None: 547 class Foo(torch.nn.Module): 548 def forward(self, x: torch.Tensor) -> torch.Tensor: 549 return torch.nn.functional.interpolate( 550 x, 551 size=(x.shape[2] * 2, x.shape[3] * 3), 552 mode="bilinear", 553 align_corners=False, 554 antialias=False, 555 ) 556 557 f = Foo() 558 559 example_inputs = (torch.randn(2, 3, 4, 5),) 560 561 gm = to_edge( 562 export( 563 f, 564 example_inputs, 565 ) 566 ) 567 new_gm = gm.transform( 568 [ReplaceSymSizeOpPass(), ScalarToTensorPass(), RemoveMixedTypeOperators()] 569 ) 570 self.assertIsNotNone(new_gm.exported_program().graph_module) 571 572 self.assertTrue( 573 torch.allclose( 574 gm.exported_program().module()(*example_inputs), 575 new_gm.exported_program().module()(*example_inputs), 576 ) 577 ) 578 579 def test_spec_prop_pass(self) -> None: 580 class Foo(torch.nn.Module): 581 def forward(self, x: torch.Tensor) -> torch.Tensor: 582 return x + x 583 584 f = Foo() 585 586 gm = ( 587 to_edge( 588 export( 589 f, 590 (torch.ones(3, 2),), 591 ) 592 ) 593 .exported_program() 594 .graph_module 595 ) 596 new_gm = SpecPropPass()(gm) 597 self.assertIsNotNone(new_gm) 598 new_nodes = new_gm.graph_module.graph.nodes 599 counter = 0 600 for node in new_nodes: 601 if node.op != "output": 602 continue 603 counter += 1 604 self.assertIs(node.meta["spec"][0], node.args[0][0].meta["spec"]) 605 606 self.assertEqual(counter, 1) 607 608 def test_spec_prop_pass_tuple_output(self) -> None: 609 class Foo(torch.nn.Module): 610 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: 611 return (x + x,) 612 613 f = Foo() 614 615 gm = ( 616 to_edge( 617 export( 618 f, 619 (torch.ones(3, 2),), 620 ) 621 ) 622 .exported_program() 623 .graph_module 624 ) 625 new_gm = SpecPropPass()(gm) 626 self.assertIsNotNone(new_gm) 627 new_nodes = new_gm.graph_module.graph.nodes 628 counter = 0 629 for node in new_nodes: 630 if node.op != "output": 631 continue 632 counter += 1 633 self.assertIs(node.meta["spec"][0], node.args[0][0].meta["spec"]) 634 635 self.assertEqual(counter, 1) 636 637 def test_compile_fix_broken_ops(self) -> None: 638 # When pass an input of more than 4 dimensions to Linear 639 # aten._unsafe_view is used under the hood 640 x = torch.randn([2, 3, 4, 5]) 641 model: torch.nn.Linear = torch.nn.Linear(5, 5) 642 643 class Foo(torch.nn.Module): 644 def __init__(self): 645 super().__init__() 646 self.model = model 647 648 def forward(self, inp: torch.Tensor) -> torch.Tensor: 649 return self.model(inp) 650 651 f = Foo() 652 653 # ReplaceBrokenOpsWithFunctionalOpsPass is used in to_edge() 654 prog = to_edge( 655 export( 656 f, 657 (x,), 658 ), 659 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 660 ) 661 gm = prog.exported_program().graph_module 662 count_after = 0 663 for node in gm.graph.nodes: 664 if node.target == torch.ops.aten._unsafe_view.default: 665 count_after += 1 666 self.assertEqual(count_after, 0) 667 self.assertTrue(torch.allclose(prog.exported_program().module()(x), f(x))) 668 669 def test_convert_symb_ops(self) -> None: 670 class Foo(torch.nn.Module): 671 def forward(self, x: torch.Tensor) -> torch.Tensor: 672 return torch.add(x, x.shape[0] - 1) 673 674 f = Foo() 675 676 # Mark the 0th dimension of X as dynamic with a max value of 3. 677 dim_x = torch.export.Dim("dim_x", max=3) 678 679 prog = to_edge( 680 export( 681 f, 682 (torch.ones(3, 2),), 683 dynamic_shapes={"x": {0: dim_x}}, 684 ), 685 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 686 ) 687 new_prog = prog.transform([EdgeToBackendOpsPass()]) 688 self.assertIsNotNone(new_prog.exported_program().graph_module) 689 converted_gm = new_prog.exported_program().graph_module 690 691 FileCheck().check("torch.ops.aten.sym_size.int").check( 692 "executorch_exir_dialects_backend__ops_executorch_prim_sub_Scalar" 693 ).check_not("operator.sub").run(converted_gm.code) 694 695 def test_alloc_node_spec(self) -> None: 696 """ 697 Make sure every memory.alloc node including those in sub graph modules 698 have a TensorSpec. 699 """ 700 eager_model = FTMapBasic() 701 inputs = eager_model.get_random_inputs() 702 prog = to_edge( 703 export( 704 eager_model, 705 inputs, 706 ), 707 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 708 ) 709 passes = [ 710 SpecPropPass(), 711 HintBasedSymShapeEvalPass(), 712 ] 713 new_prog = prog.transform(passes) 714 715 new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module) 716 self.assertIsNotNone(new_gm_res) 717 new_gm = new_gm_res.graph_module 718 719 new_gm_res = MemoryPlanningPass()(new_gm) 720 self.assertIsNotNone(new_gm_res) 721 new_gm = new_gm_res.graph_module 722 723 alloc_nodes = [] 724 for subgm in new_gm.modules(): 725 if isinstance(subgm, torch.fx.GraphModule): 726 for node in subgm.graph.nodes: 727 if node.target == memory.alloc: 728 alloc_nodes.append(node) 729 self.assertTrue(len(alloc_nodes) > 0) 730 for node in alloc_nodes: 731 self.assertTrue(isinstance(node.meta.get("spec", None), TensorSpec)) 732 733 def test_debug_pass_file_log(self) -> None: 734 eager_model = Mul() 735 inputs = eager_model.get_random_inputs() 736 737 # the debug pass works with a graph generated with make_fx directly 738 gm = make_fx(eager_model)(*inputs) 739 740 try: 741 fd, path = tempfile.mkstemp() 742 743 print(f"Write DebugPass output to {path}") 744 DebugPass(log_filename=path)(gm) 745 with open(path) as f: 746 file_cont = f.read() 747 self.assertTrue("torch.ops.aten.mul" in file_cont) 748 finally: 749 os.close(fd) 750 os.unlink(path) 751 752 def test_dce_recursive(self) -> None: 753 eager_model = FTCondDeadCode() 754 inputs = eager_model.get_random_inputs() 755 gm = export( 756 eager_model, 757 inputs, 758 ).graph_module 759 760 self.assertTrue(torch.ops.aten.sub.Tensor in collect_ops(gm)) 761 dead_code_elimination_pass(gm) 762 gm.print_readable() 763 self.assertFalse(torch.ops.aten.sub.Tensor in collect_ops(gm)) 764 765 def test_propagate_dynamic_shape(self) -> None: 766 class Foo(torch.nn.Module): 767 def forward(self, x: torch.Tensor) -> torch.Tensor: 768 y = x 769 for _ in range(2): 770 y = y + x 771 return y 772 773 f = Foo() 774 775 prog = to_edge( 776 export( 777 f, 778 (torch.rand(5),), 779 ), 780 # missing dispatch key 781 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 782 ).transform(propagate_dynamic_shape()) 783 gm = prog.exported_program().graph_module 784 nspec = 0 785 for n in gm.graph.nodes: 786 for spec in pytree.tree_flatten(n.meta["spec"])[0]: 787 self.assertTrue(all(isinstance(x, int) for x in spec.shape)) 788 nspec += 1 789 790 self.assertTrue(nspec > 0) 791 792 def test_losing_symbolic_info(self) -> None: 793 """ 794 Guard against an issue that after calling ConvertSymbolicOpsPass(), 795 future ExportPass will encounter symbolic information loss. 796 """ 797 798 class Foo(torch.nn.Module): 799 def forward(self, x: torch.Tensor) -> torch.Tensor: 800 return torch.add(x, x.shape[0] - 1) 801 802 f = Foo() 803 804 dim_x = torch.export.Dim("dim_x", max=3) 805 prog = to_edge( 806 export( 807 f, 808 (torch.ones(3, 2),), 809 dynamic_shapes={"x": {0: dim_x}}, 810 ), 811 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 812 ) 813 814 new_prog = prog.transform([EdgeToBackendOpsPass()]) 815 gm = new_prog.exported_program().graph_module 816 gm.print_readable() 817 *_, ones, out = gm.graph.nodes 818 print(f"Before ExportPass: {ones.format_node()}") 819 self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt)) 820 self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0) 821 822 new_prog = new_prog.transform([ExportPass()]) 823 gm = new_prog.exported_program().graph_module 824 gm.print_readable() 825 *_, ones, out = gm.graph.nodes 826 print(f"After ExportPass: {ones.format_node()}") 827 self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt)) 828 self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0) 829 830 def test_to_edge_with_edge_ops(self) -> None: 831 x = torch.randn([2, 3, 4, 5]) 832 833 class Foo(torch.nn.Module): 834 def forward(self, x: torch.Tensor) -> torch.Tensor: 835 return x + x 836 837 f = Foo() 838 839 gm = ( 840 to_edge( 841 export( 842 f, 843 (x,), 844 ) 845 ) 846 .exported_program() 847 .graph_module 848 ) 849 for node in gm.graph.nodes: 850 if node.op == "call_function": 851 self.assertEqual(type(node.target), EdgeOpOverload) 852 853 # TODO(T143084047) 854 @unittest.expectedFailure 855 def test_backend_fused_op_retraceable(self) -> None: 856 """This test makes sure the backend op is still retraceable, with the pattern being registered as kernel.""" 857 858 class Foo(torch.nn.Module): 859 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 860 z = x + y 861 return torch.ops.aten.relu.default(z) 862 863 f = Foo() 864 865 gm = export( 866 f, 867 ( 868 torch.randn(2, 2), 869 torch.randn(2, 2), 870 ), 871 ) 872 # should look like: 873 # graph(): 874 # %ph_0 : [#users=1] = placeholder[target=ph_0] 875 # %ph_1 : [#users=1] = placeholder[target=ph_1] 876 # %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, %ph_1), kwargs = {}) 877 # %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%add_tensor,), kwargs = {}) 878 # return [relu_default] 879 FileCheck().check("torch.ops.aten.add.Tensor").check( 880 "torch.ops.aten.relu.default" 881 ).run(gm.graph_module.code) 882 883 class AddReluFusionPass(ExportPass): 884 def call(self, graph_module: GraphModule) -> PassResult: 885 # decorator registers this pattern as a CompositeExplicitAutograd kernel, since there's no kernel registered before. 886 @bind_pattern_to_op(lib, "add_relu") 887 def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 888 z = torch.ops.aten.add.Tensor(x, y) 889 out = torch.ops.aten.relu.default(z) 890 return out 891 892 def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 893 return ops.backend.DO_NOT_USE_TEST_ONLY.add_relu.default(x, y) 894 895 subgraph_rewriter.replace_pattern(graph_module, pattern, replacement) 896 return PassResult(graph_module, True) 897 898 # TODO: larryliu this pass needs to be in to_executorch() 899 class OpReplacePass(ExportPass): 900 def call_operator(self, op, args, kwargs, meta): 901 if op == torch.ops.DO_NOT_USE_TEST_ONLY.add_relu.default: 902 return super().call_operator( 903 ops.backend.DO_NOT_USE_TEST_ONLY.add_relu.default, 904 args, 905 kwargs, 906 meta, 907 ) 908 return super().call_operator(op, args, kwargs, meta) 909 910 gm_lowered = to_edge( 911 gm, 912 compile_config=EdgeCompileConfig( 913 _check_ir_validity=False, 914 ), 915 ).transform([AddReluFusionPass(), OpReplacePass()]) 916 917 FileCheck().check( 918 "executorch_exir_dialects_backend__ops_DO_NOT_USE_TEST_ONLY_add_relu_default" 919 ).run(gm_lowered.exported_program().graph_module.code) 920 # lowered module: 921 # def forward(self, ph_0, ph_1): 922 # do_not_use_test_only_add_relu_default = executorch_exir_dialects_backend__ops_DO_NOT_USE_TEST_ONLY_add_relu_default(ph_0, ph_1); ph_0 = ph_1 = None 923 # return [do_not_use_test_only_add_relu_default] 924 925 # Retrace: 926 # If not backend op retrace will error out because no CPU/CompositeExplicitAutograd kernel registered. 927 gm_retraced = to_edge( 928 export( 929 gm_lowered.exported_program().module(), 930 ( 931 torch.randn(2, 2), 932 torch.randn(2, 2), 933 ), 934 ) 935 ) 936 # Retrace-able, the graph "promote" back to ATen dialect, showing up add and relu, which is expected. 937 FileCheck().check("torch.ops.aten.add.Tensor").check( 938 "torch.ops.aten.relu.default" 939 ).run(gm_retraced.exported_program().graph_module.code) 940 941 def test_debug_handle_generator_pass(self) -> None: 942 eager_model = MLP(2, output_size=4) 943 inputs = eager_model.get_random_inputs() 944 945 graph_module = ( 946 to_edge( 947 export( 948 eager_model, 949 inputs, 950 ) 951 ) 952 .exported_program() 953 .graph_module 954 ) 955 for node in graph_module.graph.nodes: 956 self.assertIn("debug_handle", node.meta) 957 ScalarToTensorPass()(graph_module) 958 for node in graph_module.graph.nodes: 959 self.assertIn("debug_handle", node.meta) 960 961 def test_generate_missing_debug_handles(self) -> None: 962 eager_model = MLP(2, output_size=4) 963 inputs = eager_model.get_random_inputs() 964 965 ep = to_edge( 966 export( 967 eager_model, 968 inputs, 969 ) 970 ).exported_program() 971 972 list(ep.graph.nodes)[0].meta.pop("debug_handle") 973 self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None) 974 generate_missing_debug_handles(ep) 975 self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None) 976 977 def test_debug_handle_generator_pass_with_control_flow(self) -> None: 978 def true_nested(y: torch.Tensor) -> torch.Tensor: 979 y = y + y 980 y = torch.mm(y, y) 981 return y 982 983 def false_nested(y: torch.Tensor) -> torch.Tensor: 984 return torch.mm(y, y) 985 986 def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor: 987 z = control_flow.cond(pred2, true_nested, false_nested, [x]) 988 return x + z 989 990 def false_fn(x: torch.Tensor, _) -> torch.Tensor: 991 return x.cos() 992 993 def map_fn( 994 x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor 995 ) -> torch.Tensor: 996 x = x.cos() 997 y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) 998 x = x + y 999 return x.sin() 1000 1001 class Foo(torch.nn.Module): 1002 def forward( 1003 self, 1004 xs: torch.Tensor, 1005 pred1: torch.Tensor, 1006 pred2: torch.Tensor, 1007 y: torch.Tensor, 1008 ) -> torch.Tensor: 1009 y = torch.mm(y, y) 1010 return control_flow.map(map_fn, xs, pred1, pred2, y) 1011 1012 f = Foo() 1013 1014 inputs = ( 1015 torch.ones(2, 2), 1016 torch.tensor([False]), 1017 torch.tensor([False]), 1018 torch.ones(2, 2), 1019 ) 1020 1021 ep = to_edge( 1022 export( 1023 f, 1024 inputs, 1025 ) 1026 ).exported_program() 1027 graph_module = ep.graph_module 1028 1029 def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None: 1030 queue = [graph_module] 1031 while queue: 1032 current_graph_module = queue.pop(0) 1033 for node in current_graph_module.graph.nodes: 1034 self.assertIn("debug_handle", node.meta) 1035 control_flow_submodules = [ 1036 submodule 1037 for _, submodule, _ in get_control_flow_submodules( 1038 current_graph_module 1039 ) 1040 ] 1041 queue.extend(control_flow_submodules) 1042 1043 DebugHandleGeneratorPass()(graph_module) 1044 check_debug_handle_metadata(graph_module) 1045 generate_missing_debug_handles(ep) 1046 1047 # Check debug handle still preserved after ScalarToTensorPass 1048 ScalarToTensorPass()(graph_module) 1049 check_debug_handle_metadata(graph_module) 1050 1051 def test_symint_conversion(self) -> None: 1052 class Foo(torch.nn.Module): 1053 def forward(self, x: torch.Tensor) -> torch.Tensor: 1054 return torch.add(x, x.shape[0] - 1) 1055 1056 f = Foo() 1057 1058 dim_x = torch.export.Dim("dim_x", max=3) 1059 prog = to_edge( 1060 export( 1061 f, 1062 (torch.ones(3, 2),), 1063 dynamic_shapes={"x": {0: dim_x}}, 1064 ), 1065 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 1066 ) 1067 prog = prog.transform([SymToTensorPass()]) 1068 1069 FileCheck().check("torch.ops.aten.scalar_tensor.default").run( 1070 prog.exported_program().graph_module.code 1071 ) 1072 self.assertTrue( 1073 torch.allclose( 1074 f(torch.ones(3, 2)), prog.exported_program().module()(torch.ones(3, 2)) 1075 ) 1076 ) 1077 self.assertTrue( 1078 torch.allclose( 1079 f(torch.zeros(3, 2)), 1080 prog.exported_program().module()(torch.zeros(3, 2)), 1081 ) 1082 ) 1083 1084 def test_remove_assert_pass(self) -> None: 1085 class Foo(torch.nn.Module): 1086 def forward(self, x: torch.Tensor) -> torch.Tensor: 1087 assert x.shape[0] == 5 1088 return x * x 1089 1090 f = Foo() 1091 1092 gm = to_edge( 1093 export( 1094 f, 1095 (torch.randn(5),), 1096 ), 1097 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 1098 ) 1099 new_gm = gm.transform([RemoveGraphAssertsPass()]) 1100 num_asserts = [ 1101 node 1102 for node in new_gm.exported_program().graph.nodes 1103 if node.op == "call_function" 1104 and node.target == torch.ops.aten._assert_async.msg 1105 ] 1106 self.assertEqual(len(num_asserts), 0) 1107 1108 def test_arange(self) -> None: 1109 class M(torch.nn.Module): 1110 def __init__(self): 1111 super().__init__() 1112 self.a = torch.ones(2) 1113 1114 def forward(self, x): 1115 return torch.arange(start=0, end=2) + x 1116 1117 _ = to_edge( 1118 export( 1119 M(), 1120 (torch.randn(2),), 1121 ) 1122 ).to_executorch() 1123 1124 def test_replace_slice(self) -> None: 1125 class M(torch.nn.Module): 1126 def __init__(self): 1127 super().__init__() 1128 self.a = torch.ones(10) 1129 1130 def forward(self, x): 1131 return self.a[:2] + x 1132 1133 gm = ( 1134 to_edge( 1135 export( 1136 M(), 1137 (torch.randn(2),), 1138 ) 1139 ) 1140 .exported_program() 1141 .graph_module 1142 ) 1143 FileCheck().check( 1144 "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor" 1145 ).run(gm.code) 1146 1147 def test_constant_prop_pass_for_add(self) -> None: 1148 class Add(torch.nn.Module): 1149 def forward(self, x: torch.Tensor) -> torch.Tensor: 1150 return x + 3 1151 1152 add = Add() 1153 1154 edge = to_edge( 1155 export(add, (torch.ones(1),)), 1156 compile_config=EdgeCompileConfig(_skip_dim_order=False), 1157 ) 1158 edge = edge.transform([ScalarToTensorPass(), RemoveMixedTypeOperators()]) 1159 exported_program = lift_constant_tensor_pass(edge.exported_program()) 1160 1161 # Check there is a lifted tensor followed by a to_copy node 1162 FileCheck().check("_lifted_tensor_constant0").check( 1163 "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" 1164 ).run(exported_program.graph_module.code) 1165 1166 new_ep = constant_prop_pass(exported_program) 1167 1168 # Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor 1169 FileCheck().check_not("_lifted_tensor_constant").check( 1170 "_prop_tensor_constant0" 1171 ).check_not( 1172 "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" 1173 ).run( 1174 new_ep.graph_module.code 1175 ) 1176 1177 def test_constant_prop_pass_for_parameter(self) -> None: 1178 def count_additions(gm: torch.fx.GraphModule) -> int: 1179 return sum( 1180 (node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes 1181 ) 1182 1183 class M(torch.nn.Module): 1184 def __init__(self): 1185 super().__init__() 1186 self.a = torch.nn.Parameter(torch.ones(1, 2, 3)) 1187 1188 def forward(self, x): 1189 b = self.a + self.a 1190 c = torch.cat([self.a, b]) 1191 return (c + c) + x 1192 1193 aten = export( 1194 M(), 1195 (torch.zeros(2, 2, 3),), 1196 ) 1197 self.assertEqual(count_additions(aten.graph_module), 3) 1198 new_ep = constant_prop_pass(aten) 1199 self.assertEqual(count_additions(new_ep.graph_module), 1) 1200 1201 def test_constant_prop_pass_graph_signature(self) -> None: 1202 def count_additions(gm: torch.fx.GraphModule) -> int: 1203 return sum( 1204 (node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes 1205 ) 1206 1207 class M(torch.nn.Module): 1208 def __init__(self): 1209 super().__init__() 1210 self.a = torch.nn.Parameter(torch.ones(1, 2, 3)) 1211 1212 def forward(self, x): 1213 b = self.a + self.a 1214 c = torch.cat([self.a, b]) 1215 return (c + c) + x 1216 1217 aten = export( 1218 M(), 1219 (torch.zeros(2, 2, 3),), 1220 ) 1221 # Input signature will have two entries: 1222 # (1) parameter `a` and (2) user input `x`. 1223 self.assertEqual(len(aten.graph_signature.input_specs), 2) 1224 new_ep = constant_prop_pass(aten) 1225 # Check that there are exactly two propagated tensors - (1) propagated 1226 # constant and (2) user input. 1227 self.assertEqual( 1228 new_ep.graph_signature.input_specs, 1229 [ 1230 InputSpec( 1231 kind=InputKind.CONSTANT_TENSOR, 1232 arg=TensorArgument(name="_prop_tensor_constant0"), 1233 target="_prop_tensor_constant0", 1234 persistent=True, 1235 ), 1236 # User input graph signature. 1237 aten.graph_signature.input_specs[-1], 1238 ], 1239 ) 1240 1241 def test_constant_prop_pass_for_parameter_slice(self) -> None: 1242 def count_slice(gm: torch.fx.GraphModule) -> int: 1243 return sum( 1244 (node.target == torch.ops.aten.slice_copy.Tensor) 1245 for node in gm.graph.nodes 1246 ) 1247 1248 class M(torch.nn.Module): 1249 def __init__(self): 1250 super().__init__() 1251 self.a = torch.nn.Parameter(torch.ones(3, 2, 2)) 1252 1253 def forward(self, x): 1254 # Create slice of shape (1, 2, 2) 1255 slice_tensor = torch.slice_copy(self.a, dim=0, start=0, end=1) 1256 return torch.cat([x, slice_tensor]) 1257 1258 aten = export( 1259 M(), 1260 (torch.zeros(2, 2, 2),), 1261 ) 1262 self.assertIn("a", aten.state_dict) 1263 self.assertEqual(count_slice(aten.graph_module), 1) 1264 1265 new_ep = constant_prop_pass(aten) 1266 # Check there is a propagated tensor. 1267 FileCheck().check("_prop_tensor_constant0").run(aten.graph_module.code) 1268 self.assertIn("_prop_tensor_constant0", new_ep.constants) 1269 self.assertNotIn("a", new_ep.state_dict) 1270 # No more slice copy. 1271 self.assertEqual(count_slice(new_ep.graph_module), 0) 1272 1273 def test_constant_prop_pass_no_propagate(self) -> None: 1274 def count_placeholder(gm: torch.fx.GraphModule) -> int: 1275 return sum((node.op == "placeholder") for node in gm.graph.nodes) 1276 1277 class M(torch.nn.Module): 1278 def __init__(self): 1279 super().__init__() 1280 self.a = torch.nn.Parameter(torch.ones(3, 2, 4)) 1281 1282 def forward(self, x, y): 1283 # y is unused. 1284 return x + self.a 1285 1286 aten = export( 1287 M(), 1288 (torch.zeros(3, 2, 4), torch.zeros(3, 2, 4)), 1289 ) 1290 self.assertIn("a", aten.state_dict) 1291 self.assertEqual(count_placeholder(aten.graph_module), 3) 1292 1293 new_ep = constant_prop_pass(aten) 1294 # Check there is no propagated tensor. 1295 FileCheck().check("p_a").check("x").check("y").run(aten.graph_module.code) 1296 self.assertNotIn("_prop_tensor_constant0", new_ep.constants) 1297 self.assertIn("a", new_ep.state_dict) 1298 self.assertEqual(count_placeholder(new_ep.graph_module), 3) 1299 1300 def test_constant_prop_pass_for_control_flow(self) -> None: 1301 class Module(torch.nn.Module): 1302 def __init__(self): 1303 super().__init__() 1304 self.linear = torch.nn.Linear(3, 3) 1305 1306 def t(self, val): 1307 return val + 1 1308 1309 def f(self, val): 1310 return val - 1 1311 1312 def true_fn(self, val): 1313 return self.linear(val) + self.t(val) 1314 1315 def false_fn(self, val): 1316 return self.linear(val) - self.f(val) 1317 1318 def forward(self, pred, x): 1319 return torch.ops.higher_order.cond( 1320 pred, self.true_fn, self.false_fn, [x] 1321 ) 1322 1323 mod = Module() 1324 x = torch.randn([3, 3]) 1325 pred = torch.tensor(x[0][0].item() < 0) 1326 edge = to_edge( 1327 export(mod, (pred, x)), 1328 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 1329 ) 1330 error_msg = r"constant_prop_pass for control flow is not supported yet." 1331 1332 # TODO(chenlai): enable constant prop pass for control flow 1333 with self.assertRaisesRegex( 1334 RuntimeError, 1335 error_msg, 1336 ): 1337 _ = constant_prop_pass(edge.exported_program()) 1338 1339 def test_mutable_buffers(self) -> None: 1340 def count_copies(gm: torch.fx.GraphModule) -> int: 1341 return sum( 1342 (node.target == torch.ops.aten.copy_.default) for node in gm.graph.nodes 1343 ) 1344 1345 class MutableStateModule(torch.nn.Module): 1346 def __init__(self): 1347 super().__init__() 1348 self.register_buffer("state", torch.zeros(1)) 1349 1350 def forward(self, x): 1351 y = x + self.state 1352 self.state.add_(1) 1353 return y 1354 1355 model = to_edge( 1356 export( 1357 MutableStateModule(), 1358 (torch.zeros(1),), 1359 ) 1360 ) 1361 self.assertEqual(count_copies(model.exported_program().graph_module), 0) 1362 # Before 1363 # graph(): 1364 # %arg0_1 : [num_users=2] = placeholder[target=arg0_1] 1365 # %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1] 1366 # %arg1_1 : [num_users=1] = placeholder[target=arg1_1] 1367 # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {}) 1368 # %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32}) 1369 # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {}) 1370 # return (aten_add_tensor_1, aten_add_tensor) 1371 gm, _ = insert_write_back_for_buffers_pass(model.exported_program()) 1372 1373 # After 1374 # graph(): 1375 # %arg0_1 : [num_users=3] = placeholder[target=arg0_1] 1376 # %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1] 1377 # %arg1_1 : [num_users=1] = placeholder[target=arg1_1] 1378 # %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {}) 1379 # %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32}) 1380 # %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {}) 1381 # %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {}) 1382 # return (copy__default, aten_add_tensor) 1383 self.assertEqual(count_copies(gm), 1) 1384 1385 def test_remove_quantized_op_noop_pass(self) -> None: 1386 class TestAddSliceNoop(torch.nn.Module): 1387 def __init__(self): 1388 super().__init__() 1389 1390 def forward(self, x): 1391 x = x + x 1392 x = x + x[:] 1393 return x 1394 1395 class TestAddSliceNotNoop(torch.nn.Module): 1396 def __init__(self): 1397 super().__init__() 1398 1399 def forward(self, x): 1400 x = x + x 1401 x = x + x[:1] 1402 return x 1403 1404 def count_dq_nodes(gm: torch.fx.GraphModule) -> int: 1405 return sum( 1406 ( 1407 node.target 1408 in ( 1409 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 1410 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 1411 ) 1412 ) 1413 for node in gm.graph.nodes 1414 ) 1415 1416 def count_q_nodes(gm: torch.fx.GraphModule) -> int: 1417 return sum( 1418 ( 1419 node.target 1420 in ( 1421 torch.ops.quantized_decomposed.quantize_per_tensor.default, 1422 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 1423 ) 1424 ) 1425 for node in gm.graph.nodes 1426 ) 1427 1428 def quantize_model( 1429 m_eager: torch.nn.Module, example_inputs: Tuple[torch.Tensor] 1430 ) -> Tuple[EdgeProgramManager, int, int]: 1431 # program capture 1432 m = torch.export.export_for_training( 1433 m_eager, 1434 example_inputs, 1435 ).module() 1436 1437 quantizer = XNNPACKQuantizer() 1438 quantization_config = get_symmetric_quantization_config() 1439 quantizer.set_global(quantization_config) 1440 m = prepare_pt2e(m, quantizer) # pyre-fixme[6] 1441 m = convert_pt2e(m, fold_quantize=True) 1442 ep = torch.export.export(m, example_inputs) 1443 dq_nodes_pre = count_dq_nodes(ep.graph_module) 1444 q_nodes_pre = count_q_nodes(ep.graph_module) 1445 edge = to_edge( 1446 ep, compile_config=EdgeCompileConfig(_check_ir_validity=False) 1447 ) 1448 return edge, dq_nodes_pre, q_nodes_pre 1449 1450 example_inputs = (torch.randn(9, 8),) 1451 model = TestAddSliceNoop() 1452 m_eager = model.eval() 1453 edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs) 1454 1455 dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module) 1456 q_nodes_post = count_q_nodes(edge.exported_program().graph_module) 1457 # One dq and one q node around the slice copy should have been removed. 1458 self.assertEqual(dq_nodes_pre - dq_nodes_post, 1) 1459 self.assertEqual(q_nodes_pre - q_nodes_post, 1) 1460 1461 # Check that the slice_copy is removed by the RemoveNoopPass. 1462 for node in edge.exported_program().graph_module.graph.nodes: 1463 self.assertFalse("slice" in str(node.target)) 1464 1465 model = TestAddSliceNotNoop() 1466 m_eager = model.eval() 1467 edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs) 1468 1469 dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module) 1470 q_nodes_post = count_q_nodes(edge.exported_program().graph_module) 1471 # One dq and one q node around the slice copy should have been removed. 1472 self.assertEqual(dq_nodes_pre, dq_nodes_post) 1473 self.assertEqual(q_nodes_pre, q_nodes_post) 1474 1475 # Check that the slice_copy is not removed by the RemoveNoopPass. 1476 self.assertTrue( 1477 any( 1478 "slice" in str(node.target) 1479 for node in edge.exported_program().graph_module.graph.nodes 1480 ) 1481 ) 1482 1483 def test_dq_q_no_op_pass(self) -> None: 1484 class TestDqQ(torch.nn.Module): 1485 def __init__(self): 1486 super().__init__() 1487 1488 def forward(self, x): 1489 dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 1490 x, 1.0, 0, -128, 127, torch.int8 1491 ) 1492 q = torch.ops.quantized_decomposed.quantize_per_tensor.default( 1493 dq, 1.0, 0, -128, 127, torch.int8 1494 ) 1495 return q 1496 1497 model = TestDqQ() 1498 m_eager = model.eval() 1499 ep = torch.export.export(m_eager, (torch.randn(9, 8),)) 1500 edge = to_edge(ep) 1501 # Check that the dq and q nodes are not touched by the RemoveNoopPass. 1502 self.assertTrue( 1503 any( 1504 "dequantize" in str(node.target) 1505 for node in edge.exported_program().graph_module.graph.nodes 1506 ) 1507 ) 1508 self.assertTrue( 1509 any( 1510 "quantize" in str(node.target) 1511 for node in edge.exported_program().graph_module.graph.nodes 1512 ) 1513 ) 1514 1515 def test_dq_q_different_qparams(self) -> None: 1516 class TestDqQDifferentQParam(torch.nn.Module): 1517 def __init__(self): 1518 super().__init__() 1519 1520 def forward(self, x): 1521 dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default( 1522 x, 1.0, 0, -128, 127, torch.int8 1523 ) 1524 slice_copy_output = torch.ops.aten.slice_copy.Tensor(dq, 0, 0) 1525 q = torch.ops.quantized_decomposed.quantize_per_tensor.default( 1526 slice_copy_output, 1.0, 0, -127, 127, torch.int8 1527 ) 1528 return q 1529 1530 model = TestDqQDifferentQParam() 1531 m_eager = model.eval() 1532 ep = torch.export.export(m_eager, (torch.randn(9, 8),)) 1533 edge = to_edge(ep) 1534 print(edge.exported_program().graph_module.graph) 1535 # Check that the dq and q nodes are not touched by the RemoveNoopPass. 1536 self.assertTrue( 1537 any( 1538 "dequantize" in str(node.target) 1539 for node in edge.exported_program().graph_module.graph.nodes 1540 ) 1541 ) 1542 self.assertTrue( 1543 any( 1544 "quantize" in str(node.target) 1545 for node in edge.exported_program().graph_module.graph.nodes 1546 ) 1547 ) 1548 self.assertFalse( 1549 any( 1550 "slice" in str(node.target) 1551 for node in edge.exported_program().graph_module.graph.nodes 1552 ) 1553 ) 1554 1555 def test_normalize_view_copy_base_pass(self) -> None: 1556 1557 class ViewChain(torch.nn.Module): 1558 def forward(self, x): 1559 x = torch.ops.aten.view_copy.default(x, [30, 1]) 1560 x = torch.ops.aten.view_copy.default(x, [5, 6]) 1561 x = torch.ops.aten.view_copy.default(x, [2, 15]) 1562 x = torch.ops.aten.view_copy.default(x, [3, -1]) 1563 return x 1564 1565 def is_view_copy(node: torch.fx.Node) -> bool: 1566 return ( 1567 node.op == "call_function" 1568 and node.target == torch.ops.aten.view_copy.default 1569 ) 1570 1571 gm = export(ViewChain(), (torch.ones(30),)).graph_module 1572 1573 # Check before transformation 1574 n_view_copy_before = 0 1575 n_view_copy_bases_before = 0 1576 for node in gm.graph.nodes: 1577 if is_view_copy(node): 1578 n_view_copy_before += 1 1579 base = node.args[0] 1580 if is_view_copy(base): 1581 n_view_copy_bases_before += 1 1582 1583 self.assertEqual(n_view_copy_before, 4) 1584 self.assertEqual(n_view_copy_bases_before, 3) 1585 1586 # Do transformation 1587 p = NormalizeViewCopyBasePass() 1588 gm_res = p(gm) 1589 assert gm_res is not None 1590 gm = gm_res.graph_module 1591 1592 # Check after transformation 1593 n_view_copy_after = 0 1594 n_view_copy_bases_after = 0 1595 for node in gm.graph.nodes: 1596 if is_view_copy(node): 1597 n_view_copy_after += 1 1598 base = node.args[0] 1599 if is_view_copy(base): 1600 n_view_copy_bases_after += 1 1601 1602 self.assertEqual(n_view_copy_after, 4) 1603 self.assertEqual(n_view_copy_bases_after, 0) 1604 1605 def test_replace_view_copy_with_view_pass(self) -> None: # noqa: C901 1606 1607 # Helper functions 1608 def is_view_copy(node: torch.fx.Node) -> bool: 1609 return ( 1610 node.op == "call_function" 1611 and node.target == torch.ops.aten.view_copy.default 1612 ) 1613 1614 def is_memory_view(node: torch.fx.Node) -> bool: 1615 return node.op == "call_function" and node.target == memory.view 1616 1617 # Test example set up 1618 class TestViewCopies(torch.nn.Module): 1619 def __init__(self): 1620 super().__init__() 1621 self.parameter = torch.nn.Parameter(torch.ones(1)) 1622 1623 def forward(self, x): 1624 o1 = torch.ops.aten.view_copy.default(x, [1]) 1625 o2 = torch.ops.aten.view_copy.default(self.parameter, [1]) 1626 # view_copys at the end of a function are not replaced, so add 1627 # a computation before the end of the graph. 1628 return torch.ops.aten.add.Tensor(o1, o2) 1629 1630 ep = torch.export.export( 1631 TestViewCopies(), 1632 args=(torch.ones(1),), 1633 ) 1634 for node in ep.graph.nodes: 1635 if node.op == "placeholder": 1636 node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) 1637 node.meta["spec"].shape_dynamism = TensorShapeDynamism.STATIC 1638 1639 # Run tests 1640 gm = ep.graph_module 1641 1642 # Check before transformation 1643 FileCheck().check_count( 1644 "torch.ops.aten.view_copy.default", 2, exactly=True 1645 ).run(gm.code) 1646 FileCheck().check_count("executorch_exir_memory_view", 0, exactly=True).run( 1647 gm.code 1648 ) 1649 1650 # Do transformation 1651 p = ReplaceViewCopyWithViewPass() 1652 gm_res = p(gm) 1653 assert gm_res is not None 1654 gm = gm_res.graph_module 1655 1656 # Check after transformation 1657 FileCheck().check_count( 1658 "torch.ops.aten.view_copy.default", 0, exactly=True 1659 ).run(gm.code) 1660 FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run( 1661 gm.code 1662 ) 1663 1664 def test_constant_prop_pass_for_no_grad(self) -> None: 1665 class LSTM(torch.nn.Module): 1666 def __init__(self, input_size, hidden_size, num_layers): 1667 super(LSTM, self).__init__() 1668 self.hidden_size = hidden_size 1669 self.num_layers = num_layers 1670 self.lstm = torch.nn.LSTM( 1671 input_size, hidden_size, num_layers, batch_first=True 1672 ) 1673 1674 def forward(self, text_tokens): 1675 # input: (seq_len, batch, input_size) 1676 lstm_out, (new_hidden_state, new_cell_state) = self.lstm( 1677 input=text_tokens, hx=None 1678 ) 1679 return lstm_out 1680 1681 lstm = LSTM(input_size=200, hidden_size=203, num_layers=2) 1682 example_input = (torch.rand(2, 10, 200),) 1683 1684 aten = torch.export.export(lstm, example_input, strict=False) 1685 _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( 1686 _check_ir_validity=True, 1687 _skip_dim_order=True, # TODO(T189114319): Reuse dim order op after solving the ios oss issue 1688 ) 1689 1690 edge_manager: EdgeProgramManager = to_edge( 1691 aten, 1692 compile_config=_EDGE_COMPILE_CONFIG, 1693 ) 1694 new_ep = constant_prop_pass(edge_manager._edge_programs["forward"]) 1695 _ = copy.deepcopy(new_ep.module_call_graph) 1696 1697 def test_dim_order_revert_pass(self) -> None: 1698 aten_op_str = "torch.ops.aten._to_copy.default" 1699 edge_aten_op_str = "executorch_exir_dialects_edge__ops_aten__to_copy_default" 1700 edge_dim_order_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" 1701 1702 class Module(torch.nn.Module): 1703 """ 1704 A simple module that has a single to op that converts to channels last and then back to contiguous. 1705 Assuming contiguous input. 1706 """ 1707 1708 def __init__(self): 1709 super().__init__() 1710 1711 def forward(self, x: torch.Tensor) -> torch.Tensor: 1712 return x.to(memory_format=torch.channels_last).to( 1713 memory_format=torch.contiguous_format 1714 ) + x.to(memory_format=torch.channels_last).to( 1715 memory_format=torch.contiguous_format 1716 ) 1717 1718 @staticmethod 1719 def to_copy_count(): 1720 return 4 1721 1722 def _do_checks( 1723 test_str: str, allowed: str, allowed_count: int, not_allowed_list: List[str] 1724 ) -> None: 1725 for not_allowed in not_allowed_list: 1726 FileCheck().check_count(allowed, allowed_count, exactly=True).check_not( 1727 not_allowed 1728 ).run(test_str) 1729 1730 m = Module() 1731 n = m.to_copy_count() 1732 input = torch.randn([2, 3, 4, 5]).to(memory_format=torch.contiguous_format) 1733 1734 # 1. vanilla export, no edge ops 1735 ep = export( 1736 m, 1737 (input,), 1738 ).run_decompositions({}) 1739 _do_checks( 1740 ep.graph_module.code, 1741 aten_op_str, 1742 n, 1743 [edge_aten_op_str, edge_dim_order_op_str], 1744 ) 1745 1746 # 2a. to edge without dim orders, we should see edge aten ops but not dim order ops 1747 edge_prog = to_edge( 1748 ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=True) 1749 )._edge_programs["forward"] 1750 _do_checks( 1751 edge_prog.graph_module.code, 1752 edge_aten_op_str, 1753 n, 1754 [aten_op_str, edge_dim_order_op_str], 1755 ) 1756 1757 # 3a. expect no change after the pass, we should see edge aten ops but not dim order ops 1758 new_res = DimOrderOpsRevertPass()(edge_prog.graph_module) 1759 self.assertIsNotNone(new_res) 1760 _do_checks( 1761 new_res.graph_module.code, 1762 edge_aten_op_str, 1763 n, 1764 [aten_op_str, edge_dim_order_op_str], 1765 ) 1766 1767 # 2b. let's try with dim order enabled, we should see edge dim order ops but not edge aten ops 1768 edge_prog_dim_order = to_edge( 1769 ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=False) 1770 )._edge_programs["forward"] 1771 _do_checks( 1772 edge_prog_dim_order.graph_module.code, 1773 edge_dim_order_op_str, 1774 n, 1775 [aten_op_str, edge_aten_op_str], 1776 ) 1777 1778 # 3b. expect edge aten ops after the pass, we should see not see the edge dim order ops 1779 new_res_dim_order = DimOrderOpsRevertPass()(edge_prog_dim_order.graph_module) 1780 self.assertIsNotNone(new_res_dim_order) 1781 _do_checks( 1782 new_res_dim_order.graph_module.code, 1783 edge_aten_op_str, 1784 n, 1785 [aten_op_str, edge_dim_order_op_str], 1786 ) 1787 1788 output_no_dim_order = new_res.graph_module(input) 1789 output_no_dim_order_revert = new_res_dim_order.graph_module(input) 1790 self.assertTrue( 1791 torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0]) 1792 ) 1793