1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 // A simple llama2 runner that includes preprocessing and post processing logic. 10 // The module takes in a string as input and emits a string as output. 11 12 #pragma once 13 14 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h> 15 #include <executorch/extension/llm/runner/irunner.h> 16 #include <executorch/extension/llm/runner/stats.h> 17 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h> 18 #include <executorch/extension/llm/tokenizer/tiktoken.h> 19 #include <cstdint> 20 #include <functional> 21 #include <memory> 22 #include <string> 23 24 #include "llama_runner/LlamaConfig.h" 25 #include "llama_runner/LlamaRuntime.h" 26 using Stats = ::executorch::llm::Stats; 27 28 using example::LlamaModelOptions; 29 using example::LlamaModelPaths; 30 using example::LlamaRuntime; 31 using executorch::extension::llm::Tokenizer; 32 using executorch::runtime::Error; 33 using executorch::runtime::Result; 34 35 class MTKLlamaRunner : public executorch::extension::llm::IRunner { 36 public: 37 explicit MTKLlamaRunner( 38 const std::string& model_path, 39 const std::string& tokenizer_path, 40 const float temperature = 0.8f); 41 42 bool is_loaded() const; 43 Error load(); 44 Error generate( 45 const std::string& prompt, 46 int32_t seq_len = 128, 47 std::function<void(const std::string&)> token_callback = {}, 48 std::function<void(const Stats&)> stats_callback = {}, 49 bool echo = true, 50 bool warming = false); 51 void stop(); 52 53 LlamaModelOptions get_model_options(); 54 LlamaModelPaths get_model_paths(); 55 Result<uint64_t> digest_prompt( 56 LlamaRuntime& llama_runtime, 57 const std::unique_ptr<Tokenizer>& tokenizer, 58 const std::vector<uint64_t> input_tokens); 59 Error gen_response( 60 LlamaRuntime& llama_runtime, 61 const std::unique_ptr<Tokenizer>& tokenizer, 62 const uint64_t input_token, 63 std::function<void(const std::string&)> token_callback); 64 Error inference( 65 LlamaRuntime& llama_runtime, 66 const std::unique_ptr<Tokenizer>& tokenizer, 67 const std::string& prompt, 68 std::function<void(const std::string&)> token_callback); 69 std::unique_ptr<Tokenizer> load_tokenizer(); 70 71 private: 72 // model 73 const LlamaModelOptions modeloptions_; 74 const LlamaModelPaths modelpaths_; 75 std::unique_ptr<Tokenizer> tokenizer_; 76 std::unique_ptr<LlamaRuntime> runtime_; 77 }; 78