xref: /aosp_15_r20/external/executorch/examples/qualcomm/qaihub_scripts/llama/runner/runner.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // A simple llama2/3 runner that includes preprocessing and post processing
10 // logic. The module takes in a string as input and emits a string as output.
11 
12 #pragma once
13 
14 #include <cstdint>
15 #include <functional>
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 
20 #include <executorch/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.h>
21 #include <executorch/extension/llm/sampler/sampler.h>
22 #include <executorch/extension/llm/tokenizer/tokenizer.h>
23 #include <executorch/extension/module/module.h>
24 
25 namespace example {
26 
27 class Runner {
28  public:
29   explicit Runner(
30       const std::vector<std::string>& models_path,
31       const std::vector<std::string>& pos_embs_path,
32       const std::vector<int>& shard_layers,
33       const std::string& tokenizer_path,
34       const int eval_mode,
35       const float temperature,
36       const float logits_scale,
37       const int logits_offset);
38 
39   struct Stats {
40     // Scaling factor for timestamps - in this case, we use ms.
41     const long SCALING_FACTOR_UNITS_PER_SECOND = 1000;
42     // Time stamps for the different stages of the execution
43     // model_load_start_ms: Start of model loading.
44     long model_load_start_ms;
45     // model_load_end_ms: End of model loading.
46     long model_load_end_ms;
47     // inference_start_ms: Immediately after the model is loaded (or we check
48     // for model load), measure the inference time.
49     long inference_start_ms;
50     // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right
51     // before the inference loop starts
52     long prompt_eval_end_ms;
53     // first_token: Timestamp when the first generated token is emitted
54     long first_token_ms;
55     // inference_end_ms: End of inference/generation.
56     long inference_end_ms;
57     // Keep a running total of the time spent in sampling.
58     long aggregate_sampling_time_ms;
59     // Token count from prompt
60     int64_t num_prompt_tokens;
61     // Token count from generated (total - prompt)
62     int64_t num_generated_tokens;
63   };
64 
65   bool is_loaded() const;
66   executorch::runtime::Error load();
67   executorch::runtime::Error generate(
68       const std::string& prompt,
69       const std::string& system_prompt,
70       int32_t seq_len,
71       std::function<void(const std::string&)> token_callback = {},
72       std::function<void(const Stats&)> stats_callback = {});
73   void stop();
74   std::vector<executorch::runtime::Result<executorch::runtime::MethodMeta>>
75   get_methods_meta();
76 
77  private:
78   enum EvalMode {
79     kBert = 0,
80     kKVCached,
81     kUnsupported,
82   };
83 
84   enum LlamaVersion {
85     kLlama2 = 0,
86     kLlama3,
87   };
88 
89   int32_t logitsToToken(const executorch::aten::Tensor& logits_tensor);
90   void run_model_step(
91       std::vector<std::vector<executorch::runtime::EValue>>& inputs);
92   // metadata
93   int32_t bos_id_;
94   std::unordered_set<uint64_t> eos_id_;
95   const int32_t n_bos_;
96   const int32_t n_eos_;
97   const int32_t vocab_size_;
98   const int32_t max_seq_len_;
99   int32_t eval_mode_;
100   std::vector<std::shared_ptr<executorch::extension::Module>> modules_;
101   std::vector<std::string> method_names_;
102   std::string tokenizer_path_;
103   float temperature_;
104   std::unique_ptr<executorch::extension::llm::Tokenizer> tokenizer_;
105   std::unique_ptr<executorch::extension::llm::Sampler> sampler_;
106   Stats stats_;
107   std::unique_ptr<Memory> io_mem_;
108   const float logits_scale_;
109   const int32_t logits_offset_;
110   LlamaVersion version_;
111 };
112 
113 } // namespace example
114