xref: /aosp_15_r20/external/executorch/examples/models/llama/rope.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# @lint-ignore-every LICENSELINT
2*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker#
5*523fa7a6SAndroid Build Coastguard Worker# Llama 2 is licensed under the LLAMA 2 Community License,
6*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Worker# Different RoPE implementations
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerimport math
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Tuple
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport torch
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Worker# ======================== Stock Implementation ========================
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerdef apply_scaling(freqs: torch.Tensor):
19*523fa7a6SAndroid Build Coastguard Worker    # Values obtained from grid search
20*523fa7a6SAndroid Build Coastguard Worker    scale_factor = 8
21*523fa7a6SAndroid Build Coastguard Worker    low_freq_factor = 1
22*523fa7a6SAndroid Build Coastguard Worker    high_freq_factor = 4
23*523fa7a6SAndroid Build Coastguard Worker    old_context_len = 8192  # original llama3 length
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker    low_freq_wavelen = old_context_len / low_freq_factor
26*523fa7a6SAndroid Build Coastguard Worker    high_freq_wavelen = old_context_len / high_freq_factor
27*523fa7a6SAndroid Build Coastguard Worker    new_freqs = []
28*523fa7a6SAndroid Build Coastguard Worker    for freq in freqs:
29*523fa7a6SAndroid Build Coastguard Worker        wavelen = 2 * math.pi / freq
30*523fa7a6SAndroid Build Coastguard Worker        if wavelen < high_freq_wavelen:
31*523fa7a6SAndroid Build Coastguard Worker            new_freqs.append(freq)
32*523fa7a6SAndroid Build Coastguard Worker        elif wavelen > low_freq_wavelen:
33*523fa7a6SAndroid Build Coastguard Worker            new_freqs.append(freq / scale_factor)
34*523fa7a6SAndroid Build Coastguard Worker        else:
35*523fa7a6SAndroid Build Coastguard Worker            assert low_freq_wavelen != high_freq_wavelen
36*523fa7a6SAndroid Build Coastguard Worker            smooth = (old_context_len / wavelen - low_freq_factor) / (
37*523fa7a6SAndroid Build Coastguard Worker                high_freq_factor - low_freq_factor
38*523fa7a6SAndroid Build Coastguard Worker            )
39*523fa7a6SAndroid Build Coastguard Worker            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
40*523fa7a6SAndroid Build Coastguard Worker    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Workerdef precompute_freqs_cis(
44*523fa7a6SAndroid Build Coastguard Worker    dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
45*523fa7a6SAndroid Build Coastguard Worker):
46*523fa7a6SAndroid Build Coastguard Worker    freqs = 1.0 / (
47*523fa7a6SAndroid Build Coastguard Worker        theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
48*523fa7a6SAndroid Build Coastguard Worker    )
49*523fa7a6SAndroid Build Coastguard Worker    t = torch.arange(end, device=freqs.device)  # pyre-ignore
50*523fa7a6SAndroid Build Coastguard Worker    if use_scaled:
51*523fa7a6SAndroid Build Coastguard Worker        freqs = apply_scaling(freqs)  # pyre-ignore
52*523fa7a6SAndroid Build Coastguard Worker    freqs = torch.outer(t, freqs).float()
53*523fa7a6SAndroid Build Coastguard Worker    freqs_cos = torch.cos(freqs)
54*523fa7a6SAndroid Build Coastguard Worker    freqs_sin = torch.sin(freqs)
55*523fa7a6SAndroid Build Coastguard Worker    return freqs_cos, freqs_sin
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Workerdef reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
59*523fa7a6SAndroid Build Coastguard Worker    ndim = x.ndim
60*523fa7a6SAndroid Build Coastguard Worker    freqs_cis_ndim = freqs_cis.ndim
61*523fa7a6SAndroid Build Coastguard Worker    if freqs_cis_ndim == 3:
62*523fa7a6SAndroid Build Coastguard Worker        # freqs_cis: (seq_len, n_heads, head_dim // 2)
63*523fa7a6SAndroid Build Coastguard Worker        assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1])
64*523fa7a6SAndroid Build Coastguard Worker        shape = [
65*523fa7a6SAndroid Build Coastguard Worker            d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1
66*523fa7a6SAndroid Build Coastguard Worker            for i, d in enumerate(x.shape)
67*523fa7a6SAndroid Build Coastguard Worker        ]
68*523fa7a6SAndroid Build Coastguard Worker    else:
69*523fa7a6SAndroid Build Coastguard Worker        # freqs_cis: (seq_len, head_dim // 2)
70*523fa7a6SAndroid Build Coastguard Worker        assert freqs_cis.shape == (x.shape[1], x.shape[-1])
71*523fa7a6SAndroid Build Coastguard Worker        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
72*523fa7a6SAndroid Build Coastguard Worker    return freqs_cis.view(shape)
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker
75*523fa7a6SAndroid Build Coastguard Workerdef apply_rotary_emb(
76*523fa7a6SAndroid Build Coastguard Worker    xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
77*523fa7a6SAndroid Build Coastguard Worker) -> Tuple[torch.Tensor, torch.Tensor]:
78*523fa7a6SAndroid Build Coastguard Worker    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
79*523fa7a6SAndroid Build Coastguard Worker    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
82*523fa7a6SAndroid Build Coastguard Worker    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
85*523fa7a6SAndroid Build Coastguard Worker    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
86*523fa7a6SAndroid Build Coastguard Worker    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
87*523fa7a6SAndroid Build Coastguard Worker    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
88*523fa7a6SAndroid Build Coastguard Worker
89*523fa7a6SAndroid Build Coastguard Worker    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
90*523fa7a6SAndroid Build Coastguard Worker    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker    return xq_out.type_as(xq), xk_out.type_as(xk)
93*523fa7a6SAndroid Build Coastguard Worker
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Workerclass RotaryEmbedding(torch.nn.Module):
96*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
97*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker    def forward(
100*523fa7a6SAndroid Build Coastguard Worker        self,
101*523fa7a6SAndroid Build Coastguard Worker        xq: torch.Tensor,
102*523fa7a6SAndroid Build Coastguard Worker        xk: torch.Tensor,
103*523fa7a6SAndroid Build Coastguard Worker        freqs_cos: torch.Tensor,
104*523fa7a6SAndroid Build Coastguard Worker        freqs_sin: torch.Tensor,
105*523fa7a6SAndroid Build Coastguard Worker    ):
106*523fa7a6SAndroid Build Coastguard Worker        xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
107*523fa7a6SAndroid Build Coastguard Worker        return xq_out, xk_out
108*523fa7a6SAndroid Build Coastguard Worker
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Worker# ======================= HuggingFace Implementation ========================
111*523fa7a6SAndroid Build Coastguard Worker
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Worker# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L77
114*523fa7a6SAndroid Build Coastguard Workerdef hf_precompute_freqs_cis(dim: int, end: int, theta: float):
115*523fa7a6SAndroid Build Coastguard Worker    freqs = 1.0 / (
116*523fa7a6SAndroid Build Coastguard Worker        theta
117*523fa7a6SAndroid Build Coastguard Worker        ** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim)
118*523fa7a6SAndroid Build Coastguard Worker    )
119*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore Undefined attribute [16]: `float` has no attribute `device`.
120*523fa7a6SAndroid Build Coastguard Worker    t = torch.arange(end, device=freqs.device, dtype=torch.int64).type_as(
121*523fa7a6SAndroid Build Coastguard Worker        freqs  # pyre-ignore
122*523fa7a6SAndroid Build Coastguard Worker    )
123*523fa7a6SAndroid Build Coastguard Worker    freqs = torch.outer(t, freqs).float()  # pyre-ignore
124*523fa7a6SAndroid Build Coastguard Worker    emb = torch.cat((freqs, freqs), dim=-1)
125*523fa7a6SAndroid Build Coastguard Worker    freqs_cos = torch.cos(emb)
126*523fa7a6SAndroid Build Coastguard Worker    freqs_sin = torch.sin(emb)
127*523fa7a6SAndroid Build Coastguard Worker    return freqs_cos, freqs_sin
128*523fa7a6SAndroid Build Coastguard Worker
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker# Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L135
131*523fa7a6SAndroid Build Coastguard Workerdef rotate_half(x):
132*523fa7a6SAndroid Build Coastguard Worker    """Rotates half the hidden dims of the input."""
133*523fa7a6SAndroid Build Coastguard Worker    x1 = x[..., : x.shape[-1] // 2]
134*523fa7a6SAndroid Build Coastguard Worker    x2 = x[..., x.shape[-1] // 2 :]
135*523fa7a6SAndroid Build Coastguard Worker    return torch.cat((-x2, x1), dim=-1)
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Worker
138*523fa7a6SAndroid Build Coastguard Workerdef hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
139*523fa7a6SAndroid Build Coastguard Worker    """Applies Rotary Position Embedding to the query and key tensors.
140*523fa7a6SAndroid Build Coastguard Worker
141*523fa7a6SAndroid Build Coastguard Worker    Args:
142*523fa7a6SAndroid Build Coastguard Worker        q (`torch.Tensor`): The query tensor.
143*523fa7a6SAndroid Build Coastguard Worker        k (`torch.Tensor`): The key tensor.
144*523fa7a6SAndroid Build Coastguard Worker        cos (`torch.Tensor`): The cosine part of the rotary embedding.
145*523fa7a6SAndroid Build Coastguard Worker        sin (`torch.Tensor`): The sine part of the rotary embedding.
146*523fa7a6SAndroid Build Coastguard Worker        position_ids (`torch.Tensor`, *optional*):
147*523fa7a6SAndroid Build Coastguard Worker            Deprecated and unused.
148*523fa7a6SAndroid Build Coastguard Worker        unsqueeze_dim (`int`, *optional*, defaults to 1):
149*523fa7a6SAndroid Build Coastguard Worker            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
150*523fa7a6SAndroid Build Coastguard Worker            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
151*523fa7a6SAndroid Build Coastguard Worker            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
152*523fa7a6SAndroid Build Coastguard Worker            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
153*523fa7a6SAndroid Build Coastguard Worker            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
154*523fa7a6SAndroid Build Coastguard Worker            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
155*523fa7a6SAndroid Build Coastguard Worker    Returns:
156*523fa7a6SAndroid Build Coastguard Worker        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
157*523fa7a6SAndroid Build Coastguard Worker    """
158*523fa7a6SAndroid Build Coastguard Worker    cos = cos.unsqueeze(unsqueeze_dim)
159*523fa7a6SAndroid Build Coastguard Worker    sin = sin.unsqueeze(unsqueeze_dim)
160*523fa7a6SAndroid Build Coastguard Worker    q_embed = (q * cos) + (rotate_half(q) * sin)
161*523fa7a6SAndroid Build Coastguard Worker    k_embed = (k * cos) + (rotate_half(k) * sin)
162*523fa7a6SAndroid Build Coastguard Worker    return q_embed, k_embed
163