1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import executorch.backends.vulkan.custom_ops_lib # noqa 8import torch 9 10from executorch.examples.models.llama.rope import RotaryEmbedding 11 12 13class VkRotaryEmbedding(torch.nn.Module): 14 def __init__(self): 15 super().__init__() 16 17 def forward( 18 self, 19 xq: torch.Tensor, 20 xk: torch.Tensor, 21 freqs_cos: torch.Tensor, 22 freqs_sin: torch.Tensor, 23 ): 24 xq_out, xk_out = torch.ops.et_vk.apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) 25 return xq_out, xk_out 26 27 28def replace_with_vulkan_rotary_emb(module: torch.nn.Module): 29 for name, child in module.named_children(): 30 if isinstance(child, RotaryEmbedding): 31 new_module = VkRotaryEmbedding() 32 setattr(module, name, new_module) 33 else: 34 replace_with_vulkan_rotary_emb(child) 35 36 return module 37