1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <limits>
10
11 #include <executorch/extension/llm/custom_ops/op_sdpa.h> // Declares the operator
12 #include <executorch/kernels/test/TestUtil.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19 using executorch::runtime::testing::TensorFactory;
20
op_sdpa_with_kv_cache(const exec_aten::Tensor & query,const exec_aten::Tensor & key,const exec_aten::Tensor & value,exec_aten::Tensor & key_cache,exec_aten::Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const exec_aten::optional<exec_aten::Tensor> & attn_mask,double dropout_p,bool is_causal,exec_aten::optional<double> scale,exec_aten::Tensor & out)21 exec_aten::Tensor op_sdpa_with_kv_cache(
22 const exec_aten::Tensor& query,
23 const exec_aten::Tensor& key,
24 const exec_aten::Tensor& value,
25 exec_aten::Tensor& key_cache,
26 exec_aten::Tensor& value_cache,
27 const int64_t start_pos,
28 const int64_t seq_len,
29 const exec_aten::optional<exec_aten::Tensor>& attn_mask,
30 double dropout_p,
31 bool is_causal,
32 exec_aten::optional<double> scale,
33 exec_aten::Tensor& out) {
34 executorch::runtime::KernelRuntimeContext context{};
35 return torch::executor::native::sdpa_with_kv_cache_out(
36 context,
37 query,
38 key,
39 value,
40 key_cache,
41 value_cache,
42 start_pos,
43 seq_len,
44 attn_mask,
45 dropout_p,
46 is_causal,
47 scale,
48 out);
49 }
50
51 /*
52 SDPA with cache is equivalent of the following code
53 # q, = (batch size, q seq len, num heads, head dim)
54 # k, v = (batch size, kv seq len, num heads, head dim)
55 # k cache, v cache = (num layers, batch size, max seq length, num heads, head
56 dim) # attn_mask = [max seq length, max seq length]
57
58 def sdpa_with_cache(q, k, v, k_cache, v_cache, start_pos, attn_mask):
59 attn_mask = attn_mask[start_pos].view((1, -1))
60 q = q.transpose(1, 2)
61 k_cache[:, start_pos] = k
62 v_cache[:, start_pos] = v
63 k = k.transpose(1, 2)
64 v = v.transpose(1, 2)
65 sliced_k_cache = k_cache
66 sliced_v_cache = v_cache
67 sliced_k_cache = sliced_k_cache.transpose(1, 2)
68 sliced_v_cache = sliced_v_cache.transpose(1, 2)
69 out = F.scaled_dot_product_attention(q, sliced_k_cache, sliced_v_cache,
70 attn_mask=attn_mask)
71 out = out.transpose(1, 2)
72 */
73
74 /*
75 Missing tests:
76 1. Test for different batch sizes
77 2. Mix 2 with attention_mask
78 3. No bool attention_mask
79 4. apply scaling
80 5. Different dtypes, fp16, bf16, double (or expect throw)
81 */
TEST(OpScaledDotProductAttentionTest,BasicTest)82 TEST(OpScaledDotProductAttentionTest, BasicTest) {
83 TensorFactory<exec_aten::ScalarType::Float> tfFloat;
84
85 exec_aten::Tensor query = tfFloat.make(
86 {1, 1, 4, 4},
87 {0.8823,
88 0.9150,
89 0.3829,
90 0.9593,
91 0.3904,
92 0.6009,
93 0.2566,
94 0.7936,
95 0.9408,
96 0.1332,
97 0.9346,
98 0.5936,
99 0.8694,
100 0.5677,
101 0.7411,
102 0.4294});
103 exec_aten::Tensor key = tfFloat.make(
104 {1, 1, 4, 4},
105 {0.8854,
106 0.5739,
107 0.2666,
108 0.6274,
109 0.2696,
110 0.4414,
111 0.2969,
112 0.8317,
113 0.1053,
114 0.2695,
115 0.3588,
116 0.1994,
117 0.5472,
118 0.0062,
119 0.9516,
120 0.0753});
121 exec_aten::Tensor value = tfFloat.make(
122 {1, 1, 4, 4},
123 {0.8860,
124 0.5832,
125 0.3376,
126 0.8090,
127 0.5779,
128 0.9040,
129 0.5547,
130 0.3423,
131 0.6343,
132 0.3644,
133 0.7104,
134 0.9464,
135 0.7890,
136 0.2814,
137 0.7886,
138 0.5895});
139 exec_aten::Tensor key_cache_0 = tfFloat.zeros({1, 5, 4, 4});
140 exec_aten::Tensor value_cache_0 = tfFloat.zeros({1, 5, 4, 4});
141 exec_aten::Tensor key_cache_1 = tfFloat.zeros({1, 5, 4, 4});
142 exec_aten::Tensor value_cache_1 = tfFloat.zeros({1, 5, 4, 4});
143 exec_aten::Tensor key_cache_2 = tfFloat.zeros({1, 5, 4, 4});
144 exec_aten::Tensor value_cache_2 = tfFloat.zeros({1, 5, 4, 4});
145 exec_aten::optional<exec_aten::Tensor> attn_mask;
146 double dropout_p = 0;
147 bool is_causal = false;
148 exec_aten::optional<double> scale;
149
150 // start pos: 0 layer id 0
151 exec_aten::Tensor ret_expected_0 = tfFloat.make(
152 {1, 1, 4, 4},
153 {0.8860,
154 0.5832,
155 0.3376,
156 0.8090,
157 0.5779,
158 0.9040,
159 0.5547,
160 0.3423,
161 0.6343,
162 0.3644,
163 0.7104,
164 0.9464,
165 0.7890,
166 0.2814,
167 0.7886,
168 0.5895});
169
170 std::vector<int32_t> out_size = {1, 1, 4, 4};
171 exec_aten::Tensor out = tfFloat.zeros(out_size);
172 exec_aten::Tensor ret = op_sdpa_with_kv_cache(
173 query,
174 key,
175 value,
176 key_cache_0,
177 value_cache_0,
178 0,
179 1,
180 attn_mask,
181 dropout_p,
182 is_causal,
183 scale,
184 out);
185 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_0, 1e-4, 1e-4);
186
187 // start pos: 0 layer id 2
188 exec_aten::Tensor ret_expected_1 = tfFloat.make(
189 {1, 1, 4, 4},
190 {0.8860,
191 0.5832,
192 0.3376,
193 0.8090,
194 0.5779,
195 0.9040,
196 0.5547,
197 0.3423,
198 0.6343,
199 0.3644,
200 0.7104,
201 0.9464,
202 0.7890,
203 0.2814,
204 0.7886,
205 0.5895});
206 out = tfFloat.zeros(out_size);
207 ret = op_sdpa_with_kv_cache(
208 query,
209 key,
210 value,
211 key_cache_2,
212 value_cache_2,
213 0,
214 1,
215 attn_mask,
216 dropout_p,
217 is_causal,
218 scale,
219 out);
220 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_1, 1e-4, 1e-4);
221
222 // start pos: 1 layer id 0
223 exec_aten::Tensor ret_expected_2 = tfFloat.make(
224 {1, 1, 4, 4},
225 {0.8860,
226 0.5832,
227 0.3376,
228 0.8090,
229 0.5779,
230 0.9040,
231 0.5547,
232 0.3423,
233 0.6343,
234 0.3644,
235 0.7104,
236 0.9464,
237 0.7890,
238 0.2814,
239 0.7886,
240 0.5895});
241 out = tfFloat.zeros(out_size);
242 ret = op_sdpa_with_kv_cache(
243 query,
244 key,
245 value,
246 key_cache_0,
247 value_cache_0,
248 1,
249 1,
250 attn_mask,
251 dropout_p,
252 is_causal,
253 scale,
254 out);
255 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_2, 1e-4, 1e-4);
256
257 // start pos: 1 layer id 1
258 exec_aten::Tensor ret_expected_3 = tfFloat.make(
259 {1, 1, 4, 4},
260 {0.6486,
261 0.4270,
262 0.2472,
263 0.5922,
264 0.3669,
265 0.5740,
266 0.3522,
267 0.2173,
268 0.3635,
269 0.2088,
270 0.4071,
271 0.5423,
272 0.5110,
273 0.1822,
274 0.5107,
275 0.3817});
276 out = tfFloat.zeros(out_size);
277 ret = op_sdpa_with_kv_cache(
278 query,
279 key,
280 value,
281 key_cache_1,
282 value_cache_1,
283 1,
284 1,
285 attn_mask,
286 dropout_p,
287 is_causal,
288 scale,
289 out);
290 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_3, 1e-4, 1e-4);
291
292 // start pos: 2 layer id 1
293 exec_aten::Tensor ret_expected_4 = tfFloat.make(
294 {1, 1, 4, 4},
295 {0.7490,
296 0.4930,
297 0.2854,
298 0.6838,
299 0.4489,
300 0.7021,
301 0.4308,
302 0.2659,
303 0.4622,
304 0.2655,
305 0.5176,
306 0.6895,
307 0.6202,
308 0.2212,
309 0.6199,
310 0.4634});
311 out = tfFloat.zeros(out_size);
312 ret = op_sdpa_with_kv_cache(
313 query,
314 key,
315 value,
316 key_cache_1,
317 value_cache_1,
318 2,
319 1,
320 attn_mask,
321 dropout_p,
322 is_causal,
323 scale,
324 out);
325 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_4, 1e-4, 1e-4);
326
327 // start pos: 2 layer id 2
328 exec_aten::Tensor ret_expected_5 = tfFloat.make(
329 {1, 1, 4, 4},
330 {0.7490,
331 0.4930,
332 0.2854,
333 0.6838,
334 0.4489,
335 0.7021,
336 0.4308,
337 0.2659,
338 0.4622,
339 0.2655,
340 0.5176,
341 0.6895,
342 0.6202,
343 0.2212,
344 0.6199,
345 0.4634});
346 out = tfFloat.zeros(out_size);
347 ret = op_sdpa_with_kv_cache(
348 query,
349 key,
350 value,
351 key_cache_2,
352 value_cache_2,
353 2,
354 1,
355 attn_mask,
356 dropout_p,
357 is_causal,
358 scale,
359 out);
360 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4);
361 }
362
TEST(OpScaledDotProductAttentionTest,LargerTest)363 TEST(OpScaledDotProductAttentionTest, LargerTest) {
364 TensorFactory<exec_aten::ScalarType::Float> tfFloat;
365
366 exec_aten::Tensor query = tfFloat.make(
367 {1, 1, 7, 4}, {0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566,
368 0.7936, 0.9408, 0.1332, 0.9346, 0.5936, 0.8694, 0.5677,
369 0.7411, 0.4294, 0.8854, 0.5739, 0.2666, 0.6274, 0.2696,
370 0.4414, 0.2969, 0.8317, 0.1053, 0.2695, 0.3588, 0.1994});
371 exec_aten::Tensor key = tfFloat.make(
372 {1, 1, 7, 4}, {0.5472, 0.0062, 0.9516, 0.0753, 0.8860, 0.5832, 0.3376,
373 0.8090, 0.5779, 0.9040, 0.5547, 0.3423, 0.6343, 0.3644,
374 0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895, 0.7539,
375 0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071});
376 exec_aten::Tensor value = tfFloat.make(
377 {1, 1, 7, 4}, {0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542,
378 0.3278, 0.6532, 0.3958, 0.9147, 0.2036, 0.2018, 0.2018,
379 0.9497, 0.6666, 0.9811, 0.0874, 0.0041, 0.1088, 0.1637,
380 0.7025, 0.6790, 0.9155, 0.2418, 0.1591, 0.7653, 0.2979});
381 exec_aten::Tensor key_cache_0 = tfFloat.zeros({1, 8, 7, 4});
382 exec_aten::Tensor value_cache_0 = tfFloat.zeros({1, 8, 7, 4});
383 exec_aten::Tensor key_cache_1 = tfFloat.zeros({1, 8, 7, 4});
384 exec_aten::Tensor value_cache_1 = tfFloat.zeros({1, 8, 7, 4});
385 exec_aten::Tensor key_cache_2 = tfFloat.zeros({1, 8, 7, 4});
386 exec_aten::Tensor value_cache_2 = tfFloat.zeros({1, 8, 7, 4});
387 exec_aten::optional<exec_aten::Tensor> attn_mask;
388 double dropout_p = 0;
389 bool is_causal = false;
390 exec_aten::optional<double> scale;
391
392 // start pos: 0 layer id 0
393 exec_aten::Tensor ret_expected_0 = tfFloat.make(
394 {1, 1, 7, 4}, {0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542,
395 0.3278, 0.6532, 0.3958, 0.9147, 0.2036, 0.2018, 0.2018,
396 0.9497, 0.6666, 0.9811, 0.0874, 0.0041, 0.1088, 0.1637,
397 0.7025, 0.6790, 0.9155, 0.2418, 0.1591, 0.7653, 0.2979});
398
399 std::vector<int32_t> out_size = {1, 1, 7, 4};
400 exec_aten::Tensor out = tfFloat.zeros(out_size);
401 exec_aten::Tensor ret = op_sdpa_with_kv_cache(
402 query,
403 key,
404 value,
405 key_cache_0,
406 value_cache_0,
407 0,
408 1,
409 attn_mask,
410 dropout_p,
411 is_causal,
412 scale,
413 out);
414 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_0, 1e-4, 1e-4);
415
416 // start pos: 0 layer id 2
417 exec_aten::Tensor ret_expected_1 = tfFloat.make(
418 {1, 1, 7, 4}, {0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542,
419 0.3278, 0.6532, 0.3958, 0.9147, 0.2036, 0.2018, 0.2018,
420 0.9497, 0.6666, 0.9811, 0.0874, 0.0041, 0.1088, 0.1637,
421 0.7025, 0.6790, 0.9155, 0.2418, 0.1591, 0.7653, 0.2979});
422 out = tfFloat.zeros(out_size);
423 ret = op_sdpa_with_kv_cache(
424 query,
425 key,
426 value,
427 key_cache_2,
428 value_cache_2,
429 0,
430 1,
431 attn_mask,
432 dropout_p,
433 is_causal,
434 scale,
435 out);
436 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_1, 1e-4, 1e-4);
437
438 // start pos: 1 layer id 0
439 exec_aten::Tensor ret_expected_2 = tfFloat.make(
440 {1, 1, 7, 4}, {0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542,
441 0.3278, 0.6532, 0.3958, 0.9147, 0.2036, 0.2018, 0.2018,
442 0.9497, 0.6666, 0.9811, 0.0874, 0.0041, 0.1088, 0.1637,
443 0.7025, 0.6790, 0.9155, 0.2418, 0.1591, 0.7653, 0.2979});
444 out = tfFloat.zeros(out_size);
445 ret = op_sdpa_with_kv_cache(
446 query,
447 key,
448 value,
449 key_cache_0,
450 value_cache_0,
451 1,
452 1,
453 attn_mask,
454 dropout_p,
455 is_causal,
456 scale,
457 out);
458 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_2, 1e-4, 1e-4);
459
460 // start pos: 1 layer id 1
461 exec_aten::Tensor ret_expected_3 = tfFloat.make(
462 {1, 1, 7, 4}, {0.4038, 0.3015, 0.5469, 0.0888, 0.3566, 0.1065, 0.4389,
463 0.2199, 0.4354, 0.2639, 0.6097, 0.1358, 0.1412, 0.1412,
464 0.6645, 0.4664, 0.6599, 0.0588, 0.0027, 0.0732, 0.0929,
465 0.3989, 0.3856, 0.5198, 0.1398, 0.0920, 0.4424, 0.1722});
466 out = tfFloat.zeros(out_size);
467 ret = op_sdpa_with_kv_cache(
468 query,
469 key,
470 value,
471 key_cache_1,
472 value_cache_1,
473 1,
474 1,
475 attn_mask,
476 dropout_p,
477 is_causal,
478 scale,
479 out);
480 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_3, 1e-4, 1e-4);
481
482 // start pos: 2 layer id 1
483 exec_aten::Tensor ret_expected_4 = tfFloat.make(
484 {1, 1, 7, 4}, {0.5005, 0.3737, 0.6779, 0.1101, 0.4268, 0.1275, 0.5254,
485 0.2633, 0.5225, 0.3166, 0.7317, 0.1629, 0.1661, 0.1661,
486 0.7819, 0.5488, 0.7891, 0.0703, 0.0033, 0.0875, 0.1185,
487 0.5089, 0.4919, 0.6631, 0.1771, 0.1166, 0.5607, 0.2182});
488 out = tfFloat.zeros(out_size);
489 ret = op_sdpa_with_kv_cache(
490 query,
491 key,
492 value,
493 key_cache_1,
494 value_cache_1,
495 2,
496 1,
497 attn_mask,
498 dropout_p,
499 is_causal,
500 scale,
501 out);
502 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_4, 1e-4, 1e-4);
503
504 // start pos: 2 layer id 2
505 exec_aten::Tensor ret_expected_5 = tfFloat.make(
506 {1, 1, 7, 4}, {0.5005, 0.3737, 0.6779, 0.1101, 0.4268, 0.1275, 0.5254,
507 0.2633, 0.5225, 0.3166, 0.7317, 0.1629, 0.1661, 0.1661,
508 0.7819, 0.5488, 0.7891, 0.0703, 0.0033, 0.0875, 0.1185,
509 0.5089, 0.4919, 0.6631, 0.1771, 0.1166, 0.5607, 0.2182});
510 out = tfFloat.zeros(out_size);
511 ret = op_sdpa_with_kv_cache(
512 query,
513 key,
514 value,
515 key_cache_2,
516 value_cache_2,
517 2,
518 1,
519 attn_mask,
520 dropout_p,
521 is_causal,
522 scale,
523 out);
524 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4);
525 }
526
TEST(OpScaledDotProductAttentionTest,BasicTestWithAttnMask)527 TEST(OpScaledDotProductAttentionTest, BasicTestWithAttnMask) {
528 TensorFactory<exec_aten::ScalarType::Float> tfFloat;
529
530 exec_aten::Tensor query = tfFloat.make(
531 {1, 1, 4, 4},
532 {0.8823,
533 0.9150,
534 0.3829,
535 0.9593,
536 0.3904,
537 0.6009,
538 0.2566,
539 0.7936,
540 0.9408,
541 0.1332,
542 0.9346,
543 0.5936,
544 0.8694,
545 0.5677,
546 0.7411,
547 0.4294});
548 exec_aten::Tensor key = tfFloat.make(
549 {1, 1, 4, 4},
550 {0.8854,
551 0.5739,
552 0.2666,
553 0.6274,
554 0.2696,
555 0.4414,
556 0.2969,
557 0.8317,
558 0.1053,
559 0.2695,
560 0.3588,
561 0.1994,
562 0.5472,
563 0.0062,
564 0.9516,
565 0.0753});
566 exec_aten::Tensor value = tfFloat.make(
567 {1, 1, 4, 4},
568 {0.8860,
569 0.5832,
570 0.3376,
571 0.8090,
572 0.5779,
573 0.9040,
574 0.5547,
575 0.3423,
576 0.6343,
577 0.3644,
578 0.7104,
579 0.9464,
580 0.7890,
581 0.2814,
582 0.7886,
583 0.5895});
584 exec_aten::Tensor attn_mask = tfFloat.make({1, 1}, {0});
585 exec_aten::Tensor key_cache_0 = tfFloat.zeros({1, 5, 4, 4});
586 exec_aten::Tensor value_cache_0 = tfFloat.zeros({1, 5, 4, 4});
587 exec_aten::Tensor key_cache_1 = tfFloat.zeros({1, 5, 4, 4});
588 exec_aten::Tensor value_cache_1 = tfFloat.zeros({1, 5, 4, 4});
589 exec_aten::Tensor key_cache_2 = tfFloat.zeros({1, 5, 4, 4});
590 exec_aten::Tensor value_cache_2 = tfFloat.zeros({1, 5, 4, 4});
591 double dropout_p = 0;
592 bool is_causal = false;
593 exec_aten::optional<double> scale;
594
595 // start pos: 0 layer id 0
596 exec_aten::Tensor ret_expected_0 = tfFloat.make(
597 {1, 1, 4, 4},
598 {0.8860,
599 0.5832,
600 0.3376,
601 0.8090,
602 0.5779,
603 0.9040,
604 0.5547,
605 0.3423,
606 0.6343,
607 0.3644,
608 0.7104,
609 0.9464,
610 0.7890,
611 0.2814,
612 0.7886,
613 0.5895});
614
615 std::vector<int32_t> out_size = {1, 1, 4, 4};
616 exec_aten::Tensor out = tfFloat.zeros(out_size);
617 exec_aten::Tensor ret = op_sdpa_with_kv_cache(
618 query,
619 key,
620 value,
621 key_cache_0,
622 value_cache_0,
623 0,
624 1,
625 attn_mask,
626 dropout_p,
627 is_causal,
628 scale,
629 out);
630 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_0, 1e-4, 1e-4);
631
632 // start pos: 0 layer id 2
633 exec_aten::Tensor ret_expected_1 = tfFloat.make(
634 {1, 1, 4, 4},
635 {0.8860,
636 0.5832,
637 0.3376,
638 0.8090,
639 0.5779,
640 0.9040,
641 0.5547,
642 0.3423,
643 0.6343,
644 0.3644,
645 0.7104,
646 0.9464,
647 0.7890,
648 0.2814,
649 0.7886,
650 0.5895});
651 out = tfFloat.zeros(out_size);
652 ret = op_sdpa_with_kv_cache(
653 query,
654 key,
655 value,
656 key_cache_2,
657 value_cache_2,
658 0,
659 1,
660 attn_mask,
661 dropout_p,
662 is_causal,
663 scale,
664 out);
665 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_1, 1e-4, 1e-4);
666
667 attn_mask = tfFloat.make({1, 2}, {0, 0});
668 // start pos: 1 layer id 0
669 exec_aten::Tensor ret_expected_2 = tfFloat.make(
670 {1, 1, 4, 4},
671 {0.8860,
672 0.5832,
673 0.3376,
674 0.8090,
675 0.5779,
676 0.9040,
677 0.5547,
678 0.3423,
679 0.6343,
680 0.3644,
681 0.7104,
682 0.9464,
683 0.7890,
684 0.2814,
685 0.7886,
686 0.5895});
687 out = tfFloat.zeros(out_size);
688 ret = op_sdpa_with_kv_cache(
689 query,
690 key,
691 value,
692 key_cache_0,
693 value_cache_0,
694 1,
695 1,
696 attn_mask,
697 dropout_p,
698 is_causal,
699 scale,
700 out);
701 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_2, 1e-4, 1e-4);
702
703 // start pos: 1 layer id 1
704 exec_aten::Tensor ret_expected_3 = tfFloat.make(
705 {1, 1, 4, 4},
706 {0.6486,
707 0.4270,
708 0.2472,
709 0.5922,
710 0.3669,
711 0.5740,
712 0.3522,
713 0.2173,
714 0.3635,
715 0.2088,
716 0.4071,
717 0.5423,
718 0.5110,
719 0.1822,
720 0.5107,
721 0.3817});
722 out = tfFloat.zeros(out_size);
723 ret = op_sdpa_with_kv_cache(
724 query,
725 key,
726 value,
727 key_cache_1,
728 value_cache_1,
729 1,
730 1,
731 attn_mask,
732 dropout_p,
733 is_causal,
734 scale,
735 out);
736 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_3, 1e-4, 1e-4);
737
738 attn_mask = tfFloat.make({1, 3}, {0, 0, 0});
739 // start pos: 2 layer id 1
740 exec_aten::Tensor ret_expected_4 = tfFloat.make(
741 {1, 1, 4, 4},
742 {0.7490,
743 0.4930,
744 0.2854,
745 0.6838,
746 0.4489,
747 0.7021,
748 0.4308,
749 0.2659,
750 0.4622,
751 0.2655,
752 0.5176,
753 0.6895,
754 0.6202,
755 0.2212,
756 0.6199,
757 0.4634});
758 out = tfFloat.zeros(out_size);
759 ret = op_sdpa_with_kv_cache(
760 query,
761 key,
762 value,
763 key_cache_1,
764 value_cache_1,
765 2,
766 1,
767 attn_mask,
768 dropout_p,
769 is_causal,
770 scale,
771 out);
772 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_4, 1e-4, 1e-4);
773
774 // start pos: 2 layer id 2
775 exec_aten::Tensor ret_expected_5 = tfFloat.make(
776 {1, 1, 4, 4},
777 {0.7490,
778 0.4930,
779 0.2854,
780 0.6838,
781 0.4489,
782 0.7021,
783 0.4308,
784 0.2659,
785 0.4622,
786 0.2655,
787 0.5176,
788 0.6895,
789 0.6202,
790 0.2212,
791 0.6199,
792 0.4634});
793 out = tfFloat.zeros(out_size);
794 ret = op_sdpa_with_kv_cache(
795 query,
796 key,
797 value,
798 key_cache_2,
799 value_cache_2,
800 2,
801 1,
802 attn_mask,
803 dropout_p,
804 is_causal,
805 scale,
806 out);
807 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4);
808 }
809
TEST(OpScaledDotProductAttentionTest,SequenceTest)810 TEST(OpScaledDotProductAttentionTest, SequenceTest) {
811 TensorFactory<exec_aten::ScalarType::Float> tfFloat;
812
813 exec_aten::Tensor query = tfFloat.make(
814 {1, 1, 8, 4},
815 {0.1261, 0.5031, 0.1117, 0.3905, 0.3625, 0.9328, 0.6549, 0.4128,
816 0.5845, 0.3557, 0.6965, 0.6978, 0.6343, 0.3051, 0.9266, 0.4278,
817 0.3053, 0.8132, 0.9075, 0.9976, 0.6481, 0.3296, 0.7539, 0.9290,
818 0.0096, 0.4381, 0.1590, 0.5932, 0.7068, 0.3967, 0.4582, 0.7251});
819 exec_aten::Tensor key = tfFloat.make(
820 {1, 1, 8, 4},
821 {0.4160, 0.0801, 0.9001, 0.2483, 0.4451, 0.5472, 0.4700, 0.0297,
822 0.7294, 0.2729, 0.2407, 0.6195, 0.2391, 0.2689, 0.3315, 0.3122,
823 0.2912, 0.3652, 0.6299, 0.0954, 0.1974, 0.5073, 0.5695, 0.7761,
824 0.1488, 0.6596, 0.7842, 0.7776, 0.0343, 0.3092, 0.0702, 0.1836});
825 exec_aten::Tensor value = tfFloat.make(
826 {1, 1, 8, 4},
827 {0.7785, 0.4253, 0.7124, 0.2065, 0.5760, 0.1976, 0.7499, 0.2813,
828 0.3746, 0.0662, 0.5017, 0.9747, 0.7427, 0.2332, 0.5067, 0.4452,
829 0.0975, 0.8920, 0.5081, 0.6053, 0.2981, 0.2660, 0.5824, 0.6849,
830 0.6121, 0.2590, 0.9854, 0.4264, 0.1938, 0.2661, 0.9922, 0.5000});
831
832 exec_aten::Tensor key_cache_0 = tfFloat.zeros({1, 5, 8, 4});
833 exec_aten::Tensor value_cache_0 = tfFloat.zeros({1, 5, 8, 4});
834
835 exec_aten::optional<exec_aten::Tensor> attn_mask;
836 double dropout_p = 0;
837 bool is_causal = false;
838 exec_aten::optional<double> scale;
839
840 // start pos: 0 layer id 0
841 exec_aten::Tensor ret_expected_0 = tfFloat.make(
842 {1, 1, 8, 4},
843 {0.7785, 0.4253, 0.7124, 0.2065, 0.5760, 0.1976, 0.7499, 0.2813,
844 0.3746, 0.0662, 0.5017, 0.9747, 0.7427, 0.2332, 0.5067, 0.4452,
845 0.0975, 0.8920, 0.5081, 0.6053, 0.2981, 0.2660, 0.5824, 0.6849,
846 0.6121, 0.2590, 0.9854, 0.4264, 0.1938, 0.2661, 0.9922, 0.5000});
847
848 std::vector<int32_t> out_size = {1, 1, 8, 4};
849 exec_aten::Tensor out = tfFloat.zeros(out_size);
850 exec_aten::Tensor ret = op_sdpa_with_kv_cache(
851 query,
852 key,
853 value,
854 key_cache_0,
855 value_cache_0,
856 0,
857 1,
858 attn_mask,
859 dropout_p,
860 is_causal,
861 scale,
862 out);
863 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_0, 1e-4, 1e-4);
864
865 // start pos: 1 layer id 0
866 query = tfFloat.make(
867 {1, 1, 8, 4},
868 {0.4321, 0.2919, 0.3689, 0.0789, 0.1027, 0.7926, 0.9277, 0.9772,
869 0.1390, 0.7704, 0.1905, 0.7983, 0.8608, 0.8869, 0.8600, 0.8128,
870 0.5097, 0.7297, 0.3211, 0.7177, 0.3393, 0.4916, 0.0648, 0.3693,
871 0.2371, 0.3313, 0.1807, 0.0503, 0.5326, 0.8245, 0.9554, 0.7918});
872 key = tfFloat.make(
873 {1, 1, 8, 4},
874 {0.2408, 0.0055, 0.6897, 0.7802, 0.0707, 0.6793, 0.9227, 0.5303,
875 0.1988, 0.9099, 0.7135, 0.8311, 0.1619, 0.7910, 0.1585, 0.9947,
876 0.2882, 0.8013, 0.6001, 0.6325, 0.4233, 0.7054, 0.2916, 0.0287,
877 0.3079, 0.8918, 0.3684, 0.6572, 0.3151, 0.8751, 0.7992, 0.6765});
878 value = tfFloat.make(
879 {1, 1, 8, 4},
880 {0.2444, 0.0914, 0.5188, 0.2067, 0.9111, 0.0195, 0.7234, 0.9985,
881 0.7504, 0.6705, 0.0189, 0.9809, 0.4145, 0.0328, 0.9936, 0.2965,
882 0.4646, 0.9576, 0.1534, 0.1463, 0.5813, 0.4331, 0.6152, 0.0806,
883 0.5150, 0.2776, 0.2542, 0.0422, 0.7651, 0.5963, 0.0773, 0.8968});
884 exec_aten::Tensor ret_expected_1 = tfFloat.make(
885 {1, 1, 8, 4},
886 {0.5203, 0.2639, 0.6188, 0.2066, 0.7836, 0.0872, 0.7335, 0.7256,
887 0.5940, 0.4189, 0.2199, 0.9784, 0.5461, 0.1132, 0.7983, 0.3561,
888 0.3125, 0.9305, 0.3003, 0.3364, 0.4355, 0.3471, 0.5983, 0.3918,
889 0.5631, 0.2684, 0.6168, 0.2327, 0.5942, 0.4976, 0.3510, 0.7781});
890 out = tfFloat.zeros(out_size);
891 ret = op_sdpa_with_kv_cache(
892 query,
893 key,
894 value,
895 key_cache_0,
896 value_cache_0,
897 1,
898 1,
899 attn_mask,
900 dropout_p,
901 is_causal,
902 scale,
903 out);
904 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_1, 1e-4, 1e-4);
905
906 // start pos: 2 layer id 0
907 query = tfFloat.make(
908 {1, 1, 8, 4},
909 {0.6508, 0.5928, 0.2064, 0.5754, 0.9818, 0.8429, 0.1106, 0.9564,
910 0.5388, 0.7405, 0.8883, 0.9263, 0.1102, 0.9378, 0.1604, 0.5375,
911 0.1506, 0.3904, 0.4773, 0.4402, 0.4210, 0.5394, 0.9932, 0.7905,
912 0.7797, 0.7001, 0.8871, 0.4769, 0.5398, 0.6029, 0.0639, 0.0972});
913 key = tfFloat.make(
914 {1, 1, 8, 4},
915 {0.5613, 0.3044, 0.4908, 0.3853, 0.5778, 0.8253, 0.3342, 0.9004,
916 0.8948, 0.1163, 0.1139, 0.0955, 0.2260, 0.3054, 0.4624, 0.3784,
917 0.2474, 0.3412, 0.3191, 0.9905, 0.3147, 0.1420, 0.7078, 0.4711,
918 0.8828, 0.8124, 0.9594, 0.1338, 0.8214, 0.9196, 0.2531, 0.9596});
919 value = tfFloat.make(
920 {1, 1, 8, 4},
921 {0.8748, 0.5055, 0.7411, 0.3252, 0.0639, 0.6264, 0.6491, 0.1732,
922 0.7425, 0.0729, 0.9303, 0.9842, 0.6361, 0.1863, 0.7433, 0.5852,
923 0.6360, 0.6643, 0.8807, 0.2851, 0.3875, 0.6364, 0.5545, 0.9032,
924 0.2374, 0.4818, 0.5934, 0.3672, 0.8409, 0.5547, 0.0379, 0.4458});
925 exec_aten::Tensor ret_expected_2 = tfFloat.make(
926 {1, 1, 8, 4},
927 {0.6350, 0.3426, 0.6582, 0.2484, 0.4391, 0.3419, 0.6962, 0.4399,
928 0.6321, 0.3475, 0.3754, 0.9798, 0.5721, 0.1344, 0.7829, 0.4233,
929 0.4122, 0.8394, 0.5040, 0.3304, 0.4066, 0.4378, 0.5820, 0.5922,
930 0.4333, 0.3541, 0.6168, 0.2918, 0.6486, 0.4949, 0.2965, 0.6151});
931 out = tfFloat.zeros(out_size);
932 ret = op_sdpa_with_kv_cache(
933 query,
934 key,
935 value,
936 key_cache_0,
937 value_cache_0,
938 2,
939 1,
940 attn_mask,
941 dropout_p,
942 is_causal,
943 scale,
944 out);
945 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_2, 1e-4, 1e-4);
946
947 // start pos: 3 layer id 0
948 query = tfFloat.make(
949 {1, 1, 8, 4},
950 {0.2732, 0.5486, 0.4419, 0.0040, 0.4089, 0.4521, 0.3526, 0.9594,
951 0.3909, 0.8212, 0.6239, 0.0779, 0.6175, 0.9144, 0.1729, 0.1768,
952 0.9894, 0.9018, 0.2211, 0.8009, 0.4360, 0.0070, 0.5376, 0.6615,
953 0.3500, 0.6739, 0.0724, 0.8465, 0.9263, 0.7757, 0.5847, 0.6647});
954 key = tfFloat.make(
955 {1, 1, 8, 4},
956 {0.1382, 0.3751, 0.4523, 0.2218, 0.1307, 0.8363, 0.8393, 0.0459,
957 0.6591, 0.7034, 0.9750, 0.7893, 0.9597, 0.3363, 0.8502, 0.9067,
958 0.0278, 0.0986, 0.6012, 0.7730, 0.2516, 0.5551, 0.4993, 0.6266,
959 0.2313, 0.7820, 0.8325, 0.1531, 0.5048, 0.5014, 0.6606, 0.9658});
960 value = tfFloat.make(
961 {1, 1, 8, 4},
962 {0.6466, 0.3864, 0.9491, 0.3097, 0.3548, 0.5341, 0.1192, 0.5544,
963 0.1608, 0.5514, 0.5479, 0.5692, 0.0784, 0.0251, 0.7301, 0.9288,
964 0.0563, 0.6852, 0.1319, 0.5313, 0.9652, 0.8793, 0.1344, 0.8093,
965 0.7612, 0.4992, 0.9844, 0.3014, 0.3836, 0.2473, 0.5719, 0.6324});
966 exec_aten::Tensor ret_expected_3 = tfFloat.make(
967 {1, 1, 8, 4},
968 {0.6441, 0.3571, 0.7319, 0.2624, 0.4506, 0.3619, 0.5749, 0.4930,
969 0.4860, 0.3924, 0.4596, 0.8517, 0.4312, 0.1060, 0.7579, 0.5796,
970 0.3507, 0.8063, 0.4223, 0.3597, 0.5522, 0.5558, 0.4665, 0.6486,
971 0.5263, 0.3701, 0.6880, 0.2790, 0.6116, 0.4449, 0.3184, 0.6258});
972 out = tfFloat.zeros(out_size);
973 ret = op_sdpa_with_kv_cache(
974 query,
975 key,
976 value,
977 key_cache_0,
978 value_cache_0,
979 3,
980 1,
981 attn_mask,
982 dropout_p,
983 is_causal,
984 scale,
985 out);
986 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_3, 1e-4, 1e-4);
987 }
988