xref: /aosp_15_r20/external/executorch/extension/android/jni/jni_layer_llama.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <chrono>
10 #include <cstdint>
11 #include <memory>
12 #include <string>
13 #include <unordered_map>
14 #include <vector>
15 
16 #include <executorch/examples/models/llama/runner/runner.h>
17 #include <executorch/examples/models/llava/runner/llava_runner.h>
18 #include <executorch/extension/llm/runner/image.h>
19 #include <executorch/extension/llm/runner/irunner.h>
20 #include <executorch/runtime/platform/log.h>
21 #include <executorch/runtime/platform/platform.h>
22 #include <executorch/runtime/platform/runtime.h>
23 
24 #if defined(ET_USE_THREADPOOL)
25 #include <executorch/extension/threadpool/cpuinfo_utils.h>
26 #include <executorch/extension/threadpool/threadpool.h>
27 #endif
28 
29 #include <fbjni/ByteBuffer.h>
30 #include <fbjni/fbjni.h>
31 
32 #if defined(EXECUTORCH_BUILD_MEDIATEK)
33 #include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
34 #endif
35 
36 namespace llm = ::executorch::extension::llm;
37 using ::executorch::runtime::Error;
38 
39 namespace {
utf8_check_validity(const char * str,size_t length)40 bool utf8_check_validity(const char* str, size_t length) {
41   for (size_t i = 0; i < length; ++i) {
42     uint8_t byte = static_cast<uint8_t>(str[i]);
43     if (byte >= 0x80) { // Non-ASCII byte
44       if (i + 1 >= length) { // Incomplete sequence
45         return false;
46       }
47       uint8_t next_byte = static_cast<uint8_t>(str[i + 1]);
48       if ((byte & 0xE0) == 0xC0 &&
49           (next_byte & 0xC0) == 0x80) { // 2-byte sequence
50         i += 1;
51       } else if (
52           (byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 &&
53           (i + 2 < length) &&
54           (static_cast<uint8_t>(str[i + 2]) & 0xC0) ==
55               0x80) { // 3-byte sequence
56         i += 2;
57       } else if (
58           (byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 &&
59           (i + 2 < length) &&
60           (static_cast<uint8_t>(str[i + 2]) & 0xC0) == 0x80 &&
61           (i + 3 < length) &&
62           (static_cast<uint8_t>(str[i + 3]) & 0xC0) ==
63               0x80) { // 4-byte sequence
64         i += 3;
65       } else {
66         return false; // Invalid sequence
67       }
68     }
69   }
70   return true; // All bytes were valid
71 }
72 
73 std::string token_buffer;
74 } // namespace
75 
76 namespace executorch_jni {
77 
78 class ExecuTorchLlamaCallbackJni
79     : public facebook::jni::JavaClass<ExecuTorchLlamaCallbackJni> {
80  public:
81   constexpr static const char* kJavaDescriptor =
82       "Lorg/pytorch/executorch/LlamaCallback;";
83 
onResult(std::string result) const84   void onResult(std::string result) const {
85     static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
86     static const auto method =
87         cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");
88 
89     token_buffer += result;
90     if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) {
91       ET_LOG(
92           Info, "Current token buffer is not valid UTF-8. Waiting for more.");
93       return;
94     }
95     result = token_buffer;
96     token_buffer = "";
97     facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
98     method(self(), s);
99   }
100 
onStats(const llm::Stats & result) const101   void onStats(const llm::Stats& result) const {
102     static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
103     static const auto method = cls->getMethod<void(jfloat)>("onStats");
104     double eval_time =
105         (double)(result.inference_end_ms - result.prompt_eval_end_ms);
106 
107     float tps = result.num_generated_tokens / eval_time *
108         result.SCALING_FACTOR_UNITS_PER_SECOND;
109 
110     method(self(), tps);
111   }
112 };
113 
114 class ExecuTorchLlamaJni
115     : public facebook::jni::HybridClass<ExecuTorchLlamaJni> {
116  private:
117   friend HybridBase;
118   int model_type_category_;
119   std::unique_ptr<llm::IRunner> runner_;
120   std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
121 
122  public:
123   constexpr static auto kJavaDescriptor =
124       "Lorg/pytorch/executorch/LlamaModule;";
125 
126   constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
127   constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
128   constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3;
129 
initHybrid(facebook::jni::alias_ref<jclass>,jint model_type_category,facebook::jni::alias_ref<jstring> model_path,facebook::jni::alias_ref<jstring> tokenizer_path,jfloat temperature)130   static facebook::jni::local_ref<jhybriddata> initHybrid(
131       facebook::jni::alias_ref<jclass>,
132       jint model_type_category,
133       facebook::jni::alias_ref<jstring> model_path,
134       facebook::jni::alias_ref<jstring> tokenizer_path,
135       jfloat temperature) {
136     return makeCxxInstance(
137         model_type_category, model_path, tokenizer_path, temperature);
138   }
139 
ExecuTorchLlamaJni(jint model_type_category,facebook::jni::alias_ref<jstring> model_path,facebook::jni::alias_ref<jstring> tokenizer_path,jfloat temperature)140   ExecuTorchLlamaJni(
141       jint model_type_category,
142       facebook::jni::alias_ref<jstring> model_path,
143       facebook::jni::alias_ref<jstring> tokenizer_path,
144       jfloat temperature) {
145 #if defined(ET_USE_THREADPOOL)
146     // Reserve 1 thread for the main thread.
147     uint32_t num_performant_cores =
148         ::executorch::extension::cpuinfo::get_num_performant_cores() - 1;
149     if (num_performant_cores > 0) {
150       ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores);
151       ::executorch::extension::threadpool::get_threadpool()
152           ->_unsafe_reset_threadpool(num_performant_cores);
153     }
154 #endif
155 
156     model_type_category_ = model_type_category;
157     if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
158       multi_modal_runner_ = std::make_unique<example::LlavaRunner>(
159           model_path->toStdString().c_str(),
160           tokenizer_path->toStdString().c_str(),
161           temperature);
162     } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
163       runner_ = std::make_unique<example::Runner>(
164           model_path->toStdString().c_str(),
165           tokenizer_path->toStdString().c_str(),
166           temperature);
167 #if defined(EXECUTORCH_BUILD_MEDIATEK)
168     } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
169       runner_ = std::make_unique<MTKLlamaRunner>(
170           model_path->toStdString().c_str(),
171           tokenizer_path->toStdString().c_str(),
172           temperature);
173       // Interpret the model type as LLM
174       model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
175 #endif
176     }
177   }
178 
generate(facebook::jni::alias_ref<jintArray> image,jint width,jint height,jint channels,facebook::jni::alias_ref<jstring> prompt,jint seq_len,facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,jboolean echo)179   jint generate(
180       facebook::jni::alias_ref<jintArray> image,
181       jint width,
182       jint height,
183       jint channels,
184       facebook::jni::alias_ref<jstring> prompt,
185       jint seq_len,
186       facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
187       jboolean echo) {
188     if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
189       auto image_size = image->size();
190       std::vector<llm::Image> images;
191       if (image_size != 0) {
192         std::vector<jint> image_data_jint(image_size);
193         std::vector<uint8_t> image_data(image_size);
194         image->getRegion(0, image_size, image_data_jint.data());
195         for (int i = 0; i < image_size; i++) {
196           image_data[i] = image_data_jint[i];
197         }
198         llm::Image image_runner{image_data, width, height, channels};
199         images.push_back(image_runner);
200       }
201       multi_modal_runner_->generate(
202           std::move(images),
203           prompt->toStdString(),
204           seq_len,
205           [callback](std::string result) { callback->onResult(result); },
206           [callback](const llm::Stats& result) { callback->onStats(result); },
207           echo);
208     } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
209       runner_->generate(
210           prompt->toStdString(),
211           seq_len,
212           [callback](std::string result) { callback->onResult(result); },
213           [callback](const llm::Stats& result) { callback->onStats(result); },
214           echo);
215     }
216     return 0;
217   }
218 
219   // Returns a tuple of (error, start_pos)
220   // Contract is valid within an AAR (JNI + corresponding Java code)
221   // If the first element is not Error::Ok, the other element is undefined.
prefill_prompt(facebook::jni::alias_ref<jstring> prompt,jlong start_pos,jint bos,jint eos)222   facebook::jni::local_ref<jlongArray> prefill_prompt(
223       facebook::jni::alias_ref<jstring> prompt,
224       jlong start_pos,
225       jint bos,
226       jint eos) {
227     facebook::jni::local_ref<jlongArray> tuple_result =
228         facebook::jni::make_long_array(2);
229     if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
230       tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
231       return tuple_result;
232     }
233 
234     auto&& result = multi_modal_runner_->prefill_prompt(
235         prompt->toStdString(), start_pos, bos, eos);
236     tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
237     if (result.ok()) {
238       tuple_result->pin()[1] = static_cast<jlong>(start_pos);
239     }
240     return tuple_result;
241   }
242 
243   // Returns a tuple of (error, start_pos)
244   // Contract is valid within an AAR (JNI + corresponding Java code)
245   // If the first element is not Error::Ok, the other element is undefined.
246 
prefill_images(facebook::jni::alias_ref<jintArray> image,jint width,jint height,jint channels,jlong start_pos)247   facebook::jni::local_ref<jlongArray> prefill_images(
248       facebook::jni::alias_ref<jintArray> image,
249       jint width,
250       jint height,
251       jint channels,
252       jlong start_pos) {
253     facebook::jni::local_ref<jlongArray> tuple_result =
254         facebook::jni::make_long_array(2);
255 
256     if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
257       tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
258       return tuple_result;
259     }
260 
261     auto image_size = image->size();
262     std::vector<llm::Image> images;
263     if (image_size != 0) {
264       std::vector<jint> image_data_jint(image_size);
265       std::vector<uint8_t> image_data(image_size);
266       image->getRegion(0, image_size, image_data_jint.data());
267       for (int i = 0; i < image_size; i++) {
268         image_data[i] = image_data_jint[i];
269       }
270       llm::Image image_runner{image_data, width, height, channels};
271       images.push_back(image_runner);
272     }
273     // TODO(hsz): make  start_pos a reference and update it here
274     jint result = static_cast<jint>(
275         multi_modal_runner_->prefill_images(images, start_pos));
276     tuple_result->pin()[0] = result;
277     tuple_result->pin()[1] = static_cast<jlong>(start_pos);
278     return tuple_result;
279   }
280 
generate_from_pos(facebook::jni::alias_ref<jstring> prompt,jint seq_len,jlong start_pos,facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,jboolean echo)281   jint generate_from_pos(
282       facebook::jni::alias_ref<jstring> prompt,
283       jint seq_len,
284       jlong start_pos,
285       facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
286       jboolean echo) {
287     if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
288       return static_cast<jint>(Error::NotSupported);
289     }
290     return static_cast<jint>(multi_modal_runner_->generate_from_pos(
291         prompt->toStdString(),
292         seq_len,
293         start_pos,
294         [callback](const std::string& result) { callback->onResult(result); },
295         [callback](const llm::Stats& stats) { callback->onStats(stats); },
296         echo));
297   }
298 
stop()299   void stop() {
300     if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
301       multi_modal_runner_->stop();
302     } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
303       runner_->stop();
304     }
305   }
306 
load()307   jint load() {
308     if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
309       return static_cast<jint>(multi_modal_runner_->load());
310     } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
311       return static_cast<jint>(runner_->load());
312     }
313     return static_cast<jint>(Error::InvalidArgument);
314   }
315 
registerNatives()316   static void registerNatives() {
317     registerHybrid({
318         makeNativeMethod("initHybrid", ExecuTorchLlamaJni::initHybrid),
319         makeNativeMethod("generate", ExecuTorchLlamaJni::generate),
320         makeNativeMethod("stop", ExecuTorchLlamaJni::stop),
321         makeNativeMethod("load", ExecuTorchLlamaJni::load),
322         makeNativeMethod(
323             "prefillImagesNative", ExecuTorchLlamaJni::prefill_images),
324         makeNativeMethod(
325             "prefillPromptNative", ExecuTorchLlamaJni::prefill_prompt),
326         makeNativeMethod(
327             "generateFromPos", ExecuTorchLlamaJni::generate_from_pos),
328     });
329   }
330 };
331 
332 } // namespace executorch_jni
333 
register_natives_for_llama()334 void register_natives_for_llama() {
335   executorch_jni::ExecuTorchLlamaJni::registerNatives();
336 }
337