1 #include <ATen/core/ivalue.h>
2 #include <c10/util/Exception.h>
3 #include <caffe2/serialize/file_adapter.h>
4 #include <caffe2/serialize/inline_container.h>
5 #include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
6 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
7 #include <torch/csrc/jit/mobile/import.h>
8 #include <torch/csrc/jit/mobile/module.h>
9 #include <torch/csrc/jit/serialization/export.h>
10 #include <torch/csrc/jit/serialization/import.h>
11 #include <torch/csrc/jit/serialization/pickler.h>
12 #include <cstddef>
13 #include <sstream>
14
15 namespace torch::jit {
16
17 using caffe2::serialize::PyTorchStreamReader;
18 using caffe2::serialize::PyTorchStreamWriter;
19
20 // Current support bytecode version
21 namespace {
22 constexpr int64_t kBytecodeVersionV4 = 0x4L;
23 constexpr int64_t kBytecodeVersionV5 = 0x5L;
24 constexpr int64_t kBytecodeVersionV6 = 0x6L;
25 constexpr int64_t kBytecodeVersionV7 = 0x7L;
26 constexpr int64_t kBytecodeVersionV8 = 0x8L;
27 constexpr int64_t kBytecodeVersionV9 = 0x9L;
28 } // namespace
29
30 /********************** Utility Functions **********************/
31
32 // Utility function that can be reused by backport_vn_to_vn-1(). If any utility
33 // function can be reused by other backport function, move it here.
34 namespace {
35 // Copy files from source to destination except the files and dirs
selective_copy(PyTorchStreamReader & reader,PyTorchStreamWriter & writer,const std::unordered_set<std::string> & excluded_files,const std::unordered_set<std::string> & excluded_dirs)36 void selective_copy(
37 PyTorchStreamReader& reader,
38 PyTorchStreamWriter& writer,
39 const std::unordered_set<std::string>& excluded_files,
40 const std::unordered_set<std::string>& excluded_dirs) {
41 auto records = reader.getAllRecords();
42 for (const auto& record : records) {
43 // Don't copy archive in excluded_files, usually archive `version` and
44 // `bytecode`. Archive `version` will be written when PyTorchStreamWriter is
45 // going to finalize and run writeEndOfFile()
46
47 // records is the list of all files names in the zip file, and each record
48 // is one file with path to parent folder, the example records is:
49 // data.pkl
50 // code/__torch__/___torch_mangle_5.py
51 // code/__torch__/___torch_mangle_5.py.debug_pkl
52 // constants/140245072983168.storage
53 // constants.pkl
54 // bytecode.pkl
55 // version
56 bool skip = excluded_files.count(record) > 0;
57
58 // Skip dirs, find the last '/' and compare it with record
59 for (const auto& excluded_dir : excluded_dirs) {
60 std::size_t found = record.find_last_of("/\\");
61 auto path = record.substr(0, found);
62 if (excluded_dir == path) {
63 skip = true;
64 break;
65 }
66 }
67 if (!skip) {
68 auto data_ptr = reader.getRecord(record);
69 auto data = std::get<0>(data_ptr).get();
70 auto size = std::get<1>(data_ptr);
71 writer.writeRecord(record, data, size);
72 }
73 }
74 }
75
76 // The write_archive_current function is used for bytecode from version v5 to
77 // v7 (the latest bytecode version). pre-v5 we serialized things differently.
78 // This write archive function may change in export_module.cpp, however we don't
79 // have a way to keep the old export function in the codebase. To be able to
80 // export the model in old format, we keep a record of the export function here.
write_archive_current(PyTorchStreamWriter & writer,const IValue & value,const std::string & archive_name,const std::string & archive_dir,const std::string & tensor_dir,bool use_storage_context,SerializationStorageContext & storage_context)81 void write_archive_current(
82 PyTorchStreamWriter& writer,
83 const IValue& value,
84 const std::string& archive_name,
85 const std::string& archive_dir,
86 const std::string& tensor_dir,
87 bool use_storage_context,
88 SerializationStorageContext& storage_context) {
89 std::vector<char> data;
90 // Vector to capture the run-time class types during pickling the IValues
91 std::vector<c10::ClassTypePtr> memoizedClassTypes;
92 std::vector<std::string> tensor_names;
93 Pickler data_pickle(
94 [&](const char* buf, size_t size) {
95 data.insert(data.end(), buf, buf + size);
96 },
97 nullptr,
98 nullptr,
99 &memoizedClassTypes,
100 [&](const at::Tensor& tensor) {
101 // returns a string to use in picker.cpp as storage obj key
102 if (use_storage_context) {
103 std::string string_id =
104 std::to_string(reinterpret_cast<std::intptr_t>(
105 tensor.storage().unsafeGetStorageImpl()));
106 tensor_names.push_back(string_id + ".storage");
107 storage_context.getOrAddStorage(tensor.storage());
108 } else {
109 tensor_names.push_back(std::to_string(tensor_names.size()));
110 }
111 return tensor_names.back();
112 });
113 data_pickle.protocol();
114 data_pickle.pushIValue(value);
115 data_pickle.stop();
116 // write out tensor data
117 size_t i = 0;
118
119 TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size());
120 const std::unordered_set<std::string>& pre_serialized_files =
121 writer.getAllWrittenRecords();
122
123 for (const auto& td : data_pickle.tensorData()) {
124 WriteableTensorData writable_td = getWriteableTensorData(td);
125 std::string fname = tensor_dir + tensor_names[i++];
126 if (use_storage_context &&
127 pre_serialized_files.find(fname) != pre_serialized_files.end()) {
128 // storage has been serialzed already, skip
129 continue;
130 }
131 writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
132 }
133
134 std::string fname = archive_dir + archive_name + ".pkl";
135 writer.writeRecord(fname, data.data(), data.size());
136 }
137
138 /*
139 inputs: 1) bytecode tuple from bytecode.pkl 2) the output bytecode version,
140 return: A boolean to indicate whether bytecode tuple is updated successfully
141 */
update_bytecode_version(std::vector<at::IValue> & bytecode_values,const int64_t to_version)142 bool update_bytecode_version(
143 std::vector<at::IValue>& bytecode_values,
144 const int64_t to_version) {
145 if (!bytecode_values.empty() && bytecode_values[0].isInt()) {
146 bytecode_values[0] = c10::IValue(to_version);
147 return true;
148 }
149 return false;
150 }
151
152 /*
153 inputs: 1) input model stringstream 2) the output bytecode version,
154 return: model stringstream with updated bytecode version in bytecode.pkl
155
156 Example bytecode.pkl:
157 (${bytecode_version},
158 ('__torch__.m.forward',
159 (('instructions',
160 (('STOREN', 1, 2),
161 ('DROPR', 1, 0),
162 ('MOVE', 2, 0),
163 ('OP', 0, 0),
164 ('RET', 0, 0))),
165 ('operators', (('aten::Int', 'Tensor'),)),
166 ('constants', ()),
167 ('types', ()),
168 ('register_size', 2))))
169 */
update_bytecode_version(std::stringstream & input_model,const int64_t to_version)170 std::stringstream update_bytecode_version(
171 std::stringstream& input_model,
172 const int64_t to_version) {
173 PyTorchStreamReader reader_bytecode(&input_model);
174 auto constants_values =
175 std::move(*readArchive(kArchiveNameConstants, reader_bytecode).toTuple())
176 .elements();
177
178 std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader_bytecode);
179 std::unordered_set<std::string> excluded_files{
180 "constants.pkl", "bytecode.pkl"};
181
182 std::unordered_set<std::string> excluded_dirs{
183 "constants",
184 "bytecode",
185 };
186
187 std::stringstream output_model_stream;
188 auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
189 output_model_stream.write(static_cast<const char*>(buf), nbytes);
190 return !output_model_stream ? 0 : nbytes;
191 };
192
193 PyTorchStreamWriter writer_bytecode(writer_func);
194
195 selective_copy(
196 reader_bytecode, writer_bytecode, excluded_files, excluded_dirs);
197
198 update_bytecode_version(bytecode_values, to_version);
199 auto bytecode_tuple = c10::ivalue::Tuple::create(std::move(bytecode_values));
200 SerializationStorageContext storage_context;
201 write_archive_current(
202 writer_bytecode,
203 c10::ivalue::Tuple::create(std::move(constants_values)),
204 /*archive_name=*/"constants",
205 /*archive_dir=*/"",
206 /*tensor_dir=*/"constants/",
207 /*use_storage_context=*/true,
208 storage_context);
209 write_archive_current(
210 writer_bytecode,
211 bytecode_tuple,
212 /*archive_name=*/"bytecode",
213 /*archive_dir=*/"",
214 /*tensor_dir=*/"constants/",
215 /*use_storage_context=*/true,
216 storage_context);
217
218 return output_model_stream;
219 }
220 } // namespace
221
222 /******************** backport_v{i}_to_v{i-1} Functions **********************/
223
224 /*
225 To add next backport function, for example, backport_vn_to_vn-1, create an
226 anonymous namespace with a backport_vn_to_vn-1 function + other necessary
227 customized function. If a function can be reused by other backport functions,
228 move it to the utility function group. It will be easier to split out
229 backport_manager.cpp to smaller files when it grows too long.
230
231 How to add backport_v{i}_to_v{i-1} ?
232 There are two options:
233 1) [Format change only, recommended] Constrcut a reader with the
234 input_model_stream, modify the file, and use PyTorchWriter to write it to
235 output_model_stream. See backport_v5_to_v4.
236
237 2) [Both format and content change] ]Use torch.jit.load() to load the stream,
238 and save it to output_model_stream.
239
240 The first option is preferred, because it will be purely format change, and
241 the model doesn't need to go through inline again and model content will
242 remain the same.
243
244 A note for manipulate stringstream, it's recommend to declare a new
245 stringstream, tmp_stream, and swap it with the argument output_model_stream
246 once it's ready, output_model_stream.swap(tmp_stream). Do not use
247 output_model_stream.clear(). It only clears out error state flag
248 (https://www.cplusplus.com/reference/ios/ios/clear/), while the content is the
249 same. It's cleaner to just declare a new one and swap.
250
251 */
252
253 namespace {
254
255 /*
256 The following functions needed for backport model from v5 to v4.
257 Backport function bytecode v5 that deduplicate constanst table.
258 Previously, in v4, constant table will be exported twice, in both archive
259 bytecode and archive constants, and majority (almost all) are duplicates.
260 Currently, in v5, JIT and mobile will share archive constants, and all
261 constant tensors will be exported in this archive. The bump was needed
262 because the v5 bytecode export the tensor storage path in the schema, since
263 the runtime code is now able to query which archive this tensor is stored at
264 and query the correct archive.
265 For example, Previously, in v4, we deserialize tensor as without archive
266 path, and mobile will always read tensor from bytecode archive:
267 (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage,
268 '0', 'cpu', 8),),
269 0,
270 (2, 4),
271 (4, 1),
272 False,
273 collections.OrderedDict()),
274 1)),
275 So, if the program defines: torch.add(x, h, out=x)
276 Currently, in v5, we deserialize the bytecode with the archive path, and
277 mobile can read tensor from the given path:
278 (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage,
279 'constants/0', 'cpu', 8),),
280 0,
281 (2, 4),
282 (4, 1),
283 False,
284 collections.OrderedDict()),
285 1)),
286 Thus, the backport is necessary such that the runtime can read tensor from
287 the correct archive.
288 */
backport_v5_to_v4(std::stringstream & input_model_stream)289 std::stringstream backport_v5_to_v4(std::stringstream& input_model_stream) {
290 // 1) read from archive `bytecode` archive
291 PyTorchStreamReader reader(&input_model_stream);
292 std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader);
293 auto constants_values =
294 std::move(*readArchive(kArchiveNameConstants, reader).toTuple())
295 .elements();
296
297 // 2) Copy everything to new output, except some specific files and dirs
298 // (usually version, bytecode.pkl and bytecode folder are skipped)
299 std::unordered_set<std::string> excluded_files{
300 "constants.pkl", "bytecode.pkl"};
301
302 std::unordered_set<std::string> excluded_dirs{
303 "constants",
304 "bytecode",
305 };
306
307 std::stringstream output_model_stream;
308 auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
309 output_model_stream.write(static_cast<const char*>(buf), nbytes);
310 return !output_model_stream ? 0 : nbytes;
311 };
312
313 PyTorchStreamWriter writer(writer_func);
314
315 selective_copy(reader, writer, excluded_files, excluded_dirs);
316
317 // 3) write `bytecode` archive
318 // Update the bytecode version in bytecode.pkl
319 update_bytecode_version(bytecode_values, kBytecodeVersionV4);
320 // Construct the list of ivalues to a big tuple
321 auto bytecode_tuple = c10::ivalue::Tuple::create(std::move(bytecode_values));
322
323 // The export function to generate bytecode.pkl for version 4. After bytecode
324 // version bump, the old export function doesn't exist anymore, so keep a copy
325 // here for backport pupose.
326 auto writeArchiveV4 = [](PyTorchStreamWriter& writer,
327 const std::string& archive_name,
328 const c10::IValue& value) {
329 std::vector<char> data;
330
331 // Vector to capture the run-time class types during pickling the IValues
332 std::vector<c10::ClassTypePtr> memoizedClassTypes;
333 Pickler data_pickle(
334 [&](const char* buf, size_t size) {
335 data.insert(data.end(), buf, buf + size);
336 },
337 nullptr,
338 nullptr,
339 &memoizedClassTypes);
340 data_pickle.protocol();
341 data_pickle.pushIValue(value);
342 data_pickle.stop();
343 size_t i = 0;
344 std::string prefix = archive_name + "/";
345
346 for (const auto& td : data_pickle.tensorData()) {
347 WriteableTensorData writable_td = getWriteableTensorData(td);
348 std::string fname = prefix + std::to_string(i++);
349 writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
350 }
351 std::string fname = archive_name + ".pkl";
352 writer.writeRecord(fname, data.data(), data.size());
353 };
354
355 // write `bytecode` archive
356 writeArchiveV4(writer, kArchiveNameBytecode, bytecode_tuple);
357 // write `constants` archive
358 auto constants_tuple =
359 c10::ivalue::Tuple::create(std::move(constants_values));
360 writeArchiveV4(writer, kArchiveNameConstants, constants_tuple);
361 return output_model_stream;
362 }
363
364 /*
365 Backport function bytecode v6 that introduced support for operators with default
366 arguments in mobile. Previously, in v5, there is no number of specified
367 arguments for operators in bytecode operator table. In v6, operators are aware
368 of the number of specified arguments being present in the schema.
369
370 The bump was needed because the v6 bytecode specifies number of specified
371 arguments for operators in the schema, since the runtime code is now able to
372 query the number of specified arguments and supports default arguments.
373
374 For example, aten::foo's schema in v5 is
375 foo(Tensor a, Tensor b) -> Tensor
376 and in v6, it's
377 foo(Tensor a, Tensor b, int groups=1) -> Tensor
378
379 Accordingly, the operator table in v5 is:
380 ('operators', (('aten::foo', ''),))
381 and in v6, it's
382 ('operators', (('aten::foo', '', 2),))
383
384 Thus, the backport is necessary such that the bytecode operator table contains
385 number of specified arguments.
386 */
backport_v6_to_v5(std::stringstream & input_model_stream)387 std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) {
388 auto rai =
389 std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
390 auto reader = std::make_shared<PyTorchStreamReader>(rai);
391
392 // If there are debug info files in the original model file, it should also
393 // show up in the backported model
394 bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl");
395
396 // extra_files are kept
397 auto records = reader->getAllRecords();
398 ExtraFilesMap extra_files;
399 for (const auto& record : records) {
400 std::size_t found = record.find_last_of("/\\");
401 auto path = record.substr(0, found);
402 if ("extra" == path) {
403 extra_files.emplace(record.substr(found + 1), "");
404 }
405 }
406 // Loading the TS module is required for this backport, because bytecode needs
407 // to be re-emitted (refer to the comments below)
408 Module torch_script = torch::jit::load(rai, std::nullopt, extra_files);
409
410 // The RAII guard to change the flag, emitBytecodeDefaultInputs, to true, so
411 // that TS stores the default argument values in the constant table, and emits
412 // the instructions (LOADC, for example), to push the values to the stack. It
413 // restores the behavior of V5 and before. For V6, the default arg values are
414 // resolved at runtime init stage for better operator compatibility.
415 std::stringstream intermediate_model_stream;
416 {
417 BytecodeEmitModeGuard argNumGuard(
418 true /*emit_default_input_instructions*/,
419 false /*enable_defaults_args_with_out_args*/,
420 false /*enable_emit_promoted_ops*/);
421 torch_script._save_for_mobile(
422 intermediate_model_stream, extra_files, hasBytecodeDebug);
423 }
424
425 // Update the bytecode version (from 6 to 5)
426 std::stringstream output_model_stream =
427 update_bytecode_version(intermediate_model_stream, kBytecodeVersionV5);
428 return output_model_stream;
429 }
430
431 /*
432 Backport function bytecode v7 that introduced support for operators with out
433 arguments. Previously, in v6, operators with out arguments forced the
434 serialization of all arguments in the schema, even when optional arguments
435 were not provided (as they had default values). Currently, in v7, operators
436 are aware of out arguments being present in the schema (always appended),
437 allowing the serialization of only required arguments (as default values will
438 be provided by the runtime).
439
440 The bump was needed because the v7 bytecode specifies less arguments for ops
441 with out arguments in the schema, since the runtime code is now able to query
442 whether an argument is of type "out" and insert the necessary default values in
443 the right order in the interpreter stack (i.e. before the out arguments).
444
445 For example schema is: torch.add(x, h, alpha=1.0, out=x) So, if the program
446 defines: torch.add(x, h, out=x) Previously, in v6, we serialized the bytecode to
447 contain all 4 arguments. Currently, in v7, we serialize the bytecode with only 3
448 arguments, since alpha is optional and has a default value that the runtime will
449 push in the stack. Thus, the backport is necessary such that the bytecode
450 contains all the arguments as before.
451 */
backport_v7_to_v6(std::stringstream & input_model_stream)452 std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
453 auto rai =
454 std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
455 auto reader = std::make_shared<PyTorchStreamReader>(rai);
456 auto constants_values =
457 std::move(*readArchive(kArchiveNameConstants, *reader).toTuple())
458 .elements();
459
460 // If there are debug info files in the original model file, it should also
461 // show up in the backported model
462 bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl");
463
464 // extra_files are kept
465 auto records = reader->getAllRecords();
466 ExtraFilesMap extra_files;
467 for (const auto& record : records) {
468 std::size_t found = record.find_last_of("/\\");
469 auto path = record.substr(0, found);
470 if ("extra" == path) {
471 extra_files.emplace(record.substr(found + 1), "");
472 }
473 }
474 // Loading the TS module is required for this backport, because bytecode needs
475 // to be re-emitted (refer to the comments below)
476 Module torch_script = torch::jit::load(rai, std::nullopt, extra_files);
477
478 // The RAII guard to change the flag, emit_default_input_instructions, to
479 // false to keep the same behavior in bytecode version 6. Change the flag,
480 // enable_defaults_args_with_out_args, to deserialized the number of specified
481 // operators which allowing both out arguments and default arguments to
482 // #all_args, instead of (#all_args - #default_args)
483 std::stringstream intermediate_model_stream;
484 {
485 BytecodeEmitModeGuard argNumGuard(
486 false /*emit_default_input_instructions*/,
487 false /*enable_defaults_args_with_out_args*/,
488 false /*enable_emit_promoted_ops*/);
489 torch_script._save_for_mobile(
490 intermediate_model_stream, extra_files, hasBytecodeDebug);
491 }
492
493 // Update the bytecode version (from 7 to 6)
494 std::stringstream output_model_stream =
495 update_bytecode_version(intermediate_model_stream, kBytecodeVersionV6);
496 return output_model_stream;
497 }
498
backport_v9_to_v8(std::stringstream & input_model_stream)499 std::stringstream backport_v9_to_v8(std::stringstream& input_model_stream) {
500 ExtraFilesMap extra_files;
501 Module torch_script =
502 torch::jit::load(input_model_stream, std::nullopt, extra_files);
503 std::stringstream intermediate_model_stream;
504 // TODO(@pavithran) : Check if debug info is available and use load/save while
505 // backporting hardcode debaug info to be false untill supported.
506 bool hasBytecodeDebug = false;
507 {
508 BytecodeEmitModeGuard argNumGuard(
509 false /*emit_default_input_instructions*/,
510 true /*enable_defaults_args_with_out_args*/,
511 true /*enable_emit_promoted_ops*/);
512 torch_script._save_for_mobile(
513 intermediate_model_stream,
514 extra_files,
515 hasBytecodeDebug,
516 /*use_flatbuffer=*/false);
517 }
518 // Update the bytecode version (from 9 to 8)
519 std::stringstream output_model_stream =
520 update_bytecode_version(intermediate_model_stream, kBytecodeVersionV8);
521
522 return output_model_stream;
523 }
524
backport_v8_to_v7(std::stringstream & input_model_stream)525 std::stringstream backport_v8_to_v7(std::stringstream& input_model_stream) {
526 auto rai =
527 std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
528 auto reader = std::make_shared<PyTorchStreamReader>(rai);
529 // extra_files are kept
530 auto records = reader->getAllRecords();
531 bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl");
532 ExtraFilesMap extra_files;
533 for (const auto& record : records) {
534 std::size_t found = record.find_last_of("/\\");
535 auto path = record.substr(0, found);
536 if ("extra" == path) {
537 extra_files.emplace(record.substr(found + 1), "");
538 }
539 }
540 Module torch_script = torch::jit::load(rai, std::nullopt, extra_files);
541 std::stringstream intermediate_model_stream;
542 {
543 BytecodeEmitModeGuard argNumGuard(
544 false /*emit_default_input_instructions*/,
545 true /*enable_defaults_args_with_out_args*/,
546 false /*enable_emit_promoted_ops*/);
547 torch_script._save_for_mobile(
548 intermediate_model_stream, extra_files, hasBytecodeDebug);
549 }
550
551 // Update the bytecode version (from 8 to 7)
552 std::stringstream output_model_stream =
553 update_bytecode_version(intermediate_model_stream, kBytecodeVersionV7);
554
555 return output_model_stream;
556 }
557
558 } // namespace
559
560 /********************** BackportManager **********************/
561
562 // A generic contract for backport logic to the previous bytecode version.
563 // Args:
564 // * PyTorchStreamReader has access to the input model from N bytecode version.
565 // * PyTorchStreamWriter has access to the output model backported to the
566 // previous N-1 bytecode version. Returns true if successful, false otherwise.
567 using BytecodeBackportFunction =
568 std::function<std::stringstream(std::stringstream&)>;
569
BackportManager()570 BackportManager::BackportManager() {
571 registerBytecodeBackportFunction(kBytecodeVersionV5, backport_v5_to_v4);
572 registerBytecodeBackportFunction(kBytecodeVersionV6, backport_v6_to_v5);
573 registerBytecodeBackportFunction(kBytecodeVersionV7, backport_v7_to_v6);
574 registerBytecodeBackportFunction(kBytecodeVersionV8, backport_v8_to_v7);
575 registerBytecodeBackportFunction(kBytecodeVersionV9, backport_v9_to_v8);
576 }
577
578 std::unordered_map<
579 int64_t,
580 std::function<std::stringstream(std::stringstream&)>>&
bytecodeBackportFunctions() const581 BackportManager::bytecodeBackportFunctions() const {
582 static std::unordered_map<
583 int64_t,
584 std::function<std::stringstream(std::stringstream&)>>
585 backport_functions;
586 return backport_functions;
587 }
588
hasBytecodeBackportFunction(const int64_t from_version) const589 bool BackportManager::hasBytecodeBackportFunction(
590 const int64_t from_version) const {
591 return bytecodeBackportFunctions().count(from_version);
592 }
593
registerBytecodeBackportFunction(const int64_t from_version,const BytecodeBackportFunction & backport_function)594 void BackportManager::registerBytecodeBackportFunction(
595 const int64_t from_version,
596 const BytecodeBackportFunction& backport_function) {
597 TORCH_CHECK(
598 !hasBytecodeBackportFunction(from_version),
599 "Backporting from version ",
600 from_version,
601 " is already registered.");
602 bytecodeBackportFunctions()[from_version] = backport_function;
603 }
604
605 // The main function to run backport from version n to version i.
606 // All models (file or buffer) will be converted stream first, and
607 // istream_adapter has access to it. During the backport process,
608 // the intermediate result will be stored with stream.
backport(std::istream & oss,PyTorchStreamWriter & final_writer,int64_t from_version,int64_t to_version) const609 bool BackportManager::backport(
610 std::istream& oss,
611 PyTorchStreamWriter& final_writer,
612 int64_t from_version,
613 int64_t to_version) const {
614 if (from_version <= to_version) {
615 TORCH_WARN(
616 "backport donesn't support backporting model to new version. It's trying to backport from version ",
617 from_version,
618 " to version ",
619 to_version);
620 return false;
621 }
622 int64_t bytecode_version = from_version;
623 bool backport_success = true;
624
625 // 1) Given an istream_adapter (an adapter with access to the input model, the
626 // model can be from istream, file and etc), copy all model content to
627 // stringstream
628 oss.seekg(0, std::ios::beg);
629 std::stringstream input_model_stream;
630 input_model_stream << oss.rdbuf();
631 std::stringstream output_model_stream;
632
633 // 2) backport model, backport_v{i}_to_v{i-1} function's argurment is
634 // (input_model_stream and output_model_stream)
635 while (bytecode_version > to_version) {
636 // Swap input and output if it's not the first time and output_model_stream
637 // has value.
638 if (!output_model_stream.str().empty()) {
639 input_model_stream.swap(output_model_stream);
640 // reset output_model_stream
641 output_model_stream.str("");
642 }
643
644 if (!hasBytecodeBackportFunction(bytecode_version)) {
645 return false;
646 }
647
648 input_model_stream.seekg(0, input_model_stream.beg);
649 auto input_model_stream_version =
650 _get_model_bytecode_version(input_model_stream);
651
652 if (static_cast<int64_t>(input_model_stream_version) != bytecode_version) {
653 TORCH_WARN(
654 "The bytecode version of input model stream is supposed to be ",
655 bytecode_version,
656 ", but it gets ",
657 input_model_stream_version);
658 return false;
659 }
660
661 // Keep backporting till request version
662 std::stringstream backport_model_stream =
663 bytecodeBackportFunctions()[bytecode_version--](input_model_stream);
664
665 output_model_stream.swap(backport_model_stream);
666 output_model_stream.seekg(0, output_model_stream.beg);
667 auto output_model_stream_version =
668 _get_model_bytecode_version(output_model_stream);
669
670 if (static_cast<int64_t>(output_model_stream_version) != bytecode_version) {
671 TORCH_WARN(
672 "The bytecode version of output model stream is supposed to be ",
673 bytecode_version,
674 ", but it gets ",
675 output_model_stream_version);
676 return false;
677 }
678 }
679
680 // 3) Write the final output_model_stream to final_writer, final_writer has
681 // access to the final model destination (file, ostream and etc)
682 if (output_model_stream.str().empty()) {
683 TORCH_WARN("No output model from backport.");
684 return false;
685 }
686 PyTorchStreamReader last_model_reader(&output_model_stream);
687 selective_copy(
688 last_model_reader,
689 final_writer,
690 std::unordered_set<std::string>(),
691 std::unordered_set<std::string>());
692
693 return backport_success;
694 }
695
696 } // namespace torch::jit
697