1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8from typing import Optional, Union 9 10import torch 11from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken 12from executorch.extension.llm.tokenizer.tokenizer import ( 13 Tokenizer as SentencePieceTokenizer, 14) 15 16from lm_eval.models.huggingface import HFLM as eval_wrapper 17 18from torch import nn 19 20 21class EagerEvalWrapper(eval_wrapper): 22 """ 23 A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library. 24 """ 25 26 def __init__( 27 self, 28 model: nn.Module, 29 tokenizer: Union[SentencePieceTokenizer, Tiktoken], 30 max_seq_length: Optional[int] = None, 31 use_kv_cache: bool = False, 32 ): 33 device = "cuda" if torch.cuda.is_available() else "cpu" 34 super().__init__(device=device, pretrained="gpt2") 35 self._model = model 36 self._tokenizer = tokenizer 37 self._device = torch.device(device) 38 self._max_seq_length = 2048 if max_seq_length is None else max_seq_length 39 self._use_kv_cache = use_kv_cache 40 41 @property 42 def eot_token_id(self): 43 """ 44 The stories model does not have an EOT token, so we use the EOS token instead. 45 """ 46 if hasattr(self._tokenizer, "eot_id"): 47 return self._tokenizer.eot_id 48 return self._tokenizer.eos_id 49 50 @property 51 def prefix_token_id(self): 52 return self.eot_token_id 53 54 @property 55 def max_length(self): 56 return self._max_seq_length 57 58 @property 59 def max_gen_toks(self): 60 return 50 61 62 @property 63 def batch_size(self): 64 return 1 65 66 @property 67 def device(self): 68 return self._device 69 70 def tok_encode(self, string: str, **kwargs): # pyre-ignore 71 return self._tokenizer.encode(string, bos=False, eos=False) 72 73 def tok_decode(self, tokens): 74 return self._tokenizer.decode(tokens) 75 76 def _model_call(self, inps): 77 if self._use_kv_cache: 78 pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) 79 # Batch process the whole sequence. 80 logits = self._model(inps[:, : self._max_seq_length], pos_tensor) 81 return logits 82 else: 83 return self._model(inps) 84 85 def _model_generate(self, context, max_length, eos_token_id): 86 raise Exception("unimplemented") 87