1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
17 
18 #include <cstddef>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/status/status.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/core/api/op_resolver.h"
31 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
32 #include "tensorflow_lite_support/cc/common.h"
33 #include "tensorflow_lite_support/cc/port/status_macros.h"
34 #include "tensorflow_lite_support/cc/port/statusor.h"
35 #include "tensorflow_lite_support/cc/task/core/category.h"
36 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
37 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
38 #include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
39 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
40 #include "tensorflow_lite_support/cc/utils/common_utils.h"
41 
42 namespace tflite {
43 namespace task {
44 namespace text {
45 namespace nlclassifier {
46 
47 using ::absl::StatusCode;
48 using ::flatbuffers::Offset;
49 using ::flatbuffers::Vector;
50 using ::tflite::TensorMetadata;
51 using ::tflite::support::CreateStatusWithPayload;
52 using ::tflite::support::StatusOr;
53 using ::tflite::support::TfLiteSupportStatus;
54 using ::tflite::support::text::tokenizer::RegexTokenizer;
55 using ::tflite::support::text::tokenizer::Tokenizer;
56 using ::tflite::support::text::tokenizer::TokenizerResult;
57 using ::tflite::support::utils::LoadVocabFromBuffer;
58 using ::tflite::task::core::Category;
59 using ::tflite::task::core::Dequantize;
60 using ::tflite::task::core::GetStringAtIndex;
61 using ::tflite::task::core::PopulateTensor;
62 
63 namespace {
64 constexpr int kRegexTokenizerInputTensorIndex = 0;
65 constexpr int kRegexTokenizerProcessUnitIndex = 0;
66 constexpr char kNoVersionInfo[] = "NO_VERSION_INFO";
67 
CheckAndLoadFirstAssociatedFile(const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>> * associated_files,const tflite::metadata::ModelMetadataExtractor * metadata_extractor)68 StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
69     const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
70         associated_files,
71     const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
72   if (associated_files == nullptr || associated_files->size() < 1 ||
73       associated_files->Get(0)->name() == nullptr) {
74     return CreateStatusWithPayload(
75         absl::StatusCode::kInvalidArgument,
76         "Invalid vocab_file from input process unit.",
77         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
78   }
79   ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
80                    metadata_extractor->GetAssociatedFile(
81                        associated_files->Get(0)->name()->str()));
82   return vocab_buffer;
83 }
84 
CreateRegexTokenizerFromProcessUnit(const tflite::ProcessUnit * tokenizer_process_unit,const tflite::metadata::ModelMetadataExtractor * metadata_extractor)85 StatusOr<std::unique_ptr<RegexTokenizer>> CreateRegexTokenizerFromProcessUnit(
86     const tflite::ProcessUnit* tokenizer_process_unit,
87     const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
88   if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr) {
89     return CreateStatusWithPayload(
90         absl::StatusCode::kInvalidArgument,
91         "No metadata or input process unit found.",
92         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
93   }
94 
95   if (tokenizer_process_unit->options_type() !=
96       ProcessUnitOptions_RegexTokenizerOptions) {
97     return CreateStatusWithPayload(
98         absl::StatusCode::kNotFound,
99         absl::StrCat(
100             "Incorrect options_type:", tokenizer_process_unit->options_type(),
101             " need RegexTokenizerOptions."),
102         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
103   }
104 
105   const tflite::RegexTokenizerOptions* options =
106       tokenizer_process_unit->options_as<RegexTokenizerOptions>();
107   ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
108                    CheckAndLoadFirstAssociatedFile(options->vocab_file(),
109                                                    metadata_extractor));
110   if (options->delim_regex_pattern() == nullptr) {
111     return CreateStatusWithPayload(
112         absl::StatusCode::kInvalidArgument,
113         "Invalid delim_regex_pattern from input process unit.",
114         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
115   }
116 
117   std::unique_ptr<RegexTokenizer> regex_tokenizer =
118       absl::make_unique<RegexTokenizer>(options->delim_regex_pattern()->str(),
119                                         vocab_buffer.data(),
120                                         vocab_buffer.size());
121 
122   int unknown_token_id = 0;
123   if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) {
124     return CreateStatusWithPayload(
125         absl::StatusCode::kInvalidArgument,
126         "RegexTokenizer doesn't have <UNKNOWN> token.",
127         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
128   }
129 
130   int pad_token_id = 0;
131   if (!regex_tokenizer->GetPadToken(&pad_token_id)) {
132     return CreateStatusWithPayload(
133         absl::StatusCode::kInvalidArgument,
134         "RegexTokenizer doesn't have <PAD> token.",
135         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
136   }
137   return regex_tokenizer;
138 }
139 
140 }  // namespace
141 
GetOptions() const142 const NLClassifierOptions& NLClassifier::GetOptions() const { return options_; }
143 
TrySetLabelFromMetadata(const TensorMetadata * metadata)144 absl::Status NLClassifier::TrySetLabelFromMetadata(
145     const TensorMetadata* metadata) {
146   if (metadata == nullptr) {
147     return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
148                                    "Metadata not found for output tensor",
149                                    TfLiteSupportStatus::kMetadataNotFoundError);
150   }
151   const auto* associated_files = metadata->associated_files();
152   if (associated_files == nullptr || associated_files->size() == 0) {
153     return CreateStatusWithPayload(
154         absl::StatusCode::kInvalidArgument,
155         "No label file found for tensor metadata.",
156         TfLiteSupportStatus::kMetadataMissingLabelsError);
157   }
158   const tflite::AssociatedFile* associated_file =
159       associated_files->Get(kOutputTensorLabelFileIndex);
160   if (associated_file->type() != AssociatedFileType_TENSOR_AXIS_LABELS) {
161     return CreateStatusWithPayload(
162         absl::StatusCode::kInvalidArgument,
163         "Incorrect label type found for tensor metadata.",
164         TfLiteSupportStatus::kMetadataMissingLabelsError);
165   }
166   tflite::support::StatusOr<absl::string_view> label_buffer =
167       GetMetadataExtractor()->GetAssociatedFile(
168           associated_files->Get(kOutputTensorIndex)->name()->str());
169   if (label_buffer.ok()) {
170     labels_vector_ =
171         absl::make_unique<std::vector<std::string>>(LoadVocabFromBuffer(
172             label_buffer.value().data(), label_buffer.value().size()));
173     if (associated_file->version() == nullptr) {
174       labels_version_ = kNoVersionInfo;
175     } else {
176       labels_version_ = associated_file->version()->str();
177     }
178     return absl::OkStatus();
179   } else {
180     return CreateStatusWithPayload(
181         absl::StatusCode::kInvalidArgument,
182         "Failed to extract label file from metadata.",
183         TfLiteSupportStatus::kMetadataMissingLabelsError);
184   }
185 }
186 
Classify(const std::string & text)187 std::vector<Category> NLClassifier::Classify(const std::string& text) {
188   StatusOr<std::vector<Category>> infer_result = ClassifyText(text);
189   if (!infer_result.ok()) {
190     return {};
191   }
192   return infer_result.value();
193 }
194 
ClassifyText(const std::string & text)195 StatusOr<std::vector<Category>> NLClassifier::ClassifyText(
196     const std::string& text) {
197   return Infer(text);
198 }
199 
GetModelVersion() const200 std::string NLClassifier::GetModelVersion() const {
201   tflite::support::StatusOr<std::string> model_version =
202       GetMetadataExtractor()->GetModelVersion();
203   if (model_version.ok()) {
204     return model_version.value();
205   }
206   return kNoVersionInfo;
207 }
208 
GetLabelsVersion() const209 std::string NLClassifier::GetLabelsVersion() const {
210   return labels_version_;
211 }
212 
Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const std::string & input)213 absl::Status NLClassifier::Preprocess(
214     const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
215   TfLiteTensor* input_tensor = FindTensorWithNameOrIndex(
216       input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
217       options_.input_tensor_name, options_.input_tensor_index);
218   if (input_tensor == nullptr) {
219     return CreateStatusWithPayload(
220         absl::StatusCode::kInvalidArgument,
221         "No input tensor found from NLClassifierOptions.",
222         TfLiteSupportStatus::kInputTensorNotFoundError);
223   }
224 
225   if (HasRegexTokenizerMetadata()) {
226     //                              |<-------sentence_length-------->|
227     // input_tensor                 <START>, t1, t2... <PAD>, <PAD>...
228     // <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
229     // found in tokenizer vocab.
230     TokenizerResult result = tokenizer_->Tokenize(input);
231 
232     size_t max_sentence_length = input_tensor->dims->size == 2
233                                      ? input_tensor->dims->data[1]
234                                      : input_tensor->dims->data[0];
235 
236     int unknown_token_id = 0;
237     tokenizer_->GetUnknownToken(&unknown_token_id);
238 
239     int pad_token_id = 0;
240     tokenizer_->GetPadToken(&pad_token_id);
241 
242     std::vector<int> input_tokens(max_sentence_length, pad_token_id);
243     int start_token_id = 0;
244     size_t input_token_index = 0;
245     if (tokenizer_->GetStartToken(&start_token_id)) {
246       input_tokens[0] = start_token_id;
247       input_token_index = 1;
248     }
249 
250     for (size_t i = 0; (i < result.subwords.size()) &&
251                        (input_token_index < max_sentence_length);
252          ++i, ++input_token_index) {
253       const std::string& token = result.subwords[i];
254       int token_id = 0;
255       if (tokenizer_->LookupId(token, &token_id)) {
256         input_tokens[input_token_index] = token_id;
257       } else {
258         input_tokens[input_token_index] = unknown_token_id;
259       }
260     }
261 
262     PopulateTensor(input_tokens, input_tensor);
263   } else {
264     PopulateTensor(input, input_tensor);
265   }
266   return absl::OkStatus();
267 }
268 
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const std::string &)269 StatusOr<std::vector<Category>> NLClassifier::Postprocess(
270     const std::vector<const TfLiteTensor*>& output_tensors,
271     const std::string& /*input*/) {
272   return BuildResults(
273       FindTensorWithNameOrIndex(
274           output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
275           options_.output_score_tensor_name,
276           options_.output_score_tensor_index),
277       FindTensorWithNameOrIndex(
278           output_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
279           options_.output_label_tensor_name,
280           options_.output_label_tensor_index));
281 }
282 
BuildResults(const TfLiteTensor * scores,const TfLiteTensor * labels)283 std::vector<Category> NLClassifier::BuildResults(const TfLiteTensor* scores,
284                                                  const TfLiteTensor* labels) {
285   bool use_index_as_labels = (labels_vector_ == nullptr) && (labels == nullptr);
286   // Some models output scores with transposed shape [1, categories]
287   int categories =
288       scores->dims->size == 2 ? scores->dims->data[1] : scores->dims->data[0];
289 
290   std::vector<Category> predictions;
291   predictions.reserve(categories);
292 
293   bool should_dequantize = scores->type == kTfLiteUInt8 ||
294                            scores->type == kTfLiteInt8 ||
295                            scores->type == kTfLiteInt16;
296   for (int index = 0; index < categories; index++) {
297     std::string label;
298     if (use_index_as_labels) {
299       label = std::to_string(index);
300     } else if (labels_vector_ == nullptr) {
301       if (labels->type == kTfLiteString) {
302         label = GetStringAtIndex(labels, index);
303       } else if (labels->type == kTfLiteInt32) {
304         label = std::to_string(GetTensorData<int>(labels)[index]);
305       }
306     } else {
307       label = (*labels_vector_)[index];
308     }
309     if (should_dequantize) {
310       predictions.push_back(Category(label, Dequantize(*scores, index)));
311     } else if (scores->type == kTfLiteBool) {
312       predictions.push_back(
313           Category(label, GetTensorData<bool>(scores)[index] ? 1.0 : 0.0));
314     } else {
315       predictions.push_back(
316           Category(label, scores->type == kTfLiteFloat32
317                               ? GetTensorData<float>(scores)[index]
318                               : GetTensorData<double>(scores)[index]));
319     }
320   }
321 
322   return predictions;
323 }
Initialize(const NLClassifierOptions & options)324 absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) {
325   options_ = options;
326   // input tensor should be type STRING
327   auto input_tensor = FindTensorWithNameOrIndex(
328       GetInputTensors(), GetMetadataExtractor()->GetInputTensorMetadata(),
329       options.input_tensor_name, options.input_tensor_index);
330   if (input_tensor == nullptr) {
331     return CreateStatusWithPayload(
332         StatusCode::kInvalidArgument,
333         absl::StrCat("No input tensor found with name ",
334                      options.input_tensor_name, " or at index ",
335                      options.input_tensor_index),
336         TfLiteSupportStatus::kInputTensorNotFoundError);
337   }
338   if (HasRegexTokenizerMetadata()) {
339     if (input_tensor->type != kTfLiteInt32) {
340       return CreateStatusWithPayload(
341           StatusCode::kInvalidArgument,
342           absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
343                        ". Requested INT32, got ",
344                        TfLiteTypeGetName(input_tensor->type), "."),
345           TfLiteSupportStatus::kInvalidInputTensorTypeError);
346     }
347     RETURN_IF_ERROR(SetupRegexTokenizer());
348   } else {
349     if (input_tensor->type != kTfLiteString) {
350       return CreateStatusWithPayload(
351           StatusCode::kInvalidArgument,
352           absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
353                        ". Requested STRING, got ",
354                        TfLiteTypeGetName(input_tensor->type), "."),
355           TfLiteSupportStatus::kInvalidInputTensorTypeError);
356     }
357   }
358 
359   // output score tensor should be type
360   // UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL
361   std::vector<const TfLiteTensor*> output_tensors = GetOutputTensors();
362   const Vector<Offset<TensorMetadata>>* output_tensor_metadatas =
363       GetMetadataExtractor()->GetOutputTensorMetadata();
364 
365   const auto scores = FindTensorWithNameOrIndex(
366       output_tensors, output_tensor_metadatas, options.output_score_tensor_name,
367       options.output_score_tensor_index);
368   if (scores == nullptr) {
369     return CreateStatusWithPayload(
370         StatusCode::kInvalidArgument,
371         absl::StrCat("No output score tensor found with name ",
372                      options.output_score_tensor_name, " or at index ",
373                      options.output_score_tensor_index),
374         TfLiteSupportStatus::kOutputTensorNotFoundError);
375   }
376   static constexpr TfLiteType valid_types[] = {kTfLiteUInt8,   kTfLiteInt8,
377                                                kTfLiteInt16,   kTfLiteFloat32,
378                                                kTfLiteFloat64, kTfLiteBool};
379   if (!absl::c_linear_search(valid_types, scores->type)) {
380     return CreateStatusWithPayload(
381         StatusCode::kInvalidArgument,
382         absl::StrCat("Type mismatch for score tensor ", scores->name,
383                      ". Requested one of these types: "
384                      "INT8/UINT8/INT16/FLOAT32/FLOAT64/BOOL, got ",
385                      TfLiteTypeGetName(scores->type), "."),
386         TfLiteSupportStatus::kInvalidOutputTensorTypeError);
387   }
388 
389   // Extract associated label file from output score tensor if one exists, a
390   // well-formatted metadata should have same number of tensors with the model.
391   if (output_tensor_metadatas &&
392       output_tensor_metadatas->size() == output_tensors.size()) {
393     for (int i = 0; i < output_tensor_metadatas->size(); ++i) {
394       const tflite::TensorMetadata* metadata = output_tensor_metadatas->Get(i);
395       if ((metadata->name() && metadata->name()->string_view() ==
396                                    options.output_score_tensor_name) ||
397           i == options.output_score_tensor_index) {
398         if (TrySetLabelFromMetadata(metadata).ok()) {
399           return absl::OkStatus();
400         }
401       }
402     }
403   }
404 
405   // If labels_vector_ is not set up from metadata, try register output label
406   // tensor from options.
407   if (labels_vector_ == nullptr) {
408     // output label tensor should be type STRING or INT32 if the one exists
409     auto labels = FindTensorWithNameOrIndex(
410         output_tensors, output_tensor_metadatas,
411         options.output_label_tensor_name, options.output_label_tensor_index);
412     if (labels != nullptr && labels->type != kTfLiteString &&
413         labels->type != kTfLiteInt32) {
414       return CreateStatusWithPayload(
415           StatusCode::kInvalidArgument,
416           absl::StrCat("Type mismatch for label tensor ", scores->name,
417                        ". Requested STRING or INT32, got ",
418                        TfLiteTypeGetName(scores->type), "."),
419           TfLiteSupportStatus::kInvalidOutputTensorTypeError);
420     }
421   }
422   return absl::OkStatus();
423 }
424 
425 StatusOr<std::unique_ptr<NLClassifier>>
CreateFromBufferAndOptions(const char * model_buffer_data,size_t model_buffer_size,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)426 NLClassifier::CreateFromBufferAndOptions(
427     const char* model_buffer_data, size_t model_buffer_size,
428     const NLClassifierOptions& options,
429     std::unique_ptr<tflite::OpResolver> resolver) {
430   std::unique_ptr<NLClassifier> nl_classifier;
431   ASSIGN_OR_RETURN(
432       nl_classifier,
433       core::TaskAPIFactory::CreateFromBuffer<NLClassifier>(
434           model_buffer_data, model_buffer_size, std::move(resolver)));
435   RETURN_IF_ERROR(nl_classifier->Initialize(options));
436   return std::move(nl_classifier);
437 }
438 
CreateFromFileAndOptions(const std::string & path_to_model,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)439 StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions(
440     const std::string& path_to_model, const NLClassifierOptions& options,
441     std::unique_ptr<tflite::OpResolver> resolver) {
442   std::unique_ptr<NLClassifier> nl_classifier;
443   ASSIGN_OR_RETURN(nl_classifier,
444                    core::TaskAPIFactory::CreateFromFile<NLClassifier>(
445                        path_to_model, std::move(resolver)));
446   RETURN_IF_ERROR(nl_classifier->Initialize(options));
447   return std::move(nl_classifier);
448 }
449 
CreateFromFdAndOptions(int fd,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)450 StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
451     int fd, const NLClassifierOptions& options,
452     std::unique_ptr<tflite::OpResolver> resolver) {
453   std::unique_ptr<NLClassifier> nl_classifier;
454   ASSIGN_OR_RETURN(nl_classifier,
455                    core::TaskAPIFactory::CreateFromFileDescriptor<NLClassifier>(
456                        fd, std::move(resolver)));
457   RETURN_IF_ERROR(nl_classifier->Initialize(options));
458   return std::move(nl_classifier);
459 }
460 
HasRegexTokenizerMetadata()461 bool NLClassifier::HasRegexTokenizerMetadata() {
462   const TensorMetadata* input_tensor_metadata =
463       GetMetadataExtractor()->GetInputTensorMetadata(
464           kRegexTokenizerInputTensorIndex);
465   if (input_tensor_metadata == nullptr) {
466     return false;
467   }
468   tflite::support::StatusOr<const tflite::ProcessUnit*> status =
469       GetMetadataExtractor()->FindFirstProcessUnit(
470           *input_tensor_metadata, ProcessUnitOptions_RegexTokenizerOptions);
471   return status.ok() ? status.value() != nullptr : false;
472 }
473 
SetupRegexTokenizer()474 absl::Status NLClassifier::SetupRegexTokenizer() {
475   ASSIGN_OR_RETURN(
476       tokenizer_,
477       CreateRegexTokenizerFromProcessUnit(
478           GetMetadataExtractor()
479               ->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex)
480               ->process_units()
481               ->Get(kRegexTokenizerProcessUnitIndex),
482           GetMetadataExtractor()));
483 
484   return absl::OkStatus();
485 }
486 
487 }  // namespace nlclassifier
488 }  // namespace text
489 }  // namespace task
490 }  // namespace tflite
491