1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9# Example script for exporting Llama2 to flatbuffer 10 11import argparse 12import copy 13import json 14import logging 15import re 16import shlex 17from enum import Enum 18from json import JSONDecodeError 19from pathlib import Path 20from typing import Callable, List, Optional, Union 21 22import pkg_resources 23import torch 24 25from executorch.devtools.etrecord import generate_etrecord 26 27from executorch.extension.llm.export.builder import DType, LLMEdgeManager 28 29from executorch.extension.llm.export.partitioner_lib import ( 30 get_coreml_partitioner, 31 get_mps_partitioner, 32 get_qnn_partitioner, 33 get_vulkan_partitioner, 34 get_xnnpack_partitioner, 35) 36 37from executorch.extension.llm.export.quantizer_lib import ( 38 get_coreml_quantizer, 39 get_pt2e_quantization_params, 40 get_pt2e_quantizers, 41 get_qnn_quantizer, 42 get_vulkan_quantizer, 43) 44from executorch.util.activation_memory_profiler import generate_memory_trace 45 46from ..model_factory import EagerModelFactory 47from .source_transformation.apply_spin_quant_r1_r2 import ( 48 fuse_layer_norms, 49 get_model_with_r1_r2, 50) 51 52from .source_transformation.attention import replace_attention_to_attention_sha 53from .source_transformation.quantize import ( 54 get_quant_embedding_transform, 55 get_quant_weight_transform, 56) 57from .source_transformation.quantized_kv_cache import ( 58 replace_kv_cache_with_quantized_kv_cache, 59) 60from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm 61 62from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis 63from .source_transformation.sdpa import ( 64 replace_causal_mask, 65 replace_kv_cache_with_coreml_kv_cache, 66 replace_kv_cache_with_simple_kv_cache, 67 replace_sdpa_with_coreml_sdpa, 68 replace_sdpa_with_custom_op, 69 replace_sdpa_with_flex_sdpa, 70 replace_sdpa_with_simple_sdpa, 71) 72from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb 73 74IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) 75FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 76logging.basicConfig(level=logging.INFO, format=FORMAT) 77 78pkg_name = __name__ 79verbosity_setting = None 80 81 82EXECUTORCH_DEFINED_MODELS = ["stories110m", "llama2", "llama3", "llama3_1", "llama3_2"] 83TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] 84 85 86class WeightType(Enum): 87 LLAMA = "LLAMA" 88 FAIRSEQ2 = "FAIRSEQ2" 89 90 91def set_pkg_name(name: str) -> None: 92 global pkg_name 93 pkg_name = name 94 95 96def get_resource_path(resource_name) -> str: 97 return pkg_resources.resource_filename(pkg_name, resource_name) 98 99 100def set_verbosity(val): 101 global verbosity_setting 102 verbosity_setting = val 103 104 105def verbose_export(): 106 return verbosity_setting 107 108 109def build_model( 110 modelname: str = "llama3", 111 extra_opts: str = "", 112 *, 113 par_local_output: bool = False, 114 resource_pkg_name: str = __name__, 115) -> str: 116 if False: # par_local_output: 117 output_dir_path = "par:." 118 else: 119 output_dir_path = "." 120 121 argString = f"--model {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}" 122 parser = build_args_parser() 123 args = parser.parse_args(shlex.split(argString)) 124 # pkg_name = resource_pkg_name 125 return export_llama(args) 126 127 128def build_args_parser() -> argparse.ArgumentParser: 129 ckpt_dir = f"{Path(__file__).absolute().parent.as_posix()}" 130 parser = argparse.ArgumentParser() 131 parser.add_argument("-o", "--output-dir", default=".", help="output directory") 132 # parser.add_argument( 133 # "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file" 134 # ) 135 parser.add_argument( 136 "--model", 137 default="llama3", 138 choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS, 139 help="The Lllama model to export. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.", 140 ) 141 parser.add_argument( 142 "-E", 143 "--embedding-quantize", 144 default=None, 145 type=str, 146 help="type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '8,1024'.", 147 ) 148 parser.add_argument( 149 "--pt2e_quantize", 150 default=None, 151 choices=[ 152 "xnnpack_dynamic", 153 "xnnpack_dynamic_qc4", 154 "qnn_8a8w", 155 "qnn_16a16w", 156 "qnn_16a4w", 157 "coreml_c4w", 158 "coreml_8a_c8w", 159 "coreml_8a_c4w", 160 "coreml_baseline_8a_c8w", 161 "coreml_baseline_8a_c4w", 162 "vulkan_8w", 163 ], 164 help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.", 165 ) 166 167 parser.add_argument( 168 "-qmode", 169 "--quantization_mode", 170 type=_qmode_type, 171 default=None, 172 help="type of quantization", 173 ) 174 175 parser.add_argument( 176 "-c", 177 "--checkpoint", 178 default=f"{ckpt_dir}/params/demo_rand_params.pth", 179 help="checkpoint path", 180 ) 181 182 parser.add_argument( 183 "--checkpoint_dir", 184 default=None, 185 help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.", 186 ) 187 188 parser.add_argument( 189 "--use_qnn_sha", 190 action="store_true", 191 help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)", 192 ) 193 194 parser.add_argument( 195 "--calibration_tasks", 196 nargs="+", 197 type=str, 198 default=None, 199 help="Tasks for GPTQ calibration from lm_eval", 200 ) 201 parser.add_argument( 202 "--calibration_limit", 203 type=int, 204 default=None, 205 help="number of samples used for calibration from lm_eval", 206 ) 207 parser.add_argument( 208 "--calibration_seq_length", 209 type=int, 210 default=None, 211 help="Sequence length for GPTQ calibration from lm_eval", 212 ) 213 parser.add_argument( 214 "--calibration_data", 215 type=str, 216 default="Once upon a time", 217 help="Calibration prompts from users", 218 ) 219 parser.add_argument( 220 "-t", 221 "--tokenizer_path", 222 default=None, 223 help="tokenizer path (Note: .model not .bin)", 224 ) 225 parser.add_argument( 226 "-kv", 227 "--use_kv_cache", 228 default=False, 229 action="store_true", 230 help="Whether or not to export a model using kv cache", 231 ) 232 parser.add_argument( 233 "--quantize_kv_cache", 234 default=False, 235 action="store_true", 236 help="Whether or not to export a model using int8 per token quantized kv cache", 237 ) 238 parser.add_argument( 239 "--num_sharding", 240 type=int, 241 default=0, 242 help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.", 243 ) 244 parser.add_argument( 245 "--use_sdpa_with_kv_cache", 246 default=False, 247 action="store_true", 248 help="Whether to use sdpa_with_kv_cache update op when using kv cache", 249 ) 250 parser.add_argument( 251 "--disable_dynamic_shape", 252 dest="enable_dynamic_shape", 253 default=True, # Enable this by default 254 action="store_false", 255 help="Enable dynamic shape along seq dim. Used for faster prefill", 256 ) 257 parser.add_argument( 258 "-p", 259 "--params", 260 default=f"{ckpt_dir}/params/demo_config.json", 261 help="config.json", 262 ) 263 parser.add_argument( 264 "--optimized_rotation_path", 265 default=None, 266 required=False, 267 help="[QNN backend] Optimized rotation checkpoint path. Just apply R1/R2 here." 268 "You can download the optimized rotation matrices from https://github.com/facebookresearch/SpinQuant/tree/main", 269 ) 270 parser.add_argument( 271 "-m", 272 "--metadata", 273 default=None, 274 help='metadata string in json format. Example {"key": 1, "key2": "value2"}', 275 ) 276 parser.add_argument( 277 "-s", 278 "--so_library", 279 default=None, 280 required=False, 281 help="shared library for quantized operators", 282 ) 283 parser.add_argument( 284 "--profile_memory", 285 required=False, 286 action="store_true", 287 help="Generate chrome trace of activation memory for intermediate tensors.", 288 ) 289 parser.add_argument( 290 "-prof", 291 "--profile_path", 292 default=None, 293 help="Use cProfile to profile model export. Results saved to profile_path as a html file.", 294 ) 295 parser.add_argument( 296 "-G", 297 "--group_size", 298 type=int, 299 default=None, 300 help="group_size for weight quantization", 301 ) 302 303 parser.add_argument( 304 "-d", 305 "--dtype-override", 306 default="fp32", 307 type=str, 308 choices=["fp32", "fp16", "bf16"], 309 help="Override the dtype of the model (default is the checkpoint dtype)." 310 "Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.", 311 ) 312 313 parser.add_argument( 314 "-n", 315 "--output_name", 316 default=None, 317 help="Override the output filename of the saved pte model file.", 318 ) 319 320 parser.add_argument( 321 "--max_seq_length", 322 type=int, 323 default=128, 324 help="maximum length sequence to evaluate", 325 ) 326 327 parser.add_argument("-2", "--fairseq2", action="store_true") 328 parser.add_argument("-v", "--verbose", action="store_true") 329 parser.add_argument( 330 "-X", 331 "--xnnpack", 332 action="store_true", 333 help="Delegate to DQLinear ops to the xnnpack backend", 334 ) 335 parser.add_argument( 336 "--xnnpack-extended-ops", 337 action="store_true", 338 help="Delegate more operators beyond DQLinear to the xnnpack backend. Requires -X or --xnnpack to be set.", 339 ) 340 parser.add_argument("-V", "--vulkan", action="store_true") 341 parser.add_argument("--mps", action="store_true") 342 parser.add_argument("--coreml", action="store_true") 343 parser.add_argument( 344 "--coreml-enable-state", 345 action="store_true", 346 help="This option is only for coreml, and is only supported for MacOS15+/iOS18+", 347 ) 348 parser.add_argument( 349 "--coreml-preserve-sdpa", 350 action="store_true", 351 help="This option is only for coreml: Preserve sdpa in torch edge program to use coreml iOS18.sdpa op", 352 ) 353 parser.add_argument( 354 "--coreml-quantize", 355 default=None, 356 choices=["b4w"], 357 help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight)", 358 ) 359 parser.add_argument( 360 "--coreml-ios", 361 type=int, 362 default=15, 363 choices=(15, 16, 17, 18), 364 help="This option is only for coreml: The minimum iOS version to deploy", 365 ) 366 parser.add_argument( 367 "--qnn", 368 action="store_true", 369 help="Delegate llama2 to qnn backend (Qualcomm), please use it --kv_cahce=True", 370 ) 371 372 parser.add_argument( 373 "--expand_rope_table", 374 default=False, 375 action="store_true", 376 help="[Temp workaround] Expand sin/cos table in head dim to take vectorized path in optimized kernels.", 377 ) 378 379 parser.add_argument( 380 "--generate_etrecord", 381 action="store_true", 382 required=False, 383 default=False, 384 help="Generate the ETRecord debug artifact.", 385 ) 386 387 parser.add_argument( 388 "--generate_full_logits", 389 action="store_true", 390 required=False, 391 default=False, 392 help="Generate logits for all inputs.", 393 ) 394 395 parser.add_argument( 396 "--soc_model", 397 help="[QNN backend] SoC model of current device. e.g. 'SM8650' for Snapdragon 8 Gen 3.", 398 type=str, 399 required=False, 400 default="SM8650", 401 ) 402 403 parser.add_argument( 404 "-sq", 405 "--use_spin_quant", 406 type=str, 407 default=None, 408 choices=["cuda", "native"], 409 help="Use SpinQuant for better quantization performance. Only support cuda and native.", 410 ) 411 412 parser.add_argument( 413 "-qat", 414 "--use_qat", 415 default=False, 416 action="store_true", 417 help="Whether the checkpoin is pre-quantized with QAT or not.", 418 ) 419 420 parser.add_argument( 421 "-lora", 422 "--use_lora", 423 type=int, 424 default=0, 425 help="Whether the checkpoint contains LoRA adaptors or not. 0: no LoRA adaptors; " 426 "otherwise, it means the rank of LoRA adaptors. Currently it only works if QAT is enabled.", 427 ) 428 429 parser.add_argument( 430 "--preq_mode", 431 type=str, 432 default=None, 433 choices=["8da4w", "8da4w_output_8da8w"], 434 help="Quantization mode used for pre-quantized checkpoint. Only support 8da4w and 8da4w_output_8da8w right now.", 435 ) 436 437 parser.add_argument( 438 "--preq_group_size", 439 type=int, 440 default=32, 441 help="group_size for pre-quantized checkpoint weight quantization", 442 ) 443 444 parser.add_argument( 445 "--preq_embedding_quantize", 446 default="8,0", 447 type=str, 448 help="type of embedding quantization for pre-quantized checkpoint, '<bitwidth>,<groupsize>', e.g., '8,1024'.", 449 ) 450 451 parser.add_argument( 452 "--output_prune_map", 453 default=None, 454 help="path to the output pruning token mapping file (token_map.json)", 455 ) 456 457 parser.add_argument( 458 "--input_prune_map", 459 default=None, 460 help="path to the input pruning token mapping file (token_map.json)", 461 ) 462 463 parser.add_argument( 464 "--export_only", 465 default=False, 466 action="store_true", 467 help="If true, stops right after torch.export() and saves the exported model.", 468 ) 469 return parser 470 471 472def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: 473 path = str(path) 474 475 if verbose_export(): 476 print(f"creating canonical path for {path}") 477 478 if not path.startswith("par:"): 479 return path 480 481 if not IS_FBCODE: 482 print("not FBCODE") 483 return path[4:] 484 else: 485 return_val = pkg_resources.resource_filename(pkg_name, path[4:]) 486 if verbose_export(): 487 print(f"canonical name is: {return_val}") 488 return return_val 489 490 491def export_llama(args) -> str: 492 if args.profile_path is not None: 493 try: 494 from executorch.util.python_profiler import CProfilerFlameGraph 495 496 with CProfilerFlameGraph(args.profile_path): 497 builder = _export_llama(args) 498 assert ( 499 filename := builder.get_saved_pte_filename() 500 ) is not None, "Fail to get file name from builder" 501 return filename 502 except ImportError: 503 print( 504 "Please run `pip install snakeviz` to install required dependencies for cProfiler flamegraph." 505 ) 506 return "" 507 else: 508 builder = _export_llama(args) 509 assert ( 510 filename := builder.get_saved_pte_filename() 511 ) is not None, "Fail to get file name from builder" 512 return filename 513 514 515def _prepare_for_llama_export(args) -> LLMEdgeManager: 516 """ 517 Helper function for export_llama. Loads the model from checkpoint and params, 518 and sets up a LLMEdgeManager with initial transforms and dtype conversion. 519 520 Returns a LLMEdgeManager prior to calling export_to_edge with quantizers 521 """ 522 # load model from checkpoint and params.json 523 checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None 524 checkpoint_dir = ( 525 canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None 526 ) 527 params_path = canonical_path(args.params) 528 output_dir_path = canonical_path(args.output_dir, dir=True) 529 weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA 530 531 # dtype override 532 if args.dtype_override is not None: 533 dtype_override = DType[args.dtype_override] 534 elif args.quantization_mode in ["8da4w", "8da4w-gptq"]: 535 dtype_override = DType["fp16"] 536 else: 537 dtype_override = None 538 539 return ( 540 _load_llama_model( 541 args.model, 542 checkpoint=checkpoint_path, 543 checkpoint_dir=checkpoint_dir, 544 params_path=params_path, 545 use_kv_cache=args.use_kv_cache, 546 use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, 547 generate_full_logits=args.generate_full_logits, 548 weight_type=weight_type, 549 enable_dynamic_shape=args.enable_dynamic_shape, 550 calibration_tasks=args.calibration_tasks, 551 calibration_limit=args.calibration_limit, 552 calibration_seq_length=args.calibration_seq_length, 553 calibration_data=args.calibration_data, 554 tokenizer_path=args.tokenizer_path, 555 verbose=args.verbose, 556 max_seq_len=args.max_seq_length, 557 input_prune_map_path=args.input_prune_map, 558 output_prune_map_path=args.output_prune_map, 559 metadata_str=args.metadata, 560 dtype_override=dtype_override, 561 args=args, 562 ) 563 .set_output_dir(output_dir_path) 564 .source_transform(_get_source_transforms(args.model, dtype_override, args)) 565 ) 566 567 568def get_quantizer_and_quant_params(args): 569 pt2e_quant_params = get_pt2e_quantization_params( 570 args.pt2e_quantize, args.quantization_mode 571 ) 572 quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library) 573 quant_dtype = None 574 if args.qnn and args.pt2e_quantize: 575 assert len(quantizers) == 0, "Should not enable both xnnpack and qnn" 576 qnn_quantizer, quant_dtype = get_qnn_quantizer( 577 args.pt2e_quantize, args.quantization_mode 578 ) 579 quantizers.append(qnn_quantizer) 580 if args.coreml and args.pt2e_quantize: 581 assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml" 582 coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize) 583 quantizers.append(coreml_quantizer) 584 if args.vulkan and args.pt2e_quantize: 585 assert ( 586 len(quantizers) == 0 587 ), "Should not enable both vulkan and other quantizers" 588 vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize) 589 quantizers.append(vulkan_quantizer) 590 logging.info(f"Applying quantizers: {quantizers}") 591 return pt2e_quant_params, quantizers, quant_dtype 592 593 594def _qmode_type(value): 595 choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"] 596 patterns = [r"torchao:8da(\d+)w"] 597 598 if value in choices: 599 return value 600 601 for pattern in patterns: 602 matches = re.findall(pattern, value) 603 if len(matches) == 1: 604 return value 605 606 raise argparse.ArgumentTypeError( 607 f"Got qmode {value}, but expected one of {choices}, or one of the regex patterns {patterns}." 608 ) 609 610 611def _validate_args(args): 612 """ 613 TODO: Combine all the backends under --backend args 614 """ 615 if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn): 616 raise ValueError( 617 "Dynamic shape is not supported with coreml, MPS or qnn backends." 618 " Please use --disable_dynamic_shape." 619 ) 620 621 if args.num_sharding > 0 and not args.qnn: 622 raise ValueError("Model shard is only supported with qnn backend now.") 623 624 if ( 625 args.quantization_mode is not None 626 and args.quantization_mode.startswith("torchao:") 627 ) or ( 628 args.embedding_quantize is not None 629 and args.embedding_quantize.startswith("torchao:") 630 ): 631 if args.enable_dynamic_shape: 632 raise ValueError( 633 "Dynamic shape is not currently supported with torchao ops. Please use --disable_dynamic_shape." 634 "If you need this feature, please file an issue." 635 ) 636 637 638def _export_llama(args) -> LLMEdgeManager: # noqa: C901 639 _validate_args(args) 640 pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) 641 642 # export_to_edge 643 builder_exported = _prepare_for_llama_export(args).export() 644 645 if args.export_only: 646 exit() 647 648 builder_exported_to_edge = builder_exported.pt2e_quantize( 649 quantizers 650 ).export_to_edge() 651 652 modelname = builder_exported_to_edge.modelname 653 654 # to_backend 655 partitioners = [] 656 657 # Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled 658 if ( 659 pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None 660 ) or (args.xnnpack): 661 partitioners.append( 662 get_xnnpack_partitioner(dynamic_quant_only_partitioner=True) 663 ) 664 665 # force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False 666 args.xnnpack = True 667 modelname = f"xnnpack_dq_{modelname}" 668 669 if args.xnnpack_extended_ops: 670 assert args.xnnpack, "xnnpack_extended_ops requires xnnpack to be enabled" 671 partitioners.append( 672 get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) 673 ) 674 modelname = f"xnnpack_{modelname}" 675 676 if args.vulkan: 677 partitioners.append( 678 get_vulkan_partitioner( 679 args.dtype_override, 680 args.enable_dynamic_shape, 681 ) 682 ) 683 # Apply XNNPACK after Vulkan so that undelegated ops can be accelerated by XNNPACK 684 partitioners.append( 685 get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) 686 ) 687 modelname = f"vulkan_{modelname}" 688 689 if args.mps: 690 partitioners.append(get_mps_partitioner(args.use_kv_cache)) 691 modelname = f"mps_{modelname}" 692 693 if args.coreml: 694 coreml_partitioner = get_coreml_partitioner( 695 args.coreml_ios, 696 args.embedding_quantize, 697 args.pt2e_quantize, 698 args.coreml_quantize, 699 ) 700 partitioners.append(coreml_partitioner) 701 modelname = f"coreml_{modelname}" 702 703 if args.qnn: 704 from executorch.extension.llm.custom_ops import model_sharding 705 706 partitioners.append( 707 get_qnn_partitioner( 708 args.use_kv_cache, args.pt2e_quantize, args.num_sharding, args.soc_model 709 ) 710 ) 711 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` 712 from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io 713 714 # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program` 715 _transform(builder_exported_to_edge.edge_manager.exported_program()) 716 717 if args.num_sharding > 0: 718 model_sharding.split_graph( 719 builder_exported_to_edge.edge_manager.exported_program(), 720 # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. 721 builder_exported_to_edge.metadata["get_n_layers"], 722 shares=args.num_sharding, 723 ) 724 725 from functools import partial 726 727 # pyre-ignore 728 from executorch.backends.qualcomm.quantizer.custom_annotation import ( 729 get_custom_quant_ios_dtype, 730 ) 731 732 atten = builder_exported_to_edge.model.layers[0].attention 733 if args.use_qnn_sha: 734 cache_shape = torch.Size( 735 (atten.max_batch_size, atten.max_seq_len, atten.head_dim) 736 ) 737 else: 738 cache_shape = torch.Size( 739 ( 740 atten.max_batch_size, 741 atten.max_seq_len, 742 atten.n_kv_heads, 743 atten.head_dim, 744 ) 745 ) 746 # pyre-ignore 747 tag_quant_io( 748 builder_exported_to_edge.edge_manager.exported_program().graph_module, 749 partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore 750 ) 751 752 logging.info("Lowering model using following partitioner(s): ") 753 for partitioner in partitioners: 754 logging.info(f"--> {partitioner.__class__.__name__}") 755 756 if args.generate_etrecord: 757 if not builder_exported_to_edge.edge_manager: 758 raise ValueError("Unable to generate etrecord due to missing edge manager.") 759 760 logging.info("Generating etrecord") 761 # Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive. 762 edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager) 763 builder = builder_exported_to_edge.to_backend(partitioners) 764 if args.num_sharding > 0 and args.qnn: 765 from executorch.backends.qualcomm.utils.utils import canonicalize_program 766 767 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. 768 canonicalize_program(builder.edge_manager.exported_program()) 769 770 builder = builder.to_executorch() 771 772 # Generate ETRecord 773 if edge_manager_copy: 774 generate_etrecord( 775 et_record="etrecord.bin", 776 edge_dialect_program=edge_manager_copy, 777 executorch_program=builder.export_program, 778 ) 779 logging.info("Generated etrecord.bin") 780 else: 781 builder = builder_exported_to_edge.to_backend(partitioners) 782 if args.num_sharding > 0 and args.qnn: 783 from executorch.backends.qualcomm.utils.utils import canonicalize_program 784 785 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. 786 canonicalize_program(builder.edge_manager.exported_program()) 787 788 builder = builder.to_executorch() 789 790 if args.profile_memory: 791 generate_memory_trace(builder.export_program, "memory_profile.json") 792 793 if builder.dtype == DType.fp16: 794 modelname = f"{modelname}_h" 795 796 if args.output_name: 797 modelname = args.output_name 798 if modelname.endswith(".pte"): 799 output_file = modelname 800 modelname = modelname[:-4] 801 print(f"modelname: {modelname}") 802 print(f"output_file: {output_file}") 803 else: 804 output_file = f"{builder.output_dir}/{modelname}.pte" 805 print(f"modelname: {modelname}") 806 print(f"output_file: {output_file}") 807 else: 808 output_file = f"{builder.output_dir}/{modelname}.pte" 809 810 builder.save_to_pte(output_file) 811 812 return builder 813 814 815def _load_llama_model_metadata( 816 weight_type: WeightType, 817 use_kv_cache: bool, 818 use_sdpa_with_kv_cache: bool, 819 enable_dynamic_shape: bool, 820 max_seq_len: int, 821 n_layers: int, 822 vocab_size: int, 823 metadata_str: Optional[str] = None, 824): 825 is_fairseq2 = weight_type == WeightType.FAIRSEQ2 826 metadata = { 827 "get_bos_id": 3 if is_fairseq2 else 1, 828 "get_eos_ids": [3] if is_fairseq2 else [2], 829 "get_max_seq_len": max_seq_len, 830 "get_n_layers": n_layers, 831 "get_vocab_size": vocab_size, 832 "use_kv_cache": use_kv_cache, 833 "use_sdpa_with_kv_cache": use_sdpa_with_kv_cache, 834 "enable_dynamic_shape": enable_dynamic_shape, 835 } 836 if metadata_str: 837 try: 838 extra = json.loads(metadata_str) 839 for k, v in extra.items(): 840 metadata[k] = v 841 except JSONDecodeError: 842 logging.error("Invalid metadata, should be a valid JSON string") 843 return metadata 844 845 846def _load_llama_model( 847 modelname: str = "llama3", 848 *, 849 checkpoint: Optional[str] = None, 850 checkpoint_dir: Optional[str] = None, 851 params_path: str, 852 use_kv_cache: bool = False, 853 use_sdpa_with_kv_cache: bool = False, 854 generate_full_logits: bool = False, 855 weight_type: WeightType = WeightType.LLAMA, 856 enable_dynamic_shape: bool = False, 857 calibration_tasks: Optional[List[str]] = None, 858 calibration_limit: Optional[int] = None, 859 calibration_seq_length: Optional[int] = None, 860 calibration_data: Optional[str] = None, 861 tokenizer_path: Optional[str] = None, 862 verbose: bool = False, 863 max_seq_len: int = 128, 864 input_prune_map_path: Optional[str] = None, 865 output_prune_map_path: Optional[str] = None, 866 metadata_str: Optional[str] = None, 867 dtype_override: Optional[DType] = None, 868 args, 869) -> "LLMEdgeManager": 870 """ 871 A helper util that builds a Llama2 model. It returns a LLMEdgeManager that 872 can help further lower the model to ExecuTorch. 873 Returns: 874 An instance of LLMEdgeManager which contains the eager mode model. 875 """ 876 877 assert ( 878 checkpoint or checkpoint_dir 879 ) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty" 880 logging.info( 881 f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}" 882 ) 883 884 if modelname in EXECUTORCH_DEFINED_MODELS: 885 module_name = "llama" 886 model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py. 887 elif modelname in TORCHTUNE_DEFINED_MODELS: 888 if modelname == "llama3_2_vision": 889 module_name = "llama3_2_vision" 890 model_class_name = "Llama3_2Decoder" 891 else: 892 raise ValueError(f"{modelname} is not a valid Llama model.") 893 else: 894 raise ValueError(f"{modelname} is not a valid Llama model.") 895 896 model, example_inputs, example_kwarg_inputs, dynamic_shapes = ( 897 EagerModelFactory.create_model( 898 module_name, 899 model_class_name, 900 checkpoint=checkpoint, 901 checkpoint_dir=checkpoint_dir, 902 params=params_path, 903 use_kv_cache=use_kv_cache, 904 use_sdpa_with_kv_cache=use_sdpa_with_kv_cache, 905 generate_full_logits=generate_full_logits, 906 fairseq2=weight_type == WeightType.FAIRSEQ2, 907 max_seq_len=max_seq_len, 908 enable_dynamic_shape=enable_dynamic_shape, 909 input_prune_map_path=input_prune_map_path, 910 output_prune_map_path=output_prune_map_path, 911 args=args, 912 ) 913 ) 914 if dtype_override: 915 assert isinstance( 916 dtype_override, DType 917 ), "Override dtype needs to be of type <DType>" 918 torch_dtype = dtype_override.to_torch_dtype() 919 logging.info(f"model.to {torch_dtype}") 920 model = model.to(dtype=torch_dtype) 921 dtype = dtype_override 922 else: 923 state_dict = model.state_dict() 924 dtype = state_dict[next(iter(state_dict))].dtype 925 assert dtype in [ 926 torch.bfloat16, 927 torch.float16, 928 torch.float32, 929 ], f"Only support bfloat16, fp16 or fp32 got {dtype}" 930 logging.info(f"Loaded model with dtype={dtype}") 931 932 if dtype == torch.bfloat16: 933 dtype = DType.bf16 934 elif dtype == torch.float16: 935 dtype = DType.fp16 936 elif dtype == torch.float32: 937 dtype = DType.fp32 938 else: 939 raise ValueError(f"Unsupported dtype {dtype}") 940 941 return LLMEdgeManager( 942 model=model, 943 modelname=modelname, 944 max_seq_len=model.max_seq_len, 945 dtype=dtype, 946 use_kv_cache=use_kv_cache, 947 generate_full_logits=generate_full_logits, 948 example_inputs=example_inputs, 949 example_kwarg_inputs=example_kwarg_inputs, 950 dynamic_shapes=dynamic_shapes, 951 enable_dynamic_shape=enable_dynamic_shape, 952 calibration_tasks=calibration_tasks, 953 calibration_limit=calibration_limit, 954 calibration_seq_length=calibration_seq_length, 955 calibration_data=calibration_data, 956 tokenizer_path=tokenizer_path, 957 verbose=verbose, 958 metadata=_load_llama_model_metadata( 959 weight_type, 960 use_kv_cache, 961 use_sdpa_with_kv_cache, 962 enable_dynamic_shape, 963 # pyre-fixme[6]: For 5th argument expected `ModelArgs` but got 964 # `Union[Tensor, Module]`. 965 model.max_seq_len, 966 # pyre-fixme[6]: For 6th argument expected `int` but got `Union[Tensor, 967 # Module]`. 968 model.n_layers, 969 # pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor, 970 # Module]`. 971 model.vocab_size, 972 metadata_str, 973 ), 974 args=args, 975 ) 976 977 978def _get_source_transforms( # noqa 979 modelname: str, dtype_override: Optional[DType], args 980) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: 981 transforms = [] 982 983 if args.use_spin_quant: 984 if args.use_spin_quant == "cuda": 985 from .source_transformation.spin_quant import ( 986 inject_fast_hadamard_transform_cuda_for_spin_quant, 987 ) 988 989 transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant) 990 elif args.use_spin_quant == "native": 991 from .source_transformation.spin_quant import ( 992 inject_fast_hadamard_transform_native_for_spin_quant, 993 ) 994 995 transforms.append(inject_fast_hadamard_transform_native_for_spin_quant) 996 997 if args.quantization_mode: 998 """ 999 When this option is selected, it finds all linear layers and transforms 1000 into quantized linear equivalent module. 1001 1002 There are cases where the checkpoint is already quantized, for example 1003 on use_spin_quant is enabled. In that case, it will do the appropriate 1004 transformations based on the given checkpoint first. In those cases, 1005 if quantization_mode is enabled, it will quantize any remaining linear 1006 ops that is not quantized. 1007 1008 There are cases where this may be a no-op, namely, if all linears are 1009 quantized in the checkpoint. 1010 """ 1011 modelname = f"{modelname}_q" 1012 transforms.append( 1013 get_quant_weight_transform(args, dtype_override, verbose_export()) 1014 ) 1015 1016 if args.embedding_quantize: 1017 """ 1018 When this option is selected, it finds all embedding layers and transforms 1019 into quantized embedding equivalent module. 1020 1021 There are cases where the checkpoint is already quantized, for example 1022 on use_spin_quant is enabled. In that case, it will do the appropriate 1023 transformations based on the given checkpoint first. In those cases, 1024 this wil be a no-op. 1025 """ 1026 modelname = f"{modelname}_e" 1027 transforms.append(get_quant_embedding_transform(args)) 1028 1029 if args.expand_rope_table: 1030 transforms.append(materialze_broadcast_of_rope_freq_cis) 1031 1032 if args.use_sdpa_with_kv_cache: 1033 transforms.append(replace_sdpa_with_custom_op) 1034 1035 if args.quantize_kv_cache: 1036 assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" 1037 transforms.append(replace_kv_cache_with_quantized_kv_cache) 1038 1039 if args.use_kv_cache: 1040 if args.qnn: 1041 from executorch.backends.qualcomm.utils.utils import ( 1042 convert_linear_to_conv2d, 1043 ) 1044 1045 if args.use_qnn_sha: 1046 if args.optimized_rotation_path: 1047 transforms.append(fuse_layer_norms) 1048 transforms.append( 1049 get_model_with_r1_r2(args.optimized_rotation_path) 1050 ) 1051 transforms.append(replace_attention_to_attention_sha) 1052 transforms.append(replace_causal_mask) 1053 transforms.append(replace_rms_norm_with_native_rms_norm) 1054 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. 1055 transforms.append(convert_linear_to_conv2d) 1056 else: 1057 transforms.append(replace_kv_cache_with_simple_kv_cache) 1058 transforms.append(replace_sdpa_with_flex_sdpa) 1059 transforms.append(replace_causal_mask) 1060 transforms.append(replace_rms_norm_with_native_rms_norm) 1061 if args.optimized_rotation_path: 1062 transforms.append(fuse_layer_norms) 1063 transforms.append( 1064 get_model_with_r1_r2(args.optimized_rotation_path) 1065 ) 1066 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. 1067 transforms.append(convert_linear_to_conv2d) 1068 1069 elif args.mps: 1070 # Currently mps doesn't support sdpa op, use the simpler decomposition 1071 # to get free perf gain. 1072 transforms.append(replace_sdpa_with_simple_sdpa) 1073 transforms.append(replace_causal_mask) 1074 1075 elif args.coreml: 1076 # iOS 18 introduced fused sdpa op 1077 if args.coreml_ios >= 18: 1078 transforms.append(replace_sdpa_with_coreml_sdpa) 1079 else: 1080 transforms.append(replace_sdpa_with_simple_sdpa) 1081 transforms.append(replace_kv_cache_with_coreml_kv_cache) 1082 1083 if args.vulkan: 1084 transforms.append(replace_with_vulkan_rotary_emb) 1085 1086 return transforms 1087