1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
17
18 #include <stddef.h>
19
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/status/status.h"
26 #include "absl/strings/ascii.h"
27 #include "absl/strings/str_format.h"
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/core/api/op_resolver.h"
30 #include "tensorflow/lite/string_type.h"
31 #include "tensorflow_lite_support/cc/common.h"
32 #include "tensorflow_lite_support/cc/port/status_macros.h"
33 #include "tensorflow_lite_support/cc/task/core/category.h"
34 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
35 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
36 #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
37 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
38 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
39 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
40
41 namespace tflite {
42 namespace task {
43 namespace text {
44 namespace nlclassifier {
45
46 using ::tflite::support::CreateStatusWithPayload;
47 using ::tflite::support::StatusOr;
48 using ::tflite::support::TfLiteSupportStatus;
49 using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
50 using ::tflite::support::text::tokenizer::TokenizerResult;
51 using ::tflite::task::core::FindTensorByName;
52 using ::tflite::task::core::PopulateTensor;
53
54 namespace {
55 constexpr char kIdsTensorName[] = "ids";
56 constexpr char kMaskTensorName[] = "mask";
57 constexpr char kSegmentIdsTensorName[] = "segment_ids";
58 constexpr int kIdsTensorIndex = 0;
59 constexpr int kMaskTensorIndex = 1;
60 constexpr int kSegmentIdsTensorIndex = 2;
61 constexpr char kScoreTensorName[] = "probability";
62 constexpr char kClassificationToken[] = "[CLS]";
63 constexpr char kSeparator[] = "[SEP]";
64 constexpr int kTokenizerProcessUnitIndex = 0;
65 } // namespace
66
Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const std::string & input)67 absl::Status BertNLClassifier::Preprocess(
68 const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
69 auto* input_tensor_metadatas =
70 GetMetadataExtractor()->GetInputTensorMetadata();
71 auto* ids_tensor =
72 FindTensorByName(input_tensors, input_tensor_metadatas, kIdsTensorName);
73 auto* mask_tensor =
74 FindTensorByName(input_tensors, input_tensor_metadatas, kMaskTensorName);
75 auto* segment_ids_tensor = FindTensorByName(
76 input_tensors, input_tensor_metadatas, kSegmentIdsTensorName);
77
78 std::string processed_input = input;
79 absl::AsciiStrToLower(&processed_input);
80
81 TokenizerResult input_tokenize_results;
82 input_tokenize_results = tokenizer_->Tokenize(processed_input);
83
84 // Offset by 2 to account for [CLS] and [SEP]
85 int input_tokens_size =
86 static_cast<int>(input_tokenize_results.subwords.size()) + 2;
87 int input_tensor_length = input_tokens_size;
88 if (!input_tensors_are_dynamic_) {
89 input_tokens_size = std::min(kMaxSeqLen, input_tokens_size);
90 input_tensor_length = kMaxSeqLen;
91 } else {
92 GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kIdsTensorIndex,
93 {1, input_tensor_length});
94 GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kMaskTensorIndex,
95 {1, input_tensor_length});
96 GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kSegmentIdsTensorIndex,
97 {1, input_tensor_length});
98 GetTfLiteEngine()->interpreter()->AllocateTensors();
99 }
100
101 std::vector<std::string> input_tokens;
102 input_tokens.reserve(input_tokens_size);
103 input_tokens.push_back(std::string(kClassificationToken));
104 for (int i = 0; i < input_tokens_size - 2; ++i) {
105 input_tokens.push_back(std::move(input_tokenize_results.subwords[i]));
106 }
107 input_tokens.push_back(std::string(kSeparator));
108
109 std::vector<int> input_ids(input_tensor_length, 0);
110 std::vector<int> input_mask(input_tensor_length, 0);
111 // Convert tokens back into ids and set mask
112 for (int i = 0; i < input_tokens.size(); ++i) {
113 tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
114 input_mask[i] = 1;
115 }
116 // |<--------input_tensor_length------->|
117 // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
118 // input_masks 1 1 1... 1 1 0 0... 0
119 // segment_ids 0 0 0... 0 0 0 0... 0
120
121 PopulateTensor(input_ids, ids_tensor);
122 PopulateTensor(input_mask, mask_tensor);
123 PopulateTensor(std::vector<int>(input_tensor_length, 0), segment_ids_tensor);
124
125 return absl::OkStatus();
126 }
127
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const std::string &)128 StatusOr<std::vector<core::Category>> BertNLClassifier::Postprocess(
129 const std::vector<const TfLiteTensor*>& output_tensors,
130 const std::string& /*input*/) {
131 if (output_tensors.size() != 1) {
132 return CreateStatusWithPayload(
133 absl::StatusCode::kInvalidArgument,
134 absl::StrFormat("BertNLClassifier models are expected to have only 1 "
135 "output, found %d",
136 output_tensors.size()),
137 TfLiteSupportStatus::kInvalidNumOutputTensorsError);
138 }
139 const TfLiteTensor* scores = FindTensorByName(
140 output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
141 kScoreTensorName);
142
143 // optional labels extracted from metadata
144 return BuildResults(scores, /*labels=*/nullptr);
145 }
146
147 StatusOr<std::unique_ptr<BertNLClassifier>>
CreateFromFile(const std::string & path_to_model_with_metadata,std::unique_ptr<tflite::OpResolver> resolver)148 BertNLClassifier::CreateFromFile(
149 const std::string& path_to_model_with_metadata,
150 std::unique_ptr<tflite::OpResolver> resolver) {
151 std::unique_ptr<BertNLClassifier> bert_nl_classifier;
152 ASSIGN_OR_RETURN(bert_nl_classifier,
153 core::TaskAPIFactory::CreateFromFile<BertNLClassifier>(
154 path_to_model_with_metadata, std::move(resolver)));
155 RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
156 return std::move(bert_nl_classifier);
157 }
158
159 StatusOr<std::unique_ptr<BertNLClassifier>>
CreateFromBuffer(const char * model_with_metadata_buffer_data,size_t model_with_metadata_buffer_size,std::unique_ptr<tflite::OpResolver> resolver)160 BertNLClassifier::CreateFromBuffer(
161 const char* model_with_metadata_buffer_data,
162 size_t model_with_metadata_buffer_size,
163 std::unique_ptr<tflite::OpResolver> resolver) {
164 std::unique_ptr<BertNLClassifier> bert_nl_classifier;
165 ASSIGN_OR_RETURN(bert_nl_classifier,
166 core::TaskAPIFactory::CreateFromBuffer<BertNLClassifier>(
167 model_with_metadata_buffer_data,
168 model_with_metadata_buffer_size, std::move(resolver)));
169 RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
170 return std::move(bert_nl_classifier);
171 }
172
CreateFromFd(int fd,std::unique_ptr<tflite::OpResolver> resolver)173 StatusOr<std::unique_ptr<BertNLClassifier>> BertNLClassifier::CreateFromFd(
174 int fd, std::unique_ptr<tflite::OpResolver> resolver) {
175 std::unique_ptr<BertNLClassifier> bert_nl_classifier;
176 ASSIGN_OR_RETURN(
177 bert_nl_classifier,
178 core::TaskAPIFactory::CreateFromFileDescriptor<BertNLClassifier>(
179 fd, std::move(resolver)));
180 RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
181 return std::move(bert_nl_classifier);
182 }
183
InitializeFromMetadata()184 absl::Status BertNLClassifier::InitializeFromMetadata() {
185 // Set up mandatory tokenizer.
186 const ProcessUnit* tokenizer_process_unit =
187 GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
188 if (tokenizer_process_unit == nullptr) {
189 return CreateStatusWithPayload(
190 absl::StatusCode::kInvalidArgument,
191 "No input process unit found from metadata.",
192 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
193 }
194 ASSIGN_OR_RETURN(tokenizer_,
195 CreateTokenizerFromProcessUnit(tokenizer_process_unit,
196 GetMetadataExtractor()));
197
198 // Set up optional label vector.
199 TrySetLabelFromMetadata(
200 GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex))
201 .IgnoreError();
202
203 auto* input_tensor_metadatas =
204 GetMetadataExtractor()->GetInputTensorMetadata();
205 const auto& input_tensors = GetInputTensors();
206 const auto& ids_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
207 kIdsTensorName);
208 const auto& mask_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
209 kMaskTensorName);
210 const auto& segment_ids_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
211 kSegmentIdsTensorName);
212 if (ids_tensor.dims->size != 2 || mask_tensor.dims->size != 2 ||
213 segment_ids_tensor.dims->size != 2) {
214 return CreateStatusWithPayload(
215 absl::StatusCode::kInternal,
216 absl::StrFormat(
217 "The three input tensors in Bert models are expected to have dim "
218 "2, but got ids_tensor (%d), mask_tensor (%d), segment_ids_tensor "
219 "(%d).",
220 ids_tensor.dims->size, mask_tensor.dims->size,
221 segment_ids_tensor.dims->size),
222 TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
223 }
224 if (ids_tensor.dims->data[0] != 1 || mask_tensor.dims->data[0] != 1 ||
225 segment_ids_tensor.dims->data[0] != 1) {
226 return CreateStatusWithPayload(
227 absl::StatusCode::kInternal,
228 absl::StrFormat(
229 "The three input tensors in Bert models are expected to have same "
230 "batch size 1, but got ids_tensor (%d), mask_tensor (%d), "
231 "segment_ids_tensor (%d).",
232 ids_tensor.dims->data[0], mask_tensor.dims->data[0],
233 segment_ids_tensor.dims->data[0]),
234 TfLiteSupportStatus::kInvalidInputTensorSizeError);
235 }
236 if (ids_tensor.dims->data[1] != mask_tensor.dims->data[1] ||
237 ids_tensor.dims->data[1] != segment_ids_tensor.dims->data[1]) {
238 return CreateStatusWithPayload(
239 absl::StatusCode::kInternal,
240 absl::StrFormat("The three input tensors in Bert models are "
241 "expected to have same length, but got ids_tensor "
242 "(%d), mask_tensor (%d), segment_ids_tensor (%d).",
243 ids_tensor.dims->data[1], mask_tensor.dims->data[1],
244 segment_ids_tensor.dims->data[1]),
245 TfLiteSupportStatus::kInvalidInputTensorSizeError);
246 }
247
248 // If some tensor does not have a size 2 dims_signature, then we
249 // assume the input is not dynamic.
250 if (ids_tensor.dims_signature->size != 2 ||
251 mask_tensor.dims_signature->size != 2 ||
252 segment_ids_tensor.dims_signature->size != 2) {
253 return absl::OkStatus();
254 }
255
256 if (ids_tensor.dims_signature->data[1] == -1 &&
257 mask_tensor.dims_signature->data[1] == -1 &&
258 segment_ids_tensor.dims_signature->data[1] == -1) {
259 input_tensors_are_dynamic_ = true;
260 } else if (ids_tensor.dims_signature->data[1] == -1 ||
261 mask_tensor.dims_signature->data[1] == -1 ||
262 segment_ids_tensor.dims_signature->data[1] == -1) {
263 return CreateStatusWithPayload(
264 absl::StatusCode::kInternal,
265 "Input tensors contain a mix of static and dynamic tensors",
266 TfLiteSupportStatus::kInvalidInputTensorSizeError);
267 }
268
269 return absl::OkStatus();
270 }
271
272 } // namespace nlclassifier
273 } // namespace text
274 } // namespace task
275 } // namespace tflite
276