xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/mtk_llama_runner.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // A simple llama2 runner that includes preprocessing and post processing logic.
10 // The module takes in a string as input and emits a string as output.
11 
12 #pragma once
13 
14 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
15 #include <executorch/extension/llm/runner/irunner.h>
16 #include <executorch/extension/llm/runner/stats.h>
17 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
18 #include <executorch/extension/llm/tokenizer/tiktoken.h>
19 #include <cstdint>
20 #include <functional>
21 #include <memory>
22 #include <string>
23 
24 #include "llama_runner/LlamaConfig.h"
25 #include "llama_runner/LlamaRuntime.h"
26 using Stats = ::executorch::llm::Stats;
27 
28 using example::LlamaModelOptions;
29 using example::LlamaModelPaths;
30 using example::LlamaRuntime;
31 using executorch::extension::llm::Tokenizer;
32 using executorch::runtime::Error;
33 using executorch::runtime::Result;
34 
35 class MTKLlamaRunner : public executorch::extension::llm::IRunner {
36  public:
37   explicit MTKLlamaRunner(
38       const std::string& model_path,
39       const std::string& tokenizer_path,
40       const float temperature = 0.8f);
41 
42   bool is_loaded() const;
43   Error load();
44   Error generate(
45       const std::string& prompt,
46       int32_t seq_len = 128,
47       std::function<void(const std::string&)> token_callback = {},
48       std::function<void(const Stats&)> stats_callback = {},
49       bool echo = true,
50       bool warming = false);
51   void stop();
52 
53   LlamaModelOptions get_model_options();
54   LlamaModelPaths get_model_paths();
55   Result<uint64_t> digest_prompt(
56       LlamaRuntime& llama_runtime,
57       const std::unique_ptr<Tokenizer>& tokenizer,
58       const std::vector<uint64_t> input_tokens);
59   Error gen_response(
60       LlamaRuntime& llama_runtime,
61       const std::unique_ptr<Tokenizer>& tokenizer,
62       const uint64_t input_token,
63       std::function<void(const std::string&)> token_callback);
64   Error inference(
65       LlamaRuntime& llama_runtime,
66       const std::unique_ptr<Tokenizer>& tokenizer,
67       const std::string& prompt,
68       std::function<void(const std::string&)> token_callback);
69   std::unique_ptr<Tokenizer> load_tokenizer();
70 
71  private:
72   // model
73   const LlamaModelOptions modeloptions_;
74   const LlamaModelPaths modelpaths_;
75   std::unique_ptr<Tokenizer> tokenizer_;
76   std::unique_ptr<LlamaRuntime> runtime_;
77 };
78