1 #pragma once
2
3 #include <array>
4 #include <cerrno>
5 #include <cstddef>
6 #include <cstring>
7 #include <fstream>
8 #include <istream>
9 #include <memory>
10
11 #include <c10/core/CPUAllocator.h>
12 #include <c10/core/impl/alloc_cpu.h>
13 #include <caffe2/serialize/read_adapter_interface.h>
14
15 #if defined(HAVE_MMAP)
16 #include <fcntl.h>
17 #include <sys/mman.h>
18 #include <sys/stat.h>
19 #include <sys/types.h>
20 #include <unistd.h>
21 #endif
22
23 /**
24 * @file
25 *
26 * Helpers for identifying file formats when reading serialized data.
27 *
28 * Note that these functions are declared inline because they will typically
29 * only be called from one or two locations per binary.
30 */
31
32 namespace torch::jit {
33
34 /**
35 * The format of a file or data stream.
36 */
37 enum class FileFormat {
38 UnknownFileFormat = 0,
39 FlatbufferFileFormat,
40 ZipFileFormat,
41 };
42
43 /// The size of the buffer to pass to #getFileFormat(), in bytes.
44 constexpr size_t kFileFormatHeaderSize = 8;
45 constexpr size_t kMaxAlignment = 16;
46
47 /**
48 * Returns the likely file format based on the magic header bytes in @p header,
49 * which should contain the first bytes of a file or data stream.
50 */
51 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
getFileFormat(const char * data)52 static inline FileFormat getFileFormat(const char* data) {
53 // The size of magic strings to look for in the buffer.
54 static constexpr size_t kMagicSize = 4;
55
56 // Bytes 4..7 of a Flatbuffer-encoded file produced by
57 // `flatbuffer_serializer.h`. (The first four bytes contain an offset to the
58 // actual Flatbuffer data.)
59 static constexpr std::array<char, kMagicSize> kFlatbufferMagicString = {
60 'P', 'T', 'M', 'F'};
61 static constexpr size_t kFlatbufferMagicOffset = 4;
62
63 // The first four bytes of a ZIP file.
64 static constexpr std::array<char, kMagicSize> kZipMagicString = {
65 'P', 'K', '\x03', '\x04'};
66
67 // Note that we check for Flatbuffer magic first. Since the first four bytes
68 // of flatbuffer data contain an offset to the root struct, it's theoretically
69 // possible to construct a file whose offset looks like the ZIP magic. On the
70 // other hand, bytes 4-7 of ZIP files are constrained to a small set of values
71 // that do not typically cross into the printable ASCII range, so a ZIP file
72 // should never have a header that looks like a Flatbuffer file.
73 if (std::memcmp(
74 data + kFlatbufferMagicOffset,
75 kFlatbufferMagicString.data(),
76 kMagicSize) == 0) {
77 // Magic header for a binary file containing a Flatbuffer-serialized mobile
78 // Module.
79 return FileFormat::FlatbufferFileFormat;
80 } else if (std::memcmp(data, kZipMagicString.data(), kMagicSize) == 0) {
81 // Magic header for a zip file, which we use to store pickled sub-files.
82 return FileFormat::ZipFileFormat;
83 }
84 return FileFormat::UnknownFileFormat;
85 }
86
87 /**
88 * Returns the likely file format based on the magic header bytes of @p data.
89 * If the stream position changes while inspecting the data, this function will
90 * restore the stream position to its original offset before returning.
91 */
92 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
getFileFormat(std::istream & data)93 static inline FileFormat getFileFormat(std::istream& data) {
94 FileFormat format = FileFormat::UnknownFileFormat;
95 std::streampos orig_pos = data.tellg();
96 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
97 std::array<char, kFileFormatHeaderSize> header;
98 data.read(header.data(), header.size());
99 if (data.good()) {
100 format = getFileFormat(header.data());
101 }
102 data.seekg(orig_pos, data.beg);
103 return format;
104 }
105
106 /**
107 * Returns the likely file format based on the magic header bytes of the file
108 * named @p filename.
109 */
110 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
getFileFormat(const std::string & filename)111 static inline FileFormat getFileFormat(const std::string& filename) {
112 std::ifstream data(filename, std::ifstream::binary);
113 return getFileFormat(data);
114 }
115
116 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
file_not_found_error()117 static void file_not_found_error() {
118 std::stringstream message;
119 message << "Error while opening file: ";
120 if (errno == ENOENT) {
121 message << "no such file or directory" << '\n';
122 } else {
123 message << "error no is: " << errno << '\n';
124 }
125 TORCH_CHECK(false, message.str());
126 }
127
128 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
get_file_content(const char * filename)129 static inline std::tuple<std::shared_ptr<char>, size_t> get_file_content(
130 const char* filename) {
131 #if defined(HAVE_MMAP)
132 int fd = open(filename, O_RDONLY);
133 if (fd < 0) {
134 // failed to open file, chances are it's no such file or directory.
135 file_not_found_error();
136 }
137 struct stat statbuf {};
138 fstat(fd, &statbuf);
139 size_t size = statbuf.st_size;
140 void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
141 close(fd);
142 auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); };
143 std::shared_ptr<char> data(reinterpret_cast<char*>(ptr), deleter);
144 #else
145 FILE* f = fopen(filename, "rb");
146 if (f == nullptr) {
147 file_not_found_error();
148 }
149 fseek(f, 0, SEEK_END);
150 size_t size = ftell(f);
151 fseek(f, 0, SEEK_SET);
152 // make sure buffer size is multiple of alignment
153 size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
154 std::shared_ptr<char> data(
155 static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
156 fread(data.get(), size, 1, f);
157 fclose(f);
158 #endif
159 return std::make_tuple(data, size);
160 }
161
162 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
get_stream_content(std::istream & in)163 static inline std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
164 std::istream& in) {
165 // get size of the stream and reset to orig
166 std::streampos orig_pos = in.tellg();
167 in.seekg(orig_pos, std::ios::end);
168 const long size = in.tellg();
169 in.seekg(orig_pos, in.beg);
170
171 // read stream
172 // NOLINT make sure buffer size is multiple of alignment
173 size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
174 std::shared_ptr<char> data(
175 static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
176 in.read(data.get(), size);
177
178 // reset stream to original position
179 in.seekg(orig_pos, in.beg);
180 return std::make_tuple(data, size);
181 }
182
183 // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
get_rai_content(caffe2::serialize::ReadAdapterInterface * rai)184 static inline std::tuple<std::shared_ptr<char>, size_t> get_rai_content(
185 caffe2::serialize::ReadAdapterInterface* rai) {
186 size_t buffer_size = (rai->size() / kMaxAlignment + 1) * kMaxAlignment;
187 std::shared_ptr<char> data(
188 static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
189 rai->read(
190 0, data.get(), rai->size(), "Loading ReadAdapterInterface to bytes");
191 return std::make_tuple(data, buffer_size);
192 }
193
194 } // namespace torch::jit
195