xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/pickle.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/serialization/pickle.h>
2 
3 #include <ATen/core/ivalue.h>
4 #include <caffe2/serialize/inline_container.h>
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/serialization/export.h>
7 #include <torch/csrc/jit/serialization/import.h>
8 #include <torch/csrc/jit/serialization/import_read.h>
9 
10 namespace torch::jit {
11 
12 namespace {
13 
customClassResolver(const c10::QualifiedName & qn)14 c10::StrongTypePtr customClassResolver(const c10::QualifiedName& qn) {
15   at::TypePtr type = nullptr;
16   if (c10::QualifiedName("__torch__").isPrefixOf(qn)) {
17     type = torch::getCustomClass(qn.qualifiedName());
18   } else {
19     // This is a regular type, fall back to the default type parser
20     torch::jit::ScriptTypeParser parser;
21     type = parser.parseType(qn.qualifiedName());
22     return c10::StrongTypePtr(nullptr, std::move(type));
23   }
24   if (type == nullptr) {
25     TORCH_CHECK(
26         false,
27         "Couldn't resolve type '{}', did you forget to add its build dependency?",
28         qn.qualifiedName());
29   }
30   // Passing nullptr is a little bit sus, but should be fine:
31   // 1. The lifetime of the class type is not tied to a specific
32   // CompilationUnit
33   //    but rather the global custom class registry.
34   // 2. We will not access the `cu_` field and immediately discard this
35   //    StrongTypePtr post-deserialization.
36   return c10::StrongTypePtr(nullptr, std::move(type));
37 }
38 
39 } // namespace
40 
pickle(std::function<void (const char * data_start,size_t data_len)> writer,const IValue & ivalue,std::vector<at::Tensor> * tensor_table)41 void pickle(
42     std::function<void(const char* data_start, size_t data_len)> writer,
43     const IValue& ivalue,
44     std::vector<at::Tensor>* tensor_table) {
45   Pickler pickler(std::move(writer), tensor_table, nullptr, nullptr);
46   pickler.protocol();
47   pickler.pushIValue(ivalue);
48   pickler.stop();
49 }
50 
pickle(const IValue & ivalue,std::vector<at::Tensor> * tensor_table)51 std::vector<char> pickle(
52     const IValue& ivalue,
53     std::vector<at::Tensor>* tensor_table) {
54   std::vector<char> data;
55 
56   pickle(
57       [&](const char* bytes, size_t len) {
58         data.insert(data.end(), bytes, bytes + len);
59       },
60       ivalue,
61       tensor_table);
62 
63   return data;
64 }
65 
66 // This has to live here instead of the C++ API to mirror torch.save since the
67 // mobile build excludes the C++ API
pickle_save(const at::IValue & ivalue)68 std::vector<char> pickle_save(const at::IValue& ivalue) {
69 #ifndef C10_MOBILE
70   // Pickle the IValue into an array of bytes
71   std::vector<char> pickle_data;
72   Pickler pickler([&](const char* buf, size_t size) {
73     pickle_data.insert(pickle_data.end(), buf, buf + size);
74   });
75   pickler.protocol();
76   pickler.pushIValue(ivalue);
77   pickler.stop();
78 
79   std::vector<char> container_data;
80   container_data.reserve(pickle_data.size());
81 
82   caffe2::serialize::PyTorchStreamWriter writer(
83       [&](const void* void_bytes, size_t len) {
84         const char* bytes = reinterpret_cast<const char*>(void_bytes);
85         container_data.insert(container_data.end(), bytes, bytes + len);
86         return len;
87       });
88 
89   // Write the generated bytes and the associated tensors into a data.pkl file
90   // and data/0, data/1, data/2... files for each of the tensors
91   writeArchiveAndTensors(
92       "data",
93       pickle_data.data(),
94       pickle_data.size(),
95       pickler.tensorData(),
96       writer);
97   return container_data;
98 #else
99   AT_ERROR(
100       "pickle_save not supported on mobile "
101       "(see https://github.com/pytorch/pytorch/pull/30108)");
102 #endif
103 }
104 
105 #ifndef C10_MOBILE
read(uint64_t pos,void * buf,size_t n,const char * what) const106 size_t VectorReader::read(uint64_t pos, void* buf, size_t n, const char* what)
107     const {
108   std::copy(
109       data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
110   return n;
111 }
112 
read(uint64_t pos,void * buf,size_t n,const char * what) const113 size_t StringViewReader::read(
114     uint64_t pos,
115     void* buf,
116     size_t n,
117     const char* what) const {
118   std::copy(
119       data_.data() + pos, data_.data() + pos + n, reinterpret_cast<char*>(buf));
120   return n;
121 }
122 #endif
123 
pickle_load(const std::vector<char> & data)124 IValue pickle_load(const std::vector<char>& data) {
125   // Read in the pickle data
126 #ifndef C10_MOBILE
127   caffe2::serialize::PyTorchStreamReader reader(
128       std::make_unique<VectorReader>(data));
129 
130   return readArchiveAndTensors(
131       "data",
132       /*pickle_prefix=*/"",
133       /*tensor_prefix=*/"",
134       /*type_resolver=*/std::nullopt,
135       /*obj_loader=*/std::nullopt,
136       /*device=*/std::nullopt,
137       reader);
138 #else
139   AT_ERROR(
140       "pickle_load not supported on mobile "
141       "(see https://github.com/pytorch/pytorch/pull/30108)");
142 #endif
143 };
144 
145 // A specialized version of pickle_load that can load custom objects.
pickle_load_obj(std::string_view data)146 c10::IValue pickle_load_obj(std::string_view data) {
147 #ifndef C10_MOBILE
148   caffe2::serialize::PyTorchStreamReader reader(
149       std::make_unique<torch::jit::StringViewReader>(data));
150   return torch::jit::readArchiveAndTensors(
151       "data",
152       /*pickle_prefix=*/"",
153       /*tensor_prefix=*/"",
154       /*type_resolver=*/customClassResolver,
155       /*obj_loader=*/torch::jit::ObjLoaderFunc,
156       /*device=*/c10::nullopt,
157       reader);
158 #else
159   AT_ERROR(
160       "pickle_load not supported on mobile "
161       "(see https://github.com/pytorch/pytorch/pull/30108)");
162 #endif
163 }
164 
unpickle(std::function<size_t (char *,size_t)> reader,TypeResolver type_resolver,c10::ArrayRef<at::Tensor> tensor_table,c10::TypePtr (* type_parser)(const std::string &),ObjLoader obj_loader)165 IValue unpickle(
166     std::function<size_t(char*, size_t)> reader,
167     TypeResolver type_resolver,
168     c10::ArrayRef<at::Tensor> tensor_table,
169     c10::TypePtr (*type_parser)(const std::string&),
170     ObjLoader obj_loader) {
171   Unpickler unpickler(
172       std::move(reader),
173       std::move(type_resolver),
174       tensor_table,
175       std::move(obj_loader),
176       type_parser);
177   return unpickler.parse_ivalue();
178 }
179 
unpickle(const char * data,size_t size,TypeResolver type_resolver,c10::ArrayRef<at::Tensor> tensor_table,c10::TypePtr (* type_parser)(const std::string &))180 IValue unpickle(
181     const char* data,
182     size_t size,
183     TypeResolver type_resolver,
184     c10::ArrayRef<at::Tensor> tensor_table,
185     c10::TypePtr (*type_parser)(const std::string&)) {
186   return unpickle(
187       data, size, nullptr, std::move(type_resolver), tensor_table, type_parser);
188 }
189 
unpickle(const char * data,size_t size,ObjLoader obj_loader,TypeResolver type_resolver,c10::ArrayRef<at::Tensor> tensor_table,c10::TypePtr (* type_parser)(const std::string &))190 IValue unpickle(
191     const char* data,
192     size_t size,
193     ObjLoader obj_loader,
194     TypeResolver type_resolver,
195     c10::ArrayRef<at::Tensor> tensor_table,
196     c10::TypePtr (*type_parser)(const std::string&)) {
197   size_t bytes_read = 0;
198   return unpickle(
199       [&](char* buffer, size_t len) -> size_t {
200         if (bytes_read >= size) {
201           return 0;
202         }
203         len = std::min(size - bytes_read, len);
204         // Copy len bytes into buffer
205         const char* start = data + bytes_read;
206         std::memcpy(buffer, start, len);
207         bytes_read += len;
208         return len;
209       },
210       std::move(type_resolver),
211       tensor_table,
212       type_parser,
213       std::move(obj_loader));
214 }
215 
216 } // namespace torch::jit
217