1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <gtest/gtest.h>
10
11 #include <ATen/ATen.h>
12
13 #include <executorch/backends/vulkan/runtime/api/api.h>
14 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16
17 #include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
18 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
19 #include <executorch/extension/llm/custom_ops/op_sdpa.h>
20
21 #include <cassert>
22 #include <iostream>
23
24 namespace torch {
25 namespace executor {
26 namespace native {
27
28 // The below are copied from executorch/extension/llm/custom_ops/op_sdpa_aot.cpp
29 // They are needed because the original definitions are inaccessible due to
30 // being defined in an anonymous namespace.
31
sdpa_with_kv_cache_out_no_context(const Tensor & q_projected,const Tensor & k_projected,const Tensor & v_projected,Tensor & key_cache,Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const optional<Tensor> attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)32 Tensor& sdpa_with_kv_cache_out_no_context(
33 const Tensor& q_projected,
34 const Tensor& k_projected,
35 const Tensor& v_projected,
36 Tensor& key_cache,
37 Tensor& value_cache,
38 const int64_t start_pos,
39 const int64_t seq_len,
40 // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
41 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
42 const optional<Tensor> attn_mask,
43 const double dropout_p,
44 const bool is_causal,
45 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
46 const optional<double> scale,
47 Tensor& output) {
48 executorch::runtime::KernelRuntimeContext context{};
49 return torch::executor::native::sdpa_with_kv_cache_out(
50 context,
51 q_projected,
52 k_projected,
53 v_projected,
54 key_cache,
55 value_cache,
56 start_pos,
57 seq_len,
58 attn_mask,
59 dropout_p,
60 is_causal,
61 scale,
62 output);
63 }
64
sdpa_with_kv_cache_aten(const at::Tensor & q_projected,const at::Tensor & k_projected,const at::Tensor & v_projected,at::Tensor & key_cache,at::Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const std::optional<at::Tensor> attn_mask,const double dropout_p,const bool is_causal,const std::optional<double> scale)65 at::Tensor sdpa_with_kv_cache_aten(
66 const at::Tensor& q_projected,
67 const at::Tensor& k_projected,
68 const at::Tensor& v_projected,
69 at::Tensor& key_cache,
70 at::Tensor& value_cache,
71 const int64_t start_pos,
72 const int64_t seq_len,
73 // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
74 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
75 const std::optional<at::Tensor> attn_mask,
76 const double dropout_p,
77 const bool is_causal,
78 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
79 const std::optional<double> scale) {
80 auto output = at::empty_like(q_projected);
81 WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11)
82 (q_projected,
83 k_projected,
84 v_projected,
85 key_cache,
86 value_cache,
87 start_pos,
88 seq_len,
89 attn_mask,
90 dropout_p,
91 is_causal,
92 scale,
93 output);
94 return output;
95 }
96
97 } // namespace native
98 } // namespace executor
99 } // namespace torch
100
101 //
102 // Reference Implementation
103 //
104
105 /*
106 * Converts a boolean mask to an additive mask. Values that are false are
107 * converted to -inf, and values that are true are converted to 0.
108 */
convert_boolean_attn_mask(const at::Tensor & attn_mask,caffe2::TypeMeta dtype)109 at::Tensor convert_boolean_attn_mask(
110 const at::Tensor& attn_mask,
111 caffe2::TypeMeta dtype) {
112 // Convert boolean mask to additive mask; need to invert mask to indicate what
113 // to mask *out*.
114 if (attn_mask.dtype() == at::kBool) {
115 return at::where(
116 attn_mask.logical_not(),
117 -std::numeric_limits<double>::infinity(),
118 at::scalar_tensor(
119 0.0, at::TensorOptions().dtype(dtype).device(attn_mask.device())));
120 }
121 // Otherwise, attn_mask represents an additive attention tensor
122 return attn_mask;
123 }
124
125 /*
126 * Construct an attention mask for SDPA.
127 * 1. Construct a square matrix of ones with each dim equal to start_pos +
128 * seq_len
129 * 2. Keep the lower triangular elements as 1 and set the rest to 0
130 * 3. Slice the mask to keep only seq_len rows starting from input_pos
131 * 4. Convert the mask to an additive mask
132 */
construct_attention_mask(const at::Tensor & q,const at::Tensor & k_cache,const int start_pos)133 at::Tensor construct_attention_mask(
134 const at::Tensor& q,
135 const at::Tensor& k_cache,
136 const int start_pos) {
137 const int max_seq_len = k_cache.size(1);
138 const int seq_len = q.size(1);
139
140 const int length = start_pos + seq_len;
141 at::Tensor attn_mask_base =
142 at::ones({length, length}, q.options().dtype(at::kBool)).tril();
143
144 at::Tensor attn_mask_sliced =
145 at::slice(attn_mask_base, 0, start_pos, start_pos + seq_len);
146
147 attn_mask_sliced = convert_boolean_attn_mask(attn_mask_sliced, q.dtype());
148 return attn_mask_sliced;
149 }
150
151 /*
152 * Reference implementation of SDPA
153 */
sdpa_reference_impl(const at::Tensor & q_projected,const at::Tensor & k_projected,const at::Tensor & v_projected,at::Tensor & key_cache,at::Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const std::optional<at::Tensor> __attn_mask_ignored,const double dropout_p,const bool is_causal,const std::optional<double> scale)154 at::Tensor sdpa_reference_impl(
155 const at::Tensor& q_projected,
156 const at::Tensor& k_projected,
157 const at::Tensor& v_projected,
158 at::Tensor& key_cache,
159 at::Tensor& value_cache,
160 const int64_t start_pos,
161 const int64_t seq_len,
162 const std::optional<at::Tensor> __attn_mask_ignored,
163 const double dropout_p,
164 const bool is_causal,
165 const std::optional<double> scale) {
166 at::Tensor attn_mask =
167 construct_attention_mask(q_projected, key_cache, start_pos);
168
169 // Cache update
170 at::Tensor key_cache_updated = at::slice_scatter(
171 key_cache, k_projected, 1, start_pos, start_pos + k_projected.size(1));
172 at::Tensor value_cache_updated = at::slice_scatter(
173 value_cache, v_projected, 1, start_pos, start_pos + v_projected.size(1));
174
175 // Write back to input
176 key_cache = key_cache_updated;
177 value_cache = value_cache_updated;
178
179 at::Tensor key_cache_sliced =
180 at::slice(key_cache_updated, 1, 0, start_pos + q_projected.size(1));
181
182 at::Tensor value_cache_sliced =
183 at::slice(value_cache_updated, 1, 0, start_pos + q_projected.size(1));
184
185 // Since n_heads may not be the same as n_kv_heads, the sliced k and v cache
186 // matrices need to be "expanded" to match
187 const int num_repeats = q_projected.size(2) / key_cache.size(2);
188 at::Tensor key_cache_sliced_repeated =
189 at::repeat_interleave(key_cache_sliced, num_repeats, 2);
190 at::Tensor value_cache_sliced_repeated =
191 at::repeat_interleave(value_cache_sliced, num_repeats, 2);
192
193 at::Tensor q_transposed = q_projected.transpose(1, 2);
194 at::Tensor k_transposed = key_cache_sliced_repeated.transpose(1, 2);
195 at::Tensor v_transposed = value_cache_sliced_repeated.transpose(1, 2);
196
197 at::Tensor k_transposed_2 = k_transposed.transpose(-2, -1);
198 at::Tensor attn_weight_prescale = at::matmul(q_transposed, k_transposed_2);
199
200 float scale_factor = 1.0 / sqrt(q_transposed.size(-1));
201 at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask;
202
203 at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1);
204 at::Tensor out = at::matmul(attn_weight_softmax, v_transposed);
205
206 return out.transpose(1, 2);
207 }
208
209 //
210 // Test functions
211 //
212
test_reference_sdpa(const int start_input_pos,const int sequence_len,const int embedding_dim,const int num_heads,const int num_kv_heads,const int batch_size,const int max_seq_len,at::ScalarType dtype=at::kFloat)213 void test_reference_sdpa(
214 const int start_input_pos,
215 const int sequence_len,
216 const int embedding_dim,
217 const int num_heads,
218 const int num_kv_heads,
219 const int batch_size,
220 const int max_seq_len,
221 at::ScalarType dtype = at::kFloat) {
222 const int head_dim = embedding_dim / num_heads;
223
224 // K and V caches. Need an extra set for the reference implementation
225
226 at::Tensor k_cache = at::zeros(
227 {batch_size, max_seq_len, num_kv_heads, head_dim},
228 at::device(at::kCPU).dtype(dtype));
229 at::Tensor v_cache = at::zeros_like(k_cache);
230
231 at::Tensor k_cache_ref = at::zeros_like(k_cache);
232 at::Tensor v_cache_ref = at::zeros_like(v_cache);
233
234 for (int input_pos = start_input_pos; input_pos + sequence_len < max_seq_len;
235 input_pos += sequence_len) {
236 at::Tensor q = at::rand(
237 {batch_size, sequence_len, num_heads, head_dim},
238 at::device(at::kCPU).dtype(dtype));
239 at::Tensor k = at::rand(
240 {batch_size, sequence_len, num_kv_heads, head_dim},
241 at::device(at::kCPU).dtype(dtype));
242 at::Tensor v = at::rand_like(k);
243
244 at::Tensor reference_impl_out = sdpa_reference_impl(
245 q, k, v, k_cache, v_cache, input_pos, sequence_len, {}, 0.0, true, {});
246
247 at::Tensor reference_out = torch::executor::native::sdpa_with_kv_cache_aten(
248 q,
249 k,
250 v,
251 k_cache_ref,
252 v_cache_ref,
253 input_pos,
254 sequence_len,
255 {},
256 0.0,
257 true,
258 {});
259
260 ASSERT_TRUE(at::allclose(reference_impl_out, reference_out));
261 }
262 }
263
from_at_scalartype(c10::ScalarType at_scalartype)264 vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
265 using namespace vkcompute;
266 switch (at_scalartype) {
267 case c10::kFloat:
268 return vkapi::kFloat;
269 case c10::kHalf:
270 return vkapi::kHalf;
271 case c10::kInt:
272 return vkapi::kInt;
273 case c10::kLong:
274 return vkapi::kInt;
275 case c10::kChar:
276 return vkapi::kChar;
277 default:
278 VK_THROW("Unsupported at::ScalarType!");
279 }
280 }
281
test_vulkan_sdpa(const int start_input_pos,const int base_sequence_len,const int embedding_dim,const int num_heads,const int num_kv_heads,const int batch_size,const int max_seq_len,const bool dynamic_seq_len=true,at::ScalarType dtype=at::kFloat)282 void test_vulkan_sdpa(
283 const int start_input_pos,
284 const int base_sequence_len,
285 const int embedding_dim,
286 const int num_heads,
287 const int num_kv_heads,
288 const int batch_size,
289 const int max_seq_len,
290 const bool dynamic_seq_len = true,
291 at::ScalarType dtype = at::kFloat) {
292 const int head_dim = embedding_dim / num_heads;
293
294 const int init_seq_len = dynamic_seq_len ? max_seq_len : base_sequence_len;
295 // K and V caches
296
297 at::Tensor k_cache = at::zeros(
298 {batch_size, max_seq_len, num_kv_heads, head_dim},
299 at::device(at::kCPU).dtype(dtype));
300
301 at::Tensor v_cache = at::zeros_like(k_cache);
302
303 // Reference input data
304 at::Tensor q = at::empty(
305 {batch_size, init_seq_len, num_heads, head_dim},
306 at::device(at::kCPU).dtype(dtype));
307 at::Tensor k = at::empty(
308 {batch_size, init_seq_len, num_kv_heads, head_dim},
309 at::device(at::kCPU).dtype(dtype));
310 at::Tensor v = at::empty_like(k);
311
312 // Get reference output
313 at::Tensor out = at::empty_like(q);
314
315 // Build Vulkan SDPA graph
316 using namespace vkcompute;
317
318 GraphConfig config;
319 config.set_storage_type_override(utils::kTexture3D);
320 ComputeGraph graph(config);
321
322 // "Data" variant for vulkan initialization
323
324 at::Tensor k_cache_data = at::zeros_like(k_cache);
325 at::Tensor v_cache_data = at::zeros_like(v_cache);
326
327 #define MAKE_TENSORREF_FOR(x) \
328 ValueRef r_##x = graph.add_tensorref( \
329 x.sizes().vec(), \
330 from_at_scalartype(x.scalar_type()), \
331 x.const_data_ptr());
332
333 MAKE_TENSORREF_FOR(k_cache_data);
334 MAKE_TENSORREF_FOR(v_cache_data);
335
336 #define MAKE_INPUT_FOR(x) \
337 IOValueRef r_##x = graph.add_input_tensor( \
338 x.sizes().vec(), from_at_scalartype(x.scalar_type()));
339
340 MAKE_INPUT_FOR(q);
341 MAKE_INPUT_FOR(k);
342 MAKE_INPUT_FOR(v);
343 #undef MAKE_INPUT_FOR
344
345 const ValueRef r_input_pos_symint = graph.add_symint(start_input_pos);
346 const ValueRef r_out = graph.add_tensor(
347 out.sizes().vec(), from_at_scalartype(out.scalar_type()));
348
349 VK_GET_OP_FN("sdpa_with_kv_cache.default")
350 (graph,
351 {
352 r_q.value,
353 r_k.value,
354 r_v.value,
355 r_k_cache_data,
356 r_v_cache_data,
357 r_input_pos_symint,
358 kDummyValueRef, // sequence_len
359 kDummyValueRef, // attn_mask
360 kDummyValueRef, // dropout_p
361 kDummyValueRef, // is_causal
362 kDummyValueRef, // scale
363 r_out,
364 });
365
366 ValueRef staging_out = graph.set_output_tensor(r_out);
367
368 graph.prepare();
369 graph.encode_prepack();
370 graph.prepack();
371 graph.encode_execute();
372
373 //
374 // Run model
375 //
376
377 #define COPY_INPUT(x) \
378 graph.copy_into_staging(r_##x.staging, x.const_data_ptr(), x.numel());
379
380 #define EXTRACT_TENSOR(x) \
381 at::Tensor vk_##x = at::zeros_like(x).contiguous(); \
382 graph.copy_from_staging( \
383 staging_##x, vk_##x.mutable_data_ptr(), vk_##x.numel());
384
385 int seq_len = base_sequence_len;
386 for (int i = 0, input_pos = start_input_pos;
387 input_pos + seq_len < max_seq_len;
388 input_pos += seq_len, i++) {
389 q = at::rand(
390 {batch_size, seq_len, num_heads, head_dim},
391 at::device(at::kCPU).dtype(dtype));
392 k = at::rand(
393 {batch_size, seq_len, num_kv_heads, head_dim},
394 at::device(at::kCPU).dtype(dtype));
395 v = at::rand_like(k);
396
397 at::Tensor reference_out = sdpa_reference_impl(
398 q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {});
399
400 graph.set_symint(r_input_pos_symint, input_pos);
401 graph.resize_input(0, q.sizes().vec());
402 graph.resize_input(1, k.sizes().vec());
403 graph.resize_input(2, v.sizes().vec());
404 graph.propagate_resize();
405
406 // Run Vulkan SDPA
407 COPY_INPUT(q);
408 COPY_INPUT(k);
409 COPY_INPUT(v);
410
411 graph.execute();
412
413 out = at::empty_like(q);
414 EXTRACT_TENSOR(out);
415
416 const bool output_correct = at::allclose(reference_out, vk_out);
417 if (!output_correct) {
418 at::Tensor diffs = at::abs(reference_out - vk_out);
419
420 std::cout << "Failed at input_pos " << input_pos << " with seq_len "
421 << seq_len << std::endl;
422
423 std::cout << "Maximum difference: " << std::endl;
424 std::cout << at::max(diffs).item() << std::endl;
425 std::cout << "Found at index " << std::endl;
426 std::cout << at::argmax(diffs).item() << std::endl;
427
428 std::cout << "Maximum value observed: " << std::endl;
429 std::cout << at::max(at::abs(at::cat({reference_out, vk_out}, -1))).item()
430 << std::endl;
431 }
432 ASSERT_TRUE(output_correct);
433
434 if (dynamic_seq_len) {
435 seq_len = base_sequence_len + (i % 3);
436 }
437 }
438 }
439
TEST(VulkanSDPATest,test_sdpa_op_small_params)440 TEST(VulkanSDPATest, test_sdpa_op_small_params) {
441 const int starting_input_pos = 0;
442 const int base_sequence_len = 3;
443 const int embedding_dim = 18;
444 const int num_heads = 6;
445 const int num_kv_heads = 2;
446 const int batch_size = 1;
447 const int max_seq_len = 7;
448
449 test_vulkan_sdpa(
450 starting_input_pos,
451 base_sequence_len,
452 embedding_dim,
453 num_heads,
454 num_kv_heads,
455 batch_size,
456 max_seq_len,
457 false);
458 }
459
TEST(VulkanSDPATest,test_sdpa_op_small_params_dynamic)460 TEST(VulkanSDPATest, test_sdpa_op_small_params_dynamic) {
461 const int starting_input_pos = 0;
462 const int base_sequence_len = 3;
463 const int embedding_dim = 18;
464 const int num_heads = 6;
465 const int num_kv_heads = 2;
466 const int batch_size = 1;
467 const int max_seq_len = 12;
468
469 test_vulkan_sdpa(
470 starting_input_pos,
471 base_sequence_len,
472 embedding_dim,
473 num_heads,
474 num_kv_heads,
475 batch_size,
476 max_seq_len);
477 }
478
TEST(VulkanSDPATest,test_sdpa_op_llama3_params_dynamic)479 TEST(VulkanSDPATest, test_sdpa_op_llama3_params_dynamic) {
480 const int starting_input_pos = 0;
481 const int base_sequence_len = 3;
482 const int embedding_dim = 2048;
483 const int num_heads = 32;
484 const int num_kv_heads = 8;
485 const int batch_size = 1;
486 const int max_seq_len = 128;
487
488 test_vulkan_sdpa(
489 starting_input_pos,
490 base_sequence_len,
491 embedding_dim,
492 num_heads,
493 num_kv_heads,
494 batch_size,
495 max_seq_len);
496 }
497
TEST(VulkanSDPATest,test_reference_impl)498 TEST(VulkanSDPATest, test_reference_impl) {
499 const int starting_input_pos = 0;
500 const int base_sequence_len = 3;
501 const int embedding_dim = 2048;
502 const int num_heads = 32;
503 const int num_kv_heads = 8;
504 const int batch_size = 1;
505 const int max_seq_len = 128;
506
507 test_reference_sdpa(
508 starting_input_pos,
509 base_sequence_len,
510 embedding_dim,
511 num_heads,
512 num_kv_heads,
513 batch_size,
514 max_seq_len);
515 }
516