xref: /aosp_15_r20/external/executorch/backends/arm/arm_backend.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7
8#
9# Main implementation of AoT flow to partition and preprocess for Arm target
10# backends. Converts via TOSA as an intermediate form supported by AoT and
11# JIT compiler flows.
12#
13
14import logging
15import os
16from typing import final, List, Optional
17
18import serializer.tosa_serializer as ts
19from executorch.backends.arm.arm_vela import vela_compile
20from executorch.backends.arm.operators.node_visitor import get_node_visitors
21
22from executorch.backends.arm.tosa_specification import TosaSpecification
23from executorch.backends.arm._passes.arm_pass_manager import (
24    ArmPassManager,
25)  # usort: skip
26from executorch.backends.arm.process_node import (
27    process_call_function,
28    process_output,
29    process_placeholder,
30)
31from executorch.backends.arm.tosa_utils import dbg_fail, dbg_tosa_dump
32from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
33from executorch.exir.backend.compile_spec_schema import CompileSpec
34from torch.export.exported_program import ExportedProgram
35
36# TOSA backend debug functionality
37logger = logging.getLogger(__name__)
38logger.setLevel(logging.WARNING)
39TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1"
40if TOSA_DBG_VERBOSE:
41    logging.basicConfig(level=logging.INFO)
42    logger.setLevel(logging.INFO)
43
44
45class ArmCompileSpecBuilder:
46    def __init__(self):
47        self.compile_spec: List[CompileSpec] = []
48        self.compiler_flags = []
49        self.output_format = None
50        self.path_for_intermediates = None
51        # TODO MLETORCH-265 Remove permute_nhwc flag
52        self.permute_nhwc = False
53        self.quantize_io = False
54        self.tosa_version = None
55
56    def ethosu_compile_spec(
57        self,
58        config: str,
59        system_config: str,
60        memory_mode: str,
61        extra_flags: Optional[str] = None,
62        config_ini: Optional[str] = "Arm/vela.ini",
63    ) -> "ArmCompileSpecBuilder":
64        """
65        Generate compile spec for Ethos-U NPU
66
67        Args:
68            config: Ethos-U accelerator configuration, e.g. ethos-u55-128
69            system_config: System configuration to select from the Vel
70                configuration file
71            memory_mode: Memory mode to select from the Vela configuration file
72            extra_flags: Extra flags for the Vela compiler
73            config_ini: Vela configuration file(s) in Python ConfigParser .ini
74                file format
75        """
76        assert (
77            self.output_format is None
78        ), f"Output format already set to f{self.output_format}"
79        self.output_format = "vela"
80        self.compiler_flags = [
81            f"--accelerator-config={config}",
82            f"--config={config_ini}",
83        ]
84        if system_config is not None:
85            self.compiler_flags.append(f"--system-config={system_config}")
86        if memory_mode is not None:
87            self.compiler_flags.append(f"--memory-mode={memory_mode}")
88        if extra_flags is not None:
89            self.compiler_flags.append(extra_flags)
90
91        base_tosa_version = "TOSA-0.80.0+BI"
92        if "U55" in config:
93            # Add the Ethos-U55 extension marker
94            base_tosa_version += "+u55"
95        self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)
96
97        return self
98
99    def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder":
100        """
101        Generate compile spec for TOSA flatbuffer output
102        """
103        assert (
104            self.output_format is None
105        ), f"Output format already set: {self.output_format}"
106        self.output_format = "tosa"
107        self.tosa_version = TosaSpecification.create_from_string(tosa_version)
108        return self
109
110    def dump_intermediate_artifacts_to(
111        self, output_path: str
112    ) -> "ArmCompileSpecBuilder":
113        """
114        Sets a path for dumping intermediate results during such as tosa and pte.
115        """
116        self.path_for_intermediates = output_path
117        return self
118
119    def set_permute_memory_format(
120        self, set_nhwc_permutation: bool = True
121    ) -> "ArmCompileSpecBuilder":
122        """
123        Permute to channel last in compiler and runtime. Compilation and
124        runtime will convert rank 4 inputs to channel last for each sub-graph.
125        """
126        self.permute_nhwc = set_nhwc_permutation
127        return self
128
129    def set_quantize_io(self, quantize_io: bool = False) -> "ArmCompileSpecBuilder":
130        """
131        Quantization of inputs and dequantization of outputs for cases where
132        whole graph is quantized and method signature is not of quantized type.
133        """
134        self.quantize_io = quantize_io
135        return self
136
137    def build(self) -> List[CompileSpec]:
138        """
139        Generate a list of compile spec objects from the builder
140        """
141        assert self.tosa_version
142
143        # Always supply a TOSA version
144        self.compile_spec = [
145            CompileSpec("tosa_version", str(self.tosa_version).encode())
146        ]
147
148        if self.output_format == "vela":
149            self.compile_spec += [
150                CompileSpec("output_format", "vela".encode()),
151                CompileSpec("compile_flags", " ".join(self.compiler_flags).encode()),
152            ]
153        elif self.output_format == "tosa":
154            self.compile_spec.append(CompileSpec("output_format", "tosa".encode()))
155
156        if self.path_for_intermediates is not None:
157            self.compile_spec.append(
158                CompileSpec("debug_artifact_path", self.path_for_intermediates.encode())
159            )
160
161        if self.permute_nhwc:
162            self.compile_spec.append(
163                CompileSpec("permute_memory_format", "nhwc".encode())
164            )
165
166        if self.quantize_io:
167            self.compile_spec.append(CompileSpec("quantize_io", "True".encode()))
168
169        return self.compile_spec
170
171
172def is_permute_memory(compile_spec: List[CompileSpec]) -> bool:
173    for spec in compile_spec:
174        if spec.key == "permute_memory_format":
175            return spec.value.decode() == "nhwc"
176    return False
177
178
179def is_tosa(compile_spec: List[CompileSpec]) -> bool:
180    for spec in compile_spec:
181        if spec.key == "output_format":
182            return spec.value.decode() == "tosa"
183    return False
184
185
186def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]:
187    for spec in compile_spec:
188        if spec.key == "debug_artifact_path":
189            return spec.value.decode()
190    return None
191
192
193def _get_first_delegation_tag(graph_module) -> str | None:
194    """Get the first delegation tag from the graph_module or return None."""
195    for node in graph_module.graph.nodes:
196        tag = node.meta.get("delegation_tag")
197        if tag:
198            return tag
199
200    logger.debug("No delegation tag found in partition.")
201    return None
202
203
204@final
205class ArmBackend(BackendDetails):
206    @staticmethod
207    def preprocess(  # noqa: C901
208        edge_program: ExportedProgram,
209        compile_spec: List[CompileSpec],
210    ) -> PreprocessResult:
211        logger.info("ArmBackend::preprocess")
212
213        # if a debug/test build capture output files from TOSA stage
214        artifact_path = None
215        output_format = ""
216        compile_flags = []
217        for spec in compile_spec:
218            if spec.key == "debug_artifact_path":
219                artifact_path = spec.value.decode()
220            if spec.key == "output_format":
221                output_format = spec.value.decode()
222            if spec.key == "compile_flags":
223                compile_flags.append(spec.value.decode())
224
225        # Check that the output format is set in the compile spec
226        if not output_format:
227            raise RuntimeError("output format is required")
228
229        tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec)
230        assert (
231            tosa_spec is not None
232        ), "TOSA backend needs a TOSA version specified in the CompileSpec!"
233
234        if output_format == "vela" and len(compile_flags) == 0:
235            # Not testing for compile_flags correctness here, just that they are
236            # present. The compiler will give errors if they are not valid.
237            raise RuntimeError("compile flags are required for vela output format")
238
239        logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")
240
241        # Converted output for this subgraph, serializer needs path early as it emits
242        # const data directly. Path created and data written only in debug builds.
243        tosa_graph = ts.TosaSerializer(artifact_path)
244        graph_module = ArmPassManager().transform_to_backend_pipeline(
245            exported_program=edge_program, compile_spec=compile_spec
246        )
247
248        node_visitors = get_node_visitors(edge_program, tosa_spec)
249
250        for node in graph_module.graph.nodes:
251            if node.op == "call_function":
252                process_call_function(node, tosa_graph, node_visitors, tosa_spec)
253            elif node.op == "placeholder":
254                process_placeholder(node, tosa_graph, edge_program, tosa_spec)
255            elif node.op == "output":
256                process_output(node, tosa_graph)
257            else:
258                # This will only happen if an unpartitioned graph is passed without
259                # any checking of compatibility.
260                dbg_fail(node, tosa_graph, artifact_path)
261
262        # TODO: It would be awesome if this dump could somehow be done on top level and not here.
263        # Problem is that the desc.json has to be created on the tosa_graph object, which we can't
264        # access from top level.
265        if artifact_path:
266            tag = _get_first_delegation_tag(graph_module)
267            dbg_tosa_dump(
268                tosa_graph,
269                artifact_path,
270                suffix="{}".format(f"_{tag}" if tag else ""),
271            )
272
273        # Serialize and return the program. While we have always produced TOSA
274        # output as an intermediate, some flows compile to device binaries in
275        # preprocess and some consume TOSA fb directly.
276        if output_format == "vela":
277            # Emit vela_bin_stream format
278            binary = vela_compile(tosa_graph, compile_flags)
279        elif output_format == "tosa":
280            # Emit TOSA flatbuffer
281            binary = bytes(tosa_graph.serialize())
282        else:
283            raise RuntimeError(f"Unknown format {output_format}")
284
285        # Continueing from above. Can I put tosa_graph into this function?
286        # debug_handle_map = ...
287        return PreprocessResult(processed_bytes=binary)
288