xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/op_sdpa_with_kv_cache_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <limits>
10 
11 #include <executorch/extension/llm/custom_ops/op_sdpa.h> // Declares the operator
12 #include <executorch/kernels/test/TestUtil.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace ::testing;
19 using executorch::runtime::testing::TensorFactory;
20 
op_sdpa_with_kv_cache(const exec_aten::Tensor & query,const exec_aten::Tensor & key,const exec_aten::Tensor & value,exec_aten::Tensor & key_cache,exec_aten::Tensor & value_cache,const int64_t start_pos,const int64_t seq_len,const exec_aten::optional<exec_aten::Tensor> & attn_mask,double dropout_p,bool is_causal,exec_aten::optional<double> scale,exec_aten::Tensor & out)21 exec_aten::Tensor op_sdpa_with_kv_cache(
22     const exec_aten::Tensor& query,
23     const exec_aten::Tensor& key,
24     const exec_aten::Tensor& value,
25     exec_aten::Tensor& key_cache,
26     exec_aten::Tensor& value_cache,
27     const int64_t start_pos,
28     const int64_t seq_len,
29     const exec_aten::optional<exec_aten::Tensor>& attn_mask,
30     double dropout_p,
31     bool is_causal,
32     exec_aten::optional<double> scale,
33     exec_aten::Tensor& out) {
34   executorch::runtime::KernelRuntimeContext context{};
35   return torch::executor::native::sdpa_with_kv_cache_out(
36       context,
37       query,
38       key,
39       value,
40       key_cache,
41       value_cache,
42       start_pos,
43       seq_len,
44       attn_mask,
45       dropout_p,
46       is_causal,
47       scale,
48       out);
49 }
50 
51 /*
52 SDPA with cache is equivalent of the following code
53 # q, = (batch size, q seq len, num heads, head dim)
54 # k, v = (batch size, kv seq len, num heads, head dim)
55 # k cache, v cache = (num layers, batch size, max seq length, num heads, head
56 dim) # attn_mask = [max seq length, max seq length]
57 
58 def sdpa_with_cache(q, k, v, k_cache, v_cache, start_pos, attn_mask):
59     attn_mask = attn_mask[start_pos].view((1, -1))
60     q = q.transpose(1, 2)
61     k_cache[:, start_pos] = k
62     v_cache[:, start_pos] = v
63     k = k.transpose(1, 2)
64     v = v.transpose(1, 2)
65     sliced_k_cache = k_cache
66     sliced_v_cache = v_cache
67     sliced_k_cache = sliced_k_cache.transpose(1, 2)
68     sliced_v_cache = sliced_v_cache.transpose(1, 2)
69     out = F.scaled_dot_product_attention(q, sliced_k_cache, sliced_v_cache,
70 attn_mask=attn_mask)
71     out = out.transpose(1, 2)
72 */
73 
74 /*
75 Missing tests:
76 1. Test for different batch sizes
77 2. Mix 2 with attention_mask
78 3. No bool attention_mask
79 4. apply scaling
80 5. Different dtypes, fp16, bf16, double (or expect throw)
81 */
TEST(OpScaledDotProductAttentionTest,BasicTest)82 TEST(OpScaledDotProductAttentionTest, BasicTest) {
83   TensorFactory<exec_aten::ScalarType::Float> tfFloat;
84 
85   exec_aten::Tensor query = tfFloat.make(
86       {1, 1, 4, 4},
87       {0.8823,
88        0.9150,
89        0.3829,
90        0.9593,
91        0.3904,
92        0.6009,
93        0.2566,
94        0.7936,
95        0.9408,
96        0.1332,
97        0.9346,
98        0.5936,
99        0.8694,
100        0.5677,
101        0.7411,
102        0.4294});
103   exec_aten::Tensor key = tfFloat.make(
104       {1, 1, 4, 4},
105       {0.8854,
106        0.5739,
107        0.2666,
108        0.6274,
109        0.2696,
110        0.4414,
111        0.2969,
112        0.8317,
113        0.1053,
114        0.2695,
115        0.3588,
116        0.1994,
117        0.5472,
118        0.0062,
119        0.9516,
120        0.0753});
121   exec_aten::Tensor value = tfFloat.make(
122       {1, 1, 4, 4},
123       {0.8860,
124        0.5832,
125        0.3376,
126        0.8090,
127        0.5779,
128        0.9040,
129        0.5547,
130        0.3423,
131        0.6343,
132        0.3644,
133        0.7104,
134        0.9464,
135        0.7890,
136        0.2814,
137        0.7886,
138        0.5895});
139   exec_aten::Tensor key_cache_0 = tfFloat.zeros({1, 5, 4, 4});
140   exec_aten::Tensor value_cache_0 = tfFloat.zeros({1, 5, 4, 4});
141   exec_aten::Tensor key_cache_1 = tfFloat.zeros({1, 5, 4, 4});
142   exec_aten::Tensor value_cache_1 = tfFloat.zeros({1, 5, 4, 4});
143   exec_aten::Tensor key_cache_2 = tfFloat.zeros({1, 5, 4, 4});
144   exec_aten::Tensor value_cache_2 = tfFloat.zeros({1, 5, 4, 4});
145   exec_aten::optional<exec_aten::Tensor> attn_mask;
146   double dropout_p = 0;
147   bool is_causal = false;
148   exec_aten::optional<double> scale;
149 
150   // start pos: 0 layer id 0
151   exec_aten::Tensor ret_expected_0 = tfFloat.make(
152       {1, 1, 4, 4},
153       {0.8860,
154        0.5832,
155        0.3376,
156        0.8090,
157        0.5779,
158        0.9040,
159        0.5547,
160        0.3423,
161        0.6343,
162        0.3644,
163        0.7104,
164        0.9464,
165        0.7890,
166        0.2814,
167        0.7886,
168        0.5895});
169 
170   std::vector<int32_t> out_size = {1, 1, 4, 4};
171   exec_aten::Tensor out = tfFloat.zeros(out_size);
172   exec_aten::Tensor ret = op_sdpa_with_kv_cache(
173       query,
174       key,
175       value,
176       key_cache_0,
177       value_cache_0,
178       0,
179       1,
180       attn_mask,
181       dropout_p,
182       is_causal,
183       scale,
184       out);
185   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_0, 1e-4, 1e-4);
186 
187   // start pos: 0 layer id 2
188   exec_aten::Tensor ret_expected_1 = tfFloat.make(
189       {1, 1, 4, 4},
190       {0.8860,
191        0.5832,
192        0.3376,
193        0.8090,
194        0.5779,
195        0.9040,
196        0.5547,
197        0.3423,
198        0.6343,
199        0.3644,
200        0.7104,
201        0.9464,
202        0.7890,
203        0.2814,
204        0.7886,
205        0.5895});
206   out = tfFloat.zeros(out_size);
207   ret = op_sdpa_with_kv_cache(
208       query,
209       key,
210       value,
211       key_cache_2,
212       value_cache_2,
213       0,
214       1,
215       attn_mask,
216       dropout_p,
217       is_causal,
218       scale,
219       out);
220   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_1, 1e-4, 1e-4);
221 
222   // start pos: 1 layer id 0
223   exec_aten::Tensor ret_expected_2 = tfFloat.make(
224       {1, 1, 4, 4},
225       {0.8860,
226        0.5832,
227        0.3376,
228        0.8090,
229        0.5779,
230        0.9040,
231        0.5547,
232        0.3423,
233        0.6343,
234        0.3644,
235        0.7104,
236        0.9464,
237        0.7890,
238        0.2814,
239        0.7886,
240        0.5895});
241   out = tfFloat.zeros(out_size);
242   ret = op_sdpa_with_kv_cache(
243       query,
244       key,
245       value,
246       key_cache_0,
247       value_cache_0,
248       1,
249       1,
250       attn_mask,
251       dropout_p,
252       is_causal,
253       scale,
254       out);
255   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_2, 1e-4, 1e-4);
256 
257   // start pos: 1 layer id 1
258   exec_aten::Tensor ret_expected_3 = tfFloat.make(
259       {1, 1, 4, 4},
260       {0.6486,
261        0.4270,
262        0.2472,
263        0.5922,
264        0.3669,
265        0.5740,
266        0.3522,
267        0.2173,
268        0.3635,
269        0.2088,
270        0.4071,
271        0.5423,
272        0.5110,
273        0.1822,
274        0.5107,
275        0.3817});
276   out = tfFloat.zeros(out_size);
277   ret = op_sdpa_with_kv_cache(
278       query,
279       key,
280       value,
281       key_cache_1,
282       value_cache_1,
283       1,
284       1,
285       attn_mask,
286       dropout_p,
287       is_causal,
288       scale,
289       out);
290   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_3, 1e-4, 1e-4);
291 
292   // start pos: 2 layer id 1
293   exec_aten::Tensor ret_expected_4 = tfFloat.make(
294       {1, 1, 4, 4},
295       {0.7490,
296        0.4930,
297        0.2854,
298        0.6838,
299        0.4489,
300        0.7021,
301        0.4308,
302        0.2659,
303        0.4622,
304        0.2655,
305        0.5176,
306        0.6895,
307        0.6202,
308        0.2212,
309        0.6199,
310        0.4634});
311   out = tfFloat.zeros(out_size);
312   ret = op_sdpa_with_kv_cache(
313       query,
314       key,
315       value,
316       key_cache_1,
317       value_cache_1,
318       2,
319       1,
320       attn_mask,
321       dropout_p,
322       is_causal,
323       scale,
324       out);
325   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_4, 1e-4, 1e-4);
326 
327   // start pos: 2 layer id 2
328   exec_aten::Tensor ret_expected_5 = tfFloat.make(
329       {1, 1, 4, 4},
330       {0.7490,
331        0.4930,
332        0.2854,
333        0.6838,
334        0.4489,
335        0.7021,
336        0.4308,
337        0.2659,
338        0.4622,
339        0.2655,
340        0.5176,
341        0.6895,
342        0.6202,
343        0.2212,
344        0.6199,
345        0.4634});
346   out = tfFloat.zeros(out_size);
347   ret = op_sdpa_with_kv_cache(
348       query,
349       key,
350       value,
351       key_cache_2,
352       value_cache_2,
353       2,
354       1,
355       attn_mask,
356       dropout_p,
357       is_causal,
358       scale,
359       out);
360   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4);
361 }
362 
TEST(OpScaledDotProductAttentionTest,LargerTest)363 TEST(OpScaledDotProductAttentionTest, LargerTest) {
364   TensorFactory<exec_aten::ScalarType::Float> tfFloat;
365 
366   exec_aten::Tensor query = tfFloat.make(
367       {1, 1, 7, 4}, {0.8823, 0.9150, 0.3829, 0.9593, 0.3904, 0.6009, 0.2566,
368                      0.7936, 0.9408, 0.1332, 0.9346, 0.5936, 0.8694, 0.5677,
369                      0.7411, 0.4294, 0.8854, 0.5739, 0.2666, 0.6274, 0.2696,
370                      0.4414, 0.2969, 0.8317, 0.1053, 0.2695, 0.3588, 0.1994});
371   exec_aten::Tensor key = tfFloat.make(
372       {1, 1, 7, 4}, {0.5472, 0.0062, 0.9516, 0.0753, 0.8860, 0.5832, 0.3376,
373                      0.8090, 0.5779, 0.9040, 0.5547, 0.3423, 0.6343, 0.3644,
374                      0.7104, 0.9464, 0.7890, 0.2814, 0.7886, 0.5895, 0.7539,
375                      0.1952, 0.0050, 0.3068, 0.1165, 0.9103, 0.6440, 0.7071});
376   exec_aten::Tensor value = tfFloat.make(
377       {1, 1, 7, 4}, {0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542,
378                      0.3278, 0.6532, 0.3958, 0.9147, 0.2036, 0.2018, 0.2018,
379                      0.9497, 0.6666, 0.9811, 0.0874, 0.0041, 0.1088, 0.1637,
380                      0.7025, 0.6790, 0.9155, 0.2418, 0.1591, 0.7653, 0.2979});
381   exec_aten::Tensor key_cache_0 = tfFloat.zeros({1, 8, 7, 4});
382   exec_aten::Tensor value_cache_0 = tfFloat.zeros({1, 8, 7, 4});
383   exec_aten::Tensor key_cache_1 = tfFloat.zeros({1, 8, 7, 4});
384   exec_aten::Tensor value_cache_1 = tfFloat.zeros({1, 8, 7, 4});
385   exec_aten::Tensor key_cache_2 = tfFloat.zeros({1, 8, 7, 4});
386   exec_aten::Tensor value_cache_2 = tfFloat.zeros({1, 8, 7, 4});
387   exec_aten::optional<exec_aten::Tensor> attn_mask;
388   double dropout_p = 0;
389   bool is_causal = false;
390   exec_aten::optional<double> scale;
391 
392   // start pos: 0 layer id 0
393   exec_aten::Tensor ret_expected_0 = tfFloat.make(
394       {1, 1, 7, 4}, {0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542,
395                      0.3278, 0.6532, 0.3958, 0.9147, 0.2036, 0.2018, 0.2018,
396                      0.9497, 0.6666, 0.9811, 0.0874, 0.0041, 0.1088, 0.1637,
397                      0.7025, 0.6790, 0.9155, 0.2418, 0.1591, 0.7653, 0.2979});
398 
399   std::vector<int32_t> out_size = {1, 1, 7, 4};
400   exec_aten::Tensor out = tfFloat.zeros(out_size);
401   exec_aten::Tensor ret = op_sdpa_with_kv_cache(
402       query,
403       key,
404       value,
405       key_cache_0,
406       value_cache_0,
407       0,
408       1,
409       attn_mask,
410       dropout_p,
411       is_causal,
412       scale,
413       out);
414   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_0, 1e-4, 1e-4);
415 
416   // start pos: 0 layer id 2
417   exec_aten::Tensor ret_expected_1 = tfFloat.make(
418       {1, 1, 7, 4}, {0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542,
419                      0.3278, 0.6532, 0.3958, 0.9147, 0.2036, 0.2018, 0.2018,
420                      0.9497, 0.6666, 0.9811, 0.0874, 0.0041, 0.1088, 0.1637,
421                      0.7025, 0.6790, 0.9155, 0.2418, 0.1591, 0.7653, 0.2979});
422   out = tfFloat.zeros(out_size);
423   ret = op_sdpa_with_kv_cache(
424       query,
425       key,
426       value,
427       key_cache_2,
428       value_cache_2,
429       0,
430       1,
431       attn_mask,
432       dropout_p,
433       is_causal,
434       scale,
435       out);
436   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_1, 1e-4, 1e-4);
437 
438   // start pos: 1 layer id 0
439   exec_aten::Tensor ret_expected_2 = tfFloat.make(
440       {1, 1, 7, 4}, {0.6581, 0.4913, 0.8913, 0.1447, 0.5315, 0.1587, 0.6542,
441                      0.3278, 0.6532, 0.3958, 0.9147, 0.2036, 0.2018, 0.2018,
442                      0.9497, 0.6666, 0.9811, 0.0874, 0.0041, 0.1088, 0.1637,
443                      0.7025, 0.6790, 0.9155, 0.2418, 0.1591, 0.7653, 0.2979});
444   out = tfFloat.zeros(out_size);
445   ret = op_sdpa_with_kv_cache(
446       query,
447       key,
448       value,
449       key_cache_0,
450       value_cache_0,
451       1,
452       1,
453       attn_mask,
454       dropout_p,
455       is_causal,
456       scale,
457       out);
458   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_2, 1e-4, 1e-4);
459 
460   // start pos: 1 layer id 1
461   exec_aten::Tensor ret_expected_3 = tfFloat.make(
462       {1, 1, 7, 4}, {0.4038, 0.3015, 0.5469, 0.0888, 0.3566, 0.1065, 0.4389,
463                      0.2199, 0.4354, 0.2639, 0.6097, 0.1358, 0.1412, 0.1412,
464                      0.6645, 0.4664, 0.6599, 0.0588, 0.0027, 0.0732, 0.0929,
465                      0.3989, 0.3856, 0.5198, 0.1398, 0.0920, 0.4424, 0.1722});
466   out = tfFloat.zeros(out_size);
467   ret = op_sdpa_with_kv_cache(
468       query,
469       key,
470       value,
471       key_cache_1,
472       value_cache_1,
473       1,
474       1,
475       attn_mask,
476       dropout_p,
477       is_causal,
478       scale,
479       out);
480   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_3, 1e-4, 1e-4);
481 
482   // start pos: 2 layer id 1
483   exec_aten::Tensor ret_expected_4 = tfFloat.make(
484       {1, 1, 7, 4}, {0.5005, 0.3737, 0.6779, 0.1101, 0.4268, 0.1275, 0.5254,
485                      0.2633, 0.5225, 0.3166, 0.7317, 0.1629, 0.1661, 0.1661,
486                      0.7819, 0.5488, 0.7891, 0.0703, 0.0033, 0.0875, 0.1185,
487                      0.5089, 0.4919, 0.6631, 0.1771, 0.1166, 0.5607, 0.2182});
488   out = tfFloat.zeros(out_size);
489   ret = op_sdpa_with_kv_cache(
490       query,
491       key,
492       value,
493       key_cache_1,
494       value_cache_1,
495       2,
496       1,
497       attn_mask,
498       dropout_p,
499       is_causal,
500       scale,
501       out);
502   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_4, 1e-4, 1e-4);
503 
504   // start pos: 2 layer id 2
505   exec_aten::Tensor ret_expected_5 = tfFloat.make(
506       {1, 1, 7, 4}, {0.5005, 0.3737, 0.6779, 0.1101, 0.4268, 0.1275, 0.5254,
507                      0.2633, 0.5225, 0.3166, 0.7317, 0.1629, 0.1661, 0.1661,
508                      0.7819, 0.5488, 0.7891, 0.0703, 0.0033, 0.0875, 0.1185,
509                      0.5089, 0.4919, 0.6631, 0.1771, 0.1166, 0.5607, 0.2182});
510   out = tfFloat.zeros(out_size);
511   ret = op_sdpa_with_kv_cache(
512       query,
513       key,
514       value,
515       key_cache_2,
516       value_cache_2,
517       2,
518       1,
519       attn_mask,
520       dropout_p,
521       is_causal,
522       scale,
523       out);
524   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4);
525 }
526 
TEST(OpScaledDotProductAttentionTest,BasicTestWithAttnMask)527 TEST(OpScaledDotProductAttentionTest, BasicTestWithAttnMask) {
528   TensorFactory<exec_aten::ScalarType::Float> tfFloat;
529 
530   exec_aten::Tensor query = tfFloat.make(
531       {1, 1, 4, 4},
532       {0.8823,
533        0.9150,
534        0.3829,
535        0.9593,
536        0.3904,
537        0.6009,
538        0.2566,
539        0.7936,
540        0.9408,
541        0.1332,
542        0.9346,
543        0.5936,
544        0.8694,
545        0.5677,
546        0.7411,
547        0.4294});
548   exec_aten::Tensor key = tfFloat.make(
549       {1, 1, 4, 4},
550       {0.8854,
551        0.5739,
552        0.2666,
553        0.6274,
554        0.2696,
555        0.4414,
556        0.2969,
557        0.8317,
558        0.1053,
559        0.2695,
560        0.3588,
561        0.1994,
562        0.5472,
563        0.0062,
564        0.9516,
565        0.0753});
566   exec_aten::Tensor value = tfFloat.make(
567       {1, 1, 4, 4},
568       {0.8860,
569        0.5832,
570        0.3376,
571        0.8090,
572        0.5779,
573        0.9040,
574        0.5547,
575        0.3423,
576        0.6343,
577        0.3644,
578        0.7104,
579        0.9464,
580        0.7890,
581        0.2814,
582        0.7886,
583        0.5895});
584   exec_aten::Tensor attn_mask = tfFloat.make({1, 1}, {0});
585   exec_aten::Tensor key_cache_0 = tfFloat.zeros({1, 5, 4, 4});
586   exec_aten::Tensor value_cache_0 = tfFloat.zeros({1, 5, 4, 4});
587   exec_aten::Tensor key_cache_1 = tfFloat.zeros({1, 5, 4, 4});
588   exec_aten::Tensor value_cache_1 = tfFloat.zeros({1, 5, 4, 4});
589   exec_aten::Tensor key_cache_2 = tfFloat.zeros({1, 5, 4, 4});
590   exec_aten::Tensor value_cache_2 = tfFloat.zeros({1, 5, 4, 4});
591   double dropout_p = 0;
592   bool is_causal = false;
593   exec_aten::optional<double> scale;
594 
595   // start pos: 0 layer id 0
596   exec_aten::Tensor ret_expected_0 = tfFloat.make(
597       {1, 1, 4, 4},
598       {0.8860,
599        0.5832,
600        0.3376,
601        0.8090,
602        0.5779,
603        0.9040,
604        0.5547,
605        0.3423,
606        0.6343,
607        0.3644,
608        0.7104,
609        0.9464,
610        0.7890,
611        0.2814,
612        0.7886,
613        0.5895});
614 
615   std::vector<int32_t> out_size = {1, 1, 4, 4};
616   exec_aten::Tensor out = tfFloat.zeros(out_size);
617   exec_aten::Tensor ret = op_sdpa_with_kv_cache(
618       query,
619       key,
620       value,
621       key_cache_0,
622       value_cache_0,
623       0,
624       1,
625       attn_mask,
626       dropout_p,
627       is_causal,
628       scale,
629       out);
630   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_0, 1e-4, 1e-4);
631 
632   // start pos: 0 layer id 2
633   exec_aten::Tensor ret_expected_1 = tfFloat.make(
634       {1, 1, 4, 4},
635       {0.8860,
636        0.5832,
637        0.3376,
638        0.8090,
639        0.5779,
640        0.9040,
641        0.5547,
642        0.3423,
643        0.6343,
644        0.3644,
645        0.7104,
646        0.9464,
647        0.7890,
648        0.2814,
649        0.7886,
650        0.5895});
651   out = tfFloat.zeros(out_size);
652   ret = op_sdpa_with_kv_cache(
653       query,
654       key,
655       value,
656       key_cache_2,
657       value_cache_2,
658       0,
659       1,
660       attn_mask,
661       dropout_p,
662       is_causal,
663       scale,
664       out);
665   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_1, 1e-4, 1e-4);
666 
667   attn_mask = tfFloat.make({1, 2}, {0, 0});
668   // start pos: 1 layer id 0
669   exec_aten::Tensor ret_expected_2 = tfFloat.make(
670       {1, 1, 4, 4},
671       {0.8860,
672        0.5832,
673        0.3376,
674        0.8090,
675        0.5779,
676        0.9040,
677        0.5547,
678        0.3423,
679        0.6343,
680        0.3644,
681        0.7104,
682        0.9464,
683        0.7890,
684        0.2814,
685        0.7886,
686        0.5895});
687   out = tfFloat.zeros(out_size);
688   ret = op_sdpa_with_kv_cache(
689       query,
690       key,
691       value,
692       key_cache_0,
693       value_cache_0,
694       1,
695       1,
696       attn_mask,
697       dropout_p,
698       is_causal,
699       scale,
700       out);
701   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_2, 1e-4, 1e-4);
702 
703   // start pos: 1 layer id 1
704   exec_aten::Tensor ret_expected_3 = tfFloat.make(
705       {1, 1, 4, 4},
706       {0.6486,
707        0.4270,
708        0.2472,
709        0.5922,
710        0.3669,
711        0.5740,
712        0.3522,
713        0.2173,
714        0.3635,
715        0.2088,
716        0.4071,
717        0.5423,
718        0.5110,
719        0.1822,
720        0.5107,
721        0.3817});
722   out = tfFloat.zeros(out_size);
723   ret = op_sdpa_with_kv_cache(
724       query,
725       key,
726       value,
727       key_cache_1,
728       value_cache_1,
729       1,
730       1,
731       attn_mask,
732       dropout_p,
733       is_causal,
734       scale,
735       out);
736   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_3, 1e-4, 1e-4);
737 
738   attn_mask = tfFloat.make({1, 3}, {0, 0, 0});
739   // start pos: 2 layer id 1
740   exec_aten::Tensor ret_expected_4 = tfFloat.make(
741       {1, 1, 4, 4},
742       {0.7490,
743        0.4930,
744        0.2854,
745        0.6838,
746        0.4489,
747        0.7021,
748        0.4308,
749        0.2659,
750        0.4622,
751        0.2655,
752        0.5176,
753        0.6895,
754        0.6202,
755        0.2212,
756        0.6199,
757        0.4634});
758   out = tfFloat.zeros(out_size);
759   ret = op_sdpa_with_kv_cache(
760       query,
761       key,
762       value,
763       key_cache_1,
764       value_cache_1,
765       2,
766       1,
767       attn_mask,
768       dropout_p,
769       is_causal,
770       scale,
771       out);
772   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_4, 1e-4, 1e-4);
773 
774   // start pos: 2 layer id 2
775   exec_aten::Tensor ret_expected_5 = tfFloat.make(
776       {1, 1, 4, 4},
777       {0.7490,
778        0.4930,
779        0.2854,
780        0.6838,
781        0.4489,
782        0.7021,
783        0.4308,
784        0.2659,
785        0.4622,
786        0.2655,
787        0.5176,
788        0.6895,
789        0.6202,
790        0.2212,
791        0.6199,
792        0.4634});
793   out = tfFloat.zeros(out_size);
794   ret = op_sdpa_with_kv_cache(
795       query,
796       key,
797       value,
798       key_cache_2,
799       value_cache_2,
800       2,
801       1,
802       attn_mask,
803       dropout_p,
804       is_causal,
805       scale,
806       out);
807   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4);
808 }
809 
TEST(OpScaledDotProductAttentionTest,SequenceTest)810 TEST(OpScaledDotProductAttentionTest, SequenceTest) {
811   TensorFactory<exec_aten::ScalarType::Float> tfFloat;
812 
813   exec_aten::Tensor query = tfFloat.make(
814       {1, 1, 8, 4},
815       {0.1261, 0.5031, 0.1117, 0.3905, 0.3625, 0.9328, 0.6549, 0.4128,
816        0.5845, 0.3557, 0.6965, 0.6978, 0.6343, 0.3051, 0.9266, 0.4278,
817        0.3053, 0.8132, 0.9075, 0.9976, 0.6481, 0.3296, 0.7539, 0.9290,
818        0.0096, 0.4381, 0.1590, 0.5932, 0.7068, 0.3967, 0.4582, 0.7251});
819   exec_aten::Tensor key = tfFloat.make(
820       {1, 1, 8, 4},
821       {0.4160, 0.0801, 0.9001, 0.2483, 0.4451, 0.5472, 0.4700, 0.0297,
822        0.7294, 0.2729, 0.2407, 0.6195, 0.2391, 0.2689, 0.3315, 0.3122,
823        0.2912, 0.3652, 0.6299, 0.0954, 0.1974, 0.5073, 0.5695, 0.7761,
824        0.1488, 0.6596, 0.7842, 0.7776, 0.0343, 0.3092, 0.0702, 0.1836});
825   exec_aten::Tensor value = tfFloat.make(
826       {1, 1, 8, 4},
827       {0.7785, 0.4253, 0.7124, 0.2065, 0.5760, 0.1976, 0.7499, 0.2813,
828        0.3746, 0.0662, 0.5017, 0.9747, 0.7427, 0.2332, 0.5067, 0.4452,
829        0.0975, 0.8920, 0.5081, 0.6053, 0.2981, 0.2660, 0.5824, 0.6849,
830        0.6121, 0.2590, 0.9854, 0.4264, 0.1938, 0.2661, 0.9922, 0.5000});
831 
832   exec_aten::Tensor key_cache_0 = tfFloat.zeros({1, 5, 8, 4});
833   exec_aten::Tensor value_cache_0 = tfFloat.zeros({1, 5, 8, 4});
834 
835   exec_aten::optional<exec_aten::Tensor> attn_mask;
836   double dropout_p = 0;
837   bool is_causal = false;
838   exec_aten::optional<double> scale;
839 
840   // start pos: 0 layer id 0
841   exec_aten::Tensor ret_expected_0 = tfFloat.make(
842       {1, 1, 8, 4},
843       {0.7785, 0.4253, 0.7124, 0.2065, 0.5760, 0.1976, 0.7499, 0.2813,
844        0.3746, 0.0662, 0.5017, 0.9747, 0.7427, 0.2332, 0.5067, 0.4452,
845        0.0975, 0.8920, 0.5081, 0.6053, 0.2981, 0.2660, 0.5824, 0.6849,
846        0.6121, 0.2590, 0.9854, 0.4264, 0.1938, 0.2661, 0.9922, 0.5000});
847 
848   std::vector<int32_t> out_size = {1, 1, 8, 4};
849   exec_aten::Tensor out = tfFloat.zeros(out_size);
850   exec_aten::Tensor ret = op_sdpa_with_kv_cache(
851       query,
852       key,
853       value,
854       key_cache_0,
855       value_cache_0,
856       0,
857       1,
858       attn_mask,
859       dropout_p,
860       is_causal,
861       scale,
862       out);
863   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_0, 1e-4, 1e-4);
864 
865   // start pos: 1 layer id 0
866   query = tfFloat.make(
867       {1, 1, 8, 4},
868       {0.4321, 0.2919, 0.3689, 0.0789, 0.1027, 0.7926, 0.9277, 0.9772,
869        0.1390, 0.7704, 0.1905, 0.7983, 0.8608, 0.8869, 0.8600, 0.8128,
870        0.5097, 0.7297, 0.3211, 0.7177, 0.3393, 0.4916, 0.0648, 0.3693,
871        0.2371, 0.3313, 0.1807, 0.0503, 0.5326, 0.8245, 0.9554, 0.7918});
872   key = tfFloat.make(
873       {1, 1, 8, 4},
874       {0.2408, 0.0055, 0.6897, 0.7802, 0.0707, 0.6793, 0.9227, 0.5303,
875        0.1988, 0.9099, 0.7135, 0.8311, 0.1619, 0.7910, 0.1585, 0.9947,
876        0.2882, 0.8013, 0.6001, 0.6325, 0.4233, 0.7054, 0.2916, 0.0287,
877        0.3079, 0.8918, 0.3684, 0.6572, 0.3151, 0.8751, 0.7992, 0.6765});
878   value = tfFloat.make(
879       {1, 1, 8, 4},
880       {0.2444, 0.0914, 0.5188, 0.2067, 0.9111, 0.0195, 0.7234, 0.9985,
881        0.7504, 0.6705, 0.0189, 0.9809, 0.4145, 0.0328, 0.9936, 0.2965,
882        0.4646, 0.9576, 0.1534, 0.1463, 0.5813, 0.4331, 0.6152, 0.0806,
883        0.5150, 0.2776, 0.2542, 0.0422, 0.7651, 0.5963, 0.0773, 0.8968});
884   exec_aten::Tensor ret_expected_1 = tfFloat.make(
885       {1, 1, 8, 4},
886       {0.5203, 0.2639, 0.6188, 0.2066, 0.7836, 0.0872, 0.7335, 0.7256,
887        0.5940, 0.4189, 0.2199, 0.9784, 0.5461, 0.1132, 0.7983, 0.3561,
888        0.3125, 0.9305, 0.3003, 0.3364, 0.4355, 0.3471, 0.5983, 0.3918,
889        0.5631, 0.2684, 0.6168, 0.2327, 0.5942, 0.4976, 0.3510, 0.7781});
890   out = tfFloat.zeros(out_size);
891   ret = op_sdpa_with_kv_cache(
892       query,
893       key,
894       value,
895       key_cache_0,
896       value_cache_0,
897       1,
898       1,
899       attn_mask,
900       dropout_p,
901       is_causal,
902       scale,
903       out);
904   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_1, 1e-4, 1e-4);
905 
906   // start pos: 2 layer id 0
907   query = tfFloat.make(
908       {1, 1, 8, 4},
909       {0.6508, 0.5928, 0.2064, 0.5754, 0.9818, 0.8429, 0.1106, 0.9564,
910        0.5388, 0.7405, 0.8883, 0.9263, 0.1102, 0.9378, 0.1604, 0.5375,
911        0.1506, 0.3904, 0.4773, 0.4402, 0.4210, 0.5394, 0.9932, 0.7905,
912        0.7797, 0.7001, 0.8871, 0.4769, 0.5398, 0.6029, 0.0639, 0.0972});
913   key = tfFloat.make(
914       {1, 1, 8, 4},
915       {0.5613, 0.3044, 0.4908, 0.3853, 0.5778, 0.8253, 0.3342, 0.9004,
916        0.8948, 0.1163, 0.1139, 0.0955, 0.2260, 0.3054, 0.4624, 0.3784,
917        0.2474, 0.3412, 0.3191, 0.9905, 0.3147, 0.1420, 0.7078, 0.4711,
918        0.8828, 0.8124, 0.9594, 0.1338, 0.8214, 0.9196, 0.2531, 0.9596});
919   value = tfFloat.make(
920       {1, 1, 8, 4},
921       {0.8748, 0.5055, 0.7411, 0.3252, 0.0639, 0.6264, 0.6491, 0.1732,
922        0.7425, 0.0729, 0.9303, 0.9842, 0.6361, 0.1863, 0.7433, 0.5852,
923        0.6360, 0.6643, 0.8807, 0.2851, 0.3875, 0.6364, 0.5545, 0.9032,
924        0.2374, 0.4818, 0.5934, 0.3672, 0.8409, 0.5547, 0.0379, 0.4458});
925   exec_aten::Tensor ret_expected_2 = tfFloat.make(
926       {1, 1, 8, 4},
927       {0.6350, 0.3426, 0.6582, 0.2484, 0.4391, 0.3419, 0.6962, 0.4399,
928        0.6321, 0.3475, 0.3754, 0.9798, 0.5721, 0.1344, 0.7829, 0.4233,
929        0.4122, 0.8394, 0.5040, 0.3304, 0.4066, 0.4378, 0.5820, 0.5922,
930        0.4333, 0.3541, 0.6168, 0.2918, 0.6486, 0.4949, 0.2965, 0.6151});
931   out = tfFloat.zeros(out_size);
932   ret = op_sdpa_with_kv_cache(
933       query,
934       key,
935       value,
936       key_cache_0,
937       value_cache_0,
938       2,
939       1,
940       attn_mask,
941       dropout_p,
942       is_causal,
943       scale,
944       out);
945   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_2, 1e-4, 1e-4);
946 
947   // start pos: 3 layer id 0
948   query = tfFloat.make(
949       {1, 1, 8, 4},
950       {0.2732, 0.5486, 0.4419, 0.0040, 0.4089, 0.4521, 0.3526, 0.9594,
951        0.3909, 0.8212, 0.6239, 0.0779, 0.6175, 0.9144, 0.1729, 0.1768,
952        0.9894, 0.9018, 0.2211, 0.8009, 0.4360, 0.0070, 0.5376, 0.6615,
953        0.3500, 0.6739, 0.0724, 0.8465, 0.9263, 0.7757, 0.5847, 0.6647});
954   key = tfFloat.make(
955       {1, 1, 8, 4},
956       {0.1382, 0.3751, 0.4523, 0.2218, 0.1307, 0.8363, 0.8393, 0.0459,
957        0.6591, 0.7034, 0.9750, 0.7893, 0.9597, 0.3363, 0.8502, 0.9067,
958        0.0278, 0.0986, 0.6012, 0.7730, 0.2516, 0.5551, 0.4993, 0.6266,
959        0.2313, 0.7820, 0.8325, 0.1531, 0.5048, 0.5014, 0.6606, 0.9658});
960   value = tfFloat.make(
961       {1, 1, 8, 4},
962       {0.6466, 0.3864, 0.9491, 0.3097, 0.3548, 0.5341, 0.1192, 0.5544,
963        0.1608, 0.5514, 0.5479, 0.5692, 0.0784, 0.0251, 0.7301, 0.9288,
964        0.0563, 0.6852, 0.1319, 0.5313, 0.9652, 0.8793, 0.1344, 0.8093,
965        0.7612, 0.4992, 0.9844, 0.3014, 0.3836, 0.2473, 0.5719, 0.6324});
966   exec_aten::Tensor ret_expected_3 = tfFloat.make(
967       {1, 1, 8, 4},
968       {0.6441, 0.3571, 0.7319, 0.2624, 0.4506, 0.3619, 0.5749, 0.4930,
969        0.4860, 0.3924, 0.4596, 0.8517, 0.4312, 0.1060, 0.7579, 0.5796,
970        0.3507, 0.8063, 0.4223, 0.3597, 0.5522, 0.5558, 0.4665, 0.6486,
971        0.5263, 0.3701, 0.6880, 0.2790, 0.6116, 0.4449, 0.3184, 0.6258});
972   out = tfFloat.zeros(out_size);
973   ret = op_sdpa_with_kv_cache(
974       query,
975       key,
976       value,
977       key_cache_0,
978       value_cache_0,
979       3,
980       1,
981       attn_mask,
982       dropout_p,
983       is_causal,
984       scale,
985       out);
986   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_3, 1e-4, 1e-4);
987 }
988