1# @lint-ignore-every LICENSELINT 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# Llama 2 is licensed under the LLAMA 2 Community License, 6# Copyright (c) Meta Platforms, Inc. All Rights Reserved. 7 8# Please refer to README.md in the same folder for more information. 9 10from dataclasses import dataclass 11from functools import partial 12from typing import Dict, Optional, Tuple 13 14import torch 15import torch.nn.functional as F 16 17from executorch.examples.models.llama.rope import ( 18 hf_apply_rotary_emb, 19 hf_precompute_freqs_cis, 20 precompute_freqs_cis, 21 RotaryEmbedding, 22) 23 24from torch import nn 25 26 27class RMSNorm(torch.nn.Module): 28 def __init__(self, dim: int, eps: float = 1e-6): 29 """ 30 Initialize the RMSNorm normalization layer. 31 32 Args: 33 dim (int): The dimension of the input tensor. 34 eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 35 36 Attributes: 37 eps (float): A small value added to the denominator for numerical stability. 38 weight (nn.Parameter): Learnable scaling parameter. 39 40 """ 41 super().__init__() 42 self.dim = dim 43 self.eps = eps 44 self.weight = nn.Parameter(torch.ones(dim)) 45 46 def _norm(self, x): 47 """ 48 Apply the RMSNorm normalization to the input tensor. 49 50 Args: 51 x (torch.Tensor): The input tensor. 52 53 Returns: 54 torch.Tensor: The normalized tensor. 55 56 """ 57 return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) 58 59 def forward(self, x): 60 """ 61 Forward pass through the RMSNorm layer. 62 63 Args: 64 x (torch.Tensor): The input tensor. 65 66 Returns: 67 torch.Tensor: The output tensor after applying RMSNorm. 68 69 """ 70 output = self._norm(x.float()).type_as(x) 71 return output * self.weight 72 73 74def find_multiple(n: int, k: int) -> int: 75 if n % k == 0: 76 return n 77 return n + k - (n % k) 78 79 80@dataclass 81class ModelArgs: 82 dim: int = 4096 83 n_layers: int = 32 84 n_heads: int = 32 85 n_kv_heads: Optional[int] = None 86 vocab_size: int = -1 # defined later by tokenizer 87 hidden_dim: Optional[int] = None 88 multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 89 ffn_dim_multiplier: Optional[float] = None 90 norm_eps: float = 1e-5 91 max_batch_size: int = 32 92 max_seq_len: int = 2048 93 moe: bool = False # True to enable the MoE (Mixture of Experts) 94 num_experts: int = 8 # Number of experts 95 num_activated_experts: int = 2 # Number of experts to activate 96 use_kv_cache: bool = False # Use key/value cache 97 use_sdpa_with_kv_cache_op: bool = ( 98 False # Use custom sdpa op that updates kv cache in-place 99 ) 100 # Generate logits for all inputs. When it's True, it would take big memory usage 101 # at runtime. Enable it only necessary (e.g., use perplexity tools that requires 102 # logits for all input tokens.) 103 generate_full_logits: bool = False 104 enable_dynamic_shape: bool = False # export model with dynamic shape support 105 # A dictionary mapping from pruned token-id to original token-id 106 input_prune_map: Optional[Dict[int, int]] = None 107 # A dictionary mapping from pruned token-id to original token-id 108 output_prune_map: Optional[Dict[int, int]] = None 109 use_hf_rope: bool = False # Use HuggingFace's RoPE implementation 110 rope_theta: Optional[float] = ( 111 None # The official name to override self.rope_freq_base. 112 ) 113 rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. 114 use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. 115 # Additional Model Metadata needed at runtime 116 bos_idx: int = 1 117 eos_idx: int = 3 118 bos_count: int = -1 # i.e., a single EOS is used as BOS 119 eos_count: int = 2 120 121 quantization_args: Optional[dict] = None 122 lora_args: Optional[dict] = None 123 124 def __post_init__(self): 125 if self.n_kv_heads is None: 126 self.n_kv_heads = self.n_heads 127 128 # rope_theta overrides rope_freq_base since it's the official name. 129 if self.rope_theta is not None: 130 self.rope_freq_base = self.rope_theta 131 132 if self.use_sdpa_with_kv_cache_op: 133 assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache" 134 135 if self.hidden_dim is None: 136 # If hidden_dim is not explicitly set in the ModelArgs, 137 # then calculate implicitly based on dim and also multiple of `args.multiple_of` 138 multiple_of = self.multiple_of 139 hidden_dim = 4 * self.dim 140 hidden_dim = int(2 * hidden_dim / 3) 141 if self.ffn_dim_multiplier is not None: 142 hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) 143 self.hidden_dim = find_multiple(hidden_dim, multiple_of) 144 145 146class KVCache(nn.Module): 147 def __init__( 148 self, 149 max_batch_size: int, 150 max_seq_length: int, 151 n_heads: int, 152 head_dim: int, 153 transpose_cache: bool, 154 enable_dynamic_shape: bool, 155 dtype=torch.float32, 156 ): 157 super().__init__() 158 self.max_seq_length = max_seq_length 159 self.is_transposed = transpose_cache 160 if transpose_cache: 161 cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 162 else: 163 cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) 164 165 self.max_batch_size = max_batch_size 166 self.n_heads = n_heads 167 self.head_dim = head_dim 168 self.transpose_cache = transpose_cache 169 self.enable_dynamic_shape = enable_dynamic_shape 170 self.register_buffer( 171 "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") 172 ) 173 self.register_buffer( 174 "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") 175 ) 176 177 def update( 178 self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor 179 ) -> Tuple[torch.Tensor, torch.Tensor]: 180 # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache 181 if self.enable_dynamic_shape: 182 start_pos = input_pos[0].item() 183 torch._check_is_size(start_pos) 184 torch._check(start_pos < self.max_seq_length) 185 dim_to_slice = 2 if self.transpose_cache else 1 186 seq_length = k_val.size(dim_to_slice) 187 # Replace the entry in the cache for this token 188 # The following lines are equivalent to: 189 # cache_k[:bsz, start_pos : start_pos + seqlen] = xk 190 # cache_v[:bsz, start_pos : start_pos + seqlen] = xv 191 # when dim_to_slice is 1 192 # We use .narrow() here to make the compiler happy 193 # pyre-ignore: Incompatible parameter type [6] 194 narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) 195 # pyre-ignore: Incompatible parameter type [6] 196 narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) 197 198 narrowed_k.copy_(k_val) 199 narrowed_v.copy_(v_val) 200 return self.k_cache, self.v_cache 201 else: 202 k_out = self.k_cache 203 v_out = self.v_cache 204 if self.transpose_cache: 205 k_out[:, :, input_pos] = k_val 206 v_out[:, :, input_pos] = v_val 207 else: 208 k_out[:, input_pos] = k_val 209 v_out[:, input_pos] = v_val 210 211 return k_out, v_out 212 213 214class SDPA(nn.Module): 215 def __init__( 216 self, 217 kv_cache: KVCache, 218 dim: int, 219 head_dim: int, 220 n_rep: int, 221 max_seq_len: int, 222 enable_dynamic_shape: bool, 223 ): 224 super().__init__() 225 self.kv_cache = kv_cache 226 self.dim = dim 227 self.head_dim = head_dim 228 self.n_rep = n_rep 229 self.max_seq_len = max_seq_len 230 self.enable_dynamic_shape = enable_dynamic_shape 231 232 def forward( 233 self, 234 input_pos: torch.Tensor, 235 q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) 236 k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) 237 v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) 238 bsz, 239 seqlen, 240 mask: torch.Tensor, 241 ) -> torch.Tensor: 242 q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 243 k = k.transpose(1, 2) 244 v = v.transpose(1, 2) 245 246 k, v = self.kv_cache.update(input_pos, k, v) 247 if self.enable_dynamic_shape: 248 start_pos = input_pos[-1].item() 249 torch._check_is_size(start_pos) 250 torch._check(start_pos < self.max_seq_len) 251 seq_length = q.size(2) 252 # pyre-ignore: Incompatible parameter type [6] 253 attn_mask = mask.narrow(0, start_pos, seq_length) 254 else: 255 attn_mask = mask[None, None, input_pos] 256 257 k = k.repeat_interleave(self.n_rep, dim=1) 258 v = v.repeat_interleave(self.n_rep, dim=1) 259 y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) 260 261 return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 262 263 264class Attention(nn.Module): 265 def __init__(self, args: ModelArgs, layer_id: int): 266 super().__init__() 267 self.use_kv_cache = args.use_kv_cache 268 self.n_heads = args.n_heads 269 self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads 270 assert self.n_heads % self.n_kv_heads == 0 271 model_parallel_size = 1 272 self.n_local_heads = self.n_heads // model_parallel_size 273 self.n_local_kv_heads = self.n_kv_heads // model_parallel_size 274 self.n_rep = self.n_local_heads // self.n_local_kv_heads 275 self.head_dim = args.dim // self.n_heads 276 self.max_batch_size = args.max_batch_size 277 self.max_seq_len = args.max_seq_len 278 self.dim = args.dim 279 self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) 280 self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) 281 self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) 282 self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) 283 284 self.layer_id = layer_id 285 286 causal_mask = torch.tril( 287 torch.ones( 288 self.max_seq_len, 289 self.max_seq_len, 290 dtype=torch.bool, 291 device="cpu", 292 ) 293 ) 294 self.register_buffer("mask", causal_mask, persistent=False) 295 296 if self.use_kv_cache: 297 self.kv_cache = KVCache( 298 args.max_batch_size, 299 args.max_seq_len, 300 self.n_kv_heads, 301 self.head_dim, 302 not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v 303 args.enable_dynamic_shape, 304 ) 305 self.SDPA = SDPA( 306 kv_cache=self.kv_cache, 307 dim=self.dim, 308 head_dim=self.head_dim, 309 n_rep=self.n_rep, 310 max_seq_len=self.max_seq_len, 311 enable_dynamic_shape=args.enable_dynamic_shape, 312 ) 313 if args.use_hf_rope: 314 self.apply_rotary_emb = hf_apply_rotary_emb 315 else: 316 self.apply_rotary_emb = RotaryEmbedding() 317 318 def forward( 319 self, 320 x: torch.Tensor, 321 freqs_cos: torch.Tensor, 322 freqs_sin: torch.Tensor, 323 input_pos: Optional[torch.Tensor] = None, 324 ): 325 bsz, seqlen, _ = x.shape 326 327 # QKV 328 q, k, v = self.wq(x), self.wk(x), self.wv(x) 329 # We need view_copy elimination 330 q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) 331 k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 332 v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 333 334 # RoPE relative positional embeddings 335 q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) 336 337 if self.use_kv_cache: 338 assert input_pos is not None 339 output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) 340 return self.wo(output) 341 342 q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 343 k = k.transpose(1, 2) 344 v = v.transpose(1, 2) 345 346 # grouped multiquery attention: expand out keys and values 347 k = k.repeat_interleave(self.n_rep, dim=1) 348 v = v.repeat_interleave(self.n_rep, dim=1) 349 350 assert hasattr(self, "mask") 351 352 mask = self.mask[:seqlen, :seqlen] 353 354 output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) 355 356 output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 357 358 output = self.wo(output) 359 360 return output 361 362 363class FeedForward(nn.Module): 364 def __init__(self, args: ModelArgs): 365 super().__init__() 366 assert args.hidden_dim is not None 367 hidden_dim: int = args.hidden_dim 368 self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) 369 self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) 370 self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) 371 372 def forward(self, x): 373 return self.w2(F.silu(self.w1(x)) * self.w3(x)) 374 375 376class ConditionalFeedForward(nn.Module): 377 def __init__(self, args: ModelArgs): 378 super().__init__() 379 self.dim = args.dim 380 hidden_dim = args.hidden_dim 381 if hidden_dim is None: 382 # If hidden_dim is not explicitly set in the ModelArgs, 383 # then calculate implicitly based on dim and also multiple of `args.multiple_of` 384 multiple_of = args.multiple_of 385 hidden_dim = 4 * self.dim 386 hidden_dim = int(2 * hidden_dim / 3) 387 hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 388 389 self.w1 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) 390 self.w2 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) 391 self.w3 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) 392 self.num_experts = args.num_experts 393 394 def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor: 395 w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D] 396 w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D] 397 w2_weights = self.w2[expert_indices] # [T, A, D, D] 398 x1 = F.silu(torch.einsum("ti,taio -> tao", x, w1_weights)) 399 x3 = torch.einsum("ti, taio -> tao", x, w3_weights) 400 expert_outs = torch.einsum("tao, taoi -> tai", (x1 * x3), w2_weights) 401 return expert_outs 402 403 404class MOEFeedForward(nn.Module): 405 def __init__(self, config) -> None: 406 super().__init__() 407 self.gate = nn.Linear(config.dim, config.num_experts, bias=False) 408 self.cond_ffn = ConditionalFeedForward(config) 409 self.dim = config.dim 410 411 def forward(self, x: torch.Tensor) -> torch.Tensor: 412 x = x.view(-1, self.dim) 413 # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts 414 # x: [T, D] 415 scores = self.gate(x) # [T, E] 416 expert_weights, expert_indices = torch.topk(scores, 2, dim=-1) # [T, A], [T, A] 417 expert_weights = expert_weights.softmax(dim=-1) # [T, A] 418 expert_outs = self.cond_ffn(x, expert_indices) 419 return torch.einsum("tai,ta -> ti", expert_outs, expert_weights) 420 421 422class TransformerBlock(nn.Module): 423 def __init__(self, layer_id: int, args: ModelArgs): 424 super().__init__() 425 self.use_kv_cache = args.use_kv_cache 426 self.n_heads = args.n_heads 427 self.dim = args.dim 428 self.head_dim = args.dim // args.n_heads 429 self.attention = Attention(args, layer_id) 430 if args.moe: 431 self.block_sparse_moe = MOEFeedForward(args) 432 else: 433 self.feed_forward = FeedForward(args) 434 self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 435 self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 436 437 def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN 438 h = self.attention.forward( 439 self.attention_norm(x), freqs_cos, freqs_sin, input_pos 440 ) 441 442 h = x + h 443 if hasattr(self, "block_sparse_moe"): 444 out = h + self.block_sparse_moe(self.ffn_norm(h)) 445 else: 446 out = h + self.feed_forward(self.ffn_norm(h)) 447 return out 448 449 450class Transformer(nn.Module): 451 def __init__(self, params: ModelArgs): 452 super().__init__() 453 self.params = params 454 self.vocab_size = params.vocab_size 455 self.n_layers = params.n_layers 456 457 self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) 458 self.layers = torch.nn.ModuleList() 459 for layer_id in range(params.n_layers): 460 self.layers.append(TransformerBlock(layer_id, params)) 461 self.norm = RMSNorm(params.dim, eps=params.norm_eps) 462 self.output = nn.Linear(params.dim, params.vocab_size, bias=False) 463 self.use_kv_cache = params.use_kv_cache 464 self.generate_full_logits = params.generate_full_logits 465 self.max_seq_len = params.max_seq_len 466 self.input_prune_map = params.input_prune_map 467 self.output_prune_map = params.output_prune_map 468 if params.use_hf_rope: 469 self.precompute_freqs_cis = hf_precompute_freqs_cis 470 else: 471 self.precompute_freqs_cis = partial( 472 precompute_freqs_cis, use_scaled=params.use_scaled_rope 473 ) 474 freqs_cos, freqs_sin = self.precompute_freqs_cis( 475 params.dim // params.n_heads, 476 ( 477 params.max_seq_len # Normal llama2. 478 if params.ffn_dim_multiplier is None 479 else params.max_seq_len * 2 # Sharded checkpoint. 480 ), 481 params.rope_freq_base, 482 ) 483 self.register_buffer("freqs_cos", freqs_cos, persistent=False) 484 self.register_buffer("freqs_sin", freqs_sin, persistent=False) 485 486 def forward( 487 self, 488 tokens: Optional[torch.LongTensor] = None, # tokens 489 input_pos: Optional[ 490 torch.LongTensor 491 ] = None, # Scalar tensor indicating size of window of the caches 492 h: Optional[torch.FloatTensor] = None, # embeddings 493 ) -> torch.Tensor: 494 if (tokens is None) ^ (h is not None): 495 raise ValueError( 496 "You cannot specify both tokens and h at the same time, and must specify either one" 497 ) 498 if tokens is not None and h is None: 499 h = self.tok_embeddings(tokens) 500 seqlen = h.shape[1] 501 502 if self.use_kv_cache: 503 assert ( 504 input_pos is not None 505 ), "input_pos must be provided when use_kv_cache is True" 506 507 if self.params.enable_dynamic_shape: 508 # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. 509 input_pos_item = input_pos[-1].item() 510 torch._check_is_size(input_pos_item) 511 torch._check(input_pos_item < self.params.max_seq_len) 512 # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor 513 freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) 514 # pyre-ignore: Incompatible parameter type [6] 515 freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) 516 else: 517 # When not using dynamic shape, use of the .item results in 518 # symints, due to querying the data from tensor. 519 # this path avoids that for mps backend, although probably mps backend 520 # can support dynamic shape? 521 freqs_cos = self.freqs_cos[input_pos] 522 freqs_sin = self.freqs_sin[input_pos] 523 524 else: 525 assert input_pos is None, "input_pos is unused when use_kv_cache is False" 526 freqs_cos = self.freqs_cos[:seqlen] 527 freqs_sin = self.freqs_sin[:seqlen] 528 529 for layer in self.layers: 530 h = layer( 531 h, 532 freqs_cos, 533 freqs_sin, 534 input_pos, 535 ) 536 537 if not self.generate_full_logits: 538 # Only the last logit is used for the new generated token 539 h = h[:, -1, :] 540 541 h = self.norm(h) 542 543 logits = self.output(h) 544 545 if self.output_prune_map is not None: 546 # expand to original size so that downstream applications can use the logits as-is. 547 if self.generate_full_logits: 548 # (1, seq_len, pruned_size) -> (1, seq_len, original_size) 549 expanded_logits = torch.full( 550 [logits.shape[0], logits.shape[1], self.vocab_size], 551 float("-inf"), 552 device=logits.device, 553 dtype=logits.dtype, 554 ) 555 expanded_logits[:, :, list(self.output_prune_map.values())] = logits 556 else: 557 # (1, pruned_size) -> (1, original_size) 558 expanded_logits = torch.full( 559 [logits.shape[0], self.vocab_size], 560 float("-inf"), 561 device=logits.device, 562 dtype=logits.dtype, 563 ) 564 expanded_logits[:, list(self.output_prune_map.values())] = logits 565 logits = expanded_logits 566 567 return logits 568