1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10
11from executorch.examples.models.llama.llama_transformer import KVCache
12
13from executorch.examples.models.llama.source_transformation.quantized_kv_cache import (
14    QuantizedCacheType,
15    QuantizedKVCache,
16)
17
18from executorch.examples.models.llama.source_transformation.sdpa import SDPACustom
19
20
21class SDPAWithQuantizedKVCacheTest(unittest.TestCase):
22
23    def _init_cache(self):
24        self.kv_cache = KVCache(
25            self.max_batch_size,
26            self.max_seq_len,
27            self.n_kv_heads,
28            self.head_dim,
29            False,
30            self.enable_dynamic_shape,
31            dtype=self.dtype,
32        )
33        self.quantized_kv_cache = QuantizedKVCache.from_float(
34            self.kv_cache, QuantizedCacheType.AffineAsymmetric
35        )
36
37    def _init_kv(self):
38        kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
39        q_shape = (1, self.seq_len, self.n_heads, self.head_dim)
40        q = torch.rand(q_shape, dtype=self.dtype)
41        k = torch.rand(kv_shape, dtype=self.dtype)
42        v = torch.rand(kv_shape, dtype=self.dtype)
43        return q, k, v
44
45    def setUp(self):
46        torch.manual_seed(42)
47        self.max_batch_size = 1
48        self.max_seq_len = 5
49        self.n_kv_heads = 4
50        self.n_heads = 8
51        self.head_dim = 17
52        self.dim = self.n_heads * self.head_dim
53        self.enable_dynamic_shape = False
54        self.dtype = torch.float32
55
56    def test_simple(self, is_dynamic_shape=False):
57        self.enable_dynamic_shape = is_dynamic_shape
58        input_pos = torch.tensor([0], dtype=torch.int64)
59        self.seq_len = 3
60        self._init_cache()
61        q, k, v = self._init_kv()
62        self.float_sdpa = SDPACustom(self.kv_cache, self.dim)
63        self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim)
64        float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
65        quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
66        torch.testing.assert_close(
67            float_out,
68            quantized_out,
69            # had to adjust rtol because switching to using custom_sdpa means we
70            # will use dequantized k and v instead of original k and v
71            # this leads to larger differences in the output.
72            # subsequent diff in the stack will address this issue.
73            rtol=1e-01,
74            atol=1e-03,
75        )
76
77        input_pos = torch.tensor([3], dtype=torch.int64)
78        self.seq_len = 1
79        q, k, v = self._init_kv()
80        float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
81        quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
82        torch.testing.assert_close(
83            float_out,
84            quantized_out,
85            rtol=1e-03,
86            atol=1e-03,
87        )
88