1 #include <ATen/core/ivalue.h>
2 #include <torch/csrc/jit/mobile/code.h>
3 #include <torch/csrc/jit/mobile/parse_bytecode.h>
4 #include <torch/csrc/jit/mobile/type_parser.h>
5 #include <torch/csrc/jit/mobile/upgrader_mobile.h>
6 #include <torch/csrc/jit/runtime/instruction.h>
7 #include <torch/csrc/jit/serialization/import_export_constants.h>
8 #include <torch/csrc/jit/serialization/import_export_functions.h>
9 #include <torch/custom_class_detail.h>
10
11 namespace torch::jit {
12 OpCode parseOpCode(const char* str);
13 using c10::IValue;
14
expect_field(c10::ivalue::TupleElements & elements,const std::string & expected_name,size_t entry)15 IValue expect_field(
16 c10::ivalue::TupleElements& elements,
17 const std::string& expected_name,
18 size_t entry) {
19 auto row = std::move(elements.at(entry)).toTuple();
20 TORCH_INTERNAL_ASSERT(
21 row->elements().at(0).toStringRef() == expected_name,
22 "Expected ",
23 expected_name,
24 " found ",
25 row->elements().at(0).toStringRef());
26 return std::move(row)->elements().at(1);
27 }
28
29 namespace mobile {
30
31 namespace {
32 #define COUNT_OPCODE(_, _a) 1 +
33 constexpr size_t numOpcodes = FORALL_OPCODES(COUNT_OPCODE) 0;
34 #undef COUNT_OPCODE
35
36 // Pickled strings are memoized, so we can cache a mapping from
37 // pointers to parsed OpCodes to speed up parsing.
38 class OpCodeCache {
39 private:
40 // We store as void* to emphasize that we care only about the
41 // address and should not be dereferencing these pointers.
42 std::array<const void*, numOpcodes> keys_{};
43 std::array<OpCode, numOpcodes> values_{};
44 size_t usedEntries_ = 0;
45
46 public:
OpCodeCache()47 OpCodeCache() {
48 memset(keys_.data(), 0, keys_.size() * sizeof(keys_[0]));
49 }
50
parse(const c10::ivalue::ConstantString & s)51 OpCode parse(const c10::ivalue::ConstantString& s) {
52 const auto endIt = keys_.begin() + usedEntries_;
53 auto it = std::find_if(
54 keys_.begin(), endIt, [&s](const void* k) { return k == &s; });
55 if (it == endIt) {
56 OpCode result = parseOpCode(s.string().c_str());
57 if (usedEntries_ < numOpcodes) {
58 keys_[usedEntries_] = &s;
59 values_[usedEntries_++] = result;
60 }
61 return result;
62 }
63 // NOTE: I tried implementing the transpose heuristic here to
64 // speed up the search, but it removed the benefit of this cache.
65 return values_[it - keys_.begin()];
66 }
67 };
68 } // namespace
69
applyUpgrader(mobile::Function * function,uint64_t operator_version)70 void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
71 Code& code = function->get_code();
72 auto& operator_version_map = getOperatorVersionMapForMobile();
73 for (size_t i = 0; i < code.instructions_.size(); i++) {
74 Instruction& inst = code.instructions_[i];
75 if (inst.op == OpCode::OP) {
76 std::string operator_name = code.op_names_[inst.X].name +
77 (code.op_names_[inst.X].overload_name.empty()
78 ? ""
79 : "." + code.op_names_[inst.X].overload_name);
80
81 auto it = operator_version_map.find(operator_name);
82 // Find out if there is an upgrader for this operator
83 if (it != operator_version_map.end()) {
84 auto upgrader_list = it->second;
85 // Loop all upgraders for this operator, and find out if there exists a
86 // valid upgrader. Use iteration here instead of other faster search
87 // algorithm, because the number of upgrader per operator will be just a
88 // few and tend to keep the code light-weight from binary size concern.
89 for (const auto& upgrader : upgrader_list) {
90 if (static_cast<int>(operator_version) <= upgrader.max_version &&
91 static_cast<int>(operator_version) >= upgrader.min_version) {
92 // If there exists a valid upgrader, change the instruction OP to
93 // CALL, and the index will point to the according upgrader
94 // function. All upgrader function are available in
95 // function->get_code().functions_. It's a vector of function
96 // pointer and they are initialized in the same order as the global
97 // vector kUpgraderBytecode.
98 // Instruction new_inst = inst;
99 // new_inst.op = OpCode::CALL;
100 // new_inst.X = upgrader.index;
101 // code->instructions_[i] = new_inst;
102 TORCH_CHECK(
103 upgrader.index < static_cast<int>(code.functions_.size()),
104 "upgrader index is, ",
105 upgrader.index,
106 " and it's larger than the upgrader function list length ",
107 code.functions_.size());
108 inst.op = OpCode::CALL;
109 inst.X = upgrader.index;
110 }
111 }
112 }
113 }
114 }
115 }
116
parseInstructions(const std::string & function_name,c10::ivalue::TupleElements && ins_list,c10::ivalue::TupleElements & debug_handles_m_tuple,mobile::Function * function)117 void parseInstructions(
118 const std::string& function_name,
119 c10::ivalue::TupleElements&& ins_list,
120 c10::ivalue::TupleElements& debug_handles_m_tuple,
121 mobile::Function* function) {
122 c10::List<int64_t> debug_handles_list;
123 if (!debug_handles_m_tuple.empty()) {
124 const std::string& debug_info_function_name =
125 debug_handles_m_tuple[0].toStringRef();
126 TORCH_CHECK(
127 debug_info_function_name == function_name,
128 "The function names in the bytecode table and the debug info table do not match.");
129 IValue& debug_handles_table = debug_handles_m_tuple[1];
130 auto debugHandlesTableElements =
131 std::move(*std::move(debug_handles_table).toTuple()).elements();
132 debug_handles_list = (expect_field(
133 debugHandlesTableElements,
134 "function_debug_handles",
135 BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
136 .toTupleRef()
137 .elements())[0]
138 .toIntList();
139 TORCH_CHECK(
140 debug_handles_list.size() == ins_list.size(),
141 "The numbers of instructions and debug handles strings do not match.");
142 }
143
144 // NOTE: this won't perform particularly well if the ins_list IValue
145 // didn't come from unpickler and thus have its strings
146 // interned. Consider adding a flag to bypass the cache if that
147 // becomes an important use case.
148 OpCodeCache opCodeCache;
149 for (const auto j : c10::irange(ins_list.size())) {
150 auto ins_tuple = std::move(ins_list[j]).toTuple();
151 c10::ArrayRef<IValue> ins_item = ins_tuple->elements();
152 TORCH_CHECK(
153 ins_item.size() == 3,
154 "There should be three parts in an instruction. The function name is ",
155 function_name);
156 OpCode op_code = opCodeCache.parse(*ins_item[0].toString());
157 auto X = ins_item[1].toInt();
158 auto N = ins_item[2].toInt();
159
160 if (!debug_handles_list.empty()) {
161 int64_t debug_handle = debug_handles_list[j];
162 function->append_instruction(op_code, X, N, debug_handle);
163 } else {
164 function->append_instruction(op_code, X, N);
165 }
166 }
167 }
168
parseConstants(const c10::ivalue::TupleElements & consts_list,mobile::Function * function)169 void parseConstants(
170 const c10::ivalue::TupleElements& consts_list,
171 mobile::Function* function) {
172 for (const auto& constant : consts_list) {
173 function->append_constant(constant);
174 }
175 }
parseTypes(const c10::ivalue::TupleElements & types_list,mobile::Function * function)176 void parseTypes(
177 const c10::ivalue::TupleElements& types_list,
178 mobile::Function* function) {
179 std::vector<std::string> types_string_list;
180 types_string_list.resize(types_list.size());
181 for (size_t i = 0; i < types_list.size(); i++) {
182 types_string_list[i] = types_list[i].toStringRef();
183 }
184
185 std::vector<c10::TypePtr> types_ptr_list = c10::parseType(types_string_list);
186 for (auto& type_ptr : types_ptr_list) {
187 function->append_type(type_ptr);
188 }
189 }
190
parseRegisterSize(size_t rsize,mobile::Function * function)191 void parseRegisterSize(size_t rsize, mobile::Function* function) {
192 function->set_register_size(rsize);
193 }
194
195 } // namespace mobile
196 } // namespace torch::jit
197