1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
17
18 #include <cstddef>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/status/status.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/string_view.h"
28 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/core/api/op_resolver.h"
31 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
32 #include "tensorflow_lite_support/cc/common.h"
33 #include "tensorflow_lite_support/cc/port/status_macros.h"
34 #include "tensorflow_lite_support/cc/port/statusor.h"
35 #include "tensorflow_lite_support/cc/task/core/category.h"
36 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
37 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
38 #include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
39 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
40 #include "tensorflow_lite_support/cc/utils/common_utils.h"
41
42 namespace tflite {
43 namespace task {
44 namespace text {
45 namespace nlclassifier {
46
47 using ::absl::StatusCode;
48 using ::flatbuffers::Offset;
49 using ::flatbuffers::Vector;
50 using ::tflite::TensorMetadata;
51 using ::tflite::support::CreateStatusWithPayload;
52 using ::tflite::support::StatusOr;
53 using ::tflite::support::TfLiteSupportStatus;
54 using ::tflite::support::text::tokenizer::RegexTokenizer;
55 using ::tflite::support::text::tokenizer::Tokenizer;
56 using ::tflite::support::text::tokenizer::TokenizerResult;
57 using ::tflite::support::utils::LoadVocabFromBuffer;
58 using ::tflite::task::core::Category;
59 using ::tflite::task::core::Dequantize;
60 using ::tflite::task::core::GetStringAtIndex;
61 using ::tflite::task::core::PopulateTensor;
62
63 namespace {
64 constexpr int kRegexTokenizerInputTensorIndex = 0;
65 constexpr int kRegexTokenizerProcessUnitIndex = 0;
66 constexpr char kNoVersionInfo[] = "NO_VERSION_INFO";
67
CheckAndLoadFirstAssociatedFile(const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>> * associated_files,const tflite::metadata::ModelMetadataExtractor * metadata_extractor)68 StatusOr<absl::string_view> CheckAndLoadFirstAssociatedFile(
69 const flatbuffers::Vector<flatbuffers::Offset<tflite::AssociatedFile>>*
70 associated_files,
71 const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
72 if (associated_files == nullptr || associated_files->size() < 1 ||
73 associated_files->Get(0)->name() == nullptr) {
74 return CreateStatusWithPayload(
75 absl::StatusCode::kInvalidArgument,
76 "Invalid vocab_file from input process unit.",
77 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
78 }
79 ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
80 metadata_extractor->GetAssociatedFile(
81 associated_files->Get(0)->name()->str()));
82 return vocab_buffer;
83 }
84
CreateRegexTokenizerFromProcessUnit(const tflite::ProcessUnit * tokenizer_process_unit,const tflite::metadata::ModelMetadataExtractor * metadata_extractor)85 StatusOr<std::unique_ptr<RegexTokenizer>> CreateRegexTokenizerFromProcessUnit(
86 const tflite::ProcessUnit* tokenizer_process_unit,
87 const tflite::metadata::ModelMetadataExtractor* metadata_extractor) {
88 if (metadata_extractor == nullptr || tokenizer_process_unit == nullptr) {
89 return CreateStatusWithPayload(
90 absl::StatusCode::kInvalidArgument,
91 "No metadata or input process unit found.",
92 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
93 }
94
95 if (tokenizer_process_unit->options_type() !=
96 ProcessUnitOptions_RegexTokenizerOptions) {
97 return CreateStatusWithPayload(
98 absl::StatusCode::kNotFound,
99 absl::StrCat(
100 "Incorrect options_type:", tokenizer_process_unit->options_type(),
101 " need RegexTokenizerOptions."),
102 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
103 }
104
105 const tflite::RegexTokenizerOptions* options =
106 tokenizer_process_unit->options_as<RegexTokenizerOptions>();
107 ASSIGN_OR_RETURN(absl::string_view vocab_buffer,
108 CheckAndLoadFirstAssociatedFile(options->vocab_file(),
109 metadata_extractor));
110 if (options->delim_regex_pattern() == nullptr) {
111 return CreateStatusWithPayload(
112 absl::StatusCode::kInvalidArgument,
113 "Invalid delim_regex_pattern from input process unit.",
114 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
115 }
116
117 std::unique_ptr<RegexTokenizer> regex_tokenizer =
118 absl::make_unique<RegexTokenizer>(options->delim_regex_pattern()->str(),
119 vocab_buffer.data(),
120 vocab_buffer.size());
121
122 int unknown_token_id = 0;
123 if (!regex_tokenizer->GetUnknownToken(&unknown_token_id)) {
124 return CreateStatusWithPayload(
125 absl::StatusCode::kInvalidArgument,
126 "RegexTokenizer doesn't have <UNKNOWN> token.",
127 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
128 }
129
130 int pad_token_id = 0;
131 if (!regex_tokenizer->GetPadToken(&pad_token_id)) {
132 return CreateStatusWithPayload(
133 absl::StatusCode::kInvalidArgument,
134 "RegexTokenizer doesn't have <PAD> token.",
135 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
136 }
137 return regex_tokenizer;
138 }
139
140 } // namespace
141
GetOptions() const142 const NLClassifierOptions& NLClassifier::GetOptions() const { return options_; }
143
TrySetLabelFromMetadata(const TensorMetadata * metadata)144 absl::Status NLClassifier::TrySetLabelFromMetadata(
145 const TensorMetadata* metadata) {
146 if (metadata == nullptr) {
147 return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
148 "Metadata not found for output tensor",
149 TfLiteSupportStatus::kMetadataNotFoundError);
150 }
151 const auto* associated_files = metadata->associated_files();
152 if (associated_files == nullptr || associated_files->size() == 0) {
153 return CreateStatusWithPayload(
154 absl::StatusCode::kInvalidArgument,
155 "No label file found for tensor metadata.",
156 TfLiteSupportStatus::kMetadataMissingLabelsError);
157 }
158 const tflite::AssociatedFile* associated_file =
159 associated_files->Get(kOutputTensorLabelFileIndex);
160 if (associated_file->type() != AssociatedFileType_TENSOR_AXIS_LABELS) {
161 return CreateStatusWithPayload(
162 absl::StatusCode::kInvalidArgument,
163 "Incorrect label type found for tensor metadata.",
164 TfLiteSupportStatus::kMetadataMissingLabelsError);
165 }
166 tflite::support::StatusOr<absl::string_view> label_buffer =
167 GetMetadataExtractor()->GetAssociatedFile(
168 associated_files->Get(kOutputTensorIndex)->name()->str());
169 if (label_buffer.ok()) {
170 labels_vector_ =
171 absl::make_unique<std::vector<std::string>>(LoadVocabFromBuffer(
172 label_buffer.value().data(), label_buffer.value().size()));
173 if (associated_file->version() == nullptr) {
174 labels_version_ = kNoVersionInfo;
175 } else {
176 labels_version_ = associated_file->version()->str();
177 }
178 return absl::OkStatus();
179 } else {
180 return CreateStatusWithPayload(
181 absl::StatusCode::kInvalidArgument,
182 "Failed to extract label file from metadata.",
183 TfLiteSupportStatus::kMetadataMissingLabelsError);
184 }
185 }
186
Classify(const std::string & text)187 std::vector<Category> NLClassifier::Classify(const std::string& text) {
188 StatusOr<std::vector<Category>> infer_result = ClassifyText(text);
189 if (!infer_result.ok()) {
190 return {};
191 }
192 return infer_result.value();
193 }
194
ClassifyText(const std::string & text)195 StatusOr<std::vector<Category>> NLClassifier::ClassifyText(
196 const std::string& text) {
197 return Infer(text);
198 }
199
GetModelVersion() const200 std::string NLClassifier::GetModelVersion() const {
201 tflite::support::StatusOr<std::string> model_version =
202 GetMetadataExtractor()->GetModelVersion();
203 if (model_version.ok()) {
204 return model_version.value();
205 }
206 return kNoVersionInfo;
207 }
208
GetLabelsVersion() const209 std::string NLClassifier::GetLabelsVersion() const {
210 return labels_version_;
211 }
212
Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const std::string & input)213 absl::Status NLClassifier::Preprocess(
214 const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
215 TfLiteTensor* input_tensor = FindTensorWithNameOrIndex(
216 input_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
217 options_.input_tensor_name, options_.input_tensor_index);
218 if (input_tensor == nullptr) {
219 return CreateStatusWithPayload(
220 absl::StatusCode::kInvalidArgument,
221 "No input tensor found from NLClassifierOptions.",
222 TfLiteSupportStatus::kInputTensorNotFoundError);
223 }
224
225 if (HasRegexTokenizerMetadata()) {
226 // |<-------sentence_length-------->|
227 // input_tensor <START>, t1, t2... <PAD>, <PAD>...
228 // <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's not
229 // found in tokenizer vocab.
230 TokenizerResult result = tokenizer_->Tokenize(input);
231
232 size_t max_sentence_length = input_tensor->dims->size == 2
233 ? input_tensor->dims->data[1]
234 : input_tensor->dims->data[0];
235
236 int unknown_token_id = 0;
237 tokenizer_->GetUnknownToken(&unknown_token_id);
238
239 int pad_token_id = 0;
240 tokenizer_->GetPadToken(&pad_token_id);
241
242 std::vector<int> input_tokens(max_sentence_length, pad_token_id);
243 int start_token_id = 0;
244 size_t input_token_index = 0;
245 if (tokenizer_->GetStartToken(&start_token_id)) {
246 input_tokens[0] = start_token_id;
247 input_token_index = 1;
248 }
249
250 for (size_t i = 0; (i < result.subwords.size()) &&
251 (input_token_index < max_sentence_length);
252 ++i, ++input_token_index) {
253 const std::string& token = result.subwords[i];
254 int token_id = 0;
255 if (tokenizer_->LookupId(token, &token_id)) {
256 input_tokens[input_token_index] = token_id;
257 } else {
258 input_tokens[input_token_index] = unknown_token_id;
259 }
260 }
261
262 PopulateTensor(input_tokens, input_tensor);
263 } else {
264 PopulateTensor(input, input_tensor);
265 }
266 return absl::OkStatus();
267 }
268
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const std::string &)269 StatusOr<std::vector<Category>> NLClassifier::Postprocess(
270 const std::vector<const TfLiteTensor*>& output_tensors,
271 const std::string& /*input*/) {
272 return BuildResults(
273 FindTensorWithNameOrIndex(
274 output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
275 options_.output_score_tensor_name,
276 options_.output_score_tensor_index),
277 FindTensorWithNameOrIndex(
278 output_tensors, GetMetadataExtractor()->GetInputTensorMetadata(),
279 options_.output_label_tensor_name,
280 options_.output_label_tensor_index));
281 }
282
BuildResults(const TfLiteTensor * scores,const TfLiteTensor * labels)283 std::vector<Category> NLClassifier::BuildResults(const TfLiteTensor* scores,
284 const TfLiteTensor* labels) {
285 bool use_index_as_labels = (labels_vector_ == nullptr) && (labels == nullptr);
286 // Some models output scores with transposed shape [1, categories]
287 int categories =
288 scores->dims->size == 2 ? scores->dims->data[1] : scores->dims->data[0];
289
290 std::vector<Category> predictions;
291 predictions.reserve(categories);
292
293 bool should_dequantize = scores->type == kTfLiteUInt8 ||
294 scores->type == kTfLiteInt8 ||
295 scores->type == kTfLiteInt16;
296 for (int index = 0; index < categories; index++) {
297 std::string label;
298 if (use_index_as_labels) {
299 label = std::to_string(index);
300 } else if (labels_vector_ == nullptr) {
301 if (labels->type == kTfLiteString) {
302 label = GetStringAtIndex(labels, index);
303 } else if (labels->type == kTfLiteInt32) {
304 label = std::to_string(GetTensorData<int>(labels)[index]);
305 }
306 } else {
307 label = (*labels_vector_)[index];
308 }
309 if (should_dequantize) {
310 predictions.push_back(Category(label, Dequantize(*scores, index)));
311 } else if (scores->type == kTfLiteBool) {
312 predictions.push_back(
313 Category(label, GetTensorData<bool>(scores)[index] ? 1.0 : 0.0));
314 } else {
315 predictions.push_back(
316 Category(label, scores->type == kTfLiteFloat32
317 ? GetTensorData<float>(scores)[index]
318 : GetTensorData<double>(scores)[index]));
319 }
320 }
321
322 return predictions;
323 }
Initialize(const NLClassifierOptions & options)324 absl::Status NLClassifier::Initialize(const NLClassifierOptions& options) {
325 options_ = options;
326 // input tensor should be type STRING
327 auto input_tensor = FindTensorWithNameOrIndex(
328 GetInputTensors(), GetMetadataExtractor()->GetInputTensorMetadata(),
329 options.input_tensor_name, options.input_tensor_index);
330 if (input_tensor == nullptr) {
331 return CreateStatusWithPayload(
332 StatusCode::kInvalidArgument,
333 absl::StrCat("No input tensor found with name ",
334 options.input_tensor_name, " or at index ",
335 options.input_tensor_index),
336 TfLiteSupportStatus::kInputTensorNotFoundError);
337 }
338 if (HasRegexTokenizerMetadata()) {
339 if (input_tensor->type != kTfLiteInt32) {
340 return CreateStatusWithPayload(
341 StatusCode::kInvalidArgument,
342 absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
343 ". Requested INT32, got ",
344 TfLiteTypeGetName(input_tensor->type), "."),
345 TfLiteSupportStatus::kInvalidInputTensorTypeError);
346 }
347 RETURN_IF_ERROR(SetupRegexTokenizer());
348 } else {
349 if (input_tensor->type != kTfLiteString) {
350 return CreateStatusWithPayload(
351 StatusCode::kInvalidArgument,
352 absl::StrCat("Type mismatch for input tensor ", input_tensor->name,
353 ". Requested STRING, got ",
354 TfLiteTypeGetName(input_tensor->type), "."),
355 TfLiteSupportStatus::kInvalidInputTensorTypeError);
356 }
357 }
358
359 // output score tensor should be type
360 // UINT8/INT8/INT16(quantized) or FLOAT32/FLOAT64(dequantized) or BOOL
361 std::vector<const TfLiteTensor*> output_tensors = GetOutputTensors();
362 const Vector<Offset<TensorMetadata>>* output_tensor_metadatas =
363 GetMetadataExtractor()->GetOutputTensorMetadata();
364
365 const auto scores = FindTensorWithNameOrIndex(
366 output_tensors, output_tensor_metadatas, options.output_score_tensor_name,
367 options.output_score_tensor_index);
368 if (scores == nullptr) {
369 return CreateStatusWithPayload(
370 StatusCode::kInvalidArgument,
371 absl::StrCat("No output score tensor found with name ",
372 options.output_score_tensor_name, " or at index ",
373 options.output_score_tensor_index),
374 TfLiteSupportStatus::kOutputTensorNotFoundError);
375 }
376 static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteInt8,
377 kTfLiteInt16, kTfLiteFloat32,
378 kTfLiteFloat64, kTfLiteBool};
379 if (!absl::c_linear_search(valid_types, scores->type)) {
380 return CreateStatusWithPayload(
381 StatusCode::kInvalidArgument,
382 absl::StrCat("Type mismatch for score tensor ", scores->name,
383 ". Requested one of these types: "
384 "INT8/UINT8/INT16/FLOAT32/FLOAT64/BOOL, got ",
385 TfLiteTypeGetName(scores->type), "."),
386 TfLiteSupportStatus::kInvalidOutputTensorTypeError);
387 }
388
389 // Extract associated label file from output score tensor if one exists, a
390 // well-formatted metadata should have same number of tensors with the model.
391 if (output_tensor_metadatas &&
392 output_tensor_metadatas->size() == output_tensors.size()) {
393 for (int i = 0; i < output_tensor_metadatas->size(); ++i) {
394 const tflite::TensorMetadata* metadata = output_tensor_metadatas->Get(i);
395 if ((metadata->name() && metadata->name()->string_view() ==
396 options.output_score_tensor_name) ||
397 i == options.output_score_tensor_index) {
398 if (TrySetLabelFromMetadata(metadata).ok()) {
399 return absl::OkStatus();
400 }
401 }
402 }
403 }
404
405 // If labels_vector_ is not set up from metadata, try register output label
406 // tensor from options.
407 if (labels_vector_ == nullptr) {
408 // output label tensor should be type STRING or INT32 if the one exists
409 auto labels = FindTensorWithNameOrIndex(
410 output_tensors, output_tensor_metadatas,
411 options.output_label_tensor_name, options.output_label_tensor_index);
412 if (labels != nullptr && labels->type != kTfLiteString &&
413 labels->type != kTfLiteInt32) {
414 return CreateStatusWithPayload(
415 StatusCode::kInvalidArgument,
416 absl::StrCat("Type mismatch for label tensor ", scores->name,
417 ". Requested STRING or INT32, got ",
418 TfLiteTypeGetName(scores->type), "."),
419 TfLiteSupportStatus::kInvalidOutputTensorTypeError);
420 }
421 }
422 return absl::OkStatus();
423 }
424
425 StatusOr<std::unique_ptr<NLClassifier>>
CreateFromBufferAndOptions(const char * model_buffer_data,size_t model_buffer_size,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)426 NLClassifier::CreateFromBufferAndOptions(
427 const char* model_buffer_data, size_t model_buffer_size,
428 const NLClassifierOptions& options,
429 std::unique_ptr<tflite::OpResolver> resolver) {
430 std::unique_ptr<NLClassifier> nl_classifier;
431 ASSIGN_OR_RETURN(
432 nl_classifier,
433 core::TaskAPIFactory::CreateFromBuffer<NLClassifier>(
434 model_buffer_data, model_buffer_size, std::move(resolver)));
435 RETURN_IF_ERROR(nl_classifier->Initialize(options));
436 return std::move(nl_classifier);
437 }
438
CreateFromFileAndOptions(const std::string & path_to_model,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)439 StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFileAndOptions(
440 const std::string& path_to_model, const NLClassifierOptions& options,
441 std::unique_ptr<tflite::OpResolver> resolver) {
442 std::unique_ptr<NLClassifier> nl_classifier;
443 ASSIGN_OR_RETURN(nl_classifier,
444 core::TaskAPIFactory::CreateFromFile<NLClassifier>(
445 path_to_model, std::move(resolver)));
446 RETURN_IF_ERROR(nl_classifier->Initialize(options));
447 return std::move(nl_classifier);
448 }
449
CreateFromFdAndOptions(int fd,const NLClassifierOptions & options,std::unique_ptr<tflite::OpResolver> resolver)450 StatusOr<std::unique_ptr<NLClassifier>> NLClassifier::CreateFromFdAndOptions(
451 int fd, const NLClassifierOptions& options,
452 std::unique_ptr<tflite::OpResolver> resolver) {
453 std::unique_ptr<NLClassifier> nl_classifier;
454 ASSIGN_OR_RETURN(nl_classifier,
455 core::TaskAPIFactory::CreateFromFileDescriptor<NLClassifier>(
456 fd, std::move(resolver)));
457 RETURN_IF_ERROR(nl_classifier->Initialize(options));
458 return std::move(nl_classifier);
459 }
460
HasRegexTokenizerMetadata()461 bool NLClassifier::HasRegexTokenizerMetadata() {
462 const TensorMetadata* input_tensor_metadata =
463 GetMetadataExtractor()->GetInputTensorMetadata(
464 kRegexTokenizerInputTensorIndex);
465 if (input_tensor_metadata == nullptr) {
466 return false;
467 }
468 tflite::support::StatusOr<const tflite::ProcessUnit*> status =
469 GetMetadataExtractor()->FindFirstProcessUnit(
470 *input_tensor_metadata, ProcessUnitOptions_RegexTokenizerOptions);
471 return status.ok() ? status.value() != nullptr : false;
472 }
473
SetupRegexTokenizer()474 absl::Status NLClassifier::SetupRegexTokenizer() {
475 ASSIGN_OR_RETURN(
476 tokenizer_,
477 CreateRegexTokenizerFromProcessUnit(
478 GetMetadataExtractor()
479 ->GetInputTensorMetadata(kRegexTokenizerInputTensorIndex)
480 ->process_units()
481 ->Get(kRegexTokenizerProcessUnitIndex),
482 GetMetadataExtractor()));
483
484 return absl::OkStatus();
485 }
486
487 } // namespace nlclassifier
488 } // namespace text
489 } // namespace task
490 } // namespace tflite
491