xref: /aosp_15_r20/external/executorch/examples/models/phi-3-mini/static_cache.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8from typing import Optional
9
10import torch
11from transformers import PretrainedConfig, StaticCache
12
13
14class ETStaticCache(StaticCache):
15    """
16    A customized static cache implementation, which overrides a few methods to make it exportable to ExecuTorch.
17    This can be removed once transformers supports static cache for Phi3 properly.
18    """
19
20    def __init__(
21        self,
22        config: PretrainedConfig,
23        max_batch_size: int,
24        max_cache_len: int,
25        device,
26        dtype=torch.float32,
27    ) -> None:
28        super().__init__(
29            config=config,
30            max_batch_size=max_batch_size,
31            max_cache_len=max_cache_len,
32            device=device,
33            dtype=dtype,
34        )
35
36    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
37        # pyre-fixme[16]: `ETStaticCache` has no attribute `key_cache`.
38        return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item()
39
40    def get_usable_length(
41        self, new_seq_length: int, layer_idx: Optional[int] = 0
42    ) -> int:
43        return self.get_seq_length(layer_idx)
44