xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/llama2/llama.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import codecs
8import getpass
9import json
10import os
11import time
12from multiprocessing.connection import Client
13
14import torch
15from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo
16
17from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
18
19from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
20from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
21from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO
22from executorch.backends.qualcomm.utils.utils import (
23    capture_program,
24    convert_linear_to_conv2d,
25    generate_htp_compiler_spec,
26    generate_qnn_executorch_compiler_spec,
27    get_soc_to_chipset_map,
28)
29from executorch.examples.qualcomm.oss_scripts.llama2.model.static_llama import (
30    LlamaModel,
31    ModelArgs,
32)
33from executorch.examples.qualcomm.utils import (
34    make_output_dir,
35    make_quantizer,
36    setup_common_args_and_variables,
37    SimpleADB,
38)
39from executorch.exir import EdgeCompileConfig, EdgeProgramManager
40from executorch.exir.capture._config import ExecutorchBackendConfig
41from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
42from executorch.extension.llm.export.builder import DType
43
44from sentencepiece import SentencePieceProcessor
45from torch.ao.quantization.observer import MinMaxObserver
46from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
47
48
49pte_filename = "llama2_qnn"
50
51
52def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
53    """
54    This function is specific for matmul op 16a8w.
55    """
56
57    from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
58    from executorch.backends.qualcomm.quantizer.quantizer import (
59        get_16a8w_qnn_ptq_config,
60        get_8a8w_qnn_ptq_config,
61        QuantizationConfig,
62    )
63    from torch.ao.quantization.quantizer import (
64        QuantizationAnnotation,
65        SharedQuantizationSpec,
66    )
67    from torch.fx import Node
68
69    def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
70        input_qspec_map = {}
71        input_act = node.args[0]
72        input_spec = quantization_config.input_activation
73        input_qspec_map[input_act] = input_spec
74
75        input_act1 = node.args[1]
76        input_spec1 = quantization_config.weight
77        input_qspec_map[input_act1] = input_spec1
78
79        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
80            input_qspec_map=input_qspec_map,
81            output_qspec=quantization_config.output_activation,
82            _annotated=True,
83        )
84
85    def annotate_cat(node: Node, quantization_config: QuantizationConfig):
86        input_nodes = node.args[0]
87
88        first_input_node = input_nodes[0]
89        input_qspec_map = {}
90        input_qspec_map[first_input_node] = quantization_config.input_activation
91        share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
92            (first_input_node, node)
93        )
94
95        for input_node in input_nodes[1:]:
96            if input_node not in input_qspec_map:
97                input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
98
99        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
100            input_qspec_map=input_qspec_map,
101            output_qspec=share_qparams_with_input_act0_qspec,
102            _annotated=True,
103        )
104
105    def annotate_single_in_single_out(
106        node: Node, quantization_config: QuantizationConfig
107    ) -> None:
108
109        input_qspec_map = {}
110        input_act = node.args[0]
111        input_qspec_map[input_act] = quantization_config.input_activation
112
113        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
114            input_qspec_map=input_qspec_map,
115            output_qspec=quantization_config.output_activation,
116            _annotated=True,
117        )
118
119    def annotate_matmul_input1(node: Node):
120        quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True)
121        while isinstance(node, Node) and node.op == "call_function":
122            if node.target in [
123                torch.ops.aten.permute.default,
124                torch.ops.aten.transpose.int,
125            ]:
126                annotate_single_in_single_out(node, quantization_config_8a8w)
127                node = node.args[0]
128            elif node.target == torch.ops.aten.cat.default:
129                annotate_cat(node, quantization_config_8a8w)
130                node = node.args[0][0]
131            else:
132                node = node.args[0]
133
134    quantization_config_16a8w = get_16a8w_qnn_ptq_config()
135
136    for node in gm.graph.nodes:
137        if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
138            annotate_matmul(node, quantization_config_16a8w)
139            annotate_matmul_input1(node.args[1])
140
141
142def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None:
143    from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
144    from executorch.backends.qualcomm.quantizer.quantizer import (
145        get_ptq_per_channel_quant_config,
146        QuantizationConfig,
147    )
148    from torch.ao.quantization.quantizer import QuantizationAnnotation
149    from torch.fx import Node
150
151    def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
152        input_qspec_map = {}
153        input_act = node.args[0]
154        input_spec = quantization_config.input_activation
155        input_qspec_map[input_act] = input_spec
156
157        weight = node.args[1]
158        input_qspec_map[weight] = quantization_config.weight
159
160        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
161            input_qspec_map=input_qspec_map,
162            output_qspec=quantization_config.output_activation,
163            _annotated=True,
164        )
165
166    quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
167        torch.uint16, weight_dtype=torch.int8
168    )
169    for node in gm.graph.nodes:
170        if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
171            if "nn_module_stack" in node.meta:
172                module_values_list = list(node.meta["nn_module_stack"].values())
173                full_qualified_name = module_values_list[0][0]
174                if full_qualified_name == "L['self'].llama.output":
175                    annotate_conv2d(
176                        node, quantization_config=quantization_config_16a8w_per_channel
177                    )
178
179
180def calibrate(
181    example_inputs,
182    user_prompts,
183    module: torch.fx.GraphModule,
184    tokenizer_model_path="tokenizer.model",
185):
186    sp_model = SentencePieceProcessor(model_file=tokenizer_model_path)
187    _, _, atten_mask, k_caches, v_caches = example_inputs
188
189    # TODO: change criteria & support batch inputs if necessary
190    pos = torch.tensor(0, dtype=torch.int32)
191    token_list = [sp_model.bos_id()]
192    for prompt in user_prompts.split():
193        token_list += sp_model.encode(prompt)
194
195    def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
196        probs_sort, probs_indices = torch.sort(probs, dim=-1, descending=True)
197        probs_sum = torch.cumsum(probs_sort, dim=-1)
198        mask = probs_sum - probs_sort > top_p
199        probs_sort[mask] = 0
200        probs_sort /= probs_sort.sum(dim=-1, keepdim=True)
201        next_token = torch.multinomial(probs_sort, num_samples=1)
202        return probs_indices.gather(dim=-1, index=next_token)
203
204    with torch.no_grad():
205        while token_list[-1] != sp_model.eos_id() and pos < 128:
206            logits, new_k_caches, new_v_caches = module(
207                torch.full((1, 1), token_list[pos]),
208                torch.full((1, 1), pos),
209                atten_mask,
210                *k_caches,
211                *v_caches,
212            )
213            k_caches = [
214                torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
215                for i, k_cache in enumerate(k_caches)
216            ]
217            v_caches = [
218                torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
219                for i, v_cache in enumerate(v_caches)
220            ]
221
222            pos += 1
223            atten_mask[0][-pos - 1] = 0
224            if pos >= len(token_list):
225                probs = torch.softmax(logits[:, -1] / 0.8, dim=-1)
226                token_list.append(sample_top_p(probs, 0.9).item())
227
228    print(f"calibration data:\n{sp_model.decode(token_list)}")
229
230
231class SingleLlama:
232    def __init__(self, llama_model) -> None:
233        super().__init__()
234        self.llama_model = llama_model
235        self.quant_dtype = None
236        self.llama_meta = self.llama_model.get_metadata()
237        self.has_quant_io = False
238        tokens, pos_ids, atten_mask, k_caches, v_caches = self.get_example_inputs()
239        self.inputs = (tokens, pos_ids, atten_mask, *k_caches, *v_caches)
240
241    def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type):
242        if not self.has_quant_io:
243            return
244
245        # shape of k caches and v caches
246        input_cache_shape = {
247            (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]),
248            (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]),
249        }
250        for n in gm.graph.nodes:
251            if (
252                n.op == "placeholder"
253                and len(users := list(n.users)) == 1
254                and users[0].meta["val"].size()[-2:] in input_cache_shape
255            ):
256                n.meta[QCOM_QUANTIZED_IO] = kv_type
257            elif n.op == "output":
258                for a in n.args[0]:
259                    if (
260                        a.meta["val"].flatten().size()[0]
261                        == self.llama_meta["get_head_dim"]
262                    ):
263                        a.meta[QCOM_QUANTIZED_IO] = kv_type
264
265    def quantize(self, quant_dtype, custom_annotations=()):
266        self.quant_dtype = quant_dtype
267        quantizer = make_quantizer(
268            quant_dtype=quant_dtype,
269            per_channel_conv=True,
270            per_channel_linear=True,
271            act_observer=MinMaxObserver,
272        )
273        quantizer.add_custom_quant_annotations(custom_annotations)
274
275        self.has_quant_io = True
276        fx_graph_module = None
277
278        with torch.no_grad():
279            fx_graph_module = torch.export.export(
280                self.llama_model, self.inputs
281            ).module()
282            fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
283        print("Quantizing the model...")
284        calibrate(
285            self.get_example_inputs(),
286            args.prompt,
287            fx_graph_module,
288            tokenizer_model_path=args.tokenizer_model,
289        )
290
291        self.llama_model = convert_pt2e(fx_graph_module)
292
293    def lowering_modules(
294        self, work_space, kv_type=torch.uint8, soc_model=QcomChipset.SM8650
295    ):
296        executorch_config = ExecutorchBackendConfig(
297            passes=[
298                BuildQuantIo(),
299            ],
300            # For shared buffer, user must pass the memory address
301            # which is allocated by RPC memory to executor runner.
302            # Therefore, won't want to pre-allocate
303            # by memory manager in runtime.
304            memory_planning_pass=MemoryPlanningPass(
305                alloc_graph_input=False,
306                alloc_graph_output=False,
307            ),
308            extract_delegate_segments=True,
309        )
310        with torch.no_grad():
311            # backend option
312            backend_options = generate_htp_compiler_spec(use_fp16=False)
313            compiler_specs = generate_qnn_executorch_compiler_spec(
314                soc_model=soc_model,
315                backend_options=backend_options,
316                shared_buffer=True,
317            )
318            partitioner = QnnPartitioner(compiler_specs)
319            edge_prog = capture_program(self.llama_model, self.inputs)
320            self._tag_kv_ios(edge_prog.exported_program.graph_module, kv_type=kv_type)
321            edge_prog_mgr = EdgeProgramManager(
322                edge_programs={"forward": edge_prog.exported_program},
323                constant_methods=self.llama_meta,
324                compile_config=EdgeCompileConfig(_check_ir_validity=False),
325            )
326            edge_prog_mgr = edge_prog_mgr.to_backend(partitioner)
327            exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config)
328            with open(f"{work_space}/{pte_filename}.pte", "wb") as file:
329                exec_prog_mgr.write_to_file(file)
330
331    def get_example_inputs(self):
332        return self.llama_model.get_example_inputs()
333
334
335def compile(args):
336    os.makedirs(args.artifact, exist_ok=True)
337    start_ts = time.time()
338    with open(args.params) as f:
339        config = ModelArgs(**json.load(f))
340        # TODO: support batch inputs if necessary
341        config.max_batch_size = 1
342        config.max_seq_len = 1024
343    state_dict = torch.load(
344        args.checkpoint, weights_only=True, map_location="cpu", mmap=True
345    )
346    end_load_ts = time.time()
347    print("torch.load checkpoint", end_load_ts - start_ts)
348
349    llama_instance = None
350    with torch.device("meta"):
351        llama_instance = LlamaModel(config, output_new_cache_only=True)
352    if "model" in state_dict:
353        state_dict = state_dict["model"]
354    llama_instance.load_state_dict(
355        state_dict,
356        strict=False,
357        assign=True,
358    )
359    end_load_state_dict_ts = time.time()
360    print("instance.load_state_dict", end_load_state_dict_ts - end_load_ts)
361
362    for layer in llama_instance.layers:
363        if getattr(layer.attention, "prepare_sha", None):
364            layer.attention.prepare_sha()
365
366    kv_type = torch.uint8
367    assert args.ptq in [
368        "8a8w",
369        "16a4w",
370    ], f"No support for quant type {args.ptq}. Support 8a8w and 16a4w."
371    quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
372    assert args.tokenizer_model is not None, "Need tokenizer model for calibration"
373
374    if args.dtype_override is not None:
375        dtype_override = DType[args.dtype_override]
376        llama_instance = llama_instance.to(dtype_override.to_torch_dtype())
377
378    llama_instance = convert_linear_to_conv2d(llama_instance)
379    single_llama = SingleLlama(llama_instance.eval())
380
381    start_quantize_ts = time.time()
382    single_llama.quantize(
383        quant_dtype,
384        custom_annotations=(
385            annotate_matmul_16a8w,
386            annotate_linear_16a8w_in_affine_layer,
387        ),
388    )
389    end_quantize_ts = time.time()
390    print("single_llama.quantize(quant_dtype)", end_quantize_ts - start_quantize_ts)
391    single_llama.lowering_modules(
392        args.artifact, kv_type=kv_type, soc_model=get_soc_to_chipset_map()[args.model]
393    )
394    end_lowering_ts = time.time()
395    print("Complete Compile", end_lowering_ts - end_quantize_ts)
396
397
398def inference(args, pre_gen_pte=""):
399    workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama"
400
401    runner_args = " ".join(
402        [
403            f"--model_path {pte_filename}.pte",
404            "--output_folder_path outputs",
405            f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}",
406            f'--prompt "{args.prompt}"',
407            f"--seq_len {args.seq_len}",
408            f"--temperature {args.temperature}",
409        ]
410    )
411    runner_cmd = " ".join(
412        [
413            f"cd {workspace} &&",
414            f"./qnn_llama_runner {runner_args}",
415        ]
416    )
417
418    pte_path = (
419        f"{pre_gen_pte}/{pte_filename}.pte"
420        if pre_gen_pte
421        else f"{args.artifact}/{pte_filename}.pte"
422    )
423    adb = SimpleADB(
424        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
425        build_path=f"{args.build_folder}",
426        pte_path=pte_path,
427        workspace=workspace,
428        device_id=args.device,
429        host_id=args.host,
430        soc_model=args.model,
431        shared_buffer=args.shared_buffer,
432        runner="examples/qualcomm/oss_scripts/llama2/qnn_llama_runner",
433    )
434    # No pregen inputs, input_list is not required
435    adb.push(inputs=[], input_list="", files=[args.tokenizer_bin])
436    adb.execute(custom_runner_cmd=runner_cmd)
437
438    # collect output data
439    output_data_folder = f"{args.artifact}/outputs"
440    make_output_dir(output_data_folder)
441    outputs = []
442
443    def post_process():
444        for f in sorted(
445            os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1])
446        ):
447            with codecs.open(
448                os.path.join(output_data_folder, f),
449                "r",
450                encoding="utf-8",
451                errors="replace",
452            ) as fdata:
453                outputs.append(fdata.read())
454
455    adb.pull(output_path=args.artifact, callback=post_process)
456
457    if args.ip and args.port != -1:
458        with Client((args.ip, args.port)) as conn:
459            conn.send(
460                json.dumps(
461                    {
462                        "result": outputs,
463                    }
464                )
465            )
466    else:
467        for idx, output in enumerate(outputs):
468            print(f"Results[{idx}]:\n{output}")
469
470
471# flake8: noqa: C901
472if __name__ == "__main__":
473    parser = setup_common_args_and_variables()
474    parser.add_argument(
475        "-a",
476        "--artifact",
477        help="path for storing generated artifacts and output by this example. Default ./llama2_qnn",
478        default="./llama2_qnn",
479        type=str,
480    )
481
482    parser.add_argument(
483        "-P",
484        "--ptq",
485        help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.",
486        default="16a4w",
487    )
488
489    parser.add_argument(
490        "--checkpoint",
491        help="Pass llama2 checkpoint.",
492        required=True,
493        type=str,
494    )
495
496    parser.add_argument(
497        "--params",
498        help="Pass llama2 params json file.",
499        required=True,
500        type=str,
501    )
502
503    parser.add_argument(
504        "--tokenizer_bin",
505        help="Pass llama2 tokenizer binary.",
506        required=True,
507        type=str,
508    )
509
510    parser.add_argument(
511        "--tokenizer_model",
512        help="Pass llama2 tokenizer model.",
513        type=str,
514        default=None,
515    )
516
517    parser.add_argument(
518        "--prompt",
519        help="User prompts for llama2.",
520        required=True,
521        type=str,
522    )
523
524    parser.add_argument(
525        "--seq_len",
526        help="Ouput sequence length for llama2.",
527        default=128,
528        type=int,
529    )
530
531    parser.add_argument(
532        "--temperature",
533        help="Sampling temperature for llama2.",
534        default=0.8,
535        type=float,
536    )
537
538    parser.add_argument(
539        "-d",
540        "--dtype-override",
541        default="fp32",
542        type=str,
543        choices=["fp32", "fp16"],
544        help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32",
545    )
546
547    parser.add_argument(
548        "--pre_gen_pte",
549        help="Run the Pre-generated llama2 in the given directory",
550        type=str,
551    )
552
553    args = parser.parse_args()
554    if args.compile_only and args.pre_gen_pte:
555        exit("Cannot set both compile_only and pre_gen_pte as true")
556
557    if args.pre_gen_pte:
558        inference(args, args.pre_gen_pte)
559        exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}")
560
561    if args.compile_only:
562        compile(args)
563        exit(f"Finish compile_only and save to {args.artifact}")
564
565    try:
566        compile(args)
567        inference(args)
568    except Exception as e:
569        if args.ip and args.port != -1:
570            with Client((args.ip, args.port)) as conn:
571                conn.send(json.dumps({"Error": str(e)}))
572        else:
573            raise Exception(e)
574