1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import json 9from typing import Optional, Type 10 11import torch 12 13from executorch.examples.models.llama.export_llama_lib import ( 14 _prepare_for_llama_export, 15 build_args_parser as _build_args_parser, 16) 17from executorch.examples.models.llama.runner.generation import LlamaRunner 18from executorch.extension.llm.export.builder import LLMEdgeManager 19 20 21class EagerLlamaRunner(LlamaRunner): 22 """ 23 Runs llama in eager mode with provided checkpoint file. 24 """ 25 26 def __init__(self, args): 27 with open(args.params, "r") as f: 28 params = json.loads(f.read()) 29 super().__init__( 30 tokenizer_path=args.tokenizer_path, 31 max_seq_len=args.max_seq_length, 32 max_batch_size=1, 33 use_kv_cache=args.use_kv_cache, 34 vocab_size=params["vocab_size"], 35 device="cuda" if torch.cuda.is_available() else "cpu", 36 ) 37 manager: LLMEdgeManager = _prepare_for_llama_export(args) 38 self.model = manager.model.eval().to(device=self.device) 39 40 def forward( 41 self, 42 tokens: torch.Tensor, 43 input_pos: Optional[torch.Tensor] = None, 44 ) -> torch.Tensor: 45 return self.model.forward(tokens=tokens, input_pos=input_pos) 46 47 48def build_args_parser() -> argparse.ArgumentParser: 49 parser = _build_args_parser() 50 51 parser.add_argument( 52 "--prompt", 53 type=str, 54 default=None, 55 ) 56 57 parser.add_argument( 58 "--temperature", 59 type=float, 60 default=0, 61 ) 62 63 parser.add_argument( 64 "--show_tokens", 65 action="store_true", 66 default=False, 67 help="Show the tokens that were generated", 68 ) 69 70 parser.add_argument( 71 "--chat", 72 action="store_true", 73 default=False, 74 help="Have multi-turn chat with the model", 75 ) 76 77 return parser 78 79 80def execute_runner(runner_class: Type[LlamaRunner]) -> None: 81 parser = build_args_parser() 82 args = parser.parse_args() 83 84 with torch.no_grad(): 85 runner = runner_class(args) # pyre-ignore: Missing argument [20] 86 generated_tokens = ( 87 runner.chat_completion(temperature=args.temperature) 88 if args.chat 89 else runner.text_completion( 90 prompt=args.prompt, 91 temperature=args.temperature, 92 echo=True, 93 ) 94 ) 95 if args.show_tokens: 96 print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") 97 98 99def main() -> None: 100 execute_runner(EagerLlamaRunner) 101 102 103if __name__ == "__main__": 104 main() # pragma: no cover 105