1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import collections 7import copy 8import os 9import subprocess 10import tempfile 11import unittest 12from typing import Callable, Dict, List, Optional, Tuple 13 14import numpy as np 15import torch 16 17from executorch import exir 18from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner 19from executorch.backends.qualcomm.qnn_preprocess import QnnBackend 20from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype 21from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset 22from executorch.backends.qualcomm.utils.utils import ( 23 capture_program, 24 get_soc_to_chipset_map, 25) 26from executorch.devtools import generate_etrecord, Inspector 27from executorch.examples.qualcomm.utils import ( 28 generate_inputs, 29 make_output_dir, 30 SimpleADB, 31) 32 33from executorch.exir.backend.backend_api import to_backend 34from executorch.exir.backend.compile_spec_schema import CompileSpec 35from executorch.exir.dialects._ops import ops as exir_ops 36from executorch.exir.pass_base import ExportPass 37from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass 38from executorch.exir.program import ExecutorchProgram, ExecutorchProgramManager 39from torch.ao.quantization.quantize_pt2e import ( 40 convert_pt2e, 41 prepare_pt2e, 42 prepare_qat_pt2e, 43) 44 45 46def generate_context_binary( 47 module: torch.nn.Module, 48 inputs: Dict[str, torch.Tensor], 49 quantized: bool, 50 artifact_dir: str, 51): 52 # we also expect clang showing in PATH or context may fail to generate 53 qnn_sdk = os.environ.get("QNN_SDK_ROOT", None) 54 ndk = os.environ.get("ANDROID_NDK_ROOT", None) 55 assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable" 56 assert ndk, "ANDROID_NDK_ROOT was not found in environment variable" 57 58 inputs_tup = tuple(inputs.values()) 59 jit_module = torch.jit.trace(module, inputs_tup) 60 torch.jit.save(jit_module, f"{artifact_dir}/jit_module.pt") 61 62 # input data 63 if quantized: 64 input_list = [] 65 for name, data in inputs.items(): 66 file_name = f"{artifact_dir}/{name}.raw" 67 data.detach().numpy().tofile(file_name) 68 input_list.append(file_name) 69 70 with open(f"{artifact_dir}/input_list.txt", "w") as f: 71 f.write(" ".join(input_list)) 72 73 # flow of qnn tools 74 target = "x86_64-linux-clang" 75 inputs_str = [ 76 f"-d '{k}' {str(tuple(v.shape)).replace(' ', '')[1:-1]}" 77 for k, v in inputs.items() 78 ] 79 cmds = [ 80 # setup qnn env 81 f"source {qnn_sdk}/bin/envsetup.sh;" 82 # qnn-pytorch-converter 83 f"{qnn_sdk}/bin/{target}/qnn-pytorch-converter", 84 f"-i {artifact_dir}/jit_module.pt", 85 *inputs_str, 86 f"--input_list {artifact_dir}/input_list.txt" if quantized else "", 87 "--preserve_io", 88 f"-o {artifact_dir}/model.cpp;", 89 # qnn-model-lib-generator 90 f"{qnn_sdk}/bin/{target}/qnn-model-lib-generator", 91 f"-c {artifact_dir}/model.cpp", 92 f"-t {target}", 93 "-l model", 94 f"-o {artifact_dir}/model_libs;", 95 # qnn-context-binary-generator 96 f"{qnn_sdk}/bin/{target}/qnn-context-binary-generator", 97 f"--model {artifact_dir}/model_libs/{target}/libmodel.so", 98 f"--backend {qnn_sdk}/lib/{target}/libQnnHtp.so", 99 "--binary_file model_ctx", 100 f"--output_dir {artifact_dir};", 101 ] 102 result = subprocess.run( 103 " ".join(cmds), 104 shell=True, 105 executable="/bin/bash", 106 capture_output=True, 107 ) 108 assert os.path.isfile(f"{artifact_dir}/model_ctx.bin"), print(result.stderr) 109 110 111class TestQNN(unittest.TestCase): 112 rtol: float = 0 113 atol: float = 0 114 host: str = "" 115 device: str = "" 116 build_folder: str = "" 117 model: QcomChipset = None 118 compiler_specs: List[CompileSpec] = None 119 chipset_table = get_soc_to_chipset_map() 120 error_only = False 121 ip = "localhost" 122 port = 8080 123 executorch_root: str = "" 124 artifact_dir: str = "" 125 image_dataset: str = "" 126 pretrained_weight: str = "" 127 enable_profile: bool = False 128 online_prepare: bool = False 129 use_8a8w: str = "8a8w" 130 use_16a16w: str = "16a16w" 131 use_16a4w: str = "16a4w" 132 shared_buffer: bool = False 133 enable_x86_64: bool = False 134 135 def _assert_outputs_equal(self, model_output, ref_output): 136 self.assertTrue(len(ref_output) == len(model_output)) 137 for i in range(len(ref_output)): 138 self.assertTrue( 139 torch.allclose( 140 model_output[i], ref_output[i], atol=self.atol, rtol=self.rtol 141 ), 142 msg=f"ref_output:\n{ref_output[i]}\n\nmodel_output:\n{model_output[i]}", 143 ) 144 145 def _save_model_and_expected_output( 146 self, 147 module: torch.nn.Module, 148 buffer: exir.ExirExportedProgram, 149 inputs: Tuple[torch.Tensor], 150 dir_name: str, 151 ) -> None: 152 # Save the input data list to be executed 153 input_list = "" 154 for idx, _ in enumerate(inputs): 155 input_name = f"input_0_{idx}.raw" 156 input_list += input_name + " " 157 input_list = input_list.strip() + "\n" 158 159 ref_output = module(*inputs) 160 161 # Save the expected output data to be verified 162 ref_outputs = [] 163 if isinstance(ref_output, collections.OrderedDict): 164 ref_outputs.append(ref_output["out"].detach()) 165 elif isinstance(ref_output, (list, tuple)): 166 for output in ref_output: 167 ref_outputs.append(output.detach()) 168 else: 169 ref_outputs.append(ref_output.detach()) 170 171 pte_fname = f"{dir_name}/qnn_executorch_test.pte" 172 with open(pte_fname, "wb") as file: 173 file.write(buffer) 174 175 return input_list, ref_outputs, pte_fname 176 177 def verify_output( # noqa: C901 178 self, 179 module: torch.nn.Module, 180 sample_inputs: Tuple[torch.Tensor], 181 executorch_prog: ExecutorchProgram | ExecutorchProgramManager, 182 etrecord_path: str = "etrecord.bin", 183 expected_profile_events: int = -1, 184 expected_intermediate_events: int = -1, 185 method_index: int = 0, 186 ): 187 with tempfile.TemporaryDirectory() as tmp_dir: 188 ( 189 input_list, 190 ref_outputs, 191 pte_fname, 192 ) = self._save_model_and_expected_output( 193 module, 194 executorch_prog.buffer, 195 sample_inputs, 196 tmp_dir, 197 ) 198 199 output_dir = f"{tmp_dir}/outputs" 200 outputs = [] 201 etdump_path = f"{tmp_dir}/etdump.etdp" 202 debug_output_path = f"{tmp_dir}/debug_output.bin" 203 204 def post_process(): 205 for i, f in enumerate(sorted(os.listdir(output_dir))): 206 filename = os.path.join(output_dir, f) 207 output = np.fromfile(filename, dtype=ref_outputs[i].numpy().dtype) 208 output = torch.from_numpy(output).reshape(ref_outputs[i].shape) 209 outputs.append(output) 210 211 def validate_profile(): 212 inspector = Inspector(etdump_path=etdump_path, etrecord=etrecord_path) 213 self.assertTrue( 214 len(inspector.to_dataframe().index) == expected_profile_events 215 ) 216 217 def validate_intermediate_tensor(): 218 inspector = Inspector( 219 etdump_path=etdump_path, debug_buffer_path=debug_output_path 220 ) 221 for event_block in inspector.event_blocks: 222 if event_block.name == "Execute": 223 self.assertTrue( 224 len(event_block.events) == expected_intermediate_events 225 ) 226 227 if self.enable_x86_64: 228 generate_inputs(tmp_dir, "input_list.txt", [sample_inputs], input_list) 229 make_output_dir(output_dir) 230 231 target = "x86_64-linux-clang" 232 qnn_sdk = os.environ.get("QNN_SDK_ROOT", None) 233 assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable" 234 235 build_folder = self.build_folder 236 if os.path.isabs(self.build_folder): 237 # obey user's opinion 238 pass 239 else: 240 # ok, assuming the user give a relative path to cwd 241 build_folder = os.path.join(os.getcwd(), self.build_folder) 242 243 cmd = [ 244 # qnn_executor_runner 245 f"{build_folder}/examples/qualcomm/executor_runner/qnn_executor_runner", 246 "--model_path", 247 pte_fname, 248 "--input_list_path", 249 f"{tmp_dir}/input_list.txt", 250 "--output_folder_path", 251 output_dir, 252 "--method_index", 253 str(method_index), 254 ] 255 if expected_intermediate_events != -1: 256 cmd.append("--dump_intermediate_outputs") 257 258 env = dict(os.environ) 259 env["LD_LIBRARY_PATH"] = f"{qnn_sdk}/lib/{target}/:{build_folder}/lib" 260 proc = subprocess.run( 261 cmd, 262 stdout=subprocess.PIPE, 263 stderr=subprocess.STDOUT, 264 env=env, 265 cwd=tmp_dir, 266 ) 267 268 self.assertEqual( 269 proc.returncode, 270 0, 271 f"The process running qnn_executorch_runner return {proc.returncode}, " 272 "STDOUT=\n" 273 f"{proc.stdout.decode('utf-8')}", 274 ) 275 276 # Verify the outputs 277 post_process() 278 self._assert_outputs_equal(outputs, ref_outputs) 279 280 # Verify the etdump 281 if expected_profile_events != -1: 282 validate_profile() 283 284 if expected_intermediate_events != -1: 285 validate_intermediate_tensor() 286 else: 287 adb = SimpleADB( 288 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 289 build_path=self.build_folder, 290 pte_path=pte_fname, 291 workspace="/data/local/tmp/qnn_executorch_test", 292 device_id=self.device, 293 host_id=self.host, 294 soc_model=self.model, 295 error_only=self.error_only, 296 dump_intermediate_outputs=( 297 True if expected_intermediate_events != -1 else False 298 ), 299 ) 300 adb.push(inputs=[sample_inputs], input_list=input_list) 301 adb.execute(method_index=method_index) 302 adb.pull(output_path=tmp_dir, callback=post_process) 303 self._assert_outputs_equal(outputs, ref_outputs) 304 305 if expected_profile_events != -1: 306 adb.pull_etdump(etdump_path, callback=validate_profile) 307 308 if expected_intermediate_events != -1: 309 adb.pull_debug_output( 310 etdump_path, 311 debug_output_path, 312 callback=validate_intermediate_tensor, 313 ) 314 315 def lower_module_and_test_output( 316 self, 317 module: torch.nn.Module, 318 sample_inputs: Tuple[torch.Tensor], 319 expected_partitions: int = 1, 320 expected_profile_events: int = -1, 321 expected_intermediate_events: int = -1, 322 assert_output_equal: bool = True, 323 skip_node_id_set: set = None, 324 skip_node_op_set: set = None, 325 ): 326 qnn_partitioner = QnnPartitioner( 327 self.compiler_specs, skip_node_id_set, skip_node_op_set 328 ) 329 delegated_program = capture_program(module, sample_inputs) 330 331 # this is needed for the ETRecord as lowering modifies the graph in-place 332 edge_copy = copy.deepcopy(delegated_program) 333 334 delegated_program.exported_program = to_backend( 335 delegated_program.exported_program, qnn_partitioner 336 ) 337 exec_prog = delegated_program.to_executorch( 338 exir.ExecutorchBackendConfig( 339 # For shared buffer, user must pass the memory address 340 # which is allocated by RPC memory to executor runner. 341 # Therefore, won't want to pre-allocate 342 # by memory manager in runtime. 343 memory_planning_pass=MemoryPlanningPass( 344 alloc_graph_input=not self.shared_buffer, 345 alloc_graph_output=not self.shared_buffer, 346 ), 347 ) 348 ) 349 350 # Assert the backend name is qnn 351 self.assertEqual( 352 len(exec_prog.program.execution_plan[0].delegates), 353 expected_partitions, 354 ) 355 for i in range(expected_partitions): 356 self.assertEqual( 357 exec_prog.program.execution_plan[0].delegates[i].id, 358 QnnBackend.__name__, 359 ) 360 361 etrecord_path = "etrecord.bin" 362 if self.enable_profile: 363 generate_etrecord(etrecord_path, edge_copy, exec_prog) 364 # Check numerics 365 if ( 366 assert_output_equal 367 or expected_profile_events != -1 368 or expected_intermediate_events != -1 369 ): 370 self.verify_output( 371 module, 372 sample_inputs, 373 exec_prog, 374 etrecord_path, 375 expected_profile_events, 376 expected_intermediate_events, 377 ) 378 379 def get_qdq_module( 380 self, 381 module: torch.nn.Module, 382 inputs: Tuple[torch.Tensor], 383 is_conv_per_channel: Optional[bool] = True, 384 is_linear_per_channel: Optional[bool] = False, 385 custom_quant_annotations: Tuple[Callable] = (), 386 quant_dtype: QuantDtype = QuantDtype.use_8a8w, 387 ) -> torch.fx.GraphModule: 388 m = torch.export.export(module, inputs).module() 389 390 quantizer = QnnQuantizer() 391 quantizer.add_custom_quant_annotations(custom_quant_annotations) 392 quantizer.set_per_channel_conv_quant(is_conv_per_channel) 393 quantizer.set_per_channel_linear_quant(is_linear_per_channel) 394 quantizer.set_quant_config(quant_dtype) 395 396 prepared = prepare_pt2e(m, quantizer) 397 prepared(*inputs) 398 quantized_module = convert_pt2e(prepared) 399 nodes = {node.target for node in quantized_module.graph.nodes} 400 q_and_dq = { 401 torch.ops.quantized_decomposed.quantize_per_tensor.default, 402 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 403 torch.ops.quantized_decomposed.quantize_per_channel.default, 404 torch.ops.quantized_decomposed.dequantize_per_channel.default, 405 } 406 self.assertTrue(nodes.intersection(q_and_dq)) 407 return quantized_module 408 409 def get_prepared_qat_module( 410 self, 411 module: torch.nn.Module, 412 inputs: Tuple[torch.Tensor], 413 is_conv_per_channel: Optional[bool] = True, 414 is_linear_per_channel: Optional[bool] = False, 415 custom_quant_annotations: Tuple[Callable] = (), 416 quant_dtype: QuantDtype = QuantDtype.use_8a8w, 417 ) -> torch.fx.GraphModule: 418 m = torch.export.export_for_training(module, inputs).module() 419 420 quantizer = QnnQuantizer() 421 quantizer.add_custom_quant_annotations(custom_quant_annotations) 422 quantizer.set_per_channel_conv_quant(is_conv_per_channel) 423 quantizer.set_per_channel_linear_quant(is_linear_per_channel) 424 425 if quant_dtype == QuantDtype.use_8a8w: 426 quantizer.set_quant_config(quant_dtype, is_qat=True) 427 else: 428 raise RuntimeError("Shuld not be here") 429 430 prepared = prepare_qat_pt2e(m, quantizer) 431 return torch.ao.quantization.move_exported_model_to_train(prepared) 432 433 def get_converted_sgd_trained_module( 434 self, 435 ori_module: torch.nn.Module, 436 prepared: torch.nn.Module, 437 inputs: Tuple[torch.Tensor], 438 ) -> torch.fx.GraphModule: 439 optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) 440 criterion = torch.nn.CrossEntropyLoss() 441 output = prepared(*inputs) 442 loss = criterion(output, ori_module(*inputs)) 443 optimizer.zero_grad() 444 loss.backward() 445 optimizer.step() 446 return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared) 447 448 def split_graph(self, graph_module: torch.fx.GraphModule, division: int): 449 class SplitGraph(ExportPass): 450 """ 451 Split graph based on number of nodes. 452 """ 453 454 def __init__(self, shares): 455 super().__init__() 456 self.shares = shares 457 458 def _insert_clone( 459 self, graph_module: torch.fx.GraphModule 460 ) -> torch.fx.GraphModule: 461 num_graph_nodes = 0 462 for node in graph_module.graph.nodes: 463 num_graph_nodes += 1 if node.op == "call_function" else 0 464 465 if num_graph_nodes % self.shares != 0 or node.op != "call_function": 466 continue 467 468 with graph_module.graph.inserting_after(node): 469 users = list(node.users.keys()) 470 inserted_node = graph_module.graph.create_node( 471 "call_function", 472 exir_ops.edge.aten.clone.default, 473 (node,), 474 ) 475 inserted_node.meta["val"] = node.meta["val"] 476 if "quant_attrs" in node.meta: 477 inserted_node.meta["quant_attrs"] = node.meta["quant_attrs"] 478 for user in users: 479 user.replace_input_with(node, inserted_node) 480 481 def call(self, graph_module: torch.fx.GraphModule): 482 self._insert_clone(graph_module) 483 graph_module.recompile() 484 485 num_graph_nodes = 0 486 for node in graph_module.graph.nodes: 487 num_graph_nodes += 1 if node.op == "call_function" else 0 488 489 SplitGraph(-(num_graph_nodes // -division))(graph_module) 490