1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10 11from executorch.examples.models.llama.llama_transformer import KVCache 12 13from executorch.examples.models.llama.source_transformation.quantized_kv_cache import ( 14 QuantizedCacheType, 15 QuantizedKVCache, 16) 17 18from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom 19 20 21class SDPAWithQuantizedKVCacheTest(unittest.TestCase): 22 23 def _init_cache(self): 24 self.kv_cache = KVCache( 25 self.max_batch_size, 26 self.max_seq_len, 27 self.n_kv_heads, 28 self.head_dim, 29 False, 30 self.enable_dynamic_shape, 31 dtype=self.dtype, 32 ) 33 self.quantized_kv_cache = QuantizedKVCache.from_float( 34 self.kv_cache, QuantizedCacheType.AffineAsymmetric 35 ) 36 37 def _init_kv(self): 38 kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim) 39 q_shape = (1, self.seq_len, self.n_heads, self.head_dim) 40 q = torch.rand(q_shape, dtype=self.dtype) 41 k = torch.rand(kv_shape, dtype=self.dtype) 42 v = torch.rand(kv_shape, dtype=self.dtype) 43 return q, k, v 44 45 def setUp(self): 46 torch.manual_seed(42) 47 self.max_batch_size = 1 48 self.max_seq_len = 5 49 self.n_kv_heads = 4 50 self.n_heads = 8 51 self.head_dim = 17 52 self.dim = self.n_heads * self.head_dim 53 self.enable_dynamic_shape = False 54 self.dtype = torch.float32 55 56 def test_simple(self, is_dynamic_shape=False): 57 self.enable_dynamic_shape = is_dynamic_shape 58 input_pos = torch.tensor([0], dtype=torch.int64) 59 self.seq_len = 3 60 self._init_cache() 61 q, k, v = self._init_kv() 62 self.float_sdpa = SDPACustom(self.kv_cache, self.dim) 63 self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim) 64 float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) 65 quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None) 66 torch.testing.assert_close( 67 float_out, 68 quantized_out, 69 # had to adjust rtol because switching to using custom_sdpa means we 70 # will use dequantized k and v instead of original k and v 71 # this leads to larger differences in the output. 72 # subsequent diff in the stack will address this issue. 73 rtol=1e-01, 74 atol=1e-03, 75 ) 76 77 input_pos = torch.tensor([3], dtype=torch.int64) 78 self.seq_len = 1 79 q, k, v = self._init_kv() 80 float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) 81 quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None) 82 torch.testing.assert_close( 83 float_out, 84 quantized_out, 85 rtol=1e-03, 86 atol=1e-03, 87 ) 88