xref: /aosp_15_r20/external/pytorch/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <cassert>
2 #include <iostream>
3 #include <memory>
4 #include <string>
5 
6 #include <fbjni/ByteBuffer.h>
7 #include <fbjni/fbjni.h>
8 
9 #include <ATen/record_function.h>
10 #include <torch/csrc/jit/runtime/print_handler.h>
11 #include <torch/script.h>
12 #include "caffe2/serialize/read_adapter_interface.h"
13 
14 #include "pytorch_jni_common.h"
15 
16 #ifdef __ANDROID__
17 #include <android/asset_manager.h>
18 #include <android/asset_manager_jni.h>
19 #include <android/log.h>
20 #endif
21 
22 namespace pytorch_jni {
23 
24 namespace {
25 
26 struct JITCallGuard {
27   // Inference only workload.
28   c10::InferenceMode guard;
29   // Disable graph optimizer to ensure list of unused ops are not changed for
30   // custom mobile build.
31   torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
32 };
33 
34 } // namespace
35 
36 class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
37  private:
38   friend HybridBase;
39   torch::jit::Module module_;
40   c10::DeviceType deviceType_;
41 
42  public:
43   constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";
44 
initHybrid(facebook::jni::alias_ref<jclass>,facebook::jni::alias_ref<jstring> modelPath,facebook::jni::alias_ref<facebook::jni::JMap<facebook::jni::JString,facebook::jni::JString>> extraFiles,jint device)45   static facebook::jni::local_ref<jhybriddata> initHybrid(
46       facebook::jni::alias_ref<jclass>,
47       facebook::jni::alias_ref<jstring> modelPath,
48       facebook::jni::alias_ref<
49           facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
50           extraFiles,
51       jint device) {
52     return makeCxxInstance(modelPath, extraFiles, device);
53   }
54 
55 #ifdef __ANDROID__
initHybridAndroidAsset(facebook::jni::alias_ref<jclass>,facebook::jni::alias_ref<jstring> assetName,facebook::jni::alias_ref<jobject> assetManager,jint device)56   static facebook::jni::local_ref<jhybriddata> initHybridAndroidAsset(
57       facebook::jni::alias_ref<jclass>,
58       facebook::jni::alias_ref<jstring> assetName,
59       facebook::jni::alias_ref<jobject> assetManager,
60       jint device) {
61     return makeCxxInstance(assetName, assetManager, device);
62   }
63 #endif
64 
65 #ifdef TRACE_ENABLED
onFunctionEnter(const at::RecordFunction & fn)66   static std::unique_ptr<at::ObserverContext> onFunctionEnter(
67       const at::RecordFunction& fn) {
68     Trace::beginSection(fn.name().str());
69     return nullptr;
70   }
71 
onFunctionExit(const at::RecordFunction &,at::ObserverContext *)72   static void onFunctionExit(const at::RecordFunction&, at::ObserverContext*) {
73     Trace::endSection();
74   }
75 #endif
76 
preModuleLoadSetupOnce()77   static void preModuleLoadSetupOnce() {
78     auto qengines = at::globalContext().supportedQEngines();
79     if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
80         qengines.end()) {
81       at::globalContext().setQEngine(at::QEngine::QNNPACK);
82     }
83 
84 #ifdef __ANDROID__
85     torch::jit::setPrintHandler([](const std::string& s) {
86       __android_log_print(ANDROID_LOG_DEBUG, "pytorch-print", "%s", s.c_str());
87     });
88 #endif
89 
90 #ifdef TRACE_ENABLED
91     at::addGlobalCallback(
92         at::RecordFunctionCallback(&onFunctionEnter, &onFunctionExit)
93             .scopes({RecordScope::FUNCTION, RecordScope::USER_SCOPE}));
94 #endif
95   }
96 
preModuleLoadSetup()97   void preModuleLoadSetup() {
98     static const int once = []() {
99       preModuleLoadSetupOnce();
100       return 0;
101     }();
102     ((void)once);
103   }
104 
PytorchJni(facebook::jni::alias_ref<jstring> modelPath,facebook::jni::alias_ref<facebook::jni::JMap<facebook::jni::JString,facebook::jni::JString>> extraFiles,jint device)105   PytorchJni(
106       facebook::jni::alias_ref<jstring> modelPath,
107       facebook::jni::alias_ref<
108           facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
109           extraFiles,
110       jint device) {
111     preModuleLoadSetup();
112     JITCallGuard guard;
113     std::unordered_map<std::string, std::string> extra_files;
114     const auto has_extra = extraFiles && extraFiles->size() > 0;
115     if (has_extra) {
116       for (const auto& e : *extraFiles) {
117         extra_files[e.first->toStdString()] = "";
118       }
119     }
120     deviceType_ = deviceJniCodeToDeviceType(device);
121     module_ = torch::jit::load(
122         std::move(modelPath->toStdString()), std::nullopt, extra_files);
123     if (has_extra) {
124       static auto putMethod =
125           facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>::
126               javaClassStatic()
127                   ->template getMethod<facebook::jni::alias_ref<jobject>(
128                       facebook::jni::alias_ref<jobject>,
129                       facebook::jni::alias_ref<jobject>)>("put");
130       for (const auto& ef : extra_files) {
131         putMethod(
132             extraFiles,
133             facebook::jni::make_jstring(ef.first),
134             facebook::jni::make_jstring(ef.second));
135       }
136     }
137 
138     module_.eval();
139   }
140 
141 #ifdef __ANDROID__
PytorchJni(facebook::jni::alias_ref<jstring> assetName,facebook::jni::alias_ref<jobject> assetManager,jint device)142   PytorchJni(
143       facebook::jni::alias_ref<jstring> assetName,
144       facebook::jni::alias_ref<jobject> assetManager,
145       jint device) {
146     preModuleLoadSetup();
147     JNIEnv* env = facebook::jni::Environment::current();
148     AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get());
149     if (!mgr) {
150       facebook::jni::throwNewJavaException(
151           facebook::jni::gJavaLangIllegalArgumentException,
152           "Unable to get asset manager");
153     }
154     AAsset* asset = AAssetManager_open(
155         mgr, assetName->toStdString().c_str(), AASSET_MODE_BUFFER);
156     if (!asset) {
157       facebook::jni::throwNewJavaException(
158           facebook::jni::gJavaLangIllegalArgumentException,
159           "Failed to open asset '%s'",
160           assetName->toStdString().c_str());
161     }
162     auto assetBuffer = AAsset_getBuffer(asset);
163     if (!assetBuffer) {
164       facebook::jni::throwNewJavaException(
165           facebook::jni::gJavaLangIllegalArgumentException,
166           "Could not get buffer for asset '%s'",
167           assetName->toStdString().c_str());
168     }
169     JITCallGuard guard;
170     module_ = torch::jit::load(std::make_unique<MemoryReadAdapter>(
171         assetBuffer, AAsset_getLength(asset)));
172     AAsset_close(asset);
173     module_.eval();
174     deviceType_ = deviceJniCodeToDeviceType(device);
175   }
176 #endif
177 
registerNatives()178   static void registerNatives() {
179     registerHybrid({
180         makeNativeMethod("initHybrid", PytorchJni::initHybrid),
181 #ifdef __ANDROID__
182         makeNativeMethod(
183             "initHybridAndroidAsset", PytorchJni::initHybridAndroidAsset),
184 #endif
185         makeNativeMethod("forward", PytorchJni::forward),
186         makeNativeMethod("runMethod", PytorchJni::runMethod),
187     });
188   }
189 
forward(facebook::jni::alias_ref<facebook::jni::JArrayClass<JIValue::javaobject>::javaobject> jinputs)190   facebook::jni::local_ref<JIValue> forward(
191       facebook::jni::alias_ref<
192           facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
193           jinputs) {
194     Trace _s{"jni::Module::forward"};
195     std::vector<at::IValue> inputs{};
196     size_t n = jinputs->size();
197     inputs.reserve(n);
198     for (const auto i : c10::irange(n)) {
199       at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
200       inputs.push_back(std::move(atIValue));
201     }
202     auto output = [&]() {
203       JITCallGuard guard;
204       return module_.forward(std::move(inputs));
205     }();
206     return JIValue::newJIValueFromAtIValue(output);
207   }
208 
runMethod(facebook::jni::alias_ref<facebook::jni::JString::javaobject> jmethodName,facebook::jni::alias_ref<facebook::jni::JArrayClass<JIValue::javaobject>::javaobject> jinputs)209   facebook::jni::local_ref<JIValue> runMethod(
210       facebook::jni::alias_ref<facebook::jni::JString::javaobject> jmethodName,
211       facebook::jni::alias_ref<
212           facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
213           jinputs) {
214     std::string methodName = jmethodName->toStdString();
215 
216     std::vector<at::IValue> inputs{};
217     size_t n = jinputs->size();
218     inputs.reserve(n);
219     for (const auto i : c10::irange(n)) {
220       at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
221       inputs.push_back(std::move(atIValue));
222     }
223     if (auto method = module_.find_method(methodName)) {
224       auto output = [&]() {
225         JITCallGuard guard;
226         return (*method)(std::move(inputs));
227       }();
228       return JIValue::newJIValueFromAtIValue(output);
229     }
230 
231     facebook::jni::throwNewJavaException(
232         facebook::jni::gJavaLangIllegalArgumentException,
233         "Undefined method %s",
234         methodName.c_str());
235   }
236 };
237 
238 } // namespace pytorch_jni
239 
JNI_OnLoad(JavaVM * vm,void *)240 JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
241   return facebook::jni::initialize(vm, [] {
242     pytorch_jni::common_registerNatives();
243     pytorch_jni::PytorchJni::registerNatives();
244   });
245 }
246