1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cassert>
10 #include <chrono>
11 #include <iostream>
12 #include <memory>
13 #include <sstream>
14 #include <string>
15 #include <unordered_map>
16 #include <vector>
17
18 #include "jni_layer_constants.h"
19
20 #include <executorch/extension/android/jni/log.h>
21 #include <executorch/extension/module/module.h>
22 #include <executorch/extension/runner_util/inputs.h>
23 #include <executorch/extension/tensor/tensor.h>
24 #include <executorch/runtime/core/portable_type/tensor_impl.h>
25 #include <executorch/runtime/platform/log.h>
26 #include <executorch/runtime/platform/platform.h>
27 #include <executorch/runtime/platform/runtime.h>
28
29 #ifdef ET_USE_THREADPOOL
30 #include <cpuinfo.h>
31 #include <executorch/extension/threadpool/threadpool.h>
32 #endif
33
34 #include <fbjni/ByteBuffer.h>
35 #include <fbjni/fbjni.h>
36
37 using namespace executorch::extension;
38 using namespace torch::executor;
39
40 namespace executorch::extension {
41 class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
42 public:
43 constexpr static const char* kJavaDescriptor =
44 "Lorg/pytorch/executorch/Tensor;";
45
TensorHybrid(exec_aten::Tensor tensor)46 explicit TensorHybrid(exec_aten::Tensor tensor) {}
47
48 static facebook::jni::local_ref<TensorHybrid::javaobject>
newJTensorFromTensor(const exec_aten::Tensor & tensor)49 newJTensorFromTensor(const exec_aten::Tensor& tensor) {
50 // Java wrapper currently only supports contiguous tensors.
51
52 const auto scalarType = tensor.scalar_type();
53
54 if (scalar_type_to_java_dtype.count(scalarType) == 0) {
55 facebook::jni::throwNewJavaException(
56 facebook::jni::gJavaLangIllegalArgumentException,
57 "exec_aten::Tensor scalar type %d is not supported on java side",
58 scalarType);
59 }
60 int jdtype = scalar_type_to_java_dtype.at(scalarType);
61
62 const auto& tensor_shape = tensor.sizes();
63 std::vector<jlong> tensor_shape_vec;
64 for (const auto& s : tensor_shape) {
65 tensor_shape_vec.push_back(s);
66 }
67 facebook::jni::local_ref<jlongArray> jTensorShape =
68 facebook::jni::make_long_array(tensor_shape_vec.size());
69 jTensorShape->setRegion(
70 0, tensor_shape_vec.size(), tensor_shape_vec.data());
71
72 static auto cls = TensorHybrid::javaClassStatic();
73 // Note: this is safe as long as the data stored in tensor is valid; the
74 // data won't go out of scope as long as the Method for the inference is
75 // valid and there is no other inference call. Java layer picks up this
76 // value immediately so the data is valid.
77 facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
78 facebook::jni::JByteBuffer::wrapBytes(
79 (uint8_t*)tensor.data_ptr(), tensor.nbytes());
80 jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
81
82 static const auto jMethodNewTensor =
83 cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
84 facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
85 facebook::jni::alias_ref<jlongArray>,
86 jint,
87 facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
88 return jMethodNewTensor(
89 cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor));
90 }
91
92 private:
93 friend HybridBase;
94 };
95
96 class JEValue : public facebook::jni::JavaClass<JEValue> {
97 public:
98 constexpr static const char* kJavaDescriptor =
99 "Lorg/pytorch/executorch/EValue;";
100
101 constexpr static int kTypeCodeTensor = 1;
102 constexpr static int kTypeCodeString = 2;
103 constexpr static int kTypeCodeDouble = 3;
104 constexpr static int kTypeCodeInt = 4;
105 constexpr static int kTypeCodeBool = 5;
106
newJEValueFromEValue(EValue evalue)107 static facebook::jni::local_ref<JEValue> newJEValueFromEValue(EValue evalue) {
108 if (evalue.isTensor()) {
109 static auto jMethodTensor =
110 JEValue::javaClassStatic()
111 ->getStaticMethod<facebook::jni::local_ref<JEValue>(
112 facebook::jni::local_ref<TensorHybrid::javaobject>)>("from");
113 return jMethodTensor(
114 JEValue::javaClassStatic(),
115 TensorHybrid::newJTensorFromTensor(evalue.toTensor()));
116 } else if (evalue.isInt()) {
117 static auto jMethodTensor =
118 JEValue::javaClassStatic()
119 ->getStaticMethod<facebook::jni::local_ref<JEValue>(jlong)>(
120 "from");
121 return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt());
122 } else if (evalue.isDouble()) {
123 static auto jMethodTensor =
124 JEValue::javaClassStatic()
125 ->getStaticMethod<facebook::jni::local_ref<JEValue>(jdouble)>(
126 "from");
127 return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble());
128 } else if (evalue.isBool()) {
129 static auto jMethodTensor =
130 JEValue::javaClassStatic()
131 ->getStaticMethod<facebook::jni::local_ref<JEValue>(jboolean)>(
132 "from");
133 return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool());
134 } else if (evalue.isString()) {
135 static auto jMethodTensor =
136 JEValue::javaClassStatic()
137 ->getStaticMethod<facebook::jni::local_ref<JEValue>(
138 facebook::jni::local_ref<jstring>)>("from");
139 std::string str =
140 std::string(evalue.toString().begin(), evalue.toString().end());
141 return jMethodTensor(
142 JEValue::javaClassStatic(), facebook::jni::make_jstring(str));
143 }
144 facebook::jni::throwNewJavaException(
145 facebook::jni::gJavaLangIllegalArgumentException,
146 "Unsupported EValue type: %d",
147 evalue.tag);
148 }
149
JEValueToTensorImpl(facebook::jni::alias_ref<JEValue> JEValue)150 static TensorPtr JEValueToTensorImpl(
151 facebook::jni::alias_ref<JEValue> JEValue) {
152 static const auto typeCodeField =
153 JEValue::javaClassStatic()->getField<jint>("mTypeCode");
154 const auto typeCode = JEValue->getFieldValue(typeCodeField);
155 if (JEValue::kTypeCodeTensor == typeCode) {
156 static const auto jMethodGetTensor =
157 JEValue::javaClassStatic()
158 ->getMethod<facebook::jni::alias_ref<TensorHybrid::javaobject>()>(
159 "toTensor");
160 auto jtensor = jMethodGetTensor(JEValue);
161
162 static auto cls = TensorHybrid::javaClassStatic();
163 static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
164 jint jdtype = dtypeMethod(jtensor);
165
166 static const auto shapeField = cls->getField<jlongArray>("shape");
167 auto jshape = jtensor->getFieldValue(shapeField);
168
169 static auto dataBufferMethod = cls->getMethod<
170 facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
171 "getRawDataBuffer");
172 facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
173 dataBufferMethod(jtensor);
174
175 const auto rank = jshape->size();
176
177 const auto shapeArr = jshape->getRegion(0, rank);
178 std::vector<exec_aten::SizesType> shape_vec;
179 shape_vec.reserve(rank);
180
181 auto numel = 1;
182 for (int i = 0; i < rank; i++) {
183 shape_vec.push_back(shapeArr[i]);
184 }
185 for (int i = rank - 1; i >= 0; --i) {
186 numel *= shapeArr[i];
187 }
188 JNIEnv* jni = facebook::jni::Environment::current();
189 if (java_dtype_to_scalar_type.count(jdtype) == 0) {
190 facebook::jni::throwNewJavaException(
191 facebook::jni::gJavaLangIllegalArgumentException,
192 "Unknown Tensor jdtype %d",
193 jdtype);
194 }
195 ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype);
196 const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
197 if (dataCapacity != numel) {
198 facebook::jni::throwNewJavaException(
199 facebook::jni::gJavaLangIllegalArgumentException,
200 "Tensor dimensions(elements number:%d inconsistent with buffer capacity(%d)",
201 numel,
202 dataCapacity);
203 }
204 return from_blob(
205 jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type);
206 }
207 facebook::jni::throwNewJavaException(
208 facebook::jni::gJavaLangIllegalArgumentException,
209 "Unknown EValue typeCode %d",
210 typeCode);
211 }
212 };
213
214 class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
215 private:
216 friend HybridBase;
217 std::unique_ptr<Module> module_;
218
219 public:
220 constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/NativePeer;";
221
initHybrid(facebook::jni::alias_ref<jclass>,facebook::jni::alias_ref<jstring> modelPath,jint loadMode)222 static facebook::jni::local_ref<jhybriddata> initHybrid(
223 facebook::jni::alias_ref<jclass>,
224 facebook::jni::alias_ref<jstring> modelPath,
225 jint loadMode) {
226 return makeCxxInstance(modelPath, loadMode);
227 }
228
ExecuTorchJni(facebook::jni::alias_ref<jstring> modelPath,jint loadMode)229 ExecuTorchJni(facebook::jni::alias_ref<jstring> modelPath, jint loadMode) {
230 Module::LoadMode load_mode = Module::LoadMode::Mmap;
231 if (loadMode == 0) {
232 load_mode = Module::LoadMode::File;
233 } else if (loadMode == 1) {
234 load_mode = Module::LoadMode::Mmap;
235 } else if (loadMode == 2) {
236 load_mode = Module::LoadMode::MmapUseMlock;
237 } else if (loadMode == 3) {
238 load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors;
239 }
240
241 module_ = std::make_unique<Module>(modelPath->toStdString(), load_mode);
242
243 #ifdef ET_USE_THREADPOOL
244 // Default to using cores/2 threadpool threads. The long-term plan is to
245 // improve performant core detection in CPUInfo, but for now we can use
246 // cores/2 as a sane default.
247 //
248 // Based on testing, this is almost universally faster than using all
249 // cores, as efficiency cores can be quite slow. In extreme cases, using
250 // all cores can be 10x slower than using cores/2.
251 //
252 // TODO Allow overriding this default from Java.
253 auto threadpool = executorch::extension::threadpool::get_threadpool();
254 if (threadpool) {
255 int thread_count = cpuinfo_get_processors_count() / 2;
256 if (thread_count > 0) {
257 threadpool->_unsafe_reset_threadpool(thread_count);
258 }
259 }
260 #endif
261 }
262
forward(facebook::jni::alias_ref<facebook::jni::JArrayClass<JEValue::javaobject>::javaobject> jinputs)263 facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> forward(
264 facebook::jni::alias_ref<
265 facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
266 jinputs) {
267 return execute_method("forward", jinputs);
268 }
269
execute(facebook::jni::alias_ref<jstring> methodName,facebook::jni::alias_ref<facebook::jni::JArrayClass<JEValue::javaobject>::javaobject> jinputs)270 facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> execute(
271 facebook::jni::alias_ref<jstring> methodName,
272 facebook::jni::alias_ref<
273 facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
274 jinputs) {
275 return execute_method(methodName->toStdString(), jinputs);
276 }
277
load_method(facebook::jni::alias_ref<jstring> methodName)278 jint load_method(facebook::jni::alias_ref<jstring> methodName) {
279 return static_cast<jint>(module_->load_method(methodName->toStdString()));
280 }
281
execute_method(std::string method,facebook::jni::alias_ref<facebook::jni::JArrayClass<JEValue::javaobject>::javaobject> jinputs)282 facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> execute_method(
283 std::string method,
284 facebook::jni::alias_ref<
285 facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
286 jinputs) {
287 // If no inputs is given, it will run with sample inputs (ones)
288 if (jinputs->size() == 0) {
289 if (module_->load_method(method) != Error::Ok) {
290 return {};
291 }
292 auto&& underlying_method = module_->methods_[method].method;
293 auto&& buf = prepare_input_tensors(*underlying_method);
294 auto result = underlying_method->execute();
295 if (result != Error::Ok) {
296 return {};
297 }
298 facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
299 facebook::jni::JArrayClass<JEValue>::newArray(
300 underlying_method->outputs_size());
301
302 for (int i = 0; i < underlying_method->outputs_size(); i++) {
303 auto jevalue =
304 JEValue::newJEValueFromEValue(underlying_method->get_output(i));
305 jresult->setElement(i, *jevalue);
306 }
307 return jresult;
308 }
309
310 std::vector<EValue> evalues;
311 std::vector<TensorPtr> tensors;
312
313 static const auto typeCodeField =
314 JEValue::javaClassStatic()->getField<jint>("mTypeCode");
315
316 for (int i = 0; i < jinputs->size(); i++) {
317 auto jevalue = jinputs->getElement(i);
318 const auto typeCode = jevalue->getFieldValue(typeCodeField);
319 if (typeCode == JEValue::kTypeCodeTensor) {
320 tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue));
321 evalues.emplace_back(tensors.back());
322 } else if (typeCode == JEValue::kTypeCodeInt) {
323 int64_t value = jevalue->getFieldValue(typeCodeField);
324 evalues.emplace_back(value);
325 } else if (typeCode == JEValue::kTypeCodeDouble) {
326 double value = jevalue->getFieldValue(typeCodeField);
327 evalues.emplace_back(value);
328 } else if (typeCode == JEValue::kTypeCodeBool) {
329 bool value = jevalue->getFieldValue(typeCodeField);
330 evalues.emplace_back(value);
331 }
332 }
333
334 #ifdef EXECUTORCH_ANDROID_PROFILING
335 auto start = std::chrono::high_resolution_clock::now();
336 auto result = module_->execute(method, evalues);
337 auto end = std::chrono::high_resolution_clock::now();
338 auto duration =
339 std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
340 .count();
341 ET_LOG(Debug, "Execution time: %lld ms.", duration);
342
343 #else
344 auto result = module_->execute(method, evalues);
345
346 #endif
347
348 if (!result.ok()) {
349 facebook::jni::throwNewJavaException(
350 "java/lang/Exception",
351 "Execution of method %s failed with status 0x%" PRIx32,
352 method.c_str(),
353 static_cast<error_code_t>(result.error()));
354 return {};
355 }
356
357 facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
358 facebook::jni::JArrayClass<JEValue>::newArray(result.get().size());
359
360 for (int i = 0; i < result.get().size(); i++) {
361 auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]);
362 jresult->setElement(i, *jevalue);
363 }
364
365 return jresult;
366 }
367
368 facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
readLogBuffer()369 readLogBuffer() {
370 #ifdef __ANDROID__
371
372 facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret;
373
374 access_log_buffer([&](std::vector<log_entry>& buffer) {
375 const auto size = buffer.size();
376 ret = facebook::jni::JArrayClass<jstring>::newArray(size);
377 for (auto i = 0u; i < size; i++) {
378 const auto& entry = buffer[i];
379 // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL
380 // MESSAGE".
381 std::stringstream ss;
382 ss << "[" << entry.timestamp << " " << entry.function << " "
383 << entry.filename << ":" << entry.line << "] "
384 << static_cast<char>(entry.level) << " " << entry.message;
385
386 facebook::jni::local_ref<facebook::jni::JString> jstr_message =
387 facebook::jni::make_jstring(ss.str().c_str());
388 (*ret)[i] = jstr_message;
389 }
390 });
391
392 return ret;
393 #else
394 return facebook::jni::JArrayClass<String>::newArray(0);
395 #endif
396 }
397
registerNatives()398 static void registerNatives() {
399 registerHybrid({
400 makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
401 makeNativeMethod("forward", ExecuTorchJni::forward),
402 makeNativeMethod("execute", ExecuTorchJni::execute),
403 makeNativeMethod("loadMethod", ExecuTorchJni::load_method),
404 makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer),
405 });
406 }
407 };
408 } // namespace executorch::extension
409
410 #ifdef EXECUTORCH_BUILD_LLAMA_JNI
411 extern void register_natives_for_llama();
412 #else
413 // No op if we don't build llama
register_natives_for_llama()414 void register_natives_for_llama() {}
415 #endif
JNI_OnLoad(JavaVM * vm,void *)416 JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
417 return facebook::jni::initialize(vm, [] {
418 executorch::extension::ExecuTorchJni::registerNatives();
419 register_natives_for_llama();
420 });
421 }
422