1*523fa7a6SAndroid Build Coastguard Worker# @lint-ignore-every LICENSELINT 2*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker# 5*523fa7a6SAndroid Build Coastguard Worker# Llama 2 is licensed under the LLAMA 2 Community License, 6*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. All Rights Reserved. 7*523fa7a6SAndroid Build Coastguard Worker 8*523fa7a6SAndroid Build Coastguard Worker# Different RoPE implementations 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerimport math 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Tuple 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport torch 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Worker# ======================== Stock Implementation ======================== 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerdef apply_scaling(freqs: torch.Tensor): 19*523fa7a6SAndroid Build Coastguard Worker # Values obtained from grid search 20*523fa7a6SAndroid Build Coastguard Worker scale_factor = 8 21*523fa7a6SAndroid Build Coastguard Worker low_freq_factor = 1 22*523fa7a6SAndroid Build Coastguard Worker high_freq_factor = 4 23*523fa7a6SAndroid Build Coastguard Worker old_context_len = 8192 # original llama3 length 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Worker low_freq_wavelen = old_context_len / low_freq_factor 26*523fa7a6SAndroid Build Coastguard Worker high_freq_wavelen = old_context_len / high_freq_factor 27*523fa7a6SAndroid Build Coastguard Worker new_freqs = [] 28*523fa7a6SAndroid Build Coastguard Worker for freq in freqs: 29*523fa7a6SAndroid Build Coastguard Worker wavelen = 2 * math.pi / freq 30*523fa7a6SAndroid Build Coastguard Worker if wavelen < high_freq_wavelen: 31*523fa7a6SAndroid Build Coastguard Worker new_freqs.append(freq) 32*523fa7a6SAndroid Build Coastguard Worker elif wavelen > low_freq_wavelen: 33*523fa7a6SAndroid Build Coastguard Worker new_freqs.append(freq / scale_factor) 34*523fa7a6SAndroid Build Coastguard Worker else: 35*523fa7a6SAndroid Build Coastguard Worker assert low_freq_wavelen != high_freq_wavelen 36*523fa7a6SAndroid Build Coastguard Worker smooth = (old_context_len / wavelen - low_freq_factor) / ( 37*523fa7a6SAndroid Build Coastguard Worker high_freq_factor - low_freq_factor 38*523fa7a6SAndroid Build Coastguard Worker ) 39*523fa7a6SAndroid Build Coastguard Worker new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) 40*523fa7a6SAndroid Build Coastguard Worker return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) 41*523fa7a6SAndroid Build Coastguard Worker 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Workerdef precompute_freqs_cis( 44*523fa7a6SAndroid Build Coastguard Worker dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False 45*523fa7a6SAndroid Build Coastguard Worker): 46*523fa7a6SAndroid Build Coastguard Worker freqs = 1.0 / ( 47*523fa7a6SAndroid Build Coastguard Worker theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) 48*523fa7a6SAndroid Build Coastguard Worker ) 49*523fa7a6SAndroid Build Coastguard Worker t = torch.arange(end, device=freqs.device) # pyre-ignore 50*523fa7a6SAndroid Build Coastguard Worker if use_scaled: 51*523fa7a6SAndroid Build Coastguard Worker freqs = apply_scaling(freqs) # pyre-ignore 52*523fa7a6SAndroid Build Coastguard Worker freqs = torch.outer(t, freqs).float() 53*523fa7a6SAndroid Build Coastguard Worker freqs_cos = torch.cos(freqs) 54*523fa7a6SAndroid Build Coastguard Worker freqs_sin = torch.sin(freqs) 55*523fa7a6SAndroid Build Coastguard Worker return freqs_cos, freqs_sin 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker 58*523fa7a6SAndroid Build Coastguard Workerdef reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 59*523fa7a6SAndroid Build Coastguard Worker ndim = x.ndim 60*523fa7a6SAndroid Build Coastguard Worker freqs_cis_ndim = freqs_cis.ndim 61*523fa7a6SAndroid Build Coastguard Worker if freqs_cis_ndim == 3: 62*523fa7a6SAndroid Build Coastguard Worker # freqs_cis: (seq_len, n_heads, head_dim // 2) 63*523fa7a6SAndroid Build Coastguard Worker assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]) 64*523fa7a6SAndroid Build Coastguard Worker shape = [ 65*523fa7a6SAndroid Build Coastguard Worker d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 66*523fa7a6SAndroid Build Coastguard Worker for i, d in enumerate(x.shape) 67*523fa7a6SAndroid Build Coastguard Worker ] 68*523fa7a6SAndroid Build Coastguard Worker else: 69*523fa7a6SAndroid Build Coastguard Worker # freqs_cis: (seq_len, head_dim // 2) 70*523fa7a6SAndroid Build Coastguard Worker assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 71*523fa7a6SAndroid Build Coastguard Worker shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 72*523fa7a6SAndroid Build Coastguard Worker return freqs_cis.view(shape) 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Worker 75*523fa7a6SAndroid Build Coastguard Workerdef apply_rotary_emb( 76*523fa7a6SAndroid Build Coastguard Worker xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor 77*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, torch.Tensor]: 78*523fa7a6SAndroid Build Coastguard Worker xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) 79*523fa7a6SAndroid Build Coastguard Worker xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) 80*523fa7a6SAndroid Build Coastguard Worker 81*523fa7a6SAndroid Build Coastguard Worker freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) 82*523fa7a6SAndroid Build Coastguard Worker freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Worker xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin 85*523fa7a6SAndroid Build Coastguard Worker xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos 86*523fa7a6SAndroid Build Coastguard Worker xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin 87*523fa7a6SAndroid Build Coastguard Worker xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos 88*523fa7a6SAndroid Build Coastguard Worker 89*523fa7a6SAndroid Build Coastguard Worker xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) 90*523fa7a6SAndroid Build Coastguard Worker xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker return xq_out.type_as(xq), xk_out.type_as(xk) 93*523fa7a6SAndroid Build Coastguard Worker 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Workerclass RotaryEmbedding(torch.nn.Module): 96*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 97*523fa7a6SAndroid Build Coastguard Worker super().__init__() 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker def forward( 100*523fa7a6SAndroid Build Coastguard Worker self, 101*523fa7a6SAndroid Build Coastguard Worker xq: torch.Tensor, 102*523fa7a6SAndroid Build Coastguard Worker xk: torch.Tensor, 103*523fa7a6SAndroid Build Coastguard Worker freqs_cos: torch.Tensor, 104*523fa7a6SAndroid Build Coastguard Worker freqs_sin: torch.Tensor, 105*523fa7a6SAndroid Build Coastguard Worker ): 106*523fa7a6SAndroid Build Coastguard Worker xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) 107*523fa7a6SAndroid Build Coastguard Worker return xq_out, xk_out 108*523fa7a6SAndroid Build Coastguard Worker 109*523fa7a6SAndroid Build Coastguard Worker 110*523fa7a6SAndroid Build Coastguard Worker# ======================= HuggingFace Implementation ======================== 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker 113*523fa7a6SAndroid Build Coastguard Worker# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77 114*523fa7a6SAndroid Build Coastguard Workerdef hf_precompute_freqs_cis(dim: int, end: int, theta: float): 115*523fa7a6SAndroid Build Coastguard Worker freqs = 1.0 / ( 116*523fa7a6SAndroid Build Coastguard Worker theta 117*523fa7a6SAndroid Build Coastguard Worker ** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim) 118*523fa7a6SAndroid Build Coastguard Worker ) 119*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore Undefined attribute [16]: `float` has no attribute `device`. 120*523fa7a6SAndroid Build Coastguard Worker t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as( 121*523fa7a6SAndroid Build Coastguard Worker freqs # pyre-ignore 122*523fa7a6SAndroid Build Coastguard Worker ) 123*523fa7a6SAndroid Build Coastguard Worker freqs = torch.outer(t, freqs).float() # pyre-ignore 124*523fa7a6SAndroid Build Coastguard Worker emb = torch.cat((freqs, freqs), dim=-1) 125*523fa7a6SAndroid Build Coastguard Worker freqs_cos = torch.cos(emb) 126*523fa7a6SAndroid Build Coastguard Worker freqs_sin = torch.sin(emb) 127*523fa7a6SAndroid Build Coastguard Worker return freqs_cos, freqs_sin 128*523fa7a6SAndroid Build Coastguard Worker 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L135 131*523fa7a6SAndroid Build Coastguard Workerdef rotate_half(x): 132*523fa7a6SAndroid Build Coastguard Worker """Rotates half the hidden dims of the input.""" 133*523fa7a6SAndroid Build Coastguard Worker x1 = x[..., : x.shape[-1] // 2] 134*523fa7a6SAndroid Build Coastguard Worker x2 = x[..., x.shape[-1] // 2 :] 135*523fa7a6SAndroid Build Coastguard Worker return torch.cat((-x2, x1), dim=-1) 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Worker 138*523fa7a6SAndroid Build Coastguard Workerdef hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 139*523fa7a6SAndroid Build Coastguard Worker """Applies Rotary Position Embedding to the query and key tensors. 140*523fa7a6SAndroid Build Coastguard Worker 141*523fa7a6SAndroid Build Coastguard Worker Args: 142*523fa7a6SAndroid Build Coastguard Worker q (`torch.Tensor`): The query tensor. 143*523fa7a6SAndroid Build Coastguard Worker k (`torch.Tensor`): The key tensor. 144*523fa7a6SAndroid Build Coastguard Worker cos (`torch.Tensor`): The cosine part of the rotary embedding. 145*523fa7a6SAndroid Build Coastguard Worker sin (`torch.Tensor`): The sine part of the rotary embedding. 146*523fa7a6SAndroid Build Coastguard Worker position_ids (`torch.Tensor`, *optional*): 147*523fa7a6SAndroid Build Coastguard Worker Deprecated and unused. 148*523fa7a6SAndroid Build Coastguard Worker unsqueeze_dim (`int`, *optional*, defaults to 1): 149*523fa7a6SAndroid Build Coastguard Worker The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 150*523fa7a6SAndroid Build Coastguard Worker sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 151*523fa7a6SAndroid Build Coastguard Worker that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 152*523fa7a6SAndroid Build Coastguard Worker k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 153*523fa7a6SAndroid Build Coastguard Worker cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 154*523fa7a6SAndroid Build Coastguard Worker the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 155*523fa7a6SAndroid Build Coastguard Worker Returns: 156*523fa7a6SAndroid Build Coastguard Worker `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 157*523fa7a6SAndroid Build Coastguard Worker """ 158*523fa7a6SAndroid Build Coastguard Worker cos = cos.unsqueeze(unsqueeze_dim) 159*523fa7a6SAndroid Build Coastguard Worker sin = sin.unsqueeze(unsqueeze_dim) 160*523fa7a6SAndroid Build Coastguard Worker q_embed = (q * cos) + (rotate_half(q) * sin) 161*523fa7a6SAndroid Build Coastguard Worker k_embed = (k * cos) + (rotate_half(k) * sin) 162*523fa7a6SAndroid Build Coastguard Worker return q_embed, k_embed 163