1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/MatMul.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Slice.h>
16 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Softmax.h>
17 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Transpose.h>
18
19 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
20
21 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
22
23 namespace vkcompute {
24
add_kv_cache_update_node(ComputeGraph & graph,const ValueRef input_pos_symint,const ValueRef projected,const ValueRef cache)25 void add_kv_cache_update_node(
26 ComputeGraph& graph,
27 const ValueRef input_pos_symint,
28 const ValueRef projected,
29 const ValueRef cache) {
30 std::string kernel_name("kv_cache_update");
31 add_storage_type_suffix(kernel_name, graph.storage_type_of(projected));
32 add_dtype_suffix(kernel_name, graph.dtype_of(projected));
33
34 utils::uvec3 global_size;
35 vkapi::ParamsBindList param_ubos;
36
37 if (graph.is_buffer_storage(cache)) {
38 global_size = graph.create_global_wg_size(projected);
39
40 param_ubos = {
41 graph.numel_ubo(projected),
42 graph.strides_ubo(cache),
43 graph.get_or_create_int_param_buffer(input_pos_symint)};
44 } else {
45 global_size = graph.logical_limits_of(projected);
46
47 param_ubos = {
48 graph.logical_limits_ubo(projected),
49 graph.get_or_create_int_param_buffer(input_pos_symint)};
50 }
51 const utils::uvec3 local_size = graph.create_local_wg_size(global_size);
52
53 graph.execute_nodes().emplace_back(new DispatchNode(
54 graph,
55 VK_KERNEL_FROM_STR(kernel_name),
56 global_size,
57 local_size,
58 // Inputs and Outputs
59 {{cache, vkapi::kWrite}, {projected, vkapi::kRead}},
60 // Shader param buffers
61 param_ubos,
62 // Specialization Constants
63 {},
64 // Resizing Logic
65 nullptr,
66 {}));
67 }
68
add_attn_weight_scale_and_mask_node(ComputeGraph & graph,const ValueRef input_pos_symint,const ValueRef q_projected,const ValueRef attn_weight)69 void add_attn_weight_scale_and_mask_node(
70 ComputeGraph& graph,
71 const ValueRef input_pos_symint,
72 const ValueRef q_projected,
73 const ValueRef attn_weight) {
74 std::string kernel_name("sdpa_attn_weight_scale_and_mask");
75 add_storage_type_suffix(kernel_name, graph.storage_type_of(attn_weight));
76 add_dtype_suffix(kernel_name, graph.dtype_of(attn_weight));
77
78 const int32_t head_dim_size = graph.size_at<int32_t>(-1, q_projected);
79 const float scale_val = 1.0f / std::sqrt(static_cast<float>(head_dim_size));
80
81 utils::uvec3 global_size;
82 utils::uvec3 local_size;
83 vkapi::ParamsBindList param_ubos;
84
85 if (graph.is_buffer_storage(attn_weight)) {
86 global_size = {
87 graph.size_at<uint32_t>(-1, attn_weight),
88 graph.size_at<uint32_t>(-2, attn_weight),
89 graph.size_at<uint32_t>(-3, attn_weight),
90 };
91
92 param_ubos = {
93 graph.sizes_ubo(attn_weight),
94 graph.strides_ubo(attn_weight),
95 graph.create_params_buffer(scale_val)};
96 } else {
97 global_size = graph.logical_limits_of(attn_weight);
98
99 param_ubos = {
100 graph.logical_limits_ubo(attn_weight),
101 graph.get_or_create_int_param_buffer(input_pos_symint),
102 graph.create_params_buffer(scale_val)};
103 }
104
105 local_size = graph.create_local_wg_size(global_size);
106
107 graph.execute_nodes().emplace_back(new DispatchNode(
108 graph,
109 VK_KERNEL_FROM_STR(kernel_name),
110 global_size,
111 local_size,
112 // Inputs and Outputs
113 {{attn_weight, vkapi::kReadWrite}},
114 // Shader param buffers
115 param_ubos,
116 // Specialization Constants
117 {},
118 // Resizing Logic
119 nullptr,
120 {}));
121 }
122
get_cache_slice_sizes(ComputeGraph & graph,ValueRef cache,ValueRef input_pos_symint,ValueRef q_projected)123 std::vector<int64_t> get_cache_slice_sizes(
124 ComputeGraph& graph,
125 ValueRef cache,
126 ValueRef input_pos_symint,
127 ValueRef q_projected) {
128 std::vector<int64_t> slice_sizes = graph.sizes_of(cache);
129
130 // Cache slicing will always be in the channels dim
131 const int32_t input_pos_val = graph.read_symint(input_pos_symint);
132 const int64_t q_seq_len = graph.size_at<int64_t>(1, q_projected);
133 slice_sizes.at(1) = input_pos_val + q_seq_len;
134 return slice_sizes;
135 }
136
resize_cache_slice_view_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)137 void resize_cache_slice_view_node(
138 ComputeGraph* graph,
139 const std::vector<ArgGroup>& args,
140 const std::vector<ValueRef>& extra_args) {
141 (void)args;
142 std::vector<int64_t> slice_sizes = get_cache_slice_sizes(
143 *graph, extra_args[0], extra_args[1], extra_args[2]);
144
145 graph->get_tensor(extra_args[3])->virtual_resize(slice_sizes);
146 }
147
add_cache_slice_view_node(ComputeGraph & graph,ValueRef cache,ValueRef input_pos_symint,ValueRef q_projected,ValueRef cache_sliced,const int64_t max_seq_len)148 void add_cache_slice_view_node(
149 ComputeGraph& graph,
150 ValueRef cache,
151 ValueRef input_pos_symint,
152 ValueRef q_projected,
153 ValueRef cache_sliced,
154 const int64_t max_seq_len) {
155 std::vector<int64_t> slice_sizes =
156 get_cache_slice_sizes(graph, cache, input_pos_symint, q_projected);
157 // Initialize the slice to the maximum possible size to start
158 slice_sizes.at(1) = max_seq_len;
159
160 graph.get_tensor(cache_sliced)->virtual_resize(slice_sizes);
161
162 graph.execute_nodes().emplace_back(new ExecuteNode(
163 resize_cache_slice_view_node,
164 {cache, input_pos_symint, q_projected, cache_sliced}));
165 }
166
resize_sdpa_out(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)167 void resize_sdpa_out(
168 ComputeGraph* graph,
169 const std::vector<ArgGroup>& args,
170 const std::vector<ValueRef>& extra_args) {
171 (void)args;
172
173 int arg_idx = 0;
174 const ValueRef q_projected = extra_args[arg_idx++];
175 const ValueRef out = extra_args[arg_idx++];
176 graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected));
177 }
178
sdpa_with_kv_cache_impl(ComputeGraph & graph,const std::vector<ValueRef> & args)179 void sdpa_with_kv_cache_impl(
180 ComputeGraph& graph,
181 const std::vector<ValueRef>& args) {
182 int arg_idx = 0;
183 const ValueRef q_projected = args[arg_idx++];
184 const ValueRef k_projected = args[arg_idx++];
185 const ValueRef v_projected = args[arg_idx++];
186 const ValueRef k_cache_data = args[arg_idx++];
187 const ValueRef v_cache_data = args[arg_idx++];
188 const ValueRef input_pos_symint = args[arg_idx++];
189 const ValueRef sequence_len = args[arg_idx++];
190 const ValueRef attn_mask = args[arg_idx++];
191 const ValueRef dropout_p = args[arg_idx++];
192 const ValueRef is_causal = args[arg_idx++];
193 const ValueRef scale = args[arg_idx++];
194
195 // Output tensors
196 const ValueRef out = args[arg_idx++];
197
198 // Unused variables
199 (void)sequence_len;
200
201 // Batches must be 1
202 VK_CHECK_COND(graph.size_at<int32_t>(-4, q_projected) == 1);
203 VK_CHECK_COND(graph.size_at<int32_t>(-4, k_projected) == 1);
204 VK_CHECK_COND(graph.size_at<int32_t>(-4, v_projected) == 1);
205 // k and v projected must have the same shape
206 VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected));
207 // head dim must match between tensors
208 VK_CHECK_COND(
209 graph.size_at<int32_t>(-1, q_projected) ==
210 graph.size_at<int32_t>(-1, k_projected));
211 // All tensors must have the packed dim be the width (head) dimension
212 VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim);
213 VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim);
214 VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim);
215 // Some variables are not supported yet
216 VK_CHECK_COND(
217 graph.val_is_none(dropout_p) ||
218 graph.extract_scalar<double>(dropout_p) == 0);
219 VK_CHECK_COND(graph.val_is_none(scale));
220 // is_causal is assumed to be true in the current implementation.
221 VK_CHECK_COND(
222 graph.val_is_none(is_causal) || graph.extract_scalar<bool>(is_causal));
223 VK_CHECK_COND(graph.val_is_none(attn_mask));
224
225 const ValueRef k_cache =
226 prepack_standard_like(graph, k_cache_data, q_projected);
227 const ValueRef v_cache =
228 prepack_standard_like(graph, v_cache_data, q_projected);
229
230 const int32_t max_seq_len = graph.size_at<int32_t>(1, k_cache);
231
232 add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache);
233 add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache);
234
235 // Slice caches from 0 to input_pos + sequence_len
236 const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache);
237 const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache);
238 add_cache_slice_view_node(
239 graph,
240 k_cache,
241 input_pos_symint,
242 q_projected,
243 k_cache_sliced,
244 max_seq_len);
245 add_cache_slice_view_node(
246 graph,
247 v_cache,
248 input_pos_symint,
249 q_projected,
250 v_cache_sliced,
251 max_seq_len);
252
253 // Scalar values for various dims
254 const ValueRef channels = graph.add_scalar<int64_t>(1);
255 const ValueRef height = graph.add_scalar<int64_t>(2);
256 const ValueRef width = graph.add_scalar<int64_t>(3);
257
258 // Repeat interleave
259 const int64_t num_heads = graph.size_at<int64_t>(2, q_projected);
260 const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_projected);
261
262 const ValueRef num_repeats =
263 graph.add_scalar<int64_t>(num_heads / num_kv_heads);
264
265 std::vector<int64_t> cache_slice_repeated_sizes(graph.sizes_of(q_projected));
266 cache_slice_repeated_sizes.at(1) = max_seq_len;
267
268 TmpTensor k_cache_sliced_repeated(
269 &graph, cache_slice_repeated_sizes, graph.dtype_of(k_cache_sliced));
270 TmpTensor v_cache_sliced_repeated(
271 &graph, cache_slice_repeated_sizes, graph.dtype_of(v_cache_sliced));
272
273 add_repeat_interleave_node(
274 graph, k_cache_sliced, num_repeats, height, k_cache_sliced_repeated);
275 add_repeat_interleave_node(
276 graph, v_cache_sliced, num_repeats, height, v_cache_sliced_repeated);
277
278 // Transpose sequence and head dims
279 const ValueRef q_transposed = graph.add_tensor_view(q_projected);
280 const ValueRef k_transposed = graph.add_tensor_view(k_cache_sliced_repeated);
281 const ValueRef v_transposed = graph.add_tensor_view(v_cache_sliced_repeated);
282
283 add_transpose_view_node(graph, q_projected, channels, height, q_transposed);
284 add_transpose_view_node(
285 graph, k_cache_sliced_repeated, channels, height, k_transposed);
286 add_transpose_view_node(
287 graph, v_cache_sliced_repeated, channels, height, v_transposed);
288
289 // Transpose K again to prepare for matmul
290 const ValueRef k_transposed_2 = graph.add_tensor_view(k_transposed);
291 add_transpose_view_node(graph, k_transposed, height, width, k_transposed_2);
292
293 // Initialize attn_weight to the maximum possible size
294 std::vector<int64_t> attn_weight_full_sizes = graph.sizes_of(q_transposed);
295 attn_weight_full_sizes.at(2) = max_seq_len;
296 attn_weight_full_sizes.at(3) = max_seq_len;
297 TmpTensor attn_weight(
298 &graph, attn_weight_full_sizes, graph.dtype_of(q_transposed));
299
300 // Resize attn_weight to the correct dim
301 std::vector<int64_t> attn_weight_sizes = attn_weight_full_sizes;
302 attn_weight_sizes.at(2) = graph.size_at<int64_t>(2, q_transposed);
303 attn_weight_sizes.at(3) = graph.size_at<int64_t>(2, k_transposed);
304 graph.get_tensor(attn_weight)->virtual_resize(attn_weight_sizes);
305
306 // Calculate attention weight, which is a matmul of Q and K
307 const ValueRef mat2_is_transposed = graph.add_scalar<bool>(false);
308 add_matmul_node(
309 graph, q_transposed, k_transposed_2, attn_weight, mat2_is_transposed);
310
311 // Apply scale and mask to the attention weight
312 add_attn_weight_scale_and_mask_node(
313 graph, input_pos_symint, q_projected, attn_weight);
314
315 TmpTensor attn_weight_softmax(
316 &graph, attn_weight_full_sizes, graph.dtype_of(q_transposed));
317 graph.get_tensor(attn_weight_softmax)->virtual_resize(attn_weight_sizes);
318 add_softmax_node(graph, attn_weight, width, attn_weight_softmax, false);
319
320 // Calculate final output
321 const ValueRef out_transposed = graph.add_tensor_view(out);
322 add_transpose_view_node(graph, out, channels, height, out_transposed);
323 add_matmul_node(
324 graph,
325 attn_weight_softmax,
326 v_transposed,
327 out_transposed,
328 mat2_is_transposed);
329
330 graph.execute_nodes().emplace_back(
331 new ExecuteNode(resize_sdpa_out, {q_projected, out}));
332 }
333
334 REGISTER_OPERATORS {
335 VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
336 }
337
338 } // namespace vkcompute
339