xref: /aosp_15_r20/external/executorch/examples/models/llama/source_transformation/quantized_kv_cache.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from enum import Enum
9
10import torch
11import torch.nn as nn
12from executorch.examples.models.llama.llama_transformer import KVCache
13from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib  # noqa: F401
14
15
16"""
17 Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py
18"""
19
20
21# Doesnt have to abide by affine quantizaiton laws
22# However, if we do implement quantized sdpa, then this might be handy
23class QuantizedCacheType(Enum):
24    AffineSymmetric = 0
25    AffineAsymmetric = 1
26    AffineSymmetricGroupWise = 2
27    AffineAsymmetricGroupWise = 3
28
29
30class QuantizedKVCache(nn.Module):
31    def __init__(
32        self,
33        max_batch_size,
34        max_seq_length,
35        n_heads,
36        head_dim,
37        cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric,
38        tranposed=False,
39        enable_dynamic_shape=False,
40    ):
41        super().__init__()
42        if cache_type not in (
43            QuantizedCacheType.AffineSymmetric,
44            QuantizedCacheType.AffineAsymmetric,
45        ):
46
47            raise ValueError(
48                f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
49            )
50
51        # For now supporting int8 only
52        self.quantized_cache_dtype = torch.int8
53        self.cache_fp_type = torch.float32
54        self.is_transposed = tranposed
55        self.enable_dynamic_shape = enable_dynamic_shape
56        if self.is_transposed:
57            cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
58            scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
59        else:
60            cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim)
61            scale_shape = (max_batch_size, max_seq_length, n_heads, 1)
62        self.register_buffer(
63            "k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
64        )
65        self.register_buffer(
66            "v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
67        )
68        self.register_buffer(
69            "k_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
70        )
71        self.register_buffer(
72            "v_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
73        )
74        if cache_type == QuantizedCacheType.AffineAsymmetric:
75            self.register_buffer(
76                "k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
77            )
78            self.register_buffer(
79                "v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64)
80            )
81
82    def _quantize(self, value):
83        scales, zero_points = (
84            torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(
85                value, self.quantized_cache_dtype
86            )
87        )
88        quantized_value = torch.ops.quantized_decomposed.quantize_per_token(
89            value,
90            scales,
91            zero_points,
92            torch.iinfo(self.quantized_cache_dtype).min,
93            torch.iinfo(self.quantized_cache_dtype).max,
94            self.quantized_cache_dtype,
95        )
96        return quantized_value, scales, zero_points
97
98    def update(self, input_pos, k_val, v_val):
99        # quantize current k_val and store it in the cache
100        quantized_k_val, k_scales, k_zero_points = self._quantize(k_val)
101
102        quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)
103
104        if self.is_transposed:
105            # We cannot use update_cache op at the moment
106            # if the cache is transposed
107            # Also note that we shold not need separate paths
108            # for dynamic shape vs !
109            # Only reason it is done this way is to accommodate
110            # for lowering pains of backends that work better
111            # with index_put op.
112            if self.enable_dynamic_shape:
113                start_pos = input_pos[0].item()
114                torch._check_is_size(start_pos)
115                dim_to_slice = 2 if self.is_transposed else 1
116                torch._check(start_pos < self.k_cache.size(dim_to_slice))
117                seq_length = k_val.size(dim_to_slice)
118                narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
119                narrowed_k_scales = self.k_cache_scales.narrow(
120                    dim_to_slice, start_pos, seq_length
121                )
122                narrowed_k_zp = self.k_cache_zero_points.narrow(
123                    dim_to_slice, start_pos, seq_length
124                )
125                narrowed_k.copy_(quantized_k_val)
126                narrowed_k_scales.copy_(k_scales)
127                narrowed_k_zp.copy_(k_zero_points)
128                narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
129                narrowed_v_scales = self.v_cache_scales.narrow(
130                    dim_to_slice, start_pos, seq_length
131                )
132                narrowed_v_zp = self.v_cache_zero_points.narrow(
133                    dim_to_slice, start_pos, seq_length
134                )
135                narrowed_v.copy_(quantized_v_val)
136                narrowed_v_scales.copy_(v_scales)
137                narrowed_v_zp.copy_(v_zero_points)
138            else:
139                self.k_cache[:, :, input_pos] = quantized_k_val
140                self.k_cache_scales[:, :, input_pos] = k_scales
141                self.k_cache_zero_points[:, :, input_pos] = k_zero_points
142                self.v_cache[:, :, input_pos] = quantized_v_val
143                self.v_cache_scales[:, :, input_pos] = v_scales
144                self.v_cache_zero_points[:, :, input_pos] = v_zero_points
145        else:
146            # Right now using custom ops on this path.
147            # In future we can update custom op to handle transposed cache
148            # as well.
149            # Note that we may have to revert this change if other ET
150            # backends such as QNN want to use quantized cache, with dynamic shape,
151            # instead of quantizing on their own.
152            # But until this opting for code simplicity
153            start_pos = input_pos[0].item()
154            _ = torch.ops.llama.update_quantized_cache(
155                quantized_k_val, self.k_cache, start_pos
156            )
157            _ = torch.ops.llama.update_quantized_cache(
158                k_scales, self.k_cache_scales, start_pos
159            )
160            _ = torch.ops.llama.update_quantized_cache(
161                k_zero_points, self.k_cache_zero_points, start_pos
162            )
163            _ = torch.ops.llama.update_quantized_cache(
164                quantized_v_val, self.v_cache, start_pos
165            )
166            _ = torch.ops.llama.update_quantized_cache(
167                v_scales, self.v_cache_scales, start_pos
168            )
169            _ = torch.ops.llama.update_quantized_cache(
170                v_zero_points, self.v_cache_zero_points, start_pos
171            )
172
173        k_out = torch.ops.quantized_decomposed.dequantize_per_token(
174            self.k_cache,
175            self.k_cache_scales,
176            self.k_cache_zero_points,
177            torch.iinfo(self.quantized_cache_dtype).min,
178            torch.iinfo(self.quantized_cache_dtype).max,
179            self.quantized_cache_dtype,
180            self.cache_fp_type,
181        )
182        v_out = torch.ops.quantized_decomposed.dequantize_per_token(
183            self.v_cache,
184            self.v_cache_scales,
185            self.v_cache_zero_points,
186            torch.iinfo(self.quantized_cache_dtype).min,
187            torch.iinfo(self.quantized_cache_dtype).max,
188            self.quantized_cache_dtype,
189            self.cache_fp_type,
190        )
191        return k_out, v_out
192
193    @classmethod
194    def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
195        cache_shape = kv_cache.k_cache.shape
196        if kv_cache.is_transposed:
197            max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
198        else:
199            max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
200        return cls(
201            max_batch_size,
202            max_seq_length,
203            n_heads,
204            head_dim,
205            cache_type,
206            kv_cache.is_transposed,
207            kv_cache.enable_dynamic_shape,
208        )
209
210
211def replace_kv_cache_with_quantized_kv_cache(module):
212    logging.warning(
213        "Replacing KVCache with QuantizedKVCache. This modifies the model in place."
214    )
215    for name, child in module.named_children():
216        if isinstance(child, KVCache):
217            setattr(
218                module,
219                name,
220                QuantizedKVCache.from_float(child, QuantizedCacheType.AffineAsymmetric),
221            )
222        else:
223            replace_kv_cache_with_quantized_kv_cache(child)
224    return module
225