xref: /aosp_15_r20/external/pytorch/android/pytorch_android/src/main/cpp/pytorch_jni_common.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/FunctionRef.h>
4 #include <fbjni/fbjni.h>
5 #include <torch/csrc/api/include/torch/types.h>
6 #include "caffe2/serialize/read_adapter_interface.h"
7 
8 #include "cmake_macros.h"
9 
10 #ifdef __ANDROID__
11 #include <android/log.h>
12 #define ALOGI(...) \
13   __android_log_print(ANDROID_LOG_INFO, "pytorch-jni", __VA_ARGS__)
14 #define ALOGE(...) \
15   __android_log_print(ANDROID_LOG_ERROR, "pytorch-jni", __VA_ARGS__)
16 #endif
17 
18 #if defined(TRACE_ENABLED) && defined(__ANDROID__)
19 #include <android/trace.h>
20 #include <dlfcn.h>
21 #endif
22 
23 namespace pytorch_jni {
24 
25 constexpr static int kDeviceCPU = 1;
26 constexpr static int kDeviceVulkan = 2;
27 
28 c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode);
29 
30 class Trace {
31  public:
32 #if defined(TRACE_ENABLED) && defined(__ANDROID__)
33   typedef void* (*fp_ATrace_beginSection)(const char* sectionName);
34   typedef void* (*fp_ATrace_endSection)(void);
35 
36   static fp_ATrace_beginSection ATrace_beginSection;
37   static fp_ATrace_endSection ATrace_endSection;
38 #endif
39 
ensureInit()40   static void ensureInit() {
41     if (!Trace::is_initialized_) {
42       init();
43       Trace::is_initialized_ = true;
44     }
45   }
46 
beginSection(const char * name)47   static void beginSection(const char* name) {
48     Trace::ensureInit();
49 #if defined(TRACE_ENABLED) && defined(__ANDROID__)
50     ATrace_beginSection(name);
51 #endif
52   }
53 
endSection()54   static void endSection() {
55 #if defined(TRACE_ENABLED) && defined(__ANDROID__)
56     ATrace_endSection();
57 #endif
58   }
59 
Trace(const char * name)60   Trace(const char* name) {
61     ensureInit();
62     beginSection(name);
63   }
64 
~Trace()65   ~Trace() {
66     endSection();
67   }
68 
69  private:
70   static void init();
71   static bool is_initialized_;
72 };
73 
74 class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
75  public:
MemoryReadAdapter(const void * data,off_t size)76   explicit MemoryReadAdapter(const void* data, off_t size)
77       : data_(data), size_(size){};
78 
size()79   size_t size() const override {
80     return size_;
81   }
82 
83   size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
84       const override {
85     memcpy(buf, (int8_t*)(data_) + pos, n);
86     return n;
87   }
88 
~MemoryReadAdapter()89   ~MemoryReadAdapter() {}
90 
91  private:
92   const void* data_;
93   off_t size_;
94 };
95 
96 class JIValue : public facebook::jni::JavaClass<JIValue> {
97   using DictCallback = c10::function_ref<facebook::jni::local_ref<JIValue>(
98       c10::Dict<c10::IValue, c10::IValue>)>;
99 
100  public:
101   constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
102 
103   constexpr static int kTypeCodeNull = 1;
104 
105   constexpr static int kTypeCodeTensor = 2;
106   constexpr static int kTypeCodeBool = 3;
107   constexpr static int kTypeCodeLong = 4;
108   constexpr static int kTypeCodeDouble = 5;
109   constexpr static int kTypeCodeString = 6;
110 
111   constexpr static int kTypeCodeTuple = 7;
112   constexpr static int kTypeCodeBoolList = 8;
113   constexpr static int kTypeCodeLongList = 9;
114   constexpr static int kTypeCodeDoubleList = 10;
115   constexpr static int kTypeCodeTensorList = 11;
116   constexpr static int kTypeCodeList = 12;
117 
118   constexpr static int kTypeCodeDictStringKey = 13;
119   constexpr static int kTypeCodeDictLongKey = 14;
120 
121   static facebook::jni::local_ref<JIValue> newJIValueFromAtIValue(
122       const at::IValue& ivalue,
123       DictCallback stringDictCallback = newJIValueFromStringDict,
124       DictCallback intDictCallback = newJIValueFromIntDict);
125 
126   static at::IValue JIValueToAtIValue(
127       facebook::jni::alias_ref<JIValue> jivalue);
128 
129  private:
130   static facebook::jni::local_ref<JIValue> newJIValueFromStringDict(
131       c10::Dict<c10::IValue, c10::IValue>);
132   static facebook::jni::local_ref<JIValue> newJIValueFromIntDict(
133       c10::Dict<c10::IValue, c10::IValue>);
134 };
135 
136 void common_registerNatives();
137 } // namespace pytorch_jni
138