1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import getpass 8import json 9import logging 10import os 11 12import sys 13import time 14from multiprocessing.connection import Client 15 16import torch 17from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo 18 19from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner 20 21from executorch.backends.qualcomm.quantizer.custom_annotation import ( 22 annotate_matmul_16a8w, 23 custom_annotate_llama_last_conv_16a8w, 24) 25 26from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype 27from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset 28from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO 29from executorch.backends.qualcomm.utils.utils import ( 30 capture_program, 31 convert_linear_to_conv2d, 32 generate_htp_compiler_spec, 33 generate_qnn_executorch_compiler_spec, 34 get_soc_to_chipset_map, 35) 36from executorch.examples.qualcomm.oss_scripts.llama2.model.static_llama import ( 37 LlamaModel, 38 ModelArgs, 39) 40from executorch.examples.qualcomm.utils import ( 41 make_output_dir, 42 make_quantizer, 43 setup_common_args_and_variables, 44 SimpleADB, 45) 46from executorch.exir import EdgeCompileConfig, EdgeProgramManager 47from executorch.exir.capture._config import ExecutorchBackendConfig 48from executorch.exir.dialects._ops import ops as exir_ops 49from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass 50from executorch.extension.llm.custom_ops import model_sharding 51from executorch.extension.llm.export.builder import DType 52from executorch.extension.llm.tokenizer.utils import get_tokenizer 53 54from torch.ao.quantization.observer import MinMaxObserver 55from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 56 57sys.setrecursionlimit(4096) 58FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 59logging.basicConfig(level=logging.INFO, format=FORMAT) 60logging.getLogger().setLevel(logging.INFO) 61 62pte_filename = "llama3_2_qnn" 63 64 65def calibrate( 66 example_inputs, 67 user_prompts, 68 module: torch.fx.GraphModule, 69 tokenizer_model_path="tokenizer.model", 70): 71 sp_model = get_tokenizer(tokenizer_model_path) 72 _, _, atten_mask, k_caches, v_caches = example_inputs 73 74 # TODO: change criteria & support batch inputs if necessary 75 pos = torch.tensor(0, dtype=torch.int32) 76 token_list = sp_model.encode(user_prompts, bos=True, eos=False) 77 78 with torch.no_grad(): 79 while token_list[-1] != sp_model.eos_id and pos < 511: 80 logits, new_k_caches, new_v_caches = module( 81 torch.full((1, 1), token_list[pos], dtype=torch.int32), 82 torch.full((1, 1), pos), 83 atten_mask, 84 *k_caches, 85 *v_caches, 86 ) 87 k_caches = [ 88 torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1) 89 for i, k_cache in enumerate(k_caches) 90 ] 91 v_caches = [ 92 torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1) 93 for i, v_cache in enumerate(v_caches) 94 ] 95 96 pos += 1 97 atten_mask[0][-pos - 1] = 0 98 if pos >= len(token_list): 99 token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) 100 101 print(f"calibration data:\n{sp_model.decode(token_list)}") 102 103 104class SingleLlama: 105 def __init__(self, llama_model) -> None: 106 super().__init__() 107 self.llama_model = llama_model 108 self.quant_dtype = None 109 self.llama_meta = self.llama_model.get_metadata() 110 self.has_quant_io = False 111 tokens, pos_ids, atten_mask, k_caches, v_caches = self.get_example_inputs() 112 self.inputs = (tokens, pos_ids, atten_mask, *k_caches, *v_caches) 113 114 def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type): 115 if not self.has_quant_io: 116 return 117 118 # shape of k caches and v caches 119 input_cache_shape = { 120 (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]), 121 (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]), 122 } 123 for n in gm.graph.nodes: 124 if ( 125 n.op == "placeholder" 126 and len(users := list(n.users)) == 1 127 and users[0].meta["val"].size()[-2:] in input_cache_shape 128 ): 129 n.meta[QCOM_QUANTIZED_IO] = kv_type 130 elif n.op == "output": 131 for a in n.args[0]: 132 if ( 133 a.meta["val"].flatten().size()[0] 134 == self.llama_meta["get_head_dim"] 135 ): 136 a.meta[QCOM_QUANTIZED_IO] = kv_type 137 138 # Tag sharding io 139 if exir_ops.edge.llama.fallback.default in [ 140 u.target for u in list(n.users.keys()) 141 ] + [n.target]: 142 n.meta[QCOM_QUANTIZED_IO] = sharding_type 143 144 def quantize(self, quant_dtype, custom_annotations=()): 145 self.quant_dtype = quant_dtype 146 quantizer = make_quantizer( 147 quant_dtype=quant_dtype, 148 per_channel_conv=True, 149 per_channel_linear=True, 150 act_observer=MinMaxObserver, 151 ) 152 quantizer.add_custom_quant_annotations(custom_annotations) 153 154 self.has_quant_io = True 155 fx_graph_module = None 156 157 with torch.no_grad(): 158 fx_graph_module = torch.export.export( 159 self.llama_model, self.inputs 160 ).module() 161 fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) 162 logging.info("Quantizing the model...") 163 calibrate( 164 self.get_example_inputs(), 165 args.prompt, 166 fx_graph_module, 167 tokenizer_model_path=args.tokenizer_model, 168 ) 169 170 self.llama_model = convert_pt2e(fx_graph_module) 171 172 def lowering_modules( 173 self, 174 work_space, 175 kv_type=torch.uint8, 176 sharding_type=torch.uint16, 177 use_fp16=False, 178 soc_model=QcomChipset.SM8650, 179 num_sharding=0, 180 ): 181 executorch_config = ExecutorchBackendConfig( 182 passes=[ 183 BuildQuantIo(), 184 ], 185 # For shared buffer, user must pass the memory address 186 # which is allocated by RPC memory to executor runner. 187 # Therefore, won't want to pre-allocate 188 # by memory manager in runtime. 189 memory_planning_pass=MemoryPlanningPass( 190 alloc_graph_input=False, 191 alloc_graph_output=False, 192 ), 193 extract_delegate_segments=True, 194 ) 195 with torch.no_grad(): 196 # backend option 197 backend_options = generate_htp_compiler_spec(use_fp16=use_fp16) 198 compiler_specs = generate_qnn_executorch_compiler_spec( 199 soc_model=soc_model, 200 backend_options=backend_options, 201 shared_buffer=True, 202 ) 203 skip_node_op_set = {"llama.fallback.default"} 204 partitioner = QnnPartitioner( 205 compiler_specs, skip_node_op_set=skip_node_op_set 206 ) 207 edge_prog = capture_program( 208 self.llama_model, self.inputs, custom_pass_config=frozenset() 209 ) 210 211 if num_sharding > 0: 212 model_sharding.split_graph( 213 edge_prog.exported_program, 214 self.llama_meta["get_n_layers"], 215 shares=num_sharding, 216 ) 217 218 self._tag_kv_ios( 219 edge_prog.exported_program.graph_module, 220 kv_type=kv_type, 221 sharding_type=sharding_type, 222 ) 223 edge_prog_mgr = EdgeProgramManager( 224 edge_programs={"forward": edge_prog.exported_program}, 225 constant_methods=self.llama_meta, 226 compile_config=EdgeCompileConfig(_check_ir_validity=False), 227 ) 228 edge_prog_mgr = edge_prog_mgr.to_backend(partitioner) 229 exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) 230 with open(f"{work_space}/{pte_filename}.pte", "wb") as file: 231 exec_prog_mgr.write_to_file(file) 232 233 def get_example_inputs(self): 234 return self.llama_model.get_example_inputs() 235 236 237def compile(args): 238 os.makedirs(args.artifact, exist_ok=True) 239 start_ts = time.time() 240 with open(args.params) as f: 241 config = ModelArgs(**json.load(f)) 242 # TODO: support batch inputs if necessary 243 config.max_batch_size = 1 244 config.max_seq_len = 512 245 state_dict = torch.load( 246 args.checkpoint, weights_only=True, map_location="cpu", mmap=True 247 ) 248 249 llama_instance = None 250 with torch.device("meta"): 251 llama_instance = LlamaModel(config, output_new_cache_only=True) 252 if "model" in state_dict: 253 state_dict = state_dict["model"] 254 llama_instance.load_state_dict( 255 state_dict, 256 strict=False, 257 assign=True, 258 ) 259 end_load_ts = time.time() 260 logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}") 261 262 for layer in llama_instance.layers: 263 if getattr(layer.attention, "prepare_sha", None): 264 layer.attention.prepare_sha() 265 266 use_fp16 = False 267 if args.ptq != None: 268 kv_type = torch.uint8 269 if args.ptq == "8a8w": 270 sharding_type = torch.uint8 271 elif args.ptq == "16a4w": 272 sharding_type = torch.uint16 273 else: 274 assert args.ptq in [ 275 "8a8w", 276 "16a4w", 277 ], f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." 278 quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") 279 else: 280 use_fp16 = True 281 kv_type = torch.float32 282 sharding_type = torch.float32 283 assert args.tokenizer_model is not None, "Need tokenizer model for calibration" 284 285 if args.dtype_override is not None: 286 dtype_override = DType[args.dtype_override] 287 llama_instance = llama_instance.to(dtype_override.to_torch_dtype()) 288 289 llama_instance = convert_linear_to_conv2d(llama_instance) 290 single_llama = SingleLlama(llama_instance.eval()) 291 292 if args.ptq != None: 293 start_quantize_ts = time.time() 294 single_llama.quantize( 295 quant_dtype, 296 custom_annotations=( 297 custom_annotate_llama_last_conv_16a8w, 298 annotate_matmul_16a8w, 299 ), 300 ) 301 end_quantize_ts = time.time() 302 logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") 303 304 start_lowering_ts = time.time() 305 single_llama.lowering_modules( 306 args.artifact, 307 kv_type=kv_type, 308 sharding_type=sharding_type, 309 use_fp16=use_fp16, 310 soc_model=get_soc_to_chipset_map()[args.model], 311 num_sharding=args.num_sharding, 312 ) 313 end_lowering_ts = time.time() 314 logging.info(f"Time for compiling: {end_lowering_ts - start_lowering_ts}") 315 316 317def inference(args, pre_gen_pte=""): 318 workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" 319 320 runner_args = " ".join( 321 [ 322 f"--model_path {pte_filename}.pte", 323 "--output_path outputs/outputs.txt", 324 f"--tokenizer_path {os.path.basename(args.tokenizer_model)}", 325 f'--prompt "{args.prompt}"', 326 f"--seq_len {args.seq_len}", 327 f"--temperature {args.temperature}", 328 ] 329 ) 330 runner_cmd = " ".join( 331 [ 332 f"cd {workspace} &&", 333 f"./qnn_llama3_2_{args.model_size.lower()}_runner {runner_args}", 334 ] 335 ) 336 337 pte_path = ( 338 f"{pre_gen_pte}/{pte_filename}.pte" 339 if pre_gen_pte 340 else f"{args.artifact}/{pte_filename}.pte" 341 ) 342 adb = SimpleADB( 343 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 344 build_path=f"{args.build_folder}", 345 pte_path=pte_path, 346 workspace=workspace, 347 device_id=args.device, 348 host_id=args.host, 349 soc_model=args.model, 350 shared_buffer=args.shared_buffer, 351 runner=f"examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_{args.model_size.lower()}_runner", 352 ) 353 # No pregen inputs, input_list is not required 354 adb.push(inputs=[], input_list="", files=[args.tokenizer_model]) 355 adb.execute(custom_runner_cmd=runner_cmd) 356 357 # collect output data 358 output_data_folder = f"{args.artifact}/outputs" 359 make_output_dir(output_data_folder) 360 outputs = [] 361 362 def post_process(): 363 with open(f"{args.artifact}/outputs/outputs.txt", "r") as f: 364 outputs.append(f.read()) 365 366 adb.pull(output_path=args.artifact, callback=post_process) 367 368 if args.ip and args.port != -1: 369 with Client((args.ip, args.port)) as conn: 370 conn.send( 371 json.dumps( 372 { 373 "result": outputs, 374 } 375 ) 376 ) 377 else: 378 for idx, output in enumerate(outputs): 379 logging.info(f"Results[{idx}]:\n{output}") 380 381 382# flake8: noqa: C901 383if __name__ == "__main__": 384 parser = setup_common_args_and_variables() 385 parser.add_argument( 386 "-a", 387 "--artifact", 388 help="path for storing generated artifacts and output by this example. Default ./llama3_2_qnn", 389 default="./llama3_2_qnn", 390 type=str, 391 ) 392 393 parser.add_argument( 394 "-P", 395 "--ptq", 396 help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.", 397 type=str, 398 ) 399 400 parser.add_argument( 401 "--checkpoint", 402 help="Pass llama checkpoint.", 403 required=True, 404 type=str, 405 ) 406 407 parser.add_argument( 408 "--params", 409 help="Pass llama params json file.", 410 required=True, 411 type=str, 412 ) 413 414 parser.add_argument( 415 "--model_size", 416 help="Determine what runner be used. For llama 3.2, we only support 1B/3B. ", 417 choices=["1B", "3B"], 418 required=True, 419 type=str, 420 ) 421 422 parser.add_argument( 423 "--tokenizer_model", 424 help="Pass llama tokenizer model.", 425 type=str, 426 default=None, 427 ) 428 429 parser.add_argument( 430 "--prompt", 431 help="User prompts for llama.", 432 required=True, 433 type=str, 434 ) 435 436 parser.add_argument( 437 "--seq_len", 438 help="Ouput sequence length for llama.", 439 default=128, 440 type=int, 441 ) 442 443 parser.add_argument( 444 "--temperature", 445 help="Sampling temperature for llama.", 446 default=0.8, 447 type=float, 448 ) 449 450 parser.add_argument( 451 "-d", 452 "--dtype-override", 453 default="fp32", 454 type=str, 455 choices=["fp32", "fp16"], 456 help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32", 457 ) 458 459 parser.add_argument( 460 "--pre_gen_pte", 461 help="Run the Pre-generated llama in the given directory", 462 type=str, 463 ) 464 465 parser.add_argument( 466 "--num_sharding", 467 type=int, 468 default=0, 469 help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.", 470 ) 471 472 args = parser.parse_args() 473 if args.compile_only and args.pre_gen_pte: 474 exit("Cannot set both compile_only and pre_gen_pte as true") 475 476 if args.pre_gen_pte: 477 inference(args, args.pre_gen_pte) 478 exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") 479 480 if args.compile_only: 481 compile(args) 482 exit(f"Finish compile_only and save to {args.artifact}") 483 484 try: 485 compile(args) 486 inference(args) 487 except Exception as e: 488 if args.ip and args.port != -1: 489 with Client((args.ip, args.port)) as conn: 490 conn.send(json.dumps({"Error": str(e)})) 491 else: 492 raise Exception(e) 493