1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import json 9from typing import Optional 10 11import torch 12 13from executorch.examples.models.llama.export_llama_lib import ( 14 EXECUTORCH_DEFINED_MODELS, 15 TORCHTUNE_DEFINED_MODELS, 16) 17 18from executorch.extension.pybindings.portable_lib import _load_for_executorch 19 20# Load custom ops and quantized ops. 21from executorch.extension.pybindings import portable_lib # noqa # usort: skip 22 23from executorch.examples.models.llama.runner.generation import LlamaRunner 24 25# Note: import this after portable_lib 26from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip 27from executorch.kernels import quantized # noqa 28 29 30class NativeLlamaRunner(LlamaRunner): 31 """ 32 Runs llama via ExecuTorch with provided pte file. 33 """ 34 35 def __init__(self, args): 36 with open(args.params, "r") as f: 37 params = json.loads(f.read()) 38 super().__init__( 39 tokenizer_path=args.tokenizer, 40 max_seq_len=args.max_len, 41 max_batch_size=1, 42 use_kv_cache=args.kv_cache, 43 vocab_size=params["vocab_size"], 44 ) 45 self.model = _load_for_executorch(args.pte) 46 47 def forward( 48 self, 49 tokens: torch.Tensor, 50 input_pos: Optional[torch.Tensor] = None, 51 ) -> torch.Tensor: 52 return ( 53 self.model.forward((tokens, input_pos)) 54 if input_pos is not None 55 else self.model.forward((tokens,)) 56 )[0] 57 58 59def build_args_parser() -> argparse.ArgumentParser: 60 # TODO: merge these with build_args_parser from export_llama_lib. 61 parser = argparse.ArgumentParser() 62 63 parser.add_argument( 64 "--model", 65 default="llama3", 66 choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS, 67 ) 68 69 parser.add_argument( 70 "-f", 71 "--pte", 72 type=str, 73 default=None, 74 help="path to exported executorch .pte file", 75 ) 76 77 parser.add_argument( 78 "-p", "--params", type=str, default=None, help="model params file" 79 ) 80 81 parser.add_argument( 82 "-t", 83 "--tokenizer", 84 type=str, 85 default=None, 86 ) 87 88 parser.add_argument( 89 "--prompt", 90 type=str, 91 default="Hello", 92 ) 93 94 parser.add_argument( 95 "--temperature", 96 type=float, 97 default=0.6, 98 ) 99 100 parser.add_argument( 101 "-kv", 102 "--kv_cache", 103 action="store_true", 104 ) 105 106 parser.add_argument( 107 "--max_len", 108 type=int, 109 default=128, 110 help="Maximum length of the generated response sequence.", 111 ) 112 113 return parser 114 115 116def main() -> None: 117 parser = build_args_parser() 118 args = parser.parse_args() 119 runner = NativeLlamaRunner(args) 120 generated_tokens = runner.text_completion( 121 prompt=args.prompt, 122 temperature=args.temperature, 123 ) 124 print(f"Response: {generated_tokens}") 125 126 127if __name__ == "__main__": 128 main() # pragma: no cover 129