1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7 8# 9# Main implementation of AoT flow to partition and preprocess for Arm target 10# backends. Converts via TOSA as an intermediate form supported by AoT and 11# JIT compiler flows. 12# 13 14import logging 15import os 16from typing import final, List, Optional 17 18import serializer.tosa_serializer as ts 19from executorch.backends.arm.arm_vela import vela_compile 20from executorch.backends.arm.operators.node_visitor import get_node_visitors 21 22from executorch.backends.arm.tosa_specification import TosaSpecification 23from executorch.backends.arm._passes.arm_pass_manager import ( 24 ArmPassManager, 25) # usort: skip 26from executorch.backends.arm.process_node import ( 27 process_call_function, 28 process_output, 29 process_placeholder, 30) 31from executorch.backends.arm.tosa_utils import dbg_fail, dbg_tosa_dump 32from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult 33from executorch.exir.backend.compile_spec_schema import CompileSpec 34from torch.export.exported_program import ExportedProgram 35 36# TOSA backend debug functionality 37logger = logging.getLogger(__name__) 38logger.setLevel(logging.WARNING) 39TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" 40if TOSA_DBG_VERBOSE: 41 logging.basicConfig(level=logging.INFO) 42 logger.setLevel(logging.INFO) 43 44 45class ArmCompileSpecBuilder: 46 def __init__(self): 47 self.compile_spec: List[CompileSpec] = [] 48 self.compiler_flags = [] 49 self.output_format = None 50 self.path_for_intermediates = None 51 # TODO MLETORCH-265 Remove permute_nhwc flag 52 self.permute_nhwc = False 53 self.quantize_io = False 54 self.tosa_version = None 55 56 def ethosu_compile_spec( 57 self, 58 config: str, 59 system_config: str, 60 memory_mode: str, 61 extra_flags: Optional[str] = None, 62 config_ini: Optional[str] = "Arm/vela.ini", 63 ) -> "ArmCompileSpecBuilder": 64 """ 65 Generate compile spec for Ethos-U NPU 66 67 Args: 68 config: Ethos-U accelerator configuration, e.g. ethos-u55-128 69 system_config: System configuration to select from the Vel 70 configuration file 71 memory_mode: Memory mode to select from the Vela configuration file 72 extra_flags: Extra flags for the Vela compiler 73 config_ini: Vela configuration file(s) in Python ConfigParser .ini 74 file format 75 """ 76 assert ( 77 self.output_format is None 78 ), f"Output format already set to f{self.output_format}" 79 self.output_format = "vela" 80 self.compiler_flags = [ 81 f"--accelerator-config={config}", 82 f"--config={config_ini}", 83 ] 84 if system_config is not None: 85 self.compiler_flags.append(f"--system-config={system_config}") 86 if memory_mode is not None: 87 self.compiler_flags.append(f"--memory-mode={memory_mode}") 88 if extra_flags is not None: 89 self.compiler_flags.append(extra_flags) 90 91 base_tosa_version = "TOSA-0.80.0+BI" 92 if "U55" in config: 93 # Add the Ethos-U55 extension marker 94 base_tosa_version += "+u55" 95 self.tosa_version = TosaSpecification.create_from_string(base_tosa_version) 96 97 return self 98 99 def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder": 100 """ 101 Generate compile spec for TOSA flatbuffer output 102 """ 103 assert ( 104 self.output_format is None 105 ), f"Output format already set: {self.output_format}" 106 self.output_format = "tosa" 107 self.tosa_version = TosaSpecification.create_from_string(tosa_version) 108 return self 109 110 def dump_intermediate_artifacts_to( 111 self, output_path: str 112 ) -> "ArmCompileSpecBuilder": 113 """ 114 Sets a path for dumping intermediate results during such as tosa and pte. 115 """ 116 self.path_for_intermediates = output_path 117 return self 118 119 def set_permute_memory_format( 120 self, set_nhwc_permutation: bool = True 121 ) -> "ArmCompileSpecBuilder": 122 """ 123 Permute to channel last in compiler and runtime. Compilation and 124 runtime will convert rank 4 inputs to channel last for each sub-graph. 125 """ 126 self.permute_nhwc = set_nhwc_permutation 127 return self 128 129 def set_quantize_io(self, quantize_io: bool = False) -> "ArmCompileSpecBuilder": 130 """ 131 Quantization of inputs and dequantization of outputs for cases where 132 whole graph is quantized and method signature is not of quantized type. 133 """ 134 self.quantize_io = quantize_io 135 return self 136 137 def build(self) -> List[CompileSpec]: 138 """ 139 Generate a list of compile spec objects from the builder 140 """ 141 assert self.tosa_version 142 143 # Always supply a TOSA version 144 self.compile_spec = [ 145 CompileSpec("tosa_version", str(self.tosa_version).encode()) 146 ] 147 148 if self.output_format == "vela": 149 self.compile_spec += [ 150 CompileSpec("output_format", "vela".encode()), 151 CompileSpec("compile_flags", " ".join(self.compiler_flags).encode()), 152 ] 153 elif self.output_format == "tosa": 154 self.compile_spec.append(CompileSpec("output_format", "tosa".encode())) 155 156 if self.path_for_intermediates is not None: 157 self.compile_spec.append( 158 CompileSpec("debug_artifact_path", self.path_for_intermediates.encode()) 159 ) 160 161 if self.permute_nhwc: 162 self.compile_spec.append( 163 CompileSpec("permute_memory_format", "nhwc".encode()) 164 ) 165 166 if self.quantize_io: 167 self.compile_spec.append(CompileSpec("quantize_io", "True".encode())) 168 169 return self.compile_spec 170 171 172def is_permute_memory(compile_spec: List[CompileSpec]) -> bool: 173 for spec in compile_spec: 174 if spec.key == "permute_memory_format": 175 return spec.value.decode() == "nhwc" 176 return False 177 178 179def is_tosa(compile_spec: List[CompileSpec]) -> bool: 180 for spec in compile_spec: 181 if spec.key == "output_format": 182 return spec.value.decode() == "tosa" 183 return False 184 185 186def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]: 187 for spec in compile_spec: 188 if spec.key == "debug_artifact_path": 189 return spec.value.decode() 190 return None 191 192 193def _get_first_delegation_tag(graph_module) -> str | None: 194 """Get the first delegation tag from the graph_module or return None.""" 195 for node in graph_module.graph.nodes: 196 tag = node.meta.get("delegation_tag") 197 if tag: 198 return tag 199 200 logger.debug("No delegation tag found in partition.") 201 return None 202 203 204@final 205class ArmBackend(BackendDetails): 206 @staticmethod 207 def preprocess( # noqa: C901 208 edge_program: ExportedProgram, 209 compile_spec: List[CompileSpec], 210 ) -> PreprocessResult: 211 logger.info("ArmBackend::preprocess") 212 213 # if a debug/test build capture output files from TOSA stage 214 artifact_path = None 215 output_format = "" 216 compile_flags = [] 217 for spec in compile_spec: 218 if spec.key == "debug_artifact_path": 219 artifact_path = spec.value.decode() 220 if spec.key == "output_format": 221 output_format = spec.value.decode() 222 if spec.key == "compile_flags": 223 compile_flags.append(spec.value.decode()) 224 225 # Check that the output format is set in the compile spec 226 if not output_format: 227 raise RuntimeError("output format is required") 228 229 tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec) 230 assert ( 231 tosa_spec is not None 232 ), "TOSA backend needs a TOSA version specified in the CompileSpec!" 233 234 if output_format == "vela" and len(compile_flags) == 0: 235 # Not testing for compile_flags correctness here, just that they are 236 # present. The compiler will give errors if they are not valid. 237 raise RuntimeError("compile flags are required for vela output format") 238 239 logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}") 240 241 # Converted output for this subgraph, serializer needs path early as it emits 242 # const data directly. Path created and data written only in debug builds. 243 tosa_graph = ts.TosaSerializer(artifact_path) 244 graph_module = ArmPassManager().transform_to_backend_pipeline( 245 exported_program=edge_program, compile_spec=compile_spec 246 ) 247 248 node_visitors = get_node_visitors(edge_program, tosa_spec) 249 250 for node in graph_module.graph.nodes: 251 if node.op == "call_function": 252 process_call_function(node, tosa_graph, node_visitors, tosa_spec) 253 elif node.op == "placeholder": 254 process_placeholder(node, tosa_graph, edge_program, tosa_spec) 255 elif node.op == "output": 256 process_output(node, tosa_graph) 257 else: 258 # This will only happen if an unpartitioned graph is passed without 259 # any checking of compatibility. 260 dbg_fail(node, tosa_graph, artifact_path) 261 262 # TODO: It would be awesome if this dump could somehow be done on top level and not here. 263 # Problem is that the desc.json has to be created on the tosa_graph object, which we can't 264 # access from top level. 265 if artifact_path: 266 tag = _get_first_delegation_tag(graph_module) 267 dbg_tosa_dump( 268 tosa_graph, 269 artifact_path, 270 suffix="{}".format(f"_{tag}" if tag else ""), 271 ) 272 273 # Serialize and return the program. While we have always produced TOSA 274 # output as an intermediate, some flows compile to device binaries in 275 # preprocess and some consume TOSA fb directly. 276 if output_format == "vela": 277 # Emit vela_bin_stream format 278 binary = vela_compile(tosa_graph, compile_flags) 279 elif output_format == "tosa": 280 # Emit TOSA flatbuffer 281 binary = bytes(tosa_graph.serialize()) 282 else: 283 raise RuntimeError(f"Unknown format {output_format}") 284 285 # Continueing from above. Can I put tosa_graph into this function? 286 # debug_handle_map = ... 287 return PreprocessResult(processed_bytes=binary) 288