1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * Copyright (c) 2024 MediaTek Inc.
4 * All rights reserved.
5 *
6 * This source code is licensed under the BSD-style license found in the
7 * LICENSE file in the root directory of this source tree.
8 */
9
10 /* Copyright Statement:
11 *
12 * This software/firmware and related documentation ("MediaTek Software") are
13 * protected under relevant copyright laws. The information contained herein
14 * is confidential and proprietary to MediaTek Inc. and/or its licensors.
15 * Without the prior written permission of MediaTek inc. and/or its licensors,
16 * any reproduction, modification, use or disclosure of MediaTek Software,
17 * and information contained herein, in whole or in part, shall be strictly
18 * prohibited.
19 */
20 /* MediaTek Inc. (C) 2024. All rights reserved.
21 *
22 * BY OPENING THIS FILE, RECEIVER HEREBY UNEQUIVOCALLY ACKNOWLEDGES AND AGREES
23 * THAT THE SOFTWARE/FIRMWARE AND ITS DOCUMENTATIONS ("MEDIATEK SOFTWARE")
24 * RECEIVED FROM MEDIATEK AND/OR ITS REPRESENTATIVES ARE PROVIDED TO RECEIVER ON
25 * AN "AS-IS" BASIS ONLY. MEDIATEK EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES,
26 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF
27 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE OR NONINFRINGEMENT.
28 * NEITHER DOES MEDIATEK PROVIDE ANY WARRANTY WHATSOEVER WITH RESPECT TO THE
29 * SOFTWARE OF ANY THIRD PARTY WHICH MAY BE USED BY, INCORPORATED IN, OR
30 * SUPPLIED WITH THE MEDIATEK SOFTWARE, AND RECEIVER AGREES TO LOOK ONLY TO SUCH
31 * THIRD PARTY FOR ANY WARRANTY CLAIM RELATING THERETO. RECEIVER EXPRESSLY
32 * ACKNOWLEDGES THAT IT IS RECEIVER'S SOLE RESPONSIBILITY TO OBTAIN FROM ANY
33 * THIRD PARTY ALL PROPER LICENSES CONTAINED IN MEDIATEK SOFTWARE. MEDIATEK
34 * SHALL ALSO NOT BE RESPONSIBLE FOR ANY MEDIATEK SOFTWARE RELEASES MADE TO
35 * RECEIVER'S SPECIFICATION OR TO CONFORM TO A PARTICULAR STANDARD OR OPEN
36 * FORUM. RECEIVER'S SOLE AND EXCLUSIVE REMEDY AND MEDIATEK'S ENTIRE AND
37 * CUMULATIVE LIABILITY WITH RESPECT TO THE MEDIATEK SOFTWARE RELEASED HEREUNDER
38 * WILL BE, AT MEDIATEK'S OPTION, TO REVISE OR REPLACE THE MEDIATEK SOFTWARE AT
39 * ISSUE, OR REFUND ANY SOFTWARE LICENSE FEES OR SERVICE CHARGE PAID BY RECEIVER
40 * TO MEDIATEK FOR SUCH MEDIATEK SOFTWARE AT ISSUE.
41 *
42 * The following software/firmware and/or related documentation ("MediaTek
43 * Software") have been modified by MediaTek Inc. All revisions are subject to
44 * any receiver's applicable license agreements with MediaTek Inc.
45 */
46
47 #include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
48 #include "executorch/backends/mediatek/runtime/include/NeuronBufferAllocator.h"
49
50 #include <ctime>
51 #include <iostream>
52 #include <memory>
53 #include <random>
54
55 #include <executorch/extension/data_loader/file_data_loader.h>
56 #include <executorch/extension/evalue_util/print_evalue.h>
57 #include <executorch/runtime/executor/method.h>
58 #include <executorch/runtime/executor/program.h>
59 #include <executorch/runtime/platform/log.h>
60 #include <executorch/runtime/platform/profiler.h>
61 #include <executorch/runtime/platform/runtime.h>
62 // #include <executorch/util/util.h>
63 #include <executorch/extension/llm/runner/util.h>
64 #include <executorch/runtime/core/result.h>
65
66 #include "llama_runner/ModelChunk.h"
67 #include "llama_runner/Utils.h"
68 #include "llama_runner/llm_helper/include/llama_runner_values.h"
69 #include "llama_runner/llm_helper/include/llm_types.h"
70
71 static uint64_t MAX_RESPONSE = 50; // Maximum number of tokens to generate.
72 // Global BOS and EOS option for tokenization (encoding)
73 static constexpr int8_t kAddBos = 1;
74 static constexpr int8_t kAddEos = 0;
75
76 using namespace example::llm_helper;
77 using example::utils::argmax;
78 using example::utils::split;
79 using example::utils::Timer;
80 using example::utils::to_string;
81 using namespace mtk::vars;
82
83 namespace llm = ::executorch::extension::llm;
84
MTKLlamaRunner(const std::string & model_path,const std::string & tokenizer_path,const float temperature)85 MTKLlamaRunner::MTKLlamaRunner(
86 const std::string& model_path,
87 const std::string& tokenizer_path,
88 const float temperature)
89 : modeloptions_(get_model_options()), modelpaths_(get_model_paths()) {
90 executorch::runtime::runtime_init();
91 ET_LOG(
92 Info,
93 "Creating MTK Llama runner. Current it will self-load .pte, .bin, and .so files. Initiated runtime_init().");
94 }
95
load()96 Error MTKLlamaRunner::load() {
97 if (is_loaded()) {
98 return Error::Ok;
99 }
100
101 // Load tokenizer
102 ET_LOG(Info, "Loading tokenizer.");
103 tokenizer_ = load_tokenizer();
104 ET_LOG(Info, "Complete loading tokenizer.");
105
106 // Load prompt model
107 runtime_ = std::make_unique<LlamaRuntime>();
108 ET_LOG(Info, "Loading prompt model.");
109 runtime_->Initialize(modeloptions_, modelpaths_);
110 ET_LOG(Info, "Complete loading prompt model.");
111
112 return Error::Ok;
113 }
114
is_loaded() const115 bool MTKLlamaRunner::is_loaded() const {
116 return tokenizer_ && runtime_;
117 }
118
generate(const std::string & prompt,int32_t seq_len,std::function<void (const std::string &)> token_callback,std::function<void (const Stats &)> stats_callback,bool echo,bool warming)119 Error MTKLlamaRunner::generate(
120 const std::string& prompt,
121 int32_t seq_len,
122 std::function<void(const std::string&)> token_callback,
123 std::function<void(const Stats&)> stats_callback,
124 bool echo,
125 bool warming) {
126 if (!is_loaded()) {
127 ET_CHECK_OK_OR_RETURN_ERROR(load());
128 }
129
130 // Wrap the token_callback with print function
131 std::function<void(const std::string&)> wrapped_callback =
132 [token_callback](const std::string& piece) {
133 llm::safe_printf(piece.c_str());
134 fflush(stdout);
135 if (token_callback) {
136 token_callback(piece);
137 }
138 };
139
140 ET_LOG(Info, "Starting inference from MTKLlamaRunner");
141 inference(*runtime_.get(), tokenizer_, prompt, wrapped_callback);
142 ET_LOG(Info, "Completed inference from MTKLlamaRunner");
143
144 return Error::Ok;
145 }
146
stop()147 void MTKLlamaRunner::stop() {
148 if (is_loaded()) {
149 runtime_->Release();
150 } else {
151 ET_LOG(Error, "Llama Runtime is not loaded, cannot stop");
152 }
153 }
154
get_model_options()155 LlamaModelOptions MTKLlamaRunner::get_model_options() {
156 LlamaModelOptions options = {
157 // Sizes
158 .prompt_token_batch_size = PROMPT_TOKEN_BATCH_SIZE,
159 .cache_size = CACHE_SIZE,
160 .hidden_size = HIDDEN_SIZE,
161 .num_head = NUM_HEAD,
162 .num_layer = NUM_LAYER,
163 .max_token_length = MAX_TOKEN_LENGTH,
164 .rot_emb_base = ROT_EMB_BASE,
165
166 // Types
167 .model_input_type = MODEL_INPUT_TYPE,
168 .model_output_type = MODEL_OUTPUT_TYPE,
169 .cache_type = CACHE_TYPE,
170 .mask_type = MASK_TYPE,
171 .rot_emb_type = ROT_EMB_TYPE};
172 ET_LOG(Info, "Completed get_model_options");
173 return options;
174 }
175
get_model_paths()176 LlamaModelPaths MTKLlamaRunner::get_model_paths() {
177 LlamaModelPaths model_paths = {
178 .tokenizer_path = TOKENIZER_PATH,
179 .token_embedding_path = TOKEN_EMBEDDING_PATH,
180 .prompt_model_paths = split(PROMPT_MODEL_PATHS, ','),
181 .gen_model_paths = split(GEN_MODEL_PATHS, ',')};
182 ET_LOG(Info, "Completed get_model_paths");
183 return model_paths;
184 }
185
digest_prompt(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const std::vector<uint64_t> input_tokens)186 Result<uint64_t> MTKLlamaRunner::digest_prompt(
187 LlamaRuntime& llama_runtime,
188 const std::unique_ptr<Tokenizer>& tokenizer,
189 const std::vector<uint64_t> input_tokens) {
190 const auto input_token_count = input_tokens.size();
191 const auto prompt_token_batch_size = llama_runtime.GetTokenBatchSize();
192 size_t cur_token_index = 0;
193
194 Timer timer_digest_prompt([=](const auto elapsed_sec) {
195 // Ideal prompt size is a multiple of prompt batch size
196 const size_t ideal_prompt_size =
197 std::ceil(float(input_token_count) / prompt_token_batch_size) *
198 prompt_token_batch_size;
199 ET_LOG(
200 Info,
201 "Done analyzing prompt in %f sec (%f tok/s)",
202 elapsed_sec,
203 (float)ideal_prompt_size / elapsed_sec);
204 });
205
206 auto getNextTokens = [&]() {
207 const size_t num_tok_remain = input_token_count - cur_token_index;
208 const size_t remainder = num_tok_remain % prompt_token_batch_size;
209 const size_t num_new_tokens =
210 remainder ? remainder : prompt_token_batch_size;
211 const auto start = cur_token_index;
212 const auto end = start + num_new_tokens;
213 return std::vector(
214 input_tokens.begin() + start, input_tokens.begin() + end);
215 };
216
217 void* logits;
218 timer_digest_prompt.Start();
219 while (cur_token_index < input_token_count) {
220 const auto next_tokens = getNextTokens();
221 ET_LOG(
222 Debug,
223 "Digest next tokens (size=%zu), 1st tok=%lu",
224 next_tokens.size(),
225 next_tokens[0]);
226 logits = llama_runtime.Run(next_tokens);
227 cur_token_index += next_tokens.size();
228 }
229 timer_digest_prompt.End();
230
231 const auto vocab_size = tokenizer->vocab_size();
232 const auto logits_type = llama_runtime.GetModelOptions().model_output_type;
233 const auto first_output_token = argmax(logits_type, logits, vocab_size);
234 return first_output_token;
235 }
236
gen_response(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const uint64_t input_token,std::function<void (const std::string &)> token_callback)237 Error MTKLlamaRunner::gen_response(
238 LlamaRuntime& llama_runtime,
239 const std::unique_ptr<Tokenizer>& tokenizer,
240 const uint64_t input_token,
241 std::function<void(const std::string&)> token_callback) {
242 Timer timer_model_swap(
243 [](const auto elapsed_sec) { ET_LOG(Info, "Model swapped."); });
244
245 // Swap to gen mode
246 timer_model_swap.Start();
247 llama_runtime.SwapModel(1);
248 timer_model_swap.End();
249
250 size_t gen_tok_count = 0;
251 uint64_t prev_token = input_token;
252 uint64_t output_token = input_token;
253
254 auto decode_res = tokenizer->decode(prev_token, output_token);
255 ET_CHECK_OR_RETURN_ERROR(
256 decode_res.ok(),
257 InvalidState,
258 "Tokenizer failed to decode first generated token: %lu",
259 output_token);
260 std::string full_response = std::move(decode_res.get());
261 std::vector<uint64_t> full_response_tokens = {input_token};
262
263 const auto vocab_size = tokenizer->vocab_size();
264 const auto logits_type = llama_runtime.GetModelOptions().model_output_type;
265
266 double gen_total_time_sec = 0;
267 Timer timer_gen_token(
268 [&](const auto elapsed_sec) { gen_total_time_sec += elapsed_sec; });
269
270 // Print first output token
271 token_callback(full_response);
272
273 while (gen_tok_count++ < MAX_RESPONSE &&
274 llama_runtime.GetTokenIndex() < modeloptions_.max_token_length) {
275 timer_gen_token.Start();
276 void* logits = llama_runtime.Run({output_token});
277 timer_gen_token.End();
278
279 prev_token = output_token;
280 output_token = argmax(logits_type, logits, vocab_size);
281 full_response_tokens.push_back(output_token);
282
283 // Stop when output is EOS
284 if (output_token == tokenizer->eos_tok()) {
285 token_callback("</eos>");
286 break;
287 }
288 auto decode_res = tokenizer->decode(prev_token, output_token);
289 ET_CHECK_OR_RETURN_ERROR(
290 decode_res.ok(),
291 InvalidState,
292 "Tokenizer failed to decode generated token %lu",
293 output_token);
294 const std::string tok_str = std::move(decode_res.get());
295 full_response += tok_str;
296 token_callback(tok_str);
297 }
298
299 std::cout << "\n\n[Generated Tokens]\n"
300 << to_string(full_response_tokens) << std::endl;
301
302 ET_LOG(
303 Info,
304 "Token generation speed: %f tok/s",
305 gen_tok_count / gen_total_time_sec);
306
307 return Error::Ok;
308 }
309
inference(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const std::string & prompt,std::function<void (const std::string &)> token_callback)310 Error MTKLlamaRunner::inference(
311 LlamaRuntime& llama_runtime,
312 const std::unique_ptr<Tokenizer>& tokenizer,
313 const std::string& prompt,
314 std::function<void(const std::string&)> token_callback) {
315 // Tokenize input prompt
316 auto encode_res = tokenizer->encode(prompt, kAddBos, kAddEos);
317 ET_CHECK_OR_RETURN_ERROR(
318 encode_res.ok(), InvalidState, "Tokenizer failed to encode prompt");
319 const auto input_tokens = std::move(encode_res.get());
320
321 // Run prompt mode (pre-fill)
322 auto prefill_res = digest_prompt(llama_runtime, tokenizer, input_tokens);
323 ET_CHECK_OR_RETURN_ERROR(
324 prefill_res.ok(), InvalidState, "Failed to digest prompt");
325 const auto first_output_token = prefill_res.get();
326
327 // run generation mode (decoding)
328 return gen_response(
329 llama_runtime, tokenizer, first_output_token, token_callback);
330 }
331
load_tokenizer()332 std::unique_ptr<Tokenizer> MTKLlamaRunner::load_tokenizer() {
333 std::unique_ptr<Tokenizer> tokenizer;
334 // Assumes that tokenizer type is Tiktoken
335 tokenizer = example::get_tiktoken_for_llama();
336 tokenizer->load(modelpaths_.tokenizer_path);
337 return tokenizer;
338 }
339