xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/import.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/interned_strings.h>
2 #include <caffe2/serialize/file_adapter.h>
3 #include <caffe2/serialize/in_memory_adapter.h>
4 #include <caffe2/serialize/inline_container.h>
5 #include <caffe2/serialize/istream_adapter.h>
6 #include <caffe2/serialize/read_adapter_interface.h>
7 
8 #include <torch/csrc/jit/api/compilation_unit.h>
9 
10 #include <ATen/core/functional.h>
11 #include <ATen/core/ivalue_inl.h>
12 #include <c10/util/Exception.h>
13 #include <c10/util/irange.h>
14 #include <torch/csrc/jit/frontend/script_type_parser.h>
15 #include <torch/csrc/jit/ir/graph_utils.h>
16 #include <torch/csrc/jit/ir/ir.h>
17 #include <torch/csrc/jit/mobile/file_format.h>
18 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
19 #include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
20 #include <torch/csrc/jit/passes/shape_analysis.h>
21 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
22 #include <torch/csrc/jit/serialization/import.h>
23 #include <torch/csrc/jit/serialization/import_export_helpers.h>
24 #include <torch/csrc/jit/serialization/import_read.h>
25 #include <torch/csrc/jit/serialization/import_source.h>
26 #include <torch/csrc/jit/serialization/source_range_serialization.h>
27 #include <torch/csrc/jit/serialization/unpickler.h>
28 
29 #include <ATen/ATen.h>
30 #include <fmt/format.h>
31 
32 #include <string>
33 #include <utility>
34 #include <vector>
35 
36 namespace torch::jit {
37 
38 using caffe2::serialize::MemoryReadAdapter;
39 using caffe2::serialize::PyTorchStreamReader;
40 using caffe2::serialize::ReadAdapterInterface;
41 
postSetStateValidate(const IValue & v)42 static void postSetStateValidate(const IValue& v) {
43   auto obj = v.toObject();
44   const auto& objType = obj->type();
45   for (const auto i : c10::irange(objType->numAttributes())) {
46     const auto& attrType = objType->getAttribute(i);
47 #ifndef STRIP_ERROR_MESSAGES
48     const auto& attrName = objType->getAttributeName(i);
49 #endif
50     const auto& slot = obj->getSlot(i);
51     // const auto attrType = objType->getAttribute(i);
52     // Verify that all the non-optional attributes have been initialized
53     // TODO: Issue #20497
54     if (attrType->kind() != TypeKind::UnionType &&
55         attrType->kind() != TypeKind::OptionalType &&
56         attrType->kind() != TypeKind::NoneType) {
57       TORCH_CHECK(
58           !slot.isNone(),
59           fmt::format(
60               "The field '{}' was left uninitialized after '__setstate__', "
61               "but expected a value of type '{}'",
62               attrName,
63               attrType->repr_str()));
64     }
65   }
66 }
67 
68 // Decouple how to get obj from type. In this file it's dependent on
69 // Method.run() and graph executor, etc.
70 // For bytecode import we need to decouple these dependencies.
ObjLoaderFunc(const at::StrongTypePtr & type,IValue input)71 c10::intrusive_ptr<c10::ivalue::Object> ObjLoaderFunc(
72     const at::StrongTypePtr& type,
73     IValue input) {
74   auto cls = type.type_->expect<at::ClassType>();
75   auto qn = cls->name();
76   size_t n = cls->numAttributes();
77   if (checkHasValidSetGetState(cls)) {
78     auto obj = c10::ivalue::Object::create(type, n);
79     // XXX: Do not optimize __setstate__, so that we don't try to
80     // specialize the class before it is initialized.
81     GraphOptimizerEnabledGuard guard(false);
82     Function& set_state = cls->getMethod("__setstate__");
83     // since we are in the middle of unpickling we might still have lists and
84     // dicts that do not have accurate tags (e.g. they report they are
85     // List[Any]). But we need to run __setstate__ which will check the input
86     // type and may access the tags. Since setstate has a known input type, we
87     // can correctly restore the tags now by apply the input type of set_state
88     // to the state object being passed.
89     // TODO: Remove once [serialization type tags] is landed
90     restoreAccurateTypeTags(
91         input, set_state.getSchema().arguments().at(1).type());
92     set_state({obj, input});
93     postSetStateValidate(obj);
94     return obj;
95   } else {
96     auto dict = std::move(input).toGenericDict();
97     auto obj = c10::ivalue::Object::create(type, n);
98     for (const auto i : c10::irange(n)) {
99       obj->setSlot(i, dict.at(cls->getAttributeName(i)));
100     }
101     return obj;
102   }
103 }
104 
105 namespace {
106 
107 // This is a deserializer class which loads script modules from pt files.
108 // Content of the file is written using PyTorchStreamWriter, for details please
109 // check caffe2/serialize/inline_container.h.
110 // The module is saved in pickle. readArchive() is called to parse and construct
111 // the constant table and the script module.
112 class ScriptModuleDeserializer final {
113  public:
ScriptModuleDeserializer(std::shared_ptr<CompilationUnit> cu,std::shared_ptr<PyTorchStreamReader> reader)114   ScriptModuleDeserializer(
115       std::shared_ptr<CompilationUnit> cu,
116       std::shared_ptr<PyTorchStreamReader> reader)
117       : compilation_unit_(std::move(cu)),
118         reader_(std::move(reader)),
119         code_prefix_("code/"),
120         pickle_dir_prefix_(""),
121         tensor_dir_prefix_(""),
122         source_importer_(
123             compilation_unit_,
124             &constants_table_,
125             [this](const std::string& qualifier) {
126               return findSourceInArchiveFromQualifier(
127                   *reader_, code_prefix_, qualifier);
128             },
129             reader_->version()) {}
130 
ScriptModuleDeserializer(std::shared_ptr<CompilationUnit> cu,std::shared_ptr<PyTorchStreamReader> reader,std::string pickle_dir_prefix,std::string tensor_dir_prefix,std::shared_ptr<DeserializationStorageContext> storage_context)131   ScriptModuleDeserializer(
132       std::shared_ptr<CompilationUnit> cu,
133       std::shared_ptr<PyTorchStreamReader> reader,
134       std::string pickle_dir_prefix,
135       std::string tensor_dir_prefix,
136       std::shared_ptr<DeserializationStorageContext> storage_context)
137       : compilation_unit_(std::move(cu)),
138         reader_(std::move(reader)),
139         storage_context_(std::move(storage_context)),
140         code_prefix_(".data/ts_code/code/"),
141         pickle_dir_prefix_(std::move(pickle_dir_prefix)),
142         tensor_dir_prefix_(std::move(tensor_dir_prefix)),
143         source_importer_(
144             compilation_unit_,
145             &constants_table_,
146             [this](const std::string& qualifier) {
147               return findSourceInArchiveFromQualifier(
148                   *reader_, code_prefix_, qualifier);
149             },
150             reader_->version()) {}
151 
152   Module deserialize(
153       std::optional<at::Device> device,
154       ExtraFilesMap& extra_files,
155       bool restore_shapes = false);
156 
157  private:
158   IValue readArchive(const std::string& archive_name);
159 
160   std::shared_ptr<CompilationUnit> compilation_unit_;
161   std::shared_ptr<PyTorchStreamReader> reader_;
162   std::shared_ptr<DeserializationStorageContext> storage_context_;
163   std::optional<at::Device> device_;
164   std::vector<at::IValue> constants_table_;
165   std::string code_prefix_;
166   std::string pickle_dir_prefix_;
167   std::string tensor_dir_prefix_;
168   SourceImporter source_importer_;
169 };
170 
readArchive(const std::string & archive_name)171 IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
172   auto type_resolver = [&](const c10::QualifiedName& qn) {
173     auto cls = source_importer_.loadType(qn);
174     return c10::StrongTypePtr(compilation_unit_, std::move(cls));
175   };
176 
177   return readArchiveAndTensors(
178       /*archive_name=*/archive_name,
179       /*pickle_prefix=*/pickle_dir_prefix_,
180       /*tensor_prefix=*/tensor_dir_prefix_,
181       type_resolver,
182       ObjLoaderFunc,
183       device_,
184       *reader_,
185       nullptr,
186       storage_context_);
187 }
188 
rewriteQuantizedConvForBC(const Module & module)189 void rewriteQuantizedConvForBC(const Module& module) {
190   const std::string& old_quantized_conv2d = R"(
191 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
192          %r = quantized::conv2d(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
193          return (%r) )";
194 
195   const std::string& old_quantized_conv2d_relu = R"(
196 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
197          %r = quantized::conv2d_relu(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
198          return (%r) )";
199 
200   const std::string& old_quantized_conv3d = R"(
201 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
202          %r = quantized::conv3d(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
203          return (%r) )";
204 
205   const std::string& old_quantized_conv3d_relu = R"(
206 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
207          %r = quantized::conv3d_relu(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
208          return (%r) )";
209 
210   const std::string& new_quantized_conv2d = R"(
211 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
212          %r = quantized::conv2d(%x, %packed_params, %r_scale, %r_zero_point)
213          return (%r) )";
214 
215   const std::string& new_quantized_conv2d_relu = R"(
216 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
217          %r = quantized::conv2d_relu(%x, %packed_params, %r_scale, %r_zero_point)
218          return (%r) )";
219 
220   const std::string& new_quantized_conv3d = R"(
221 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
222          %r = quantized::conv3d(%x, %packed_params, %r_scale, %r_zero_point)
223          return (%r) )";
224 
225   const std::string& new_quantized_conv3d_relu = R"(
226 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
227          %r = quantized::conv3d_relu(%x, %packed_params, %r_scale, %r_zero_point)
228          return (%r) )";
229 
230   SubgraphRewriter rewriter;
231   static const std::vector<std::pair<std::string, std::string>>
232       patterns_and_replacements = {
233           {old_quantized_conv2d, new_quantized_conv2d},
234           {old_quantized_conv2d_relu, new_quantized_conv2d_relu},
235           {old_quantized_conv3d, new_quantized_conv3d},
236           {old_quantized_conv3d_relu, new_quantized_conv3d_relu},
237       };
238   for (const auto& item : patterns_and_replacements) {
239     rewriter.RegisterRewritePattern(item.first, item.second);
240   }
241   rewriter.runOnModule(module);
242 
243   for (const Module& child : module.children()) {
244     rewriteQuantizedConvForBC(child);
245   }
246 }
247 
deserialize(std::optional<at::Device> device,ExtraFilesMap & extra_files,bool restore_shapes)248 Module ScriptModuleDeserializer::deserialize(
249     std::optional<at::Device> device,
250     ExtraFilesMap& extra_files,
251     bool restore_shapes) {
252   // we populate the upgraders map before any load starts
253   populate_upgraders_graph_map();
254 
255   C10_LOG_API_USAGE_ONCE("torch.jit.load");
256   device_ = device;
257   // Load extra files.
258   for (const auto& kv : extra_files) {
259     const std::string& key = "extra/" + kv.first;
260     if (reader_->hasRecord(key)) {
261       auto [meta_ptr, meta_size] = reader_->getRecord(key);
262       extra_files[kv.first] =
263           std::string(static_cast<char*>(meta_ptr.get()), meta_size);
264     }
265   }
266   if (reader_->hasRecord("model.json") && code_prefix_ == "code/") {
267     AT_ERROR("Legacy model format is not supported on mobile.");
268   }
269   auto tuple = readArchive("constants").toTuple();
270   for (auto constant : tuple->elements()) {
271     constants_table_.push_back(constant.toIValue());
272   }
273   auto m_ivalue = readArchive("data");
274   auto m = Module(m_ivalue.toObject());
275   rewriteQuantizedConvForBC(m);
276   // Checking for and loading saved traced inputs
277   if (restore_shapes && reader_->hasRecord("traced_inputs.pkl")) {
278     auto dict = readArchive("traced_inputs").toGenericDict();
279     for (const auto& entry : dict) {
280       auto inputs = entry.value().toList().vec();
281       auto g =
282           toGraphFunction(m.get_method(entry.key().toStringRef()).function())
283               .graph();
284       Stack stack(inputs.begin(), inputs.end());
285       // Added the module as the first input if we are missing
286       // an input as traced modules refer to self as an additional input
287       if (g->inputs().size() == stack.size() + 1) {
288         stack.insert(stack.begin(), m_ivalue);
289       }
290       setInputTensorTypes(*g, stack, /*complete=*/true);
291       PropagateInputShapes(g);
292     }
293   } else {
294     if (restore_shapes) {
295       TORCH_WARN("Cannot restore shapes as no traced inputs were stored");
296     }
297   }
298   c10::LogAPIUsageMetadata(
299       "torch.script.load.metadata",
300       {{"serialization_id", reader_->serializationId()}});
301   return m;
302 }
303 } // namespace
304 
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::istream & in,std::optional<at::Device> device,bool load_debug_files)305 Module import_ir_module(
306     std::shared_ptr<CompilationUnit> cu,
307     std::istream& in,
308     std::optional<at::Device> device,
309     bool load_debug_files) {
310   ExtraFilesMap extra_files;
311   return import_ir_module(
312       std::move(cu), in, device, extra_files, load_debug_files);
313 }
314 
315 static Module _load_jit_module_from_bytes(
316     const std::shared_ptr<char>& data,
317     size_t size,
318     std::shared_ptr<CompilationUnit> cu,
319     std::optional<c10::Device> device,
320     ExtraFilesMap& extra_files,
321     bool restore_shapes);
322 
parse_and_initialize_jit_module(const std::shared_ptr<char> & data,size_t size,ExtraFilesMap & extra_files,std::optional<at::Device> device)323 Module parse_and_initialize_jit_module(
324     const std::shared_ptr<char>& data,
325     size_t size,
326     ExtraFilesMap& extra_files,
327     std::optional<at::Device> device) {
328   populate_upgraders_graph_map();
329   ExtraFilesMap jit_files;
330   std::vector<IValue> jit_constants;
331   mobile::Module mobilem = parse_and_initialize_mobile_module_for_jit(
332       data.get(), size, jit_files, jit_constants, device, &extra_files);
333 
334   Module m = jitModuleFromSourceAndConstants(
335       mobilem._ivalue(),
336       jit_files,
337       jit_constants,
338       static_cast<int32_t>(mobilem.bytecode_version()));
339   m.set_delete_memory(data);
340   return m;
341 }
342 
load_jit_module_from_file(const std::string & filename,ExtraFilesMap & extra_files,std::optional<at::Device> device)343 Module load_jit_module_from_file(
344     const std::string& filename,
345     ExtraFilesMap& extra_files,
346     std::optional<at::Device> device) {
347   auto data = get_file_content(filename.c_str());
348   return parse_and_initialize_jit_module(
349       std::get<0>(data), std::get<1>(data), extra_files, device);
350 }
351 
load_jit_module_from_stream(std::istream & in,ExtraFilesMap & extra_files,std::optional<at::Device> device)352 Module load_jit_module_from_stream(
353     std::istream& in,
354     ExtraFilesMap& extra_files,
355     std::optional<at::Device> device) {
356   auto data = get_stream_content(in);
357   return parse_and_initialize_jit_module(
358       std::get<0>(data), std::get<1>(data), extra_files, device);
359 }
360 
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::istream & in,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files,bool restore_shapes)361 Module import_ir_module(
362     std::shared_ptr<CompilationUnit> cu,
363     std::istream& in,
364     std::optional<at::Device> device,
365     ExtraFilesMap& extra_files,
366     bool load_debug_files,
367     bool restore_shapes) {
368   in.seekg(0, in.beg);
369   // NOTE: Zipformat can be large files. So using stream version directly
370   // instead of reading the file all at once.
371   if (getFileFormat(in) != FileFormat::FlatbufferFileFormat) {
372     auto reader = std::make_unique<PyTorchStreamReader>(&in);
373     reader->setShouldLoadDebugSymbol(load_debug_files);
374     ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
375     return deserializer.deserialize(device, extra_files, restore_shapes);
376   }
377   auto [data, size] = get_stream_content(in);
378   return _load_jit_module_from_bytes(
379       data, size, cu, device, extra_files, restore_shapes);
380 }
381 
382 // For reading unified serialization format from torch.Package.
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::shared_ptr<PyTorchStreamReader> reader,std::shared_ptr<DeserializationStorageContext> storage_context,std::optional<at::Device> device,const std::string & ts_id)383 Module import_ir_module(
384     std::shared_ptr<CompilationUnit> cu,
385     std::shared_ptr<PyTorchStreamReader> reader,
386     std::shared_ptr<DeserializationStorageContext> storage_context,
387     std::optional<at::Device> device,
388     const std::string& ts_id) {
389   ScriptModuleDeserializer deserializer(
390       std::move(cu),
391       std::move(reader),
392       /* pickle_dir_prefix = */ ".data/ts_code/" + ts_id + "/",
393       /* tensor_dir_prefix = */ ".data/",
394       std::move(storage_context));
395   ExtraFilesMap extra_files;
396   return deserializer.deserialize(device, extra_files);
397 }
398 
import_ir_module(std::shared_ptr<CompilationUnit> cu,const std::string & filename,std::optional<at::Device> device,bool load_debug_files)399 Module import_ir_module(
400     std::shared_ptr<CompilationUnit> cu,
401     const std::string& filename,
402     std::optional<at::Device> device,
403     bool load_debug_files) {
404   ExtraFilesMap extra_files;
405   return import_ir_module(
406       std::move(cu), filename, device, extra_files, load_debug_files);
407 }
408 
import_ir_module(std::shared_ptr<CompilationUnit> cu,const std::string & filename,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files,bool restore_shapes)409 Module import_ir_module(
410     std::shared_ptr<CompilationUnit> cu,
411     const std::string& filename,
412     std::optional<at::Device> device,
413     ExtraFilesMap& extra_files,
414     bool load_debug_files,
415     bool restore_shapes) {
416   // NOTE: Zipformat can be large files. So using stream version directly
417   // instead of reading the file all at once.
418   if (getFileFormat(filename) != FileFormat::FlatbufferFileFormat) {
419     auto reader = std::make_unique<PyTorchStreamReader>(filename);
420     reader->setShouldLoadDebugSymbol(load_debug_files);
421     ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
422     return deserializer.deserialize(device, extra_files, restore_shapes);
423   }
424   auto [data, size] = get_file_content(filename.c_str());
425   return _load_jit_module_from_bytes(
426       data, size, cu, device, extra_files, restore_shapes);
427 }
428 
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::unique_ptr<ReadAdapterInterface> rai,std::optional<at::Device> device,bool load_debug_files)429 Module import_ir_module(
430     std::shared_ptr<CompilationUnit> cu,
431     std::unique_ptr<ReadAdapterInterface> rai,
432     std::optional<at::Device> device,
433     bool load_debug_files) {
434   ExtraFilesMap extra_files;
435   return import_ir_module(
436       std::move(cu), std::move(rai), device, extra_files, load_debug_files);
437 }
438 
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::unique_ptr<ReadAdapterInterface> rai,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files)439 Module import_ir_module(
440     std::shared_ptr<CompilationUnit> cu,
441     std::unique_ptr<ReadAdapterInterface> rai,
442     std::optional<at::Device> device,
443     ExtraFilesMap& extra_files,
444     bool load_debug_files) {
445   std::shared_ptr<ReadAdapterInterface> rai_shared = std::move(rai);
446   return import_ir_module(
447       std::move(cu), rai_shared, device, extra_files, load_debug_files);
448 }
449 
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::shared_ptr<ReadAdapterInterface> rai,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files)450 Module import_ir_module(
451     std::shared_ptr<CompilationUnit> cu,
452     std::shared_ptr<ReadAdapterInterface> rai,
453     std::optional<at::Device> device,
454     ExtraFilesMap& extra_files,
455     bool load_debug_files) {
456   auto reader = std::make_shared<PyTorchStreamReader>(std::move(rai));
457   reader->setShouldLoadDebugSymbol(load_debug_files);
458   ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
459   return deserializer.deserialize(device, extra_files);
460 }
461 
load(std::istream & in,std::optional<at::Device> device,bool load_debug_files)462 Module load(
463     std::istream& in,
464     std::optional<at::Device> device,
465     bool load_debug_files) {
466   auto cu = std::make_shared<CompilationUnit>();
467   return import_ir_module(std::move(cu), in, device, load_debug_files);
468 }
469 
load(std::istream & in,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files)470 Module load(
471     std::istream& in,
472     std::optional<at::Device> device,
473     ExtraFilesMap& extra_files,
474     bool load_debug_files) {
475   auto cu = std::make_shared<CompilationUnit>();
476   return import_ir_module(
477       std::move(cu), in, device, extra_files, load_debug_files);
478 }
479 
load(const std::string & filename,std::optional<at::Device> device,bool load_debug_files)480 Module load(
481     const std::string& filename,
482     std::optional<at::Device> device,
483     bool load_debug_files) {
484   auto cu = std::make_shared<CompilationUnit>();
485   return import_ir_module(std::move(cu), filename, device, load_debug_files);
486 }
487 
load(const std::string & filename,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files)488 Module load(
489     const std::string& filename,
490     std::optional<at::Device> device,
491     ExtraFilesMap& extra_files,
492     bool load_debug_files) {
493   auto cu = std::make_shared<CompilationUnit>();
494   return import_ir_module(
495       std::move(cu), filename, device, extra_files, load_debug_files);
496 }
497 
load(std::shared_ptr<ReadAdapterInterface> rai,std::optional<c10::Device> device,bool load_debug_files)498 Module load(
499     std::shared_ptr<ReadAdapterInterface> rai,
500     std::optional<c10::Device> device,
501     bool load_debug_files) {
502   auto cu = std::make_shared<CompilationUnit>();
503   ExtraFilesMap extra_files;
504   return import_ir_module(
505       std::move(cu), std::move(rai), device, extra_files, load_debug_files);
506 }
507 
load(std::shared_ptr<ReadAdapterInterface> rai,std::optional<c10::Device> device,ExtraFilesMap & extra_files,bool load_debug_files)508 Module load(
509     std::shared_ptr<ReadAdapterInterface> rai,
510     std::optional<c10::Device> device,
511     ExtraFilesMap& extra_files,
512     bool load_debug_files) {
513   auto cu = std::make_shared<CompilationUnit>();
514   return import_ir_module(
515       std::move(cu), std::move(rai), device, extra_files, load_debug_files);
516 }
517 
_load_jit_module_from_bytes(const std::shared_ptr<char> & data,size_t size,std::shared_ptr<CompilationUnit> cu,std::optional<c10::Device> device,ExtraFilesMap & extra_files,bool restore_shapes)518 Module _load_jit_module_from_bytes(
519     const std::shared_ptr<char>& data,
520     size_t size,
521     std::shared_ptr<CompilationUnit> cu,
522     std::optional<c10::Device> device,
523     ExtraFilesMap& extra_files,
524     bool restore_shapes) {
525   TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
526   auto format = getFileFormat(data.get());
527   switch (format) {
528     case FileFormat::FlatbufferFileFormat: {
529       return parse_and_initialize_jit_module(data, size, extra_files, device);
530     }
531     case FileFormat::ZipFileFormat: {
532       auto rai = std::make_unique<MemoryReadAdapter>(data.get(), size);
533       auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
534       ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
535       return deserializer.deserialize(device, extra_files, restore_shapes);
536     }
537 
538     default:
539       TORCH_CHECK(false, "Unrecognized data format");
540   }
541 }
542 
543 // Replace object with a newly created but equivalent object.
544 // The goal is to replace object's methods. However, since object's
545 // methods are attached to type; we need to replace it's type.
546 // Non-objects are unchanged; however, nested structures such as list, dict
547 // are also reconstructed because they might contain an object.
recreateObject(IValue ivalue,const TypeResolver & resolver)548 static IValue recreateObject(IValue ivalue, const TypeResolver& resolver) {
549   if (ivalue.isObject()) {
550     auto obj = ivalue.toObject();
551     auto classtype_old = obj->type();
552     auto newtype = resolver(*classtype_old->name());
553     size_t n = classtype_old->numAttributes();
554     auto newobj = c10::ivalue::Object::create(newtype, n);
555     for (const auto i : c10::irange(n)) {
556       newobj->setSlot(i, recreateObject(obj->getSlot(i), resolver));
557     }
558     return newobj;
559   } else if (ivalue.isList()) {
560     auto res = c10::impl::GenericList(ivalue.type()->containedType(0));
561     for (const auto& ival : ivalue.toList()) {
562       res.emplace_back(recreateObject(ival, resolver));
563     }
564     return res;
565   } else if (ivalue.isGenericDict()) {
566     auto result = c10::impl::GenericDict(
567         ivalue.type()->containedType(0), ivalue.type()->containedType(1));
568     for (const auto& kv : ivalue.toGenericDict()) {
569       result.insert_or_assign(
570           recreateObject(kv.key(), resolver),
571           recreateObject(kv.value(), resolver));
572     }
573     return result;
574   } else if (ivalue.isTuple()) {
575     std::vector<IValue> res;
576     for (const auto& ival : ivalue.toTuple()->elements()) {
577       res.push_back(recreateObject(ival, resolver));
578     }
579     return c10::ivalue::Tuple::create(res);
580   }
581   // Leaf types are returned verbatim.
582   return ivalue;
583 }
584 
jitModuleFromSourceAndConstants(const IValue & ivalue,const ExtraFilesMap & source,const std::vector<IValue> & constants,int32_t version)585 Module jitModuleFromSourceAndConstants(
586     const IValue& ivalue,
587     const ExtraFilesMap& source,
588     const std::vector<IValue>& constants,
589     int32_t version) {
590   auto compilation_unit = std::make_shared<CompilationUnit>();
591   SourceImporter importer(
592       compilation_unit,
593       &constants,
594       [&source](const std::string& qualifier) -> std::shared_ptr<Source> {
595         auto source_iter = source.find(qualifier);
596         if (source_iter == source.end()) {
597           return nullptr;
598         }
599         return std::make_shared<Source>(
600             source_iter->second, qualifier, 1, nullptr, Source::COPIES_STRING);
601       },
602       version);
603   auto type_resolver = [&](const c10::QualifiedName& qn) {
604     auto cls = importer.loadType(qn);
605     return c10::StrongTypePtr(compilation_unit, std::move(cls));
606   };
607   auto newIvalue = recreateObject(ivalue, type_resolver).toObject();
608   Module m(newIvalue);
609   rewriteQuantizedConvForBC(m);
610   return m;
611 }
612 
613 } // namespace torch::jit
614