xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/test_sdpa_with_kv_cache.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import unittest
10
11import torch
12import torch.nn.functional as F
13
14from .sdpa_with_kv_cache import custom_ops_lib  # noqa
15
16
17def _sdpa_with_kv_cache_ref(q, k, v, k_cache, v_cache, attn_mask, start_pos, seq_len):
18    q = q.transpose(1, 2)
19    k_cache[:, start_pos : start_pos + seq_len, :, :] = k
20    v_cache[:, start_pos : start_pos + seq_len, :, :] = v
21    sliced_k_cache = k_cache[:, : start_pos + seq_len, :, :]
22    sliced_v_cache = v_cache[:, : start_pos + seq_len, :, :]
23    sliced_k_cache = sliced_k_cache.transpose(1, 2)
24    sliced_v_cache = sliced_v_cache.transpose(1, 2)
25
26    num_heads_q = q.size(1)
27    num_heads_kv = sliced_k_cache.size(1)
28    if num_heads_q != num_heads_kv:
29        assert (
30            num_heads_q % num_heads_kv == 0
31        ), f"{num_heads_q} not divisible by {num_heads_kv}"
32    n_reps = num_heads_q // num_heads_kv
33    if n_reps > 1:
34        sliced_k_cache = sliced_k_cache.repeat_interleave(n_reps, dim=1)
35        sliced_v_cache = sliced_v_cache.repeat_interleave(n_reps, dim=1)
36    out = F.scaled_dot_product_attention(
37        q, sliced_k_cache, sliced_v_cache, attn_mask=attn_mask
38    )
39    out = out.transpose(1, 2)
40    return out
41
42
43class SDPATest(unittest.TestCase):
44
45    def setUp(self):
46        torch.manual_seed(42)
47        self.k_cache = torch.zeros((1, 10, 8, 4))
48        self.v_cache = torch.zeros((1, 10, 8, 4))
49        self.mask = torch.full(
50            (10, 10),
51            float("-inf"),
52        )
53        self.mask = torch.triu(self.mask, diagonal=1)
54        self.use_mask_with_custom_op = False
55        self.is_causal = False
56
57    def test_sdpa_with_cache_no_mqa_1(self):
58        q = torch.rand((1, 1, 8, 4))
59        k = torch.rand((1, 1, 8, 4))
60        v = torch.rand((1, 1, 8, 4))
61        start_pos = 0
62        seq_len = q.size(1)
63        attn_mask = self.mask[start_pos : start_pos + seq_len, :]
64        attn_mask = attn_mask[:, : start_pos + seq_len]
65        ref_output = _sdpa_with_kv_cache_ref(
66            q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
67        )
68        if self.use_mask_with_custom_op:
69            attn_mask = attn_mask.contiguous()
70            op_output = torch.ops.llama.sdpa_with_kv_cache(
71                q,
72                k,
73                v,
74                self.k_cache,
75                self.v_cache,
76                start_pos,
77                seq_len,
78                attn_mask,
79                0,
80                False,
81            )
82        else:
83            op_output = torch.ops.llama.sdpa_with_kv_cache(
84                q,
85                k,
86                v,
87                self.k_cache,
88                self.v_cache,
89                start_pos,
90                seq_len,
91                None,
92                0,
93                self.is_causal,
94            )
95        self.assertTrue(torch.allclose(ref_output, op_output))
96
97    def test_sdpa_with_cache_no_mqa_2(self):
98        q = torch.rand((1, 1, 8, 4))
99        k = torch.rand((1, 1, 8, 4))
100        v = torch.rand((1, 1, 8, 4))
101        start_pos = 1
102        seq_len = q.size(1)
103        attn_mask = self.mask[start_pos : start_pos + seq_len, :]
104        attn_mask = attn_mask[:, : start_pos + seq_len]
105
106        ref_output = _sdpa_with_kv_cache_ref(
107            q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
108        )
109        if self.use_mask_with_custom_op:
110            attn_mask = attn_mask.contiguous()
111            op_output = torch.ops.llama.sdpa_with_kv_cache(
112                q,
113                k,
114                v,
115                self.k_cache,
116                self.v_cache,
117                start_pos,
118                seq_len,
119                attn_mask,
120                0,
121                False,
122            )
123        else:
124            op_output = torch.ops.llama.sdpa_with_kv_cache(
125                q,
126                k,
127                v,
128                self.k_cache,
129                self.v_cache,
130                start_pos,
131                seq_len,
132                None,
133                0,
134                self.is_causal,
135            )
136
137        self.assertTrue(torch.allclose(ref_output, op_output))
138
139    def test_sdpa_with_cache_no_mqa_3(self):
140        q = torch.rand((1, 1, 8, 4))
141        k = torch.rand((1, 1, 8, 4))
142        v = torch.rand((1, 1, 8, 4))
143        start_pos = 2
144        seq_len = q.size(1)
145        attn_mask = self.mask[start_pos : start_pos + seq_len, :]
146        attn_mask = attn_mask[:, : start_pos + seq_len]
147
148        ref_output = _sdpa_with_kv_cache_ref(
149            q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
150        )
151        if self.use_mask_with_custom_op:
152            attn_mask = attn_mask.contiguous()
153            op_output = torch.ops.llama.sdpa_with_kv_cache(
154                q,
155                k,
156                v,
157                self.k_cache,
158                self.v_cache,
159                start_pos,
160                seq_len,
161                attn_mask,
162                0,
163                False,
164            )
165        else:
166            op_output = torch.ops.llama.sdpa_with_kv_cache(
167                q,
168                k,
169                v,
170                self.k_cache,
171                self.v_cache,
172                start_pos,
173                seq_len,
174                None,
175                0,
176                self.is_causal,
177            )
178        self.assertTrue(torch.allclose(ref_output, op_output))
179
180    def test_sdpa_with_cache_no_mqa_4(self):
181        q = torch.rand((1, 1, 8, 4))
182        k = torch.rand((1, 1, 8, 4))
183        v = torch.rand((1, 1, 8, 4))
184        start_pos = 3
185        seq_len = q.size(1)
186        attn_mask = self.mask[start_pos : start_pos + seq_len, :]
187        attn_mask = attn_mask[:, : start_pos + seq_len]
188
189        ref_output = _sdpa_with_kv_cache_ref(
190            q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
191        )
192        if self.use_mask_with_custom_op:
193            attn_mask = attn_mask.contiguous()
194            op_output = torch.ops.llama.sdpa_with_kv_cache(
195                q,
196                k,
197                v,
198                self.k_cache,
199                self.v_cache,
200                start_pos,
201                seq_len,
202                attn_mask,
203                0,
204                False,
205            )
206        else:
207            op_output = torch.ops.llama.sdpa_with_kv_cache(
208                q,
209                k,
210                v,
211                self.k_cache,
212                self.v_cache,
213                start_pos,
214                seq_len,
215                None,
216                0,
217                self.is_causal,
218            )
219        self.assertTrue(torch.allclose(ref_output, op_output))
220
221
222class SDPAWithAttentionMaskTest(SDPATest):
223
224    def setUp(self):
225        SDPATest.setUp(self)
226        self.mask = torch.full(
227            (10, 10),
228            100.642,
229        )
230        self.use_mask_with_custom_op = True
231
232
233class SDPAWithCausalTest(SDPATest):
234
235    def setUp(self):
236        SDPATest.setUp(self)
237        self.is_causal = True
238
239
240class SDPAWithDynamicShape(unittest.TestCase):
241
242    def setUp(self):
243        torch.manual_seed(42)
244        self.k_cache = torch.zeros((1, 10, 8, 4))
245        self.v_cache = torch.zeros((1, 10, 8, 4))
246        self.mask = torch.full(
247            (10, 10),
248            float("-inf"),
249        )
250        self.mask = torch.triu(self.mask, diagonal=1)
251        self.use_mask_with_custom_op = False
252        self.is_causal = False
253
254    def test_sdpa_with_cache_dynamic_shape_0(self):
255        q = torch.rand((1, 4, 8, 4))
256        k = torch.rand((1, 4, 8, 4))
257        v = torch.rand((1, 4, 8, 4))
258        seq_len = q.size(1)
259        start_pos = 0
260        attn_mask = self.mask[start_pos : start_pos + seq_len, :]
261        attn_mask = attn_mask[:, : start_pos + seq_len]
262
263        ref_output = _sdpa_with_kv_cache_ref(
264            q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
265        )
266
267        op_output = torch.ops.llama.sdpa_with_kv_cache(
268            q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
269        )
270        self.assertTrue(torch.allclose(ref_output, op_output))
271
272    def test_sdpa_with_cache_dynamic_shape_2(self):
273        q = torch.rand((1, 3, 8, 4))
274        k = torch.rand((1, 3, 8, 4))
275        v = torch.rand((1, 3, 8, 4))
276        seq_len = q.size(1)
277        start_pos = 2
278        attn_mask = self.mask[start_pos : start_pos + seq_len, :]
279        attn_mask = attn_mask[:, : start_pos + seq_len]
280
281        ref_output = _sdpa_with_kv_cache_ref(
282            q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
283        )
284
285        op_output = torch.ops.llama.sdpa_with_kv_cache(
286            q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
287        )
288        self.assertTrue(torch.allclose(ref_output, op_output))
289
290    @unittest.skip("This test will expect failure but runtime is not bubbling it up.")
291    def test_sdpa_with_cache_dynamic_shape_4(self):
292        q = torch.rand((1, 11, 8, 4))
293        k = torch.rand((1, 11, 8, 4))
294        v = torch.rand((1, 11, 8, 4))
295        seq_len = q.size(1)
296        start_pos = 4
297
298        torch.ops.llama.sdpa_with_kv_cache(
299            q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
300        )
301
302
303class SDPATestWithMQA(unittest.TestCase):
304
305    def setup_caches(self):
306        self.k_cache = torch.zeros((1, 5, self.n_heads_kv, 4))
307        self.v_cache = torch.zeros((1, 5, self.n_heads_kv, 4))
308
309    def setUp(self):
310        torch.manual_seed(42)
311        self.n_heads_kv = 4
312        self.n_heads_q = 8
313        self.setup_caches()
314        self.mask = torch.full(
315            (5, 5),
316            float("-inf"),
317        )
318        self.mask = torch.triu(self.mask, diagonal=1)
319
320    def test_sdpa_with_cache_mqa_1(self):
321        q = torch.rand((1, 1, self.n_heads_q, 4))
322        k = torch.rand((1, 1, self.n_heads_kv, 4))
323        v = torch.rand((1, 1, self.n_heads_kv, 4))
324        start_pos = 0
325        seq_len = q.size(1)
326        attn_mask = self.mask[start_pos : start_pos + seq_len, :]
327        attn_mask = attn_mask[:, : start_pos + seq_len]
328        ref_output = _sdpa_with_kv_cache_ref(
329            q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
330        )
331        op_output = torch.ops.llama.sdpa_with_kv_cache(
332            q, k, v, self.k_cache, self.v_cache, 0, 1, None, 0, False
333        )
334        self.assertTrue(torch.allclose(ref_output, op_output))
335
336    def test_sdpa_with_cache_mqa_2(self):
337        q = torch.rand((1, 1, self.n_heads_q, 4))
338        k = torch.rand((1, 1, self.n_heads_kv, 4))
339        v = torch.rand((1, 1, self.n_heads_kv, 4))
340        start_pos = 1
341        seq_len = q.size(1)
342        attn_mask = self.mask[start_pos : start_pos + seq_len, :]
343        attn_mask = attn_mask[:, : start_pos + seq_len]
344        ref_output = _sdpa_with_kv_cache_ref(
345            q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
346        )
347        op_output = torch.ops.llama.sdpa_with_kv_cache(
348            q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False
349        )
350        self.assertTrue(torch.allclose(ref_output, op_output))
351
352    def test_sdpa_with_cache_mqa_3(self):
353        self.n_heads_q = 14
354        self.n_heads_kv = 7
355        self.setup_caches()
356        q = torch.rand((1, 1, self.n_heads_q, 4))
357        k = torch.rand((1, 1, self.n_heads_kv, 4))
358        v = torch.rand((1, 1, self.n_heads_kv, 4))
359        start_pos = 1
360        seq_len = q.size(1)
361        attn_mask = self.mask[start_pos : start_pos + seq_len, :]
362        attn_mask = attn_mask[:, : start_pos + seq_len]
363        ref_output = _sdpa_with_kv_cache_ref(
364            q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
365        )
366        op_output = torch.ops.llama.sdpa_with_kv_cache(
367            q, k, v, self.k_cache, self.v_cache, 1, 1, None, 0, False
368        )
369        self.assertTrue(torch.allclose(ref_output, op_output))
370
371
372class SDPATestCommon(unittest.TestCase):
373
374    def setup_caches(self):
375        self.k_cache = torch.zeros(
376            (self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim)
377        )
378        self.v_cache = torch.zeros(
379            (self.n_batch, self.max_seq_len, self.n_heads_kv, self.head_dim)
380        )
381        self.mask = torch.full(
382            (self.max_seq_len, self.max_seq_len),
383            float("-inf"),
384        )
385        self.mask = torch.triu(self.mask, diagonal=1)
386
387    def setUp(self):
388        torch.manual_seed(42)
389        self.n_batch = 5
390        self.n_heads_kv = 32
391        self.n_heads_q = 32
392        self.head_dim = 128
393        self.max_seq_len = 2048
394        self.setup_caches()
395
396    def _scale_tensor(self, tensor, min_value, max_value, scale=True):
397        normalized_tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
398
399        scaled_tensor = normalized_tensor * (max_value - min_value) + min_value
400
401        return scaled_tensor if scale else tensor
402
403    def _test_sdpa_common(
404        self,
405        n_heads_kv,
406        n_heads_q,
407        head_dim,
408        max_seq_len,
409        seq_len,
410        next_iter_seq_len=1,
411        scale_tensors=False,
412    ):
413        # Range arbitrarily chosen to reproduce a numerical error on x86 in some of the long context tests
414        tensor_scale_max = 15
415        tensor_scale_min = -15
416        self.n_heads_kv = n_heads_kv
417        self.n_heads_q = n_heads_q
418        self.head_dim = head_dim
419        self.max_seq_len = max_seq_len
420        self.setup_caches()
421        q = self._scale_tensor(
422            torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
423            tensor_scale_max,
424            tensor_scale_min,
425            scale_tensors,
426        )
427        k = self._scale_tensor(
428            torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
429            tensor_scale_max,
430            tensor_scale_min,
431            scale_tensors,
432        )
433        v = self._scale_tensor(
434            torch.rand((self.n_batch, seq_len, self.n_heads_kv, self.head_dim)),
435            tensor_scale_max,
436            tensor_scale_min,
437            scale_tensors,
438        )
439
440        start_pos = 0
441        attn_mask = self.mask[start_pos : start_pos + seq_len, :]
442        attn_mask = attn_mask[:, : start_pos + seq_len]
443        ref_output = _sdpa_with_kv_cache_ref(
444            q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
445        )
446        op_output = torch.ops.llama.sdpa_with_kv_cache(
447            q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
448        )
449        self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6))
450
451        q = self._scale_tensor(
452            torch.rand(
453                (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
454            ),
455            tensor_scale_max,
456            tensor_scale_min,
457            scale_tensors,
458        )
459        k = self._scale_tensor(
460            torch.rand(
461                (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
462            ),
463            tensor_scale_max,
464            tensor_scale_min,
465            scale_tensors,
466        )
467        v = self._scale_tensor(
468            torch.rand(
469                (self.n_batch, next_iter_seq_len, self.n_heads_kv, self.head_dim)
470            ),
471            tensor_scale_max,
472            tensor_scale_min,
473            scale_tensors,
474        )
475
476        start_pos = seq_len
477        seq_len = q.size(1)
478        attn_mask = self.mask[start_pos : start_pos + seq_len, :]
479        attn_mask = attn_mask[:, : start_pos + seq_len]
480        ref_output = _sdpa_with_kv_cache_ref(
481            q, k, v, self.k_cache, self.v_cache, attn_mask, start_pos, seq_len
482        )
483        op_output = torch.ops.llama.sdpa_with_kv_cache(
484            q, k, v, self.k_cache, self.v_cache, start_pos, seq_len, None, 0, True
485        )
486        self.assertTrue(torch.allclose(ref_output, op_output, atol=1e-6))
487
488
489class SDPATestForLargeSeqLength(SDPATestCommon):
490
491    def test_sdpa_with_cache_seq_len_130(self):
492        n_heads_kv = 32
493        n_heads_q = 32
494        head_dim = 128
495        max_seq_len = 2048
496        seq_len = 130
497        self._test_sdpa_common(
498            n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True
499        )
500
501    def test_sdpa_with_cache_seq_len_small(self):
502        n_heads_kv = 4
503        n_heads_q = 4
504        head_dim = 4
505        max_seq_len = 8
506        seq_len = 4
507        self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len)
508
509    def test_sdpa_with_cache_seq_len_llava_example(self):
510        n_heads_kv = 32
511        n_heads_q = 32
512        head_dim = 128
513        max_seq_len = 2048
514        seq_len = 634
515        self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len)
516
517    def test_sdpa_with_cache_seq_len_130_gqa(self):
518        n_heads_kv = 8
519        n_heads_q = 32
520        head_dim = 128
521        max_seq_len = 2048
522        seq_len = 130
523        self._test_sdpa_common(
524            n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, True
525        )
526
527    def test_sdpa_with_cache_seq_len_llava_example_gqa(self):
528        n_heads_kv = 16
529        n_heads_q = 32
530        head_dim = 128
531        max_seq_len = 2048
532        seq_len = 634
533        self._test_sdpa_common(n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len)
534
535
536class SDPATestForSpeculativeDecode(SDPATestCommon):
537
538    def test_sdpa_with_cache_seq_len_130(self):
539        n_heads_kv = 32
540        n_heads_q = 32
541        head_dim = 128
542        max_seq_len = 2048
543        seq_len = 130
544        next_iter_seq_len = 17
545        self._test_sdpa_common(
546            n_heads_kv,
547            n_heads_q,
548            head_dim,
549            max_seq_len,
550            seq_len,
551            next_iter_seq_len,
552            True,
553        )
554
555    def test_sdpa_with_cache_seq_len_llava_example(self):
556        n_heads_kv = 32
557        n_heads_q = 32
558        head_dim = 128
559        max_seq_len = 2048
560        seq_len = 634
561        next_iter_seq_len = 64
562        self._test_sdpa_common(
563            n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
564        )
565
566    def test_sdpa_with_cache_seq_len_130_gqa(self):
567        n_heads_kv = 8
568        n_heads_q = 32
569        head_dim = 128
570        max_seq_len = 2048
571        seq_len = 130
572        next_iter_seq_len = 33
573        self._test_sdpa_common(
574            n_heads_kv,
575            n_heads_q,
576            head_dim,
577            max_seq_len,
578            seq_len,
579            next_iter_seq_len,
580            True,
581        )
582
583    def test_sdpa_with_cache_seq_len_llava_example_gqa(self):
584        n_heads_kv = 16
585        n_heads_q = 32
586        head_dim = 128
587        max_seq_len = 2048
588        seq_len = 634
589        next_iter_seq_len = 117
590        self._test_sdpa_common(
591            n_heads_kv, n_heads_q, head_dim, max_seq_len, seq_len, next_iter_seq_len
592        )
593