xref: /aosp_15_r20/external/executorch/extension/llm/modules/kv_cache.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Tuple
8
9import torch
10from torchtune.modules.kv_cache import KVCache as TuneKVCache
11
12
13class KVCache(TuneKVCache):
14    """
15    An export-friendly KVCache implementation adopted from torchtune KVCache:
16    https://github.com/pytorch/torchtune/blob/main/torchtune/modules/kv_cache.py
17    This also takes both transposed and un-transposed KVCache shapes.
18    Standalone ``nn.Module`` containing a kv-cache to cache past key and values during inference.
19
20    Args:
21        batch_size (int): batch size model will be run with
22        max_seq_len (int): maximum sequence length model will be run with
23        num_kv_heads (int): number of key/value heads.
24        head_dim (int): per-attention head embedding dimension
25        dtype (torch.dtype): dtype for the caches
26        transpose_cache (bool): whether we transpose(1, 2) for kv cache.
27    """
28
29    def __init__(
30        self,
31        batch_size: int,
32        max_seq_len: int,
33        num_kv_heads: int,
34        head_dim: int,
35        dtype: torch.dtype,
36        transpose_cache: bool = True,
37    ) -> None:
38        super().__init__(
39            batch_size=batch_size,
40            max_seq_len=max_seq_len,
41            num_kv_heads=num_kv_heads,
42            head_dim=head_dim,
43            dtype=dtype,
44        )
45        self.transpose_cache = transpose_cache
46        self.max_seq_len = max_seq_len
47        if self.transpose_cache:
48            cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
49        else:
50            cache_shape = (batch_size, max_seq_len, num_kv_heads, head_dim)
51
52        self.register_buffer(
53            "k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
54        )
55        self.register_buffer(
56            "v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
57        )
58        self.register_buffer(
59            "cache_pos", torch.arange(0, self.max_seq_len), persistent=False
60        )
61        self.batch_size = batch_size
62
63    def update(
64        self, k_val: torch.Tensor, v_val: torch.Tensor
65    ) -> Tuple[torch.Tensor, torch.Tensor]:
66        """Update KV cache with the new ``k_val``, ``v_val`` and return the updated cache.
67
68        Note:
69            When updating the KV cache, it is assumed that subsequent updates should update key-value
70            positions in consecutive sequence positions. If you wish to update cache values which have
71            already been filled, use ``.reset()``, which will reset the cache to the zero-th position.
72
73        Example:
74            >>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16)
75            >>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
76            >>> cache.update(keys, values)
77            >>> # now positions 0 through 7 are filled
78            >>> cache.size
79            >>> 8
80            >>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32))
81            >>> cache.update(keys, values)
82            >>> # this will fill at position 8
83            >>> cache.size
84            >>> 9
85
86        Args:
87            k_val (torch.Tensor): Current key tensor with shape [B, H, S, D]
88            v_val (torch.Tensor): Current value tensor with shape [B, H, S, D]
89
90        Returns:
91            Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively.
92
93        Raises:
94            AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length.
95            ValueError: if the batch size of the new key (or value) tensor is greater than the batch size
96                used during cache setup.
97        """
98        if self.transpose_cache:
99            bsz, _, seq_len, _ = k_val.shape
100        else:
101            bsz, seq_len, _, _ = k_val.shape
102        if bsz > self.k_cache.shape[0]:
103            raise ValueError(
104                f"The current cache has been setup with a batch size of {self.k_cache.shape[0]}"
105                f", but found new key tensors with batch size {k_val.shape[0]}!"
106            )
107
108        assert (self.cache_pos[0] + seq_len) <= self.max_seq_len
109
110        k_out = self.k_cache
111        v_out = self.v_cache
112
113        if self.transpose_cache:
114            k_out[:, :, self.cache_pos[:seq_len]] = k_val
115            v_out[:, :, self.cache_pos[:seq_len]] = v_val
116        else:
117            k_out[:, self.cache_pos[:seq_len]] = k_val
118            v_out[:, self.cache_pos[:seq_len]] = v_val
119
120        # forward cache_pos seq_len positions along
121        # cache_pos starts at (0, 1, 2, 3, 4, 5, ...)
122        # an update of seq_len = 5 tokens brings it to
123        # (5, 6, 7, 8, 9, ...)
124        # this allows us to track the current position in the cache
125        # after the last update in a compile-friendly way without any dynamism
126        # e.g. relying on an int size tracker, or re-creating cache_pos every time
127        self.cache_pos.add_(seq_len)
128
129        return k_out, v_out
130
131    def clone(self) -> "KVCache":
132        """Create a clone of the KVCache."""
133        if self.transpose_cache:
134            num_kv_heads = self.k_cache.shape[1]
135        else:
136            num_kv_heads = self.k_cache.shape[2]
137        clone = KVCache(
138            batch_size=self.batch_size,
139            max_seq_len=self.max_seq_len,
140            num_kv_heads=num_kv_heads,
141            head_dim=self.k_cache.shape[3],
142            dtype=self.k_cache.dtype,
143            transpose_cache=self.transpose_cache,
144        )
145        clone.k_cache.copy_(self.k_cache)
146        clone.v_cache.copy_(self.v_cache)
147        clone.cache_pos.copy_(self.cache_pos)
148        return clone
149