1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import os 9import subprocess 10import sys 11from pathlib import Path 12 13from typing import Callable, List, Optional 14 15import numpy as np 16 17import torch 18from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner 19from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype 20from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset 21from executorch.backends.qualcomm.utils.utils import ( 22 capture_program, 23 generate_htp_compiler_spec, 24 generate_qnn_executorch_compiler_spec, 25 get_soc_to_arch_map, 26) 27from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge 28from executorch.exir.backend.backend_api import to_backend 29from executorch.exir.capture._config import ExecutorchBackendConfig 30from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass 31from torch.ao.quantization.observer import MovingAverageMinMaxObserver 32from torch.ao.quantization.quantize_pt2e import ( 33 convert_pt2e, 34 prepare_pt2e, 35 prepare_qat_pt2e, 36) 37 38 39class SimpleADB: 40 """ 41 A wrapper class for communicating with Android device 42 43 Attributes: 44 qnn_sdk (str): QNN SDK path setup in environment variable 45 build_path (str): Path where artifacts were built 46 pte_path (str): Path where executorch binary was stored 47 workspace (str): Folder for storing artifacts on android device 48 device_id (str): Serial number of android device 49 soc_model (str): Chipset of device 50 host_id (str): Hostname of machine where device connects 51 error_only (bool): Redirect stdio and leave error messages only 52 shared_buffer (bool): Apply zero-copy mechanism in runtime 53 runner (str): Runtime executor binary 54 """ 55 56 def __init__( 57 self, 58 qnn_sdk, 59 build_path, 60 pte_path, 61 workspace, 62 device_id, 63 soc_model, 64 host_id=None, 65 error_only=False, 66 shared_buffer=False, 67 dump_intermediate_outputs=False, 68 runner="examples/qualcomm/executor_runner/qnn_executor_runner", 69 ): 70 self.qnn_sdk = qnn_sdk 71 self.build_path = build_path 72 self.pte_path = pte_path if isinstance(pte_path, list) else [pte_path] 73 self.workspace = workspace 74 self.device_id = device_id 75 self.host_id = host_id 76 self.working_dir = Path(self.pte_path[0]).parent.absolute() 77 self.input_list_filename = "input_list.txt" 78 self.etdump_path = f"{self.workspace}/etdump.etdp" 79 self.dump_intermediate_outputs = dump_intermediate_outputs 80 self.debug_output_path = f"{self.workspace}/debug_output.bin" 81 self.output_folder = f"{self.workspace}/outputs" 82 self.htp_arch = get_soc_to_arch_map()[soc_model] 83 self.error_only = error_only 84 self.shared_buffer = shared_buffer 85 self.runner = runner 86 87 def _adb(self, cmd): 88 if not self.host_id: 89 cmds = ["adb", "-s", self.device_id] 90 else: 91 cmds = ["adb", "-H", self.host_id, "-s", self.device_id] 92 cmds.extend(cmd) 93 94 subprocess.run( 95 cmds, stdout=subprocess.DEVNULL if self.error_only else sys.stdout 96 ) 97 98 def push(self, inputs=None, input_list=None, files=None): 99 self._adb(["shell", f"rm -rf {self.workspace}"]) 100 self._adb(["shell", f"mkdir -p {self.workspace}"]) 101 102 # necessary artifacts 103 artifacts = [ 104 *self.pte_path, 105 f"{self.qnn_sdk}/lib/aarch64-android/libQnnHtp.so", 106 ( 107 f"{self.qnn_sdk}/lib/hexagon-v{self.htp_arch}/" 108 f"unsigned/libQnnHtpV{self.htp_arch}Skel.so" 109 ), 110 ( 111 f"{self.qnn_sdk}/lib/aarch64-android/" 112 f"libQnnHtpV{self.htp_arch}Stub.so" 113 ), 114 f"{self.qnn_sdk}/lib/aarch64-android/libQnnHtpPrepare.so", 115 f"{self.qnn_sdk}/lib/aarch64-android/libQnnSystem.so", 116 f"{self.build_path}/{self.runner}", 117 f"{self.build_path}/backends/qualcomm/libqnn_executorch_backend.so", 118 ] 119 input_list_file, input_files = generate_inputs( 120 self.working_dir, self.input_list_filename, inputs, input_list 121 ) 122 123 if input_list_file is not None: 124 # prepare input list 125 artifacts.append(input_list_file) 126 127 for artifact in artifacts: 128 self._adb(["push", artifact, self.workspace]) 129 130 # input data 131 for file_name in input_files: 132 self._adb(["push", file_name, self.workspace]) 133 134 # custom files 135 if files is not None: 136 for file_name in files: 137 self._adb(["push", file_name, self.workspace]) 138 139 def execute(self, custom_runner_cmd=None, method_index=0): 140 self._adb(["shell", f"mkdir -p {self.output_folder}"]) 141 # run the delegation 142 if custom_runner_cmd is None: 143 qnn_executor_runner_args = " ".join( 144 [ 145 f"--model_path {os.path.basename(self.pte_path[0])}", 146 f"--output_folder_path {self.output_folder}", 147 f"--input_list_path {self.input_list_filename}", 148 f"--etdump_path {self.etdump_path}", 149 "--shared_buffer" if self.shared_buffer else "", 150 f"--debug_output_path {self.debug_output_path}", 151 ( 152 "--dump_intermediate_outputs" 153 if self.dump_intermediate_outputs 154 else "" 155 ), 156 f"--method_index {method_index}", 157 ] 158 ) 159 qnn_executor_runner_cmds = " ".join( 160 [ 161 f"cd {self.workspace} &&", 162 f"./qnn_executor_runner {qnn_executor_runner_args}", 163 ] 164 ) 165 else: 166 qnn_executor_runner_cmds = custom_runner_cmd 167 168 self._adb(["shell", f"{qnn_executor_runner_cmds}"]) 169 170 def pull(self, output_path, callback=None): 171 self._adb(["pull", "-a", self.output_folder, output_path]) 172 if callback: 173 callback() 174 175 def pull_etdump(self, output_path, callback=None): 176 self._adb(["pull", self.etdump_path, output_path]) 177 if callback: 178 callback() 179 180 def pull_debug_output(self, etdump_path, debug_ouput_path, callback=None): 181 self._adb(["pull", self.etdump_path, etdump_path]) 182 self._adb(["pull", self.debug_output_path, debug_ouput_path]) 183 if callback: 184 callback() 185 186 187def ptq_calibrate(captured_model, quantizer, dataset): 188 annotated_model = prepare_pt2e(captured_model, quantizer) 189 print("Quantizing(PTQ) the model...") 190 # calibration 191 if callable(dataset): 192 dataset(annotated_model) 193 else: 194 for data in dataset: 195 annotated_model(*data) 196 return annotated_model 197 198 199def qat_train(ori_model, captured_model, quantizer, dataset): 200 data, targets = dataset 201 annotated_model = torch.ao.quantization.move_exported_model_to_train( 202 prepare_qat_pt2e(captured_model, quantizer) 203 ) 204 optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) 205 criterion = torch.nn.CrossEntropyLoss() 206 for i, d in enumerate(data): 207 print(f"Epoch {i}") 208 if i > 3: 209 # Freeze quantizer parameters 210 annotated_model.apply(torch.ao.quantization.disable_observer) 211 if i > 2: 212 # Freeze batch norm mean and variance estimates 213 annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) 214 215 output = annotated_model(*d) 216 loss = criterion(output, targets[i]) 217 optimizer.zero_grad() 218 loss.backward() 219 optimizer.step() 220 221 return torch.ao.quantization.quantize_pt2e.convert_pt2e( 222 torch.ao.quantization.move_exported_model_to_eval(annotated_model) 223 ) 224 225 226def make_quantizer( 227 quant_dtype: Optional[QuantDtype] = QuantDtype.use_8a8w, 228 custom_annotations=(), 229 per_channel_conv=True, 230 per_channel_linear=False, 231 act_observer=MovingAverageMinMaxObserver, 232 is_qat=False, 233): 234 quantizer = QnnQuantizer() 235 quantizer.add_custom_quant_annotations(custom_annotations) 236 quantizer.set_per_channel_conv_quant(per_channel_conv) 237 quantizer.set_per_channel_linear_quant(per_channel_linear) 238 quantizer.set_quant_config(quant_dtype, is_qat, act_observer) 239 return quantizer 240 241 242# TODO: refactor to support different backends 243def build_executorch_binary( 244 model, # noqa: B006 245 inputs, # noqa: B006 246 soc_model, 247 file_name, 248 dataset: List[torch.Tensor] | Callable[[torch.fx.GraphModule], None], 249 skip_node_id_set=None, 250 skip_node_op_set=None, 251 quant_dtype: Optional[QuantDtype] = None, 252 custom_quantizer=None, 253 shared_buffer=False, 254 metadata=None, 255 dump_intermediate_outputs=False, 256 custom_pass_config=frozenset(), 257 qat_training_data=None, 258): 259 if quant_dtype is not None: 260 captured_model = torch.export.export(model, inputs).module() 261 if qat_training_data: 262 quantizer = custom_quantizer or make_quantizer( 263 quant_dtype=quant_dtype, is_qat=True 264 ) 265 # qat training 266 annotated_model = qat_train( 267 model, captured_model, quantizer, qat_training_data 268 ) 269 else: 270 quantizer = custom_quantizer or make_quantizer(quant_dtype=quant_dtype) 271 # ptq calibration 272 annotated_model = ptq_calibrate(captured_model, quantizer, dataset) 273 274 quantized_model = convert_pt2e(annotated_model) 275 edge_prog = capture_program(quantized_model, inputs, custom_pass_config) 276 else: 277 edge_prog = capture_program(model, inputs, custom_pass_config) 278 279 backend_options = generate_htp_compiler_spec( 280 use_fp16=False if quant_dtype else True 281 ) 282 qnn_partitioner = QnnPartitioner( 283 generate_qnn_executorch_compiler_spec( 284 soc_model=getattr(QcomChipset, soc_model), 285 backend_options=backend_options, 286 shared_buffer=shared_buffer, 287 dump_intermediate_outputs=dump_intermediate_outputs, 288 ), 289 skip_node_id_set, 290 skip_node_op_set, 291 ) 292 293 executorch_config = ExecutorchBackendConfig( 294 # For shared buffer, user must pass the memory address 295 # which is allocated by RPC memory to executor runner. 296 # Therefore, won't want to pre-allocate 297 # by memory manager in runtime. 298 memory_planning_pass=MemoryPlanningPass( 299 alloc_graph_input=not shared_buffer, 300 alloc_graph_output=not shared_buffer, 301 ), 302 ) 303 304 if metadata is None: 305 exported_program = to_backend(edge_prog.exported_program, qnn_partitioner) 306 exported_program.graph_module.graph.print_tabular() 307 exec_prog = to_edge(exported_program).to_executorch(config=executorch_config) 308 with open(f"{file_name}.pte", "wb") as file: 309 file.write(exec_prog.buffer) 310 else: 311 edge_prog_mgr = EdgeProgramManager( 312 edge_programs={"forward": edge_prog.exported_program}, 313 constant_methods=metadata, 314 compile_config=EdgeCompileConfig(_check_ir_validity=False), 315 ) 316 317 edge_prog_mgr = edge_prog_mgr.to_backend(qnn_partitioner) 318 exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) 319 with open(f"{file_name}.pte", "wb") as file: 320 file.write(exec_prog_mgr.buffer) 321 322 323def make_output_dir(path: str): 324 if os.path.exists(path): 325 for f in os.listdir(path): 326 os.remove(os.path.join(path, f)) 327 os.removedirs(path) 328 os.makedirs(path) 329 330 331def topk_accuracy(predictions, targets, k): 332 def solve(prob, target, k): 333 _, indices = torch.topk(prob, k=k, sorted=True) 334 golden = torch.reshape(target, [-1, 1]) 335 correct = (golden == indices) * 1.0 336 top_k_accuracy = torch.mean(correct) * k 337 return top_k_accuracy 338 339 cnt = 0 340 for index, pred in enumerate(predictions): 341 cnt += solve(torch.from_numpy(pred), targets[index], k) 342 343 return cnt * 100.0 / len(predictions) 344 345 346def segmentation_metrics(predictions, targets, classes): 347 def make_confusion(goldens, predictions, num_classes): 348 def histogram(golden, predict): 349 mask = golden < num_classes 350 hist = np.bincount( 351 num_classes * golden[mask].astype(int) + predict[mask], 352 minlength=num_classes**2, 353 ).reshape(num_classes, num_classes) 354 return hist 355 356 confusion = np.zeros((num_classes, num_classes)) 357 for g, p in zip(goldens, predictions): 358 confusion += histogram(g.flatten(), p.flatten()) 359 360 return confusion 361 362 eps = 1e-6 363 confusion = make_confusion(targets, predictions, len(classes)) 364 pa = np.diag(confusion).sum() / (confusion.sum() + eps) 365 mpa = np.mean(np.diag(confusion) / (confusion.sum(axis=1) + eps)) 366 iou = np.diag(confusion) / ( 367 confusion.sum(axis=1) + confusion.sum(axis=0) - np.diag(confusion) + eps 368 ) 369 miou = np.mean(iou) 370 cls_iou = dict(zip(classes, iou)) 371 return (pa, mpa, miou, cls_iou) 372 373 374def get_imagenet_dataset( 375 dataset_path, data_size, image_shape, crop_size=None, shuffle=True 376): 377 from torchvision import datasets, transforms 378 379 def get_data_loader(): 380 preprocess = transforms.Compose( 381 [ 382 transforms.Resize(image_shape), 383 transforms.CenterCrop(crop_size or image_shape[0]), 384 transforms.ToTensor(), 385 transforms.Normalize( 386 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 387 ), 388 ] 389 ) 390 imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) 391 return torch.utils.data.DataLoader( 392 imagenet_data, 393 shuffle=shuffle, 394 ) 395 396 # prepare input data 397 inputs, targets, input_list = [], [], "" 398 data_loader = get_data_loader() 399 for index, data in enumerate(data_loader): 400 if index >= data_size: 401 break 402 feature, target = data 403 inputs.append((feature,)) 404 targets.append(target) 405 input_list += f"input_{index}_0.raw\n" 406 407 return inputs, targets, input_list 408 409 410def setup_common_args_and_variables(): 411 parser = argparse.ArgumentParser() 412 413 parser.add_argument( 414 "-m", 415 "--model", 416 help="SoC model of current device. e.g. 'SM8550' for Snapdragon 8 Gen 2.", 417 type=str, 418 required=True, 419 ) 420 421 parser.add_argument( 422 "-b", 423 "--build_folder", 424 help="path to cmake binary directory for android, e.g., /path/to/build-android", 425 type=str, 426 required=True, 427 ) 428 429 parser.add_argument( 430 "-H", 431 "--host", 432 help="hostname where android device is connected.", 433 default=None, 434 type=str, 435 ) 436 437 parser.add_argument( 438 "--ip", 439 help="IPC address for delivering execution result", 440 default="", 441 type=str, 442 ) 443 444 parser.add_argument( 445 "--port", 446 help="IPC port for delivering execution result", 447 default=-1, 448 type=int, 449 ) 450 451 parser.add_argument( 452 "-S", 453 "--skip_delegate_node_ids", 454 help="If specified, skip delegation for the specified node based on node ids. Node ids should be seperated by comma. e.g., aten_relu_default_10,aten_relu_default_2", 455 default=None, 456 type=str, 457 ) 458 459 parser.add_argument( 460 "-f", 461 "--skip_delegate_node_ops", 462 help="If specified, skip delegation for the specified op. Node ops should be seperated by comma. e.g., aten.add.Tensor,aten.relu.default", 463 default=None, 464 type=str, 465 ) 466 467 parser.add_argument( 468 "-c", 469 "--compile_only", 470 help="If specified, only compile the model.", 471 action="store_true", 472 default=False, 473 ) 474 475 parser.add_argument( 476 "-s", 477 "--device", 478 help="serial number for android device communicated via ADB.", 479 type=str, 480 ) 481 482 parser.add_argument( 483 "-z", 484 "--shared_buffer", 485 help="Enables usage of shared buffer between application and backend for graph I/O.", 486 action="store_true", 487 ) 488 489 parser.add_argument( 490 "--skip_push", 491 help="If specified, skip pushing files to device.", 492 action="store_true", 493 default=False, 494 ) 495 496 parser.add_argument( 497 "--dump_intermediate_outputs", 498 help="If specified, enable dump intermediate outputs", 499 action="store_true", 500 default=False, 501 ) 502 503 # QNN_SDK_ROOT might also be an argument, but it is used in various places. 504 # So maybe it's fine to just use the environment. 505 if "QNN_SDK_ROOT" not in os.environ: 506 raise RuntimeError("Environment variable QNN_SDK_ROOT must be set") 507 print(f"QNN_SDK_ROOT={os.getenv('QNN_SDK_ROOT')}") 508 509 return parser 510 511 512def parse_skip_delegation_node(args): 513 skip_node_id_set = set() 514 skip_node_op_set = set() 515 516 if args.skip_delegate_node_ids is not None: 517 skip_node_id_set = set(map(str, args.skip_delegate_node_ids.split(","))) 518 print("Skipping following node ids: ", skip_node_id_set) 519 520 if args.skip_delegate_node_ops is not None: 521 skip_node_op_set = set(map(str, args.skip_delegate_node_ops.split(","))) 522 print("Skipping following node ops: ", skip_node_op_set) 523 524 return skip_node_id_set, skip_node_op_set 525 526 527def generate_inputs(dest_path: str, file_name: str, inputs=None, input_list=None): 528 input_list_file = None 529 input_files = [] 530 531 # Prepare input list 532 if input_list is not None: 533 input_list_file = f"{dest_path}/{file_name}" 534 with open(input_list_file, "w") as f: 535 f.write(input_list) 536 f.flush() 537 538 # Prepare input data 539 if inputs is not None: 540 for idx, data in enumerate(inputs): 541 for i, d in enumerate(data): 542 file_name = f"{dest_path}/input_{idx}_{i}.raw" 543 d.detach().numpy().tofile(file_name) 544 input_files.append(file_name) 545 546 return input_list_file, input_files 547