xref: /aosp_15_r20/external/executorch/examples/mediatek/aot_utils/llm_utils/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1import os
2import sys
3
4if os.getcwd() not in sys.path:
5    sys.path.append(os.getcwd())
6import json
7import math
8
9import numpy as np
10import torch
11from safetensors.torch import load_file
12
13
14# flake8: noqa: C901
15
16
17def _get_embedding_weight(weight_dir, state_dict):
18    try_last = False
19    checkpoint_filename = None
20    if state_dict is None:
21        checkpoint_files = [
22            os.path.join(weight_dir, f)
23            for f in os.listdir(weight_dir)
24            if (
25                (f.startswith("pytorch_model") and f.endswith(".bin"))
26                or (f.startswith("model") and f.endswith(".safetensors"))
27            )
28        ]
29
30        for f in checkpoint_files:
31            if "pytorch_model.bin" in f:
32                checkpoint_filename = f
33                break
34            elif "pytorch_model-00001-of" in f:
35                checkpoint_filename = f
36                break
37            elif "model.safetensors" in f:
38                checkpoint_filename = f
39                break
40            elif "model-00001-of" in f:
41                checkpoint_filename = f
42                break
43        if checkpoint_filename is None:
44            raise FileNotFoundError(
45                f"Unable to find the first checkpoint file in {weight_dir}! "
46                "This folder must have either the file pytorch_model.bin or "
47                "pytorch_model-00001-of-XXXXX.bin or "
48                "model.safetensors or "
49                "model-00001-of-XXXXX.safetensors."
50            )
51
52        if checkpoint_filename.endswith(".bin"):
53            state_dict = torch.load(
54                checkpoint_filename, map_location="cpu", weights_only=True
55            )
56        elif checkpoint_filename.endswith(".safetensors"):
57            state_dict = load_file(checkpoint_filename, device="cpu")
58        try_last = True
59
60    state_dict_keys = list(state_dict.keys())
61
62    expected_embedding_subkey = "embed_tokens.weight"
63
64    embed_key = None
65    for key in state_dict_keys:
66        if expected_embedding_subkey in key:
67            embed_key = key
68            break
69    if embed_key is None:
70        if try_last:
71            if (
72                checkpoint_filename == "pytorch_model.bin"
73                or checkpoint_filename == "model.safetensors"
74            ):
75                print("state_dict keys:", state_dict_keys)
76                raise KeyError(
77                    f"Cannot find embedding layer weight inside {checkpoint_filename}. "
78                    f"Please ensure embedding layer weight key contains {expected_embedding_subkey}"
79                )
80            else:
81                checkpoint_filename = checkpoint_filename.replace(
82                    "00001", checkpoint_filename.split("-")[-1].split(".")[0]
83                )
84                if checkpoint_filename.endswith(".bin"):
85                    state_dict = torch.load(
86                        checkpoint_filename, map_location="cpu", weights_only=True
87                    )
88                elif checkpoint_filename.endswith(".safetensors"):
89                    state_dict = load_file(checkpoint_filename, device="cpu")
90                state_dict_keys = list(state_dict.keys())
91                for key in state_dict_keys:
92                    if expected_embedding_subkey in key:
93                        embed_key = key
94                        break
95                if embed_key is None:
96                    print("state_dict keys:", state_dict_keys)
97                    raise KeyError(
98                        f"Cannot find embedding layer weight inside {checkpoint_filename}. "
99                        f"Please ensure embedding layer weight key contains {expected_embedding_subkey}"
100                    )
101        else:
102            print("state_dict keys:", state_dict_keys)
103            raise KeyError(
104                f"Cannot find embedding layer weight inside state dict. "
105                f"Please ensure embedding layer weight key contains {expected_embedding_subkey}"
106            )
107    return state_dict[embed_key]
108
109
110def chunk_and_tokenize_prompt(
111    prompt,
112    tokenizer,
113    sub_responses,
114    max_len,
115    response_handler,
116    preformatter=None,
117    wikitext=False,
118):
119    if max_len == 0:
120        # No chunking
121        if preformatter is not None:
122            prompt_formatted = preformatter.generate_prompt(prompt, None)
123        else:
124            prompt_formatted = prompt
125
126        if tokenizer is None:
127            with response_handler:
128                print("Prompt tokens:")
129                print(prompt)
130            prompt_tokens = prompt_formatted
131        else:
132            with response_handler:
133                if preformatter is not None:
134                    print(f"Prompt (with {preformatter.name} preformatter):")
135                    print(prompt)
136                else:
137                    print("Prompt:")
138                    print(prompt)
139            prompt_tokens = tokenizer(prompt_formatted, return_tensors="np")[
140                "input_ids"
141            ].astype(np.int32)
142        return prompt_tokens, None
143    else:
144        if wikitext:
145            # Wikitext chunking, tokenized already
146            if prompt.shape[1] < max_len:
147                return prompt, None
148            else:
149                return prompt[:, :max_len], prompt[:, max_len:]
150
151        else:
152            # Oppo streaming prompt chunking
153            sentences = prompt.split("\n")
154            chunked = False
155            curr_chunk = ""
156            prev_chunk = None
157            prev_chunk_tokens = None
158
159            for sentence in sentences:
160                if preformatter is not None:
161                    if len(sub_responses) == 0:
162                        curr_chunk_formatted = preformatter.generate_prompt(
163                            curr_chunk, None
164                        )
165                    else:
166                        curr_chunk_formatted = preformatter.generate_prompt(
167                            curr_chunk, sub_responses[-1]
168                        )
169                else:
170                    curr_chunk_formatted = curr_chunk
171                if tokenizer is None:
172                    curr_chunk_tokens = curr_chunk_formatted
173                else:
174                    curr_chunk_tokens = tokenizer(
175                        curr_chunk_formatted, return_tensors="np"
176                    )["input_ids"].astype(np.int32)
177
178                if curr_chunk_tokens.shape[1] < max_len:
179                    prev_chunk = curr_chunk
180                    prev_chunk_tokens = curr_chunk_tokens
181                    curr_chunk += sentence + "\n"
182                else:
183                    chunked = True
184                    break
185
186            if prev_chunk_tokens is None:
187                raise RuntimeError(
188                    f"Length of a single line ({curr_chunk_tokens.shape[1]}) is more than maximum length to chunk prompt to ({max_len})"
189                )
190
191            if chunked:
192                with response_handler:
193                    if preformatter is not None:
194                        if len(sub_responses) == 0:
195                            print(f"Prompt (with {preformatter.name} preformatter):")
196                            print(prev_chunk)
197                        else:
198                            print(
199                                f"Prompt (with {preformatter.name} preformatter with input):"
200                            )
201                            print(prev_chunk)
202                    else:
203                        print("Prompt:")
204                        print(prev_chunk)
205                return prev_chunk_tokens, prompt.split(prev_chunk)[1]
206            else:
207                with response_handler:
208                    if preformatter is not None:
209                        if len(sub_responses) == 0:
210                            print(f"Prompt (with {preformatter.name} preformatter):")
211                            print(curr_chunk)
212                        else:
213                            print(
214                                f"Prompt (with {preformatter.name} preformatter with input):"
215                            )
216                            print(curr_chunk)
217                    else:
218                        print("Prompt:")
219                        print(curr_chunk)
220                return curr_chunk_tokens, None
221
222
223def dump_embedding_lut_for_cmdline(weight_dir, state_dict, config):
224    model_name = os.path.basename(weight_dir)
225    output_path = os.path.join(weight_dir, f"embedding_{model_name}_fp32.bin")
226    if not os.path.exists(output_path):
227        embedding = (
228            _get_embedding_weight(weight_dir, state_dict)
229            .to(torch.float32)
230            .cpu()
231            .numpy()
232        )
233
234        with open(output_path, "wb") as f:
235            f.write(embedding.flatten().tobytes())
236        print(f"cmdline LUT embedding bin exported to {output_path}")
237
238
239def generate_alibi(
240    cache_size,
241    valid_cache,
242    input_length,
243    valid_input,
244    num_heads,
245    batch_size=1,
246    pytorch=False,
247):
248    assert (
249        valid_input <= input_length
250    ), "valid_input must be less than or equal to input_length"
251    assert (
252        valid_cache <= cache_size
253    ), "valid_cache must be less than or equal to cache_size"
254    valid_seq_length = valid_cache + valid_input
255    total_valid = np.ones((batch_size, valid_seq_length), dtype=np.int32)
256    closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
257    base = 2 ** ((-(2 ** -(math.log2 - 3))))
258    powers = np.arange(1, 1 + closest_power_of_2, dtype=np.int32)
259    slopes = np.power(base, powers)
260
261    if closest_power_of_2 != num_heads:
262        extra_base = 2 ** ((-(2 ** -(math.log2(2 * closest_power_of_2) - 3))))
263        num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
264        extra_powers = np.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=np.int32)
265        slopes = np.concatenate([slopes, np.power(extra_base, extra_powers)], axis=0)
266
267    arange_tensor = ((np.cumsum(total_valid, axis=-1) - 1))[:, None, :]
268    alibi = slopes[..., None] * arange_tensor
269    alibi = alibi.reshape(batch_size, num_heads, 1, valid_seq_length)
270
271    pre_pad_length = cache_size - valid_cache
272    pre_pad_tensor = np.zeros(
273        (batch_size, num_heads, 1, pre_pad_length), dtype=np.float32
274    )
275    post_pad_length = input_length - valid_input
276    post_pad_tensor = np.zeros(
277        (batch_size, num_heads, 1, post_pad_length), dtype=np.float32
278    )
279    alibi = np.concatenate([pre_pad_tensor, alibi, post_pad_tensor], axis=-1).astype(
280        np.float32
281    )
282
283    if pytorch:
284        return torch.from_numpy(alibi.copy())
285    return alibi.copy()
286
287
288def generate_mask(
289    cache_size,
290    valid_cache,
291    input_length,
292    valid_input,
293    batch_size=1,
294    mask_value=-100.0,
295    pytorch=True,
296):
297    assert (
298        valid_cache <= cache_size
299    ), "valid_cache must be less than or equal to cache_size"
300    assert (
301        valid_input <= input_length
302    ), "valid_input must be less than or equal to input_length"
303    # Cache mask portion
304    valid = np.zeros((1, 1, 1, valid_cache + input_length), dtype=np.float32)
305    cache_mask = np.full(
306        (1, 1, 1, cache_size - valid_cache), mask_value, dtype=np.float32
307    )
308    cache_mask = np.concatenate((cache_mask, valid), axis=-1)
309    cache_mask_final_shape = np.broadcast_to(
310        cache_mask, (batch_size, 1, input_length, cache_size + input_length)
311    )
312
313    # Attention mask portion
314    mask_cond = np.arange(valid_input)
315    triangle = mask_cond >= (mask_cond + 1).reshape(valid_input, 1)
316    small_attention_mask = triangle.astype(np.float32) * mask_value
317    attention_mask = np.pad(
318        small_attention_mask,
319        (0, input_length - valid_input),
320        "constant",
321        constant_values=mask_value,
322    )
323    attention_mask_with_cache = np.concatenate(
324        [np.zeros((input_length, cache_size), dtype=np.float32), attention_mask],
325        axis=-1,
326    )
327    attention_mask_final_shape = np.broadcast_to(
328        attention_mask_with_cache[None, None, :, :],
329        (batch_size, 1, input_length, cache_size + input_length),
330    )
331
332    combined_mask = attention_mask_final_shape + cache_mask_final_shape
333
334    if pytorch:
335        return torch.from_numpy(combined_mask.copy())
336    return combined_mask.copy()
337
338
339def get_dest_path(output_folder, exp_name, shape, chunk_idx):
340    dest_folder_root = output_folder + f"_{shape}"
341    os.makedirs(dest_folder_root, exist_ok=True)
342    fname = f"{exp_name}_{shape}_{chunk_idx}.pte"
343    dest_path = os.path.join(dest_folder_root, fname)
344
345    return dest_path
346
347
348def get_dirname(file_path):
349    return os.path.dirname(file_path)
350
351
352def get_exp_name(config_path):
353    weight_dir = get_dirname(config_path)
354    weight_name = os.path.basename(weight_dir)
355    config_name = os.path.basename(config_path).split(".json")[0].replace("config", "")
356    if config_name == "":
357        exp_name = f"{weight_name}"
358    else:
359        if config_name.startswith("_"):
360            config_name = config_name[1:]
361        exp_name = f"{weight_name}_{config_name}"
362    return exp_name
363
364
365def get_embedding_layer(config, weight_dir, state_dict):
366    embedding_weight = _get_embedding_weight(weight_dir, state_dict)
367
368    model = torch.nn.Embedding(config.vocab_size, config.hidden_size, -1)
369    embed_state_dict = {}
370    embed_state_dict["weight"] = embedding_weight.to(torch.float32)
371    model.load_state_dict(embed_state_dict)
372    return model
373
374
375def get_export_shapes(shapes):
376    export_shapes = {}
377    max_num_token = 0
378    max_cache_size = 0
379    for shape in shapes:
380        print(f"Shape: {shape}")
381        num_token = int(shape.split("t")[0])
382        cache_size = int(shape.split("t")[1].split("c")[0])
383        export_shapes[shape] = [num_token, cache_size]
384        max_num_token = num_token if num_token > max_num_token else max_num_token
385        max_cache_size = cache_size if cache_size > max_cache_size else max_cache_size
386
387    return export_shapes, max_num_token, max_cache_size
388
389
390def get_master_rot_emb(config, dtype):
391    rot_dim = int(config.hidden_size / config.num_attention_heads)
392    length = config.max_position_embeddings
393
394    if config.ntk_scaling_factor != 1.0:
395        base = (10000 * config.ntk_scaling_factor) ** (rot_dim / (rot_dim - 2))
396    else:
397        base = 10000
398
399    inv_freq = 1.0 / (
400        base ** (np.arange(0, rot_dim, 2, dtype=np.float32) / rot_dim)
401    )  # (rot_dim/2)
402    t = np.arange(length, dtype=np.float32)  # (len)
403    freqs = np.einsum("i,j->ij", t, inv_freq)  # (len, rot_dim/2)
404    emb = np.concatenate((freqs, freqs), axis=-1)  # (len, rot_dim)
405    master_cos = np.cos(emb)[None, None, :, :]  # (1,1,len,rot_dim)
406    master_sin = np.sin(emb)[None, None, :, :]  # (1,1,len,rot_dim)
407
408    rot_emb = np.concatenate((master_cos, master_sin), axis=1)
409
410    if isinstance(dtype, torch.dtype):
411        return torch.from_numpy(rot_emb).to(dtype)
412    else:
413        return rot_emb.astype(dtype)
414
415
416def get_normalized_config(config_filepath):
417    config_file = json.load(open(config_filepath, "r"))
418    if config_file["model_type"] == "llama":
419        from models.llm_models.configuration_llama import LlamaConfig as config_class
420    config = config_class(**config_file, verbose=False)
421    return config
422
423
424def get_sorted_path_list(folder, ext=".", absolute=False):
425    if absolute:
426        sorted_list = sorted(
427            os.listdir(folder), key=lambda f: int(f.rsplit("_", 1)[1].split(ext)[0])
428        )
429        return [os.path.join(folder, x) for x in sorted_list]
430    else:
431        return sorted(
432            os.listdir(folder), key=lambda f: int(f.rsplit("_", 1)[1].split(ext)[0])
433        )
434
435
436def load_checkpoints(weight_dir):
437    checkpoint_files = [
438        os.path.join(weight_dir, f)
439        for f in os.listdir(weight_dir)
440        if (f.startswith("pytorch_model") and f.endswith(".bin"))
441        or (f.startswith("model") and f.endswith(".safetensors"))
442    ]
443    if len(checkpoint_files) == 0:
444        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
445        print("!No model weight files found! Using fake weights!")
446        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
447
448    if len(checkpoint_files) == 0:
449        return None
450
451    state_dict = {}
452    print("Loading weights from disk")
453    is_safetensors = checkpoint_files[0].endswith(".safetensors")
454    for i in range(len(checkpoint_files)):
455        if is_safetensors:
456            state_dict = {**state_dict, **load_file(checkpoint_files[i], device="cpu")}
457        else:
458            state_dict = {
459                **state_dict,
460                **torch.load(
461                    checkpoint_files[i], map_location="cpu", weights_only=True
462                ),
463            }
464
465    return state_dict
466
467
468def resolve_model_classes(
469    config_filepath, bypass_tokenizer=False, response_handler=None
470):
471    config_file = json.load(open(config_filepath, "r"))
472    weight_dir = get_dirname(config_filepath)
473    if config_file["model_type"] == "llama":
474        from models.llm_models.configuration_llama import LlamaConfig as config_class
475        from models.llm_models.modeling_llama import LlamaModelChunk as chunk_class
476    config = config_class(**config_file, response_handler=response_handler)
477    if bypass_tokenizer:
478        return config, weight_dir, chunk_class
479    else:
480        if config.tokenizer == "default":
481            if config_file["model_type"] == "llama":
482                from aot_utils.llm_utils.tokenizers_.tokenization_llama import (
483                    LlamaTokenizer as tokenizer_class,
484                )
485        else:
486            if config.tokenizer == "llama":
487                from aot_utils.llm_utils.tokenizers_.tokenization_llama import (
488                    LlamaTokenizer as tokenizer_class,
489                )
490            elif config.tokenizer == "pretrained_fast":
491                from aot_utils.llm_utils.tokenizers_.tokenization_utils_fast import (
492                    PreTrainedTokenizerFast as tokenizer_class,
493                )
494        return config, weight_dir, tokenizer_class, chunk_class
495