xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/file_format.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <array>
4 #include <cerrno>
5 #include <cstddef>
6 #include <cstring>
7 #include <fstream>
8 #include <istream>
9 #include <memory>
10 
11 #include <c10/core/CPUAllocator.h>
12 #include <c10/core/impl/alloc_cpu.h>
13 #include <caffe2/serialize/read_adapter_interface.h>
14 
15 #if defined(HAVE_MMAP)
16 #include <fcntl.h>
17 #include <sys/mman.h>
18 #include <sys/stat.h>
19 #include <sys/types.h>
20 #include <unistd.h>
21 #endif
22 
23 /**
24  * @file
25  *
26  * Helpers for identifying file formats when reading serialized data.
27  *
28  * Note that these functions are declared inline because they will typically
29  * only be called from one or two locations per binary.
30  */
31 
32 namespace torch::jit {
33 
34 /**
35  * The format of a file or data stream.
36  */
37 enum class FileFormat {
38   UnknownFileFormat = 0,
39   FlatbufferFileFormat,
40   ZipFileFormat,
41 };
42 
43 /// The size of the buffer to pass to #getFileFormat(), in bytes.
44 constexpr size_t kFileFormatHeaderSize = 8;
45 constexpr size_t kMaxAlignment = 16;
46 
47 /**
48  * Returns the likely file format based on the magic header bytes in @p header,
49  * which should contain the first bytes of a file or data stream.
50  */
51 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
getFileFormat(const char * data)52 static inline FileFormat getFileFormat(const char* data) {
53   // The size of magic strings to look for in the buffer.
54   static constexpr size_t kMagicSize = 4;
55 
56   // Bytes 4..7 of a Flatbuffer-encoded file produced by
57   // `flatbuffer_serializer.h`. (The first four bytes contain an offset to the
58   // actual Flatbuffer data.)
59   static constexpr std::array<char, kMagicSize> kFlatbufferMagicString = {
60       'P', 'T', 'M', 'F'};
61   static constexpr size_t kFlatbufferMagicOffset = 4;
62 
63   // The first four bytes of a ZIP file.
64   static constexpr std::array<char, kMagicSize> kZipMagicString = {
65       'P', 'K', '\x03', '\x04'};
66 
67   // Note that we check for Flatbuffer magic first. Since the first four bytes
68   // of flatbuffer data contain an offset to the root struct, it's theoretically
69   // possible to construct a file whose offset looks like the ZIP magic. On the
70   // other hand, bytes 4-7 of ZIP files are constrained to a small set of values
71   // that do not typically cross into the printable ASCII range, so a ZIP file
72   // should never have a header that looks like a Flatbuffer file.
73   if (std::memcmp(
74           data + kFlatbufferMagicOffset,
75           kFlatbufferMagicString.data(),
76           kMagicSize) == 0) {
77     // Magic header for a binary file containing a Flatbuffer-serialized mobile
78     // Module.
79     return FileFormat::FlatbufferFileFormat;
80   } else if (std::memcmp(data, kZipMagicString.data(), kMagicSize) == 0) {
81     // Magic header for a zip file, which we use to store pickled sub-files.
82     return FileFormat::ZipFileFormat;
83   }
84   return FileFormat::UnknownFileFormat;
85 }
86 
87 /**
88  * Returns the likely file format based on the magic header bytes of @p data.
89  * If the stream position changes while inspecting the data, this function will
90  * restore the stream position to its original offset before returning.
91  */
92 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
getFileFormat(std::istream & data)93 static inline FileFormat getFileFormat(std::istream& data) {
94   FileFormat format = FileFormat::UnknownFileFormat;
95   std::streampos orig_pos = data.tellg();
96   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
97   std::array<char, kFileFormatHeaderSize> header;
98   data.read(header.data(), header.size());
99   if (data.good()) {
100     format = getFileFormat(header.data());
101   }
102   data.seekg(orig_pos, data.beg);
103   return format;
104 }
105 
106 /**
107  * Returns the likely file format based on the magic header bytes of the file
108  * named @p filename.
109  */
110 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
getFileFormat(const std::string & filename)111 static inline FileFormat getFileFormat(const std::string& filename) {
112   std::ifstream data(filename, std::ifstream::binary);
113   return getFileFormat(data);
114 }
115 
116 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
file_not_found_error()117 static void file_not_found_error() {
118   std::stringstream message;
119   message << "Error while opening file: ";
120   if (errno == ENOENT) {
121     message << "no such file or directory" << '\n';
122   } else {
123     message << "error no is: " << errno << '\n';
124   }
125   TORCH_CHECK(false, message.str());
126 }
127 
128 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
get_file_content(const char * filename)129 static inline std::tuple<std::shared_ptr<char>, size_t> get_file_content(
130     const char* filename) {
131 #if defined(HAVE_MMAP)
132   int fd = open(filename, O_RDONLY);
133   if (fd < 0) {
134     // failed to open file, chances are it's no such file or directory.
135     file_not_found_error();
136   }
137   struct stat statbuf {};
138   fstat(fd, &statbuf);
139   size_t size = statbuf.st_size;
140   void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
141   close(fd);
142   auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); };
143   std::shared_ptr<char> data(reinterpret_cast<char*>(ptr), deleter);
144 #else
145   FILE* f = fopen(filename, "rb");
146   if (f == nullptr) {
147     file_not_found_error();
148   }
149   fseek(f, 0, SEEK_END);
150   size_t size = ftell(f);
151   fseek(f, 0, SEEK_SET);
152   // make sure buffer size is multiple of alignment
153   size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
154   std::shared_ptr<char> data(
155       static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
156   fread(data.get(), size, 1, f);
157   fclose(f);
158 #endif
159   return std::make_tuple(data, size);
160 }
161 
162 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
get_stream_content(std::istream & in)163 static inline std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
164     std::istream& in) {
165   // get size of the stream and reset to orig
166   std::streampos orig_pos = in.tellg();
167   in.seekg(orig_pos, std::ios::end);
168   const long size = in.tellg();
169   in.seekg(orig_pos, in.beg);
170 
171   // read stream
172   // NOLINT make sure buffer size is multiple of alignment
173   size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
174   std::shared_ptr<char> data(
175       static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
176   in.read(data.get(), size);
177 
178   // reset stream to original position
179   in.seekg(orig_pos, in.beg);
180   return std::make_tuple(data, size);
181 }
182 
183 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
get_rai_content(caffe2::serialize::ReadAdapterInterface * rai)184 static inline std::tuple<std::shared_ptr<char>, size_t> get_rai_content(
185     caffe2::serialize::ReadAdapterInterface* rai) {
186   size_t buffer_size = (rai->size() / kMaxAlignment + 1) * kMaxAlignment;
187   std::shared_ptr<char> data(
188       static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
189   rai->read(
190       0, data.get(), rai->size(), "Loading ReadAdapterInterface to bytes");
191   return std::make_tuple(data, buffer_size);
192 }
193 
194 } // namespace torch::jit
195