1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport logging 8*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass 9*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, final, List 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport torch 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack._passes import XNNPACKPassManager 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass 15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import ( 16*523fa7a6SAndroid Build Coastguard Worker TagImplicitQDqPass, 17*523fa7a6SAndroid Build Coastguard Worker) 18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.operators.node_visitor import get_node_visitors 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 21*523fa7a6SAndroid Build Coastguard Worker ConstantDataOffset, 22*523fa7a6SAndroid Build Coastguard Worker XNNGraph, 23*523fa7a6SAndroid Build Coastguard Worker) 24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import ( 25*523fa7a6SAndroid Build Coastguard Worker serialize_xnnpack_binary, 26*523fa7a6SAndroid Build Coastguard Worker) 27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config 28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.utils.utils import is_param_node 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.utils.xnnpack_constants import ( 31*523fa7a6SAndroid Build Coastguard Worker XNN_VALUE_FLAG_EXTERNAL_INPUT, 32*523fa7a6SAndroid Build Coastguard Worker XNN_VALUE_FLAG_EXTERNAL_OUTPUT, 33*523fa7a6SAndroid Build Coastguard Worker) 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import ( 36*523fa7a6SAndroid Build Coastguard Worker BackendDetails, 37*523fa7a6SAndroid Build Coastguard Worker CompileSpec, 38*523fa7a6SAndroid Build Coastguard Worker PreprocessResult, 39*523fa7a6SAndroid Build Coastguard Worker) 40*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.verifier import EXIREdgeDialectVerifier 41*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard WorkerDEFAULT_DEBUG_HANDLE = 65535 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__) 46*523fa7a6SAndroid Build Coastguard Workerlogger.setLevel(logging.WARNING) 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Worker@dataclass 50*523fa7a6SAndroid Build Coastguard Workerclass ExternalMeta: 51*523fa7a6SAndroid Build Coastguard Worker external_id: int 52*523fa7a6SAndroid Build Coastguard Worker io_type: int 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker 55*523fa7a6SAndroid Build Coastguard Workerdef generate_node_to_external_map( 56*523fa7a6SAndroid Build Coastguard Worker exported_program: ExportedProgram, 57*523fa7a6SAndroid Build Coastguard Worker edge_graph_module: torch.fx.GraphModule, 58*523fa7a6SAndroid Build Coastguard Worker) -> Dict[torch.fx.Node, ExternalMeta]: 59*523fa7a6SAndroid Build Coastguard Worker node_to_external_map = {} 60*523fa7a6SAndroid Build Coastguard Worker for node in edge_graph_module.graph.nodes: 61*523fa7a6SAndroid Build Coastguard Worker # The order in which we visit the placeholder node is same as the *args 62*523fa7a6SAndroid Build Coastguard Worker # order for the forward(*args) signature for this gm. Using the order of 63*523fa7a6SAndroid Build Coastguard Worker # the nodes as external_id to extract the right arg from *args at runtime 64*523fa7a6SAndroid Build Coastguard Worker # 65*523fa7a6SAndroid Build Coastguard Worker # Removing parameters/buffers since they will disappear from the signature 66*523fa7a6SAndroid Build Coastguard Worker # at runtime 67*523fa7a6SAndroid Build Coastguard Worker if node.op == "placeholder" and not is_param_node(exported_program, node): 68*523fa7a6SAndroid Build Coastguard Worker node_to_external_map[node] = ExternalMeta( 69*523fa7a6SAndroid Build Coastguard Worker external_id=len(node_to_external_map), 70*523fa7a6SAndroid Build Coastguard Worker io_type=XNN_VALUE_FLAG_EXTERNAL_INPUT, 71*523fa7a6SAndroid Build Coastguard Worker ) 72*523fa7a6SAndroid Build Coastguard Worker for node in edge_graph_module.graph.nodes: 73*523fa7a6SAndroid Build Coastguard Worker if node.op == "output": 74*523fa7a6SAndroid Build Coastguard Worker for output_nodes in node.args: 75*523fa7a6SAndroid Build Coastguard Worker for output_node in output_nodes: 76*523fa7a6SAndroid Build Coastguard Worker node_to_external_map[output_node] = ExternalMeta( 77*523fa7a6SAndroid Build Coastguard Worker external_id=len(node_to_external_map), 78*523fa7a6SAndroid Build Coastguard Worker io_type=XNN_VALUE_FLAG_EXTERNAL_OUTPUT, 79*523fa7a6SAndroid Build Coastguard Worker ) 80*523fa7a6SAndroid Build Coastguard Worker return node_to_external_map 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Workerdef assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None: 84*523fa7a6SAndroid Build Coastguard Worker for node in edge_graph_module.graph.nodes: 85*523fa7a6SAndroid Build Coastguard Worker if node.op != "placeholder": 86*523fa7a6SAndroid Build Coastguard Worker continue 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker # We expect the default dim order for all tensor-like inputs i.e. inputs, buffers, and params 89*523fa7a6SAndroid Build Coastguard Worker t = node.meta.get("val", None) 90*523fa7a6SAndroid Build Coastguard Worker if t is not None and getattr(t, "dim_order", None) is not None: 91*523fa7a6SAndroid Build Coastguard Worker default_dim_order = tuple(range(t.dim())) 92*523fa7a6SAndroid Build Coastguard Worker if t.dim_order() != default_dim_order: 93*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 94*523fa7a6SAndroid Build Coastguard Worker f"XNNPACK backend only supports contiguous memory format for inputs." 95*523fa7a6SAndroid Build Coastguard Worker f"Expecting dim_order: {default_dim_order}, but got {node.meta['val'].dim_order()} for a placeholder node {node}." 96*523fa7a6SAndroid Build Coastguard Worker ) 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker@final 100*523fa7a6SAndroid Build Coastguard Workerclass XnnpackBackend(BackendDetails): 101*523fa7a6SAndroid Build Coastguard Worker @staticmethod 102*523fa7a6SAndroid Build Coastguard Worker def preprocess( 103*523fa7a6SAndroid Build Coastguard Worker edge_program: ExportedProgram, 104*523fa7a6SAndroid Build Coastguard Worker compile_specs: List[CompileSpec], 105*523fa7a6SAndroid Build Coastguard Worker ) -> PreprocessResult: 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker xnnpack_edge_compile_config = get_xnnpack_edge_compile_config() 108*523fa7a6SAndroid Build Coastguard Worker 109*523fa7a6SAndroid Build Coastguard Worker # Need to wrap EP here because xnnpack does addmm to linear 110*523fa7a6SAndroid Build Coastguard Worker # transforms. This makes resulting graph not aten compliant 111*523fa7a6SAndroid Build Coastguard Worker # as aten.linear is not a core aten op. 112*523fa7a6SAndroid Build Coastguard Worker # Ideal fix would be to have XNNPACK verifier that bypass 113*523fa7a6SAndroid Build Coastguard Worker # most checks but the base Verifier itself has some strict changes 114*523fa7a6SAndroid Build Coastguard Worker # and to bypass those, we would basically copy what EdgeDialectVerifier 115*523fa7a6SAndroid Build Coastguard Worker # does. So for now instead of copy pasting that, just instantiate 116*523fa7a6SAndroid Build Coastguard Worker # EdgeDialectVerifier, but disable it. 117*523fa7a6SAndroid Build Coastguard Worker # TODO (task link) to implement NullVerifier or something similar 118*523fa7a6SAndroid Build Coastguard Worker ep = ExportedProgram( 119*523fa7a6SAndroid Build Coastguard Worker root=edge_program.graph_module, 120*523fa7a6SAndroid Build Coastguard Worker graph=edge_program.graph, 121*523fa7a6SAndroid Build Coastguard Worker graph_signature=edge_program.graph_signature, 122*523fa7a6SAndroid Build Coastguard Worker state_dict=edge_program.state_dict, 123*523fa7a6SAndroid Build Coastguard Worker range_constraints=edge_program.range_constraints, 124*523fa7a6SAndroid Build Coastguard Worker module_call_graph=edge_program.module_call_graph, 125*523fa7a6SAndroid Build Coastguard Worker example_inputs=edge_program.example_inputs, 126*523fa7a6SAndroid Build Coastguard Worker constants=edge_program.constants, 127*523fa7a6SAndroid Build Coastguard Worker verifiers=[ 128*523fa7a6SAndroid Build Coastguard Worker EXIREdgeDialectVerifier( 129*523fa7a6SAndroid Build Coastguard Worker edge_compile_config=xnnpack_edge_compile_config, class_only=True 130*523fa7a6SAndroid Build Coastguard Worker ) 131*523fa7a6SAndroid Build Coastguard Worker ], 132*523fa7a6SAndroid Build Coastguard Worker ) 133*523fa7a6SAndroid Build Coastguard Worker 134*523fa7a6SAndroid Build Coastguard Worker passes = [] 135*523fa7a6SAndroid Build Coastguard Worker for spec in compile_specs: 136*523fa7a6SAndroid Build Coastguard Worker if spec.key == "dqlinear_partitioner": 137*523fa7a6SAndroid Build Coastguard Worker passes.append(ConvertToLinearPass) 138*523fa7a6SAndroid Build Coastguard Worker passes.append(TagImplicitQDqPass) 139*523fa7a6SAndroid Build Coastguard Worker 140*523fa7a6SAndroid Build Coastguard Worker passes = passes if len(passes) > 0 else None 141*523fa7a6SAndroid Build Coastguard Worker # XNNPACK Delegate Specific Passes 142*523fa7a6SAndroid Build Coastguard Worker ep = XNNPACKPassManager(ep, passes=passes).transform() 143*523fa7a6SAndroid Build Coastguard Worker graph_module = ep.graph_module 144*523fa7a6SAndroid Build Coastguard Worker 145*523fa7a6SAndroid Build Coastguard Worker node_to_external_map = generate_node_to_external_map(ep, graph_module) 146*523fa7a6SAndroid Build Coastguard Worker 147*523fa7a6SAndroid Build Coastguard Worker # Make sure all inputs are contiguous_format or NCHW or default dim order 148*523fa7a6SAndroid Build Coastguard Worker assert_default_dim_order(graph_module) 149*523fa7a6SAndroid Build Coastguard Worker 150*523fa7a6SAndroid Build Coastguard Worker # TODO retrace the graph module to lift the new params may have 151*523fa7a6SAndroid Build Coastguard Worker # been added to the graph in passes 152*523fa7a6SAndroid Build Coastguard Worker 153*523fa7a6SAndroid Build Coastguard Worker vals_to_ids = {} 154*523fa7a6SAndroid Build Coastguard Worker xnnpack_graph = XNNGraph( 155*523fa7a6SAndroid Build Coastguard Worker version="0", 156*523fa7a6SAndroid Build Coastguard Worker xnodes=[], 157*523fa7a6SAndroid Build Coastguard Worker xvalues=[], 158*523fa7a6SAndroid Build Coastguard Worker num_externs=len(node_to_external_map), 159*523fa7a6SAndroid Build Coastguard Worker input_ids=[], 160*523fa7a6SAndroid Build Coastguard Worker output_ids=[], 161*523fa7a6SAndroid Build Coastguard Worker constant_data=[ConstantDataOffset(0, 0)], 162*523fa7a6SAndroid Build Coastguard Worker ) 163*523fa7a6SAndroid Build Coastguard Worker 164*523fa7a6SAndroid Build Coastguard Worker constant_data_bytes = bytearray() 165*523fa7a6SAndroid Build Coastguard Worker node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes) 166*523fa7a6SAndroid Build Coastguard Worker 167*523fa7a6SAndroid Build Coastguard Worker for node in graph_module.graph.nodes: 168*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function": 169*523fa7a6SAndroid Build Coastguard Worker logger.info(f"Visiting: {node}, {node.target.__name__}") 170*523fa7a6SAndroid Build Coastguard Worker if node.target.__name__ in node_visitors: 171*523fa7a6SAndroid Build Coastguard Worker node_visitors[node.target.__name__].define_node( 172*523fa7a6SAndroid Build Coastguard Worker node, 173*523fa7a6SAndroid Build Coastguard Worker xnnpack_graph, 174*523fa7a6SAndroid Build Coastguard Worker vals_to_ids, 175*523fa7a6SAndroid Build Coastguard Worker node.meta.get("debug_handle", DEFAULT_DEBUG_HANDLE), 176*523fa7a6SAndroid Build Coastguard Worker ) 177*523fa7a6SAndroid Build Coastguard Worker else: 178*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 179*523fa7a6SAndroid Build Coastguard Worker f"For {node}, {node.op}:{node.target.__name__} is not supported in XNNPACK Delegate" 180*523fa7a6SAndroid Build Coastguard Worker ) 181*523fa7a6SAndroid Build Coastguard Worker elif node.op in [ 182*523fa7a6SAndroid Build Coastguard Worker "get_attr", 183*523fa7a6SAndroid Build Coastguard Worker "placeholder", 184*523fa7a6SAndroid Build Coastguard Worker "output", 185*523fa7a6SAndroid Build Coastguard Worker ]: 186*523fa7a6SAndroid Build Coastguard Worker continue 187*523fa7a6SAndroid Build Coastguard Worker else: 188*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"{node.op} is not supported in XNNPACK") 189*523fa7a6SAndroid Build Coastguard Worker return PreprocessResult( 190*523fa7a6SAndroid Build Coastguard Worker processed_bytes=serialize_xnnpack_binary( 191*523fa7a6SAndroid Build Coastguard Worker xnnpack_graph, constant_data_bytes 192*523fa7a6SAndroid Build Coastguard Worker ), 193*523fa7a6SAndroid Build Coastguard Worker debug_handle_map={}, 194*523fa7a6SAndroid Build Coastguard Worker ) 195