xref: /aosp_15_r20/external/executorch/examples/models/llama/runner/runner.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // A simple llama2 runner that includes preprocessing and post processing logic.
10 // The module takes in a string as input and emits a string as output.
11 
12 #include <executorch/examples/models/llama/runner/runner.h>
13 
14 #include <ctime>
15 
16 #include <executorch/extension/llm/runner/util.h>
17 
18 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
19 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
20 
21 namespace example {
22 
23 using ::executorch::extension::Module;
24 using ::executorch::runtime::Error;
25 using ::executorch::runtime::Result;
26 
27 namespace llm = ::executorch::extension::llm;
28 
29 namespace {
30 static constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
31 static constexpr auto kBosId = "get_bos_id";
32 static constexpr auto kEosIds = "get_eos_ids";
33 static constexpr auto kMaxSeqLen = "get_max_seq_len";
34 static constexpr auto kVocabSize = "get_vocab_size";
35 static constexpr auto kUseKVCache = "use_kv_cache";
36 static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
37 } // namespace
38 
Runner(const std::string & model_path,const std::string & tokenizer_path,const float temperature)39 Runner::Runner(
40     const std::string& model_path,
41     const std::string& tokenizer_path,
42     const float temperature)
43     // NOTE: we observed ~2x loading performance increase on iPhone 15
44     // and a ~5% improvement on Galaxy S22 by switching to
45     // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
46     : temperature_(temperature),
47       module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
48       tokenizer_path_(tokenizer_path),
49       metadata_({
50           {kEnableDynamicShape, false},
51           {kMaxSeqLen, 128},
52           {kUseKVCache, true},
53           {kUseSDPAWithKVCache, false},
54       }) {
55   ET_LOG(
56       Info,
57       "Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
58       model_path.c_str(),
59       tokenizer_path.c_str());
60 }
61 
is_loaded() const62 bool Runner::is_loaded() const {
63   return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
64       text_prefiller_ && text_token_generator_;
65 }
66 
load()67 Error Runner::load() {
68   if (is_loaded()) {
69     return Error::Ok;
70   }
71   ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
72   // load tokenizer. Assuming tiktoken is the default tokenizer
73   tokenizer_ = nullptr;
74   tokenizer_ = get_tiktoken_for_llama();
75   Error err = tokenizer_->load(tokenizer_path_);
76   // Rely on tiktoken to throw error if the artifact is incompatible. Then we
77   // fallback to BPE tokenizer.
78   if (err == Error::InvalidArgument) {
79     ET_LOG(
80         Info,
81         "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
82         tokenizer_path_.c_str());
83     tokenizer_.reset();
84     tokenizer_ = std::make_unique<llm::BPETokenizer>();
85     tokenizer_->load(tokenizer_path_);
86   }
87 
88   ET_LOG(Info, "Reading metadata from model");
89 
90   metadata_[kBosId] = tokenizer_->bos_tok();
91   auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
92       std::unordered_set<uint64_t>{tokenizer_->eos_tok()});
93   metadata_[kVocabSize] = tokenizer_->vocab_size();
94 
95   const auto method_names =
96       ET_UNWRAP(module_->method_names(), "Failed reading method names");
97 
98   for (auto& pair : metadata_) {
99     const auto& method_name = pair.first;
100     auto& value = pair.second;
101 
102     if (method_names.count(method_name)) {
103       value = ET_UNWRAP(module_->get(method_name))
104                   .toScalar()
105                   .to<decltype(metadata_)::mapped_type>();
106     } else {
107       ET_LOG(
108           Info,
109           "Methond %s not found, using the default value %" PRId64,
110           method_name.c_str(),
111           value);
112     }
113     ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
114   }
115   if (method_names.count(kEosIds)) {
116     eos_ids->clear();
117     for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
118       auto value = eos_id.toScalar().to<int64_t>();
119       eos_ids->emplace(value);
120       ET_LOG(Info, "eos_id = %" PRId64, value);
121     }
122   }
123   text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
124       module_.get(),
125       metadata_.at(kUseKVCache),
126       metadata_.at(kVocabSize),
127       temperature_);
128   text_prefiller_ = std::make_unique<llm::TextPrefiller>(
129       text_decoder_runner_.get(),
130       metadata_.at(kUseKVCache),
131       metadata_.at(kEnableDynamicShape));
132 
133   text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
134       tokenizer_.get(),
135       text_decoder_runner_.get(),
136       metadata_.at(kUseKVCache),
137       std::move(eos_ids),
138       &stats_);
139 
140   return Error::Ok;
141 }
142 
143 // Don't print with the same priority during warmup
144 #define RUNNER_ET_LOG(warmup, format, ...) \
145   if (warmup) {                            \
146     ET_LOG(Debug, format, __VA_ARGS__);    \
147   } else {                                 \
148     ET_LOG(Info, format, __VA_ARGS__);     \
149   }
150 
generate(const std::string & prompt,int32_t seq_len,std::function<void (const std::string &)> token_callback,std::function<void (const llm::Stats &)> stats_callback,bool echo,bool warmup)151 Error Runner::generate(
152     const std::string& prompt,
153     int32_t seq_len,
154     std::function<void(const std::string&)> token_callback,
155     std::function<void(const llm::Stats&)> stats_callback,
156     bool echo,
157     bool warmup) {
158   // Prepare the inputs.
159   // Use ones-initialized inputs.
160   ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
161   if (!is_loaded()) {
162     stats_.model_load_start_ms = llm::time_in_ms();
163     ET_CHECK_OK_OR_RETURN_ERROR(load());
164     stats_.model_load_end_ms = llm::time_in_ms();
165   }
166 
167   if (warmup) {
168     ET_LOG(Info, "Doing a warmup run...");
169   }
170 
171   RUNNER_ET_LOG(
172       warmup,
173       "RSS after loading model: %f MiB (0 if unsupported)",
174       llm::get_rss_bytes() / 1024.0 / 1024.0);
175 
176   // Wrap the token_callback with print function
177   std::function<void(const std::string&)> wrapped_callback =
178       [token_callback, warmup](const std::string& piece) {
179         if (!warmup) {
180           llm::safe_printf(piece.c_str());
181           fflush(stdout);
182         }
183         if (token_callback) {
184           token_callback(piece);
185         }
186       };
187   // First token time only measures the time it takes to encode the prompt and
188   // return a response token.
189 
190   stats_.inference_start_ms = llm::time_in_ms();
191   shouldStop_ = false;
192 
193   // Set the sequence length to the max seq length if not provided
194   seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxSeqLen))
195       ? seq_len
196       : metadata_.at(kMaxSeqLen);
197 
198   Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
199       prompt,
200       /* bos */ 0,
201       /* eos */ 0);
202 
203   ET_CHECK_OK_OR_RETURN_ERROR(
204       encode_res.error(), "Failed to encode prompt %s", prompt.c_str());
205 
206   // encode the (string) prompt into tokens sequence
207   std::vector<uint64_t> prompt_tokens = encode_res.get();
208   int num_prompt_tokens = prompt_tokens.size();
209 
210   ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
211   ET_CHECK_MSG(
212       num_prompt_tokens < metadata_.at(kMaxSeqLen),
213       "num_prompt_tokens %d >= max_seq_len_ %" PRId64
214       ", Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
215       num_prompt_tokens,
216       metadata_.at(kMaxSeqLen));
217   ET_CHECK_MSG(
218       num_prompt_tokens < seq_len,
219       "num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()",
220       num_prompt_tokens,
221       seq_len);
222 
223   // Prefill first
224   // Here feed all tokens to the model and get the next predicted token
225   // after the prompt. After that we will enter generate loop.
226 
227   // print prompts
228   if (echo) {
229     wrapped_callback(prompt);
230   }
231   int64_t pos = 0;
232   auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
233   stats_.first_token_ms = llm::time_in_ms();
234   stats_.prompt_eval_end_ms = llm::time_in_ms();
235   ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
236   uint64_t cur_token = prefill_res.get();
237 
238   // print the first token from prefill. No prev_token so use cur_token for it.
239   wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token)));
240   RUNNER_ET_LOG(
241       warmup,
242       "RSS after prompt prefill: %f MiB (0 if unsupported)",
243       llm::get_rss_bytes() / 1024.0 / 1024.0);
244 
245   // start the main loop
246   prompt_tokens.push_back(cur_token);
247   int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
248       prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));
249 
250   stats_.inference_end_ms = llm::time_in_ms();
251   if (!warmup) {
252     printf("\n");
253   }
254   RUNNER_ET_LOG(
255       warmup,
256       "RSS after finishing text generation: %f MiB (0 if unsupported)",
257       llm::get_rss_bytes() / 1024.0 / 1024.0);
258 
259   if (num_prompt_tokens + num_generated_tokens == seq_len) {
260     RUNNER_ET_LOG(warmup, "Sequence length (%i tokens) reached!", seq_len);
261   }
262 
263   stats_.num_prompt_tokens = num_prompt_tokens;
264   stats_.num_generated_tokens = num_generated_tokens;
265 
266   if (warmup) {
267     ET_LOG(Info, "Warmup run finished!");
268   } else {
269     // Do not print report during warmup
270     ::executorch::llm::print_report(stats_);
271   }
272   if (stats_callback) {
273     stats_callback(stats_);
274   }
275 
276   return Error::Ok;
277 }
278 
warmup(const std::string & prompt,int32_t seq_len)279 Error Runner::warmup(const std::string& prompt, int32_t seq_len) {
280   Error err = generate(
281       prompt,
282       seq_len,
283       /*token_callback=*/nullptr,
284       /*stats_callbak=*/nullptr,
285       /*echo=*/false,
286       /*warmup=*/true);
287   stats_.reset();
288   return err;
289 }
290 
stop()291 void Runner::stop() {
292   if (is_loaded()) {
293     text_token_generator_->stop();
294   } else {
295     ET_LOG(Error, "Token generator is not loaded, cannot stop");
296   }
297 }
298 } // namespace example
299