xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/mtk_llama_runner.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * Copyright (c) 2024 MediaTek Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the BSD-style license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 /* Copyright Statement:
11  *
12  * This software/firmware and related documentation ("MediaTek Software") are
13  * protected under relevant copyright laws. The information contained herein
14  * is confidential and proprietary to MediaTek Inc. and/or its licensors.
15  * Without the prior written permission of MediaTek inc. and/or its licensors,
16  * any reproduction, modification, use or disclosure of MediaTek Software,
17  * and information contained herein, in whole or in part, shall be strictly
18  * prohibited.
19  */
20 /* MediaTek Inc. (C) 2024. All rights reserved.
21  *
22  * BY OPENING THIS FILE, RECEIVER HEREBY UNEQUIVOCALLY ACKNOWLEDGES AND AGREES
23  * THAT THE SOFTWARE/FIRMWARE AND ITS DOCUMENTATIONS ("MEDIATEK SOFTWARE")
24  * RECEIVED FROM MEDIATEK AND/OR ITS REPRESENTATIVES ARE PROVIDED TO RECEIVER ON
25  * AN "AS-IS" BASIS ONLY. MEDIATEK EXPRESSLY DISCLAIMS ANY AND ALL WARRANTIES,
26  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED WARRANTIES OF
27  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE OR NONINFRINGEMENT.
28  * NEITHER DOES MEDIATEK PROVIDE ANY WARRANTY WHATSOEVER WITH RESPECT TO THE
29  * SOFTWARE OF ANY THIRD PARTY WHICH MAY BE USED BY, INCORPORATED IN, OR
30  * SUPPLIED WITH THE MEDIATEK SOFTWARE, AND RECEIVER AGREES TO LOOK ONLY TO SUCH
31  * THIRD PARTY FOR ANY WARRANTY CLAIM RELATING THERETO. RECEIVER EXPRESSLY
32  * ACKNOWLEDGES THAT IT IS RECEIVER'S SOLE RESPONSIBILITY TO OBTAIN FROM ANY
33  * THIRD PARTY ALL PROPER LICENSES CONTAINED IN MEDIATEK SOFTWARE. MEDIATEK
34  * SHALL ALSO NOT BE RESPONSIBLE FOR ANY MEDIATEK SOFTWARE RELEASES MADE TO
35  * RECEIVER'S SPECIFICATION OR TO CONFORM TO A PARTICULAR STANDARD OR OPEN
36  * FORUM. RECEIVER'S SOLE AND EXCLUSIVE REMEDY AND MEDIATEK'S ENTIRE AND
37  * CUMULATIVE LIABILITY WITH RESPECT TO THE MEDIATEK SOFTWARE RELEASED HEREUNDER
38  * WILL BE, AT MEDIATEK'S OPTION, TO REVISE OR REPLACE THE MEDIATEK SOFTWARE AT
39  * ISSUE, OR REFUND ANY SOFTWARE LICENSE FEES OR SERVICE CHARGE PAID BY RECEIVER
40  * TO MEDIATEK FOR SUCH MEDIATEK SOFTWARE AT ISSUE.
41  *
42  * The following software/firmware and/or related documentation ("MediaTek
43  * Software") have been modified by MediaTek Inc. All revisions are subject to
44  * any receiver's applicable license agreements with MediaTek Inc.
45  */
46 
47 #include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
48 #include "executorch/backends/mediatek/runtime/include/NeuronBufferAllocator.h"
49 
50 #include <ctime>
51 #include <iostream>
52 #include <memory>
53 #include <random>
54 
55 #include <executorch/extension/data_loader/file_data_loader.h>
56 #include <executorch/extension/evalue_util/print_evalue.h>
57 #include <executorch/runtime/executor/method.h>
58 #include <executorch/runtime/executor/program.h>
59 #include <executorch/runtime/platform/log.h>
60 #include <executorch/runtime/platform/profiler.h>
61 #include <executorch/runtime/platform/runtime.h>
62 // #include <executorch/util/util.h>
63 #include <executorch/extension/llm/runner/util.h>
64 #include <executorch/runtime/core/result.h>
65 
66 #include "llama_runner/ModelChunk.h"
67 #include "llama_runner/Utils.h"
68 #include "llama_runner/llm_helper/include/llama_runner_values.h"
69 #include "llama_runner/llm_helper/include/llm_types.h"
70 
71 static uint64_t MAX_RESPONSE = 50; // Maximum number of tokens to generate.
72 // Global BOS and EOS option for tokenization (encoding)
73 static constexpr int8_t kAddBos = 1;
74 static constexpr int8_t kAddEos = 0;
75 
76 using namespace example::llm_helper;
77 using example::utils::argmax;
78 using example::utils::split;
79 using example::utils::Timer;
80 using example::utils::to_string;
81 using namespace mtk::vars;
82 
83 namespace llm = ::executorch::extension::llm;
84 
MTKLlamaRunner(const std::string & model_path,const std::string & tokenizer_path,const float temperature)85 MTKLlamaRunner::MTKLlamaRunner(
86     const std::string& model_path,
87     const std::string& tokenizer_path,
88     const float temperature)
89     : modeloptions_(get_model_options()), modelpaths_(get_model_paths()) {
90   executorch::runtime::runtime_init();
91   ET_LOG(
92       Info,
93       "Creating MTK Llama runner. Current it will self-load .pte, .bin, and .so files. Initiated runtime_init().");
94 }
95 
load()96 Error MTKLlamaRunner::load() {
97   if (is_loaded()) {
98     return Error::Ok;
99   }
100 
101   // Load tokenizer
102   ET_LOG(Info, "Loading tokenizer.");
103   tokenizer_ = load_tokenizer();
104   ET_LOG(Info, "Complete loading tokenizer.");
105 
106   // Load prompt model
107   runtime_ = std::make_unique<LlamaRuntime>();
108   ET_LOG(Info, "Loading prompt model.");
109   runtime_->Initialize(modeloptions_, modelpaths_);
110   ET_LOG(Info, "Complete loading prompt model.");
111 
112   return Error::Ok;
113 }
114 
is_loaded() const115 bool MTKLlamaRunner::is_loaded() const {
116   return tokenizer_ && runtime_;
117 }
118 
generate(const std::string & prompt,int32_t seq_len,std::function<void (const std::string &)> token_callback,std::function<void (const Stats &)> stats_callback,bool echo,bool warming)119 Error MTKLlamaRunner::generate(
120     const std::string& prompt,
121     int32_t seq_len,
122     std::function<void(const std::string&)> token_callback,
123     std::function<void(const Stats&)> stats_callback,
124     bool echo,
125     bool warming) {
126   if (!is_loaded()) {
127     ET_CHECK_OK_OR_RETURN_ERROR(load());
128   }
129 
130   // Wrap the token_callback with print function
131   std::function<void(const std::string&)> wrapped_callback =
132       [token_callback](const std::string& piece) {
133         llm::safe_printf(piece.c_str());
134         fflush(stdout);
135         if (token_callback) {
136           token_callback(piece);
137         }
138       };
139 
140   ET_LOG(Info, "Starting inference from MTKLlamaRunner");
141   inference(*runtime_.get(), tokenizer_, prompt, wrapped_callback);
142   ET_LOG(Info, "Completed inference from MTKLlamaRunner");
143 
144   return Error::Ok;
145 }
146 
stop()147 void MTKLlamaRunner::stop() {
148   if (is_loaded()) {
149     runtime_->Release();
150   } else {
151     ET_LOG(Error, "Llama Runtime is not loaded, cannot stop");
152   }
153 }
154 
get_model_options()155 LlamaModelOptions MTKLlamaRunner::get_model_options() {
156   LlamaModelOptions options = {
157       // Sizes
158       .prompt_token_batch_size = PROMPT_TOKEN_BATCH_SIZE,
159       .cache_size = CACHE_SIZE,
160       .hidden_size = HIDDEN_SIZE,
161       .num_head = NUM_HEAD,
162       .num_layer = NUM_LAYER,
163       .max_token_length = MAX_TOKEN_LENGTH,
164       .rot_emb_base = ROT_EMB_BASE,
165 
166       // Types
167       .model_input_type = MODEL_INPUT_TYPE,
168       .model_output_type = MODEL_OUTPUT_TYPE,
169       .cache_type = CACHE_TYPE,
170       .mask_type = MASK_TYPE,
171       .rot_emb_type = ROT_EMB_TYPE};
172   ET_LOG(Info, "Completed get_model_options");
173   return options;
174 }
175 
get_model_paths()176 LlamaModelPaths MTKLlamaRunner::get_model_paths() {
177   LlamaModelPaths model_paths = {
178       .tokenizer_path = TOKENIZER_PATH,
179       .token_embedding_path = TOKEN_EMBEDDING_PATH,
180       .prompt_model_paths = split(PROMPT_MODEL_PATHS, ','),
181       .gen_model_paths = split(GEN_MODEL_PATHS, ',')};
182   ET_LOG(Info, "Completed get_model_paths");
183   return model_paths;
184 }
185 
digest_prompt(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const std::vector<uint64_t> input_tokens)186 Result<uint64_t> MTKLlamaRunner::digest_prompt(
187     LlamaRuntime& llama_runtime,
188     const std::unique_ptr<Tokenizer>& tokenizer,
189     const std::vector<uint64_t> input_tokens) {
190   const auto input_token_count = input_tokens.size();
191   const auto prompt_token_batch_size = llama_runtime.GetTokenBatchSize();
192   size_t cur_token_index = 0;
193 
194   Timer timer_digest_prompt([=](const auto elapsed_sec) {
195     // Ideal prompt size is a multiple of prompt batch size
196     const size_t ideal_prompt_size =
197         std::ceil(float(input_token_count) / prompt_token_batch_size) *
198         prompt_token_batch_size;
199     ET_LOG(
200         Info,
201         "Done analyzing prompt in %f sec (%f tok/s)",
202         elapsed_sec,
203         (float)ideal_prompt_size / elapsed_sec);
204   });
205 
206   auto getNextTokens = [&]() {
207     const size_t num_tok_remain = input_token_count - cur_token_index;
208     const size_t remainder = num_tok_remain % prompt_token_batch_size;
209     const size_t num_new_tokens =
210         remainder ? remainder : prompt_token_batch_size;
211     const auto start = cur_token_index;
212     const auto end = start + num_new_tokens;
213     return std::vector(
214         input_tokens.begin() + start, input_tokens.begin() + end);
215   };
216 
217   void* logits;
218   timer_digest_prompt.Start();
219   while (cur_token_index < input_token_count) {
220     const auto next_tokens = getNextTokens();
221     ET_LOG(
222         Debug,
223         "Digest next tokens (size=%zu), 1st tok=%lu",
224         next_tokens.size(),
225         next_tokens[0]);
226     logits = llama_runtime.Run(next_tokens);
227     cur_token_index += next_tokens.size();
228   }
229   timer_digest_prompt.End();
230 
231   const auto vocab_size = tokenizer->vocab_size();
232   const auto logits_type = llama_runtime.GetModelOptions().model_output_type;
233   const auto first_output_token = argmax(logits_type, logits, vocab_size);
234   return first_output_token;
235 }
236 
gen_response(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const uint64_t input_token,std::function<void (const std::string &)> token_callback)237 Error MTKLlamaRunner::gen_response(
238     LlamaRuntime& llama_runtime,
239     const std::unique_ptr<Tokenizer>& tokenizer,
240     const uint64_t input_token,
241     std::function<void(const std::string&)> token_callback) {
242   Timer timer_model_swap(
243       [](const auto elapsed_sec) { ET_LOG(Info, "Model swapped."); });
244 
245   // Swap to gen mode
246   timer_model_swap.Start();
247   llama_runtime.SwapModel(1);
248   timer_model_swap.End();
249 
250   size_t gen_tok_count = 0;
251   uint64_t prev_token = input_token;
252   uint64_t output_token = input_token;
253 
254   auto decode_res = tokenizer->decode(prev_token, output_token);
255   ET_CHECK_OR_RETURN_ERROR(
256       decode_res.ok(),
257       InvalidState,
258       "Tokenizer failed to decode first generated token: %lu",
259       output_token);
260   std::string full_response = std::move(decode_res.get());
261   std::vector<uint64_t> full_response_tokens = {input_token};
262 
263   const auto vocab_size = tokenizer->vocab_size();
264   const auto logits_type = llama_runtime.GetModelOptions().model_output_type;
265 
266   double gen_total_time_sec = 0;
267   Timer timer_gen_token(
268       [&](const auto elapsed_sec) { gen_total_time_sec += elapsed_sec; });
269 
270   // Print first output token
271   token_callback(full_response);
272 
273   while (gen_tok_count++ < MAX_RESPONSE &&
274          llama_runtime.GetTokenIndex() < modeloptions_.max_token_length) {
275     timer_gen_token.Start();
276     void* logits = llama_runtime.Run({output_token});
277     timer_gen_token.End();
278 
279     prev_token = output_token;
280     output_token = argmax(logits_type, logits, vocab_size);
281     full_response_tokens.push_back(output_token);
282 
283     // Stop when output is EOS
284     if (output_token == tokenizer->eos_tok()) {
285       token_callback("</eos>");
286       break;
287     }
288     auto decode_res = tokenizer->decode(prev_token, output_token);
289     ET_CHECK_OR_RETURN_ERROR(
290         decode_res.ok(),
291         InvalidState,
292         "Tokenizer failed to decode generated token %lu",
293         output_token);
294     const std::string tok_str = std::move(decode_res.get());
295     full_response += tok_str;
296     token_callback(tok_str);
297   }
298 
299   std::cout << "\n\n[Generated Tokens]\n"
300             << to_string(full_response_tokens) << std::endl;
301 
302   ET_LOG(
303       Info,
304       "Token generation speed: %f tok/s",
305       gen_tok_count / gen_total_time_sec);
306 
307   return Error::Ok;
308 }
309 
inference(LlamaRuntime & llama_runtime,const std::unique_ptr<Tokenizer> & tokenizer,const std::string & prompt,std::function<void (const std::string &)> token_callback)310 Error MTKLlamaRunner::inference(
311     LlamaRuntime& llama_runtime,
312     const std::unique_ptr<Tokenizer>& tokenizer,
313     const std::string& prompt,
314     std::function<void(const std::string&)> token_callback) {
315   // Tokenize input prompt
316   auto encode_res = tokenizer->encode(prompt, kAddBos, kAddEos);
317   ET_CHECK_OR_RETURN_ERROR(
318       encode_res.ok(), InvalidState, "Tokenizer failed to encode prompt");
319   const auto input_tokens = std::move(encode_res.get());
320 
321   // Run prompt mode (pre-fill)
322   auto prefill_res = digest_prompt(llama_runtime, tokenizer, input_tokens);
323   ET_CHECK_OR_RETURN_ERROR(
324       prefill_res.ok(), InvalidState, "Failed to digest prompt");
325   const auto first_output_token = prefill_res.get();
326 
327   // run generation mode (decoding)
328   return gen_response(
329       llama_runtime, tokenizer, first_output_token, token_callback);
330 }
331 
load_tokenizer()332 std::unique_ptr<Tokenizer> MTKLlamaRunner::load_tokenizer() {
333   std::unique_ptr<Tokenizer> tokenizer;
334   // Assumes that tokenizer type is Tiktoken
335   tokenizer = example::get_tiktoken_for_llama();
336   tokenizer->load(modelpaths_.tokenizer_path);
337   return tokenizer;
338 }
339