1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8import logging 9from contextlib import contextmanager, nullcontext 10from functools import singledispatch 11from typing import Generator, List 12 13import torch 14 15from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult 16from executorch.exir.backend.compile_spec_schema import CompileSpec 17 18from executorch.exir.backend.partitioner import Partitioner, PartitionResult 19from executorch.exir.backend.utils import ( 20 _maybe_duplicate_constant_nodes, 21 is_identical_graph, 22) 23 24from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name 25 26from executorch.exir.graph_module import get_control_flow_submodules 27from executorch.exir.lowered_backend_module import ( 28 _unsafe_adjust_original_program, 29 create_exported_program_from_submodule, 30 create_submodule_from_nodes, 31 LoweredBackendModule, 32) 33from executorch.exir.program._fake_program import ( 34 get_fake_program, 35 update_to_real_program, 36) 37from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param 38from torch.export import ExportedProgram 39 40 41@singledispatch 42def to_backend(args): 43 """ 44 A generic function the dispatch happens on the type of the first argument. There are currently to overloaded to_backend function: 45 46 Note: Python is dynamically-typed language and therefore cannot have proper method overloading as that requires the language to 47 be able to discriminate between types at compile-time. @to_backend.register will attach the function to to_backend() base on the type of the first 48 argument (type annotation is required). However, it can't take multiple types as arguments. 49 50 :: 51 52 def to_backend( 53 backend_id: str, 54 edge_graph_module: ExportedProgram, 55 compile_specs: List[CompileSpec], 56 ) -> LoweredBackendModule: 57 58 def to_backend( 59 graph_module: torch.fx.GraphModule, 60 partitioner: Type[TPartitioner], 61 ) -> torch.fx.GraphModule 62 """ 63 pass 64 65 66@to_backend.register 67def _( 68 backend_id: str, 69 edge_program: ExportedProgram, 70 compile_specs: List[CompileSpec], 71) -> LoweredBackendModule: 72 """ 73 Add overloaded implementations for to_backend: 74 75 :: 76 77 def to_backend( 78 backend_id: str, 79 edge_program: ExportedProgram, 80 compile_specs: List[CompileSpec], 81 ) -> LoweredBackendModule: 82 83 84 Requires the passed in exported program in Edge dialect to be executed in 85 the backend identified by backend_id. The forward method of the given 86 edge_graph_module will be targeted for execution. 87 88 Args: 89 backend_id: The backend identifier. 90 exported_program: An exported program in Edge dialect to target for 91 lowering to the backend. 92 compile_specs: A list of backend-specific objects with static 93 metadata to configure the "compilation" process (e.g. it could be 94 another dictionary itself). 95 96 Returns: 97 LoweredBackendModule: A Module that has been lowered to the target backend. 98 Internally, the lowered Module contains these special attributes: 99 backend_id (str: backend id), __processed_module__ (str: a compiled module) 100 compile_spec, original_module (original exported program) 101 102 Raises: 103 NotImplementedError: The backend is not implemented (e.g. it was not found). 104 This exception is derived from RuntimeError and should be caught accordingly. 105 RuntimeError: The module cannot be processed by the backend. 106 """ 107 assert isinstance(edge_program, ExportedProgram) 108 109 # All backend implementation are final, so we don't need to consider nested subclasses. 110 for cls in BackendDetails.__subclasses__(): 111 if backend_id == cls.__name__: 112 copied_edge_program = copy.deepcopy(edge_program) 113 preprocess_result: PreprocessResult = cls.preprocess( 114 copied_edge_program, 115 compile_specs, 116 ) 117 lowered_module = LoweredBackendModule( 118 edge_program=edge_program, 119 backend_id=backend_id, 120 processed_bytes=preprocess_result.processed_bytes, 121 compile_specs=compile_specs, 122 ) 123 lowered_module.meta = { 124 "debug_handle_map": preprocess_result.debug_handle_map 125 } 126 return lowered_module 127 raise NotImplementedError(f"Backend {backend_id} was not found.") 128 129 130_ENABLE_VALIDATION: bool = True 131 132 133def disable_validation() -> None: 134 """Disables validation""" 135 global _ENABLE_VALIDATION 136 _ENABLE_VALIDATION = False 137 138 139@contextmanager 140def validation_disabled() -> Generator[None, None, None]: 141 """ 142 Disables checking functions (ex. if the partitioned graph is identical to 143 the original graph). This context manager should only be used in certain 144 scenarios (such as when it has been profiled that checks are taking too 145 long, and are not necessarily needed) 146 """ 147 global _ENABLE_VALIDATION 148 existing_setting = _ENABLE_VALIDATION 149 disable_validation() 150 try: 151 yield 152 finally: 153 _ENABLE_VALIDATION = existing_setting 154 155 156def _get_node_list_with_same_tag( 157 tagged_graph_module: torch.fx.GraphModule, 158 tag: str, 159 owning_program: ExportedProgram, 160) -> List[torch.fx.Node]: 161 """ 162 Return a list of nodes with the same tag. 163 """ 164 node_list = [] 165 166 for node in tagged_graph_module.graph.nodes: 167 if node.meta.get("delegation_tag", "") == tag: 168 if node.op == "output": 169 raise RuntimeError(f"output node {node} should not be tagged") 170 if node.op == "placeholder": 171 if ( 172 not is_param(owning_program, node) 173 and not is_buffer(owning_program, node) 174 and not is_lifted_tensor_constant(owning_program, node) 175 ): 176 raise RuntimeError( 177 f"placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged: {node} " 178 ) 179 else: 180 # check that the users all belong to the same tag 181 for user in node.users: 182 users_tag = user.meta.get("delegation_tag", None) 183 if users_tag != tag: 184 raise RuntimeError( 185 f"constant data node ({node}) is tagged with ({tag}) but has user ({user}) which has tag ({users_tag})" 186 ) 187 node_list.append(node) 188 return node_list 189 190 191def _partition_and_lower_one_graph_module( 192 tagged_graph_module: torch.fx.GraphModule, 193 partition_result: PartitionResult, 194 owning_program: ExportedProgram, 195 is_submodule: bool, 196) -> torch.fx.GraphModule: 197 """ 198 Partitioned and lowered the graph module based on the partition tag, this is to handle one graph module. 199 """ 200 for tag, delegation_spec in partition_result.partition_tags.items(): 201 # Create partition with nodes containing this tag. There should only be 202 # one contained submodule per tag 203 node_list = _get_node_list_with_same_tag( 204 tagged_graph_module, tag, owning_program 205 ) 206 207 if len(node_list) == 0: 208 logging.debug(f"Did not find any nodes for tag {tag}") 209 continue 210 211 logging.debug(f"For tag {tag}, found nodes {node_list}") 212 # Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs) 213 214 replace_ctx = ( 215 tagged_graph_module._set_replace_hook( 216 owning_program.graph_signature.get_replace_hook() 217 ) 218 if not is_submodule 219 else nullcontext() 220 ) 221 with replace_ctx: 222 submodule, call_module_node = create_submodule_from_nodes( 223 tagged_graph_module, node_list, tag 224 ) 225 226 tagged_graph_module_output_node = [ 227 node for node in tagged_graph_module.graph.nodes if node.op == "output" 228 ][0] 229 submodule_output_node = [ 230 node for node in submodule.graph.nodes if node.op == "output" 231 ][0] 232 # Copy the output node meta from the original output node, because 233 # create_submodule_from_nodes doesn't cover the meta field 234 submodule_output_node.meta = tagged_graph_module_output_node.meta 235 logging.debug(f"Partitioned graph module: {tagged_graph_module}") 236 237 ( 238 submodule_program, 239 toplevel_input_specs_to_delete, 240 toplevel_output_specs_to_delete, 241 ) = create_exported_program_from_submodule( 242 submodule, 243 owning_program, 244 tag, 245 call_module_node, 246 is_submodule, 247 ) 248 249 lowered_submodule = to_backend( 250 delegation_spec.backend_id, 251 submodule_program, 252 delegation_spec.compile_specs, 253 ) 254 255 # call delegate args should only use user_inputs 256 call_delegate_args = [] 257 # Preserve input order as user_inputs 258 for inp_name in submodule_program.graph_signature.user_inputs: 259 for inp_node in call_module_node.all_input_nodes: 260 if inp_node.name == inp_name: 261 call_delegate_args.append(inp_node) 262 break 263 264 def generate_debug_handle(ep: ExportedProgram) -> int: 265 """ 266 Generate a debug handle for the given ExportedProgram. 267 """ 268 debug_handle = 0 269 for node in ep.graph_module.graph.nodes: 270 debug_handle = max(debug_handle, node.meta.get("debug_handle", 0)) 271 return debug_handle + 1 272 273 # Replace the partitioned submodule with a lowered submodule 274 # Add call_method node with function "forward" 275 with tagged_graph_module.graph.inserting_before(call_module_node): 276 lowered_name = get_lowered_module_name( 277 tagged_graph_module, lowered_submodule 278 ) 279 lowered_node = tagged_graph_module.graph.get_attr(lowered_name) 280 call_delegate_node = tagged_graph_module.graph.call_function( 281 executorch_call_delegate, 282 (lowered_node,) + tuple(call_delegate_args), 283 call_module_node.kwargs, 284 ) 285 call_delegate_node.meta["debug_handle"] = generate_debug_handle( 286 owning_program 287 ) 288 call_delegate_node.meta["val"] = submodule_output_node.meta["val"] 289 call_module_node.replace_all_uses_with(call_delegate_node) 290 tagged_graph_module.graph.erase_node(call_module_node) 291 292 if is_submodule: 293 assert len(toplevel_input_specs_to_delete) == 0 294 assert len(toplevel_output_specs_to_delete) == 0 295 elif ( 296 len(toplevel_input_specs_to_delete) > 0 297 or len(toplevel_output_specs_to_delete) > 0 298 ): 299 _unsafe_adjust_original_program( 300 owning_program, 301 call_delegate_node, 302 toplevel_input_specs_to_delete, 303 toplevel_output_specs_to_delete, 304 ) 305 306 return tagged_graph_module 307 308 309def _partition_and_lower( 310 tagged_graph_module: torch.fx.GraphModule, 311 partition_result: PartitionResult, 312 owning_program: ExportedProgram, 313 is_submodule: bool = False, 314) -> torch.fx.GraphModule: 315 """ 316 Partitions the graph module into submodules based on tags, and then lowered the nodes with the same tag as one lowered module, including the submodule from control flow 317 """ 318 319 partitioned_module = _partition_and_lower_one_graph_module( 320 tagged_graph_module, partition_result, owning_program, is_submodule 321 ) 322 323 # Recursively partition and lower for submodules 324 for name, submod, _node in get_control_flow_submodules(partitioned_module): 325 partitioned_submodule = _partition_and_lower( 326 submod, partition_result, owning_program, is_submodule=True 327 ) 328 tagged_graph_module.add_module(name, partitioned_submodule) 329 330 return tagged_graph_module 331 332 333@to_backend.register 334def _( 335 edge_program: ExportedProgram, 336 partitioner_instance: Partitioner, 337) -> ExportedProgram: 338 """ 339 Add overloaded implementations for to_backend: 340 341 :: 342 343 def to_backend( 344 edge_program: ExportedProgram, 345 partitioner: Partitioner, 346 ) -> ExportedProgram: 347 348 Returns a semantically-equivalent program to the one given as input (represented 349 as a graph module in Edge dialect), but with portions of the program targeted for 350 delegation as determined by the partitioner. 351 352 Args: 353 ExportedProgram: Program in Edge dialect. 354 355 partitioner: An instance of the partitioner, in charge with tagging 356 portions of the input program for delegation. A valid partitioner must return PartitionerResult 357 including both tagged exported program and partitioner_tag: Dict[str, DelegationSpec], where each key is a tag name and 358 the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec. 359 360 361 Returns: 362 ExportedProgram: The input program, with some portions targeted for delegation. 363 """ 364 edge_program._validate() 365 366 # Use fake program, with FakeTensors in the state dict, to avoid copying large constant values. 367 # Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback. 368 try: 369 fake_edge_program = get_fake_program(edge_program) 370 except Exception as e: 371 logging.warning( 372 f"Error in get_fake_program for graph {edge_program.graph_module}, fallback to deepcopy: {e}" 373 ) 374 fake_edge_program = copy.deepcopy(edge_program) 375 partitioner_result = partitioner_instance(fake_edge_program) 376 tagged_exported_program = partitioner_result.tagged_exported_program 377 378 # Check that the partitioner did not modify the original graph 379 if _ENABLE_VALIDATION: 380 assert is_identical_graph( 381 tagged_exported_program.graph_module, 382 edge_program.graph_module, 383 ), f"The partitioner {partitioner_instance} should not modify the graph module" 384 else: 385 logging.warning("Disabled validating the partitioner.") 386 387 assert ( 388 partitioner_result.partition_tags is not None 389 ), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec" 390 391 update_to_real_program(tagged_exported_program, edge_program) 392 393 for tag, _ in partitioner_result.partition_tags.items(): 394 _maybe_duplicate_constant_nodes(tagged_exported_program, tag) 395 396 tagged_graph_module = _partition_and_lower( 397 tagged_exported_program.graph_module, 398 partitioner_result, 399 tagged_exported_program, 400 ) 401 402 return ExportedProgram( 403 root=tagged_graph_module, 404 graph=tagged_graph_module.graph, 405 graph_signature=tagged_exported_program.graph_signature, 406 state_dict=tagged_exported_program.state_dict, 407 range_constraints=copy.deepcopy(tagged_exported_program.range_constraints), 408 module_call_graph=copy.deepcopy(tagged_exported_program.module_call_graph), 409 example_inputs=None, 410 constants=tagged_exported_program.constants, 411 verifiers=[tagged_exported_program.verifier], 412 ) 413