xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/flatbuffer_loader.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifdef FLATBUFFERS_VERSION_MAJOR
2 #error "flatbuffer_loader.h must not include any flatbuffers headers"
3 #endif // FLATBUFFERS_VERSION_MAJOR
4 
5 #include <array>
6 #include <istream>
7 #include <memory>
8 #include <string>
9 #include <tuple>
10 #include <unordered_map>
11 #include <unordered_set>
12 #include <utility>
13 #include <vector>
14 
15 #include <ATen/ATen.h>
16 #include <ATen/core/dynamic_type.h>
17 #include <ATen/core/ivalue.h>
18 #include <ATen/core/qualified_name.h>
19 #include <c10/core/CPUAllocator.h>
20 #include <c10/core/impl/alloc_cpu.h>
21 #include <c10/util/Exception.h>
22 #include <c10/util/ScopeExit.h>
23 #include <caffe2/serialize/inline_container.h>
24 #include <torch/csrc/jit/mobile/file_format.h>
25 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
26 #include <torch/csrc/jit/mobile/function.h>
27 #include <torch/csrc/jit/mobile/import.h>
28 #include <torch/csrc/jit/mobile/interpreter.h>
29 #include <torch/csrc/jit/mobile/module.h>
30 #include <torch/csrc/jit/mobile/observer.h>
31 #include <torch/csrc/jit/mobile/type_parser.h>
32 #include <torch/csrc/jit/runtime/instruction.h>
33 #include <torch/csrc/jit/serialization/export_bytecode.h>
34 #include <torch/csrc/jit/serialization/import_export_constants.h>
35 #include <torch/csrc/jit/serialization/import_read.h>
36 #include <torch/custom_class.h>
37 #include <optional>
38 
39 #ifndef DISABLE_UPGRADER
40 #include <torch/csrc/jit/mobile/parse_bytecode.h>
41 #include <torch/csrc/jit/mobile/upgrader_mobile.h>
42 #endif
43 
44 #ifdef _WIN32
45 #include <malloc.h>
46 #else
47 #include <cstdlib>
48 #endif
49 
50 #if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2)
51 #include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT
52 namespace flatbuffers = flatbuffers_fbsource;
53 #define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT
54 #else
55 #include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
56 #endif
57 
58 namespace torch::jit {
59 
60 // Our own alignment requirement does not need to be exactly the same as what
61 // flatbuffers supports, but what flatbuffers supports needs to satisfy our
62 // requirement.
63 static_assert(
64     kFlatbufferDataAlignmentBytes <= FLATBUFFERS_MAX_ALIGNMENT,
65     "Sizes must be compatible");
66 static_assert(
67     (kFlatbufferDataAlignmentBytes & ~(kFlatbufferDataAlignmentBytes - 1)) ==
68         kFlatbufferDataAlignmentBytes,
69     "Must be a power of 2");
70 
71 namespace {
72 
73 static constexpr c10::string_view kCustomClassPrefix =
74     "__torch__.torch.classes";
75 static constexpr c10::string_view kTorchPrefix = "__torch__";
76 static constexpr c10::string_view kJitPrefix = "torch.jit";
77 
78 class FlatbufferLoader final {
79  public:
80   FlatbufferLoader();
81 
82   typedef IValue (
83       *IValueParser)(FlatbufferLoader&, const mobile::serialization::IValue&);
84   void registerIValueParser(
85       mobile::serialization::IValueUnion ivalue_type,
86       IValueParser parser);
87   mobile::Module parseModule(mobile::serialization::Module* module, char* end);
88 
89   void extractJitSourceAndConstants(
90       ExtraFilesMap* jit_sources,
91       std::vector<IValue>* constants);
92 
93   using TypeResolver = TypePtr (*)(
94       const std::string& type_str,
95       const std::shared_ptr<CompilationUnit>& cu);
96 
97   void internal_registerTypeResolver(TypeResolver type_resolver);
98 
getIValue(uint32_t pos)99   IValue& getIValue(uint32_t pos) {
100     TORCH_CHECK(pos < all_ivalues_.size());
101     return all_ivalues_[pos];
102   }
103 
getFunction(uint32_t pos)104   mobile::Function* getFunction(uint32_t pos) {
105     return all_functions_[pos];
106   }
107 
getType(uint32_t pos)108   ClassTypePtr getType(uint32_t pos) {
109     TORCH_CHECK(pos < all_types_.size());
110     return all_types_[pos];
111   }
112 
113   c10::Storage getStorage(uint32_t index);
114   TypePtr getOrCreateTypeAnnotations(const flatbuffers::String* offset);
115   ClassTypePtr getOrCreateClassTypeForObject(
116       const mobile::serialization::Object* object);
117 
getCurrentFlatbufferInput()118   const mobile::serialization::Module* getCurrentFlatbufferInput() {
119     return module_;
120   }
121 
setShouldCopyTensorMemory(bool should_copy_tensor_memory)122   void setShouldCopyTensorMemory(bool should_copy_tensor_memory) {
123     should_copy_tensor_memory_ = should_copy_tensor_memory;
124   }
125 
126   std::shared_ptr<mobile::CompilationUnit> mcu_;
127   std::shared_ptr<CompilationUnit> cu_;
128 
129  private:
130   IValue parseIValue(const mobile::serialization::IValue* ivalue);
131   std::unique_ptr<mobile::Function> parseFunction(
132       const mobile::serialization::Function* method);
133   void parseAndPopulate(
134       uint32_t i,
135       const mobile::serialization::IValue* ivalue);
136 
137   std::unordered_map<uint32_t, mobile::Function*> all_functions_;
138   std::vector<ClassTypePtr> all_types_;
139   std::unordered_set<uint32_t> initialized_types_;
140   std::unordered_map<const flatbuffers::String*, TypePtr> type_annotations_;
141   std::vector<bool> storage_loaded_;
142   std::vector<c10::Storage> storages_;
143   std::vector<IValue> all_ivalues_;
144   std::array<
145       IValueParser,
146       static_cast<uint8_t>(mobile::serialization::IValueUnion::MAX) + 1>
147       ivalue_parsers_;
148   TypeResolver type_resolver_ = nullptr;
149   mobile::serialization::Module* module_ = nullptr;
150   bool module_parsed_ = false;
151   bool should_copy_tensor_memory_ = false;
152   // 0 -> mobile_ivalue_size_ elements are from the mobile module.
153   uint32_t mobile_ivalue_size_ = 0;
154 };
155 
156 IValue parseList(
157     FlatbufferLoader&,
158     const mobile::serialization::IValue& ivalue);
159 IValue parseTensor(
160     FlatbufferLoader&,
161     const mobile::serialization::IValue& ivalue);
162 IValue parseTuple(
163     FlatbufferLoader&,
164     const mobile::serialization::IValue& ivalue);
165 IValue parseDict(
166     FlatbufferLoader&,
167     const mobile::serialization::IValue& ivalue);
168 IValue parseObject(
169     FlatbufferLoader&,
170     const mobile::serialization::IValue& ivalue);
171 IValue parseIntList(
172     FlatbufferLoader&,
173     const mobile::serialization::IValue& ivalue);
174 IValue parseDoubleList(
175     FlatbufferLoader&,
176     const mobile::serialization::IValue& ivalue);
177 IValue parseBoolList(
178     FlatbufferLoader&,
179     const mobile::serialization::IValue& ivalue);
180 IValue parseBasic(
181     FlatbufferLoader&,
182     const mobile::serialization::IValue& ivalue);
183 IValue parseEnum(
184     FlatbufferLoader&,
185     const mobile::serialization::IValue& ivalue);
186 
resolveType(const std::string & type_string,const std::shared_ptr<CompilationUnit> & cu)187 TypePtr resolveType(
188     const std::string& type_string,
189     const std::shared_ptr<CompilationUnit>& cu) {
190   TypePtr type;
191   c10::string_view type_str(type_string);
192   if (type_str.starts_with(kCustomClassPrefix)) {
193     type = getCustomClass(type_string);
194     TORCH_CHECK(
195         type, "The implementation of class ", type_string, " cannot be found.");
196   } else if (
197       type_str.starts_with(kTorchPrefix) || type_str.starts_with(kJitPrefix)) {
198     c10::QualifiedName qn(type_string);
199     if (cu->get_class(qn) == nullptr) {
200       auto classtype = ClassType::create(qn, cu, true);
201       cu->register_type(classtype);
202       type = classtype;
203     } else {
204       type = cu->get_class(qn);
205     }
206   } else {
207     type = c10::parseType(type_string);
208   }
209   return type;
210 }
211 
FlatbufferLoader()212 FlatbufferLoader::FlatbufferLoader()
213     : mcu_(std::make_shared<mobile::CompilationUnit>()),
214       cu_(std::make_shared<CompilationUnit>()),
215       ivalue_parsers_{nullptr} {
216   registerIValueParser(mobile::serialization::IValueUnion::NONE, &parseBasic);
217   registerIValueParser(mobile::serialization::IValueUnion::Int, &parseBasic);
218   registerIValueParser(mobile::serialization::IValueUnion::Bool, &parseBasic);
219   registerIValueParser(mobile::serialization::IValueUnion::Double, &parseBasic);
220   registerIValueParser(
221       mobile::serialization::IValueUnion::ComplexDouble, &parseBasic);
222   registerIValueParser(
223       mobile::serialization::IValueUnion::TensorMetadata, &parseTensor);
224   registerIValueParser(mobile::serialization::IValueUnion::String, &parseBasic);
225   registerIValueParser(mobile::serialization::IValueUnion::List, &parseList);
226   registerIValueParser(
227       mobile::serialization::IValueUnion::IntList, &parseIntList);
228   registerIValueParser(
229       mobile::serialization::IValueUnion::DoubleList, &parseDoubleList);
230   registerIValueParser(
231       mobile::serialization::IValueUnion::BoolList, &parseBoolList);
232   registerIValueParser(mobile::serialization::IValueUnion::Tuple, &parseTuple);
233   registerIValueParser(mobile::serialization::IValueUnion::Dict, &parseDict);
234   registerIValueParser(
235       mobile::serialization::IValueUnion::Object, &parseObject);
236   registerIValueParser(mobile::serialization::IValueUnion::Device, &parseBasic);
237   registerIValueParser(
238       mobile::serialization::IValueUnion::EnumValue, &parseEnum);
239   internal_registerTypeResolver(&resolveType);
240 }
241 
registerIValueParser(mobile::serialization::IValueUnion ivalue_type,IValueParser parser)242 void FlatbufferLoader::registerIValueParser(
243     mobile::serialization::IValueUnion ivalue_type,
244     IValueParser parser) {
245   ivalue_parsers_[static_cast<uint8_t>(ivalue_type)] = parser;
246 }
247 
internal_registerTypeResolver(TypeResolver type_resolver)248 void FlatbufferLoader::internal_registerTypeResolver(
249     TypeResolver type_resolver) {
250   type_resolver_ = type_resolver;
251 }
252 
parseExtraFilesFromVector(const flatbuffers::Vector<flatbuffers::Offset<torch::jit::mobile::serialization::ExtraFile>> * files,ExtraFilesMap * extra_files)253 void parseExtraFilesFromVector(
254     const flatbuffers::Vector<flatbuffers::Offset<
255         torch::jit::mobile::serialization::ExtraFile>>* files,
256     ExtraFilesMap* extra_files) {
257   for (uint32_t i = 0; i < files->size(); ++i) {
258     const auto* extra_file = files->Get(i);
259     (*extra_files)[extra_file->name()->str()] = extra_file->content()->str();
260   }
261 }
262 
parseExtraFiles(mobile::serialization::Module * module,ExtraFilesMap & extra_files)263 void parseExtraFiles(
264     mobile::serialization::Module* module,
265     ExtraFilesMap& extra_files) {
266   auto extra_files_offsets = module->extra_files();
267   parseExtraFilesFromVector(extra_files_offsets, &extra_files);
268 }
269 
parseAndPopulate(uint32_t i,const mobile::serialization::IValue * ivalue)270 void FlatbufferLoader::parseAndPopulate(
271     uint32_t i,
272     const mobile::serialization::IValue* ivalue) {
273   if (const auto* func = ivalue->val_as_Function()) {
274     auto func_ptr = parseFunction(func);
275     all_functions_[i] = func_ptr.get();
276     mcu_->register_function(std::move(func_ptr));
277   } else {
278     all_ivalues_[i] = parseIValue(ivalue);
279   }
280 }
281 
parseModule(mobile::serialization::Module * module,char * end)282 mobile::Module FlatbufferLoader::parseModule(
283     mobile::serialization::Module* module,
284     char* end) {
285   module_ = module;
286   all_ivalues_.clear();
287   all_types_.clear();
288   storages_.clear();
289   storage_loaded_.clear();
290   module_parsed_ = false;
291 
292   const auto* ivalues = module->ivalues();
293   TORCH_CHECK(
294       ivalues && module->object_types(),
295       "Parsing flatbuffer module: Corrupted ivalues/object_types field");
296   TORCH_CHECK(
297       reinterpret_cast<const char*>(ivalues) < end, "Corrupted ivalues field");
298   TORCH_CHECK(
299       module->storage_data_size() >= 0,
300       "Parsing flatbuffer module: illegal storage_data_size: ",
301       module->storage_data_size(),
302       ", expected to be non negative");
303   all_ivalues_.resize(ivalues->size());
304   all_types_.resize(module->object_types()->size());
305   storages_.resize(module->storage_data_size());
306   storage_loaded_.resize(module->storage_data_size(), false);
307 
308   mobile_ivalue_size_ = module_->mobile_ivalue_size();
309   if (mobile_ivalue_size_ == 0 || mobile_ivalue_size_ > ivalues->size()) {
310     mobile_ivalue_size_ = ivalues->size();
311   }
312 
313   for (uint32_t i = 0; i < mobile_ivalue_size_; i++) {
314     const auto* ival = ivalues->Get(i);
315     TORCH_CHECK(
316         reinterpret_cast<const char*>(ival) < end, "Corrupted ivalue item")
317     parseAndPopulate(i, ival);
318   }
319   IValue& module_ivalue = getIValue(module->state_obj());
320 
321   // register functions
322   for (const auto& f : all_functions_) {
323     uint32_t class_index =
324         ivalues->Get(f.first)->val_as_Function()->class_type();
325     ClassTypePtr class_type = all_types_[class_index];
326     class_type->addMethod(f.second);
327   }
328 
329   module_parsed_ = true;
330   auto m = mobile::Module(module_ivalue.toObject(), mcu_);
331   m.set_min_operator_version(module->operator_version());
332   m.set_bytecode_version(module->bytecode_version());
333   return m;
334 }
335 
appendUpgraderFunctions(mobile::Function * function)336 void appendUpgraderFunctions(mobile::Function* function) {
337 #ifndef DISABLE_UPGRADER
338   for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
339     function->append_function(byteCodeFunctionWithOperator.function);
340   }
341 #endif
342 }
343 
parseFunction(const mobile::serialization::Function * method)344 std::unique_ptr<mobile::Function> FlatbufferLoader::parseFunction(
345     const mobile::serialization::Function* method) {
346   auto function = std::make_unique<mobile::Function>(
347       c10::QualifiedName(method->qn()->str()));
348   // TODO(qihan) add debug handle
349   // const auto* debug_handle = method->debug_info()->debug_handle();
350   for (const auto* inst : *method->instructions()) {
351     function->append_instruction(
352         static_cast<OpCode>(inst->op()), inst->x(), inst->n());
353   }
354 
355   for (uint32_t i : *method->constants()) {
356     function->append_constant(getIValue(i));
357   }
358 
359   appendUpgraderFunctions(function.get());
360   // 2. Decides if upgrader is needed
361   const uint32_t operator_version = module_->operator_version();
362   bool use_upgrader =
363       (operator_version < caffe2::serialize::kProducedFileFormatVersion);
364 
365   for (const auto* op : *method->operators()) {
366     std::optional<int> num_args = std::nullopt;
367     if (op->num_args_serialized() > -1) {
368       num_args = op->num_args_serialized();
369     }
370 
371     function->append_operator(
372         op->name()->str(), op->overload_name()->str(), num_args);
373   }
374 
375   function->initialize_operators(true);
376 
377   for (const auto i : *method->type_annotations()) {
378     function->append_type(getOrCreateTypeAnnotations(i));
379   }
380 
381   // 3. If upgrader is needed, change change the OP instrunction to CALL
382   // instruction (In next PR, use_upgrader will be parsed to parseInstruction
383   // function and do the actual change)
384   if (use_upgrader) {
385 #ifndef DISABLE_UPGRADER
386     applyUpgrader(function.get(), operator_version);
387 #endif
388   }
389 
390   function->set_register_size(method->register_size());
391   if (method->schema()) {
392     try {
393       auto parseArgList = [this](const auto* args_fb) {
394         std::vector<c10::Argument> args;
395         for (const auto* arg_tb : *args_fb) {
396           IValue default_value = getIValue(arg_tb->default_value());
397           TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type());
398           auto arg = c10::Argument(
399               arg_tb->name()->str(),
400               std::move(type_ptr),
401               std::nullopt /*N*/,
402               std::move(default_value));
403           args.emplace_back(std::move(arg));
404         }
405         return args;
406       };
407       c10::FunctionSchema schema(
408           method->qn()->str(),
409           "" /*overload_name*/,
410           parseArgList(method->schema()->arguments()),
411           parseArgList(method->schema()->returns()),
412           false /*is_varargs*/,
413           false /*is_varret*/);
414 
415       function->setSchema(std::move(schema));
416     } catch (const c10::Error& e) {
417     }
418   }
419   return function;
420 }
421 
parseEnum(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)422 IValue parseEnum(
423     FlatbufferLoader& loader,
424     const mobile::serialization::IValue& ivalue) {
425   const auto* enum_val = ivalue.val_as_EnumValue();
426   auto enum_type = loader.getOrCreateTypeAnnotations(enum_val->type_name())
427                        ->cast<c10::EnumType>();
428   AT_ASSERT(
429       enum_type,
430       "Enum with type: " + enum_val->type_name()->str() + " not found.");
431   IValue val = loader.getIValue(enum_val->value());
432   for (const auto& p : enum_type->enumNamesValues()) {
433     if (p.second == val) {
434       auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
435           enum_type, p.first, p.second);
436       return IValue(std::move(enum_holder));
437     }
438   }
439   AT_ASSERT(
440       false, "Enum with type: " + enum_val->type_name()->str() + " not found.");
441 }
442 
parseBasic(FlatbufferLoader &,const mobile::serialization::IValue & ivalue)443 IValue parseBasic(
444     FlatbufferLoader&,
445     const mobile::serialization::IValue& ivalue) {
446   switch (ivalue.val_type()) {
447     case mobile::serialization::IValueUnion::NONE:
448       return {};
449     case mobile::serialization::IValueUnion::Int:
450       return ivalue.val_as_Int()->int_val();
451     case mobile::serialization::IValueUnion::Bool:
452       return ivalue.val_as_Bool()->bool_val();
453     case mobile::serialization::IValueUnion::Double:
454       return ivalue.val_as_Double()->double_val();
455     case mobile::serialization::IValueUnion::ComplexDouble: {
456       const auto* comp = ivalue.val_as_ComplexDouble();
457       return c10::complex<double>(comp->real(), comp->imag());
458     }
459     case mobile::serialization::IValueUnion::String:
460       return ivalue.val_as_String()->data()->str();
461     case mobile::serialization::IValueUnion::Device: {
462       return c10::Device(ivalue.val_as_Device()->str()->str());
463     }
464     default:
465       return {};
466   }
467 }
468 
parseTensorFromMetadata(FlatbufferLoader * loader,const mobile::serialization::TensorMetadata * tensor_md)469 at::Tensor parseTensorFromMetadata(
470     FlatbufferLoader* loader,
471     const mobile::serialization::TensorMetadata* tensor_md) {
472   at::ScalarType type = static_cast<at::ScalarType>(tensor_md->scalar_type());
473   auto options = at::CPU(type).options();
474   at::Tensor tensor;
475   if (tensor_md->quantized_schema() != nullptr) {
476     // is quantized
477     const auto* schema = tensor_md->quantized_schema();
478     auto qscheme_type = static_cast<at::QScheme>(schema->qscheme());
479     switch (qscheme_type) {
480       case at::kPerTensorAffine: {
481         tensor = at::_empty_affine_quantized(
482             {0}, options, schema->scale(), schema->zero_point());
483       } break;
484       case at::kPerChannelAffineFloatQParams:
485       case at::kPerChannelAffine: {
486         at::Tensor scales = parseTensorFromMetadata(loader, schema->scales());
487         at::Tensor zero_points =
488             parseTensorFromMetadata(loader, schema->zero_points());
489         tensor = at::_empty_per_channel_affine_quantized(
490             {0}, scales, zero_points, schema->axis(), options);
491       } break;
492       default:
493         TORCH_CHECK(
494             false,
495             "Unsupported tensor quantization type in serialization ",
496             toString(qscheme_type));
497         break;
498     }
499   } else {
500     tensor = at::empty({0}, options);
501   }
502   at::TensorImpl* impl = tensor.unsafeGetTensorImpl();
503 
504   c10::Storage storage;
505   storage = loader->getStorage(tensor_md->storage_location_index());
506   impl->set_storage_keep_dtype(storage);
507   impl->set_storage_offset(tensor_md->storage_offset());
508 
509   std::vector<int64_t> size{
510       tensor_md->sizes()->begin(), tensor_md->sizes()->end()};
511   std::vector<int64_t> stride{
512       tensor_md->strides()->begin(), tensor_md->strides()->end()};
513   impl->set_sizes_and_strides(size, stride);
514 #ifndef MIN_EDGE_RUNTIME
515   tensor = autograd::make_variable(tensor, tensor_md->requires_grad());
516 #endif
517   return tensor;
518 }
519 
parseTensor(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)520 IValue parseTensor(
521     FlatbufferLoader& loader,
522     const mobile::serialization::IValue& ivalue) {
523   const mobile::serialization::TensorMetadata* tensor_md =
524       ivalue.val_as_TensorMetadata();
525   return parseTensorFromMetadata(&loader, tensor_md);
526 }
527 
parseList(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)528 IValue parseList(
529     FlatbufferLoader& loader,
530     const mobile::serialization::IValue& ivalue) {
531   const mobile::serialization::List* list = ivalue.val_as_List();
532   auto res = c10::impl::GenericList(AnyType::get());
533   for (auto i : *list->items()) {
534     res.emplace_back(loader.getIValue(i));
535   }
536   auto type = loader.getOrCreateTypeAnnotations(list->annotation_str());
537   res.unsafeSetElementType(type->containedType(0));
538   return res;
539 }
540 
541 template <typename T, typename U>
parseListNative(const U * list)542 std::vector<T> parseListNative(const U* list) {
543   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(list != nullptr);
544   return {list->items()->begin(), list->items()->end()};
545 }
546 
parseIntList(FlatbufferLoader &,const mobile::serialization::IValue & ivalue)547 IValue parseIntList(
548     FlatbufferLoader&,
549     const mobile::serialization::IValue& ivalue) {
550   const auto& list = ivalue.val_as_IntList();
551   return parseListNative<int64_t>(list);
552 }
553 
parseDoubleList(FlatbufferLoader &,const mobile::serialization::IValue & ivalue)554 IValue parseDoubleList(
555     FlatbufferLoader&,
556     const mobile::serialization::IValue& ivalue) {
557   const auto& list = ivalue.val_as_DoubleList();
558   return parseListNative<double>(list);
559 }
560 
parseBoolList(FlatbufferLoader &,const mobile::serialization::IValue & ivalue)561 IValue parseBoolList(
562     FlatbufferLoader&,
563     const mobile::serialization::IValue& ivalue) {
564   const auto& list = ivalue.val_as_BoolList();
565   std::vector<uint8_t> res = parseListNative<uint8_t>(list);
566   c10::List<bool> boollist;
567   for (auto x : res) {
568     boollist.push_back(x);
569   }
570   return boollist;
571 }
572 
parseTuple(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)573 IValue parseTuple(
574     FlatbufferLoader& loader,
575     const mobile::serialization::IValue& ivalue) {
576   const auto& tuple = ivalue.val_as_Tuple();
577   const auto items = tuple->items();
578   std::vector<IValue> res;
579   res.reserve(items->size());
580   for (auto i : *items) {
581     res.emplace_back(loader.getIValue(i));
582   }
583   return c10::ivalue::Tuple::create(std::move(res));
584 }
585 
parseDict(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)586 IValue parseDict(
587     FlatbufferLoader& loader,
588     const mobile::serialization::IValue& ivalue) {
589   const auto* dict = ivalue.val_as_Dict();
590   auto result = c10::impl::GenericDict(AnyType::get(), AnyType::get());
591   const auto* keys = dict->keys();
592   const auto* values = dict->values();
593   for (size_t i = 0; i < keys->size(); ++i) {
594     uint32_t key = keys->Get(i);
595     uint32_t val = values->Get(i);
596     result.insert_or_assign(loader.getIValue(key), loader.getIValue(val));
597   }
598   auto type = loader.getOrCreateTypeAnnotations(dict->annotation_str());
599   result.unsafeSetKeyType(type->containedType(0));
600   result.unsafeSetValueType(type->containedType(1));
601   return result;
602 }
603 
getOrCreateClassTypeForObject(const mobile::serialization::Object * object)604 ClassTypePtr FlatbufferLoader::getOrCreateClassTypeForObject(
605     const mobile::serialization::Object* object) {
606   auto cls = getType(object->type_index());
607   const mobile::serialization::ObjectType* obj_type =
608       module_->object_types()->Get(object->type_index());
609   if (cls == nullptr) {
610     c10::string_view qn_str(
611         obj_type->type_name()->c_str(), obj_type->type_name()->size());
612     if (qn_str.starts_with(kTorchPrefix) || qn_str.starts_with(kJitPrefix)) {
613       c10::QualifiedName qn(obj_type->type_name()->str());
614       cls = cu_->get_class(qn);
615       if (cls == nullptr) {
616         cls = ClassType::create(qn, cu_, true);
617         cu_->register_type(cls);
618       }
619     } else {
620       cls = c10::parseType(std::string(qn_str))->cast<ClassType>();
621     }
622     TORCH_CHECK(object->type_index() < all_ivalues_.size());
623     all_types_[object->type_index()] = cls;
624 
625     if (obj_type->type() == mobile::serialization::TypeType::CLASS_WITH_FIELD) {
626       for (uint32_t i = 0; i < object->attrs()->size(); i++) {
627         IValue val = getIValue(object->attrs()->Get(i));
628         // Need to use concrete object's field's type to set type of field.
629         cls->addAttribute(
630             obj_type->attr_names()->Get(i)->str(),
631             val.type<c10::DynamicType>());
632       }
633     }
634     initialized_types_.insert(object->type_index());
635   }
636   return cls;
637 }
638 
parseObject(FlatbufferLoader & loader,const mobile::serialization::IValue & ivalue)639 IValue parseObject(
640     FlatbufferLoader& loader,
641     const mobile::serialization::IValue& ivalue) {
642   const mobile::serialization::Object* object = ivalue.val_as_Object();
643   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(object != nullptr);
644   const auto* cur_input = loader.getCurrentFlatbufferInput();
645   const mobile::serialization::ObjectType* obj_type =
646       cur_input->object_types()->Get(object->type_index());
647   auto cls = loader.getOrCreateClassTypeForObject(object);
648   Stack stack;
649   switch (obj_type->type()) {
650     case mobile::serialization::TypeType::CLASS_WITH_FIELD: {
651       auto obj = c10::ivalue::Object::create(
652           at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
653       for (uint32_t i = 0; i < object->attrs()->size(); i++) {
654         IValue val = loader.getIValue(object->attrs()->Get(i));
655         obj->setSlot(i, std::move(val));
656       }
657       return obj;
658     }
659     case mobile::serialization::TypeType::CLASS_WITH_SETSTATE: {
660       IValue input = loader.getIValue(object->state());
661       mobile::Function* setstate = loader.getFunction(object->setstate_func());
662       auto obj =
663           c10::ivalue::Object::create(at::StrongTypePtr(loader.cu_, cls), 0);
664       stack.emplace_back(obj);
665       stack.emplace_back(std::move(input));
666       setstate->run(stack);
667       return obj;
668     }
669     case mobile::serialization::TypeType::CUSTOM_CLASS: {
670       auto custom_class_type =
671           torch::jit::getCustomClass(cls->name()->qualifiedName());
672       IValue input = loader.getIValue(object->state());
673       auto obj = c10::ivalue::Object::create(
674           c10::StrongTypePtr(nullptr, custom_class_type), 1);
675       stack.emplace_back(obj);
676       stack.emplace_back(std::move(input));
677       custom_class_type->getMethod("__setstate__").run(stack);
678       return obj;
679     }
680     default:
681       AT_ASSERT(false, "need to be object");
682   }
683 }
684 
parseIValue(const mobile::serialization::IValue * ivalue)685 IValue FlatbufferLoader::parseIValue(
686     const mobile::serialization::IValue* ivalue) {
687   return ivalue_parsers_[static_cast<uint32_t>(ivalue->val_type())](
688       *this, *ivalue);
689 }
690 
691 void deleteNothing2(void*);
deleteNothing2(void *)692 void deleteNothing2(void*) {}
693 
getStorage(uint32_t index)694 c10::Storage FlatbufferLoader::getStorage(uint32_t index) {
695   TORCH_CHECK(index < storage_loaded_.size());
696   TORCH_CHECK(index < storages_.size());
697   if (!storage_loaded_[index]) {
698     auto* storage = module_->storage_data()->GetMutableObject(index);
699     size_t size = storage->data()->size();
700 
701     at::DataPtr data;
702     if (should_copy_tensor_memory_) {
703       auto* allocator = at::GetCPUAllocator();
704       data = allocator->allocate(size);
705       memcpy(data.get(), storage->data()->data(), size);
706     } else {
707       void* ptr = static_cast<void*>(storage->mutable_data()->data());
708       data = at::DataPtr(ptr, ptr, deleteNothing2, DeviceType::CPU);
709     }
710     storages_[index] =
711         c10::Storage(c10::Storage::use_byte_size_t(), size, std::move(data));
712     storage_loaded_[index] = true;
713   }
714   return storages_[index];
715 }
716 
getOrCreateTypeAnnotations(const flatbuffers::String * offset)717 TypePtr FlatbufferLoader::getOrCreateTypeAnnotations(
718     const flatbuffers::String* offset) {
719   auto iter = type_annotations_.find(offset);
720   if (iter != type_annotations_.end()) {
721     return iter->second;
722   }
723   TypePtr type = type_resolver_(offset->str(), cu_);
724   type_annotations_[offset] = type;
725   return type;
726 }
727 
extractJitSourceAndConstants(ExtraFilesMap * jit_sources,std::vector<IValue> * constants)728 void FlatbufferLoader::extractJitSourceAndConstants(
729     ExtraFilesMap* jit_sources,
730     std::vector<IValue>* constants) {
731   AT_ASSERT(
732       module_parsed_,
733       "Need to first parse a flatbuffer file before extracting jit_sources");
734 
735   const auto* ivalues = module_->ivalues();
736   for (uint32_t i = mobile_ivalue_size_; i < ivalues->size(); i++) {
737     const auto* ival = ivalues->Get(i);
738     parseAndPopulate(i, ival);
739   }
740   // register functions
741   for (const auto& f : all_functions_) {
742     if (f.first >= mobile_ivalue_size_) {
743       uint32_t class_index =
744           ivalues->Get(f.first)->val_as_Function()->class_type();
745       ClassTypePtr class_type = all_types_[class_index];
746       class_type->addMethod(f.second);
747     }
748   }
749   const auto* jit_constants = module_->jit_constants();
750   for (const auto i : c10::irange(jit_constants->size())) {
751     constants->emplace_back(getIValue(jit_constants->Get(i)));
752   }
753   parseExtraFilesFromVector(module_->jit_sources(), jit_sources);
754 }
755 
756 } // namespace
757 
parse_and_initialize_mobile_module(void * data,size_t size,std::optional<at::Device>,ExtraFilesMap * extra_files,bool should_copy_tensor_memory)758 mobile::Module parse_and_initialize_mobile_module(
759     void* data,
760     size_t size,
761     std::optional<at::Device>,
762     ExtraFilesMap* extra_files,
763     bool should_copy_tensor_memory) {
764   // TODO(T128189662): If not copying, enforce that data is aligned to
765   // kFlatbufferDataAlignmentBytes, and add unit tests.
766 
767   // Validate Flatbuffer module before parsing.
768   flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data), size);
769   TORCH_CHECK(
770       mobile::serialization::VerifyModuleBuffer(verifier),
771       "Malformed Flatbuffer module");
772 
773   FlatbufferLoader loader;
774   loader.setShouldCopyTensorMemory(should_copy_tensor_memory);
775 
776   // Flatbuffer doesn't seem to have a way to provide the buffer size when
777   // interacting with the buffer.
778   auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
779   auto* end = static_cast<char*>(data) + size;
780   mobile::Module m = loader.parseModule(flatbuffer_module, end);
781   if (extra_files != nullptr) {
782     parseExtraFiles(flatbuffer_module, *extra_files);
783   }
784   return m;
785 }
786 
parse_and_initialize_mobile_module(std::shared_ptr<char> data,size_t size,std::optional<at::Device> device,ExtraFilesMap * extra_files)787 mobile::Module parse_and_initialize_mobile_module(
788     std::shared_ptr<char> data,
789     size_t size,
790     std::optional<at::Device> device,
791     ExtraFilesMap* extra_files) {
792   mobile::Module m = parse_and_initialize_mobile_module(
793       data.get(),
794       size,
795       device,
796       extra_files,
797       /*should_copy_tensor_memory=*/false);
798   m.set_delete_memory(std::move(data));
799   return m;
800 }
801 
parse_and_initialize_mobile_module_for_jit(void * data,size_t size,ExtraFilesMap & jit_sources,std::vector<IValue> & jit_constants,std::optional<at::Device>,ExtraFilesMap * extra_files)802 mobile::Module parse_and_initialize_mobile_module_for_jit(
803     void* data,
804     size_t size,
805     ExtraFilesMap& jit_sources,
806     std::vector<IValue>& jit_constants,
807     std::optional<at::Device>,
808     ExtraFilesMap* extra_files) {
809   TORCH_CHECK(
810       mobile::serialization::ModuleBufferHasIdentifier(data), "Format error");
811   // TODO(T128189662): Enforce that data is aligned to
812   // kFlatbufferDataAlignmentBytes, and add unit tests.
813 
814   // Validate Flatbuffer module before parsing.
815   flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data), size);
816   TORCH_CHECK(
817       mobile::serialization::VerifyModuleBuffer(verifier),
818       "Malformed Flatbuffer module");
819 
820   FlatbufferLoader loader;
821   auto* flatbuffer_module = mobile::serialization::GetMutableModule(data);
822   auto* end = static_cast<char*>(data) + size;
823   mobile::Module m = loader.parseModule(flatbuffer_module, end);
824   if (extra_files != nullptr) {
825     parseExtraFiles(flatbuffer_module, *extra_files);
826   }
827 
828   loader.extractJitSourceAndConstants(&jit_sources, &jit_constants);
829   return m;
830 }
831 
load_mobile_module_from_file(const std::string & filename,std::optional<c10::Device> device,ExtraFilesMap * extra_files)832 mobile::Module load_mobile_module_from_file(
833     const std::string& filename,
834     std::optional<c10::Device> device,
835     ExtraFilesMap* extra_files) {
836   auto [data, size] = get_file_content(filename.c_str());
837   return parse_and_initialize_mobile_module(
838       std::move(data), size, device, extra_files);
839 }
840 
get_bytecode_version(std::istream & in)841 uint64_t get_bytecode_version(std::istream& in) {
842   auto [data, size] = get_stream_content(in);
843   return get_bytecode_version_from_bytes(data.get());
844 }
845 
get_bytecode_version(const std::string & filename)846 uint64_t get_bytecode_version(const std::string& filename) {
847   auto [data, size] = get_file_content(filename.c_str());
848   return get_bytecode_version_from_bytes(data.get());
849 }
850 
get_bytecode_version_from_bytes(char * flatbuffer_content)851 uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content) {
852   TORCH_CHECK(
853       mobile::serialization::ModuleBufferHasIdentifier(flatbuffer_content),
854       "Format error");
855   auto* flatbuffer_module =
856       mobile::serialization::GetMutableModule(flatbuffer_content);
857   return flatbuffer_module->bytecode_version();
858 }
859 
get_module_info_from_flatbuffer(char * flatbuffer_content)860 mobile::ModuleInfo get_module_info_from_flatbuffer(char* flatbuffer_content) {
861   auto* ff_module = mobile::serialization::GetMutableModule(flatbuffer_content);
862   mobile::ModuleInfo minfo;
863   minfo.operator_version = ff_module->operator_version();
864   minfo.bytecode_version = ff_module->bytecode_version();
865 
866   uint32_t mobile_ivalue_size = ff_module->mobile_ivalue_size();
867   if (mobile_ivalue_size == 0) {
868     mobile_ivalue_size = ff_module->ivalues()->size();
869   }
870 
871   std::vector<std::string> type_name_list;
872   for (uint32_t i = 0; i < mobile_ivalue_size; i++) {
873     const auto* ival = ff_module->ivalues()->Get(i);
874     if (const auto* func = ival->val_as_Function()) {
875       minfo.function_names.insert(func->qn()->str());
876       for (const auto* op : *func->operators()) {
877         at::OperatorName opname(op->name()->str(), op->overload_name()->str());
878         minfo.opname_to_num_args[mobile::operator_str(opname)] =
879             op->num_args_serialized();
880       }
881       for (const auto* type_ann : *func->type_annotations()) {
882         type_name_list.push_back(type_ann->str());
883       }
884     }
885   }
886   c10::TypeParser parser(type_name_list);
887   parser.parseList();
888   minfo.type_names = parser.getContainedTypes();
889   return minfo;
890 }
891 
load_mobile_module_from_stream_with_copy(std::istream & in,std::optional<at::Device> device,ExtraFilesMap * extra_files)892 mobile::Module load_mobile_module_from_stream_with_copy(
893     std::istream& in,
894     std::optional<at::Device> device,
895     ExtraFilesMap* extra_files) {
896   auto [data, size] = get_stream_content(in);
897   return parse_and_initialize_mobile_module(
898       std::move(data), size, device, extra_files);
899 }
900 
parse_flatbuffer_no_object(std::shared_ptr<char> data,size_t size,std::optional<at::Device> device)901 mobile::Module parse_flatbuffer_no_object(
902     std::shared_ptr<char> data,
903     size_t size,
904     std::optional<at::Device> device) {
905   (void)device;
906   (void)size;
907 
908   // Validate Flatbuffer module before parsing.
909   flatbuffers::Verifier verifier(reinterpret_cast<uint8_t*>(data.get()), size);
910   TORCH_CHECK(
911       mobile::serialization::VerifyModuleBuffer(verifier),
912       "Malformed Flatbuffer module");
913 
914   auto* flatbuffer_module = mobile::serialization::GetMutableModule(data.get());
915   FlatbufferLoader loader;
916   // replace parserObject with to handle only class with field case
917   // function.
918   loader.registerIValueParser(
919       mobile::serialization::IValueUnion::Object,
920       +[](FlatbufferLoader& loader,
921           const mobile::serialization::IValue& ivalue) {
922         const mobile::serialization::Object* object = ivalue.val_as_Object();
923         auto cls = loader.getOrCreateClassTypeForObject(object);
924         auto obj = c10::ivalue::Object::create(
925             at::StrongTypePtr(loader.cu_, cls), object->attrs()->size());
926         for (uint32_t i = 0; i < object->attrs()->size(); i++) {
927           IValue val = loader.getIValue(object->attrs()->Get(i));
928           obj->setSlot(i, std::move(val));
929         }
930         return static_cast<c10::IValue>(obj);
931       });
932 
933   auto* end = data.get() + size;
934   mobile::Module m = loader.parseModule(flatbuffer_module, end);
935   m.set_delete_memory(std::move(data));
936   return m;
937 }
938 
register_flatbuffer_loader()939 bool register_flatbuffer_loader() {
940   return true;
941 }
942 
943 } // namespace torch::jit
944