1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8from typing import Optional 9 10import torch 11from transformers import PretrainedConfig, StaticCache 12 13 14class ETStaticCache(StaticCache): 15 """ 16 A customized static cache implementation, which overrides a few methods to make it exportable to ExecuTorch. 17 This can be removed once transformers supports static cache for Phi3 properly. 18 """ 19 20 def __init__( 21 self, 22 config: PretrainedConfig, 23 max_batch_size: int, 24 max_cache_len: int, 25 device, 26 dtype=torch.float32, 27 ) -> None: 28 super().__init__( 29 config=config, 30 max_batch_size=max_batch_size, 31 max_cache_len=max_cache_len, 32 device=device, 33 dtype=dtype, 34 ) 35 36 def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 37 # pyre-fixme[16]: `ETStaticCache` has no attribute `key_cache`. 38 return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item() 39 40 def get_usable_length( 41 self, new_seq_length: int, layer_idx: Optional[int] = 0 42 ) -> int: 43 return self.get_seq_length(layer_idx) 44