1 #pragma once 2 3 #include <cerrno> 4 #include <cstdio> 5 #include <cstring> 6 #include <fstream> 7 #include <istream> 8 #include <mutex> 9 #include <ostream> 10 #include <unordered_set> 11 12 #include <c10/core/Allocator.h> 13 #include <c10/core/Backend.h> 14 15 #include "caffe2/serialize/istream_adapter.h" 16 #include "caffe2/serialize/read_adapter_interface.h" 17 #include "caffe2/serialize/versions.h" 18 19 20 extern "C" { 21 typedef struct mz_zip_archive mz_zip_archive; 22 } 23 24 // PyTorch containers are a special zip archive with the following layout 25 // archive_name.zip contains: 26 // archive_name/ 27 // version # a file with a single decimal number written in ascii, 28 // # used to establish the version of the archive format 29 // model.json # overall model description, this is a json output of 30 // # ModelDef from torch.proto 31 // # the following names are by convention only, model.json will 32 // # refer to these files by full names 33 // tensors/ 34 // 0 # flat storage for tensor data, meta-data about shapes, etc. is 35 // # in model.json 36 // 1 37 // ... 38 // # code entries will only exist for modules that have methods attached 39 // code/ 40 // archive_name.py # serialized torch script code (python syntax, using 41 // PythonPrint) archive_name_my_submodule.py # submodules have separate 42 // files 43 // 44 // The PyTorchStreamWriter also ensures additional useful properties for these 45 // files 46 // 1. All files are stored uncompressed. 47 // 2. All files in the archive are aligned to 64 byte boundaries such that 48 // it is possible to mmap the entire file and get an aligned pointer to 49 // tensor data. 50 // 3. We universally write in ZIP64 format for consistency. 51 52 // The PyTorchStreamReader also provides additional properties: 53 // 1. It can read zip files that are created with common 54 // zip tools. This means that even though our writer doesn't compress files, 55 // the reader can still read files that were compressed. 56 // 2. It provides a getRecordOffset function which returns the offset into the 57 // raw file where file data lives. If the file was written with 58 // PyTorchStreamWriter it is guaranteed to be 64 byte aligned. 59 60 // PyTorchReader/Writer handle checking the version number on the archive format 61 // and ensure that all files are written to a archive_name directory so they 62 // unzip cleanly. 63 64 // When developing this format we want to pay particular attention to the 65 // following use cases: 66 // 67 // -- Reading -- 68 // 1) Reading with full random access 69 // a) Reading with file api's such as fread() 70 // b) mmaping the file and jumping around the mapped region 71 // 2) Reading with 1-pass sequential access 72 // -> A reader will need to build up a data structure of parsed structures 73 // as it reads 74 // 75 // -- Writing -- 76 // 1) Writing with full random access 77 // 2) Writing with 1-pass sequential access 78 // -> We must take care not to require updating values that have already 79 // been written. We place the variable-length index at the end and do 80 // not put any indicies into the header to fulfill this constraint. 81 82 // The model.json, which contains all the metadata information, 83 // should be written as the last file. One reason is that the size of tensor 84 // data is usually stable. As long as the shape and type of the tensor do not 85 // change, the size of the data won't change. On the other sied, the size of the 86 // serialized model is likely to change, so we store it as the last record, and 87 // we don't need to move previous records when updating the model data. 88 89 // The zip format is sufficiently flexible to handle the above use-case. 90 // it puts its central directory at the end of the archive and we write 91 // model.json as the last file when writing after we have accumulated all 92 // other information. 93 94 namespace caffe2 { 95 namespace serialize { 96 97 static constexpr const char* kSerializationIdRecordName = ".data/serialization_id"; 98 99 struct MzZipReaderIterWrapper; 100 101 class TORCH_API ChunkRecordIterator { 102 public: 103 ~ChunkRecordIterator(); 104 105 // Read at most `chunkSize` into `buf`. Return the number of actual bytes read. 106 size_t next(void* buf); recordSize()107 size_t recordSize() const { return recordSize_; } 108 109 private: 110 ChunkRecordIterator( 111 size_t recordSize, 112 size_t chunkSize, 113 std::unique_ptr<MzZipReaderIterWrapper> iter); 114 115 const size_t recordSize_; 116 const size_t chunkSize_; 117 size_t offset_; 118 std::unique_ptr<MzZipReaderIterWrapper> iter_; 119 120 friend class PyTorchStreamReader; 121 }; 122 123 class TORCH_API PyTorchStreamReader final { 124 public: 125 explicit PyTorchStreamReader(const std::string& file_name); 126 explicit PyTorchStreamReader(std::istream* in); 127 explicit PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in); 128 129 // return dataptr, size 130 std::tuple<at::DataPtr, size_t> getRecord(const std::string& name); 131 // multi-thread getRecord 132 std::tuple<at::DataPtr, size_t> getRecord(const std::string& name, std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders); 133 // inplace memory writing 134 size_t getRecord(const std::string& name, void* dst, size_t n); 135 // inplace memory writing, multi-threads. 136 // When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader 137 // This approach can be used for reading large tensors. 138 size_t getRecord(const std::string& name, void* dst, size_t n, 139 std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders); 140 size_t getRecord( 141 const std::string& name, 142 void* dst, 143 size_t n, 144 size_t chunk_size, 145 void* buf, 146 const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr); 147 148 // Concurrent reading records with multiple readers. 149 // additionalReaders are additional clients to access the underlying record at different offsets 150 // and write to different trunks of buffers. 151 // If the overall size of the tensor is 10, and size of additionalReader is 2. 152 // The default thread will read [0,4), the additional reader will read [4,8). 153 // The default reader will read [8,10). 154 // The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8), 155 // the additional reader will write to buffer[8,10). 156 // When additionalReaders is empty, the default behavior is call getRecord(name) with default reader 157 // This approach can be used for reading large tensors. 158 size_t getRecordMultiReaders(const std::string& name, 159 std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders, 160 void *dst, size_t n); 161 162 size_t getRecordSize(const std::string& name); 163 164 size_t getRecordOffset(const std::string& name); 165 bool hasRecord(const std::string& name); 166 std::vector<std::string> getAllRecords(); 167 168 ChunkRecordIterator createChunkReaderIter( 169 const std::string& name, 170 const size_t recordSize, 171 const size_t chunkSize); 172 173 ~PyTorchStreamReader(); version()174 uint64_t version() const { 175 return version_; 176 } serializationId()177 const std::string& serializationId() { 178 return serialization_id_; 179 } 180 setShouldLoadDebugSymbol(bool should_load_debug_symbol)181 void setShouldLoadDebugSymbol(bool should_load_debug_symbol) { 182 load_debug_symbol_ = should_load_debug_symbol; 183 } setAdditionalReaderSizeThreshold(const size_t & size)184 void setAdditionalReaderSizeThreshold(const size_t& size){ 185 additional_reader_size_threshold_ = size; 186 } 187 private: 188 void init(); 189 size_t read(uint64_t pos, char* buf, size_t n); 190 void valid(const char* what, const char* info = ""); 191 size_t getRecordID(const std::string& name); 192 193 friend size_t 194 istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n); 195 std::unique_ptr<mz_zip_archive> ar_; 196 std::string archive_name_; 197 std::string archive_name_plus_slash_; 198 std::shared_ptr<ReadAdapterInterface> in_; 199 int64_t version_; 200 std::mutex reader_lock_; 201 bool load_debug_symbol_ = true; 202 std::string serialization_id_; 203 size_t additional_reader_size_threshold_; 204 }; 205 206 class TORCH_API PyTorchStreamWriter final { 207 public: 208 explicit PyTorchStreamWriter(const std::string& archive_name); 209 explicit PyTorchStreamWriter( 210 const std::function<size_t(const void*, size_t)> writer_func); 211 212 void setMinVersion(const uint64_t version); 213 214 void writeRecord( 215 const std::string& name, 216 const void* data, 217 size_t size, 218 bool compress = false); 219 void writeEndOfFile(); 220 221 const std::unordered_set<std::string>& getAllWrittenRecords(); 222 finalized()223 bool finalized() const { 224 return finalized_; 225 } 226 archiveName()227 const std::string& archiveName() { 228 return archive_name_; 229 } 230 serializationId()231 const std::string& serializationId() { 232 return serialization_id_; 233 } 234 235 ~PyTorchStreamWriter(); 236 237 private: 238 void setup(const std::string& file_name); 239 void valid(const char* what, const char* info = ""); 240 void writeSerializationId(); 241 size_t current_pos_ = 0; 242 std::unordered_set<std::string> files_written_; 243 std::unique_ptr<mz_zip_archive> ar_; 244 std::string archive_name_; 245 std::string archive_name_plus_slash_; 246 std::string padding_; 247 std::ofstream file_stream_; 248 std::function<size_t(const void*, size_t)> writer_func_; 249 uint64_t combined_uncomp_crc32_ = 0; 250 std::string serialization_id_; 251 252 // This number will be updated when the model has operators 253 // that have valid upgraders. 254 uint64_t version_ = kMinProducedFileFormatVersion; 255 bool finalized_ = false; 256 bool err_seen_ = false; 257 friend size_t ostream_write_func( 258 void* pOpaque, 259 uint64_t file_ofs, 260 const void* pBuf, 261 size_t n); 262 }; 263 264 namespace detail { 265 // Writer-specific constants 266 constexpr uint64_t kFieldAlignment = 64; 267 268 // Returns a record to be appended to the local user extra data entry in order 269 // to make data beginning aligned at kFieldAlignment bytes boundary. 270 size_t getPadding( 271 size_t cursor, 272 size_t filename_size, 273 size_t size, 274 std::string& padding_buf); 275 } // namespace detail 276 277 } // namespace serialize 278 } // namespace caffe2 279