1 #include <cassert>
2 #include <iostream>
3 #include <memory>
4 #include <string>
5
6 #include <c10/core/MemoryFormat.h>
7 #include <c10/util/irange.h>
8
9 #include <fbjni/ByteBuffer.h>
10 #include <fbjni/fbjni.h>
11
12 #include "pytorch_jni_common.h"
13 #if defined(__ANDROID__)
14 #ifndef USE_PTHREADPOOL
15 #define USE_PTHREADPOOL
16 #endif /* USE_PTHREADPOOL */
17 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
18 #endif
19
20 namespace pytorch_jni {
21
deviceJniCodeToDeviceType(jint deviceJniCode)22 c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode) {
23 if (deviceJniCode == kDeviceCPU) {
24 return at::kCPU;
25 } else if (deviceJniCode == kDeviceVulkan) {
26 return at::kVulkan;
27 }
28
29 facebook::jni::throwNewJavaException(
30 facebook::jni::gJavaLangIllegalArgumentException, "Unknown device");
31 }
32
33 bool Trace::is_initialized_ = false;
34
35 #if defined(TRACE_ENABLED) && defined(__ANDROID__)
36 Trace::fp_ATrace_beginSection Trace::ATrace_beginSection;
37 Trace::fp_ATrace_endSection Trace::ATrace_endSection;
38 #endif
39
init()40 void Trace::init() {
41 #if defined(TRACE_ENABLED) && defined(__ANDROID__)
42 void* lib = dlopen("libandroid.so", RTLD_NOW || RTLD_LOCAL);
43 if (lib != NULL) {
44 Trace::ATrace_beginSection = reinterpret_cast<fp_ATrace_beginSection>(
45 dlsym(lib, "ATrace_beginSection"));
46 Trace::ATrace_endSection =
47 reinterpret_cast<fp_ATrace_endSection>(dlsym(lib, "ATrace_endSection"));
48 }
49 #endif
50 }
51
52 // NOTE: Codes must be kept in sync with DType.java.
53 // NOTE: Never serialize these, because they can change between releases.
54 constexpr static int kTensorDTypeUInt8 = 1;
55 constexpr static int kTensorDTypeInt8 = 2;
56 constexpr static int kTensorDTypeInt32 = 3;
57 constexpr static int kTensorDTypeFloat32 = 4;
58 constexpr static int kTensorDTypeInt64 = 5;
59 constexpr static int kTensorDTypeFloat64 = 6;
60
61 constexpr static int kTensorMemoryFormatContiguous = 1;
62 constexpr static int kTensorMemoryFormatChannelsLast = 2;
63 constexpr static int kTensorMemoryFormatChannelsLast3d = 3;
64
65 template <typename K = jobject, typename V = jobject>
66 struct JHashMap
67 : facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>> {
68 constexpr static auto kJavaDescriptor = "Ljava/util/HashMap;";
69
70 using Super =
71 facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>>;
72
createpytorch_jni::JHashMap73 static facebook::jni::local_ref<JHashMap<K, V>> create() {
74 return Super::newInstance();
75 }
76
putpytorch_jni::JHashMap77 void put(
78 facebook::jni::alias_ref<facebook::jni::JObject::javaobject> key,
79 facebook::jni::alias_ref<facebook::jni::JObject::javaobject> value) {
80 static auto putMethod =
81 Super::javaClassStatic()
82 ->template getMethod<facebook::jni::alias_ref<
83 facebook::jni::JObject::javaobject>(
84 facebook::jni::alias_ref<facebook::jni::JObject::javaobject>,
85 facebook::jni::alias_ref<facebook::jni::JObject::javaobject>)>(
86 "put");
87 putMethod(Super::self(), key, value);
88 }
89 };
90
newAtTensor(facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,facebook::jni::alias_ref<jlongArray> jshape,jint jdtype,jint jmemoryFormat)91 static at::Tensor newAtTensor(
92 facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
93 facebook::jni::alias_ref<jlongArray> jshape,
94 jint jdtype,
95 jint jmemoryFormat) {
96 const auto rank = jshape->size();
97 const auto shapeArr = jshape->getRegion(0, rank);
98 std::vector<int64_t> shapeVec{};
99 shapeVec.reserve(rank);
100 auto numel = 1;
101 for (const auto i : c10::irange(rank)) {
102 shapeVec.push_back(shapeArr[i]);
103 numel *= shapeArr[i];
104 }
105 JNIEnv* jni = facebook::jni::Environment::current();
106 caffe2::TypeMeta typeMeta{};
107 int dataElementSizeBytes = 0;
108 if (kTensorDTypeFloat32 == jdtype) {
109 dataElementSizeBytes = 4;
110 typeMeta = caffe2::TypeMeta::Make<float>();
111 } else if (kTensorDTypeInt32 == jdtype) {
112 dataElementSizeBytes = 4;
113 typeMeta = caffe2::TypeMeta::Make<int32_t>();
114 } else if (kTensorDTypeInt8 == jdtype) {
115 dataElementSizeBytes = 1;
116 typeMeta = caffe2::TypeMeta::Make<int8_t>();
117 } else if (kTensorDTypeUInt8 == jdtype) {
118 dataElementSizeBytes = 1;
119 typeMeta = caffe2::TypeMeta::Make<uint8_t>();
120 } else if (kTensorDTypeFloat64 == jdtype) {
121 dataElementSizeBytes = 8;
122 typeMeta = caffe2::TypeMeta::Make<double>();
123 } else if (kTensorDTypeInt64 == jdtype) {
124 dataElementSizeBytes = 8;
125 typeMeta = caffe2::TypeMeta::Make<int64_t>();
126 } else {
127 facebook::jni::throwNewJavaException(
128 facebook::jni::gJavaLangIllegalArgumentException,
129 "Unknown Tensor jdtype %d",
130 jdtype);
131 }
132 const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
133 if (dataCapacity != numel) {
134 facebook::jni::throwNewJavaException(
135 facebook::jni::gJavaLangIllegalArgumentException,
136 "Tensor dimensions(elements number:%d, element byte size:%d, total "
137 "bytes:%d) inconsistent with buffer capacity(%d)",
138 numel,
139 dataElementSizeBytes,
140 numel * dataElementSizeBytes,
141 dataCapacity);
142 }
143
144 if (jmemoryFormat == kTensorMemoryFormatChannelsLast) {
145 auto sizes = torch::IntArrayRef(shapeVec);
146 return torch::from_blob(
147 jni->GetDirectBufferAddress(jbuffer.get()),
148 sizes,
149 torch::IntArrayRef(c10::get_channels_last_strides_2d(sizes)),
150 at::TensorOptions(typeMeta).memory_format(
151 at::MemoryFormat::ChannelsLast));
152 } else if (jmemoryFormat == kTensorMemoryFormatChannelsLast3d) {
153 auto sizes = torch::IntArrayRef(shapeVec);
154 return torch::from_blob(
155 jni->GetDirectBufferAddress(jbuffer.get()),
156 sizes,
157 torch::IntArrayRef(c10::get_channels_last_strides_3d(sizes)),
158 at::TensorOptions(typeMeta).memory_format(
159 at::MemoryFormat::ChannelsLast3d));
160 }
161 return torch::from_blob(
162 jni->GetDirectBufferAddress(jbuffer.get()),
163 torch::IntArrayRef(shapeVec),
164 at::TensorOptions(typeMeta));
165 }
166
167 class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
168 public:
169 constexpr static const char* kJavaDescriptor = "Lorg/pytorch/Tensor;";
170
TensorHybrid(at::Tensor tensor)171 explicit TensorHybrid(at::Tensor tensor) : tensor_(tensor) {}
172
initHybrid(facebook::jni::alias_ref<TensorHybrid::javaobject> jTensorThis)173 static facebook::jni::local_ref<TensorHybrid::jhybriddata> initHybrid(
174 facebook::jni::alias_ref<TensorHybrid::javaobject> jTensorThis) {
175 static auto cls = TensorHybrid::javaClassStatic();
176 static const auto jMethodDTypeCode = cls->getMethod<jint()>("dtypeJniCode");
177 static const auto jMethodMemoryFormatCode =
178 cls->getMethod<jint()>("memoryFormatJniCode");
179 static const auto jFieldShape = cls->getField<jlongArray>("shape");
180 static const auto jMethodGetDataBuffer = cls->getMethod<
181 facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
182 "getRawDataBuffer");
183
184 at::Tensor tensor = newAtTensor(
185 jMethodGetDataBuffer(jTensorThis),
186 jTensorThis->getFieldValue(jFieldShape),
187 jMethodDTypeCode(jTensorThis),
188 jMethodMemoryFormatCode(jTensorThis));
189 return makeCxxInstance(std::move(tensor));
190 }
191
192 static facebook::jni::local_ref<TensorHybrid::javaobject>
newJTensorFromAtTensor(const at::Tensor & input_tensor)193 newJTensorFromAtTensor(const at::Tensor& input_tensor) {
194 // Java wrapper currently only supports contiguous tensors.
195
196 int jmemoryFormat = 0;
197 at::Tensor tensor{};
198 if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast)) {
199 tensor = input_tensor;
200 jmemoryFormat = kTensorMemoryFormatChannelsLast;
201 } else if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
202 tensor = input_tensor;
203 jmemoryFormat = kTensorMemoryFormatChannelsLast3d;
204 } else {
205 tensor = input_tensor.contiguous();
206 jmemoryFormat = kTensorMemoryFormatContiguous;
207 }
208
209 const auto scalarType = tensor.scalar_type();
210 int jdtype = 0;
211 if (at::kFloat == scalarType) {
212 jdtype = kTensorDTypeFloat32;
213 } else if (at::kInt == scalarType) {
214 jdtype = kTensorDTypeInt32;
215 } else if (at::kByte == scalarType) {
216 jdtype = kTensorDTypeUInt8;
217 } else if (at::kChar == scalarType) {
218 jdtype = kTensorDTypeInt8;
219 } else if (at::kLong == scalarType) {
220 jdtype = kTensorDTypeInt64;
221 } else if (at::kDouble == scalarType) {
222 jdtype = kTensorDTypeFloat64;
223 } else {
224 facebook::jni::throwNewJavaException(
225 facebook::jni::gJavaLangIllegalArgumentException,
226 "at::Tensor scalar type %s is not supported on java side",
227 c10::toString(scalarType));
228 }
229
230 const auto& tensorShape = tensor.sizes();
231 std::vector<jlong> tensorShapeVec;
232 for (const auto& s : tensorShape) {
233 tensorShapeVec.push_back(s);
234 }
235 facebook::jni::local_ref<jlongArray> jTensorShape =
236 facebook::jni::make_long_array(tensorShapeVec.size());
237 jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data());
238
239 static auto cls = TensorHybrid::javaClassStatic();
240 facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
241 facebook::jni::JByteBuffer::wrapBytes(
242 (uint8_t*)tensor.data_ptr(), tensor.nbytes());
243 jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
244
245 static const auto jMethodNewTensor =
246 cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
247 facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
248 facebook::jni::alias_ref<jlongArray>,
249 jint,
250 jint,
251 facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
252 return jMethodNewTensor(
253 cls,
254 jTensorBuffer,
255 jTensorShape,
256 jdtype,
257 jmemoryFormat,
258 makeCxxInstance(tensor));
259 }
260
newAtTensorFromJTensor(facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor)261 static at::Tensor newAtTensorFromJTensor(
262 facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor) {
263 static auto cls = TensorHybrid::javaClassStatic();
264 static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
265 jint jdtype = dtypeMethod(jtensor);
266
267 static const auto memoryFormatMethod =
268 cls->getMethod<jint()>("memoryFormatJniCode");
269 jint jmemoryFormat = memoryFormatMethod(jtensor);
270
271 static const auto shapeField = cls->getField<jlongArray>("shape");
272 auto jshape = jtensor->getFieldValue(shapeField);
273
274 static auto dataBufferMethod = cls->getMethod<
275 facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
276 "getRawDataBuffer");
277 facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
278 dataBufferMethod(jtensor);
279 return newAtTensor(jbuffer, jshape, jdtype, jmemoryFormat);
280 }
281
tensor() const282 at::Tensor tensor() const {
283 return tensor_;
284 }
285
286 private:
287 friend HybridBase;
288 at::Tensor tensor_;
289 };
290
newJIValueFromStringDict(c10::Dict<c10::IValue,c10::IValue> dict)291 facebook::jni::local_ref<JIValue> JIValue::newJIValueFromStringDict(
292 c10::Dict<c10::IValue, c10::IValue> dict) {
293 static auto jMethodDictStringKey =
294 JIValue::javaClassStatic()
295 ->getStaticMethod<facebook::jni::local_ref<JIValue>(
296 facebook::jni::alias_ref<facebook::jni::JMap<
297 facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
298 facebook::jni::alias_ref<JIValue::javaobject>>>)>(
299 "dictStringKeyFrom");
300
301 auto jmap = JHashMap<
302 facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
303 facebook::jni::alias_ref<JIValue::javaobject>>::create();
304 for (auto& pair : dict) {
305 jmap->put(
306 facebook::jni::make_jstring(pair.key().toStringRef()),
307 JIValue::newJIValueFromAtIValue(pair.value()));
308 }
309 return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
310 }
311
newJIValueFromIntDict(c10::Dict<c10::IValue,c10::IValue> dict)312 facebook::jni::local_ref<JIValue> JIValue::newJIValueFromIntDict(
313 c10::Dict<c10::IValue, c10::IValue> dict) {
314 static auto jMethodDictLongKey =
315 JIValue::javaClassStatic()
316 ->getStaticMethod<facebook::jni::local_ref<JIValue>(
317 facebook::jni::alias_ref<facebook::jni::JMap<
318 facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
319 facebook::jni::alias_ref<JIValue::javaobject>>>)>(
320 "dictLongKeyFrom");
321 auto jmap = JHashMap<
322 facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
323 facebook::jni::alias_ref<JIValue::javaobject>>::create();
324 for (auto& pair : dict) {
325 jmap->put(
326 facebook::jni::JLong::valueOf(pair.key().toInt()),
327 JIValue::newJIValueFromAtIValue(pair.value()));
328 }
329 return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
330 }
331
newJIValueFromAtIValue(const at::IValue & ivalue,DictCallback stringDictCallback,DictCallback intDictCallback)332 facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
333 const at::IValue& ivalue,
334 DictCallback stringDictCallback,
335 DictCallback intDictCallback) {
336 Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
337 if (ivalue.isNone()) {
338 static auto jMethodOptionalNull =
339 JIValue::javaClassStatic()
340 ->getStaticMethod<facebook::jni::local_ref<JIValue>()>(
341 "optionalNull");
342 return jMethodOptionalNull(JIValue::javaClassStatic());
343 } else if (ivalue.isTensor()) {
344 static auto jMethodTensor =
345 JIValue::javaClassStatic()
346 ->getStaticMethod<facebook::jni::local_ref<JIValue>(
347 facebook::jni::local_ref<TensorHybrid::javaobject>)>("from");
348 const auto& tensor = ivalue.toTensor();
349 return jMethodTensor(
350 JIValue::javaClassStatic(),
351 TensorHybrid::newJTensorFromAtTensor(tensor));
352 } else if (ivalue.isBool()) {
353 static auto jMethodBool =
354 JIValue::javaClassStatic()
355 ->getStaticMethod<facebook::jni::local_ref<JIValue>(jboolean)>(
356 "from");
357 return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool());
358 } else if (ivalue.isInt()) {
359 static auto jMethodInt =
360 JIValue::javaClassStatic()
361 ->getStaticMethod<facebook::jni::local_ref<JIValue>(jlong)>("from");
362 return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt());
363 } else if (ivalue.isDouble()) {
364 static auto jMethodDouble =
365 JIValue::javaClassStatic()
366 ->getStaticMethod<facebook::jni::local_ref<JIValue>(jdouble)>(
367 "from");
368 return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble());
369 } else if (ivalue.isString()) {
370 static auto jMethodString =
371 JIValue::javaClassStatic()
372 ->getStaticMethod<facebook::jni::local_ref<JIValue>(
373 facebook::jni::alias_ref<facebook::jni::JString::javaobject>)>(
374 "from");
375 return jMethodString(
376 JIValue::javaClassStatic(),
377 facebook::jni::make_jstring(ivalue.toStringRef()));
378 } else if (ivalue.isTuple()) {
379 auto elementsVec = ivalue.toTupleRef().elements();
380 static auto jMethodTupleArr =
381 JIValue::javaClassStatic()
382 ->getStaticMethod<facebook::jni::local_ref<JIValue>(
383 facebook::jni::alias_ref<facebook::jni::JArrayClass<
384 JIValue::javaobject>::javaobject>)>("tupleFrom");
385 auto jElementsArray =
386 facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
387 elementsVec.size());
388 auto index = 0;
389 for (const auto& e : elementsVec) {
390 (*jElementsArray)[index++] = JIValue::newJIValueFromAtIValue(e);
391 }
392 return jMethodTupleArr(JIValue::javaClassStatic(), jElementsArray);
393 } else if (ivalue.isBoolList()) {
394 auto list = ivalue.toBoolList();
395 static auto jMethodBoolListArr =
396 JIValue::javaClassStatic()
397 ->getStaticMethod<facebook::jni::local_ref<JIValue>(
398 facebook::jni::alias_ref<jbooleanArray>)>("listFrom");
399 size_t n = list.size();
400 auto jArray = facebook::jni::make_boolean_array(n);
401 auto jArrayPinned = jArray->pin();
402 auto index = 0;
403 for (const auto& e : list) {
404 jArrayPinned[index++] = e;
405 }
406 return jMethodBoolListArr(JIValue::javaClassStatic(), jArray);
407 } else if (ivalue.isIntList()) {
408 auto list = ivalue.toIntList();
409 static auto jMethodLongListArr =
410 JIValue::javaClassStatic()
411 ->getStaticMethod<facebook::jni::local_ref<JIValue>(
412 facebook::jni::alias_ref<jlongArray>)>("listFrom");
413 size_t n = list.size();
414 auto jArray = facebook::jni::make_long_array(n);
415 auto jArrayPinned = jArray->pin();
416 auto index = 0;
417 for (const auto& e : list) {
418 jArrayPinned[index++] = e;
419 }
420 return jMethodLongListArr(JIValue::javaClassStatic(), jArray);
421 } else if (ivalue.isDoubleList()) {
422 auto list = ivalue.toDoubleList();
423 static auto jMethoDoubleListArr =
424 JIValue::javaClassStatic()
425 ->getStaticMethod<facebook::jni::local_ref<JIValue>(
426 facebook::jni::alias_ref<jdoubleArray>)>("listFrom");
427 size_t n = list.size();
428 auto jArray = facebook::jni::make_double_array(n);
429 auto jArrayPinned = jArray->pin();
430 auto index = 0;
431 for (const auto& e : list) {
432 jArrayPinned[index++] = e;
433 }
434 return jMethoDoubleListArr(JIValue::javaClassStatic(), jArray);
435 } else if (ivalue.isTensorList()) {
436 auto list = ivalue.toTensorList();
437 static auto jMethodTensorListArr =
438 JIValue::javaClassStatic()
439 ->getStaticMethod<facebook::jni::local_ref<JIValue>(
440 facebook::jni::alias_ref<facebook::jni::JArrayClass<
441 TensorHybrid::javaobject>::javaobject>)>("listFrom");
442 auto jArray =
443 facebook::jni::JArrayClass<TensorHybrid::javaobject>::newArray(
444 list.size());
445 auto index = 0;
446 for (const auto& e : list) {
447 (*jArray)[index++] = TensorHybrid::newJTensorFromAtTensor(e);
448 }
449 return jMethodTensorListArr(JIValue::javaClassStatic(), jArray);
450 } else if (ivalue.isList()) {
451 auto list = ivalue.toList();
452 static auto jMethodListArr =
453 JIValue::javaClassStatic()
454 ->getStaticMethod<facebook::jni::local_ref<JIValue>(
455 facebook::jni::alias_ref<facebook::jni::JArrayClass<
456 JIValue::javaobject>::javaobject>)>("listFrom");
457 auto jArray =
458 facebook::jni::JArrayClass<JIValue::javaobject>::newArray(list.size());
459 auto index = 0;
460 for (const auto& e : list) {
461 (*jArray)[index++] = JIValue::newJIValueFromAtIValue(e);
462 }
463 return jMethodListArr(JIValue::javaClassStatic(), jArray);
464 } else if (ivalue.isGenericDict()) {
465 auto dict = ivalue.toGenericDict();
466 const auto keyType = dict.keyType();
467
468 if (!keyType) {
469 facebook::jni::throwNewJavaException(
470 facebook::jni::gJavaLangIllegalArgumentException,
471 "Unknown IValue-Dict key type");
472 }
473
474 if (*keyType == *c10::StringType::get()) {
475 return stringDictCallback(std::move(dict));
476 } else if (*keyType == *c10::IntType::get()) {
477 return intDictCallback(std::move(dict));
478 }
479
480 facebook::jni::throwNewJavaException(
481 facebook::jni::gJavaLangIllegalArgumentException,
482 "Unsupported IValue-Dict key type: %s",
483 keyType->str().c_str());
484 }
485
486 facebook::jni::throwNewJavaException(
487 facebook::jni::gJavaLangIllegalArgumentException,
488 "Unsupported IValue type %s",
489 ivalue.tagKind().c_str());
490 }
491
JIValueToAtIValue(facebook::jni::alias_ref<JIValue> jivalue)492 at::IValue JIValue::JIValueToAtIValue(
493 facebook::jni::alias_ref<JIValue> jivalue) {
494 Trace _s{"jni::JIValue::JIValueToAtIValue"};
495 static const auto typeCodeField =
496 JIValue::javaClassStatic()->getField<jint>("mTypeCode");
497 const auto typeCode = jivalue->getFieldValue(typeCodeField);
498 if (JIValue::kTypeCodeNull == typeCode) {
499 return at::IValue{};
500 } else if (JIValue::kTypeCodeTensor == typeCode) {
501 static const auto jMethodGetTensor =
502 JIValue::javaClassStatic()
503 ->getMethod<facebook::jni::alias_ref<TensorHybrid::javaobject>()>(
504 "toTensor");
505 return TensorHybrid::newAtTensorFromJTensor(jMethodGetTensor(jivalue));
506 } else if (JIValue::kTypeCodeBool == typeCode) {
507 static const auto jMethodGetBool =
508 JIValue::javaClassStatic()->getMethod<jboolean()>("toBool");
509 // explicit cast to bool as jboolean is defined as uint8_t, IValue ctor
510 // for int will be called for jboolean
511 bool b = jMethodGetBool(jivalue);
512 return at::IValue{b};
513 } else if (JIValue::kTypeCodeLong == typeCode) {
514 static const auto jMethodGetLong =
515 JIValue::javaClassStatic()->getMethod<jlong()>("toLong");
516 return at::IValue{(int64_t)jMethodGetLong(jivalue)};
517 } else if (JIValue::kTypeCodeDouble == typeCode) {
518 static const auto jMethodGetDouble =
519 JIValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
520 return at::IValue{jMethodGetDouble(jivalue)};
521 } else if (JIValue::kTypeCodeString == typeCode) {
522 static const auto jMethodGetString =
523 JIValue::javaClassStatic()->getMethod<jstring()>("toStr");
524 return at::IValue{jMethodGetString(jivalue)->toStdString()};
525 } else if (JIValue::kTypeCodeTuple == typeCode) {
526 static const auto jMethodGetTuple =
527 JIValue::javaClassStatic()
528 ->getMethod<
529 facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
530 "toTuple");
531 auto jarray = jMethodGetTuple(jivalue);
532 size_t n = jarray->size();
533
534 std::vector<at::IValue> elements;
535 elements.reserve(n);
536 for (const auto i : c10::irange(n)) {
537 auto jivalue_element = jarray->getElement(i);
538 auto element = JIValue::JIValueToAtIValue(jivalue_element);
539 elements.push_back(std::move(element));
540 }
541 return c10::ivalue::Tuple::create(std::move(elements));
542 } else if (JIValue::kTypeCodeBoolList == typeCode) {
543 static const auto jMethodGetBoolList =
544 JIValue::javaClassStatic()->getMethod<jbooleanArray()>("toBoolList");
545 auto jArray = jMethodGetBoolList(jivalue);
546 auto jArrayPinned = jArray->pin();
547 size_t n = jArrayPinned.size();
548 c10::List<bool> list{};
549 list.reserve(n);
550 for (const auto i : c10::irange(n)) {
551 list.push_back(jArrayPinned[i]);
552 }
553 return at::IValue{std::move(list)};
554 } else if (JIValue::kTypeCodeLongList == typeCode) {
555 static const auto jMethodGetLongList =
556 JIValue::javaClassStatic()->getMethod<jlongArray()>("toLongList");
557 auto jArray = jMethodGetLongList(jivalue);
558 auto jArrayPinned = jArray->pin();
559 size_t n = jArrayPinned.size();
560 c10::List<int64_t> list{};
561 list.reserve(n);
562 for (const auto i : c10::irange(n)) {
563 list.push_back(jArrayPinned[i]);
564 }
565 return at::IValue{std::move(list)};
566 } else if (JIValue::kTypeCodeDoubleList == typeCode) {
567 static const auto jMethodGetDoubleList =
568 JIValue::javaClassStatic()->getMethod<jdoubleArray()>("toDoubleList");
569 auto jArray = jMethodGetDoubleList(jivalue);
570 auto jArrayPinned = jArray->pin();
571 size_t n = jArrayPinned.size();
572 c10::List<double> list{};
573 list.reserve(n);
574 for (const auto i : c10::irange(n)) {
575 list.push_back(jArrayPinned[i]);
576 }
577 return at::IValue{std::move(list)};
578 } else if (JIValue::kTypeCodeTensorList == typeCode) {
579 static const auto jMethodGetTensorList =
580 JIValue::javaClassStatic()
581 ->getMethod<facebook::jni::JArrayClass<
582 TensorHybrid::javaobject>::javaobject()>("toTensorList");
583 auto jArray = jMethodGetTensorList(jivalue);
584 size_t n = jArray->size();
585 c10::List<at::Tensor> list{};
586 list.reserve(n);
587 for (const auto i : c10::irange(n)) {
588 list.push_back(
589 TensorHybrid::newAtTensorFromJTensor(jArray->getElement(i)));
590 }
591 return at::IValue{std::move(list)};
592 } else if (JIValue::kTypeCodeList == typeCode) {
593 static const auto jMethodGetList =
594 JIValue::javaClassStatic()
595 ->getMethod<
596 facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
597 "toList");
598 auto jarray = jMethodGetList(jivalue);
599 size_t n = jarray->size();
600 if (n == 0) {
601 return at::IValue{c10::impl::GenericList(c10::TensorType::get())};
602 }
603
604 auto jivalue_first_element = jarray->getElement(0);
605 auto first_element = JIValue::JIValueToAtIValue(jivalue_first_element);
606 c10::impl::GenericList list{c10::unshapedType(first_element.type())};
607 list.reserve(n);
608 list.push_back(first_element);
609 for (const auto i : c10::irange(1, n)) {
610 auto jivalue_element = jarray->getElement(i);
611 auto element = JIValue::JIValueToAtIValue(jivalue_element);
612 list.push_back(element);
613 }
614 return at::IValue{list};
615 } else if (JIValue::kTypeCodeDictStringKey == typeCode) {
616 static const auto jMethodGetDictStringKey =
617 JIValue::javaClassStatic()
618 ->getMethod<facebook::jni::JMap<jstring, JIValue::javaobject>::
619 javaobject()>("toDictStringKey");
620 auto jmap = jMethodGetDictStringKey(jivalue);
621 auto it = jmap->begin();
622 if (it == jmap->end()) {
623 return at::IValue{c10::impl::GenericDict(
624 c10::StringType::get(), c10::TensorType::get())};
625 }
626
627 auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
628 c10::impl::GenericDict dict{
629 c10::StringType::get(), c10::unshapedType(firstEntryValue.type())};
630 dict.insert(it->first->toStdString(), firstEntryValue);
631 it++;
632 for (; it != jmap->end(); it++) {
633 dict.insert(
634 it->first->toStdString(), JIValue::JIValueToAtIValue(it->second));
635 }
636 return at::IValue{dict};
637 } else if (JIValue::kTypeCodeDictLongKey == typeCode) {
638 static const auto jMethodGetDictLongKey =
639 JIValue::javaClassStatic()
640 ->getMethod<facebook::jni::JMap<
641 facebook::jni::JLong::javaobject,
642 JIValue::javaobject>::javaobject()>("toDictLongKey");
643 auto jmap = jMethodGetDictLongKey(jivalue);
644 auto it = jmap->begin();
645 if (it == jmap->end()) {
646 return at::IValue{
647 c10::impl::GenericDict(c10::IntType::get(), c10::TensorType::get())};
648 }
649
650 auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
651 c10::impl::GenericDict dict{
652 c10::IntType::get(), c10::unshapedType(firstEntryValue.type())};
653 dict.insert((int64_t)it->first->longValue(), firstEntryValue);
654 it++;
655 for (; it != jmap->end(); it++) {
656 dict.insert(
657 (int64_t)it->first->longValue(),
658 JIValue::JIValueToAtIValue(it->second));
659 }
660 return at::IValue{dict};
661 }
662
663 facebook::jni::throwNewJavaException(
664 facebook::jni::gJavaLangIllegalArgumentException,
665 "Unknown IValue typeCode %d",
666 typeCode);
667 }
668
669 #if defined(__ANDROID__)
670 class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
671 public:
672 constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
673
registerNatives()674 static void registerNatives() {
675 javaClassStatic()->registerNatives({
676 makeNativeMethod(
677 "nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
678 });
679 }
680
setNumThreads(facebook::jni::alias_ref<jclass>,jint numThreads)681 static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
682 caffe2::pthreadpool()->set_thread_count(numThreads);
683 }
684 };
685 #endif
686
common_registerNatives()687 void common_registerNatives() {
688 static const int once = []() {
689 #if defined(__ANDROID__)
690 pytorch_jni::PyTorchAndroidJni::registerNatives();
691 #endif
692 return 0;
693 }();
694 ((void)once);
695 }
696
697 } // namespace pytorch_jni
698