xref: /aosp_15_r20/external/executorch/examples/models/llama/source_transformation/vulkan_rope.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import executorch.backends.vulkan.custom_ops_lib  # noqa
8import torch
9
10from executorch.examples.models.llama.rope import RotaryEmbedding
11
12
13class VkRotaryEmbedding(torch.nn.Module):
14    def __init__(self):
15        super().__init__()
16
17    def forward(
18        self,
19        xq: torch.Tensor,
20        xk: torch.Tensor,
21        freqs_cos: torch.Tensor,
22        freqs_sin: torch.Tensor,
23    ):
24        xq_out, xk_out = torch.ops.et_vk.apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
25        return xq_out, xk_out
26
27
28def replace_with_vulkan_rotary_emb(module: torch.nn.Module):
29    for name, child in module.named_children():
30        if isinstance(child, RotaryEmbedding):
31            new_module = VkRotaryEmbedding()
32            setattr(module, name, new_module)
33        else:
34            replace_with_vulkan_rotary_emb(child)
35
36    return module
37