1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.h"
17 
18 #include <stddef.h>
19 
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/status/status.h"
26 #include "absl/strings/ascii.h"
27 #include "absl/strings/str_format.h"
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/core/api/op_resolver.h"
30 #include "tensorflow/lite/string_type.h"
31 #include "tensorflow_lite_support/cc/common.h"
32 #include "tensorflow_lite_support/cc/port/status_macros.h"
33 #include "tensorflow_lite_support/cc/task/core/category.h"
34 #include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
35 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
36 #include "tensorflow_lite_support/cc/task/text/nlclassifier/nl_classifier.h"
37 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
38 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
39 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
40 
41 namespace tflite {
42 namespace task {
43 namespace text {
44 namespace nlclassifier {
45 
46 using ::tflite::support::CreateStatusWithPayload;
47 using ::tflite::support::StatusOr;
48 using ::tflite::support::TfLiteSupportStatus;
49 using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
50 using ::tflite::support::text::tokenizer::TokenizerResult;
51 using ::tflite::task::core::FindTensorByName;
52 using ::tflite::task::core::PopulateTensor;
53 
54 namespace {
55 constexpr char kIdsTensorName[] = "ids";
56 constexpr char kMaskTensorName[] = "mask";
57 constexpr char kSegmentIdsTensorName[] = "segment_ids";
58 constexpr int kIdsTensorIndex = 0;
59 constexpr int kMaskTensorIndex = 1;
60 constexpr int kSegmentIdsTensorIndex = 2;
61 constexpr char kScoreTensorName[] = "probability";
62 constexpr char kClassificationToken[] = "[CLS]";
63 constexpr char kSeparator[] = "[SEP]";
64 constexpr int kTokenizerProcessUnitIndex = 0;
65 }  // namespace
66 
Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const std::string & input)67 absl::Status BertNLClassifier::Preprocess(
68     const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
69   auto* input_tensor_metadatas =
70       GetMetadataExtractor()->GetInputTensorMetadata();
71   auto* ids_tensor =
72       FindTensorByName(input_tensors, input_tensor_metadatas, kIdsTensorName);
73   auto* mask_tensor =
74       FindTensorByName(input_tensors, input_tensor_metadatas, kMaskTensorName);
75   auto* segment_ids_tensor = FindTensorByName(
76       input_tensors, input_tensor_metadatas, kSegmentIdsTensorName);
77 
78   std::string processed_input = input;
79   absl::AsciiStrToLower(&processed_input);
80 
81   TokenizerResult input_tokenize_results;
82   input_tokenize_results = tokenizer_->Tokenize(processed_input);
83 
84   // Offset by 2 to account for [CLS] and [SEP]
85   int input_tokens_size =
86       static_cast<int>(input_tokenize_results.subwords.size()) + 2;
87   int input_tensor_length = input_tokens_size;
88   if (!input_tensors_are_dynamic_) {
89     input_tokens_size = std::min(kMaxSeqLen, input_tokens_size);
90     input_tensor_length = kMaxSeqLen;
91   } else {
92     GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kIdsTensorIndex,
93                                                     {1, input_tensor_length});
94     GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kMaskTensorIndex,
95                                                     {1, input_tensor_length});
96     GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kSegmentIdsTensorIndex,
97                                                     {1, input_tensor_length});
98     GetTfLiteEngine()->interpreter()->AllocateTensors();
99   }
100 
101   std::vector<std::string> input_tokens;
102   input_tokens.reserve(input_tokens_size);
103   input_tokens.push_back(std::string(kClassificationToken));
104   for (int i = 0; i < input_tokens_size - 2; ++i) {
105     input_tokens.push_back(std::move(input_tokenize_results.subwords[i]));
106   }
107   input_tokens.push_back(std::string(kSeparator));
108 
109   std::vector<int> input_ids(input_tensor_length, 0);
110   std::vector<int> input_mask(input_tensor_length, 0);
111   // Convert tokens back into ids and set mask
112   for (int i = 0; i < input_tokens.size(); ++i) {
113     tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
114     input_mask[i] = 1;
115   }
116   //                           |<--------input_tensor_length------->|
117   // input_ids                 [CLS] s1  s2...  sn [SEP]  0  0...  0
118   // input_masks                 1    1   1...  1    1    0  0...  0
119   // segment_ids                 0    0   0...  0    0    0  0...  0
120 
121   PopulateTensor(input_ids, ids_tensor);
122   PopulateTensor(input_mask, mask_tensor);
123   PopulateTensor(std::vector<int>(input_tensor_length, 0), segment_ids_tensor);
124 
125   return absl::OkStatus();
126 }
127 
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const std::string &)128 StatusOr<std::vector<core::Category>> BertNLClassifier::Postprocess(
129     const std::vector<const TfLiteTensor*>& output_tensors,
130     const std::string& /*input*/) {
131   if (output_tensors.size() != 1) {
132     return CreateStatusWithPayload(
133         absl::StatusCode::kInvalidArgument,
134         absl::StrFormat("BertNLClassifier models are expected to have only 1 "
135                         "output, found %d",
136                         output_tensors.size()),
137         TfLiteSupportStatus::kInvalidNumOutputTensorsError);
138   }
139   const TfLiteTensor* scores = FindTensorByName(
140       output_tensors, GetMetadataExtractor()->GetOutputTensorMetadata(),
141       kScoreTensorName);
142 
143   // optional labels extracted from metadata
144   return BuildResults(scores, /*labels=*/nullptr);
145 }
146 
147 StatusOr<std::unique_ptr<BertNLClassifier>>
CreateFromFile(const std::string & path_to_model_with_metadata,std::unique_ptr<tflite::OpResolver> resolver)148 BertNLClassifier::CreateFromFile(
149     const std::string& path_to_model_with_metadata,
150     std::unique_ptr<tflite::OpResolver> resolver) {
151   std::unique_ptr<BertNLClassifier> bert_nl_classifier;
152   ASSIGN_OR_RETURN(bert_nl_classifier,
153                    core::TaskAPIFactory::CreateFromFile<BertNLClassifier>(
154                        path_to_model_with_metadata, std::move(resolver)));
155   RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
156   return std::move(bert_nl_classifier);
157 }
158 
159 StatusOr<std::unique_ptr<BertNLClassifier>>
CreateFromBuffer(const char * model_with_metadata_buffer_data,size_t model_with_metadata_buffer_size,std::unique_ptr<tflite::OpResolver> resolver)160 BertNLClassifier::CreateFromBuffer(
161     const char* model_with_metadata_buffer_data,
162     size_t model_with_metadata_buffer_size,
163     std::unique_ptr<tflite::OpResolver> resolver) {
164   std::unique_ptr<BertNLClassifier> bert_nl_classifier;
165   ASSIGN_OR_RETURN(bert_nl_classifier,
166                    core::TaskAPIFactory::CreateFromBuffer<BertNLClassifier>(
167                        model_with_metadata_buffer_data,
168                        model_with_metadata_buffer_size, std::move(resolver)));
169   RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
170   return std::move(bert_nl_classifier);
171 }
172 
CreateFromFd(int fd,std::unique_ptr<tflite::OpResolver> resolver)173 StatusOr<std::unique_ptr<BertNLClassifier>> BertNLClassifier::CreateFromFd(
174     int fd, std::unique_ptr<tflite::OpResolver> resolver) {
175   std::unique_ptr<BertNLClassifier> bert_nl_classifier;
176   ASSIGN_OR_RETURN(
177       bert_nl_classifier,
178       core::TaskAPIFactory::CreateFromFileDescriptor<BertNLClassifier>(
179           fd, std::move(resolver)));
180   RETURN_IF_ERROR(bert_nl_classifier->InitializeFromMetadata());
181   return std::move(bert_nl_classifier);
182 }
183 
InitializeFromMetadata()184 absl::Status BertNLClassifier::InitializeFromMetadata() {
185   // Set up mandatory tokenizer.
186   const ProcessUnit* tokenizer_process_unit =
187       GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
188   if (tokenizer_process_unit == nullptr) {
189     return CreateStatusWithPayload(
190         absl::StatusCode::kInvalidArgument,
191         "No input process unit found from metadata.",
192         TfLiteSupportStatus::kMetadataInvalidTokenizerError);
193   }
194   ASSIGN_OR_RETURN(tokenizer_,
195                    CreateTokenizerFromProcessUnit(tokenizer_process_unit,
196                                                   GetMetadataExtractor()));
197 
198   // Set up optional label vector.
199   TrySetLabelFromMetadata(
200       GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex))
201       .IgnoreError();
202 
203   auto* input_tensor_metadatas =
204       GetMetadataExtractor()->GetInputTensorMetadata();
205   const auto& input_tensors = GetInputTensors();
206   const auto& ids_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
207                                              kIdsTensorName);
208   const auto& mask_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
209                                               kMaskTensorName);
210   const auto& segment_ids_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
211                                                      kSegmentIdsTensorName);
212   if (ids_tensor.dims->size != 2 || mask_tensor.dims->size != 2 ||
213       segment_ids_tensor.dims->size != 2) {
214     return CreateStatusWithPayload(
215         absl::StatusCode::kInternal,
216         absl::StrFormat(
217             "The three input tensors in Bert models are expected to have dim "
218             "2, but got ids_tensor (%d), mask_tensor (%d), segment_ids_tensor "
219             "(%d).",
220             ids_tensor.dims->size, mask_tensor.dims->size,
221             segment_ids_tensor.dims->size),
222         TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
223   }
224   if (ids_tensor.dims->data[0] != 1 || mask_tensor.dims->data[0] != 1 ||
225       segment_ids_tensor.dims->data[0] != 1) {
226     return CreateStatusWithPayload(
227         absl::StatusCode::kInternal,
228         absl::StrFormat(
229             "The three input tensors in Bert models are expected to have same "
230             "batch size 1, but got ids_tensor (%d), mask_tensor (%d), "
231             "segment_ids_tensor (%d).",
232             ids_tensor.dims->data[0], mask_tensor.dims->data[0],
233             segment_ids_tensor.dims->data[0]),
234         TfLiteSupportStatus::kInvalidInputTensorSizeError);
235   }
236   if (ids_tensor.dims->data[1] != mask_tensor.dims->data[1] ||
237       ids_tensor.dims->data[1] != segment_ids_tensor.dims->data[1]) {
238     return CreateStatusWithPayload(
239         absl::StatusCode::kInternal,
240         absl::StrFormat("The three input tensors in Bert models are "
241                         "expected to have same length, but got ids_tensor "
242                         "(%d), mask_tensor (%d), segment_ids_tensor (%d).",
243                         ids_tensor.dims->data[1], mask_tensor.dims->data[1],
244                         segment_ids_tensor.dims->data[1]),
245         TfLiteSupportStatus::kInvalidInputTensorSizeError);
246   }
247 
248   // If some tensor does not have a size 2 dims_signature, then we
249   // assume the input is not dynamic.
250   if (ids_tensor.dims_signature->size != 2 ||
251       mask_tensor.dims_signature->size != 2 ||
252       segment_ids_tensor.dims_signature->size != 2) {
253     return absl::OkStatus();
254   }
255 
256   if (ids_tensor.dims_signature->data[1] == -1 &&
257       mask_tensor.dims_signature->data[1] == -1 &&
258       segment_ids_tensor.dims_signature->data[1] == -1) {
259     input_tensors_are_dynamic_ = true;
260   } else if (ids_tensor.dims_signature->data[1] == -1 ||
261              mask_tensor.dims_signature->data[1] == -1 ||
262              segment_ids_tensor.dims_signature->data[1] == -1) {
263     return CreateStatusWithPayload(
264         absl::StatusCode::kInternal,
265         "Input tensors contain a mix of static and dynamic tensors",
266         TfLiteSupportStatus::kInvalidInputTensorSizeError);
267   }
268 
269   return absl::OkStatus();
270 }
271 
272 }  // namespace nlclassifier
273 }  // namespace text
274 }  // namespace task
275 }  // namespace tflite
276