1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import operator 8import unittest 9from typing import Dict, List 10 11import executorch.exir as exir 12import torch 13from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend 14from executorch.exir.backend.compile_spec_schema import CompileSpec 15from executorch.exir.backend.partitioner import ( 16 DelegationSpec, 17 Partitioner, 18 PartitionResult, 19) 20 21# import the backend implementation 22from executorch.exir.backend.test.backend_with_compiler_demo import ( 23 BackendWithCompilerDemo, 24) 25from executorch.exir.backend.test.hta_partitioner_demo import ( 26 HTAPartitionerMultiplePatternsDemo, 27 HTAPartitionerOnePatternDemo, 28) 29from executorch.exir.backend.test.op_partitioner_demo import ( 30 AddAttributePartitionerDemo, 31 AddMulPartitionerDemo, 32) 33from executorch.exir.backend.test.qnn_backend_demo import QnnBackend 34 35from executorch.exir.delegate import executorch_call_delegate 36from executorch.exir.dialects._ops import ops as exir_ops 37from executorch.exir.graph_module import get_control_flow_submodules 38from executorch.exir.lowered_backend_module import get_lowered_submodules 39from executorch.exir.print_program import print_program 40from executorch.exir.schema import ( 41 BackendDelegate, 42 BackendDelegateDataReference, 43 DataLocation, 44 DelegateCall, 45 Program, 46) 47 48from executorch.extension.pybindings.portable_lib import ( # @manual 49 _load_for_executorch_from_buffer, 50) 51from executorch.extension.pytree import tree_flatten 52 53from functorch.experimental import control_flow 54from torch.ao.quantization import get_default_qconfig_mapping # @manual 55from torch.ao.quantization.backend_config.executorch import ( 56 get_executorch_backend_config, 57) 58from torch.ao.quantization.quantize_fx import ( 59 _convert_to_reference_decomposed_fx, 60 prepare_fx, 61) 62from torch.export import ExportedProgram 63from torch.testing import FileCheck 64 65 66def vary_segments(test_method): 67 """A decorator that calls the test method with `extract_delegate_segments` set to 68 True and False. 69 70 Decorated test methods must expect a boolean parameter named 71 `extract_delegate_segments`, and they should pass that value to to_executorch() like: 72 73 m.to_executorch( 74 config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments) 75 ) 76 77 This will cause the delegate data blobs to be extracted from the program and 78 serialized as separate, freeable program segments. Backends should detect no 79 difference at runtime. 80 """ 81 82 def wrapper(self): 83 for extract_delegate_segments in [False, True]: 84 # subTest will create a different top-level test entry for each 85 # value, whose full names have a suffix like 86 # "(extract_delegate_segments=True)". 87 with self.subTest(extract_delegate_segments=extract_delegate_segments): 88 test_method(self, extract_delegate_segments=extract_delegate_segments) 89 90 return wrapper 91 92 93class TestBackends(unittest.TestCase): 94 def check_delegate_input( 95 self, delegate: LoweredBackendModule, input_len: int 96 ) -> None: 97 counter = 0 98 for node in delegate.original_module.graph.nodes: 99 if node.op == "placeholder": 100 counter += 1 101 self.assertEqual(counter, input_len) 102 103 def check_backend_delegate( 104 self, 105 program: Program, 106 delegate: BackendDelegate, 107 expected_id: str, 108 expected_processed: bytes, 109 ) -> None: 110 self.assertEqual(delegate.id, expected_id) 111 processed: BackendDelegateDataReference = delegate.processed 112 self.assertEqual(processed.location, DataLocation.INLINE) 113 self.assertLess(processed.index, len(program.backend_delegate_data)) 114 self.assertEqual( 115 program.backend_delegate_data[processed.index].data, expected_processed 116 ) 117 118 @vary_segments 119 def test_backend_with_compiler(self, extract_delegate_segments: bool): 120 class SinModule(torch.nn.Module): 121 def __init__(self): 122 super().__init__() 123 124 # TODO(chenlai): add a test with a diffrent method name when 125 # it's resolved in compiler side. 126 def forward(self, x): 127 return torch.sin(x) 128 129 sin_module = SinModule() 130 model_inputs = (torch.ones(1),) 131 edgeir_m = exir.capture( 132 sin_module, model_inputs, exir.CaptureConfig() 133 ).to_edge() 134 max_value = model_inputs[0].shape[0] 135 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 136 lowered_sin_module = to_backend( 137 "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs 138 ) 139 140 class CompositeModule(torch.nn.Module): 141 def __init__(self): 142 super().__init__() 143 self.lowered_linear_sin = lowered_sin_module 144 145 def forward(self, x): 146 return self.lowered_linear_sin(x) 147 148 composite_model = CompositeModule() 149 model_inputs = (torch.ones(1),) 150 151 composite_model(*model_inputs) 152 153 exec_prog = ( 154 exir.capture(composite_model, model_inputs, exir.CaptureConfig()) 155 .to_edge() 156 .to_executorch( 157 config=exir.ExecutorchBackendConfig( 158 extract_delegate_segments=extract_delegate_segments 159 ) 160 ) 161 ) 162 graph_module = exec_prog.dump_graph_module() 163 164 # Check that there is not an aten.sin node. 165 self.assertTrue( 166 exir_ops.edge.aten.sin 167 not in {node.target for node in graph_module.graph.nodes} 168 ) 169 170 # Check that there exists a call_delegate, representing the call to the 171 # delegated function 172 FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( 173 graph_module.code 174 ) 175 lowered_submodules = get_lowered_submodules(graph_module) 176 self.assertEqual(len(lowered_submodules), 1) 177 178 for node in graph_module.graph.nodes: 179 if node.op == "call_function" and node.target == executorch_call_delegate: 180 # Check that first arg is lowered_module_{unique_id} 181 self.assertEqual(node.args[0].target, "lowered_module_0") 182 183 program = exec_prog.program 184 185 # Check the program can be printed 186 print_program(program) 187 188 # Check the backend delegate 189 self.check_backend_delegate( 190 program=program, 191 delegate=program.execution_plan[0].delegates[0], 192 expected_id=BackendWithCompilerDemo.__name__, 193 expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#", 194 ) 195 196 # Check the delegate instruction 197 self.assertTrue( 198 isinstance( 199 program.execution_plan[0].chains[0].instructions[0].instr_args, 200 DelegateCall, 201 ) 202 ) 203 buff = exec_prog.buffer 204 205 executorch_module = _load_for_executorch_from_buffer(buff) 206 model_inputs = torch.ones(1) 207 model_outputs = executorch_module.forward([model_inputs]) 208 self.assertEqual( 209 model_inputs, 210 torch.ones(1), 211 ) 212 expected_output = 0.8333 * torch.ones(1) 213 214 self.assertTrue( 215 torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03) 216 ) 217 218 @vary_segments 219 def test_lowered_add_mul(self, extract_delegate_segments: bool): 220 class AddMulModule(torch.nn.Module): 221 def __init__(self): 222 super().__init__() 223 224 def forward(self, a, x, b): 225 y = torch.mm(a, x) 226 z = torch.add(y, b) 227 return z 228 229 add_mul_module = AddMulModule() 230 model_inputs = (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) 231 edge_graph_module = exir.capture( 232 add_mul_module, model_inputs, exir.CaptureConfig() 233 ).to_edge() 234 max_value = model_inputs[0].shape[0] 235 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 236 lowered_add_mul = to_backend( 237 "BackendWithCompilerDemo", edge_graph_module.exported_program, compile_specs 238 ) 239 240 class CompositeModule(torch.nn.Module): 241 def __init__(self): 242 super().__init__() 243 self.lowered_add_mul = lowered_add_mul 244 245 def forward(self, a, x, b): 246 return self.lowered_add_mul(a, x, b) 247 248 composite_model = CompositeModule() 249 250 composite_model(*model_inputs) 251 252 exec_prog = ( 253 exir.capture(composite_model, model_inputs, exir.CaptureConfig()) 254 .to_edge() 255 .to_executorch( 256 config=exir.ExecutorchBackendConfig( 257 extract_delegate_segments=extract_delegate_segments 258 ) 259 ) 260 ) 261 buff = exec_prog.buffer 262 263 executorch_module = _load_for_executorch_from_buffer(buff) 264 265 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 266 inputs_flattened, _ = tree_flatten(model_inputs) 267 model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) 268 ref_output = add_mul_module(*model_inputs) 269 270 self.assertTrue( 271 torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03) 272 ) 273 274 def run_model_in_unsupported_backend(self, extract_delegate_segments: bool): 275 class SinModule(torch.nn.Module): 276 def __init__(self): 277 super().__init__() 278 279 def forward(self, x): 280 return torch.sin(x) 281 282 sin_module = SinModule() 283 # the backend only accepts shape <= 4 284 model_inputs = (torch.ones(6),) 285 edgeir_m = exir.capture( 286 sin_module, model_inputs, exir.CaptureConfig() 287 ).to_edge() 288 max_value = model_inputs[0].shape[0] 289 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 290 lowered_sin_module = to_backend( 291 "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs 292 ) 293 294 class CompositeModule(torch.nn.Module): 295 def __init__(self): 296 super().__init__() 297 self.lowered_linear_sin = lowered_sin_module 298 299 def forward(self, x): 300 return self.lowered_linear_sin(x) 301 302 composite_model = CompositeModule() 303 model_inputs = (torch.zeros(6),) 304 305 composite_model(*model_inputs) 306 307 exec_prog = ( 308 exir.capture(composite_model, model_inputs, exir.CaptureConfig()) 309 .to_edge() 310 .to_executorch( 311 config=exir.ExecutorchBackendConfig( 312 extract_delegate_segments=extract_delegate_segments 313 ), 314 ) 315 ) 316 317 buff = exec_prog.buffer 318 319 # This line should raise an exception like 320 # RuntimeError: failed with error 0x12 321 _load_for_executorch_from_buffer(buff) 322 323 @vary_segments 324 def test_backend_with_compiler_out_of_range(self, extract_delegate_segments: bool): 325 with self.assertRaisesRegex( 326 RuntimeError, 327 "loading method forward failed with error 0x12", 328 ): 329 self.run_model_in_unsupported_backend( 330 extract_delegate_segments=extract_delegate_segments 331 ) 332 333 @vary_segments 334 def test_backend_with_compiler_delegate_and_operator( 335 self, extract_delegate_segments: bool 336 ): 337 # Test includes both delegates and operator 338 # import the backend implementation 339 from executorch.exir.backend.test.backend_with_compiler_demo import ( 340 BackendWithCompilerDemo, 341 ) 342 343 class SinModule(torch.nn.Module): 344 def __init__(self): 345 super().__init__() 346 347 # TODO(chenlai): add a test with a diffrent method name when 348 # it's resolved in compiler side. 349 def forward(self, x): 350 return [torch.sin(x)] 351 352 sin_module = SinModule() 353 model_inputs = (torch.ones(1),) 354 edgeir_m = exir.capture( 355 sin_module, model_inputs, exir.CaptureConfig() 356 ).to_edge() 357 max_value = model_inputs[0].shape[0] 358 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 359 lowered_sin_module = to_backend( 360 "BackendWithCompilerDemo", edgeir_m.exported_program, compile_specs 361 ) 362 363 class CompositeModule(torch.nn.Module): 364 def __init__(self): 365 super().__init__() 366 self.lowered_linear_sin = lowered_sin_module 367 368 def forward(self, x): 369 a = self.lowered_linear_sin(x)[0] 370 b = self.lowered_linear_sin(x)[0] 371 return torch.add(a, b) 372 373 composite_model = CompositeModule() 374 model_inputs = (torch.ones(1),) 375 376 composite_model(*model_inputs) 377 378 exec_prog = ( 379 exir.capture(composite_model, model_inputs, exir.CaptureConfig()) 380 .to_edge() 381 .to_executorch( 382 config=exir.ExecutorchBackendConfig( 383 extract_delegate_segments=extract_delegate_segments 384 ), 385 ) 386 ) 387 graph_module = exec_prog.dump_graph_module() 388 program = exec_prog.program 389 buff = exec_prog.buffer 390 391 # Check that there is not an aten.sin node. 392 self.assertTrue( 393 exir_ops.edge.aten.sin.default 394 not in {node.target for node in graph_module.graph.nodes} 395 ) 396 397 # Check that there exists a call_delegate op, representing the call to the 398 # delegated function 399 FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( 400 graph_module.code 401 ) 402 403 for node in graph_module.graph.nodes: 404 if node.op == "call_function" and node.target == executorch_call_delegate: 405 # Check that first arg is lowered_module_{unique_id} 406 self.assertEqual(node.args[0].target, "lowered_module_0") 407 408 # Check the backend delegate 409 self.check_backend_delegate( 410 program=program, 411 delegate=program.execution_plan[0].delegates[0], 412 expected_id=BackendWithCompilerDemo.__name__, 413 expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#", 414 ) 415 416 # Check the delegate instruction 417 self.assertTrue( 418 isinstance( 419 program.execution_plan[0].chains[0].instructions[0].instr_args, 420 DelegateCall, 421 ) 422 ) 423 424 executorch_module = _load_for_executorch_from_buffer(buff) 425 model_inputs = torch.ones(1) 426 427 model_outputs = executorch_module.forward([model_inputs]) 428 429 self.assertEqual( 430 model_inputs, 431 torch.ones(1), 432 ) 433 expected_output = 1.666667 * torch.ones(1) 434 435 self.assertTrue( 436 torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03) 437 ) 438 439 def test_backend_with_compiler_backend_runtime_exception(self): 440 class SinModule(torch.nn.Module): 441 def __init__(self): 442 super().__init__() 443 444 # TODO(chenlai): add a test with a diffrent method name when 445 # it's resolved in compiler side. 446 def forward(self, x): 447 return torch.sin(x) + torch.cos(x) 448 449 sin_module = SinModule() 450 model_inputs = (torch.ones(1),) 451 edgeir_m = exir.capture( 452 sin_module, model_inputs, exir.CaptureConfig() 453 ).to_edge() 454 error_msg = r"call_function aten.cos.default is not supported in backend BackendWithCompilerDemo" 455 456 with self.assertRaisesRegex( 457 RuntimeError, 458 error_msg, 459 ): 460 _ = to_backend("BackendWithCompilerDemo", edgeir_m.exported_program, []) 461 462 def test_backend_with_compiler_backend_not_found_exception(self): 463 class SinModule(torch.nn.Module): 464 def __init__(self): 465 super().__init__() 466 467 # TODO(chenlai): add a test with a diffrent method name when 468 # it's resolved in compiler side. 469 def forward(self, x): 470 return torch.sin(x) + torch.cos(x) 471 472 sin_module = SinModule() 473 model_inputs = (torch.ones(1),) 474 edgeir_m = exir.capture( 475 sin_module, model_inputs, exir.CaptureConfig() 476 ).to_edge() 477 error_msg = r"Backend FakeBackendWithCompilerDemo was not found." 478 479 with self.assertRaisesRegex( 480 NotImplementedError, 481 error_msg, 482 ): 483 _ = to_backend("FakeBackendWithCompilerDemo", edgeir_m.exported_program, []) 484 485 @vary_segments 486 def test_backend_with_compiler_delegate_and_operator_with_two_modules( 487 self, extract_delegate_segments: bool 488 ): 489 # the submodule runs in a specific backend. In this example, `BackendWithCompilerDemo` backend 490 class LowerableSubModel(torch.nn.Module): 491 def __init__(self): 492 super().__init__() 493 494 def forward(self, x): 495 return torch.sin(x) 496 497 # sin_module is an nn.Module 498 to_be_lowered = LowerableSubModel() 499 example_input = (torch.ones(1),) 500 to_be_lowered_exir_submodule = exir.capture( 501 to_be_lowered, example_input, exir.CaptureConfig() 502 ).to_edge() 503 504 max_value = example_input[0].shape[0] 505 compile_specs = [CompileSpec("max_value", bytes([max_value]))] 506 lowered_module = to_backend( 507 "BackendWithCompilerDemo", 508 to_be_lowered_exir_submodule.exported_program, 509 compile_specs, 510 ) 511 512 class NonLowerableSubModel(torch.nn.Module): 513 def __init__(self, bias): 514 super().__init__() 515 self.bias = bias 516 517 def forward(self, a, b): 518 return torch.add(torch.add(a, b), self.bias) 519 520 # the composite modules, including lower part and non-lowerpart 521 class CompositeModel(torch.nn.Module): 522 def __init__(self): 523 super().__init__() 524 self.non_lowerable = NonLowerableSubModel(torch.ones(1) * 0.3) 525 self.lowerable = lowered_module 526 527 def forward(self, x): 528 a = self.lowerable(x) 529 b = self.lowerable(a) 530 ret = self.non_lowerable(a, b) 531 return a, b, ret 532 533 composite_model = CompositeModel() 534 535 # Prepare the model input 536 model_inputs = (torch.ones(1),) 537 538 # Verify the input works with eager module 539 composite_model(*model_inputs) 540 541 exec_prog = ( 542 exir.capture(composite_model, model_inputs, exir.CaptureConfig()) 543 .to_edge() 544 .to_executorch( 545 config=exir.ExecutorchBackendConfig( 546 extract_delegate_segments=extract_delegate_segments 547 ), 548 ) 549 ) 550 flatbuffer = exec_prog.buffer 551 552 executorch_module = _load_for_executorch_from_buffer(flatbuffer) 553 model_outputs = executorch_module.forward([*model_inputs]) 554 555 expected_outputs = [ 556 0.8333 * torch.ones(1), 557 0.7369 * torch.ones(1), 558 1.8702 * torch.ones(1), 559 ] 560 561 for index, expected_output in enumerate(expected_outputs): 562 self.assertTrue( 563 torch.allclose( 564 model_outputs[index], expected_output, atol=1e-03, rtol=1e-03 565 ) 566 ) 567 568 @vary_segments 569 def test_partition_delegate_graph_with_multiple_patterns( 570 self, extract_delegate_segments: bool 571 ): 572 class CompositeModel(torch.nn.Module): 573 def __init__(self, _weight): 574 super().__init__() 575 self.weight = _weight 576 self.lstm = torch.nn.LSTM( 577 input_size=32, 578 hidden_size=32, 579 num_layers=1, 580 ) 581 self.conv = torch.nn.Conv1d(1, 1, 1, stride=2) 582 583 def forward(self, x_raw, h, c): 584 output, (hn, cn) = self.lstm(x_raw, (h, c)) 585 k = self.conv(output) 586 x = output 587 y = cn 588 a = torch.sub(x, y) 589 b = torch.sub(x, a) 590 c = torch.sub(x, b) 591 d = torch.add(x, self.weight) 592 e = torch.mul(c, d) 593 return e, hn, k 594 595 # Prepare input and trace it 596 input_x = torch.ones([1, 32]) 597 input_h = torch.ones([1, 32]) 598 input_c = torch.ones([1, 32]) 599 inputs = (input_x, input_h, input_c) 600 601 composite_m = CompositeModel(3) 602 orig_res = composite_m(*inputs) 603 604 traced = exir.capture(composite_m, inputs, exir.CaptureConfig()).to_edge( 605 # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. 606 exir.EdgeCompileConfig(_check_ir_validity=False) 607 ) 608 609 program_without_delegates = ( 610 exir.capture(CompositeModel(3), inputs) 611 .to_edge( 612 # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. 613 exir.EdgeCompileConfig(_check_ir_validity=False) 614 ) 615 .to_executorch( 616 config=exir.ExecutorchBackendConfig( 617 extract_delegate_segments=extract_delegate_segments 618 ), 619 ) 620 ) 621 # after this step, part of the graph will be lowered to backend, depending on 622 # HTAPartitionerDemo's rule. 623 program_with_delegates = traced 624 program_with_delegates.exported_program = to_backend( 625 traced.exported_program, HTAPartitionerMultiplePatternsDemo() 626 ) 627 program_with_delegates = program_with_delegates.to_executorch( 628 config=exir.ExecutorchBackendConfig( 629 extract_delegate_segments=extract_delegate_segments 630 ), 631 ) 632 633 new_res = program_with_delegates.dump_graph_module()(*inputs) 634 for t1, t2 in zip(new_res, orig_res, strict=True): 635 self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03)) 636 637 # Check the backend delegate 638 self.check_backend_delegate( 639 program=program_with_delegates.program, 640 delegate=program_with_delegates.program.execution_plan[0].delegates[0], 641 expected_id=QnnBackend.__name__, 642 expected_processed=b"imqnncompiled", 643 ) 644 645 # Check add not in the program with delegates 646 self.assertEqual( 647 0, 648 len( 649 [ 650 op 651 for op in program_with_delegates.program.execution_plan[0].operators 652 if op.name == "aten::sub" 653 ] 654 ), 655 ) 656 657 # Check convolution not in the program with delegates 658 self.assertEqual( 659 0, 660 len( 661 [ 662 op 663 for op in program_with_delegates.program.execution_plan[0].operators 664 if op.name == "aten::convolution" 665 ] 666 ), 667 ) 668 669 # Check convolution in the program without delegates 670 self.assertEqual( 671 1, 672 len( 673 [ 674 op 675 for op in program_without_delegates.program.execution_plan[ 676 0 677 ].operators 678 if op.name == "aten::convolution" 679 ] 680 ), 681 ) 682 683 @vary_segments 684 def test_partition_delegate_graph_with_one_patterns( 685 self, extract_delegate_segments: bool 686 ): 687 class CompositeModel(torch.nn.Module): 688 def __init__(self, _weight): 689 super().__init__() 690 self.weight = _weight 691 self.lstm = torch.nn.LSTM( 692 input_size=32, 693 hidden_size=32, 694 num_layers=1, 695 ) 696 self.conv = torch.nn.Conv1d(1, 1, 1, stride=2) 697 698 def forward(self, x_raw, h, c): 699 output, (hn, cn) = self.lstm(x_raw, (h, c)) 700 k = self.conv(output) 701 x = output 702 y = cn 703 a = torch.sub(x, y) 704 b = torch.sub(x, a) 705 c = torch.sub(x, b) 706 d = torch.add(x, self.weight) 707 e = torch.mul(c, d) 708 return e, hn, k 709 710 # Prepare input and trace it 711 input_x = torch.ones([1, 32]) 712 input_h = torch.ones([1, 32]) 713 input_c = torch.ones([1, 32]) 714 inputs = (input_x, input_h, input_c) 715 716 composite_m = CompositeModel(3) 717 orig_res = composite_m(*inputs) 718 719 traced = exir.capture( 720 composite_m, 721 inputs, 722 exir.CaptureConfig(), 723 ).to_edge( 724 # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. 725 exir.EdgeCompileConfig(_check_ir_validity=False) 726 ) 727 728 program_without_delegates = ( 729 exir.capture( 730 CompositeModel(3), 731 (input_x, input_h, input_c), 732 exir.CaptureConfig(), 733 ) 734 .to_edge( 735 # torch._export.verifier.SpecViolationError: Operator torch._ops.aten.mkldnn_rnn_layer.default is not Aten Canonical. 736 exir.EdgeCompileConfig(_check_ir_validity=False) 737 ) 738 .to_executorch( 739 config=exir.ExecutorchBackendConfig( 740 extract_delegate_segments=extract_delegate_segments 741 ), 742 ) 743 ) 744 # after this step, part of the graph will be lowered to backend, depending on 745 # HTAPartitionerDemo's rule. 746 traced_with_delegate = traced 747 traced_with_delegate.exported_program = to_backend( 748 traced.exported_program, HTAPartitionerOnePatternDemo() 749 ) 750 751 new_res = traced_with_delegate(*inputs) 752 for t1, t2 in zip(new_res, orig_res, strict=True): 753 self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03)) 754 755 program_with_delegates = traced_with_delegate.to_executorch( 756 config=exir.ExecutorchBackendConfig( 757 extract_delegate_segments=extract_delegate_segments 758 ), 759 ) 760 761 # TODO(T143084047): Currently not retraceable 762 # Retracing is not needed, but keeping this here to make sure the result 763 # of to_backend is retraceable 764 # graph_module_with_delegate = exir.capture( 765 # traced_with_delegate, 766 # (input_x, input_h, input_c), 767 # exir.CaptureConfig(), 768 # ).to_edge() 769 770 # program_with_delegates = graph_module_with_delegate.to_executorch( 771 # config=exir.ExecutorchBackendConfig(extract_delegate_segments=extract_delegate_segments), 772 # ) 773 774 new_res = program_with_delegates.dump_graph_module()(*inputs) 775 for t1, t2 in zip(new_res, orig_res, strict=True): 776 self.assertTrue(torch.allclose(t1, t2, atol=1e-03, rtol=1e-03)) 777 778 # Check the backend delegate 779 self.check_backend_delegate( 780 program=program_with_delegates.program, 781 delegate=program_with_delegates.program.execution_plan[0].delegates[0], 782 expected_id=QnnBackend.__name__, 783 expected_processed=b"imqnncompiled", 784 ) 785 786 # Check add is in the program with delegates 787 self.assertEqual( 788 1, 789 len( 790 [ 791 op 792 for op in program_with_delegates.program.execution_plan[0].operators 793 if op.name == "aten::sub" 794 ] 795 ), 796 ) 797 798 # Check convolution not in the program with delegates 799 self.assertEqual( 800 0, 801 len( 802 [ 803 op 804 for op in program_with_delegates.program.execution_plan[0].operators 805 if op.name == "aten::convolution" 806 ] 807 ), 808 ) 809 810 # Check convolution in the program without delegates 811 self.assertEqual( 812 1, 813 len( 814 [ 815 op 816 for op in program_without_delegates.program.execution_plan[ 817 0 818 ].operators 819 if op.name == "aten::convolution" 820 ] 821 ), 822 ) 823 824 @vary_segments 825 def test_add_mul_partitioner(self, extract_delegate_segments: bool): 826 class Model(torch.nn.Module): 827 def __init__(self): 828 super().__init__() 829 830 def forward(self, a, x, b): 831 y = torch.mm(a, x) 832 z = y + b 833 a = z - a 834 y = torch.mm(a, x) 835 z = y + b 836 return z 837 838 m = Model() 839 inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) 840 orig_res = m(*inputs) 841 842 ep = exir.capture(m, inputs, exir.CaptureConfig()).to_edge() 843 executorch_prog = ep 844 executorch_prog.exported_program = to_backend( 845 ep.exported_program, AddMulPartitionerDemo() 846 ) 847 848 for node in executorch_prog.exported_program.graph.nodes: 849 if node.op == "call_function" and node.target is executorch_call_delegate: 850 for user in node.users: 851 self.assertTrue( 852 user.op == "call_function" and user.target == operator.getitem 853 ) 854 self.assertTrue(user.meta.get("source_fn_stack", None) is None) 855 self.assertTrue(user.meta.get("nn_module_stack", None) is None) 856 857 executorch_prog = executorch_prog.to_executorch( 858 config=exir.ExecutorchBackendConfig( 859 extract_delegate_segments=extract_delegate_segments 860 ), 861 ) 862 863 new_res = executorch_prog.dump_graph_module()(*inputs) 864 self.assertTrue(torch.allclose(new_res[0], orig_res)) 865 866 counter = 0 867 for node in executorch_prog.dump_graph_module().graph.nodes: 868 if node.op == "get_attr": 869 self.assertEqual(node.target, f"lowered_module_{counter}") 870 counter += 1 871 # There should be 2 delegated modules 872 self.assertEqual(counter, 2) 873 874 executorch_module = _load_for_executorch_from_buffer(executorch_prog.buffer) 875 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 876 inputs_flattened, _ = tree_flatten(inputs) 877 model_output = executorch_module.run_method("forward", tuple(inputs_flattened)) 878 ref_output = m(*inputs) 879 880 self.assertTrue( 881 torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03), 882 ) 883 884 @vary_segments 885 def test_partitioner_with_attributes(self, extract_delegate_segments: bool): 886 """ 887 Check that if we tag the getattr nodes, the attributes will be added to 888 the lowered submodule rather than being passed into the delegate as 889 inputs. 890 """ 891 892 class AddOne(torch.nn.Module): 893 def __init__(self): 894 super().__init__() 895 self.one = torch.ones(1, 3) 896 897 def forward(self, x): 898 return x + self.one 899 900 class Model(torch.nn.Module): 901 def __init__(self): 902 super().__init__() 903 self.add_one = AddOne() 904 905 def forward(self, x, y): 906 x = self.add_one(x) * y 907 return self.add_one(x), self.add_one(y) 908 909 inputs = (torch.randn(1, 3), torch.randn(1, 3)) 910 orig_res = Model()(*inputs) 911 ep = exir.capture(Model(), inputs, exir.CaptureConfig()).to_edge() 912 executorch_prog = ep 913 executorch_prog.exported_program = to_backend( 914 ep.exported_program, AddAttributePartitionerDemo() 915 ) 916 917 for node in executorch_prog.exported_program.graph.nodes: 918 if node.op == "call_function" and node.target is executorch_call_delegate: 919 for user in node.users: 920 self.assertTrue( 921 user.op == "call_function" and user.target == operator.getitem 922 ) 923 self.assertTrue(user.meta.get("source_fn_stack", None) is None) 924 self.assertTrue(user.meta.get("nn_module_stack", None) is None) 925 926 executorch_prog = executorch_prog.to_executorch( 927 config=exir.ExecutorchBackendConfig( 928 extract_delegate_segments=extract_delegate_segments 929 ), 930 ) 931 932 # Check the delegated submodules 933 lowered_submodules = get_lowered_submodules(executorch_prog.dump_graph_module()) 934 self.assertEqual(len(lowered_submodules), 2) 935 # Attributes should be stored in the lowered module 936 self.check_delegate_input(lowered_submodules[0][1], 1) 937 self.check_delegate_input(lowered_submodules[1][1], 2) 938 939 executorch_prog.buffer 940 941 new_res = executorch_prog.dump_graph_module()(*inputs) 942 self.assertTrue(torch.allclose(orig_res[0], new_res[0])) 943 self.assertTrue(torch.allclose(orig_res[1], new_res[1])) 944 945 def test_bad_partitioner(self): 946 """ 947 Checks that we throw an error if user provided partitioner modifies the 948 graph module 949 """ 950 inputs = (torch.randn(1, 3), torch.randn(1, 3)) 951 952 class Model(torch.nn.Module): 953 def __init__(self): 954 super().__init__() 955 956 def forward(self, x, y): 957 x = x + y 958 x = x * y 959 x = x - y 960 x = x / y 961 x = x * y 962 x = x + y 963 return x 964 965 class BadPartitioner(Partitioner): 966 def partition(self, exported_program: ExportedProgram) -> PartitionResult: 967 # Partitioner should not modify the given graph module 968 for node in exported_program.graph.nodes: 969 if ( 970 node.op == "call_function" 971 and node.target == exir_ops.edge.aten.add.Tensor 972 ): 973 node.target = exir_ops.edge.aten.mul.Tensor 974 return PartitionResult( 975 tagged_exported_program=exported_program, 976 partition_tags={ 977 "tag1": DelegationSpec("BackendWithCompilerDemo", []) 978 }, 979 ) 980 981 ep = exir.capture(Model(), inputs, exir.CaptureConfig()).to_edge() 982 with self.assertRaises(AssertionError): 983 _ = to_backend(ep.exported_program, BadPartitioner()) 984 985 def test_quantized_with_delegate(self) -> None: 986 torch.ops.load_library( 987 "//executorch/kernels/quantized:custom_ops_generated_lib" 988 ) 989 qconfig_mapping = get_default_qconfig_mapping("qnnpack") 990 in_size = 2 991 input_size = 3 992 output_size = 4 993 linear = torch.nn.Linear(input_size, output_size).eval() 994 example_inputs = (torch.ones(in_size, input_size),) 995 prepared_linear = prepare_fx( 996 linear, 997 qconfig_mapping, 998 example_inputs, 999 backend_config=get_executorch_backend_config(), 1000 ) 1001 converted_linear: torch.nn.Module = _convert_to_reference_decomposed_fx( 1002 prepared_linear, 1003 ) 1004 1005 # fails to trace here 1006 converted_linear_gm = exir.capture( 1007 converted_linear, 1008 example_inputs, 1009 exir.CaptureConfig( 1010 enable_aot=True, 1011 ), 1012 ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) 1013 FileCheck().check_count("quantize_per_tensor_default", 3).check("addmm").run( 1014 converted_linear_gm.exported_program.graph_module.code 1015 ) 1016 1017 def test_partition_with_control_flow(self) -> None: 1018 def true_fn(x, y): 1019 x = x - y 1020 x = x + y 1021 x = x - y 1022 return x 1023 1024 def false_fn(x, y): 1025 x = x - y 1026 x = torch.mm(x, y) 1027 x = x - y 1028 return x 1029 1030 def f(x, y): 1031 x = x + y 1032 x = torch.ops.higher_order.cond(x[0][0] == 1, true_fn, false_fn, [x, y]) 1033 x = x - y 1034 return x 1035 1036 inputs = (torch.ones(2, 2), torch.ones(2, 2)) 1037 orig_res = f(*inputs) 1038 orig = exir.capture( 1039 f, 1040 inputs, 1041 exir.CaptureConfig(), 1042 ).to_edge() 1043 partitioned = orig 1044 partitioned.exported_program = to_backend( 1045 orig.exported_program, AddMulPartitionerDemo() 1046 ) 1047 1048 new_res = partitioned(*inputs) 1049 self.assertTrue(torch.allclose(orig_res, new_res[0])) 1050 1051 toplevel_lowered = get_lowered_submodules( 1052 partitioned.exported_program.graph_module 1053 ) 1054 self.assertEqual(len(toplevel_lowered), 1) 1055 FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( 1056 toplevel_lowered[0][1].original_module.graph_module.code 1057 ) 1058 1059 # Toplevel module only has the cond submodules 1060 partitioned_submodules = get_control_flow_submodules( 1061 partitioned.exported_program.graph_module 1062 ) 1063 self.assertEqual(len(partitioned_submodules), 2) 1064 1065 true_gm = partitioned_submodules[0][1] 1066 true_lowered = get_lowered_submodules(true_gm) 1067 self.assertEqual(len(true_lowered), 1) 1068 FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( 1069 true_lowered[0][1].original_module.graph_module.code 1070 ) 1071 1072 false_gm = partitioned_submodules[1][1] 1073 false_lowered = get_lowered_submodules(false_gm) 1074 self.assertEqual(len(true_lowered), 1) 1075 FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( 1076 false_lowered[0][1].original_module.graph_module.code 1077 ) 1078 1079 def test_partition_with_map(self) -> None: 1080 def map_fn(x, y): 1081 x = x - y 1082 x = x + y 1083 return x 1084 1085 def f(xs, y): 1086 y = torch.mm(y, y) 1087 return control_flow.map(map_fn, xs, y) 1088 1089 inputs = (torch.ones(2, 2), torch.ones(2, 2)) 1090 orig_res = f(*inputs) 1091 orig = exir.capture( 1092 f, 1093 inputs, 1094 exir.CaptureConfig(), 1095 ).to_edge() 1096 partitioned = orig 1097 partitioned.exported_program = to_backend( 1098 orig.exported_program, AddMulPartitionerDemo() 1099 ) 1100 1101 toplevel_lowered = get_lowered_submodules( 1102 partitioned.exported_program.graph_module 1103 ) 1104 self.assertEqual(len(toplevel_lowered), 1) 1105 FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( 1106 toplevel_lowered[0][1].original_module.graph_module.code 1107 ) 1108 1109 # Toplevel module only has the map submodule 1110 partitioned_submodules = get_control_flow_submodules( 1111 partitioned.exported_program.graph_module 1112 ) 1113 self.assertEqual(len(partitioned_submodules), 1) 1114 1115 map_fn_gm = partitioned_submodules[0][1] 1116 map_fn_lowered = get_lowered_submodules(map_fn_gm) 1117 self.assertEqual(len(map_fn_lowered), 1) 1118 FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( 1119 map_fn_lowered[0][1].original_module.graph_module.code 1120 ) 1121 1122 new_res = partitioned(*inputs) 1123 1124 self.assertTrue(torch.allclose(orig_res, new_res[0])) 1125 1126 def test_partition_with_nested_control_flow(self) -> None: 1127 """ 1128 Partitions the add and mul ops, including the ones inside the submodules 1129 """ 1130 1131 def true_nested(y): 1132 y = y + y 1133 y = torch.mm(y, y) 1134 return y 1135 1136 def false_nested(y): 1137 return torch.mm(y, y) 1138 1139 def true_fn(x, pred2): 1140 z = control_flow.cond(pred2, true_nested, false_nested, [x]) 1141 return x + z 1142 1143 def false_fn(x, _): 1144 return x.cos() 1145 1146 def map_fn(x, pred1, pred2, y): 1147 x = x.cos() 1148 y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) 1149 x = x + y 1150 return x.sin() 1151 1152 def f(xs, pred1, pred2, y): 1153 y = torch.mm(y, y) 1154 return control_flow.map(map_fn, xs, pred1, pred2, y) 1155 1156 inputs = ( 1157 torch.ones(2, 2), 1158 torch.tensor([False]), 1159 torch.Tensor([False]), 1160 torch.ones(2, 2), 1161 ) 1162 1163 orig_res = f(*inputs) 1164 orig = exir.capture( 1165 f, 1166 inputs, 1167 exir.CaptureConfig(), 1168 ).to_edge() 1169 partitioned = orig 1170 partitioned.exported_program = to_backend( 1171 orig.exported_program, AddMulPartitionerDemo() 1172 ) 1173 1174 new_res = partitioned(*inputs) 1175 self.assertTrue(torch.allclose(orig_res, new_res[0])) 1176 1177 toplevel_lowered = get_lowered_submodules( 1178 partitioned.exported_program.graph_module 1179 ) 1180 self.assertEqual(len(toplevel_lowered), 1) 1181 FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( 1182 toplevel_lowered[0][1].original_module.graph_module.code 1183 ) 1184 1185 # Toplevel module only has the map submodule 1186 partitioned_submodules = get_control_flow_submodules( 1187 partitioned.exported_program.graph_module 1188 ) 1189 self.assertEqual(len(partitioned_submodules), 1) 1190 1191 # Map module has the cond submodules 1192 map_submodules = get_control_flow_submodules(partitioned_submodules[0][1]) 1193 self.assertEqual(len(map_submodules), 2) 1194 1195 # True module 1196 true_module = map_submodules[0][1] 1197 true_lowered = get_lowered_submodules(true_module) 1198 self.assertEqual(len(true_lowered), 1) 1199 FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").run( 1200 true_lowered[0][1].original_module.graph_module.code 1201 ) 1202 1203 # False module 1204 false_lowered = get_lowered_submodules(map_submodules[1][1]) 1205 self.assertEqual(len(false_lowered), 0) 1206 1207 # True module has the nested cond submodules 1208 true_submodules = get_control_flow_submodules(true_module) 1209 self.assertEqual(len(true_submodules), 2) 1210 1211 # Nested True module 1212 true_true_lowered = get_lowered_submodules(true_submodules[0][1]) 1213 self.assertEqual(len(true_true_lowered), 1) 1214 FileCheck().check("executorch_exir_dialects_edge__ops_aten_add_Tensor").check( 1215 "executorch_exir_dialects_edge__ops_aten_mm_default" 1216 ).run(true_true_lowered[0][1].original_module.graph_module.code) 1217 1218 # Nested False module 1219 true_false_lowered = get_lowered_submodules(true_submodules[1][1]) 1220 self.assertEqual(len(true_false_lowered), 1) 1221 FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default").run( 1222 true_false_lowered[0][1].original_module.graph_module.code 1223 ) 1224 1225 def test_list_input(self): 1226 def f(x: List[torch.Tensor]): 1227 y = x[0] + x[1] 1228 return y 1229 1230 inputs = ([torch.randn(2, 2), torch.randn(2, 2)],) 1231 edge_prog = exir.capture(f, inputs, exir.CaptureConfig()).to_edge() 1232 lowered_gm = to_backend( 1233 BackendWithCompilerDemo.__name__, edge_prog.exported_program, [] 1234 ) 1235 1236 class ComposedM(torch.nn.Module): 1237 def __init__(self): 1238 super().__init__() 1239 self.lowered = lowered_gm 1240 1241 def forward(self, x: List[torch.Tensor]): 1242 return self.lowered(x) 1243 1244 gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge() 1245 gm(*inputs) 1246 1247 def test_dict_input(self): 1248 class M(torch.nn.Module): 1249 def forward(self, x: Dict[str, torch.Tensor]): 1250 y = x["a"] + x["b"] 1251 return y 1252 1253 inputs = ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},) 1254 edge_prog = exir.to_edge(torch.export.export(M(), inputs)) 1255 lowered_gm = to_backend( 1256 BackendWithCompilerDemo.__name__, edge_prog.exported_program(), [] 1257 ) 1258 1259 class ComposedM(torch.nn.Module): 1260 def __init__(self): 1261 super().__init__() 1262 self.lowered = lowered_gm 1263 1264 def forward(self, x: List[torch.Tensor]): 1265 return self.lowered(x) 1266 1267 gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge() 1268 gm(*inputs) 1269