xref: /aosp_15_r20/external/pytorch/android/test_app/app/src/main/cpp/pytorch_testapp_jni.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <android/log.h>
2 #include <pthread.h>
3 #include <unistd.h>
4 #include <cassert>
5 #include <cmath>
6 #include <vector>
7 #define ALOGI(...) \
8   __android_log_print(ANDROID_LOG_INFO, "PyTorchTestAppJni", __VA_ARGS__)
9 #define ALOGE(...) \
10   __android_log_print(ANDROID_LOG_ERROR, "PyTorchTestAppJni", __VA_ARGS__)
11 
12 #include "jni.h"
13 
14 #include <torch/script.h>
15 
16 namespace pytorch_testapp_jni {
17 namespace {
18 
19 template <typename T>
log(const char * m,T t)20 void log(const char* m, T t) {
21   std::ostringstream os;
22   os << t << std::endl;
23   ALOGI("%s %s", m, os.str().c_str());
24 }
25 
26 struct JITCallGuard {
27   c10::InferenceMode guard;
28   torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
29 };
30 } // namespace
31 
loadAndForwardModel(JNIEnv * env,jclass,jstring jModelPath)32 static void loadAndForwardModel(JNIEnv* env, jclass, jstring jModelPath) {
33   const char* modelPath = env->GetStringUTFChars(jModelPath, 0);
34   assert(modelPath);
35 
36   // To load torchscript model for mobile we need set these guards,
37   // because mobile build doesn't support features like autograd for smaller
38   // build size which is placed in `struct JITCallGuard` in this example. It may
39   // change in future, you can track the latest changes keeping an eye in
40   // android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
41   JITCallGuard guard;
42   torch::jit::Module module = torch::jit::load(modelPath);
43   module.eval();
44   torch::Tensor t = torch::randn({1, 3, 224, 224});
45   log("input tensor:", t);
46   c10::IValue t_out = module.forward({t});
47   log("output tensor:", t_out);
48   env->ReleaseStringUTFChars(jModelPath, modelPath);
49 }
50 } // namespace pytorch_testapp_jni
51 
JNI_OnLoad(JavaVM * vm,void *)52 JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
53   JNIEnv* env;
54   if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
55     return JNI_ERR;
56   }
57 
58   jclass c =
59       env->FindClass("org/pytorch/testapp/LibtorchNativeClient$NativePeer");
60   if (c == nullptr) {
61     return JNI_ERR;
62   }
63 
64   static const JNINativeMethod methods[] = {
65       {"loadAndForwardModel",
66        "(Ljava/lang/String;)V",
67        (void*)pytorch_testapp_jni::loadAndForwardModel},
68   };
69   int rc = env->RegisterNatives(
70       c, methods, sizeof(methods) / sizeof(JNINativeMethod));
71 
72   if (rc != JNI_OK) {
73     return rc;
74   }
75 
76   return JNI_VERSION_1_6;
77 }
78