xref: /aosp_15_r20/external/executorch/extension/android/jni/jni_layer.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cassert>
10 #include <chrono>
11 #include <iostream>
12 #include <memory>
13 #include <sstream>
14 #include <string>
15 #include <unordered_map>
16 #include <vector>
17 
18 #include "jni_layer_constants.h"
19 
20 #include <executorch/extension/android/jni/log.h>
21 #include <executorch/extension/module/module.h>
22 #include <executorch/extension/runner_util/inputs.h>
23 #include <executorch/extension/tensor/tensor.h>
24 #include <executorch/runtime/core/portable_type/tensor_impl.h>
25 #include <executorch/runtime/platform/log.h>
26 #include <executorch/runtime/platform/platform.h>
27 #include <executorch/runtime/platform/runtime.h>
28 
29 #ifdef ET_USE_THREADPOOL
30 #include <cpuinfo.h>
31 #include <executorch/extension/threadpool/threadpool.h>
32 #endif
33 
34 #include <fbjni/ByteBuffer.h>
35 #include <fbjni/fbjni.h>
36 
37 using namespace executorch::extension;
38 using namespace torch::executor;
39 
40 namespace executorch::extension {
41 class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
42  public:
43   constexpr static const char* kJavaDescriptor =
44       "Lorg/pytorch/executorch/Tensor;";
45 
TensorHybrid(exec_aten::Tensor tensor)46   explicit TensorHybrid(exec_aten::Tensor tensor) {}
47 
48   static facebook::jni::local_ref<TensorHybrid::javaobject>
newJTensorFromTensor(const exec_aten::Tensor & tensor)49   newJTensorFromTensor(const exec_aten::Tensor& tensor) {
50     // Java wrapper currently only supports contiguous tensors.
51 
52     const auto scalarType = tensor.scalar_type();
53 
54     if (scalar_type_to_java_dtype.count(scalarType) == 0) {
55       facebook::jni::throwNewJavaException(
56           facebook::jni::gJavaLangIllegalArgumentException,
57           "exec_aten::Tensor scalar type %d is not supported on java side",
58           scalarType);
59     }
60     int jdtype = scalar_type_to_java_dtype.at(scalarType);
61 
62     const auto& tensor_shape = tensor.sizes();
63     std::vector<jlong> tensor_shape_vec;
64     for (const auto& s : tensor_shape) {
65       tensor_shape_vec.push_back(s);
66     }
67     facebook::jni::local_ref<jlongArray> jTensorShape =
68         facebook::jni::make_long_array(tensor_shape_vec.size());
69     jTensorShape->setRegion(
70         0, tensor_shape_vec.size(), tensor_shape_vec.data());
71 
72     static auto cls = TensorHybrid::javaClassStatic();
73     // Note: this is safe as long as the data stored in tensor is valid; the
74     // data won't go out of scope as long as the Method for the inference is
75     // valid and there is no other inference call. Java layer picks up this
76     // value immediately so the data is valid.
77     facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
78         facebook::jni::JByteBuffer::wrapBytes(
79             (uint8_t*)tensor.data_ptr(), tensor.nbytes());
80     jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
81 
82     static const auto jMethodNewTensor =
83         cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
84             facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
85             facebook::jni::alias_ref<jlongArray>,
86             jint,
87             facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
88     return jMethodNewTensor(
89         cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor));
90   }
91 
92  private:
93   friend HybridBase;
94 };
95 
96 class JEValue : public facebook::jni::JavaClass<JEValue> {
97  public:
98   constexpr static const char* kJavaDescriptor =
99       "Lorg/pytorch/executorch/EValue;";
100 
101   constexpr static int kTypeCodeTensor = 1;
102   constexpr static int kTypeCodeString = 2;
103   constexpr static int kTypeCodeDouble = 3;
104   constexpr static int kTypeCodeInt = 4;
105   constexpr static int kTypeCodeBool = 5;
106 
newJEValueFromEValue(EValue evalue)107   static facebook::jni::local_ref<JEValue> newJEValueFromEValue(EValue evalue) {
108     if (evalue.isTensor()) {
109       static auto jMethodTensor =
110           JEValue::javaClassStatic()
111               ->getStaticMethod<facebook::jni::local_ref<JEValue>(
112                   facebook::jni::local_ref<TensorHybrid::javaobject>)>("from");
113       return jMethodTensor(
114           JEValue::javaClassStatic(),
115           TensorHybrid::newJTensorFromTensor(evalue.toTensor()));
116     } else if (evalue.isInt()) {
117       static auto jMethodTensor =
118           JEValue::javaClassStatic()
119               ->getStaticMethod<facebook::jni::local_ref<JEValue>(jlong)>(
120                   "from");
121       return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt());
122     } else if (evalue.isDouble()) {
123       static auto jMethodTensor =
124           JEValue::javaClassStatic()
125               ->getStaticMethod<facebook::jni::local_ref<JEValue>(jdouble)>(
126                   "from");
127       return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble());
128     } else if (evalue.isBool()) {
129       static auto jMethodTensor =
130           JEValue::javaClassStatic()
131               ->getStaticMethod<facebook::jni::local_ref<JEValue>(jboolean)>(
132                   "from");
133       return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool());
134     } else if (evalue.isString()) {
135       static auto jMethodTensor =
136           JEValue::javaClassStatic()
137               ->getStaticMethod<facebook::jni::local_ref<JEValue>(
138                   facebook::jni::local_ref<jstring>)>("from");
139       std::string str =
140           std::string(evalue.toString().begin(), evalue.toString().end());
141       return jMethodTensor(
142           JEValue::javaClassStatic(), facebook::jni::make_jstring(str));
143     }
144     facebook::jni::throwNewJavaException(
145         facebook::jni::gJavaLangIllegalArgumentException,
146         "Unsupported EValue type: %d",
147         evalue.tag);
148   }
149 
JEValueToTensorImpl(facebook::jni::alias_ref<JEValue> JEValue)150   static TensorPtr JEValueToTensorImpl(
151       facebook::jni::alias_ref<JEValue> JEValue) {
152     static const auto typeCodeField =
153         JEValue::javaClassStatic()->getField<jint>("mTypeCode");
154     const auto typeCode = JEValue->getFieldValue(typeCodeField);
155     if (JEValue::kTypeCodeTensor == typeCode) {
156       static const auto jMethodGetTensor =
157           JEValue::javaClassStatic()
158               ->getMethod<facebook::jni::alias_ref<TensorHybrid::javaobject>()>(
159                   "toTensor");
160       auto jtensor = jMethodGetTensor(JEValue);
161 
162       static auto cls = TensorHybrid::javaClassStatic();
163       static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
164       jint jdtype = dtypeMethod(jtensor);
165 
166       static const auto shapeField = cls->getField<jlongArray>("shape");
167       auto jshape = jtensor->getFieldValue(shapeField);
168 
169       static auto dataBufferMethod = cls->getMethod<
170           facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
171           "getRawDataBuffer");
172       facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
173           dataBufferMethod(jtensor);
174 
175       const auto rank = jshape->size();
176 
177       const auto shapeArr = jshape->getRegion(0, rank);
178       std::vector<exec_aten::SizesType> shape_vec;
179       shape_vec.reserve(rank);
180 
181       auto numel = 1;
182       for (int i = 0; i < rank; i++) {
183         shape_vec.push_back(shapeArr[i]);
184       }
185       for (int i = rank - 1; i >= 0; --i) {
186         numel *= shapeArr[i];
187       }
188       JNIEnv* jni = facebook::jni::Environment::current();
189       if (java_dtype_to_scalar_type.count(jdtype) == 0) {
190         facebook::jni::throwNewJavaException(
191             facebook::jni::gJavaLangIllegalArgumentException,
192             "Unknown Tensor jdtype %d",
193             jdtype);
194       }
195       ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype);
196       const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
197       if (dataCapacity != numel) {
198         facebook::jni::throwNewJavaException(
199             facebook::jni::gJavaLangIllegalArgumentException,
200             "Tensor dimensions(elements number:%d inconsistent with buffer capacity(%d)",
201             numel,
202             dataCapacity);
203       }
204       return from_blob(
205           jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type);
206     }
207     facebook::jni::throwNewJavaException(
208         facebook::jni::gJavaLangIllegalArgumentException,
209         "Unknown EValue typeCode %d",
210         typeCode);
211   }
212 };
213 
214 class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
215  private:
216   friend HybridBase;
217   std::unique_ptr<Module> module_;
218 
219  public:
220   constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/NativePeer;";
221 
initHybrid(facebook::jni::alias_ref<jclass>,facebook::jni::alias_ref<jstring> modelPath,jint loadMode)222   static facebook::jni::local_ref<jhybriddata> initHybrid(
223       facebook::jni::alias_ref<jclass>,
224       facebook::jni::alias_ref<jstring> modelPath,
225       jint loadMode) {
226     return makeCxxInstance(modelPath, loadMode);
227   }
228 
ExecuTorchJni(facebook::jni::alias_ref<jstring> modelPath,jint loadMode)229   ExecuTorchJni(facebook::jni::alias_ref<jstring> modelPath, jint loadMode) {
230     Module::LoadMode load_mode = Module::LoadMode::Mmap;
231     if (loadMode == 0) {
232       load_mode = Module::LoadMode::File;
233     } else if (loadMode == 1) {
234       load_mode = Module::LoadMode::Mmap;
235     } else if (loadMode == 2) {
236       load_mode = Module::LoadMode::MmapUseMlock;
237     } else if (loadMode == 3) {
238       load_mode = Module::LoadMode::MmapUseMlockIgnoreErrors;
239     }
240 
241     module_ = std::make_unique<Module>(modelPath->toStdString(), load_mode);
242 
243 #ifdef ET_USE_THREADPOOL
244     // Default to using cores/2 threadpool threads. The long-term plan is to
245     // improve performant core detection in CPUInfo, but for now we can use
246     // cores/2 as a sane default.
247     //
248     // Based on testing, this is almost universally faster than using all
249     // cores, as efficiency cores can be quite slow. In extreme cases, using
250     // all cores can be 10x slower than using cores/2.
251     //
252     // TODO Allow overriding this default from Java.
253     auto threadpool = executorch::extension::threadpool::get_threadpool();
254     if (threadpool) {
255       int thread_count = cpuinfo_get_processors_count() / 2;
256       if (thread_count > 0) {
257         threadpool->_unsafe_reset_threadpool(thread_count);
258       }
259     }
260 #endif
261   }
262 
forward(facebook::jni::alias_ref<facebook::jni::JArrayClass<JEValue::javaobject>::javaobject> jinputs)263   facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> forward(
264       facebook::jni::alias_ref<
265           facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
266           jinputs) {
267     return execute_method("forward", jinputs);
268   }
269 
execute(facebook::jni::alias_ref<jstring> methodName,facebook::jni::alias_ref<facebook::jni::JArrayClass<JEValue::javaobject>::javaobject> jinputs)270   facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> execute(
271       facebook::jni::alias_ref<jstring> methodName,
272       facebook::jni::alias_ref<
273           facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
274           jinputs) {
275     return execute_method(methodName->toStdString(), jinputs);
276   }
277 
load_method(facebook::jni::alias_ref<jstring> methodName)278   jint load_method(facebook::jni::alias_ref<jstring> methodName) {
279     return static_cast<jint>(module_->load_method(methodName->toStdString()));
280   }
281 
execute_method(std::string method,facebook::jni::alias_ref<facebook::jni::JArrayClass<JEValue::javaobject>::javaobject> jinputs)282   facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> execute_method(
283       std::string method,
284       facebook::jni::alias_ref<
285           facebook::jni::JArrayClass<JEValue::javaobject>::javaobject>
286           jinputs) {
287     // If no inputs is given, it will run with sample inputs (ones)
288     if (jinputs->size() == 0) {
289       if (module_->load_method(method) != Error::Ok) {
290         return {};
291       }
292       auto&& underlying_method = module_->methods_[method].method;
293       auto&& buf = prepare_input_tensors(*underlying_method);
294       auto result = underlying_method->execute();
295       if (result != Error::Ok) {
296         return {};
297       }
298       facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
299           facebook::jni::JArrayClass<JEValue>::newArray(
300               underlying_method->outputs_size());
301 
302       for (int i = 0; i < underlying_method->outputs_size(); i++) {
303         auto jevalue =
304             JEValue::newJEValueFromEValue(underlying_method->get_output(i));
305         jresult->setElement(i, *jevalue);
306       }
307       return jresult;
308     }
309 
310     std::vector<EValue> evalues;
311     std::vector<TensorPtr> tensors;
312 
313     static const auto typeCodeField =
314         JEValue::javaClassStatic()->getField<jint>("mTypeCode");
315 
316     for (int i = 0; i < jinputs->size(); i++) {
317       auto jevalue = jinputs->getElement(i);
318       const auto typeCode = jevalue->getFieldValue(typeCodeField);
319       if (typeCode == JEValue::kTypeCodeTensor) {
320         tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue));
321         evalues.emplace_back(tensors.back());
322       } else if (typeCode == JEValue::kTypeCodeInt) {
323         int64_t value = jevalue->getFieldValue(typeCodeField);
324         evalues.emplace_back(value);
325       } else if (typeCode == JEValue::kTypeCodeDouble) {
326         double value = jevalue->getFieldValue(typeCodeField);
327         evalues.emplace_back(value);
328       } else if (typeCode == JEValue::kTypeCodeBool) {
329         bool value = jevalue->getFieldValue(typeCodeField);
330         evalues.emplace_back(value);
331       }
332     }
333 
334 #ifdef EXECUTORCH_ANDROID_PROFILING
335     auto start = std::chrono::high_resolution_clock::now();
336     auto result = module_->execute(method, evalues);
337     auto end = std::chrono::high_resolution_clock::now();
338     auto duration =
339         std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
340             .count();
341     ET_LOG(Debug, "Execution time: %lld ms.", duration);
342 
343 #else
344     auto result = module_->execute(method, evalues);
345 
346 #endif
347 
348     if (!result.ok()) {
349       facebook::jni::throwNewJavaException(
350           "java/lang/Exception",
351           "Execution of method %s failed with status 0x%" PRIx32,
352           method.c_str(),
353           static_cast<error_code_t>(result.error()));
354       return {};
355     }
356 
357     facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> jresult =
358         facebook::jni::JArrayClass<JEValue>::newArray(result.get().size());
359 
360     for (int i = 0; i < result.get().size(); i++) {
361       auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]);
362       jresult->setElement(i, *jevalue);
363     }
364 
365     return jresult;
366   }
367 
368   facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
readLogBuffer()369   readLogBuffer() {
370 #ifdef __ANDROID__
371 
372     facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret;
373 
374     access_log_buffer([&](std::vector<log_entry>& buffer) {
375       const auto size = buffer.size();
376       ret = facebook::jni::JArrayClass<jstring>::newArray(size);
377       for (auto i = 0u; i < size; i++) {
378         const auto& entry = buffer[i];
379         // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL
380         // MESSAGE".
381         std::stringstream ss;
382         ss << "[" << entry.timestamp << " " << entry.function << " "
383            << entry.filename << ":" << entry.line << "] "
384            << static_cast<char>(entry.level) << " " << entry.message;
385 
386         facebook::jni::local_ref<facebook::jni::JString> jstr_message =
387             facebook::jni::make_jstring(ss.str().c_str());
388         (*ret)[i] = jstr_message;
389       }
390     });
391 
392     return ret;
393 #else
394     return facebook::jni::JArrayClass<String>::newArray(0);
395 #endif
396   }
397 
registerNatives()398   static void registerNatives() {
399     registerHybrid({
400         makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
401         makeNativeMethod("forward", ExecuTorchJni::forward),
402         makeNativeMethod("execute", ExecuTorchJni::execute),
403         makeNativeMethod("loadMethod", ExecuTorchJni::load_method),
404         makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer),
405     });
406   }
407 };
408 } // namespace executorch::extension
409 
410 #ifdef EXECUTORCH_BUILD_LLAMA_JNI
411 extern void register_natives_for_llama();
412 #else
413 // No op if we don't build llama
register_natives_for_llama()414 void register_natives_for_llama() {}
415 #endif
JNI_OnLoad(JavaVM * vm,void *)416 JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
417   return facebook::jni::initialize(vm, [] {
418     executorch::extension::ExecuTorchJni::registerNatives();
419     register_natives_for_llama();
420   });
421 }
422