xref: /aosp_15_r20/external/executorch/examples/models/llama/runner/eager.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import json
9from typing import Optional, Type
10
11import torch
12
13from executorch.examples.models.llama.export_llama_lib import (
14    _prepare_for_llama_export,
15    build_args_parser as _build_args_parser,
16)
17from executorch.examples.models.llama.runner.generation import LlamaRunner
18from executorch.extension.llm.export.builder import LLMEdgeManager
19
20
21class EagerLlamaRunner(LlamaRunner):
22    """
23    Runs llama in eager mode with provided checkpoint file.
24    """
25
26    def __init__(self, args):
27        with open(args.params, "r") as f:
28            params = json.loads(f.read())
29        super().__init__(
30            tokenizer_path=args.tokenizer_path,
31            max_seq_len=args.max_seq_length,
32            max_batch_size=1,
33            use_kv_cache=args.use_kv_cache,
34            vocab_size=params["vocab_size"],
35            device="cuda" if torch.cuda.is_available() else "cpu",
36        )
37        manager: LLMEdgeManager = _prepare_for_llama_export(args)
38        self.model = manager.model.eval().to(device=self.device)
39
40    def forward(
41        self,
42        tokens: torch.Tensor,
43        input_pos: Optional[torch.Tensor] = None,
44    ) -> torch.Tensor:
45        return self.model.forward(tokens=tokens, input_pos=input_pos)
46
47
48def build_args_parser() -> argparse.ArgumentParser:
49    parser = _build_args_parser()
50
51    parser.add_argument(
52        "--prompt",
53        type=str,
54        default=None,
55    )
56
57    parser.add_argument(
58        "--temperature",
59        type=float,
60        default=0,
61    )
62
63    parser.add_argument(
64        "--show_tokens",
65        action="store_true",
66        default=False,
67        help="Show the tokens that were generated",
68    )
69
70    parser.add_argument(
71        "--chat",
72        action="store_true",
73        default=False,
74        help="Have multi-turn chat with the model",
75    )
76
77    return parser
78
79
80def execute_runner(runner_class: Type[LlamaRunner]) -> None:
81    parser = build_args_parser()
82    args = parser.parse_args()
83
84    with torch.no_grad():
85        runner = runner_class(args)  # pyre-ignore: Missing argument [20]
86        generated_tokens = (
87            runner.chat_completion(temperature=args.temperature)
88            if args.chat
89            else runner.text_completion(
90                prompt=args.prompt,
91                temperature=args.temperature,
92                echo=True,
93            )
94        )
95        if args.show_tokens:
96            print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
97
98
99def main() -> None:
100    execute_runner(EagerLlamaRunner)
101
102
103if __name__ == "__main__":
104    main()  # pragma: no cover
105