xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/metadata/cc/metadata_extractor.cc (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
17 
18 #include <string>
19 
20 #include "absl/memory/memory.h"
21 #include "absl/status/status.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/string_view.h"
25 #include "flatbuffers/flatbuffers.h"
26 #include "contrib/minizip/ioapi.h"
27 #include "contrib/minizip/unzip.h"
28 #include "tensorflow/lite/schema/schema_generated.h"
29 #include "tensorflow_lite_support/cc/common.h"
30 #include "tensorflow_lite_support/cc/port/status_macros.h"
31 #include "tensorflow_lite_support/metadata/cc/utils/zip_readonly_mem_file.h"
32 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
33 
34 namespace tflite {
35 namespace metadata {
36 
37 namespace {
38 constexpr char kMetadataBufferName[] = "TFLITE_METADATA";
39 
40 using ::absl::StatusCode;
41 using ::flatbuffers::Offset;
42 using ::flatbuffers::Vector;
43 using ::tflite::TensorMetadata;
44 using ::tflite::support::CreateStatusWithPayload;
45 using ::tflite::support::TfLiteSupportStatus;
46 
47 // Util to get item from src_vector specified by index.
48 template <typename T>
GetItemFromVector(const flatbuffers::Vector<flatbuffers::Offset<T>> * src_vector,int index)49 const T* GetItemFromVector(
50     const flatbuffers::Vector<flatbuffers::Offset<T>>* src_vector, int index) {
51   if (src_vector == nullptr || index < 0 || index >= src_vector->size()) {
52     return nullptr;
53   }
54   return src_vector->Get(index);
55 }
56 
57 // Wrapper function around calls to unzip to avoid repeating conversion logic
58 // from error code to Status.
UnzipErrorToStatus(int error)59 absl::Status UnzipErrorToStatus(int error) {
60   if (error != UNZ_OK) {
61     return CreateStatusWithPayload(
62         StatusCode::kUnknown, "Unable to read associated file in zip archive.",
63         TfLiteSupportStatus::kMetadataAssociatedFileZipError);
64   }
65   return absl::OkStatus();
66 }
67 
68 // Stores a file name, position in zip buffer and size.
69 struct ZipFileInfo {
70   std::string name;
71   ZPOS64_T position;
72   ZPOS64_T size;
73 };
74 
75 // Returns the ZipFileInfo corresponding to the current file in the provided
76 // unzFile object.
GetCurrentZipFileInfo(const unzFile & zf)77 tflite::support::StatusOr<ZipFileInfo> GetCurrentZipFileInfo(const unzFile& zf) {
78   // Open file in raw mode, as data is expected to be uncompressed.
79   int method;
80   RETURN_IF_ERROR(UnzipErrorToStatus(
81       unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1)));
82   if (method != Z_NO_COMPRESSION) {
83     return CreateStatusWithPayload(
84         StatusCode::kUnknown, "Expected uncompressed zip archive.",
85         TfLiteSupportStatus::kMetadataAssociatedFileZipError);
86   }
87 
88   // Get file info a first time to get filename size.
89   unz_file_info64 file_info;
90   RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
91       zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0,
92       /*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
93       /*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
94 
95   // Second call to get file name.
96   auto file_name_size = file_info.size_filename;
97   char* c_file_name = (char*)malloc(file_name_size);
98   RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
99       zf, &file_info, c_file_name, file_name_size,
100       /*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
101       /*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
102   std::string file_name = std::string(c_file_name, file_name_size);
103   free(c_file_name);
104 
105   // Get position in file.
106   auto position = unzGetCurrentFileZStreamPos64(zf);
107   if (position == 0) {
108     return CreateStatusWithPayload(
109         StatusCode::kUnknown, "Unable to read file in zip archive.",
110         TfLiteSupportStatus::kMetadataAssociatedFileZipError);
111   }
112   ZipFileInfo result = {.name = file_name,
113                         .position = position,
114                         .size = file_info.uncompressed_size};
115 
116   // Close file and return.
117   RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf)));
118   return result;
119 }
120 }  // namespace
121 
122 /* static */
123 tflite::support::StatusOr<std::unique_ptr<ModelMetadataExtractor>>
CreateFromModelBuffer(const char * buffer_data,size_t buffer_size)124 ModelMetadataExtractor::CreateFromModelBuffer(const char* buffer_data,
125                                               size_t buffer_size) {
126   // Use absl::WrapUnique() to call private constructor:
127   // https://abseil.io/tips/126.
128   std::unique_ptr<ModelMetadataExtractor> extractor =
129       absl::WrapUnique(new ModelMetadataExtractor());
130   RETURN_IF_ERROR(extractor->InitFromModelBuffer(buffer_data, buffer_size));
131   return extractor;
132 }
133 
134 /* static */
135 tflite::support::StatusOr<const tflite::ProcessUnit*>
FindFirstProcessUnit(const tflite::TensorMetadata & tensor_metadata,tflite::ProcessUnitOptions type)136 ModelMetadataExtractor::FindFirstProcessUnit(
137     const tflite::TensorMetadata& tensor_metadata,
138     tflite::ProcessUnitOptions type) {
139   const tflite::ProcessUnit* result = nullptr;
140   if (tensor_metadata.process_units() == nullptr) {
141     return result;
142   }
143   for (const auto process_unit : *tensor_metadata.process_units()) {
144     if (process_unit->options_type() == type) {
145       if (result != nullptr) {
146         return CreateStatusWithPayload(
147             StatusCode::kInvalidArgument,
148             absl::StrCat("Found multiple ProcessUnits with type=",
149                          tflite::EnumNameProcessUnitOptions(type),
150                          ", expected at most one."),
151             TfLiteSupportStatus::kMetadataInvalidProcessUnitsError);
152       }
153       result = process_unit;
154     }
155   }
156   return result;
157 }
158 
159 /* static */
FindFirstAssociatedFileName(const tflite::TensorMetadata & tensor_metadata,tflite::AssociatedFileType type,absl::string_view locale)160 std::string ModelMetadataExtractor::FindFirstAssociatedFileName(
161     const tflite::TensorMetadata& tensor_metadata,
162     tflite::AssociatedFileType type, absl::string_view locale) {
163   if (tensor_metadata.associated_files() == nullptr) {
164     return std::string();
165   }
166   for (const auto associated_file : *tensor_metadata.associated_files()) {
167     if (associated_file->type() != type || associated_file->name() == nullptr) {
168       continue;
169     }
170     if (locale.empty() || (associated_file->locale() != nullptr &&
171                            locale == associated_file->locale()->str())) {
172       return associated_file->name()->str();
173     }
174   }
175   return std::string();
176 }
177 
InitFromModelBuffer(const char * buffer_data,size_t buffer_size)178 absl::Status ModelMetadataExtractor::InitFromModelBuffer(
179     const char* buffer_data, size_t buffer_size) {
180   // Rely on the simplest, base flatbuffers verifier. Here is not the place to
181   // e.g. use an OpResolver: we just want to make sure the buffer is valid to
182   // access the metadata.
183   flatbuffers::Verifier verifier = flatbuffers::Verifier(
184       reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
185   if (!tflite::VerifyModelBuffer(verifier)) {
186     return CreateStatusWithPayload(
187         StatusCode::kInvalidArgument,
188         "The model is not a valid FlatBuffer buffer.",
189         TfLiteSupportStatus::kInvalidFlatBufferError);
190   }
191   model_ = tflite::GetModel(buffer_data);
192   if (model_->metadata() == nullptr) {
193     // Not all models have metadata, which is OK. `GetModelMetadata()` then
194     // returns nullptr.
195     return absl::OkStatus();
196   }
197   // Look for the "TFLITE_METADATA" field, if any.
198   for (int i = 0; i < model_->metadata()->size(); ++i) {
199     const auto metadata = model_->metadata()->Get(i);
200     if (!metadata->name()) {
201       continue;
202     }
203     if (metadata->name()->str() != kMetadataBufferName) {
204       continue;
205     }
206     const auto buffer_index = metadata->buffer();
207     const auto metadata_buffer =
208         model_->buffers()->Get(buffer_index)->data()->data();
209     if (!tflite::ModelMetadataBufferHasIdentifier(metadata_buffer)) {
210       return CreateStatusWithPayload(
211           StatusCode::kInvalidArgument,
212           absl::StrFormat(
213               "Invalid metadata schema version: expected %s, got %s",
214               absl::string_view(tflite::ModelMetadataIdentifier())
215                   .substr(
216                       0, flatbuffers::FlatBufferBuilder::kFileIdentifierLength),
217               // Returned identifier is not null terminated; has to be
218               // truncated.
219               absl::string_view(
220                   flatbuffers::GetBufferIdentifier(metadata_buffer))
221                   .substr(
222                       0,
223                       flatbuffers::FlatBufferBuilder::kFileIdentifierLength)),
224           TfLiteSupportStatus::kMetadataInvalidSchemaVersionError);
225     }
226     model_metadata_ = tflite::GetModelMetadata(metadata_buffer);
227     if (model_metadata_ == nullptr) {
228       return CreateStatusWithPayload(StatusCode::kInternal,
229                                      "Expected Model Metadata not to be null.");
230     }
231     return ExtractAssociatedFiles(buffer_data, buffer_size);
232     break;
233   }
234   return absl::OkStatus();
235 }
236 
ExtractAssociatedFiles(const char * buffer_data,size_t buffer_size)237 absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
238     const char* buffer_data, size_t buffer_size) {
239   // Create in-memory read-only zip file.
240   ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size);
241   // Open zip.
242   unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def());
243   if (zf == nullptr) {
244     // It's OK if it fails: this means there are no associated files with this
245     // model.
246     return absl::OkStatus();
247   }
248   // Get number of files.
249   unz_global_info global_info;
250   if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) {
251     return CreateStatusWithPayload(
252         StatusCode::kUnknown, "Unable to get zip archive info.",
253         TfLiteSupportStatus::kMetadataAssociatedFileZipError);
254   }
255 
256   // Browse through files in archive.
257   if (global_info.number_entry > 0) {
258     int error = unzGoToFirstFile(zf);
259     while (error == UNZ_OK) {
260       ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf));
261       // Store result in map.
262       associated_files_[zip_file_info.name] = absl::string_view(
263           buffer_data + zip_file_info.position, zip_file_info.size);
264       error = unzGoToNextFile(zf);
265     }
266     if (error != UNZ_END_OF_LIST_OF_FILE) {
267       return CreateStatusWithPayload(
268           StatusCode::kUnknown,
269           "Unable to read associated file in zip archive.",
270           TfLiteSupportStatus::kMetadataAssociatedFileZipError);
271     }
272   }
273   // Close zip.
274   if (unzClose(zf) != UNZ_OK) {
275     return CreateStatusWithPayload(
276         StatusCode::kUnknown, "Unable to close zip archive.",
277         TfLiteSupportStatus::kMetadataAssociatedFileZipError);
278   }
279   return absl::OkStatus();
280 }
281 
282 tflite::support::StatusOr<absl::string_view>
GetAssociatedFile(const std::string & filename) const283 ModelMetadataExtractor::GetAssociatedFile(const std::string& filename) const {
284   auto it = associated_files_.find(filename);
285   if (it == associated_files_.end()) {
286     return CreateStatusWithPayload(
287         StatusCode::kNotFound,
288         absl::StrFormat("No associated file with name: %s", filename),
289         TfLiteSupportStatus::kMetadataAssociatedFileNotFoundError);
290   }
291   return it->second;
292 }
293 
294 tflite::support::StatusOr<std::string>
GetModelVersion() const295 ModelMetadataExtractor::GetModelVersion() const {
296   if (model_metadata_ == nullptr) {
297     return CreateStatusWithPayload(
298       StatusCode::kFailedPrecondition,
299       "No model metadata",
300       TfLiteSupportStatus::kMetadataNotFoundError);
301   }
302   if (model_metadata_->version() == nullptr) {
303     return CreateStatusWithPayload(
304       StatusCode::kNotFound,
305       "No version in model metadata",
306       TfLiteSupportStatus::kMetadataNotFoundError);
307   }
308   return model_metadata_->version()->str();
309 }
310 
311 const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
GetInputTensorMetadata() const312 ModelMetadataExtractor::GetInputTensorMetadata() const {
313   if (model_metadata_ == nullptr ||
314       model_metadata_->subgraph_metadata() == nullptr) {
315     return nullptr;
316   }
317   return model_metadata_->subgraph_metadata()
318       ->Get(kDefaultSubgraphIndex)
319       ->input_tensor_metadata();
320 }
321 
GetInputTensorMetadata(int index) const322 const tflite::TensorMetadata* ModelMetadataExtractor::GetInputTensorMetadata(
323     int index) const {
324   return GetItemFromVector<tflite::TensorMetadata>(GetInputTensorMetadata(),
325                                                    index);
326 }
327 
GetInputTensorCount() const328 int ModelMetadataExtractor::GetInputTensorCount() const {
329   const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
330       input_tensor_metadata = GetInputTensorMetadata();
331   return input_tensor_metadata == nullptr ? 0 : input_tensor_metadata->size();
332 }
333 
334 const Vector<Offset<TensorMetadata>>*
GetOutputTensorMetadata() const335 ModelMetadataExtractor::GetOutputTensorMetadata() const {
336   if (model_metadata_ == nullptr ||
337       model_metadata_->subgraph_metadata() == nullptr) {
338     return nullptr;
339   }
340   return model_metadata_->subgraph_metadata()
341       ->Get(kDefaultSubgraphIndex)
342       ->output_tensor_metadata();
343 }
344 
GetOutputTensorMetadata(int index) const345 const tflite::TensorMetadata* ModelMetadataExtractor::GetOutputTensorMetadata(
346     int index) const {
347   return GetItemFromVector<tflite::TensorMetadata>(GetOutputTensorMetadata(),
348                                                    index);
349 }
350 
GetOutputTensorCount() const351 int ModelMetadataExtractor::GetOutputTensorCount() const {
352   const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMetadata>>*
353       output_tensor_metadata = GetOutputTensorMetadata();
354   return output_tensor_metadata == nullptr ? 0 : output_tensor_metadata->size();
355 }
356 
357 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetInputProcessUnits() const358 ModelMetadataExtractor::GetInputProcessUnits() const {
359   if (model_metadata_ == nullptr ||
360       model_metadata_->subgraph_metadata() == nullptr) {
361     return nullptr;
362   }
363   return model_metadata_->subgraph_metadata()
364       ->Get(kDefaultSubgraphIndex)
365       ->input_process_units();
366 }
367 
GetInputProcessUnit(int index) const368 const tflite::ProcessUnit* ModelMetadataExtractor::GetInputProcessUnit(
369     int index) const {
370   return GetItemFromVector<tflite::ProcessUnit>(GetInputProcessUnits(), index);
371 }
372 
GetInputProcessUnitsCount() const373 int ModelMetadataExtractor::GetInputProcessUnitsCount() const {
374   const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* input_process_units =
375       GetInputProcessUnits();
376   return input_process_units == nullptr ? 0 : input_process_units->size();
377 }
378 
379 const Vector<flatbuffers::Offset<tflite::ProcessUnit>>*
GetOutputProcessUnits() const380 ModelMetadataExtractor::GetOutputProcessUnits() const {
381   if (model_metadata_ == nullptr ||
382       model_metadata_->subgraph_metadata() == nullptr) {
383     return nullptr;
384   }
385   return model_metadata_->subgraph_metadata()
386       ->Get(kDefaultSubgraphIndex)
387       ->output_process_units();
388 }
389 
GetOutputProcessUnit(int index) const390 const tflite::ProcessUnit* ModelMetadataExtractor::GetOutputProcessUnit(
391     int index) const {
392   return GetItemFromVector<tflite::ProcessUnit>(GetOutputProcessUnits(), index);
393 }
394 
GetOutputProcessUnitsCount() const395 int ModelMetadataExtractor::GetOutputProcessUnitsCount() const {
396   const Vector<flatbuffers::Offset<tflite::ProcessUnit>>* output_process_units =
397       GetOutputProcessUnits();
398   return output_process_units == nullptr ? 0 : output_process_units->size();
399 }
400 
401 }  // namespace metadata
402 }  // namespace tflite