xref: /aosp_15_r20/external/executorch/examples/models/llama/llama_transformer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# @lint-ignore-every LICENSELINT
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# Llama 2 is licensed under the LLAMA 2 Community License,
6# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
7
8# Please refer to README.md in the same folder for more information.
9
10from dataclasses import dataclass
11from functools import partial
12from typing import Dict, Optional, Tuple
13
14import torch
15import torch.nn.functional as F
16
17from executorch.examples.models.llama.rope import (
18    hf_apply_rotary_emb,
19    hf_precompute_freqs_cis,
20    precompute_freqs_cis,
21    RotaryEmbedding,
22)
23
24from torch import nn
25
26
27class RMSNorm(torch.nn.Module):
28    def __init__(self, dim: int, eps: float = 1e-6):
29        """
30        Initialize the RMSNorm normalization layer.
31
32        Args:
33            dim (int): The dimension of the input tensor.
34            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
35
36        Attributes:
37            eps (float): A small value added to the denominator for numerical stability.
38            weight (nn.Parameter): Learnable scaling parameter.
39
40        """
41        super().__init__()
42        self.dim = dim
43        self.eps = eps
44        self.weight = nn.Parameter(torch.ones(dim))
45
46    def _norm(self, x):
47        """
48        Apply the RMSNorm normalization to the input tensor.
49
50        Args:
51            x (torch.Tensor): The input tensor.
52
53        Returns:
54            torch.Tensor: The normalized tensor.
55
56        """
57        return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
58
59    def forward(self, x):
60        """
61        Forward pass through the RMSNorm layer.
62
63        Args:
64            x (torch.Tensor): The input tensor.
65
66        Returns:
67            torch.Tensor: The output tensor after applying RMSNorm.
68
69        """
70        output = self._norm(x.float()).type_as(x)
71        return output * self.weight
72
73
74def find_multiple(n: int, k: int) -> int:
75    if n % k == 0:
76        return n
77    return n + k - (n % k)
78
79
80@dataclass
81class ModelArgs:
82    dim: int = 4096
83    n_layers: int = 32
84    n_heads: int = 32
85    n_kv_heads: Optional[int] = None
86    vocab_size: int = -1  # defined later by tokenizer
87    hidden_dim: Optional[int] = None
88    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
89    ffn_dim_multiplier: Optional[float] = None
90    norm_eps: float = 1e-5
91    max_batch_size: int = 32
92    max_seq_len: int = 2048
93    moe: bool = False  # True to enable the MoE (Mixture of Experts)
94    num_experts: int = 8  # Number of experts
95    num_activated_experts: int = 2  # Number of experts to activate
96    use_kv_cache: bool = False  # Use key/value cache
97    use_sdpa_with_kv_cache_op: bool = (
98        False  # Use custom sdpa op that updates kv cache in-place
99    )
100    # Generate logits for all inputs. When it's True, it would take big memory usage
101    # at runtime. Enable it only necessary (e.g., use perplexity tools that requires
102    # logits for all input tokens.)
103    generate_full_logits: bool = False
104    enable_dynamic_shape: bool = False  # export model with dynamic shape support
105    # A dictionary mapping from pruned token-id to original token-id
106    input_prune_map: Optional[Dict[int, int]] = None
107    # A dictionary mapping from pruned token-id to original token-id
108    output_prune_map: Optional[Dict[int, int]] = None
109    use_hf_rope: bool = False  # Use HuggingFace's RoPE implementation
110    rope_theta: Optional[float] = (
111        None  # The official name to override self.rope_freq_base.
112    )
113    rope_freq_base: float = 10000.0  # The base frequency for RoPE. Keep it for BC.
114    use_scaled_rope: bool = False  # Use scaled RoPE, introduced in llama3.1.
115    # Additional Model Metadata needed at runtime
116    bos_idx: int = 1
117    eos_idx: int = 3
118    bos_count: int = -1  # i.e., a single EOS is used as BOS
119    eos_count: int = 2
120
121    quantization_args: Optional[dict] = None
122    lora_args: Optional[dict] = None
123
124    def __post_init__(self):
125        if self.n_kv_heads is None:
126            self.n_kv_heads = self.n_heads
127
128        # rope_theta overrides rope_freq_base since it's the official name.
129        if self.rope_theta is not None:
130            self.rope_freq_base = self.rope_theta
131
132        if self.use_sdpa_with_kv_cache_op:
133            assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache"
134
135        if self.hidden_dim is None:
136            # If hidden_dim is not explicitly set in the ModelArgs,
137            # then calculate implicitly based on dim and also multiple of `args.multiple_of`
138            multiple_of = self.multiple_of
139            hidden_dim = 4 * self.dim
140            hidden_dim = int(2 * hidden_dim / 3)
141            if self.ffn_dim_multiplier is not None:
142                hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
143            self.hidden_dim = find_multiple(hidden_dim, multiple_of)
144
145
146class KVCache(nn.Module):
147    def __init__(
148        self,
149        max_batch_size: int,
150        max_seq_length: int,
151        n_heads: int,
152        head_dim: int,
153        transpose_cache: bool,
154        enable_dynamic_shape: bool,
155        dtype=torch.float32,
156    ):
157        super().__init__()
158        self.max_seq_length = max_seq_length
159        self.is_transposed = transpose_cache
160        if transpose_cache:
161            cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
162        else:
163            cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
164
165        self.max_batch_size = max_batch_size
166        self.n_heads = n_heads
167        self.head_dim = head_dim
168        self.transpose_cache = transpose_cache
169        self.enable_dynamic_shape = enable_dynamic_shape
170        self.register_buffer(
171            "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
172        )
173        self.register_buffer(
174            "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu")
175        )
176
177    def update(
178        self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor
179    ) -> Tuple[torch.Tensor, torch.Tensor]:
180        # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
181        if self.enable_dynamic_shape:
182            start_pos = input_pos[0].item()
183            torch._check_is_size(start_pos)
184            torch._check(start_pos < self.max_seq_length)
185            dim_to_slice = 2 if self.transpose_cache else 1
186            seq_length = k_val.size(dim_to_slice)
187            # Replace the entry in the cache for this token
188            # The following lines are equivalent to:
189            # cache_k[:bsz, start_pos : start_pos + seqlen] = xk
190            # cache_v[:bsz, start_pos : start_pos + seqlen] = xv
191            # when dim_to_slice is 1
192            # We use .narrow() here to make the compiler happy
193            # pyre-ignore: Incompatible parameter type [6]
194            narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
195            # pyre-ignore: Incompatible parameter type [6]
196            narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
197
198            narrowed_k.copy_(k_val)
199            narrowed_v.copy_(v_val)
200            return self.k_cache, self.v_cache
201        else:
202            k_out = self.k_cache
203            v_out = self.v_cache
204            if self.transpose_cache:
205                k_out[:, :, input_pos] = k_val
206                v_out[:, :, input_pos] = v_val
207            else:
208                k_out[:, input_pos] = k_val
209                v_out[:, input_pos] = v_val
210
211            return k_out, v_out
212
213
214class SDPA(nn.Module):
215    def __init__(
216        self,
217        kv_cache: KVCache,
218        dim: int,
219        head_dim: int,
220        n_rep: int,
221        max_seq_len: int,
222        enable_dynamic_shape: bool,
223    ):
224        super().__init__()
225        self.kv_cache = kv_cache
226        self.dim = dim
227        self.head_dim = head_dim
228        self.n_rep = n_rep
229        self.max_seq_len = max_seq_len
230        self.enable_dynamic_shape = enable_dynamic_shape
231
232    def forward(
233        self,
234        input_pos: torch.Tensor,
235        q: torch.Tensor,  # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
236        k: torch.Tensor,  # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
237        v: torch.Tensor,  # (bs, seqlen, n_local_kv_heads, head_dim)
238        bsz,
239        seqlen,
240        mask: torch.Tensor,
241    ) -> torch.Tensor:
242        q = q.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
243        k = k.transpose(1, 2)
244        v = v.transpose(1, 2)
245
246        k, v = self.kv_cache.update(input_pos, k, v)
247        if self.enable_dynamic_shape:
248            start_pos = input_pos[-1].item()
249            torch._check_is_size(start_pos)
250            torch._check(start_pos < self.max_seq_len)
251            seq_length = q.size(2)
252            # pyre-ignore: Incompatible parameter type [6]
253            attn_mask = mask.narrow(0, start_pos, seq_length)
254        else:
255            attn_mask = mask[None, None, input_pos]
256
257        k = k.repeat_interleave(self.n_rep, dim=1)
258        v = v.repeat_interleave(self.n_rep, dim=1)
259        y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
260
261        return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
262
263
264class Attention(nn.Module):
265    def __init__(self, args: ModelArgs, layer_id: int):
266        super().__init__()
267        self.use_kv_cache = args.use_kv_cache
268        self.n_heads = args.n_heads
269        self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
270        assert self.n_heads % self.n_kv_heads == 0
271        model_parallel_size = 1
272        self.n_local_heads = self.n_heads // model_parallel_size
273        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
274        self.n_rep = self.n_local_heads // self.n_local_kv_heads
275        self.head_dim = args.dim // self.n_heads
276        self.max_batch_size = args.max_batch_size
277        self.max_seq_len = args.max_seq_len
278        self.dim = args.dim
279        self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
280        self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
281        self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
282        self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
283
284        self.layer_id = layer_id
285
286        causal_mask = torch.tril(
287            torch.ones(
288                self.max_seq_len,
289                self.max_seq_len,
290                dtype=torch.bool,
291                device="cpu",
292            )
293        )
294        self.register_buffer("mask", causal_mask, persistent=False)
295
296        if self.use_kv_cache:
297            self.kv_cache = KVCache(
298                args.max_batch_size,
299                args.max_seq_len,
300                self.n_kv_heads,
301                self.head_dim,
302                not args.use_sdpa_with_kv_cache_op,  # if we are using the custom op dont transpose the cache. Expect untransposed q k v
303                args.enable_dynamic_shape,
304            )
305            self.SDPA = SDPA(
306                kv_cache=self.kv_cache,
307                dim=self.dim,
308                head_dim=self.head_dim,
309                n_rep=self.n_rep,
310                max_seq_len=self.max_seq_len,
311                enable_dynamic_shape=args.enable_dynamic_shape,
312            )
313        if args.use_hf_rope:
314            self.apply_rotary_emb = hf_apply_rotary_emb
315        else:
316            self.apply_rotary_emb = RotaryEmbedding()
317
318    def forward(
319        self,
320        x: torch.Tensor,
321        freqs_cos: torch.Tensor,
322        freqs_sin: torch.Tensor,
323        input_pos: Optional[torch.Tensor] = None,
324    ):
325        bsz, seqlen, _ = x.shape
326
327        # QKV
328        q, k, v = self.wq(x), self.wk(x), self.wv(x)
329        # We need view_copy elimination
330        q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
331        k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
332        v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
333
334        # RoPE relative positional embeddings
335        q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin)
336
337        if self.use_kv_cache:
338            assert input_pos is not None
339            output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
340            return self.wo(output)
341
342        q = q.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
343        k = k.transpose(1, 2)
344        v = v.transpose(1, 2)
345
346        # grouped multiquery attention: expand out keys and values
347        k = k.repeat_interleave(self.n_rep, dim=1)
348        v = v.repeat_interleave(self.n_rep, dim=1)
349
350        assert hasattr(self, "mask")
351
352        mask = self.mask[:seqlen, :seqlen]
353
354        output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
355
356        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
357
358        output = self.wo(output)
359
360        return output
361
362
363class FeedForward(nn.Module):
364    def __init__(self, args: ModelArgs):
365        super().__init__()
366        assert args.hidden_dim is not None
367        hidden_dim: int = args.hidden_dim
368        self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
369        self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)
370        self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)
371
372    def forward(self, x):
373        return self.w2(F.silu(self.w1(x)) * self.w3(x))
374
375
376class ConditionalFeedForward(nn.Module):
377    def __init__(self, args: ModelArgs):
378        super().__init__()
379        self.dim = args.dim
380        hidden_dim = args.hidden_dim
381        if hidden_dim is None:
382            # If hidden_dim is not explicitly set in the ModelArgs,
383            # then calculate implicitly based on dim and also multiple of `args.multiple_of`
384            multiple_of = args.multiple_of
385            hidden_dim = 4 * self.dim
386            hidden_dim = int(2 * hidden_dim / 3)
387            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
388
389        self.w1 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim))
390        self.w2 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim))
391        self.w3 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim))
392        self.num_experts = args.num_experts
393
394    def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor:
395        w1_weights = self.w1[expert_indices].transpose(-1, -2)  # [T, A, D, D]
396        w3_weights = self.w3[expert_indices].transpose(-1, -2)  # [T, A, D, D]
397        w2_weights = self.w2[expert_indices]  # [T, A, D, D]
398        x1 = F.silu(torch.einsum("ti,taio -> tao", x, w1_weights))
399        x3 = torch.einsum("ti, taio -> tao", x, w3_weights)
400        expert_outs = torch.einsum("tao, taoi -> tai", (x1 * x3), w2_weights)
401        return expert_outs
402
403
404class MOEFeedForward(nn.Module):
405    def __init__(self, config) -> None:
406        super().__init__()
407        self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
408        self.cond_ffn = ConditionalFeedForward(config)
409        self.dim = config.dim
410
411    def forward(self, x: torch.Tensor) -> torch.Tensor:
412        x = x.view(-1, self.dim)
413        # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
414        # x: [T, D]
415        scores = self.gate(x)  # [T, E]
416        expert_weights, expert_indices = torch.topk(scores, 2, dim=-1)  # [T, A], [T, A]
417        expert_weights = expert_weights.softmax(dim=-1)  # [T, A]
418        expert_outs = self.cond_ffn(x, expert_indices)
419        return torch.einsum("tai,ta -> ti", expert_outs, expert_weights)
420
421
422class TransformerBlock(nn.Module):
423    def __init__(self, layer_id: int, args: ModelArgs):
424        super().__init__()
425        self.use_kv_cache = args.use_kv_cache
426        self.n_heads = args.n_heads
427        self.dim = args.dim
428        self.head_dim = args.dim // args.n_heads
429        self.attention = Attention(args, layer_id)
430        if args.moe:
431            self.block_sparse_moe = MOEFeedForward(args)
432        else:
433            self.feed_forward = FeedForward(args)
434        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
435        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
436
437    def forward(self, x, freqs_cos, freqs_sin, input_pos=None):  # x: 1xN
438        h = self.attention.forward(
439            self.attention_norm(x), freqs_cos, freqs_sin, input_pos
440        )
441
442        h = x + h
443        if hasattr(self, "block_sparse_moe"):
444            out = h + self.block_sparse_moe(self.ffn_norm(h))
445        else:
446            out = h + self.feed_forward(self.ffn_norm(h))
447        return out
448
449
450class Transformer(nn.Module):
451    def __init__(self, params: ModelArgs):
452        super().__init__()
453        self.params = params
454        self.vocab_size = params.vocab_size
455        self.n_layers = params.n_layers
456
457        self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
458        self.layers = torch.nn.ModuleList()
459        for layer_id in range(params.n_layers):
460            self.layers.append(TransformerBlock(layer_id, params))
461        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
462        self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
463        self.use_kv_cache = params.use_kv_cache
464        self.generate_full_logits = params.generate_full_logits
465        self.max_seq_len = params.max_seq_len
466        self.input_prune_map = params.input_prune_map
467        self.output_prune_map = params.output_prune_map
468        if params.use_hf_rope:
469            self.precompute_freqs_cis = hf_precompute_freqs_cis
470        else:
471            self.precompute_freqs_cis = partial(
472                precompute_freqs_cis, use_scaled=params.use_scaled_rope
473            )
474        freqs_cos, freqs_sin = self.precompute_freqs_cis(
475            params.dim // params.n_heads,
476            (
477                params.max_seq_len  # Normal llama2.
478                if params.ffn_dim_multiplier is None
479                else params.max_seq_len * 2  # Sharded checkpoint.
480            ),
481            params.rope_freq_base,
482        )
483        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
484        self.register_buffer("freqs_sin", freqs_sin, persistent=False)
485
486    def forward(
487        self,
488        tokens: Optional[torch.LongTensor] = None,  # tokens
489        input_pos: Optional[
490            torch.LongTensor
491        ] = None,  # Scalar tensor indicating size of window of the caches
492        h: Optional[torch.FloatTensor] = None,  # embeddings
493    ) -> torch.Tensor:
494        if (tokens is None) ^ (h is not None):
495            raise ValueError(
496                "You cannot specify both tokens and h at the same time, and must specify either one"
497            )
498        if tokens is not None and h is None:
499            h = self.tok_embeddings(tokens)
500        seqlen = h.shape[1]
501
502        if self.use_kv_cache:
503            assert (
504                input_pos is not None
505            ), "input_pos must be provided when use_kv_cache is True"
506
507            if self.params.enable_dynamic_shape:
508                # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
509                input_pos_item = input_pos[-1].item()
510                torch._check_is_size(input_pos_item)
511                torch._check(input_pos_item < self.params.max_seq_len)
512                # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
513                freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen)
514                # pyre-ignore: Incompatible parameter type [6]
515                freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen)
516            else:
517                # When not using dynamic shape, use of the .item results in
518                # symints, due to querying the data from tensor.
519                # this path avoids that for mps backend, although probably mps backend
520                # can support dynamic shape?
521                freqs_cos = self.freqs_cos[input_pos]
522                freqs_sin = self.freqs_sin[input_pos]
523
524        else:
525            assert input_pos is None, "input_pos is unused when use_kv_cache is False"
526            freqs_cos = self.freqs_cos[:seqlen]
527            freqs_sin = self.freqs_sin[:seqlen]
528
529        for layer in self.layers:
530            h = layer(
531                h,
532                freqs_cos,
533                freqs_sin,
534                input_pos,
535            )
536
537        if not self.generate_full_logits:
538            # Only the last logit is used for the new generated token
539            h = h[:, -1, :]
540
541        h = self.norm(h)
542
543        logits = self.output(h)
544
545        if self.output_prune_map is not None:
546            # expand to original size so that downstream applications can use the logits as-is.
547            if self.generate_full_logits:
548                # (1, seq_len, pruned_size) -> (1, seq_len, original_size)
549                expanded_logits = torch.full(
550                    [logits.shape[0], logits.shape[1], self.vocab_size],
551                    float("-inf"),
552                    device=logits.device,
553                    dtype=logits.dtype,
554                )
555                expanded_logits[:, :, list(self.output_prune_map.values())] = logits
556            else:
557                # (1, pruned_size) -> (1, original_size)
558                expanded_logits = torch.full(
559                    [logits.shape[0], self.vocab_size],
560                    float("-inf"),
561                    device=logits.device,
562                    dtype=logits.dtype,
563                )
564                expanded_logits[:, list(self.output_prune_map.values())] = logits
565            logits = expanded_logits
566
567        return logits
568