xref: /aosp_15_r20/external/executorch/backends/vulkan/test/op_tests/sdpa_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <gtest/gtest.h>
10 
11 #include <ATen/ATen.h>
12 
13 #include <executorch/backends/vulkan/runtime/api/api.h>
14 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16 
17 #include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
18 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
19 #include <executorch/extension/llm/custom_ops/op_sdpa.h>
20 
21 #include <cassert>
22 #include <iostream>
23 
24 namespace torch {
25 namespace executor {
26 namespace native {
27 
28 // The below are copied from executorch/extension/llm/custom_ops/op_sdpa_aot.cpp
29 // They are needed because the original definitions are inaccessible due to
30 // being defined in an anonymous namespace.
31 
sdpa_with_kv_cache_out_no_context(const Tensor & q_projected,const Tensor & k_projected,const Tensor & v_projected,Tensor & key_cache,Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const optional<Tensor> attn_mask,const double dropout_p,const bool is_causal,const optional<double> scale,Tensor & output)32 Tensor& sdpa_with_kv_cache_out_no_context(
33     const Tensor& q_projected,
34     const Tensor& k_projected,
35     const Tensor& v_projected,
36     Tensor& key_cache,
37     Tensor& value_cache,
38     const int64_t start_pos,
39     const int64_t seq_len,
40     // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
41     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
42     const optional<Tensor> attn_mask,
43     const double dropout_p,
44     const bool is_causal,
45     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
46     const optional<double> scale,
47     Tensor& output) {
48   executorch::runtime::KernelRuntimeContext context{};
49   return torch::executor::native::sdpa_with_kv_cache_out(
50       context,
51       q_projected,
52       k_projected,
53       v_projected,
54       key_cache,
55       value_cache,
56       start_pos,
57       seq_len,
58       attn_mask,
59       dropout_p,
60       is_causal,
61       scale,
62       output);
63 }
64 
sdpa_with_kv_cache_aten(const at::Tensor & q_projected,const at::Tensor & k_projected,const at::Tensor & v_projected,at::Tensor & key_cache,at::Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const std::optional<at::Tensor> attn_mask,const double dropout_p,const bool is_causal,const std::optional<double> scale)65 at::Tensor sdpa_with_kv_cache_aten(
66     const at::Tensor& q_projected,
67     const at::Tensor& k_projected,
68     const at::Tensor& v_projected,
69     at::Tensor& key_cache,
70     at::Tensor& value_cache,
71     const int64_t start_pos,
72     const int64_t seq_len,
73     // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
74     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
75     const std::optional<at::Tensor> attn_mask,
76     const double dropout_p,
77     const bool is_causal,
78     // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
79     const std::optional<double> scale) {
80   auto output = at::empty_like(q_projected);
81   WRAP_TO_ATEN(sdpa_with_kv_cache_out_no_context, 11)
82   (q_projected,
83    k_projected,
84    v_projected,
85    key_cache,
86    value_cache,
87    start_pos,
88    seq_len,
89    attn_mask,
90    dropout_p,
91    is_causal,
92    scale,
93    output);
94   return output;
95 }
96 
97 } // namespace native
98 } // namespace executor
99 } // namespace torch
100 
101 //
102 // Reference Implementation
103 //
104 
105 /*
106  * Converts a boolean mask to an additive mask. Values that are false are
107  * converted to -inf, and values that are true are converted to 0.
108  */
convert_boolean_attn_mask(const at::Tensor & attn_mask,caffe2::TypeMeta dtype)109 at::Tensor convert_boolean_attn_mask(
110     const at::Tensor& attn_mask,
111     caffe2::TypeMeta dtype) {
112   // Convert boolean mask to additive mask; need to invert mask to indicate what
113   // to mask *out*.
114   if (attn_mask.dtype() == at::kBool) {
115     return at::where(
116         attn_mask.logical_not(),
117         -std::numeric_limits<double>::infinity(),
118         at::scalar_tensor(
119             0.0, at::TensorOptions().dtype(dtype).device(attn_mask.device())));
120   }
121   // Otherwise, attn_mask represents an additive attention tensor
122   return attn_mask;
123 }
124 
125 /*
126  * Construct an attention mask for SDPA.
127  * 1. Construct a square matrix of ones with each dim equal to start_pos +
128  *    seq_len
129  * 2. Keep the lower triangular elements as 1 and set the rest to 0
130  * 3. Slice the mask to keep only seq_len rows starting from input_pos
131  * 4. Convert the mask to an additive mask
132  */
construct_attention_mask(const at::Tensor & q,const at::Tensor & k_cache,const int start_pos)133 at::Tensor construct_attention_mask(
134     const at::Tensor& q,
135     const at::Tensor& k_cache,
136     const int start_pos) {
137   const int max_seq_len = k_cache.size(1);
138   const int seq_len = q.size(1);
139 
140   const int length = start_pos + seq_len;
141   at::Tensor attn_mask_base =
142       at::ones({length, length}, q.options().dtype(at::kBool)).tril();
143 
144   at::Tensor attn_mask_sliced =
145       at::slice(attn_mask_base, 0, start_pos, start_pos + seq_len);
146 
147   attn_mask_sliced = convert_boolean_attn_mask(attn_mask_sliced, q.dtype());
148   return attn_mask_sliced;
149 }
150 
151 /*
152  * Reference implementation of SDPA
153  */
sdpa_reference_impl(const at::Tensor & q_projected,const at::Tensor & k_projected,const at::Tensor & v_projected,at::Tensor & key_cache,at::Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const std::optional<at::Tensor> __attn_mask_ignored,const double dropout_p,const bool is_causal,const std::optional<double> scale)154 at::Tensor sdpa_reference_impl(
155     const at::Tensor& q_projected,
156     const at::Tensor& k_projected,
157     const at::Tensor& v_projected,
158     at::Tensor& key_cache,
159     at::Tensor& value_cache,
160     const int64_t start_pos,
161     const int64_t seq_len,
162     const std::optional<at::Tensor> __attn_mask_ignored,
163     const double dropout_p,
164     const bool is_causal,
165     const std::optional<double> scale) {
166   at::Tensor attn_mask =
167       construct_attention_mask(q_projected, key_cache, start_pos);
168 
169   // Cache update
170   at::Tensor key_cache_updated = at::slice_scatter(
171       key_cache, k_projected, 1, start_pos, start_pos + k_projected.size(1));
172   at::Tensor value_cache_updated = at::slice_scatter(
173       value_cache, v_projected, 1, start_pos, start_pos + v_projected.size(1));
174 
175   // Write back to input
176   key_cache = key_cache_updated;
177   value_cache = value_cache_updated;
178 
179   at::Tensor key_cache_sliced =
180       at::slice(key_cache_updated, 1, 0, start_pos + q_projected.size(1));
181 
182   at::Tensor value_cache_sliced =
183       at::slice(value_cache_updated, 1, 0, start_pos + q_projected.size(1));
184 
185   // Since n_heads may not be the same as n_kv_heads, the sliced k and v cache
186   // matrices need to be "expanded" to match
187   const int num_repeats = q_projected.size(2) / key_cache.size(2);
188   at::Tensor key_cache_sliced_repeated =
189       at::repeat_interleave(key_cache_sliced, num_repeats, 2);
190   at::Tensor value_cache_sliced_repeated =
191       at::repeat_interleave(value_cache_sliced, num_repeats, 2);
192 
193   at::Tensor q_transposed = q_projected.transpose(1, 2);
194   at::Tensor k_transposed = key_cache_sliced_repeated.transpose(1, 2);
195   at::Tensor v_transposed = value_cache_sliced_repeated.transpose(1, 2);
196 
197   at::Tensor k_transposed_2 = k_transposed.transpose(-2, -1);
198   at::Tensor attn_weight_prescale = at::matmul(q_transposed, k_transposed_2);
199 
200   float scale_factor = 1.0 / sqrt(q_transposed.size(-1));
201   at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask;
202 
203   at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1);
204   at::Tensor out = at::matmul(attn_weight_softmax, v_transposed);
205 
206   return out.transpose(1, 2);
207 }
208 
209 //
210 // Test functions
211 //
212 
test_reference_sdpa(const int start_input_pos,const int sequence_len,const int embedding_dim,const int num_heads,const int num_kv_heads,const int batch_size,const int max_seq_len,at::ScalarType dtype=at::kFloat)213 void test_reference_sdpa(
214     const int start_input_pos,
215     const int sequence_len,
216     const int embedding_dim,
217     const int num_heads,
218     const int num_kv_heads,
219     const int batch_size,
220     const int max_seq_len,
221     at::ScalarType dtype = at::kFloat) {
222   const int head_dim = embedding_dim / num_heads;
223 
224   // K and V caches. Need an extra set for the reference implementation
225 
226   at::Tensor k_cache = at::zeros(
227       {batch_size, max_seq_len, num_kv_heads, head_dim},
228       at::device(at::kCPU).dtype(dtype));
229   at::Tensor v_cache = at::zeros_like(k_cache);
230 
231   at::Tensor k_cache_ref = at::zeros_like(k_cache);
232   at::Tensor v_cache_ref = at::zeros_like(v_cache);
233 
234   for (int input_pos = start_input_pos; input_pos + sequence_len < max_seq_len;
235        input_pos += sequence_len) {
236     at::Tensor q = at::rand(
237         {batch_size, sequence_len, num_heads, head_dim},
238         at::device(at::kCPU).dtype(dtype));
239     at::Tensor k = at::rand(
240         {batch_size, sequence_len, num_kv_heads, head_dim},
241         at::device(at::kCPU).dtype(dtype));
242     at::Tensor v = at::rand_like(k);
243 
244     at::Tensor reference_impl_out = sdpa_reference_impl(
245         q, k, v, k_cache, v_cache, input_pos, sequence_len, {}, 0.0, true, {});
246 
247     at::Tensor reference_out = torch::executor::native::sdpa_with_kv_cache_aten(
248         q,
249         k,
250         v,
251         k_cache_ref,
252         v_cache_ref,
253         input_pos,
254         sequence_len,
255         {},
256         0.0,
257         true,
258         {});
259 
260     ASSERT_TRUE(at::allclose(reference_impl_out, reference_out));
261   }
262 }
263 
from_at_scalartype(c10::ScalarType at_scalartype)264 vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
265   using namespace vkcompute;
266   switch (at_scalartype) {
267     case c10::kFloat:
268       return vkapi::kFloat;
269     case c10::kHalf:
270       return vkapi::kHalf;
271     case c10::kInt:
272       return vkapi::kInt;
273     case c10::kLong:
274       return vkapi::kInt;
275     case c10::kChar:
276       return vkapi::kChar;
277     default:
278       VK_THROW("Unsupported at::ScalarType!");
279   }
280 }
281 
test_vulkan_sdpa(const int start_input_pos,const int base_sequence_len,const int embedding_dim,const int num_heads,const int num_kv_heads,const int batch_size,const int max_seq_len,const bool dynamic_seq_len=true,at::ScalarType dtype=at::kFloat)282 void test_vulkan_sdpa(
283     const int start_input_pos,
284     const int base_sequence_len,
285     const int embedding_dim,
286     const int num_heads,
287     const int num_kv_heads,
288     const int batch_size,
289     const int max_seq_len,
290     const bool dynamic_seq_len = true,
291     at::ScalarType dtype = at::kFloat) {
292   const int head_dim = embedding_dim / num_heads;
293 
294   const int init_seq_len = dynamic_seq_len ? max_seq_len : base_sequence_len;
295   // K and V caches
296 
297   at::Tensor k_cache = at::zeros(
298       {batch_size, max_seq_len, num_kv_heads, head_dim},
299       at::device(at::kCPU).dtype(dtype));
300 
301   at::Tensor v_cache = at::zeros_like(k_cache);
302 
303   // Reference input data
304   at::Tensor q = at::empty(
305       {batch_size, init_seq_len, num_heads, head_dim},
306       at::device(at::kCPU).dtype(dtype));
307   at::Tensor k = at::empty(
308       {batch_size, init_seq_len, num_kv_heads, head_dim},
309       at::device(at::kCPU).dtype(dtype));
310   at::Tensor v = at::empty_like(k);
311 
312   // Get reference output
313   at::Tensor out = at::empty_like(q);
314 
315   // Build Vulkan SDPA graph
316   using namespace vkcompute;
317 
318   GraphConfig config;
319   config.set_storage_type_override(utils::kTexture3D);
320   ComputeGraph graph(config);
321 
322   // "Data" variant for vulkan initialization
323 
324   at::Tensor k_cache_data = at::zeros_like(k_cache);
325   at::Tensor v_cache_data = at::zeros_like(v_cache);
326 
327 #define MAKE_TENSORREF_FOR(x)              \
328   ValueRef r_##x = graph.add_tensorref(    \
329       x.sizes().vec(),                     \
330       from_at_scalartype(x.scalar_type()), \
331       x.const_data_ptr());
332 
333   MAKE_TENSORREF_FOR(k_cache_data);
334   MAKE_TENSORREF_FOR(v_cache_data);
335 
336 #define MAKE_INPUT_FOR(x)                    \
337   IOValueRef r_##x = graph.add_input_tensor( \
338       x.sizes().vec(), from_at_scalartype(x.scalar_type()));
339 
340   MAKE_INPUT_FOR(q);
341   MAKE_INPUT_FOR(k);
342   MAKE_INPUT_FOR(v);
343 #undef MAKE_INPUT_FOR
344 
345   const ValueRef r_input_pos_symint = graph.add_symint(start_input_pos);
346   const ValueRef r_out = graph.add_tensor(
347       out.sizes().vec(), from_at_scalartype(out.scalar_type()));
348 
349   VK_GET_OP_FN("sdpa_with_kv_cache.default")
350   (graph,
351    {
352        r_q.value,
353        r_k.value,
354        r_v.value,
355        r_k_cache_data,
356        r_v_cache_data,
357        r_input_pos_symint,
358        kDummyValueRef, // sequence_len
359        kDummyValueRef, // attn_mask
360        kDummyValueRef, // dropout_p
361        kDummyValueRef, // is_causal
362        kDummyValueRef, // scale
363        r_out,
364    });
365 
366   ValueRef staging_out = graph.set_output_tensor(r_out);
367 
368   graph.prepare();
369   graph.encode_prepack();
370   graph.prepack();
371   graph.encode_execute();
372 
373   //
374   // Run model
375   //
376 
377 #define COPY_INPUT(x) \
378   graph.copy_into_staging(r_##x.staging, x.const_data_ptr(), x.numel());
379 
380 #define EXTRACT_TENSOR(x)                             \
381   at::Tensor vk_##x = at::zeros_like(x).contiguous(); \
382   graph.copy_from_staging(                            \
383       staging_##x, vk_##x.mutable_data_ptr(), vk_##x.numel());
384 
385   int seq_len = base_sequence_len;
386   for (int i = 0, input_pos = start_input_pos;
387        input_pos + seq_len < max_seq_len;
388        input_pos += seq_len, i++) {
389     q = at::rand(
390         {batch_size, seq_len, num_heads, head_dim},
391         at::device(at::kCPU).dtype(dtype));
392     k = at::rand(
393         {batch_size, seq_len, num_kv_heads, head_dim},
394         at::device(at::kCPU).dtype(dtype));
395     v = at::rand_like(k);
396 
397     at::Tensor reference_out = sdpa_reference_impl(
398         q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {});
399 
400     graph.set_symint(r_input_pos_symint, input_pos);
401     graph.resize_input(0, q.sizes().vec());
402     graph.resize_input(1, k.sizes().vec());
403     graph.resize_input(2, v.sizes().vec());
404     graph.propagate_resize();
405 
406     // Run Vulkan SDPA
407     COPY_INPUT(q);
408     COPY_INPUT(k);
409     COPY_INPUT(v);
410 
411     graph.execute();
412 
413     out = at::empty_like(q);
414     EXTRACT_TENSOR(out);
415 
416     const bool output_correct = at::allclose(reference_out, vk_out);
417     if (!output_correct) {
418       at::Tensor diffs = at::abs(reference_out - vk_out);
419 
420       std::cout << "Failed at input_pos " << input_pos << " with seq_len "
421                 << seq_len << std::endl;
422 
423       std::cout << "Maximum difference: " << std::endl;
424       std::cout << at::max(diffs).item() << std::endl;
425       std::cout << "Found at index " << std::endl;
426       std::cout << at::argmax(diffs).item() << std::endl;
427 
428       std::cout << "Maximum value observed: " << std::endl;
429       std::cout << at::max(at::abs(at::cat({reference_out, vk_out}, -1))).item()
430                 << std::endl;
431     }
432     ASSERT_TRUE(output_correct);
433 
434     if (dynamic_seq_len) {
435       seq_len = base_sequence_len + (i % 3);
436     }
437   }
438 }
439 
TEST(VulkanSDPATest,test_sdpa_op_small_params)440 TEST(VulkanSDPATest, test_sdpa_op_small_params) {
441   const int starting_input_pos = 0;
442   const int base_sequence_len = 3;
443   const int embedding_dim = 18;
444   const int num_heads = 6;
445   const int num_kv_heads = 2;
446   const int batch_size = 1;
447   const int max_seq_len = 7;
448 
449   test_vulkan_sdpa(
450       starting_input_pos,
451       base_sequence_len,
452       embedding_dim,
453       num_heads,
454       num_kv_heads,
455       batch_size,
456       max_seq_len,
457       false);
458 }
459 
TEST(VulkanSDPATest,test_sdpa_op_small_params_dynamic)460 TEST(VulkanSDPATest, test_sdpa_op_small_params_dynamic) {
461   const int starting_input_pos = 0;
462   const int base_sequence_len = 3;
463   const int embedding_dim = 18;
464   const int num_heads = 6;
465   const int num_kv_heads = 2;
466   const int batch_size = 1;
467   const int max_seq_len = 12;
468 
469   test_vulkan_sdpa(
470       starting_input_pos,
471       base_sequence_len,
472       embedding_dim,
473       num_heads,
474       num_kv_heads,
475       batch_size,
476       max_seq_len);
477 }
478 
TEST(VulkanSDPATest,test_sdpa_op_llama3_params_dynamic)479 TEST(VulkanSDPATest, test_sdpa_op_llama3_params_dynamic) {
480   const int starting_input_pos = 0;
481   const int base_sequence_len = 3;
482   const int embedding_dim = 2048;
483   const int num_heads = 32;
484   const int num_kv_heads = 8;
485   const int batch_size = 1;
486   const int max_seq_len = 128;
487 
488   test_vulkan_sdpa(
489       starting_input_pos,
490       base_sequence_len,
491       embedding_dim,
492       num_heads,
493       num_kv_heads,
494       batch_size,
495       max_seq_len);
496 }
497 
TEST(VulkanSDPATest,test_reference_impl)498 TEST(VulkanSDPATest, test_reference_impl) {
499   const int starting_input_pos = 0;
500   const int base_sequence_len = 3;
501   const int embedding_dim = 2048;
502   const int num_heads = 32;
503   const int num_kv_heads = 8;
504   const int batch_size = 1;
505   const int max_seq_len = 128;
506 
507   test_reference_sdpa(
508       starting_input_pos,
509       base_sequence_len,
510       embedding_dim,
511       num_heads,
512       num_kv_heads,
513       batch_size,
514       max_seq_len);
515 }
516