xref: /aosp_15_r20/external/pytorch/caffe2/serialize/inline_container_test.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <array>
2 #include <cstdio>
3 #include <cstring>
4 #include <string>
5 
6 #include <gtest/gtest.h>
7 
8 #include "caffe2/serialize/inline_container.h"
9 #include <c10/util/Logging.h>
10 #include "c10/util/irange.h"
11 
12 namespace caffe2 {
13 namespace serialize {
14 namespace {
15 
TEST(PyTorchStreamWriterAndReader,SaveAndLoad)16 TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
17   int64_t kFieldAlignment = 64L;
18 
19   std::ostringstream oss;
20   // write records through writers
21   PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
22     oss.write(static_cast<const char*>(b), n);
23     return oss ? n : 0;
24   });
25   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
26   std::array<char, 127> data1;
27   // Inplace memory buffer
28   std::vector<uint8_t> buf(data1.size());
29 
30   for (auto i : c10::irange(data1.size())) {
31     data1[i] = data1.size() - i;
32   }
33   writer.writeRecord("key1", data1.data(), data1.size());
34 
35   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
36   std::array<char, 64> data2;
37   for (auto i : c10::irange(data2.size())) {
38     data2[i] = data2.size() - i;
39   }
40   writer.writeRecord("key2", data2.data(), data2.size());
41 
42   const std::unordered_set<std::string>& written_records =
43       writer.getAllWrittenRecords();
44   ASSERT_EQ(written_records.size(), 2);
45   ASSERT_EQ(written_records.count("key1"), 1);
46   ASSERT_EQ(written_records.count("key2"), 1);
47 
48   writer.writeEndOfFile();
49   ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
50 
51   std::string the_file = oss.str();
52   const char* file_name = "output.zip";
53   std::ofstream foo(file_name);
54   foo.write(the_file.c_str(), the_file.size());
55   foo.close();
56 
57   std::istringstream iss(the_file);
58 
59   // read records through readers
60   PyTorchStreamReader reader(&iss);
61   ASSERT_TRUE(reader.hasRecord("key1"));
62   ASSERT_TRUE(reader.hasRecord("key2"));
63   ASSERT_FALSE(reader.hasRecord("key2000"));
64   at::DataPtr data_ptr;
65   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
66   int64_t size;
67   std::tie(data_ptr, size) = reader.getRecord("key1");
68   size_t off1 = reader.getRecordOffset("key1");
69   ASSERT_EQ(size, data1.size());
70   ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), data1.size()), 0);
71   ASSERT_EQ(memcmp(the_file.c_str() + off1, data1.data(), data1.size()), 0);
72   ASSERT_EQ(off1 % kFieldAlignment, 0);
73   // inplace getRecord() test
74   std::vector<uint8_t> dst(size);
75   size_t ret = reader.getRecord("key1", dst.data(), size);
76   ASSERT_EQ(ret, size);
77   ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
78   // chunked getRecord() test
79   ret = reader.getRecord(
80       "key1", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
81         memcpy(dst, src, n);
82       });
83   ASSERT_EQ(ret, size);
84   ASSERT_EQ(memcmp(dst.data(), data1.data(), size), 0);
85 
86   std::tie(data_ptr, size) = reader.getRecord("key2");
87   size_t off2 = reader.getRecordOffset("key2");
88   ASSERT_EQ(off2 % kFieldAlignment, 0);
89 
90   ASSERT_EQ(size, data2.size());
91   ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), data2.size()), 0);
92   ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
93   // inplace getRecord() test
94   dst.resize(size);
95   ret = reader.getRecord("key2", dst.data(), size);
96   ASSERT_EQ(ret, size);
97   ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
98   // chunked getRecord() test
99   ret = reader.getRecord(
100       "key2", dst.data(), size, 3, buf.data(), [](void* dst, const void* src, size_t n) {
101         memcpy(dst, src, n);
102       });
103   ASSERT_EQ(ret, size);
104   ASSERT_EQ(memcmp(dst.data(), data2.data(), size), 0);
105   // clean up
106   remove(file_name);
107 }
108 
TEST(PyTorchStreamWriterAndReader,LoadWithMultiThreads)109 TEST(PyTorchStreamWriterAndReader, LoadWithMultiThreads) {
110 
111   std::ostringstream oss;
112   // write records through writers
113   PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
114     oss.write(static_cast<const char*>(b), n);
115     return oss ? n : 0;
116   });
117 
118   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
119   std::array<char, 127> data1;
120   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
121   std::array<char, 64> data2;
122   for (auto i : c10::irange(data1.size())) {
123     data1[i] = data1.size() - i;
124   }
125   writer.writeRecord("key1", data1.data(), data1.size());
126 
127   for (auto i : c10::irange(data2.size())) {
128     data2[i] = data2.size() - i;
129   }
130   writer.writeRecord("key2", data2.data(), data2.size());
131 
132   const std::unordered_set<std::string>& written_records =
133       writer.getAllWrittenRecords();
134   ASSERT_EQ(written_records.size(), 2);
135   ASSERT_EQ(written_records.count("key1"), 1);
136   ASSERT_EQ(written_records.count("key2"), 1);
137 
138   writer.writeEndOfFile();
139   ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
140 
141   std::string the_file = oss.str();
142   const char* file_name = "output.zip";
143   std::ofstream foo(file_name);
144   foo.write(the_file.c_str(), the_file.size());
145   foo.close();
146 
147   // read records through pytorchStreamReader
148   std::istringstream iss(the_file);
149   PyTorchStreamReader reader(&iss);
150   reader.setAdditionalReaderSizeThreshold(0);
151   // before testing, sanity check
152   int64_t size1, size2, ret;
153   at::DataPtr data_ptr;
154   std::tie(data_ptr, size1) = reader.getRecord("key1");
155   std::tie(data_ptr, size2) = reader.getRecord("key2");
156 
157   // Test getRecord(name, additional_readers)
158   std::vector<std::shared_ptr<ReadAdapterInterface>> additionalReader;
159   for(int i=0; i<10; ++i){
160     // Test various sized additional readers.
161     std::tie(data_ptr, ret) = reader.getRecord("key1", additionalReader);
162     ASSERT_EQ(ret, size1);
163     ASSERT_EQ(memcmp(data_ptr.get(), data1.data(), size1), 0);
164 
165     std::tie(data_ptr, ret) = reader.getRecord("key2", additionalReader);
166     ASSERT_EQ(ret, size2);
167     ASSERT_EQ(memcmp(data_ptr.get(), data2.data(), size2), 0);
168   }
169 
170   // Inplace multi-threading getRecord(name, dst, n, additional_readers) test
171   additionalReader.clear();
172   std::vector<uint8_t> dst1(size1), dst2(size2);
173   for(int i=0; i<10; ++i){
174     // Test various sizes of read threads
175     additionalReader.push_back(std::make_unique<IStreamAdapter>(&iss));
176 
177     ret = reader.getRecord("key1", dst1.data(), size1, additionalReader);
178     ASSERT_EQ(ret, size1);
179     ASSERT_EQ(memcmp(dst1.data(), data1.data(), size1), 0);
180 
181     ret = reader.getRecord("key2", dst2.data(), size2, additionalReader);
182     ASSERT_EQ(ret, size2);
183     ASSERT_EQ(memcmp(dst2.data(), data2.data(), size2), 0);
184   }
185   // clean up
186   remove(file_name);
187 }
188 
TEST(PytorchStreamWriterAndReader,GetNonexistentRecordThrows)189 TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
190   std::ostringstream oss;
191   // write records through writers
192   PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
193     oss.write(static_cast<const char*>(b), n);
194     return oss ? n : 0;
195   });
196   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
197   std::array<char, 127> data1;
198 
199   // Inplace memory buffer
200   std::vector<uint8_t> buf;
201 
202   for (auto i : c10::irange(data1.size())) {
203     data1[i] = data1.size() - i;
204   }
205   writer.writeRecord("key1", data1.data(), data1.size());
206 
207   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
208   std::array<char, 64> data2;
209   for (auto i : c10::irange(data2.size())) {
210     data2[i] = data2.size() - i;
211   }
212   writer.writeRecord("key2", data2.data(), data2.size());
213 
214   const std::unordered_set<std::string>& written_records =
215       writer.getAllWrittenRecords();
216   ASSERT_EQ(written_records.size(), 2);
217   ASSERT_EQ(written_records.count("key1"), 1);
218   ASSERT_EQ(written_records.count("key2"), 1);
219 
220   writer.writeEndOfFile();
221   ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
222 
223   std::string the_file = oss.str();
224   const char* file_name = "output2.zip";
225   std::ofstream foo(file_name);
226   foo.write(the_file.c_str(), the_file.size());
227   foo.close();
228 
229   std::istringstream iss(the_file);
230 
231   // read records through readers
232   PyTorchStreamReader reader(&iss);
233   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
234   EXPECT_THROW(reader.getRecord("key3"), c10::Error);
235   std::vector<uint8_t> dst(data1.size());
236   EXPECT_THROW(reader.getRecord("key3", dst.data(), data1.size()), c10::Error);
237   EXPECT_THROW(
238       reader.getRecord(
239           "key3",
240           dst.data(),
241           data1.size(),
242           3,
243           buf.data(),
244           [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); }),
245       c10::Error);
246 
247   // Reader should still work after throwing
248   EXPECT_TRUE(reader.hasRecord("key1"));
249   // clean up
250   remove(file_name);
251 }
252 
TEST(PytorchStreamWriterAndReader,SkipDebugRecords)253 TEST(PytorchStreamWriterAndReader, SkipDebugRecords) {
254   std::ostringstream oss;
255   PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
256     oss.write(static_cast<const char*>(b), n);
257     return oss ? n : 0;
258   });
259   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
260   std::array<char, 127> data1;
261   // Inplace memory buffer
262   std::vector<uint8_t> buf(data1.size());
263 
264   for (auto i : c10::irange(data1.size())) {
265     data1[i] = data1.size() - i;
266   }
267   writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());
268 
269   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
270   std::array<char, 64> data2;
271   for (auto i : c10::irange(data2.size())) {
272     data2[i] = data2.size() - i;
273   }
274   writer.writeRecord("key2.debug_pkl", data2.data(), data2.size());
275 
276   const std::unordered_set<std::string>& written_records =
277       writer.getAllWrittenRecords();
278   ASSERT_EQ(written_records.size(), 2);
279   ASSERT_EQ(written_records.count("key1.debug_pkl"), 1);
280   ASSERT_EQ(written_records.count("key2.debug_pkl"), 1);
281   writer.writeEndOfFile();
282   ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
283 
284   std::string the_file = oss.str();
285   const char* file_name = "output3.zip";
286   std::ofstream foo(file_name);
287   foo.write(the_file.c_str(), the_file.size());
288   foo.close();
289 
290   std::istringstream iss(the_file);
291 
292   // read records through readers
293   PyTorchStreamReader reader(&iss);
294   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
295 
296   reader.setShouldLoadDebugSymbol(false);
297   EXPECT_FALSE(reader.hasRecord("key1.debug_pkl"));
298   at::DataPtr ptr;
299   size_t size;
300   std::tie(ptr, size) = reader.getRecord("key1.debug_pkl");
301   EXPECT_EQ(size, 0);
302   std::vector<uint8_t> dst(data1.size());
303   size_t ret = reader.getRecord("key1.debug_pkl", dst.data(), data1.size());
304   EXPECT_EQ(ret, 0);
305   ret = reader.getRecord(
306       "key1.debug_pkl",
307       dst.data(),
308       data1.size(),
309       3,
310       buf.data(),
311       [](void* dst, const void* src, size_t n) { memcpy(dst, src, n); });
312   EXPECT_EQ(ret, 0);
313   // clean up
314   remove(file_name);
315 }
316 
TEST(PytorchStreamWriterAndReader,ValidSerializationId)317 TEST(PytorchStreamWriterAndReader, ValidSerializationId) {
318   std::ostringstream oss;
319   PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
320     oss.write(static_cast<const char*>(b), n);
321     return oss ? n : 0;
322   });
323 
324   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
325   std::array<char, 127> data1;
326 
327   for (auto i: c10::irange(data1.size())) {
328     data1[i] = data1.size() - i;
329   }
330   writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());
331   writer.writeEndOfFile();
332   auto writer_serialization_id = writer.serializationId();
333 
334   std::string the_file = oss.str();
335 
336   std::istringstream iss(the_file);
337 
338   // read records through readers
339   PyTorchStreamReader reader(&iss);
340   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
341 
342   EXPECT_EQ(reader.serializationId(), writer_serialization_id);
343 
344   // write a second time
345   PyTorchStreamWriter writer2([&](const void* b, size_t n) -> size_t {
346     oss.write(static_cast<const char*>(b), n);
347     return oss ? n : 0;
348   });
349   writer2.writeRecord("key1.debug_pkl", data1.data(), data1.size());
350   writer2.writeEndOfFile();
351   auto writer2_serialization_id = writer2.serializationId();
352 
353   EXPECT_EQ(writer_serialization_id, writer2_serialization_id);
354 }
355 
TEST(PytorchStreamWriterAndReader,SkipDuplicateSerializationIdRecords)356 TEST(PytorchStreamWriterAndReader, SkipDuplicateSerializationIdRecords) {
357   std::ostringstream oss;
358   PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
359     oss.write(static_cast<const char*>(b), n);
360     return oss ? n : 0;
361   });
362 
363   std::string dup_serialization_id = "dup-serialization-id";
364   writer.writeRecord(kSerializationIdRecordName, dup_serialization_id.c_str(), dup_serialization_id.size());
365 
366   const std::unordered_set<std::string>& written_records =
367       writer.getAllWrittenRecords();
368   ASSERT_EQ(written_records.size(), 0);
369   writer.writeEndOfFile();
370   ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
371   auto writer_serialization_id = writer.serializationId();
372 
373   std::string the_file = oss.str();
374   const char* file_name = "output4.zip";
375   std::ofstream foo(file_name);
376   foo.write(the_file.c_str(), the_file.size());
377   foo.close();
378 
379   std::istringstream iss(the_file);
380 
381   // read records through readers
382   PyTorchStreamReader reader(&iss);
383   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
384 
385   EXPECT_EQ(reader.serializationId(), writer_serialization_id);
386   // clean up
387   remove(file_name);
388 }
389 
TEST(PytorchStreamWriterAndReader,LogAPIUsageMetadata)390 TEST(PytorchStreamWriterAndReader, LogAPIUsageMetadata) {
391   std::map<std::string, std::map<std::string, std::string>> logs;
392 
393   SetAPIUsageMetadataLogger(
394       [&](const std::string& context,
395           const std::map<std::string, std::string>& metadata_map) {
396         logs.insert({context, metadata_map});
397       });
398   std::ostringstream oss;
399   PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
400     oss.write(static_cast<const char*>(b), n);
401     return oss ? n : 0;
402   });
403   writer.writeEndOfFile();
404 
405   std::istringstream iss(oss.str());
406   // read records through readers
407   PyTorchStreamReader reader(&iss);
408 
409   ASSERT_EQ(logs.size(), 2);
410   std::map<std::string, std::map<std::string, std::string>> expected_logs = {
411       {"pytorch.stream.writer.metadata",
412        {{"serialization_id", writer.serializationId()},
413        {"file_name", "archive"},
414        {"file_size", str(oss.str().length())}}},
415       {"pytorch.stream.reader.metadata",
416        {{"serialization_id", writer.serializationId()},
417        {"file_name", "archive"},
418        {"file_size", str(iss.str().length())}}}
419   };
420   ASSERT_EQ(expected_logs, logs);
421 
422   // reset logger
423   SetAPIUsageMetadataLogger(
424       [&](const std::string& context,
425           const std::map<std::string, std::string>& metadata_map) {});
426 }
427 
428 class ChunkRecordIteratorTest : public ::testing::TestWithParam<int64_t> {};
429 INSTANTIATE_TEST_SUITE_P(
430     ChunkRecordIteratorTestGroup,
431     ChunkRecordIteratorTest,
432     testing::Values(100, 150, 1010));
433 
TEST_P(ChunkRecordIteratorTest,ChunkRead)434 TEST_P(ChunkRecordIteratorTest, ChunkRead) {
435   auto chunkSize = GetParam();
436   std::string zipFileName = "output_chunk_" + std::to_string(chunkSize) + ".zip";
437   const char* fileName = zipFileName.c_str();
438   const std::string recordName = "key1";
439   const size_t tensorDataSizeInBytes = 1000;
440 
441   // write records through writers
442   std::ostringstream oss(std::ios::binary);
443   PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
444     oss.write(static_cast<const char*>(b), n);
445     return oss ? n : 0;
446   });
447 
448   auto tensorData = std::vector<uint8_t>(tensorDataSizeInBytes, 1);
449   auto dataPtr = tensorData.data();
450   writer.writeRecord(recordName, dataPtr, tensorDataSizeInBytes);
451   const std::unordered_set<std::string>& written_records =
452       writer.getAllWrittenRecords();
453   ASSERT_EQ(written_records.size(), 1);
454   ASSERT_EQ(written_records.count(recordName), 1);
455   writer.writeEndOfFile();
456   ASSERT_EQ(written_records.count(kSerializationIdRecordName), 1);
457 
458   std::string the_file = oss.str();
459   std::ofstream foo(fileName, std::ios::binary);
460   foo.write(the_file.c_str(), the_file.size());
461   foo.close();
462   LOG(INFO) << "Finished saving tensor into zip file " << fileName;
463 
464   LOG(INFO) << "Testing chunk size " << chunkSize;
465   PyTorchStreamReader reader(fileName);
466   ASSERT_TRUE(reader.hasRecord(recordName));
467   auto chunkIterator = reader.createChunkReaderIter(
468       recordName, tensorDataSizeInBytes, chunkSize);
469   std::vector<uint8_t> buffer(chunkSize);
470   size_t totalReadSize = 0;
471   while (auto readSize = chunkIterator.next(buffer.data())) {
472     auto expectedData = std::vector<uint8_t>(readSize, 1);
473     ASSERT_EQ(memcmp(expectedData.data(), buffer.data(), readSize), 0);
474     totalReadSize += readSize;
475   }
476   ASSERT_EQ(totalReadSize, tensorDataSizeInBytes);
477   // clean up
478   remove(fileName);
479 }
480 
481 } // namespace
482 } // namespace serialize
483 } // namespace caffe2
484