xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/import.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/import.h>
2 #include <torch/csrc/jit/mobile/parse_bytecode.h>
3 #include <torch/csrc/jit/mobile/parse_operators.h>
4 
5 #include <ATen/core/ivalue.h>
6 #include <ATen/core/qualified_name.h>
7 #include <c10/util/Exception.h>
8 #include <c10/util/ScopeExit.h>
9 #include <c10/util/irange.h>
10 #include <caffe2/serialize/in_memory_adapter.h>
11 #include <caffe2/serialize/inline_container.h>
12 #include <caffe2/serialize/istream_adapter.h>
13 #include <caffe2/serialize/versions.h>
14 #include <torch/csrc/jit/api/compilation_unit.h>
15 #include <torch/csrc/jit/mobile/file_format.h>
16 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
17 #include <torch/csrc/jit/mobile/observer.h>
18 #include <torch/csrc/jit/mobile/type_parser.h>
19 #include <torch/csrc/jit/mobile/upgrader_mobile.h>
20 #include <torch/csrc/jit/runtime/instruction.h>
21 #include <torch/csrc/jit/serialization/import_export_constants.h>
22 #include <torch/csrc/jit/serialization/import_export_functions.h>
23 #include <torch/csrc/jit/serialization/import_read.h>
24 #include <torch/custom_class.h>
25 #include <optional>
26 #include <string>
27 #include <vector>
28 
29 // The import process to serialize the bytecode package.
30 // An example for bytecode.pkl of a small mobile_module looks like:
31 // (4,  # model version number (caffe2::serialize::kProducedBytecodeVersion)
32 //  # first method
33 //  (
34 //   # function name
35 //   '__torch__.m.forward',
36 //   # code
37 //   (('instructions',
38 //     (('STOREN', 1, 2),
39 //      ('DROPR', 1, 0),
40 //      ('MOVE', 2, 0),
41 //      ('OP', 0, 0),
42 //      ('RET', 0, 0))),
43 //    ('operators', (('aten::Int', 'Tensor'),)),
44 //    ('constants', ()),
45 //    ('types', ()),
46 //    ('register_size', 2)),
47 //   # schema -- optional (forward-compatible addition to version 4)
48 //   (('arguments',
49 //     ((('name', 'x'), ('type', 'Tensor'), ('default_value', 13)),
50 //      ...)),  # more args follow here
51 //    ('returns',
52 //     ((('name', ''), ('type', 'Tensor'), ('default_value', None)),
53 //      ...)),  # more return values follow here
54 //   )),
55 //  # more methods follow here
56 //  ...)
57 
58 // In addition, the module debugging information can be saved
59 // in mobile_debug_handles.pkl. An example for it looks like:
60 // (4,
61 //  ('__torch__.m.forward',
62 //   (('module_debug_handles', 10))))
63 //   Here 10 is the debug handle.
64 // We also store separately and optionally callstack_debug_map.
65 // This serializes inlined callstack (InlinedCallStack data structure)
66 // corresponding to the debug handles.
67 // Callstack_debug_map serializes tuples of
68 // (int64_t(debug_handle), int64_t(source_range_tag), InlinedCallStack)
69 // source_range_tag maps to .debug_pkl files where this tag maps it to
70 // source range.
71 // InlinedCallStack is serialized as:
72 // IValue(InlinedCallStack) = {IValue(ModuleInstanceInfo),
73 // int64_t(source_range_tag), IValue(InlinedCallStack)} ModuleInstanceInfo is
74 // serialized as a tuple of (class_type_name, instance_name)
75 
76 // Note that currently the backward compatibility is not supported by bytecode.
77 // This format and process need to be revisited and redesigned if we want to
78 // support backward compatibility in future.
79 
80 // Note that the following function-schema fields are not supported:
81 //  - Argument::{known_length_,kwarg_only_}
82 //  - FunctionSchema::{overload_name_, is_vararg_, is_varret_}
83 
84 namespace torch::jit {
85 using caffe2::serialize::MemoryReadAdapter;
86 using caffe2::serialize::PyTorchStreamReader;
87 using caffe2::serialize::ReadAdapterInterface;
88 
89 OpCode parseOpCode(const char* str);
90 
resolveTypeNameMobile(const c10::QualifiedName & qn,const std::shared_ptr<CompilationUnit> & compilation_unit)91 TypePtr resolveTypeNameMobile(
92     const c10::QualifiedName& qn,
93     const std::shared_ptr<CompilationUnit>& compilation_unit) {
94   // HACK: first we check whether the name starts with special prefix to
95   // tell if it's a supported pytorch class type. There are two special
96   // prefixes. "__torch__" for nn module, and "torch.jit" from to_backend.
97   // This is a reliable
98   // check today, but there is no guarantee that this is the case. The
99   // real solution is to merge type parsers so we can share class
100   // resolution logic.
101   static const c10::QualifiedName torchPrefix = "__torch__";
102   static const c10::QualifiedName jitPrefix = "torch.jit";
103   if (torchPrefix.isPrefixOf(qn) || jitPrefix.isPrefixOf(qn)) {
104     if (compilation_unit->get_class(qn) == nullptr) {
105       auto typeptr = ClassType::create(qn, compilation_unit, true);
106       compilation_unit->register_type(typeptr);
107     }
108     return compilation_unit->get_class(qn);
109   } else {
110     return c10::parseType(qn.qualifiedName());
111   }
112 }
113 
typeResolverMobile(const c10::QualifiedName & qn,const std::shared_ptr<CompilationUnit> & compilation_unit)114 c10::StrongTypePtr typeResolverMobile(
115     const c10::QualifiedName& qn,
116     const std::shared_ptr<CompilationUnit>& compilation_unit) {
117   return c10::StrongTypePtr(
118       compilation_unit, resolveTypeNameMobile(qn, compilation_unit));
119 }
120 
objLoaderMobile(const at::StrongTypePtr & type,const IValue & input,mobile::CompilationUnit & mobile_compilation_unit)121 c10::intrusive_ptr<c10::ivalue::Object> objLoaderMobile(
122     const at::StrongTypePtr& type,
123     const IValue& input,
124     mobile::CompilationUnit& mobile_compilation_unit) {
125   auto cls = type.type_->expect<at::ClassType>();
126   auto qn = cls->name();
127   c10::QualifiedName method_name(qn.value(), "__setstate__");
128   auto setstate = mobile_compilation_unit.find_function(method_name);
129   auto find_custom_class_with_setstate = [&qn]() -> c10::ClassTypePtr {
130     auto custom_class_type = torch::jit::getCustomClass(qn->qualifiedName());
131     if (custom_class_type && custom_class_type->findMethod("__setstate__")) {
132       return custom_class_type;
133     }
134     return nullptr;
135   };
136   if (setstate) {
137     auto obj = c10::ivalue::Object::create(type, 0);
138     Stack stack({obj, input});
139     setstate->run(stack);
140     return obj;
141   } else if (auto custom_class_type = find_custom_class_with_setstate()) {
142     auto obj = c10::ivalue::Object::create(
143         c10::StrongTypePtr(nullptr, custom_class_type), 1);
144     Stack stack({obj, input});
145     custom_class_type->getMethod("__setstate__").run(stack);
146     return obj;
147   } else {
148     auto dict = input.toGenericDict();
149     size_t ndict = dict.size();
150     auto obj = c10::ivalue::Object::create(type, ndict);
151     auto it = dict.begin();
152     for (const auto i : c10::irange(ndict)) {
153       cls->addOrCheckAttribute(it->key().toStringRef(), it->key().type());
154       obj->setSlot(i, it->value());
155       ++it;
156     }
157     return obj;
158   }
159 }
160 
isTensorInBytecodeArchive(caffe2::serialize::PyTorchStreamReader & stream_reader)161 bool isTensorInBytecodeArchive(
162     caffe2::serialize::PyTorchStreamReader& stream_reader) {
163   auto records = stream_reader.getAllRecords();
164   for (const auto& record : records) {
165     if (record.find("bytecode/") != std::string::npos) {
166       return true;
167     }
168   }
169   return false;
170 }
171 
172 namespace {
173 
tryRegisterMethod(const std::vector<c10::Argument> & args,Function & func)174 void tryRegisterMethod(const std::vector<c10::Argument>& args, Function& func) {
175   if (args.empty() || args[0].name() != "self") {
176     return;
177   }
178 
179   if (auto cls = args[0].type()->castRaw<ClassType>()) {
180     if (C10_UNLIKELY(cls->findMethod(func.name()))) {
181       return;
182     }
183     cls->addMethod(&func);
184   }
185 }
186 
187 // The deserializer class which loads the bytecode package from bc files.
188 class BytecodeDeserializer final {
189  public:
190   explicit BytecodeDeserializer(
191       std::unique_ptr<PyTorchStreamReader> reader,
192       uint64_t module_load_options = 0);
193   mobile::Module deserialize(std::optional<at::Device> device);
194   mobile::Module deserialize(
195       std::optional<at::Device> device,
196       ExtraFilesMap& extra_files);
197   void deserialize_only_extra(
198       std::optional<at::Device> device,
199       ExtraFilesMap& extra_files);
200 
201  private:
202   TypePtr resolveTypeName(const c10::QualifiedName& qn);
203   void init_upgrader(mobile::Function* function);
204   void parseMethods(
205       c10::ivalue::TupleElements&& vals,
206       std::optional<c10::ivalue::TupleElements>&& debug_handles,
207       mobile::CompilationUnit& mcu);
208   c10::IValue readArchive(
209       const std::string& archive_name,
210       std::shared_ptr<mobile::CompilationUnit> mcu);
211   void parseFunctionSchema(
212       const std::string& function_name,
213       IValue* schemaTable,
214       const int64_t& model_version,
215       mobile::Function* function);
216   std::shared_ptr<CompilationUnit> compilation_unit_;
217   std::unordered_set<std::string> imported_libs_;
218   std::unique_ptr<PyTorchStreamReader> reader_{};
219   std::optional<at::Device> device_;
220   uint64_t module_load_options_;
221   // From `version` or `.data/version` in model.ptl and it's compute
222   // dynamically. It's used for finding the minimum required runtime to run all
223   // operators from the given model. If it's less than the current runtime,
224   // upgrader will be applied at loading stage.
225   uint64_t operator_version_{0};
226   uint64_t bytecode_version_{0};
227 };
228 
BytecodeDeserializer(std::unique_ptr<PyTorchStreamReader> reader,uint64_t module_load_options)229 BytecodeDeserializer::BytecodeDeserializer(
230     std::unique_ptr<PyTorchStreamReader> reader,
231     uint64_t module_load_options)
232     : compilation_unit_(std::make_shared<CompilationUnit>()),
233       reader_(std::move(reader)),
234       module_load_options_(module_load_options) {}
235 
resolveTypeName(const c10::QualifiedName & qn)236 TypePtr BytecodeDeserializer::resolveTypeName(const c10::QualifiedName& qn) {
237   return resolveTypeNameMobile(qn, compilation_unit_);
238 }
239 
240 // It requires compilation_unit_ when parsing function schema. Keep it in
241 // BytecodeDeserializer. It may be refacotred later to make it independent
242 // of the specific BytecodeDeserializer, like parsing other tables
parseFunctionSchema(const std::string & function_name,IValue * schemaTable,const int64_t & model_version,mobile::Function * function)243 void BytecodeDeserializer::parseFunctionSchema(
244     const std::string& function_name,
245     IValue* schemaTable,
246     const int64_t& model_version,
247     mobile::Function* function) {
248   // function schema
249   if (schemaTable) { // (schema is optional for back compat)
250     auto parseArgList = [this,
251                          function](c10::ivalue::TupleElements&& argTables) {
252       std::vector<c10::Argument> args;
253       for (auto& argTable : argTables) {
254         auto argTableElements = std::move(argTable.toTupleRef()).elements();
255         auto name =
256             expect_field(argTableElements, "name", BYTECODE_INDEX_ARGUMENT_NAME)
257                 .toStringRef();
258         c10::TypePtr type = resolveTypeName(
259             (expect_field(
260                  argTableElements, "type", BYTECODE_INDEX_ARGUMENT_TYPE))
261                 .toStringRef());
262         IValue default_value = expect_field(
263             argTableElements,
264             "default_value",
265             BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE);
266         args.emplace_back(
267             name,
268             std::move(type),
269             std::nullopt /*N*/,
270             std::move(default_value));
271       }
272       tryRegisterMethod(args, *function);
273       return args;
274     };
275     auto schemaTableElements = std::move(schemaTable->toTupleRef()).elements();
276     auto arg_list = std::move(expect_field(
277                                   schemaTableElements,
278                                   "arguments",
279                                   BYTECODE_INDEX_SCHEMA_ARGUMENTS)
280                                   .toTupleRef())
281                         .elements();
282     auto ret_list =
283         std::move(
284             expect_field(
285                 schemaTableElements, "returns", BYTECODE_INDEX_SCHEMA_RETURNS)
286                 .toTupleRef())
287             .elements();
288     c10::FunctionSchema schema(
289         function_name,
290         "" /*overload_name*/,
291         parseArgList(std::move(arg_list)),
292         parseArgList(std::move(ret_list)),
293         false /*is_varargs*/,
294         false /*is_varret*/);
295     function->setSchema(std::move(schema));
296   }
297 }
298 
init_upgrader(mobile::Function * function)299 void BytecodeDeserializer::init_upgrader(mobile::Function* function) {
300   for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
301     function->append_function(byteCodeFunctionWithOperator.function);
302   }
303 }
304 
parseMethods(c10::ivalue::TupleElements && vals,std::optional<c10::ivalue::TupleElements> && debug_handles,mobile::CompilationUnit & mcu)305 void BytecodeDeserializer::parseMethods(
306     c10::ivalue::TupleElements&& vals,
307     std::optional<c10::ivalue::TupleElements>&& debug_handles,
308     mobile::CompilationUnit& mcu) {
309   TORCH_CHECK(!vals.empty(), "Bytecode has no elements. ");
310   // Initialized with the version number when kProducedBytecodeVersion was
311   // introduced. The old models (some of them already in production) without
312   // version number are seen as version 3 (deprecated).
313   constexpr uint64_t default_version = 0x3L;
314   bytecode_version_ = default_version;
315   size_t method_i_start = 0;
316   if (vals[0].isInt()) {
317     bytecode_version_ = vals[0].toInt();
318     method_i_start = 1;
319   }
320   TORCH_CHECK(
321       caffe2::serialize::kMinSupportedBytecodeVersion <= bytecode_version_ &&
322           bytecode_version_ <= caffe2::serialize::kMaxSupportedBytecodeVersion,
323       "Lite Interpreter version number does not match. ",
324       "The model version must be between ",
325       caffe2::serialize::kMinSupportedBytecodeVersion,
326       " and ",
327       caffe2::serialize::kMaxSupportedBytecodeVersion,
328       " but the model version is ",
329       bytecode_version_);
330 
331   if (debug_handles) {
332     TORCH_CHECK(
333         debug_handles->size() == vals.size(),
334         "The numbers of bytecode values and debug info values do not match.");
335   }
336 
337   // Process all methods in this mobile module.
338   for (const auto i : c10::irange(method_i_start, vals.size())) {
339     auto element = std::move(vals[i]);
340     auto m_tuple = std::move(element.toTupleRef()).elements();
341     const std::string& function_name = m_tuple[0].toStringRef();
342     auto codeTableElements =
343         std::move(std::move(m_tuple[1]).toTupleRef()).elements();
344     IValue* schemaTable = // older files do not store function schema
345         (bytecode_version_ > 0x4L ||
346          (bytecode_version_ == 0x4L && m_tuple.size() >= 3))
347         ? &m_tuple[2]
348         : nullptr;
349     auto function =
350         std::make_unique<mobile::Function>(c10::QualifiedName(function_name));
351 
352     auto ins_list =
353         std::move(
354             expect_field(
355                 codeTableElements, "instructions", BYTECODE_INDEX_INSTRUCTION)
356                 .toTupleRef())
357             .elements();
358     auto ops_list =
359         std::move(expect_field(
360                       codeTableElements, "operators", BYTECODE_INDEX_OPERATOR)
361                       .toTupleRef())
362             .elements();
363     auto consts_list =
364         std::move(expect_field(
365                       codeTableElements, "constants", BYTECODE_INDEX_CONSTANT)
366                       .toTupleRef())
367             .elements();
368     auto types_list =
369         std::move(expect_field(codeTableElements, "types", BYTECODE_INDEX_TYPE)
370                       .toTupleRef())
371             .elements();
372     int64_t register_size =
373         expect_field(
374             codeTableElements, "register_size", BYTECODE_INDEX_REGISTER_SIZE)
375             .toInt();
376 
377     c10::ivalue::TupleElements debug_handles_m_tuple;
378     if (debug_handles) {
379       debug_handles_m_tuple =
380           std::move(std::move((*debug_handles)[i]).toTupleRef()).elements();
381     }
382     init_upgrader(function.get());
383     // 1. First pass all operators from models
384     parseOperators(std::move(ops_list), module_load_options_, function.get());
385 
386     // 2. Decides if upgrader is needed
387     bool use_upgrader =
388         (operator_version_ < caffe2::serialize::kProducedFileFormatVersion);
389 
390     parseInstructions(
391         function_name,
392         std::move(ins_list),
393         debug_handles_m_tuple,
394         function.get());
395 
396     // 3. If upgrader is needed, change change the OP instrunction to CALL
397     // instruction (In next PR, use_upgrader will be parsed to parseInstruction
398     // function and do the actual change)
399     if (use_upgrader) {
400       applyUpgrader(function.get(), operator_version_);
401     }
402 
403     parseConstants(consts_list, function.get());
404 
405     parseTypes(types_list, function.get());
406 
407     function->set_register_size(register_size);
408 
409     parseFunctionSchema(
410         function_name, schemaTable, bytecode_version_, function.get());
411 
412     mcu.register_function(std::move(function));
413   }
414 }
415 
deserialize_only_extra(std::optional<at::Device> device,ExtraFilesMap & extra_files)416 void BytecodeDeserializer::deserialize_only_extra(
417     std::optional<at::Device> device,
418     ExtraFilesMap& extra_files) {
419   device_ = device;
420   for (const auto& kv : extra_files) {
421     const std::string& key = "extra/" + kv.first;
422     if (reader_->hasRecord(key)) {
423       auto [meta_ptr, meta_size] = reader_->getRecord(key);
424       extra_files[kv.first] =
425           std::string(static_cast<char*>(meta_ptr.get()), meta_size);
426     }
427   }
428 }
429 
deserialize(std::optional<at::Device> device,ExtraFilesMap & extra_files)430 mobile::Module BytecodeDeserializer::deserialize(
431     std::optional<at::Device> device,
432     ExtraFilesMap& extra_files) {
433   deserialize_only_extra(device, extra_files);
434   return deserialize(device);
435 }
436 
deserialize(std::optional<at::Device> device)437 mobile::Module BytecodeDeserializer::deserialize(
438     std::optional<at::Device> device) {
439   device_ = device;
440   auto mcu = std::make_shared<mobile::CompilationUnit>();
441 
442   // bvals can have 2 possible formats:
443   //
444   // 1. Old format: bvals is an array (Tuple) of N elements, each element being
445   // itself a Tuple(method_name, method_table).
446   //
447   // 2. New format: bvals is an array (Tuple) of 1+N elements. The first element
448   // being a Tuple (int, table), and the integer stands for the bytecode version
449   // number. The rest of the elements are the same as before.
450   //
451   auto bvals = std::move(readArchive("bytecode", mcu).toTupleRef()).elements();
452 
453   std::optional<c10::ivalue::TupleElements> debug_handles;
454   bool has_debug_handles{false};
455   if (reader_->hasRecord("mobile_debug_handles.pkl")) {
456     debug_handles =
457         std::move(readArchive("mobile_debug_handles", mcu).toTupleRef())
458             .elements();
459     has_debug_handles = true;
460   }
461   operator_version_ = reader_->version();
462   parseMethods(std::move(bvals), std::move(debug_handles), *mcu);
463   auto m = mobile::Module(readArchive("data", mcu).toObject(), mcu);
464   m.set_min_operator_version(operator_version_);
465   m.set_bytecode_version(bytecode_version_);
466   m.setHasDebugHandles(has_debug_handles);
467 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)
468   MobileDebugTable debug_table = MobileDebugTable(reader_, compilation_unit_);
469   m.setDebugTable(std::move(debug_table));
470 #endif
471   return m;
472 }
473 
readArchive(const std::string & archive_name,std::shared_ptr<mobile::CompilationUnit> mcu)474 c10::IValue BytecodeDeserializer::readArchive(
475     const std::string& archive_name,
476     std::shared_ptr<mobile::CompilationUnit> mcu) {
477   auto type_resolver = [this](const c10::QualifiedName& qn) {
478     return typeResolverMobile(qn, compilation_unit_);
479   };
480 
481   auto obj_loader = [&](const at::StrongTypePtr& type, const IValue& input) {
482     return objLoaderMobile(type, input, *mcu);
483   };
484 
485   bool bytecode_tensor_in_constants_archive =
486       (archive_name == "bytecode" && !isTensorInBytecodeArchive(*reader_));
487 
488   auto ivalues = torch::jit::readArchiveAndTensors(
489       archive_name,
490       /*pickle_prefix=*/"",
491       /*tensor_prefix=*/
492       bytecode_tensor_in_constants_archive ? "constants/" : "",
493       type_resolver,
494       obj_loader,
495       device_,
496       *reader_,
497       nullptr);
498   return ivalues;
499 }
500 
_load_for_mobile_impl(std::unique_ptr<ReadAdapterInterface> rai,std::optional<c10::Device> device,ExtraFilesMap & extra_files,uint64_t module_load_options)501 mobile::Module _load_for_mobile_impl(
502     std::unique_ptr<ReadAdapterInterface> rai,
503     std::optional<c10::Device> device,
504     ExtraFilesMap& extra_files,
505     uint64_t module_load_options) {
506   auto observer = torch::observerConfig().getModuleObserver();
507   // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
508   auto instance_key = std::rand();
509 
510   std::unordered_map<std::string, std::string> metadata_map;
511   if (observer) {
512     observer->onEnterLoadModel(instance_key);
513     auto defaultExtraFileList = observer->getDefaultExtraFiles();
514     // Add files in defaultExtraFileList to fail_extra_files and extra_files
515     for (const auto& fileName : defaultExtraFileList) {
516       extra_files.insert(std::make_pair(fileName, ""));
517     }
518   }
519 
520   const size_t model_size = rai != nullptr ? rai->size() : 0;
521   auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
522   if (module_load_options &
523       MobileModuleLoadOptions::PARSE_ALL_EXTRA_FILE_MAPS) {
524     // ExtraFilesMap is serialized with a "extra/", hence it is necessary to
525     // account for when we de-serialize de-serialized filemap key values contain
526     // prefix and we need to remove prior to construct the map. "extra/" string
527     // has a length of 6 characters, hence we need only sub-string 6th position
528     // of a string. Please refer to following link for a detail:
529     // https://www.internalfb.com/code/fbsource/[9996fcb7a6fb]/fbcode/caffe2/torch/csrc/jit/mobile/import.cpp?lines=427-434
530     std::vector<std::string> all_files = reader->getAllRecords();
531     for (auto& file_name : all_files) {
532       if (file_name.find("extra/") == 0) {
533         extra_files[file_name.substr(6)] = "";
534       }
535     }
536   }
537   BytecodeDeserializer deserializer(std::move(reader), module_load_options);
538 
539   std::string error_message;
540   auto guard = c10::make_scope_exit([&]() {
541     if (!observer) {
542       return;
543     }
544     deserializer.deserialize_only_extra(device, extra_files);
545 
546     metadata_map = observer->processMetadataFromExtra(extra_files);
547 
548     observer->onFailLoadModel(
549         instance_key,
550         error_message.empty() ? "Unknown exception" : error_message.c_str(),
551         metadata_map);
552   });
553 
554   try {
555     mobile::Module result = deserializer.deserialize(device, extra_files);
556     if (observer) {
557       // Add model_name and model_size to metadata_map
558       extra_files.insert(std::make_pair("model_name", result.name()));
559       extra_files.insert(
560           std::make_pair("model_size", std::to_string(model_size)));
561       metadata_map = observer->processMetadataFromExtra(extra_files);
562       observer->onExitLoadModel(instance_key, metadata_map);
563     }
564     result.setMetadata(metadata_map);
565     guard.release();
566     return result;
567   } catch (c10::Error& error) {
568     error_message = error.what();
569     TORCH_RETHROW(error);
570   }
571 }
572 
_load_mobile_from_bytes(const std::shared_ptr<char> & data,size_t size,std::optional<c10::Device> device,ExtraFilesMap & extra_files,uint64_t module_load_options)573 mobile::Module _load_mobile_from_bytes(
574     const std::shared_ptr<char>& data,
575     size_t size,
576     std::optional<c10::Device> device,
577     ExtraFilesMap& extra_files,
578     uint64_t module_load_options) {
579   TORCH_CHECK(size >= kFileFormatHeaderSize, "Format error");
580   auto format = getFileFormat(data.get());
581   switch (format) {
582     case FileFormat::ZipFileFormat: {
583       std::unique_ptr<ReadAdapterInterface> rai =
584           std::make_unique<MemoryReadAdapter>(data.get(), size);
585       return _load_for_mobile_impl(
586           std::move(rai), device, extra_files, module_load_options);
587     }
588     case FileFormat::FlatbufferFileFormat: {
589       return parse_and_initialize_mobile_module(
590           data, size, device, &extra_files);
591     }
592     default: {
593       TORCH_CHECK(false, "Format error");
594     }
595   }
596 }
597 
598 } // namespace
599 
_load_for_mobile(std::istream & in,std::optional<at::Device> device)600 mobile::Module _load_for_mobile(
601     std::istream& in,
602     std::optional<at::Device> device) {
603   ExtraFilesMap extra_files;
604   return _load_for_mobile(in, device, extra_files);
605 }
606 
_load_for_mobile(const std::string & filename,std::optional<at::Device> device)607 mobile::Module _load_for_mobile(
608     const std::string& filename,
609     std::optional<at::Device> device) {
610   ExtraFilesMap extra_files;
611   return _load_for_mobile(filename, device, extra_files);
612 }
613 
_load_for_mobile(std::unique_ptr<ReadAdapterInterface> rai,std::optional<c10::Device> device)614 mobile::Module _load_for_mobile(
615     std::unique_ptr<ReadAdapterInterface> rai,
616     std::optional<c10::Device> device) {
617   ExtraFilesMap extra_files;
618   return _load_for_mobile(std::move(rai), device, extra_files);
619 }
620 
_load_for_mobile(std::istream & in,std::optional<at::Device> device,ExtraFilesMap & extra_files,uint64_t module_load_options)621 mobile::Module _load_for_mobile(
622     std::istream& in,
623     std::optional<at::Device> device,
624     ExtraFilesMap& extra_files,
625     uint64_t module_load_options) {
626   if (getFileFormat(in) == FileFormat::FlatbufferFileFormat) {
627     auto [data, size] = get_stream_content(in);
628     return _load_mobile_from_bytes(
629         data, size, device, extra_files, module_load_options);
630   }
631   auto rai = std::make_unique<caffe2::serialize::IStreamAdapter>(&in);
632   auto module = _load_for_mobile_impl(
633       std::move(rai), device, extra_files, module_load_options);
634   return module;
635 }
636 
_load_for_mobile(const std::string & filename,std::optional<at::Device> device,ExtraFilesMap & extra_files)637 mobile::Module _load_for_mobile(
638     const std::string& filename,
639     std::optional<at::Device> device,
640     ExtraFilesMap& extra_files) {
641   return _load_for_mobile(
642       filename, device, extra_files, kDefaultMobileLoadOptions);
643 }
644 
_load_for_mobile(const std::string & filename,std::optional<at::Device> device,ExtraFilesMap & extra_files,uint64_t module_load_options)645 mobile::Module _load_for_mobile(
646     const std::string& filename,
647     std::optional<at::Device> device,
648     ExtraFilesMap& extra_files,
649     uint64_t module_load_options) {
650   auto observer = torch::observerConfig().getModuleObserver();
651   if (observer) {
652     extra_files.insert(std::make_pair("model_path", filename));
653   }
654   auto format = getFileFormat(filename);
655 
656   if (format == FileFormat::FlatbufferFileFormat) {
657     auto [data, size] = get_file_content(filename.c_str());
658     return _load_mobile_from_bytes(
659         data, size, device, extra_files, module_load_options);
660   }
661 
662   auto rai = std::make_unique<caffe2::serialize::FileAdapter>(filename);
663   return _load_for_mobile_impl(
664       std::move(rai), device, extra_files, module_load_options);
665 }
666 
_load_for_mobile(std::unique_ptr<ReadAdapterInterface> rai,std::optional<c10::Device> device,ExtraFilesMap & extra_files,uint64_t module_load_options)667 TORCH_API mobile::Module _load_for_mobile(
668     std::unique_ptr<ReadAdapterInterface> rai,
669     std::optional<c10::Device> device,
670     ExtraFilesMap& extra_files,
671     uint64_t module_load_options) {
672   // TODO optimize file read for non-flatbuffer models
673   auto [data, size] = get_rai_content(rai.get());
674   return _load_mobile_from_bytes(
675       data, size, device, extra_files, module_load_options);
676 }
677 
_load_extra_only_for_mobile(const std::string & filename,std::optional<at::Device> device,ExtraFilesMap & extra_files)678 void _load_extra_only_for_mobile(
679     const std::string& filename,
680     std::optional<at::Device> device,
681     ExtraFilesMap& extra_files) {
682   auto observer = torch::observerConfig().getModuleObserver();
683   // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
684   auto instance_key = std::rand();
685   if (observer) {
686     observer->onEnterLoadModel(instance_key);
687   }
688 
689   auto format = getFileFormat(filename);
690   switch (format) {
691     case FileFormat::ZipFileFormat: {
692       auto rai = std::make_unique<caffe2::serialize::FileAdapter>(filename);
693       auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
694       BytecodeDeserializer deserializer(std::move(reader));
695       deserializer.deserialize_only_extra(device, extra_files);
696       break;
697     }
698     case FileFormat::FlatbufferFileFormat: {
699       // TODO: the current flatbuffers implementation will always load the
700       // whole module including the extra files. Ideally it should be
701       // possible to just get the extra files given data
702       load_mobile_module_from_file(filename, std::nullopt, &extra_files);
703       break;
704     }
705     default: {
706       TORCH_CHECK(false, "Format error");
707     }
708   }
709 }
710 
711 namespace mobile {
712 
_export_operator_list(torch::jit::mobile::Module & module)713 std::set<std::string> _export_operator_list(
714     torch::jit::mobile::Module& module) {
715   std::set<std::string> operator_list;
716   for (Method func : module.get_methods()) {
717     const Function& function = func.function();
718     const auto& code = function.get_code();
719     // op_names below isn't a list of unique operator names. In fact
720     // it can contain the same operator name many many times, so we need
721     // to de-dup the list by adding all the operator names into
722     // an std::set<std::string>.
723     std::vector<c10::OperatorName> const& op_names = code.op_names_;
724     for (auto& op_name : op_names) {
725       operator_list.insert(toString(op_name));
726     }
727   }
728   return operator_list;
729 }
730 
731 } // namespace mobile
732 } // namespace torch::jit
733