1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <gtest/gtest.h>
10
11 #include <ATen/ATen.h>
12
13 #include <executorch/backends/vulkan/runtime/api/api.h>
14 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16
17 #include <cassert>
18
19 //
20 // Reference Implementations
21 //
22
rotary_embedding_impl(const at::Tensor & xq,const at::Tensor & xk,const at::Tensor & freqs_cos,const at::Tensor & freqs_sin)23 std::pair<at::Tensor, at::Tensor> rotary_embedding_impl(
24 const at::Tensor& xq,
25 const at::Tensor& xk,
26 const at::Tensor& freqs_cos,
27 const at::Tensor& freqs_sin) {
28 std::vector<at::Tensor> xq_even_odd = at::unbind(
29 xq.reshape({xq.size(0), xq.size(1), xq.size(2), xq.size(3) / 2, 2}), -1);
30 at::Tensor& xq_r = xq_even_odd[0];
31 at::Tensor& xq_i = xq_even_odd[1];
32
33 std::vector<at::Tensor> xk_even_odd = at::unbind(
34 xk.reshape({xk.size(0), xk.size(1), xk.size(2), xk.size(3) / 2, 2}), -1);
35 at::Tensor& xk_r = xk_even_odd[0];
36 at::Tensor& xk_i = xk_even_odd[1];
37
38 at::Tensor freqs_cos_reshape =
39 freqs_cos.reshape({1, freqs_cos.size(0), 1, freqs_cos.size(1)});
40 at::Tensor freqs_sin_reshape =
41 freqs_sin.reshape({1, freqs_sin.size(0), 1, freqs_sin.size(1)});
42
43 at::Tensor xq_out_r = xq_r * freqs_cos_reshape - xq_i * freqs_sin_reshape;
44 at::Tensor xq_out_i = xq_r * freqs_sin_reshape + xq_i * freqs_cos_reshape;
45 at::Tensor xk_out_r = xk_r * freqs_cos_reshape - xk_i * freqs_sin_reshape;
46 at::Tensor xk_out_i = xk_r * freqs_sin_reshape + xk_i * freqs_cos_reshape;
47
48 at::Tensor xq_out = at::flatten(at::stack({xq_out_r, xq_out_i}, -1), 3);
49 at::Tensor xk_out = at::flatten(at::stack({xk_out_r, xk_out_i}, -1), 3);
50
51 return std::make_pair(xq_out, xk_out);
52 }
53
54 //
55 // Test functions
56 //
57
from_at_scalartype(c10::ScalarType at_scalartype)58 vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
59 using namespace vkcompute;
60 switch (at_scalartype) {
61 case c10::kFloat:
62 return vkapi::kFloat;
63 case c10::kHalf:
64 return vkapi::kHalf;
65 case c10::kInt:
66 return vkapi::kInt;
67 case c10::kLong:
68 return vkapi::kInt;
69 case c10::kChar:
70 return vkapi::kChar;
71 case c10::kByte:
72 return vkapi::kByte;
73 default:
74 VK_THROW("Unsupported at::ScalarType!");
75 }
76 }
77
test_reference(const int n_heads=4,const int n_kv_heads=2,const int dim=32,const int seq_len=1)78 void test_reference(
79 const int n_heads = 4,
80 const int n_kv_heads = 2,
81 const int dim = 32,
82 const int seq_len = 1) {
83 const int head_dim = dim / n_heads;
84
85 at::Tensor xq = at::rand(
86 {1, seq_len, n_heads, head_dim}, at::device(at::kCPU).dtype(at::kFloat));
87 at::Tensor xk = at::rand(
88 {1, seq_len, n_kv_heads, head_dim},
89 at::device(at::kCPU).dtype(at::kFloat));
90 at::Tensor freqs_cos =
91 at::rand({seq_len, head_dim / 2}, at::device(at::kCPU).dtype(at::kFloat));
92 at::Tensor freqs_sin =
93 at::rand({seq_len, head_dim / 2}, at::device(at::kCPU).dtype(at::kFloat));
94
95 std::pair<at::Tensor, at::Tensor> outs =
96 rotary_embedding_impl(xq, xk, freqs_cos, freqs_sin);
97 at::Tensor& xq_out = outs.first;
98 at::Tensor& xk_out = outs.second;
99
100 // Build Vulkan graph
101 using namespace vkcompute;
102
103 GraphConfig config;
104 config.set_storage_type_override(utils::kTexture3D);
105 ComputeGraph graph(config);
106
107 #define MAKE_INPUT_FOR(x) \
108 IOValueRef r_##x = graph.add_input_tensor( \
109 x.sizes().vec(), from_at_scalartype(x.scalar_type()));
110
111 MAKE_INPUT_FOR(xq);
112 MAKE_INPUT_FOR(xk);
113 MAKE_INPUT_FOR(freqs_cos);
114 MAKE_INPUT_FOR(freqs_sin);
115
116 const ValueRef r_xq_out = graph.add_tensor(
117 xq_out.sizes().vec(), from_at_scalartype(xq_out.scalar_type()));
118 const ValueRef r_xk_out = graph.add_tensor(
119 xk_out.sizes().vec(), from_at_scalartype(xk_out.scalar_type()));
120
121 VK_GET_OP_FN("et_vk.apply_rotary_emb.default")
122 (graph,
123 {r_xq.value,
124 r_xk.value,
125 r_freqs_cos.value,
126 r_freqs_sin.value,
127 graph.add_value_list({r_xq_out, r_xk_out})});
128
129 ValueRef staging_xq_out = graph.set_output_tensor(r_xq_out);
130 ValueRef staging_xk_out = graph.set_output_tensor(r_xk_out);
131
132 graph.prepare();
133 graph.encode_prepack();
134 graph.prepack();
135 graph.encode_execute();
136
137 //
138 // Run model
139 //
140
141 graph.propagate_resize();
142 graph.copy_into_staging(r_xq.staging, xq.const_data_ptr(), xq.numel());
143 graph.copy_into_staging(r_xk.staging, xk.const_data_ptr(), xk.numel());
144 graph.copy_into_staging(
145 r_freqs_cos.staging, freqs_cos.const_data_ptr(), freqs_cos.numel());
146 graph.copy_into_staging(
147 r_freqs_sin.staging, freqs_sin.const_data_ptr(), freqs_sin.numel());
148
149 graph.execute();
150
151 at::Tensor vk_xq_out = at::empty_like(xq_out);
152 graph.copy_from_staging(
153 staging_xq_out, vk_xq_out.mutable_data_ptr(), vk_xq_out.numel());
154
155 at::Tensor vk_xk_out = at::empty_like(xk_out);
156 graph.copy_from_staging(
157 staging_xk_out, vk_xk_out.mutable_data_ptr(), vk_xk_out.numel());
158
159 EXPECT_TRUE(at::allclose(xq_out, vk_xq_out, 1e-4, 1e-4));
160 EXPECT_TRUE(at::allclose(xk_out, vk_xk_out, 1e-4, 1e-4));
161 }
162
TEST(VulkanRotaryEmbeddingTest,rotary_embedding_test)163 TEST(VulkanRotaryEmbeddingTest, rotary_embedding_test) {
164 test_reference();
165 }
166
TEST(VulkanRotaryEmbeddingTest,rotary_embedding_llama3_params_test)167 TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test) {
168 test_reference(
169 /*n_heads=*/32,
170 /*n_kv_heads=*/8,
171 /*dim=*/2048);
172 }
173
TEST(VulkanRotaryEmbeddingTest,rotary_embedding_llama3_params_test_seq_len_3)174 TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test_seq_len_3) {
175 test_reference(
176 /*n_heads=*/32,
177 /*n_kv_heads=*/8,
178 /*dim=*/2048,
179 /*seq_len=*/3);
180 }
181