xref: /aosp_15_r20/external/executorch/examples/mediatek/model_export_scripts/llama.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1import os
2import sys
3
4if os.getcwd() not in sys.path:
5    sys.path.append(os.getcwd())
6import argparse
7import struct
8import warnings
9
10import torch
11
12from aot_utils.llm_utils.preformatter import Preformatter
13from aot_utils.llm_utils.sanity_checks import (
14    check_all_chunks_same_num_layer,
15    check_between_inclusive,
16    check_exist,
17    check_ext,
18    check_old_arg,
19    check_shapes,
20    check_supported_model,
21    check_supported_tokenizer,
22    check_tokenizer_exist,
23    check_weights_exist,
24)
25from aot_utils.llm_utils.utils import (
26    dump_embedding_lut_for_cmdline,
27    generate_mask,
28    get_dest_path,
29    get_dirname,
30    get_embedding_layer,
31    get_exp_name,
32    get_export_shapes,
33    get_master_rot_emb,
34    get_normalized_config,
35    load_checkpoints,
36    resolve_model_classes,
37)
38from datasets import load_dataset
39from executorch import exir
40from executorch.backends.mediatek import (
41    NeuropilotPartitioner,
42    NeuropilotQuantizer,
43    Precision,
44)
45from executorch.exir.backend.backend_details import CompileSpec
46from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
47from tqdm import tqdm
48
49warnings.filterwarnings("ignore")
50
51
52def get_argument_parser():
53    parser = argparse.ArgumentParser(
54        description="Run Export to ET for suppoorted LLM models.", allow_abbrev=False
55    )
56    parser.add_argument(
57        "config",
58        type=str,
59        help="[Required] Model config json file. "
60        "Model config must be in same directory as all model weight bins and tokenizer files.",
61    )
62    parser.add_argument(
63        "-p",
64        "--precision",
65        type=str,
66        default="A16W8",
67        choices=["A16W4", "A16W8", "A16W16", "A8W4", "A8W8"],
68        help="Precision to quantize entire model to.",
69    )
70    parser.add_argument(
71        "-d",
72        "--dataset",
73        type=str,
74        default=None,
75        help="Calibration dataset name or path to dataset. Defaults to None to use random inputs",
76    )
77    parser.add_argument(
78        "-n",
79        "--num_chunks",
80        type=int,
81        default=4,
82        help="Number of chunks to cut the model into. Defaults to 4.",
83    )
84    parser.add_argument(
85        "-r",
86        "--response_cap",
87        type=int,
88        default=9,
89        help="Max Number of Response Tokens to save during calibration. Defaults to 9.",
90    )
91    parser.add_argument(
92        "--preformatter",
93        type=str,
94        default=None,
95        help="Preformatter Template to use to wrap input with. Defaults to None.",
96    )
97    parser.add_argument(
98        "-shapes",
99        nargs="+",
100        help="[Required] Expected input shapes to reconfigure TFLites to. Space separated list of "
101        "shapes in the format: xtyc (e.g. 32t512c)",
102    )
103
104    return parser
105
106
107# flake8: noqa: F405
108def args_sanity_checks(args):
109    check_old_arg(args.config)
110    check_exist(args.config, "Config file")
111    check_ext(args.config, ".json", "Config file")
112    config = get_normalized_config(args.config)
113
114    weight_dir = get_dirname(args.config)
115    check_tokenizer_exist(weight_dir)
116    check_weights_exist(weight_dir)
117
118    check_supported_model(config)
119    check_supported_tokenizer(config)
120
121    if args.preformatter is not None:
122        check_exist(args.preformatter, "Preformatter json file")
123        check_ext(args.preformatter, ".json", "preformatter")
124
125    if args.dataset is not None:
126        check_exist(args.dataset)
127
128    check_between_inclusive(args.num_chunks, 1, config.num_hidden_layers, "num_chunks")
129
130    check_shapes(args.shapes)
131
132
133def print_args(args, exp_name):
134    print("Please check if all arguments are correct:")
135    print(f"Config file:                  {args.config}")
136    print(f"Output pte folder:            pte/{exp_name}")
137    print(f"Quantization precision:       {args.precision}")
138    print(f"Preformatter:                 {args.preformatter}")
139    print(f"Calibration Dataset:          {args.dataset}")
140    print(f"Max Response Tokens:          {args.response_cap}")
141    print(f"Number of chunks:             {args.num_chunks}")
142    print(f"Export shape(s):              {args.shapes}")
143    print()
144
145
146def apply_preformatter(inp, preformatter=None):
147    formatted_text = preformatter.generate_prompt(inp["text"])
148    inp["text"] = formatted_text
149    print(f"Formatted Prompt:\n{formatted_text}")
150    return inp
151
152
153def tokenize_dataset(inp, tokenizer):
154    text = inp["text"]
155    inp_encoded = tokenizer(text, return_tensors="pt")  # dict
156    inp_encoded.pop("attention_mask")
157    inp_encoded = inp_encoded["input_ids"]
158    inp_encoded = inp_encoded.to(torch.int32)
159    inp["input_ids"] = inp_encoded
160    inp.pop("text")
161    return inp
162
163
164def reset_cache(
165    num_chunks, num_key_value_heads, num_blocks_per_chunk, head_dim, max_cache_size
166):
167    cache = []
168    for i in range(num_chunks):
169        curr_chunk_cache = torch.zeros(
170            (
171                2 * num_blocks_per_chunk[i],
172                num_key_value_heads,
173                max_cache_size,  # generate fixed cache as torch dynamic shape cannot handle 2 dynamic dim
174                head_dim,
175            ),
176            dtype=torch.float32,
177        )
178        cache.append(curr_chunk_cache)
179    return cache
180
181
182def forward_and_save(
183    models,
184    hidden_state,
185    cache,
186    mask,
187    pos_emb,
188    model_input_dict,
189    num_blocks_per_chunk,
190    batch_name,
191):
192    for chunk_idx in range(len(models)):
193        cache_in = cache[chunk_idx]
194
195        try:
196            model_input_dict[str(chunk_idx)] = {
197                **model_input_dict[str(chunk_idx)],
198                batch_name: {
199                    "hidden_state": hidden_state,
200                    "mask": mask,
201                    "pos_emb": pos_emb,
202                    "cache": cache_in,
203                },
204            }
205        except:
206            model_input_dict[str(chunk_idx)] = {
207                batch_name: {
208                    "hidden_state": hidden_state,
209                    "mask": mask,
210                    "pos_emb": pos_emb,
211                    "cache": cache_in,
212                }
213            }
214        with torch.no_grad():
215            model_out = models[chunk_idx](
216                hidden_state, mask, pos_emb, *torch.split(cache_in, 1, dim=0)
217            )
218        hidden_state = model_out[0]
219        cache[chunk_idx] = torch.cat(
220            model_out[1 : 1 + 2 * num_blocks_per_chunk[chunk_idx]], dim=0
221        ).clone()
222    return hidden_state, cache
223
224
225def prepare_model_inputs(
226    inp,
227    models,
228    embedding_layer,
229    master_rot_emb,
230    num_blocks_per_chunk,
231    num_key_value_heads,
232    head_dim,
233    max_cache_size,
234    eos_token_id_tensor,
235    response_cap,
236):
237    model_input_dict = {str(i): None for i in range(len(models))}
238    input_ids = inp.pop("input_ids")
239    hidden_state = embedding_layer(torch.tensor(input_ids))
240    input_length = hidden_state.shape[1]
241    # Assume fixed cache size
242    mask = generate_mask(max_cache_size, 0, input_length, input_length)
243    pos_emb = master_rot_emb[:, :, :input_length, :]
244    # cache shape: num chunks of 2*num_block, num kv heads, c, head dim
245    cache = reset_cache(
246        len(models), num_key_value_heads, num_blocks_per_chunk, head_dim, max_cache_size
247    )  # empty kv
248    logits, cache = forward_and_save(
249        models,
250        hidden_state,
251        cache,
252        mask,
253        pos_emb,
254        model_input_dict,
255        num_blocks_per_chunk,
256        "prompt",
257    )
258    next_token_logits = logits[:, -1, :]  # last layer logits
259    next_token = torch.argmax(next_token_logits, dim=-1)
260    response_count = 0
261    seq_length = input_length
262    while True:
263        curr_input_id = next_token[:, None].to(torch.int32)
264        input_length = curr_input_id.shape[1]
265        hidden_state = embedding_layer(curr_input_id)
266        mask = generate_mask(max_cache_size, seq_length, input_length, input_length)
267        pos_emb = master_rot_emb[:, :, seq_length : seq_length + input_length, :]
268        logits, cache = forward_and_save(
269            models,
270            hidden_state,
271            cache,
272            mask,
273            pos_emb,
274            model_input_dict,
275            num_blocks_per_chunk,
276            f"response{response_count}",
277        )
278        next_token_logits = logits[:, -1, :]
279        next_token = torch.argmax(next_token_logits, dim=-1)
280        if next_token == eos_token_id_tensor:
281            print(f"Found EOS on batch: {response_count}")
282            break
283
284        response_count += 1
285        seq_length += input_length
286        if response_count == response_cap:
287            break
288
289    return model_input_dict
290
291
292def calibrate_model(model, cal_dataset, chunk_idx: str):
293    with torch.no_grad():
294        for inp in tqdm(cal_dataset, desc="Calibrating Model: "):
295            # pass prompt and response
296            for batch in tqdm(inp[chunk_idx].keys(), desc="Batch: "):
297                if inp[chunk_idx][batch] is not None:
298                    inputs_embeds = torch.tensor(inp[chunk_idx][batch]["hidden_state"])
299                    mask = torch.tensor(inp[chunk_idx][batch]["mask"])
300                    pos_emb = torch.tensor(inp[chunk_idx][batch]["pos_emb"])
301                    cache = torch.tensor(inp[chunk_idx][batch]["cache"])
302                    model(inputs_embeds, mask, pos_emb, *torch.split(cache, 1, dim=0))
303
304
305def export_to_et_ir(
306    output_folder,
307    exp_name,
308    model,
309    precision,
310    max_num_token,
311    max_cache_size,
312    chunk_idx,
313    export_shapes,
314    cal_dataset=None,
315):
316    print(f"Exporting Chunk {chunk_idx} to PTE")
317    example_inputs, dynamic_shapes = model.get_example_inputs(
318        max_num_token, max_cache_size, True
319    )
320    print("Getting pre autograd ATen Dialect Graph")
321    pre_autograd_aten_dialect = torch.export.export_for_training(
322        model, example_inputs, dynamic_shapes=dynamic_shapes
323    ).module()  # NOTE: Will be replaced with export
324    quantizer = NeuropilotQuantizer()
325    quantizer.setup_precision(getattr(Precision, precision))
326    prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
327    # at this point quant min max are inf
328    if cal_dataset is not None:
329        calibrate_model(prepared_graph, cal_dataset, str(chunk_idx))
330    else:
331        prepared_graph(*example_inputs)  # dummy calibration
332    converted_graph = convert_pt2e(prepared_graph, fold_quantize=False)
333
334    print("Getting ATen Dialect Graph")
335    # Fixed Shape Export Here
336    for shape, ntok_and_cache in export_shapes.items():
337        dest_path = get_dest_path(output_folder, exp_name, shape, chunk_idx)
338        print(f"Exporting Shape {shape} to:\n{dest_path}")
339        example_inputs = model.get_example_inputs(*ntok_and_cache)
340        aten_dialect: exir.ExportedProgram = torch.export.export(
341            converted_graph, example_inputs
342        )
343
344        print("Lowering to Edge Dialect Graph")
345        edge_program: exir.EdgeProgramManager = exir.to_edge(
346            aten_dialect,
347            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
348        )
349        del aten_dialect
350
351        print("Delegating Edge Program to Neuropilot Backend")
352        compile_spec = [
353            CompileSpec("gno", struct.pack("3s", b"LTS")),
354            CompileSpec("gno-exp", struct.pack("0s", b"")),
355            CompileSpec("gno-non-4d-tiling", struct.pack("0s", b"")),
356            CompileSpec("ImportForever", struct.pack("?", True)),
357        ]
358        partitioner = NeuropilotPartitioner(compile_spec)
359        delegated_program = edge_program.to_backend(partitioner)
360        print("Exported Delegated Program:")
361        print(delegated_program.exported_program())
362        del edge_program
363
364        print("Transforming delegated program to executorch backend")
365        executorch_program = delegated_program.to_executorch(
366            config=exir.ExecutorchBackendConfig(
367                memory_planning_pass=exir.passes.MemoryPlanningPass(
368                    alloc_graph_input=False,
369                    alloc_graph_output=False,
370                ),
371                extract_delegate_segments=True,
372            )
373        )
374
375        print(f"ET Model Dest: {dest_path}\n")
376        os.makedirs(dest_path.rsplit("/", 1)[0], exist_ok=True)
377        with open(dest_path, "wb") as file:
378            file.write(executorch_program.buffer)
379
380
381def main():
382    parser = get_argument_parser()
383    args = parser.parse_args()
384    args_sanity_checks(args)
385    if args.dataset is None:
386        exp_name = f"{get_exp_name(args.config)}_{args.precision}_dummy_cal_{args.num_chunks}_chunks"
387    else:
388        exp_name = (
389            f"{get_exp_name(args.config)}_{args.precision}_{args.num_chunks}_chunks"
390        )
391    print_args(args, exp_name)
392
393    config, weight_dir, tokenizer_class, chunk_class = resolve_model_classes(
394        args.config
395    )
396    tokenizer = tokenizer_class.from_pretrained(weight_dir)
397    if args.preformatter is not None:
398        preformatter = Preformatter(args.preformatter)
399
400    head_dim = int(config.hidden_size / config.num_attention_heads)
401
402    # Evenly distribute the layers across chunks.
403    num_blocks_per_chunk = [
404        (config.num_hidden_layers // args.num_chunks)
405        + (i < (config.num_hidden_layers % args.num_chunks))
406        for i in range(args.num_chunks)
407    ]
408    check_all_chunks_same_num_layer(num_blocks_per_chunk)  # noqa: F405
409
410    output_folder = os.path.join("pte", exp_name)
411
412    # Load all collected checkpoint files into one giant state_dict
413    state_dict = load_checkpoints(weight_dir)
414
415    dump_embedding_lut_for_cmdline(weight_dir, state_dict, config)
416
417    export_shapes, max_num_token, max_cache_size = get_export_shapes(args.shapes)
418    print(f"export shapes: {export_shapes}")
419    print(f"Max Num Token: {max_num_token}")
420    print(f"Max Cache Size: {max_cache_size}")
421
422    if args.dataset is not None:
423        embedding_layer = get_embedding_layer(config, weight_dir, state_dict)
424
425    # Instantiate model chunks
426    print("Instantiating submodels")
427    models = []
428    for chunk_idx, num_blocks in enumerate(num_blocks_per_chunk):
429        chunk = chunk_class(
430            config,
431            num_blocks,
432            chunk_idx=chunk_idx,
433            dtype=torch.float32,
434            include_tail=(chunk_idx == args.num_chunks - 1),
435            jit_trace=True,
436        )
437        chunk = chunk.load_weights(state_dict, sum(num_blocks_per_chunk[:chunk_idx]))
438        models.append(chunk)
439
440    cal_dataset = None
441    if args.dataset is not None:
442        cal_dataset = load_dataset("text", data_files=args.dataset, split="train")
443        master_rot_emb = get_master_rot_emb(config, dtype=torch.float32)
444        if args.preformatter is not None:
445            cal_dataset = cal_dataset.map(
446                apply_preformatter, fn_kwargs={"preformatter": preformatter}
447            )
448        cal_dataset = cal_dataset.map(
449            tokenize_dataset, fn_kwargs={"tokenizer": tokenizer}
450        )
451        print("Preparing Model Calibration Inputs...")
452        cal_dataset = cal_dataset.map(
453            prepare_model_inputs,
454            fn_kwargs={
455                "models": models,
456                "embedding_layer": embedding_layer,
457                "master_rot_emb": master_rot_emb,
458                "num_blocks_per_chunk": num_blocks_per_chunk,
459                "num_key_value_heads": config.num_key_value_heads,
460                "head_dim": head_dim,
461                "max_cache_size": max_cache_size,
462                "eos_token_id_tensor": torch.tensor(tokenizer.eos_token_id),
463                "response_cap": args.response_cap,
464            },
465        )
466
467    for chunk_idx, chunk in enumerate(models):
468        export_to_et_ir(
469            output_folder,
470            exp_name,
471            chunk,
472            args.precision,
473            max_num_token,
474            max_cache_size,
475            chunk_idx,
476            export_shapes,
477            cal_dataset,
478        )
479
480
481if __name__ == "__main__":
482    main()
483