xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/parse_bytecode.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/ivalue.h>
2 #include <torch/csrc/jit/mobile/code.h>
3 #include <torch/csrc/jit/mobile/parse_bytecode.h>
4 #include <torch/csrc/jit/mobile/type_parser.h>
5 #include <torch/csrc/jit/mobile/upgrader_mobile.h>
6 #include <torch/csrc/jit/runtime/instruction.h>
7 #include <torch/csrc/jit/serialization/import_export_constants.h>
8 #include <torch/csrc/jit/serialization/import_export_functions.h>
9 #include <torch/custom_class_detail.h>
10 
11 namespace torch::jit {
12 OpCode parseOpCode(const char* str);
13 using c10::IValue;
14 
expect_field(c10::ivalue::TupleElements & elements,const std::string & expected_name,size_t entry)15 IValue expect_field(
16     c10::ivalue::TupleElements& elements,
17     const std::string& expected_name,
18     size_t entry) {
19   auto row = std::move(elements.at(entry)).toTuple();
20   TORCH_INTERNAL_ASSERT(
21       row->elements().at(0).toStringRef() == expected_name,
22       "Expected ",
23       expected_name,
24       " found ",
25       row->elements().at(0).toStringRef());
26   return std::move(row)->elements().at(1);
27 }
28 
29 namespace mobile {
30 
31 namespace {
32 #define COUNT_OPCODE(_, _a) 1 +
33 constexpr size_t numOpcodes = FORALL_OPCODES(COUNT_OPCODE) 0;
34 #undef COUNT_OPCODE
35 
36 // Pickled strings are memoized, so we can cache a mapping from
37 // pointers to parsed OpCodes to speed up parsing.
38 class OpCodeCache {
39  private:
40   // We store as void* to emphasize that we care only about the
41   // address and should not be dereferencing these pointers.
42   std::array<const void*, numOpcodes> keys_{};
43   std::array<OpCode, numOpcodes> values_{};
44   size_t usedEntries_ = 0;
45 
46  public:
OpCodeCache()47   OpCodeCache() {
48     memset(keys_.data(), 0, keys_.size() * sizeof(keys_[0]));
49   }
50 
parse(const c10::ivalue::ConstantString & s)51   OpCode parse(const c10::ivalue::ConstantString& s) {
52     const auto endIt = keys_.begin() + usedEntries_;
53     auto it = std::find_if(
54         keys_.begin(), endIt, [&s](const void* k) { return k == &s; });
55     if (it == endIt) {
56       OpCode result = parseOpCode(s.string().c_str());
57       if (usedEntries_ < numOpcodes) {
58         keys_[usedEntries_] = &s;
59         values_[usedEntries_++] = result;
60       }
61       return result;
62     }
63     // NOTE: I tried implementing the transpose heuristic here to
64     // speed up the search, but it removed the benefit of this cache.
65     return values_[it - keys_.begin()];
66   }
67 };
68 } // namespace
69 
applyUpgrader(mobile::Function * function,uint64_t operator_version)70 void applyUpgrader(mobile::Function* function, uint64_t operator_version) {
71   Code& code = function->get_code();
72   auto& operator_version_map = getOperatorVersionMapForMobile();
73   for (size_t i = 0; i < code.instructions_.size(); i++) {
74     Instruction& inst = code.instructions_[i];
75     if (inst.op == OpCode::OP) {
76       std::string operator_name = code.op_names_[inst.X].name +
77           (code.op_names_[inst.X].overload_name.empty()
78                ? ""
79                : "." + code.op_names_[inst.X].overload_name);
80 
81       auto it = operator_version_map.find(operator_name);
82       // Find out if there is an upgrader for this operator
83       if (it != operator_version_map.end()) {
84         auto upgrader_list = it->second;
85         // Loop all upgraders for this operator, and find out if there exists a
86         // valid upgrader. Use iteration here instead of other faster search
87         // algorithm, because the number of upgrader per operator will be just a
88         // few and tend to keep the code light-weight from binary size concern.
89         for (const auto& upgrader : upgrader_list) {
90           if (static_cast<int>(operator_version) <= upgrader.max_version &&
91               static_cast<int>(operator_version) >= upgrader.min_version) {
92             // If there exists a valid upgrader, change the instruction OP to
93             // CALL, and the index will point to the according upgrader
94             // function. All upgrader function are available in
95             // function->get_code().functions_. It's a vector of function
96             // pointer and they are initialized in the same order as the global
97             // vector kUpgraderBytecode.
98             // Instruction new_inst = inst;
99             // new_inst.op = OpCode::CALL;
100             // new_inst.X = upgrader.index;
101             // code->instructions_[i] = new_inst;
102             TORCH_CHECK(
103                 upgrader.index < static_cast<int>(code.functions_.size()),
104                 "upgrader index is, ",
105                 upgrader.index,
106                 " and it's larger than the upgrader function list length ",
107                 code.functions_.size());
108             inst.op = OpCode::CALL;
109             inst.X = upgrader.index;
110           }
111         }
112       }
113     }
114   }
115 }
116 
parseInstructions(const std::string & function_name,c10::ivalue::TupleElements && ins_list,c10::ivalue::TupleElements & debug_handles_m_tuple,mobile::Function * function)117 void parseInstructions(
118     const std::string& function_name,
119     c10::ivalue::TupleElements&& ins_list,
120     c10::ivalue::TupleElements& debug_handles_m_tuple,
121     mobile::Function* function) {
122   c10::List<int64_t> debug_handles_list;
123   if (!debug_handles_m_tuple.empty()) {
124     const std::string& debug_info_function_name =
125         debug_handles_m_tuple[0].toStringRef();
126     TORCH_CHECK(
127         debug_info_function_name == function_name,
128         "The function names in the bytecode table and the debug info table do not match.");
129     IValue& debug_handles_table = debug_handles_m_tuple[1];
130     auto debugHandlesTableElements =
131         std::move(*std::move(debug_handles_table).toTuple()).elements();
132     debug_handles_list = (expect_field(
133                               debugHandlesTableElements,
134                               "function_debug_handles",
135                               BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
136                               .toTupleRef()
137                               .elements())[0]
138                              .toIntList();
139     TORCH_CHECK(
140         debug_handles_list.size() == ins_list.size(),
141         "The numbers of instructions and debug handles strings do not match.");
142   }
143 
144   // NOTE: this won't perform particularly well if the ins_list IValue
145   // didn't come from unpickler and thus have its strings
146   // interned. Consider adding a flag to bypass the cache if that
147   // becomes an important use case.
148   OpCodeCache opCodeCache;
149   for (const auto j : c10::irange(ins_list.size())) {
150     auto ins_tuple = std::move(ins_list[j]).toTuple();
151     c10::ArrayRef<IValue> ins_item = ins_tuple->elements();
152     TORCH_CHECK(
153         ins_item.size() == 3,
154         "There should be three parts in an instruction. The function name is ",
155         function_name);
156     OpCode op_code = opCodeCache.parse(*ins_item[0].toString());
157     auto X = ins_item[1].toInt();
158     auto N = ins_item[2].toInt();
159 
160     if (!debug_handles_list.empty()) {
161       int64_t debug_handle = debug_handles_list[j];
162       function->append_instruction(op_code, X, N, debug_handle);
163     } else {
164       function->append_instruction(op_code, X, N);
165     }
166   }
167 }
168 
parseConstants(const c10::ivalue::TupleElements & consts_list,mobile::Function * function)169 void parseConstants(
170     const c10::ivalue::TupleElements& consts_list,
171     mobile::Function* function) {
172   for (const auto& constant : consts_list) {
173     function->append_constant(constant);
174   }
175 }
parseTypes(const c10::ivalue::TupleElements & types_list,mobile::Function * function)176 void parseTypes(
177     const c10::ivalue::TupleElements& types_list,
178     mobile::Function* function) {
179   std::vector<std::string> types_string_list;
180   types_string_list.resize(types_list.size());
181   for (size_t i = 0; i < types_list.size(); i++) {
182     types_string_list[i] = types_list[i].toStringRef();
183   }
184 
185   std::vector<c10::TypePtr> types_ptr_list = c10::parseType(types_string_list);
186   for (auto& type_ptr : types_ptr_list) {
187     function->append_type(type_ptr);
188   }
189 }
190 
parseRegisterSize(size_t rsize,mobile::Function * function)191 void parseRegisterSize(size_t rsize, mobile::Function* function) {
192   function->set_register_size(rsize);
193 }
194 
195 } // namespace mobile
196 } // namespace torch::jit
197