xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/flatbuffer_serializer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
2 
3 #ifdef FLATBUFFERS_VERSION_MAJOR
4 #error "flatbuffer_serializer.h must not include any flatbuffers headers"
5 #endif // FLATBUFFERS_VERSION_MAJOR
6 
7 #include <fstream>
8 #include <functional>
9 #include <string>
10 #include <unordered_map>
11 #include <utility>
12 #include <vector>
13 
14 #include <ATen/ATen.h>
15 #include <c10/core/CPUAllocator.h>
16 #include <c10/util/Exception.h>
17 #include <caffe2/serialize/versions.h>
18 #include <torch/csrc/jit/mobile/code.h>
19 #include <torch/csrc/jit/mobile/train/export_data.h>
20 #include <torch/csrc/jit/passes/inliner.h>
21 #include <torch/csrc/jit/runtime/instruction.h>
22 
23 #if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2)
24 #include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT
25 namespace flatbuffers = flatbuffers_fbsource;
26 #define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT
27 #else
28 #include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
29 #endif
30 
31 namespace torch::jit {
32 
33 using flatbuffers::FlatBufferBuilder;
34 using mobile::serialization::CreateArg;
35 using mobile::serialization::CreateDebugInfo;
36 using mobile::serialization::CreateDict;
37 using mobile::serialization::CreateFunctionDirect;
38 using mobile::serialization::CreateIValue;
39 using mobile::serialization::CreateList;
40 using mobile::serialization::CreateModule;
41 using mobile::serialization::CreateObject;
42 using mobile::serialization::CreateOperator;
43 using mobile::serialization::CreateTensorMetadataDirect;
44 using mobile::serialization::CreateTupleDirect;
45 
46 namespace {
47 
48 // TODO: remove once caffe2::kProducedBytecodeVersion is >= 9 and flatbuffer is
49 // launched.
50 constexpr uint32_t kMinVersion = 9;
51 
52 // We will store IValue NONE in index 0 in flatbuffer.
53 constexpr int kNoneIndex = 0;
54 
realType(TypePtr type)55 static TypePtr realType(TypePtr type) {
56   if (auto dyn = type->castRaw<c10::DynamicType>()) {
57     return dyn->fallback();
58   } else {
59     return type;
60   }
61 }
62 
print_type(const c10::Type & t)63 auto print_type(const c10::Type& t) -> std::optional<std::string> {
64   auto namedType = t.cast<c10::NamedType>();
65   if (namedType && namedType->name()) {
66     return namedType->name().value().qualifiedName();
67   }
68   if (auto dyn = t.castRaw<c10::DynamicType>()) {
69     return dyn->fallback()->annotation_str();
70   }
71   return std::nullopt;
72 }
73 
74 class FlatbufferSerializer {
75  public:
76   FlatbufferSerializer() = default;
77 
78   flatbuffers::DetachedBuffer serializeModule(
79       const mobile::Module& module,
80       bool include_tensor_data_in_flatbuffer,
81       const ExtraFilesMap& extra_files = ExtraFilesMap(),
82       const ExtraFilesMap& jit_sources = ExtraFilesMap(),
83       const std::vector<IValue>& jit_constants = {});
84 
85  private:
86   template <typename It>
storeIValuesAndGetIndexes(flatbuffers::FlatBufferBuilder & fbb,It begin,It end)87   std::vector<uint32_t> storeIValuesAndGetIndexes(
88       flatbuffers::FlatBufferBuilder& fbb,
89       It begin,
90       It end) {
91     std::vector<uint32_t> indexes;
92     for (; begin != end; ++begin) {
93       indexes.push_back(storeIValueAndGetIndex(fbb, *begin));
94     }
95     return indexes;
96   }
97 
98   flatbuffers::Offset<mobile::serialization::Tuple> tupleToFB(
99       flatbuffers::FlatBufferBuilder& fbb,
100       const IValue& tuple);
101 
102   flatbuffers::Offset<mobile::serialization::List> listToFB(
103       flatbuffers::FlatBufferBuilder& fbb,
104       const IValue& list);
105 
106   flatbuffers::Offset<mobile::serialization::Dict> dictToFB(
107       flatbuffers::FlatBufferBuilder& fbb,
108       const IValue& list);
109 
110   flatbuffers::Offset<mobile::serialization::Object> objectToFB(
111       flatbuffers::FlatBufferBuilder& fbb,
112       const IValue& ivalue);
113 
114   flatbuffers::Offset<mobile::serialization::TensorMetadata> tensorToFB(
115       flatbuffers::FlatBufferBuilder& fbb,
116       const IValue& ivalue);
117 
118   flatbuffers::Offset<mobile::serialization::Function> functionToFB(
119       flatbuffers::FlatBufferBuilder& fbb,
120       const std::string& qn,
121       const mobile::Function& func);
122 
123   flatbuffers::Offset<mobile::serialization::IValue> iValueToFB(
124       flatbuffers::FlatBufferBuilder& fbb,
125       const IValue& ivalue);
126 
127   flatbuffers::Offset<jit::mobile::serialization::Schema> CreateFBSchema(
128       flatbuffers::FlatBufferBuilder& fbb,
129       const std::vector<Argument>& args,
130       const std::vector<Argument>& returns,
131       const c10::TypePrinter& type_printer);
132 
133   flatbuffers::Offset<mobile::serialization::ObjectType> classTypeToFB(
134       flatbuffers::FlatBufferBuilder& fbb,
135       const ClassTypePtr& class_ptr);
136 
137   uint32_t storeIValueAndGetIndex(
138       flatbuffers::FlatBufferBuilder& fbb,
139       const IValue& ivalue);
140   uint32_t storeFunctionAndGetIndex(
141       flatbuffers::FlatBufferBuilder& fbb,
142       const std::string& qn,
143       const mobile::Function& function);
144 
145   uint32_t storeClassTypeAndGetIndex(
146       flatbuffers::FlatBufferBuilder& fbb,
147       const ClassTypePtr& class_type);
148 
149   flatbuffers::Offset<flatbuffers::Vector<
150       flatbuffers::Offset<mobile::serialization::ExtraFile>>>
151   storeExtraFilesAndGetOffset(
152       FlatBufferBuilder& fbb,
153       const ExtraFilesMap& extra_files);
154 
insertIValue(flatbuffers::Offset<mobile::serialization::IValue> ivalue)155   uint32_t insertIValue(
156       flatbuffers::Offset<mobile::serialization::IValue> ivalue) {
157     uint32_t size = ivalue_offsets_.size();
158     ivalue_offsets_.push_back(ivalue);
159     return size;
160   }
161 
162   std::vector<at::Tensor> tensor_data_;
163 
164   std::unordered_map<const void*, uint32_t> memoized_storage_map_;
165 
166   std::vector<flatbuffers::Offset<mobile::serialization::IValue>>
167       ivalue_offsets_;
168   std::vector<flatbuffers::Offset<mobile::serialization::ObjectType>>
169       obj_types_offset_;
170 
171   // qualified name to serialized class, type or function
172   std::unordered_map<std::string, uint32_t> qn_to_serialized_values_;
173 
174   // cache of some ivalues
175   struct IValueHash {
operator ()torch::jit::__anon332933560111::FlatbufferSerializer::IValueHash176     size_t operator()(const IValue& val) const {
177       return IValue::hash(val);
178     }
179   };
180 
181   struct IValueEqual {
182     // Copy of this
183     // https://www.internalfb.com/code/aros/[3b875bce7ffa2adacdcea9b3e0cb6d304737a193]/xros/third-party/caffe2/caffe2/aten/src/ATen/core/ivalue.cpp?lines=266
184     // but without relying on aten::nonzero operator being present in the
185     // binary.
operator ()torch::jit::__anon332933560111::FlatbufferSerializer::IValueEqual186     bool operator()(const IValue& lhs, const IValue& rhs) const {
187       // The only case we don't return bool is for tensor comparison. Lets do
188       // pointer comparison here.
189       if (lhs.isTensor() || rhs.isTensor()) {
190         if (lhs.isTensor() && rhs.isTensor()) {
191           return (&lhs.toTensor()) == (&rhs.toTensor());
192         }
193         return false;
194       }
195       IValue eq = lhs.equals(rhs);
196       if (eq.isBool()) {
197         return eq.toBool();
198       }
199       return false;
200     }
201   };
202 
203   std::unordered_map<IValue, uint32_t, IValueHash, IValueEqual> cached_ivalues_;
204   const mobile::CompilationUnit* mcu_ = nullptr;
205 };
206 
207 flatbuffers::Offset<jit::mobile::serialization::Schema> FlatbufferSerializer::
CreateFBSchema(flatbuffers::FlatBufferBuilder & fbb,const std::vector<Argument> & args,const std::vector<Argument> & returns,const c10::TypePrinter & type_printer)208     CreateFBSchema(
209         flatbuffers::FlatBufferBuilder& fbb,
210         const std::vector<Argument>& args,
211         const std::vector<Argument>& returns,
212         const c10::TypePrinter& type_printer) {
213   std::vector<flatbuffers::Offset<jit::mobile::serialization::Arg>> arg_vec;
214   arg_vec.reserve(args.size());
215   std::vector<flatbuffers::Offset<jit::mobile::serialization::Arg>> return_vec;
216   return_vec.reserve(returns.size());
217   for (const auto& arg : args) {
218     auto index = storeIValueAndGetIndex(fbb, arg.default_value());
219     arg_vec.emplace_back(CreateArg(
220         fbb,
221         fbb.CreateSharedString(arg.name()),
222         fbb.CreateSharedString(
223             realType(arg.type())->annotation_str(type_printer)),
224         index));
225   }
226 
227   for (const auto& ret : returns) {
228     auto index = storeIValueAndGetIndex(fbb, ret.default_value());
229     return_vec.emplace_back(CreateArg(
230         fbb,
231         fbb.CreateSharedString(ret.name()),
232         fbb.CreateSharedString(
233             realType(ret.type())->annotation_str(type_printer)),
234         index));
235   }
236   return CreateSchema(
237       fbb, fbb.CreateVector(arg_vec), fbb.CreateVector(return_vec));
238 }
239 
240 flatbuffers::Offset<mobile::serialization::Function> FlatbufferSerializer::
functionToFB(FlatBufferBuilder & fbb,const std::string & qn,const mobile::Function & func)241     functionToFB(
242         FlatBufferBuilder& fbb,
243         const std::string& qn,
244         const mobile::Function& func) {
245   const auto& code = func.get_code();
246 
247   // instructions
248   std::vector<mobile::serialization::Instruction> instruction_vector;
249   instruction_vector.reserve(code.instructions_.size());
250   for (const auto& inst : code.instructions_) {
251     instruction_vector.emplace_back(inst.op, inst.N, inst.X);
252   }
253 
254   // operators
255   std::vector<flatbuffers::Offset<mobile::serialization::Operator>>
256       operator_vector;
257   operator_vector.reserve(code.op_names_.size());
258   for (const auto i : c10::irange(code.op_names_.size())) {
259     const auto& opname = code.op_names_[i];
260     const int op_size = code.operator_input_sizes_[i];
261     operator_vector.push_back(CreateOperator(
262         fbb,
263         fbb.CreateSharedString(opname.name),
264         fbb.CreateSharedString(opname.overload_name),
265         op_size));
266   }
267 
268   const auto& constants = code.constants_;
269 
270   std::vector<uint32_t> constant_indexes;
271   constant_indexes.reserve(constants.size());
272   for (const auto& constant : constants) {
273     constant_indexes.push_back(storeIValueAndGetIndex(fbb, constant));
274   }
275 
276   // types
277   static const std::string torch_prefix("__torch__");
278   static const std::string class_prefix("__torch__.torch.classes");
279   std::vector<flatbuffers::Offset<flatbuffers::String>> type_offsets;
280 
281   for (const TypePtr& t : code.types_) {
282     auto type_str = realType(t)->annotation_str();
283     if (type_str.find(torch_prefix) == 0) {
284       TORCH_CHECK(
285           type_str.find(class_prefix) == 0,
286           "__torch__ types other than custom c++ classes (__torch__.torch.classes)"
287           "are not supported in lite interpreter. ",
288           "Workaround: instead of using arbitrary class type (class Foo()), ",
289           "define a pytorch class (class Foo(torch.nn.Module)).");
290     }
291 
292     type_offsets.push_back(fbb.CreateSharedString(type_str));
293   }
294 
295   // since the register location is embedded into the bytecode, pass the
296   // register size
297   auto register_size = static_cast<int>(code.register_size_);
298 
299   // schema
300   auto type_printer = [&](const c10::Type& t) -> std::optional<std::string> {
301     auto namedType = t.cast<c10::NamedType>();
302     if (namedType && namedType->name()) {
303       return namedType->name().value().qualifiedName();
304     }
305     if (auto dyn = t.castRaw<c10::DynamicType>()) {
306       return dyn->fallback()->annotation_str();
307     }
308     return std::nullopt;
309   };
310 
311   flatbuffers::Offset<mobile::serialization::Schema> schema_offset = 0;
312   uint32_t class_index = 0;
313   if (func.hasSchema()) {
314     const auto& schema = func.getSchema();
315     TORCH_CHECK(
316         schema.overload_name().empty(), // @TODO: is this check correct?
317         "Overloads are not supported in mobile modules.");
318     TORCH_CHECK(
319         !schema.is_vararg(),
320         "Python *args are not supported in mobile modules.");
321     TORCH_CHECK(
322         !schema.is_varret(),
323         "A variable number of return values is not supported in mobile modules.");
324     schema_offset =
325         CreateFBSchema(fbb, schema.arguments(), schema.returns(), type_printer);
326     auto classtype = schema.arguments()[0].type()->cast<ClassType>();
327     class_index = storeClassTypeAndGetIndex(fbb, classtype);
328   }
329 
330   auto debug_info_offset =
331       CreateDebugInfo(fbb, fbb.CreateVector(code.debug_handles_));
332 
333   auto function_offset = CreateFunctionDirect(
334       fbb,
335       qn.c_str(),
336       &instruction_vector,
337       &operator_vector,
338       &constant_indexes,
339       &type_offsets,
340       register_size,
341       schema_offset,
342       debug_info_offset,
343       class_index);
344   return function_offset;
345 }
346 
347 flatbuffers::Offset<
348     flatbuffers::Vector<flatbuffers::Offset<mobile::serialization::ExtraFile>>>
storeExtraFilesAndGetOffset(FlatBufferBuilder & fbb,const ExtraFilesMap & extra_files)349 FlatbufferSerializer::storeExtraFilesAndGetOffset(
350     FlatBufferBuilder& fbb,
351     const ExtraFilesMap& extra_files) {
352   std::vector<flatbuffers::Offset<mobile::serialization::ExtraFile>>
353       extra_file_offsets;
354 
355   for (const auto& extra_file : extra_files) {
356     flatbuffers::Offset<mobile::serialization::ExtraFile> extra_file_offset =
357         mobile::serialization::CreateExtraFile(
358             fbb,
359             fbb.CreateSharedString(extra_file.first),
360             fbb.CreateString(extra_file.second));
361     extra_file_offsets.emplace_back(extra_file_offset);
362   }
363   return fbb.CreateVector(extra_file_offsets);
364 }
365 
serializeModule(const mobile::Module & module,bool include_tensor_data_in_flatbuffer,const ExtraFilesMap & extra_files,const ExtraFilesMap & jit_sources,const std::vector<IValue> & jit_constants)366 flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule(
367     const mobile::Module& module,
368     bool include_tensor_data_in_flatbuffer,
369     const ExtraFilesMap& extra_files,
370     const ExtraFilesMap& jit_sources,
371     const std::vector<IValue>& jit_constants) {
372   FlatBufferBuilder fbb;
373 
374   mcu_ = &module.compilation_unit();
375 
376   // first element is None.
377   insertIValue(CreateIValue(fbb, mobile::serialization::IValueUnion::NONE, 0));
378 
379   auto methods = module.get_methods();
380   std::vector<uint32_t> functions_index;
381   functions_index.reserve(methods.size());
382   for (const auto& method : methods) {
383     auto func_offset = storeFunctionAndGetIndex(
384         fbb, method.function().qualname().qualifiedName(), method.function());
385     functions_index.push_back(func_offset);
386   }
387 
388   auto functions_offset = fbb.CreateVector(functions_index);
389   uint32_t ivalue_index = storeIValueAndGetIndex(fbb, module._ivalue());
390 
391   flatbuffers::Offset<flatbuffers::Vector<
392       flatbuffers::Offset<mobile::serialization::StorageData>>>
393       storage_data_offset = 0;
394   auto extra_files_offset = storeExtraFilesAndGetOffset(fbb, extra_files);
395 
396   auto jit_source_offset = storeExtraFilesAndGetOffset(fbb, jit_sources);
397   std::vector<uint32_t> jit_constants_indexes;
398   jit_constants_indexes.reserve(jit_constants.size());
399   const uint32_t mobile_ivalue_size = ivalue_offsets_.size();
400   for (const auto& ival : jit_constants) {
401     jit_constants_indexes.emplace_back(storeIValueAndGetIndex(fbb, ival));
402   }
403   const uint32_t operator_version =
404       static_cast<uint32_t>(module.min_operator_version());
405   uint32_t bytecode_version = static_cast<uint32_t>(module.bytecode_version());
406   if (bytecode_version < kMinVersion) {
407     bytecode_version = kMinVersion;
408   }
409 
410   // NOTE: saving of storage has to be the last thing to do.
411   if (include_tensor_data_in_flatbuffer) {
412     std::vector<flatbuffers::Offset<mobile::serialization::StorageData>>
413         storage_data;
414     for (auto td : tensor_data_) {
415       if (td.storage().device_type() != DeviceType::CPU) {
416         td = at::empty({0}, td.options())
417                  .set_(
418                      td.storage(),
419                      /* storage_offset = */ 0,
420                      /* size = */
421                      {static_cast<int64_t>(
422                          td.storage().nbytes() / td.element_size())},
423                      /* stride = */ {1})
424                  .cpu();
425       }
426       fbb.ForceVectorAlignment(
427           td.storage().nbytes(), sizeof(uint8_t), FLATBUFFERS_MAX_ALIGNMENT);
428       auto storage_offset = mobile::serialization::CreateStorageData(
429           fbb,
430           fbb.CreateVector(
431               reinterpret_cast<const uint8_t*>(td.storage().data()),
432               td.storage().nbytes()));
433       storage_data.push_back(storage_offset);
434     }
435     storage_data_offset = fbb.CreateVector(storage_data);
436   }
437 
438   auto mod = CreateModule(
439       fbb,
440       /*bytecode_version=*/bytecode_version,
441       extra_files_offset, /* extra_files */
442       functions_offset,
443       ivalue_index,
444       fbb.CreateVector(ivalue_offsets_),
445       static_cast<int32_t>(tensor_data_.size()),
446       storage_data_offset,
447       fbb.CreateVector(obj_types_offset_),
448       jit_source_offset,
449       fbb.CreateVector(jit_constants_indexes),
450       operator_version,
451       mobile_ivalue_size);
452   FinishModuleBuffer(fbb, mod);
453   return fbb.Release();
454 }
455 
456 flatbuffers::Offset<mobile::serialization::Tuple> FlatbufferSerializer::
tupleToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & tuple)457     tupleToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& tuple) {
458   const auto& elements = tuple.toTuple()->elements();
459   std::vector<uint32_t> items =
460       storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end());
461   return CreateTupleDirect(fbb, &items);
462 }
463 
listToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & list)464 flatbuffers::Offset<mobile::serialization::List> FlatbufferSerializer::listToFB(
465     flatbuffers::FlatBufferBuilder& fbb,
466     const IValue& list) {
467   const auto& elements = list.toList();
468   std::vector<uint32_t> items =
469       storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end());
470   return CreateList(
471       fbb,
472       fbb.CreateVector(items),
473       fbb.CreateSharedString(
474           realType(list.type<c10::Type>())->annotation_str(print_type)));
475 }
476 
dictToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & ivalue)477 flatbuffers::Offset<mobile::serialization::Dict> FlatbufferSerializer::dictToFB(
478     flatbuffers::FlatBufferBuilder& fbb,
479     const IValue& ivalue) {
480   const auto& dict = ivalue.toGenericDict();
481   std::vector<uint32_t> keys;
482   std::vector<uint32_t> values;
483   keys.reserve(dict.size());
484   values.reserve(dict.size());
485   for (const auto& entry : dict) {
486     auto key_index = storeIValueAndGetIndex(fbb, entry.key());
487     keys.push_back(key_index);
488     auto value_index = storeIValueAndGetIndex(fbb, entry.value());
489     values.push_back(value_index);
490   }
491 
492   return CreateDict(
493       fbb,
494       fbb.CreateVector(keys),
495       fbb.CreateVector(values),
496       fbb.CreateSharedString(
497           realType(ivalue.type<c10::Type>())->annotation_str(print_type)));
498 }
499 
500 flatbuffers::Offset<mobile::serialization::ObjectType> FlatbufferSerializer::
classTypeToFB(FlatBufferBuilder & fbb,const ClassTypePtr & class_ptr)501     classTypeToFB(FlatBufferBuilder& fbb, const ClassTypePtr& class_ptr) {
502   mobile::serialization::TypeType typetype =
503       mobile::serialization::TypeType::UNSET;
504 
505   flatbuffers::Offset<
506       flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
507       names_offset = 0;
508   c10::QualifiedName setstate_name(*class_ptr->name(), "__setstate__");
509   c10::QualifiedName getstate_name(*class_ptr->name(), "__getstate__");
510   const mobile::Function* setstate = mcu_->find_function(setstate_name);
511   const mobile::Function* getstate = mcu_->find_function(getstate_name);
512   if (setstate != nullptr && getstate != nullptr) {
513     typetype = mobile::serialization::TypeType::CLASS_WITH_SETSTATE;
514   } else if (
515       class_ptr->findMethod("__setstate__") &&
516       class_ptr->findMethod("__getstate__")) {
517     typetype = mobile::serialization::TypeType::CUSTOM_CLASS;
518   } else {
519     size_t num_attr = class_ptr->numAttributes();
520     std::vector<flatbuffers::Offset<flatbuffers::String>> names;
521     std::vector<uint32_t> type_index;
522     for (size_t i = 0; i < num_attr; ++i) {
523       names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i)));
524     }
525     names_offset = fbb.CreateVector(names);
526     typetype = mobile::serialization::TypeType::CLASS_WITH_FIELD;
527   }
528 
529   auto name_offset = fbb.CreateString(class_ptr->name()->qualifiedName());
530   return CreateObjectType(fbb, name_offset, typetype, names_offset);
531 }
532 
storeFunctionAndGetIndex(flatbuffers::FlatBufferBuilder & fbb,const std::string & qn,const mobile::Function & function)533 uint32_t FlatbufferSerializer::storeFunctionAndGetIndex(
534     flatbuffers::FlatBufferBuilder& fbb,
535     const std::string& qn,
536     const mobile::Function& function) {
537   auto iter = qn_to_serialized_values_.find(qn);
538   if (iter != qn_to_serialized_values_.end()) {
539     return iter->second;
540   }
541 
542   auto offset = CreateIValue(
543       fbb,
544       mobile::serialization::IValueUnion::Function,
545       functionToFB(fbb, qn, function).Union());
546 
547   uint32_t index = insertIValue(offset);
548   qn_to_serialized_values_[qn] = index;
549   return index;
550 }
551 
storeClassTypeAndGetIndex(FlatBufferBuilder & fbb,const ClassTypePtr & class_ptr)552 uint32_t FlatbufferSerializer::storeClassTypeAndGetIndex(
553     FlatBufferBuilder& fbb,
554     const ClassTypePtr& class_ptr) {
555   const auto& type_str = class_ptr->name()->qualifiedName();
556   auto iter = qn_to_serialized_values_.find(type_str);
557   if (iter != qn_to_serialized_values_.end()) {
558     return iter->second;
559   }
560 
561   auto offset = classTypeToFB(fbb, class_ptr);
562   uint32_t res = obj_types_offset_.size();
563   obj_types_offset_.push_back(offset);
564   qn_to_serialized_values_[type_str] = res;
565   return res;
566 }
567 
568 flatbuffers::Offset<mobile::serialization::Object> FlatbufferSerializer::
objectToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & ivalue)569     objectToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) {
570   auto obj = ivalue.toObject();
571   auto type = obj->type();
572   // rename type?
573   // check getstate
574 
575   // save state as ivalue
576   flatbuffers::Offset<flatbuffers::Vector<uint32_t>> attrs = 0;
577   uint32_t state_index = 0;
578   uint32_t setstate_func_index = 0;
579   const auto qn = type->name()->qualifiedName() + ".__setstate__";
580   auto getstate = type->findMethod("__getstate__");
581   auto setstate = type->findMethod("__setstate__");
582   if (getstate && setstate) {
583     auto state = (*getstate)({obj});
584     state_index = storeIValueAndGetIndex(fbb, state);
585     auto func_index = qn_to_serialized_values_.find(qn);
586     if (func_index != qn_to_serialized_values_.end()) {
587       setstate_func_index = func_index->second;
588     }
589   } else {
590     size_t num_attr = type->numAttributes();
591     std::vector<uint32_t> tuple_index;
592     for (size_t i = 0; i < num_attr; ++i) {
593       tuple_index.push_back(storeIValueAndGetIndex(fbb, obj->getSlot(i)));
594     }
595     attrs = fbb.CreateVector(tuple_index);
596   }
597 
598   uint32_t type_index = storeClassTypeAndGetIndex(fbb, type);
599   return CreateObject(fbb, type_index, state_index, attrs, setstate_func_index);
600 }
601 
602 flatbuffers::Offset<mobile::serialization::TensorMetadata> FlatbufferSerializer::
tensorToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & ivalue)603     FlatbufferSerializer::tensorToFB(
604         flatbuffers::FlatBufferBuilder& fbb,
605         const IValue& ivalue) {
606   auto& tensor = ivalue.toTensor();
607   bool quantized = tensor.is_quantized();
608   const at::Storage& storage = tensor.storage();
609 
610   flatbuffers::Offset<mobile::serialization::QuantizedSchema> qschema_offset =
611       0;
612   if (quantized) {
613     double scale = 0;
614     int64_t zero_point = 0;
615     flatbuffers::Offset<mobile::serialization::TensorMetadata> scales = 0;
616     flatbuffers::Offset<mobile::serialization::TensorMetadata> zero_points = 0;
617     int64_t axis = 0;
618 
619     switch (tensor.qscheme()) {
620       case at::kPerTensorAffine:
621         scale = tensor.q_scale();
622         zero_point = tensor.q_zero_point();
623         break;
624       case at::kPerChannelAffineFloatQParams:
625       case at::kPerChannelAffine: {
626         scales = tensorToFB(fbb, tensor.q_per_channel_scales());
627         zero_points = tensorToFB(fbb, tensor.q_per_channel_zero_points());
628         axis = tensor.q_per_channel_axis();
629       } break;
630       default:
631         TORCH_CHECK(
632             false,
633             "Unsupported tensor quantization type in serialization ",
634             toString(tensor.qscheme()));
635         break;
636     }
637 
638     qschema_offset = mobile::serialization::CreateQuantizedSchema(
639         fbb,
640         static_cast<int8_t>(tensor.qscheme()),
641         scale,
642         static_cast<int32_t>(zero_point),
643         scales,
644         zero_points,
645         static_cast<int32_t>(axis));
646   }
647 
648   void* addr = storage.unsafeGetStorageImpl();
649   uint32_t storage_index = 0;
650   auto it = memoized_storage_map_.find(addr);
651   if (it != memoized_storage_map_.end()) {
652     storage_index = it->second;
653   } else {
654     storage_index = tensor_data_.size();
655     memoized_storage_map_[addr] = storage_index;
656     tensor_data_.push_back(tensor);
657   }
658 
659   std::vector<int> sizes{tensor.sizes().begin(), tensor.sizes().end()};
660   std::vector<int> strides{tensor.strides().begin(), tensor.strides().end()};
661 
662   return CreateTensorMetadataDirect(
663       fbb,
664       /* storage_location_index */ storage_index,
665       /* scalar_type */ static_cast<int8_t>(tensor.scalar_type()),
666       /* int32_t storage_offset */
667       static_cast<int32_t>(tensor.storage_offset()),
668       /* sizes */ &sizes,
669       /* strides */ &strides,
670       /* bool requires_grad */ tensor.requires_grad(),
671       /* qschema */ qschema_offset);
672 }
673 
storeIValueAndGetIndex(flatbuffers::FlatBufferBuilder & fbb,const IValue & ivalue)674 uint32_t FlatbufferSerializer::storeIValueAndGetIndex(
675     flatbuffers::FlatBufferBuilder& fbb,
676     const IValue& ivalue) {
677   if (ivalue.isNone()) {
678     return kNoneIndex;
679   }
680 
681   try {
682     auto iter = cached_ivalues_.find(ivalue);
683     if (iter != cached_ivalues_.end()) {
684       return iter->second;
685     }
686     // NOLINTNEXTLINE(bugprone-empty-catch)
687   } catch (...) {
688     // Threw if ivalue is not hashable or
689     // if ivalue is don't have proper operator==
690     // we don't care catchall because either case we want to skip hashing
691   }
692 
693   auto offset = iValueToFB(fbb, ivalue);
694   uint32_t index = insertIValue(offset);
695   try {
696     cached_ivalues_[ivalue] = index;
697     // NOLINTNEXTLINE(bugprone-empty-catch)
698   } catch (...) {
699     // Threw if ivalue is not hashable or
700     // if ivalue is don't have proper operator==
701     // we don't care catchall because either case we want to skip hashing
702   }
703 
704   return index;
705 }
706 
707 flatbuffers::Offset<mobile::serialization::IValue> FlatbufferSerializer::
iValueToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & ivalue)708     iValueToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) {
709   using mobile::serialization::IValueUnion;
710 
711   IValueUnion ivalue_type = IValueUnion::NONE;
712   flatbuffers::Offset<void> offset = 0;
713 
714   if (ivalue.isTensor()) {
715     ivalue_type = IValueUnion::TensorMetadata;
716     offset = tensorToFB(fbb, ivalue).Union();
717   } else if (ivalue.isTuple()) {
718     ivalue_type = IValueUnion::Tuple;
719     offset = tupleToFB(fbb, ivalue).Union();
720   } else if (ivalue.isDouble()) {
721     ivalue_type = IValueUnion::Double;
722     offset = fbb.CreateStruct(mobile::serialization::Double(ivalue.toDouble()))
723                  .Union();
724   } else if (ivalue.isComplexDouble()) {
725     auto comp = ivalue.toComplexDouble();
726     ivalue_type = IValueUnion::ComplexDouble;
727     offset = fbb.CreateStruct(mobile::serialization::ComplexDouble(
728                                   comp.real(), comp.imag()))
729                  .Union();
730   } else if (ivalue.isInt()) {
731     ivalue_type = IValueUnion::Int;
732     offset =
733         fbb.CreateStruct(mobile::serialization::Int(ivalue.toInt())).Union();
734   } else if (ivalue.isBool()) {
735     ivalue_type = IValueUnion::Bool;
736     offset =
737         fbb.CreateStruct(mobile::serialization::Bool(ivalue.toBool())).Union();
738   } else if (ivalue.isString()) {
739     ivalue_type = IValueUnion::String;
740     offset = mobile::serialization::CreateString(
741                  fbb, fbb.CreateSharedString(ivalue.toStringRef()))
742                  .Union();
743   } else if (ivalue.isGenericDict()) {
744     ivalue_type = IValueUnion::Dict;
745     offset = dictToFB(fbb, ivalue).Union();
746   } else if (ivalue.isNone()) {
747     ivalue_type = IValueUnion::NONE;
748     offset = 0;
749   } else if (ivalue.isIntList()) {
750     ivalue_type = IValueUnion::IntList;
751     offset = mobile::serialization::CreateIntList(
752                  fbb, fbb.CreateVector(ivalue.toIntVector()))
753                  .Union();
754   } else if (ivalue.isDoubleList()) {
755     ivalue_type = IValueUnion::DoubleList;
756     offset = mobile::serialization::CreateDoubleList(
757                  fbb, fbb.CreateVector(ivalue.toDoubleVector()))
758                  .Union();
759   } else if (ivalue.isBoolList()) {
760     ivalue_type = IValueUnion::BoolList;
761     auto boollist = ivalue.toBoolList();
762     std::vector<uint8_t> bool_vec(boollist.begin(), boollist.end());
763     offset =
764         mobile::serialization::CreateBoolListDirect(fbb, &bool_vec).Union();
765   } else if (ivalue.isList()) {
766     ivalue_type = IValueUnion::List;
767     offset = listToFB(fbb, ivalue).Union();
768   } else if (ivalue.isObject()) {
769     ivalue_type = IValueUnion::Object;
770     offset = objectToFB(fbb, ivalue).Union();
771   } else if (ivalue.isDevice()) {
772     ivalue_type = IValueUnion::Device;
773     offset = mobile::serialization::CreateDevice(
774                  fbb, fbb.CreateSharedString(ivalue.toDevice().str()))
775                  .Union();
776   } else if (ivalue.isEnum()) {
777     const auto& enum_holder = ivalue.toEnumHolder();
778     const auto& qualified_class_name =
779         enum_holder->type()->qualifiedClassName();
780     uint32_t ival_pos = storeIValueAndGetIndex(fbb, enum_holder->value());
781     ivalue_type = IValueUnion::EnumValue;
782     offset = mobile::serialization::CreateEnumValue(
783                  fbb,
784                  fbb.CreateSharedString(qualified_class_name.qualifiedName()),
785                  ival_pos)
786                  .Union();
787   } else {
788     AT_ERROR("Invalid IValue type for serialization: ", ivalue.tagKind());
789   }
790   return CreateIValue(fbb, ivalue_type, offset);
791 }
792 
793 } // namespace
794 
save_mobile_module(const mobile::Module & module,const std::string & filename,const ExtraFilesMap & extra_files,const ExtraFilesMap & jit_sources,const std::vector<IValue> & jit_constants)795 void save_mobile_module(
796     const mobile::Module& module,
797     const std::string& filename,
798     const ExtraFilesMap& extra_files,
799     const ExtraFilesMap& jit_sources,
800     const std::vector<IValue>& jit_constants) {
801   auto buffer = save_mobile_module_to_bytes(
802       module, extra_files, jit_sources, jit_constants);
803   std::fstream ofile(filename, std::ios::binary | std::ios::out);
804   ofile.write(
805       reinterpret_cast<char*>(buffer->data()),
806       static_cast<std::streamsize>(buffer->size()));
807   ofile.close();
808 }
809 
810 /// Deletes a DetachedBuffer, along with the internal
811 /// flatbuffers::DetachedBuffer if present. Used as a custom deleter for
812 /// std::unique_ptr; see UniqueDetachedBuffer and make_unique_detached_buffer.
destroy(DetachedBuffer * buf)813 void DetachedBuffer::destroy(DetachedBuffer* buf) {
814   // May be null.
815   delete static_cast<flatbuffers::DetachedBuffer*>(buf->data_owner_);
816   delete buf;
817 }
818 
819 /// Provides access to DetachedBuffer::destroy().
820 struct DetachedBufferFriend {
821   /// Returns a UniqueDetachedBuffer that wraps the provided DetachedBuffer.
make_unique_detached_buffertorch::jit::DetachedBufferFriend822   static DetachedBuffer::UniqueDetachedBuffer make_unique_detached_buffer(
823       DetachedBuffer* buf) {
824     return DetachedBuffer::UniqueDetachedBuffer(buf, DetachedBuffer::destroy);
825   }
826 };
827 
save_mobile_module_to_bytes(const mobile::Module & module,const ExtraFilesMap & extra_files,const ExtraFilesMap & jit_sources,const std::vector<IValue> & jit_constants)828 DetachedBuffer::UniqueDetachedBuffer save_mobile_module_to_bytes(
829     const mobile::Module& module,
830     const ExtraFilesMap& extra_files,
831     const ExtraFilesMap& jit_sources,
832     const std::vector<IValue>& jit_constants) {
833   FlatbufferSerializer fb_serializer;
834   flatbuffers::DetachedBuffer buf = fb_serializer.serializeModule(
835       module,
836       /*include_tensor_data_in_flatbuffer=*/true,
837       extra_files,
838       jit_sources,
839       jit_constants);
840   flatbuffers::DetachedBuffer* buf_ptr =
841       new flatbuffers::DetachedBuffer(std::move(buf));
842   DetachedBuffer* ret =
843       new DetachedBuffer(buf_ptr->data(), buf_ptr->size(), buf_ptr);
844   return DetachedBufferFriend::make_unique_detached_buffer(ret);
845 }
846 
save_mobile_module_to_func(const mobile::Module & module,const std::function<size_t (const void *,size_t)> & writer_func)847 void save_mobile_module_to_func(
848     const mobile::Module& module,
849     const std::function<size_t(const void*, size_t)>& writer_func) {
850   auto buffer = save_mobile_module_to_bytes(module);
851   writer_func(buffer->data(), buffer->size());
852 }
853 
register_flatbuffer_serializer()854 bool register_flatbuffer_serializer() {
855   return true;
856 }
857 
858 } // namespace torch::jit
859