xref: /aosp_15_r20/external/executorch/examples/mediatek/models/llm_models/modeling_common.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1"""Common backbone across multiple models"""
2
3import math
4
5import numpy as np
6import torch
7from models.llm_models.configuration_base import BaseConfig
8from models.llm_models.modeling_base import BaseModelChunk
9from torch import nn
10from torch.export import Dim
11
12torch.manual_seed(42)
13np.random.seed(42)
14
15
16# flake8: noqa: C901
17
18
19class RMSNorm(nn.Module):
20    def __init__(self, hidden_size, eps=1e-6):
21        super().__init__()
22        self.weight = nn.Parameter(torch.ones(hidden_size))
23        self.variance_epsilon = eps
24
25    def forward(self, hidden_states):
26        variance = hidden_states.pow(2).mean(-1, keepdim=True)
27        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
28        return self.weight * hidden_states
29
30
31class Gelu(nn.Module):
32    def __init__(self):
33        super().__init__()
34
35    def forward(self, x: torch.Tensor) -> torch.Tensor:
36        return x * torch.sigmoid(1.702 * x)
37
38
39class MLP(nn.Module):
40    def __init__(self, config: BaseConfig):
41        super().__init__()
42        hidden_size = config.hidden_size
43        intermediate_size = config.intermediate_size
44
45        self.gate_proj = nn.Linear(hidden_size, intermediate_size)
46        self.down_proj = nn.Linear(intermediate_size, hidden_size)
47        self.up_proj = nn.Linear(hidden_size, intermediate_size)
48
49    def forward(self, x):
50        gate = self.gate_proj(x)
51        up = self.up_proj(x)
52        pre_down = gate * torch.sigmoid(gate) * up
53        down = self.down_proj(pre_down)
54
55        return down
56
57
58class Attention(nn.Module):
59    def __init__(self, config: BaseConfig):
60        super().__init__()
61        self.config = config
62        self.hidden_size = config.hidden_size
63        self.num_heads = config.num_attention_heads
64        self.num_key_value_heads = config.num_key_value_heads
65        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
66        self.head_dim = self.hidden_size // self.num_heads
67        self.attn_scale = math.sqrt(self.head_dim)
68
69        if config.combine_qkv:
70            self.qkv_proj = nn.Linear(
71                self.hidden_size,
72                (2 * self.num_key_value_heads * self.head_dim) + self.hidden_size,
73            )
74        else:
75            self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
76            self.k_proj = nn.Linear(
77                self.hidden_size, self.num_key_value_heads * self.head_dim
78            )
79            self.v_proj = nn.Linear(
80                self.hidden_size, self.num_key_value_heads * self.head_dim
81            )
82        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)
83
84    def apply_rotary_pos_emb_mtk(self, q, k, cos, sin):
85        q1 = q[..., : q.shape[-1] // 2]
86        q2 = q[..., q.shape[-1] // 2 :]
87        q_rotated = torch.cat((-q2, q1), dim=-1)
88        k1 = k[..., : k.shape[-1] // 2]
89        k2 = k[..., k.shape[-1] // 2 :]
90        k_rotated = torch.cat((-k2, k1), dim=-1)
91
92        q_embed = q * cos + q_rotated * sin
93        k_embed = k * cos + k_rotated * sin
94        return q_embed, k_embed
95
96    def repeat_kv(self, hidden_states, batch, q_len, n_rep):
97        if isinstance(hidden_states, list):
98            output = []
99            for hs in hidden_states:
100                output.append(
101                    hs.repeat(1, 1, n_rep, 1).view(batch, 1, q_len, self.head_dim)
102                )
103            return output
104        else:
105            hidden_states = hidden_states.repeat(1, 1, n_rep, 1)
106            return hidden_states.view(batch, self.num_heads, q_len, self.head_dim)
107
108    def forward(
109        self,
110        hidden_states,  # (b, t, 4096)
111        mask,  # (b, 1, t, c+t)
112        pos_emb,  # (b, 2, t, head dim)
113        past_key,  # (b, num kv heads, c, head dim)
114        past_value,  # (b, num kv heads, c, head dim)
115    ):
116        bsz, q_len, _ = hidden_states.size()
117        c_len = past_key.size()[2]
118
119        if self.config.combine_qkv:
120            proj = self.qkv_proj(hidden_states)
121            query_states = (
122                proj[:, :, : self.config.hidden_size]
123                .view(bsz, q_len, self.num_heads, self.head_dim)
124                .transpose(1, 2)
125            )
126            key_states = (
127                proj[
128                    :,
129                    :,
130                    self.config.hidden_size : self.config.hidden_size
131                    + self.num_key_value_heads * self.head_dim,
132                ]
133                .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
134                .transpose(1, 2)
135            )
136            value_states = (
137                proj[
138                    :,
139                    :,
140                    self.config.hidden_size
141                    + self.num_key_value_heads * self.head_dim :,
142                ]
143                .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
144                .transpose(1, 2)
145            )
146        else:
147            query_states = (
148                self.q_proj(hidden_states)
149                .view(bsz, q_len, self.num_heads, self.head_dim)
150                .transpose(1, 2)
151            )
152            key_states = (
153                self.k_proj(hidden_states)
154                .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
155                .transpose(1, 2)
156            )
157            value_states = (
158                self.v_proj(hidden_states)
159                .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
160                .transpose(1, 2)
161            )
162
163        if self.config.position_embedding == "rope":
164            cos, sin = torch.split(pos_emb, 1, dim=1)
165            query_states, key_states = self.apply_rotary_pos_emb_mtk(
166                query_states, key_states, cos, sin
167            )
168
169        key_states = torch.cat([past_key, key_states], dim=2)
170        value_states = torch.cat([past_value, value_states], dim=2)
171        key_states_out = key_states
172        value_states_out = value_states
173        if self.num_key_value_groups > 1:
174            key_states = self.repeat_kv(
175                key_states, bsz, q_len + c_len, self.num_key_value_groups
176            )
177            value_states = self.repeat_kv(
178                value_states, bsz, q_len + c_len, self.num_key_value_groups
179            )
180        attn_weights = (
181            torch.matmul(query_states, key_states.transpose(2, 3)) / self.attn_scale
182        )
183        attn_weights = attn_weights + mask
184        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
185
186        attn_output = torch.matmul(attn_weights, value_states)
187        attn_output = attn_output.transpose(1, 2)
188        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
189        attn_output = self.o_proj(attn_output)
190
191        key_states_out = key_states_out[:, :, q_len:, :]
192        value_states_out = value_states_out[:, :, q_len:, :]
193
194        return attn_output, key_states_out, value_states_out
195
196
197class DecoderLayer(nn.Module):
198    def __init__(
199        self,
200        config: BaseConfig,
201        return_attn=False,
202        jit_trace=False,
203        attn_class=Attention,
204        mlp_class=MLP,
205    ):
206        super().__init__()
207        self.hidden_size = config.hidden_size
208        self.return_attn = return_attn
209        self.jit_trace = jit_trace
210        self.self_attn = attn_class(config)
211        self.mlp = mlp_class(config)
212        if config.norm == "RMSNorm":
213            self.input_norm = RMSNorm(config.hidden_size, eps=config.norm_eps).float()
214            self.post_attention_norm = RMSNorm(
215                config.hidden_size, eps=config.norm_eps
216            ).float()
217        else:
218            self.input_norm = nn.LayerNorm(
219                config.hidden_size, eps=config.norm_eps
220            ).float()
221            self.post_attention_norm = nn.LayerNorm(
222                config.hidden_size, eps=config.norm_eps
223            ).float()
224
225    def forward(
226        self,
227        hidden_states,  # (b, t, hidden_dim)
228        mask,  # (b, 1, t, c+t)
229        pos_emb,  # (b, 2, t, head_dim)
230        past_key,  # (b, num_kv_head, c, head_dim)
231        past_value,  # (b, num_kv_head, c, head_dim)
232    ):
233        residual = hidden_states
234        if self.jit_trace:
235            hidden_states = self.input_norm(hidden_states)
236        else:
237            dtype = hidden_states.dtype
238            hidden_states = self.input_norm(hidden_states.to(torch.float32)).to(dtype)
239
240        layer_device = hidden_states.device
241
242        # Self Attention
243        attn_output, present_key, present_value = self.self_attn(
244            hidden_states=hidden_states.to(layer_device),
245            mask=mask.to(layer_device),
246            pos_emb=pos_emb.to(layer_device),
247            past_key=past_key.to(layer_device),
248            past_value=past_value.to(layer_device),
249        )
250        hidden_states = residual.to(layer_device) + attn_output
251
252        # Fully Connected
253        residual = hidden_states
254        if self.jit_trace:
255            hidden_states = self.post_attention_norm(hidden_states)
256        else:
257            dtype = hidden_states.dtype
258            hidden_states = self.post_attention_norm(
259                hidden_states.to(torch.float32)
260            ).to(dtype)
261        hidden_states = self.mlp(hidden_states)
262        hidden_states = residual + hidden_states
263
264        if self.return_attn:
265            return hidden_states, present_key, present_value, attn_output
266        return hidden_states, present_key, present_value
267
268
269class ModelChunk(BaseModelChunk):
270    def __init__(
271        self,
272        config: BaseConfig,
273        num_blocks,
274        chunk_idx,
275        dtype=torch.float32,
276        include_tail=False,
277        return_attn=False,
278        jit_trace=False,
279        decoder_class=DecoderLayer,
280    ):
281        super().__init__(
282            config, num_blocks, chunk_idx, dtype, include_tail, return_attn, jit_trace
283        )
284        self.head_dim = config.hidden_size // config.num_attention_heads
285        self.layers = nn.ModuleList(
286            [
287                decoder_class(config, return_attn=return_attn, jit_trace=jit_trace)
288                for _ in range(num_blocks)
289            ]
290        )
291
292        if self.config.use_stable_embedding and self.chunk_idx == 0:
293            self.embed_layer_norm = nn.LayerNorm(config.hidden_size).float()
294
295        if self.include_tail:
296            if config.norm == "RMSNorm":
297                self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps).float()
298            else:
299                self.norm = nn.LayerNorm(
300                    config.hidden_size, eps=config.norm_eps
301                ).float()
302            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
303
304    def forward(self, inputs_embeds, mask, pos_emb, *cache):
305        if not self.jit_trace:
306            assert (
307                len(cache) == 2 * self.num_blocks
308            ), f"split cache wrong number of input caches: {len(cache)} != 2*{self.num_blocks}"
309            assert (
310                cache[0].shape[0] == inputs_embeds.size()[0]
311            ), f"split cache batch size mismatch: {cache[0].shape[0]} != {inputs_embeds.size()[0]}"
312
313        inputs_embeds = inputs_embeds.to(self.device_list[0])
314
315        if self.config.use_stable_embedding and self.chunk_idx == 0:
316            if self.jit_trace:
317                inputs_embeds = self.embed_layer_norm(inputs_embeds)
318            else:
319                inputs_embeds = self.embed_layer_norm(
320                    inputs_embeds.to(torch.float32)
321                ).to(self.dtype)
322
323        hidden_states = inputs_embeds
324
325        next_key_cache = []
326        next_value_cache = []
327        if self.return_attn:
328            attn_outputs = []
329
330        # decoder layers
331        for idx, decoder_layer in enumerate(self.layers):
332            decoder_outputs = decoder_layer(
333                hidden_states.to(self.device_list[idx]),
334                mask=mask.to(self.device_list[idx]),
335                pos_emb=pos_emb.to(self.device_list[idx]),
336                past_key=cache[idx].to(self.device_list[idx]),
337                past_value=cache[self.num_blocks + idx].to(self.device_list[idx]),
338            )
339            hidden_states = decoder_outputs[0]
340            next_key_cache.append(decoder_outputs[1].to(inputs_embeds.device))
341            next_value_cache.append(decoder_outputs[2].to(inputs_embeds.device))
342            if self.return_attn:
343                attn_outputs.append(decoder_outputs[3].to(inputs_embeds.device))
344
345        if self.include_tail:
346            if self.jit_trace:
347                hidden_states = self.norm(hidden_states)
348            else:
349                hidden_states = self.norm(hidden_states.to(torch.float32)).to(
350                    self.dtype
351                )
352            hidden_states = self.lm_head(hidden_states)
353
354        if self.return_attn:
355            return hidden_states, *next_key_cache, *next_value_cache, *attn_outputs
356        return hidden_states, *next_key_cache, *next_value_cache
357
358    def load_weights(self, state_dict, state_dict_start_idx):
359        if state_dict is None:
360            fake_weights = True
361        else:
362            expected_subkey = f"layers.{state_dict_start_idx}.self_attn.o_proj.weight"
363            state_dict_keys = list(state_dict.keys())
364            temp_key = None
365            input_norm_subkey = None
366            post_attention_norm_subkey = None
367            for key in state_dict_keys:
368                if expected_subkey in key:
369                    temp_key = key
370                if (
371                    f"layers.{state_dict_start_idx}" in key
372                    and "norm" in key
373                    and "input" in key
374                ):
375                    input_norm_subkey = key.split(".")[-2]
376                if (
377                    f"layers.{state_dict_start_idx}" in key
378                    and "norm" in key
379                    and "post_attention" in key
380                ):
381                    post_attention_norm_subkey = key.split(".")[-2]
382            if temp_key is None:
383                raise KeyError(
384                    f"Cannot find layer {state_dict_start_idx}'s o_proj weight inside state_dict. "
385                    f"Please ensure o_proj weight key contains: {expected_subkey}"
386                )
387            if input_norm_subkey is None:
388                raise KeyError(
389                    f"Cannot find layer {state_dict_start_idx}'s input norm weight inside state_dict. "
390                    f"Please ensure input norm weight key contains: layers.{state_dict_start_idx}, norm, and input inside"
391                    " the key string."
392                )
393            if post_attention_norm_subkey is None:
394                raise KeyError(
395                    f"Cannot find layer {state_dict_start_idx}'s post attention norm weight inside state_dict."
396                    f" Please ensure post attention norm weight key contains: layers.{state_dict_start_idx}, norm, and "
397                    "post_attention inside the key string."
398                )
399            prefix = temp_key.split(expected_subkey)[0]
400            fake_weights = False
401
402        outer_layer_idx = state_dict_start_idx
403        self.device_list = []
404        if self.config.use_stable_embedding and self.chunk_idx == 0:
405            if fake_weights:
406                temp_state_dict = {
407                    "embed_layer_norm.weight": torch.rand(
408                        self.config.hidden_size, dtype=torch.float32
409                    ),
410                    "embed_layer_norm.bias": torch.zeros(
411                        self.config.hidden_size, dtype=torch.float32
412                    ),
413                }
414            else:
415                temp_state_dict = {
416                    "embed_layer_norm.weight": state_dict.pop(
417                        f"{prefix}embed_layer_norm.weight"
418                    ).to(torch.float32),
419                    "embed_layer_norm.bias": state_dict.pop(
420                        f"{prefix}embed_layer_norm.bias",
421                        torch.zeros(self.config.hidden_size, dtype=self.dtype),
422                    ).to(torch.float32),
423                }
424        else:
425            temp_state_dict = {}
426
427        for inner_layer_idx in range(self.num_blocks):
428            if fake_weights:
429                if self.config.combine_qkv:
430                    temp_state_dict[
431                        f"layers.{inner_layer_idx}.self_attn.qkv_proj.weight"
432                    ] = torch.rand(
433                        3 * self.config.hidden_size,
434                        self.config.hidden_size,
435                        dtype=self.dtype,
436                    )
437                    temp_state_dict[
438                        f"layers.{inner_layer_idx}.self_attn.qkv_proj.bias"
439                    ] = torch.zeros(
440                        (2 * self.config.num_key_value_heads * self.head_dim)
441                        + self.config.hidden_size,
442                        dtype=self.dtype,
443                    )
444                else:
445                    temp_state_dict = {
446                        **temp_state_dict,
447                        **{
448                            f"layers.{inner_layer_idx}.self_attn.q_proj.weight": torch.rand(
449                                self.config.hidden_size,
450                                self.config.hidden_size,
451                                dtype=self.dtype,
452                            ),
453                            f"layers.{inner_layer_idx}.self_attn.k_proj.weight": torch.rand(
454                                self.config.num_key_value_heads * self.head_dim,
455                                self.config.hidden_size,
456                                dtype=self.dtype,
457                            ),
458                            f"layers.{inner_layer_idx}.self_attn.v_proj.weight": torch.rand(
459                                self.config.num_key_value_heads * self.head_dim,
460                                self.config.hidden_size,
461                                dtype=self.dtype,
462                            ),
463                            f"layers.{inner_layer_idx}.self_attn.q_proj.bias": torch.zeros(
464                                self.config.hidden_size, dtype=self.dtype
465                            ),
466                            f"layers.{inner_layer_idx}.self_attn.k_proj.bias": torch.zeros(
467                                self.config.num_key_value_heads * self.head_dim,
468                                dtype=self.dtype,
469                            ),
470                            f"layers.{inner_layer_idx}.self_attn.v_proj.bias": torch.zeros(
471                                self.config.num_key_value_heads * self.head_dim,
472                                dtype=self.dtype,
473                            ),
474                        },
475                    }
476                temp_state_dict = {
477                    **temp_state_dict,
478                    **{
479                        f"layers.{inner_layer_idx}.self_attn.o_proj.weight": torch.rand(
480                            self.config.hidden_size,
481                            self.config.hidden_size,
482                            dtype=self.dtype,
483                        ),
484                        f"layers.{inner_layer_idx}.mlp.gate_proj.weight": torch.rand(
485                            self.config.intermediate_size,
486                            self.config.hidden_size,
487                            dtype=self.dtype,
488                        ),
489                        f"layers.{inner_layer_idx}.mlp.down_proj.weight": torch.rand(
490                            self.config.hidden_size,
491                            self.config.intermediate_size,
492                            dtype=self.dtype,
493                        ),
494                        f"layers.{inner_layer_idx}.mlp.up_proj.weight": torch.rand(
495                            self.config.intermediate_size,
496                            self.config.hidden_size,
497                            dtype=self.dtype,
498                        ),
499                        f"layers.{inner_layer_idx}.input_norm.weight": torch.rand(
500                            self.config.hidden_size, dtype=torch.float32
501                        ),
502                        f"layers.{inner_layer_idx}.post_attention_norm.weight": torch.rand(
503                            self.config.hidden_size, dtype=torch.float32
504                        ),
505                        f"layers.{inner_layer_idx}.self_attn.o_proj.bias": torch.zeros(
506                            self.config.hidden_size, dtype=self.dtype
507                        ),
508                        f"layers.{inner_layer_idx}.mlp.gate_proj.bias": torch.zeros(
509                            self.config.intermediate_size, dtype=self.dtype
510                        ),
511                        f"layers.{inner_layer_idx}.mlp.down_proj.bias": torch.zeros(
512                            self.config.hidden_size, dtype=self.dtype
513                        ),
514                        f"layers.{inner_layer_idx}.mlp.up_proj.bias": torch.zeros(
515                            self.config.intermediate_size, dtype=self.dtype
516                        ),
517                    },
518                }
519
520                if self.config.norm == "LayerNorm":
521                    temp_state_dict = {
522                        **temp_state_dict,
523                        **{
524                            f"layers.{inner_layer_idx}.input_norm.bias": torch.zeros(
525                                self.config.hidden_size, dtype=torch.float32
526                            ),
527                            f"layers.{inner_layer_idx}.post_attention_norm.bias": torch.zeros(
528                                self.config.hidden_size, dtype=torch.float32
529                            ),
530                        },
531                    }
532
533            else:
534                if self.config.combine_qkv:
535                    temp_state_dict[
536                        f"layers.{inner_layer_idx}.self_attn.qkv_proj.weight"
537                    ] = state_dict.pop(
538                        f"{prefix}layers.{outer_layer_idx}.self_attn.qkv_proj.weight"
539                    )
540                    temp_state_dict[
541                        f"layers.{inner_layer_idx}.self_attn.qkv_proj.bias"
542                    ] = state_dict.pop(
543                        f"{prefix}layers.{outer_layer_idx}.self_attn.qkv_proj.bias",
544                        torch.zeros(
545                            (2 * self.config.num_key_value_heads * self.head_dim)
546                            + self.config.hidden_size,
547                            dtype=self.dtype,
548                        ),
549                    )
550                else:
551                    temp_state_dict = {
552                        **temp_state_dict,
553                        **{
554                            f"layers.{inner_layer_idx}.self_attn.q_proj.weight": state_dict.pop(
555                                f"{prefix}layers.{outer_layer_idx}.self_attn.q_proj.weight"
556                            ),
557                            f"layers.{inner_layer_idx}.self_attn.k_proj.weight": state_dict.pop(
558                                f"{prefix}layers.{outer_layer_idx}.self_attn.k_proj.weight"
559                            ),
560                            f"layers.{inner_layer_idx}.self_attn.v_proj.weight": state_dict.pop(
561                                f"{prefix}layers.{outer_layer_idx}.self_attn.v_proj.weight"
562                            ),
563                            f"layers.{inner_layer_idx}.self_attn.q_proj.bias": state_dict.pop(
564                                f"{prefix}layers.{outer_layer_idx}.self_attn.q_proj.bias",
565                                torch.zeros(self.config.hidden_size, dtype=self.dtype),
566                            ),
567                            f"layers.{inner_layer_idx}.self_attn.k_proj.bias": state_dict.pop(
568                                f"{prefix}layers.{outer_layer_idx}.self_attn.k_proj.bias",
569                                torch.zeros(
570                                    self.config.num_key_value_heads * self.head_dim,
571                                    dtype=self.dtype,
572                                ),
573                            ),
574                            f"layers.{inner_layer_idx}.self_attn.v_proj.bias": state_dict.pop(
575                                f"{prefix}layers.{outer_layer_idx}.self_attn.v_proj.bias",
576                                torch.zeros(
577                                    self.config.num_key_value_heads * self.head_dim,
578                                    dtype=self.dtype,
579                                ),
580                            ),
581                        },
582                    }
583
584                temp_state_dict = {
585                    **temp_state_dict,
586                    **{
587                        f"layers.{inner_layer_idx}.self_attn.o_proj.weight": state_dict.pop(
588                            f"{prefix}layers.{outer_layer_idx}.self_attn.o_proj.weight"
589                        ),
590                        f"layers.{inner_layer_idx}.mlp.gate_proj.weight": state_dict.pop(
591                            f"{prefix}layers.{outer_layer_idx}.mlp.gate_proj.weight"
592                        ),
593                        f"layers.{inner_layer_idx}.mlp.down_proj.weight": state_dict.pop(
594                            f"{prefix}layers.{outer_layer_idx}.mlp.down_proj.weight"
595                        ),
596                        f"layers.{inner_layer_idx}.mlp.up_proj.weight": state_dict.pop(
597                            f"{prefix}layers.{outer_layer_idx}.mlp.up_proj.weight"
598                        ),
599                        f"layers.{inner_layer_idx}.input_norm.weight": state_dict.pop(
600                            f"{prefix}layers.{outer_layer_idx}.{input_norm_subkey}.weight"
601                        ).to(torch.float32),
602                        f"layers.{inner_layer_idx}.post_attention_norm.weight": state_dict.pop(
603                            f"{prefix}layers.{outer_layer_idx}.{post_attention_norm_subkey}.weight"
604                        ).to(
605                            torch.float32
606                        ),
607                        f"layers.{inner_layer_idx}.self_attn.o_proj.bias": state_dict.pop(
608                            f"{prefix}layers.{outer_layer_idx}.self_attn.o_proj.bias",
609                            torch.zeros(self.config.hidden_size, dtype=self.dtype),
610                        ),
611                        f"layers.{inner_layer_idx}.mlp.gate_proj.bias": state_dict.pop(
612                            f"{prefix}layers.{outer_layer_idx}.mlp.gate_proj.bias",
613                            torch.zeros(
614                                self.config.intermediate_size, dtype=self.dtype
615                            ),
616                        ),
617                        f"layers.{inner_layer_idx}.mlp.down_proj.bias": state_dict.pop(
618                            f"{prefix}layers.{outer_layer_idx}.mlp.down_proj.bias",
619                            torch.zeros(self.config.hidden_size, dtype=self.dtype),
620                        ),
621                        f"layers.{inner_layer_idx}.mlp.up_proj.bias": state_dict.pop(
622                            f"{prefix}layers.{outer_layer_idx}.mlp.up_proj.bias",
623                            torch.zeros(
624                                self.config.intermediate_size, dtype=self.dtype
625                            ),
626                        ),
627                    },
628                }
629
630                if self.config.norm == "LayerNorm":
631                    temp_state_dict = {
632                        **temp_state_dict,
633                        **{
634                            f"layers.{inner_layer_idx}.input_norm.bias": state_dict.pop(
635                                f"{prefix}layers.{outer_layer_idx}.{input_norm_subkey}.bias",
636                                torch.zeros(self.config.hidden_size, dtype=self.dtype),
637                            ).to(torch.float32),
638                            f"layers.{inner_layer_idx}.post_attention_norm.bias": state_dict.pop(
639                                f"{prefix}layers.{outer_layer_idx}.{post_attention_norm_subkey}.bias",
640                                torch.zeros(self.config.hidden_size, dtype=self.dtype),
641                            ).to(
642                                torch.float32
643                            ),
644                        },
645                    }
646
647            if torch.cuda.device_count() == 0 or self.jit_trace:
648                self.device_list.append("cpu")
649            else:
650                device_id = outer_layer_idx // (
651                    self.config.num_hidden_layers // torch.cuda.device_count()
652                    + (self.config.num_hidden_layers % torch.cuda.device_count() != 0)
653                )
654                self.device_list.append(f"cuda:{device_id}")
655            outer_layer_idx += 1
656        if self.include_tail:
657            if fake_weights:
658                temp_state_dict = {
659                    **temp_state_dict,
660                    "norm.weight": torch.rand(
661                        self.config.hidden_size, dtype=torch.float32
662                    ),
663                    "lm_head.weight": torch.rand(
664                        self.config.vocab_size,
665                        self.config.hidden_size,
666                        dtype=self.dtype,
667                    ),
668                    "lm_head.bias": torch.zeros(
669                        self.config.vocab_size, dtype=self.dtype
670                    ),
671                }
672                if self.config.norm == "LayerNorm":
673                    temp_state_dict["norm.bias"] = torch.zeros(
674                        self.config.hidden_size, dtype=torch.float32
675                    )
676            else:
677                if self.config.tie_word_embeddings:
678                    lm_head_weight_key = f"{prefix}embed_tokens.weight"
679                    lm_head_bias_key = f"{prefix}embed_tokens.bias"
680                else:
681                    lm_head_weight_key = "lm_head.weight"
682                    lm_head_bias_key = "lm_head.bias"
683                temp_state_dict = {
684                    **temp_state_dict,
685                    **{
686                        "lm_head.weight": state_dict.pop(lm_head_weight_key),
687                        "norm.weight": state_dict.pop(f"{prefix}norm.weight").to(
688                            torch.float32
689                        ),
690                        "lm_head.bias": state_dict.pop(
691                            lm_head_bias_key,
692                            torch.zeros(self.config.vocab_size, dtype=self.dtype),
693                        ),
694                    },
695                }
696                if self.config.norm == "LayerNorm":
697                    temp_state_dict["norm.bias"] = state_dict.pop(
698                        f"{prefix}norm.bias",
699                        torch.zeros(self.config.hidden_size, dtype=self.dtype),
700                    ).to(torch.float32)
701
702        print(f"Loading weights for chunk {self.chunk_idx}")
703        if temp_state_dict.keys() != self.state_dict().keys():
704            temp_state_dict_only_keys = [
705                x for x in temp_state_dict.keys() if x not in self.state_dict().keys()
706            ]
707            model_only_keys = [
708                x for x in self.state_dict().keys() if x not in temp_state_dict.keys()
709            ]
710            raise RuntimeError(
711                f"model state dict keys don't match with state_dict to load into model.\nModel only keys:{model_only_keys}\nstate_dict only keys:{temp_state_dict_only_keys}"
712            )
713        self.load_state_dict(temp_state_dict)
714        for i in range(self.num_blocks):
715            self.layers[i].to(self.device_list[i])
716        if self.config.use_stable_embedding and self.chunk_idx == 0:
717            self.embed_layer_norm.to(self.device_list[0])
718        if self.include_tail:
719            self.norm.to(self.device_list[-1])
720            self.lm_head.to(self.device_list[-1])
721        self.eval()
722
723        return self
724
725    def get_example_inputs(
726        self, num_token: int = 128, cache_size: int = 512, get_dym_shape=False
727    ):
728        head_dim = int(self.config.hidden_size / self.config.num_attention_heads)
729        example_inputs = (
730            torch.randn(
731                1, num_token, self.config.hidden_size, device="cpu", dtype=torch.float32
732            ),
733            torch.randn(
734                1,
735                1,
736                num_token,
737                cache_size + num_token,
738                device="cpu",
739                dtype=torch.float32,
740            ),
741            torch.randn(1, 2, num_token, head_dim, device="cpu", dtype=torch.float32),
742            *[
743                torch.randn(
744                    1,
745                    self.config.num_key_value_heads,
746                    cache_size,
747                    head_dim,
748                    device="cpu",
749                    dtype=torch.float32,
750                )
751                for _ in range(2 * self.num_blocks)
752            ],
753        )
754        # Specify dims that would be dynamic during calibration
755        # Note: Assume cache size fixed shape as torch dynamic shape cannot handle dim 3 being
756        # combination of 2 dynamic dims
757        if get_dym_shape:
758            nt = Dim("num_token", max=num_token)
759            cache_dims = tuple(({} for _ in range(2 * self.num_blocks)))
760            dynamic_shapes = (
761                {0: Dim.STATIC, 1: nt, 2: Dim.STATIC},
762                {0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: nt + cache_size},
763                {0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: Dim.STATIC},
764                cache_dims,
765            )
766            return example_inputs, dynamic_shapes
767
768        return example_inputs
769