1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Tuple 8 9import torch 10from torchtune.modules.kv_cache import KVCache as TuneKVCache 11 12 13class KVCache(TuneKVCache): 14 """ 15 An export-friendly KVCache implementation adopted from torchtune KVCache: 16 https://github.com/pytorch/torchtune/blob/main/torchtune/modules/kv_cache.py 17 This also takes both transposed and un-transposed KVCache shapes. 18 Standalone ``nn.Module`` containing a kv-cache to cache past key and values during inference. 19 20 Args: 21 batch_size (int): batch size model will be run with 22 max_seq_len (int): maximum sequence length model will be run with 23 num_kv_heads (int): number of key/value heads. 24 head_dim (int): per-attention head embedding dimension 25 dtype (torch.dtype): dtype for the caches 26 transpose_cache (bool): whether we transpose(1, 2) for kv cache. 27 """ 28 29 def __init__( 30 self, 31 batch_size: int, 32 max_seq_len: int, 33 num_kv_heads: int, 34 head_dim: int, 35 dtype: torch.dtype, 36 transpose_cache: bool = True, 37 ) -> None: 38 super().__init__( 39 batch_size=batch_size, 40 max_seq_len=max_seq_len, 41 num_kv_heads=num_kv_heads, 42 head_dim=head_dim, 43 dtype=dtype, 44 ) 45 self.transpose_cache = transpose_cache 46 self.max_seq_len = max_seq_len 47 if self.transpose_cache: 48 cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim) 49 else: 50 cache_shape = (batch_size, max_seq_len, num_kv_heads, head_dim) 51 52 self.register_buffer( 53 "k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False 54 ) 55 self.register_buffer( 56 "v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False 57 ) 58 self.register_buffer( 59 "cache_pos", torch.arange(0, self.max_seq_len), persistent=False 60 ) 61 self.batch_size = batch_size 62 63 def update( 64 self, k_val: torch.Tensor, v_val: torch.Tensor 65 ) -> Tuple[torch.Tensor, torch.Tensor]: 66 """Update KV cache with the new ``k_val``, ``v_val`` and return the updated cache. 67 68 Note: 69 When updating the KV cache, it is assumed that subsequent updates should update key-value 70 positions in consecutive sequence positions. If you wish to update cache values which have 71 already been filled, use ``.reset()``, which will reset the cache to the zero-th position. 72 73 Example: 74 >>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16) 75 >>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32)) 76 >>> cache.update(keys, values) 77 >>> # now positions 0 through 7 are filled 78 >>> cache.size 79 >>> 8 80 >>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32)) 81 >>> cache.update(keys, values) 82 >>> # this will fill at position 8 83 >>> cache.size 84 >>> 9 85 86 Args: 87 k_val (torch.Tensor): Current key tensor with shape [B, H, S, D] 88 v_val (torch.Tensor): Current value tensor with shape [B, H, S, D] 89 90 Returns: 91 Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively. 92 93 Raises: 94 AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. 95 ValueError: if the batch size of the new key (or value) tensor is greater than the batch size 96 used during cache setup. 97 """ 98 if self.transpose_cache: 99 bsz, _, seq_len, _ = k_val.shape 100 else: 101 bsz, seq_len, _, _ = k_val.shape 102 if bsz > self.k_cache.shape[0]: 103 raise ValueError( 104 f"The current cache has been setup with a batch size of {self.k_cache.shape[0]}" 105 f", but found new key tensors with batch size {k_val.shape[0]}!" 106 ) 107 108 assert (self.cache_pos[0] + seq_len) <= self.max_seq_len 109 110 k_out = self.k_cache 111 v_out = self.v_cache 112 113 if self.transpose_cache: 114 k_out[:, :, self.cache_pos[:seq_len]] = k_val 115 v_out[:, :, self.cache_pos[:seq_len]] = v_val 116 else: 117 k_out[:, self.cache_pos[:seq_len]] = k_val 118 v_out[:, self.cache_pos[:seq_len]] = v_val 119 120 # forward cache_pos seq_len positions along 121 # cache_pos starts at (0, 1, 2, 3, 4, 5, ...) 122 # an update of seq_len = 5 tokens brings it to 123 # (5, 6, 7, 8, 9, ...) 124 # this allows us to track the current position in the cache 125 # after the last update in a compile-friendly way without any dynamism 126 # e.g. relying on an int size tracker, or re-creating cache_pos every time 127 self.cache_pos.add_(seq_len) 128 129 return k_out, v_out 130 131 def clone(self) -> "KVCache": 132 """Create a clone of the KVCache.""" 133 if self.transpose_cache: 134 num_kv_heads = self.k_cache.shape[1] 135 else: 136 num_kv_heads = self.k_cache.shape[2] 137 clone = KVCache( 138 batch_size=self.batch_size, 139 max_seq_len=self.max_seq_len, 140 num_kv_heads=num_kv_heads, 141 head_dim=self.k_cache.shape[3], 142 dtype=self.k_cache.dtype, 143 transpose_cache=self.transpose_cache, 144 ) 145 clone.k_cache.copy_(self.k_cache) 146 clone.v_cache.copy_(self.v_cache) 147 clone.cache_pos.copy_(self.cache_pos) 148 return clone 149