# @lint-ignore-every LICENSELINT # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # Llama 2 is licensed under the LLAMA 2 Community License, # Copyright (c) Meta Platforms, Inc. All Rights Reserved. # Different RoPE implementations import math from typing import Tuple import torch # ======================== Stock Implementation ======================== def apply_scaling(freqs: torch.Tensor): # Values obtained from grid search scale_factor = 8 low_freq_factor = 1 high_freq_factor = 4 old_context_len = 8192 # original llama3 length low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor new_freqs = [] for freq in freqs: wavelen = 2 * math.pi / freq if wavelen < high_freq_wavelen: new_freqs.append(freq) elif wavelen > low_freq_wavelen: new_freqs.append(freq / scale_factor) else: assert low_freq_wavelen != high_freq_wavelen smooth = (old_context_len / wavelen - low_freq_factor) / ( high_freq_factor - low_freq_factor ) new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) def precompute_freqs_cis( dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False ): freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) ) t = torch.arange(end, device=freqs.device) # pyre-ignore if use_scaled: freqs = apply_scaling(freqs) # pyre-ignore freqs = torch.outer(t, freqs).float() freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim freqs_cis_ndim = freqs_cis.ndim if freqs_cis_ndim == 3: # freqs_cis: (seq_len, n_heads, head_dim // 2) assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]) shape = [ d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 for i, d in enumerate(x.shape) ] else: # freqs_cis: (seq_len, head_dim // 2) assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(shape) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class RotaryEmbedding(torch.nn.Module): def __init__(self): super().__init__() def forward( self, xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) return xq_out, xk_out # ======================= HuggingFace Implementation ======================== # Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77 def hf_precompute_freqs_cis(dim: int, end: int, theta: float): freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim) ) # pyre-ignore Undefined attribute [16]: `float` has no attribute `device`. t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as( freqs # pyre-ignore ) freqs = torch.outer(t, freqs).float() # pyre-ignore emb = torch.cat((freqs, freqs), dim=-1) freqs_cos = torch.cos(emb) freqs_sin = torch.sin(emb) return freqs_cos, freqs_sin # Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L135 def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed