xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/llama3_2/llama.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import getpass
8import json
9import logging
10import os
11
12import sys
13import time
14from multiprocessing.connection import Client
15
16import torch
17from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo
18
19from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
20
21from executorch.backends.qualcomm.quantizer.custom_annotation import (
22    annotate_matmul_16a8w,
23    custom_annotate_llama_last_conv_16a8w,
24)
25
26from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
27from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
28from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO
29from executorch.backends.qualcomm.utils.utils import (
30    capture_program,
31    convert_linear_to_conv2d,
32    generate_htp_compiler_spec,
33    generate_qnn_executorch_compiler_spec,
34    get_soc_to_chipset_map,
35)
36from executorch.examples.qualcomm.oss_scripts.llama2.model.static_llama import (
37    LlamaModel,
38    ModelArgs,
39)
40from executorch.examples.qualcomm.utils import (
41    make_output_dir,
42    make_quantizer,
43    setup_common_args_and_variables,
44    SimpleADB,
45)
46from executorch.exir import EdgeCompileConfig, EdgeProgramManager
47from executorch.exir.capture._config import ExecutorchBackendConfig
48from executorch.exir.dialects._ops import ops as exir_ops
49from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
50from executorch.extension.llm.custom_ops import model_sharding
51from executorch.extension.llm.export.builder import DType
52from executorch.extension.llm.tokenizer.utils import get_tokenizer
53
54from torch.ao.quantization.observer import MinMaxObserver
55from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
56
57sys.setrecursionlimit(4096)
58FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
59logging.basicConfig(level=logging.INFO, format=FORMAT)
60logging.getLogger().setLevel(logging.INFO)
61
62pte_filename = "llama3_2_qnn"
63
64
65def calibrate(
66    example_inputs,
67    user_prompts,
68    module: torch.fx.GraphModule,
69    tokenizer_model_path="tokenizer.model",
70):
71    sp_model = get_tokenizer(tokenizer_model_path)
72    _, _, atten_mask, k_caches, v_caches = example_inputs
73
74    # TODO: change criteria & support batch inputs if necessary
75    pos = torch.tensor(0, dtype=torch.int32)
76    token_list = sp_model.encode(user_prompts, bos=True, eos=False)
77
78    with torch.no_grad():
79        while token_list[-1] != sp_model.eos_id and pos < 511:
80            logits, new_k_caches, new_v_caches = module(
81                torch.full((1, 1), token_list[pos], dtype=torch.int32),
82                torch.full((1, 1), pos),
83                atten_mask,
84                *k_caches,
85                *v_caches,
86            )
87            k_caches = [
88                torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
89                for i, k_cache in enumerate(k_caches)
90            ]
91            v_caches = [
92                torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
93                for i, v_cache in enumerate(v_caches)
94            ]
95
96            pos += 1
97            atten_mask[0][-pos - 1] = 0
98            if pos >= len(token_list):
99                token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
100
101    print(f"calibration data:\n{sp_model.decode(token_list)}")
102
103
104class SingleLlama:
105    def __init__(self, llama_model) -> None:
106        super().__init__()
107        self.llama_model = llama_model
108        self.quant_dtype = None
109        self.llama_meta = self.llama_model.get_metadata()
110        self.has_quant_io = False
111        tokens, pos_ids, atten_mask, k_caches, v_caches = self.get_example_inputs()
112        self.inputs = (tokens, pos_ids, atten_mask, *k_caches, *v_caches)
113
114    def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type):
115        if not self.has_quant_io:
116            return
117
118        # shape of k caches and v caches
119        input_cache_shape = {
120            (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]),
121            (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]),
122        }
123        for n in gm.graph.nodes:
124            if (
125                n.op == "placeholder"
126                and len(users := list(n.users)) == 1
127                and users[0].meta["val"].size()[-2:] in input_cache_shape
128            ):
129                n.meta[QCOM_QUANTIZED_IO] = kv_type
130            elif n.op == "output":
131                for a in n.args[0]:
132                    if (
133                        a.meta["val"].flatten().size()[0]
134                        == self.llama_meta["get_head_dim"]
135                    ):
136                        a.meta[QCOM_QUANTIZED_IO] = kv_type
137
138            # Tag sharding io
139            if exir_ops.edge.llama.fallback.default in [
140                u.target for u in list(n.users.keys())
141            ] + [n.target]:
142                n.meta[QCOM_QUANTIZED_IO] = sharding_type
143
144    def quantize(self, quant_dtype, custom_annotations=()):
145        self.quant_dtype = quant_dtype
146        quantizer = make_quantizer(
147            quant_dtype=quant_dtype,
148            per_channel_conv=True,
149            per_channel_linear=True,
150            act_observer=MinMaxObserver,
151        )
152        quantizer.add_custom_quant_annotations(custom_annotations)
153
154        self.has_quant_io = True
155        fx_graph_module = None
156
157        with torch.no_grad():
158            fx_graph_module = torch.export.export(
159                self.llama_model, self.inputs
160            ).module()
161            fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
162        logging.info("Quantizing the model...")
163        calibrate(
164            self.get_example_inputs(),
165            args.prompt,
166            fx_graph_module,
167            tokenizer_model_path=args.tokenizer_model,
168        )
169
170        self.llama_model = convert_pt2e(fx_graph_module)
171
172    def lowering_modules(
173        self,
174        work_space,
175        kv_type=torch.uint8,
176        sharding_type=torch.uint16,
177        use_fp16=False,
178        soc_model=QcomChipset.SM8650,
179        num_sharding=0,
180    ):
181        executorch_config = ExecutorchBackendConfig(
182            passes=[
183                BuildQuantIo(),
184            ],
185            # For shared buffer, user must pass the memory address
186            # which is allocated by RPC memory to executor runner.
187            # Therefore, won't want to pre-allocate
188            # by memory manager in runtime.
189            memory_planning_pass=MemoryPlanningPass(
190                alloc_graph_input=False,
191                alloc_graph_output=False,
192            ),
193            extract_delegate_segments=True,
194        )
195        with torch.no_grad():
196            # backend option
197            backend_options = generate_htp_compiler_spec(use_fp16=use_fp16)
198            compiler_specs = generate_qnn_executorch_compiler_spec(
199                soc_model=soc_model,
200                backend_options=backend_options,
201                shared_buffer=True,
202            )
203            skip_node_op_set = {"llama.fallback.default"}
204            partitioner = QnnPartitioner(
205                compiler_specs, skip_node_op_set=skip_node_op_set
206            )
207            edge_prog = capture_program(
208                self.llama_model, self.inputs, custom_pass_config=frozenset()
209            )
210
211            if num_sharding > 0:
212                model_sharding.split_graph(
213                    edge_prog.exported_program,
214                    self.llama_meta["get_n_layers"],
215                    shares=num_sharding,
216                )
217
218            self._tag_kv_ios(
219                edge_prog.exported_program.graph_module,
220                kv_type=kv_type,
221                sharding_type=sharding_type,
222            )
223            edge_prog_mgr = EdgeProgramManager(
224                edge_programs={"forward": edge_prog.exported_program},
225                constant_methods=self.llama_meta,
226                compile_config=EdgeCompileConfig(_check_ir_validity=False),
227            )
228            edge_prog_mgr = edge_prog_mgr.to_backend(partitioner)
229            exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config)
230            with open(f"{work_space}/{pte_filename}.pte", "wb") as file:
231                exec_prog_mgr.write_to_file(file)
232
233    def get_example_inputs(self):
234        return self.llama_model.get_example_inputs()
235
236
237def compile(args):
238    os.makedirs(args.artifact, exist_ok=True)
239    start_ts = time.time()
240    with open(args.params) as f:
241        config = ModelArgs(**json.load(f))
242        # TODO: support batch inputs if necessary
243        config.max_batch_size = 1
244        config.max_seq_len = 512
245    state_dict = torch.load(
246        args.checkpoint, weights_only=True, map_location="cpu", mmap=True
247    )
248
249    llama_instance = None
250    with torch.device("meta"):
251        llama_instance = LlamaModel(config, output_new_cache_only=True)
252    if "model" in state_dict:
253        state_dict = state_dict["model"]
254    llama_instance.load_state_dict(
255        state_dict,
256        strict=False,
257        assign=True,
258    )
259    end_load_ts = time.time()
260    logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}")
261
262    for layer in llama_instance.layers:
263        if getattr(layer.attention, "prepare_sha", None):
264            layer.attention.prepare_sha()
265
266    use_fp16 = False
267    if args.ptq != None:
268        kv_type = torch.uint8
269        if args.ptq == "8a8w":
270            sharding_type = torch.uint8
271        elif args.ptq == "16a4w":
272            sharding_type = torch.uint16
273        else:
274            assert args.ptq in [
275                "8a8w",
276                "16a4w",
277            ], f"No support for quant type {args.ptq}. Support 8a8w and 16a4w."
278        quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
279    else:
280        use_fp16 = True
281        kv_type = torch.float32
282        sharding_type = torch.float32
283    assert args.tokenizer_model is not None, "Need tokenizer model for calibration"
284
285    if args.dtype_override is not None:
286        dtype_override = DType[args.dtype_override]
287        llama_instance = llama_instance.to(dtype_override.to_torch_dtype())
288
289    llama_instance = convert_linear_to_conv2d(llama_instance)
290    single_llama = SingleLlama(llama_instance.eval())
291
292    if args.ptq != None:
293        start_quantize_ts = time.time()
294        single_llama.quantize(
295            quant_dtype,
296            custom_annotations=(
297                custom_annotate_llama_last_conv_16a8w,
298                annotate_matmul_16a8w,
299            ),
300        )
301        end_quantize_ts = time.time()
302        logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}")
303
304    start_lowering_ts = time.time()
305    single_llama.lowering_modules(
306        args.artifact,
307        kv_type=kv_type,
308        sharding_type=sharding_type,
309        use_fp16=use_fp16,
310        soc_model=get_soc_to_chipset_map()[args.model],
311        num_sharding=args.num_sharding,
312    )
313    end_lowering_ts = time.time()
314    logging.info(f"Time for compiling: {end_lowering_ts - start_lowering_ts}")
315
316
317def inference(args, pre_gen_pte=""):
318    workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama"
319
320    runner_args = " ".join(
321        [
322            f"--model_path {pte_filename}.pte",
323            "--output_path outputs/outputs.txt",
324            f"--tokenizer_path {os.path.basename(args.tokenizer_model)}",
325            f'--prompt "{args.prompt}"',
326            f"--seq_len {args.seq_len}",
327            f"--temperature {args.temperature}",
328        ]
329    )
330    runner_cmd = " ".join(
331        [
332            f"cd {workspace} &&",
333            f"./qnn_llama3_2_{args.model_size.lower()}_runner {runner_args}",
334        ]
335    )
336
337    pte_path = (
338        f"{pre_gen_pte}/{pte_filename}.pte"
339        if pre_gen_pte
340        else f"{args.artifact}/{pte_filename}.pte"
341    )
342    adb = SimpleADB(
343        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
344        build_path=f"{args.build_folder}",
345        pte_path=pte_path,
346        workspace=workspace,
347        device_id=args.device,
348        host_id=args.host,
349        soc_model=args.model,
350        shared_buffer=args.shared_buffer,
351        runner=f"examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_{args.model_size.lower()}_runner",
352    )
353    # No pregen inputs, input_list is not required
354    adb.push(inputs=[], input_list="", files=[args.tokenizer_model])
355    adb.execute(custom_runner_cmd=runner_cmd)
356
357    # collect output data
358    output_data_folder = f"{args.artifact}/outputs"
359    make_output_dir(output_data_folder)
360    outputs = []
361
362    def post_process():
363        with open(f"{args.artifact}/outputs/outputs.txt", "r") as f:
364            outputs.append(f.read())
365
366    adb.pull(output_path=args.artifact, callback=post_process)
367
368    if args.ip and args.port != -1:
369        with Client((args.ip, args.port)) as conn:
370            conn.send(
371                json.dumps(
372                    {
373                        "result": outputs,
374                    }
375                )
376            )
377    else:
378        for idx, output in enumerate(outputs):
379            logging.info(f"Results[{idx}]:\n{output}")
380
381
382# flake8: noqa: C901
383if __name__ == "__main__":
384    parser = setup_common_args_and_variables()
385    parser.add_argument(
386        "-a",
387        "--artifact",
388        help="path for storing generated artifacts and output by this example. Default ./llama3_2_qnn",
389        default="./llama3_2_qnn",
390        type=str,
391    )
392
393    parser.add_argument(
394        "-P",
395        "--ptq",
396        help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.",
397        type=str,
398    )
399
400    parser.add_argument(
401        "--checkpoint",
402        help="Pass llama checkpoint.",
403        required=True,
404        type=str,
405    )
406
407    parser.add_argument(
408        "--params",
409        help="Pass llama params json file.",
410        required=True,
411        type=str,
412    )
413
414    parser.add_argument(
415        "--model_size",
416        help="Determine what runner be used. For llama 3.2, we only support 1B/3B. ",
417        choices=["1B", "3B"],
418        required=True,
419        type=str,
420    )
421
422    parser.add_argument(
423        "--tokenizer_model",
424        help="Pass llama tokenizer model.",
425        type=str,
426        default=None,
427    )
428
429    parser.add_argument(
430        "--prompt",
431        help="User prompts for llama.",
432        required=True,
433        type=str,
434    )
435
436    parser.add_argument(
437        "--seq_len",
438        help="Ouput sequence length for llama.",
439        default=128,
440        type=int,
441    )
442
443    parser.add_argument(
444        "--temperature",
445        help="Sampling temperature for llama.",
446        default=0.8,
447        type=float,
448    )
449
450    parser.add_argument(
451        "-d",
452        "--dtype-override",
453        default="fp32",
454        type=str,
455        choices=["fp32", "fp16"],
456        help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32",
457    )
458
459    parser.add_argument(
460        "--pre_gen_pte",
461        help="Run the Pre-generated llama in the given directory",
462        type=str,
463    )
464
465    parser.add_argument(
466        "--num_sharding",
467        type=int,
468        default=0,
469        help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.",
470    )
471
472    args = parser.parse_args()
473    if args.compile_only and args.pre_gen_pte:
474        exit("Cannot set both compile_only and pre_gen_pte as true")
475
476    if args.pre_gen_pte:
477        inference(args, args.pre_gen_pte)
478        exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}")
479
480    if args.compile_only:
481        compile(args)
482        exit(f"Finish compile_only and save to {args.artifact}")
483
484    try:
485        compile(args)
486        inference(args)
487    except Exception as e:
488        if args.ip and args.port != -1:
489            with Client((args.ip, args.port)) as conn:
490                conn.send(json.dumps({"Error": str(e)}))
491        else:
492            raise Exception(e)
493