1 #ifdef FLATBUFFERS_VERSION_MAJOR
2 #error "flatbuffer_loader.h must not include any flatbuffers headers"
3 #endif // FLATBUFFERS_VERSION_MAJOR
4
5 #include <array>
6 #include <istream>
7 #include <memory>
8 #include <string>
9 #include <tuple>
10 #include <unordered_map>
11 #include <unordered_set>
12 #include <utility>
13 #include <vector>
14
15 #include <ATen/ATen.h>
16 #include <ATen/core/dynamic_type.h>
17 #include <ATen/core/ivalue.h>
18 #include <ATen/core/qualified_name.h>
19 #include <c10/core/CPUAllocator.h>
20 #include <c10/core/impl/alloc_cpu.h>
21 #include <c10/util/Exception.h>
22 #include <c10/util/ScopeExit.h>
23 #include <caffe2/serialize/inline_container.h>
24 #include <torch/csrc/jit/mobile/file_format.h>
25 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
26 #include <torch/csrc/jit/mobile/function.h>
27 #include <torch/csrc/jit/mobile/import.h>
28 #include <torch/csrc/jit/mobile/interpreter.h>
29 #include <torch/csrc/jit/mobile/module.h>
30 #include <torch/csrc/jit/mobile/observer.h>
31 #include <torch/csrc/jit/mobile/type_parser.h>
32 #include <torch/csrc/jit/runtime/instruction.h>
33 #include <torch/csrc/jit/serialization/export_bytecode.h>
34 #include <torch/csrc/jit/serialization/import_export_constants.h>
35 #include <torch/csrc/jit/serialization/import_read.h>
36 #include <torch/custom_class.h>
37 #include <optional>
38
39 #ifndef DISABLE_UPGRADER
40 #include <torch/csrc/jit/mobile/parse_bytecode.h>
41 #include <torch/csrc/jit/mobile/upgrader_mobile.h>
42 #endif
43
44 #ifdef _WIN32
45 #include <malloc.h>
46 #else
47 #include <cstdlib>
48 #endif
49
50 #if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2)
51 #include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT
52 namespace flatbuffers = flatbuffers_fbsource;
53 #define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT
54 #else
55 #include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
56 #endif
57
58 namespace torch::jit {
59
60 // Our own alignment requirement does not need to be exactly the same as what
61 // flatbuffers supports, but what flatbuffers supports needs to satisfy our
62 // requirement.
63 static_assert(
64 kFlatbufferDataAlignmentBytes <= FLATBUFFERS_MAX_ALIGNMENT,
65 "Sizes must be compatible");
66 static_assert(
67 (kFlatbufferDataAlignmentBytes & ~(kFlatbufferDataAlignmentBytes - 1)) ==
68 kFlatbufferDataAlignmentBytes,
69 "Must be a power of 2");
70
71 namespace {
72
73 static constexpr c10::string_view kCustomClassPrefix =
74 "__torch__.torch.classes";
75 static constexpr c10::string_view kTorchPrefix = "__torch__";
76 static constexpr c10::string_view kJitPrefix = "torch.jit";
77
78 class FlatbufferLoader final {
79 public:
80 FlatbufferLoader();
81
82 typedef IValue (
83 *IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
84 void registerIValueParser(
85 mobile::serialization::IValueUnion ivalue_type,
86 IValueParser parser);
87 mobile::Module parseModule(mobile::serialization::Module* module, char* end);
88
89 void extractJitSourceAndConstants(
90 ExtraFilesMap* jit_sources,
91 std::vector<IValue>* constants);
92
93 using TypeResolver = TypePtr (*)(
94 const std::string& type_str,
95 const std::shared_ptr<CompilationUnit>& cu);
96
97 void internal_registerTypeResolver(TypeResolver type_resolver);
98
getIValue(uint32_t pos)99 IValue& getIValue(uint32_t pos) {
100 TORCH_CHECK(pos < all_ivalues_.size());
101 return all_ivalues_[pos];
102 }
103
getFunction(uint32_t pos)104 mobile::Function* getFunction(uint32_t pos) {
105 return all_functions_[pos];
106 }
107
getType(uint32_t pos)108 ClassTypePtr getType(uint32_t pos) {
109 TORCH_CHECK(pos < all_types_.size());
110 return all_types_[pos];
111 }
112
113 c10::Storage getStorage(uint32_t index);
114 TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
115 ClassTypePtr getOrCreateClassTypeForObject(
116 const mobile::serialization::Object* object);
117
getCurrentFlatbufferInput()118 const mobile::serialization::Module* getCurrentFlatbufferInput() {
119 return module_;
120 }
121
setShouldCopyTensorMemory(bool should_copy_tensor_memory)122 void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
123 should_copy_tensor_memory_ = should_copy_tensor_memory;
124 }
125
126 std::shared_ptr<mobile::CompilationUnit> mcu_;
127 std::shared_ptr<CompilationUnit> cu_;
128
129 private:
130 IValue parseIValue(const mobile::serialization::IValue* ivalue);
131 std::unique_ptr<mobile::Function> parseFunction(
132 const mobile::serialization::Function* method);
133 void parseAndPopulate(
134 uint32_t i,
135 const mobile::serialization::IValue* ivalue);
136
137 std::unordered_map<uint32_t, mobile::Function*> all_functions_;
138 std::vector<ClassTypePtr> all_types_;
139 std::unordered_set<uint32_t> initialized_types_;
140 std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
141 std::vector<bool> storage_loaded_;
142 std::vector<c10::Storage> storages_;
143 std::vector<IValue> all_ivalues_;
144 std::array<
145 IValueParser,
146 static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
147 ivalue_parsers_;
148 TypeResolver type_resolver_ = nullptr;
149 mobile::serialization::Module* module_ = nullptr;
150 bool module_parsed_ = false;
151 bool should_copy_tensor_memory_ = false;
152 // 0 -> mobile_ivalue_size_ elements are from the mobile module.
153 uint32_t mobile_ivalue_size_ = 0;
154 };
155
156 IValue parseList(
157 FlatbufferLoader&,
158 const mobile::serialization::IValue& ivalue);
159 IValue parseTensor(
160 FlatbufferLoader&,
161 const mobile::serialization::IValue& ivalue);
162 IValue parseTuple(
163 FlatbufferLoader&,
164 const mobile::serialization::IValue& ivalue);
165 IValue parseDict(
166 FlatbufferLoader&,
167 const mobile::serialization::IValue& ivalue);
168 IValue parseObject(
169 FlatbufferLoader&,
170 const mobile::serialization::IValue& ivalue);
171 IValue parseIntList(
172 FlatbufferLoader&,
173 const mobile::serialization::IValue& ivalue);
174 IValue parseDoubleList(
175 FlatbufferLoader&,
176 const mobile::serialization::IValue& ivalue);
177 IValue parseBoolList(
178 FlatbufferLoader&,
179 const mobile::serialization::IValue& ivalue);
180 IValue parseBasic(
181 FlatbufferLoader&,
182 const mobile::serialization::IValue& ivalue);
183 IValue parseEnum(
184 FlatbufferLoader&,
185 const mobile::serialization::IValue& ivalue);
186
resolveType(const std::string & type_string,const std::shared_ptr<CompilationUnit> & cu)187 TypePtr resolveType(
188 const std::string& type_string,
189 const std::shared_ptr<CompilationUnit>& cu) {
190 TypePtr type;
191 c10::string_view type_str(type_string);
192 if (type_str.starts_with(kCustomClassPrefix)) {
193 type = getCustomClass(type_string);
194 TORCH_CHECK(
195 type, "The implementation of class ", type_string, " cannot be found.");
196 } else if (
197 type_str.starts_with(kTorchPrefix) || type_str.starts_with(kJitPrefix)) {
198 c10::QualifiedName qn(type_string);
199 if (cu->get_class(qn) == nullptr) {
200 auto classtype = ClassType::create(qn, cu, true);
201 cu->register_type(classtype);
202 type = classtype;
203 } else {
204 type = cu->get_class(qn);
205 }
206 } else {
207 type = c10::parseType(type_string);
208 }
209 return type;
210 }
211
FlatbufferLoader()212 FlatbufferLoader::FlatbufferLoader()
213 : mcu_(std::make_shared<mobile::CompilationUnit>()),
214 cu_(std::make_shared<CompilationUnit>()),
215 ivalue_parsers_{nullptr} {
216 registerIValueParser(mobile::serialization::IValueUnion::NONE, &parseBasic);
217 registerIValueParser(mobile::serialization::IValueUnion::Int, &parseBasic);
218 registerIValueParser(mobile::serialization::IValueUnion::Bool, &parseBasic);
219 registerIValueParser(mobile::serialization::IValueUnion::Double, &parseBasic);
220 registerIValueParser(
221 mobile::serialization::IValueUnion::ComplexDouble, &parseBasic);
222 registerIValueParser(
223 mobile::serialization::IValueUnion::TensorMetadata, &parseTensor);
224 registerIValueParser(mobile::serialization::IValueUnion::String, &parseBasic);
225 registerIValueParser(mobile::serialization::IValueUnion::List, &parseList);
226 registerIValueParser(
227 mobile::serialization::IValueUnion::IntList, &parseIntList);
228 registerIValueParser(
229 mobile::serialization::IValueUnion::DoubleList, &parseDoubleList);
230 registerIValueParser(
231 mobile::serialization::IValueUnion::BoolList, &parseBoolList);
232 registerIValueParser(mobile::serialization::IValueUnion::Tuple, &parseTuple);
233 registerIValueParser(mobile::serialization::IValueUnion::Dict, &parseDict);
234 registerIValueParser(
235 mobile::serialization::IValueUnion::Object, &parseObject);
236 registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic);
237 registerIValueParser(
238 mobile::serialization::IValueUnion::EnumValue, &parseEnum);
239 internal_registerTypeResolver(&resolveType);
240 }
241
registerIValueParser(mobile::serialization::IValueUnion ivalue_type,IValueParser parser)242 void FlatbufferLoader::registerIValueParser(
243 mobile::serialization::IValueUnion ivalue_type,
244 IValueParser parser) {
245 ivalue_parsers_[static_cast<uint8_t>(ivalue_type)] = parser;
246 }
247
internal_registerTypeResolver(TypeResolver type_resolver)248 void FlatbufferLoader::internal_registerTypeResolver(
249 TypeResolver type_resolver) {
250 type_resolver_ = type_resolver;
251 }
252
parseExtraFilesFromVector(const flatbuffers::Vector<flatbuffers::Offset<torch::jit::mobile::serialization::ExtraFile>> * files,ExtraFilesMap * extra_files)253 void parseExtraFilesFromVector(
254 const flatbuffers::Vector<flatbuffers::Offset<
255 torch::jit::mobile::serialization::ExtraFile>>* files,
256 ExtraFilesMap* extra_files) {
257 for (uint32_t i = 0; i < files->size(); ++i) {
258 const auto* extra_file = files->Get(i);
259 (*extra_files)[extra_file->name()->str()] = extra_file->content()->str();
260 }
261 }
262
parseExtraFiles(mobile::serialization::Module * module,ExtraFilesMap & extra_files)263 void parseExtraFiles(
264 mobile::serialization::Module* module,
265 ExtraFilesMap& extra_files) {
266 auto extra_files_offsets = module->extra_files();
267 parseExtraFilesFromVector(extra_files_offsets, &extra_files);
268 }
269
parseAndPopulate(uint32_t i,const mobile::serialization::IValue * ivalue)270 void FlatbufferLoader::parseAndPopulate(
271 uint32_t i,
272 const mobile::serialization::IValue* ivalue) {
273 if (const auto* func = ivalue->val_as_Function()) {
274 auto func_ptr = parseFunction(func);
275 all_functions_[i] = func_ptr.get();
276 mcu_->register_function(std::move(func_ptr));
277 } else {
278 all_ivalues_[i] = parseIValue(ivalue);
279 }
280 }
281
parseModule(mobile::serialization::Module * module,char * end)282 mobile::Module FlatbufferLoader::parseModule(
283 mobile::serialization::Module* module,
284 char* end) {
285 module_ = module;
286 all_ivalues_.clear();
287 all_types_.clear();
288 storages_.clear();
289 storage_loaded_.clear();
290 module_parsed_ = false;
291
292 const auto* ivalues = module->ivalues();
293 TORCH_CHECK(
294 ivalues && module->object_types(),
295 "Parsing flatbuffer module: Corrupted ivalues/object_types field");
296 TORCH_CHECK(
297 reinterpret_cast<const char*>(ivalues) < end, "Corrupted ivalues field");
298 TORCH_CHECK(
299 module->storage_data_size() >= 0,
300 "Parsing flatbuffer module: illegal storage_data_size: ",
301 module->storage_data_size(),
302 ", expected to be non negative");
303 all_ivalues_.resize(ivalues->size());
304 all_types_.resize(module->object_types()->size());
305 storages_.resize(module->storage_data_size());
306 storage_loaded_.resize(module->storage_data_size(), false);
307
308 mobile_ivalue_size_ = module_->mobile_ivalue_size();
309 if (mobile_ivalue_size_ == 0 || mobile_ivalue_size_ > ivalues->size()) {
310 mobile_ivalue_size_ = ivalues->size();
311 }
312
313 for (uint32_t i = 0; i < mobile_ivalue_size_; i++) {
314 const auto* ival = ivalues->Get(i);
315 TORCH_CHECK(
316 reinterpret_cast<const char*>(ival) < end, "Corrupted ivalue item")
317 parseAndPopulate(i, ival);
318 }
319 IValue& module_ivalue = getIValue(module->state_obj());
320
321 // register functions
322 for (const auto& f : all_functions_) {
323 uint32_t class_index =
324 ivalues->Get(f.first)->val_as_Function()->class_type();
325 ClassTypePtr class_type = all_types_[class_index];
326 class_type->addMethod(f.second);
327 }
328
329 module_parsed_ = true;
330 auto m = mobile::Module(module_ivalue.toObject(), mcu_);
331 m.set_min_operator_version(module->operator_version());
332 m.set_bytecode_version(module->bytecode_version());
333 return m;
334 }
335
appendUpgraderFunctions(mobile::Function * function)336 void appendUpgraderFunctions(mobile::Function* function) {
337 #ifndef DISABLE_UPGRADER
338 for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
339 function->append_function(byteCodeFunctionWithOperator.function);
340 }
341 #endif
342 }
343
parseFunction(const mobile::serialization::Function * method)344 std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
345 const mobile::serialization::Function* method) {
346 auto function = std::make_unique<mobile::Function>(
347 c10::QualifiedName(method->qn()->str()));
348 // TODO(qihan) add debug handle
349 // const auto* debug_handle = method->debug_info()->debug_handle();
350 for (const auto* inst : *method->instructions()) {
351 function->append_instruction(
352 static_cast<OpCode>(inst->op()), inst->x(), inst->n());
353 }
354
355 for (uint32_t i : *method->constants()) {
356 function->append_constant(getIValue(i));
357 }
358
359 appendUpgraderFunctions(function.get());
360 // 2. Decides if upgrader is needed
361 const uint32_t operator_version = module_->operator_version();
362 bool use_upgrader =
363 (operator_version < caffe2::serialize::kProducedFileFormatVersion);
364
365 for (const auto* op : *method->operators()) {
366 std::optional<int> num_args = std::nullopt;
367 if (op->num_args_serialized() > -1) {
368 num_args = op->num_args_serialized();
369 }
370
371 function->append_operator(
372 op->name()->str(), op->overload_name()->str(), num_args);
373 }
374
375 function->initialize_operators(true);
376
377 for (const auto i : *method->type_annotations()) {
378 function->append_type(getOrCreateTypeAnnotations(i));
379 }
380
381 // 3. If upgrader is needed, change change the OP instrunction to CALL
382 // instruction (In next PR, use_upgrader will be parsed to parseInstruction
383 // function and do the actual change)
384 if (use_upgrader) {
385 #ifndef DISABLE_UPGRADER
386 applyUpgrader(function.get(), operator_version);
387 #endif
388 }
389
390 function->set_register_size(method->register_size());
391 if (method->schema()) {
392 try {
393 auto parseArgList = [this](const auto* args_fb) {
394 std::vector<c10::Argument> args;
395 for (const auto* arg_tb : *args_fb) {
396 IValue default_value = getIValue(arg_tb->default_value());
397 TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type());
398 auto arg = c10::Argument(
399 arg_tb->name()->str(),
400 std::move(type_ptr),
401 std::nullopt /*N*/,
402 std::move(default_value));
403 args.emplace_back(std::move(arg));
404 }
405 return args;
406 };
407 c10::FunctionSchema schema(
408 method->qn()->str(),
409 "" /*overload_name*/,
410 parseArgList(method->schema()->arguments()),
411 parseArgList(method->schema()->returns()),
412 false /*is_varargs*/,
413 false /*is_varret*/);
414
415 function->setSchema(std::move(schema));
416 } catch (const c10::Error& e) {
417 }
418 }
419 return function;
420 }
421
parseEnum(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)422 IValue parseEnum(
423 FlatbufferLoader& loader,
424 const mobile::serialization::IValue& ivalue) {
425 const auto* enum_val = ivalue.val_as_EnumValue();
426 auto enum_type = loader.getOrCreateTypeAnnotations(enum_val->type_name())
427 ->cast<c10::EnumType>();
428 AT_ASSERT(
429 enum_type,
430 "Enum with type: " + enum_val->type_name()->str() + " not found.");
431 IValue val = loader.getIValue(enum_val->value());
432 for (const auto& p : enum_type->enumNamesValues()) {
433 if (p.second == val) {
434 auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
435 enum_type, p.first, p.second);
436 return IValue(std::move(enum_holder));
437 }
438 }
439 AT_ASSERT(
440 false, "Enum with type: " + enum_val->type_name()->str() + " not found.");
441 }
442
parseBasic(FlatbufferLoader &,const mobile::serialization::IValue & ivalue)443 IValue parseBasic(
444 FlatbufferLoader&,
445 const mobile::serialization::IValue& ivalue) {
446 switch (ivalue.val_type()) {
447 case mobile::serialization::IValueUnion::NONE:
448 return {};
449 case mobile::serialization::IValueUnion::Int:
450 return ivalue.val_as_Int()->int_val();
451 case mobile::serialization::IValueUnion::Bool:
452 return ivalue.val_as_Bool()->bool_val();
453 case mobile::serialization::IValueUnion::Double:
454 return ivalue.val_as_Double()->double_val();
455 case mobile::serialization::IValueUnion::ComplexDouble: {
456 const auto* comp = ivalue.val_as_ComplexDouble();
457 return c10::complex<double>(comp->real(), comp->imag());
458 }
459 case mobile::serialization::IValueUnion::String:
460 return ivalue.val_as_String()->data()->str();
461 case mobile::serialization::IValueUnion::Device: {
462 return c10::Device(ivalue.val_as_Device()->str()->str());
463 }
464 default:
465 return {};
466 }
467 }
468
parseTensorFromMetadata(FlatbufferLoader * loader,const mobile::serialization::TensorMetadata * tensor_md)469 at::Tensor parseTensorFromMetadata(
470 FlatbufferLoader* loader,
471 const mobile::serialization::TensorMetadata* tensor_md) {
472 at::ScalarType type = static_cast<at::ScalarType>(tensor_md->scalar_type());
473 auto options = at::CPU(type).options();
474 at::Tensor tensor;
475 if (tensor_md->quantized_schema() != nullptr) {
476 // is quantized
477 const auto* schema = tensor_md->quantized_schema();
478 auto qscheme_type = static_cast<at::QScheme>(schema->qscheme());
479 switch (qscheme_type) {
480 case at::kPerTensorAffine: {
481 tensor = at::_empty_affine_quantized(
482 {0}, options, schema->scale(), schema->zero_point());
483 } break;
484 case at::kPerChannelAffineFloatQParams:
485 case at::kPerChannelAffine: {
486 at::Tensor scales = parseTensorFromMetadata(loader, schema->scales());
487 at::Tensor zero_points =
488 parseTensorFromMetadata(loader, schema->zero_points());
489 tensor = at::_empty_per_channel_affine_quantized(
490 {0}, scales, zero_points, schema->axis(), options);
491 } break;
492 default:
493 TORCH_CHECK(
494 false,
495 "Unsupported tensor quantization type in serialization ",
496 toString(qscheme_type));
497 break;
498 }
499 } else {
500 tensor = at::empty({0}, options);
501 }
502 at::TensorImpl* impl = tensor.unsafeGetTensorImpl();
503
504 c10::Storage storage;
505 storage = loader->getStorage(tensor_md->storage_location_index());
506 impl->set_storage_keep_dtype(storage);
507 impl->set_storage_offset(tensor_md->storage_offset());
508
509 std::vector<int64_t> size{
510 tensor_md->sizes()->begin(), tensor_md->sizes()->end()};
511 std::vector<int64_t> stride{
512 tensor_md->strides()->begin(), tensor_md->strides()->end()};
513 impl->set_sizes_and_strides(size, stride);
514 #ifndef MIN_EDGE_RUNTIME
515 tensor = autograd::make_variable(tensor, tensor_md->requires_grad());
516 #endif
517 return tensor;
518 }
519
parseTensor(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)520 IValue parseTensor(
521 FlatbufferLoader& loader,
522 const mobile::serialization::IValue& ivalue) {
523 const mobile::serialization::TensorMetadata* tensor_md =
524 ivalue.val_as_TensorMetadata();
525 return parseTensorFromMetadata(&loader, tensor_md);
526 }
527
parseList(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)528 IValue parseList(
529 FlatbufferLoader& loader,
530 const mobile::serialization::IValue& ivalue) {
531 const mobile::serialization::List* list = ivalue.val_as_List();
532 auto res = c10::impl::GenericList(AnyType::get());
533 for (auto i : *list->items()) {
534 res.emplace_back(loader.getIValue(i));
535 }
536 auto type = loader.getOrCreateTypeAnnotations(list->annotation_str());
537 res.unsafeSetElementType(type->containedType(0));
538 return res;
539 }
540
541 template <typename T, typename U>
parseListNative(const U * list)542 std::vector<T> parseListNative(const U* list) {
543 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
544 return {list->items()->begin(), list->items()->end()};
545 }
546
parseIntList(FlatbufferLoader &,const mobile::serialization::IValue & ivalue)547 IValue parseIntList(
548 FlatbufferLoader&,
549 const mobile::serialization::IValue& ivalue) {
550 const auto& list = ivalue.val_as_IntList();
551 return parseListNative<int64_t>(list);
552 }
553
parseDoubleList(FlatbufferLoader &,const mobile::serialization::IValue & ivalue)554 IValue parseDoubleList(
555 FlatbufferLoader&,
556 const mobile::serialization::IValue& ivalue) {
557 const auto& list = ivalue.val_as_DoubleList();
558 return parseListNative<double>(list);
559 }
560
parseBoolList(FlatbufferLoader &,const mobile::serialization::IValue & ivalue)561 IValue parseBoolList(
562 FlatbufferLoader&,
563 const mobile::serialization::IValue& ivalue) {
564 const auto& list = ivalue.val_as_BoolList();
565 std::vector<uint8_t> res = parseListNative<uint8_t>(list);
566 c10::List<bool> boollist;
567 for (auto x : res) {
568 boollist.push_back(x);
569 }
570 return boollist;
571 }
572
parseTuple(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)573 IValue parseTuple(
574 FlatbufferLoader& loader,
575 const mobile::serialization::IValue& ivalue) {
576 const auto& tuple = ivalue.val_as_Tuple();
577 const auto items = tuple->items();
578 std::vector<IValue> res;
579 res.reserve(items->size());
580 for (auto i : *items) {
581 res.emplace_back(loader.getIValue(i));
582 }
583 return c10::ivalue::Tuple::create(std::move(res));
584 }
585
parseDict(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)586 IValue parseDict(
587 FlatbufferLoader& loader,
588 const mobile::serialization::IValue& ivalue) {
589 const auto* dict = ivalue.val_as_Dict();
590 auto result = c10::impl::GenericDict(AnyType::get(), AnyType::get());
591 const auto* keys = dict->keys();
592 const auto* values = dict->values();
593 for (size_t i = 0; i < keys->size(); ++i) {
594 uint32_t key = keys->Get(i);
595 uint32_t val = values->Get(i);
596 result.insert_or_assign(loader.getIValue(key), loader.getIValue(val));
597 }
598 auto type = loader.getOrCreateTypeAnnotations(dict->annotation_str());
599 result.unsafeSetKeyType(type->containedType(0));
600 result.unsafeSetValueType(type->containedType(1));
601 return result;
602 }
603
getOrCreateClassTypeForObject(const mobile::serialization::Object * object)604 ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject(
605 const mobile::serialization::Object* object) {
606 auto cls = getType(object->type_index());
607 const mobile::serialization::ObjectType* obj_type =
608 module_->object_types()->Get(object->type_index());
609 if (cls == nullptr) {
610 c10::string_view qn_str(
611 obj_type->type_name()->c_str(), obj_type->type_name()->size());
612 if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
613 c10::QualifiedName qn(obj_type->type_name()->str());
614 cls = cu_->get_class(qn);
615 if (cls == nullptr) {
616 cls = ClassType::create(qn, cu_, true);
617 cu_->register_type(cls);
618 }
619 } else {
620 cls = c10::parseType(std::string(qn_str))->cast<ClassType>();
621 }
622 TORCH_CHECK(object->type_index() < all_ivalues_.size());
623 all_types_[object->type_index()] = cls;
624
625 if (obj_type->type() == mobile::serialization::TypeType::CLASS_WITH_FIELD) {
626 for (uint32_t i = 0; i < object->attrs()->size(); i++) {
627 IValue val = getIValue(object->attrs()->Get(i));
628 // Need to use concrete object's field's type to set type of field.
629 cls->addAttribute(
630 obj_type->attr_names()->Get(i)->str(),
631 val.type<c10::DynamicType>());
632 }
633 }
634 initialized_types_.insert(object->type_index());
635 }
636 return cls;
637 }
638
parseObject(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)639 IValue parseObject(
640 FlatbufferLoader& loader,
641 const mobile::serialization::IValue& ivalue) {
642 const mobile::serialization::Object* object = ivalue.val_as_Object();
643 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(object != nullptr);
644 const auto* cur_input = loader.getCurrentFlatbufferInput();
645 const mobile::serialization::ObjectType* obj_type =
646 cur_input->object_types()->Get(object->type_index());
647 auto cls = loader.getOrCreateClassTypeForObject(object);
648 Stack stack;
649 switch (obj_type->type()) {
650 case mobile::serialization::TypeType::CLASS_WITH_FIELD: {
651 auto obj = c10::ivalue::Object::create(
652 at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
653 for (uint32_t i = 0; i < object->attrs()->size(); i++) {
654 IValue val = loader.getIValue(object->attrs()->Get(i));
655 obj->setSlot(i, std::move(val));
656 }
657 return obj;
658 }
659 case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: {
660 IValue input = loader.getIValue(object->state());
661 mobile::Function* setstate = loader.getFunction(object->setstate_func());
662 auto obj =
663 c10::ivalue::Object::create(at::StrongTypePtr(loader.cu_, cls), 0);
664 stack.emplace_back(obj);
665 stack.emplace_back(std::move(input));
666 setstate->run(stack);
667 return obj;
668 }
669 case mobile::serialization::TypeType::CUSTOM_CLASS: {
670 auto custom_class_type =
671 torch::jit::getCustomClass(cls->name()->qualifiedName());
672 IValue input = loader.getIValue(object->state());
673 auto obj = c10::ivalue::Object::create(
674 c10::StrongTypePtr(nullptr, custom_class_type), 1);
675 stack.emplace_back(obj);
676 stack.emplace_back(std::move(input));
677 custom_class_type->getMethod("__setstate__").run(stack);
678 return obj;
679 }
680 default:
681 AT_ASSERT(false, "need to be object");
682 }
683 }
684
parseIValue(const mobile::serialization::IValue * ivalue)685 IValue FlatbufferLoader::parseIValue(
686 const mobile::serialization::IValue* ivalue) {
687 return ivalue_parsers_[static_cast<uint32_t>(ivalue->val_type())](
688 *this, *ivalue);
689 }
690
691 void deleteNothing2(void*);
deleteNothing2(void *)692 void deleteNothing2(void*) {}
693
getStorage(uint32_t index)694 c10::Storage FlatbufferLoader::getStorage(uint32_t index) {
695 TORCH_CHECK(index < storage_loaded_.size());
696 TORCH_CHECK(index < storages_.size());
697 if (!storage_loaded_[index]) {
698 auto* storage = module_->storage_data()->GetMutableObject(index);
699 size_t size = storage->data()->size();
700
701 at::DataPtr data;
702 if (should_copy_tensor_memory_) {
703 auto* allocator = at::GetCPUAllocator();
704 data = allocator->allocate(size);
705 memcpy(data.get(), storage->data()->data(), size);
706 } else {
707 void* ptr = static_cast<void*>(storage->mutable_data()->data());
708 data = at::DataPtr(ptr, ptr, deleteNothing2, DeviceType::CPU);
709 }
710 storages_[index] =
711 c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data));
712 storage_loaded_[index] = true;
713 }
714 return storages_[index];
715 }
716
getOrCreateTypeAnnotations(const flatbuffers::String * offset)717 TypePtr FlatbufferLoader::getOrCreateTypeAnnotations(
718 const flatbuffers::String* offset) {
719 auto iter = type_annotations_.find(offset);
720 if (iter != type_annotations_.end()) {
721 return iter->second;
722 }
723 TypePtr type = type_resolver_(offset->str(), cu_);
724 type_annotations_[offset] = type;
725 return type;
726 }
727
extractJitSourceAndConstants(ExtraFilesMap * jit_sources,std::vector<IValue> * constants)728 void FlatbufferLoader::extractJitSourceAndConstants(
729 ExtraFilesMap* jit_sources,
730 std::vector<IValue>* constants) {
731 AT_ASSERT(
732 module_parsed_,
733 "Need to first parse a flatbuffer file before extracting jit_sources");
734
735 const auto* ivalues = module_->ivalues();
736 for (uint32_t i = mobile_ivalue_size_; i < ivalues->size(); i++) {
737 const auto* ival = ivalues->Get(i);
738 parseAndPopulate(i, ival);
739 }
740 // register functions
741 for (const auto& f : all_functions_) {
742 if (f.first >= mobile_ivalue_size_) {
743 uint32_t class_index =
744 ivalues->Get(f.first)->val_as_Function()->class_type();
745 ClassTypePtr class_type = all_types_[class_index];
746 class_type->addMethod(f.second);
747 }
748 }
749 const auto* jit_constants = module_->jit_constants();
750 for (const auto i : c10::irange(jit_constants->size())) {
751 constants->emplace_back(getIValue(jit_constants->Get(i)));
752 }
753 parseExtraFilesFromVector(module_->jit_sources(), jit_sources);
754 }
755
756 } // namespace
757
parse_and_initialize_mobile_module(void * data,size_t size,std::optional<at::Device>,ExtraFilesMap * extra_files,bool should_copy_tensor_memory)758 mobile::Module parse_and_initialize_mobile_module(
759 void* data,
760 size_t size,
761 std::optional<at::Device>,
762 ExtraFilesMap* extra_files,
763 bool should_copy_tensor_memory) {
764 // TODO(T128189662): If not copying, enforce that data is aligned to
765 // kFlatbufferDataAlignmentBytes, and add unit tests.
766
767 // Validate Flatbuffer module before parsing.
768 flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data), size);
769 TORCH_CHECK(
770 mobile::serialization::VerifyModuleBuffer(verifier),
771 "Malformed Flatbuffer module");
772
773 FlatbufferLoader loader;
774 loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
775
776 // Flatbuffer doesn't seem to have a way to provide the buffer size when
777 // interacting with the buffer.
778 auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
779 auto* end = static_cast<char*>(data) + size;
780 mobile::Module m = loader.parseModule(flatbuffer_module, end);
781 if (extra_files != nullptr) {
782 parseExtraFiles(flatbuffer_module, *extra_files);
783 }
784 return m;
785 }
786
parse_and_initialize_mobile_module(std::shared_ptr<char> data,size_t size,std::optional<at::Device> device,ExtraFilesMap * extra_files)787 mobile::Module parse_and_initialize_mobile_module(
788 std::shared_ptr<char> data,
789 size_t size,
790 std::optional<at::Device> device,
791 ExtraFilesMap* extra_files) {
792 mobile::Module m = parse_and_initialize_mobile_module(
793 data.get(),
794 size,
795 device,
796 extra_files,
797 /*should_copy_tensor_memory=*/false);
798 m.set_delete_memory(std::move(data));
799 return m;
800 }
801
parse_and_initialize_mobile_module_for_jit(void * data,size_t size,ExtraFilesMap & jit_sources,std::vector<IValue> & jit_constants,std::optional<at::Device>,ExtraFilesMap * extra_files)802 mobile::Module parse_and_initialize_mobile_module_for_jit(
803 void* data,
804 size_t size,
805 ExtraFilesMap& jit_sources,
806 std::vector<IValue>& jit_constants,
807 std::optional<at::Device>,
808 ExtraFilesMap* extra_files) {
809 TORCH_CHECK(
810 mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
811 // TODO(T128189662): Enforce that data is aligned to
812 // kFlatbufferDataAlignmentBytes, and add unit tests.
813
814 // Validate Flatbuffer module before parsing.
815 flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data), size);
816 TORCH_CHECK(
817 mobile::serialization::VerifyModuleBuffer(verifier),
818 "Malformed Flatbuffer module");
819
820 FlatbufferLoader loader;
821 auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
822 auto* end = static_cast<char*>(data) + size;
823 mobile::Module m = loader.parseModule(flatbuffer_module, end);
824 if (extra_files != nullptr) {
825 parseExtraFiles(flatbuffer_module, *extra_files);
826 }
827
828 loader.extractJitSourceAndConstants(&jit_sources, &jit_constants);
829 return m;
830 }
831
load_mobile_module_from_file(const std::string & filename,std::optional<c10::Device> device,ExtraFilesMap * extra_files)832 mobile::Module load_mobile_module_from_file(
833 const std::string& filename,
834 std::optional<c10::Device> device,
835 ExtraFilesMap* extra_files) {
836 auto [data, size] = get_file_content(filename.c_str());
837 return parse_and_initialize_mobile_module(
838 std::move(data), size, device, extra_files);
839 }
840
get_bytecode_version(std::istream & in)841 uint64_t get_bytecode_version(std::istream& in) {
842 auto [data, size] = get_stream_content(in);
843 return get_bytecode_version_from_bytes(data.get());
844 }
845
get_bytecode_version(const std::string & filename)846 uint64_t get_bytecode_version(const std::string& filename) {
847 auto [data, size] = get_file_content(filename.c_str());
848 return get_bytecode_version_from_bytes(data.get());
849 }
850
get_bytecode_version_from_bytes(char * flatbuffer_content)851 uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content) {
852 TORCH_CHECK(
853 mobile::serialization::ModuleBufferHasIdentifier(flatbuffer_content),
854 "Format error");
855 auto* flatbuffer_module =
856 mobile::serialization::GetMutableModule(flatbuffer_content);
857 return flatbuffer_module->bytecode_version();
858 }
859
get_module_info_from_flatbuffer(char * flatbuffer_content)860 mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) {
861 auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content);
862 mobile::ModuleInfo minfo;
863 minfo.operator_version = ff_module->operator_version();
864 minfo.bytecode_version = ff_module->bytecode_version();
865
866 uint32_t mobile_ivalue_size = ff_module->mobile_ivalue_size();
867 if (mobile_ivalue_size == 0) {
868 mobile_ivalue_size = ff_module->ivalues()->size();
869 }
870
871 std::vector<std::string> type_name_list;
872 for (uint32_t i = 0; i < mobile_ivalue_size; i++) {
873 const auto* ival = ff_module->ivalues()->Get(i);
874 if (const auto* func = ival->val_as_Function()) {
875 minfo.function_names.insert(func->qn()->str());
876 for (const auto* op : *func->operators()) {
877 at::OperatorName opname(op->name()->str(), op->overload_name()->str());
878 minfo.opname_to_num_args[mobile::operator_str(opname)] =
879 op->num_args_serialized();
880 }
881 for (const auto* type_ann : *func->type_annotations()) {
882 type_name_list.push_back(type_ann->str());
883 }
884 }
885 }
886 c10::TypeParser parser(type_name_list);
887 parser.parseList();
888 minfo.type_names = parser.getContainedTypes();
889 return minfo;
890 }
891
load_mobile_module_from_stream_with_copy(std::istream & in,std::optional<at::Device> device,ExtraFilesMap * extra_files)892 mobile::Module load_mobile_module_from_stream_with_copy(
893 std::istream& in,
894 std::optional<at::Device> device,
895 ExtraFilesMap* extra_files) {
896 auto [data, size] = get_stream_content(in);
897 return parse_and_initialize_mobile_module(
898 std::move(data), size, device, extra_files);
899 }
900
parse_flatbuffer_no_object(std::shared_ptr<char> data,size_t size,std::optional<at::Device> device)901 mobile::Module parse_flatbuffer_no_object(
902 std::shared_ptr<char> data,
903 size_t size,
904 std::optional<at::Device> device) {
905 (void)device;
906 (void)size;
907
908 // Validate Flatbuffer module before parsing.
909 flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data.get()), size);
910 TORCH_CHECK(
911 mobile::serialization::VerifyModuleBuffer(verifier),
912 "Malformed Flatbuffer module");
913
914 auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
915 FlatbufferLoader loader;
916 // replace parserObject with to handle only class with field case
917 // function.
918 loader.registerIValueParser(
919 mobile::serialization::IValueUnion::Object,
920 +[](FlatbufferLoader& loader,
921 const mobile::serialization::IValue& ivalue) {
922 const mobile::serialization::Object* object = ivalue.val_as_Object();
923 auto cls = loader.getOrCreateClassTypeForObject(object);
924 auto obj = c10::ivalue::Object::create(
925 at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
926 for (uint32_t i = 0; i < object->attrs()->size(); i++) {
927 IValue val = loader.getIValue(object->attrs()->Get(i));
928 obj->setSlot(i, std::move(val));
929 }
930 return static_cast<c10::IValue>(obj);
931 });
932
933 auto* end = data.get() + size;
934 mobile::Module m = loader.parseModule(flatbuffer_module, end);
935 m.set_delete_memory(std::move(data));
936 return m;
937 }
938
register_flatbuffer_loader()939 bool register_flatbuffer_loader() {
940 return true;
941 }
942
943 } // namespace torch::jit
944