1 #include <torch/serialize/input-archive.h>
2
3 #include <torch/types.h>
4 #include <torch/utils.h>
5
6 #include <c10/util/Exception.h>
7 #include <caffe2/serialize/read_adapter_interface.h>
8 #include <torch/csrc/jit/api/module.h>
9 #include <torch/csrc/jit/serialization/import.h>
10
11 #include <istream>
12 #include <memory>
13 #include <string>
14 #include <utility>
15
16 namespace torch {
17 namespace serialize {
18
InputArchive()19 InputArchive::InputArchive()
20 : module_("Module", std::make_shared<jit::CompilationUnit>()) {}
21
read(const std::string & key,c10::IValue & ivalue)22 void InputArchive::read(const std::string& key, c10::IValue& ivalue) {
23 ivalue = module_.attr(key);
24 }
25
try_read(const std::string & key,c10::IValue & ivalue)26 bool InputArchive::try_read(const std::string& key, c10::IValue& ivalue) {
27 if (!module_.hasattr(key)) {
28 return false;
29 }
30 ivalue = module_.attr(key);
31 return true;
32 }
33
try_read(const std::string & key,Tensor & tensor,bool is_buffer)34 bool InputArchive::try_read(
35 const std::string& key,
36 Tensor& tensor,
37 bool is_buffer) {
38 if (!module_.hasattr(key)) {
39 return false;
40 }
41 auto iv = module_.attr(key);
42 if (!iv.isTensor()) {
43 return false;
44 }
45 auto read_tensor = iv.toTensor();
46 // clang-format on
47 if (tensor.defined()) {
48 torch::NoGradGuard guard;
49 if (tensor.device() != read_tensor.device()) {
50 tensor.set_data(read_tensor);
51 } else {
52 tensor.set_(read_tensor);
53 }
54 } else {
55 tensor = std::move(read_tensor);
56 }
57 return true;
58 }
59
read(const std::string & key,Tensor & tensor,bool is_buffer)60 void InputArchive::read(
61 const std::string& key,
62 Tensor& tensor,
63 bool is_buffer) {
64 TORCH_CHECK(
65 try_read(key, tensor, is_buffer),
66 "No such serialized tensor '",
67 hierarchy_prefix_,
68 key,
69 "'");
70 }
71
try_read(const std::string & key,InputArchive & archive)72 bool InputArchive::try_read(const std::string& key, InputArchive& archive) {
73 if (!module_.hasattr(key)) {
74 return false;
75 }
76 auto iv = module_.attr(key);
77 if (!iv.isModule()) {
78 return false;
79 }
80 archive.module_ = iv.toModule();
81 archive.hierarchy_prefix_ = hierarchy_prefix_ + key + ".";
82 return true;
83 }
84
read(const std::string & key,InputArchive & archive)85 void InputArchive::read(const std::string& key, InputArchive& archive) {
86 TORCH_CHECK(
87 try_read(key, archive),
88 "No such serialized submodule: '",
89 hierarchy_prefix_,
90 key,
91 "'");
92 }
93
load_from(const std::string & filename,std::optional<torch::Device> device)94 void InputArchive::load_from(
95 const std::string& filename,
96 std::optional<torch::Device> device /*= std::nullopt*/) {
97 module_ = torch::jit::load(filename, std::move(device));
98 }
99
load_from(std::istream & stream,std::optional<torch::Device> device)100 void InputArchive::load_from(
101 std::istream& stream,
102 std::optional<torch::Device> device /*= std::nullopt*/) {
103 module_ = torch::jit::load(stream, std::move(device));
104 }
105
load_from(const char * data,size_t size,std::optional<torch::Device> device)106 void InputArchive::load_from(
107 const char* data,
108 size_t size,
109 std::optional<torch::Device> device /*= std::nullopt*/) {
110 using caffe2::serialize::ReadAdapterInterface;
111 class OurAdapter : public ReadAdapterInterface {
112 public:
113 OurAdapter(const char* data, size_t size) : data_(data), size_(size) {}
114 size_t size() const override {
115 return size_;
116 }
117 size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
118 const override {
119 (void)what;
120 if (pos >= size_) {
121 return 0;
122 }
123 size_t nread = std::min(static_cast<size_t>(pos) + n, size_) - pos;
124 memcpy(buf, data_ + pos, nread);
125 return nread;
126 }
127
128 private:
129 const char* data_;
130 size_t size_;
131 };
132 module_ = torch::jit::load(
133 std::make_unique<OurAdapter>(data, size), std::move(device));
134 }
135
load_from(const std::function<size_t (uint64_t,void *,size_t)> & read_func,const std::function<size_t (void)> & size_func,std::optional<torch::Device> device)136 void InputArchive::load_from(
137 const std::function<size_t(uint64_t, void*, size_t)>& read_func,
138 const std::function<size_t(void)>& size_func,
139 std::optional<torch::Device> device /*= std::nullopt*/) {
140 using caffe2::serialize::ReadAdapterInterface;
141 class OurAdapter : public ReadAdapterInterface {
142 public:
143 OurAdapter(
144 const std::function<size_t(uint64_t, void*, size_t)>& read_func,
145 const std::function<size_t(void)>& size_func)
146 : read_func_(read_func), size_func_(size_func) {}
147 size_t size() const override {
148 return size_func_();
149 }
150 size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
151 const override {
152 (void)what;
153 return read_func_(pos, buf, n);
154 }
155
156 private:
157 const std::function<size_t(uint64_t, void*, size_t)>& read_func_;
158 const std::function<size_t(void)>& size_func_;
159 };
160 module_ = torch::jit::load(
161 std::make_unique<OurAdapter>(read_func, size_func), std::move(device));
162 }
163
keys()164 std::vector<std::string> InputArchive::keys() {
165 std::vector<std::string> all_keys;
166 all_keys.reserve(module_.named_attributes(/*recurse=*/false).size());
167
168 for (const torch::jit::NameValue& s :
169 module_.named_attributes(/*recurse=*/false)) {
170 all_keys.push_back(s.name);
171 }
172
173 return all_keys;
174 }
175
176 } // namespace serialize
177 } // namespace torch
178