xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/unpickler.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <c10/util/ArrayRef.h>
5 #include <caffe2/serialize/inline_container.h>
6 #include <torch/csrc/Export.h>
7 #include <torch/csrc/jit/frontend/script_type_parser.h>
8 #include <torch/csrc/jit/serialization/pickler.h>
9 
10 namespace torch::jit {
11 
12 using TypeResolver =
13     std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
14 
15 using ObjLoader = std::function<
16     c10::intrusive_ptr<c10::ivalue::Object>(const at::StrongTypePtr&, IValue)>;
17 
18 class DeserializationStorageContext;
19 
20 // [unpickler refactor] there is some cruft around PickleOpCode::BUILD,
21 // PickleOpCode::NEWOBJ, and the last_opcode_ member below that should be
22 // deleted at some point, the Pickler doesn't produce it and it's only around to
23 // support models saved before 1.1
24 class TORCH_API Unpickler {
25   AT_DISALLOW_COPY_AND_ASSIGN(Unpickler);
26 
27   using TypeParserT = c10::TypePtr (*)(const std::string&);
28 
29  public:
30   // tensors inside the pickle are references to the tensor_table.
31   // class_resolver is to resolve strong class type, type_resolver_ is
32   // to resolve any JIT type. class_resolver and type_resolver are not merged
33   // here because some use cases need to get strong class type that
34   // type_resolver_ can not return.
35   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
36   Unpickler(
37       std::function<size_t(char*, size_t)> reader,
38       TypeResolver type_resolver,
39       c10::ArrayRef<at::Tensor> tensor_table,
40       TypeParserT type_parser = defaultTypeParser)
reader_(std::move (reader))41       : reader_(std::move(reader)),
42         tensor_table_(tensor_table),
43         type_resolver_(std::move(type_resolver)),
44         use_storage_device_(false),
45         type_parser_(type_parser),
46         version_(caffe2::serialize::kProducedFileFormatVersion) {}
47 
48   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
49   Unpickler(
50       std::function<size_t(char*, size_t)> reader,
51       TypeResolver type_resolver,
52       c10::ArrayRef<at::Tensor> tensor_table,
53       ObjLoader obj_loader,
54       TypeParserT type_parser = defaultTypeParser)
reader_(std::move (reader))55       : reader_(std::move(reader)),
56         tensor_table_(tensor_table),
57         type_resolver_(std::move(type_resolver)),
58         obj_loader_(std::move(obj_loader)),
59         use_storage_device_(false),
60         type_parser_(type_parser),
61         version_(caffe2::serialize::kProducedFileFormatVersion) {}
62 
63   // tensors inside the pickle contain meta-data, the raw tensor
64   // dead is retrieved by calling `read_record`.
65   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
66   Unpickler(
67       std::function<size_t(char*, size_t)> reader,
68       TypeResolver type_resolver,
69       ObjLoader obj_loader,
70       std::function<at::DataPtr(const std::string&)> read_record,
71       std::optional<at::Device> device,
72       bool use_storage_device = false,
73       TypeParserT type_parser = defaultTypeParser,
74       std::shared_ptr<DeserializationStorageContext> storage_context = nullptr)
reader_(std::move (reader))75       : reader_(std::move(reader)),
76         tensor_table_(),
77         type_resolver_(std::move(type_resolver)),
78         obj_loader_(std::move(obj_loader)),
79         read_record_(std::move(read_record)),
80         device_(device),
81         use_storage_device_(use_storage_device),
82         type_parser_(type_parser),
83         storage_context_(std::move(storage_context)),
84         version_(caffe2::serialize::kProducedFileFormatVersion) {}
85 
86   // consume the pickle stream, producing an IValue from the contents.
87   // Type Tags: the pickler will restore the type tags on
88   // List and Dict objects when possible IValue is an Object.
89   // Otherwise, Dict and List objects will end up with Any as their tag.
90   // If you know the type of the ivalue, tags can be restored with
91   // restoreAccurateTypeTags
92   IValue parse_ivalue();
93 
94   // [type tag serialization]
95   // This is used to determine whether to restore type tags be recursively
96   // descending into the returned stack object (if version_number <= 2), or
97   // if version_number >= 3, to use the type strings included in the pickle
98   // archive for container types. By default this is set to
99   // `kProducedFileFormatVersion` so unless you're loading a pickle file
100   // from alongside a corresponding `version` file, you don't need to set
101   // the version manually.
set_version(uint64_t version_number)102   void set_version(uint64_t version_number) {
103     version_ = version_number;
104   }
105 
defaultTypeParser(const std::string & str)106   static c10::TypePtr defaultTypeParser(const std::string& str) {
107     ScriptTypeParser parser;
108     return parser.parseType(str);
109   }
110 
111  private:
112   // No arguments ensures that a template argument must be specified
113   // so that the number of bytes read / type read is explicit
114   template <typename T>
read()115   T read() {
116     T item;
117     if (sizeof(T) <= buffer_remaining_) {
118       // Fast path: entirely from buffer.
119       memcpy(&item, buffer_.data() + buffer_pos_, sizeof(T));
120       buffer_remaining_ -= sizeof(T);
121       buffer_pos_ += sizeof(T);
122     } else {
123       // Don't over-template the slow path, to avoid code size bloat.
124       readSlowWithBuffer(reinterpret_cast<char*>(&item), sizeof(T));
125     }
126     return item;
127   }
128   void readSlowWithBuffer(char* dest, size_t sz);
129   std::string readBytes(size_t num_bytes);
130 
131   double readFloat();
132   void readGlobal(
133       const std::string& module_name,
134       const std::string& class_name);
135   void rebuildTensor(bool quantized);
136   void rebuildTensorFromTypeV2();
137   void rebuildSparseTensor();
138 #ifdef USE_DISTRIBUTED
139   void rebuildRRef();
140 #endif
141   PickleOpCode readInstruction();
readOpCode()142   PickleOpCode readOpCode() {
143     return static_cast<PickleOpCode>(read<uint8_t>());
144   }
145   std::string readString();
146   void readList(IValue list_ivalue);
147   void readListElements(IValue list_ivalue, size_t start);
148   void setInput(size_t memo_id);
149   void run();
150 
151   // Returns the number of bytes read. This should statefully
152   // remember the position. Don't call reader_ directly.
153   std::function<size_t(char*, size_t)> reader_;
154   // Small buffer to avoid calling reader_ on a per-byte basis.
155   std::array<char, 256> buffer_;
156   size_t buffer_pos_{0};
157   size_t buffer_remaining_{0};
158 
159   std::vector<IValue> stack_;
160 
161   // globals are represented on the stack as IValue integer indices
162   // into this list
163   std::vector<std::function<void(void)>> globals_;
164   std::vector<IValue> memo_table_;
165   std::vector<size_t> marks_;
166   c10::ArrayRef<at::Tensor> tensor_table_;
167 
168   // When deserializing types on lists and dicts, cache the type here
169   // so we don't have to parse the same type multiple times. Strings
170   // are already de-duplicated and replaced with BINGETs in the
171   // pickler, so we can just use the actual data pointer of each string.
172   std::unordered_map<std::string, c10::TypePtr> type_cache_;
173 
174   // optionally nullptr, needs to be present for creating classes
175   TypeResolver type_resolver_;
176   ObjLoader obj_loader_;
177   IValue empty_tuple_;
178 
179   std::function<at::DataPtr(const std::string&)> read_record_;
180   std::optional<at::Device> device_;
181   // When set to true, Unpickler will ignore the pickled device and use the
182   // device of the DataPtr returned by the read_record_ function. The default
183   // value of this flag is false.
184   const bool use_storage_device_;
185 
186   TypeParserT type_parser_{defaultTypeParser};
187 
188   // Used for torch.package to enable sharing of storages across
189   // ScriptModules and eager modules
190   std::shared_ptr<DeserializationStorageContext> storage_context_;
191 
192   // See [type tag serialization]
193   uint64_t version_;
194 
195   // See [NOTE] skip_next_read_global
196   uint8_t skip_next_read_global = 0;
197 };
198 
199 void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag);
200 
201 } // namespace torch::jit
202