xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/mtk_llama_executor_runner.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * Copyright (c) 2024 MediaTek Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the BSD-style license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 /* Copyright Statement:
11  *
12  * This software/firmware and related documentation ("MediaTek Software") are
13  * protected under relevant copyright laws. The information contained herein
14  * is confidential and proprietary to MediaTek Inc. and/or its licensors.
15  * Without the prior written permission of MediaTek inc. and/or its licensors,
16  * any reproduction, modification, use or disclosure of MediaTek Software,
17  * and information contained herein, in whole or in part, shall be strictly
18  * prohibited.
19  */
20 /* MediaTek Inc. (C) 2024. All rights reserved.
21  *
22  * BY OPENING THIS FILE, RECEIVER HEREBY UNEQUIVOCALLY ACKNOWLEDGES AND AGREES
23  * THAT THE SOFTWARE/FIRMWARE AND ITS DOCUMENTATIONS ("MEDIATEK SOFTWARE")
24  * RECEIVED FROM MEDIATEK AND/OR ITS REPRESENTATIVES ARE PROVIDED TO RECEIVER ON
25  * AN "AS-IS" BASIS ONLY. MEDIATEK EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES,
26  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF
27  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE OR NONINFRINGEMENT.
28  * NEITHER DOES MEDIATEK PROVIDE ANY WARRANTY WHATSOEVER WITH RESPECT TO THE
29  * SOFTWARE OF ANY THIRD PARTY WHICH MAY BE USED BY, INCORPORATED IN, OR
30  * SUPPLIED WITH THE MEDIATEK SOFTWARE, AND RECEIVER AGREES TO LOOK ONLY TO SUCH
31  * THIRD PARTY FOR ANY WARRANTY CLAIM RELATING THERETO. RECEIVER EXPRESSLY
32  * ACKNOWLEDGES THAT IT IS RECEIVER'S SOLE RESPONSIBILITY TO OBTAIN FROM ANY
33  * THIRD PARTY ALL PROPER LICENSES CONTAINED IN MEDIATEK SOFTWARE. MEDIATEK
34  * SHALL ALSO NOT BE RESPONSIBLE FOR ANY MEDIATEK SOFTWARE RELEASES MADE TO
35  * RECEIVER'S SPECIFICATION OR TO CONFORM TO A PARTICULAR STANDARD OR OPEN
36  * FORUM. RECEIVER'S SOLE AND EXCLUSIVE REMEDY AND MEDIATEK'S ENTIRE AND
37  * CUMULATIVE LIABILITY WITH RESPECT TO THE MEDIATEK SOFTWARE RELEASED HEREUNDER
38  * WILL BE, AT MEDIATEK'S OPTION, TO REVISE OR REPLACE THE MEDIATEK SOFTWARE AT
39  * ISSUE, OR REFUND ANY SOFTWARE LICENSE FEES OR SERVICE CHARGE PAID BY RECEIVER
40  * TO MEDIATEK FOR SUCH MEDIATEK SOFTWARE AT ISSUE.
41  *
42  * The following software/firmware and/or related documentation ("MediaTek
43  * Software") have been modified by MediaTek Inc. All revisions are subject to
44  * any receiver's applicable license agreements with MediaTek Inc.
45  */
46 
47 #include "executorch/backends/mediatek/runtime/include/NeuronBufferAllocator.h"
48 
49 #include <ctime>
50 #include <iostream>
51 #include <memory>
52 #include <random>
53 
54 #include <gflags/gflags.h>
55 
56 #include <executorch/extension/data_loader/file_data_loader.h>
57 #include <executorch/extension/evalue_util/print_evalue.h>
58 #include <executorch/runtime/executor/method.h>
59 #include <executorch/runtime/executor/program.h>
60 #include <executorch/runtime/platform/log.h>
61 #include <executorch/runtime/platform/profiler.h>
62 #include <executorch/runtime/platform/runtime.h>
63 
64 #include "llama_runner/LlamaConfig.h"
65 #include "llama_runner/LlamaRuntime.h"
66 #include "llama_runner/ModelChunk.h"
67 #include "llama_runner/Utils.h"
68 #include "llama_runner/llm_helper/include/llm_types.h"
69 
70 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
71 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
72 #include <executorch/extension/llm/tokenizer/tiktoken.h>
73 
74 // Llama model options
75 DEFINE_uint64(
76     prompt_token_batch_size,
77     128,
78     "Token batch size for prompt model.");
79 DEFINE_uint64(cache_size, 1024, "Model cache size.");
80 DEFINE_uint64(hidden_size, 4096, "Model hidden size.");
81 DEFINE_uint64(num_head, 32, "Number of attention heads in each layer.");
82 DEFINE_uint64(num_layer, 32, "Number of layers in the model.");
83 DEFINE_uint64(
84     max_token_length,
85     2048,
86     "Maximum token length that the model supports.");
87 DEFINE_double(
88     rot_emb_base,
89     10000,
90     "Rotary embedding base value, aka 'rope_theta'.");
91 
92 // Model IO Types
93 DEFINE_string(input_type, "int16", "Model input type. Default to 'int16'");
94 DEFINE_string(output_type, "int16", "Model output type. Default to 'int16'");
95 DEFINE_string(cache_type, "int16", "Model cache type. Default to 'int16'");
96 DEFINE_string(mask_type, "int16", "Model mask type. Default to 'int16'");
97 DEFINE_string(
98     rot_emb_type,
99     "int16",
100     "Model rotary embedding type. Default to 'int16'");
101 
102 // Model Paths
103 DEFINE_string(
104     token_embedding_path,
105     "embedding.bin",
106     "Input token embedding lookup table path.");
107 DEFINE_string(
108     prompt_model_paths,
109     "model_128t.pte",
110     "Comma-separated prompt model paths.");
111 DEFINE_string(
112     gen_model_paths,
113     "model_1t.pte",
114     "Comma-separated generative model paths.");
115 
116 // Tokenizer
117 DEFINE_string(tokenizer_path, "tokenizer.model", "tokenizer.model vocab path.");
118 DEFINE_string(
119     tokenizer_type,
120     "tiktoken",
121     "Tokenizer type. One of ['bpe', 'tiktoken'].");
122 DEFINE_uint64(vocab_size, 128000, "Tokenizer vocab size.");
123 DEFINE_uint64(bos_token, 128000, "BOS token id.");
124 DEFINE_uint64(eos_token, 128001, "EOS token id.");
125 
126 // Inference
127 DEFINE_uint64(max_response, 50, "Maximum number of tokens to generate.");
128 DEFINE_string(prompt_file, "", "File containing the prompt text.");
129 
130 // Global BOS and EOS option for tokenization (encoding)
131 static constexpr int8_t kAddBos = 1;
132 static constexpr int8_t kAddEos = 0;
133 
134 using namespace example::llm_helper;
135 using example::LlamaModelOptions;
136 using example::LlamaModelPaths;
137 using example::LlamaRuntime;
138 using example::utils::argmax;
139 using example::utils::read_file;
140 using example::utils::split;
141 using example::utils::Timer;
142 using example::utils::to_string;
143 using executorch::extension::llm::BPETokenizer;
144 using executorch::extension::llm::Tokenizer;
145 using executorch::runtime::Error;
146 using executorch::runtime::Result;
147 
get_model_options()148 LlamaModelOptions get_model_options() {
149   LlamaModelOptions options = {
150       // Sizes
151       .prompt_token_batch_size = FLAGS_prompt_token_batch_size,
152       .cache_size = FLAGS_cache_size,
153       .hidden_size = FLAGS_hidden_size,
154       .num_head = FLAGS_num_head,
155       .num_layer = FLAGS_num_layer,
156       .max_token_length = FLAGS_max_token_length,
157       .rot_emb_base = FLAGS_rot_emb_base,
158 
159       // Types
160       .model_input_type = getLLMTypeFromName(FLAGS_input_type.c_str()),
161       .model_output_type = getLLMTypeFromName(FLAGS_output_type.c_str()),
162       .cache_type = getLLMTypeFromName(FLAGS_cache_type.c_str()),
163       .mask_type = getLLMTypeFromName(FLAGS_mask_type.c_str()),
164       .rot_emb_type = getLLMTypeFromName(FLAGS_rot_emb_type.c_str())};
165   return options;
166 }
167 
get_model_paths()168 LlamaModelPaths get_model_paths() {
169   LlamaModelPaths model_paths = {
170       .tokenizer_path = FLAGS_tokenizer_path,
171       .token_embedding_path = FLAGS_token_embedding_path,
172       .prompt_model_paths = split(FLAGS_prompt_model_paths, ','),
173       .gen_model_paths = split(FLAGS_gen_model_paths, ',')};
174   return model_paths;
175 }
176 
digest_prompt(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const std::vector<uint64_t> input_tokens)177 Result<uint64_t> digest_prompt(
178     LlamaRuntime& llama_runtime,
179     const std::unique_ptr<Tokenizer>& tokenizer,
180     const std::vector<uint64_t> input_tokens) {
181   const auto input_token_count = input_tokens.size();
182   const auto prompt_token_batch_size = llama_runtime.GetTokenBatchSize();
183   size_t cur_token_index = 0;
184 
185   Timer timer_digest_prompt([=](const auto elapsed_sec) {
186     // Ideal prompt size is a multiple of prompt batch size
187     const size_t ideal_prompt_size =
188         std::ceil(float(input_token_count) / prompt_token_batch_size) *
189         prompt_token_batch_size;
190     ET_LOG(
191         Info,
192         "Done analyzing prompt in %f sec (%f tok/s)",
193         elapsed_sec,
194         (float)ideal_prompt_size / elapsed_sec);
195   });
196 
197   auto getNextTokens = [&]() {
198     const size_t num_tok_remain = input_token_count - cur_token_index;
199     const size_t remainder = num_tok_remain % prompt_token_batch_size;
200     const size_t num_new_tokens =
201         remainder ? remainder : prompt_token_batch_size;
202     const auto start = cur_token_index;
203     const auto end = start + num_new_tokens;
204     return std::vector(
205         input_tokens.begin() + start, input_tokens.begin() + end);
206   };
207 
208   void* logits;
209   timer_digest_prompt.Start();
210   while (cur_token_index < input_token_count) {
211     const auto next_tokens = getNextTokens();
212     ET_LOG(
213         Debug,
214         "Digest next tokens (size=%zu), 1st tok=%lu",
215         next_tokens.size(),
216         next_tokens[0]);
217     logits = llama_runtime.Run(next_tokens);
218     cur_token_index += next_tokens.size();
219   }
220   timer_digest_prompt.End();
221 
222   const auto vocab_size = tokenizer->vocab_size();
223   const auto logits_type = llama_runtime.GetModelOptions().model_output_type;
224   const auto first_output_token = argmax(logits_type, logits, vocab_size);
225   return first_output_token;
226 }
227 
gen_response(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const uint64_t input_token)228 Error gen_response(
229     LlamaRuntime& llama_runtime,
230     const std::unique_ptr<Tokenizer>& tokenizer,
231     const uint64_t input_token) {
232   Timer timer_model_swap(
233       [](const auto elapsed_sec) { ET_LOG(Info, "Model swapped."); });
234 
235   // Swap to gen mode
236   timer_model_swap.Start();
237   llama_runtime.SwapModel(1);
238   timer_model_swap.End();
239 
240   size_t gen_tok_count = 0;
241   uint64_t prev_token = input_token;
242   uint64_t output_token = input_token;
243 
244   auto decode_res = tokenizer->decode(prev_token, output_token);
245   ET_CHECK_OR_RETURN_ERROR(
246       decode_res.ok(),
247       InvalidState,
248       "Tokenizer failed to decode first generated token: %lu",
249       output_token);
250   std::string full_response = std::move(decode_res.get());
251   std::vector<uint64_t> full_response_tokens = {input_token};
252 
253   const auto vocab_size = tokenizer->vocab_size();
254   const auto logits_type = llama_runtime.GetModelOptions().model_output_type;
255 
256   double gen_total_time_sec = 0;
257   Timer timer_gen_token(
258       [&](const auto elapsed_sec) { gen_total_time_sec += elapsed_sec; });
259 
260   // Print first output token
261   std::cout << "\n[Real-time Response]" << std::endl;
262   std::cout << full_response << std::flush;
263 
264   while (gen_tok_count++ < FLAGS_max_response &&
265          llama_runtime.GetTokenIndex() < FLAGS_max_token_length) {
266     timer_gen_token.Start();
267     void* logits = llama_runtime.Run({output_token});
268     timer_gen_token.End();
269 
270     prev_token = output_token;
271     output_token = argmax(logits_type, logits, vocab_size);
272     full_response_tokens.push_back(output_token);
273 
274     // Stop when output is EOS
275     if (output_token == tokenizer->eos_tok()) {
276       std::cout << "</eos>" << std::flush;
277       break;
278     }
279     auto decode_res = tokenizer->decode(prev_token, output_token);
280     ET_CHECK_OR_RETURN_ERROR(
281         decode_res.ok(),
282         InvalidState,
283         "Tokenizer failed to decode generated token %lu",
284         output_token);
285     const std::string tok_str = std::move(decode_res.get());
286     full_response += tok_str;
287     std::cout << tok_str << std::flush;
288   }
289 
290   std::cout << "\n\n[Generated Tokens]\n"
291             << to_string(full_response_tokens) << std::endl;
292 
293   ET_LOG(
294       Info,
295       "Token generation speed: %f tok/s",
296       gen_tok_count / gen_total_time_sec);
297 
298   return Error::Ok;
299 }
300 
inference(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const std::string & prompt)301 Error inference(
302     LlamaRuntime& llama_runtime,
303     const std::unique_ptr<Tokenizer>& tokenizer,
304     const std::string& prompt) {
305   // Tokenize input prompt
306   auto encode_res = tokenizer->encode(prompt, kAddBos, kAddEos);
307   ET_CHECK_OR_RETURN_ERROR(
308       encode_res.ok(), InvalidState, "Tokenizer failed to encode prompt");
309   const auto input_tokens = std::move(encode_res.get());
310 
311   std::cout << "\n[Input Prompt]\n" << prompt << std::endl;
312 
313   // Run prompt mode (pre-fill)
314   auto prefill_res = digest_prompt(llama_runtime, tokenizer, input_tokens);
315   ET_CHECK_OR_RETURN_ERROR(
316       prefill_res.ok(), InvalidState, "Failed to digest prompt");
317   const auto first_output_token = prefill_res.get();
318 
319   // run generation mode (decoding)
320   return gen_response(llama_runtime, tokenizer, first_output_token);
321 }
322 
load_tokenizer()323 std::unique_ptr<Tokenizer> load_tokenizer() {
324   std::unique_ptr<Tokenizer> tokenizer;
325   if (FLAGS_tokenizer_type == "bpe") {
326     tokenizer = std::make_unique<BPETokenizer>();
327   } else if (FLAGS_tokenizer_type == "tiktoken") {
328     tokenizer = example::get_tiktoken_for_llama();
329   }
330   ET_CHECK_MSG(
331       tokenizer, "Invalid tokenizer type: %s", FLAGS_tokenizer_type.c_str());
332   tokenizer->load(FLAGS_tokenizer_path);
333   return tokenizer;
334 }
335 
main(int argc,char ** argv)336 int main(int argc, char** argv) {
337   executorch::runtime::runtime_init();
338 
339   gflags::ParseCommandLineFlags(&argc, &argv, true);
340   if (argc != 1) {
341     std::string msg = "Extra commandline args:";
342     for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) {
343       msg += std::string(" ") + argv[i];
344     }
345     ET_LOG(Error, "%s", msg.c_str());
346     return 1;
347   }
348 
349   LlamaModelOptions model_options = get_model_options();
350   LlamaModelPaths model_paths = get_model_paths();
351 
352   if (model_paths.prompt_model_paths.empty()) {
353     model_options.prompt_token_batch_size = 1;
354     ET_LOG(
355         Info,
356         "No prompt model paths provided, overriding prompt_token_batch_size to 1");
357   }
358 
359   // Prepare timers
360   Timer timer_init(
361       [](const auto elapsed_sec) { ET_LOG(Info, "Model initialized."); });
362   Timer timer_release(
363       [](const auto elapsed_sec) { ET_LOG(Info, "Model released."); });
364 
365   LlamaRuntime llama_runtime;
366 
367   // Initialize model
368   ET_LOG(Info, "Begin model loading.");
369   timer_init.Start();
370   const auto tokenizer = load_tokenizer();
371   llama_runtime.Initialize(model_options, model_paths);
372   timer_init.End();
373 
374   // Run model
375   ET_CHECK_MSG(!FLAGS_prompt_file.empty(), "No prompt file provided.");
376   std::string prompt = read_file(FLAGS_prompt_file);
377   inference(llama_runtime, tokenizer, prompt);
378 
379   // Release model
380   timer_release.Start();
381   llama_runtime.Release();
382   timer_release.End();
383 
384   return 0;
385 }
386