xref: /aosp_15_r20/external/executorch/backends/qualcomm/partition/qnn_partitioner.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import copy
7from collections import defaultdict
8from typing import Any, Dict, List
9
10import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
11import torch
12from executorch.backends.qualcomm.builders import node_visitor
13from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
14from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
15from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER
16
17from executorch.exir.backend.backend_details import CompileSpec
18from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
19    generate_partitions_from_list_of_nodes,
20)
21from executorch.exir.backend.partitioner import (
22    DelegationSpec,
23    Partitioner,
24    PartitionResult,
25)
26from executorch.exir.backend.utils import tag_constant_data
27from torch.fx.passes.infra.partitioner import Partition
28from torch.fx.passes.operator_support import OperatorSupportBase
29
30from .common_defs import (
31    allow_list_operator,
32    not_supported_operator,
33    to_be_implemented_operator,
34)
35from .utils import generate_qnn_executorch_option
36
37
38class QnnOperatorSupport(OperatorSupportBase):
39    def __init__(
40        self,
41        edge_program: torch.export.ExportedProgram,
42        compiler_specs,
43        skip_node_id_set: set = None,
44        skip_node_op_set: set = None,
45    ):
46        self.node_visitors = node_visitor.get_node_visitors(edge_program)
47
48        self.skip_node_op_set = skip_node_op_set
49        self.skip_node_id_set = skip_node_id_set
50        self.nodes_to_wrappers = defaultdict(dict)
51        self.qnn_manager = PyQnnManager.QnnManager(
52            generate_qnn_executorch_option(compiler_specs)
53        )
54
55        self.qnn_manager.Init()
56
57    def is_node_supported(self, _, node: torch.fx.Node) -> bool:
58        if node.op != "call_function" or node.target in not_supported_operator:
59            return False
60
61        if node.target in to_be_implemented_operator:
62            print(
63                f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped, this op can be supported, please report an issue in https://github.com/pytorch/executorch/issues"
64            )
65            return False
66
67        if (
68            node.target in allow_list_operator
69            # bypass if custom op appears
70            or OpContextLoader.namespace == node.target.namespace
71        ):
72            return True
73
74        if (
75            node.name in self.skip_node_id_set
76            or node.target.__name__ in self.skip_node_op_set
77        ):
78            print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped")
79            return False
80
81        supported = False
82        op_wrapper = self.node_visitors[node.target.__name__].define_node(
83            node, self.nodes_to_wrappers
84        )
85
86        op_wrapper_list = []
87        if isinstance(op_wrapper, List):
88            op_wrapper_list.extend(op_wrapper)
89        else:
90            op_wrapper_list.append(op_wrapper)
91
92        if op_wrapper is not None:
93            supported = self.qnn_manager.IsNodeSupportedByBackend(
94                [op_wrapper.GetOpWrapper() for op_wrapper in op_wrapper_list]
95            )
96
97        self.nodes_to_wrappers.clear()
98        print(f"[QNN Partitioner Op Support]: {node.target.__name__} | {supported}")
99        return supported
100
101    def __del__(self):
102        self.qnn_manager.Destroy()
103
104
105class QnnPartitioner(Partitioner):
106    def __init__(
107        self,
108        compiler_specs: List[CompileSpec],
109        skip_node_id_set: set = None,
110        skip_node_op_set: set = None,
111    ):
112        self.compiler_specs_snapshot = copy.deepcopy(compiler_specs)
113
114        self.delegation_spec = DelegationSpec(
115            QnnBackend.__name__, self.compiler_specs_snapshot
116        )
117        self.partition_tags: Dict[str, DelegationSpec] = {}
118        self.skip_node_id_set = set() if skip_node_id_set is None else skip_node_id_set
119        self.skip_node_op_set = set() if skip_node_op_set is None else skip_node_op_set
120
121    def generate_partitions(
122        self, edge_program: torch.export.ExportedProgram
123    ) -> List[Any]:
124        self.op_support_checker = QnnOperatorSupport(
125            edge_program,
126            self.compiler_specs_snapshot,
127            self.skip_node_id_set,
128            self.skip_node_op_set,
129        )
130        return generate_partitions_from_list_of_nodes(
131            edge_program.graph_module,
132            op_support=self.op_support_checker,
133        )
134
135    def tag_nodes(
136        self, partitions: List[Partition], edge_program: torch.export.ExportedProgram
137    ) -> None:
138        for partition in partitions:
139            for node in partition.nodes:
140                delegation_tag = f"qnn_{partition.id}"
141                node.meta["delegation_tag"] = delegation_tag
142                self.partition_tags[delegation_tag] = self.delegation_spec
143
144        # need to take care of consumed constants
145        consumed_constants = (
146            *edge_program.graph_signature.inputs_to_buffers,
147            *edge_program.graph_signature.inputs_to_parameters,
148        )
149        for node in edge_program.graph_module.graph.nodes:
150            # find placeholders as lifted_constants
151            if node.op != "placeholder" or len(node.users) != 0:
152                continue
153
154            if node.name in consumed_constants:
155                # does no harm to merge them into last partition,
156                # since they will all be removed in following stage
157                node.meta["delegation_tag"] = delegation_tag
158
159    # override
160    def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResult:
161        partitions = self.generate_partitions(edge_program)
162        if len(partitions) != 0:
163            self.tag_nodes(partitions, edge_program)
164            tag_constant_data(edge_program)
165        for node in edge_program.graph_module.graph.nodes:
166            if hasattr(node, "meta"):
167                # pop certain keys in meta for not affecting the passes in compilation
168                # TODO: need to put property name in common definitions
169                node.meta.pop(QCOM_AXIS_ORDER, "")
170        del self.op_support_checker
171        return PartitionResult(
172            tagged_exported_program=edge_program, partition_tags=self.partition_tags
173        )
174