xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ops/impl/MatMul.h>
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/RepeatInterleave.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Slice.h>
16 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Softmax.h>
17 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Transpose.h>
18 
19 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
20 
21 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
22 
23 namespace vkcompute {
24 
add_kv_cache_update_node(ComputeGraph & graph,const ValueRef input_pos_symint,const ValueRef projected,const ValueRef cache)25 void add_kv_cache_update_node(
26     ComputeGraph& graph,
27     const ValueRef input_pos_symint,
28     const ValueRef projected,
29     const ValueRef cache) {
30   std::string kernel_name("kv_cache_update");
31   add_storage_type_suffix(kernel_name, graph.storage_type_of(projected));
32   add_dtype_suffix(kernel_name, graph.dtype_of(projected));
33 
34   utils::uvec3 global_size;
35   vkapi::ParamsBindList param_ubos;
36 
37   if (graph.is_buffer_storage(cache)) {
38     global_size = graph.create_global_wg_size(projected);
39 
40     param_ubos = {
41         graph.numel_ubo(projected),
42         graph.strides_ubo(cache),
43         graph.get_or_create_int_param_buffer(input_pos_symint)};
44   } else {
45     global_size = graph.logical_limits_of(projected);
46 
47     param_ubos = {
48         graph.logical_limits_ubo(projected),
49         graph.get_or_create_int_param_buffer(input_pos_symint)};
50   }
51   const utils::uvec3 local_size = graph.create_local_wg_size(global_size);
52 
53   graph.execute_nodes().emplace_back(new DispatchNode(
54       graph,
55       VK_KERNEL_FROM_STR(kernel_name),
56       global_size,
57       local_size,
58       // Inputs and Outputs
59       {{cache, vkapi::kWrite}, {projected, vkapi::kRead}},
60       // Shader param buffers
61       param_ubos,
62       // Specialization Constants
63       {},
64       // Resizing Logic
65       nullptr,
66       {}));
67 }
68 
add_attn_weight_scale_and_mask_node(ComputeGraph & graph,const ValueRef input_pos_symint,const ValueRef q_projected,const ValueRef attn_weight)69 void add_attn_weight_scale_and_mask_node(
70     ComputeGraph& graph,
71     const ValueRef input_pos_symint,
72     const ValueRef q_projected,
73     const ValueRef attn_weight) {
74   std::string kernel_name("sdpa_attn_weight_scale_and_mask");
75   add_storage_type_suffix(kernel_name, graph.storage_type_of(attn_weight));
76   add_dtype_suffix(kernel_name, graph.dtype_of(attn_weight));
77 
78   const int32_t head_dim_size = graph.size_at<int32_t>(-1, q_projected);
79   const float scale_val = 1.0f / std::sqrt(static_cast<float>(head_dim_size));
80 
81   utils::uvec3 global_size;
82   utils::uvec3 local_size;
83   vkapi::ParamsBindList param_ubos;
84 
85   if (graph.is_buffer_storage(attn_weight)) {
86     global_size = {
87         graph.size_at<uint32_t>(-1, attn_weight),
88         graph.size_at<uint32_t>(-2, attn_weight),
89         graph.size_at<uint32_t>(-3, attn_weight),
90     };
91 
92     param_ubos = {
93         graph.sizes_ubo(attn_weight),
94         graph.strides_ubo(attn_weight),
95         graph.create_params_buffer(scale_val)};
96   } else {
97     global_size = graph.logical_limits_of(attn_weight);
98 
99     param_ubos = {
100         graph.logical_limits_ubo(attn_weight),
101         graph.get_or_create_int_param_buffer(input_pos_symint),
102         graph.create_params_buffer(scale_val)};
103   }
104 
105   local_size = graph.create_local_wg_size(global_size);
106 
107   graph.execute_nodes().emplace_back(new DispatchNode(
108       graph,
109       VK_KERNEL_FROM_STR(kernel_name),
110       global_size,
111       local_size,
112       // Inputs and Outputs
113       {{attn_weight, vkapi::kReadWrite}},
114       // Shader param buffers
115       param_ubos,
116       // Specialization Constants
117       {},
118       // Resizing Logic
119       nullptr,
120       {}));
121 }
122 
get_cache_slice_sizes(ComputeGraph & graph,ValueRef cache,ValueRef input_pos_symint,ValueRef q_projected)123 std::vector<int64_t> get_cache_slice_sizes(
124     ComputeGraph& graph,
125     ValueRef cache,
126     ValueRef input_pos_symint,
127     ValueRef q_projected) {
128   std::vector<int64_t> slice_sizes = graph.sizes_of(cache);
129 
130   // Cache slicing will always be in the channels dim
131   const int32_t input_pos_val = graph.read_symint(input_pos_symint);
132   const int64_t q_seq_len = graph.size_at<int64_t>(1, q_projected);
133   slice_sizes.at(1) = input_pos_val + q_seq_len;
134   return slice_sizes;
135 }
136 
resize_cache_slice_view_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)137 void resize_cache_slice_view_node(
138     ComputeGraph* graph,
139     const std::vector<ArgGroup>& args,
140     const std::vector<ValueRef>& extra_args) {
141   (void)args;
142   std::vector<int64_t> slice_sizes = get_cache_slice_sizes(
143       *graph, extra_args[0], extra_args[1], extra_args[2]);
144 
145   graph->get_tensor(extra_args[3])->virtual_resize(slice_sizes);
146 }
147 
add_cache_slice_view_node(ComputeGraph & graph,ValueRef cache,ValueRef input_pos_symint,ValueRef q_projected,ValueRef cache_sliced,const int64_t max_seq_len)148 void add_cache_slice_view_node(
149     ComputeGraph& graph,
150     ValueRef cache,
151     ValueRef input_pos_symint,
152     ValueRef q_projected,
153     ValueRef cache_sliced,
154     const int64_t max_seq_len) {
155   std::vector<int64_t> slice_sizes =
156       get_cache_slice_sizes(graph, cache, input_pos_symint, q_projected);
157   // Initialize the slice to the maximum possible size to start
158   slice_sizes.at(1) = max_seq_len;
159 
160   graph.get_tensor(cache_sliced)->virtual_resize(slice_sizes);
161 
162   graph.execute_nodes().emplace_back(new ExecuteNode(
163       resize_cache_slice_view_node,
164       {cache, input_pos_symint, q_projected, cache_sliced}));
165 }
166 
resize_sdpa_out(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)167 void resize_sdpa_out(
168     ComputeGraph* graph,
169     const std::vector<ArgGroup>& args,
170     const std::vector<ValueRef>& extra_args) {
171   (void)args;
172 
173   int arg_idx = 0;
174   const ValueRef q_projected = extra_args[arg_idx++];
175   const ValueRef out = extra_args[arg_idx++];
176   graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected));
177 }
178 
sdpa_with_kv_cache_impl(ComputeGraph & graph,const std::vector<ValueRef> & args)179 void sdpa_with_kv_cache_impl(
180     ComputeGraph& graph,
181     const std::vector<ValueRef>& args) {
182   int arg_idx = 0;
183   const ValueRef q_projected = args[arg_idx++];
184   const ValueRef k_projected = args[arg_idx++];
185   const ValueRef v_projected = args[arg_idx++];
186   const ValueRef k_cache_data = args[arg_idx++];
187   const ValueRef v_cache_data = args[arg_idx++];
188   const ValueRef input_pos_symint = args[arg_idx++];
189   const ValueRef sequence_len = args[arg_idx++];
190   const ValueRef attn_mask = args[arg_idx++];
191   const ValueRef dropout_p = args[arg_idx++];
192   const ValueRef is_causal = args[arg_idx++];
193   const ValueRef scale = args[arg_idx++];
194 
195   // Output tensors
196   const ValueRef out = args[arg_idx++];
197 
198   // Unused variables
199   (void)sequence_len;
200 
201   // Batches must be 1
202   VK_CHECK_COND(graph.size_at<int32_t>(-4, q_projected) == 1);
203   VK_CHECK_COND(graph.size_at<int32_t>(-4, k_projected) == 1);
204   VK_CHECK_COND(graph.size_at<int32_t>(-4, v_projected) == 1);
205   // k and v projected must have the same shape
206   VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected));
207   // head dim must match between tensors
208   VK_CHECK_COND(
209       graph.size_at<int32_t>(-1, q_projected) ==
210       graph.size_at<int32_t>(-1, k_projected));
211   // All tensors must have the packed dim be the width (head) dimension
212   VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim);
213   VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim);
214   VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim);
215   // Some variables are not supported yet
216   VK_CHECK_COND(
217       graph.val_is_none(dropout_p) ||
218       graph.extract_scalar<double>(dropout_p) == 0);
219   VK_CHECK_COND(graph.val_is_none(scale));
220   // is_causal is assumed to be true in the current implementation.
221   VK_CHECK_COND(
222       graph.val_is_none(is_causal) || graph.extract_scalar<bool>(is_causal));
223   VK_CHECK_COND(graph.val_is_none(attn_mask));
224 
225   const ValueRef k_cache =
226       prepack_standard_like(graph, k_cache_data, q_projected);
227   const ValueRef v_cache =
228       prepack_standard_like(graph, v_cache_data, q_projected);
229 
230   const int32_t max_seq_len = graph.size_at<int32_t>(1, k_cache);
231 
232   add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache);
233   add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache);
234 
235   // Slice caches from 0 to input_pos + sequence_len
236   const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache);
237   const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache);
238   add_cache_slice_view_node(
239       graph,
240       k_cache,
241       input_pos_symint,
242       q_projected,
243       k_cache_sliced,
244       max_seq_len);
245   add_cache_slice_view_node(
246       graph,
247       v_cache,
248       input_pos_symint,
249       q_projected,
250       v_cache_sliced,
251       max_seq_len);
252 
253   // Scalar values for various dims
254   const ValueRef channels = graph.add_scalar<int64_t>(1);
255   const ValueRef height = graph.add_scalar<int64_t>(2);
256   const ValueRef width = graph.add_scalar<int64_t>(3);
257 
258   // Repeat interleave
259   const int64_t num_heads = graph.size_at<int64_t>(2, q_projected);
260   const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_projected);
261 
262   const ValueRef num_repeats =
263       graph.add_scalar<int64_t>(num_heads / num_kv_heads);
264 
265   std::vector<int64_t> cache_slice_repeated_sizes(graph.sizes_of(q_projected));
266   cache_slice_repeated_sizes.at(1) = max_seq_len;
267 
268   TmpTensor k_cache_sliced_repeated(
269       &graph, cache_slice_repeated_sizes, graph.dtype_of(k_cache_sliced));
270   TmpTensor v_cache_sliced_repeated(
271       &graph, cache_slice_repeated_sizes, graph.dtype_of(v_cache_sliced));
272 
273   add_repeat_interleave_node(
274       graph, k_cache_sliced, num_repeats, height, k_cache_sliced_repeated);
275   add_repeat_interleave_node(
276       graph, v_cache_sliced, num_repeats, height, v_cache_sliced_repeated);
277 
278   // Transpose sequence and head dims
279   const ValueRef q_transposed = graph.add_tensor_view(q_projected);
280   const ValueRef k_transposed = graph.add_tensor_view(k_cache_sliced_repeated);
281   const ValueRef v_transposed = graph.add_tensor_view(v_cache_sliced_repeated);
282 
283   add_transpose_view_node(graph, q_projected, channels, height, q_transposed);
284   add_transpose_view_node(
285       graph, k_cache_sliced_repeated, channels, height, k_transposed);
286   add_transpose_view_node(
287       graph, v_cache_sliced_repeated, channels, height, v_transposed);
288 
289   // Transpose K again to prepare for matmul
290   const ValueRef k_transposed_2 = graph.add_tensor_view(k_transposed);
291   add_transpose_view_node(graph, k_transposed, height, width, k_transposed_2);
292 
293   // Initialize attn_weight to the maximum possible size
294   std::vector<int64_t> attn_weight_full_sizes = graph.sizes_of(q_transposed);
295   attn_weight_full_sizes.at(2) = max_seq_len;
296   attn_weight_full_sizes.at(3) = max_seq_len;
297   TmpTensor attn_weight(
298       &graph, attn_weight_full_sizes, graph.dtype_of(q_transposed));
299 
300   // Resize attn_weight to the correct dim
301   std::vector<int64_t> attn_weight_sizes = attn_weight_full_sizes;
302   attn_weight_sizes.at(2) = graph.size_at<int64_t>(2, q_transposed);
303   attn_weight_sizes.at(3) = graph.size_at<int64_t>(2, k_transposed);
304   graph.get_tensor(attn_weight)->virtual_resize(attn_weight_sizes);
305 
306   // Calculate attention weight, which is a matmul of Q and K
307   const ValueRef mat2_is_transposed = graph.add_scalar<bool>(false);
308   add_matmul_node(
309       graph, q_transposed, k_transposed_2, attn_weight, mat2_is_transposed);
310 
311   // Apply scale and mask to the attention weight
312   add_attn_weight_scale_and_mask_node(
313       graph, input_pos_symint, q_projected, attn_weight);
314 
315   TmpTensor attn_weight_softmax(
316       &graph, attn_weight_full_sizes, graph.dtype_of(q_transposed));
317   graph.get_tensor(attn_weight_softmax)->virtual_resize(attn_weight_sizes);
318   add_softmax_node(graph, attn_weight, width, attn_weight_softmax, false);
319 
320   // Calculate final output
321   const ValueRef out_transposed = graph.add_tensor_view(out);
322   add_transpose_view_node(graph, out, channels, height, out_transposed);
323   add_matmul_node(
324       graph,
325       attn_weight_softmax,
326       v_transposed,
327       out_transposed,
328       mat2_is_transposed);
329 
330   graph.execute_nodes().emplace_back(
331       new ExecuteNode(resize_sdpa_out, {q_projected, out}));
332 }
333 
334 REGISTER_OPERATORS {
335   VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
336 }
337 
338 } // namespace vkcompute
339