xref: /aosp_15_r20/external/executorch/examples/models/llama/runner/native.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import json
9from typing import Optional
10
11import torch
12
13from executorch.examples.models.llama.export_llama_lib import (
14    EXECUTORCH_DEFINED_MODELS,
15    TORCHTUNE_DEFINED_MODELS,
16)
17
18from executorch.extension.pybindings.portable_lib import _load_for_executorch
19
20# Load custom ops and quantized ops.
21from executorch.extension.pybindings import portable_lib  # noqa # usort: skip
22
23from executorch.examples.models.llama.runner.generation import LlamaRunner
24
25# Note: import this after portable_lib
26from executorch.extension.llm.custom_ops import sdpa_with_kv_cache  # noqa # usort: skip
27from executorch.kernels import quantized  # noqa
28
29
30class NativeLlamaRunner(LlamaRunner):
31    """
32    Runs llama via ExecuTorch with provided pte file.
33    """
34
35    def __init__(self, args):
36        with open(args.params, "r") as f:
37            params = json.loads(f.read())
38        super().__init__(
39            tokenizer_path=args.tokenizer,
40            max_seq_len=args.max_len,
41            max_batch_size=1,
42            use_kv_cache=args.kv_cache,
43            vocab_size=params["vocab_size"],
44        )
45        self.model = _load_for_executorch(args.pte)
46
47    def forward(
48        self,
49        tokens: torch.Tensor,
50        input_pos: Optional[torch.Tensor] = None,
51    ) -> torch.Tensor:
52        return (
53            self.model.forward((tokens, input_pos))
54            if input_pos is not None
55            else self.model.forward((tokens,))
56        )[0]
57
58
59def build_args_parser() -> argparse.ArgumentParser:
60    # TODO: merge these with build_args_parser from export_llama_lib.
61    parser = argparse.ArgumentParser()
62
63    parser.add_argument(
64        "--model",
65        default="llama3",
66        choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
67    )
68
69    parser.add_argument(
70        "-f",
71        "--pte",
72        type=str,
73        default=None,
74        help="path to exported executorch .pte file",
75    )
76
77    parser.add_argument(
78        "-p", "--params", type=str, default=None, help="model params file"
79    )
80
81    parser.add_argument(
82        "-t",
83        "--tokenizer",
84        type=str,
85        default=None,
86    )
87
88    parser.add_argument(
89        "--prompt",
90        type=str,
91        default="Hello",
92    )
93
94    parser.add_argument(
95        "--temperature",
96        type=float,
97        default=0.6,
98    )
99
100    parser.add_argument(
101        "-kv",
102        "--kv_cache",
103        action="store_true",
104    )
105
106    parser.add_argument(
107        "--max_len",
108        type=int,
109        default=128,
110        help="Maximum length of the generated response sequence.",
111    )
112
113    return parser
114
115
116def main() -> None:
117    parser = build_args_parser()
118    args = parser.parse_args()
119    runner = NativeLlamaRunner(args)
120    generated_tokens = runner.text_completion(
121        prompt=args.prompt,
122        temperature=args.temperature,
123    )
124    print(f"Response: {generated_tokens}")
125
126
127if __name__ == "__main__":
128    main()  # pragma: no cover
129