1import os 2import sys 3 4if os.getcwd() not in sys.path: 5 sys.path.append(os.getcwd()) 6import argparse 7import struct 8import warnings 9 10import torch 11 12from aot_utils.llm_utils.preformatter import Preformatter 13from aot_utils.llm_utils.sanity_checks import ( 14 check_all_chunks_same_num_layer, 15 check_between_inclusive, 16 check_exist, 17 check_ext, 18 check_old_arg, 19 check_shapes, 20 check_supported_model, 21 check_supported_tokenizer, 22 check_tokenizer_exist, 23 check_weights_exist, 24) 25from aot_utils.llm_utils.utils import ( 26 dump_embedding_lut_for_cmdline, 27 generate_mask, 28 get_dest_path, 29 get_dirname, 30 get_embedding_layer, 31 get_exp_name, 32 get_export_shapes, 33 get_master_rot_emb, 34 get_normalized_config, 35 load_checkpoints, 36 resolve_model_classes, 37) 38from datasets import load_dataset 39from executorch import exir 40from executorch.backends.mediatek import ( 41 NeuropilotPartitioner, 42 NeuropilotQuantizer, 43 Precision, 44) 45from executorch.exir.backend.backend_details import CompileSpec 46from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 47from tqdm import tqdm 48 49warnings.filterwarnings("ignore") 50 51 52def get_argument_parser(): 53 parser = argparse.ArgumentParser( 54 description="Run Export to ET for suppoorted LLM models.", allow_abbrev=False 55 ) 56 parser.add_argument( 57 "config", 58 type=str, 59 help="[Required] Model config json file. " 60 "Model config must be in same directory as all model weight bins and tokenizer files.", 61 ) 62 parser.add_argument( 63 "-p", 64 "--precision", 65 type=str, 66 default="A16W8", 67 choices=["A16W4", "A16W8", "A16W16", "A8W4", "A8W8"], 68 help="Precision to quantize entire model to.", 69 ) 70 parser.add_argument( 71 "-d", 72 "--dataset", 73 type=str, 74 default=None, 75 help="Calibration dataset name or path to dataset. Defaults to None to use random inputs", 76 ) 77 parser.add_argument( 78 "-n", 79 "--num_chunks", 80 type=int, 81 default=4, 82 help="Number of chunks to cut the model into. Defaults to 4.", 83 ) 84 parser.add_argument( 85 "-r", 86 "--response_cap", 87 type=int, 88 default=9, 89 help="Max Number of Response Tokens to save during calibration. Defaults to 9.", 90 ) 91 parser.add_argument( 92 "--preformatter", 93 type=str, 94 default=None, 95 help="Preformatter Template to use to wrap input with. Defaults to None.", 96 ) 97 parser.add_argument( 98 "-shapes", 99 nargs="+", 100 help="[Required] Expected input shapes to reconfigure TFLites to. Space separated list of " 101 "shapes in the format: xtyc (e.g. 32t512c)", 102 ) 103 104 return parser 105 106 107# flake8: noqa: F405 108def args_sanity_checks(args): 109 check_old_arg(args.config) 110 check_exist(args.config, "Config file") 111 check_ext(args.config, ".json", "Config file") 112 config = get_normalized_config(args.config) 113 114 weight_dir = get_dirname(args.config) 115 check_tokenizer_exist(weight_dir) 116 check_weights_exist(weight_dir) 117 118 check_supported_model(config) 119 check_supported_tokenizer(config) 120 121 if args.preformatter is not None: 122 check_exist(args.preformatter, "Preformatter json file") 123 check_ext(args.preformatter, ".json", "preformatter") 124 125 if args.dataset is not None: 126 check_exist(args.dataset) 127 128 check_between_inclusive(args.num_chunks, 1, config.num_hidden_layers, "num_chunks") 129 130 check_shapes(args.shapes) 131 132 133def print_args(args, exp_name): 134 print("Please check if all arguments are correct:") 135 print(f"Config file: {args.config}") 136 print(f"Output pte folder: pte/{exp_name}") 137 print(f"Quantization precision: {args.precision}") 138 print(f"Preformatter: {args.preformatter}") 139 print(f"Calibration Dataset: {args.dataset}") 140 print(f"Max Response Tokens: {args.response_cap}") 141 print(f"Number of chunks: {args.num_chunks}") 142 print(f"Export shape(s): {args.shapes}") 143 print() 144 145 146def apply_preformatter(inp, preformatter=None): 147 formatted_text = preformatter.generate_prompt(inp["text"]) 148 inp["text"] = formatted_text 149 print(f"Formatted Prompt:\n{formatted_text}") 150 return inp 151 152 153def tokenize_dataset(inp, tokenizer): 154 text = inp["text"] 155 inp_encoded = tokenizer(text, return_tensors="pt") # dict 156 inp_encoded.pop("attention_mask") 157 inp_encoded = inp_encoded["input_ids"] 158 inp_encoded = inp_encoded.to(torch.int32) 159 inp["input_ids"] = inp_encoded 160 inp.pop("text") 161 return inp 162 163 164def reset_cache( 165 num_chunks, num_key_value_heads, num_blocks_per_chunk, head_dim, max_cache_size 166): 167 cache = [] 168 for i in range(num_chunks): 169 curr_chunk_cache = torch.zeros( 170 ( 171 2 * num_blocks_per_chunk[i], 172 num_key_value_heads, 173 max_cache_size, # generate fixed cache as torch dynamic shape cannot handle 2 dynamic dim 174 head_dim, 175 ), 176 dtype=torch.float32, 177 ) 178 cache.append(curr_chunk_cache) 179 return cache 180 181 182def forward_and_save( 183 models, 184 hidden_state, 185 cache, 186 mask, 187 pos_emb, 188 model_input_dict, 189 num_blocks_per_chunk, 190 batch_name, 191): 192 for chunk_idx in range(len(models)): 193 cache_in = cache[chunk_idx] 194 195 try: 196 model_input_dict[str(chunk_idx)] = { 197 **model_input_dict[str(chunk_idx)], 198 batch_name: { 199 "hidden_state": hidden_state, 200 "mask": mask, 201 "pos_emb": pos_emb, 202 "cache": cache_in, 203 }, 204 } 205 except: 206 model_input_dict[str(chunk_idx)] = { 207 batch_name: { 208 "hidden_state": hidden_state, 209 "mask": mask, 210 "pos_emb": pos_emb, 211 "cache": cache_in, 212 } 213 } 214 with torch.no_grad(): 215 model_out = models[chunk_idx]( 216 hidden_state, mask, pos_emb, *torch.split(cache_in, 1, dim=0) 217 ) 218 hidden_state = model_out[0] 219 cache[chunk_idx] = torch.cat( 220 model_out[1 : 1 + 2 * num_blocks_per_chunk[chunk_idx]], dim=0 221 ).clone() 222 return hidden_state, cache 223 224 225def prepare_model_inputs( 226 inp, 227 models, 228 embedding_layer, 229 master_rot_emb, 230 num_blocks_per_chunk, 231 num_key_value_heads, 232 head_dim, 233 max_cache_size, 234 eos_token_id_tensor, 235 response_cap, 236): 237 model_input_dict = {str(i): None for i in range(len(models))} 238 input_ids = inp.pop("input_ids") 239 hidden_state = embedding_layer(torch.tensor(input_ids)) 240 input_length = hidden_state.shape[1] 241 # Assume fixed cache size 242 mask = generate_mask(max_cache_size, 0, input_length, input_length) 243 pos_emb = master_rot_emb[:, :, :input_length, :] 244 # cache shape: num chunks of 2*num_block, num kv heads, c, head dim 245 cache = reset_cache( 246 len(models), num_key_value_heads, num_blocks_per_chunk, head_dim, max_cache_size 247 ) # empty kv 248 logits, cache = forward_and_save( 249 models, 250 hidden_state, 251 cache, 252 mask, 253 pos_emb, 254 model_input_dict, 255 num_blocks_per_chunk, 256 "prompt", 257 ) 258 next_token_logits = logits[:, -1, :] # last layer logits 259 next_token = torch.argmax(next_token_logits, dim=-1) 260 response_count = 0 261 seq_length = input_length 262 while True: 263 curr_input_id = next_token[:, None].to(torch.int32) 264 input_length = curr_input_id.shape[1] 265 hidden_state = embedding_layer(curr_input_id) 266 mask = generate_mask(max_cache_size, seq_length, input_length, input_length) 267 pos_emb = master_rot_emb[:, :, seq_length : seq_length + input_length, :] 268 logits, cache = forward_and_save( 269 models, 270 hidden_state, 271 cache, 272 mask, 273 pos_emb, 274 model_input_dict, 275 num_blocks_per_chunk, 276 f"response{response_count}", 277 ) 278 next_token_logits = logits[:, -1, :] 279 next_token = torch.argmax(next_token_logits, dim=-1) 280 if next_token == eos_token_id_tensor: 281 print(f"Found EOS on batch: {response_count}") 282 break 283 284 response_count += 1 285 seq_length += input_length 286 if response_count == response_cap: 287 break 288 289 return model_input_dict 290 291 292def calibrate_model(model, cal_dataset, chunk_idx: str): 293 with torch.no_grad(): 294 for inp in tqdm(cal_dataset, desc="Calibrating Model: "): 295 # pass prompt and response 296 for batch in tqdm(inp[chunk_idx].keys(), desc="Batch: "): 297 if inp[chunk_idx][batch] is not None: 298 inputs_embeds = torch.tensor(inp[chunk_idx][batch]["hidden_state"]) 299 mask = torch.tensor(inp[chunk_idx][batch]["mask"]) 300 pos_emb = torch.tensor(inp[chunk_idx][batch]["pos_emb"]) 301 cache = torch.tensor(inp[chunk_idx][batch]["cache"]) 302 model(inputs_embeds, mask, pos_emb, *torch.split(cache, 1, dim=0)) 303 304 305def export_to_et_ir( 306 output_folder, 307 exp_name, 308 model, 309 precision, 310 max_num_token, 311 max_cache_size, 312 chunk_idx, 313 export_shapes, 314 cal_dataset=None, 315): 316 print(f"Exporting Chunk {chunk_idx} to PTE") 317 example_inputs, dynamic_shapes = model.get_example_inputs( 318 max_num_token, max_cache_size, True 319 ) 320 print("Getting pre autograd ATen Dialect Graph") 321 pre_autograd_aten_dialect = torch.export.export_for_training( 322 model, example_inputs, dynamic_shapes=dynamic_shapes 323 ).module() # NOTE: Will be replaced with export 324 quantizer = NeuropilotQuantizer() 325 quantizer.setup_precision(getattr(Precision, precision)) 326 prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer) 327 # at this point quant min max are inf 328 if cal_dataset is not None: 329 calibrate_model(prepared_graph, cal_dataset, str(chunk_idx)) 330 else: 331 prepared_graph(*example_inputs) # dummy calibration 332 converted_graph = convert_pt2e(prepared_graph, fold_quantize=False) 333 334 print("Getting ATen Dialect Graph") 335 # Fixed Shape Export Here 336 for shape, ntok_and_cache in export_shapes.items(): 337 dest_path = get_dest_path(output_folder, exp_name, shape, chunk_idx) 338 print(f"Exporting Shape {shape} to:\n{dest_path}") 339 example_inputs = model.get_example_inputs(*ntok_and_cache) 340 aten_dialect: exir.ExportedProgram = torch.export.export( 341 converted_graph, example_inputs 342 ) 343 344 print("Lowering to Edge Dialect Graph") 345 edge_program: exir.EdgeProgramManager = exir.to_edge( 346 aten_dialect, 347 compile_config=exir.EdgeCompileConfig(_check_ir_validity=False), 348 ) 349 del aten_dialect 350 351 print("Delegating Edge Program to Neuropilot Backend") 352 compile_spec = [ 353 CompileSpec("gno", struct.pack("3s", b"LTS")), 354 CompileSpec("gno-exp", struct.pack("0s", b"")), 355 CompileSpec("gno-non-4d-tiling", struct.pack("0s", b"")), 356 CompileSpec("ImportForever", struct.pack("?", True)), 357 ] 358 partitioner = NeuropilotPartitioner(compile_spec) 359 delegated_program = edge_program.to_backend(partitioner) 360 print("Exported Delegated Program:") 361 print(delegated_program.exported_program()) 362 del edge_program 363 364 print("Transforming delegated program to executorch backend") 365 executorch_program = delegated_program.to_executorch( 366 config=exir.ExecutorchBackendConfig( 367 memory_planning_pass=exir.passes.MemoryPlanningPass( 368 alloc_graph_input=False, 369 alloc_graph_output=False, 370 ), 371 extract_delegate_segments=True, 372 ) 373 ) 374 375 print(f"ET Model Dest: {dest_path}\n") 376 os.makedirs(dest_path.rsplit("/", 1)[0], exist_ok=True) 377 with open(dest_path, "wb") as file: 378 file.write(executorch_program.buffer) 379 380 381def main(): 382 parser = get_argument_parser() 383 args = parser.parse_args() 384 args_sanity_checks(args) 385 if args.dataset is None: 386 exp_name = f"{get_exp_name(args.config)}_{args.precision}_dummy_cal_{args.num_chunks}_chunks" 387 else: 388 exp_name = ( 389 f"{get_exp_name(args.config)}_{args.precision}_{args.num_chunks}_chunks" 390 ) 391 print_args(args, exp_name) 392 393 config, weight_dir, tokenizer_class, chunk_class = resolve_model_classes( 394 args.config 395 ) 396 tokenizer = tokenizer_class.from_pretrained(weight_dir) 397 if args.preformatter is not None: 398 preformatter = Preformatter(args.preformatter) 399 400 head_dim = int(config.hidden_size / config.num_attention_heads) 401 402 # Evenly distribute the layers across chunks. 403 num_blocks_per_chunk = [ 404 (config.num_hidden_layers // args.num_chunks) 405 + (i < (config.num_hidden_layers % args.num_chunks)) 406 for i in range(args.num_chunks) 407 ] 408 check_all_chunks_same_num_layer(num_blocks_per_chunk) # noqa: F405 409 410 output_folder = os.path.join("pte", exp_name) 411 412 # Load all collected checkpoint files into one giant state_dict 413 state_dict = load_checkpoints(weight_dir) 414 415 dump_embedding_lut_for_cmdline(weight_dir, state_dict, config) 416 417 export_shapes, max_num_token, max_cache_size = get_export_shapes(args.shapes) 418 print(f"export shapes: {export_shapes}") 419 print(f"Max Num Token: {max_num_token}") 420 print(f"Max Cache Size: {max_cache_size}") 421 422 if args.dataset is not None: 423 embedding_layer = get_embedding_layer(config, weight_dir, state_dict) 424 425 # Instantiate model chunks 426 print("Instantiating submodels") 427 models = [] 428 for chunk_idx, num_blocks in enumerate(num_blocks_per_chunk): 429 chunk = chunk_class( 430 config, 431 num_blocks, 432 chunk_idx=chunk_idx, 433 dtype=torch.float32, 434 include_tail=(chunk_idx == args.num_chunks - 1), 435 jit_trace=True, 436 ) 437 chunk = chunk.load_weights(state_dict, sum(num_blocks_per_chunk[:chunk_idx])) 438 models.append(chunk) 439 440 cal_dataset = None 441 if args.dataset is not None: 442 cal_dataset = load_dataset("text", data_files=args.dataset, split="train") 443 master_rot_emb = get_master_rot_emb(config, dtype=torch.float32) 444 if args.preformatter is not None: 445 cal_dataset = cal_dataset.map( 446 apply_preformatter, fn_kwargs={"preformatter": preformatter} 447 ) 448 cal_dataset = cal_dataset.map( 449 tokenize_dataset, fn_kwargs={"tokenizer": tokenizer} 450 ) 451 print("Preparing Model Calibration Inputs...") 452 cal_dataset = cal_dataset.map( 453 prepare_model_inputs, 454 fn_kwargs={ 455 "models": models, 456 "embedding_layer": embedding_layer, 457 "master_rot_emb": master_rot_emb, 458 "num_blocks_per_chunk": num_blocks_per_chunk, 459 "num_key_value_heads": config.num_key_value_heads, 460 "head_dim": head_dim, 461 "max_cache_size": max_cache_size, 462 "eos_token_id_tensor": torch.tensor(tokenizer.eos_token_id), 463 "response_cap": args.response_cap, 464 }, 465 ) 466 467 for chunk_idx, chunk in enumerate(models): 468 export_to_et_ir( 469 output_folder, 470 exp_name, 471 chunk, 472 args.precision, 473 max_num_token, 474 max_cache_size, 475 chunk_idx, 476 export_shapes, 477 cal_dataset, 478 ) 479 480 481if __name__ == "__main__": 482 main() 483