xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/import_data.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/import_data.h>
2 
3 #include <ATen/Functions.h>
4 #include <ATen/core/ivalue.h>
5 #include <c10/util/irange.h>
6 
7 #include <torch/csrc/jit/api/compilation_unit.h>
8 #include <torch/csrc/jit/mobile/file_format.h>
9 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
10 #include <torch/csrc/jit/mobile/import.h>
11 #include <torch/csrc/jit/mobile/import_export_common.h>
12 #include <torch/csrc/jit/mobile/module.h>
13 #include <torch/csrc/jit/mobile/observer.h>
14 #include <torch/csrc/jit/mobile/type_parser.h>
15 #include <torch/csrc/jit/runtime/instruction.h>
16 #include <torch/csrc/jit/serialization/unpickler.h>
17 #include <torch/custom_class.h>
18 
19 #include <caffe2/serialize/in_memory_adapter.h>
20 #include <string>
21 #include <vector>
22 
23 namespace torch::jit {
24 using caffe2::serialize::PyTorchStreamReader;
25 
26 namespace {
27 
28 /**
29  * Given a ZIP file containing a file named "data.pkl", uses Pickle to
30  * deserialize the file and returns the IValue inside it.
31  */
32 class IValueUnpickler final {
33  public:
34   explicit IValueUnpickler(std::unique_ptr<PyTorchStreamReader> reader);
35   c10::IValue deserialize(std::optional<at::Device> device);
36 
37  private:
38   c10::IValue readArchive(
39       const std::string& archive_name,
40       std::shared_ptr<mobile::CompilationUnit> mcu,
41       std::optional<at::Device> device);
42 
43   std::shared_ptr<CompilationUnit> compilation_unit_;
44   std::unique_ptr<PyTorchStreamReader> reader_;
45 };
46 
IValueUnpickler(std::unique_ptr<PyTorchStreamReader> reader)47 IValueUnpickler::IValueUnpickler(std::unique_ptr<PyTorchStreamReader> reader)
48     : compilation_unit_(std::make_shared<CompilationUnit>()),
49       reader_(std::move(reader)) {}
50 
deserialize(std::optional<at::Device> device)51 c10::IValue IValueUnpickler::deserialize(std::optional<at::Device> device) {
52   auto mcu = std::make_shared<mobile::CompilationUnit>();
53 
54   return readArchive("data", mcu, device);
55 }
56 
readArchive(const std::string & archive_name,std::shared_ptr<mobile::CompilationUnit> mcu,std::optional<at::Device> device)57 c10::IValue IValueUnpickler::readArchive(
58     const std::string& archive_name,
59     std::shared_ptr<mobile::CompilationUnit> mcu,
60     std::optional<at::Device> device) {
61   std::stringstream picklename;
62   picklename << archive_name << ".pkl";
63   at::DataPtr pickle_ptr;
64   size_t pickle_size = 0;
65   std::tie(pickle_ptr, pickle_size) = reader_->getRecord(picklename.str());
66 
67   size_t bytes_read = 0;
68   auto data = reinterpret_cast<const char*>(pickle_ptr.get());
69   auto reader = [&](char* buffer, size_t len) -> size_t {
70     if (bytes_read >= pickle_size) {
71       return 0;
72     }
73     len = std::min(pickle_size - bytes_read, len);
74     // Copy len bytes into buffer
75     const char* start = data + bytes_read;
76     std::memcpy(buffer, start, len);
77     bytes_read += len;
78     return len;
79   };
80 
81   static const c10::QualifiedName torchPrefix = "__torch__";
82   auto type_resolver = [&](const c10::QualifiedName& qn) {
83     TypePtr type;
84     // HACK: first we check whether the name starts with `__torch__` to tell if
85     // it's "supposed" to be a class type. This is a reliable check today, but
86     // there is no guarantee that this is the case. The real solution is to
87     // merge type parsers so we can share class resolution logic.
88     if (torchPrefix.isPrefixOf(qn)) {
89       if (compilation_unit_->get_class(qn) == nullptr) {
90         auto typeptr = ClassType::create(qn, compilation_unit_, true);
91         compilation_unit_->register_type(typeptr);
92       }
93       type = compilation_unit_->get_class(qn);
94     } else {
95       type = c10::parseType(qn.qualifiedName());
96     }
97     return c10::StrongTypePtr(compilation_unit_, type);
98   };
99 
100   auto obj_loader = [&](const at::StrongTypePtr& type, IValue input) {
101     auto cls = type.type_->expect<at::ClassType>();
102     auto qn = cls->name();
103     c10::QualifiedName method_name(qn.value(), "__setstate__");
104     auto setstate = mcu->find_function(method_name);
105     auto find_custom_class_with_setstate = [&qn]() -> c10::ClassTypePtr {
106       auto custom_class_type = torch::jit::getCustomClass(qn->qualifiedName());
107       if (custom_class_type && custom_class_type->findMethod("__setstate__")) {
108         return custom_class_type;
109       }
110       return nullptr;
111     };
112     if (setstate) {
113       auto obj = c10::ivalue::Object::create(type, 0);
114       Stack stack({obj, input});
115       setstate->run(stack);
116       return obj;
117     } else if (auto custom_class_type = find_custom_class_with_setstate()) {
118       auto obj = c10::ivalue::Object::create(
119           c10::StrongTypePtr(nullptr, custom_class_type), 1);
120       Stack stack({obj, input});
121       custom_class_type->getMethod("__setstate__").run(stack);
122       return obj;
123     } else {
124       auto dict = std::move(input).toGenericDict();
125       size_t ndict = dict.size();
126       auto obj = c10::ivalue::Object::create(type, ndict);
127       auto it = dict.begin();
128       for (const auto i : c10::irange(ndict)) {
129         std::stringstream name;
130         name << it->key();
131         cls->addOrCheckAttribute(name.str(), it->key().type());
132         obj->setSlot(i, it->value());
133         ++it;
134       }
135       return obj;
136     }
137   };
138 
139   auto read_record = [&](const std::string& name) {
140     std::stringstream ss;
141     ss << archive_name << "/" << name;
142     return std::get<0>(reader_->getRecord(ss.str()));
143   };
144 
145   Unpickler unpickler(
146       reader,
147       std::move(type_resolver),
148       std::move(obj_loader),
149       std::move(read_record),
150       device,
151       false,
152       nullptr);
153   return unpickler.parse_ivalue();
154 }
155 
156 /**
157  * Extracts and returns the parameter map serialized as ZIP + Pickle in @p rai.
158  */
load_parameters_from_zip(std::unique_ptr<ReadAdapterInterface> rai,std::optional<c10::Device> device)159 std::map<std::string, at::Tensor> load_parameters_from_zip(
160     std::unique_ptr<ReadAdapterInterface> rai,
161     std::optional<c10::Device> device) {
162   auto reader = std::make_unique<PyTorchStreamReader>(std::move(rai));
163   IValueUnpickler unpickler(std::move(reader));
164   auto result = unpickler.deserialize(device).toGenericDict();
165   std::map<std::string, at::Tensor> map;
166   for (const auto& e : result) {
167     auto key = e.key().toStringRef();
168     auto value = e.value().toTensor().tensor_data();
169     map[key] = value;
170   }
171   return map;
172 }
173 
174 } // namespace
175 
176 /**
177  * Extracts the parameter map stored in @p module. Expects a layout
178  * compatible with the one created by #_save_parameters().
179  */
mobile_module_to_parameter_map(const mobile::Module & module)180 std::map<std::string, at::Tensor> mobile_module_to_parameter_map(
181     const mobile::Module& module) {
182   // Safely look for a slot with the expected name. Note that
183   // c10::ivalue::Object::getAttr() is not safe if the attribute isn't present.
184   auto obj = module._ivalue();
185   const std::vector<IValue>& slots = obj->slots();
186   for (const auto i : c10::irange(slots.size())) {
187     if (obj->type()->getAttributeName(i) ==
188         mobile::internal::kSavedParametersAttributeName) {
189       // Found a slot with the right name; make sure it's a
190       // Dict<string, Tensor>.
191       c10::IValue data = slots[i];
192       if (data.isGenericDict()) {
193         auto data_dict = data.toGenericDict();
194 
195         // The key and value should be DynamicTypes that wrap String and Tensor.
196         c10::DynamicType* keyType =
197             data_dict.keyType()->castRaw<c10::DynamicType>();
198         c10::DynamicType* valueType =
199             data_dict.valueType()->castRaw<c10::DynamicType>();
200         if (keyType != nullptr &&
201             keyType->fallback()->kind() == TypeKind::StringType &&
202             valueType != nullptr &&
203             valueType->fallback()->kind() == TypeKind::TensorType) {
204           // Name and type are good; copy the contents to the output map.
205           std::map<std::string, at::Tensor> params;
206           for (const auto& e : data_dict) {
207             // The source Tensor points into the flatbuffer data associated with
208             // the Module. But, this Tensor needs to outlive the Module, since
209             // the caller of _load_parameters() won't have a pointer to the
210             // Module. So, return a deep copy.
211             const auto& source = e.value().toTensor();
212             at::Tensor copy = at::empty_like(source); // Must be the same shape.
213             copy.copy_(source);
214 
215             params[e.key().toStringRef()] = copy;
216           }
217           return params;
218         }
219       }
220     }
221   }
222 
223   TORCH_CHECK(
224       false,
225       "Could not find Dict<string, Tensor> named '",
226       mobile::internal::kSavedParametersAttributeName,
227       "' in deserialized mobile::Module");
228 }
229 
_load_parameters_bytes(const std::shared_ptr<char> & data,size_t size,std::optional<at::Device> device)230 static std::map<std::string, at::Tensor> _load_parameters_bytes(
231     const std::shared_ptr<char>& data,
232     size_t size,
233     std::optional<at::Device> device) {
234   TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
235   FileFormat format = getFileFormat(data.get());
236   // Call the appropriate parser.
237   std::map<std::string, at::Tensor> map;
238   switch (format) {
239     case FileFormat::FlatbufferFileFormat: {
240       auto m = parse_flatbuffer_no_object(data, size, device);
241       map = mobile_module_to_parameter_map(m);
242       break;
243     }
244 
245     case FileFormat::ZipFileFormat: {
246       auto rai = std::make_unique<caffe2::serialize::MemoryReadAdapter>(
247           data.get(), size);
248       map = load_parameters_from_zip(std::move(rai), device);
249       break;
250     }
251 
252     default:
253       TORCH_CHECK(false, "Unrecognized data format");
254   }
255   return map;
256 }
257 
_load_parameters(std::istream & in,std::optional<at::Device> device)258 std::map<std::string, at::Tensor> _load_parameters(
259     std::istream& in,
260     std::optional<at::Device> device) {
261   auto [data, size] = get_stream_content(in);
262   return _load_parameters_bytes(data, size, device);
263 }
264 
_load_parameters(const std::string & filename,std::optional<at::Device> device)265 std::map<std::string, at::Tensor> _load_parameters(
266     const std::string& filename,
267     std::optional<at::Device> device) {
268   auto [data, size] = get_file_content(filename.c_str());
269   return _load_parameters_bytes(data, size, device);
270 }
271 
272 } // namespace torch::jit
273