1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8import unittest 9 10import torch 11from executorch.examples.models.llama.llama_transformer import KVCache, SDPA 12from executorch.examples.models.llama.source_transformation.sdpa import SDPASimple 13 14 15class SDPATest(unittest.TestCase): 16 def test_simple_sdpa(self): 17 # Verify the correctness between the simple SDPA and the original SDPA module defined in llama_transformer.py 18 max_batch_size = 1 19 max_seq_length = 128 20 n_heads = 8 21 head_dim = 8 22 dim = 64 23 n_rep = 1 24 bsz = 1 25 seqlen = 1 26 n_local_heads = n_heads 27 kv_cache = KVCache( 28 max_batch_size=max_batch_size, 29 max_seq_length=max_seq_length, 30 n_heads=n_heads, 31 head_dim=head_dim, 32 transpose_cache=True, 33 enable_dynamic_shape=False, 34 ) 35 sdpa = SDPA( 36 kv_cache=copy.deepcopy(kv_cache), 37 dim=dim, 38 head_dim=head_dim, 39 n_rep=n_rep, 40 max_seq_len=max_seq_length, 41 enable_dynamic_shape=False, 42 ) 43 input_pos = torch.tensor([0]) 44 query = torch.randn(1, 1, n_local_heads, head_dim) 45 key = torch.randn(1, 1, n_local_heads, head_dim) 46 value = torch.randn(1, 1, n_local_heads, head_dim) 47 mask = torch.randn(max_seq_length, max_seq_length) 48 sdpa_output = sdpa( 49 input_pos, 50 query, 51 key, 52 value, 53 bsz=bsz, 54 seqlen=seqlen, 55 mask=mask, 56 ) 57 58 simple_sdpa = SDPASimple( 59 kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep 60 ) 61 simple_sdpa_output = simple_sdpa( 62 input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask 63 ) 64 65 # Compare the output from output from two sdpa implementation 66 self.assertTrue(torch.allclose(sdpa_output, simple_sdpa_output)) 67