1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-unsafe 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport copy 10*523fa7a6SAndroid Build Coastguard Workerimport io 11*523fa7a6SAndroid Build Coastguard Workerimport logging 12*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerimport torch 15*523fa7a6SAndroid Build Coastguard Workerimport torch._export 16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir._serialize import _serialize_pte_binary 17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir._serialize._cord import Cord 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir._warnings import experimental 19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_api import to_backend 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.partitioner import Partitioner 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.emit import emit_program, EmitterOutput 23*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.emit._emitter import _DelegateDebugIdentifierMap 24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import ExportError 25*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.graph_module import get_control_flow_submodules 26*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_base import PassBase 27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.pass_manager import PassType 28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes import ( 29*523fa7a6SAndroid Build Coastguard Worker base_post_op_replace_passes, 30*523fa7a6SAndroid Build Coastguard Worker base_pre_op_replace_passes, 31*523fa7a6SAndroid Build Coastguard Worker dead_code_elimination_pass, 32*523fa7a6SAndroid Build Coastguard Worker EdgeToBackendOpsPass, 33*523fa7a6SAndroid Build Coastguard Worker MemoryFormatOpsPass, 34*523fa7a6SAndroid Build Coastguard Worker OpReplacePass, 35*523fa7a6SAndroid Build Coastguard Worker) 36*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.insert_write_back_for_buffers_pass import ( 37*523fa7a6SAndroid Build Coastguard Worker insert_write_back_for_buffers_pass, 38*523fa7a6SAndroid Build Coastguard Worker) 39*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.normalize_view_copy_base_pass import ( 40*523fa7a6SAndroid Build Coastguard Worker NormalizeViewCopyBasePass, 41*523fa7a6SAndroid Build Coastguard Worker) 42*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass 43*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators 44*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge 45*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.replace_view_copy_with_view_pass import ( 46*523fa7a6SAndroid Build Coastguard Worker ReplaceViewCopyWithViewPass, 47*523fa7a6SAndroid Build Coastguard Worker) 48*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.spec_prop_pass import SpecPropPass 49*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass 50*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.print_program import pretty_print, print_program 51*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import Program 52*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import _default_decomposition_table 53*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.verifier import ( 54*523fa7a6SAndroid Build Coastguard Worker EXIRATenDialectVerifier, 55*523fa7a6SAndroid Build Coastguard Worker EXIREdgeDialectVerifier, 56*523fa7a6SAndroid Build Coastguard Worker get_aten_verifier, 57*523fa7a6SAndroid Build Coastguard Worker) 58*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass 59*523fa7a6SAndroid Build Coastguard Workerfrom torch.export import ExportedProgram 60*523fa7a6SAndroid Build Coastguard Workerfrom torch.export._remove_auto_functionalized_pass import ( 61*523fa7a6SAndroid Build Coastguard Worker unsafe_remove_auto_functionalized_pass, 62*523fa7a6SAndroid Build Coastguard Worker) 63*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ( 64*523fa7a6SAndroid Build Coastguard Worker ConstantArgument, 65*523fa7a6SAndroid Build Coastguard Worker ExportGraphSignature, 66*523fa7a6SAndroid Build Coastguard Worker InputKind, 67*523fa7a6SAndroid Build Coastguard Worker InputSpec, 68*523fa7a6SAndroid Build Coastguard Worker OutputSpec, 69*523fa7a6SAndroid Build Coastguard Worker TensorArgument, 70*523fa7a6SAndroid Build Coastguard Worker) 71*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx import _pytree as fx_pytree 72*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx._compatibility import compatibility 73*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.passes.infra.pass_manager import PassManager 74*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard WorkerVal = Any 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Workerfrom torch.library import Library 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Worker# This is the reserved namespace that is used to register ops to that will 81*523fa7a6SAndroid Build Coastguard Worker# be prevented from being decomposed during to_edge_transform_and_lower. 82*523fa7a6SAndroid Build Coastguard Workeredge_no_decomp_namespace = "EDGE_DO_NOT_DECOMP" 83*523fa7a6SAndroid Build Coastguard Workerlib = Library(edge_no_decomp_namespace, "DEF") 84*523fa7a6SAndroid Build Coastguard Worker# Map from aten ops to the transformed ops registered in the edge_no_decomp_namespace. 85*523fa7a6SAndroid Build Coastguard Workeraten_op_to_transform_op = {} 86*523fa7a6SAndroid Build Coastguard Worker# Map from the transformed ops registered in the edge_no_decomp_namespace to aten ops. 87*523fa7a6SAndroid Build Coastguard Workertransform_op_to_aten_op = {} 88*523fa7a6SAndroid Build Coastguard Worker 89*523fa7a6SAndroid Build Coastguard Worker 90*523fa7a6SAndroid Build Coastguard Workerdef _get_updated_range_constraints(gm): 91*523fa7a6SAndroid Build Coastguard Worker def get_shape_env(gm): 92*523fa7a6SAndroid Build Coastguard Worker vals = [ 93*523fa7a6SAndroid Build Coastguard Worker node.meta["val"] 94*523fa7a6SAndroid Build Coastguard Worker for node in gm.graph.nodes 95*523fa7a6SAndroid Build Coastguard Worker if node.meta.get("val", None) is not None 96*523fa7a6SAndroid Build Coastguard Worker ] 97*523fa7a6SAndroid Build Coastguard Worker from torch._guards import detect_fake_mode # type: ignore[21] 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker fake_mode = detect_fake_mode(vals) 100*523fa7a6SAndroid Build Coastguard Worker if fake_mode is not None: 101*523fa7a6SAndroid Build Coastguard Worker return fake_mode.shape_env 102*523fa7a6SAndroid Build Coastguard Worker for v in vals: 103*523fa7a6SAndroid Build Coastguard Worker if isinstance(v, torch.SymInt): 104*523fa7a6SAndroid Build Coastguard Worker return v.node.shape_env 105*523fa7a6SAndroid Build Coastguard Worker 106*523fa7a6SAndroid Build Coastguard Worker shape_env = get_shape_env(gm) 107*523fa7a6SAndroid Build Coastguard Worker if shape_env is None: 108*523fa7a6SAndroid Build Coastguard Worker return {} 109*523fa7a6SAndroid Build Coastguard Worker range_constraints = { 110*523fa7a6SAndroid Build Coastguard Worker k: v 111*523fa7a6SAndroid Build Coastguard Worker for k, v in shape_env.var_to_range.items() 112*523fa7a6SAndroid Build Coastguard Worker if k not in shape_env.replacements 113*523fa7a6SAndroid Build Coastguard Worker } 114*523fa7a6SAndroid Build Coastguard Worker # Only when we have an unbacked symint, and it's used as constructor inputs, 115*523fa7a6SAndroid Build Coastguard Worker # runtime_var_to_range will make a difference compated to var_to_range. 116*523fa7a6SAndroid Build Coastguard Worker # e.g. [2, oo) -> [0, oo) 117*523fa7a6SAndroid Build Coastguard Worker for k, v in shape_env.var_to_range.items(): 118*523fa7a6SAndroid Build Coastguard Worker if k not in shape_env.replacements: 119*523fa7a6SAndroid Build Coastguard Worker range_constraints[k] = v 120*523fa7a6SAndroid Build Coastguard Worker return range_constraints 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker 123*523fa7a6SAndroid Build Coastguard Workerdef _get_updated_graph_signature( 124*523fa7a6SAndroid Build Coastguard Worker old_signature: ExportGraphSignature, 125*523fa7a6SAndroid Build Coastguard Worker new_gm: torch.fx.GraphModule, 126*523fa7a6SAndroid Build Coastguard Worker) -> ExportGraphSignature: 127*523fa7a6SAndroid Build Coastguard Worker """ 128*523fa7a6SAndroid Build Coastguard Worker Update the graph signature's user_input/user_outputs. 129*523fa7a6SAndroid Build Coastguard Worker """ 130*523fa7a6SAndroid Build Coastguard Worker new_input_specs = [] 131*523fa7a6SAndroid Build Coastguard Worker i = 0 132*523fa7a6SAndroid Build Coastguard Worker for node in new_gm.graph.nodes: 133*523fa7a6SAndroid Build Coastguard Worker if node.op != "placeholder": 134*523fa7a6SAndroid Build Coastguard Worker continue 135*523fa7a6SAndroid Build Coastguard Worker 136*523fa7a6SAndroid Build Coastguard Worker assert i < len( 137*523fa7a6SAndroid Build Coastguard Worker old_signature.input_specs 138*523fa7a6SAndroid Build Coastguard Worker ), "Number of inputs changed after transformation" 139*523fa7a6SAndroid Build Coastguard Worker old_input_spec = old_signature.input_specs[i] 140*523fa7a6SAndroid Build Coastguard Worker arg = ( 141*523fa7a6SAndroid Build Coastguard Worker old_input_spec.arg 142*523fa7a6SAndroid Build Coastguard Worker if isinstance(old_input_spec.arg, ConstantArgument) 143*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[20]: Argument `class_fqn` expected. 144*523fa7a6SAndroid Build Coastguard Worker else type(old_input_spec.arg)(node.name) 145*523fa7a6SAndroid Build Coastguard Worker ) 146*523fa7a6SAndroid Build Coastguard Worker new_input_specs.append( 147*523fa7a6SAndroid Build Coastguard Worker InputSpec( 148*523fa7a6SAndroid Build Coastguard Worker old_input_spec.kind, 149*523fa7a6SAndroid Build Coastguard Worker arg, 150*523fa7a6SAndroid Build Coastguard Worker old_input_spec.target, 151*523fa7a6SAndroid Build Coastguard Worker persistent=old_input_spec.persistent, 152*523fa7a6SAndroid Build Coastguard Worker ) 153*523fa7a6SAndroid Build Coastguard Worker ) 154*523fa7a6SAndroid Build Coastguard Worker i += 1 155*523fa7a6SAndroid Build Coastguard Worker 156*523fa7a6SAndroid Build Coastguard Worker output_node = list(new_gm.graph.nodes)[-1] 157*523fa7a6SAndroid Build Coastguard Worker assert output_node.op == "output" 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Worker new_output_specs = [] 160*523fa7a6SAndroid Build Coastguard Worker for i, node in enumerate(output_node.args[0]): 161*523fa7a6SAndroid Build Coastguard Worker assert i < len( 162*523fa7a6SAndroid Build Coastguard Worker old_signature.output_specs 163*523fa7a6SAndroid Build Coastguard Worker ), "Number of outputs changed after transformation" 164*523fa7a6SAndroid Build Coastguard Worker old_output_spec = old_signature.output_specs[i] 165*523fa7a6SAndroid Build Coastguard Worker arg = ( 166*523fa7a6SAndroid Build Coastguard Worker old_output_spec.arg 167*523fa7a6SAndroid Build Coastguard Worker if isinstance(old_output_spec.arg, ConstantArgument) 168*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[20]: Argument `class_fqn` expected. 169*523fa7a6SAndroid Build Coastguard Worker else type(old_output_spec.arg)(node.name) 170*523fa7a6SAndroid Build Coastguard Worker ) 171*523fa7a6SAndroid Build Coastguard Worker new_output_specs.append( 172*523fa7a6SAndroid Build Coastguard Worker OutputSpec(old_output_spec.kind, arg, old_output_spec.target) 173*523fa7a6SAndroid Build Coastguard Worker ) 174*523fa7a6SAndroid Build Coastguard Worker 175*523fa7a6SAndroid Build Coastguard Worker new_signature = ExportGraphSignature( 176*523fa7a6SAndroid Build Coastguard Worker input_specs=new_input_specs, output_specs=new_output_specs 177*523fa7a6SAndroid Build Coastguard Worker ) 178*523fa7a6SAndroid Build Coastguard Worker return new_signature 179*523fa7a6SAndroid Build Coastguard Worker 180*523fa7a6SAndroid Build Coastguard Worker 181*523fa7a6SAndroid Build Coastguard Workerdef _transform(self, *passes: PassType) -> "ExportedProgram": 182*523fa7a6SAndroid Build Coastguard Worker pm = PassManager(list(passes)) 183*523fa7a6SAndroid Build Coastguard Worker res = pm(self.graph_module) 184*523fa7a6SAndroid Build Coastguard Worker transformed_gm = res.graph_module if res is not None else self.graph_module 185*523fa7a6SAndroid Build Coastguard Worker assert transformed_gm is not None 186*523fa7a6SAndroid Build Coastguard Worker 187*523fa7a6SAndroid Build Coastguard Worker if transformed_gm is self.graph_module and not res.modified: 188*523fa7a6SAndroid Build Coastguard Worker return self 189*523fa7a6SAndroid Build Coastguard Worker 190*523fa7a6SAndroid Build Coastguard Worker transformed_ep = ExportedProgram( 191*523fa7a6SAndroid Build Coastguard Worker root=transformed_gm, 192*523fa7a6SAndroid Build Coastguard Worker graph=transformed_gm.graph, 193*523fa7a6SAndroid Build Coastguard Worker graph_signature=_get_updated_graph_signature( 194*523fa7a6SAndroid Build Coastguard Worker self.graph_signature, transformed_gm 195*523fa7a6SAndroid Build Coastguard Worker ), 196*523fa7a6SAndroid Build Coastguard Worker state_dict=self.state_dict, 197*523fa7a6SAndroid Build Coastguard Worker range_constraints=_get_updated_range_constraints(transformed_gm), 198*523fa7a6SAndroid Build Coastguard Worker module_call_graph=copy.deepcopy(self._module_call_graph), 199*523fa7a6SAndroid Build Coastguard Worker example_inputs=self.example_inputs, 200*523fa7a6SAndroid Build Coastguard Worker constants=self.constants, 201*523fa7a6SAndroid Build Coastguard Worker verifiers=[self.verifier], 202*523fa7a6SAndroid Build Coastguard Worker ) 203*523fa7a6SAndroid Build Coastguard Worker transformed_ep.graph_module.meta.update(self.graph_module.meta) 204*523fa7a6SAndroid Build Coastguard Worker transformed_ep.graph_module.meta.update(res.graph_module.meta) 205*523fa7a6SAndroid Build Coastguard Worker return transformed_ep 206*523fa7a6SAndroid Build Coastguard Worker 207*523fa7a6SAndroid Build Coastguard Worker 208*523fa7a6SAndroid Build Coastguard Workerdef _copy_module(new_prog, new_gm): 209*523fa7a6SAndroid Build Coastguard Worker new_prog.meta.update(new_gm.meta) 210*523fa7a6SAndroid Build Coastguard Worker new_prog.graph = new_gm.graph 211*523fa7a6SAndroid Build Coastguard Worker submodules = [name for name, _ in new_prog.named_children()] 212*523fa7a6SAndroid Build Coastguard Worker for name in submodules: 213*523fa7a6SAndroid Build Coastguard Worker delattr(new_prog, name) 214*523fa7a6SAndroid Build Coastguard Worker for name, mod in new_gm.named_children(): 215*523fa7a6SAndroid Build Coastguard Worker setattr(new_prog, name, mod) 216*523fa7a6SAndroid Build Coastguard Worker for node in new_gm.graph.nodes: 217*523fa7a6SAndroid Build Coastguard Worker if node.op == "get_attr": 218*523fa7a6SAndroid Build Coastguard Worker t = getattr(new_gm, node.target, None) 219*523fa7a6SAndroid Build Coastguard Worker if isinstance(t, torch.Tensor): 220*523fa7a6SAndroid Build Coastguard Worker setattr(new_prog, node.target, t) 221*523fa7a6SAndroid Build Coastguard Worker 222*523fa7a6SAndroid Build Coastguard Worker 223*523fa7a6SAndroid Build Coastguard Workerdef lift_constant_tensor_pass(ep): 224*523fa7a6SAndroid Build Coastguard Worker """ 225*523fa7a6SAndroid Build Coastguard Worker Takes an ExportedProgram and returns the ExportedProgram modified in-place, 226*523fa7a6SAndroid Build Coastguard Worker with the constant tensors as buffers. 227*523fa7a6SAndroid Build Coastguard Worker """ 228*523fa7a6SAndroid Build Coastguard Worker if len([node for node in ep.graph.nodes if node.op == "placeholder"]) == 0: 229*523fa7a6SAndroid Build Coastguard Worker return ep 230*523fa7a6SAndroid Build Coastguard Worker 231*523fa7a6SAndroid Build Coastguard Worker graph_signature = ep.graph_signature 232*523fa7a6SAndroid Build Coastguard Worker buffers = list(graph_signature.buffers) 233*523fa7a6SAndroid Build Coastguard Worker 234*523fa7a6SAndroid Build Coastguard Worker fake_mode = list(ep.graph.nodes)[0].meta["val"].fake_mode 235*523fa7a6SAndroid Build Coastguard Worker first_user_input = None 236*523fa7a6SAndroid Build Coastguard Worker lifted_constants = [] 237*523fa7a6SAndroid Build Coastguard Worker for node in ep.graph.nodes: 238*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder" and node.name in graph_signature.user_inputs: 239*523fa7a6SAndroid Build Coastguard Worker first_user_input = node 240*523fa7a6SAndroid Build Coastguard Worker break 241*523fa7a6SAndroid Build Coastguard Worker 242*523fa7a6SAndroid Build Coastguard Worker for node in ep.graph.nodes: 243*523fa7a6SAndroid Build Coastguard Worker if node.op == "get_attr": 244*523fa7a6SAndroid Build Coastguard Worker constant_tensor = getattr(ep.graph_module, node.target) 245*523fa7a6SAndroid Build Coastguard Worker if not isinstance(constant_tensor, torch.Tensor): 246*523fa7a6SAndroid Build Coastguard Worker continue 247*523fa7a6SAndroid Build Coastguard Worker 248*523fa7a6SAndroid Build Coastguard Worker constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}" 249*523fa7a6SAndroid Build Coastguard Worker 250*523fa7a6SAndroid Build Coastguard Worker with ep.graph.inserting_before(first_user_input): 251*523fa7a6SAndroid Build Coastguard Worker # Insert the constant node before the first user input 252*523fa7a6SAndroid Build Coastguard Worker const_placeholder_node = ep.graph.placeholder(constant_tensor_fqn) 253*523fa7a6SAndroid Build Coastguard Worker for k, v in node.meta.items(): 254*523fa7a6SAndroid Build Coastguard Worker const_placeholder_node.meta[k] = v 255*523fa7a6SAndroid Build Coastguard Worker if fake_mode is not None: 256*523fa7a6SAndroid Build Coastguard Worker const_placeholder_node.meta["val"] = fake_mode.from_tensor( 257*523fa7a6SAndroid Build Coastguard Worker constant_tensor, static_shapes=True 258*523fa7a6SAndroid Build Coastguard Worker ) 259*523fa7a6SAndroid Build Coastguard Worker else: 260*523fa7a6SAndroid Build Coastguard Worker const_placeholder_node.meta["val"] = constant_tensor 261*523fa7a6SAndroid Build Coastguard Worker const_placeholder_node.meta["val"].constant = constant_tensor 262*523fa7a6SAndroid Build Coastguard Worker node.replace_all_uses_with(const_placeholder_node) 263*523fa7a6SAndroid Build Coastguard Worker ep.graph.erase_node(node) 264*523fa7a6SAndroid Build Coastguard Worker 265*523fa7a6SAndroid Build Coastguard Worker # Add the constant as a buffer to the graph signature 266*523fa7a6SAndroid Build Coastguard Worker lifted_constants.append( 267*523fa7a6SAndroid Build Coastguard Worker InputSpec( 268*523fa7a6SAndroid Build Coastguard Worker kind=InputKind.BUFFER, 269*523fa7a6SAndroid Build Coastguard Worker arg=TensorArgument(name=const_placeholder_node.name), 270*523fa7a6SAndroid Build Coastguard Worker target=constant_tensor_fqn, 271*523fa7a6SAndroid Build Coastguard Worker persistent=True, 272*523fa7a6SAndroid Build Coastguard Worker ) 273*523fa7a6SAndroid Build Coastguard Worker ) 274*523fa7a6SAndroid Build Coastguard Worker buffers.append(constant_tensor_fqn) 275*523fa7a6SAndroid Build Coastguard Worker ep.state_dict[constant_tensor_fqn] = constant_tensor 276*523fa7a6SAndroid Build Coastguard Worker 277*523fa7a6SAndroid Build Coastguard Worker new_input_specs = [] 278*523fa7a6SAndroid Build Coastguard Worker for s in graph_signature.input_specs: 279*523fa7a6SAndroid Build Coastguard Worker if s.kind == InputKind.USER_INPUT and len(lifted_constants) > 0: 280*523fa7a6SAndroid Build Coastguard Worker new_input_specs.extend(lifted_constants) 281*523fa7a6SAndroid Build Coastguard Worker lifted_constants.clear() 282*523fa7a6SAndroid Build Coastguard Worker new_input_specs.append(s) 283*523fa7a6SAndroid Build Coastguard Worker ep.graph_signature.input_specs = new_input_specs 284*523fa7a6SAndroid Build Coastguard Worker ep.graph_module.recompile() 285*523fa7a6SAndroid Build Coastguard Worker return ep 286*523fa7a6SAndroid Build Coastguard Worker 287*523fa7a6SAndroid Build Coastguard Worker 288*523fa7a6SAndroid Build Coastguard Worker# Stub to ease migration from `transform` to private `_transform` 289*523fa7a6SAndroid Build Coastguard Workerdef transform_exported_program(ep, *passes: PassType) -> ExportedProgram: 290*523fa7a6SAndroid Build Coastguard Worker if hasattr(ep, "_transform"): 291*523fa7a6SAndroid Build Coastguard Worker return ep._transform(*passes) 292*523fa7a6SAndroid Build Coastguard Worker else: 293*523fa7a6SAndroid Build Coastguard Worker return ep.transform(*passes) 294*523fa7a6SAndroid Build Coastguard Worker 295*523fa7a6SAndroid Build Coastguard Worker 296*523fa7a6SAndroid Build Coastguard Workerclass HackedUpExportedProgramDONOTUSE(ExportedProgram): 297*523fa7a6SAndroid Build Coastguard Worker def __init__( 298*523fa7a6SAndroid Build Coastguard Worker self, 299*523fa7a6SAndroid Build Coastguard Worker root, 300*523fa7a6SAndroid Build Coastguard Worker graph, 301*523fa7a6SAndroid Build Coastguard Worker graph_signature, 302*523fa7a6SAndroid Build Coastguard Worker call_spec, 303*523fa7a6SAndroid Build Coastguard Worker state_dict, 304*523fa7a6SAndroid Build Coastguard Worker range_constraints, 305*523fa7a6SAndroid Build Coastguard Worker module_call_graph, 306*523fa7a6SAndroid Build Coastguard Worker example_inputs, 307*523fa7a6SAndroid Build Coastguard Worker verifier, 308*523fa7a6SAndroid Build Coastguard Worker ): 309*523fa7a6SAndroid Build Coastguard Worker super().__init__( 310*523fa7a6SAndroid Build Coastguard Worker root=root, 311*523fa7a6SAndroid Build Coastguard Worker graph=graph, 312*523fa7a6SAndroid Build Coastguard Worker graph_signature=graph_signature, 313*523fa7a6SAndroid Build Coastguard Worker state_dict=state_dict, 314*523fa7a6SAndroid Build Coastguard Worker range_constraints=range_constraints, 315*523fa7a6SAndroid Build Coastguard Worker module_call_graph=module_call_graph, 316*523fa7a6SAndroid Build Coastguard Worker example_inputs=example_inputs, 317*523fa7a6SAndroid Build Coastguard Worker verifier=verifier, 318*523fa7a6SAndroid Build Coastguard Worker ) 319*523fa7a6SAndroid Build Coastguard Worker 320*523fa7a6SAndroid Build Coastguard Worker def __call__(self, *args: Any, **kwargs: Any) -> Any: 321*523fa7a6SAndroid Build Coastguard Worker import torch._export.error as error 322*523fa7a6SAndroid Build Coastguard Worker 323*523fa7a6SAndroid Build Coastguard Worker if self.call_spec.in_spec is not None: 324*523fa7a6SAndroid Build Coastguard Worker user_args = args 325*523fa7a6SAndroid Build Coastguard Worker try: 326*523fa7a6SAndroid Build Coastguard Worker args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec) # type: ignore[assignment] 327*523fa7a6SAndroid Build Coastguard Worker except Exception: 328*523fa7a6SAndroid Build Coastguard Worker _, received_spec = pytree.tree_flatten(user_args) 329*523fa7a6SAndroid Build Coastguard Worker raise error.InternalError( 330*523fa7a6SAndroid Build Coastguard Worker "Trying to flatten user inputs with exported input tree spec: \n" 331*523fa7a6SAndroid Build Coastguard Worker f"{self.call_spec.in_spec}\n" 332*523fa7a6SAndroid Build Coastguard Worker "but actually got inputs with tree spec of: \n" 333*523fa7a6SAndroid Build Coastguard Worker f"{received_spec}" 334*523fa7a6SAndroid Build Coastguard Worker ) 335*523fa7a6SAndroid Build Coastguard Worker 336*523fa7a6SAndroid Build Coastguard Worker ordered_params = tuple( 337*523fa7a6SAndroid Build Coastguard Worker self.state_dict[name] for name in self.graph_signature.parameters 338*523fa7a6SAndroid Build Coastguard Worker ) 339*523fa7a6SAndroid Build Coastguard Worker ordered_buffers = tuple( 340*523fa7a6SAndroid Build Coastguard Worker self.state_dict[name] for name in self.graph_signature.buffers 341*523fa7a6SAndroid Build Coastguard Worker ) 342*523fa7a6SAndroid Build Coastguard Worker 343*523fa7a6SAndroid Build Coastguard Worker with torch.no_grad(): 344*523fa7a6SAndroid Build Coastguard Worker # NOTE: calling convention is first params, then buffers, then args as user supplied them. 345*523fa7a6SAndroid Build Coastguard Worker # See: torch/_functorch/aot_autograd.py#L1034 346*523fa7a6SAndroid Build Coastguard Worker res = torch.fx.Interpreter(self.graph_module).run( 347*523fa7a6SAndroid Build Coastguard Worker *ordered_params, *ordered_buffers, *args, enable_io_processing=False 348*523fa7a6SAndroid Build Coastguard Worker ) 349*523fa7a6SAndroid Build Coastguard Worker 350*523fa7a6SAndroid Build Coastguard Worker if self.call_spec.out_spec is not None: 351*523fa7a6SAndroid Build Coastguard Worker mutation = self.graph_signature.buffers_to_mutate 352*523fa7a6SAndroid Build Coastguard Worker num_mutated = len(mutation) 353*523fa7a6SAndroid Build Coastguard Worker mutated_buffers = res[:num_mutated] 354*523fa7a6SAndroid Build Coastguard Worker 355*523fa7a6SAndroid Build Coastguard Worker # Exclude dependency token from final result. 356*523fa7a6SAndroid Build Coastguard Worker assertion_dep_token = self.graph_signature.assertion_dep_token 357*523fa7a6SAndroid Build Coastguard Worker if assertion_dep_token is not None: 358*523fa7a6SAndroid Build Coastguard Worker assertion_dep_token_index = list(assertion_dep_token.keys())[0] 359*523fa7a6SAndroid Build Coastguard Worker res = res[:assertion_dep_token_index] 360*523fa7a6SAndroid Build Coastguard Worker 361*523fa7a6SAndroid Build Coastguard Worker res = res[num_mutated:] 362*523fa7a6SAndroid Build Coastguard Worker try: 363*523fa7a6SAndroid Build Coastguard Worker res = pytree.tree_unflatten(res, self.call_spec.out_spec) 364*523fa7a6SAndroid Build Coastguard Worker except Exception: 365*523fa7a6SAndroid Build Coastguard Worker _, received_spec = pytree.tree_flatten(res) 366*523fa7a6SAndroid Build Coastguard Worker raise error.InternalError( 367*523fa7a6SAndroid Build Coastguard Worker "Trying to flatten user outputs with exported output tree spec: \n" 368*523fa7a6SAndroid Build Coastguard Worker f"{self.call_spec.out_spec}\n" 369*523fa7a6SAndroid Build Coastguard Worker "but actually got outputs with tree spec of: \n" 370*523fa7a6SAndroid Build Coastguard Worker f"{received_spec}" 371*523fa7a6SAndroid Build Coastguard Worker ) 372*523fa7a6SAndroid Build Coastguard Worker finally: 373*523fa7a6SAndroid Build Coastguard Worker ix = 0 374*523fa7a6SAndroid Build Coastguard Worker for buffer in self.graph_signature.buffers_to_mutate.values(): 375*523fa7a6SAndroid Build Coastguard Worker self.state_dict[buffer] = mutated_buffers[ix] 376*523fa7a6SAndroid Build Coastguard Worker ix += 1 377*523fa7a6SAndroid Build Coastguard Worker return res 378*523fa7a6SAndroid Build Coastguard Worker 379*523fa7a6SAndroid Build Coastguard Worker 380*523fa7a6SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 381*523fa7a6SAndroid Build Coastguard Workerclass ExirExportedProgram: 382*523fa7a6SAndroid Build Coastguard Worker def __init__( 383*523fa7a6SAndroid Build Coastguard Worker self, 384*523fa7a6SAndroid Build Coastguard Worker exported_program: ExportedProgram, 385*523fa7a6SAndroid Build Coastguard Worker after_to_edge_passes: bool, 386*523fa7a6SAndroid Build Coastguard Worker ): 387*523fa7a6SAndroid Build Coastguard Worker self.exported_program = exported_program 388*523fa7a6SAndroid Build Coastguard Worker 389*523fa7a6SAndroid Build Coastguard Worker # Add a flag to denote whehter to_edge is called on this program 390*523fa7a6SAndroid Build Coastguard Worker # to detect misusage of directly calling to_executorch without to_edge 391*523fa7a6SAndroid Build Coastguard Worker self.after_to_edge_passes = after_to_edge_passes 392*523fa7a6SAndroid Build Coastguard Worker 393*523fa7a6SAndroid Build Coastguard Worker def transform(self, *passes: PassType) -> "ExirExportedProgram": 394*523fa7a6SAndroid Build Coastguard Worker self.exported_program = _transform(self.exported_program, *passes) 395*523fa7a6SAndroid Build Coastguard Worker return self 396*523fa7a6SAndroid Build Coastguard Worker 397*523fa7a6SAndroid Build Coastguard Worker def __call__(self, *args: Any) -> Any: 398*523fa7a6SAndroid Build Coastguard Worker return self.exported_program.module()(*args) 399*523fa7a6SAndroid Build Coastguard Worker 400*523fa7a6SAndroid Build Coastguard Worker # TODO(ycao): Change this to a composable function. 401*523fa7a6SAndroid Build Coastguard Worker def to_edge( 402*523fa7a6SAndroid Build Coastguard Worker self, config: Optional[EdgeCompileConfig] = None 403*523fa7a6SAndroid Build Coastguard Worker ) -> "ExirExportedProgram": 404*523fa7a6SAndroid Build Coastguard Worker config = config or EdgeCompileConfig() 405*523fa7a6SAndroid Build Coastguard Worker assert isinstance( 406*523fa7a6SAndroid Build Coastguard Worker self.exported_program.graph_module, torch.fx.GraphModule 407*523fa7a6SAndroid Build Coastguard Worker ), f"type is instead: {type(self.exported_program.graph_module).__name__}" 408*523fa7a6SAndroid Build Coastguard Worker 409*523fa7a6SAndroid Build Coastguard Worker return _to_edge(self, config) 410*523fa7a6SAndroid Build Coastguard Worker 411*523fa7a6SAndroid Build Coastguard Worker def dump(self) -> None: 412*523fa7a6SAndroid Build Coastguard Worker print(self.exported_program.graph_module.graph) 413*523fa7a6SAndroid Build Coastguard Worker 414*523fa7a6SAndroid Build Coastguard Worker def to_executorch( 415*523fa7a6SAndroid Build Coastguard Worker self, 416*523fa7a6SAndroid Build Coastguard Worker config: Optional[ExecutorchBackendConfig] = None, 417*523fa7a6SAndroid Build Coastguard Worker ) -> "ExecutorchProgram": 418*523fa7a6SAndroid Build Coastguard Worker if not self.after_to_edge_passes: 419*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError("Must run to_edge before to_executorch.") 420*523fa7a6SAndroid Build Coastguard Worker config = config or ExecutorchBackendConfig() 421*523fa7a6SAndroid Build Coastguard Worker new_gm = self.exported_program.graph_module 422*523fa7a6SAndroid Build Coastguard Worker for p in edge_to_executorch_passes(config): 423*523fa7a6SAndroid Build Coastguard Worker new_gm_res = p(new_gm) 424*523fa7a6SAndroid Build Coastguard Worker assert new_gm_res is not None 425*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 426*523fa7a6SAndroid Build Coastguard Worker 427*523fa7a6SAndroid Build Coastguard Worker # This is tech debt on tech debt. memory planning pass inherits from some pass infra for GMs. 428*523fa7a6SAndroid Build Coastguard Worker # This isnt enough info now so i cant use call I have to use some new function 'run'. 429*523fa7a6SAndroid Build Coastguard Worker # Existing user passes dont use run so Im just cheating here because they dont need to work on mutable buffers yet. 430*523fa7a6SAndroid Build Coastguard Worker # After exir.capture is gone I will clean up the memory planning infra to be consistent. 431*523fa7a6SAndroid Build Coastguard Worker # Frankly all of exir has big code quality issues because of the migrations that need to be addressed. 432*523fa7a6SAndroid Build Coastguard Worker new_gm_res = config.memory_planning_pass(new_gm) # pyre-ignore[29] 433*523fa7a6SAndroid Build Coastguard Worker assert new_gm_res is not None 434*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 435*523fa7a6SAndroid Build Coastguard Worker new_prog = ExirExportedProgram( 436*523fa7a6SAndroid Build Coastguard Worker copy.deepcopy(self.exported_program), self.after_to_edge_passes 437*523fa7a6SAndroid Build Coastguard Worker ) 438*523fa7a6SAndroid Build Coastguard Worker _copy_module(new_prog.exported_program.graph_module, new_gm) 439*523fa7a6SAndroid Build Coastguard Worker executorch_prog = ExecutorchProgram( 440*523fa7a6SAndroid Build Coastguard Worker new_prog, 441*523fa7a6SAndroid Build Coastguard Worker emit_stacktrace=config.emit_stacktrace, 442*523fa7a6SAndroid Build Coastguard Worker extract_delegate_segments=config.extract_delegate_segments, 443*523fa7a6SAndroid Build Coastguard Worker segment_alignment=config.segment_alignment, 444*523fa7a6SAndroid Build Coastguard Worker constant_tensor_alignment=config.constant_tensor_alignment, 445*523fa7a6SAndroid Build Coastguard Worker delegate_alignment=config.delegate_alignment, 446*523fa7a6SAndroid Build Coastguard Worker ) 447*523fa7a6SAndroid Build Coastguard Worker executorch_prog.graph_module.meta.update(new_gm.meta) 448*523fa7a6SAndroid Build Coastguard Worker executorch_prog.graph_module.meta.update( 449*523fa7a6SAndroid Build Coastguard Worker self.exported_program.graph_module.meta 450*523fa7a6SAndroid Build Coastguard Worker ) 451*523fa7a6SAndroid Build Coastguard Worker return executorch_prog 452*523fa7a6SAndroid Build Coastguard Worker 453*523fa7a6SAndroid Build Coastguard Worker def __deepcopy__( 454*523fa7a6SAndroid Build Coastguard Worker self, memo: Optional[Dict[int, Any]] = None 455*523fa7a6SAndroid Build Coastguard Worker ) -> "ExirExportedProgram": 456*523fa7a6SAndroid Build Coastguard Worker new_eep = ExirExportedProgram( 457*523fa7a6SAndroid Build Coastguard Worker copy.deepcopy(self.exported_program, memo), 458*523fa7a6SAndroid Build Coastguard Worker self.after_to_edge_passes, 459*523fa7a6SAndroid Build Coastguard Worker ) 460*523fa7a6SAndroid Build Coastguard Worker return new_eep 461*523fa7a6SAndroid Build Coastguard Worker 462*523fa7a6SAndroid Build Coastguard Worker 463*523fa7a6SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 464*523fa7a6SAndroid Build Coastguard Workerclass ExecutorchProgram: 465*523fa7a6SAndroid Build Coastguard Worker def __init__( 466*523fa7a6SAndroid Build Coastguard Worker self, 467*523fa7a6SAndroid Build Coastguard Worker exir_exported_program: ExirExportedProgram, 468*523fa7a6SAndroid Build Coastguard Worker emit_stacktrace: bool, 469*523fa7a6SAndroid Build Coastguard Worker extract_delegate_segments: bool, 470*523fa7a6SAndroid Build Coastguard Worker segment_alignment: int, 471*523fa7a6SAndroid Build Coastguard Worker constant_tensor_alignment: Optional[int] = None, 472*523fa7a6SAndroid Build Coastguard Worker delegate_alignment: Optional[int] = None, 473*523fa7a6SAndroid Build Coastguard Worker ) -> None: 474*523fa7a6SAndroid Build Coastguard Worker if not exir_exported_program.after_to_edge_passes: 475*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 476*523fa7a6SAndroid Build Coastguard Worker "Need to call prog.to_edge prior to constructing ExecutorchProgram." 477*523fa7a6SAndroid Build Coastguard Worker ) 478*523fa7a6SAndroid Build Coastguard Worker self.exported_program = exir_exported_program.exported_program 479*523fa7a6SAndroid Build Coastguard Worker self._pte_data: Optional[Cord] = None 480*523fa7a6SAndroid Build Coastguard Worker self._buffer: Optional[bytes] = None 481*523fa7a6SAndroid Build Coastguard Worker self._emitter_output: Optional[EmitterOutput] = None 482*523fa7a6SAndroid Build Coastguard Worker self._emit_stacktrace: bool = emit_stacktrace 483*523fa7a6SAndroid Build Coastguard Worker self._extract_delegate_segments: bool = extract_delegate_segments 484*523fa7a6SAndroid Build Coastguard Worker self._segment_alignment: int = segment_alignment 485*523fa7a6SAndroid Build Coastguard Worker self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment 486*523fa7a6SAndroid Build Coastguard Worker self._delegate_alignment: Optional[int] = delegate_alignment 487*523fa7a6SAndroid Build Coastguard Worker 488*523fa7a6SAndroid Build Coastguard Worker def _get_pte_data(self) -> Cord: 489*523fa7a6SAndroid Build Coastguard Worker if self._pte_data is None: 490*523fa7a6SAndroid Build Coastguard Worker self._pte_data = _serialize_pte_binary( 491*523fa7a6SAndroid Build Coastguard Worker program=self.program, 492*523fa7a6SAndroid Build Coastguard Worker extract_delegate_segments=self._extract_delegate_segments, 493*523fa7a6SAndroid Build Coastguard Worker segment_alignment=self._segment_alignment, 494*523fa7a6SAndroid Build Coastguard Worker constant_tensor_alignment=self._constant_tensor_alignment, 495*523fa7a6SAndroid Build Coastguard Worker delegate_alignment=self._delegate_alignment, 496*523fa7a6SAndroid Build Coastguard Worker ) 497*523fa7a6SAndroid Build Coastguard Worker return self._pte_data 498*523fa7a6SAndroid Build Coastguard Worker 499*523fa7a6SAndroid Build Coastguard Worker @property 500*523fa7a6SAndroid Build Coastguard Worker def buffer(self) -> bytes: 501*523fa7a6SAndroid Build Coastguard Worker """Returns the serialized ExecuTorch binary as a byte string. 502*523fa7a6SAndroid Build Coastguard Worker 503*523fa7a6SAndroid Build Coastguard Worker Note that the call to `buffer` may allocate a very large amount of 504*523fa7a6SAndroid Build Coastguard Worker contiguous memory, depending on the model size. If writing to a file, 505*523fa7a6SAndroid Build Coastguard Worker use `write_to_file` which won't incur additional copies. 506*523fa7a6SAndroid Build Coastguard Worker """ 507*523fa7a6SAndroid Build Coastguard Worker # TODO(T181494963): update pybinding to remove buffer cache, which can consume large 508*523fa7a6SAndroid Build Coastguard Worker # amounts of memory longer than necessary. 509*523fa7a6SAndroid Build Coastguard Worker if self._buffer is None: 510*523fa7a6SAndroid Build Coastguard Worker self._buffer = bytes(self._get_pte_data()) 511*523fa7a6SAndroid Build Coastguard Worker return self._buffer 512*523fa7a6SAndroid Build Coastguard Worker 513*523fa7a6SAndroid Build Coastguard Worker @property 514*523fa7a6SAndroid Build Coastguard Worker def program(self) -> Program: 515*523fa7a6SAndroid Build Coastguard Worker if self._emitter_output is None: 516*523fa7a6SAndroid Build Coastguard Worker self._emitter_output = emit_program( 517*523fa7a6SAndroid Build Coastguard Worker self.exported_program, self._emit_stacktrace 518*523fa7a6SAndroid Build Coastguard Worker ) 519*523fa7a6SAndroid Build Coastguard Worker return self._emitter_output.program 520*523fa7a6SAndroid Build Coastguard Worker 521*523fa7a6SAndroid Build Coastguard Worker @property 522*523fa7a6SAndroid Build Coastguard Worker def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]: 523*523fa7a6SAndroid Build Coastguard Worker if self._emitter_output: 524*523fa7a6SAndroid Build Coastguard Worker return self._emitter_output.debug_handle_map 525*523fa7a6SAndroid Build Coastguard Worker return {} 526*523fa7a6SAndroid Build Coastguard Worker 527*523fa7a6SAndroid Build Coastguard Worker @property 528*523fa7a6SAndroid Build Coastguard Worker def delegate_map( 529*523fa7a6SAndroid Build Coastguard Worker self, 530*523fa7a6SAndroid Build Coastguard Worker ) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]: 531*523fa7a6SAndroid Build Coastguard Worker if self._emitter_output: 532*523fa7a6SAndroid Build Coastguard Worker return self._emitter_output.method_to_delegate_debug_id_map 533*523fa7a6SAndroid Build Coastguard Worker return {} 534*523fa7a6SAndroid Build Coastguard Worker 535*523fa7a6SAndroid Build Coastguard Worker @property 536*523fa7a6SAndroid Build Coastguard Worker def graph_module(self) -> torch.fx.GraphModule: 537*523fa7a6SAndroid Build Coastguard Worker return self.exported_program.graph_module 538*523fa7a6SAndroid Build Coastguard Worker 539*523fa7a6SAndroid Build Coastguard Worker # TODO (zhxchen17) Change this to property. 540*523fa7a6SAndroid Build Coastguard Worker def dump_graph_module(self) -> torch.fx.GraphModule: 541*523fa7a6SAndroid Build Coastguard Worker return self.exported_program.graph_module 542*523fa7a6SAndroid Build Coastguard Worker 543*523fa7a6SAndroid Build Coastguard Worker def dump_exported_program(self) -> ExportedProgram: 544*523fa7a6SAndroid Build Coastguard Worker return self.exported_program 545*523fa7a6SAndroid Build Coastguard Worker 546*523fa7a6SAndroid Build Coastguard Worker def write_to_file(self, open_file: io.BufferedIOBase) -> None: 547*523fa7a6SAndroid Build Coastguard Worker """ 548*523fa7a6SAndroid Build Coastguard Worker Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over 549*523fa7a6SAndroid Build Coastguard Worker `buffer`, as it writes to file without copying into a contiguous block of memory first, 550*523fa7a6SAndroid Build Coastguard Worker reducing the peak memory usage. 551*523fa7a6SAndroid Build Coastguard Worker """ 552*523fa7a6SAndroid Build Coastguard Worker self._get_pte_data().write_to_file(open_file) 553*523fa7a6SAndroid Build Coastguard Worker 554*523fa7a6SAndroid Build Coastguard Worker 555*523fa7a6SAndroid Build Coastguard Workerdef _get_aten_to_edge_passes(config: EdgeCompileConfig): 556*523fa7a6SAndroid Build Coastguard Worker # TODO: the last two passes for aten_to_edge need to be eliminated_dead_code -> debug_handle_generator. After enable 557*523fa7a6SAndroid Build Coastguard Worker # use_edge_op it can be moved to aten_to_edge_passes before eliminated_dead_code pass. Also ExportPass doesn't play 558*523fa7a6SAndroid Build Coastguard Worker # well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta. 559*523fa7a6SAndroid Build Coastguard Worker # It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost. 560*523fa7a6SAndroid Build Coastguard Worker 561*523fa7a6SAndroid Build Coastguard Worker pre_op_replace_passes = base_pre_op_replace_passes + ( 562*523fa7a6SAndroid Build Coastguard Worker [] if config._skip_type_promotion else [RemoveMixedTypeOperators()] 563*523fa7a6SAndroid Build Coastguard Worker ) 564*523fa7a6SAndroid Build Coastguard Worker 565*523fa7a6SAndroid Build Coastguard Worker post_op_replace_passes = base_post_op_replace_passes 566*523fa7a6SAndroid Build Coastguard Worker 567*523fa7a6SAndroid Build Coastguard Worker return pre_op_replace_passes, post_op_replace_passes 568*523fa7a6SAndroid Build Coastguard Worker 569*523fa7a6SAndroid Build Coastguard Worker 570*523fa7a6SAndroid Build Coastguard Workerdef _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": 571*523fa7a6SAndroid Build Coastguard Worker if config._check_ir_validity: 572*523fa7a6SAndroid Build Coastguard Worker try: 573*523fa7a6SAndroid Build Coastguard Worker EXIRATenDialectVerifier()(ep.exported_program.graph_module) 574*523fa7a6SAndroid Build Coastguard Worker except ExportError: 575*523fa7a6SAndroid Build Coastguard Worker logging.info( 576*523fa7a6SAndroid Build Coastguard Worker "If a particular operator failed core ATen IR check, please consider adding it to the exception list. " 577*523fa7a6SAndroid Build Coastguard Worker "Add the operator to _core_aten_ops_exception_list in EdgeCompileConfig. This is the recommended way " 578*523fa7a6SAndroid Build Coastguard Worker "to resolve this type of failure, so that the rest of the IR validation check can still be performed.\n" 579*523fa7a6SAndroid Build Coastguard Worker "If you'd like to disable IR validation checking, please set _check_ir_validity in EdgeCompileConfig, " 580*523fa7a6SAndroid Build Coastguard Worker "like *.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))." 581*523fa7a6SAndroid Build Coastguard Worker ) 582*523fa7a6SAndroid Build Coastguard Worker raise 583*523fa7a6SAndroid Build Coastguard Worker 584*523fa7a6SAndroid Build Coastguard Worker dialect = ep.exported_program.dialect 585*523fa7a6SAndroid Build Coastguard Worker if dialect == "ATEN": 586*523fa7a6SAndroid Build Coastguard Worker ep = ExirExportedProgram( 587*523fa7a6SAndroid Build Coastguard Worker ExportedProgram( 588*523fa7a6SAndroid Build Coastguard Worker root=ep.exported_program.graph_module, 589*523fa7a6SAndroid Build Coastguard Worker graph=ep.exported_program.graph_module.graph, 590*523fa7a6SAndroid Build Coastguard Worker graph_signature=ep.exported_program.graph_signature, 591*523fa7a6SAndroid Build Coastguard Worker state_dict=ep.exported_program.state_dict, 592*523fa7a6SAndroid Build Coastguard Worker range_constraints=ep.exported_program.range_constraints, 593*523fa7a6SAndroid Build Coastguard Worker module_call_graph=ep.exported_program.module_call_graph, 594*523fa7a6SAndroid Build Coastguard Worker example_inputs=ep.exported_program.example_inputs, 595*523fa7a6SAndroid Build Coastguard Worker constants=ep.exported_program.constants, 596*523fa7a6SAndroid Build Coastguard Worker verifiers=[ 597*523fa7a6SAndroid Build Coastguard Worker get_aten_verifier( 598*523fa7a6SAndroid Build Coastguard Worker config=config, 599*523fa7a6SAndroid Build Coastguard Worker ) 600*523fa7a6SAndroid Build Coastguard Worker ], 601*523fa7a6SAndroid Build Coastguard Worker ), 602*523fa7a6SAndroid Build Coastguard Worker False, 603*523fa7a6SAndroid Build Coastguard Worker ) 604*523fa7a6SAndroid Build Coastguard Worker pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config) 605*523fa7a6SAndroid Build Coastguard Worker 606*523fa7a6SAndroid Build Coastguard Worker new_ep = copy.deepcopy(ep).transform(*pre_op_replace_passes) 607*523fa7a6SAndroid Build Coastguard Worker if dialect == "ATEN": 608*523fa7a6SAndroid Build Coastguard Worker new_ep.exported_program = lift_constant_tensor_pass(new_ep.exported_program) 609*523fa7a6SAndroid Build Coastguard Worker 610*523fa7a6SAndroid Build Coastguard Worker new_gm = new_ep.exported_program.graph_module 611*523fa7a6SAndroid Build Coastguard Worker if config._use_edge_ops: 612*523fa7a6SAndroid Build Coastguard Worker new_gm_res = OpReplacePass()(new_gm) 613*523fa7a6SAndroid Build Coastguard Worker assert new_gm_res is not None 614*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 615*523fa7a6SAndroid Build Coastguard Worker if not config._skip_dim_order: 616*523fa7a6SAndroid Build Coastguard Worker new_gm_res = MemoryFormatOpsPass()(new_gm) 617*523fa7a6SAndroid Build Coastguard Worker assert new_gm_res is not None 618*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 619*523fa7a6SAndroid Build Coastguard Worker 620*523fa7a6SAndroid Build Coastguard Worker for p in post_op_replace_passes: 621*523fa7a6SAndroid Build Coastguard Worker new_gm_res = p(new_gm) 622*523fa7a6SAndroid Build Coastguard Worker assert new_gm_res is not None 623*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 624*523fa7a6SAndroid Build Coastguard Worker 625*523fa7a6SAndroid Build Coastguard Worker new_ep.exported_program = ExportedProgram( 626*523fa7a6SAndroid Build Coastguard Worker root=new_gm, 627*523fa7a6SAndroid Build Coastguard Worker graph=new_gm.graph, 628*523fa7a6SAndroid Build Coastguard Worker graph_signature=_get_updated_graph_signature( 629*523fa7a6SAndroid Build Coastguard Worker new_ep.exported_program.graph_signature, new_gm 630*523fa7a6SAndroid Build Coastguard Worker ), 631*523fa7a6SAndroid Build Coastguard Worker state_dict=new_ep.exported_program.state_dict, 632*523fa7a6SAndroid Build Coastguard Worker range_constraints=new_ep.exported_program.range_constraints, 633*523fa7a6SAndroid Build Coastguard Worker module_call_graph=new_ep.exported_program.module_call_graph, 634*523fa7a6SAndroid Build Coastguard Worker example_inputs=new_ep.exported_program.example_inputs, 635*523fa7a6SAndroid Build Coastguard Worker constants=new_ep.exported_program.constants, 636*523fa7a6SAndroid Build Coastguard Worker verifiers=[ 637*523fa7a6SAndroid Build Coastguard Worker EXIREdgeDialectVerifier( 638*523fa7a6SAndroid Build Coastguard Worker edge_compile_config=config, 639*523fa7a6SAndroid Build Coastguard Worker class_only=True, 640*523fa7a6SAndroid Build Coastguard Worker ) 641*523fa7a6SAndroid Build Coastguard Worker ], 642*523fa7a6SAndroid Build Coastguard Worker ) 643*523fa7a6SAndroid Build Coastguard Worker new_ep.after_to_edge_passes = True 644*523fa7a6SAndroid Build Coastguard Worker return new_ep 645*523fa7a6SAndroid Build Coastguard Worker 646*523fa7a6SAndroid Build Coastguard Worker 647*523fa7a6SAndroid Build Coastguard Workerdef pre_memory_planning_passes( 648*523fa7a6SAndroid Build Coastguard Worker config: ExecutorchBackendConfig, name: Optional[str] = None 649*523fa7a6SAndroid Build Coastguard Worker) -> List[PassType]: 650*523fa7a6SAndroid Build Coastguard Worker """ 651*523fa7a6SAndroid Build Coastguard Worker Returns a list of passes to run before memory planning. 652*523fa7a6SAndroid Build Coastguard Worker Get the sym shape eval pass based on the method name, if the pass is not in the dict, use the default pass. 653*523fa7a6SAndroid Build Coastguard Worker """ 654*523fa7a6SAndroid Build Coastguard Worker # Handle symbolic shape eval pass 655*523fa7a6SAndroid Build Coastguard Worker if isinstance(config.sym_shape_eval_pass, dict): 656*523fa7a6SAndroid Build Coastguard Worker default_pass = ExecutorchBackendConfig().sym_shape_eval_pass 657*523fa7a6SAndroid Build Coastguard Worker if not name: 658*523fa7a6SAndroid Build Coastguard Worker sym_shape_eval_pass = default_pass 659*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore: Undefined attribute [16] 660*523fa7a6SAndroid Build Coastguard Worker sym_shape_eval_pass = config.sym_shape_eval_pass.get(name, default_pass) 661*523fa7a6SAndroid Build Coastguard Worker elif isinstance(config.sym_shape_eval_pass, PassBase): 662*523fa7a6SAndroid Build Coastguard Worker sym_shape_eval_pass = config.sym_shape_eval_pass 663*523fa7a6SAndroid Build Coastguard Worker else: 664*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 665*523fa7a6SAndroid Build Coastguard Worker f"sym_shape_eval_pass must be a dict or a PassBase, got {config.sym_shape_eval_pass}" 666*523fa7a6SAndroid Build Coastguard Worker ) 667*523fa7a6SAndroid Build Coastguard Worker if config.remove_view_copy: 668*523fa7a6SAndroid Build Coastguard Worker return [ 669*523fa7a6SAndroid Build Coastguard Worker NormalizeViewCopyBasePass(), 670*523fa7a6SAndroid Build Coastguard Worker dead_code_elimination_pass, 671*523fa7a6SAndroid Build Coastguard Worker ReplaceViewCopyWithViewPass(), 672*523fa7a6SAndroid Build Coastguard Worker sym_shape_eval_pass, 673*523fa7a6SAndroid Build Coastguard Worker config.to_out_var_pass, 674*523fa7a6SAndroid Build Coastguard Worker ] 675*523fa7a6SAndroid Build Coastguard Worker else: 676*523fa7a6SAndroid Build Coastguard Worker return [ 677*523fa7a6SAndroid Build Coastguard Worker sym_shape_eval_pass, 678*523fa7a6SAndroid Build Coastguard Worker config.to_out_var_pass, 679*523fa7a6SAndroid Build Coastguard Worker ] 680*523fa7a6SAndroid Build Coastguard Worker 681*523fa7a6SAndroid Build Coastguard Worker 682*523fa7a6SAndroid Build Coastguard Workerdef edge_to_executorch_passes( 683*523fa7a6SAndroid Build Coastguard Worker config: ExecutorchBackendConfig, name: Optional[str] = None 684*523fa7a6SAndroid Build Coastguard Worker) -> List[PassType]: 685*523fa7a6SAndroid Build Coastguard Worker """ 686*523fa7a6SAndroid Build Coastguard Worker Returns a list of passes to lower from edge to executorch. 687*523fa7a6SAndroid Build Coastguard Worker Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass. 688*523fa7a6SAndroid Build Coastguard Worker """ 689*523fa7a6SAndroid Build Coastguard Worker passes: List[PassType] = [ 690*523fa7a6SAndroid Build Coastguard Worker *config.passes, 691*523fa7a6SAndroid Build Coastguard Worker SpecPropPass(), 692*523fa7a6SAndroid Build Coastguard Worker # ExecuTorch backend ops are unable to handle unbacked symints. So after 693*523fa7a6SAndroid Build Coastguard Worker # this pass, passes cannot be Interpreter-based, because it will fail if 694*523fa7a6SAndroid Build Coastguard Worker # there exists an unbacked symint operation. 695*523fa7a6SAndroid Build Coastguard Worker EdgeToBackendOpsPass(), 696*523fa7a6SAndroid Build Coastguard Worker RemoveGraphAssertsPass(), 697*523fa7a6SAndroid Build Coastguard Worker ] + pre_memory_planning_passes(config, name) 698*523fa7a6SAndroid Build Coastguard Worker 699*523fa7a6SAndroid Build Coastguard Worker return passes 700*523fa7a6SAndroid Build Coastguard Worker 701*523fa7a6SAndroid Build Coastguard Worker 702*523fa7a6SAndroid Build Coastguard Workerdef _generate_edge_program( 703*523fa7a6SAndroid Build Coastguard Worker name: str, 704*523fa7a6SAndroid Build Coastguard Worker config: EdgeCompileConfig, 705*523fa7a6SAndroid Build Coastguard Worker program: ExportedProgram, 706*523fa7a6SAndroid Build Coastguard Worker ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None, 707*523fa7a6SAndroid Build Coastguard Worker) -> ExportedProgram: 708*523fa7a6SAndroid Build Coastguard Worker if config._check_ir_validity: 709*523fa7a6SAndroid Build Coastguard Worker try: 710*523fa7a6SAndroid Build Coastguard Worker EXIRATenDialectVerifier( 711*523fa7a6SAndroid Build Coastguard Worker edge_compile_config=config, 712*523fa7a6SAndroid Build Coastguard Worker class_only=False, 713*523fa7a6SAndroid Build Coastguard Worker exception_list=ops_set_to_not_decompose, 714*523fa7a6SAndroid Build Coastguard Worker )(program.graph_module) 715*523fa7a6SAndroid Build Coastguard Worker except ExportError as e: 716*523fa7a6SAndroid Build Coastguard Worker logging.info(f"Input program {name} is not in ATen dialect.") 717*523fa7a6SAndroid Build Coastguard Worker raise e 718*523fa7a6SAndroid Build Coastguard Worker 719*523fa7a6SAndroid Build Coastguard Worker pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config) 720*523fa7a6SAndroid Build Coastguard Worker 721*523fa7a6SAndroid Build Coastguard Worker passes = [] 722*523fa7a6SAndroid Build Coastguard Worker passes.append( 723*523fa7a6SAndroid Build Coastguard Worker ReplaceViewOpsWithViewCopyOpsPass() 724*523fa7a6SAndroid Build Coastguard Worker ) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture 725*523fa7a6SAndroid Build Coastguard Worker passes.extend(pre_op_replace_passes) 726*523fa7a6SAndroid Build Coastguard Worker if config._use_edge_ops: 727*523fa7a6SAndroid Build Coastguard Worker passes.append(OpReplacePass()) 728*523fa7a6SAndroid Build Coastguard Worker if not config._skip_dim_order: 729*523fa7a6SAndroid Build Coastguard Worker passes.append(MemoryFormatOpsPass()) 730*523fa7a6SAndroid Build Coastguard Worker 731*523fa7a6SAndroid Build Coastguard Worker gm = program.graph_module 732*523fa7a6SAndroid Build Coastguard Worker for p in passes: 733*523fa7a6SAndroid Build Coastguard Worker gm_res = p(gm) 734*523fa7a6SAndroid Build Coastguard Worker assert gm_res is not None 735*523fa7a6SAndroid Build Coastguard Worker gm = gm_res.graph_module 736*523fa7a6SAndroid Build Coastguard Worker 737*523fa7a6SAndroid Build Coastguard Worker edge_program = ExportedProgram( 738*523fa7a6SAndroid Build Coastguard Worker root=gm, 739*523fa7a6SAndroid Build Coastguard Worker graph=gm.graph, 740*523fa7a6SAndroid Build Coastguard Worker graph_signature=_get_updated_graph_signature(program.graph_signature, gm), 741*523fa7a6SAndroid Build Coastguard Worker state_dict=program.state_dict, 742*523fa7a6SAndroid Build Coastguard Worker range_constraints=program.range_constraints, 743*523fa7a6SAndroid Build Coastguard Worker module_call_graph=program.module_call_graph, 744*523fa7a6SAndroid Build Coastguard Worker example_inputs=program.example_inputs, 745*523fa7a6SAndroid Build Coastguard Worker constants=program.constants, 746*523fa7a6SAndroid Build Coastguard Worker verifiers=[ 747*523fa7a6SAndroid Build Coastguard Worker EXIREdgeDialectVerifier( 748*523fa7a6SAndroid Build Coastguard Worker edge_compile_config=config, 749*523fa7a6SAndroid Build Coastguard Worker class_only=True, 750*523fa7a6SAndroid Build Coastguard Worker exception_list=ops_set_to_not_decompose, 751*523fa7a6SAndroid Build Coastguard Worker ) 752*523fa7a6SAndroid Build Coastguard Worker ], 753*523fa7a6SAndroid Build Coastguard Worker ) 754*523fa7a6SAndroid Build Coastguard Worker # Lift the tensor constants created in ScalarToTensorPass 755*523fa7a6SAndroid Build Coastguard Worker edge_program = lift_constant_tensor_pass(edge_program) 756*523fa7a6SAndroid Build Coastguard Worker edge_program = _transform(edge_program, *post_op_replace_passes) 757*523fa7a6SAndroid Build Coastguard Worker 758*523fa7a6SAndroid Build Coastguard Worker return edge_program 759*523fa7a6SAndroid Build Coastguard Worker 760*523fa7a6SAndroid Build Coastguard Worker 761*523fa7a6SAndroid Build Coastguard Workerdef _replace_aten_ops_with_transformed_ops( 762*523fa7a6SAndroid Build Coastguard Worker name: str, 763*523fa7a6SAndroid Build Coastguard Worker program: ExportedProgram, 764*523fa7a6SAndroid Build Coastguard Worker partitioner, 765*523fa7a6SAndroid Build Coastguard Worker): 766*523fa7a6SAndroid Build Coastguard Worker ops_to_not_decompose = set() 767*523fa7a6SAndroid Build Coastguard Worker partitioners = partitioner.get(name) 768*523fa7a6SAndroid Build Coastguard Worker if partitioners is None: 769*523fa7a6SAndroid Build Coastguard Worker return 770*523fa7a6SAndroid Build Coastguard Worker 771*523fa7a6SAndroid Build Coastguard Worker # Iterate through the graph and replace the aten ops with the corresponding 772*523fa7a6SAndroid Build Coastguard Worker # transformed ops. 773*523fa7a6SAndroid Build Coastguard Worker for partitioner in partitioners: 774*523fa7a6SAndroid Build Coastguard Worker ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose( 775*523fa7a6SAndroid Build Coastguard Worker program 776*523fa7a6SAndroid Build Coastguard Worker ) 777*523fa7a6SAndroid Build Coastguard Worker 778*523fa7a6SAndroid Build Coastguard Worker for op_aten in ops_set_to_not_decompose: 779*523fa7a6SAndroid Build Coastguard Worker _register_no_decomp_op(op_aten) 780*523fa7a6SAndroid Build Coastguard Worker 781*523fa7a6SAndroid Build Coastguard Worker for node in program.graph.nodes: 782*523fa7a6SAndroid Build Coastguard Worker is_op_supported = check_op_support(node) if check_op_support else True 783*523fa7a6SAndroid Build Coastguard Worker if ( 784*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 785*523fa7a6SAndroid Build Coastguard Worker and node.target in ops_set_to_not_decompose 786*523fa7a6SAndroid Build Coastguard Worker and is_op_supported 787*523fa7a6SAndroid Build Coastguard Worker ): 788*523fa7a6SAndroid Build Coastguard Worker ops_to_not_decompose.add(node.target) 789*523fa7a6SAndroid Build Coastguard Worker node.target = aten_op_to_transform_op[node.target] 790*523fa7a6SAndroid Build Coastguard Worker 791*523fa7a6SAndroid Build Coastguard Worker for _, submod, _ in get_control_flow_submodules(program.graph_module): 792*523fa7a6SAndroid Build Coastguard Worker for node in submod.graph.nodes: 793*523fa7a6SAndroid Build Coastguard Worker is_op_supported = check_op_support(node) if check_op_support else True 794*523fa7a6SAndroid Build Coastguard Worker if ( 795*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 796*523fa7a6SAndroid Build Coastguard Worker and node.target in ops_set_to_not_decompose 797*523fa7a6SAndroid Build Coastguard Worker and is_op_supported 798*523fa7a6SAndroid Build Coastguard Worker ): 799*523fa7a6SAndroid Build Coastguard Worker ops_to_not_decompose.add(node.target) 800*523fa7a6SAndroid Build Coastguard Worker node.target = aten_op_to_transform_op[node.target] 801*523fa7a6SAndroid Build Coastguard Worker 802*523fa7a6SAndroid Build Coastguard Worker return ops_to_not_decompose 803*523fa7a6SAndroid Build Coastguard Worker 804*523fa7a6SAndroid Build Coastguard Worker 805*523fa7a6SAndroid Build Coastguard Workerdef _restore_transformed_ops_to_aten_ops(program: ExportedProgram): 806*523fa7a6SAndroid Build Coastguard Worker # Iterate through the graph and replace back the transformed ops with their 807*523fa7a6SAndroid Build Coastguard Worker # corresponding aten ops. 808*523fa7a6SAndroid Build Coastguard Worker for node in program.graph.nodes: 809*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" and str(node.target) in transform_op_to_aten_op: 810*523fa7a6SAndroid Build Coastguard Worker node.target = transform_op_to_aten_op[str(node.target)] 811*523fa7a6SAndroid Build Coastguard Worker for _, submod, _ in get_control_flow_submodules(program.graph_module): 812*523fa7a6SAndroid Build Coastguard Worker for node in submod.graph.nodes: 813*523fa7a6SAndroid Build Coastguard Worker if ( 814*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" 815*523fa7a6SAndroid Build Coastguard Worker and str(node.target) in transform_op_to_aten_op 816*523fa7a6SAndroid Build Coastguard Worker ): 817*523fa7a6SAndroid Build Coastguard Worker node.target = transform_op_to_aten_op[str(node.target)] 818*523fa7a6SAndroid Build Coastguard Worker 819*523fa7a6SAndroid Build Coastguard Worker 820*523fa7a6SAndroid Build Coastguard Worker# Returns the op in edge_no_decomp_namespace namespace for the aten 821*523fa7a6SAndroid Build Coastguard Worker# op that is passed in. 822*523fa7a6SAndroid Build Coastguard Workerdef _get_transformed_op(op_aten): 823*523fa7a6SAndroid Build Coastguard Worker op_name = op_aten._schema.name.split("::")[1] 824*523fa7a6SAndroid Build Coastguard Worker overload_name = op_aten._schema.overload_name 825*523fa7a6SAndroid Build Coastguard Worker assert hasattr( 826*523fa7a6SAndroid Build Coastguard Worker torch.ops, edge_no_decomp_namespace 827*523fa7a6SAndroid Build Coastguard Worker ), f"Couldn't find {edge_no_decomp_namespace} in torch.ops. Please make sure the Library has been registered." 828*523fa7a6SAndroid Build Coastguard Worker op_namespace = getattr(torch.ops, edge_no_decomp_namespace) 829*523fa7a6SAndroid Build Coastguard Worker op = getattr(op_namespace, op_name) 830*523fa7a6SAndroid Build Coastguard Worker return getattr(op, overload_name) 831*523fa7a6SAndroid Build Coastguard Worker 832*523fa7a6SAndroid Build Coastguard Worker 833*523fa7a6SAndroid Build Coastguard Worker# Registers the op in edge_no_decomp_namespace namespace for the aten 834*523fa7a6SAndroid Build Coastguard Worker# op that is passed in if it is not already cached in the table. 835*523fa7a6SAndroid Build Coastguard Workerdef _register_no_decomp_op(op_aten): 836*523fa7a6SAndroid Build Coastguard Worker # Check if the op is already cached in the table. If not, then we need to 837*523fa7a6SAndroid Build Coastguard Worker # create a new op in the edge_no_decomp_namespace namespace. 838*523fa7a6SAndroid Build Coastguard Worker if aten_op_to_transform_op.get(op_aten) is None and isinstance( 839*523fa7a6SAndroid Build Coastguard Worker op_aten, torch._ops.OpOverload 840*523fa7a6SAndroid Build Coastguard Worker ): 841*523fa7a6SAndroid Build Coastguard Worker # Extract the schema from the aten op. 842*523fa7a6SAndroid Build Coastguard Worker op_schema = str(op_aten._schema).split("::")[1] 843*523fa7a6SAndroid Build Coastguard Worker op_name = op_aten._schema.name.split("::")[1] 844*523fa7a6SAndroid Build Coastguard Worker # Define an op in the edge_no_decomp_namespace namespace with the aten schema. 845*523fa7a6SAndroid Build Coastguard Worker lib.define(op_schema) 846*523fa7a6SAndroid Build Coastguard Worker # Define the implementation of the op in the edge_no_decomp_namespace namespace. 847*523fa7a6SAndroid Build Coastguard Worker # Important to note that the implementation of the op is the same as the aten op. 848*523fa7a6SAndroid Build Coastguard Worker 849*523fa7a6SAndroid Build Coastguard Worker overload_name = op_aten._schema.overload_name 850*523fa7a6SAndroid Build Coastguard Worker if overload_name != "": 851*523fa7a6SAndroid Build Coastguard Worker op_name += "." + overload_name 852*523fa7a6SAndroid Build Coastguard Worker lib.impl(op_name, op_aten, "CompositeExplicitAutograd") 853*523fa7a6SAndroid Build Coastguard Worker 854*523fa7a6SAndroid Build Coastguard Worker # Cache the aten op and transformed op in their corresponding tables for future use. 855*523fa7a6SAndroid Build Coastguard Worker aten_op_to_transform_op[op_aten] = _get_transformed_op(op_aten) 856*523fa7a6SAndroid Build Coastguard Worker transform_op_to_aten_op[str(aten_op_to_transform_op[op_aten])] = op_aten 857*523fa7a6SAndroid Build Coastguard Worker 858*523fa7a6SAndroid Build Coastguard Worker 859*523fa7a6SAndroid Build Coastguard Workerdef _sanity_check_graph_for_non_decomp_ops( 860*523fa7a6SAndroid Build Coastguard Worker name: str, 861*523fa7a6SAndroid Build Coastguard Worker program: ExportedProgram, 862*523fa7a6SAndroid Build Coastguard Worker ops_set_to_not_decompose, 863*523fa7a6SAndroid Build Coastguard Worker check_op_support, 864*523fa7a6SAndroid Build Coastguard Worker generate_error=False, 865*523fa7a6SAndroid Build Coastguard Worker partitioner_name=None, 866*523fa7a6SAndroid Build Coastguard Worker): 867*523fa7a6SAndroid Build Coastguard Worker warning_str = f"Found {ops_set_to_not_decompose} in edge dialect program {name}." 868*523fa7a6SAndroid Build Coastguard Worker if partitioner_name is not None: 869*523fa7a6SAndroid Build Coastguard Worker warning_str += f" This op was registered by the partitioner {partitioner_name} to not be decomposed." 870*523fa7a6SAndroid Build Coastguard Worker 871*523fa7a6SAndroid Build Coastguard Worker # Check that the ops that were registered to not be decomposed are not present in the 872*523fa7a6SAndroid Build Coastguard Worker # graph anymore as the transform passes and backends should have consumed them by now. 873*523fa7a6SAndroid Build Coastguard Worker ops_set_to_not_decompose = { 874*523fa7a6SAndroid Build Coastguard Worker aten_to_edge(op) for op in ops_set_to_not_decompose 875*523fa7a6SAndroid Build Coastguard Worker }.union(ops_set_to_not_decompose) 876*523fa7a6SAndroid Build Coastguard Worker for node in program.graph_module.graph.nodes: 877*523fa7a6SAndroid Build Coastguard Worker is_op_supported = check_op_support(node) if check_op_support else True 878*523fa7a6SAndroid Build Coastguard Worker if ( 879*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" and node.target in ops_set_to_not_decompose 880*523fa7a6SAndroid Build Coastguard Worker ) and is_op_supported: 881*523fa7a6SAndroid Build Coastguard Worker if generate_error: 882*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(warning_str) 883*523fa7a6SAndroid Build Coastguard Worker else: 884*523fa7a6SAndroid Build Coastguard Worker logging.warning(warning_str) 885*523fa7a6SAndroid Build Coastguard Worker for _, submod, _ in get_control_flow_submodules(program.graph_module): 886*523fa7a6SAndroid Build Coastguard Worker for node in submod.graph.nodes: 887*523fa7a6SAndroid Build Coastguard Worker is_op_supported = check_op_support(node) if check_op_support else True 888*523fa7a6SAndroid Build Coastguard Worker if ( 889*523fa7a6SAndroid Build Coastguard Worker node.op == "call_function" and node.target in ops_set_to_not_decompose 890*523fa7a6SAndroid Build Coastguard Worker ) and is_op_supported: 891*523fa7a6SAndroid Build Coastguard Worker if generate_error: 892*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(warning_str) 893*523fa7a6SAndroid Build Coastguard Worker else: 894*523fa7a6SAndroid Build Coastguard Worker logging.warning(warning_str) 895*523fa7a6SAndroid Build Coastguard Worker 896*523fa7a6SAndroid Build Coastguard Worker 897*523fa7a6SAndroid Build Coastguard Workerdef _gen_edge_manager_for_partitioners( 898*523fa7a6SAndroid Build Coastguard Worker partitioner: Dict[str, List[Partitioner]], 899*523fa7a6SAndroid Build Coastguard Worker aten_programs: Dict[str, ExportedProgram], 900*523fa7a6SAndroid Build Coastguard Worker config: EdgeCompileConfig, 901*523fa7a6SAndroid Build Coastguard Worker constant_methods: Optional[Dict[str, Any]], 902*523fa7a6SAndroid Build Coastguard Worker) -> "EdgeProgramManager": 903*523fa7a6SAndroid Build Coastguard Worker """ 904*523fa7a6SAndroid Build Coastguard Worker Generates EdgeProgramManager for subsequent lowering to the 905*523fa7a6SAndroid Build Coastguard Worker partitioners specified by partitioner. The EdgeProgramManager is generated from 906*523fa7a6SAndroid Build Coastguard Worker aten_programs. 907*523fa7a6SAndroid Build Coastguard Worker 908*523fa7a6SAndroid Build Coastguard Worker Partitioners specify what nodes should not be decomposed from the original aten programs. 909*523fa7a6SAndroid Build Coastguard Worker This is done through two passes of run_decompositions. 910*523fa7a6SAndroid Build Coastguard Worker - First pass preserves all aten_targets specified by partitioners to preserve 911*523fa7a6SAndroid Build Coastguard Worker them from nested decompositions 912*523fa7a6SAndroid Build Coastguard Worker - Second pass uses check_op fn provided by partitioners to perform additional checks 913*523fa7a6SAndroid Build Coastguard Worker on nodes with preserved aten targets. They are then replaces with transformed ops to 914*523fa7a6SAndroid Build Coastguard Worker keep them through the second pass of decompositions 915*523fa7a6SAndroid Build Coastguard Worker """ 916*523fa7a6SAndroid Build Coastguard Worker ops_set_to_not_decompose_by_program = {} 917*523fa7a6SAndroid Build Coastguard Worker edge_programs: Dict[str, ExportedProgram] = {} 918*523fa7a6SAndroid Build Coastguard Worker for name, program in aten_programs.items(): 919*523fa7a6SAndroid Build Coastguard Worker if partitioner is not None: 920*523fa7a6SAndroid Build Coastguard Worker # preserve all ops listed by all partitioners first 921*523fa7a6SAndroid Build Coastguard Worker all_ops_no_decomp = set() 922*523fa7a6SAndroid Build Coastguard Worker for curr_partitioner in partitioner.get(name, []): 923*523fa7a6SAndroid Build Coastguard Worker curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) 924*523fa7a6SAndroid Build Coastguard Worker all_ops_no_decomp |= set(curr_ops_no_decomp) 925*523fa7a6SAndroid Build Coastguard Worker 926*523fa7a6SAndroid Build Coastguard Worker table = _default_decomposition_table() 927*523fa7a6SAndroid Build Coastguard Worker 928*523fa7a6SAndroid Build Coastguard Worker for op in all_ops_no_decomp: 929*523fa7a6SAndroid Build Coastguard Worker table.pop(op, None) 930*523fa7a6SAndroid Build Coastguard Worker 931*523fa7a6SAndroid Build Coastguard Worker program = program.run_decompositions(table) 932*523fa7a6SAndroid Build Coastguard Worker # Among all the preserved aten ops, use the check_op_fn to do an additional 933*523fa7a6SAndroid Build Coastguard Worker # check on which ops need to be preserved and which ops need to be decomposed 934*523fa7a6SAndroid Build Coastguard Worker # Those which are truly preserved will be replaced with transformed ops 935*523fa7a6SAndroid Build Coastguard Worker ops_set_to_not_decompose_by_program[name] = ( 936*523fa7a6SAndroid Build Coastguard Worker _replace_aten_ops_with_transformed_ops(name, program, partitioner) or [] 937*523fa7a6SAndroid Build Coastguard Worker ) 938*523fa7a6SAndroid Build Coastguard Worker program = program.run_decompositions(_default_decomposition_table()) 939*523fa7a6SAndroid Build Coastguard Worker 940*523fa7a6SAndroid Build Coastguard Worker _restore_transformed_ops_to_aten_ops(program) 941*523fa7a6SAndroid Build Coastguard Worker 942*523fa7a6SAndroid Build Coastguard Worker edge_programs[name] = program 943*523fa7a6SAndroid Build Coastguard Worker 944*523fa7a6SAndroid Build Coastguard Worker edge_programs[name] = _generate_edge_program( 945*523fa7a6SAndroid Build Coastguard Worker name, 946*523fa7a6SAndroid Build Coastguard Worker config, 947*523fa7a6SAndroid Build Coastguard Worker program, 948*523fa7a6SAndroid Build Coastguard Worker list(ops_set_to_not_decompose_by_program.get(name, [])), 949*523fa7a6SAndroid Build Coastguard Worker ) 950*523fa7a6SAndroid Build Coastguard Worker 951*523fa7a6SAndroid Build Coastguard Worker edge_manager = EdgeProgramManager( 952*523fa7a6SAndroid Build Coastguard Worker edge_programs, 953*523fa7a6SAndroid Build Coastguard Worker constant_methods, 954*523fa7a6SAndroid Build Coastguard Worker config, 955*523fa7a6SAndroid Build Coastguard Worker list(set().union(*ops_set_to_not_decompose_by_program.values())), 956*523fa7a6SAndroid Build Coastguard Worker ) 957*523fa7a6SAndroid Build Coastguard Worker return edge_manager 958*523fa7a6SAndroid Build Coastguard Worker 959*523fa7a6SAndroid Build Coastguard Worker 960*523fa7a6SAndroid Build Coastguard Workerdef to_edge_transform_and_lower( 961*523fa7a6SAndroid Build Coastguard Worker programs: Union[ExportedProgram, Dict[str, ExportedProgram]], 962*523fa7a6SAndroid Build Coastguard Worker transform_passes: Optional[ 963*523fa7a6SAndroid Build Coastguard Worker Union[Sequence[PassType], Dict[str, Sequence[PassType]]] 964*523fa7a6SAndroid Build Coastguard Worker ] = None, 965*523fa7a6SAndroid Build Coastguard Worker partitioner: Optional[ 966*523fa7a6SAndroid Build Coastguard Worker Union[List[Partitioner], Dict[str, List[Partitioner]]] 967*523fa7a6SAndroid Build Coastguard Worker ] = None, 968*523fa7a6SAndroid Build Coastguard Worker constant_methods: Optional[Dict[str, Any]] = None, 969*523fa7a6SAndroid Build Coastguard Worker compile_config: Optional[EdgeCompileConfig] = None, 970*523fa7a6SAndroid Build Coastguard Worker) -> "EdgeProgramManager": 971*523fa7a6SAndroid Build Coastguard Worker """ 972*523fa7a6SAndroid Build Coastguard Worker :func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of 973*523fa7a6SAndroid Build Coastguard Worker exported programs in ATen dialect. It differs fundamentally from to_edge in that it 974*523fa7a6SAndroid Build Coastguard Worker combines the conversion of the ATen dialect to the edge dialect program, then running 975*523fa7a6SAndroid Build Coastguard Worker the transformation passes and then subsequently lowering the programs to their 976*523fa7a6SAndroid Build Coastguard Worker corresponding backends all into a single API. 977*523fa7a6SAndroid Build Coastguard Worker 978*523fa7a6SAndroid Build Coastguard Worker This is fundamentally useful for lowering to backends that have ops registered that they 979*523fa7a6SAndroid Build Coastguard Worker do not want to be decomposed and thus rely on matching with these non-decomposed ops. For 980*523fa7a6SAndroid Build Coastguard Worker these sorts of backends this is the *only* API that should be used to lower to the edge 981*523fa7a6SAndroid Build Coastguard Worker dialect. Using a combination of to_edge(...) and to_backend(...) will result in inconsistent 982*523fa7a6SAndroid Build Coastguard Worker or wrong behavior. 983*523fa7a6SAndroid Build Coastguard Worker 984*523fa7a6SAndroid Build Coastguard Worker This API is the primary recommended way to lower to the CPU based XNNPack backend. 985*523fa7a6SAndroid Build Coastguard Worker 986*523fa7a6SAndroid Build Coastguard Worker Args: 987*523fa7a6SAndroid Build Coastguard Worker programs: Can be a single ExportedProgram or a dictionary mapping function names 988*523fa7a6SAndroid Build Coastguard Worker to their corresponding ExportedPrograms. If only a single ExportedProgram is 989*523fa7a6SAndroid Build Coastguard Worker provided it will be assigned the name "forward". 990*523fa7a6SAndroid Build Coastguard Worker 991*523fa7a6SAndroid Build Coastguard Worker transform_passes: The passes can either be a list of passes, or a dictionary 992*523fa7a6SAndroid Build Coastguard Worker mapping method names to lists of passes. If it is just a list of passes, all methods 993*523fa7a6SAndroid Build Coastguard Worker in the given EdgeProgramManager will be transformed with the provided passes. If it 994*523fa7a6SAndroid Build Coastguard Worker is a dictionary, only method names specified in the dictionary will be transformed 995*523fa7a6SAndroid Build Coastguard Worker with their corresponding passes. 996*523fa7a6SAndroid Build Coastguard Worker 997*523fa7a6SAndroid Build Coastguard Worker partitioner: The partitioner can either be a Partitioner subclass instance, or a 998*523fa7a6SAndroid Build Coastguard Worker dictionary mapping method names to Partitioner subclass instance. If it is a 999*523fa7a6SAndroid Build Coastguard Worker Partitioner subclass, all programs in the given EdgeProgramManager will be lowered 1000*523fa7a6SAndroid Build Coastguard Worker using the given partitioner. If it is a dictionary, only method names specified in 1001*523fa7a6SAndroid Build Coastguard Worker the dictionary will be lowered with the given partitioner. 1002*523fa7a6SAndroid Build Coastguard Worker 1003*523fa7a6SAndroid Build Coastguard Worker constant_methods: An optional dictionary of method name to the constant value 1004*523fa7a6SAndroid Build Coastguard Worker returned by that method in eager mode. Often used to store config information on 1005*523fa7a6SAndroid Build Coastguard Worker Edge models. 1006*523fa7a6SAndroid Build Coastguard Worker 1007*523fa7a6SAndroid Build Coastguard Worker compile_config: An optional argument used to provide greater control over the 1008*523fa7a6SAndroid Build Coastguard Worker transformation to edge dialect process. 1009*523fa7a6SAndroid Build Coastguard Worker 1010*523fa7a6SAndroid Build Coastguard Worker Returns: 1011*523fa7a6SAndroid Build Coastguard Worker EdgeProgramManager 1012*523fa7a6SAndroid Build Coastguard Worker """ 1013*523fa7a6SAndroid Build Coastguard Worker assert not isinstance(constant_methods, EdgeCompileConfig) 1014*523fa7a6SAndroid Build Coastguard Worker config = compile_config or EdgeCompileConfig() 1015*523fa7a6SAndroid Build Coastguard Worker if not isinstance(programs, dict): 1016*523fa7a6SAndroid Build Coastguard Worker aten_programs = {"forward": programs} 1017*523fa7a6SAndroid Build Coastguard Worker else: 1018*523fa7a6SAndroid Build Coastguard Worker aten_programs = programs 1019*523fa7a6SAndroid Build Coastguard Worker 1020*523fa7a6SAndroid Build Coastguard Worker if not isinstance(partitioner, dict) and partitioner is not None: 1021*523fa7a6SAndroid Build Coastguard Worker partitioner = {name: partitioner for name in aten_programs.keys()} 1022*523fa7a6SAndroid Build Coastguard Worker elif partitioner is None: 1023*523fa7a6SAndroid Build Coastguard Worker partitioner = {name: [] for name in aten_programs.keys()} 1024*523fa7a6SAndroid Build Coastguard Worker 1025*523fa7a6SAndroid Build Coastguard Worker edge_manager = _gen_edge_manager_for_partitioners( 1026*523fa7a6SAndroid Build Coastguard Worker partitioner, aten_programs, config, constant_methods 1027*523fa7a6SAndroid Build Coastguard Worker ) 1028*523fa7a6SAndroid Build Coastguard Worker 1029*523fa7a6SAndroid Build Coastguard Worker if transform_passes is not None: 1030*523fa7a6SAndroid Build Coastguard Worker edge_manager = edge_manager.transform(transform_passes) 1031*523fa7a6SAndroid Build Coastguard Worker 1032*523fa7a6SAndroid Build Coastguard Worker if partitioner is not None: 1033*523fa7a6SAndroid Build Coastguard Worker for name, partitioner_list in partitioner.items(): 1034*523fa7a6SAndroid Build Coastguard Worker for curr_partitioner in partitioner_list: 1035*523fa7a6SAndroid Build Coastguard Worker edge_manager = edge_manager.to_backend({name: curr_partitioner}) 1036*523fa7a6SAndroid Build Coastguard Worker 1037*523fa7a6SAndroid Build Coastguard Worker for name, program in edge_manager._edge_programs.items(): 1038*523fa7a6SAndroid Build Coastguard Worker ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set() 1039*523fa7a6SAndroid Build Coastguard Worker partitioners = partitioner.get(name, []) 1040*523fa7a6SAndroid Build Coastguard Worker for curr_partitioner in partitioners: 1041*523fa7a6SAndroid Build Coastguard Worker curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose( 1042*523fa7a6SAndroid Build Coastguard Worker program 1043*523fa7a6SAndroid Build Coastguard Worker ) 1044*523fa7a6SAndroid Build Coastguard Worker ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set) 1045*523fa7a6SAndroid Build Coastguard Worker _sanity_check_graph_for_non_decomp_ops( 1046*523fa7a6SAndroid Build Coastguard Worker name, 1047*523fa7a6SAndroid Build Coastguard Worker program, 1048*523fa7a6SAndroid Build Coastguard Worker ops_set_to_not_decompose, 1049*523fa7a6SAndroid Build Coastguard Worker check_op_support, 1050*523fa7a6SAndroid Build Coastguard Worker partitioner_name=curr_partitioner.__class__.__name__, 1051*523fa7a6SAndroid Build Coastguard Worker generate_error=True, 1052*523fa7a6SAndroid Build Coastguard Worker ) 1053*523fa7a6SAndroid Build Coastguard Worker 1054*523fa7a6SAndroid Build Coastguard Worker if config._check_ir_validity: 1055*523fa7a6SAndroid Build Coastguard Worker EXIREdgeDialectVerifier( 1056*523fa7a6SAndroid Build Coastguard Worker edge_compile_config=config, 1057*523fa7a6SAndroid Build Coastguard Worker class_only=True, 1058*523fa7a6SAndroid Build Coastguard Worker exception_list=list(ops_set_to_not_decompose), 1059*523fa7a6SAndroid Build Coastguard Worker )()(program.graph_module) 1060*523fa7a6SAndroid Build Coastguard Worker 1061*523fa7a6SAndroid Build Coastguard Worker return edge_manager 1062*523fa7a6SAndroid Build Coastguard Worker 1063*523fa7a6SAndroid Build Coastguard Worker 1064*523fa7a6SAndroid Build Coastguard Worker@experimental( 1065*523fa7a6SAndroid Build Coastguard Worker """ 1066*523fa7a6SAndroid Build Coastguard Worker This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed. 1067*523fa7a6SAndroid Build Coastguard Worker This function will be combined with to_edge in the future. 1068*523fa7a6SAndroid Build Coastguard Worker """ 1069*523fa7a6SAndroid Build Coastguard Worker) 1070*523fa7a6SAndroid Build Coastguard Workerdef to_edge_with_preserved_ops( 1071*523fa7a6SAndroid Build Coastguard Worker programs: Union[ExportedProgram, Dict[str, ExportedProgram]], 1072*523fa7a6SAndroid Build Coastguard Worker constant_methods: Optional[Dict[str, Any]] = None, 1073*523fa7a6SAndroid Build Coastguard Worker compile_config: Optional[EdgeCompileConfig] = None, 1074*523fa7a6SAndroid Build Coastguard Worker preserve_ops: Tuple[torch._ops.OpOverload, ...] = (), 1075*523fa7a6SAndroid Build Coastguard Worker) -> "EdgeProgramManager": 1076*523fa7a6SAndroid Build Coastguard Worker """ 1077*523fa7a6SAndroid Build Coastguard Worker :func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in 1078*523fa7a6SAndroid Build Coastguard Worker ATen dialect. Upon construction those programs are transformed into edge dialect. 1079*523fa7a6SAndroid Build Coastguard Worker 1080*523fa7a6SAndroid Build Coastguard Worker Args: 1081*523fa7a6SAndroid Build Coastguard Worker programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward". 1082*523fa7a6SAndroid Build Coastguard Worker constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models. 1083*523fa7a6SAndroid Build Coastguard Worker compile_config: An optional argument used to provide greater control over the transformation to edge dialect process. 1084*523fa7a6SAndroid Build Coastguard Worker preserve_ops: An argument used to specify ops that should not be decomposed. 1085*523fa7a6SAndroid Build Coastguard Worker 1086*523fa7a6SAndroid Build Coastguard Worker Returns: 1087*523fa7a6SAndroid Build Coastguard Worker EdgeProgramManager 1088*523fa7a6SAndroid Build Coastguard Worker """ 1089*523fa7a6SAndroid Build Coastguard Worker assert not isinstance(constant_methods, EdgeCompileConfig) 1090*523fa7a6SAndroid Build Coastguard Worker config = compile_config or EdgeCompileConfig() 1091*523fa7a6SAndroid Build Coastguard Worker if not isinstance(programs, dict): 1092*523fa7a6SAndroid Build Coastguard Worker aten_programs = {"forward": programs} 1093*523fa7a6SAndroid Build Coastguard Worker else: 1094*523fa7a6SAndroid Build Coastguard Worker aten_programs = programs 1095*523fa7a6SAndroid Build Coastguard Worker 1096*523fa7a6SAndroid Build Coastguard Worker edge_programs: Dict[str, ExportedProgram] = {} 1097*523fa7a6SAndroid Build Coastguard Worker 1098*523fa7a6SAndroid Build Coastguard Worker for name, program in aten_programs.items(): 1099*523fa7a6SAndroid Build Coastguard Worker # Decompose to Core ATen 1100*523fa7a6SAndroid Build Coastguard Worker table = _default_decomposition_table() 1101*523fa7a6SAndroid Build Coastguard Worker for op in preserve_ops: 1102*523fa7a6SAndroid Build Coastguard Worker table.pop(op, None) 1103*523fa7a6SAndroid Build Coastguard Worker program = program.run_decompositions(table) 1104*523fa7a6SAndroid Build Coastguard Worker edge_programs[name] = _generate_edge_program( 1105*523fa7a6SAndroid Build Coastguard Worker name, config, program, list(preserve_ops) 1106*523fa7a6SAndroid Build Coastguard Worker ) 1107*523fa7a6SAndroid Build Coastguard Worker 1108*523fa7a6SAndroid Build Coastguard Worker return EdgeProgramManager( 1109*523fa7a6SAndroid Build Coastguard Worker edge_programs, constant_methods, config, list(preserve_ops) 1110*523fa7a6SAndroid Build Coastguard Worker ) 1111*523fa7a6SAndroid Build Coastguard Worker 1112*523fa7a6SAndroid Build Coastguard Worker 1113*523fa7a6SAndroid Build Coastguard Workerdef to_edge( 1114*523fa7a6SAndroid Build Coastguard Worker programs: Union[ExportedProgram, Dict[str, ExportedProgram]], 1115*523fa7a6SAndroid Build Coastguard Worker constant_methods: Optional[Dict[str, Any]] = None, 1116*523fa7a6SAndroid Build Coastguard Worker compile_config: Optional[EdgeCompileConfig] = None, 1117*523fa7a6SAndroid Build Coastguard Worker) -> "EdgeProgramManager": 1118*523fa7a6SAndroid Build Coastguard Worker """ 1119*523fa7a6SAndroid Build Coastguard Worker :func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in 1120*523fa7a6SAndroid Build Coastguard Worker ATen dialect. Upon construction those programs are transformed into edge dialect. 1121*523fa7a6SAndroid Build Coastguard Worker 1122*523fa7a6SAndroid Build Coastguard Worker Args: 1123*523fa7a6SAndroid Build Coastguard Worker programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward". 1124*523fa7a6SAndroid Build Coastguard Worker 1125*523fa7a6SAndroid Build Coastguard Worker constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models. 1126*523fa7a6SAndroid Build Coastguard Worker 1127*523fa7a6SAndroid Build Coastguard Worker compile_config: An optional argument used to provide greater control over the transformation to edge dialect process. 1128*523fa7a6SAndroid Build Coastguard Worker 1129*523fa7a6SAndroid Build Coastguard Worker Returns: 1130*523fa7a6SAndroid Build Coastguard Worker EdgeProgramManager 1131*523fa7a6SAndroid Build Coastguard Worker """ 1132*523fa7a6SAndroid Build Coastguard Worker assert not isinstance(constant_methods, EdgeCompileConfig) 1133*523fa7a6SAndroid Build Coastguard Worker config = compile_config or EdgeCompileConfig() 1134*523fa7a6SAndroid Build Coastguard Worker if not isinstance(programs, dict): 1135*523fa7a6SAndroid Build Coastguard Worker aten_programs = {"forward": programs} 1136*523fa7a6SAndroid Build Coastguard Worker else: 1137*523fa7a6SAndroid Build Coastguard Worker aten_programs = programs 1138*523fa7a6SAndroid Build Coastguard Worker 1139*523fa7a6SAndroid Build Coastguard Worker edge_programs: Dict[str, ExportedProgram] = {} 1140*523fa7a6SAndroid Build Coastguard Worker 1141*523fa7a6SAndroid Build Coastguard Worker for name, program in aten_programs.items(): 1142*523fa7a6SAndroid Build Coastguard Worker # Decompose to Core ATen 1143*523fa7a6SAndroid Build Coastguard Worker program = program.run_decompositions(_default_decomposition_table()) 1144*523fa7a6SAndroid Build Coastguard Worker edge_programs[name] = _generate_edge_program(name, config, program) 1145*523fa7a6SAndroid Build Coastguard Worker 1146*523fa7a6SAndroid Build Coastguard Worker return EdgeProgramManager(edge_programs, constant_methods, config) 1147*523fa7a6SAndroid Build Coastguard Worker 1148*523fa7a6SAndroid Build Coastguard Worker 1149*523fa7a6SAndroid Build Coastguard Workerclass EdgeProgramManager: 1150*523fa7a6SAndroid Build Coastguard Worker """ 1151*523fa7a6SAndroid Build Coastguard Worker Package of one or more `ExportedPrograms` in Edge dialect. Designed to simplify 1152*523fa7a6SAndroid Build Coastguard Worker lowering to ExecuTorch. See: https://pytorch.org/executorch/stable/ir-exir.html 1153*523fa7a6SAndroid Build Coastguard Worker 1154*523fa7a6SAndroid Build Coastguard Worker Allows easy applications of transforms across a collection of exported programs 1155*523fa7a6SAndroid Build Coastguard Worker including the delegation of subgraphs. 1156*523fa7a6SAndroid Build Coastguard Worker 1157*523fa7a6SAndroid Build Coastguard Worker Manages the second link in the lowering chain of ATen -> Edge -> ExecuTorch. 1158*523fa7a6SAndroid Build Coastguard Worker """ 1159*523fa7a6SAndroid Build Coastguard Worker 1160*523fa7a6SAndroid Build Coastguard Worker def __init__( 1161*523fa7a6SAndroid Build Coastguard Worker self, 1162*523fa7a6SAndroid Build Coastguard Worker edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]], 1163*523fa7a6SAndroid Build Coastguard Worker constant_methods: Optional[Dict[str, Any]] = None, 1164*523fa7a6SAndroid Build Coastguard Worker compile_config: Optional[EdgeCompileConfig] = None, 1165*523fa7a6SAndroid Build Coastguard Worker ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None, 1166*523fa7a6SAndroid Build Coastguard Worker ): 1167*523fa7a6SAndroid Build Coastguard Worker """ 1168*523fa7a6SAndroid Build Coastguard Worker Should not be called directly by users. User should use :func:'to_edge' instead. 1169*523fa7a6SAndroid Build Coastguard Worker 1170*523fa7a6SAndroid Build Coastguard Worker Constructs an EdgeProgramManager from an existing set of exported programs in edge dialect. 1171*523fa7a6SAndroid Build Coastguard Worker """ 1172*523fa7a6SAndroid Build Coastguard Worker self.compile_config = compile_config or EdgeCompileConfig() 1173*523fa7a6SAndroid Build Coastguard Worker if not isinstance(edge_programs, dict): 1174*523fa7a6SAndroid Build Coastguard Worker edge_programs = {"forward": edge_programs} 1175*523fa7a6SAndroid Build Coastguard Worker 1176*523fa7a6SAndroid Build Coastguard Worker for name, program in edge_programs.items(): 1177*523fa7a6SAndroid Build Coastguard Worker try: 1178*523fa7a6SAndroid Build Coastguard Worker EXIREdgeDialectVerifier( 1179*523fa7a6SAndroid Build Coastguard Worker edge_compile_config=self.compile_config, 1180*523fa7a6SAndroid Build Coastguard Worker exception_list=ops_set_to_not_decompose, 1181*523fa7a6SAndroid Build Coastguard Worker )(program.graph_module) 1182*523fa7a6SAndroid Build Coastguard Worker except ExportError as e: 1183*523fa7a6SAndroid Build Coastguard Worker logging.info(f"Input program {name} is not in aten dialect.") 1184*523fa7a6SAndroid Build Coastguard Worker raise e 1185*523fa7a6SAndroid Build Coastguard Worker 1186*523fa7a6SAndroid Build Coastguard Worker self._edge_programs: Dict[str, ExportedProgram] = edge_programs 1187*523fa7a6SAndroid Build Coastguard Worker self._config_methods = constant_methods 1188*523fa7a6SAndroid Build Coastguard Worker 1189*523fa7a6SAndroid Build Coastguard Worker @property 1190*523fa7a6SAndroid Build Coastguard Worker def methods(self) -> Set[str]: 1191*523fa7a6SAndroid Build Coastguard Worker """ 1192*523fa7a6SAndroid Build Coastguard Worker Returns the set of methods in this EdgeProgramManager. 1193*523fa7a6SAndroid Build Coastguard Worker """ 1194*523fa7a6SAndroid Build Coastguard Worker return set(self._edge_programs.keys()) 1195*523fa7a6SAndroid Build Coastguard Worker 1196*523fa7a6SAndroid Build Coastguard Worker @property 1197*523fa7a6SAndroid Build Coastguard Worker def config_methods(self) -> Set[str]: 1198*523fa7a6SAndroid Build Coastguard Worker """ 1199*523fa7a6SAndroid Build Coastguard Worker Returns the set of config methods in this EdgeProgramManager. 1200*523fa7a6SAndroid Build Coastguard Worker """ 1201*523fa7a6SAndroid Build Coastguard Worker return set(self._config_methods.keys()) if self._config_methods else set() 1202*523fa7a6SAndroid Build Coastguard Worker 1203*523fa7a6SAndroid Build Coastguard Worker def exported_program(self, method_name: str = "forward") -> ExportedProgram: 1204*523fa7a6SAndroid Build Coastguard Worker """ 1205*523fa7a6SAndroid Build Coastguard Worker Returns the ExportedProgram specified by 'method_name'. 1206*523fa7a6SAndroid Build Coastguard Worker """ 1207*523fa7a6SAndroid Build Coastguard Worker return self._edge_programs[method_name] 1208*523fa7a6SAndroid Build Coastguard Worker 1209*523fa7a6SAndroid Build Coastguard Worker def transform( 1210*523fa7a6SAndroid Build Coastguard Worker self, 1211*523fa7a6SAndroid Build Coastguard Worker passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]], 1212*523fa7a6SAndroid Build Coastguard Worker compile_config: Optional[EdgeCompileConfig] = None, 1213*523fa7a6SAndroid Build Coastguard Worker ) -> "EdgeProgramManager": 1214*523fa7a6SAndroid Build Coastguard Worker """ 1215*523fa7a6SAndroid Build Coastguard Worker Transforms the program according to the provided passes. 1216*523fa7a6SAndroid Build Coastguard Worker 1217*523fa7a6SAndroid Build Coastguard Worker Args: 1218*523fa7a6SAndroid Build Coastguard Worker passes: The passes can either be a list of passes, or a 1219*523fa7a6SAndroid Build Coastguard Worker dictionary mapping method names to lists of passes. If it is 1220*523fa7a6SAndroid Build Coastguard Worker just a list of passes, all methods in the given EdgeProgramManager 1221*523fa7a6SAndroid Build Coastguard Worker will be transformed with the provided passes. If it is a 1222*523fa7a6SAndroid Build Coastguard Worker dictionary, only method names specified in the dictionary will be 1223*523fa7a6SAndroid Build Coastguard Worker transformed with their corresponding passes. 1224*523fa7a6SAndroid Build Coastguard Worker compile_config: Compile config to use for veriy the correctness of model 1225*523fa7a6SAndroid Build Coastguard Worker graph after each pass. If not specified, the compile config of the 1226*523fa7a6SAndroid Build Coastguard Worker calling EdgeProgramManager will be used. It will be used in as compile 1227*523fa7a6SAndroid Build Coastguard Worker config of returned EdgeProgramManager. 1228*523fa7a6SAndroid Build Coastguard Worker 1229*523fa7a6SAndroid Build Coastguard Worker Returns: 1230*523fa7a6SAndroid Build Coastguard Worker EdgeProgramManager: A copy of the calling EdgeProgramManager with the 1231*523fa7a6SAndroid Build Coastguard Worker transformations applied. 1232*523fa7a6SAndroid Build Coastguard Worker """ 1233*523fa7a6SAndroid Build Coastguard Worker compile_config = compile_config or self.compile_config 1234*523fa7a6SAndroid Build Coastguard Worker new_programs: Dict[str, ExportedProgram] = {} 1235*523fa7a6SAndroid Build Coastguard Worker if isinstance(passes, dict): 1236*523fa7a6SAndroid Build Coastguard Worker for name, program in self._edge_programs.items(): 1237*523fa7a6SAndroid Build Coastguard Worker if name in passes.keys(): 1238*523fa7a6SAndroid Build Coastguard Worker new_programs[name] = _transform(program, *passes[name]) 1239*523fa7a6SAndroid Build Coastguard Worker EXIREdgeDialectVerifier(edge_compile_config=compile_config)( 1240*523fa7a6SAndroid Build Coastguard Worker new_programs[name].graph_module 1241*523fa7a6SAndroid Build Coastguard Worker ) 1242*523fa7a6SAndroid Build Coastguard Worker else: 1243*523fa7a6SAndroid Build Coastguard Worker new_programs[name] = copy.deepcopy(program) 1244*523fa7a6SAndroid Build Coastguard Worker 1245*523fa7a6SAndroid Build Coastguard Worker else: # apply passes to every method 1246*523fa7a6SAndroid Build Coastguard Worker for name, program in self._edge_programs.items(): 1247*523fa7a6SAndroid Build Coastguard Worker new_programs[name] = _transform(program, *passes) 1248*523fa7a6SAndroid Build Coastguard Worker EXIREdgeDialectVerifier(edge_compile_config=compile_config)( 1249*523fa7a6SAndroid Build Coastguard Worker new_programs[name].graph_module 1250*523fa7a6SAndroid Build Coastguard Worker ) 1251*523fa7a6SAndroid Build Coastguard Worker 1252*523fa7a6SAndroid Build Coastguard Worker return EdgeProgramManager( 1253*523fa7a6SAndroid Build Coastguard Worker new_programs, copy.deepcopy(self._config_methods), compile_config 1254*523fa7a6SAndroid Build Coastguard Worker ) 1255*523fa7a6SAndroid Build Coastguard Worker 1256*523fa7a6SAndroid Build Coastguard Worker def to_backend( 1257*523fa7a6SAndroid Build Coastguard Worker self, partitioner: Union[Partitioner, Dict[str, Partitioner]] 1258*523fa7a6SAndroid Build Coastguard Worker ) -> "EdgeProgramManager": 1259*523fa7a6SAndroid Build Coastguard Worker """ 1260*523fa7a6SAndroid Build Coastguard Worker Returns a semantically-equivalent program to the one given as input, 1261*523fa7a6SAndroid Build Coastguard Worker but with portions of each program in the EdgeProgramManager targeted 1262*523fa7a6SAndroid Build Coastguard Worker for delegation as determined by the partitioner. 1263*523fa7a6SAndroid Build Coastguard Worker 1264*523fa7a6SAndroid Build Coastguard Worker Args: 1265*523fa7a6SAndroid Build Coastguard Worker partitioner: The partitioner can either be a Partitioner subclass instance, or a 1266*523fa7a6SAndroid Build Coastguard Worker dictionary mapping method names to Partitioner subclass instance. If it is a 1267*523fa7a6SAndroid Build Coastguard Worker Partitioner subclass, all programs in the given EdgeProgramManager 1268*523fa7a6SAndroid Build Coastguard Worker will be lowered using the given partitioner. If it is a 1269*523fa7a6SAndroid Build Coastguard Worker dictionary, only method names specified in the dictionary will be 1270*523fa7a6SAndroid Build Coastguard Worker lowered with the given partitioner. 1271*523fa7a6SAndroid Build Coastguard Worker 1272*523fa7a6SAndroid Build Coastguard Worker The Partitioner subclass instance is in charge with tagging portions of the 1273*523fa7a6SAndroid Build Coastguard Worker input program for delegation. A valid partitioner must return PartitionerResult including valid 1274*523fa7a6SAndroid Build Coastguard Worker partition_tags: Dict[str, DelegationSpec], where each key is a tag 1275*523fa7a6SAndroid Build Coastguard Worker name and the nodes with same tag will be fused a one subgraph and 1276*523fa7a6SAndroid Build Coastguard Worker delegated to backend specififed in delegation spec. 1277*523fa7a6SAndroid Build Coastguard Worker 1278*523fa7a6SAndroid Build Coastguard Worker Returns: 1279*523fa7a6SAndroid Build Coastguard Worker EdgeProgramManager: A copy of the calling EdgeProgramManager with the 1280*523fa7a6SAndroid Build Coastguard Worker specified subgraphs lowered. 1281*523fa7a6SAndroid Build Coastguard Worker """ 1282*523fa7a6SAndroid Build Coastguard Worker new_edge_programs: Dict[str, ExportedProgram] = {} 1283*523fa7a6SAndroid Build Coastguard Worker if isinstance(partitioner, dict): 1284*523fa7a6SAndroid Build Coastguard Worker for name, program in self._edge_programs.items(): 1285*523fa7a6SAndroid Build Coastguard Worker if name in partitioner.keys(): 1286*523fa7a6SAndroid Build Coastguard Worker new_edge_programs[name] = to_backend(program, partitioner[name]) 1287*523fa7a6SAndroid Build Coastguard Worker else: 1288*523fa7a6SAndroid Build Coastguard Worker new_edge_programs[name] = program 1289*523fa7a6SAndroid Build Coastguard Worker 1290*523fa7a6SAndroid Build Coastguard Worker else: # apply partitioner to every method 1291*523fa7a6SAndroid Build Coastguard Worker for name, program in self._edge_programs.items(): 1292*523fa7a6SAndroid Build Coastguard Worker new_edge_programs[name] = to_backend(program, partitioner) 1293*523fa7a6SAndroid Build Coastguard Worker 1294*523fa7a6SAndroid Build Coastguard Worker config = EdgeCompileConfig(_check_ir_validity=False) 1295*523fa7a6SAndroid Build Coastguard Worker return EdgeProgramManager( 1296*523fa7a6SAndroid Build Coastguard Worker new_edge_programs, copy.deepcopy(self._config_methods), config 1297*523fa7a6SAndroid Build Coastguard Worker ) 1298*523fa7a6SAndroid Build Coastguard Worker 1299*523fa7a6SAndroid Build Coastguard Worker def to_executorch( 1300*523fa7a6SAndroid Build Coastguard Worker self, 1301*523fa7a6SAndroid Build Coastguard Worker config: Optional[ExecutorchBackendConfig] = None, 1302*523fa7a6SAndroid Build Coastguard Worker ) -> "ExecutorchProgramManager": 1303*523fa7a6SAndroid Build Coastguard Worker """ 1304*523fa7a6SAndroid Build Coastguard Worker Transforms the program to the ExecuTorch backend. 1305*523fa7a6SAndroid Build Coastguard Worker 1306*523fa7a6SAndroid Build Coastguard Worker Args: 1307*523fa7a6SAndroid Build Coastguard Worker config: An optional argument used to provide greater control over 1308*523fa7a6SAndroid Build Coastguard Worker the transformation to the ExecuTorch backend. 1309*523fa7a6SAndroid Build Coastguard Worker 1310*523fa7a6SAndroid Build Coastguard Worker Returns: 1311*523fa7a6SAndroid Build Coastguard Worker ExecutorchProgramManager: A manager representing the state of the EdgeProgramManager 1312*523fa7a6SAndroid Build Coastguard Worker after it has been transformed to the ExecuTorch backend. 1313*523fa7a6SAndroid Build Coastguard Worker """ 1314*523fa7a6SAndroid Build Coastguard Worker config = config if config else ExecutorchBackendConfig() 1315*523fa7a6SAndroid Build Coastguard Worker 1316*523fa7a6SAndroid Build Coastguard Worker execution_programs: Dict[str, ExportedProgram] = {} 1317*523fa7a6SAndroid Build Coastguard Worker for name, program in self._edge_programs.items(): 1318*523fa7a6SAndroid Build Coastguard Worker program = weights_to_outputs_pass(program) 1319*523fa7a6SAndroid Build Coastguard Worker program = unsafe_remove_auto_functionalized_pass(program) 1320*523fa7a6SAndroid Build Coastguard Worker gm, new_signature = insert_write_back_for_buffers_pass(program) 1321*523fa7a6SAndroid Build Coastguard Worker new_gm = program.graph_module 1322*523fa7a6SAndroid Build Coastguard Worker for p in edge_to_executorch_passes(config, name): 1323*523fa7a6SAndroid Build Coastguard Worker new_gm_res = p(new_gm) 1324*523fa7a6SAndroid Build Coastguard Worker assert new_gm_res is not None 1325*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 1326*523fa7a6SAndroid Build Coastguard Worker if isinstance(p, SpecPropPass): 1327*523fa7a6SAndroid Build Coastguard Worker # Note that this is a hacky way to get around the fact that 1328*523fa7a6SAndroid Build Coastguard Worker # placeholder nodes corresponding to the parameters of the graph module 1329*523fa7a6SAndroid Build Coastguard Worker # shall not participate in memory planning. It increases runtime memory 1330*523fa7a6SAndroid Build Coastguard Worker # footprint. 1331*523fa7a6SAndroid Build Coastguard Worker # Proper way would be to have ExportPass work with ExportedProgram 1332*523fa7a6SAndroid Build Coastguard Worker # instead of GraphModule. This is because ExportPass should work 1333*523fa7a6SAndroid Build Coastguard Worker # on top of the export artifact of torch.export whichi s ExportedProgram. 1334*523fa7a6SAndroid Build Coastguard Worker # Working with GraphModule does not provide all the information contained 1335*523fa7a6SAndroid Build Coastguard Worker # in the ExportedProgram 1336*523fa7a6SAndroid Build Coastguard Worker # TODO(who?) 1337*523fa7a6SAndroid Build Coastguard Worker p.update_placeholder_tensor_specs(program, new_gm) 1338*523fa7a6SAndroid Build Coastguard Worker 1339*523fa7a6SAndroid Build Coastguard Worker if isinstance(config.memory_planning_pass, dict): 1340*523fa7a6SAndroid Build Coastguard Worker memory_planning_pass = config.memory_planning_pass.get( 1341*523fa7a6SAndroid Build Coastguard Worker name, ExecutorchBackendConfig().memory_planning_pass 1342*523fa7a6SAndroid Build Coastguard Worker ) 1343*523fa7a6SAndroid Build Coastguard Worker else: 1344*523fa7a6SAndroid Build Coastguard Worker memory_planning_pass = config.memory_planning_pass 1345*523fa7a6SAndroid Build Coastguard Worker # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work 1346*523fa7a6SAndroid Build Coastguard Worker if hasattr(memory_planning_pass, "run"): 1347*523fa7a6SAndroid Build Coastguard Worker new_gm_res = memory_planning_pass.run( # pyre-ignore[16] 1348*523fa7a6SAndroid Build Coastguard Worker new_gm, new_signature 1349*523fa7a6SAndroid Build Coastguard Worker ) 1350*523fa7a6SAndroid Build Coastguard Worker else: 1351*523fa7a6SAndroid Build Coastguard Worker new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29] 1352*523fa7a6SAndroid Build Coastguard Worker assert new_gm_res is not None 1353*523fa7a6SAndroid Build Coastguard Worker new_gm = new_gm_res.graph_module 1354*523fa7a6SAndroid Build Coastguard Worker 1355*523fa7a6SAndroid Build Coastguard Worker _copy_module(program.graph_module, new_gm) 1356*523fa7a6SAndroid Build Coastguard Worker execution_programs[name] = program 1357*523fa7a6SAndroid Build Coastguard Worker 1358*523fa7a6SAndroid Build Coastguard Worker return ExecutorchProgramManager( 1359*523fa7a6SAndroid Build Coastguard Worker execution_programs, self._config_methods, config 1360*523fa7a6SAndroid Build Coastguard Worker ) 1361*523fa7a6SAndroid Build Coastguard Worker 1362*523fa7a6SAndroid Build Coastguard Worker 1363*523fa7a6SAndroid Build Coastguard Workerclass ExecutorchProgramManager: 1364*523fa7a6SAndroid Build Coastguard Worker """ 1365*523fa7a6SAndroid Build Coastguard Worker Package of one or more `ExportedPrograms` in Execution dialect. Designed to simplify 1366*523fa7a6SAndroid Build Coastguard Worker lowering to ExecuTorch. See: https://pytorch.org/executorch/stable/ir-exir.html 1367*523fa7a6SAndroid Build Coastguard Worker 1368*523fa7a6SAndroid Build Coastguard Worker When the ExecutorchProgramManager is constructed the ExportedPrograms in execution dialect 1369*523fa7a6SAndroid Build Coastguard Worker are used to form the executorch binary (in a process called emission) and then serialized 1370*523fa7a6SAndroid Build Coastguard Worker to a buffer. 1371*523fa7a6SAndroid Build Coastguard Worker 1372*523fa7a6SAndroid Build Coastguard Worker Manages the final link in the lowering chain of ATen -> Edge -> ExecuTorch. 1373*523fa7a6SAndroid Build Coastguard Worker """ 1374*523fa7a6SAndroid Build Coastguard Worker 1375*523fa7a6SAndroid Build Coastguard Worker def __init__( 1376*523fa7a6SAndroid Build Coastguard Worker self, 1377*523fa7a6SAndroid Build Coastguard Worker execution_programs: Dict[str, ExportedProgram], 1378*523fa7a6SAndroid Build Coastguard Worker config_methods: Optional[Dict[str, Any]] = None, 1379*523fa7a6SAndroid Build Coastguard Worker backend_config: Optional[ExecutorchBackendConfig] = None, 1380*523fa7a6SAndroid Build Coastguard Worker ): 1381*523fa7a6SAndroid Build Coastguard Worker """ 1382*523fa7a6SAndroid Build Coastguard Worker End users should not call this constructor directly. Instead, they should use 1383*523fa7a6SAndroid Build Coastguard Worker :func:'to_executorch' to construct an ExecutorchProgramManager. 1384*523fa7a6SAndroid Build Coastguard Worker 1385*523fa7a6SAndroid Build Coastguard Worker Constructs an ExecutorchProgramManager from a set of exported programs in 1386*523fa7a6SAndroid Build Coastguard Worker execution dialect. 1387*523fa7a6SAndroid Build Coastguard Worker 1388*523fa7a6SAndroid Build Coastguard Worker Args: 1389*523fa7a6SAndroid Build Coastguard Worker execution_programs: A dictionary of method name to the corresponding 1390*523fa7a6SAndroid Build Coastguard Worker ExportedProgram. 1391*523fa7a6SAndroid Build Coastguard Worker 1392*523fa7a6SAndroid Build Coastguard Worker config_methods: A dictionary of method name to the config value returned 1393*523fa7a6SAndroid Build Coastguard Worker by that method in eager mode. 1394*523fa7a6SAndroid Build Coastguard Worker 1395*523fa7a6SAndroid Build Coastguard Worker backend_config: An optional argument used to provide greater control over 1396*523fa7a6SAndroid Build Coastguard Worker the emission and serialization. 1397*523fa7a6SAndroid Build Coastguard Worker """ 1398*523fa7a6SAndroid Build Coastguard Worker # Set up methods 1399*523fa7a6SAndroid Build Coastguard Worker self._execution_programs: Dict[str, ExportedProgram] = execution_programs 1400*523fa7a6SAndroid Build Coastguard Worker self._config_methods: Optional[Dict[str, Any]] = config_methods 1401*523fa7a6SAndroid Build Coastguard Worker 1402*523fa7a6SAndroid Build Coastguard Worker backend_config = backend_config or ExecutorchBackendConfig() 1403*523fa7a6SAndroid Build Coastguard Worker 1404*523fa7a6SAndroid Build Coastguard Worker # Emit methods 1405*523fa7a6SAndroid Build Coastguard Worker self._emitter_output: EmitterOutput = emit_program( 1406*523fa7a6SAndroid Build Coastguard Worker self._execution_programs, 1407*523fa7a6SAndroid Build Coastguard Worker backend_config.emit_stacktrace, 1408*523fa7a6SAndroid Build Coastguard Worker self._config_methods, 1409*523fa7a6SAndroid Build Coastguard Worker ) 1410*523fa7a6SAndroid Build Coastguard Worker 1411*523fa7a6SAndroid Build Coastguard Worker # Serialize emitter output, ready to be written to a file. 1412*523fa7a6SAndroid Build Coastguard Worker self._pte_data: Cord = _serialize_pte_binary( 1413*523fa7a6SAndroid Build Coastguard Worker program=self._emitter_output.program, 1414*523fa7a6SAndroid Build Coastguard Worker mutable_data=self._emitter_output.mutable_data, 1415*523fa7a6SAndroid Build Coastguard Worker extract_delegate_segments=backend_config.extract_delegate_segments, 1416*523fa7a6SAndroid Build Coastguard Worker segment_alignment=backend_config.segment_alignment, 1417*523fa7a6SAndroid Build Coastguard Worker constant_tensor_alignment=backend_config.constant_tensor_alignment, 1418*523fa7a6SAndroid Build Coastguard Worker delegate_alignment=backend_config.delegate_alignment, 1419*523fa7a6SAndroid Build Coastguard Worker ) 1420*523fa7a6SAndroid Build Coastguard Worker self._buffer: Optional[bytes] = None 1421*523fa7a6SAndroid Build Coastguard Worker 1422*523fa7a6SAndroid Build Coastguard Worker @property 1423*523fa7a6SAndroid Build Coastguard Worker def methods(self) -> Set[str]: 1424*523fa7a6SAndroid Build Coastguard Worker """ 1425*523fa7a6SAndroid Build Coastguard Worker Returns the set of methods in this ExecutorchProgramManager. 1426*523fa7a6SAndroid Build Coastguard Worker """ 1427*523fa7a6SAndroid Build Coastguard Worker return set(self._execution_programs.keys()) 1428*523fa7a6SAndroid Build Coastguard Worker 1429*523fa7a6SAndroid Build Coastguard Worker @property 1430*523fa7a6SAndroid Build Coastguard Worker def config_methods(self) -> Set[str]: 1431*523fa7a6SAndroid Build Coastguard Worker """ 1432*523fa7a6SAndroid Build Coastguard Worker Returns the set of config methods in this ExecutorchProgramManager. 1433*523fa7a6SAndroid Build Coastguard Worker """ 1434*523fa7a6SAndroid Build Coastguard Worker return set(self._config_methods.keys()) if self._config_methods else set() 1435*523fa7a6SAndroid Build Coastguard Worker 1436*523fa7a6SAndroid Build Coastguard Worker def exported_program(self, method_name: str = "forward") -> ExportedProgram: 1437*523fa7a6SAndroid Build Coastguard Worker """ 1438*523fa7a6SAndroid Build Coastguard Worker Returns the ExportedProgram specified by 'method_name'. 1439*523fa7a6SAndroid Build Coastguard Worker """ 1440*523fa7a6SAndroid Build Coastguard Worker return self._execution_programs[method_name] 1441*523fa7a6SAndroid Build Coastguard Worker 1442*523fa7a6SAndroid Build Coastguard Worker def dump_executorch_program( 1443*523fa7a6SAndroid Build Coastguard Worker self, verbose: bool = False, out: Optional[TextIO] = None 1444*523fa7a6SAndroid Build Coastguard Worker ) -> None: 1445*523fa7a6SAndroid Build Coastguard Worker """ 1446*523fa7a6SAndroid Build Coastguard Worker Prints the ExecuTorch binary in a human readable format. 1447*523fa7a6SAndroid Build Coastguard Worker 1448*523fa7a6SAndroid Build Coastguard Worker Args: 1449*523fa7a6SAndroid Build Coastguard Worker verbose (bool): 1450*523fa7a6SAndroid Build Coastguard Worker If False prints the binary in a condensed format. 1451*523fa7a6SAndroid Build Coastguard Worker If True prints the binary 1-1 with the specification in the schema. 1452*523fa7a6SAndroid Build Coastguard Worker out: 1453*523fa7a6SAndroid Build Coastguard Worker If None, prints to stdout. 1454*523fa7a6SAndroid Build Coastguard Worker If non-None, writes the string to that stream object. It can be 1455*523fa7a6SAndroid Build Coastguard Worker a file object, a StringIO object, or any other TextIO subclass. 1456*523fa7a6SAndroid Build Coastguard Worker """ 1457*523fa7a6SAndroid Build Coastguard Worker if verbose: 1458*523fa7a6SAndroid Build Coastguard Worker pretty_print(self._emitter_output.program, out=out) 1459*523fa7a6SAndroid Build Coastguard Worker else: 1460*523fa7a6SAndroid Build Coastguard Worker print_program(self._emitter_output.program, out=out) 1461*523fa7a6SAndroid Build Coastguard Worker 1462*523fa7a6SAndroid Build Coastguard Worker @property 1463*523fa7a6SAndroid Build Coastguard Worker def debug_handle_map(self) -> Dict[int, Union[int, List[int]]]: 1464*523fa7a6SAndroid Build Coastguard Worker return self._emitter_output.debug_handle_map 1465*523fa7a6SAndroid Build Coastguard Worker 1466*523fa7a6SAndroid Build Coastguard Worker @property 1467*523fa7a6SAndroid Build Coastguard Worker def delegate_map( 1468*523fa7a6SAndroid Build Coastguard Worker self, 1469*523fa7a6SAndroid Build Coastguard Worker ) -> Dict[str, Dict[int, Dict[str, Union[str, _DelegateDebugIdentifierMap]]]]: 1470*523fa7a6SAndroid Build Coastguard Worker return self._emitter_output.method_to_delegate_debug_id_map 1471*523fa7a6SAndroid Build Coastguard Worker 1472*523fa7a6SAndroid Build Coastguard Worker @property 1473*523fa7a6SAndroid Build Coastguard Worker def executorch_program(self) -> Program: 1474*523fa7a6SAndroid Build Coastguard Worker """ 1475*523fa7a6SAndroid Build Coastguard Worker Returns the object that represents the ExecuTorch binary before serialization. 1476*523fa7a6SAndroid Build Coastguard Worker """ 1477*523fa7a6SAndroid Build Coastguard Worker return self._emitter_output.program 1478*523fa7a6SAndroid Build Coastguard Worker 1479*523fa7a6SAndroid Build Coastguard Worker @property 1480*523fa7a6SAndroid Build Coastguard Worker def buffer(self) -> bytes: 1481*523fa7a6SAndroid Build Coastguard Worker """Returns the serialized ExecuTorch binary as a byte string. 1482*523fa7a6SAndroid Build Coastguard Worker 1483*523fa7a6SAndroid Build Coastguard Worker Note that the call to `buffer` may allocate a very large amount of 1484*523fa7a6SAndroid Build Coastguard Worker contiguous memory, depending on the model size. If writing to a file, 1485*523fa7a6SAndroid Build Coastguard Worker use `write_to_file` which won't incur additional copies. 1486*523fa7a6SAndroid Build Coastguard Worker """ 1487*523fa7a6SAndroid Build Coastguard Worker # TODO(T181494963): update pybinding to remove buffer cache, which can consume large 1488*523fa7a6SAndroid Build Coastguard Worker # amounts of memory longer than necessary. 1489*523fa7a6SAndroid Build Coastguard Worker if self._buffer is None: 1490*523fa7a6SAndroid Build Coastguard Worker self._buffer = bytes(self._pte_data) 1491*523fa7a6SAndroid Build Coastguard Worker return self._buffer 1492*523fa7a6SAndroid Build Coastguard Worker 1493*523fa7a6SAndroid Build Coastguard Worker def write_to_file(self, open_file: io.BufferedIOBase) -> None: 1494*523fa7a6SAndroid Build Coastguard Worker """ 1495*523fa7a6SAndroid Build Coastguard Worker Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over 1496*523fa7a6SAndroid Build Coastguard Worker `buffer`, as it writes to file without copying into a contiguous block of memory first, 1497*523fa7a6SAndroid Build Coastguard Worker reducing the peak memory usage. 1498*523fa7a6SAndroid Build Coastguard Worker """ 1499*523fa7a6SAndroid Build Coastguard Worker self._pte_data.write_to_file(open_file) 1500