xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/train/export_data.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/train/export_data.h>
2 
3 #include <torch/csrc/jit/mobile/import_export_common.h>
4 #include <torch/csrc/jit/mobile/module.h>
5 #include <torch/csrc/jit/runtime/instruction.h>
6 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
7 #include <torch/csrc/jit/serialization/pickler.h>
8 #include <torch/csrc/jit/serialization/type_name_uniquer.h>
9 
10 #include <caffe2/serialize/inline_container.h>
11 
12 #include <ATen/core/ivalue.h>
13 #include <ATen/core/jit_type.h>
14 
15 #include <string>
16 #include <vector>
17 
18 namespace torch::jit {
19 namespace mobile {
20 
21 char const* toString(OpCode op);
22 
23 namespace {
24 
25 /**
26  * Serializes an IValue using Pickle, and puts it in a file named "data.pkl"
27  * in a ZIP wrapper.
28  */
29 class IValuePickler final {
30  public:
IValuePickler(const std::string & filename)31   explicit IValuePickler(const std::string& filename) : writer_(filename) {}
32 
IValuePickler(const std::function<size_t (const void *,size_t)> & writer_func)33   explicit IValuePickler(
34       const std::function<size_t(const void*, size_t)>& writer_func)
35       : writer_(writer_func) {}
36 
serialize(const IValue & object)37   void serialize(const IValue& object) {
38     // Serialize just the data
39     writeArchive("data", object);
40   }
41 
42  private:
writeArchive(const std::string & archive_name,const IValue & value)43   void writeArchive(const std::string& archive_name, const IValue& value) {
44     std::vector<char> data;
45     // Vector to capture the run-time class types during pickling the IValues
46     std::vector<c10::ClassTypePtr> memoizedClassTypes;
47     Pickler data_pickle(
48         [&](const char* buf, size_t size) {
49           data.insert(data.end(), buf, buf + size);
50         },
51         nullptr,
52         [&](const c10::ClassTypePtr& t) {
53           return type_name_uniquer_.getUniqueName(t);
54         },
55         &memoizedClassTypes);
56     data_pickle.protocol();
57     data_pickle.pushIValue(value);
58     data_pickle.stop();
59     size_t i = 0;
60     std::string prefix = archive_name + "/";
61     for (const auto& td : data_pickle.tensorData()) {
62       WriteableTensorData writable_td = getWriteableTensorData(td);
63       std::string fname = prefix + std::to_string(i++);
64       writer_.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes());
65     }
66     std::string fname = archive_name + ".pkl";
67     writer_.writeRecord(fname, data.data(), data.size());
68   }
69 
70   caffe2::serialize::PyTorchStreamWriter writer_;
71   TypeNameUniquer type_name_uniquer_;
72 };
73 
74 } // namespace
75 
76 /**
77  * Converts a map of named tensors to a c10::Dict.
78  */
tensor_map_to_dict(const std::map<std::string,at::Tensor> & map)79 c10::Dict<std::string, at::Tensor> tensor_map_to_dict(
80     const std::map<std::string, at::Tensor>& map) {
81   c10::Dict<std::string, at::Tensor> dict;
82   for (const auto& e : map) {
83     dict.insert(e.first, e.second);
84   }
85   return dict;
86 }
87 
88 /**
89  * Returns a Module with a single attribute, with the attribute name specified
90  * by #internal::kSavedParametersAttributeName, whose value is the provided
91  * dict.
92  */
tensor_dict_to_mobile(const c10::Dict<std::string,at::Tensor> & dict)93 mobile::Module tensor_dict_to_mobile(
94     const c10::Dict<std::string, at::Tensor>& dict) {
95   // Create an Object to back the Module, with an attribute to hold the dict.
96   auto cu = std::make_shared<torch::jit::CompilationUnit>();
97   // Note that the name doesn't really matter, but it must begin with
98   // "__torch__." to be treated as a valid class when being imported.
99   auto cls = c10::ClassType::create(
100       "__torch__.SavedParameters", cu, /*is_module=*/true);
101   cls->addAttribute(
102       internal::kSavedParametersAttributeName,
103       c10::DictType::create(dict.keyType(), dict.valueType()));
104   auto object = c10::ivalue::Object::create(
105       c10::StrongTypePtr(std::move(cu), std::move(cls)), /*numSlots=*/1);
106 
107   // Add the dict as an attribute.
108   object->setAttr(internal::kSavedParametersAttributeName, dict);
109 
110   // Wrap the Object in a Module.
111   auto mcu = std::make_shared<mobile::CompilationUnit>();
112   return mobile::Module(object, mcu);
113 }
114 
115 } // namespace mobile
116 
117 void (*_save_mobile_module_to)(
118     const mobile::Module& module,
119     const std::function<size_t(const void*, size_t)>& writer_func) = nullptr;
120 
_save_parameters(const std::map<std::string,at::Tensor> & map,std::ostream & out,bool use_flatbuffer)121 void _save_parameters(
122     const std::map<std::string, at::Tensor>& map,
123     std::ostream& out,
124     bool use_flatbuffer) {
125   auto dict = mobile::tensor_map_to_dict(map);
126 
127   auto write_func = [&out](const void* buf, size_t nbytes) -> size_t {
128     out.write(
129         static_cast<const char*>(buf), static_cast<std::streamsize>(nbytes));
130     return !out ? 0 : nbytes;
131   };
132 
133   if (use_flatbuffer) {
134     save_mobile_module_to_func(mobile::tensor_dict_to_mobile(dict), write_func);
135   } else {
136     // For Pickle, we only serialize the dict itself.
137     mobile::IValuePickler pickler(write_func);
138     pickler.serialize(dict);
139   }
140 }
141 
_save_parameters(const std::map<std::string,at::Tensor> & map,const std::string & filename,bool use_flatbuffer)142 void _save_parameters(
143     const std::map<std::string, at::Tensor>& map,
144     const std::string& filename,
145     bool use_flatbuffer) {
146   auto dict = mobile::tensor_map_to_dict(map);
147 
148   std::ofstream ifile(filename);
149   _save_parameters(map, ifile, use_flatbuffer);
150 }
151 
152 } // namespace torch::jit
153