1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import codecs 8import getpass 9import json 10import os 11import time 12from multiprocessing.connection import Client 13 14import torch 15from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo 16 17from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner 18 19from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype 20from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset 21from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO 22from executorch.backends.qualcomm.utils.utils import ( 23 capture_program, 24 convert_linear_to_conv2d, 25 generate_htp_compiler_spec, 26 generate_qnn_executorch_compiler_spec, 27 get_soc_to_chipset_map, 28) 29from executorch.examples.qualcomm.oss_scripts.llama2.model.static_llama import ( 30 LlamaModel, 31 ModelArgs, 32) 33from executorch.examples.qualcomm.utils import ( 34 make_output_dir, 35 make_quantizer, 36 setup_common_args_and_variables, 37 SimpleADB, 38) 39from executorch.exir import EdgeCompileConfig, EdgeProgramManager 40from executorch.exir.capture._config import ExecutorchBackendConfig 41from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass 42from executorch.extension.llm.export.builder import DType 43 44from sentencepiece import SentencePieceProcessor 45from torch.ao.quantization.observer import MinMaxObserver 46from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 47 48 49pte_filename = "llama2_qnn" 50 51 52def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: 53 """ 54 This function is specific for matmul op 16a8w. 55 """ 56 57 from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY 58 from executorch.backends.qualcomm.quantizer.quantizer import ( 59 get_16a8w_qnn_ptq_config, 60 get_8a8w_qnn_ptq_config, 61 QuantizationConfig, 62 ) 63 from torch.ao.quantization.quantizer import ( 64 QuantizationAnnotation, 65 SharedQuantizationSpec, 66 ) 67 from torch.fx import Node 68 69 def annotate_matmul(node: Node, quantization_config: QuantizationConfig): 70 input_qspec_map = {} 71 input_act = node.args[0] 72 input_spec = quantization_config.input_activation 73 input_qspec_map[input_act] = input_spec 74 75 input_act1 = node.args[1] 76 input_spec1 = quantization_config.weight 77 input_qspec_map[input_act1] = input_spec1 78 79 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 80 input_qspec_map=input_qspec_map, 81 output_qspec=quantization_config.output_activation, 82 _annotated=True, 83 ) 84 85 def annotate_cat(node: Node, quantization_config: QuantizationConfig): 86 input_nodes = node.args[0] 87 88 first_input_node = input_nodes[0] 89 input_qspec_map = {} 90 input_qspec_map[first_input_node] = quantization_config.input_activation 91 share_qparams_with_input_act0_qspec = SharedQuantizationSpec( 92 (first_input_node, node) 93 ) 94 95 for input_node in input_nodes[1:]: 96 if input_node not in input_qspec_map: 97 input_qspec_map[input_node] = share_qparams_with_input_act0_qspec 98 99 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 100 input_qspec_map=input_qspec_map, 101 output_qspec=share_qparams_with_input_act0_qspec, 102 _annotated=True, 103 ) 104 105 def annotate_single_in_single_out( 106 node: Node, quantization_config: QuantizationConfig 107 ) -> None: 108 109 input_qspec_map = {} 110 input_act = node.args[0] 111 input_qspec_map[input_act] = quantization_config.input_activation 112 113 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 114 input_qspec_map=input_qspec_map, 115 output_qspec=quantization_config.output_activation, 116 _annotated=True, 117 ) 118 119 def annotate_matmul_input1(node: Node): 120 quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True) 121 while isinstance(node, Node) and node.op == "call_function": 122 if node.target in [ 123 torch.ops.aten.permute.default, 124 torch.ops.aten.transpose.int, 125 ]: 126 annotate_single_in_single_out(node, quantization_config_8a8w) 127 node = node.args[0] 128 elif node.target == torch.ops.aten.cat.default: 129 annotate_cat(node, quantization_config_8a8w) 130 node = node.args[0][0] 131 else: 132 node = node.args[0] 133 134 quantization_config_16a8w = get_16a8w_qnn_ptq_config() 135 136 for node in gm.graph.nodes: 137 if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: 138 annotate_matmul(node, quantization_config_16a8w) 139 annotate_matmul_input1(node.args[1]) 140 141 142def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: 143 from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY 144 from executorch.backends.qualcomm.quantizer.quantizer import ( 145 get_ptq_per_channel_quant_config, 146 QuantizationConfig, 147 ) 148 from torch.ao.quantization.quantizer import QuantizationAnnotation 149 from torch.fx import Node 150 151 def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: 152 input_qspec_map = {} 153 input_act = node.args[0] 154 input_spec = quantization_config.input_activation 155 input_qspec_map[input_act] = input_spec 156 157 weight = node.args[1] 158 input_qspec_map[weight] = quantization_config.weight 159 160 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 161 input_qspec_map=input_qspec_map, 162 output_qspec=quantization_config.output_activation, 163 _annotated=True, 164 ) 165 166 quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( 167 torch.uint16, weight_dtype=torch.int8 168 ) 169 for node in gm.graph.nodes: 170 if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: 171 if "nn_module_stack" in node.meta: 172 module_values_list = list(node.meta["nn_module_stack"].values()) 173 full_qualified_name = module_values_list[0][0] 174 if full_qualified_name == "L['self'].llama.output": 175 annotate_conv2d( 176 node, quantization_config=quantization_config_16a8w_per_channel 177 ) 178 179 180def calibrate( 181 example_inputs, 182 user_prompts, 183 module: torch.fx.GraphModule, 184 tokenizer_model_path="tokenizer.model", 185): 186 sp_model = SentencePieceProcessor(model_file=tokenizer_model_path) 187 _, _, atten_mask, k_caches, v_caches = example_inputs 188 189 # TODO: change criteria & support batch inputs if necessary 190 pos = torch.tensor(0, dtype=torch.int32) 191 token_list = [sp_model.bos_id()] 192 for prompt in user_prompts.split(): 193 token_list += sp_model.encode(prompt) 194 195 def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: 196 probs_sort, probs_indices = torch.sort(probs, dim=-1, descending=True) 197 probs_sum = torch.cumsum(probs_sort, dim=-1) 198 mask = probs_sum - probs_sort > top_p 199 probs_sort[mask] = 0 200 probs_sort /= probs_sort.sum(dim=-1, keepdim=True) 201 next_token = torch.multinomial(probs_sort, num_samples=1) 202 return probs_indices.gather(dim=-1, index=next_token) 203 204 with torch.no_grad(): 205 while token_list[-1] != sp_model.eos_id() and pos < 128: 206 logits, new_k_caches, new_v_caches = module( 207 torch.full((1, 1), token_list[pos]), 208 torch.full((1, 1), pos), 209 atten_mask, 210 *k_caches, 211 *v_caches, 212 ) 213 k_caches = [ 214 torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1) 215 for i, k_cache in enumerate(k_caches) 216 ] 217 v_caches = [ 218 torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1) 219 for i, v_cache in enumerate(v_caches) 220 ] 221 222 pos += 1 223 atten_mask[0][-pos - 1] = 0 224 if pos >= len(token_list): 225 probs = torch.softmax(logits[:, -1] / 0.8, dim=-1) 226 token_list.append(sample_top_p(probs, 0.9).item()) 227 228 print(f"calibration data:\n{sp_model.decode(token_list)}") 229 230 231class SingleLlama: 232 def __init__(self, llama_model) -> None: 233 super().__init__() 234 self.llama_model = llama_model 235 self.quant_dtype = None 236 self.llama_meta = self.llama_model.get_metadata() 237 self.has_quant_io = False 238 tokens, pos_ids, atten_mask, k_caches, v_caches = self.get_example_inputs() 239 self.inputs = (tokens, pos_ids, atten_mask, *k_caches, *v_caches) 240 241 def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type): 242 if not self.has_quant_io: 243 return 244 245 # shape of k caches and v caches 246 input_cache_shape = { 247 (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]), 248 (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]), 249 } 250 for n in gm.graph.nodes: 251 if ( 252 n.op == "placeholder" 253 and len(users := list(n.users)) == 1 254 and users[0].meta["val"].size()[-2:] in input_cache_shape 255 ): 256 n.meta[QCOM_QUANTIZED_IO] = kv_type 257 elif n.op == "output": 258 for a in n.args[0]: 259 if ( 260 a.meta["val"].flatten().size()[0] 261 == self.llama_meta["get_head_dim"] 262 ): 263 a.meta[QCOM_QUANTIZED_IO] = kv_type 264 265 def quantize(self, quant_dtype, custom_annotations=()): 266 self.quant_dtype = quant_dtype 267 quantizer = make_quantizer( 268 quant_dtype=quant_dtype, 269 per_channel_conv=True, 270 per_channel_linear=True, 271 act_observer=MinMaxObserver, 272 ) 273 quantizer.add_custom_quant_annotations(custom_annotations) 274 275 self.has_quant_io = True 276 fx_graph_module = None 277 278 with torch.no_grad(): 279 fx_graph_module = torch.export.export( 280 self.llama_model, self.inputs 281 ).module() 282 fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) 283 print("Quantizing the model...") 284 calibrate( 285 self.get_example_inputs(), 286 args.prompt, 287 fx_graph_module, 288 tokenizer_model_path=args.tokenizer_model, 289 ) 290 291 self.llama_model = convert_pt2e(fx_graph_module) 292 293 def lowering_modules( 294 self, work_space, kv_type=torch.uint8, soc_model=QcomChipset.SM8650 295 ): 296 executorch_config = ExecutorchBackendConfig( 297 passes=[ 298 BuildQuantIo(), 299 ], 300 # For shared buffer, user must pass the memory address 301 # which is allocated by RPC memory to executor runner. 302 # Therefore, won't want to pre-allocate 303 # by memory manager in runtime. 304 memory_planning_pass=MemoryPlanningPass( 305 alloc_graph_input=False, 306 alloc_graph_output=False, 307 ), 308 extract_delegate_segments=True, 309 ) 310 with torch.no_grad(): 311 # backend option 312 backend_options = generate_htp_compiler_spec(use_fp16=False) 313 compiler_specs = generate_qnn_executorch_compiler_spec( 314 soc_model=soc_model, 315 backend_options=backend_options, 316 shared_buffer=True, 317 ) 318 partitioner = QnnPartitioner(compiler_specs) 319 edge_prog = capture_program(self.llama_model, self.inputs) 320 self._tag_kv_ios(edge_prog.exported_program.graph_module, kv_type=kv_type) 321 edge_prog_mgr = EdgeProgramManager( 322 edge_programs={"forward": edge_prog.exported_program}, 323 constant_methods=self.llama_meta, 324 compile_config=EdgeCompileConfig(_check_ir_validity=False), 325 ) 326 edge_prog_mgr = edge_prog_mgr.to_backend(partitioner) 327 exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) 328 with open(f"{work_space}/{pte_filename}.pte", "wb") as file: 329 exec_prog_mgr.write_to_file(file) 330 331 def get_example_inputs(self): 332 return self.llama_model.get_example_inputs() 333 334 335def compile(args): 336 os.makedirs(args.artifact, exist_ok=True) 337 start_ts = time.time() 338 with open(args.params) as f: 339 config = ModelArgs(**json.load(f)) 340 # TODO: support batch inputs if necessary 341 config.max_batch_size = 1 342 config.max_seq_len = 1024 343 state_dict = torch.load( 344 args.checkpoint, weights_only=True, map_location="cpu", mmap=True 345 ) 346 end_load_ts = time.time() 347 print("torch.load checkpoint", end_load_ts - start_ts) 348 349 llama_instance = None 350 with torch.device("meta"): 351 llama_instance = LlamaModel(config, output_new_cache_only=True) 352 if "model" in state_dict: 353 state_dict = state_dict["model"] 354 llama_instance.load_state_dict( 355 state_dict, 356 strict=False, 357 assign=True, 358 ) 359 end_load_state_dict_ts = time.time() 360 print("instance.load_state_dict", end_load_state_dict_ts - end_load_ts) 361 362 for layer in llama_instance.layers: 363 if getattr(layer.attention, "prepare_sha", None): 364 layer.attention.prepare_sha() 365 366 kv_type = torch.uint8 367 assert args.ptq in [ 368 "8a8w", 369 "16a4w", 370 ], f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." 371 quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") 372 assert args.tokenizer_model is not None, "Need tokenizer model for calibration" 373 374 if args.dtype_override is not None: 375 dtype_override = DType[args.dtype_override] 376 llama_instance = llama_instance.to(dtype_override.to_torch_dtype()) 377 378 llama_instance = convert_linear_to_conv2d(llama_instance) 379 single_llama = SingleLlama(llama_instance.eval()) 380 381 start_quantize_ts = time.time() 382 single_llama.quantize( 383 quant_dtype, 384 custom_annotations=( 385 annotate_matmul_16a8w, 386 annotate_linear_16a8w_in_affine_layer, 387 ), 388 ) 389 end_quantize_ts = time.time() 390 print("single_llama.quantize(quant_dtype)", end_quantize_ts - start_quantize_ts) 391 single_llama.lowering_modules( 392 args.artifact, kv_type=kv_type, soc_model=get_soc_to_chipset_map()[args.model] 393 ) 394 end_lowering_ts = time.time() 395 print("Complete Compile", end_lowering_ts - end_quantize_ts) 396 397 398def inference(args, pre_gen_pte=""): 399 workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" 400 401 runner_args = " ".join( 402 [ 403 f"--model_path {pte_filename}.pte", 404 "--output_folder_path outputs", 405 f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}", 406 f'--prompt "{args.prompt}"', 407 f"--seq_len {args.seq_len}", 408 f"--temperature {args.temperature}", 409 ] 410 ) 411 runner_cmd = " ".join( 412 [ 413 f"cd {workspace} &&", 414 f"./qnn_llama_runner {runner_args}", 415 ] 416 ) 417 418 pte_path = ( 419 f"{pre_gen_pte}/{pte_filename}.pte" 420 if pre_gen_pte 421 else f"{args.artifact}/{pte_filename}.pte" 422 ) 423 adb = SimpleADB( 424 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 425 build_path=f"{args.build_folder}", 426 pte_path=pte_path, 427 workspace=workspace, 428 device_id=args.device, 429 host_id=args.host, 430 soc_model=args.model, 431 shared_buffer=args.shared_buffer, 432 runner="examples/qualcomm/oss_scripts/llama2/qnn_llama_runner", 433 ) 434 # No pregen inputs, input_list is not required 435 adb.push(inputs=[], input_list="", files=[args.tokenizer_bin]) 436 adb.execute(custom_runner_cmd=runner_cmd) 437 438 # collect output data 439 output_data_folder = f"{args.artifact}/outputs" 440 make_output_dir(output_data_folder) 441 outputs = [] 442 443 def post_process(): 444 for f in sorted( 445 os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) 446 ): 447 with codecs.open( 448 os.path.join(output_data_folder, f), 449 "r", 450 encoding="utf-8", 451 errors="replace", 452 ) as fdata: 453 outputs.append(fdata.read()) 454 455 adb.pull(output_path=args.artifact, callback=post_process) 456 457 if args.ip and args.port != -1: 458 with Client((args.ip, args.port)) as conn: 459 conn.send( 460 json.dumps( 461 { 462 "result": outputs, 463 } 464 ) 465 ) 466 else: 467 for idx, output in enumerate(outputs): 468 print(f"Results[{idx}]:\n{output}") 469 470 471# flake8: noqa: C901 472if __name__ == "__main__": 473 parser = setup_common_args_and_variables() 474 parser.add_argument( 475 "-a", 476 "--artifact", 477 help="path for storing generated artifacts and output by this example. Default ./llama2_qnn", 478 default="./llama2_qnn", 479 type=str, 480 ) 481 482 parser.add_argument( 483 "-P", 484 "--ptq", 485 help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.", 486 default="16a4w", 487 ) 488 489 parser.add_argument( 490 "--checkpoint", 491 help="Pass llama2 checkpoint.", 492 required=True, 493 type=str, 494 ) 495 496 parser.add_argument( 497 "--params", 498 help="Pass llama2 params json file.", 499 required=True, 500 type=str, 501 ) 502 503 parser.add_argument( 504 "--tokenizer_bin", 505 help="Pass llama2 tokenizer binary.", 506 required=True, 507 type=str, 508 ) 509 510 parser.add_argument( 511 "--tokenizer_model", 512 help="Pass llama2 tokenizer model.", 513 type=str, 514 default=None, 515 ) 516 517 parser.add_argument( 518 "--prompt", 519 help="User prompts for llama2.", 520 required=True, 521 type=str, 522 ) 523 524 parser.add_argument( 525 "--seq_len", 526 help="Ouput sequence length for llama2.", 527 default=128, 528 type=int, 529 ) 530 531 parser.add_argument( 532 "--temperature", 533 help="Sampling temperature for llama2.", 534 default=0.8, 535 type=float, 536 ) 537 538 parser.add_argument( 539 "-d", 540 "--dtype-override", 541 default="fp32", 542 type=str, 543 choices=["fp32", "fp16"], 544 help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32", 545 ) 546 547 parser.add_argument( 548 "--pre_gen_pte", 549 help="Run the Pre-generated llama2 in the given directory", 550 type=str, 551 ) 552 553 args = parser.parse_args() 554 if args.compile_only and args.pre_gen_pte: 555 exit("Cannot set both compile_only and pre_gen_pte as true") 556 557 if args.pre_gen_pte: 558 inference(args, args.pre_gen_pte) 559 exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") 560 561 if args.compile_only: 562 compile(args) 563 exit(f"Finish compile_only and save to {args.artifact}") 564 565 try: 566 compile(args) 567 inference(args) 568 except Exception as e: 569 if args.ip and args.port != -1: 570 with Client((args.ip, args.port)) as conn: 571 conn.send(json.dumps({"Error": str(e)})) 572 else: 573 raise Exception(e) 574