xref: /aosp_15_r20/external/executorch/examples/qualcomm/qaihub_scripts/llama/llama2/qaihub_llama2_7b.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9from multiprocessing.connection import Client
10
11import torch
12from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
13from executorch.backends.qualcomm.utils.utils import (
14    ExecutorchBackendConfig,
15    from_context_binary,
16    generate_htp_compiler_spec,
17    generate_qnn_executorch_compiler_spec,
18    get_soc_to_chipset_map,
19)
20from executorch.examples.qualcomm.qaihub_scripts.utils.utils import (
21    gen_pte_from_ctx_bin,
22    get_encoding,
23)
24from executorch.examples.qualcomm.utils import (
25    setup_common_args_and_variables,
26    SimpleADB,
27)
28from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
29
30
31def main(args):
32    os.makedirs(args.artifact, exist_ok=True)
33
34    target_names = (
35        [
36            f"llama_v2_7b_chat_quantized_PromptProcessor_{i}_Quantized.bin"
37            for i in range(1, 5)
38        ]
39        if args.use_prompt_processor
40        else [
41            f"llama_v2_7b_chat_quantized_TokenGenerator_{i}_Quantized.bin"
42            for i in range(1, 5)
43        ]
44    )
45
46    # common part for compile & inference
47    backend_options = generate_htp_compiler_spec(
48        use_fp16=False,
49        use_multi_contexts=True,
50    )
51    compiler_specs = generate_qnn_executorch_compiler_spec(
52        soc_model=getattr(QcomChipset, args.model),
53        backend_options=backend_options,
54        is_from_context_binary=True,
55    )
56
57    if args.use_prompt_processor:
58        pte_name = "qaihub_llama2_7b_prompt"
59        last_shard_num_inputs = 4
60        last_shard_num_outputs = 513
61    else:
62        pte_name = "qaihub_llama2_7b_token"
63        last_shard_num_inputs = 516
64        last_shard_num_outputs = 513
65
66    if args.pre_gen_pte is None:
67        # create custom operators as context loader
68        soc_model = get_soc_to_chipset_map()[args.model]
69        bundle_programs = [
70            from_context_binary(
71                ctx_path=f"{args.context_binaries}/{target}",
72                op_name=f"ctx_loader_{i}",
73                soc_model=soc_model,
74            )
75            for i, target in enumerate(target_names)
76        ]
77        pte_names = [f"{pte_name}_{i}" for i in range(len(target_names))]
78        memory_planning_pass = MemoryPlanningPass(
79            alloc_graph_input=False,
80            alloc_graph_output=False,
81        )
82        pte_files = gen_pte_from_ctx_bin(
83            artifact=args.artifact,
84            pte_names=pte_names,
85            bundle_programs=bundle_programs,
86            backend_config=ExecutorchBackendConfig(
87                memory_planning_pass=memory_planning_pass
88            ),
89        )
90    else:
91        pte_files = [f"{args.pre_gen_pte}/{pte_name}_{i}.pte" for i in range(4)]
92
93    if args.compile_only:
94        return
95
96    adb = SimpleADB(
97        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
98        build_path=args.build_folder,
99        pte_path=pte_files,
100        workspace=f"/data/local/tmp/executorch/{pte_name}",
101        device_id=args.device,
102        host_id=args.host,
103        soc_model=args.model,
104        runner="examples/qualcomm/qaihub_scripts/llama/qaihub_llama2_7b_runner",
105    )
106    output_file = "result.txt"
107    pos_embs_file = ["freq_cos", "freq_sin"]
108    encoding = get_encoding(
109        path_to_shard=f"{args.context_binaries}/{target_names[-1]}",
110        compiler_specs=compiler_specs,
111        get_input=False,
112        get_output=True,
113        num_input=last_shard_num_inputs,
114        num_output=last_shard_num_outputs,
115    )[0]
116    scale = encoding["scale"][-1]
117    offset = encoding["offset"][-1]
118    outputs = []
119    runner_args = [
120        *[
121            f"--sharded_{i+1}_path {os.path.basename(pte_file)}"
122            for i, pte_file in enumerate(pte_files)
123        ],
124        *[f"--{fname}_path {fname}.raw" for fname in pos_embs_file],
125        f"--output_path {adb.output_folder}/{output_file}",
126        f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}",
127        f"--prompt '{args.prompt}'",
128        f"--temperature {args.temperature}",
129        f"--seq_len {args.seq_len}",
130        f"--eval_mode {0 if args.use_prompt_processor else 1}",
131        f"--logits_scale {scale}",
132        f"--logits_offset {-offset}",
133    ]
134    runner_cmds = " ".join(
135        [
136            f"cd {adb.workspace} &&",
137            f"./qaihub_llama2_7b_runner {' '.join(runner_args)}",
138        ]
139    )
140
141    def compute_pos_embedding():
142        head_dim, max_seq_len, theta = 128, 1024, 10000.0
143        base = torch.arange(0, head_dim, 2)
144        freqs = 1.0 / (theta ** (base[: (head_dim // 2)].float() / head_dim))
145        t = torch.arange(max_seq_len * 2)
146        freqs = torch.outer(t, freqs).float()
147        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
148        freqs_cis = freqs_cis[0:max_seq_len]
149        freqs_real = torch.view_as_real(freqs_cis)
150        return freqs_real[:, :, 0], freqs_real[:, :, 1]
151
152    def post_process():
153        with open(f"{args.artifact}/outputs/{output_file}", "r") as f:
154            outputs.append(f.read())
155
156    custom_files = [args.tokenizer_bin]
157    for var_name, freq in zip(pos_embs_file, compute_pos_embedding()):
158        custom_files.append(f"{adb.working_dir}/{var_name}.raw")
159        scale, offset = (freq.max() - freq.min()) / 65535, 32768
160        freq = (freq / scale + offset).clip(min=0, max=65535).detach()
161        freq.to(dtype=torch.uint16).numpy().tofile(custom_files[-1])
162
163    if not args.skip_push:
164        adb.push(files=custom_files)
165    adb.execute(custom_runner_cmd=runner_cmds)
166    adb.pull(args.artifact, callback=post_process)
167    if args.ip and args.port != -1:
168        with Client((args.ip, args.port)) as conn:
169            conn.send(
170                json.dumps(
171                    {
172                        "result": outputs[0],
173                    }
174                )
175            )
176    else:
177        print(outputs[0])
178
179
180if __name__ == "__main__":
181    parser = setup_common_args_and_variables()
182
183    parser.add_argument(
184        "-a",
185        "--artifact",
186        help="path for storing generated artifacts by this example. Default ./llama2_qai_hub",
187        default="./llama2_qai_hub",
188        type=str,
189    )
190
191    parser.add_argument(
192        "--context_binaries",
193        help="path to context binaries generated from qai_hub",
194        required=True,
195    )
196
197    parser.add_argument(
198        "--use_prompt_processor",
199        help="tokens will be evaluated all at once",
200        default=False,
201        action="store_true",
202    )
203
204    parser.add_argument(
205        "--tokenizer_bin",
206        help="llama2 tokenizer binary",
207        required=True,
208        type=str,
209    )
210
211    parser.add_argument(
212        "--seq_len",
213        help="ouput sequence length for llama2",
214        default=128,
215        type=int,
216    )
217
218    parser.add_argument(
219        "--temperature",
220        help="sampling temperature for llama2",
221        default=0.0,
222        type=float,
223    )
224
225    parser.add_argument(
226        "--prompt",
227        help="user prompts for llama2",
228        required=True,
229        type=str,
230    )
231
232    parser.add_argument(
233        "--pre_gen_pte",
234        help="folder path to pre-compiled ptes",
235        default=None,
236        type=str,
237    )
238
239    args = parser.parse_args()
240
241    try:
242        main(args)
243    except Exception as e:
244        if args.ip and args.port != -1:
245            with Client((args.ip, args.port)) as conn:
246                conn.send(json.dumps({"Error": str(e)}))
247        else:
248            raise Exception(e)
249