1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from enum import Enum 9 10import torch 11import torch.nn as nn 12from executorch.examples.models.llama.llama_transformer import KVCache 13from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 14 15 16""" 17 Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py 18""" 19 20 21# Doesnt have to abide by affine quantizaiton laws 22# However, if we do implement quantized sdpa, then this might be handy 23class QuantizedCacheType(Enum): 24 AffineSymmetric = 0 25 AffineAsymmetric = 1 26 AffineSymmetricGroupWise = 2 27 AffineAsymmetricGroupWise = 3 28 29 30class QuantizedKVCache(nn.Module): 31 def __init__( 32 self, 33 max_batch_size, 34 max_seq_length, 35 n_heads, 36 head_dim, 37 cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, 38 tranposed=False, 39 enable_dynamic_shape=False, 40 ): 41 super().__init__() 42 if cache_type not in ( 43 QuantizedCacheType.AffineSymmetric, 44 QuantizedCacheType.AffineAsymmetric, 45 ): 46 47 raise ValueError( 48 f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}" 49 ) 50 51 # For now supporting int8 only 52 self.quantized_cache_dtype = torch.int8 53 self.cache_fp_type = torch.float32 54 self.is_transposed = tranposed 55 self.enable_dynamic_shape = enable_dynamic_shape 56 if self.is_transposed: 57 cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) 58 scale_shape = (max_batch_size, n_heads, max_seq_length, 1) 59 else: 60 cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) 61 scale_shape = (max_batch_size, max_seq_length, n_heads, 1) 62 self.register_buffer( 63 "k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) 64 ) 65 self.register_buffer( 66 "v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) 67 ) 68 self.register_buffer( 69 "k_cache_scales", torch.ones(scale_shape, dtype=torch.float64) 70 ) 71 self.register_buffer( 72 "v_cache_scales", torch.ones(scale_shape, dtype=torch.float64) 73 ) 74 if cache_type == QuantizedCacheType.AffineAsymmetric: 75 self.register_buffer( 76 "k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64) 77 ) 78 self.register_buffer( 79 "v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64) 80 ) 81 82 def _quantize(self, value): 83 scales, zero_points = ( 84 torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( 85 value, self.quantized_cache_dtype 86 ) 87 ) 88 quantized_value = torch.ops.quantized_decomposed.quantize_per_token( 89 value, 90 scales, 91 zero_points, 92 torch.iinfo(self.quantized_cache_dtype).min, 93 torch.iinfo(self.quantized_cache_dtype).max, 94 self.quantized_cache_dtype, 95 ) 96 return quantized_value, scales, zero_points 97 98 def update(self, input_pos, k_val, v_val): 99 # quantize current k_val and store it in the cache 100 quantized_k_val, k_scales, k_zero_points = self._quantize(k_val) 101 102 quantized_v_val, v_scales, v_zero_points = self._quantize(v_val) 103 104 if self.is_transposed: 105 # We cannot use update_cache op at the moment 106 # if the cache is transposed 107 # Also note that we shold not need separate paths 108 # for dynamic shape vs ! 109 # Only reason it is done this way is to accommodate 110 # for lowering pains of backends that work better 111 # with index_put op. 112 if self.enable_dynamic_shape: 113 start_pos = input_pos[0].item() 114 torch._check_is_size(start_pos) 115 dim_to_slice = 2 if self.is_transposed else 1 116 torch._check(start_pos < self.k_cache.size(dim_to_slice)) 117 seq_length = k_val.size(dim_to_slice) 118 narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) 119 narrowed_k_scales = self.k_cache_scales.narrow( 120 dim_to_slice, start_pos, seq_length 121 ) 122 narrowed_k_zp = self.k_cache_zero_points.narrow( 123 dim_to_slice, start_pos, seq_length 124 ) 125 narrowed_k.copy_(quantized_k_val) 126 narrowed_k_scales.copy_(k_scales) 127 narrowed_k_zp.copy_(k_zero_points) 128 narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) 129 narrowed_v_scales = self.v_cache_scales.narrow( 130 dim_to_slice, start_pos, seq_length 131 ) 132 narrowed_v_zp = self.v_cache_zero_points.narrow( 133 dim_to_slice, start_pos, seq_length 134 ) 135 narrowed_v.copy_(quantized_v_val) 136 narrowed_v_scales.copy_(v_scales) 137 narrowed_v_zp.copy_(v_zero_points) 138 else: 139 self.k_cache[:, :, input_pos] = quantized_k_val 140 self.k_cache_scales[:, :, input_pos] = k_scales 141 self.k_cache_zero_points[:, :, input_pos] = k_zero_points 142 self.v_cache[:, :, input_pos] = quantized_v_val 143 self.v_cache_scales[:, :, input_pos] = v_scales 144 self.v_cache_zero_points[:, :, input_pos] = v_zero_points 145 else: 146 # Right now using custom ops on this path. 147 # In future we can update custom op to handle transposed cache 148 # as well. 149 # Note that we may have to revert this change if other ET 150 # backends such as QNN want to use quantized cache, with dynamic shape, 151 # instead of quantizing on their own. 152 # But until this opting for code simplicity 153 start_pos = input_pos[0].item() 154 _ = torch.ops.llama.update_quantized_cache( 155 quantized_k_val, self.k_cache, start_pos 156 ) 157 _ = torch.ops.llama.update_quantized_cache( 158 k_scales, self.k_cache_scales, start_pos 159 ) 160 _ = torch.ops.llama.update_quantized_cache( 161 k_zero_points, self.k_cache_zero_points, start_pos 162 ) 163 _ = torch.ops.llama.update_quantized_cache( 164 quantized_v_val, self.v_cache, start_pos 165 ) 166 _ = torch.ops.llama.update_quantized_cache( 167 v_scales, self.v_cache_scales, start_pos 168 ) 169 _ = torch.ops.llama.update_quantized_cache( 170 v_zero_points, self.v_cache_zero_points, start_pos 171 ) 172 173 k_out = torch.ops.quantized_decomposed.dequantize_per_token( 174 self.k_cache, 175 self.k_cache_scales, 176 self.k_cache_zero_points, 177 torch.iinfo(self.quantized_cache_dtype).min, 178 torch.iinfo(self.quantized_cache_dtype).max, 179 self.quantized_cache_dtype, 180 self.cache_fp_type, 181 ) 182 v_out = torch.ops.quantized_decomposed.dequantize_per_token( 183 self.v_cache, 184 self.v_cache_scales, 185 self.v_cache_zero_points, 186 torch.iinfo(self.quantized_cache_dtype).min, 187 torch.iinfo(self.quantized_cache_dtype).max, 188 self.quantized_cache_dtype, 189 self.cache_fp_type, 190 ) 191 return k_out, v_out 192 193 @classmethod 194 def from_float(cls, kv_cache, cache_type: QuantizedCacheType): 195 cache_shape = kv_cache.k_cache.shape 196 if kv_cache.is_transposed: 197 max_batch_size, n_heads, max_seq_length, head_dim = cache_shape 198 else: 199 max_batch_size, max_seq_length, n_heads, head_dim = cache_shape 200 return cls( 201 max_batch_size, 202 max_seq_length, 203 n_heads, 204 head_dim, 205 cache_type, 206 kv_cache.is_transposed, 207 kv_cache.enable_dynamic_shape, 208 ) 209 210 211def replace_kv_cache_with_quantized_kv_cache(module): 212 logging.warning( 213 "Replacing KVCache with QuantizedKVCache. This modifies the model in place." 214 ) 215 for name, child in module.named_children(): 216 if isinstance(child, KVCache): 217 setattr( 218 module, 219 name, 220 QuantizedKVCache.from_float(child, QuantizedCacheType.AffineAsymmetric), 221 ) 222 else: 223 replace_kv_cache_with_quantized_kv_cache(child) 224 return module 225