xref: /aosp_15_r20/external/pytorch/caffe2/serialize/inline_container.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <cerrno>
4*da0073e9SAndroid Build Coastguard Worker #include <cstdio>
5*da0073e9SAndroid Build Coastguard Worker #include <cstring>
6*da0073e9SAndroid Build Coastguard Worker #include <fstream>
7*da0073e9SAndroid Build Coastguard Worker #include <istream>
8*da0073e9SAndroid Build Coastguard Worker #include <mutex>
9*da0073e9SAndroid Build Coastguard Worker #include <ostream>
10*da0073e9SAndroid Build Coastguard Worker #include <unordered_set>
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Allocator.h>
13*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Backend.h>
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker #include "caffe2/serialize/istream_adapter.h"
16*da0073e9SAndroid Build Coastguard Worker #include "caffe2/serialize/read_adapter_interface.h"
17*da0073e9SAndroid Build Coastguard Worker #include "caffe2/serialize/versions.h"
18*da0073e9SAndroid Build Coastguard Worker 
19*da0073e9SAndroid Build Coastguard Worker 
20*da0073e9SAndroid Build Coastguard Worker extern "C" {
21*da0073e9SAndroid Build Coastguard Worker typedef struct mz_zip_archive mz_zip_archive;
22*da0073e9SAndroid Build Coastguard Worker }
23*da0073e9SAndroid Build Coastguard Worker 
24*da0073e9SAndroid Build Coastguard Worker // PyTorch containers are a special zip archive with the following layout
25*da0073e9SAndroid Build Coastguard Worker // archive_name.zip contains:
26*da0073e9SAndroid Build Coastguard Worker //    archive_name/
27*da0073e9SAndroid Build Coastguard Worker //        version # a file with a single decimal number written in ascii,
28*da0073e9SAndroid Build Coastguard Worker //                # used to establish the version of the archive format
29*da0073e9SAndroid Build Coastguard Worker //        model.json # overall model description, this is a json output of
30*da0073e9SAndroid Build Coastguard Worker //                   # ModelDef from torch.proto
31*da0073e9SAndroid Build Coastguard Worker //        # the following names are by convention only, model.json will
32*da0073e9SAndroid Build Coastguard Worker //        # refer to these files by full names
33*da0073e9SAndroid Build Coastguard Worker //        tensors/
34*da0073e9SAndroid Build Coastguard Worker //          0 # flat storage for tensor data, meta-data about shapes, etc. is
35*da0073e9SAndroid Build Coastguard Worker //            # in model.json
36*da0073e9SAndroid Build Coastguard Worker //          1
37*da0073e9SAndroid Build Coastguard Worker //          ...
38*da0073e9SAndroid Build Coastguard Worker //        # code entries will only exist for modules that have methods attached
39*da0073e9SAndroid Build Coastguard Worker //        code/
40*da0073e9SAndroid Build Coastguard Worker //          archive_name.py # serialized torch script code (python syntax, using
41*da0073e9SAndroid Build Coastguard Worker //          PythonPrint) archive_name_my_submodule.py # submodules have separate
42*da0073e9SAndroid Build Coastguard Worker //          files
43*da0073e9SAndroid Build Coastguard Worker //
44*da0073e9SAndroid Build Coastguard Worker // The PyTorchStreamWriter also ensures additional useful properties for these
45*da0073e9SAndroid Build Coastguard Worker // files
46*da0073e9SAndroid Build Coastguard Worker // 1. All files are stored uncompressed.
47*da0073e9SAndroid Build Coastguard Worker // 2. All files in the archive are aligned to 64 byte boundaries such that
48*da0073e9SAndroid Build Coastguard Worker //    it is possible to mmap the entire file and get an aligned pointer to
49*da0073e9SAndroid Build Coastguard Worker //    tensor data.
50*da0073e9SAndroid Build Coastguard Worker // 3. We universally write in ZIP64 format for consistency.
51*da0073e9SAndroid Build Coastguard Worker 
52*da0073e9SAndroid Build Coastguard Worker // The PyTorchStreamReader also provides additional properties:
53*da0073e9SAndroid Build Coastguard Worker // 1. It can read zip files that are created with common
54*da0073e9SAndroid Build Coastguard Worker //    zip tools. This means that even though our writer doesn't compress files,
55*da0073e9SAndroid Build Coastguard Worker //    the reader can still read files that were compressed.
56*da0073e9SAndroid Build Coastguard Worker // 2. It provides a getRecordOffset function which returns the offset into the
57*da0073e9SAndroid Build Coastguard Worker //    raw file where file data lives. If the file was written with
58*da0073e9SAndroid Build Coastguard Worker //    PyTorchStreamWriter it is guaranteed to be 64 byte aligned.
59*da0073e9SAndroid Build Coastguard Worker 
60*da0073e9SAndroid Build Coastguard Worker // PyTorchReader/Writer handle checking the version number on the archive format
61*da0073e9SAndroid Build Coastguard Worker // and ensure that all files are written to a archive_name directory so they
62*da0073e9SAndroid Build Coastguard Worker // unzip cleanly.
63*da0073e9SAndroid Build Coastguard Worker 
64*da0073e9SAndroid Build Coastguard Worker // When developing this format we want to pay particular attention to the
65*da0073e9SAndroid Build Coastguard Worker // following use cases:
66*da0073e9SAndroid Build Coastguard Worker //
67*da0073e9SAndroid Build Coastguard Worker // -- Reading --
68*da0073e9SAndroid Build Coastguard Worker // 1) Reading with full random access
69*da0073e9SAndroid Build Coastguard Worker //   a) Reading with file api's such as fread()
70*da0073e9SAndroid Build Coastguard Worker //   b) mmaping the file and jumping around the mapped region
71*da0073e9SAndroid Build Coastguard Worker // 2) Reading with 1-pass sequential access
72*da0073e9SAndroid Build Coastguard Worker //      -> A reader will need to build up a data structure of parsed structures
73*da0073e9SAndroid Build Coastguard Worker //         as it reads
74*da0073e9SAndroid Build Coastguard Worker //
75*da0073e9SAndroid Build Coastguard Worker // -- Writing --
76*da0073e9SAndroid Build Coastguard Worker // 1) Writing with full random access
77*da0073e9SAndroid Build Coastguard Worker // 2) Writing with 1-pass sequential access
78*da0073e9SAndroid Build Coastguard Worker //      -> We must take care not to require updating values that have already
79*da0073e9SAndroid Build Coastguard Worker //         been written. We place the variable-length index at the end and do
80*da0073e9SAndroid Build Coastguard Worker //         not put any indicies into the header to fulfill this constraint.
81*da0073e9SAndroid Build Coastguard Worker 
82*da0073e9SAndroid Build Coastguard Worker // The model.json, which contains all the metadata information,
83*da0073e9SAndroid Build Coastguard Worker // should be written as the last file. One reason is that the size of tensor
84*da0073e9SAndroid Build Coastguard Worker // data is usually stable. As long as the shape and type of the tensor do not
85*da0073e9SAndroid Build Coastguard Worker // change, the size of the data won't change. On the other sied, the size of the
86*da0073e9SAndroid Build Coastguard Worker // serialized model is likely to change, so we store it as the last record, and
87*da0073e9SAndroid Build Coastguard Worker // we don't need to move previous records when updating the model data.
88*da0073e9SAndroid Build Coastguard Worker 
89*da0073e9SAndroid Build Coastguard Worker // The zip format is sufficiently flexible to handle the above use-case.
90*da0073e9SAndroid Build Coastguard Worker // it puts its central directory at the end of the archive and we write
91*da0073e9SAndroid Build Coastguard Worker // model.json as the last file when writing after we have accumulated all
92*da0073e9SAndroid Build Coastguard Worker // other information.
93*da0073e9SAndroid Build Coastguard Worker 
94*da0073e9SAndroid Build Coastguard Worker namespace caffe2 {
95*da0073e9SAndroid Build Coastguard Worker namespace serialize {
96*da0073e9SAndroid Build Coastguard Worker 
97*da0073e9SAndroid Build Coastguard Worker static constexpr const char* kSerializationIdRecordName = ".data/serialization_id";
98*da0073e9SAndroid Build Coastguard Worker 
99*da0073e9SAndroid Build Coastguard Worker struct MzZipReaderIterWrapper;
100*da0073e9SAndroid Build Coastguard Worker 
101*da0073e9SAndroid Build Coastguard Worker class TORCH_API ChunkRecordIterator {
102*da0073e9SAndroid Build Coastguard Worker  public:
103*da0073e9SAndroid Build Coastguard Worker   ~ChunkRecordIterator();
104*da0073e9SAndroid Build Coastguard Worker 
105*da0073e9SAndroid Build Coastguard Worker   // Read at most `chunkSize` into `buf`. Return the number of actual bytes read.
106*da0073e9SAndroid Build Coastguard Worker   size_t next(void* buf);
recordSize()107*da0073e9SAndroid Build Coastguard Worker   size_t recordSize() const { return recordSize_; }
108*da0073e9SAndroid Build Coastguard Worker 
109*da0073e9SAndroid Build Coastguard Worker  private:
110*da0073e9SAndroid Build Coastguard Worker  ChunkRecordIterator(
111*da0073e9SAndroid Build Coastguard Worker       size_t recordSize,
112*da0073e9SAndroid Build Coastguard Worker       size_t chunkSize,
113*da0073e9SAndroid Build Coastguard Worker       std::unique_ptr<MzZipReaderIterWrapper> iter);
114*da0073e9SAndroid Build Coastguard Worker 
115*da0073e9SAndroid Build Coastguard Worker   const size_t recordSize_;
116*da0073e9SAndroid Build Coastguard Worker   const size_t chunkSize_;
117*da0073e9SAndroid Build Coastguard Worker   size_t offset_;
118*da0073e9SAndroid Build Coastguard Worker   std::unique_ptr<MzZipReaderIterWrapper> iter_;
119*da0073e9SAndroid Build Coastguard Worker 
120*da0073e9SAndroid Build Coastguard Worker   friend class PyTorchStreamReader;
121*da0073e9SAndroid Build Coastguard Worker };
122*da0073e9SAndroid Build Coastguard Worker 
123*da0073e9SAndroid Build Coastguard Worker class TORCH_API PyTorchStreamReader final {
124*da0073e9SAndroid Build Coastguard Worker  public:
125*da0073e9SAndroid Build Coastguard Worker   explicit PyTorchStreamReader(const std::string& file_name);
126*da0073e9SAndroid Build Coastguard Worker   explicit PyTorchStreamReader(std::istream* in);
127*da0073e9SAndroid Build Coastguard Worker   explicit PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in);
128*da0073e9SAndroid Build Coastguard Worker 
129*da0073e9SAndroid Build Coastguard Worker   // return dataptr, size
130*da0073e9SAndroid Build Coastguard Worker   std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
131*da0073e9SAndroid Build Coastguard Worker   // multi-thread getRecord
132*da0073e9SAndroid Build Coastguard Worker   std::tuple<at::DataPtr, size_t> getRecord(const std::string& name, std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
133*da0073e9SAndroid Build Coastguard Worker   // inplace memory writing
134*da0073e9SAndroid Build Coastguard Worker   size_t getRecord(const std::string& name, void* dst, size_t n);
135*da0073e9SAndroid Build Coastguard Worker   // inplace memory writing, multi-threads.
136*da0073e9SAndroid Build Coastguard Worker   // When additionalReaders is empty, the default behavior is call getRecord(name, dst, n) with default reader
137*da0073e9SAndroid Build Coastguard Worker   // This approach can be used for reading large tensors.
138*da0073e9SAndroid Build Coastguard Worker   size_t getRecord(const std::string& name, void* dst, size_t n,
139*da0073e9SAndroid Build Coastguard Worker     std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders);
140*da0073e9SAndroid Build Coastguard Worker   size_t getRecord(
141*da0073e9SAndroid Build Coastguard Worker       const std::string& name,
142*da0073e9SAndroid Build Coastguard Worker       void* dst,
143*da0073e9SAndroid Build Coastguard Worker       size_t n,
144*da0073e9SAndroid Build Coastguard Worker       size_t chunk_size,
145*da0073e9SAndroid Build Coastguard Worker       void* buf,
146*da0073e9SAndroid Build Coastguard Worker       const std::function<void(void*, const void*, size_t)>& memcpy_func = nullptr);
147*da0073e9SAndroid Build Coastguard Worker 
148*da0073e9SAndroid Build Coastguard Worker   // Concurrent reading records with multiple readers.
149*da0073e9SAndroid Build Coastguard Worker   // additionalReaders are additional clients to access the underlying record at different offsets
150*da0073e9SAndroid Build Coastguard Worker   // and write to different trunks of buffers.
151*da0073e9SAndroid Build Coastguard Worker   // If the overall size of the tensor is 10, and size of additionalReader is 2.
152*da0073e9SAndroid Build Coastguard Worker   // The default thread will read [0,4), the additional reader will read [4,8).
153*da0073e9SAndroid Build Coastguard Worker   // The default reader will read [8,10).
154*da0073e9SAndroid Build Coastguard Worker   // The default reader will write to buffer[0,4), the additional reader will write to buffer[4,8),
155*da0073e9SAndroid Build Coastguard Worker   // the additional reader will write to buffer[8,10).
156*da0073e9SAndroid Build Coastguard Worker   // When additionalReaders is empty, the default behavior is call getRecord(name) with default reader
157*da0073e9SAndroid Build Coastguard Worker   // This approach can be used for reading large tensors.
158*da0073e9SAndroid Build Coastguard Worker   size_t getRecordMultiReaders(const std::string& name,
159*da0073e9SAndroid Build Coastguard Worker   std::vector<std::shared_ptr<ReadAdapterInterface>>& additionalReaders,
160*da0073e9SAndroid Build Coastguard Worker   void *dst, size_t n);
161*da0073e9SAndroid Build Coastguard Worker 
162*da0073e9SAndroid Build Coastguard Worker   size_t getRecordSize(const std::string& name);
163*da0073e9SAndroid Build Coastguard Worker 
164*da0073e9SAndroid Build Coastguard Worker   size_t getRecordOffset(const std::string& name);
165*da0073e9SAndroid Build Coastguard Worker   bool hasRecord(const std::string& name);
166*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> getAllRecords();
167*da0073e9SAndroid Build Coastguard Worker 
168*da0073e9SAndroid Build Coastguard Worker   ChunkRecordIterator createChunkReaderIter(
169*da0073e9SAndroid Build Coastguard Worker       const std::string& name,
170*da0073e9SAndroid Build Coastguard Worker       const size_t recordSize,
171*da0073e9SAndroid Build Coastguard Worker       const size_t chunkSize);
172*da0073e9SAndroid Build Coastguard Worker 
173*da0073e9SAndroid Build Coastguard Worker   ~PyTorchStreamReader();
version()174*da0073e9SAndroid Build Coastguard Worker   uint64_t version() const {
175*da0073e9SAndroid Build Coastguard Worker     return version_;
176*da0073e9SAndroid Build Coastguard Worker   }
serializationId()177*da0073e9SAndroid Build Coastguard Worker   const std::string& serializationId() {
178*da0073e9SAndroid Build Coastguard Worker     return serialization_id_;
179*da0073e9SAndroid Build Coastguard Worker   }
180*da0073e9SAndroid Build Coastguard Worker 
setShouldLoadDebugSymbol(bool should_load_debug_symbol)181*da0073e9SAndroid Build Coastguard Worker   void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
182*da0073e9SAndroid Build Coastguard Worker     load_debug_symbol_ = should_load_debug_symbol;
183*da0073e9SAndroid Build Coastguard Worker   }
setAdditionalReaderSizeThreshold(const size_t & size)184*da0073e9SAndroid Build Coastguard Worker   void setAdditionalReaderSizeThreshold(const size_t& size){
185*da0073e9SAndroid Build Coastguard Worker     additional_reader_size_threshold_ = size;
186*da0073e9SAndroid Build Coastguard Worker   }
187*da0073e9SAndroid Build Coastguard Worker  private:
188*da0073e9SAndroid Build Coastguard Worker   void init();
189*da0073e9SAndroid Build Coastguard Worker   size_t read(uint64_t pos, char* buf, size_t n);
190*da0073e9SAndroid Build Coastguard Worker   void valid(const char* what, const char* info = "");
191*da0073e9SAndroid Build Coastguard Worker   size_t getRecordID(const std::string& name);
192*da0073e9SAndroid Build Coastguard Worker 
193*da0073e9SAndroid Build Coastguard Worker   friend size_t
194*da0073e9SAndroid Build Coastguard Worker   istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);
195*da0073e9SAndroid Build Coastguard Worker   std::unique_ptr<mz_zip_archive> ar_;
196*da0073e9SAndroid Build Coastguard Worker   std::string archive_name_;
197*da0073e9SAndroid Build Coastguard Worker   std::string archive_name_plus_slash_;
198*da0073e9SAndroid Build Coastguard Worker   std::shared_ptr<ReadAdapterInterface> in_;
199*da0073e9SAndroid Build Coastguard Worker   int64_t version_;
200*da0073e9SAndroid Build Coastguard Worker   std::mutex reader_lock_;
201*da0073e9SAndroid Build Coastguard Worker   bool load_debug_symbol_ = true;
202*da0073e9SAndroid Build Coastguard Worker   std::string serialization_id_;
203*da0073e9SAndroid Build Coastguard Worker   size_t additional_reader_size_threshold_;
204*da0073e9SAndroid Build Coastguard Worker };
205*da0073e9SAndroid Build Coastguard Worker 
206*da0073e9SAndroid Build Coastguard Worker class TORCH_API PyTorchStreamWriter final {
207*da0073e9SAndroid Build Coastguard Worker  public:
208*da0073e9SAndroid Build Coastguard Worker   explicit PyTorchStreamWriter(const std::string& archive_name);
209*da0073e9SAndroid Build Coastguard Worker   explicit PyTorchStreamWriter(
210*da0073e9SAndroid Build Coastguard Worker       const std::function<size_t(const void*, size_t)> writer_func);
211*da0073e9SAndroid Build Coastguard Worker 
212*da0073e9SAndroid Build Coastguard Worker   void setMinVersion(const uint64_t version);
213*da0073e9SAndroid Build Coastguard Worker 
214*da0073e9SAndroid Build Coastguard Worker   void writeRecord(
215*da0073e9SAndroid Build Coastguard Worker       const std::string& name,
216*da0073e9SAndroid Build Coastguard Worker       const void* data,
217*da0073e9SAndroid Build Coastguard Worker       size_t size,
218*da0073e9SAndroid Build Coastguard Worker       bool compress = false);
219*da0073e9SAndroid Build Coastguard Worker   void writeEndOfFile();
220*da0073e9SAndroid Build Coastguard Worker 
221*da0073e9SAndroid Build Coastguard Worker   const std::unordered_set<std::string>& getAllWrittenRecords();
222*da0073e9SAndroid Build Coastguard Worker 
finalized()223*da0073e9SAndroid Build Coastguard Worker   bool finalized() const {
224*da0073e9SAndroid Build Coastguard Worker     return finalized_;
225*da0073e9SAndroid Build Coastguard Worker   }
226*da0073e9SAndroid Build Coastguard Worker 
archiveName()227*da0073e9SAndroid Build Coastguard Worker   const std::string& archiveName() {
228*da0073e9SAndroid Build Coastguard Worker     return archive_name_;
229*da0073e9SAndroid Build Coastguard Worker   }
230*da0073e9SAndroid Build Coastguard Worker 
serializationId()231*da0073e9SAndroid Build Coastguard Worker   const std::string& serializationId() {
232*da0073e9SAndroid Build Coastguard Worker     return serialization_id_;
233*da0073e9SAndroid Build Coastguard Worker   }
234*da0073e9SAndroid Build Coastguard Worker 
235*da0073e9SAndroid Build Coastguard Worker   ~PyTorchStreamWriter();
236*da0073e9SAndroid Build Coastguard Worker 
237*da0073e9SAndroid Build Coastguard Worker  private:
238*da0073e9SAndroid Build Coastguard Worker   void setup(const std::string& file_name);
239*da0073e9SAndroid Build Coastguard Worker   void valid(const char* what, const char* info = "");
240*da0073e9SAndroid Build Coastguard Worker   void writeSerializationId();
241*da0073e9SAndroid Build Coastguard Worker   size_t current_pos_ = 0;
242*da0073e9SAndroid Build Coastguard Worker   std::unordered_set<std::string> files_written_;
243*da0073e9SAndroid Build Coastguard Worker   std::unique_ptr<mz_zip_archive> ar_;
244*da0073e9SAndroid Build Coastguard Worker   std::string archive_name_;
245*da0073e9SAndroid Build Coastguard Worker   std::string archive_name_plus_slash_;
246*da0073e9SAndroid Build Coastguard Worker   std::string padding_;
247*da0073e9SAndroid Build Coastguard Worker   std::ofstream file_stream_;
248*da0073e9SAndroid Build Coastguard Worker   std::function<size_t(const void*, size_t)> writer_func_;
249*da0073e9SAndroid Build Coastguard Worker   uint64_t combined_uncomp_crc32_ = 0;
250*da0073e9SAndroid Build Coastguard Worker   std::string serialization_id_;
251*da0073e9SAndroid Build Coastguard Worker 
252*da0073e9SAndroid Build Coastguard Worker   // This number will be updated when the model has operators
253*da0073e9SAndroid Build Coastguard Worker   // that have valid upgraders.
254*da0073e9SAndroid Build Coastguard Worker   uint64_t version_ = kMinProducedFileFormatVersion;
255*da0073e9SAndroid Build Coastguard Worker   bool finalized_ = false;
256*da0073e9SAndroid Build Coastguard Worker   bool err_seen_ = false;
257*da0073e9SAndroid Build Coastguard Worker   friend size_t ostream_write_func(
258*da0073e9SAndroid Build Coastguard Worker       void* pOpaque,
259*da0073e9SAndroid Build Coastguard Worker       uint64_t file_ofs,
260*da0073e9SAndroid Build Coastguard Worker       const void* pBuf,
261*da0073e9SAndroid Build Coastguard Worker       size_t n);
262*da0073e9SAndroid Build Coastguard Worker };
263*da0073e9SAndroid Build Coastguard Worker 
264*da0073e9SAndroid Build Coastguard Worker namespace detail {
265*da0073e9SAndroid Build Coastguard Worker // Writer-specific constants
266*da0073e9SAndroid Build Coastguard Worker constexpr uint64_t kFieldAlignment = 64;
267*da0073e9SAndroid Build Coastguard Worker 
268*da0073e9SAndroid Build Coastguard Worker // Returns a record to be appended to the local user extra data entry in order
269*da0073e9SAndroid Build Coastguard Worker // to make data beginning aligned at kFieldAlignment bytes boundary.
270*da0073e9SAndroid Build Coastguard Worker size_t getPadding(
271*da0073e9SAndroid Build Coastguard Worker     size_t cursor,
272*da0073e9SAndroid Build Coastguard Worker     size_t filename_size,
273*da0073e9SAndroid Build Coastguard Worker     size_t size,
274*da0073e9SAndroid Build Coastguard Worker     std::string& padding_buf);
275*da0073e9SAndroid Build Coastguard Worker } // namespace detail
276*da0073e9SAndroid Build Coastguard Worker 
277*da0073e9SAndroid Build Coastguard Worker } // namespace serialize
278*da0073e9SAndroid Build Coastguard Worker } // namespace caffe2
279