1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import copy 10import operator 11from collections import defaultdict 12from typing import Any, Dict, List, Optional, Set, Tuple, Union 13 14import torch 15import torch.utils._pytree as pytree 16from executorch.exir._serialize import _serialize_pte_binary 17from executorch.exir.backend.compile_spec_schema import CompileSpec 18from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name 19from executorch.exir.emit import emit_program 20 21from executorch.exir.graph_module import _get_submodule 22 23from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass 24from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass 25from executorch.exir.schema import Program 26 27from executorch.exir.tracer import Value 28from torch._library.fake_class_registry import FakeScriptObject 29 30from torch._subclasses import FakeTensor 31from torch.export.exported_program import ( 32 ConstantArgument, 33 ExportedProgram, 34 ExportGraphSignature, 35 InputKind, 36 InputSpec, 37 ModuleCallEntry, 38 ModuleCallSignature, 39 OutputKind, 40 OutputSpec, 41 TensorArgument, 42) 43from torch.fx.passes.utils.fuser_utils import ( 44 erase_nodes, 45 fuse_as_graphmodule, 46 insert_subgm, 47 legalize_graph, 48 NodeList, 49 topo_sort, 50) 51 52 53class LoweredBackendModule(torch.nn.Module): 54 """ 55 A subclass of nn.Module that is generated for modules containing 56 delegated functions. This is can be created by calling `to_backend`. 57 """ 58 59 _backend_id: str # The backend's name 60 _processed_bytes: bytes # The delegate blobs created from backend.preprocess 61 _compile_specs: List[ 62 CompileSpec 63 ] # A list of backend-specific objects with static metadata to configure the "compilation" process. 64 _original_exported_program: ExportedProgram # The original EXIR module 65 66 def __init__( 67 self, 68 edge_program: ExportedProgram, 69 backend_id: str, 70 processed_bytes: bytes, 71 compile_specs: List[CompileSpec], 72 ) -> None: 73 super().__init__() 74 self._original_exported_program = edge_program 75 self._backend_id = backend_id 76 self._processed_bytes = processed_bytes 77 self._compile_specs = compile_specs 78 79 # pyre-ignore 80 def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule": 81 # Copy exported program 82 copied_program = ExportedProgram( 83 root=copy.deepcopy(self._original_exported_program.graph_module), 84 graph=copy.deepcopy(self._original_exported_program.graph), 85 graph_signature=copy.deepcopy( 86 self._original_exported_program.graph_signature 87 ), 88 state_dict=self._original_exported_program.state_dict, 89 range_constraints=copy.deepcopy( 90 self._original_exported_program.range_constraints 91 ), 92 module_call_graph=copy.deepcopy( 93 self._original_exported_program.module_call_graph 94 ), 95 constants=self._original_exported_program.constants, 96 verifiers=[copy.deepcopy(self._original_exported_program.verifier)], 97 ) 98 99 res = LoweredBackendModule( 100 edge_program=copied_program, 101 backend_id=self._backend_id, 102 processed_bytes=self._processed_bytes, 103 compile_specs=copy.deepcopy(self._compile_specs, memo), 104 ) 105 # pyre-fixme[16]: `LoweredBackendModule` has no attribute `meta`. 106 res.meta = copy.copy(getattr(self, "meta", {})) 107 return res 108 109 @property 110 def backend_id(self) -> str: 111 """ 112 Returns the backends name. 113 """ 114 return self._backend_id 115 116 @property 117 def processed_bytes(self) -> bytes: 118 """ 119 Returns the delegate blob created from backend.preprocess 120 """ 121 return self._processed_bytes 122 123 @property 124 def compile_specs(self) -> List[CompileSpec]: 125 """ 126 Returns a list of backend-specific objects with static metadata to configure the "compilation" process. 127 """ 128 return self._compile_specs 129 130 @property 131 def original_module(self) -> ExportedProgram: 132 """ 133 Returns the original EXIR module 134 """ 135 return self._original_exported_program 136 137 # TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api 138 def buffer( 139 self, 140 extract_delegate_segments: bool = False, 141 segment_alignment: int = 128, 142 constant_tensor_alignment: Optional[int] = None, 143 delegate_alignment: Optional[int] = None, 144 memory_planning: MemoryPlanningPass = None, # pyre-fixme[9] 145 ) -> bytes: 146 """ 147 Returns a buffer containing the serialized ExecuTorch binary. 148 """ 149 # TODO(T181463742): avoid calling bytes(..) which incurs large copies. 150 out = bytes( 151 _serialize_pte_binary( 152 program=self.program(memory_planning=memory_planning), 153 extract_delegate_segments=extract_delegate_segments, 154 segment_alignment=segment_alignment, 155 constant_tensor_alignment=constant_tensor_alignment, 156 delegate_alignment=delegate_alignment, 157 ) 158 ) 159 return out 160 161 # TODO(chenlai): re-consider recapture instead of manually constructing the program because 162 # the meta data construction is done manually. 163 def program( 164 self, 165 emit_stacktrace: bool = False, 166 memory_planning: MemoryPlanningPass = None, # pyre-fixme[9] 167 ) -> Program: 168 # Fix autodpes introuces cyclic dependencies: 169 # program -> verifier -> lowered_backend_module -> program 170 # @manual 171 from executorch.exir.program._program import ( 172 _get_updated_graph_signature, 173 _transform, 174 ) 175 176 """ 177 Returns the object that represents the ExecuTorch binary before serialization. 178 """ 179 # Creates a new module based on the original module. The original module will 180 # look something like following: 181 # 182 # opcode name target args kwargs 183 # ------------- ------------------- ---------------- ------------------------------------------ -------- 184 # placeholder arg0_1 arg0_1 () {} 185 # placeholder arg1_1 arg1_1 () {} 186 # call_function aten_repeat_default * (arg1_1, [4, 1]) {} 187 # call_function aten_mul_tensor * (aten_repeat_default, aten_repeat_default) {} 188 # call_function aten_add_tensor * (arg1_1, arg1_1) {} 189 # output output output ([aten_mul_tensor, aten_add_tensor],) {} 190 # 191 # if the whole module is lowered, the resulting lowered module look like 192 # 193 # opcode name target args kwargs 194 # ------------- ------------------------ --------------------------- ---------------------------------- -------- 195 # placeholder arg0_1 arg0_1 () {} 196 # placeholder arg1_1 arg1_1 () {} 197 # get_attr lowered_module_0 lowered_module_0 () {} 198 # call_function executorch_call_delegate executorch_call_delegate (lowered_module_0, arg0_1, arg1_1) {} 199 # call_function getitem <built-in function getitem> (executorch_call_delegate, 0) {} 200 # call_function getitem_1 <built-in function getitem> (executorch_call_delegate, 1) {} 201 # output output_1 output ([getitem, getitem_1],) {} 202 # 203 # We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node 204 # and return the list of getitems as the output 205 206 lowered_exported_program = copy.deepcopy(self._original_exported_program) 207 208 # The real input nodes are the ones not buffer or parameter 209 all_input_nodes = [ 210 node 211 for node in lowered_exported_program.graph.nodes 212 if ( 213 node.op == "placeholder" 214 and node.name 215 not in lowered_exported_program.graph_signature.inputs_to_buffers 216 and node.name 217 not in lowered_exported_program.graph_signature.inputs_to_parameters 218 ) 219 ] 220 221 output_node = [ 222 node for node in lowered_exported_program.graph.nodes if node.op == "output" 223 ] 224 assert len(output_node) == 1, "There should be only one output node" 225 226 # Step 1. Cleaning up the graph before inserting the call_delegate node 227 # Remove the original output node 228 lowered_exported_program.graph.erase_node(output_node[0]) 229 230 # Remove all the everything else except the input 231 for node in reversed(lowered_exported_program.graph.nodes): 232 if node.op != "placeholder": 233 lowered_exported_program.graph.erase_node(node) 234 235 # Find placeholders that are parameters or buffers, remove them from the main graph 236 for node in lowered_exported_program.graph.nodes: 237 if node.op == "placeholder" and ( 238 node.name in lowered_exported_program.graph_signature.inputs_to_buffers 239 or node.name 240 in lowered_exported_program.graph_signature.inputs_to_parameters 241 ): 242 lowered_exported_program.graph.erase_node(node) 243 244 # Step 2. Start constructing the graph 245 lowered_name = get_lowered_module_name( 246 lowered_exported_program.graph_module, self 247 ) 248 # Insert the lowered module to the graph module as an attibute 249 lowered_node = lowered_exported_program.graph.get_attr(lowered_name) 250 251 # Insert a call_delegate node to the graph module, with arguments from the arg list 252 delegate_node = lowered_exported_program.graph.call_function( 253 executorch_call_delegate, (lowered_node, *all_input_nodes) 254 ) 255 # Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],) 256 # We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly 257 original_output_nodes = [ 258 node 259 for node in self._original_exported_program.graph.nodes 260 if node.op == "output" 261 ][0].args[0] 262 263 delegate_node.meta["spec"] = tuple( 264 [make_spec(node.meta["val"]) for node in original_output_nodes] 265 ) 266 delegate_node.meta["val"] = tuple( 267 [node.meta["val"] for node in original_output_nodes] 268 ) 269 270 # The getitem nodes that are going to be inserted to the lowered graph module 271 getitem_nodes = [] 272 for i in range(len(original_output_nodes)): 273 getitem_node = lowered_exported_program.graph.call_function( 274 operator.getitem, 275 args=(delegate_node, i), 276 ) 277 getitem_node.meta["val"] = delegate_node.meta["val"][i] 278 getitem_nodes.append(getitem_node) 279 lowered_exported_program.graph.output(getitem_nodes) 280 281 lowered_exported_program.graph_module.recompile() 282 lowered_exported_program.graph.lint() 283 284 # Users output will be the get items nodes instead 285 output_specs = [ 286 OutputSpec( 287 kind=OutputKind.USER_OUTPUT, 288 arg=TensorArgument(name=getitem_node.name), 289 target=None, 290 ) 291 for getitem_node in getitem_nodes 292 ] 293 # All data are consumed by the delegates so they should be removed from the state dict. 294 inputs_to_parameters = ( 295 lowered_exported_program.graph_signature.inputs_to_parameters 296 ) 297 inputs_to_buffers = lowered_exported_program.graph_signature.inputs_to_buffers 298 input_specs = [ 299 InputSpec( 300 kind=InputKind.USER_INPUT, 301 arg=TensorArgument(name=node.name), 302 target=None, 303 ) 304 for user_input in lowered_exported_program.graph_signature.user_inputs 305 if user_input not in inputs_to_parameters 306 and user_input not in inputs_to_buffers 307 ] 308 309 # Double check the ExportedProgram data(especially everything except graph) is good 310 exported_program = ExportedProgram( 311 root=lowered_exported_program.graph_module, 312 graph=lowered_exported_program.graph, 313 graph_signature=_get_updated_graph_signature( 314 ExportGraphSignature( 315 input_specs=input_specs, output_specs=output_specs 316 ), 317 lowered_exported_program.graph_module, 318 ), 319 # TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None) 320 # somewhere as we should pass it a list of tensors to the lowered module and output a 321 # list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the 322 # inputs/outputs to the toplevel program will be in the format of the eager module. 323 state_dict={}, # None because all data are consumed by delegate 324 range_constraints=lowered_exported_program.range_constraints, 325 module_call_graph=lowered_exported_program.module_call_graph, 326 example_inputs=None, 327 verifiers=[lowered_exported_program.verifier], 328 ) 329 if memory_planning is None: 330 memory_planning = MemoryPlanningPass() 331 exported_program = _transform(exported_program, SpecPropPass(), memory_planning) 332 emitted_program = emit_program( 333 exported_program, emit_stacktrace=emit_stacktrace 334 ).program 335 return emitted_program 336 337 # Used to patch each delegated function with a call_delegate call 338 # @staticmethod 339 def forward( 340 self, 341 *args: Value, 342 **kwargs: Tuple[Value, ...], 343 ) -> Value: 344 return executorch_call_delegate(self, *args) 345 346 347# TODO(zhxchen17) Try ExportPass 348def _fixup_output_node(gm: torch.fx.GraphModule) -> None: 349 for node in reversed(gm.graph.nodes): 350 if node.op == "output": 351 with gm.graph.inserting_before(node): 352 assert len(node.args) == 1 353 outputs = node.args[0] 354 if isinstance(outputs, torch.fx.Node): 355 val = outputs.meta.get("val") 356 if isinstance(val, list): 357 # If a list is returned, in some cases it is represented as a 358 # singular node, like `split_copy_tensor` but EXIR will return a 359 # opened-up list like `[getitem1, getitem2]` 360 outputs = [ 361 torch.fx.Proxy(outputs)[i].node for i in range(len(val)) 362 ] 363 returns, out_spec = pytree.tree_flatten(outputs) 364 node.args = (returns,) 365 return 366 367 368def arrange_graph_placeholders( 369 gm: torch.fx.GraphModule, owning_program: ExportedProgram 370) -> torch.fx.GraphModule: 371 """ 372 Modifies the graph of the given graphmodule with one that contains the same nodes as the original, 373 but with placeholders in order of (Params + Buffers) (User Inputs) 374 375 This is used by the delegate api which disturbs the placeholder ordering when creating a submodule 376 from partitioned nodes 377 378 Args: 379 gm: The graph module that we want arranged 380 owning_program: ExportedProgram that the submodule (gm) belongs to 381 382 Returns: 383 The graph module in-placed arranged 384 """ 385 new_graph = torch.fx.Graph() 386 387 node_map = {} # mapping of nodes from old graph to new graph 388 389 graph_sign = owning_program.graph_signature 390 391 # Add all placeholders into the graph first: 392 param_nodes = [] 393 buffer_nodes = [] 394 input_nodes = [] 395 for node in gm.graph.nodes: 396 if node.op != "placeholder": 397 continue 398 399 if node.name in graph_sign.inputs_to_parameters: 400 param_nodes.append(node) 401 elif node.name in graph_sign.inputs_to_buffers: 402 buffer_nodes.append(node) 403 else: 404 input_nodes.append(node) 405 406 for param_node in param_nodes: 407 new_node = new_graph.node_copy(param_node, lambda x: node_map[x]) 408 node_map[param_node] = new_node 409 for buffer_node in buffer_nodes: 410 new_node = new_graph.node_copy(buffer_node, lambda x: node_map[x]) 411 node_map[buffer_node] = new_node 412 for input_node in input_nodes: 413 new_node = new_graph.node_copy(input_node, lambda x: node_map[x]) 414 node_map[input_node] = new_node 415 416 # Now add all the other nodes in order 417 for node in gm.graph.nodes: 418 if node.op == "placeholder": 419 continue 420 421 new_node = new_graph.node_copy(node, lambda x: node_map[x]) 422 node_map[node] = new_node 423 424 # lint to ensure correctness 425 new_graph.lint() 426 427 new_graph._codegen = gm.graph._codegen 428 gm.graph = new_graph 429 430 return gm 431 432 433# TODO Don't regenerate new signature manually. 434def _get_new_signature( # noqa: C901 435 original_program: ExportedProgram, 436 gm: torch.fx.GraphModule, 437 call_module_node: torch.fx.Node, 438 tag: str, 439 is_submodule: bool = False, 440) -> Tuple[ 441 ExportGraphSignature, 442 Dict[str, Union[torch.Tensor, torch.nn.Parameter]], 443 Dict[str, Union[torch.Tensor, FakeScriptObject, torch.ScriptObject]], 444 Dict[str, InputSpec], 445 Dict[str, OutputSpec], 446]: 447 """ 448 Args: 449 original_program: The original program that we are paritioning 450 gm: The partitioned graph module. 451 call_module_node: The node in the original program that is calling the 452 partitioned graph module. 453 tag: The tag being used for this partitioned submodule. This is used to 454 tell if a particular parameter/buffer/constant node is being tagged, 455 aka consumed by the delegate. 456 is_submodule: True if we are currently partitioning inside of a 457 submodule (like cond's submodule). If we are inside of a submodule, 458 we do not care about consuming params/buffers. 459 460 Returns: 461 462 new_signature (ExportGraphSignature): The new signature for the 463 partitioned graph module. 464 new_state_dict (Dict[str, Union[torch.Tensor, torch.nn.Parameter]]): The 465 new state dict containing the consumed params/buffers. 466 new_constants (Dict[str, Union[torch.Tensor, FakeScriptObject, 467 torch.ScriptObject]]): The new constants table containing the 468 consumed constants . 469 input_specs_to_delete (Dict[str, InputSpec]): The input specs that have 470 been consumed by the delegate (param/buffer input nodes) and should 471 be removed from the toplevel ExportedProgram. 472 output_specs_to_delete (Dict[str, InputSpec]): The output specs that have 473 been consumed by the delegate (buffer mutation nodes) and should be 474 removed from the toplevel ExportedProgram. 475 """ 476 old_signature = original_program.graph_signature 477 478 input_specs = [] 479 output_specs = [] 480 input_specs_to_delete = {} 481 output_specs_to_delete = {} 482 new_state_dict = {} 483 new_constants = {} 484 485 # If we are within a submodule, we do not need to care about consuming 486 # parameter/buffers 487 input_node_to_sig: Dict[str, InputSpec] = ( 488 {input_spec.arg.name: input_spec for input_spec in old_signature.input_specs} 489 if not is_submodule 490 else {} 491 ) 492 493 toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list) 494 if not is_submodule: 495 for output_spec in old_signature.output_specs: 496 toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec) 497 498 for node in gm.graph.nodes: 499 if node.op == "placeholder": 500 501 if node.name not in input_node_to_sig: 502 input_specs.append( 503 InputSpec( 504 kind=InputKind.USER_INPUT, 505 arg=TensorArgument(name=node.name), 506 target=None, 507 ) 508 ) 509 continue 510 511 orig_input_spec = input_node_to_sig[node.name] 512 513 if not isinstance(orig_input_spec.arg, TensorArgument): 514 input_specs.append(orig_input_spec) 515 516 elif node.meta.get("delegation_tag", None) == tag: 517 input_specs.append(orig_input_spec) 518 519 if orig_input_spec.kind == InputKind.USER_INPUT: 520 continue 521 522 # The following input specs are all attributes that should be 523 # consumed by the delegate, so we want to remove it from the 524 # toplevel module input/output 525 input_specs_to_delete[node.name] = orig_input_spec 526 527 input_target = orig_input_spec.target 528 if input_target in original_program.state_dict: 529 assert orig_input_spec.kind in ( 530 InputKind.PARAMETER, 531 InputKind.BUFFER, 532 ) 533 534 new_state_dict[input_target] = original_program.state_dict[ 535 input_target 536 ] 537 elif input_target in original_program.constants: 538 assert orig_input_spec.kind in ( 539 InputKind.CONSTANT_TENSOR, 540 InputKind.CUSTOM_OBJ, 541 InputKind.BUFFER, 542 ) 543 544 new_constants[input_target] = original_program.constants[ 545 input_target 546 ] 547 else: 548 raise RuntimeError(f"Invalid input spec {orig_input_spec} received") 549 550 else: 551 input_specs.append( 552 InputSpec( 553 kind=InputKind.USER_INPUT, 554 arg=TensorArgument(name=node.name), 555 target=None, 556 ) 557 ) 558 559 if node.op == "output": 560 buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list) 561 for user in call_module_node.users.keys(): 562 if user.name in toplevel_output_node_to_sig: 563 assert ( 564 user.op == "call_function" and user.target == operator.getitem 565 ), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}" 566 getitem_idx = user.args[1] 567 assert isinstance( 568 getitem_idx, int 569 ), f"Invalid getitem type: {type(getitem_idx)}" 570 buffer_mutation_idxs[getitem_idx].extend( 571 toplevel_output_node_to_sig[user.name] 572 ) 573 574 for i, output_node in enumerate(node.args[0]): 575 if i in buffer_mutation_idxs: 576 assert isinstance(output_node, torch.fx.Node) 577 orig_output_specs = buffer_mutation_idxs[i] 578 579 if any( 580 orig_output_spec.kind == OutputKind.BUFFER_MUTATION 581 and orig_output_spec.target in new_state_dict 582 for orig_output_spec in orig_output_specs 583 ): 584 # If the delegate wants to consume the buffer, then the 585 # delegate should also consume the buffer mutation 586 # (output spec would be a BUFFER_MUTATION). Otherwise 587 # the delegate will just return the result of the 588 # mutation as a USER_OUTPUT. 589 590 orig_output_spec = [ 591 orig_output_spec 592 for orig_output_spec in orig_output_specs 593 if orig_output_spec.kind == OutputKind.BUFFER_MUTATION 594 and orig_output_spec.target in new_state_dict 595 ][0] 596 597 assert len(orig_output_specs) == 1, ( 598 f"Constant {orig_output_spec.target} was tagged to be " 599 "consumed by the buffer, and was found to also contain " 600 "a buffer mutation. However this buffer mutation node " 601 "was found to also be used as other types of outputs " 602 "which is currently not supported. Please file an " 603 "issue on Github. \n\n" 604 f"The toplevel program: {original_program}\n" 605 ) 606 output_specs.append( 607 OutputSpec( 608 kind=OutputKind.BUFFER_MUTATION, 609 arg=TensorArgument(name=output_node.name), 610 target=orig_output_spec.target, 611 ) 612 ) 613 output_specs_to_delete[orig_output_spec.arg.name] = ( 614 orig_output_spec 615 ) 616 else: 617 output_specs.append( 618 OutputSpec( 619 kind=OutputKind.USER_OUTPUT, 620 arg=TensorArgument(name=output_node.name), 621 target=None, 622 ) 623 ) 624 625 elif not isinstance(output_node, torch.fx.Node): 626 output_specs.append( 627 OutputSpec( 628 kind=OutputKind.USER_OUTPUT, 629 arg=ConstantArgument(name="", value=output_node), 630 target=None, 631 ) 632 ) 633 634 else: 635 output_specs.append( 636 OutputSpec( 637 kind=OutputKind.USER_OUTPUT, 638 arg=TensorArgument(name=output_node.name), 639 target=None, 640 ) 641 ) 642 643 new_signature = ExportGraphSignature( 644 input_specs=input_specs, output_specs=output_specs 645 ) 646 647 return ( 648 new_signature, 649 new_state_dict, 650 new_constants, 651 input_specs_to_delete, 652 output_specs_to_delete, 653 ) 654 655 656def create_exported_program_from_submodule( 657 submodule: torch.fx.GraphModule, 658 owning_program: ExportedProgram, 659 tag: str, 660 call_module_node: torch.fx.Node, 661 is_submodule: bool, 662) -> Tuple[ExportedProgram, Dict[str, InputSpec], Dict[str, OutputSpec]]: 663 """ 664 Creates an ExportedProgram from the given submodule using the parameters and buffers 665 from the top-level owning program 666 667 Args: 668 submodule: submodule to create and exported program from 669 owning_program: exported program containing the parameters and buffers used within 670 the submodule 671 672 Returns: 673 The ExportedProgram created from submodule 674 input_specs_to_delete (Dict[str, InputSpec]): The input specs that have 675 been consumed by the delegate (param/buffer input nodes) and should 676 be removed from the toplevel ExportedProgram. 677 output_specs_to_delete (Dict[str, InputSpec]): The output specs that have 678 been consumed by the delegate (buffer mutation nodes) and should be 679 removed from the toplevel ExportedProgram. 680 """ 681 # Arrange the submodule's placeholders in order 682 submodule = arrange_graph_placeholders(submodule, owning_program) 683 684 # TODO: we probably need to arrange the outputs wrt buffer mutations. 685 686 # Get updated graph signature 687 ( 688 subgraph_signature, 689 subgraph_state_dict, 690 subgraph_constants, 691 toplevel_input_specs_to_delete, 692 toplevel_output_specs_to_delete, 693 ) = _get_new_signature( 694 owning_program, submodule, call_module_node, tag, is_submodule 695 ) 696 697 in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1] 698 out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1] 699 700 return ( 701 ExportedProgram( 702 root=submodule, 703 graph=submodule.graph, 704 graph_signature=subgraph_signature, 705 state_dict=subgraph_state_dict, 706 range_constraints=copy.deepcopy(owning_program.range_constraints), 707 module_call_graph=[ 708 ModuleCallEntry( 709 "", 710 ModuleCallSignature( 711 inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec 712 ), 713 ) 714 ], 715 constants=subgraph_constants, 716 verifiers=[owning_program.verifier], 717 ), 718 toplevel_input_specs_to_delete, 719 toplevel_output_specs_to_delete, 720 ) 721 722 723def create_submodule_from_nodes( 724 gm: torch.fx.GraphModule, 725 node_list: NodeList, 726 tag: str, 727 skip_legalize_graph: bool = False, 728) -> Tuple[torch.fx.GraphModule, torch.fx.Node]: 729 """ 730 Modifies the given graph module in-place to separate out the given nodes 731 into a submodule. The given node_list should form a fully connected 732 subgraph. 733 734 Args: 735 gm: The graph module that we want to partition 736 node_list: A list of nodes that belong in the partition 737 738 Returns: 739 The submodule that has been partitioned, the call_module node in the 740 toplevel graph module calling the submodule 741 """ 742 sorted_nodes = topo_sort(node_list) 743 744 submodule_name = "fused_" + tag 745 sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( 746 gm, sorted_nodes, submodule_name 747 ) 748 749 _fixup_output_node(sub_gm) 750 751 gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) 752 submodule_node = None 753 for node in gm.graph.nodes: 754 if node.op == "call_module": 755 if node.target == submodule_name: 756 submodule_node = node 757 else: 758 raise RuntimeError( 759 f"The submodule created with nodes {node_list} did not form \ 760 one fully contained subgraph. Check that these nodes form a \ 761 fully contained graph. Partitioned graph: {gm.graph}." 762 ) 763 764 if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor): 765 # If the original output is a single tensor, it has been 766 # pytree.tree_flatten-ed to be a singleton list, so we want to replace 767 # all uses with a getitem call to the 0th index of the result 768 with gm.graph.inserting_after(submodule_node): 769 proxy_out = torch.fx.Proxy(submodule_node)[0].node # type: ignore[index] 770 submodule_node.replace_all_uses_with(proxy_out) 771 proxy_out.meta["val"] = submodule_node.meta["val"] 772 # Reset the args since it was overwritten in the previous line 773 proxy_out.args = (submodule_node, 0) 774 else: 775 # fuse_as_graphmodule will automatically propagate the metadata of the 776 # partition's last node to the getitem nodes that appear after the 777 # call_module node. However, in the case of delegation we do not want 778 # these getitem nodes to contain irrelevant previous metadata 779 # (ex. source_fn, # nn_module_stack) 780 for user_node in submodule_node.users: 781 user_node.meta.pop("nn_module_stack", None) 782 user_node.meta.pop("source_fn_stack", None) 783 784 erase_nodes(gm, sorted_nodes) 785 786 # Topological sort original gm with newly created sub_gm 787 # TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes 788 # once we transition to using fuse_by_partitions. 789 if not skip_legalize_graph: 790 legalize_graph(gm) 791 792 # Get the call_module node 793 submodule_node = None 794 for node in gm.graph.nodes: 795 if node.op == "call_module" and node.target == submodule_name: 796 submodule_node = node 797 elif node.op == "call_module": 798 raise RuntimeError( 799 f"The submodule created with nodes {node_list} did not form \ 800 one fully contained subgraph. Check that these nodes form a \ 801 fully contained graph. Partitioned graph: {gm.graph}." 802 ) 803 804 assert ( 805 submodule_node is not None 806 ), f"No submodule was created with the nodes {node_list} in the graph {gm.graph}" 807 808 return sub_gm, submodule_node 809 810 811def get_lowered_submodules( 812 graph_module: torch.fx.GraphModule, 813) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]: 814 """ 815 Returns a list of lowered modules that are in the given graph (does not look 816 into submodules). Specifically, the returned value is a list containing a 817 tuple of (name of the lowered module that's stored in the graph module, the 818 lowered module itself, and the fx node that called this lowered module). 819 """ 820 lowered_submodules = [] 821 for node in graph_module.graph.nodes: 822 if node.op == "call_function" and node.target == executorch_call_delegate: 823 name, module, node = _get_submodule(graph_module, node, 0) 824 assert isinstance(module, LoweredBackendModule) 825 lowered_submodules.append((name, module, node)) 826 return lowered_submodules 827 828 829def get_lowered_backend_modules( 830 graph_module: torch.fx.GraphModule, 831) -> List[LoweredBackendModule]: 832 """ 833 Returns a list of exported programs which were lowered by backen delegates 834 """ 835 lowered_programs = [] 836 for node in graph_module.graph.nodes: 837 if node.op == "call_function" and node.target == executorch_call_delegate: 838 lowered_backend_module = getattr(graph_module, node.args[0].name) 839 lowered_programs.append(lowered_backend_module) 840 841 return lowered_programs 842 843 844def _unsafe_adjust_original_program( # noqa: C901 845 original_program: ExportedProgram, 846 call_delegate_node: torch.fx.Node, 847 input_specs_to_delete: Dict[str, InputSpec], 848 output_specs_to_delete: Dict[str, OutputSpec], 849) -> None: 850 """ 851 Directly modify the original exported program's signature and state dict 852 based on the consumed params/buffers in the delegate. 853 """ 854 original_program._graph_signature.input_specs = [ 855 input_spec 856 for input_spec in original_program.graph_signature.input_specs 857 if input_spec.arg.name not in input_specs_to_delete 858 ] 859 860 currently_used_targets: Set[str] = { 861 input_spec.target 862 for input_spec in original_program._graph_signature.input_specs 863 if input_spec.target is not None 864 } 865 866 original_program._graph_signature.output_specs = [ 867 output_spec 868 for output_spec in original_program.graph_signature.output_specs 869 if output_spec.arg.name not in output_specs_to_delete 870 ] 871 872 # Delete all parameters/buffers consumed by the created exported program 873 # from the graph signature, state dict, constants table 874 for node in original_program.graph.nodes: 875 if node.op == "placeholder": 876 if node.name in input_specs_to_delete: 877 assert len(node.users) == 0 878 original_program.graph.erase_node(node) 879 else: 880 break 881 882 for input_spec in input_specs_to_delete.values(): 883 input_target = input_spec.target 884 assert input_target is not None 885 886 if input_target in currently_used_targets: 887 continue 888 889 if input_spec.kind == InputKind.PARAMETER: 890 del original_program._state_dict[input_target] 891 elif input_spec.kind == InputKind.BUFFER: 892 if input_spec.persistent: 893 del original_program._state_dict[input_target] 894 else: 895 del original_program._constants[input_spec.target] 896 elif input_spec.kind == InputKind.CONSTANT_TENSOR: 897 del original_program._constants[input_spec.target] 898 else: 899 raise RuntimeError(f"Invalid input spec {input_spec} received") 900 901 # Delete buffer mutations from the output which were consumed by the delegate 902 toplevel_output_node = None 903 for node in reversed(original_program.graph.nodes): 904 if node.op == "output": 905 toplevel_output_node = node 906 break 907 908 assert toplevel_output_node is not None 909 assert ( 910 len(toplevel_output_node.args) == 1 911 ), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}" 912 913 new_output_args = [ 914 arg 915 for arg in toplevel_output_node.args[0] 916 if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete 917 ] 918 toplevel_output_node.args = (tuple(new_output_args),) 919 920 # Delete the buffer mutation getitem nodes 921 getitem_idxs: List[int] = [] 922 user_nodes = list(call_delegate_node.users.keys()) 923 for user in user_nodes: 924 if user.name in output_specs_to_delete: 925 assert ( 926 user.op == "call_function" and user.target == operator.getitem 927 ), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}" 928 user_idx = user.args[1] 929 assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}" 930 getitem_idxs.append(user_idx) 931 original_program.graph.erase_node(user) 932 933 getitem_idxs.sort(reverse=True) 934 935 # Adjust all the getitem indices after the deleted getitems 936 user_nodes = list(call_delegate_node.users.keys()) 937 for user in user_nodes: 938 assert user.op == "call_function" and user.target == operator.getitem 939 user_idx = user.args[1] 940 assert isinstance(user_idx, int) 941 for i, idx in enumerate(getitem_idxs): 942 if user_idx > idx: 943 user.args = (user.args[0], user_idx - (len(getitem_idxs) - i)) 944 break 945 946 original_program._validate() 947