1 #include <torch/csrc/jit/mobile/import_data.h>
2
3 #include <ATen/Functions.h>
4 #include <ATen/core/ivalue.h>
5 #include <c10/util/irange.h>
6
7 #include <torch/csrc/jit/api/compilation_unit.h>
8 #include <torch/csrc/jit/mobile/file_format.h>
9 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
10 #include <torch/csrc/jit/mobile/import.h>
11 #include <torch/csrc/jit/mobile/import_export_common.h>
12 #include <torch/csrc/jit/mobile/module.h>
13 #include <torch/csrc/jit/mobile/observer.h>
14 #include <torch/csrc/jit/mobile/type_parser.h>
15 #include <torch/csrc/jit/runtime/instruction.h>
16 #include <torch/csrc/jit/serialization/unpickler.h>
17 #include <torch/custom_class.h>
18
19 #include <caffe2/serialize/in_memory_adapter.h>
20 #include <string>
21 #include <vector>
22
23 namespace torch::jit {
24 using caffe2::serialize::PyTorchStreamReader;
25
26 namespace {
27
28 /**
29 * Given a ZIP file containing a file named "data.pkl", uses Pickle to
30 * deserialize the file and returns the IValue inside it.
31 */
32 class IValueUnpickler final {
33 public:
34 explicit IValueUnpickler(std::unique_ptr<PyTorchStreamReader> reader);
35 c10::IValue deserialize(std::optional<at::Device> device);
36
37 private:
38 c10::IValue readArchive(
39 const std::string& archive_name,
40 std::shared_ptr<mobile::CompilationUnit> mcu,
41 std::optional<at::Device> device);
42
43 std::shared_ptr<CompilationUnit> compilation_unit_;
44 std::unique_ptr<PyTorchStreamReader> reader_;
45 };
46
IValueUnpickler(std::unique_ptr<PyTorchStreamReader> reader)47 IValueUnpickler::IValueUnpickler(std::unique_ptr<PyTorchStreamReader> reader)
48 : compilation_unit_(std::make_shared<CompilationUnit>()),
49 reader_(std::move(reader)) {}
50
deserialize(std::optional<at::Device> device)51 c10::IValue IValueUnpickler::deserialize(std::optional<at::Device> device) {
52 auto mcu = std::make_shared<mobile::CompilationUnit>();
53
54 return readArchive("data", mcu, device);
55 }
56
readArchive(const std::string & archive_name,std::shared_ptr<mobile::CompilationUnit> mcu,std::optional<at::Device> device)57 c10::IValue IValueUnpickler::readArchive(
58 const std::string& archive_name,
59 std::shared_ptr<mobile::CompilationUnit> mcu,
60 std::optional<at::Device> device) {
61 std::stringstream picklename;
62 picklename << archive_name << ".pkl";
63 at::DataPtr pickle_ptr;
64 size_t pickle_size = 0;
65 std::tie(pickle_ptr, pickle_size) = reader_->getRecord(picklename.str());
66
67 size_t bytes_read = 0;
68 auto data = reinterpret_cast<const char*>(pickle_ptr.get());
69 auto reader = [&](char* buffer, size_t len) -> size_t {
70 if (bytes_read >= pickle_size) {
71 return 0;
72 }
73 len = std::min(pickle_size - bytes_read, len);
74 // Copy len bytes into buffer
75 const char* start = data + bytes_read;
76 std::memcpy(buffer, start, len);
77 bytes_read += len;
78 return len;
79 };
80
81 static const c10::QualifiedName torchPrefix = "__torch__";
82 auto type_resolver = [&](const c10::QualifiedName& qn) {
83 TypePtr type;
84 // HACK: first we check whether the name starts with `__torch__` to tell if
85 // it's "supposed" to be a class type. This is a reliable check today, but
86 // there is no guarantee that this is the case. The real solution is to
87 // merge type parsers so we can share class resolution logic.
88 if (torchPrefix.isPrefixOf(qn)) {
89 if (compilation_unit_->get_class(qn) == nullptr) {
90 auto typeptr = ClassType::create(qn, compilation_unit_, true);
91 compilation_unit_->register_type(typeptr);
92 }
93 type = compilation_unit_->get_class(qn);
94 } else {
95 type = c10::parseType(qn.qualifiedName());
96 }
97 return c10::StrongTypePtr(compilation_unit_, type);
98 };
99
100 auto obj_loader = [&](const at::StrongTypePtr& type, IValue input) {
101 auto cls = type.type_->expect<at::ClassType>();
102 auto qn = cls->name();
103 c10::QualifiedName method_name(qn.value(), "__setstate__");
104 auto setstate = mcu->find_function(method_name);
105 auto find_custom_class_with_setstate = [&qn]() -> c10::ClassTypePtr {
106 auto custom_class_type = torch::jit::getCustomClass(qn->qualifiedName());
107 if (custom_class_type && custom_class_type->findMethod("__setstate__")) {
108 return custom_class_type;
109 }
110 return nullptr;
111 };
112 if (setstate) {
113 auto obj = c10::ivalue::Object::create(type, 0);
114 Stack stack({obj, input});
115 setstate->run(stack);
116 return obj;
117 } else if (auto custom_class_type = find_custom_class_with_setstate()) {
118 auto obj = c10::ivalue::Object::create(
119 c10::StrongTypePtr(nullptr, custom_class_type), 1);
120 Stack stack({obj, input});
121 custom_class_type->getMethod("__setstate__").run(stack);
122 return obj;
123 } else {
124 auto dict = std::move(input).toGenericDict();
125 size_t ndict = dict.size();
126 auto obj = c10::ivalue::Object::create(type, ndict);
127 auto it = dict.begin();
128 for (const auto i : c10::irange(ndict)) {
129 std::stringstream name;
130 name << it->key();
131 cls->addOrCheckAttribute(name.str(), it->key().type());
132 obj->setSlot(i, it->value());
133 ++it;
134 }
135 return obj;
136 }
137 };
138
139 auto read_record = [&](const std::string& name) {
140 std::stringstream ss;
141 ss << archive_name << "/" << name;
142 return std::get<0>(reader_->getRecord(ss.str()));
143 };
144
145 Unpickler unpickler(
146 reader,
147 std::move(type_resolver),
148 std::move(obj_loader),
149 std::move(read_record),
150 device,
151 false,
152 nullptr);
153 return unpickler.parse_ivalue();
154 }
155
156 /**
157 * Extracts and returns the parameter map serialized as ZIP + Pickle in @p rai.
158 */
load_parameters_from_zip(std::unique_ptr<ReadAdapterInterface> rai,std::optional<c10::Device> device)159 std::map<std::string, at::Tensor> load_parameters_from_zip(
160 std::unique_ptr<ReadAdapterInterface> rai,
161 std::optional<c10::Device> device) {
162 auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
163 IValueUnpickler unpickler(std::move(reader));
164 auto result = unpickler.deserialize(device).toGenericDict();
165 std::map<std::string, at::Tensor> map;
166 for (const auto& e : result) {
167 auto key = e.key().toStringRef();
168 auto value = e.value().toTensor().tensor_data();
169 map[key] = value;
170 }
171 return map;
172 }
173
174 } // namespace
175
176 /**
177 * Extracts the parameter map stored in @p module. Expects a layout
178 * compatible with the one created by #_save_parameters().
179 */
mobile_module_to_parameter_map(const mobile::Module & module)180 std::map<std::string, at::Tensor> mobile_module_to_parameter_map(
181 const mobile::Module& module) {
182 // Safely look for a slot with the expected name. Note that
183 // c10::ivalue::Object::getAttr() is not safe if the attribute isn't present.
184 auto obj = module._ivalue();
185 const std::vector<IValue>& slots = obj->slots();
186 for (const auto i : c10::irange(slots.size())) {
187 if (obj->type()->getAttributeName(i) ==
188 mobile::internal::kSavedParametersAttributeName) {
189 // Found a slot with the right name; make sure it's a
190 // Dict<string, Tensor>.
191 c10::IValue data = slots[i];
192 if (data.isGenericDict()) {
193 auto data_dict = data.toGenericDict();
194
195 // The key and value should be DynamicTypes that wrap String and Tensor.
196 c10::DynamicType* keyType =
197 data_dict.keyType()->castRaw<c10::DynamicType>();
198 c10::DynamicType* valueType =
199 data_dict.valueType()->castRaw<c10::DynamicType>();
200 if (keyType != nullptr &&
201 keyType->fallback()->kind() == TypeKind::StringType &&
202 valueType != nullptr &&
203 valueType->fallback()->kind() == TypeKind::TensorType) {
204 // Name and type are good; copy the contents to the output map.
205 std::map<std::string, at::Tensor> params;
206 for (const auto& e : data_dict) {
207 // The source Tensor points into the flatbuffer data associated with
208 // the Module. But, this Tensor needs to outlive the Module, since
209 // the caller of _load_parameters() won't have a pointer to the
210 // Module. So, return a deep copy.
211 const auto& source = e.value().toTensor();
212 at::Tensor copy = at::empty_like(source); // Must be the same shape.
213 copy.copy_(source);
214
215 params[e.key().toStringRef()] = copy;
216 }
217 return params;
218 }
219 }
220 }
221 }
222
223 TORCH_CHECK(
224 false,
225 "Could not find Dict<string, Tensor> named '",
226 mobile::internal::kSavedParametersAttributeName,
227 "' in deserialized mobile::Module");
228 }
229
_load_parameters_bytes(const std::shared_ptr<char> & data,size_t size,std::optional<at::Device> device)230 static std::map<std::string, at::Tensor> _load_parameters_bytes(
231 const std::shared_ptr<char>& data,
232 size_t size,
233 std::optional<at::Device> device) {
234 TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
235 FileFormat format = getFileFormat(data.get());
236 // Call the appropriate parser.
237 std::map<std::string, at::Tensor> map;
238 switch (format) {
239 case FileFormat::FlatbufferFileFormat: {
240 auto m = parse_flatbuffer_no_object(data, size, device);
241 map = mobile_module_to_parameter_map(m);
242 break;
243 }
244
245 case FileFormat::ZipFileFormat: {
246 auto rai = std::make_unique<caffe2::serialize::MemoryReadAdapter>(
247 data.get(), size);
248 map = load_parameters_from_zip(std::move(rai), device);
249 break;
250 }
251
252 default:
253 TORCH_CHECK(false, "Unrecognized data format");
254 }
255 return map;
256 }
257
_load_parameters(std::istream & in,std::optional<at::Device> device)258 std::map<std::string, at::Tensor> _load_parameters(
259 std::istream& in,
260 std::optional<at::Device> device) {
261 auto [data, size] = get_stream_content(in);
262 return _load_parameters_bytes(data, size, device);
263 }
264
_load_parameters(const std::string & filename,std::optional<at::Device> device)265 std::map<std::string, at::Tensor> _load_parameters(
266 const std::string& filename,
267 std::optional<at::Device> device) {
268 auto [data, size] = get_file_content(filename.c_str());
269 return _load_parameters_bytes(data, size, device);
270 }
271
272 } // namespace torch::jit
273