xref: /aosp_15_r20/external/executorch/examples/models/llama/evaluate/eager_eval.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8from typing import Optional, Union
9
10import torch
11from executorch.examples.models.llama.tokenizer.tiktoken import Tokenizer as Tiktoken
12from executorch.extension.llm.tokenizer.tokenizer import (
13    Tokenizer as SentencePieceTokenizer,
14)
15
16from lm_eval.models.huggingface import HFLM as eval_wrapper
17
18from torch import nn
19
20
21class EagerEvalWrapper(eval_wrapper):
22    """
23    A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
24    """
25
26    def __init__(
27        self,
28        model: nn.Module,
29        tokenizer: Union[SentencePieceTokenizer, Tiktoken],
30        max_seq_length: Optional[int] = None,
31        use_kv_cache: bool = False,
32    ):
33        device = "cuda" if torch.cuda.is_available() else "cpu"
34        super().__init__(device=device, pretrained="gpt2")
35        self._model = model
36        self._tokenizer = tokenizer
37        self._device = torch.device(device)
38        self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
39        self._use_kv_cache = use_kv_cache
40
41    @property
42    def eot_token_id(self):
43        """
44        The stories model does not have an EOT token, so we use the EOS token instead.
45        """
46        if hasattr(self._tokenizer, "eot_id"):
47            return self._tokenizer.eot_id
48        return self._tokenizer.eos_id
49
50    @property
51    def prefix_token_id(self):
52        return self.eot_token_id
53
54    @property
55    def max_length(self):
56        return self._max_seq_length
57
58    @property
59    def max_gen_toks(self):
60        return 50
61
62    @property
63    def batch_size(self):
64        return 1
65
66    @property
67    def device(self):
68        return self._device
69
70    def tok_encode(self, string: str, **kwargs):  # pyre-ignore
71        return self._tokenizer.encode(string, bos=False, eos=False)
72
73    def tok_decode(self, tokens):
74        return self._tokenizer.decode(tokens)
75
76    def _model_call(self, inps):
77        if self._use_kv_cache:
78            pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
79            # Batch process the whole sequence.
80            logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
81            return logits
82        else:
83            return self._model(inps)
84
85    def _model_generate(self, context, max_length, eos_token_id):
86        raise Exception("unimplemented")
87