1"""Common backbone across multiple models""" 2 3import math 4 5import numpy as np 6import torch 7from models.llm_models.configuration_base import BaseConfig 8from models.llm_models.modeling_base import BaseModelChunk 9from torch import nn 10from torch.export import Dim 11 12torch.manual_seed(42) 13np.random.seed(42) 14 15 16# flake8: noqa: C901 17 18 19class RMSNorm(nn.Module): 20 def __init__(self, hidden_size, eps=1e-6): 21 super().__init__() 22 self.weight = nn.Parameter(torch.ones(hidden_size)) 23 self.variance_epsilon = eps 24 25 def forward(self, hidden_states): 26 variance = hidden_states.pow(2).mean(-1, keepdim=True) 27 hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 28 return self.weight * hidden_states 29 30 31class Gelu(nn.Module): 32 def __init__(self): 33 super().__init__() 34 35 def forward(self, x: torch.Tensor) -> torch.Tensor: 36 return x * torch.sigmoid(1.702 * x) 37 38 39class MLP(nn.Module): 40 def __init__(self, config: BaseConfig): 41 super().__init__() 42 hidden_size = config.hidden_size 43 intermediate_size = config.intermediate_size 44 45 self.gate_proj = nn.Linear(hidden_size, intermediate_size) 46 self.down_proj = nn.Linear(intermediate_size, hidden_size) 47 self.up_proj = nn.Linear(hidden_size, intermediate_size) 48 49 def forward(self, x): 50 gate = self.gate_proj(x) 51 up = self.up_proj(x) 52 pre_down = gate * torch.sigmoid(gate) * up 53 down = self.down_proj(pre_down) 54 55 return down 56 57 58class Attention(nn.Module): 59 def __init__(self, config: BaseConfig): 60 super().__init__() 61 self.config = config 62 self.hidden_size = config.hidden_size 63 self.num_heads = config.num_attention_heads 64 self.num_key_value_heads = config.num_key_value_heads 65 self.num_key_value_groups = self.num_heads // self.num_key_value_heads 66 self.head_dim = self.hidden_size // self.num_heads 67 self.attn_scale = math.sqrt(self.head_dim) 68 69 if config.combine_qkv: 70 self.qkv_proj = nn.Linear( 71 self.hidden_size, 72 (2 * self.num_key_value_heads * self.head_dim) + self.hidden_size, 73 ) 74 else: 75 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim) 76 self.k_proj = nn.Linear( 77 self.hidden_size, self.num_key_value_heads * self.head_dim 78 ) 79 self.v_proj = nn.Linear( 80 self.hidden_size, self.num_key_value_heads * self.head_dim 81 ) 82 self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size) 83 84 def apply_rotary_pos_emb_mtk(self, q, k, cos, sin): 85 q1 = q[..., : q.shape[-1] // 2] 86 q2 = q[..., q.shape[-1] // 2 :] 87 q_rotated = torch.cat((-q2, q1), dim=-1) 88 k1 = k[..., : k.shape[-1] // 2] 89 k2 = k[..., k.shape[-1] // 2 :] 90 k_rotated = torch.cat((-k2, k1), dim=-1) 91 92 q_embed = q * cos + q_rotated * sin 93 k_embed = k * cos + k_rotated * sin 94 return q_embed, k_embed 95 96 def repeat_kv(self, hidden_states, batch, q_len, n_rep): 97 if isinstance(hidden_states, list): 98 output = [] 99 for hs in hidden_states: 100 output.append( 101 hs.repeat(1, 1, n_rep, 1).view(batch, 1, q_len, self.head_dim) 102 ) 103 return output 104 else: 105 hidden_states = hidden_states.repeat(1, 1, n_rep, 1) 106 return hidden_states.view(batch, self.num_heads, q_len, self.head_dim) 107 108 def forward( 109 self, 110 hidden_states, # (b, t, 4096) 111 mask, # (b, 1, t, c+t) 112 pos_emb, # (b, 2, t, head dim) 113 past_key, # (b, num kv heads, c, head dim) 114 past_value, # (b, num kv heads, c, head dim) 115 ): 116 bsz, q_len, _ = hidden_states.size() 117 c_len = past_key.size()[2] 118 119 if self.config.combine_qkv: 120 proj = self.qkv_proj(hidden_states) 121 query_states = ( 122 proj[:, :, : self.config.hidden_size] 123 .view(bsz, q_len, self.num_heads, self.head_dim) 124 .transpose(1, 2) 125 ) 126 key_states = ( 127 proj[ 128 :, 129 :, 130 self.config.hidden_size : self.config.hidden_size 131 + self.num_key_value_heads * self.head_dim, 132 ] 133 .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 134 .transpose(1, 2) 135 ) 136 value_states = ( 137 proj[ 138 :, 139 :, 140 self.config.hidden_size 141 + self.num_key_value_heads * self.head_dim :, 142 ] 143 .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 144 .transpose(1, 2) 145 ) 146 else: 147 query_states = ( 148 self.q_proj(hidden_states) 149 .view(bsz, q_len, self.num_heads, self.head_dim) 150 .transpose(1, 2) 151 ) 152 key_states = ( 153 self.k_proj(hidden_states) 154 .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 155 .transpose(1, 2) 156 ) 157 value_states = ( 158 self.v_proj(hidden_states) 159 .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 160 .transpose(1, 2) 161 ) 162 163 if self.config.position_embedding == "rope": 164 cos, sin = torch.split(pos_emb, 1, dim=1) 165 query_states, key_states = self.apply_rotary_pos_emb_mtk( 166 query_states, key_states, cos, sin 167 ) 168 169 key_states = torch.cat([past_key, key_states], dim=2) 170 value_states = torch.cat([past_value, value_states], dim=2) 171 key_states_out = key_states 172 value_states_out = value_states 173 if self.num_key_value_groups > 1: 174 key_states = self.repeat_kv( 175 key_states, bsz, q_len + c_len, self.num_key_value_groups 176 ) 177 value_states = self.repeat_kv( 178 value_states, bsz, q_len + c_len, self.num_key_value_groups 179 ) 180 attn_weights = ( 181 torch.matmul(query_states, key_states.transpose(2, 3)) / self.attn_scale 182 ) 183 attn_weights = attn_weights + mask 184 attn_weights = nn.functional.softmax(attn_weights, dim=-1) 185 186 attn_output = torch.matmul(attn_weights, value_states) 187 attn_output = attn_output.transpose(1, 2) 188 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 189 attn_output = self.o_proj(attn_output) 190 191 key_states_out = key_states_out[:, :, q_len:, :] 192 value_states_out = value_states_out[:, :, q_len:, :] 193 194 return attn_output, key_states_out, value_states_out 195 196 197class DecoderLayer(nn.Module): 198 def __init__( 199 self, 200 config: BaseConfig, 201 return_attn=False, 202 jit_trace=False, 203 attn_class=Attention, 204 mlp_class=MLP, 205 ): 206 super().__init__() 207 self.hidden_size = config.hidden_size 208 self.return_attn = return_attn 209 self.jit_trace = jit_trace 210 self.self_attn = attn_class(config) 211 self.mlp = mlp_class(config) 212 if config.norm == "RMSNorm": 213 self.input_norm = RMSNorm(config.hidden_size, eps=config.norm_eps).float() 214 self.post_attention_norm = RMSNorm( 215 config.hidden_size, eps=config.norm_eps 216 ).float() 217 else: 218 self.input_norm = nn.LayerNorm( 219 config.hidden_size, eps=config.norm_eps 220 ).float() 221 self.post_attention_norm = nn.LayerNorm( 222 config.hidden_size, eps=config.norm_eps 223 ).float() 224 225 def forward( 226 self, 227 hidden_states, # (b, t, hidden_dim) 228 mask, # (b, 1, t, c+t) 229 pos_emb, # (b, 2, t, head_dim) 230 past_key, # (b, num_kv_head, c, head_dim) 231 past_value, # (b, num_kv_head, c, head_dim) 232 ): 233 residual = hidden_states 234 if self.jit_trace: 235 hidden_states = self.input_norm(hidden_states) 236 else: 237 dtype = hidden_states.dtype 238 hidden_states = self.input_norm(hidden_states.to(torch.float32)).to(dtype) 239 240 layer_device = hidden_states.device 241 242 # Self Attention 243 attn_output, present_key, present_value = self.self_attn( 244 hidden_states=hidden_states.to(layer_device), 245 mask=mask.to(layer_device), 246 pos_emb=pos_emb.to(layer_device), 247 past_key=past_key.to(layer_device), 248 past_value=past_value.to(layer_device), 249 ) 250 hidden_states = residual.to(layer_device) + attn_output 251 252 # Fully Connected 253 residual = hidden_states 254 if self.jit_trace: 255 hidden_states = self.post_attention_norm(hidden_states) 256 else: 257 dtype = hidden_states.dtype 258 hidden_states = self.post_attention_norm( 259 hidden_states.to(torch.float32) 260 ).to(dtype) 261 hidden_states = self.mlp(hidden_states) 262 hidden_states = residual + hidden_states 263 264 if self.return_attn: 265 return hidden_states, present_key, present_value, attn_output 266 return hidden_states, present_key, present_value 267 268 269class ModelChunk(BaseModelChunk): 270 def __init__( 271 self, 272 config: BaseConfig, 273 num_blocks, 274 chunk_idx, 275 dtype=torch.float32, 276 include_tail=False, 277 return_attn=False, 278 jit_trace=False, 279 decoder_class=DecoderLayer, 280 ): 281 super().__init__( 282 config, num_blocks, chunk_idx, dtype, include_tail, return_attn, jit_trace 283 ) 284 self.head_dim = config.hidden_size // config.num_attention_heads 285 self.layers = nn.ModuleList( 286 [ 287 decoder_class(config, return_attn=return_attn, jit_trace=jit_trace) 288 for _ in range(num_blocks) 289 ] 290 ) 291 292 if self.config.use_stable_embedding and self.chunk_idx == 0: 293 self.embed_layer_norm = nn.LayerNorm(config.hidden_size).float() 294 295 if self.include_tail: 296 if config.norm == "RMSNorm": 297 self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps).float() 298 else: 299 self.norm = nn.LayerNorm( 300 config.hidden_size, eps=config.norm_eps 301 ).float() 302 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) 303 304 def forward(self, inputs_embeds, mask, pos_emb, *cache): 305 if not self.jit_trace: 306 assert ( 307 len(cache) == 2 * self.num_blocks 308 ), f"split cache wrong number of input caches: {len(cache)} != 2*{self.num_blocks}" 309 assert ( 310 cache[0].shape[0] == inputs_embeds.size()[0] 311 ), f"split cache batch size mismatch: {cache[0].shape[0]} != {inputs_embeds.size()[0]}" 312 313 inputs_embeds = inputs_embeds.to(self.device_list[0]) 314 315 if self.config.use_stable_embedding and self.chunk_idx == 0: 316 if self.jit_trace: 317 inputs_embeds = self.embed_layer_norm(inputs_embeds) 318 else: 319 inputs_embeds = self.embed_layer_norm( 320 inputs_embeds.to(torch.float32) 321 ).to(self.dtype) 322 323 hidden_states = inputs_embeds 324 325 next_key_cache = [] 326 next_value_cache = [] 327 if self.return_attn: 328 attn_outputs = [] 329 330 # decoder layers 331 for idx, decoder_layer in enumerate(self.layers): 332 decoder_outputs = decoder_layer( 333 hidden_states.to(self.device_list[idx]), 334 mask=mask.to(self.device_list[idx]), 335 pos_emb=pos_emb.to(self.device_list[idx]), 336 past_key=cache[idx].to(self.device_list[idx]), 337 past_value=cache[self.num_blocks + idx].to(self.device_list[idx]), 338 ) 339 hidden_states = decoder_outputs[0] 340 next_key_cache.append(decoder_outputs[1].to(inputs_embeds.device)) 341 next_value_cache.append(decoder_outputs[2].to(inputs_embeds.device)) 342 if self.return_attn: 343 attn_outputs.append(decoder_outputs[3].to(inputs_embeds.device)) 344 345 if self.include_tail: 346 if self.jit_trace: 347 hidden_states = self.norm(hidden_states) 348 else: 349 hidden_states = self.norm(hidden_states.to(torch.float32)).to( 350 self.dtype 351 ) 352 hidden_states = self.lm_head(hidden_states) 353 354 if self.return_attn: 355 return hidden_states, *next_key_cache, *next_value_cache, *attn_outputs 356 return hidden_states, *next_key_cache, *next_value_cache 357 358 def load_weights(self, state_dict, state_dict_start_idx): 359 if state_dict is None: 360 fake_weights = True 361 else: 362 expected_subkey = f"layers.{state_dict_start_idx}.self_attn.o_proj.weight" 363 state_dict_keys = list(state_dict.keys()) 364 temp_key = None 365 input_norm_subkey = None 366 post_attention_norm_subkey = None 367 for key in state_dict_keys: 368 if expected_subkey in key: 369 temp_key = key 370 if ( 371 f"layers.{state_dict_start_idx}" in key 372 and "norm" in key 373 and "input" in key 374 ): 375 input_norm_subkey = key.split(".")[-2] 376 if ( 377 f"layers.{state_dict_start_idx}" in key 378 and "norm" in key 379 and "post_attention" in key 380 ): 381 post_attention_norm_subkey = key.split(".")[-2] 382 if temp_key is None: 383 raise KeyError( 384 f"Cannot find layer {state_dict_start_idx}'s o_proj weight inside state_dict. " 385 f"Please ensure o_proj weight key contains: {expected_subkey}" 386 ) 387 if input_norm_subkey is None: 388 raise KeyError( 389 f"Cannot find layer {state_dict_start_idx}'s input norm weight inside state_dict. " 390 f"Please ensure input norm weight key contains: layers.{state_dict_start_idx}, norm, and input inside" 391 " the key string." 392 ) 393 if post_attention_norm_subkey is None: 394 raise KeyError( 395 f"Cannot find layer {state_dict_start_idx}'s post attention norm weight inside state_dict." 396 f" Please ensure post attention norm weight key contains: layers.{state_dict_start_idx}, norm, and " 397 "post_attention inside the key string." 398 ) 399 prefix = temp_key.split(expected_subkey)[0] 400 fake_weights = False 401 402 outer_layer_idx = state_dict_start_idx 403 self.device_list = [] 404 if self.config.use_stable_embedding and self.chunk_idx == 0: 405 if fake_weights: 406 temp_state_dict = { 407 "embed_layer_norm.weight": torch.rand( 408 self.config.hidden_size, dtype=torch.float32 409 ), 410 "embed_layer_norm.bias": torch.zeros( 411 self.config.hidden_size, dtype=torch.float32 412 ), 413 } 414 else: 415 temp_state_dict = { 416 "embed_layer_norm.weight": state_dict.pop( 417 f"{prefix}embed_layer_norm.weight" 418 ).to(torch.float32), 419 "embed_layer_norm.bias": state_dict.pop( 420 f"{prefix}embed_layer_norm.bias", 421 torch.zeros(self.config.hidden_size, dtype=self.dtype), 422 ).to(torch.float32), 423 } 424 else: 425 temp_state_dict = {} 426 427 for inner_layer_idx in range(self.num_blocks): 428 if fake_weights: 429 if self.config.combine_qkv: 430 temp_state_dict[ 431 f"layers.{inner_layer_idx}.self_attn.qkv_proj.weight" 432 ] = torch.rand( 433 3 * self.config.hidden_size, 434 self.config.hidden_size, 435 dtype=self.dtype, 436 ) 437 temp_state_dict[ 438 f"layers.{inner_layer_idx}.self_attn.qkv_proj.bias" 439 ] = torch.zeros( 440 (2 * self.config.num_key_value_heads * self.head_dim) 441 + self.config.hidden_size, 442 dtype=self.dtype, 443 ) 444 else: 445 temp_state_dict = { 446 **temp_state_dict, 447 **{ 448 f"layers.{inner_layer_idx}.self_attn.q_proj.weight": torch.rand( 449 self.config.hidden_size, 450 self.config.hidden_size, 451 dtype=self.dtype, 452 ), 453 f"layers.{inner_layer_idx}.self_attn.k_proj.weight": torch.rand( 454 self.config.num_key_value_heads * self.head_dim, 455 self.config.hidden_size, 456 dtype=self.dtype, 457 ), 458 f"layers.{inner_layer_idx}.self_attn.v_proj.weight": torch.rand( 459 self.config.num_key_value_heads * self.head_dim, 460 self.config.hidden_size, 461 dtype=self.dtype, 462 ), 463 f"layers.{inner_layer_idx}.self_attn.q_proj.bias": torch.zeros( 464 self.config.hidden_size, dtype=self.dtype 465 ), 466 f"layers.{inner_layer_idx}.self_attn.k_proj.bias": torch.zeros( 467 self.config.num_key_value_heads * self.head_dim, 468 dtype=self.dtype, 469 ), 470 f"layers.{inner_layer_idx}.self_attn.v_proj.bias": torch.zeros( 471 self.config.num_key_value_heads * self.head_dim, 472 dtype=self.dtype, 473 ), 474 }, 475 } 476 temp_state_dict = { 477 **temp_state_dict, 478 **{ 479 f"layers.{inner_layer_idx}.self_attn.o_proj.weight": torch.rand( 480 self.config.hidden_size, 481 self.config.hidden_size, 482 dtype=self.dtype, 483 ), 484 f"layers.{inner_layer_idx}.mlp.gate_proj.weight": torch.rand( 485 self.config.intermediate_size, 486 self.config.hidden_size, 487 dtype=self.dtype, 488 ), 489 f"layers.{inner_layer_idx}.mlp.down_proj.weight": torch.rand( 490 self.config.hidden_size, 491 self.config.intermediate_size, 492 dtype=self.dtype, 493 ), 494 f"layers.{inner_layer_idx}.mlp.up_proj.weight": torch.rand( 495 self.config.intermediate_size, 496 self.config.hidden_size, 497 dtype=self.dtype, 498 ), 499 f"layers.{inner_layer_idx}.input_norm.weight": torch.rand( 500 self.config.hidden_size, dtype=torch.float32 501 ), 502 f"layers.{inner_layer_idx}.post_attention_norm.weight": torch.rand( 503 self.config.hidden_size, dtype=torch.float32 504 ), 505 f"layers.{inner_layer_idx}.self_attn.o_proj.bias": torch.zeros( 506 self.config.hidden_size, dtype=self.dtype 507 ), 508 f"layers.{inner_layer_idx}.mlp.gate_proj.bias": torch.zeros( 509 self.config.intermediate_size, dtype=self.dtype 510 ), 511 f"layers.{inner_layer_idx}.mlp.down_proj.bias": torch.zeros( 512 self.config.hidden_size, dtype=self.dtype 513 ), 514 f"layers.{inner_layer_idx}.mlp.up_proj.bias": torch.zeros( 515 self.config.intermediate_size, dtype=self.dtype 516 ), 517 }, 518 } 519 520 if self.config.norm == "LayerNorm": 521 temp_state_dict = { 522 **temp_state_dict, 523 **{ 524 f"layers.{inner_layer_idx}.input_norm.bias": torch.zeros( 525 self.config.hidden_size, dtype=torch.float32 526 ), 527 f"layers.{inner_layer_idx}.post_attention_norm.bias": torch.zeros( 528 self.config.hidden_size, dtype=torch.float32 529 ), 530 }, 531 } 532 533 else: 534 if self.config.combine_qkv: 535 temp_state_dict[ 536 f"layers.{inner_layer_idx}.self_attn.qkv_proj.weight" 537 ] = state_dict.pop( 538 f"{prefix}layers.{outer_layer_idx}.self_attn.qkv_proj.weight" 539 ) 540 temp_state_dict[ 541 f"layers.{inner_layer_idx}.self_attn.qkv_proj.bias" 542 ] = state_dict.pop( 543 f"{prefix}layers.{outer_layer_idx}.self_attn.qkv_proj.bias", 544 torch.zeros( 545 (2 * self.config.num_key_value_heads * self.head_dim) 546 + self.config.hidden_size, 547 dtype=self.dtype, 548 ), 549 ) 550 else: 551 temp_state_dict = { 552 **temp_state_dict, 553 **{ 554 f"layers.{inner_layer_idx}.self_attn.q_proj.weight": state_dict.pop( 555 f"{prefix}layers.{outer_layer_idx}.self_attn.q_proj.weight" 556 ), 557 f"layers.{inner_layer_idx}.self_attn.k_proj.weight": state_dict.pop( 558 f"{prefix}layers.{outer_layer_idx}.self_attn.k_proj.weight" 559 ), 560 f"layers.{inner_layer_idx}.self_attn.v_proj.weight": state_dict.pop( 561 f"{prefix}layers.{outer_layer_idx}.self_attn.v_proj.weight" 562 ), 563 f"layers.{inner_layer_idx}.self_attn.q_proj.bias": state_dict.pop( 564 f"{prefix}layers.{outer_layer_idx}.self_attn.q_proj.bias", 565 torch.zeros(self.config.hidden_size, dtype=self.dtype), 566 ), 567 f"layers.{inner_layer_idx}.self_attn.k_proj.bias": state_dict.pop( 568 f"{prefix}layers.{outer_layer_idx}.self_attn.k_proj.bias", 569 torch.zeros( 570 self.config.num_key_value_heads * self.head_dim, 571 dtype=self.dtype, 572 ), 573 ), 574 f"layers.{inner_layer_idx}.self_attn.v_proj.bias": state_dict.pop( 575 f"{prefix}layers.{outer_layer_idx}.self_attn.v_proj.bias", 576 torch.zeros( 577 self.config.num_key_value_heads * self.head_dim, 578 dtype=self.dtype, 579 ), 580 ), 581 }, 582 } 583 584 temp_state_dict = { 585 **temp_state_dict, 586 **{ 587 f"layers.{inner_layer_idx}.self_attn.o_proj.weight": state_dict.pop( 588 f"{prefix}layers.{outer_layer_idx}.self_attn.o_proj.weight" 589 ), 590 f"layers.{inner_layer_idx}.mlp.gate_proj.weight": state_dict.pop( 591 f"{prefix}layers.{outer_layer_idx}.mlp.gate_proj.weight" 592 ), 593 f"layers.{inner_layer_idx}.mlp.down_proj.weight": state_dict.pop( 594 f"{prefix}layers.{outer_layer_idx}.mlp.down_proj.weight" 595 ), 596 f"layers.{inner_layer_idx}.mlp.up_proj.weight": state_dict.pop( 597 f"{prefix}layers.{outer_layer_idx}.mlp.up_proj.weight" 598 ), 599 f"layers.{inner_layer_idx}.input_norm.weight": state_dict.pop( 600 f"{prefix}layers.{outer_layer_idx}.{input_norm_subkey}.weight" 601 ).to(torch.float32), 602 f"layers.{inner_layer_idx}.post_attention_norm.weight": state_dict.pop( 603 f"{prefix}layers.{outer_layer_idx}.{post_attention_norm_subkey}.weight" 604 ).to( 605 torch.float32 606 ), 607 f"layers.{inner_layer_idx}.self_attn.o_proj.bias": state_dict.pop( 608 f"{prefix}layers.{outer_layer_idx}.self_attn.o_proj.bias", 609 torch.zeros(self.config.hidden_size, dtype=self.dtype), 610 ), 611 f"layers.{inner_layer_idx}.mlp.gate_proj.bias": state_dict.pop( 612 f"{prefix}layers.{outer_layer_idx}.mlp.gate_proj.bias", 613 torch.zeros( 614 self.config.intermediate_size, dtype=self.dtype 615 ), 616 ), 617 f"layers.{inner_layer_idx}.mlp.down_proj.bias": state_dict.pop( 618 f"{prefix}layers.{outer_layer_idx}.mlp.down_proj.bias", 619 torch.zeros(self.config.hidden_size, dtype=self.dtype), 620 ), 621 f"layers.{inner_layer_idx}.mlp.up_proj.bias": state_dict.pop( 622 f"{prefix}layers.{outer_layer_idx}.mlp.up_proj.bias", 623 torch.zeros( 624 self.config.intermediate_size, dtype=self.dtype 625 ), 626 ), 627 }, 628 } 629 630 if self.config.norm == "LayerNorm": 631 temp_state_dict = { 632 **temp_state_dict, 633 **{ 634 f"layers.{inner_layer_idx}.input_norm.bias": state_dict.pop( 635 f"{prefix}layers.{outer_layer_idx}.{input_norm_subkey}.bias", 636 torch.zeros(self.config.hidden_size, dtype=self.dtype), 637 ).to(torch.float32), 638 f"layers.{inner_layer_idx}.post_attention_norm.bias": state_dict.pop( 639 f"{prefix}layers.{outer_layer_idx}.{post_attention_norm_subkey}.bias", 640 torch.zeros(self.config.hidden_size, dtype=self.dtype), 641 ).to( 642 torch.float32 643 ), 644 }, 645 } 646 647 if torch.cuda.device_count() == 0 or self.jit_trace: 648 self.device_list.append("cpu") 649 else: 650 device_id = outer_layer_idx // ( 651 self.config.num_hidden_layers // torch.cuda.device_count() 652 + (self.config.num_hidden_layers % torch.cuda.device_count() != 0) 653 ) 654 self.device_list.append(f"cuda:{device_id}") 655 outer_layer_idx += 1 656 if self.include_tail: 657 if fake_weights: 658 temp_state_dict = { 659 **temp_state_dict, 660 "norm.weight": torch.rand( 661 self.config.hidden_size, dtype=torch.float32 662 ), 663 "lm_head.weight": torch.rand( 664 self.config.vocab_size, 665 self.config.hidden_size, 666 dtype=self.dtype, 667 ), 668 "lm_head.bias": torch.zeros( 669 self.config.vocab_size, dtype=self.dtype 670 ), 671 } 672 if self.config.norm == "LayerNorm": 673 temp_state_dict["norm.bias"] = torch.zeros( 674 self.config.hidden_size, dtype=torch.float32 675 ) 676 else: 677 if self.config.tie_word_embeddings: 678 lm_head_weight_key = f"{prefix}embed_tokens.weight" 679 lm_head_bias_key = f"{prefix}embed_tokens.bias" 680 else: 681 lm_head_weight_key = "lm_head.weight" 682 lm_head_bias_key = "lm_head.bias" 683 temp_state_dict = { 684 **temp_state_dict, 685 **{ 686 "lm_head.weight": state_dict.pop(lm_head_weight_key), 687 "norm.weight": state_dict.pop(f"{prefix}norm.weight").to( 688 torch.float32 689 ), 690 "lm_head.bias": state_dict.pop( 691 lm_head_bias_key, 692 torch.zeros(self.config.vocab_size, dtype=self.dtype), 693 ), 694 }, 695 } 696 if self.config.norm == "LayerNorm": 697 temp_state_dict["norm.bias"] = state_dict.pop( 698 f"{prefix}norm.bias", 699 torch.zeros(self.config.hidden_size, dtype=self.dtype), 700 ).to(torch.float32) 701 702 print(f"Loading weights for chunk {self.chunk_idx}") 703 if temp_state_dict.keys() != self.state_dict().keys(): 704 temp_state_dict_only_keys = [ 705 x for x in temp_state_dict.keys() if x not in self.state_dict().keys() 706 ] 707 model_only_keys = [ 708 x for x in self.state_dict().keys() if x not in temp_state_dict.keys() 709 ] 710 raise RuntimeError( 711 f"model state dict keys don't match with state_dict to load into model.\nModel only keys:{model_only_keys}\nstate_dict only keys:{temp_state_dict_only_keys}" 712 ) 713 self.load_state_dict(temp_state_dict) 714 for i in range(self.num_blocks): 715 self.layers[i].to(self.device_list[i]) 716 if self.config.use_stable_embedding and self.chunk_idx == 0: 717 self.embed_layer_norm.to(self.device_list[0]) 718 if self.include_tail: 719 self.norm.to(self.device_list[-1]) 720 self.lm_head.to(self.device_list[-1]) 721 self.eval() 722 723 return self 724 725 def get_example_inputs( 726 self, num_token: int = 128, cache_size: int = 512, get_dym_shape=False 727 ): 728 head_dim = int(self.config.hidden_size / self.config.num_attention_heads) 729 example_inputs = ( 730 torch.randn( 731 1, num_token, self.config.hidden_size, device="cpu", dtype=torch.float32 732 ), 733 torch.randn( 734 1, 735 1, 736 num_token, 737 cache_size + num_token, 738 device="cpu", 739 dtype=torch.float32, 740 ), 741 torch.randn(1, 2, num_token, head_dim, device="cpu", dtype=torch.float32), 742 *[ 743 torch.randn( 744 1, 745 self.config.num_key_value_heads, 746 cache_size, 747 head_dim, 748 device="cpu", 749 dtype=torch.float32, 750 ) 751 for _ in range(2 * self.num_blocks) 752 ], 753 ) 754 # Specify dims that would be dynamic during calibration 755 # Note: Assume cache size fixed shape as torch dynamic shape cannot handle dim 3 being 756 # combination of 2 dynamic dims 757 if get_dym_shape: 758 nt = Dim("num_token", max=num_token) 759 cache_dims = tuple(({} for _ in range(2 * self.num_blocks))) 760 dynamic_shapes = ( 761 {0: Dim.STATIC, 1: nt, 2: Dim.STATIC}, 762 {0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: nt + cache_size}, 763 {0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: Dim.STATIC}, 764 cache_dims, 765 ) 766 return example_inputs, dynamic_shapes 767 768 return example_inputs 769