1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
17
18 #include <string>
19
20 #include "absl/memory/memory.h"
21 #include "absl/status/status.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/string_view.h"
25 #include "flatbuffers/flatbuffers.h"
26 #include "contrib/minizip/ioapi.h"
27 #include "contrib/minizip/unzip.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow_lite_support/cc/common.h"
30 #include "tensorflow_lite_support/cc/port/status_macros.h"
31 #include "tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h"
32 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
33
34 namespace tflite {
35 namespace metadata {
36
37 namespace {
38 constexpr char kMetadataBufferName[] = "TFLITE_METADATA";
39
40 using ::absl::StatusCode;
41 using ::flatbuffers::Offset;
42 using ::flatbuffers::Vector;
43 using ::tflite::TensorMetadata;
44 using ::tflite::support::CreateStatusWithPayload;
45 using ::tflite::support::TfLiteSupportStatus;
46
47 // Util to get item from src_vector specified by index.
48 template <typename T>
GetItemFromVector(const flatbuffers::Vector<flatbuffers::Offset<T>> * src_vector,int index)49 const T* GetItemFromVector(
50 const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) {
51 if (src_vector == nullptr || index < 0 || index >= src_vector->size()) {
52 return nullptr;
53 }
54 return src_vector->Get(index);
55 }
56
57 // Wrapper function around calls to unzip to avoid repeating conversion logic
58 // from error code to Status.
UnzipErrorToStatus(int error)59 absl::Status UnzipErrorToStatus(int error) {
60 if (error != UNZ_OK) {
61 return CreateStatusWithPayload(
62 StatusCode::kUnknown, "Unable to read associated file in zip archive.",
63 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
64 }
65 return absl::OkStatus();
66 }
67
68 // Stores a file name, position in zip buffer and size.
69 struct ZipFileInfo {
70 std::string name;
71 ZPOS64_T position;
72 ZPOS64_T size;
73 };
74
75 // Returns the ZipFileInfo corresponding to the current file in the provided
76 // unzFile object.
GetCurrentZipFileInfo(const unzFile & zf)77 tflite::support::StatusOr<ZipFileInfo> GetCurrentZipFileInfo(const unzFile& zf) {
78 // Open file in raw mode, as data is expected to be uncompressed.
79 int method;
80 RETURN_IF_ERROR(UnzipErrorToStatus(
81 unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1)));
82 if (method != Z_NO_COMPRESSION) {
83 return CreateStatusWithPayload(
84 StatusCode::kUnknown, "Expected uncompressed zip archive.",
85 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
86 }
87
88 // Get file info a first time to get filename size.
89 unz_file_info64 file_info;
90 RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
91 zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0,
92 /*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
93 /*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
94
95 // Second call to get file name.
96 auto file_name_size = file_info.size_filename;
97 char* c_file_name = (char*)malloc(file_name_size);
98 RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
99 zf, &file_info, c_file_name, file_name_size,
100 /*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
101 /*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
102 std::string file_name = std::string(c_file_name, file_name_size);
103 free(c_file_name);
104
105 // Get position in file.
106 auto position = unzGetCurrentFileZStreamPos64(zf);
107 if (position == 0) {
108 return CreateStatusWithPayload(
109 StatusCode::kUnknown, "Unable to read file in zip archive.",
110 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
111 }
112 ZipFileInfo result = {.name = file_name,
113 .position = position,
114 .size = file_info.uncompressed_size};
115
116 // Close file and return.
117 RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf)));
118 return result;
119 }
120 } // namespace
121
122 /* static */
123 tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
CreateFromModelBuffer(const char * buffer_data,size_t buffer_size)124 ModelMetadataExtractor::CreateFromModelBuffer(const char* buffer_data,
125 size_t buffer_size) {
126 // Use absl::WrapUnique() to call private constructor:
127 // https://abseil.io/tips/126.
128 std::unique_ptr<ModelMetadataExtractor> extractor =
129 absl::WrapUnique(new ModelMetadataExtractor());
130 RETURN_IF_ERROR(extractor->InitFromModelBuffer(buffer_data, buffer_size));
131 return extractor;
132 }
133
134 /* static */
135 tflite::support::StatusOr<const tflite::ProcessUnit*>
FindFirstProcessUnit(const tflite::TensorMetadata & tensor_metadata,tflite::ProcessUnitOptions type)136 ModelMetadataExtractor::FindFirstProcessUnit(
137 const tflite::TensorMetadata& tensor_metadata,
138 tflite::ProcessUnitOptions type) {
139 const tflite::ProcessUnit* result = nullptr;
140 if (tensor_metadata.process_units() == nullptr) {
141 return result;
142 }
143 for (const auto process_unit : *tensor_metadata.process_units()) {
144 if (process_unit->options_type() == type) {
145 if (result != nullptr) {
146 return CreateStatusWithPayload(
147 StatusCode::kInvalidArgument,
148 absl::StrCat("Found multiple ProcessUnits with type=",
149 tflite::EnumNameProcessUnitOptions(type),
150 ", expected at most one."),
151 TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
152 }
153 result = process_unit;
154 }
155 }
156 return result;
157 }
158
159 /* static */
FindFirstAssociatedFileName(const tflite::TensorMetadata & tensor_metadata,tflite::AssociatedFileType type,absl::string_view locale)160 std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
161 const tflite::TensorMetadata& tensor_metadata,
162 tflite::AssociatedFileType type, absl::string_view locale) {
163 if (tensor_metadata.associated_files() == nullptr) {
164 return std::string();
165 }
166 for (const auto associated_file : *tensor_metadata.associated_files()) {
167 if (associated_file->type() != type || associated_file->name() == nullptr) {
168 continue;
169 }
170 if (locale.empty() || (associated_file->locale() != nullptr &&
171 locale == associated_file->locale()->str())) {
172 return associated_file->name()->str();
173 }
174 }
175 return std::string();
176 }
177
InitFromModelBuffer(const char * buffer_data,size_t buffer_size)178 absl::Status ModelMetadataExtractor::InitFromModelBuffer(
179 const char* buffer_data, size_t buffer_size) {
180 // Rely on the simplest, base flatbuffers verifier. Here is not the place to
181 // e.g. use an OpResolver: we just want to make sure the buffer is valid to
182 // access the metadata.
183 flatbuffers::Verifier verifier = flatbuffers::Verifier(
184 reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
185 if (!tflite::VerifyModelBuffer(verifier)) {
186 return CreateStatusWithPayload(
187 StatusCode::kInvalidArgument,
188 "The model is not a valid FlatBuffer buffer.",
189 TfLiteSupportStatus::kInvalidFlatBufferError);
190 }
191 model_ = tflite::GetModel(buffer_data);
192 if (model_->metadata() == nullptr) {
193 // Not all models have metadata, which is OK. `GetModelMetadata()` then
194 // returns nullptr.
195 return absl::OkStatus();
196 }
197 // Look for the "TFLITE_METADATA" field, if any.
198 for (int i = 0; i < model_->metadata()->size(); ++i) {
199 const auto metadata = model_->metadata()->Get(i);
200 if (!metadata->name()) {
201 continue;
202 }
203 if (metadata->name()->str() != kMetadataBufferName) {
204 continue;
205 }
206 const auto buffer_index = metadata->buffer();
207 const auto metadata_buffer =
208 model_->buffers()->Get(buffer_index)->data()->data();
209 if (!tflite::ModelMetadataBufferHasIdentifier(metadata_buffer)) {
210 return CreateStatusWithPayload(
211 StatusCode::kInvalidArgument,
212 absl::StrFormat(
213 "Invalid metadata schema version: expected %s, got %s",
214 absl::string_view(tflite::ModelMetadataIdentifier())
215 .substr(
216 0, flatbuffers::FlatBufferBuilder::kFileIdentifierLength),
217 // Returned identifier is not null terminated; has to be
218 // truncated.
219 absl::string_view(
220 flatbuffers::GetBufferIdentifier(metadata_buffer))
221 .substr(
222 0,
223 flatbuffers::FlatBufferBuilder::kFileIdentifierLength)),
224 TfLiteSupportStatus::kMetadataInvalidSchemaVersionError);
225 }
226 model_metadata_ = tflite::GetModelMetadata(metadata_buffer);
227 if (model_metadata_ == nullptr) {
228 return CreateStatusWithPayload(StatusCode::kInternal,
229 "Expected Model Metadata not to be null.");
230 }
231 return ExtractAssociatedFiles(buffer_data, buffer_size);
232 break;
233 }
234 return absl::OkStatus();
235 }
236
ExtractAssociatedFiles(const char * buffer_data,size_t buffer_size)237 absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
238 const char* buffer_data, size_t buffer_size) {
239 // Create in-memory read-only zip file.
240 ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size);
241 // Open zip.
242 unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def());
243 if (zf == nullptr) {
244 // It's OK if it fails: this means there are no associated files with this
245 // model.
246 return absl::OkStatus();
247 }
248 // Get number of files.
249 unz_global_info global_info;
250 if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) {
251 return CreateStatusWithPayload(
252 StatusCode::kUnknown, "Unable to get zip archive info.",
253 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
254 }
255
256 // Browse through files in archive.
257 if (global_info.number_entry > 0) {
258 int error = unzGoToFirstFile(zf);
259 while (error == UNZ_OK) {
260 ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf));
261 // Store result in map.
262 associated_files_[zip_file_info.name] = absl::string_view(
263 buffer_data + zip_file_info.position, zip_file_info.size);
264 error = unzGoToNextFile(zf);
265 }
266 if (error != UNZ_END_OF_LIST_OF_FILE) {
267 return CreateStatusWithPayload(
268 StatusCode::kUnknown,
269 "Unable to read associated file in zip archive.",
270 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
271 }
272 }
273 // Close zip.
274 if (unzClose(zf) != UNZ_OK) {
275 return CreateStatusWithPayload(
276 StatusCode::kUnknown, "Unable to close zip archive.",
277 TfLiteSupportStatus::kMetadataAssociatedFileZipError);
278 }
279 return absl::OkStatus();
280 }
281
282 tflite::support::StatusOr<absl::string_view>
GetAssociatedFile(const std::string & filename) const283 ModelMetadataExtractor::GetAssociatedFile(const std::string& filename) const {
284 auto it = associated_files_.find(filename);
285 if (it == associated_files_.end()) {
286 return CreateStatusWithPayload(
287 StatusCode::kNotFound,
288 absl::StrFormat("No associated file with name: %s", filename),
289 TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError);
290 }
291 return it->second;
292 }
293
294 tflite::support::StatusOr<std::string>
GetModelVersion() const295 ModelMetadataExtractor::GetModelVersion() const {
296 if (model_metadata_ == nullptr) {
297 return CreateStatusWithPayload(
298 StatusCode::kFailedPrecondition,
299 "No model metadata",
300 TfLiteSupportStatus::kMetadataNotFoundError);
301 }
302 if (model_metadata_->version() == nullptr) {
303 return CreateStatusWithPayload(
304 StatusCode::kNotFound,
305 "No version in model metadata",
306 TfLiteSupportStatus::kMetadataNotFoundError);
307 }
308 return model_metadata_->version()->str();
309 }
310
311 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
GetInputTensorMetadata() const312 ModelMetadataExtractor::GetInputTensorMetadata() const {
313 if (model_metadata_ == nullptr ||
314 model_metadata_->subgraph_metadata() == nullptr) {
315 return nullptr;
316 }
317 return model_metadata_->subgraph_metadata()
318 ->Get(kDefaultSubgraphIndex)
319 ->input_tensor_metadata();
320 }
321
GetInputTensorMetadata(int index) const322 const tflite::TensorMetadata* ModelMetadataExtractor::GetInputTensorMetadata(
323 int index) const {
324 return GetItemFromVector<tflite::TensorMetadata>(GetInputTensorMetadata(),
325 index);
326 }
327
GetInputTensorCount() const328 int ModelMetadataExtractor::GetInputTensorCount() const {
329 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
330 input_tensor_metadata = GetInputTensorMetadata();
331 return input_tensor_metadata == nullptr ? 0 : input_tensor_metadata->size();
332 }
333
334 const Vector<Offset<TensorMetadata>>*
GetOutputTensorMetadata() const335 ModelMetadataExtractor::GetOutputTensorMetadata() const {
336 if (model_metadata_ == nullptr ||
337 model_metadata_->subgraph_metadata() == nullptr) {
338 return nullptr;
339 }
340 return model_metadata_->subgraph_metadata()
341 ->Get(kDefaultSubgraphIndex)
342 ->output_tensor_metadata();
343 }
344
GetOutputTensorMetadata(int index) const345 const tflite::TensorMetadata* ModelMetadataExtractor::GetOutputTensorMetadata(
346 int index) const {
347 return GetItemFromVector<tflite::TensorMetadata>(GetOutputTensorMetadata(),
348 index);
349 }
350
GetOutputTensorCount() const351 int ModelMetadataExtractor::GetOutputTensorCount() const {
352 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
353 output_tensor_metadata = GetOutputTensorMetadata();
354 return output_tensor_metadata == nullptr ? 0 : output_tensor_metadata->size();
355 }
356
357 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetInputProcessUnits() const358 ModelMetadataExtractor::GetInputProcessUnits() const {
359 if (model_metadata_ == nullptr ||
360 model_metadata_->subgraph_metadata() == nullptr) {
361 return nullptr;
362 }
363 return model_metadata_->subgraph_metadata()
364 ->Get(kDefaultSubgraphIndex)
365 ->input_process_units();
366 }
367
GetInputProcessUnit(int index) const368 const tflite::ProcessUnit* ModelMetadataExtractor::GetInputProcessUnit(
369 int index) const {
370 return GetItemFromVector<tflite::ProcessUnit>(GetInputProcessUnits(), index);
371 }
372
GetInputProcessUnitsCount() const373 int ModelMetadataExtractor::GetInputProcessUnitsCount() const {
374 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* input_process_units =
375 GetInputProcessUnits();
376 return input_process_units == nullptr ? 0 : input_process_units->size();
377 }
378
379 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetOutputProcessUnits() const380 ModelMetadataExtractor::GetOutputProcessUnits() const {
381 if (model_metadata_ == nullptr ||
382 model_metadata_->subgraph_metadata() == nullptr) {
383 return nullptr;
384 }
385 return model_metadata_->subgraph_metadata()
386 ->Get(kDefaultSubgraphIndex)
387 ->output_process_units();
388 }
389
GetOutputProcessUnit(int index) const390 const tflite::ProcessUnit* ModelMetadataExtractor::GetOutputProcessUnit(
391 int index) const {
392 return GetItemFromVector<tflite::ProcessUnit>(GetOutputProcessUnits(), index);
393 }
394
GetOutputProcessUnitsCount() const395 int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
396 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* output_process_units =
397 GetOutputProcessUnits();
398 return output_process_units == nullptr ? 0 : output_process_units->size();
399 }
400
401 } // namespace metadata
402 } // namespace tflite