1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import copy 7from collections import defaultdict 8from typing import Any, Dict, List 9 10import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager 11import torch 12from executorch.backends.qualcomm.builders import node_visitor 13from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader 14from executorch.backends.qualcomm.qnn_preprocess import QnnBackend 15from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER 16 17from executorch.exir.backend.backend_details import CompileSpec 18from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( 19 generate_partitions_from_list_of_nodes, 20) 21from executorch.exir.backend.partitioner import ( 22 DelegationSpec, 23 Partitioner, 24 PartitionResult, 25) 26from executorch.exir.backend.utils import tag_constant_data 27from torch.fx.passes.infra.partitioner import Partition 28from torch.fx.passes.operator_support import OperatorSupportBase 29 30from .common_defs import ( 31 allow_list_operator, 32 not_supported_operator, 33 to_be_implemented_operator, 34) 35from .utils import generate_qnn_executorch_option 36 37 38class QnnOperatorSupport(OperatorSupportBase): 39 def __init__( 40 self, 41 edge_program: torch.export.ExportedProgram, 42 compiler_specs, 43 skip_node_id_set: set = None, 44 skip_node_op_set: set = None, 45 ): 46 self.node_visitors = node_visitor.get_node_visitors(edge_program) 47 48 self.skip_node_op_set = skip_node_op_set 49 self.skip_node_id_set = skip_node_id_set 50 self.nodes_to_wrappers = defaultdict(dict) 51 self.qnn_manager = PyQnnManager.QnnManager( 52 generate_qnn_executorch_option(compiler_specs) 53 ) 54 55 self.qnn_manager.Init() 56 57 def is_node_supported(self, _, node: torch.fx.Node) -> bool: 58 if node.op != "call_function" or node.target in not_supported_operator: 59 return False 60 61 if node.target in to_be_implemented_operator: 62 print( 63 f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped, this op can be supported, please report an issue in https://github.com/pytorch/executorch/issues" 64 ) 65 return False 66 67 if ( 68 node.target in allow_list_operator 69 # bypass if custom op appears 70 or OpContextLoader.namespace == node.target.namespace 71 ): 72 return True 73 74 if ( 75 node.name in self.skip_node_id_set 76 or node.target.__name__ in self.skip_node_op_set 77 ): 78 print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped") 79 return False 80 81 supported = False 82 op_wrapper = self.node_visitors[node.target.__name__].define_node( 83 node, self.nodes_to_wrappers 84 ) 85 86 op_wrapper_list = [] 87 if isinstance(op_wrapper, List): 88 op_wrapper_list.extend(op_wrapper) 89 else: 90 op_wrapper_list.append(op_wrapper) 91 92 if op_wrapper is not None: 93 supported = self.qnn_manager.IsNodeSupportedByBackend( 94 [op_wrapper.GetOpWrapper() for op_wrapper in op_wrapper_list] 95 ) 96 97 self.nodes_to_wrappers.clear() 98 print(f"[QNN Partitioner Op Support]: {node.target.__name__} | {supported}") 99 return supported 100 101 def __del__(self): 102 self.qnn_manager.Destroy() 103 104 105class QnnPartitioner(Partitioner): 106 def __init__( 107 self, 108 compiler_specs: List[CompileSpec], 109 skip_node_id_set: set = None, 110 skip_node_op_set: set = None, 111 ): 112 self.compiler_specs_snapshot = copy.deepcopy(compiler_specs) 113 114 self.delegation_spec = DelegationSpec( 115 QnnBackend.__name__, self.compiler_specs_snapshot 116 ) 117 self.partition_tags: Dict[str, DelegationSpec] = {} 118 self.skip_node_id_set = set() if skip_node_id_set is None else skip_node_id_set 119 self.skip_node_op_set = set() if skip_node_op_set is None else skip_node_op_set 120 121 def generate_partitions( 122 self, edge_program: torch.export.ExportedProgram 123 ) -> List[Any]: 124 self.op_support_checker = QnnOperatorSupport( 125 edge_program, 126 self.compiler_specs_snapshot, 127 self.skip_node_id_set, 128 self.skip_node_op_set, 129 ) 130 return generate_partitions_from_list_of_nodes( 131 edge_program.graph_module, 132 op_support=self.op_support_checker, 133 ) 134 135 def tag_nodes( 136 self, partitions: List[Partition], edge_program: torch.export.ExportedProgram 137 ) -> None: 138 for partition in partitions: 139 for node in partition.nodes: 140 delegation_tag = f"qnn_{partition.id}" 141 node.meta["delegation_tag"] = delegation_tag 142 self.partition_tags[delegation_tag] = self.delegation_spec 143 144 # need to take care of consumed constants 145 consumed_constants = ( 146 *edge_program.graph_signature.inputs_to_buffers, 147 *edge_program.graph_signature.inputs_to_parameters, 148 ) 149 for node in edge_program.graph_module.graph.nodes: 150 # find placeholders as lifted_constants 151 if node.op != "placeholder" or len(node.users) != 0: 152 continue 153 154 if node.name in consumed_constants: 155 # does no harm to merge them into last partition, 156 # since they will all be removed in following stage 157 node.meta["delegation_tag"] = delegation_tag 158 159 # override 160 def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResult: 161 partitions = self.generate_partitions(edge_program) 162 if len(partitions) != 0: 163 self.tag_nodes(partitions, edge_program) 164 tag_constant_data(edge_program) 165 for node in edge_program.graph_module.graph.nodes: 166 if hasattr(node, "meta"): 167 # pop certain keys in meta for not affecting the passes in compilation 168 # TODO: need to put property name in common definitions 169 node.meta.pop(QCOM_AXIS_ORDER, "") 170 del self.op_support_checker 171 return PartitionResult( 172 tagged_exported_program=edge_program, partition_tags=self.partition_tags 173 ) 174