1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8import torch.nn 9from transformers import Phi3ForCausalLM 10 11from .static_cache import ETStaticCache 12 13 14class Phi3Mini(torch.nn.Module): 15 16 def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int): 17 super().__init__() 18 self.model = model 19 self.cache = ETStaticCache( 20 # pyre-fixme[16]: `Phi3ForCausalLM` has no attribute `config`. 21 config=model.config, 22 max_batch_size=max_batch_size, 23 max_cache_len=max_seq_len, 24 # pyre-fixme[16]: `Phi3ForCausalLM` has no attribute `device`. 25 device=self.model.device, 26 # pyre-fixme[16]: `Phi3ForCausalLM` has no attribute `dtype`. 27 dtype=self.model.dtype, 28 ) 29 30 def forward( 31 self, 32 # pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`. 33 input_ids: torch.LongTensor = None, 34 ) -> torch.FloatTensor: 35 # pyre-fixme[16]: `Phi3ForCausalLM` has no attribute `forward`. 36 return self.model.forward( 37 input_ids=input_ids, 38 use_cache=True, 39 return_dict=True, 40 past_key_values=self.cache, 41 ).logits[:, -1, :] 42