1 #include <torch/csrc/jit/serialization/export_bytecode.h>
2 #include <utility>
3
4 #include <torch/csrc/jit/operator_upgraders/version_map.h>
5 #include <torch/csrc/jit/runtime/instruction.h>
6 #include <torch/csrc/jit/serialization/export.h>
7
8 #include <c10/util/Exception.h>
9 #include <torch/csrc/jit/api/function_impl.h>
10 #include <torch/csrc/jit/api/method.h>
11 #include <torch/csrc/jit/backends/backend_debug_handler.h>
12 #include <torch/csrc/jit/backends/backend_debug_info.h>
13 #include <torch/csrc/jit/frontend/source_range.h>
14 #include <torch/csrc/jit/ir/attributes.h>
15 #include <torch/csrc/jit/ir/ir.h>
16 #include <torch/csrc/jit/ir/type_hashing.h>
17 #include <torch/csrc/jit/mobile/function.h>
18 #include <torch/csrc/jit/mobile/interpreter.h>
19 #include <torch/csrc/jit/mobile/method.h>
20 #include <torch/csrc/jit/mobile/module.h>
21 #include <torch/csrc/jit/passes/inliner.h>
22 #include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
23 #include <torch/csrc/jit/serialization/import_export_constants.h>
24 #include <torch/csrc/jit/serialization/import_export_functions.h>
25 #include <torch/csrc/jit/serialization/import_export_helpers.h>
26 #include <torch/csrc/jit/serialization/pickle.h>
27 #include <torch/csrc/jit/serialization/python_print.h>
28 #include <torch/csrc/jit/serialization/source_range_serialization.h>
29 #include <torch/csrc/jit/serialization/type_name_uniquer.h>
30
31 #include <caffe2/serialize/inline_container.h>
32
33 namespace torch::jit {
34
gatherGetSetStates(const ObjectPtr & obj)35 static std::vector<Method> gatherGetSetStates(const ObjectPtr& obj) {
36 std::vector<Method> methods;
37 // Use DFS on IValue's to traverse dependencies of module._ivalue and
38 // add all setstate/getstates to initial stack.
39 std::vector<ObjectPtr> ivalue_stack;
40 ivalue_stack.emplace_back(obj);
41 while (!ivalue_stack.empty()) {
42 ObjectPtr cur = ivalue_stack.back();
43 ivalue_stack.pop_back();
44 auto type = cur->type();
45 Function* setstate = type->findMethod("__setstate__");
46 Function* getstate = type->findMethod("__getstate__");
47 if (getstate && setstate) {
48 if (setstate->isGraphFunction()) {
49 methods.emplace_back(cur, setstate);
50 }
51 if (getstate->isGraphFunction()) {
52 methods.emplace_back(cur, getstate);
53 }
54 } else {
55 for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
56 IValue field = cur->getSlot(i);
57 if (field.isObject()) {
58 ivalue_stack.emplace_back(field.toObject());
59 }
60 }
61 }
62 }
63 return methods;
64 }
65
findAllDependentFunctions(const Module & module,Graph & graph)66 static std::vector<Method> findAllDependentFunctions(
67 const Module& module,
68 Graph& graph) {
69 std::vector<Method> methods;
70 std::unordered_set<c10::string_view> called_method_names;
71 auto nodes = findAllNodes(graph, c10::prim::CallMethod, true);
72 for (Node* node : nodes) {
73 if (auto iface = node->input(0)->type()->castRaw<InterfaceType>()) {
74 const FunctionSchema* schema = iface->getMethod(node->s(attr::name));
75 called_method_names.insert(schema->name());
76 }
77 }
78
79 for (const auto& submodule : module.modules()) {
80 for (const auto& m : submodule.get_methods()) {
81 if (called_method_names.find(m.function().qualname().name()) !=
82 called_method_names.end()) {
83 methods.emplace_back(m);
84 }
85 }
86 }
87 return methods;
88 }
89
90 // NOTE: order of functions returned will be:
91 // 1. functions originated from the methods passed in will be first
92 // 2. All the dependent functions will come afterwards.
93 // This order is meaningful because currently mobile Module looks up
94 // methods with linear search.
inlineFunctions(const std::vector<Method> & initial_methods,bool incl_dependent_functions)95 static std::vector<std::unique_ptr<GraphFunction>> inlineFunctions(
96 const std::vector<Method>& initial_methods,
97 bool incl_dependent_functions) {
98 std::set<std::pair<std::string, Function*>> visited;
99 std::deque<Method> stack;
100 std::copy(
101 initial_methods.begin(),
102 initial_methods.end(),
103 std::back_inserter(stack));
104 std::vector<std::unique_ptr<GraphFunction>> inlined_functions;
105 while (!stack.empty()) {
106 Method cur = stack.front();
107 stack.pop_front();
108 auto tup = std::make_pair(
109 cur.owner()._ivalue()->type()->name()->qualifiedName(),
110 &cur.function());
111 if (visited.find(tup) != visited.end()) {
112 continue;
113 }
114 visited.insert(tup);
115 const auto& f = toGraphFunction(cur.function());
116 auto graph = f.graph()->copyUnique();
117 Inline(*graph);
118 c10::QualifiedName qn(*cur.owner()._ivalue()->type()->name(), f.name());
119
120 if (incl_dependent_functions) {
121 std::vector<Method> dependent_methods =
122 findAllDependentFunctions(cur.owner(), *graph);
123 std::copy(
124 dependent_methods.begin(),
125 dependent_methods.end(),
126 std::back_inserter(stack));
127 }
128 auto inlined_func = std::make_unique<GraphFunction>(
129 qn, std::move(graph), f.function_creator());
130 inlined_func->setSchema(f.getSchema());
131 inlined_functions.emplace_back(std::move(inlined_func));
132 }
133 return inlined_functions;
134 }
135
compileGraphToMobileCode(const std::string & name,const std::shared_ptr<Graph> & graph,const CompilationOptions & compilation_options,BackendDebugInfoRecorder & debug_info_recorder)136 mobile::Code compileGraphToMobileCode(
137 const std::string& name,
138 const std::shared_ptr<Graph>& graph,
139 const CompilationOptions& compilation_options,
140 BackendDebugInfoRecorder& debug_info_recorder) {
141 MobileCode code(
142 graph,
143 name,
144 compilation_options.enable_default_value_for_unspecified_arg,
145 compilation_options.enable_default_args_before_out_args,
146 compilation_options.enable_emit_promoted_ops);
147
148 mobile::Code mobile_code;
149
150 // operator names
151 std::vector<std::string> method_names;
152 std::vector<int64_t> op_debug_handles;
153 int next_new_op_index = 0;
154
155 auto op_to_specified_args = code.op_to_num_specified_args();
156
157 for (size_t i = 0; i < code.instructions().size(); ++i) {
158 Instruction ins = code.instructions()[i];
159
160 if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) {
161 // Found a new op (assumes new operators ordered by ascending ins.X)
162 auto node = code.instructions_source()[i];
163 const c10::OperatorName& opname = node->schema().operator_name();
164 auto unique_name = c10::toString(opname);
165 // For operator with vararg, adding default arguments would be confusing
166 // and is not allowed. For an operator with num_args = -1, it means the
167 // number of arguments is not available for this operator, we don't do any
168 // backward compatibility adaptation at runtime.
169 std::optional<int> num_args = std::nullopt;
170 auto it = op_to_specified_args.find(unique_name);
171 if (it != op_to_specified_args.end()) {
172 num_args = it->second;
173 }
174 mobile_code.operator_input_sizes_.emplace_back(num_args.value_or(-1));
175 mobile_code.op_names_.emplace_back(opname);
176 auto func = mobile::makeOperatorFunction(opname, num_args);
177 TORCH_INTERNAL_ASSERT(
178 func.has_value(),
179 "Operator with name: ",
180 toString(opname),
181 " not found");
182 mobile_code.operators_.emplace_back(*func);
183 next_new_op_index++;
184 }
185 // CALL nodes at this point represent built-in (i.e. non-Graph)
186 // functions that were not inlined. Here we convert the CALL
187 // instructions for these functions into INTERFACE_CALL instructions
188 // s.t. at runtime, we will look up the Function* on the Type of the
189 // 0th argument in the stack and call that directly.
190 if (ins.op == CALL) {
191 auto node = code.instructions_source()[i];
192 if (node->kind() == prim::CallMethod) {
193 // NB: replacing instruction
194 auto method_name_idx =
195 code.constant_table().size() + method_names.size();
196 method_names.emplace_back(node->s(attr::name));
197 ins = Instruction{
198 INTERFACE_CALL,
199 static_cast<int32_t>(method_name_idx),
200 static_cast<uint16_t>(node->inputs().size())};
201 } else {
202 TORCH_INTERNAL_ASSERT(
203 false, "Unsupported node kind on CALL opcode for mobile");
204 }
205 } else if (ins.op == RET) {
206 auto node = code.instructions_source()[i];
207 for (const auto& input : node->inputs()) {
208 const auto& input_type = input->type();
209 if (input_type->kind() == TypeKind::ListType ||
210 input_type->kind() == TypeKind::DictType) {
211 for (const TypePtr& element_type : input_type->containedTypes()) {
212 TORCH_CHECK(
213 element_type->kind() != TypeKind::ClassType,
214 "Returning a list or dictionary with pytorch class type ",
215 "is not supported in mobile module "
216 "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). "
217 "Workaround: instead of using pytorch class as their element type, ",
218 "use a combination of list, dictionary, and single types.");
219 }
220 }
221 }
222 } else {
223 TORCH_CHECK(
224 isOpSupportedInMobile(ins.op),
225 toString(ins.op),
226 " is not supported in mobile module.");
227 }
228 auto node = code.instructions_source()[i];
229 int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node);
230 // Note 1-to-1 correspondence between instructions and debug handles
231 mobile_code.instructions_.emplace_back(ins);
232 mobile_code.debug_handles_.emplace_back(debug_handle);
233 }
234
235 // copy constants
236 mobile_code.constants_ = code.constant_table();
237
238 // Make a copy of the constants and append the method names
239 // that we emitted for the converted INTERFACE_CALL nodes above.
240 for (auto& method_name : method_names) {
241 mobile_code.constants_.emplace_back(method_name);
242 }
243
244 mobile_code.types_ = code.type_table();
245 mobile_code.register_size_ = code.register_size();
246 return mobile_code;
247 }
248
convertJitFunctionToMobileFunction(const GraphFunction & function,const CompilationOptions & options)249 std::unique_ptr<mobile::Function> convertJitFunctionToMobileFunction(
250 const GraphFunction& function,
251 const CompilationOptions& options) {
252 BackendDebugInfoRecorder debug_handle;
253 auto mobileCode = compileGraphToMobileCode(
254 function.name(), function.graph(), options, debug_handle);
255 const auto& schema = function.getSchema();
256 return std::make_unique<mobile::Function>(
257 function.qualname(), std::move(mobileCode), schema);
258 }
259
convertMobileFunctionToCodeTable(const mobile::Function & func,const CompilationOptions & compilation_options)260 IValue convertMobileFunctionToCodeTable(
261 const mobile::Function& func,
262 const CompilationOptions& compilation_options) {
263 auto code = func.get_code();
264 std::vector<IValue> instructions;
265 instructions.reserve(code.instructions_.size());
266 for (Instruction ins : code.instructions_) {
267 instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
268 }
269
270 std::vector<IValue> operators;
271 operators.reserve(code.op_names_.size());
272 for (unsigned i = 0; i < code.op_names_.size(); ++i) {
273 const auto& opname = code.op_names_[i];
274 const int size = code.operator_input_sizes_[i];
275 if (compilation_options.enable_default_value_for_unspecified_arg) {
276 operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
277 } else {
278 operators.emplace_back(
279 to_tuple({opname.name, opname.overload_name, size}));
280 }
281 }
282
283 std::vector<IValue> types;
284 for (const TypePtr& t : code.types_) {
285 std::string type_str = t->annotation_str();
286 types.emplace_back(type_str);
287 }
288
289 auto register_size = static_cast<int>(code.register_size_);
290 auto codeTable = Table(
291 {{"instructions", to_tuple(instructions)},
292 {"operators", to_tuple(operators)},
293 {"constants", to_tuple(code.constants_)},
294 {"types", to_tuple(types)},
295 {"register_size", register_size}});
296
297 return codeTable;
298 }
299
checkSchema(const c10::FunctionSchema & schema)300 static void checkSchema(const c10::FunctionSchema& schema) {
301 TORCH_CHECK(
302 schema.overload_name().empty(), // @TODO: is this check correct?
303 "Overloads are not supported in mobile modules.");
304 TORCH_CHECK(
305 !schema.is_vararg(), "Python *args are not supported in mobile modules.");
306 TORCH_CHECK(
307 !schema.is_varret(),
308 "A variable number of return values is not supported in mobile modules.");
309 }
310
isLoweredModule(const Module & m)311 static bool isLoweredModule(const Module& m) {
312 c10::QualifiedName type_name;
313 if (m.type()->name()) {
314 type_name = m.type()->name().value();
315 }
316 bool isLoweredModule = false;
317 for (const auto& atom : type_name.atoms()) {
318 if (atom == "LoweredModule") {
319 isLoweredModule = true;
320 break;
321 }
322 }
323 return isLoweredModule;
324 }
325
326 // Check if the global static map of backend debug info
327 // contains debug info for this module and any of its children.
328 // If so combine all the maps together and return one.
getBackendDebugInfoMap(const Module & m,BackendDebugInfoMapType & debug_map)329 static void getBackendDebugInfoMap(
330 const Module& m,
331 BackendDebugInfoMapType& debug_map) {
332 if (isLoweredModule(m)) {
333 auto backend_debug_info =
334 m.attr("__backend_debug_info").toCustomClass<PyTorchBackendDebugInfo>();
335 const auto& map = backend_debug_info->getDebugInfoMap();
336 if (map) {
337 debug_map.insert(map.value().begin(), map.value().end());
338 }
339 }
340 for (const auto& c : m.children()) {
341 getBackendDebugInfoMap(c, debug_map);
342 }
343 }
344
get_min_operator_version_from_version_map(const mobile::Module & module)345 static uint64_t get_min_operator_version_from_version_map(
346 const mobile::Module& module) {
347 uint64_t min_version = caffe2::serialize::kMinSupportedFileFormatVersion;
348 for (const auto& func : module.compilation_unit().methods()) {
349 for (const auto& op_name : func->get_code().op_names_) {
350 auto schema_name = op_name.overload_name.empty()
351 ? op_name.name
352 : op_name.name + "." + op_name.overload_name;
353 auto version_entry = get_operator_version_map().find(schema_name);
354 if (version_entry != get_operator_version_map().end()) {
355 const auto& entry = version_entry->second;
356 min_version = std::max(
357 min_version, uint64_t(entry[entry.size() - 1].bumped_at_version));
358 }
359 }
360 }
361 return min_version;
362 }
363
jitModuleToMobile(const Module & module,const CompilationOptions & options)364 mobile::Module jitModuleToMobile(
365 const Module& module,
366 const CompilationOptions& options) {
367 std::shared_ptr<mobile::CompilationUnit> mcu =
368 std::make_shared<mobile::CompilationUnit>();
369 BackendDebugInfoRecorder debug_info_recorder;
370
371 std::vector<Method> methods_to_export = module.get_methods();
372 std::vector<Method> getsetstates = gatherGetSetStates(module._ivalue());
373 std::copy(
374 getsetstates.begin(),
375 getsetstates.end(),
376 std::back_inserter(methods_to_export));
377
378 for (const auto& func :
379 inlineFunctions(methods_to_export, options.incl_interface_call)) {
380 auto mobile_code = compileGraphToMobileCode(
381 func->name(), func->graph(), options, debug_info_recorder);
382 const auto& schema = func->getSchema();
383 checkSchema(schema);
384 auto mobile_func = std::make_unique<mobile::Function>(
385 func->qualname(), std::move(mobile_code), schema);
386 mcu->register_function(std::move(mobile_func));
387 }
388
389 mobile::Module m(module._ivalue(), mcu);
390 m.setHasDebugHandles(true);
391 BackendDebugInfoMapType backend_debug_info_map;
392 getBackendDebugInfoMap(module, backend_debug_info_map);
393 auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
394 debug_handle_cs_ptr_map.insert(
395 backend_debug_info_map.begin(), backend_debug_info_map.end());
396 m.setDebugTable(MobileDebugTable(
397 debug_handle_cs_ptr_map.begin(), debug_handle_cs_ptr_map.end()));
398 m.set_min_operator_version(
399 static_cast<int64_t>(get_min_operator_version_from_version_map(m)));
400 m.set_bytecode_version(options.model_version);
401 return m;
402 }
403
404 } // namespace torch::jit
405