xref: /aosp_15_r20/external/executorch/examples/models/llama/rope.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# @lint-ignore-every LICENSELINT
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# Llama 2 is licensed under the LLAMA 2 Community License,
6# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
7
8# Different RoPE implementations
9
10import math
11from typing import Tuple
12
13import torch
14
15# ======================== Stock Implementation ========================
16
17
18def apply_scaling(freqs: torch.Tensor):
19    # Values obtained from grid search
20    scale_factor = 8
21    low_freq_factor = 1
22    high_freq_factor = 4
23    old_context_len = 8192  # original llama3 length
24
25    low_freq_wavelen = old_context_len / low_freq_factor
26    high_freq_wavelen = old_context_len / high_freq_factor
27    new_freqs = []
28    for freq in freqs:
29        wavelen = 2 * math.pi / freq
30        if wavelen < high_freq_wavelen:
31            new_freqs.append(freq)
32        elif wavelen > low_freq_wavelen:
33            new_freqs.append(freq / scale_factor)
34        else:
35            assert low_freq_wavelen != high_freq_wavelen
36            smooth = (old_context_len / wavelen - low_freq_factor) / (
37                high_freq_factor - low_freq_factor
38            )
39            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
40    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
41
42
43def precompute_freqs_cis(
44    dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
45):
46    freqs = 1.0 / (
47        theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
48    )
49    t = torch.arange(end, device=freqs.device)  # pyre-ignore
50    if use_scaled:
51        freqs = apply_scaling(freqs)  # pyre-ignore
52    freqs = torch.outer(t, freqs).float()
53    freqs_cos = torch.cos(freqs)
54    freqs_sin = torch.sin(freqs)
55    return freqs_cos, freqs_sin
56
57
58def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
59    ndim = x.ndim
60    freqs_cis_ndim = freqs_cis.ndim
61    if freqs_cis_ndim == 3:
62        # freqs_cis: (seq_len, n_heads, head_dim // 2)
63        assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1])
64        shape = [
65            d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1
66            for i, d in enumerate(x.shape)
67        ]
68    else:
69        # freqs_cis: (seq_len, head_dim // 2)
70        assert freqs_cis.shape == (x.shape[1], x.shape[-1])
71        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
72    return freqs_cis.view(shape)
73
74
75def apply_rotary_emb(
76    xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
77) -> Tuple[torch.Tensor, torch.Tensor]:
78    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
79    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
80
81    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
82    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
83
84    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
85    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
86    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
87    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
88
89    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
90    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
91
92    return xq_out.type_as(xq), xk_out.type_as(xk)
93
94
95class RotaryEmbedding(torch.nn.Module):
96    def __init__(self):
97        super().__init__()
98
99    def forward(
100        self,
101        xq: torch.Tensor,
102        xk: torch.Tensor,
103        freqs_cos: torch.Tensor,
104        freqs_sin: torch.Tensor,
105    ):
106        xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
107        return xq_out, xk_out
108
109
110# ======================= HuggingFace Implementation ========================
111
112
113# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77
114def hf_precompute_freqs_cis(dim: int, end: int, theta: float):
115    freqs = 1.0 / (
116        theta
117        ** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim)
118    )
119    # pyre-ignore Undefined attribute [16]: `float` has no attribute `device`.
120    t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as(
121        freqs  # pyre-ignore
122    )
123    freqs = torch.outer(t, freqs).float()  # pyre-ignore
124    emb = torch.cat((freqs, freqs), dim=-1)
125    freqs_cos = torch.cos(emb)
126    freqs_sin = torch.sin(emb)
127    return freqs_cos, freqs_sin
128
129
130# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L135
131def rotate_half(x):
132    """Rotates half the hidden dims of the input."""
133    x1 = x[..., : x.shape[-1] // 2]
134    x2 = x[..., x.shape[-1] // 2 :]
135    return torch.cat((-x2, x1), dim=-1)
136
137
138def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
139    """Applies Rotary Position Embedding to the query and key tensors.
140
141    Args:
142        q (`torch.Tensor`): The query tensor.
143        k (`torch.Tensor`): The key tensor.
144        cos (`torch.Tensor`): The cosine part of the rotary embedding.
145        sin (`torch.Tensor`): The sine part of the rotary embedding.
146        position_ids (`torch.Tensor`, *optional*):
147            Deprecated and unused.
148        unsqueeze_dim (`int`, *optional*, defaults to 1):
149            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
150            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
151            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
152            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
153            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
154            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
155    Returns:
156        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
157    """
158    cos = cos.unsqueeze(unsqueeze_dim)
159    sin = sin.unsqueeze(unsqueeze_dim)
160    q_embed = (q * cos) + (rotate_half(q) * sin)
161    k_embed = (k * cos) + (rotate_half(k) * sin)
162    return q_embed, k_embed
163