1 #include <ATen/core/ivalue.h>
2 #include <caffe2/serialize/file_adapter.h>
3 #include <caffe2/serialize/inline_container.h>
4 #include <torch/csrc/jit/mobile/compatibility/backport.h>
5 #include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
6 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
7
8 #include <string>
9
10 namespace torch::jit {
11
12 using caffe2::serialize::IStreamAdapter;
13 using caffe2::serialize::PyTorchStreamWriter;
14
15 const static BackportManager backportManager;
16
17 // Forward declare so that _backport_for_mobile() overloads can
18 // call this method directly.
19 bool _backport_for_mobile_impl(
20 std::istream& oss,
21 PyTorchStreamWriter& writer,
22 const int64_t to_version);
23
_backport_for_mobile(std::istream & in,std::ostream & out,const int64_t to_version)24 bool _backport_for_mobile(
25 std::istream& in,
26 std::ostream& out,
27 const int64_t to_version) {
28 auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
29 out.write(static_cast<const char*>(buf), nbytes);
30 return !out ? 0 : nbytes;
31 };
32 PyTorchStreamWriter writer(writer_func);
33 return _backport_for_mobile_impl(in, writer, to_version);
34 }
35
_backport_for_mobile(std::istream & in,const std::string & output_filename,const int64_t to_version)36 bool _backport_for_mobile(
37 std::istream& in,
38 const std::string& output_filename,
39 const int64_t to_version) {
40 PyTorchStreamWriter writer(output_filename);
41 return _backport_for_mobile_impl(in, writer, to_version);
42 }
43
_backport_for_mobile(const std::string & input_filename,std::ostream & out,const int64_t to_version)44 bool _backport_for_mobile(
45 const std::string& input_filename,
46 std::ostream& out,
47 const int64_t to_version) {
48 std::ifstream file_stream;
49 std::unique_ptr<IStreamAdapter> istream_adapter;
50 file_stream.open(input_filename, std::ifstream::in | std::ifstream::binary);
51 if (!file_stream) {
52 AT_ERROR("open file failed, file path: ", input_filename);
53 }
54 auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
55 out.write(static_cast<const char*>(buf), nbytes);
56 return !out ? 0 : nbytes;
57 };
58
59 PyTorchStreamWriter writer(writer_func);
60 return _backport_for_mobile_impl(file_stream, writer, to_version);
61 }
62
_backport_for_mobile(const std::string & input_filename,const std::string & output_filename,const int64_t to_version)63 bool _backport_for_mobile(
64 const std::string& input_filename,
65 const std::string& output_filename,
66 const int64_t to_version) {
67 std::ifstream file_stream;
68 file_stream.open(input_filename, std::ifstream::in | std::ifstream::binary);
69 if (!file_stream) {
70 AT_ERROR("open file failed, file path: ", input_filename);
71 }
72
73 PyTorchStreamWriter writer(output_filename);
74 return _backport_for_mobile_impl(file_stream, writer, to_version);
75 }
76
_backport_for_mobile_impl(std::istream & oss,PyTorchStreamWriter & writer,const int64_t to_version)77 bool _backport_for_mobile_impl(
78 std::istream& oss,
79 PyTorchStreamWriter& writer,
80 const int64_t to_version) {
81 if (!backportManager.hasBytecodeBackportFunction(to_version + 1)) {
82 return false;
83 }
84 oss.seekg(0, oss.beg);
85 auto from_version = _get_model_bytecode_version(oss);
86 return backportManager.backport(oss, writer, from_version, to_version);
87 }
88
89 } // namespace torch::jit
90