1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_ 17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_ 18 19 #include <stddef.h> 20 #include <string.h> 21 22 #include <memory> 23 #include <string> 24 #include <vector> 25 26 #include "absl/status/status.h" 27 #include "flatbuffers/flatbuffers.h" // from @flatbuffers 28 #include "tensorflow/lite/c/common.h" 29 #include "tensorflow/lite/core/api/op_resolver.h" 30 #include "tensorflow/lite/kernels/register.h" 31 #include "tensorflow/lite/string_type.h" 32 #include "tensorflow_lite_support/cc/common.h" 33 #include "tensorflow_lite_support/cc/port/statusor.h" 34 #include "tensorflow_lite_support/cc/task/core/base_task_api.h" 35 #include "tensorflow_lite_support/cc/task/core/category.h" 36 #include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h" 37 38 namespace tflite { 39 namespace task { 40 namespace text { 41 namespace nlclassifier { 42 43 // Options to identify input and output tensors of the model 44 struct NLClassifierOptions { 45 int input_tensor_index = 0; 46 int output_score_tensor_index = 0; 47 // By default there is no output label tensor. The label file can be attached 48 // to the output score tensor metadata. 49 int output_label_tensor_index = -1; 50 std::string input_tensor_name = "INPUT"; 51 std::string output_score_tensor_name = "OUTPUT_SCORE"; 52 std::string output_label_tensor_name = "OUTPUT_LABEL"; 53 }; 54 55 // Classifier API for NLClassification tasks, categorizes string into different 56 // classes. 57 // 58 // The API expects a TFLite model with the following input/output tensor: 59 // Input tensor: 60 // (kTfLiteString) - input of the model, accepts a string. 61 // or 62 // (kTfLiteInt32) - input of the model, accepts a tokenized 63 // indices of a string input. A RegexTokenizer needs to be set up in the input 64 // tensor's metadata. 65 // Output score tensor: 66 // (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/ 67 // kTfLiteFloat64/kTfLiteBool) 68 // - output scores for each class, if type is one of the Int types, 69 // dequantize it to double, if type is kTfLiteBool, convert the values to 70 // 0.0 and 1.0 respectively 71 // - can have an optional associated file in metadata for labels, the file 72 // should be a plain text file with one label per line, the number of 73 // labels should match the number of categories the model outputs. 74 // Output label tensor: optional 75 // (kTfLiteString/kTfLiteInt32) 76 // - output classname for each class, should be of the same length with 77 // scores. If this tensor is not present, the API uses score indices as 78 // classnames. 79 // - will be ignored if output score tensor already has an associated label 80 // file. 81 // 82 // By default the API tries to find the input/output tensors with default 83 // configurations in NLClassifierOptions, with tensor name prioritized over 84 // tensor index. The option is configurable for different TFLite models. 85 class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>, 86 const std::string&> { 87 public: 88 using BaseTaskApi::BaseTaskApi; 89 90 // Creates a NLClassifier from TFLite model buffer. 91 static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> 92 CreateFromBufferAndOptions( 93 const char* model_buffer_data, size_t model_buffer_size, 94 const NLClassifierOptions& options = {}, 95 std::unique_ptr<tflite::OpResolver> resolver = 96 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); 97 98 // Creates a NLClassifier from TFLite model file. 99 static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> 100 CreateFromFileAndOptions( 101 const std::string& path_to_model, const NLClassifierOptions& options = {}, 102 std::unique_ptr<tflite::OpResolver> resolver = 103 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); 104 105 // Creates a NLClassifier from TFLite model file descriptor. 106 static tflite::support::StatusOr<std::unique_ptr<NLClassifier>> 107 CreateFromFdAndOptions( 108 int fd, const NLClassifierOptions& options = {}, 109 std::unique_ptr<tflite::OpResolver> resolver = 110 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); 111 112 // DEPRECATED (unannotated for backward compatibility). Prefer using `ClassifyText`. 113 std::vector<core::Category> Classify(const std::string& text); 114 115 // Performs classification on a string input, returns classified results or an 116 // error. 117 tflite::support::StatusOr<std::vector<core::Category>> ClassifyText( 118 const std::string& text); 119 120 // Gets the model version, or "NO_VERSION_INFO" in case there is no version. 121 std::string GetModelVersion() const; 122 123 // Gets the labels version, or "NO_VERSION_INFO" in case there is no version. 124 std::string GetLabelsVersion() const; 125 126 protected: 127 static constexpr int kOutputTensorIndex = 0; 128 static constexpr int kOutputTensorLabelFileIndex = 0; 129 130 absl::Status Initialize(const NLClassifierOptions& options); 131 const NLClassifierOptions& GetOptions() const; 132 133 // Try to extract attached label file from metadata and initialize 134 // labels_vector_, return error if metadata type is incorrect or no label file 135 // is attached in metadata. 136 absl::Status TrySetLabelFromMetadata(const TensorMetadata* metadata); 137 138 // Pass through the input text into model's input tensor. 139 absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors, 140 const std::string& input) override; 141 142 // Extract model output and create results with output label tensor or label 143 // file attached in metadata. If no output label tensor or label file is 144 // found, use output score index as labels. 145 tflite::support::StatusOr<std::vector<core::Category>> Postprocess( 146 const std::vector<const TfLiteTensor*>& output_tensors, 147 const std::string& input) override; 148 149 std::vector<core::Category> BuildResults(const TfLiteTensor* scores, 150 const TfLiteTensor* labels); 151 152 // Gets the tensor from a vector of tensors by checking tensor name first and 153 // tensor index second, return nullptr if no tensor is found. 154 template <typename TensorType> FindTensorWithNameOrIndex(const std::vector<TensorType * > & tensors,const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>> * metadata_array,const std::string & name,int index)155 static TensorType* FindTensorWithNameOrIndex( 156 const std::vector<TensorType*>& tensors, 157 const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>* 158 metadata_array, 159 const std::string& name, int index) { 160 if (metadata_array != nullptr && metadata_array->size() == tensors.size()) { 161 for (int i = 0; i < metadata_array->size(); i++) { 162 if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) { 163 return tensors[i]; 164 } 165 } 166 } 167 168 for (TensorType* tensor : tensors) { 169 if (tensor->name == name) { 170 return tensor; 171 } 172 } 173 return index >= 0 && index < tensors.size() ? tensors[index] : nullptr; 174 } 175 176 private: 177 bool HasRegexTokenizerMetadata(); 178 absl::Status SetupRegexTokenizer(); 179 180 NLClassifierOptions options_; 181 // labels vector initialized from output tensor's associated file, if one 182 // exists. 183 std::unique_ptr<std::vector<std::string>> labels_vector_; 184 // labels version assigned from output tensor's associated file metadata, 185 // if one exists. 186 std::string labels_version_; 187 std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer> tokenizer_; 188 }; 189 190 } // namespace nlclassifier 191 } // namespace text 192 } // namespace task 193 } // namespace tflite 194 195 #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_ 196