1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import unittest 10 11import torch 12 13 14class UpdateQuantizedKVCacheTest(unittest.TestCase): 15 16 def _reset(self): 17 self.quantized_k_cache = torch.zeros( 18 (self.batch_size, self.seq_len, self.num_heads, self.head_dim), 19 dtype=torch.int8, 20 ) 21 self.quantized_v_cache = torch.zeros( 22 (self.batch_size, self.seq_len, self.num_heads, self.head_dim), 23 dtype=torch.int8, 24 ) 25 self.k_scales_cache = torch.zeros( 26 (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.float64 27 ) 28 self.v_scales_cache = torch.zeros( 29 (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.float64 30 ) 31 self.k_zero_points_cache = torch.zeros( 32 (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.int64 33 ) 34 self.v_zero_points_cache = torch.zeros( 35 (self.batch_size, self.seq_len, self.num_heads, 1), dtype=torch.int64 36 ) 37 38 def setUp(self): 39 torch.manual_seed(42) 40 self.batch_size = 1 41 self.seq_len = 10 42 self.num_heads = 8 43 self.head_dim = 4 44 self._reset() 45 46 def _update_k(self, start_pos, value, scales, zero_points): 47 seq_len = value.size(1) 48 self.quantized_k_cache[:, start_pos : start_pos + seq_len, :, :] = value 49 self.k_scales_cache[:, start_pos : start_pos + seq_len, :, :] = scales 50 self.k_zero_points_cache[:, start_pos : start_pos + seq_len, :, :] = zero_points 51 52 def _update_v(self, start_pos, value, scales, zero_points): 53 seq_len = value.size(1) 54 self.quantized_v_cache[:, start_pos : start_pos + seq_len, :, :] = value 55 self.v_scales_cache[:, start_pos : start_pos + seq_len, :, :] = scales 56 self.v_zero_points_cache[:, start_pos : start_pos + seq_len, :, :] = zero_points 57 58 def _update_and_validate( 59 self, k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos 60 ): 61 k_cache = self.quantized_k_cache.clone() 62 v_cache = self.quantized_v_cache.clone() 63 k_scales_cache = self.k_scales_cache.clone() 64 v_scales_cache = self.v_scales_cache.clone() 65 k_zero_points_cache = self.k_zero_points_cache.clone() 66 v_zero_points_cache = self.v_zero_points_cache.clone() 67 self._update_k(start_pos, k, k_scales, k_zero_points) 68 self._update_v(start_pos, v, v_scales, v_zero_points) 69 70 torch.ops.llama.update_quantized_cache(k, k_cache, start_pos) 71 torch.ops.llama.update_quantized_cache(k_scales, k_scales_cache, start_pos) 72 torch.ops.llama.update_quantized_cache( 73 k_zero_points, k_zero_points_cache, start_pos 74 ) 75 76 torch.ops.llama.update_quantized_cache(v, v_cache, start_pos) 77 torch.ops.llama.update_quantized_cache(v_scales, v_scales_cache, start_pos) 78 torch.ops.llama.update_quantized_cache( 79 v_zero_points, v_zero_points_cache, start_pos 80 ) 81 82 self.assertTrue(torch.allclose(k_cache, self.quantized_k_cache)) 83 self.assertTrue(torch.allclose(v_cache, self.quantized_v_cache)) 84 self.assertTrue(torch.allclose(k_scales_cache, self.k_scales_cache)) 85 self.assertTrue(torch.allclose(v_scales_cache, self.v_scales_cache)) 86 self.assertTrue(torch.allclose(k_zero_points_cache, self.k_zero_points_cache)) 87 self.assertTrue(torch.allclose(v_zero_points_cache, self.v_zero_points_cache)) 88 89 def test_update_kv_cache_simple(self): 90 k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) 91 v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) 92 k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) 93 v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) 94 k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) 95 v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) 96 start_pos = 0 97 self._update_and_validate( 98 k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos 99 ) 100 101 def test_update_kv_cache_large_update(self): 102 self._reset() 103 k = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) 104 v = torch.randint(0, 50, (1, 3, 8, 4), dtype=torch.int8) 105 k_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) 106 v_scales = torch.rand((1, 3, 8, 1), dtype=torch.float64) 107 k_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) 108 v_zero_points = torch.randint(0, 20, (1, 3, 8, 1), dtype=torch.int64) 109 start_pos = 0 110 self._update_and_validate( 111 k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos 112 ) 113 114 def test_update_kv_cache_update_nonzero_offset(self): 115 self._reset() 116 k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) 117 v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) 118 k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) 119 v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) 120 k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) 121 v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) 122 start_pos = 2 123 self._update_and_validate( 124 k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos 125 ) 126 127 def test_update_kv_cache_more_updates(self): 128 self._reset() 129 k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) 130 v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) 131 k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) 132 v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) 133 k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) 134 v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) 135 start_pos = 2 136 self._update_and_validate( 137 k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos 138 ) 139 140 k = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) 141 v = torch.randint(0, 50, (1, 1, 8, 4), dtype=torch.int8) 142 k_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) 143 v_scales = torch.rand((1, 1, 8, 1), dtype=torch.float64) 144 k_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) 145 v_zero_points = torch.randint(0, 20, (1, 1, 8, 1), dtype=torch.int64) 146 start_pos = 4 147 148 self._update_and_validate( 149 k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos 150 ) 151 152 def test_batched_update_kv_cache_more_updates(self): 153 self.batch_size = 7 154 self._reset() 155 k = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) 156 v = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) 157 k_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) 158 v_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) 159 k_zero_points = torch.randint( 160 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 161 ) 162 v_zero_points = torch.randint( 163 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 164 ) 165 start_pos = 2 166 self._update_and_validate( 167 k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos 168 ) 169 170 k = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) 171 v = torch.randint(0, 50, (self.batch_size, 1, 8, 4), dtype=torch.int8) 172 k_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) 173 v_scales = torch.rand((self.batch_size, 1, 8, 1), dtype=torch.float64) 174 k_zero_points = torch.randint( 175 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 176 ) 177 v_zero_points = torch.randint( 178 0, 20, (self.batch_size, 1, 8, 1), dtype=torch.int64 179 ) 180 start_pos = 4 181 182 self._update_and_validate( 183 k, v, k_scales, v_scales, k_zero_points, v_zero_points, start_pos 184 ) 185