xref: /aosp_15_r20/external/executorch/examples/models/llama/tests/test_simple_sdpa.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8import unittest
9
10import torch
11from executorch.examples.models.llama.llama_transformer import KVCache, SDPA
12from executorch.examples.models.llama.source_transformation.sdpa import SDPASimple
13
14
15class SDPATest(unittest.TestCase):
16    def test_simple_sdpa(self):
17        # Verify the correctness between the simple SDPA and the original SDPA module defined in llama_transformer.py
18        max_batch_size = 1
19        max_seq_length = 128
20        n_heads = 8
21        head_dim = 8
22        dim = 64
23        n_rep = 1
24        bsz = 1
25        seqlen = 1
26        n_local_heads = n_heads
27        kv_cache = KVCache(
28            max_batch_size=max_batch_size,
29            max_seq_length=max_seq_length,
30            n_heads=n_heads,
31            head_dim=head_dim,
32            transpose_cache=True,
33            enable_dynamic_shape=False,
34        )
35        sdpa = SDPA(
36            kv_cache=copy.deepcopy(kv_cache),
37            dim=dim,
38            head_dim=head_dim,
39            n_rep=n_rep,
40            max_seq_len=max_seq_length,
41            enable_dynamic_shape=False,
42        )
43        input_pos = torch.tensor([0])
44        query = torch.randn(1, 1, n_local_heads, head_dim)
45        key = torch.randn(1, 1, n_local_heads, head_dim)
46        value = torch.randn(1, 1, n_local_heads, head_dim)
47        mask = torch.randn(max_seq_length, max_seq_length)
48        sdpa_output = sdpa(
49            input_pos,
50            query,
51            key,
52            value,
53            bsz=bsz,
54            seqlen=seqlen,
55            mask=mask,
56        )
57
58        simple_sdpa = SDPASimple(
59            kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep
60        )
61        simple_sdpa_output = simple_sdpa(
62            input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask
63        )
64
65        # Compare the output from output from two sdpa implementation
66        self.assertTrue(torch.allclose(sdpa_output, simple_sdpa_output))
67