1 #include <torch/csrc/jit/mobile/train/export_data.h>
2
3 #include <torch/csrc/jit/mobile/import_export_common.h>
4 #include <torch/csrc/jit/mobile/module.h>
5 #include <torch/csrc/jit/runtime/instruction.h>
6 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
7 #include <torch/csrc/jit/serialization/pickler.h>
8 #include <torch/csrc/jit/serialization/type_name_uniquer.h>
9
10 #include <caffe2/serialize/inline_container.h>
11
12 #include <ATen/core/ivalue.h>
13 #include <ATen/core/jit_type.h>
14
15 #include <string>
16 #include <vector>
17
18 namespace torch::jit {
19 namespace mobile {
20
21 char const* toString(OpCode op);
22
23 namespace {
24
25 /**
26 * Serializes an IValue using Pickle, and puts it in a file named "data.pkl"
27 * in a ZIP wrapper.
28 */
29 class IValuePickler final {
30 public:
IValuePickler(const std::string & filename)31 explicit IValuePickler(const std::string& filename) : writer_(filename) {}
32
IValuePickler(const std::function<size_t (const void *,size_t)> & writer_func)33 explicit IValuePickler(
34 const std::function<size_t(const void*, size_t)>& writer_func)
35 : writer_(writer_func) {}
36
serialize(const IValue & object)37 void serialize(const IValue& object) {
38 // Serialize just the data
39 writeArchive("data", object);
40 }
41
42 private:
writeArchive(const std::string & archive_name,const IValue & value)43 void writeArchive(const std::string& archive_name, const IValue& value) {
44 std::vector<char> data;
45 // Vector to capture the run-time class types during pickling the IValues
46 std::vector<c10::ClassTypePtr> memoizedClassTypes;
47 Pickler data_pickle(
48 [&](const char* buf, size_t size) {
49 data.insert(data.end(), buf, buf + size);
50 },
51 nullptr,
52 [&](const c10::ClassTypePtr& t) {
53 return type_name_uniquer_.getUniqueName(t);
54 },
55 &memoizedClassTypes);
56 data_pickle.protocol();
57 data_pickle.pushIValue(value);
58 data_pickle.stop();
59 size_t i = 0;
60 std::string prefix = archive_name + "/";
61 for (const auto& td : data_pickle.tensorData()) {
62 WriteableTensorData writable_td = getWriteableTensorData(td);
63 std::string fname = prefix + std::to_string(i++);
64 writer_.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
65 }
66 std::string fname = archive_name + ".pkl";
67 writer_.writeRecord(fname, data.data(), data.size());
68 }
69
70 caffe2::serialize::PyTorchStreamWriter writer_;
71 TypeNameUniquer type_name_uniquer_;
72 };
73
74 } // namespace
75
76 /**
77 * Converts a map of named tensors to a c10::Dict.
78 */
tensor_map_to_dict(const std::map<std::string,at::Tensor> & map)79 c10::Dict<std::string, at::Tensor> tensor_map_to_dict(
80 const std::map<std::string, at::Tensor>& map) {
81 c10::Dict<std::string, at::Tensor> dict;
82 for (const auto& e : map) {
83 dict.insert(e.first, e.second);
84 }
85 return dict;
86 }
87
88 /**
89 * Returns a Module with a single attribute, with the attribute name specified
90 * by #internal::kSavedParametersAttributeName, whose value is the provided
91 * dict.
92 */
tensor_dict_to_mobile(const c10::Dict<std::string,at::Tensor> & dict)93 mobile::Module tensor_dict_to_mobile(
94 const c10::Dict<std::string, at::Tensor>& dict) {
95 // Create an Object to back the Module, with an attribute to hold the dict.
96 auto cu = std::make_shared<torch::jit::CompilationUnit>();
97 // Note that the name doesn't really matter, but it must begin with
98 // "__torch__." to be treated as a valid class when being imported.
99 auto cls = c10::ClassType::create(
100 "__torch__.SavedParameters", cu, /*is_module=*/true);
101 cls->addAttribute(
102 internal::kSavedParametersAttributeName,
103 c10::DictType::create(dict.keyType(), dict.valueType()));
104 auto object = c10::ivalue::Object::create(
105 c10::StrongTypePtr(std::move(cu), std::move(cls)), /*numSlots=*/1);
106
107 // Add the dict as an attribute.
108 object->setAttr(internal::kSavedParametersAttributeName, dict);
109
110 // Wrap the Object in a Module.
111 auto mcu = std::make_shared<mobile::CompilationUnit>();
112 return mobile::Module(object, mcu);
113 }
114
115 } // namespace mobile
116
117 void (*_save_mobile_module_to)(
118 const mobile::Module& module,
119 const std::function<size_t(const void*, size_t)>& writer_func) = nullptr;
120
_save_parameters(const std::map<std::string,at::Tensor> & map,std::ostream & out,bool use_flatbuffer)121 void _save_parameters(
122 const std::map<std::string, at::Tensor>& map,
123 std::ostream& out,
124 bool use_flatbuffer) {
125 auto dict = mobile::tensor_map_to_dict(map);
126
127 auto write_func = [&out](const void* buf, size_t nbytes) -> size_t {
128 out.write(
129 static_cast<const char*>(buf), static_cast<std::streamsize>(nbytes));
130 return !out ? 0 : nbytes;
131 };
132
133 if (use_flatbuffer) {
134 save_mobile_module_to_func(mobile::tensor_dict_to_mobile(dict), write_func);
135 } else {
136 // For Pickle, we only serialize the dict itself.
137 mobile::IValuePickler pickler(write_func);
138 pickler.serialize(dict);
139 }
140 }
141
_save_parameters(const std::map<std::string,at::Tensor> & map,const std::string & filename,bool use_flatbuffer)142 void _save_parameters(
143 const std::map<std::string, at::Tensor>& map,
144 const std::string& filename,
145 bool use_flatbuffer) {
146 auto dict = mobile::tensor_map_to_dict(map);
147
148 std::ofstream ifile(filename);
149 _save_parameters(map, ifile, use_flatbuffer);
150 }
151
152 } // namespace torch::jit
153