xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_
17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_
18 
19 #include <stddef.h>
20 #include <string.h>
21 
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "absl/status/status.h"
27 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/core/api/op_resolver.h"
30 #include "tensorflow/lite/kernels/register.h"
31 #include "tensorflow/lite/string_type.h"
32 #include "tensorflow_lite_support/cc/common.h"
33 #include "tensorflow_lite_support/cc/port/statusor.h"
34 #include "tensorflow_lite_support/cc/task/core/base_task_api.h"
35 #include "tensorflow_lite_support/cc/task/core/category.h"
36 #include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
37 
38 namespace tflite {
39 namespace task {
40 namespace text {
41 namespace nlclassifier {
42 
43 // Options to identify input and output tensors of the model
44 struct NLClassifierOptions {
45   int input_tensor_index = 0;
46   int output_score_tensor_index = 0;
47   // By default there is no output label tensor. The label file can be attached
48   // to the output score tensor metadata.
49   int output_label_tensor_index = -1;
50   std::string input_tensor_name = "INPUT";
51   std::string output_score_tensor_name = "OUTPUT_SCORE";
52   std::string output_label_tensor_name = "OUTPUT_LABEL";
53 };
54 
55 // Classifier API for NLClassification tasks, categorizes string into different
56 // classes.
57 //
58 // The API expects a TFLite model with the following input/output tensor:
59 // Input tensor:
60 //   (kTfLiteString) - input of the model, accepts a string.
61 //      or
62 //   (kTfLiteInt32) - input of the model, accepts a tokenized
63 //   indices of a string input. A RegexTokenizer needs to be set up in the input
64 //   tensor's metadata.
65 // Output score tensor:
66 //   (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/
67 //    kTfLiteFloat64/kTfLiteBool)
68 //    - output scores for each class, if type is one of the Int types,
69 //      dequantize it to double, if type is kTfLiteBool, convert the values to
70 //      0.0 and 1.0 respectively
71 //    - can have an optional associated file in metadata for labels, the file
72 //      should be a plain text file with one label per line, the number of
73 //      labels should match the number of categories the model outputs.
74 // Output label tensor: optional
75 //   (kTfLiteString/kTfLiteInt32)
76 //    - output classname for each class, should be of the same length with
77 //      scores. If this tensor is not present, the API uses score indices as
78 //      classnames.
79 //    - will be ignored if output score tensor already has an associated label
80 //      file.
81 //
82 // By default the API tries to find the input/output tensors with default
83 // configurations in NLClassifierOptions, with tensor name prioritized over
84 // tensor index. The option is configurable for different TFLite models.
85 class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
86                                               const std::string&> {
87  public:
88   using BaseTaskApi::BaseTaskApi;
89 
90   // Creates a NLClassifier from TFLite model buffer.
91   static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
92   CreateFromBufferAndOptions(
93       const char* model_buffer_data, size_t model_buffer_size,
94       const NLClassifierOptions& options = {},
95       std::unique_ptr<tflite::OpResolver> resolver =
96           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
97 
98   // Creates a NLClassifier from TFLite model file.
99   static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
100   CreateFromFileAndOptions(
101       const std::string& path_to_model, const NLClassifierOptions& options = {},
102       std::unique_ptr<tflite::OpResolver> resolver =
103           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
104 
105   // Creates a NLClassifier from TFLite model file descriptor.
106   static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
107   CreateFromFdAndOptions(
108       int fd, const NLClassifierOptions& options = {},
109       std::unique_ptr<tflite::OpResolver> resolver =
110           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
111 
112   // DEPRECATED (unannotated for backward compatibility).  Prefer using `ClassifyText`.
113   std::vector<core::Category> Classify(const std::string& text);
114 
115   // Performs classification on a string input, returns classified results or an
116   // error.
117   tflite::support::StatusOr<std::vector<core::Category>> ClassifyText(
118       const std::string& text);
119 
120   // Gets the model version, or "NO_VERSION_INFO" in case there is no version.
121   std::string GetModelVersion() const;
122 
123   // Gets the labels version, or "NO_VERSION_INFO" in case there is no version.
124   std::string GetLabelsVersion() const;
125 
126  protected:
127   static constexpr int kOutputTensorIndex = 0;
128   static constexpr int kOutputTensorLabelFileIndex = 0;
129 
130   absl::Status Initialize(const NLClassifierOptions& options);
131   const NLClassifierOptions& GetOptions() const;
132 
133   // Try to extract attached label file from metadata and initialize
134   // labels_vector_, return error if metadata type is incorrect or no label file
135   // is attached in metadata.
136   absl::Status TrySetLabelFromMetadata(const TensorMetadata* metadata);
137 
138   // Pass through the input text into model's input tensor.
139   absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
140                           const std::string& input) override;
141 
142   // Extract model output and create results with output label tensor or label
143   // file attached in metadata. If no output label tensor or label file is
144   // found, use output score index as labels.
145   tflite::support::StatusOr<std::vector<core::Category>> Postprocess(
146       const std::vector<const TfLiteTensor*>& output_tensors,
147       const std::string& input) override;
148 
149   std::vector<core::Category> BuildResults(const TfLiteTensor* scores,
150                                            const TfLiteTensor* labels);
151 
152   // Gets the tensor from a vector of tensors by checking tensor name first and
153   // tensor index second, return nullptr if no tensor is found.
154   template <typename TensorType>
FindTensorWithNameOrIndex(const std::vector<TensorType * > & tensors,const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>> * metadata_array,const std::string & name,int index)155   static TensorType* FindTensorWithNameOrIndex(
156       const std::vector<TensorType*>& tensors,
157       const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
158           metadata_array,
159       const std::string& name, int index) {
160     if (metadata_array != nullptr && metadata_array->size() == tensors.size()) {
161       for (int i = 0; i < metadata_array->size(); i++) {
162         if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) {
163           return tensors[i];
164         }
165       }
166     }
167 
168     for (TensorType* tensor : tensors) {
169       if (tensor->name == name) {
170         return tensor;
171       }
172     }
173     return index >= 0 && index < tensors.size() ? tensors[index] : nullptr;
174   }
175 
176  private:
177   bool HasRegexTokenizerMetadata();
178   absl::Status SetupRegexTokenizer();
179 
180   NLClassifierOptions options_;
181   // labels vector initialized from output tensor's associated file, if one
182   // exists.
183   std::unique_ptr<std::vector<std::string>> labels_vector_;
184   // labels version assigned from output tensor's associated file metadata,
185   // if one exists.
186   std::string labels_version_;
187   std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer> tokenizer_;
188 };
189 
190 }  // namespace nlclassifier
191 }  // namespace text
192 }  // namespace task
193 }  // namespace tflite
194 
195 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_
196