1import os 2import sys 3 4if os.getcwd() not in sys.path: 5 sys.path.append(os.getcwd()) 6import json 7import math 8 9import numpy as np 10import torch 11from safetensors.torch import load_file 12 13 14# flake8: noqa: C901 15 16 17def _get_embedding_weight(weight_dir, state_dict): 18 try_last = False 19 checkpoint_filename = None 20 if state_dict is None: 21 checkpoint_files = [ 22 os.path.join(weight_dir, f) 23 for f in os.listdir(weight_dir) 24 if ( 25 (f.startswith("pytorch_model") and f.endswith(".bin")) 26 or (f.startswith("model") and f.endswith(".safetensors")) 27 ) 28 ] 29 30 for f in checkpoint_files: 31 if "pytorch_model.bin" in f: 32 checkpoint_filename = f 33 break 34 elif "pytorch_model-00001-of" in f: 35 checkpoint_filename = f 36 break 37 elif "model.safetensors" in f: 38 checkpoint_filename = f 39 break 40 elif "model-00001-of" in f: 41 checkpoint_filename = f 42 break 43 if checkpoint_filename is None: 44 raise FileNotFoundError( 45 f"Unable to find the first checkpoint file in {weight_dir}! " 46 "This folder must have either the file pytorch_model.bin or " 47 "pytorch_model-00001-of-XXXXX.bin or " 48 "model.safetensors or " 49 "model-00001-of-XXXXX.safetensors." 50 ) 51 52 if checkpoint_filename.endswith(".bin"): 53 state_dict = torch.load( 54 checkpoint_filename, map_location="cpu", weights_only=True 55 ) 56 elif checkpoint_filename.endswith(".safetensors"): 57 state_dict = load_file(checkpoint_filename, device="cpu") 58 try_last = True 59 60 state_dict_keys = list(state_dict.keys()) 61 62 expected_embedding_subkey = "embed_tokens.weight" 63 64 embed_key = None 65 for key in state_dict_keys: 66 if expected_embedding_subkey in key: 67 embed_key = key 68 break 69 if embed_key is None: 70 if try_last: 71 if ( 72 checkpoint_filename == "pytorch_model.bin" 73 or checkpoint_filename == "model.safetensors" 74 ): 75 print("state_dict keys:", state_dict_keys) 76 raise KeyError( 77 f"Cannot find embedding layer weight inside {checkpoint_filename}. " 78 f"Please ensure embedding layer weight key contains {expected_embedding_subkey}" 79 ) 80 else: 81 checkpoint_filename = checkpoint_filename.replace( 82 "00001", checkpoint_filename.split("-")[-1].split(".")[0] 83 ) 84 if checkpoint_filename.endswith(".bin"): 85 state_dict = torch.load( 86 checkpoint_filename, map_location="cpu", weights_only=True 87 ) 88 elif checkpoint_filename.endswith(".safetensors"): 89 state_dict = load_file(checkpoint_filename, device="cpu") 90 state_dict_keys = list(state_dict.keys()) 91 for key in state_dict_keys: 92 if expected_embedding_subkey in key: 93 embed_key = key 94 break 95 if embed_key is None: 96 print("state_dict keys:", state_dict_keys) 97 raise KeyError( 98 f"Cannot find embedding layer weight inside {checkpoint_filename}. " 99 f"Please ensure embedding layer weight key contains {expected_embedding_subkey}" 100 ) 101 else: 102 print("state_dict keys:", state_dict_keys) 103 raise KeyError( 104 f"Cannot find embedding layer weight inside state dict. " 105 f"Please ensure embedding layer weight key contains {expected_embedding_subkey}" 106 ) 107 return state_dict[embed_key] 108 109 110def chunk_and_tokenize_prompt( 111 prompt, 112 tokenizer, 113 sub_responses, 114 max_len, 115 response_handler, 116 preformatter=None, 117 wikitext=False, 118): 119 if max_len == 0: 120 # No chunking 121 if preformatter is not None: 122 prompt_formatted = preformatter.generate_prompt(prompt, None) 123 else: 124 prompt_formatted = prompt 125 126 if tokenizer is None: 127 with response_handler: 128 print("Prompt tokens:") 129 print(prompt) 130 prompt_tokens = prompt_formatted 131 else: 132 with response_handler: 133 if preformatter is not None: 134 print(f"Prompt (with {preformatter.name} preformatter):") 135 print(prompt) 136 else: 137 print("Prompt:") 138 print(prompt) 139 prompt_tokens = tokenizer(prompt_formatted, return_tensors="np")[ 140 "input_ids" 141 ].astype(np.int32) 142 return prompt_tokens, None 143 else: 144 if wikitext: 145 # Wikitext chunking, tokenized already 146 if prompt.shape[1] < max_len: 147 return prompt, None 148 else: 149 return prompt[:, :max_len], prompt[:, max_len:] 150 151 else: 152 # Oppo streaming prompt chunking 153 sentences = prompt.split("\n") 154 chunked = False 155 curr_chunk = "" 156 prev_chunk = None 157 prev_chunk_tokens = None 158 159 for sentence in sentences: 160 if preformatter is not None: 161 if len(sub_responses) == 0: 162 curr_chunk_formatted = preformatter.generate_prompt( 163 curr_chunk, None 164 ) 165 else: 166 curr_chunk_formatted = preformatter.generate_prompt( 167 curr_chunk, sub_responses[-1] 168 ) 169 else: 170 curr_chunk_formatted = curr_chunk 171 if tokenizer is None: 172 curr_chunk_tokens = curr_chunk_formatted 173 else: 174 curr_chunk_tokens = tokenizer( 175 curr_chunk_formatted, return_tensors="np" 176 )["input_ids"].astype(np.int32) 177 178 if curr_chunk_tokens.shape[1] < max_len: 179 prev_chunk = curr_chunk 180 prev_chunk_tokens = curr_chunk_tokens 181 curr_chunk += sentence + "\n" 182 else: 183 chunked = True 184 break 185 186 if prev_chunk_tokens is None: 187 raise RuntimeError( 188 f"Length of a single line ({curr_chunk_tokens.shape[1]}) is more than maximum length to chunk prompt to ({max_len})" 189 ) 190 191 if chunked: 192 with response_handler: 193 if preformatter is not None: 194 if len(sub_responses) == 0: 195 print(f"Prompt (with {preformatter.name} preformatter):") 196 print(prev_chunk) 197 else: 198 print( 199 f"Prompt (with {preformatter.name} preformatter with input):" 200 ) 201 print(prev_chunk) 202 else: 203 print("Prompt:") 204 print(prev_chunk) 205 return prev_chunk_tokens, prompt.split(prev_chunk)[1] 206 else: 207 with response_handler: 208 if preformatter is not None: 209 if len(sub_responses) == 0: 210 print(f"Prompt (with {preformatter.name} preformatter):") 211 print(curr_chunk) 212 else: 213 print( 214 f"Prompt (with {preformatter.name} preformatter with input):" 215 ) 216 print(curr_chunk) 217 else: 218 print("Prompt:") 219 print(curr_chunk) 220 return curr_chunk_tokens, None 221 222 223def dump_embedding_lut_for_cmdline(weight_dir, state_dict, config): 224 model_name = os.path.basename(weight_dir) 225 output_path = os.path.join(weight_dir, f"embedding_{model_name}_fp32.bin") 226 if not os.path.exists(output_path): 227 embedding = ( 228 _get_embedding_weight(weight_dir, state_dict) 229 .to(torch.float32) 230 .cpu() 231 .numpy() 232 ) 233 234 with open(output_path, "wb") as f: 235 f.write(embedding.flatten().tobytes()) 236 print(f"cmdline LUT embedding bin exported to {output_path}") 237 238 239def generate_alibi( 240 cache_size, 241 valid_cache, 242 input_length, 243 valid_input, 244 num_heads, 245 batch_size=1, 246 pytorch=False, 247): 248 assert ( 249 valid_input <= input_length 250 ), "valid_input must be less than or equal to input_length" 251 assert ( 252 valid_cache <= cache_size 253 ), "valid_cache must be less than or equal to cache_size" 254 valid_seq_length = valid_cache + valid_input 255 total_valid = np.ones((batch_size, valid_seq_length), dtype=np.int32) 256 closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) 257 base = 2 ** ((-(2 ** -(math.log2 - 3)))) 258 powers = np.arange(1, 1 + closest_power_of_2, dtype=np.int32) 259 slopes = np.power(base, powers) 260 261 if closest_power_of_2 != num_heads: 262 extra_base = 2 ** ((-(2 ** -(math.log2(2 * closest_power_of_2) - 3)))) 263 num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) 264 extra_powers = np.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=np.int32) 265 slopes = np.concatenate([slopes, np.power(extra_base, extra_powers)], axis=0) 266 267 arange_tensor = ((np.cumsum(total_valid, axis=-1) - 1))[:, None, :] 268 alibi = slopes[..., None] * arange_tensor 269 alibi = alibi.reshape(batch_size, num_heads, 1, valid_seq_length) 270 271 pre_pad_length = cache_size - valid_cache 272 pre_pad_tensor = np.zeros( 273 (batch_size, num_heads, 1, pre_pad_length), dtype=np.float32 274 ) 275 post_pad_length = input_length - valid_input 276 post_pad_tensor = np.zeros( 277 (batch_size, num_heads, 1, post_pad_length), dtype=np.float32 278 ) 279 alibi = np.concatenate([pre_pad_tensor, alibi, post_pad_tensor], axis=-1).astype( 280 np.float32 281 ) 282 283 if pytorch: 284 return torch.from_numpy(alibi.copy()) 285 return alibi.copy() 286 287 288def generate_mask( 289 cache_size, 290 valid_cache, 291 input_length, 292 valid_input, 293 batch_size=1, 294 mask_value=-100.0, 295 pytorch=True, 296): 297 assert ( 298 valid_cache <= cache_size 299 ), "valid_cache must be less than or equal to cache_size" 300 assert ( 301 valid_input <= input_length 302 ), "valid_input must be less than or equal to input_length" 303 # Cache mask portion 304 valid = np.zeros((1, 1, 1, valid_cache + input_length), dtype=np.float32) 305 cache_mask = np.full( 306 (1, 1, 1, cache_size - valid_cache), mask_value, dtype=np.float32 307 ) 308 cache_mask = np.concatenate((cache_mask, valid), axis=-1) 309 cache_mask_final_shape = np.broadcast_to( 310 cache_mask, (batch_size, 1, input_length, cache_size + input_length) 311 ) 312 313 # Attention mask portion 314 mask_cond = np.arange(valid_input) 315 triangle = mask_cond >= (mask_cond + 1).reshape(valid_input, 1) 316 small_attention_mask = triangle.astype(np.float32) * mask_value 317 attention_mask = np.pad( 318 small_attention_mask, 319 (0, input_length - valid_input), 320 "constant", 321 constant_values=mask_value, 322 ) 323 attention_mask_with_cache = np.concatenate( 324 [np.zeros((input_length, cache_size), dtype=np.float32), attention_mask], 325 axis=-1, 326 ) 327 attention_mask_final_shape = np.broadcast_to( 328 attention_mask_with_cache[None, None, :, :], 329 (batch_size, 1, input_length, cache_size + input_length), 330 ) 331 332 combined_mask = attention_mask_final_shape + cache_mask_final_shape 333 334 if pytorch: 335 return torch.from_numpy(combined_mask.copy()) 336 return combined_mask.copy() 337 338 339def get_dest_path(output_folder, exp_name, shape, chunk_idx): 340 dest_folder_root = output_folder + f"_{shape}" 341 os.makedirs(dest_folder_root, exist_ok=True) 342 fname = f"{exp_name}_{shape}_{chunk_idx}.pte" 343 dest_path = os.path.join(dest_folder_root, fname) 344 345 return dest_path 346 347 348def get_dirname(file_path): 349 return os.path.dirname(file_path) 350 351 352def get_exp_name(config_path): 353 weight_dir = get_dirname(config_path) 354 weight_name = os.path.basename(weight_dir) 355 config_name = os.path.basename(config_path).split(".json")[0].replace("config", "") 356 if config_name == "": 357 exp_name = f"{weight_name}" 358 else: 359 if config_name.startswith("_"): 360 config_name = config_name[1:] 361 exp_name = f"{weight_name}_{config_name}" 362 return exp_name 363 364 365def get_embedding_layer(config, weight_dir, state_dict): 366 embedding_weight = _get_embedding_weight(weight_dir, state_dict) 367 368 model = torch.nn.Embedding(config.vocab_size, config.hidden_size, -1) 369 embed_state_dict = {} 370 embed_state_dict["weight"] = embedding_weight.to(torch.float32) 371 model.load_state_dict(embed_state_dict) 372 return model 373 374 375def get_export_shapes(shapes): 376 export_shapes = {} 377 max_num_token = 0 378 max_cache_size = 0 379 for shape in shapes: 380 print(f"Shape: {shape}") 381 num_token = int(shape.split("t")[0]) 382 cache_size = int(shape.split("t")[1].split("c")[0]) 383 export_shapes[shape] = [num_token, cache_size] 384 max_num_token = num_token if num_token > max_num_token else max_num_token 385 max_cache_size = cache_size if cache_size > max_cache_size else max_cache_size 386 387 return export_shapes, max_num_token, max_cache_size 388 389 390def get_master_rot_emb(config, dtype): 391 rot_dim = int(config.hidden_size / config.num_attention_heads) 392 length = config.max_position_embeddings 393 394 if config.ntk_scaling_factor != 1.0: 395 base = (10000 * config.ntk_scaling_factor) ** (rot_dim / (rot_dim - 2)) 396 else: 397 base = 10000 398 399 inv_freq = 1.0 / ( 400 base ** (np.arange(0, rot_dim, 2, dtype=np.float32) / rot_dim) 401 ) # (rot_dim/2) 402 t = np.arange(length, dtype=np.float32) # (len) 403 freqs = np.einsum("i,j->ij", t, inv_freq) # (len, rot_dim/2) 404 emb = np.concatenate((freqs, freqs), axis=-1) # (len, rot_dim) 405 master_cos = np.cos(emb)[None, None, :, :] # (1,1,len,rot_dim) 406 master_sin = np.sin(emb)[None, None, :, :] # (1,1,len,rot_dim) 407 408 rot_emb = np.concatenate((master_cos, master_sin), axis=1) 409 410 if isinstance(dtype, torch.dtype): 411 return torch.from_numpy(rot_emb).to(dtype) 412 else: 413 return rot_emb.astype(dtype) 414 415 416def get_normalized_config(config_filepath): 417 config_file = json.load(open(config_filepath, "r")) 418 if config_file["model_type"] == "llama": 419 from models.llm_models.configuration_llama import LlamaConfig as config_class 420 config = config_class(**config_file, verbose=False) 421 return config 422 423 424def get_sorted_path_list(folder, ext=".", absolute=False): 425 if absolute: 426 sorted_list = sorted( 427 os.listdir(folder), key=lambda f: int(f.rsplit("_", 1)[1].split(ext)[0]) 428 ) 429 return [os.path.join(folder, x) for x in sorted_list] 430 else: 431 return sorted( 432 os.listdir(folder), key=lambda f: int(f.rsplit("_", 1)[1].split(ext)[0]) 433 ) 434 435 436def load_checkpoints(weight_dir): 437 checkpoint_files = [ 438 os.path.join(weight_dir, f) 439 for f in os.listdir(weight_dir) 440 if (f.startswith("pytorch_model") and f.endswith(".bin")) 441 or (f.startswith("model") and f.endswith(".safetensors")) 442 ] 443 if len(checkpoint_files) == 0: 444 print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 445 print("!No model weight files found! Using fake weights!") 446 print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 447 448 if len(checkpoint_files) == 0: 449 return None 450 451 state_dict = {} 452 print("Loading weights from disk") 453 is_safetensors = checkpoint_files[0].endswith(".safetensors") 454 for i in range(len(checkpoint_files)): 455 if is_safetensors: 456 state_dict = {**state_dict, **load_file(checkpoint_files[i], device="cpu")} 457 else: 458 state_dict = { 459 **state_dict, 460 **torch.load( 461 checkpoint_files[i], map_location="cpu", weights_only=True 462 ), 463 } 464 465 return state_dict 466 467 468def resolve_model_classes( 469 config_filepath, bypass_tokenizer=False, response_handler=None 470): 471 config_file = json.load(open(config_filepath, "r")) 472 weight_dir = get_dirname(config_filepath) 473 if config_file["model_type"] == "llama": 474 from models.llm_models.configuration_llama import LlamaConfig as config_class 475 from models.llm_models.modeling_llama import LlamaModelChunk as chunk_class 476 config = config_class(**config_file, response_handler=response_handler) 477 if bypass_tokenizer: 478 return config, weight_dir, chunk_class 479 else: 480 if config.tokenizer == "default": 481 if config_file["model_type"] == "llama": 482 from aot_utils.llm_utils.tokenizers_.tokenization_llama import ( 483 LlamaTokenizer as tokenizer_class, 484 ) 485 else: 486 if config.tokenizer == "llama": 487 from aot_utils.llm_utils.tokenizers_.tokenization_llama import ( 488 LlamaTokenizer as tokenizer_class, 489 ) 490 elif config.tokenizer == "pretrained_fast": 491 from aot_utils.llm_utils.tokenizers_.tokenization_utils_fast import ( 492 PreTrainedTokenizerFast as tokenizer_class, 493 ) 494 return config, weight_dir, tokenizer_class, chunk_class 495