1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <chrono>
10 #include <cstdint>
11 #include <memory>
12 #include <string>
13 #include <unordered_map>
14 #include <vector>
15
16 #include <executorch/examples/models/llama/runner/runner.h>
17 #include <executorch/examples/models/llava/runner/llava_runner.h>
18 #include <executorch/extension/llm/runner/image.h>
19 #include <executorch/extension/llm/runner/irunner.h>
20 #include <executorch/runtime/platform/log.h>
21 #include <executorch/runtime/platform/platform.h>
22 #include <executorch/runtime/platform/runtime.h>
23
24 #if defined(ET_USE_THREADPOOL)
25 #include <executorch/extension/threadpool/cpuinfo_utils.h>
26 #include <executorch/extension/threadpool/threadpool.h>
27 #endif
28
29 #include <fbjni/ByteBuffer.h>
30 #include <fbjni/fbjni.h>
31
32 #if defined(EXECUTORCH_BUILD_MEDIATEK)
33 #include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
34 #endif
35
36 namespace llm = ::executorch::extension::llm;
37 using ::executorch::runtime::Error;
38
39 namespace {
utf8_check_validity(const char * str,size_t length)40 bool utf8_check_validity(const char* str, size_t length) {
41 for (size_t i = 0; i < length; ++i) {
42 uint8_t byte = static_cast<uint8_t>(str[i]);
43 if (byte >= 0x80) { // Non-ASCII byte
44 if (i + 1 >= length) { // Incomplete sequence
45 return false;
46 }
47 uint8_t next_byte = static_cast<uint8_t>(str[i + 1]);
48 if ((byte & 0xE0) == 0xC0 &&
49 (next_byte & 0xC0) == 0x80) { // 2-byte sequence
50 i += 1;
51 } else if (
52 (byte & 0xF0) == 0xE0 && (next_byte & 0xC0) == 0x80 &&
53 (i + 2 < length) &&
54 (static_cast<uint8_t>(str[i + 2]) & 0xC0) ==
55 0x80) { // 3-byte sequence
56 i += 2;
57 } else if (
58 (byte & 0xF8) == 0xF0 && (next_byte & 0xC0) == 0x80 &&
59 (i + 2 < length) &&
60 (static_cast<uint8_t>(str[i + 2]) & 0xC0) == 0x80 &&
61 (i + 3 < length) &&
62 (static_cast<uint8_t>(str[i + 3]) & 0xC0) ==
63 0x80) { // 4-byte sequence
64 i += 3;
65 } else {
66 return false; // Invalid sequence
67 }
68 }
69 }
70 return true; // All bytes were valid
71 }
72
73 std::string token_buffer;
74 } // namespace
75
76 namespace executorch_jni {
77
78 class ExecuTorchLlamaCallbackJni
79 : public facebook::jni::JavaClass<ExecuTorchLlamaCallbackJni> {
80 public:
81 constexpr static const char* kJavaDescriptor =
82 "Lorg/pytorch/executorch/LlamaCallback;";
83
onResult(std::string result) const84 void onResult(std::string result) const {
85 static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
86 static const auto method =
87 cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");
88
89 token_buffer += result;
90 if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) {
91 ET_LOG(
92 Info, "Current token buffer is not valid UTF-8. Waiting for more.");
93 return;
94 }
95 result = token_buffer;
96 token_buffer = "";
97 facebook::jni::local_ref<jstring> s = facebook::jni::make_jstring(result);
98 method(self(), s);
99 }
100
onStats(const llm::Stats & result) const101 void onStats(const llm::Stats& result) const {
102 static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
103 static const auto method = cls->getMethod<void(jfloat)>("onStats");
104 double eval_time =
105 (double)(result.inference_end_ms - result.prompt_eval_end_ms);
106
107 float tps = result.num_generated_tokens / eval_time *
108 result.SCALING_FACTOR_UNITS_PER_SECOND;
109
110 method(self(), tps);
111 }
112 };
113
114 class ExecuTorchLlamaJni
115 : public facebook::jni::HybridClass<ExecuTorchLlamaJni> {
116 private:
117 friend HybridBase;
118 int model_type_category_;
119 std::unique_ptr<llm::IRunner> runner_;
120 std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
121
122 public:
123 constexpr static auto kJavaDescriptor =
124 "Lorg/pytorch/executorch/LlamaModule;";
125
126 constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
127 constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
128 constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3;
129
initHybrid(facebook::jni::alias_ref<jclass>,jint model_type_category,facebook::jni::alias_ref<jstring> model_path,facebook::jni::alias_ref<jstring> tokenizer_path,jfloat temperature)130 static facebook::jni::local_ref<jhybriddata> initHybrid(
131 facebook::jni::alias_ref<jclass>,
132 jint model_type_category,
133 facebook::jni::alias_ref<jstring> model_path,
134 facebook::jni::alias_ref<jstring> tokenizer_path,
135 jfloat temperature) {
136 return makeCxxInstance(
137 model_type_category, model_path, tokenizer_path, temperature);
138 }
139
ExecuTorchLlamaJni(jint model_type_category,facebook::jni::alias_ref<jstring> model_path,facebook::jni::alias_ref<jstring> tokenizer_path,jfloat temperature)140 ExecuTorchLlamaJni(
141 jint model_type_category,
142 facebook::jni::alias_ref<jstring> model_path,
143 facebook::jni::alias_ref<jstring> tokenizer_path,
144 jfloat temperature) {
145 #if defined(ET_USE_THREADPOOL)
146 // Reserve 1 thread for the main thread.
147 uint32_t num_performant_cores =
148 ::executorch::extension::cpuinfo::get_num_performant_cores() - 1;
149 if (num_performant_cores > 0) {
150 ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores);
151 ::executorch::extension::threadpool::get_threadpool()
152 ->_unsafe_reset_threadpool(num_performant_cores);
153 }
154 #endif
155
156 model_type_category_ = model_type_category;
157 if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
158 multi_modal_runner_ = std::make_unique<example::LlavaRunner>(
159 model_path->toStdString().c_str(),
160 tokenizer_path->toStdString().c_str(),
161 temperature);
162 } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
163 runner_ = std::make_unique<example::Runner>(
164 model_path->toStdString().c_str(),
165 tokenizer_path->toStdString().c_str(),
166 temperature);
167 #if defined(EXECUTORCH_BUILD_MEDIATEK)
168 } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
169 runner_ = std::make_unique<MTKLlamaRunner>(
170 model_path->toStdString().c_str(),
171 tokenizer_path->toStdString().c_str(),
172 temperature);
173 // Interpret the model type as LLM
174 model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
175 #endif
176 }
177 }
178
generate(facebook::jni::alias_ref<jintArray> image,jint width,jint height,jint channels,facebook::jni::alias_ref<jstring> prompt,jint seq_len,facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,jboolean echo)179 jint generate(
180 facebook::jni::alias_ref<jintArray> image,
181 jint width,
182 jint height,
183 jint channels,
184 facebook::jni::alias_ref<jstring> prompt,
185 jint seq_len,
186 facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
187 jboolean echo) {
188 if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
189 auto image_size = image->size();
190 std::vector<llm::Image> images;
191 if (image_size != 0) {
192 std::vector<jint> image_data_jint(image_size);
193 std::vector<uint8_t> image_data(image_size);
194 image->getRegion(0, image_size, image_data_jint.data());
195 for (int i = 0; i < image_size; i++) {
196 image_data[i] = image_data_jint[i];
197 }
198 llm::Image image_runner{image_data, width, height, channels};
199 images.push_back(image_runner);
200 }
201 multi_modal_runner_->generate(
202 std::move(images),
203 prompt->toStdString(),
204 seq_len,
205 [callback](std::string result) { callback->onResult(result); },
206 [callback](const llm::Stats& result) { callback->onStats(result); },
207 echo);
208 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
209 runner_->generate(
210 prompt->toStdString(),
211 seq_len,
212 [callback](std::string result) { callback->onResult(result); },
213 [callback](const llm::Stats& result) { callback->onStats(result); },
214 echo);
215 }
216 return 0;
217 }
218
219 // Returns a tuple of (error, start_pos)
220 // Contract is valid within an AAR (JNI + corresponding Java code)
221 // If the first element is not Error::Ok, the other element is undefined.
prefill_prompt(facebook::jni::alias_ref<jstring> prompt,jlong start_pos,jint bos,jint eos)222 facebook::jni::local_ref<jlongArray> prefill_prompt(
223 facebook::jni::alias_ref<jstring> prompt,
224 jlong start_pos,
225 jint bos,
226 jint eos) {
227 facebook::jni::local_ref<jlongArray> tuple_result =
228 facebook::jni::make_long_array(2);
229 if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
230 tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
231 return tuple_result;
232 }
233
234 auto&& result = multi_modal_runner_->prefill_prompt(
235 prompt->toStdString(), start_pos, bos, eos);
236 tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
237 if (result.ok()) {
238 tuple_result->pin()[1] = static_cast<jlong>(start_pos);
239 }
240 return tuple_result;
241 }
242
243 // Returns a tuple of (error, start_pos)
244 // Contract is valid within an AAR (JNI + corresponding Java code)
245 // If the first element is not Error::Ok, the other element is undefined.
246
prefill_images(facebook::jni::alias_ref<jintArray> image,jint width,jint height,jint channels,jlong start_pos)247 facebook::jni::local_ref<jlongArray> prefill_images(
248 facebook::jni::alias_ref<jintArray> image,
249 jint width,
250 jint height,
251 jint channels,
252 jlong start_pos) {
253 facebook::jni::local_ref<jlongArray> tuple_result =
254 facebook::jni::make_long_array(2);
255
256 if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
257 tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
258 return tuple_result;
259 }
260
261 auto image_size = image->size();
262 std::vector<llm::Image> images;
263 if (image_size != 0) {
264 std::vector<jint> image_data_jint(image_size);
265 std::vector<uint8_t> image_data(image_size);
266 image->getRegion(0, image_size, image_data_jint.data());
267 for (int i = 0; i < image_size; i++) {
268 image_data[i] = image_data_jint[i];
269 }
270 llm::Image image_runner{image_data, width, height, channels};
271 images.push_back(image_runner);
272 }
273 // TODO(hsz): make start_pos a reference and update it here
274 jint result = static_cast<jint>(
275 multi_modal_runner_->prefill_images(images, start_pos));
276 tuple_result->pin()[0] = result;
277 tuple_result->pin()[1] = static_cast<jlong>(start_pos);
278 return tuple_result;
279 }
280
generate_from_pos(facebook::jni::alias_ref<jstring> prompt,jint seq_len,jlong start_pos,facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,jboolean echo)281 jint generate_from_pos(
282 facebook::jni::alias_ref<jstring> prompt,
283 jint seq_len,
284 jlong start_pos,
285 facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
286 jboolean echo) {
287 if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
288 return static_cast<jint>(Error::NotSupported);
289 }
290 return static_cast<jint>(multi_modal_runner_->generate_from_pos(
291 prompt->toStdString(),
292 seq_len,
293 start_pos,
294 [callback](const std::string& result) { callback->onResult(result); },
295 [callback](const llm::Stats& stats) { callback->onStats(stats); },
296 echo));
297 }
298
stop()299 void stop() {
300 if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
301 multi_modal_runner_->stop();
302 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
303 runner_->stop();
304 }
305 }
306
load()307 jint load() {
308 if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
309 return static_cast<jint>(multi_modal_runner_->load());
310 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
311 return static_cast<jint>(runner_->load());
312 }
313 return static_cast<jint>(Error::InvalidArgument);
314 }
315
registerNatives()316 static void registerNatives() {
317 registerHybrid({
318 makeNativeMethod("initHybrid", ExecuTorchLlamaJni::initHybrid),
319 makeNativeMethod("generate", ExecuTorchLlamaJni::generate),
320 makeNativeMethod("stop", ExecuTorchLlamaJni::stop),
321 makeNativeMethod("load", ExecuTorchLlamaJni::load),
322 makeNativeMethod(
323 "prefillImagesNative", ExecuTorchLlamaJni::prefill_images),
324 makeNativeMethod(
325 "prefillPromptNative", ExecuTorchLlamaJni::prefill_prompt),
326 makeNativeMethod(
327 "generateFromPos", ExecuTorchLlamaJni::generate_from_pos),
328 });
329 }
330 };
331
332 } // namespace executorch_jni
333
register_natives_for_llama()334 void register_natives_for_llama() {
335 executorch_jni::ExecuTorchLlamaJni::registerNatives();
336 }
337