1 #include <android/log.h>
2 #include <pthread.h>
3 #include <unistd.h>
4 #include <cassert>
5 #include <cmath>
6 #include <vector>
7 #define ALOGI(...) \
8 __android_log_print(ANDROID_LOG_INFO, "PyTorchTestAppJni", __VA_ARGS__)
9 #define ALOGE(...) \
10 __android_log_print(ANDROID_LOG_ERROR, "PyTorchTestAppJni", __VA_ARGS__)
11
12 #include "jni.h"
13
14 #include <torch/script.h>
15
16 namespace pytorch_testapp_jni {
17 namespace {
18
19 template <typename T>
log(const char * m,T t)20 void log(const char* m, T t) {
21 std::ostringstream os;
22 os << t << std::endl;
23 ALOGI("%s %s", m, os.str().c_str());
24 }
25
26 struct JITCallGuard {
27 c10::InferenceMode guard;
28 torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
29 };
30 } // namespace
31
loadAndForwardModel(JNIEnv * env,jclass,jstring jModelPath)32 static void loadAndForwardModel(JNIEnv* env, jclass, jstring jModelPath) {
33 const char* modelPath = env->GetStringUTFChars(jModelPath, 0);
34 assert(modelPath);
35
36 // To load torchscript model for mobile we need set these guards,
37 // because mobile build doesn't support features like autograd for smaller
38 // build size which is placed in `struct JITCallGuard` in this example. It may
39 // change in future, you can track the latest changes keeping an eye in
40 // android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
41 JITCallGuard guard;
42 torch::jit::Module module = torch::jit::load(modelPath);
43 module.eval();
44 torch::Tensor t = torch::randn({1, 3, 224, 224});
45 log("input tensor:", t);
46 c10::IValue t_out = module.forward({t});
47 log("output tensor:", t_out);
48 env->ReleaseStringUTFChars(jModelPath, modelPath);
49 }
50 } // namespace pytorch_testapp_jni
51
JNI_OnLoad(JavaVM * vm,void *)52 JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
53 JNIEnv* env;
54 if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
55 return JNI_ERR;
56 }
57
58 jclass c =
59 env->FindClass("org/pytorch/testapp/LibtorchNativeClient$NativePeer");
60 if (c == nullptr) {
61 return JNI_ERR;
62 }
63
64 static const JNINativeMethod methods[] = {
65 {"loadAndForwardModel",
66 "(Ljava/lang/String;)V",
67 (void*)pytorch_testapp_jni::loadAndForwardModel},
68 };
69 int rc = env->RegisterNatives(
70 c, methods, sizeof(methods) / sizeof(JNINativeMethod));
71
72 if (rc != JNI_OK) {
73 return rc;
74 }
75
76 return JNI_VERSION_1_6;
77 }
78