xref: /aosp_15_r20/external/executorch/backends/vulkan/test/op_tests/rotary_embedding_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <gtest/gtest.h>
10 
11 #include <ATen/ATen.h>
12 
13 #include <executorch/backends/vulkan/runtime/api/api.h>
14 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16 
17 #include <cassert>
18 
19 //
20 // Reference Implementations
21 //
22 
rotary_embedding_impl(const at::Tensor & xq,const at::Tensor & xk,const at::Tensor & freqs_cos,const at::Tensor & freqs_sin)23 std::pair<at::Tensor, at::Tensor> rotary_embedding_impl(
24     const at::Tensor& xq,
25     const at::Tensor& xk,
26     const at::Tensor& freqs_cos,
27     const at::Tensor& freqs_sin) {
28   std::vector<at::Tensor> xq_even_odd = at::unbind(
29       xq.reshape({xq.size(0), xq.size(1), xq.size(2), xq.size(3) / 2, 2}), -1);
30   at::Tensor& xq_r = xq_even_odd[0];
31   at::Tensor& xq_i = xq_even_odd[1];
32 
33   std::vector<at::Tensor> xk_even_odd = at::unbind(
34       xk.reshape({xk.size(0), xk.size(1), xk.size(2), xk.size(3) / 2, 2}), -1);
35   at::Tensor& xk_r = xk_even_odd[0];
36   at::Tensor& xk_i = xk_even_odd[1];
37 
38   at::Tensor freqs_cos_reshape =
39       freqs_cos.reshape({1, freqs_cos.size(0), 1, freqs_cos.size(1)});
40   at::Tensor freqs_sin_reshape =
41       freqs_sin.reshape({1, freqs_sin.size(0), 1, freqs_sin.size(1)});
42 
43   at::Tensor xq_out_r = xq_r * freqs_cos_reshape - xq_i * freqs_sin_reshape;
44   at::Tensor xq_out_i = xq_r * freqs_sin_reshape + xq_i * freqs_cos_reshape;
45   at::Tensor xk_out_r = xk_r * freqs_cos_reshape - xk_i * freqs_sin_reshape;
46   at::Tensor xk_out_i = xk_r * freqs_sin_reshape + xk_i * freqs_cos_reshape;
47 
48   at::Tensor xq_out = at::flatten(at::stack({xq_out_r, xq_out_i}, -1), 3);
49   at::Tensor xk_out = at::flatten(at::stack({xk_out_r, xk_out_i}, -1), 3);
50 
51   return std::make_pair(xq_out, xk_out);
52 }
53 
54 //
55 // Test functions
56 //
57 
from_at_scalartype(c10::ScalarType at_scalartype)58 vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
59   using namespace vkcompute;
60   switch (at_scalartype) {
61     case c10::kFloat:
62       return vkapi::kFloat;
63     case c10::kHalf:
64       return vkapi::kHalf;
65     case c10::kInt:
66       return vkapi::kInt;
67     case c10::kLong:
68       return vkapi::kInt;
69     case c10::kChar:
70       return vkapi::kChar;
71     case c10::kByte:
72       return vkapi::kByte;
73     default:
74       VK_THROW("Unsupported at::ScalarType!");
75   }
76 }
77 
test_reference(const int n_heads=4,const int n_kv_heads=2,const int dim=32,const int seq_len=1)78 void test_reference(
79     const int n_heads = 4,
80     const int n_kv_heads = 2,
81     const int dim = 32,
82     const int seq_len = 1) {
83   const int head_dim = dim / n_heads;
84 
85   at::Tensor xq = at::rand(
86       {1, seq_len, n_heads, head_dim}, at::device(at::kCPU).dtype(at::kFloat));
87   at::Tensor xk = at::rand(
88       {1, seq_len, n_kv_heads, head_dim},
89       at::device(at::kCPU).dtype(at::kFloat));
90   at::Tensor freqs_cos =
91       at::rand({seq_len, head_dim / 2}, at::device(at::kCPU).dtype(at::kFloat));
92   at::Tensor freqs_sin =
93       at::rand({seq_len, head_dim / 2}, at::device(at::kCPU).dtype(at::kFloat));
94 
95   std::pair<at::Tensor, at::Tensor> outs =
96       rotary_embedding_impl(xq, xk, freqs_cos, freqs_sin);
97   at::Tensor& xq_out = outs.first;
98   at::Tensor& xk_out = outs.second;
99 
100   // Build Vulkan graph
101   using namespace vkcompute;
102 
103   GraphConfig config;
104   config.set_storage_type_override(utils::kTexture3D);
105   ComputeGraph graph(config);
106 
107 #define MAKE_INPUT_FOR(x)                    \
108   IOValueRef r_##x = graph.add_input_tensor( \
109       x.sizes().vec(), from_at_scalartype(x.scalar_type()));
110 
111   MAKE_INPUT_FOR(xq);
112   MAKE_INPUT_FOR(xk);
113   MAKE_INPUT_FOR(freqs_cos);
114   MAKE_INPUT_FOR(freqs_sin);
115 
116   const ValueRef r_xq_out = graph.add_tensor(
117       xq_out.sizes().vec(), from_at_scalartype(xq_out.scalar_type()));
118   const ValueRef r_xk_out = graph.add_tensor(
119       xk_out.sizes().vec(), from_at_scalartype(xk_out.scalar_type()));
120 
121   VK_GET_OP_FN("et_vk.apply_rotary_emb.default")
122   (graph,
123    {r_xq.value,
124     r_xk.value,
125     r_freqs_cos.value,
126     r_freqs_sin.value,
127     graph.add_value_list({r_xq_out, r_xk_out})});
128 
129   ValueRef staging_xq_out = graph.set_output_tensor(r_xq_out);
130   ValueRef staging_xk_out = graph.set_output_tensor(r_xk_out);
131 
132   graph.prepare();
133   graph.encode_prepack();
134   graph.prepack();
135   graph.encode_execute();
136 
137   //
138   // Run model
139   //
140 
141   graph.propagate_resize();
142   graph.copy_into_staging(r_xq.staging, xq.const_data_ptr(), xq.numel());
143   graph.copy_into_staging(r_xk.staging, xk.const_data_ptr(), xk.numel());
144   graph.copy_into_staging(
145       r_freqs_cos.staging, freqs_cos.const_data_ptr(), freqs_cos.numel());
146   graph.copy_into_staging(
147       r_freqs_sin.staging, freqs_sin.const_data_ptr(), freqs_sin.numel());
148 
149   graph.execute();
150 
151   at::Tensor vk_xq_out = at::empty_like(xq_out);
152   graph.copy_from_staging(
153       staging_xq_out, vk_xq_out.mutable_data_ptr(), vk_xq_out.numel());
154 
155   at::Tensor vk_xk_out = at::empty_like(xk_out);
156   graph.copy_from_staging(
157       staging_xk_out, vk_xk_out.mutable_data_ptr(), vk_xk_out.numel());
158 
159   EXPECT_TRUE(at::allclose(xq_out, vk_xq_out, 1e-4, 1e-4));
160   EXPECT_TRUE(at::allclose(xk_out, vk_xk_out, 1e-4, 1e-4));
161 }
162 
TEST(VulkanRotaryEmbeddingTest,rotary_embedding_test)163 TEST(VulkanRotaryEmbeddingTest, rotary_embedding_test) {
164   test_reference();
165 }
166 
TEST(VulkanRotaryEmbeddingTest,rotary_embedding_llama3_params_test)167 TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test) {
168   test_reference(
169       /*n_heads=*/32,
170       /*n_kv_heads=*/8,
171       /*dim=*/2048);
172 }
173 
TEST(VulkanRotaryEmbeddingTest,rotary_embedding_llama3_params_test_seq_len_3)174 TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test_seq_len_3) {
175   test_reference(
176       /*n_heads=*/32,
177       /*n_kv_heads=*/8,
178       /*dim=*/2048,
179       /*seq_len=*/3);
180 }
181