xref: /aosp_15_r20/external/executorch/backends/xnnpack/xnnpack_preprocess.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport logging
8*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass
9*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, final, List
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport torch
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack._passes import XNNPACKPassManager
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack._passes.tag_implicit_q_dq_pass import (
16*523fa7a6SAndroid Build Coastguard Worker    TagImplicitQDqPass,
17*523fa7a6SAndroid Build Coastguard Worker)
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.operators.node_visitor import get_node_visitors
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.serialization.xnnpack_graph_schema import (
21*523fa7a6SAndroid Build Coastguard Worker    ConstantDataOffset,
22*523fa7a6SAndroid Build Coastguard Worker    XNNGraph,
23*523fa7a6SAndroid Build Coastguard Worker)
24*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.serialization.xnnpack_graph_serialize import (
25*523fa7a6SAndroid Build Coastguard Worker    serialize_xnnpack_binary,
26*523fa7a6SAndroid Build Coastguard Worker)
27*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
28*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.utils.utils import is_param_node
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.utils.xnnpack_constants import (
31*523fa7a6SAndroid Build Coastguard Worker    XNN_VALUE_FLAG_EXTERNAL_INPUT,
32*523fa7a6SAndroid Build Coastguard Worker    XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
33*523fa7a6SAndroid Build Coastguard Worker)
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.backend.backend_details import (
36*523fa7a6SAndroid Build Coastguard Worker    BackendDetails,
37*523fa7a6SAndroid Build Coastguard Worker    CompileSpec,
38*523fa7a6SAndroid Build Coastguard Worker    PreprocessResult,
39*523fa7a6SAndroid Build Coastguard Worker)
40*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.verification.verifier import EXIREdgeDialectVerifier
41*523fa7a6SAndroid Build Coastguard Workerfrom torch.export.exported_program import ExportedProgram
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard WorkerDEFAULT_DEBUG_HANDLE = 65535
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Workerlogger = logging.getLogger(__name__)
46*523fa7a6SAndroid Build Coastguard Workerlogger.setLevel(logging.WARNING)
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker@dataclass
50*523fa7a6SAndroid Build Coastguard Workerclass ExternalMeta:
51*523fa7a6SAndroid Build Coastguard Worker    external_id: int
52*523fa7a6SAndroid Build Coastguard Worker    io_type: int
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker
55*523fa7a6SAndroid Build Coastguard Workerdef generate_node_to_external_map(
56*523fa7a6SAndroid Build Coastguard Worker    exported_program: ExportedProgram,
57*523fa7a6SAndroid Build Coastguard Worker    edge_graph_module: torch.fx.GraphModule,
58*523fa7a6SAndroid Build Coastguard Worker) -> Dict[torch.fx.Node, ExternalMeta]:
59*523fa7a6SAndroid Build Coastguard Worker    node_to_external_map = {}
60*523fa7a6SAndroid Build Coastguard Worker    for node in edge_graph_module.graph.nodes:
61*523fa7a6SAndroid Build Coastguard Worker        # The order in which we visit the placeholder node is same as the *args
62*523fa7a6SAndroid Build Coastguard Worker        # order for the forward(*args) signature for this gm. Using the order of
63*523fa7a6SAndroid Build Coastguard Worker        # the nodes as external_id to extract the right arg from *args at runtime
64*523fa7a6SAndroid Build Coastguard Worker        #
65*523fa7a6SAndroid Build Coastguard Worker        # Removing parameters/buffers since they will disappear from the signature
66*523fa7a6SAndroid Build Coastguard Worker        # at runtime
67*523fa7a6SAndroid Build Coastguard Worker        if node.op == "placeholder" and not is_param_node(exported_program, node):
68*523fa7a6SAndroid Build Coastguard Worker            node_to_external_map[node] = ExternalMeta(
69*523fa7a6SAndroid Build Coastguard Worker                external_id=len(node_to_external_map),
70*523fa7a6SAndroid Build Coastguard Worker                io_type=XNN_VALUE_FLAG_EXTERNAL_INPUT,
71*523fa7a6SAndroid Build Coastguard Worker            )
72*523fa7a6SAndroid Build Coastguard Worker    for node in edge_graph_module.graph.nodes:
73*523fa7a6SAndroid Build Coastguard Worker        if node.op == "output":
74*523fa7a6SAndroid Build Coastguard Worker            for output_nodes in node.args:
75*523fa7a6SAndroid Build Coastguard Worker                for output_node in output_nodes:
76*523fa7a6SAndroid Build Coastguard Worker                    node_to_external_map[output_node] = ExternalMeta(
77*523fa7a6SAndroid Build Coastguard Worker                        external_id=len(node_to_external_map),
78*523fa7a6SAndroid Build Coastguard Worker                        io_type=XNN_VALUE_FLAG_EXTERNAL_OUTPUT,
79*523fa7a6SAndroid Build Coastguard Worker                    )
80*523fa7a6SAndroid Build Coastguard Worker    return node_to_external_map
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Workerdef assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None:
84*523fa7a6SAndroid Build Coastguard Worker    for node in edge_graph_module.graph.nodes:
85*523fa7a6SAndroid Build Coastguard Worker        if node.op != "placeholder":
86*523fa7a6SAndroid Build Coastguard Worker            continue
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker        # We expect the default dim order for all tensor-like inputs i.e. inputs, buffers, and params
89*523fa7a6SAndroid Build Coastguard Worker        t = node.meta.get("val", None)
90*523fa7a6SAndroid Build Coastguard Worker        if t is not None and getattr(t, "dim_order", None) is not None:
91*523fa7a6SAndroid Build Coastguard Worker            default_dim_order = tuple(range(t.dim()))
92*523fa7a6SAndroid Build Coastguard Worker            if t.dim_order() != default_dim_order:
93*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError(
94*523fa7a6SAndroid Build Coastguard Worker                    f"XNNPACK backend only supports contiguous memory format for inputs."
95*523fa7a6SAndroid Build Coastguard Worker                    f"Expecting dim_order: {default_dim_order}, but got {node.meta['val'].dim_order()} for a placeholder node {node}."
96*523fa7a6SAndroid Build Coastguard Worker                )
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker@final
100*523fa7a6SAndroid Build Coastguard Workerclass XnnpackBackend(BackendDetails):
101*523fa7a6SAndroid Build Coastguard Worker    @staticmethod
102*523fa7a6SAndroid Build Coastguard Worker    def preprocess(
103*523fa7a6SAndroid Build Coastguard Worker        edge_program: ExportedProgram,
104*523fa7a6SAndroid Build Coastguard Worker        compile_specs: List[CompileSpec],
105*523fa7a6SAndroid Build Coastguard Worker    ) -> PreprocessResult:
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker        xnnpack_edge_compile_config = get_xnnpack_edge_compile_config()
108*523fa7a6SAndroid Build Coastguard Worker
109*523fa7a6SAndroid Build Coastguard Worker        # Need to wrap EP here because xnnpack does addmm to linear
110*523fa7a6SAndroid Build Coastguard Worker        # transforms. This makes resulting graph not aten compliant
111*523fa7a6SAndroid Build Coastguard Worker        # as aten.linear is not a core aten op.
112*523fa7a6SAndroid Build Coastguard Worker        # Ideal fix would be to have XNNPACK verifier that bypass
113*523fa7a6SAndroid Build Coastguard Worker        # most checks but the base Verifier itself has some strict changes
114*523fa7a6SAndroid Build Coastguard Worker        # and to bypass those, we would basically copy what EdgeDialectVerifier
115*523fa7a6SAndroid Build Coastguard Worker        # does. So for now instead of copy pasting that, just instantiate
116*523fa7a6SAndroid Build Coastguard Worker        # EdgeDialectVerifier, but disable it.
117*523fa7a6SAndroid Build Coastguard Worker        # TODO (task link) to implement NullVerifier or something similar
118*523fa7a6SAndroid Build Coastguard Worker        ep = ExportedProgram(
119*523fa7a6SAndroid Build Coastguard Worker            root=edge_program.graph_module,
120*523fa7a6SAndroid Build Coastguard Worker            graph=edge_program.graph,
121*523fa7a6SAndroid Build Coastguard Worker            graph_signature=edge_program.graph_signature,
122*523fa7a6SAndroid Build Coastguard Worker            state_dict=edge_program.state_dict,
123*523fa7a6SAndroid Build Coastguard Worker            range_constraints=edge_program.range_constraints,
124*523fa7a6SAndroid Build Coastguard Worker            module_call_graph=edge_program.module_call_graph,
125*523fa7a6SAndroid Build Coastguard Worker            example_inputs=edge_program.example_inputs,
126*523fa7a6SAndroid Build Coastguard Worker            constants=edge_program.constants,
127*523fa7a6SAndroid Build Coastguard Worker            verifiers=[
128*523fa7a6SAndroid Build Coastguard Worker                EXIREdgeDialectVerifier(
129*523fa7a6SAndroid Build Coastguard Worker                    edge_compile_config=xnnpack_edge_compile_config, class_only=True
130*523fa7a6SAndroid Build Coastguard Worker                )
131*523fa7a6SAndroid Build Coastguard Worker            ],
132*523fa7a6SAndroid Build Coastguard Worker        )
133*523fa7a6SAndroid Build Coastguard Worker
134*523fa7a6SAndroid Build Coastguard Worker        passes = []
135*523fa7a6SAndroid Build Coastguard Worker        for spec in compile_specs:
136*523fa7a6SAndroid Build Coastguard Worker            if spec.key == "dqlinear_partitioner":
137*523fa7a6SAndroid Build Coastguard Worker                passes.append(ConvertToLinearPass)
138*523fa7a6SAndroid Build Coastguard Worker                passes.append(TagImplicitQDqPass)
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker        passes = passes if len(passes) > 0 else None
141*523fa7a6SAndroid Build Coastguard Worker        # XNNPACK Delegate Specific Passes
142*523fa7a6SAndroid Build Coastguard Worker        ep = XNNPACKPassManager(ep, passes=passes).transform()
143*523fa7a6SAndroid Build Coastguard Worker        graph_module = ep.graph_module
144*523fa7a6SAndroid Build Coastguard Worker
145*523fa7a6SAndroid Build Coastguard Worker        node_to_external_map = generate_node_to_external_map(ep, graph_module)
146*523fa7a6SAndroid Build Coastguard Worker
147*523fa7a6SAndroid Build Coastguard Worker        # Make sure all inputs are contiguous_format or NCHW or default dim order
148*523fa7a6SAndroid Build Coastguard Worker        assert_default_dim_order(graph_module)
149*523fa7a6SAndroid Build Coastguard Worker
150*523fa7a6SAndroid Build Coastguard Worker        # TODO retrace the graph module to lift the new params may have
151*523fa7a6SAndroid Build Coastguard Worker        # been added to the graph in passes
152*523fa7a6SAndroid Build Coastguard Worker
153*523fa7a6SAndroid Build Coastguard Worker        vals_to_ids = {}
154*523fa7a6SAndroid Build Coastguard Worker        xnnpack_graph = XNNGraph(
155*523fa7a6SAndroid Build Coastguard Worker            version="0",
156*523fa7a6SAndroid Build Coastguard Worker            xnodes=[],
157*523fa7a6SAndroid Build Coastguard Worker            xvalues=[],
158*523fa7a6SAndroid Build Coastguard Worker            num_externs=len(node_to_external_map),
159*523fa7a6SAndroid Build Coastguard Worker            input_ids=[],
160*523fa7a6SAndroid Build Coastguard Worker            output_ids=[],
161*523fa7a6SAndroid Build Coastguard Worker            constant_data=[ConstantDataOffset(0, 0)],
162*523fa7a6SAndroid Build Coastguard Worker        )
163*523fa7a6SAndroid Build Coastguard Worker
164*523fa7a6SAndroid Build Coastguard Worker        constant_data_bytes = bytearray()
165*523fa7a6SAndroid Build Coastguard Worker        node_visitors = get_node_visitors(ep, node_to_external_map, constant_data_bytes)
166*523fa7a6SAndroid Build Coastguard Worker
167*523fa7a6SAndroid Build Coastguard Worker        for node in graph_module.graph.nodes:
168*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function":
169*523fa7a6SAndroid Build Coastguard Worker                logger.info(f"Visiting: {node}, {node.target.__name__}")
170*523fa7a6SAndroid Build Coastguard Worker                if node.target.__name__ in node_visitors:
171*523fa7a6SAndroid Build Coastguard Worker                    node_visitors[node.target.__name__].define_node(
172*523fa7a6SAndroid Build Coastguard Worker                        node,
173*523fa7a6SAndroid Build Coastguard Worker                        xnnpack_graph,
174*523fa7a6SAndroid Build Coastguard Worker                        vals_to_ids,
175*523fa7a6SAndroid Build Coastguard Worker                        node.meta.get("debug_handle", DEFAULT_DEBUG_HANDLE),
176*523fa7a6SAndroid Build Coastguard Worker                    )
177*523fa7a6SAndroid Build Coastguard Worker                else:
178*523fa7a6SAndroid Build Coastguard Worker                    raise RuntimeError(
179*523fa7a6SAndroid Build Coastguard Worker                        f"For {node}, {node.op}:{node.target.__name__} is not supported in XNNPACK Delegate"
180*523fa7a6SAndroid Build Coastguard Worker                    )
181*523fa7a6SAndroid Build Coastguard Worker            elif node.op in [
182*523fa7a6SAndroid Build Coastguard Worker                "get_attr",
183*523fa7a6SAndroid Build Coastguard Worker                "placeholder",
184*523fa7a6SAndroid Build Coastguard Worker                "output",
185*523fa7a6SAndroid Build Coastguard Worker            ]:
186*523fa7a6SAndroid Build Coastguard Worker                continue
187*523fa7a6SAndroid Build Coastguard Worker            else:
188*523fa7a6SAndroid Build Coastguard Worker                raise RuntimeError(f"{node.op} is not supported in XNNPACK")
189*523fa7a6SAndroid Build Coastguard Worker        return PreprocessResult(
190*523fa7a6SAndroid Build Coastguard Worker            processed_bytes=serialize_xnnpack_binary(
191*523fa7a6SAndroid Build Coastguard Worker                xnnpack_graph, constant_data_bytes
192*523fa7a6SAndroid Build Coastguard Worker            ),
193*523fa7a6SAndroid Build Coastguard Worker            debug_handle_map={},
194*523fa7a6SAndroid Build Coastguard Worker        )
195