xref: /aosp_15_r20/external/pytorch/caffe2/serialize/inline_container.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cerrno>
4 #include <cstdio>
5 #include <cstring>
6 #include <fstream>
7 #include <istream>
8 #include <mutex>
9 #include <ostream>
10 #include <unordered_set>
11 
12 #include <c10/core/Allocator.h>
13 #include <c10/core/Backend.h>
14 
15 #include "caffe2/serialize/istream_adapter.h"
16 #include "caffe2/serialize/read_adapter_interface.h"
17 #include "caffe2/serialize/versions.h"
18 
19 
20 extern "C" {
21 typedef struct mz_zip_archive mz_zip_archive;
22 }
23 
24 // PyTorch containers are a special zip archive with the following layout
25 // archive_name.zip contains:
26 //    archive_name/
27 //        version # a file with a single decimal number written in ascii,
28 //                # used to establish the version of the archive format
29 //        model.json # overall model description, this is a json output of
30 //                   # ModelDef from torch.proto
31 //        # the following names are by convention only, model.json will
32 //        # refer to these files by full names
33 //        tensors/
34 //          0 # flat storage for tensor data, meta-data about shapes, etc. is
35 //            # in model.json
36 //          1
37 //          ...
38 //        # code entries will only exist for modules that have methods attached
39 //        code/
40 //          archive_name.py # serialized torch script code (python syntax, using
41 //          PythonPrint) archive_name_my_submodule.py # submodules have separate
42 //          files
43 //
44 // The PyTorchStreamWriter also ensures additional useful properties for these
45 // files
46 // 1. All files are stored uncompressed.
47 // 2. All files in the archive are aligned to 64 byte boundaries such that
48 //    it is possible to mmap the entire file and get an aligned pointer to
49 //    tensor data.
50 // 3. We universally write in ZIP64 format for consistency.
51 
52 // The PyTorchStreamReader also provides additional properties:
53 // 1. It can read zip files that are created with common
54 //    zip tools. This means that even though our writer doesn't compress files,
55 //    the reader can still read files that were compressed.
56 // 2. It provides a getRecordOffset function which returns the offset into the
57 //    raw file where file data lives. If the file was written with
58 //    PyTorchStreamWriter it is guaranteed to be 64 byte aligned.
59 
60 // PyTorchReader/Writer handle checking the version number on the archive format
61 // and ensure that all files are written to a archive_name directory so they
62 // unzip cleanly.
63 
64 // When developing this format we want to pay particular attention to the
65 // following use cases:
66 //
67 // -- Reading --
68 // 1) Reading with full random access
69 //   a) Reading with file api's such as fread()
70 //   b) mmaping the file and jumping around the mapped region
71 // 2) Reading with 1-pass sequential access
72 //      -> A reader will need to build up a data structure of parsed structures
73 //         as it reads
74 //
75 // -- Writing --
76 // 1) Writing with full random access
77 // 2) Writing with 1-pass sequential access
78 //      -> We must take care not to require updating values that have already
79 //         been written. We place the variable-length index at the end and do
80 //         not put any indicies into the header to fulfill this constraint.
81 
82 // The model.json, which contains all the metadata information,
83 // should be written as the last file. One reason is that the size of tensor
84 // data is usually stable. As long as the shape and type of the tensor do not
85 // change, the size of the data won't change. On the other sied, the size of the
86 // serialized model is likely to change, so we store it as the last record, and
87 // we don't need to move previous records when updating the model data.
88 
89 // The zip format is sufficiently flexible to handle the above use-case.
90 // it puts its central directory at the end of the archive and we write
91 // model.json as the last file when writing after we have accumulated all
92 // other information.
93 
94 namespace caffe2 {
95 namespace serialize {
96 
97 static constexpr const char* kSerializationIdRecordName = ".data/serialization_id";
98 
99 struct MzZipReaderIterWrapper;
100 
101 class TORCH_API ChunkRecordIterator {
102  public:
103   ~ChunkRecordIterator();
104 
105   // Read at most `chunkSize` into `buf`. Return the number of actual bytes read.
106   size_t next(void* buf);
recordSize()107   size_t recordSize() const { return recordSize_; }
108 
109  private:
110  ChunkRecordIterator(
111       size_t recordSize,
112       size_t chunkSize,
113       std::unique_ptr<MzZipReaderIterWrapper> iter);
114 
115   const size_t recordSize_;
116   const size_t chunkSize_;
117   size_t offset_;
118   std::unique_ptr<MzZipReaderIterWrapper> iter_;
119 
120   friend class PyTorchStreamReader;
121 };
122 
123 class TORCH_API PyTorchStreamReader final {
124  public:
125   explicit PyTorchStreamReader(const std::string& file_name);
126   explicit PyTorchStreamReader(std::istream* in);
127   explicit PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in);
128 
129   // return dataptr, size
130   std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
131   // multi-thread getRecord
132   std::tuple<at::DataPtr, size_t> getRecord(const std::string& name, std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
133   // inplace memory writing
134   size_t getRecord(const std::string& name, void* dst, size_t n);
135   // inplace memory writing, multi-threads.
136   // When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader
137   // This approach can be used for reading large tensors.
138   size_t getRecord(const std::string& name, void* dst, size_t n,
139     std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
140   size_t getRecord(
141       const std::string& name,
142       void* dst,
143       size_t n,
144       size_t chunk_size,
145       void* buf,
146       const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr);
147 
148   // Concurrent reading records with multiple readers.
149   // additionalReaders are additional clients to access the underlying record at different offsets
150   // and write to different trunks of buffers.
151   // If the overall size of the tensor is 10, and size of additionalReader is 2.
152   // The default thread will read [0,4), the additional reader will read [4,8).
153   // The default reader will read [8,10).
154   // The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8),
155   // the additional reader will write to buffer[8,10).
156   // When additionalReaders is empty, the default behavior is call getRecord(name) with default reader
157   // This approach can be used for reading large tensors.
158   size_t getRecordMultiReaders(const std::string& name,
159   std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
160   void *dst, size_t n);
161 
162   size_t getRecordSize(const std::string& name);
163 
164   size_t getRecordOffset(const std::string& name);
165   bool hasRecord(const std::string& name);
166   std::vector<std::string> getAllRecords();
167 
168   ChunkRecordIterator createChunkReaderIter(
169       const std::string& name,
170       const size_t recordSize,
171       const size_t chunkSize);
172 
173   ~PyTorchStreamReader();
version()174   uint64_t version() const {
175     return version_;
176   }
serializationId()177   const std::string& serializationId() {
178     return serialization_id_;
179   }
180 
setShouldLoadDebugSymbol(bool should_load_debug_symbol)181   void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
182     load_debug_symbol_ = should_load_debug_symbol;
183   }
setAdditionalReaderSizeThreshold(const size_t & size)184   void setAdditionalReaderSizeThreshold(const size_t& size){
185     additional_reader_size_threshold_ = size;
186   }
187  private:
188   void init();
189   size_t read(uint64_t pos, char* buf, size_t n);
190   void valid(const char* what, const char* info = "");
191   size_t getRecordID(const std::string& name);
192 
193   friend size_t
194   istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);
195   std::unique_ptr<mz_zip_archive> ar_;
196   std::string archive_name_;
197   std::string archive_name_plus_slash_;
198   std::shared_ptr<ReadAdapterInterface> in_;
199   int64_t version_;
200   std::mutex reader_lock_;
201   bool load_debug_symbol_ = true;
202   std::string serialization_id_;
203   size_t additional_reader_size_threshold_;
204 };
205 
206 class TORCH_API PyTorchStreamWriter final {
207  public:
208   explicit PyTorchStreamWriter(const std::string& archive_name);
209   explicit PyTorchStreamWriter(
210       const std::function<size_t(const void*, size_t)> writer_func);
211 
212   void setMinVersion(const uint64_t version);
213 
214   void writeRecord(
215       const std::string& name,
216       const void* data,
217       size_t size,
218       bool compress = false);
219   void writeEndOfFile();
220 
221   const std::unordered_set<std::string>& getAllWrittenRecords();
222 
finalized()223   bool finalized() const {
224     return finalized_;
225   }
226 
archiveName()227   const std::string& archiveName() {
228     return archive_name_;
229   }
230 
serializationId()231   const std::string& serializationId() {
232     return serialization_id_;
233   }
234 
235   ~PyTorchStreamWriter();
236 
237  private:
238   void setup(const std::string& file_name);
239   void valid(const char* what, const char* info = "");
240   void writeSerializationId();
241   size_t current_pos_ = 0;
242   std::unordered_set<std::string> files_written_;
243   std::unique_ptr<mz_zip_archive> ar_;
244   std::string archive_name_;
245   std::string archive_name_plus_slash_;
246   std::string padding_;
247   std::ofstream file_stream_;
248   std::function<size_t(const void*, size_t)> writer_func_;
249   uint64_t combined_uncomp_crc32_ = 0;
250   std::string serialization_id_;
251 
252   // This number will be updated when the model has operators
253   // that have valid upgraders.
254   uint64_t version_ = kMinProducedFileFormatVersion;
255   bool finalized_ = false;
256   bool err_seen_ = false;
257   friend size_t ostream_write_func(
258       void* pOpaque,
259       uint64_t file_ofs,
260       const void* pBuf,
261       size_t n);
262 };
263 
264 namespace detail {
265 // Writer-specific constants
266 constexpr uint64_t kFieldAlignment = 64;
267 
268 // Returns a record to be appended to the local user extra data entry in order
269 // to make data beginning aligned at kFieldAlignment bytes boundary.
270 size_t getPadding(
271     size_t cursor,
272     size_t filename_size,
273     size_t size,
274     std::string& padding_buf);
275 } // namespace detail
276 
277 } // namespace serialize
278 } // namespace caffe2
279