xref: /aosp_15_r20/external/executorch/examples/models/llama/export_llama_lib.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9# Example script for exporting Llama2 to flatbuffer
10
11import argparse
12import copy
13import json
14import logging
15import re
16import shlex
17from enum import Enum
18from json import JSONDecodeError
19from pathlib import Path
20from typing import Callable, List, Optional, Union
21
22import pkg_resources
23import torch
24
25from executorch.devtools.etrecord import generate_etrecord
26
27from executorch.extension.llm.export.builder import DType, LLMEdgeManager
28
29from executorch.extension.llm.export.partitioner_lib import (
30    get_coreml_partitioner,
31    get_mps_partitioner,
32    get_qnn_partitioner,
33    get_vulkan_partitioner,
34    get_xnnpack_partitioner,
35)
36
37from executorch.extension.llm.export.quantizer_lib import (
38    get_coreml_quantizer,
39    get_pt2e_quantization_params,
40    get_pt2e_quantizers,
41    get_qnn_quantizer,
42    get_vulkan_quantizer,
43)
44from executorch.util.activation_memory_profiler import generate_memory_trace
45
46from ..model_factory import EagerModelFactory
47from .source_transformation.apply_spin_quant_r1_r2 import (
48    fuse_layer_norms,
49    get_model_with_r1_r2,
50)
51
52from .source_transformation.attention import replace_attention_to_attention_sha
53from .source_transformation.quantize import (
54    get_quant_embedding_transform,
55    get_quant_weight_transform,
56)
57from .source_transformation.quantized_kv_cache import (
58    replace_kv_cache_with_quantized_kv_cache,
59)
60from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
61
62from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
63from .source_transformation.sdpa import (
64    replace_causal_mask,
65    replace_kv_cache_with_coreml_kv_cache,
66    replace_kv_cache_with_simple_kv_cache,
67    replace_sdpa_with_coreml_sdpa,
68    replace_sdpa_with_custom_op,
69    replace_sdpa_with_flex_sdpa,
70    replace_sdpa_with_simple_sdpa,
71)
72from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb
73
74IS_FBCODE = True  #  os.environ.get("FBCODE_PLATFORM", False)
75FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
76logging.basicConfig(level=logging.INFO, format=FORMAT)
77
78pkg_name = __name__
79verbosity_setting = None
80
81
82EXECUTORCH_DEFINED_MODELS = ["stories110m", "llama2", "llama3", "llama3_1", "llama3_2"]
83TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
84
85
86class WeightType(Enum):
87    LLAMA = "LLAMA"
88    FAIRSEQ2 = "FAIRSEQ2"
89
90
91def set_pkg_name(name: str) -> None:
92    global pkg_name
93    pkg_name = name
94
95
96def get_resource_path(resource_name) -> str:
97    return pkg_resources.resource_filename(pkg_name, resource_name)
98
99
100def set_verbosity(val):
101    global verbosity_setting
102    verbosity_setting = val
103
104
105def verbose_export():
106    return verbosity_setting
107
108
109def build_model(
110    modelname: str = "llama3",
111    extra_opts: str = "",
112    *,
113    par_local_output: bool = False,
114    resource_pkg_name: str = __name__,
115) -> str:
116    if False:  # par_local_output:
117        output_dir_path = "par:."
118    else:
119        output_dir_path = "."
120
121    argString = f"--model {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}"
122    parser = build_args_parser()
123    args = parser.parse_args(shlex.split(argString))
124    # pkg_name = resource_pkg_name
125    return export_llama(args)
126
127
128def build_args_parser() -> argparse.ArgumentParser:
129    ckpt_dir = f"{Path(__file__).absolute().parent.as_posix()}"
130    parser = argparse.ArgumentParser()
131    parser.add_argument("-o", "--output-dir", default=".", help="output directory")
132    # parser.add_argument(
133    #     "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file"
134    # )
135    parser.add_argument(
136        "--model",
137        default="llama3",
138        choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
139        help="The Lllama model to export. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.",
140    )
141    parser.add_argument(
142        "-E",
143        "--embedding-quantize",
144        default=None,
145        type=str,
146        help="type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
147    )
148    parser.add_argument(
149        "--pt2e_quantize",
150        default=None,
151        choices=[
152            "xnnpack_dynamic",
153            "xnnpack_dynamic_qc4",
154            "qnn_8a8w",
155            "qnn_16a16w",
156            "qnn_16a4w",
157            "coreml_c4w",
158            "coreml_8a_c8w",
159            "coreml_8a_c4w",
160            "coreml_baseline_8a_c8w",
161            "coreml_baseline_8a_c4w",
162            "vulkan_8w",
163        ],
164        help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
165    )
166
167    parser.add_argument(
168        "-qmode",
169        "--quantization_mode",
170        type=_qmode_type,
171        default=None,
172        help="type of quantization",
173    )
174
175    parser.add_argument(
176        "-c",
177        "--checkpoint",
178        default=f"{ckpt_dir}/params/demo_rand_params.pth",
179        help="checkpoint path",
180    )
181
182    parser.add_argument(
183        "--checkpoint_dir",
184        default=None,
185        help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
186    )
187
188    parser.add_argument(
189        "--use_qnn_sha",
190        action="store_true",
191        help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)",
192    )
193
194    parser.add_argument(
195        "--calibration_tasks",
196        nargs="+",
197        type=str,
198        default=None,
199        help="Tasks for GPTQ calibration from lm_eval",
200    )
201    parser.add_argument(
202        "--calibration_limit",
203        type=int,
204        default=None,
205        help="number of samples used for calibration from lm_eval",
206    )
207    parser.add_argument(
208        "--calibration_seq_length",
209        type=int,
210        default=None,
211        help="Sequence length for GPTQ calibration from lm_eval",
212    )
213    parser.add_argument(
214        "--calibration_data",
215        type=str,
216        default="Once upon a time",
217        help="Calibration prompts from users",
218    )
219    parser.add_argument(
220        "-t",
221        "--tokenizer_path",
222        default=None,
223        help="tokenizer path (Note: .model not .bin)",
224    )
225    parser.add_argument(
226        "-kv",
227        "--use_kv_cache",
228        default=False,
229        action="store_true",
230        help="Whether or not to export a model using kv cache",
231    )
232    parser.add_argument(
233        "--quantize_kv_cache",
234        default=False,
235        action="store_true",
236        help="Whether or not to export a model using int8 per token quantized kv cache",
237    )
238    parser.add_argument(
239        "--num_sharding",
240        type=int,
241        default=0,
242        help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.",
243    )
244    parser.add_argument(
245        "--use_sdpa_with_kv_cache",
246        default=False,
247        action="store_true",
248        help="Whether to use sdpa_with_kv_cache update op when using kv cache",
249    )
250    parser.add_argument(
251        "--disable_dynamic_shape",
252        dest="enable_dynamic_shape",
253        default=True,  # Enable this by default
254        action="store_false",
255        help="Enable dynamic shape along seq dim. Used for faster prefill",
256    )
257    parser.add_argument(
258        "-p",
259        "--params",
260        default=f"{ckpt_dir}/params/demo_config.json",
261        help="config.json",
262    )
263    parser.add_argument(
264        "--optimized_rotation_path",
265        default=None,
266        required=False,
267        help="[QNN backend] Optimized rotation checkpoint path. Just apply R1/R2 here."
268        "You can download the optimized rotation matrices from https://github.com/facebookresearch/SpinQuant/tree/main",
269    )
270    parser.add_argument(
271        "-m",
272        "--metadata",
273        default=None,
274        help='metadata string in json format. Example {"key": 1, "key2": "value2"}',
275    )
276    parser.add_argument(
277        "-s",
278        "--so_library",
279        default=None,
280        required=False,
281        help="shared library for quantized operators",
282    )
283    parser.add_argument(
284        "--profile_memory",
285        required=False,
286        action="store_true",
287        help="Generate chrome trace of activation memory for intermediate tensors.",
288    )
289    parser.add_argument(
290        "-prof",
291        "--profile_path",
292        default=None,
293        help="Use cProfile to profile model export. Results saved to profile_path as a html file.",
294    )
295    parser.add_argument(
296        "-G",
297        "--group_size",
298        type=int,
299        default=None,
300        help="group_size for weight quantization",
301    )
302
303    parser.add_argument(
304        "-d",
305        "--dtype-override",
306        default="fp32",
307        type=str,
308        choices=["fp32", "fp16", "bf16"],
309        help="Override the dtype of the model (default is the checkpoint dtype)."
310        "Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.",
311    )
312
313    parser.add_argument(
314        "-n",
315        "--output_name",
316        default=None,
317        help="Override the output filename of the saved pte model file.",
318    )
319
320    parser.add_argument(
321        "--max_seq_length",
322        type=int,
323        default=128,
324        help="maximum length sequence to evaluate",
325    )
326
327    parser.add_argument("-2", "--fairseq2", action="store_true")
328    parser.add_argument("-v", "--verbose", action="store_true")
329    parser.add_argument(
330        "-X",
331        "--xnnpack",
332        action="store_true",
333        help="Delegate to DQLinear ops to the xnnpack backend",
334    )
335    parser.add_argument(
336        "--xnnpack-extended-ops",
337        action="store_true",
338        help="Delegate more operators beyond DQLinear to the xnnpack backend. Requires -X or --xnnpack to be set.",
339    )
340    parser.add_argument("-V", "--vulkan", action="store_true")
341    parser.add_argument("--mps", action="store_true")
342    parser.add_argument("--coreml", action="store_true")
343    parser.add_argument(
344        "--coreml-enable-state",
345        action="store_true",
346        help="This option is only for coreml, and is only supported for MacOS15+/iOS18+",
347    )
348    parser.add_argument(
349        "--coreml-preserve-sdpa",
350        action="store_true",
351        help="This option is only for coreml: Preserve sdpa in torch edge program to use coreml iOS18.sdpa op",
352    )
353    parser.add_argument(
354        "--coreml-quantize",
355        default=None,
356        choices=["b4w"],
357        help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight)",
358    )
359    parser.add_argument(
360        "--coreml-ios",
361        type=int,
362        default=15,
363        choices=(15, 16, 17, 18),
364        help="This option is only for coreml: The minimum iOS version to deploy",
365    )
366    parser.add_argument(
367        "--qnn",
368        action="store_true",
369        help="Delegate llama2 to qnn backend (Qualcomm), please use it --kv_cahce=True",
370    )
371
372    parser.add_argument(
373        "--expand_rope_table",
374        default=False,
375        action="store_true",
376        help="[Temp workaround] Expand sin/cos table in head dim to take vectorized path in optimized kernels.",
377    )
378
379    parser.add_argument(
380        "--generate_etrecord",
381        action="store_true",
382        required=False,
383        default=False,
384        help="Generate the ETRecord debug artifact.",
385    )
386
387    parser.add_argument(
388        "--generate_full_logits",
389        action="store_true",
390        required=False,
391        default=False,
392        help="Generate logits for all inputs.",
393    )
394
395    parser.add_argument(
396        "--soc_model",
397        help="[QNN backend] SoC model of current device. e.g. 'SM8650' for Snapdragon 8 Gen 3.",
398        type=str,
399        required=False,
400        default="SM8650",
401    )
402
403    parser.add_argument(
404        "-sq",
405        "--use_spin_quant",
406        type=str,
407        default=None,
408        choices=["cuda", "native"],
409        help="Use SpinQuant for better quantization performance. Only support cuda and native.",
410    )
411
412    parser.add_argument(
413        "-qat",
414        "--use_qat",
415        default=False,
416        action="store_true",
417        help="Whether the checkpoin is pre-quantized with QAT or not.",
418    )
419
420    parser.add_argument(
421        "-lora",
422        "--use_lora",
423        type=int,
424        default=0,
425        help="Whether the checkpoint contains LoRA adaptors or not. 0: no LoRA adaptors; "
426        "otherwise, it means the rank of LoRA adaptors. Currently it only works if QAT is enabled.",
427    )
428
429    parser.add_argument(
430        "--preq_mode",
431        type=str,
432        default=None,
433        choices=["8da4w", "8da4w_output_8da8w"],
434        help="Quantization mode used for pre-quantized checkpoint. Only support 8da4w and 8da4w_output_8da8w right now.",
435    )
436
437    parser.add_argument(
438        "--preq_group_size",
439        type=int,
440        default=32,
441        help="group_size for pre-quantized checkpoint weight quantization",
442    )
443
444    parser.add_argument(
445        "--preq_embedding_quantize",
446        default="8,0",
447        type=str,
448        help="type of embedding quantization for pre-quantized checkpoint, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
449    )
450
451    parser.add_argument(
452        "--output_prune_map",
453        default=None,
454        help="path to the output pruning token mapping file (token_map.json)",
455    )
456
457    parser.add_argument(
458        "--input_prune_map",
459        default=None,
460        help="path to the input pruning token mapping file (token_map.json)",
461    )
462
463    parser.add_argument(
464        "--export_only",
465        default=False,
466        action="store_true",
467        help="If true, stops right after torch.export() and saves the exported model.",
468    )
469    return parser
470
471
472def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
473    path = str(path)
474
475    if verbose_export():
476        print(f"creating canonical path for {path}")
477
478    if not path.startswith("par:"):
479        return path
480
481    if not IS_FBCODE:
482        print("not FBCODE")
483        return path[4:]
484    else:
485        return_val = pkg_resources.resource_filename(pkg_name, path[4:])
486        if verbose_export():
487            print(f"canonical name is: {return_val}")
488        return return_val
489
490
491def export_llama(args) -> str:
492    if args.profile_path is not None:
493        try:
494            from executorch.util.python_profiler import CProfilerFlameGraph
495
496            with CProfilerFlameGraph(args.profile_path):
497                builder = _export_llama(args)
498                assert (
499                    filename := builder.get_saved_pte_filename()
500                ) is not None, "Fail to get file name from builder"
501                return filename
502        except ImportError:
503            print(
504                "Please run `pip install snakeviz` to install required dependencies for cProfiler flamegraph."
505            )
506            return ""
507    else:
508        builder = _export_llama(args)
509        assert (
510            filename := builder.get_saved_pte_filename()
511        ) is not None, "Fail to get file name from builder"
512        return filename
513
514
515def _prepare_for_llama_export(args) -> LLMEdgeManager:
516    """
517    Helper function for export_llama. Loads the model from checkpoint and params,
518    and sets up a LLMEdgeManager with initial transforms and dtype conversion.
519
520    Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
521    """
522    # load model from checkpoint and params.json
523    checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None
524    checkpoint_dir = (
525        canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
526    )
527    params_path = canonical_path(args.params)
528    output_dir_path = canonical_path(args.output_dir, dir=True)
529    weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
530
531    # dtype override
532    if args.dtype_override is not None:
533        dtype_override = DType[args.dtype_override]
534    elif args.quantization_mode in ["8da4w", "8da4w-gptq"]:
535        dtype_override = DType["fp16"]
536    else:
537        dtype_override = None
538
539    return (
540        _load_llama_model(
541            args.model,
542            checkpoint=checkpoint_path,
543            checkpoint_dir=checkpoint_dir,
544            params_path=params_path,
545            use_kv_cache=args.use_kv_cache,
546            use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
547            generate_full_logits=args.generate_full_logits,
548            weight_type=weight_type,
549            enable_dynamic_shape=args.enable_dynamic_shape,
550            calibration_tasks=args.calibration_tasks,
551            calibration_limit=args.calibration_limit,
552            calibration_seq_length=args.calibration_seq_length,
553            calibration_data=args.calibration_data,
554            tokenizer_path=args.tokenizer_path,
555            verbose=args.verbose,
556            max_seq_len=args.max_seq_length,
557            input_prune_map_path=args.input_prune_map,
558            output_prune_map_path=args.output_prune_map,
559            metadata_str=args.metadata,
560            dtype_override=dtype_override,
561            args=args,
562        )
563        .set_output_dir(output_dir_path)
564        .source_transform(_get_source_transforms(args.model, dtype_override, args))
565    )
566
567
568def get_quantizer_and_quant_params(args):
569    pt2e_quant_params = get_pt2e_quantization_params(
570        args.pt2e_quantize, args.quantization_mode
571    )
572    quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library)
573    quant_dtype = None
574    if args.qnn and args.pt2e_quantize:
575        assert len(quantizers) == 0, "Should not enable both xnnpack and qnn"
576        qnn_quantizer, quant_dtype = get_qnn_quantizer(
577            args.pt2e_quantize, args.quantization_mode
578        )
579        quantizers.append(qnn_quantizer)
580    if args.coreml and args.pt2e_quantize:
581        assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
582        coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize)
583        quantizers.append(coreml_quantizer)
584    if args.vulkan and args.pt2e_quantize:
585        assert (
586            len(quantizers) == 0
587        ), "Should not enable both vulkan and other quantizers"
588        vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize)
589        quantizers.append(vulkan_quantizer)
590    logging.info(f"Applying quantizers: {quantizers}")
591    return pt2e_quant_params, quantizers, quant_dtype
592
593
594def _qmode_type(value):
595    choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
596    patterns = [r"torchao:8da(\d+)w"]
597
598    if value in choices:
599        return value
600
601    for pattern in patterns:
602        matches = re.findall(pattern, value)
603        if len(matches) == 1:
604            return value
605
606    raise argparse.ArgumentTypeError(
607        f"Got qmode {value}, but expected one of {choices}, or one of the regex patterns {patterns}."
608    )
609
610
611def _validate_args(args):
612    """
613    TODO: Combine all the backends under --backend args
614    """
615    if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn):
616        raise ValueError(
617            "Dynamic shape is not supported with coreml, MPS or qnn backends."
618            " Please use --disable_dynamic_shape."
619        )
620
621    if args.num_sharding > 0 and not args.qnn:
622        raise ValueError("Model shard is only supported with qnn backend now.")
623
624    if (
625        args.quantization_mode is not None
626        and args.quantization_mode.startswith("torchao:")
627    ) or (
628        args.embedding_quantize is not None
629        and args.embedding_quantize.startswith("torchao:")
630    ):
631        if args.enable_dynamic_shape:
632            raise ValueError(
633                "Dynamic shape is not currently supported with torchao ops. Please use --disable_dynamic_shape."
634                "If you need this feature, please file an issue."
635            )
636
637
638def _export_llama(args) -> LLMEdgeManager:  # noqa: C901
639    _validate_args(args)
640    pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
641
642    # export_to_edge
643    builder_exported = _prepare_for_llama_export(args).export()
644
645    if args.export_only:
646        exit()
647
648    builder_exported_to_edge = builder_exported.pt2e_quantize(
649        quantizers
650    ).export_to_edge()
651
652    modelname = builder_exported_to_edge.modelname
653
654    # to_backend
655    partitioners = []
656
657    # Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
658    if (
659        pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None
660    ) or (args.xnnpack):
661        partitioners.append(
662            get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)
663        )
664
665        # force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
666        args.xnnpack = True
667        modelname = f"xnnpack_dq_{modelname}"
668
669    if args.xnnpack_extended_ops:
670        assert args.xnnpack, "xnnpack_extended_ops requires xnnpack to be enabled"
671        partitioners.append(
672            get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
673        )
674        modelname = f"xnnpack_{modelname}"
675
676    if args.vulkan:
677        partitioners.append(
678            get_vulkan_partitioner(
679                args.dtype_override,
680                args.enable_dynamic_shape,
681            )
682        )
683        # Apply XNNPACK after Vulkan so that undelegated ops can be accelerated by XNNPACK
684        partitioners.append(
685            get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
686        )
687        modelname = f"vulkan_{modelname}"
688
689    if args.mps:
690        partitioners.append(get_mps_partitioner(args.use_kv_cache))
691        modelname = f"mps_{modelname}"
692
693    if args.coreml:
694        coreml_partitioner = get_coreml_partitioner(
695            args.coreml_ios,
696            args.embedding_quantize,
697            args.pt2e_quantize,
698            args.coreml_quantize,
699        )
700        partitioners.append(coreml_partitioner)
701        modelname = f"coreml_{modelname}"
702
703    if args.qnn:
704        from executorch.extension.llm.custom_ops import model_sharding
705
706        partitioners.append(
707            get_qnn_partitioner(
708                args.use_kv_cache, args.pt2e_quantize, args.num_sharding, args.soc_model
709            )
710        )
711        # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
712        from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io
713
714        # pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
715        _transform(builder_exported_to_edge.edge_manager.exported_program())
716
717        if args.num_sharding > 0:
718            model_sharding.split_graph(
719                builder_exported_to_edge.edge_manager.exported_program(),
720                # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
721                builder_exported_to_edge.metadata["get_n_layers"],
722                shares=args.num_sharding,
723            )
724
725        from functools import partial
726
727        # pyre-ignore
728        from executorch.backends.qualcomm.quantizer.custom_annotation import (
729            get_custom_quant_ios_dtype,
730        )
731
732        atten = builder_exported_to_edge.model.layers[0].attention
733        if args.use_qnn_sha:
734            cache_shape = torch.Size(
735                (atten.max_batch_size, atten.max_seq_len, atten.head_dim)
736            )
737        else:
738            cache_shape = torch.Size(
739                (
740                    atten.max_batch_size,
741                    atten.max_seq_len,
742                    atten.n_kv_heads,
743                    atten.head_dim,
744                )
745            )
746        # pyre-ignore
747        tag_quant_io(
748            builder_exported_to_edge.edge_manager.exported_program().graph_module,
749            partial(get_custom_quant_ios_dtype, cache_shape),  # pyre-ignore
750        )
751
752    logging.info("Lowering model using following partitioner(s): ")
753    for partitioner in partitioners:
754        logging.info(f"--> {partitioner.__class__.__name__}")
755
756    if args.generate_etrecord:
757        if not builder_exported_to_edge.edge_manager:
758            raise ValueError("Unable to generate etrecord due to missing edge manager.")
759
760        logging.info("Generating etrecord")
761        # Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
762        edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager)
763        builder = builder_exported_to_edge.to_backend(partitioners)
764        if args.num_sharding > 0 and args.qnn:
765            from executorch.backends.qualcomm.utils.utils import canonicalize_program
766
767            # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
768            canonicalize_program(builder.edge_manager.exported_program())
769
770        builder = builder.to_executorch()
771
772        # Generate ETRecord
773        if edge_manager_copy:
774            generate_etrecord(
775                et_record="etrecord.bin",
776                edge_dialect_program=edge_manager_copy,
777                executorch_program=builder.export_program,
778            )
779            logging.info("Generated etrecord.bin")
780    else:
781        builder = builder_exported_to_edge.to_backend(partitioners)
782        if args.num_sharding > 0 and args.qnn:
783            from executorch.backends.qualcomm.utils.utils import canonicalize_program
784
785            # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
786            canonicalize_program(builder.edge_manager.exported_program())
787
788        builder = builder.to_executorch()
789
790    if args.profile_memory:
791        generate_memory_trace(builder.export_program, "memory_profile.json")
792
793    if builder.dtype == DType.fp16:
794        modelname = f"{modelname}_h"
795
796    if args.output_name:
797        modelname = args.output_name
798        if modelname.endswith(".pte"):
799            output_file = modelname
800            modelname = modelname[:-4]
801            print(f"modelname: {modelname}")
802            print(f"output_file: {output_file}")
803        else:
804            output_file = f"{builder.output_dir}/{modelname}.pte"
805            print(f"modelname: {modelname}")
806            print(f"output_file: {output_file}")
807    else:
808        output_file = f"{builder.output_dir}/{modelname}.pte"
809
810    builder.save_to_pte(output_file)
811
812    return builder
813
814
815def _load_llama_model_metadata(
816    weight_type: WeightType,
817    use_kv_cache: bool,
818    use_sdpa_with_kv_cache: bool,
819    enable_dynamic_shape: bool,
820    max_seq_len: int,
821    n_layers: int,
822    vocab_size: int,
823    metadata_str: Optional[str] = None,
824):
825    is_fairseq2 = weight_type == WeightType.FAIRSEQ2
826    metadata = {
827        "get_bos_id": 3 if is_fairseq2 else 1,
828        "get_eos_ids": [3] if is_fairseq2 else [2],
829        "get_max_seq_len": max_seq_len,
830        "get_n_layers": n_layers,
831        "get_vocab_size": vocab_size,
832        "use_kv_cache": use_kv_cache,
833        "use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
834        "enable_dynamic_shape": enable_dynamic_shape,
835    }
836    if metadata_str:
837        try:
838            extra = json.loads(metadata_str)
839            for k, v in extra.items():
840                metadata[k] = v
841        except JSONDecodeError:
842            logging.error("Invalid metadata, should be a valid JSON string")
843    return metadata
844
845
846def _load_llama_model(
847    modelname: str = "llama3",
848    *,
849    checkpoint: Optional[str] = None,
850    checkpoint_dir: Optional[str] = None,
851    params_path: str,
852    use_kv_cache: bool = False,
853    use_sdpa_with_kv_cache: bool = False,
854    generate_full_logits: bool = False,
855    weight_type: WeightType = WeightType.LLAMA,
856    enable_dynamic_shape: bool = False,
857    calibration_tasks: Optional[List[str]] = None,
858    calibration_limit: Optional[int] = None,
859    calibration_seq_length: Optional[int] = None,
860    calibration_data: Optional[str] = None,
861    tokenizer_path: Optional[str] = None,
862    verbose: bool = False,
863    max_seq_len: int = 128,
864    input_prune_map_path: Optional[str] = None,
865    output_prune_map_path: Optional[str] = None,
866    metadata_str: Optional[str] = None,
867    dtype_override: Optional[DType] = None,
868    args,
869) -> "LLMEdgeManager":
870    """
871    A helper util that builds a Llama2 model. It returns a LLMEdgeManager that
872    can help further lower the model to ExecuTorch.
873    Returns:
874        An instance of LLMEdgeManager which contains the eager mode model.
875    """
876
877    assert (
878        checkpoint or checkpoint_dir
879    ) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
880    logging.info(
881        f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
882    )
883
884    if modelname in EXECUTORCH_DEFINED_MODELS:
885        module_name = "llama"
886        model_class_name = "Llama2Model"  # TODO: Change to "LlamaModel" in examples/models/llama/model.py.
887    elif modelname in TORCHTUNE_DEFINED_MODELS:
888        if modelname == "llama3_2_vision":
889            module_name = "llama3_2_vision"
890            model_class_name = "Llama3_2Decoder"
891        else:
892            raise ValueError(f"{modelname} is not a valid Llama model.")
893    else:
894        raise ValueError(f"{modelname} is not a valid Llama model.")
895
896    model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
897        EagerModelFactory.create_model(
898            module_name,
899            model_class_name,
900            checkpoint=checkpoint,
901            checkpoint_dir=checkpoint_dir,
902            params=params_path,
903            use_kv_cache=use_kv_cache,
904            use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
905            generate_full_logits=generate_full_logits,
906            fairseq2=weight_type == WeightType.FAIRSEQ2,
907            max_seq_len=max_seq_len,
908            enable_dynamic_shape=enable_dynamic_shape,
909            input_prune_map_path=input_prune_map_path,
910            output_prune_map_path=output_prune_map_path,
911            args=args,
912        )
913    )
914    if dtype_override:
915        assert isinstance(
916            dtype_override, DType
917        ), "Override dtype needs to be of type <DType>"
918        torch_dtype = dtype_override.to_torch_dtype()
919        logging.info(f"model.to {torch_dtype}")
920        model = model.to(dtype=torch_dtype)
921        dtype = dtype_override
922    else:
923        state_dict = model.state_dict()
924        dtype = state_dict[next(iter(state_dict))].dtype
925        assert dtype in [
926            torch.bfloat16,
927            torch.float16,
928            torch.float32,
929        ], f"Only support bfloat16, fp16 or fp32 got {dtype}"
930        logging.info(f"Loaded model with dtype={dtype}")
931
932        if dtype == torch.bfloat16:
933            dtype = DType.bf16
934        elif dtype == torch.float16:
935            dtype = DType.fp16
936        elif dtype == torch.float32:
937            dtype = DType.fp32
938        else:
939            raise ValueError(f"Unsupported dtype {dtype}")
940
941    return LLMEdgeManager(
942        model=model,
943        modelname=modelname,
944        max_seq_len=model.max_seq_len,
945        dtype=dtype,
946        use_kv_cache=use_kv_cache,
947        generate_full_logits=generate_full_logits,
948        example_inputs=example_inputs,
949        example_kwarg_inputs=example_kwarg_inputs,
950        dynamic_shapes=dynamic_shapes,
951        enable_dynamic_shape=enable_dynamic_shape,
952        calibration_tasks=calibration_tasks,
953        calibration_limit=calibration_limit,
954        calibration_seq_length=calibration_seq_length,
955        calibration_data=calibration_data,
956        tokenizer_path=tokenizer_path,
957        verbose=verbose,
958        metadata=_load_llama_model_metadata(
959            weight_type,
960            use_kv_cache,
961            use_sdpa_with_kv_cache,
962            enable_dynamic_shape,
963            # pyre-fixme[6]: For 5th argument expected `ModelArgs` but got
964            #  `Union[Tensor, Module]`.
965            model.max_seq_len,
966            # pyre-fixme[6]: For 6th argument expected `int` but got `Union[Tensor,
967            #  Module]`.
968            model.n_layers,
969            # pyre-fixme[6]: For 7th argument expected `int` but got `Union[Tensor,
970            #  Module]`.
971            model.vocab_size,
972            metadata_str,
973        ),
974        args=args,
975    )
976
977
978def _get_source_transforms(  # noqa
979    modelname: str, dtype_override: Optional[DType], args
980) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
981    transforms = []
982
983    if args.use_spin_quant:
984        if args.use_spin_quant == "cuda":
985            from .source_transformation.spin_quant import (
986                inject_fast_hadamard_transform_cuda_for_spin_quant,
987            )
988
989            transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
990        elif args.use_spin_quant == "native":
991            from .source_transformation.spin_quant import (
992                inject_fast_hadamard_transform_native_for_spin_quant,
993            )
994
995            transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
996
997    if args.quantization_mode:
998        """
999        When this option is selected, it finds all linear layers and transforms
1000        into quantized linear equivalent module.
1001
1002        There are cases where the checkpoint is already quantized, for example
1003        on use_spin_quant is enabled. In that case, it will do the appropriate
1004        transformations based on the given checkpoint first. In those cases,
1005        if quantization_mode is enabled, it will quantize any remaining linear
1006        ops that is not quantized.
1007
1008        There are cases where this may be a no-op, namely, if all linears are
1009        quantized in the checkpoint.
1010        """
1011        modelname = f"{modelname}_q"
1012        transforms.append(
1013            get_quant_weight_transform(args, dtype_override, verbose_export())
1014        )
1015
1016    if args.embedding_quantize:
1017        """
1018        When this option is selected, it finds all embedding layers and transforms
1019        into quantized embedding equivalent module.
1020
1021        There are cases where the checkpoint is already quantized, for example
1022        on use_spin_quant is enabled. In that case, it will do the appropriate
1023        transformations based on the given checkpoint first. In those cases,
1024        this wil be a no-op.
1025        """
1026        modelname = f"{modelname}_e"
1027        transforms.append(get_quant_embedding_transform(args))
1028
1029    if args.expand_rope_table:
1030        transforms.append(materialze_broadcast_of_rope_freq_cis)
1031
1032    if args.use_sdpa_with_kv_cache:
1033        transforms.append(replace_sdpa_with_custom_op)
1034
1035    if args.quantize_kv_cache:
1036        assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
1037        transforms.append(replace_kv_cache_with_quantized_kv_cache)
1038
1039    if args.use_kv_cache:
1040        if args.qnn:
1041            from executorch.backends.qualcomm.utils.utils import (
1042                convert_linear_to_conv2d,
1043            )
1044
1045            if args.use_qnn_sha:
1046                if args.optimized_rotation_path:
1047                    transforms.append(fuse_layer_norms)
1048                    transforms.append(
1049                        get_model_with_r1_r2(args.optimized_rotation_path)
1050                    )
1051                transforms.append(replace_attention_to_attention_sha)
1052                transforms.append(replace_causal_mask)
1053                transforms.append(replace_rms_norm_with_native_rms_norm)
1054                # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
1055                transforms.append(convert_linear_to_conv2d)
1056            else:
1057                transforms.append(replace_kv_cache_with_simple_kv_cache)
1058                transforms.append(replace_sdpa_with_flex_sdpa)
1059                transforms.append(replace_causal_mask)
1060                transforms.append(replace_rms_norm_with_native_rms_norm)
1061                if args.optimized_rotation_path:
1062                    transforms.append(fuse_layer_norms)
1063                    transforms.append(
1064                        get_model_with_r1_r2(args.optimized_rotation_path)
1065                    )
1066                # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
1067                transforms.append(convert_linear_to_conv2d)
1068
1069        elif args.mps:
1070            # Currently mps doesn't support sdpa op, use the simpler decomposition
1071            # to get free perf gain.
1072            transforms.append(replace_sdpa_with_simple_sdpa)
1073            transforms.append(replace_causal_mask)
1074
1075        elif args.coreml:
1076            # iOS 18 introduced fused sdpa op
1077            if args.coreml_ios >= 18:
1078                transforms.append(replace_sdpa_with_coreml_sdpa)
1079            else:
1080                transforms.append(replace_sdpa_with_simple_sdpa)
1081            transforms.append(replace_kv_cache_with_coreml_kv_cache)
1082
1083    if args.vulkan:
1084        transforms.append(replace_with_vulkan_rotary_emb)
1085
1086    return transforms
1087