1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // A simple llama2 runner that includes preprocessing and post processing logic.
10 // The module takes in a string as input and emits a string as output.
11
12 #include <executorch/examples/models/llama/runner/runner.h>
13
14 #include <ctime>
15
16 #include <executorch/extension/llm/runner/util.h>
17
18 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
19 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
20
21 namespace example {
22
23 using ::executorch::extension::Module;
24 using ::executorch::runtime::Error;
25 using ::executorch::runtime::Result;
26
27 namespace llm = ::executorch::extension::llm;
28
29 namespace {
30 static constexpr auto kEnableDynamicShape = "enable_dynamic_shape";
31 static constexpr auto kBosId = "get_bos_id";
32 static constexpr auto kEosIds = "get_eos_ids";
33 static constexpr auto kMaxSeqLen = "get_max_seq_len";
34 static constexpr auto kVocabSize = "get_vocab_size";
35 static constexpr auto kUseKVCache = "use_kv_cache";
36 static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
37 } // namespace
38
Runner(const std::string & model_path,const std::string & tokenizer_path,const float temperature)39 Runner::Runner(
40 const std::string& model_path,
41 const std::string& tokenizer_path,
42 const float temperature)
43 // NOTE: we observed ~2x loading performance increase on iPhone 15
44 // and a ~5% improvement on Galaxy S22 by switching to
45 // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
46 : temperature_(temperature),
47 module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
48 tokenizer_path_(tokenizer_path),
49 metadata_({
50 {kEnableDynamicShape, false},
51 {kMaxSeqLen, 128},
52 {kUseKVCache, true},
53 {kUseSDPAWithKVCache, false},
54 }) {
55 ET_LOG(
56 Info,
57 "Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
58 model_path.c_str(),
59 tokenizer_path.c_str());
60 }
61
is_loaded() const62 bool Runner::is_loaded() const {
63 return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
64 text_prefiller_ && text_token_generator_;
65 }
66
load()67 Error Runner::load() {
68 if (is_loaded()) {
69 return Error::Ok;
70 }
71 ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
72 // load tokenizer. Assuming tiktoken is the default tokenizer
73 tokenizer_ = nullptr;
74 tokenizer_ = get_tiktoken_for_llama();
75 Error err = tokenizer_->load(tokenizer_path_);
76 // Rely on tiktoken to throw error if the artifact is incompatible. Then we
77 // fallback to BPE tokenizer.
78 if (err == Error::InvalidArgument) {
79 ET_LOG(
80 Info,
81 "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
82 tokenizer_path_.c_str());
83 tokenizer_.reset();
84 tokenizer_ = std::make_unique<llm::BPETokenizer>();
85 tokenizer_->load(tokenizer_path_);
86 }
87
88 ET_LOG(Info, "Reading metadata from model");
89
90 metadata_[kBosId] = tokenizer_->bos_tok();
91 auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
92 std::unordered_set<uint64_t>{tokenizer_->eos_tok()});
93 metadata_[kVocabSize] = tokenizer_->vocab_size();
94
95 const auto method_names =
96 ET_UNWRAP(module_->method_names(), "Failed reading method names");
97
98 for (auto& pair : metadata_) {
99 const auto& method_name = pair.first;
100 auto& value = pair.second;
101
102 if (method_names.count(method_name)) {
103 value = ET_UNWRAP(module_->get(method_name))
104 .toScalar()
105 .to<decltype(metadata_)::mapped_type>();
106 } else {
107 ET_LOG(
108 Info,
109 "Methond %s not found, using the default value %" PRId64,
110 method_name.c_str(),
111 value);
112 }
113 ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
114 }
115 if (method_names.count(kEosIds)) {
116 eos_ids->clear();
117 for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
118 auto value = eos_id.toScalar().to<int64_t>();
119 eos_ids->emplace(value);
120 ET_LOG(Info, "eos_id = %" PRId64, value);
121 }
122 }
123 text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
124 module_.get(),
125 metadata_.at(kUseKVCache),
126 metadata_.at(kVocabSize),
127 temperature_);
128 text_prefiller_ = std::make_unique<llm::TextPrefiller>(
129 text_decoder_runner_.get(),
130 metadata_.at(kUseKVCache),
131 metadata_.at(kEnableDynamicShape));
132
133 text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
134 tokenizer_.get(),
135 text_decoder_runner_.get(),
136 metadata_.at(kUseKVCache),
137 std::move(eos_ids),
138 &stats_);
139
140 return Error::Ok;
141 }
142
143 // Don't print with the same priority during warmup
144 #define RUNNER_ET_LOG(warmup, format, ...) \
145 if (warmup) { \
146 ET_LOG(Debug, format, __VA_ARGS__); \
147 } else { \
148 ET_LOG(Info, format, __VA_ARGS__); \
149 }
150
generate(const std::string & prompt,int32_t seq_len,std::function<void (const std::string &)> token_callback,std::function<void (const llm::Stats &)> stats_callback,bool echo,bool warmup)151 Error Runner::generate(
152 const std::string& prompt,
153 int32_t seq_len,
154 std::function<void(const std::string&)> token_callback,
155 std::function<void(const llm::Stats&)> stats_callback,
156 bool echo,
157 bool warmup) {
158 // Prepare the inputs.
159 // Use ones-initialized inputs.
160 ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
161 if (!is_loaded()) {
162 stats_.model_load_start_ms = llm::time_in_ms();
163 ET_CHECK_OK_OR_RETURN_ERROR(load());
164 stats_.model_load_end_ms = llm::time_in_ms();
165 }
166
167 if (warmup) {
168 ET_LOG(Info, "Doing a warmup run...");
169 }
170
171 RUNNER_ET_LOG(
172 warmup,
173 "RSS after loading model: %f MiB (0 if unsupported)",
174 llm::get_rss_bytes() / 1024.0 / 1024.0);
175
176 // Wrap the token_callback with print function
177 std::function<void(const std::string&)> wrapped_callback =
178 [token_callback, warmup](const std::string& piece) {
179 if (!warmup) {
180 llm::safe_printf(piece.c_str());
181 fflush(stdout);
182 }
183 if (token_callback) {
184 token_callback(piece);
185 }
186 };
187 // First token time only measures the time it takes to encode the prompt and
188 // return a response token.
189
190 stats_.inference_start_ms = llm::time_in_ms();
191 shouldStop_ = false;
192
193 // Set the sequence length to the max seq length if not provided
194 seq_len = (seq_len > 0 && seq_len <= metadata_.at(kMaxSeqLen))
195 ? seq_len
196 : metadata_.at(kMaxSeqLen);
197
198 Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
199 prompt,
200 /* bos */ 0,
201 /* eos */ 0);
202
203 ET_CHECK_OK_OR_RETURN_ERROR(
204 encode_res.error(), "Failed to encode prompt %s", prompt.c_str());
205
206 // encode the (string) prompt into tokens sequence
207 std::vector<uint64_t> prompt_tokens = encode_res.get();
208 int num_prompt_tokens = prompt_tokens.size();
209
210 ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
211 ET_CHECK_MSG(
212 num_prompt_tokens < metadata_.at(kMaxSeqLen),
213 "num_prompt_tokens %d >= max_seq_len_ %" PRId64
214 ", Max seq length exceeded - please increase max seq len value in .../llama2/model.py",
215 num_prompt_tokens,
216 metadata_.at(kMaxSeqLen));
217 ET_CHECK_MSG(
218 num_prompt_tokens < seq_len,
219 "num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()",
220 num_prompt_tokens,
221 seq_len);
222
223 // Prefill first
224 // Here feed all tokens to the model and get the next predicted token
225 // after the prompt. After that we will enter generate loop.
226
227 // print prompts
228 if (echo) {
229 wrapped_callback(prompt);
230 }
231 int64_t pos = 0;
232 auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
233 stats_.first_token_ms = llm::time_in_ms();
234 stats_.prompt_eval_end_ms = llm::time_in_ms();
235 ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
236 uint64_t cur_token = prefill_res.get();
237
238 // print the first token from prefill. No prev_token so use cur_token for it.
239 wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token)));
240 RUNNER_ET_LOG(
241 warmup,
242 "RSS after prompt prefill: %f MiB (0 if unsupported)",
243 llm::get_rss_bytes() / 1024.0 / 1024.0);
244
245 // start the main loop
246 prompt_tokens.push_back(cur_token);
247 int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
248 prompt_tokens, num_prompt_tokens, seq_len, wrapped_callback));
249
250 stats_.inference_end_ms = llm::time_in_ms();
251 if (!warmup) {
252 printf("\n");
253 }
254 RUNNER_ET_LOG(
255 warmup,
256 "RSS after finishing text generation: %f MiB (0 if unsupported)",
257 llm::get_rss_bytes() / 1024.0 / 1024.0);
258
259 if (num_prompt_tokens + num_generated_tokens == seq_len) {
260 RUNNER_ET_LOG(warmup, "Sequence length (%i tokens) reached!", seq_len);
261 }
262
263 stats_.num_prompt_tokens = num_prompt_tokens;
264 stats_.num_generated_tokens = num_generated_tokens;
265
266 if (warmup) {
267 ET_LOG(Info, "Warmup run finished!");
268 } else {
269 // Do not print report during warmup
270 ::executorch::llm::print_report(stats_);
271 }
272 if (stats_callback) {
273 stats_callback(stats_);
274 }
275
276 return Error::Ok;
277 }
278
warmup(const std::string & prompt,int32_t seq_len)279 Error Runner::warmup(const std::string& prompt, int32_t seq_len) {
280 Error err = generate(
281 prompt,
282 seq_len,
283 /*token_callback=*/nullptr,
284 /*stats_callbak=*/nullptr,
285 /*echo=*/false,
286 /*warmup=*/true);
287 stats_.reset();
288 return err;
289 }
290
stop()291 void Runner::stop() {
292 if (is_loaded()) {
293 text_token_generator_->stop();
294 } else {
295 ET_LOG(Error, "Token generator is not loaded, cannot stop");
296 }
297 }
298 } // namespace example
299