1 #include <array>
2 #include <cstdio>
3 #include <cstring>
4 #include <string>
5
6 #include <gtest/gtest.h>
7
8 #include "caffe2/serialize/inline_container.h"
9 #include <c10/util/Logging.h>
10 #include "c10/util/irange.h"
11
12 namespace caffe2 {
13 namespace serialize {
14 namespace {
15
TEST(PyTorchStreamWriterAndReader,SaveAndLoad)16 TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
17 int64_t kFieldAlignment = 64L;
18
19 std::ostringstream oss;
20 // write records through writers
21 PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
22 oss.write(static_cast<const char*>(b), n);
23 return oss ? n : 0;
24 });
25 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
26 std::array<char, 127> data1;
27 // Inplace memory buffer
28 std::vector<uint8_t> buf(data1.size());
29
30 for (auto i : c10::irange(data1.size())) {
31 data1[i] = data1.size() - i;
32 }
33 writer.writeRecord("key1", data1.data(), data1.size());
34
35 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
36 std::array<char, 64> data2;
37 for (auto i : c10::irange(data2.size())) {
38 data2[i] = data2.size() - i;
39 }
40 writer.writeRecord("key2", data2.data(), data2.size());
41
42 const std::unordered_set<std::string>& written_records =
43 writer.getAllWrittenRecords();
44 ASSERT_EQ(written_records.size(), 2);
45 ASSERT_EQ(written_records.count("key1"), 1);
46 ASSERT_EQ(written_records.count("key2"), 1);
47
48 writer.writeEndOfFile();
49 ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
50
51 std::string the_file = oss.str();
52 const char* file_name = "output.zip";
53 std::ofstream foo(file_name);
54 foo.write(the_file.c_str(), the_file.size());
55 foo.close();
56
57 std::istringstream iss(the_file);
58
59 // read records through readers
60 PyTorchStreamReader reader(&iss);
61 ASSERT_TRUE(reader.hasRecord("key1"));
62 ASSERT_TRUE(reader.hasRecord("key2"));
63 ASSERT_FALSE(reader.hasRecord("key2000"));
64 at::DataPtr data_ptr;
65 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
66 int64_t size;
67 std::tie(data_ptr, size) = reader.getRecord("key1");
68 size_t off1 = reader.getRecordOffset("key1");
69 ASSERT_EQ(size, data1.size());
70 ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
71 ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
72 ASSERT_EQ(off1 % kFieldAlignment, 0);
73 // inplace getRecord() test
74 std::vector<uint8_t> dst(size);
75 size_t ret = reader.getRecord("key1", dst.data(), size);
76 ASSERT_EQ(ret, size);
77 ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
78 // chunked getRecord() test
79 ret = reader.getRecord(
80 "key1", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
81 memcpy(dst, src, n);
82 });
83 ASSERT_EQ(ret, size);
84 ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
85
86 std::tie(data_ptr, size) = reader.getRecord("key2");
87 size_t off2 = reader.getRecordOffset("key2");
88 ASSERT_EQ(off2 % kFieldAlignment, 0);
89
90 ASSERT_EQ(size, data2.size());
91 ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
92 ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
93 // inplace getRecord() test
94 dst.resize(size);
95 ret = reader.getRecord("key2", dst.data(), size);
96 ASSERT_EQ(ret, size);
97 ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
98 // chunked getRecord() test
99 ret = reader.getRecord(
100 "key2", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
101 memcpy(dst, src, n);
102 });
103 ASSERT_EQ(ret, size);
104 ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
105 // clean up
106 remove(file_name);
107 }
108
TEST(PyTorchStreamWriterAndReader,LoadWithMultiThreads)109 TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
110
111 std::ostringstream oss;
112 // write records through writers
113 PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
114 oss.write(static_cast<const char*>(b), n);
115 return oss ? n : 0;
116 });
117
118 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
119 std::array<char, 127> data1;
120 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
121 std::array<char, 64> data2;
122 for (auto i : c10::irange(data1.size())) {
123 data1[i] = data1.size() - i;
124 }
125 writer.writeRecord("key1", data1.data(), data1.size());
126
127 for (auto i : c10::irange(data2.size())) {
128 data2[i] = data2.size() - i;
129 }
130 writer.writeRecord("key2", data2.data(), data2.size());
131
132 const std::unordered_set<std::string>& written_records =
133 writer.getAllWrittenRecords();
134 ASSERT_EQ(written_records.size(), 2);
135 ASSERT_EQ(written_records.count("key1"), 1);
136 ASSERT_EQ(written_records.count("key2"), 1);
137
138 writer.writeEndOfFile();
139 ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
140
141 std::string the_file = oss.str();
142 const char* file_name = "output.zip";
143 std::ofstream foo(file_name);
144 foo.write(the_file.c_str(), the_file.size());
145 foo.close();
146
147 // read records through pytorchStreamReader
148 std::istringstream iss(the_file);
149 PyTorchStreamReader reader(&iss);
150 reader.setAdditionalReaderSizeThreshold(0);
151 // before testing, sanity check
152 int64_t size1, size2, ret;
153 at::DataPtr data_ptr;
154 std::tie(data_ptr, size1) = reader.getRecord("key1");
155 std::tie(data_ptr, size2) = reader.getRecord("key2");
156
157 // Test getRecord(name, additional_readers)
158 std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader;
159 for(int i=0; i<10; ++i){
160 // Test various sized additional readers.
161 std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader);
162 ASSERT_EQ(ret, size1);
163 ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), size1), 0);
164
165 std::tie(data_ptr, ret) = reader.getRecord("key2", additionalReader);
166 ASSERT_EQ(ret, size2);
167 ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), size2), 0);
168 }
169
170 // Inplace multi-threading getRecord(name, dst, n, additional_readers) test
171 additionalReader.clear();
172 std::vector<uint8_t> dst1(size1), dst2(size2);
173 for(int i=0; i<10; ++i){
174 // Test various sizes of read threads
175 additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss));
176
177 ret = reader.getRecord("key1", dst1.data(), size1, additionalReader);
178 ASSERT_EQ(ret, size1);
179 ASSERT_EQ(memcmp(dst1.data(), data1.data(), size1), 0);
180
181 ret = reader.getRecord("key2", dst2.data(), size2, additionalReader);
182 ASSERT_EQ(ret, size2);
183 ASSERT_EQ(memcmp(dst2.data(), data2.data(), size2), 0);
184 }
185 // clean up
186 remove(file_name);
187 }
188
TEST(PytorchStreamWriterAndReader,GetNonexistentRecordThrows)189 TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
190 std::ostringstream oss;
191 // write records through writers
192 PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
193 oss.write(static_cast<const char*>(b), n);
194 return oss ? n : 0;
195 });
196 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
197 std::array<char, 127> data1;
198
199 // Inplace memory buffer
200 std::vector<uint8_t> buf;
201
202 for (auto i : c10::irange(data1.size())) {
203 data1[i] = data1.size() - i;
204 }
205 writer.writeRecord("key1", data1.data(), data1.size());
206
207 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
208 std::array<char, 64> data2;
209 for (auto i : c10::irange(data2.size())) {
210 data2[i] = data2.size() - i;
211 }
212 writer.writeRecord("key2", data2.data(), data2.size());
213
214 const std::unordered_set<std::string>& written_records =
215 writer.getAllWrittenRecords();
216 ASSERT_EQ(written_records.size(), 2);
217 ASSERT_EQ(written_records.count("key1"), 1);
218 ASSERT_EQ(written_records.count("key2"), 1);
219
220 writer.writeEndOfFile();
221 ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
222
223 std::string the_file = oss.str();
224 const char* file_name = "output2.zip";
225 std::ofstream foo(file_name);
226 foo.write(the_file.c_str(), the_file.size());
227 foo.close();
228
229 std::istringstream iss(the_file);
230
231 // read records through readers
232 PyTorchStreamReader reader(&iss);
233 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
234 EXPECT_THROW(reader.getRecord("key3"), c10::Error);
235 std::vector<uint8_t> dst(data1.size());
236 EXPECT_THROW(reader.getRecord("key3", dst.data(), data1.size()), c10::Error);
237 EXPECT_THROW(
238 reader.getRecord(
239 "key3",
240 dst.data(),
241 data1.size(),
242 3,
243 buf.data(),
244 [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); }),
245 c10::Error);
246
247 // Reader should still work after throwing
248 EXPECT_TRUE(reader.hasRecord("key1"));
249 // clean up
250 remove(file_name);
251 }
252
TEST(PytorchStreamWriterAndReader,SkipDebugRecords)253 TEST(PytorchStreamWriterAndReader, SkipDebugRecords) {
254 std::ostringstream oss;
255 PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
256 oss.write(static_cast<const char*>(b), n);
257 return oss ? n : 0;
258 });
259 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
260 std::array<char, 127> data1;
261 // Inplace memory buffer
262 std::vector<uint8_t> buf(data1.size());
263
264 for (auto i : c10::irange(data1.size())) {
265 data1[i] = data1.size() - i;
266 }
267 writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());
268
269 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
270 std::array<char, 64> data2;
271 for (auto i : c10::irange(data2.size())) {
272 data2[i] = data2.size() - i;
273 }
274 writer.writeRecord("key2.debug_pkl", data2.data(), data2.size());
275
276 const std::unordered_set<std::string>& written_records =
277 writer.getAllWrittenRecords();
278 ASSERT_EQ(written_records.size(), 2);
279 ASSERT_EQ(written_records.count("key1.debug_pkl"), 1);
280 ASSERT_EQ(written_records.count("key2.debug_pkl"), 1);
281 writer.writeEndOfFile();
282 ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
283
284 std::string the_file = oss.str();
285 const char* file_name = "output3.zip";
286 std::ofstream foo(file_name);
287 foo.write(the_file.c_str(), the_file.size());
288 foo.close();
289
290 std::istringstream iss(the_file);
291
292 // read records through readers
293 PyTorchStreamReader reader(&iss);
294 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
295
296 reader.setShouldLoadDebugSymbol(false);
297 EXPECT_FALSE(reader.hasRecord("key1.debug_pkl"));
298 at::DataPtr ptr;
299 size_t size;
300 std::tie(ptr, size) = reader.getRecord("key1.debug_pkl");
301 EXPECT_EQ(size, 0);
302 std::vector<uint8_t> dst(data1.size());
303 size_t ret = reader.getRecord("key1.debug_pkl", dst.data(), data1.size());
304 EXPECT_EQ(ret, 0);
305 ret = reader.getRecord(
306 "key1.debug_pkl",
307 dst.data(),
308 data1.size(),
309 3,
310 buf.data(),
311 [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
312 EXPECT_EQ(ret, 0);
313 // clean up
314 remove(file_name);
315 }
316
TEST(PytorchStreamWriterAndReader,ValidSerializationId)317 TEST(PytorchStreamWriterAndReader, ValidSerializationId) {
318 std::ostringstream oss;
319 PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
320 oss.write(static_cast<const char*>(b), n);
321 return oss ? n : 0;
322 });
323
324 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
325 std::array<char, 127> data1;
326
327 for (auto i: c10::irange(data1.size())) {
328 data1[i] = data1.size() - i;
329 }
330 writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());
331 writer.writeEndOfFile();
332 auto writer_serialization_id = writer.serializationId();
333
334 std::string the_file = oss.str();
335
336 std::istringstream iss(the_file);
337
338 // read records through readers
339 PyTorchStreamReader reader(&iss);
340 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
341
342 EXPECT_EQ(reader.serializationId(), writer_serialization_id);
343
344 // write a second time
345 PyTorchStreamWriter writer2([&](const void* b, size_t n) -> size_t {
346 oss.write(static_cast<const char*>(b), n);
347 return oss ? n : 0;
348 });
349 writer2.writeRecord("key1.debug_pkl", data1.data(), data1.size());
350 writer2.writeEndOfFile();
351 auto writer2_serialization_id = writer2.serializationId();
352
353 EXPECT_EQ(writer_serialization_id, writer2_serialization_id);
354 }
355
TEST(PytorchStreamWriterAndReader,SkipDuplicateSerializationIdRecords)356 TEST(PytorchStreamWriterAndReader, SkipDuplicateSerializationIdRecords) {
357 std::ostringstream oss;
358 PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
359 oss.write(static_cast<const char*>(b), n);
360 return oss ? n : 0;
361 });
362
363 std::string dup_serialization_id = "dup-serialization-id";
364 writer.writeRecord(kSerializationIdRecordName, dup_serialization_id.c_str(), dup_serialization_id.size());
365
366 const std::unordered_set<std::string>& written_records =
367 writer.getAllWrittenRecords();
368 ASSERT_EQ(written_records.size(), 0);
369 writer.writeEndOfFile();
370 ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
371 auto writer_serialization_id = writer.serializationId();
372
373 std::string the_file = oss.str();
374 const char* file_name = "output4.zip";
375 std::ofstream foo(file_name);
376 foo.write(the_file.c_str(), the_file.size());
377 foo.close();
378
379 std::istringstream iss(the_file);
380
381 // read records through readers
382 PyTorchStreamReader reader(&iss);
383 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
384
385 EXPECT_EQ(reader.serializationId(), writer_serialization_id);
386 // clean up
387 remove(file_name);
388 }
389
TEST(PytorchStreamWriterAndReader,LogAPIUsageMetadata)390 TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) {
391 std::map<std::string, std::map<std::string, std::string>> logs;
392
393 SetAPIUsageMetadataLogger(
394 [&](const std::string& context,
395 const std::map<std::string, std::string>& metadata_map) {
396 logs.insert({context, metadata_map});
397 });
398 std::ostringstream oss;
399 PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
400 oss.write(static_cast<const char*>(b), n);
401 return oss ? n : 0;
402 });
403 writer.writeEndOfFile();
404
405 std::istringstream iss(oss.str());
406 // read records through readers
407 PyTorchStreamReader reader(&iss);
408
409 ASSERT_EQ(logs.size(), 2);
410 std::map<std::string, std::map<std::string, std::string>> expected_logs = {
411 {"pytorch.stream.writer.metadata",
412 {{"serialization_id", writer.serializationId()},
413 {"file_name", "archive"},
414 {"file_size", str(oss.str().length())}}},
415 {"pytorch.stream.reader.metadata",
416 {{"serialization_id", writer.serializationId()},
417 {"file_name", "archive"},
418 {"file_size", str(iss.str().length())}}}
419 };
420 ASSERT_EQ(expected_logs, logs);
421
422 // reset logger
423 SetAPIUsageMetadataLogger(
424 [&](const std::string& context,
425 const std::map<std::string, std::string>& metadata_map) {});
426 }
427
428 class ChunkRecordIteratorTest : public ::testing::TestWithParam<int64_t> {};
429 INSTANTIATE_TEST_SUITE_P(
430 ChunkRecordIteratorTestGroup,
431 ChunkRecordIteratorTest,
432 testing::Values(100, 150, 1010));
433
TEST_P(ChunkRecordIteratorTest,ChunkRead)434 TEST_P(ChunkRecordIteratorTest, ChunkRead) {
435 auto chunkSize = GetParam();
436 std::string zipFileName = "output_chunk_" + std::to_string(chunkSize) + ".zip";
437 const char* fileName = zipFileName.c_str();
438 const std::string recordName = "key1";
439 const size_t tensorDataSizeInBytes = 1000;
440
441 // write records through writers
442 std::ostringstream oss(std::ios::binary);
443 PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
444 oss.write(static_cast<const char*>(b), n);
445 return oss ? n : 0;
446 });
447
448 auto tensorData = std::vector<uint8_t>(tensorDataSizeInBytes, 1);
449 auto dataPtr = tensorData.data();
450 writer.writeRecord(recordName, dataPtr, tensorDataSizeInBytes);
451 const std::unordered_set<std::string>& written_records =
452 writer.getAllWrittenRecords();
453 ASSERT_EQ(written_records.size(), 1);
454 ASSERT_EQ(written_records.count(recordName), 1);
455 writer.writeEndOfFile();
456 ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
457
458 std::string the_file = oss.str();
459 std::ofstream foo(fileName, std::ios::binary);
460 foo.write(the_file.c_str(), the_file.size());
461 foo.close();
462 LOG(INFO) << "Finished saving tensor into zip file " << fileName;
463
464 LOG(INFO) << "Testing chunk size " << chunkSize;
465 PyTorchStreamReader reader(fileName);
466 ASSERT_TRUE(reader.hasRecord(recordName));
467 auto chunkIterator = reader.createChunkReaderIter(
468 recordName, tensorDataSizeInBytes, chunkSize);
469 std::vector<uint8_t> buffer(chunkSize);
470 size_t totalReadSize = 0;
471 while (auto readSize = chunkIterator.next(buffer.data())) {
472 auto expectedData = std::vector<uint8_t>(readSize, 1);
473 ASSERT_EQ(memcmp(expectedData.data(), buffer.data(), readSize), 0);
474 totalReadSize += readSize;
475 }
476 ASSERT_EQ(totalReadSize, tensorDataSizeInBytes);
477 // clean up
478 remove(fileName);
479 }
480
481 } // namespace
482 } // namespace serialize
483 } // namespace caffe2
484