xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/compatibility/backport.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/ivalue.h>
2 #include <caffe2/serialize/file_adapter.h>
3 #include <caffe2/serialize/inline_container.h>
4 #include <torch/csrc/jit/mobile/compatibility/backport.h>
5 #include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
6 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
7 
8 #include <string>
9 
10 namespace torch::jit {
11 
12 using caffe2::serialize::IStreamAdapter;
13 using caffe2::serialize::PyTorchStreamWriter;
14 
15 const static BackportManager backportManager;
16 
17 // Forward declare so that _backport_for_mobile() overloads can
18 // call this method directly.
19 bool _backport_for_mobile_impl(
20     std::istream& oss,
21     PyTorchStreamWriter& writer,
22     const int64_t to_version);
23 
_backport_for_mobile(std::istream & in,std::ostream & out,const int64_t to_version)24 bool _backport_for_mobile(
25     std::istream& in,
26     std::ostream& out,
27     const int64_t to_version) {
28   auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
29     out.write(static_cast<const char*>(buf), nbytes);
30     return !out ? 0 : nbytes;
31   };
32   PyTorchStreamWriter writer(writer_func);
33   return _backport_for_mobile_impl(in, writer, to_version);
34 }
35 
_backport_for_mobile(std::istream & in,const std::string & output_filename,const int64_t to_version)36 bool _backport_for_mobile(
37     std::istream& in,
38     const std::string& output_filename,
39     const int64_t to_version) {
40   PyTorchStreamWriter writer(output_filename);
41   return _backport_for_mobile_impl(in, writer, to_version);
42 }
43 
_backport_for_mobile(const std::string & input_filename,std::ostream & out,const int64_t to_version)44 bool _backport_for_mobile(
45     const std::string& input_filename,
46     std::ostream& out,
47     const int64_t to_version) {
48   std::ifstream file_stream;
49   std::unique_ptr<IStreamAdapter> istream_adapter;
50   file_stream.open(input_filename, std::ifstream::in | std::ifstream::binary);
51   if (!file_stream) {
52     AT_ERROR("open file failed, file path: ", input_filename);
53   }
54   auto writer_func = [&](const void* buf, size_t nbytes) -> size_t {
55     out.write(static_cast<const char*>(buf), nbytes);
56     return !out ? 0 : nbytes;
57   };
58 
59   PyTorchStreamWriter writer(writer_func);
60   return _backport_for_mobile_impl(file_stream, writer, to_version);
61 }
62 
_backport_for_mobile(const std::string & input_filename,const std::string & output_filename,const int64_t to_version)63 bool _backport_for_mobile(
64     const std::string& input_filename,
65     const std::string& output_filename,
66     const int64_t to_version) {
67   std::ifstream file_stream;
68   file_stream.open(input_filename, std::ifstream::in | std::ifstream::binary);
69   if (!file_stream) {
70     AT_ERROR("open file failed, file path: ", input_filename);
71   }
72 
73   PyTorchStreamWriter writer(output_filename);
74   return _backport_for_mobile_impl(file_stream, writer, to_version);
75 }
76 
_backport_for_mobile_impl(std::istream & oss,PyTorchStreamWriter & writer,const int64_t to_version)77 bool _backport_for_mobile_impl(
78     std::istream& oss,
79     PyTorchStreamWriter& writer,
80     const int64_t to_version) {
81   if (!backportManager.hasBytecodeBackportFunction(to_version + 1)) {
82     return false;
83   }
84   oss.seekg(0, oss.beg);
85   auto from_version = _get_model_bytecode_version(oss);
86   return backportManager.backport(oss, writer, from_version, to_version);
87 }
88 
89 } // namespace torch::jit
90