1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8from dataclasses import dataclass 9from typing import Any, Dict, List, Optional, Union 10 11import torch 12import torch.fx 13from executorch.exir.emit._emitter import ( 14 _DelegateDebugIdentifierMap, 15 _EmitterState, 16 _ProgramState, 17 _TopLevelEmitter, 18) 19from executorch.exir.error import ExportError, ExportErrorType 20 21from executorch.exir.schema import Buffer, Program, SubsegmentOffsets 22from executorch.exir.version import EXECUTORCH_SCHEMA_VERSION 23from torch.export.exported_program import ExportedProgram, OutputKind 24from torch.utils import _pytree as pytree 25 26 27@dataclass 28class EmitterOutput: 29 """ 30 The outputs of program emission. Contains the executorch program object as well as 31 a mapping of instruction ids to debug handles. 32 """ 33 34 # The ExecuTorch program 35 program: Program 36 37 # This dictionary maps the instruction ids to their corresponding 38 # debug handles or list of debug handles in the case of delegate calls. 39 debug_handle_map: Dict[int, Union[int, List[int]]] 40 41 # This dictionary maps the method name to the corresponding dict which 42 # contains the mapping of the delegate instruction id to its corresponding 43 # delegate name and delegate debug identifier mapping. 44 method_to_delegate_debug_id_map: Dict[ 45 str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]] 46 ] 47 48 mutable_data: Optional[List[Buffer]] 49 50 51def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.GraphModule: 52 gm = exported_program.graph_module 53 output_node = None 54 for node in gm.graph.nodes: 55 if node.op == "output": 56 output_node = node 57 assert output_node is not None 58 59 mutated_outputs: List[Optional[str]] = [ 60 out_spec.target if out_spec.kind in (OutputKind.BUFFER_MUTATION,) else None 61 for out_spec in exported_program.graph_signature.output_specs 62 ] 63 outputs = pytree.tree_flatten(output_node.args)[0] 64 65 user_output_nodes = [] 66 for return_node, mutated_node_name in zip(outputs, mutated_outputs): 67 if mutated_node_name is None: 68 user_output_nodes.append(return_node) 69 continue 70 71 with gm.graph.inserting_before(output_node): 72 # Only return user outputs 73 new_output = gm.graph.output(tuple(user_output_nodes)) 74 new_output.meta = output_node.meta.copy() 75 output_node.replace_all_uses_with(new_output) 76 gm.graph.erase_node(output_node) 77 78 return gm 79 80 81# For each entry point in the model, determine if its a joint graph, 82# and if it is return a map of the indices in the model output that the 83# gradient outputs start at and that the parameter outputs start at. 84def _get_training_metadata(methods: Dict[str, ExportedProgram]) -> Dict[str, int]: 85 gradients_method_prefix = "__et_training_gradients_index_" 86 parameters_method_prefix = "__et_training_parameters_index_" 87 fqn_method_prefix = "__et_training_fqn_" 88 training_metadata = {} 89 for name, method in methods.items(): 90 found_grad = False 91 found_param = False 92 fqns = [] 93 i = 0 94 for output_spec in method.graph_signature.output_specs: 95 if output_spec.kind == OutputKind.GRADIENT_TO_PARAMETER: 96 if not found_grad: 97 training_metadata[gradients_method_prefix + name] = i 98 found_grad = True 99 fqns.append(output_spec.target) 100 elif output_spec.kind == OutputKind.TOKEN and not found_param: 101 assert found_grad # Params must come after gradients 102 training_metadata[parameters_method_prefix + name] = i 103 found_param = True 104 i += 1 105 if len(fqns) > 0: 106 training_metadata[fqn_method_prefix + name] = fqns 107 return training_metadata 108 109 110def emit_program( 111 methods: Union[ExportedProgram, Dict[str, ExportedProgram]], 112 emit_stacktrace: bool = False, 113 prim_getters: Optional[Dict[str, Any]] = None, 114) -> EmitterOutput: 115 """ 116 Given a exported program, it returns the program in the format 117 of the Python version of the flatbuffer Program schema. 118 119 Args: 120 methods: Either the exported program (Exported_Program) that we want to 121 emit into the flatbuffer, or a dictionary of method names to 122 ExportedPrograms. 123 emit_stacktrace: Flag to enable emission of a stacktrace for each 124 instruction for debugging purposes 125 126 Return: 127 The program in a Python class which mimics the flatbuffer schema 128 """ 129 130 if isinstance(methods, ExportedProgram): 131 methods = {"forward": methods} 132 133 # validation 134 bad_methods = [] 135 for name, exported_program in methods.items(): 136 if not isinstance(exported_program, ExportedProgram): 137 bad_methods.append(name) 138 if len(bad_methods) != 0: 139 raise ExportError( 140 ExportErrorType.INVALID_INPUT_TYPE, 141 f"Did not receive ExportedProgram for the following methods {str(bad_methods)}", 142 ) 143 144 plans = [] 145 debug_handle_map = {} 146 method_to_delegate_debug_id_map = {} 147 program_state = _ProgramState() 148 149 # emit each entry point in order according to name. 150 for name, exported_program in sorted(methods.items()): 151 # create empty state 152 emitter_state = _EmitterState( 153 values=[], 154 operators=[], 155 delegates=[], 156 operator_cache={}, 157 delegate_cache={}, 158 emit_stacktrace=emit_stacktrace, 159 ) 160 161 gm = _remove_non_user_outputs(exported_program) 162 163 emitter = _TopLevelEmitter( 164 name, exported_program, gm, program_state, emitter_state 165 ) 166 167 emitter.run() 168 plans.append(emitter.plan()) 169 170 debug_handle_map[name] = emitter.debug_handle_map 171 method_to_delegate_debug_id_map[name] = ( 172 emitter.instr_id_to_delegate_debug_id_map 173 ) 174 175 training_metadata = _get_training_metadata(methods) 176 if len(training_metadata) > 0: 177 plans.extend(emitter._emit_prim_getters(training_metadata)) 178 179 # emit any primitive getters 180 if prim_getters is not None: 181 plans.extend(emitter._emit_prim_getters(prim_getters)) 182 183 return EmitterOutput( 184 debug_handle_map=debug_handle_map, 185 method_to_delegate_debug_id_map=method_to_delegate_debug_id_map, 186 program=Program( 187 version=EXECUTORCH_SCHEMA_VERSION, 188 execution_plan=plans, 189 constant_buffer=program_state.constant_buffer, 190 backend_delegate_data=program_state.backend_delegate_data, 191 # Segments may be added at serialization time. 192 segments=[], 193 # Subsegment offsets may be added at serialization time. 194 constant_segment=SubsegmentOffsets(segment_index=0, offsets=[]), 195 mutable_data_segments=None, # Will be filled in during serialization 196 ), 197 mutable_data=( 198 program_state.mutable_buffer 199 if len(program_state.mutable_buffer) > 1 200 else None 201 ), 202 ) 203