1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport copy 8*523fa7a6SAndroid Build Coastguard Workerimport logging 9*523fa7a6SAndroid Build Coastguard Workerfrom contextlib import contextmanager, nullcontext 10*523fa7a6SAndroid Build Coastguard Workerfrom functools import singledispatch 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Generator, List 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport torch 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import BackendDetails, PreprocessResult 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.compile_spec_schema import CompileSpec 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.partitioner import Partitioner, PartitionResult 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.utils import ( 20*523fa7a6SAndroid Build Coastguard Worker _maybe_duplicate_constant_nodes, 21*523fa7a6SAndroid Build Coastguard Worker is_identical_graph, 22*523fa7a6SAndroid Build Coastguard Worker) 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import get_control_flow_submodules 27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.lowered_backend_module import ( 28*523fa7a6SAndroid Build Coastguard Worker _unsafe_adjust_original_program, 29*523fa7a6SAndroid Build Coastguard Worker create_exported_program_from_submodule, 30*523fa7a6SAndroid Build Coastguard Worker create_submodule_from_nodes, 31*523fa7a6SAndroid Build Coastguard Worker LoweredBackendModule, 32*523fa7a6SAndroid Build Coastguard Worker) 33*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.program._fake_program import ( 34*523fa7a6SAndroid Build Coastguard Worker get_fake_program, 35*523fa7a6SAndroid Build Coastguard Worker update_to_real_program, 36*523fa7a6SAndroid Build Coastguard Worker) 37*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param 38*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import ExportedProgram 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Worker@singledispatch 42*523fa7a6SAndroid Build Coastguard Workerdef to_backend(args): 43*523fa7a6SAndroid Build Coastguard Worker """ 44*523fa7a6SAndroid Build Coastguard Worker A generic function the dispatch happens on the type of the first argument. There are currently to overloaded to_backend function: 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker Note: Python is dynamically-typed language and therefore cannot have proper method overloading as that requires the language to 47*523fa7a6SAndroid Build Coastguard Worker be able to discriminate between types at compile-time. @to_backend.register will attach the function to to_backend() base on the type of the first 48*523fa7a6SAndroid Build Coastguard Worker argument (type annotation is required). However, it can't take multiple types as arguments. 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard Worker :: 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker def to_backend( 53*523fa7a6SAndroid Build Coastguard Worker backend_id: str, 54*523fa7a6SAndroid Build Coastguard Worker edge_graph_module: ExportedProgram, 55*523fa7a6SAndroid Build Coastguard Worker compile_specs: List[CompileSpec], 56*523fa7a6SAndroid Build Coastguard Worker ) -> LoweredBackendModule: 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Worker def to_backend( 59*523fa7a6SAndroid Build Coastguard Worker graph_module: torch.fx.GraphModule, 60*523fa7a6SAndroid Build Coastguard Worker partitioner: Type[TPartitioner], 61*523fa7a6SAndroid Build Coastguard Worker ) -> torch.fx.GraphModule 62*523fa7a6SAndroid Build Coastguard Worker """ 63*523fa7a6SAndroid Build Coastguard Worker pass 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker 66*523fa7a6SAndroid Build Coastguard Worker@to_backend.register 67*523fa7a6SAndroid Build Coastguard Workerdef _( 68*523fa7a6SAndroid Build Coastguard Worker backend_id: str, 69*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, 70*523fa7a6SAndroid Build Coastguard Worker compile_specs: List[CompileSpec], 71*523fa7a6SAndroid Build Coastguard Worker) -> LoweredBackendModule: 72*523fa7a6SAndroid Build Coastguard Worker """ 73*523fa7a6SAndroid Build Coastguard Worker Add overloaded implementations for to_backend: 74*523fa7a6SAndroid Build Coastguard Worker 75*523fa7a6SAndroid Build Coastguard Worker :: 76*523fa7a6SAndroid Build Coastguard Worker 77*523fa7a6SAndroid Build Coastguard Worker def to_backend( 78*523fa7a6SAndroid Build Coastguard Worker backend_id: str, 79*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, 80*523fa7a6SAndroid Build Coastguard Worker compile_specs: List[CompileSpec], 81*523fa7a6SAndroid Build Coastguard Worker ) -> LoweredBackendModule: 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Worker Requires the passed in exported program in Edge dialect to be executed in 85*523fa7a6SAndroid Build Coastguard Worker the backend identified by backend_id. The forward method of the given 86*523fa7a6SAndroid Build Coastguard Worker edge_graph_module will be targeted for execution. 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker Args: 89*523fa7a6SAndroid Build Coastguard Worker backend_id: The backend identifier. 90*523fa7a6SAndroid Build Coastguard Worker exported_program: An exported program in Edge dialect to target for 91*523fa7a6SAndroid Build Coastguard Worker lowering to the backend. 92*523fa7a6SAndroid Build Coastguard Worker compile_specs: A list of backend-specific objects with static 93*523fa7a6SAndroid Build Coastguard Worker metadata to configure the "compilation" process (e.g. it could be 94*523fa7a6SAndroid Build Coastguard Worker another dictionary itself). 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Worker Returns: 97*523fa7a6SAndroid Build Coastguard Worker LoweredBackendModule: A Module that has been lowered to the target backend. 98*523fa7a6SAndroid Build Coastguard Worker Internally, the lowered Module contains these special attributes: 99*523fa7a6SAndroid Build Coastguard Worker backend_id (str: backend id), __processed_module__ (str: a compiled module) 100*523fa7a6SAndroid Build Coastguard Worker compile_spec, original_module (original exported program) 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker Raises: 103*523fa7a6SAndroid Build Coastguard Worker NotImplementedError: The backend is not implemented (e.g. it was not found). 104*523fa7a6SAndroid Build Coastguard Worker This exception is derived from RuntimeError and should be caught accordingly. 105*523fa7a6SAndroid Build Coastguard Worker RuntimeError: The module cannot be processed by the backend. 106*523fa7a6SAndroid Build Coastguard Worker """ 107*523fa7a6SAndroid Build Coastguard Worker assert isinstance(edge_program, ExportedProgram) 108*523fa7a6SAndroid Build Coastguard Worker 109*523fa7a6SAndroid Build Coastguard Worker # All backend implementation are final, so we don't need to consider nested subclasses. 110*523fa7a6SAndroid Build Coastguard Worker for cls in BackendDetails.__subclasses__(): 111*523fa7a6SAndroid Build Coastguard Worker if backend_id == cls.__name__: 112*523fa7a6SAndroid Build Coastguard Worker copied_edge_program = copy.deepcopy(edge_program) 113*523fa7a6SAndroid Build Coastguard Worker preprocess_result: PreprocessResult = cls.preprocess( 114*523fa7a6SAndroid Build Coastguard Worker copied_edge_program, 115*523fa7a6SAndroid Build Coastguard Worker compile_specs, 116*523fa7a6SAndroid Build Coastguard Worker ) 117*523fa7a6SAndroid Build Coastguard Worker lowered_module = LoweredBackendModule( 118*523fa7a6SAndroid Build Coastguard Worker edge_program=edge_program, 119*523fa7a6SAndroid Build Coastguard Worker backend_id=backend_id, 120*523fa7a6SAndroid Build Coastguard Worker processed_bytes=preprocess_result.processed_bytes, 121*523fa7a6SAndroid Build Coastguard Worker compile_specs=compile_specs, 122*523fa7a6SAndroid Build Coastguard Worker ) 123*523fa7a6SAndroid Build Coastguard Worker lowered_module.meta = { 124*523fa7a6SAndroid Build Coastguard Worker "debug_handle_map": preprocess_result.debug_handle_map 125*523fa7a6SAndroid Build Coastguard Worker } 126*523fa7a6SAndroid Build Coastguard Worker return lowered_module 127*523fa7a6SAndroid Build Coastguard Worker raise NotImplementedError(f"Backend {backend_id} was not found.") 128*523fa7a6SAndroid Build Coastguard Worker 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker_ENABLE_VALIDATION: bool = True 131*523fa7a6SAndroid Build Coastguard Worker 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Workerdef disable_validation() -> None: 134*523fa7a6SAndroid Build Coastguard Worker """Disables validation""" 135*523fa7a6SAndroid Build Coastguard Worker global _ENABLE_VALIDATION 136*523fa7a6SAndroid Build Coastguard Worker _ENABLE_VALIDATION = False 137*523fa7a6SAndroid Build Coastguard Worker 138*523fa7a6SAndroid Build Coastguard Worker 139*523fa7a6SAndroid Build Coastguard Worker@contextmanager 140*523fa7a6SAndroid Build Coastguard Workerdef validation_disabled() -> Generator[None, None, None]: 141*523fa7a6SAndroid Build Coastguard Worker """ 142*523fa7a6SAndroid Build Coastguard Worker Disables checking functions (ex. if the partitioned graph is identical to 143*523fa7a6SAndroid Build Coastguard Worker the original graph). This context manager should only be used in certain 144*523fa7a6SAndroid Build Coastguard Worker scenarios (such as when it has been profiled that checks are taking too 145*523fa7a6SAndroid Build Coastguard Worker long, and are not necessarily needed) 146*523fa7a6SAndroid Build Coastguard Worker """ 147*523fa7a6SAndroid Build Coastguard Worker global _ENABLE_VALIDATION 148*523fa7a6SAndroid Build Coastguard Worker existing_setting = _ENABLE_VALIDATION 149*523fa7a6SAndroid Build Coastguard Worker disable_validation() 150*523fa7a6SAndroid Build Coastguard Worker try: 151*523fa7a6SAndroid Build Coastguard Worker yield 152*523fa7a6SAndroid Build Coastguard Worker finally: 153*523fa7a6SAndroid Build Coastguard Worker _ENABLE_VALIDATION = existing_setting 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Worker 156*523fa7a6SAndroid Build Coastguard Workerdef _get_node_list_with_same_tag( 157*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module: torch.fx.GraphModule, 158*523fa7a6SAndroid Build Coastguard Worker tag: str, 159*523fa7a6SAndroid Build Coastguard Worker owning_program: ExportedProgram, 160*523fa7a6SAndroid Build Coastguard Worker) -> List[torch.fx.Node]: 161*523fa7a6SAndroid Build Coastguard Worker """ 162*523fa7a6SAndroid Build Coastguard Worker Return a list of nodes with the same tag. 163*523fa7a6SAndroid Build Coastguard Worker """ 164*523fa7a6SAndroid Build Coastguard Worker node_list = [] 165*523fa7a6SAndroid Build Coastguard Worker 166*523fa7a6SAndroid Build Coastguard Worker for node in tagged_graph_module.graph.nodes: 167*523fa7a6SAndroid Build Coastguard Worker if node.meta.get("delegation_tag", "") == tag: 168*523fa7a6SAndroid Build Coastguard Worker if node.op == "output": 169*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"output node {node} should not be tagged") 170*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder": 171*523fa7a6SAndroid Build Coastguard Worker if ( 172*523fa7a6SAndroid Build Coastguard Worker not is_param(owning_program, node) 173*523fa7a6SAndroid Build Coastguard Worker and not is_buffer(owning_program, node) 174*523fa7a6SAndroid Build Coastguard Worker and not is_lifted_tensor_constant(owning_program, node) 175*523fa7a6SAndroid Build Coastguard Worker ): 176*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 177*523fa7a6SAndroid Build Coastguard Worker f"placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged: {node} " 178*523fa7a6SAndroid Build Coastguard Worker ) 179*523fa7a6SAndroid Build Coastguard Worker else: 180*523fa7a6SAndroid Build Coastguard Worker # check that the users all belong to the same tag 181*523fa7a6SAndroid Build Coastguard Worker for user in node.users: 182*523fa7a6SAndroid Build Coastguard Worker users_tag = user.meta.get("delegation_tag", None) 183*523fa7a6SAndroid Build Coastguard Worker if users_tag != tag: 184*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 185*523fa7a6SAndroid Build Coastguard Worker f"constant data node ({node}) is tagged with ({tag}) but has user ({user}) which has tag ({users_tag})" 186*523fa7a6SAndroid Build Coastguard Worker ) 187*523fa7a6SAndroid Build Coastguard Worker node_list.append(node) 188*523fa7a6SAndroid Build Coastguard Worker return node_list 189*523fa7a6SAndroid Build Coastguard Worker 190*523fa7a6SAndroid Build Coastguard Worker 191*523fa7a6SAndroid Build Coastguard Workerdef _partition_and_lower_one_graph_module( 192*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module: torch.fx.GraphModule, 193*523fa7a6SAndroid Build Coastguard Worker partition_result: PartitionResult, 194*523fa7a6SAndroid Build Coastguard Worker owning_program: ExportedProgram, 195*523fa7a6SAndroid Build Coastguard Worker is_submodule: bool, 196*523fa7a6SAndroid Build Coastguard Worker) -> torch.fx.GraphModule: 197*523fa7a6SAndroid Build Coastguard Worker """ 198*523fa7a6SAndroid Build Coastguard Worker Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module. 199*523fa7a6SAndroid Build Coastguard Worker """ 200*523fa7a6SAndroid Build Coastguard Worker for tag, delegation_spec in partition_result.partition_tags.items(): 201*523fa7a6SAndroid Build Coastguard Worker # Create partition with nodes containing this tag. There should only be 202*523fa7a6SAndroid Build Coastguard Worker # one contained submodule per tag 203*523fa7a6SAndroid Build Coastguard Worker node_list = _get_node_list_with_same_tag( 204*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module, tag, owning_program 205*523fa7a6SAndroid Build Coastguard Worker ) 206*523fa7a6SAndroid Build Coastguard Worker 207*523fa7a6SAndroid Build Coastguard Worker if len(node_list) == 0: 208*523fa7a6SAndroid Build Coastguard Worker logging.debug(f"Did not find any nodes for tag {tag}") 209*523fa7a6SAndroid Build Coastguard Worker continue 210*523fa7a6SAndroid Build Coastguard Worker 211*523fa7a6SAndroid Build Coastguard Worker logging.debug(f"For tag {tag}, found nodes {node_list}") 212*523fa7a6SAndroid Build Coastguard Worker # Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs) 213*523fa7a6SAndroid Build Coastguard Worker 214*523fa7a6SAndroid Build Coastguard Worker replace_ctx = ( 215*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module._set_replace_hook( 216*523fa7a6SAndroid Build Coastguard Worker owning_program.graph_signature.get_replace_hook() 217*523fa7a6SAndroid Build Coastguard Worker ) 218*523fa7a6SAndroid Build Coastguard Worker if not is_submodule 219*523fa7a6SAndroid Build Coastguard Worker else nullcontext() 220*523fa7a6SAndroid Build Coastguard Worker ) 221*523fa7a6SAndroid Build Coastguard Worker with replace_ctx: 222*523fa7a6SAndroid Build Coastguard Worker submodule, call_module_node = create_submodule_from_nodes( 223*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module, node_list, tag 224*523fa7a6SAndroid Build Coastguard Worker ) 225*523fa7a6SAndroid Build Coastguard Worker 226*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module_output_node = [ 227*523fa7a6SAndroid Build Coastguard Worker node for node in tagged_graph_module.graph.nodes if node.op == "output" 228*523fa7a6SAndroid Build Coastguard Worker ][0] 229*523fa7a6SAndroid Build Coastguard Worker submodule_output_node = [ 230*523fa7a6SAndroid Build Coastguard Worker node for node in submodule.graph.nodes if node.op == "output" 231*523fa7a6SAndroid Build Coastguard Worker ][0] 232*523fa7a6SAndroid Build Coastguard Worker # Copy the output node meta from the original output node, because 233*523fa7a6SAndroid Build Coastguard Worker # create_submodule_from_nodes doesn't cover the meta field 234*523fa7a6SAndroid Build Coastguard Worker submodule_output_node.meta = tagged_graph_module_output_node.meta 235*523fa7a6SAndroid Build Coastguard Worker logging.debug(f"Partitioned graph module: {tagged_graph_module}") 236*523fa7a6SAndroid Build Coastguard Worker 237*523fa7a6SAndroid Build Coastguard Worker ( 238*523fa7a6SAndroid Build Coastguard Worker submodule_program, 239*523fa7a6SAndroid Build Coastguard Worker toplevel_input_specs_to_delete, 240*523fa7a6SAndroid Build Coastguard Worker toplevel_output_specs_to_delete, 241*523fa7a6SAndroid Build Coastguard Worker ) = create_exported_program_from_submodule( 242*523fa7a6SAndroid Build Coastguard Worker submodule, 243*523fa7a6SAndroid Build Coastguard Worker owning_program, 244*523fa7a6SAndroid Build Coastguard Worker tag, 245*523fa7a6SAndroid Build Coastguard Worker call_module_node, 246*523fa7a6SAndroid Build Coastguard Worker is_submodule, 247*523fa7a6SAndroid Build Coastguard Worker ) 248*523fa7a6SAndroid Build Coastguard Worker 249*523fa7a6SAndroid Build Coastguard Worker lowered_submodule = to_backend( 250*523fa7a6SAndroid Build Coastguard Worker delegation_spec.backend_id, 251*523fa7a6SAndroid Build Coastguard Worker submodule_program, 252*523fa7a6SAndroid Build Coastguard Worker delegation_spec.compile_specs, 253*523fa7a6SAndroid Build Coastguard Worker ) 254*523fa7a6SAndroid Build Coastguard Worker 255*523fa7a6SAndroid Build Coastguard Worker # call delegate args should only use user_inputs 256*523fa7a6SAndroid Build Coastguard Worker call_delegate_args = [] 257*523fa7a6SAndroid Build Coastguard Worker # Preserve input order as user_inputs 258*523fa7a6SAndroid Build Coastguard Worker for inp_name in submodule_program.graph_signature.user_inputs: 259*523fa7a6SAndroid Build Coastguard Worker for inp_node in call_module_node.all_input_nodes: 260*523fa7a6SAndroid Build Coastguard Worker if inp_node.name == inp_name: 261*523fa7a6SAndroid Build Coastguard Worker call_delegate_args.append(inp_node) 262*523fa7a6SAndroid Build Coastguard Worker break 263*523fa7a6SAndroid Build Coastguard Worker 264*523fa7a6SAndroid Build Coastguard Worker def generate_debug_handle(ep: ExportedProgram) -> int: 265*523fa7a6SAndroid Build Coastguard Worker """ 266*523fa7a6SAndroid Build Coastguard Worker Generate a debug handle for the given ExportedProgram. 267*523fa7a6SAndroid Build Coastguard Worker """ 268*523fa7a6SAndroid Build Coastguard Worker debug_handle = 0 269*523fa7a6SAndroid Build Coastguard Worker for node in ep.graph_module.graph.nodes: 270*523fa7a6SAndroid Build Coastguard Worker debug_handle = max(debug_handle, node.meta.get("debug_handle", 0)) 271*523fa7a6SAndroid Build Coastguard Worker return debug_handle + 1 272*523fa7a6SAndroid Build Coastguard Worker 273*523fa7a6SAndroid Build Coastguard Worker # Replace the partitioned submodule with a lowered submodule 274*523fa7a6SAndroid Build Coastguard Worker # Add call_method node with function "forward" 275*523fa7a6SAndroid Build Coastguard Worker with tagged_graph_module.graph.inserting_before(call_module_node): 276*523fa7a6SAndroid Build Coastguard Worker lowered_name = get_lowered_module_name( 277*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module, lowered_submodule 278*523fa7a6SAndroid Build Coastguard Worker ) 279*523fa7a6SAndroid Build Coastguard Worker lowered_node = tagged_graph_module.graph.get_attr(lowered_name) 280*523fa7a6SAndroid Build Coastguard Worker call_delegate_node = tagged_graph_module.graph.call_function( 281*523fa7a6SAndroid Build Coastguard Worker executorch_call_delegate, 282*523fa7a6SAndroid Build Coastguard Worker (lowered_node,) + tuple(call_delegate_args), 283*523fa7a6SAndroid Build Coastguard Worker call_module_node.kwargs, 284*523fa7a6SAndroid Build Coastguard Worker ) 285*523fa7a6SAndroid Build Coastguard Worker call_delegate_node.meta["debug_handle"] = generate_debug_handle( 286*523fa7a6SAndroid Build Coastguard Worker owning_program 287*523fa7a6SAndroid Build Coastguard Worker ) 288*523fa7a6SAndroid Build Coastguard Worker call_delegate_node.meta["val"] = submodule_output_node.meta["val"] 289*523fa7a6SAndroid Build Coastguard Worker call_module_node.replace_all_uses_with(call_delegate_node) 290*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module.graph.erase_node(call_module_node) 291*523fa7a6SAndroid Build Coastguard Worker 292*523fa7a6SAndroid Build Coastguard Worker if is_submodule: 293*523fa7a6SAndroid Build Coastguard Worker assert len(toplevel_input_specs_to_delete) == 0 294*523fa7a6SAndroid Build Coastguard Worker assert len(toplevel_output_specs_to_delete) == 0 295*523fa7a6SAndroid Build Coastguard Worker elif ( 296*523fa7a6SAndroid Build Coastguard Worker len(toplevel_input_specs_to_delete) > 0 297*523fa7a6SAndroid Build Coastguard Worker or len(toplevel_output_specs_to_delete) > 0 298*523fa7a6SAndroid Build Coastguard Worker ): 299*523fa7a6SAndroid Build Coastguard Worker _unsafe_adjust_original_program( 300*523fa7a6SAndroid Build Coastguard Worker owning_program, 301*523fa7a6SAndroid Build Coastguard Worker call_delegate_node, 302*523fa7a6SAndroid Build Coastguard Worker toplevel_input_specs_to_delete, 303*523fa7a6SAndroid Build Coastguard Worker toplevel_output_specs_to_delete, 304*523fa7a6SAndroid Build Coastguard Worker ) 305*523fa7a6SAndroid Build Coastguard Worker 306*523fa7a6SAndroid Build Coastguard Worker return tagged_graph_module 307*523fa7a6SAndroid Build Coastguard Worker 308*523fa7a6SAndroid Build Coastguard Worker 309*523fa7a6SAndroid Build Coastguard Workerdef _partition_and_lower( 310*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module: torch.fx.GraphModule, 311*523fa7a6SAndroid Build Coastguard Worker partition_result: PartitionResult, 312*523fa7a6SAndroid Build Coastguard Worker owning_program: ExportedProgram, 313*523fa7a6SAndroid Build Coastguard Worker is_submodule: bool = False, 314*523fa7a6SAndroid Build Coastguard Worker) -> torch.fx.GraphModule: 315*523fa7a6SAndroid Build Coastguard Worker """ 316*523fa7a6SAndroid Build Coastguard Worker Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow 317*523fa7a6SAndroid Build Coastguard Worker """ 318*523fa7a6SAndroid Build Coastguard Worker 319*523fa7a6SAndroid Build Coastguard Worker partitioned_module = _partition_and_lower_one_graph_module( 320*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module, partition_result, owning_program, is_submodule 321*523fa7a6SAndroid Build Coastguard Worker ) 322*523fa7a6SAndroid Build Coastguard Worker 323*523fa7a6SAndroid Build Coastguard Worker # Recursively partition and lower for submodules 324*523fa7a6SAndroid Build Coastguard Worker for name, submod, _node in get_control_flow_submodules(partitioned_module): 325*523fa7a6SAndroid Build Coastguard Worker partitioned_submodule = _partition_and_lower( 326*523fa7a6SAndroid Build Coastguard Worker submod, partition_result, owning_program, is_submodule=True 327*523fa7a6SAndroid Build Coastguard Worker ) 328*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module.add_module(name, partitioned_submodule) 329*523fa7a6SAndroid Build Coastguard Worker 330*523fa7a6SAndroid Build Coastguard Worker return tagged_graph_module 331*523fa7a6SAndroid Build Coastguard Worker 332*523fa7a6SAndroid Build Coastguard Worker 333*523fa7a6SAndroid Build Coastguard Worker@to_backend.register 334*523fa7a6SAndroid Build Coastguard Workerdef _( 335*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, 336*523fa7a6SAndroid Build Coastguard Worker partitioner_instance: Partitioner, 337*523fa7a6SAndroid Build Coastguard Worker) -> ExportedProgram: 338*523fa7a6SAndroid Build Coastguard Worker """ 339*523fa7a6SAndroid Build Coastguard Worker Add overloaded implementations for to_backend: 340*523fa7a6SAndroid Build Coastguard Worker 341*523fa7a6SAndroid Build Coastguard Worker :: 342*523fa7a6SAndroid Build Coastguard Worker 343*523fa7a6SAndroid Build Coastguard Worker def to_backend( 344*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, 345*523fa7a6SAndroid Build Coastguard Worker partitioner: Partitioner, 346*523fa7a6SAndroid Build Coastguard Worker ) -> ExportedProgram: 347*523fa7a6SAndroid Build Coastguard Worker 348*523fa7a6SAndroid Build Coastguard Worker Returns a semantically-equivalent program to the one given as input (represented 349*523fa7a6SAndroid Build Coastguard Worker as a graph module in Edge dialect), but with portions of the program targeted for 350*523fa7a6SAndroid Build Coastguard Worker delegation as determined by the partitioner. 351*523fa7a6SAndroid Build Coastguard Worker 352*523fa7a6SAndroid Build Coastguard Worker Args: 353*523fa7a6SAndroid Build Coastguard Worker ExportedProgram: Program in Edge dialect. 354*523fa7a6SAndroid Build Coastguard Worker 355*523fa7a6SAndroid Build Coastguard Worker partitioner: An instance of the partitioner, in charge with tagging 356*523fa7a6SAndroid Build Coastguard Worker portions of the input program for delegation. A valid partitioner must return PartitionerResult 357*523fa7a6SAndroid Build Coastguard Worker including both tagged exported program and partitioner_tag: Dict[str, DelegationSpec], where each key is a tag name and 358*523fa7a6SAndroid Build Coastguard Worker the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec. 359*523fa7a6SAndroid Build Coastguard Worker 360*523fa7a6SAndroid Build Coastguard Worker 361*523fa7a6SAndroid Build Coastguard Worker Returns: 362*523fa7a6SAndroid Build Coastguard Worker ExportedProgram: The input program, with some portions targeted for delegation. 363*523fa7a6SAndroid Build Coastguard Worker """ 364*523fa7a6SAndroid Build Coastguard Worker edge_program._validate() 365*523fa7a6SAndroid Build Coastguard Worker 366*523fa7a6SAndroid Build Coastguard Worker # Use fake program, with FakeTensors in the state dict, to avoid copying large constant values. 367*523fa7a6SAndroid Build Coastguard Worker # Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback. 368*523fa7a6SAndroid Build Coastguard Worker try: 369*523fa7a6SAndroid Build Coastguard Worker fake_edge_program = get_fake_program(edge_program) 370*523fa7a6SAndroid Build Coastguard Worker except Exception as e: 371*523fa7a6SAndroid Build Coastguard Worker logging.warning( 372*523fa7a6SAndroid Build Coastguard Worker f"Error in get_fake_program for graph {edge_program.graph_module}, fallback to deepcopy: {e}" 373*523fa7a6SAndroid Build Coastguard Worker ) 374*523fa7a6SAndroid Build Coastguard Worker fake_edge_program = copy.deepcopy(edge_program) 375*523fa7a6SAndroid Build Coastguard Worker partitioner_result = partitioner_instance(fake_edge_program) 376*523fa7a6SAndroid Build Coastguard Worker tagged_exported_program = partitioner_result.tagged_exported_program 377*523fa7a6SAndroid Build Coastguard Worker 378*523fa7a6SAndroid Build Coastguard Worker # Check that the partitioner did not modify the original graph 379*523fa7a6SAndroid Build Coastguard Worker if _ENABLE_VALIDATION: 380*523fa7a6SAndroid Build Coastguard Worker assert is_identical_graph( 381*523fa7a6SAndroid Build Coastguard Worker tagged_exported_program.graph_module, 382*523fa7a6SAndroid Build Coastguard Worker edge_program.graph_module, 383*523fa7a6SAndroid Build Coastguard Worker ), f"The partitioner {partitioner_instance} should not modify the graph module" 384*523fa7a6SAndroid Build Coastguard Worker else: 385*523fa7a6SAndroid Build Coastguard Worker logging.warning("Disabled validating the partitioner.") 386*523fa7a6SAndroid Build Coastguard Worker 387*523fa7a6SAndroid Build Coastguard Worker assert ( 388*523fa7a6SAndroid Build Coastguard Worker partitioner_result.partition_tags is not None 389*523fa7a6SAndroid Build Coastguard Worker ), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec" 390*523fa7a6SAndroid Build Coastguard Worker 391*523fa7a6SAndroid Build Coastguard Worker update_to_real_program(tagged_exported_program, edge_program) 392*523fa7a6SAndroid Build Coastguard Worker 393*523fa7a6SAndroid Build Coastguard Worker for tag, _ in partitioner_result.partition_tags.items(): 394*523fa7a6SAndroid Build Coastguard Worker _maybe_duplicate_constant_nodes(tagged_exported_program, tag) 395*523fa7a6SAndroid Build Coastguard Worker 396*523fa7a6SAndroid Build Coastguard Worker tagged_graph_module = _partition_and_lower( 397*523fa7a6SAndroid Build Coastguard Worker tagged_exported_program.graph_module, 398*523fa7a6SAndroid Build Coastguard Worker partitioner_result, 399*523fa7a6SAndroid Build Coastguard Worker tagged_exported_program, 400*523fa7a6SAndroid Build Coastguard Worker ) 401*523fa7a6SAndroid Build Coastguard Worker 402*523fa7a6SAndroid Build Coastguard Worker return ExportedProgram( 403*523fa7a6SAndroid Build Coastguard Worker root=tagged_graph_module, 404*523fa7a6SAndroid Build Coastguard Worker graph=tagged_graph_module.graph, 405*523fa7a6SAndroid Build Coastguard Worker graph_signature=tagged_exported_program.graph_signature, 406*523fa7a6SAndroid Build Coastguard Worker state_dict=tagged_exported_program.state_dict, 407*523fa7a6SAndroid Build Coastguard Worker range_constraints=copy.deepcopy(tagged_exported_program.range_constraints), 408*523fa7a6SAndroid Build Coastguard Worker module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph), 409*523fa7a6SAndroid Build Coastguard Worker example_inputs=None, 410*523fa7a6SAndroid Build Coastguard Worker constants=tagged_exported_program.constants, 411*523fa7a6SAndroid Build Coastguard Worker verifiers=[tagged_exported_program.verifier], 412*523fa7a6SAndroid Build Coastguard Worker ) 413