xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/test_update_quantized_cache.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import unittest
10
11import torch
12
13
14class UpdateQuantizedKVCacheTest(unittest.TestCase):
15
16    def _reset(self):
17        self.quantized_k_cache = torch.zeros(
18            (self.batch_size, self.seq_len, self.num_heads, self.head_dim),
19            dtype=torch.int8,
20        )
21        self.quantized_v_cache = torch.zeros(
22            (self.batch_size, self.seq_len, self.num_heads, self.head_dim),
23            dtype=torch.int8,
24        )
25        self.k_scales_cache = torch.zeros(
26            (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.float64
27        )
28        self.v_scales_cache = torch.zeros(
29            (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.float64
30        )
31        self.k_zero_points_cache = torch.zeros(
32            (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.int64
33        )
34        self.v_zero_points_cache = torch.zeros(
35            (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.int64
36        )
37
38    def setUp(self):
39        torch.manual_seed(42)
40        self.batch_size = 1
41        self.seq_len = 10
42        self.num_heads = 8
43        self.head_dim = 4
44        self._reset()
45
46    def _update_k(self, start_pos, value, scales, zero_points):
47        seq_len = value.size(1)
48        self.quantized_k_cache[:, start_pos : start_pos + seq_len, :, :] = value
49        self.k_scales_cache[:, start_pos : start_pos + seq_len, :, :] = scales
50        self.k_zero_points_cache[:, start_pos : start_pos + seq_len, :, :] = zero_points
51
52    def _update_v(self, start_pos, value, scales, zero_points):
53        seq_len = value.size(1)
54        self.quantized_v_cache[:, start_pos : start_pos + seq_len, :, :] = value
55        self.v_scales_cache[:, start_pos : start_pos + seq_len, :, :] = scales
56        self.v_zero_points_cache[:, start_pos : start_pos + seq_len, :, :] = zero_points
57
58    def _update_and_validate(
59        self, k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos
60    ):
61        k_cache = self.quantized_k_cache.clone()
62        v_cache = self.quantized_v_cache.clone()
63        k_scales_cache = self.k_scales_cache.clone()
64        v_scales_cache = self.v_scales_cache.clone()
65        k_zero_points_cache = self.k_zero_points_cache.clone()
66        v_zero_points_cache = self.v_zero_points_cache.clone()
67        self._update_k(start_pos, k, k_scales, k_zero_points)
68        self._update_v(start_pos, v, v_scales, v_zero_points)
69
70        torch.ops.llama.update_quantized_cache(k, k_cache, start_pos)
71        torch.ops.llama.update_quantized_cache(k_scales, k_scales_cache, start_pos)
72        torch.ops.llama.update_quantized_cache(
73            k_zero_points, k_zero_points_cache, start_pos
74        )
75
76        torch.ops.llama.update_quantized_cache(v, v_cache, start_pos)
77        torch.ops.llama.update_quantized_cache(v_scales, v_scales_cache, start_pos)
78        torch.ops.llama.update_quantized_cache(
79            v_zero_points, v_zero_points_cache, start_pos
80        )
81
82        self.assertTrue(torch.allclose(k_cache, self.quantized_k_cache))
83        self.assertTrue(torch.allclose(v_cache, self.quantized_v_cache))
84        self.assertTrue(torch.allclose(k_scales_cache, self.k_scales_cache))
85        self.assertTrue(torch.allclose(v_scales_cache, self.v_scales_cache))
86        self.assertTrue(torch.allclose(k_zero_points_cache, self.k_zero_points_cache))
87        self.assertTrue(torch.allclose(v_zero_points_cache, self.v_zero_points_cache))
88
89    def test_update_kv_cache_simple(self):
90        k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8)
91        v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8)
92        k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64)
93        v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64)
94        k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64)
95        v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64)
96        start_pos = 0
97        self._update_and_validate(
98            k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos
99        )
100
101    def test_update_kv_cache_large_update(self):
102        self._reset()
103        k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8)
104        v = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8)
105        k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64)
106        v_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64)
107        k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64)
108        v_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64)
109        start_pos = 0
110        self._update_and_validate(
111            k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos
112        )
113
114    def test_update_kv_cache_update_nonzero_offset(self):
115        self._reset()
116        k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8)
117        v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8)
118        k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64)
119        v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64)
120        k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64)
121        v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64)
122        start_pos = 2
123        self._update_and_validate(
124            k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos
125        )
126
127    def test_update_kv_cache_more_updates(self):
128        self._reset()
129        k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8)
130        v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8)
131        k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64)
132        v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64)
133        k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64)
134        v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64)
135        start_pos = 2
136        self._update_and_validate(
137            k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos
138        )
139
140        k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8)
141        v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8)
142        k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64)
143        v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64)
144        k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64)
145        v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64)
146        start_pos = 4
147
148        self._update_and_validate(
149            k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos
150        )
151
152    def test_batched_update_kv_cache_more_updates(self):
153        self.batch_size = 7
154        self._reset()
155        k = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8)
156        v = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8)
157        k_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64)
158        v_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64)
159        k_zero_points = torch.randint(
160            0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64
161        )
162        v_zero_points = torch.randint(
163            0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64
164        )
165        start_pos = 2
166        self._update_and_validate(
167            k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos
168        )
169
170        k = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8)
171        v = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8)
172        k_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64)
173        v_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64)
174        k_zero_points = torch.randint(
175            0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64
176        )
177        v_zero_points = torch.randint(
178            0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64
179        )
180        start_pos = 4
181
182        self._update_and_validate(
183            k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos
184        )
185