1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <cerrno> 4*da0073e9SAndroid Build Coastguard Worker #include <cstdio> 5*da0073e9SAndroid Build Coastguard Worker #include <cstring> 6*da0073e9SAndroid Build Coastguard Worker #include <fstream> 7*da0073e9SAndroid Build Coastguard Worker #include <istream> 8*da0073e9SAndroid Build Coastguard Worker #include <mutex> 9*da0073e9SAndroid Build Coastguard Worker #include <ostream> 10*da0073e9SAndroid Build Coastguard Worker #include <unordered_set> 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Allocator.h> 13*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Backend.h> 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker #include "caffe2/serialize/istream_adapter.h" 16*da0073e9SAndroid Build Coastguard Worker #include "caffe2/serialize/read_adapter_interface.h" 17*da0073e9SAndroid Build Coastguard Worker #include "caffe2/serialize/versions.h" 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker extern "C" { 21*da0073e9SAndroid Build Coastguard Worker typedef struct mz_zip_archive mz_zip_archive; 22*da0073e9SAndroid Build Coastguard Worker } 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker // PyTorch containers are a special zip archive with the following layout 25*da0073e9SAndroid Build Coastguard Worker // archive_name.zip contains: 26*da0073e9SAndroid Build Coastguard Worker // archive_name/ 27*da0073e9SAndroid Build Coastguard Worker // version # a file with a single decimal number written in ascii, 28*da0073e9SAndroid Build Coastguard Worker // # used to establish the version of the archive format 29*da0073e9SAndroid Build Coastguard Worker // model.json # overall model description, this is a json output of 30*da0073e9SAndroid Build Coastguard Worker // # ModelDef from torch.proto 31*da0073e9SAndroid Build Coastguard Worker // # the following names are by convention only, model.json will 32*da0073e9SAndroid Build Coastguard Worker // # refer to these files by full names 33*da0073e9SAndroid Build Coastguard Worker // tensors/ 34*da0073e9SAndroid Build Coastguard Worker // 0 # flat storage for tensor data, meta-data about shapes, etc. is 35*da0073e9SAndroid Build Coastguard Worker // # in model.json 36*da0073e9SAndroid Build Coastguard Worker // 1 37*da0073e9SAndroid Build Coastguard Worker // ... 38*da0073e9SAndroid Build Coastguard Worker // # code entries will only exist for modules that have methods attached 39*da0073e9SAndroid Build Coastguard Worker // code/ 40*da0073e9SAndroid Build Coastguard Worker // archive_name.py # serialized torch script code (python syntax, using 41*da0073e9SAndroid Build Coastguard Worker // PythonPrint) archive_name_my_submodule.py # submodules have separate 42*da0073e9SAndroid Build Coastguard Worker // files 43*da0073e9SAndroid Build Coastguard Worker // 44*da0073e9SAndroid Build Coastguard Worker // The PyTorchStreamWriter also ensures additional useful properties for these 45*da0073e9SAndroid Build Coastguard Worker // files 46*da0073e9SAndroid Build Coastguard Worker // 1. All files are stored uncompressed. 47*da0073e9SAndroid Build Coastguard Worker // 2. All files in the archive are aligned to 64 byte boundaries such that 48*da0073e9SAndroid Build Coastguard Worker // it is possible to mmap the entire file and get an aligned pointer to 49*da0073e9SAndroid Build Coastguard Worker // tensor data. 50*da0073e9SAndroid Build Coastguard Worker // 3. We universally write in ZIP64 format for consistency. 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker // The PyTorchStreamReader also provides additional properties: 53*da0073e9SAndroid Build Coastguard Worker // 1. It can read zip files that are created with common 54*da0073e9SAndroid Build Coastguard Worker // zip tools. This means that even though our writer doesn't compress files, 55*da0073e9SAndroid Build Coastguard Worker // the reader can still read files that were compressed. 56*da0073e9SAndroid Build Coastguard Worker // 2. It provides a getRecordOffset function which returns the offset into the 57*da0073e9SAndroid Build Coastguard Worker // raw file where file data lives. If the file was written with 58*da0073e9SAndroid Build Coastguard Worker // PyTorchStreamWriter it is guaranteed to be 64 byte aligned. 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker // PyTorchReader/Writer handle checking the version number on the archive format 61*da0073e9SAndroid Build Coastguard Worker // and ensure that all files are written to a archive_name directory so they 62*da0073e9SAndroid Build Coastguard Worker // unzip cleanly. 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker // When developing this format we want to pay particular attention to the 65*da0073e9SAndroid Build Coastguard Worker // following use cases: 66*da0073e9SAndroid Build Coastguard Worker // 67*da0073e9SAndroid Build Coastguard Worker // -- Reading -- 68*da0073e9SAndroid Build Coastguard Worker // 1) Reading with full random access 69*da0073e9SAndroid Build Coastguard Worker // a) Reading with file api's such as fread() 70*da0073e9SAndroid Build Coastguard Worker // b) mmaping the file and jumping around the mapped region 71*da0073e9SAndroid Build Coastguard Worker // 2) Reading with 1-pass sequential access 72*da0073e9SAndroid Build Coastguard Worker // -> A reader will need to build up a data structure of parsed structures 73*da0073e9SAndroid Build Coastguard Worker // as it reads 74*da0073e9SAndroid Build Coastguard Worker // 75*da0073e9SAndroid Build Coastguard Worker // -- Writing -- 76*da0073e9SAndroid Build Coastguard Worker // 1) Writing with full random access 77*da0073e9SAndroid Build Coastguard Worker // 2) Writing with 1-pass sequential access 78*da0073e9SAndroid Build Coastguard Worker // -> We must take care not to require updating values that have already 79*da0073e9SAndroid Build Coastguard Worker // been written. We place the variable-length index at the end and do 80*da0073e9SAndroid Build Coastguard Worker // not put any indicies into the header to fulfill this constraint. 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker // The model.json, which contains all the metadata information, 83*da0073e9SAndroid Build Coastguard Worker // should be written as the last file. One reason is that the size of tensor 84*da0073e9SAndroid Build Coastguard Worker // data is usually stable. As long as the shape and type of the tensor do not 85*da0073e9SAndroid Build Coastguard Worker // change, the size of the data won't change. On the other sied, the size of the 86*da0073e9SAndroid Build Coastguard Worker // serialized model is likely to change, so we store it as the last record, and 87*da0073e9SAndroid Build Coastguard Worker // we don't need to move previous records when updating the model data. 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker // The zip format is sufficiently flexible to handle the above use-case. 90*da0073e9SAndroid Build Coastguard Worker // it puts its central directory at the end of the archive and we write 91*da0073e9SAndroid Build Coastguard Worker // model.json as the last file when writing after we have accumulated all 92*da0073e9SAndroid Build Coastguard Worker // other information. 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker namespace caffe2 { 95*da0073e9SAndroid Build Coastguard Worker namespace serialize { 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker static constexpr const char* kSerializationIdRecordName = ".data/serialization_id"; 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker struct MzZipReaderIterWrapper; 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker class TORCH_API ChunkRecordIterator { 102*da0073e9SAndroid Build Coastguard Worker public: 103*da0073e9SAndroid Build Coastguard Worker ~ChunkRecordIterator(); 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker // Read at most `chunkSize` into `buf`. Return the number of actual bytes read. 106*da0073e9SAndroid Build Coastguard Worker size_t next(void* buf); recordSize()107*da0073e9SAndroid Build Coastguard Worker size_t recordSize() const { return recordSize_; } 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker private: 110*da0073e9SAndroid Build Coastguard Worker ChunkRecordIterator( 111*da0073e9SAndroid Build Coastguard Worker size_t recordSize, 112*da0073e9SAndroid Build Coastguard Worker size_t chunkSize, 113*da0073e9SAndroid Build Coastguard Worker std::unique_ptr<MzZipReaderIterWrapper> iter); 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker const size_t recordSize_; 116*da0073e9SAndroid Build Coastguard Worker const size_t chunkSize_; 117*da0073e9SAndroid Build Coastguard Worker size_t offset_; 118*da0073e9SAndroid Build Coastguard Worker std::unique_ptr<MzZipReaderIterWrapper> iter_; 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker friend class PyTorchStreamReader; 121*da0073e9SAndroid Build Coastguard Worker }; 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker class TORCH_API PyTorchStreamReader final { 124*da0073e9SAndroid Build Coastguard Worker public: 125*da0073e9SAndroid Build Coastguard Worker explicit PyTorchStreamReader(const std::string& file_name); 126*da0073e9SAndroid Build Coastguard Worker explicit PyTorchStreamReader(std::istream* in); 127*da0073e9SAndroid Build Coastguard Worker explicit PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in); 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker // return dataptr, size 130*da0073e9SAndroid Build Coastguard Worker std::tuple<at::DataPtr, size_t> getRecord(const std::string& name); 131*da0073e9SAndroid Build Coastguard Worker // multi-thread getRecord 132*da0073e9SAndroid Build Coastguard Worker std::tuple<at::DataPtr, size_t> getRecord(const std::string& name, std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders); 133*da0073e9SAndroid Build Coastguard Worker // inplace memory writing 134*da0073e9SAndroid Build Coastguard Worker size_t getRecord(const std::string& name, void* dst, size_t n); 135*da0073e9SAndroid Build Coastguard Worker // inplace memory writing, multi-threads. 136*da0073e9SAndroid Build Coastguard Worker // When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader 137*da0073e9SAndroid Build Coastguard Worker // This approach can be used for reading large tensors. 138*da0073e9SAndroid Build Coastguard Worker size_t getRecord(const std::string& name, void* dst, size_t n, 139*da0073e9SAndroid Build Coastguard Worker std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders); 140*da0073e9SAndroid Build Coastguard Worker size_t getRecord( 141*da0073e9SAndroid Build Coastguard Worker const std::string& name, 142*da0073e9SAndroid Build Coastguard Worker void* dst, 143*da0073e9SAndroid Build Coastguard Worker size_t n, 144*da0073e9SAndroid Build Coastguard Worker size_t chunk_size, 145*da0073e9SAndroid Build Coastguard Worker void* buf, 146*da0073e9SAndroid Build Coastguard Worker const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr); 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker // Concurrent reading records with multiple readers. 149*da0073e9SAndroid Build Coastguard Worker // additionalReaders are additional clients to access the underlying record at different offsets 150*da0073e9SAndroid Build Coastguard Worker // and write to different trunks of buffers. 151*da0073e9SAndroid Build Coastguard Worker // If the overall size of the tensor is 10, and size of additionalReader is 2. 152*da0073e9SAndroid Build Coastguard Worker // The default thread will read [0,4), the additional reader will read [4,8). 153*da0073e9SAndroid Build Coastguard Worker // The default reader will read [8,10). 154*da0073e9SAndroid Build Coastguard Worker // The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8), 155*da0073e9SAndroid Build Coastguard Worker // the additional reader will write to buffer[8,10). 156*da0073e9SAndroid Build Coastguard Worker // When additionalReaders is empty, the default behavior is call getRecord(name) with default reader 157*da0073e9SAndroid Build Coastguard Worker // This approach can be used for reading large tensors. 158*da0073e9SAndroid Build Coastguard Worker size_t getRecordMultiReaders(const std::string& name, 159*da0073e9SAndroid Build Coastguard Worker std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders, 160*da0073e9SAndroid Build Coastguard Worker void *dst, size_t n); 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker size_t getRecordSize(const std::string& name); 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker size_t getRecordOffset(const std::string& name); 165*da0073e9SAndroid Build Coastguard Worker bool hasRecord(const std::string& name); 166*da0073e9SAndroid Build Coastguard Worker std::vector<std::string> getAllRecords(); 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker ChunkRecordIterator createChunkReaderIter( 169*da0073e9SAndroid Build Coastguard Worker const std::string& name, 170*da0073e9SAndroid Build Coastguard Worker const size_t recordSize, 171*da0073e9SAndroid Build Coastguard Worker const size_t chunkSize); 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker ~PyTorchStreamReader(); version()174*da0073e9SAndroid Build Coastguard Worker uint64_t version() const { 175*da0073e9SAndroid Build Coastguard Worker return version_; 176*da0073e9SAndroid Build Coastguard Worker } serializationId()177*da0073e9SAndroid Build Coastguard Worker const std::string& serializationId() { 178*da0073e9SAndroid Build Coastguard Worker return serialization_id_; 179*da0073e9SAndroid Build Coastguard Worker } 180*da0073e9SAndroid Build Coastguard Worker setShouldLoadDebugSymbol(bool should_load_debug_symbol)181*da0073e9SAndroid Build Coastguard Worker void setShouldLoadDebugSymbol(bool should_load_debug_symbol) { 182*da0073e9SAndroid Build Coastguard Worker load_debug_symbol_ = should_load_debug_symbol; 183*da0073e9SAndroid Build Coastguard Worker } setAdditionalReaderSizeThreshold(const size_t & size)184*da0073e9SAndroid Build Coastguard Worker void setAdditionalReaderSizeThreshold(const size_t& size){ 185*da0073e9SAndroid Build Coastguard Worker additional_reader_size_threshold_ = size; 186*da0073e9SAndroid Build Coastguard Worker } 187*da0073e9SAndroid Build Coastguard Worker private: 188*da0073e9SAndroid Build Coastguard Worker void init(); 189*da0073e9SAndroid Build Coastguard Worker size_t read(uint64_t pos, char* buf, size_t n); 190*da0073e9SAndroid Build Coastguard Worker void valid(const char* what, const char* info = ""); 191*da0073e9SAndroid Build Coastguard Worker size_t getRecordID(const std::string& name); 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker friend size_t 194*da0073e9SAndroid Build Coastguard Worker istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n); 195*da0073e9SAndroid Build Coastguard Worker std::unique_ptr<mz_zip_archive> ar_; 196*da0073e9SAndroid Build Coastguard Worker std::string archive_name_; 197*da0073e9SAndroid Build Coastguard Worker std::string archive_name_plus_slash_; 198*da0073e9SAndroid Build Coastguard Worker std::shared_ptr<ReadAdapterInterface> in_; 199*da0073e9SAndroid Build Coastguard Worker int64_t version_; 200*da0073e9SAndroid Build Coastguard Worker std::mutex reader_lock_; 201*da0073e9SAndroid Build Coastguard Worker bool load_debug_symbol_ = true; 202*da0073e9SAndroid Build Coastguard Worker std::string serialization_id_; 203*da0073e9SAndroid Build Coastguard Worker size_t additional_reader_size_threshold_; 204*da0073e9SAndroid Build Coastguard Worker }; 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker class TORCH_API PyTorchStreamWriter final { 207*da0073e9SAndroid Build Coastguard Worker public: 208*da0073e9SAndroid Build Coastguard Worker explicit PyTorchStreamWriter(const std::string& archive_name); 209*da0073e9SAndroid Build Coastguard Worker explicit PyTorchStreamWriter( 210*da0073e9SAndroid Build Coastguard Worker const std::function<size_t(const void*, size_t)> writer_func); 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker void setMinVersion(const uint64_t version); 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker void writeRecord( 215*da0073e9SAndroid Build Coastguard Worker const std::string& name, 216*da0073e9SAndroid Build Coastguard Worker const void* data, 217*da0073e9SAndroid Build Coastguard Worker size_t size, 218*da0073e9SAndroid Build Coastguard Worker bool compress = false); 219*da0073e9SAndroid Build Coastguard Worker void writeEndOfFile(); 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker const std::unordered_set<std::string>& getAllWrittenRecords(); 222*da0073e9SAndroid Build Coastguard Worker finalized()223*da0073e9SAndroid Build Coastguard Worker bool finalized() const { 224*da0073e9SAndroid Build Coastguard Worker return finalized_; 225*da0073e9SAndroid Build Coastguard Worker } 226*da0073e9SAndroid Build Coastguard Worker archiveName()227*da0073e9SAndroid Build Coastguard Worker const std::string& archiveName() { 228*da0073e9SAndroid Build Coastguard Worker return archive_name_; 229*da0073e9SAndroid Build Coastguard Worker } 230*da0073e9SAndroid Build Coastguard Worker serializationId()231*da0073e9SAndroid Build Coastguard Worker const std::string& serializationId() { 232*da0073e9SAndroid Build Coastguard Worker return serialization_id_; 233*da0073e9SAndroid Build Coastguard Worker } 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker ~PyTorchStreamWriter(); 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker private: 238*da0073e9SAndroid Build Coastguard Worker void setup(const std::string& file_name); 239*da0073e9SAndroid Build Coastguard Worker void valid(const char* what, const char* info = ""); 240*da0073e9SAndroid Build Coastguard Worker void writeSerializationId(); 241*da0073e9SAndroid Build Coastguard Worker size_t current_pos_ = 0; 242*da0073e9SAndroid Build Coastguard Worker std::unordered_set<std::string> files_written_; 243*da0073e9SAndroid Build Coastguard Worker std::unique_ptr<mz_zip_archive> ar_; 244*da0073e9SAndroid Build Coastguard Worker std::string archive_name_; 245*da0073e9SAndroid Build Coastguard Worker std::string archive_name_plus_slash_; 246*da0073e9SAndroid Build Coastguard Worker std::string padding_; 247*da0073e9SAndroid Build Coastguard Worker std::ofstream file_stream_; 248*da0073e9SAndroid Build Coastguard Worker std::function<size_t(const void*, size_t)> writer_func_; 249*da0073e9SAndroid Build Coastguard Worker uint64_t combined_uncomp_crc32_ = 0; 250*da0073e9SAndroid Build Coastguard Worker std::string serialization_id_; 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker // This number will be updated when the model has operators 253*da0073e9SAndroid Build Coastguard Worker // that have valid upgraders. 254*da0073e9SAndroid Build Coastguard Worker uint64_t version_ = kMinProducedFileFormatVersion; 255*da0073e9SAndroid Build Coastguard Worker bool finalized_ = false; 256*da0073e9SAndroid Build Coastguard Worker bool err_seen_ = false; 257*da0073e9SAndroid Build Coastguard Worker friend size_t ostream_write_func( 258*da0073e9SAndroid Build Coastguard Worker void* pOpaque, 259*da0073e9SAndroid Build Coastguard Worker uint64_t file_ofs, 260*da0073e9SAndroid Build Coastguard Worker const void* pBuf, 261*da0073e9SAndroid Build Coastguard Worker size_t n); 262*da0073e9SAndroid Build Coastguard Worker }; 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker namespace detail { 265*da0073e9SAndroid Build Coastguard Worker // Writer-specific constants 266*da0073e9SAndroid Build Coastguard Worker constexpr uint64_t kFieldAlignment = 64; 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker // Returns a record to be appended to the local user extra data entry in order 269*da0073e9SAndroid Build Coastguard Worker // to make data beginning aligned at kFieldAlignment bytes boundary. 270*da0073e9SAndroid Build Coastguard Worker size_t getPadding( 271*da0073e9SAndroid Build Coastguard Worker size_t cursor, 272*da0073e9SAndroid Build Coastguard Worker size_t filename_size, 273*da0073e9SAndroid Build Coastguard Worker size_t size, 274*da0073e9SAndroid Build Coastguard Worker std::string& padding_buf); 275*da0073e9SAndroid Build Coastguard Worker } // namespace detail 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker } // namespace serialize 278*da0073e9SAndroid Build Coastguard Worker } // namespace caffe2 279