1 #include <cstdio>
2 #include <cstring>
3 #include <cerrno>
4 #include <istream>
5 #include <ostream>
6 #include <fstream>
7 #include <algorithm>
8 #include <sstream>
9 #include <sys/stat.h>
10 #include <sys/types.h>
11 #include <thread>
12
13 #include <c10/core/Allocator.h>
14 #include <c10/core/Backend.h>
15 #include <c10/core/CPUAllocator.h>
16 #include <c10/core/Backend.h>
17 #include <c10/util/Exception.h>
18 #include <c10/util/Logging.h>
19 #include <c10/util/hash.h>
20
21 #include "caffe2/core/common.h"
22 #include "caffe2/serialize/file_adapter.h"
23 #include "caffe2/serialize/inline_container.h"
24 #include "caffe2/serialize/istream_adapter.h"
25 #include "caffe2/serialize/read_adapter_interface.h"
26
27 #include "caffe2/serialize/versions.h"
28 #include "miniz.h"
29
30 namespace caffe2 {
31 namespace serialize {
32 constexpr c10::string_view kDebugPklSuffix(".debug_pkl");
33
34 struct MzZipReaderIterWrapper {
MzZipReaderIterWrappercaffe2::serialize::MzZipReaderIterWrapper35 MzZipReaderIterWrapper(mz_zip_reader_extract_iter_state* iter) : impl(iter) {}
36 mz_zip_reader_extract_iter_state* impl;
37 };
38
ChunkRecordIterator(size_t recordSize,size_t chunkSize,std::unique_ptr<MzZipReaderIterWrapper> iter)39 ChunkRecordIterator::ChunkRecordIterator(
40 size_t recordSize,
41 size_t chunkSize,
42 std::unique_ptr<MzZipReaderIterWrapper> iter)
43 : recordSize_(recordSize),
44 chunkSize_(chunkSize),
45 offset_(0),
46 iter_(std::move(iter)) {}
47
~ChunkRecordIterator()48 ChunkRecordIterator::~ChunkRecordIterator() {
49 mz_zip_reader_extract_iter_free(iter_->impl);
50 }
51
next(void * buf)52 size_t ChunkRecordIterator::next(void* buf){
53 size_t want_size = std::min(chunkSize_, recordSize_ - offset_);
54 if (want_size == 0) {
55 return 0;
56 }
57 size_t read_size = mz_zip_reader_extract_iter_read(iter_->impl, buf, want_size);
58 TORCH_CHECK(read_size > 0, "Read bytes should be larger than 0");
59 offset_ += read_size;
60 return read_size;
61 }
62
istream_read_func(void * pOpaque,mz_uint64 file_ofs,void * pBuf,size_t n)63 size_t istream_read_func(void* pOpaque, mz_uint64 file_ofs, void* pBuf, size_t n) {
64 auto self = static_cast<PyTorchStreamReader*>(pOpaque);
65 return self->read(file_ofs, static_cast<char*>(pBuf), n);
66 }
67
basename(const std::string & name)68 static std::string basename(const std::string& name) {
69 size_t start = 0;
70 for(size_t i = 0; i < name.size(); ++i) {
71 if (name[i] == '\\' || name[i] == '/') {
72 start = i + 1;
73 }
74 }
75
76 if (start >= name.size()) {
77 return "";
78 }
79
80 size_t end = name.size();
81 for(size_t i = end; i > start; --i) {
82 if (name[i - 1] == '.') {
83 end = i - 1;
84 break;
85 }
86 }
87 return name.substr(start, end - start);
88 }
89
parentdir(const std::string & name)90 static std::string parentdir(const std::string& name) {
91 size_t end = name.find_last_of('/');
92 if (end == std::string::npos) {
93 end = name.find_last_of('\\');
94 }
95
96 #ifdef WIN32
97 if (end != std::string::npos && end > 1 && name[end - 1] == ':') {
98 // This is a Windows root directory, so include the slash in
99 // the parent directory
100 end++;
101 }
102 #endif
103
104 if (end == std::string::npos) {
105 return "";
106 }
107
108 return name.substr(0, end);
109 }
110
read(uint64_t pos,char * buf,size_t n)111 size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
112 return in_->read(pos, buf, n, "reading file");
113 }
114
115 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PyTorchStreamReader(const std::string & file_name)116 PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name)
117 : ar_(std::make_unique<mz_zip_archive>()),
118 in_(std::make_unique<FileAdapter>(file_name)) {
119 init();
120 }
121
122 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PyTorchStreamReader(std::istream * in)123 PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
124 : ar_(std::make_unique<mz_zip_archive>()),
125 in_(std::make_unique<IStreamAdapter>(in)) {
126 init();
127 }
128
129 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in)130 PyTorchStreamReader::PyTorchStreamReader(
131 std::shared_ptr<ReadAdapterInterface> in)
132 : ar_(std::make_unique<mz_zip_archive>()), in_(std::move(in)) {
133 init();
134 }
135
init()136 void PyTorchStreamReader::init() {
137 AT_ASSERT(in_ != nullptr);
138 AT_ASSERT(ar_ != nullptr);
139 memset(ar_.get(), 0, sizeof(mz_zip_archive));
140
141 size_t size = in_->size();
142
143 // check for the old magic number,
144 constexpr size_t kMagicValueLength = 8;
145 if (size > kMagicValueLength) {
146 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
147 char buf[kMagicValueLength];
148 read(0, buf, kMagicValueLength);
149 valid("checking magic number");
150 AT_ASSERTM(
151 memcmp("PYTORCH1", buf, kMagicValueLength) != 0,
152 "File is an unsupported archive format from the preview release.");
153 }
154
155 ar_->m_pIO_opaque = this;
156 ar_->m_pRead = istream_read_func;
157
158 mz_zip_reader_init(ar_.get(), size, 0);
159 valid("reading zip archive");
160
161 // figure out the archive_name (i.e. the zip folder all the other files are in)
162 // all lookups to getRecord will be prefixed by this folder
163 mz_uint n = mz_zip_reader_get_num_files(ar_.get());
164 if (n == 0) {
165 CAFFE_THROW("archive does not contain any files");
166 }
167 size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0, nullptr, 0);
168 valid("getting filename");
169 std::string buf(name_size, '\0');
170 mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size);
171 valid("getting filename");
172 auto pos = buf.find_first_of('/');
173 if (pos == std::string::npos) {
174 CAFFE_THROW("file in archive is not in a subdirectory: ", buf);
175 }
176 archive_name_ = buf.substr(0, pos);
177 archive_name_plus_slash_ = archive_name_ + "/";
178
179 // read serialization id
180 if (hasRecord(kSerializationIdRecordName)) {
181 at::DataPtr serialization_id_ptr;
182 size_t serialization_id_size = 0;
183 std::tie(serialization_id_ptr, serialization_id_size) =
184 getRecord(kSerializationIdRecordName);
185 serialization_id_.assign(
186 static_cast<const char*>(serialization_id_ptr.get()),
187 serialization_id_size);
188 }
189 c10::LogAPIUsageMetadata(
190 "pytorch.stream.reader.metadata",
191 {{"serialization_id", serialization_id_},
192 {"file_name", archive_name_},
193 {"file_size", str(mz_zip_get_archive_size(ar_.get()))}});
194
195 // version check
196 at::DataPtr version_ptr;
197 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
198 size_t version_size;
199 if (hasRecord(".data/version")) {
200 std::tie(version_ptr, version_size) = getRecord(".data/version");
201 } else {
202 TORCH_CHECK(hasRecord("version"))
203 std::tie(version_ptr, version_size) = getRecord("version");
204 }
205 std::string version(static_cast<const char*>(version_ptr.get()), version_size);
206 try {
207 version_ = std::stoull(version);
208 } catch (const std::invalid_argument& e) {
209 CAFFE_THROW("Couldn't parse the version ",
210 version,
211 " as Long Long.");
212 }
213 if (version_ < static_cast<decltype(version_)>(kMinSupportedFileFormatVersion)) {
214 CAFFE_THROW(
215 "Attempted to read a PyTorch file with version ",
216 std::to_string(version_),
217 ", but the minimum supported version for reading is ",
218 std::to_string(kMinSupportedFileFormatVersion),
219 ". Your PyTorch script module file is too old. Please regenerate it",
220 " with latest version of PyTorch to mitigate this issue.");
221 }
222
223 if (version_ > static_cast<decltype(version_)>(kMaxSupportedFileFormatVersion)) {
224 CAFFE_THROW(
225 "Attempted to read a PyTorch file with version ",
226 version_,
227 ", but the maximum supported version for reading is ",
228 kMaxSupportedFileFormatVersion,
229 ". The version of your PyTorch installation may be too old, ",
230 "please upgrade PyTorch to latest version to mitigate this issue.");
231 }
232 }
233
valid(const char * what,const char * info)234 void PyTorchStreamReader::valid(const char* what, const char* info) {
235 const auto err = mz_zip_get_last_error(ar_.get());
236 TORCH_CHECK(
237 err == MZ_ZIP_NO_ERROR,
238 "PytorchStreamReader failed ",
239 what,
240 info,
241 ": ",
242 mz_zip_get_error_string(err));
243 }
244
245 constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
246 constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
247 constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
248 constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50;
249
250 namespace detail {
getPadding(size_t cursor,size_t filename_size,size_t size,std::string & padding_buf)251 size_t getPadding(
252 size_t cursor,
253 size_t filename_size,
254 size_t size,
255 std::string& padding_buf) {
256 size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
257 sizeof(mz_uint16) * 2;
258 if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
259 start += sizeof(mz_uint16) * 2;
260 if (size >= MZ_UINT32_MAX) {
261 start += 2 * sizeof(mz_uint64);
262 }
263 if (cursor >= MZ_UINT32_MAX) {
264 start += sizeof(mz_uint64);
265 }
266 }
267 size_t mod = start % kFieldAlignment;
268 size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
269 size_t padding_size = next_offset - start;
270 size_t padding_size_plus_fbxx = padding_size + 4;
271 if (padding_buf.size() < padding_size_plus_fbxx) {
272 padding_buf.append(padding_size_plus_fbxx - padding_buf.size(), 'Z');
273 }
274 // zip extra encoding (key, size_of_extra_bytes)
275 padding_buf[0] = 'F';
276 padding_buf[1] = 'B';
277 padding_buf[2] = (uint8_t)padding_size;
278 padding_buf[3] = (uint8_t)(padding_size >> 8);
279 return padding_size_plus_fbxx;
280 }
281 }
282
hasRecord(const std::string & name)283 bool PyTorchStreamReader::hasRecord(const std::string& name) {
284 std::lock_guard<std::mutex> guard(reader_lock_);
285
286 if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
287 return false;
288 }
289 std::string ss = archive_name_plus_slash_ + name;
290 mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
291 const mz_zip_error err = mz_zip_get_last_error(ar_.get());
292
293 if (err == MZ_ZIP_NO_ERROR) {
294 return true;
295 } else if (err == MZ_ZIP_FILE_NOT_FOUND) {
296 return false;
297 } else {
298 // A different error happened, raise it.
299 valid("attempting to locate file ", name.c_str());
300 }
301 TORCH_INTERNAL_ASSERT(false, "should not reach here");
302 }
303
getAllRecords()304 std::vector<std::string> PyTorchStreamReader::getAllRecords() {
305 std::lock_guard<std::mutex> guard(reader_lock_);
306 mz_uint num_files = mz_zip_reader_get_num_files(ar_.get());
307 std::vector<std::string> out;
308 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
309 char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE];
310 for (size_t i = 0; i < num_files; i++) {
311 mz_zip_reader_get_filename(ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
312 if (strncmp(
313 buf,
314 archive_name_plus_slash_.data(),
315 archive_name_plus_slash_.size()) != 0) {
316 CAFFE_THROW(
317 "file in archive is not in a subdirectory ",
318 archive_name_plus_slash_,
319 ": ",
320 buf);
321 }
322 if ((load_debug_symbol_) ||
323 (!c10::string_view(buf + archive_name_plus_slash_.size()).ends_with(kDebugPklSuffix))) {
324 // NOLINTNEXTLINE(modernize-use-emplace)
325 out.push_back(buf + archive_name_plus_slash_.size());
326 }
327 }
328 return out;
329 }
330
331 const std::unordered_set<std::string>&
getAllWrittenRecords()332 PyTorchStreamWriter::getAllWrittenRecords() {
333 return files_written_;
334 }
335
getRecordID(const std::string & name)336 size_t PyTorchStreamReader::getRecordID(const std::string& name) {
337 std::string ss = archive_name_plus_slash_ + name;
338 size_t result = mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
339 valid("locating file ", name.c_str());
340 return result;
341 }
342
343 // return dataptr, size
getRecord(const std::string & name)344 std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
345 std::lock_guard<std::mutex> guard(reader_lock_);
346 if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
347 at::DataPtr retval;
348 return std::make_tuple(std::move(retval), 0);
349 }
350 size_t key = getRecordID(name);
351 mz_zip_archive_file_stat stat;
352 mz_zip_reader_file_stat(ar_.get(), key, &stat);
353 valid("retrieving file meta-data for ", name.c_str());
354 at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
355 mz_zip_reader_extract_to_mem(ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
356 valid("reading file ", name.c_str());
357
358 return std::make_tuple(std::move(retval), stat.m_uncomp_size);
359 }
360
361 size_t
getRecordMultiReaders(const std::string & name,std::vector<std::shared_ptr<ReadAdapterInterface>> & additionalReaders,void * dst,size_t n)362 PyTorchStreamReader::getRecordMultiReaders(const std::string& name,
363 std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
364 void *dst, size_t n){
365
366 size_t nthread = additionalReaders.size()+1;
367 size_t recordOff = getRecordOffset(name);
368 std::vector<std::thread> loaderThreads;
369 size_t perThreadSize = (n+nthread-1)/nthread;
370 std::vector<size_t> readSizes(nthread, 0);
371 std::lock_guard<std::mutex> guard(reader_lock_);
372 for(size_t i = 0; i < nthread ; i++){
373 loaderThreads.emplace_back([this, name, i, n, recordOff, perThreadSize, dst, &additionalReaders, &readSizes]{
374 size_t startPos = i*perThreadSize;
375 size_t endPos = std::min((i+1)*perThreadSize,n);
376 if (startPos < endPos){
377 size_t threadReadSize = endPos - startPos;
378 size_t size = 0;
379 if (i==0){
380 size = read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
381 }else{
382 auto reader = additionalReaders[i-1];
383 size = reader->read(recordOff+startPos, (char *)dst+startPos, threadReadSize);
384 }
385 readSizes[i] = size;
386 LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos << "] "
387 << "from " << name << " of size " << n;
388 TORCH_CHECK(
389 threadReadSize == size,
390 "record size ",
391 threadReadSize,
392 " mismatch with read size ",
393 size);
394 }
395 });
396 }
397
398 for (auto& thread : loaderThreads) {
399 thread.join();
400 }
401 loaderThreads.clear();
402
403 size_t total_read_n = 0;
404 for (auto& r : readSizes){
405 total_read_n += r;
406 }
407
408 TORCH_CHECK(
409 n == total_read_n,
410 "Multi reader total read size ",
411 total_read_n,
412 " mismatch with dst size ",
413 n);
414
415 return total_read_n;
416 }
417
418 // read record with multi clients
419 std::tuple<at::DataPtr, size_t>
getRecord(const std::string & name,std::vector<std::shared_ptr<ReadAdapterInterface>> & additionalReaders)420 PyTorchStreamReader::getRecord(const std::string& name,
421 std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
422 if(additionalReaders.empty()){
423 // No additional readers or record too small, use single threaded version
424 return getRecord(name);
425 }
426
427 if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
428 at::DataPtr retval;
429 return std::make_tuple(std::move(retval), 0);
430 }
431 size_t key = getRecordID(name);
432 mz_zip_archive_file_stat stat;
433 mz_zip_reader_file_stat(ar_.get(), key, &stat);
434 auto n = stat.m_uncomp_size;
435 valid("retrieving file meta-data for ", name.c_str());
436 if(n < additional_reader_size_threshold_){
437 // Reader size too small, use single threaded version
438 return getRecord(name);
439 }
440
441 at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
442 void* dst = retval.get();
443 PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
444 return std::make_tuple(std::move(retval), stat.m_uncomp_size);
445 }
446
447 // inplace memory writing
448 size_t
getRecord(const std::string & name,void * dst,size_t n)449 PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n) {
450 std::lock_guard<std::mutex> guard(reader_lock_);
451 if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
452 return 0;
453 }
454 size_t key = getRecordID(name);
455 mz_zip_archive_file_stat stat;
456 mz_zip_reader_file_stat(ar_.get(), key, &stat);
457 TORCH_CHECK(
458 n == stat.m_uncomp_size,
459 "record size ",
460 stat.m_uncomp_size,
461 " mismatch with dst size ",
462 n);
463 valid("retrieving file meta-data for ", name.c_str());
464 mz_zip_reader_extract_to_mem(ar_.get(), key, dst, stat.m_uncomp_size, 0);
465 valid("reading file ", name.c_str());
466
467 return stat.m_uncomp_size;
468 }
469
470
471 // inplace memory writing, in-tensor multi-threads, can be used for large tensor.
472 size_t
getRecord(const std::string & name,void * dst,size_t n,std::vector<std::shared_ptr<ReadAdapterInterface>> & additionalReaders)473 PyTorchStreamReader::getRecord(const std::string& name, void* dst, size_t n,
474 std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders) {
475 if(additionalReaders.empty()){
476 // No additional readers, use single threaded version
477 return getRecord(name, dst, n);
478 }
479
480 if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
481 return 0;
482 }
483 size_t key = getRecordID(name);
484 mz_zip_archive_file_stat stat;
485 mz_zip_reader_file_stat(ar_.get(), key, &stat);
486 TORCH_CHECK(
487 n == stat.m_uncomp_size,
488 "record size ",
489 stat.m_uncomp_size,
490 " mismatch with dst size ",
491 n);
492 valid("retrieving file meta-data for ", name.c_str());
493
494 if(n < additional_reader_size_threshold_){
495 // Reader size too small, use single threaded version
496 return getRecord(name, dst, n);
497 }
498
499 PyTorchStreamReader::getRecordMultiReaders(name, additionalReaders, dst, n);
500 return stat.m_uncomp_size;
501 }
502
getRecord(const std::string & name,void * dst,size_t n,size_t chunk_size,void * buf,const std::function<void (void *,const void *,size_t)> & memcpy_func)503 size_t PyTorchStreamReader::getRecord(
504 const std::string& name,
505 void* dst,
506 size_t n,
507 size_t chunk_size,
508 void* buf,
509 const std::function<void(void*, const void*, size_t)>& memcpy_func) {
510 std::lock_guard<std::mutex> guard(reader_lock_);
511 if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
512 return 0;
513 }
514 if (chunk_size <= 0) {
515 chunk_size = n;
516 }
517 size_t key = getRecordID(name);
518 mz_zip_archive_file_stat stat;
519 mz_zip_reader_file_stat(ar_.get(), key, &stat);
520 TORCH_CHECK(
521 n == stat.m_uncomp_size,
522 "record size ",
523 stat.m_uncomp_size,
524 " mismatch with dst size ",
525 n);
526 valid("retrieving file meta-data for ", name.c_str());
527
528 std::vector<uint8_t> buffer;
529 if (buf == nullptr) {
530 buffer.resize(chunk_size);
531 buf = buffer.data();
532 }
533
534 auto chunkIterator =
535 createChunkReaderIter(name, (size_t)stat.m_uncomp_size, chunk_size);
536 while (auto readSize = chunkIterator.next(buf)) {
537 memcpy_func((char*)dst + chunkIterator.offset_ - readSize, buf, readSize);
538 }
539 valid("reading file ", name.c_str());
540
541 return stat.m_uncomp_size;
542 }
543
createChunkReaderIter(const std::string & name,const size_t recordSize,const size_t chunkSize)544 ChunkRecordIterator PyTorchStreamReader::createChunkReaderIter(
545 const std::string& name,
546 const size_t recordSize,
547 const size_t chunkSize) {
548 // Create zip reader iterator
549 size_t key = getRecordID(name);
550 mz_zip_reader_extract_iter_state* zipReaderIter =
551 mz_zip_reader_extract_iter_new(ar_.get(), key, 0);
552 TORCH_CHECK(
553 zipReaderIter != nullptr,
554 "Failed to create zip reader iter: ",
555 mz_zip_get_error_string(mz_zip_get_last_error(ar_.get())));
556
557 return ChunkRecordIterator(
558 recordSize,
559 chunkSize,
560 std::make_unique<MzZipReaderIterWrapper>(zipReaderIter));
561 }
562
read_le_16(uint8_t * buf)563 static int64_t read_le_16(uint8_t* buf) {
564 return buf[0] + (buf[1] << 8);
565 }
566
getRecordOffset(const std::string & name)567 size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
568 std::lock_guard<std::mutex> guard(reader_lock_);
569 mz_zip_archive_file_stat stat;
570 mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
571 valid("retrieving file meta-data for ", name.c_str());
572 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
573 uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
574 in_->read(
575 stat.m_local_header_ofs,
576 local_header,
577 MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
578 "reading file header");
579 size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
580 size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
581 return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
582 }
583
getRecordSize(const std::string & name)584 size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
585 mz_zip_archive_file_stat stat;
586 mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
587 return stat.m_uncomp_size;
588 }
589
~PyTorchStreamReader()590 PyTorchStreamReader::~PyTorchStreamReader() {
591 mz_zip_clear_last_error(ar_.get());
592 mz_zip_reader_end(ar_.get());
593 valid("closing reader for archive ", archive_name_.c_str());
594 }
595
ostream_write_func(void * pOpaque,mz_uint64 file_ofs,const void * pBuf,size_t n)596 size_t ostream_write_func(
597 void* pOpaque,
598 mz_uint64 file_ofs,
599 const void* pBuf,
600 size_t n) {
601 auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
602 if (self->current_pos_ != file_ofs) {
603 CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
604 }
605 size_t ret = self->writer_func_(pBuf, n);
606 if (n != ret) {
607 self->err_seen_ = true;
608 }
609 self->current_pos_ += ret;
610
611 // Get the CRC32 of uncompressed data from the data descriptor, if the written
612 // data is identified as the data descriptor block.
613 // See [Note: write_record_metadata] for why we check for non-null pBuf here
614 if (pBuf && n >= 8 && MZ_READ_LE32(pBuf) == MZ_ZIP_DATA_DESCRIPTOR_ID) {
615 const int8_t* pInt8Buf = (const int8_t*)pBuf;
616 const uint32_t uncomp_crc32 = MZ_READ_LE32(pInt8Buf + 4);
617 self->combined_uncomp_crc32_ =
618 c10::hash_combine(self->combined_uncomp_crc32_, uncomp_crc32);
619 }
620
621 return ret;
622 }
623
PyTorchStreamWriter(const std::string & file_name)624 PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name)
625 : archive_name_(basename(file_name)) {
626 setup(file_name);
627 }
628
PyTorchStreamWriter(const std::function<size_t (const void *,size_t)> writer_func)629 PyTorchStreamWriter::PyTorchStreamWriter(
630 const std::function<size_t(const void*, size_t)> writer_func)
631 : archive_name_("archive"),
632 writer_func_(writer_func) {
633 setup(archive_name_);
634 }
635
setup(const string & file_name)636 void PyTorchStreamWriter::setup(const string& file_name) {
637 ar_ = std::make_unique<mz_zip_archive>();
638 memset(ar_.get(), 0, sizeof(mz_zip_archive));
639 archive_name_plus_slash_ = archive_name_ + "/"; // for writeRecord().
640
641 if (archive_name_.size() == 0) {
642 CAFFE_THROW("invalid file name: ", file_name);
643 }
644 if (!writer_func_) {
645 file_stream_.open(
646 file_name,
647 std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
648 valid("opening archive ", file_name.c_str());
649
650 const std::string dir_name = parentdir(file_name);
651 if(!dir_name.empty()) {
652 struct stat st;
653 bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
654 TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist.");
655 }
656 TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
657 writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
658 if (!buf) {
659 // See [Note: write_record_metadata]
660 file_stream_.seekp(nbytes, std::ios_base::cur);
661 } else {
662 file_stream_.write(static_cast<const char*>(buf), nbytes);
663 }
664 return !file_stream_ ? 0 : nbytes;
665 };
666 }
667
668 ar_->m_pIO_opaque = this;
669 ar_->m_pWrite = ostream_write_func;
670
671 mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
672 valid("initializing archive ", file_name.c_str());
673 }
674
setMinVersion(const uint64_t version)675 void PyTorchStreamWriter::setMinVersion(const uint64_t version) {
676 version_ = std::max(version, version_);
677 }
678
writeRecord(const std::string & name,const void * data,size_t size,bool compress)679 void PyTorchStreamWriter::writeRecord(
680 const std::string& name,
681 const void* data,
682 size_t size,
683 bool compress) {
684 AT_ASSERT(!finalized_);
685 AT_ASSERT(!archive_name_plus_slash_.empty());
686 TORCH_INTERNAL_ASSERT(
687 files_written_.count(name) == 0, "Tried to serialize file twice: ", name);
688 if (name == kSerializationIdRecordName && serialization_id_.empty()) {
689 // In case of copying records from another file, skip writing a different
690 // serialization_id than the one computed in this writer.
691 // This is to ensure serialization_id is unique per serialization output.
692 return;
693 }
694 std::string full_name = archive_name_plus_slash_ + name;
695 size_t padding_size =
696 detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
697 uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
698 mz_zip_writer_add_mem_ex_v2(
699 /*pZip=*/ar_.get(),
700 /*pArchive_name=*/full_name.c_str(),
701 /*pBuf=*/data,
702 /*buf_size=*/size,
703 /*pComment=*/nullptr,
704 /*comment_size=*/0,
705 /*level_and_flags=*/flags,
706 /*uncomp_size=*/0,
707 /*uncomp_crc32=*/0,
708 /*last_modified=*/nullptr,
709 /*user_extra_data=*/padding_.c_str(),
710 /*user_extra_data_len=*/padding_size,
711 /*user_extra_data_central=*/nullptr,
712 /*user_extra_data_central_len=*/0);
713 valid("writing file ", name.c_str());
714 files_written_.insert(name);
715 }
716
writeEndOfFile()717 void PyTorchStreamWriter::writeEndOfFile() {
718 // Ensurers that finalized is set to true even
719 // exception is raised during the method call.
720 // I.e. even partial call to writeEndOfFile() should mark
721 // file as finalized, otherwise double exception raised from
722 // destructor would would result in `std::terminate()`
723 // See https://github.com/pytorch/pytorch/issues/87997/
724 struct Finalizer {
725 Finalizer(bool& var): var_(var) {}
726 ~Finalizer() {
727 var_ = true;
728 }
729 private:
730 bool& var_;
731 } f(finalized_);
732
733 auto allRecords = getAllWrittenRecords();
734 // If no ".data/version" or "version" record in the output model, rewrites version info
735 if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {
736 std::string version = std::to_string(version_);
737 version.push_back('\n');
738 if (version_ >= 0x6L) {
739 writeRecord(".data/version", version.c_str(), version.size());
740 } else {
741 writeRecord("version", version.c_str(), version.size());
742 }
743 }
744
745 // If no "byteorder" record in the output model, rewrites byteorder info
746 if(allRecords.find("byteorder") == allRecords.end()) {
747 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
748 std::string byteorder = "little";
749 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
750 std::string byteorder = "big";
751 #else
752 #error Unexpected or undefined __BYTE_ORDER__
753 #endif
754 writeRecord("byteorder", byteorder.c_str(), byteorder.size());
755 }
756
757 writeSerializationId();
758
759 AT_ASSERT(!finalized_);
760 finalized_ = true;
761
762 mz_zip_writer_finalize_archive(ar_.get());
763 mz_zip_writer_end(ar_.get());
764 valid("writing central directory for archive ", archive_name_.c_str());
765 c10::LogAPIUsageMetadata(
766 "pytorch.stream.writer.metadata",
767 {{"serialization_id", serialization_id_},
768 {"file_name", archive_name_},
769 {"file_size", str(mz_zip_get_archive_size(ar_.get()))}});
770 if (file_stream_.is_open()) {
771 file_stream_.close();
772 }
773 }
774
valid(const char * what,const char * info)775 void PyTorchStreamWriter::valid(const char* what, const char* info) {
776 auto err = mz_zip_get_last_error(ar_.get());
777 if (err != MZ_ZIP_NO_ERROR) {
778 CAFFE_THROW(
779 "PytorchStreamWriter failed ",
780 what,
781 info,
782 ": ",
783 mz_zip_get_error_string(err));
784 }
785 if (err_seen_) {
786 CAFFE_THROW("PytorchStreamWriter failed ", what, info, ".");
787 }
788 }
789
writeSerializationId()790 void PyTorchStreamWriter::writeSerializationId() {
791 // Serialization id is computed based on all files written, and is composed of
792 // 1) a combined hash of record name hashes
793 // 2) a combined crc32 of the record uncompressed data
794 // This is best effort to create a fixed-length, unique and deterministic id
795 // for the serialized files without incurring additional computation overhead.
796 if (files_written_.find(kSerializationIdRecordName) == files_written_.end()) {
797 uint64_t combined_record_name_hash = 0;
798 for (const std::string& record_name : files_written_) {
799 size_t record_name_hash = c10::hash<std::string>{}(record_name);
800 combined_record_name_hash =
801 c10::hash_combine(combined_record_name_hash, record_name_hash);
802 }
803 std::ostringstream serialization_id_oss;
804 serialization_id_oss << std::setfill('0') << std::setw(20)
805 << combined_record_name_hash
806 << std::setfill('0') << std::setw(20)
807 << combined_uncomp_crc32_;
808 serialization_id_ = serialization_id_oss.str();
809 writeRecord(
810 kSerializationIdRecordName,
811 serialization_id_.c_str(),
812 serialization_id_.size());
813 }
814 }
815
816 // NOLINTNEXTLINE(bugprone-exception-escape)
~PyTorchStreamWriter()817 PyTorchStreamWriter::~PyTorchStreamWriter() {
818 if (!finalized_) {
819 writeEndOfFile();
820 }
821 }
822
823 } // namespace serialize
824 } // namespace caffe2
825