1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * Copyright (c) 2024 MediaTek Inc.
4 * All rights reserved.
5 *
6 * This source code is licensed under the BSD-style license found in the
7 * LICENSE file in the root directory of this source tree.
8 */
9
10 /* Copyright Statement:
11 *
12 * This software/firmware and related documentation ("MediaTek Software") are
13 * protected under relevant copyright laws. The information contained herein
14 * is confidential and proprietary to MediaTek Inc. and/or its licensors.
15 * Without the prior written permission of MediaTek inc. and/or its licensors,
16 * any reproduction, modification, use or disclosure of MediaTek Software,
17 * and information contained herein, in whole or in part, shall be strictly
18 * prohibited.
19 */
20 /* MediaTek Inc. (C) 2024. All rights reserved.
21 *
22 * BY OPENING THIS FILE, RECEIVER HEREBY UNEQUIVOCALLY ACKNOWLEDGES AND AGREES
23 * THAT THE SOFTWARE/FIRMWARE AND ITS DOCUMENTATIONS ("MEDIATEK SOFTWARE")
24 * RECEIVED FROM MEDIATEK AND/OR ITS REPRESENTATIVES ARE PROVIDED TO RECEIVER ON
25 * AN "AS-IS" BASIS ONLY. MEDIATEK EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES,
26 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF
27 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE OR NONINFRINGEMENT.
28 * NEITHER DOES MEDIATEK PROVIDE ANY WARRANTY WHATSOEVER WITH RESPECT TO THE
29 * SOFTWARE OF ANY THIRD PARTY WHICH MAY BE USED BY, INCORPORATED IN, OR
30 * SUPPLIED WITH THE MEDIATEK SOFTWARE, AND RECEIVER AGREES TO LOOK ONLY TO SUCH
31 * THIRD PARTY FOR ANY WARRANTY CLAIM RELATING THERETO. RECEIVER EXPRESSLY
32 * ACKNOWLEDGES THAT IT IS RECEIVER'S SOLE RESPONSIBILITY TO OBTAIN FROM ANY
33 * THIRD PARTY ALL PROPER LICENSES CONTAINED IN MEDIATEK SOFTWARE. MEDIATEK
34 * SHALL ALSO NOT BE RESPONSIBLE FOR ANY MEDIATEK SOFTWARE RELEASES MADE TO
35 * RECEIVER'S SPECIFICATION OR TO CONFORM TO A PARTICULAR STANDARD OR OPEN
36 * FORUM. RECEIVER'S SOLE AND EXCLUSIVE REMEDY AND MEDIATEK'S ENTIRE AND
37 * CUMULATIVE LIABILITY WITH RESPECT TO THE MEDIATEK SOFTWARE RELEASED HEREUNDER
38 * WILL BE, AT MEDIATEK'S OPTION, TO REVISE OR REPLACE THE MEDIATEK SOFTWARE AT
39 * ISSUE, OR REFUND ANY SOFTWARE LICENSE FEES OR SERVICE CHARGE PAID BY RECEIVER
40 * TO MEDIATEK FOR SUCH MEDIATEK SOFTWARE AT ISSUE.
41 *
42 * The following software/firmware and/or related documentation ("MediaTek
43 * Software") have been modified by MediaTek Inc. All revisions are subject to
44 * any receiver's applicable license agreements with MediaTek Inc.
45 */
46
47 #include "executorch/backends/mediatek/runtime/include/NeuronBufferAllocator.h"
48
49 #include <ctime>
50 #include <iostream>
51 #include <memory>
52 #include <random>
53
54 #include <gflags/gflags.h>
55
56 #include <executorch/extension/data_loader/file_data_loader.h>
57 #include <executorch/extension/evalue_util/print_evalue.h>
58 #include <executorch/runtime/executor/method.h>
59 #include <executorch/runtime/executor/program.h>
60 #include <executorch/runtime/platform/log.h>
61 #include <executorch/runtime/platform/profiler.h>
62 #include <executorch/runtime/platform/runtime.h>
63
64 #include "llama_runner/LlamaConfig.h"
65 #include "llama_runner/LlamaRuntime.h"
66 #include "llama_runner/ModelChunk.h"
67 #include "llama_runner/Utils.h"
68 #include "llama_runner/llm_helper/include/llm_types.h"
69
70 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
71 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
72 #include <executorch/extension/llm/tokenizer/tiktoken.h>
73
74 // Llama model options
75 DEFINE_uint64(
76 prompt_token_batch_size,
77 128,
78 "Token batch size for prompt model.");
79 DEFINE_uint64(cache_size, 1024, "Model cache size.");
80 DEFINE_uint64(hidden_size, 4096, "Model hidden size.");
81 DEFINE_uint64(num_head, 32, "Number of attention heads in each layer.");
82 DEFINE_uint64(num_layer, 32, "Number of layers in the model.");
83 DEFINE_uint64(
84 max_token_length,
85 2048,
86 "Maximum token length that the model supports.");
87 DEFINE_double(
88 rot_emb_base,
89 10000,
90 "Rotary embedding base value, aka 'rope_theta'.");
91
92 // Model IO Types
93 DEFINE_string(input_type, "int16", "Model input type. Default to 'int16'");
94 DEFINE_string(output_type, "int16", "Model output type. Default to 'int16'");
95 DEFINE_string(cache_type, "int16", "Model cache type. Default to 'int16'");
96 DEFINE_string(mask_type, "int16", "Model mask type. Default to 'int16'");
97 DEFINE_string(
98 rot_emb_type,
99 "int16",
100 "Model rotary embedding type. Default to 'int16'");
101
102 // Model Paths
103 DEFINE_string(
104 token_embedding_path,
105 "embedding.bin",
106 "Input token embedding lookup table path.");
107 DEFINE_string(
108 prompt_model_paths,
109 "model_128t.pte",
110 "Comma-separated prompt model paths.");
111 DEFINE_string(
112 gen_model_paths,
113 "model_1t.pte",
114 "Comma-separated generative model paths.");
115
116 // Tokenizer
117 DEFINE_string(tokenizer_path, "tokenizer.model", "tokenizer.model vocab path.");
118 DEFINE_string(
119 tokenizer_type,
120 "tiktoken",
121 "Tokenizer type. One of ['bpe', 'tiktoken'].");
122 DEFINE_uint64(vocab_size, 128000, "Tokenizer vocab size.");
123 DEFINE_uint64(bos_token, 128000, "BOS token id.");
124 DEFINE_uint64(eos_token, 128001, "EOS token id.");
125
126 // Inference
127 DEFINE_uint64(max_response, 50, "Maximum number of tokens to generate.");
128 DEFINE_string(prompt_file, "", "File containing the prompt text.");
129
130 // Global BOS and EOS option for tokenization (encoding)
131 static constexpr int8_t kAddBos = 1;
132 static constexpr int8_t kAddEos = 0;
133
134 using namespace example::llm_helper;
135 using example::LlamaModelOptions;
136 using example::LlamaModelPaths;
137 using example::LlamaRuntime;
138 using example::utils::argmax;
139 using example::utils::read_file;
140 using example::utils::split;
141 using example::utils::Timer;
142 using example::utils::to_string;
143 using executorch::extension::llm::BPETokenizer;
144 using executorch::extension::llm::Tokenizer;
145 using executorch::runtime::Error;
146 using executorch::runtime::Result;
147
get_model_options()148 LlamaModelOptions get_model_options() {
149 LlamaModelOptions options = {
150 // Sizes
151 .prompt_token_batch_size = FLAGS_prompt_token_batch_size,
152 .cache_size = FLAGS_cache_size,
153 .hidden_size = FLAGS_hidden_size,
154 .num_head = FLAGS_num_head,
155 .num_layer = FLAGS_num_layer,
156 .max_token_length = FLAGS_max_token_length,
157 .rot_emb_base = FLAGS_rot_emb_base,
158
159 // Types
160 .model_input_type = getLLMTypeFromName(FLAGS_input_type.c_str()),
161 .model_output_type = getLLMTypeFromName(FLAGS_output_type.c_str()),
162 .cache_type = getLLMTypeFromName(FLAGS_cache_type.c_str()),
163 .mask_type = getLLMTypeFromName(FLAGS_mask_type.c_str()),
164 .rot_emb_type = getLLMTypeFromName(FLAGS_rot_emb_type.c_str())};
165 return options;
166 }
167
get_model_paths()168 LlamaModelPaths get_model_paths() {
169 LlamaModelPaths model_paths = {
170 .tokenizer_path = FLAGS_tokenizer_path,
171 .token_embedding_path = FLAGS_token_embedding_path,
172 .prompt_model_paths = split(FLAGS_prompt_model_paths, ','),
173 .gen_model_paths = split(FLAGS_gen_model_paths, ',')};
174 return model_paths;
175 }
176
digest_prompt(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const std::vector<uint64_t> input_tokens)177 Result<uint64_t> digest_prompt(
178 LlamaRuntime& llama_runtime,
179 const std::unique_ptr<Tokenizer>& tokenizer,
180 const std::vector<uint64_t> input_tokens) {
181 const auto input_token_count = input_tokens.size();
182 const auto prompt_token_batch_size = llama_runtime.GetTokenBatchSize();
183 size_t cur_token_index = 0;
184
185 Timer timer_digest_prompt([=](const auto elapsed_sec) {
186 // Ideal prompt size is a multiple of prompt batch size
187 const size_t ideal_prompt_size =
188 std::ceil(float(input_token_count) / prompt_token_batch_size) *
189 prompt_token_batch_size;
190 ET_LOG(
191 Info,
192 "Done analyzing prompt in %f sec (%f tok/s)",
193 elapsed_sec,
194 (float)ideal_prompt_size / elapsed_sec);
195 });
196
197 auto getNextTokens = [&]() {
198 const size_t num_tok_remain = input_token_count - cur_token_index;
199 const size_t remainder = num_tok_remain % prompt_token_batch_size;
200 const size_t num_new_tokens =
201 remainder ? remainder : prompt_token_batch_size;
202 const auto start = cur_token_index;
203 const auto end = start + num_new_tokens;
204 return std::vector(
205 input_tokens.begin() + start, input_tokens.begin() + end);
206 };
207
208 void* logits;
209 timer_digest_prompt.Start();
210 while (cur_token_index < input_token_count) {
211 const auto next_tokens = getNextTokens();
212 ET_LOG(
213 Debug,
214 "Digest next tokens (size=%zu), 1st tok=%lu",
215 next_tokens.size(),
216 next_tokens[0]);
217 logits = llama_runtime.Run(next_tokens);
218 cur_token_index += next_tokens.size();
219 }
220 timer_digest_prompt.End();
221
222 const auto vocab_size = tokenizer->vocab_size();
223 const auto logits_type = llama_runtime.GetModelOptions().model_output_type;
224 const auto first_output_token = argmax(logits_type, logits, vocab_size);
225 return first_output_token;
226 }
227
gen_response(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const uint64_t input_token)228 Error gen_response(
229 LlamaRuntime& llama_runtime,
230 const std::unique_ptr<Tokenizer>& tokenizer,
231 const uint64_t input_token) {
232 Timer timer_model_swap(
233 [](const auto elapsed_sec) { ET_LOG(Info, "Model swapped."); });
234
235 // Swap to gen mode
236 timer_model_swap.Start();
237 llama_runtime.SwapModel(1);
238 timer_model_swap.End();
239
240 size_t gen_tok_count = 0;
241 uint64_t prev_token = input_token;
242 uint64_t output_token = input_token;
243
244 auto decode_res = tokenizer->decode(prev_token, output_token);
245 ET_CHECK_OR_RETURN_ERROR(
246 decode_res.ok(),
247 InvalidState,
248 "Tokenizer failed to decode first generated token: %lu",
249 output_token);
250 std::string full_response = std::move(decode_res.get());
251 std::vector<uint64_t> full_response_tokens = {input_token};
252
253 const auto vocab_size = tokenizer->vocab_size();
254 const auto logits_type = llama_runtime.GetModelOptions().model_output_type;
255
256 double gen_total_time_sec = 0;
257 Timer timer_gen_token(
258 [&](const auto elapsed_sec) { gen_total_time_sec += elapsed_sec; });
259
260 // Print first output token
261 std::cout << "\n[Real-time Response]" << std::endl;
262 std::cout << full_response << std::flush;
263
264 while (gen_tok_count++ < FLAGS_max_response &&
265 llama_runtime.GetTokenIndex() < FLAGS_max_token_length) {
266 timer_gen_token.Start();
267 void* logits = llama_runtime.Run({output_token});
268 timer_gen_token.End();
269
270 prev_token = output_token;
271 output_token = argmax(logits_type, logits, vocab_size);
272 full_response_tokens.push_back(output_token);
273
274 // Stop when output is EOS
275 if (output_token == tokenizer->eos_tok()) {
276 std::cout << "</eos>" << std::flush;
277 break;
278 }
279 auto decode_res = tokenizer->decode(prev_token, output_token);
280 ET_CHECK_OR_RETURN_ERROR(
281 decode_res.ok(),
282 InvalidState,
283 "Tokenizer failed to decode generated token %lu",
284 output_token);
285 const std::string tok_str = std::move(decode_res.get());
286 full_response += tok_str;
287 std::cout << tok_str << std::flush;
288 }
289
290 std::cout << "\n\n[Generated Tokens]\n"
291 << to_string(full_response_tokens) << std::endl;
292
293 ET_LOG(
294 Info,
295 "Token generation speed: %f tok/s",
296 gen_tok_count / gen_total_time_sec);
297
298 return Error::Ok;
299 }
300
inference(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const std::string & prompt)301 Error inference(
302 LlamaRuntime& llama_runtime,
303 const std::unique_ptr<Tokenizer>& tokenizer,
304 const std::string& prompt) {
305 // Tokenize input prompt
306 auto encode_res = tokenizer->encode(prompt, kAddBos, kAddEos);
307 ET_CHECK_OR_RETURN_ERROR(
308 encode_res.ok(), InvalidState, "Tokenizer failed to encode prompt");
309 const auto input_tokens = std::move(encode_res.get());
310
311 std::cout << "\n[Input Prompt]\n" << prompt << std::endl;
312
313 // Run prompt mode (pre-fill)
314 auto prefill_res = digest_prompt(llama_runtime, tokenizer, input_tokens);
315 ET_CHECK_OR_RETURN_ERROR(
316 prefill_res.ok(), InvalidState, "Failed to digest prompt");
317 const auto first_output_token = prefill_res.get();
318
319 // run generation mode (decoding)
320 return gen_response(llama_runtime, tokenizer, first_output_token);
321 }
322
load_tokenizer()323 std::unique_ptr<Tokenizer> load_tokenizer() {
324 std::unique_ptr<Tokenizer> tokenizer;
325 if (FLAGS_tokenizer_type == "bpe") {
326 tokenizer = std::make_unique<BPETokenizer>();
327 } else if (FLAGS_tokenizer_type == "tiktoken") {
328 tokenizer = example::get_tiktoken_for_llama();
329 }
330 ET_CHECK_MSG(
331 tokenizer, "Invalid tokenizer type: %s", FLAGS_tokenizer_type.c_str());
332 tokenizer->load(FLAGS_tokenizer_path);
333 return tokenizer;
334 }
335
main(int argc,char ** argv)336 int main(int argc, char** argv) {
337 executorch::runtime::runtime_init();
338
339 gflags::ParseCommandLineFlags(&argc, &argv, true);
340 if (argc != 1) {
341 std::string msg = "Extra commandline args:";
342 for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) {
343 msg += std::string(" ") + argv[i];
344 }
345 ET_LOG(Error, "%s", msg.c_str());
346 return 1;
347 }
348
349 LlamaModelOptions model_options = get_model_options();
350 LlamaModelPaths model_paths = get_model_paths();
351
352 if (model_paths.prompt_model_paths.empty()) {
353 model_options.prompt_token_batch_size = 1;
354 ET_LOG(
355 Info,
356 "No prompt model paths provided, overriding prompt_token_batch_size to 1");
357 }
358
359 // Prepare timers
360 Timer timer_init(
361 [](const auto elapsed_sec) { ET_LOG(Info, "Model initialized."); });
362 Timer timer_release(
363 [](const auto elapsed_sec) { ET_LOG(Info, "Model released."); });
364
365 LlamaRuntime llama_runtime;
366
367 // Initialize model
368 ET_LOG(Info, "Begin model loading.");
369 timer_init.Start();
370 const auto tokenizer = load_tokenizer();
371 llama_runtime.Initialize(model_options, model_paths);
372 timer_init.End();
373
374 // Run model
375 ET_CHECK_MSG(!FLAGS_prompt_file.empty(), "No prompt file provided.");
376 std::string prompt = read_file(FLAGS_prompt_file);
377 inference(llama_runtime, tokenizer, prompt);
378
379 // Release model
380 timer_release.Start();
381 llama_runtime.Release();
382 timer_release.End();
383
384 return 0;
385 }
386