1 #pragma once 2 3 #include <c10/util/FunctionRef.h> 4 #include <fbjni/fbjni.h> 5 #include <torch/csrc/api/include/torch/types.h> 6 #include "caffe2/serialize/read_adapter_interface.h" 7 8 #include "cmake_macros.h" 9 10 #ifdef __ANDROID__ 11 #include <android/log.h> 12 #define ALOGI(...) \ 13 __android_log_print(ANDROID_LOG_INFO, "pytorch-jni", __VA_ARGS__) 14 #define ALOGE(...) \ 15 __android_log_print(ANDROID_LOG_ERROR, "pytorch-jni", __VA_ARGS__) 16 #endif 17 18 #if defined(TRACE_ENABLED) && defined(__ANDROID__) 19 #include <android/trace.h> 20 #include <dlfcn.h> 21 #endif 22 23 namespace pytorch_jni { 24 25 constexpr static int kDeviceCPU = 1; 26 constexpr static int kDeviceVulkan = 2; 27 28 c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode); 29 30 class Trace { 31 public: 32 #if defined(TRACE_ENABLED) && defined(__ANDROID__) 33 typedef void* (*fp_ATrace_beginSection)(const char* sectionName); 34 typedef void* (*fp_ATrace_endSection)(void); 35 36 static fp_ATrace_beginSection ATrace_beginSection; 37 static fp_ATrace_endSection ATrace_endSection; 38 #endif 39 ensureInit()40 static void ensureInit() { 41 if (!Trace::is_initialized_) { 42 init(); 43 Trace::is_initialized_ = true; 44 } 45 } 46 beginSection(const char * name)47 static void beginSection(const char* name) { 48 Trace::ensureInit(); 49 #if defined(TRACE_ENABLED) && defined(__ANDROID__) 50 ATrace_beginSection(name); 51 #endif 52 } 53 endSection()54 static void endSection() { 55 #if defined(TRACE_ENABLED) && defined(__ANDROID__) 56 ATrace_endSection(); 57 #endif 58 } 59 Trace(const char * name)60 Trace(const char* name) { 61 ensureInit(); 62 beginSection(name); 63 } 64 ~Trace()65 ~Trace() { 66 endSection(); 67 } 68 69 private: 70 static void init(); 71 static bool is_initialized_; 72 }; 73 74 class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface { 75 public: MemoryReadAdapter(const void * data,off_t size)76 explicit MemoryReadAdapter(const void* data, off_t size) 77 : data_(data), size_(size){}; 78 size()79 size_t size() const override { 80 return size_; 81 } 82 83 size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") 84 const override { 85 memcpy(buf, (int8_t*)(data_) + pos, n); 86 return n; 87 } 88 ~MemoryReadAdapter()89 ~MemoryReadAdapter() {} 90 91 private: 92 const void* data_; 93 off_t size_; 94 }; 95 96 class JIValue : public facebook::jni::JavaClass<JIValue> { 97 using DictCallback = c10::function_ref<facebook::jni::local_ref<JIValue>( 98 c10::Dict<c10::IValue, c10::IValue>)>; 99 100 public: 101 constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;"; 102 103 constexpr static int kTypeCodeNull = 1; 104 105 constexpr static int kTypeCodeTensor = 2; 106 constexpr static int kTypeCodeBool = 3; 107 constexpr static int kTypeCodeLong = 4; 108 constexpr static int kTypeCodeDouble = 5; 109 constexpr static int kTypeCodeString = 6; 110 111 constexpr static int kTypeCodeTuple = 7; 112 constexpr static int kTypeCodeBoolList = 8; 113 constexpr static int kTypeCodeLongList = 9; 114 constexpr static int kTypeCodeDoubleList = 10; 115 constexpr static int kTypeCodeTensorList = 11; 116 constexpr static int kTypeCodeList = 12; 117 118 constexpr static int kTypeCodeDictStringKey = 13; 119 constexpr static int kTypeCodeDictLongKey = 14; 120 121 static facebook::jni::local_ref<JIValue> newJIValueFromAtIValue( 122 const at::IValue& ivalue, 123 DictCallback stringDictCallback = newJIValueFromStringDict, 124 DictCallback intDictCallback = newJIValueFromIntDict); 125 126 static at::IValue JIValueToAtIValue( 127 facebook::jni::alias_ref<JIValue> jivalue); 128 129 private: 130 static facebook::jni::local_ref<JIValue> newJIValueFromStringDict( 131 c10::Dict<c10::IValue, c10::IValue>); 132 static facebook::jni::local_ref<JIValue> newJIValueFromIntDict( 133 c10::Dict<c10::IValue, c10::IValue>); 134 }; 135 136 void common_registerNatives(); 137 } // namespace pytorch_jni 138