1 #include <torch/csrc/jit/serialization/export.h>
2
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/api/function_impl.h>
5 #include <torch/csrc/jit/backends/backend_debug_handler.h>
6 #include <torch/csrc/jit/backends/backend_debug_info.h>
7 #include <torch/csrc/jit/frontend/source_range.h>
8 #include <torch/csrc/jit/ir/attributes.h>
9 #include <torch/csrc/jit/ir/ir.h>
10 #include <torch/csrc/jit/ir/type_hashing.h>
11 #include <torch/csrc/jit/mobile/function.h>
12 #include <torch/csrc/jit/mobile/interpreter.h>
13 #include <torch/csrc/jit/mobile/method.h>
14 #include <torch/csrc/jit/mobile/module.h>
15 #include <torch/csrc/jit/passes/inliner.h>
16 #include <torch/csrc/jit/runtime/instruction.h>
17 #include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
18 #include <torch/csrc/jit/serialization/export_bytecode.h>
19 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
20 #include <torch/csrc/jit/serialization/import_export_constants.h>
21 #include <torch/csrc/jit/serialization/import_export_functions.h>
22 #include <torch/csrc/jit/serialization/import_export_helpers.h>
23 #include <torch/csrc/jit/serialization/pickle.h>
24 #include <torch/csrc/jit/serialization/python_print.h>
25 #include <torch/csrc/jit/serialization/source_range_serialization.h>
26 #include <torch/csrc/jit/serialization/type_name_uniquer.h>
27
28 #include <caffe2/serialize/inline_container.h>
29
30 #include <ATen/ATen.h>
31
32 #include <ATen/core/jit_type.h>
33 #include <ATen/core/qualified_name.h>
34 #include <cerrno>
35 #include <sstream>
36 #include <string>
37 #include <unordered_map>
38 #include <unordered_set>
39 #include <utility>
40 #include <vector>
41
42 namespace torch::jit {
43
getOptionsFromGlobal()44 CompilationOptions getOptionsFromGlobal() {
45 CompilationOptions compilation_options;
46 compilation_options.enable_default_args_before_out_args =
47 BytecodeEmitMode::is_default_args_before_out_args_enabled();
48 compilation_options.enable_default_value_for_unspecified_arg =
49 BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled();
50 compilation_options.enable_emit_promoted_ops =
51 BytecodeEmitMode::is_emit_promoted_ops_enabled();
52 compilation_options.incl_interface_call = getMobileInterfaceCallExport();
53 compilation_options.model_version =
54 caffe2::serialize::kProducedBytecodeVersion;
55 return compilation_options;
56 }
57
to_tuple(std::initializer_list<IValue> ivalues)58 static IValue to_tuple(std::initializer_list<IValue> ivalues) {
59 return c10::ivalue::Tuple::create(ivalues);
60 }
61
to_tuple(std::vector<IValue> ivalues)62 IValue to_tuple(std::vector<IValue> ivalues) {
63 return c10::ivalue::Tuple::create(std::move(ivalues));
64 }
65
Table(const std::vector<std::pair<std::string,IValue>> & entries)66 IValue Table(const std::vector<std::pair<std::string, IValue>>& entries) {
67 std::vector<IValue> ivalue_entries;
68 ivalue_entries.reserve(entries.size());
69 for (const auto& e : entries) {
70 ivalue_entries.push_back(to_tuple({e.first, e.second}));
71 }
72 return to_tuple(std::move(ivalue_entries));
73 }
74
75 namespace {
76
GetExtraFilesHook()77 ExportModuleExtraFilesHook& GetExtraFilesHook() {
78 static ExportModuleExtraFilesHook func = nullptr;
79 return func;
80 }
81
82 /**
83 * If the type is not NamedTuple, it will return default_type_str. If the type
84 * is a NamedTuple, it will return a string with following structure to describe
85 * the content in the NamedTuple: "qualified_named[ NamedTuple, [ [filed_name_1,
86 * field_type_1], [filed_name_2, field_type_2]
87 * ]
88 * ]"
89 * Example NamedTuple type:
90 * "__torch__.base_models.sparse_nn.pytorch_preproc_types.PreprocOutputType[
91 * NamedTuple, [
92 * [float_features, Tensor],
93 * [id_list_features, List[Tensor]],
94 * [label, Tensor],
95 * [weight, Tensor],
96 * ]
97 * ]"
98 *
99 * @param compilation_unit Jit compilation unit to look up function schema.
100 * @param type_ptr A type pointer and it can be possibly any type.
101 * @param default_type_str The default string representation. The string can
102 * either from type_ptr->str(), type_ptr->annotation_str(), or
103 * type_ptr->repr_str(). In some cases, they could be different in different
104 * scenario. For example, Tensor type can be "Tensor", "Tensor (inferred)" and
105 * "Tensor[]", and we only want "Tensor". Leave it as part of arguments as the
106 * default return, when type_ptr is not a NamedTuple.
107 * @return string representation.
108 */
get_named_tuple_str_or_default(const CompilationUnit & compilation_unit,const TypePtr & type_ptr,std::string default_type_str)109 std::string get_named_tuple_str_or_default(
110 const CompilationUnit& compilation_unit,
111 const TypePtr& type_ptr,
112 std::string default_type_str) {
113 if (type_ptr->kind() == TypeKind::TupleType) {
114 // For the simple types (Tensor, Tensor), the mobile type parse can parse
115 // it and compilation unit won't have it's definition. The default type
116 // string will be returned instead.
117 if (compilation_unit.get_named_tuple(type_ptr->str())) {
118 auto named_tuple_ptr = compilation_unit.get_named_tuple(type_ptr->str());
119 if (named_tuple_ptr != nullptr) {
120 std::string named_tuple_str = type_ptr->str();
121 named_tuple_str.append("[NamedTuple, [");
122 std::vector<IValue> name_type_pairs;
123
124 // Get the field name and field type for the NamedTuple
125 for (auto it = named_tuple_ptr->schema()->arguments().begin();
126 it != named_tuple_ptr->schema()->arguments().end();
127 it++) {
128 const std::string named_tuple_name = it->name();
129 const c10::TypePtr& named_tuple_type = it->type();
130 // When it->type() is Tensor type, in Python, if it's inferred type,
131 // str() return "Tensor" and repr_str() return "Tensor (inferred)". If
132 // it's not inferred type, str() return "Tensor[]" and repr_str()
133 // return "Tensor". In cpp, repr_str() will always return "Tensor"
134 // regardless inferred type. When exporing custom type in bytecode,
135 // "Tensor" is the preferred way to deserialize Tensor type
136 std::string named_tuple_type_str = it->is_inferred_type()
137 ? named_tuple_type->str()
138 : named_tuple_type->repr_str();
139 // The type can also be NamedTuple. Will parse it recursively and get
140 // it's string representation.
141 named_tuple_type_str = get_named_tuple_str_or_default(
142 compilation_unit, named_tuple_type, named_tuple_type_str);
143 name_type_pairs.emplace_back(
144 c10::ivalue::Tuple::create({it->name(), named_tuple_type_str}));
145
146 named_tuple_str.append("[")
147 .append(named_tuple_name)
148 .append(", ")
149 .append(named_tuple_type_str)
150 .append("]");
151 if (it != named_tuple_ptr->schema()->arguments().end() - 1) {
152 named_tuple_str.append(",");
153 }
154 }
155 named_tuple_str.append("]]");
156 return named_tuple_str;
157 }
158 }
159 }
160 return default_type_str;
161 }
162
getFunctionTuple(const CompilationUnit & compilation_unit,const mobile::Function & func,BackendDebugInfoRecorder & debug_info_recorder,TypeNameUniquer & type_name_uniquer_)163 std::pair<IValue, IValue> getFunctionTuple(
164 const CompilationUnit& compilation_unit,
165 const mobile::Function& func,
166 BackendDebugInfoRecorder& debug_info_recorder,
167 TypeNameUniquer& type_name_uniquer_) {
168 const auto& mobile_code = func.get_code();
169
170 // instructions
171 std::vector<IValue> instructions;
172 instructions.reserve(mobile_code.instructions_.size());
173 for (Instruction ins : mobile_code.instructions_) {
174 instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
175 }
176
177 // operators
178 std::vector<IValue> operators;
179 operators.reserve(mobile_code.op_names_.size());
180 for (const auto i : c10::irange(mobile_code.op_names_.size())) {
181 const auto& opname = mobile_code.op_names_[i];
182 const int size = mobile_code.operator_input_sizes_[i];
183 if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {
184 operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
185 } else {
186 operators.emplace_back(
187 to_tuple({opname.name, opname.overload_name, size}));
188 }
189 }
190
191 // types
192 std::vector<IValue> types;
193 types.reserve(mobile_code.types_.size());
194 static const std::string torch_prefix("__torch__");
195 static const std::string class_prefix("__torch__.torch.classes");
196
197 for (const TypePtr& ty : mobile_code.types_) {
198 auto t = ty;
199 if (auto dyn = t->castRaw<c10::DynamicType>()) {
200 t = dyn->fallback();
201 }
202 std::string type_str = t->annotation_str();
203 if (t->kind() == TypeKind::DictType) {
204 // For DictType, there are two items in t->containedTypes(), the first one
205 // is key and the second one is value. Both of them could be NamedTuple
206 // type.
207 const TypePtr& key_type = t->containedTypes()[0];
208 const TypePtr& value_type = t->containedTypes()[1];
209 std::string key_type_str = get_named_tuple_str_or_default(
210 compilation_unit, key_type, key_type->annotation_str());
211 std::string value_type_str = get_named_tuple_str_or_default(
212 compilation_unit, value_type, value_type->annotation_str());
213
214 // Construct the dict representation after achieving correct string
215 // representation for both key and value, like
216 // "Dict[str,__torch__.dper3.core.pytorch_schema_utils.IdScoreListFeatureTuple[NamedTuple,
217 // [[lengths, Tensor],[values,
218 // __torch__.dper3.core.pytorch_schema_utils.IdScoreTuple[NamedTuple,
219 // [[ids, Tensor],[scores, Tensor]]]],[offsets, Optional[Tensor]]]]]"
220 std::string dict_str;
221 dict_str.append("Dict[")
222 .append(key_type_str)
223 .append(",")
224 .append(value_type_str)
225 .append("]");
226 types.emplace_back(dict_str);
227 continue;
228 } else if (t->kind() == TypeKind::TupleType) {
229 std::string named_tuple_str =
230 get_named_tuple_str_or_default(compilation_unit, t, type_str);
231 types.emplace_back(named_tuple_str);
232 continue;
233 } else if (type_str.find(torch_prefix) == 0) {
234 TORCH_CHECK(
235 type_str.find(class_prefix) == 0,
236 "__torch__ types other than custom c++ classes (__torch__.torch.classes)"
237 "are not supported in lite interpreter. ",
238 "Workaround: instead of using arbitrary class type (class Foo()), ",
239 "define a pytorch class (class Foo(torch.nn.Module)). The problematic type is: ",
240 type_str);
241 }
242 types.emplace_back(type_str);
243 }
244
245 // since the register location is embedded into the bytecode, pass the
246 // register size
247 auto register_size = static_cast<int>(mobile_code.register_size_);
248
249 auto codeTable = Table(
250 {{"instructions", to_tuple(instructions)},
251 {"operators", to_tuple(operators)},
252 {"constants", to_tuple(mobile_code.constants_)},
253 {"types", to_tuple(types)},
254 {"register_size", register_size}});
255
256 // schema
257 const auto& schema = func.getSchema();
258 auto type_printer = [&](const c10::Type& t) -> std::optional<std::string> {
259 auto namedType = t.cast<c10::NamedType>();
260 if (namedType && namedType->name()) {
261 return type_name_uniquer_.getUniqueName(namedType).qualifiedName();
262 }
263 return std::nullopt;
264 };
265
266 auto makeArgTuple = [&](const std::vector<Argument>& args) {
267 std::vector<IValue> argTables;
268 for (auto&& arg : args) {
269 TORCH_CHECK(
270 !arg.N(),
271 "Arguments with known list lengths are not supported in mobile modules.");
272 TORCH_CHECK(
273 !arg.kwarg_only(),
274 "Keyword-only arguments are not supported in mobile modules.");
275 /*
276 This part adds the argument's name, type and default_value in
277 `bytecode.pkl` This has to be consistent with the `code/` directory
278 which has annotated py code of the entire module. `type_printer` uses
279 `TypeNameUniquer` to get the managled name of the argument. This helps
280 in having the right object reference when a class method is called using
281 the `self` argument.
282
283 arg.type()->annotation_str(type_printer) => mangled unique name of the
284 module/submodule
285 */
286 auto arg_type = arg.type();
287 if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
288 arg_type = dyn->fallback();
289 }
290 argTables.emplace_back(Table({
291 {"name", arg.name()},
292 {"type", arg_type->annotation_str(type_printer)},
293 {"default_value", arg.default_value()},
294 }));
295 }
296 return to_tuple(argTables);
297 };
298 auto schemaTable = Table({
299 {"arguments", makeArgTuple(schema.arguments())},
300 {"returns", makeArgTuple(schema.returns())},
301 });
302
303 // function tuple
304 std::string qn;
305 if (func.name() == "__setstate__" || func.name() == "__getstate__") {
306 auto classtype = func.getSchema().arguments()[0].type()->cast<ClassType>();
307 TORCH_INTERNAL_ASSERT(
308 classtype, "class is null ", func.qualname().qualifiedName());
309 qn = c10::QualifiedName(
310 type_name_uniquer_.getUniqueName(classtype), func.name())
311 .qualifiedName();
312 } else {
313 qn = func.qualname().qualifiedName();
314 }
315 auto bytecode_vals = to_tuple({qn, codeTable, schemaTable});
316
317 std::optional<IValue> debug_info_vals;
318 // module debug info
319 // This is just a set of debug handles.
320 // We always save debug handles.
321 // debug handles generated by debug_handle_manager
322 // will correspond to {source_range, inlinedCallStackPtr} which we will
323 // serialize separately.
324 IValue module_debug_tuple =
325 c10::ivalue::Tuple::create(mobile_code.debug_handles_);
326 auto function_debug_info =
327 Table({{"function_debug_handles", module_debug_tuple}});
328 debug_info_vals = to_tuple({qn, function_debug_info});
329 return std::make_pair(bytecode_vals, debug_info_vals);
330 }
331
pushMobileFunctionsToIValues(const CompilationUnit & compilation_unit,const mobile::Module & module,std::vector<c10::IValue> & elements,std::vector<c10::IValue> & debugInfoElements,BackendDebugInfoRecorder & recorder,TypeNameUniquer & uniquer)332 void pushMobileFunctionsToIValues(
333 const CompilationUnit& compilation_unit,
334 const mobile::Module& module,
335 std::vector<c10::IValue>& elements,
336 std::vector<c10::IValue>& debugInfoElements,
337 BackendDebugInfoRecorder& recorder,
338 TypeNameUniquer& uniquer) {
339 for (const auto& method : module.get_methods()) {
340 auto tuple = getFunctionTuple(
341 compilation_unit, method.function(), recorder, uniquer);
342 elements.push_back(std::move(tuple.first));
343 debugInfoElements.push_back(std::move(tuple.second));
344 }
345 }
346
347 struct ModuleMethod {
ModuleMethodtorch::jit::__anonfedb86c30111::ModuleMethod348 ModuleMethod(Module m, const GraphFunction& f, c10::QualifiedName n)
349 : module(std::move(m)), function(f), exportName(std::move(n)) {}
350 Module module;
351 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
352 const GraphFunction& function;
353 c10::QualifiedName exportName;
354 };
355
isLoweredModule(const Module & m)356 bool isLoweredModule(const Module& m) {
357 c10::QualifiedName type_name;
358 if (m.type()->name()) {
359 type_name = m.type()->name().value();
360 }
361 bool isLoweredModule = false;
362 for (const auto& atom : type_name.atoms()) {
363 if (atom == "LoweredModule") {
364 isLoweredModule = true;
365 break;
366 }
367 }
368 return isLoweredModule;
369 }
370
371 // Check if the global static map of backend debug info
372 // contains debug info for this module and any of its children.
373 // If so combine all the maps together and return one.
getBackendDebugInfoMap(const Module & m,BackendDebugInfoMapType & debug_map)374 void getBackendDebugInfoMap(
375 const Module& m,
376 BackendDebugInfoMapType& debug_map) {
377 if (isLoweredModule(m)) {
378 auto backend_debug_info =
379 m.attr("__backend_debug_info").toCustomClass<PyTorchBackendDebugInfo>();
380 const auto& map = backend_debug_info->getDebugInfoMap();
381 if (map) {
382 debug_map.insert(map.value().begin(), map.value().end());
383 }
384 }
385 for (const auto& c : m.children()) {
386 getBackendDebugInfoMap(c, debug_map);
387 }
388 }
389
getBackendSourceRanges(const Module & m)390 SourceRangeRecords getBackendSourceRanges(const Module& m) {
391 SourceRangeRecords sr_records;
392 if (isLoweredModule(m)) {
393 constexpr size_t kSourceRange = 1;
394 auto backend_debug_info =
395 m.attr("__backend_debug_info").toCustomClass<PyTorchBackendDebugInfo>();
396 const auto& map = backend_debug_info->getDebugInfoMap();
397 if (map) {
398 const auto& map_val = map.value();
399 // This map is map of debug handle-to-DebugInfoTuple
400 // DebugInfoTuple= <source range, op name, inlined_cs_ptr>
401 for (const auto& it : map_val) {
402 auto& source_range =
403 std::get<kDebugInfoTupleSourceRangeIndex>(it.second);
404 sr_records.emplace_back(
405 std::numeric_limits<size_t>::max(), source_range);
406 const auto& cs_ptr = std::get<kDebugInfoTupleInlinedCSIndex>(it.second);
407 if (cs_ptr) {
408 for (const auto& e : cs_ptr->vec()) {
409 const auto& sr = std::get<kSourceRange>(e);
410 sr_records.emplace_back(std::numeric_limits<size_t>::max(), sr);
411 }
412 }
413 }
414 }
415 }
416 for (const auto& c : m.children()) {
417 const auto& child_sr_records = getBackendSourceRanges(c);
418 sr_records.reserve(sr_records.size() + child_sr_records.size());
419 std::move(
420 child_sr_records.begin(),
421 child_sr_records.end(),
422 std::back_inserter(sr_records));
423 }
424 return sr_records;
425 }
426
427 // TODO: remove mobileInterfaceCallExport as it is no longer needed.
428 // This function was introduced to guard the usage of `InterfaceCall` and
429 // now the support for `InterfaceCall` should be mature enough.
mobileInterfaceCallExport()430 auto& mobileInterfaceCallExport() {
431 static std::atomic<bool> flag{true};
432 return flag;
433 }
434
435 } // namespace
436
enableMobileInterfaceCallExport()437 TORCH_API void enableMobileInterfaceCallExport() {
438 mobileInterfaceCallExport().store(true, std::memory_order_relaxed);
439 }
getMobileInterfaceCallExport()440 bool getMobileInterfaceCallExport() {
441 return mobileInterfaceCallExport().load(std::memory_order_relaxed);
442 }
443
SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook)444 void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
445 GetExtraFilesHook() = std::move(hook);
446 }
447
serialize(const Module & module,const ExtraFilesMap & extra_files,bool bytecode_format,bool save_mobile_debug_info)448 void ScriptModuleSerializer::serialize(
449 const Module& module,
450 const ExtraFilesMap& extra_files,
451 bool bytecode_format,
452 bool save_mobile_debug_info) {
453 C10_LOG_API_USAGE_ONCE("torch.jit.save");
454 writeExtraFiles(module, extra_files);
455 // Serialize the model object
456 writeArchive(
457 module._ivalue(),
458 /*archive_name=*/"data",
459 /*archive_dir=*/"",
460 /*tensor_dir=*/"data/");
461 // Then we serialize all code info.
462 convertTypes(module.type());
463 writeFiles("code/");
464 // The tensor constants from the code are written to a separate archive
465 // so loading the code does not depend on loading the data
466 std::vector<IValue> ivalue_constants(
467 constant_table_.begin(), constant_table_.end());
468 if (bytecode_format) {
469 writeArchive(
470 c10::ivalue::Tuple::create(ivalue_constants),
471 /*archive_name=*/"constants",
472 /*archive_dir=*/"",
473 /*tensor_dir=*/"constants/",
474 /*use_storage_context=*/true);
475
476 writeByteCode(module, save_mobile_debug_info);
477 } else {
478 writeArchive(
479 c10::ivalue::Tuple::create(ivalue_constants),
480 /*archive_name=*/"constants",
481 /*archive_dir=*/"",
482 /*tensor_dir=*/"constants/");
483 }
484 if (!module.retrieve_traced_inputs().empty()) {
485 writeArchive(
486 module.retrieve_traced_inputs(),
487 /*archive_name=*/"traced_inputs",
488 /*archive_dir=*/"",
489 /*tensor_dir=*/"traced_inputs/",
490 /*use_storage_context*/ false,
491 /*skip_tensor_data*/ true);
492 }
493 // Acquires and sets minimum (dynamic) version
494 for (auto& item : file_streams_) {
495 writer_.setMinVersion(item.value().minVersion());
496 }
497 }
498
writeArchive(const IValue & value,const std::string & archive_name,const std::string & archive_dir,const std::string & tensor_dir,bool use_storage_context,bool skip_tensor_data)499 void ScriptModuleSerializer::writeArchive(
500 const IValue& value,
501 const std::string& archive_name,
502 const std::string& archive_dir,
503 const std::string& tensor_dir,
504 bool use_storage_context,
505 bool skip_tensor_data) {
506 std::vector<char> data;
507 // Vector to capture the run-time class types during pickling the IValues
508 std::vector<c10::ClassTypePtr> memoizedClassTypes;
509 std::vector<std::string> tensor_names;
510 // tensors that are already serialized in use_storage_context
511 std::unordered_set<std::string> serialized_tensors;
512 Pickler data_pickle(
513 [&](const char* buf, size_t size) {
514 data.insert(data.end(), buf, buf + size);
515 },
516 nullptr,
517 [&](const c10::ClassTypePtr& t) {
518 return type_name_uniquer_.getUniqueName(t);
519 },
520 &memoizedClassTypes,
521 [&](const at::Tensor& tensor) {
522 // returns a string to use in picker.cpp as storage obj key
523 if (use_storage_context) {
524 bool already_serialized =
525 storage_context_.hasStorage(tensor.storage());
526 std::string tensor_name =
527 std::to_string(
528 storage_context_.getOrAddStorage(tensor.storage())) +
529 ".storage";
530 if (already_serialized) {
531 // this case is hit when storage has been serialized already
532 // from a torch.package context
533 serialized_tensors.insert(tensor_name);
534 }
535 tensor_names.push_back(tensor_name);
536 } else {
537 tensor_names.push_back(std::to_string(tensor_names.size()));
538 }
539 return tensor_names.back();
540 });
541 data_pickle.protocol();
542 data_pickle.pushIValue(value);
543 data_pickle.stop();
544 // write out tensor data
545 size_t i = 0;
546
547 TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size());
548
549 for (const auto& td : data_pickle.tensorData()) {
550 std::string tensor_name = tensor_names[i++];
551 if (td.is_meta() || skip_tensor_data) {
552 writer_.writeRecord(tensor_dir + tensor_name, nullptr, 0);
553 continue;
554 }
555 WriteableTensorData writable_td = getWriteableTensorData(td);
556 if (use_storage_context && serialized_tensors.count(tensor_name)) {
557 // storage has been serialzed already, skip
558 continue;
559 }
560 writer_.writeRecord(
561 tensor_dir + tensor_name,
562 writable_td.data(),
563 writable_td.sizeInBytes());
564 }
565
566 std::string fname = archive_dir + archive_name + ".pkl";
567 writer_.writeRecord(fname, data.data(), data.size());
568
569 // serialize all the captured run-time class types
570 for (const c10::ClassTypePtr& wroteType : memoizedClassTypes) {
571 convertNamedType(wroteType);
572 }
573 }
574
writeExtraFiles(const Module & module,const ExtraFilesMap & extra_files)575 void ScriptModuleSerializer::writeExtraFiles(
576 const Module& module,
577 const ExtraFilesMap& extra_files) {
578 // Write out extra files.
579 for (const auto& kv : extra_files) {
580 const std::string key = "extra/" + kv.first;
581 writer_.writeRecord(key, kv.second.data(), kv.second.size());
582 }
583 auto hook = GetExtraFilesHook();
584 if (hook) {
585 ExtraFilesMap hook_files = hook(module);
586 for (const auto& kv : hook_files) {
587 // Checks if the hooked file is already written in extra files,
588 // if so, skips it and warns
589 if (extra_files.find(kv.first) != extra_files.end()) {
590 TORCH_WARN_ONCE(
591 "An extra files hook attempted to write ",
592 kv.first,
593 " but ",
594 "this is already written in extra files and so will be skipped. ",
595 "This warning will only appear once per process.");
596 continue;
597 }
598 const std::string key = "extra/" + kv.first;
599 writer_.writeRecord(key, kv.second.data(), kv.second.size());
600 }
601 }
602 }
603
updateSourceRangeTags(const SourceRangeRecords & ranges)604 void ScriptModuleSerializer::updateSourceRangeTags(
605 const SourceRangeRecords& ranges) {
606 for (const auto& range : ranges) {
607 if (source_range_tags_.find(range.range) == source_range_tags_.end()) {
608 source_range_tags_[range.range] = current_source_range_tag_;
609 current_source_range_tag_++;
610 }
611 }
612 }
613
convertTypes(const at::NamedTypePtr & root_type)614 void ScriptModuleSerializer::convertTypes(const at::NamedTypePtr& root_type) {
615 class_deps_.add(root_type);
616 for (size_t i = 0; i < class_deps_.size(); ++i) {
617 // note: convertNameType may extend class_deps_, so re-checking .size() is
618 // necessary
619 convertNamedType(class_deps_[i]);
620 }
621 }
622
writeFiles(const std::string & code_dir)623 void ScriptModuleSerializer::writeFiles(const std::string& code_dir) {
624 current_source_range_tag_ = 0;
625 // Mapping of filename => src. We need this because multiple classes may go
626 // in the same file (e.g. foo.bar.Baz and foo.bar.Qux)
627 for (auto& item : file_streams_) {
628 const std::string filename = qualifierToArchivePath(item.key(), code_dir);
629
630 std::string src = item.value().str();
631
632 // Only compress these records if they're not tiny.
633 // The cpu cost of generating zip datastructs and compressing isn't
634 // well-spent for very small records.
635 static constexpr size_t kMinToCompress = 200;
636
637 writer_.writeRecord(
638 filename,
639 src.c_str(),
640 src.size(),
641 src.size() > kMinToCompress /*compress*/);
642
643 // Write out the debug information
644 std::string debugFilename = filename + ".debug_pkl";
645 SourceRangePickler source_range_pickler;
646 updateSourceRangeTags(item.value().ranges());
647 auto range_data =
648 source_range_pickler.pickle(item.value().ranges(), source_range_tags_);
649 writer_.writeRecord(
650 debugFilename,
651 range_data.data(),
652 range_data.size(),
653 range_data.size() > kMinToCompress /*compress*/);
654 }
655 }
656
writeByteCode(const Module & module,const bool save_mobile_debug_info)657 void ScriptModuleSerializer::writeByteCode(
658 const Module& module,
659 const bool save_mobile_debug_info) {
660 std::vector<c10::IValue> elements;
661 BackendDebugInfoRecorder debug_info_recorder;
662 int64_t version_to_write = caffe2::serialize::kProducedBytecodeVersion;
663
664 elements.emplace_back(static_cast<int64_t>(version_to_write));
665 std::vector<c10::IValue> debug_info_elements;
666 // Always save debug handles
667 debug_info_elements.emplace_back(static_cast<int64_t>(version_to_write));
668
669 mobile::Module mobile_module =
670 jitModuleToMobile(module, getOptionsFromGlobal());
671
672 pushMobileFunctionsToIValues(
673 *module._ivalue()->compilation_unit(),
674 mobile_module,
675 elements,
676 debug_info_elements,
677 debug_info_recorder,
678 type_name_uniquer_);
679
680 auto telements = to_tuple(std::move(elements));
681 writeArchive(
682 telements,
683 /*archive_name=*/"bytecode",
684 /*archive_dir=*/"",
685 /*tensor_dir=*/"constants/",
686 /*use_storage_context=*/true);
687
688 auto debug_info_telements = to_tuple(std::move(debug_info_elements));
689
690 // At the moment keeping this feature experimental
691 // since we have not evaluated how this affect model size
692 // and we have not build any utility to strip off debug info
693 // when desired
694 // TODO: Build utility to strip off debug map. It should also do the
695 // same for debug_pkl files
696 if (save_mobile_debug_info) {
697 // Note that stripping off debug map will not strip off
698 // debug handles.
699 // The reason we save debug handles conditionally is so that
700 // we dont end up with a model that has debug handles but has not
701 // debug map to correlate debug handels with.
702 // Once we have a model with both handles and debug map, we can
703 // strip off debug map and have a lean model served to production.
704 // If exception ocurrs we have a model with debug map that can be
705 // used to symbolicate debug handles
706 writeArchive(
707 debug_info_telements,
708 /*archive_name=*/"mobile_debug_handles",
709 /*archive_dir=*/"",
710 /*tensor_dir=*/"mobile_debug_handles/");
711 static constexpr size_t kMinToCompress = 200;
712 // For delegated backends get source ranges that are in the debug info
713 // map. Since delegated backend replace original module with lowered
714 // module we will not serialize original module's code which is what would
715 // have contained source range. Since we dont have that anymore, extract
716 // source ranges out of delegated module and store in a separate archive.
717 // Note that we must do this first because in order to serialize inlined
718 // CS appropriate source_range_tags must have been generated.
719 auto backend_source_range_records = getBackendSourceRanges(module);
720 SourceRangePickler source_range_pickler;
721 updateSourceRangeTags(backend_source_range_records);
722 auto range_data = source_range_pickler.pickle(
723 backend_source_range_records, source_range_tags_);
724 std::string debugFilename = "delegated_backends.debug_pkl";
725 writer_.writeRecord(
726 debugFilename,
727 range_data.data(),
728 range_data.size(),
729 range_data.size() > kMinToCompress /*compress*/);
730
731 // For delegated backends get debug_info_map
732 // This is merged with other debug_info_map of other modules
733 // which were not delegated.
734 BackendDebugInfoMapType backend_debug_info_map;
735 getBackendDebugInfoMap(module, backend_debug_info_map);
736 // Now get the debug-handles-to-inlined-cs-ptr-map
737 // And serialize that in a separate archive
738 const auto& debug_info = mobile_module.getDebugTable().getCallStackPtrMap();
739 BackendDebugInfoMapType debug_handle_cs_ptr_map(
740 debug_info.begin(), debug_info.end());
741 CallStackDebugInfoPickler cs_debug_info_pickler;
742 auto cs_data = cs_debug_info_pickler.pickle(
743 debug_handle_cs_ptr_map, source_range_tags_);
744 // Write out map: [debug-handle, {source range, InlinedCallStack}]
745 std::string filename = "callstack_debug_map.pkl";
746 writer_.writeRecord(
747 filename,
748 cs_data.data(),
749 cs_data.size(),
750 cs_data.size() > kMinToCompress /*compress*/);
751 }
752 }
753
754 namespace {
755
type_printer(const c10::Type & type,torch::jit::TypeNameUniquer & type_name_uniquer)756 std::optional<std::string> type_printer(
757 const c10::Type& type,
758 torch::jit::TypeNameUniquer& type_name_uniquer) {
759 if (auto dyn = type.castRaw<c10::DynamicType>()) {
760 return dyn->fallback()->annotation_str(
761 [&](auto&& t) { return type_printer(t, type_name_uniquer); });
762 }
763 auto namedType = type.cast<c10::NamedType>();
764 if (namedType && namedType->name()) {
765 return type_name_uniquer.getUniqueName(namedType).qualifiedName();
766 }
767 return std::nullopt;
768 }
769
770 } // namespace
771
convertNamedType(const c10::NamedTypePtr & class_type)772 void ScriptModuleSerializer::convertNamedType(
773 const c10::NamedTypePtr& class_type) {
774 if (converted_types_.count(class_type)) {
775 return;
776 }
777 converted_types_.insert(class_type);
778 auto qualname = type_name_uniquer_.getUniqueName(class_type);
779 std::string qualifier = qualname.prefix();
780 PythonPrint* pp = file_streams_.find(qualifier);
781
782 if (!pp) {
783 pp = &file_streams_.insert(
784 std::move(qualifier),
785 PythonPrint(
786 constant_table_,
787 class_deps_,
788 [&](const c10::Type& t) {
789 return type_printer(t, type_name_uniquer_);
790 },
791 /*enforce_importable=*/true));
792 }
793 pp->printNamedType(class_type);
794 }
795
serialize_unified_format(Module & module,uint64_t script_module_id)796 void ScriptModuleSerializer::serialize_unified_format(
797 Module& module,
798 uint64_t script_module_id) {
799 const std::string archive_dir =
800 ".data/ts_code/" + std::to_string(script_module_id) + "/";
801
802 // Serialize the model object
803 writeArchive(
804 module._ivalue(),
805 "data",
806 archive_dir,
807 /*tensor_dir=*/".data/",
808 /*use_storage_context=*/true);
809 // Then we serialize all code info.
810 convertTypes(module.type());
811 // The tensor constants from the code are written to a separate archive
812 // so loading the code does not depend on loading the data
813 std::vector<IValue> ivalue_constants(
814 constant_table_.begin(), constant_table_.end());
815 writeArchive(
816 c10::ivalue::Tuple::create(ivalue_constants),
817 "constants",
818 archive_dir,
819 /*tensor_dir=*/".data/",
820 /*use_storage_context=*/true);
821
822 // Note: writeFiles() call needs to be made in addition to calling this
823 // function to have the code actually saved (tensors are saved)
824 }
825
storage_context()826 SerializationStorageContext& ScriptModuleSerializer::storage_context() {
827 return storage_context_;
828 }
829
ExportModule(const Module & module,std::ostream & out,const ExtraFilesMap & extra_files,bool bytecode_format,bool save_mobile_debug_info,bool use_flatbuffer)830 void ExportModule(
831 const Module& module,
832 std::ostream& out,
833 const ExtraFilesMap& extra_files,
834 bool bytecode_format,
835 bool save_mobile_debug_info,
836 bool use_flatbuffer) {
837 auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
838 out.write(
839 static_cast<const char*>(buf), static_cast<std::streamsize>(nbytes));
840 return !out ? 0 : nbytes;
841 };
842 ExportModule(
843 module,
844 writer_func,
845 extra_files,
846 bytecode_format,
847 save_mobile_debug_info,
848 use_flatbuffer);
849 }
850
ExportModule(const Module & module,const std::string & filename,const ExtraFilesMap & extra_files,bool bytecode_format,bool save_mobile_debug_info,bool use_flatbuffer)851 void ExportModule(
852 const Module& module,
853 const std::string& filename,
854 const ExtraFilesMap& extra_files,
855 bool bytecode_format,
856 bool save_mobile_debug_info,
857 bool use_flatbuffer) {
858 if (!use_flatbuffer) {
859 // the zip archive need to know the filepath
860 caffe2::serialize::PyTorchStreamWriter writer(filename);
861 ScriptModuleSerializer serializer(writer);
862 serializer.serialize(
863 module, extra_files, bytecode_format, save_mobile_debug_info);
864 return;
865 }
866 std::ofstream ofile;
867 ofile.open(filename, std::ios::binary | std::ios::out);
868 if (ofile.fail()) {
869 std::stringstream message;
870 if (errno == ENOENT) {
871 message << "Parent directory of " << filename << " does not exist.\n";
872 } else {
873 message << "Error while opening file: " << errno << '\n';
874 }
875 TORCH_CHECK(false, message.str());
876 }
877 ExportModule(
878 module,
879 ofile,
880 extra_files,
881 bytecode_format,
882 save_mobile_debug_info,
883 use_flatbuffer);
884 }
885
save_jit_module(const Module & module,const std::string & filename,const ExtraFilesMap & extra_files)886 void save_jit_module(
887 const Module& module,
888 const std::string& filename,
889 const ExtraFilesMap& extra_files) {
890 auto buffer = save_jit_module_to_bytes(module, extra_files);
891 std::fstream ofile(filename, std::ios::binary | std::ios::out);
892 ofile.write(
893 reinterpret_cast<char*>(buffer->data()),
894 static_cast<std::streamsize>(buffer->size()));
895 ofile.close();
896 }
897
save_jit_module_to_bytes(const Module & module,const ExtraFilesMap & extra_files)898 DetachedBuffer::UniqueDetachedBuffer save_jit_module_to_bytes(
899 const Module& module,
900 const ExtraFilesMap& extra_files) {
901 ExtraFilesMap jitfiles;
902 std::vector<IValue> constants;
903 jitModuleToPythonCodeAndConstants(module, &jitfiles, &constants);
904 CompilationOptions options = getOptionsFromGlobal();
905 mobile::Module mobilem = jitModuleToMobile(module, options);
906 return save_mobile_module_to_bytes(mobilem, extra_files, jitfiles, constants);
907 }
908
save_jit_module_to_write_func(const Module & module,const ExtraFilesMap & extra_files,bool save_mobile_debug_info,const std::function<size_t (const void *,size_t)> & writer_func)909 void save_jit_module_to_write_func(
910 const Module& module,
911 const ExtraFilesMap& extra_files,
912 bool save_mobile_debug_info,
913 const std::function<size_t(const void*, size_t)>& writer_func) {
914 (void)save_mobile_debug_info;
915 auto buffer = save_jit_module_to_bytes(module, extra_files);
916 writer_func(reinterpret_cast<void*>(buffer->data()), buffer->size());
917 }
918
ExportModule(const Module & module,const std::function<size_t (const void *,size_t)> & writer_func,const ExtraFilesMap & extra_files,bool bytecode_format,bool save_mobile_debug_info,bool use_flatbuffer)919 void ExportModule(
920 const Module& module,
921 const std::function<size_t(const void*, size_t)>& writer_func,
922 const ExtraFilesMap& extra_files,
923 bool bytecode_format,
924 bool save_mobile_debug_info,
925 bool use_flatbuffer) {
926 if (use_flatbuffer) {
927 save_jit_module_to_write_func(
928 module, extra_files, save_mobile_debug_info, writer_func);
929 } else {
930 caffe2::serialize::PyTorchStreamWriter writer(writer_func);
931 ScriptModuleSerializer serializer(writer);
932 serializer.serialize(
933 module, extra_files, bytecode_format, save_mobile_debug_info);
934 }
935 }
936
937 namespace {
export_opnames(const script::Module & m,std::set<std::string> & opnames)938 void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
939 mobile::Module mobile_m = jitModuleToMobile(m, getOptionsFromGlobal());
940 for (const auto& method : mobile_m.get_methods()) {
941 for (const auto& op : method.function().get_code().op_names_) {
942 opnames.emplace(
943 op.overload_name.empty() ? op.name
944 : op.name + "." + op.overload_name);
945 }
946 }
947 }
948 } // namespace
949
export_opnames(const script::Module & m)950 std::vector<std::string> export_opnames(const script::Module& m) {
951 std::set<std::string> names;
952 export_opnames(m, names);
953 return std::vector<std::string>(names.begin(), names.end());
954 }
955
956 // Thread local flag (only happens in export, i.e. on server side)
957 // to control if instructions for bytecode default inputs are emitted
958 // or not. It's the major difference between bytecode v5 and v6.
959 thread_local bool emitBytecodeDefaultInputs =
960 caffe2::serialize::kProducedBytecodeVersion <= 5 ? true : false;
is_default_value_for_unspecified_arg_enabled()961 bool BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled() {
962 return emitBytecodeDefaultInputs;
963 }
set_default_value_for_unspecified_arg_enabled(bool enabled)964 void BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
965 bool enabled) {
966 emitBytecodeDefaultInputs = enabled;
967 }
968
969 thread_local bool emitDefautlArgsWithOutArgs =
970 caffe2::serialize::kProducedBytecodeVersion <= 6 ? false : true;
is_default_args_before_out_args_enabled()971 bool BytecodeEmitMode::is_default_args_before_out_args_enabled() {
972 return emitDefautlArgsWithOutArgs;
973 }
set_default_args_before_out_args_enabled(bool enabled)974 void BytecodeEmitMode::set_default_args_before_out_args_enabled(bool enabled) {
975 emitDefautlArgsWithOutArgs = enabled;
976 }
977
978 thread_local bool emitDefaultEmitPromotedOps =
979 caffe2::serialize::kProducedBytecodeVersion <= 7 ? false : true;
is_emit_promoted_ops_enabled()980 bool BytecodeEmitMode::is_emit_promoted_ops_enabled() {
981 return emitDefaultEmitPromotedOps;
982 }
set_default_emit_promoted_ops_enabled(bool enabled)983 void BytecodeEmitMode::set_default_emit_promoted_ops_enabled(bool enabled) {
984 emitDefaultEmitPromotedOps = enabled;
985 }
986
987 } // namespace torch::jit
988