xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/export_module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/serialization/export.h>
2 
3 #include <c10/util/Exception.h>
4 #include <torch/csrc/jit/api/function_impl.h>
5 #include <torch/csrc/jit/backends/backend_debug_handler.h>
6 #include <torch/csrc/jit/backends/backend_debug_info.h>
7 #include <torch/csrc/jit/frontend/source_range.h>
8 #include <torch/csrc/jit/ir/attributes.h>
9 #include <torch/csrc/jit/ir/ir.h>
10 #include <torch/csrc/jit/ir/type_hashing.h>
11 #include <torch/csrc/jit/mobile/function.h>
12 #include <torch/csrc/jit/mobile/interpreter.h>
13 #include <torch/csrc/jit/mobile/method.h>
14 #include <torch/csrc/jit/mobile/module.h>
15 #include <torch/csrc/jit/passes/inliner.h>
16 #include <torch/csrc/jit/runtime/instruction.h>
17 #include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
18 #include <torch/csrc/jit/serialization/export_bytecode.h>
19 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
20 #include <torch/csrc/jit/serialization/import_export_constants.h>
21 #include <torch/csrc/jit/serialization/import_export_functions.h>
22 #include <torch/csrc/jit/serialization/import_export_helpers.h>
23 #include <torch/csrc/jit/serialization/pickle.h>
24 #include <torch/csrc/jit/serialization/python_print.h>
25 #include <torch/csrc/jit/serialization/source_range_serialization.h>
26 #include <torch/csrc/jit/serialization/type_name_uniquer.h>
27 
28 #include <caffe2/serialize/inline_container.h>
29 
30 #include <ATen/ATen.h>
31 
32 #include <ATen/core/jit_type.h>
33 #include <ATen/core/qualified_name.h>
34 #include <cerrno>
35 #include <sstream>
36 #include <string>
37 #include <unordered_map>
38 #include <unordered_set>
39 #include <utility>
40 #include <vector>
41 
42 namespace torch::jit {
43 
getOptionsFromGlobal()44 CompilationOptions getOptionsFromGlobal() {
45   CompilationOptions compilation_options;
46   compilation_options.enable_default_args_before_out_args =
47       BytecodeEmitMode::is_default_args_before_out_args_enabled();
48   compilation_options.enable_default_value_for_unspecified_arg =
49       BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled();
50   compilation_options.enable_emit_promoted_ops =
51       BytecodeEmitMode::is_emit_promoted_ops_enabled();
52   compilation_options.incl_interface_call = getMobileInterfaceCallExport();
53   compilation_options.model_version =
54       caffe2::serialize::kProducedBytecodeVersion;
55   return compilation_options;
56 }
57 
to_tuple(std::initializer_list<IValue> ivalues)58 static IValue to_tuple(std::initializer_list<IValue> ivalues) {
59   return c10::ivalue::Tuple::create(ivalues);
60 }
61 
to_tuple(std::vector<IValue> ivalues)62 IValue to_tuple(std::vector<IValue> ivalues) {
63   return c10::ivalue::Tuple::create(std::move(ivalues));
64 }
65 
Table(const std::vector<std::pair<std::string,IValue>> & entries)66 IValue Table(const std::vector<std::pair<std::string, IValue>>& entries) {
67   std::vector<IValue> ivalue_entries;
68   ivalue_entries.reserve(entries.size());
69   for (const auto& e : entries) {
70     ivalue_entries.push_back(to_tuple({e.first, e.second}));
71   }
72   return to_tuple(std::move(ivalue_entries));
73 }
74 
75 namespace {
76 
GetExtraFilesHook()77 ExportModuleExtraFilesHook& GetExtraFilesHook() {
78   static ExportModuleExtraFilesHook func = nullptr;
79   return func;
80 }
81 
82 /**
83  * If the type is not NamedTuple, it will return default_type_str. If the type
84  * is a NamedTuple, it will return a string with following structure to describe
85  * the content in the NamedTuple: "qualified_named[ NamedTuple, [ [filed_name_1,
86  * field_type_1], [filed_name_2, field_type_2]
87  *   ]
88  * ]"
89  *  Example NamedTuple type:
90  *  "__torch__.base_models.sparse_nn.pytorch_preproc_types.PreprocOutputType[
91  *     NamedTuple, [
92  *         [float_features, Tensor],
93  *         [id_list_features, List[Tensor]],
94  *         [label,  Tensor],
95  *         [weight, Tensor],
96  *         ]
97  *     ]"
98  *
99  * @param compilation_unit Jit compilation unit to look up function schema.
100  * @param type_ptr A type pointer and it can be possibly any type.
101  * @param default_type_str The default string representation. The string can
102  * either from type_ptr->str(), type_ptr->annotation_str(), or
103  * type_ptr->repr_str(). In some cases, they could be different in different
104  * scenario. For example, Tensor type can be "Tensor", "Tensor (inferred)" and
105  * "Tensor[]", and we only want "Tensor". Leave it as part of arguments as the
106  * default return, when type_ptr is not a NamedTuple.
107  * @return string representation.
108  */
get_named_tuple_str_or_default(const CompilationUnit & compilation_unit,const TypePtr & type_ptr,std::string default_type_str)109 std::string get_named_tuple_str_or_default(
110     const CompilationUnit& compilation_unit,
111     const TypePtr& type_ptr,
112     std::string default_type_str) {
113   if (type_ptr->kind() == TypeKind::TupleType) {
114     // For the simple types (Tensor, Tensor), the mobile type parse can parse
115     // it and compilation unit won't have it's definition. The default type
116     // string will be returned instead.
117     if (compilation_unit.get_named_tuple(type_ptr->str())) {
118       auto named_tuple_ptr = compilation_unit.get_named_tuple(type_ptr->str());
119       if (named_tuple_ptr != nullptr) {
120         std::string named_tuple_str = type_ptr->str();
121         named_tuple_str.append("[NamedTuple, [");
122         std::vector<IValue> name_type_pairs;
123 
124         // Get the field name and field type for the NamedTuple
125         for (auto it = named_tuple_ptr->schema()->arguments().begin();
126              it != named_tuple_ptr->schema()->arguments().end();
127              it++) {
128           const std::string named_tuple_name = it->name();
129           const c10::TypePtr& named_tuple_type = it->type();
130           // When it->type() is Tensor type, in Python, if it's inferred type,
131           // str() return "Tensor" and repr_str() return "Tensor (inferred)". If
132           // it's not inferred type, str() return "Tensor[]" and repr_str()
133           // return "Tensor". In cpp, repr_str() will always return "Tensor"
134           // regardless inferred type. When exporing custom type in bytecode,
135           // "Tensor" is the preferred way to deserialize Tensor type
136           std::string named_tuple_type_str = it->is_inferred_type()
137               ? named_tuple_type->str()
138               : named_tuple_type->repr_str();
139           // The type can also be NamedTuple. Will parse it recursively and get
140           // it's string representation.
141           named_tuple_type_str = get_named_tuple_str_or_default(
142               compilation_unit, named_tuple_type, named_tuple_type_str);
143           name_type_pairs.emplace_back(
144               c10::ivalue::Tuple::create({it->name(), named_tuple_type_str}));
145 
146           named_tuple_str.append("[")
147               .append(named_tuple_name)
148               .append(", ")
149               .append(named_tuple_type_str)
150               .append("]");
151           if (it != named_tuple_ptr->schema()->arguments().end() - 1) {
152             named_tuple_str.append(",");
153           }
154         }
155         named_tuple_str.append("]]");
156         return named_tuple_str;
157       }
158     }
159   }
160   return default_type_str;
161 }
162 
getFunctionTuple(const CompilationUnit & compilation_unit,const mobile::Function & func,BackendDebugInfoRecorder & debug_info_recorder,TypeNameUniquer & type_name_uniquer_)163 std::pair<IValue, IValue> getFunctionTuple(
164     const CompilationUnit& compilation_unit,
165     const mobile::Function& func,
166     BackendDebugInfoRecorder& debug_info_recorder,
167     TypeNameUniquer& type_name_uniquer_) {
168   const auto& mobile_code = func.get_code();
169 
170   // instructions
171   std::vector<IValue> instructions;
172   instructions.reserve(mobile_code.instructions_.size());
173   for (Instruction ins : mobile_code.instructions_) {
174     instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
175   }
176 
177   // operators
178   std::vector<IValue> operators;
179   operators.reserve(mobile_code.op_names_.size());
180   for (const auto i : c10::irange(mobile_code.op_names_.size())) {
181     const auto& opname = mobile_code.op_names_[i];
182     const int size = mobile_code.operator_input_sizes_[i];
183     if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {
184       operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
185     } else {
186       operators.emplace_back(
187           to_tuple({opname.name, opname.overload_name, size}));
188     }
189   }
190 
191   // types
192   std::vector<IValue> types;
193   types.reserve(mobile_code.types_.size());
194   static const std::string torch_prefix("__torch__");
195   static const std::string class_prefix("__torch__.torch.classes");
196 
197   for (const TypePtr& ty : mobile_code.types_) {
198     auto t = ty;
199     if (auto dyn = t->castRaw<c10::DynamicType>()) {
200       t = dyn->fallback();
201     }
202     std::string type_str = t->annotation_str();
203     if (t->kind() == TypeKind::DictType) {
204       // For DictType, there are two items in t->containedTypes(), the first one
205       // is key and the second one is value. Both of them could be NamedTuple
206       // type.
207       const TypePtr& key_type = t->containedTypes()[0];
208       const TypePtr& value_type = t->containedTypes()[1];
209       std::string key_type_str = get_named_tuple_str_or_default(
210           compilation_unit, key_type, key_type->annotation_str());
211       std::string value_type_str = get_named_tuple_str_or_default(
212           compilation_unit, value_type, value_type->annotation_str());
213 
214       // Construct the dict representation after achieving correct string
215       // representation for both key and value, like
216       // "Dict[str,__torch__.dper3.core.pytorch_schema_utils.IdScoreListFeatureTuple[NamedTuple,
217       // [[lengths, Tensor],[values,
218       // __torch__.dper3.core.pytorch_schema_utils.IdScoreTuple[NamedTuple,
219       // [[ids, Tensor],[scores, Tensor]]]],[offsets, Optional[Tensor]]]]]"
220       std::string dict_str;
221       dict_str.append("Dict[")
222           .append(key_type_str)
223           .append(",")
224           .append(value_type_str)
225           .append("]");
226       types.emplace_back(dict_str);
227       continue;
228     } else if (t->kind() == TypeKind::TupleType) {
229       std::string named_tuple_str =
230           get_named_tuple_str_or_default(compilation_unit, t, type_str);
231       types.emplace_back(named_tuple_str);
232       continue;
233     } else if (type_str.find(torch_prefix) == 0) {
234       TORCH_CHECK(
235           type_str.find(class_prefix) == 0,
236           "__torch__ types other than custom c++ classes (__torch__.torch.classes)"
237           "are not supported in lite interpreter. ",
238           "Workaround: instead of using arbitrary class type (class Foo()), ",
239           "define a pytorch class (class Foo(torch.nn.Module)). The problematic type is: ",
240           type_str);
241     }
242     types.emplace_back(type_str);
243   }
244 
245   // since the register location is embedded into the bytecode, pass the
246   // register size
247   auto register_size = static_cast<int>(mobile_code.register_size_);
248 
249   auto codeTable = Table(
250       {{"instructions", to_tuple(instructions)},
251        {"operators", to_tuple(operators)},
252        {"constants", to_tuple(mobile_code.constants_)},
253        {"types", to_tuple(types)},
254        {"register_size", register_size}});
255 
256   // schema
257   const auto& schema = func.getSchema();
258   auto type_printer = [&](const c10::Type& t) -> std::optional<std::string> {
259     auto namedType = t.cast<c10::NamedType>();
260     if (namedType && namedType->name()) {
261       return type_name_uniquer_.getUniqueName(namedType).qualifiedName();
262     }
263     return std::nullopt;
264   };
265 
266   auto makeArgTuple = [&](const std::vector<Argument>& args) {
267     std::vector<IValue> argTables;
268     for (auto&& arg : args) {
269       TORCH_CHECK(
270           !arg.N(),
271           "Arguments with known list lengths are not supported in mobile modules.");
272       TORCH_CHECK(
273           !arg.kwarg_only(),
274           "Keyword-only arguments are not supported in mobile modules.");
275       /*
276         This part adds the argument's name, type and default_value in
277         `bytecode.pkl` This has to be consistent with the `code/` directory
278         which has annotated py code of the entire module. `type_printer` uses
279         `TypeNameUniquer` to get the managled name of the argument. This helps
280         in having the right object reference when a class method is called using
281         the `self` argument.
282 
283         arg.type()->annotation_str(type_printer) => mangled unique name of the
284         module/submodule
285       */
286       auto arg_type = arg.type();
287       if (auto dyn = arg_type->castRaw<c10::DynamicType>()) {
288         arg_type = dyn->fallback();
289       }
290       argTables.emplace_back(Table({
291           {"name", arg.name()},
292           {"type", arg_type->annotation_str(type_printer)},
293           {"default_value", arg.default_value()},
294       }));
295     }
296     return to_tuple(argTables);
297   };
298   auto schemaTable = Table({
299       {"arguments", makeArgTuple(schema.arguments())},
300       {"returns", makeArgTuple(schema.returns())},
301   });
302 
303   // function tuple
304   std::string qn;
305   if (func.name() == "__setstate__" || func.name() == "__getstate__") {
306     auto classtype = func.getSchema().arguments()[0].type()->cast<ClassType>();
307     TORCH_INTERNAL_ASSERT(
308         classtype, "class is null ", func.qualname().qualifiedName());
309     qn = c10::QualifiedName(
310              type_name_uniquer_.getUniqueName(classtype), func.name())
311              .qualifiedName();
312   } else {
313     qn = func.qualname().qualifiedName();
314   }
315   auto bytecode_vals = to_tuple({qn, codeTable, schemaTable});
316 
317   std::optional<IValue> debug_info_vals;
318   // module debug info
319   // This is just a set of debug handles.
320   // We always save debug handles.
321   // debug handles generated by debug_handle_manager
322   // will correspond to {source_range, inlinedCallStackPtr} which we will
323   // serialize separately.
324   IValue module_debug_tuple =
325       c10::ivalue::Tuple::create(mobile_code.debug_handles_);
326   auto function_debug_info =
327       Table({{"function_debug_handles", module_debug_tuple}});
328   debug_info_vals = to_tuple({qn, function_debug_info});
329   return std::make_pair(bytecode_vals, debug_info_vals);
330 }
331 
pushMobileFunctionsToIValues(const CompilationUnit & compilation_unit,const mobile::Module & module,std::vector<c10::IValue> & elements,std::vector<c10::IValue> & debugInfoElements,BackendDebugInfoRecorder & recorder,TypeNameUniquer & uniquer)332 void pushMobileFunctionsToIValues(
333     const CompilationUnit& compilation_unit,
334     const mobile::Module& module,
335     std::vector<c10::IValue>& elements,
336     std::vector<c10::IValue>& debugInfoElements,
337     BackendDebugInfoRecorder& recorder,
338     TypeNameUniquer& uniquer) {
339   for (const auto& method : module.get_methods()) {
340     auto tuple = getFunctionTuple(
341         compilation_unit, method.function(), recorder, uniquer);
342     elements.push_back(std::move(tuple.first));
343     debugInfoElements.push_back(std::move(tuple.second));
344   }
345 }
346 
347 struct ModuleMethod {
ModuleMethodtorch::jit::__anonfedb86c30111::ModuleMethod348   ModuleMethod(Module m, const GraphFunction& f, c10::QualifiedName n)
349       : module(std::move(m)), function(f), exportName(std::move(n)) {}
350   Module module;
351   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
352   const GraphFunction& function;
353   c10::QualifiedName exportName;
354 };
355 
isLoweredModule(const Module & m)356 bool isLoweredModule(const Module& m) {
357   c10::QualifiedName type_name;
358   if (m.type()->name()) {
359     type_name = m.type()->name().value();
360   }
361   bool isLoweredModule = false;
362   for (const auto& atom : type_name.atoms()) {
363     if (atom == "LoweredModule") {
364       isLoweredModule = true;
365       break;
366     }
367   }
368   return isLoweredModule;
369 }
370 
371 // Check if the global static map of backend debug info
372 // contains debug info for this module and any of its children.
373 // If so combine all the maps together and return one.
getBackendDebugInfoMap(const Module & m,BackendDebugInfoMapType & debug_map)374 void getBackendDebugInfoMap(
375     const Module& m,
376     BackendDebugInfoMapType& debug_map) {
377   if (isLoweredModule(m)) {
378     auto backend_debug_info =
379         m.attr("__backend_debug_info").toCustomClass<PyTorchBackendDebugInfo>();
380     const auto& map = backend_debug_info->getDebugInfoMap();
381     if (map) {
382       debug_map.insert(map.value().begin(), map.value().end());
383     }
384   }
385   for (const auto& c : m.children()) {
386     getBackendDebugInfoMap(c, debug_map);
387   }
388 }
389 
getBackendSourceRanges(const Module & m)390 SourceRangeRecords getBackendSourceRanges(const Module& m) {
391   SourceRangeRecords sr_records;
392   if (isLoweredModule(m)) {
393     constexpr size_t kSourceRange = 1;
394     auto backend_debug_info =
395         m.attr("__backend_debug_info").toCustomClass<PyTorchBackendDebugInfo>();
396     const auto& map = backend_debug_info->getDebugInfoMap();
397     if (map) {
398       const auto& map_val = map.value();
399       // This map is map of debug handle-to-DebugInfoTuple
400       // DebugInfoTuple= <source range, op name, inlined_cs_ptr>
401       for (const auto& it : map_val) {
402         auto& source_range =
403             std::get<kDebugInfoTupleSourceRangeIndex>(it.second);
404         sr_records.emplace_back(
405             std::numeric_limits<size_t>::max(), source_range);
406         const auto& cs_ptr = std::get<kDebugInfoTupleInlinedCSIndex>(it.second);
407         if (cs_ptr) {
408           for (const auto& e : cs_ptr->vec()) {
409             const auto& sr = std::get<kSourceRange>(e);
410             sr_records.emplace_back(std::numeric_limits<size_t>::max(), sr);
411           }
412         }
413       }
414     }
415   }
416   for (const auto& c : m.children()) {
417     const auto& child_sr_records = getBackendSourceRanges(c);
418     sr_records.reserve(sr_records.size() + child_sr_records.size());
419     std::move(
420         child_sr_records.begin(),
421         child_sr_records.end(),
422         std::back_inserter(sr_records));
423   }
424   return sr_records;
425 }
426 
427 // TODO: remove mobileInterfaceCallExport as it is no longer needed.
428 // This function was introduced to guard the usage of `InterfaceCall` and
429 // now the support for `InterfaceCall` should be mature enough.
mobileInterfaceCallExport()430 auto& mobileInterfaceCallExport() {
431   static std::atomic<bool> flag{true};
432   return flag;
433 }
434 
435 } // namespace
436 
enableMobileInterfaceCallExport()437 TORCH_API void enableMobileInterfaceCallExport() {
438   mobileInterfaceCallExport().store(true, std::memory_order_relaxed);
439 }
getMobileInterfaceCallExport()440 bool getMobileInterfaceCallExport() {
441   return mobileInterfaceCallExport().load(std::memory_order_relaxed);
442 }
443 
SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook)444 void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
445   GetExtraFilesHook() = std::move(hook);
446 }
447 
serialize(const Module & module,const ExtraFilesMap & extra_files,bool bytecode_format,bool save_mobile_debug_info)448 void ScriptModuleSerializer::serialize(
449     const Module& module,
450     const ExtraFilesMap& extra_files,
451     bool bytecode_format,
452     bool save_mobile_debug_info) {
453   C10_LOG_API_USAGE_ONCE("torch.jit.save");
454   writeExtraFiles(module, extra_files);
455   // Serialize the model object
456   writeArchive(
457       module._ivalue(),
458       /*archive_name=*/"data",
459       /*archive_dir=*/"",
460       /*tensor_dir=*/"data/");
461   // Then we serialize all code info.
462   convertTypes(module.type());
463   writeFiles("code/");
464   // The tensor constants from the code are written to a separate archive
465   // so loading the code does not depend on loading the data
466   std::vector<IValue> ivalue_constants(
467       constant_table_.begin(), constant_table_.end());
468   if (bytecode_format) {
469     writeArchive(
470         c10::ivalue::Tuple::create(ivalue_constants),
471         /*archive_name=*/"constants",
472         /*archive_dir=*/"",
473         /*tensor_dir=*/"constants/",
474         /*use_storage_context=*/true);
475 
476     writeByteCode(module, save_mobile_debug_info);
477   } else {
478     writeArchive(
479         c10::ivalue::Tuple::create(ivalue_constants),
480         /*archive_name=*/"constants",
481         /*archive_dir=*/"",
482         /*tensor_dir=*/"constants/");
483   }
484   if (!module.retrieve_traced_inputs().empty()) {
485     writeArchive(
486         module.retrieve_traced_inputs(),
487         /*archive_name=*/"traced_inputs",
488         /*archive_dir=*/"",
489         /*tensor_dir=*/"traced_inputs/",
490         /*use_storage_context*/ false,
491         /*skip_tensor_data*/ true);
492   }
493   // Acquires and sets minimum (dynamic) version
494   for (auto& item : file_streams_) {
495     writer_.setMinVersion(item.value().minVersion());
496   }
497 }
498 
writeArchive(const IValue & value,const std::string & archive_name,const std::string & archive_dir,const std::string & tensor_dir,bool use_storage_context,bool skip_tensor_data)499 void ScriptModuleSerializer::writeArchive(
500     const IValue& value,
501     const std::string& archive_name,
502     const std::string& archive_dir,
503     const std::string& tensor_dir,
504     bool use_storage_context,
505     bool skip_tensor_data) {
506   std::vector<char> data;
507   // Vector to capture the run-time class types during pickling the IValues
508   std::vector<c10::ClassTypePtr> memoizedClassTypes;
509   std::vector<std::string> tensor_names;
510   // tensors that are already serialized in use_storage_context
511   std::unordered_set<std::string> serialized_tensors;
512   Pickler data_pickle(
513       [&](const char* buf, size_t size) {
514         data.insert(data.end(), buf, buf + size);
515       },
516       nullptr,
517       [&](const c10::ClassTypePtr& t) {
518         return type_name_uniquer_.getUniqueName(t);
519       },
520       &memoizedClassTypes,
521       [&](const at::Tensor& tensor) {
522         // returns a string to use in picker.cpp as storage obj key
523         if (use_storage_context) {
524           bool already_serialized =
525               storage_context_.hasStorage(tensor.storage());
526           std::string tensor_name =
527               std::to_string(
528                   storage_context_.getOrAddStorage(tensor.storage())) +
529               ".storage";
530           if (already_serialized) {
531             // this case is hit when storage has been serialized already
532             // from a torch.package context
533             serialized_tensors.insert(tensor_name);
534           }
535           tensor_names.push_back(tensor_name);
536         } else {
537           tensor_names.push_back(std::to_string(tensor_names.size()));
538         }
539         return tensor_names.back();
540       });
541   data_pickle.protocol();
542   data_pickle.pushIValue(value);
543   data_pickle.stop();
544   // write out tensor data
545   size_t i = 0;
546 
547   TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size());
548 
549   for (const auto& td : data_pickle.tensorData()) {
550     std::string tensor_name = tensor_names[i++];
551     if (td.is_meta() || skip_tensor_data) {
552       writer_.writeRecord(tensor_dir + tensor_name, nullptr, 0);
553       continue;
554     }
555     WriteableTensorData writable_td = getWriteableTensorData(td);
556     if (use_storage_context && serialized_tensors.count(tensor_name)) {
557       // storage has been serialzed already, skip
558       continue;
559     }
560     writer_.writeRecord(
561         tensor_dir + tensor_name,
562         writable_td.data(),
563         writable_td.sizeInBytes());
564   }
565 
566   std::string fname = archive_dir + archive_name + ".pkl";
567   writer_.writeRecord(fname, data.data(), data.size());
568 
569   // serialize all the captured run-time class types
570   for (const c10::ClassTypePtr& wroteType : memoizedClassTypes) {
571     convertNamedType(wroteType);
572   }
573 }
574 
writeExtraFiles(const Module & module,const ExtraFilesMap & extra_files)575 void ScriptModuleSerializer::writeExtraFiles(
576     const Module& module,
577     const ExtraFilesMap& extra_files) {
578   // Write out extra files.
579   for (const auto& kv : extra_files) {
580     const std::string key = "extra/" + kv.first;
581     writer_.writeRecord(key, kv.second.data(), kv.second.size());
582   }
583   auto hook = GetExtraFilesHook();
584   if (hook) {
585     ExtraFilesMap hook_files = hook(module);
586     for (const auto& kv : hook_files) {
587       // Checks if the hooked file is already written in extra files,
588       //   if so, skips it and warns
589       if (extra_files.find(kv.first) != extra_files.end()) {
590         TORCH_WARN_ONCE(
591             "An extra files hook attempted to write ",
592             kv.first,
593             " but ",
594             "this is already written in extra files and so will be skipped. ",
595             "This warning will only appear once per process.");
596         continue;
597       }
598       const std::string key = "extra/" + kv.first;
599       writer_.writeRecord(key, kv.second.data(), kv.second.size());
600     }
601   }
602 }
603 
updateSourceRangeTags(const SourceRangeRecords & ranges)604 void ScriptModuleSerializer::updateSourceRangeTags(
605     const SourceRangeRecords& ranges) {
606   for (const auto& range : ranges) {
607     if (source_range_tags_.find(range.range) == source_range_tags_.end()) {
608       source_range_tags_[range.range] = current_source_range_tag_;
609       current_source_range_tag_++;
610     }
611   }
612 }
613 
convertTypes(const at::NamedTypePtr & root_type)614 void ScriptModuleSerializer::convertTypes(const at::NamedTypePtr& root_type) {
615   class_deps_.add(root_type);
616   for (size_t i = 0; i < class_deps_.size(); ++i) {
617     // note: convertNameType may extend class_deps_, so re-checking .size() is
618     // necessary
619     convertNamedType(class_deps_[i]);
620   }
621 }
622 
writeFiles(const std::string & code_dir)623 void ScriptModuleSerializer::writeFiles(const std::string& code_dir) {
624   current_source_range_tag_ = 0;
625   // Mapping of filename => src. We need this because multiple classes may go
626   // in the same file (e.g. foo.bar.Baz and foo.bar.Qux)
627   for (auto& item : file_streams_) {
628     const std::string filename = qualifierToArchivePath(item.key(), code_dir);
629 
630     std::string src = item.value().str();
631 
632     // Only compress these records if they're not tiny.
633     // The cpu cost of generating zip datastructs and compressing isn't
634     // well-spent for very small records.
635     static constexpr size_t kMinToCompress = 200;
636 
637     writer_.writeRecord(
638         filename,
639         src.c_str(),
640         src.size(),
641         src.size() > kMinToCompress /*compress*/);
642 
643     // Write out the debug information
644     std::string debugFilename = filename + ".debug_pkl";
645     SourceRangePickler source_range_pickler;
646     updateSourceRangeTags(item.value().ranges());
647     auto range_data =
648         source_range_pickler.pickle(item.value().ranges(), source_range_tags_);
649     writer_.writeRecord(
650         debugFilename,
651         range_data.data(),
652         range_data.size(),
653         range_data.size() > kMinToCompress /*compress*/);
654   }
655 }
656 
writeByteCode(const Module & module,const bool save_mobile_debug_info)657 void ScriptModuleSerializer::writeByteCode(
658     const Module& module,
659     const bool save_mobile_debug_info) {
660   std::vector<c10::IValue> elements;
661   BackendDebugInfoRecorder debug_info_recorder;
662   int64_t version_to_write = caffe2::serialize::kProducedBytecodeVersion;
663 
664   elements.emplace_back(static_cast<int64_t>(version_to_write));
665   std::vector<c10::IValue> debug_info_elements;
666   // Always save debug handles
667   debug_info_elements.emplace_back(static_cast<int64_t>(version_to_write));
668 
669   mobile::Module mobile_module =
670       jitModuleToMobile(module, getOptionsFromGlobal());
671 
672   pushMobileFunctionsToIValues(
673       *module._ivalue()->compilation_unit(),
674       mobile_module,
675       elements,
676       debug_info_elements,
677       debug_info_recorder,
678       type_name_uniquer_);
679 
680   auto telements = to_tuple(std::move(elements));
681   writeArchive(
682       telements,
683       /*archive_name=*/"bytecode",
684       /*archive_dir=*/"",
685       /*tensor_dir=*/"constants/",
686       /*use_storage_context=*/true);
687 
688   auto debug_info_telements = to_tuple(std::move(debug_info_elements));
689 
690   // At the moment keeping this feature experimental
691   // since we have not evaluated how this affect model size
692   // and we have not build any utility to strip off debug info
693   // when desired
694   // TODO: Build utility to strip off debug map. It should also do the
695   // same for debug_pkl files
696   if (save_mobile_debug_info) {
697     // Note that stripping off debug map will not strip off
698     // debug handles.
699     // The reason we save debug handles conditionally is so that
700     // we dont end up with a model that has debug handles but has not
701     // debug map to correlate debug handels with.
702     // Once we have a model with both handles and debug map, we can
703     // strip off debug map and have a lean model served to production.
704     // If exception ocurrs we have a model with debug map that can be
705     // used to symbolicate debug handles
706     writeArchive(
707         debug_info_telements,
708         /*archive_name=*/"mobile_debug_handles",
709         /*archive_dir=*/"",
710         /*tensor_dir=*/"mobile_debug_handles/");
711     static constexpr size_t kMinToCompress = 200;
712     // For delegated backends get source ranges that are in the debug info
713     // map. Since delegated backend replace original module with lowered
714     // module we will not serialize original module's code which is what would
715     // have contained source range. Since we dont have that anymore, extract
716     // source ranges out of delegated module and store in a separate archive.
717     // Note that we must do this first because in order to serialize inlined
718     // CS appropriate source_range_tags must have been generated.
719     auto backend_source_range_records = getBackendSourceRanges(module);
720     SourceRangePickler source_range_pickler;
721     updateSourceRangeTags(backend_source_range_records);
722     auto range_data = source_range_pickler.pickle(
723         backend_source_range_records, source_range_tags_);
724     std::string debugFilename = "delegated_backends.debug_pkl";
725     writer_.writeRecord(
726         debugFilename,
727         range_data.data(),
728         range_data.size(),
729         range_data.size() > kMinToCompress /*compress*/);
730 
731     // For delegated backends get debug_info_map
732     // This is merged with other debug_info_map of other modules
733     // which were not delegated.
734     BackendDebugInfoMapType backend_debug_info_map;
735     getBackendDebugInfoMap(module, backend_debug_info_map);
736     // Now get the debug-handles-to-inlined-cs-ptr-map
737     // And serialize that in a separate archive
738     const auto& debug_info = mobile_module.getDebugTable().getCallStackPtrMap();
739     BackendDebugInfoMapType debug_handle_cs_ptr_map(
740         debug_info.begin(), debug_info.end());
741     CallStackDebugInfoPickler cs_debug_info_pickler;
742     auto cs_data = cs_debug_info_pickler.pickle(
743         debug_handle_cs_ptr_map, source_range_tags_);
744     // Write out map: [debug-handle, {source range, InlinedCallStack}]
745     std::string filename = "callstack_debug_map.pkl";
746     writer_.writeRecord(
747         filename,
748         cs_data.data(),
749         cs_data.size(),
750         cs_data.size() > kMinToCompress /*compress*/);
751   }
752 }
753 
754 namespace {
755 
type_printer(const c10::Type & type,torch::jit::TypeNameUniquer & type_name_uniquer)756 std::optional<std::string> type_printer(
757     const c10::Type& type,
758     torch::jit::TypeNameUniquer& type_name_uniquer) {
759   if (auto dyn = type.castRaw<c10::DynamicType>()) {
760     return dyn->fallback()->annotation_str(
761         [&](auto&& t) { return type_printer(t, type_name_uniquer); });
762   }
763   auto namedType = type.cast<c10::NamedType>();
764   if (namedType && namedType->name()) {
765     return type_name_uniquer.getUniqueName(namedType).qualifiedName();
766   }
767   return std::nullopt;
768 }
769 
770 } // namespace
771 
convertNamedType(const c10::NamedTypePtr & class_type)772 void ScriptModuleSerializer::convertNamedType(
773     const c10::NamedTypePtr& class_type) {
774   if (converted_types_.count(class_type)) {
775     return;
776   }
777   converted_types_.insert(class_type);
778   auto qualname = type_name_uniquer_.getUniqueName(class_type);
779   std::string qualifier = qualname.prefix();
780   PythonPrint* pp = file_streams_.find(qualifier);
781 
782   if (!pp) {
783     pp = &file_streams_.insert(
784         std::move(qualifier),
785         PythonPrint(
786             constant_table_,
787             class_deps_,
788             [&](const c10::Type& t) {
789               return type_printer(t, type_name_uniquer_);
790             },
791             /*enforce_importable=*/true));
792   }
793   pp->printNamedType(class_type);
794 }
795 
serialize_unified_format(Module & module,uint64_t script_module_id)796 void ScriptModuleSerializer::serialize_unified_format(
797     Module& module,
798     uint64_t script_module_id) {
799   const std::string archive_dir =
800       ".data/ts_code/" + std::to_string(script_module_id) + "/";
801 
802   // Serialize the model object
803   writeArchive(
804       module._ivalue(),
805       "data",
806       archive_dir,
807       /*tensor_dir=*/".data/",
808       /*use_storage_context=*/true);
809   // Then we serialize all code info.
810   convertTypes(module.type());
811   // The tensor constants from the code are written to a separate archive
812   // so loading the code does not depend on loading the data
813   std::vector<IValue> ivalue_constants(
814       constant_table_.begin(), constant_table_.end());
815   writeArchive(
816       c10::ivalue::Tuple::create(ivalue_constants),
817       "constants",
818       archive_dir,
819       /*tensor_dir=*/".data/",
820       /*use_storage_context=*/true);
821 
822   // Note: writeFiles() call needs to be made in addition to calling this
823   // function to have the code actually saved (tensors are saved)
824 }
825 
storage_context()826 SerializationStorageContext& ScriptModuleSerializer::storage_context() {
827   return storage_context_;
828 }
829 
ExportModule(const Module & module,std::ostream & out,const ExtraFilesMap & extra_files,bool bytecode_format,bool save_mobile_debug_info,bool use_flatbuffer)830 void ExportModule(
831     const Module& module,
832     std::ostream& out,
833     const ExtraFilesMap& extra_files,
834     bool bytecode_format,
835     bool save_mobile_debug_info,
836     bool use_flatbuffer) {
837   auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
838     out.write(
839         static_cast<const char*>(buf), static_cast<std::streamsize>(nbytes));
840     return !out ? 0 : nbytes;
841   };
842   ExportModule(
843       module,
844       writer_func,
845       extra_files,
846       bytecode_format,
847       save_mobile_debug_info,
848       use_flatbuffer);
849 }
850 
ExportModule(const Module & module,const std::string & filename,const ExtraFilesMap & extra_files,bool bytecode_format,bool save_mobile_debug_info,bool use_flatbuffer)851 void ExportModule(
852     const Module& module,
853     const std::string& filename,
854     const ExtraFilesMap& extra_files,
855     bool bytecode_format,
856     bool save_mobile_debug_info,
857     bool use_flatbuffer) {
858   if (!use_flatbuffer) {
859     // the zip archive need to know the filepath
860     caffe2::serialize::PyTorchStreamWriter writer(filename);
861     ScriptModuleSerializer serializer(writer);
862     serializer.serialize(
863         module, extra_files, bytecode_format, save_mobile_debug_info);
864     return;
865   }
866   std::ofstream ofile;
867   ofile.open(filename, std::ios::binary | std::ios::out);
868   if (ofile.fail()) {
869     std::stringstream message;
870     if (errno == ENOENT) {
871       message << "Parent directory of " << filename << " does not exist.\n";
872     } else {
873       message << "Error while opening file: " << errno << '\n';
874     }
875     TORCH_CHECK(false, message.str());
876   }
877   ExportModule(
878       module,
879       ofile,
880       extra_files,
881       bytecode_format,
882       save_mobile_debug_info,
883       use_flatbuffer);
884 }
885 
save_jit_module(const Module & module,const std::string & filename,const ExtraFilesMap & extra_files)886 void save_jit_module(
887     const Module& module,
888     const std::string& filename,
889     const ExtraFilesMap& extra_files) {
890   auto buffer = save_jit_module_to_bytes(module, extra_files);
891   std::fstream ofile(filename, std::ios::binary | std::ios::out);
892   ofile.write(
893       reinterpret_cast<char*>(buffer->data()),
894       static_cast<std::streamsize>(buffer->size()));
895   ofile.close();
896 }
897 
save_jit_module_to_bytes(const Module & module,const ExtraFilesMap & extra_files)898 DetachedBuffer::UniqueDetachedBuffer save_jit_module_to_bytes(
899     const Module& module,
900     const ExtraFilesMap& extra_files) {
901   ExtraFilesMap jitfiles;
902   std::vector<IValue> constants;
903   jitModuleToPythonCodeAndConstants(module, &jitfiles, &constants);
904   CompilationOptions options = getOptionsFromGlobal();
905   mobile::Module mobilem = jitModuleToMobile(module, options);
906   return save_mobile_module_to_bytes(mobilem, extra_files, jitfiles, constants);
907 }
908 
save_jit_module_to_write_func(const Module & module,const ExtraFilesMap & extra_files,bool save_mobile_debug_info,const std::function<size_t (const void *,size_t)> & writer_func)909 void save_jit_module_to_write_func(
910     const Module& module,
911     const ExtraFilesMap& extra_files,
912     bool save_mobile_debug_info,
913     const std::function<size_t(const void*, size_t)>& writer_func) {
914   (void)save_mobile_debug_info;
915   auto buffer = save_jit_module_to_bytes(module, extra_files);
916   writer_func(reinterpret_cast<void*>(buffer->data()), buffer->size());
917 }
918 
ExportModule(const Module & module,const std::function<size_t (const void *,size_t)> & writer_func,const ExtraFilesMap & extra_files,bool bytecode_format,bool save_mobile_debug_info,bool use_flatbuffer)919 void ExportModule(
920     const Module& module,
921     const std::function<size_t(const void*, size_t)>& writer_func,
922     const ExtraFilesMap& extra_files,
923     bool bytecode_format,
924     bool save_mobile_debug_info,
925     bool use_flatbuffer) {
926   if (use_flatbuffer) {
927     save_jit_module_to_write_func(
928         module, extra_files, save_mobile_debug_info, writer_func);
929   } else {
930     caffe2::serialize::PyTorchStreamWriter writer(writer_func);
931     ScriptModuleSerializer serializer(writer);
932     serializer.serialize(
933         module, extra_files, bytecode_format, save_mobile_debug_info);
934   }
935 }
936 
937 namespace {
export_opnames(const script::Module & m,std::set<std::string> & opnames)938 void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
939   mobile::Module mobile_m = jitModuleToMobile(m, getOptionsFromGlobal());
940   for (const auto& method : mobile_m.get_methods()) {
941     for (const auto& op : method.function().get_code().op_names_) {
942       opnames.emplace(
943           op.overload_name.empty() ? op.name
944                                    : op.name + "." + op.overload_name);
945     }
946   }
947 }
948 } // namespace
949 
export_opnames(const script::Module & m)950 std::vector<std::string> export_opnames(const script::Module& m) {
951   std::set<std::string> names;
952   export_opnames(m, names);
953   return std::vector<std::string>(names.begin(), names.end());
954 }
955 
956 // Thread local flag (only happens in export, i.e. on server side)
957 // to control if instructions for bytecode default inputs are emitted
958 // or not. It's the major difference between bytecode v5 and v6.
959 thread_local bool emitBytecodeDefaultInputs =
960     caffe2::serialize::kProducedBytecodeVersion <= 5 ? true : false;
is_default_value_for_unspecified_arg_enabled()961 bool BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled() {
962   return emitBytecodeDefaultInputs;
963 }
set_default_value_for_unspecified_arg_enabled(bool enabled)964 void BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled(
965     bool enabled) {
966   emitBytecodeDefaultInputs = enabled;
967 }
968 
969 thread_local bool emitDefautlArgsWithOutArgs =
970     caffe2::serialize::kProducedBytecodeVersion <= 6 ? false : true;
is_default_args_before_out_args_enabled()971 bool BytecodeEmitMode::is_default_args_before_out_args_enabled() {
972   return emitDefautlArgsWithOutArgs;
973 }
set_default_args_before_out_args_enabled(bool enabled)974 void BytecodeEmitMode::set_default_args_before_out_args_enabled(bool enabled) {
975   emitDefautlArgsWithOutArgs = enabled;
976 }
977 
978 thread_local bool emitDefaultEmitPromotedOps =
979     caffe2::serialize::kProducedBytecodeVersion <= 7 ? false : true;
is_emit_promoted_ops_enabled()980 bool BytecodeEmitMode::is_emit_promoted_ops_enabled() {
981   return emitDefaultEmitPromotedOps;
982 }
set_default_emit_promoted_ops_enabled(bool enabled)983 void BytecodeEmitMode::set_default_emit_promoted_ops_enabled(bool enabled) {
984   emitDefaultEmitPromotedOps = enabled;
985 }
986 
987 } // namespace torch::jit
988