1# @lint-ignore-every LICENSELINT 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# Llama 2 is licensed under the LLAMA 2 Community License, 6# Copyright (c) Meta Platforms, Inc. All Rights Reserved. 7 8# Different RoPE implementations 9 10import math 11from typing import Tuple 12 13import torch 14 15# ======================== Stock Implementation ======================== 16 17 18def apply_scaling(freqs: torch.Tensor): 19 # Values obtained from grid search 20 scale_factor = 8 21 low_freq_factor = 1 22 high_freq_factor = 4 23 old_context_len = 8192 # original llama3 length 24 25 low_freq_wavelen = old_context_len / low_freq_factor 26 high_freq_wavelen = old_context_len / high_freq_factor 27 new_freqs = [] 28 for freq in freqs: 29 wavelen = 2 * math.pi / freq 30 if wavelen < high_freq_wavelen: 31 new_freqs.append(freq) 32 elif wavelen > low_freq_wavelen: 33 new_freqs.append(freq / scale_factor) 34 else: 35 assert low_freq_wavelen != high_freq_wavelen 36 smooth = (old_context_len / wavelen - low_freq_factor) / ( 37 high_freq_factor - low_freq_factor 38 ) 39 new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) 40 return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) 41 42 43def precompute_freqs_cis( 44 dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False 45): 46 freqs = 1.0 / ( 47 theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) 48 ) 49 t = torch.arange(end, device=freqs.device) # pyre-ignore 50 if use_scaled: 51 freqs = apply_scaling(freqs) # pyre-ignore 52 freqs = torch.outer(t, freqs).float() 53 freqs_cos = torch.cos(freqs) 54 freqs_sin = torch.sin(freqs) 55 return freqs_cos, freqs_sin 56 57 58def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 59 ndim = x.ndim 60 freqs_cis_ndim = freqs_cis.ndim 61 if freqs_cis_ndim == 3: 62 # freqs_cis: (seq_len, n_heads, head_dim // 2) 63 assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]) 64 shape = [ 65 d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 66 for i, d in enumerate(x.shape) 67 ] 68 else: 69 # freqs_cis: (seq_len, head_dim // 2) 70 assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 71 shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 72 return freqs_cis.view(shape) 73 74 75def apply_rotary_emb( 76 xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor 77) -> Tuple[torch.Tensor, torch.Tensor]: 78 xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) 79 xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) 80 81 freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) 82 freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) 83 84 xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin 85 xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos 86 xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin 87 xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos 88 89 xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) 90 xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) 91 92 return xq_out.type_as(xq), xk_out.type_as(xk) 93 94 95class RotaryEmbedding(torch.nn.Module): 96 def __init__(self): 97 super().__init__() 98 99 def forward( 100 self, 101 xq: torch.Tensor, 102 xk: torch.Tensor, 103 freqs_cos: torch.Tensor, 104 freqs_sin: torch.Tensor, 105 ): 106 xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) 107 return xq_out, xk_out 108 109 110# ======================= HuggingFace Implementation ======================== 111 112 113# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77 114def hf_precompute_freqs_cis(dim: int, end: int, theta: float): 115 freqs = 1.0 / ( 116 theta 117 ** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim) 118 ) 119 # pyre-ignore Undefined attribute [16]: `float` has no attribute `device`. 120 t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as( 121 freqs # pyre-ignore 122 ) 123 freqs = torch.outer(t, freqs).float() # pyre-ignore 124 emb = torch.cat((freqs, freqs), dim=-1) 125 freqs_cos = torch.cos(emb) 126 freqs_sin = torch.sin(emb) 127 return freqs_cos, freqs_sin 128 129 130# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L135 131def rotate_half(x): 132 """Rotates half the hidden dims of the input.""" 133 x1 = x[..., : x.shape[-1] // 2] 134 x2 = x[..., x.shape[-1] // 2 :] 135 return torch.cat((-x2, x1), dim=-1) 136 137 138def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 139 """Applies Rotary Position Embedding to the query and key tensors. 140 141 Args: 142 q (`torch.Tensor`): The query tensor. 143 k (`torch.Tensor`): The key tensor. 144 cos (`torch.Tensor`): The cosine part of the rotary embedding. 145 sin (`torch.Tensor`): The sine part of the rotary embedding. 146 position_ids (`torch.Tensor`, *optional*): 147 Deprecated and unused. 148 unsqueeze_dim (`int`, *optional*, defaults to 1): 149 The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 150 sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 151 that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 152 k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 153 cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 154 the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 155 Returns: 156 `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 157 """ 158 cos = cos.unsqueeze(unsqueeze_dim) 159 sin = sin.unsqueeze(unsqueeze_dim) 160 q_embed = (q * cos) + (rotate_half(q) * sin) 161 k_embed = (k * cos) + (rotate_half(k) * sin) 162 return q_embed, k_embed 163