1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 #include <c10/util/ArrayRef.h> 5 #include <caffe2/serialize/inline_container.h> 6 #include <torch/csrc/Export.h> 7 #include <torch/csrc/jit/frontend/script_type_parser.h> 8 #include <torch/csrc/jit/serialization/pickler.h> 9 10 namespace torch::jit { 11 12 using TypeResolver = 13 std::function<c10::StrongTypePtr(const c10::QualifiedName&)>; 14 15 using ObjLoader = std::function< 16 c10::intrusive_ptr<c10::ivalue::Object>(const at::StrongTypePtr&, IValue)>; 17 18 class DeserializationStorageContext; 19 20 // [unpickler refactor] there is some cruft around PickleOpCode::BUILD, 21 // PickleOpCode::NEWOBJ, and the last_opcode_ member below that should be 22 // deleted at some point, the Pickler doesn't produce it and it's only around to 23 // support models saved before 1.1 24 class TORCH_API Unpickler { 25 AT_DISALLOW_COPY_AND_ASSIGN(Unpickler); 26 27 using TypeParserT = c10::TypePtr (*)(const std::string&); 28 29 public: 30 // tensors inside the pickle are references to the tensor_table. 31 // class_resolver is to resolve strong class type, type_resolver_ is 32 // to resolve any JIT type. class_resolver and type_resolver are not merged 33 // here because some use cases need to get strong class type that 34 // type_resolver_ can not return. 35 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) 36 Unpickler( 37 std::function<size_t(char*, size_t)> reader, 38 TypeResolver type_resolver, 39 c10::ArrayRef<at::Tensor> tensor_table, 40 TypeParserT type_parser = defaultTypeParser) reader_(std::move (reader))41 : reader_(std::move(reader)), 42 tensor_table_(tensor_table), 43 type_resolver_(std::move(type_resolver)), 44 use_storage_device_(false), 45 type_parser_(type_parser), 46 version_(caffe2::serialize::kProducedFileFormatVersion) {} 47 48 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) 49 Unpickler( 50 std::function<size_t(char*, size_t)> reader, 51 TypeResolver type_resolver, 52 c10::ArrayRef<at::Tensor> tensor_table, 53 ObjLoader obj_loader, 54 TypeParserT type_parser = defaultTypeParser) reader_(std::move (reader))55 : reader_(std::move(reader)), 56 tensor_table_(tensor_table), 57 type_resolver_(std::move(type_resolver)), 58 obj_loader_(std::move(obj_loader)), 59 use_storage_device_(false), 60 type_parser_(type_parser), 61 version_(caffe2::serialize::kProducedFileFormatVersion) {} 62 63 // tensors inside the pickle contain meta-data, the raw tensor 64 // dead is retrieved by calling `read_record`. 65 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) 66 Unpickler( 67 std::function<size_t(char*, size_t)> reader, 68 TypeResolver type_resolver, 69 ObjLoader obj_loader, 70 std::function<at::DataPtr(const std::string&)> read_record, 71 std::optional<at::Device> device, 72 bool use_storage_device = false, 73 TypeParserT type_parser = defaultTypeParser, 74 std::shared_ptr<DeserializationStorageContext> storage_context = nullptr) reader_(std::move (reader))75 : reader_(std::move(reader)), 76 tensor_table_(), 77 type_resolver_(std::move(type_resolver)), 78 obj_loader_(std::move(obj_loader)), 79 read_record_(std::move(read_record)), 80 device_(device), 81 use_storage_device_(use_storage_device), 82 type_parser_(type_parser), 83 storage_context_(std::move(storage_context)), 84 version_(caffe2::serialize::kProducedFileFormatVersion) {} 85 86 // consume the pickle stream, producing an IValue from the contents. 87 // Type Tags: the pickler will restore the type tags on 88 // List and Dict objects when possible IValue is an Object. 89 // Otherwise, Dict and List objects will end up with Any as their tag. 90 // If you know the type of the ivalue, tags can be restored with 91 // restoreAccurateTypeTags 92 IValue parse_ivalue(); 93 94 // [type tag serialization] 95 // This is used to determine whether to restore type tags be recursively 96 // descending into the returned stack object (if version_number <= 2), or 97 // if version_number >= 3, to use the type strings included in the pickle 98 // archive for container types. By default this is set to 99 // `kProducedFileFormatVersion` so unless you're loading a pickle file 100 // from alongside a corresponding `version` file, you don't need to set 101 // the version manually. set_version(uint64_t version_number)102 void set_version(uint64_t version_number) { 103 version_ = version_number; 104 } 105 defaultTypeParser(const std::string & str)106 static c10::TypePtr defaultTypeParser(const std::string& str) { 107 ScriptTypeParser parser; 108 return parser.parseType(str); 109 } 110 111 private: 112 // No arguments ensures that a template argument must be specified 113 // so that the number of bytes read / type read is explicit 114 template <typename T> read()115 T read() { 116 T item; 117 if (sizeof(T) <= buffer_remaining_) { 118 // Fast path: entirely from buffer. 119 memcpy(&item, buffer_.data() + buffer_pos_, sizeof(T)); 120 buffer_remaining_ -= sizeof(T); 121 buffer_pos_ += sizeof(T); 122 } else { 123 // Don't over-template the slow path, to avoid code size bloat. 124 readSlowWithBuffer(reinterpret_cast<char*>(&item), sizeof(T)); 125 } 126 return item; 127 } 128 void readSlowWithBuffer(char* dest, size_t sz); 129 std::string readBytes(size_t num_bytes); 130 131 double readFloat(); 132 void readGlobal( 133 const std::string& module_name, 134 const std::string& class_name); 135 void rebuildTensor(bool quantized); 136 void rebuildTensorFromTypeV2(); 137 void rebuildSparseTensor(); 138 #ifdef USE_DISTRIBUTED 139 void rebuildRRef(); 140 #endif 141 PickleOpCode readInstruction(); readOpCode()142 PickleOpCode readOpCode() { 143 return static_cast<PickleOpCode>(read<uint8_t>()); 144 } 145 std::string readString(); 146 void readList(IValue list_ivalue); 147 void readListElements(IValue list_ivalue, size_t start); 148 void setInput(size_t memo_id); 149 void run(); 150 151 // Returns the number of bytes read. This should statefully 152 // remember the position. Don't call reader_ directly. 153 std::function<size_t(char*, size_t)> reader_; 154 // Small buffer to avoid calling reader_ on a per-byte basis. 155 std::array<char, 256> buffer_; 156 size_t buffer_pos_{0}; 157 size_t buffer_remaining_{0}; 158 159 std::vector<IValue> stack_; 160 161 // globals are represented on the stack as IValue integer indices 162 // into this list 163 std::vector<std::function<void(void)>> globals_; 164 std::vector<IValue> memo_table_; 165 std::vector<size_t> marks_; 166 c10::ArrayRef<at::Tensor> tensor_table_; 167 168 // When deserializing types on lists and dicts, cache the type here 169 // so we don't have to parse the same type multiple times. Strings 170 // are already de-duplicated and replaced with BINGETs in the 171 // pickler, so we can just use the actual data pointer of each string. 172 std::unordered_map<std::string, c10::TypePtr> type_cache_; 173 174 // optionally nullptr, needs to be present for creating classes 175 TypeResolver type_resolver_; 176 ObjLoader obj_loader_; 177 IValue empty_tuple_; 178 179 std::function<at::DataPtr(const std::string&)> read_record_; 180 std::optional<at::Device> device_; 181 // When set to true, Unpickler will ignore the pickled device and use the 182 // device of the DataPtr returned by the read_record_ function. The default 183 // value of this flag is false. 184 const bool use_storage_device_; 185 186 TypeParserT type_parser_{defaultTypeParser}; 187 188 // Used for torch.package to enable sharing of storages across 189 // ScriptModules and eager modules 190 std::shared_ptr<DeserializationStorageContext> storage_context_; 191 192 // See [type tag serialization] 193 uint64_t version_; 194 195 // See [NOTE] skip_next_read_global 196 uint8_t skip_next_read_global = 0; 197 }; 198 199 void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag); 200 201 } // namespace torch::jit 202