xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/compatibility/backport_manager.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/ivalue.h>
2 #include <c10/util/Exception.h>
3 #include <caffe2/serialize/file_adapter.h>
4 #include <caffe2/serialize/inline_container.h>
5 #include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
6 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
7 #include <torch/csrc/jit/mobile/import.h>
8 #include <torch/csrc/jit/mobile/module.h>
9 #include <torch/csrc/jit/serialization/export.h>
10 #include <torch/csrc/jit/serialization/import.h>
11 #include <torch/csrc/jit/serialization/pickler.h>
12 #include <cstddef>
13 #include <sstream>
14 
15 namespace torch::jit {
16 
17 using caffe2::serialize::PyTorchStreamReader;
18 using caffe2::serialize::PyTorchStreamWriter;
19 
20 // Current support bytecode version
21 namespace {
22 constexpr int64_t kBytecodeVersionV4 = 0x4L;
23 constexpr int64_t kBytecodeVersionV5 = 0x5L;
24 constexpr int64_t kBytecodeVersionV6 = 0x6L;
25 constexpr int64_t kBytecodeVersionV7 = 0x7L;
26 constexpr int64_t kBytecodeVersionV8 = 0x8L;
27 constexpr int64_t kBytecodeVersionV9 = 0x9L;
28 } // namespace
29 
30 /********************** Utility Functions **********************/
31 
32 // Utility function that can be reused by backport_vn_to_vn-1(). If any utility
33 // function can be reused by other backport function, move it here.
34 namespace {
35 // Copy files from source to destination except the files and dirs
selective_copy(PyTorchStreamReader & reader,PyTorchStreamWriter & writer,const std::unordered_set<std::string> & excluded_files,const std::unordered_set<std::string> & excluded_dirs)36 void selective_copy(
37     PyTorchStreamReader& reader,
38     PyTorchStreamWriter& writer,
39     const std::unordered_set<std::string>& excluded_files,
40     const std::unordered_set<std::string>& excluded_dirs) {
41   auto records = reader.getAllRecords();
42   for (const auto& record : records) {
43     // Don't copy archive in excluded_files, usually archive `version` and
44     // `bytecode`. Archive `version` will be written when PyTorchStreamWriter is
45     // going to finalize and run writeEndOfFile()
46 
47     // records is the list of all files names in the zip file, and each record
48     // is one file with path to parent folder, the example records is:
49     // data.pkl
50     // code/__torch__/___torch_mangle_5.py
51     // code/__torch__/___torch_mangle_5.py.debug_pkl
52     // constants/140245072983168.storage
53     // constants.pkl
54     // bytecode.pkl
55     // version
56     bool skip = excluded_files.count(record) > 0;
57 
58     // Skip dirs, find the last '/' and compare it with record
59     for (const auto& excluded_dir : excluded_dirs) {
60       std::size_t found = record.find_last_of("/\\");
61       auto path = record.substr(0, found);
62       if (excluded_dir == path) {
63         skip = true;
64         break;
65       }
66     }
67     if (!skip) {
68       auto data_ptr = reader.getRecord(record);
69       auto data = std::get<0>(data_ptr).get();
70       auto size = std::get<1>(data_ptr);
71       writer.writeRecord(record, data, size);
72     }
73   }
74 }
75 
76 // The write_archive_current function is used for bytecode from version v5 to
77 // v7 (the latest bytecode version). pre-v5 we serialized things differently.
78 // This write archive function may change in export_module.cpp, however we don't
79 // have a way to keep the old export function in the codebase. To be able to
80 // export the model in old format, we keep a record of the export function here.
write_archive_current(PyTorchStreamWriter & writer,const IValue & value,const std::string & archive_name,const std::string & archive_dir,const std::string & tensor_dir,bool use_storage_context,SerializationStorageContext & storage_context)81 void write_archive_current(
82     PyTorchStreamWriter& writer,
83     const IValue& value,
84     const std::string& archive_name,
85     const std::string& archive_dir,
86     const std::string& tensor_dir,
87     bool use_storage_context,
88     SerializationStorageContext& storage_context) {
89   std::vector<char> data;
90   // Vector to capture the run-time class types during pickling the IValues
91   std::vector<c10::ClassTypePtr> memoizedClassTypes;
92   std::vector<std::string> tensor_names;
93   Pickler data_pickle(
94       [&](const char* buf, size_t size) {
95         data.insert(data.end(), buf, buf + size);
96       },
97       nullptr,
98       nullptr,
99       &memoizedClassTypes,
100       [&](const at::Tensor& tensor) {
101         // returns a string to use in picker.cpp as storage obj key
102         if (use_storage_context) {
103           std::string string_id =
104               std::to_string(reinterpret_cast<std::intptr_t>(
105                   tensor.storage().unsafeGetStorageImpl()));
106           tensor_names.push_back(string_id + ".storage");
107           storage_context.getOrAddStorage(tensor.storage());
108         } else {
109           tensor_names.push_back(std::to_string(tensor_names.size()));
110         }
111         return tensor_names.back();
112       });
113   data_pickle.protocol();
114   data_pickle.pushIValue(value);
115   data_pickle.stop();
116   // write out tensor data
117   size_t i = 0;
118 
119   TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size());
120   const std::unordered_set<std::string>& pre_serialized_files =
121       writer.getAllWrittenRecords();
122 
123   for (const auto& td : data_pickle.tensorData()) {
124     WriteableTensorData writable_td = getWriteableTensorData(td);
125     std::string fname = tensor_dir + tensor_names[i++];
126     if (use_storage_context &&
127         pre_serialized_files.find(fname) != pre_serialized_files.end()) {
128       // storage has been serialzed already, skip
129       continue;
130     }
131     writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
132   }
133 
134   std::string fname = archive_dir + archive_name + ".pkl";
135   writer.writeRecord(fname, data.data(), data.size());
136 }
137 
138 /*
139 inputs: 1) bytecode tuple from bytecode.pkl 2) the output bytecode version,
140 return: A boolean to indicate whether bytecode tuple is updated successfully
141 */
update_bytecode_version(std::vector<at::IValue> & bytecode_values,const int64_t to_version)142 bool update_bytecode_version(
143     std::vector<at::IValue>& bytecode_values,
144     const int64_t to_version) {
145   if (!bytecode_values.empty() && bytecode_values[0].isInt()) {
146     bytecode_values[0] = c10::IValue(to_version);
147     return true;
148   }
149   return false;
150 }
151 
152 /*
153 inputs: 1) input model stringstream 2) the output bytecode version,
154 return: model stringstream with updated bytecode version in bytecode.pkl
155 
156 Example bytecode.pkl:
157 (${bytecode_version},
158   ('__torch__.m.forward',
159     (('instructions',
160       (('STOREN', 1, 2),
161        ('DROPR', 1, 0),
162        ('MOVE', 2, 0),
163        ('OP', 0, 0),
164        ('RET', 0, 0))),
165      ('operators', (('aten::Int', 'Tensor'),)),
166      ('constants', ()),
167      ('types', ()),
168      ('register_size', 2))))
169 */
update_bytecode_version(std::stringstream & input_model,const int64_t to_version)170 std::stringstream update_bytecode_version(
171     std::stringstream& input_model,
172     const int64_t to_version) {
173   PyTorchStreamReader reader_bytecode(&input_model);
174   auto constants_values =
175       std::move(*readArchive(kArchiveNameConstants, reader_bytecode).toTuple())
176           .elements();
177 
178   std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader_bytecode);
179   std::unordered_set<std::string> excluded_files{
180       "constants.pkl", "bytecode.pkl"};
181 
182   std::unordered_set<std::string> excluded_dirs{
183       "constants",
184       "bytecode",
185   };
186 
187   std::stringstream output_model_stream;
188   auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
189     output_model_stream.write(static_cast<const char*>(buf), nbytes);
190     return !output_model_stream ? 0 : nbytes;
191   };
192 
193   PyTorchStreamWriter writer_bytecode(writer_func);
194 
195   selective_copy(
196       reader_bytecode, writer_bytecode, excluded_files, excluded_dirs);
197 
198   update_bytecode_version(bytecode_values, to_version);
199   auto bytecode_tuple = c10::ivalue::Tuple::create(std::move(bytecode_values));
200   SerializationStorageContext storage_context;
201   write_archive_current(
202       writer_bytecode,
203       c10::ivalue::Tuple::create(std::move(constants_values)),
204       /*archive_name=*/"constants",
205       /*archive_dir=*/"",
206       /*tensor_dir=*/"constants/",
207       /*use_storage_context=*/true,
208       storage_context);
209   write_archive_current(
210       writer_bytecode,
211       bytecode_tuple,
212       /*archive_name=*/"bytecode",
213       /*archive_dir=*/"",
214       /*tensor_dir=*/"constants/",
215       /*use_storage_context=*/true,
216       storage_context);
217 
218   return output_model_stream;
219 }
220 } // namespace
221 
222 /******************** backport_v{i}_to_v{i-1} Functions **********************/
223 
224 /*
225  To add next backport function, for example, backport_vn_to_vn-1, create an
226  anonymous namespace with a backport_vn_to_vn-1 function + other necessary
227  customized function. If a function can be reused by other backport functions,
228  move it to the utility function group. It will be easier to split out
229  backport_manager.cpp to smaller files when it grows too long.
230 
231  How to add backport_v{i}_to_v{i-1} ?
232  There are two options:
233  1) [Format change only, recommended] Constrcut a reader with the
234  input_model_stream, modify the file, and use PyTorchWriter to write it to
235  output_model_stream. See backport_v5_to_v4.
236 
237  2) [Both format and content change] ]Use torch.jit.load() to load the stream,
238  and save it to output_model_stream.
239 
240  The first option is preferred, because it will be purely format change, and
241  the model doesn't need to go through inline again and model content will
242  remain the same.
243 
244  A note for manipulate stringstream, it's recommend to declare a new
245  stringstream, tmp_stream, and swap it with the argument output_model_stream
246  once it's ready, output_model_stream.swap(tmp_stream). Do not use
247  output_model_stream.clear(). It only clears out error state flag
248  (https://www.cplusplus.com/reference/ios/ios/clear/), while the content is the
249  same. It's cleaner to just declare a new one and swap.
250 
251 */
252 
253 namespace {
254 
255 /*
256 The following functions needed for backport model from v5 to v4.
257 Backport function bytecode v5 that deduplicate constanst table.
258 Previously, in v4, constant table will be exported twice, in both archive
259 bytecode and archive constants, and majority (almost all) are duplicates.
260 Currently, in v5, JIT and mobile will share archive constants, and all
261 constant tensors will be exported in this archive. The bump was needed
262 because the v5 bytecode export the tensor storage path in the schema, since
263 the runtime code is now able to query which archive this tensor is stored at
264 and query the correct archive.
265 For example, Previously, in v4, we deserialize tensor as without archive
266 path, and mobile will always read tensor from bytecode archive:
267 (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage,
268 '0', 'cpu', 8),),
269    0,
270    (2, 4),
271    (4, 1),
272    False,
273    collections.OrderedDict()),
274  1)),
275  So, if the program defines: torch.add(x, h, out=x)
276 Currently, in v5, we deserialize the bytecode with the archive path, and
277 mobile can read tensor from the given path:
278 (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage,
279 'constants/0', 'cpu', 8),),
280    0,
281    (2, 4),
282    (4, 1),
283    False,
284    collections.OrderedDict()),
285  1)),
286 Thus, the backport is necessary such that the runtime can read tensor from
287 the correct archive.
288 */
backport_v5_to_v4(std::stringstream & input_model_stream)289 std::stringstream backport_v5_to_v4(std::stringstream& input_model_stream) {
290   // 1) read from archive `bytecode` archive
291   PyTorchStreamReader reader(&input_model_stream);
292   std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader);
293   auto constants_values =
294       std::move(*readArchive(kArchiveNameConstants, reader).toTuple())
295           .elements();
296 
297   // 2) Copy everything to new output, except some specific files and dirs
298   // (usually version, bytecode.pkl and bytecode folder are skipped)
299   std::unordered_set<std::string> excluded_files{
300       "constants.pkl", "bytecode.pkl"};
301 
302   std::unordered_set<std::string> excluded_dirs{
303       "constants",
304       "bytecode",
305   };
306 
307   std::stringstream output_model_stream;
308   auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
309     output_model_stream.write(static_cast<const char*>(buf), nbytes);
310     return !output_model_stream ? 0 : nbytes;
311   };
312 
313   PyTorchStreamWriter writer(writer_func);
314 
315   selective_copy(reader, writer, excluded_files, excluded_dirs);
316 
317   // 3) write `bytecode` archive
318   // Update the bytecode version in bytecode.pkl
319   update_bytecode_version(bytecode_values, kBytecodeVersionV4);
320   // Construct the list of ivalues to a big tuple
321   auto bytecode_tuple = c10::ivalue::Tuple::create(std::move(bytecode_values));
322 
323   // The export function to generate bytecode.pkl for version 4. After bytecode
324   // version bump, the old export function doesn't exist anymore, so keep a copy
325   // here for backport pupose.
326   auto writeArchiveV4 = [](PyTorchStreamWriter& writer,
327                            const std::string& archive_name,
328                            const c10::IValue& value) {
329     std::vector<char> data;
330 
331     // Vector to capture the run-time class types during pickling the IValues
332     std::vector<c10::ClassTypePtr> memoizedClassTypes;
333     Pickler data_pickle(
334         [&](const char* buf, size_t size) {
335           data.insert(data.end(), buf, buf + size);
336         },
337         nullptr,
338         nullptr,
339         &memoizedClassTypes);
340     data_pickle.protocol();
341     data_pickle.pushIValue(value);
342     data_pickle.stop();
343     size_t i = 0;
344     std::string prefix = archive_name + "/";
345 
346     for (const auto& td : data_pickle.tensorData()) {
347       WriteableTensorData writable_td = getWriteableTensorData(td);
348       std::string fname = prefix + std::to_string(i++);
349       writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
350     }
351     std::string fname = archive_name + ".pkl";
352     writer.writeRecord(fname, data.data(), data.size());
353   };
354 
355   // write `bytecode` archive
356   writeArchiveV4(writer, kArchiveNameBytecode, bytecode_tuple);
357   // write `constants` archive
358   auto constants_tuple =
359       c10::ivalue::Tuple::create(std::move(constants_values));
360   writeArchiveV4(writer, kArchiveNameConstants, constants_tuple);
361   return output_model_stream;
362 }
363 
364 /*
365 Backport function bytecode v6 that introduced support for operators with default
366 arguments in mobile. Previously, in v5, there is no number of specified
367 arguments for operators in bytecode operator table. In v6, operators are aware
368 of the number of specified arguments being present in the schema.
369 
370 The bump was needed because the v6 bytecode specifies number of specified
371 arguments for operators in the schema, since the runtime code is now able to
372 query the number of specified arguments and supports default arguments.
373 
374 For example, aten::foo's schema in v5 is
375 foo(Tensor a, Tensor b) -> Tensor
376 and in v6, it's
377 foo(Tensor a, Tensor b, int groups=1) -> Tensor
378 
379 Accordingly, the operator table in v5 is:
380 ('operators', (('aten::foo', ''),))
381 and in v6, it's
382 ('operators', (('aten::foo', '', 2),))
383 
384 Thus, the backport is necessary such that the bytecode operator table contains
385 number of specified arguments.
386 */
backport_v6_to_v5(std::stringstream & input_model_stream)387 std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) {
388   auto rai =
389       std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
390   auto reader = std::make_shared<PyTorchStreamReader>(rai);
391 
392   // If there are debug info files in the original model file, it should also
393   // show up in the backported model
394   bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl");
395 
396   // extra_files are kept
397   auto records = reader->getAllRecords();
398   ExtraFilesMap extra_files;
399   for (const auto& record : records) {
400     std::size_t found = record.find_last_of("/\\");
401     auto path = record.substr(0, found);
402     if ("extra" == path) {
403       extra_files.emplace(record.substr(found + 1), "");
404     }
405   }
406   // Loading the TS module is required for this backport, because bytecode needs
407   // to be re-emitted (refer to the comments below)
408   Module torch_script = torch::jit::load(rai, std::nullopt, extra_files);
409 
410   // The RAII guard to change the flag, emitBytecodeDefaultInputs, to true, so
411   // that TS stores the default argument values in the constant table, and emits
412   // the instructions (LOADC, for example), to push the values to the stack. It
413   // restores the behavior of V5 and before. For V6, the default arg values are
414   // resolved at runtime init stage for better operator compatibility.
415   std::stringstream intermediate_model_stream;
416   {
417     BytecodeEmitModeGuard argNumGuard(
418         true /*emit_default_input_instructions*/,
419         false /*enable_defaults_args_with_out_args*/,
420         false /*enable_emit_promoted_ops*/);
421     torch_script._save_for_mobile(
422         intermediate_model_stream, extra_files, hasBytecodeDebug);
423   }
424 
425   // Update the bytecode version (from 6 to 5)
426   std::stringstream output_model_stream =
427       update_bytecode_version(intermediate_model_stream, kBytecodeVersionV5);
428   return output_model_stream;
429 }
430 
431 /*
432 Backport function bytecode v7 that introduced support for operators with out
433 arguments. Previously, in v6, operators with out arguments forced the
434 serialization of all arguments in the schema, even when optional arguments
435 were not provided (as they had default values). Currently, in v7, operators
436 are aware of out arguments being present in the schema (always appended),
437 allowing the serialization of only required arguments (as default values will
438 be provided by the runtime).
439 
440 The bump was needed because the v7 bytecode specifies less arguments for ops
441 with out arguments in the schema, since the runtime code is now able to query
442 whether an argument is of type "out" and insert the necessary default values in
443 the right order in the interpreter stack (i.e. before the out arguments).
444 
445 For example schema is: torch.add(x, h, alpha=1.0, out=x) So, if the program
446 defines: torch.add(x, h, out=x) Previously, in v6, we serialized the bytecode to
447 contain all 4 arguments. Currently, in v7, we serialize the bytecode with only 3
448 arguments, since alpha is optional and has a default value that the runtime will
449 push in the stack. Thus, the backport is necessary such that the bytecode
450 contains all the arguments as before.
451 */
backport_v7_to_v6(std::stringstream & input_model_stream)452 std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) {
453   auto rai =
454       std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
455   auto reader = std::make_shared<PyTorchStreamReader>(rai);
456   auto constants_values =
457       std::move(*readArchive(kArchiveNameConstants, *reader).toTuple())
458           .elements();
459 
460   // If there are debug info files in the original model file, it should also
461   // show up in the backported model
462   bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl");
463 
464   // extra_files are kept
465   auto records = reader->getAllRecords();
466   ExtraFilesMap extra_files;
467   for (const auto& record : records) {
468     std::size_t found = record.find_last_of("/\\");
469     auto path = record.substr(0, found);
470     if ("extra" == path) {
471       extra_files.emplace(record.substr(found + 1), "");
472     }
473   }
474   // Loading the TS module is required for this backport, because bytecode needs
475   // to be re-emitted (refer to the comments below)
476   Module torch_script = torch::jit::load(rai, std::nullopt, extra_files);
477 
478   // The RAII guard to change the flag, emit_default_input_instructions, to
479   // false to keep the same behavior in bytecode version 6. Change the flag,
480   // enable_defaults_args_with_out_args, to deserialized the number of specified
481   // operators which allowing both out arguments and default arguments to
482   // #all_args, instead of (#all_args - #default_args)
483   std::stringstream intermediate_model_stream;
484   {
485     BytecodeEmitModeGuard argNumGuard(
486         false /*emit_default_input_instructions*/,
487         false /*enable_defaults_args_with_out_args*/,
488         false /*enable_emit_promoted_ops*/);
489     torch_script._save_for_mobile(
490         intermediate_model_stream, extra_files, hasBytecodeDebug);
491   }
492 
493   // Update the bytecode version (from 7 to 6)
494   std::stringstream output_model_stream =
495       update_bytecode_version(intermediate_model_stream, kBytecodeVersionV6);
496   return output_model_stream;
497 }
498 
backport_v9_to_v8(std::stringstream & input_model_stream)499 std::stringstream backport_v9_to_v8(std::stringstream& input_model_stream) {
500   ExtraFilesMap extra_files;
501   Module torch_script =
502       torch::jit::load(input_model_stream, std::nullopt, extra_files);
503   std::stringstream intermediate_model_stream;
504   // TODO(@pavithran) : Check if debug info is available and use load/save while
505   // backporting hardcode debaug info to be false untill supported.
506   bool hasBytecodeDebug = false;
507   {
508     BytecodeEmitModeGuard argNumGuard(
509         false /*emit_default_input_instructions*/,
510         true /*enable_defaults_args_with_out_args*/,
511         true /*enable_emit_promoted_ops*/);
512     torch_script._save_for_mobile(
513         intermediate_model_stream,
514         extra_files,
515         hasBytecodeDebug,
516         /*use_flatbuffer=*/false);
517   }
518   // Update the bytecode version (from 9 to 8)
519   std::stringstream output_model_stream =
520       update_bytecode_version(intermediate_model_stream, kBytecodeVersionV8);
521 
522   return output_model_stream;
523 }
524 
backport_v8_to_v7(std::stringstream & input_model_stream)525 std::stringstream backport_v8_to_v7(std::stringstream& input_model_stream) {
526   auto rai =
527       std::make_shared<caffe2::serialize::IStreamAdapter>(&input_model_stream);
528   auto reader = std::make_shared<PyTorchStreamReader>(rai);
529   // extra_files are kept
530   auto records = reader->getAllRecords();
531   bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl");
532   ExtraFilesMap extra_files;
533   for (const auto& record : records) {
534     std::size_t found = record.find_last_of("/\\");
535     auto path = record.substr(0, found);
536     if ("extra" == path) {
537       extra_files.emplace(record.substr(found + 1), "");
538     }
539   }
540   Module torch_script = torch::jit::load(rai, std::nullopt, extra_files);
541   std::stringstream intermediate_model_stream;
542   {
543     BytecodeEmitModeGuard argNumGuard(
544         false /*emit_default_input_instructions*/,
545         true /*enable_defaults_args_with_out_args*/,
546         false /*enable_emit_promoted_ops*/);
547     torch_script._save_for_mobile(
548         intermediate_model_stream, extra_files, hasBytecodeDebug);
549   }
550 
551   // Update the bytecode version (from 8 to 7)
552   std::stringstream output_model_stream =
553       update_bytecode_version(intermediate_model_stream, kBytecodeVersionV7);
554 
555   return output_model_stream;
556 }
557 
558 } // namespace
559 
560 /********************** BackportManager **********************/
561 
562 // A generic contract for backport logic to the previous bytecode version.
563 // Args:
564 // * PyTorchStreamReader has access to the input model from N bytecode version.
565 // * PyTorchStreamWriter has access to the output model backported to the
566 // previous N-1 bytecode version. Returns true if successful, false otherwise.
567 using BytecodeBackportFunction =
568     std::function<std::stringstream(std::stringstream&)>;
569 
BackportManager()570 BackportManager::BackportManager() {
571   registerBytecodeBackportFunction(kBytecodeVersionV5, backport_v5_to_v4);
572   registerBytecodeBackportFunction(kBytecodeVersionV6, backport_v6_to_v5);
573   registerBytecodeBackportFunction(kBytecodeVersionV7, backport_v7_to_v6);
574   registerBytecodeBackportFunction(kBytecodeVersionV8, backport_v8_to_v7);
575   registerBytecodeBackportFunction(kBytecodeVersionV9, backport_v9_to_v8);
576 }
577 
578 std::unordered_map<
579     int64_t,
580     std::function<std::stringstream(std::stringstream&)>>&
bytecodeBackportFunctions() const581 BackportManager::bytecodeBackportFunctions() const {
582   static std::unordered_map<
583       int64_t,
584       std::function<std::stringstream(std::stringstream&)>>
585       backport_functions;
586   return backport_functions;
587 }
588 
hasBytecodeBackportFunction(const int64_t from_version) const589 bool BackportManager::hasBytecodeBackportFunction(
590     const int64_t from_version) const {
591   return bytecodeBackportFunctions().count(from_version);
592 }
593 
registerBytecodeBackportFunction(const int64_t from_version,const BytecodeBackportFunction & backport_function)594 void BackportManager::registerBytecodeBackportFunction(
595     const int64_t from_version,
596     const BytecodeBackportFunction& backport_function) {
597   TORCH_CHECK(
598       !hasBytecodeBackportFunction(from_version),
599       "Backporting from version ",
600       from_version,
601       " is already registered.");
602   bytecodeBackportFunctions()[from_version] = backport_function;
603 }
604 
605 // The main function to run backport from version n to version i.
606 // All models (file or buffer) will be converted stream first, and
607 // istream_adapter has access to it. During the backport process,
608 // the intermediate result will be stored with stream.
backport(std::istream & oss,PyTorchStreamWriter & final_writer,int64_t from_version,int64_t to_version) const609 bool BackportManager::backport(
610     std::istream& oss,
611     PyTorchStreamWriter& final_writer,
612     int64_t from_version,
613     int64_t to_version) const {
614   if (from_version <= to_version) {
615     TORCH_WARN(
616         "backport donesn't support backporting model to new version. It's trying to backport from version ",
617         from_version,
618         " to version ",
619         to_version);
620     return false;
621   }
622   int64_t bytecode_version = from_version;
623   bool backport_success = true;
624 
625   // 1) Given an istream_adapter (an adapter with access to the input model, the
626   // model can be from istream, file and etc), copy all model content to
627   // stringstream
628   oss.seekg(0, std::ios::beg);
629   std::stringstream input_model_stream;
630   input_model_stream << oss.rdbuf();
631   std::stringstream output_model_stream;
632 
633   // 2) backport model, backport_v{i}_to_v{i-1} function's argurment is
634   // (input_model_stream and output_model_stream)
635   while (bytecode_version > to_version) {
636     // Swap input and output if it's not the first time and output_model_stream
637     // has value.
638     if (!output_model_stream.str().empty()) {
639       input_model_stream.swap(output_model_stream);
640       // reset output_model_stream
641       output_model_stream.str("");
642     }
643 
644     if (!hasBytecodeBackportFunction(bytecode_version)) {
645       return false;
646     }
647 
648     input_model_stream.seekg(0, input_model_stream.beg);
649     auto input_model_stream_version =
650         _get_model_bytecode_version(input_model_stream);
651 
652     if (static_cast<int64_t>(input_model_stream_version) != bytecode_version) {
653       TORCH_WARN(
654           "The bytecode version of input model stream is supposed to be ",
655           bytecode_version,
656           ", but it gets ",
657           input_model_stream_version);
658       return false;
659     }
660 
661     // Keep backporting till request version
662     std::stringstream backport_model_stream =
663         bytecodeBackportFunctions()[bytecode_version--](input_model_stream);
664 
665     output_model_stream.swap(backport_model_stream);
666     output_model_stream.seekg(0, output_model_stream.beg);
667     auto output_model_stream_version =
668         _get_model_bytecode_version(output_model_stream);
669 
670     if (static_cast<int64_t>(output_model_stream_version) != bytecode_version) {
671       TORCH_WARN(
672           "The bytecode version of output model stream is supposed to be ",
673           bytecode_version,
674           ", but it gets ",
675           output_model_stream_version);
676       return false;
677     }
678   }
679 
680   // 3) Write the final output_model_stream to final_writer, final_writer has
681   // access to the final model destination (file, ostream and etc)
682   if (output_model_stream.str().empty()) {
683     TORCH_WARN("No output model from backport.");
684     return false;
685   }
686   PyTorchStreamReader last_model_reader(&output_model_stream);
687   selective_copy(
688       last_model_reader,
689       final_writer,
690       std::unordered_set<std::string>(),
691       std::unordered_set<std::string>());
692 
693   return backport_success;
694 }
695 
696 } // namespace torch::jit
697