1 #include <ATen/core/interned_strings.h>
2 #include <caffe2/serialize/file_adapter.h>
3 #include <caffe2/serialize/in_memory_adapter.h>
4 #include <caffe2/serialize/inline_container.h>
5 #include <caffe2/serialize/istream_adapter.h>
6 #include <caffe2/serialize/read_adapter_interface.h>
7
8 #include <torch/csrc/jit/api/compilation_unit.h>
9
10 #include <ATen/core/functional.h>
11 #include <ATen/core/ivalue_inl.h>
12 #include <c10/util/Exception.h>
13 #include <c10/util/irange.h>
14 #include <torch/csrc/jit/frontend/script_type_parser.h>
15 #include <torch/csrc/jit/ir/graph_utils.h>
16 #include <torch/csrc/jit/ir/ir.h>
17 #include <torch/csrc/jit/mobile/file_format.h>
18 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
19 #include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
20 #include <torch/csrc/jit/passes/shape_analysis.h>
21 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
22 #include <torch/csrc/jit/serialization/import.h>
23 #include <torch/csrc/jit/serialization/import_export_helpers.h>
24 #include <torch/csrc/jit/serialization/import_read.h>
25 #include <torch/csrc/jit/serialization/import_source.h>
26 #include <torch/csrc/jit/serialization/source_range_serialization.h>
27 #include <torch/csrc/jit/serialization/unpickler.h>
28
29 #include <ATen/ATen.h>
30 #include <fmt/format.h>
31
32 #include <string>
33 #include <utility>
34 #include <vector>
35
36 namespace torch::jit {
37
38 using caffe2::serialize::MemoryReadAdapter;
39 using caffe2::serialize::PyTorchStreamReader;
40 using caffe2::serialize::ReadAdapterInterface;
41
postSetStateValidate(const IValue & v)42 static void postSetStateValidate(const IValue& v) {
43 auto obj = v.toObject();
44 const auto& objType = obj->type();
45 for (const auto i : c10::irange(objType->numAttributes())) {
46 const auto& attrType = objType->getAttribute(i);
47 #ifndef STRIP_ERROR_MESSAGES
48 const auto& attrName = objType->getAttributeName(i);
49 #endif
50 const auto& slot = obj->getSlot(i);
51 // const auto attrType = objType->getAttribute(i);
52 // Verify that all the non-optional attributes have been initialized
53 // TODO: Issue #20497
54 if (attrType->kind() != TypeKind::UnionType &&
55 attrType->kind() != TypeKind::OptionalType &&
56 attrType->kind() != TypeKind::NoneType) {
57 TORCH_CHECK(
58 !slot.isNone(),
59 fmt::format(
60 "The field '{}' was left uninitialized after '__setstate__', "
61 "but expected a value of type '{}'",
62 attrName,
63 attrType->repr_str()));
64 }
65 }
66 }
67
68 // Decouple how to get obj from type. In this file it's dependent on
69 // Method.run() and graph executor, etc.
70 // For bytecode import we need to decouple these dependencies.
ObjLoaderFunc(const at::StrongTypePtr & type,IValue input)71 c10::intrusive_ptr<c10::ivalue::Object> ObjLoaderFunc(
72 const at::StrongTypePtr& type,
73 IValue input) {
74 auto cls = type.type_->expect<at::ClassType>();
75 auto qn = cls->name();
76 size_t n = cls->numAttributes();
77 if (checkHasValidSetGetState(cls)) {
78 auto obj = c10::ivalue::Object::create(type, n);
79 // XXX: Do not optimize __setstate__, so that we don't try to
80 // specialize the class before it is initialized.
81 GraphOptimizerEnabledGuard guard(false);
82 Function& set_state = cls->getMethod("__setstate__");
83 // since we are in the middle of unpickling we might still have lists and
84 // dicts that do not have accurate tags (e.g. they report they are
85 // List[Any]). But we need to run __setstate__ which will check the input
86 // type and may access the tags. Since setstate has a known input type, we
87 // can correctly restore the tags now by apply the input type of set_state
88 // to the state object being passed.
89 // TODO: Remove once [serialization type tags] is landed
90 restoreAccurateTypeTags(
91 input, set_state.getSchema().arguments().at(1).type());
92 set_state({obj, input});
93 postSetStateValidate(obj);
94 return obj;
95 } else {
96 auto dict = std::move(input).toGenericDict();
97 auto obj = c10::ivalue::Object::create(type, n);
98 for (const auto i : c10::irange(n)) {
99 obj->setSlot(i, dict.at(cls->getAttributeName(i)));
100 }
101 return obj;
102 }
103 }
104
105 namespace {
106
107 // This is a deserializer class which loads script modules from pt files.
108 // Content of the file is written using PyTorchStreamWriter, for details please
109 // check caffe2/serialize/inline_container.h.
110 // The module is saved in pickle. readArchive() is called to parse and construct
111 // the constant table and the script module.
112 class ScriptModuleDeserializer final {
113 public:
ScriptModuleDeserializer(std::shared_ptr<CompilationUnit> cu,std::shared_ptr<PyTorchStreamReader> reader)114 ScriptModuleDeserializer(
115 std::shared_ptr<CompilationUnit> cu,
116 std::shared_ptr<PyTorchStreamReader> reader)
117 : compilation_unit_(std::move(cu)),
118 reader_(std::move(reader)),
119 code_prefix_("code/"),
120 pickle_dir_prefix_(""),
121 tensor_dir_prefix_(""),
122 source_importer_(
123 compilation_unit_,
124 &constants_table_,
125 [this](const std::string& qualifier) {
126 return findSourceInArchiveFromQualifier(
127 *reader_, code_prefix_, qualifier);
128 },
129 reader_->version()) {}
130
ScriptModuleDeserializer(std::shared_ptr<CompilationUnit> cu,std::shared_ptr<PyTorchStreamReader> reader,std::string pickle_dir_prefix,std::string tensor_dir_prefix,std::shared_ptr<DeserializationStorageContext> storage_context)131 ScriptModuleDeserializer(
132 std::shared_ptr<CompilationUnit> cu,
133 std::shared_ptr<PyTorchStreamReader> reader,
134 std::string pickle_dir_prefix,
135 std::string tensor_dir_prefix,
136 std::shared_ptr<DeserializationStorageContext> storage_context)
137 : compilation_unit_(std::move(cu)),
138 reader_(std::move(reader)),
139 storage_context_(std::move(storage_context)),
140 code_prefix_(".data/ts_code/code/"),
141 pickle_dir_prefix_(std::move(pickle_dir_prefix)),
142 tensor_dir_prefix_(std::move(tensor_dir_prefix)),
143 source_importer_(
144 compilation_unit_,
145 &constants_table_,
146 [this](const std::string& qualifier) {
147 return findSourceInArchiveFromQualifier(
148 *reader_, code_prefix_, qualifier);
149 },
150 reader_->version()) {}
151
152 Module deserialize(
153 std::optional<at::Device> device,
154 ExtraFilesMap& extra_files,
155 bool restore_shapes = false);
156
157 private:
158 IValue readArchive(const std::string& archive_name);
159
160 std::shared_ptr<CompilationUnit> compilation_unit_;
161 std::shared_ptr<PyTorchStreamReader> reader_;
162 std::shared_ptr<DeserializationStorageContext> storage_context_;
163 std::optional<at::Device> device_;
164 std::vector<at::IValue> constants_table_;
165 std::string code_prefix_;
166 std::string pickle_dir_prefix_;
167 std::string tensor_dir_prefix_;
168 SourceImporter source_importer_;
169 };
170
readArchive(const std::string & archive_name)171 IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
172 auto type_resolver = [&](const c10::QualifiedName& qn) {
173 auto cls = source_importer_.loadType(qn);
174 return c10::StrongTypePtr(compilation_unit_, std::move(cls));
175 };
176
177 return readArchiveAndTensors(
178 /*archive_name=*/archive_name,
179 /*pickle_prefix=*/pickle_dir_prefix_,
180 /*tensor_prefix=*/tensor_dir_prefix_,
181 type_resolver,
182 ObjLoaderFunc,
183 device_,
184 *reader_,
185 nullptr,
186 storage_context_);
187 }
188
rewriteQuantizedConvForBC(const Module & module)189 void rewriteQuantizedConvForBC(const Module& module) {
190 const std::string& old_quantized_conv2d = R"(
191 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
192 %r = quantized::conv2d(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
193 return (%r) )";
194
195 const std::string& old_quantized_conv2d_relu = R"(
196 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
197 %r = quantized::conv2d_relu(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
198 return (%r) )";
199
200 const std::string& old_quantized_conv3d = R"(
201 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
202 %r = quantized::conv3d(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
203 return (%r) )";
204
205 const std::string& old_quantized_conv3d_relu = R"(
206 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
207 %r = quantized::conv3d_relu(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
208 return (%r) )";
209
210 const std::string& new_quantized_conv2d = R"(
211 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
212 %r = quantized::conv2d(%x, %packed_params, %r_scale, %r_zero_point)
213 return (%r) )";
214
215 const std::string& new_quantized_conv2d_relu = R"(
216 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
217 %r = quantized::conv2d_relu(%x, %packed_params, %r_scale, %r_zero_point)
218 return (%r) )";
219
220 const std::string& new_quantized_conv3d = R"(
221 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
222 %r = quantized::conv3d(%x, %packed_params, %r_scale, %r_zero_point)
223 return (%r) )";
224
225 const std::string& new_quantized_conv3d_relu = R"(
226 graph(%x, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point):
227 %r = quantized::conv3d_relu(%x, %packed_params, %r_scale, %r_zero_point)
228 return (%r) )";
229
230 SubgraphRewriter rewriter;
231 static const std::vector<std::pair<std::string, std::string>>
232 patterns_and_replacements = {
233 {old_quantized_conv2d, new_quantized_conv2d},
234 {old_quantized_conv2d_relu, new_quantized_conv2d_relu},
235 {old_quantized_conv3d, new_quantized_conv3d},
236 {old_quantized_conv3d_relu, new_quantized_conv3d_relu},
237 };
238 for (const auto& item : patterns_and_replacements) {
239 rewriter.RegisterRewritePattern(item.first, item.second);
240 }
241 rewriter.runOnModule(module);
242
243 for (const Module& child : module.children()) {
244 rewriteQuantizedConvForBC(child);
245 }
246 }
247
deserialize(std::optional<at::Device> device,ExtraFilesMap & extra_files,bool restore_shapes)248 Module ScriptModuleDeserializer::deserialize(
249 std::optional<at::Device> device,
250 ExtraFilesMap& extra_files,
251 bool restore_shapes) {
252 // we populate the upgraders map before any load starts
253 populate_upgraders_graph_map();
254
255 C10_LOG_API_USAGE_ONCE("torch.jit.load");
256 device_ = device;
257 // Load extra files.
258 for (const auto& kv : extra_files) {
259 const std::string& key = "extra/" + kv.first;
260 if (reader_->hasRecord(key)) {
261 auto [meta_ptr, meta_size] = reader_->getRecord(key);
262 extra_files[kv.first] =
263 std::string(static_cast<char*>(meta_ptr.get()), meta_size);
264 }
265 }
266 if (reader_->hasRecord("model.json") && code_prefix_ == "code/") {
267 AT_ERROR("Legacy model format is not supported on mobile.");
268 }
269 auto tuple = readArchive("constants").toTuple();
270 for (auto constant : tuple->elements()) {
271 constants_table_.push_back(constant.toIValue());
272 }
273 auto m_ivalue = readArchive("data");
274 auto m = Module(m_ivalue.toObject());
275 rewriteQuantizedConvForBC(m);
276 // Checking for and loading saved traced inputs
277 if (restore_shapes && reader_->hasRecord("traced_inputs.pkl")) {
278 auto dict = readArchive("traced_inputs").toGenericDict();
279 for (const auto& entry : dict) {
280 auto inputs = entry.value().toList().vec();
281 auto g =
282 toGraphFunction(m.get_method(entry.key().toStringRef()).function())
283 .graph();
284 Stack stack(inputs.begin(), inputs.end());
285 // Added the module as the first input if we are missing
286 // an input as traced modules refer to self as an additional input
287 if (g->inputs().size() == stack.size() + 1) {
288 stack.insert(stack.begin(), m_ivalue);
289 }
290 setInputTensorTypes(*g, stack, /*complete=*/true);
291 PropagateInputShapes(g);
292 }
293 } else {
294 if (restore_shapes) {
295 TORCH_WARN("Cannot restore shapes as no traced inputs were stored");
296 }
297 }
298 c10::LogAPIUsageMetadata(
299 "torch.script.load.metadata",
300 {{"serialization_id", reader_->serializationId()}});
301 return m;
302 }
303 } // namespace
304
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::istream & in,std::optional<at::Device> device,bool load_debug_files)305 Module import_ir_module(
306 std::shared_ptr<CompilationUnit> cu,
307 std::istream& in,
308 std::optional<at::Device> device,
309 bool load_debug_files) {
310 ExtraFilesMap extra_files;
311 return import_ir_module(
312 std::move(cu), in, device, extra_files, load_debug_files);
313 }
314
315 static Module _load_jit_module_from_bytes(
316 const std::shared_ptr<char>& data,
317 size_t size,
318 std::shared_ptr<CompilationUnit> cu,
319 std::optional<c10::Device> device,
320 ExtraFilesMap& extra_files,
321 bool restore_shapes);
322
parse_and_initialize_jit_module(const std::shared_ptr<char> & data,size_t size,ExtraFilesMap & extra_files,std::optional<at::Device> device)323 Module parse_and_initialize_jit_module(
324 const std::shared_ptr<char>& data,
325 size_t size,
326 ExtraFilesMap& extra_files,
327 std::optional<at::Device> device) {
328 populate_upgraders_graph_map();
329 ExtraFilesMap jit_files;
330 std::vector<IValue> jit_constants;
331 mobile::Module mobilem = parse_and_initialize_mobile_module_for_jit(
332 data.get(), size, jit_files, jit_constants, device, &extra_files);
333
334 Module m = jitModuleFromSourceAndConstants(
335 mobilem._ivalue(),
336 jit_files,
337 jit_constants,
338 static_cast<int32_t>(mobilem.bytecode_version()));
339 m.set_delete_memory(data);
340 return m;
341 }
342
load_jit_module_from_file(const std::string & filename,ExtraFilesMap & extra_files,std::optional<at::Device> device)343 Module load_jit_module_from_file(
344 const std::string& filename,
345 ExtraFilesMap& extra_files,
346 std::optional<at::Device> device) {
347 auto data = get_file_content(filename.c_str());
348 return parse_and_initialize_jit_module(
349 std::get<0>(data), std::get<1>(data), extra_files, device);
350 }
351
load_jit_module_from_stream(std::istream & in,ExtraFilesMap & extra_files,std::optional<at::Device> device)352 Module load_jit_module_from_stream(
353 std::istream& in,
354 ExtraFilesMap& extra_files,
355 std::optional<at::Device> device) {
356 auto data = get_stream_content(in);
357 return parse_and_initialize_jit_module(
358 std::get<0>(data), std::get<1>(data), extra_files, device);
359 }
360
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::istream & in,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files,bool restore_shapes)361 Module import_ir_module(
362 std::shared_ptr<CompilationUnit> cu,
363 std::istream& in,
364 std::optional<at::Device> device,
365 ExtraFilesMap& extra_files,
366 bool load_debug_files,
367 bool restore_shapes) {
368 in.seekg(0, in.beg);
369 // NOTE: Zipformat can be large files. So using stream version directly
370 // instead of reading the file all at once.
371 if (getFileFormat(in) != FileFormat::FlatbufferFileFormat) {
372 auto reader = std::make_unique<PyTorchStreamReader>(&in);
373 reader->setShouldLoadDebugSymbol(load_debug_files);
374 ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
375 return deserializer.deserialize(device, extra_files, restore_shapes);
376 }
377 auto [data, size] = get_stream_content(in);
378 return _load_jit_module_from_bytes(
379 data, size, cu, device, extra_files, restore_shapes);
380 }
381
382 // For reading unified serialization format from torch.Package.
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::shared_ptr<PyTorchStreamReader> reader,std::shared_ptr<DeserializationStorageContext> storage_context,std::optional<at::Device> device,const std::string & ts_id)383 Module import_ir_module(
384 std::shared_ptr<CompilationUnit> cu,
385 std::shared_ptr<PyTorchStreamReader> reader,
386 std::shared_ptr<DeserializationStorageContext> storage_context,
387 std::optional<at::Device> device,
388 const std::string& ts_id) {
389 ScriptModuleDeserializer deserializer(
390 std::move(cu),
391 std::move(reader),
392 /* pickle_dir_prefix = */ ".data/ts_code/" + ts_id + "/",
393 /* tensor_dir_prefix = */ ".data/",
394 std::move(storage_context));
395 ExtraFilesMap extra_files;
396 return deserializer.deserialize(device, extra_files);
397 }
398
import_ir_module(std::shared_ptr<CompilationUnit> cu,const std::string & filename,std::optional<at::Device> device,bool load_debug_files)399 Module import_ir_module(
400 std::shared_ptr<CompilationUnit> cu,
401 const std::string& filename,
402 std::optional<at::Device> device,
403 bool load_debug_files) {
404 ExtraFilesMap extra_files;
405 return import_ir_module(
406 std::move(cu), filename, device, extra_files, load_debug_files);
407 }
408
import_ir_module(std::shared_ptr<CompilationUnit> cu,const std::string & filename,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files,bool restore_shapes)409 Module import_ir_module(
410 std::shared_ptr<CompilationUnit> cu,
411 const std::string& filename,
412 std::optional<at::Device> device,
413 ExtraFilesMap& extra_files,
414 bool load_debug_files,
415 bool restore_shapes) {
416 // NOTE: Zipformat can be large files. So using stream version directly
417 // instead of reading the file all at once.
418 if (getFileFormat(filename) != FileFormat::FlatbufferFileFormat) {
419 auto reader = std::make_unique<PyTorchStreamReader>(filename);
420 reader->setShouldLoadDebugSymbol(load_debug_files);
421 ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
422 return deserializer.deserialize(device, extra_files, restore_shapes);
423 }
424 auto [data, size] = get_file_content(filename.c_str());
425 return _load_jit_module_from_bytes(
426 data, size, cu, device, extra_files, restore_shapes);
427 }
428
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::unique_ptr<ReadAdapterInterface> rai,std::optional<at::Device> device,bool load_debug_files)429 Module import_ir_module(
430 std::shared_ptr<CompilationUnit> cu,
431 std::unique_ptr<ReadAdapterInterface> rai,
432 std::optional<at::Device> device,
433 bool load_debug_files) {
434 ExtraFilesMap extra_files;
435 return import_ir_module(
436 std::move(cu), std::move(rai), device, extra_files, load_debug_files);
437 }
438
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::unique_ptr<ReadAdapterInterface> rai,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files)439 Module import_ir_module(
440 std::shared_ptr<CompilationUnit> cu,
441 std::unique_ptr<ReadAdapterInterface> rai,
442 std::optional<at::Device> device,
443 ExtraFilesMap& extra_files,
444 bool load_debug_files) {
445 std::shared_ptr<ReadAdapterInterface> rai_shared = std::move(rai);
446 return import_ir_module(
447 std::move(cu), rai_shared, device, extra_files, load_debug_files);
448 }
449
import_ir_module(std::shared_ptr<CompilationUnit> cu,std::shared_ptr<ReadAdapterInterface> rai,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files)450 Module import_ir_module(
451 std::shared_ptr<CompilationUnit> cu,
452 std::shared_ptr<ReadAdapterInterface> rai,
453 std::optional<at::Device> device,
454 ExtraFilesMap& extra_files,
455 bool load_debug_files) {
456 auto reader = std::make_shared<PyTorchStreamReader>(std::move(rai));
457 reader->setShouldLoadDebugSymbol(load_debug_files);
458 ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
459 return deserializer.deserialize(device, extra_files);
460 }
461
load(std::istream & in,std::optional<at::Device> device,bool load_debug_files)462 Module load(
463 std::istream& in,
464 std::optional<at::Device> device,
465 bool load_debug_files) {
466 auto cu = std::make_shared<CompilationUnit>();
467 return import_ir_module(std::move(cu), in, device, load_debug_files);
468 }
469
load(std::istream & in,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files)470 Module load(
471 std::istream& in,
472 std::optional<at::Device> device,
473 ExtraFilesMap& extra_files,
474 bool load_debug_files) {
475 auto cu = std::make_shared<CompilationUnit>();
476 return import_ir_module(
477 std::move(cu), in, device, extra_files, load_debug_files);
478 }
479
load(const std::string & filename,std::optional<at::Device> device,bool load_debug_files)480 Module load(
481 const std::string& filename,
482 std::optional<at::Device> device,
483 bool load_debug_files) {
484 auto cu = std::make_shared<CompilationUnit>();
485 return import_ir_module(std::move(cu), filename, device, load_debug_files);
486 }
487
load(const std::string & filename,std::optional<at::Device> device,ExtraFilesMap & extra_files,bool load_debug_files)488 Module load(
489 const std::string& filename,
490 std::optional<at::Device> device,
491 ExtraFilesMap& extra_files,
492 bool load_debug_files) {
493 auto cu = std::make_shared<CompilationUnit>();
494 return import_ir_module(
495 std::move(cu), filename, device, extra_files, load_debug_files);
496 }
497
load(std::shared_ptr<ReadAdapterInterface> rai,std::optional<c10::Device> device,bool load_debug_files)498 Module load(
499 std::shared_ptr<ReadAdapterInterface> rai,
500 std::optional<c10::Device> device,
501 bool load_debug_files) {
502 auto cu = std::make_shared<CompilationUnit>();
503 ExtraFilesMap extra_files;
504 return import_ir_module(
505 std::move(cu), std::move(rai), device, extra_files, load_debug_files);
506 }
507
load(std::shared_ptr<ReadAdapterInterface> rai,std::optional<c10::Device> device,ExtraFilesMap & extra_files,bool load_debug_files)508 Module load(
509 std::shared_ptr<ReadAdapterInterface> rai,
510 std::optional<c10::Device> device,
511 ExtraFilesMap& extra_files,
512 bool load_debug_files) {
513 auto cu = std::make_shared<CompilationUnit>();
514 return import_ir_module(
515 std::move(cu), std::move(rai), device, extra_files, load_debug_files);
516 }
517
_load_jit_module_from_bytes(const std::shared_ptr<char> & data,size_t size,std::shared_ptr<CompilationUnit> cu,std::optional<c10::Device> device,ExtraFilesMap & extra_files,bool restore_shapes)518 Module _load_jit_module_from_bytes(
519 const std::shared_ptr<char>& data,
520 size_t size,
521 std::shared_ptr<CompilationUnit> cu,
522 std::optional<c10::Device> device,
523 ExtraFilesMap& extra_files,
524 bool restore_shapes) {
525 TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
526 auto format = getFileFormat(data.get());
527 switch (format) {
528 case FileFormat::FlatbufferFileFormat: {
529 return parse_and_initialize_jit_module(data, size, extra_files, device);
530 }
531 case FileFormat::ZipFileFormat: {
532 auto rai = std::make_unique<MemoryReadAdapter>(data.get(), size);
533 auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
534 ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
535 return deserializer.deserialize(device, extra_files, restore_shapes);
536 }
537
538 default:
539 TORCH_CHECK(false, "Unrecognized data format");
540 }
541 }
542
543 // Replace object with a newly created but equivalent object.
544 // The goal is to replace object's methods. However, since object's
545 // methods are attached to type; we need to replace it's type.
546 // Non-objects are unchanged; however, nested structures such as list, dict
547 // are also reconstructed because they might contain an object.
recreateObject(IValue ivalue,const TypeResolver & resolver)548 static IValue recreateObject(IValue ivalue, const TypeResolver& resolver) {
549 if (ivalue.isObject()) {
550 auto obj = ivalue.toObject();
551 auto classtype_old = obj->type();
552 auto newtype = resolver(*classtype_old->name());
553 size_t n = classtype_old->numAttributes();
554 auto newobj = c10::ivalue::Object::create(newtype, n);
555 for (const auto i : c10::irange(n)) {
556 newobj->setSlot(i, recreateObject(obj->getSlot(i), resolver));
557 }
558 return newobj;
559 } else if (ivalue.isList()) {
560 auto res = c10::impl::GenericList(ivalue.type()->containedType(0));
561 for (const auto& ival : ivalue.toList()) {
562 res.emplace_back(recreateObject(ival, resolver));
563 }
564 return res;
565 } else if (ivalue.isGenericDict()) {
566 auto result = c10::impl::GenericDict(
567 ivalue.type()->containedType(0), ivalue.type()->containedType(1));
568 for (const auto& kv : ivalue.toGenericDict()) {
569 result.insert_or_assign(
570 recreateObject(kv.key(), resolver),
571 recreateObject(kv.value(), resolver));
572 }
573 return result;
574 } else if (ivalue.isTuple()) {
575 std::vector<IValue> res;
576 for (const auto& ival : ivalue.toTuple()->elements()) {
577 res.push_back(recreateObject(ival, resolver));
578 }
579 return c10::ivalue::Tuple::create(res);
580 }
581 // Leaf types are returned verbatim.
582 return ivalue;
583 }
584
jitModuleFromSourceAndConstants(const IValue & ivalue,const ExtraFilesMap & source,const std::vector<IValue> & constants,int32_t version)585 Module jitModuleFromSourceAndConstants(
586 const IValue& ivalue,
587 const ExtraFilesMap& source,
588 const std::vector<IValue>& constants,
589 int32_t version) {
590 auto compilation_unit = std::make_shared<CompilationUnit>();
591 SourceImporter importer(
592 compilation_unit,
593 &constants,
594 [&source](const std::string& qualifier) -> std::shared_ptr<Source> {
595 auto source_iter = source.find(qualifier);
596 if (source_iter == source.end()) {
597 return nullptr;
598 }
599 return std::make_shared<Source>(
600 source_iter->second, qualifier, 1, nullptr, Source::COPIES_STRING);
601 },
602 version);
603 auto type_resolver = [&](const c10::QualifiedName& qn) {
604 auto cls = importer.loadType(qn);
605 return c10::StrongTypePtr(compilation_unit, std::move(cls));
606 };
607 auto newIvalue = recreateObject(ivalue, type_resolver).toObject();
608 Module m(newIvalue);
609 rewriteQuantizedConvForBC(m);
610 return m;
611 }
612
613 } // namespace torch::jit
614