xref: /aosp_15_r20/external/pytorch/caffe2/serialize/inline_container.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <cstdio>
2 #include <cstring>
3 #include <cerrno>
4 #include <istream>
5 #include <ostream>
6 #include <fstream>
7 #include <algorithm>
8 #include <sstream>
9 #include <sys/stat.h>
10 #include <sys/types.h>
11 #include <thread>
12 
13 #include <c10/core/Allocator.h>
14 #include <c10/core/Backend.h>
15 #include <c10/core/CPUAllocator.h>
16 #include <c10/core/Backend.h>
17 #include <c10/util/Exception.h>
18 #include <c10/util/Logging.h>
19 #include <c10/util/hash.h>
20 
21 #include "caffe2/core/common.h"
22 #include "caffe2/serialize/file_adapter.h"
23 #include "caffe2/serialize/inline_container.h"
24 #include "caffe2/serialize/istream_adapter.h"
25 #include "caffe2/serialize/read_adapter_interface.h"
26 
27 #include "caffe2/serialize/versions.h"
28 #include "miniz.h"
29 
30 namespace caffe2 {
31 namespace serialize {
32 constexpr c10::string_view kDebugPklSuffix(".debug_pkl");
33 
34 struct MzZipReaderIterWrapper {
MzZipReaderIterWrappercaffe2::serialize::MzZipReaderIterWrapper35   MzZipReaderIterWrapper(mz_zip_reader_extract_iter_state* iter) : impl(iter) {}
36   mz_zip_reader_extract_iter_state* impl;
37 };
38 
ChunkRecordIterator(size_t recordSize,size_t chunkSize,std::unique_ptr<MzZipReaderIterWrapper> iter)39 ChunkRecordIterator::ChunkRecordIterator(
40     size_t recordSize,
41     size_t chunkSize,
42     std::unique_ptr<MzZipReaderIterWrapper> iter)
43     : recordSize_(recordSize),
44       chunkSize_(chunkSize),
45       offset_(0),
46       iter_(std::move(iter)) {}
47 
~ChunkRecordIterator()48 ChunkRecordIterator::~ChunkRecordIterator() {
49   mz_zip_reader_extract_iter_free(iter_->impl);
50 }
51 
next(void * buf)52 size_t ChunkRecordIterator::next(void* buf){
53   size_t want_size = std::min(chunkSize_, recordSize_ - offset_);
54   if (want_size == 0) {
55     return 0;
56   }
57   size_t read_size = mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
58   TORCH_CHECK(read_size > 0, "Read bytes should be larger than 0");
59   offset_ += read_size;
60   return read_size;
61 }
62 
istream_read_func(void * pOpaque,mz_uint64 file_ofs,void * pBuf,size_t n)63 size_t istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
64   auto self = static_cast<PyTorchStreamReader*>(pOpaque);
65   return self->read(file_ofs, static_cast<char*>(pBuf), n);
66 }
67 
basename(const std::string & name)68 static std::string basename(const std::string& name) {
69   size_t start = 0;
70   for(size_t i = 0; i < name.size(); ++i) {
71     if (name[i] == '\\' || name[i] == '/') {
72       start = i + 1;
73     }
74   }
75 
76   if (start >= name.size()) {
77     return "";
78   }
79 
80   size_t end = name.size();
81   for(size_t i = end; i > start; --i) {
82     if (name[i - 1] == '.') {
83       end = i - 1;
84       break;
85     }
86   }
87   return name.substr(start, end - start);
88 }
89 
parentdir(const std::string & name)90 static std::string parentdir(const std::string& name) {
91   size_t end = name.find_last_of('/');
92   if (end == std::string::npos) {
93     end = name.find_last_of('\\');
94   }
95 
96   #ifdef WIN32
97   if (end != std::string::npos && end > 1 && name[end - 1] == ':') {
98     // This is a Windows root directory, so include the slash in
99     // the parent directory
100     end++;
101   }
102   #endif
103 
104   if (end == std::string::npos) {
105     return "";
106   }
107 
108   return name.substr(0, end);
109 }
110 
read(uint64_t pos,char * buf,size_t n)111 size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
112   return in_->read(pos, buf, n, "reading file");
113 }
114 
115 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PyTorchStreamReader(const std::string & file_name)116 PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name)
117     : ar_(std::make_unique<mz_zip_archive>()),
118       in_(std::make_unique<FileAdapter>(file_name)) {
119   init();
120 }
121 
122 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PyTorchStreamReader(std::istream * in)123 PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
124     : ar_(std::make_unique<mz_zip_archive>()),
125       in_(std::make_unique<IStreamAdapter>(in)) {
126   init();
127 }
128 
129 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in)130 PyTorchStreamReader::PyTorchStreamReader(
131     std::shared_ptr<ReadAdapterInterface> in)
132     : ar_(std::make_unique<mz_zip_archive>()), in_(std::move(in)) {
133   init();
134 }
135 
init()136 void PyTorchStreamReader::init() {
137   AT_ASSERT(in_ != nullptr);
138   AT_ASSERT(ar_ != nullptr);
139   memset(ar_.get(), 0, sizeof(mz_zip_archive));
140 
141   size_t size = in_->size();
142 
143   // check for the old magic number,
144   constexpr size_t kMagicValueLength = 8;
145   if (size > kMagicValueLength) {
146     // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
147     char buf[kMagicValueLength];
148     read(0, buf, kMagicValueLength);
149     valid("checking magic number");
150     AT_ASSERTM(
151         memcmp("PYTORCH1", buf, kMagicValueLength) != 0,
152         "File is an unsupported archive format from the preview release.");
153   }
154 
155   ar_->m_pIO_opaque = this;
156   ar_->m_pRead = istream_read_func;
157 
158   mz_zip_reader_init(ar_.get(), size, 0);
159   valid("reading zip archive");
160 
161   // figure out the archive_name (i.e. the zip folder all the other files are in)
162   // all lookups to getRecord will be prefixed by this folder
163   mz_uint n = mz_zip_reader_get_num_files(ar_.get());
164   if (n == 0) {
165     CAFFE_THROW("archive does not contain any files");
166   }
167   size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0, nullptr, 0);
168   valid("getting filename");
169   std::string buf(name_size, '\0');
170   mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size);
171   valid("getting filename");
172   auto pos = buf.find_first_of('/');
173   if (pos == std::string::npos) {
174     CAFFE_THROW("file in archive is not in a subdirectory: ", buf);
175   }
176   archive_name_ = buf.substr(0, pos);
177   archive_name_plus_slash_ = archive_name_ + "/";
178 
179   // read serialization id
180   if (hasRecord(kSerializationIdRecordName)) {
181     at::DataPtr serialization_id_ptr;
182     size_t serialization_id_size = 0;
183     std::tie(serialization_id_ptr, serialization_id_size) =
184         getRecord(kSerializationIdRecordName);
185     serialization_id_.assign(
186         static_cast<const char*>(serialization_id_ptr.get()),
187         serialization_id_size);
188   }
189   c10::LogAPIUsageMetadata(
190       "pytorch.stream.reader.metadata",
191       {{"serialization_id", serialization_id_},
192        {"file_name", archive_name_},
193        {"file_size", str(mz_zip_get_archive_size(ar_.get()))}});
194 
195   // version check
196   at::DataPtr version_ptr;
197   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
198   size_t version_size;
199   if (hasRecord(".data/version")) {
200     std::tie(version_ptr, version_size) = getRecord(".data/version");
201   } else {
202     TORCH_CHECK(hasRecord("version"))
203     std::tie(version_ptr, version_size) = getRecord("version");
204   }
205   std::string version(static_cast<const char*>(version_ptr.get()), version_size);
206   try {
207     version_ = std::stoull(version);
208   } catch (const std::invalid_argument& e) {
209     CAFFE_THROW("Couldn't parse the version ",
210                  version,
211                  " as Long Long.");
212   }
213   if (version_ < static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
214     CAFFE_THROW(
215         "Attempted to read a PyTorch file with version ",
216         std::to_string(version_),
217         ", but the minimum supported version for reading is ",
218         std::to_string(kMinSupportedFileFormatVersion),
219         ". Your PyTorch script module file is too old. Please regenerate it",
220         " with latest version of PyTorch to mitigate this issue.");
221   }
222 
223   if (version_ > static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
224     CAFFE_THROW(
225         "Attempted to read a PyTorch file with version ",
226         version_,
227         ", but the maximum supported version for reading is ",
228         kMaxSupportedFileFormatVersion,
229         ". The version of your PyTorch installation may be too old, ",
230         "please upgrade PyTorch to latest version to mitigate this issue.");
231   }
232 }
233 
valid(const char * what,const char * info)234 void PyTorchStreamReader::valid(const char* what, const char* info) {
235   const auto err = mz_zip_get_last_error(ar_.get());
236   TORCH_CHECK(
237       err == MZ_ZIP_NO_ERROR,
238       "PytorchStreamReader failed ",
239       what,
240       info,
241       ": ",
242       mz_zip_get_error_string(err));
243 }
244 
245 constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
246 constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
247 constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
248 constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50;
249 
250 namespace detail {
getPadding(size_t cursor,size_t filename_size,size_t size,std::string & padding_buf)251 size_t getPadding(
252     size_t cursor,
253     size_t filename_size,
254     size_t size,
255     std::string& padding_buf) {
256   size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
257       sizeof(mz_uint16) * 2;
258   if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
259     start += sizeof(mz_uint16) * 2;
260     if (size >= MZ_UINT32_MAX) {
261       start += 2 * sizeof(mz_uint64);
262     }
263     if (cursor >= MZ_UINT32_MAX) {
264       start += sizeof(mz_uint64);
265     }
266   }
267   size_t mod = start % kFieldAlignment;
268   size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
269   size_t padding_size = next_offset - start;
270   size_t padding_size_plus_fbxx = padding_size + 4;
271   if (padding_buf.size() < padding_size_plus_fbxx) {
272     padding_buf.append(padding_size_plus_fbxx - padding_buf.size(), 'Z');
273   }
274   // zip extra encoding (key, size_of_extra_bytes)
275   padding_buf[0] = 'F';
276   padding_buf[1] = 'B';
277   padding_buf[2] = (uint8_t)padding_size;
278   padding_buf[3] = (uint8_t)(padding_size >> 8);
279   return padding_size_plus_fbxx;
280 }
281 }
282 
hasRecord(const std::string & name)283 bool PyTorchStreamReader::hasRecord(const std::string& name) {
284   std::lock_guard<std::mutex> guard(reader_lock_);
285 
286   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
287     return false;
288   }
289   std::string ss = archive_name_plus_slash_ + name;
290   mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
291   const mz_zip_error err = mz_zip_get_last_error(ar_.get());
292 
293   if (err == MZ_ZIP_NO_ERROR) {
294     return true;
295   } else if (err == MZ_ZIP_FILE_NOT_FOUND) {
296     return false;
297   } else {
298     // A different error happened, raise it.
299     valid("attempting to locate file ", name.c_str());
300   }
301   TORCH_INTERNAL_ASSERT(false, "should not reach here");
302 }
303 
getAllRecords()304 std::vector<std::string> PyTorchStreamReader::getAllRecords() {
305   std::lock_guard<std::mutex> guard(reader_lock_);
306   mz_uint num_files = mz_zip_reader_get_num_files(ar_.get());
307   std::vector<std::string> out;
308   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
309   char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE];
310   for (size_t i = 0; i < num_files; i++) {
311     mz_zip_reader_get_filename(ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
312     if (strncmp(
313             buf,
314             archive_name_plus_slash_.data(),
315             archive_name_plus_slash_.size()) != 0) {
316       CAFFE_THROW(
317           "file in archive is not in a subdirectory ",
318           archive_name_plus_slash_,
319           ": ",
320           buf);
321     }
322     if ((load_debug_symbol_) ||
323         (!c10::string_view(buf + archive_name_plus_slash_.size()).ends_with(kDebugPklSuffix))) {
324       // NOLINTNEXTLINE(modernize-use-emplace)
325       out.push_back(buf + archive_name_plus_slash_.size());
326     }
327   }
328   return out;
329 }
330 
331 const std::unordered_set<std::string>&
getAllWrittenRecords()332 PyTorchStreamWriter::getAllWrittenRecords() {
333   return files_written_;
334 }
335 
getRecordID(const std::string & name)336 size_t PyTorchStreamReader::getRecordID(const std::string& name) {
337   std::string ss = archive_name_plus_slash_ + name;
338   size_t result = mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
339   valid("locating file ", name.c_str());
340   return result;
341 }
342 
343 // return dataptr, size
getRecord(const std::string & name)344 std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
345   std::lock_guard<std::mutex> guard(reader_lock_);
346   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
347     at::DataPtr retval;
348     return std::make_tuple(std::move(retval), 0);
349   }
350   size_t key = getRecordID(name);
351   mz_zip_archive_file_stat stat;
352   mz_zip_reader_file_stat(ar_.get(), key, &stat);
353   valid("retrieving file meta-data for ", name.c_str());
354   at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
355   mz_zip_reader_extract_to_mem(ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
356   valid("reading file ", name.c_str());
357 
358   return std::make_tuple(std::move(retval), stat.m_uncomp_size);
359 }
360 
361 size_t
getRecordMultiReaders(const std::string & name,std::vector<std::shared_ptr<ReadAdapterInterface>> & additionalReaders,void * dst,size_t n)362 PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
363   std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
364   void *dst, size_t n){
365 
366   size_t nthread = additionalReaders.size()+1;
367   size_t recordOff = getRecordOffset(name);
368   std::vector<std::thread> loaderThreads;
369   size_t perThreadSize = (n+nthread-1)/nthread;
370   std::vector<size_t> readSizes(nthread, 0);
371   std::lock_guard<std::mutex> guard(reader_lock_);
372   for(size_t i = 0; i < nthread ; i++){
373     loaderThreads.emplace_back([this, name, i, n, recordOff, perThreadSize, dst, &additionalReaders, &readSizes]{
374       size_t startPos = i*perThreadSize;
375       size_t endPos = std::min((i+1)*perThreadSize,n);
376       if (startPos < endPos){
377         size_t threadReadSize = endPos - startPos;
378         size_t size = 0;
379         if (i==0){
380           size = read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
381         }else{
382           auto reader = additionalReaders[i-1];
383           size = reader->read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
384         }
385         readSizes[i] = size;
386         LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos << "] "
387             << "from " << name << " of size " << n;
388         TORCH_CHECK(
389               threadReadSize == size,
390               "record size ",
391               threadReadSize,
392               " mismatch with read size ",
393               size);
394       }
395     });
396   }
397 
398   for (auto& thread : loaderThreads) {
399     thread.join();
400   }
401   loaderThreads.clear();
402 
403   size_t total_read_n = 0;
404   for (auto& r : readSizes){
405     total_read_n += r;
406   }
407 
408   TORCH_CHECK(
409       n == total_read_n,
410       "Multi reader total read size ",
411       total_read_n,
412       " mismatch with dst size ",
413       n);
414 
415   return total_read_n;
416 }
417 
418 // read record with multi clients
419 std::tuple<at::DataPtr, size_t>
getRecord(const std::string & name,std::vector<std::shared_ptr<ReadAdapterInterface>> & additionalReaders)420 PyTorchStreamReader::getRecord(const std::string& name,
421   std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
422   if(additionalReaders.empty()){
423     // No additional readers or record too small, use single threaded version
424     return getRecord(name);
425   }
426 
427   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
428     at::DataPtr retval;
429     return std::make_tuple(std::move(retval), 0);
430   }
431   size_t key = getRecordID(name);
432   mz_zip_archive_file_stat stat;
433   mz_zip_reader_file_stat(ar_.get(), key, &stat);
434   auto n = stat.m_uncomp_size;
435   valid("retrieving file meta-data for ", name.c_str());
436   if(n < additional_reader_size_threshold_){
437     // Reader size too small, use single threaded version
438     return getRecord(name);
439   }
440 
441   at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
442   void* dst = retval.get();
443   PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
444   return std::make_tuple(std::move(retval), stat.m_uncomp_size);
445 }
446 
447 // inplace memory writing
448 size_t
getRecord(const std::string & name,void * dst,size_t n)449 PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
450   std::lock_guard<std::mutex> guard(reader_lock_);
451   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
452     return 0;
453   }
454   size_t key = getRecordID(name);
455   mz_zip_archive_file_stat stat;
456   mz_zip_reader_file_stat(ar_.get(), key, &stat);
457   TORCH_CHECK(
458       n == stat.m_uncomp_size,
459       "record size ",
460       stat.m_uncomp_size,
461       " mismatch with dst size ",
462       n);
463   valid("retrieving file meta-data for ", name.c_str());
464   mz_zip_reader_extract_to_mem(ar_.get(), key, dst, stat.m_uncomp_size, 0);
465   valid("reading file ", name.c_str());
466 
467   return stat.m_uncomp_size;
468 }
469 
470 
471 // inplace memory writing, in-tensor multi-threads, can be used for large tensor.
472 size_t
getRecord(const std::string & name,void * dst,size_t n,std::vector<std::shared_ptr<ReadAdapterInterface>> & additionalReaders)473 PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n,
474   std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
475   if(additionalReaders.empty()){
476     // No additional readers, use single threaded version
477     return getRecord(name, dst, n);
478   }
479 
480   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
481     return 0;
482   }
483   size_t key = getRecordID(name);
484   mz_zip_archive_file_stat stat;
485   mz_zip_reader_file_stat(ar_.get(), key, &stat);
486   TORCH_CHECK(
487       n == stat.m_uncomp_size,
488       "record size ",
489       stat.m_uncomp_size,
490       " mismatch with dst size ",
491       n);
492   valid("retrieving file meta-data for ", name.c_str());
493 
494   if(n < additional_reader_size_threshold_){
495     // Reader size too small, use single threaded version
496     return getRecord(name, dst, n);
497   }
498 
499   PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
500   return stat.m_uncomp_size;
501 }
502 
getRecord(const std::string & name,void * dst,size_t n,size_t chunk_size,void * buf,const std::function<void (void *,const void *,size_t)> & memcpy_func)503 size_t PyTorchStreamReader::getRecord(
504     const std::string& name,
505     void* dst,
506     size_t n,
507     size_t chunk_size,
508     void* buf,
509     const std::function<void(void*, const void*, size_t)>& memcpy_func) {
510   std::lock_guard<std::mutex> guard(reader_lock_);
511   if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
512     return 0;
513   }
514   if (chunk_size <= 0) {
515     chunk_size = n;
516   }
517   size_t key = getRecordID(name);
518   mz_zip_archive_file_stat stat;
519   mz_zip_reader_file_stat(ar_.get(), key, &stat);
520   TORCH_CHECK(
521       n == stat.m_uncomp_size,
522       "record size ",
523       stat.m_uncomp_size,
524       " mismatch with dst size ",
525       n);
526   valid("retrieving file meta-data for ", name.c_str());
527 
528   std::vector<uint8_t> buffer;
529   if (buf == nullptr) {
530     buffer.resize(chunk_size);
531     buf = buffer.data();
532   }
533 
534   auto chunkIterator =
535       createChunkReaderIter(name, (size_t)stat.m_uncomp_size, chunk_size);
536   while (auto readSize = chunkIterator.next(buf)) {
537     memcpy_func((char*)dst + chunkIterator.offset_ - readSize, buf, readSize);
538   }
539   valid("reading file ", name.c_str());
540 
541   return stat.m_uncomp_size;
542 }
543 
createChunkReaderIter(const std::string & name,const size_t recordSize,const size_t chunkSize)544 ChunkRecordIterator PyTorchStreamReader::createChunkReaderIter(
545     const std::string& name,
546     const size_t recordSize,
547     const size_t chunkSize) {
548   // Create zip reader iterator
549   size_t key = getRecordID(name);
550   mz_zip_reader_extract_iter_state* zipReaderIter =
551       mz_zip_reader_extract_iter_new(ar_.get(), key, 0);
552   TORCH_CHECK(
553       zipReaderIter != nullptr,
554       "Failed to create zip reader iter: ",
555       mz_zip_get_error_string(mz_zip_get_last_error(ar_.get())));
556 
557   return ChunkRecordIterator(
558       recordSize,
559       chunkSize,
560       std::make_unique<MzZipReaderIterWrapper>(zipReaderIter));
561 }
562 
read_le_16(uint8_t * buf)563 static int64_t read_le_16(uint8_t* buf) {
564   return buf[0] + (buf[1] << 8);
565 }
566 
getRecordOffset(const std::string & name)567 size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
568   std::lock_guard<std::mutex> guard(reader_lock_);
569   mz_zip_archive_file_stat stat;
570   mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
571   valid("retrieving file meta-data for ", name.c_str());
572   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
573   uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
574   in_->read(
575       stat.m_local_header_ofs,
576       local_header,
577       MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
578       "reading file header");
579   size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
580   size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
581   return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
582 }
583 
getRecordSize(const std::string & name)584 size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
585   mz_zip_archive_file_stat stat;
586   mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
587   return stat.m_uncomp_size;
588 }
589 
~PyTorchStreamReader()590 PyTorchStreamReader::~PyTorchStreamReader() {
591   mz_zip_clear_last_error(ar_.get());
592   mz_zip_reader_end(ar_.get());
593   valid("closing reader for archive ", archive_name_.c_str());
594 }
595 
ostream_write_func(void * pOpaque,mz_uint64 file_ofs,const void * pBuf,size_t n)596 size_t ostream_write_func(
597     void* pOpaque,
598     mz_uint64 file_ofs,
599     const void* pBuf,
600     size_t n) {
601   auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
602   if (self->current_pos_ != file_ofs) {
603     CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
604   }
605   size_t ret = self->writer_func_(pBuf, n);
606   if (n != ret) {
607     self->err_seen_ = true;
608   }
609   self->current_pos_ += ret;
610 
611   // Get the CRC32 of uncompressed data from the data descriptor, if the written
612   // data is identified as the data descriptor block.
613   // See [Note: write_record_metadata] for why we check for non-null pBuf here
614   if (pBuf && n >= 8 && MZ_READ_LE32(pBuf) == MZ_ZIP_DATA_DESCRIPTOR_ID) {
615     const int8_t* pInt8Buf = (const int8_t*)pBuf;
616     const uint32_t uncomp_crc32 = MZ_READ_LE32(pInt8Buf + 4);
617     self->combined_uncomp_crc32_ =
618         c10::hash_combine(self->combined_uncomp_crc32_, uncomp_crc32);
619   }
620 
621   return ret;
622 }
623 
PyTorchStreamWriter(const std::string & file_name)624 PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name)
625     : archive_name_(basename(file_name)) {
626   setup(file_name);
627 }
628 
PyTorchStreamWriter(const std::function<size_t (const void *,size_t)> writer_func)629 PyTorchStreamWriter::PyTorchStreamWriter(
630     const std::function<size_t(const void*, size_t)> writer_func)
631     : archive_name_("archive"),
632       writer_func_(writer_func) {
633   setup(archive_name_);
634 }
635 
setup(const string & file_name)636 void PyTorchStreamWriter::setup(const string& file_name) {
637   ar_ = std::make_unique<mz_zip_archive>();
638   memset(ar_.get(), 0, sizeof(mz_zip_archive));
639   archive_name_plus_slash_ = archive_name_ + "/"; // for writeRecord().
640 
641   if (archive_name_.size() == 0) {
642     CAFFE_THROW("invalid file name: ", file_name);
643   }
644   if (!writer_func_) {
645     file_stream_.open(
646         file_name,
647         std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
648     valid("opening archive ", file_name.c_str());
649 
650     const std::string dir_name = parentdir(file_name);
651     if(!dir_name.empty()) {
652       struct stat st;
653       bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
654       TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist.");
655     }
656     TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
657     writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
658       if (!buf) {
659         // See [Note: write_record_metadata]
660         file_stream_.seekp(nbytes, std::ios_base::cur);
661       } else {
662         file_stream_.write(static_cast<const char*>(buf), nbytes);
663       }
664       return !file_stream_ ? 0 : nbytes;
665     };
666   }
667 
668   ar_->m_pIO_opaque = this;
669   ar_->m_pWrite = ostream_write_func;
670 
671   mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
672   valid("initializing archive ", file_name.c_str());
673 }
674 
setMinVersion(const uint64_t version)675 void PyTorchStreamWriter::setMinVersion(const uint64_t version) {
676   version_ = std::max(version, version_);
677 }
678 
writeRecord(const std::string & name,const void * data,size_t size,bool compress)679 void PyTorchStreamWriter::writeRecord(
680     const std::string& name,
681     const void* data,
682     size_t size,
683     bool compress) {
684   AT_ASSERT(!finalized_);
685   AT_ASSERT(!archive_name_plus_slash_.empty());
686   TORCH_INTERNAL_ASSERT(
687       files_written_.count(name) == 0, "Tried to serialize file twice: ", name);
688   if (name == kSerializationIdRecordName && serialization_id_.empty()) {
689     // In case of copying records from another file, skip writing a different
690     // serialization_id than the one computed in this writer.
691     // This is to ensure serialization_id is unique per serialization output.
692     return;
693   }
694   std::string full_name = archive_name_plus_slash_ + name;
695   size_t padding_size =
696       detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
697   uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
698   mz_zip_writer_add_mem_ex_v2(
699       /*pZip=*/ar_.get(),
700       /*pArchive_name=*/full_name.c_str(),
701       /*pBuf=*/data,
702       /*buf_size=*/size,
703       /*pComment=*/nullptr,
704       /*comment_size=*/0,
705       /*level_and_flags=*/flags,
706       /*uncomp_size=*/0,
707       /*uncomp_crc32=*/0,
708       /*last_modified=*/nullptr,
709       /*user_extra_data=*/padding_.c_str(),
710       /*user_extra_data_len=*/padding_size,
711       /*user_extra_data_central=*/nullptr,
712       /*user_extra_data_central_len=*/0);
713   valid("writing file ", name.c_str());
714   files_written_.insert(name);
715 }
716 
writeEndOfFile()717 void PyTorchStreamWriter::writeEndOfFile() {
718   // Ensurers that finalized is set to true even
719   // exception is raised during the method call.
720   // I.e. even partial call to writeEndOfFile() should mark
721   // file as finalized, otherwise double exception raised from
722   // destructor would would result in `std::terminate()`
723   // See https://github.com/pytorch/pytorch/issues/87997/
724   struct Finalizer {
725     Finalizer(bool& var): var_(var) {}
726     ~Finalizer() {
727       var_ = true;
728     }
729    private:
730     bool& var_;
731   } f(finalized_);
732 
733   auto allRecords = getAllWrittenRecords();
734   // If no ".data/version" or "version" record in the output model, rewrites version info
735   if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {
736     std::string version = std::to_string(version_);
737     version.push_back('\n');
738     if (version_ >= 0x6L) {
739       writeRecord(".data/version", version.c_str(), version.size());
740     } else {
741       writeRecord("version", version.c_str(), version.size());
742     }
743   }
744 
745   // If no "byteorder" record in the output model, rewrites byteorder info
746   if(allRecords.find("byteorder") == allRecords.end()) {
747 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
748     std::string byteorder = "little";
749 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
750     std::string byteorder = "big";
751 #else
752 #error Unexpected or undefined __BYTE_ORDER__
753 #endif
754     writeRecord("byteorder", byteorder.c_str(), byteorder.size());
755   }
756 
757   writeSerializationId();
758 
759   AT_ASSERT(!finalized_);
760   finalized_ = true;
761 
762   mz_zip_writer_finalize_archive(ar_.get());
763   mz_zip_writer_end(ar_.get());
764   valid("writing central directory for archive ", archive_name_.c_str());
765   c10::LogAPIUsageMetadata(
766       "pytorch.stream.writer.metadata",
767       {{"serialization_id", serialization_id_},
768        {"file_name", archive_name_},
769        {"file_size", str(mz_zip_get_archive_size(ar_.get()))}});
770   if (file_stream_.is_open()) {
771     file_stream_.close();
772   }
773 }
774 
valid(const char * what,const char * info)775 void PyTorchStreamWriter::valid(const char* what, const char* info) {
776   auto err = mz_zip_get_last_error(ar_.get());
777   if (err != MZ_ZIP_NO_ERROR) {
778     CAFFE_THROW(
779         "PytorchStreamWriter failed ",
780         what,
781         info,
782         ": ",
783         mz_zip_get_error_string(err));
784   }
785   if (err_seen_) {
786     CAFFE_THROW("PytorchStreamWriter failed ", what, info, ".");
787   }
788 }
789 
writeSerializationId()790 void PyTorchStreamWriter::writeSerializationId() {
791   // Serialization id is computed based on all files written, and is composed of
792   // 1) a combined hash of record name hashes
793   // 2) a combined crc32 of the record uncompressed data
794   // This is best effort to create a fixed-length, unique and deterministic id
795   // for the serialized files without incurring additional computation overhead.
796   if (files_written_.find(kSerializationIdRecordName) == files_written_.end()) {
797     uint64_t combined_record_name_hash = 0;
798     for (const std::string& record_name : files_written_) {
799       size_t record_name_hash = c10::hash<std::string>{}(record_name);
800       combined_record_name_hash =
801           c10::hash_combine(combined_record_name_hash, record_name_hash);
802     }
803     std::ostringstream serialization_id_oss;
804     serialization_id_oss << std::setfill('0') << std::setw(20)
805                          << combined_record_name_hash
806                          << std::setfill('0') << std::setw(20)
807                          << combined_uncomp_crc32_;
808     serialization_id_ = serialization_id_oss.str();
809     writeRecord(
810         kSerializationIdRecordName,
811         serialization_id_.c_str(),
812         serialization_id_.size());
813   }
814 }
815 
816 // NOLINTNEXTLINE(bugprone-exception-escape)
~PyTorchStreamWriter()817 PyTorchStreamWriter::~PyTorchStreamWriter() {
818   if (!finalized_) {
819     writeEndOfFile();
820   }
821 }
822 
823 } // namespace serialize
824 } // namespace caffe2
825