1 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
2
3 #ifdef FLATBUFFERS_VERSION_MAJOR
4 #error "flatbuffer_serializer.h must not include any flatbuffers headers"
5 #endif // FLATBUFFERS_VERSION_MAJOR
6
7 #include <fstream>
8 #include <functional>
9 #include <string>
10 #include <unordered_map>
11 #include <utility>
12 #include <vector>
13
14 #include <ATen/ATen.h>
15 #include <c10/core/CPUAllocator.h>
16 #include <c10/util/Exception.h>
17 #include <caffe2/serialize/versions.h>
18 #include <torch/csrc/jit/mobile/code.h>
19 #include <torch/csrc/jit/mobile/train/export_data.h>
20 #include <torch/csrc/jit/passes/inliner.h>
21 #include <torch/csrc/jit/runtime/instruction.h>
22
23 #if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2)
24 #include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT
25 namespace flatbuffers = flatbuffers_fbsource;
26 #define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT
27 #else
28 #include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
29 #endif
30
31 namespace torch::jit {
32
33 using flatbuffers::FlatBufferBuilder;
34 using mobile::serialization::CreateArg;
35 using mobile::serialization::CreateDebugInfo;
36 using mobile::serialization::CreateDict;
37 using mobile::serialization::CreateFunctionDirect;
38 using mobile::serialization::CreateIValue;
39 using mobile::serialization::CreateList;
40 using mobile::serialization::CreateModule;
41 using mobile::serialization::CreateObject;
42 using mobile::serialization::CreateOperator;
43 using mobile::serialization::CreateTensorMetadataDirect;
44 using mobile::serialization::CreateTupleDirect;
45
46 namespace {
47
48 // TODO: remove once caffe2::kProducedBytecodeVersion is >= 9 and flatbuffer is
49 // launched.
50 constexpr uint32_t kMinVersion = 9;
51
52 // We will store IValue NONE in index 0 in flatbuffer.
53 constexpr int kNoneIndex = 0;
54
realType(TypePtr type)55 static TypePtr realType(TypePtr type) {
56 if (auto dyn = type->castRaw<c10::DynamicType>()) {
57 return dyn->fallback();
58 } else {
59 return type;
60 }
61 }
62
print_type(const c10::Type & t)63 auto print_type(const c10::Type& t) -> std::optional<std::string> {
64 auto namedType = t.cast<c10::NamedType>();
65 if (namedType && namedType->name()) {
66 return namedType->name().value().qualifiedName();
67 }
68 if (auto dyn = t.castRaw<c10::DynamicType>()) {
69 return dyn->fallback()->annotation_str();
70 }
71 return std::nullopt;
72 }
73
74 class FlatbufferSerializer {
75 public:
76 FlatbufferSerializer() = default;
77
78 flatbuffers::DetachedBuffer serializeModule(
79 const mobile::Module& module,
80 bool include_tensor_data_in_flatbuffer,
81 const ExtraFilesMap& extra_files = ExtraFilesMap(),
82 const ExtraFilesMap& jit_sources = ExtraFilesMap(),
83 const std::vector<IValue>& jit_constants = {});
84
85 private:
86 template <typename It>
storeIValuesAndGetIndexes(flatbuffers::FlatBufferBuilder & fbb,It begin,It end)87 std::vector<uint32_t> storeIValuesAndGetIndexes(
88 flatbuffers::FlatBufferBuilder& fbb,
89 It begin,
90 It end) {
91 std::vector<uint32_t> indexes;
92 for (; begin != end; ++begin) {
93 indexes.push_back(storeIValueAndGetIndex(fbb, *begin));
94 }
95 return indexes;
96 }
97
98 flatbuffers::Offset<mobile::serialization::Tuple> tupleToFB(
99 flatbuffers::FlatBufferBuilder& fbb,
100 const IValue& tuple);
101
102 flatbuffers::Offset<mobile::serialization::List> listToFB(
103 flatbuffers::FlatBufferBuilder& fbb,
104 const IValue& list);
105
106 flatbuffers::Offset<mobile::serialization::Dict> dictToFB(
107 flatbuffers::FlatBufferBuilder& fbb,
108 const IValue& list);
109
110 flatbuffers::Offset<mobile::serialization::Object> objectToFB(
111 flatbuffers::FlatBufferBuilder& fbb,
112 const IValue& ivalue);
113
114 flatbuffers::Offset<mobile::serialization::TensorMetadata> tensorToFB(
115 flatbuffers::FlatBufferBuilder& fbb,
116 const IValue& ivalue);
117
118 flatbuffers::Offset<mobile::serialization::Function> functionToFB(
119 flatbuffers::FlatBufferBuilder& fbb,
120 const std::string& qn,
121 const mobile::Function& func);
122
123 flatbuffers::Offset<mobile::serialization::IValue> iValueToFB(
124 flatbuffers::FlatBufferBuilder& fbb,
125 const IValue& ivalue);
126
127 flatbuffers::Offset<jit::mobile::serialization::Schema> CreateFBSchema(
128 flatbuffers::FlatBufferBuilder& fbb,
129 const std::vector<Argument>& args,
130 const std::vector<Argument>& returns,
131 const c10::TypePrinter& type_printer);
132
133 flatbuffers::Offset<mobile::serialization::ObjectType> classTypeToFB(
134 flatbuffers::FlatBufferBuilder& fbb,
135 const ClassTypePtr& class_ptr);
136
137 uint32_t storeIValueAndGetIndex(
138 flatbuffers::FlatBufferBuilder& fbb,
139 const IValue& ivalue);
140 uint32_t storeFunctionAndGetIndex(
141 flatbuffers::FlatBufferBuilder& fbb,
142 const std::string& qn,
143 const mobile::Function& function);
144
145 uint32_t storeClassTypeAndGetIndex(
146 flatbuffers::FlatBufferBuilder& fbb,
147 const ClassTypePtr& class_type);
148
149 flatbuffers::Offset<flatbuffers::Vector<
150 flatbuffers::Offset<mobile::serialization::ExtraFile>>>
151 storeExtraFilesAndGetOffset(
152 FlatBufferBuilder& fbb,
153 const ExtraFilesMap& extra_files);
154
insertIValue(flatbuffers::Offset<mobile::serialization::IValue> ivalue)155 uint32_t insertIValue(
156 flatbuffers::Offset<mobile::serialization::IValue> ivalue) {
157 uint32_t size = ivalue_offsets_.size();
158 ivalue_offsets_.push_back(ivalue);
159 return size;
160 }
161
162 std::vector<at::Tensor> tensor_data_;
163
164 std::unordered_map<const void*, uint32_t> memoized_storage_map_;
165
166 std::vector<flatbuffers::Offset<mobile::serialization::IValue>>
167 ivalue_offsets_;
168 std::vector<flatbuffers::Offset<mobile::serialization::ObjectType>>
169 obj_types_offset_;
170
171 // qualified name to serialized class, type or function
172 std::unordered_map<std::string, uint32_t> qn_to_serialized_values_;
173
174 // cache of some ivalues
175 struct IValueHash {
operator ()torch::jit::__anon332933560111::FlatbufferSerializer::IValueHash176 size_t operator()(const IValue& val) const {
177 return IValue::hash(val);
178 }
179 };
180
181 struct IValueEqual {
182 // Copy of this
183 // https://www.internalfb.com/code/aros/[3b875bce7ffa2adacdcea9b3e0cb6d304737a193]/xros/third-party/caffe2/caffe2/aten/src/ATen/core/ivalue.cpp?lines=266
184 // but without relying on aten::nonzero operator being present in the
185 // binary.
operator ()torch::jit::__anon332933560111::FlatbufferSerializer::IValueEqual186 bool operator()(const IValue& lhs, const IValue& rhs) const {
187 // The only case we don't return bool is for tensor comparison. Lets do
188 // pointer comparison here.
189 if (lhs.isTensor() || rhs.isTensor()) {
190 if (lhs.isTensor() && rhs.isTensor()) {
191 return (&lhs.toTensor()) == (&rhs.toTensor());
192 }
193 return false;
194 }
195 IValue eq = lhs.equals(rhs);
196 if (eq.isBool()) {
197 return eq.toBool();
198 }
199 return false;
200 }
201 };
202
203 std::unordered_map<IValue, uint32_t, IValueHash, IValueEqual> cached_ivalues_;
204 const mobile::CompilationUnit* mcu_ = nullptr;
205 };
206
207 flatbuffers::Offset<jit::mobile::serialization::Schema> FlatbufferSerializer::
CreateFBSchema(flatbuffers::FlatBufferBuilder & fbb,const std::vector<Argument> & args,const std::vector<Argument> & returns,const c10::TypePrinter & type_printer)208 CreateFBSchema(
209 flatbuffers::FlatBufferBuilder& fbb,
210 const std::vector<Argument>& args,
211 const std::vector<Argument>& returns,
212 const c10::TypePrinter& type_printer) {
213 std::vector<flatbuffers::Offset<jit::mobile::serialization::Arg>> arg_vec;
214 arg_vec.reserve(args.size());
215 std::vector<flatbuffers::Offset<jit::mobile::serialization::Arg>> return_vec;
216 return_vec.reserve(returns.size());
217 for (const auto& arg : args) {
218 auto index = storeIValueAndGetIndex(fbb, arg.default_value());
219 arg_vec.emplace_back(CreateArg(
220 fbb,
221 fbb.CreateSharedString(arg.name()),
222 fbb.CreateSharedString(
223 realType(arg.type())->annotation_str(type_printer)),
224 index));
225 }
226
227 for (const auto& ret : returns) {
228 auto index = storeIValueAndGetIndex(fbb, ret.default_value());
229 return_vec.emplace_back(CreateArg(
230 fbb,
231 fbb.CreateSharedString(ret.name()),
232 fbb.CreateSharedString(
233 realType(ret.type())->annotation_str(type_printer)),
234 index));
235 }
236 return CreateSchema(
237 fbb, fbb.CreateVector(arg_vec), fbb.CreateVector(return_vec));
238 }
239
240 flatbuffers::Offset<mobile::serialization::Function> FlatbufferSerializer::
functionToFB(FlatBufferBuilder & fbb,const std::string & qn,const mobile::Function & func)241 functionToFB(
242 FlatBufferBuilder& fbb,
243 const std::string& qn,
244 const mobile::Function& func) {
245 const auto& code = func.get_code();
246
247 // instructions
248 std::vector<mobile::serialization::Instruction> instruction_vector;
249 instruction_vector.reserve(code.instructions_.size());
250 for (const auto& inst : code.instructions_) {
251 instruction_vector.emplace_back(inst.op, inst.N, inst.X);
252 }
253
254 // operators
255 std::vector<flatbuffers::Offset<mobile::serialization::Operator>>
256 operator_vector;
257 operator_vector.reserve(code.op_names_.size());
258 for (const auto i : c10::irange(code.op_names_.size())) {
259 const auto& opname = code.op_names_[i];
260 const int op_size = code.operator_input_sizes_[i];
261 operator_vector.push_back(CreateOperator(
262 fbb,
263 fbb.CreateSharedString(opname.name),
264 fbb.CreateSharedString(opname.overload_name),
265 op_size));
266 }
267
268 const auto& constants = code.constants_;
269
270 std::vector<uint32_t> constant_indexes;
271 constant_indexes.reserve(constants.size());
272 for (const auto& constant : constants) {
273 constant_indexes.push_back(storeIValueAndGetIndex(fbb, constant));
274 }
275
276 // types
277 static const std::string torch_prefix("__torch__");
278 static const std::string class_prefix("__torch__.torch.classes");
279 std::vector<flatbuffers::Offset<flatbuffers::String>> type_offsets;
280
281 for (const TypePtr& t : code.types_) {
282 auto type_str = realType(t)->annotation_str();
283 if (type_str.find(torch_prefix) == 0) {
284 TORCH_CHECK(
285 type_str.find(class_prefix) == 0,
286 "__torch__ types other than custom c++ classes (__torch__.torch.classes)"
287 "are not supported in lite interpreter. ",
288 "Workaround: instead of using arbitrary class type (class Foo()), ",
289 "define a pytorch class (class Foo(torch.nn.Module)).");
290 }
291
292 type_offsets.push_back(fbb.CreateSharedString(type_str));
293 }
294
295 // since the register location is embedded into the bytecode, pass the
296 // register size
297 auto register_size = static_cast<int>(code.register_size_);
298
299 // schema
300 auto type_printer = [&](const c10::Type& t) -> std::optional<std::string> {
301 auto namedType = t.cast<c10::NamedType>();
302 if (namedType && namedType->name()) {
303 return namedType->name().value().qualifiedName();
304 }
305 if (auto dyn = t.castRaw<c10::DynamicType>()) {
306 return dyn->fallback()->annotation_str();
307 }
308 return std::nullopt;
309 };
310
311 flatbuffers::Offset<mobile::serialization::Schema> schema_offset = 0;
312 uint32_t class_index = 0;
313 if (func.hasSchema()) {
314 const auto& schema = func.getSchema();
315 TORCH_CHECK(
316 schema.overload_name().empty(), // @TODO: is this check correct?
317 "Overloads are not supported in mobile modules.");
318 TORCH_CHECK(
319 !schema.is_vararg(),
320 "Python *args are not supported in mobile modules.");
321 TORCH_CHECK(
322 !schema.is_varret(),
323 "A variable number of return values is not supported in mobile modules.");
324 schema_offset =
325 CreateFBSchema(fbb, schema.arguments(), schema.returns(), type_printer);
326 auto classtype = schema.arguments()[0].type()->cast<ClassType>();
327 class_index = storeClassTypeAndGetIndex(fbb, classtype);
328 }
329
330 auto debug_info_offset =
331 CreateDebugInfo(fbb, fbb.CreateVector(code.debug_handles_));
332
333 auto function_offset = CreateFunctionDirect(
334 fbb,
335 qn.c_str(),
336 &instruction_vector,
337 &operator_vector,
338 &constant_indexes,
339 &type_offsets,
340 register_size,
341 schema_offset,
342 debug_info_offset,
343 class_index);
344 return function_offset;
345 }
346
347 flatbuffers::Offset<
348 flatbuffers::Vector<flatbuffers::Offset<mobile::serialization::ExtraFile>>>
storeExtraFilesAndGetOffset(FlatBufferBuilder & fbb,const ExtraFilesMap & extra_files)349 FlatbufferSerializer::storeExtraFilesAndGetOffset(
350 FlatBufferBuilder& fbb,
351 const ExtraFilesMap& extra_files) {
352 std::vector<flatbuffers::Offset<mobile::serialization::ExtraFile>>
353 extra_file_offsets;
354
355 for (const auto& extra_file : extra_files) {
356 flatbuffers::Offset<mobile::serialization::ExtraFile> extra_file_offset =
357 mobile::serialization::CreateExtraFile(
358 fbb,
359 fbb.CreateSharedString(extra_file.first),
360 fbb.CreateString(extra_file.second));
361 extra_file_offsets.emplace_back(extra_file_offset);
362 }
363 return fbb.CreateVector(extra_file_offsets);
364 }
365
serializeModule(const mobile::Module & module,bool include_tensor_data_in_flatbuffer,const ExtraFilesMap & extra_files,const ExtraFilesMap & jit_sources,const std::vector<IValue> & jit_constants)366 flatbuffers::DetachedBuffer FlatbufferSerializer::serializeModule(
367 const mobile::Module& module,
368 bool include_tensor_data_in_flatbuffer,
369 const ExtraFilesMap& extra_files,
370 const ExtraFilesMap& jit_sources,
371 const std::vector<IValue>& jit_constants) {
372 FlatBufferBuilder fbb;
373
374 mcu_ = &module.compilation_unit();
375
376 // first element is None.
377 insertIValue(CreateIValue(fbb, mobile::serialization::IValueUnion::NONE, 0));
378
379 auto methods = module.get_methods();
380 std::vector<uint32_t> functions_index;
381 functions_index.reserve(methods.size());
382 for (const auto& method : methods) {
383 auto func_offset = storeFunctionAndGetIndex(
384 fbb, method.function().qualname().qualifiedName(), method.function());
385 functions_index.push_back(func_offset);
386 }
387
388 auto functions_offset = fbb.CreateVector(functions_index);
389 uint32_t ivalue_index = storeIValueAndGetIndex(fbb, module._ivalue());
390
391 flatbuffers::Offset<flatbuffers::Vector<
392 flatbuffers::Offset<mobile::serialization::StorageData>>>
393 storage_data_offset = 0;
394 auto extra_files_offset = storeExtraFilesAndGetOffset(fbb, extra_files);
395
396 auto jit_source_offset = storeExtraFilesAndGetOffset(fbb, jit_sources);
397 std::vector<uint32_t> jit_constants_indexes;
398 jit_constants_indexes.reserve(jit_constants.size());
399 const uint32_t mobile_ivalue_size = ivalue_offsets_.size();
400 for (const auto& ival : jit_constants) {
401 jit_constants_indexes.emplace_back(storeIValueAndGetIndex(fbb, ival));
402 }
403 const uint32_t operator_version =
404 static_cast<uint32_t>(module.min_operator_version());
405 uint32_t bytecode_version = static_cast<uint32_t>(module.bytecode_version());
406 if (bytecode_version < kMinVersion) {
407 bytecode_version = kMinVersion;
408 }
409
410 // NOTE: saving of storage has to be the last thing to do.
411 if (include_tensor_data_in_flatbuffer) {
412 std::vector<flatbuffers::Offset<mobile::serialization::StorageData>>
413 storage_data;
414 for (auto td : tensor_data_) {
415 if (td.storage().device_type() != DeviceType::CPU) {
416 td = at::empty({0}, td.options())
417 .set_(
418 td.storage(),
419 /* storage_offset = */ 0,
420 /* size = */
421 {static_cast<int64_t>(
422 td.storage().nbytes() / td.element_size())},
423 /* stride = */ {1})
424 .cpu();
425 }
426 fbb.ForceVectorAlignment(
427 td.storage().nbytes(), sizeof(uint8_t), FLATBUFFERS_MAX_ALIGNMENT);
428 auto storage_offset = mobile::serialization::CreateStorageData(
429 fbb,
430 fbb.CreateVector(
431 reinterpret_cast<const uint8_t*>(td.storage().data()),
432 td.storage().nbytes()));
433 storage_data.push_back(storage_offset);
434 }
435 storage_data_offset = fbb.CreateVector(storage_data);
436 }
437
438 auto mod = CreateModule(
439 fbb,
440 /*bytecode_version=*/bytecode_version,
441 extra_files_offset, /* extra_files */
442 functions_offset,
443 ivalue_index,
444 fbb.CreateVector(ivalue_offsets_),
445 static_cast<int32_t>(tensor_data_.size()),
446 storage_data_offset,
447 fbb.CreateVector(obj_types_offset_),
448 jit_source_offset,
449 fbb.CreateVector(jit_constants_indexes),
450 operator_version,
451 mobile_ivalue_size);
452 FinishModuleBuffer(fbb, mod);
453 return fbb.Release();
454 }
455
456 flatbuffers::Offset<mobile::serialization::Tuple> FlatbufferSerializer::
tupleToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & tuple)457 tupleToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& tuple) {
458 const auto& elements = tuple.toTuple()->elements();
459 std::vector<uint32_t> items =
460 storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end());
461 return CreateTupleDirect(fbb, &items);
462 }
463
listToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & list)464 flatbuffers::Offset<mobile::serialization::List> FlatbufferSerializer::listToFB(
465 flatbuffers::FlatBufferBuilder& fbb,
466 const IValue& list) {
467 const auto& elements = list.toList();
468 std::vector<uint32_t> items =
469 storeIValuesAndGetIndexes(fbb, elements.begin(), elements.end());
470 return CreateList(
471 fbb,
472 fbb.CreateVector(items),
473 fbb.CreateSharedString(
474 realType(list.type<c10::Type>())->annotation_str(print_type)));
475 }
476
dictToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & ivalue)477 flatbuffers::Offset<mobile::serialization::Dict> FlatbufferSerializer::dictToFB(
478 flatbuffers::FlatBufferBuilder& fbb,
479 const IValue& ivalue) {
480 const auto& dict = ivalue.toGenericDict();
481 std::vector<uint32_t> keys;
482 std::vector<uint32_t> values;
483 keys.reserve(dict.size());
484 values.reserve(dict.size());
485 for (const auto& entry : dict) {
486 auto key_index = storeIValueAndGetIndex(fbb, entry.key());
487 keys.push_back(key_index);
488 auto value_index = storeIValueAndGetIndex(fbb, entry.value());
489 values.push_back(value_index);
490 }
491
492 return CreateDict(
493 fbb,
494 fbb.CreateVector(keys),
495 fbb.CreateVector(values),
496 fbb.CreateSharedString(
497 realType(ivalue.type<c10::Type>())->annotation_str(print_type)));
498 }
499
500 flatbuffers::Offset<mobile::serialization::ObjectType> FlatbufferSerializer::
classTypeToFB(FlatBufferBuilder & fbb,const ClassTypePtr & class_ptr)501 classTypeToFB(FlatBufferBuilder& fbb, const ClassTypePtr& class_ptr) {
502 mobile::serialization::TypeType typetype =
503 mobile::serialization::TypeType::UNSET;
504
505 flatbuffers::Offset<
506 flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>>
507 names_offset = 0;
508 c10::QualifiedName setstate_name(*class_ptr->name(), "__setstate__");
509 c10::QualifiedName getstate_name(*class_ptr->name(), "__getstate__");
510 const mobile::Function* setstate = mcu_->find_function(setstate_name);
511 const mobile::Function* getstate = mcu_->find_function(getstate_name);
512 if (setstate != nullptr && getstate != nullptr) {
513 typetype = mobile::serialization::TypeType::CLASS_WITH_SETSTATE;
514 } else if (
515 class_ptr->findMethod("__setstate__") &&
516 class_ptr->findMethod("__getstate__")) {
517 typetype = mobile::serialization::TypeType::CUSTOM_CLASS;
518 } else {
519 size_t num_attr = class_ptr->numAttributes();
520 std::vector<flatbuffers::Offset<flatbuffers::String>> names;
521 std::vector<uint32_t> type_index;
522 for (size_t i = 0; i < num_attr; ++i) {
523 names.push_back(fbb.CreateSharedString(class_ptr->getAttributeName(i)));
524 }
525 names_offset = fbb.CreateVector(names);
526 typetype = mobile::serialization::TypeType::CLASS_WITH_FIELD;
527 }
528
529 auto name_offset = fbb.CreateString(class_ptr->name()->qualifiedName());
530 return CreateObjectType(fbb, name_offset, typetype, names_offset);
531 }
532
storeFunctionAndGetIndex(flatbuffers::FlatBufferBuilder & fbb,const std::string & qn,const mobile::Function & function)533 uint32_t FlatbufferSerializer::storeFunctionAndGetIndex(
534 flatbuffers::FlatBufferBuilder& fbb,
535 const std::string& qn,
536 const mobile::Function& function) {
537 auto iter = qn_to_serialized_values_.find(qn);
538 if (iter != qn_to_serialized_values_.end()) {
539 return iter->second;
540 }
541
542 auto offset = CreateIValue(
543 fbb,
544 mobile::serialization::IValueUnion::Function,
545 functionToFB(fbb, qn, function).Union());
546
547 uint32_t index = insertIValue(offset);
548 qn_to_serialized_values_[qn] = index;
549 return index;
550 }
551
storeClassTypeAndGetIndex(FlatBufferBuilder & fbb,const ClassTypePtr & class_ptr)552 uint32_t FlatbufferSerializer::storeClassTypeAndGetIndex(
553 FlatBufferBuilder& fbb,
554 const ClassTypePtr& class_ptr) {
555 const auto& type_str = class_ptr->name()->qualifiedName();
556 auto iter = qn_to_serialized_values_.find(type_str);
557 if (iter != qn_to_serialized_values_.end()) {
558 return iter->second;
559 }
560
561 auto offset = classTypeToFB(fbb, class_ptr);
562 uint32_t res = obj_types_offset_.size();
563 obj_types_offset_.push_back(offset);
564 qn_to_serialized_values_[type_str] = res;
565 return res;
566 }
567
568 flatbuffers::Offset<mobile::serialization::Object> FlatbufferSerializer::
objectToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & ivalue)569 objectToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) {
570 auto obj = ivalue.toObject();
571 auto type = obj->type();
572 // rename type?
573 // check getstate
574
575 // save state as ivalue
576 flatbuffers::Offset<flatbuffers::Vector<uint32_t>> attrs = 0;
577 uint32_t state_index = 0;
578 uint32_t setstate_func_index = 0;
579 const auto qn = type->name()->qualifiedName() + ".__setstate__";
580 auto getstate = type->findMethod("__getstate__");
581 auto setstate = type->findMethod("__setstate__");
582 if (getstate && setstate) {
583 auto state = (*getstate)({obj});
584 state_index = storeIValueAndGetIndex(fbb, state);
585 auto func_index = qn_to_serialized_values_.find(qn);
586 if (func_index != qn_to_serialized_values_.end()) {
587 setstate_func_index = func_index->second;
588 }
589 } else {
590 size_t num_attr = type->numAttributes();
591 std::vector<uint32_t> tuple_index;
592 for (size_t i = 0; i < num_attr; ++i) {
593 tuple_index.push_back(storeIValueAndGetIndex(fbb, obj->getSlot(i)));
594 }
595 attrs = fbb.CreateVector(tuple_index);
596 }
597
598 uint32_t type_index = storeClassTypeAndGetIndex(fbb, type);
599 return CreateObject(fbb, type_index, state_index, attrs, setstate_func_index);
600 }
601
602 flatbuffers::Offset<mobile::serialization::TensorMetadata> FlatbufferSerializer::
tensorToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & ivalue)603 FlatbufferSerializer::tensorToFB(
604 flatbuffers::FlatBufferBuilder& fbb,
605 const IValue& ivalue) {
606 auto& tensor = ivalue.toTensor();
607 bool quantized = tensor.is_quantized();
608 const at::Storage& storage = tensor.storage();
609
610 flatbuffers::Offset<mobile::serialization::QuantizedSchema> qschema_offset =
611 0;
612 if (quantized) {
613 double scale = 0;
614 int64_t zero_point = 0;
615 flatbuffers::Offset<mobile::serialization::TensorMetadata> scales = 0;
616 flatbuffers::Offset<mobile::serialization::TensorMetadata> zero_points = 0;
617 int64_t axis = 0;
618
619 switch (tensor.qscheme()) {
620 case at::kPerTensorAffine:
621 scale = tensor.q_scale();
622 zero_point = tensor.q_zero_point();
623 break;
624 case at::kPerChannelAffineFloatQParams:
625 case at::kPerChannelAffine: {
626 scales = tensorToFB(fbb, tensor.q_per_channel_scales());
627 zero_points = tensorToFB(fbb, tensor.q_per_channel_zero_points());
628 axis = tensor.q_per_channel_axis();
629 } break;
630 default:
631 TORCH_CHECK(
632 false,
633 "Unsupported tensor quantization type in serialization ",
634 toString(tensor.qscheme()));
635 break;
636 }
637
638 qschema_offset = mobile::serialization::CreateQuantizedSchema(
639 fbb,
640 static_cast<int8_t>(tensor.qscheme()),
641 scale,
642 static_cast<int32_t>(zero_point),
643 scales,
644 zero_points,
645 static_cast<int32_t>(axis));
646 }
647
648 void* addr = storage.unsafeGetStorageImpl();
649 uint32_t storage_index = 0;
650 auto it = memoized_storage_map_.find(addr);
651 if (it != memoized_storage_map_.end()) {
652 storage_index = it->second;
653 } else {
654 storage_index = tensor_data_.size();
655 memoized_storage_map_[addr] = storage_index;
656 tensor_data_.push_back(tensor);
657 }
658
659 std::vector<int> sizes{tensor.sizes().begin(), tensor.sizes().end()};
660 std::vector<int> strides{tensor.strides().begin(), tensor.strides().end()};
661
662 return CreateTensorMetadataDirect(
663 fbb,
664 /* storage_location_index */ storage_index,
665 /* scalar_type */ static_cast<int8_t>(tensor.scalar_type()),
666 /* int32_t storage_offset */
667 static_cast<int32_t>(tensor.storage_offset()),
668 /* sizes */ &sizes,
669 /* strides */ &strides,
670 /* bool requires_grad */ tensor.requires_grad(),
671 /* qschema */ qschema_offset);
672 }
673
storeIValueAndGetIndex(flatbuffers::FlatBufferBuilder & fbb,const IValue & ivalue)674 uint32_t FlatbufferSerializer::storeIValueAndGetIndex(
675 flatbuffers::FlatBufferBuilder& fbb,
676 const IValue& ivalue) {
677 if (ivalue.isNone()) {
678 return kNoneIndex;
679 }
680
681 try {
682 auto iter = cached_ivalues_.find(ivalue);
683 if (iter != cached_ivalues_.end()) {
684 return iter->second;
685 }
686 // NOLINTNEXTLINE(bugprone-empty-catch)
687 } catch (...) {
688 // Threw if ivalue is not hashable or
689 // if ivalue is don't have proper operator==
690 // we don't care catchall because either case we want to skip hashing
691 }
692
693 auto offset = iValueToFB(fbb, ivalue);
694 uint32_t index = insertIValue(offset);
695 try {
696 cached_ivalues_[ivalue] = index;
697 // NOLINTNEXTLINE(bugprone-empty-catch)
698 } catch (...) {
699 // Threw if ivalue is not hashable or
700 // if ivalue is don't have proper operator==
701 // we don't care catchall because either case we want to skip hashing
702 }
703
704 return index;
705 }
706
707 flatbuffers::Offset<mobile::serialization::IValue> FlatbufferSerializer::
iValueToFB(flatbuffers::FlatBufferBuilder & fbb,const IValue & ivalue)708 iValueToFB(flatbuffers::FlatBufferBuilder& fbb, const IValue& ivalue) {
709 using mobile::serialization::IValueUnion;
710
711 IValueUnion ivalue_type = IValueUnion::NONE;
712 flatbuffers::Offset<void> offset = 0;
713
714 if (ivalue.isTensor()) {
715 ivalue_type = IValueUnion::TensorMetadata;
716 offset = tensorToFB(fbb, ivalue).Union();
717 } else if (ivalue.isTuple()) {
718 ivalue_type = IValueUnion::Tuple;
719 offset = tupleToFB(fbb, ivalue).Union();
720 } else if (ivalue.isDouble()) {
721 ivalue_type = IValueUnion::Double;
722 offset = fbb.CreateStruct(mobile::serialization::Double(ivalue.toDouble()))
723 .Union();
724 } else if (ivalue.isComplexDouble()) {
725 auto comp = ivalue.toComplexDouble();
726 ivalue_type = IValueUnion::ComplexDouble;
727 offset = fbb.CreateStruct(mobile::serialization::ComplexDouble(
728 comp.real(), comp.imag()))
729 .Union();
730 } else if (ivalue.isInt()) {
731 ivalue_type = IValueUnion::Int;
732 offset =
733 fbb.CreateStruct(mobile::serialization::Int(ivalue.toInt())).Union();
734 } else if (ivalue.isBool()) {
735 ivalue_type = IValueUnion::Bool;
736 offset =
737 fbb.CreateStruct(mobile::serialization::Bool(ivalue.toBool())).Union();
738 } else if (ivalue.isString()) {
739 ivalue_type = IValueUnion::String;
740 offset = mobile::serialization::CreateString(
741 fbb, fbb.CreateSharedString(ivalue.toStringRef()))
742 .Union();
743 } else if (ivalue.isGenericDict()) {
744 ivalue_type = IValueUnion::Dict;
745 offset = dictToFB(fbb, ivalue).Union();
746 } else if (ivalue.isNone()) {
747 ivalue_type = IValueUnion::NONE;
748 offset = 0;
749 } else if (ivalue.isIntList()) {
750 ivalue_type = IValueUnion::IntList;
751 offset = mobile::serialization::CreateIntList(
752 fbb, fbb.CreateVector(ivalue.toIntVector()))
753 .Union();
754 } else if (ivalue.isDoubleList()) {
755 ivalue_type = IValueUnion::DoubleList;
756 offset = mobile::serialization::CreateDoubleList(
757 fbb, fbb.CreateVector(ivalue.toDoubleVector()))
758 .Union();
759 } else if (ivalue.isBoolList()) {
760 ivalue_type = IValueUnion::BoolList;
761 auto boollist = ivalue.toBoolList();
762 std::vector<uint8_t> bool_vec(boollist.begin(), boollist.end());
763 offset =
764 mobile::serialization::CreateBoolListDirect(fbb, &bool_vec).Union();
765 } else if (ivalue.isList()) {
766 ivalue_type = IValueUnion::List;
767 offset = listToFB(fbb, ivalue).Union();
768 } else if (ivalue.isObject()) {
769 ivalue_type = IValueUnion::Object;
770 offset = objectToFB(fbb, ivalue).Union();
771 } else if (ivalue.isDevice()) {
772 ivalue_type = IValueUnion::Device;
773 offset = mobile::serialization::CreateDevice(
774 fbb, fbb.CreateSharedString(ivalue.toDevice().str()))
775 .Union();
776 } else if (ivalue.isEnum()) {
777 const auto& enum_holder = ivalue.toEnumHolder();
778 const auto& qualified_class_name =
779 enum_holder->type()->qualifiedClassName();
780 uint32_t ival_pos = storeIValueAndGetIndex(fbb, enum_holder->value());
781 ivalue_type = IValueUnion::EnumValue;
782 offset = mobile::serialization::CreateEnumValue(
783 fbb,
784 fbb.CreateSharedString(qualified_class_name.qualifiedName()),
785 ival_pos)
786 .Union();
787 } else {
788 AT_ERROR("Invalid IValue type for serialization: ", ivalue.tagKind());
789 }
790 return CreateIValue(fbb, ivalue_type, offset);
791 }
792
793 } // namespace
794
save_mobile_module(const mobile::Module & module,const std::string & filename,const ExtraFilesMap & extra_files,const ExtraFilesMap & jit_sources,const std::vector<IValue> & jit_constants)795 void save_mobile_module(
796 const mobile::Module& module,
797 const std::string& filename,
798 const ExtraFilesMap& extra_files,
799 const ExtraFilesMap& jit_sources,
800 const std::vector<IValue>& jit_constants) {
801 auto buffer = save_mobile_module_to_bytes(
802 module, extra_files, jit_sources, jit_constants);
803 std::fstream ofile(filename, std::ios::binary | std::ios::out);
804 ofile.write(
805 reinterpret_cast<char*>(buffer->data()),
806 static_cast<std::streamsize>(buffer->size()));
807 ofile.close();
808 }
809
810 /// Deletes a DetachedBuffer, along with the internal
811 /// flatbuffers::DetachedBuffer if present. Used as a custom deleter for
812 /// std::unique_ptr; see UniqueDetachedBuffer and make_unique_detached_buffer.
destroy(DetachedBuffer * buf)813 void DetachedBuffer::destroy(DetachedBuffer* buf) {
814 // May be null.
815 delete static_cast<flatbuffers::DetachedBuffer*>(buf->data_owner_);
816 delete buf;
817 }
818
819 /// Provides access to DetachedBuffer::destroy().
820 struct DetachedBufferFriend {
821 /// Returns a UniqueDetachedBuffer that wraps the provided DetachedBuffer.
make_unique_detached_buffertorch::jit::DetachedBufferFriend822 static DetachedBuffer::UniqueDetachedBuffer make_unique_detached_buffer(
823 DetachedBuffer* buf) {
824 return DetachedBuffer::UniqueDetachedBuffer(buf, DetachedBuffer::destroy);
825 }
826 };
827
save_mobile_module_to_bytes(const mobile::Module & module,const ExtraFilesMap & extra_files,const ExtraFilesMap & jit_sources,const std::vector<IValue> & jit_constants)828 DetachedBuffer::UniqueDetachedBuffer save_mobile_module_to_bytes(
829 const mobile::Module& module,
830 const ExtraFilesMap& extra_files,
831 const ExtraFilesMap& jit_sources,
832 const std::vector<IValue>& jit_constants) {
833 FlatbufferSerializer fb_serializer;
834 flatbuffers::DetachedBuffer buf = fb_serializer.serializeModule(
835 module,
836 /*include_tensor_data_in_flatbuffer=*/true,
837 extra_files,
838 jit_sources,
839 jit_constants);
840 flatbuffers::DetachedBuffer* buf_ptr =
841 new flatbuffers::DetachedBuffer(std::move(buf));
842 DetachedBuffer* ret =
843 new DetachedBuffer(buf_ptr->data(), buf_ptr->size(), buf_ptr);
844 return DetachedBufferFriend::make_unique_detached_buffer(ret);
845 }
846
save_mobile_module_to_func(const mobile::Module & module,const std::function<size_t (const void *,size_t)> & writer_func)847 void save_mobile_module_to_func(
848 const mobile::Module& module,
849 const std::function<size_t(const void*, size_t)>& writer_func) {
850 auto buffer = save_mobile_module_to_bytes(module);
851 writer_func(buffer->data(), buffer->size());
852 }
853
register_flatbuffer_serializer()854 bool register_flatbuffer_serializer() {
855 return true;
856 }
857
858 } // namespace torch::jit
859