1 #include <torch/csrc/jit/serialization/pickle.h>
2
3 #include <ATen/core/ivalue.h>
4 #include <caffe2/serialize/inline_container.h>
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/serialization/export.h>
7 #include <torch/csrc/jit/serialization/import.h>
8 #include <torch/csrc/jit/serialization/import_read.h>
9
10 namespace torch::jit {
11
12 namespace {
13
customClassResolver(const c10::QualifiedName & qn)14 c10::StrongTypePtr customClassResolver(const c10::QualifiedName& qn) {
15 at::TypePtr type = nullptr;
16 if (c10::QualifiedName("__torch__").isPrefixOf(qn)) {
17 type = torch::getCustomClass(qn.qualifiedName());
18 } else {
19 // This is a regular type, fall back to the default type parser
20 torch::jit::ScriptTypeParser parser;
21 type = parser.parseType(qn.qualifiedName());
22 return c10::StrongTypePtr(nullptr, std::move(type));
23 }
24 if (type == nullptr) {
25 TORCH_CHECK(
26 false,
27 "Couldn't resolve type '{}', did you forget to add its build dependency?",
28 qn.qualifiedName());
29 }
30 // Passing nullptr is a little bit sus, but should be fine:
31 // 1. The lifetime of the class type is not tied to a specific
32 // CompilationUnit
33 // but rather the global custom class registry.
34 // 2. We will not access the `cu_` field and immediately discard this
35 // StrongTypePtr post-deserialization.
36 return c10::StrongTypePtr(nullptr, std::move(type));
37 }
38
39 } // namespace
40
pickle(std::function<void (const char * data_start,size_t data_len)> writer,const IValue & ivalue,std::vector<at::Tensor> * tensor_table)41 void pickle(
42 std::function<void(const char* data_start, size_t data_len)> writer,
43 const IValue& ivalue,
44 std::vector<at::Tensor>* tensor_table) {
45 Pickler pickler(std::move(writer), tensor_table, nullptr, nullptr);
46 pickler.protocol();
47 pickler.pushIValue(ivalue);
48 pickler.stop();
49 }
50
pickle(const IValue & ivalue,std::vector<at::Tensor> * tensor_table)51 std::vector<char> pickle(
52 const IValue& ivalue,
53 std::vector<at::Tensor>* tensor_table) {
54 std::vector<char> data;
55
56 pickle(
57 [&](const char* bytes, size_t len) {
58 data.insert(data.end(), bytes, bytes + len);
59 },
60 ivalue,
61 tensor_table);
62
63 return data;
64 }
65
66 // This has to live here instead of the C++ API to mirror torch.save since the
67 // mobile build excludes the C++ API
pickle_save(const at::IValue & ivalue)68 std::vector<char> pickle_save(const at::IValue& ivalue) {
69 #ifndef C10_MOBILE
70 // Pickle the IValue into an array of bytes
71 std::vector<char> pickle_data;
72 Pickler pickler([&](const char* buf, size_t size) {
73 pickle_data.insert(pickle_data.end(), buf, buf + size);
74 });
75 pickler.protocol();
76 pickler.pushIValue(ivalue);
77 pickler.stop();
78
79 std::vector<char> container_data;
80 container_data.reserve(pickle_data.size());
81
82 caffe2::serialize::PyTorchStreamWriter writer(
83 [&](const void* void_bytes, size_t len) {
84 const char* bytes = reinterpret_cast<const char*>(void_bytes);
85 container_data.insert(container_data.end(), bytes, bytes + len);
86 return len;
87 });
88
89 // Write the generated bytes and the associated tensors into a data.pkl file
90 // and data/0, data/1, data/2... files for each of the tensors
91 writeArchiveAndTensors(
92 "data",
93 pickle_data.data(),
94 pickle_data.size(),
95 pickler.tensorData(),
96 writer);
97 return container_data;
98 #else
99 AT_ERROR(
100 "pickle_save not supported on mobile "
101 "(see https://github.com/pytorch/pytorch/pull/30108)");
102 #endif
103 }
104
105 #ifndef C10_MOBILE
read(uint64_t pos,void * buf,size_t n,const char * what) const106 size_t VectorReader::read(uint64_t pos, void* buf, size_t n, const char* what)
107 const {
108 std::copy(
109 data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
110 return n;
111 }
112
read(uint64_t pos,void * buf,size_t n,const char * what) const113 size_t StringViewReader::read(
114 uint64_t pos,
115 void* buf,
116 size_t n,
117 const char* what) const {
118 std::copy(
119 data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
120 return n;
121 }
122 #endif
123
pickle_load(const std::vector<char> & data)124 IValue pickle_load(const std::vector<char>& data) {
125 // Read in the pickle data
126 #ifndef C10_MOBILE
127 caffe2::serialize::PyTorchStreamReader reader(
128 std::make_unique<VectorReader>(data));
129
130 return readArchiveAndTensors(
131 "data",
132 /*pickle_prefix=*/"",
133 /*tensor_prefix=*/"",
134 /*type_resolver=*/std::nullopt,
135 /*obj_loader=*/std::nullopt,
136 /*device=*/std::nullopt,
137 reader);
138 #else
139 AT_ERROR(
140 "pickle_load not supported on mobile "
141 "(see https://github.com/pytorch/pytorch/pull/30108)");
142 #endif
143 };
144
145 // A specialized version of pickle_load that can load custom objects.
pickle_load_obj(std::string_view data)146 c10::IValue pickle_load_obj(std::string_view data) {
147 #ifndef C10_MOBILE
148 caffe2::serialize::PyTorchStreamReader reader(
149 std::make_unique<torch::jit::StringViewReader>(data));
150 return torch::jit::readArchiveAndTensors(
151 "data",
152 /*pickle_prefix=*/"",
153 /*tensor_prefix=*/"",
154 /*type_resolver=*/customClassResolver,
155 /*obj_loader=*/torch::jit::ObjLoaderFunc,
156 /*device=*/c10::nullopt,
157 reader);
158 #else
159 AT_ERROR(
160 "pickle_load not supported on mobile "
161 "(see https://github.com/pytorch/pytorch/pull/30108)");
162 #endif
163 }
164
unpickle(std::function<size_t (char *,size_t)> reader,TypeResolver type_resolver,c10::ArrayRef<at::Tensor> tensor_table,c10::TypePtr (* type_parser)(const std::string &),ObjLoader obj_loader)165 IValue unpickle(
166 std::function<size_t(char*, size_t)> reader,
167 TypeResolver type_resolver,
168 c10::ArrayRef<at::Tensor> tensor_table,
169 c10::TypePtr (*type_parser)(const std::string&),
170 ObjLoader obj_loader) {
171 Unpickler unpickler(
172 std::move(reader),
173 std::move(type_resolver),
174 tensor_table,
175 std::move(obj_loader),
176 type_parser);
177 return unpickler.parse_ivalue();
178 }
179
unpickle(const char * data,size_t size,TypeResolver type_resolver,c10::ArrayRef<at::Tensor> tensor_table,c10::TypePtr (* type_parser)(const std::string &))180 IValue unpickle(
181 const char* data,
182 size_t size,
183 TypeResolver type_resolver,
184 c10::ArrayRef<at::Tensor> tensor_table,
185 c10::TypePtr (*type_parser)(const std::string&)) {
186 return unpickle(
187 data, size, nullptr, std::move(type_resolver), tensor_table, type_parser);
188 }
189
unpickle(const char * data,size_t size,ObjLoader obj_loader,TypeResolver type_resolver,c10::ArrayRef<at::Tensor> tensor_table,c10::TypePtr (* type_parser)(const std::string &))190 IValue unpickle(
191 const char* data,
192 size_t size,
193 ObjLoader obj_loader,
194 TypeResolver type_resolver,
195 c10::ArrayRef<at::Tensor> tensor_table,
196 c10::TypePtr (*type_parser)(const std::string&)) {
197 size_t bytes_read = 0;
198 return unpickle(
199 [&](char* buffer, size_t len) -> size_t {
200 if (bytes_read >= size) {
201 return 0;
202 }
203 len = std::min(size - bytes_read, len);
204 // Copy len bytes into buffer
205 const char* start = data + bytes_read;
206 std::memcpy(buffer, start, len);
207 bytes_read += len;
208 return len;
209 },
210 std::move(type_resolver),
211 tensor_table,
212 type_parser,
213 std::move(obj_loader));
214 }
215
216 } // namespace torch::jit
217