1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h"
17
18 #include "absl/strings/str_join.h"
19 #include "absl/strings/str_split.h"
20 #include "tensorflow_lite_support/cc/port/status_macros.h"
21 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
22 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
23 #include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
24 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
25
26 namespace tflite {
27 namespace task {
28 namespace text {
29 namespace qa {
30
31 constexpr char kIdsTensorName[] = "ids";
32 constexpr char kMaskTensorName[] = "mask";
33 constexpr char kSegmentIdsTensorName[] = "segment_ids";
34 constexpr char kEndLogitsTensorName[] = "end_logits";
35 constexpr char kStartLogitsTensorName[] = "start_logits";
36
37 using ::tflite::support::CreateStatusWithPayload;
38 using ::tflite::support::StatusOr;
39 using ::tflite::support::TfLiteSupportStatus;
40 using ::tflite::support::text::tokenizer::BertTokenizer;
41 using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
42 using ::tflite::support::text::tokenizer::SentencePieceTokenizer;
43 using ::tflite::support::text::tokenizer::TokenizerResult;
44 using ::tflite::task::core::FindTensorByName;
45 using ::tflite::task::core::PopulateTensor;
46 using ::tflite::task::core::PopulateVector;
47 using ::tflite::task::core::ReverseSortIndices;
48
49 namespace {
50 constexpr int kTokenizerProcessUnitIndex = 0;
51 }
52
53 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateFromFile(const std::string & path_to_model_with_metadata)54 BertQuestionAnswerer::CreateFromFile(
55 const std::string& path_to_model_with_metadata) {
56 std::unique_ptr<BertQuestionAnswerer> api_to_init;
57 ASSIGN_OR_RETURN(
58 api_to_init,
59 core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
60 path_to_model_with_metadata,
61 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
62 kNumLiteThreads));
63 RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
64 return api_to_init;
65 }
66
67 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateFromBuffer(const char * model_with_metadata_buffer_data,size_t model_with_metadata_buffer_size)68 BertQuestionAnswerer::CreateFromBuffer(
69 const char* model_with_metadata_buffer_data,
70 size_t model_with_metadata_buffer_size) {
71 std::unique_ptr<BertQuestionAnswerer> api_to_init;
72 ASSIGN_OR_RETURN(
73 api_to_init,
74 core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
75 model_with_metadata_buffer_data, model_with_metadata_buffer_size,
76 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
77 kNumLiteThreads));
78 RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
79 return api_to_init;
80 }
81
CreateFromFd(int fd)82 StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateFromFd(
83 int fd) {
84 std::unique_ptr<BertQuestionAnswerer> api_to_init;
85 ASSIGN_OR_RETURN(
86 api_to_init,
87 core::TaskAPIFactory::CreateFromFileDescriptor<BertQuestionAnswerer>(
88 fd, absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
89 kNumLiteThreads));
90 RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
91 return api_to_init;
92 }
93
94 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateBertQuestionAnswererFromFile(const std::string & path_to_model,const std::string & path_to_vocab)95 BertQuestionAnswerer::CreateBertQuestionAnswererFromFile(
96 const std::string& path_to_model, const std::string& path_to_vocab) {
97 std::unique_ptr<BertQuestionAnswerer> api_to_init;
98 ASSIGN_OR_RETURN(
99 api_to_init,
100 core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
101 path_to_model,
102 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
103 kNumLiteThreads));
104 api_to_init->InitializeBertTokenizer(path_to_vocab);
105 return api_to_init;
106 }
107
108 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateBertQuestionAnswererFromBuffer(const char * model_buffer_data,size_t model_buffer_size,const char * vocab_buffer_data,size_t vocab_buffer_size)109 BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
110 const char* model_buffer_data, size_t model_buffer_size,
111 const char* vocab_buffer_data, size_t vocab_buffer_size) {
112 std::unique_ptr<BertQuestionAnswerer> api_to_init;
113 ASSIGN_OR_RETURN(
114 api_to_init,
115 core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
116 model_buffer_data, model_buffer_size,
117 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
118 kNumLiteThreads));
119 api_to_init->InitializeBertTokenizerFromBinary(vocab_buffer_data,
120 vocab_buffer_size);
121 return api_to_init;
122 }
123
124 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateAlbertQuestionAnswererFromFile(const std::string & path_to_model,const std::string & path_to_spmodel)125 BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile(
126 const std::string& path_to_model, const std::string& path_to_spmodel) {
127 std::unique_ptr<BertQuestionAnswerer> api_to_init;
128 ASSIGN_OR_RETURN(
129 api_to_init,
130 core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
131 path_to_model,
132 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
133 kNumLiteThreads));
134 api_to_init->InitializeSentencepieceTokenizer(path_to_spmodel);
135 return api_to_init;
136 }
137
138 StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateAlbertQuestionAnswererFromBuffer(const char * model_buffer_data,size_t model_buffer_size,const char * spmodel_buffer_data,size_t spmodel_buffer_size)139 BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
140 const char* model_buffer_data, size_t model_buffer_size,
141 const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
142 std::unique_ptr<BertQuestionAnswerer> api_to_init;
143 ASSIGN_OR_RETURN(
144 api_to_init,
145 core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
146 model_buffer_data, model_buffer_size,
147 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
148 kNumLiteThreads));
149 api_to_init->InitializeSentencepieceTokenizerFromBinary(spmodel_buffer_data,
150 spmodel_buffer_size);
151 return api_to_init;
152 }
153
Answer(const std::string & context,const std::string & question)154 std::vector<QaAnswer> BertQuestionAnswerer::Answer(
155 const std::string& context, const std::string& question) {
156 // The BertQuestionAnswererer implementation for Preprocess() and
157 // Postprocess() never returns errors: just call value().
158 return Infer(context, question).value();
159 }
160
Preprocess(const std::vector<TfLiteTensor * > & input_tensors,const std::string & context,const std::string & query)161 absl::Status BertQuestionAnswerer::Preprocess(
162 const std::vector<TfLiteTensor*>& input_tensors, const std::string& context,
163 const std::string& query) {
164 auto* input_tensor_metadatas =
165 GetMetadataExtractor()->GetInputTensorMetadata();
166 TfLiteTensor* ids_tensor =
167 input_tensor_metadatas
168 ? FindTensorByName(input_tensors, input_tensor_metadatas,
169 kIdsTensorName)
170 : input_tensors[0];
171 TfLiteTensor* mask_tensor =
172 input_tensor_metadatas
173 ? FindTensorByName(input_tensors, input_tensor_metadatas,
174 kMaskTensorName)
175 : input_tensors[1];
176 TfLiteTensor* segment_ids_tensor =
177 input_tensor_metadatas
178 ? FindTensorByName(input_tensors, input_tensor_metadatas,
179 kSegmentIdsTensorName)
180 : input_tensors[2];
181
182 token_to_orig_map_.clear();
183
184 // The orig_tokens is used for recovering the answer string from the index,
185 // while the processed_tokens is lower-cased and used to generate input of
186 // the model.
187 orig_tokens_ = absl::StrSplit(context, absl::ByChar(' '), absl::SkipEmpty());
188 std::vector<std::string> processed_tokens(orig_tokens_);
189
190 std::string processed_query = query;
191 if (kUseLowerCase) {
192 for (auto& token : processed_tokens) {
193 absl::AsciiStrToLower(&token);
194 }
195 absl::AsciiStrToLower(&processed_query);
196 }
197
198 TokenizerResult query_tokenize_results;
199 query_tokenize_results = tokenizer_->Tokenize(processed_query);
200
201 std::vector<std::string> query_tokens = query_tokenize_results.subwords;
202 if (query_tokens.size() > kMaxQueryLen) {
203 query_tokens.resize(kMaxQueryLen);
204 }
205
206 // Example:
207 // context: tokenize me please
208 // all_doc_tokens: token ##ize me plea ##se
209 // token_to_orig_index: [0, 0, 1, 2, 2]
210
211 std::vector<std::string> all_doc_tokens;
212 std::vector<int> token_to_orig_index;
213 for (size_t i = 0; i < processed_tokens.size(); i++) {
214 const std::string& token = processed_tokens[i];
215 std::vector<std::string> sub_tokens = tokenizer_->Tokenize(token).subwords;
216 for (const std::string& sub_token : sub_tokens) {
217 token_to_orig_index.emplace_back(i);
218 all_doc_tokens.emplace_back(sub_token);
219 }
220 }
221
222 // -3 accounts for [CLS], [SEP] and [SEP].
223 int max_context_len = kMaxSeqLen - query_tokens.size() - 3;
224 if (all_doc_tokens.size() > max_context_len) {
225 all_doc_tokens.resize(max_context_len);
226 }
227
228 std::vector<std::string> tokens;
229 tokens.reserve(3 + query_tokens.size() + all_doc_tokens.size());
230 std::vector<int> segment_ids;
231 segment_ids.reserve(kMaxSeqLen);
232
233 // Start of generating the features.
234 tokens.emplace_back("[CLS]");
235 segment_ids.emplace_back(0);
236
237 // For query input.
238 for (const auto& query_token : query_tokens) {
239 tokens.emplace_back(query_token);
240 segment_ids.emplace_back(0);
241 }
242
243 // For Separation.
244 tokens.emplace_back("[SEP]");
245 segment_ids.emplace_back(0);
246
247 // For Text Input.
248 for (int i = 0; i < all_doc_tokens.size(); i++) {
249 auto& doc_token = all_doc_tokens[i];
250 tokens.emplace_back(doc_token);
251 segment_ids.emplace_back(1);
252 token_to_orig_map_[tokens.size()] = token_to_orig_index[i];
253 }
254
255 // For ending mark.
256 tokens.emplace_back("[SEP]");
257 segment_ids.emplace_back(1);
258
259 std::vector<int> input_ids(tokens.size());
260 input_ids.reserve(kMaxSeqLen);
261 // Convert tokens back into ids
262 for (int i = 0; i < tokens.size(); i++) {
263 auto& token = tokens[i];
264 tokenizer_->LookupId(token, &input_ids[i]);
265 }
266
267 std::vector<int> input_mask;
268 input_mask.reserve(kMaxSeqLen);
269 input_mask.insert(input_mask.end(), tokens.size(), 1);
270
271 int zeros_to_pad = kMaxSeqLen - input_ids.size();
272 input_ids.insert(input_ids.end(), zeros_to_pad, 0);
273 input_mask.insert(input_mask.end(), zeros_to_pad, 0);
274 segment_ids.insert(segment_ids.end(), zeros_to_pad, 0);
275
276 // input_ids INT32[1, 384]
277 PopulateTensor(input_ids, ids_tensor);
278 // input_mask INT32[1, 384]
279 PopulateTensor(input_mask, mask_tensor);
280 // segment_ids INT32[1, 384]
281 PopulateTensor(segment_ids, segment_ids_tensor);
282
283 return absl::OkStatus();
284 }
285
Postprocess(const std::vector<const TfLiteTensor * > & output_tensors,const std::string &,const std::string &)286 StatusOr<std::vector<QaAnswer>> BertQuestionAnswerer::Postprocess(
287 const std::vector<const TfLiteTensor*>& output_tensors,
288 const std::string& /*lowercased_context*/,
289 const std::string& /*lowercased_query*/) {
290 auto* output_tensor_metadatas =
291 GetMetadataExtractor()->GetOutputTensorMetadata();
292
293 const TfLiteTensor* end_logits_tensor =
294 output_tensor_metadatas
295 ? FindTensorByName(output_tensors, output_tensor_metadatas,
296 kEndLogitsTensorName)
297 : output_tensors[0];
298 const TfLiteTensor* start_logits_tensor =
299 output_tensor_metadatas
300 ? FindTensorByName(output_tensors, output_tensor_metadatas,
301 kStartLogitsTensorName)
302 : output_tensors[1];
303
304 std::vector<float> end_logits;
305 std::vector<float> start_logits;
306
307 // end_logits FLOAT[1, 384]
308 PopulateVector(end_logits_tensor, &end_logits);
309 // start_logits FLOAT[1, 384]
310 PopulateVector(start_logits_tensor, &start_logits);
311
312 auto start_indices = ReverseSortIndices(start_logits);
313 auto end_indices = ReverseSortIndices(end_logits);
314
315 std::vector<QaAnswer::Pos> orig_results;
316 for (int start_index = 0; start_index < kPredictAnsNum; start_index++) {
317 for (int end_index = 0; end_index < kPredictAnsNum; end_index++) {
318 int start = start_indices[start_index];
319 int end = end_indices[end_index];
320
321 if (!token_to_orig_map_.contains(start + kOutputOffset) ||
322 !token_to_orig_map_.contains(end + kOutputOffset) || end < start ||
323 (end - start + 1) > kMaxAnsLen) {
324 continue;
325 }
326 orig_results.emplace_back(
327 QaAnswer::Pos(start, end, start_logits[start] + end_logits[end]));
328 }
329 }
330
331 std::sort(orig_results.begin(), orig_results.end());
332
333 std::vector<QaAnswer> answers;
334 for (int i = 0; i < orig_results.size() && i < kPredictAnsNum; i++) {
335 auto orig_pos = orig_results[i];
336 answers.emplace_back(
337 orig_pos.start > 0 ? ConvertIndexToString(orig_pos.start, orig_pos.end)
338 : "",
339 orig_pos);
340 }
341
342 return answers;
343 }
344
ConvertIndexToString(int start,int end)345 std::string BertQuestionAnswerer::ConvertIndexToString(int start, int end) {
346 int start_index = token_to_orig_map_[start + kOutputOffset];
347 int end_index = token_to_orig_map_[end + kOutputOffset];
348
349 return absl::StrJoin(orig_tokens_.begin() + start_index,
350 orig_tokens_.begin() + end_index + 1, " ");
351 }
352
InitializeFromMetadata()353 absl::Status BertQuestionAnswerer::InitializeFromMetadata() {
354 const ProcessUnit* tokenizer_process_unit =
355 GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
356 if (tokenizer_process_unit == nullptr) {
357 return CreateStatusWithPayload(
358 absl::StatusCode::kInvalidArgument,
359 "No input process unit found from metadata.",
360 TfLiteSupportStatus::kMetadataInvalidTokenizerError);
361 }
362 ASSIGN_OR_RETURN(tokenizer_,
363 CreateTokenizerFromProcessUnit(tokenizer_process_unit,
364 GetMetadataExtractor()));
365 return absl::OkStatus();
366 }
367
InitializeBertTokenizer(const std::string & path_to_vocab)368 void BertQuestionAnswerer::InitializeBertTokenizer(
369 const std::string& path_to_vocab) {
370 tokenizer_ = absl::make_unique<BertTokenizer>(path_to_vocab);
371 }
372
InitializeBertTokenizerFromBinary(const char * vocab_buffer_data,size_t vocab_buffer_size)373 void BertQuestionAnswerer::InitializeBertTokenizerFromBinary(
374 const char* vocab_buffer_data, size_t vocab_buffer_size) {
375 tokenizer_ =
376 absl::make_unique<BertTokenizer>(vocab_buffer_data, vocab_buffer_size);
377 }
378
InitializeSentencepieceTokenizer(const std::string & path_to_spmodel)379 void BertQuestionAnswerer::InitializeSentencepieceTokenizer(
380 const std::string& path_to_spmodel) {
381 tokenizer_ = absl::make_unique<SentencePieceTokenizer>(path_to_spmodel);
382 }
383
InitializeSentencepieceTokenizerFromBinary(const char * spmodel_buffer_data,size_t spmodel_buffer_size)384 void BertQuestionAnswerer::InitializeSentencepieceTokenizerFromBinary(
385 const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
386 tokenizer_ = absl::make_unique<SentencePieceTokenizer>(spmodel_buffer_data,
387 spmodel_buffer_size);
388 }
389
390 } // namespace qa
391 } // namespace text
392 } // namespace task
393 } // namespace tflite
394